1mod connection_pool;
2
3use crate::{
4 auth::{self, Impersonator},
5 db::{
6 self, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage,
7 Database, InviteMemberResult, MembershipUpdated, MessageId, NotificationId, Project,
8 ProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId, ServerId,
9 User, UserId,
10 },
11 executor::Executor,
12 AppState, Error, Result,
13};
14use anyhow::anyhow;
15use async_tungstenite::tungstenite::{
16 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
17};
18use axum::{
19 body::Body,
20 extract::{
21 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
22 ConnectInfo, WebSocketUpgrade,
23 },
24 headers::{Header, HeaderName},
25 http::StatusCode,
26 middleware,
27 response::IntoResponse,
28 routing::get,
29 Extension, Router, TypedHeader,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use futures::{
34 channel::oneshot,
35 future::{self, BoxFuture},
36 stream::FuturesUnordered,
37 FutureExt, SinkExt, StreamExt, TryStreamExt,
38};
39use prometheus::{register_int_gauge, IntGauge};
40use rpc::{
41 proto::{
42 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
43 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
44 },
45 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
46};
47use serde::{Serialize, Serializer};
48use std::{
49 any::TypeId,
50 fmt,
51 future::Future,
52 marker::PhantomData,
53 mem,
54 net::SocketAddr,
55 ops::{Deref, DerefMut},
56 rc::Rc,
57 sync::{
58 atomic::{AtomicBool, Ordering::SeqCst},
59 Arc, OnceLock,
60 },
61 time::{Duration, Instant},
62};
63use time::OffsetDateTime;
64use tokio::sync::{watch, Semaphore};
65use tower::ServiceBuilder;
66use tracing::{field, info_span, instrument, Instrument};
67use util::SemanticVersion;
68
69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
70
71// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
72pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
73
74const MESSAGE_COUNT_PER_PAGE: usize = 100;
75const MAX_MESSAGE_LEN: usize = 1024;
76const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
77
78type MessageHandler =
79 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
80
81struct Response<R> {
82 peer: Arc<Peer>,
83 receipt: Receipt<R>,
84 responded: Arc<AtomicBool>,
85}
86
87impl<R: RequestMessage> Response<R> {
88 fn send(self, payload: R::Response) -> Result<()> {
89 self.responded.store(true, SeqCst);
90 self.peer.respond(self.receipt, payload)?;
91 Ok(())
92 }
93}
94
95#[derive(Clone)]
96struct Session {
97 user_id: UserId,
98 connection_id: ConnectionId,
99 db: Arc<tokio::sync::Mutex<DbHandle>>,
100 peer: Arc<Peer>,
101 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
102 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
103 _executor: Executor,
104}
105
106impl Session {
107 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
108 #[cfg(test)]
109 tokio::task::yield_now().await;
110 let guard = self.db.lock().await;
111 #[cfg(test)]
112 tokio::task::yield_now().await;
113 guard
114 }
115
116 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
117 #[cfg(test)]
118 tokio::task::yield_now().await;
119 let guard = self.connection_pool.lock();
120 ConnectionPoolGuard {
121 guard,
122 _not_send: PhantomData,
123 }
124 }
125}
126
127impl fmt::Debug for Session {
128 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129 f.debug_struct("Session")
130 .field("user_id", &self.user_id)
131 .field("connection_id", &self.connection_id)
132 .finish()
133 }
134}
135
136struct DbHandle(Arc<Database>);
137
138impl Deref for DbHandle {
139 type Target = Database;
140
141 fn deref(&self) -> &Self::Target {
142 self.0.as_ref()
143 }
144}
145
146pub struct Server {
147 id: parking_lot::Mutex<ServerId>,
148 peer: Arc<Peer>,
149 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
150 app_state: Arc<AppState>,
151 executor: Executor,
152 handlers: HashMap<TypeId, MessageHandler>,
153 teardown: watch::Sender<bool>,
154}
155
156pub(crate) struct ConnectionPoolGuard<'a> {
157 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
158 _not_send: PhantomData<Rc<()>>,
159}
160
161#[derive(Serialize)]
162pub struct ServerSnapshot<'a> {
163 peer: &'a Peer,
164 #[serde(serialize_with = "serialize_deref")]
165 connection_pool: ConnectionPoolGuard<'a>,
166}
167
168pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
169where
170 S: Serializer,
171 T: Deref<Target = U>,
172 U: Serialize,
173{
174 Serialize::serialize(value.deref(), serializer)
175}
176
177impl Server {
178 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
179 let mut server = Self {
180 id: parking_lot::Mutex::new(id),
181 peer: Peer::new(id.0 as u32),
182 app_state,
183 executor,
184 connection_pool: Default::default(),
185 handlers: Default::default(),
186 teardown: watch::channel(false).0,
187 };
188
189 server
190 .add_request_handler(ping)
191 .add_request_handler(create_room)
192 .add_request_handler(join_room)
193 .add_request_handler(rejoin_room)
194 .add_request_handler(leave_room)
195 .add_request_handler(set_room_participant_role)
196 .add_request_handler(call)
197 .add_request_handler(cancel_call)
198 .add_message_handler(decline_call)
199 .add_request_handler(update_participant_location)
200 .add_request_handler(share_project)
201 .add_message_handler(unshare_project)
202 .add_request_handler(join_project)
203 .add_request_handler(join_hosted_project)
204 .add_message_handler(leave_project)
205 .add_request_handler(update_project)
206 .add_request_handler(update_worktree)
207 .add_message_handler(start_language_server)
208 .add_message_handler(update_language_server)
209 .add_message_handler(update_diagnostic_summary)
210 .add_message_handler(update_worktree_settings)
211 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
212 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
213 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
214 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
215 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
216 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
217 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
218 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
219 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
220 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
221 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
222 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
223 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
224 .add_request_handler(
225 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
226 )
227 .add_request_handler(
228 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
229 )
230 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
231 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
232 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
233 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
234 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
235 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
236 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
237 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
238 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
239 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
240 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
241 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
242 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
243 .add_message_handler(create_buffer_for_peer)
244 .add_request_handler(update_buffer)
245 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
246 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
247 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
248 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
249 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
250 .add_request_handler(get_users)
251 .add_request_handler(fuzzy_search_users)
252 .add_request_handler(request_contact)
253 .add_request_handler(remove_contact)
254 .add_request_handler(respond_to_contact_request)
255 .add_request_handler(create_channel)
256 .add_request_handler(delete_channel)
257 .add_request_handler(invite_channel_member)
258 .add_request_handler(remove_channel_member)
259 .add_request_handler(set_channel_member_role)
260 .add_request_handler(set_channel_visibility)
261 .add_request_handler(rename_channel)
262 .add_request_handler(join_channel_buffer)
263 .add_request_handler(leave_channel_buffer)
264 .add_message_handler(update_channel_buffer)
265 .add_request_handler(rejoin_channel_buffers)
266 .add_request_handler(get_channel_members)
267 .add_request_handler(respond_to_channel_invite)
268 .add_request_handler(join_channel)
269 .add_request_handler(join_channel_chat)
270 .add_message_handler(leave_channel_chat)
271 .add_request_handler(send_channel_message)
272 .add_request_handler(remove_channel_message)
273 .add_request_handler(get_channel_messages)
274 .add_request_handler(get_channel_messages_by_id)
275 .add_request_handler(get_notifications)
276 .add_request_handler(mark_notification_as_read)
277 .add_request_handler(move_channel)
278 .add_request_handler(follow)
279 .add_message_handler(unfollow)
280 .add_message_handler(update_followers)
281 .add_request_handler(get_private_user_info)
282 .add_message_handler(acknowledge_channel_message)
283 .add_message_handler(acknowledge_buffer_version);
284
285 Arc::new(server)
286 }
287
288 pub async fn start(&self) -> Result<()> {
289 let server_id = *self.id.lock();
290 let app_state = self.app_state.clone();
291 let peer = self.peer.clone();
292 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
293 let pool = self.connection_pool.clone();
294 let live_kit_client = self.app_state.live_kit_client.clone();
295
296 let span = info_span!("start server");
297 self.executor.spawn_detached(
298 async move {
299 tracing::info!("waiting for cleanup timeout");
300 timeout.await;
301 tracing::info!("cleanup timeout expired, retrieving stale rooms");
302 if let Some((room_ids, channel_ids)) = app_state
303 .db
304 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
305 .await
306 .trace_err()
307 {
308 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
309 tracing::info!(
310 stale_channel_buffer_count = channel_ids.len(),
311 "retrieved stale channel buffers"
312 );
313
314 for channel_id in channel_ids {
315 if let Some(refreshed_channel_buffer) = app_state
316 .db
317 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
318 .await
319 .trace_err()
320 {
321 for connection_id in refreshed_channel_buffer.connection_ids {
322 peer.send(
323 connection_id,
324 proto::UpdateChannelBufferCollaborators {
325 channel_id: channel_id.to_proto(),
326 collaborators: refreshed_channel_buffer
327 .collaborators
328 .clone(),
329 },
330 )
331 .trace_err();
332 }
333 }
334 }
335
336 for room_id in room_ids {
337 let mut contacts_to_update = HashSet::default();
338 let mut canceled_calls_to_user_ids = Vec::new();
339 let mut live_kit_room = String::new();
340 let mut delete_live_kit_room = false;
341
342 if let Some(mut refreshed_room) = app_state
343 .db
344 .clear_stale_room_participants(room_id, server_id)
345 .await
346 .trace_err()
347 {
348 tracing::info!(
349 room_id = room_id.0,
350 new_participant_count = refreshed_room.room.participants.len(),
351 "refreshed room"
352 );
353 room_updated(&refreshed_room.room, &peer);
354 if let Some(channel) = refreshed_room.channel.as_ref() {
355 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
356 }
357 contacts_to_update
358 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
359 contacts_to_update
360 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
361 canceled_calls_to_user_ids =
362 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
363 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
364 delete_live_kit_room = refreshed_room.room.participants.is_empty();
365 }
366
367 {
368 let pool = pool.lock();
369 for canceled_user_id in canceled_calls_to_user_ids {
370 for connection_id in pool.user_connection_ids(canceled_user_id) {
371 peer.send(
372 connection_id,
373 proto::CallCanceled {
374 room_id: room_id.to_proto(),
375 },
376 )
377 .trace_err();
378 }
379 }
380 }
381
382 for user_id in contacts_to_update {
383 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
384 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
385 if let Some((busy, contacts)) = busy.zip(contacts) {
386 let pool = pool.lock();
387 let updated_contact = contact_for_user(user_id, busy, &pool);
388 for contact in contacts {
389 if let db::Contact::Accepted {
390 user_id: contact_user_id,
391 ..
392 } = contact
393 {
394 for contact_conn_id in
395 pool.user_connection_ids(contact_user_id)
396 {
397 peer.send(
398 contact_conn_id,
399 proto::UpdateContacts {
400 contacts: vec![updated_contact.clone()],
401 remove_contacts: Default::default(),
402 incoming_requests: Default::default(),
403 remove_incoming_requests: Default::default(),
404 outgoing_requests: Default::default(),
405 remove_outgoing_requests: Default::default(),
406 },
407 )
408 .trace_err();
409 }
410 }
411 }
412 }
413 }
414
415 if let Some(live_kit) = live_kit_client.as_ref() {
416 if delete_live_kit_room {
417 live_kit.delete_room(live_kit_room).await.trace_err();
418 }
419 }
420 }
421 }
422
423 app_state
424 .db
425 .delete_stale_servers(&app_state.config.zed_environment, server_id)
426 .await
427 .trace_err();
428 }
429 .instrument(span),
430 );
431 Ok(())
432 }
433
434 pub fn teardown(&self) {
435 self.peer.teardown();
436 self.connection_pool.lock().reset();
437 let _ = self.teardown.send(true);
438 }
439
440 #[cfg(test)]
441 pub fn reset(&self, id: ServerId) {
442 self.teardown();
443 *self.id.lock() = id;
444 self.peer.reset(id.0 as u32);
445 let _ = self.teardown.send(false);
446 }
447
448 #[cfg(test)]
449 pub fn id(&self) -> ServerId {
450 *self.id.lock()
451 }
452
453 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
454 where
455 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
456 Fut: 'static + Send + Future<Output = Result<()>>,
457 M: EnvelopedMessage,
458 {
459 let prev_handler = self.handlers.insert(
460 TypeId::of::<M>(),
461 Box::new(move |envelope, session| {
462 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
463 let received_at = envelope.received_at;
464 tracing::info!(
465 "message received"
466 );
467 let start_time = Instant::now();
468 let future = (handler)(*envelope, session);
469 async move {
470 let result = future.await;
471 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
472 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
473 let queue_duration_ms = total_duration_ms - processing_duration_ms;
474 match result {
475 Err(error) => {
476 tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, "error handling message")
477 }
478 Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
479 }
480 }
481 .boxed()
482 }),
483 );
484 if prev_handler.is_some() {
485 panic!("registered a handler for the same message twice");
486 }
487 self
488 }
489
490 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
491 where
492 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
493 Fut: 'static + Send + Future<Output = Result<()>>,
494 M: EnvelopedMessage,
495 {
496 self.add_handler(move |envelope, session| handler(envelope.payload, session));
497 self
498 }
499
500 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
501 where
502 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
503 Fut: Send + Future<Output = Result<()>>,
504 M: RequestMessage,
505 {
506 let handler = Arc::new(handler);
507 self.add_handler(move |envelope, session| {
508 let receipt = envelope.receipt();
509 let handler = handler.clone();
510 async move {
511 let peer = session.peer.clone();
512 let responded = Arc::new(AtomicBool::default());
513 let response = Response {
514 peer: peer.clone(),
515 responded: responded.clone(),
516 receipt,
517 };
518 match (handler)(envelope.payload, response, session).await {
519 Ok(()) => {
520 if responded.load(std::sync::atomic::Ordering::SeqCst) {
521 Ok(())
522 } else {
523 Err(anyhow!("handler did not send a response"))?
524 }
525 }
526 Err(error) => {
527 let proto_err = match &error {
528 Error::Internal(err) => err.to_proto(),
529 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
530 };
531 peer.respond_with_error(receipt, proto_err)?;
532 Err(error)
533 }
534 }
535 }
536 })
537 }
538
539 #[allow(clippy::too_many_arguments)]
540 pub fn handle_connection(
541 self: &Arc<Self>,
542 connection: Connection,
543 address: String,
544 user: User,
545 zed_version: ZedVersion,
546 impersonator: Option<User>,
547 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
548 executor: Executor,
549 ) -> impl Future<Output = ()> {
550 let this = self.clone();
551 let user_id = user.id;
552 let login = user.github_login.clone();
553 let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty, connection_id = field::Empty);
554 if let Some(impersonator) = impersonator {
555 span.record("impersonator", &impersonator.github_login);
556 }
557 let mut teardown = self.teardown.subscribe();
558 async move {
559 if *teardown.borrow() {
560 tracing::error!("server is tearing down");
561 return
562 }
563 let (connection_id, handle_io, mut incoming_rx) = this
564 .peer
565 .add_connection(connection, {
566 let executor = executor.clone();
567 move |duration| executor.sleep(duration)
568 });
569 tracing::Span::current().record("connection_id", format!("{}", connection_id));
570 tracing::info!("connection opened");
571
572 let session = Session {
573 user_id,
574 connection_id,
575 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
576 peer: this.peer.clone(),
577 connection_pool: this.connection_pool.clone(),
578 live_kit_client: this.app_state.live_kit_client.clone(),
579 _executor: executor.clone()
580 };
581
582 if let Err(error) = this.send_initial_client_update(connection_id, user, zed_version, send_connection_id, &session).await {
583 tracing::error!(?error, "failed to send initial client update");
584 return;
585 }
586
587 let handle_io = handle_io.fuse();
588 futures::pin_mut!(handle_io);
589
590 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
591 // This prevents deadlocks when e.g., client A performs a request to client B and
592 // client B performs a request to client A. If both clients stop processing further
593 // messages until their respective request completes, they won't have a chance to
594 // respond to the other client's request and cause a deadlock.
595 //
596 // This arrangement ensures we will attempt to process earlier messages first, but fall
597 // back to processing messages arrived later in the spirit of making progress.
598 let mut foreground_message_handlers = FuturesUnordered::new();
599 let concurrent_handlers = Arc::new(Semaphore::new(256));
600 loop {
601 let next_message = async {
602 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
603 let message = incoming_rx.next().await;
604 (permit, message)
605 }.fuse();
606 futures::pin_mut!(next_message);
607 futures::select_biased! {
608 _ = teardown.changed().fuse() => return,
609 result = handle_io => {
610 if let Err(error) = result {
611 tracing::error!(?error, "error handling I/O");
612 }
613 break;
614 }
615 _ = foreground_message_handlers.next() => {}
616 next_message = next_message => {
617 let (permit, message) = next_message;
618 if let Some(message) = message {
619 let type_name = message.payload_type_name();
620 // note: we copy all the fields from the parent span so we can query them in the logs.
621 // (https://github.com/tokio-rs/tracing/issues/2670).
622 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
623 let span_enter = span.enter();
624 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
625 let is_background = message.is_background();
626 let handle_message = (handler)(message, session.clone());
627 drop(span_enter);
628
629 let handle_message = async move {
630 handle_message.await;
631 drop(permit);
632 }.instrument(span);
633 if is_background {
634 executor.spawn_detached(handle_message);
635 } else {
636 foreground_message_handlers.push(handle_message);
637 }
638 } else {
639 tracing::error!("no message handler");
640 }
641 } else {
642 tracing::info!("connection closed");
643 break;
644 }
645 }
646 }
647 }
648
649 drop(foreground_message_handlers);
650 tracing::info!("signing out");
651 if let Err(error) = connection_lost(session, teardown, executor).await {
652 tracing::error!(?error, "error signing out");
653 }
654
655 }.instrument(span)
656 }
657
658 async fn send_initial_client_update(
659 &self,
660 connection_id: ConnectionId,
661 user: User,
662 zed_version: ZedVersion,
663 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
664 session: &Session,
665 ) -> Result<()> {
666 self.peer.send(
667 connection_id,
668 proto::Hello {
669 peer_id: Some(connection_id.into()),
670 },
671 )?;
672 tracing::info!("sent hello message");
673
674 if let Some(send_connection_id) = send_connection_id.take() {
675 let _ = send_connection_id.send(connection_id);
676 }
677
678 if !user.connected_once {
679 self.peer.send(connection_id, proto::ShowContacts {})?;
680 self.app_state
681 .db
682 .set_user_connected_once(user.id, true)
683 .await?;
684 }
685
686 let (contacts, channels_for_user, channel_invites) = future::try_join3(
687 self.app_state.db.get_contacts(user.id),
688 self.app_state.db.get_channels_for_user(user.id),
689 self.app_state.db.get_channel_invites_for_user(user.id),
690 )
691 .await?;
692
693 {
694 let mut pool = self.connection_pool.lock();
695 pool.add_connection(connection_id, user.id, user.admin, zed_version);
696 for membership in &channels_for_user.channel_memberships {
697 pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
698 }
699 self.peer.send(
700 connection_id,
701 build_initial_contacts_update(contacts, &pool),
702 )?;
703 self.peer.send(
704 connection_id,
705 build_update_user_channels(&channels_for_user),
706 )?;
707 self.peer.send(
708 connection_id,
709 build_channels_update(channels_for_user, channel_invites),
710 )?;
711 }
712
713 if let Some(incoming_call) = self.app_state.db.incoming_call_for_user(user.id).await? {
714 self.peer.send(connection_id, incoming_call)?;
715 }
716
717 update_user_contacts(user.id, &session).await?;
718 Ok(())
719 }
720
721 pub async fn invite_code_redeemed(
722 self: &Arc<Self>,
723 inviter_id: UserId,
724 invitee_id: UserId,
725 ) -> Result<()> {
726 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
727 if let Some(code) = &user.invite_code {
728 let pool = self.connection_pool.lock();
729 let invitee_contact = contact_for_user(invitee_id, false, &pool);
730 for connection_id in pool.user_connection_ids(inviter_id) {
731 self.peer.send(
732 connection_id,
733 proto::UpdateContacts {
734 contacts: vec![invitee_contact.clone()],
735 ..Default::default()
736 },
737 )?;
738 self.peer.send(
739 connection_id,
740 proto::UpdateInviteInfo {
741 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
742 count: user.invite_count as u32,
743 },
744 )?;
745 }
746 }
747 }
748 Ok(())
749 }
750
751 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
752 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
753 if let Some(invite_code) = &user.invite_code {
754 let pool = self.connection_pool.lock();
755 for connection_id in pool.user_connection_ids(user_id) {
756 self.peer.send(
757 connection_id,
758 proto::UpdateInviteInfo {
759 url: format!(
760 "{}{}",
761 self.app_state.config.invite_link_prefix, invite_code
762 ),
763 count: user.invite_count as u32,
764 },
765 )?;
766 }
767 }
768 }
769 Ok(())
770 }
771
772 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
773 ServerSnapshot {
774 connection_pool: ConnectionPoolGuard {
775 guard: self.connection_pool.lock(),
776 _not_send: PhantomData,
777 },
778 peer: &self.peer,
779 }
780 }
781}
782
783impl<'a> Deref for ConnectionPoolGuard<'a> {
784 type Target = ConnectionPool;
785
786 fn deref(&self) -> &Self::Target {
787 &self.guard
788 }
789}
790
791impl<'a> DerefMut for ConnectionPoolGuard<'a> {
792 fn deref_mut(&mut self) -> &mut Self::Target {
793 &mut self.guard
794 }
795}
796
797impl<'a> Drop for ConnectionPoolGuard<'a> {
798 fn drop(&mut self) {
799 #[cfg(test)]
800 self.check_invariants();
801 }
802}
803
804fn broadcast<F>(
805 sender_id: Option<ConnectionId>,
806 receiver_ids: impl IntoIterator<Item = ConnectionId>,
807 mut f: F,
808) where
809 F: FnMut(ConnectionId) -> anyhow::Result<()>,
810{
811 for receiver_id in receiver_ids {
812 if Some(receiver_id) != sender_id {
813 if let Err(error) = f(receiver_id) {
814 tracing::error!("failed to send to {:?} {}", receiver_id, error);
815 }
816 }
817 }
818}
819
820pub struct ProtocolVersion(u32);
821
822impl Header for ProtocolVersion {
823 fn name() -> &'static HeaderName {
824 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
825 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
826 }
827
828 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
829 where
830 Self: Sized,
831 I: Iterator<Item = &'i axum::http::HeaderValue>,
832 {
833 let version = values
834 .next()
835 .ok_or_else(axum::headers::Error::invalid)?
836 .to_str()
837 .map_err(|_| axum::headers::Error::invalid())?
838 .parse()
839 .map_err(|_| axum::headers::Error::invalid())?;
840 Ok(Self(version))
841 }
842
843 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
844 values.extend([self.0.to_string().parse().unwrap()]);
845 }
846}
847
848pub struct AppVersionHeader(SemanticVersion);
849impl Header for AppVersionHeader {
850 fn name() -> &'static HeaderName {
851 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
852 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
853 }
854
855 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
856 where
857 Self: Sized,
858 I: Iterator<Item = &'i axum::http::HeaderValue>,
859 {
860 let version = values
861 .next()
862 .ok_or_else(axum::headers::Error::invalid)?
863 .to_str()
864 .map_err(|_| axum::headers::Error::invalid())?
865 .parse()
866 .map_err(|_| axum::headers::Error::invalid())?;
867 Ok(Self(version))
868 }
869
870 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
871 values.extend([self.0.to_string().parse().unwrap()]);
872 }
873}
874
875pub fn routes(server: Arc<Server>) -> Router<(), Body> {
876 Router::new()
877 .route("/rpc", get(handle_websocket_request))
878 .layer(
879 ServiceBuilder::new()
880 .layer(Extension(server.app_state.clone()))
881 .layer(middleware::from_fn(auth::validate_header)),
882 )
883 .route("/metrics", get(handle_metrics))
884 .layer(Extension(server))
885}
886
887pub async fn handle_websocket_request(
888 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
889 app_version_header: Option<TypedHeader<AppVersionHeader>>,
890 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
891 Extension(server): Extension<Arc<Server>>,
892 Extension(user): Extension<User>,
893 Extension(impersonator): Extension<Impersonator>,
894 ws: WebSocketUpgrade,
895) -> axum::response::Response {
896 if protocol_version != rpc::PROTOCOL_VERSION {
897 return (
898 StatusCode::UPGRADE_REQUIRED,
899 "client must be upgraded".to_string(),
900 )
901 .into_response();
902 }
903
904 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
905 return (
906 StatusCode::UPGRADE_REQUIRED,
907 "no version header found".to_string(),
908 )
909 .into_response();
910 };
911
912 if !version.is_supported() {
913 return (
914 StatusCode::UPGRADE_REQUIRED,
915 "client must be upgraded".to_string(),
916 )
917 .into_response();
918 }
919
920 let socket_address = socket_address.to_string();
921 ws.on_upgrade(move |socket| {
922 let socket = socket
923 .map_ok(to_tungstenite_message)
924 .err_into()
925 .with(|message| async move { Ok(to_axum_message(message)) });
926 let connection = Connection::new(Box::pin(socket));
927 async move {
928 server
929 .handle_connection(
930 connection,
931 socket_address,
932 user,
933 version,
934 impersonator.0,
935 None,
936 Executor::Production,
937 )
938 .await;
939 }
940 })
941}
942
943pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
944 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
945 let connections_metric = CONNECTIONS_METRIC
946 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
947
948 let connections = server
949 .connection_pool
950 .lock()
951 .connections()
952 .filter(|connection| !connection.admin)
953 .count();
954 connections_metric.set(connections as _);
955
956 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
957 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
958 register_int_gauge!(
959 "shared_projects",
960 "number of open projects with one or more guests"
961 )
962 .unwrap()
963 });
964
965 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
966 shared_projects_metric.set(shared_projects as _);
967
968 let encoder = prometheus::TextEncoder::new();
969 let metric_families = prometheus::gather();
970 let encoded_metrics = encoder
971 .encode_to_string(&metric_families)
972 .map_err(|err| anyhow!("{}", err))?;
973 Ok(encoded_metrics)
974}
975
976#[instrument(err, skip(executor))]
977async fn connection_lost(
978 session: Session,
979 mut teardown: watch::Receiver<bool>,
980 executor: Executor,
981) -> Result<()> {
982 session.peer.disconnect(session.connection_id);
983 session
984 .connection_pool()
985 .await
986 .remove_connection(session.connection_id)?;
987
988 session
989 .db()
990 .await
991 .connection_lost(session.connection_id)
992 .await
993 .trace_err();
994
995 futures::select_biased! {
996 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
997 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
998 leave_room_for_session(&session).await.trace_err();
999 leave_channel_buffers_for_session(&session)
1000 .await
1001 .trace_err();
1002
1003 if !session
1004 .connection_pool()
1005 .await
1006 .is_user_online(session.user_id)
1007 {
1008 let db = session.db().await;
1009 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
1010 room_updated(&room, &session.peer);
1011 }
1012 }
1013
1014 update_user_contacts(session.user_id, &session).await?;
1015 }
1016 _ = teardown.changed().fuse() => {}
1017 }
1018
1019 Ok(())
1020}
1021
1022/// Acknowledges a ping from a client, used to keep the connection alive.
1023async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1024 response.send(proto::Ack {})?;
1025 Ok(())
1026}
1027
1028/// Creates a new room for calling (outside of channels)
1029async fn create_room(
1030 _request: proto::CreateRoom,
1031 response: Response<proto::CreateRoom>,
1032 session: Session,
1033) -> Result<()> {
1034 let live_kit_room = nanoid::nanoid!(30);
1035
1036 let live_kit_connection_info = {
1037 let live_kit_room = live_kit_room.clone();
1038 let live_kit = session.live_kit_client.as_ref();
1039
1040 util::async_maybe!({
1041 let live_kit = live_kit?;
1042
1043 let token = live_kit
1044 .room_token(&live_kit_room, &session.user_id.to_string())
1045 .trace_err()?;
1046
1047 Some(proto::LiveKitConnectionInfo {
1048 server_url: live_kit.url().into(),
1049 token,
1050 can_publish: true,
1051 })
1052 })
1053 }
1054 .await;
1055
1056 let room = session
1057 .db()
1058 .await
1059 .create_room(session.user_id, session.connection_id, &live_kit_room)
1060 .await?;
1061
1062 response.send(proto::CreateRoomResponse {
1063 room: Some(room.clone()),
1064 live_kit_connection_info,
1065 })?;
1066
1067 update_user_contacts(session.user_id, &session).await?;
1068 Ok(())
1069}
1070
1071/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1072async fn join_room(
1073 request: proto::JoinRoom,
1074 response: Response<proto::JoinRoom>,
1075 session: Session,
1076) -> Result<()> {
1077 let room_id = RoomId::from_proto(request.id);
1078
1079 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1080
1081 if let Some(channel_id) = channel_id {
1082 return join_channel_internal(channel_id, Box::new(response), session).await;
1083 }
1084
1085 let joined_room = {
1086 let room = session
1087 .db()
1088 .await
1089 .join_room(room_id, session.user_id, session.connection_id)
1090 .await?;
1091 room_updated(&room.room, &session.peer);
1092 room.into_inner()
1093 };
1094
1095 for connection_id in session
1096 .connection_pool()
1097 .await
1098 .user_connection_ids(session.user_id)
1099 {
1100 session
1101 .peer
1102 .send(
1103 connection_id,
1104 proto::CallCanceled {
1105 room_id: room_id.to_proto(),
1106 },
1107 )
1108 .trace_err();
1109 }
1110
1111 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1112 if let Some(token) = live_kit
1113 .room_token(
1114 &joined_room.room.live_kit_room,
1115 &session.user_id.to_string(),
1116 )
1117 .trace_err()
1118 {
1119 Some(proto::LiveKitConnectionInfo {
1120 server_url: live_kit.url().into(),
1121 token,
1122 can_publish: true,
1123 })
1124 } else {
1125 None
1126 }
1127 } else {
1128 None
1129 };
1130
1131 response.send(proto::JoinRoomResponse {
1132 room: Some(joined_room.room),
1133 channel_id: None,
1134 live_kit_connection_info,
1135 })?;
1136
1137 update_user_contacts(session.user_id, &session).await?;
1138 Ok(())
1139}
1140
1141/// Rejoin room is used to reconnect to a room after connection errors.
1142async fn rejoin_room(
1143 request: proto::RejoinRoom,
1144 response: Response<proto::RejoinRoom>,
1145 session: Session,
1146) -> Result<()> {
1147 let room;
1148 let channel;
1149 {
1150 let mut rejoined_room = session
1151 .db()
1152 .await
1153 .rejoin_room(request, session.user_id, session.connection_id)
1154 .await?;
1155
1156 response.send(proto::RejoinRoomResponse {
1157 room: Some(rejoined_room.room.clone()),
1158 reshared_projects: rejoined_room
1159 .reshared_projects
1160 .iter()
1161 .map(|project| proto::ResharedProject {
1162 id: project.id.to_proto(),
1163 collaborators: project
1164 .collaborators
1165 .iter()
1166 .map(|collaborator| collaborator.to_proto())
1167 .collect(),
1168 })
1169 .collect(),
1170 rejoined_projects: rejoined_room
1171 .rejoined_projects
1172 .iter()
1173 .map(|rejoined_project| proto::RejoinedProject {
1174 id: rejoined_project.id.to_proto(),
1175 worktrees: rejoined_project
1176 .worktrees
1177 .iter()
1178 .map(|worktree| proto::WorktreeMetadata {
1179 id: worktree.id,
1180 root_name: worktree.root_name.clone(),
1181 visible: worktree.visible,
1182 abs_path: worktree.abs_path.clone(),
1183 })
1184 .collect(),
1185 collaborators: rejoined_project
1186 .collaborators
1187 .iter()
1188 .map(|collaborator| collaborator.to_proto())
1189 .collect(),
1190 language_servers: rejoined_project.language_servers.clone(),
1191 })
1192 .collect(),
1193 })?;
1194 room_updated(&rejoined_room.room, &session.peer);
1195
1196 for project in &rejoined_room.reshared_projects {
1197 for collaborator in &project.collaborators {
1198 session
1199 .peer
1200 .send(
1201 collaborator.connection_id,
1202 proto::UpdateProjectCollaborator {
1203 project_id: project.id.to_proto(),
1204 old_peer_id: Some(project.old_connection_id.into()),
1205 new_peer_id: Some(session.connection_id.into()),
1206 },
1207 )
1208 .trace_err();
1209 }
1210
1211 broadcast(
1212 Some(session.connection_id),
1213 project
1214 .collaborators
1215 .iter()
1216 .map(|collaborator| collaborator.connection_id),
1217 |connection_id| {
1218 session.peer.forward_send(
1219 session.connection_id,
1220 connection_id,
1221 proto::UpdateProject {
1222 project_id: project.id.to_proto(),
1223 worktrees: project.worktrees.clone(),
1224 },
1225 )
1226 },
1227 );
1228 }
1229
1230 for project in &rejoined_room.rejoined_projects {
1231 for collaborator in &project.collaborators {
1232 session
1233 .peer
1234 .send(
1235 collaborator.connection_id,
1236 proto::UpdateProjectCollaborator {
1237 project_id: project.id.to_proto(),
1238 old_peer_id: Some(project.old_connection_id.into()),
1239 new_peer_id: Some(session.connection_id.into()),
1240 },
1241 )
1242 .trace_err();
1243 }
1244 }
1245
1246 for project in &mut rejoined_room.rejoined_projects {
1247 for worktree in mem::take(&mut project.worktrees) {
1248 #[cfg(any(test, feature = "test-support"))]
1249 const MAX_CHUNK_SIZE: usize = 2;
1250 #[cfg(not(any(test, feature = "test-support")))]
1251 const MAX_CHUNK_SIZE: usize = 256;
1252
1253 // Stream this worktree's entries.
1254 let message = proto::UpdateWorktree {
1255 project_id: project.id.to_proto(),
1256 worktree_id: worktree.id,
1257 abs_path: worktree.abs_path.clone(),
1258 root_name: worktree.root_name,
1259 updated_entries: worktree.updated_entries,
1260 removed_entries: worktree.removed_entries,
1261 scan_id: worktree.scan_id,
1262 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1263 updated_repositories: worktree.updated_repositories,
1264 removed_repositories: worktree.removed_repositories,
1265 };
1266 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1267 session.peer.send(session.connection_id, update.clone())?;
1268 }
1269
1270 // Stream this worktree's diagnostics.
1271 for summary in worktree.diagnostic_summaries {
1272 session.peer.send(
1273 session.connection_id,
1274 proto::UpdateDiagnosticSummary {
1275 project_id: project.id.to_proto(),
1276 worktree_id: worktree.id,
1277 summary: Some(summary),
1278 },
1279 )?;
1280 }
1281
1282 for settings_file in worktree.settings_files {
1283 session.peer.send(
1284 session.connection_id,
1285 proto::UpdateWorktreeSettings {
1286 project_id: project.id.to_proto(),
1287 worktree_id: worktree.id,
1288 path: settings_file.path,
1289 content: Some(settings_file.content),
1290 },
1291 )?;
1292 }
1293 }
1294
1295 for language_server in &project.language_servers {
1296 session.peer.send(
1297 session.connection_id,
1298 proto::UpdateLanguageServer {
1299 project_id: project.id.to_proto(),
1300 language_server_id: language_server.id,
1301 variant: Some(
1302 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1303 proto::LspDiskBasedDiagnosticsUpdated {},
1304 ),
1305 ),
1306 },
1307 )?;
1308 }
1309 }
1310
1311 let rejoined_room = rejoined_room.into_inner();
1312
1313 room = rejoined_room.room;
1314 channel = rejoined_room.channel;
1315 }
1316
1317 if let Some(channel) = channel {
1318 channel_updated(
1319 &channel,
1320 &room,
1321 &session.peer,
1322 &*session.connection_pool().await,
1323 );
1324 }
1325
1326 update_user_contacts(session.user_id, &session).await?;
1327 Ok(())
1328}
1329
1330/// leave room disconnects from the room.
1331async fn leave_room(
1332 _: proto::LeaveRoom,
1333 response: Response<proto::LeaveRoom>,
1334 session: Session,
1335) -> Result<()> {
1336 leave_room_for_session(&session).await?;
1337 response.send(proto::Ack {})?;
1338 Ok(())
1339}
1340
1341/// Updates the permissions of someone else in the room.
1342async fn set_room_participant_role(
1343 request: proto::SetRoomParticipantRole,
1344 response: Response<proto::SetRoomParticipantRole>,
1345 session: Session,
1346) -> Result<()> {
1347 let user_id = UserId::from_proto(request.user_id);
1348 let role = ChannelRole::from(request.role());
1349
1350 if role == ChannelRole::Talker {
1351 let pool = session.connection_pool().await;
1352
1353 for connection in pool.user_connections(user_id) {
1354 if !connection.zed_version.supports_talker_role() {
1355 Err(anyhow!(
1356 "This user is on zed {} which does not support unmute",
1357 connection.zed_version
1358 ))?;
1359 }
1360 }
1361 }
1362
1363 let (live_kit_room, can_publish) = {
1364 let room = session
1365 .db()
1366 .await
1367 .set_room_participant_role(
1368 session.user_id,
1369 RoomId::from_proto(request.room_id),
1370 user_id,
1371 role,
1372 )
1373 .await?;
1374
1375 let live_kit_room = room.live_kit_room.clone();
1376 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1377 room_updated(&room, &session.peer);
1378 (live_kit_room, can_publish)
1379 };
1380
1381 if let Some(live_kit) = session.live_kit_client.as_ref() {
1382 live_kit
1383 .update_participant(
1384 live_kit_room.clone(),
1385 request.user_id.to_string(),
1386 live_kit_server::proto::ParticipantPermission {
1387 can_subscribe: true,
1388 can_publish,
1389 can_publish_data: can_publish,
1390 hidden: false,
1391 recorder: false,
1392 },
1393 )
1394 .await
1395 .trace_err();
1396 }
1397
1398 response.send(proto::Ack {})?;
1399 Ok(())
1400}
1401
1402/// Call someone else into the current room
1403async fn call(
1404 request: proto::Call,
1405 response: Response<proto::Call>,
1406 session: Session,
1407) -> Result<()> {
1408 let room_id = RoomId::from_proto(request.room_id);
1409 let calling_user_id = session.user_id;
1410 let calling_connection_id = session.connection_id;
1411 let called_user_id = UserId::from_proto(request.called_user_id);
1412 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1413 if !session
1414 .db()
1415 .await
1416 .has_contact(calling_user_id, called_user_id)
1417 .await?
1418 {
1419 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1420 }
1421
1422 let incoming_call = {
1423 let (room, incoming_call) = &mut *session
1424 .db()
1425 .await
1426 .call(
1427 room_id,
1428 calling_user_id,
1429 calling_connection_id,
1430 called_user_id,
1431 initial_project_id,
1432 )
1433 .await?;
1434 room_updated(&room, &session.peer);
1435 mem::take(incoming_call)
1436 };
1437 update_user_contacts(called_user_id, &session).await?;
1438
1439 let mut calls = session
1440 .connection_pool()
1441 .await
1442 .user_connection_ids(called_user_id)
1443 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1444 .collect::<FuturesUnordered<_>>();
1445
1446 while let Some(call_response) = calls.next().await {
1447 match call_response.as_ref() {
1448 Ok(_) => {
1449 response.send(proto::Ack {})?;
1450 return Ok(());
1451 }
1452 Err(_) => {
1453 call_response.trace_err();
1454 }
1455 }
1456 }
1457
1458 {
1459 let room = session
1460 .db()
1461 .await
1462 .call_failed(room_id, called_user_id)
1463 .await?;
1464 room_updated(&room, &session.peer);
1465 }
1466 update_user_contacts(called_user_id, &session).await?;
1467
1468 Err(anyhow!("failed to ring user"))?
1469}
1470
1471/// Cancel an outgoing call.
1472async fn cancel_call(
1473 request: proto::CancelCall,
1474 response: Response<proto::CancelCall>,
1475 session: Session,
1476) -> Result<()> {
1477 let called_user_id = UserId::from_proto(request.called_user_id);
1478 let room_id = RoomId::from_proto(request.room_id);
1479 {
1480 let room = session
1481 .db()
1482 .await
1483 .cancel_call(room_id, session.connection_id, called_user_id)
1484 .await?;
1485 room_updated(&room, &session.peer);
1486 }
1487
1488 for connection_id in session
1489 .connection_pool()
1490 .await
1491 .user_connection_ids(called_user_id)
1492 {
1493 session
1494 .peer
1495 .send(
1496 connection_id,
1497 proto::CallCanceled {
1498 room_id: room_id.to_proto(),
1499 },
1500 )
1501 .trace_err();
1502 }
1503 response.send(proto::Ack {})?;
1504
1505 update_user_contacts(called_user_id, &session).await?;
1506 Ok(())
1507}
1508
1509/// Decline an incoming call.
1510async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1511 let room_id = RoomId::from_proto(message.room_id);
1512 {
1513 let room = session
1514 .db()
1515 .await
1516 .decline_call(Some(room_id), session.user_id)
1517 .await?
1518 .ok_or_else(|| anyhow!("failed to decline call"))?;
1519 room_updated(&room, &session.peer);
1520 }
1521
1522 for connection_id in session
1523 .connection_pool()
1524 .await
1525 .user_connection_ids(session.user_id)
1526 {
1527 session
1528 .peer
1529 .send(
1530 connection_id,
1531 proto::CallCanceled {
1532 room_id: room_id.to_proto(),
1533 },
1534 )
1535 .trace_err();
1536 }
1537 update_user_contacts(session.user_id, &session).await?;
1538 Ok(())
1539}
1540
1541/// Updates other participants in the room with your current location.
1542async fn update_participant_location(
1543 request: proto::UpdateParticipantLocation,
1544 response: Response<proto::UpdateParticipantLocation>,
1545 session: Session,
1546) -> Result<()> {
1547 let room_id = RoomId::from_proto(request.room_id);
1548 let location = request
1549 .location
1550 .ok_or_else(|| anyhow!("invalid location"))?;
1551
1552 let db = session.db().await;
1553 let room = db
1554 .update_room_participant_location(room_id, session.connection_id, location)
1555 .await?;
1556
1557 room_updated(&room, &session.peer);
1558 response.send(proto::Ack {})?;
1559 Ok(())
1560}
1561
1562/// Share a project into the room.
1563async fn share_project(
1564 request: proto::ShareProject,
1565 response: Response<proto::ShareProject>,
1566 session: Session,
1567) -> Result<()> {
1568 let (project_id, room) = &*session
1569 .db()
1570 .await
1571 .share_project(
1572 RoomId::from_proto(request.room_id),
1573 session.connection_id,
1574 &request.worktrees,
1575 )
1576 .await?;
1577 response.send(proto::ShareProjectResponse {
1578 project_id: project_id.to_proto(),
1579 })?;
1580 room_updated(&room, &session.peer);
1581
1582 Ok(())
1583}
1584
1585/// Unshare a project from the room.
1586async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1587 let project_id = ProjectId::from_proto(message.project_id);
1588
1589 let (room, guest_connection_ids) = &*session
1590 .db()
1591 .await
1592 .unshare_project(project_id, session.connection_id)
1593 .await?;
1594
1595 broadcast(
1596 Some(session.connection_id),
1597 guest_connection_ids.iter().copied(),
1598 |conn_id| session.peer.send(conn_id, message.clone()),
1599 );
1600 room_updated(&room, &session.peer);
1601
1602 Ok(())
1603}
1604
1605/// Join someone elses shared project.
1606async fn join_project(
1607 request: proto::JoinProject,
1608 response: Response<proto::JoinProject>,
1609 session: Session,
1610) -> Result<()> {
1611 let project_id = ProjectId::from_proto(request.project_id);
1612
1613 tracing::info!(%project_id, "join project");
1614
1615 let (project, replica_id) = &mut *session
1616 .db()
1617 .await
1618 .join_project_in_room(project_id, session.connection_id)
1619 .await?;
1620
1621 join_project_internal(response, session, project, replica_id)
1622}
1623
1624trait JoinProjectInternalResponse {
1625 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1626}
1627impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1628 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1629 Response::<proto::JoinProject>::send(self, result)
1630 }
1631}
1632impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1633 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1634 Response::<proto::JoinHostedProject>::send(self, result)
1635 }
1636}
1637
1638fn join_project_internal(
1639 response: impl JoinProjectInternalResponse,
1640 session: Session,
1641 project: &mut Project,
1642 replica_id: &ReplicaId,
1643) -> Result<()> {
1644 let collaborators = project
1645 .collaborators
1646 .iter()
1647 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1648 .map(|collaborator| collaborator.to_proto())
1649 .collect::<Vec<_>>();
1650 let project_id = project.id;
1651 let guest_user_id = session.user_id;
1652
1653 let worktrees = project
1654 .worktrees
1655 .iter()
1656 .map(|(id, worktree)| proto::WorktreeMetadata {
1657 id: *id,
1658 root_name: worktree.root_name.clone(),
1659 visible: worktree.visible,
1660 abs_path: worktree.abs_path.clone(),
1661 })
1662 .collect::<Vec<_>>();
1663
1664 for collaborator in &collaborators {
1665 session
1666 .peer
1667 .send(
1668 collaborator.peer_id.unwrap().into(),
1669 proto::AddProjectCollaborator {
1670 project_id: project_id.to_proto(),
1671 collaborator: Some(proto::Collaborator {
1672 peer_id: Some(session.connection_id.into()),
1673 replica_id: replica_id.0 as u32,
1674 user_id: guest_user_id.to_proto(),
1675 }),
1676 },
1677 )
1678 .trace_err();
1679 }
1680
1681 // First, we send the metadata associated with each worktree.
1682 response.send(proto::JoinProjectResponse {
1683 project_id: project.id.0 as u64,
1684 worktrees: worktrees.clone(),
1685 replica_id: replica_id.0 as u32,
1686 collaborators: collaborators.clone(),
1687 language_servers: project.language_servers.clone(),
1688 role: project.role.into(), // todo
1689 })?;
1690
1691 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1692 #[cfg(any(test, feature = "test-support"))]
1693 const MAX_CHUNK_SIZE: usize = 2;
1694 #[cfg(not(any(test, feature = "test-support")))]
1695 const MAX_CHUNK_SIZE: usize = 256;
1696
1697 // Stream this worktree's entries.
1698 let message = proto::UpdateWorktree {
1699 project_id: project_id.to_proto(),
1700 worktree_id,
1701 abs_path: worktree.abs_path.clone(),
1702 root_name: worktree.root_name,
1703 updated_entries: worktree.entries,
1704 removed_entries: Default::default(),
1705 scan_id: worktree.scan_id,
1706 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1707 updated_repositories: worktree.repository_entries.into_values().collect(),
1708 removed_repositories: Default::default(),
1709 };
1710 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1711 session.peer.send(session.connection_id, update.clone())?;
1712 }
1713
1714 // Stream this worktree's diagnostics.
1715 for summary in worktree.diagnostic_summaries {
1716 session.peer.send(
1717 session.connection_id,
1718 proto::UpdateDiagnosticSummary {
1719 project_id: project_id.to_proto(),
1720 worktree_id: worktree.id,
1721 summary: Some(summary),
1722 },
1723 )?;
1724 }
1725
1726 for settings_file in worktree.settings_files {
1727 session.peer.send(
1728 session.connection_id,
1729 proto::UpdateWorktreeSettings {
1730 project_id: project_id.to_proto(),
1731 worktree_id: worktree.id,
1732 path: settings_file.path,
1733 content: Some(settings_file.content),
1734 },
1735 )?;
1736 }
1737 }
1738
1739 for language_server in &project.language_servers {
1740 session.peer.send(
1741 session.connection_id,
1742 proto::UpdateLanguageServer {
1743 project_id: project_id.to_proto(),
1744 language_server_id: language_server.id,
1745 variant: Some(
1746 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1747 proto::LspDiskBasedDiagnosticsUpdated {},
1748 ),
1749 ),
1750 },
1751 )?;
1752 }
1753
1754 Ok(())
1755}
1756
1757/// Leave someone elses shared project.
1758async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1759 let sender_id = session.connection_id;
1760 let project_id = ProjectId::from_proto(request.project_id);
1761 let db = session.db().await;
1762 if db.is_hosted_project(project_id).await? {
1763 let project = db.leave_hosted_project(project_id, sender_id).await?;
1764 project_left(&project, &session);
1765 return Ok(());
1766 }
1767
1768 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1769 tracing::info!(
1770 %project_id,
1771 host_user_id = ?project.host_user_id,
1772 host_connection_id = ?project.host_connection_id,
1773 "leave project"
1774 );
1775
1776 project_left(&project, &session);
1777 room_updated(&room, &session.peer);
1778
1779 Ok(())
1780}
1781
1782async fn join_hosted_project(
1783 request: proto::JoinHostedProject,
1784 response: Response<proto::JoinHostedProject>,
1785 session: Session,
1786) -> Result<()> {
1787 let (mut project, replica_id) = session
1788 .db()
1789 .await
1790 .join_hosted_project(
1791 ProjectId(request.project_id as i32),
1792 session.user_id,
1793 session.connection_id,
1794 )
1795 .await?;
1796
1797 join_project_internal(response, session, &mut project, &replica_id)
1798}
1799
1800/// Updates other participants with changes to the project
1801async fn update_project(
1802 request: proto::UpdateProject,
1803 response: Response<proto::UpdateProject>,
1804 session: Session,
1805) -> Result<()> {
1806 let project_id = ProjectId::from_proto(request.project_id);
1807 let (room, guest_connection_ids) = &*session
1808 .db()
1809 .await
1810 .update_project(project_id, session.connection_id, &request.worktrees)
1811 .await?;
1812 broadcast(
1813 Some(session.connection_id),
1814 guest_connection_ids.iter().copied(),
1815 |connection_id| {
1816 session
1817 .peer
1818 .forward_send(session.connection_id, connection_id, request.clone())
1819 },
1820 );
1821 room_updated(&room, &session.peer);
1822 response.send(proto::Ack {})?;
1823
1824 Ok(())
1825}
1826
1827/// Updates other participants with changes to the worktree
1828async fn update_worktree(
1829 request: proto::UpdateWorktree,
1830 response: Response<proto::UpdateWorktree>,
1831 session: Session,
1832) -> Result<()> {
1833 let guest_connection_ids = session
1834 .db()
1835 .await
1836 .update_worktree(&request, session.connection_id)
1837 .await?;
1838
1839 broadcast(
1840 Some(session.connection_id),
1841 guest_connection_ids.iter().copied(),
1842 |connection_id| {
1843 session
1844 .peer
1845 .forward_send(session.connection_id, connection_id, request.clone())
1846 },
1847 );
1848 response.send(proto::Ack {})?;
1849 Ok(())
1850}
1851
1852/// Updates other participants with changes to the diagnostics
1853async fn update_diagnostic_summary(
1854 message: proto::UpdateDiagnosticSummary,
1855 session: Session,
1856) -> Result<()> {
1857 let guest_connection_ids = session
1858 .db()
1859 .await
1860 .update_diagnostic_summary(&message, session.connection_id)
1861 .await?;
1862
1863 broadcast(
1864 Some(session.connection_id),
1865 guest_connection_ids.iter().copied(),
1866 |connection_id| {
1867 session
1868 .peer
1869 .forward_send(session.connection_id, connection_id, message.clone())
1870 },
1871 );
1872
1873 Ok(())
1874}
1875
1876/// Updates other participants with changes to the worktree settings
1877async fn update_worktree_settings(
1878 message: proto::UpdateWorktreeSettings,
1879 session: Session,
1880) -> Result<()> {
1881 let guest_connection_ids = session
1882 .db()
1883 .await
1884 .update_worktree_settings(&message, session.connection_id)
1885 .await?;
1886
1887 broadcast(
1888 Some(session.connection_id),
1889 guest_connection_ids.iter().copied(),
1890 |connection_id| {
1891 session
1892 .peer
1893 .forward_send(session.connection_id, connection_id, message.clone())
1894 },
1895 );
1896
1897 Ok(())
1898}
1899
1900/// Notify other participants that a language server has started.
1901async fn start_language_server(
1902 request: proto::StartLanguageServer,
1903 session: Session,
1904) -> Result<()> {
1905 let guest_connection_ids = session
1906 .db()
1907 .await
1908 .start_language_server(&request, session.connection_id)
1909 .await?;
1910
1911 broadcast(
1912 Some(session.connection_id),
1913 guest_connection_ids.iter().copied(),
1914 |connection_id| {
1915 session
1916 .peer
1917 .forward_send(session.connection_id, connection_id, request.clone())
1918 },
1919 );
1920 Ok(())
1921}
1922
1923/// Notify other participants that a language server has changed.
1924async fn update_language_server(
1925 request: proto::UpdateLanguageServer,
1926 session: Session,
1927) -> Result<()> {
1928 let project_id = ProjectId::from_proto(request.project_id);
1929 let project_connection_ids = session
1930 .db()
1931 .await
1932 .project_connection_ids(project_id, session.connection_id)
1933 .await?;
1934 broadcast(
1935 Some(session.connection_id),
1936 project_connection_ids.iter().copied(),
1937 |connection_id| {
1938 session
1939 .peer
1940 .forward_send(session.connection_id, connection_id, request.clone())
1941 },
1942 );
1943 Ok(())
1944}
1945
1946/// forward a project request to the host. These requests should be read only
1947/// as guests are allowed to send them.
1948async fn forward_read_only_project_request<T>(
1949 request: T,
1950 response: Response<T>,
1951 session: Session,
1952) -> Result<()>
1953where
1954 T: EntityMessage + RequestMessage,
1955{
1956 let project_id = ProjectId::from_proto(request.remote_entity_id());
1957 let host_connection_id = session
1958 .db()
1959 .await
1960 .host_for_read_only_project_request(project_id, session.connection_id)
1961 .await?;
1962 let payload = session
1963 .peer
1964 .forward_request(session.connection_id, host_connection_id, request)
1965 .await?;
1966 response.send(payload)?;
1967 Ok(())
1968}
1969
1970/// forward a project request to the host. These requests are disallowed
1971/// for guests.
1972async fn forward_mutating_project_request<T>(
1973 request: T,
1974 response: Response<T>,
1975 session: Session,
1976) -> Result<()>
1977where
1978 T: EntityMessage + RequestMessage,
1979{
1980 let project_id = ProjectId::from_proto(request.remote_entity_id());
1981 let host_connection_id = session
1982 .db()
1983 .await
1984 .host_for_mutating_project_request(project_id, session.connection_id)
1985 .await?;
1986 let payload = session
1987 .peer
1988 .forward_request(session.connection_id, host_connection_id, request)
1989 .await?;
1990 response.send(payload)?;
1991 Ok(())
1992}
1993
1994/// Notify other participants that a new buffer has been created
1995async fn create_buffer_for_peer(
1996 request: proto::CreateBufferForPeer,
1997 session: Session,
1998) -> Result<()> {
1999 session
2000 .db()
2001 .await
2002 .check_user_is_project_host(
2003 ProjectId::from_proto(request.project_id),
2004 session.connection_id,
2005 )
2006 .await?;
2007 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2008 session
2009 .peer
2010 .forward_send(session.connection_id, peer_id.into(), request)?;
2011 Ok(())
2012}
2013
2014/// Notify other participants that a buffer has been updated. This is
2015/// allowed for guests as long as the update is limited to selections.
2016async fn update_buffer(
2017 request: proto::UpdateBuffer,
2018 response: Response<proto::UpdateBuffer>,
2019 session: Session,
2020) -> Result<()> {
2021 let project_id = ProjectId::from_proto(request.project_id);
2022 let mut guest_connection_ids;
2023 let mut host_connection_id = None;
2024
2025 let mut requires_write_permission = false;
2026
2027 for op in request.operations.iter() {
2028 match op.variant {
2029 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2030 Some(_) => requires_write_permission = true,
2031 }
2032 }
2033
2034 {
2035 let collaborators = session
2036 .db()
2037 .await
2038 .project_collaborators_for_buffer_update(
2039 project_id,
2040 session.connection_id,
2041 requires_write_permission,
2042 )
2043 .await?;
2044 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2045 for collaborator in collaborators.iter() {
2046 if collaborator.is_host {
2047 host_connection_id = Some(collaborator.connection_id);
2048 } else {
2049 guest_connection_ids.push(collaborator.connection_id);
2050 }
2051 }
2052 }
2053 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2054
2055 broadcast(
2056 Some(session.connection_id),
2057 guest_connection_ids,
2058 |connection_id| {
2059 session
2060 .peer
2061 .forward_send(session.connection_id, connection_id, request.clone())
2062 },
2063 );
2064 if host_connection_id != session.connection_id {
2065 session
2066 .peer
2067 .forward_request(session.connection_id, host_connection_id, request.clone())
2068 .await?;
2069 }
2070
2071 response.send(proto::Ack {})?;
2072 Ok(())
2073}
2074
2075/// Notify other participants that a project has been updated.
2076async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2077 request: T,
2078 session: Session,
2079) -> Result<()> {
2080 let project_id = ProjectId::from_proto(request.remote_entity_id());
2081 let project_connection_ids = session
2082 .db()
2083 .await
2084 .project_connection_ids(project_id, session.connection_id)
2085 .await?;
2086
2087 broadcast(
2088 Some(session.connection_id),
2089 project_connection_ids.iter().copied(),
2090 |connection_id| {
2091 session
2092 .peer
2093 .forward_send(session.connection_id, connection_id, request.clone())
2094 },
2095 );
2096 Ok(())
2097}
2098
2099/// Start following another user in a call.
2100async fn follow(
2101 request: proto::Follow,
2102 response: Response<proto::Follow>,
2103 session: Session,
2104) -> Result<()> {
2105 let room_id = RoomId::from_proto(request.room_id);
2106 let project_id = request.project_id.map(ProjectId::from_proto);
2107 let leader_id = request
2108 .leader_id
2109 .ok_or_else(|| anyhow!("invalid leader id"))?
2110 .into();
2111 let follower_id = session.connection_id;
2112
2113 session
2114 .db()
2115 .await
2116 .check_room_participants(room_id, leader_id, session.connection_id)
2117 .await?;
2118
2119 let response_payload = session
2120 .peer
2121 .forward_request(session.connection_id, leader_id, request)
2122 .await?;
2123 response.send(response_payload)?;
2124
2125 if let Some(project_id) = project_id {
2126 let room = session
2127 .db()
2128 .await
2129 .follow(room_id, project_id, leader_id, follower_id)
2130 .await?;
2131 room_updated(&room, &session.peer);
2132 }
2133
2134 Ok(())
2135}
2136
2137/// Stop following another user in a call.
2138async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2139 let room_id = RoomId::from_proto(request.room_id);
2140 let project_id = request.project_id.map(ProjectId::from_proto);
2141 let leader_id = request
2142 .leader_id
2143 .ok_or_else(|| anyhow!("invalid leader id"))?
2144 .into();
2145 let follower_id = session.connection_id;
2146
2147 session
2148 .db()
2149 .await
2150 .check_room_participants(room_id, leader_id, session.connection_id)
2151 .await?;
2152
2153 session
2154 .peer
2155 .forward_send(session.connection_id, leader_id, request)?;
2156
2157 if let Some(project_id) = project_id {
2158 let room = session
2159 .db()
2160 .await
2161 .unfollow(room_id, project_id, leader_id, follower_id)
2162 .await?;
2163 room_updated(&room, &session.peer);
2164 }
2165
2166 Ok(())
2167}
2168
2169/// Notify everyone following you of your current location.
2170async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2171 let room_id = RoomId::from_proto(request.room_id);
2172 let database = session.db.lock().await;
2173
2174 let connection_ids = if let Some(project_id) = request.project_id {
2175 let project_id = ProjectId::from_proto(project_id);
2176 database
2177 .project_connection_ids(project_id, session.connection_id)
2178 .await?
2179 } else {
2180 database
2181 .room_connection_ids(room_id, session.connection_id)
2182 .await?
2183 };
2184
2185 // For now, don't send view update messages back to that view's current leader.
2186 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2187 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2188 _ => None,
2189 });
2190
2191 for connection_id in connection_ids.iter().cloned() {
2192 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2193 session
2194 .peer
2195 .forward_send(session.connection_id, connection_id, request.clone())?;
2196 }
2197 }
2198 Ok(())
2199}
2200
2201/// Get public data about users.
2202async fn get_users(
2203 request: proto::GetUsers,
2204 response: Response<proto::GetUsers>,
2205 session: Session,
2206) -> Result<()> {
2207 let user_ids = request
2208 .user_ids
2209 .into_iter()
2210 .map(UserId::from_proto)
2211 .collect();
2212 let users = session
2213 .db()
2214 .await
2215 .get_users_by_ids(user_ids)
2216 .await?
2217 .into_iter()
2218 .map(|user| proto::User {
2219 id: user.id.to_proto(),
2220 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2221 github_login: user.github_login,
2222 })
2223 .collect();
2224 response.send(proto::UsersResponse { users })?;
2225 Ok(())
2226}
2227
2228/// Search for users (to invite) buy Github login
2229async fn fuzzy_search_users(
2230 request: proto::FuzzySearchUsers,
2231 response: Response<proto::FuzzySearchUsers>,
2232 session: Session,
2233) -> Result<()> {
2234 let query = request.query;
2235 let users = match query.len() {
2236 0 => vec![],
2237 1 | 2 => session
2238 .db()
2239 .await
2240 .get_user_by_github_login(&query)
2241 .await?
2242 .into_iter()
2243 .collect(),
2244 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2245 };
2246 let users = users
2247 .into_iter()
2248 .filter(|user| user.id != session.user_id)
2249 .map(|user| proto::User {
2250 id: user.id.to_proto(),
2251 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2252 github_login: user.github_login,
2253 })
2254 .collect();
2255 response.send(proto::UsersResponse { users })?;
2256 Ok(())
2257}
2258
2259/// Send a contact request to another user.
2260async fn request_contact(
2261 request: proto::RequestContact,
2262 response: Response<proto::RequestContact>,
2263 session: Session,
2264) -> Result<()> {
2265 let requester_id = session.user_id;
2266 let responder_id = UserId::from_proto(request.responder_id);
2267 if requester_id == responder_id {
2268 return Err(anyhow!("cannot add yourself as a contact"))?;
2269 }
2270
2271 let notifications = session
2272 .db()
2273 .await
2274 .send_contact_request(requester_id, responder_id)
2275 .await?;
2276
2277 // Update outgoing contact requests of requester
2278 let mut update = proto::UpdateContacts::default();
2279 update.outgoing_requests.push(responder_id.to_proto());
2280 for connection_id in session
2281 .connection_pool()
2282 .await
2283 .user_connection_ids(requester_id)
2284 {
2285 session.peer.send(connection_id, update.clone())?;
2286 }
2287
2288 // Update incoming contact requests of responder
2289 let mut update = proto::UpdateContacts::default();
2290 update
2291 .incoming_requests
2292 .push(proto::IncomingContactRequest {
2293 requester_id: requester_id.to_proto(),
2294 });
2295 let connection_pool = session.connection_pool().await;
2296 for connection_id in connection_pool.user_connection_ids(responder_id) {
2297 session.peer.send(connection_id, update.clone())?;
2298 }
2299
2300 send_notifications(&connection_pool, &session.peer, notifications);
2301
2302 response.send(proto::Ack {})?;
2303 Ok(())
2304}
2305
2306/// Accept or decline a contact request
2307async fn respond_to_contact_request(
2308 request: proto::RespondToContactRequest,
2309 response: Response<proto::RespondToContactRequest>,
2310 session: Session,
2311) -> Result<()> {
2312 let responder_id = session.user_id;
2313 let requester_id = UserId::from_proto(request.requester_id);
2314 let db = session.db().await;
2315 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2316 db.dismiss_contact_notification(responder_id, requester_id)
2317 .await?;
2318 } else {
2319 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2320
2321 let notifications = db
2322 .respond_to_contact_request(responder_id, requester_id, accept)
2323 .await?;
2324 let requester_busy = db.is_user_busy(requester_id).await?;
2325 let responder_busy = db.is_user_busy(responder_id).await?;
2326
2327 let pool = session.connection_pool().await;
2328 // Update responder with new contact
2329 let mut update = proto::UpdateContacts::default();
2330 if accept {
2331 update
2332 .contacts
2333 .push(contact_for_user(requester_id, requester_busy, &pool));
2334 }
2335 update
2336 .remove_incoming_requests
2337 .push(requester_id.to_proto());
2338 for connection_id in pool.user_connection_ids(responder_id) {
2339 session.peer.send(connection_id, update.clone())?;
2340 }
2341
2342 // Update requester with new contact
2343 let mut update = proto::UpdateContacts::default();
2344 if accept {
2345 update
2346 .contacts
2347 .push(contact_for_user(responder_id, responder_busy, &pool));
2348 }
2349 update
2350 .remove_outgoing_requests
2351 .push(responder_id.to_proto());
2352
2353 for connection_id in pool.user_connection_ids(requester_id) {
2354 session.peer.send(connection_id, update.clone())?;
2355 }
2356
2357 send_notifications(&pool, &session.peer, notifications);
2358 }
2359
2360 response.send(proto::Ack {})?;
2361 Ok(())
2362}
2363
2364/// Remove a contact.
2365async fn remove_contact(
2366 request: proto::RemoveContact,
2367 response: Response<proto::RemoveContact>,
2368 session: Session,
2369) -> Result<()> {
2370 let requester_id = session.user_id;
2371 let responder_id = UserId::from_proto(request.user_id);
2372 let db = session.db().await;
2373 let (contact_accepted, deleted_notification_id) =
2374 db.remove_contact(requester_id, responder_id).await?;
2375
2376 let pool = session.connection_pool().await;
2377 // Update outgoing contact requests of requester
2378 let mut update = proto::UpdateContacts::default();
2379 if contact_accepted {
2380 update.remove_contacts.push(responder_id.to_proto());
2381 } else {
2382 update
2383 .remove_outgoing_requests
2384 .push(responder_id.to_proto());
2385 }
2386 for connection_id in pool.user_connection_ids(requester_id) {
2387 session.peer.send(connection_id, update.clone())?;
2388 }
2389
2390 // Update incoming contact requests of responder
2391 let mut update = proto::UpdateContacts::default();
2392 if contact_accepted {
2393 update.remove_contacts.push(requester_id.to_proto());
2394 } else {
2395 update
2396 .remove_incoming_requests
2397 .push(requester_id.to_proto());
2398 }
2399 for connection_id in pool.user_connection_ids(responder_id) {
2400 session.peer.send(connection_id, update.clone())?;
2401 if let Some(notification_id) = deleted_notification_id {
2402 session.peer.send(
2403 connection_id,
2404 proto::DeleteNotification {
2405 notification_id: notification_id.to_proto(),
2406 },
2407 )?;
2408 }
2409 }
2410
2411 response.send(proto::Ack {})?;
2412 Ok(())
2413}
2414
2415/// Creates a new channel.
2416async fn create_channel(
2417 request: proto::CreateChannel,
2418 response: Response<proto::CreateChannel>,
2419 session: Session,
2420) -> Result<()> {
2421 let db = session.db().await;
2422
2423 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2424 let (channel, membership) = db
2425 .create_channel(&request.name, parent_id, session.user_id)
2426 .await?;
2427
2428 let root_id = channel.root_id();
2429 let channel = Channel::from_model(channel);
2430
2431 response.send(proto::CreateChannelResponse {
2432 channel: Some(channel.to_proto()),
2433 parent_id: request.parent_id,
2434 })?;
2435
2436 let mut connection_pool = session.connection_pool().await;
2437 if let Some(membership) = membership {
2438 connection_pool.subscribe_to_channel(
2439 membership.user_id,
2440 membership.channel_id,
2441 membership.role,
2442 );
2443 let update = proto::UpdateUserChannels {
2444 channel_memberships: vec![proto::ChannelMembership {
2445 channel_id: membership.channel_id.to_proto(),
2446 role: membership.role.into(),
2447 }],
2448 ..Default::default()
2449 };
2450 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2451 session.peer.send(connection_id, update.clone())?;
2452 }
2453 }
2454
2455 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2456 if !role.can_see_channel(channel.visibility) {
2457 continue;
2458 }
2459
2460 let update = proto::UpdateChannels {
2461 channels: vec![channel.to_proto()],
2462 ..Default::default()
2463 };
2464 session.peer.send(connection_id, update.clone())?;
2465 }
2466
2467 Ok(())
2468}
2469
2470/// Delete a channel
2471async fn delete_channel(
2472 request: proto::DeleteChannel,
2473 response: Response<proto::DeleteChannel>,
2474 session: Session,
2475) -> Result<()> {
2476 let db = session.db().await;
2477
2478 let channel_id = request.channel_id;
2479 let (root_channel, removed_channels) = db
2480 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2481 .await?;
2482 response.send(proto::Ack {})?;
2483
2484 // Notify members of removed channels
2485 let mut update = proto::UpdateChannels::default();
2486 update
2487 .delete_channels
2488 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2489
2490 let connection_pool = session.connection_pool().await;
2491 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2492 session.peer.send(connection_id, update.clone())?;
2493 }
2494
2495 Ok(())
2496}
2497
2498/// Invite someone to join a channel.
2499async fn invite_channel_member(
2500 request: proto::InviteChannelMember,
2501 response: Response<proto::InviteChannelMember>,
2502 session: Session,
2503) -> Result<()> {
2504 let db = session.db().await;
2505 let channel_id = ChannelId::from_proto(request.channel_id);
2506 let invitee_id = UserId::from_proto(request.user_id);
2507 let InviteMemberResult {
2508 channel,
2509 notifications,
2510 } = db
2511 .invite_channel_member(
2512 channel_id,
2513 invitee_id,
2514 session.user_id,
2515 request.role().into(),
2516 )
2517 .await?;
2518
2519 let update = proto::UpdateChannels {
2520 channel_invitations: vec![channel.to_proto()],
2521 ..Default::default()
2522 };
2523
2524 let connection_pool = session.connection_pool().await;
2525 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2526 session.peer.send(connection_id, update.clone())?;
2527 }
2528
2529 send_notifications(&connection_pool, &session.peer, notifications);
2530
2531 response.send(proto::Ack {})?;
2532 Ok(())
2533}
2534
2535/// remove someone from a channel
2536async fn remove_channel_member(
2537 request: proto::RemoveChannelMember,
2538 response: Response<proto::RemoveChannelMember>,
2539 session: Session,
2540) -> Result<()> {
2541 let db = session.db().await;
2542 let channel_id = ChannelId::from_proto(request.channel_id);
2543 let member_id = UserId::from_proto(request.user_id);
2544
2545 let RemoveChannelMemberResult {
2546 membership_update,
2547 notification_id,
2548 } = db
2549 .remove_channel_member(channel_id, member_id, session.user_id)
2550 .await?;
2551
2552 let mut connection_pool = session.connection_pool().await;
2553 notify_membership_updated(
2554 &mut connection_pool,
2555 membership_update,
2556 member_id,
2557 &session.peer,
2558 );
2559 for connection_id in connection_pool.user_connection_ids(member_id) {
2560 if let Some(notification_id) = notification_id {
2561 session
2562 .peer
2563 .send(
2564 connection_id,
2565 proto::DeleteNotification {
2566 notification_id: notification_id.to_proto(),
2567 },
2568 )
2569 .trace_err();
2570 }
2571 }
2572
2573 response.send(proto::Ack {})?;
2574 Ok(())
2575}
2576
2577/// Toggle the channel between public and private.
2578/// Care is taken to maintain the invariant that public channels only descend from public channels,
2579/// (though members-only channels can appear at any point in the hierarchy).
2580async fn set_channel_visibility(
2581 request: proto::SetChannelVisibility,
2582 response: Response<proto::SetChannelVisibility>,
2583 session: Session,
2584) -> Result<()> {
2585 let db = session.db().await;
2586 let channel_id = ChannelId::from_proto(request.channel_id);
2587 let visibility = request.visibility().into();
2588
2589 let channel_model = db
2590 .set_channel_visibility(channel_id, visibility, session.user_id)
2591 .await?;
2592 let root_id = channel_model.root_id();
2593 let channel = Channel::from_model(channel_model);
2594
2595 let mut connection_pool = session.connection_pool().await;
2596 for (user_id, role) in connection_pool
2597 .channel_user_ids(root_id)
2598 .collect::<Vec<_>>()
2599 .into_iter()
2600 {
2601 let update = if role.can_see_channel(channel.visibility) {
2602 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2603 proto::UpdateChannels {
2604 channels: vec![channel.to_proto()],
2605 ..Default::default()
2606 }
2607 } else {
2608 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2609 proto::UpdateChannels {
2610 delete_channels: vec![channel.id.to_proto()],
2611 ..Default::default()
2612 }
2613 };
2614
2615 for connection_id in connection_pool.user_connection_ids(user_id) {
2616 session.peer.send(connection_id, update.clone())?;
2617 }
2618 }
2619
2620 response.send(proto::Ack {})?;
2621 Ok(())
2622}
2623
2624/// Alter the role for a user in the channel.
2625async fn set_channel_member_role(
2626 request: proto::SetChannelMemberRole,
2627 response: Response<proto::SetChannelMemberRole>,
2628 session: Session,
2629) -> Result<()> {
2630 let db = session.db().await;
2631 let channel_id = ChannelId::from_proto(request.channel_id);
2632 let member_id = UserId::from_proto(request.user_id);
2633 let result = db
2634 .set_channel_member_role(
2635 channel_id,
2636 session.user_id,
2637 member_id,
2638 request.role().into(),
2639 )
2640 .await?;
2641
2642 match result {
2643 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2644 let mut connection_pool = session.connection_pool().await;
2645 notify_membership_updated(
2646 &mut connection_pool,
2647 membership_update,
2648 member_id,
2649 &session.peer,
2650 )
2651 }
2652 db::SetMemberRoleResult::InviteUpdated(channel) => {
2653 let update = proto::UpdateChannels {
2654 channel_invitations: vec![channel.to_proto()],
2655 ..Default::default()
2656 };
2657
2658 for connection_id in session
2659 .connection_pool()
2660 .await
2661 .user_connection_ids(member_id)
2662 {
2663 session.peer.send(connection_id, update.clone())?;
2664 }
2665 }
2666 }
2667
2668 response.send(proto::Ack {})?;
2669 Ok(())
2670}
2671
2672/// Change the name of a channel
2673async fn rename_channel(
2674 request: proto::RenameChannel,
2675 response: Response<proto::RenameChannel>,
2676 session: Session,
2677) -> Result<()> {
2678 let db = session.db().await;
2679 let channel_id = ChannelId::from_proto(request.channel_id);
2680 let channel_model = db
2681 .rename_channel(channel_id, session.user_id, &request.name)
2682 .await?;
2683 let root_id = channel_model.root_id();
2684 let channel = Channel::from_model(channel_model);
2685
2686 response.send(proto::RenameChannelResponse {
2687 channel: Some(channel.to_proto()),
2688 })?;
2689
2690 let connection_pool = session.connection_pool().await;
2691 let update = proto::UpdateChannels {
2692 channels: vec![channel.to_proto()],
2693 ..Default::default()
2694 };
2695 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2696 if role.can_see_channel(channel.visibility) {
2697 session.peer.send(connection_id, update.clone())?;
2698 }
2699 }
2700
2701 Ok(())
2702}
2703
2704/// Move a channel to a new parent.
2705async fn move_channel(
2706 request: proto::MoveChannel,
2707 response: Response<proto::MoveChannel>,
2708 session: Session,
2709) -> Result<()> {
2710 let channel_id = ChannelId::from_proto(request.channel_id);
2711 let to = ChannelId::from_proto(request.to);
2712
2713 let (root_id, channels) = session
2714 .db()
2715 .await
2716 .move_channel(channel_id, to, session.user_id)
2717 .await?;
2718
2719 let connection_pool = session.connection_pool().await;
2720 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2721 let channels = channels
2722 .iter()
2723 .filter_map(|channel| {
2724 if role.can_see_channel(channel.visibility) {
2725 Some(channel.to_proto())
2726 } else {
2727 None
2728 }
2729 })
2730 .collect::<Vec<_>>();
2731 if channels.is_empty() {
2732 continue;
2733 }
2734
2735 let update = proto::UpdateChannels {
2736 channels,
2737 ..Default::default()
2738 };
2739
2740 session.peer.send(connection_id, update.clone())?;
2741 }
2742
2743 response.send(Ack {})?;
2744 Ok(())
2745}
2746
2747/// Get the list of channel members
2748async fn get_channel_members(
2749 request: proto::GetChannelMembers,
2750 response: Response<proto::GetChannelMembers>,
2751 session: Session,
2752) -> Result<()> {
2753 let db = session.db().await;
2754 let channel_id = ChannelId::from_proto(request.channel_id);
2755 let members = db
2756 .get_channel_participant_details(channel_id, session.user_id)
2757 .await?;
2758 response.send(proto::GetChannelMembersResponse { members })?;
2759 Ok(())
2760}
2761
2762/// Accept or decline a channel invitation.
2763async fn respond_to_channel_invite(
2764 request: proto::RespondToChannelInvite,
2765 response: Response<proto::RespondToChannelInvite>,
2766 session: Session,
2767) -> Result<()> {
2768 let db = session.db().await;
2769 let channel_id = ChannelId::from_proto(request.channel_id);
2770 let RespondToChannelInvite {
2771 membership_update,
2772 notifications,
2773 } = db
2774 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2775 .await?;
2776
2777 let mut connection_pool = session.connection_pool().await;
2778 if let Some(membership_update) = membership_update {
2779 notify_membership_updated(
2780 &mut connection_pool,
2781 membership_update,
2782 session.user_id,
2783 &session.peer,
2784 );
2785 } else {
2786 let update = proto::UpdateChannels {
2787 remove_channel_invitations: vec![channel_id.to_proto()],
2788 ..Default::default()
2789 };
2790
2791 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2792 session.peer.send(connection_id, update.clone())?;
2793 }
2794 };
2795
2796 send_notifications(&connection_pool, &session.peer, notifications);
2797
2798 response.send(proto::Ack {})?;
2799
2800 Ok(())
2801}
2802
2803/// Join the channels' room
2804async fn join_channel(
2805 request: proto::JoinChannel,
2806 response: Response<proto::JoinChannel>,
2807 session: Session,
2808) -> Result<()> {
2809 let channel_id = ChannelId::from_proto(request.channel_id);
2810 join_channel_internal(channel_id, Box::new(response), session).await
2811}
2812
2813trait JoinChannelInternalResponse {
2814 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2815}
2816impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2817 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2818 Response::<proto::JoinChannel>::send(self, result)
2819 }
2820}
2821impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2822 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2823 Response::<proto::JoinRoom>::send(self, result)
2824 }
2825}
2826
2827async fn join_channel_internal(
2828 channel_id: ChannelId,
2829 response: Box<impl JoinChannelInternalResponse>,
2830 session: Session,
2831) -> Result<()> {
2832 let joined_room = {
2833 leave_room_for_session(&session).await?;
2834 let db = session.db().await;
2835
2836 let (joined_room, membership_updated, role) = db
2837 .join_channel(channel_id, session.user_id, session.connection_id)
2838 .await?;
2839
2840 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2841 let (can_publish, token) = if role == ChannelRole::Guest {
2842 (
2843 false,
2844 live_kit
2845 .guest_token(
2846 &joined_room.room.live_kit_room,
2847 &session.user_id.to_string(),
2848 )
2849 .trace_err()?,
2850 )
2851 } else {
2852 (
2853 true,
2854 live_kit
2855 .room_token(
2856 &joined_room.room.live_kit_room,
2857 &session.user_id.to_string(),
2858 )
2859 .trace_err()?,
2860 )
2861 };
2862
2863 Some(LiveKitConnectionInfo {
2864 server_url: live_kit.url().into(),
2865 token,
2866 can_publish,
2867 })
2868 });
2869
2870 response.send(proto::JoinRoomResponse {
2871 room: Some(joined_room.room.clone()),
2872 channel_id: joined_room
2873 .channel
2874 .as_ref()
2875 .map(|channel| channel.id.to_proto()),
2876 live_kit_connection_info,
2877 })?;
2878
2879 let mut connection_pool = session.connection_pool().await;
2880 if let Some(membership_updated) = membership_updated {
2881 notify_membership_updated(
2882 &mut connection_pool,
2883 membership_updated,
2884 session.user_id,
2885 &session.peer,
2886 );
2887 }
2888
2889 room_updated(&joined_room.room, &session.peer);
2890
2891 joined_room
2892 };
2893
2894 channel_updated(
2895 &joined_room
2896 .channel
2897 .ok_or_else(|| anyhow!("channel not returned"))?,
2898 &joined_room.room,
2899 &session.peer,
2900 &*session.connection_pool().await,
2901 );
2902
2903 update_user_contacts(session.user_id, &session).await?;
2904 Ok(())
2905}
2906
2907/// Start editing the channel notes
2908async fn join_channel_buffer(
2909 request: proto::JoinChannelBuffer,
2910 response: Response<proto::JoinChannelBuffer>,
2911 session: Session,
2912) -> Result<()> {
2913 let db = session.db().await;
2914 let channel_id = ChannelId::from_proto(request.channel_id);
2915
2916 let open_response = db
2917 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2918 .await?;
2919
2920 let collaborators = open_response.collaborators.clone();
2921 response.send(open_response)?;
2922
2923 let update = UpdateChannelBufferCollaborators {
2924 channel_id: channel_id.to_proto(),
2925 collaborators: collaborators.clone(),
2926 };
2927 channel_buffer_updated(
2928 session.connection_id,
2929 collaborators
2930 .iter()
2931 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2932 &update,
2933 &session.peer,
2934 );
2935
2936 Ok(())
2937}
2938
2939/// Edit the channel notes
2940async fn update_channel_buffer(
2941 request: proto::UpdateChannelBuffer,
2942 session: Session,
2943) -> Result<()> {
2944 let db = session.db().await;
2945 let channel_id = ChannelId::from_proto(request.channel_id);
2946
2947 let (collaborators, non_collaborators, epoch, version) = db
2948 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2949 .await?;
2950
2951 channel_buffer_updated(
2952 session.connection_id,
2953 collaborators,
2954 &proto::UpdateChannelBuffer {
2955 channel_id: channel_id.to_proto(),
2956 operations: request.operations,
2957 },
2958 &session.peer,
2959 );
2960
2961 let pool = &*session.connection_pool().await;
2962
2963 broadcast(
2964 None,
2965 non_collaborators
2966 .iter()
2967 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2968 |peer_id| {
2969 session.peer.send(
2970 peer_id,
2971 proto::UpdateChannels {
2972 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2973 channel_id: channel_id.to_proto(),
2974 epoch: epoch as u64,
2975 version: version.clone(),
2976 }],
2977 ..Default::default()
2978 },
2979 )
2980 },
2981 );
2982
2983 Ok(())
2984}
2985
2986/// Rejoin the channel notes after a connection blip
2987async fn rejoin_channel_buffers(
2988 request: proto::RejoinChannelBuffers,
2989 response: Response<proto::RejoinChannelBuffers>,
2990 session: Session,
2991) -> Result<()> {
2992 let db = session.db().await;
2993 let buffers = db
2994 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2995 .await?;
2996
2997 for rejoined_buffer in &buffers {
2998 let collaborators_to_notify = rejoined_buffer
2999 .buffer
3000 .collaborators
3001 .iter()
3002 .filter_map(|c| Some(c.peer_id?.into()));
3003 channel_buffer_updated(
3004 session.connection_id,
3005 collaborators_to_notify,
3006 &proto::UpdateChannelBufferCollaborators {
3007 channel_id: rejoined_buffer.buffer.channel_id,
3008 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3009 },
3010 &session.peer,
3011 );
3012 }
3013
3014 response.send(proto::RejoinChannelBuffersResponse {
3015 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3016 })?;
3017
3018 Ok(())
3019}
3020
3021/// Stop editing the channel notes
3022async fn leave_channel_buffer(
3023 request: proto::LeaveChannelBuffer,
3024 response: Response<proto::LeaveChannelBuffer>,
3025 session: Session,
3026) -> Result<()> {
3027 let db = session.db().await;
3028 let channel_id = ChannelId::from_proto(request.channel_id);
3029
3030 let left_buffer = db
3031 .leave_channel_buffer(channel_id, session.connection_id)
3032 .await?;
3033
3034 response.send(Ack {})?;
3035
3036 channel_buffer_updated(
3037 session.connection_id,
3038 left_buffer.connections,
3039 &proto::UpdateChannelBufferCollaborators {
3040 channel_id: channel_id.to_proto(),
3041 collaborators: left_buffer.collaborators,
3042 },
3043 &session.peer,
3044 );
3045
3046 Ok(())
3047}
3048
3049fn channel_buffer_updated<T: EnvelopedMessage>(
3050 sender_id: ConnectionId,
3051 collaborators: impl IntoIterator<Item = ConnectionId>,
3052 message: &T,
3053 peer: &Peer,
3054) {
3055 broadcast(Some(sender_id), collaborators, |peer_id| {
3056 peer.send(peer_id, message.clone())
3057 });
3058}
3059
3060fn send_notifications(
3061 connection_pool: &ConnectionPool,
3062 peer: &Peer,
3063 notifications: db::NotificationBatch,
3064) {
3065 for (user_id, notification) in notifications {
3066 for connection_id in connection_pool.user_connection_ids(user_id) {
3067 if let Err(error) = peer.send(
3068 connection_id,
3069 proto::AddNotification {
3070 notification: Some(notification.clone()),
3071 },
3072 ) {
3073 tracing::error!(
3074 "failed to send notification to {:?} {}",
3075 connection_id,
3076 error
3077 );
3078 }
3079 }
3080 }
3081}
3082
3083/// Send a message to the channel
3084async fn send_channel_message(
3085 request: proto::SendChannelMessage,
3086 response: Response<proto::SendChannelMessage>,
3087 session: Session,
3088) -> Result<()> {
3089 // Validate the message body.
3090 let body = request.body.trim().to_string();
3091 if body.len() > MAX_MESSAGE_LEN {
3092 return Err(anyhow!("message is too long"))?;
3093 }
3094 if body.is_empty() {
3095 return Err(anyhow!("message can't be blank"))?;
3096 }
3097
3098 // TODO: adjust mentions if body is trimmed
3099
3100 let timestamp = OffsetDateTime::now_utc();
3101 let nonce = request
3102 .nonce
3103 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3104
3105 let channel_id = ChannelId::from_proto(request.channel_id);
3106 let CreatedChannelMessage {
3107 message_id,
3108 participant_connection_ids,
3109 channel_members,
3110 notifications,
3111 } = session
3112 .db()
3113 .await
3114 .create_channel_message(
3115 channel_id,
3116 session.user_id,
3117 &body,
3118 &request.mentions,
3119 timestamp,
3120 nonce.clone().into(),
3121 match request.reply_to_message_id {
3122 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3123 None => None,
3124 },
3125 )
3126 .await?;
3127 let message = proto::ChannelMessage {
3128 sender_id: session.user_id.to_proto(),
3129 id: message_id.to_proto(),
3130 body,
3131 mentions: request.mentions,
3132 timestamp: timestamp.unix_timestamp() as u64,
3133 nonce: Some(nonce),
3134 reply_to_message_id: request.reply_to_message_id,
3135 };
3136 broadcast(
3137 Some(session.connection_id),
3138 participant_connection_ids,
3139 |connection| {
3140 session.peer.send(
3141 connection,
3142 proto::ChannelMessageSent {
3143 channel_id: channel_id.to_proto(),
3144 message: Some(message.clone()),
3145 },
3146 )
3147 },
3148 );
3149 response.send(proto::SendChannelMessageResponse {
3150 message: Some(message),
3151 })?;
3152
3153 let pool = &*session.connection_pool().await;
3154 broadcast(
3155 None,
3156 channel_members
3157 .iter()
3158 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3159 |peer_id| {
3160 session.peer.send(
3161 peer_id,
3162 proto::UpdateChannels {
3163 latest_channel_message_ids: vec![proto::ChannelMessageId {
3164 channel_id: channel_id.to_proto(),
3165 message_id: message_id.to_proto(),
3166 }],
3167 ..Default::default()
3168 },
3169 )
3170 },
3171 );
3172 send_notifications(pool, &session.peer, notifications);
3173
3174 Ok(())
3175}
3176
3177/// Delete a channel message
3178async fn remove_channel_message(
3179 request: proto::RemoveChannelMessage,
3180 response: Response<proto::RemoveChannelMessage>,
3181 session: Session,
3182) -> Result<()> {
3183 let channel_id = ChannelId::from_proto(request.channel_id);
3184 let message_id = MessageId::from_proto(request.message_id);
3185 let connection_ids = session
3186 .db()
3187 .await
3188 .remove_channel_message(channel_id, message_id, session.user_id)
3189 .await?;
3190 broadcast(Some(session.connection_id), connection_ids, |connection| {
3191 session.peer.send(connection, request.clone())
3192 });
3193 response.send(proto::Ack {})?;
3194 Ok(())
3195}
3196
3197/// Mark a channel message as read
3198async fn acknowledge_channel_message(
3199 request: proto::AckChannelMessage,
3200 session: Session,
3201) -> Result<()> {
3202 let channel_id = ChannelId::from_proto(request.channel_id);
3203 let message_id = MessageId::from_proto(request.message_id);
3204 let notifications = session
3205 .db()
3206 .await
3207 .observe_channel_message(channel_id, session.user_id, message_id)
3208 .await?;
3209 send_notifications(
3210 &*session.connection_pool().await,
3211 &session.peer,
3212 notifications,
3213 );
3214 Ok(())
3215}
3216
3217/// Mark a buffer version as synced
3218async fn acknowledge_buffer_version(
3219 request: proto::AckBufferOperation,
3220 session: Session,
3221) -> Result<()> {
3222 let buffer_id = BufferId::from_proto(request.buffer_id);
3223 session
3224 .db()
3225 .await
3226 .observe_buffer_version(
3227 buffer_id,
3228 session.user_id,
3229 request.epoch as i32,
3230 &request.version,
3231 )
3232 .await?;
3233 Ok(())
3234}
3235
3236/// Start receiving chat updates for a channel
3237async fn join_channel_chat(
3238 request: proto::JoinChannelChat,
3239 response: Response<proto::JoinChannelChat>,
3240 session: Session,
3241) -> Result<()> {
3242 let channel_id = ChannelId::from_proto(request.channel_id);
3243
3244 let db = session.db().await;
3245 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3246 .await?;
3247 let messages = db
3248 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3249 .await?;
3250 response.send(proto::JoinChannelChatResponse {
3251 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3252 messages,
3253 })?;
3254 Ok(())
3255}
3256
3257/// Stop receiving chat updates for a channel
3258async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3259 let channel_id = ChannelId::from_proto(request.channel_id);
3260 session
3261 .db()
3262 .await
3263 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3264 .await?;
3265 Ok(())
3266}
3267
3268/// Retrieve the chat history for a channel
3269async fn get_channel_messages(
3270 request: proto::GetChannelMessages,
3271 response: Response<proto::GetChannelMessages>,
3272 session: Session,
3273) -> Result<()> {
3274 let channel_id = ChannelId::from_proto(request.channel_id);
3275 let messages = session
3276 .db()
3277 .await
3278 .get_channel_messages(
3279 channel_id,
3280 session.user_id,
3281 MESSAGE_COUNT_PER_PAGE,
3282 Some(MessageId::from_proto(request.before_message_id)),
3283 )
3284 .await?;
3285 response.send(proto::GetChannelMessagesResponse {
3286 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3287 messages,
3288 })?;
3289 Ok(())
3290}
3291
3292/// Retrieve specific chat messages
3293async fn get_channel_messages_by_id(
3294 request: proto::GetChannelMessagesById,
3295 response: Response<proto::GetChannelMessagesById>,
3296 session: Session,
3297) -> Result<()> {
3298 let message_ids = request
3299 .message_ids
3300 .iter()
3301 .map(|id| MessageId::from_proto(*id))
3302 .collect::<Vec<_>>();
3303 let messages = session
3304 .db()
3305 .await
3306 .get_channel_messages_by_id(session.user_id, &message_ids)
3307 .await?;
3308 response.send(proto::GetChannelMessagesResponse {
3309 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3310 messages,
3311 })?;
3312 Ok(())
3313}
3314
3315/// Retrieve the current users notifications
3316async fn get_notifications(
3317 request: proto::GetNotifications,
3318 response: Response<proto::GetNotifications>,
3319 session: Session,
3320) -> Result<()> {
3321 let notifications = session
3322 .db()
3323 .await
3324 .get_notifications(
3325 session.user_id,
3326 NOTIFICATION_COUNT_PER_PAGE,
3327 request
3328 .before_id
3329 .map(|id| db::NotificationId::from_proto(id)),
3330 )
3331 .await?;
3332 response.send(proto::GetNotificationsResponse {
3333 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3334 notifications,
3335 })?;
3336 Ok(())
3337}
3338
3339/// Mark notifications as read
3340async fn mark_notification_as_read(
3341 request: proto::MarkNotificationRead,
3342 response: Response<proto::MarkNotificationRead>,
3343 session: Session,
3344) -> Result<()> {
3345 let database = &session.db().await;
3346 let notifications = database
3347 .mark_notification_as_read_by_id(
3348 session.user_id,
3349 NotificationId::from_proto(request.notification_id),
3350 )
3351 .await?;
3352 send_notifications(
3353 &*session.connection_pool().await,
3354 &session.peer,
3355 notifications,
3356 );
3357 response.send(proto::Ack {})?;
3358 Ok(())
3359}
3360
3361/// Get the current users information
3362async fn get_private_user_info(
3363 _request: proto::GetPrivateUserInfo,
3364 response: Response<proto::GetPrivateUserInfo>,
3365 session: Session,
3366) -> Result<()> {
3367 let db = session.db().await;
3368
3369 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3370 let user = db
3371 .get_user_by_id(session.user_id)
3372 .await?
3373 .ok_or_else(|| anyhow!("user not found"))?;
3374 let flags = db.get_user_flags(session.user_id).await?;
3375
3376 response.send(proto::GetPrivateUserInfoResponse {
3377 metrics_id,
3378 staff: user.admin,
3379 flags,
3380 })?;
3381 Ok(())
3382}
3383
3384fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3385 match message {
3386 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3387 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3388 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3389 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3390 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3391 code: frame.code.into(),
3392 reason: frame.reason,
3393 })),
3394 }
3395}
3396
3397fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3398 match message {
3399 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3400 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3401 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3402 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3403 AxumMessage::Close(frame) => {
3404 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3405 code: frame.code.into(),
3406 reason: frame.reason,
3407 }))
3408 }
3409 }
3410}
3411
3412fn notify_membership_updated(
3413 connection_pool: &mut ConnectionPool,
3414 result: MembershipUpdated,
3415 user_id: UserId,
3416 peer: &Peer,
3417) {
3418 for membership in &result.new_channels.channel_memberships {
3419 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3420 }
3421 for channel_id in &result.removed_channels {
3422 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3423 }
3424
3425 let user_channels_update = proto::UpdateUserChannels {
3426 channel_memberships: result
3427 .new_channels
3428 .channel_memberships
3429 .iter()
3430 .map(|cm| proto::ChannelMembership {
3431 channel_id: cm.channel_id.to_proto(),
3432 role: cm.role.into(),
3433 })
3434 .collect(),
3435 ..Default::default()
3436 };
3437
3438 let mut update = build_channels_update(result.new_channels, vec![]);
3439 update.delete_channels = result
3440 .removed_channels
3441 .into_iter()
3442 .map(|id| id.to_proto())
3443 .collect();
3444 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3445
3446 for connection_id in connection_pool.user_connection_ids(user_id) {
3447 peer.send(connection_id, user_channels_update.clone())
3448 .trace_err();
3449 peer.send(connection_id, update.clone()).trace_err();
3450 }
3451}
3452
3453fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3454 proto::UpdateUserChannels {
3455 channel_memberships: channels
3456 .channel_memberships
3457 .iter()
3458 .map(|m| proto::ChannelMembership {
3459 channel_id: m.channel_id.to_proto(),
3460 role: m.role.into(),
3461 })
3462 .collect(),
3463 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3464 observed_channel_message_id: channels.observed_channel_messages.clone(),
3465 }
3466}
3467
3468fn build_channels_update(
3469 channels: ChannelsForUser,
3470 channel_invites: Vec<db::Channel>,
3471) -> proto::UpdateChannels {
3472 let mut update = proto::UpdateChannels::default();
3473
3474 for channel in channels.channels {
3475 update.channels.push(channel.to_proto());
3476 }
3477
3478 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3479 update.latest_channel_message_ids = channels.latest_channel_messages;
3480
3481 for (channel_id, participants) in channels.channel_participants {
3482 update
3483 .channel_participants
3484 .push(proto::ChannelParticipants {
3485 channel_id: channel_id.to_proto(),
3486 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3487 });
3488 }
3489
3490 for channel in channel_invites {
3491 update.channel_invitations.push(channel.to_proto());
3492 }
3493 for project in channels.hosted_projects {
3494 update.hosted_projects.push(project);
3495 }
3496
3497 update
3498}
3499
3500fn build_initial_contacts_update(
3501 contacts: Vec<db::Contact>,
3502 pool: &ConnectionPool,
3503) -> proto::UpdateContacts {
3504 let mut update = proto::UpdateContacts::default();
3505
3506 for contact in contacts {
3507 match contact {
3508 db::Contact::Accepted { user_id, busy } => {
3509 update.contacts.push(contact_for_user(user_id, busy, &pool));
3510 }
3511 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3512 db::Contact::Incoming { user_id } => {
3513 update
3514 .incoming_requests
3515 .push(proto::IncomingContactRequest {
3516 requester_id: user_id.to_proto(),
3517 })
3518 }
3519 }
3520 }
3521
3522 update
3523}
3524
3525fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3526 proto::Contact {
3527 user_id: user_id.to_proto(),
3528 online: pool.is_user_online(user_id),
3529 busy,
3530 }
3531}
3532
3533fn room_updated(room: &proto::Room, peer: &Peer) {
3534 broadcast(
3535 None,
3536 room.participants
3537 .iter()
3538 .filter_map(|participant| Some(participant.peer_id?.into())),
3539 |peer_id| {
3540 peer.send(
3541 peer_id,
3542 proto::RoomUpdated {
3543 room: Some(room.clone()),
3544 },
3545 )
3546 },
3547 );
3548}
3549
3550fn channel_updated(
3551 channel: &db::channel::Model,
3552 room: &proto::Room,
3553 peer: &Peer,
3554 pool: &ConnectionPool,
3555) {
3556 let participants = room
3557 .participants
3558 .iter()
3559 .map(|p| p.user_id)
3560 .collect::<Vec<_>>();
3561
3562 broadcast(
3563 None,
3564 pool.channel_connection_ids(channel.root_id())
3565 .filter_map(|(channel_id, role)| {
3566 role.can_see_channel(channel.visibility).then(|| channel_id)
3567 }),
3568 |peer_id| {
3569 peer.send(
3570 peer_id,
3571 proto::UpdateChannels {
3572 channel_participants: vec![proto::ChannelParticipants {
3573 channel_id: channel.id.to_proto(),
3574 participant_user_ids: participants.clone(),
3575 }],
3576 ..Default::default()
3577 },
3578 )
3579 },
3580 );
3581}
3582
3583async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3584 let db = session.db().await;
3585
3586 let contacts = db.get_contacts(user_id).await?;
3587 let busy = db.is_user_busy(user_id).await?;
3588
3589 let pool = session.connection_pool().await;
3590 let updated_contact = contact_for_user(user_id, busy, &pool);
3591 for contact in contacts {
3592 if let db::Contact::Accepted {
3593 user_id: contact_user_id,
3594 ..
3595 } = contact
3596 {
3597 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3598 session
3599 .peer
3600 .send(
3601 contact_conn_id,
3602 proto::UpdateContacts {
3603 contacts: vec![updated_contact.clone()],
3604 remove_contacts: Default::default(),
3605 incoming_requests: Default::default(),
3606 remove_incoming_requests: Default::default(),
3607 outgoing_requests: Default::default(),
3608 remove_outgoing_requests: Default::default(),
3609 },
3610 )
3611 .trace_err();
3612 }
3613 }
3614 }
3615 Ok(())
3616}
3617
3618async fn leave_room_for_session(session: &Session) -> Result<()> {
3619 let mut contacts_to_update = HashSet::default();
3620
3621 let room_id;
3622 let canceled_calls_to_user_ids;
3623 let live_kit_room;
3624 let delete_live_kit_room;
3625 let room;
3626 let channel;
3627
3628 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3629 contacts_to_update.insert(session.user_id);
3630
3631 for project in left_room.left_projects.values() {
3632 project_left(project, session);
3633 }
3634
3635 room_id = RoomId::from_proto(left_room.room.id);
3636 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3637 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3638 delete_live_kit_room = left_room.deleted;
3639 room = mem::take(&mut left_room.room);
3640 channel = mem::take(&mut left_room.channel);
3641
3642 room_updated(&room, &session.peer);
3643 } else {
3644 return Ok(());
3645 }
3646
3647 if let Some(channel) = channel {
3648 channel_updated(
3649 &channel,
3650 &room,
3651 &session.peer,
3652 &*session.connection_pool().await,
3653 );
3654 }
3655
3656 {
3657 let pool = session.connection_pool().await;
3658 for canceled_user_id in canceled_calls_to_user_ids {
3659 for connection_id in pool.user_connection_ids(canceled_user_id) {
3660 session
3661 .peer
3662 .send(
3663 connection_id,
3664 proto::CallCanceled {
3665 room_id: room_id.to_proto(),
3666 },
3667 )
3668 .trace_err();
3669 }
3670 contacts_to_update.insert(canceled_user_id);
3671 }
3672 }
3673
3674 for contact_user_id in contacts_to_update {
3675 update_user_contacts(contact_user_id, &session).await?;
3676 }
3677
3678 if let Some(live_kit) = session.live_kit_client.as_ref() {
3679 live_kit
3680 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3681 .await
3682 .trace_err();
3683
3684 if delete_live_kit_room {
3685 live_kit.delete_room(live_kit_room).await.trace_err();
3686 }
3687 }
3688
3689 Ok(())
3690}
3691
3692async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3693 let left_channel_buffers = session
3694 .db()
3695 .await
3696 .leave_channel_buffers(session.connection_id)
3697 .await?;
3698
3699 for left_buffer in left_channel_buffers {
3700 channel_buffer_updated(
3701 session.connection_id,
3702 left_buffer.connections,
3703 &proto::UpdateChannelBufferCollaborators {
3704 channel_id: left_buffer.channel_id.to_proto(),
3705 collaborators: left_buffer.collaborators,
3706 },
3707 &session.peer,
3708 );
3709 }
3710
3711 Ok(())
3712}
3713
3714fn project_left(project: &db::LeftProject, session: &Session) {
3715 for connection_id in &project.connection_ids {
3716 if project.host_user_id == Some(session.user_id) {
3717 session
3718 .peer
3719 .send(
3720 *connection_id,
3721 proto::UnshareProject {
3722 project_id: project.id.to_proto(),
3723 },
3724 )
3725 .trace_err();
3726 } else {
3727 session
3728 .peer
3729 .send(
3730 *connection_id,
3731 proto::RemoveProjectCollaborator {
3732 project_id: project.id.to_proto(),
3733 peer_id: Some(session.connection_id.into()),
3734 },
3735 )
3736 .trace_err();
3737 }
3738 }
3739}
3740
3741pub trait ResultExt {
3742 type Ok;
3743
3744 fn trace_err(self) -> Option<Self::Ok>;
3745}
3746
3747impl<T, E> ResultExt for Result<T, E>
3748where
3749 E: std::fmt::Debug,
3750{
3751 type Ok = T;
3752
3753 fn trace_err(self) -> Option<T> {
3754 match self {
3755 Ok(value) => Some(value),
3756 Err(error) => {
3757 tracing::error!("{:?}", error);
3758 None
3759 }
3760 }
3761 }
3762}