1mod connection_pool;
2
3use crate::{
4 auth::{self, Impersonator},
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
7 InviteMemberResult, MembershipUpdated, MessageId, NotificationId, ProjectId,
8 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
9 },
10 executor::Executor,
11 AppState, Error, Result,
12};
13use anyhow::anyhow;
14use async_tungstenite::tungstenite::{
15 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
16};
17use axum::{
18 body::Body,
19 extract::{
20 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
21 ConnectInfo, WebSocketUpgrade,
22 },
23 headers::{Header, HeaderName},
24 http::StatusCode,
25 middleware,
26 response::IntoResponse,
27 routing::get,
28 Extension, Router, TypedHeader,
29};
30use collections::{HashMap, HashSet};
31pub use connection_pool::{ConnectionPool, ZedVersion};
32use futures::{
33 channel::oneshot,
34 future::{self, BoxFuture},
35 stream::FuturesUnordered,
36 FutureExt, SinkExt, StreamExt, TryStreamExt,
37};
38use prometheus::{register_int_gauge, IntGauge};
39use rpc::{
40 proto::{
41 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
42 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
43 },
44 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
45};
46use serde::{Serialize, Serializer};
47use std::{
48 any::TypeId,
49 fmt,
50 future::Future,
51 marker::PhantomData,
52 mem,
53 net::SocketAddr,
54 ops::{Deref, DerefMut},
55 rc::Rc,
56 sync::{
57 atomic::{AtomicBool, Ordering::SeqCst},
58 Arc, OnceLock,
59 },
60 time::{Duration, Instant},
61};
62use time::OffsetDateTime;
63use tokio::sync::{watch, Semaphore};
64use tower::ServiceBuilder;
65use tracing::{field, info_span, instrument, Instrument};
66use util::SemanticVersion;
67
68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
70
71const MESSAGE_COUNT_PER_PAGE: usize = 100;
72const MAX_MESSAGE_LEN: usize = 1024;
73const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
74
75type MessageHandler =
76 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
77
78struct Response<R> {
79 peer: Arc<Peer>,
80 receipt: Receipt<R>,
81 responded: Arc<AtomicBool>,
82}
83
84impl<R: RequestMessage> Response<R> {
85 fn send(self, payload: R::Response) -> Result<()> {
86 self.responded.store(true, SeqCst);
87 self.peer.respond(self.receipt, payload)?;
88 Ok(())
89 }
90}
91
92#[derive(Clone)]
93struct Session {
94 user_id: UserId,
95 connection_id: ConnectionId,
96 db: Arc<tokio::sync::Mutex<DbHandle>>,
97 peer: Arc<Peer>,
98 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
99 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
100 _executor: Executor,
101}
102
103impl Session {
104 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
105 #[cfg(test)]
106 tokio::task::yield_now().await;
107 let guard = self.db.lock().await;
108 #[cfg(test)]
109 tokio::task::yield_now().await;
110 guard
111 }
112
113 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
114 #[cfg(test)]
115 tokio::task::yield_now().await;
116 let guard = self.connection_pool.lock();
117 ConnectionPoolGuard {
118 guard,
119 _not_send: PhantomData,
120 }
121 }
122}
123
124impl fmt::Debug for Session {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 f.debug_struct("Session")
127 .field("user_id", &self.user_id)
128 .field("connection_id", &self.connection_id)
129 .finish()
130 }
131}
132
133struct DbHandle(Arc<Database>);
134
135impl Deref for DbHandle {
136 type Target = Database;
137
138 fn deref(&self) -> &Self::Target {
139 self.0.as_ref()
140 }
141}
142
143pub struct Server {
144 id: parking_lot::Mutex<ServerId>,
145 peer: Arc<Peer>,
146 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
147 app_state: Arc<AppState>,
148 executor: Executor,
149 handlers: HashMap<TypeId, MessageHandler>,
150 teardown: watch::Sender<bool>,
151}
152
153pub(crate) struct ConnectionPoolGuard<'a> {
154 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
155 _not_send: PhantomData<Rc<()>>,
156}
157
158#[derive(Serialize)]
159pub struct ServerSnapshot<'a> {
160 peer: &'a Peer,
161 #[serde(serialize_with = "serialize_deref")]
162 connection_pool: ConnectionPoolGuard<'a>,
163}
164
165pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
166where
167 S: Serializer,
168 T: Deref<Target = U>,
169 U: Serialize,
170{
171 Serialize::serialize(value.deref(), serializer)
172}
173
174impl Server {
175 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
176 let mut server = Self {
177 id: parking_lot::Mutex::new(id),
178 peer: Peer::new(id.0 as u32),
179 app_state,
180 executor,
181 connection_pool: Default::default(),
182 handlers: Default::default(),
183 teardown: watch::channel(false).0,
184 };
185
186 server
187 .add_request_handler(ping)
188 .add_request_handler(create_room)
189 .add_request_handler(join_room)
190 .add_request_handler(rejoin_room)
191 .add_request_handler(leave_room)
192 .add_request_handler(set_room_participant_role)
193 .add_request_handler(call)
194 .add_request_handler(cancel_call)
195 .add_message_handler(decline_call)
196 .add_request_handler(update_participant_location)
197 .add_request_handler(share_project)
198 .add_message_handler(unshare_project)
199 .add_request_handler(join_project)
200 .add_message_handler(leave_project)
201 .add_request_handler(update_project)
202 .add_request_handler(update_worktree)
203 .add_message_handler(start_language_server)
204 .add_message_handler(update_language_server)
205 .add_message_handler(update_diagnostic_summary)
206 .add_message_handler(update_worktree_settings)
207 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
208 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
209 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
210 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
211 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
212 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
213 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
214 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
215 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
216 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
217 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
218 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
219 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
220 .add_request_handler(
221 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
222 )
223 .add_request_handler(
224 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
225 )
226 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
227 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
228 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
229 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
230 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
231 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
232 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
233 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
234 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
235 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
236 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
237 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
238 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
239 .add_message_handler(create_buffer_for_peer)
240 .add_request_handler(update_buffer)
241 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
242 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
243 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
244 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
245 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
246 .add_request_handler(get_users)
247 .add_request_handler(fuzzy_search_users)
248 .add_request_handler(request_contact)
249 .add_request_handler(remove_contact)
250 .add_request_handler(respond_to_contact_request)
251 .add_request_handler(create_channel)
252 .add_request_handler(delete_channel)
253 .add_request_handler(invite_channel_member)
254 .add_request_handler(remove_channel_member)
255 .add_request_handler(set_channel_member_role)
256 .add_request_handler(set_channel_visibility)
257 .add_request_handler(rename_channel)
258 .add_request_handler(join_channel_buffer)
259 .add_request_handler(leave_channel_buffer)
260 .add_message_handler(update_channel_buffer)
261 .add_request_handler(rejoin_channel_buffers)
262 .add_request_handler(get_channel_members)
263 .add_request_handler(respond_to_channel_invite)
264 .add_request_handler(join_channel)
265 .add_request_handler(join_channel_chat)
266 .add_message_handler(leave_channel_chat)
267 .add_request_handler(send_channel_message)
268 .add_request_handler(remove_channel_message)
269 .add_request_handler(get_channel_messages)
270 .add_request_handler(get_channel_messages_by_id)
271 .add_request_handler(get_notifications)
272 .add_request_handler(mark_notification_as_read)
273 .add_request_handler(move_channel)
274 .add_request_handler(follow)
275 .add_message_handler(unfollow)
276 .add_message_handler(update_followers)
277 .add_request_handler(get_private_user_info)
278 .add_message_handler(acknowledge_channel_message)
279 .add_message_handler(acknowledge_buffer_version);
280
281 Arc::new(server)
282 }
283
284 pub async fn start(&self) -> Result<()> {
285 let server_id = *self.id.lock();
286 let app_state = self.app_state.clone();
287 let peer = self.peer.clone();
288 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
289 let pool = self.connection_pool.clone();
290 let live_kit_client = self.app_state.live_kit_client.clone();
291
292 let span = info_span!("start server");
293 self.executor.spawn_detached(
294 async move {
295 tracing::info!("waiting for cleanup timeout");
296 timeout.await;
297 tracing::info!("cleanup timeout expired, retrieving stale rooms");
298 if let Some((room_ids, channel_ids)) = app_state
299 .db
300 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
301 .await
302 .trace_err()
303 {
304 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
305 tracing::info!(
306 stale_channel_buffer_count = channel_ids.len(),
307 "retrieved stale channel buffers"
308 );
309
310 for channel_id in channel_ids {
311 if let Some(refreshed_channel_buffer) = app_state
312 .db
313 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
314 .await
315 .trace_err()
316 {
317 for connection_id in refreshed_channel_buffer.connection_ids {
318 peer.send(
319 connection_id,
320 proto::UpdateChannelBufferCollaborators {
321 channel_id: channel_id.to_proto(),
322 collaborators: refreshed_channel_buffer
323 .collaborators
324 .clone(),
325 },
326 )
327 .trace_err();
328 }
329 }
330 }
331
332 for room_id in room_ids {
333 let mut contacts_to_update = HashSet::default();
334 let mut canceled_calls_to_user_ids = Vec::new();
335 let mut live_kit_room = String::new();
336 let mut delete_live_kit_room = false;
337
338 if let Some(mut refreshed_room) = app_state
339 .db
340 .clear_stale_room_participants(room_id, server_id)
341 .await
342 .trace_err()
343 {
344 tracing::info!(
345 room_id = room_id.0,
346 new_participant_count = refreshed_room.room.participants.len(),
347 "refreshed room"
348 );
349 room_updated(&refreshed_room.room, &peer);
350 if let Some(channel_id) = refreshed_room.channel_id {
351 channel_updated(
352 channel_id,
353 &refreshed_room.room,
354 &refreshed_room.channel_members,
355 &peer,
356 &*pool.lock(),
357 );
358 }
359 contacts_to_update
360 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
361 contacts_to_update
362 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
363 canceled_calls_to_user_ids =
364 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
365 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
366 delete_live_kit_room = refreshed_room.room.participants.is_empty();
367 }
368
369 {
370 let pool = pool.lock();
371 for canceled_user_id in canceled_calls_to_user_ids {
372 for connection_id in pool.user_connection_ids(canceled_user_id) {
373 peer.send(
374 connection_id,
375 proto::CallCanceled {
376 room_id: room_id.to_proto(),
377 },
378 )
379 .trace_err();
380 }
381 }
382 }
383
384 for user_id in contacts_to_update {
385 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
386 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
387 if let Some((busy, contacts)) = busy.zip(contacts) {
388 let pool = pool.lock();
389 let updated_contact = contact_for_user(user_id, busy, &pool);
390 for contact in contacts {
391 if let db::Contact::Accepted {
392 user_id: contact_user_id,
393 ..
394 } = contact
395 {
396 for contact_conn_id in
397 pool.user_connection_ids(contact_user_id)
398 {
399 peer.send(
400 contact_conn_id,
401 proto::UpdateContacts {
402 contacts: vec![updated_contact.clone()],
403 remove_contacts: Default::default(),
404 incoming_requests: Default::default(),
405 remove_incoming_requests: Default::default(),
406 outgoing_requests: Default::default(),
407 remove_outgoing_requests: Default::default(),
408 },
409 )
410 .trace_err();
411 }
412 }
413 }
414 }
415 }
416
417 if let Some(live_kit) = live_kit_client.as_ref() {
418 if delete_live_kit_room {
419 live_kit.delete_room(live_kit_room).await.trace_err();
420 }
421 }
422 }
423 }
424
425 app_state
426 .db
427 .delete_stale_servers(&app_state.config.zed_environment, server_id)
428 .await
429 .trace_err();
430 }
431 .instrument(span),
432 );
433 Ok(())
434 }
435
436 pub fn teardown(&self) {
437 self.peer.teardown();
438 self.connection_pool.lock().reset();
439 let _ = self.teardown.send(true);
440 }
441
442 #[cfg(test)]
443 pub fn reset(&self, id: ServerId) {
444 self.teardown();
445 *self.id.lock() = id;
446 self.peer.reset(id.0 as u32);
447 let _ = self.teardown.send(false);
448 }
449
450 #[cfg(test)]
451 pub fn id(&self) -> ServerId {
452 *self.id.lock()
453 }
454
455 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
456 where
457 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
458 Fut: 'static + Send + Future<Output = Result<()>>,
459 M: EnvelopedMessage,
460 {
461 let prev_handler = self.handlers.insert(
462 TypeId::of::<M>(),
463 Box::new(move |envelope, session| {
464 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
465 let span = info_span!(
466 "handle message",
467 payload_type = envelope.payload_type_name()
468 );
469 span.in_scope(|| {
470 tracing::info!(
471 payload_type = envelope.payload_type_name(),
472 "message received"
473 );
474 });
475 let start_time = Instant::now();
476 let future = (handler)(*envelope, session);
477 async move {
478 let result = future.await;
479 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
480 match result {
481 Err(error) => {
482 tracing::error!(%error, ?duration_ms, "error handling message")
483 }
484 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
485 }
486 }
487 .instrument(span)
488 .boxed()
489 }),
490 );
491 if prev_handler.is_some() {
492 panic!("registered a handler for the same message twice");
493 }
494 self
495 }
496
497 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
498 where
499 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
500 Fut: 'static + Send + Future<Output = Result<()>>,
501 M: EnvelopedMessage,
502 {
503 self.add_handler(move |envelope, session| handler(envelope.payload, session));
504 self
505 }
506
507 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
508 where
509 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
510 Fut: Send + Future<Output = Result<()>>,
511 M: RequestMessage,
512 {
513 let handler = Arc::new(handler);
514 self.add_handler(move |envelope, session| {
515 let receipt = envelope.receipt();
516 let handler = handler.clone();
517 async move {
518 let peer = session.peer.clone();
519 let responded = Arc::new(AtomicBool::default());
520 let response = Response {
521 peer: peer.clone(),
522 responded: responded.clone(),
523 receipt,
524 };
525 match (handler)(envelope.payload, response, session).await {
526 Ok(()) => {
527 if responded.load(std::sync::atomic::Ordering::SeqCst) {
528 Ok(())
529 } else {
530 Err(anyhow!("handler did not send a response"))?
531 }
532 }
533 Err(error) => {
534 let proto_err = match &error {
535 Error::Internal(err) => err.to_proto(),
536 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
537 };
538 peer.respond_with_error(receipt, proto_err)?;
539 Err(error)
540 }
541 }
542 }
543 })
544 }
545
546 pub fn handle_connection(
547 self: &Arc<Self>,
548 connection: Connection,
549 address: String,
550 user: User,
551 zed_version: ZedVersion,
552 impersonator: Option<User>,
553 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
554 executor: Executor,
555 ) -> impl Future<Output = Result<()>> {
556 let this = self.clone();
557 let user_id = user.id;
558 let login = user.github_login;
559 let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
560 if let Some(impersonator) = impersonator {
561 span.record("impersonator", &impersonator.github_login);
562 }
563 let mut teardown = self.teardown.subscribe();
564 async move {
565 if *teardown.borrow() {
566 return Err(anyhow!("server is tearing down"))?;
567 }
568 let (connection_id, handle_io, mut incoming_rx) = this
569 .peer
570 .add_connection(connection, {
571 let executor = executor.clone();
572 move |duration| executor.sleep(duration)
573 });
574
575 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
576 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
577 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
578
579 if let Some(send_connection_id) = send_connection_id.take() {
580 let _ = send_connection_id.send(connection_id);
581 }
582
583 if !user.connected_once {
584 this.peer.send(connection_id, proto::ShowContacts {})?;
585 this.app_state.db.set_user_connected_once(user_id, true).await?;
586 }
587
588 let (contacts, channels_for_user, channel_invites) = future::try_join3(
589 this.app_state.db.get_contacts(user_id),
590 this.app_state.db.get_channels_for_user(user_id),
591 this.app_state.db.get_channel_invites_for_user(user_id),
592 ).await?;
593
594 {
595 let mut pool = this.connection_pool.lock();
596 pool.add_connection(connection_id, user_id, user.admin, zed_version);
597 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
598 this.peer.send(connection_id, build_update_user_channels(&channels_for_user))?;
599 this.peer.send(connection_id, build_channels_update(
600 channels_for_user,
601 channel_invites
602 ))?;
603 }
604
605 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
606 this.peer.send(connection_id, incoming_call)?;
607 }
608
609 let session = Session {
610 user_id,
611 connection_id,
612 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
613 peer: this.peer.clone(),
614 connection_pool: this.connection_pool.clone(),
615 live_kit_client: this.app_state.live_kit_client.clone(),
616 _executor: executor.clone()
617 };
618 update_user_contacts(user_id, &session).await?;
619
620 let handle_io = handle_io.fuse();
621 futures::pin_mut!(handle_io);
622
623 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
624 // This prevents deadlocks when e.g., client A performs a request to client B and
625 // client B performs a request to client A. If both clients stop processing further
626 // messages until their respective request completes, they won't have a chance to
627 // respond to the other client's request and cause a deadlock.
628 //
629 // This arrangement ensures we will attempt to process earlier messages first, but fall
630 // back to processing messages arrived later in the spirit of making progress.
631 let mut foreground_message_handlers = FuturesUnordered::new();
632 let concurrent_handlers = Arc::new(Semaphore::new(256));
633 loop {
634 let next_message = async {
635 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
636 let message = incoming_rx.next().await;
637 (permit, message)
638 }.fuse();
639 futures::pin_mut!(next_message);
640 futures::select_biased! {
641 _ = teardown.changed().fuse() => return Ok(()),
642 result = handle_io => {
643 if let Err(error) = result {
644 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
645 }
646 break;
647 }
648 _ = foreground_message_handlers.next() => {}
649 next_message = next_message => {
650 let (permit, message) = next_message;
651 if let Some(message) = message {
652 let type_name = message.payload_type_name();
653 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
654 let span_enter = span.enter();
655 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
656 let is_background = message.is_background();
657 let handle_message = (handler)(message, session.clone());
658 drop(span_enter);
659
660 let handle_message = async move {
661 handle_message.await;
662 drop(permit);
663 }.instrument(span);
664 if is_background {
665 executor.spawn_detached(handle_message);
666 } else {
667 foreground_message_handlers.push(handle_message);
668 }
669 } else {
670 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
671 }
672 } else {
673 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
674 break;
675 }
676 }
677 }
678 }
679
680 drop(foreground_message_handlers);
681 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
682 if let Err(error) = connection_lost(session, teardown, executor).await {
683 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
684 }
685
686 Ok(())
687 }.instrument(span)
688 }
689
690 pub async fn invite_code_redeemed(
691 self: &Arc<Self>,
692 inviter_id: UserId,
693 invitee_id: UserId,
694 ) -> Result<()> {
695 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
696 if let Some(code) = &user.invite_code {
697 let pool = self.connection_pool.lock();
698 let invitee_contact = contact_for_user(invitee_id, false, &pool);
699 for connection_id in pool.user_connection_ids(inviter_id) {
700 self.peer.send(
701 connection_id,
702 proto::UpdateContacts {
703 contacts: vec![invitee_contact.clone()],
704 ..Default::default()
705 },
706 )?;
707 self.peer.send(
708 connection_id,
709 proto::UpdateInviteInfo {
710 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
711 count: user.invite_count as u32,
712 },
713 )?;
714 }
715 }
716 }
717 Ok(())
718 }
719
720 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
721 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
722 if let Some(invite_code) = &user.invite_code {
723 let pool = self.connection_pool.lock();
724 for connection_id in pool.user_connection_ids(user_id) {
725 self.peer.send(
726 connection_id,
727 proto::UpdateInviteInfo {
728 url: format!(
729 "{}{}",
730 self.app_state.config.invite_link_prefix, invite_code
731 ),
732 count: user.invite_count as u32,
733 },
734 )?;
735 }
736 }
737 }
738 Ok(())
739 }
740
741 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
742 ServerSnapshot {
743 connection_pool: ConnectionPoolGuard {
744 guard: self.connection_pool.lock(),
745 _not_send: PhantomData,
746 },
747 peer: &self.peer,
748 }
749 }
750}
751
752impl<'a> Deref for ConnectionPoolGuard<'a> {
753 type Target = ConnectionPool;
754
755 fn deref(&self) -> &Self::Target {
756 &*self.guard
757 }
758}
759
760impl<'a> DerefMut for ConnectionPoolGuard<'a> {
761 fn deref_mut(&mut self) -> &mut Self::Target {
762 &mut *self.guard
763 }
764}
765
766impl<'a> Drop for ConnectionPoolGuard<'a> {
767 fn drop(&mut self) {
768 #[cfg(test)]
769 self.check_invariants();
770 }
771}
772
773fn broadcast<F>(
774 sender_id: Option<ConnectionId>,
775 receiver_ids: impl IntoIterator<Item = ConnectionId>,
776 mut f: F,
777) where
778 F: FnMut(ConnectionId) -> anyhow::Result<()>,
779{
780 for receiver_id in receiver_ids {
781 if Some(receiver_id) != sender_id {
782 if let Err(error) = f(receiver_id) {
783 tracing::error!("failed to send to {:?} {}", receiver_id, error);
784 }
785 }
786 }
787}
788
789pub struct ProtocolVersion(u32);
790
791impl Header for ProtocolVersion {
792 fn name() -> &'static HeaderName {
793 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
794 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
795 }
796
797 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
798 where
799 Self: Sized,
800 I: Iterator<Item = &'i axum::http::HeaderValue>,
801 {
802 let version = values
803 .next()
804 .ok_or_else(axum::headers::Error::invalid)?
805 .to_str()
806 .map_err(|_| axum::headers::Error::invalid())?
807 .parse()
808 .map_err(|_| axum::headers::Error::invalid())?;
809 Ok(Self(version))
810 }
811
812 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
813 values.extend([self.0.to_string().parse().unwrap()]);
814 }
815}
816
817pub struct AppVersionHeader(SemanticVersion);
818impl Header for AppVersionHeader {
819 fn name() -> &'static HeaderName {
820 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
821 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
822 }
823
824 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
825 where
826 Self: Sized,
827 I: Iterator<Item = &'i axum::http::HeaderValue>,
828 {
829 let version = values
830 .next()
831 .ok_or_else(axum::headers::Error::invalid)?
832 .to_str()
833 .map_err(|_| axum::headers::Error::invalid())?
834 .parse()
835 .map_err(|_| axum::headers::Error::invalid())?;
836 Ok(Self(version))
837 }
838
839 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
840 values.extend([self.0.to_string().parse().unwrap()]);
841 }
842}
843
844pub fn routes(server: Arc<Server>) -> Router<Body> {
845 Router::new()
846 .route("/rpc", get(handle_websocket_request))
847 .layer(
848 ServiceBuilder::new()
849 .layer(Extension(server.app_state.clone()))
850 .layer(middleware::from_fn(auth::validate_header)),
851 )
852 .route("/metrics", get(handle_metrics))
853 .layer(Extension(server))
854}
855
856pub async fn handle_websocket_request(
857 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
858 app_version_header: Option<TypedHeader<AppVersionHeader>>,
859 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
860 Extension(server): Extension<Arc<Server>>,
861 Extension(user): Extension<User>,
862 Extension(impersonator): Extension<Impersonator>,
863 ws: WebSocketUpgrade,
864) -> axum::response::Response {
865 if protocol_version != rpc::PROTOCOL_VERSION {
866 return (
867 StatusCode::UPGRADE_REQUIRED,
868 "client must be upgraded".to_string(),
869 )
870 .into_response();
871 }
872
873 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
874 return (
875 StatusCode::UPGRADE_REQUIRED,
876 "no version header found".to_string(),
877 )
878 .into_response();
879 };
880
881 if !version.is_supported() {
882 return (
883 StatusCode::UPGRADE_REQUIRED,
884 "client must be upgraded".to_string(),
885 )
886 .into_response();
887 }
888
889 let socket_address = socket_address.to_string();
890 ws.on_upgrade(move |socket| {
891 use util::ResultExt;
892 let socket = socket
893 .map_ok(to_tungstenite_message)
894 .err_into()
895 .with(|message| async move { Ok(to_axum_message(message)) });
896 let connection = Connection::new(Box::pin(socket));
897 async move {
898 server
899 .handle_connection(
900 connection,
901 socket_address,
902 user,
903 version,
904 impersonator.0,
905 None,
906 Executor::Production,
907 )
908 .await
909 .log_err();
910 }
911 })
912}
913
914pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
915 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
916 let connections_metric = CONNECTIONS_METRIC
917 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
918
919 let connections = server
920 .connection_pool
921 .lock()
922 .connections()
923 .filter(|connection| !connection.admin)
924 .count();
925 connections_metric.set(connections as _);
926
927 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
928 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
929 register_int_gauge!(
930 "shared_projects",
931 "number of open projects with one or more guests"
932 )
933 .unwrap()
934 });
935
936 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
937 shared_projects_metric.set(shared_projects as _);
938
939 let encoder = prometheus::TextEncoder::new();
940 let metric_families = prometheus::gather();
941 let encoded_metrics = encoder
942 .encode_to_string(&metric_families)
943 .map_err(|err| anyhow!("{}", err))?;
944 Ok(encoded_metrics)
945}
946
947#[instrument(err, skip(executor))]
948async fn connection_lost(
949 session: Session,
950 mut teardown: watch::Receiver<bool>,
951 executor: Executor,
952) -> Result<()> {
953 session.peer.disconnect(session.connection_id);
954 session
955 .connection_pool()
956 .await
957 .remove_connection(session.connection_id)?;
958
959 session
960 .db()
961 .await
962 .connection_lost(session.connection_id)
963 .await
964 .trace_err();
965
966 futures::select_biased! {
967 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
968 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
969 leave_room_for_session(&session).await.trace_err();
970 leave_channel_buffers_for_session(&session)
971 .await
972 .trace_err();
973
974 if !session
975 .connection_pool()
976 .await
977 .is_user_online(session.user_id)
978 {
979 let db = session.db().await;
980 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
981 room_updated(&room, &session.peer);
982 }
983 }
984
985 update_user_contacts(session.user_id, &session).await?;
986 }
987 _ = teardown.changed().fuse() => {}
988 }
989
990 Ok(())
991}
992
993/// Acknowledges a ping from a client, used to keep the connection alive.
994async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
995 response.send(proto::Ack {})?;
996 Ok(())
997}
998
999/// Creates a new room for calling (outside of channels)
1000async fn create_room(
1001 _request: proto::CreateRoom,
1002 response: Response<proto::CreateRoom>,
1003 session: Session,
1004) -> Result<()> {
1005 let live_kit_room = nanoid::nanoid!(30);
1006
1007 let live_kit_connection_info = {
1008 let live_kit_room = live_kit_room.clone();
1009 let live_kit = session.live_kit_client.as_ref();
1010
1011 util::async_maybe!({
1012 let live_kit = live_kit?;
1013
1014 let token = live_kit
1015 .room_token(&live_kit_room, &session.user_id.to_string())
1016 .trace_err()?;
1017
1018 Some(proto::LiveKitConnectionInfo {
1019 server_url: live_kit.url().into(),
1020 token,
1021 can_publish: true,
1022 })
1023 })
1024 }
1025 .await;
1026
1027 let room = session
1028 .db()
1029 .await
1030 .create_room(session.user_id, session.connection_id, &live_kit_room)
1031 .await?;
1032
1033 response.send(proto::CreateRoomResponse {
1034 room: Some(room.clone()),
1035 live_kit_connection_info,
1036 })?;
1037
1038 update_user_contacts(session.user_id, &session).await?;
1039 Ok(())
1040}
1041
1042/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1043async fn join_room(
1044 request: proto::JoinRoom,
1045 response: Response<proto::JoinRoom>,
1046 session: Session,
1047) -> Result<()> {
1048 let room_id = RoomId::from_proto(request.id);
1049
1050 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1051
1052 if let Some(channel_id) = channel_id {
1053 return join_channel_internal(channel_id, Box::new(response), session).await;
1054 }
1055
1056 let joined_room = {
1057 let room = session
1058 .db()
1059 .await
1060 .join_room(room_id, session.user_id, session.connection_id)
1061 .await?;
1062 room_updated(&room.room, &session.peer);
1063 room.into_inner()
1064 };
1065
1066 for connection_id in session
1067 .connection_pool()
1068 .await
1069 .user_connection_ids(session.user_id)
1070 {
1071 session
1072 .peer
1073 .send(
1074 connection_id,
1075 proto::CallCanceled {
1076 room_id: room_id.to_proto(),
1077 },
1078 )
1079 .trace_err();
1080 }
1081
1082 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1083 if let Some(token) = live_kit
1084 .room_token(
1085 &joined_room.room.live_kit_room,
1086 &session.user_id.to_string(),
1087 )
1088 .trace_err()
1089 {
1090 Some(proto::LiveKitConnectionInfo {
1091 server_url: live_kit.url().into(),
1092 token,
1093 can_publish: true,
1094 })
1095 } else {
1096 None
1097 }
1098 } else {
1099 None
1100 };
1101
1102 response.send(proto::JoinRoomResponse {
1103 room: Some(joined_room.room),
1104 channel_id: None,
1105 live_kit_connection_info,
1106 })?;
1107
1108 update_user_contacts(session.user_id, &session).await?;
1109 Ok(())
1110}
1111
1112/// Rejoin room is used to reconnect to a room after connection errors.
1113async fn rejoin_room(
1114 request: proto::RejoinRoom,
1115 response: Response<proto::RejoinRoom>,
1116 session: Session,
1117) -> Result<()> {
1118 let room;
1119 let channel_id;
1120 let channel_members;
1121 {
1122 let mut rejoined_room = session
1123 .db()
1124 .await
1125 .rejoin_room(request, session.user_id, session.connection_id)
1126 .await?;
1127
1128 response.send(proto::RejoinRoomResponse {
1129 room: Some(rejoined_room.room.clone()),
1130 reshared_projects: rejoined_room
1131 .reshared_projects
1132 .iter()
1133 .map(|project| proto::ResharedProject {
1134 id: project.id.to_proto(),
1135 collaborators: project
1136 .collaborators
1137 .iter()
1138 .map(|collaborator| collaborator.to_proto())
1139 .collect(),
1140 })
1141 .collect(),
1142 rejoined_projects: rejoined_room
1143 .rejoined_projects
1144 .iter()
1145 .map(|rejoined_project| proto::RejoinedProject {
1146 id: rejoined_project.id.to_proto(),
1147 worktrees: rejoined_project
1148 .worktrees
1149 .iter()
1150 .map(|worktree| proto::WorktreeMetadata {
1151 id: worktree.id,
1152 root_name: worktree.root_name.clone(),
1153 visible: worktree.visible,
1154 abs_path: worktree.abs_path.clone(),
1155 })
1156 .collect(),
1157 collaborators: rejoined_project
1158 .collaborators
1159 .iter()
1160 .map(|collaborator| collaborator.to_proto())
1161 .collect(),
1162 language_servers: rejoined_project.language_servers.clone(),
1163 })
1164 .collect(),
1165 })?;
1166 room_updated(&rejoined_room.room, &session.peer);
1167
1168 for project in &rejoined_room.reshared_projects {
1169 for collaborator in &project.collaborators {
1170 session
1171 .peer
1172 .send(
1173 collaborator.connection_id,
1174 proto::UpdateProjectCollaborator {
1175 project_id: project.id.to_proto(),
1176 old_peer_id: Some(project.old_connection_id.into()),
1177 new_peer_id: Some(session.connection_id.into()),
1178 },
1179 )
1180 .trace_err();
1181 }
1182
1183 broadcast(
1184 Some(session.connection_id),
1185 project
1186 .collaborators
1187 .iter()
1188 .map(|collaborator| collaborator.connection_id),
1189 |connection_id| {
1190 session.peer.forward_send(
1191 session.connection_id,
1192 connection_id,
1193 proto::UpdateProject {
1194 project_id: project.id.to_proto(),
1195 worktrees: project.worktrees.clone(),
1196 },
1197 )
1198 },
1199 );
1200 }
1201
1202 for project in &rejoined_room.rejoined_projects {
1203 for collaborator in &project.collaborators {
1204 session
1205 .peer
1206 .send(
1207 collaborator.connection_id,
1208 proto::UpdateProjectCollaborator {
1209 project_id: project.id.to_proto(),
1210 old_peer_id: Some(project.old_connection_id.into()),
1211 new_peer_id: Some(session.connection_id.into()),
1212 },
1213 )
1214 .trace_err();
1215 }
1216 }
1217
1218 for project in &mut rejoined_room.rejoined_projects {
1219 for worktree in mem::take(&mut project.worktrees) {
1220 #[cfg(any(test, feature = "test-support"))]
1221 const MAX_CHUNK_SIZE: usize = 2;
1222 #[cfg(not(any(test, feature = "test-support")))]
1223 const MAX_CHUNK_SIZE: usize = 256;
1224
1225 // Stream this worktree's entries.
1226 let message = proto::UpdateWorktree {
1227 project_id: project.id.to_proto(),
1228 worktree_id: worktree.id,
1229 abs_path: worktree.abs_path.clone(),
1230 root_name: worktree.root_name,
1231 updated_entries: worktree.updated_entries,
1232 removed_entries: worktree.removed_entries,
1233 scan_id: worktree.scan_id,
1234 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1235 updated_repositories: worktree.updated_repositories,
1236 removed_repositories: worktree.removed_repositories,
1237 };
1238 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1239 session.peer.send(session.connection_id, update.clone())?;
1240 }
1241
1242 // Stream this worktree's diagnostics.
1243 for summary in worktree.diagnostic_summaries {
1244 session.peer.send(
1245 session.connection_id,
1246 proto::UpdateDiagnosticSummary {
1247 project_id: project.id.to_proto(),
1248 worktree_id: worktree.id,
1249 summary: Some(summary),
1250 },
1251 )?;
1252 }
1253
1254 for settings_file in worktree.settings_files {
1255 session.peer.send(
1256 session.connection_id,
1257 proto::UpdateWorktreeSettings {
1258 project_id: project.id.to_proto(),
1259 worktree_id: worktree.id,
1260 path: settings_file.path,
1261 content: Some(settings_file.content),
1262 },
1263 )?;
1264 }
1265 }
1266
1267 for language_server in &project.language_servers {
1268 session.peer.send(
1269 session.connection_id,
1270 proto::UpdateLanguageServer {
1271 project_id: project.id.to_proto(),
1272 language_server_id: language_server.id,
1273 variant: Some(
1274 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1275 proto::LspDiskBasedDiagnosticsUpdated {},
1276 ),
1277 ),
1278 },
1279 )?;
1280 }
1281 }
1282
1283 let rejoined_room = rejoined_room.into_inner();
1284
1285 room = rejoined_room.room;
1286 channel_id = rejoined_room.channel_id;
1287 channel_members = rejoined_room.channel_members;
1288 }
1289
1290 if let Some(channel_id) = channel_id {
1291 channel_updated(
1292 channel_id,
1293 &room,
1294 &channel_members,
1295 &session.peer,
1296 &*session.connection_pool().await,
1297 );
1298 }
1299
1300 update_user_contacts(session.user_id, &session).await?;
1301 Ok(())
1302}
1303
1304/// leave room disconnects from the room.
1305async fn leave_room(
1306 _: proto::LeaveRoom,
1307 response: Response<proto::LeaveRoom>,
1308 session: Session,
1309) -> Result<()> {
1310 leave_room_for_session(&session).await?;
1311 response.send(proto::Ack {})?;
1312 Ok(())
1313}
1314
1315/// Updates the permissions of someone else in the room.
1316async fn set_room_participant_role(
1317 request: proto::SetRoomParticipantRole,
1318 response: Response<proto::SetRoomParticipantRole>,
1319 session: Session,
1320) -> Result<()> {
1321 let user_id = UserId::from_proto(request.user_id);
1322 let role = ChannelRole::from(request.role());
1323
1324 if role == ChannelRole::Talker {
1325 let pool = session.connection_pool().await;
1326
1327 for connection in pool.user_connections(user_id) {
1328 if !connection.zed_version.supports_talker_role() {
1329 Err(anyhow!(
1330 "This user is on zed {} which does not support unmute",
1331 connection.zed_version
1332 ))?;
1333 }
1334 }
1335 }
1336
1337 let (live_kit_room, can_publish) = {
1338 let room = session
1339 .db()
1340 .await
1341 .set_room_participant_role(
1342 session.user_id,
1343 RoomId::from_proto(request.room_id),
1344 user_id,
1345 role,
1346 )
1347 .await?;
1348
1349 let live_kit_room = room.live_kit_room.clone();
1350 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1351 room_updated(&room, &session.peer);
1352 (live_kit_room, can_publish)
1353 };
1354
1355 if let Some(live_kit) = session.live_kit_client.as_ref() {
1356 live_kit
1357 .update_participant(
1358 live_kit_room.clone(),
1359 request.user_id.to_string(),
1360 live_kit_server::proto::ParticipantPermission {
1361 can_subscribe: true,
1362 can_publish,
1363 can_publish_data: can_publish,
1364 hidden: false,
1365 recorder: false,
1366 },
1367 )
1368 .await
1369 .trace_err();
1370 }
1371
1372 response.send(proto::Ack {})?;
1373 Ok(())
1374}
1375
1376/// Call someone else into the current room
1377async fn call(
1378 request: proto::Call,
1379 response: Response<proto::Call>,
1380 session: Session,
1381) -> Result<()> {
1382 let room_id = RoomId::from_proto(request.room_id);
1383 let calling_user_id = session.user_id;
1384 let calling_connection_id = session.connection_id;
1385 let called_user_id = UserId::from_proto(request.called_user_id);
1386 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1387 if !session
1388 .db()
1389 .await
1390 .has_contact(calling_user_id, called_user_id)
1391 .await?
1392 {
1393 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1394 }
1395
1396 let incoming_call = {
1397 let (room, incoming_call) = &mut *session
1398 .db()
1399 .await
1400 .call(
1401 room_id,
1402 calling_user_id,
1403 calling_connection_id,
1404 called_user_id,
1405 initial_project_id,
1406 )
1407 .await?;
1408 room_updated(&room, &session.peer);
1409 mem::take(incoming_call)
1410 };
1411 update_user_contacts(called_user_id, &session).await?;
1412
1413 let mut calls = session
1414 .connection_pool()
1415 .await
1416 .user_connection_ids(called_user_id)
1417 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1418 .collect::<FuturesUnordered<_>>();
1419
1420 while let Some(call_response) = calls.next().await {
1421 match call_response.as_ref() {
1422 Ok(_) => {
1423 response.send(proto::Ack {})?;
1424 return Ok(());
1425 }
1426 Err(_) => {
1427 call_response.trace_err();
1428 }
1429 }
1430 }
1431
1432 {
1433 let room = session
1434 .db()
1435 .await
1436 .call_failed(room_id, called_user_id)
1437 .await?;
1438 room_updated(&room, &session.peer);
1439 }
1440 update_user_contacts(called_user_id, &session).await?;
1441
1442 Err(anyhow!("failed to ring user"))?
1443}
1444
1445/// Cancel an outgoing call.
1446async fn cancel_call(
1447 request: proto::CancelCall,
1448 response: Response<proto::CancelCall>,
1449 session: Session,
1450) -> Result<()> {
1451 let called_user_id = UserId::from_proto(request.called_user_id);
1452 let room_id = RoomId::from_proto(request.room_id);
1453 {
1454 let room = session
1455 .db()
1456 .await
1457 .cancel_call(room_id, session.connection_id, called_user_id)
1458 .await?;
1459 room_updated(&room, &session.peer);
1460 }
1461
1462 for connection_id in session
1463 .connection_pool()
1464 .await
1465 .user_connection_ids(called_user_id)
1466 {
1467 session
1468 .peer
1469 .send(
1470 connection_id,
1471 proto::CallCanceled {
1472 room_id: room_id.to_proto(),
1473 },
1474 )
1475 .trace_err();
1476 }
1477 response.send(proto::Ack {})?;
1478
1479 update_user_contacts(called_user_id, &session).await?;
1480 Ok(())
1481}
1482
1483/// Decline an incoming call.
1484async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1485 let room_id = RoomId::from_proto(message.room_id);
1486 {
1487 let room = session
1488 .db()
1489 .await
1490 .decline_call(Some(room_id), session.user_id)
1491 .await?
1492 .ok_or_else(|| anyhow!("failed to decline call"))?;
1493 room_updated(&room, &session.peer);
1494 }
1495
1496 for connection_id in session
1497 .connection_pool()
1498 .await
1499 .user_connection_ids(session.user_id)
1500 {
1501 session
1502 .peer
1503 .send(
1504 connection_id,
1505 proto::CallCanceled {
1506 room_id: room_id.to_proto(),
1507 },
1508 )
1509 .trace_err();
1510 }
1511 update_user_contacts(session.user_id, &session).await?;
1512 Ok(())
1513}
1514
1515/// Updates other participants in the room with your current location.
1516async fn update_participant_location(
1517 request: proto::UpdateParticipantLocation,
1518 response: Response<proto::UpdateParticipantLocation>,
1519 session: Session,
1520) -> Result<()> {
1521 let room_id = RoomId::from_proto(request.room_id);
1522 let location = request
1523 .location
1524 .ok_or_else(|| anyhow!("invalid location"))?;
1525
1526 let db = session.db().await;
1527 let room = db
1528 .update_room_participant_location(room_id, session.connection_id, location)
1529 .await?;
1530
1531 room_updated(&room, &session.peer);
1532 response.send(proto::Ack {})?;
1533 Ok(())
1534}
1535
1536/// Share a project into the room.
1537async fn share_project(
1538 request: proto::ShareProject,
1539 response: Response<proto::ShareProject>,
1540 session: Session,
1541) -> Result<()> {
1542 let (project_id, room) = &*session
1543 .db()
1544 .await
1545 .share_project(
1546 RoomId::from_proto(request.room_id),
1547 session.connection_id,
1548 &request.worktrees,
1549 )
1550 .await?;
1551 response.send(proto::ShareProjectResponse {
1552 project_id: project_id.to_proto(),
1553 })?;
1554 room_updated(&room, &session.peer);
1555
1556 Ok(())
1557}
1558
1559/// Unshare a project from the room.
1560async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1561 let project_id = ProjectId::from_proto(message.project_id);
1562
1563 let (room, guest_connection_ids) = &*session
1564 .db()
1565 .await
1566 .unshare_project(project_id, session.connection_id)
1567 .await?;
1568
1569 broadcast(
1570 Some(session.connection_id),
1571 guest_connection_ids.iter().copied(),
1572 |conn_id| session.peer.send(conn_id, message.clone()),
1573 );
1574 room_updated(&room, &session.peer);
1575
1576 Ok(())
1577}
1578
1579/// Join someone elses shared project.
1580async fn join_project(
1581 request: proto::JoinProject,
1582 response: Response<proto::JoinProject>,
1583 session: Session,
1584) -> Result<()> {
1585 let project_id = ProjectId::from_proto(request.project_id);
1586 let guest_user_id = session.user_id;
1587
1588 tracing::info!(%project_id, "join project");
1589
1590 let (project, replica_id) = &mut *session
1591 .db()
1592 .await
1593 .join_project(project_id, session.connection_id)
1594 .await?;
1595
1596 let collaborators = project
1597 .collaborators
1598 .iter()
1599 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1600 .map(|collaborator| collaborator.to_proto())
1601 .collect::<Vec<_>>();
1602
1603 let worktrees = project
1604 .worktrees
1605 .iter()
1606 .map(|(id, worktree)| proto::WorktreeMetadata {
1607 id: *id,
1608 root_name: worktree.root_name.clone(),
1609 visible: worktree.visible,
1610 abs_path: worktree.abs_path.clone(),
1611 })
1612 .collect::<Vec<_>>();
1613
1614 for collaborator in &collaborators {
1615 session
1616 .peer
1617 .send(
1618 collaborator.peer_id.unwrap().into(),
1619 proto::AddProjectCollaborator {
1620 project_id: project_id.to_proto(),
1621 collaborator: Some(proto::Collaborator {
1622 peer_id: Some(session.connection_id.into()),
1623 replica_id: replica_id.0 as u32,
1624 user_id: guest_user_id.to_proto(),
1625 }),
1626 },
1627 )
1628 .trace_err();
1629 }
1630
1631 // First, we send the metadata associated with each worktree.
1632 response.send(proto::JoinProjectResponse {
1633 worktrees: worktrees.clone(),
1634 replica_id: replica_id.0 as u32,
1635 collaborators: collaborators.clone(),
1636 language_servers: project.language_servers.clone(),
1637 })?;
1638
1639 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1640 #[cfg(any(test, feature = "test-support"))]
1641 const MAX_CHUNK_SIZE: usize = 2;
1642 #[cfg(not(any(test, feature = "test-support")))]
1643 const MAX_CHUNK_SIZE: usize = 256;
1644
1645 // Stream this worktree's entries.
1646 let message = proto::UpdateWorktree {
1647 project_id: project_id.to_proto(),
1648 worktree_id,
1649 abs_path: worktree.abs_path.clone(),
1650 root_name: worktree.root_name,
1651 updated_entries: worktree.entries,
1652 removed_entries: Default::default(),
1653 scan_id: worktree.scan_id,
1654 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1655 updated_repositories: worktree.repository_entries.into_values().collect(),
1656 removed_repositories: Default::default(),
1657 };
1658 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1659 session.peer.send(session.connection_id, update.clone())?;
1660 }
1661
1662 // Stream this worktree's diagnostics.
1663 for summary in worktree.diagnostic_summaries {
1664 session.peer.send(
1665 session.connection_id,
1666 proto::UpdateDiagnosticSummary {
1667 project_id: project_id.to_proto(),
1668 worktree_id: worktree.id,
1669 summary: Some(summary),
1670 },
1671 )?;
1672 }
1673
1674 for settings_file in worktree.settings_files {
1675 session.peer.send(
1676 session.connection_id,
1677 proto::UpdateWorktreeSettings {
1678 project_id: project_id.to_proto(),
1679 worktree_id: worktree.id,
1680 path: settings_file.path,
1681 content: Some(settings_file.content),
1682 },
1683 )?;
1684 }
1685 }
1686
1687 for language_server in &project.language_servers {
1688 session.peer.send(
1689 session.connection_id,
1690 proto::UpdateLanguageServer {
1691 project_id: project_id.to_proto(),
1692 language_server_id: language_server.id,
1693 variant: Some(
1694 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1695 proto::LspDiskBasedDiagnosticsUpdated {},
1696 ),
1697 ),
1698 },
1699 )?;
1700 }
1701
1702 Ok(())
1703}
1704
1705/// Leave someone elses shared project.
1706async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1707 let sender_id = session.connection_id;
1708 let project_id = ProjectId::from_proto(request.project_id);
1709
1710 let (room, project) = &*session
1711 .db()
1712 .await
1713 .leave_project(project_id, sender_id)
1714 .await?;
1715 tracing::info!(
1716 %project_id,
1717 host_user_id = %project.host_user_id,
1718 host_connection_id = ?project.host_connection_id,
1719 "leave project"
1720 );
1721
1722 project_left(&project, &session);
1723 room_updated(&room, &session.peer);
1724
1725 Ok(())
1726}
1727
1728/// Updates other participants with changes to the project
1729async fn update_project(
1730 request: proto::UpdateProject,
1731 response: Response<proto::UpdateProject>,
1732 session: Session,
1733) -> Result<()> {
1734 let project_id = ProjectId::from_proto(request.project_id);
1735 let (room, guest_connection_ids) = &*session
1736 .db()
1737 .await
1738 .update_project(project_id, session.connection_id, &request.worktrees)
1739 .await?;
1740 broadcast(
1741 Some(session.connection_id),
1742 guest_connection_ids.iter().copied(),
1743 |connection_id| {
1744 session
1745 .peer
1746 .forward_send(session.connection_id, connection_id, request.clone())
1747 },
1748 );
1749 room_updated(&room, &session.peer);
1750 response.send(proto::Ack {})?;
1751
1752 Ok(())
1753}
1754
1755/// Updates other participants with changes to the worktree
1756async fn update_worktree(
1757 request: proto::UpdateWorktree,
1758 response: Response<proto::UpdateWorktree>,
1759 session: Session,
1760) -> Result<()> {
1761 let guest_connection_ids = session
1762 .db()
1763 .await
1764 .update_worktree(&request, session.connection_id)
1765 .await?;
1766
1767 broadcast(
1768 Some(session.connection_id),
1769 guest_connection_ids.iter().copied(),
1770 |connection_id| {
1771 session
1772 .peer
1773 .forward_send(session.connection_id, connection_id, request.clone())
1774 },
1775 );
1776 response.send(proto::Ack {})?;
1777 Ok(())
1778}
1779
1780/// Updates other participants with changes to the diagnostics
1781async fn update_diagnostic_summary(
1782 message: proto::UpdateDiagnosticSummary,
1783 session: Session,
1784) -> Result<()> {
1785 let guest_connection_ids = session
1786 .db()
1787 .await
1788 .update_diagnostic_summary(&message, session.connection_id)
1789 .await?;
1790
1791 broadcast(
1792 Some(session.connection_id),
1793 guest_connection_ids.iter().copied(),
1794 |connection_id| {
1795 session
1796 .peer
1797 .forward_send(session.connection_id, connection_id, message.clone())
1798 },
1799 );
1800
1801 Ok(())
1802}
1803
1804/// Updates other participants with changes to the worktree settings
1805async fn update_worktree_settings(
1806 message: proto::UpdateWorktreeSettings,
1807 session: Session,
1808) -> Result<()> {
1809 let guest_connection_ids = session
1810 .db()
1811 .await
1812 .update_worktree_settings(&message, session.connection_id)
1813 .await?;
1814
1815 broadcast(
1816 Some(session.connection_id),
1817 guest_connection_ids.iter().copied(),
1818 |connection_id| {
1819 session
1820 .peer
1821 .forward_send(session.connection_id, connection_id, message.clone())
1822 },
1823 );
1824
1825 Ok(())
1826}
1827
1828/// Notify other participants that a language server has started.
1829async fn start_language_server(
1830 request: proto::StartLanguageServer,
1831 session: Session,
1832) -> Result<()> {
1833 let guest_connection_ids = session
1834 .db()
1835 .await
1836 .start_language_server(&request, session.connection_id)
1837 .await?;
1838
1839 broadcast(
1840 Some(session.connection_id),
1841 guest_connection_ids.iter().copied(),
1842 |connection_id| {
1843 session
1844 .peer
1845 .forward_send(session.connection_id, connection_id, request.clone())
1846 },
1847 );
1848 Ok(())
1849}
1850
1851/// Notify other participants that a language server has changed.
1852async fn update_language_server(
1853 request: proto::UpdateLanguageServer,
1854 session: Session,
1855) -> Result<()> {
1856 let project_id = ProjectId::from_proto(request.project_id);
1857 let project_connection_ids = session
1858 .db()
1859 .await
1860 .project_connection_ids(project_id, session.connection_id)
1861 .await?;
1862 broadcast(
1863 Some(session.connection_id),
1864 project_connection_ids.iter().copied(),
1865 |connection_id| {
1866 session
1867 .peer
1868 .forward_send(session.connection_id, connection_id, request.clone())
1869 },
1870 );
1871 Ok(())
1872}
1873
1874/// forward a project request to the host. These requests should be read only
1875/// as guests are allowed to send them.
1876async fn forward_read_only_project_request<T>(
1877 request: T,
1878 response: Response<T>,
1879 session: Session,
1880) -> Result<()>
1881where
1882 T: EntityMessage + RequestMessage,
1883{
1884 let project_id = ProjectId::from_proto(request.remote_entity_id());
1885 let host_connection_id = session
1886 .db()
1887 .await
1888 .host_for_read_only_project_request(project_id, session.connection_id)
1889 .await?;
1890 let payload = session
1891 .peer
1892 .forward_request(session.connection_id, host_connection_id, request)
1893 .await?;
1894 response.send(payload)?;
1895 Ok(())
1896}
1897
1898/// forward a project request to the host. These requests are disallowed
1899/// for guests.
1900async fn forward_mutating_project_request<T>(
1901 request: T,
1902 response: Response<T>,
1903 session: Session,
1904) -> Result<()>
1905where
1906 T: EntityMessage + RequestMessage,
1907{
1908 let project_id = ProjectId::from_proto(request.remote_entity_id());
1909 let host_connection_id = session
1910 .db()
1911 .await
1912 .host_for_mutating_project_request(project_id, session.connection_id)
1913 .await?;
1914 let payload = session
1915 .peer
1916 .forward_request(session.connection_id, host_connection_id, request)
1917 .await?;
1918 response.send(payload)?;
1919 Ok(())
1920}
1921
1922/// Notify other participants that a new buffer has been created
1923async fn create_buffer_for_peer(
1924 request: proto::CreateBufferForPeer,
1925 session: Session,
1926) -> Result<()> {
1927 session
1928 .db()
1929 .await
1930 .check_user_is_project_host(
1931 ProjectId::from_proto(request.project_id),
1932 session.connection_id,
1933 )
1934 .await?;
1935 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1936 session
1937 .peer
1938 .forward_send(session.connection_id, peer_id.into(), request)?;
1939 Ok(())
1940}
1941
1942/// Notify other participants that a buffer has been updated. This is
1943/// allowed for guests as long as the update is limited to selections.
1944async fn update_buffer(
1945 request: proto::UpdateBuffer,
1946 response: Response<proto::UpdateBuffer>,
1947 session: Session,
1948) -> Result<()> {
1949 let project_id = ProjectId::from_proto(request.project_id);
1950 let mut guest_connection_ids;
1951 let mut host_connection_id = None;
1952
1953 let mut requires_write_permission = false;
1954
1955 for op in request.operations.iter() {
1956 match op.variant {
1957 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1958 Some(_) => requires_write_permission = true,
1959 }
1960 }
1961
1962 {
1963 let collaborators = session
1964 .db()
1965 .await
1966 .project_collaborators_for_buffer_update(
1967 project_id,
1968 session.connection_id,
1969 requires_write_permission,
1970 )
1971 .await?;
1972 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1973 for collaborator in collaborators.iter() {
1974 if collaborator.is_host {
1975 host_connection_id = Some(collaborator.connection_id);
1976 } else {
1977 guest_connection_ids.push(collaborator.connection_id);
1978 }
1979 }
1980 }
1981 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1982
1983 broadcast(
1984 Some(session.connection_id),
1985 guest_connection_ids,
1986 |connection_id| {
1987 session
1988 .peer
1989 .forward_send(session.connection_id, connection_id, request.clone())
1990 },
1991 );
1992 if host_connection_id != session.connection_id {
1993 session
1994 .peer
1995 .forward_request(session.connection_id, host_connection_id, request.clone())
1996 .await?;
1997 }
1998
1999 response.send(proto::Ack {})?;
2000 Ok(())
2001}
2002
2003/// Notify other participants that a project has been updated.
2004async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2005 request: T,
2006 session: Session,
2007) -> Result<()> {
2008 let project_id = ProjectId::from_proto(request.remote_entity_id());
2009 let project_connection_ids = session
2010 .db()
2011 .await
2012 .project_connection_ids(project_id, session.connection_id)
2013 .await?;
2014
2015 broadcast(
2016 Some(session.connection_id),
2017 project_connection_ids.iter().copied(),
2018 |connection_id| {
2019 session
2020 .peer
2021 .forward_send(session.connection_id, connection_id, request.clone())
2022 },
2023 );
2024 Ok(())
2025}
2026
2027/// Start following another user in a call.
2028async fn follow(
2029 request: proto::Follow,
2030 response: Response<proto::Follow>,
2031 session: Session,
2032) -> Result<()> {
2033 let room_id = RoomId::from_proto(request.room_id);
2034 let project_id = request.project_id.map(ProjectId::from_proto);
2035 let leader_id = request
2036 .leader_id
2037 .ok_or_else(|| anyhow!("invalid leader id"))?
2038 .into();
2039 let follower_id = session.connection_id;
2040
2041 session
2042 .db()
2043 .await
2044 .check_room_participants(room_id, leader_id, session.connection_id)
2045 .await?;
2046
2047 let response_payload = session
2048 .peer
2049 .forward_request(session.connection_id, leader_id, request)
2050 .await?;
2051 response.send(response_payload)?;
2052
2053 if let Some(project_id) = project_id {
2054 let room = session
2055 .db()
2056 .await
2057 .follow(room_id, project_id, leader_id, follower_id)
2058 .await?;
2059 room_updated(&room, &session.peer);
2060 }
2061
2062 Ok(())
2063}
2064
2065/// Stop following another user in a call.
2066async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2067 let room_id = RoomId::from_proto(request.room_id);
2068 let project_id = request.project_id.map(ProjectId::from_proto);
2069 let leader_id = request
2070 .leader_id
2071 .ok_or_else(|| anyhow!("invalid leader id"))?
2072 .into();
2073 let follower_id = session.connection_id;
2074
2075 session
2076 .db()
2077 .await
2078 .check_room_participants(room_id, leader_id, session.connection_id)
2079 .await?;
2080
2081 session
2082 .peer
2083 .forward_send(session.connection_id, leader_id, request)?;
2084
2085 if let Some(project_id) = project_id {
2086 let room = session
2087 .db()
2088 .await
2089 .unfollow(room_id, project_id, leader_id, follower_id)
2090 .await?;
2091 room_updated(&room, &session.peer);
2092 }
2093
2094 Ok(())
2095}
2096
2097/// Notify everyone following you of your current location.
2098async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2099 let room_id = RoomId::from_proto(request.room_id);
2100 let database = session.db.lock().await;
2101
2102 let connection_ids = if let Some(project_id) = request.project_id {
2103 let project_id = ProjectId::from_proto(project_id);
2104 database
2105 .project_connection_ids(project_id, session.connection_id)
2106 .await?
2107 } else {
2108 database
2109 .room_connection_ids(room_id, session.connection_id)
2110 .await?
2111 };
2112
2113 // For now, don't send view update messages back to that view's current leader.
2114 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2115 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2116 _ => None,
2117 });
2118
2119 for connection_id in connection_ids.iter().cloned() {
2120 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2121 session
2122 .peer
2123 .forward_send(session.connection_id, connection_id, request.clone())?;
2124 }
2125 }
2126 Ok(())
2127}
2128
2129/// Get public data about users.
2130async fn get_users(
2131 request: proto::GetUsers,
2132 response: Response<proto::GetUsers>,
2133 session: Session,
2134) -> Result<()> {
2135 let user_ids = request
2136 .user_ids
2137 .into_iter()
2138 .map(UserId::from_proto)
2139 .collect();
2140 let users = session
2141 .db()
2142 .await
2143 .get_users_by_ids(user_ids)
2144 .await?
2145 .into_iter()
2146 .map(|user| proto::User {
2147 id: user.id.to_proto(),
2148 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2149 github_login: user.github_login,
2150 })
2151 .collect();
2152 response.send(proto::UsersResponse { users })?;
2153 Ok(())
2154}
2155
2156/// Search for users (to invite) buy Github login
2157async fn fuzzy_search_users(
2158 request: proto::FuzzySearchUsers,
2159 response: Response<proto::FuzzySearchUsers>,
2160 session: Session,
2161) -> Result<()> {
2162 let query = request.query;
2163 let users = match query.len() {
2164 0 => vec![],
2165 1 | 2 => session
2166 .db()
2167 .await
2168 .get_user_by_github_login(&query)
2169 .await?
2170 .into_iter()
2171 .collect(),
2172 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2173 };
2174 let users = users
2175 .into_iter()
2176 .filter(|user| user.id != session.user_id)
2177 .map(|user| proto::User {
2178 id: user.id.to_proto(),
2179 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2180 github_login: user.github_login,
2181 })
2182 .collect();
2183 response.send(proto::UsersResponse { users })?;
2184 Ok(())
2185}
2186
2187/// Send a contact request to another user.
2188async fn request_contact(
2189 request: proto::RequestContact,
2190 response: Response<proto::RequestContact>,
2191 session: Session,
2192) -> Result<()> {
2193 let requester_id = session.user_id;
2194 let responder_id = UserId::from_proto(request.responder_id);
2195 if requester_id == responder_id {
2196 return Err(anyhow!("cannot add yourself as a contact"))?;
2197 }
2198
2199 let notifications = session
2200 .db()
2201 .await
2202 .send_contact_request(requester_id, responder_id)
2203 .await?;
2204
2205 // Update outgoing contact requests of requester
2206 let mut update = proto::UpdateContacts::default();
2207 update.outgoing_requests.push(responder_id.to_proto());
2208 for connection_id in session
2209 .connection_pool()
2210 .await
2211 .user_connection_ids(requester_id)
2212 {
2213 session.peer.send(connection_id, update.clone())?;
2214 }
2215
2216 // Update incoming contact requests of responder
2217 let mut update = proto::UpdateContacts::default();
2218 update
2219 .incoming_requests
2220 .push(proto::IncomingContactRequest {
2221 requester_id: requester_id.to_proto(),
2222 });
2223 let connection_pool = session.connection_pool().await;
2224 for connection_id in connection_pool.user_connection_ids(responder_id) {
2225 session.peer.send(connection_id, update.clone())?;
2226 }
2227
2228 send_notifications(&*connection_pool, &session.peer, notifications);
2229
2230 response.send(proto::Ack {})?;
2231 Ok(())
2232}
2233
2234/// Accept or decline a contact request
2235async fn respond_to_contact_request(
2236 request: proto::RespondToContactRequest,
2237 response: Response<proto::RespondToContactRequest>,
2238 session: Session,
2239) -> Result<()> {
2240 let responder_id = session.user_id;
2241 let requester_id = UserId::from_proto(request.requester_id);
2242 let db = session.db().await;
2243 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2244 db.dismiss_contact_notification(responder_id, requester_id)
2245 .await?;
2246 } else {
2247 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2248
2249 let notifications = db
2250 .respond_to_contact_request(responder_id, requester_id, accept)
2251 .await?;
2252 let requester_busy = db.is_user_busy(requester_id).await?;
2253 let responder_busy = db.is_user_busy(responder_id).await?;
2254
2255 let pool = session.connection_pool().await;
2256 // Update responder with new contact
2257 let mut update = proto::UpdateContacts::default();
2258 if accept {
2259 update
2260 .contacts
2261 .push(contact_for_user(requester_id, requester_busy, &pool));
2262 }
2263 update
2264 .remove_incoming_requests
2265 .push(requester_id.to_proto());
2266 for connection_id in pool.user_connection_ids(responder_id) {
2267 session.peer.send(connection_id, update.clone())?;
2268 }
2269
2270 // Update requester with new contact
2271 let mut update = proto::UpdateContacts::default();
2272 if accept {
2273 update
2274 .contacts
2275 .push(contact_for_user(responder_id, responder_busy, &pool));
2276 }
2277 update
2278 .remove_outgoing_requests
2279 .push(responder_id.to_proto());
2280
2281 for connection_id in pool.user_connection_ids(requester_id) {
2282 session.peer.send(connection_id, update.clone())?;
2283 }
2284
2285 send_notifications(&*pool, &session.peer, notifications);
2286 }
2287
2288 response.send(proto::Ack {})?;
2289 Ok(())
2290}
2291
2292/// Remove a contact.
2293async fn remove_contact(
2294 request: proto::RemoveContact,
2295 response: Response<proto::RemoveContact>,
2296 session: Session,
2297) -> Result<()> {
2298 let requester_id = session.user_id;
2299 let responder_id = UserId::from_proto(request.user_id);
2300 let db = session.db().await;
2301 let (contact_accepted, deleted_notification_id) =
2302 db.remove_contact(requester_id, responder_id).await?;
2303
2304 let pool = session.connection_pool().await;
2305 // Update outgoing contact requests of requester
2306 let mut update = proto::UpdateContacts::default();
2307 if contact_accepted {
2308 update.remove_contacts.push(responder_id.to_proto());
2309 } else {
2310 update
2311 .remove_outgoing_requests
2312 .push(responder_id.to_proto());
2313 }
2314 for connection_id in pool.user_connection_ids(requester_id) {
2315 session.peer.send(connection_id, update.clone())?;
2316 }
2317
2318 // Update incoming contact requests of responder
2319 let mut update = proto::UpdateContacts::default();
2320 if contact_accepted {
2321 update.remove_contacts.push(requester_id.to_proto());
2322 } else {
2323 update
2324 .remove_incoming_requests
2325 .push(requester_id.to_proto());
2326 }
2327 for connection_id in pool.user_connection_ids(responder_id) {
2328 session.peer.send(connection_id, update.clone())?;
2329 if let Some(notification_id) = deleted_notification_id {
2330 session.peer.send(
2331 connection_id,
2332 proto::DeleteNotification {
2333 notification_id: notification_id.to_proto(),
2334 },
2335 )?;
2336 }
2337 }
2338
2339 response.send(proto::Ack {})?;
2340 Ok(())
2341}
2342
2343/// Creates a new channel.
2344async fn create_channel(
2345 request: proto::CreateChannel,
2346 response: Response<proto::CreateChannel>,
2347 session: Session,
2348) -> Result<()> {
2349 let db = session.db().await;
2350
2351 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2352 let (channel, owner, channel_members) = db
2353 .create_channel(&request.name, parent_id, session.user_id)
2354 .await?;
2355
2356 response.send(proto::CreateChannelResponse {
2357 channel: Some(channel.to_proto()),
2358 parent_id: request.parent_id,
2359 })?;
2360
2361 let connection_pool = session.connection_pool().await;
2362 if let Some(owner) = owner {
2363 let update = proto::UpdateUserChannels {
2364 channel_memberships: vec![proto::ChannelMembership {
2365 channel_id: owner.channel_id.to_proto(),
2366 role: owner.role.into(),
2367 }],
2368 ..Default::default()
2369 };
2370 for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2371 session.peer.send(connection_id, update.clone())?;
2372 }
2373 }
2374
2375 for channel_member in channel_members {
2376 if !channel_member.role.can_see_channel(channel.visibility) {
2377 continue;
2378 }
2379
2380 let update = proto::UpdateChannels {
2381 channels: vec![channel.to_proto()],
2382 ..Default::default()
2383 };
2384 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2385 session.peer.send(connection_id, update.clone())?;
2386 }
2387 }
2388
2389 Ok(())
2390}
2391
2392/// Delete a channel
2393async fn delete_channel(
2394 request: proto::DeleteChannel,
2395 response: Response<proto::DeleteChannel>,
2396 session: Session,
2397) -> Result<()> {
2398 let db = session.db().await;
2399
2400 let channel_id = request.channel_id;
2401 let (removed_channels, member_ids) = db
2402 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2403 .await?;
2404 response.send(proto::Ack {})?;
2405
2406 // Notify members of removed channels
2407 let mut update = proto::UpdateChannels::default();
2408 update
2409 .delete_channels
2410 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2411
2412 let connection_pool = session.connection_pool().await;
2413 for member_id in member_ids {
2414 for connection_id in connection_pool.user_connection_ids(member_id) {
2415 session.peer.send(connection_id, update.clone())?;
2416 }
2417 }
2418
2419 Ok(())
2420}
2421
2422/// Invite someone to join a channel.
2423async fn invite_channel_member(
2424 request: proto::InviteChannelMember,
2425 response: Response<proto::InviteChannelMember>,
2426 session: Session,
2427) -> Result<()> {
2428 let db = session.db().await;
2429 let channel_id = ChannelId::from_proto(request.channel_id);
2430 let invitee_id = UserId::from_proto(request.user_id);
2431 let InviteMemberResult {
2432 channel,
2433 notifications,
2434 } = db
2435 .invite_channel_member(
2436 channel_id,
2437 invitee_id,
2438 session.user_id,
2439 request.role().into(),
2440 )
2441 .await?;
2442
2443 let update = proto::UpdateChannels {
2444 channel_invitations: vec![channel.to_proto()],
2445 ..Default::default()
2446 };
2447
2448 let connection_pool = session.connection_pool().await;
2449 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2450 session.peer.send(connection_id, update.clone())?;
2451 }
2452
2453 send_notifications(&*connection_pool, &session.peer, notifications);
2454
2455 response.send(proto::Ack {})?;
2456 Ok(())
2457}
2458
2459/// remove someone from a channel
2460async fn remove_channel_member(
2461 request: proto::RemoveChannelMember,
2462 response: Response<proto::RemoveChannelMember>,
2463 session: Session,
2464) -> Result<()> {
2465 let db = session.db().await;
2466 let channel_id = ChannelId::from_proto(request.channel_id);
2467 let member_id = UserId::from_proto(request.user_id);
2468
2469 let RemoveChannelMemberResult {
2470 membership_update,
2471 notification_id,
2472 } = db
2473 .remove_channel_member(channel_id, member_id, session.user_id)
2474 .await?;
2475
2476 let connection_pool = &session.connection_pool().await;
2477 notify_membership_updated(
2478 &connection_pool,
2479 membership_update,
2480 member_id,
2481 &session.peer,
2482 );
2483 for connection_id in connection_pool.user_connection_ids(member_id) {
2484 if let Some(notification_id) = notification_id {
2485 session
2486 .peer
2487 .send(
2488 connection_id,
2489 proto::DeleteNotification {
2490 notification_id: notification_id.to_proto(),
2491 },
2492 )
2493 .trace_err();
2494 }
2495 }
2496
2497 response.send(proto::Ack {})?;
2498 Ok(())
2499}
2500
2501/// Toggle the channel between public and private.
2502/// Care is taken to maintain the invariant that public channels only descend from public channels,
2503/// (though members-only channels can appear at any point in the hierarchy).
2504async fn set_channel_visibility(
2505 request: proto::SetChannelVisibility,
2506 response: Response<proto::SetChannelVisibility>,
2507 session: Session,
2508) -> Result<()> {
2509 let db = session.db().await;
2510 let channel_id = ChannelId::from_proto(request.channel_id);
2511 let visibility = request.visibility().into();
2512
2513 let (channel, channel_members) = db
2514 .set_channel_visibility(channel_id, visibility, session.user_id)
2515 .await?;
2516
2517 let connection_pool = session.connection_pool().await;
2518 for member in channel_members {
2519 let update = if member.role.can_see_channel(channel.visibility) {
2520 proto::UpdateChannels {
2521 channels: vec![channel.to_proto()],
2522 ..Default::default()
2523 }
2524 } else {
2525 proto::UpdateChannels {
2526 delete_channels: vec![channel.id.to_proto()],
2527 ..Default::default()
2528 }
2529 };
2530
2531 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2532 session.peer.send(connection_id, update.clone())?;
2533 }
2534 }
2535
2536 response.send(proto::Ack {})?;
2537 Ok(())
2538}
2539
2540/// Alter the role for a user in the channel.
2541async fn set_channel_member_role(
2542 request: proto::SetChannelMemberRole,
2543 response: Response<proto::SetChannelMemberRole>,
2544 session: Session,
2545) -> Result<()> {
2546 let db = session.db().await;
2547 let channel_id = ChannelId::from_proto(request.channel_id);
2548 let member_id = UserId::from_proto(request.user_id);
2549 let result = db
2550 .set_channel_member_role(
2551 channel_id,
2552 session.user_id,
2553 member_id,
2554 request.role().into(),
2555 )
2556 .await?;
2557
2558 match result {
2559 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2560 let connection_pool = session.connection_pool().await;
2561 notify_membership_updated(
2562 &connection_pool,
2563 membership_update,
2564 member_id,
2565 &session.peer,
2566 )
2567 }
2568 db::SetMemberRoleResult::InviteUpdated(channel) => {
2569 let update = proto::UpdateChannels {
2570 channel_invitations: vec![channel.to_proto()],
2571 ..Default::default()
2572 };
2573
2574 for connection_id in session
2575 .connection_pool()
2576 .await
2577 .user_connection_ids(member_id)
2578 {
2579 session.peer.send(connection_id, update.clone())?;
2580 }
2581 }
2582 }
2583
2584 response.send(proto::Ack {})?;
2585 Ok(())
2586}
2587
2588/// Change the name of a channel
2589async fn rename_channel(
2590 request: proto::RenameChannel,
2591 response: Response<proto::RenameChannel>,
2592 session: Session,
2593) -> Result<()> {
2594 let db = session.db().await;
2595 let channel_id = ChannelId::from_proto(request.channel_id);
2596 let (channel, channel_members) = db
2597 .rename_channel(channel_id, session.user_id, &request.name)
2598 .await?;
2599
2600 response.send(proto::RenameChannelResponse {
2601 channel: Some(channel.to_proto()),
2602 })?;
2603
2604 let connection_pool = session.connection_pool().await;
2605 for channel_member in channel_members {
2606 if !channel_member.role.can_see_channel(channel.visibility) {
2607 continue;
2608 }
2609 let update = proto::UpdateChannels {
2610 channels: vec![channel.to_proto()],
2611 ..Default::default()
2612 };
2613 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2614 session.peer.send(connection_id, update.clone())?;
2615 }
2616 }
2617
2618 Ok(())
2619}
2620
2621/// Move a channel to a new parent.
2622async fn move_channel(
2623 request: proto::MoveChannel,
2624 response: Response<proto::MoveChannel>,
2625 session: Session,
2626) -> Result<()> {
2627 let channel_id = ChannelId::from_proto(request.channel_id);
2628 let to = ChannelId::from_proto(request.to);
2629
2630 let (channels, channel_members) = session
2631 .db()
2632 .await
2633 .move_channel(channel_id, to, session.user_id)
2634 .await?;
2635
2636 let connection_pool = session.connection_pool().await;
2637 for member in channel_members {
2638 let channels = channels
2639 .iter()
2640 .filter_map(|channel| {
2641 if member.role.can_see_channel(channel.visibility) {
2642 Some(channel.to_proto())
2643 } else {
2644 None
2645 }
2646 })
2647 .collect::<Vec<_>>();
2648 if channels.is_empty() {
2649 continue;
2650 }
2651
2652 let update = proto::UpdateChannels {
2653 channels,
2654 ..Default::default()
2655 };
2656
2657 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2658 session.peer.send(connection_id, update.clone())?;
2659 }
2660 }
2661
2662 response.send(Ack {})?;
2663 Ok(())
2664}
2665
2666/// Get the list of channel members
2667async fn get_channel_members(
2668 request: proto::GetChannelMembers,
2669 response: Response<proto::GetChannelMembers>,
2670 session: Session,
2671) -> Result<()> {
2672 let db = session.db().await;
2673 let channel_id = ChannelId::from_proto(request.channel_id);
2674 let members = db
2675 .get_channel_participant_details(channel_id, session.user_id)
2676 .await?;
2677 response.send(proto::GetChannelMembersResponse { members })?;
2678 Ok(())
2679}
2680
2681/// Accept or decline a channel invitation.
2682async fn respond_to_channel_invite(
2683 request: proto::RespondToChannelInvite,
2684 response: Response<proto::RespondToChannelInvite>,
2685 session: Session,
2686) -> Result<()> {
2687 let db = session.db().await;
2688 let channel_id = ChannelId::from_proto(request.channel_id);
2689 let RespondToChannelInvite {
2690 membership_update,
2691 notifications,
2692 } = db
2693 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2694 .await?;
2695
2696 let connection_pool = session.connection_pool().await;
2697 if let Some(membership_update) = membership_update {
2698 notify_membership_updated(
2699 &connection_pool,
2700 membership_update,
2701 session.user_id,
2702 &session.peer,
2703 );
2704 } else {
2705 let update = proto::UpdateChannels {
2706 remove_channel_invitations: vec![channel_id.to_proto()],
2707 ..Default::default()
2708 };
2709
2710 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2711 session.peer.send(connection_id, update.clone())?;
2712 }
2713 };
2714
2715 send_notifications(&*connection_pool, &session.peer, notifications);
2716
2717 response.send(proto::Ack {})?;
2718
2719 Ok(())
2720}
2721
2722/// Join the channels' room
2723async fn join_channel(
2724 request: proto::JoinChannel,
2725 response: Response<proto::JoinChannel>,
2726 session: Session,
2727) -> Result<()> {
2728 let channel_id = ChannelId::from_proto(request.channel_id);
2729 join_channel_internal(channel_id, Box::new(response), session).await
2730}
2731
2732trait JoinChannelInternalResponse {
2733 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2734}
2735impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2736 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2737 Response::<proto::JoinChannel>::send(self, result)
2738 }
2739}
2740impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2741 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2742 Response::<proto::JoinRoom>::send(self, result)
2743 }
2744}
2745
2746async fn join_channel_internal(
2747 channel_id: ChannelId,
2748 response: Box<impl JoinChannelInternalResponse>,
2749 session: Session,
2750) -> Result<()> {
2751 let joined_room = {
2752 leave_room_for_session(&session).await?;
2753 let db = session.db().await;
2754
2755 let (joined_room, membership_updated, role) = db
2756 .join_channel(channel_id, session.user_id, session.connection_id)
2757 .await?;
2758
2759 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2760 let (can_publish, token) = if role == ChannelRole::Guest {
2761 (
2762 false,
2763 live_kit
2764 .guest_token(
2765 &joined_room.room.live_kit_room,
2766 &session.user_id.to_string(),
2767 )
2768 .trace_err()?,
2769 )
2770 } else {
2771 (
2772 true,
2773 live_kit
2774 .room_token(
2775 &joined_room.room.live_kit_room,
2776 &session.user_id.to_string(),
2777 )
2778 .trace_err()?,
2779 )
2780 };
2781
2782 Some(LiveKitConnectionInfo {
2783 server_url: live_kit.url().into(),
2784 token,
2785 can_publish,
2786 })
2787 });
2788
2789 response.send(proto::JoinRoomResponse {
2790 room: Some(joined_room.room.clone()),
2791 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2792 live_kit_connection_info,
2793 })?;
2794
2795 let connection_pool = session.connection_pool().await;
2796 if let Some(membership_updated) = membership_updated {
2797 notify_membership_updated(
2798 &connection_pool,
2799 membership_updated,
2800 session.user_id,
2801 &session.peer,
2802 );
2803 }
2804
2805 room_updated(&joined_room.room, &session.peer);
2806
2807 joined_room
2808 };
2809
2810 channel_updated(
2811 channel_id,
2812 &joined_room.room,
2813 &joined_room.channel_members,
2814 &session.peer,
2815 &*session.connection_pool().await,
2816 );
2817
2818 update_user_contacts(session.user_id, &session).await?;
2819 Ok(())
2820}
2821
2822/// Start editing the channel notes
2823async fn join_channel_buffer(
2824 request: proto::JoinChannelBuffer,
2825 response: Response<proto::JoinChannelBuffer>,
2826 session: Session,
2827) -> Result<()> {
2828 let db = session.db().await;
2829 let channel_id = ChannelId::from_proto(request.channel_id);
2830
2831 let open_response = db
2832 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2833 .await?;
2834
2835 let collaborators = open_response.collaborators.clone();
2836 response.send(open_response)?;
2837
2838 let update = UpdateChannelBufferCollaborators {
2839 channel_id: channel_id.to_proto(),
2840 collaborators: collaborators.clone(),
2841 };
2842 channel_buffer_updated(
2843 session.connection_id,
2844 collaborators
2845 .iter()
2846 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2847 &update,
2848 &session.peer,
2849 );
2850
2851 Ok(())
2852}
2853
2854/// Edit the channel notes
2855async fn update_channel_buffer(
2856 request: proto::UpdateChannelBuffer,
2857 session: Session,
2858) -> Result<()> {
2859 let db = session.db().await;
2860 let channel_id = ChannelId::from_proto(request.channel_id);
2861
2862 let (collaborators, non_collaborators, epoch, version) = db
2863 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2864 .await?;
2865
2866 channel_buffer_updated(
2867 session.connection_id,
2868 collaborators,
2869 &proto::UpdateChannelBuffer {
2870 channel_id: channel_id.to_proto(),
2871 operations: request.operations,
2872 },
2873 &session.peer,
2874 );
2875
2876 let pool = &*session.connection_pool().await;
2877
2878 broadcast(
2879 None,
2880 non_collaborators
2881 .iter()
2882 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2883 |peer_id| {
2884 session.peer.send(
2885 peer_id.into(),
2886 proto::UpdateChannels {
2887 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2888 channel_id: channel_id.to_proto(),
2889 epoch: epoch as u64,
2890 version: version.clone(),
2891 }],
2892 ..Default::default()
2893 },
2894 )
2895 },
2896 );
2897
2898 Ok(())
2899}
2900
2901/// Rejoin the channel notes after a connection blip
2902async fn rejoin_channel_buffers(
2903 request: proto::RejoinChannelBuffers,
2904 response: Response<proto::RejoinChannelBuffers>,
2905 session: Session,
2906) -> Result<()> {
2907 let db = session.db().await;
2908 let buffers = db
2909 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2910 .await?;
2911
2912 for rejoined_buffer in &buffers {
2913 let collaborators_to_notify = rejoined_buffer
2914 .buffer
2915 .collaborators
2916 .iter()
2917 .filter_map(|c| Some(c.peer_id?.into()));
2918 channel_buffer_updated(
2919 session.connection_id,
2920 collaborators_to_notify,
2921 &proto::UpdateChannelBufferCollaborators {
2922 channel_id: rejoined_buffer.buffer.channel_id,
2923 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2924 },
2925 &session.peer,
2926 );
2927 }
2928
2929 response.send(proto::RejoinChannelBuffersResponse {
2930 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2931 })?;
2932
2933 Ok(())
2934}
2935
2936/// Stop editing the channel notes
2937async fn leave_channel_buffer(
2938 request: proto::LeaveChannelBuffer,
2939 response: Response<proto::LeaveChannelBuffer>,
2940 session: Session,
2941) -> Result<()> {
2942 let db = session.db().await;
2943 let channel_id = ChannelId::from_proto(request.channel_id);
2944
2945 let left_buffer = db
2946 .leave_channel_buffer(channel_id, session.connection_id)
2947 .await?;
2948
2949 response.send(Ack {})?;
2950
2951 channel_buffer_updated(
2952 session.connection_id,
2953 left_buffer.connections,
2954 &proto::UpdateChannelBufferCollaborators {
2955 channel_id: channel_id.to_proto(),
2956 collaborators: left_buffer.collaborators,
2957 },
2958 &session.peer,
2959 );
2960
2961 Ok(())
2962}
2963
2964fn channel_buffer_updated<T: EnvelopedMessage>(
2965 sender_id: ConnectionId,
2966 collaborators: impl IntoIterator<Item = ConnectionId>,
2967 message: &T,
2968 peer: &Peer,
2969) {
2970 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2971 peer.send(peer_id.into(), message.clone())
2972 });
2973}
2974
2975fn send_notifications(
2976 connection_pool: &ConnectionPool,
2977 peer: &Peer,
2978 notifications: db::NotificationBatch,
2979) {
2980 for (user_id, notification) in notifications {
2981 for connection_id in connection_pool.user_connection_ids(user_id) {
2982 if let Err(error) = peer.send(
2983 connection_id,
2984 proto::AddNotification {
2985 notification: Some(notification.clone()),
2986 },
2987 ) {
2988 tracing::error!(
2989 "failed to send notification to {:?} {}",
2990 connection_id,
2991 error
2992 );
2993 }
2994 }
2995 }
2996}
2997
2998/// Send a message to the channel
2999async fn send_channel_message(
3000 request: proto::SendChannelMessage,
3001 response: Response<proto::SendChannelMessage>,
3002 session: Session,
3003) -> Result<()> {
3004 // Validate the message body.
3005 let body = request.body.trim().to_string();
3006 if body.len() > MAX_MESSAGE_LEN {
3007 return Err(anyhow!("message is too long"))?;
3008 }
3009 if body.is_empty() {
3010 return Err(anyhow!("message can't be blank"))?;
3011 }
3012
3013 // TODO: adjust mentions if body is trimmed
3014
3015 let timestamp = OffsetDateTime::now_utc();
3016 let nonce = request
3017 .nonce
3018 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3019
3020 let channel_id = ChannelId::from_proto(request.channel_id);
3021 let CreatedChannelMessage {
3022 message_id,
3023 participant_connection_ids,
3024 channel_members,
3025 notifications,
3026 } = session
3027 .db()
3028 .await
3029 .create_channel_message(
3030 channel_id,
3031 session.user_id,
3032 &body,
3033 &request.mentions,
3034 timestamp,
3035 nonce.clone().into(),
3036 match request.reply_to_message_id {
3037 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3038 None => None,
3039 },
3040 )
3041 .await?;
3042 let message = proto::ChannelMessage {
3043 sender_id: session.user_id.to_proto(),
3044 id: message_id.to_proto(),
3045 body,
3046 mentions: request.mentions,
3047 timestamp: timestamp.unix_timestamp() as u64,
3048 nonce: Some(nonce),
3049 reply_to_message_id: request.reply_to_message_id,
3050 };
3051 broadcast(
3052 Some(session.connection_id),
3053 participant_connection_ids,
3054 |connection| {
3055 session.peer.send(
3056 connection,
3057 proto::ChannelMessageSent {
3058 channel_id: channel_id.to_proto(),
3059 message: Some(message.clone()),
3060 },
3061 )
3062 },
3063 );
3064 response.send(proto::SendChannelMessageResponse {
3065 message: Some(message),
3066 })?;
3067
3068 let pool = &*session.connection_pool().await;
3069 broadcast(
3070 None,
3071 channel_members
3072 .iter()
3073 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3074 |peer_id| {
3075 session.peer.send(
3076 peer_id.into(),
3077 proto::UpdateChannels {
3078 latest_channel_message_ids: vec![proto::ChannelMessageId {
3079 channel_id: channel_id.to_proto(),
3080 message_id: message_id.to_proto(),
3081 }],
3082 ..Default::default()
3083 },
3084 )
3085 },
3086 );
3087 send_notifications(pool, &session.peer, notifications);
3088
3089 Ok(())
3090}
3091
3092/// Delete a channel message
3093async fn remove_channel_message(
3094 request: proto::RemoveChannelMessage,
3095 response: Response<proto::RemoveChannelMessage>,
3096 session: Session,
3097) -> Result<()> {
3098 let channel_id = ChannelId::from_proto(request.channel_id);
3099 let message_id = MessageId::from_proto(request.message_id);
3100 let connection_ids = session
3101 .db()
3102 .await
3103 .remove_channel_message(channel_id, message_id, session.user_id)
3104 .await?;
3105 broadcast(Some(session.connection_id), connection_ids, |connection| {
3106 session.peer.send(connection, request.clone())
3107 });
3108 response.send(proto::Ack {})?;
3109 Ok(())
3110}
3111
3112/// Mark a channel message as read
3113async fn acknowledge_channel_message(
3114 request: proto::AckChannelMessage,
3115 session: Session,
3116) -> Result<()> {
3117 let channel_id = ChannelId::from_proto(request.channel_id);
3118 let message_id = MessageId::from_proto(request.message_id);
3119 let notifications = session
3120 .db()
3121 .await
3122 .observe_channel_message(channel_id, session.user_id, message_id)
3123 .await?;
3124 send_notifications(
3125 &*session.connection_pool().await,
3126 &session.peer,
3127 notifications,
3128 );
3129 Ok(())
3130}
3131
3132/// Mark a buffer version as synced
3133async fn acknowledge_buffer_version(
3134 request: proto::AckBufferOperation,
3135 session: Session,
3136) -> Result<()> {
3137 let buffer_id = BufferId::from_proto(request.buffer_id);
3138 session
3139 .db()
3140 .await
3141 .observe_buffer_version(
3142 buffer_id,
3143 session.user_id,
3144 request.epoch as i32,
3145 &request.version,
3146 )
3147 .await?;
3148 Ok(())
3149}
3150
3151/// Start receiving chat updates for a channel
3152async fn join_channel_chat(
3153 request: proto::JoinChannelChat,
3154 response: Response<proto::JoinChannelChat>,
3155 session: Session,
3156) -> Result<()> {
3157 let channel_id = ChannelId::from_proto(request.channel_id);
3158
3159 let db = session.db().await;
3160 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3161 .await?;
3162 let messages = db
3163 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3164 .await?;
3165 response.send(proto::JoinChannelChatResponse {
3166 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3167 messages,
3168 })?;
3169 Ok(())
3170}
3171
3172/// Stop receiving chat updates for a channel
3173async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3174 let channel_id = ChannelId::from_proto(request.channel_id);
3175 session
3176 .db()
3177 .await
3178 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3179 .await?;
3180 Ok(())
3181}
3182
3183/// Retrieve the chat history for a channel
3184async fn get_channel_messages(
3185 request: proto::GetChannelMessages,
3186 response: Response<proto::GetChannelMessages>,
3187 session: Session,
3188) -> Result<()> {
3189 let channel_id = ChannelId::from_proto(request.channel_id);
3190 let messages = session
3191 .db()
3192 .await
3193 .get_channel_messages(
3194 channel_id,
3195 session.user_id,
3196 MESSAGE_COUNT_PER_PAGE,
3197 Some(MessageId::from_proto(request.before_message_id)),
3198 )
3199 .await?;
3200 response.send(proto::GetChannelMessagesResponse {
3201 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3202 messages,
3203 })?;
3204 Ok(())
3205}
3206
3207/// Retrieve specific chat messages
3208async fn get_channel_messages_by_id(
3209 request: proto::GetChannelMessagesById,
3210 response: Response<proto::GetChannelMessagesById>,
3211 session: Session,
3212) -> Result<()> {
3213 let message_ids = request
3214 .message_ids
3215 .iter()
3216 .map(|id| MessageId::from_proto(*id))
3217 .collect::<Vec<_>>();
3218 let messages = session
3219 .db()
3220 .await
3221 .get_channel_messages_by_id(session.user_id, &message_ids)
3222 .await?;
3223 response.send(proto::GetChannelMessagesResponse {
3224 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3225 messages,
3226 })?;
3227 Ok(())
3228}
3229
3230/// Retrieve the current users notifications
3231async fn get_notifications(
3232 request: proto::GetNotifications,
3233 response: Response<proto::GetNotifications>,
3234 session: Session,
3235) -> Result<()> {
3236 let notifications = session
3237 .db()
3238 .await
3239 .get_notifications(
3240 session.user_id,
3241 NOTIFICATION_COUNT_PER_PAGE,
3242 request
3243 .before_id
3244 .map(|id| db::NotificationId::from_proto(id)),
3245 )
3246 .await?;
3247 response.send(proto::GetNotificationsResponse {
3248 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3249 notifications,
3250 })?;
3251 Ok(())
3252}
3253
3254/// Mark notifications as read
3255async fn mark_notification_as_read(
3256 request: proto::MarkNotificationRead,
3257 response: Response<proto::MarkNotificationRead>,
3258 session: Session,
3259) -> Result<()> {
3260 let database = &session.db().await;
3261 let notifications = database
3262 .mark_notification_as_read_by_id(
3263 session.user_id,
3264 NotificationId::from_proto(request.notification_id),
3265 )
3266 .await?;
3267 send_notifications(
3268 &*session.connection_pool().await,
3269 &session.peer,
3270 notifications,
3271 );
3272 response.send(proto::Ack {})?;
3273 Ok(())
3274}
3275
3276/// Get the current users information
3277async fn get_private_user_info(
3278 _request: proto::GetPrivateUserInfo,
3279 response: Response<proto::GetPrivateUserInfo>,
3280 session: Session,
3281) -> Result<()> {
3282 let db = session.db().await;
3283
3284 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3285 let user = db
3286 .get_user_by_id(session.user_id)
3287 .await?
3288 .ok_or_else(|| anyhow!("user not found"))?;
3289 let flags = db.get_user_flags(session.user_id).await?;
3290
3291 response.send(proto::GetPrivateUserInfoResponse {
3292 metrics_id,
3293 staff: user.admin,
3294 flags,
3295 })?;
3296 Ok(())
3297}
3298
3299fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3300 match message {
3301 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3302 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3303 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3304 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3305 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3306 code: frame.code.into(),
3307 reason: frame.reason,
3308 })),
3309 }
3310}
3311
3312fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3313 match message {
3314 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3315 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3316 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3317 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3318 AxumMessage::Close(frame) => {
3319 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3320 code: frame.code.into(),
3321 reason: frame.reason,
3322 }))
3323 }
3324 }
3325}
3326
3327fn notify_membership_updated(
3328 connection_pool: &ConnectionPool,
3329 result: MembershipUpdated,
3330 user_id: UserId,
3331 peer: &Peer,
3332) {
3333 let user_channels_update = proto::UpdateUserChannels {
3334 channel_memberships: result
3335 .new_channels
3336 .channel_memberships
3337 .iter()
3338 .map(|cm| proto::ChannelMembership {
3339 channel_id: cm.channel_id.to_proto(),
3340 role: cm.role.into(),
3341 })
3342 .collect(),
3343 ..Default::default()
3344 };
3345 let mut update = build_channels_update(result.new_channels, vec![]);
3346 update.delete_channels = result
3347 .removed_channels
3348 .into_iter()
3349 .map(|id| id.to_proto())
3350 .collect();
3351 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3352
3353 for connection_id in connection_pool.user_connection_ids(user_id) {
3354 peer.send(connection_id, user_channels_update.clone())
3355 .trace_err();
3356 peer.send(connection_id, update.clone()).trace_err();
3357 }
3358}
3359
3360fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3361 proto::UpdateUserChannels {
3362 channel_memberships: channels
3363 .channel_memberships
3364 .iter()
3365 .map(|m| proto::ChannelMembership {
3366 channel_id: m.channel_id.to_proto(),
3367 role: m.role.into(),
3368 })
3369 .collect(),
3370 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3371 observed_channel_message_id: channels.observed_channel_messages.clone(),
3372 ..Default::default()
3373 }
3374}
3375
3376fn build_channels_update(
3377 channels: ChannelsForUser,
3378 channel_invites: Vec<db::Channel>,
3379) -> proto::UpdateChannels {
3380 let mut update = proto::UpdateChannels::default();
3381
3382 for channel in channels.channels {
3383 update.channels.push(channel.to_proto());
3384 }
3385
3386 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3387 update.latest_channel_message_ids = channels.latest_channel_messages;
3388
3389 for (channel_id, participants) in channels.channel_participants {
3390 update
3391 .channel_participants
3392 .push(proto::ChannelParticipants {
3393 channel_id: channel_id.to_proto(),
3394 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3395 });
3396 }
3397
3398 for channel in channel_invites {
3399 update.channel_invitations.push(channel.to_proto());
3400 }
3401 for project in channels.hosted_projects {
3402 update.hosted_projects.push(project);
3403 }
3404
3405 update
3406}
3407
3408fn build_initial_contacts_update(
3409 contacts: Vec<db::Contact>,
3410 pool: &ConnectionPool,
3411) -> proto::UpdateContacts {
3412 let mut update = proto::UpdateContacts::default();
3413
3414 for contact in contacts {
3415 match contact {
3416 db::Contact::Accepted { user_id, busy } => {
3417 update.contacts.push(contact_for_user(user_id, busy, &pool));
3418 }
3419 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3420 db::Contact::Incoming { user_id } => {
3421 update
3422 .incoming_requests
3423 .push(proto::IncomingContactRequest {
3424 requester_id: user_id.to_proto(),
3425 })
3426 }
3427 }
3428 }
3429
3430 update
3431}
3432
3433fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3434 proto::Contact {
3435 user_id: user_id.to_proto(),
3436 online: pool.is_user_online(user_id),
3437 busy,
3438 }
3439}
3440
3441fn room_updated(room: &proto::Room, peer: &Peer) {
3442 broadcast(
3443 None,
3444 room.participants
3445 .iter()
3446 .filter_map(|participant| Some(participant.peer_id?.into())),
3447 |peer_id| {
3448 peer.send(
3449 peer_id.into(),
3450 proto::RoomUpdated {
3451 room: Some(room.clone()),
3452 },
3453 )
3454 },
3455 );
3456}
3457
3458fn channel_updated(
3459 channel_id: ChannelId,
3460 room: &proto::Room,
3461 channel_members: &[UserId],
3462 peer: &Peer,
3463 pool: &ConnectionPool,
3464) {
3465 let participants = room
3466 .participants
3467 .iter()
3468 .map(|p| p.user_id)
3469 .collect::<Vec<_>>();
3470
3471 broadcast(
3472 None,
3473 channel_members
3474 .iter()
3475 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3476 |peer_id| {
3477 peer.send(
3478 peer_id.into(),
3479 proto::UpdateChannels {
3480 channel_participants: vec![proto::ChannelParticipants {
3481 channel_id: channel_id.to_proto(),
3482 participant_user_ids: participants.clone(),
3483 }],
3484 ..Default::default()
3485 },
3486 )
3487 },
3488 );
3489}
3490
3491async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3492 let db = session.db().await;
3493
3494 let contacts = db.get_contacts(user_id).await?;
3495 let busy = db.is_user_busy(user_id).await?;
3496
3497 let pool = session.connection_pool().await;
3498 let updated_contact = contact_for_user(user_id, busy, &pool);
3499 for contact in contacts {
3500 if let db::Contact::Accepted {
3501 user_id: contact_user_id,
3502 ..
3503 } = contact
3504 {
3505 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3506 session
3507 .peer
3508 .send(
3509 contact_conn_id,
3510 proto::UpdateContacts {
3511 contacts: vec![updated_contact.clone()],
3512 remove_contacts: Default::default(),
3513 incoming_requests: Default::default(),
3514 remove_incoming_requests: Default::default(),
3515 outgoing_requests: Default::default(),
3516 remove_outgoing_requests: Default::default(),
3517 },
3518 )
3519 .trace_err();
3520 }
3521 }
3522 }
3523 Ok(())
3524}
3525
3526async fn leave_room_for_session(session: &Session) -> Result<()> {
3527 let mut contacts_to_update = HashSet::default();
3528
3529 let room_id;
3530 let canceled_calls_to_user_ids;
3531 let live_kit_room;
3532 let delete_live_kit_room;
3533 let room;
3534 let channel_members;
3535 let channel_id;
3536
3537 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3538 contacts_to_update.insert(session.user_id);
3539
3540 for project in left_room.left_projects.values() {
3541 project_left(project, session);
3542 }
3543
3544 room_id = RoomId::from_proto(left_room.room.id);
3545 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3546 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3547 delete_live_kit_room = left_room.deleted;
3548 room = mem::take(&mut left_room.room);
3549 channel_members = mem::take(&mut left_room.channel_members);
3550 channel_id = left_room.channel_id;
3551
3552 room_updated(&room, &session.peer);
3553 } else {
3554 return Ok(());
3555 }
3556
3557 if let Some(channel_id) = channel_id {
3558 channel_updated(
3559 channel_id,
3560 &room,
3561 &channel_members,
3562 &session.peer,
3563 &*session.connection_pool().await,
3564 );
3565 }
3566
3567 {
3568 let pool = session.connection_pool().await;
3569 for canceled_user_id in canceled_calls_to_user_ids {
3570 for connection_id in pool.user_connection_ids(canceled_user_id) {
3571 session
3572 .peer
3573 .send(
3574 connection_id,
3575 proto::CallCanceled {
3576 room_id: room_id.to_proto(),
3577 },
3578 )
3579 .trace_err();
3580 }
3581 contacts_to_update.insert(canceled_user_id);
3582 }
3583 }
3584
3585 for contact_user_id in contacts_to_update {
3586 update_user_contacts(contact_user_id, &session).await?;
3587 }
3588
3589 if let Some(live_kit) = session.live_kit_client.as_ref() {
3590 live_kit
3591 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3592 .await
3593 .trace_err();
3594
3595 if delete_live_kit_room {
3596 live_kit.delete_room(live_kit_room).await.trace_err();
3597 }
3598 }
3599
3600 Ok(())
3601}
3602
3603async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3604 let left_channel_buffers = session
3605 .db()
3606 .await
3607 .leave_channel_buffers(session.connection_id)
3608 .await?;
3609
3610 for left_buffer in left_channel_buffers {
3611 channel_buffer_updated(
3612 session.connection_id,
3613 left_buffer.connections,
3614 &proto::UpdateChannelBufferCollaborators {
3615 channel_id: left_buffer.channel_id.to_proto(),
3616 collaborators: left_buffer.collaborators,
3617 },
3618 &session.peer,
3619 );
3620 }
3621
3622 Ok(())
3623}
3624
3625fn project_left(project: &db::LeftProject, session: &Session) {
3626 for connection_id in &project.connection_ids {
3627 if project.host_user_id == session.user_id {
3628 session
3629 .peer
3630 .send(
3631 *connection_id,
3632 proto::UnshareProject {
3633 project_id: project.id.to_proto(),
3634 },
3635 )
3636 .trace_err();
3637 } else {
3638 session
3639 .peer
3640 .send(
3641 *connection_id,
3642 proto::RemoveProjectCollaborator {
3643 project_id: project.id.to_proto(),
3644 peer_id: Some(session.connection_id.into()),
3645 },
3646 )
3647 .trace_err();
3648 }
3649 }
3650}
3651
3652pub trait ResultExt {
3653 type Ok;
3654
3655 fn trace_err(self) -> Option<Self::Ok>;
3656}
3657
3658impl<T, E> ResultExt for Result<T, E>
3659where
3660 E: std::fmt::Debug,
3661{
3662 type Ok = T;
3663
3664 fn trace_err(self) -> Option<T> {
3665 match self {
3666 Ok(value) => Some(value),
3667 Err(error) => {
3668 tracing::error!("{:?}", error);
3669 None
3670 }
3671 }
3672 }
3673}