1mod connection_pool;
2
3use crate::{
4 auth::{self, Impersonator},
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
7 HostedProjectId, InviteMemberResult, MembershipUpdated, MessageId, NotificationId, Project,
8 ProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId, ServerId,
9 User, UserId,
10 },
11 executor::Executor,
12 AppState, Error, Result,
13};
14use anyhow::anyhow;
15use async_tungstenite::tungstenite::{
16 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
17};
18use axum::{
19 body::Body,
20 extract::{
21 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
22 ConnectInfo, WebSocketUpgrade,
23 },
24 headers::{Header, HeaderName},
25 http::StatusCode,
26 middleware,
27 response::IntoResponse,
28 routing::get,
29 Extension, Router, TypedHeader,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use futures::{
34 channel::oneshot,
35 future::{self, BoxFuture},
36 stream::FuturesUnordered,
37 FutureExt, SinkExt, StreamExt, TryStreamExt,
38};
39use prometheus::{register_int_gauge, IntGauge};
40use rpc::{
41 proto::{
42 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
43 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
44 },
45 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
46};
47use serde::{Serialize, Serializer};
48use std::{
49 any::TypeId,
50 fmt,
51 future::Future,
52 marker::PhantomData,
53 mem,
54 net::SocketAddr,
55 ops::{Deref, DerefMut},
56 rc::Rc,
57 sync::{
58 atomic::{AtomicBool, Ordering::SeqCst},
59 Arc, OnceLock,
60 },
61 time::{Duration, Instant},
62};
63use time::OffsetDateTime;
64use tokio::sync::{watch, Semaphore};
65use tower::ServiceBuilder;
66use tracing::{field, info_span, instrument, Instrument};
67use util::SemanticVersion;
68
69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
70pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
71
72const MESSAGE_COUNT_PER_PAGE: usize = 100;
73const MAX_MESSAGE_LEN: usize = 1024;
74const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
75
76type MessageHandler =
77 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
78
79struct Response<R> {
80 peer: Arc<Peer>,
81 receipt: Receipt<R>,
82 responded: Arc<AtomicBool>,
83}
84
85impl<R: RequestMessage> Response<R> {
86 fn send(self, payload: R::Response) -> Result<()> {
87 self.responded.store(true, SeqCst);
88 self.peer.respond(self.receipt, payload)?;
89 Ok(())
90 }
91}
92
93#[derive(Clone)]
94struct Session {
95 user_id: UserId,
96 connection_id: ConnectionId,
97 db: Arc<tokio::sync::Mutex<DbHandle>>,
98 peer: Arc<Peer>,
99 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
100 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
101 _executor: Executor,
102}
103
104impl Session {
105 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 let guard = self.db.lock().await;
109 #[cfg(test)]
110 tokio::task::yield_now().await;
111 guard
112 }
113
114 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
115 #[cfg(test)]
116 tokio::task::yield_now().await;
117 let guard = self.connection_pool.lock();
118 ConnectionPoolGuard {
119 guard,
120 _not_send: PhantomData,
121 }
122 }
123}
124
125impl fmt::Debug for Session {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("Session")
128 .field("user_id", &self.user_id)
129 .field("connection_id", &self.connection_id)
130 .finish()
131 }
132}
133
134struct DbHandle(Arc<Database>);
135
136impl Deref for DbHandle {
137 type Target = Database;
138
139 fn deref(&self) -> &Self::Target {
140 self.0.as_ref()
141 }
142}
143
144pub struct Server {
145 id: parking_lot::Mutex<ServerId>,
146 peer: Arc<Peer>,
147 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
148 app_state: Arc<AppState>,
149 executor: Executor,
150 handlers: HashMap<TypeId, MessageHandler>,
151 teardown: watch::Sender<bool>,
152}
153
154pub(crate) struct ConnectionPoolGuard<'a> {
155 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
156 _not_send: PhantomData<Rc<()>>,
157}
158
159#[derive(Serialize)]
160pub struct ServerSnapshot<'a> {
161 peer: &'a Peer,
162 #[serde(serialize_with = "serialize_deref")]
163 connection_pool: ConnectionPoolGuard<'a>,
164}
165
166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
167where
168 S: Serializer,
169 T: Deref<Target = U>,
170 U: Serialize,
171{
172 Serialize::serialize(value.deref(), serializer)
173}
174
175impl Server {
176 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
177 let mut server = Self {
178 id: parking_lot::Mutex::new(id),
179 peer: Peer::new(id.0 as u32),
180 app_state,
181 executor,
182 connection_pool: Default::default(),
183 handlers: Default::default(),
184 teardown: watch::channel(false).0,
185 };
186
187 server
188 .add_request_handler(ping)
189 .add_request_handler(create_room)
190 .add_request_handler(join_room)
191 .add_request_handler(rejoin_room)
192 .add_request_handler(leave_room)
193 .add_request_handler(set_room_participant_role)
194 .add_request_handler(call)
195 .add_request_handler(cancel_call)
196 .add_message_handler(decline_call)
197 .add_request_handler(update_participant_location)
198 .add_request_handler(share_project)
199 .add_message_handler(unshare_project)
200 .add_request_handler(join_project)
201 .add_request_handler(join_hosted_project)
202 .add_message_handler(leave_project)
203 .add_request_handler(update_project)
204 .add_request_handler(update_worktree)
205 .add_message_handler(start_language_server)
206 .add_message_handler(update_language_server)
207 .add_message_handler(update_diagnostic_summary)
208 .add_message_handler(update_worktree_settings)
209 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
210 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
211 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
212 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
213 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
214 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
215 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
216 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
217 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
218 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
219 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
220 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
221 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
222 .add_request_handler(
223 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
224 )
225 .add_request_handler(
226 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
227 )
228 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
229 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
230 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
231 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
232 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
233 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
234 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
235 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
236 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
237 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
238 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
239 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
240 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
241 .add_message_handler(create_buffer_for_peer)
242 .add_request_handler(update_buffer)
243 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
244 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
245 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
246 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
247 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
248 .add_request_handler(get_users)
249 .add_request_handler(fuzzy_search_users)
250 .add_request_handler(request_contact)
251 .add_request_handler(remove_contact)
252 .add_request_handler(respond_to_contact_request)
253 .add_request_handler(create_channel)
254 .add_request_handler(delete_channel)
255 .add_request_handler(invite_channel_member)
256 .add_request_handler(remove_channel_member)
257 .add_request_handler(set_channel_member_role)
258 .add_request_handler(set_channel_visibility)
259 .add_request_handler(rename_channel)
260 .add_request_handler(join_channel_buffer)
261 .add_request_handler(leave_channel_buffer)
262 .add_message_handler(update_channel_buffer)
263 .add_request_handler(rejoin_channel_buffers)
264 .add_request_handler(get_channel_members)
265 .add_request_handler(respond_to_channel_invite)
266 .add_request_handler(join_channel)
267 .add_request_handler(join_channel_chat)
268 .add_message_handler(leave_channel_chat)
269 .add_request_handler(send_channel_message)
270 .add_request_handler(remove_channel_message)
271 .add_request_handler(get_channel_messages)
272 .add_request_handler(get_channel_messages_by_id)
273 .add_request_handler(get_notifications)
274 .add_request_handler(mark_notification_as_read)
275 .add_request_handler(move_channel)
276 .add_request_handler(follow)
277 .add_message_handler(unfollow)
278 .add_message_handler(update_followers)
279 .add_request_handler(get_private_user_info)
280 .add_message_handler(acknowledge_channel_message)
281 .add_message_handler(acknowledge_buffer_version);
282
283 Arc::new(server)
284 }
285
286 pub async fn start(&self) -> Result<()> {
287 let server_id = *self.id.lock();
288 let app_state = self.app_state.clone();
289 let peer = self.peer.clone();
290 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
291 let pool = self.connection_pool.clone();
292 let live_kit_client = self.app_state.live_kit_client.clone();
293
294 let span = info_span!("start server");
295 self.executor.spawn_detached(
296 async move {
297 tracing::info!("waiting for cleanup timeout");
298 timeout.await;
299 tracing::info!("cleanup timeout expired, retrieving stale rooms");
300 if let Some((room_ids, channel_ids)) = app_state
301 .db
302 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
303 .await
304 .trace_err()
305 {
306 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
307 tracing::info!(
308 stale_channel_buffer_count = channel_ids.len(),
309 "retrieved stale channel buffers"
310 );
311
312 for channel_id in channel_ids {
313 if let Some(refreshed_channel_buffer) = app_state
314 .db
315 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
316 .await
317 .trace_err()
318 {
319 for connection_id in refreshed_channel_buffer.connection_ids {
320 peer.send(
321 connection_id,
322 proto::UpdateChannelBufferCollaborators {
323 channel_id: channel_id.to_proto(),
324 collaborators: refreshed_channel_buffer
325 .collaborators
326 .clone(),
327 },
328 )
329 .trace_err();
330 }
331 }
332 }
333
334 for room_id in room_ids {
335 let mut contacts_to_update = HashSet::default();
336 let mut canceled_calls_to_user_ids = Vec::new();
337 let mut live_kit_room = String::new();
338 let mut delete_live_kit_room = false;
339
340 if let Some(mut refreshed_room) = app_state
341 .db
342 .clear_stale_room_participants(room_id, server_id)
343 .await
344 .trace_err()
345 {
346 tracing::info!(
347 room_id = room_id.0,
348 new_participant_count = refreshed_room.room.participants.len(),
349 "refreshed room"
350 );
351 room_updated(&refreshed_room.room, &peer);
352 if let Some(channel_id) = refreshed_room.channel_id {
353 channel_updated(
354 channel_id,
355 &refreshed_room.room,
356 &refreshed_room.channel_members,
357 &peer,
358 &pool.lock(),
359 );
360 }
361 contacts_to_update
362 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
363 contacts_to_update
364 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
365 canceled_calls_to_user_ids =
366 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
367 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
368 delete_live_kit_room = refreshed_room.room.participants.is_empty();
369 }
370
371 {
372 let pool = pool.lock();
373 for canceled_user_id in canceled_calls_to_user_ids {
374 for connection_id in pool.user_connection_ids(canceled_user_id) {
375 peer.send(
376 connection_id,
377 proto::CallCanceled {
378 room_id: room_id.to_proto(),
379 },
380 )
381 .trace_err();
382 }
383 }
384 }
385
386 for user_id in contacts_to_update {
387 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
388 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
389 if let Some((busy, contacts)) = busy.zip(contacts) {
390 let pool = pool.lock();
391 let updated_contact = contact_for_user(user_id, busy, &pool);
392 for contact in contacts {
393 if let db::Contact::Accepted {
394 user_id: contact_user_id,
395 ..
396 } = contact
397 {
398 for contact_conn_id in
399 pool.user_connection_ids(contact_user_id)
400 {
401 peer.send(
402 contact_conn_id,
403 proto::UpdateContacts {
404 contacts: vec![updated_contact.clone()],
405 remove_contacts: Default::default(),
406 incoming_requests: Default::default(),
407 remove_incoming_requests: Default::default(),
408 outgoing_requests: Default::default(),
409 remove_outgoing_requests: Default::default(),
410 },
411 )
412 .trace_err();
413 }
414 }
415 }
416 }
417 }
418
419 if let Some(live_kit) = live_kit_client.as_ref() {
420 if delete_live_kit_room {
421 live_kit.delete_room(live_kit_room).await.trace_err();
422 }
423 }
424 }
425 }
426
427 app_state
428 .db
429 .delete_stale_servers(&app_state.config.zed_environment, server_id)
430 .await
431 .trace_err();
432 }
433 .instrument(span),
434 );
435 Ok(())
436 }
437
438 pub fn teardown(&self) {
439 self.peer.teardown();
440 self.connection_pool.lock().reset();
441 let _ = self.teardown.send(true);
442 }
443
444 #[cfg(test)]
445 pub fn reset(&self, id: ServerId) {
446 self.teardown();
447 *self.id.lock() = id;
448 self.peer.reset(id.0 as u32);
449 let _ = self.teardown.send(false);
450 }
451
452 #[cfg(test)]
453 pub fn id(&self) -> ServerId {
454 *self.id.lock()
455 }
456
457 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
458 where
459 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
460 Fut: 'static + Send + Future<Output = Result<()>>,
461 M: EnvelopedMessage,
462 {
463 let prev_handler = self.handlers.insert(
464 TypeId::of::<M>(),
465 Box::new(move |envelope, session| {
466 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
467 let span = info_span!(
468 "handle message",
469 payload_type = envelope.payload_type_name()
470 );
471 span.in_scope(|| {
472 tracing::info!(
473 payload_type = envelope.payload_type_name(),
474 "message received"
475 );
476 });
477 let start_time = Instant::now();
478 let future = (handler)(*envelope, session);
479 async move {
480 let result = future.await;
481 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
482 match result {
483 Err(error) => {
484 tracing::error!(%error, ?duration_ms, "error handling message")
485 }
486 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
487 }
488 }
489 .instrument(span)
490 .boxed()
491 }),
492 );
493 if prev_handler.is_some() {
494 panic!("registered a handler for the same message twice");
495 }
496 self
497 }
498
499 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
500 where
501 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
502 Fut: 'static + Send + Future<Output = Result<()>>,
503 M: EnvelopedMessage,
504 {
505 self.add_handler(move |envelope, session| handler(envelope.payload, session));
506 self
507 }
508
509 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
510 where
511 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
512 Fut: Send + Future<Output = Result<()>>,
513 M: RequestMessage,
514 {
515 let handler = Arc::new(handler);
516 self.add_handler(move |envelope, session| {
517 let receipt = envelope.receipt();
518 let handler = handler.clone();
519 async move {
520 let peer = session.peer.clone();
521 let responded = Arc::new(AtomicBool::default());
522 let response = Response {
523 peer: peer.clone(),
524 responded: responded.clone(),
525 receipt,
526 };
527 match (handler)(envelope.payload, response, session).await {
528 Ok(()) => {
529 if responded.load(std::sync::atomic::Ordering::SeqCst) {
530 Ok(())
531 } else {
532 Err(anyhow!("handler did not send a response"))?
533 }
534 }
535 Err(error) => {
536 let proto_err = match &error {
537 Error::Internal(err) => err.to_proto(),
538 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
539 };
540 peer.respond_with_error(receipt, proto_err)?;
541 Err(error)
542 }
543 }
544 }
545 })
546 }
547
548 #[allow(clippy::too_many_arguments)]
549 pub fn handle_connection(
550 self: &Arc<Self>,
551 connection: Connection,
552 address: String,
553 user: User,
554 zed_version: ZedVersion,
555 impersonator: Option<User>,
556 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
557 executor: Executor,
558 ) -> impl Future<Output = Result<()>> {
559 let this = self.clone();
560 let user_id = user.id;
561 let login = user.github_login;
562 let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
563 if let Some(impersonator) = impersonator {
564 span.record("impersonator", &impersonator.github_login);
565 }
566 let mut teardown = self.teardown.subscribe();
567 async move {
568 if *teardown.borrow() {
569 return Err(anyhow!("server is tearing down"))?;
570 }
571 let (connection_id, handle_io, mut incoming_rx) = this
572 .peer
573 .add_connection(connection, {
574 let executor = executor.clone();
575 move |duration| executor.sleep(duration)
576 });
577
578 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
579 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
580 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
581
582 if let Some(send_connection_id) = send_connection_id.take() {
583 let _ = send_connection_id.send(connection_id);
584 }
585
586 if !user.connected_once {
587 this.peer.send(connection_id, proto::ShowContacts {})?;
588 this.app_state.db.set_user_connected_once(user_id, true).await?;
589 }
590
591 let (contacts, channels_for_user, channel_invites) = future::try_join3(
592 this.app_state.db.get_contacts(user_id),
593 this.app_state.db.get_channels_for_user(user_id),
594 this.app_state.db.get_channel_invites_for_user(user_id),
595 ).await?;
596
597 {
598 let mut pool = this.connection_pool.lock();
599 pool.add_connection(connection_id, user_id, user.admin, zed_version);
600 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
601 this.peer.send(connection_id, build_update_user_channels(&channels_for_user))?;
602 this.peer.send(connection_id, build_channels_update(
603 channels_for_user,
604 channel_invites
605 ))?;
606 }
607
608 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
609 this.peer.send(connection_id, incoming_call)?;
610 }
611
612 let session = Session {
613 user_id,
614 connection_id,
615 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
616 peer: this.peer.clone(),
617 connection_pool: this.connection_pool.clone(),
618 live_kit_client: this.app_state.live_kit_client.clone(),
619 _executor: executor.clone()
620 };
621 update_user_contacts(user_id, &session).await?;
622
623 let handle_io = handle_io.fuse();
624 futures::pin_mut!(handle_io);
625
626 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
627 // This prevents deadlocks when e.g., client A performs a request to client B and
628 // client B performs a request to client A. If both clients stop processing further
629 // messages until their respective request completes, they won't have a chance to
630 // respond to the other client's request and cause a deadlock.
631 //
632 // This arrangement ensures we will attempt to process earlier messages first, but fall
633 // back to processing messages arrived later in the spirit of making progress.
634 let mut foreground_message_handlers = FuturesUnordered::new();
635 let concurrent_handlers = Arc::new(Semaphore::new(256));
636 loop {
637 let next_message = async {
638 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
639 let message = incoming_rx.next().await;
640 (permit, message)
641 }.fuse();
642 futures::pin_mut!(next_message);
643 futures::select_biased! {
644 _ = teardown.changed().fuse() => return Ok(()),
645 result = handle_io => {
646 if let Err(error) = result {
647 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
648 }
649 break;
650 }
651 _ = foreground_message_handlers.next() => {}
652 next_message = next_message => {
653 let (permit, message) = next_message;
654 if let Some(message) = message {
655 let type_name = message.payload_type_name();
656 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
657 let span_enter = span.enter();
658 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
659 let is_background = message.is_background();
660 let handle_message = (handler)(message, session.clone());
661 drop(span_enter);
662
663 let handle_message = async move {
664 handle_message.await;
665 drop(permit);
666 }.instrument(span);
667 if is_background {
668 executor.spawn_detached(handle_message);
669 } else {
670 foreground_message_handlers.push(handle_message);
671 }
672 } else {
673 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
674 }
675 } else {
676 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
677 break;
678 }
679 }
680 }
681 }
682
683 drop(foreground_message_handlers);
684 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
685 if let Err(error) = connection_lost(session, teardown, executor).await {
686 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
687 }
688
689 Ok(())
690 }.instrument(span)
691 }
692
693 pub async fn invite_code_redeemed(
694 self: &Arc<Self>,
695 inviter_id: UserId,
696 invitee_id: UserId,
697 ) -> Result<()> {
698 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
699 if let Some(code) = &user.invite_code {
700 let pool = self.connection_pool.lock();
701 let invitee_contact = contact_for_user(invitee_id, false, &pool);
702 for connection_id in pool.user_connection_ids(inviter_id) {
703 self.peer.send(
704 connection_id,
705 proto::UpdateContacts {
706 contacts: vec![invitee_contact.clone()],
707 ..Default::default()
708 },
709 )?;
710 self.peer.send(
711 connection_id,
712 proto::UpdateInviteInfo {
713 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
714 count: user.invite_count as u32,
715 },
716 )?;
717 }
718 }
719 }
720 Ok(())
721 }
722
723 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
724 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
725 if let Some(invite_code) = &user.invite_code {
726 let pool = self.connection_pool.lock();
727 for connection_id in pool.user_connection_ids(user_id) {
728 self.peer.send(
729 connection_id,
730 proto::UpdateInviteInfo {
731 url: format!(
732 "{}{}",
733 self.app_state.config.invite_link_prefix, invite_code
734 ),
735 count: user.invite_count as u32,
736 },
737 )?;
738 }
739 }
740 }
741 Ok(())
742 }
743
744 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
745 ServerSnapshot {
746 connection_pool: ConnectionPoolGuard {
747 guard: self.connection_pool.lock(),
748 _not_send: PhantomData,
749 },
750 peer: &self.peer,
751 }
752 }
753}
754
755impl<'a> Deref for ConnectionPoolGuard<'a> {
756 type Target = ConnectionPool;
757
758 fn deref(&self) -> &Self::Target {
759 &self.guard
760 }
761}
762
763impl<'a> DerefMut for ConnectionPoolGuard<'a> {
764 fn deref_mut(&mut self) -> &mut Self::Target {
765 &mut self.guard
766 }
767}
768
769impl<'a> Drop for ConnectionPoolGuard<'a> {
770 fn drop(&mut self) {
771 #[cfg(test)]
772 self.check_invariants();
773 }
774}
775
776fn broadcast<F>(
777 sender_id: Option<ConnectionId>,
778 receiver_ids: impl IntoIterator<Item = ConnectionId>,
779 mut f: F,
780) where
781 F: FnMut(ConnectionId) -> anyhow::Result<()>,
782{
783 for receiver_id in receiver_ids {
784 if Some(receiver_id) != sender_id {
785 if let Err(error) = f(receiver_id) {
786 tracing::error!("failed to send to {:?} {}", receiver_id, error);
787 }
788 }
789 }
790}
791
792pub struct ProtocolVersion(u32);
793
794impl Header for ProtocolVersion {
795 fn name() -> &'static HeaderName {
796 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
797 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
798 }
799
800 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
801 where
802 Self: Sized,
803 I: Iterator<Item = &'i axum::http::HeaderValue>,
804 {
805 let version = values
806 .next()
807 .ok_or_else(axum::headers::Error::invalid)?
808 .to_str()
809 .map_err(|_| axum::headers::Error::invalid())?
810 .parse()
811 .map_err(|_| axum::headers::Error::invalid())?;
812 Ok(Self(version))
813 }
814
815 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
816 values.extend([self.0.to_string().parse().unwrap()]);
817 }
818}
819
820pub struct AppVersionHeader(SemanticVersion);
821impl Header for AppVersionHeader {
822 fn name() -> &'static HeaderName {
823 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
824 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
825 }
826
827 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
828 where
829 Self: Sized,
830 I: Iterator<Item = &'i axum::http::HeaderValue>,
831 {
832 let version = values
833 .next()
834 .ok_or_else(axum::headers::Error::invalid)?
835 .to_str()
836 .map_err(|_| axum::headers::Error::invalid())?
837 .parse()
838 .map_err(|_| axum::headers::Error::invalid())?;
839 Ok(Self(version))
840 }
841
842 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
843 values.extend([self.0.to_string().parse().unwrap()]);
844 }
845}
846
847pub fn routes(server: Arc<Server>) -> Router<Body> {
848 Router::new()
849 .route("/rpc", get(handle_websocket_request))
850 .layer(
851 ServiceBuilder::new()
852 .layer(Extension(server.app_state.clone()))
853 .layer(middleware::from_fn(auth::validate_header)),
854 )
855 .route("/metrics", get(handle_metrics))
856 .layer(Extension(server))
857}
858
859pub async fn handle_websocket_request(
860 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
861 app_version_header: Option<TypedHeader<AppVersionHeader>>,
862 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
863 Extension(server): Extension<Arc<Server>>,
864 Extension(user): Extension<User>,
865 Extension(impersonator): Extension<Impersonator>,
866 ws: WebSocketUpgrade,
867) -> axum::response::Response {
868 if protocol_version != rpc::PROTOCOL_VERSION {
869 return (
870 StatusCode::UPGRADE_REQUIRED,
871 "client must be upgraded".to_string(),
872 )
873 .into_response();
874 }
875
876 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
877 return (
878 StatusCode::UPGRADE_REQUIRED,
879 "no version header found".to_string(),
880 )
881 .into_response();
882 };
883
884 if !version.is_supported() {
885 return (
886 StatusCode::UPGRADE_REQUIRED,
887 "client must be upgraded".to_string(),
888 )
889 .into_response();
890 }
891
892 let socket_address = socket_address.to_string();
893 ws.on_upgrade(move |socket| {
894 use util::ResultExt;
895 let socket = socket
896 .map_ok(to_tungstenite_message)
897 .err_into()
898 .with(|message| async move { Ok(to_axum_message(message)) });
899 let connection = Connection::new(Box::pin(socket));
900 async move {
901 server
902 .handle_connection(
903 connection,
904 socket_address,
905 user,
906 version,
907 impersonator.0,
908 None,
909 Executor::Production,
910 )
911 .await
912 .log_err();
913 }
914 })
915}
916
917pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
918 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
919 let connections_metric = CONNECTIONS_METRIC
920 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
921
922 let connections = server
923 .connection_pool
924 .lock()
925 .connections()
926 .filter(|connection| !connection.admin)
927 .count();
928 connections_metric.set(connections as _);
929
930 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
931 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
932 register_int_gauge!(
933 "shared_projects",
934 "number of open projects with one or more guests"
935 )
936 .unwrap()
937 });
938
939 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
940 shared_projects_metric.set(shared_projects as _);
941
942 let encoder = prometheus::TextEncoder::new();
943 let metric_families = prometheus::gather();
944 let encoded_metrics = encoder
945 .encode_to_string(&metric_families)
946 .map_err(|err| anyhow!("{}", err))?;
947 Ok(encoded_metrics)
948}
949
950#[instrument(err, skip(executor))]
951async fn connection_lost(
952 session: Session,
953 mut teardown: watch::Receiver<bool>,
954 executor: Executor,
955) -> Result<()> {
956 session.peer.disconnect(session.connection_id);
957 session
958 .connection_pool()
959 .await
960 .remove_connection(session.connection_id)?;
961
962 session
963 .db()
964 .await
965 .connection_lost(session.connection_id)
966 .await
967 .trace_err();
968
969 futures::select_biased! {
970 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
971 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
972 leave_room_for_session(&session).await.trace_err();
973 leave_channel_buffers_for_session(&session)
974 .await
975 .trace_err();
976
977 if !session
978 .connection_pool()
979 .await
980 .is_user_online(session.user_id)
981 {
982 let db = session.db().await;
983 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
984 room_updated(&room, &session.peer);
985 }
986 }
987
988 update_user_contacts(session.user_id, &session).await?;
989 }
990 _ = teardown.changed().fuse() => {}
991 }
992
993 Ok(())
994}
995
996/// Acknowledges a ping from a client, used to keep the connection alive.
997async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
998 response.send(proto::Ack {})?;
999 Ok(())
1000}
1001
1002/// Creates a new room for calling (outside of channels)
1003async fn create_room(
1004 _request: proto::CreateRoom,
1005 response: Response<proto::CreateRoom>,
1006 session: Session,
1007) -> Result<()> {
1008 let live_kit_room = nanoid::nanoid!(30);
1009
1010 let live_kit_connection_info = {
1011 let live_kit_room = live_kit_room.clone();
1012 let live_kit = session.live_kit_client.as_ref();
1013
1014 util::async_maybe!({
1015 let live_kit = live_kit?;
1016
1017 let token = live_kit
1018 .room_token(&live_kit_room, &session.user_id.to_string())
1019 .trace_err()?;
1020
1021 Some(proto::LiveKitConnectionInfo {
1022 server_url: live_kit.url().into(),
1023 token,
1024 can_publish: true,
1025 })
1026 })
1027 }
1028 .await;
1029
1030 let room = session
1031 .db()
1032 .await
1033 .create_room(session.user_id, session.connection_id, &live_kit_room)
1034 .await?;
1035
1036 response.send(proto::CreateRoomResponse {
1037 room: Some(room.clone()),
1038 live_kit_connection_info,
1039 })?;
1040
1041 update_user_contacts(session.user_id, &session).await?;
1042 Ok(())
1043}
1044
1045/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1046async fn join_room(
1047 request: proto::JoinRoom,
1048 response: Response<proto::JoinRoom>,
1049 session: Session,
1050) -> Result<()> {
1051 let room_id = RoomId::from_proto(request.id);
1052
1053 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1054
1055 if let Some(channel_id) = channel_id {
1056 return join_channel_internal(channel_id, Box::new(response), session).await;
1057 }
1058
1059 let joined_room = {
1060 let room = session
1061 .db()
1062 .await
1063 .join_room(room_id, session.user_id, session.connection_id)
1064 .await?;
1065 room_updated(&room.room, &session.peer);
1066 room.into_inner()
1067 };
1068
1069 for connection_id in session
1070 .connection_pool()
1071 .await
1072 .user_connection_ids(session.user_id)
1073 {
1074 session
1075 .peer
1076 .send(
1077 connection_id,
1078 proto::CallCanceled {
1079 room_id: room_id.to_proto(),
1080 },
1081 )
1082 .trace_err();
1083 }
1084
1085 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1086 if let Some(token) = live_kit
1087 .room_token(
1088 &joined_room.room.live_kit_room,
1089 &session.user_id.to_string(),
1090 )
1091 .trace_err()
1092 {
1093 Some(proto::LiveKitConnectionInfo {
1094 server_url: live_kit.url().into(),
1095 token,
1096 can_publish: true,
1097 })
1098 } else {
1099 None
1100 }
1101 } else {
1102 None
1103 };
1104
1105 response.send(proto::JoinRoomResponse {
1106 room: Some(joined_room.room),
1107 channel_id: None,
1108 live_kit_connection_info,
1109 })?;
1110
1111 update_user_contacts(session.user_id, &session).await?;
1112 Ok(())
1113}
1114
1115/// Rejoin room is used to reconnect to a room after connection errors.
1116async fn rejoin_room(
1117 request: proto::RejoinRoom,
1118 response: Response<proto::RejoinRoom>,
1119 session: Session,
1120) -> Result<()> {
1121 let room;
1122 let channel_id;
1123 let channel_members;
1124 {
1125 let mut rejoined_room = session
1126 .db()
1127 .await
1128 .rejoin_room(request, session.user_id, session.connection_id)
1129 .await?;
1130
1131 response.send(proto::RejoinRoomResponse {
1132 room: Some(rejoined_room.room.clone()),
1133 reshared_projects: rejoined_room
1134 .reshared_projects
1135 .iter()
1136 .map(|project| proto::ResharedProject {
1137 id: project.id.to_proto(),
1138 collaborators: project
1139 .collaborators
1140 .iter()
1141 .map(|collaborator| collaborator.to_proto())
1142 .collect(),
1143 })
1144 .collect(),
1145 rejoined_projects: rejoined_room
1146 .rejoined_projects
1147 .iter()
1148 .map(|rejoined_project| proto::RejoinedProject {
1149 id: rejoined_project.id.to_proto(),
1150 worktrees: rejoined_project
1151 .worktrees
1152 .iter()
1153 .map(|worktree| proto::WorktreeMetadata {
1154 id: worktree.id,
1155 root_name: worktree.root_name.clone(),
1156 visible: worktree.visible,
1157 abs_path: worktree.abs_path.clone(),
1158 })
1159 .collect(),
1160 collaborators: rejoined_project
1161 .collaborators
1162 .iter()
1163 .map(|collaborator| collaborator.to_proto())
1164 .collect(),
1165 language_servers: rejoined_project.language_servers.clone(),
1166 })
1167 .collect(),
1168 })?;
1169 room_updated(&rejoined_room.room, &session.peer);
1170
1171 for project in &rejoined_room.reshared_projects {
1172 for collaborator in &project.collaborators {
1173 session
1174 .peer
1175 .send(
1176 collaborator.connection_id,
1177 proto::UpdateProjectCollaborator {
1178 project_id: project.id.to_proto(),
1179 old_peer_id: Some(project.old_connection_id.into()),
1180 new_peer_id: Some(session.connection_id.into()),
1181 },
1182 )
1183 .trace_err();
1184 }
1185
1186 broadcast(
1187 Some(session.connection_id),
1188 project
1189 .collaborators
1190 .iter()
1191 .map(|collaborator| collaborator.connection_id),
1192 |connection_id| {
1193 session.peer.forward_send(
1194 session.connection_id,
1195 connection_id,
1196 proto::UpdateProject {
1197 project_id: project.id.to_proto(),
1198 worktrees: project.worktrees.clone(),
1199 },
1200 )
1201 },
1202 );
1203 }
1204
1205 for project in &rejoined_room.rejoined_projects {
1206 for collaborator in &project.collaborators {
1207 session
1208 .peer
1209 .send(
1210 collaborator.connection_id,
1211 proto::UpdateProjectCollaborator {
1212 project_id: project.id.to_proto(),
1213 old_peer_id: Some(project.old_connection_id.into()),
1214 new_peer_id: Some(session.connection_id.into()),
1215 },
1216 )
1217 .trace_err();
1218 }
1219 }
1220
1221 for project in &mut rejoined_room.rejoined_projects {
1222 for worktree in mem::take(&mut project.worktrees) {
1223 #[cfg(any(test, feature = "test-support"))]
1224 const MAX_CHUNK_SIZE: usize = 2;
1225 #[cfg(not(any(test, feature = "test-support")))]
1226 const MAX_CHUNK_SIZE: usize = 256;
1227
1228 // Stream this worktree's entries.
1229 let message = proto::UpdateWorktree {
1230 project_id: project.id.to_proto(),
1231 worktree_id: worktree.id,
1232 abs_path: worktree.abs_path.clone(),
1233 root_name: worktree.root_name,
1234 updated_entries: worktree.updated_entries,
1235 removed_entries: worktree.removed_entries,
1236 scan_id: worktree.scan_id,
1237 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1238 updated_repositories: worktree.updated_repositories,
1239 removed_repositories: worktree.removed_repositories,
1240 };
1241 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1242 session.peer.send(session.connection_id, update.clone())?;
1243 }
1244
1245 // Stream this worktree's diagnostics.
1246 for summary in worktree.diagnostic_summaries {
1247 session.peer.send(
1248 session.connection_id,
1249 proto::UpdateDiagnosticSummary {
1250 project_id: project.id.to_proto(),
1251 worktree_id: worktree.id,
1252 summary: Some(summary),
1253 },
1254 )?;
1255 }
1256
1257 for settings_file in worktree.settings_files {
1258 session.peer.send(
1259 session.connection_id,
1260 proto::UpdateWorktreeSettings {
1261 project_id: project.id.to_proto(),
1262 worktree_id: worktree.id,
1263 path: settings_file.path,
1264 content: Some(settings_file.content),
1265 },
1266 )?;
1267 }
1268 }
1269
1270 for language_server in &project.language_servers {
1271 session.peer.send(
1272 session.connection_id,
1273 proto::UpdateLanguageServer {
1274 project_id: project.id.to_proto(),
1275 language_server_id: language_server.id,
1276 variant: Some(
1277 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1278 proto::LspDiskBasedDiagnosticsUpdated {},
1279 ),
1280 ),
1281 },
1282 )?;
1283 }
1284 }
1285
1286 let rejoined_room = rejoined_room.into_inner();
1287
1288 room = rejoined_room.room;
1289 channel_id = rejoined_room.channel_id;
1290 channel_members = rejoined_room.channel_members;
1291 }
1292
1293 if let Some(channel_id) = channel_id {
1294 channel_updated(
1295 channel_id,
1296 &room,
1297 &channel_members,
1298 &session.peer,
1299 &*session.connection_pool().await,
1300 );
1301 }
1302
1303 update_user_contacts(session.user_id, &session).await?;
1304 Ok(())
1305}
1306
1307/// leave room disconnects from the room.
1308async fn leave_room(
1309 _: proto::LeaveRoom,
1310 response: Response<proto::LeaveRoom>,
1311 session: Session,
1312) -> Result<()> {
1313 leave_room_for_session(&session).await?;
1314 response.send(proto::Ack {})?;
1315 Ok(())
1316}
1317
1318/// Updates the permissions of someone else in the room.
1319async fn set_room_participant_role(
1320 request: proto::SetRoomParticipantRole,
1321 response: Response<proto::SetRoomParticipantRole>,
1322 session: Session,
1323) -> Result<()> {
1324 let user_id = UserId::from_proto(request.user_id);
1325 let role = ChannelRole::from(request.role());
1326
1327 if role == ChannelRole::Talker {
1328 let pool = session.connection_pool().await;
1329
1330 for connection in pool.user_connections(user_id) {
1331 if !connection.zed_version.supports_talker_role() {
1332 Err(anyhow!(
1333 "This user is on zed {} which does not support unmute",
1334 connection.zed_version
1335 ))?;
1336 }
1337 }
1338 }
1339
1340 let (live_kit_room, can_publish) = {
1341 let room = session
1342 .db()
1343 .await
1344 .set_room_participant_role(
1345 session.user_id,
1346 RoomId::from_proto(request.room_id),
1347 user_id,
1348 role,
1349 )
1350 .await?;
1351
1352 let live_kit_room = room.live_kit_room.clone();
1353 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1354 room_updated(&room, &session.peer);
1355 (live_kit_room, can_publish)
1356 };
1357
1358 if let Some(live_kit) = session.live_kit_client.as_ref() {
1359 live_kit
1360 .update_participant(
1361 live_kit_room.clone(),
1362 request.user_id.to_string(),
1363 live_kit_server::proto::ParticipantPermission {
1364 can_subscribe: true,
1365 can_publish,
1366 can_publish_data: can_publish,
1367 hidden: false,
1368 recorder: false,
1369 },
1370 )
1371 .await
1372 .trace_err();
1373 }
1374
1375 response.send(proto::Ack {})?;
1376 Ok(())
1377}
1378
1379/// Call someone else into the current room
1380async fn call(
1381 request: proto::Call,
1382 response: Response<proto::Call>,
1383 session: Session,
1384) -> Result<()> {
1385 let room_id = RoomId::from_proto(request.room_id);
1386 let calling_user_id = session.user_id;
1387 let calling_connection_id = session.connection_id;
1388 let called_user_id = UserId::from_proto(request.called_user_id);
1389 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1390 if !session
1391 .db()
1392 .await
1393 .has_contact(calling_user_id, called_user_id)
1394 .await?
1395 {
1396 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1397 }
1398
1399 let incoming_call = {
1400 let (room, incoming_call) = &mut *session
1401 .db()
1402 .await
1403 .call(
1404 room_id,
1405 calling_user_id,
1406 calling_connection_id,
1407 called_user_id,
1408 initial_project_id,
1409 )
1410 .await?;
1411 room_updated(&room, &session.peer);
1412 mem::take(incoming_call)
1413 };
1414 update_user_contacts(called_user_id, &session).await?;
1415
1416 let mut calls = session
1417 .connection_pool()
1418 .await
1419 .user_connection_ids(called_user_id)
1420 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1421 .collect::<FuturesUnordered<_>>();
1422
1423 while let Some(call_response) = calls.next().await {
1424 match call_response.as_ref() {
1425 Ok(_) => {
1426 response.send(proto::Ack {})?;
1427 return Ok(());
1428 }
1429 Err(_) => {
1430 call_response.trace_err();
1431 }
1432 }
1433 }
1434
1435 {
1436 let room = session
1437 .db()
1438 .await
1439 .call_failed(room_id, called_user_id)
1440 .await?;
1441 room_updated(&room, &session.peer);
1442 }
1443 update_user_contacts(called_user_id, &session).await?;
1444
1445 Err(anyhow!("failed to ring user"))?
1446}
1447
1448/// Cancel an outgoing call.
1449async fn cancel_call(
1450 request: proto::CancelCall,
1451 response: Response<proto::CancelCall>,
1452 session: Session,
1453) -> Result<()> {
1454 let called_user_id = UserId::from_proto(request.called_user_id);
1455 let room_id = RoomId::from_proto(request.room_id);
1456 {
1457 let room = session
1458 .db()
1459 .await
1460 .cancel_call(room_id, session.connection_id, called_user_id)
1461 .await?;
1462 room_updated(&room, &session.peer);
1463 }
1464
1465 for connection_id in session
1466 .connection_pool()
1467 .await
1468 .user_connection_ids(called_user_id)
1469 {
1470 session
1471 .peer
1472 .send(
1473 connection_id,
1474 proto::CallCanceled {
1475 room_id: room_id.to_proto(),
1476 },
1477 )
1478 .trace_err();
1479 }
1480 response.send(proto::Ack {})?;
1481
1482 update_user_contacts(called_user_id, &session).await?;
1483 Ok(())
1484}
1485
1486/// Decline an incoming call.
1487async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1488 let room_id = RoomId::from_proto(message.room_id);
1489 {
1490 let room = session
1491 .db()
1492 .await
1493 .decline_call(Some(room_id), session.user_id)
1494 .await?
1495 .ok_or_else(|| anyhow!("failed to decline call"))?;
1496 room_updated(&room, &session.peer);
1497 }
1498
1499 for connection_id in session
1500 .connection_pool()
1501 .await
1502 .user_connection_ids(session.user_id)
1503 {
1504 session
1505 .peer
1506 .send(
1507 connection_id,
1508 proto::CallCanceled {
1509 room_id: room_id.to_proto(),
1510 },
1511 )
1512 .trace_err();
1513 }
1514 update_user_contacts(session.user_id, &session).await?;
1515 Ok(())
1516}
1517
1518/// Updates other participants in the room with your current location.
1519async fn update_participant_location(
1520 request: proto::UpdateParticipantLocation,
1521 response: Response<proto::UpdateParticipantLocation>,
1522 session: Session,
1523) -> Result<()> {
1524 let room_id = RoomId::from_proto(request.room_id);
1525 let location = request
1526 .location
1527 .ok_or_else(|| anyhow!("invalid location"))?;
1528
1529 let db = session.db().await;
1530 let room = db
1531 .update_room_participant_location(room_id, session.connection_id, location)
1532 .await?;
1533
1534 room_updated(&room, &session.peer);
1535 response.send(proto::Ack {})?;
1536 Ok(())
1537}
1538
1539/// Share a project into the room.
1540async fn share_project(
1541 request: proto::ShareProject,
1542 response: Response<proto::ShareProject>,
1543 session: Session,
1544) -> Result<()> {
1545 let (project_id, room) = &*session
1546 .db()
1547 .await
1548 .share_project(
1549 RoomId::from_proto(request.room_id),
1550 session.connection_id,
1551 &request.worktrees,
1552 )
1553 .await?;
1554 response.send(proto::ShareProjectResponse {
1555 project_id: project_id.to_proto(),
1556 })?;
1557 room_updated(&room, &session.peer);
1558
1559 Ok(())
1560}
1561
1562/// Unshare a project from the room.
1563async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1564 let project_id = ProjectId::from_proto(message.project_id);
1565
1566 let (room, guest_connection_ids) = &*session
1567 .db()
1568 .await
1569 .unshare_project(project_id, session.connection_id)
1570 .await?;
1571
1572 broadcast(
1573 Some(session.connection_id),
1574 guest_connection_ids.iter().copied(),
1575 |conn_id| session.peer.send(conn_id, message.clone()),
1576 );
1577 room_updated(&room, &session.peer);
1578
1579 Ok(())
1580}
1581
1582/// Join someone elses shared project.
1583async fn join_project(
1584 request: proto::JoinProject,
1585 response: Response<proto::JoinProject>,
1586 session: Session,
1587) -> Result<()> {
1588 let project_id = ProjectId::from_proto(request.project_id);
1589
1590 tracing::info!(%project_id, "join project");
1591
1592 let (project, replica_id) = &mut *session
1593 .db()
1594 .await
1595 .join_project_in_room(project_id, session.connection_id)
1596 .await?;
1597
1598 join_project_internal(response, session, project, replica_id)
1599}
1600
1601trait JoinProjectInternalResponse {
1602 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1603}
1604impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1605 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1606 Response::<proto::JoinProject>::send(self, result)
1607 }
1608}
1609impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1610 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1611 Response::<proto::JoinHostedProject>::send(self, result)
1612 }
1613}
1614
1615fn join_project_internal(
1616 response: impl JoinProjectInternalResponse,
1617 session: Session,
1618 project: &mut Project,
1619 replica_id: &ReplicaId,
1620) -> Result<()> {
1621 let collaborators = project
1622 .collaborators
1623 .iter()
1624 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1625 .map(|collaborator| collaborator.to_proto())
1626 .collect::<Vec<_>>();
1627 let project_id = project.id;
1628 let guest_user_id = session.user_id;
1629
1630 let worktrees = project
1631 .worktrees
1632 .iter()
1633 .map(|(id, worktree)| proto::WorktreeMetadata {
1634 id: *id,
1635 root_name: worktree.root_name.clone(),
1636 visible: worktree.visible,
1637 abs_path: worktree.abs_path.clone(),
1638 })
1639 .collect::<Vec<_>>();
1640
1641 for collaborator in &collaborators {
1642 session
1643 .peer
1644 .send(
1645 collaborator.peer_id.unwrap().into(),
1646 proto::AddProjectCollaborator {
1647 project_id: project_id.to_proto(),
1648 collaborator: Some(proto::Collaborator {
1649 peer_id: Some(session.connection_id.into()),
1650 replica_id: replica_id.0 as u32,
1651 user_id: guest_user_id.to_proto(),
1652 }),
1653 },
1654 )
1655 .trace_err();
1656 }
1657
1658 // First, we send the metadata associated with each worktree.
1659 response.send(proto::JoinProjectResponse {
1660 project_id: project.id.0 as u64,
1661 worktrees: worktrees.clone(),
1662 replica_id: replica_id.0 as u32,
1663 collaborators: collaborators.clone(),
1664 language_servers: project.language_servers.clone(),
1665 role: project.role.into(), // todo
1666 })?;
1667
1668 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1669 #[cfg(any(test, feature = "test-support"))]
1670 const MAX_CHUNK_SIZE: usize = 2;
1671 #[cfg(not(any(test, feature = "test-support")))]
1672 const MAX_CHUNK_SIZE: usize = 256;
1673
1674 // Stream this worktree's entries.
1675 let message = proto::UpdateWorktree {
1676 project_id: project_id.to_proto(),
1677 worktree_id,
1678 abs_path: worktree.abs_path.clone(),
1679 root_name: worktree.root_name,
1680 updated_entries: worktree.entries,
1681 removed_entries: Default::default(),
1682 scan_id: worktree.scan_id,
1683 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1684 updated_repositories: worktree.repository_entries.into_values().collect(),
1685 removed_repositories: Default::default(),
1686 };
1687 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1688 session.peer.send(session.connection_id, update.clone())?;
1689 }
1690
1691 // Stream this worktree's diagnostics.
1692 for summary in worktree.diagnostic_summaries {
1693 session.peer.send(
1694 session.connection_id,
1695 proto::UpdateDiagnosticSummary {
1696 project_id: project_id.to_proto(),
1697 worktree_id: worktree.id,
1698 summary: Some(summary),
1699 },
1700 )?;
1701 }
1702
1703 for settings_file in worktree.settings_files {
1704 session.peer.send(
1705 session.connection_id,
1706 proto::UpdateWorktreeSettings {
1707 project_id: project_id.to_proto(),
1708 worktree_id: worktree.id,
1709 path: settings_file.path,
1710 content: Some(settings_file.content),
1711 },
1712 )?;
1713 }
1714 }
1715
1716 for language_server in &project.language_servers {
1717 session.peer.send(
1718 session.connection_id,
1719 proto::UpdateLanguageServer {
1720 project_id: project_id.to_proto(),
1721 language_server_id: language_server.id,
1722 variant: Some(
1723 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1724 proto::LspDiskBasedDiagnosticsUpdated {},
1725 ),
1726 ),
1727 },
1728 )?;
1729 }
1730
1731 Ok(())
1732}
1733
1734/// Leave someone elses shared project.
1735async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1736 let sender_id = session.connection_id;
1737 let project_id = ProjectId::from_proto(request.project_id);
1738 let db = session.db().await;
1739 if db.is_hosted_project(project_id).await? {
1740 let project = db.leave_hosted_project(project_id, sender_id).await?;
1741 project_left(&project, &session);
1742 return Ok(());
1743 }
1744
1745 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1746 tracing::info!(
1747 %project_id,
1748 host_user_id = ?project.host_user_id,
1749 host_connection_id = ?project.host_connection_id,
1750 "leave project"
1751 );
1752
1753 project_left(&project, &session);
1754 room_updated(&room, &session.peer);
1755
1756 Ok(())
1757}
1758
1759async fn join_hosted_project(
1760 request: proto::JoinHostedProject,
1761 response: Response<proto::JoinHostedProject>,
1762 session: Session,
1763) -> Result<()> {
1764 let (mut project, replica_id) = session
1765 .db()
1766 .await
1767 .join_hosted_project(
1768 HostedProjectId(request.id as i32),
1769 session.user_id,
1770 session.connection_id,
1771 )
1772 .await?;
1773
1774 join_project_internal(response, session, &mut project, &replica_id)
1775}
1776
1777/// Updates other participants with changes to the project
1778async fn update_project(
1779 request: proto::UpdateProject,
1780 response: Response<proto::UpdateProject>,
1781 session: Session,
1782) -> Result<()> {
1783 let project_id = ProjectId::from_proto(request.project_id);
1784 let (room, guest_connection_ids) = &*session
1785 .db()
1786 .await
1787 .update_project(project_id, session.connection_id, &request.worktrees)
1788 .await?;
1789 broadcast(
1790 Some(session.connection_id),
1791 guest_connection_ids.iter().copied(),
1792 |connection_id| {
1793 session
1794 .peer
1795 .forward_send(session.connection_id, connection_id, request.clone())
1796 },
1797 );
1798 room_updated(&room, &session.peer);
1799 response.send(proto::Ack {})?;
1800
1801 Ok(())
1802}
1803
1804/// Updates other participants with changes to the worktree
1805async fn update_worktree(
1806 request: proto::UpdateWorktree,
1807 response: Response<proto::UpdateWorktree>,
1808 session: Session,
1809) -> Result<()> {
1810 let guest_connection_ids = session
1811 .db()
1812 .await
1813 .update_worktree(&request, session.connection_id)
1814 .await?;
1815
1816 broadcast(
1817 Some(session.connection_id),
1818 guest_connection_ids.iter().copied(),
1819 |connection_id| {
1820 session
1821 .peer
1822 .forward_send(session.connection_id, connection_id, request.clone())
1823 },
1824 );
1825 response.send(proto::Ack {})?;
1826 Ok(())
1827}
1828
1829/// Updates other participants with changes to the diagnostics
1830async fn update_diagnostic_summary(
1831 message: proto::UpdateDiagnosticSummary,
1832 session: Session,
1833) -> Result<()> {
1834 let guest_connection_ids = session
1835 .db()
1836 .await
1837 .update_diagnostic_summary(&message, session.connection_id)
1838 .await?;
1839
1840 broadcast(
1841 Some(session.connection_id),
1842 guest_connection_ids.iter().copied(),
1843 |connection_id| {
1844 session
1845 .peer
1846 .forward_send(session.connection_id, connection_id, message.clone())
1847 },
1848 );
1849
1850 Ok(())
1851}
1852
1853/// Updates other participants with changes to the worktree settings
1854async fn update_worktree_settings(
1855 message: proto::UpdateWorktreeSettings,
1856 session: Session,
1857) -> Result<()> {
1858 let guest_connection_ids = session
1859 .db()
1860 .await
1861 .update_worktree_settings(&message, session.connection_id)
1862 .await?;
1863
1864 broadcast(
1865 Some(session.connection_id),
1866 guest_connection_ids.iter().copied(),
1867 |connection_id| {
1868 session
1869 .peer
1870 .forward_send(session.connection_id, connection_id, message.clone())
1871 },
1872 );
1873
1874 Ok(())
1875}
1876
1877/// Notify other participants that a language server has started.
1878async fn start_language_server(
1879 request: proto::StartLanguageServer,
1880 session: Session,
1881) -> Result<()> {
1882 let guest_connection_ids = session
1883 .db()
1884 .await
1885 .start_language_server(&request, session.connection_id)
1886 .await?;
1887
1888 broadcast(
1889 Some(session.connection_id),
1890 guest_connection_ids.iter().copied(),
1891 |connection_id| {
1892 session
1893 .peer
1894 .forward_send(session.connection_id, connection_id, request.clone())
1895 },
1896 );
1897 Ok(())
1898}
1899
1900/// Notify other participants that a language server has changed.
1901async fn update_language_server(
1902 request: proto::UpdateLanguageServer,
1903 session: Session,
1904) -> Result<()> {
1905 let project_id = ProjectId::from_proto(request.project_id);
1906 let project_connection_ids = session
1907 .db()
1908 .await
1909 .project_connection_ids(project_id, session.connection_id)
1910 .await?;
1911 broadcast(
1912 Some(session.connection_id),
1913 project_connection_ids.iter().copied(),
1914 |connection_id| {
1915 session
1916 .peer
1917 .forward_send(session.connection_id, connection_id, request.clone())
1918 },
1919 );
1920 Ok(())
1921}
1922
1923/// forward a project request to the host. These requests should be read only
1924/// as guests are allowed to send them.
1925async fn forward_read_only_project_request<T>(
1926 request: T,
1927 response: Response<T>,
1928 session: Session,
1929) -> Result<()>
1930where
1931 T: EntityMessage + RequestMessage,
1932{
1933 let project_id = ProjectId::from_proto(request.remote_entity_id());
1934 let host_connection_id = session
1935 .db()
1936 .await
1937 .host_for_read_only_project_request(project_id, session.connection_id)
1938 .await?;
1939 let payload = session
1940 .peer
1941 .forward_request(session.connection_id, host_connection_id, request)
1942 .await?;
1943 response.send(payload)?;
1944 Ok(())
1945}
1946
1947/// forward a project request to the host. These requests are disallowed
1948/// for guests.
1949async fn forward_mutating_project_request<T>(
1950 request: T,
1951 response: Response<T>,
1952 session: Session,
1953) -> Result<()>
1954where
1955 T: EntityMessage + RequestMessage,
1956{
1957 let project_id = ProjectId::from_proto(request.remote_entity_id());
1958 let host_connection_id = session
1959 .db()
1960 .await
1961 .host_for_mutating_project_request(project_id, session.connection_id)
1962 .await?;
1963 let payload = session
1964 .peer
1965 .forward_request(session.connection_id, host_connection_id, request)
1966 .await?;
1967 response.send(payload)?;
1968 Ok(())
1969}
1970
1971/// Notify other participants that a new buffer has been created
1972async fn create_buffer_for_peer(
1973 request: proto::CreateBufferForPeer,
1974 session: Session,
1975) -> Result<()> {
1976 session
1977 .db()
1978 .await
1979 .check_user_is_project_host(
1980 ProjectId::from_proto(request.project_id),
1981 session.connection_id,
1982 )
1983 .await?;
1984 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1985 session
1986 .peer
1987 .forward_send(session.connection_id, peer_id.into(), request)?;
1988 Ok(())
1989}
1990
1991/// Notify other participants that a buffer has been updated. This is
1992/// allowed for guests as long as the update is limited to selections.
1993async fn update_buffer(
1994 request: proto::UpdateBuffer,
1995 response: Response<proto::UpdateBuffer>,
1996 session: Session,
1997) -> Result<()> {
1998 let project_id = ProjectId::from_proto(request.project_id);
1999 let mut guest_connection_ids;
2000 let mut host_connection_id = None;
2001
2002 let mut requires_write_permission = false;
2003
2004 for op in request.operations.iter() {
2005 match op.variant {
2006 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2007 Some(_) => requires_write_permission = true,
2008 }
2009 }
2010
2011 {
2012 let collaborators = session
2013 .db()
2014 .await
2015 .project_collaborators_for_buffer_update(
2016 project_id,
2017 session.connection_id,
2018 requires_write_permission,
2019 )
2020 .await?;
2021 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2022 for collaborator in collaborators.iter() {
2023 if collaborator.is_host {
2024 host_connection_id = Some(collaborator.connection_id);
2025 } else {
2026 guest_connection_ids.push(collaborator.connection_id);
2027 }
2028 }
2029 }
2030 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2031
2032 broadcast(
2033 Some(session.connection_id),
2034 guest_connection_ids,
2035 |connection_id| {
2036 session
2037 .peer
2038 .forward_send(session.connection_id, connection_id, request.clone())
2039 },
2040 );
2041 if host_connection_id != session.connection_id {
2042 session
2043 .peer
2044 .forward_request(session.connection_id, host_connection_id, request.clone())
2045 .await?;
2046 }
2047
2048 response.send(proto::Ack {})?;
2049 Ok(())
2050}
2051
2052/// Notify other participants that a project has been updated.
2053async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2054 request: T,
2055 session: Session,
2056) -> Result<()> {
2057 let project_id = ProjectId::from_proto(request.remote_entity_id());
2058 let project_connection_ids = session
2059 .db()
2060 .await
2061 .project_connection_ids(project_id, session.connection_id)
2062 .await?;
2063
2064 broadcast(
2065 Some(session.connection_id),
2066 project_connection_ids.iter().copied(),
2067 |connection_id| {
2068 session
2069 .peer
2070 .forward_send(session.connection_id, connection_id, request.clone())
2071 },
2072 );
2073 Ok(())
2074}
2075
2076/// Start following another user in a call.
2077async fn follow(
2078 request: proto::Follow,
2079 response: Response<proto::Follow>,
2080 session: Session,
2081) -> Result<()> {
2082 let room_id = RoomId::from_proto(request.room_id);
2083 let project_id = request.project_id.map(ProjectId::from_proto);
2084 let leader_id = request
2085 .leader_id
2086 .ok_or_else(|| anyhow!("invalid leader id"))?
2087 .into();
2088 let follower_id = session.connection_id;
2089
2090 session
2091 .db()
2092 .await
2093 .check_room_participants(room_id, leader_id, session.connection_id)
2094 .await?;
2095
2096 let response_payload = session
2097 .peer
2098 .forward_request(session.connection_id, leader_id, request)
2099 .await?;
2100 response.send(response_payload)?;
2101
2102 if let Some(project_id) = project_id {
2103 let room = session
2104 .db()
2105 .await
2106 .follow(room_id, project_id, leader_id, follower_id)
2107 .await?;
2108 room_updated(&room, &session.peer);
2109 }
2110
2111 Ok(())
2112}
2113
2114/// Stop following another user in a call.
2115async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2116 let room_id = RoomId::from_proto(request.room_id);
2117 let project_id = request.project_id.map(ProjectId::from_proto);
2118 let leader_id = request
2119 .leader_id
2120 .ok_or_else(|| anyhow!("invalid leader id"))?
2121 .into();
2122 let follower_id = session.connection_id;
2123
2124 session
2125 .db()
2126 .await
2127 .check_room_participants(room_id, leader_id, session.connection_id)
2128 .await?;
2129
2130 session
2131 .peer
2132 .forward_send(session.connection_id, leader_id, request)?;
2133
2134 if let Some(project_id) = project_id {
2135 let room = session
2136 .db()
2137 .await
2138 .unfollow(room_id, project_id, leader_id, follower_id)
2139 .await?;
2140 room_updated(&room, &session.peer);
2141 }
2142
2143 Ok(())
2144}
2145
2146/// Notify everyone following you of your current location.
2147async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2148 let room_id = RoomId::from_proto(request.room_id);
2149 let database = session.db.lock().await;
2150
2151 let connection_ids = if let Some(project_id) = request.project_id {
2152 let project_id = ProjectId::from_proto(project_id);
2153 database
2154 .project_connection_ids(project_id, session.connection_id)
2155 .await?
2156 } else {
2157 database
2158 .room_connection_ids(room_id, session.connection_id)
2159 .await?
2160 };
2161
2162 // For now, don't send view update messages back to that view's current leader.
2163 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2164 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2165 _ => None,
2166 });
2167
2168 for connection_id in connection_ids.iter().cloned() {
2169 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2170 session
2171 .peer
2172 .forward_send(session.connection_id, connection_id, request.clone())?;
2173 }
2174 }
2175 Ok(())
2176}
2177
2178/// Get public data about users.
2179async fn get_users(
2180 request: proto::GetUsers,
2181 response: Response<proto::GetUsers>,
2182 session: Session,
2183) -> Result<()> {
2184 let user_ids = request
2185 .user_ids
2186 .into_iter()
2187 .map(UserId::from_proto)
2188 .collect();
2189 let users = session
2190 .db()
2191 .await
2192 .get_users_by_ids(user_ids)
2193 .await?
2194 .into_iter()
2195 .map(|user| proto::User {
2196 id: user.id.to_proto(),
2197 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2198 github_login: user.github_login,
2199 })
2200 .collect();
2201 response.send(proto::UsersResponse { users })?;
2202 Ok(())
2203}
2204
2205/// Search for users (to invite) buy Github login
2206async fn fuzzy_search_users(
2207 request: proto::FuzzySearchUsers,
2208 response: Response<proto::FuzzySearchUsers>,
2209 session: Session,
2210) -> Result<()> {
2211 let query = request.query;
2212 let users = match query.len() {
2213 0 => vec![],
2214 1 | 2 => session
2215 .db()
2216 .await
2217 .get_user_by_github_login(&query)
2218 .await?
2219 .into_iter()
2220 .collect(),
2221 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2222 };
2223 let users = users
2224 .into_iter()
2225 .filter(|user| user.id != session.user_id)
2226 .map(|user| proto::User {
2227 id: user.id.to_proto(),
2228 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2229 github_login: user.github_login,
2230 })
2231 .collect();
2232 response.send(proto::UsersResponse { users })?;
2233 Ok(())
2234}
2235
2236/// Send a contact request to another user.
2237async fn request_contact(
2238 request: proto::RequestContact,
2239 response: Response<proto::RequestContact>,
2240 session: Session,
2241) -> Result<()> {
2242 let requester_id = session.user_id;
2243 let responder_id = UserId::from_proto(request.responder_id);
2244 if requester_id == responder_id {
2245 return Err(anyhow!("cannot add yourself as a contact"))?;
2246 }
2247
2248 let notifications = session
2249 .db()
2250 .await
2251 .send_contact_request(requester_id, responder_id)
2252 .await?;
2253
2254 // Update outgoing contact requests of requester
2255 let mut update = proto::UpdateContacts::default();
2256 update.outgoing_requests.push(responder_id.to_proto());
2257 for connection_id in session
2258 .connection_pool()
2259 .await
2260 .user_connection_ids(requester_id)
2261 {
2262 session.peer.send(connection_id, update.clone())?;
2263 }
2264
2265 // Update incoming contact requests of responder
2266 let mut update = proto::UpdateContacts::default();
2267 update
2268 .incoming_requests
2269 .push(proto::IncomingContactRequest {
2270 requester_id: requester_id.to_proto(),
2271 });
2272 let connection_pool = session.connection_pool().await;
2273 for connection_id in connection_pool.user_connection_ids(responder_id) {
2274 session.peer.send(connection_id, update.clone())?;
2275 }
2276
2277 send_notifications(&connection_pool, &session.peer, notifications);
2278
2279 response.send(proto::Ack {})?;
2280 Ok(())
2281}
2282
2283/// Accept or decline a contact request
2284async fn respond_to_contact_request(
2285 request: proto::RespondToContactRequest,
2286 response: Response<proto::RespondToContactRequest>,
2287 session: Session,
2288) -> Result<()> {
2289 let responder_id = session.user_id;
2290 let requester_id = UserId::from_proto(request.requester_id);
2291 let db = session.db().await;
2292 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2293 db.dismiss_contact_notification(responder_id, requester_id)
2294 .await?;
2295 } else {
2296 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2297
2298 let notifications = db
2299 .respond_to_contact_request(responder_id, requester_id, accept)
2300 .await?;
2301 let requester_busy = db.is_user_busy(requester_id).await?;
2302 let responder_busy = db.is_user_busy(responder_id).await?;
2303
2304 let pool = session.connection_pool().await;
2305 // Update responder with new contact
2306 let mut update = proto::UpdateContacts::default();
2307 if accept {
2308 update
2309 .contacts
2310 .push(contact_for_user(requester_id, requester_busy, &pool));
2311 }
2312 update
2313 .remove_incoming_requests
2314 .push(requester_id.to_proto());
2315 for connection_id in pool.user_connection_ids(responder_id) {
2316 session.peer.send(connection_id, update.clone())?;
2317 }
2318
2319 // Update requester with new contact
2320 let mut update = proto::UpdateContacts::default();
2321 if accept {
2322 update
2323 .contacts
2324 .push(contact_for_user(responder_id, responder_busy, &pool));
2325 }
2326 update
2327 .remove_outgoing_requests
2328 .push(responder_id.to_proto());
2329
2330 for connection_id in pool.user_connection_ids(requester_id) {
2331 session.peer.send(connection_id, update.clone())?;
2332 }
2333
2334 send_notifications(&pool, &session.peer, notifications);
2335 }
2336
2337 response.send(proto::Ack {})?;
2338 Ok(())
2339}
2340
2341/// Remove a contact.
2342async fn remove_contact(
2343 request: proto::RemoveContact,
2344 response: Response<proto::RemoveContact>,
2345 session: Session,
2346) -> Result<()> {
2347 let requester_id = session.user_id;
2348 let responder_id = UserId::from_proto(request.user_id);
2349 let db = session.db().await;
2350 let (contact_accepted, deleted_notification_id) =
2351 db.remove_contact(requester_id, responder_id).await?;
2352
2353 let pool = session.connection_pool().await;
2354 // Update outgoing contact requests of requester
2355 let mut update = proto::UpdateContacts::default();
2356 if contact_accepted {
2357 update.remove_contacts.push(responder_id.to_proto());
2358 } else {
2359 update
2360 .remove_outgoing_requests
2361 .push(responder_id.to_proto());
2362 }
2363 for connection_id in pool.user_connection_ids(requester_id) {
2364 session.peer.send(connection_id, update.clone())?;
2365 }
2366
2367 // Update incoming contact requests of responder
2368 let mut update = proto::UpdateContacts::default();
2369 if contact_accepted {
2370 update.remove_contacts.push(requester_id.to_proto());
2371 } else {
2372 update
2373 .remove_incoming_requests
2374 .push(requester_id.to_proto());
2375 }
2376 for connection_id in pool.user_connection_ids(responder_id) {
2377 session.peer.send(connection_id, update.clone())?;
2378 if let Some(notification_id) = deleted_notification_id {
2379 session.peer.send(
2380 connection_id,
2381 proto::DeleteNotification {
2382 notification_id: notification_id.to_proto(),
2383 },
2384 )?;
2385 }
2386 }
2387
2388 response.send(proto::Ack {})?;
2389 Ok(())
2390}
2391
2392/// Creates a new channel.
2393async fn create_channel(
2394 request: proto::CreateChannel,
2395 response: Response<proto::CreateChannel>,
2396 session: Session,
2397) -> Result<()> {
2398 let db = session.db().await;
2399
2400 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2401 let (channel, owner, channel_members) = db
2402 .create_channel(&request.name, parent_id, session.user_id)
2403 .await?;
2404
2405 response.send(proto::CreateChannelResponse {
2406 channel: Some(channel.to_proto()),
2407 parent_id: request.parent_id,
2408 })?;
2409
2410 let connection_pool = session.connection_pool().await;
2411 if let Some(owner) = owner {
2412 let update = proto::UpdateUserChannels {
2413 channel_memberships: vec![proto::ChannelMembership {
2414 channel_id: owner.channel_id.to_proto(),
2415 role: owner.role.into(),
2416 }],
2417 ..Default::default()
2418 };
2419 for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2420 session.peer.send(connection_id, update.clone())?;
2421 }
2422 }
2423
2424 for channel_member in channel_members {
2425 if !channel_member.role.can_see_channel(channel.visibility) {
2426 continue;
2427 }
2428
2429 let update = proto::UpdateChannels {
2430 channels: vec![channel.to_proto()],
2431 ..Default::default()
2432 };
2433 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2434 session.peer.send(connection_id, update.clone())?;
2435 }
2436 }
2437
2438 Ok(())
2439}
2440
2441/// Delete a channel
2442async fn delete_channel(
2443 request: proto::DeleteChannel,
2444 response: Response<proto::DeleteChannel>,
2445 session: Session,
2446) -> Result<()> {
2447 let db = session.db().await;
2448
2449 let channel_id = request.channel_id;
2450 let (removed_channels, member_ids) = db
2451 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2452 .await?;
2453 response.send(proto::Ack {})?;
2454
2455 // Notify members of removed channels
2456 let mut update = proto::UpdateChannels::default();
2457 update
2458 .delete_channels
2459 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2460
2461 let connection_pool = session.connection_pool().await;
2462 for member_id in member_ids {
2463 for connection_id in connection_pool.user_connection_ids(member_id) {
2464 session.peer.send(connection_id, update.clone())?;
2465 }
2466 }
2467
2468 Ok(())
2469}
2470
2471/// Invite someone to join a channel.
2472async fn invite_channel_member(
2473 request: proto::InviteChannelMember,
2474 response: Response<proto::InviteChannelMember>,
2475 session: Session,
2476) -> Result<()> {
2477 let db = session.db().await;
2478 let channel_id = ChannelId::from_proto(request.channel_id);
2479 let invitee_id = UserId::from_proto(request.user_id);
2480 let InviteMemberResult {
2481 channel,
2482 notifications,
2483 } = db
2484 .invite_channel_member(
2485 channel_id,
2486 invitee_id,
2487 session.user_id,
2488 request.role().into(),
2489 )
2490 .await?;
2491
2492 let update = proto::UpdateChannels {
2493 channel_invitations: vec![channel.to_proto()],
2494 ..Default::default()
2495 };
2496
2497 let connection_pool = session.connection_pool().await;
2498 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2499 session.peer.send(connection_id, update.clone())?;
2500 }
2501
2502 send_notifications(&connection_pool, &session.peer, notifications);
2503
2504 response.send(proto::Ack {})?;
2505 Ok(())
2506}
2507
2508/// remove someone from a channel
2509async fn remove_channel_member(
2510 request: proto::RemoveChannelMember,
2511 response: Response<proto::RemoveChannelMember>,
2512 session: Session,
2513) -> Result<()> {
2514 let db = session.db().await;
2515 let channel_id = ChannelId::from_proto(request.channel_id);
2516 let member_id = UserId::from_proto(request.user_id);
2517
2518 let RemoveChannelMemberResult {
2519 membership_update,
2520 notification_id,
2521 } = db
2522 .remove_channel_member(channel_id, member_id, session.user_id)
2523 .await?;
2524
2525 let connection_pool = &session.connection_pool().await;
2526 notify_membership_updated(
2527 &connection_pool,
2528 membership_update,
2529 member_id,
2530 &session.peer,
2531 );
2532 for connection_id in connection_pool.user_connection_ids(member_id) {
2533 if let Some(notification_id) = notification_id {
2534 session
2535 .peer
2536 .send(
2537 connection_id,
2538 proto::DeleteNotification {
2539 notification_id: notification_id.to_proto(),
2540 },
2541 )
2542 .trace_err();
2543 }
2544 }
2545
2546 response.send(proto::Ack {})?;
2547 Ok(())
2548}
2549
2550/// Toggle the channel between public and private.
2551/// Care is taken to maintain the invariant that public channels only descend from public channels,
2552/// (though members-only channels can appear at any point in the hierarchy).
2553async fn set_channel_visibility(
2554 request: proto::SetChannelVisibility,
2555 response: Response<proto::SetChannelVisibility>,
2556 session: Session,
2557) -> Result<()> {
2558 let db = session.db().await;
2559 let channel_id = ChannelId::from_proto(request.channel_id);
2560 let visibility = request.visibility().into();
2561
2562 let (channel, channel_members) = db
2563 .set_channel_visibility(channel_id, visibility, session.user_id)
2564 .await?;
2565
2566 let connection_pool = session.connection_pool().await;
2567 for member in channel_members {
2568 let update = if member.role.can_see_channel(channel.visibility) {
2569 proto::UpdateChannels {
2570 channels: vec![channel.to_proto()],
2571 ..Default::default()
2572 }
2573 } else {
2574 proto::UpdateChannels {
2575 delete_channels: vec![channel.id.to_proto()],
2576 ..Default::default()
2577 }
2578 };
2579
2580 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2581 session.peer.send(connection_id, update.clone())?;
2582 }
2583 }
2584
2585 response.send(proto::Ack {})?;
2586 Ok(())
2587}
2588
2589/// Alter the role for a user in the channel.
2590async fn set_channel_member_role(
2591 request: proto::SetChannelMemberRole,
2592 response: Response<proto::SetChannelMemberRole>,
2593 session: Session,
2594) -> Result<()> {
2595 let db = session.db().await;
2596 let channel_id = ChannelId::from_proto(request.channel_id);
2597 let member_id = UserId::from_proto(request.user_id);
2598 let result = db
2599 .set_channel_member_role(
2600 channel_id,
2601 session.user_id,
2602 member_id,
2603 request.role().into(),
2604 )
2605 .await?;
2606
2607 match result {
2608 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2609 let connection_pool = session.connection_pool().await;
2610 notify_membership_updated(
2611 &connection_pool,
2612 membership_update,
2613 member_id,
2614 &session.peer,
2615 )
2616 }
2617 db::SetMemberRoleResult::InviteUpdated(channel) => {
2618 let update = proto::UpdateChannels {
2619 channel_invitations: vec![channel.to_proto()],
2620 ..Default::default()
2621 };
2622
2623 for connection_id in session
2624 .connection_pool()
2625 .await
2626 .user_connection_ids(member_id)
2627 {
2628 session.peer.send(connection_id, update.clone())?;
2629 }
2630 }
2631 }
2632
2633 response.send(proto::Ack {})?;
2634 Ok(())
2635}
2636
2637/// Change the name of a channel
2638async fn rename_channel(
2639 request: proto::RenameChannel,
2640 response: Response<proto::RenameChannel>,
2641 session: Session,
2642) -> Result<()> {
2643 let db = session.db().await;
2644 let channel_id = ChannelId::from_proto(request.channel_id);
2645 let (channel, channel_members) = db
2646 .rename_channel(channel_id, session.user_id, &request.name)
2647 .await?;
2648
2649 response.send(proto::RenameChannelResponse {
2650 channel: Some(channel.to_proto()),
2651 })?;
2652
2653 let connection_pool = session.connection_pool().await;
2654 for channel_member in channel_members {
2655 if !channel_member.role.can_see_channel(channel.visibility) {
2656 continue;
2657 }
2658 let update = proto::UpdateChannels {
2659 channels: vec![channel.to_proto()],
2660 ..Default::default()
2661 };
2662 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2663 session.peer.send(connection_id, update.clone())?;
2664 }
2665 }
2666
2667 Ok(())
2668}
2669
2670/// Move a channel to a new parent.
2671async fn move_channel(
2672 request: proto::MoveChannel,
2673 response: Response<proto::MoveChannel>,
2674 session: Session,
2675) -> Result<()> {
2676 let channel_id = ChannelId::from_proto(request.channel_id);
2677 let to = ChannelId::from_proto(request.to);
2678
2679 let (channels, channel_members) = session
2680 .db()
2681 .await
2682 .move_channel(channel_id, to, session.user_id)
2683 .await?;
2684
2685 let connection_pool = session.connection_pool().await;
2686 for member in channel_members {
2687 let channels = channels
2688 .iter()
2689 .filter_map(|channel| {
2690 if member.role.can_see_channel(channel.visibility) {
2691 Some(channel.to_proto())
2692 } else {
2693 None
2694 }
2695 })
2696 .collect::<Vec<_>>();
2697 if channels.is_empty() {
2698 continue;
2699 }
2700
2701 let update = proto::UpdateChannels {
2702 channels,
2703 ..Default::default()
2704 };
2705
2706 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2707 session.peer.send(connection_id, update.clone())?;
2708 }
2709 }
2710
2711 response.send(Ack {})?;
2712 Ok(())
2713}
2714
2715/// Get the list of channel members
2716async fn get_channel_members(
2717 request: proto::GetChannelMembers,
2718 response: Response<proto::GetChannelMembers>,
2719 session: Session,
2720) -> Result<()> {
2721 let db = session.db().await;
2722 let channel_id = ChannelId::from_proto(request.channel_id);
2723 let members = db
2724 .get_channel_participant_details(channel_id, session.user_id)
2725 .await?;
2726 response.send(proto::GetChannelMembersResponse { members })?;
2727 Ok(())
2728}
2729
2730/// Accept or decline a channel invitation.
2731async fn respond_to_channel_invite(
2732 request: proto::RespondToChannelInvite,
2733 response: Response<proto::RespondToChannelInvite>,
2734 session: Session,
2735) -> Result<()> {
2736 let db = session.db().await;
2737 let channel_id = ChannelId::from_proto(request.channel_id);
2738 let RespondToChannelInvite {
2739 membership_update,
2740 notifications,
2741 } = db
2742 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2743 .await?;
2744
2745 let connection_pool = session.connection_pool().await;
2746 if let Some(membership_update) = membership_update {
2747 notify_membership_updated(
2748 &connection_pool,
2749 membership_update,
2750 session.user_id,
2751 &session.peer,
2752 );
2753 } else {
2754 let update = proto::UpdateChannels {
2755 remove_channel_invitations: vec![channel_id.to_proto()],
2756 ..Default::default()
2757 };
2758
2759 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2760 session.peer.send(connection_id, update.clone())?;
2761 }
2762 };
2763
2764 send_notifications(&connection_pool, &session.peer, notifications);
2765
2766 response.send(proto::Ack {})?;
2767
2768 Ok(())
2769}
2770
2771/// Join the channels' room
2772async fn join_channel(
2773 request: proto::JoinChannel,
2774 response: Response<proto::JoinChannel>,
2775 session: Session,
2776) -> Result<()> {
2777 let channel_id = ChannelId::from_proto(request.channel_id);
2778 join_channel_internal(channel_id, Box::new(response), session).await
2779}
2780
2781trait JoinChannelInternalResponse {
2782 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2783}
2784impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2785 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2786 Response::<proto::JoinChannel>::send(self, result)
2787 }
2788}
2789impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2790 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2791 Response::<proto::JoinRoom>::send(self, result)
2792 }
2793}
2794
2795async fn join_channel_internal(
2796 channel_id: ChannelId,
2797 response: Box<impl JoinChannelInternalResponse>,
2798 session: Session,
2799) -> Result<()> {
2800 let joined_room = {
2801 leave_room_for_session(&session).await?;
2802 let db = session.db().await;
2803
2804 let (joined_room, membership_updated, role) = db
2805 .join_channel(channel_id, session.user_id, session.connection_id)
2806 .await?;
2807
2808 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2809 let (can_publish, token) = if role == ChannelRole::Guest {
2810 (
2811 false,
2812 live_kit
2813 .guest_token(
2814 &joined_room.room.live_kit_room,
2815 &session.user_id.to_string(),
2816 )
2817 .trace_err()?,
2818 )
2819 } else {
2820 (
2821 true,
2822 live_kit
2823 .room_token(
2824 &joined_room.room.live_kit_room,
2825 &session.user_id.to_string(),
2826 )
2827 .trace_err()?,
2828 )
2829 };
2830
2831 Some(LiveKitConnectionInfo {
2832 server_url: live_kit.url().into(),
2833 token,
2834 can_publish,
2835 })
2836 });
2837
2838 response.send(proto::JoinRoomResponse {
2839 room: Some(joined_room.room.clone()),
2840 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2841 live_kit_connection_info,
2842 })?;
2843
2844 let connection_pool = session.connection_pool().await;
2845 if let Some(membership_updated) = membership_updated {
2846 notify_membership_updated(
2847 &connection_pool,
2848 membership_updated,
2849 session.user_id,
2850 &session.peer,
2851 );
2852 }
2853
2854 room_updated(&joined_room.room, &session.peer);
2855
2856 joined_room
2857 };
2858
2859 channel_updated(
2860 channel_id,
2861 &joined_room.room,
2862 &joined_room.channel_members,
2863 &session.peer,
2864 &*session.connection_pool().await,
2865 );
2866
2867 update_user_contacts(session.user_id, &session).await?;
2868 Ok(())
2869}
2870
2871/// Start editing the channel notes
2872async fn join_channel_buffer(
2873 request: proto::JoinChannelBuffer,
2874 response: Response<proto::JoinChannelBuffer>,
2875 session: Session,
2876) -> Result<()> {
2877 let db = session.db().await;
2878 let channel_id = ChannelId::from_proto(request.channel_id);
2879
2880 let open_response = db
2881 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2882 .await?;
2883
2884 let collaborators = open_response.collaborators.clone();
2885 response.send(open_response)?;
2886
2887 let update = UpdateChannelBufferCollaborators {
2888 channel_id: channel_id.to_proto(),
2889 collaborators: collaborators.clone(),
2890 };
2891 channel_buffer_updated(
2892 session.connection_id,
2893 collaborators
2894 .iter()
2895 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2896 &update,
2897 &session.peer,
2898 );
2899
2900 Ok(())
2901}
2902
2903/// Edit the channel notes
2904async fn update_channel_buffer(
2905 request: proto::UpdateChannelBuffer,
2906 session: Session,
2907) -> Result<()> {
2908 let db = session.db().await;
2909 let channel_id = ChannelId::from_proto(request.channel_id);
2910
2911 let (collaborators, non_collaborators, epoch, version) = db
2912 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2913 .await?;
2914
2915 channel_buffer_updated(
2916 session.connection_id,
2917 collaborators,
2918 &proto::UpdateChannelBuffer {
2919 channel_id: channel_id.to_proto(),
2920 operations: request.operations,
2921 },
2922 &session.peer,
2923 );
2924
2925 let pool = &*session.connection_pool().await;
2926
2927 broadcast(
2928 None,
2929 non_collaborators
2930 .iter()
2931 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2932 |peer_id| {
2933 session.peer.send(
2934 peer_id,
2935 proto::UpdateChannels {
2936 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2937 channel_id: channel_id.to_proto(),
2938 epoch: epoch as u64,
2939 version: version.clone(),
2940 }],
2941 ..Default::default()
2942 },
2943 )
2944 },
2945 );
2946
2947 Ok(())
2948}
2949
2950/// Rejoin the channel notes after a connection blip
2951async fn rejoin_channel_buffers(
2952 request: proto::RejoinChannelBuffers,
2953 response: Response<proto::RejoinChannelBuffers>,
2954 session: Session,
2955) -> Result<()> {
2956 let db = session.db().await;
2957 let buffers = db
2958 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2959 .await?;
2960
2961 for rejoined_buffer in &buffers {
2962 let collaborators_to_notify = rejoined_buffer
2963 .buffer
2964 .collaborators
2965 .iter()
2966 .filter_map(|c| Some(c.peer_id?.into()));
2967 channel_buffer_updated(
2968 session.connection_id,
2969 collaborators_to_notify,
2970 &proto::UpdateChannelBufferCollaborators {
2971 channel_id: rejoined_buffer.buffer.channel_id,
2972 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2973 },
2974 &session.peer,
2975 );
2976 }
2977
2978 response.send(proto::RejoinChannelBuffersResponse {
2979 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2980 })?;
2981
2982 Ok(())
2983}
2984
2985/// Stop editing the channel notes
2986async fn leave_channel_buffer(
2987 request: proto::LeaveChannelBuffer,
2988 response: Response<proto::LeaveChannelBuffer>,
2989 session: Session,
2990) -> Result<()> {
2991 let db = session.db().await;
2992 let channel_id = ChannelId::from_proto(request.channel_id);
2993
2994 let left_buffer = db
2995 .leave_channel_buffer(channel_id, session.connection_id)
2996 .await?;
2997
2998 response.send(Ack {})?;
2999
3000 channel_buffer_updated(
3001 session.connection_id,
3002 left_buffer.connections,
3003 &proto::UpdateChannelBufferCollaborators {
3004 channel_id: channel_id.to_proto(),
3005 collaborators: left_buffer.collaborators,
3006 },
3007 &session.peer,
3008 );
3009
3010 Ok(())
3011}
3012
3013fn channel_buffer_updated<T: EnvelopedMessage>(
3014 sender_id: ConnectionId,
3015 collaborators: impl IntoIterator<Item = ConnectionId>,
3016 message: &T,
3017 peer: &Peer,
3018) {
3019 broadcast(Some(sender_id), collaborators, |peer_id| {
3020 peer.send(peer_id, message.clone())
3021 });
3022}
3023
3024fn send_notifications(
3025 connection_pool: &ConnectionPool,
3026 peer: &Peer,
3027 notifications: db::NotificationBatch,
3028) {
3029 for (user_id, notification) in notifications {
3030 for connection_id in connection_pool.user_connection_ids(user_id) {
3031 if let Err(error) = peer.send(
3032 connection_id,
3033 proto::AddNotification {
3034 notification: Some(notification.clone()),
3035 },
3036 ) {
3037 tracing::error!(
3038 "failed to send notification to {:?} {}",
3039 connection_id,
3040 error
3041 );
3042 }
3043 }
3044 }
3045}
3046
3047/// Send a message to the channel
3048async fn send_channel_message(
3049 request: proto::SendChannelMessage,
3050 response: Response<proto::SendChannelMessage>,
3051 session: Session,
3052) -> Result<()> {
3053 // Validate the message body.
3054 let body = request.body.trim().to_string();
3055 if body.len() > MAX_MESSAGE_LEN {
3056 return Err(anyhow!("message is too long"))?;
3057 }
3058 if body.is_empty() {
3059 return Err(anyhow!("message can't be blank"))?;
3060 }
3061
3062 // TODO: adjust mentions if body is trimmed
3063
3064 let timestamp = OffsetDateTime::now_utc();
3065 let nonce = request
3066 .nonce
3067 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3068
3069 let channel_id = ChannelId::from_proto(request.channel_id);
3070 let CreatedChannelMessage {
3071 message_id,
3072 participant_connection_ids,
3073 channel_members,
3074 notifications,
3075 } = session
3076 .db()
3077 .await
3078 .create_channel_message(
3079 channel_id,
3080 session.user_id,
3081 &body,
3082 &request.mentions,
3083 timestamp,
3084 nonce.clone().into(),
3085 match request.reply_to_message_id {
3086 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3087 None => None,
3088 },
3089 )
3090 .await?;
3091 let message = proto::ChannelMessage {
3092 sender_id: session.user_id.to_proto(),
3093 id: message_id.to_proto(),
3094 body,
3095 mentions: request.mentions,
3096 timestamp: timestamp.unix_timestamp() as u64,
3097 nonce: Some(nonce),
3098 reply_to_message_id: request.reply_to_message_id,
3099 };
3100 broadcast(
3101 Some(session.connection_id),
3102 participant_connection_ids,
3103 |connection| {
3104 session.peer.send(
3105 connection,
3106 proto::ChannelMessageSent {
3107 channel_id: channel_id.to_proto(),
3108 message: Some(message.clone()),
3109 },
3110 )
3111 },
3112 );
3113 response.send(proto::SendChannelMessageResponse {
3114 message: Some(message),
3115 })?;
3116
3117 let pool = &*session.connection_pool().await;
3118 broadcast(
3119 None,
3120 channel_members
3121 .iter()
3122 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3123 |peer_id| {
3124 session.peer.send(
3125 peer_id,
3126 proto::UpdateChannels {
3127 latest_channel_message_ids: vec![proto::ChannelMessageId {
3128 channel_id: channel_id.to_proto(),
3129 message_id: message_id.to_proto(),
3130 }],
3131 ..Default::default()
3132 },
3133 )
3134 },
3135 );
3136 send_notifications(pool, &session.peer, notifications);
3137
3138 Ok(())
3139}
3140
3141/// Delete a channel message
3142async fn remove_channel_message(
3143 request: proto::RemoveChannelMessage,
3144 response: Response<proto::RemoveChannelMessage>,
3145 session: Session,
3146) -> Result<()> {
3147 let channel_id = ChannelId::from_proto(request.channel_id);
3148 let message_id = MessageId::from_proto(request.message_id);
3149 let connection_ids = session
3150 .db()
3151 .await
3152 .remove_channel_message(channel_id, message_id, session.user_id)
3153 .await?;
3154 broadcast(Some(session.connection_id), connection_ids, |connection| {
3155 session.peer.send(connection, request.clone())
3156 });
3157 response.send(proto::Ack {})?;
3158 Ok(())
3159}
3160
3161/// Mark a channel message as read
3162async fn acknowledge_channel_message(
3163 request: proto::AckChannelMessage,
3164 session: Session,
3165) -> Result<()> {
3166 let channel_id = ChannelId::from_proto(request.channel_id);
3167 let message_id = MessageId::from_proto(request.message_id);
3168 let notifications = session
3169 .db()
3170 .await
3171 .observe_channel_message(channel_id, session.user_id, message_id)
3172 .await?;
3173 send_notifications(
3174 &*session.connection_pool().await,
3175 &session.peer,
3176 notifications,
3177 );
3178 Ok(())
3179}
3180
3181/// Mark a buffer version as synced
3182async fn acknowledge_buffer_version(
3183 request: proto::AckBufferOperation,
3184 session: Session,
3185) -> Result<()> {
3186 let buffer_id = BufferId::from_proto(request.buffer_id);
3187 session
3188 .db()
3189 .await
3190 .observe_buffer_version(
3191 buffer_id,
3192 session.user_id,
3193 request.epoch as i32,
3194 &request.version,
3195 )
3196 .await?;
3197 Ok(())
3198}
3199
3200/// Start receiving chat updates for a channel
3201async fn join_channel_chat(
3202 request: proto::JoinChannelChat,
3203 response: Response<proto::JoinChannelChat>,
3204 session: Session,
3205) -> Result<()> {
3206 let channel_id = ChannelId::from_proto(request.channel_id);
3207
3208 let db = session.db().await;
3209 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3210 .await?;
3211 let messages = db
3212 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3213 .await?;
3214 response.send(proto::JoinChannelChatResponse {
3215 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3216 messages,
3217 })?;
3218 Ok(())
3219}
3220
3221/// Stop receiving chat updates for a channel
3222async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3223 let channel_id = ChannelId::from_proto(request.channel_id);
3224 session
3225 .db()
3226 .await
3227 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3228 .await?;
3229 Ok(())
3230}
3231
3232/// Retrieve the chat history for a channel
3233async fn get_channel_messages(
3234 request: proto::GetChannelMessages,
3235 response: Response<proto::GetChannelMessages>,
3236 session: Session,
3237) -> Result<()> {
3238 let channel_id = ChannelId::from_proto(request.channel_id);
3239 let messages = session
3240 .db()
3241 .await
3242 .get_channel_messages(
3243 channel_id,
3244 session.user_id,
3245 MESSAGE_COUNT_PER_PAGE,
3246 Some(MessageId::from_proto(request.before_message_id)),
3247 )
3248 .await?;
3249 response.send(proto::GetChannelMessagesResponse {
3250 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3251 messages,
3252 })?;
3253 Ok(())
3254}
3255
3256/// Retrieve specific chat messages
3257async fn get_channel_messages_by_id(
3258 request: proto::GetChannelMessagesById,
3259 response: Response<proto::GetChannelMessagesById>,
3260 session: Session,
3261) -> Result<()> {
3262 let message_ids = request
3263 .message_ids
3264 .iter()
3265 .map(|id| MessageId::from_proto(*id))
3266 .collect::<Vec<_>>();
3267 let messages = session
3268 .db()
3269 .await
3270 .get_channel_messages_by_id(session.user_id, &message_ids)
3271 .await?;
3272 response.send(proto::GetChannelMessagesResponse {
3273 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3274 messages,
3275 })?;
3276 Ok(())
3277}
3278
3279/// Retrieve the current users notifications
3280async fn get_notifications(
3281 request: proto::GetNotifications,
3282 response: Response<proto::GetNotifications>,
3283 session: Session,
3284) -> Result<()> {
3285 let notifications = session
3286 .db()
3287 .await
3288 .get_notifications(
3289 session.user_id,
3290 NOTIFICATION_COUNT_PER_PAGE,
3291 request
3292 .before_id
3293 .map(|id| db::NotificationId::from_proto(id)),
3294 )
3295 .await?;
3296 response.send(proto::GetNotificationsResponse {
3297 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3298 notifications,
3299 })?;
3300 Ok(())
3301}
3302
3303/// Mark notifications as read
3304async fn mark_notification_as_read(
3305 request: proto::MarkNotificationRead,
3306 response: Response<proto::MarkNotificationRead>,
3307 session: Session,
3308) -> Result<()> {
3309 let database = &session.db().await;
3310 let notifications = database
3311 .mark_notification_as_read_by_id(
3312 session.user_id,
3313 NotificationId::from_proto(request.notification_id),
3314 )
3315 .await?;
3316 send_notifications(
3317 &*session.connection_pool().await,
3318 &session.peer,
3319 notifications,
3320 );
3321 response.send(proto::Ack {})?;
3322 Ok(())
3323}
3324
3325/// Get the current users information
3326async fn get_private_user_info(
3327 _request: proto::GetPrivateUserInfo,
3328 response: Response<proto::GetPrivateUserInfo>,
3329 session: Session,
3330) -> Result<()> {
3331 let db = session.db().await;
3332
3333 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3334 let user = db
3335 .get_user_by_id(session.user_id)
3336 .await?
3337 .ok_or_else(|| anyhow!("user not found"))?;
3338 let flags = db.get_user_flags(session.user_id).await?;
3339
3340 response.send(proto::GetPrivateUserInfoResponse {
3341 metrics_id,
3342 staff: user.admin,
3343 flags,
3344 })?;
3345 Ok(())
3346}
3347
3348fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3349 match message {
3350 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3351 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3352 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3353 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3354 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3355 code: frame.code.into(),
3356 reason: frame.reason,
3357 })),
3358 }
3359}
3360
3361fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3362 match message {
3363 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3364 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3365 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3366 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3367 AxumMessage::Close(frame) => {
3368 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3369 code: frame.code.into(),
3370 reason: frame.reason,
3371 }))
3372 }
3373 }
3374}
3375
3376fn notify_membership_updated(
3377 connection_pool: &ConnectionPool,
3378 result: MembershipUpdated,
3379 user_id: UserId,
3380 peer: &Peer,
3381) {
3382 let user_channels_update = proto::UpdateUserChannels {
3383 channel_memberships: result
3384 .new_channels
3385 .channel_memberships
3386 .iter()
3387 .map(|cm| proto::ChannelMembership {
3388 channel_id: cm.channel_id.to_proto(),
3389 role: cm.role.into(),
3390 })
3391 .collect(),
3392 ..Default::default()
3393 };
3394 let mut update = build_channels_update(result.new_channels, vec![]);
3395 update.delete_channels = result
3396 .removed_channels
3397 .into_iter()
3398 .map(|id| id.to_proto())
3399 .collect();
3400 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3401
3402 for connection_id in connection_pool.user_connection_ids(user_id) {
3403 peer.send(connection_id, user_channels_update.clone())
3404 .trace_err();
3405 peer.send(connection_id, update.clone()).trace_err();
3406 }
3407}
3408
3409fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3410 proto::UpdateUserChannels {
3411 channel_memberships: channels
3412 .channel_memberships
3413 .iter()
3414 .map(|m| proto::ChannelMembership {
3415 channel_id: m.channel_id.to_proto(),
3416 role: m.role.into(),
3417 })
3418 .collect(),
3419 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3420 observed_channel_message_id: channels.observed_channel_messages.clone(),
3421 }
3422}
3423
3424fn build_channels_update(
3425 channels: ChannelsForUser,
3426 channel_invites: Vec<db::Channel>,
3427) -> proto::UpdateChannels {
3428 let mut update = proto::UpdateChannels::default();
3429
3430 for channel in channels.channels {
3431 update.channels.push(channel.to_proto());
3432 }
3433
3434 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3435 update.latest_channel_message_ids = channels.latest_channel_messages;
3436
3437 for (channel_id, participants) in channels.channel_participants {
3438 update
3439 .channel_participants
3440 .push(proto::ChannelParticipants {
3441 channel_id: channel_id.to_proto(),
3442 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3443 });
3444 }
3445
3446 for channel in channel_invites {
3447 update.channel_invitations.push(channel.to_proto());
3448 }
3449 for project in channels.hosted_projects {
3450 update.hosted_projects.push(project);
3451 }
3452
3453 update
3454}
3455
3456fn build_initial_contacts_update(
3457 contacts: Vec<db::Contact>,
3458 pool: &ConnectionPool,
3459) -> proto::UpdateContacts {
3460 let mut update = proto::UpdateContacts::default();
3461
3462 for contact in contacts {
3463 match contact {
3464 db::Contact::Accepted { user_id, busy } => {
3465 update.contacts.push(contact_for_user(user_id, busy, &pool));
3466 }
3467 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3468 db::Contact::Incoming { user_id } => {
3469 update
3470 .incoming_requests
3471 .push(proto::IncomingContactRequest {
3472 requester_id: user_id.to_proto(),
3473 })
3474 }
3475 }
3476 }
3477
3478 update
3479}
3480
3481fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3482 proto::Contact {
3483 user_id: user_id.to_proto(),
3484 online: pool.is_user_online(user_id),
3485 busy,
3486 }
3487}
3488
3489fn room_updated(room: &proto::Room, peer: &Peer) {
3490 broadcast(
3491 None,
3492 room.participants
3493 .iter()
3494 .filter_map(|participant| Some(participant.peer_id?.into())),
3495 |peer_id| {
3496 peer.send(
3497 peer_id,
3498 proto::RoomUpdated {
3499 room: Some(room.clone()),
3500 },
3501 )
3502 },
3503 );
3504}
3505
3506fn channel_updated(
3507 channel_id: ChannelId,
3508 room: &proto::Room,
3509 channel_members: &[UserId],
3510 peer: &Peer,
3511 pool: &ConnectionPool,
3512) {
3513 let participants = room
3514 .participants
3515 .iter()
3516 .map(|p| p.user_id)
3517 .collect::<Vec<_>>();
3518
3519 broadcast(
3520 None,
3521 channel_members
3522 .iter()
3523 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3524 |peer_id| {
3525 peer.send(
3526 peer_id,
3527 proto::UpdateChannels {
3528 channel_participants: vec![proto::ChannelParticipants {
3529 channel_id: channel_id.to_proto(),
3530 participant_user_ids: participants.clone(),
3531 }],
3532 ..Default::default()
3533 },
3534 )
3535 },
3536 );
3537}
3538
3539async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3540 let db = session.db().await;
3541
3542 let contacts = db.get_contacts(user_id).await?;
3543 let busy = db.is_user_busy(user_id).await?;
3544
3545 let pool = session.connection_pool().await;
3546 let updated_contact = contact_for_user(user_id, busy, &pool);
3547 for contact in contacts {
3548 if let db::Contact::Accepted {
3549 user_id: contact_user_id,
3550 ..
3551 } = contact
3552 {
3553 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3554 session
3555 .peer
3556 .send(
3557 contact_conn_id,
3558 proto::UpdateContacts {
3559 contacts: vec![updated_contact.clone()],
3560 remove_contacts: Default::default(),
3561 incoming_requests: Default::default(),
3562 remove_incoming_requests: Default::default(),
3563 outgoing_requests: Default::default(),
3564 remove_outgoing_requests: Default::default(),
3565 },
3566 )
3567 .trace_err();
3568 }
3569 }
3570 }
3571 Ok(())
3572}
3573
3574async fn leave_room_for_session(session: &Session) -> Result<()> {
3575 let mut contacts_to_update = HashSet::default();
3576
3577 let room_id;
3578 let canceled_calls_to_user_ids;
3579 let live_kit_room;
3580 let delete_live_kit_room;
3581 let room;
3582 let channel_members;
3583 let channel_id;
3584
3585 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3586 contacts_to_update.insert(session.user_id);
3587
3588 for project in left_room.left_projects.values() {
3589 project_left(project, session);
3590 }
3591
3592 room_id = RoomId::from_proto(left_room.room.id);
3593 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3594 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3595 delete_live_kit_room = left_room.deleted;
3596 room = mem::take(&mut left_room.room);
3597 channel_members = mem::take(&mut left_room.channel_members);
3598 channel_id = left_room.channel_id;
3599
3600 room_updated(&room, &session.peer);
3601 } else {
3602 return Ok(());
3603 }
3604
3605 if let Some(channel_id) = channel_id {
3606 channel_updated(
3607 channel_id,
3608 &room,
3609 &channel_members,
3610 &session.peer,
3611 &*session.connection_pool().await,
3612 );
3613 }
3614
3615 {
3616 let pool = session.connection_pool().await;
3617 for canceled_user_id in canceled_calls_to_user_ids {
3618 for connection_id in pool.user_connection_ids(canceled_user_id) {
3619 session
3620 .peer
3621 .send(
3622 connection_id,
3623 proto::CallCanceled {
3624 room_id: room_id.to_proto(),
3625 },
3626 )
3627 .trace_err();
3628 }
3629 contacts_to_update.insert(canceled_user_id);
3630 }
3631 }
3632
3633 for contact_user_id in contacts_to_update {
3634 update_user_contacts(contact_user_id, &session).await?;
3635 }
3636
3637 if let Some(live_kit) = session.live_kit_client.as_ref() {
3638 live_kit
3639 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3640 .await
3641 .trace_err();
3642
3643 if delete_live_kit_room {
3644 live_kit.delete_room(live_kit_room).await.trace_err();
3645 }
3646 }
3647
3648 Ok(())
3649}
3650
3651async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3652 let left_channel_buffers = session
3653 .db()
3654 .await
3655 .leave_channel_buffers(session.connection_id)
3656 .await?;
3657
3658 for left_buffer in left_channel_buffers {
3659 channel_buffer_updated(
3660 session.connection_id,
3661 left_buffer.connections,
3662 &proto::UpdateChannelBufferCollaborators {
3663 channel_id: left_buffer.channel_id.to_proto(),
3664 collaborators: left_buffer.collaborators,
3665 },
3666 &session.peer,
3667 );
3668 }
3669
3670 Ok(())
3671}
3672
3673fn project_left(project: &db::LeftProject, session: &Session) {
3674 for connection_id in &project.connection_ids {
3675 if project.host_user_id == Some(session.user_id) {
3676 session
3677 .peer
3678 .send(
3679 *connection_id,
3680 proto::UnshareProject {
3681 project_id: project.id.to_proto(),
3682 },
3683 )
3684 .trace_err();
3685 } else {
3686 session
3687 .peer
3688 .send(
3689 *connection_id,
3690 proto::RemoveProjectCollaborator {
3691 project_id: project.id.to_proto(),
3692 peer_id: Some(session.connection_id.into()),
3693 },
3694 )
3695 .trace_err();
3696 }
3697 }
3698}
3699
3700pub trait ResultExt {
3701 type Ok;
3702
3703 fn trace_err(self) -> Option<Self::Ok>;
3704}
3705
3706impl<T, E> ResultExt for Result<T, E>
3707where
3708 E: std::fmt::Debug,
3709{
3710 type Ok = T;
3711
3712 fn trace_err(self) -> Option<T> {
3713 match self {
3714 Ok(value) => Some(value),
3715 Err(error) => {
3716 tracing::error!("{:?}", error);
3717 None
3718 }
3719 }
3720 }
3721}