1mod connection_pool;
2
3use crate::{
4 auth::{self, Impersonator},
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
7 InviteMemberResult, MembershipUpdated, MessageId, NotificationId, ProjectId,
8 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
9 },
10 executor::Executor,
11 AppState, Error, Result,
12};
13use anyhow::anyhow;
14use async_tungstenite::tungstenite::{
15 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
16};
17use axum::{
18 body::Body,
19 extract::{
20 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
21 ConnectInfo, WebSocketUpgrade,
22 },
23 headers::{Header, HeaderName},
24 http::StatusCode,
25 middleware,
26 response::IntoResponse,
27 routing::get,
28 Extension, Router, TypedHeader,
29};
30use collections::{HashMap, HashSet};
31pub use connection_pool::ConnectionPool;
32use futures::{
33 channel::oneshot,
34 future::{self, BoxFuture},
35 stream::FuturesUnordered,
36 FutureExt, SinkExt, StreamExt, TryStreamExt,
37};
38use lazy_static::lazy_static;
39use prometheus::{register_int_gauge, IntGauge};
40use rpc::{
41 proto::{
42 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
43 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
44 },
45 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
46};
47use serde::{Serialize, Serializer};
48use std::{
49 any::TypeId,
50 fmt,
51 future::Future,
52 marker::PhantomData,
53 mem,
54 net::SocketAddr,
55 ops::{Deref, DerefMut},
56 rc::Rc,
57 sync::{
58 atomic::{AtomicBool, Ordering::SeqCst},
59 Arc,
60 },
61 time::{Duration, Instant},
62};
63use time::OffsetDateTime;
64use tokio::sync::{watch, Semaphore};
65use tower::ServiceBuilder;
66use tracing::{field, info_span, instrument, Instrument};
67use util::SemanticVersion;
68
69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
70pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
71
72const MESSAGE_COUNT_PER_PAGE: usize = 100;
73const MAX_MESSAGE_LEN: usize = 1024;
74const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
75
76lazy_static! {
77 static ref METRIC_CONNECTIONS: IntGauge =
78 register_int_gauge!("connections", "number of connections").unwrap();
79 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
80 "shared_projects",
81 "number of open projects with one or more guests"
82 )
83 .unwrap();
84}
85
86type MessageHandler =
87 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
88
89struct Response<R> {
90 peer: Arc<Peer>,
91 receipt: Receipt<R>,
92 responded: Arc<AtomicBool>,
93}
94
95impl<R: RequestMessage> Response<R> {
96 fn send(self, payload: R::Response) -> Result<()> {
97 self.responded.store(true, SeqCst);
98 self.peer.respond(self.receipt, payload)?;
99 Ok(())
100 }
101}
102
103#[derive(Clone)]
104struct Session {
105 user_id: UserId,
106 connection_id: ConnectionId,
107 db: Arc<tokio::sync::Mutex<DbHandle>>,
108 peer: Arc<Peer>,
109 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
110 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
111 _executor: Executor,
112}
113
114impl Session {
115 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
116 #[cfg(test)]
117 tokio::task::yield_now().await;
118 let guard = self.db.lock().await;
119 #[cfg(test)]
120 tokio::task::yield_now().await;
121 guard
122 }
123
124 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
125 #[cfg(test)]
126 tokio::task::yield_now().await;
127 let guard = self.connection_pool.lock();
128 ConnectionPoolGuard {
129 guard,
130 _not_send: PhantomData,
131 }
132 }
133}
134
135impl fmt::Debug for Session {
136 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137 f.debug_struct("Session")
138 .field("user_id", &self.user_id)
139 .field("connection_id", &self.connection_id)
140 .finish()
141 }
142}
143
144struct DbHandle(Arc<Database>);
145
146impl Deref for DbHandle {
147 type Target = Database;
148
149 fn deref(&self) -> &Self::Target {
150 self.0.as_ref()
151 }
152}
153
154pub struct Server {
155 id: parking_lot::Mutex<ServerId>,
156 peer: Arc<Peer>,
157 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
158 app_state: Arc<AppState>,
159 executor: Executor,
160 handlers: HashMap<TypeId, MessageHandler>,
161 teardown: watch::Sender<()>,
162}
163
164pub(crate) struct ConnectionPoolGuard<'a> {
165 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
166 _not_send: PhantomData<Rc<()>>,
167}
168
169#[derive(Serialize)]
170pub struct ServerSnapshot<'a> {
171 peer: &'a Peer,
172 #[serde(serialize_with = "serialize_deref")]
173 connection_pool: ConnectionPoolGuard<'a>,
174}
175
176pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
177where
178 S: Serializer,
179 T: Deref<Target = U>,
180 U: Serialize,
181{
182 Serialize::serialize(value.deref(), serializer)
183}
184
185impl Server {
186 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
187 let mut server = Self {
188 id: parking_lot::Mutex::new(id),
189 peer: Peer::new(id.0 as u32),
190 app_state,
191 executor,
192 connection_pool: Default::default(),
193 handlers: Default::default(),
194 teardown: watch::channel(()).0,
195 };
196
197 server
198 .add_request_handler(ping)
199 .add_request_handler(create_room)
200 .add_request_handler(join_room)
201 .add_request_handler(rejoin_room)
202 .add_request_handler(leave_room)
203 .add_request_handler(set_room_participant_role)
204 .add_request_handler(call)
205 .add_request_handler(cancel_call)
206 .add_message_handler(decline_call)
207 .add_request_handler(update_participant_location)
208 .add_request_handler(share_project)
209 .add_message_handler(unshare_project)
210 .add_request_handler(join_project)
211 .add_message_handler(leave_project)
212 .add_request_handler(update_project)
213 .add_request_handler(update_worktree)
214 .add_message_handler(start_language_server)
215 .add_message_handler(update_language_server)
216 .add_message_handler(update_diagnostic_summary)
217 .add_message_handler(update_worktree_settings)
218 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
219 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
220 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
221 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
222 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
223 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
224 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
225 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
226 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
227 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
228 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
229 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
230 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
231 .add_request_handler(
232 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
233 )
234 .add_request_handler(
235 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
236 )
237 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
238 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
239 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
240 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
241 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
242 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
243 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
244 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
245 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
246 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
247 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
248 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
249 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
250 .add_message_handler(create_buffer_for_peer)
251 .add_request_handler(update_buffer)
252 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
253 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
254 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
255 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
256 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
257 .add_request_handler(get_users)
258 .add_request_handler(fuzzy_search_users)
259 .add_request_handler(request_contact)
260 .add_request_handler(remove_contact)
261 .add_request_handler(respond_to_contact_request)
262 .add_request_handler(create_channel)
263 .add_request_handler(delete_channel)
264 .add_request_handler(invite_channel_member)
265 .add_request_handler(remove_channel_member)
266 .add_request_handler(set_channel_member_role)
267 .add_request_handler(set_channel_visibility)
268 .add_request_handler(rename_channel)
269 .add_request_handler(join_channel_buffer)
270 .add_request_handler(leave_channel_buffer)
271 .add_message_handler(update_channel_buffer)
272 .add_request_handler(rejoin_channel_buffers)
273 .add_request_handler(get_channel_members)
274 .add_request_handler(respond_to_channel_invite)
275 .add_request_handler(join_channel)
276 .add_request_handler(join_channel_chat)
277 .add_message_handler(leave_channel_chat)
278 .add_request_handler(send_channel_message)
279 .add_request_handler(remove_channel_message)
280 .add_request_handler(get_channel_messages)
281 .add_request_handler(get_channel_messages_by_id)
282 .add_request_handler(get_notifications)
283 .add_request_handler(mark_notification_as_read)
284 .add_request_handler(move_channel)
285 .add_request_handler(follow)
286 .add_message_handler(unfollow)
287 .add_message_handler(update_followers)
288 .add_request_handler(get_private_user_info)
289 .add_message_handler(acknowledge_channel_message)
290 .add_message_handler(acknowledge_buffer_version);
291
292 Arc::new(server)
293 }
294
295 pub async fn start(&self) -> Result<()> {
296 let server_id = *self.id.lock();
297 let app_state = self.app_state.clone();
298 let peer = self.peer.clone();
299 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
300 let pool = self.connection_pool.clone();
301 let live_kit_client = self.app_state.live_kit_client.clone();
302
303 let span = info_span!("start server");
304 self.executor.spawn_detached(
305 async move {
306 tracing::info!("waiting for cleanup timeout");
307 timeout.await;
308 tracing::info!("cleanup timeout expired, retrieving stale rooms");
309 if let Some((room_ids, channel_ids)) = app_state
310 .db
311 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
312 .await
313 .trace_err()
314 {
315 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
316 tracing::info!(
317 stale_channel_buffer_count = channel_ids.len(),
318 "retrieved stale channel buffers"
319 );
320
321 for channel_id in channel_ids {
322 if let Some(refreshed_channel_buffer) = app_state
323 .db
324 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
325 .await
326 .trace_err()
327 {
328 for connection_id in refreshed_channel_buffer.connection_ids {
329 peer.send(
330 connection_id,
331 proto::UpdateChannelBufferCollaborators {
332 channel_id: channel_id.to_proto(),
333 collaborators: refreshed_channel_buffer
334 .collaborators
335 .clone(),
336 },
337 )
338 .trace_err();
339 }
340 }
341 }
342
343 for room_id in room_ids {
344 let mut contacts_to_update = HashSet::default();
345 let mut canceled_calls_to_user_ids = Vec::new();
346 let mut live_kit_room = String::new();
347 let mut delete_live_kit_room = false;
348
349 if let Some(mut refreshed_room) = app_state
350 .db
351 .clear_stale_room_participants(room_id, server_id)
352 .await
353 .trace_err()
354 {
355 tracing::info!(
356 room_id = room_id.0,
357 new_participant_count = refreshed_room.room.participants.len(),
358 "refreshed room"
359 );
360 room_updated(&refreshed_room.room, &peer);
361 if let Some(channel_id) = refreshed_room.channel_id {
362 channel_updated(
363 channel_id,
364 &refreshed_room.room,
365 &refreshed_room.channel_members,
366 &peer,
367 &*pool.lock(),
368 );
369 }
370 contacts_to_update
371 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
372 contacts_to_update
373 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
374 canceled_calls_to_user_ids =
375 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
376 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
377 delete_live_kit_room = refreshed_room.room.participants.is_empty();
378 }
379
380 {
381 let pool = pool.lock();
382 for canceled_user_id in canceled_calls_to_user_ids {
383 for connection_id in pool.user_connection_ids(canceled_user_id) {
384 peer.send(
385 connection_id,
386 proto::CallCanceled {
387 room_id: room_id.to_proto(),
388 },
389 )
390 .trace_err();
391 }
392 }
393 }
394
395 for user_id in contacts_to_update {
396 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
397 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
398 if let Some((busy, contacts)) = busy.zip(contacts) {
399 let pool = pool.lock();
400 let updated_contact = contact_for_user(user_id, busy, &pool);
401 for contact in contacts {
402 if let db::Contact::Accepted {
403 user_id: contact_user_id,
404 ..
405 } = contact
406 {
407 for contact_conn_id in
408 pool.user_connection_ids(contact_user_id)
409 {
410 peer.send(
411 contact_conn_id,
412 proto::UpdateContacts {
413 contacts: vec![updated_contact.clone()],
414 remove_contacts: Default::default(),
415 incoming_requests: Default::default(),
416 remove_incoming_requests: Default::default(),
417 outgoing_requests: Default::default(),
418 remove_outgoing_requests: Default::default(),
419 },
420 )
421 .trace_err();
422 }
423 }
424 }
425 }
426 }
427
428 if let Some(live_kit) = live_kit_client.as_ref() {
429 if delete_live_kit_room {
430 live_kit.delete_room(live_kit_room).await.trace_err();
431 }
432 }
433 }
434 }
435
436 app_state
437 .db
438 .delete_stale_servers(&app_state.config.zed_environment, server_id)
439 .await
440 .trace_err();
441 }
442 .instrument(span),
443 );
444 Ok(())
445 }
446
447 pub fn teardown(&self) {
448 self.peer.teardown();
449 self.connection_pool.lock().reset();
450 let _ = self.teardown.send(());
451 }
452
453 #[cfg(test)]
454 pub fn reset(&self, id: ServerId) {
455 self.teardown();
456 *self.id.lock() = id;
457 self.peer.reset(id.0 as u32);
458 }
459
460 #[cfg(test)]
461 pub fn id(&self) -> ServerId {
462 *self.id.lock()
463 }
464
465 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
466 where
467 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
468 Fut: 'static + Send + Future<Output = Result<()>>,
469 M: EnvelopedMessage,
470 {
471 let prev_handler = self.handlers.insert(
472 TypeId::of::<M>(),
473 Box::new(move |envelope, session| {
474 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
475 let span = info_span!(
476 "handle message",
477 payload_type = envelope.payload_type_name()
478 );
479 span.in_scope(|| {
480 tracing::info!(
481 payload_type = envelope.payload_type_name(),
482 "message received"
483 );
484 });
485 let start_time = Instant::now();
486 let future = (handler)(*envelope, session);
487 async move {
488 let result = future.await;
489 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
490 match result {
491 Err(error) => {
492 tracing::error!(%error, ?duration_ms, "error handling message")
493 }
494 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
495 }
496 }
497 .instrument(span)
498 .boxed()
499 }),
500 );
501 if prev_handler.is_some() {
502 panic!("registered a handler for the same message twice");
503 }
504 self
505 }
506
507 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
508 where
509 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
510 Fut: 'static + Send + Future<Output = Result<()>>,
511 M: EnvelopedMessage,
512 {
513 self.add_handler(move |envelope, session| handler(envelope.payload, session));
514 self
515 }
516
517 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
518 where
519 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
520 Fut: Send + Future<Output = Result<()>>,
521 M: RequestMessage,
522 {
523 let handler = Arc::new(handler);
524 self.add_handler(move |envelope, session| {
525 let receipt = envelope.receipt();
526 let handler = handler.clone();
527 async move {
528 let peer = session.peer.clone();
529 let responded = Arc::new(AtomicBool::default());
530 let response = Response {
531 peer: peer.clone(),
532 responded: responded.clone(),
533 receipt,
534 };
535 match (handler)(envelope.payload, response, session).await {
536 Ok(()) => {
537 if responded.load(std::sync::atomic::Ordering::SeqCst) {
538 Ok(())
539 } else {
540 Err(anyhow!("handler did not send a response"))?
541 }
542 }
543 Err(error) => {
544 let proto_err = match &error {
545 Error::Internal(err) => err.to_proto(),
546 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
547 };
548 peer.respond_with_error(receipt, proto_err)?;
549 Err(error)
550 }
551 }
552 }
553 })
554 }
555
556 pub fn handle_connection(
557 self: &Arc<Self>,
558 connection: Connection,
559 address: String,
560 user: User,
561 impersonator: Option<User>,
562 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
563 executor: Executor,
564 ) -> impl Future<Output = Result<()>> {
565 let this = self.clone();
566 let user_id = user.id;
567 let login = user.github_login;
568 let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
569 if let Some(impersonator) = impersonator {
570 span.record("impersonator", &impersonator.github_login);
571 }
572 let mut teardown = self.teardown.subscribe();
573 async move {
574 let (connection_id, handle_io, mut incoming_rx) = this
575 .peer
576 .add_connection(connection, {
577 let executor = executor.clone();
578 move |duration| executor.sleep(duration)
579 });
580
581 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
582 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
583 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
584
585 if let Some(send_connection_id) = send_connection_id.take() {
586 let _ = send_connection_id.send(connection_id);
587 }
588
589 if !user.connected_once {
590 this.peer.send(connection_id, proto::ShowContacts {})?;
591 this.app_state.db.set_user_connected_once(user_id, true).await?;
592 }
593
594 let (contacts, channels_for_user, channel_invites) = future::try_join3(
595 this.app_state.db.get_contacts(user_id),
596 this.app_state.db.get_channels_for_user(user_id),
597 this.app_state.db.get_channel_invites_for_user(user_id),
598 ).await?;
599
600 {
601 let mut pool = this.connection_pool.lock();
602 pool.add_connection(connection_id, user_id, user.admin);
603 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
604 this.peer.send(connection_id, build_update_user_channels(&channels_for_user))?;
605 this.peer.send(connection_id, build_channels_update(
606 channels_for_user,
607 channel_invites
608 ))?;
609 }
610
611 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
612 this.peer.send(connection_id, incoming_call)?;
613 }
614
615 let session = Session {
616 user_id,
617 connection_id,
618 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
619 peer: this.peer.clone(),
620 connection_pool: this.connection_pool.clone(),
621 live_kit_client: this.app_state.live_kit_client.clone(),
622 _executor: executor.clone()
623 };
624 update_user_contacts(user_id, &session).await?;
625
626 let handle_io = handle_io.fuse();
627 futures::pin_mut!(handle_io);
628
629 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
630 // This prevents deadlocks when e.g., client A performs a request to client B and
631 // client B performs a request to client A. If both clients stop processing further
632 // messages until their respective request completes, they won't have a chance to
633 // respond to the other client's request and cause a deadlock.
634 //
635 // This arrangement ensures we will attempt to process earlier messages first, but fall
636 // back to processing messages arrived later in the spirit of making progress.
637 let mut foreground_message_handlers = FuturesUnordered::new();
638 let concurrent_handlers = Arc::new(Semaphore::new(256));
639 loop {
640 let next_message = async {
641 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
642 let message = incoming_rx.next().await;
643 (permit, message)
644 }.fuse();
645 futures::pin_mut!(next_message);
646 futures::select_biased! {
647 _ = teardown.changed().fuse() => return Ok(()),
648 result = handle_io => {
649 if let Err(error) = result {
650 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
651 }
652 break;
653 }
654 _ = foreground_message_handlers.next() => {}
655 next_message = next_message => {
656 let (permit, message) = next_message;
657 if let Some(message) = message {
658 let type_name = message.payload_type_name();
659 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
660 let span_enter = span.enter();
661 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
662 let is_background = message.is_background();
663 let handle_message = (handler)(message, session.clone());
664 drop(span_enter);
665
666 let handle_message = async move {
667 handle_message.await;
668 drop(permit);
669 }.instrument(span);
670 if is_background {
671 executor.spawn_detached(handle_message);
672 } else {
673 foreground_message_handlers.push(handle_message);
674 }
675 } else {
676 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
677 }
678 } else {
679 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
680 break;
681 }
682 }
683 }
684 }
685
686 drop(foreground_message_handlers);
687 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
688 if let Err(error) = connection_lost(session, teardown, executor).await {
689 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
690 }
691
692 Ok(())
693 }.instrument(span)
694 }
695
696 pub async fn invite_code_redeemed(
697 self: &Arc<Self>,
698 inviter_id: UserId,
699 invitee_id: UserId,
700 ) -> Result<()> {
701 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
702 if let Some(code) = &user.invite_code {
703 let pool = self.connection_pool.lock();
704 let invitee_contact = contact_for_user(invitee_id, false, &pool);
705 for connection_id in pool.user_connection_ids(inviter_id) {
706 self.peer.send(
707 connection_id,
708 proto::UpdateContacts {
709 contacts: vec![invitee_contact.clone()],
710 ..Default::default()
711 },
712 )?;
713 self.peer.send(
714 connection_id,
715 proto::UpdateInviteInfo {
716 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
717 count: user.invite_count as u32,
718 },
719 )?;
720 }
721 }
722 }
723 Ok(())
724 }
725
726 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
727 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
728 if let Some(invite_code) = &user.invite_code {
729 let pool = self.connection_pool.lock();
730 for connection_id in pool.user_connection_ids(user_id) {
731 self.peer.send(
732 connection_id,
733 proto::UpdateInviteInfo {
734 url: format!(
735 "{}{}",
736 self.app_state.config.invite_link_prefix, invite_code
737 ),
738 count: user.invite_count as u32,
739 },
740 )?;
741 }
742 }
743 }
744 Ok(())
745 }
746
747 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
748 ServerSnapshot {
749 connection_pool: ConnectionPoolGuard {
750 guard: self.connection_pool.lock(),
751 _not_send: PhantomData,
752 },
753 peer: &self.peer,
754 }
755 }
756}
757
758impl<'a> Deref for ConnectionPoolGuard<'a> {
759 type Target = ConnectionPool;
760
761 fn deref(&self) -> &Self::Target {
762 &*self.guard
763 }
764}
765
766impl<'a> DerefMut for ConnectionPoolGuard<'a> {
767 fn deref_mut(&mut self) -> &mut Self::Target {
768 &mut *self.guard
769 }
770}
771
772impl<'a> Drop for ConnectionPoolGuard<'a> {
773 fn drop(&mut self) {
774 #[cfg(test)]
775 self.check_invariants();
776 }
777}
778
779fn broadcast<F>(
780 sender_id: Option<ConnectionId>,
781 receiver_ids: impl IntoIterator<Item = ConnectionId>,
782 mut f: F,
783) where
784 F: FnMut(ConnectionId) -> anyhow::Result<()>,
785{
786 for receiver_id in receiver_ids {
787 if Some(receiver_id) != sender_id {
788 if let Err(error) = f(receiver_id) {
789 tracing::error!("failed to send to {:?} {}", receiver_id, error);
790 }
791 }
792 }
793}
794
795lazy_static! {
796 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
797 static ref ZED_APP_VERSION: HeaderName = HeaderName::from_static("x-zed-app-version");
798}
799
800pub struct ProtocolVersion(u32);
801
802impl Header for ProtocolVersion {
803 fn name() -> &'static HeaderName {
804 &ZED_PROTOCOL_VERSION
805 }
806
807 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
808 where
809 Self: Sized,
810 I: Iterator<Item = &'i axum::http::HeaderValue>,
811 {
812 let version = values
813 .next()
814 .ok_or_else(axum::headers::Error::invalid)?
815 .to_str()
816 .map_err(|_| axum::headers::Error::invalid())?
817 .parse()
818 .map_err(|_| axum::headers::Error::invalid())?;
819 Ok(Self(version))
820 }
821
822 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
823 values.extend([self.0.to_string().parse().unwrap()]);
824 }
825}
826
827pub struct AppVersionHeader(SemanticVersion);
828impl Header for AppVersionHeader {
829 fn name() -> &'static HeaderName {
830 &ZED_APP_VERSION
831 }
832
833 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
834 where
835 Self: Sized,
836 I: Iterator<Item = &'i axum::http::HeaderValue>,
837 {
838 let version = values
839 .next()
840 .ok_or_else(axum::headers::Error::invalid)?
841 .to_str()
842 .map_err(|_| axum::headers::Error::invalid())?
843 .parse()
844 .map_err(|_| axum::headers::Error::invalid())?;
845 Ok(Self(version))
846 }
847
848 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
849 values.extend([self.0.to_string().parse().unwrap()]);
850 }
851}
852
853pub fn routes(server: Arc<Server>) -> Router<Body> {
854 Router::new()
855 .route("/rpc", get(handle_websocket_request))
856 .layer(
857 ServiceBuilder::new()
858 .layer(Extension(server.app_state.clone()))
859 .layer(middleware::from_fn(auth::validate_header)),
860 )
861 .route("/metrics", get(handle_metrics))
862 .layer(Extension(server))
863}
864
865pub async fn handle_websocket_request(
866 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
867 app_version_header: Option<TypedHeader<AppVersionHeader>>,
868 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
869 Extension(server): Extension<Arc<Server>>,
870 Extension(user): Extension<User>,
871 Extension(impersonator): Extension<Impersonator>,
872 ws: WebSocketUpgrade,
873) -> axum::response::Response {
874 if protocol_version != rpc::PROTOCOL_VERSION {
875 return (
876 StatusCode::UPGRADE_REQUIRED,
877 "client must be upgraded".to_string(),
878 )
879 .into_response();
880 }
881
882 // the first version of zed that sent this header was 0.121.x
883 if let Some(version) = app_version_header.map(|header| header.0 .0) {
884 // 0.123.0 was a nightly version with incompatible collab changes
885 // that were reverted.
886 if version == "0.123.0".parse().unwrap() {
887 return (
888 StatusCode::UPGRADE_REQUIRED,
889 "client must be upgraded".to_string(),
890 )
891 .into_response();
892 }
893 }
894
895 let socket_address = socket_address.to_string();
896 ws.on_upgrade(move |socket| {
897 use util::ResultExt;
898 let socket = socket
899 .map_ok(to_tungstenite_message)
900 .err_into()
901 .with(|message| async move { Ok(to_axum_message(message)) });
902 let connection = Connection::new(Box::pin(socket));
903 async move {
904 server
905 .handle_connection(
906 connection,
907 socket_address,
908 user,
909 impersonator.0,
910 None,
911 Executor::Production,
912 )
913 .await
914 .log_err();
915 }
916 })
917}
918
919pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
920 let connections = server
921 .connection_pool
922 .lock()
923 .connections()
924 .filter(|connection| !connection.admin)
925 .count();
926
927 METRIC_CONNECTIONS.set(connections as _);
928
929 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
930 METRIC_SHARED_PROJECTS.set(shared_projects as _);
931
932 let encoder = prometheus::TextEncoder::new();
933 let metric_families = prometheus::gather();
934 let encoded_metrics = encoder
935 .encode_to_string(&metric_families)
936 .map_err(|err| anyhow!("{}", err))?;
937 Ok(encoded_metrics)
938}
939
940#[instrument(err, skip(executor))]
941async fn connection_lost(
942 session: Session,
943 mut teardown: watch::Receiver<()>,
944 executor: Executor,
945) -> Result<()> {
946 session.peer.disconnect(session.connection_id);
947 session
948 .connection_pool()
949 .await
950 .remove_connection(session.connection_id)?;
951
952 session
953 .db()
954 .await
955 .connection_lost(session.connection_id)
956 .await
957 .trace_err();
958
959 futures::select_biased! {
960 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
961 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
962 leave_room_for_session(&session).await.trace_err();
963 leave_channel_buffers_for_session(&session)
964 .await
965 .trace_err();
966
967 if !session
968 .connection_pool()
969 .await
970 .is_user_online(session.user_id)
971 {
972 let db = session.db().await;
973 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
974 room_updated(&room, &session.peer);
975 }
976 }
977
978 update_user_contacts(session.user_id, &session).await?;
979 }
980 _ = teardown.changed().fuse() => {}
981 }
982
983 Ok(())
984}
985
986/// Acknowledges a ping from a client, used to keep the connection alive.
987async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
988 response.send(proto::Ack {})?;
989 Ok(())
990}
991
992/// Creates a new room for calling (outside of channels)
993async fn create_room(
994 _request: proto::CreateRoom,
995 response: Response<proto::CreateRoom>,
996 session: Session,
997) -> Result<()> {
998 let live_kit_room = nanoid::nanoid!(30);
999
1000 let live_kit_connection_info = {
1001 let live_kit_room = live_kit_room.clone();
1002 let live_kit = session.live_kit_client.as_ref();
1003
1004 util::async_maybe!({
1005 let live_kit = live_kit?;
1006
1007 let token = live_kit
1008 .room_token(&live_kit_room, &session.user_id.to_string())
1009 .trace_err()?;
1010
1011 Some(proto::LiveKitConnectionInfo {
1012 server_url: live_kit.url().into(),
1013 token,
1014 can_publish: true,
1015 })
1016 })
1017 }
1018 .await;
1019
1020 let room = session
1021 .db()
1022 .await
1023 .create_room(session.user_id, session.connection_id, &live_kit_room)
1024 .await?;
1025
1026 response.send(proto::CreateRoomResponse {
1027 room: Some(room.clone()),
1028 live_kit_connection_info,
1029 })?;
1030
1031 update_user_contacts(session.user_id, &session).await?;
1032 Ok(())
1033}
1034
1035/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1036async fn join_room(
1037 request: proto::JoinRoom,
1038 response: Response<proto::JoinRoom>,
1039 session: Session,
1040) -> Result<()> {
1041 let room_id = RoomId::from_proto(request.id);
1042
1043 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1044
1045 if let Some(channel_id) = channel_id {
1046 return join_channel_internal(channel_id, Box::new(response), session).await;
1047 }
1048
1049 let joined_room = {
1050 let room = session
1051 .db()
1052 .await
1053 .join_room(room_id, session.user_id, session.connection_id)
1054 .await?;
1055 room_updated(&room.room, &session.peer);
1056 room.into_inner()
1057 };
1058
1059 for connection_id in session
1060 .connection_pool()
1061 .await
1062 .user_connection_ids(session.user_id)
1063 {
1064 session
1065 .peer
1066 .send(
1067 connection_id,
1068 proto::CallCanceled {
1069 room_id: room_id.to_proto(),
1070 },
1071 )
1072 .trace_err();
1073 }
1074
1075 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1076 if let Some(token) = live_kit
1077 .room_token(
1078 &joined_room.room.live_kit_room,
1079 &session.user_id.to_string(),
1080 )
1081 .trace_err()
1082 {
1083 Some(proto::LiveKitConnectionInfo {
1084 server_url: live_kit.url().into(),
1085 token,
1086 can_publish: true,
1087 })
1088 } else {
1089 None
1090 }
1091 } else {
1092 None
1093 };
1094
1095 response.send(proto::JoinRoomResponse {
1096 room: Some(joined_room.room),
1097 channel_id: None,
1098 live_kit_connection_info,
1099 })?;
1100
1101 update_user_contacts(session.user_id, &session).await?;
1102 Ok(())
1103}
1104
1105/// Rejoin room is used to reconnect to a room after connection errors.
1106async fn rejoin_room(
1107 request: proto::RejoinRoom,
1108 response: Response<proto::RejoinRoom>,
1109 session: Session,
1110) -> Result<()> {
1111 let room;
1112 let channel_id;
1113 let channel_members;
1114 {
1115 let mut rejoined_room = session
1116 .db()
1117 .await
1118 .rejoin_room(request, session.user_id, session.connection_id)
1119 .await?;
1120
1121 response.send(proto::RejoinRoomResponse {
1122 room: Some(rejoined_room.room.clone()),
1123 reshared_projects: rejoined_room
1124 .reshared_projects
1125 .iter()
1126 .map(|project| proto::ResharedProject {
1127 id: project.id.to_proto(),
1128 collaborators: project
1129 .collaborators
1130 .iter()
1131 .map(|collaborator| collaborator.to_proto())
1132 .collect(),
1133 })
1134 .collect(),
1135 rejoined_projects: rejoined_room
1136 .rejoined_projects
1137 .iter()
1138 .map(|rejoined_project| proto::RejoinedProject {
1139 id: rejoined_project.id.to_proto(),
1140 worktrees: rejoined_project
1141 .worktrees
1142 .iter()
1143 .map(|worktree| proto::WorktreeMetadata {
1144 id: worktree.id,
1145 root_name: worktree.root_name.clone(),
1146 visible: worktree.visible,
1147 abs_path: worktree.abs_path.clone(),
1148 })
1149 .collect(),
1150 collaborators: rejoined_project
1151 .collaborators
1152 .iter()
1153 .map(|collaborator| collaborator.to_proto())
1154 .collect(),
1155 language_servers: rejoined_project.language_servers.clone(),
1156 })
1157 .collect(),
1158 })?;
1159 room_updated(&rejoined_room.room, &session.peer);
1160
1161 for project in &rejoined_room.reshared_projects {
1162 for collaborator in &project.collaborators {
1163 session
1164 .peer
1165 .send(
1166 collaborator.connection_id,
1167 proto::UpdateProjectCollaborator {
1168 project_id: project.id.to_proto(),
1169 old_peer_id: Some(project.old_connection_id.into()),
1170 new_peer_id: Some(session.connection_id.into()),
1171 },
1172 )
1173 .trace_err();
1174 }
1175
1176 broadcast(
1177 Some(session.connection_id),
1178 project
1179 .collaborators
1180 .iter()
1181 .map(|collaborator| collaborator.connection_id),
1182 |connection_id| {
1183 session.peer.forward_send(
1184 session.connection_id,
1185 connection_id,
1186 proto::UpdateProject {
1187 project_id: project.id.to_proto(),
1188 worktrees: project.worktrees.clone(),
1189 },
1190 )
1191 },
1192 );
1193 }
1194
1195 for project in &rejoined_room.rejoined_projects {
1196 for collaborator in &project.collaborators {
1197 session
1198 .peer
1199 .send(
1200 collaborator.connection_id,
1201 proto::UpdateProjectCollaborator {
1202 project_id: project.id.to_proto(),
1203 old_peer_id: Some(project.old_connection_id.into()),
1204 new_peer_id: Some(session.connection_id.into()),
1205 },
1206 )
1207 .trace_err();
1208 }
1209 }
1210
1211 for project in &mut rejoined_room.rejoined_projects {
1212 for worktree in mem::take(&mut project.worktrees) {
1213 #[cfg(any(test, feature = "test-support"))]
1214 const MAX_CHUNK_SIZE: usize = 2;
1215 #[cfg(not(any(test, feature = "test-support")))]
1216 const MAX_CHUNK_SIZE: usize = 256;
1217
1218 // Stream this worktree's entries.
1219 let message = proto::UpdateWorktree {
1220 project_id: project.id.to_proto(),
1221 worktree_id: worktree.id,
1222 abs_path: worktree.abs_path.clone(),
1223 root_name: worktree.root_name,
1224 updated_entries: worktree.updated_entries,
1225 removed_entries: worktree.removed_entries,
1226 scan_id: worktree.scan_id,
1227 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1228 updated_repositories: worktree.updated_repositories,
1229 removed_repositories: worktree.removed_repositories,
1230 };
1231 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1232 session.peer.send(session.connection_id, update.clone())?;
1233 }
1234
1235 // Stream this worktree's diagnostics.
1236 for summary in worktree.diagnostic_summaries {
1237 session.peer.send(
1238 session.connection_id,
1239 proto::UpdateDiagnosticSummary {
1240 project_id: project.id.to_proto(),
1241 worktree_id: worktree.id,
1242 summary: Some(summary),
1243 },
1244 )?;
1245 }
1246
1247 for settings_file in worktree.settings_files {
1248 session.peer.send(
1249 session.connection_id,
1250 proto::UpdateWorktreeSettings {
1251 project_id: project.id.to_proto(),
1252 worktree_id: worktree.id,
1253 path: settings_file.path,
1254 content: Some(settings_file.content),
1255 },
1256 )?;
1257 }
1258 }
1259
1260 for language_server in &project.language_servers {
1261 session.peer.send(
1262 session.connection_id,
1263 proto::UpdateLanguageServer {
1264 project_id: project.id.to_proto(),
1265 language_server_id: language_server.id,
1266 variant: Some(
1267 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1268 proto::LspDiskBasedDiagnosticsUpdated {},
1269 ),
1270 ),
1271 },
1272 )?;
1273 }
1274 }
1275
1276 let rejoined_room = rejoined_room.into_inner();
1277
1278 room = rejoined_room.room;
1279 channel_id = rejoined_room.channel_id;
1280 channel_members = rejoined_room.channel_members;
1281 }
1282
1283 if let Some(channel_id) = channel_id {
1284 channel_updated(
1285 channel_id,
1286 &room,
1287 &channel_members,
1288 &session.peer,
1289 &*session.connection_pool().await,
1290 );
1291 }
1292
1293 update_user_contacts(session.user_id, &session).await?;
1294 Ok(())
1295}
1296
1297/// leave room disconnects from the room.
1298async fn leave_room(
1299 _: proto::LeaveRoom,
1300 response: Response<proto::LeaveRoom>,
1301 session: Session,
1302) -> Result<()> {
1303 leave_room_for_session(&session).await?;
1304 response.send(proto::Ack {})?;
1305 Ok(())
1306}
1307
1308/// Updates the permissions of someone else in the room.
1309async fn set_room_participant_role(
1310 request: proto::SetRoomParticipantRole,
1311 response: Response<proto::SetRoomParticipantRole>,
1312 session: Session,
1313) -> Result<()> {
1314 let (live_kit_room, can_publish) = {
1315 let room = session
1316 .db()
1317 .await
1318 .set_room_participant_role(
1319 session.user_id,
1320 RoomId::from_proto(request.room_id),
1321 UserId::from_proto(request.user_id),
1322 ChannelRole::from(request.role()),
1323 )
1324 .await?;
1325
1326 let live_kit_room = room.live_kit_room.clone();
1327 let can_publish = ChannelRole::from(request.role()).can_publish_to_rooms();
1328 room_updated(&room, &session.peer);
1329 (live_kit_room, can_publish)
1330 };
1331
1332 if let Some(live_kit) = session.live_kit_client.as_ref() {
1333 live_kit
1334 .update_participant(
1335 live_kit_room.clone(),
1336 request.user_id.to_string(),
1337 live_kit_server::proto::ParticipantPermission {
1338 can_subscribe: true,
1339 can_publish,
1340 can_publish_data: can_publish,
1341 hidden: false,
1342 recorder: false,
1343 },
1344 )
1345 .await
1346 .trace_err();
1347 }
1348
1349 response.send(proto::Ack {})?;
1350 Ok(())
1351}
1352
1353/// Call someone else into the current room
1354async fn call(
1355 request: proto::Call,
1356 response: Response<proto::Call>,
1357 session: Session,
1358) -> Result<()> {
1359 let room_id = RoomId::from_proto(request.room_id);
1360 let calling_user_id = session.user_id;
1361 let calling_connection_id = session.connection_id;
1362 let called_user_id = UserId::from_proto(request.called_user_id);
1363 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1364 if !session
1365 .db()
1366 .await
1367 .has_contact(calling_user_id, called_user_id)
1368 .await?
1369 {
1370 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1371 }
1372
1373 let incoming_call = {
1374 let (room, incoming_call) = &mut *session
1375 .db()
1376 .await
1377 .call(
1378 room_id,
1379 calling_user_id,
1380 calling_connection_id,
1381 called_user_id,
1382 initial_project_id,
1383 )
1384 .await?;
1385 room_updated(&room, &session.peer);
1386 mem::take(incoming_call)
1387 };
1388 update_user_contacts(called_user_id, &session).await?;
1389
1390 let mut calls = session
1391 .connection_pool()
1392 .await
1393 .user_connection_ids(called_user_id)
1394 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1395 .collect::<FuturesUnordered<_>>();
1396
1397 while let Some(call_response) = calls.next().await {
1398 match call_response.as_ref() {
1399 Ok(_) => {
1400 response.send(proto::Ack {})?;
1401 return Ok(());
1402 }
1403 Err(_) => {
1404 call_response.trace_err();
1405 }
1406 }
1407 }
1408
1409 {
1410 let room = session
1411 .db()
1412 .await
1413 .call_failed(room_id, called_user_id)
1414 .await?;
1415 room_updated(&room, &session.peer);
1416 }
1417 update_user_contacts(called_user_id, &session).await?;
1418
1419 Err(anyhow!("failed to ring user"))?
1420}
1421
1422/// Cancel an outgoing call.
1423async fn cancel_call(
1424 request: proto::CancelCall,
1425 response: Response<proto::CancelCall>,
1426 session: Session,
1427) -> Result<()> {
1428 let called_user_id = UserId::from_proto(request.called_user_id);
1429 let room_id = RoomId::from_proto(request.room_id);
1430 {
1431 let room = session
1432 .db()
1433 .await
1434 .cancel_call(room_id, session.connection_id, called_user_id)
1435 .await?;
1436 room_updated(&room, &session.peer);
1437 }
1438
1439 for connection_id in session
1440 .connection_pool()
1441 .await
1442 .user_connection_ids(called_user_id)
1443 {
1444 session
1445 .peer
1446 .send(
1447 connection_id,
1448 proto::CallCanceled {
1449 room_id: room_id.to_proto(),
1450 },
1451 )
1452 .trace_err();
1453 }
1454 response.send(proto::Ack {})?;
1455
1456 update_user_contacts(called_user_id, &session).await?;
1457 Ok(())
1458}
1459
1460/// Decline an incoming call.
1461async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1462 let room_id = RoomId::from_proto(message.room_id);
1463 {
1464 let room = session
1465 .db()
1466 .await
1467 .decline_call(Some(room_id), session.user_id)
1468 .await?
1469 .ok_or_else(|| anyhow!("failed to decline call"))?;
1470 room_updated(&room, &session.peer);
1471 }
1472
1473 for connection_id in session
1474 .connection_pool()
1475 .await
1476 .user_connection_ids(session.user_id)
1477 {
1478 session
1479 .peer
1480 .send(
1481 connection_id,
1482 proto::CallCanceled {
1483 room_id: room_id.to_proto(),
1484 },
1485 )
1486 .trace_err();
1487 }
1488 update_user_contacts(session.user_id, &session).await?;
1489 Ok(())
1490}
1491
1492/// Updates other participants in the room with your current location.
1493async fn update_participant_location(
1494 request: proto::UpdateParticipantLocation,
1495 response: Response<proto::UpdateParticipantLocation>,
1496 session: Session,
1497) -> Result<()> {
1498 let room_id = RoomId::from_proto(request.room_id);
1499 let location = request
1500 .location
1501 .ok_or_else(|| anyhow!("invalid location"))?;
1502
1503 let db = session.db().await;
1504 let room = db
1505 .update_room_participant_location(room_id, session.connection_id, location)
1506 .await?;
1507
1508 room_updated(&room, &session.peer);
1509 response.send(proto::Ack {})?;
1510 Ok(())
1511}
1512
1513/// Share a project into the room.
1514async fn share_project(
1515 request: proto::ShareProject,
1516 response: Response<proto::ShareProject>,
1517 session: Session,
1518) -> Result<()> {
1519 let (project_id, room) = &*session
1520 .db()
1521 .await
1522 .share_project(
1523 RoomId::from_proto(request.room_id),
1524 session.connection_id,
1525 &request.worktrees,
1526 )
1527 .await?;
1528 response.send(proto::ShareProjectResponse {
1529 project_id: project_id.to_proto(),
1530 })?;
1531 room_updated(&room, &session.peer);
1532
1533 Ok(())
1534}
1535
1536/// Unshare a project from the room.
1537async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1538 let project_id = ProjectId::from_proto(message.project_id);
1539
1540 let (room, guest_connection_ids) = &*session
1541 .db()
1542 .await
1543 .unshare_project(project_id, session.connection_id)
1544 .await?;
1545
1546 broadcast(
1547 Some(session.connection_id),
1548 guest_connection_ids.iter().copied(),
1549 |conn_id| session.peer.send(conn_id, message.clone()),
1550 );
1551 room_updated(&room, &session.peer);
1552
1553 Ok(())
1554}
1555
1556/// Join someone elses shared project.
1557async fn join_project(
1558 request: proto::JoinProject,
1559 response: Response<proto::JoinProject>,
1560 session: Session,
1561) -> Result<()> {
1562 let project_id = ProjectId::from_proto(request.project_id);
1563 let guest_user_id = session.user_id;
1564
1565 tracing::info!(%project_id, "join project");
1566
1567 let (project, replica_id) = &mut *session
1568 .db()
1569 .await
1570 .join_project(project_id, session.connection_id)
1571 .await?;
1572
1573 let collaborators = project
1574 .collaborators
1575 .iter()
1576 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1577 .map(|collaborator| collaborator.to_proto())
1578 .collect::<Vec<_>>();
1579
1580 let worktrees = project
1581 .worktrees
1582 .iter()
1583 .map(|(id, worktree)| proto::WorktreeMetadata {
1584 id: *id,
1585 root_name: worktree.root_name.clone(),
1586 visible: worktree.visible,
1587 abs_path: worktree.abs_path.clone(),
1588 })
1589 .collect::<Vec<_>>();
1590
1591 for collaborator in &collaborators {
1592 session
1593 .peer
1594 .send(
1595 collaborator.peer_id.unwrap().into(),
1596 proto::AddProjectCollaborator {
1597 project_id: project_id.to_proto(),
1598 collaborator: Some(proto::Collaborator {
1599 peer_id: Some(session.connection_id.into()),
1600 replica_id: replica_id.0 as u32,
1601 user_id: guest_user_id.to_proto(),
1602 }),
1603 },
1604 )
1605 .trace_err();
1606 }
1607
1608 // First, we send the metadata associated with each worktree.
1609 response.send(proto::JoinProjectResponse {
1610 worktrees: worktrees.clone(),
1611 replica_id: replica_id.0 as u32,
1612 collaborators: collaborators.clone(),
1613 language_servers: project.language_servers.clone(),
1614 })?;
1615
1616 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1617 #[cfg(any(test, feature = "test-support"))]
1618 const MAX_CHUNK_SIZE: usize = 2;
1619 #[cfg(not(any(test, feature = "test-support")))]
1620 const MAX_CHUNK_SIZE: usize = 256;
1621
1622 // Stream this worktree's entries.
1623 let message = proto::UpdateWorktree {
1624 project_id: project_id.to_proto(),
1625 worktree_id,
1626 abs_path: worktree.abs_path.clone(),
1627 root_name: worktree.root_name,
1628 updated_entries: worktree.entries,
1629 removed_entries: Default::default(),
1630 scan_id: worktree.scan_id,
1631 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1632 updated_repositories: worktree.repository_entries.into_values().collect(),
1633 removed_repositories: Default::default(),
1634 };
1635 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1636 session.peer.send(session.connection_id, update.clone())?;
1637 }
1638
1639 // Stream this worktree's diagnostics.
1640 for summary in worktree.diagnostic_summaries {
1641 session.peer.send(
1642 session.connection_id,
1643 proto::UpdateDiagnosticSummary {
1644 project_id: project_id.to_proto(),
1645 worktree_id: worktree.id,
1646 summary: Some(summary),
1647 },
1648 )?;
1649 }
1650
1651 for settings_file in worktree.settings_files {
1652 session.peer.send(
1653 session.connection_id,
1654 proto::UpdateWorktreeSettings {
1655 project_id: project_id.to_proto(),
1656 worktree_id: worktree.id,
1657 path: settings_file.path,
1658 content: Some(settings_file.content),
1659 },
1660 )?;
1661 }
1662 }
1663
1664 for language_server in &project.language_servers {
1665 session.peer.send(
1666 session.connection_id,
1667 proto::UpdateLanguageServer {
1668 project_id: project_id.to_proto(),
1669 language_server_id: language_server.id,
1670 variant: Some(
1671 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1672 proto::LspDiskBasedDiagnosticsUpdated {},
1673 ),
1674 ),
1675 },
1676 )?;
1677 }
1678
1679 Ok(())
1680}
1681
1682/// Leave someone elses shared project.
1683async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1684 let sender_id = session.connection_id;
1685 let project_id = ProjectId::from_proto(request.project_id);
1686
1687 let (room, project) = &*session
1688 .db()
1689 .await
1690 .leave_project(project_id, sender_id)
1691 .await?;
1692 tracing::info!(
1693 %project_id,
1694 host_user_id = %project.host_user_id,
1695 host_connection_id = ?project.host_connection_id,
1696 "leave project"
1697 );
1698
1699 project_left(&project, &session);
1700 room_updated(&room, &session.peer);
1701
1702 Ok(())
1703}
1704
1705/// Updates other participants with changes to the project
1706async fn update_project(
1707 request: proto::UpdateProject,
1708 response: Response<proto::UpdateProject>,
1709 session: Session,
1710) -> Result<()> {
1711 let project_id = ProjectId::from_proto(request.project_id);
1712 let (room, guest_connection_ids) = &*session
1713 .db()
1714 .await
1715 .update_project(project_id, session.connection_id, &request.worktrees)
1716 .await?;
1717 broadcast(
1718 Some(session.connection_id),
1719 guest_connection_ids.iter().copied(),
1720 |connection_id| {
1721 session
1722 .peer
1723 .forward_send(session.connection_id, connection_id, request.clone())
1724 },
1725 );
1726 room_updated(&room, &session.peer);
1727 response.send(proto::Ack {})?;
1728
1729 Ok(())
1730}
1731
1732/// Updates other participants with changes to the worktree
1733async fn update_worktree(
1734 request: proto::UpdateWorktree,
1735 response: Response<proto::UpdateWorktree>,
1736 session: Session,
1737) -> Result<()> {
1738 let guest_connection_ids = session
1739 .db()
1740 .await
1741 .update_worktree(&request, session.connection_id)
1742 .await?;
1743
1744 broadcast(
1745 Some(session.connection_id),
1746 guest_connection_ids.iter().copied(),
1747 |connection_id| {
1748 session
1749 .peer
1750 .forward_send(session.connection_id, connection_id, request.clone())
1751 },
1752 );
1753 response.send(proto::Ack {})?;
1754 Ok(())
1755}
1756
1757/// Updates other participants with changes to the diagnostics
1758async fn update_diagnostic_summary(
1759 message: proto::UpdateDiagnosticSummary,
1760 session: Session,
1761) -> Result<()> {
1762 let guest_connection_ids = session
1763 .db()
1764 .await
1765 .update_diagnostic_summary(&message, session.connection_id)
1766 .await?;
1767
1768 broadcast(
1769 Some(session.connection_id),
1770 guest_connection_ids.iter().copied(),
1771 |connection_id| {
1772 session
1773 .peer
1774 .forward_send(session.connection_id, connection_id, message.clone())
1775 },
1776 );
1777
1778 Ok(())
1779}
1780
1781/// Updates other participants with changes to the worktree settings
1782async fn update_worktree_settings(
1783 message: proto::UpdateWorktreeSettings,
1784 session: Session,
1785) -> Result<()> {
1786 let guest_connection_ids = session
1787 .db()
1788 .await
1789 .update_worktree_settings(&message, session.connection_id)
1790 .await?;
1791
1792 broadcast(
1793 Some(session.connection_id),
1794 guest_connection_ids.iter().copied(),
1795 |connection_id| {
1796 session
1797 .peer
1798 .forward_send(session.connection_id, connection_id, message.clone())
1799 },
1800 );
1801
1802 Ok(())
1803}
1804
1805/// Notify other participants that a language server has started.
1806async fn start_language_server(
1807 request: proto::StartLanguageServer,
1808 session: Session,
1809) -> Result<()> {
1810 let guest_connection_ids = session
1811 .db()
1812 .await
1813 .start_language_server(&request, session.connection_id)
1814 .await?;
1815
1816 broadcast(
1817 Some(session.connection_id),
1818 guest_connection_ids.iter().copied(),
1819 |connection_id| {
1820 session
1821 .peer
1822 .forward_send(session.connection_id, connection_id, request.clone())
1823 },
1824 );
1825 Ok(())
1826}
1827
1828/// Notify other participants that a language server has changed.
1829async fn update_language_server(
1830 request: proto::UpdateLanguageServer,
1831 session: Session,
1832) -> Result<()> {
1833 let project_id = ProjectId::from_proto(request.project_id);
1834 let project_connection_ids = session
1835 .db()
1836 .await
1837 .project_connection_ids(project_id, session.connection_id)
1838 .await?;
1839 broadcast(
1840 Some(session.connection_id),
1841 project_connection_ids.iter().copied(),
1842 |connection_id| {
1843 session
1844 .peer
1845 .forward_send(session.connection_id, connection_id, request.clone())
1846 },
1847 );
1848 Ok(())
1849}
1850
1851/// forward a project request to the host. These requests should be read only
1852/// as guests are allowed to send them.
1853async fn forward_read_only_project_request<T>(
1854 request: T,
1855 response: Response<T>,
1856 session: Session,
1857) -> Result<()>
1858where
1859 T: EntityMessage + RequestMessage,
1860{
1861 let project_id = ProjectId::from_proto(request.remote_entity_id());
1862 let host_connection_id = session
1863 .db()
1864 .await
1865 .host_for_read_only_project_request(project_id, session.connection_id)
1866 .await?;
1867 let payload = session
1868 .peer
1869 .forward_request(session.connection_id, host_connection_id, request)
1870 .await?;
1871 response.send(payload)?;
1872 Ok(())
1873}
1874
1875/// forward a project request to the host. These requests are disallowed
1876/// for guests.
1877async fn forward_mutating_project_request<T>(
1878 request: T,
1879 response: Response<T>,
1880 session: Session,
1881) -> Result<()>
1882where
1883 T: EntityMessage + RequestMessage,
1884{
1885 let project_id = ProjectId::from_proto(request.remote_entity_id());
1886 let host_connection_id = session
1887 .db()
1888 .await
1889 .host_for_mutating_project_request(project_id, session.connection_id)
1890 .await?;
1891 let payload = session
1892 .peer
1893 .forward_request(session.connection_id, host_connection_id, request)
1894 .await?;
1895 response.send(payload)?;
1896 Ok(())
1897}
1898
1899/// Notify other participants that a new buffer has been created
1900async fn create_buffer_for_peer(
1901 request: proto::CreateBufferForPeer,
1902 session: Session,
1903) -> Result<()> {
1904 session
1905 .db()
1906 .await
1907 .check_user_is_project_host(
1908 ProjectId::from_proto(request.project_id),
1909 session.connection_id,
1910 )
1911 .await?;
1912 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1913 session
1914 .peer
1915 .forward_send(session.connection_id, peer_id.into(), request)?;
1916 Ok(())
1917}
1918
1919/// Notify other participants that a buffer has been updated. This is
1920/// allowed for guests as long as the update is limited to selections.
1921async fn update_buffer(
1922 request: proto::UpdateBuffer,
1923 response: Response<proto::UpdateBuffer>,
1924 session: Session,
1925) -> Result<()> {
1926 let project_id = ProjectId::from_proto(request.project_id);
1927 let mut guest_connection_ids;
1928 let mut host_connection_id = None;
1929
1930 let mut requires_write_permission = false;
1931
1932 for op in request.operations.iter() {
1933 match op.variant {
1934 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1935 Some(_) => requires_write_permission = true,
1936 }
1937 }
1938
1939 {
1940 let collaborators = session
1941 .db()
1942 .await
1943 .project_collaborators_for_buffer_update(
1944 project_id,
1945 session.connection_id,
1946 requires_write_permission,
1947 )
1948 .await?;
1949 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1950 for collaborator in collaborators.iter() {
1951 if collaborator.is_host {
1952 host_connection_id = Some(collaborator.connection_id);
1953 } else {
1954 guest_connection_ids.push(collaborator.connection_id);
1955 }
1956 }
1957 }
1958 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1959
1960 broadcast(
1961 Some(session.connection_id),
1962 guest_connection_ids,
1963 |connection_id| {
1964 session
1965 .peer
1966 .forward_send(session.connection_id, connection_id, request.clone())
1967 },
1968 );
1969 if host_connection_id != session.connection_id {
1970 session
1971 .peer
1972 .forward_request(session.connection_id, host_connection_id, request.clone())
1973 .await?;
1974 }
1975
1976 response.send(proto::Ack {})?;
1977 Ok(())
1978}
1979
1980/// Notify other participants that a project has been updated.
1981async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
1982 request: T,
1983 session: Session,
1984) -> Result<()> {
1985 let project_id = ProjectId::from_proto(request.remote_entity_id());
1986 let project_connection_ids = session
1987 .db()
1988 .await
1989 .project_connection_ids(project_id, session.connection_id)
1990 .await?;
1991
1992 broadcast(
1993 Some(session.connection_id),
1994 project_connection_ids.iter().copied(),
1995 |connection_id| {
1996 session
1997 .peer
1998 .forward_send(session.connection_id, connection_id, request.clone())
1999 },
2000 );
2001 Ok(())
2002}
2003
2004/// Start following another user in a call.
2005async fn follow(
2006 request: proto::Follow,
2007 response: Response<proto::Follow>,
2008 session: Session,
2009) -> Result<()> {
2010 let room_id = RoomId::from_proto(request.room_id);
2011 let project_id = request.project_id.map(ProjectId::from_proto);
2012 let leader_id = request
2013 .leader_id
2014 .ok_or_else(|| anyhow!("invalid leader id"))?
2015 .into();
2016 let follower_id = session.connection_id;
2017
2018 session
2019 .db()
2020 .await
2021 .check_room_participants(room_id, leader_id, session.connection_id)
2022 .await?;
2023
2024 let response_payload = session
2025 .peer
2026 .forward_request(session.connection_id, leader_id, request)
2027 .await?;
2028 response.send(response_payload)?;
2029
2030 if let Some(project_id) = project_id {
2031 let room = session
2032 .db()
2033 .await
2034 .follow(room_id, project_id, leader_id, follower_id)
2035 .await?;
2036 room_updated(&room, &session.peer);
2037 }
2038
2039 Ok(())
2040}
2041
2042/// Stop following another user in a call.
2043async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2044 let room_id = RoomId::from_proto(request.room_id);
2045 let project_id = request.project_id.map(ProjectId::from_proto);
2046 let leader_id = request
2047 .leader_id
2048 .ok_or_else(|| anyhow!("invalid leader id"))?
2049 .into();
2050 let follower_id = session.connection_id;
2051
2052 session
2053 .db()
2054 .await
2055 .check_room_participants(room_id, leader_id, session.connection_id)
2056 .await?;
2057
2058 session
2059 .peer
2060 .forward_send(session.connection_id, leader_id, request)?;
2061
2062 if let Some(project_id) = project_id {
2063 let room = session
2064 .db()
2065 .await
2066 .unfollow(room_id, project_id, leader_id, follower_id)
2067 .await?;
2068 room_updated(&room, &session.peer);
2069 }
2070
2071 Ok(())
2072}
2073
2074/// Notify everyone following you of your current location.
2075async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2076 let room_id = RoomId::from_proto(request.room_id);
2077 let database = session.db.lock().await;
2078
2079 let connection_ids = if let Some(project_id) = request.project_id {
2080 let project_id = ProjectId::from_proto(project_id);
2081 database
2082 .project_connection_ids(project_id, session.connection_id)
2083 .await?
2084 } else {
2085 database
2086 .room_connection_ids(room_id, session.connection_id)
2087 .await?
2088 };
2089
2090 // For now, don't send view update messages back to that view's current leader.
2091 let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2092 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2093 _ => None,
2094 });
2095
2096 for follower_peer_id in request.follower_ids.iter().copied() {
2097 let follower_connection_id = follower_peer_id.into();
2098 if Some(follower_peer_id) != connection_id_to_omit
2099 && connection_ids.contains(&follower_connection_id)
2100 {
2101 session.peer.forward_send(
2102 session.connection_id,
2103 follower_connection_id,
2104 request.clone(),
2105 )?;
2106 }
2107 }
2108 Ok(())
2109}
2110
2111/// Get public data about users.
2112async fn get_users(
2113 request: proto::GetUsers,
2114 response: Response<proto::GetUsers>,
2115 session: Session,
2116) -> Result<()> {
2117 let user_ids = request
2118 .user_ids
2119 .into_iter()
2120 .map(UserId::from_proto)
2121 .collect();
2122 let users = session
2123 .db()
2124 .await
2125 .get_users_by_ids(user_ids)
2126 .await?
2127 .into_iter()
2128 .map(|user| proto::User {
2129 id: user.id.to_proto(),
2130 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2131 github_login: user.github_login,
2132 })
2133 .collect();
2134 response.send(proto::UsersResponse { users })?;
2135 Ok(())
2136}
2137
2138/// Search for users (to invite) buy Github login
2139async fn fuzzy_search_users(
2140 request: proto::FuzzySearchUsers,
2141 response: Response<proto::FuzzySearchUsers>,
2142 session: Session,
2143) -> Result<()> {
2144 let query = request.query;
2145 let users = match query.len() {
2146 0 => vec![],
2147 1 | 2 => session
2148 .db()
2149 .await
2150 .get_user_by_github_login(&query)
2151 .await?
2152 .into_iter()
2153 .collect(),
2154 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2155 };
2156 let users = users
2157 .into_iter()
2158 .filter(|user| user.id != session.user_id)
2159 .map(|user| proto::User {
2160 id: user.id.to_proto(),
2161 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2162 github_login: user.github_login,
2163 })
2164 .collect();
2165 response.send(proto::UsersResponse { users })?;
2166 Ok(())
2167}
2168
2169/// Send a contact request to another user.
2170async fn request_contact(
2171 request: proto::RequestContact,
2172 response: Response<proto::RequestContact>,
2173 session: Session,
2174) -> Result<()> {
2175 let requester_id = session.user_id;
2176 let responder_id = UserId::from_proto(request.responder_id);
2177 if requester_id == responder_id {
2178 return Err(anyhow!("cannot add yourself as a contact"))?;
2179 }
2180
2181 let notifications = session
2182 .db()
2183 .await
2184 .send_contact_request(requester_id, responder_id)
2185 .await?;
2186
2187 // Update outgoing contact requests of requester
2188 let mut update = proto::UpdateContacts::default();
2189 update.outgoing_requests.push(responder_id.to_proto());
2190 for connection_id in session
2191 .connection_pool()
2192 .await
2193 .user_connection_ids(requester_id)
2194 {
2195 session.peer.send(connection_id, update.clone())?;
2196 }
2197
2198 // Update incoming contact requests of responder
2199 let mut update = proto::UpdateContacts::default();
2200 update
2201 .incoming_requests
2202 .push(proto::IncomingContactRequest {
2203 requester_id: requester_id.to_proto(),
2204 });
2205 let connection_pool = session.connection_pool().await;
2206 for connection_id in connection_pool.user_connection_ids(responder_id) {
2207 session.peer.send(connection_id, update.clone())?;
2208 }
2209
2210 send_notifications(&*connection_pool, &session.peer, notifications);
2211
2212 response.send(proto::Ack {})?;
2213 Ok(())
2214}
2215
2216/// Accept or decline a contact request
2217async fn respond_to_contact_request(
2218 request: proto::RespondToContactRequest,
2219 response: Response<proto::RespondToContactRequest>,
2220 session: Session,
2221) -> Result<()> {
2222 let responder_id = session.user_id;
2223 let requester_id = UserId::from_proto(request.requester_id);
2224 let db = session.db().await;
2225 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2226 db.dismiss_contact_notification(responder_id, requester_id)
2227 .await?;
2228 } else {
2229 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2230
2231 let notifications = db
2232 .respond_to_contact_request(responder_id, requester_id, accept)
2233 .await?;
2234 let requester_busy = db.is_user_busy(requester_id).await?;
2235 let responder_busy = db.is_user_busy(responder_id).await?;
2236
2237 let pool = session.connection_pool().await;
2238 // Update responder with new contact
2239 let mut update = proto::UpdateContacts::default();
2240 if accept {
2241 update
2242 .contacts
2243 .push(contact_for_user(requester_id, requester_busy, &pool));
2244 }
2245 update
2246 .remove_incoming_requests
2247 .push(requester_id.to_proto());
2248 for connection_id in pool.user_connection_ids(responder_id) {
2249 session.peer.send(connection_id, update.clone())?;
2250 }
2251
2252 // Update requester with new contact
2253 let mut update = proto::UpdateContacts::default();
2254 if accept {
2255 update
2256 .contacts
2257 .push(contact_for_user(responder_id, responder_busy, &pool));
2258 }
2259 update
2260 .remove_outgoing_requests
2261 .push(responder_id.to_proto());
2262
2263 for connection_id in pool.user_connection_ids(requester_id) {
2264 session.peer.send(connection_id, update.clone())?;
2265 }
2266
2267 send_notifications(&*pool, &session.peer, notifications);
2268 }
2269
2270 response.send(proto::Ack {})?;
2271 Ok(())
2272}
2273
2274/// Remove a contact.
2275async fn remove_contact(
2276 request: proto::RemoveContact,
2277 response: Response<proto::RemoveContact>,
2278 session: Session,
2279) -> Result<()> {
2280 let requester_id = session.user_id;
2281 let responder_id = UserId::from_proto(request.user_id);
2282 let db = session.db().await;
2283 let (contact_accepted, deleted_notification_id) =
2284 db.remove_contact(requester_id, responder_id).await?;
2285
2286 let pool = session.connection_pool().await;
2287 // Update outgoing contact requests of requester
2288 let mut update = proto::UpdateContacts::default();
2289 if contact_accepted {
2290 update.remove_contacts.push(responder_id.to_proto());
2291 } else {
2292 update
2293 .remove_outgoing_requests
2294 .push(responder_id.to_proto());
2295 }
2296 for connection_id in pool.user_connection_ids(requester_id) {
2297 session.peer.send(connection_id, update.clone())?;
2298 }
2299
2300 // Update incoming contact requests of responder
2301 let mut update = proto::UpdateContacts::default();
2302 if contact_accepted {
2303 update.remove_contacts.push(requester_id.to_proto());
2304 } else {
2305 update
2306 .remove_incoming_requests
2307 .push(requester_id.to_proto());
2308 }
2309 for connection_id in pool.user_connection_ids(responder_id) {
2310 session.peer.send(connection_id, update.clone())?;
2311 if let Some(notification_id) = deleted_notification_id {
2312 session.peer.send(
2313 connection_id,
2314 proto::DeleteNotification {
2315 notification_id: notification_id.to_proto(),
2316 },
2317 )?;
2318 }
2319 }
2320
2321 response.send(proto::Ack {})?;
2322 Ok(())
2323}
2324
2325/// Creates a new channel.
2326async fn create_channel(
2327 request: proto::CreateChannel,
2328 response: Response<proto::CreateChannel>,
2329 session: Session,
2330) -> Result<()> {
2331 let db = session.db().await;
2332
2333 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2334 let (channel, owner, channel_members) = db
2335 .create_channel(&request.name, parent_id, session.user_id)
2336 .await?;
2337
2338 response.send(proto::CreateChannelResponse {
2339 channel: Some(channel.to_proto()),
2340 parent_id: request.parent_id,
2341 })?;
2342
2343 let connection_pool = session.connection_pool().await;
2344 if let Some(owner) = owner {
2345 let update = proto::UpdateUserChannels {
2346 channel_memberships: vec![proto::ChannelMembership {
2347 channel_id: owner.channel_id.to_proto(),
2348 role: owner.role.into(),
2349 }],
2350 ..Default::default()
2351 };
2352 for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2353 session.peer.send(connection_id, update.clone())?;
2354 }
2355 }
2356
2357 for channel_member in channel_members {
2358 if !channel_member.role.can_see_channel(channel.visibility) {
2359 continue;
2360 }
2361
2362 let update = proto::UpdateChannels {
2363 channels: vec![channel.to_proto()],
2364 ..Default::default()
2365 };
2366 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2367 session.peer.send(connection_id, update.clone())?;
2368 }
2369 }
2370
2371 Ok(())
2372}
2373
2374/// Delete a channel
2375async fn delete_channel(
2376 request: proto::DeleteChannel,
2377 response: Response<proto::DeleteChannel>,
2378 session: Session,
2379) -> Result<()> {
2380 let db = session.db().await;
2381
2382 let channel_id = request.channel_id;
2383 let (removed_channels, member_ids) = db
2384 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2385 .await?;
2386 response.send(proto::Ack {})?;
2387
2388 // Notify members of removed channels
2389 let mut update = proto::UpdateChannels::default();
2390 update
2391 .delete_channels
2392 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2393
2394 let connection_pool = session.connection_pool().await;
2395 for member_id in member_ids {
2396 for connection_id in connection_pool.user_connection_ids(member_id) {
2397 session.peer.send(connection_id, update.clone())?;
2398 }
2399 }
2400
2401 Ok(())
2402}
2403
2404/// Invite someone to join a channel.
2405async fn invite_channel_member(
2406 request: proto::InviteChannelMember,
2407 response: Response<proto::InviteChannelMember>,
2408 session: Session,
2409) -> Result<()> {
2410 let db = session.db().await;
2411 let channel_id = ChannelId::from_proto(request.channel_id);
2412 let invitee_id = UserId::from_proto(request.user_id);
2413 let InviteMemberResult {
2414 channel,
2415 notifications,
2416 } = db
2417 .invite_channel_member(
2418 channel_id,
2419 invitee_id,
2420 session.user_id,
2421 request.role().into(),
2422 )
2423 .await?;
2424
2425 let update = proto::UpdateChannels {
2426 channel_invitations: vec![channel.to_proto()],
2427 ..Default::default()
2428 };
2429
2430 let connection_pool = session.connection_pool().await;
2431 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2432 session.peer.send(connection_id, update.clone())?;
2433 }
2434
2435 send_notifications(&*connection_pool, &session.peer, notifications);
2436
2437 response.send(proto::Ack {})?;
2438 Ok(())
2439}
2440
2441/// remove someone from a channel
2442async fn remove_channel_member(
2443 request: proto::RemoveChannelMember,
2444 response: Response<proto::RemoveChannelMember>,
2445 session: Session,
2446) -> Result<()> {
2447 let db = session.db().await;
2448 let channel_id = ChannelId::from_proto(request.channel_id);
2449 let member_id = UserId::from_proto(request.user_id);
2450
2451 let RemoveChannelMemberResult {
2452 membership_update,
2453 notification_id,
2454 } = db
2455 .remove_channel_member(channel_id, member_id, session.user_id)
2456 .await?;
2457
2458 let connection_pool = &session.connection_pool().await;
2459 notify_membership_updated(
2460 &connection_pool,
2461 membership_update,
2462 member_id,
2463 &session.peer,
2464 );
2465 for connection_id in connection_pool.user_connection_ids(member_id) {
2466 if let Some(notification_id) = notification_id {
2467 session
2468 .peer
2469 .send(
2470 connection_id,
2471 proto::DeleteNotification {
2472 notification_id: notification_id.to_proto(),
2473 },
2474 )
2475 .trace_err();
2476 }
2477 }
2478
2479 response.send(proto::Ack {})?;
2480 Ok(())
2481}
2482
2483/// Toggle the channel between public and private.
2484/// Care is taken to maintain the invariant that public channels only descend from public channels,
2485/// (though members-only channels can appear at any point in the hierarchy).
2486async fn set_channel_visibility(
2487 request: proto::SetChannelVisibility,
2488 response: Response<proto::SetChannelVisibility>,
2489 session: Session,
2490) -> Result<()> {
2491 let db = session.db().await;
2492 let channel_id = ChannelId::from_proto(request.channel_id);
2493 let visibility = request.visibility().into();
2494
2495 let (channel, channel_members) = db
2496 .set_channel_visibility(channel_id, visibility, session.user_id)
2497 .await?;
2498
2499 let connection_pool = session.connection_pool().await;
2500 for member in channel_members {
2501 let update = if member.role.can_see_channel(channel.visibility) {
2502 proto::UpdateChannels {
2503 channels: vec![channel.to_proto()],
2504 ..Default::default()
2505 }
2506 } else {
2507 proto::UpdateChannels {
2508 delete_channels: vec![channel.id.to_proto()],
2509 ..Default::default()
2510 }
2511 };
2512
2513 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2514 session.peer.send(connection_id, update.clone())?;
2515 }
2516 }
2517
2518 response.send(proto::Ack {})?;
2519 Ok(())
2520}
2521
2522/// Alter the role for a user in the channel.
2523async fn set_channel_member_role(
2524 request: proto::SetChannelMemberRole,
2525 response: Response<proto::SetChannelMemberRole>,
2526 session: Session,
2527) -> Result<()> {
2528 let db = session.db().await;
2529 let channel_id = ChannelId::from_proto(request.channel_id);
2530 let member_id = UserId::from_proto(request.user_id);
2531 let result = db
2532 .set_channel_member_role(
2533 channel_id,
2534 session.user_id,
2535 member_id,
2536 request.role().into(),
2537 )
2538 .await?;
2539
2540 match result {
2541 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2542 let connection_pool = session.connection_pool().await;
2543 notify_membership_updated(
2544 &connection_pool,
2545 membership_update,
2546 member_id,
2547 &session.peer,
2548 )
2549 }
2550 db::SetMemberRoleResult::InviteUpdated(channel) => {
2551 let update = proto::UpdateChannels {
2552 channel_invitations: vec![channel.to_proto()],
2553 ..Default::default()
2554 };
2555
2556 for connection_id in session
2557 .connection_pool()
2558 .await
2559 .user_connection_ids(member_id)
2560 {
2561 session.peer.send(connection_id, update.clone())?;
2562 }
2563 }
2564 }
2565
2566 response.send(proto::Ack {})?;
2567 Ok(())
2568}
2569
2570/// Change the name of a channel
2571async fn rename_channel(
2572 request: proto::RenameChannel,
2573 response: Response<proto::RenameChannel>,
2574 session: Session,
2575) -> Result<()> {
2576 let db = session.db().await;
2577 let channel_id = ChannelId::from_proto(request.channel_id);
2578 let (channel, channel_members) = db
2579 .rename_channel(channel_id, session.user_id, &request.name)
2580 .await?;
2581
2582 response.send(proto::RenameChannelResponse {
2583 channel: Some(channel.to_proto()),
2584 })?;
2585
2586 let connection_pool = session.connection_pool().await;
2587 for channel_member in channel_members {
2588 if !channel_member.role.can_see_channel(channel.visibility) {
2589 continue;
2590 }
2591 let update = proto::UpdateChannels {
2592 channels: vec![channel.to_proto()],
2593 ..Default::default()
2594 };
2595 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2596 session.peer.send(connection_id, update.clone())?;
2597 }
2598 }
2599
2600 Ok(())
2601}
2602
2603/// Move a channel to a new parent.
2604async fn move_channel(
2605 request: proto::MoveChannel,
2606 response: Response<proto::MoveChannel>,
2607 session: Session,
2608) -> Result<()> {
2609 let channel_id = ChannelId::from_proto(request.channel_id);
2610 let to = ChannelId::from_proto(request.to);
2611
2612 let (channels, channel_members) = session
2613 .db()
2614 .await
2615 .move_channel(channel_id, to, session.user_id)
2616 .await?;
2617
2618 let connection_pool = session.connection_pool().await;
2619 for member in channel_members {
2620 let channels = channels
2621 .iter()
2622 .filter_map(|channel| {
2623 if member.role.can_see_channel(channel.visibility) {
2624 Some(channel.to_proto())
2625 } else {
2626 None
2627 }
2628 })
2629 .collect::<Vec<_>>();
2630 if channels.is_empty() {
2631 continue;
2632 }
2633
2634 let update = proto::UpdateChannels {
2635 channels,
2636 ..Default::default()
2637 };
2638
2639 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2640 session.peer.send(connection_id, update.clone())?;
2641 }
2642 }
2643
2644 response.send(Ack {})?;
2645 Ok(())
2646}
2647
2648/// Get the list of channel members
2649async fn get_channel_members(
2650 request: proto::GetChannelMembers,
2651 response: Response<proto::GetChannelMembers>,
2652 session: Session,
2653) -> Result<()> {
2654 let db = session.db().await;
2655 let channel_id = ChannelId::from_proto(request.channel_id);
2656 let members = db
2657 .get_channel_participant_details(channel_id, session.user_id)
2658 .await?;
2659 response.send(proto::GetChannelMembersResponse { members })?;
2660 Ok(())
2661}
2662
2663/// Accept or decline a channel invitation.
2664async fn respond_to_channel_invite(
2665 request: proto::RespondToChannelInvite,
2666 response: Response<proto::RespondToChannelInvite>,
2667 session: Session,
2668) -> Result<()> {
2669 let db = session.db().await;
2670 let channel_id = ChannelId::from_proto(request.channel_id);
2671 let RespondToChannelInvite {
2672 membership_update,
2673 notifications,
2674 } = db
2675 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2676 .await?;
2677
2678 let connection_pool = session.connection_pool().await;
2679 if let Some(membership_update) = membership_update {
2680 notify_membership_updated(
2681 &connection_pool,
2682 membership_update,
2683 session.user_id,
2684 &session.peer,
2685 );
2686 } else {
2687 let update = proto::UpdateChannels {
2688 remove_channel_invitations: vec![channel_id.to_proto()],
2689 ..Default::default()
2690 };
2691
2692 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2693 session.peer.send(connection_id, update.clone())?;
2694 }
2695 };
2696
2697 send_notifications(&*connection_pool, &session.peer, notifications);
2698
2699 response.send(proto::Ack {})?;
2700
2701 Ok(())
2702}
2703
2704/// Join the channels' room
2705async fn join_channel(
2706 request: proto::JoinChannel,
2707 response: Response<proto::JoinChannel>,
2708 session: Session,
2709) -> Result<()> {
2710 let channel_id = ChannelId::from_proto(request.channel_id);
2711 join_channel_internal(channel_id, Box::new(response), session).await
2712}
2713
2714trait JoinChannelInternalResponse {
2715 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2716}
2717impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2718 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2719 Response::<proto::JoinChannel>::send(self, result)
2720 }
2721}
2722impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2723 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2724 Response::<proto::JoinRoom>::send(self, result)
2725 }
2726}
2727
2728async fn join_channel_internal(
2729 channel_id: ChannelId,
2730 response: Box<impl JoinChannelInternalResponse>,
2731 session: Session,
2732) -> Result<()> {
2733 let joined_room = {
2734 leave_room_for_session(&session).await?;
2735 let db = session.db().await;
2736
2737 let (joined_room, membership_updated, role) = db
2738 .join_channel(channel_id, session.user_id, session.connection_id)
2739 .await?;
2740
2741 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2742 let (can_publish, token) = if role == ChannelRole::Guest {
2743 (
2744 false,
2745 live_kit
2746 .guest_token(
2747 &joined_room.room.live_kit_room,
2748 &session.user_id.to_string(),
2749 )
2750 .trace_err()?,
2751 )
2752 } else {
2753 (
2754 true,
2755 live_kit
2756 .room_token(
2757 &joined_room.room.live_kit_room,
2758 &session.user_id.to_string(),
2759 )
2760 .trace_err()?,
2761 )
2762 };
2763
2764 Some(LiveKitConnectionInfo {
2765 server_url: live_kit.url().into(),
2766 token,
2767 can_publish,
2768 })
2769 });
2770
2771 response.send(proto::JoinRoomResponse {
2772 room: Some(joined_room.room.clone()),
2773 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2774 live_kit_connection_info,
2775 })?;
2776
2777 let connection_pool = session.connection_pool().await;
2778 if let Some(membership_updated) = membership_updated {
2779 notify_membership_updated(
2780 &connection_pool,
2781 membership_updated,
2782 session.user_id,
2783 &session.peer,
2784 );
2785 }
2786
2787 room_updated(&joined_room.room, &session.peer);
2788
2789 joined_room
2790 };
2791
2792 channel_updated(
2793 channel_id,
2794 &joined_room.room,
2795 &joined_room.channel_members,
2796 &session.peer,
2797 &*session.connection_pool().await,
2798 );
2799
2800 update_user_contacts(session.user_id, &session).await?;
2801 Ok(())
2802}
2803
2804/// Start editing the channel notes
2805async fn join_channel_buffer(
2806 request: proto::JoinChannelBuffer,
2807 response: Response<proto::JoinChannelBuffer>,
2808 session: Session,
2809) -> Result<()> {
2810 let db = session.db().await;
2811 let channel_id = ChannelId::from_proto(request.channel_id);
2812
2813 let open_response = db
2814 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2815 .await?;
2816
2817 let collaborators = open_response.collaborators.clone();
2818 response.send(open_response)?;
2819
2820 let update = UpdateChannelBufferCollaborators {
2821 channel_id: channel_id.to_proto(),
2822 collaborators: collaborators.clone(),
2823 };
2824 channel_buffer_updated(
2825 session.connection_id,
2826 collaborators
2827 .iter()
2828 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2829 &update,
2830 &session.peer,
2831 );
2832
2833 Ok(())
2834}
2835
2836/// Edit the channel notes
2837async fn update_channel_buffer(
2838 request: proto::UpdateChannelBuffer,
2839 session: Session,
2840) -> Result<()> {
2841 let db = session.db().await;
2842 let channel_id = ChannelId::from_proto(request.channel_id);
2843
2844 let (collaborators, non_collaborators, epoch, version) = db
2845 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2846 .await?;
2847
2848 channel_buffer_updated(
2849 session.connection_id,
2850 collaborators,
2851 &proto::UpdateChannelBuffer {
2852 channel_id: channel_id.to_proto(),
2853 operations: request.operations,
2854 },
2855 &session.peer,
2856 );
2857
2858 let pool = &*session.connection_pool().await;
2859
2860 broadcast(
2861 None,
2862 non_collaborators
2863 .iter()
2864 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2865 |peer_id| {
2866 session.peer.send(
2867 peer_id.into(),
2868 proto::UpdateChannels {
2869 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2870 channel_id: channel_id.to_proto(),
2871 epoch: epoch as u64,
2872 version: version.clone(),
2873 }],
2874 ..Default::default()
2875 },
2876 )
2877 },
2878 );
2879
2880 Ok(())
2881}
2882
2883/// Rejoin the channel notes after a connection blip
2884async fn rejoin_channel_buffers(
2885 request: proto::RejoinChannelBuffers,
2886 response: Response<proto::RejoinChannelBuffers>,
2887 session: Session,
2888) -> Result<()> {
2889 let db = session.db().await;
2890 let buffers = db
2891 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2892 .await?;
2893
2894 for rejoined_buffer in &buffers {
2895 let collaborators_to_notify = rejoined_buffer
2896 .buffer
2897 .collaborators
2898 .iter()
2899 .filter_map(|c| Some(c.peer_id?.into()));
2900 channel_buffer_updated(
2901 session.connection_id,
2902 collaborators_to_notify,
2903 &proto::UpdateChannelBufferCollaborators {
2904 channel_id: rejoined_buffer.buffer.channel_id,
2905 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2906 },
2907 &session.peer,
2908 );
2909 }
2910
2911 response.send(proto::RejoinChannelBuffersResponse {
2912 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2913 })?;
2914
2915 Ok(())
2916}
2917
2918/// Stop editing the channel notes
2919async fn leave_channel_buffer(
2920 request: proto::LeaveChannelBuffer,
2921 response: Response<proto::LeaveChannelBuffer>,
2922 session: Session,
2923) -> Result<()> {
2924 let db = session.db().await;
2925 let channel_id = ChannelId::from_proto(request.channel_id);
2926
2927 let left_buffer = db
2928 .leave_channel_buffer(channel_id, session.connection_id)
2929 .await?;
2930
2931 response.send(Ack {})?;
2932
2933 channel_buffer_updated(
2934 session.connection_id,
2935 left_buffer.connections,
2936 &proto::UpdateChannelBufferCollaborators {
2937 channel_id: channel_id.to_proto(),
2938 collaborators: left_buffer.collaborators,
2939 },
2940 &session.peer,
2941 );
2942
2943 Ok(())
2944}
2945
2946fn channel_buffer_updated<T: EnvelopedMessage>(
2947 sender_id: ConnectionId,
2948 collaborators: impl IntoIterator<Item = ConnectionId>,
2949 message: &T,
2950 peer: &Peer,
2951) {
2952 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2953 peer.send(peer_id.into(), message.clone())
2954 });
2955}
2956
2957fn send_notifications(
2958 connection_pool: &ConnectionPool,
2959 peer: &Peer,
2960 notifications: db::NotificationBatch,
2961) {
2962 for (user_id, notification) in notifications {
2963 for connection_id in connection_pool.user_connection_ids(user_id) {
2964 if let Err(error) = peer.send(
2965 connection_id,
2966 proto::AddNotification {
2967 notification: Some(notification.clone()),
2968 },
2969 ) {
2970 tracing::error!(
2971 "failed to send notification to {:?} {}",
2972 connection_id,
2973 error
2974 );
2975 }
2976 }
2977 }
2978}
2979
2980/// Send a message to the channel
2981async fn send_channel_message(
2982 request: proto::SendChannelMessage,
2983 response: Response<proto::SendChannelMessage>,
2984 session: Session,
2985) -> Result<()> {
2986 // Validate the message body.
2987 let body = request.body.trim().to_string();
2988 if body.len() > MAX_MESSAGE_LEN {
2989 return Err(anyhow!("message is too long"))?;
2990 }
2991 if body.is_empty() {
2992 return Err(anyhow!("message can't be blank"))?;
2993 }
2994
2995 // TODO: adjust mentions if body is trimmed
2996
2997 let timestamp = OffsetDateTime::now_utc();
2998 let nonce = request
2999 .nonce
3000 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3001
3002 let channel_id = ChannelId::from_proto(request.channel_id);
3003 let CreatedChannelMessage {
3004 message_id,
3005 participant_connection_ids,
3006 channel_members,
3007 notifications,
3008 } = session
3009 .db()
3010 .await
3011 .create_channel_message(
3012 channel_id,
3013 session.user_id,
3014 &body,
3015 &request.mentions,
3016 timestamp,
3017 nonce.clone().into(),
3018 match request.reply_to_message_id {
3019 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3020 None => None,
3021 },
3022 )
3023 .await?;
3024 let message = proto::ChannelMessage {
3025 sender_id: session.user_id.to_proto(),
3026 id: message_id.to_proto(),
3027 body,
3028 mentions: request.mentions,
3029 timestamp: timestamp.unix_timestamp() as u64,
3030 nonce: Some(nonce),
3031 reply_to_message_id: request.reply_to_message_id,
3032 };
3033 broadcast(
3034 Some(session.connection_id),
3035 participant_connection_ids,
3036 |connection| {
3037 session.peer.send(
3038 connection,
3039 proto::ChannelMessageSent {
3040 channel_id: channel_id.to_proto(),
3041 message: Some(message.clone()),
3042 },
3043 )
3044 },
3045 );
3046 response.send(proto::SendChannelMessageResponse {
3047 message: Some(message),
3048 })?;
3049
3050 let pool = &*session.connection_pool().await;
3051 broadcast(
3052 None,
3053 channel_members
3054 .iter()
3055 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3056 |peer_id| {
3057 session.peer.send(
3058 peer_id.into(),
3059 proto::UpdateChannels {
3060 latest_channel_message_ids: vec![proto::ChannelMessageId {
3061 channel_id: channel_id.to_proto(),
3062 message_id: message_id.to_proto(),
3063 }],
3064 ..Default::default()
3065 },
3066 )
3067 },
3068 );
3069 send_notifications(pool, &session.peer, notifications);
3070
3071 Ok(())
3072}
3073
3074/// Delete a channel message
3075async fn remove_channel_message(
3076 request: proto::RemoveChannelMessage,
3077 response: Response<proto::RemoveChannelMessage>,
3078 session: Session,
3079) -> Result<()> {
3080 let channel_id = ChannelId::from_proto(request.channel_id);
3081 let message_id = MessageId::from_proto(request.message_id);
3082 let connection_ids = session
3083 .db()
3084 .await
3085 .remove_channel_message(channel_id, message_id, session.user_id)
3086 .await?;
3087 broadcast(Some(session.connection_id), connection_ids, |connection| {
3088 session.peer.send(connection, request.clone())
3089 });
3090 response.send(proto::Ack {})?;
3091 Ok(())
3092}
3093
3094/// Mark a channel message as read
3095async fn acknowledge_channel_message(
3096 request: proto::AckChannelMessage,
3097 session: Session,
3098) -> Result<()> {
3099 let channel_id = ChannelId::from_proto(request.channel_id);
3100 let message_id = MessageId::from_proto(request.message_id);
3101 let notifications = session
3102 .db()
3103 .await
3104 .observe_channel_message(channel_id, session.user_id, message_id)
3105 .await?;
3106 send_notifications(
3107 &*session.connection_pool().await,
3108 &session.peer,
3109 notifications,
3110 );
3111 Ok(())
3112}
3113
3114/// Mark a buffer version as synced
3115async fn acknowledge_buffer_version(
3116 request: proto::AckBufferOperation,
3117 session: Session,
3118) -> Result<()> {
3119 let buffer_id = BufferId::from_proto(request.buffer_id);
3120 session
3121 .db()
3122 .await
3123 .observe_buffer_version(
3124 buffer_id,
3125 session.user_id,
3126 request.epoch as i32,
3127 &request.version,
3128 )
3129 .await?;
3130 Ok(())
3131}
3132
3133/// Start receiving chat updates for a channel
3134async fn join_channel_chat(
3135 request: proto::JoinChannelChat,
3136 response: Response<proto::JoinChannelChat>,
3137 session: Session,
3138) -> Result<()> {
3139 let channel_id = ChannelId::from_proto(request.channel_id);
3140
3141 let db = session.db().await;
3142 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3143 .await?;
3144 let messages = db
3145 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3146 .await?;
3147 response.send(proto::JoinChannelChatResponse {
3148 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3149 messages,
3150 })?;
3151 Ok(())
3152}
3153
3154/// Stop receiving chat updates for a channel
3155async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3156 let channel_id = ChannelId::from_proto(request.channel_id);
3157 session
3158 .db()
3159 .await
3160 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3161 .await?;
3162 Ok(())
3163}
3164
3165/// Retrieve the chat history for a channel
3166async fn get_channel_messages(
3167 request: proto::GetChannelMessages,
3168 response: Response<proto::GetChannelMessages>,
3169 session: Session,
3170) -> Result<()> {
3171 let channel_id = ChannelId::from_proto(request.channel_id);
3172 let messages = session
3173 .db()
3174 .await
3175 .get_channel_messages(
3176 channel_id,
3177 session.user_id,
3178 MESSAGE_COUNT_PER_PAGE,
3179 Some(MessageId::from_proto(request.before_message_id)),
3180 )
3181 .await?;
3182 response.send(proto::GetChannelMessagesResponse {
3183 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3184 messages,
3185 })?;
3186 Ok(())
3187}
3188
3189/// Retrieve specific chat messages
3190async fn get_channel_messages_by_id(
3191 request: proto::GetChannelMessagesById,
3192 response: Response<proto::GetChannelMessagesById>,
3193 session: Session,
3194) -> Result<()> {
3195 let message_ids = request
3196 .message_ids
3197 .iter()
3198 .map(|id| MessageId::from_proto(*id))
3199 .collect::<Vec<_>>();
3200 let messages = session
3201 .db()
3202 .await
3203 .get_channel_messages_by_id(session.user_id, &message_ids)
3204 .await?;
3205 response.send(proto::GetChannelMessagesResponse {
3206 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3207 messages,
3208 })?;
3209 Ok(())
3210}
3211
3212/// Retrieve the current users notifications
3213async fn get_notifications(
3214 request: proto::GetNotifications,
3215 response: Response<proto::GetNotifications>,
3216 session: Session,
3217) -> Result<()> {
3218 let notifications = session
3219 .db()
3220 .await
3221 .get_notifications(
3222 session.user_id,
3223 NOTIFICATION_COUNT_PER_PAGE,
3224 request
3225 .before_id
3226 .map(|id| db::NotificationId::from_proto(id)),
3227 )
3228 .await?;
3229 response.send(proto::GetNotificationsResponse {
3230 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3231 notifications,
3232 })?;
3233 Ok(())
3234}
3235
3236/// Mark notifications as read
3237async fn mark_notification_as_read(
3238 request: proto::MarkNotificationRead,
3239 response: Response<proto::MarkNotificationRead>,
3240 session: Session,
3241) -> Result<()> {
3242 let database = &session.db().await;
3243 let notifications = database
3244 .mark_notification_as_read_by_id(
3245 session.user_id,
3246 NotificationId::from_proto(request.notification_id),
3247 )
3248 .await?;
3249 send_notifications(
3250 &*session.connection_pool().await,
3251 &session.peer,
3252 notifications,
3253 );
3254 response.send(proto::Ack {})?;
3255 Ok(())
3256}
3257
3258/// Get the current users information
3259async fn get_private_user_info(
3260 _request: proto::GetPrivateUserInfo,
3261 response: Response<proto::GetPrivateUserInfo>,
3262 session: Session,
3263) -> Result<()> {
3264 let db = session.db().await;
3265
3266 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3267 let user = db
3268 .get_user_by_id(session.user_id)
3269 .await?
3270 .ok_or_else(|| anyhow!("user not found"))?;
3271 let flags = db.get_user_flags(session.user_id).await?;
3272
3273 response.send(proto::GetPrivateUserInfoResponse {
3274 metrics_id,
3275 staff: user.admin,
3276 flags,
3277 })?;
3278 Ok(())
3279}
3280
3281fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3282 match message {
3283 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3284 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3285 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3286 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3287 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3288 code: frame.code.into(),
3289 reason: frame.reason,
3290 })),
3291 }
3292}
3293
3294fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3295 match message {
3296 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3297 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3298 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3299 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3300 AxumMessage::Close(frame) => {
3301 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3302 code: frame.code.into(),
3303 reason: frame.reason,
3304 }))
3305 }
3306 }
3307}
3308
3309fn notify_membership_updated(
3310 connection_pool: &ConnectionPool,
3311 result: MembershipUpdated,
3312 user_id: UserId,
3313 peer: &Peer,
3314) {
3315 let user_channels_update = proto::UpdateUserChannels {
3316 channel_memberships: result
3317 .new_channels
3318 .channel_memberships
3319 .iter()
3320 .map(|cm| proto::ChannelMembership {
3321 channel_id: cm.channel_id.to_proto(),
3322 role: cm.role.into(),
3323 })
3324 .collect(),
3325 ..Default::default()
3326 };
3327 let mut update = build_channels_update(result.new_channels, vec![]);
3328 update.delete_channels = result
3329 .removed_channels
3330 .into_iter()
3331 .map(|id| id.to_proto())
3332 .collect();
3333 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3334
3335 for connection_id in connection_pool.user_connection_ids(user_id) {
3336 peer.send(connection_id, user_channels_update.clone())
3337 .trace_err();
3338 peer.send(connection_id, update.clone()).trace_err();
3339 }
3340}
3341
3342fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3343 proto::UpdateUserChannels {
3344 channel_memberships: channels
3345 .channel_memberships
3346 .iter()
3347 .map(|m| proto::ChannelMembership {
3348 channel_id: m.channel_id.to_proto(),
3349 role: m.role.into(),
3350 })
3351 .collect(),
3352 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3353 observed_channel_message_id: channels.observed_channel_messages.clone(),
3354 ..Default::default()
3355 }
3356}
3357
3358fn build_channels_update(
3359 channels: ChannelsForUser,
3360 channel_invites: Vec<db::Channel>,
3361) -> proto::UpdateChannels {
3362 let mut update = proto::UpdateChannels::default();
3363
3364 for channel in channels.channels {
3365 update.channels.push(channel.to_proto());
3366 }
3367
3368 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3369 update.latest_channel_message_ids = channels.latest_channel_messages;
3370
3371 for (channel_id, participants) in channels.channel_participants {
3372 update
3373 .channel_participants
3374 .push(proto::ChannelParticipants {
3375 channel_id: channel_id.to_proto(),
3376 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3377 });
3378 }
3379
3380 for channel in channel_invites {
3381 update.channel_invitations.push(channel.to_proto());
3382 }
3383
3384 update
3385}
3386
3387fn build_initial_contacts_update(
3388 contacts: Vec<db::Contact>,
3389 pool: &ConnectionPool,
3390) -> proto::UpdateContacts {
3391 let mut update = proto::UpdateContacts::default();
3392
3393 for contact in contacts {
3394 match contact {
3395 db::Contact::Accepted { user_id, busy } => {
3396 update.contacts.push(contact_for_user(user_id, busy, &pool));
3397 }
3398 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3399 db::Contact::Incoming { user_id } => {
3400 update
3401 .incoming_requests
3402 .push(proto::IncomingContactRequest {
3403 requester_id: user_id.to_proto(),
3404 })
3405 }
3406 }
3407 }
3408
3409 update
3410}
3411
3412fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3413 proto::Contact {
3414 user_id: user_id.to_proto(),
3415 online: pool.is_user_online(user_id),
3416 busy,
3417 }
3418}
3419
3420fn room_updated(room: &proto::Room, peer: &Peer) {
3421 broadcast(
3422 None,
3423 room.participants
3424 .iter()
3425 .filter_map(|participant| Some(participant.peer_id?.into())),
3426 |peer_id| {
3427 peer.send(
3428 peer_id.into(),
3429 proto::RoomUpdated {
3430 room: Some(room.clone()),
3431 },
3432 )
3433 },
3434 );
3435}
3436
3437fn channel_updated(
3438 channel_id: ChannelId,
3439 room: &proto::Room,
3440 channel_members: &[UserId],
3441 peer: &Peer,
3442 pool: &ConnectionPool,
3443) {
3444 let participants = room
3445 .participants
3446 .iter()
3447 .map(|p| p.user_id)
3448 .collect::<Vec<_>>();
3449
3450 broadcast(
3451 None,
3452 channel_members
3453 .iter()
3454 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3455 |peer_id| {
3456 peer.send(
3457 peer_id.into(),
3458 proto::UpdateChannels {
3459 channel_participants: vec![proto::ChannelParticipants {
3460 channel_id: channel_id.to_proto(),
3461 participant_user_ids: participants.clone(),
3462 }],
3463 ..Default::default()
3464 },
3465 )
3466 },
3467 );
3468}
3469
3470async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3471 let db = session.db().await;
3472
3473 let contacts = db.get_contacts(user_id).await?;
3474 let busy = db.is_user_busy(user_id).await?;
3475
3476 let pool = session.connection_pool().await;
3477 let updated_contact = contact_for_user(user_id, busy, &pool);
3478 for contact in contacts {
3479 if let db::Contact::Accepted {
3480 user_id: contact_user_id,
3481 ..
3482 } = contact
3483 {
3484 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3485 session
3486 .peer
3487 .send(
3488 contact_conn_id,
3489 proto::UpdateContacts {
3490 contacts: vec![updated_contact.clone()],
3491 remove_contacts: Default::default(),
3492 incoming_requests: Default::default(),
3493 remove_incoming_requests: Default::default(),
3494 outgoing_requests: Default::default(),
3495 remove_outgoing_requests: Default::default(),
3496 },
3497 )
3498 .trace_err();
3499 }
3500 }
3501 }
3502 Ok(())
3503}
3504
3505async fn leave_room_for_session(session: &Session) -> Result<()> {
3506 let mut contacts_to_update = HashSet::default();
3507
3508 let room_id;
3509 let canceled_calls_to_user_ids;
3510 let live_kit_room;
3511 let delete_live_kit_room;
3512 let room;
3513 let channel_members;
3514 let channel_id;
3515
3516 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3517 contacts_to_update.insert(session.user_id);
3518
3519 for project in left_room.left_projects.values() {
3520 project_left(project, session);
3521 }
3522
3523 room_id = RoomId::from_proto(left_room.room.id);
3524 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3525 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3526 delete_live_kit_room = left_room.deleted;
3527 room = mem::take(&mut left_room.room);
3528 channel_members = mem::take(&mut left_room.channel_members);
3529 channel_id = left_room.channel_id;
3530
3531 room_updated(&room, &session.peer);
3532 } else {
3533 return Ok(());
3534 }
3535
3536 if let Some(channel_id) = channel_id {
3537 channel_updated(
3538 channel_id,
3539 &room,
3540 &channel_members,
3541 &session.peer,
3542 &*session.connection_pool().await,
3543 );
3544 }
3545
3546 {
3547 let pool = session.connection_pool().await;
3548 for canceled_user_id in canceled_calls_to_user_ids {
3549 for connection_id in pool.user_connection_ids(canceled_user_id) {
3550 session
3551 .peer
3552 .send(
3553 connection_id,
3554 proto::CallCanceled {
3555 room_id: room_id.to_proto(),
3556 },
3557 )
3558 .trace_err();
3559 }
3560 contacts_to_update.insert(canceled_user_id);
3561 }
3562 }
3563
3564 for contact_user_id in contacts_to_update {
3565 update_user_contacts(contact_user_id, &session).await?;
3566 }
3567
3568 if let Some(live_kit) = session.live_kit_client.as_ref() {
3569 live_kit
3570 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3571 .await
3572 .trace_err();
3573
3574 if delete_live_kit_room {
3575 live_kit.delete_room(live_kit_room).await.trace_err();
3576 }
3577 }
3578
3579 Ok(())
3580}
3581
3582async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3583 let left_channel_buffers = session
3584 .db()
3585 .await
3586 .leave_channel_buffers(session.connection_id)
3587 .await?;
3588
3589 for left_buffer in left_channel_buffers {
3590 channel_buffer_updated(
3591 session.connection_id,
3592 left_buffer.connections,
3593 &proto::UpdateChannelBufferCollaborators {
3594 channel_id: left_buffer.channel_id.to_proto(),
3595 collaborators: left_buffer.collaborators,
3596 },
3597 &session.peer,
3598 );
3599 }
3600
3601 Ok(())
3602}
3603
3604fn project_left(project: &db::LeftProject, session: &Session) {
3605 for connection_id in &project.connection_ids {
3606 if project.host_user_id == session.user_id {
3607 session
3608 .peer
3609 .send(
3610 *connection_id,
3611 proto::UnshareProject {
3612 project_id: project.id.to_proto(),
3613 },
3614 )
3615 .trace_err();
3616 } else {
3617 session
3618 .peer
3619 .send(
3620 *connection_id,
3621 proto::RemoveProjectCollaborator {
3622 project_id: project.id.to_proto(),
3623 peer_id: Some(session.connection_id.into()),
3624 },
3625 )
3626 .trace_err();
3627 }
3628 }
3629}
3630
3631pub trait ResultExt {
3632 type Ok;
3633
3634 fn trace_err(self) -> Option<Self::Ok>;
3635}
3636
3637impl<T, E> ResultExt for Result<T, E>
3638where
3639 E: std::fmt::Debug,
3640{
3641 type Ok = T;
3642
3643 fn trace_err(self) -> Option<T> {
3644 match self {
3645 Ok(value) => Some(value),
3646 Err(error) => {
3647 tracing::error!("{:?}", error);
3648 None
3649 }
3650 }
3651 }
3652}