1mod connection_pool;
2
3use crate::{
4 auth::{self, Impersonator},
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
7 InviteMemberResult, MembershipUpdated, MessageId, NotificationId, ProjectId,
8 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
9 },
10 executor::Executor,
11 AppState, Error, Result,
12};
13use anyhow::anyhow;
14use async_tungstenite::tungstenite::{
15 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
16};
17use axum::{
18 body::Body,
19 extract::{
20 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
21 ConnectInfo, WebSocketUpgrade,
22 },
23 headers::{Header, HeaderName},
24 http::StatusCode,
25 middleware,
26 response::IntoResponse,
27 routing::get,
28 Extension, Router, TypedHeader,
29};
30use collections::{HashMap, HashSet};
31pub use connection_pool::{ConnectionPool, ZedVersion};
32use futures::{
33 channel::oneshot,
34 future::{self, BoxFuture},
35 stream::FuturesUnordered,
36 FutureExt, SinkExt, StreamExt, TryStreamExt,
37};
38use prometheus::{register_int_gauge, IntGauge};
39use rpc::{
40 proto::{
41 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
42 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
43 },
44 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
45};
46use serde::{Serialize, Serializer};
47use std::{
48 any::TypeId,
49 fmt,
50 future::Future,
51 marker::PhantomData,
52 mem,
53 net::SocketAddr,
54 ops::{Deref, DerefMut},
55 rc::Rc,
56 sync::{
57 atomic::{AtomicBool, Ordering::SeqCst},
58 Arc, OnceLock,
59 },
60 time::{Duration, Instant},
61};
62use time::OffsetDateTime;
63use tokio::sync::{watch, Semaphore};
64use tower::ServiceBuilder;
65use tracing::{field, info_span, instrument, Instrument};
66use util::SemanticVersion;
67
68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
70
71const MESSAGE_COUNT_PER_PAGE: usize = 100;
72const MAX_MESSAGE_LEN: usize = 1024;
73const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
74
75type MessageHandler =
76 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
77
78struct Response<R> {
79 peer: Arc<Peer>,
80 receipt: Receipt<R>,
81 responded: Arc<AtomicBool>,
82}
83
84impl<R: RequestMessage> Response<R> {
85 fn send(self, payload: R::Response) -> Result<()> {
86 self.responded.store(true, SeqCst);
87 self.peer.respond(self.receipt, payload)?;
88 Ok(())
89 }
90}
91
92#[derive(Clone)]
93struct Session {
94 user_id: UserId,
95 connection_id: ConnectionId,
96 db: Arc<tokio::sync::Mutex<DbHandle>>,
97 peer: Arc<Peer>,
98 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
99 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
100 _executor: Executor,
101}
102
103impl Session {
104 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
105 #[cfg(test)]
106 tokio::task::yield_now().await;
107 let guard = self.db.lock().await;
108 #[cfg(test)]
109 tokio::task::yield_now().await;
110 guard
111 }
112
113 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
114 #[cfg(test)]
115 tokio::task::yield_now().await;
116 let guard = self.connection_pool.lock();
117 ConnectionPoolGuard {
118 guard,
119 _not_send: PhantomData,
120 }
121 }
122}
123
124impl fmt::Debug for Session {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 f.debug_struct("Session")
127 .field("user_id", &self.user_id)
128 .field("connection_id", &self.connection_id)
129 .finish()
130 }
131}
132
133struct DbHandle(Arc<Database>);
134
135impl Deref for DbHandle {
136 type Target = Database;
137
138 fn deref(&self) -> &Self::Target {
139 self.0.as_ref()
140 }
141}
142
143pub struct Server {
144 id: parking_lot::Mutex<ServerId>,
145 peer: Arc<Peer>,
146 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
147 app_state: Arc<AppState>,
148 executor: Executor,
149 handlers: HashMap<TypeId, MessageHandler>,
150 teardown: watch::Sender<()>,
151}
152
153pub(crate) struct ConnectionPoolGuard<'a> {
154 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
155 _not_send: PhantomData<Rc<()>>,
156}
157
158#[derive(Serialize)]
159pub struct ServerSnapshot<'a> {
160 peer: &'a Peer,
161 #[serde(serialize_with = "serialize_deref")]
162 connection_pool: ConnectionPoolGuard<'a>,
163}
164
165pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
166where
167 S: Serializer,
168 T: Deref<Target = U>,
169 U: Serialize,
170{
171 Serialize::serialize(value.deref(), serializer)
172}
173
174impl Server {
175 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
176 let mut server = Self {
177 id: parking_lot::Mutex::new(id),
178 peer: Peer::new(id.0 as u32),
179 app_state,
180 executor,
181 connection_pool: Default::default(),
182 handlers: Default::default(),
183 teardown: watch::channel(()).0,
184 };
185
186 server
187 .add_request_handler(ping)
188 .add_request_handler(create_room)
189 .add_request_handler(join_room)
190 .add_request_handler(rejoin_room)
191 .add_request_handler(leave_room)
192 .add_request_handler(set_room_participant_role)
193 .add_request_handler(call)
194 .add_request_handler(cancel_call)
195 .add_message_handler(decline_call)
196 .add_request_handler(update_participant_location)
197 .add_request_handler(share_project)
198 .add_message_handler(unshare_project)
199 .add_request_handler(join_project)
200 .add_message_handler(leave_project)
201 .add_request_handler(update_project)
202 .add_request_handler(update_worktree)
203 .add_message_handler(start_language_server)
204 .add_message_handler(update_language_server)
205 .add_message_handler(update_diagnostic_summary)
206 .add_message_handler(update_worktree_settings)
207 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
208 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
209 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
210 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
211 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
212 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
213 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
214 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
215 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
216 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
217 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
218 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
219 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
220 .add_request_handler(
221 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
222 )
223 .add_request_handler(
224 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
225 )
226 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
227 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
228 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
229 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
230 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
231 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
232 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
233 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
234 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
235 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
236 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
237 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
238 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
239 .add_message_handler(create_buffer_for_peer)
240 .add_request_handler(update_buffer)
241 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
242 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
243 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
244 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
245 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
246 .add_request_handler(get_users)
247 .add_request_handler(fuzzy_search_users)
248 .add_request_handler(request_contact)
249 .add_request_handler(remove_contact)
250 .add_request_handler(respond_to_contact_request)
251 .add_request_handler(create_channel)
252 .add_request_handler(delete_channel)
253 .add_request_handler(invite_channel_member)
254 .add_request_handler(remove_channel_member)
255 .add_request_handler(set_channel_member_role)
256 .add_request_handler(set_channel_visibility)
257 .add_request_handler(rename_channel)
258 .add_request_handler(join_channel_buffer)
259 .add_request_handler(leave_channel_buffer)
260 .add_message_handler(update_channel_buffer)
261 .add_request_handler(rejoin_channel_buffers)
262 .add_request_handler(get_channel_members)
263 .add_request_handler(respond_to_channel_invite)
264 .add_request_handler(join_channel)
265 .add_request_handler(join_channel_chat)
266 .add_message_handler(leave_channel_chat)
267 .add_request_handler(send_channel_message)
268 .add_request_handler(remove_channel_message)
269 .add_request_handler(get_channel_messages)
270 .add_request_handler(get_channel_messages_by_id)
271 .add_request_handler(get_notifications)
272 .add_request_handler(mark_notification_as_read)
273 .add_request_handler(move_channel)
274 .add_request_handler(follow)
275 .add_message_handler(unfollow)
276 .add_message_handler(update_followers)
277 .add_request_handler(get_private_user_info)
278 .add_message_handler(acknowledge_channel_message)
279 .add_message_handler(acknowledge_buffer_version);
280
281 Arc::new(server)
282 }
283
284 pub async fn start(&self) -> Result<()> {
285 let server_id = *self.id.lock();
286 let app_state = self.app_state.clone();
287 let peer = self.peer.clone();
288 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
289 let pool = self.connection_pool.clone();
290 let live_kit_client = self.app_state.live_kit_client.clone();
291
292 let span = info_span!("start server");
293 self.executor.spawn_detached(
294 async move {
295 tracing::info!("waiting for cleanup timeout");
296 timeout.await;
297 tracing::info!("cleanup timeout expired, retrieving stale rooms");
298 if let Some((room_ids, channel_ids)) = app_state
299 .db
300 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
301 .await
302 .trace_err()
303 {
304 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
305 tracing::info!(
306 stale_channel_buffer_count = channel_ids.len(),
307 "retrieved stale channel buffers"
308 );
309
310 for channel_id in channel_ids {
311 if let Some(refreshed_channel_buffer) = app_state
312 .db
313 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
314 .await
315 .trace_err()
316 {
317 for connection_id in refreshed_channel_buffer.connection_ids {
318 peer.send(
319 connection_id,
320 proto::UpdateChannelBufferCollaborators {
321 channel_id: channel_id.to_proto(),
322 collaborators: refreshed_channel_buffer
323 .collaborators
324 .clone(),
325 },
326 )
327 .trace_err();
328 }
329 }
330 }
331
332 for room_id in room_ids {
333 let mut contacts_to_update = HashSet::default();
334 let mut canceled_calls_to_user_ids = Vec::new();
335 let mut live_kit_room = String::new();
336 let mut delete_live_kit_room = false;
337
338 if let Some(mut refreshed_room) = app_state
339 .db
340 .clear_stale_room_participants(room_id, server_id)
341 .await
342 .trace_err()
343 {
344 tracing::info!(
345 room_id = room_id.0,
346 new_participant_count = refreshed_room.room.participants.len(),
347 "refreshed room"
348 );
349 room_updated(&refreshed_room.room, &peer);
350 if let Some(channel_id) = refreshed_room.channel_id {
351 channel_updated(
352 channel_id,
353 &refreshed_room.room,
354 &refreshed_room.channel_members,
355 &peer,
356 &*pool.lock(),
357 );
358 }
359 contacts_to_update
360 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
361 contacts_to_update
362 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
363 canceled_calls_to_user_ids =
364 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
365 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
366 delete_live_kit_room = refreshed_room.room.participants.is_empty();
367 }
368
369 {
370 let pool = pool.lock();
371 for canceled_user_id in canceled_calls_to_user_ids {
372 for connection_id in pool.user_connection_ids(canceled_user_id) {
373 peer.send(
374 connection_id,
375 proto::CallCanceled {
376 room_id: room_id.to_proto(),
377 },
378 )
379 .trace_err();
380 }
381 }
382 }
383
384 for user_id in contacts_to_update {
385 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
386 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
387 if let Some((busy, contacts)) = busy.zip(contacts) {
388 let pool = pool.lock();
389 let updated_contact = contact_for_user(user_id, busy, &pool);
390 for contact in contacts {
391 if let db::Contact::Accepted {
392 user_id: contact_user_id,
393 ..
394 } = contact
395 {
396 for contact_conn_id in
397 pool.user_connection_ids(contact_user_id)
398 {
399 peer.send(
400 contact_conn_id,
401 proto::UpdateContacts {
402 contacts: vec![updated_contact.clone()],
403 remove_contacts: Default::default(),
404 incoming_requests: Default::default(),
405 remove_incoming_requests: Default::default(),
406 outgoing_requests: Default::default(),
407 remove_outgoing_requests: Default::default(),
408 },
409 )
410 .trace_err();
411 }
412 }
413 }
414 }
415 }
416
417 if let Some(live_kit) = live_kit_client.as_ref() {
418 if delete_live_kit_room {
419 live_kit.delete_room(live_kit_room).await.trace_err();
420 }
421 }
422 }
423 }
424
425 app_state
426 .db
427 .delete_stale_servers(&app_state.config.zed_environment, server_id)
428 .await
429 .trace_err();
430 }
431 .instrument(span),
432 );
433 Ok(())
434 }
435
436 pub fn teardown(&self) {
437 self.peer.teardown();
438 self.connection_pool.lock().reset();
439 let _ = self.teardown.send(());
440 }
441
442 #[cfg(test)]
443 pub fn reset(&self, id: ServerId) {
444 self.teardown();
445 *self.id.lock() = id;
446 self.peer.reset(id.0 as u32);
447 }
448
449 #[cfg(test)]
450 pub fn id(&self) -> ServerId {
451 *self.id.lock()
452 }
453
454 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
455 where
456 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
457 Fut: 'static + Send + Future<Output = Result<()>>,
458 M: EnvelopedMessage,
459 {
460 let prev_handler = self.handlers.insert(
461 TypeId::of::<M>(),
462 Box::new(move |envelope, session| {
463 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
464 let span = info_span!(
465 "handle message",
466 payload_type = envelope.payload_type_name()
467 );
468 span.in_scope(|| {
469 tracing::info!(
470 payload_type = envelope.payload_type_name(),
471 "message received"
472 );
473 });
474 let start_time = Instant::now();
475 let future = (handler)(*envelope, session);
476 async move {
477 let result = future.await;
478 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
479 match result {
480 Err(error) => {
481 tracing::error!(%error, ?duration_ms, "error handling message")
482 }
483 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
484 }
485 }
486 .instrument(span)
487 .boxed()
488 }),
489 );
490 if prev_handler.is_some() {
491 panic!("registered a handler for the same message twice");
492 }
493 self
494 }
495
496 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
497 where
498 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
499 Fut: 'static + Send + Future<Output = Result<()>>,
500 M: EnvelopedMessage,
501 {
502 self.add_handler(move |envelope, session| handler(envelope.payload, session));
503 self
504 }
505
506 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
507 where
508 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
509 Fut: Send + Future<Output = Result<()>>,
510 M: RequestMessage,
511 {
512 let handler = Arc::new(handler);
513 self.add_handler(move |envelope, session| {
514 let receipt = envelope.receipt();
515 let handler = handler.clone();
516 async move {
517 let peer = session.peer.clone();
518 let responded = Arc::new(AtomicBool::default());
519 let response = Response {
520 peer: peer.clone(),
521 responded: responded.clone(),
522 receipt,
523 };
524 match (handler)(envelope.payload, response, session).await {
525 Ok(()) => {
526 if responded.load(std::sync::atomic::Ordering::SeqCst) {
527 Ok(())
528 } else {
529 Err(anyhow!("handler did not send a response"))?
530 }
531 }
532 Err(error) => {
533 let proto_err = match &error {
534 Error::Internal(err) => err.to_proto(),
535 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
536 };
537 peer.respond_with_error(receipt, proto_err)?;
538 Err(error)
539 }
540 }
541 }
542 })
543 }
544
545 pub fn handle_connection(
546 self: &Arc<Self>,
547 connection: Connection,
548 address: String,
549 user: User,
550 zed_version: ZedVersion,
551 impersonator: Option<User>,
552 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
553 executor: Executor,
554 ) -> impl Future<Output = Result<()>> {
555 let this = self.clone();
556 let user_id = user.id;
557 let login = user.github_login;
558 let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
559 if let Some(impersonator) = impersonator {
560 span.record("impersonator", &impersonator.github_login);
561 }
562 let mut teardown = self.teardown.subscribe();
563 async move {
564 let (connection_id, handle_io, mut incoming_rx) = this
565 .peer
566 .add_connection(connection, {
567 let executor = executor.clone();
568 move |duration| executor.sleep(duration)
569 });
570
571 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
572 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
573 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
574
575 if let Some(send_connection_id) = send_connection_id.take() {
576 let _ = send_connection_id.send(connection_id);
577 }
578
579 if !user.connected_once {
580 this.peer.send(connection_id, proto::ShowContacts {})?;
581 this.app_state.db.set_user_connected_once(user_id, true).await?;
582 }
583
584 let (contacts, channels_for_user, channel_invites) = future::try_join3(
585 this.app_state.db.get_contacts(user_id),
586 this.app_state.db.get_channels_for_user(user_id),
587 this.app_state.db.get_channel_invites_for_user(user_id),
588 ).await?;
589
590 {
591 let mut pool = this.connection_pool.lock();
592 pool.add_connection(connection_id, user_id, user.admin, zed_version);
593 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
594 this.peer.send(connection_id, build_update_user_channels(&channels_for_user))?;
595 this.peer.send(connection_id, build_channels_update(
596 channels_for_user,
597 channel_invites
598 ))?;
599 }
600
601 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
602 this.peer.send(connection_id, incoming_call)?;
603 }
604
605 let session = Session {
606 user_id,
607 connection_id,
608 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
609 peer: this.peer.clone(),
610 connection_pool: this.connection_pool.clone(),
611 live_kit_client: this.app_state.live_kit_client.clone(),
612 _executor: executor.clone()
613 };
614 update_user_contacts(user_id, &session).await?;
615
616 let handle_io = handle_io.fuse();
617 futures::pin_mut!(handle_io);
618
619 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
620 // This prevents deadlocks when e.g., client A performs a request to client B and
621 // client B performs a request to client A. If both clients stop processing further
622 // messages until their respective request completes, they won't have a chance to
623 // respond to the other client's request and cause a deadlock.
624 //
625 // This arrangement ensures we will attempt to process earlier messages first, but fall
626 // back to processing messages arrived later in the spirit of making progress.
627 let mut foreground_message_handlers = FuturesUnordered::new();
628 let concurrent_handlers = Arc::new(Semaphore::new(256));
629 loop {
630 let next_message = async {
631 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
632 let message = incoming_rx.next().await;
633 (permit, message)
634 }.fuse();
635 futures::pin_mut!(next_message);
636 futures::select_biased! {
637 _ = teardown.changed().fuse() => return Ok(()),
638 result = handle_io => {
639 if let Err(error) = result {
640 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
641 }
642 break;
643 }
644 _ = foreground_message_handlers.next() => {}
645 next_message = next_message => {
646 let (permit, message) = next_message;
647 if let Some(message) = message {
648 let type_name = message.payload_type_name();
649 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
650 let span_enter = span.enter();
651 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
652 let is_background = message.is_background();
653 let handle_message = (handler)(message, session.clone());
654 drop(span_enter);
655
656 let handle_message = async move {
657 handle_message.await;
658 drop(permit);
659 }.instrument(span);
660 if is_background {
661 executor.spawn_detached(handle_message);
662 } else {
663 foreground_message_handlers.push(handle_message);
664 }
665 } else {
666 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
667 }
668 } else {
669 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
670 break;
671 }
672 }
673 }
674 }
675
676 drop(foreground_message_handlers);
677 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
678 if let Err(error) = connection_lost(session, teardown, executor).await {
679 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
680 }
681
682 Ok(())
683 }.instrument(span)
684 }
685
686 pub async fn invite_code_redeemed(
687 self: &Arc<Self>,
688 inviter_id: UserId,
689 invitee_id: UserId,
690 ) -> Result<()> {
691 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
692 if let Some(code) = &user.invite_code {
693 let pool = self.connection_pool.lock();
694 let invitee_contact = contact_for_user(invitee_id, false, &pool);
695 for connection_id in pool.user_connection_ids(inviter_id) {
696 self.peer.send(
697 connection_id,
698 proto::UpdateContacts {
699 contacts: vec![invitee_contact.clone()],
700 ..Default::default()
701 },
702 )?;
703 self.peer.send(
704 connection_id,
705 proto::UpdateInviteInfo {
706 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
707 count: user.invite_count as u32,
708 },
709 )?;
710 }
711 }
712 }
713 Ok(())
714 }
715
716 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
717 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
718 if let Some(invite_code) = &user.invite_code {
719 let pool = self.connection_pool.lock();
720 for connection_id in pool.user_connection_ids(user_id) {
721 self.peer.send(
722 connection_id,
723 proto::UpdateInviteInfo {
724 url: format!(
725 "{}{}",
726 self.app_state.config.invite_link_prefix, invite_code
727 ),
728 count: user.invite_count as u32,
729 },
730 )?;
731 }
732 }
733 }
734 Ok(())
735 }
736
737 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
738 ServerSnapshot {
739 connection_pool: ConnectionPoolGuard {
740 guard: self.connection_pool.lock(),
741 _not_send: PhantomData,
742 },
743 peer: &self.peer,
744 }
745 }
746}
747
748impl<'a> Deref for ConnectionPoolGuard<'a> {
749 type Target = ConnectionPool;
750
751 fn deref(&self) -> &Self::Target {
752 &*self.guard
753 }
754}
755
756impl<'a> DerefMut for ConnectionPoolGuard<'a> {
757 fn deref_mut(&mut self) -> &mut Self::Target {
758 &mut *self.guard
759 }
760}
761
762impl<'a> Drop for ConnectionPoolGuard<'a> {
763 fn drop(&mut self) {
764 #[cfg(test)]
765 self.check_invariants();
766 }
767}
768
769fn broadcast<F>(
770 sender_id: Option<ConnectionId>,
771 receiver_ids: impl IntoIterator<Item = ConnectionId>,
772 mut f: F,
773) where
774 F: FnMut(ConnectionId) -> anyhow::Result<()>,
775{
776 for receiver_id in receiver_ids {
777 if Some(receiver_id) != sender_id {
778 if let Err(error) = f(receiver_id) {
779 tracing::error!("failed to send to {:?} {}", receiver_id, error);
780 }
781 }
782 }
783}
784
785pub struct ProtocolVersion(u32);
786
787impl Header for ProtocolVersion {
788 fn name() -> &'static HeaderName {
789 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
790 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
791 }
792
793 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
794 where
795 Self: Sized,
796 I: Iterator<Item = &'i axum::http::HeaderValue>,
797 {
798 let version = values
799 .next()
800 .ok_or_else(axum::headers::Error::invalid)?
801 .to_str()
802 .map_err(|_| axum::headers::Error::invalid())?
803 .parse()
804 .map_err(|_| axum::headers::Error::invalid())?;
805 Ok(Self(version))
806 }
807
808 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
809 values.extend([self.0.to_string().parse().unwrap()]);
810 }
811}
812
813pub struct AppVersionHeader(SemanticVersion);
814impl Header for AppVersionHeader {
815 fn name() -> &'static HeaderName {
816 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
817 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
818 }
819
820 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
821 where
822 Self: Sized,
823 I: Iterator<Item = &'i axum::http::HeaderValue>,
824 {
825 let version = values
826 .next()
827 .ok_or_else(axum::headers::Error::invalid)?
828 .to_str()
829 .map_err(|_| axum::headers::Error::invalid())?
830 .parse()
831 .map_err(|_| axum::headers::Error::invalid())?;
832 Ok(Self(version))
833 }
834
835 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
836 values.extend([self.0.to_string().parse().unwrap()]);
837 }
838}
839
840pub fn routes(server: Arc<Server>) -> Router<Body> {
841 Router::new()
842 .route("/rpc", get(handle_websocket_request))
843 .layer(
844 ServiceBuilder::new()
845 .layer(Extension(server.app_state.clone()))
846 .layer(middleware::from_fn(auth::validate_header)),
847 )
848 .route("/metrics", get(handle_metrics))
849 .layer(Extension(server))
850}
851
852pub async fn handle_websocket_request(
853 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
854 app_version_header: Option<TypedHeader<AppVersionHeader>>,
855 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
856 Extension(server): Extension<Arc<Server>>,
857 Extension(user): Extension<User>,
858 Extension(impersonator): Extension<Impersonator>,
859 ws: WebSocketUpgrade,
860) -> axum::response::Response {
861 if protocol_version != rpc::PROTOCOL_VERSION {
862 return (
863 StatusCode::UPGRADE_REQUIRED,
864 "client must be upgraded".to_string(),
865 )
866 .into_response();
867 }
868
869 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
870 return (
871 StatusCode::UPGRADE_REQUIRED,
872 "no version header found".to_string(),
873 )
874 .into_response();
875 };
876
877 if !version.is_supported() {
878 return (
879 StatusCode::UPGRADE_REQUIRED,
880 "client must be upgraded".to_string(),
881 )
882 .into_response();
883 }
884
885 let socket_address = socket_address.to_string();
886 ws.on_upgrade(move |socket| {
887 use util::ResultExt;
888 let socket = socket
889 .map_ok(to_tungstenite_message)
890 .err_into()
891 .with(|message| async move { Ok(to_axum_message(message)) });
892 let connection = Connection::new(Box::pin(socket));
893 async move {
894 server
895 .handle_connection(
896 connection,
897 socket_address,
898 user,
899 version,
900 impersonator.0,
901 None,
902 Executor::Production,
903 )
904 .await
905 .log_err();
906 }
907 })
908}
909
910pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
911 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
912 let connections_metric = CONNECTIONS_METRIC
913 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
914
915 let connections = server
916 .connection_pool
917 .lock()
918 .connections()
919 .filter(|connection| !connection.admin)
920 .count();
921 connections_metric.set(connections as _);
922
923 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
924 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
925 register_int_gauge!(
926 "shared_projects",
927 "number of open projects with one or more guests"
928 )
929 .unwrap()
930 });
931
932 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
933 shared_projects_metric.set(shared_projects as _);
934
935 let encoder = prometheus::TextEncoder::new();
936 let metric_families = prometheus::gather();
937 let encoded_metrics = encoder
938 .encode_to_string(&metric_families)
939 .map_err(|err| anyhow!("{}", err))?;
940 Ok(encoded_metrics)
941}
942
943#[instrument(err, skip(executor))]
944async fn connection_lost(
945 session: Session,
946 mut teardown: watch::Receiver<()>,
947 executor: Executor,
948) -> Result<()> {
949 session.peer.disconnect(session.connection_id);
950 session
951 .connection_pool()
952 .await
953 .remove_connection(session.connection_id)?;
954
955 session
956 .db()
957 .await
958 .connection_lost(session.connection_id)
959 .await
960 .trace_err();
961
962 futures::select_biased! {
963 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
964 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
965 leave_room_for_session(&session).await.trace_err();
966 leave_channel_buffers_for_session(&session)
967 .await
968 .trace_err();
969
970 if !session
971 .connection_pool()
972 .await
973 .is_user_online(session.user_id)
974 {
975 let db = session.db().await;
976 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
977 room_updated(&room, &session.peer);
978 }
979 }
980
981 update_user_contacts(session.user_id, &session).await?;
982 }
983 _ = teardown.changed().fuse() => {}
984 }
985
986 Ok(())
987}
988
989/// Acknowledges a ping from a client, used to keep the connection alive.
990async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
991 response.send(proto::Ack {})?;
992 Ok(())
993}
994
995/// Creates a new room for calling (outside of channels)
996async fn create_room(
997 _request: proto::CreateRoom,
998 response: Response<proto::CreateRoom>,
999 session: Session,
1000) -> Result<()> {
1001 let live_kit_room = nanoid::nanoid!(30);
1002
1003 let live_kit_connection_info = {
1004 let live_kit_room = live_kit_room.clone();
1005 let live_kit = session.live_kit_client.as_ref();
1006
1007 util::async_maybe!({
1008 let live_kit = live_kit?;
1009
1010 let token = live_kit
1011 .room_token(&live_kit_room, &session.user_id.to_string())
1012 .trace_err()?;
1013
1014 Some(proto::LiveKitConnectionInfo {
1015 server_url: live_kit.url().into(),
1016 token,
1017 can_publish: true,
1018 })
1019 })
1020 }
1021 .await;
1022
1023 let room = session
1024 .db()
1025 .await
1026 .create_room(session.user_id, session.connection_id, &live_kit_room)
1027 .await?;
1028
1029 response.send(proto::CreateRoomResponse {
1030 room: Some(room.clone()),
1031 live_kit_connection_info,
1032 })?;
1033
1034 update_user_contacts(session.user_id, &session).await?;
1035 Ok(())
1036}
1037
1038/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1039async fn join_room(
1040 request: proto::JoinRoom,
1041 response: Response<proto::JoinRoom>,
1042 session: Session,
1043) -> Result<()> {
1044 let room_id = RoomId::from_proto(request.id);
1045
1046 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1047
1048 if let Some(channel_id) = channel_id {
1049 return join_channel_internal(channel_id, Box::new(response), session).await;
1050 }
1051
1052 let joined_room = {
1053 let room = session
1054 .db()
1055 .await
1056 .join_room(room_id, session.user_id, session.connection_id)
1057 .await?;
1058 room_updated(&room.room, &session.peer);
1059 room.into_inner()
1060 };
1061
1062 for connection_id in session
1063 .connection_pool()
1064 .await
1065 .user_connection_ids(session.user_id)
1066 {
1067 session
1068 .peer
1069 .send(
1070 connection_id,
1071 proto::CallCanceled {
1072 room_id: room_id.to_proto(),
1073 },
1074 )
1075 .trace_err();
1076 }
1077
1078 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1079 if let Some(token) = live_kit
1080 .room_token(
1081 &joined_room.room.live_kit_room,
1082 &session.user_id.to_string(),
1083 )
1084 .trace_err()
1085 {
1086 Some(proto::LiveKitConnectionInfo {
1087 server_url: live_kit.url().into(),
1088 token,
1089 can_publish: true,
1090 })
1091 } else {
1092 None
1093 }
1094 } else {
1095 None
1096 };
1097
1098 response.send(proto::JoinRoomResponse {
1099 room: Some(joined_room.room),
1100 channel_id: None,
1101 live_kit_connection_info,
1102 })?;
1103
1104 update_user_contacts(session.user_id, &session).await?;
1105 Ok(())
1106}
1107
1108/// Rejoin room is used to reconnect to a room after connection errors.
1109async fn rejoin_room(
1110 request: proto::RejoinRoom,
1111 response: Response<proto::RejoinRoom>,
1112 session: Session,
1113) -> Result<()> {
1114 let room;
1115 let channel_id;
1116 let channel_members;
1117 {
1118 let mut rejoined_room = session
1119 .db()
1120 .await
1121 .rejoin_room(request, session.user_id, session.connection_id)
1122 .await?;
1123
1124 response.send(proto::RejoinRoomResponse {
1125 room: Some(rejoined_room.room.clone()),
1126 reshared_projects: rejoined_room
1127 .reshared_projects
1128 .iter()
1129 .map(|project| proto::ResharedProject {
1130 id: project.id.to_proto(),
1131 collaborators: project
1132 .collaborators
1133 .iter()
1134 .map(|collaborator| collaborator.to_proto())
1135 .collect(),
1136 })
1137 .collect(),
1138 rejoined_projects: rejoined_room
1139 .rejoined_projects
1140 .iter()
1141 .map(|rejoined_project| proto::RejoinedProject {
1142 id: rejoined_project.id.to_proto(),
1143 worktrees: rejoined_project
1144 .worktrees
1145 .iter()
1146 .map(|worktree| proto::WorktreeMetadata {
1147 id: worktree.id,
1148 root_name: worktree.root_name.clone(),
1149 visible: worktree.visible,
1150 abs_path: worktree.abs_path.clone(),
1151 })
1152 .collect(),
1153 collaborators: rejoined_project
1154 .collaborators
1155 .iter()
1156 .map(|collaborator| collaborator.to_proto())
1157 .collect(),
1158 language_servers: rejoined_project.language_servers.clone(),
1159 })
1160 .collect(),
1161 })?;
1162 room_updated(&rejoined_room.room, &session.peer);
1163
1164 for project in &rejoined_room.reshared_projects {
1165 for collaborator in &project.collaborators {
1166 session
1167 .peer
1168 .send(
1169 collaborator.connection_id,
1170 proto::UpdateProjectCollaborator {
1171 project_id: project.id.to_proto(),
1172 old_peer_id: Some(project.old_connection_id.into()),
1173 new_peer_id: Some(session.connection_id.into()),
1174 },
1175 )
1176 .trace_err();
1177 }
1178
1179 broadcast(
1180 Some(session.connection_id),
1181 project
1182 .collaborators
1183 .iter()
1184 .map(|collaborator| collaborator.connection_id),
1185 |connection_id| {
1186 session.peer.forward_send(
1187 session.connection_id,
1188 connection_id,
1189 proto::UpdateProject {
1190 project_id: project.id.to_proto(),
1191 worktrees: project.worktrees.clone(),
1192 },
1193 )
1194 },
1195 );
1196 }
1197
1198 for project in &rejoined_room.rejoined_projects {
1199 for collaborator in &project.collaborators {
1200 session
1201 .peer
1202 .send(
1203 collaborator.connection_id,
1204 proto::UpdateProjectCollaborator {
1205 project_id: project.id.to_proto(),
1206 old_peer_id: Some(project.old_connection_id.into()),
1207 new_peer_id: Some(session.connection_id.into()),
1208 },
1209 )
1210 .trace_err();
1211 }
1212 }
1213
1214 for project in &mut rejoined_room.rejoined_projects {
1215 for worktree in mem::take(&mut project.worktrees) {
1216 #[cfg(any(test, feature = "test-support"))]
1217 const MAX_CHUNK_SIZE: usize = 2;
1218 #[cfg(not(any(test, feature = "test-support")))]
1219 const MAX_CHUNK_SIZE: usize = 256;
1220
1221 // Stream this worktree's entries.
1222 let message = proto::UpdateWorktree {
1223 project_id: project.id.to_proto(),
1224 worktree_id: worktree.id,
1225 abs_path: worktree.abs_path.clone(),
1226 root_name: worktree.root_name,
1227 updated_entries: worktree.updated_entries,
1228 removed_entries: worktree.removed_entries,
1229 scan_id: worktree.scan_id,
1230 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1231 updated_repositories: worktree.updated_repositories,
1232 removed_repositories: worktree.removed_repositories,
1233 };
1234 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1235 session.peer.send(session.connection_id, update.clone())?;
1236 }
1237
1238 // Stream this worktree's diagnostics.
1239 for summary in worktree.diagnostic_summaries {
1240 session.peer.send(
1241 session.connection_id,
1242 proto::UpdateDiagnosticSummary {
1243 project_id: project.id.to_proto(),
1244 worktree_id: worktree.id,
1245 summary: Some(summary),
1246 },
1247 )?;
1248 }
1249
1250 for settings_file in worktree.settings_files {
1251 session.peer.send(
1252 session.connection_id,
1253 proto::UpdateWorktreeSettings {
1254 project_id: project.id.to_proto(),
1255 worktree_id: worktree.id,
1256 path: settings_file.path,
1257 content: Some(settings_file.content),
1258 },
1259 )?;
1260 }
1261 }
1262
1263 for language_server in &project.language_servers {
1264 session.peer.send(
1265 session.connection_id,
1266 proto::UpdateLanguageServer {
1267 project_id: project.id.to_proto(),
1268 language_server_id: language_server.id,
1269 variant: Some(
1270 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1271 proto::LspDiskBasedDiagnosticsUpdated {},
1272 ),
1273 ),
1274 },
1275 )?;
1276 }
1277 }
1278
1279 let rejoined_room = rejoined_room.into_inner();
1280
1281 room = rejoined_room.room;
1282 channel_id = rejoined_room.channel_id;
1283 channel_members = rejoined_room.channel_members;
1284 }
1285
1286 if let Some(channel_id) = channel_id {
1287 channel_updated(
1288 channel_id,
1289 &room,
1290 &channel_members,
1291 &session.peer,
1292 &*session.connection_pool().await,
1293 );
1294 }
1295
1296 update_user_contacts(session.user_id, &session).await?;
1297 Ok(())
1298}
1299
1300/// leave room disconnects from the room.
1301async fn leave_room(
1302 _: proto::LeaveRoom,
1303 response: Response<proto::LeaveRoom>,
1304 session: Session,
1305) -> Result<()> {
1306 leave_room_for_session(&session).await?;
1307 response.send(proto::Ack {})?;
1308 Ok(())
1309}
1310
1311/// Updates the permissions of someone else in the room.
1312async fn set_room_participant_role(
1313 request: proto::SetRoomParticipantRole,
1314 response: Response<proto::SetRoomParticipantRole>,
1315 session: Session,
1316) -> Result<()> {
1317 let user_id = UserId::from_proto(request.user_id);
1318 let role = ChannelRole::from(request.role());
1319
1320 if role == ChannelRole::Talker {
1321 let pool = session.connection_pool().await;
1322
1323 for connection in pool.user_connections(user_id) {
1324 if !connection.zed_version.supports_talker_role() {
1325 Err(anyhow!(
1326 "This user is on zed {} which does not support unmute",
1327 connection.zed_version
1328 ))?;
1329 }
1330 }
1331 }
1332
1333 let (live_kit_room, can_publish) = {
1334 let room = session
1335 .db()
1336 .await
1337 .set_room_participant_role(
1338 session.user_id,
1339 RoomId::from_proto(request.room_id),
1340 user_id,
1341 role,
1342 )
1343 .await?;
1344
1345 let live_kit_room = room.live_kit_room.clone();
1346 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1347 room_updated(&room, &session.peer);
1348 (live_kit_room, can_publish)
1349 };
1350
1351 if let Some(live_kit) = session.live_kit_client.as_ref() {
1352 live_kit
1353 .update_participant(
1354 live_kit_room.clone(),
1355 request.user_id.to_string(),
1356 live_kit_server::proto::ParticipantPermission {
1357 can_subscribe: true,
1358 can_publish,
1359 can_publish_data: can_publish,
1360 hidden: false,
1361 recorder: false,
1362 },
1363 )
1364 .await
1365 .trace_err();
1366 }
1367
1368 response.send(proto::Ack {})?;
1369 Ok(())
1370}
1371
1372/// Call someone else into the current room
1373async fn call(
1374 request: proto::Call,
1375 response: Response<proto::Call>,
1376 session: Session,
1377) -> Result<()> {
1378 let room_id = RoomId::from_proto(request.room_id);
1379 let calling_user_id = session.user_id;
1380 let calling_connection_id = session.connection_id;
1381 let called_user_id = UserId::from_proto(request.called_user_id);
1382 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1383 if !session
1384 .db()
1385 .await
1386 .has_contact(calling_user_id, called_user_id)
1387 .await?
1388 {
1389 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1390 }
1391
1392 let incoming_call = {
1393 let (room, incoming_call) = &mut *session
1394 .db()
1395 .await
1396 .call(
1397 room_id,
1398 calling_user_id,
1399 calling_connection_id,
1400 called_user_id,
1401 initial_project_id,
1402 )
1403 .await?;
1404 room_updated(&room, &session.peer);
1405 mem::take(incoming_call)
1406 };
1407 update_user_contacts(called_user_id, &session).await?;
1408
1409 let mut calls = session
1410 .connection_pool()
1411 .await
1412 .user_connection_ids(called_user_id)
1413 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1414 .collect::<FuturesUnordered<_>>();
1415
1416 while let Some(call_response) = calls.next().await {
1417 match call_response.as_ref() {
1418 Ok(_) => {
1419 response.send(proto::Ack {})?;
1420 return Ok(());
1421 }
1422 Err(_) => {
1423 call_response.trace_err();
1424 }
1425 }
1426 }
1427
1428 {
1429 let room = session
1430 .db()
1431 .await
1432 .call_failed(room_id, called_user_id)
1433 .await?;
1434 room_updated(&room, &session.peer);
1435 }
1436 update_user_contacts(called_user_id, &session).await?;
1437
1438 Err(anyhow!("failed to ring user"))?
1439}
1440
1441/// Cancel an outgoing call.
1442async fn cancel_call(
1443 request: proto::CancelCall,
1444 response: Response<proto::CancelCall>,
1445 session: Session,
1446) -> Result<()> {
1447 let called_user_id = UserId::from_proto(request.called_user_id);
1448 let room_id = RoomId::from_proto(request.room_id);
1449 {
1450 let room = session
1451 .db()
1452 .await
1453 .cancel_call(room_id, session.connection_id, called_user_id)
1454 .await?;
1455 room_updated(&room, &session.peer);
1456 }
1457
1458 for connection_id in session
1459 .connection_pool()
1460 .await
1461 .user_connection_ids(called_user_id)
1462 {
1463 session
1464 .peer
1465 .send(
1466 connection_id,
1467 proto::CallCanceled {
1468 room_id: room_id.to_proto(),
1469 },
1470 )
1471 .trace_err();
1472 }
1473 response.send(proto::Ack {})?;
1474
1475 update_user_contacts(called_user_id, &session).await?;
1476 Ok(())
1477}
1478
1479/// Decline an incoming call.
1480async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1481 let room_id = RoomId::from_proto(message.room_id);
1482 {
1483 let room = session
1484 .db()
1485 .await
1486 .decline_call(Some(room_id), session.user_id)
1487 .await?
1488 .ok_or_else(|| anyhow!("failed to decline call"))?;
1489 room_updated(&room, &session.peer);
1490 }
1491
1492 for connection_id in session
1493 .connection_pool()
1494 .await
1495 .user_connection_ids(session.user_id)
1496 {
1497 session
1498 .peer
1499 .send(
1500 connection_id,
1501 proto::CallCanceled {
1502 room_id: room_id.to_proto(),
1503 },
1504 )
1505 .trace_err();
1506 }
1507 update_user_contacts(session.user_id, &session).await?;
1508 Ok(())
1509}
1510
1511/// Updates other participants in the room with your current location.
1512async fn update_participant_location(
1513 request: proto::UpdateParticipantLocation,
1514 response: Response<proto::UpdateParticipantLocation>,
1515 session: Session,
1516) -> Result<()> {
1517 let room_id = RoomId::from_proto(request.room_id);
1518 let location = request
1519 .location
1520 .ok_or_else(|| anyhow!("invalid location"))?;
1521
1522 let db = session.db().await;
1523 let room = db
1524 .update_room_participant_location(room_id, session.connection_id, location)
1525 .await?;
1526
1527 room_updated(&room, &session.peer);
1528 response.send(proto::Ack {})?;
1529 Ok(())
1530}
1531
1532/// Share a project into the room.
1533async fn share_project(
1534 request: proto::ShareProject,
1535 response: Response<proto::ShareProject>,
1536 session: Session,
1537) -> Result<()> {
1538 let (project_id, room) = &*session
1539 .db()
1540 .await
1541 .share_project(
1542 RoomId::from_proto(request.room_id),
1543 session.connection_id,
1544 &request.worktrees,
1545 )
1546 .await?;
1547 response.send(proto::ShareProjectResponse {
1548 project_id: project_id.to_proto(),
1549 })?;
1550 room_updated(&room, &session.peer);
1551
1552 Ok(())
1553}
1554
1555/// Unshare a project from the room.
1556async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1557 let project_id = ProjectId::from_proto(message.project_id);
1558
1559 let (room, guest_connection_ids) = &*session
1560 .db()
1561 .await
1562 .unshare_project(project_id, session.connection_id)
1563 .await?;
1564
1565 broadcast(
1566 Some(session.connection_id),
1567 guest_connection_ids.iter().copied(),
1568 |conn_id| session.peer.send(conn_id, message.clone()),
1569 );
1570 room_updated(&room, &session.peer);
1571
1572 Ok(())
1573}
1574
1575/// Join someone elses shared project.
1576async fn join_project(
1577 request: proto::JoinProject,
1578 response: Response<proto::JoinProject>,
1579 session: Session,
1580) -> Result<()> {
1581 let project_id = ProjectId::from_proto(request.project_id);
1582 let guest_user_id = session.user_id;
1583
1584 tracing::info!(%project_id, "join project");
1585
1586 let (project, replica_id) = &mut *session
1587 .db()
1588 .await
1589 .join_project(project_id, session.connection_id)
1590 .await?;
1591
1592 let collaborators = project
1593 .collaborators
1594 .iter()
1595 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1596 .map(|collaborator| collaborator.to_proto())
1597 .collect::<Vec<_>>();
1598
1599 let worktrees = project
1600 .worktrees
1601 .iter()
1602 .map(|(id, worktree)| proto::WorktreeMetadata {
1603 id: *id,
1604 root_name: worktree.root_name.clone(),
1605 visible: worktree.visible,
1606 abs_path: worktree.abs_path.clone(),
1607 })
1608 .collect::<Vec<_>>();
1609
1610 for collaborator in &collaborators {
1611 session
1612 .peer
1613 .send(
1614 collaborator.peer_id.unwrap().into(),
1615 proto::AddProjectCollaborator {
1616 project_id: project_id.to_proto(),
1617 collaborator: Some(proto::Collaborator {
1618 peer_id: Some(session.connection_id.into()),
1619 replica_id: replica_id.0 as u32,
1620 user_id: guest_user_id.to_proto(),
1621 }),
1622 },
1623 )
1624 .trace_err();
1625 }
1626
1627 // First, we send the metadata associated with each worktree.
1628 response.send(proto::JoinProjectResponse {
1629 worktrees: worktrees.clone(),
1630 replica_id: replica_id.0 as u32,
1631 collaborators: collaborators.clone(),
1632 language_servers: project.language_servers.clone(),
1633 })?;
1634
1635 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1636 #[cfg(any(test, feature = "test-support"))]
1637 const MAX_CHUNK_SIZE: usize = 2;
1638 #[cfg(not(any(test, feature = "test-support")))]
1639 const MAX_CHUNK_SIZE: usize = 256;
1640
1641 // Stream this worktree's entries.
1642 let message = proto::UpdateWorktree {
1643 project_id: project_id.to_proto(),
1644 worktree_id,
1645 abs_path: worktree.abs_path.clone(),
1646 root_name: worktree.root_name,
1647 updated_entries: worktree.entries,
1648 removed_entries: Default::default(),
1649 scan_id: worktree.scan_id,
1650 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1651 updated_repositories: worktree.repository_entries.into_values().collect(),
1652 removed_repositories: Default::default(),
1653 };
1654 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1655 session.peer.send(session.connection_id, update.clone())?;
1656 }
1657
1658 // Stream this worktree's diagnostics.
1659 for summary in worktree.diagnostic_summaries {
1660 session.peer.send(
1661 session.connection_id,
1662 proto::UpdateDiagnosticSummary {
1663 project_id: project_id.to_proto(),
1664 worktree_id: worktree.id,
1665 summary: Some(summary),
1666 },
1667 )?;
1668 }
1669
1670 for settings_file in worktree.settings_files {
1671 session.peer.send(
1672 session.connection_id,
1673 proto::UpdateWorktreeSettings {
1674 project_id: project_id.to_proto(),
1675 worktree_id: worktree.id,
1676 path: settings_file.path,
1677 content: Some(settings_file.content),
1678 },
1679 )?;
1680 }
1681 }
1682
1683 for language_server in &project.language_servers {
1684 session.peer.send(
1685 session.connection_id,
1686 proto::UpdateLanguageServer {
1687 project_id: project_id.to_proto(),
1688 language_server_id: language_server.id,
1689 variant: Some(
1690 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1691 proto::LspDiskBasedDiagnosticsUpdated {},
1692 ),
1693 ),
1694 },
1695 )?;
1696 }
1697
1698 Ok(())
1699}
1700
1701/// Leave someone elses shared project.
1702async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1703 let sender_id = session.connection_id;
1704 let project_id = ProjectId::from_proto(request.project_id);
1705
1706 let (room, project) = &*session
1707 .db()
1708 .await
1709 .leave_project(project_id, sender_id)
1710 .await?;
1711 tracing::info!(
1712 %project_id,
1713 host_user_id = %project.host_user_id,
1714 host_connection_id = ?project.host_connection_id,
1715 "leave project"
1716 );
1717
1718 project_left(&project, &session);
1719 room_updated(&room, &session.peer);
1720
1721 Ok(())
1722}
1723
1724/// Updates other participants with changes to the project
1725async fn update_project(
1726 request: proto::UpdateProject,
1727 response: Response<proto::UpdateProject>,
1728 session: Session,
1729) -> Result<()> {
1730 let project_id = ProjectId::from_proto(request.project_id);
1731 let (room, guest_connection_ids) = &*session
1732 .db()
1733 .await
1734 .update_project(project_id, session.connection_id, &request.worktrees)
1735 .await?;
1736 broadcast(
1737 Some(session.connection_id),
1738 guest_connection_ids.iter().copied(),
1739 |connection_id| {
1740 session
1741 .peer
1742 .forward_send(session.connection_id, connection_id, request.clone())
1743 },
1744 );
1745 room_updated(&room, &session.peer);
1746 response.send(proto::Ack {})?;
1747
1748 Ok(())
1749}
1750
1751/// Updates other participants with changes to the worktree
1752async fn update_worktree(
1753 request: proto::UpdateWorktree,
1754 response: Response<proto::UpdateWorktree>,
1755 session: Session,
1756) -> Result<()> {
1757 let guest_connection_ids = session
1758 .db()
1759 .await
1760 .update_worktree(&request, session.connection_id)
1761 .await?;
1762
1763 broadcast(
1764 Some(session.connection_id),
1765 guest_connection_ids.iter().copied(),
1766 |connection_id| {
1767 session
1768 .peer
1769 .forward_send(session.connection_id, connection_id, request.clone())
1770 },
1771 );
1772 response.send(proto::Ack {})?;
1773 Ok(())
1774}
1775
1776/// Updates other participants with changes to the diagnostics
1777async fn update_diagnostic_summary(
1778 message: proto::UpdateDiagnosticSummary,
1779 session: Session,
1780) -> Result<()> {
1781 let guest_connection_ids = session
1782 .db()
1783 .await
1784 .update_diagnostic_summary(&message, session.connection_id)
1785 .await?;
1786
1787 broadcast(
1788 Some(session.connection_id),
1789 guest_connection_ids.iter().copied(),
1790 |connection_id| {
1791 session
1792 .peer
1793 .forward_send(session.connection_id, connection_id, message.clone())
1794 },
1795 );
1796
1797 Ok(())
1798}
1799
1800/// Updates other participants with changes to the worktree settings
1801async fn update_worktree_settings(
1802 message: proto::UpdateWorktreeSettings,
1803 session: Session,
1804) -> Result<()> {
1805 let guest_connection_ids = session
1806 .db()
1807 .await
1808 .update_worktree_settings(&message, session.connection_id)
1809 .await?;
1810
1811 broadcast(
1812 Some(session.connection_id),
1813 guest_connection_ids.iter().copied(),
1814 |connection_id| {
1815 session
1816 .peer
1817 .forward_send(session.connection_id, connection_id, message.clone())
1818 },
1819 );
1820
1821 Ok(())
1822}
1823
1824/// Notify other participants that a language server has started.
1825async fn start_language_server(
1826 request: proto::StartLanguageServer,
1827 session: Session,
1828) -> Result<()> {
1829 let guest_connection_ids = session
1830 .db()
1831 .await
1832 .start_language_server(&request, session.connection_id)
1833 .await?;
1834
1835 broadcast(
1836 Some(session.connection_id),
1837 guest_connection_ids.iter().copied(),
1838 |connection_id| {
1839 session
1840 .peer
1841 .forward_send(session.connection_id, connection_id, request.clone())
1842 },
1843 );
1844 Ok(())
1845}
1846
1847/// Notify other participants that a language server has changed.
1848async fn update_language_server(
1849 request: proto::UpdateLanguageServer,
1850 session: Session,
1851) -> Result<()> {
1852 let project_id = ProjectId::from_proto(request.project_id);
1853 let project_connection_ids = session
1854 .db()
1855 .await
1856 .project_connection_ids(project_id, session.connection_id)
1857 .await?;
1858 broadcast(
1859 Some(session.connection_id),
1860 project_connection_ids.iter().copied(),
1861 |connection_id| {
1862 session
1863 .peer
1864 .forward_send(session.connection_id, connection_id, request.clone())
1865 },
1866 );
1867 Ok(())
1868}
1869
1870/// forward a project request to the host. These requests should be read only
1871/// as guests are allowed to send them.
1872async fn forward_read_only_project_request<T>(
1873 request: T,
1874 response: Response<T>,
1875 session: Session,
1876) -> Result<()>
1877where
1878 T: EntityMessage + RequestMessage,
1879{
1880 let project_id = ProjectId::from_proto(request.remote_entity_id());
1881 let host_connection_id = session
1882 .db()
1883 .await
1884 .host_for_read_only_project_request(project_id, session.connection_id)
1885 .await?;
1886 let payload = session
1887 .peer
1888 .forward_request(session.connection_id, host_connection_id, request)
1889 .await?;
1890 response.send(payload)?;
1891 Ok(())
1892}
1893
1894/// forward a project request to the host. These requests are disallowed
1895/// for guests.
1896async fn forward_mutating_project_request<T>(
1897 request: T,
1898 response: Response<T>,
1899 session: Session,
1900) -> Result<()>
1901where
1902 T: EntityMessage + RequestMessage,
1903{
1904 let project_id = ProjectId::from_proto(request.remote_entity_id());
1905 let host_connection_id = session
1906 .db()
1907 .await
1908 .host_for_mutating_project_request(project_id, session.connection_id)
1909 .await?;
1910 let payload = session
1911 .peer
1912 .forward_request(session.connection_id, host_connection_id, request)
1913 .await?;
1914 response.send(payload)?;
1915 Ok(())
1916}
1917
1918/// Notify other participants that a new buffer has been created
1919async fn create_buffer_for_peer(
1920 request: proto::CreateBufferForPeer,
1921 session: Session,
1922) -> Result<()> {
1923 session
1924 .db()
1925 .await
1926 .check_user_is_project_host(
1927 ProjectId::from_proto(request.project_id),
1928 session.connection_id,
1929 )
1930 .await?;
1931 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1932 session
1933 .peer
1934 .forward_send(session.connection_id, peer_id.into(), request)?;
1935 Ok(())
1936}
1937
1938/// Notify other participants that a buffer has been updated. This is
1939/// allowed for guests as long as the update is limited to selections.
1940async fn update_buffer(
1941 request: proto::UpdateBuffer,
1942 response: Response<proto::UpdateBuffer>,
1943 session: Session,
1944) -> Result<()> {
1945 let project_id = ProjectId::from_proto(request.project_id);
1946 let mut guest_connection_ids;
1947 let mut host_connection_id = None;
1948
1949 let mut requires_write_permission = false;
1950
1951 for op in request.operations.iter() {
1952 match op.variant {
1953 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1954 Some(_) => requires_write_permission = true,
1955 }
1956 }
1957
1958 {
1959 let collaborators = session
1960 .db()
1961 .await
1962 .project_collaborators_for_buffer_update(
1963 project_id,
1964 session.connection_id,
1965 requires_write_permission,
1966 )
1967 .await?;
1968 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1969 for collaborator in collaborators.iter() {
1970 if collaborator.is_host {
1971 host_connection_id = Some(collaborator.connection_id);
1972 } else {
1973 guest_connection_ids.push(collaborator.connection_id);
1974 }
1975 }
1976 }
1977 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1978
1979 broadcast(
1980 Some(session.connection_id),
1981 guest_connection_ids,
1982 |connection_id| {
1983 session
1984 .peer
1985 .forward_send(session.connection_id, connection_id, request.clone())
1986 },
1987 );
1988 if host_connection_id != session.connection_id {
1989 session
1990 .peer
1991 .forward_request(session.connection_id, host_connection_id, request.clone())
1992 .await?;
1993 }
1994
1995 response.send(proto::Ack {})?;
1996 Ok(())
1997}
1998
1999/// Notify other participants that a project has been updated.
2000async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2001 request: T,
2002 session: Session,
2003) -> Result<()> {
2004 let project_id = ProjectId::from_proto(request.remote_entity_id());
2005 let project_connection_ids = session
2006 .db()
2007 .await
2008 .project_connection_ids(project_id, session.connection_id)
2009 .await?;
2010
2011 broadcast(
2012 Some(session.connection_id),
2013 project_connection_ids.iter().copied(),
2014 |connection_id| {
2015 session
2016 .peer
2017 .forward_send(session.connection_id, connection_id, request.clone())
2018 },
2019 );
2020 Ok(())
2021}
2022
2023/// Start following another user in a call.
2024async fn follow(
2025 request: proto::Follow,
2026 response: Response<proto::Follow>,
2027 session: Session,
2028) -> Result<()> {
2029 let room_id = RoomId::from_proto(request.room_id);
2030 let project_id = request.project_id.map(ProjectId::from_proto);
2031 let leader_id = request
2032 .leader_id
2033 .ok_or_else(|| anyhow!("invalid leader id"))?
2034 .into();
2035 let follower_id = session.connection_id;
2036
2037 session
2038 .db()
2039 .await
2040 .check_room_participants(room_id, leader_id, session.connection_id)
2041 .await?;
2042
2043 let response_payload = session
2044 .peer
2045 .forward_request(session.connection_id, leader_id, request)
2046 .await?;
2047 response.send(response_payload)?;
2048
2049 if let Some(project_id) = project_id {
2050 let room = session
2051 .db()
2052 .await
2053 .follow(room_id, project_id, leader_id, follower_id)
2054 .await?;
2055 room_updated(&room, &session.peer);
2056 }
2057
2058 Ok(())
2059}
2060
2061/// Stop following another user in a call.
2062async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2063 let room_id = RoomId::from_proto(request.room_id);
2064 let project_id = request.project_id.map(ProjectId::from_proto);
2065 let leader_id = request
2066 .leader_id
2067 .ok_or_else(|| anyhow!("invalid leader id"))?
2068 .into();
2069 let follower_id = session.connection_id;
2070
2071 session
2072 .db()
2073 .await
2074 .check_room_participants(room_id, leader_id, session.connection_id)
2075 .await?;
2076
2077 session
2078 .peer
2079 .forward_send(session.connection_id, leader_id, request)?;
2080
2081 if let Some(project_id) = project_id {
2082 let room = session
2083 .db()
2084 .await
2085 .unfollow(room_id, project_id, leader_id, follower_id)
2086 .await?;
2087 room_updated(&room, &session.peer);
2088 }
2089
2090 Ok(())
2091}
2092
2093/// Notify everyone following you of your current location.
2094async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2095 let room_id = RoomId::from_proto(request.room_id);
2096 let database = session.db.lock().await;
2097
2098 let connection_ids = if let Some(project_id) = request.project_id {
2099 let project_id = ProjectId::from_proto(project_id);
2100 database
2101 .project_connection_ids(project_id, session.connection_id)
2102 .await?
2103 } else {
2104 database
2105 .room_connection_ids(room_id, session.connection_id)
2106 .await?
2107 };
2108
2109 // For now, don't send view update messages back to that view's current leader.
2110 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2111 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2112 _ => None,
2113 });
2114
2115 for connection_id in connection_ids.iter().cloned() {
2116 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2117 session
2118 .peer
2119 .forward_send(session.connection_id, connection_id, request.clone())?;
2120 }
2121 }
2122 Ok(())
2123}
2124
2125/// Get public data about users.
2126async fn get_users(
2127 request: proto::GetUsers,
2128 response: Response<proto::GetUsers>,
2129 session: Session,
2130) -> Result<()> {
2131 let user_ids = request
2132 .user_ids
2133 .into_iter()
2134 .map(UserId::from_proto)
2135 .collect();
2136 let users = session
2137 .db()
2138 .await
2139 .get_users_by_ids(user_ids)
2140 .await?
2141 .into_iter()
2142 .map(|user| proto::User {
2143 id: user.id.to_proto(),
2144 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2145 github_login: user.github_login,
2146 })
2147 .collect();
2148 response.send(proto::UsersResponse { users })?;
2149 Ok(())
2150}
2151
2152/// Search for users (to invite) buy Github login
2153async fn fuzzy_search_users(
2154 request: proto::FuzzySearchUsers,
2155 response: Response<proto::FuzzySearchUsers>,
2156 session: Session,
2157) -> Result<()> {
2158 let query = request.query;
2159 let users = match query.len() {
2160 0 => vec![],
2161 1 | 2 => session
2162 .db()
2163 .await
2164 .get_user_by_github_login(&query)
2165 .await?
2166 .into_iter()
2167 .collect(),
2168 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2169 };
2170 let users = users
2171 .into_iter()
2172 .filter(|user| user.id != session.user_id)
2173 .map(|user| proto::User {
2174 id: user.id.to_proto(),
2175 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2176 github_login: user.github_login,
2177 })
2178 .collect();
2179 response.send(proto::UsersResponse { users })?;
2180 Ok(())
2181}
2182
2183/// Send a contact request to another user.
2184async fn request_contact(
2185 request: proto::RequestContact,
2186 response: Response<proto::RequestContact>,
2187 session: Session,
2188) -> Result<()> {
2189 let requester_id = session.user_id;
2190 let responder_id = UserId::from_proto(request.responder_id);
2191 if requester_id == responder_id {
2192 return Err(anyhow!("cannot add yourself as a contact"))?;
2193 }
2194
2195 let notifications = session
2196 .db()
2197 .await
2198 .send_contact_request(requester_id, responder_id)
2199 .await?;
2200
2201 // Update outgoing contact requests of requester
2202 let mut update = proto::UpdateContacts::default();
2203 update.outgoing_requests.push(responder_id.to_proto());
2204 for connection_id in session
2205 .connection_pool()
2206 .await
2207 .user_connection_ids(requester_id)
2208 {
2209 session.peer.send(connection_id, update.clone())?;
2210 }
2211
2212 // Update incoming contact requests of responder
2213 let mut update = proto::UpdateContacts::default();
2214 update
2215 .incoming_requests
2216 .push(proto::IncomingContactRequest {
2217 requester_id: requester_id.to_proto(),
2218 });
2219 let connection_pool = session.connection_pool().await;
2220 for connection_id in connection_pool.user_connection_ids(responder_id) {
2221 session.peer.send(connection_id, update.clone())?;
2222 }
2223
2224 send_notifications(&*connection_pool, &session.peer, notifications);
2225
2226 response.send(proto::Ack {})?;
2227 Ok(())
2228}
2229
2230/// Accept or decline a contact request
2231async fn respond_to_contact_request(
2232 request: proto::RespondToContactRequest,
2233 response: Response<proto::RespondToContactRequest>,
2234 session: Session,
2235) -> Result<()> {
2236 let responder_id = session.user_id;
2237 let requester_id = UserId::from_proto(request.requester_id);
2238 let db = session.db().await;
2239 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2240 db.dismiss_contact_notification(responder_id, requester_id)
2241 .await?;
2242 } else {
2243 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2244
2245 let notifications = db
2246 .respond_to_contact_request(responder_id, requester_id, accept)
2247 .await?;
2248 let requester_busy = db.is_user_busy(requester_id).await?;
2249 let responder_busy = db.is_user_busy(responder_id).await?;
2250
2251 let pool = session.connection_pool().await;
2252 // Update responder with new contact
2253 let mut update = proto::UpdateContacts::default();
2254 if accept {
2255 update
2256 .contacts
2257 .push(contact_for_user(requester_id, requester_busy, &pool));
2258 }
2259 update
2260 .remove_incoming_requests
2261 .push(requester_id.to_proto());
2262 for connection_id in pool.user_connection_ids(responder_id) {
2263 session.peer.send(connection_id, update.clone())?;
2264 }
2265
2266 // Update requester with new contact
2267 let mut update = proto::UpdateContacts::default();
2268 if accept {
2269 update
2270 .contacts
2271 .push(contact_for_user(responder_id, responder_busy, &pool));
2272 }
2273 update
2274 .remove_outgoing_requests
2275 .push(responder_id.to_proto());
2276
2277 for connection_id in pool.user_connection_ids(requester_id) {
2278 session.peer.send(connection_id, update.clone())?;
2279 }
2280
2281 send_notifications(&*pool, &session.peer, notifications);
2282 }
2283
2284 response.send(proto::Ack {})?;
2285 Ok(())
2286}
2287
2288/// Remove a contact.
2289async fn remove_contact(
2290 request: proto::RemoveContact,
2291 response: Response<proto::RemoveContact>,
2292 session: Session,
2293) -> Result<()> {
2294 let requester_id = session.user_id;
2295 let responder_id = UserId::from_proto(request.user_id);
2296 let db = session.db().await;
2297 let (contact_accepted, deleted_notification_id) =
2298 db.remove_contact(requester_id, responder_id).await?;
2299
2300 let pool = session.connection_pool().await;
2301 // Update outgoing contact requests of requester
2302 let mut update = proto::UpdateContacts::default();
2303 if contact_accepted {
2304 update.remove_contacts.push(responder_id.to_proto());
2305 } else {
2306 update
2307 .remove_outgoing_requests
2308 .push(responder_id.to_proto());
2309 }
2310 for connection_id in pool.user_connection_ids(requester_id) {
2311 session.peer.send(connection_id, update.clone())?;
2312 }
2313
2314 // Update incoming contact requests of responder
2315 let mut update = proto::UpdateContacts::default();
2316 if contact_accepted {
2317 update.remove_contacts.push(requester_id.to_proto());
2318 } else {
2319 update
2320 .remove_incoming_requests
2321 .push(requester_id.to_proto());
2322 }
2323 for connection_id in pool.user_connection_ids(responder_id) {
2324 session.peer.send(connection_id, update.clone())?;
2325 if let Some(notification_id) = deleted_notification_id {
2326 session.peer.send(
2327 connection_id,
2328 proto::DeleteNotification {
2329 notification_id: notification_id.to_proto(),
2330 },
2331 )?;
2332 }
2333 }
2334
2335 response.send(proto::Ack {})?;
2336 Ok(())
2337}
2338
2339/// Creates a new channel.
2340async fn create_channel(
2341 request: proto::CreateChannel,
2342 response: Response<proto::CreateChannel>,
2343 session: Session,
2344) -> Result<()> {
2345 let db = session.db().await;
2346
2347 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2348 let (channel, owner, channel_members) = db
2349 .create_channel(&request.name, parent_id, session.user_id)
2350 .await?;
2351
2352 response.send(proto::CreateChannelResponse {
2353 channel: Some(channel.to_proto()),
2354 parent_id: request.parent_id,
2355 })?;
2356
2357 let connection_pool = session.connection_pool().await;
2358 if let Some(owner) = owner {
2359 let update = proto::UpdateUserChannels {
2360 channel_memberships: vec![proto::ChannelMembership {
2361 channel_id: owner.channel_id.to_proto(),
2362 role: owner.role.into(),
2363 }],
2364 ..Default::default()
2365 };
2366 for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2367 session.peer.send(connection_id, update.clone())?;
2368 }
2369 }
2370
2371 for channel_member in channel_members {
2372 if !channel_member.role.can_see_channel(channel.visibility) {
2373 continue;
2374 }
2375
2376 let update = proto::UpdateChannels {
2377 channels: vec![channel.to_proto()],
2378 ..Default::default()
2379 };
2380 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2381 session.peer.send(connection_id, update.clone())?;
2382 }
2383 }
2384
2385 Ok(())
2386}
2387
2388/// Delete a channel
2389async fn delete_channel(
2390 request: proto::DeleteChannel,
2391 response: Response<proto::DeleteChannel>,
2392 session: Session,
2393) -> Result<()> {
2394 let db = session.db().await;
2395
2396 let channel_id = request.channel_id;
2397 let (removed_channels, member_ids) = db
2398 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2399 .await?;
2400 response.send(proto::Ack {})?;
2401
2402 // Notify members of removed channels
2403 let mut update = proto::UpdateChannels::default();
2404 update
2405 .delete_channels
2406 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2407
2408 let connection_pool = session.connection_pool().await;
2409 for member_id in member_ids {
2410 for connection_id in connection_pool.user_connection_ids(member_id) {
2411 session.peer.send(connection_id, update.clone())?;
2412 }
2413 }
2414
2415 Ok(())
2416}
2417
2418/// Invite someone to join a channel.
2419async fn invite_channel_member(
2420 request: proto::InviteChannelMember,
2421 response: Response<proto::InviteChannelMember>,
2422 session: Session,
2423) -> Result<()> {
2424 let db = session.db().await;
2425 let channel_id = ChannelId::from_proto(request.channel_id);
2426 let invitee_id = UserId::from_proto(request.user_id);
2427 let InviteMemberResult {
2428 channel,
2429 notifications,
2430 } = db
2431 .invite_channel_member(
2432 channel_id,
2433 invitee_id,
2434 session.user_id,
2435 request.role().into(),
2436 )
2437 .await?;
2438
2439 let update = proto::UpdateChannels {
2440 channel_invitations: vec![channel.to_proto()],
2441 ..Default::default()
2442 };
2443
2444 let connection_pool = session.connection_pool().await;
2445 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2446 session.peer.send(connection_id, update.clone())?;
2447 }
2448
2449 send_notifications(&*connection_pool, &session.peer, notifications);
2450
2451 response.send(proto::Ack {})?;
2452 Ok(())
2453}
2454
2455/// remove someone from a channel
2456async fn remove_channel_member(
2457 request: proto::RemoveChannelMember,
2458 response: Response<proto::RemoveChannelMember>,
2459 session: Session,
2460) -> Result<()> {
2461 let db = session.db().await;
2462 let channel_id = ChannelId::from_proto(request.channel_id);
2463 let member_id = UserId::from_proto(request.user_id);
2464
2465 let RemoveChannelMemberResult {
2466 membership_update,
2467 notification_id,
2468 } = db
2469 .remove_channel_member(channel_id, member_id, session.user_id)
2470 .await?;
2471
2472 let connection_pool = &session.connection_pool().await;
2473 notify_membership_updated(
2474 &connection_pool,
2475 membership_update,
2476 member_id,
2477 &session.peer,
2478 );
2479 for connection_id in connection_pool.user_connection_ids(member_id) {
2480 if let Some(notification_id) = notification_id {
2481 session
2482 .peer
2483 .send(
2484 connection_id,
2485 proto::DeleteNotification {
2486 notification_id: notification_id.to_proto(),
2487 },
2488 )
2489 .trace_err();
2490 }
2491 }
2492
2493 response.send(proto::Ack {})?;
2494 Ok(())
2495}
2496
2497/// Toggle the channel between public and private.
2498/// Care is taken to maintain the invariant that public channels only descend from public channels,
2499/// (though members-only channels can appear at any point in the hierarchy).
2500async fn set_channel_visibility(
2501 request: proto::SetChannelVisibility,
2502 response: Response<proto::SetChannelVisibility>,
2503 session: Session,
2504) -> Result<()> {
2505 let db = session.db().await;
2506 let channel_id = ChannelId::from_proto(request.channel_id);
2507 let visibility = request.visibility().into();
2508
2509 let (channel, channel_members) = db
2510 .set_channel_visibility(channel_id, visibility, session.user_id)
2511 .await?;
2512
2513 let connection_pool = session.connection_pool().await;
2514 for member in channel_members {
2515 let update = if member.role.can_see_channel(channel.visibility) {
2516 proto::UpdateChannels {
2517 channels: vec![channel.to_proto()],
2518 ..Default::default()
2519 }
2520 } else {
2521 proto::UpdateChannels {
2522 delete_channels: vec![channel.id.to_proto()],
2523 ..Default::default()
2524 }
2525 };
2526
2527 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2528 session.peer.send(connection_id, update.clone())?;
2529 }
2530 }
2531
2532 response.send(proto::Ack {})?;
2533 Ok(())
2534}
2535
2536/// Alter the role for a user in the channel.
2537async fn set_channel_member_role(
2538 request: proto::SetChannelMemberRole,
2539 response: Response<proto::SetChannelMemberRole>,
2540 session: Session,
2541) -> Result<()> {
2542 let db = session.db().await;
2543 let channel_id = ChannelId::from_proto(request.channel_id);
2544 let member_id = UserId::from_proto(request.user_id);
2545 let result = db
2546 .set_channel_member_role(
2547 channel_id,
2548 session.user_id,
2549 member_id,
2550 request.role().into(),
2551 )
2552 .await?;
2553
2554 match result {
2555 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2556 let connection_pool = session.connection_pool().await;
2557 notify_membership_updated(
2558 &connection_pool,
2559 membership_update,
2560 member_id,
2561 &session.peer,
2562 )
2563 }
2564 db::SetMemberRoleResult::InviteUpdated(channel) => {
2565 let update = proto::UpdateChannels {
2566 channel_invitations: vec![channel.to_proto()],
2567 ..Default::default()
2568 };
2569
2570 for connection_id in session
2571 .connection_pool()
2572 .await
2573 .user_connection_ids(member_id)
2574 {
2575 session.peer.send(connection_id, update.clone())?;
2576 }
2577 }
2578 }
2579
2580 response.send(proto::Ack {})?;
2581 Ok(())
2582}
2583
2584/// Change the name of a channel
2585async fn rename_channel(
2586 request: proto::RenameChannel,
2587 response: Response<proto::RenameChannel>,
2588 session: Session,
2589) -> Result<()> {
2590 let db = session.db().await;
2591 let channel_id = ChannelId::from_proto(request.channel_id);
2592 let (channel, channel_members) = db
2593 .rename_channel(channel_id, session.user_id, &request.name)
2594 .await?;
2595
2596 response.send(proto::RenameChannelResponse {
2597 channel: Some(channel.to_proto()),
2598 })?;
2599
2600 let connection_pool = session.connection_pool().await;
2601 for channel_member in channel_members {
2602 if !channel_member.role.can_see_channel(channel.visibility) {
2603 continue;
2604 }
2605 let update = proto::UpdateChannels {
2606 channels: vec![channel.to_proto()],
2607 ..Default::default()
2608 };
2609 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2610 session.peer.send(connection_id, update.clone())?;
2611 }
2612 }
2613
2614 Ok(())
2615}
2616
2617/// Move a channel to a new parent.
2618async fn move_channel(
2619 request: proto::MoveChannel,
2620 response: Response<proto::MoveChannel>,
2621 session: Session,
2622) -> Result<()> {
2623 let channel_id = ChannelId::from_proto(request.channel_id);
2624 let to = ChannelId::from_proto(request.to);
2625
2626 let (channels, channel_members) = session
2627 .db()
2628 .await
2629 .move_channel(channel_id, to, session.user_id)
2630 .await?;
2631
2632 let connection_pool = session.connection_pool().await;
2633 for member in channel_members {
2634 let channels = channels
2635 .iter()
2636 .filter_map(|channel| {
2637 if member.role.can_see_channel(channel.visibility) {
2638 Some(channel.to_proto())
2639 } else {
2640 None
2641 }
2642 })
2643 .collect::<Vec<_>>();
2644 if channels.is_empty() {
2645 continue;
2646 }
2647
2648 let update = proto::UpdateChannels {
2649 channels,
2650 ..Default::default()
2651 };
2652
2653 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2654 session.peer.send(connection_id, update.clone())?;
2655 }
2656 }
2657
2658 response.send(Ack {})?;
2659 Ok(())
2660}
2661
2662/// Get the list of channel members
2663async fn get_channel_members(
2664 request: proto::GetChannelMembers,
2665 response: Response<proto::GetChannelMembers>,
2666 session: Session,
2667) -> Result<()> {
2668 let db = session.db().await;
2669 let channel_id = ChannelId::from_proto(request.channel_id);
2670 let members = db
2671 .get_channel_participant_details(channel_id, session.user_id)
2672 .await?;
2673 response.send(proto::GetChannelMembersResponse { members })?;
2674 Ok(())
2675}
2676
2677/// Accept or decline a channel invitation.
2678async fn respond_to_channel_invite(
2679 request: proto::RespondToChannelInvite,
2680 response: Response<proto::RespondToChannelInvite>,
2681 session: Session,
2682) -> Result<()> {
2683 let db = session.db().await;
2684 let channel_id = ChannelId::from_proto(request.channel_id);
2685 let RespondToChannelInvite {
2686 membership_update,
2687 notifications,
2688 } = db
2689 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2690 .await?;
2691
2692 let connection_pool = session.connection_pool().await;
2693 if let Some(membership_update) = membership_update {
2694 notify_membership_updated(
2695 &connection_pool,
2696 membership_update,
2697 session.user_id,
2698 &session.peer,
2699 );
2700 } else {
2701 let update = proto::UpdateChannels {
2702 remove_channel_invitations: vec![channel_id.to_proto()],
2703 ..Default::default()
2704 };
2705
2706 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2707 session.peer.send(connection_id, update.clone())?;
2708 }
2709 };
2710
2711 send_notifications(&*connection_pool, &session.peer, notifications);
2712
2713 response.send(proto::Ack {})?;
2714
2715 Ok(())
2716}
2717
2718/// Join the channels' room
2719async fn join_channel(
2720 request: proto::JoinChannel,
2721 response: Response<proto::JoinChannel>,
2722 session: Session,
2723) -> Result<()> {
2724 let channel_id = ChannelId::from_proto(request.channel_id);
2725 join_channel_internal(channel_id, Box::new(response), session).await
2726}
2727
2728trait JoinChannelInternalResponse {
2729 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2730}
2731impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2732 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2733 Response::<proto::JoinChannel>::send(self, result)
2734 }
2735}
2736impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2737 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2738 Response::<proto::JoinRoom>::send(self, result)
2739 }
2740}
2741
2742async fn join_channel_internal(
2743 channel_id: ChannelId,
2744 response: Box<impl JoinChannelInternalResponse>,
2745 session: Session,
2746) -> Result<()> {
2747 let joined_room = {
2748 leave_room_for_session(&session).await?;
2749 let db = session.db().await;
2750
2751 let (joined_room, membership_updated, role) = db
2752 .join_channel(channel_id, session.user_id, session.connection_id)
2753 .await?;
2754
2755 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2756 let (can_publish, token) = if role == ChannelRole::Guest {
2757 (
2758 false,
2759 live_kit
2760 .guest_token(
2761 &joined_room.room.live_kit_room,
2762 &session.user_id.to_string(),
2763 )
2764 .trace_err()?,
2765 )
2766 } else {
2767 (
2768 true,
2769 live_kit
2770 .room_token(
2771 &joined_room.room.live_kit_room,
2772 &session.user_id.to_string(),
2773 )
2774 .trace_err()?,
2775 )
2776 };
2777
2778 Some(LiveKitConnectionInfo {
2779 server_url: live_kit.url().into(),
2780 token,
2781 can_publish,
2782 })
2783 });
2784
2785 response.send(proto::JoinRoomResponse {
2786 room: Some(joined_room.room.clone()),
2787 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2788 live_kit_connection_info,
2789 })?;
2790
2791 let connection_pool = session.connection_pool().await;
2792 if let Some(membership_updated) = membership_updated {
2793 notify_membership_updated(
2794 &connection_pool,
2795 membership_updated,
2796 session.user_id,
2797 &session.peer,
2798 );
2799 }
2800
2801 room_updated(&joined_room.room, &session.peer);
2802
2803 joined_room
2804 };
2805
2806 channel_updated(
2807 channel_id,
2808 &joined_room.room,
2809 &joined_room.channel_members,
2810 &session.peer,
2811 &*session.connection_pool().await,
2812 );
2813
2814 update_user_contacts(session.user_id, &session).await?;
2815 Ok(())
2816}
2817
2818/// Start editing the channel notes
2819async fn join_channel_buffer(
2820 request: proto::JoinChannelBuffer,
2821 response: Response<proto::JoinChannelBuffer>,
2822 session: Session,
2823) -> Result<()> {
2824 let db = session.db().await;
2825 let channel_id = ChannelId::from_proto(request.channel_id);
2826
2827 let open_response = db
2828 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2829 .await?;
2830
2831 let collaborators = open_response.collaborators.clone();
2832 response.send(open_response)?;
2833
2834 let update = UpdateChannelBufferCollaborators {
2835 channel_id: channel_id.to_proto(),
2836 collaborators: collaborators.clone(),
2837 };
2838 channel_buffer_updated(
2839 session.connection_id,
2840 collaborators
2841 .iter()
2842 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2843 &update,
2844 &session.peer,
2845 );
2846
2847 Ok(())
2848}
2849
2850/// Edit the channel notes
2851async fn update_channel_buffer(
2852 request: proto::UpdateChannelBuffer,
2853 session: Session,
2854) -> Result<()> {
2855 let db = session.db().await;
2856 let channel_id = ChannelId::from_proto(request.channel_id);
2857
2858 let (collaborators, non_collaborators, epoch, version) = db
2859 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2860 .await?;
2861
2862 channel_buffer_updated(
2863 session.connection_id,
2864 collaborators,
2865 &proto::UpdateChannelBuffer {
2866 channel_id: channel_id.to_proto(),
2867 operations: request.operations,
2868 },
2869 &session.peer,
2870 );
2871
2872 let pool = &*session.connection_pool().await;
2873
2874 broadcast(
2875 None,
2876 non_collaborators
2877 .iter()
2878 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2879 |peer_id| {
2880 session.peer.send(
2881 peer_id.into(),
2882 proto::UpdateChannels {
2883 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2884 channel_id: channel_id.to_proto(),
2885 epoch: epoch as u64,
2886 version: version.clone(),
2887 }],
2888 ..Default::default()
2889 },
2890 )
2891 },
2892 );
2893
2894 Ok(())
2895}
2896
2897/// Rejoin the channel notes after a connection blip
2898async fn rejoin_channel_buffers(
2899 request: proto::RejoinChannelBuffers,
2900 response: Response<proto::RejoinChannelBuffers>,
2901 session: Session,
2902) -> Result<()> {
2903 let db = session.db().await;
2904 let buffers = db
2905 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2906 .await?;
2907
2908 for rejoined_buffer in &buffers {
2909 let collaborators_to_notify = rejoined_buffer
2910 .buffer
2911 .collaborators
2912 .iter()
2913 .filter_map(|c| Some(c.peer_id?.into()));
2914 channel_buffer_updated(
2915 session.connection_id,
2916 collaborators_to_notify,
2917 &proto::UpdateChannelBufferCollaborators {
2918 channel_id: rejoined_buffer.buffer.channel_id,
2919 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2920 },
2921 &session.peer,
2922 );
2923 }
2924
2925 response.send(proto::RejoinChannelBuffersResponse {
2926 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2927 })?;
2928
2929 Ok(())
2930}
2931
2932/// Stop editing the channel notes
2933async fn leave_channel_buffer(
2934 request: proto::LeaveChannelBuffer,
2935 response: Response<proto::LeaveChannelBuffer>,
2936 session: Session,
2937) -> Result<()> {
2938 let db = session.db().await;
2939 let channel_id = ChannelId::from_proto(request.channel_id);
2940
2941 let left_buffer = db
2942 .leave_channel_buffer(channel_id, session.connection_id)
2943 .await?;
2944
2945 response.send(Ack {})?;
2946
2947 channel_buffer_updated(
2948 session.connection_id,
2949 left_buffer.connections,
2950 &proto::UpdateChannelBufferCollaborators {
2951 channel_id: channel_id.to_proto(),
2952 collaborators: left_buffer.collaborators,
2953 },
2954 &session.peer,
2955 );
2956
2957 Ok(())
2958}
2959
2960fn channel_buffer_updated<T: EnvelopedMessage>(
2961 sender_id: ConnectionId,
2962 collaborators: impl IntoIterator<Item = ConnectionId>,
2963 message: &T,
2964 peer: &Peer,
2965) {
2966 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2967 peer.send(peer_id.into(), message.clone())
2968 });
2969}
2970
2971fn send_notifications(
2972 connection_pool: &ConnectionPool,
2973 peer: &Peer,
2974 notifications: db::NotificationBatch,
2975) {
2976 for (user_id, notification) in notifications {
2977 for connection_id in connection_pool.user_connection_ids(user_id) {
2978 if let Err(error) = peer.send(
2979 connection_id,
2980 proto::AddNotification {
2981 notification: Some(notification.clone()),
2982 },
2983 ) {
2984 tracing::error!(
2985 "failed to send notification to {:?} {}",
2986 connection_id,
2987 error
2988 );
2989 }
2990 }
2991 }
2992}
2993
2994/// Send a message to the channel
2995async fn send_channel_message(
2996 request: proto::SendChannelMessage,
2997 response: Response<proto::SendChannelMessage>,
2998 session: Session,
2999) -> Result<()> {
3000 // Validate the message body.
3001 let body = request.body.trim().to_string();
3002 if body.len() > MAX_MESSAGE_LEN {
3003 return Err(anyhow!("message is too long"))?;
3004 }
3005 if body.is_empty() {
3006 return Err(anyhow!("message can't be blank"))?;
3007 }
3008
3009 // TODO: adjust mentions if body is trimmed
3010
3011 let timestamp = OffsetDateTime::now_utc();
3012 let nonce = request
3013 .nonce
3014 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3015
3016 let channel_id = ChannelId::from_proto(request.channel_id);
3017 let CreatedChannelMessage {
3018 message_id,
3019 participant_connection_ids,
3020 channel_members,
3021 notifications,
3022 } = session
3023 .db()
3024 .await
3025 .create_channel_message(
3026 channel_id,
3027 session.user_id,
3028 &body,
3029 &request.mentions,
3030 timestamp,
3031 nonce.clone().into(),
3032 match request.reply_to_message_id {
3033 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3034 None => None,
3035 },
3036 )
3037 .await?;
3038 let message = proto::ChannelMessage {
3039 sender_id: session.user_id.to_proto(),
3040 id: message_id.to_proto(),
3041 body,
3042 mentions: request.mentions,
3043 timestamp: timestamp.unix_timestamp() as u64,
3044 nonce: Some(nonce),
3045 reply_to_message_id: request.reply_to_message_id,
3046 };
3047 broadcast(
3048 Some(session.connection_id),
3049 participant_connection_ids,
3050 |connection| {
3051 session.peer.send(
3052 connection,
3053 proto::ChannelMessageSent {
3054 channel_id: channel_id.to_proto(),
3055 message: Some(message.clone()),
3056 },
3057 )
3058 },
3059 );
3060 response.send(proto::SendChannelMessageResponse {
3061 message: Some(message),
3062 })?;
3063
3064 let pool = &*session.connection_pool().await;
3065 broadcast(
3066 None,
3067 channel_members
3068 .iter()
3069 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3070 |peer_id| {
3071 session.peer.send(
3072 peer_id.into(),
3073 proto::UpdateChannels {
3074 latest_channel_message_ids: vec![proto::ChannelMessageId {
3075 channel_id: channel_id.to_proto(),
3076 message_id: message_id.to_proto(),
3077 }],
3078 ..Default::default()
3079 },
3080 )
3081 },
3082 );
3083 send_notifications(pool, &session.peer, notifications);
3084
3085 Ok(())
3086}
3087
3088/// Delete a channel message
3089async fn remove_channel_message(
3090 request: proto::RemoveChannelMessage,
3091 response: Response<proto::RemoveChannelMessage>,
3092 session: Session,
3093) -> Result<()> {
3094 let channel_id = ChannelId::from_proto(request.channel_id);
3095 let message_id = MessageId::from_proto(request.message_id);
3096 let connection_ids = session
3097 .db()
3098 .await
3099 .remove_channel_message(channel_id, message_id, session.user_id)
3100 .await?;
3101 broadcast(Some(session.connection_id), connection_ids, |connection| {
3102 session.peer.send(connection, request.clone())
3103 });
3104 response.send(proto::Ack {})?;
3105 Ok(())
3106}
3107
3108/// Mark a channel message as read
3109async fn acknowledge_channel_message(
3110 request: proto::AckChannelMessage,
3111 session: Session,
3112) -> Result<()> {
3113 let channel_id = ChannelId::from_proto(request.channel_id);
3114 let message_id = MessageId::from_proto(request.message_id);
3115 let notifications = session
3116 .db()
3117 .await
3118 .observe_channel_message(channel_id, session.user_id, message_id)
3119 .await?;
3120 send_notifications(
3121 &*session.connection_pool().await,
3122 &session.peer,
3123 notifications,
3124 );
3125 Ok(())
3126}
3127
3128/// Mark a buffer version as synced
3129async fn acknowledge_buffer_version(
3130 request: proto::AckBufferOperation,
3131 session: Session,
3132) -> Result<()> {
3133 let buffer_id = BufferId::from_proto(request.buffer_id);
3134 session
3135 .db()
3136 .await
3137 .observe_buffer_version(
3138 buffer_id,
3139 session.user_id,
3140 request.epoch as i32,
3141 &request.version,
3142 )
3143 .await?;
3144 Ok(())
3145}
3146
3147/// Start receiving chat updates for a channel
3148async fn join_channel_chat(
3149 request: proto::JoinChannelChat,
3150 response: Response<proto::JoinChannelChat>,
3151 session: Session,
3152) -> Result<()> {
3153 let channel_id = ChannelId::from_proto(request.channel_id);
3154
3155 let db = session.db().await;
3156 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3157 .await?;
3158 let messages = db
3159 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3160 .await?;
3161 response.send(proto::JoinChannelChatResponse {
3162 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3163 messages,
3164 })?;
3165 Ok(())
3166}
3167
3168/// Stop receiving chat updates for a channel
3169async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3170 let channel_id = ChannelId::from_proto(request.channel_id);
3171 session
3172 .db()
3173 .await
3174 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3175 .await?;
3176 Ok(())
3177}
3178
3179/// Retrieve the chat history for a channel
3180async fn get_channel_messages(
3181 request: proto::GetChannelMessages,
3182 response: Response<proto::GetChannelMessages>,
3183 session: Session,
3184) -> Result<()> {
3185 let channel_id = ChannelId::from_proto(request.channel_id);
3186 let messages = session
3187 .db()
3188 .await
3189 .get_channel_messages(
3190 channel_id,
3191 session.user_id,
3192 MESSAGE_COUNT_PER_PAGE,
3193 Some(MessageId::from_proto(request.before_message_id)),
3194 )
3195 .await?;
3196 response.send(proto::GetChannelMessagesResponse {
3197 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3198 messages,
3199 })?;
3200 Ok(())
3201}
3202
3203/// Retrieve specific chat messages
3204async fn get_channel_messages_by_id(
3205 request: proto::GetChannelMessagesById,
3206 response: Response<proto::GetChannelMessagesById>,
3207 session: Session,
3208) -> Result<()> {
3209 let message_ids = request
3210 .message_ids
3211 .iter()
3212 .map(|id| MessageId::from_proto(*id))
3213 .collect::<Vec<_>>();
3214 let messages = session
3215 .db()
3216 .await
3217 .get_channel_messages_by_id(session.user_id, &message_ids)
3218 .await?;
3219 response.send(proto::GetChannelMessagesResponse {
3220 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3221 messages,
3222 })?;
3223 Ok(())
3224}
3225
3226/// Retrieve the current users notifications
3227async fn get_notifications(
3228 request: proto::GetNotifications,
3229 response: Response<proto::GetNotifications>,
3230 session: Session,
3231) -> Result<()> {
3232 let notifications = session
3233 .db()
3234 .await
3235 .get_notifications(
3236 session.user_id,
3237 NOTIFICATION_COUNT_PER_PAGE,
3238 request
3239 .before_id
3240 .map(|id| db::NotificationId::from_proto(id)),
3241 )
3242 .await?;
3243 response.send(proto::GetNotificationsResponse {
3244 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3245 notifications,
3246 })?;
3247 Ok(())
3248}
3249
3250/// Mark notifications as read
3251async fn mark_notification_as_read(
3252 request: proto::MarkNotificationRead,
3253 response: Response<proto::MarkNotificationRead>,
3254 session: Session,
3255) -> Result<()> {
3256 let database = &session.db().await;
3257 let notifications = database
3258 .mark_notification_as_read_by_id(
3259 session.user_id,
3260 NotificationId::from_proto(request.notification_id),
3261 )
3262 .await?;
3263 send_notifications(
3264 &*session.connection_pool().await,
3265 &session.peer,
3266 notifications,
3267 );
3268 response.send(proto::Ack {})?;
3269 Ok(())
3270}
3271
3272/// Get the current users information
3273async fn get_private_user_info(
3274 _request: proto::GetPrivateUserInfo,
3275 response: Response<proto::GetPrivateUserInfo>,
3276 session: Session,
3277) -> Result<()> {
3278 let db = session.db().await;
3279
3280 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3281 let user = db
3282 .get_user_by_id(session.user_id)
3283 .await?
3284 .ok_or_else(|| anyhow!("user not found"))?;
3285 let flags = db.get_user_flags(session.user_id).await?;
3286
3287 response.send(proto::GetPrivateUserInfoResponse {
3288 metrics_id,
3289 staff: user.admin,
3290 flags,
3291 })?;
3292 Ok(())
3293}
3294
3295fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3296 match message {
3297 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3298 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3299 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3300 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3301 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3302 code: frame.code.into(),
3303 reason: frame.reason,
3304 })),
3305 }
3306}
3307
3308fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3309 match message {
3310 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3311 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3312 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3313 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3314 AxumMessage::Close(frame) => {
3315 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3316 code: frame.code.into(),
3317 reason: frame.reason,
3318 }))
3319 }
3320 }
3321}
3322
3323fn notify_membership_updated(
3324 connection_pool: &ConnectionPool,
3325 result: MembershipUpdated,
3326 user_id: UserId,
3327 peer: &Peer,
3328) {
3329 let user_channels_update = proto::UpdateUserChannels {
3330 channel_memberships: result
3331 .new_channels
3332 .channel_memberships
3333 .iter()
3334 .map(|cm| proto::ChannelMembership {
3335 channel_id: cm.channel_id.to_proto(),
3336 role: cm.role.into(),
3337 })
3338 .collect(),
3339 ..Default::default()
3340 };
3341 let mut update = build_channels_update(result.new_channels, vec![]);
3342 update.delete_channels = result
3343 .removed_channels
3344 .into_iter()
3345 .map(|id| id.to_proto())
3346 .collect();
3347 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3348
3349 for connection_id in connection_pool.user_connection_ids(user_id) {
3350 peer.send(connection_id, user_channels_update.clone())
3351 .trace_err();
3352 peer.send(connection_id, update.clone()).trace_err();
3353 }
3354}
3355
3356fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3357 proto::UpdateUserChannels {
3358 channel_memberships: channels
3359 .channel_memberships
3360 .iter()
3361 .map(|m| proto::ChannelMembership {
3362 channel_id: m.channel_id.to_proto(),
3363 role: m.role.into(),
3364 })
3365 .collect(),
3366 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3367 observed_channel_message_id: channels.observed_channel_messages.clone(),
3368 ..Default::default()
3369 }
3370}
3371
3372fn build_channels_update(
3373 channels: ChannelsForUser,
3374 channel_invites: Vec<db::Channel>,
3375) -> proto::UpdateChannels {
3376 let mut update = proto::UpdateChannels::default();
3377
3378 for channel in channels.channels {
3379 update.channels.push(channel.to_proto());
3380 }
3381
3382 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3383 update.latest_channel_message_ids = channels.latest_channel_messages;
3384
3385 for (channel_id, participants) in channels.channel_participants {
3386 update
3387 .channel_participants
3388 .push(proto::ChannelParticipants {
3389 channel_id: channel_id.to_proto(),
3390 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3391 });
3392 }
3393
3394 for channel in channel_invites {
3395 update.channel_invitations.push(channel.to_proto());
3396 }
3397 for project in channels.hosted_projects {
3398 update.hosted_projects.push(project);
3399 }
3400
3401 update
3402}
3403
3404fn build_initial_contacts_update(
3405 contacts: Vec<db::Contact>,
3406 pool: &ConnectionPool,
3407) -> proto::UpdateContacts {
3408 let mut update = proto::UpdateContacts::default();
3409
3410 for contact in contacts {
3411 match contact {
3412 db::Contact::Accepted { user_id, busy } => {
3413 update.contacts.push(contact_for_user(user_id, busy, &pool));
3414 }
3415 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3416 db::Contact::Incoming { user_id } => {
3417 update
3418 .incoming_requests
3419 .push(proto::IncomingContactRequest {
3420 requester_id: user_id.to_proto(),
3421 })
3422 }
3423 }
3424 }
3425
3426 update
3427}
3428
3429fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3430 proto::Contact {
3431 user_id: user_id.to_proto(),
3432 online: pool.is_user_online(user_id),
3433 busy,
3434 }
3435}
3436
3437fn room_updated(room: &proto::Room, peer: &Peer) {
3438 broadcast(
3439 None,
3440 room.participants
3441 .iter()
3442 .filter_map(|participant| Some(participant.peer_id?.into())),
3443 |peer_id| {
3444 peer.send(
3445 peer_id.into(),
3446 proto::RoomUpdated {
3447 room: Some(room.clone()),
3448 },
3449 )
3450 },
3451 );
3452}
3453
3454fn channel_updated(
3455 channel_id: ChannelId,
3456 room: &proto::Room,
3457 channel_members: &[UserId],
3458 peer: &Peer,
3459 pool: &ConnectionPool,
3460) {
3461 let participants = room
3462 .participants
3463 .iter()
3464 .map(|p| p.user_id)
3465 .collect::<Vec<_>>();
3466
3467 broadcast(
3468 None,
3469 channel_members
3470 .iter()
3471 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3472 |peer_id| {
3473 peer.send(
3474 peer_id.into(),
3475 proto::UpdateChannels {
3476 channel_participants: vec![proto::ChannelParticipants {
3477 channel_id: channel_id.to_proto(),
3478 participant_user_ids: participants.clone(),
3479 }],
3480 ..Default::default()
3481 },
3482 )
3483 },
3484 );
3485}
3486
3487async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3488 let db = session.db().await;
3489
3490 let contacts = db.get_contacts(user_id).await?;
3491 let busy = db.is_user_busy(user_id).await?;
3492
3493 let pool = session.connection_pool().await;
3494 let updated_contact = contact_for_user(user_id, busy, &pool);
3495 for contact in contacts {
3496 if let db::Contact::Accepted {
3497 user_id: contact_user_id,
3498 ..
3499 } = contact
3500 {
3501 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3502 session
3503 .peer
3504 .send(
3505 contact_conn_id,
3506 proto::UpdateContacts {
3507 contacts: vec![updated_contact.clone()],
3508 remove_contacts: Default::default(),
3509 incoming_requests: Default::default(),
3510 remove_incoming_requests: Default::default(),
3511 outgoing_requests: Default::default(),
3512 remove_outgoing_requests: Default::default(),
3513 },
3514 )
3515 .trace_err();
3516 }
3517 }
3518 }
3519 Ok(())
3520}
3521
3522async fn leave_room_for_session(session: &Session) -> Result<()> {
3523 let mut contacts_to_update = HashSet::default();
3524
3525 let room_id;
3526 let canceled_calls_to_user_ids;
3527 let live_kit_room;
3528 let delete_live_kit_room;
3529 let room;
3530 let channel_members;
3531 let channel_id;
3532
3533 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3534 contacts_to_update.insert(session.user_id);
3535
3536 for project in left_room.left_projects.values() {
3537 project_left(project, session);
3538 }
3539
3540 room_id = RoomId::from_proto(left_room.room.id);
3541 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3542 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3543 delete_live_kit_room = left_room.deleted;
3544 room = mem::take(&mut left_room.room);
3545 channel_members = mem::take(&mut left_room.channel_members);
3546 channel_id = left_room.channel_id;
3547
3548 room_updated(&room, &session.peer);
3549 } else {
3550 return Ok(());
3551 }
3552
3553 if let Some(channel_id) = channel_id {
3554 channel_updated(
3555 channel_id,
3556 &room,
3557 &channel_members,
3558 &session.peer,
3559 &*session.connection_pool().await,
3560 );
3561 }
3562
3563 {
3564 let pool = session.connection_pool().await;
3565 for canceled_user_id in canceled_calls_to_user_ids {
3566 for connection_id in pool.user_connection_ids(canceled_user_id) {
3567 session
3568 .peer
3569 .send(
3570 connection_id,
3571 proto::CallCanceled {
3572 room_id: room_id.to_proto(),
3573 },
3574 )
3575 .trace_err();
3576 }
3577 contacts_to_update.insert(canceled_user_id);
3578 }
3579 }
3580
3581 for contact_user_id in contacts_to_update {
3582 update_user_contacts(contact_user_id, &session).await?;
3583 }
3584
3585 if let Some(live_kit) = session.live_kit_client.as_ref() {
3586 live_kit
3587 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3588 .await
3589 .trace_err();
3590
3591 if delete_live_kit_room {
3592 live_kit.delete_room(live_kit_room).await.trace_err();
3593 }
3594 }
3595
3596 Ok(())
3597}
3598
3599async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3600 let left_channel_buffers = session
3601 .db()
3602 .await
3603 .leave_channel_buffers(session.connection_id)
3604 .await?;
3605
3606 for left_buffer in left_channel_buffers {
3607 channel_buffer_updated(
3608 session.connection_id,
3609 left_buffer.connections,
3610 &proto::UpdateChannelBufferCollaborators {
3611 channel_id: left_buffer.channel_id.to_proto(),
3612 collaborators: left_buffer.collaborators,
3613 },
3614 &session.peer,
3615 );
3616 }
3617
3618 Ok(())
3619}
3620
3621fn project_left(project: &db::LeftProject, session: &Session) {
3622 for connection_id in &project.connection_ids {
3623 if project.host_user_id == session.user_id {
3624 session
3625 .peer
3626 .send(
3627 *connection_id,
3628 proto::UnshareProject {
3629 project_id: project.id.to_proto(),
3630 },
3631 )
3632 .trace_err();
3633 } else {
3634 session
3635 .peer
3636 .send(
3637 *connection_id,
3638 proto::RemoveProjectCollaborator {
3639 project_id: project.id.to_proto(),
3640 peer_id: Some(session.connection_id.into()),
3641 },
3642 )
3643 .trace_err();
3644 }
3645 }
3646}
3647
3648pub trait ResultExt {
3649 type Ok;
3650
3651 fn trace_err(self) -> Option<Self::Ok>;
3652}
3653
3654impl<T, E> ResultExt for Result<T, E>
3655where
3656 E: std::fmt::Debug,
3657{
3658 type Ok = T;
3659
3660 fn trace_err(self) -> Option<T> {
3661 match self {
3662 Ok(value) => Some(value),
3663 Err(error) => {
3664 tracing::error!("{:?}", error);
3665 None
3666 }
3667 }
3668 }
3669}