1mod connection_pool;
2
3use crate::{
4 auth::{self, Impersonator},
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
7 InviteMemberResult, MembershipUpdated, MessageId, NotificationId, Project, ProjectId,
8 RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId, ServerId, User,
9 UserId,
10 },
11 executor::Executor,
12 AppState, Error, Result,
13};
14use anyhow::anyhow;
15use async_tungstenite::tungstenite::{
16 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
17};
18use axum::{
19 body::Body,
20 extract::{
21 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
22 ConnectInfo, WebSocketUpgrade,
23 },
24 headers::{Header, HeaderName},
25 http::StatusCode,
26 middleware,
27 response::IntoResponse,
28 routing::get,
29 Extension, Router, TypedHeader,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use futures::{
34 channel::oneshot,
35 future::{self, BoxFuture},
36 stream::FuturesUnordered,
37 FutureExt, SinkExt, StreamExt, TryStreamExt,
38};
39use prometheus::{register_int_gauge, IntGauge};
40use rpc::{
41 proto::{
42 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
43 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
44 },
45 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
46};
47use serde::{Serialize, Serializer};
48use std::{
49 any::TypeId,
50 fmt,
51 future::Future,
52 marker::PhantomData,
53 mem,
54 net::SocketAddr,
55 ops::{Deref, DerefMut},
56 rc::Rc,
57 sync::{
58 atomic::{AtomicBool, Ordering::SeqCst},
59 Arc, OnceLock,
60 },
61 time::{Duration, Instant},
62};
63use time::OffsetDateTime;
64use tokio::sync::{watch, Semaphore};
65use tower::ServiceBuilder;
66use tracing::{field, info_span, instrument, Instrument};
67use util::SemanticVersion;
68
69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
70
71// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
72pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
73
74const MESSAGE_COUNT_PER_PAGE: usize = 100;
75const MAX_MESSAGE_LEN: usize = 1024;
76const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
77
78type MessageHandler =
79 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
80
81struct Response<R> {
82 peer: Arc<Peer>,
83 receipt: Receipt<R>,
84 responded: Arc<AtomicBool>,
85}
86
87impl<R: RequestMessage> Response<R> {
88 fn send(self, payload: R::Response) -> Result<()> {
89 self.responded.store(true, SeqCst);
90 self.peer.respond(self.receipt, payload)?;
91 Ok(())
92 }
93}
94
95#[derive(Clone)]
96struct Session {
97 user_id: UserId,
98 connection_id: ConnectionId,
99 db: Arc<tokio::sync::Mutex<DbHandle>>,
100 peer: Arc<Peer>,
101 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
102 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
103 _executor: Executor,
104}
105
106impl Session {
107 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
108 #[cfg(test)]
109 tokio::task::yield_now().await;
110 let guard = self.db.lock().await;
111 #[cfg(test)]
112 tokio::task::yield_now().await;
113 guard
114 }
115
116 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
117 #[cfg(test)]
118 tokio::task::yield_now().await;
119 let guard = self.connection_pool.lock();
120 ConnectionPoolGuard {
121 guard,
122 _not_send: PhantomData,
123 }
124 }
125}
126
127impl fmt::Debug for Session {
128 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129 f.debug_struct("Session")
130 .field("user_id", &self.user_id)
131 .field("connection_id", &self.connection_id)
132 .finish()
133 }
134}
135
136struct DbHandle(Arc<Database>);
137
138impl Deref for DbHandle {
139 type Target = Database;
140
141 fn deref(&self) -> &Self::Target {
142 self.0.as_ref()
143 }
144}
145
146pub struct Server {
147 id: parking_lot::Mutex<ServerId>,
148 peer: Arc<Peer>,
149 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
150 app_state: Arc<AppState>,
151 executor: Executor,
152 handlers: HashMap<TypeId, MessageHandler>,
153 teardown: watch::Sender<bool>,
154}
155
156pub(crate) struct ConnectionPoolGuard<'a> {
157 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
158 _not_send: PhantomData<Rc<()>>,
159}
160
161#[derive(Serialize)]
162pub struct ServerSnapshot<'a> {
163 peer: &'a Peer,
164 #[serde(serialize_with = "serialize_deref")]
165 connection_pool: ConnectionPoolGuard<'a>,
166}
167
168pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
169where
170 S: Serializer,
171 T: Deref<Target = U>,
172 U: Serialize,
173{
174 Serialize::serialize(value.deref(), serializer)
175}
176
177impl Server {
178 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
179 let mut server = Self {
180 id: parking_lot::Mutex::new(id),
181 peer: Peer::new(id.0 as u32),
182 app_state,
183 executor,
184 connection_pool: Default::default(),
185 handlers: Default::default(),
186 teardown: watch::channel(false).0,
187 };
188
189 server
190 .add_request_handler(ping)
191 .add_request_handler(create_room)
192 .add_request_handler(join_room)
193 .add_request_handler(rejoin_room)
194 .add_request_handler(leave_room)
195 .add_request_handler(set_room_participant_role)
196 .add_request_handler(call)
197 .add_request_handler(cancel_call)
198 .add_message_handler(decline_call)
199 .add_request_handler(update_participant_location)
200 .add_request_handler(share_project)
201 .add_message_handler(unshare_project)
202 .add_request_handler(join_project)
203 .add_request_handler(join_hosted_project)
204 .add_message_handler(leave_project)
205 .add_request_handler(update_project)
206 .add_request_handler(update_worktree)
207 .add_message_handler(start_language_server)
208 .add_message_handler(update_language_server)
209 .add_message_handler(update_diagnostic_summary)
210 .add_message_handler(update_worktree_settings)
211 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
212 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
213 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
214 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
215 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
216 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
217 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
218 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
219 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
220 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
221 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
222 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
223 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
224 .add_request_handler(
225 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
226 )
227 .add_request_handler(
228 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
229 )
230 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
231 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
232 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
233 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
234 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
235 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
236 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
237 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
238 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
239 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
240 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
241 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
242 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
243 .add_message_handler(create_buffer_for_peer)
244 .add_request_handler(update_buffer)
245 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
246 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
247 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
248 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
249 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
250 .add_request_handler(get_users)
251 .add_request_handler(fuzzy_search_users)
252 .add_request_handler(request_contact)
253 .add_request_handler(remove_contact)
254 .add_request_handler(respond_to_contact_request)
255 .add_request_handler(create_channel)
256 .add_request_handler(delete_channel)
257 .add_request_handler(invite_channel_member)
258 .add_request_handler(remove_channel_member)
259 .add_request_handler(set_channel_member_role)
260 .add_request_handler(set_channel_visibility)
261 .add_request_handler(rename_channel)
262 .add_request_handler(join_channel_buffer)
263 .add_request_handler(leave_channel_buffer)
264 .add_message_handler(update_channel_buffer)
265 .add_request_handler(rejoin_channel_buffers)
266 .add_request_handler(get_channel_members)
267 .add_request_handler(respond_to_channel_invite)
268 .add_request_handler(join_channel)
269 .add_request_handler(join_channel_chat)
270 .add_message_handler(leave_channel_chat)
271 .add_request_handler(send_channel_message)
272 .add_request_handler(remove_channel_message)
273 .add_request_handler(get_channel_messages)
274 .add_request_handler(get_channel_messages_by_id)
275 .add_request_handler(get_notifications)
276 .add_request_handler(mark_notification_as_read)
277 .add_request_handler(move_channel)
278 .add_request_handler(follow)
279 .add_message_handler(unfollow)
280 .add_message_handler(update_followers)
281 .add_request_handler(get_private_user_info)
282 .add_message_handler(acknowledge_channel_message)
283 .add_message_handler(acknowledge_buffer_version);
284
285 Arc::new(server)
286 }
287
288 pub async fn start(&self) -> Result<()> {
289 let server_id = *self.id.lock();
290 let app_state = self.app_state.clone();
291 let peer = self.peer.clone();
292 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
293 let pool = self.connection_pool.clone();
294 let live_kit_client = self.app_state.live_kit_client.clone();
295
296 let span = info_span!("start server");
297 self.executor.spawn_detached(
298 async move {
299 tracing::info!("waiting for cleanup timeout");
300 timeout.await;
301 tracing::info!("cleanup timeout expired, retrieving stale rooms");
302 if let Some((room_ids, channel_ids)) = app_state
303 .db
304 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
305 .await
306 .trace_err()
307 {
308 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
309 tracing::info!(
310 stale_channel_buffer_count = channel_ids.len(),
311 "retrieved stale channel buffers"
312 );
313
314 for channel_id in channel_ids {
315 if let Some(refreshed_channel_buffer) = app_state
316 .db
317 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
318 .await
319 .trace_err()
320 {
321 for connection_id in refreshed_channel_buffer.connection_ids {
322 peer.send(
323 connection_id,
324 proto::UpdateChannelBufferCollaborators {
325 channel_id: channel_id.to_proto(),
326 collaborators: refreshed_channel_buffer
327 .collaborators
328 .clone(),
329 },
330 )
331 .trace_err();
332 }
333 }
334 }
335
336 for room_id in room_ids {
337 let mut contacts_to_update = HashSet::default();
338 let mut canceled_calls_to_user_ids = Vec::new();
339 let mut live_kit_room = String::new();
340 let mut delete_live_kit_room = false;
341
342 if let Some(mut refreshed_room) = app_state
343 .db
344 .clear_stale_room_participants(room_id, server_id)
345 .await
346 .trace_err()
347 {
348 tracing::info!(
349 room_id = room_id.0,
350 new_participant_count = refreshed_room.room.participants.len(),
351 "refreshed room"
352 );
353 room_updated(&refreshed_room.room, &peer);
354 if let Some(channel_id) = refreshed_room.channel_id {
355 channel_updated(
356 channel_id,
357 &refreshed_room.room,
358 &refreshed_room.channel_members,
359 &peer,
360 &pool.lock(),
361 );
362 }
363 contacts_to_update
364 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
365 contacts_to_update
366 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
367 canceled_calls_to_user_ids =
368 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
369 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
370 delete_live_kit_room = refreshed_room.room.participants.is_empty();
371 }
372
373 {
374 let pool = pool.lock();
375 for canceled_user_id in canceled_calls_to_user_ids {
376 for connection_id in pool.user_connection_ids(canceled_user_id) {
377 peer.send(
378 connection_id,
379 proto::CallCanceled {
380 room_id: room_id.to_proto(),
381 },
382 )
383 .trace_err();
384 }
385 }
386 }
387
388 for user_id in contacts_to_update {
389 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
390 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
391 if let Some((busy, contacts)) = busy.zip(contacts) {
392 let pool = pool.lock();
393 let updated_contact = contact_for_user(user_id, busy, &pool);
394 for contact in contacts {
395 if let db::Contact::Accepted {
396 user_id: contact_user_id,
397 ..
398 } = contact
399 {
400 for contact_conn_id in
401 pool.user_connection_ids(contact_user_id)
402 {
403 peer.send(
404 contact_conn_id,
405 proto::UpdateContacts {
406 contacts: vec![updated_contact.clone()],
407 remove_contacts: Default::default(),
408 incoming_requests: Default::default(),
409 remove_incoming_requests: Default::default(),
410 outgoing_requests: Default::default(),
411 remove_outgoing_requests: Default::default(),
412 },
413 )
414 .trace_err();
415 }
416 }
417 }
418 }
419 }
420
421 if let Some(live_kit) = live_kit_client.as_ref() {
422 if delete_live_kit_room {
423 live_kit.delete_room(live_kit_room).await.trace_err();
424 }
425 }
426 }
427 }
428
429 app_state
430 .db
431 .delete_stale_servers(&app_state.config.zed_environment, server_id)
432 .await
433 .trace_err();
434 }
435 .instrument(span),
436 );
437 Ok(())
438 }
439
440 pub fn teardown(&self) {
441 self.peer.teardown();
442 self.connection_pool.lock().reset();
443 let _ = self.teardown.send(true);
444 }
445
446 #[cfg(test)]
447 pub fn reset(&self, id: ServerId) {
448 self.teardown();
449 *self.id.lock() = id;
450 self.peer.reset(id.0 as u32);
451 let _ = self.teardown.send(false);
452 }
453
454 #[cfg(test)]
455 pub fn id(&self) -> ServerId {
456 *self.id.lock()
457 }
458
459 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
460 where
461 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
462 Fut: 'static + Send + Future<Output = Result<()>>,
463 M: EnvelopedMessage,
464 {
465 let prev_handler = self.handlers.insert(
466 TypeId::of::<M>(),
467 Box::new(move |envelope, session| {
468 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
469 let received_at = envelope.received_at;
470 let span = info_span!(
471 "handle message",
472 payload_type = envelope.payload_type_name()
473 );
474 span.in_scope(|| {
475 tracing::info!(
476 payload_type = envelope.payload_type_name(),
477 "message received"
478 );
479 });
480 let start_time = Instant::now();
481 let future = (handler)(*envelope, session);
482 async move {
483 let result = future.await;
484 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
485 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
486 let queue_duration_ms = total_duration_ms - processing_duration_ms;
487 match result {
488 Err(error) => {
489 tracing::error!(%error, ?total_duration_ms, ?processing_duration_ms, ?queue_duration_ms, "error handling message")
490 }
491 Ok(()) => tracing::info!(?total_duration_ms, ?processing_duration_ms, ?queue_duration_ms, "finished handling message"),
492 }
493 }
494 .instrument(span)
495 .boxed()
496 }),
497 );
498 if prev_handler.is_some() {
499 panic!("registered a handler for the same message twice");
500 }
501 self
502 }
503
504 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
505 where
506 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
507 Fut: 'static + Send + Future<Output = Result<()>>,
508 M: EnvelopedMessage,
509 {
510 self.add_handler(move |envelope, session| handler(envelope.payload, session));
511 self
512 }
513
514 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
515 where
516 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
517 Fut: Send + Future<Output = Result<()>>,
518 M: RequestMessage,
519 {
520 let handler = Arc::new(handler);
521 self.add_handler(move |envelope, session| {
522 let receipt = envelope.receipt();
523 let handler = handler.clone();
524 async move {
525 let peer = session.peer.clone();
526 let responded = Arc::new(AtomicBool::default());
527 let response = Response {
528 peer: peer.clone(),
529 responded: responded.clone(),
530 receipt,
531 };
532 match (handler)(envelope.payload, response, session).await {
533 Ok(()) => {
534 if responded.load(std::sync::atomic::Ordering::SeqCst) {
535 Ok(())
536 } else {
537 Err(anyhow!("handler did not send a response"))?
538 }
539 }
540 Err(error) => {
541 let proto_err = match &error {
542 Error::Internal(err) => err.to_proto(),
543 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
544 };
545 peer.respond_with_error(receipt, proto_err)?;
546 Err(error)
547 }
548 }
549 }
550 })
551 }
552
553 #[allow(clippy::too_many_arguments)]
554 pub fn handle_connection(
555 self: &Arc<Self>,
556 connection: Connection,
557 address: String,
558 user: User,
559 zed_version: ZedVersion,
560 impersonator: Option<User>,
561 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
562 executor: Executor,
563 ) -> impl Future<Output = Result<()>> {
564 let this = self.clone();
565 let user_id = user.id;
566 let login = user.github_login;
567 let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
568 if let Some(impersonator) = impersonator {
569 span.record("impersonator", &impersonator.github_login);
570 }
571 let mut teardown = self.teardown.subscribe();
572 async move {
573 if *teardown.borrow() {
574 return Err(anyhow!("server is tearing down"))?;
575 }
576 let (connection_id, handle_io, mut incoming_rx) = this
577 .peer
578 .add_connection(connection, {
579 let executor = executor.clone();
580 move |duration| executor.sleep(duration)
581 });
582
583 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
584 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
585 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
586
587 if let Some(send_connection_id) = send_connection_id.take() {
588 let _ = send_connection_id.send(connection_id);
589 }
590
591 if !user.connected_once {
592 this.peer.send(connection_id, proto::ShowContacts {})?;
593 this.app_state.db.set_user_connected_once(user_id, true).await?;
594 }
595
596 let (contacts, channels_for_user, channel_invites) = future::try_join3(
597 this.app_state.db.get_contacts(user_id),
598 this.app_state.db.get_channels_for_user(user_id),
599 this.app_state.db.get_channel_invites_for_user(user_id),
600 ).await?;
601
602 {
603 let mut pool = this.connection_pool.lock();
604 pool.add_connection(connection_id, user_id, user.admin, zed_version);
605 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
606 this.peer.send(connection_id, build_update_user_channels(&channels_for_user))?;
607 this.peer.send(connection_id, build_channels_update(
608 channels_for_user,
609 channel_invites
610 ))?;
611 }
612
613 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
614 this.peer.send(connection_id, incoming_call)?;
615 }
616
617 let session = Session {
618 user_id,
619 connection_id,
620 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
621 peer: this.peer.clone(),
622 connection_pool: this.connection_pool.clone(),
623 live_kit_client: this.app_state.live_kit_client.clone(),
624 _executor: executor.clone()
625 };
626 update_user_contacts(user_id, &session).await?;
627
628 let handle_io = handle_io.fuse();
629 futures::pin_mut!(handle_io);
630
631 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
632 // This prevents deadlocks when e.g., client A performs a request to client B and
633 // client B performs a request to client A. If both clients stop processing further
634 // messages until their respective request completes, they won't have a chance to
635 // respond to the other client's request and cause a deadlock.
636 //
637 // This arrangement ensures we will attempt to process earlier messages first, but fall
638 // back to processing messages arrived later in the spirit of making progress.
639 let mut foreground_message_handlers = FuturesUnordered::new();
640 let concurrent_handlers = Arc::new(Semaphore::new(256));
641 loop {
642 let next_message = async {
643 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
644 let message = incoming_rx.next().await;
645 (permit, message)
646 }.fuse();
647 futures::pin_mut!(next_message);
648 futures::select_biased! {
649 _ = teardown.changed().fuse() => return Ok(()),
650 result = handle_io => {
651 if let Err(error) = result {
652 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
653 }
654 break;
655 }
656 _ = foreground_message_handlers.next() => {}
657 next_message = next_message => {
658 let (permit, message) = next_message;
659 if let Some(message) = message {
660 let type_name = message.payload_type_name();
661 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
662 let span_enter = span.enter();
663 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
664 let is_background = message.is_background();
665 let handle_message = (handler)(message, session.clone());
666 drop(span_enter);
667
668 let handle_message = async move {
669 handle_message.await;
670 drop(permit);
671 }.instrument(span);
672 if is_background {
673 executor.spawn_detached(handle_message);
674 } else {
675 foreground_message_handlers.push(handle_message);
676 }
677 } else {
678 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
679 }
680 } else {
681 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
682 break;
683 }
684 }
685 }
686 }
687
688 drop(foreground_message_handlers);
689 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
690 if let Err(error) = connection_lost(session, teardown, executor).await {
691 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
692 }
693
694 Ok(())
695 }.instrument(span)
696 }
697
698 pub async fn invite_code_redeemed(
699 self: &Arc<Self>,
700 inviter_id: UserId,
701 invitee_id: UserId,
702 ) -> Result<()> {
703 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
704 if let Some(code) = &user.invite_code {
705 let pool = self.connection_pool.lock();
706 let invitee_contact = contact_for_user(invitee_id, false, &pool);
707 for connection_id in pool.user_connection_ids(inviter_id) {
708 self.peer.send(
709 connection_id,
710 proto::UpdateContacts {
711 contacts: vec![invitee_contact.clone()],
712 ..Default::default()
713 },
714 )?;
715 self.peer.send(
716 connection_id,
717 proto::UpdateInviteInfo {
718 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
719 count: user.invite_count as u32,
720 },
721 )?;
722 }
723 }
724 }
725 Ok(())
726 }
727
728 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
729 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
730 if let Some(invite_code) = &user.invite_code {
731 let pool = self.connection_pool.lock();
732 for connection_id in pool.user_connection_ids(user_id) {
733 self.peer.send(
734 connection_id,
735 proto::UpdateInviteInfo {
736 url: format!(
737 "{}{}",
738 self.app_state.config.invite_link_prefix, invite_code
739 ),
740 count: user.invite_count as u32,
741 },
742 )?;
743 }
744 }
745 }
746 Ok(())
747 }
748
749 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
750 ServerSnapshot {
751 connection_pool: ConnectionPoolGuard {
752 guard: self.connection_pool.lock(),
753 _not_send: PhantomData,
754 },
755 peer: &self.peer,
756 }
757 }
758}
759
760impl<'a> Deref for ConnectionPoolGuard<'a> {
761 type Target = ConnectionPool;
762
763 fn deref(&self) -> &Self::Target {
764 &self.guard
765 }
766}
767
768impl<'a> DerefMut for ConnectionPoolGuard<'a> {
769 fn deref_mut(&mut self) -> &mut Self::Target {
770 &mut self.guard
771 }
772}
773
774impl<'a> Drop for ConnectionPoolGuard<'a> {
775 fn drop(&mut self) {
776 #[cfg(test)]
777 self.check_invariants();
778 }
779}
780
781fn broadcast<F>(
782 sender_id: Option<ConnectionId>,
783 receiver_ids: impl IntoIterator<Item = ConnectionId>,
784 mut f: F,
785) where
786 F: FnMut(ConnectionId) -> anyhow::Result<()>,
787{
788 for receiver_id in receiver_ids {
789 if Some(receiver_id) != sender_id {
790 if let Err(error) = f(receiver_id) {
791 tracing::error!("failed to send to {:?} {}", receiver_id, error);
792 }
793 }
794 }
795}
796
797pub struct ProtocolVersion(u32);
798
799impl Header for ProtocolVersion {
800 fn name() -> &'static HeaderName {
801 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
802 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
803 }
804
805 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
806 where
807 Self: Sized,
808 I: Iterator<Item = &'i axum::http::HeaderValue>,
809 {
810 let version = values
811 .next()
812 .ok_or_else(axum::headers::Error::invalid)?
813 .to_str()
814 .map_err(|_| axum::headers::Error::invalid())?
815 .parse()
816 .map_err(|_| axum::headers::Error::invalid())?;
817 Ok(Self(version))
818 }
819
820 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
821 values.extend([self.0.to_string().parse().unwrap()]);
822 }
823}
824
825pub struct AppVersionHeader(SemanticVersion);
826impl Header for AppVersionHeader {
827 fn name() -> &'static HeaderName {
828 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
829 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
830 }
831
832 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
833 where
834 Self: Sized,
835 I: Iterator<Item = &'i axum::http::HeaderValue>,
836 {
837 let version = values
838 .next()
839 .ok_or_else(axum::headers::Error::invalid)?
840 .to_str()
841 .map_err(|_| axum::headers::Error::invalid())?
842 .parse()
843 .map_err(|_| axum::headers::Error::invalid())?;
844 Ok(Self(version))
845 }
846
847 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
848 values.extend([self.0.to_string().parse().unwrap()]);
849 }
850}
851
852pub fn routes(server: Arc<Server>) -> Router<(), Body> {
853 Router::new()
854 .route("/rpc", get(handle_websocket_request))
855 .layer(
856 ServiceBuilder::new()
857 .layer(Extension(server.app_state.clone()))
858 .layer(middleware::from_fn(auth::validate_header)),
859 )
860 .route("/metrics", get(handle_metrics))
861 .layer(Extension(server))
862}
863
864pub async fn handle_websocket_request(
865 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
866 app_version_header: Option<TypedHeader<AppVersionHeader>>,
867 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
868 Extension(server): Extension<Arc<Server>>,
869 Extension(user): Extension<User>,
870 Extension(impersonator): Extension<Impersonator>,
871 ws: WebSocketUpgrade,
872) -> axum::response::Response {
873 if protocol_version != rpc::PROTOCOL_VERSION {
874 return (
875 StatusCode::UPGRADE_REQUIRED,
876 "client must be upgraded".to_string(),
877 )
878 .into_response();
879 }
880
881 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
882 return (
883 StatusCode::UPGRADE_REQUIRED,
884 "no version header found".to_string(),
885 )
886 .into_response();
887 };
888
889 if !version.is_supported() {
890 return (
891 StatusCode::UPGRADE_REQUIRED,
892 "client must be upgraded".to_string(),
893 )
894 .into_response();
895 }
896
897 let socket_address = socket_address.to_string();
898 ws.on_upgrade(move |socket| {
899 use util::ResultExt;
900 let socket = socket
901 .map_ok(to_tungstenite_message)
902 .err_into()
903 .with(|message| async move { Ok(to_axum_message(message)) });
904 let connection = Connection::new(Box::pin(socket));
905 async move {
906 server
907 .handle_connection(
908 connection,
909 socket_address,
910 user,
911 version,
912 impersonator.0,
913 None,
914 Executor::Production,
915 )
916 .await
917 .log_err();
918 }
919 })
920}
921
922pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
923 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
924 let connections_metric = CONNECTIONS_METRIC
925 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
926
927 let connections = server
928 .connection_pool
929 .lock()
930 .connections()
931 .filter(|connection| !connection.admin)
932 .count();
933 connections_metric.set(connections as _);
934
935 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
936 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
937 register_int_gauge!(
938 "shared_projects",
939 "number of open projects with one or more guests"
940 )
941 .unwrap()
942 });
943
944 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
945 shared_projects_metric.set(shared_projects as _);
946
947 let encoder = prometheus::TextEncoder::new();
948 let metric_families = prometheus::gather();
949 let encoded_metrics = encoder
950 .encode_to_string(&metric_families)
951 .map_err(|err| anyhow!("{}", err))?;
952 Ok(encoded_metrics)
953}
954
955#[instrument(err, skip(executor))]
956async fn connection_lost(
957 session: Session,
958 mut teardown: watch::Receiver<bool>,
959 executor: Executor,
960) -> Result<()> {
961 session.peer.disconnect(session.connection_id);
962 session
963 .connection_pool()
964 .await
965 .remove_connection(session.connection_id)?;
966
967 session
968 .db()
969 .await
970 .connection_lost(session.connection_id)
971 .await
972 .trace_err();
973
974 futures::select_biased! {
975 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
976 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
977 leave_room_for_session(&session).await.trace_err();
978 leave_channel_buffers_for_session(&session)
979 .await
980 .trace_err();
981
982 if !session
983 .connection_pool()
984 .await
985 .is_user_online(session.user_id)
986 {
987 let db = session.db().await;
988 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
989 room_updated(&room, &session.peer);
990 }
991 }
992
993 update_user_contacts(session.user_id, &session).await?;
994 }
995 _ = teardown.changed().fuse() => {}
996 }
997
998 Ok(())
999}
1000
1001/// Acknowledges a ping from a client, used to keep the connection alive.
1002async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1003 response.send(proto::Ack {})?;
1004 Ok(())
1005}
1006
1007/// Creates a new room for calling (outside of channels)
1008async fn create_room(
1009 _request: proto::CreateRoom,
1010 response: Response<proto::CreateRoom>,
1011 session: Session,
1012) -> Result<()> {
1013 let live_kit_room = nanoid::nanoid!(30);
1014
1015 let live_kit_connection_info = {
1016 let live_kit_room = live_kit_room.clone();
1017 let live_kit = session.live_kit_client.as_ref();
1018
1019 util::async_maybe!({
1020 let live_kit = live_kit?;
1021
1022 let token = live_kit
1023 .room_token(&live_kit_room, &session.user_id.to_string())
1024 .trace_err()?;
1025
1026 Some(proto::LiveKitConnectionInfo {
1027 server_url: live_kit.url().into(),
1028 token,
1029 can_publish: true,
1030 })
1031 })
1032 }
1033 .await;
1034
1035 let room = session
1036 .db()
1037 .await
1038 .create_room(session.user_id, session.connection_id, &live_kit_room)
1039 .await?;
1040
1041 response.send(proto::CreateRoomResponse {
1042 room: Some(room.clone()),
1043 live_kit_connection_info,
1044 })?;
1045
1046 update_user_contacts(session.user_id, &session).await?;
1047 Ok(())
1048}
1049
1050/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1051async fn join_room(
1052 request: proto::JoinRoom,
1053 response: Response<proto::JoinRoom>,
1054 session: Session,
1055) -> Result<()> {
1056 let room_id = RoomId::from_proto(request.id);
1057
1058 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1059
1060 if let Some(channel_id) = channel_id {
1061 return join_channel_internal(channel_id, Box::new(response), session).await;
1062 }
1063
1064 let joined_room = {
1065 let room = session
1066 .db()
1067 .await
1068 .join_room(room_id, session.user_id, session.connection_id)
1069 .await?;
1070 room_updated(&room.room, &session.peer);
1071 room.into_inner()
1072 };
1073
1074 for connection_id in session
1075 .connection_pool()
1076 .await
1077 .user_connection_ids(session.user_id)
1078 {
1079 session
1080 .peer
1081 .send(
1082 connection_id,
1083 proto::CallCanceled {
1084 room_id: room_id.to_proto(),
1085 },
1086 )
1087 .trace_err();
1088 }
1089
1090 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1091 if let Some(token) = live_kit
1092 .room_token(
1093 &joined_room.room.live_kit_room,
1094 &session.user_id.to_string(),
1095 )
1096 .trace_err()
1097 {
1098 Some(proto::LiveKitConnectionInfo {
1099 server_url: live_kit.url().into(),
1100 token,
1101 can_publish: true,
1102 })
1103 } else {
1104 None
1105 }
1106 } else {
1107 None
1108 };
1109
1110 response.send(proto::JoinRoomResponse {
1111 room: Some(joined_room.room),
1112 channel_id: None,
1113 live_kit_connection_info,
1114 })?;
1115
1116 update_user_contacts(session.user_id, &session).await?;
1117 Ok(())
1118}
1119
1120/// Rejoin room is used to reconnect to a room after connection errors.
1121async fn rejoin_room(
1122 request: proto::RejoinRoom,
1123 response: Response<proto::RejoinRoom>,
1124 session: Session,
1125) -> Result<()> {
1126 let room;
1127 let channel_id;
1128 let channel_members;
1129 {
1130 let mut rejoined_room = session
1131 .db()
1132 .await
1133 .rejoin_room(request, session.user_id, session.connection_id)
1134 .await?;
1135
1136 response.send(proto::RejoinRoomResponse {
1137 room: Some(rejoined_room.room.clone()),
1138 reshared_projects: rejoined_room
1139 .reshared_projects
1140 .iter()
1141 .map(|project| proto::ResharedProject {
1142 id: project.id.to_proto(),
1143 collaborators: project
1144 .collaborators
1145 .iter()
1146 .map(|collaborator| collaborator.to_proto())
1147 .collect(),
1148 })
1149 .collect(),
1150 rejoined_projects: rejoined_room
1151 .rejoined_projects
1152 .iter()
1153 .map(|rejoined_project| proto::RejoinedProject {
1154 id: rejoined_project.id.to_proto(),
1155 worktrees: rejoined_project
1156 .worktrees
1157 .iter()
1158 .map(|worktree| proto::WorktreeMetadata {
1159 id: worktree.id,
1160 root_name: worktree.root_name.clone(),
1161 visible: worktree.visible,
1162 abs_path: worktree.abs_path.clone(),
1163 })
1164 .collect(),
1165 collaborators: rejoined_project
1166 .collaborators
1167 .iter()
1168 .map(|collaborator| collaborator.to_proto())
1169 .collect(),
1170 language_servers: rejoined_project.language_servers.clone(),
1171 })
1172 .collect(),
1173 })?;
1174 room_updated(&rejoined_room.room, &session.peer);
1175
1176 for project in &rejoined_room.reshared_projects {
1177 for collaborator in &project.collaborators {
1178 session
1179 .peer
1180 .send(
1181 collaborator.connection_id,
1182 proto::UpdateProjectCollaborator {
1183 project_id: project.id.to_proto(),
1184 old_peer_id: Some(project.old_connection_id.into()),
1185 new_peer_id: Some(session.connection_id.into()),
1186 },
1187 )
1188 .trace_err();
1189 }
1190
1191 broadcast(
1192 Some(session.connection_id),
1193 project
1194 .collaborators
1195 .iter()
1196 .map(|collaborator| collaborator.connection_id),
1197 |connection_id| {
1198 session.peer.forward_send(
1199 session.connection_id,
1200 connection_id,
1201 proto::UpdateProject {
1202 project_id: project.id.to_proto(),
1203 worktrees: project.worktrees.clone(),
1204 },
1205 )
1206 },
1207 );
1208 }
1209
1210 for project in &rejoined_room.rejoined_projects {
1211 for collaborator in &project.collaborators {
1212 session
1213 .peer
1214 .send(
1215 collaborator.connection_id,
1216 proto::UpdateProjectCollaborator {
1217 project_id: project.id.to_proto(),
1218 old_peer_id: Some(project.old_connection_id.into()),
1219 new_peer_id: Some(session.connection_id.into()),
1220 },
1221 )
1222 .trace_err();
1223 }
1224 }
1225
1226 for project in &mut rejoined_room.rejoined_projects {
1227 for worktree in mem::take(&mut project.worktrees) {
1228 #[cfg(any(test, feature = "test-support"))]
1229 const MAX_CHUNK_SIZE: usize = 2;
1230 #[cfg(not(any(test, feature = "test-support")))]
1231 const MAX_CHUNK_SIZE: usize = 256;
1232
1233 // Stream this worktree's entries.
1234 let message = proto::UpdateWorktree {
1235 project_id: project.id.to_proto(),
1236 worktree_id: worktree.id,
1237 abs_path: worktree.abs_path.clone(),
1238 root_name: worktree.root_name,
1239 updated_entries: worktree.updated_entries,
1240 removed_entries: worktree.removed_entries,
1241 scan_id: worktree.scan_id,
1242 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1243 updated_repositories: worktree.updated_repositories,
1244 removed_repositories: worktree.removed_repositories,
1245 };
1246 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1247 session.peer.send(session.connection_id, update.clone())?;
1248 }
1249
1250 // Stream this worktree's diagnostics.
1251 for summary in worktree.diagnostic_summaries {
1252 session.peer.send(
1253 session.connection_id,
1254 proto::UpdateDiagnosticSummary {
1255 project_id: project.id.to_proto(),
1256 worktree_id: worktree.id,
1257 summary: Some(summary),
1258 },
1259 )?;
1260 }
1261
1262 for settings_file in worktree.settings_files {
1263 session.peer.send(
1264 session.connection_id,
1265 proto::UpdateWorktreeSettings {
1266 project_id: project.id.to_proto(),
1267 worktree_id: worktree.id,
1268 path: settings_file.path,
1269 content: Some(settings_file.content),
1270 },
1271 )?;
1272 }
1273 }
1274
1275 for language_server in &project.language_servers {
1276 session.peer.send(
1277 session.connection_id,
1278 proto::UpdateLanguageServer {
1279 project_id: project.id.to_proto(),
1280 language_server_id: language_server.id,
1281 variant: Some(
1282 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1283 proto::LspDiskBasedDiagnosticsUpdated {},
1284 ),
1285 ),
1286 },
1287 )?;
1288 }
1289 }
1290
1291 let rejoined_room = rejoined_room.into_inner();
1292
1293 room = rejoined_room.room;
1294 channel_id = rejoined_room.channel_id;
1295 channel_members = rejoined_room.channel_members;
1296 }
1297
1298 if let Some(channel_id) = channel_id {
1299 channel_updated(
1300 channel_id,
1301 &room,
1302 &channel_members,
1303 &session.peer,
1304 &*session.connection_pool().await,
1305 );
1306 }
1307
1308 update_user_contacts(session.user_id, &session).await?;
1309 Ok(())
1310}
1311
1312/// leave room disconnects from the room.
1313async fn leave_room(
1314 _: proto::LeaveRoom,
1315 response: Response<proto::LeaveRoom>,
1316 session: Session,
1317) -> Result<()> {
1318 leave_room_for_session(&session).await?;
1319 response.send(proto::Ack {})?;
1320 Ok(())
1321}
1322
1323/// Updates the permissions of someone else in the room.
1324async fn set_room_participant_role(
1325 request: proto::SetRoomParticipantRole,
1326 response: Response<proto::SetRoomParticipantRole>,
1327 session: Session,
1328) -> Result<()> {
1329 let user_id = UserId::from_proto(request.user_id);
1330 let role = ChannelRole::from(request.role());
1331
1332 if role == ChannelRole::Talker {
1333 let pool = session.connection_pool().await;
1334
1335 for connection in pool.user_connections(user_id) {
1336 if !connection.zed_version.supports_talker_role() {
1337 Err(anyhow!(
1338 "This user is on zed {} which does not support unmute",
1339 connection.zed_version
1340 ))?;
1341 }
1342 }
1343 }
1344
1345 let (live_kit_room, can_publish) = {
1346 let room = session
1347 .db()
1348 .await
1349 .set_room_participant_role(
1350 session.user_id,
1351 RoomId::from_proto(request.room_id),
1352 user_id,
1353 role,
1354 )
1355 .await?;
1356
1357 let live_kit_room = room.live_kit_room.clone();
1358 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1359 room_updated(&room, &session.peer);
1360 (live_kit_room, can_publish)
1361 };
1362
1363 if let Some(live_kit) = session.live_kit_client.as_ref() {
1364 live_kit
1365 .update_participant(
1366 live_kit_room.clone(),
1367 request.user_id.to_string(),
1368 live_kit_server::proto::ParticipantPermission {
1369 can_subscribe: true,
1370 can_publish,
1371 can_publish_data: can_publish,
1372 hidden: false,
1373 recorder: false,
1374 },
1375 )
1376 .await
1377 .trace_err();
1378 }
1379
1380 response.send(proto::Ack {})?;
1381 Ok(())
1382}
1383
1384/// Call someone else into the current room
1385async fn call(
1386 request: proto::Call,
1387 response: Response<proto::Call>,
1388 session: Session,
1389) -> Result<()> {
1390 let room_id = RoomId::from_proto(request.room_id);
1391 let calling_user_id = session.user_id;
1392 let calling_connection_id = session.connection_id;
1393 let called_user_id = UserId::from_proto(request.called_user_id);
1394 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1395 if !session
1396 .db()
1397 .await
1398 .has_contact(calling_user_id, called_user_id)
1399 .await?
1400 {
1401 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1402 }
1403
1404 let incoming_call = {
1405 let (room, incoming_call) = &mut *session
1406 .db()
1407 .await
1408 .call(
1409 room_id,
1410 calling_user_id,
1411 calling_connection_id,
1412 called_user_id,
1413 initial_project_id,
1414 )
1415 .await?;
1416 room_updated(&room, &session.peer);
1417 mem::take(incoming_call)
1418 };
1419 update_user_contacts(called_user_id, &session).await?;
1420
1421 let mut calls = session
1422 .connection_pool()
1423 .await
1424 .user_connection_ids(called_user_id)
1425 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1426 .collect::<FuturesUnordered<_>>();
1427
1428 while let Some(call_response) = calls.next().await {
1429 match call_response.as_ref() {
1430 Ok(_) => {
1431 response.send(proto::Ack {})?;
1432 return Ok(());
1433 }
1434 Err(_) => {
1435 call_response.trace_err();
1436 }
1437 }
1438 }
1439
1440 {
1441 let room = session
1442 .db()
1443 .await
1444 .call_failed(room_id, called_user_id)
1445 .await?;
1446 room_updated(&room, &session.peer);
1447 }
1448 update_user_contacts(called_user_id, &session).await?;
1449
1450 Err(anyhow!("failed to ring user"))?
1451}
1452
1453/// Cancel an outgoing call.
1454async fn cancel_call(
1455 request: proto::CancelCall,
1456 response: Response<proto::CancelCall>,
1457 session: Session,
1458) -> Result<()> {
1459 let called_user_id = UserId::from_proto(request.called_user_id);
1460 let room_id = RoomId::from_proto(request.room_id);
1461 {
1462 let room = session
1463 .db()
1464 .await
1465 .cancel_call(room_id, session.connection_id, called_user_id)
1466 .await?;
1467 room_updated(&room, &session.peer);
1468 }
1469
1470 for connection_id in session
1471 .connection_pool()
1472 .await
1473 .user_connection_ids(called_user_id)
1474 {
1475 session
1476 .peer
1477 .send(
1478 connection_id,
1479 proto::CallCanceled {
1480 room_id: room_id.to_proto(),
1481 },
1482 )
1483 .trace_err();
1484 }
1485 response.send(proto::Ack {})?;
1486
1487 update_user_contacts(called_user_id, &session).await?;
1488 Ok(())
1489}
1490
1491/// Decline an incoming call.
1492async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1493 let room_id = RoomId::from_proto(message.room_id);
1494 {
1495 let room = session
1496 .db()
1497 .await
1498 .decline_call(Some(room_id), session.user_id)
1499 .await?
1500 .ok_or_else(|| anyhow!("failed to decline call"))?;
1501 room_updated(&room, &session.peer);
1502 }
1503
1504 for connection_id in session
1505 .connection_pool()
1506 .await
1507 .user_connection_ids(session.user_id)
1508 {
1509 session
1510 .peer
1511 .send(
1512 connection_id,
1513 proto::CallCanceled {
1514 room_id: room_id.to_proto(),
1515 },
1516 )
1517 .trace_err();
1518 }
1519 update_user_contacts(session.user_id, &session).await?;
1520 Ok(())
1521}
1522
1523/// Updates other participants in the room with your current location.
1524async fn update_participant_location(
1525 request: proto::UpdateParticipantLocation,
1526 response: Response<proto::UpdateParticipantLocation>,
1527 session: Session,
1528) -> Result<()> {
1529 let room_id = RoomId::from_proto(request.room_id);
1530 let location = request
1531 .location
1532 .ok_or_else(|| anyhow!("invalid location"))?;
1533
1534 let db = session.db().await;
1535 let room = db
1536 .update_room_participant_location(room_id, session.connection_id, location)
1537 .await?;
1538
1539 room_updated(&room, &session.peer);
1540 response.send(proto::Ack {})?;
1541 Ok(())
1542}
1543
1544/// Share a project into the room.
1545async fn share_project(
1546 request: proto::ShareProject,
1547 response: Response<proto::ShareProject>,
1548 session: Session,
1549) -> Result<()> {
1550 let (project_id, room) = &*session
1551 .db()
1552 .await
1553 .share_project(
1554 RoomId::from_proto(request.room_id),
1555 session.connection_id,
1556 &request.worktrees,
1557 )
1558 .await?;
1559 response.send(proto::ShareProjectResponse {
1560 project_id: project_id.to_proto(),
1561 })?;
1562 room_updated(&room, &session.peer);
1563
1564 Ok(())
1565}
1566
1567/// Unshare a project from the room.
1568async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1569 let project_id = ProjectId::from_proto(message.project_id);
1570
1571 let (room, guest_connection_ids) = &*session
1572 .db()
1573 .await
1574 .unshare_project(project_id, session.connection_id)
1575 .await?;
1576
1577 broadcast(
1578 Some(session.connection_id),
1579 guest_connection_ids.iter().copied(),
1580 |conn_id| session.peer.send(conn_id, message.clone()),
1581 );
1582 room_updated(&room, &session.peer);
1583
1584 Ok(())
1585}
1586
1587/// Join someone elses shared project.
1588async fn join_project(
1589 request: proto::JoinProject,
1590 response: Response<proto::JoinProject>,
1591 session: Session,
1592) -> Result<()> {
1593 let project_id = ProjectId::from_proto(request.project_id);
1594
1595 tracing::info!(%project_id, "join project");
1596
1597 let (project, replica_id) = &mut *session
1598 .db()
1599 .await
1600 .join_project_in_room(project_id, session.connection_id)
1601 .await?;
1602
1603 join_project_internal(response, session, project, replica_id)
1604}
1605
1606trait JoinProjectInternalResponse {
1607 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1608}
1609impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1610 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1611 Response::<proto::JoinProject>::send(self, result)
1612 }
1613}
1614impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1615 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1616 Response::<proto::JoinHostedProject>::send(self, result)
1617 }
1618}
1619
1620fn join_project_internal(
1621 response: impl JoinProjectInternalResponse,
1622 session: Session,
1623 project: &mut Project,
1624 replica_id: &ReplicaId,
1625) -> Result<()> {
1626 let collaborators = project
1627 .collaborators
1628 .iter()
1629 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1630 .map(|collaborator| collaborator.to_proto())
1631 .collect::<Vec<_>>();
1632 let project_id = project.id;
1633 let guest_user_id = session.user_id;
1634
1635 let worktrees = project
1636 .worktrees
1637 .iter()
1638 .map(|(id, worktree)| proto::WorktreeMetadata {
1639 id: *id,
1640 root_name: worktree.root_name.clone(),
1641 visible: worktree.visible,
1642 abs_path: worktree.abs_path.clone(),
1643 })
1644 .collect::<Vec<_>>();
1645
1646 for collaborator in &collaborators {
1647 session
1648 .peer
1649 .send(
1650 collaborator.peer_id.unwrap().into(),
1651 proto::AddProjectCollaborator {
1652 project_id: project_id.to_proto(),
1653 collaborator: Some(proto::Collaborator {
1654 peer_id: Some(session.connection_id.into()),
1655 replica_id: replica_id.0 as u32,
1656 user_id: guest_user_id.to_proto(),
1657 }),
1658 },
1659 )
1660 .trace_err();
1661 }
1662
1663 // First, we send the metadata associated with each worktree.
1664 response.send(proto::JoinProjectResponse {
1665 project_id: project.id.0 as u64,
1666 worktrees: worktrees.clone(),
1667 replica_id: replica_id.0 as u32,
1668 collaborators: collaborators.clone(),
1669 language_servers: project.language_servers.clone(),
1670 role: project.role.into(), // todo
1671 })?;
1672
1673 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1674 #[cfg(any(test, feature = "test-support"))]
1675 const MAX_CHUNK_SIZE: usize = 2;
1676 #[cfg(not(any(test, feature = "test-support")))]
1677 const MAX_CHUNK_SIZE: usize = 256;
1678
1679 // Stream this worktree's entries.
1680 let message = proto::UpdateWorktree {
1681 project_id: project_id.to_proto(),
1682 worktree_id,
1683 abs_path: worktree.abs_path.clone(),
1684 root_name: worktree.root_name,
1685 updated_entries: worktree.entries,
1686 removed_entries: Default::default(),
1687 scan_id: worktree.scan_id,
1688 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1689 updated_repositories: worktree.repository_entries.into_values().collect(),
1690 removed_repositories: Default::default(),
1691 };
1692 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1693 session.peer.send(session.connection_id, update.clone())?;
1694 }
1695
1696 // Stream this worktree's diagnostics.
1697 for summary in worktree.diagnostic_summaries {
1698 session.peer.send(
1699 session.connection_id,
1700 proto::UpdateDiagnosticSummary {
1701 project_id: project_id.to_proto(),
1702 worktree_id: worktree.id,
1703 summary: Some(summary),
1704 },
1705 )?;
1706 }
1707
1708 for settings_file in worktree.settings_files {
1709 session.peer.send(
1710 session.connection_id,
1711 proto::UpdateWorktreeSettings {
1712 project_id: project_id.to_proto(),
1713 worktree_id: worktree.id,
1714 path: settings_file.path,
1715 content: Some(settings_file.content),
1716 },
1717 )?;
1718 }
1719 }
1720
1721 for language_server in &project.language_servers {
1722 session.peer.send(
1723 session.connection_id,
1724 proto::UpdateLanguageServer {
1725 project_id: project_id.to_proto(),
1726 language_server_id: language_server.id,
1727 variant: Some(
1728 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1729 proto::LspDiskBasedDiagnosticsUpdated {},
1730 ),
1731 ),
1732 },
1733 )?;
1734 }
1735
1736 Ok(())
1737}
1738
1739/// Leave someone elses shared project.
1740async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1741 let sender_id = session.connection_id;
1742 let project_id = ProjectId::from_proto(request.project_id);
1743 let db = session.db().await;
1744 if db.is_hosted_project(project_id).await? {
1745 let project = db.leave_hosted_project(project_id, sender_id).await?;
1746 project_left(&project, &session);
1747 return Ok(());
1748 }
1749
1750 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1751 tracing::info!(
1752 %project_id,
1753 host_user_id = ?project.host_user_id,
1754 host_connection_id = ?project.host_connection_id,
1755 "leave project"
1756 );
1757
1758 project_left(&project, &session);
1759 room_updated(&room, &session.peer);
1760
1761 Ok(())
1762}
1763
1764async fn join_hosted_project(
1765 request: proto::JoinHostedProject,
1766 response: Response<proto::JoinHostedProject>,
1767 session: Session,
1768) -> Result<()> {
1769 let (mut project, replica_id) = session
1770 .db()
1771 .await
1772 .join_hosted_project(
1773 ProjectId(request.project_id as i32),
1774 session.user_id,
1775 session.connection_id,
1776 )
1777 .await?;
1778
1779 join_project_internal(response, session, &mut project, &replica_id)
1780}
1781
1782/// Updates other participants with changes to the project
1783async fn update_project(
1784 request: proto::UpdateProject,
1785 response: Response<proto::UpdateProject>,
1786 session: Session,
1787) -> Result<()> {
1788 let project_id = ProjectId::from_proto(request.project_id);
1789 let (room, guest_connection_ids) = &*session
1790 .db()
1791 .await
1792 .update_project(project_id, session.connection_id, &request.worktrees)
1793 .await?;
1794 broadcast(
1795 Some(session.connection_id),
1796 guest_connection_ids.iter().copied(),
1797 |connection_id| {
1798 session
1799 .peer
1800 .forward_send(session.connection_id, connection_id, request.clone())
1801 },
1802 );
1803 room_updated(&room, &session.peer);
1804 response.send(proto::Ack {})?;
1805
1806 Ok(())
1807}
1808
1809/// Updates other participants with changes to the worktree
1810async fn update_worktree(
1811 request: proto::UpdateWorktree,
1812 response: Response<proto::UpdateWorktree>,
1813 session: Session,
1814) -> Result<()> {
1815 let guest_connection_ids = session
1816 .db()
1817 .await
1818 .update_worktree(&request, session.connection_id)
1819 .await?;
1820
1821 broadcast(
1822 Some(session.connection_id),
1823 guest_connection_ids.iter().copied(),
1824 |connection_id| {
1825 session
1826 .peer
1827 .forward_send(session.connection_id, connection_id, request.clone())
1828 },
1829 );
1830 response.send(proto::Ack {})?;
1831 Ok(())
1832}
1833
1834/// Updates other participants with changes to the diagnostics
1835async fn update_diagnostic_summary(
1836 message: proto::UpdateDiagnosticSummary,
1837 session: Session,
1838) -> Result<()> {
1839 let guest_connection_ids = session
1840 .db()
1841 .await
1842 .update_diagnostic_summary(&message, session.connection_id)
1843 .await?;
1844
1845 broadcast(
1846 Some(session.connection_id),
1847 guest_connection_ids.iter().copied(),
1848 |connection_id| {
1849 session
1850 .peer
1851 .forward_send(session.connection_id, connection_id, message.clone())
1852 },
1853 );
1854
1855 Ok(())
1856}
1857
1858/// Updates other participants with changes to the worktree settings
1859async fn update_worktree_settings(
1860 message: proto::UpdateWorktreeSettings,
1861 session: Session,
1862) -> Result<()> {
1863 let guest_connection_ids = session
1864 .db()
1865 .await
1866 .update_worktree_settings(&message, session.connection_id)
1867 .await?;
1868
1869 broadcast(
1870 Some(session.connection_id),
1871 guest_connection_ids.iter().copied(),
1872 |connection_id| {
1873 session
1874 .peer
1875 .forward_send(session.connection_id, connection_id, message.clone())
1876 },
1877 );
1878
1879 Ok(())
1880}
1881
1882/// Notify other participants that a language server has started.
1883async fn start_language_server(
1884 request: proto::StartLanguageServer,
1885 session: Session,
1886) -> Result<()> {
1887 let guest_connection_ids = session
1888 .db()
1889 .await
1890 .start_language_server(&request, session.connection_id)
1891 .await?;
1892
1893 broadcast(
1894 Some(session.connection_id),
1895 guest_connection_ids.iter().copied(),
1896 |connection_id| {
1897 session
1898 .peer
1899 .forward_send(session.connection_id, connection_id, request.clone())
1900 },
1901 );
1902 Ok(())
1903}
1904
1905/// Notify other participants that a language server has changed.
1906async fn update_language_server(
1907 request: proto::UpdateLanguageServer,
1908 session: Session,
1909) -> Result<()> {
1910 let project_id = ProjectId::from_proto(request.project_id);
1911 let project_connection_ids = session
1912 .db()
1913 .await
1914 .project_connection_ids(project_id, session.connection_id)
1915 .await?;
1916 broadcast(
1917 Some(session.connection_id),
1918 project_connection_ids.iter().copied(),
1919 |connection_id| {
1920 session
1921 .peer
1922 .forward_send(session.connection_id, connection_id, request.clone())
1923 },
1924 );
1925 Ok(())
1926}
1927
1928/// forward a project request to the host. These requests should be read only
1929/// as guests are allowed to send them.
1930async fn forward_read_only_project_request<T>(
1931 request: T,
1932 response: Response<T>,
1933 session: Session,
1934) -> Result<()>
1935where
1936 T: EntityMessage + RequestMessage,
1937{
1938 let project_id = ProjectId::from_proto(request.remote_entity_id());
1939 let host_connection_id = session
1940 .db()
1941 .await
1942 .host_for_read_only_project_request(project_id, session.connection_id)
1943 .await?;
1944 let payload = session
1945 .peer
1946 .forward_request(session.connection_id, host_connection_id, request)
1947 .await?;
1948 response.send(payload)?;
1949 Ok(())
1950}
1951
1952/// forward a project request to the host. These requests are disallowed
1953/// for guests.
1954async fn forward_mutating_project_request<T>(
1955 request: T,
1956 response: Response<T>,
1957 session: Session,
1958) -> Result<()>
1959where
1960 T: EntityMessage + RequestMessage,
1961{
1962 let project_id = ProjectId::from_proto(request.remote_entity_id());
1963 let host_connection_id = session
1964 .db()
1965 .await
1966 .host_for_mutating_project_request(project_id, session.connection_id)
1967 .await?;
1968 let payload = session
1969 .peer
1970 .forward_request(session.connection_id, host_connection_id, request)
1971 .await?;
1972 response.send(payload)?;
1973 Ok(())
1974}
1975
1976/// Notify other participants that a new buffer has been created
1977async fn create_buffer_for_peer(
1978 request: proto::CreateBufferForPeer,
1979 session: Session,
1980) -> Result<()> {
1981 session
1982 .db()
1983 .await
1984 .check_user_is_project_host(
1985 ProjectId::from_proto(request.project_id),
1986 session.connection_id,
1987 )
1988 .await?;
1989 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1990 session
1991 .peer
1992 .forward_send(session.connection_id, peer_id.into(), request)?;
1993 Ok(())
1994}
1995
1996/// Notify other participants that a buffer has been updated. This is
1997/// allowed for guests as long as the update is limited to selections.
1998async fn update_buffer(
1999 request: proto::UpdateBuffer,
2000 response: Response<proto::UpdateBuffer>,
2001 session: Session,
2002) -> Result<()> {
2003 let project_id = ProjectId::from_proto(request.project_id);
2004 let mut guest_connection_ids;
2005 let mut host_connection_id = None;
2006
2007 let mut requires_write_permission = false;
2008
2009 for op in request.operations.iter() {
2010 match op.variant {
2011 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2012 Some(_) => requires_write_permission = true,
2013 }
2014 }
2015
2016 {
2017 let collaborators = session
2018 .db()
2019 .await
2020 .project_collaborators_for_buffer_update(
2021 project_id,
2022 session.connection_id,
2023 requires_write_permission,
2024 )
2025 .await?;
2026 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2027 for collaborator in collaborators.iter() {
2028 if collaborator.is_host {
2029 host_connection_id = Some(collaborator.connection_id);
2030 } else {
2031 guest_connection_ids.push(collaborator.connection_id);
2032 }
2033 }
2034 }
2035 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2036
2037 broadcast(
2038 Some(session.connection_id),
2039 guest_connection_ids,
2040 |connection_id| {
2041 session
2042 .peer
2043 .forward_send(session.connection_id, connection_id, request.clone())
2044 },
2045 );
2046 if host_connection_id != session.connection_id {
2047 session
2048 .peer
2049 .forward_request(session.connection_id, host_connection_id, request.clone())
2050 .await?;
2051 }
2052
2053 response.send(proto::Ack {})?;
2054 Ok(())
2055}
2056
2057/// Notify other participants that a project has been updated.
2058async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2059 request: T,
2060 session: Session,
2061) -> Result<()> {
2062 let project_id = ProjectId::from_proto(request.remote_entity_id());
2063 let project_connection_ids = session
2064 .db()
2065 .await
2066 .project_connection_ids(project_id, session.connection_id)
2067 .await?;
2068
2069 broadcast(
2070 Some(session.connection_id),
2071 project_connection_ids.iter().copied(),
2072 |connection_id| {
2073 session
2074 .peer
2075 .forward_send(session.connection_id, connection_id, request.clone())
2076 },
2077 );
2078 Ok(())
2079}
2080
2081/// Start following another user in a call.
2082async fn follow(
2083 request: proto::Follow,
2084 response: Response<proto::Follow>,
2085 session: Session,
2086) -> Result<()> {
2087 let room_id = RoomId::from_proto(request.room_id);
2088 let project_id = request.project_id.map(ProjectId::from_proto);
2089 let leader_id = request
2090 .leader_id
2091 .ok_or_else(|| anyhow!("invalid leader id"))?
2092 .into();
2093 let follower_id = session.connection_id;
2094
2095 session
2096 .db()
2097 .await
2098 .check_room_participants(room_id, leader_id, session.connection_id)
2099 .await?;
2100
2101 let response_payload = session
2102 .peer
2103 .forward_request(session.connection_id, leader_id, request)
2104 .await?;
2105 response.send(response_payload)?;
2106
2107 if let Some(project_id) = project_id {
2108 let room = session
2109 .db()
2110 .await
2111 .follow(room_id, project_id, leader_id, follower_id)
2112 .await?;
2113 room_updated(&room, &session.peer);
2114 }
2115
2116 Ok(())
2117}
2118
2119/// Stop following another user in a call.
2120async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2121 let room_id = RoomId::from_proto(request.room_id);
2122 let project_id = request.project_id.map(ProjectId::from_proto);
2123 let leader_id = request
2124 .leader_id
2125 .ok_or_else(|| anyhow!("invalid leader id"))?
2126 .into();
2127 let follower_id = session.connection_id;
2128
2129 session
2130 .db()
2131 .await
2132 .check_room_participants(room_id, leader_id, session.connection_id)
2133 .await?;
2134
2135 session
2136 .peer
2137 .forward_send(session.connection_id, leader_id, request)?;
2138
2139 if let Some(project_id) = project_id {
2140 let room = session
2141 .db()
2142 .await
2143 .unfollow(room_id, project_id, leader_id, follower_id)
2144 .await?;
2145 room_updated(&room, &session.peer);
2146 }
2147
2148 Ok(())
2149}
2150
2151/// Notify everyone following you of your current location.
2152async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2153 let room_id = RoomId::from_proto(request.room_id);
2154 let database = session.db.lock().await;
2155
2156 let connection_ids = if let Some(project_id) = request.project_id {
2157 let project_id = ProjectId::from_proto(project_id);
2158 database
2159 .project_connection_ids(project_id, session.connection_id)
2160 .await?
2161 } else {
2162 database
2163 .room_connection_ids(room_id, session.connection_id)
2164 .await?
2165 };
2166
2167 // For now, don't send view update messages back to that view's current leader.
2168 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2169 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2170 _ => None,
2171 });
2172
2173 for connection_id in connection_ids.iter().cloned() {
2174 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2175 session
2176 .peer
2177 .forward_send(session.connection_id, connection_id, request.clone())?;
2178 }
2179 }
2180 Ok(())
2181}
2182
2183/// Get public data about users.
2184async fn get_users(
2185 request: proto::GetUsers,
2186 response: Response<proto::GetUsers>,
2187 session: Session,
2188) -> Result<()> {
2189 let user_ids = request
2190 .user_ids
2191 .into_iter()
2192 .map(UserId::from_proto)
2193 .collect();
2194 let users = session
2195 .db()
2196 .await
2197 .get_users_by_ids(user_ids)
2198 .await?
2199 .into_iter()
2200 .map(|user| proto::User {
2201 id: user.id.to_proto(),
2202 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2203 github_login: user.github_login,
2204 })
2205 .collect();
2206 response.send(proto::UsersResponse { users })?;
2207 Ok(())
2208}
2209
2210/// Search for users (to invite) buy Github login
2211async fn fuzzy_search_users(
2212 request: proto::FuzzySearchUsers,
2213 response: Response<proto::FuzzySearchUsers>,
2214 session: Session,
2215) -> Result<()> {
2216 let query = request.query;
2217 let users = match query.len() {
2218 0 => vec![],
2219 1 | 2 => session
2220 .db()
2221 .await
2222 .get_user_by_github_login(&query)
2223 .await?
2224 .into_iter()
2225 .collect(),
2226 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2227 };
2228 let users = users
2229 .into_iter()
2230 .filter(|user| user.id != session.user_id)
2231 .map(|user| proto::User {
2232 id: user.id.to_proto(),
2233 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2234 github_login: user.github_login,
2235 })
2236 .collect();
2237 response.send(proto::UsersResponse { users })?;
2238 Ok(())
2239}
2240
2241/// Send a contact request to another user.
2242async fn request_contact(
2243 request: proto::RequestContact,
2244 response: Response<proto::RequestContact>,
2245 session: Session,
2246) -> Result<()> {
2247 let requester_id = session.user_id;
2248 let responder_id = UserId::from_proto(request.responder_id);
2249 if requester_id == responder_id {
2250 return Err(anyhow!("cannot add yourself as a contact"))?;
2251 }
2252
2253 let notifications = session
2254 .db()
2255 .await
2256 .send_contact_request(requester_id, responder_id)
2257 .await?;
2258
2259 // Update outgoing contact requests of requester
2260 let mut update = proto::UpdateContacts::default();
2261 update.outgoing_requests.push(responder_id.to_proto());
2262 for connection_id in session
2263 .connection_pool()
2264 .await
2265 .user_connection_ids(requester_id)
2266 {
2267 session.peer.send(connection_id, update.clone())?;
2268 }
2269
2270 // Update incoming contact requests of responder
2271 let mut update = proto::UpdateContacts::default();
2272 update
2273 .incoming_requests
2274 .push(proto::IncomingContactRequest {
2275 requester_id: requester_id.to_proto(),
2276 });
2277 let connection_pool = session.connection_pool().await;
2278 for connection_id in connection_pool.user_connection_ids(responder_id) {
2279 session.peer.send(connection_id, update.clone())?;
2280 }
2281
2282 send_notifications(&connection_pool, &session.peer, notifications);
2283
2284 response.send(proto::Ack {})?;
2285 Ok(())
2286}
2287
2288/// Accept or decline a contact request
2289async fn respond_to_contact_request(
2290 request: proto::RespondToContactRequest,
2291 response: Response<proto::RespondToContactRequest>,
2292 session: Session,
2293) -> Result<()> {
2294 let responder_id = session.user_id;
2295 let requester_id = UserId::from_proto(request.requester_id);
2296 let db = session.db().await;
2297 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2298 db.dismiss_contact_notification(responder_id, requester_id)
2299 .await?;
2300 } else {
2301 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2302
2303 let notifications = db
2304 .respond_to_contact_request(responder_id, requester_id, accept)
2305 .await?;
2306 let requester_busy = db.is_user_busy(requester_id).await?;
2307 let responder_busy = db.is_user_busy(responder_id).await?;
2308
2309 let pool = session.connection_pool().await;
2310 // Update responder with new contact
2311 let mut update = proto::UpdateContacts::default();
2312 if accept {
2313 update
2314 .contacts
2315 .push(contact_for_user(requester_id, requester_busy, &pool));
2316 }
2317 update
2318 .remove_incoming_requests
2319 .push(requester_id.to_proto());
2320 for connection_id in pool.user_connection_ids(responder_id) {
2321 session.peer.send(connection_id, update.clone())?;
2322 }
2323
2324 // Update requester with new contact
2325 let mut update = proto::UpdateContacts::default();
2326 if accept {
2327 update
2328 .contacts
2329 .push(contact_for_user(responder_id, responder_busy, &pool));
2330 }
2331 update
2332 .remove_outgoing_requests
2333 .push(responder_id.to_proto());
2334
2335 for connection_id in pool.user_connection_ids(requester_id) {
2336 session.peer.send(connection_id, update.clone())?;
2337 }
2338
2339 send_notifications(&pool, &session.peer, notifications);
2340 }
2341
2342 response.send(proto::Ack {})?;
2343 Ok(())
2344}
2345
2346/// Remove a contact.
2347async fn remove_contact(
2348 request: proto::RemoveContact,
2349 response: Response<proto::RemoveContact>,
2350 session: Session,
2351) -> Result<()> {
2352 let requester_id = session.user_id;
2353 let responder_id = UserId::from_proto(request.user_id);
2354 let db = session.db().await;
2355 let (contact_accepted, deleted_notification_id) =
2356 db.remove_contact(requester_id, responder_id).await?;
2357
2358 let pool = session.connection_pool().await;
2359 // Update outgoing contact requests of requester
2360 let mut update = proto::UpdateContacts::default();
2361 if contact_accepted {
2362 update.remove_contacts.push(responder_id.to_proto());
2363 } else {
2364 update
2365 .remove_outgoing_requests
2366 .push(responder_id.to_proto());
2367 }
2368 for connection_id in pool.user_connection_ids(requester_id) {
2369 session.peer.send(connection_id, update.clone())?;
2370 }
2371
2372 // Update incoming contact requests of responder
2373 let mut update = proto::UpdateContacts::default();
2374 if contact_accepted {
2375 update.remove_contacts.push(requester_id.to_proto());
2376 } else {
2377 update
2378 .remove_incoming_requests
2379 .push(requester_id.to_proto());
2380 }
2381 for connection_id in pool.user_connection_ids(responder_id) {
2382 session.peer.send(connection_id, update.clone())?;
2383 if let Some(notification_id) = deleted_notification_id {
2384 session.peer.send(
2385 connection_id,
2386 proto::DeleteNotification {
2387 notification_id: notification_id.to_proto(),
2388 },
2389 )?;
2390 }
2391 }
2392
2393 response.send(proto::Ack {})?;
2394 Ok(())
2395}
2396
2397/// Creates a new channel.
2398async fn create_channel(
2399 request: proto::CreateChannel,
2400 response: Response<proto::CreateChannel>,
2401 session: Session,
2402) -> Result<()> {
2403 let db = session.db().await;
2404
2405 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2406 let (channel, owner, channel_members) = db
2407 .create_channel(&request.name, parent_id, session.user_id)
2408 .await?;
2409
2410 response.send(proto::CreateChannelResponse {
2411 channel: Some(channel.to_proto()),
2412 parent_id: request.parent_id,
2413 })?;
2414
2415 let connection_pool = session.connection_pool().await;
2416 if let Some(owner) = owner {
2417 let update = proto::UpdateUserChannels {
2418 channel_memberships: vec![proto::ChannelMembership {
2419 channel_id: owner.channel_id.to_proto(),
2420 role: owner.role.into(),
2421 }],
2422 ..Default::default()
2423 };
2424 for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2425 session.peer.send(connection_id, update.clone())?;
2426 }
2427 }
2428
2429 for channel_member in channel_members {
2430 if !channel_member.role.can_see_channel(channel.visibility) {
2431 continue;
2432 }
2433
2434 let update = proto::UpdateChannels {
2435 channels: vec![channel.to_proto()],
2436 ..Default::default()
2437 };
2438 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2439 session.peer.send(connection_id, update.clone())?;
2440 }
2441 }
2442
2443 Ok(())
2444}
2445
2446/// Delete a channel
2447async fn delete_channel(
2448 request: proto::DeleteChannel,
2449 response: Response<proto::DeleteChannel>,
2450 session: Session,
2451) -> Result<()> {
2452 let db = session.db().await;
2453
2454 let channel_id = request.channel_id;
2455 let (removed_channels, member_ids) = db
2456 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2457 .await?;
2458 response.send(proto::Ack {})?;
2459
2460 // Notify members of removed channels
2461 let mut update = proto::UpdateChannels::default();
2462 update
2463 .delete_channels
2464 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2465
2466 let connection_pool = session.connection_pool().await;
2467 for member_id in member_ids {
2468 for connection_id in connection_pool.user_connection_ids(member_id) {
2469 session.peer.send(connection_id, update.clone())?;
2470 }
2471 }
2472
2473 Ok(())
2474}
2475
2476/// Invite someone to join a channel.
2477async fn invite_channel_member(
2478 request: proto::InviteChannelMember,
2479 response: Response<proto::InviteChannelMember>,
2480 session: Session,
2481) -> Result<()> {
2482 let db = session.db().await;
2483 let channel_id = ChannelId::from_proto(request.channel_id);
2484 let invitee_id = UserId::from_proto(request.user_id);
2485 let InviteMemberResult {
2486 channel,
2487 notifications,
2488 } = db
2489 .invite_channel_member(
2490 channel_id,
2491 invitee_id,
2492 session.user_id,
2493 request.role().into(),
2494 )
2495 .await?;
2496
2497 let update = proto::UpdateChannels {
2498 channel_invitations: vec![channel.to_proto()],
2499 ..Default::default()
2500 };
2501
2502 let connection_pool = session.connection_pool().await;
2503 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2504 session.peer.send(connection_id, update.clone())?;
2505 }
2506
2507 send_notifications(&connection_pool, &session.peer, notifications);
2508
2509 response.send(proto::Ack {})?;
2510 Ok(())
2511}
2512
2513/// remove someone from a channel
2514async fn remove_channel_member(
2515 request: proto::RemoveChannelMember,
2516 response: Response<proto::RemoveChannelMember>,
2517 session: Session,
2518) -> Result<()> {
2519 let db = session.db().await;
2520 let channel_id = ChannelId::from_proto(request.channel_id);
2521 let member_id = UserId::from_proto(request.user_id);
2522
2523 let RemoveChannelMemberResult {
2524 membership_update,
2525 notification_id,
2526 } = db
2527 .remove_channel_member(channel_id, member_id, session.user_id)
2528 .await?;
2529
2530 let connection_pool = &session.connection_pool().await;
2531 notify_membership_updated(
2532 &connection_pool,
2533 membership_update,
2534 member_id,
2535 &session.peer,
2536 );
2537 for connection_id in connection_pool.user_connection_ids(member_id) {
2538 if let Some(notification_id) = notification_id {
2539 session
2540 .peer
2541 .send(
2542 connection_id,
2543 proto::DeleteNotification {
2544 notification_id: notification_id.to_proto(),
2545 },
2546 )
2547 .trace_err();
2548 }
2549 }
2550
2551 response.send(proto::Ack {})?;
2552 Ok(())
2553}
2554
2555/// Toggle the channel between public and private.
2556/// Care is taken to maintain the invariant that public channels only descend from public channels,
2557/// (though members-only channels can appear at any point in the hierarchy).
2558async fn set_channel_visibility(
2559 request: proto::SetChannelVisibility,
2560 response: Response<proto::SetChannelVisibility>,
2561 session: Session,
2562) -> Result<()> {
2563 let db = session.db().await;
2564 let channel_id = ChannelId::from_proto(request.channel_id);
2565 let visibility = request.visibility().into();
2566
2567 let (channel, channel_members) = db
2568 .set_channel_visibility(channel_id, visibility, session.user_id)
2569 .await?;
2570
2571 let connection_pool = session.connection_pool().await;
2572 for member in channel_members {
2573 let update = if member.role.can_see_channel(channel.visibility) {
2574 proto::UpdateChannels {
2575 channels: vec![channel.to_proto()],
2576 ..Default::default()
2577 }
2578 } else {
2579 proto::UpdateChannels {
2580 delete_channels: vec![channel.id.to_proto()],
2581 ..Default::default()
2582 }
2583 };
2584
2585 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2586 session.peer.send(connection_id, update.clone())?;
2587 }
2588 }
2589
2590 response.send(proto::Ack {})?;
2591 Ok(())
2592}
2593
2594/// Alter the role for a user in the channel.
2595async fn set_channel_member_role(
2596 request: proto::SetChannelMemberRole,
2597 response: Response<proto::SetChannelMemberRole>,
2598 session: Session,
2599) -> Result<()> {
2600 let db = session.db().await;
2601 let channel_id = ChannelId::from_proto(request.channel_id);
2602 let member_id = UserId::from_proto(request.user_id);
2603 let result = db
2604 .set_channel_member_role(
2605 channel_id,
2606 session.user_id,
2607 member_id,
2608 request.role().into(),
2609 )
2610 .await?;
2611
2612 match result {
2613 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2614 let connection_pool = session.connection_pool().await;
2615 notify_membership_updated(
2616 &connection_pool,
2617 membership_update,
2618 member_id,
2619 &session.peer,
2620 )
2621 }
2622 db::SetMemberRoleResult::InviteUpdated(channel) => {
2623 let update = proto::UpdateChannels {
2624 channel_invitations: vec![channel.to_proto()],
2625 ..Default::default()
2626 };
2627
2628 for connection_id in session
2629 .connection_pool()
2630 .await
2631 .user_connection_ids(member_id)
2632 {
2633 session.peer.send(connection_id, update.clone())?;
2634 }
2635 }
2636 }
2637
2638 response.send(proto::Ack {})?;
2639 Ok(())
2640}
2641
2642/// Change the name of a channel
2643async fn rename_channel(
2644 request: proto::RenameChannel,
2645 response: Response<proto::RenameChannel>,
2646 session: Session,
2647) -> Result<()> {
2648 let db = session.db().await;
2649 let channel_id = ChannelId::from_proto(request.channel_id);
2650 let (channel, channel_members) = db
2651 .rename_channel(channel_id, session.user_id, &request.name)
2652 .await?;
2653
2654 response.send(proto::RenameChannelResponse {
2655 channel: Some(channel.to_proto()),
2656 })?;
2657
2658 let connection_pool = session.connection_pool().await;
2659 for channel_member in channel_members {
2660 if !channel_member.role.can_see_channel(channel.visibility) {
2661 continue;
2662 }
2663 let update = proto::UpdateChannels {
2664 channels: vec![channel.to_proto()],
2665 ..Default::default()
2666 };
2667 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2668 session.peer.send(connection_id, update.clone())?;
2669 }
2670 }
2671
2672 Ok(())
2673}
2674
2675/// Move a channel to a new parent.
2676async fn move_channel(
2677 request: proto::MoveChannel,
2678 response: Response<proto::MoveChannel>,
2679 session: Session,
2680) -> Result<()> {
2681 let channel_id = ChannelId::from_proto(request.channel_id);
2682 let to = ChannelId::from_proto(request.to);
2683
2684 let (channels, channel_members) = session
2685 .db()
2686 .await
2687 .move_channel(channel_id, to, session.user_id)
2688 .await?;
2689
2690 let connection_pool = session.connection_pool().await;
2691 for member in channel_members {
2692 let channels = channels
2693 .iter()
2694 .filter_map(|channel| {
2695 if member.role.can_see_channel(channel.visibility) {
2696 Some(channel.to_proto())
2697 } else {
2698 None
2699 }
2700 })
2701 .collect::<Vec<_>>();
2702 if channels.is_empty() {
2703 continue;
2704 }
2705
2706 let update = proto::UpdateChannels {
2707 channels,
2708 ..Default::default()
2709 };
2710
2711 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2712 session.peer.send(connection_id, update.clone())?;
2713 }
2714 }
2715
2716 response.send(Ack {})?;
2717 Ok(())
2718}
2719
2720/// Get the list of channel members
2721async fn get_channel_members(
2722 request: proto::GetChannelMembers,
2723 response: Response<proto::GetChannelMembers>,
2724 session: Session,
2725) -> Result<()> {
2726 let db = session.db().await;
2727 let channel_id = ChannelId::from_proto(request.channel_id);
2728 let members = db
2729 .get_channel_participant_details(channel_id, session.user_id)
2730 .await?;
2731 response.send(proto::GetChannelMembersResponse { members })?;
2732 Ok(())
2733}
2734
2735/// Accept or decline a channel invitation.
2736async fn respond_to_channel_invite(
2737 request: proto::RespondToChannelInvite,
2738 response: Response<proto::RespondToChannelInvite>,
2739 session: Session,
2740) -> Result<()> {
2741 let db = session.db().await;
2742 let channel_id = ChannelId::from_proto(request.channel_id);
2743 let RespondToChannelInvite {
2744 membership_update,
2745 notifications,
2746 } = db
2747 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2748 .await?;
2749
2750 let connection_pool = session.connection_pool().await;
2751 if let Some(membership_update) = membership_update {
2752 notify_membership_updated(
2753 &connection_pool,
2754 membership_update,
2755 session.user_id,
2756 &session.peer,
2757 );
2758 } else {
2759 let update = proto::UpdateChannels {
2760 remove_channel_invitations: vec![channel_id.to_proto()],
2761 ..Default::default()
2762 };
2763
2764 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2765 session.peer.send(connection_id, update.clone())?;
2766 }
2767 };
2768
2769 send_notifications(&connection_pool, &session.peer, notifications);
2770
2771 response.send(proto::Ack {})?;
2772
2773 Ok(())
2774}
2775
2776/// Join the channels' room
2777async fn join_channel(
2778 request: proto::JoinChannel,
2779 response: Response<proto::JoinChannel>,
2780 session: Session,
2781) -> Result<()> {
2782 let channel_id = ChannelId::from_proto(request.channel_id);
2783 join_channel_internal(channel_id, Box::new(response), session).await
2784}
2785
2786trait JoinChannelInternalResponse {
2787 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2788}
2789impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2790 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2791 Response::<proto::JoinChannel>::send(self, result)
2792 }
2793}
2794impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2795 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2796 Response::<proto::JoinRoom>::send(self, result)
2797 }
2798}
2799
2800async fn join_channel_internal(
2801 channel_id: ChannelId,
2802 response: Box<impl JoinChannelInternalResponse>,
2803 session: Session,
2804) -> Result<()> {
2805 let joined_room = {
2806 leave_room_for_session(&session).await?;
2807 let db = session.db().await;
2808
2809 let (joined_room, membership_updated, role) = db
2810 .join_channel(channel_id, session.user_id, session.connection_id)
2811 .await?;
2812
2813 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2814 let (can_publish, token) = if role == ChannelRole::Guest {
2815 (
2816 false,
2817 live_kit
2818 .guest_token(
2819 &joined_room.room.live_kit_room,
2820 &session.user_id.to_string(),
2821 )
2822 .trace_err()?,
2823 )
2824 } else {
2825 (
2826 true,
2827 live_kit
2828 .room_token(
2829 &joined_room.room.live_kit_room,
2830 &session.user_id.to_string(),
2831 )
2832 .trace_err()?,
2833 )
2834 };
2835
2836 Some(LiveKitConnectionInfo {
2837 server_url: live_kit.url().into(),
2838 token,
2839 can_publish,
2840 })
2841 });
2842
2843 response.send(proto::JoinRoomResponse {
2844 room: Some(joined_room.room.clone()),
2845 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2846 live_kit_connection_info,
2847 })?;
2848
2849 let connection_pool = session.connection_pool().await;
2850 if let Some(membership_updated) = membership_updated {
2851 notify_membership_updated(
2852 &connection_pool,
2853 membership_updated,
2854 session.user_id,
2855 &session.peer,
2856 );
2857 }
2858
2859 room_updated(&joined_room.room, &session.peer);
2860
2861 joined_room
2862 };
2863
2864 channel_updated(
2865 channel_id,
2866 &joined_room.room,
2867 &joined_room.channel_members,
2868 &session.peer,
2869 &*session.connection_pool().await,
2870 );
2871
2872 update_user_contacts(session.user_id, &session).await?;
2873 Ok(())
2874}
2875
2876/// Start editing the channel notes
2877async fn join_channel_buffer(
2878 request: proto::JoinChannelBuffer,
2879 response: Response<proto::JoinChannelBuffer>,
2880 session: Session,
2881) -> Result<()> {
2882 let db = session.db().await;
2883 let channel_id = ChannelId::from_proto(request.channel_id);
2884
2885 let open_response = db
2886 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2887 .await?;
2888
2889 let collaborators = open_response.collaborators.clone();
2890 response.send(open_response)?;
2891
2892 let update = UpdateChannelBufferCollaborators {
2893 channel_id: channel_id.to_proto(),
2894 collaborators: collaborators.clone(),
2895 };
2896 channel_buffer_updated(
2897 session.connection_id,
2898 collaborators
2899 .iter()
2900 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2901 &update,
2902 &session.peer,
2903 );
2904
2905 Ok(())
2906}
2907
2908/// Edit the channel notes
2909async fn update_channel_buffer(
2910 request: proto::UpdateChannelBuffer,
2911 session: Session,
2912) -> Result<()> {
2913 let db = session.db().await;
2914 let channel_id = ChannelId::from_proto(request.channel_id);
2915
2916 let (collaborators, non_collaborators, epoch, version) = db
2917 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2918 .await?;
2919
2920 channel_buffer_updated(
2921 session.connection_id,
2922 collaborators,
2923 &proto::UpdateChannelBuffer {
2924 channel_id: channel_id.to_proto(),
2925 operations: request.operations,
2926 },
2927 &session.peer,
2928 );
2929
2930 let pool = &*session.connection_pool().await;
2931
2932 broadcast(
2933 None,
2934 non_collaborators
2935 .iter()
2936 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2937 |peer_id| {
2938 session.peer.send(
2939 peer_id,
2940 proto::UpdateChannels {
2941 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2942 channel_id: channel_id.to_proto(),
2943 epoch: epoch as u64,
2944 version: version.clone(),
2945 }],
2946 ..Default::default()
2947 },
2948 )
2949 },
2950 );
2951
2952 Ok(())
2953}
2954
2955/// Rejoin the channel notes after a connection blip
2956async fn rejoin_channel_buffers(
2957 request: proto::RejoinChannelBuffers,
2958 response: Response<proto::RejoinChannelBuffers>,
2959 session: Session,
2960) -> Result<()> {
2961 let db = session.db().await;
2962 let buffers = db
2963 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2964 .await?;
2965
2966 for rejoined_buffer in &buffers {
2967 let collaborators_to_notify = rejoined_buffer
2968 .buffer
2969 .collaborators
2970 .iter()
2971 .filter_map(|c| Some(c.peer_id?.into()));
2972 channel_buffer_updated(
2973 session.connection_id,
2974 collaborators_to_notify,
2975 &proto::UpdateChannelBufferCollaborators {
2976 channel_id: rejoined_buffer.buffer.channel_id,
2977 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2978 },
2979 &session.peer,
2980 );
2981 }
2982
2983 response.send(proto::RejoinChannelBuffersResponse {
2984 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2985 })?;
2986
2987 Ok(())
2988}
2989
2990/// Stop editing the channel notes
2991async fn leave_channel_buffer(
2992 request: proto::LeaveChannelBuffer,
2993 response: Response<proto::LeaveChannelBuffer>,
2994 session: Session,
2995) -> Result<()> {
2996 let db = session.db().await;
2997 let channel_id = ChannelId::from_proto(request.channel_id);
2998
2999 let left_buffer = db
3000 .leave_channel_buffer(channel_id, session.connection_id)
3001 .await?;
3002
3003 response.send(Ack {})?;
3004
3005 channel_buffer_updated(
3006 session.connection_id,
3007 left_buffer.connections,
3008 &proto::UpdateChannelBufferCollaborators {
3009 channel_id: channel_id.to_proto(),
3010 collaborators: left_buffer.collaborators,
3011 },
3012 &session.peer,
3013 );
3014
3015 Ok(())
3016}
3017
3018fn channel_buffer_updated<T: EnvelopedMessage>(
3019 sender_id: ConnectionId,
3020 collaborators: impl IntoIterator<Item = ConnectionId>,
3021 message: &T,
3022 peer: &Peer,
3023) {
3024 broadcast(Some(sender_id), collaborators, |peer_id| {
3025 peer.send(peer_id, message.clone())
3026 });
3027}
3028
3029fn send_notifications(
3030 connection_pool: &ConnectionPool,
3031 peer: &Peer,
3032 notifications: db::NotificationBatch,
3033) {
3034 for (user_id, notification) in notifications {
3035 for connection_id in connection_pool.user_connection_ids(user_id) {
3036 if let Err(error) = peer.send(
3037 connection_id,
3038 proto::AddNotification {
3039 notification: Some(notification.clone()),
3040 },
3041 ) {
3042 tracing::error!(
3043 "failed to send notification to {:?} {}",
3044 connection_id,
3045 error
3046 );
3047 }
3048 }
3049 }
3050}
3051
3052/// Send a message to the channel
3053async fn send_channel_message(
3054 request: proto::SendChannelMessage,
3055 response: Response<proto::SendChannelMessage>,
3056 session: Session,
3057) -> Result<()> {
3058 // Validate the message body.
3059 let body = request.body.trim().to_string();
3060 if body.len() > MAX_MESSAGE_LEN {
3061 return Err(anyhow!("message is too long"))?;
3062 }
3063 if body.is_empty() {
3064 return Err(anyhow!("message can't be blank"))?;
3065 }
3066
3067 // TODO: adjust mentions if body is trimmed
3068
3069 let timestamp = OffsetDateTime::now_utc();
3070 let nonce = request
3071 .nonce
3072 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3073
3074 let channel_id = ChannelId::from_proto(request.channel_id);
3075 let CreatedChannelMessage {
3076 message_id,
3077 participant_connection_ids,
3078 channel_members,
3079 notifications,
3080 } = session
3081 .db()
3082 .await
3083 .create_channel_message(
3084 channel_id,
3085 session.user_id,
3086 &body,
3087 &request.mentions,
3088 timestamp,
3089 nonce.clone().into(),
3090 match request.reply_to_message_id {
3091 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3092 None => None,
3093 },
3094 )
3095 .await?;
3096 let message = proto::ChannelMessage {
3097 sender_id: session.user_id.to_proto(),
3098 id: message_id.to_proto(),
3099 body,
3100 mentions: request.mentions,
3101 timestamp: timestamp.unix_timestamp() as u64,
3102 nonce: Some(nonce),
3103 reply_to_message_id: request.reply_to_message_id,
3104 };
3105 broadcast(
3106 Some(session.connection_id),
3107 participant_connection_ids,
3108 |connection| {
3109 session.peer.send(
3110 connection,
3111 proto::ChannelMessageSent {
3112 channel_id: channel_id.to_proto(),
3113 message: Some(message.clone()),
3114 },
3115 )
3116 },
3117 );
3118 response.send(proto::SendChannelMessageResponse {
3119 message: Some(message),
3120 })?;
3121
3122 let pool = &*session.connection_pool().await;
3123 broadcast(
3124 None,
3125 channel_members
3126 .iter()
3127 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3128 |peer_id| {
3129 session.peer.send(
3130 peer_id,
3131 proto::UpdateChannels {
3132 latest_channel_message_ids: vec![proto::ChannelMessageId {
3133 channel_id: channel_id.to_proto(),
3134 message_id: message_id.to_proto(),
3135 }],
3136 ..Default::default()
3137 },
3138 )
3139 },
3140 );
3141 send_notifications(pool, &session.peer, notifications);
3142
3143 Ok(())
3144}
3145
3146/// Delete a channel message
3147async fn remove_channel_message(
3148 request: proto::RemoveChannelMessage,
3149 response: Response<proto::RemoveChannelMessage>,
3150 session: Session,
3151) -> Result<()> {
3152 let channel_id = ChannelId::from_proto(request.channel_id);
3153 let message_id = MessageId::from_proto(request.message_id);
3154 let connection_ids = session
3155 .db()
3156 .await
3157 .remove_channel_message(channel_id, message_id, session.user_id)
3158 .await?;
3159 broadcast(Some(session.connection_id), connection_ids, |connection| {
3160 session.peer.send(connection, request.clone())
3161 });
3162 response.send(proto::Ack {})?;
3163 Ok(())
3164}
3165
3166/// Mark a channel message as read
3167async fn acknowledge_channel_message(
3168 request: proto::AckChannelMessage,
3169 session: Session,
3170) -> Result<()> {
3171 let channel_id = ChannelId::from_proto(request.channel_id);
3172 let message_id = MessageId::from_proto(request.message_id);
3173 let notifications = session
3174 .db()
3175 .await
3176 .observe_channel_message(channel_id, session.user_id, message_id)
3177 .await?;
3178 send_notifications(
3179 &*session.connection_pool().await,
3180 &session.peer,
3181 notifications,
3182 );
3183 Ok(())
3184}
3185
3186/// Mark a buffer version as synced
3187async fn acknowledge_buffer_version(
3188 request: proto::AckBufferOperation,
3189 session: Session,
3190) -> Result<()> {
3191 let buffer_id = BufferId::from_proto(request.buffer_id);
3192 session
3193 .db()
3194 .await
3195 .observe_buffer_version(
3196 buffer_id,
3197 session.user_id,
3198 request.epoch as i32,
3199 &request.version,
3200 )
3201 .await?;
3202 Ok(())
3203}
3204
3205/// Start receiving chat updates for a channel
3206async fn join_channel_chat(
3207 request: proto::JoinChannelChat,
3208 response: Response<proto::JoinChannelChat>,
3209 session: Session,
3210) -> Result<()> {
3211 let channel_id = ChannelId::from_proto(request.channel_id);
3212
3213 let db = session.db().await;
3214 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3215 .await?;
3216 let messages = db
3217 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3218 .await?;
3219 response.send(proto::JoinChannelChatResponse {
3220 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3221 messages,
3222 })?;
3223 Ok(())
3224}
3225
3226/// Stop receiving chat updates for a channel
3227async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3228 let channel_id = ChannelId::from_proto(request.channel_id);
3229 session
3230 .db()
3231 .await
3232 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3233 .await?;
3234 Ok(())
3235}
3236
3237/// Retrieve the chat history for a channel
3238async fn get_channel_messages(
3239 request: proto::GetChannelMessages,
3240 response: Response<proto::GetChannelMessages>,
3241 session: Session,
3242) -> Result<()> {
3243 let channel_id = ChannelId::from_proto(request.channel_id);
3244 let messages = session
3245 .db()
3246 .await
3247 .get_channel_messages(
3248 channel_id,
3249 session.user_id,
3250 MESSAGE_COUNT_PER_PAGE,
3251 Some(MessageId::from_proto(request.before_message_id)),
3252 )
3253 .await?;
3254 response.send(proto::GetChannelMessagesResponse {
3255 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3256 messages,
3257 })?;
3258 Ok(())
3259}
3260
3261/// Retrieve specific chat messages
3262async fn get_channel_messages_by_id(
3263 request: proto::GetChannelMessagesById,
3264 response: Response<proto::GetChannelMessagesById>,
3265 session: Session,
3266) -> Result<()> {
3267 let message_ids = request
3268 .message_ids
3269 .iter()
3270 .map(|id| MessageId::from_proto(*id))
3271 .collect::<Vec<_>>();
3272 let messages = session
3273 .db()
3274 .await
3275 .get_channel_messages_by_id(session.user_id, &message_ids)
3276 .await?;
3277 response.send(proto::GetChannelMessagesResponse {
3278 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3279 messages,
3280 })?;
3281 Ok(())
3282}
3283
3284/// Retrieve the current users notifications
3285async fn get_notifications(
3286 request: proto::GetNotifications,
3287 response: Response<proto::GetNotifications>,
3288 session: Session,
3289) -> Result<()> {
3290 let notifications = session
3291 .db()
3292 .await
3293 .get_notifications(
3294 session.user_id,
3295 NOTIFICATION_COUNT_PER_PAGE,
3296 request
3297 .before_id
3298 .map(|id| db::NotificationId::from_proto(id)),
3299 )
3300 .await?;
3301 response.send(proto::GetNotificationsResponse {
3302 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3303 notifications,
3304 })?;
3305 Ok(())
3306}
3307
3308/// Mark notifications as read
3309async fn mark_notification_as_read(
3310 request: proto::MarkNotificationRead,
3311 response: Response<proto::MarkNotificationRead>,
3312 session: Session,
3313) -> Result<()> {
3314 let database = &session.db().await;
3315 let notifications = database
3316 .mark_notification_as_read_by_id(
3317 session.user_id,
3318 NotificationId::from_proto(request.notification_id),
3319 )
3320 .await?;
3321 send_notifications(
3322 &*session.connection_pool().await,
3323 &session.peer,
3324 notifications,
3325 );
3326 response.send(proto::Ack {})?;
3327 Ok(())
3328}
3329
3330/// Get the current users information
3331async fn get_private_user_info(
3332 _request: proto::GetPrivateUserInfo,
3333 response: Response<proto::GetPrivateUserInfo>,
3334 session: Session,
3335) -> Result<()> {
3336 let db = session.db().await;
3337
3338 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3339 let user = db
3340 .get_user_by_id(session.user_id)
3341 .await?
3342 .ok_or_else(|| anyhow!("user not found"))?;
3343 let flags = db.get_user_flags(session.user_id).await?;
3344
3345 response.send(proto::GetPrivateUserInfoResponse {
3346 metrics_id,
3347 staff: user.admin,
3348 flags,
3349 })?;
3350 Ok(())
3351}
3352
3353fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3354 match message {
3355 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3356 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3357 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3358 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3359 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3360 code: frame.code.into(),
3361 reason: frame.reason,
3362 })),
3363 }
3364}
3365
3366fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3367 match message {
3368 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3369 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3370 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3371 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3372 AxumMessage::Close(frame) => {
3373 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3374 code: frame.code.into(),
3375 reason: frame.reason,
3376 }))
3377 }
3378 }
3379}
3380
3381fn notify_membership_updated(
3382 connection_pool: &ConnectionPool,
3383 result: MembershipUpdated,
3384 user_id: UserId,
3385 peer: &Peer,
3386) {
3387 let user_channels_update = proto::UpdateUserChannels {
3388 channel_memberships: result
3389 .new_channels
3390 .channel_memberships
3391 .iter()
3392 .map(|cm| proto::ChannelMembership {
3393 channel_id: cm.channel_id.to_proto(),
3394 role: cm.role.into(),
3395 })
3396 .collect(),
3397 ..Default::default()
3398 };
3399 let mut update = build_channels_update(result.new_channels, vec![]);
3400 update.delete_channels = result
3401 .removed_channels
3402 .into_iter()
3403 .map(|id| id.to_proto())
3404 .collect();
3405 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3406
3407 for connection_id in connection_pool.user_connection_ids(user_id) {
3408 peer.send(connection_id, user_channels_update.clone())
3409 .trace_err();
3410 peer.send(connection_id, update.clone()).trace_err();
3411 }
3412}
3413
3414fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3415 proto::UpdateUserChannels {
3416 channel_memberships: channels
3417 .channel_memberships
3418 .iter()
3419 .map(|m| proto::ChannelMembership {
3420 channel_id: m.channel_id.to_proto(),
3421 role: m.role.into(),
3422 })
3423 .collect(),
3424 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3425 observed_channel_message_id: channels.observed_channel_messages.clone(),
3426 }
3427}
3428
3429fn build_channels_update(
3430 channels: ChannelsForUser,
3431 channel_invites: Vec<db::Channel>,
3432) -> proto::UpdateChannels {
3433 let mut update = proto::UpdateChannels::default();
3434
3435 for channel in channels.channels {
3436 update.channels.push(channel.to_proto());
3437 }
3438
3439 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3440 update.latest_channel_message_ids = channels.latest_channel_messages;
3441
3442 for (channel_id, participants) in channels.channel_participants {
3443 update
3444 .channel_participants
3445 .push(proto::ChannelParticipants {
3446 channel_id: channel_id.to_proto(),
3447 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3448 });
3449 }
3450
3451 for channel in channel_invites {
3452 update.channel_invitations.push(channel.to_proto());
3453 }
3454 for project in channels.hosted_projects {
3455 update.hosted_projects.push(project);
3456 }
3457
3458 update
3459}
3460
3461fn build_initial_contacts_update(
3462 contacts: Vec<db::Contact>,
3463 pool: &ConnectionPool,
3464) -> proto::UpdateContacts {
3465 let mut update = proto::UpdateContacts::default();
3466
3467 for contact in contacts {
3468 match contact {
3469 db::Contact::Accepted { user_id, busy } => {
3470 update.contacts.push(contact_for_user(user_id, busy, &pool));
3471 }
3472 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3473 db::Contact::Incoming { user_id } => {
3474 update
3475 .incoming_requests
3476 .push(proto::IncomingContactRequest {
3477 requester_id: user_id.to_proto(),
3478 })
3479 }
3480 }
3481 }
3482
3483 update
3484}
3485
3486fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3487 proto::Contact {
3488 user_id: user_id.to_proto(),
3489 online: pool.is_user_online(user_id),
3490 busy,
3491 }
3492}
3493
3494fn room_updated(room: &proto::Room, peer: &Peer) {
3495 broadcast(
3496 None,
3497 room.participants
3498 .iter()
3499 .filter_map(|participant| Some(participant.peer_id?.into())),
3500 |peer_id| {
3501 peer.send(
3502 peer_id,
3503 proto::RoomUpdated {
3504 room: Some(room.clone()),
3505 },
3506 )
3507 },
3508 );
3509}
3510
3511fn channel_updated(
3512 channel_id: ChannelId,
3513 room: &proto::Room,
3514 channel_members: &[UserId],
3515 peer: &Peer,
3516 pool: &ConnectionPool,
3517) {
3518 let participants = room
3519 .participants
3520 .iter()
3521 .map(|p| p.user_id)
3522 .collect::<Vec<_>>();
3523
3524 broadcast(
3525 None,
3526 channel_members
3527 .iter()
3528 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3529 |peer_id| {
3530 peer.send(
3531 peer_id,
3532 proto::UpdateChannels {
3533 channel_participants: vec![proto::ChannelParticipants {
3534 channel_id: channel_id.to_proto(),
3535 participant_user_ids: participants.clone(),
3536 }],
3537 ..Default::default()
3538 },
3539 )
3540 },
3541 );
3542}
3543
3544async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3545 let db = session.db().await;
3546
3547 let contacts = db.get_contacts(user_id).await?;
3548 let busy = db.is_user_busy(user_id).await?;
3549
3550 let pool = session.connection_pool().await;
3551 let updated_contact = contact_for_user(user_id, busy, &pool);
3552 for contact in contacts {
3553 if let db::Contact::Accepted {
3554 user_id: contact_user_id,
3555 ..
3556 } = contact
3557 {
3558 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3559 session
3560 .peer
3561 .send(
3562 contact_conn_id,
3563 proto::UpdateContacts {
3564 contacts: vec![updated_contact.clone()],
3565 remove_contacts: Default::default(),
3566 incoming_requests: Default::default(),
3567 remove_incoming_requests: Default::default(),
3568 outgoing_requests: Default::default(),
3569 remove_outgoing_requests: Default::default(),
3570 },
3571 )
3572 .trace_err();
3573 }
3574 }
3575 }
3576 Ok(())
3577}
3578
3579async fn leave_room_for_session(session: &Session) -> Result<()> {
3580 let mut contacts_to_update = HashSet::default();
3581
3582 let room_id;
3583 let canceled_calls_to_user_ids;
3584 let live_kit_room;
3585 let delete_live_kit_room;
3586 let room;
3587 let channel_members;
3588 let channel_id;
3589
3590 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3591 contacts_to_update.insert(session.user_id);
3592
3593 for project in left_room.left_projects.values() {
3594 project_left(project, session);
3595 }
3596
3597 room_id = RoomId::from_proto(left_room.room.id);
3598 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3599 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3600 delete_live_kit_room = left_room.deleted;
3601 room = mem::take(&mut left_room.room);
3602 channel_members = mem::take(&mut left_room.channel_members);
3603 channel_id = left_room.channel_id;
3604
3605 room_updated(&room, &session.peer);
3606 } else {
3607 return Ok(());
3608 }
3609
3610 if let Some(channel_id) = channel_id {
3611 channel_updated(
3612 channel_id,
3613 &room,
3614 &channel_members,
3615 &session.peer,
3616 &*session.connection_pool().await,
3617 );
3618 }
3619
3620 {
3621 let pool = session.connection_pool().await;
3622 for canceled_user_id in canceled_calls_to_user_ids {
3623 for connection_id in pool.user_connection_ids(canceled_user_id) {
3624 session
3625 .peer
3626 .send(
3627 connection_id,
3628 proto::CallCanceled {
3629 room_id: room_id.to_proto(),
3630 },
3631 )
3632 .trace_err();
3633 }
3634 contacts_to_update.insert(canceled_user_id);
3635 }
3636 }
3637
3638 for contact_user_id in contacts_to_update {
3639 update_user_contacts(contact_user_id, &session).await?;
3640 }
3641
3642 if let Some(live_kit) = session.live_kit_client.as_ref() {
3643 live_kit
3644 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3645 .await
3646 .trace_err();
3647
3648 if delete_live_kit_room {
3649 live_kit.delete_room(live_kit_room).await.trace_err();
3650 }
3651 }
3652
3653 Ok(())
3654}
3655
3656async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3657 let left_channel_buffers = session
3658 .db()
3659 .await
3660 .leave_channel_buffers(session.connection_id)
3661 .await?;
3662
3663 for left_buffer in left_channel_buffers {
3664 channel_buffer_updated(
3665 session.connection_id,
3666 left_buffer.connections,
3667 &proto::UpdateChannelBufferCollaborators {
3668 channel_id: left_buffer.channel_id.to_proto(),
3669 collaborators: left_buffer.collaborators,
3670 },
3671 &session.peer,
3672 );
3673 }
3674
3675 Ok(())
3676}
3677
3678fn project_left(project: &db::LeftProject, session: &Session) {
3679 for connection_id in &project.connection_ids {
3680 if project.host_user_id == Some(session.user_id) {
3681 session
3682 .peer
3683 .send(
3684 *connection_id,
3685 proto::UnshareProject {
3686 project_id: project.id.to_proto(),
3687 },
3688 )
3689 .trace_err();
3690 } else {
3691 session
3692 .peer
3693 .send(
3694 *connection_id,
3695 proto::RemoveProjectCollaborator {
3696 project_id: project.id.to_proto(),
3697 peer_id: Some(session.connection_id.into()),
3698 },
3699 )
3700 .trace_err();
3701 }
3702 }
3703}
3704
3705pub trait ResultExt {
3706 type Ok;
3707
3708 fn trace_err(self) -> Option<Self::Ok>;
3709}
3710
3711impl<T, E> ResultExt for Result<T, E>
3712where
3713 E: std::fmt::Debug,
3714{
3715 type Ok = T;
3716
3717 fn trace_err(self) -> Option<T> {
3718 match self {
3719 Ok(value) => Some(value),
3720 Err(error) => {
3721 tracing::error!("{:?}", error);
3722 None
3723 }
3724 }
3725 }
3726}