1mod connection_pool;
2
3use crate::{
4 auth::{self, Impersonator},
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
7 InviteMemberResult, MembershipUpdated, MessageId, NotificationId, Project, ProjectId,
8 RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId, ServerId, User,
9 UserId,
10 },
11 executor::Executor,
12 AppState, Error, Result,
13};
14use anyhow::anyhow;
15use async_tungstenite::tungstenite::{
16 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
17};
18use axum::{
19 body::Body,
20 extract::{
21 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
22 ConnectInfo, WebSocketUpgrade,
23 },
24 headers::{Header, HeaderName},
25 http::StatusCode,
26 middleware,
27 response::IntoResponse,
28 routing::get,
29 Extension, Router, TypedHeader,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use futures::{
34 channel::oneshot,
35 future::{self, BoxFuture},
36 stream::FuturesUnordered,
37 FutureExt, SinkExt, StreamExt, TryStreamExt,
38};
39use prometheus::{register_int_gauge, IntGauge};
40use rpc::{
41 proto::{
42 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
43 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
44 },
45 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
46};
47use serde::{Serialize, Serializer};
48use std::{
49 any::TypeId,
50 fmt,
51 future::Future,
52 marker::PhantomData,
53 mem,
54 net::SocketAddr,
55 ops::{Deref, DerefMut},
56 rc::Rc,
57 sync::{
58 atomic::{AtomicBool, Ordering::SeqCst},
59 Arc, OnceLock,
60 },
61 time::{Duration, Instant},
62};
63use time::OffsetDateTime;
64use tokio::sync::{watch, Semaphore};
65use tower::ServiceBuilder;
66use tracing::{field, info_span, instrument, Instrument};
67use util::SemanticVersion;
68
69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
70
71// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
72pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
73
74const MESSAGE_COUNT_PER_PAGE: usize = 100;
75const MAX_MESSAGE_LEN: usize = 1024;
76const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
77
78type MessageHandler =
79 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
80
81struct Response<R> {
82 peer: Arc<Peer>,
83 receipt: Receipt<R>,
84 responded: Arc<AtomicBool>,
85}
86
87impl<R: RequestMessage> Response<R> {
88 fn send(self, payload: R::Response) -> Result<()> {
89 self.responded.store(true, SeqCst);
90 self.peer.respond(self.receipt, payload)?;
91 Ok(())
92 }
93}
94
95#[derive(Clone)]
96struct Session {
97 user_id: UserId,
98 connection_id: ConnectionId,
99 db: Arc<tokio::sync::Mutex<DbHandle>>,
100 peer: Arc<Peer>,
101 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
102 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
103 _executor: Executor,
104}
105
106impl Session {
107 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
108 #[cfg(test)]
109 tokio::task::yield_now().await;
110 let guard = self.db.lock().await;
111 #[cfg(test)]
112 tokio::task::yield_now().await;
113 guard
114 }
115
116 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
117 #[cfg(test)]
118 tokio::task::yield_now().await;
119 let guard = self.connection_pool.lock();
120 ConnectionPoolGuard {
121 guard,
122 _not_send: PhantomData,
123 }
124 }
125}
126
127impl fmt::Debug for Session {
128 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129 f.debug_struct("Session")
130 .field("user_id", &self.user_id)
131 .field("connection_id", &self.connection_id)
132 .finish()
133 }
134}
135
136struct DbHandle(Arc<Database>);
137
138impl Deref for DbHandle {
139 type Target = Database;
140
141 fn deref(&self) -> &Self::Target {
142 self.0.as_ref()
143 }
144}
145
146pub struct Server {
147 id: parking_lot::Mutex<ServerId>,
148 peer: Arc<Peer>,
149 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
150 app_state: Arc<AppState>,
151 executor: Executor,
152 handlers: HashMap<TypeId, MessageHandler>,
153 teardown: watch::Sender<bool>,
154}
155
156pub(crate) struct ConnectionPoolGuard<'a> {
157 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
158 _not_send: PhantomData<Rc<()>>,
159}
160
161#[derive(Serialize)]
162pub struct ServerSnapshot<'a> {
163 peer: &'a Peer,
164 #[serde(serialize_with = "serialize_deref")]
165 connection_pool: ConnectionPoolGuard<'a>,
166}
167
168pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
169where
170 S: Serializer,
171 T: Deref<Target = U>,
172 U: Serialize,
173{
174 Serialize::serialize(value.deref(), serializer)
175}
176
177impl Server {
178 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
179 let mut server = Self {
180 id: parking_lot::Mutex::new(id),
181 peer: Peer::new(id.0 as u32),
182 app_state,
183 executor,
184 connection_pool: Default::default(),
185 handlers: Default::default(),
186 teardown: watch::channel(false).0,
187 };
188
189 server
190 .add_request_handler(ping)
191 .add_request_handler(create_room)
192 .add_request_handler(join_room)
193 .add_request_handler(rejoin_room)
194 .add_request_handler(leave_room)
195 .add_request_handler(set_room_participant_role)
196 .add_request_handler(call)
197 .add_request_handler(cancel_call)
198 .add_message_handler(decline_call)
199 .add_request_handler(update_participant_location)
200 .add_request_handler(share_project)
201 .add_message_handler(unshare_project)
202 .add_request_handler(join_project)
203 .add_request_handler(join_hosted_project)
204 .add_message_handler(leave_project)
205 .add_request_handler(update_project)
206 .add_request_handler(update_worktree)
207 .add_message_handler(start_language_server)
208 .add_message_handler(update_language_server)
209 .add_message_handler(update_diagnostic_summary)
210 .add_message_handler(update_worktree_settings)
211 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
212 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
213 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
214 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
215 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
216 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
217 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
218 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
219 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
220 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
221 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
222 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
223 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
224 .add_request_handler(
225 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
226 )
227 .add_request_handler(
228 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
229 )
230 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
231 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
232 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
233 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
234 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
235 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
236 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
237 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
238 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
239 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
240 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
241 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
242 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
243 .add_message_handler(create_buffer_for_peer)
244 .add_request_handler(update_buffer)
245 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
246 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
247 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
248 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
249 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
250 .add_request_handler(get_users)
251 .add_request_handler(fuzzy_search_users)
252 .add_request_handler(request_contact)
253 .add_request_handler(remove_contact)
254 .add_request_handler(respond_to_contact_request)
255 .add_request_handler(create_channel)
256 .add_request_handler(delete_channel)
257 .add_request_handler(invite_channel_member)
258 .add_request_handler(remove_channel_member)
259 .add_request_handler(set_channel_member_role)
260 .add_request_handler(set_channel_visibility)
261 .add_request_handler(rename_channel)
262 .add_request_handler(join_channel_buffer)
263 .add_request_handler(leave_channel_buffer)
264 .add_message_handler(update_channel_buffer)
265 .add_request_handler(rejoin_channel_buffers)
266 .add_request_handler(get_channel_members)
267 .add_request_handler(respond_to_channel_invite)
268 .add_request_handler(join_channel)
269 .add_request_handler(join_channel_chat)
270 .add_message_handler(leave_channel_chat)
271 .add_request_handler(send_channel_message)
272 .add_request_handler(remove_channel_message)
273 .add_request_handler(get_channel_messages)
274 .add_request_handler(get_channel_messages_by_id)
275 .add_request_handler(get_notifications)
276 .add_request_handler(mark_notification_as_read)
277 .add_request_handler(move_channel)
278 .add_request_handler(follow)
279 .add_message_handler(unfollow)
280 .add_message_handler(update_followers)
281 .add_request_handler(get_private_user_info)
282 .add_message_handler(acknowledge_channel_message)
283 .add_message_handler(acknowledge_buffer_version);
284
285 Arc::new(server)
286 }
287
288 pub async fn start(&self) -> Result<()> {
289 let server_id = *self.id.lock();
290 let app_state = self.app_state.clone();
291 let peer = self.peer.clone();
292 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
293 let pool = self.connection_pool.clone();
294 let live_kit_client = self.app_state.live_kit_client.clone();
295
296 let span = info_span!("start server");
297 self.executor.spawn_detached(
298 async move {
299 tracing::info!("waiting for cleanup timeout");
300 timeout.await;
301 tracing::info!("cleanup timeout expired, retrieving stale rooms");
302 if let Some((room_ids, channel_ids)) = app_state
303 .db
304 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
305 .await
306 .trace_err()
307 {
308 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
309 tracing::info!(
310 stale_channel_buffer_count = channel_ids.len(),
311 "retrieved stale channel buffers"
312 );
313
314 for channel_id in channel_ids {
315 if let Some(refreshed_channel_buffer) = app_state
316 .db
317 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
318 .await
319 .trace_err()
320 {
321 for connection_id in refreshed_channel_buffer.connection_ids {
322 peer.send(
323 connection_id,
324 proto::UpdateChannelBufferCollaborators {
325 channel_id: channel_id.to_proto(),
326 collaborators: refreshed_channel_buffer
327 .collaborators
328 .clone(),
329 },
330 )
331 .trace_err();
332 }
333 }
334 }
335
336 for room_id in room_ids {
337 let mut contacts_to_update = HashSet::default();
338 let mut canceled_calls_to_user_ids = Vec::new();
339 let mut live_kit_room = String::new();
340 let mut delete_live_kit_room = false;
341
342 if let Some(mut refreshed_room) = app_state
343 .db
344 .clear_stale_room_participants(room_id, server_id)
345 .await
346 .trace_err()
347 {
348 tracing::info!(
349 room_id = room_id.0,
350 new_participant_count = refreshed_room.room.participants.len(),
351 "refreshed room"
352 );
353 room_updated(&refreshed_room.room, &peer);
354 if let Some(channel_id) = refreshed_room.channel_id {
355 channel_updated(
356 channel_id,
357 &refreshed_room.room,
358 &refreshed_room.channel_members,
359 &peer,
360 &pool.lock(),
361 );
362 }
363 contacts_to_update
364 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
365 contacts_to_update
366 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
367 canceled_calls_to_user_ids =
368 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
369 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
370 delete_live_kit_room = refreshed_room.room.participants.is_empty();
371 }
372
373 {
374 let pool = pool.lock();
375 for canceled_user_id in canceled_calls_to_user_ids {
376 for connection_id in pool.user_connection_ids(canceled_user_id) {
377 peer.send(
378 connection_id,
379 proto::CallCanceled {
380 room_id: room_id.to_proto(),
381 },
382 )
383 .trace_err();
384 }
385 }
386 }
387
388 for user_id in contacts_to_update {
389 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
390 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
391 if let Some((busy, contacts)) = busy.zip(contacts) {
392 let pool = pool.lock();
393 let updated_contact = contact_for_user(user_id, busy, &pool);
394 for contact in contacts {
395 if let db::Contact::Accepted {
396 user_id: contact_user_id,
397 ..
398 } = contact
399 {
400 for contact_conn_id in
401 pool.user_connection_ids(contact_user_id)
402 {
403 peer.send(
404 contact_conn_id,
405 proto::UpdateContacts {
406 contacts: vec![updated_contact.clone()],
407 remove_contacts: Default::default(),
408 incoming_requests: Default::default(),
409 remove_incoming_requests: Default::default(),
410 outgoing_requests: Default::default(),
411 remove_outgoing_requests: Default::default(),
412 },
413 )
414 .trace_err();
415 }
416 }
417 }
418 }
419 }
420
421 if let Some(live_kit) = live_kit_client.as_ref() {
422 if delete_live_kit_room {
423 live_kit.delete_room(live_kit_room).await.trace_err();
424 }
425 }
426 }
427 }
428
429 app_state
430 .db
431 .delete_stale_servers(&app_state.config.zed_environment, server_id)
432 .await
433 .trace_err();
434 }
435 .instrument(span),
436 );
437 Ok(())
438 }
439
440 pub fn teardown(&self) {
441 self.peer.teardown();
442 self.connection_pool.lock().reset();
443 let _ = self.teardown.send(true);
444 }
445
446 #[cfg(test)]
447 pub fn reset(&self, id: ServerId) {
448 self.teardown();
449 *self.id.lock() = id;
450 self.peer.reset(id.0 as u32);
451 let _ = self.teardown.send(false);
452 }
453
454 #[cfg(test)]
455 pub fn id(&self) -> ServerId {
456 *self.id.lock()
457 }
458
459 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
460 where
461 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
462 Fut: 'static + Send + Future<Output = Result<()>>,
463 M: EnvelopedMessage,
464 {
465 let prev_handler = self.handlers.insert(
466 TypeId::of::<M>(),
467 Box::new(move |envelope, session| {
468 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
469 let received_at = envelope.received_at;
470 tracing::info!(
471 "message received"
472 );
473 let start_time = Instant::now();
474 let future = (handler)(*envelope, session);
475 async move {
476 let result = future.await;
477 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
478 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
479 let queue_duration_ms = total_duration_ms - processing_duration_ms;
480 match result {
481 Err(error) => {
482 tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, "error handling message")
483 }
484 Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
485 }
486 }
487 .boxed()
488 }),
489 );
490 if prev_handler.is_some() {
491 panic!("registered a handler for the same message twice");
492 }
493 self
494 }
495
496 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
497 where
498 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
499 Fut: 'static + Send + Future<Output = Result<()>>,
500 M: EnvelopedMessage,
501 {
502 self.add_handler(move |envelope, session| handler(envelope.payload, session));
503 self
504 }
505
506 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
507 where
508 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
509 Fut: Send + Future<Output = Result<()>>,
510 M: RequestMessage,
511 {
512 let handler = Arc::new(handler);
513 self.add_handler(move |envelope, session| {
514 let receipt = envelope.receipt();
515 let handler = handler.clone();
516 async move {
517 let peer = session.peer.clone();
518 let responded = Arc::new(AtomicBool::default());
519 let response = Response {
520 peer: peer.clone(),
521 responded: responded.clone(),
522 receipt,
523 };
524 match (handler)(envelope.payload, response, session).await {
525 Ok(()) => {
526 if responded.load(std::sync::atomic::Ordering::SeqCst) {
527 Ok(())
528 } else {
529 Err(anyhow!("handler did not send a response"))?
530 }
531 }
532 Err(error) => {
533 let proto_err = match &error {
534 Error::Internal(err) => err.to_proto(),
535 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
536 };
537 peer.respond_with_error(receipt, proto_err)?;
538 Err(error)
539 }
540 }
541 }
542 })
543 }
544
545 #[allow(clippy::too_many_arguments)]
546 pub fn handle_connection(
547 self: &Arc<Self>,
548 connection: Connection,
549 address: String,
550 user: User,
551 zed_version: ZedVersion,
552 impersonator: Option<User>,
553 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
554 executor: Executor,
555 ) -> impl Future<Output = ()> {
556 let this = self.clone();
557 let user_id = user.id;
558 let login = user.github_login.clone();
559 let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty, connection_id = field::Empty);
560 if let Some(impersonator) = impersonator {
561 span.record("impersonator", &impersonator.github_login);
562 }
563 let mut teardown = self.teardown.subscribe();
564 async move {
565 if *teardown.borrow() {
566 tracing::error!("server is tearing down");
567 return
568 }
569 let (connection_id, handle_io, mut incoming_rx) = this
570 .peer
571 .add_connection(connection, {
572 let executor = executor.clone();
573 move |duration| executor.sleep(duration)
574 });
575 tracing::Span::current().record("connection_id", format!("{}", connection_id));
576 tracing::info!("connection opened");
577
578 let session = Session {
579 user_id,
580 connection_id,
581 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
582 peer: this.peer.clone(),
583 connection_pool: this.connection_pool.clone(),
584 live_kit_client: this.app_state.live_kit_client.clone(),
585 _executor: executor.clone()
586 };
587
588 if let Err(error) = this.send_initial_client_update(connection_id, user, zed_version, send_connection_id, &session).await {
589 tracing::error!(?error, "failed to send initial client update");
590 return;
591 }
592
593 let handle_io = handle_io.fuse();
594 futures::pin_mut!(handle_io);
595
596 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
597 // This prevents deadlocks when e.g., client A performs a request to client B and
598 // client B performs a request to client A. If both clients stop processing further
599 // messages until their respective request completes, they won't have a chance to
600 // respond to the other client's request and cause a deadlock.
601 //
602 // This arrangement ensures we will attempt to process earlier messages first, but fall
603 // back to processing messages arrived later in the spirit of making progress.
604 let mut foreground_message_handlers = FuturesUnordered::new();
605 let concurrent_handlers = Arc::new(Semaphore::new(256));
606 loop {
607 let next_message = async {
608 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
609 let message = incoming_rx.next().await;
610 (permit, message)
611 }.fuse();
612 futures::pin_mut!(next_message);
613 futures::select_biased! {
614 _ = teardown.changed().fuse() => return,
615 result = handle_io => {
616 if let Err(error) = result {
617 tracing::error!(?error, "error handling I/O");
618 }
619 break;
620 }
621 _ = foreground_message_handlers.next() => {}
622 next_message = next_message => {
623 let (permit, message) = next_message;
624 if let Some(message) = message {
625 let type_name = message.payload_type_name();
626 // note: we copy all the fields from the parent span so we can query them in the logs.
627 // (https://github.com/tokio-rs/tracing/issues/2670).
628 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
629 let span_enter = span.enter();
630 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
631 let is_background = message.is_background();
632 let handle_message = (handler)(message, session.clone());
633 drop(span_enter);
634
635 let handle_message = async move {
636 handle_message.await;
637 drop(permit);
638 }.instrument(span);
639 if is_background {
640 executor.spawn_detached(handle_message);
641 } else {
642 foreground_message_handlers.push(handle_message);
643 }
644 } else {
645 tracing::error!("no message handler");
646 }
647 } else {
648 tracing::info!("connection closed");
649 break;
650 }
651 }
652 }
653 }
654
655 drop(foreground_message_handlers);
656 tracing::info!("signing out");
657 if let Err(error) = connection_lost(session, teardown, executor).await {
658 tracing::error!(?error, "error signing out");
659 }
660
661 }.instrument(span)
662 }
663
664 async fn send_initial_client_update(
665 &self,
666 connection_id: ConnectionId,
667 user: User,
668 zed_version: ZedVersion,
669 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
670 session: &Session,
671 ) -> Result<()> {
672 self.peer.send(
673 connection_id,
674 proto::Hello {
675 peer_id: Some(connection_id.into()),
676 },
677 )?;
678 tracing::info!("sent hello message");
679
680 if let Some(send_connection_id) = send_connection_id.take() {
681 let _ = send_connection_id.send(connection_id);
682 }
683
684 if !user.connected_once {
685 self.peer.send(connection_id, proto::ShowContacts {})?;
686 self.app_state
687 .db
688 .set_user_connected_once(user.id, true)
689 .await?;
690 }
691
692 let (contacts, channels_for_user, channel_invites) = future::try_join3(
693 self.app_state.db.get_contacts(user.id),
694 self.app_state.db.get_channels_for_user(user.id),
695 self.app_state.db.get_channel_invites_for_user(user.id),
696 )
697 .await?;
698
699 {
700 let mut pool = self.connection_pool.lock();
701 pool.add_connection(connection_id, user.id, user.admin, zed_version);
702 self.peer.send(
703 connection_id,
704 build_initial_contacts_update(contacts, &pool),
705 )?;
706 self.peer.send(
707 connection_id,
708 build_update_user_channels(&channels_for_user),
709 )?;
710 self.peer.send(
711 connection_id,
712 build_channels_update(channels_for_user, channel_invites),
713 )?;
714 }
715
716 if let Some(incoming_call) = self.app_state.db.incoming_call_for_user(user.id).await? {
717 self.peer.send(connection_id, incoming_call)?;
718 }
719
720 update_user_contacts(user.id, &session).await?;
721 Ok(())
722 }
723
724 pub async fn invite_code_redeemed(
725 self: &Arc<Self>,
726 inviter_id: UserId,
727 invitee_id: UserId,
728 ) -> Result<()> {
729 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
730 if let Some(code) = &user.invite_code {
731 let pool = self.connection_pool.lock();
732 let invitee_contact = contact_for_user(invitee_id, false, &pool);
733 for connection_id in pool.user_connection_ids(inviter_id) {
734 self.peer.send(
735 connection_id,
736 proto::UpdateContacts {
737 contacts: vec![invitee_contact.clone()],
738 ..Default::default()
739 },
740 )?;
741 self.peer.send(
742 connection_id,
743 proto::UpdateInviteInfo {
744 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
745 count: user.invite_count as u32,
746 },
747 )?;
748 }
749 }
750 }
751 Ok(())
752 }
753
754 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
755 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
756 if let Some(invite_code) = &user.invite_code {
757 let pool = self.connection_pool.lock();
758 for connection_id in pool.user_connection_ids(user_id) {
759 self.peer.send(
760 connection_id,
761 proto::UpdateInviteInfo {
762 url: format!(
763 "{}{}",
764 self.app_state.config.invite_link_prefix, invite_code
765 ),
766 count: user.invite_count as u32,
767 },
768 )?;
769 }
770 }
771 }
772 Ok(())
773 }
774
775 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
776 ServerSnapshot {
777 connection_pool: ConnectionPoolGuard {
778 guard: self.connection_pool.lock(),
779 _not_send: PhantomData,
780 },
781 peer: &self.peer,
782 }
783 }
784}
785
786impl<'a> Deref for ConnectionPoolGuard<'a> {
787 type Target = ConnectionPool;
788
789 fn deref(&self) -> &Self::Target {
790 &self.guard
791 }
792}
793
794impl<'a> DerefMut for ConnectionPoolGuard<'a> {
795 fn deref_mut(&mut self) -> &mut Self::Target {
796 &mut self.guard
797 }
798}
799
800impl<'a> Drop for ConnectionPoolGuard<'a> {
801 fn drop(&mut self) {
802 #[cfg(test)]
803 self.check_invariants();
804 }
805}
806
807fn broadcast<F>(
808 sender_id: Option<ConnectionId>,
809 receiver_ids: impl IntoIterator<Item = ConnectionId>,
810 mut f: F,
811) where
812 F: FnMut(ConnectionId) -> anyhow::Result<()>,
813{
814 for receiver_id in receiver_ids {
815 if Some(receiver_id) != sender_id {
816 if let Err(error) = f(receiver_id) {
817 tracing::error!("failed to send to {:?} {}", receiver_id, error);
818 }
819 }
820 }
821}
822
823pub struct ProtocolVersion(u32);
824
825impl Header for ProtocolVersion {
826 fn name() -> &'static HeaderName {
827 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
828 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
829 }
830
831 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
832 where
833 Self: Sized,
834 I: Iterator<Item = &'i axum::http::HeaderValue>,
835 {
836 let version = values
837 .next()
838 .ok_or_else(axum::headers::Error::invalid)?
839 .to_str()
840 .map_err(|_| axum::headers::Error::invalid())?
841 .parse()
842 .map_err(|_| axum::headers::Error::invalid())?;
843 Ok(Self(version))
844 }
845
846 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
847 values.extend([self.0.to_string().parse().unwrap()]);
848 }
849}
850
851pub struct AppVersionHeader(SemanticVersion);
852impl Header for AppVersionHeader {
853 fn name() -> &'static HeaderName {
854 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
855 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
856 }
857
858 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
859 where
860 Self: Sized,
861 I: Iterator<Item = &'i axum::http::HeaderValue>,
862 {
863 let version = values
864 .next()
865 .ok_or_else(axum::headers::Error::invalid)?
866 .to_str()
867 .map_err(|_| axum::headers::Error::invalid())?
868 .parse()
869 .map_err(|_| axum::headers::Error::invalid())?;
870 Ok(Self(version))
871 }
872
873 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
874 values.extend([self.0.to_string().parse().unwrap()]);
875 }
876}
877
878pub fn routes(server: Arc<Server>) -> Router<(), Body> {
879 Router::new()
880 .route("/rpc", get(handle_websocket_request))
881 .layer(
882 ServiceBuilder::new()
883 .layer(Extension(server.app_state.clone()))
884 .layer(middleware::from_fn(auth::validate_header)),
885 )
886 .route("/metrics", get(handle_metrics))
887 .layer(Extension(server))
888}
889
890pub async fn handle_websocket_request(
891 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
892 app_version_header: Option<TypedHeader<AppVersionHeader>>,
893 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
894 Extension(server): Extension<Arc<Server>>,
895 Extension(user): Extension<User>,
896 Extension(impersonator): Extension<Impersonator>,
897 ws: WebSocketUpgrade,
898) -> axum::response::Response {
899 if protocol_version != rpc::PROTOCOL_VERSION {
900 return (
901 StatusCode::UPGRADE_REQUIRED,
902 "client must be upgraded".to_string(),
903 )
904 .into_response();
905 }
906
907 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
908 return (
909 StatusCode::UPGRADE_REQUIRED,
910 "no version header found".to_string(),
911 )
912 .into_response();
913 };
914
915 if !version.is_supported() {
916 return (
917 StatusCode::UPGRADE_REQUIRED,
918 "client must be upgraded".to_string(),
919 )
920 .into_response();
921 }
922
923 let socket_address = socket_address.to_string();
924 ws.on_upgrade(move |socket| {
925 let socket = socket
926 .map_ok(to_tungstenite_message)
927 .err_into()
928 .with(|message| async move { Ok(to_axum_message(message)) });
929 let connection = Connection::new(Box::pin(socket));
930 async move {
931 server
932 .handle_connection(
933 connection,
934 socket_address,
935 user,
936 version,
937 impersonator.0,
938 None,
939 Executor::Production,
940 )
941 .await;
942 }
943 })
944}
945
946pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
947 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
948 let connections_metric = CONNECTIONS_METRIC
949 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
950
951 let connections = server
952 .connection_pool
953 .lock()
954 .connections()
955 .filter(|connection| !connection.admin)
956 .count();
957 connections_metric.set(connections as _);
958
959 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
960 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
961 register_int_gauge!(
962 "shared_projects",
963 "number of open projects with one or more guests"
964 )
965 .unwrap()
966 });
967
968 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
969 shared_projects_metric.set(shared_projects as _);
970
971 let encoder = prometheus::TextEncoder::new();
972 let metric_families = prometheus::gather();
973 let encoded_metrics = encoder
974 .encode_to_string(&metric_families)
975 .map_err(|err| anyhow!("{}", err))?;
976 Ok(encoded_metrics)
977}
978
979#[instrument(err, skip(executor))]
980async fn connection_lost(
981 session: Session,
982 mut teardown: watch::Receiver<bool>,
983 executor: Executor,
984) -> Result<()> {
985 session.peer.disconnect(session.connection_id);
986 session
987 .connection_pool()
988 .await
989 .remove_connection(session.connection_id)?;
990
991 session
992 .db()
993 .await
994 .connection_lost(session.connection_id)
995 .await
996 .trace_err();
997
998 futures::select_biased! {
999 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1000 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
1001 leave_room_for_session(&session).await.trace_err();
1002 leave_channel_buffers_for_session(&session)
1003 .await
1004 .trace_err();
1005
1006 if !session
1007 .connection_pool()
1008 .await
1009 .is_user_online(session.user_id)
1010 {
1011 let db = session.db().await;
1012 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
1013 room_updated(&room, &session.peer);
1014 }
1015 }
1016
1017 update_user_contacts(session.user_id, &session).await?;
1018 }
1019 _ = teardown.changed().fuse() => {}
1020 }
1021
1022 Ok(())
1023}
1024
1025/// Acknowledges a ping from a client, used to keep the connection alive.
1026async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1027 response.send(proto::Ack {})?;
1028 Ok(())
1029}
1030
1031/// Creates a new room for calling (outside of channels)
1032async fn create_room(
1033 _request: proto::CreateRoom,
1034 response: Response<proto::CreateRoom>,
1035 session: Session,
1036) -> Result<()> {
1037 let live_kit_room = nanoid::nanoid!(30);
1038
1039 let live_kit_connection_info = {
1040 let live_kit_room = live_kit_room.clone();
1041 let live_kit = session.live_kit_client.as_ref();
1042
1043 util::async_maybe!({
1044 let live_kit = live_kit?;
1045
1046 let token = live_kit
1047 .room_token(&live_kit_room, &session.user_id.to_string())
1048 .trace_err()?;
1049
1050 Some(proto::LiveKitConnectionInfo {
1051 server_url: live_kit.url().into(),
1052 token,
1053 can_publish: true,
1054 })
1055 })
1056 }
1057 .await;
1058
1059 let room = session
1060 .db()
1061 .await
1062 .create_room(session.user_id, session.connection_id, &live_kit_room)
1063 .await?;
1064
1065 response.send(proto::CreateRoomResponse {
1066 room: Some(room.clone()),
1067 live_kit_connection_info,
1068 })?;
1069
1070 update_user_contacts(session.user_id, &session).await?;
1071 Ok(())
1072}
1073
1074/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1075async fn join_room(
1076 request: proto::JoinRoom,
1077 response: Response<proto::JoinRoom>,
1078 session: Session,
1079) -> Result<()> {
1080 let room_id = RoomId::from_proto(request.id);
1081
1082 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1083
1084 if let Some(channel_id) = channel_id {
1085 return join_channel_internal(channel_id, Box::new(response), session).await;
1086 }
1087
1088 let joined_room = {
1089 let room = session
1090 .db()
1091 .await
1092 .join_room(room_id, session.user_id, session.connection_id)
1093 .await?;
1094 room_updated(&room.room, &session.peer);
1095 room.into_inner()
1096 };
1097
1098 for connection_id in session
1099 .connection_pool()
1100 .await
1101 .user_connection_ids(session.user_id)
1102 {
1103 session
1104 .peer
1105 .send(
1106 connection_id,
1107 proto::CallCanceled {
1108 room_id: room_id.to_proto(),
1109 },
1110 )
1111 .trace_err();
1112 }
1113
1114 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1115 if let Some(token) = live_kit
1116 .room_token(
1117 &joined_room.room.live_kit_room,
1118 &session.user_id.to_string(),
1119 )
1120 .trace_err()
1121 {
1122 Some(proto::LiveKitConnectionInfo {
1123 server_url: live_kit.url().into(),
1124 token,
1125 can_publish: true,
1126 })
1127 } else {
1128 None
1129 }
1130 } else {
1131 None
1132 };
1133
1134 response.send(proto::JoinRoomResponse {
1135 room: Some(joined_room.room),
1136 channel_id: None,
1137 live_kit_connection_info,
1138 })?;
1139
1140 update_user_contacts(session.user_id, &session).await?;
1141 Ok(())
1142}
1143
1144/// Rejoin room is used to reconnect to a room after connection errors.
1145async fn rejoin_room(
1146 request: proto::RejoinRoom,
1147 response: Response<proto::RejoinRoom>,
1148 session: Session,
1149) -> Result<()> {
1150 let room;
1151 let channel_id;
1152 let channel_members;
1153 {
1154 let mut rejoined_room = session
1155 .db()
1156 .await
1157 .rejoin_room(request, session.user_id, session.connection_id)
1158 .await?;
1159
1160 response.send(proto::RejoinRoomResponse {
1161 room: Some(rejoined_room.room.clone()),
1162 reshared_projects: rejoined_room
1163 .reshared_projects
1164 .iter()
1165 .map(|project| proto::ResharedProject {
1166 id: project.id.to_proto(),
1167 collaborators: project
1168 .collaborators
1169 .iter()
1170 .map(|collaborator| collaborator.to_proto())
1171 .collect(),
1172 })
1173 .collect(),
1174 rejoined_projects: rejoined_room
1175 .rejoined_projects
1176 .iter()
1177 .map(|rejoined_project| proto::RejoinedProject {
1178 id: rejoined_project.id.to_proto(),
1179 worktrees: rejoined_project
1180 .worktrees
1181 .iter()
1182 .map(|worktree| proto::WorktreeMetadata {
1183 id: worktree.id,
1184 root_name: worktree.root_name.clone(),
1185 visible: worktree.visible,
1186 abs_path: worktree.abs_path.clone(),
1187 })
1188 .collect(),
1189 collaborators: rejoined_project
1190 .collaborators
1191 .iter()
1192 .map(|collaborator| collaborator.to_proto())
1193 .collect(),
1194 language_servers: rejoined_project.language_servers.clone(),
1195 })
1196 .collect(),
1197 })?;
1198 room_updated(&rejoined_room.room, &session.peer);
1199
1200 for project in &rejoined_room.reshared_projects {
1201 for collaborator in &project.collaborators {
1202 session
1203 .peer
1204 .send(
1205 collaborator.connection_id,
1206 proto::UpdateProjectCollaborator {
1207 project_id: project.id.to_proto(),
1208 old_peer_id: Some(project.old_connection_id.into()),
1209 new_peer_id: Some(session.connection_id.into()),
1210 },
1211 )
1212 .trace_err();
1213 }
1214
1215 broadcast(
1216 Some(session.connection_id),
1217 project
1218 .collaborators
1219 .iter()
1220 .map(|collaborator| collaborator.connection_id),
1221 |connection_id| {
1222 session.peer.forward_send(
1223 session.connection_id,
1224 connection_id,
1225 proto::UpdateProject {
1226 project_id: project.id.to_proto(),
1227 worktrees: project.worktrees.clone(),
1228 },
1229 )
1230 },
1231 );
1232 }
1233
1234 for project in &rejoined_room.rejoined_projects {
1235 for collaborator in &project.collaborators {
1236 session
1237 .peer
1238 .send(
1239 collaborator.connection_id,
1240 proto::UpdateProjectCollaborator {
1241 project_id: project.id.to_proto(),
1242 old_peer_id: Some(project.old_connection_id.into()),
1243 new_peer_id: Some(session.connection_id.into()),
1244 },
1245 )
1246 .trace_err();
1247 }
1248 }
1249
1250 for project in &mut rejoined_room.rejoined_projects {
1251 for worktree in mem::take(&mut project.worktrees) {
1252 #[cfg(any(test, feature = "test-support"))]
1253 const MAX_CHUNK_SIZE: usize = 2;
1254 #[cfg(not(any(test, feature = "test-support")))]
1255 const MAX_CHUNK_SIZE: usize = 256;
1256
1257 // Stream this worktree's entries.
1258 let message = proto::UpdateWorktree {
1259 project_id: project.id.to_proto(),
1260 worktree_id: worktree.id,
1261 abs_path: worktree.abs_path.clone(),
1262 root_name: worktree.root_name,
1263 updated_entries: worktree.updated_entries,
1264 removed_entries: worktree.removed_entries,
1265 scan_id: worktree.scan_id,
1266 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1267 updated_repositories: worktree.updated_repositories,
1268 removed_repositories: worktree.removed_repositories,
1269 };
1270 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1271 session.peer.send(session.connection_id, update.clone())?;
1272 }
1273
1274 // Stream this worktree's diagnostics.
1275 for summary in worktree.diagnostic_summaries {
1276 session.peer.send(
1277 session.connection_id,
1278 proto::UpdateDiagnosticSummary {
1279 project_id: project.id.to_proto(),
1280 worktree_id: worktree.id,
1281 summary: Some(summary),
1282 },
1283 )?;
1284 }
1285
1286 for settings_file in worktree.settings_files {
1287 session.peer.send(
1288 session.connection_id,
1289 proto::UpdateWorktreeSettings {
1290 project_id: project.id.to_proto(),
1291 worktree_id: worktree.id,
1292 path: settings_file.path,
1293 content: Some(settings_file.content),
1294 },
1295 )?;
1296 }
1297 }
1298
1299 for language_server in &project.language_servers {
1300 session.peer.send(
1301 session.connection_id,
1302 proto::UpdateLanguageServer {
1303 project_id: project.id.to_proto(),
1304 language_server_id: language_server.id,
1305 variant: Some(
1306 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1307 proto::LspDiskBasedDiagnosticsUpdated {},
1308 ),
1309 ),
1310 },
1311 )?;
1312 }
1313 }
1314
1315 let rejoined_room = rejoined_room.into_inner();
1316
1317 room = rejoined_room.room;
1318 channel_id = rejoined_room.channel_id;
1319 channel_members = rejoined_room.channel_members;
1320 }
1321
1322 if let Some(channel_id) = channel_id {
1323 channel_updated(
1324 channel_id,
1325 &room,
1326 &channel_members,
1327 &session.peer,
1328 &*session.connection_pool().await,
1329 );
1330 }
1331
1332 update_user_contacts(session.user_id, &session).await?;
1333 Ok(())
1334}
1335
1336/// leave room disconnects from the room.
1337async fn leave_room(
1338 _: proto::LeaveRoom,
1339 response: Response<proto::LeaveRoom>,
1340 session: Session,
1341) -> Result<()> {
1342 leave_room_for_session(&session).await?;
1343 response.send(proto::Ack {})?;
1344 Ok(())
1345}
1346
1347/// Updates the permissions of someone else in the room.
1348async fn set_room_participant_role(
1349 request: proto::SetRoomParticipantRole,
1350 response: Response<proto::SetRoomParticipantRole>,
1351 session: Session,
1352) -> Result<()> {
1353 let user_id = UserId::from_proto(request.user_id);
1354 let role = ChannelRole::from(request.role());
1355
1356 if role == ChannelRole::Talker {
1357 let pool = session.connection_pool().await;
1358
1359 for connection in pool.user_connections(user_id) {
1360 if !connection.zed_version.supports_talker_role() {
1361 Err(anyhow!(
1362 "This user is on zed {} which does not support unmute",
1363 connection.zed_version
1364 ))?;
1365 }
1366 }
1367 }
1368
1369 let (live_kit_room, can_publish) = {
1370 let room = session
1371 .db()
1372 .await
1373 .set_room_participant_role(
1374 session.user_id,
1375 RoomId::from_proto(request.room_id),
1376 user_id,
1377 role,
1378 )
1379 .await?;
1380
1381 let live_kit_room = room.live_kit_room.clone();
1382 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1383 room_updated(&room, &session.peer);
1384 (live_kit_room, can_publish)
1385 };
1386
1387 if let Some(live_kit) = session.live_kit_client.as_ref() {
1388 live_kit
1389 .update_participant(
1390 live_kit_room.clone(),
1391 request.user_id.to_string(),
1392 live_kit_server::proto::ParticipantPermission {
1393 can_subscribe: true,
1394 can_publish,
1395 can_publish_data: can_publish,
1396 hidden: false,
1397 recorder: false,
1398 },
1399 )
1400 .await
1401 .trace_err();
1402 }
1403
1404 response.send(proto::Ack {})?;
1405 Ok(())
1406}
1407
1408/// Call someone else into the current room
1409async fn call(
1410 request: proto::Call,
1411 response: Response<proto::Call>,
1412 session: Session,
1413) -> Result<()> {
1414 let room_id = RoomId::from_proto(request.room_id);
1415 let calling_user_id = session.user_id;
1416 let calling_connection_id = session.connection_id;
1417 let called_user_id = UserId::from_proto(request.called_user_id);
1418 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1419 if !session
1420 .db()
1421 .await
1422 .has_contact(calling_user_id, called_user_id)
1423 .await?
1424 {
1425 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1426 }
1427
1428 let incoming_call = {
1429 let (room, incoming_call) = &mut *session
1430 .db()
1431 .await
1432 .call(
1433 room_id,
1434 calling_user_id,
1435 calling_connection_id,
1436 called_user_id,
1437 initial_project_id,
1438 )
1439 .await?;
1440 room_updated(&room, &session.peer);
1441 mem::take(incoming_call)
1442 };
1443 update_user_contacts(called_user_id, &session).await?;
1444
1445 let mut calls = session
1446 .connection_pool()
1447 .await
1448 .user_connection_ids(called_user_id)
1449 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1450 .collect::<FuturesUnordered<_>>();
1451
1452 while let Some(call_response) = calls.next().await {
1453 match call_response.as_ref() {
1454 Ok(_) => {
1455 response.send(proto::Ack {})?;
1456 return Ok(());
1457 }
1458 Err(_) => {
1459 call_response.trace_err();
1460 }
1461 }
1462 }
1463
1464 {
1465 let room = session
1466 .db()
1467 .await
1468 .call_failed(room_id, called_user_id)
1469 .await?;
1470 room_updated(&room, &session.peer);
1471 }
1472 update_user_contacts(called_user_id, &session).await?;
1473
1474 Err(anyhow!("failed to ring user"))?
1475}
1476
1477/// Cancel an outgoing call.
1478async fn cancel_call(
1479 request: proto::CancelCall,
1480 response: Response<proto::CancelCall>,
1481 session: Session,
1482) -> Result<()> {
1483 let called_user_id = UserId::from_proto(request.called_user_id);
1484 let room_id = RoomId::from_proto(request.room_id);
1485 {
1486 let room = session
1487 .db()
1488 .await
1489 .cancel_call(room_id, session.connection_id, called_user_id)
1490 .await?;
1491 room_updated(&room, &session.peer);
1492 }
1493
1494 for connection_id in session
1495 .connection_pool()
1496 .await
1497 .user_connection_ids(called_user_id)
1498 {
1499 session
1500 .peer
1501 .send(
1502 connection_id,
1503 proto::CallCanceled {
1504 room_id: room_id.to_proto(),
1505 },
1506 )
1507 .trace_err();
1508 }
1509 response.send(proto::Ack {})?;
1510
1511 update_user_contacts(called_user_id, &session).await?;
1512 Ok(())
1513}
1514
1515/// Decline an incoming call.
1516async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1517 let room_id = RoomId::from_proto(message.room_id);
1518 {
1519 let room = session
1520 .db()
1521 .await
1522 .decline_call(Some(room_id), session.user_id)
1523 .await?
1524 .ok_or_else(|| anyhow!("failed to decline call"))?;
1525 room_updated(&room, &session.peer);
1526 }
1527
1528 for connection_id in session
1529 .connection_pool()
1530 .await
1531 .user_connection_ids(session.user_id)
1532 {
1533 session
1534 .peer
1535 .send(
1536 connection_id,
1537 proto::CallCanceled {
1538 room_id: room_id.to_proto(),
1539 },
1540 )
1541 .trace_err();
1542 }
1543 update_user_contacts(session.user_id, &session).await?;
1544 Ok(())
1545}
1546
1547/// Updates other participants in the room with your current location.
1548async fn update_participant_location(
1549 request: proto::UpdateParticipantLocation,
1550 response: Response<proto::UpdateParticipantLocation>,
1551 session: Session,
1552) -> Result<()> {
1553 let room_id = RoomId::from_proto(request.room_id);
1554 let location = request
1555 .location
1556 .ok_or_else(|| anyhow!("invalid location"))?;
1557
1558 let db = session.db().await;
1559 let room = db
1560 .update_room_participant_location(room_id, session.connection_id, location)
1561 .await?;
1562
1563 room_updated(&room, &session.peer);
1564 response.send(proto::Ack {})?;
1565 Ok(())
1566}
1567
1568/// Share a project into the room.
1569async fn share_project(
1570 request: proto::ShareProject,
1571 response: Response<proto::ShareProject>,
1572 session: Session,
1573) -> Result<()> {
1574 let (project_id, room) = &*session
1575 .db()
1576 .await
1577 .share_project(
1578 RoomId::from_proto(request.room_id),
1579 session.connection_id,
1580 &request.worktrees,
1581 )
1582 .await?;
1583 response.send(proto::ShareProjectResponse {
1584 project_id: project_id.to_proto(),
1585 })?;
1586 room_updated(&room, &session.peer);
1587
1588 Ok(())
1589}
1590
1591/// Unshare a project from the room.
1592async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1593 let project_id = ProjectId::from_proto(message.project_id);
1594
1595 let (room, guest_connection_ids) = &*session
1596 .db()
1597 .await
1598 .unshare_project(project_id, session.connection_id)
1599 .await?;
1600
1601 broadcast(
1602 Some(session.connection_id),
1603 guest_connection_ids.iter().copied(),
1604 |conn_id| session.peer.send(conn_id, message.clone()),
1605 );
1606 room_updated(&room, &session.peer);
1607
1608 Ok(())
1609}
1610
1611/// Join someone elses shared project.
1612async fn join_project(
1613 request: proto::JoinProject,
1614 response: Response<proto::JoinProject>,
1615 session: Session,
1616) -> Result<()> {
1617 let project_id = ProjectId::from_proto(request.project_id);
1618
1619 tracing::info!(%project_id, "join project");
1620
1621 let (project, replica_id) = &mut *session
1622 .db()
1623 .await
1624 .join_project_in_room(project_id, session.connection_id)
1625 .await?;
1626
1627 join_project_internal(response, session, project, replica_id)
1628}
1629
1630trait JoinProjectInternalResponse {
1631 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1632}
1633impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1634 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1635 Response::<proto::JoinProject>::send(self, result)
1636 }
1637}
1638impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1639 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1640 Response::<proto::JoinHostedProject>::send(self, result)
1641 }
1642}
1643
1644fn join_project_internal(
1645 response: impl JoinProjectInternalResponse,
1646 session: Session,
1647 project: &mut Project,
1648 replica_id: &ReplicaId,
1649) -> Result<()> {
1650 let collaborators = project
1651 .collaborators
1652 .iter()
1653 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1654 .map(|collaborator| collaborator.to_proto())
1655 .collect::<Vec<_>>();
1656 let project_id = project.id;
1657 let guest_user_id = session.user_id;
1658
1659 let worktrees = project
1660 .worktrees
1661 .iter()
1662 .map(|(id, worktree)| proto::WorktreeMetadata {
1663 id: *id,
1664 root_name: worktree.root_name.clone(),
1665 visible: worktree.visible,
1666 abs_path: worktree.abs_path.clone(),
1667 })
1668 .collect::<Vec<_>>();
1669
1670 for collaborator in &collaborators {
1671 session
1672 .peer
1673 .send(
1674 collaborator.peer_id.unwrap().into(),
1675 proto::AddProjectCollaborator {
1676 project_id: project_id.to_proto(),
1677 collaborator: Some(proto::Collaborator {
1678 peer_id: Some(session.connection_id.into()),
1679 replica_id: replica_id.0 as u32,
1680 user_id: guest_user_id.to_proto(),
1681 }),
1682 },
1683 )
1684 .trace_err();
1685 }
1686
1687 // First, we send the metadata associated with each worktree.
1688 response.send(proto::JoinProjectResponse {
1689 project_id: project.id.0 as u64,
1690 worktrees: worktrees.clone(),
1691 replica_id: replica_id.0 as u32,
1692 collaborators: collaborators.clone(),
1693 language_servers: project.language_servers.clone(),
1694 role: project.role.into(), // todo
1695 })?;
1696
1697 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1698 #[cfg(any(test, feature = "test-support"))]
1699 const MAX_CHUNK_SIZE: usize = 2;
1700 #[cfg(not(any(test, feature = "test-support")))]
1701 const MAX_CHUNK_SIZE: usize = 256;
1702
1703 // Stream this worktree's entries.
1704 let message = proto::UpdateWorktree {
1705 project_id: project_id.to_proto(),
1706 worktree_id,
1707 abs_path: worktree.abs_path.clone(),
1708 root_name: worktree.root_name,
1709 updated_entries: worktree.entries,
1710 removed_entries: Default::default(),
1711 scan_id: worktree.scan_id,
1712 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1713 updated_repositories: worktree.repository_entries.into_values().collect(),
1714 removed_repositories: Default::default(),
1715 };
1716 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1717 session.peer.send(session.connection_id, update.clone())?;
1718 }
1719
1720 // Stream this worktree's diagnostics.
1721 for summary in worktree.diagnostic_summaries {
1722 session.peer.send(
1723 session.connection_id,
1724 proto::UpdateDiagnosticSummary {
1725 project_id: project_id.to_proto(),
1726 worktree_id: worktree.id,
1727 summary: Some(summary),
1728 },
1729 )?;
1730 }
1731
1732 for settings_file in worktree.settings_files {
1733 session.peer.send(
1734 session.connection_id,
1735 proto::UpdateWorktreeSettings {
1736 project_id: project_id.to_proto(),
1737 worktree_id: worktree.id,
1738 path: settings_file.path,
1739 content: Some(settings_file.content),
1740 },
1741 )?;
1742 }
1743 }
1744
1745 for language_server in &project.language_servers {
1746 session.peer.send(
1747 session.connection_id,
1748 proto::UpdateLanguageServer {
1749 project_id: project_id.to_proto(),
1750 language_server_id: language_server.id,
1751 variant: Some(
1752 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1753 proto::LspDiskBasedDiagnosticsUpdated {},
1754 ),
1755 ),
1756 },
1757 )?;
1758 }
1759
1760 Ok(())
1761}
1762
1763/// Leave someone elses shared project.
1764async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1765 let sender_id = session.connection_id;
1766 let project_id = ProjectId::from_proto(request.project_id);
1767 let db = session.db().await;
1768 if db.is_hosted_project(project_id).await? {
1769 let project = db.leave_hosted_project(project_id, sender_id).await?;
1770 project_left(&project, &session);
1771 return Ok(());
1772 }
1773
1774 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1775 tracing::info!(
1776 %project_id,
1777 host_user_id = ?project.host_user_id,
1778 host_connection_id = ?project.host_connection_id,
1779 "leave project"
1780 );
1781
1782 project_left(&project, &session);
1783 room_updated(&room, &session.peer);
1784
1785 Ok(())
1786}
1787
1788async fn join_hosted_project(
1789 request: proto::JoinHostedProject,
1790 response: Response<proto::JoinHostedProject>,
1791 session: Session,
1792) -> Result<()> {
1793 let (mut project, replica_id) = session
1794 .db()
1795 .await
1796 .join_hosted_project(
1797 ProjectId(request.project_id as i32),
1798 session.user_id,
1799 session.connection_id,
1800 )
1801 .await?;
1802
1803 join_project_internal(response, session, &mut project, &replica_id)
1804}
1805
1806/// Updates other participants with changes to the project
1807async fn update_project(
1808 request: proto::UpdateProject,
1809 response: Response<proto::UpdateProject>,
1810 session: Session,
1811) -> Result<()> {
1812 let project_id = ProjectId::from_proto(request.project_id);
1813 let (room, guest_connection_ids) = &*session
1814 .db()
1815 .await
1816 .update_project(project_id, session.connection_id, &request.worktrees)
1817 .await?;
1818 broadcast(
1819 Some(session.connection_id),
1820 guest_connection_ids.iter().copied(),
1821 |connection_id| {
1822 session
1823 .peer
1824 .forward_send(session.connection_id, connection_id, request.clone())
1825 },
1826 );
1827 room_updated(&room, &session.peer);
1828 response.send(proto::Ack {})?;
1829
1830 Ok(())
1831}
1832
1833/// Updates other participants with changes to the worktree
1834async fn update_worktree(
1835 request: proto::UpdateWorktree,
1836 response: Response<proto::UpdateWorktree>,
1837 session: Session,
1838) -> Result<()> {
1839 let guest_connection_ids = session
1840 .db()
1841 .await
1842 .update_worktree(&request, session.connection_id)
1843 .await?;
1844
1845 broadcast(
1846 Some(session.connection_id),
1847 guest_connection_ids.iter().copied(),
1848 |connection_id| {
1849 session
1850 .peer
1851 .forward_send(session.connection_id, connection_id, request.clone())
1852 },
1853 );
1854 response.send(proto::Ack {})?;
1855 Ok(())
1856}
1857
1858/// Updates other participants with changes to the diagnostics
1859async fn update_diagnostic_summary(
1860 message: proto::UpdateDiagnosticSummary,
1861 session: Session,
1862) -> Result<()> {
1863 let guest_connection_ids = session
1864 .db()
1865 .await
1866 .update_diagnostic_summary(&message, session.connection_id)
1867 .await?;
1868
1869 broadcast(
1870 Some(session.connection_id),
1871 guest_connection_ids.iter().copied(),
1872 |connection_id| {
1873 session
1874 .peer
1875 .forward_send(session.connection_id, connection_id, message.clone())
1876 },
1877 );
1878
1879 Ok(())
1880}
1881
1882/// Updates other participants with changes to the worktree settings
1883async fn update_worktree_settings(
1884 message: proto::UpdateWorktreeSettings,
1885 session: Session,
1886) -> Result<()> {
1887 let guest_connection_ids = session
1888 .db()
1889 .await
1890 .update_worktree_settings(&message, session.connection_id)
1891 .await?;
1892
1893 broadcast(
1894 Some(session.connection_id),
1895 guest_connection_ids.iter().copied(),
1896 |connection_id| {
1897 session
1898 .peer
1899 .forward_send(session.connection_id, connection_id, message.clone())
1900 },
1901 );
1902
1903 Ok(())
1904}
1905
1906/// Notify other participants that a language server has started.
1907async fn start_language_server(
1908 request: proto::StartLanguageServer,
1909 session: Session,
1910) -> Result<()> {
1911 let guest_connection_ids = session
1912 .db()
1913 .await
1914 .start_language_server(&request, session.connection_id)
1915 .await?;
1916
1917 broadcast(
1918 Some(session.connection_id),
1919 guest_connection_ids.iter().copied(),
1920 |connection_id| {
1921 session
1922 .peer
1923 .forward_send(session.connection_id, connection_id, request.clone())
1924 },
1925 );
1926 Ok(())
1927}
1928
1929/// Notify other participants that a language server has changed.
1930async fn update_language_server(
1931 request: proto::UpdateLanguageServer,
1932 session: Session,
1933) -> Result<()> {
1934 let project_id = ProjectId::from_proto(request.project_id);
1935 let project_connection_ids = session
1936 .db()
1937 .await
1938 .project_connection_ids(project_id, session.connection_id)
1939 .await?;
1940 broadcast(
1941 Some(session.connection_id),
1942 project_connection_ids.iter().copied(),
1943 |connection_id| {
1944 session
1945 .peer
1946 .forward_send(session.connection_id, connection_id, request.clone())
1947 },
1948 );
1949 Ok(())
1950}
1951
1952/// forward a project request to the host. These requests should be read only
1953/// as guests are allowed to send them.
1954async fn forward_read_only_project_request<T>(
1955 request: T,
1956 response: Response<T>,
1957 session: Session,
1958) -> Result<()>
1959where
1960 T: EntityMessage + RequestMessage,
1961{
1962 let project_id = ProjectId::from_proto(request.remote_entity_id());
1963 let host_connection_id = session
1964 .db()
1965 .await
1966 .host_for_read_only_project_request(project_id, session.connection_id)
1967 .await?;
1968 let payload = session
1969 .peer
1970 .forward_request(session.connection_id, host_connection_id, request)
1971 .await?;
1972 response.send(payload)?;
1973 Ok(())
1974}
1975
1976/// forward a project request to the host. These requests are disallowed
1977/// for guests.
1978async fn forward_mutating_project_request<T>(
1979 request: T,
1980 response: Response<T>,
1981 session: Session,
1982) -> Result<()>
1983where
1984 T: EntityMessage + RequestMessage,
1985{
1986 let project_id = ProjectId::from_proto(request.remote_entity_id());
1987 let host_connection_id = session
1988 .db()
1989 .await
1990 .host_for_mutating_project_request(project_id, session.connection_id)
1991 .await?;
1992 let payload = session
1993 .peer
1994 .forward_request(session.connection_id, host_connection_id, request)
1995 .await?;
1996 response.send(payload)?;
1997 Ok(())
1998}
1999
2000/// Notify other participants that a new buffer has been created
2001async fn create_buffer_for_peer(
2002 request: proto::CreateBufferForPeer,
2003 session: Session,
2004) -> Result<()> {
2005 session
2006 .db()
2007 .await
2008 .check_user_is_project_host(
2009 ProjectId::from_proto(request.project_id),
2010 session.connection_id,
2011 )
2012 .await?;
2013 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2014 session
2015 .peer
2016 .forward_send(session.connection_id, peer_id.into(), request)?;
2017 Ok(())
2018}
2019
2020/// Notify other participants that a buffer has been updated. This is
2021/// allowed for guests as long as the update is limited to selections.
2022async fn update_buffer(
2023 request: proto::UpdateBuffer,
2024 response: Response<proto::UpdateBuffer>,
2025 session: Session,
2026) -> Result<()> {
2027 let project_id = ProjectId::from_proto(request.project_id);
2028 let mut guest_connection_ids;
2029 let mut host_connection_id = None;
2030
2031 let mut requires_write_permission = false;
2032
2033 for op in request.operations.iter() {
2034 match op.variant {
2035 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2036 Some(_) => requires_write_permission = true,
2037 }
2038 }
2039
2040 {
2041 let collaborators = session
2042 .db()
2043 .await
2044 .project_collaborators_for_buffer_update(
2045 project_id,
2046 session.connection_id,
2047 requires_write_permission,
2048 )
2049 .await?;
2050 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2051 for collaborator in collaborators.iter() {
2052 if collaborator.is_host {
2053 host_connection_id = Some(collaborator.connection_id);
2054 } else {
2055 guest_connection_ids.push(collaborator.connection_id);
2056 }
2057 }
2058 }
2059 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2060
2061 broadcast(
2062 Some(session.connection_id),
2063 guest_connection_ids,
2064 |connection_id| {
2065 session
2066 .peer
2067 .forward_send(session.connection_id, connection_id, request.clone())
2068 },
2069 );
2070 if host_connection_id != session.connection_id {
2071 session
2072 .peer
2073 .forward_request(session.connection_id, host_connection_id, request.clone())
2074 .await?;
2075 }
2076
2077 response.send(proto::Ack {})?;
2078 Ok(())
2079}
2080
2081/// Notify other participants that a project has been updated.
2082async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2083 request: T,
2084 session: Session,
2085) -> Result<()> {
2086 let project_id = ProjectId::from_proto(request.remote_entity_id());
2087 let project_connection_ids = session
2088 .db()
2089 .await
2090 .project_connection_ids(project_id, session.connection_id)
2091 .await?;
2092
2093 broadcast(
2094 Some(session.connection_id),
2095 project_connection_ids.iter().copied(),
2096 |connection_id| {
2097 session
2098 .peer
2099 .forward_send(session.connection_id, connection_id, request.clone())
2100 },
2101 );
2102 Ok(())
2103}
2104
2105/// Start following another user in a call.
2106async fn follow(
2107 request: proto::Follow,
2108 response: Response<proto::Follow>,
2109 session: Session,
2110) -> Result<()> {
2111 let room_id = RoomId::from_proto(request.room_id);
2112 let project_id = request.project_id.map(ProjectId::from_proto);
2113 let leader_id = request
2114 .leader_id
2115 .ok_or_else(|| anyhow!("invalid leader id"))?
2116 .into();
2117 let follower_id = session.connection_id;
2118
2119 session
2120 .db()
2121 .await
2122 .check_room_participants(room_id, leader_id, session.connection_id)
2123 .await?;
2124
2125 let response_payload = session
2126 .peer
2127 .forward_request(session.connection_id, leader_id, request)
2128 .await?;
2129 response.send(response_payload)?;
2130
2131 if let Some(project_id) = project_id {
2132 let room = session
2133 .db()
2134 .await
2135 .follow(room_id, project_id, leader_id, follower_id)
2136 .await?;
2137 room_updated(&room, &session.peer);
2138 }
2139
2140 Ok(())
2141}
2142
2143/// Stop following another user in a call.
2144async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2145 let room_id = RoomId::from_proto(request.room_id);
2146 let project_id = request.project_id.map(ProjectId::from_proto);
2147 let leader_id = request
2148 .leader_id
2149 .ok_or_else(|| anyhow!("invalid leader id"))?
2150 .into();
2151 let follower_id = session.connection_id;
2152
2153 session
2154 .db()
2155 .await
2156 .check_room_participants(room_id, leader_id, session.connection_id)
2157 .await?;
2158
2159 session
2160 .peer
2161 .forward_send(session.connection_id, leader_id, request)?;
2162
2163 if let Some(project_id) = project_id {
2164 let room = session
2165 .db()
2166 .await
2167 .unfollow(room_id, project_id, leader_id, follower_id)
2168 .await?;
2169 room_updated(&room, &session.peer);
2170 }
2171
2172 Ok(())
2173}
2174
2175/// Notify everyone following you of your current location.
2176async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2177 let room_id = RoomId::from_proto(request.room_id);
2178 let database = session.db.lock().await;
2179
2180 let connection_ids = if let Some(project_id) = request.project_id {
2181 let project_id = ProjectId::from_proto(project_id);
2182 database
2183 .project_connection_ids(project_id, session.connection_id)
2184 .await?
2185 } else {
2186 database
2187 .room_connection_ids(room_id, session.connection_id)
2188 .await?
2189 };
2190
2191 // For now, don't send view update messages back to that view's current leader.
2192 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2193 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2194 _ => None,
2195 });
2196
2197 for connection_id in connection_ids.iter().cloned() {
2198 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2199 session
2200 .peer
2201 .forward_send(session.connection_id, connection_id, request.clone())?;
2202 }
2203 }
2204 Ok(())
2205}
2206
2207/// Get public data about users.
2208async fn get_users(
2209 request: proto::GetUsers,
2210 response: Response<proto::GetUsers>,
2211 session: Session,
2212) -> Result<()> {
2213 let user_ids = request
2214 .user_ids
2215 .into_iter()
2216 .map(UserId::from_proto)
2217 .collect();
2218 let users = session
2219 .db()
2220 .await
2221 .get_users_by_ids(user_ids)
2222 .await?
2223 .into_iter()
2224 .map(|user| proto::User {
2225 id: user.id.to_proto(),
2226 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2227 github_login: user.github_login,
2228 })
2229 .collect();
2230 response.send(proto::UsersResponse { users })?;
2231 Ok(())
2232}
2233
2234/// Search for users (to invite) buy Github login
2235async fn fuzzy_search_users(
2236 request: proto::FuzzySearchUsers,
2237 response: Response<proto::FuzzySearchUsers>,
2238 session: Session,
2239) -> Result<()> {
2240 let query = request.query;
2241 let users = match query.len() {
2242 0 => vec![],
2243 1 | 2 => session
2244 .db()
2245 .await
2246 .get_user_by_github_login(&query)
2247 .await?
2248 .into_iter()
2249 .collect(),
2250 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2251 };
2252 let users = users
2253 .into_iter()
2254 .filter(|user| user.id != session.user_id)
2255 .map(|user| proto::User {
2256 id: user.id.to_proto(),
2257 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2258 github_login: user.github_login,
2259 })
2260 .collect();
2261 response.send(proto::UsersResponse { users })?;
2262 Ok(())
2263}
2264
2265/// Send a contact request to another user.
2266async fn request_contact(
2267 request: proto::RequestContact,
2268 response: Response<proto::RequestContact>,
2269 session: Session,
2270) -> Result<()> {
2271 let requester_id = session.user_id;
2272 let responder_id = UserId::from_proto(request.responder_id);
2273 if requester_id == responder_id {
2274 return Err(anyhow!("cannot add yourself as a contact"))?;
2275 }
2276
2277 let notifications = session
2278 .db()
2279 .await
2280 .send_contact_request(requester_id, responder_id)
2281 .await?;
2282
2283 // Update outgoing contact requests of requester
2284 let mut update = proto::UpdateContacts::default();
2285 update.outgoing_requests.push(responder_id.to_proto());
2286 for connection_id in session
2287 .connection_pool()
2288 .await
2289 .user_connection_ids(requester_id)
2290 {
2291 session.peer.send(connection_id, update.clone())?;
2292 }
2293
2294 // Update incoming contact requests of responder
2295 let mut update = proto::UpdateContacts::default();
2296 update
2297 .incoming_requests
2298 .push(proto::IncomingContactRequest {
2299 requester_id: requester_id.to_proto(),
2300 });
2301 let connection_pool = session.connection_pool().await;
2302 for connection_id in connection_pool.user_connection_ids(responder_id) {
2303 session.peer.send(connection_id, update.clone())?;
2304 }
2305
2306 send_notifications(&connection_pool, &session.peer, notifications);
2307
2308 response.send(proto::Ack {})?;
2309 Ok(())
2310}
2311
2312/// Accept or decline a contact request
2313async fn respond_to_contact_request(
2314 request: proto::RespondToContactRequest,
2315 response: Response<proto::RespondToContactRequest>,
2316 session: Session,
2317) -> Result<()> {
2318 let responder_id = session.user_id;
2319 let requester_id = UserId::from_proto(request.requester_id);
2320 let db = session.db().await;
2321 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2322 db.dismiss_contact_notification(responder_id, requester_id)
2323 .await?;
2324 } else {
2325 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2326
2327 let notifications = db
2328 .respond_to_contact_request(responder_id, requester_id, accept)
2329 .await?;
2330 let requester_busy = db.is_user_busy(requester_id).await?;
2331 let responder_busy = db.is_user_busy(responder_id).await?;
2332
2333 let pool = session.connection_pool().await;
2334 // Update responder with new contact
2335 let mut update = proto::UpdateContacts::default();
2336 if accept {
2337 update
2338 .contacts
2339 .push(contact_for_user(requester_id, requester_busy, &pool));
2340 }
2341 update
2342 .remove_incoming_requests
2343 .push(requester_id.to_proto());
2344 for connection_id in pool.user_connection_ids(responder_id) {
2345 session.peer.send(connection_id, update.clone())?;
2346 }
2347
2348 // Update requester with new contact
2349 let mut update = proto::UpdateContacts::default();
2350 if accept {
2351 update
2352 .contacts
2353 .push(contact_for_user(responder_id, responder_busy, &pool));
2354 }
2355 update
2356 .remove_outgoing_requests
2357 .push(responder_id.to_proto());
2358
2359 for connection_id in pool.user_connection_ids(requester_id) {
2360 session.peer.send(connection_id, update.clone())?;
2361 }
2362
2363 send_notifications(&pool, &session.peer, notifications);
2364 }
2365
2366 response.send(proto::Ack {})?;
2367 Ok(())
2368}
2369
2370/// Remove a contact.
2371async fn remove_contact(
2372 request: proto::RemoveContact,
2373 response: Response<proto::RemoveContact>,
2374 session: Session,
2375) -> Result<()> {
2376 let requester_id = session.user_id;
2377 let responder_id = UserId::from_proto(request.user_id);
2378 let db = session.db().await;
2379 let (contact_accepted, deleted_notification_id) =
2380 db.remove_contact(requester_id, responder_id).await?;
2381
2382 let pool = session.connection_pool().await;
2383 // Update outgoing contact requests of requester
2384 let mut update = proto::UpdateContacts::default();
2385 if contact_accepted {
2386 update.remove_contacts.push(responder_id.to_proto());
2387 } else {
2388 update
2389 .remove_outgoing_requests
2390 .push(responder_id.to_proto());
2391 }
2392 for connection_id in pool.user_connection_ids(requester_id) {
2393 session.peer.send(connection_id, update.clone())?;
2394 }
2395
2396 // Update incoming contact requests of responder
2397 let mut update = proto::UpdateContacts::default();
2398 if contact_accepted {
2399 update.remove_contacts.push(requester_id.to_proto());
2400 } else {
2401 update
2402 .remove_incoming_requests
2403 .push(requester_id.to_proto());
2404 }
2405 for connection_id in pool.user_connection_ids(responder_id) {
2406 session.peer.send(connection_id, update.clone())?;
2407 if let Some(notification_id) = deleted_notification_id {
2408 session.peer.send(
2409 connection_id,
2410 proto::DeleteNotification {
2411 notification_id: notification_id.to_proto(),
2412 },
2413 )?;
2414 }
2415 }
2416
2417 response.send(proto::Ack {})?;
2418 Ok(())
2419}
2420
2421/// Creates a new channel.
2422async fn create_channel(
2423 request: proto::CreateChannel,
2424 response: Response<proto::CreateChannel>,
2425 session: Session,
2426) -> Result<()> {
2427 let db = session.db().await;
2428
2429 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2430 let (channel, owner, channel_members) = db
2431 .create_channel(&request.name, parent_id, session.user_id)
2432 .await?;
2433
2434 response.send(proto::CreateChannelResponse {
2435 channel: Some(channel.to_proto()),
2436 parent_id: request.parent_id,
2437 })?;
2438
2439 let connection_pool = session.connection_pool().await;
2440 if let Some(owner) = owner {
2441 let update = proto::UpdateUserChannels {
2442 channel_memberships: vec![proto::ChannelMembership {
2443 channel_id: owner.channel_id.to_proto(),
2444 role: owner.role.into(),
2445 }],
2446 ..Default::default()
2447 };
2448 for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2449 session.peer.send(connection_id, update.clone())?;
2450 }
2451 }
2452
2453 for channel_member in channel_members {
2454 if !channel_member.role.can_see_channel(channel.visibility) {
2455 continue;
2456 }
2457
2458 let update = proto::UpdateChannels {
2459 channels: vec![channel.to_proto()],
2460 ..Default::default()
2461 };
2462 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2463 session.peer.send(connection_id, update.clone())?;
2464 }
2465 }
2466
2467 Ok(())
2468}
2469
2470/// Delete a channel
2471async fn delete_channel(
2472 request: proto::DeleteChannel,
2473 response: Response<proto::DeleteChannel>,
2474 session: Session,
2475) -> Result<()> {
2476 let db = session.db().await;
2477
2478 let channel_id = request.channel_id;
2479 let (removed_channels, member_ids) = db
2480 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2481 .await?;
2482 response.send(proto::Ack {})?;
2483
2484 // Notify members of removed channels
2485 let mut update = proto::UpdateChannels::default();
2486 update
2487 .delete_channels
2488 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2489
2490 let connection_pool = session.connection_pool().await;
2491 for member_id in member_ids {
2492 for connection_id in connection_pool.user_connection_ids(member_id) {
2493 session.peer.send(connection_id, update.clone())?;
2494 }
2495 }
2496
2497 Ok(())
2498}
2499
2500/// Invite someone to join a channel.
2501async fn invite_channel_member(
2502 request: proto::InviteChannelMember,
2503 response: Response<proto::InviteChannelMember>,
2504 session: Session,
2505) -> Result<()> {
2506 let db = session.db().await;
2507 let channel_id = ChannelId::from_proto(request.channel_id);
2508 let invitee_id = UserId::from_proto(request.user_id);
2509 let InviteMemberResult {
2510 channel,
2511 notifications,
2512 } = db
2513 .invite_channel_member(
2514 channel_id,
2515 invitee_id,
2516 session.user_id,
2517 request.role().into(),
2518 )
2519 .await?;
2520
2521 let update = proto::UpdateChannels {
2522 channel_invitations: vec![channel.to_proto()],
2523 ..Default::default()
2524 };
2525
2526 let connection_pool = session.connection_pool().await;
2527 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2528 session.peer.send(connection_id, update.clone())?;
2529 }
2530
2531 send_notifications(&connection_pool, &session.peer, notifications);
2532
2533 response.send(proto::Ack {})?;
2534 Ok(())
2535}
2536
2537/// remove someone from a channel
2538async fn remove_channel_member(
2539 request: proto::RemoveChannelMember,
2540 response: Response<proto::RemoveChannelMember>,
2541 session: Session,
2542) -> Result<()> {
2543 let db = session.db().await;
2544 let channel_id = ChannelId::from_proto(request.channel_id);
2545 let member_id = UserId::from_proto(request.user_id);
2546
2547 let RemoveChannelMemberResult {
2548 membership_update,
2549 notification_id,
2550 } = db
2551 .remove_channel_member(channel_id, member_id, session.user_id)
2552 .await?;
2553
2554 let connection_pool = &session.connection_pool().await;
2555 notify_membership_updated(
2556 &connection_pool,
2557 membership_update,
2558 member_id,
2559 &session.peer,
2560 );
2561 for connection_id in connection_pool.user_connection_ids(member_id) {
2562 if let Some(notification_id) = notification_id {
2563 session
2564 .peer
2565 .send(
2566 connection_id,
2567 proto::DeleteNotification {
2568 notification_id: notification_id.to_proto(),
2569 },
2570 )
2571 .trace_err();
2572 }
2573 }
2574
2575 response.send(proto::Ack {})?;
2576 Ok(())
2577}
2578
2579/// Toggle the channel between public and private.
2580/// Care is taken to maintain the invariant that public channels only descend from public channels,
2581/// (though members-only channels can appear at any point in the hierarchy).
2582async fn set_channel_visibility(
2583 request: proto::SetChannelVisibility,
2584 response: Response<proto::SetChannelVisibility>,
2585 session: Session,
2586) -> Result<()> {
2587 let db = session.db().await;
2588 let channel_id = ChannelId::from_proto(request.channel_id);
2589 let visibility = request.visibility().into();
2590
2591 let (channel, channel_members) = db
2592 .set_channel_visibility(channel_id, visibility, session.user_id)
2593 .await?;
2594
2595 let connection_pool = session.connection_pool().await;
2596 for member in channel_members {
2597 let update = if member.role.can_see_channel(channel.visibility) {
2598 proto::UpdateChannels {
2599 channels: vec![channel.to_proto()],
2600 ..Default::default()
2601 }
2602 } else {
2603 proto::UpdateChannels {
2604 delete_channels: vec![channel.id.to_proto()],
2605 ..Default::default()
2606 }
2607 };
2608
2609 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2610 session.peer.send(connection_id, update.clone())?;
2611 }
2612 }
2613
2614 response.send(proto::Ack {})?;
2615 Ok(())
2616}
2617
2618/// Alter the role for a user in the channel.
2619async fn set_channel_member_role(
2620 request: proto::SetChannelMemberRole,
2621 response: Response<proto::SetChannelMemberRole>,
2622 session: Session,
2623) -> Result<()> {
2624 let db = session.db().await;
2625 let channel_id = ChannelId::from_proto(request.channel_id);
2626 let member_id = UserId::from_proto(request.user_id);
2627 let result = db
2628 .set_channel_member_role(
2629 channel_id,
2630 session.user_id,
2631 member_id,
2632 request.role().into(),
2633 )
2634 .await?;
2635
2636 match result {
2637 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2638 let connection_pool = session.connection_pool().await;
2639 notify_membership_updated(
2640 &connection_pool,
2641 membership_update,
2642 member_id,
2643 &session.peer,
2644 )
2645 }
2646 db::SetMemberRoleResult::InviteUpdated(channel) => {
2647 let update = proto::UpdateChannels {
2648 channel_invitations: vec![channel.to_proto()],
2649 ..Default::default()
2650 };
2651
2652 for connection_id in session
2653 .connection_pool()
2654 .await
2655 .user_connection_ids(member_id)
2656 {
2657 session.peer.send(connection_id, update.clone())?;
2658 }
2659 }
2660 }
2661
2662 response.send(proto::Ack {})?;
2663 Ok(())
2664}
2665
2666/// Change the name of a channel
2667async fn rename_channel(
2668 request: proto::RenameChannel,
2669 response: Response<proto::RenameChannel>,
2670 session: Session,
2671) -> Result<()> {
2672 let db = session.db().await;
2673 let channel_id = ChannelId::from_proto(request.channel_id);
2674 let (channel, channel_members) = db
2675 .rename_channel(channel_id, session.user_id, &request.name)
2676 .await?;
2677
2678 response.send(proto::RenameChannelResponse {
2679 channel: Some(channel.to_proto()),
2680 })?;
2681
2682 let connection_pool = session.connection_pool().await;
2683 for channel_member in channel_members {
2684 if !channel_member.role.can_see_channel(channel.visibility) {
2685 continue;
2686 }
2687 let update = proto::UpdateChannels {
2688 channels: vec![channel.to_proto()],
2689 ..Default::default()
2690 };
2691 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2692 session.peer.send(connection_id, update.clone())?;
2693 }
2694 }
2695
2696 Ok(())
2697}
2698
2699/// Move a channel to a new parent.
2700async fn move_channel(
2701 request: proto::MoveChannel,
2702 response: Response<proto::MoveChannel>,
2703 session: Session,
2704) -> Result<()> {
2705 let channel_id = ChannelId::from_proto(request.channel_id);
2706 let to = ChannelId::from_proto(request.to);
2707
2708 let (channels, channel_members) = session
2709 .db()
2710 .await
2711 .move_channel(channel_id, to, session.user_id)
2712 .await?;
2713
2714 let connection_pool = session.connection_pool().await;
2715 for member in channel_members {
2716 let channels = channels
2717 .iter()
2718 .filter_map(|channel| {
2719 if member.role.can_see_channel(channel.visibility) {
2720 Some(channel.to_proto())
2721 } else {
2722 None
2723 }
2724 })
2725 .collect::<Vec<_>>();
2726 if channels.is_empty() {
2727 continue;
2728 }
2729
2730 let update = proto::UpdateChannels {
2731 channels,
2732 ..Default::default()
2733 };
2734
2735 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2736 session.peer.send(connection_id, update.clone())?;
2737 }
2738 }
2739
2740 response.send(Ack {})?;
2741 Ok(())
2742}
2743
2744/// Get the list of channel members
2745async fn get_channel_members(
2746 request: proto::GetChannelMembers,
2747 response: Response<proto::GetChannelMembers>,
2748 session: Session,
2749) -> Result<()> {
2750 let db = session.db().await;
2751 let channel_id = ChannelId::from_proto(request.channel_id);
2752 let members = db
2753 .get_channel_participant_details(channel_id, session.user_id)
2754 .await?;
2755 response.send(proto::GetChannelMembersResponse { members })?;
2756 Ok(())
2757}
2758
2759/// Accept or decline a channel invitation.
2760async fn respond_to_channel_invite(
2761 request: proto::RespondToChannelInvite,
2762 response: Response<proto::RespondToChannelInvite>,
2763 session: Session,
2764) -> Result<()> {
2765 let db = session.db().await;
2766 let channel_id = ChannelId::from_proto(request.channel_id);
2767 let RespondToChannelInvite {
2768 membership_update,
2769 notifications,
2770 } = db
2771 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2772 .await?;
2773
2774 let connection_pool = session.connection_pool().await;
2775 if let Some(membership_update) = membership_update {
2776 notify_membership_updated(
2777 &connection_pool,
2778 membership_update,
2779 session.user_id,
2780 &session.peer,
2781 );
2782 } else {
2783 let update = proto::UpdateChannels {
2784 remove_channel_invitations: vec![channel_id.to_proto()],
2785 ..Default::default()
2786 };
2787
2788 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2789 session.peer.send(connection_id, update.clone())?;
2790 }
2791 };
2792
2793 send_notifications(&connection_pool, &session.peer, notifications);
2794
2795 response.send(proto::Ack {})?;
2796
2797 Ok(())
2798}
2799
2800/// Join the channels' room
2801async fn join_channel(
2802 request: proto::JoinChannel,
2803 response: Response<proto::JoinChannel>,
2804 session: Session,
2805) -> Result<()> {
2806 let channel_id = ChannelId::from_proto(request.channel_id);
2807 join_channel_internal(channel_id, Box::new(response), session).await
2808}
2809
2810trait JoinChannelInternalResponse {
2811 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2812}
2813impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2814 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2815 Response::<proto::JoinChannel>::send(self, result)
2816 }
2817}
2818impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2819 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2820 Response::<proto::JoinRoom>::send(self, result)
2821 }
2822}
2823
2824async fn join_channel_internal(
2825 channel_id: ChannelId,
2826 response: Box<impl JoinChannelInternalResponse>,
2827 session: Session,
2828) -> Result<()> {
2829 let joined_room = {
2830 leave_room_for_session(&session).await?;
2831 let db = session.db().await;
2832
2833 let (joined_room, membership_updated, role) = db
2834 .join_channel(channel_id, session.user_id, session.connection_id)
2835 .await?;
2836
2837 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2838 let (can_publish, token) = if role == ChannelRole::Guest {
2839 (
2840 false,
2841 live_kit
2842 .guest_token(
2843 &joined_room.room.live_kit_room,
2844 &session.user_id.to_string(),
2845 )
2846 .trace_err()?,
2847 )
2848 } else {
2849 (
2850 true,
2851 live_kit
2852 .room_token(
2853 &joined_room.room.live_kit_room,
2854 &session.user_id.to_string(),
2855 )
2856 .trace_err()?,
2857 )
2858 };
2859
2860 Some(LiveKitConnectionInfo {
2861 server_url: live_kit.url().into(),
2862 token,
2863 can_publish,
2864 })
2865 });
2866
2867 response.send(proto::JoinRoomResponse {
2868 room: Some(joined_room.room.clone()),
2869 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2870 live_kit_connection_info,
2871 })?;
2872
2873 let connection_pool = session.connection_pool().await;
2874 if let Some(membership_updated) = membership_updated {
2875 notify_membership_updated(
2876 &connection_pool,
2877 membership_updated,
2878 session.user_id,
2879 &session.peer,
2880 );
2881 }
2882
2883 room_updated(&joined_room.room, &session.peer);
2884
2885 joined_room
2886 };
2887
2888 channel_updated(
2889 channel_id,
2890 &joined_room.room,
2891 &joined_room.channel_members,
2892 &session.peer,
2893 &*session.connection_pool().await,
2894 );
2895
2896 update_user_contacts(session.user_id, &session).await?;
2897 Ok(())
2898}
2899
2900/// Start editing the channel notes
2901async fn join_channel_buffer(
2902 request: proto::JoinChannelBuffer,
2903 response: Response<proto::JoinChannelBuffer>,
2904 session: Session,
2905) -> Result<()> {
2906 let db = session.db().await;
2907 let channel_id = ChannelId::from_proto(request.channel_id);
2908
2909 let open_response = db
2910 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2911 .await?;
2912
2913 let collaborators = open_response.collaborators.clone();
2914 response.send(open_response)?;
2915
2916 let update = UpdateChannelBufferCollaborators {
2917 channel_id: channel_id.to_proto(),
2918 collaborators: collaborators.clone(),
2919 };
2920 channel_buffer_updated(
2921 session.connection_id,
2922 collaborators
2923 .iter()
2924 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2925 &update,
2926 &session.peer,
2927 );
2928
2929 Ok(())
2930}
2931
2932/// Edit the channel notes
2933async fn update_channel_buffer(
2934 request: proto::UpdateChannelBuffer,
2935 session: Session,
2936) -> Result<()> {
2937 let db = session.db().await;
2938 let channel_id = ChannelId::from_proto(request.channel_id);
2939
2940 let (collaborators, non_collaborators, epoch, version) = db
2941 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2942 .await?;
2943
2944 channel_buffer_updated(
2945 session.connection_id,
2946 collaborators,
2947 &proto::UpdateChannelBuffer {
2948 channel_id: channel_id.to_proto(),
2949 operations: request.operations,
2950 },
2951 &session.peer,
2952 );
2953
2954 let pool = &*session.connection_pool().await;
2955
2956 broadcast(
2957 None,
2958 non_collaborators
2959 .iter()
2960 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2961 |peer_id| {
2962 session.peer.send(
2963 peer_id,
2964 proto::UpdateChannels {
2965 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2966 channel_id: channel_id.to_proto(),
2967 epoch: epoch as u64,
2968 version: version.clone(),
2969 }],
2970 ..Default::default()
2971 },
2972 )
2973 },
2974 );
2975
2976 Ok(())
2977}
2978
2979/// Rejoin the channel notes after a connection blip
2980async fn rejoin_channel_buffers(
2981 request: proto::RejoinChannelBuffers,
2982 response: Response<proto::RejoinChannelBuffers>,
2983 session: Session,
2984) -> Result<()> {
2985 let db = session.db().await;
2986 let buffers = db
2987 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2988 .await?;
2989
2990 for rejoined_buffer in &buffers {
2991 let collaborators_to_notify = rejoined_buffer
2992 .buffer
2993 .collaborators
2994 .iter()
2995 .filter_map(|c| Some(c.peer_id?.into()));
2996 channel_buffer_updated(
2997 session.connection_id,
2998 collaborators_to_notify,
2999 &proto::UpdateChannelBufferCollaborators {
3000 channel_id: rejoined_buffer.buffer.channel_id,
3001 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3002 },
3003 &session.peer,
3004 );
3005 }
3006
3007 response.send(proto::RejoinChannelBuffersResponse {
3008 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3009 })?;
3010
3011 Ok(())
3012}
3013
3014/// Stop editing the channel notes
3015async fn leave_channel_buffer(
3016 request: proto::LeaveChannelBuffer,
3017 response: Response<proto::LeaveChannelBuffer>,
3018 session: Session,
3019) -> Result<()> {
3020 let db = session.db().await;
3021 let channel_id = ChannelId::from_proto(request.channel_id);
3022
3023 let left_buffer = db
3024 .leave_channel_buffer(channel_id, session.connection_id)
3025 .await?;
3026
3027 response.send(Ack {})?;
3028
3029 channel_buffer_updated(
3030 session.connection_id,
3031 left_buffer.connections,
3032 &proto::UpdateChannelBufferCollaborators {
3033 channel_id: channel_id.to_proto(),
3034 collaborators: left_buffer.collaborators,
3035 },
3036 &session.peer,
3037 );
3038
3039 Ok(())
3040}
3041
3042fn channel_buffer_updated<T: EnvelopedMessage>(
3043 sender_id: ConnectionId,
3044 collaborators: impl IntoIterator<Item = ConnectionId>,
3045 message: &T,
3046 peer: &Peer,
3047) {
3048 broadcast(Some(sender_id), collaborators, |peer_id| {
3049 peer.send(peer_id, message.clone())
3050 });
3051}
3052
3053fn send_notifications(
3054 connection_pool: &ConnectionPool,
3055 peer: &Peer,
3056 notifications: db::NotificationBatch,
3057) {
3058 for (user_id, notification) in notifications {
3059 for connection_id in connection_pool.user_connection_ids(user_id) {
3060 if let Err(error) = peer.send(
3061 connection_id,
3062 proto::AddNotification {
3063 notification: Some(notification.clone()),
3064 },
3065 ) {
3066 tracing::error!(
3067 "failed to send notification to {:?} {}",
3068 connection_id,
3069 error
3070 );
3071 }
3072 }
3073 }
3074}
3075
3076/// Send a message to the channel
3077async fn send_channel_message(
3078 request: proto::SendChannelMessage,
3079 response: Response<proto::SendChannelMessage>,
3080 session: Session,
3081) -> Result<()> {
3082 // Validate the message body.
3083 let body = request.body.trim().to_string();
3084 if body.len() > MAX_MESSAGE_LEN {
3085 return Err(anyhow!("message is too long"))?;
3086 }
3087 if body.is_empty() {
3088 return Err(anyhow!("message can't be blank"))?;
3089 }
3090
3091 // TODO: adjust mentions if body is trimmed
3092
3093 let timestamp = OffsetDateTime::now_utc();
3094 let nonce = request
3095 .nonce
3096 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3097
3098 let channel_id = ChannelId::from_proto(request.channel_id);
3099 let CreatedChannelMessage {
3100 message_id,
3101 participant_connection_ids,
3102 channel_members,
3103 notifications,
3104 } = session
3105 .db()
3106 .await
3107 .create_channel_message(
3108 channel_id,
3109 session.user_id,
3110 &body,
3111 &request.mentions,
3112 timestamp,
3113 nonce.clone().into(),
3114 match request.reply_to_message_id {
3115 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3116 None => None,
3117 },
3118 )
3119 .await?;
3120 let message = proto::ChannelMessage {
3121 sender_id: session.user_id.to_proto(),
3122 id: message_id.to_proto(),
3123 body,
3124 mentions: request.mentions,
3125 timestamp: timestamp.unix_timestamp() as u64,
3126 nonce: Some(nonce),
3127 reply_to_message_id: request.reply_to_message_id,
3128 };
3129 broadcast(
3130 Some(session.connection_id),
3131 participant_connection_ids,
3132 |connection| {
3133 session.peer.send(
3134 connection,
3135 proto::ChannelMessageSent {
3136 channel_id: channel_id.to_proto(),
3137 message: Some(message.clone()),
3138 },
3139 )
3140 },
3141 );
3142 response.send(proto::SendChannelMessageResponse {
3143 message: Some(message),
3144 })?;
3145
3146 let pool = &*session.connection_pool().await;
3147 broadcast(
3148 None,
3149 channel_members
3150 .iter()
3151 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3152 |peer_id| {
3153 session.peer.send(
3154 peer_id,
3155 proto::UpdateChannels {
3156 latest_channel_message_ids: vec![proto::ChannelMessageId {
3157 channel_id: channel_id.to_proto(),
3158 message_id: message_id.to_proto(),
3159 }],
3160 ..Default::default()
3161 },
3162 )
3163 },
3164 );
3165 send_notifications(pool, &session.peer, notifications);
3166
3167 Ok(())
3168}
3169
3170/// Delete a channel message
3171async fn remove_channel_message(
3172 request: proto::RemoveChannelMessage,
3173 response: Response<proto::RemoveChannelMessage>,
3174 session: Session,
3175) -> Result<()> {
3176 let channel_id = ChannelId::from_proto(request.channel_id);
3177 let message_id = MessageId::from_proto(request.message_id);
3178 let connection_ids = session
3179 .db()
3180 .await
3181 .remove_channel_message(channel_id, message_id, session.user_id)
3182 .await?;
3183 broadcast(Some(session.connection_id), connection_ids, |connection| {
3184 session.peer.send(connection, request.clone())
3185 });
3186 response.send(proto::Ack {})?;
3187 Ok(())
3188}
3189
3190/// Mark a channel message as read
3191async fn acknowledge_channel_message(
3192 request: proto::AckChannelMessage,
3193 session: Session,
3194) -> Result<()> {
3195 let channel_id = ChannelId::from_proto(request.channel_id);
3196 let message_id = MessageId::from_proto(request.message_id);
3197 let notifications = session
3198 .db()
3199 .await
3200 .observe_channel_message(channel_id, session.user_id, message_id)
3201 .await?;
3202 send_notifications(
3203 &*session.connection_pool().await,
3204 &session.peer,
3205 notifications,
3206 );
3207 Ok(())
3208}
3209
3210/// Mark a buffer version as synced
3211async fn acknowledge_buffer_version(
3212 request: proto::AckBufferOperation,
3213 session: Session,
3214) -> Result<()> {
3215 let buffer_id = BufferId::from_proto(request.buffer_id);
3216 session
3217 .db()
3218 .await
3219 .observe_buffer_version(
3220 buffer_id,
3221 session.user_id,
3222 request.epoch as i32,
3223 &request.version,
3224 )
3225 .await?;
3226 Ok(())
3227}
3228
3229/// Start receiving chat updates for a channel
3230async fn join_channel_chat(
3231 request: proto::JoinChannelChat,
3232 response: Response<proto::JoinChannelChat>,
3233 session: Session,
3234) -> Result<()> {
3235 let channel_id = ChannelId::from_proto(request.channel_id);
3236
3237 let db = session.db().await;
3238 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3239 .await?;
3240 let messages = db
3241 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3242 .await?;
3243 response.send(proto::JoinChannelChatResponse {
3244 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3245 messages,
3246 })?;
3247 Ok(())
3248}
3249
3250/// Stop receiving chat updates for a channel
3251async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3252 let channel_id = ChannelId::from_proto(request.channel_id);
3253 session
3254 .db()
3255 .await
3256 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3257 .await?;
3258 Ok(())
3259}
3260
3261/// Retrieve the chat history for a channel
3262async fn get_channel_messages(
3263 request: proto::GetChannelMessages,
3264 response: Response<proto::GetChannelMessages>,
3265 session: Session,
3266) -> Result<()> {
3267 let channel_id = ChannelId::from_proto(request.channel_id);
3268 let messages = session
3269 .db()
3270 .await
3271 .get_channel_messages(
3272 channel_id,
3273 session.user_id,
3274 MESSAGE_COUNT_PER_PAGE,
3275 Some(MessageId::from_proto(request.before_message_id)),
3276 )
3277 .await?;
3278 response.send(proto::GetChannelMessagesResponse {
3279 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3280 messages,
3281 })?;
3282 Ok(())
3283}
3284
3285/// Retrieve specific chat messages
3286async fn get_channel_messages_by_id(
3287 request: proto::GetChannelMessagesById,
3288 response: Response<proto::GetChannelMessagesById>,
3289 session: Session,
3290) -> Result<()> {
3291 let message_ids = request
3292 .message_ids
3293 .iter()
3294 .map(|id| MessageId::from_proto(*id))
3295 .collect::<Vec<_>>();
3296 let messages = session
3297 .db()
3298 .await
3299 .get_channel_messages_by_id(session.user_id, &message_ids)
3300 .await?;
3301 response.send(proto::GetChannelMessagesResponse {
3302 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3303 messages,
3304 })?;
3305 Ok(())
3306}
3307
3308/// Retrieve the current users notifications
3309async fn get_notifications(
3310 request: proto::GetNotifications,
3311 response: Response<proto::GetNotifications>,
3312 session: Session,
3313) -> Result<()> {
3314 let notifications = session
3315 .db()
3316 .await
3317 .get_notifications(
3318 session.user_id,
3319 NOTIFICATION_COUNT_PER_PAGE,
3320 request
3321 .before_id
3322 .map(|id| db::NotificationId::from_proto(id)),
3323 )
3324 .await?;
3325 response.send(proto::GetNotificationsResponse {
3326 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3327 notifications,
3328 })?;
3329 Ok(())
3330}
3331
3332/// Mark notifications as read
3333async fn mark_notification_as_read(
3334 request: proto::MarkNotificationRead,
3335 response: Response<proto::MarkNotificationRead>,
3336 session: Session,
3337) -> Result<()> {
3338 let database = &session.db().await;
3339 let notifications = database
3340 .mark_notification_as_read_by_id(
3341 session.user_id,
3342 NotificationId::from_proto(request.notification_id),
3343 )
3344 .await?;
3345 send_notifications(
3346 &*session.connection_pool().await,
3347 &session.peer,
3348 notifications,
3349 );
3350 response.send(proto::Ack {})?;
3351 Ok(())
3352}
3353
3354/// Get the current users information
3355async fn get_private_user_info(
3356 _request: proto::GetPrivateUserInfo,
3357 response: Response<proto::GetPrivateUserInfo>,
3358 session: Session,
3359) -> Result<()> {
3360 let db = session.db().await;
3361
3362 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3363 let user = db
3364 .get_user_by_id(session.user_id)
3365 .await?
3366 .ok_or_else(|| anyhow!("user not found"))?;
3367 let flags = db.get_user_flags(session.user_id).await?;
3368
3369 response.send(proto::GetPrivateUserInfoResponse {
3370 metrics_id,
3371 staff: user.admin,
3372 flags,
3373 })?;
3374 Ok(())
3375}
3376
3377fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3378 match message {
3379 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3380 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3381 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3382 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3383 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3384 code: frame.code.into(),
3385 reason: frame.reason,
3386 })),
3387 }
3388}
3389
3390fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3391 match message {
3392 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3393 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3394 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3395 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3396 AxumMessage::Close(frame) => {
3397 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3398 code: frame.code.into(),
3399 reason: frame.reason,
3400 }))
3401 }
3402 }
3403}
3404
3405fn notify_membership_updated(
3406 connection_pool: &ConnectionPool,
3407 result: MembershipUpdated,
3408 user_id: UserId,
3409 peer: &Peer,
3410) {
3411 let user_channels_update = proto::UpdateUserChannels {
3412 channel_memberships: result
3413 .new_channels
3414 .channel_memberships
3415 .iter()
3416 .map(|cm| proto::ChannelMembership {
3417 channel_id: cm.channel_id.to_proto(),
3418 role: cm.role.into(),
3419 })
3420 .collect(),
3421 ..Default::default()
3422 };
3423 let mut update = build_channels_update(result.new_channels, vec![]);
3424 update.delete_channels = result
3425 .removed_channels
3426 .into_iter()
3427 .map(|id| id.to_proto())
3428 .collect();
3429 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3430
3431 for connection_id in connection_pool.user_connection_ids(user_id) {
3432 peer.send(connection_id, user_channels_update.clone())
3433 .trace_err();
3434 peer.send(connection_id, update.clone()).trace_err();
3435 }
3436}
3437
3438fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3439 proto::UpdateUserChannels {
3440 channel_memberships: channels
3441 .channel_memberships
3442 .iter()
3443 .map(|m| proto::ChannelMembership {
3444 channel_id: m.channel_id.to_proto(),
3445 role: m.role.into(),
3446 })
3447 .collect(),
3448 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3449 observed_channel_message_id: channels.observed_channel_messages.clone(),
3450 }
3451}
3452
3453fn build_channels_update(
3454 channels: ChannelsForUser,
3455 channel_invites: Vec<db::Channel>,
3456) -> proto::UpdateChannels {
3457 let mut update = proto::UpdateChannels::default();
3458
3459 for channel in channels.channels {
3460 update.channels.push(channel.to_proto());
3461 }
3462
3463 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3464 update.latest_channel_message_ids = channels.latest_channel_messages;
3465
3466 for (channel_id, participants) in channels.channel_participants {
3467 update
3468 .channel_participants
3469 .push(proto::ChannelParticipants {
3470 channel_id: channel_id.to_proto(),
3471 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3472 });
3473 }
3474
3475 for channel in channel_invites {
3476 update.channel_invitations.push(channel.to_proto());
3477 }
3478 for project in channels.hosted_projects {
3479 update.hosted_projects.push(project);
3480 }
3481
3482 update
3483}
3484
3485fn build_initial_contacts_update(
3486 contacts: Vec<db::Contact>,
3487 pool: &ConnectionPool,
3488) -> proto::UpdateContacts {
3489 let mut update = proto::UpdateContacts::default();
3490
3491 for contact in contacts {
3492 match contact {
3493 db::Contact::Accepted { user_id, busy } => {
3494 update.contacts.push(contact_for_user(user_id, busy, &pool));
3495 }
3496 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3497 db::Contact::Incoming { user_id } => {
3498 update
3499 .incoming_requests
3500 .push(proto::IncomingContactRequest {
3501 requester_id: user_id.to_proto(),
3502 })
3503 }
3504 }
3505 }
3506
3507 update
3508}
3509
3510fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3511 proto::Contact {
3512 user_id: user_id.to_proto(),
3513 online: pool.is_user_online(user_id),
3514 busy,
3515 }
3516}
3517
3518fn room_updated(room: &proto::Room, peer: &Peer) {
3519 broadcast(
3520 None,
3521 room.participants
3522 .iter()
3523 .filter_map(|participant| Some(participant.peer_id?.into())),
3524 |peer_id| {
3525 peer.send(
3526 peer_id,
3527 proto::RoomUpdated {
3528 room: Some(room.clone()),
3529 },
3530 )
3531 },
3532 );
3533}
3534
3535fn channel_updated(
3536 channel_id: ChannelId,
3537 room: &proto::Room,
3538 channel_members: &[UserId],
3539 peer: &Peer,
3540 pool: &ConnectionPool,
3541) {
3542 let participants = room
3543 .participants
3544 .iter()
3545 .map(|p| p.user_id)
3546 .collect::<Vec<_>>();
3547
3548 broadcast(
3549 None,
3550 channel_members
3551 .iter()
3552 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3553 |peer_id| {
3554 peer.send(
3555 peer_id,
3556 proto::UpdateChannels {
3557 channel_participants: vec![proto::ChannelParticipants {
3558 channel_id: channel_id.to_proto(),
3559 participant_user_ids: participants.clone(),
3560 }],
3561 ..Default::default()
3562 },
3563 )
3564 },
3565 );
3566}
3567
3568async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3569 let db = session.db().await;
3570
3571 let contacts = db.get_contacts(user_id).await?;
3572 let busy = db.is_user_busy(user_id).await?;
3573
3574 let pool = session.connection_pool().await;
3575 let updated_contact = contact_for_user(user_id, busy, &pool);
3576 for contact in contacts {
3577 if let db::Contact::Accepted {
3578 user_id: contact_user_id,
3579 ..
3580 } = contact
3581 {
3582 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3583 session
3584 .peer
3585 .send(
3586 contact_conn_id,
3587 proto::UpdateContacts {
3588 contacts: vec![updated_contact.clone()],
3589 remove_contacts: Default::default(),
3590 incoming_requests: Default::default(),
3591 remove_incoming_requests: Default::default(),
3592 outgoing_requests: Default::default(),
3593 remove_outgoing_requests: Default::default(),
3594 },
3595 )
3596 .trace_err();
3597 }
3598 }
3599 }
3600 Ok(())
3601}
3602
3603async fn leave_room_for_session(session: &Session) -> Result<()> {
3604 let mut contacts_to_update = HashSet::default();
3605
3606 let room_id;
3607 let canceled_calls_to_user_ids;
3608 let live_kit_room;
3609 let delete_live_kit_room;
3610 let room;
3611 let channel_members;
3612 let channel_id;
3613
3614 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3615 contacts_to_update.insert(session.user_id);
3616
3617 for project in left_room.left_projects.values() {
3618 project_left(project, session);
3619 }
3620
3621 room_id = RoomId::from_proto(left_room.room.id);
3622 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3623 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3624 delete_live_kit_room = left_room.deleted;
3625 room = mem::take(&mut left_room.room);
3626 channel_members = mem::take(&mut left_room.channel_members);
3627 channel_id = left_room.channel_id;
3628
3629 room_updated(&room, &session.peer);
3630 } else {
3631 return Ok(());
3632 }
3633
3634 if let Some(channel_id) = channel_id {
3635 channel_updated(
3636 channel_id,
3637 &room,
3638 &channel_members,
3639 &session.peer,
3640 &*session.connection_pool().await,
3641 );
3642 }
3643
3644 {
3645 let pool = session.connection_pool().await;
3646 for canceled_user_id in canceled_calls_to_user_ids {
3647 for connection_id in pool.user_connection_ids(canceled_user_id) {
3648 session
3649 .peer
3650 .send(
3651 connection_id,
3652 proto::CallCanceled {
3653 room_id: room_id.to_proto(),
3654 },
3655 )
3656 .trace_err();
3657 }
3658 contacts_to_update.insert(canceled_user_id);
3659 }
3660 }
3661
3662 for contact_user_id in contacts_to_update {
3663 update_user_contacts(contact_user_id, &session).await?;
3664 }
3665
3666 if let Some(live_kit) = session.live_kit_client.as_ref() {
3667 live_kit
3668 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3669 .await
3670 .trace_err();
3671
3672 if delete_live_kit_room {
3673 live_kit.delete_room(live_kit_room).await.trace_err();
3674 }
3675 }
3676
3677 Ok(())
3678}
3679
3680async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3681 let left_channel_buffers = session
3682 .db()
3683 .await
3684 .leave_channel_buffers(session.connection_id)
3685 .await?;
3686
3687 for left_buffer in left_channel_buffers {
3688 channel_buffer_updated(
3689 session.connection_id,
3690 left_buffer.connections,
3691 &proto::UpdateChannelBufferCollaborators {
3692 channel_id: left_buffer.channel_id.to_proto(),
3693 collaborators: left_buffer.collaborators,
3694 },
3695 &session.peer,
3696 );
3697 }
3698
3699 Ok(())
3700}
3701
3702fn project_left(project: &db::LeftProject, session: &Session) {
3703 for connection_id in &project.connection_ids {
3704 if project.host_user_id == Some(session.user_id) {
3705 session
3706 .peer
3707 .send(
3708 *connection_id,
3709 proto::UnshareProject {
3710 project_id: project.id.to_proto(),
3711 },
3712 )
3713 .trace_err();
3714 } else {
3715 session
3716 .peer
3717 .send(
3718 *connection_id,
3719 proto::RemoveProjectCollaborator {
3720 project_id: project.id.to_proto(),
3721 peer_id: Some(session.connection_id.into()),
3722 },
3723 )
3724 .trace_err();
3725 }
3726 }
3727}
3728
3729pub trait ResultExt {
3730 type Ok;
3731
3732 fn trace_err(self) -> Option<Self::Ok>;
3733}
3734
3735impl<T, E> ResultExt for Result<T, E>
3736where
3737 E: std::fmt::Debug,
3738{
3739 type Ok = T;
3740
3741 fn trace_err(self) -> Option<T> {
3742 match self {
3743 Ok(value) => Some(value),
3744 Err(error) => {
3745 tracing::error!("{:?}", error);
3746 None
3747 }
3748 }
3749 }
3750}