1mod connection_pool;
2
3use crate::{
4 auth::{self, Impersonator},
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
7 InviteMemberResult, MembershipUpdated, MessageId, NotificationId, ProjectId,
8 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
9 },
10 executor::Executor,
11 AppState, Error, Result,
12};
13use anyhow::anyhow;
14use async_tungstenite::tungstenite::{
15 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
16};
17use axum::{
18 body::Body,
19 extract::{
20 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
21 ConnectInfo, WebSocketUpgrade,
22 },
23 headers::{Header, HeaderName},
24 http::StatusCode,
25 middleware,
26 response::IntoResponse,
27 routing::get,
28 Extension, Router, TypedHeader,
29};
30use collections::{HashMap, HashSet};
31pub use connection_pool::{ConnectionPool, ZedVersion};
32use futures::{
33 channel::oneshot,
34 future::{self, BoxFuture},
35 stream::FuturesUnordered,
36 FutureExt, SinkExt, StreamExt, TryStreamExt,
37};
38use prometheus::{register_int_gauge, IntGauge};
39use rpc::{
40 proto::{
41 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
42 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
43 },
44 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
45};
46use serde::{Serialize, Serializer};
47use std::{
48 any::TypeId,
49 fmt,
50 future::Future,
51 marker::PhantomData,
52 mem,
53 net::SocketAddr,
54 ops::{Deref, DerefMut},
55 rc::Rc,
56 sync::{
57 atomic::{AtomicBool, Ordering::SeqCst},
58 Arc, OnceLock,
59 },
60 time::{Duration, Instant},
61};
62use time::OffsetDateTime;
63use tokio::sync::{watch, Semaphore};
64use tower::ServiceBuilder;
65use tracing::{field, info_span, instrument, Instrument};
66use util::SemanticVersion;
67
68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
70
71const MESSAGE_COUNT_PER_PAGE: usize = 100;
72const MAX_MESSAGE_LEN: usize = 1024;
73const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
74
75type MessageHandler =
76 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
77
78struct Response<R> {
79 peer: Arc<Peer>,
80 receipt: Receipt<R>,
81 responded: Arc<AtomicBool>,
82}
83
84impl<R: RequestMessage> Response<R> {
85 fn send(self, payload: R::Response) -> Result<()> {
86 self.responded.store(true, SeqCst);
87 self.peer.respond(self.receipt, payload)?;
88 Ok(())
89 }
90}
91
92#[derive(Clone)]
93struct Session {
94 user_id: UserId,
95 connection_id: ConnectionId,
96 db: Arc<tokio::sync::Mutex<DbHandle>>,
97 peer: Arc<Peer>,
98 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
99 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
100 _executor: Executor,
101}
102
103impl Session {
104 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
105 #[cfg(test)]
106 tokio::task::yield_now().await;
107 let guard = self.db.lock().await;
108 #[cfg(test)]
109 tokio::task::yield_now().await;
110 guard
111 }
112
113 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
114 #[cfg(test)]
115 tokio::task::yield_now().await;
116 let guard = self.connection_pool.lock();
117 ConnectionPoolGuard {
118 guard,
119 _not_send: PhantomData,
120 }
121 }
122}
123
124impl fmt::Debug for Session {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 f.debug_struct("Session")
127 .field("user_id", &self.user_id)
128 .field("connection_id", &self.connection_id)
129 .finish()
130 }
131}
132
133struct DbHandle(Arc<Database>);
134
135impl Deref for DbHandle {
136 type Target = Database;
137
138 fn deref(&self) -> &Self::Target {
139 self.0.as_ref()
140 }
141}
142
143pub struct Server {
144 id: parking_lot::Mutex<ServerId>,
145 peer: Arc<Peer>,
146 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
147 app_state: Arc<AppState>,
148 executor: Executor,
149 handlers: HashMap<TypeId, MessageHandler>,
150 teardown: watch::Sender<bool>,
151}
152
153pub(crate) struct ConnectionPoolGuard<'a> {
154 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
155 _not_send: PhantomData<Rc<()>>,
156}
157
158#[derive(Serialize)]
159pub struct ServerSnapshot<'a> {
160 peer: &'a Peer,
161 #[serde(serialize_with = "serialize_deref")]
162 connection_pool: ConnectionPoolGuard<'a>,
163}
164
165pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
166where
167 S: Serializer,
168 T: Deref<Target = U>,
169 U: Serialize,
170{
171 Serialize::serialize(value.deref(), serializer)
172}
173
174impl Server {
175 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
176 let mut server = Self {
177 id: parking_lot::Mutex::new(id),
178 peer: Peer::new(id.0 as u32),
179 app_state,
180 executor,
181 connection_pool: Default::default(),
182 handlers: Default::default(),
183 teardown: watch::channel(false).0,
184 };
185
186 server
187 .add_request_handler(ping)
188 .add_request_handler(create_room)
189 .add_request_handler(join_room)
190 .add_request_handler(rejoin_room)
191 .add_request_handler(leave_room)
192 .add_request_handler(set_room_participant_role)
193 .add_request_handler(call)
194 .add_request_handler(cancel_call)
195 .add_message_handler(decline_call)
196 .add_request_handler(update_participant_location)
197 .add_request_handler(share_project)
198 .add_message_handler(unshare_project)
199 .add_request_handler(join_project)
200 .add_message_handler(leave_project)
201 .add_request_handler(update_project)
202 .add_request_handler(update_worktree)
203 .add_message_handler(start_language_server)
204 .add_message_handler(update_language_server)
205 .add_message_handler(update_diagnostic_summary)
206 .add_message_handler(update_worktree_settings)
207 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
208 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
209 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
210 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
211 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
212 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
213 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
214 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
215 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
216 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
217 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
218 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
219 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
220 .add_request_handler(
221 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
222 )
223 .add_request_handler(
224 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
225 )
226 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
227 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
228 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
229 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
230 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
231 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
232 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
233 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
234 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
235 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
236 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
237 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
238 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
239 .add_message_handler(create_buffer_for_peer)
240 .add_request_handler(update_buffer)
241 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
242 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
243 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
244 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
245 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
246 .add_request_handler(get_users)
247 .add_request_handler(fuzzy_search_users)
248 .add_request_handler(request_contact)
249 .add_request_handler(remove_contact)
250 .add_request_handler(respond_to_contact_request)
251 .add_request_handler(create_channel)
252 .add_request_handler(delete_channel)
253 .add_request_handler(invite_channel_member)
254 .add_request_handler(remove_channel_member)
255 .add_request_handler(set_channel_member_role)
256 .add_request_handler(set_channel_visibility)
257 .add_request_handler(rename_channel)
258 .add_request_handler(join_channel_buffer)
259 .add_request_handler(leave_channel_buffer)
260 .add_message_handler(update_channel_buffer)
261 .add_request_handler(rejoin_channel_buffers)
262 .add_request_handler(get_channel_members)
263 .add_request_handler(respond_to_channel_invite)
264 .add_request_handler(join_channel)
265 .add_request_handler(join_channel_chat)
266 .add_message_handler(leave_channel_chat)
267 .add_request_handler(send_channel_message)
268 .add_request_handler(remove_channel_message)
269 .add_request_handler(get_channel_messages)
270 .add_request_handler(get_channel_messages_by_id)
271 .add_request_handler(get_notifications)
272 .add_request_handler(mark_notification_as_read)
273 .add_request_handler(move_channel)
274 .add_request_handler(follow)
275 .add_message_handler(unfollow)
276 .add_message_handler(update_followers)
277 .add_request_handler(get_private_user_info)
278 .add_message_handler(acknowledge_channel_message)
279 .add_message_handler(acknowledge_buffer_version);
280
281 Arc::new(server)
282 }
283
284 pub async fn start(&self) -> Result<()> {
285 let server_id = *self.id.lock();
286 let app_state = self.app_state.clone();
287 let peer = self.peer.clone();
288 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
289 let pool = self.connection_pool.clone();
290 let live_kit_client = self.app_state.live_kit_client.clone();
291
292 let span = info_span!("start server");
293 self.executor.spawn_detached(
294 async move {
295 tracing::info!("waiting for cleanup timeout");
296 timeout.await;
297 tracing::info!("cleanup timeout expired, retrieving stale rooms");
298 if let Some((room_ids, channel_ids)) = app_state
299 .db
300 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
301 .await
302 .trace_err()
303 {
304 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
305 tracing::info!(
306 stale_channel_buffer_count = channel_ids.len(),
307 "retrieved stale channel buffers"
308 );
309
310 for channel_id in channel_ids {
311 if let Some(refreshed_channel_buffer) = app_state
312 .db
313 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
314 .await
315 .trace_err()
316 {
317 for connection_id in refreshed_channel_buffer.connection_ids {
318 peer.send(
319 connection_id,
320 proto::UpdateChannelBufferCollaborators {
321 channel_id: channel_id.to_proto(),
322 collaborators: refreshed_channel_buffer
323 .collaborators
324 .clone(),
325 },
326 )
327 .trace_err();
328 }
329 }
330 }
331
332 for room_id in room_ids {
333 let mut contacts_to_update = HashSet::default();
334 let mut canceled_calls_to_user_ids = Vec::new();
335 let mut live_kit_room = String::new();
336 let mut delete_live_kit_room = false;
337
338 if let Some(mut refreshed_room) = app_state
339 .db
340 .clear_stale_room_participants(room_id, server_id)
341 .await
342 .trace_err()
343 {
344 tracing::info!(
345 room_id = room_id.0,
346 new_participant_count = refreshed_room.room.participants.len(),
347 "refreshed room"
348 );
349 room_updated(&refreshed_room.room, &peer);
350 if let Some(channel_id) = refreshed_room.channel_id {
351 channel_updated(
352 channel_id,
353 &refreshed_room.room,
354 &refreshed_room.channel_members,
355 &peer,
356 &pool.lock(),
357 );
358 }
359 contacts_to_update
360 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
361 contacts_to_update
362 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
363 canceled_calls_to_user_ids =
364 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
365 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
366 delete_live_kit_room = refreshed_room.room.participants.is_empty();
367 }
368
369 {
370 let pool = pool.lock();
371 for canceled_user_id in canceled_calls_to_user_ids {
372 for connection_id in pool.user_connection_ids(canceled_user_id) {
373 peer.send(
374 connection_id,
375 proto::CallCanceled {
376 room_id: room_id.to_proto(),
377 },
378 )
379 .trace_err();
380 }
381 }
382 }
383
384 for user_id in contacts_to_update {
385 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
386 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
387 if let Some((busy, contacts)) = busy.zip(contacts) {
388 let pool = pool.lock();
389 let updated_contact = contact_for_user(user_id, busy, &pool);
390 for contact in contacts {
391 if let db::Contact::Accepted {
392 user_id: contact_user_id,
393 ..
394 } = contact
395 {
396 for contact_conn_id in
397 pool.user_connection_ids(contact_user_id)
398 {
399 peer.send(
400 contact_conn_id,
401 proto::UpdateContacts {
402 contacts: vec![updated_contact.clone()],
403 remove_contacts: Default::default(),
404 incoming_requests: Default::default(),
405 remove_incoming_requests: Default::default(),
406 outgoing_requests: Default::default(),
407 remove_outgoing_requests: Default::default(),
408 },
409 )
410 .trace_err();
411 }
412 }
413 }
414 }
415 }
416
417 if let Some(live_kit) = live_kit_client.as_ref() {
418 if delete_live_kit_room {
419 live_kit.delete_room(live_kit_room).await.trace_err();
420 }
421 }
422 }
423 }
424
425 app_state
426 .db
427 .delete_stale_servers(&app_state.config.zed_environment, server_id)
428 .await
429 .trace_err();
430 }
431 .instrument(span),
432 );
433 Ok(())
434 }
435
436 pub fn teardown(&self) {
437 self.peer.teardown();
438 self.connection_pool.lock().reset();
439 let _ = self.teardown.send(true);
440 }
441
442 #[cfg(test)]
443 pub fn reset(&self, id: ServerId) {
444 self.teardown();
445 *self.id.lock() = id;
446 self.peer.reset(id.0 as u32);
447 let _ = self.teardown.send(false);
448 }
449
450 #[cfg(test)]
451 pub fn id(&self) -> ServerId {
452 *self.id.lock()
453 }
454
455 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
456 where
457 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
458 Fut: 'static + Send + Future<Output = Result<()>>,
459 M: EnvelopedMessage,
460 {
461 let prev_handler = self.handlers.insert(
462 TypeId::of::<M>(),
463 Box::new(move |envelope, session| {
464 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
465 let span = info_span!(
466 "handle message",
467 payload_type = envelope.payload_type_name()
468 );
469 span.in_scope(|| {
470 tracing::info!(
471 payload_type = envelope.payload_type_name(),
472 "message received"
473 );
474 });
475 let start_time = Instant::now();
476 let future = (handler)(*envelope, session);
477 async move {
478 let result = future.await;
479 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
480 match result {
481 Err(error) => {
482 tracing::error!(%error, ?duration_ms, "error handling message")
483 }
484 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
485 }
486 }
487 .instrument(span)
488 .boxed()
489 }),
490 );
491 if prev_handler.is_some() {
492 panic!("registered a handler for the same message twice");
493 }
494 self
495 }
496
497 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
498 where
499 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
500 Fut: 'static + Send + Future<Output = Result<()>>,
501 M: EnvelopedMessage,
502 {
503 self.add_handler(move |envelope, session| handler(envelope.payload, session));
504 self
505 }
506
507 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
508 where
509 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
510 Fut: Send + Future<Output = Result<()>>,
511 M: RequestMessage,
512 {
513 let handler = Arc::new(handler);
514 self.add_handler(move |envelope, session| {
515 let receipt = envelope.receipt();
516 let handler = handler.clone();
517 async move {
518 let peer = session.peer.clone();
519 let responded = Arc::new(AtomicBool::default());
520 let response = Response {
521 peer: peer.clone(),
522 responded: responded.clone(),
523 receipt,
524 };
525 match (handler)(envelope.payload, response, session).await {
526 Ok(()) => {
527 if responded.load(std::sync::atomic::Ordering::SeqCst) {
528 Ok(())
529 } else {
530 Err(anyhow!("handler did not send a response"))?
531 }
532 }
533 Err(error) => {
534 let proto_err = match &error {
535 Error::Internal(err) => err.to_proto(),
536 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
537 };
538 peer.respond_with_error(receipt, proto_err)?;
539 Err(error)
540 }
541 }
542 }
543 })
544 }
545
546 #[allow(clippy::too_many_arguments)]
547 pub fn handle_connection(
548 self: &Arc<Self>,
549 connection: Connection,
550 address: String,
551 user: User,
552 zed_version: ZedVersion,
553 impersonator: Option<User>,
554 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
555 executor: Executor,
556 ) -> impl Future<Output = Result<()>> {
557 let this = self.clone();
558 let user_id = user.id;
559 let login = user.github_login;
560 let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
561 if let Some(impersonator) = impersonator {
562 span.record("impersonator", &impersonator.github_login);
563 }
564 let mut teardown = self.teardown.subscribe();
565 async move {
566 if *teardown.borrow() {
567 return Err(anyhow!("server is tearing down"))?;
568 }
569 let (connection_id, handle_io, mut incoming_rx) = this
570 .peer
571 .add_connection(connection, {
572 let executor = executor.clone();
573 move |duration| executor.sleep(duration)
574 });
575
576 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
577 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
578 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
579
580 if let Some(send_connection_id) = send_connection_id.take() {
581 let _ = send_connection_id.send(connection_id);
582 }
583
584 if !user.connected_once {
585 this.peer.send(connection_id, proto::ShowContacts {})?;
586 this.app_state.db.set_user_connected_once(user_id, true).await?;
587 }
588
589 let (contacts, channels_for_user, channel_invites) = future::try_join3(
590 this.app_state.db.get_contacts(user_id),
591 this.app_state.db.get_channels_for_user(user_id),
592 this.app_state.db.get_channel_invites_for_user(user_id),
593 ).await?;
594
595 {
596 let mut pool = this.connection_pool.lock();
597 pool.add_connection(connection_id, user_id, user.admin, zed_version);
598 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
599 this.peer.send(connection_id, build_update_user_channels(&channels_for_user))?;
600 this.peer.send(connection_id, build_channels_update(
601 channels_for_user,
602 channel_invites
603 ))?;
604 }
605
606 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
607 this.peer.send(connection_id, incoming_call)?;
608 }
609
610 let session = Session {
611 user_id,
612 connection_id,
613 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
614 peer: this.peer.clone(),
615 connection_pool: this.connection_pool.clone(),
616 live_kit_client: this.app_state.live_kit_client.clone(),
617 _executor: executor.clone()
618 };
619 update_user_contacts(user_id, &session).await?;
620
621 let handle_io = handle_io.fuse();
622 futures::pin_mut!(handle_io);
623
624 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
625 // This prevents deadlocks when e.g., client A performs a request to client B and
626 // client B performs a request to client A. If both clients stop processing further
627 // messages until their respective request completes, they won't have a chance to
628 // respond to the other client's request and cause a deadlock.
629 //
630 // This arrangement ensures we will attempt to process earlier messages first, but fall
631 // back to processing messages arrived later in the spirit of making progress.
632 let mut foreground_message_handlers = FuturesUnordered::new();
633 let concurrent_handlers = Arc::new(Semaphore::new(256));
634 loop {
635 let next_message = async {
636 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
637 let message = incoming_rx.next().await;
638 (permit, message)
639 }.fuse();
640 futures::pin_mut!(next_message);
641 futures::select_biased! {
642 _ = teardown.changed().fuse() => return Ok(()),
643 result = handle_io => {
644 if let Err(error) = result {
645 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
646 }
647 break;
648 }
649 _ = foreground_message_handlers.next() => {}
650 next_message = next_message => {
651 let (permit, message) = next_message;
652 if let Some(message) = message {
653 let type_name = message.payload_type_name();
654 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
655 let span_enter = span.enter();
656 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
657 let is_background = message.is_background();
658 let handle_message = (handler)(message, session.clone());
659 drop(span_enter);
660
661 let handle_message = async move {
662 handle_message.await;
663 drop(permit);
664 }.instrument(span);
665 if is_background {
666 executor.spawn_detached(handle_message);
667 } else {
668 foreground_message_handlers.push(handle_message);
669 }
670 } else {
671 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
672 }
673 } else {
674 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
675 break;
676 }
677 }
678 }
679 }
680
681 drop(foreground_message_handlers);
682 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
683 if let Err(error) = connection_lost(session, teardown, executor).await {
684 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
685 }
686
687 Ok(())
688 }.instrument(span)
689 }
690
691 pub async fn invite_code_redeemed(
692 self: &Arc<Self>,
693 inviter_id: UserId,
694 invitee_id: UserId,
695 ) -> Result<()> {
696 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
697 if let Some(code) = &user.invite_code {
698 let pool = self.connection_pool.lock();
699 let invitee_contact = contact_for_user(invitee_id, false, &pool);
700 for connection_id in pool.user_connection_ids(inviter_id) {
701 self.peer.send(
702 connection_id,
703 proto::UpdateContacts {
704 contacts: vec![invitee_contact.clone()],
705 ..Default::default()
706 },
707 )?;
708 self.peer.send(
709 connection_id,
710 proto::UpdateInviteInfo {
711 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
712 count: user.invite_count as u32,
713 },
714 )?;
715 }
716 }
717 }
718 Ok(())
719 }
720
721 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
722 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
723 if let Some(invite_code) = &user.invite_code {
724 let pool = self.connection_pool.lock();
725 for connection_id in pool.user_connection_ids(user_id) {
726 self.peer.send(
727 connection_id,
728 proto::UpdateInviteInfo {
729 url: format!(
730 "{}{}",
731 self.app_state.config.invite_link_prefix, invite_code
732 ),
733 count: user.invite_count as u32,
734 },
735 )?;
736 }
737 }
738 }
739 Ok(())
740 }
741
742 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
743 ServerSnapshot {
744 connection_pool: ConnectionPoolGuard {
745 guard: self.connection_pool.lock(),
746 _not_send: PhantomData,
747 },
748 peer: &self.peer,
749 }
750 }
751}
752
753impl<'a> Deref for ConnectionPoolGuard<'a> {
754 type Target = ConnectionPool;
755
756 fn deref(&self) -> &Self::Target {
757 &self.guard
758 }
759}
760
761impl<'a> DerefMut for ConnectionPoolGuard<'a> {
762 fn deref_mut(&mut self) -> &mut Self::Target {
763 &mut self.guard
764 }
765}
766
767impl<'a> Drop for ConnectionPoolGuard<'a> {
768 fn drop(&mut self) {
769 #[cfg(test)]
770 self.check_invariants();
771 }
772}
773
774fn broadcast<F>(
775 sender_id: Option<ConnectionId>,
776 receiver_ids: impl IntoIterator<Item = ConnectionId>,
777 mut f: F,
778) where
779 F: FnMut(ConnectionId) -> anyhow::Result<()>,
780{
781 for receiver_id in receiver_ids {
782 if Some(receiver_id) != sender_id {
783 if let Err(error) = f(receiver_id) {
784 tracing::error!("failed to send to {:?} {}", receiver_id, error);
785 }
786 }
787 }
788}
789
790pub struct ProtocolVersion(u32);
791
792impl Header for ProtocolVersion {
793 fn name() -> &'static HeaderName {
794 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
795 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
796 }
797
798 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
799 where
800 Self: Sized,
801 I: Iterator<Item = &'i axum::http::HeaderValue>,
802 {
803 let version = values
804 .next()
805 .ok_or_else(axum::headers::Error::invalid)?
806 .to_str()
807 .map_err(|_| axum::headers::Error::invalid())?
808 .parse()
809 .map_err(|_| axum::headers::Error::invalid())?;
810 Ok(Self(version))
811 }
812
813 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
814 values.extend([self.0.to_string().parse().unwrap()]);
815 }
816}
817
818pub struct AppVersionHeader(SemanticVersion);
819impl Header for AppVersionHeader {
820 fn name() -> &'static HeaderName {
821 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
822 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
823 }
824
825 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
826 where
827 Self: Sized,
828 I: Iterator<Item = &'i axum::http::HeaderValue>,
829 {
830 let version = values
831 .next()
832 .ok_or_else(axum::headers::Error::invalid)?
833 .to_str()
834 .map_err(|_| axum::headers::Error::invalid())?
835 .parse()
836 .map_err(|_| axum::headers::Error::invalid())?;
837 Ok(Self(version))
838 }
839
840 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
841 values.extend([self.0.to_string().parse().unwrap()]);
842 }
843}
844
845pub fn routes(server: Arc<Server>) -> Router<Body> {
846 Router::new()
847 .route("/rpc", get(handle_websocket_request))
848 .layer(
849 ServiceBuilder::new()
850 .layer(Extension(server.app_state.clone()))
851 .layer(middleware::from_fn(auth::validate_header)),
852 )
853 .route("/metrics", get(handle_metrics))
854 .layer(Extension(server))
855}
856
857pub async fn handle_websocket_request(
858 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
859 app_version_header: Option<TypedHeader<AppVersionHeader>>,
860 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
861 Extension(server): Extension<Arc<Server>>,
862 Extension(user): Extension<User>,
863 Extension(impersonator): Extension<Impersonator>,
864 ws: WebSocketUpgrade,
865) -> axum::response::Response {
866 if protocol_version != rpc::PROTOCOL_VERSION {
867 return (
868 StatusCode::UPGRADE_REQUIRED,
869 "client must be upgraded".to_string(),
870 )
871 .into_response();
872 }
873
874 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
875 return (
876 StatusCode::UPGRADE_REQUIRED,
877 "no version header found".to_string(),
878 )
879 .into_response();
880 };
881
882 if !version.is_supported() {
883 return (
884 StatusCode::UPGRADE_REQUIRED,
885 "client must be upgraded".to_string(),
886 )
887 .into_response();
888 }
889
890 let socket_address = socket_address.to_string();
891 ws.on_upgrade(move |socket| {
892 use util::ResultExt;
893 let socket = socket
894 .map_ok(to_tungstenite_message)
895 .err_into()
896 .with(|message| async move { Ok(to_axum_message(message)) });
897 let connection = Connection::new(Box::pin(socket));
898 async move {
899 server
900 .handle_connection(
901 connection,
902 socket_address,
903 user,
904 version,
905 impersonator.0,
906 None,
907 Executor::Production,
908 )
909 .await
910 .log_err();
911 }
912 })
913}
914
915pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
916 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
917 let connections_metric = CONNECTIONS_METRIC
918 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
919
920 let connections = server
921 .connection_pool
922 .lock()
923 .connections()
924 .filter(|connection| !connection.admin)
925 .count();
926 connections_metric.set(connections as _);
927
928 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
929 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
930 register_int_gauge!(
931 "shared_projects",
932 "number of open projects with one or more guests"
933 )
934 .unwrap()
935 });
936
937 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
938 shared_projects_metric.set(shared_projects as _);
939
940 let encoder = prometheus::TextEncoder::new();
941 let metric_families = prometheus::gather();
942 let encoded_metrics = encoder
943 .encode_to_string(&metric_families)
944 .map_err(|err| anyhow!("{}", err))?;
945 Ok(encoded_metrics)
946}
947
948#[instrument(err, skip(executor))]
949async fn connection_lost(
950 session: Session,
951 mut teardown: watch::Receiver<bool>,
952 executor: Executor,
953) -> Result<()> {
954 session.peer.disconnect(session.connection_id);
955 session
956 .connection_pool()
957 .await
958 .remove_connection(session.connection_id)?;
959
960 session
961 .db()
962 .await
963 .connection_lost(session.connection_id)
964 .await
965 .trace_err();
966
967 futures::select_biased! {
968 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
969 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
970 leave_room_for_session(&session).await.trace_err();
971 leave_channel_buffers_for_session(&session)
972 .await
973 .trace_err();
974
975 if !session
976 .connection_pool()
977 .await
978 .is_user_online(session.user_id)
979 {
980 let db = session.db().await;
981 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
982 room_updated(&room, &session.peer);
983 }
984 }
985
986 update_user_contacts(session.user_id, &session).await?;
987 }
988 _ = teardown.changed().fuse() => {}
989 }
990
991 Ok(())
992}
993
994/// Acknowledges a ping from a client, used to keep the connection alive.
995async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
996 response.send(proto::Ack {})?;
997 Ok(())
998}
999
1000/// Creates a new room for calling (outside of channels)
1001async fn create_room(
1002 _request: proto::CreateRoom,
1003 response: Response<proto::CreateRoom>,
1004 session: Session,
1005) -> Result<()> {
1006 let live_kit_room = nanoid::nanoid!(30);
1007
1008 let live_kit_connection_info = {
1009 let live_kit_room = live_kit_room.clone();
1010 let live_kit = session.live_kit_client.as_ref();
1011
1012 util::async_maybe!({
1013 let live_kit = live_kit?;
1014
1015 let token = live_kit
1016 .room_token(&live_kit_room, &session.user_id.to_string())
1017 .trace_err()?;
1018
1019 Some(proto::LiveKitConnectionInfo {
1020 server_url: live_kit.url().into(),
1021 token,
1022 can_publish: true,
1023 })
1024 })
1025 }
1026 .await;
1027
1028 let room = session
1029 .db()
1030 .await
1031 .create_room(session.user_id, session.connection_id, &live_kit_room)
1032 .await?;
1033
1034 response.send(proto::CreateRoomResponse {
1035 room: Some(room.clone()),
1036 live_kit_connection_info,
1037 })?;
1038
1039 update_user_contacts(session.user_id, &session).await?;
1040 Ok(())
1041}
1042
1043/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1044async fn join_room(
1045 request: proto::JoinRoom,
1046 response: Response<proto::JoinRoom>,
1047 session: Session,
1048) -> Result<()> {
1049 let room_id = RoomId::from_proto(request.id);
1050
1051 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1052
1053 if let Some(channel_id) = channel_id {
1054 return join_channel_internal(channel_id, Box::new(response), session).await;
1055 }
1056
1057 let joined_room = {
1058 let room = session
1059 .db()
1060 .await
1061 .join_room(room_id, session.user_id, session.connection_id)
1062 .await?;
1063 room_updated(&room.room, &session.peer);
1064 room.into_inner()
1065 };
1066
1067 for connection_id in session
1068 .connection_pool()
1069 .await
1070 .user_connection_ids(session.user_id)
1071 {
1072 session
1073 .peer
1074 .send(
1075 connection_id,
1076 proto::CallCanceled {
1077 room_id: room_id.to_proto(),
1078 },
1079 )
1080 .trace_err();
1081 }
1082
1083 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1084 if let Some(token) = live_kit
1085 .room_token(
1086 &joined_room.room.live_kit_room,
1087 &session.user_id.to_string(),
1088 )
1089 .trace_err()
1090 {
1091 Some(proto::LiveKitConnectionInfo {
1092 server_url: live_kit.url().into(),
1093 token,
1094 can_publish: true,
1095 })
1096 } else {
1097 None
1098 }
1099 } else {
1100 None
1101 };
1102
1103 response.send(proto::JoinRoomResponse {
1104 room: Some(joined_room.room),
1105 channel_id: None,
1106 live_kit_connection_info,
1107 })?;
1108
1109 update_user_contacts(session.user_id, &session).await?;
1110 Ok(())
1111}
1112
1113/// Rejoin room is used to reconnect to a room after connection errors.
1114async fn rejoin_room(
1115 request: proto::RejoinRoom,
1116 response: Response<proto::RejoinRoom>,
1117 session: Session,
1118) -> Result<()> {
1119 let room;
1120 let channel_id;
1121 let channel_members;
1122 {
1123 let mut rejoined_room = session
1124 .db()
1125 .await
1126 .rejoin_room(request, session.user_id, session.connection_id)
1127 .await?;
1128
1129 response.send(proto::RejoinRoomResponse {
1130 room: Some(rejoined_room.room.clone()),
1131 reshared_projects: rejoined_room
1132 .reshared_projects
1133 .iter()
1134 .map(|project| proto::ResharedProject {
1135 id: project.id.to_proto(),
1136 collaborators: project
1137 .collaborators
1138 .iter()
1139 .map(|collaborator| collaborator.to_proto())
1140 .collect(),
1141 })
1142 .collect(),
1143 rejoined_projects: rejoined_room
1144 .rejoined_projects
1145 .iter()
1146 .map(|rejoined_project| proto::RejoinedProject {
1147 id: rejoined_project.id.to_proto(),
1148 worktrees: rejoined_project
1149 .worktrees
1150 .iter()
1151 .map(|worktree| proto::WorktreeMetadata {
1152 id: worktree.id,
1153 root_name: worktree.root_name.clone(),
1154 visible: worktree.visible,
1155 abs_path: worktree.abs_path.clone(),
1156 })
1157 .collect(),
1158 collaborators: rejoined_project
1159 .collaborators
1160 .iter()
1161 .map(|collaborator| collaborator.to_proto())
1162 .collect(),
1163 language_servers: rejoined_project.language_servers.clone(),
1164 })
1165 .collect(),
1166 })?;
1167 room_updated(&rejoined_room.room, &session.peer);
1168
1169 for project in &rejoined_room.reshared_projects {
1170 for collaborator in &project.collaborators {
1171 session
1172 .peer
1173 .send(
1174 collaborator.connection_id,
1175 proto::UpdateProjectCollaborator {
1176 project_id: project.id.to_proto(),
1177 old_peer_id: Some(project.old_connection_id.into()),
1178 new_peer_id: Some(session.connection_id.into()),
1179 },
1180 )
1181 .trace_err();
1182 }
1183
1184 broadcast(
1185 Some(session.connection_id),
1186 project
1187 .collaborators
1188 .iter()
1189 .map(|collaborator| collaborator.connection_id),
1190 |connection_id| {
1191 session.peer.forward_send(
1192 session.connection_id,
1193 connection_id,
1194 proto::UpdateProject {
1195 project_id: project.id.to_proto(),
1196 worktrees: project.worktrees.clone(),
1197 },
1198 )
1199 },
1200 );
1201 }
1202
1203 for project in &rejoined_room.rejoined_projects {
1204 for collaborator in &project.collaborators {
1205 session
1206 .peer
1207 .send(
1208 collaborator.connection_id,
1209 proto::UpdateProjectCollaborator {
1210 project_id: project.id.to_proto(),
1211 old_peer_id: Some(project.old_connection_id.into()),
1212 new_peer_id: Some(session.connection_id.into()),
1213 },
1214 )
1215 .trace_err();
1216 }
1217 }
1218
1219 for project in &mut rejoined_room.rejoined_projects {
1220 for worktree in mem::take(&mut project.worktrees) {
1221 #[cfg(any(test, feature = "test-support"))]
1222 const MAX_CHUNK_SIZE: usize = 2;
1223 #[cfg(not(any(test, feature = "test-support")))]
1224 const MAX_CHUNK_SIZE: usize = 256;
1225
1226 // Stream this worktree's entries.
1227 let message = proto::UpdateWorktree {
1228 project_id: project.id.to_proto(),
1229 worktree_id: worktree.id,
1230 abs_path: worktree.abs_path.clone(),
1231 root_name: worktree.root_name,
1232 updated_entries: worktree.updated_entries,
1233 removed_entries: worktree.removed_entries,
1234 scan_id: worktree.scan_id,
1235 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1236 updated_repositories: worktree.updated_repositories,
1237 removed_repositories: worktree.removed_repositories,
1238 };
1239 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1240 session.peer.send(session.connection_id, update.clone())?;
1241 }
1242
1243 // Stream this worktree's diagnostics.
1244 for summary in worktree.diagnostic_summaries {
1245 session.peer.send(
1246 session.connection_id,
1247 proto::UpdateDiagnosticSummary {
1248 project_id: project.id.to_proto(),
1249 worktree_id: worktree.id,
1250 summary: Some(summary),
1251 },
1252 )?;
1253 }
1254
1255 for settings_file in worktree.settings_files {
1256 session.peer.send(
1257 session.connection_id,
1258 proto::UpdateWorktreeSettings {
1259 project_id: project.id.to_proto(),
1260 worktree_id: worktree.id,
1261 path: settings_file.path,
1262 content: Some(settings_file.content),
1263 },
1264 )?;
1265 }
1266 }
1267
1268 for language_server in &project.language_servers {
1269 session.peer.send(
1270 session.connection_id,
1271 proto::UpdateLanguageServer {
1272 project_id: project.id.to_proto(),
1273 language_server_id: language_server.id,
1274 variant: Some(
1275 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1276 proto::LspDiskBasedDiagnosticsUpdated {},
1277 ),
1278 ),
1279 },
1280 )?;
1281 }
1282 }
1283
1284 let rejoined_room = rejoined_room.into_inner();
1285
1286 room = rejoined_room.room;
1287 channel_id = rejoined_room.channel_id;
1288 channel_members = rejoined_room.channel_members;
1289 }
1290
1291 if let Some(channel_id) = channel_id {
1292 channel_updated(
1293 channel_id,
1294 &room,
1295 &channel_members,
1296 &session.peer,
1297 &*session.connection_pool().await,
1298 );
1299 }
1300
1301 update_user_contacts(session.user_id, &session).await?;
1302 Ok(())
1303}
1304
1305/// leave room disconnects from the room.
1306async fn leave_room(
1307 _: proto::LeaveRoom,
1308 response: Response<proto::LeaveRoom>,
1309 session: Session,
1310) -> Result<()> {
1311 leave_room_for_session(&session).await?;
1312 response.send(proto::Ack {})?;
1313 Ok(())
1314}
1315
1316/// Updates the permissions of someone else in the room.
1317async fn set_room_participant_role(
1318 request: proto::SetRoomParticipantRole,
1319 response: Response<proto::SetRoomParticipantRole>,
1320 session: Session,
1321) -> Result<()> {
1322 let user_id = UserId::from_proto(request.user_id);
1323 let role = ChannelRole::from(request.role());
1324
1325 if role == ChannelRole::Talker {
1326 let pool = session.connection_pool().await;
1327
1328 for connection in pool.user_connections(user_id) {
1329 if !connection.zed_version.supports_talker_role() {
1330 Err(anyhow!(
1331 "This user is on zed {} which does not support unmute",
1332 connection.zed_version
1333 ))?;
1334 }
1335 }
1336 }
1337
1338 let (live_kit_room, can_publish) = {
1339 let room = session
1340 .db()
1341 .await
1342 .set_room_participant_role(
1343 session.user_id,
1344 RoomId::from_proto(request.room_id),
1345 user_id,
1346 role,
1347 )
1348 .await?;
1349
1350 let live_kit_room = room.live_kit_room.clone();
1351 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1352 room_updated(&room, &session.peer);
1353 (live_kit_room, can_publish)
1354 };
1355
1356 if let Some(live_kit) = session.live_kit_client.as_ref() {
1357 live_kit
1358 .update_participant(
1359 live_kit_room.clone(),
1360 request.user_id.to_string(),
1361 live_kit_server::proto::ParticipantPermission {
1362 can_subscribe: true,
1363 can_publish,
1364 can_publish_data: can_publish,
1365 hidden: false,
1366 recorder: false,
1367 },
1368 )
1369 .await
1370 .trace_err();
1371 }
1372
1373 response.send(proto::Ack {})?;
1374 Ok(())
1375}
1376
1377/// Call someone else into the current room
1378async fn call(
1379 request: proto::Call,
1380 response: Response<proto::Call>,
1381 session: Session,
1382) -> Result<()> {
1383 let room_id = RoomId::from_proto(request.room_id);
1384 let calling_user_id = session.user_id;
1385 let calling_connection_id = session.connection_id;
1386 let called_user_id = UserId::from_proto(request.called_user_id);
1387 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1388 if !session
1389 .db()
1390 .await
1391 .has_contact(calling_user_id, called_user_id)
1392 .await?
1393 {
1394 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1395 }
1396
1397 let incoming_call = {
1398 let (room, incoming_call) = &mut *session
1399 .db()
1400 .await
1401 .call(
1402 room_id,
1403 calling_user_id,
1404 calling_connection_id,
1405 called_user_id,
1406 initial_project_id,
1407 )
1408 .await?;
1409 room_updated(&room, &session.peer);
1410 mem::take(incoming_call)
1411 };
1412 update_user_contacts(called_user_id, &session).await?;
1413
1414 let mut calls = session
1415 .connection_pool()
1416 .await
1417 .user_connection_ids(called_user_id)
1418 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1419 .collect::<FuturesUnordered<_>>();
1420
1421 while let Some(call_response) = calls.next().await {
1422 match call_response.as_ref() {
1423 Ok(_) => {
1424 response.send(proto::Ack {})?;
1425 return Ok(());
1426 }
1427 Err(_) => {
1428 call_response.trace_err();
1429 }
1430 }
1431 }
1432
1433 {
1434 let room = session
1435 .db()
1436 .await
1437 .call_failed(room_id, called_user_id)
1438 .await?;
1439 room_updated(&room, &session.peer);
1440 }
1441 update_user_contacts(called_user_id, &session).await?;
1442
1443 Err(anyhow!("failed to ring user"))?
1444}
1445
1446/// Cancel an outgoing call.
1447async fn cancel_call(
1448 request: proto::CancelCall,
1449 response: Response<proto::CancelCall>,
1450 session: Session,
1451) -> Result<()> {
1452 let called_user_id = UserId::from_proto(request.called_user_id);
1453 let room_id = RoomId::from_proto(request.room_id);
1454 {
1455 let room = session
1456 .db()
1457 .await
1458 .cancel_call(room_id, session.connection_id, called_user_id)
1459 .await?;
1460 room_updated(&room, &session.peer);
1461 }
1462
1463 for connection_id in session
1464 .connection_pool()
1465 .await
1466 .user_connection_ids(called_user_id)
1467 {
1468 session
1469 .peer
1470 .send(
1471 connection_id,
1472 proto::CallCanceled {
1473 room_id: room_id.to_proto(),
1474 },
1475 )
1476 .trace_err();
1477 }
1478 response.send(proto::Ack {})?;
1479
1480 update_user_contacts(called_user_id, &session).await?;
1481 Ok(())
1482}
1483
1484/// Decline an incoming call.
1485async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1486 let room_id = RoomId::from_proto(message.room_id);
1487 {
1488 let room = session
1489 .db()
1490 .await
1491 .decline_call(Some(room_id), session.user_id)
1492 .await?
1493 .ok_or_else(|| anyhow!("failed to decline call"))?;
1494 room_updated(&room, &session.peer);
1495 }
1496
1497 for connection_id in session
1498 .connection_pool()
1499 .await
1500 .user_connection_ids(session.user_id)
1501 {
1502 session
1503 .peer
1504 .send(
1505 connection_id,
1506 proto::CallCanceled {
1507 room_id: room_id.to_proto(),
1508 },
1509 )
1510 .trace_err();
1511 }
1512 update_user_contacts(session.user_id, &session).await?;
1513 Ok(())
1514}
1515
1516/// Updates other participants in the room with your current location.
1517async fn update_participant_location(
1518 request: proto::UpdateParticipantLocation,
1519 response: Response<proto::UpdateParticipantLocation>,
1520 session: Session,
1521) -> Result<()> {
1522 let room_id = RoomId::from_proto(request.room_id);
1523 let location = request
1524 .location
1525 .ok_or_else(|| anyhow!("invalid location"))?;
1526
1527 let db = session.db().await;
1528 let room = db
1529 .update_room_participant_location(room_id, session.connection_id, location)
1530 .await?;
1531
1532 room_updated(&room, &session.peer);
1533 response.send(proto::Ack {})?;
1534 Ok(())
1535}
1536
1537/// Share a project into the room.
1538async fn share_project(
1539 request: proto::ShareProject,
1540 response: Response<proto::ShareProject>,
1541 session: Session,
1542) -> Result<()> {
1543 let (project_id, room) = &*session
1544 .db()
1545 .await
1546 .share_project(
1547 RoomId::from_proto(request.room_id),
1548 session.connection_id,
1549 &request.worktrees,
1550 )
1551 .await?;
1552 response.send(proto::ShareProjectResponse {
1553 project_id: project_id.to_proto(),
1554 })?;
1555 room_updated(&room, &session.peer);
1556
1557 Ok(())
1558}
1559
1560/// Unshare a project from the room.
1561async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1562 let project_id = ProjectId::from_proto(message.project_id);
1563
1564 let (room, guest_connection_ids) = &*session
1565 .db()
1566 .await
1567 .unshare_project(project_id, session.connection_id)
1568 .await?;
1569
1570 broadcast(
1571 Some(session.connection_id),
1572 guest_connection_ids.iter().copied(),
1573 |conn_id| session.peer.send(conn_id, message.clone()),
1574 );
1575 room_updated(&room, &session.peer);
1576
1577 Ok(())
1578}
1579
1580/// Join someone elses shared project.
1581async fn join_project(
1582 request: proto::JoinProject,
1583 response: Response<proto::JoinProject>,
1584 session: Session,
1585) -> Result<()> {
1586 let project_id = ProjectId::from_proto(request.project_id);
1587 let guest_user_id = session.user_id;
1588
1589 tracing::info!(%project_id, "join project");
1590
1591 let (project, replica_id) = &mut *session
1592 .db()
1593 .await
1594 .join_project(project_id, session.connection_id)
1595 .await?;
1596
1597 let collaborators = project
1598 .collaborators
1599 .iter()
1600 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1601 .map(|collaborator| collaborator.to_proto())
1602 .collect::<Vec<_>>();
1603
1604 let worktrees = project
1605 .worktrees
1606 .iter()
1607 .map(|(id, worktree)| proto::WorktreeMetadata {
1608 id: *id,
1609 root_name: worktree.root_name.clone(),
1610 visible: worktree.visible,
1611 abs_path: worktree.abs_path.clone(),
1612 })
1613 .collect::<Vec<_>>();
1614
1615 for collaborator in &collaborators {
1616 session
1617 .peer
1618 .send(
1619 collaborator.peer_id.unwrap().into(),
1620 proto::AddProjectCollaborator {
1621 project_id: project_id.to_proto(),
1622 collaborator: Some(proto::Collaborator {
1623 peer_id: Some(session.connection_id.into()),
1624 replica_id: replica_id.0 as u32,
1625 user_id: guest_user_id.to_proto(),
1626 }),
1627 },
1628 )
1629 .trace_err();
1630 }
1631
1632 // First, we send the metadata associated with each worktree.
1633 response.send(proto::JoinProjectResponse {
1634 worktrees: worktrees.clone(),
1635 replica_id: replica_id.0 as u32,
1636 collaborators: collaborators.clone(),
1637 language_servers: project.language_servers.clone(),
1638 })?;
1639
1640 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1641 #[cfg(any(test, feature = "test-support"))]
1642 const MAX_CHUNK_SIZE: usize = 2;
1643 #[cfg(not(any(test, feature = "test-support")))]
1644 const MAX_CHUNK_SIZE: usize = 256;
1645
1646 // Stream this worktree's entries.
1647 let message = proto::UpdateWorktree {
1648 project_id: project_id.to_proto(),
1649 worktree_id,
1650 abs_path: worktree.abs_path.clone(),
1651 root_name: worktree.root_name,
1652 updated_entries: worktree.entries,
1653 removed_entries: Default::default(),
1654 scan_id: worktree.scan_id,
1655 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1656 updated_repositories: worktree.repository_entries.into_values().collect(),
1657 removed_repositories: Default::default(),
1658 };
1659 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1660 session.peer.send(session.connection_id, update.clone())?;
1661 }
1662
1663 // Stream this worktree's diagnostics.
1664 for summary in worktree.diagnostic_summaries {
1665 session.peer.send(
1666 session.connection_id,
1667 proto::UpdateDiagnosticSummary {
1668 project_id: project_id.to_proto(),
1669 worktree_id: worktree.id,
1670 summary: Some(summary),
1671 },
1672 )?;
1673 }
1674
1675 for settings_file in worktree.settings_files {
1676 session.peer.send(
1677 session.connection_id,
1678 proto::UpdateWorktreeSettings {
1679 project_id: project_id.to_proto(),
1680 worktree_id: worktree.id,
1681 path: settings_file.path,
1682 content: Some(settings_file.content),
1683 },
1684 )?;
1685 }
1686 }
1687
1688 for language_server in &project.language_servers {
1689 session.peer.send(
1690 session.connection_id,
1691 proto::UpdateLanguageServer {
1692 project_id: project_id.to_proto(),
1693 language_server_id: language_server.id,
1694 variant: Some(
1695 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1696 proto::LspDiskBasedDiagnosticsUpdated {},
1697 ),
1698 ),
1699 },
1700 )?;
1701 }
1702
1703 Ok(())
1704}
1705
1706/// Leave someone elses shared project.
1707async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1708 let sender_id = session.connection_id;
1709 let project_id = ProjectId::from_proto(request.project_id);
1710
1711 let (room, project) = &*session
1712 .db()
1713 .await
1714 .leave_project(project_id, sender_id)
1715 .await?;
1716 tracing::info!(
1717 %project_id,
1718 host_user_id = %project.host_user_id,
1719 host_connection_id = ?project.host_connection_id,
1720 "leave project"
1721 );
1722
1723 project_left(&project, &session);
1724 room_updated(&room, &session.peer);
1725
1726 Ok(())
1727}
1728
1729/// Updates other participants with changes to the project
1730async fn update_project(
1731 request: proto::UpdateProject,
1732 response: Response<proto::UpdateProject>,
1733 session: Session,
1734) -> Result<()> {
1735 let project_id = ProjectId::from_proto(request.project_id);
1736 let (room, guest_connection_ids) = &*session
1737 .db()
1738 .await
1739 .update_project(project_id, session.connection_id, &request.worktrees)
1740 .await?;
1741 broadcast(
1742 Some(session.connection_id),
1743 guest_connection_ids.iter().copied(),
1744 |connection_id| {
1745 session
1746 .peer
1747 .forward_send(session.connection_id, connection_id, request.clone())
1748 },
1749 );
1750 room_updated(&room, &session.peer);
1751 response.send(proto::Ack {})?;
1752
1753 Ok(())
1754}
1755
1756/// Updates other participants with changes to the worktree
1757async fn update_worktree(
1758 request: proto::UpdateWorktree,
1759 response: Response<proto::UpdateWorktree>,
1760 session: Session,
1761) -> Result<()> {
1762 let guest_connection_ids = session
1763 .db()
1764 .await
1765 .update_worktree(&request, session.connection_id)
1766 .await?;
1767
1768 broadcast(
1769 Some(session.connection_id),
1770 guest_connection_ids.iter().copied(),
1771 |connection_id| {
1772 session
1773 .peer
1774 .forward_send(session.connection_id, connection_id, request.clone())
1775 },
1776 );
1777 response.send(proto::Ack {})?;
1778 Ok(())
1779}
1780
1781/// Updates other participants with changes to the diagnostics
1782async fn update_diagnostic_summary(
1783 message: proto::UpdateDiagnosticSummary,
1784 session: Session,
1785) -> Result<()> {
1786 let guest_connection_ids = session
1787 .db()
1788 .await
1789 .update_diagnostic_summary(&message, session.connection_id)
1790 .await?;
1791
1792 broadcast(
1793 Some(session.connection_id),
1794 guest_connection_ids.iter().copied(),
1795 |connection_id| {
1796 session
1797 .peer
1798 .forward_send(session.connection_id, connection_id, message.clone())
1799 },
1800 );
1801
1802 Ok(())
1803}
1804
1805/// Updates other participants with changes to the worktree settings
1806async fn update_worktree_settings(
1807 message: proto::UpdateWorktreeSettings,
1808 session: Session,
1809) -> Result<()> {
1810 let guest_connection_ids = session
1811 .db()
1812 .await
1813 .update_worktree_settings(&message, session.connection_id)
1814 .await?;
1815
1816 broadcast(
1817 Some(session.connection_id),
1818 guest_connection_ids.iter().copied(),
1819 |connection_id| {
1820 session
1821 .peer
1822 .forward_send(session.connection_id, connection_id, message.clone())
1823 },
1824 );
1825
1826 Ok(())
1827}
1828
1829/// Notify other participants that a language server has started.
1830async fn start_language_server(
1831 request: proto::StartLanguageServer,
1832 session: Session,
1833) -> Result<()> {
1834 let guest_connection_ids = session
1835 .db()
1836 .await
1837 .start_language_server(&request, session.connection_id)
1838 .await?;
1839
1840 broadcast(
1841 Some(session.connection_id),
1842 guest_connection_ids.iter().copied(),
1843 |connection_id| {
1844 session
1845 .peer
1846 .forward_send(session.connection_id, connection_id, request.clone())
1847 },
1848 );
1849 Ok(())
1850}
1851
1852/// Notify other participants that a language server has changed.
1853async fn update_language_server(
1854 request: proto::UpdateLanguageServer,
1855 session: Session,
1856) -> Result<()> {
1857 let project_id = ProjectId::from_proto(request.project_id);
1858 let project_connection_ids = session
1859 .db()
1860 .await
1861 .project_connection_ids(project_id, session.connection_id)
1862 .await?;
1863 broadcast(
1864 Some(session.connection_id),
1865 project_connection_ids.iter().copied(),
1866 |connection_id| {
1867 session
1868 .peer
1869 .forward_send(session.connection_id, connection_id, request.clone())
1870 },
1871 );
1872 Ok(())
1873}
1874
1875/// forward a project request to the host. These requests should be read only
1876/// as guests are allowed to send them.
1877async fn forward_read_only_project_request<T>(
1878 request: T,
1879 response: Response<T>,
1880 session: Session,
1881) -> Result<()>
1882where
1883 T: EntityMessage + RequestMessage,
1884{
1885 let project_id = ProjectId::from_proto(request.remote_entity_id());
1886 let host_connection_id = session
1887 .db()
1888 .await
1889 .host_for_read_only_project_request(project_id, session.connection_id)
1890 .await?;
1891 let payload = session
1892 .peer
1893 .forward_request(session.connection_id, host_connection_id, request)
1894 .await?;
1895 response.send(payload)?;
1896 Ok(())
1897}
1898
1899/// forward a project request to the host. These requests are disallowed
1900/// for guests.
1901async fn forward_mutating_project_request<T>(
1902 request: T,
1903 response: Response<T>,
1904 session: Session,
1905) -> Result<()>
1906where
1907 T: EntityMessage + RequestMessage,
1908{
1909 let project_id = ProjectId::from_proto(request.remote_entity_id());
1910 let host_connection_id = session
1911 .db()
1912 .await
1913 .host_for_mutating_project_request(project_id, session.connection_id)
1914 .await?;
1915 let payload = session
1916 .peer
1917 .forward_request(session.connection_id, host_connection_id, request)
1918 .await?;
1919 response.send(payload)?;
1920 Ok(())
1921}
1922
1923/// Notify other participants that a new buffer has been created
1924async fn create_buffer_for_peer(
1925 request: proto::CreateBufferForPeer,
1926 session: Session,
1927) -> Result<()> {
1928 session
1929 .db()
1930 .await
1931 .check_user_is_project_host(
1932 ProjectId::from_proto(request.project_id),
1933 session.connection_id,
1934 )
1935 .await?;
1936 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1937 session
1938 .peer
1939 .forward_send(session.connection_id, peer_id.into(), request)?;
1940 Ok(())
1941}
1942
1943/// Notify other participants that a buffer has been updated. This is
1944/// allowed for guests as long as the update is limited to selections.
1945async fn update_buffer(
1946 request: proto::UpdateBuffer,
1947 response: Response<proto::UpdateBuffer>,
1948 session: Session,
1949) -> Result<()> {
1950 let project_id = ProjectId::from_proto(request.project_id);
1951 let mut guest_connection_ids;
1952 let mut host_connection_id = None;
1953
1954 let mut requires_write_permission = false;
1955
1956 for op in request.operations.iter() {
1957 match op.variant {
1958 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1959 Some(_) => requires_write_permission = true,
1960 }
1961 }
1962
1963 {
1964 let collaborators = session
1965 .db()
1966 .await
1967 .project_collaborators_for_buffer_update(
1968 project_id,
1969 session.connection_id,
1970 requires_write_permission,
1971 )
1972 .await?;
1973 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1974 for collaborator in collaborators.iter() {
1975 if collaborator.is_host {
1976 host_connection_id = Some(collaborator.connection_id);
1977 } else {
1978 guest_connection_ids.push(collaborator.connection_id);
1979 }
1980 }
1981 }
1982 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1983
1984 broadcast(
1985 Some(session.connection_id),
1986 guest_connection_ids,
1987 |connection_id| {
1988 session
1989 .peer
1990 .forward_send(session.connection_id, connection_id, request.clone())
1991 },
1992 );
1993 if host_connection_id != session.connection_id {
1994 session
1995 .peer
1996 .forward_request(session.connection_id, host_connection_id, request.clone())
1997 .await?;
1998 }
1999
2000 response.send(proto::Ack {})?;
2001 Ok(())
2002}
2003
2004/// Notify other participants that a project has been updated.
2005async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2006 request: T,
2007 session: Session,
2008) -> Result<()> {
2009 let project_id = ProjectId::from_proto(request.remote_entity_id());
2010 let project_connection_ids = session
2011 .db()
2012 .await
2013 .project_connection_ids(project_id, session.connection_id)
2014 .await?;
2015
2016 broadcast(
2017 Some(session.connection_id),
2018 project_connection_ids.iter().copied(),
2019 |connection_id| {
2020 session
2021 .peer
2022 .forward_send(session.connection_id, connection_id, request.clone())
2023 },
2024 );
2025 Ok(())
2026}
2027
2028/// Start following another user in a call.
2029async fn follow(
2030 request: proto::Follow,
2031 response: Response<proto::Follow>,
2032 session: Session,
2033) -> Result<()> {
2034 let room_id = RoomId::from_proto(request.room_id);
2035 let project_id = request.project_id.map(ProjectId::from_proto);
2036 let leader_id = request
2037 .leader_id
2038 .ok_or_else(|| anyhow!("invalid leader id"))?
2039 .into();
2040 let follower_id = session.connection_id;
2041
2042 session
2043 .db()
2044 .await
2045 .check_room_participants(room_id, leader_id, session.connection_id)
2046 .await?;
2047
2048 let response_payload = session
2049 .peer
2050 .forward_request(session.connection_id, leader_id, request)
2051 .await?;
2052 response.send(response_payload)?;
2053
2054 if let Some(project_id) = project_id {
2055 let room = session
2056 .db()
2057 .await
2058 .follow(room_id, project_id, leader_id, follower_id)
2059 .await?;
2060 room_updated(&room, &session.peer);
2061 }
2062
2063 Ok(())
2064}
2065
2066/// Stop following another user in a call.
2067async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2068 let room_id = RoomId::from_proto(request.room_id);
2069 let project_id = request.project_id.map(ProjectId::from_proto);
2070 let leader_id = request
2071 .leader_id
2072 .ok_or_else(|| anyhow!("invalid leader id"))?
2073 .into();
2074 let follower_id = session.connection_id;
2075
2076 session
2077 .db()
2078 .await
2079 .check_room_participants(room_id, leader_id, session.connection_id)
2080 .await?;
2081
2082 session
2083 .peer
2084 .forward_send(session.connection_id, leader_id, request)?;
2085
2086 if let Some(project_id) = project_id {
2087 let room = session
2088 .db()
2089 .await
2090 .unfollow(room_id, project_id, leader_id, follower_id)
2091 .await?;
2092 room_updated(&room, &session.peer);
2093 }
2094
2095 Ok(())
2096}
2097
2098/// Notify everyone following you of your current location.
2099async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2100 let room_id = RoomId::from_proto(request.room_id);
2101 let database = session.db.lock().await;
2102
2103 let connection_ids = if let Some(project_id) = request.project_id {
2104 let project_id = ProjectId::from_proto(project_id);
2105 database
2106 .project_connection_ids(project_id, session.connection_id)
2107 .await?
2108 } else {
2109 database
2110 .room_connection_ids(room_id, session.connection_id)
2111 .await?
2112 };
2113
2114 // For now, don't send view update messages back to that view's current leader.
2115 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2116 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2117 _ => None,
2118 });
2119
2120 for connection_id in connection_ids.iter().cloned() {
2121 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2122 session
2123 .peer
2124 .forward_send(session.connection_id, connection_id, request.clone())?;
2125 }
2126 }
2127 Ok(())
2128}
2129
2130/// Get public data about users.
2131async fn get_users(
2132 request: proto::GetUsers,
2133 response: Response<proto::GetUsers>,
2134 session: Session,
2135) -> Result<()> {
2136 let user_ids = request
2137 .user_ids
2138 .into_iter()
2139 .map(UserId::from_proto)
2140 .collect();
2141 let users = session
2142 .db()
2143 .await
2144 .get_users_by_ids(user_ids)
2145 .await?
2146 .into_iter()
2147 .map(|user| proto::User {
2148 id: user.id.to_proto(),
2149 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2150 github_login: user.github_login,
2151 })
2152 .collect();
2153 response.send(proto::UsersResponse { users })?;
2154 Ok(())
2155}
2156
2157/// Search for users (to invite) buy Github login
2158async fn fuzzy_search_users(
2159 request: proto::FuzzySearchUsers,
2160 response: Response<proto::FuzzySearchUsers>,
2161 session: Session,
2162) -> Result<()> {
2163 let query = request.query;
2164 let users = match query.len() {
2165 0 => vec![],
2166 1 | 2 => session
2167 .db()
2168 .await
2169 .get_user_by_github_login(&query)
2170 .await?
2171 .into_iter()
2172 .collect(),
2173 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2174 };
2175 let users = users
2176 .into_iter()
2177 .filter(|user| user.id != session.user_id)
2178 .map(|user| proto::User {
2179 id: user.id.to_proto(),
2180 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2181 github_login: user.github_login,
2182 })
2183 .collect();
2184 response.send(proto::UsersResponse { users })?;
2185 Ok(())
2186}
2187
2188/// Send a contact request to another user.
2189async fn request_contact(
2190 request: proto::RequestContact,
2191 response: Response<proto::RequestContact>,
2192 session: Session,
2193) -> Result<()> {
2194 let requester_id = session.user_id;
2195 let responder_id = UserId::from_proto(request.responder_id);
2196 if requester_id == responder_id {
2197 return Err(anyhow!("cannot add yourself as a contact"))?;
2198 }
2199
2200 let notifications = session
2201 .db()
2202 .await
2203 .send_contact_request(requester_id, responder_id)
2204 .await?;
2205
2206 // Update outgoing contact requests of requester
2207 let mut update = proto::UpdateContacts::default();
2208 update.outgoing_requests.push(responder_id.to_proto());
2209 for connection_id in session
2210 .connection_pool()
2211 .await
2212 .user_connection_ids(requester_id)
2213 {
2214 session.peer.send(connection_id, update.clone())?;
2215 }
2216
2217 // Update incoming contact requests of responder
2218 let mut update = proto::UpdateContacts::default();
2219 update
2220 .incoming_requests
2221 .push(proto::IncomingContactRequest {
2222 requester_id: requester_id.to_proto(),
2223 });
2224 let connection_pool = session.connection_pool().await;
2225 for connection_id in connection_pool.user_connection_ids(responder_id) {
2226 session.peer.send(connection_id, update.clone())?;
2227 }
2228
2229 send_notifications(&connection_pool, &session.peer, notifications);
2230
2231 response.send(proto::Ack {})?;
2232 Ok(())
2233}
2234
2235/// Accept or decline a contact request
2236async fn respond_to_contact_request(
2237 request: proto::RespondToContactRequest,
2238 response: Response<proto::RespondToContactRequest>,
2239 session: Session,
2240) -> Result<()> {
2241 let responder_id = session.user_id;
2242 let requester_id = UserId::from_proto(request.requester_id);
2243 let db = session.db().await;
2244 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2245 db.dismiss_contact_notification(responder_id, requester_id)
2246 .await?;
2247 } else {
2248 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2249
2250 let notifications = db
2251 .respond_to_contact_request(responder_id, requester_id, accept)
2252 .await?;
2253 let requester_busy = db.is_user_busy(requester_id).await?;
2254 let responder_busy = db.is_user_busy(responder_id).await?;
2255
2256 let pool = session.connection_pool().await;
2257 // Update responder with new contact
2258 let mut update = proto::UpdateContacts::default();
2259 if accept {
2260 update
2261 .contacts
2262 .push(contact_for_user(requester_id, requester_busy, &pool));
2263 }
2264 update
2265 .remove_incoming_requests
2266 .push(requester_id.to_proto());
2267 for connection_id in pool.user_connection_ids(responder_id) {
2268 session.peer.send(connection_id, update.clone())?;
2269 }
2270
2271 // Update requester with new contact
2272 let mut update = proto::UpdateContacts::default();
2273 if accept {
2274 update
2275 .contacts
2276 .push(contact_for_user(responder_id, responder_busy, &pool));
2277 }
2278 update
2279 .remove_outgoing_requests
2280 .push(responder_id.to_proto());
2281
2282 for connection_id in pool.user_connection_ids(requester_id) {
2283 session.peer.send(connection_id, update.clone())?;
2284 }
2285
2286 send_notifications(&pool, &session.peer, notifications);
2287 }
2288
2289 response.send(proto::Ack {})?;
2290 Ok(())
2291}
2292
2293/// Remove a contact.
2294async fn remove_contact(
2295 request: proto::RemoveContact,
2296 response: Response<proto::RemoveContact>,
2297 session: Session,
2298) -> Result<()> {
2299 let requester_id = session.user_id;
2300 let responder_id = UserId::from_proto(request.user_id);
2301 let db = session.db().await;
2302 let (contact_accepted, deleted_notification_id) =
2303 db.remove_contact(requester_id, responder_id).await?;
2304
2305 let pool = session.connection_pool().await;
2306 // Update outgoing contact requests of requester
2307 let mut update = proto::UpdateContacts::default();
2308 if contact_accepted {
2309 update.remove_contacts.push(responder_id.to_proto());
2310 } else {
2311 update
2312 .remove_outgoing_requests
2313 .push(responder_id.to_proto());
2314 }
2315 for connection_id in pool.user_connection_ids(requester_id) {
2316 session.peer.send(connection_id, update.clone())?;
2317 }
2318
2319 // Update incoming contact requests of responder
2320 let mut update = proto::UpdateContacts::default();
2321 if contact_accepted {
2322 update.remove_contacts.push(requester_id.to_proto());
2323 } else {
2324 update
2325 .remove_incoming_requests
2326 .push(requester_id.to_proto());
2327 }
2328 for connection_id in pool.user_connection_ids(responder_id) {
2329 session.peer.send(connection_id, update.clone())?;
2330 if let Some(notification_id) = deleted_notification_id {
2331 session.peer.send(
2332 connection_id,
2333 proto::DeleteNotification {
2334 notification_id: notification_id.to_proto(),
2335 },
2336 )?;
2337 }
2338 }
2339
2340 response.send(proto::Ack {})?;
2341 Ok(())
2342}
2343
2344/// Creates a new channel.
2345async fn create_channel(
2346 request: proto::CreateChannel,
2347 response: Response<proto::CreateChannel>,
2348 session: Session,
2349) -> Result<()> {
2350 let db = session.db().await;
2351
2352 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2353 let (channel, owner, channel_members) = db
2354 .create_channel(&request.name, parent_id, session.user_id)
2355 .await?;
2356
2357 response.send(proto::CreateChannelResponse {
2358 channel: Some(channel.to_proto()),
2359 parent_id: request.parent_id,
2360 })?;
2361
2362 let connection_pool = session.connection_pool().await;
2363 if let Some(owner) = owner {
2364 let update = proto::UpdateUserChannels {
2365 channel_memberships: vec![proto::ChannelMembership {
2366 channel_id: owner.channel_id.to_proto(),
2367 role: owner.role.into(),
2368 }],
2369 ..Default::default()
2370 };
2371 for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2372 session.peer.send(connection_id, update.clone())?;
2373 }
2374 }
2375
2376 for channel_member in channel_members {
2377 if !channel_member.role.can_see_channel(channel.visibility) {
2378 continue;
2379 }
2380
2381 let update = proto::UpdateChannels {
2382 channels: vec![channel.to_proto()],
2383 ..Default::default()
2384 };
2385 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2386 session.peer.send(connection_id, update.clone())?;
2387 }
2388 }
2389
2390 Ok(())
2391}
2392
2393/// Delete a channel
2394async fn delete_channel(
2395 request: proto::DeleteChannel,
2396 response: Response<proto::DeleteChannel>,
2397 session: Session,
2398) -> Result<()> {
2399 let db = session.db().await;
2400
2401 let channel_id = request.channel_id;
2402 let (removed_channels, member_ids) = db
2403 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2404 .await?;
2405 response.send(proto::Ack {})?;
2406
2407 // Notify members of removed channels
2408 let mut update = proto::UpdateChannels::default();
2409 update
2410 .delete_channels
2411 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2412
2413 let connection_pool = session.connection_pool().await;
2414 for member_id in member_ids {
2415 for connection_id in connection_pool.user_connection_ids(member_id) {
2416 session.peer.send(connection_id, update.clone())?;
2417 }
2418 }
2419
2420 Ok(())
2421}
2422
2423/// Invite someone to join a channel.
2424async fn invite_channel_member(
2425 request: proto::InviteChannelMember,
2426 response: Response<proto::InviteChannelMember>,
2427 session: Session,
2428) -> Result<()> {
2429 let db = session.db().await;
2430 let channel_id = ChannelId::from_proto(request.channel_id);
2431 let invitee_id = UserId::from_proto(request.user_id);
2432 let InviteMemberResult {
2433 channel,
2434 notifications,
2435 } = db
2436 .invite_channel_member(
2437 channel_id,
2438 invitee_id,
2439 session.user_id,
2440 request.role().into(),
2441 )
2442 .await?;
2443
2444 let update = proto::UpdateChannels {
2445 channel_invitations: vec![channel.to_proto()],
2446 ..Default::default()
2447 };
2448
2449 let connection_pool = session.connection_pool().await;
2450 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2451 session.peer.send(connection_id, update.clone())?;
2452 }
2453
2454 send_notifications(&connection_pool, &session.peer, notifications);
2455
2456 response.send(proto::Ack {})?;
2457 Ok(())
2458}
2459
2460/// remove someone from a channel
2461async fn remove_channel_member(
2462 request: proto::RemoveChannelMember,
2463 response: Response<proto::RemoveChannelMember>,
2464 session: Session,
2465) -> Result<()> {
2466 let db = session.db().await;
2467 let channel_id = ChannelId::from_proto(request.channel_id);
2468 let member_id = UserId::from_proto(request.user_id);
2469
2470 let RemoveChannelMemberResult {
2471 membership_update,
2472 notification_id,
2473 } = db
2474 .remove_channel_member(channel_id, member_id, session.user_id)
2475 .await?;
2476
2477 let connection_pool = &session.connection_pool().await;
2478 notify_membership_updated(
2479 &connection_pool,
2480 membership_update,
2481 member_id,
2482 &session.peer,
2483 );
2484 for connection_id in connection_pool.user_connection_ids(member_id) {
2485 if let Some(notification_id) = notification_id {
2486 session
2487 .peer
2488 .send(
2489 connection_id,
2490 proto::DeleteNotification {
2491 notification_id: notification_id.to_proto(),
2492 },
2493 )
2494 .trace_err();
2495 }
2496 }
2497
2498 response.send(proto::Ack {})?;
2499 Ok(())
2500}
2501
2502/// Toggle the channel between public and private.
2503/// Care is taken to maintain the invariant that public channels only descend from public channels,
2504/// (though members-only channels can appear at any point in the hierarchy).
2505async fn set_channel_visibility(
2506 request: proto::SetChannelVisibility,
2507 response: Response<proto::SetChannelVisibility>,
2508 session: Session,
2509) -> Result<()> {
2510 let db = session.db().await;
2511 let channel_id = ChannelId::from_proto(request.channel_id);
2512 let visibility = request.visibility().into();
2513
2514 let (channel, channel_members) = db
2515 .set_channel_visibility(channel_id, visibility, session.user_id)
2516 .await?;
2517
2518 let connection_pool = session.connection_pool().await;
2519 for member in channel_members {
2520 let update = if member.role.can_see_channel(channel.visibility) {
2521 proto::UpdateChannels {
2522 channels: vec![channel.to_proto()],
2523 ..Default::default()
2524 }
2525 } else {
2526 proto::UpdateChannels {
2527 delete_channels: vec![channel.id.to_proto()],
2528 ..Default::default()
2529 }
2530 };
2531
2532 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2533 session.peer.send(connection_id, update.clone())?;
2534 }
2535 }
2536
2537 response.send(proto::Ack {})?;
2538 Ok(())
2539}
2540
2541/// Alter the role for a user in the channel.
2542async fn set_channel_member_role(
2543 request: proto::SetChannelMemberRole,
2544 response: Response<proto::SetChannelMemberRole>,
2545 session: Session,
2546) -> Result<()> {
2547 let db = session.db().await;
2548 let channel_id = ChannelId::from_proto(request.channel_id);
2549 let member_id = UserId::from_proto(request.user_id);
2550 let result = db
2551 .set_channel_member_role(
2552 channel_id,
2553 session.user_id,
2554 member_id,
2555 request.role().into(),
2556 )
2557 .await?;
2558
2559 match result {
2560 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2561 let connection_pool = session.connection_pool().await;
2562 notify_membership_updated(
2563 &connection_pool,
2564 membership_update,
2565 member_id,
2566 &session.peer,
2567 )
2568 }
2569 db::SetMemberRoleResult::InviteUpdated(channel) => {
2570 let update = proto::UpdateChannels {
2571 channel_invitations: vec![channel.to_proto()],
2572 ..Default::default()
2573 };
2574
2575 for connection_id in session
2576 .connection_pool()
2577 .await
2578 .user_connection_ids(member_id)
2579 {
2580 session.peer.send(connection_id, update.clone())?;
2581 }
2582 }
2583 }
2584
2585 response.send(proto::Ack {})?;
2586 Ok(())
2587}
2588
2589/// Change the name of a channel
2590async fn rename_channel(
2591 request: proto::RenameChannel,
2592 response: Response<proto::RenameChannel>,
2593 session: Session,
2594) -> Result<()> {
2595 let db = session.db().await;
2596 let channel_id = ChannelId::from_proto(request.channel_id);
2597 let (channel, channel_members) = db
2598 .rename_channel(channel_id, session.user_id, &request.name)
2599 .await?;
2600
2601 response.send(proto::RenameChannelResponse {
2602 channel: Some(channel.to_proto()),
2603 })?;
2604
2605 let connection_pool = session.connection_pool().await;
2606 for channel_member in channel_members {
2607 if !channel_member.role.can_see_channel(channel.visibility) {
2608 continue;
2609 }
2610 let update = proto::UpdateChannels {
2611 channels: vec![channel.to_proto()],
2612 ..Default::default()
2613 };
2614 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2615 session.peer.send(connection_id, update.clone())?;
2616 }
2617 }
2618
2619 Ok(())
2620}
2621
2622/// Move a channel to a new parent.
2623async fn move_channel(
2624 request: proto::MoveChannel,
2625 response: Response<proto::MoveChannel>,
2626 session: Session,
2627) -> Result<()> {
2628 let channel_id = ChannelId::from_proto(request.channel_id);
2629 let to = ChannelId::from_proto(request.to);
2630
2631 let (channels, channel_members) = session
2632 .db()
2633 .await
2634 .move_channel(channel_id, to, session.user_id)
2635 .await?;
2636
2637 let connection_pool = session.connection_pool().await;
2638 for member in channel_members {
2639 let channels = channels
2640 .iter()
2641 .filter_map(|channel| {
2642 if member.role.can_see_channel(channel.visibility) {
2643 Some(channel.to_proto())
2644 } else {
2645 None
2646 }
2647 })
2648 .collect::<Vec<_>>();
2649 if channels.is_empty() {
2650 continue;
2651 }
2652
2653 let update = proto::UpdateChannels {
2654 channels,
2655 ..Default::default()
2656 };
2657
2658 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2659 session.peer.send(connection_id, update.clone())?;
2660 }
2661 }
2662
2663 response.send(Ack {})?;
2664 Ok(())
2665}
2666
2667/// Get the list of channel members
2668async fn get_channel_members(
2669 request: proto::GetChannelMembers,
2670 response: Response<proto::GetChannelMembers>,
2671 session: Session,
2672) -> Result<()> {
2673 let db = session.db().await;
2674 let channel_id = ChannelId::from_proto(request.channel_id);
2675 let members = db
2676 .get_channel_participant_details(channel_id, session.user_id)
2677 .await?;
2678 response.send(proto::GetChannelMembersResponse { members })?;
2679 Ok(())
2680}
2681
2682/// Accept or decline a channel invitation.
2683async fn respond_to_channel_invite(
2684 request: proto::RespondToChannelInvite,
2685 response: Response<proto::RespondToChannelInvite>,
2686 session: Session,
2687) -> Result<()> {
2688 let db = session.db().await;
2689 let channel_id = ChannelId::from_proto(request.channel_id);
2690 let RespondToChannelInvite {
2691 membership_update,
2692 notifications,
2693 } = db
2694 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2695 .await?;
2696
2697 let connection_pool = session.connection_pool().await;
2698 if let Some(membership_update) = membership_update {
2699 notify_membership_updated(
2700 &connection_pool,
2701 membership_update,
2702 session.user_id,
2703 &session.peer,
2704 );
2705 } else {
2706 let update = proto::UpdateChannels {
2707 remove_channel_invitations: vec![channel_id.to_proto()],
2708 ..Default::default()
2709 };
2710
2711 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2712 session.peer.send(connection_id, update.clone())?;
2713 }
2714 };
2715
2716 send_notifications(&connection_pool, &session.peer, notifications);
2717
2718 response.send(proto::Ack {})?;
2719
2720 Ok(())
2721}
2722
2723/// Join the channels' room
2724async fn join_channel(
2725 request: proto::JoinChannel,
2726 response: Response<proto::JoinChannel>,
2727 session: Session,
2728) -> Result<()> {
2729 let channel_id = ChannelId::from_proto(request.channel_id);
2730 join_channel_internal(channel_id, Box::new(response), session).await
2731}
2732
2733trait JoinChannelInternalResponse {
2734 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2735}
2736impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2737 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2738 Response::<proto::JoinChannel>::send(self, result)
2739 }
2740}
2741impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2742 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2743 Response::<proto::JoinRoom>::send(self, result)
2744 }
2745}
2746
2747async fn join_channel_internal(
2748 channel_id: ChannelId,
2749 response: Box<impl JoinChannelInternalResponse>,
2750 session: Session,
2751) -> Result<()> {
2752 let joined_room = {
2753 leave_room_for_session(&session).await?;
2754 let db = session.db().await;
2755
2756 let (joined_room, membership_updated, role) = db
2757 .join_channel(channel_id, session.user_id, session.connection_id)
2758 .await?;
2759
2760 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2761 let (can_publish, token) = if role == ChannelRole::Guest {
2762 (
2763 false,
2764 live_kit
2765 .guest_token(
2766 &joined_room.room.live_kit_room,
2767 &session.user_id.to_string(),
2768 )
2769 .trace_err()?,
2770 )
2771 } else {
2772 (
2773 true,
2774 live_kit
2775 .room_token(
2776 &joined_room.room.live_kit_room,
2777 &session.user_id.to_string(),
2778 )
2779 .trace_err()?,
2780 )
2781 };
2782
2783 Some(LiveKitConnectionInfo {
2784 server_url: live_kit.url().into(),
2785 token,
2786 can_publish,
2787 })
2788 });
2789
2790 response.send(proto::JoinRoomResponse {
2791 room: Some(joined_room.room.clone()),
2792 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2793 live_kit_connection_info,
2794 })?;
2795
2796 let connection_pool = session.connection_pool().await;
2797 if let Some(membership_updated) = membership_updated {
2798 notify_membership_updated(
2799 &connection_pool,
2800 membership_updated,
2801 session.user_id,
2802 &session.peer,
2803 );
2804 }
2805
2806 room_updated(&joined_room.room, &session.peer);
2807
2808 joined_room
2809 };
2810
2811 channel_updated(
2812 channel_id,
2813 &joined_room.room,
2814 &joined_room.channel_members,
2815 &session.peer,
2816 &*session.connection_pool().await,
2817 );
2818
2819 update_user_contacts(session.user_id, &session).await?;
2820 Ok(())
2821}
2822
2823/// Start editing the channel notes
2824async fn join_channel_buffer(
2825 request: proto::JoinChannelBuffer,
2826 response: Response<proto::JoinChannelBuffer>,
2827 session: Session,
2828) -> Result<()> {
2829 let db = session.db().await;
2830 let channel_id = ChannelId::from_proto(request.channel_id);
2831
2832 let open_response = db
2833 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2834 .await?;
2835
2836 let collaborators = open_response.collaborators.clone();
2837 response.send(open_response)?;
2838
2839 let update = UpdateChannelBufferCollaborators {
2840 channel_id: channel_id.to_proto(),
2841 collaborators: collaborators.clone(),
2842 };
2843 channel_buffer_updated(
2844 session.connection_id,
2845 collaborators
2846 .iter()
2847 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2848 &update,
2849 &session.peer,
2850 );
2851
2852 Ok(())
2853}
2854
2855/// Edit the channel notes
2856async fn update_channel_buffer(
2857 request: proto::UpdateChannelBuffer,
2858 session: Session,
2859) -> Result<()> {
2860 let db = session.db().await;
2861 let channel_id = ChannelId::from_proto(request.channel_id);
2862
2863 let (collaborators, non_collaborators, epoch, version) = db
2864 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2865 .await?;
2866
2867 channel_buffer_updated(
2868 session.connection_id,
2869 collaborators,
2870 &proto::UpdateChannelBuffer {
2871 channel_id: channel_id.to_proto(),
2872 operations: request.operations,
2873 },
2874 &session.peer,
2875 );
2876
2877 let pool = &*session.connection_pool().await;
2878
2879 broadcast(
2880 None,
2881 non_collaborators
2882 .iter()
2883 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2884 |peer_id| {
2885 session.peer.send(
2886 peer_id,
2887 proto::UpdateChannels {
2888 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2889 channel_id: channel_id.to_proto(),
2890 epoch: epoch as u64,
2891 version: version.clone(),
2892 }],
2893 ..Default::default()
2894 },
2895 )
2896 },
2897 );
2898
2899 Ok(())
2900}
2901
2902/// Rejoin the channel notes after a connection blip
2903async fn rejoin_channel_buffers(
2904 request: proto::RejoinChannelBuffers,
2905 response: Response<proto::RejoinChannelBuffers>,
2906 session: Session,
2907) -> Result<()> {
2908 let db = session.db().await;
2909 let buffers = db
2910 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2911 .await?;
2912
2913 for rejoined_buffer in &buffers {
2914 let collaborators_to_notify = rejoined_buffer
2915 .buffer
2916 .collaborators
2917 .iter()
2918 .filter_map(|c| Some(c.peer_id?.into()));
2919 channel_buffer_updated(
2920 session.connection_id,
2921 collaborators_to_notify,
2922 &proto::UpdateChannelBufferCollaborators {
2923 channel_id: rejoined_buffer.buffer.channel_id,
2924 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2925 },
2926 &session.peer,
2927 );
2928 }
2929
2930 response.send(proto::RejoinChannelBuffersResponse {
2931 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2932 })?;
2933
2934 Ok(())
2935}
2936
2937/// Stop editing the channel notes
2938async fn leave_channel_buffer(
2939 request: proto::LeaveChannelBuffer,
2940 response: Response<proto::LeaveChannelBuffer>,
2941 session: Session,
2942) -> Result<()> {
2943 let db = session.db().await;
2944 let channel_id = ChannelId::from_proto(request.channel_id);
2945
2946 let left_buffer = db
2947 .leave_channel_buffer(channel_id, session.connection_id)
2948 .await?;
2949
2950 response.send(Ack {})?;
2951
2952 channel_buffer_updated(
2953 session.connection_id,
2954 left_buffer.connections,
2955 &proto::UpdateChannelBufferCollaborators {
2956 channel_id: channel_id.to_proto(),
2957 collaborators: left_buffer.collaborators,
2958 },
2959 &session.peer,
2960 );
2961
2962 Ok(())
2963}
2964
2965fn channel_buffer_updated<T: EnvelopedMessage>(
2966 sender_id: ConnectionId,
2967 collaborators: impl IntoIterator<Item = ConnectionId>,
2968 message: &T,
2969 peer: &Peer,
2970) {
2971 broadcast(Some(sender_id), collaborators, |peer_id| {
2972 peer.send(peer_id, message.clone())
2973 });
2974}
2975
2976fn send_notifications(
2977 connection_pool: &ConnectionPool,
2978 peer: &Peer,
2979 notifications: db::NotificationBatch,
2980) {
2981 for (user_id, notification) in notifications {
2982 for connection_id in connection_pool.user_connection_ids(user_id) {
2983 if let Err(error) = peer.send(
2984 connection_id,
2985 proto::AddNotification {
2986 notification: Some(notification.clone()),
2987 },
2988 ) {
2989 tracing::error!(
2990 "failed to send notification to {:?} {}",
2991 connection_id,
2992 error
2993 );
2994 }
2995 }
2996 }
2997}
2998
2999/// Send a message to the channel
3000async fn send_channel_message(
3001 request: proto::SendChannelMessage,
3002 response: Response<proto::SendChannelMessage>,
3003 session: Session,
3004) -> Result<()> {
3005 // Validate the message body.
3006 let body = request.body.trim().to_string();
3007 if body.len() > MAX_MESSAGE_LEN {
3008 return Err(anyhow!("message is too long"))?;
3009 }
3010 if body.is_empty() {
3011 return Err(anyhow!("message can't be blank"))?;
3012 }
3013
3014 // TODO: adjust mentions if body is trimmed
3015
3016 let timestamp = OffsetDateTime::now_utc();
3017 let nonce = request
3018 .nonce
3019 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3020
3021 let channel_id = ChannelId::from_proto(request.channel_id);
3022 let CreatedChannelMessage {
3023 message_id,
3024 participant_connection_ids,
3025 channel_members,
3026 notifications,
3027 } = session
3028 .db()
3029 .await
3030 .create_channel_message(
3031 channel_id,
3032 session.user_id,
3033 &body,
3034 &request.mentions,
3035 timestamp,
3036 nonce.clone().into(),
3037 match request.reply_to_message_id {
3038 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3039 None => None,
3040 },
3041 )
3042 .await?;
3043 let message = proto::ChannelMessage {
3044 sender_id: session.user_id.to_proto(),
3045 id: message_id.to_proto(),
3046 body,
3047 mentions: request.mentions,
3048 timestamp: timestamp.unix_timestamp() as u64,
3049 nonce: Some(nonce),
3050 reply_to_message_id: request.reply_to_message_id,
3051 };
3052 broadcast(
3053 Some(session.connection_id),
3054 participant_connection_ids,
3055 |connection| {
3056 session.peer.send(
3057 connection,
3058 proto::ChannelMessageSent {
3059 channel_id: channel_id.to_proto(),
3060 message: Some(message.clone()),
3061 },
3062 )
3063 },
3064 );
3065 response.send(proto::SendChannelMessageResponse {
3066 message: Some(message),
3067 })?;
3068
3069 let pool = &*session.connection_pool().await;
3070 broadcast(
3071 None,
3072 channel_members
3073 .iter()
3074 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3075 |peer_id| {
3076 session.peer.send(
3077 peer_id,
3078 proto::UpdateChannels {
3079 latest_channel_message_ids: vec![proto::ChannelMessageId {
3080 channel_id: channel_id.to_proto(),
3081 message_id: message_id.to_proto(),
3082 }],
3083 ..Default::default()
3084 },
3085 )
3086 },
3087 );
3088 send_notifications(pool, &session.peer, notifications);
3089
3090 Ok(())
3091}
3092
3093/// Delete a channel message
3094async fn remove_channel_message(
3095 request: proto::RemoveChannelMessage,
3096 response: Response<proto::RemoveChannelMessage>,
3097 session: Session,
3098) -> Result<()> {
3099 let channel_id = ChannelId::from_proto(request.channel_id);
3100 let message_id = MessageId::from_proto(request.message_id);
3101 let connection_ids = session
3102 .db()
3103 .await
3104 .remove_channel_message(channel_id, message_id, session.user_id)
3105 .await?;
3106 broadcast(Some(session.connection_id), connection_ids, |connection| {
3107 session.peer.send(connection, request.clone())
3108 });
3109 response.send(proto::Ack {})?;
3110 Ok(())
3111}
3112
3113/// Mark a channel message as read
3114async fn acknowledge_channel_message(
3115 request: proto::AckChannelMessage,
3116 session: Session,
3117) -> Result<()> {
3118 let channel_id = ChannelId::from_proto(request.channel_id);
3119 let message_id = MessageId::from_proto(request.message_id);
3120 let notifications = session
3121 .db()
3122 .await
3123 .observe_channel_message(channel_id, session.user_id, message_id)
3124 .await?;
3125 send_notifications(
3126 &*session.connection_pool().await,
3127 &session.peer,
3128 notifications,
3129 );
3130 Ok(())
3131}
3132
3133/// Mark a buffer version as synced
3134async fn acknowledge_buffer_version(
3135 request: proto::AckBufferOperation,
3136 session: Session,
3137) -> Result<()> {
3138 let buffer_id = BufferId::from_proto(request.buffer_id);
3139 session
3140 .db()
3141 .await
3142 .observe_buffer_version(
3143 buffer_id,
3144 session.user_id,
3145 request.epoch as i32,
3146 &request.version,
3147 )
3148 .await?;
3149 Ok(())
3150}
3151
3152/// Start receiving chat updates for a channel
3153async fn join_channel_chat(
3154 request: proto::JoinChannelChat,
3155 response: Response<proto::JoinChannelChat>,
3156 session: Session,
3157) -> Result<()> {
3158 let channel_id = ChannelId::from_proto(request.channel_id);
3159
3160 let db = session.db().await;
3161 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3162 .await?;
3163 let messages = db
3164 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3165 .await?;
3166 response.send(proto::JoinChannelChatResponse {
3167 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3168 messages,
3169 })?;
3170 Ok(())
3171}
3172
3173/// Stop receiving chat updates for a channel
3174async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3175 let channel_id = ChannelId::from_proto(request.channel_id);
3176 session
3177 .db()
3178 .await
3179 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3180 .await?;
3181 Ok(())
3182}
3183
3184/// Retrieve the chat history for a channel
3185async fn get_channel_messages(
3186 request: proto::GetChannelMessages,
3187 response: Response<proto::GetChannelMessages>,
3188 session: Session,
3189) -> Result<()> {
3190 let channel_id = ChannelId::from_proto(request.channel_id);
3191 let messages = session
3192 .db()
3193 .await
3194 .get_channel_messages(
3195 channel_id,
3196 session.user_id,
3197 MESSAGE_COUNT_PER_PAGE,
3198 Some(MessageId::from_proto(request.before_message_id)),
3199 )
3200 .await?;
3201 response.send(proto::GetChannelMessagesResponse {
3202 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3203 messages,
3204 })?;
3205 Ok(())
3206}
3207
3208/// Retrieve specific chat messages
3209async fn get_channel_messages_by_id(
3210 request: proto::GetChannelMessagesById,
3211 response: Response<proto::GetChannelMessagesById>,
3212 session: Session,
3213) -> Result<()> {
3214 let message_ids = request
3215 .message_ids
3216 .iter()
3217 .map(|id| MessageId::from_proto(*id))
3218 .collect::<Vec<_>>();
3219 let messages = session
3220 .db()
3221 .await
3222 .get_channel_messages_by_id(session.user_id, &message_ids)
3223 .await?;
3224 response.send(proto::GetChannelMessagesResponse {
3225 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3226 messages,
3227 })?;
3228 Ok(())
3229}
3230
3231/// Retrieve the current users notifications
3232async fn get_notifications(
3233 request: proto::GetNotifications,
3234 response: Response<proto::GetNotifications>,
3235 session: Session,
3236) -> Result<()> {
3237 let notifications = session
3238 .db()
3239 .await
3240 .get_notifications(
3241 session.user_id,
3242 NOTIFICATION_COUNT_PER_PAGE,
3243 request
3244 .before_id
3245 .map(|id| db::NotificationId::from_proto(id)),
3246 )
3247 .await?;
3248 response.send(proto::GetNotificationsResponse {
3249 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3250 notifications,
3251 })?;
3252 Ok(())
3253}
3254
3255/// Mark notifications as read
3256async fn mark_notification_as_read(
3257 request: proto::MarkNotificationRead,
3258 response: Response<proto::MarkNotificationRead>,
3259 session: Session,
3260) -> Result<()> {
3261 let database = &session.db().await;
3262 let notifications = database
3263 .mark_notification_as_read_by_id(
3264 session.user_id,
3265 NotificationId::from_proto(request.notification_id),
3266 )
3267 .await?;
3268 send_notifications(
3269 &*session.connection_pool().await,
3270 &session.peer,
3271 notifications,
3272 );
3273 response.send(proto::Ack {})?;
3274 Ok(())
3275}
3276
3277/// Get the current users information
3278async fn get_private_user_info(
3279 _request: proto::GetPrivateUserInfo,
3280 response: Response<proto::GetPrivateUserInfo>,
3281 session: Session,
3282) -> Result<()> {
3283 let db = session.db().await;
3284
3285 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3286 let user = db
3287 .get_user_by_id(session.user_id)
3288 .await?
3289 .ok_or_else(|| anyhow!("user not found"))?;
3290 let flags = db.get_user_flags(session.user_id).await?;
3291
3292 response.send(proto::GetPrivateUserInfoResponse {
3293 metrics_id,
3294 staff: user.admin,
3295 flags,
3296 })?;
3297 Ok(())
3298}
3299
3300fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3301 match message {
3302 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3303 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3304 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3305 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3306 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3307 code: frame.code.into(),
3308 reason: frame.reason,
3309 })),
3310 }
3311}
3312
3313fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3314 match message {
3315 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3316 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3317 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3318 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3319 AxumMessage::Close(frame) => {
3320 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3321 code: frame.code.into(),
3322 reason: frame.reason,
3323 }))
3324 }
3325 }
3326}
3327
3328fn notify_membership_updated(
3329 connection_pool: &ConnectionPool,
3330 result: MembershipUpdated,
3331 user_id: UserId,
3332 peer: &Peer,
3333) {
3334 let user_channels_update = proto::UpdateUserChannels {
3335 channel_memberships: result
3336 .new_channels
3337 .channel_memberships
3338 .iter()
3339 .map(|cm| proto::ChannelMembership {
3340 channel_id: cm.channel_id.to_proto(),
3341 role: cm.role.into(),
3342 })
3343 .collect(),
3344 ..Default::default()
3345 };
3346 let mut update = build_channels_update(result.new_channels, vec![]);
3347 update.delete_channels = result
3348 .removed_channels
3349 .into_iter()
3350 .map(|id| id.to_proto())
3351 .collect();
3352 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3353
3354 for connection_id in connection_pool.user_connection_ids(user_id) {
3355 peer.send(connection_id, user_channels_update.clone())
3356 .trace_err();
3357 peer.send(connection_id, update.clone()).trace_err();
3358 }
3359}
3360
3361fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3362 proto::UpdateUserChannels {
3363 channel_memberships: channels
3364 .channel_memberships
3365 .iter()
3366 .map(|m| proto::ChannelMembership {
3367 channel_id: m.channel_id.to_proto(),
3368 role: m.role.into(),
3369 })
3370 .collect(),
3371 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3372 observed_channel_message_id: channels.observed_channel_messages.clone(),
3373 ..Default::default()
3374 }
3375}
3376
3377fn build_channels_update(
3378 channels: ChannelsForUser,
3379 channel_invites: Vec<db::Channel>,
3380) -> proto::UpdateChannels {
3381 let mut update = proto::UpdateChannels::default();
3382
3383 for channel in channels.channels {
3384 update.channels.push(channel.to_proto());
3385 }
3386
3387 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3388 update.latest_channel_message_ids = channels.latest_channel_messages;
3389
3390 for (channel_id, participants) in channels.channel_participants {
3391 update
3392 .channel_participants
3393 .push(proto::ChannelParticipants {
3394 channel_id: channel_id.to_proto(),
3395 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3396 });
3397 }
3398
3399 for channel in channel_invites {
3400 update.channel_invitations.push(channel.to_proto());
3401 }
3402 for project in channels.hosted_projects {
3403 update.hosted_projects.push(project);
3404 }
3405
3406 update
3407}
3408
3409fn build_initial_contacts_update(
3410 contacts: Vec<db::Contact>,
3411 pool: &ConnectionPool,
3412) -> proto::UpdateContacts {
3413 let mut update = proto::UpdateContacts::default();
3414
3415 for contact in contacts {
3416 match contact {
3417 db::Contact::Accepted { user_id, busy } => {
3418 update.contacts.push(contact_for_user(user_id, busy, &pool));
3419 }
3420 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3421 db::Contact::Incoming { user_id } => {
3422 update
3423 .incoming_requests
3424 .push(proto::IncomingContactRequest {
3425 requester_id: user_id.to_proto(),
3426 })
3427 }
3428 }
3429 }
3430
3431 update
3432}
3433
3434fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3435 proto::Contact {
3436 user_id: user_id.to_proto(),
3437 online: pool.is_user_online(user_id),
3438 busy,
3439 }
3440}
3441
3442fn room_updated(room: &proto::Room, peer: &Peer) {
3443 broadcast(
3444 None,
3445 room.participants
3446 .iter()
3447 .filter_map(|participant| Some(participant.peer_id?.into())),
3448 |peer_id| {
3449 peer.send(
3450 peer_id,
3451 proto::RoomUpdated {
3452 room: Some(room.clone()),
3453 },
3454 )
3455 },
3456 );
3457}
3458
3459fn channel_updated(
3460 channel_id: ChannelId,
3461 room: &proto::Room,
3462 channel_members: &[UserId],
3463 peer: &Peer,
3464 pool: &ConnectionPool,
3465) {
3466 let participants = room
3467 .participants
3468 .iter()
3469 .map(|p| p.user_id)
3470 .collect::<Vec<_>>();
3471
3472 broadcast(
3473 None,
3474 channel_members
3475 .iter()
3476 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3477 |peer_id| {
3478 peer.send(
3479 peer_id,
3480 proto::UpdateChannels {
3481 channel_participants: vec![proto::ChannelParticipants {
3482 channel_id: channel_id.to_proto(),
3483 participant_user_ids: participants.clone(),
3484 }],
3485 ..Default::default()
3486 },
3487 )
3488 },
3489 );
3490}
3491
3492async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3493 let db = session.db().await;
3494
3495 let contacts = db.get_contacts(user_id).await?;
3496 let busy = db.is_user_busy(user_id).await?;
3497
3498 let pool = session.connection_pool().await;
3499 let updated_contact = contact_for_user(user_id, busy, &pool);
3500 for contact in contacts {
3501 if let db::Contact::Accepted {
3502 user_id: contact_user_id,
3503 ..
3504 } = contact
3505 {
3506 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3507 session
3508 .peer
3509 .send(
3510 contact_conn_id,
3511 proto::UpdateContacts {
3512 contacts: vec![updated_contact.clone()],
3513 remove_contacts: Default::default(),
3514 incoming_requests: Default::default(),
3515 remove_incoming_requests: Default::default(),
3516 outgoing_requests: Default::default(),
3517 remove_outgoing_requests: Default::default(),
3518 },
3519 )
3520 .trace_err();
3521 }
3522 }
3523 }
3524 Ok(())
3525}
3526
3527async fn leave_room_for_session(session: &Session) -> Result<()> {
3528 let mut contacts_to_update = HashSet::default();
3529
3530 let room_id;
3531 let canceled_calls_to_user_ids;
3532 let live_kit_room;
3533 let delete_live_kit_room;
3534 let room;
3535 let channel_members;
3536 let channel_id;
3537
3538 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3539 contacts_to_update.insert(session.user_id);
3540
3541 for project in left_room.left_projects.values() {
3542 project_left(project, session);
3543 }
3544
3545 room_id = RoomId::from_proto(left_room.room.id);
3546 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3547 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3548 delete_live_kit_room = left_room.deleted;
3549 room = mem::take(&mut left_room.room);
3550 channel_members = mem::take(&mut left_room.channel_members);
3551 channel_id = left_room.channel_id;
3552
3553 room_updated(&room, &session.peer);
3554 } else {
3555 return Ok(());
3556 }
3557
3558 if let Some(channel_id) = channel_id {
3559 channel_updated(
3560 channel_id,
3561 &room,
3562 &channel_members,
3563 &session.peer,
3564 &*session.connection_pool().await,
3565 );
3566 }
3567
3568 {
3569 let pool = session.connection_pool().await;
3570 for canceled_user_id in canceled_calls_to_user_ids {
3571 for connection_id in pool.user_connection_ids(canceled_user_id) {
3572 session
3573 .peer
3574 .send(
3575 connection_id,
3576 proto::CallCanceled {
3577 room_id: room_id.to_proto(),
3578 },
3579 )
3580 .trace_err();
3581 }
3582 contacts_to_update.insert(canceled_user_id);
3583 }
3584 }
3585
3586 for contact_user_id in contacts_to_update {
3587 update_user_contacts(contact_user_id, &session).await?;
3588 }
3589
3590 if let Some(live_kit) = session.live_kit_client.as_ref() {
3591 live_kit
3592 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3593 .await
3594 .trace_err();
3595
3596 if delete_live_kit_room {
3597 live_kit.delete_room(live_kit_room).await.trace_err();
3598 }
3599 }
3600
3601 Ok(())
3602}
3603
3604async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3605 let left_channel_buffers = session
3606 .db()
3607 .await
3608 .leave_channel_buffers(session.connection_id)
3609 .await?;
3610
3611 for left_buffer in left_channel_buffers {
3612 channel_buffer_updated(
3613 session.connection_id,
3614 left_buffer.connections,
3615 &proto::UpdateChannelBufferCollaborators {
3616 channel_id: left_buffer.channel_id.to_proto(),
3617 collaborators: left_buffer.collaborators,
3618 },
3619 &session.peer,
3620 );
3621 }
3622
3623 Ok(())
3624}
3625
3626fn project_left(project: &db::LeftProject, session: &Session) {
3627 for connection_id in &project.connection_ids {
3628 if project.host_user_id == session.user_id {
3629 session
3630 .peer
3631 .send(
3632 *connection_id,
3633 proto::UnshareProject {
3634 project_id: project.id.to_proto(),
3635 },
3636 )
3637 .trace_err();
3638 } else {
3639 session
3640 .peer
3641 .send(
3642 *connection_id,
3643 proto::RemoveProjectCollaborator {
3644 project_id: project.id.to_proto(),
3645 peer_id: Some(session.connection_id.into()),
3646 },
3647 )
3648 .trace_err();
3649 }
3650 }
3651}
3652
3653pub trait ResultExt {
3654 type Ok;
3655
3656 fn trace_err(self) -> Option<Self::Ok>;
3657}
3658
3659impl<T, E> ResultExt for Result<T, E>
3660where
3661 E: std::fmt::Debug,
3662{
3663 type Ok = T;
3664
3665 fn trace_err(self) -> Option<T> {
3666 match self {
3667 Ok(value) => Some(value),
3668 Err(error) => {
3669 tracing::error!("{:?}", error);
3670 None
3671 }
3672 }
3673 }
3674}