1mod connection_pool;
2
3use crate::{
4 auth::{self, Impersonator},
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
7 HostedProjectId, InviteMemberResult, MembershipUpdated, MessageId, NotificationId, Project,
8 ProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId, ServerId,
9 User, UserId,
10 },
11 executor::Executor,
12 AppState, Error, Result,
13};
14use anyhow::anyhow;
15use async_tungstenite::tungstenite::{
16 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
17};
18use axum::{
19 body::Body,
20 extract::{
21 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
22 ConnectInfo, WebSocketUpgrade,
23 },
24 headers::{Header, HeaderName},
25 http::StatusCode,
26 middleware,
27 response::IntoResponse,
28 routing::get,
29 Extension, Router, TypedHeader,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use futures::{
34 channel::oneshot,
35 future::{self, BoxFuture},
36 stream::FuturesUnordered,
37 FutureExt, SinkExt, StreamExt, TryStreamExt,
38};
39use prometheus::{register_int_gauge, IntGauge};
40use rpc::{
41 proto::{
42 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
43 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
44 },
45 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
46};
47use serde::{Serialize, Serializer};
48use std::{
49 any::TypeId,
50 fmt,
51 future::Future,
52 marker::PhantomData,
53 mem,
54 net::SocketAddr,
55 ops::{Deref, DerefMut},
56 rc::Rc,
57 sync::{
58 atomic::{AtomicBool, Ordering::SeqCst},
59 Arc, OnceLock,
60 },
61 time::{Duration, Instant},
62};
63use time::OffsetDateTime;
64use tokio::sync::{watch, Semaphore};
65use tower::ServiceBuilder;
66use tracing::{field, info_span, instrument, Instrument};
67use util::SemanticVersion;
68
69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
70pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
71
72const MESSAGE_COUNT_PER_PAGE: usize = 100;
73const MAX_MESSAGE_LEN: usize = 1024;
74const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
75
76type MessageHandler =
77 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
78
79struct Response<R> {
80 peer: Arc<Peer>,
81 receipt: Receipt<R>,
82 responded: Arc<AtomicBool>,
83}
84
85impl<R: RequestMessage> Response<R> {
86 fn send(self, payload: R::Response) -> Result<()> {
87 self.responded.store(true, SeqCst);
88 self.peer.respond(self.receipt, payload)?;
89 Ok(())
90 }
91}
92
93#[derive(Clone)]
94struct Session {
95 user_id: UserId,
96 connection_id: ConnectionId,
97 db: Arc<tokio::sync::Mutex<DbHandle>>,
98 peer: Arc<Peer>,
99 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
100 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
101 _executor: Executor,
102}
103
104impl Session {
105 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
106 #[cfg(test)]
107 tokio::task::yield_now().await;
108 let guard = self.db.lock().await;
109 #[cfg(test)]
110 tokio::task::yield_now().await;
111 guard
112 }
113
114 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
115 #[cfg(test)]
116 tokio::task::yield_now().await;
117 let guard = self.connection_pool.lock();
118 ConnectionPoolGuard {
119 guard,
120 _not_send: PhantomData,
121 }
122 }
123}
124
125impl fmt::Debug for Session {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("Session")
128 .field("user_id", &self.user_id)
129 .field("connection_id", &self.connection_id)
130 .finish()
131 }
132}
133
134struct DbHandle(Arc<Database>);
135
136impl Deref for DbHandle {
137 type Target = Database;
138
139 fn deref(&self) -> &Self::Target {
140 self.0.as_ref()
141 }
142}
143
144pub struct Server {
145 id: parking_lot::Mutex<ServerId>,
146 peer: Arc<Peer>,
147 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
148 app_state: Arc<AppState>,
149 executor: Executor,
150 handlers: HashMap<TypeId, MessageHandler>,
151 teardown: watch::Sender<bool>,
152}
153
154pub(crate) struct ConnectionPoolGuard<'a> {
155 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
156 _not_send: PhantomData<Rc<()>>,
157}
158
159#[derive(Serialize)]
160pub struct ServerSnapshot<'a> {
161 peer: &'a Peer,
162 #[serde(serialize_with = "serialize_deref")]
163 connection_pool: ConnectionPoolGuard<'a>,
164}
165
166pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
167where
168 S: Serializer,
169 T: Deref<Target = U>,
170 U: Serialize,
171{
172 Serialize::serialize(value.deref(), serializer)
173}
174
175impl Server {
176 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
177 let mut server = Self {
178 id: parking_lot::Mutex::new(id),
179 peer: Peer::new(id.0 as u32),
180 app_state,
181 executor,
182 connection_pool: Default::default(),
183 handlers: Default::default(),
184 teardown: watch::channel(false).0,
185 };
186
187 server
188 .add_request_handler(ping)
189 .add_request_handler(create_room)
190 .add_request_handler(join_room)
191 .add_request_handler(rejoin_room)
192 .add_request_handler(leave_room)
193 .add_request_handler(set_room_participant_role)
194 .add_request_handler(call)
195 .add_request_handler(cancel_call)
196 .add_message_handler(decline_call)
197 .add_request_handler(update_participant_location)
198 .add_request_handler(share_project)
199 .add_message_handler(unshare_project)
200 .add_request_handler(join_project)
201 .add_request_handler(join_hosted_project)
202 .add_message_handler(leave_project)
203 .add_request_handler(update_project)
204 .add_request_handler(update_worktree)
205 .add_message_handler(start_language_server)
206 .add_message_handler(update_language_server)
207 .add_message_handler(update_diagnostic_summary)
208 .add_message_handler(update_worktree_settings)
209 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
210 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
211 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
212 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
213 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
214 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
215 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
216 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
217 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
218 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
219 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
220 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
221 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
222 .add_request_handler(
223 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
224 )
225 .add_request_handler(
226 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
227 )
228 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
229 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
230 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
231 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
232 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
233 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
234 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
235 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
236 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
237 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
238 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
239 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
240 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
241 .add_message_handler(create_buffer_for_peer)
242 .add_request_handler(update_buffer)
243 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
244 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
245 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
246 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
247 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
248 .add_request_handler(get_users)
249 .add_request_handler(fuzzy_search_users)
250 .add_request_handler(request_contact)
251 .add_request_handler(remove_contact)
252 .add_request_handler(respond_to_contact_request)
253 .add_request_handler(create_channel)
254 .add_request_handler(delete_channel)
255 .add_request_handler(invite_channel_member)
256 .add_request_handler(remove_channel_member)
257 .add_request_handler(set_channel_member_role)
258 .add_request_handler(set_channel_visibility)
259 .add_request_handler(rename_channel)
260 .add_request_handler(join_channel_buffer)
261 .add_request_handler(leave_channel_buffer)
262 .add_message_handler(update_channel_buffer)
263 .add_request_handler(rejoin_channel_buffers)
264 .add_request_handler(get_channel_members)
265 .add_request_handler(respond_to_channel_invite)
266 .add_request_handler(join_channel)
267 .add_request_handler(join_channel_chat)
268 .add_message_handler(leave_channel_chat)
269 .add_request_handler(send_channel_message)
270 .add_request_handler(remove_channel_message)
271 .add_request_handler(get_channel_messages)
272 .add_request_handler(get_channel_messages_by_id)
273 .add_request_handler(get_notifications)
274 .add_request_handler(mark_notification_as_read)
275 .add_request_handler(move_channel)
276 .add_request_handler(follow)
277 .add_message_handler(unfollow)
278 .add_message_handler(update_followers)
279 .add_request_handler(get_private_user_info)
280 .add_message_handler(acknowledge_channel_message)
281 .add_message_handler(acknowledge_buffer_version);
282
283 Arc::new(server)
284 }
285
286 pub async fn start(&self) -> Result<()> {
287 let server_id = *self.id.lock();
288 let app_state = self.app_state.clone();
289 let peer = self.peer.clone();
290 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
291 let pool = self.connection_pool.clone();
292 let live_kit_client = self.app_state.live_kit_client.clone();
293
294 let span = info_span!("start server");
295 self.executor.spawn_detached(
296 async move {
297 tracing::info!("waiting for cleanup timeout");
298 timeout.await;
299 tracing::info!("cleanup timeout expired, retrieving stale rooms");
300 if let Some((room_ids, channel_ids)) = app_state
301 .db
302 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
303 .await
304 .trace_err()
305 {
306 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
307 tracing::info!(
308 stale_channel_buffer_count = channel_ids.len(),
309 "retrieved stale channel buffers"
310 );
311
312 for channel_id in channel_ids {
313 if let Some(refreshed_channel_buffer) = app_state
314 .db
315 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
316 .await
317 .trace_err()
318 {
319 for connection_id in refreshed_channel_buffer.connection_ids {
320 peer.send(
321 connection_id,
322 proto::UpdateChannelBufferCollaborators {
323 channel_id: channel_id.to_proto(),
324 collaborators: refreshed_channel_buffer
325 .collaborators
326 .clone(),
327 },
328 )
329 .trace_err();
330 }
331 }
332 }
333
334 for room_id in room_ids {
335 let mut contacts_to_update = HashSet::default();
336 let mut canceled_calls_to_user_ids = Vec::new();
337 let mut live_kit_room = String::new();
338 let mut delete_live_kit_room = false;
339
340 if let Some(mut refreshed_room) = app_state
341 .db
342 .clear_stale_room_participants(room_id, server_id)
343 .await
344 .trace_err()
345 {
346 tracing::info!(
347 room_id = room_id.0,
348 new_participant_count = refreshed_room.room.participants.len(),
349 "refreshed room"
350 );
351 room_updated(&refreshed_room.room, &peer);
352 if let Some(channel_id) = refreshed_room.channel_id {
353 channel_updated(
354 channel_id,
355 &refreshed_room.room,
356 &refreshed_room.channel_members,
357 &peer,
358 &pool.lock(),
359 );
360 }
361 contacts_to_update
362 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
363 contacts_to_update
364 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
365 canceled_calls_to_user_ids =
366 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
367 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
368 delete_live_kit_room = refreshed_room.room.participants.is_empty();
369 }
370
371 {
372 let pool = pool.lock();
373 for canceled_user_id in canceled_calls_to_user_ids {
374 for connection_id in pool.user_connection_ids(canceled_user_id) {
375 peer.send(
376 connection_id,
377 proto::CallCanceled {
378 room_id: room_id.to_proto(),
379 },
380 )
381 .trace_err();
382 }
383 }
384 }
385
386 for user_id in contacts_to_update {
387 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
388 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
389 if let Some((busy, contacts)) = busy.zip(contacts) {
390 let pool = pool.lock();
391 let updated_contact = contact_for_user(user_id, busy, &pool);
392 for contact in contacts {
393 if let db::Contact::Accepted {
394 user_id: contact_user_id,
395 ..
396 } = contact
397 {
398 for contact_conn_id in
399 pool.user_connection_ids(contact_user_id)
400 {
401 peer.send(
402 contact_conn_id,
403 proto::UpdateContacts {
404 contacts: vec![updated_contact.clone()],
405 remove_contacts: Default::default(),
406 incoming_requests: Default::default(),
407 remove_incoming_requests: Default::default(),
408 outgoing_requests: Default::default(),
409 remove_outgoing_requests: Default::default(),
410 },
411 )
412 .trace_err();
413 }
414 }
415 }
416 }
417 }
418
419 if let Some(live_kit) = live_kit_client.as_ref() {
420 if delete_live_kit_room {
421 live_kit.delete_room(live_kit_room).await.trace_err();
422 }
423 }
424 }
425 }
426
427 app_state
428 .db
429 .delete_stale_servers(&app_state.config.zed_environment, server_id)
430 .await
431 .trace_err();
432 }
433 .instrument(span),
434 );
435 Ok(())
436 }
437
438 pub fn teardown(&self) {
439 self.peer.teardown();
440 self.connection_pool.lock().reset();
441 let _ = self.teardown.send(true);
442 }
443
444 #[cfg(test)]
445 pub fn reset(&self, id: ServerId) {
446 self.teardown();
447 *self.id.lock() = id;
448 self.peer.reset(id.0 as u32);
449 let _ = self.teardown.send(false);
450 }
451
452 #[cfg(test)]
453 pub fn id(&self) -> ServerId {
454 *self.id.lock()
455 }
456
457 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
458 where
459 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
460 Fut: 'static + Send + Future<Output = Result<()>>,
461 M: EnvelopedMessage,
462 {
463 let prev_handler = self.handlers.insert(
464 TypeId::of::<M>(),
465 Box::new(move |envelope, session| {
466 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
467 let received_at = envelope.received_at;
468 let span = info_span!(
469 "handle message",
470 payload_type = envelope.payload_type_name()
471 );
472 span.in_scope(|| {
473 tracing::info!(
474 payload_type = envelope.payload_type_name(),
475 "message received"
476 );
477 });
478 let start_time = Instant::now();
479 let future = (handler)(*envelope, session);
480 async move {
481 let result = future.await;
482 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
483 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
484 let queue_duration_ms = processing_duration_ms - total_duration_ms;
485 match result {
486 Err(error) => {
487 tracing::error!(%error, ?total_duration_ms, ?processing_duration_ms, ?queue_duration_ms, "error handling message")
488 }
489 Ok(()) => tracing::info!(?total_duration_ms, ?processing_duration_ms, ?queue_duration_ms, "finished handling message"),
490 }
491 }
492 .instrument(span)
493 .boxed()
494 }),
495 );
496 if prev_handler.is_some() {
497 panic!("registered a handler for the same message twice");
498 }
499 self
500 }
501
502 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
503 where
504 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
505 Fut: 'static + Send + Future<Output = Result<()>>,
506 M: EnvelopedMessage,
507 {
508 self.add_handler(move |envelope, session| handler(envelope.payload, session));
509 self
510 }
511
512 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
513 where
514 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
515 Fut: Send + Future<Output = Result<()>>,
516 M: RequestMessage,
517 {
518 let handler = Arc::new(handler);
519 self.add_handler(move |envelope, session| {
520 let receipt = envelope.receipt();
521 let handler = handler.clone();
522 async move {
523 let peer = session.peer.clone();
524 let responded = Arc::new(AtomicBool::default());
525 let response = Response {
526 peer: peer.clone(),
527 responded: responded.clone(),
528 receipt,
529 };
530 match (handler)(envelope.payload, response, session).await {
531 Ok(()) => {
532 if responded.load(std::sync::atomic::Ordering::SeqCst) {
533 Ok(())
534 } else {
535 Err(anyhow!("handler did not send a response"))?
536 }
537 }
538 Err(error) => {
539 let proto_err = match &error {
540 Error::Internal(err) => err.to_proto(),
541 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
542 };
543 peer.respond_with_error(receipt, proto_err)?;
544 Err(error)
545 }
546 }
547 }
548 })
549 }
550
551 #[allow(clippy::too_many_arguments)]
552 pub fn handle_connection(
553 self: &Arc<Self>,
554 connection: Connection,
555 address: String,
556 user: User,
557 zed_version: ZedVersion,
558 impersonator: Option<User>,
559 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
560 executor: Executor,
561 ) -> impl Future<Output = Result<()>> {
562 let this = self.clone();
563 let user_id = user.id;
564 let login = user.github_login;
565 let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
566 if let Some(impersonator) = impersonator {
567 span.record("impersonator", &impersonator.github_login);
568 }
569 let mut teardown = self.teardown.subscribe();
570 async move {
571 if *teardown.borrow() {
572 return Err(anyhow!("server is tearing down"))?;
573 }
574 let (connection_id, handle_io, mut incoming_rx) = this
575 .peer
576 .add_connection(connection, {
577 let executor = executor.clone();
578 move |duration| executor.sleep(duration)
579 });
580
581 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
582 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
583 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
584
585 if let Some(send_connection_id) = send_connection_id.take() {
586 let _ = send_connection_id.send(connection_id);
587 }
588
589 if !user.connected_once {
590 this.peer.send(connection_id, proto::ShowContacts {})?;
591 this.app_state.db.set_user_connected_once(user_id, true).await?;
592 }
593
594 let (contacts, channels_for_user, channel_invites) = future::try_join3(
595 this.app_state.db.get_contacts(user_id),
596 this.app_state.db.get_channels_for_user(user_id),
597 this.app_state.db.get_channel_invites_for_user(user_id),
598 ).await?;
599
600 {
601 let mut pool = this.connection_pool.lock();
602 pool.add_connection(connection_id, user_id, user.admin, zed_version);
603 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
604 this.peer.send(connection_id, build_update_user_channels(&channels_for_user))?;
605 this.peer.send(connection_id, build_channels_update(
606 channels_for_user,
607 channel_invites
608 ))?;
609 }
610
611 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
612 this.peer.send(connection_id, incoming_call)?;
613 }
614
615 let session = Session {
616 user_id,
617 connection_id,
618 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
619 peer: this.peer.clone(),
620 connection_pool: this.connection_pool.clone(),
621 live_kit_client: this.app_state.live_kit_client.clone(),
622 _executor: executor.clone()
623 };
624 update_user_contacts(user_id, &session).await?;
625
626 let handle_io = handle_io.fuse();
627 futures::pin_mut!(handle_io);
628
629 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
630 // This prevents deadlocks when e.g., client A performs a request to client B and
631 // client B performs a request to client A. If both clients stop processing further
632 // messages until their respective request completes, they won't have a chance to
633 // respond to the other client's request and cause a deadlock.
634 //
635 // This arrangement ensures we will attempt to process earlier messages first, but fall
636 // back to processing messages arrived later in the spirit of making progress.
637 let mut foreground_message_handlers = FuturesUnordered::new();
638 let concurrent_handlers = Arc::new(Semaphore::new(256));
639 loop {
640 let next_message = async {
641 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
642 let message = incoming_rx.next().await;
643 (permit, message)
644 }.fuse();
645 futures::pin_mut!(next_message);
646 futures::select_biased! {
647 _ = teardown.changed().fuse() => return Ok(()),
648 result = handle_io => {
649 if let Err(error) = result {
650 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
651 }
652 break;
653 }
654 _ = foreground_message_handlers.next() => {}
655 next_message = next_message => {
656 let (permit, message) = next_message;
657 if let Some(message) = message {
658 let type_name = message.payload_type_name();
659 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
660 let span_enter = span.enter();
661 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
662 let is_background = message.is_background();
663 let handle_message = (handler)(message, session.clone());
664 drop(span_enter);
665
666 let handle_message = async move {
667 handle_message.await;
668 drop(permit);
669 }.instrument(span);
670 if is_background {
671 executor.spawn_detached(handle_message);
672 } else {
673 foreground_message_handlers.push(handle_message);
674 }
675 } else {
676 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
677 }
678 } else {
679 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
680 break;
681 }
682 }
683 }
684 }
685
686 drop(foreground_message_handlers);
687 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
688 if let Err(error) = connection_lost(session, teardown, executor).await {
689 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
690 }
691
692 Ok(())
693 }.instrument(span)
694 }
695
696 pub async fn invite_code_redeemed(
697 self: &Arc<Self>,
698 inviter_id: UserId,
699 invitee_id: UserId,
700 ) -> Result<()> {
701 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
702 if let Some(code) = &user.invite_code {
703 let pool = self.connection_pool.lock();
704 let invitee_contact = contact_for_user(invitee_id, false, &pool);
705 for connection_id in pool.user_connection_ids(inviter_id) {
706 self.peer.send(
707 connection_id,
708 proto::UpdateContacts {
709 contacts: vec![invitee_contact.clone()],
710 ..Default::default()
711 },
712 )?;
713 self.peer.send(
714 connection_id,
715 proto::UpdateInviteInfo {
716 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
717 count: user.invite_count as u32,
718 },
719 )?;
720 }
721 }
722 }
723 Ok(())
724 }
725
726 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
727 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
728 if let Some(invite_code) = &user.invite_code {
729 let pool = self.connection_pool.lock();
730 for connection_id in pool.user_connection_ids(user_id) {
731 self.peer.send(
732 connection_id,
733 proto::UpdateInviteInfo {
734 url: format!(
735 "{}{}",
736 self.app_state.config.invite_link_prefix, invite_code
737 ),
738 count: user.invite_count as u32,
739 },
740 )?;
741 }
742 }
743 }
744 Ok(())
745 }
746
747 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
748 ServerSnapshot {
749 connection_pool: ConnectionPoolGuard {
750 guard: self.connection_pool.lock(),
751 _not_send: PhantomData,
752 },
753 peer: &self.peer,
754 }
755 }
756}
757
758impl<'a> Deref for ConnectionPoolGuard<'a> {
759 type Target = ConnectionPool;
760
761 fn deref(&self) -> &Self::Target {
762 &self.guard
763 }
764}
765
766impl<'a> DerefMut for ConnectionPoolGuard<'a> {
767 fn deref_mut(&mut self) -> &mut Self::Target {
768 &mut self.guard
769 }
770}
771
772impl<'a> Drop for ConnectionPoolGuard<'a> {
773 fn drop(&mut self) {
774 #[cfg(test)]
775 self.check_invariants();
776 }
777}
778
779fn broadcast<F>(
780 sender_id: Option<ConnectionId>,
781 receiver_ids: impl IntoIterator<Item = ConnectionId>,
782 mut f: F,
783) where
784 F: FnMut(ConnectionId) -> anyhow::Result<()>,
785{
786 for receiver_id in receiver_ids {
787 if Some(receiver_id) != sender_id {
788 if let Err(error) = f(receiver_id) {
789 tracing::error!("failed to send to {:?} {}", receiver_id, error);
790 }
791 }
792 }
793}
794
795pub struct ProtocolVersion(u32);
796
797impl Header for ProtocolVersion {
798 fn name() -> &'static HeaderName {
799 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
800 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
801 }
802
803 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
804 where
805 Self: Sized,
806 I: Iterator<Item = &'i axum::http::HeaderValue>,
807 {
808 let version = values
809 .next()
810 .ok_or_else(axum::headers::Error::invalid)?
811 .to_str()
812 .map_err(|_| axum::headers::Error::invalid())?
813 .parse()
814 .map_err(|_| axum::headers::Error::invalid())?;
815 Ok(Self(version))
816 }
817
818 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
819 values.extend([self.0.to_string().parse().unwrap()]);
820 }
821}
822
823pub struct AppVersionHeader(SemanticVersion);
824impl Header for AppVersionHeader {
825 fn name() -> &'static HeaderName {
826 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
827 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
828 }
829
830 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
831 where
832 Self: Sized,
833 I: Iterator<Item = &'i axum::http::HeaderValue>,
834 {
835 let version = values
836 .next()
837 .ok_or_else(axum::headers::Error::invalid)?
838 .to_str()
839 .map_err(|_| axum::headers::Error::invalid())?
840 .parse()
841 .map_err(|_| axum::headers::Error::invalid())?;
842 Ok(Self(version))
843 }
844
845 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
846 values.extend([self.0.to_string().parse().unwrap()]);
847 }
848}
849
850pub fn routes(server: Arc<Server>) -> Router<(), Body> {
851 Router::new()
852 .route("/rpc", get(handle_websocket_request))
853 .layer(
854 ServiceBuilder::new()
855 .layer(Extension(server.app_state.clone()))
856 .layer(middleware::from_fn(auth::validate_header)),
857 )
858 .route("/metrics", get(handle_metrics))
859 .layer(Extension(server))
860}
861
862pub async fn handle_websocket_request(
863 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
864 app_version_header: Option<TypedHeader<AppVersionHeader>>,
865 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
866 Extension(server): Extension<Arc<Server>>,
867 Extension(user): Extension<User>,
868 Extension(impersonator): Extension<Impersonator>,
869 ws: WebSocketUpgrade,
870) -> axum::response::Response {
871 if protocol_version != rpc::PROTOCOL_VERSION {
872 return (
873 StatusCode::UPGRADE_REQUIRED,
874 "client must be upgraded".to_string(),
875 )
876 .into_response();
877 }
878
879 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
880 return (
881 StatusCode::UPGRADE_REQUIRED,
882 "no version header found".to_string(),
883 )
884 .into_response();
885 };
886
887 if !version.is_supported() {
888 return (
889 StatusCode::UPGRADE_REQUIRED,
890 "client must be upgraded".to_string(),
891 )
892 .into_response();
893 }
894
895 let socket_address = socket_address.to_string();
896 ws.on_upgrade(move |socket| {
897 use util::ResultExt;
898 let socket = socket
899 .map_ok(to_tungstenite_message)
900 .err_into()
901 .with(|message| async move { Ok(to_axum_message(message)) });
902 let connection = Connection::new(Box::pin(socket));
903 async move {
904 server
905 .handle_connection(
906 connection,
907 socket_address,
908 user,
909 version,
910 impersonator.0,
911 None,
912 Executor::Production,
913 )
914 .await
915 .log_err();
916 }
917 })
918}
919
920pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
921 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
922 let connections_metric = CONNECTIONS_METRIC
923 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
924
925 let connections = server
926 .connection_pool
927 .lock()
928 .connections()
929 .filter(|connection| !connection.admin)
930 .count();
931 connections_metric.set(connections as _);
932
933 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
934 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
935 register_int_gauge!(
936 "shared_projects",
937 "number of open projects with one or more guests"
938 )
939 .unwrap()
940 });
941
942 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
943 shared_projects_metric.set(shared_projects as _);
944
945 let encoder = prometheus::TextEncoder::new();
946 let metric_families = prometheus::gather();
947 let encoded_metrics = encoder
948 .encode_to_string(&metric_families)
949 .map_err(|err| anyhow!("{}", err))?;
950 Ok(encoded_metrics)
951}
952
953#[instrument(err, skip(executor))]
954async fn connection_lost(
955 session: Session,
956 mut teardown: watch::Receiver<bool>,
957 executor: Executor,
958) -> Result<()> {
959 session.peer.disconnect(session.connection_id);
960 session
961 .connection_pool()
962 .await
963 .remove_connection(session.connection_id)?;
964
965 session
966 .db()
967 .await
968 .connection_lost(session.connection_id)
969 .await
970 .trace_err();
971
972 futures::select_biased! {
973 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
974 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
975 leave_room_for_session(&session).await.trace_err();
976 leave_channel_buffers_for_session(&session)
977 .await
978 .trace_err();
979
980 if !session
981 .connection_pool()
982 .await
983 .is_user_online(session.user_id)
984 {
985 let db = session.db().await;
986 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
987 room_updated(&room, &session.peer);
988 }
989 }
990
991 update_user_contacts(session.user_id, &session).await?;
992 }
993 _ = teardown.changed().fuse() => {}
994 }
995
996 Ok(())
997}
998
999/// Acknowledges a ping from a client, used to keep the connection alive.
1000async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1001 response.send(proto::Ack {})?;
1002 Ok(())
1003}
1004
1005/// Creates a new room for calling (outside of channels)
1006async fn create_room(
1007 _request: proto::CreateRoom,
1008 response: Response<proto::CreateRoom>,
1009 session: Session,
1010) -> Result<()> {
1011 let live_kit_room = nanoid::nanoid!(30);
1012
1013 let live_kit_connection_info = {
1014 let live_kit_room = live_kit_room.clone();
1015 let live_kit = session.live_kit_client.as_ref();
1016
1017 util::async_maybe!({
1018 let live_kit = live_kit?;
1019
1020 let token = live_kit
1021 .room_token(&live_kit_room, &session.user_id.to_string())
1022 .trace_err()?;
1023
1024 Some(proto::LiveKitConnectionInfo {
1025 server_url: live_kit.url().into(),
1026 token,
1027 can_publish: true,
1028 })
1029 })
1030 }
1031 .await;
1032
1033 let room = session
1034 .db()
1035 .await
1036 .create_room(session.user_id, session.connection_id, &live_kit_room)
1037 .await?;
1038
1039 response.send(proto::CreateRoomResponse {
1040 room: Some(room.clone()),
1041 live_kit_connection_info,
1042 })?;
1043
1044 update_user_contacts(session.user_id, &session).await?;
1045 Ok(())
1046}
1047
1048/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1049async fn join_room(
1050 request: proto::JoinRoom,
1051 response: Response<proto::JoinRoom>,
1052 session: Session,
1053) -> Result<()> {
1054 let room_id = RoomId::from_proto(request.id);
1055
1056 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1057
1058 if let Some(channel_id) = channel_id {
1059 return join_channel_internal(channel_id, Box::new(response), session).await;
1060 }
1061
1062 let joined_room = {
1063 let room = session
1064 .db()
1065 .await
1066 .join_room(room_id, session.user_id, session.connection_id)
1067 .await?;
1068 room_updated(&room.room, &session.peer);
1069 room.into_inner()
1070 };
1071
1072 for connection_id in session
1073 .connection_pool()
1074 .await
1075 .user_connection_ids(session.user_id)
1076 {
1077 session
1078 .peer
1079 .send(
1080 connection_id,
1081 proto::CallCanceled {
1082 room_id: room_id.to_proto(),
1083 },
1084 )
1085 .trace_err();
1086 }
1087
1088 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1089 if let Some(token) = live_kit
1090 .room_token(
1091 &joined_room.room.live_kit_room,
1092 &session.user_id.to_string(),
1093 )
1094 .trace_err()
1095 {
1096 Some(proto::LiveKitConnectionInfo {
1097 server_url: live_kit.url().into(),
1098 token,
1099 can_publish: true,
1100 })
1101 } else {
1102 None
1103 }
1104 } else {
1105 None
1106 };
1107
1108 response.send(proto::JoinRoomResponse {
1109 room: Some(joined_room.room),
1110 channel_id: None,
1111 live_kit_connection_info,
1112 })?;
1113
1114 update_user_contacts(session.user_id, &session).await?;
1115 Ok(())
1116}
1117
1118/// Rejoin room is used to reconnect to a room after connection errors.
1119async fn rejoin_room(
1120 request: proto::RejoinRoom,
1121 response: Response<proto::RejoinRoom>,
1122 session: Session,
1123) -> Result<()> {
1124 let room;
1125 let channel_id;
1126 let channel_members;
1127 {
1128 let mut rejoined_room = session
1129 .db()
1130 .await
1131 .rejoin_room(request, session.user_id, session.connection_id)
1132 .await?;
1133
1134 response.send(proto::RejoinRoomResponse {
1135 room: Some(rejoined_room.room.clone()),
1136 reshared_projects: rejoined_room
1137 .reshared_projects
1138 .iter()
1139 .map(|project| proto::ResharedProject {
1140 id: project.id.to_proto(),
1141 collaborators: project
1142 .collaborators
1143 .iter()
1144 .map(|collaborator| collaborator.to_proto())
1145 .collect(),
1146 })
1147 .collect(),
1148 rejoined_projects: rejoined_room
1149 .rejoined_projects
1150 .iter()
1151 .map(|rejoined_project| proto::RejoinedProject {
1152 id: rejoined_project.id.to_proto(),
1153 worktrees: rejoined_project
1154 .worktrees
1155 .iter()
1156 .map(|worktree| proto::WorktreeMetadata {
1157 id: worktree.id,
1158 root_name: worktree.root_name.clone(),
1159 visible: worktree.visible,
1160 abs_path: worktree.abs_path.clone(),
1161 })
1162 .collect(),
1163 collaborators: rejoined_project
1164 .collaborators
1165 .iter()
1166 .map(|collaborator| collaborator.to_proto())
1167 .collect(),
1168 language_servers: rejoined_project.language_servers.clone(),
1169 })
1170 .collect(),
1171 })?;
1172 room_updated(&rejoined_room.room, &session.peer);
1173
1174 for project in &rejoined_room.reshared_projects {
1175 for collaborator in &project.collaborators {
1176 session
1177 .peer
1178 .send(
1179 collaborator.connection_id,
1180 proto::UpdateProjectCollaborator {
1181 project_id: project.id.to_proto(),
1182 old_peer_id: Some(project.old_connection_id.into()),
1183 new_peer_id: Some(session.connection_id.into()),
1184 },
1185 )
1186 .trace_err();
1187 }
1188
1189 broadcast(
1190 Some(session.connection_id),
1191 project
1192 .collaborators
1193 .iter()
1194 .map(|collaborator| collaborator.connection_id),
1195 |connection_id| {
1196 session.peer.forward_send(
1197 session.connection_id,
1198 connection_id,
1199 proto::UpdateProject {
1200 project_id: project.id.to_proto(),
1201 worktrees: project.worktrees.clone(),
1202 },
1203 )
1204 },
1205 );
1206 }
1207
1208 for project in &rejoined_room.rejoined_projects {
1209 for collaborator in &project.collaborators {
1210 session
1211 .peer
1212 .send(
1213 collaborator.connection_id,
1214 proto::UpdateProjectCollaborator {
1215 project_id: project.id.to_proto(),
1216 old_peer_id: Some(project.old_connection_id.into()),
1217 new_peer_id: Some(session.connection_id.into()),
1218 },
1219 )
1220 .trace_err();
1221 }
1222 }
1223
1224 for project in &mut rejoined_room.rejoined_projects {
1225 for worktree in mem::take(&mut project.worktrees) {
1226 #[cfg(any(test, feature = "test-support"))]
1227 const MAX_CHUNK_SIZE: usize = 2;
1228 #[cfg(not(any(test, feature = "test-support")))]
1229 const MAX_CHUNK_SIZE: usize = 256;
1230
1231 // Stream this worktree's entries.
1232 let message = proto::UpdateWorktree {
1233 project_id: project.id.to_proto(),
1234 worktree_id: worktree.id,
1235 abs_path: worktree.abs_path.clone(),
1236 root_name: worktree.root_name,
1237 updated_entries: worktree.updated_entries,
1238 removed_entries: worktree.removed_entries,
1239 scan_id: worktree.scan_id,
1240 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1241 updated_repositories: worktree.updated_repositories,
1242 removed_repositories: worktree.removed_repositories,
1243 };
1244 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1245 session.peer.send(session.connection_id, update.clone())?;
1246 }
1247
1248 // Stream this worktree's diagnostics.
1249 for summary in worktree.diagnostic_summaries {
1250 session.peer.send(
1251 session.connection_id,
1252 proto::UpdateDiagnosticSummary {
1253 project_id: project.id.to_proto(),
1254 worktree_id: worktree.id,
1255 summary: Some(summary),
1256 },
1257 )?;
1258 }
1259
1260 for settings_file in worktree.settings_files {
1261 session.peer.send(
1262 session.connection_id,
1263 proto::UpdateWorktreeSettings {
1264 project_id: project.id.to_proto(),
1265 worktree_id: worktree.id,
1266 path: settings_file.path,
1267 content: Some(settings_file.content),
1268 },
1269 )?;
1270 }
1271 }
1272
1273 for language_server in &project.language_servers {
1274 session.peer.send(
1275 session.connection_id,
1276 proto::UpdateLanguageServer {
1277 project_id: project.id.to_proto(),
1278 language_server_id: language_server.id,
1279 variant: Some(
1280 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1281 proto::LspDiskBasedDiagnosticsUpdated {},
1282 ),
1283 ),
1284 },
1285 )?;
1286 }
1287 }
1288
1289 let rejoined_room = rejoined_room.into_inner();
1290
1291 room = rejoined_room.room;
1292 channel_id = rejoined_room.channel_id;
1293 channel_members = rejoined_room.channel_members;
1294 }
1295
1296 if let Some(channel_id) = channel_id {
1297 channel_updated(
1298 channel_id,
1299 &room,
1300 &channel_members,
1301 &session.peer,
1302 &*session.connection_pool().await,
1303 );
1304 }
1305
1306 update_user_contacts(session.user_id, &session).await?;
1307 Ok(())
1308}
1309
1310/// leave room disconnects from the room.
1311async fn leave_room(
1312 _: proto::LeaveRoom,
1313 response: Response<proto::LeaveRoom>,
1314 session: Session,
1315) -> Result<()> {
1316 leave_room_for_session(&session).await?;
1317 response.send(proto::Ack {})?;
1318 Ok(())
1319}
1320
1321/// Updates the permissions of someone else in the room.
1322async fn set_room_participant_role(
1323 request: proto::SetRoomParticipantRole,
1324 response: Response<proto::SetRoomParticipantRole>,
1325 session: Session,
1326) -> Result<()> {
1327 let user_id = UserId::from_proto(request.user_id);
1328 let role = ChannelRole::from(request.role());
1329
1330 if role == ChannelRole::Talker {
1331 let pool = session.connection_pool().await;
1332
1333 for connection in pool.user_connections(user_id) {
1334 if !connection.zed_version.supports_talker_role() {
1335 Err(anyhow!(
1336 "This user is on zed {} which does not support unmute",
1337 connection.zed_version
1338 ))?;
1339 }
1340 }
1341 }
1342
1343 let (live_kit_room, can_publish) = {
1344 let room = session
1345 .db()
1346 .await
1347 .set_room_participant_role(
1348 session.user_id,
1349 RoomId::from_proto(request.room_id),
1350 user_id,
1351 role,
1352 )
1353 .await?;
1354
1355 let live_kit_room = room.live_kit_room.clone();
1356 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1357 room_updated(&room, &session.peer);
1358 (live_kit_room, can_publish)
1359 };
1360
1361 if let Some(live_kit) = session.live_kit_client.as_ref() {
1362 live_kit
1363 .update_participant(
1364 live_kit_room.clone(),
1365 request.user_id.to_string(),
1366 live_kit_server::proto::ParticipantPermission {
1367 can_subscribe: true,
1368 can_publish,
1369 can_publish_data: can_publish,
1370 hidden: false,
1371 recorder: false,
1372 },
1373 )
1374 .await
1375 .trace_err();
1376 }
1377
1378 response.send(proto::Ack {})?;
1379 Ok(())
1380}
1381
1382/// Call someone else into the current room
1383async fn call(
1384 request: proto::Call,
1385 response: Response<proto::Call>,
1386 session: Session,
1387) -> Result<()> {
1388 let room_id = RoomId::from_proto(request.room_id);
1389 let calling_user_id = session.user_id;
1390 let calling_connection_id = session.connection_id;
1391 let called_user_id = UserId::from_proto(request.called_user_id);
1392 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1393 if !session
1394 .db()
1395 .await
1396 .has_contact(calling_user_id, called_user_id)
1397 .await?
1398 {
1399 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1400 }
1401
1402 let incoming_call = {
1403 let (room, incoming_call) = &mut *session
1404 .db()
1405 .await
1406 .call(
1407 room_id,
1408 calling_user_id,
1409 calling_connection_id,
1410 called_user_id,
1411 initial_project_id,
1412 )
1413 .await?;
1414 room_updated(&room, &session.peer);
1415 mem::take(incoming_call)
1416 };
1417 update_user_contacts(called_user_id, &session).await?;
1418
1419 let mut calls = session
1420 .connection_pool()
1421 .await
1422 .user_connection_ids(called_user_id)
1423 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1424 .collect::<FuturesUnordered<_>>();
1425
1426 while let Some(call_response) = calls.next().await {
1427 match call_response.as_ref() {
1428 Ok(_) => {
1429 response.send(proto::Ack {})?;
1430 return Ok(());
1431 }
1432 Err(_) => {
1433 call_response.trace_err();
1434 }
1435 }
1436 }
1437
1438 {
1439 let room = session
1440 .db()
1441 .await
1442 .call_failed(room_id, called_user_id)
1443 .await?;
1444 room_updated(&room, &session.peer);
1445 }
1446 update_user_contacts(called_user_id, &session).await?;
1447
1448 Err(anyhow!("failed to ring user"))?
1449}
1450
1451/// Cancel an outgoing call.
1452async fn cancel_call(
1453 request: proto::CancelCall,
1454 response: Response<proto::CancelCall>,
1455 session: Session,
1456) -> Result<()> {
1457 let called_user_id = UserId::from_proto(request.called_user_id);
1458 let room_id = RoomId::from_proto(request.room_id);
1459 {
1460 let room = session
1461 .db()
1462 .await
1463 .cancel_call(room_id, session.connection_id, called_user_id)
1464 .await?;
1465 room_updated(&room, &session.peer);
1466 }
1467
1468 for connection_id in session
1469 .connection_pool()
1470 .await
1471 .user_connection_ids(called_user_id)
1472 {
1473 session
1474 .peer
1475 .send(
1476 connection_id,
1477 proto::CallCanceled {
1478 room_id: room_id.to_proto(),
1479 },
1480 )
1481 .trace_err();
1482 }
1483 response.send(proto::Ack {})?;
1484
1485 update_user_contacts(called_user_id, &session).await?;
1486 Ok(())
1487}
1488
1489/// Decline an incoming call.
1490async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1491 let room_id = RoomId::from_proto(message.room_id);
1492 {
1493 let room = session
1494 .db()
1495 .await
1496 .decline_call(Some(room_id), session.user_id)
1497 .await?
1498 .ok_or_else(|| anyhow!("failed to decline call"))?;
1499 room_updated(&room, &session.peer);
1500 }
1501
1502 for connection_id in session
1503 .connection_pool()
1504 .await
1505 .user_connection_ids(session.user_id)
1506 {
1507 session
1508 .peer
1509 .send(
1510 connection_id,
1511 proto::CallCanceled {
1512 room_id: room_id.to_proto(),
1513 },
1514 )
1515 .trace_err();
1516 }
1517 update_user_contacts(session.user_id, &session).await?;
1518 Ok(())
1519}
1520
1521/// Updates other participants in the room with your current location.
1522async fn update_participant_location(
1523 request: proto::UpdateParticipantLocation,
1524 response: Response<proto::UpdateParticipantLocation>,
1525 session: Session,
1526) -> Result<()> {
1527 let room_id = RoomId::from_proto(request.room_id);
1528 let location = request
1529 .location
1530 .ok_or_else(|| anyhow!("invalid location"))?;
1531
1532 let db = session.db().await;
1533 let room = db
1534 .update_room_participant_location(room_id, session.connection_id, location)
1535 .await?;
1536
1537 room_updated(&room, &session.peer);
1538 response.send(proto::Ack {})?;
1539 Ok(())
1540}
1541
1542/// Share a project into the room.
1543async fn share_project(
1544 request: proto::ShareProject,
1545 response: Response<proto::ShareProject>,
1546 session: Session,
1547) -> Result<()> {
1548 let (project_id, room) = &*session
1549 .db()
1550 .await
1551 .share_project(
1552 RoomId::from_proto(request.room_id),
1553 session.connection_id,
1554 &request.worktrees,
1555 )
1556 .await?;
1557 response.send(proto::ShareProjectResponse {
1558 project_id: project_id.to_proto(),
1559 })?;
1560 room_updated(&room, &session.peer);
1561
1562 Ok(())
1563}
1564
1565/// Unshare a project from the room.
1566async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1567 let project_id = ProjectId::from_proto(message.project_id);
1568
1569 let (room, guest_connection_ids) = &*session
1570 .db()
1571 .await
1572 .unshare_project(project_id, session.connection_id)
1573 .await?;
1574
1575 broadcast(
1576 Some(session.connection_id),
1577 guest_connection_ids.iter().copied(),
1578 |conn_id| session.peer.send(conn_id, message.clone()),
1579 );
1580 room_updated(&room, &session.peer);
1581
1582 Ok(())
1583}
1584
1585/// Join someone elses shared project.
1586async fn join_project(
1587 request: proto::JoinProject,
1588 response: Response<proto::JoinProject>,
1589 session: Session,
1590) -> Result<()> {
1591 let project_id = ProjectId::from_proto(request.project_id);
1592
1593 tracing::info!(%project_id, "join project");
1594
1595 let (project, replica_id) = &mut *session
1596 .db()
1597 .await
1598 .join_project_in_room(project_id, session.connection_id)
1599 .await?;
1600
1601 join_project_internal(response, session, project, replica_id)
1602}
1603
1604trait JoinProjectInternalResponse {
1605 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1606}
1607impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1608 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1609 Response::<proto::JoinProject>::send(self, result)
1610 }
1611}
1612impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1613 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1614 Response::<proto::JoinHostedProject>::send(self, result)
1615 }
1616}
1617
1618fn join_project_internal(
1619 response: impl JoinProjectInternalResponse,
1620 session: Session,
1621 project: &mut Project,
1622 replica_id: &ReplicaId,
1623) -> Result<()> {
1624 let collaborators = project
1625 .collaborators
1626 .iter()
1627 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1628 .map(|collaborator| collaborator.to_proto())
1629 .collect::<Vec<_>>();
1630 let project_id = project.id;
1631 let guest_user_id = session.user_id;
1632
1633 let worktrees = project
1634 .worktrees
1635 .iter()
1636 .map(|(id, worktree)| proto::WorktreeMetadata {
1637 id: *id,
1638 root_name: worktree.root_name.clone(),
1639 visible: worktree.visible,
1640 abs_path: worktree.abs_path.clone(),
1641 })
1642 .collect::<Vec<_>>();
1643
1644 for collaborator in &collaborators {
1645 session
1646 .peer
1647 .send(
1648 collaborator.peer_id.unwrap().into(),
1649 proto::AddProjectCollaborator {
1650 project_id: project_id.to_proto(),
1651 collaborator: Some(proto::Collaborator {
1652 peer_id: Some(session.connection_id.into()),
1653 replica_id: replica_id.0 as u32,
1654 user_id: guest_user_id.to_proto(),
1655 }),
1656 },
1657 )
1658 .trace_err();
1659 }
1660
1661 // First, we send the metadata associated with each worktree.
1662 response.send(proto::JoinProjectResponse {
1663 project_id: project.id.0 as u64,
1664 worktrees: worktrees.clone(),
1665 replica_id: replica_id.0 as u32,
1666 collaborators: collaborators.clone(),
1667 language_servers: project.language_servers.clone(),
1668 role: project.role.into(), // todo
1669 })?;
1670
1671 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1672 #[cfg(any(test, feature = "test-support"))]
1673 const MAX_CHUNK_SIZE: usize = 2;
1674 #[cfg(not(any(test, feature = "test-support")))]
1675 const MAX_CHUNK_SIZE: usize = 256;
1676
1677 // Stream this worktree's entries.
1678 let message = proto::UpdateWorktree {
1679 project_id: project_id.to_proto(),
1680 worktree_id,
1681 abs_path: worktree.abs_path.clone(),
1682 root_name: worktree.root_name,
1683 updated_entries: worktree.entries,
1684 removed_entries: Default::default(),
1685 scan_id: worktree.scan_id,
1686 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1687 updated_repositories: worktree.repository_entries.into_values().collect(),
1688 removed_repositories: Default::default(),
1689 };
1690 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1691 session.peer.send(session.connection_id, update.clone())?;
1692 }
1693
1694 // Stream this worktree's diagnostics.
1695 for summary in worktree.diagnostic_summaries {
1696 session.peer.send(
1697 session.connection_id,
1698 proto::UpdateDiagnosticSummary {
1699 project_id: project_id.to_proto(),
1700 worktree_id: worktree.id,
1701 summary: Some(summary),
1702 },
1703 )?;
1704 }
1705
1706 for settings_file in worktree.settings_files {
1707 session.peer.send(
1708 session.connection_id,
1709 proto::UpdateWorktreeSettings {
1710 project_id: project_id.to_proto(),
1711 worktree_id: worktree.id,
1712 path: settings_file.path,
1713 content: Some(settings_file.content),
1714 },
1715 )?;
1716 }
1717 }
1718
1719 for language_server in &project.language_servers {
1720 session.peer.send(
1721 session.connection_id,
1722 proto::UpdateLanguageServer {
1723 project_id: project_id.to_proto(),
1724 language_server_id: language_server.id,
1725 variant: Some(
1726 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1727 proto::LspDiskBasedDiagnosticsUpdated {},
1728 ),
1729 ),
1730 },
1731 )?;
1732 }
1733
1734 Ok(())
1735}
1736
1737/// Leave someone elses shared project.
1738async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1739 let sender_id = session.connection_id;
1740 let project_id = ProjectId::from_proto(request.project_id);
1741 let db = session.db().await;
1742 if db.is_hosted_project(project_id).await? {
1743 let project = db.leave_hosted_project(project_id, sender_id).await?;
1744 project_left(&project, &session);
1745 return Ok(());
1746 }
1747
1748 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1749 tracing::info!(
1750 %project_id,
1751 host_user_id = ?project.host_user_id,
1752 host_connection_id = ?project.host_connection_id,
1753 "leave project"
1754 );
1755
1756 project_left(&project, &session);
1757 room_updated(&room, &session.peer);
1758
1759 Ok(())
1760}
1761
1762async fn join_hosted_project(
1763 request: proto::JoinHostedProject,
1764 response: Response<proto::JoinHostedProject>,
1765 session: Session,
1766) -> Result<()> {
1767 let (mut project, replica_id) = session
1768 .db()
1769 .await
1770 .join_hosted_project(
1771 HostedProjectId(request.id as i32),
1772 session.user_id,
1773 session.connection_id,
1774 )
1775 .await?;
1776
1777 join_project_internal(response, session, &mut project, &replica_id)
1778}
1779
1780/// Updates other participants with changes to the project
1781async fn update_project(
1782 request: proto::UpdateProject,
1783 response: Response<proto::UpdateProject>,
1784 session: Session,
1785) -> Result<()> {
1786 let project_id = ProjectId::from_proto(request.project_id);
1787 let (room, guest_connection_ids) = &*session
1788 .db()
1789 .await
1790 .update_project(project_id, session.connection_id, &request.worktrees)
1791 .await?;
1792 broadcast(
1793 Some(session.connection_id),
1794 guest_connection_ids.iter().copied(),
1795 |connection_id| {
1796 session
1797 .peer
1798 .forward_send(session.connection_id, connection_id, request.clone())
1799 },
1800 );
1801 room_updated(&room, &session.peer);
1802 response.send(proto::Ack {})?;
1803
1804 Ok(())
1805}
1806
1807/// Updates other participants with changes to the worktree
1808async fn update_worktree(
1809 request: proto::UpdateWorktree,
1810 response: Response<proto::UpdateWorktree>,
1811 session: Session,
1812) -> Result<()> {
1813 let guest_connection_ids = session
1814 .db()
1815 .await
1816 .update_worktree(&request, session.connection_id)
1817 .await?;
1818
1819 broadcast(
1820 Some(session.connection_id),
1821 guest_connection_ids.iter().copied(),
1822 |connection_id| {
1823 session
1824 .peer
1825 .forward_send(session.connection_id, connection_id, request.clone())
1826 },
1827 );
1828 response.send(proto::Ack {})?;
1829 Ok(())
1830}
1831
1832/// Updates other participants with changes to the diagnostics
1833async fn update_diagnostic_summary(
1834 message: proto::UpdateDiagnosticSummary,
1835 session: Session,
1836) -> Result<()> {
1837 let guest_connection_ids = session
1838 .db()
1839 .await
1840 .update_diagnostic_summary(&message, session.connection_id)
1841 .await?;
1842
1843 broadcast(
1844 Some(session.connection_id),
1845 guest_connection_ids.iter().copied(),
1846 |connection_id| {
1847 session
1848 .peer
1849 .forward_send(session.connection_id, connection_id, message.clone())
1850 },
1851 );
1852
1853 Ok(())
1854}
1855
1856/// Updates other participants with changes to the worktree settings
1857async fn update_worktree_settings(
1858 message: proto::UpdateWorktreeSettings,
1859 session: Session,
1860) -> Result<()> {
1861 let guest_connection_ids = session
1862 .db()
1863 .await
1864 .update_worktree_settings(&message, session.connection_id)
1865 .await?;
1866
1867 broadcast(
1868 Some(session.connection_id),
1869 guest_connection_ids.iter().copied(),
1870 |connection_id| {
1871 session
1872 .peer
1873 .forward_send(session.connection_id, connection_id, message.clone())
1874 },
1875 );
1876
1877 Ok(())
1878}
1879
1880/// Notify other participants that a language server has started.
1881async fn start_language_server(
1882 request: proto::StartLanguageServer,
1883 session: Session,
1884) -> Result<()> {
1885 let guest_connection_ids = session
1886 .db()
1887 .await
1888 .start_language_server(&request, session.connection_id)
1889 .await?;
1890
1891 broadcast(
1892 Some(session.connection_id),
1893 guest_connection_ids.iter().copied(),
1894 |connection_id| {
1895 session
1896 .peer
1897 .forward_send(session.connection_id, connection_id, request.clone())
1898 },
1899 );
1900 Ok(())
1901}
1902
1903/// Notify other participants that a language server has changed.
1904async fn update_language_server(
1905 request: proto::UpdateLanguageServer,
1906 session: Session,
1907) -> Result<()> {
1908 let project_id = ProjectId::from_proto(request.project_id);
1909 let project_connection_ids = session
1910 .db()
1911 .await
1912 .project_connection_ids(project_id, session.connection_id)
1913 .await?;
1914 broadcast(
1915 Some(session.connection_id),
1916 project_connection_ids.iter().copied(),
1917 |connection_id| {
1918 session
1919 .peer
1920 .forward_send(session.connection_id, connection_id, request.clone())
1921 },
1922 );
1923 Ok(())
1924}
1925
1926/// forward a project request to the host. These requests should be read only
1927/// as guests are allowed to send them.
1928async fn forward_read_only_project_request<T>(
1929 request: T,
1930 response: Response<T>,
1931 session: Session,
1932) -> Result<()>
1933where
1934 T: EntityMessage + RequestMessage,
1935{
1936 let project_id = ProjectId::from_proto(request.remote_entity_id());
1937 let host_connection_id = session
1938 .db()
1939 .await
1940 .host_for_read_only_project_request(project_id, session.connection_id)
1941 .await?;
1942 let payload = session
1943 .peer
1944 .forward_request(session.connection_id, host_connection_id, request)
1945 .await?;
1946 response.send(payload)?;
1947 Ok(())
1948}
1949
1950/// forward a project request to the host. These requests are disallowed
1951/// for guests.
1952async fn forward_mutating_project_request<T>(
1953 request: T,
1954 response: Response<T>,
1955 session: Session,
1956) -> Result<()>
1957where
1958 T: EntityMessage + RequestMessage,
1959{
1960 let project_id = ProjectId::from_proto(request.remote_entity_id());
1961 let host_connection_id = session
1962 .db()
1963 .await
1964 .host_for_mutating_project_request(project_id, session.connection_id)
1965 .await?;
1966 let payload = session
1967 .peer
1968 .forward_request(session.connection_id, host_connection_id, request)
1969 .await?;
1970 response.send(payload)?;
1971 Ok(())
1972}
1973
1974/// Notify other participants that a new buffer has been created
1975async fn create_buffer_for_peer(
1976 request: proto::CreateBufferForPeer,
1977 session: Session,
1978) -> Result<()> {
1979 session
1980 .db()
1981 .await
1982 .check_user_is_project_host(
1983 ProjectId::from_proto(request.project_id),
1984 session.connection_id,
1985 )
1986 .await?;
1987 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1988 session
1989 .peer
1990 .forward_send(session.connection_id, peer_id.into(), request)?;
1991 Ok(())
1992}
1993
1994/// Notify other participants that a buffer has been updated. This is
1995/// allowed for guests as long as the update is limited to selections.
1996async fn update_buffer(
1997 request: proto::UpdateBuffer,
1998 response: Response<proto::UpdateBuffer>,
1999 session: Session,
2000) -> Result<()> {
2001 let project_id = ProjectId::from_proto(request.project_id);
2002 let mut guest_connection_ids;
2003 let mut host_connection_id = None;
2004
2005 let mut requires_write_permission = false;
2006
2007 for op in request.operations.iter() {
2008 match op.variant {
2009 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2010 Some(_) => requires_write_permission = true,
2011 }
2012 }
2013
2014 {
2015 let collaborators = session
2016 .db()
2017 .await
2018 .project_collaborators_for_buffer_update(
2019 project_id,
2020 session.connection_id,
2021 requires_write_permission,
2022 )
2023 .await?;
2024 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2025 for collaborator in collaborators.iter() {
2026 if collaborator.is_host {
2027 host_connection_id = Some(collaborator.connection_id);
2028 } else {
2029 guest_connection_ids.push(collaborator.connection_id);
2030 }
2031 }
2032 }
2033 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2034
2035 broadcast(
2036 Some(session.connection_id),
2037 guest_connection_ids,
2038 |connection_id| {
2039 session
2040 .peer
2041 .forward_send(session.connection_id, connection_id, request.clone())
2042 },
2043 );
2044 if host_connection_id != session.connection_id {
2045 session
2046 .peer
2047 .forward_request(session.connection_id, host_connection_id, request.clone())
2048 .await?;
2049 }
2050
2051 response.send(proto::Ack {})?;
2052 Ok(())
2053}
2054
2055/// Notify other participants that a project has been updated.
2056async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2057 request: T,
2058 session: Session,
2059) -> Result<()> {
2060 let project_id = ProjectId::from_proto(request.remote_entity_id());
2061 let project_connection_ids = session
2062 .db()
2063 .await
2064 .project_connection_ids(project_id, session.connection_id)
2065 .await?;
2066
2067 broadcast(
2068 Some(session.connection_id),
2069 project_connection_ids.iter().copied(),
2070 |connection_id| {
2071 session
2072 .peer
2073 .forward_send(session.connection_id, connection_id, request.clone())
2074 },
2075 );
2076 Ok(())
2077}
2078
2079/// Start following another user in a call.
2080async fn follow(
2081 request: proto::Follow,
2082 response: Response<proto::Follow>,
2083 session: Session,
2084) -> Result<()> {
2085 let room_id = RoomId::from_proto(request.room_id);
2086 let project_id = request.project_id.map(ProjectId::from_proto);
2087 let leader_id = request
2088 .leader_id
2089 .ok_or_else(|| anyhow!("invalid leader id"))?
2090 .into();
2091 let follower_id = session.connection_id;
2092
2093 session
2094 .db()
2095 .await
2096 .check_room_participants(room_id, leader_id, session.connection_id)
2097 .await?;
2098
2099 let response_payload = session
2100 .peer
2101 .forward_request(session.connection_id, leader_id, request)
2102 .await?;
2103 response.send(response_payload)?;
2104
2105 if let Some(project_id) = project_id {
2106 let room = session
2107 .db()
2108 .await
2109 .follow(room_id, project_id, leader_id, follower_id)
2110 .await?;
2111 room_updated(&room, &session.peer);
2112 }
2113
2114 Ok(())
2115}
2116
2117/// Stop following another user in a call.
2118async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2119 let room_id = RoomId::from_proto(request.room_id);
2120 let project_id = request.project_id.map(ProjectId::from_proto);
2121 let leader_id = request
2122 .leader_id
2123 .ok_or_else(|| anyhow!("invalid leader id"))?
2124 .into();
2125 let follower_id = session.connection_id;
2126
2127 session
2128 .db()
2129 .await
2130 .check_room_participants(room_id, leader_id, session.connection_id)
2131 .await?;
2132
2133 session
2134 .peer
2135 .forward_send(session.connection_id, leader_id, request)?;
2136
2137 if let Some(project_id) = project_id {
2138 let room = session
2139 .db()
2140 .await
2141 .unfollow(room_id, project_id, leader_id, follower_id)
2142 .await?;
2143 room_updated(&room, &session.peer);
2144 }
2145
2146 Ok(())
2147}
2148
2149/// Notify everyone following you of your current location.
2150async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2151 let room_id = RoomId::from_proto(request.room_id);
2152 let database = session.db.lock().await;
2153
2154 let connection_ids = if let Some(project_id) = request.project_id {
2155 let project_id = ProjectId::from_proto(project_id);
2156 database
2157 .project_connection_ids(project_id, session.connection_id)
2158 .await?
2159 } else {
2160 database
2161 .room_connection_ids(room_id, session.connection_id)
2162 .await?
2163 };
2164
2165 // For now, don't send view update messages back to that view's current leader.
2166 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2167 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2168 _ => None,
2169 });
2170
2171 for connection_id in connection_ids.iter().cloned() {
2172 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2173 session
2174 .peer
2175 .forward_send(session.connection_id, connection_id, request.clone())?;
2176 }
2177 }
2178 Ok(())
2179}
2180
2181/// Get public data about users.
2182async fn get_users(
2183 request: proto::GetUsers,
2184 response: Response<proto::GetUsers>,
2185 session: Session,
2186) -> Result<()> {
2187 let user_ids = request
2188 .user_ids
2189 .into_iter()
2190 .map(UserId::from_proto)
2191 .collect();
2192 let users = session
2193 .db()
2194 .await
2195 .get_users_by_ids(user_ids)
2196 .await?
2197 .into_iter()
2198 .map(|user| proto::User {
2199 id: user.id.to_proto(),
2200 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2201 github_login: user.github_login,
2202 })
2203 .collect();
2204 response.send(proto::UsersResponse { users })?;
2205 Ok(())
2206}
2207
2208/// Search for users (to invite) buy Github login
2209async fn fuzzy_search_users(
2210 request: proto::FuzzySearchUsers,
2211 response: Response<proto::FuzzySearchUsers>,
2212 session: Session,
2213) -> Result<()> {
2214 let query = request.query;
2215 let users = match query.len() {
2216 0 => vec![],
2217 1 | 2 => session
2218 .db()
2219 .await
2220 .get_user_by_github_login(&query)
2221 .await?
2222 .into_iter()
2223 .collect(),
2224 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2225 };
2226 let users = users
2227 .into_iter()
2228 .filter(|user| user.id != session.user_id)
2229 .map(|user| proto::User {
2230 id: user.id.to_proto(),
2231 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2232 github_login: user.github_login,
2233 })
2234 .collect();
2235 response.send(proto::UsersResponse { users })?;
2236 Ok(())
2237}
2238
2239/// Send a contact request to another user.
2240async fn request_contact(
2241 request: proto::RequestContact,
2242 response: Response<proto::RequestContact>,
2243 session: Session,
2244) -> Result<()> {
2245 let requester_id = session.user_id;
2246 let responder_id = UserId::from_proto(request.responder_id);
2247 if requester_id == responder_id {
2248 return Err(anyhow!("cannot add yourself as a contact"))?;
2249 }
2250
2251 let notifications = session
2252 .db()
2253 .await
2254 .send_contact_request(requester_id, responder_id)
2255 .await?;
2256
2257 // Update outgoing contact requests of requester
2258 let mut update = proto::UpdateContacts::default();
2259 update.outgoing_requests.push(responder_id.to_proto());
2260 for connection_id in session
2261 .connection_pool()
2262 .await
2263 .user_connection_ids(requester_id)
2264 {
2265 session.peer.send(connection_id, update.clone())?;
2266 }
2267
2268 // Update incoming contact requests of responder
2269 let mut update = proto::UpdateContacts::default();
2270 update
2271 .incoming_requests
2272 .push(proto::IncomingContactRequest {
2273 requester_id: requester_id.to_proto(),
2274 });
2275 let connection_pool = session.connection_pool().await;
2276 for connection_id in connection_pool.user_connection_ids(responder_id) {
2277 session.peer.send(connection_id, update.clone())?;
2278 }
2279
2280 send_notifications(&connection_pool, &session.peer, notifications);
2281
2282 response.send(proto::Ack {})?;
2283 Ok(())
2284}
2285
2286/// Accept or decline a contact request
2287async fn respond_to_contact_request(
2288 request: proto::RespondToContactRequest,
2289 response: Response<proto::RespondToContactRequest>,
2290 session: Session,
2291) -> Result<()> {
2292 let responder_id = session.user_id;
2293 let requester_id = UserId::from_proto(request.requester_id);
2294 let db = session.db().await;
2295 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2296 db.dismiss_contact_notification(responder_id, requester_id)
2297 .await?;
2298 } else {
2299 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2300
2301 let notifications = db
2302 .respond_to_contact_request(responder_id, requester_id, accept)
2303 .await?;
2304 let requester_busy = db.is_user_busy(requester_id).await?;
2305 let responder_busy = db.is_user_busy(responder_id).await?;
2306
2307 let pool = session.connection_pool().await;
2308 // Update responder with new contact
2309 let mut update = proto::UpdateContacts::default();
2310 if accept {
2311 update
2312 .contacts
2313 .push(contact_for_user(requester_id, requester_busy, &pool));
2314 }
2315 update
2316 .remove_incoming_requests
2317 .push(requester_id.to_proto());
2318 for connection_id in pool.user_connection_ids(responder_id) {
2319 session.peer.send(connection_id, update.clone())?;
2320 }
2321
2322 // Update requester with new contact
2323 let mut update = proto::UpdateContacts::default();
2324 if accept {
2325 update
2326 .contacts
2327 .push(contact_for_user(responder_id, responder_busy, &pool));
2328 }
2329 update
2330 .remove_outgoing_requests
2331 .push(responder_id.to_proto());
2332
2333 for connection_id in pool.user_connection_ids(requester_id) {
2334 session.peer.send(connection_id, update.clone())?;
2335 }
2336
2337 send_notifications(&pool, &session.peer, notifications);
2338 }
2339
2340 response.send(proto::Ack {})?;
2341 Ok(())
2342}
2343
2344/// Remove a contact.
2345async fn remove_contact(
2346 request: proto::RemoveContact,
2347 response: Response<proto::RemoveContact>,
2348 session: Session,
2349) -> Result<()> {
2350 let requester_id = session.user_id;
2351 let responder_id = UserId::from_proto(request.user_id);
2352 let db = session.db().await;
2353 let (contact_accepted, deleted_notification_id) =
2354 db.remove_contact(requester_id, responder_id).await?;
2355
2356 let pool = session.connection_pool().await;
2357 // Update outgoing contact requests of requester
2358 let mut update = proto::UpdateContacts::default();
2359 if contact_accepted {
2360 update.remove_contacts.push(responder_id.to_proto());
2361 } else {
2362 update
2363 .remove_outgoing_requests
2364 .push(responder_id.to_proto());
2365 }
2366 for connection_id in pool.user_connection_ids(requester_id) {
2367 session.peer.send(connection_id, update.clone())?;
2368 }
2369
2370 // Update incoming contact requests of responder
2371 let mut update = proto::UpdateContacts::default();
2372 if contact_accepted {
2373 update.remove_contacts.push(requester_id.to_proto());
2374 } else {
2375 update
2376 .remove_incoming_requests
2377 .push(requester_id.to_proto());
2378 }
2379 for connection_id in pool.user_connection_ids(responder_id) {
2380 session.peer.send(connection_id, update.clone())?;
2381 if let Some(notification_id) = deleted_notification_id {
2382 session.peer.send(
2383 connection_id,
2384 proto::DeleteNotification {
2385 notification_id: notification_id.to_proto(),
2386 },
2387 )?;
2388 }
2389 }
2390
2391 response.send(proto::Ack {})?;
2392 Ok(())
2393}
2394
2395/// Creates a new channel.
2396async fn create_channel(
2397 request: proto::CreateChannel,
2398 response: Response<proto::CreateChannel>,
2399 session: Session,
2400) -> Result<()> {
2401 let db = session.db().await;
2402
2403 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2404 let (channel, owner, channel_members) = db
2405 .create_channel(&request.name, parent_id, session.user_id)
2406 .await?;
2407
2408 response.send(proto::CreateChannelResponse {
2409 channel: Some(channel.to_proto()),
2410 parent_id: request.parent_id,
2411 })?;
2412
2413 let connection_pool = session.connection_pool().await;
2414 if let Some(owner) = owner {
2415 let update = proto::UpdateUserChannels {
2416 channel_memberships: vec![proto::ChannelMembership {
2417 channel_id: owner.channel_id.to_proto(),
2418 role: owner.role.into(),
2419 }],
2420 ..Default::default()
2421 };
2422 for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2423 session.peer.send(connection_id, update.clone())?;
2424 }
2425 }
2426
2427 for channel_member in channel_members {
2428 if !channel_member.role.can_see_channel(channel.visibility) {
2429 continue;
2430 }
2431
2432 let update = proto::UpdateChannels {
2433 channels: vec![channel.to_proto()],
2434 ..Default::default()
2435 };
2436 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2437 session.peer.send(connection_id, update.clone())?;
2438 }
2439 }
2440
2441 Ok(())
2442}
2443
2444/// Delete a channel
2445async fn delete_channel(
2446 request: proto::DeleteChannel,
2447 response: Response<proto::DeleteChannel>,
2448 session: Session,
2449) -> Result<()> {
2450 let db = session.db().await;
2451
2452 let channel_id = request.channel_id;
2453 let (removed_channels, member_ids) = db
2454 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2455 .await?;
2456 response.send(proto::Ack {})?;
2457
2458 // Notify members of removed channels
2459 let mut update = proto::UpdateChannels::default();
2460 update
2461 .delete_channels
2462 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2463
2464 let connection_pool = session.connection_pool().await;
2465 for member_id in member_ids {
2466 for connection_id in connection_pool.user_connection_ids(member_id) {
2467 session.peer.send(connection_id, update.clone())?;
2468 }
2469 }
2470
2471 Ok(())
2472}
2473
2474/// Invite someone to join a channel.
2475async fn invite_channel_member(
2476 request: proto::InviteChannelMember,
2477 response: Response<proto::InviteChannelMember>,
2478 session: Session,
2479) -> Result<()> {
2480 let db = session.db().await;
2481 let channel_id = ChannelId::from_proto(request.channel_id);
2482 let invitee_id = UserId::from_proto(request.user_id);
2483 let InviteMemberResult {
2484 channel,
2485 notifications,
2486 } = db
2487 .invite_channel_member(
2488 channel_id,
2489 invitee_id,
2490 session.user_id,
2491 request.role().into(),
2492 )
2493 .await?;
2494
2495 let update = proto::UpdateChannels {
2496 channel_invitations: vec![channel.to_proto()],
2497 ..Default::default()
2498 };
2499
2500 let connection_pool = session.connection_pool().await;
2501 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2502 session.peer.send(connection_id, update.clone())?;
2503 }
2504
2505 send_notifications(&connection_pool, &session.peer, notifications);
2506
2507 response.send(proto::Ack {})?;
2508 Ok(())
2509}
2510
2511/// remove someone from a channel
2512async fn remove_channel_member(
2513 request: proto::RemoveChannelMember,
2514 response: Response<proto::RemoveChannelMember>,
2515 session: Session,
2516) -> Result<()> {
2517 let db = session.db().await;
2518 let channel_id = ChannelId::from_proto(request.channel_id);
2519 let member_id = UserId::from_proto(request.user_id);
2520
2521 let RemoveChannelMemberResult {
2522 membership_update,
2523 notification_id,
2524 } = db
2525 .remove_channel_member(channel_id, member_id, session.user_id)
2526 .await?;
2527
2528 let connection_pool = &session.connection_pool().await;
2529 notify_membership_updated(
2530 &connection_pool,
2531 membership_update,
2532 member_id,
2533 &session.peer,
2534 );
2535 for connection_id in connection_pool.user_connection_ids(member_id) {
2536 if let Some(notification_id) = notification_id {
2537 session
2538 .peer
2539 .send(
2540 connection_id,
2541 proto::DeleteNotification {
2542 notification_id: notification_id.to_proto(),
2543 },
2544 )
2545 .trace_err();
2546 }
2547 }
2548
2549 response.send(proto::Ack {})?;
2550 Ok(())
2551}
2552
2553/// Toggle the channel between public and private.
2554/// Care is taken to maintain the invariant that public channels only descend from public channels,
2555/// (though members-only channels can appear at any point in the hierarchy).
2556async fn set_channel_visibility(
2557 request: proto::SetChannelVisibility,
2558 response: Response<proto::SetChannelVisibility>,
2559 session: Session,
2560) -> Result<()> {
2561 let db = session.db().await;
2562 let channel_id = ChannelId::from_proto(request.channel_id);
2563 let visibility = request.visibility().into();
2564
2565 let (channel, channel_members) = db
2566 .set_channel_visibility(channel_id, visibility, session.user_id)
2567 .await?;
2568
2569 let connection_pool = session.connection_pool().await;
2570 for member in channel_members {
2571 let update = if member.role.can_see_channel(channel.visibility) {
2572 proto::UpdateChannels {
2573 channels: vec![channel.to_proto()],
2574 ..Default::default()
2575 }
2576 } else {
2577 proto::UpdateChannels {
2578 delete_channels: vec![channel.id.to_proto()],
2579 ..Default::default()
2580 }
2581 };
2582
2583 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2584 session.peer.send(connection_id, update.clone())?;
2585 }
2586 }
2587
2588 response.send(proto::Ack {})?;
2589 Ok(())
2590}
2591
2592/// Alter the role for a user in the channel.
2593async fn set_channel_member_role(
2594 request: proto::SetChannelMemberRole,
2595 response: Response<proto::SetChannelMemberRole>,
2596 session: Session,
2597) -> Result<()> {
2598 let db = session.db().await;
2599 let channel_id = ChannelId::from_proto(request.channel_id);
2600 let member_id = UserId::from_proto(request.user_id);
2601 let result = db
2602 .set_channel_member_role(
2603 channel_id,
2604 session.user_id,
2605 member_id,
2606 request.role().into(),
2607 )
2608 .await?;
2609
2610 match result {
2611 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2612 let connection_pool = session.connection_pool().await;
2613 notify_membership_updated(
2614 &connection_pool,
2615 membership_update,
2616 member_id,
2617 &session.peer,
2618 )
2619 }
2620 db::SetMemberRoleResult::InviteUpdated(channel) => {
2621 let update = proto::UpdateChannels {
2622 channel_invitations: vec![channel.to_proto()],
2623 ..Default::default()
2624 };
2625
2626 for connection_id in session
2627 .connection_pool()
2628 .await
2629 .user_connection_ids(member_id)
2630 {
2631 session.peer.send(connection_id, update.clone())?;
2632 }
2633 }
2634 }
2635
2636 response.send(proto::Ack {})?;
2637 Ok(())
2638}
2639
2640/// Change the name of a channel
2641async fn rename_channel(
2642 request: proto::RenameChannel,
2643 response: Response<proto::RenameChannel>,
2644 session: Session,
2645) -> Result<()> {
2646 let db = session.db().await;
2647 let channel_id = ChannelId::from_proto(request.channel_id);
2648 let (channel, channel_members) = db
2649 .rename_channel(channel_id, session.user_id, &request.name)
2650 .await?;
2651
2652 response.send(proto::RenameChannelResponse {
2653 channel: Some(channel.to_proto()),
2654 })?;
2655
2656 let connection_pool = session.connection_pool().await;
2657 for channel_member in channel_members {
2658 if !channel_member.role.can_see_channel(channel.visibility) {
2659 continue;
2660 }
2661 let update = proto::UpdateChannels {
2662 channels: vec![channel.to_proto()],
2663 ..Default::default()
2664 };
2665 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2666 session.peer.send(connection_id, update.clone())?;
2667 }
2668 }
2669
2670 Ok(())
2671}
2672
2673/// Move a channel to a new parent.
2674async fn move_channel(
2675 request: proto::MoveChannel,
2676 response: Response<proto::MoveChannel>,
2677 session: Session,
2678) -> Result<()> {
2679 let channel_id = ChannelId::from_proto(request.channel_id);
2680 let to = ChannelId::from_proto(request.to);
2681
2682 let (channels, channel_members) = session
2683 .db()
2684 .await
2685 .move_channel(channel_id, to, session.user_id)
2686 .await?;
2687
2688 let connection_pool = session.connection_pool().await;
2689 for member in channel_members {
2690 let channels = channels
2691 .iter()
2692 .filter_map(|channel| {
2693 if member.role.can_see_channel(channel.visibility) {
2694 Some(channel.to_proto())
2695 } else {
2696 None
2697 }
2698 })
2699 .collect::<Vec<_>>();
2700 if channels.is_empty() {
2701 continue;
2702 }
2703
2704 let update = proto::UpdateChannels {
2705 channels,
2706 ..Default::default()
2707 };
2708
2709 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2710 session.peer.send(connection_id, update.clone())?;
2711 }
2712 }
2713
2714 response.send(Ack {})?;
2715 Ok(())
2716}
2717
2718/// Get the list of channel members
2719async fn get_channel_members(
2720 request: proto::GetChannelMembers,
2721 response: Response<proto::GetChannelMembers>,
2722 session: Session,
2723) -> Result<()> {
2724 let db = session.db().await;
2725 let channel_id = ChannelId::from_proto(request.channel_id);
2726 let members = db
2727 .get_channel_participant_details(channel_id, session.user_id)
2728 .await?;
2729 response.send(proto::GetChannelMembersResponse { members })?;
2730 Ok(())
2731}
2732
2733/// Accept or decline a channel invitation.
2734async fn respond_to_channel_invite(
2735 request: proto::RespondToChannelInvite,
2736 response: Response<proto::RespondToChannelInvite>,
2737 session: Session,
2738) -> Result<()> {
2739 let db = session.db().await;
2740 let channel_id = ChannelId::from_proto(request.channel_id);
2741 let RespondToChannelInvite {
2742 membership_update,
2743 notifications,
2744 } = db
2745 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2746 .await?;
2747
2748 let connection_pool = session.connection_pool().await;
2749 if let Some(membership_update) = membership_update {
2750 notify_membership_updated(
2751 &connection_pool,
2752 membership_update,
2753 session.user_id,
2754 &session.peer,
2755 );
2756 } else {
2757 let update = proto::UpdateChannels {
2758 remove_channel_invitations: vec![channel_id.to_proto()],
2759 ..Default::default()
2760 };
2761
2762 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2763 session.peer.send(connection_id, update.clone())?;
2764 }
2765 };
2766
2767 send_notifications(&connection_pool, &session.peer, notifications);
2768
2769 response.send(proto::Ack {})?;
2770
2771 Ok(())
2772}
2773
2774/// Join the channels' room
2775async fn join_channel(
2776 request: proto::JoinChannel,
2777 response: Response<proto::JoinChannel>,
2778 session: Session,
2779) -> Result<()> {
2780 let channel_id = ChannelId::from_proto(request.channel_id);
2781 join_channel_internal(channel_id, Box::new(response), session).await
2782}
2783
2784trait JoinChannelInternalResponse {
2785 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2786}
2787impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2788 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2789 Response::<proto::JoinChannel>::send(self, result)
2790 }
2791}
2792impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2793 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2794 Response::<proto::JoinRoom>::send(self, result)
2795 }
2796}
2797
2798async fn join_channel_internal(
2799 channel_id: ChannelId,
2800 response: Box<impl JoinChannelInternalResponse>,
2801 session: Session,
2802) -> Result<()> {
2803 let joined_room = {
2804 leave_room_for_session(&session).await?;
2805 let db = session.db().await;
2806
2807 let (joined_room, membership_updated, role) = db
2808 .join_channel(channel_id, session.user_id, session.connection_id)
2809 .await?;
2810
2811 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2812 let (can_publish, token) = if role == ChannelRole::Guest {
2813 (
2814 false,
2815 live_kit
2816 .guest_token(
2817 &joined_room.room.live_kit_room,
2818 &session.user_id.to_string(),
2819 )
2820 .trace_err()?,
2821 )
2822 } else {
2823 (
2824 true,
2825 live_kit
2826 .room_token(
2827 &joined_room.room.live_kit_room,
2828 &session.user_id.to_string(),
2829 )
2830 .trace_err()?,
2831 )
2832 };
2833
2834 Some(LiveKitConnectionInfo {
2835 server_url: live_kit.url().into(),
2836 token,
2837 can_publish,
2838 })
2839 });
2840
2841 response.send(proto::JoinRoomResponse {
2842 room: Some(joined_room.room.clone()),
2843 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2844 live_kit_connection_info,
2845 })?;
2846
2847 let connection_pool = session.connection_pool().await;
2848 if let Some(membership_updated) = membership_updated {
2849 notify_membership_updated(
2850 &connection_pool,
2851 membership_updated,
2852 session.user_id,
2853 &session.peer,
2854 );
2855 }
2856
2857 room_updated(&joined_room.room, &session.peer);
2858
2859 joined_room
2860 };
2861
2862 channel_updated(
2863 channel_id,
2864 &joined_room.room,
2865 &joined_room.channel_members,
2866 &session.peer,
2867 &*session.connection_pool().await,
2868 );
2869
2870 update_user_contacts(session.user_id, &session).await?;
2871 Ok(())
2872}
2873
2874/// Start editing the channel notes
2875async fn join_channel_buffer(
2876 request: proto::JoinChannelBuffer,
2877 response: Response<proto::JoinChannelBuffer>,
2878 session: Session,
2879) -> Result<()> {
2880 let db = session.db().await;
2881 let channel_id = ChannelId::from_proto(request.channel_id);
2882
2883 let open_response = db
2884 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2885 .await?;
2886
2887 let collaborators = open_response.collaborators.clone();
2888 response.send(open_response)?;
2889
2890 let update = UpdateChannelBufferCollaborators {
2891 channel_id: channel_id.to_proto(),
2892 collaborators: collaborators.clone(),
2893 };
2894 channel_buffer_updated(
2895 session.connection_id,
2896 collaborators
2897 .iter()
2898 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2899 &update,
2900 &session.peer,
2901 );
2902
2903 Ok(())
2904}
2905
2906/// Edit the channel notes
2907async fn update_channel_buffer(
2908 request: proto::UpdateChannelBuffer,
2909 session: Session,
2910) -> Result<()> {
2911 let db = session.db().await;
2912 let channel_id = ChannelId::from_proto(request.channel_id);
2913
2914 let (collaborators, non_collaborators, epoch, version) = db
2915 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2916 .await?;
2917
2918 channel_buffer_updated(
2919 session.connection_id,
2920 collaborators,
2921 &proto::UpdateChannelBuffer {
2922 channel_id: channel_id.to_proto(),
2923 operations: request.operations,
2924 },
2925 &session.peer,
2926 );
2927
2928 let pool = &*session.connection_pool().await;
2929
2930 broadcast(
2931 None,
2932 non_collaborators
2933 .iter()
2934 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2935 |peer_id| {
2936 session.peer.send(
2937 peer_id,
2938 proto::UpdateChannels {
2939 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2940 channel_id: channel_id.to_proto(),
2941 epoch: epoch as u64,
2942 version: version.clone(),
2943 }],
2944 ..Default::default()
2945 },
2946 )
2947 },
2948 );
2949
2950 Ok(())
2951}
2952
2953/// Rejoin the channel notes after a connection blip
2954async fn rejoin_channel_buffers(
2955 request: proto::RejoinChannelBuffers,
2956 response: Response<proto::RejoinChannelBuffers>,
2957 session: Session,
2958) -> Result<()> {
2959 let db = session.db().await;
2960 let buffers = db
2961 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2962 .await?;
2963
2964 for rejoined_buffer in &buffers {
2965 let collaborators_to_notify = rejoined_buffer
2966 .buffer
2967 .collaborators
2968 .iter()
2969 .filter_map(|c| Some(c.peer_id?.into()));
2970 channel_buffer_updated(
2971 session.connection_id,
2972 collaborators_to_notify,
2973 &proto::UpdateChannelBufferCollaborators {
2974 channel_id: rejoined_buffer.buffer.channel_id,
2975 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2976 },
2977 &session.peer,
2978 );
2979 }
2980
2981 response.send(proto::RejoinChannelBuffersResponse {
2982 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2983 })?;
2984
2985 Ok(())
2986}
2987
2988/// Stop editing the channel notes
2989async fn leave_channel_buffer(
2990 request: proto::LeaveChannelBuffer,
2991 response: Response<proto::LeaveChannelBuffer>,
2992 session: Session,
2993) -> Result<()> {
2994 let db = session.db().await;
2995 let channel_id = ChannelId::from_proto(request.channel_id);
2996
2997 let left_buffer = db
2998 .leave_channel_buffer(channel_id, session.connection_id)
2999 .await?;
3000
3001 response.send(Ack {})?;
3002
3003 channel_buffer_updated(
3004 session.connection_id,
3005 left_buffer.connections,
3006 &proto::UpdateChannelBufferCollaborators {
3007 channel_id: channel_id.to_proto(),
3008 collaborators: left_buffer.collaborators,
3009 },
3010 &session.peer,
3011 );
3012
3013 Ok(())
3014}
3015
3016fn channel_buffer_updated<T: EnvelopedMessage>(
3017 sender_id: ConnectionId,
3018 collaborators: impl IntoIterator<Item = ConnectionId>,
3019 message: &T,
3020 peer: &Peer,
3021) {
3022 broadcast(Some(sender_id), collaborators, |peer_id| {
3023 peer.send(peer_id, message.clone())
3024 });
3025}
3026
3027fn send_notifications(
3028 connection_pool: &ConnectionPool,
3029 peer: &Peer,
3030 notifications: db::NotificationBatch,
3031) {
3032 for (user_id, notification) in notifications {
3033 for connection_id in connection_pool.user_connection_ids(user_id) {
3034 if let Err(error) = peer.send(
3035 connection_id,
3036 proto::AddNotification {
3037 notification: Some(notification.clone()),
3038 },
3039 ) {
3040 tracing::error!(
3041 "failed to send notification to {:?} {}",
3042 connection_id,
3043 error
3044 );
3045 }
3046 }
3047 }
3048}
3049
3050/// Send a message to the channel
3051async fn send_channel_message(
3052 request: proto::SendChannelMessage,
3053 response: Response<proto::SendChannelMessage>,
3054 session: Session,
3055) -> Result<()> {
3056 // Validate the message body.
3057 let body = request.body.trim().to_string();
3058 if body.len() > MAX_MESSAGE_LEN {
3059 return Err(anyhow!("message is too long"))?;
3060 }
3061 if body.is_empty() {
3062 return Err(anyhow!("message can't be blank"))?;
3063 }
3064
3065 // TODO: adjust mentions if body is trimmed
3066
3067 let timestamp = OffsetDateTime::now_utc();
3068 let nonce = request
3069 .nonce
3070 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3071
3072 let channel_id = ChannelId::from_proto(request.channel_id);
3073 let CreatedChannelMessage {
3074 message_id,
3075 participant_connection_ids,
3076 channel_members,
3077 notifications,
3078 } = session
3079 .db()
3080 .await
3081 .create_channel_message(
3082 channel_id,
3083 session.user_id,
3084 &body,
3085 &request.mentions,
3086 timestamp,
3087 nonce.clone().into(),
3088 match request.reply_to_message_id {
3089 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3090 None => None,
3091 },
3092 )
3093 .await?;
3094 let message = proto::ChannelMessage {
3095 sender_id: session.user_id.to_proto(),
3096 id: message_id.to_proto(),
3097 body,
3098 mentions: request.mentions,
3099 timestamp: timestamp.unix_timestamp() as u64,
3100 nonce: Some(nonce),
3101 reply_to_message_id: request.reply_to_message_id,
3102 };
3103 broadcast(
3104 Some(session.connection_id),
3105 participant_connection_ids,
3106 |connection| {
3107 session.peer.send(
3108 connection,
3109 proto::ChannelMessageSent {
3110 channel_id: channel_id.to_proto(),
3111 message: Some(message.clone()),
3112 },
3113 )
3114 },
3115 );
3116 response.send(proto::SendChannelMessageResponse {
3117 message: Some(message),
3118 })?;
3119
3120 let pool = &*session.connection_pool().await;
3121 broadcast(
3122 None,
3123 channel_members
3124 .iter()
3125 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3126 |peer_id| {
3127 session.peer.send(
3128 peer_id,
3129 proto::UpdateChannels {
3130 latest_channel_message_ids: vec![proto::ChannelMessageId {
3131 channel_id: channel_id.to_proto(),
3132 message_id: message_id.to_proto(),
3133 }],
3134 ..Default::default()
3135 },
3136 )
3137 },
3138 );
3139 send_notifications(pool, &session.peer, notifications);
3140
3141 Ok(())
3142}
3143
3144/// Delete a channel message
3145async fn remove_channel_message(
3146 request: proto::RemoveChannelMessage,
3147 response: Response<proto::RemoveChannelMessage>,
3148 session: Session,
3149) -> Result<()> {
3150 let channel_id = ChannelId::from_proto(request.channel_id);
3151 let message_id = MessageId::from_proto(request.message_id);
3152 let connection_ids = session
3153 .db()
3154 .await
3155 .remove_channel_message(channel_id, message_id, session.user_id)
3156 .await?;
3157 broadcast(Some(session.connection_id), connection_ids, |connection| {
3158 session.peer.send(connection, request.clone())
3159 });
3160 response.send(proto::Ack {})?;
3161 Ok(())
3162}
3163
3164/// Mark a channel message as read
3165async fn acknowledge_channel_message(
3166 request: proto::AckChannelMessage,
3167 session: Session,
3168) -> Result<()> {
3169 let channel_id = ChannelId::from_proto(request.channel_id);
3170 let message_id = MessageId::from_proto(request.message_id);
3171 let notifications = session
3172 .db()
3173 .await
3174 .observe_channel_message(channel_id, session.user_id, message_id)
3175 .await?;
3176 send_notifications(
3177 &*session.connection_pool().await,
3178 &session.peer,
3179 notifications,
3180 );
3181 Ok(())
3182}
3183
3184/// Mark a buffer version as synced
3185async fn acknowledge_buffer_version(
3186 request: proto::AckBufferOperation,
3187 session: Session,
3188) -> Result<()> {
3189 let buffer_id = BufferId::from_proto(request.buffer_id);
3190 session
3191 .db()
3192 .await
3193 .observe_buffer_version(
3194 buffer_id,
3195 session.user_id,
3196 request.epoch as i32,
3197 &request.version,
3198 )
3199 .await?;
3200 Ok(())
3201}
3202
3203/// Start receiving chat updates for a channel
3204async fn join_channel_chat(
3205 request: proto::JoinChannelChat,
3206 response: Response<proto::JoinChannelChat>,
3207 session: Session,
3208) -> Result<()> {
3209 let channel_id = ChannelId::from_proto(request.channel_id);
3210
3211 let db = session.db().await;
3212 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3213 .await?;
3214 let messages = db
3215 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3216 .await?;
3217 response.send(proto::JoinChannelChatResponse {
3218 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3219 messages,
3220 })?;
3221 Ok(())
3222}
3223
3224/// Stop receiving chat updates for a channel
3225async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3226 let channel_id = ChannelId::from_proto(request.channel_id);
3227 session
3228 .db()
3229 .await
3230 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3231 .await?;
3232 Ok(())
3233}
3234
3235/// Retrieve the chat history for a channel
3236async fn get_channel_messages(
3237 request: proto::GetChannelMessages,
3238 response: Response<proto::GetChannelMessages>,
3239 session: Session,
3240) -> Result<()> {
3241 let channel_id = ChannelId::from_proto(request.channel_id);
3242 let messages = session
3243 .db()
3244 .await
3245 .get_channel_messages(
3246 channel_id,
3247 session.user_id,
3248 MESSAGE_COUNT_PER_PAGE,
3249 Some(MessageId::from_proto(request.before_message_id)),
3250 )
3251 .await?;
3252 response.send(proto::GetChannelMessagesResponse {
3253 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3254 messages,
3255 })?;
3256 Ok(())
3257}
3258
3259/// Retrieve specific chat messages
3260async fn get_channel_messages_by_id(
3261 request: proto::GetChannelMessagesById,
3262 response: Response<proto::GetChannelMessagesById>,
3263 session: Session,
3264) -> Result<()> {
3265 let message_ids = request
3266 .message_ids
3267 .iter()
3268 .map(|id| MessageId::from_proto(*id))
3269 .collect::<Vec<_>>();
3270 let messages = session
3271 .db()
3272 .await
3273 .get_channel_messages_by_id(session.user_id, &message_ids)
3274 .await?;
3275 response.send(proto::GetChannelMessagesResponse {
3276 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3277 messages,
3278 })?;
3279 Ok(())
3280}
3281
3282/// Retrieve the current users notifications
3283async fn get_notifications(
3284 request: proto::GetNotifications,
3285 response: Response<proto::GetNotifications>,
3286 session: Session,
3287) -> Result<()> {
3288 let notifications = session
3289 .db()
3290 .await
3291 .get_notifications(
3292 session.user_id,
3293 NOTIFICATION_COUNT_PER_PAGE,
3294 request
3295 .before_id
3296 .map(|id| db::NotificationId::from_proto(id)),
3297 )
3298 .await?;
3299 response.send(proto::GetNotificationsResponse {
3300 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3301 notifications,
3302 })?;
3303 Ok(())
3304}
3305
3306/// Mark notifications as read
3307async fn mark_notification_as_read(
3308 request: proto::MarkNotificationRead,
3309 response: Response<proto::MarkNotificationRead>,
3310 session: Session,
3311) -> Result<()> {
3312 let database = &session.db().await;
3313 let notifications = database
3314 .mark_notification_as_read_by_id(
3315 session.user_id,
3316 NotificationId::from_proto(request.notification_id),
3317 )
3318 .await?;
3319 send_notifications(
3320 &*session.connection_pool().await,
3321 &session.peer,
3322 notifications,
3323 );
3324 response.send(proto::Ack {})?;
3325 Ok(())
3326}
3327
3328/// Get the current users information
3329async fn get_private_user_info(
3330 _request: proto::GetPrivateUserInfo,
3331 response: Response<proto::GetPrivateUserInfo>,
3332 session: Session,
3333) -> Result<()> {
3334 let db = session.db().await;
3335
3336 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3337 let user = db
3338 .get_user_by_id(session.user_id)
3339 .await?
3340 .ok_or_else(|| anyhow!("user not found"))?;
3341 let flags = db.get_user_flags(session.user_id).await?;
3342
3343 response.send(proto::GetPrivateUserInfoResponse {
3344 metrics_id,
3345 staff: user.admin,
3346 flags,
3347 })?;
3348 Ok(())
3349}
3350
3351fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3352 match message {
3353 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3354 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3355 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3356 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3357 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3358 code: frame.code.into(),
3359 reason: frame.reason,
3360 })),
3361 }
3362}
3363
3364fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3365 match message {
3366 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3367 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3368 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3369 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3370 AxumMessage::Close(frame) => {
3371 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3372 code: frame.code.into(),
3373 reason: frame.reason,
3374 }))
3375 }
3376 }
3377}
3378
3379fn notify_membership_updated(
3380 connection_pool: &ConnectionPool,
3381 result: MembershipUpdated,
3382 user_id: UserId,
3383 peer: &Peer,
3384) {
3385 let user_channels_update = proto::UpdateUserChannels {
3386 channel_memberships: result
3387 .new_channels
3388 .channel_memberships
3389 .iter()
3390 .map(|cm| proto::ChannelMembership {
3391 channel_id: cm.channel_id.to_proto(),
3392 role: cm.role.into(),
3393 })
3394 .collect(),
3395 ..Default::default()
3396 };
3397 let mut update = build_channels_update(result.new_channels, vec![]);
3398 update.delete_channels = result
3399 .removed_channels
3400 .into_iter()
3401 .map(|id| id.to_proto())
3402 .collect();
3403 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3404
3405 for connection_id in connection_pool.user_connection_ids(user_id) {
3406 peer.send(connection_id, user_channels_update.clone())
3407 .trace_err();
3408 peer.send(connection_id, update.clone()).trace_err();
3409 }
3410}
3411
3412fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3413 proto::UpdateUserChannels {
3414 channel_memberships: channels
3415 .channel_memberships
3416 .iter()
3417 .map(|m| proto::ChannelMembership {
3418 channel_id: m.channel_id.to_proto(),
3419 role: m.role.into(),
3420 })
3421 .collect(),
3422 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3423 observed_channel_message_id: channels.observed_channel_messages.clone(),
3424 }
3425}
3426
3427fn build_channels_update(
3428 channels: ChannelsForUser,
3429 channel_invites: Vec<db::Channel>,
3430) -> proto::UpdateChannels {
3431 let mut update = proto::UpdateChannels::default();
3432
3433 for channel in channels.channels {
3434 update.channels.push(channel.to_proto());
3435 }
3436
3437 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3438 update.latest_channel_message_ids = channels.latest_channel_messages;
3439
3440 for (channel_id, participants) in channels.channel_participants {
3441 update
3442 .channel_participants
3443 .push(proto::ChannelParticipants {
3444 channel_id: channel_id.to_proto(),
3445 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3446 });
3447 }
3448
3449 for channel in channel_invites {
3450 update.channel_invitations.push(channel.to_proto());
3451 }
3452 for project in channels.hosted_projects {
3453 update.hosted_projects.push(project);
3454 }
3455
3456 update
3457}
3458
3459fn build_initial_contacts_update(
3460 contacts: Vec<db::Contact>,
3461 pool: &ConnectionPool,
3462) -> proto::UpdateContacts {
3463 let mut update = proto::UpdateContacts::default();
3464
3465 for contact in contacts {
3466 match contact {
3467 db::Contact::Accepted { user_id, busy } => {
3468 update.contacts.push(contact_for_user(user_id, busy, &pool));
3469 }
3470 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3471 db::Contact::Incoming { user_id } => {
3472 update
3473 .incoming_requests
3474 .push(proto::IncomingContactRequest {
3475 requester_id: user_id.to_proto(),
3476 })
3477 }
3478 }
3479 }
3480
3481 update
3482}
3483
3484fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3485 proto::Contact {
3486 user_id: user_id.to_proto(),
3487 online: pool.is_user_online(user_id),
3488 busy,
3489 }
3490}
3491
3492fn room_updated(room: &proto::Room, peer: &Peer) {
3493 broadcast(
3494 None,
3495 room.participants
3496 .iter()
3497 .filter_map(|participant| Some(participant.peer_id?.into())),
3498 |peer_id| {
3499 peer.send(
3500 peer_id,
3501 proto::RoomUpdated {
3502 room: Some(room.clone()),
3503 },
3504 )
3505 },
3506 );
3507}
3508
3509fn channel_updated(
3510 channel_id: ChannelId,
3511 room: &proto::Room,
3512 channel_members: &[UserId],
3513 peer: &Peer,
3514 pool: &ConnectionPool,
3515) {
3516 let participants = room
3517 .participants
3518 .iter()
3519 .map(|p| p.user_id)
3520 .collect::<Vec<_>>();
3521
3522 broadcast(
3523 None,
3524 channel_members
3525 .iter()
3526 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3527 |peer_id| {
3528 peer.send(
3529 peer_id,
3530 proto::UpdateChannels {
3531 channel_participants: vec![proto::ChannelParticipants {
3532 channel_id: channel_id.to_proto(),
3533 participant_user_ids: participants.clone(),
3534 }],
3535 ..Default::default()
3536 },
3537 )
3538 },
3539 );
3540}
3541
3542async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3543 let db = session.db().await;
3544
3545 let contacts = db.get_contacts(user_id).await?;
3546 let busy = db.is_user_busy(user_id).await?;
3547
3548 let pool = session.connection_pool().await;
3549 let updated_contact = contact_for_user(user_id, busy, &pool);
3550 for contact in contacts {
3551 if let db::Contact::Accepted {
3552 user_id: contact_user_id,
3553 ..
3554 } = contact
3555 {
3556 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3557 session
3558 .peer
3559 .send(
3560 contact_conn_id,
3561 proto::UpdateContacts {
3562 contacts: vec![updated_contact.clone()],
3563 remove_contacts: Default::default(),
3564 incoming_requests: Default::default(),
3565 remove_incoming_requests: Default::default(),
3566 outgoing_requests: Default::default(),
3567 remove_outgoing_requests: Default::default(),
3568 },
3569 )
3570 .trace_err();
3571 }
3572 }
3573 }
3574 Ok(())
3575}
3576
3577async fn leave_room_for_session(session: &Session) -> Result<()> {
3578 let mut contacts_to_update = HashSet::default();
3579
3580 let room_id;
3581 let canceled_calls_to_user_ids;
3582 let live_kit_room;
3583 let delete_live_kit_room;
3584 let room;
3585 let channel_members;
3586 let channel_id;
3587
3588 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3589 contacts_to_update.insert(session.user_id);
3590
3591 for project in left_room.left_projects.values() {
3592 project_left(project, session);
3593 }
3594
3595 room_id = RoomId::from_proto(left_room.room.id);
3596 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3597 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3598 delete_live_kit_room = left_room.deleted;
3599 room = mem::take(&mut left_room.room);
3600 channel_members = mem::take(&mut left_room.channel_members);
3601 channel_id = left_room.channel_id;
3602
3603 room_updated(&room, &session.peer);
3604 } else {
3605 return Ok(());
3606 }
3607
3608 if let Some(channel_id) = channel_id {
3609 channel_updated(
3610 channel_id,
3611 &room,
3612 &channel_members,
3613 &session.peer,
3614 &*session.connection_pool().await,
3615 );
3616 }
3617
3618 {
3619 let pool = session.connection_pool().await;
3620 for canceled_user_id in canceled_calls_to_user_ids {
3621 for connection_id in pool.user_connection_ids(canceled_user_id) {
3622 session
3623 .peer
3624 .send(
3625 connection_id,
3626 proto::CallCanceled {
3627 room_id: room_id.to_proto(),
3628 },
3629 )
3630 .trace_err();
3631 }
3632 contacts_to_update.insert(canceled_user_id);
3633 }
3634 }
3635
3636 for contact_user_id in contacts_to_update {
3637 update_user_contacts(contact_user_id, &session).await?;
3638 }
3639
3640 if let Some(live_kit) = session.live_kit_client.as_ref() {
3641 live_kit
3642 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3643 .await
3644 .trace_err();
3645
3646 if delete_live_kit_room {
3647 live_kit.delete_room(live_kit_room).await.trace_err();
3648 }
3649 }
3650
3651 Ok(())
3652}
3653
3654async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3655 let left_channel_buffers = session
3656 .db()
3657 .await
3658 .leave_channel_buffers(session.connection_id)
3659 .await?;
3660
3661 for left_buffer in left_channel_buffers {
3662 channel_buffer_updated(
3663 session.connection_id,
3664 left_buffer.connections,
3665 &proto::UpdateChannelBufferCollaborators {
3666 channel_id: left_buffer.channel_id.to_proto(),
3667 collaborators: left_buffer.collaborators,
3668 },
3669 &session.peer,
3670 );
3671 }
3672
3673 Ok(())
3674}
3675
3676fn project_left(project: &db::LeftProject, session: &Session) {
3677 for connection_id in &project.connection_ids {
3678 if project.host_user_id == Some(session.user_id) {
3679 session
3680 .peer
3681 .send(
3682 *connection_id,
3683 proto::UnshareProject {
3684 project_id: project.id.to_proto(),
3685 },
3686 )
3687 .trace_err();
3688 } else {
3689 session
3690 .peer
3691 .send(
3692 *connection_id,
3693 proto::RemoveProjectCollaborator {
3694 project_id: project.id.to_proto(),
3695 peer_id: Some(session.connection_id.into()),
3696 },
3697 )
3698 .trace_err();
3699 }
3700 }
3701}
3702
3703pub trait ResultExt {
3704 type Ok;
3705
3706 fn trace_err(self) -> Option<Self::Ok>;
3707}
3708
3709impl<T, E> ResultExt for Result<T, E>
3710where
3711 E: std::fmt::Debug,
3712{
3713 type Ok = T;
3714
3715 fn trace_err(self) -> Option<T> {
3716 match self {
3717 Ok(value) => Some(value),
3718 Err(error) => {
3719 tracing::error!("{:?}", error);
3720 None
3721 }
3722 }
3723 }
3724}