1mod connection_pool;
2
3use crate::{
4 auth::{self, Impersonator},
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
7 InviteMemberResult, MembershipUpdated, MessageId, NotificationId, ProjectId,
8 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
9 },
10 executor::Executor,
11 AppState, Error, Result,
12};
13use anyhow::anyhow;
14use async_tungstenite::tungstenite::{
15 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
16};
17use axum::{
18 body::Body,
19 extract::{
20 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
21 ConnectInfo, WebSocketUpgrade,
22 },
23 headers::{Header, HeaderName},
24 http::StatusCode,
25 middleware,
26 response::IntoResponse,
27 routing::get,
28 Extension, Router, TypedHeader,
29};
30use collections::{HashMap, HashSet};
31pub use connection_pool::ConnectionPool;
32use futures::{
33 channel::oneshot,
34 future::{self, BoxFuture},
35 stream::FuturesUnordered,
36 FutureExt, SinkExt, StreamExt, TryStreamExt,
37};
38use lazy_static::lazy_static;
39use prometheus::{register_int_gauge, IntGauge};
40use rpc::{
41 proto::{
42 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
43 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
44 },
45 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
46};
47use serde::{Serialize, Serializer};
48use std::{
49 any::TypeId,
50 fmt,
51 future::Future,
52 marker::PhantomData,
53 mem,
54 net::SocketAddr,
55 ops::{Deref, DerefMut},
56 rc::Rc,
57 sync::{
58 atomic::{AtomicBool, Ordering::SeqCst},
59 Arc,
60 },
61 time::{Duration, Instant},
62};
63use time::OffsetDateTime;
64use tokio::sync::{watch, Semaphore};
65use tower::ServiceBuilder;
66use tracing::{field, info_span, instrument, Instrument};
67use util::SemanticVersion;
68
69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
70pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
71
72const MESSAGE_COUNT_PER_PAGE: usize = 100;
73const MAX_MESSAGE_LEN: usize = 1024;
74const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
75
76lazy_static! {
77 static ref METRIC_CONNECTIONS: IntGauge =
78 register_int_gauge!("connections", "number of connections").unwrap();
79 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
80 "shared_projects",
81 "number of open projects with one or more guests"
82 )
83 .unwrap();
84}
85
86type MessageHandler =
87 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
88
89struct Response<R> {
90 peer: Arc<Peer>,
91 receipt: Receipt<R>,
92 responded: Arc<AtomicBool>,
93}
94
95impl<R: RequestMessage> Response<R> {
96 fn send(self, payload: R::Response) -> Result<()> {
97 self.responded.store(true, SeqCst);
98 self.peer.respond(self.receipt, payload)?;
99 Ok(())
100 }
101}
102
103#[derive(Clone)]
104struct Session {
105 zed_environment: Arc<str>,
106 user_id: UserId,
107 connection_id: ConnectionId,
108 db: Arc<tokio::sync::Mutex<DbHandle>>,
109 peer: Arc<Peer>,
110 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
111 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
112 _executor: Executor,
113}
114
115impl Session {
116 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
117 #[cfg(test)]
118 tokio::task::yield_now().await;
119 let guard = self.db.lock().await;
120 #[cfg(test)]
121 tokio::task::yield_now().await;
122 guard
123 }
124
125 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
126 #[cfg(test)]
127 tokio::task::yield_now().await;
128 let guard = self.connection_pool.lock();
129 ConnectionPoolGuard {
130 guard,
131 _not_send: PhantomData,
132 }
133 }
134}
135
136impl fmt::Debug for Session {
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 f.debug_struct("Session")
139 .field("user_id", &self.user_id)
140 .field("connection_id", &self.connection_id)
141 .finish()
142 }
143}
144
145struct DbHandle(Arc<Database>);
146
147impl Deref for DbHandle {
148 type Target = Database;
149
150 fn deref(&self) -> &Self::Target {
151 self.0.as_ref()
152 }
153}
154
155pub struct Server {
156 id: parking_lot::Mutex<ServerId>,
157 peer: Arc<Peer>,
158 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
159 app_state: Arc<AppState>,
160 executor: Executor,
161 handlers: HashMap<TypeId, MessageHandler>,
162 teardown: watch::Sender<()>,
163}
164
165pub(crate) struct ConnectionPoolGuard<'a> {
166 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
167 _not_send: PhantomData<Rc<()>>,
168}
169
170#[derive(Serialize)]
171pub struct ServerSnapshot<'a> {
172 peer: &'a Peer,
173 #[serde(serialize_with = "serialize_deref")]
174 connection_pool: ConnectionPoolGuard<'a>,
175}
176
177pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
178where
179 S: Serializer,
180 T: Deref<Target = U>,
181 U: Serialize,
182{
183 Serialize::serialize(value.deref(), serializer)
184}
185
186impl Server {
187 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
188 let mut server = Self {
189 id: parking_lot::Mutex::new(id),
190 peer: Peer::new(id.0 as u32),
191 app_state,
192 executor,
193 connection_pool: Default::default(),
194 handlers: Default::default(),
195 teardown: watch::channel(()).0,
196 };
197
198 server
199 .add_request_handler(ping)
200 .add_request_handler(create_room)
201 .add_request_handler(join_room)
202 .add_request_handler(rejoin_room)
203 .add_request_handler(leave_room)
204 .add_request_handler(set_room_participant_role)
205 .add_request_handler(call)
206 .add_request_handler(cancel_call)
207 .add_message_handler(decline_call)
208 .add_request_handler(update_participant_location)
209 .add_request_handler(share_project)
210 .add_message_handler(unshare_project)
211 .add_request_handler(join_project)
212 .add_message_handler(leave_project)
213 .add_request_handler(update_project)
214 .add_request_handler(update_worktree)
215 .add_message_handler(start_language_server)
216 .add_message_handler(update_language_server)
217 .add_message_handler(update_diagnostic_summary)
218 .add_message_handler(update_worktree_settings)
219 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
220 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
221 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
222 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
223 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
224 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
225 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
226 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
227 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
228 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
229 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
230 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
231 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
232 .add_request_handler(
233 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
234 )
235 .add_request_handler(
236 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
237 )
238 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
239 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
240 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
241 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
242 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
243 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
244 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
245 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
246 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
247 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
248 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
249 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
250 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
251 .add_message_handler(create_buffer_for_peer)
252 .add_request_handler(update_buffer)
253 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
254 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
255 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
256 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
257 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
258 .add_request_handler(get_users)
259 .add_request_handler(fuzzy_search_users)
260 .add_request_handler(request_contact)
261 .add_request_handler(remove_contact)
262 .add_request_handler(respond_to_contact_request)
263 .add_request_handler(create_channel)
264 .add_request_handler(delete_channel)
265 .add_request_handler(invite_channel_member)
266 .add_request_handler(remove_channel_member)
267 .add_request_handler(set_channel_member_role)
268 .add_request_handler(set_channel_visibility)
269 .add_request_handler(rename_channel)
270 .add_request_handler(join_channel_buffer)
271 .add_request_handler(leave_channel_buffer)
272 .add_message_handler(update_channel_buffer)
273 .add_request_handler(rejoin_channel_buffers)
274 .add_request_handler(get_channel_members)
275 .add_request_handler(respond_to_channel_invite)
276 .add_request_handler(join_channel)
277 .add_request_handler(join_channel_chat)
278 .add_message_handler(leave_channel_chat)
279 .add_request_handler(send_channel_message)
280 .add_request_handler(remove_channel_message)
281 .add_request_handler(get_channel_messages)
282 .add_request_handler(get_channel_messages_by_id)
283 .add_request_handler(get_notifications)
284 .add_request_handler(mark_notification_as_read)
285 .add_request_handler(move_channel)
286 .add_request_handler(follow)
287 .add_message_handler(unfollow)
288 .add_message_handler(update_followers)
289 .add_request_handler(get_private_user_info)
290 .add_message_handler(acknowledge_channel_message)
291 .add_message_handler(acknowledge_buffer_version);
292
293 Arc::new(server)
294 }
295
296 pub async fn start(&self) -> Result<()> {
297 let server_id = *self.id.lock();
298 let app_state = self.app_state.clone();
299 let peer = self.peer.clone();
300 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
301 let pool = self.connection_pool.clone();
302 let live_kit_client = self.app_state.live_kit_client.clone();
303
304 let span = info_span!("start server");
305 self.executor.spawn_detached(
306 async move {
307 tracing::info!("waiting for cleanup timeout");
308 timeout.await;
309 tracing::info!("cleanup timeout expired, retrieving stale rooms");
310 if let Some((room_ids, channel_ids)) = app_state
311 .db
312 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
313 .await
314 .trace_err()
315 {
316 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
317 tracing::info!(
318 stale_channel_buffer_count = channel_ids.len(),
319 "retrieved stale channel buffers"
320 );
321
322 for channel_id in channel_ids {
323 if let Some(refreshed_channel_buffer) = app_state
324 .db
325 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
326 .await
327 .trace_err()
328 {
329 for connection_id in refreshed_channel_buffer.connection_ids {
330 peer.send(
331 connection_id,
332 proto::UpdateChannelBufferCollaborators {
333 channel_id: channel_id.to_proto(),
334 collaborators: refreshed_channel_buffer
335 .collaborators
336 .clone(),
337 },
338 )
339 .trace_err();
340 }
341 }
342 }
343
344 for room_id in room_ids {
345 let mut contacts_to_update = HashSet::default();
346 let mut canceled_calls_to_user_ids = Vec::new();
347 let mut live_kit_room = String::new();
348 let mut delete_live_kit_room = false;
349
350 if let Some(mut refreshed_room) = app_state
351 .db
352 .clear_stale_room_participants(room_id, server_id)
353 .await
354 .trace_err()
355 {
356 tracing::info!(
357 room_id = room_id.0,
358 new_participant_count = refreshed_room.room.participants.len(),
359 "refreshed room"
360 );
361 room_updated(&refreshed_room.room, &peer);
362 if let Some(channel_id) = refreshed_room.channel_id {
363 channel_updated(
364 channel_id,
365 &refreshed_room.room,
366 &refreshed_room.channel_members,
367 &peer,
368 &*pool.lock(),
369 );
370 }
371 contacts_to_update
372 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
373 contacts_to_update
374 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
375 canceled_calls_to_user_ids =
376 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
377 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
378 delete_live_kit_room = refreshed_room.room.participants.is_empty();
379 }
380
381 {
382 let pool = pool.lock();
383 for canceled_user_id in canceled_calls_to_user_ids {
384 for connection_id in pool.user_connection_ids(canceled_user_id) {
385 peer.send(
386 connection_id,
387 proto::CallCanceled {
388 room_id: room_id.to_proto(),
389 },
390 )
391 .trace_err();
392 }
393 }
394 }
395
396 for user_id in contacts_to_update {
397 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
398 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
399 if let Some((busy, contacts)) = busy.zip(contacts) {
400 let pool = pool.lock();
401 let updated_contact = contact_for_user(user_id, busy, &pool);
402 for contact in contacts {
403 if let db::Contact::Accepted {
404 user_id: contact_user_id,
405 ..
406 } = contact
407 {
408 for contact_conn_id in
409 pool.user_connection_ids(contact_user_id)
410 {
411 peer.send(
412 contact_conn_id,
413 proto::UpdateContacts {
414 contacts: vec![updated_contact.clone()],
415 remove_contacts: Default::default(),
416 incoming_requests: Default::default(),
417 remove_incoming_requests: Default::default(),
418 outgoing_requests: Default::default(),
419 remove_outgoing_requests: Default::default(),
420 },
421 )
422 .trace_err();
423 }
424 }
425 }
426 }
427 }
428
429 if let Some(live_kit) = live_kit_client.as_ref() {
430 if delete_live_kit_room {
431 live_kit.delete_room(live_kit_room).await.trace_err();
432 }
433 }
434 }
435 }
436
437 app_state
438 .db
439 .delete_stale_servers(&app_state.config.zed_environment, server_id)
440 .await
441 .trace_err();
442 }
443 .instrument(span),
444 );
445 Ok(())
446 }
447
448 pub fn teardown(&self) {
449 self.peer.teardown();
450 self.connection_pool.lock().reset();
451 let _ = self.teardown.send(());
452 }
453
454 #[cfg(test)]
455 pub fn reset(&self, id: ServerId) {
456 self.teardown();
457 *self.id.lock() = id;
458 self.peer.reset(id.0 as u32);
459 }
460
461 #[cfg(test)]
462 pub fn id(&self) -> ServerId {
463 *self.id.lock()
464 }
465
466 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
467 where
468 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
469 Fut: 'static + Send + Future<Output = Result<()>>,
470 M: EnvelopedMessage,
471 {
472 let prev_handler = self.handlers.insert(
473 TypeId::of::<M>(),
474 Box::new(move |envelope, session| {
475 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
476 let span = info_span!(
477 "handle message",
478 payload_type = envelope.payload_type_name()
479 );
480 span.in_scope(|| {
481 tracing::info!(
482 payload_type = envelope.payload_type_name(),
483 "message received"
484 );
485 });
486 let start_time = Instant::now();
487 let future = (handler)(*envelope, session);
488 async move {
489 let result = future.await;
490 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
491 match result {
492 Err(error) => {
493 tracing::error!(%error, ?duration_ms, "error handling message")
494 }
495 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
496 }
497 }
498 .instrument(span)
499 .boxed()
500 }),
501 );
502 if prev_handler.is_some() {
503 panic!("registered a handler for the same message twice");
504 }
505 self
506 }
507
508 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
509 where
510 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
511 Fut: 'static + Send + Future<Output = Result<()>>,
512 M: EnvelopedMessage,
513 {
514 self.add_handler(move |envelope, session| handler(envelope.payload, session));
515 self
516 }
517
518 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
519 where
520 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
521 Fut: Send + Future<Output = Result<()>>,
522 M: RequestMessage,
523 {
524 let handler = Arc::new(handler);
525 self.add_handler(move |envelope, session| {
526 let receipt = envelope.receipt();
527 let handler = handler.clone();
528 async move {
529 let peer = session.peer.clone();
530 let responded = Arc::new(AtomicBool::default());
531 let response = Response {
532 peer: peer.clone(),
533 responded: responded.clone(),
534 receipt,
535 };
536 match (handler)(envelope.payload, response, session).await {
537 Ok(()) => {
538 if responded.load(std::sync::atomic::Ordering::SeqCst) {
539 Ok(())
540 } else {
541 Err(anyhow!("handler did not send a response"))?
542 }
543 }
544 Err(error) => {
545 let proto_err = match &error {
546 Error::Internal(err) => err.to_proto(),
547 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
548 };
549 peer.respond_with_error(receipt, proto_err)?;
550 Err(error)
551 }
552 }
553 }
554 })
555 }
556
557 pub fn handle_connection(
558 self: &Arc<Self>,
559 connection: Connection,
560 address: String,
561 user: User,
562 impersonator: Option<User>,
563 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
564 executor: Executor,
565 ) -> impl Future<Output = Result<()>> {
566 let this = self.clone();
567 let user_id = user.id;
568 let login = user.github_login;
569 let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
570 if let Some(impersonator) = impersonator {
571 span.record("impersonator", &impersonator.github_login);
572 }
573 let mut teardown = self.teardown.subscribe();
574 async move {
575 let (connection_id, handle_io, mut incoming_rx) = this
576 .peer
577 .add_connection(connection, {
578 let executor = executor.clone();
579 move |duration| executor.sleep(duration)
580 });
581
582 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
583 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
584 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
585
586 if let Some(send_connection_id) = send_connection_id.take() {
587 let _ = send_connection_id.send(connection_id);
588 }
589
590 if !user.connected_once {
591 this.peer.send(connection_id, proto::ShowContacts {})?;
592 this.app_state.db.set_user_connected_once(user_id, true).await?;
593 }
594
595 let (contacts, channels_for_user, channel_invites) = future::try_join3(
596 this.app_state.db.get_contacts(user_id),
597 this.app_state.db.get_channels_for_user(user_id),
598 this.app_state.db.get_channel_invites_for_user(user_id),
599 ).await?;
600
601 {
602 let mut pool = this.connection_pool.lock();
603 pool.add_connection(connection_id, user_id, user.admin);
604 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
605 this.peer.send(connection_id, build_update_user_channels(&channels_for_user.channel_memberships))?;
606 this.peer.send(connection_id, build_channels_update(
607 channels_for_user,
608 channel_invites
609 ))?;
610 }
611
612 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
613 this.peer.send(connection_id, incoming_call)?;
614 }
615
616 let session = Session {
617 user_id,
618 connection_id,
619 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
620 zed_environment: this.app_state.config.zed_environment.clone(),
621 peer: this.peer.clone(),
622 connection_pool: this.connection_pool.clone(),
623 live_kit_client: this.app_state.live_kit_client.clone(),
624 _executor: executor.clone()
625 };
626 update_user_contacts(user_id, &session).await?;
627
628 let handle_io = handle_io.fuse();
629 futures::pin_mut!(handle_io);
630
631 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
632 // This prevents deadlocks when e.g., client A performs a request to client B and
633 // client B performs a request to client A. If both clients stop processing further
634 // messages until their respective request completes, they won't have a chance to
635 // respond to the other client's request and cause a deadlock.
636 //
637 // This arrangement ensures we will attempt to process earlier messages first, but fall
638 // back to processing messages arrived later in the spirit of making progress.
639 let mut foreground_message_handlers = FuturesUnordered::new();
640 let concurrent_handlers = Arc::new(Semaphore::new(256));
641 loop {
642 let next_message = async {
643 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
644 let message = incoming_rx.next().await;
645 (permit, message)
646 }.fuse();
647 futures::pin_mut!(next_message);
648 futures::select_biased! {
649 _ = teardown.changed().fuse() => return Ok(()),
650 result = handle_io => {
651 if let Err(error) = result {
652 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
653 }
654 break;
655 }
656 _ = foreground_message_handlers.next() => {}
657 next_message = next_message => {
658 let (permit, message) = next_message;
659 if let Some(message) = message {
660 let type_name = message.payload_type_name();
661 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
662 let span_enter = span.enter();
663 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
664 let is_background = message.is_background();
665 let handle_message = (handler)(message, session.clone());
666 drop(span_enter);
667
668 let handle_message = async move {
669 handle_message.await;
670 drop(permit);
671 }.instrument(span);
672 if is_background {
673 executor.spawn_detached(handle_message);
674 } else {
675 foreground_message_handlers.push(handle_message);
676 }
677 } else {
678 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
679 }
680 } else {
681 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
682 break;
683 }
684 }
685 }
686 }
687
688 drop(foreground_message_handlers);
689 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
690 if let Err(error) = connection_lost(session, teardown, executor).await {
691 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
692 }
693
694 Ok(())
695 }.instrument(span)
696 }
697
698 pub async fn invite_code_redeemed(
699 self: &Arc<Self>,
700 inviter_id: UserId,
701 invitee_id: UserId,
702 ) -> Result<()> {
703 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
704 if let Some(code) = &user.invite_code {
705 let pool = self.connection_pool.lock();
706 let invitee_contact = contact_for_user(invitee_id, false, &pool);
707 for connection_id in pool.user_connection_ids(inviter_id) {
708 self.peer.send(
709 connection_id,
710 proto::UpdateContacts {
711 contacts: vec![invitee_contact.clone()],
712 ..Default::default()
713 },
714 )?;
715 self.peer.send(
716 connection_id,
717 proto::UpdateInviteInfo {
718 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
719 count: user.invite_count as u32,
720 },
721 )?;
722 }
723 }
724 }
725 Ok(())
726 }
727
728 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
729 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
730 if let Some(invite_code) = &user.invite_code {
731 let pool = self.connection_pool.lock();
732 for connection_id in pool.user_connection_ids(user_id) {
733 self.peer.send(
734 connection_id,
735 proto::UpdateInviteInfo {
736 url: format!(
737 "{}{}",
738 self.app_state.config.invite_link_prefix, invite_code
739 ),
740 count: user.invite_count as u32,
741 },
742 )?;
743 }
744 }
745 }
746 Ok(())
747 }
748
749 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
750 ServerSnapshot {
751 connection_pool: ConnectionPoolGuard {
752 guard: self.connection_pool.lock(),
753 _not_send: PhantomData,
754 },
755 peer: &self.peer,
756 }
757 }
758}
759
760impl<'a> Deref for ConnectionPoolGuard<'a> {
761 type Target = ConnectionPool;
762
763 fn deref(&self) -> &Self::Target {
764 &*self.guard
765 }
766}
767
768impl<'a> DerefMut for ConnectionPoolGuard<'a> {
769 fn deref_mut(&mut self) -> &mut Self::Target {
770 &mut *self.guard
771 }
772}
773
774impl<'a> Drop for ConnectionPoolGuard<'a> {
775 fn drop(&mut self) {
776 #[cfg(test)]
777 self.check_invariants();
778 }
779}
780
781fn broadcast<F>(
782 sender_id: Option<ConnectionId>,
783 receiver_ids: impl IntoIterator<Item = ConnectionId>,
784 mut f: F,
785) where
786 F: FnMut(ConnectionId) -> anyhow::Result<()>,
787{
788 for receiver_id in receiver_ids {
789 if Some(receiver_id) != sender_id {
790 if let Err(error) = f(receiver_id) {
791 tracing::error!("failed to send to {:?} {}", receiver_id, error);
792 }
793 }
794 }
795}
796
797lazy_static! {
798 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
799 static ref ZED_APP_VERSION: HeaderName = HeaderName::from_static("x-zed-app-version");
800}
801
802pub struct ProtocolVersion(u32);
803
804impl Header for ProtocolVersion {
805 fn name() -> &'static HeaderName {
806 &ZED_PROTOCOL_VERSION
807 }
808
809 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
810 where
811 Self: Sized,
812 I: Iterator<Item = &'i axum::http::HeaderValue>,
813 {
814 let version = values
815 .next()
816 .ok_or_else(axum::headers::Error::invalid)?
817 .to_str()
818 .map_err(|_| axum::headers::Error::invalid())?
819 .parse()
820 .map_err(|_| axum::headers::Error::invalid())?;
821 Ok(Self(version))
822 }
823
824 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
825 values.extend([self.0.to_string().parse().unwrap()]);
826 }
827}
828
829pub struct AppVersionHeader(SemanticVersion);
830impl Header for AppVersionHeader {
831 fn name() -> &'static HeaderName {
832 &ZED_APP_VERSION
833 }
834
835 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
836 where
837 Self: Sized,
838 I: Iterator<Item = &'i axum::http::HeaderValue>,
839 {
840 let version = values
841 .next()
842 .ok_or_else(axum::headers::Error::invalid)?
843 .to_str()
844 .map_err(|_| axum::headers::Error::invalid())?
845 .parse()
846 .map_err(|_| axum::headers::Error::invalid())?;
847 Ok(Self(version))
848 }
849
850 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
851 values.extend([self.0.to_string().parse().unwrap()]);
852 }
853}
854
855pub fn routes(server: Arc<Server>) -> Router<Body> {
856 Router::new()
857 .route("/rpc", get(handle_websocket_request))
858 .layer(
859 ServiceBuilder::new()
860 .layer(Extension(server.app_state.clone()))
861 .layer(middleware::from_fn(auth::validate_header)),
862 )
863 .route("/metrics", get(handle_metrics))
864 .layer(Extension(server))
865}
866
867pub async fn handle_websocket_request(
868 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
869 _app_version_header: Option<TypedHeader<AppVersionHeader>>,
870 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
871 Extension(server): Extension<Arc<Server>>,
872 Extension(user): Extension<User>,
873 Extension(impersonator): Extension<Impersonator>,
874 ws: WebSocketUpgrade,
875) -> axum::response::Response {
876 if protocol_version != rpc::PROTOCOL_VERSION {
877 return (
878 StatusCode::UPGRADE_REQUIRED,
879 "client must be upgraded".to_string(),
880 )
881 .into_response();
882 }
883
884 let socket_address = socket_address.to_string();
885 ws.on_upgrade(move |socket| {
886 use util::ResultExt;
887 let socket = socket
888 .map_ok(to_tungstenite_message)
889 .err_into()
890 .with(|message| async move { Ok(to_axum_message(message)) });
891 let connection = Connection::new(Box::pin(socket));
892 async move {
893 server
894 .handle_connection(
895 connection,
896 socket_address,
897 user,
898 impersonator.0,
899 None,
900 Executor::Production,
901 )
902 .await
903 .log_err();
904 }
905 })
906}
907
908pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
909 let connections = server
910 .connection_pool
911 .lock()
912 .connections()
913 .filter(|connection| !connection.admin)
914 .count();
915
916 METRIC_CONNECTIONS.set(connections as _);
917
918 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
919 METRIC_SHARED_PROJECTS.set(shared_projects as _);
920
921 let encoder = prometheus::TextEncoder::new();
922 let metric_families = prometheus::gather();
923 let encoded_metrics = encoder
924 .encode_to_string(&metric_families)
925 .map_err(|err| anyhow!("{}", err))?;
926 Ok(encoded_metrics)
927}
928
929#[instrument(err, skip(executor))]
930async fn connection_lost(
931 session: Session,
932 mut teardown: watch::Receiver<()>,
933 executor: Executor,
934) -> Result<()> {
935 session.peer.disconnect(session.connection_id);
936 session
937 .connection_pool()
938 .await
939 .remove_connection(session.connection_id)?;
940
941 session
942 .db()
943 .await
944 .connection_lost(session.connection_id)
945 .await
946 .trace_err();
947
948 futures::select_biased! {
949 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
950 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
951 leave_room_for_session(&session).await.trace_err();
952 leave_channel_buffers_for_session(&session)
953 .await
954 .trace_err();
955
956 if !session
957 .connection_pool()
958 .await
959 .is_user_online(session.user_id)
960 {
961 let db = session.db().await;
962 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
963 room_updated(&room, &session.peer);
964 }
965 }
966
967 update_user_contacts(session.user_id, &session).await?;
968 }
969 _ = teardown.changed().fuse() => {}
970 }
971
972 Ok(())
973}
974
975/// Acknowledges a ping from a client, used to keep the connection alive.
976async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
977 response.send(proto::Ack {})?;
978 Ok(())
979}
980
981/// Creates a new room for calling (outside of channels)
982async fn create_room(
983 _request: proto::CreateRoom,
984 response: Response<proto::CreateRoom>,
985 session: Session,
986) -> Result<()> {
987 let live_kit_room = nanoid::nanoid!(30);
988
989 let live_kit_connection_info = {
990 let live_kit_room = live_kit_room.clone();
991 let live_kit = session.live_kit_client.as_ref();
992
993 util::async_maybe!({
994 let live_kit = live_kit?;
995
996 let token = live_kit
997 .room_token(&live_kit_room, &session.user_id.to_string())
998 .trace_err()?;
999
1000 Some(proto::LiveKitConnectionInfo {
1001 server_url: live_kit.url().into(),
1002 token,
1003 can_publish: true,
1004 })
1005 })
1006 }
1007 .await;
1008
1009 let room = session
1010 .db()
1011 .await
1012 .create_room(
1013 session.user_id,
1014 session.connection_id,
1015 &live_kit_room,
1016 &session.zed_environment,
1017 )
1018 .await?;
1019
1020 response.send(proto::CreateRoomResponse {
1021 room: Some(room.clone()),
1022 live_kit_connection_info,
1023 })?;
1024
1025 update_user_contacts(session.user_id, &session).await?;
1026 Ok(())
1027}
1028
1029/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1030async fn join_room(
1031 request: proto::JoinRoom,
1032 response: Response<proto::JoinRoom>,
1033 session: Session,
1034) -> Result<()> {
1035 let room_id = RoomId::from_proto(request.id);
1036
1037 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1038
1039 if let Some(channel_id) = channel_id {
1040 return join_channel_internal(channel_id, Box::new(response), session).await;
1041 }
1042
1043 let joined_room = {
1044 let room = session
1045 .db()
1046 .await
1047 .join_room(
1048 room_id,
1049 session.user_id,
1050 session.connection_id,
1051 session.zed_environment.as_ref(),
1052 )
1053 .await?;
1054 room_updated(&room.room, &session.peer);
1055 room.into_inner()
1056 };
1057
1058 for connection_id in session
1059 .connection_pool()
1060 .await
1061 .user_connection_ids(session.user_id)
1062 {
1063 session
1064 .peer
1065 .send(
1066 connection_id,
1067 proto::CallCanceled {
1068 room_id: room_id.to_proto(),
1069 },
1070 )
1071 .trace_err();
1072 }
1073
1074 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1075 if let Some(token) = live_kit
1076 .room_token(
1077 &joined_room.room.live_kit_room,
1078 &session.user_id.to_string(),
1079 )
1080 .trace_err()
1081 {
1082 Some(proto::LiveKitConnectionInfo {
1083 server_url: live_kit.url().into(),
1084 token,
1085 can_publish: true,
1086 })
1087 } else {
1088 None
1089 }
1090 } else {
1091 None
1092 };
1093
1094 response.send(proto::JoinRoomResponse {
1095 room: Some(joined_room.room),
1096 channel_id: None,
1097 live_kit_connection_info,
1098 })?;
1099
1100 update_user_contacts(session.user_id, &session).await?;
1101 Ok(())
1102}
1103
1104/// Rejoin room is used to reconnect to a room after connection errors.
1105async fn rejoin_room(
1106 request: proto::RejoinRoom,
1107 response: Response<proto::RejoinRoom>,
1108 session: Session,
1109) -> Result<()> {
1110 let room;
1111 let channel_id;
1112 let channel_members;
1113 {
1114 let mut rejoined_room = session
1115 .db()
1116 .await
1117 .rejoin_room(request, session.user_id, session.connection_id)
1118 .await?;
1119
1120 response.send(proto::RejoinRoomResponse {
1121 room: Some(rejoined_room.room.clone()),
1122 reshared_projects: rejoined_room
1123 .reshared_projects
1124 .iter()
1125 .map(|project| proto::ResharedProject {
1126 id: project.id.to_proto(),
1127 collaborators: project
1128 .collaborators
1129 .iter()
1130 .map(|collaborator| collaborator.to_proto())
1131 .collect(),
1132 })
1133 .collect(),
1134 rejoined_projects: rejoined_room
1135 .rejoined_projects
1136 .iter()
1137 .map(|rejoined_project| proto::RejoinedProject {
1138 id: rejoined_project.id.to_proto(),
1139 worktrees: rejoined_project
1140 .worktrees
1141 .iter()
1142 .map(|worktree| proto::WorktreeMetadata {
1143 id: worktree.id,
1144 root_name: worktree.root_name.clone(),
1145 visible: worktree.visible,
1146 abs_path: worktree.abs_path.clone(),
1147 })
1148 .collect(),
1149 collaborators: rejoined_project
1150 .collaborators
1151 .iter()
1152 .map(|collaborator| collaborator.to_proto())
1153 .collect(),
1154 language_servers: rejoined_project.language_servers.clone(),
1155 })
1156 .collect(),
1157 })?;
1158 room_updated(&rejoined_room.room, &session.peer);
1159
1160 for project in &rejoined_room.reshared_projects {
1161 for collaborator in &project.collaborators {
1162 session
1163 .peer
1164 .send(
1165 collaborator.connection_id,
1166 proto::UpdateProjectCollaborator {
1167 project_id: project.id.to_proto(),
1168 old_peer_id: Some(project.old_connection_id.into()),
1169 new_peer_id: Some(session.connection_id.into()),
1170 },
1171 )
1172 .trace_err();
1173 }
1174
1175 broadcast(
1176 Some(session.connection_id),
1177 project
1178 .collaborators
1179 .iter()
1180 .map(|collaborator| collaborator.connection_id),
1181 |connection_id| {
1182 session.peer.forward_send(
1183 session.connection_id,
1184 connection_id,
1185 proto::UpdateProject {
1186 project_id: project.id.to_proto(),
1187 worktrees: project.worktrees.clone(),
1188 },
1189 )
1190 },
1191 );
1192 }
1193
1194 for project in &rejoined_room.rejoined_projects {
1195 for collaborator in &project.collaborators {
1196 session
1197 .peer
1198 .send(
1199 collaborator.connection_id,
1200 proto::UpdateProjectCollaborator {
1201 project_id: project.id.to_proto(),
1202 old_peer_id: Some(project.old_connection_id.into()),
1203 new_peer_id: Some(session.connection_id.into()),
1204 },
1205 )
1206 .trace_err();
1207 }
1208 }
1209
1210 for project in &mut rejoined_room.rejoined_projects {
1211 for worktree in mem::take(&mut project.worktrees) {
1212 #[cfg(any(test, feature = "test-support"))]
1213 const MAX_CHUNK_SIZE: usize = 2;
1214 #[cfg(not(any(test, feature = "test-support")))]
1215 const MAX_CHUNK_SIZE: usize = 256;
1216
1217 // Stream this worktree's entries.
1218 let message = proto::UpdateWorktree {
1219 project_id: project.id.to_proto(),
1220 worktree_id: worktree.id,
1221 abs_path: worktree.abs_path.clone(),
1222 root_name: worktree.root_name,
1223 updated_entries: worktree.updated_entries,
1224 removed_entries: worktree.removed_entries,
1225 scan_id: worktree.scan_id,
1226 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1227 updated_repositories: worktree.updated_repositories,
1228 removed_repositories: worktree.removed_repositories,
1229 };
1230 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1231 session.peer.send(session.connection_id, update.clone())?;
1232 }
1233
1234 // Stream this worktree's diagnostics.
1235 for summary in worktree.diagnostic_summaries {
1236 session.peer.send(
1237 session.connection_id,
1238 proto::UpdateDiagnosticSummary {
1239 project_id: project.id.to_proto(),
1240 worktree_id: worktree.id,
1241 summary: Some(summary),
1242 },
1243 )?;
1244 }
1245
1246 for settings_file in worktree.settings_files {
1247 session.peer.send(
1248 session.connection_id,
1249 proto::UpdateWorktreeSettings {
1250 project_id: project.id.to_proto(),
1251 worktree_id: worktree.id,
1252 path: settings_file.path,
1253 content: Some(settings_file.content),
1254 },
1255 )?;
1256 }
1257 }
1258
1259 for language_server in &project.language_servers {
1260 session.peer.send(
1261 session.connection_id,
1262 proto::UpdateLanguageServer {
1263 project_id: project.id.to_proto(),
1264 language_server_id: language_server.id,
1265 variant: Some(
1266 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1267 proto::LspDiskBasedDiagnosticsUpdated {},
1268 ),
1269 ),
1270 },
1271 )?;
1272 }
1273 }
1274
1275 let rejoined_room = rejoined_room.into_inner();
1276
1277 room = rejoined_room.room;
1278 channel_id = rejoined_room.channel_id;
1279 channel_members = rejoined_room.channel_members;
1280 }
1281
1282 if let Some(channel_id) = channel_id {
1283 channel_updated(
1284 channel_id,
1285 &room,
1286 &channel_members,
1287 &session.peer,
1288 &*session.connection_pool().await,
1289 );
1290 }
1291
1292 update_user_contacts(session.user_id, &session).await?;
1293 Ok(())
1294}
1295
1296/// leave room disconnects from the room.
1297async fn leave_room(
1298 _: proto::LeaveRoom,
1299 response: Response<proto::LeaveRoom>,
1300 session: Session,
1301) -> Result<()> {
1302 leave_room_for_session(&session).await?;
1303 response.send(proto::Ack {})?;
1304 Ok(())
1305}
1306
1307/// Updates the permissions of someone else in the room.
1308async fn set_room_participant_role(
1309 request: proto::SetRoomParticipantRole,
1310 response: Response<proto::SetRoomParticipantRole>,
1311 session: Session,
1312) -> Result<()> {
1313 let (live_kit_room, can_publish) = {
1314 let room = session
1315 .db()
1316 .await
1317 .set_room_participant_role(
1318 session.user_id,
1319 RoomId::from_proto(request.room_id),
1320 UserId::from_proto(request.user_id),
1321 ChannelRole::from(request.role()),
1322 )
1323 .await?;
1324
1325 let live_kit_room = room.live_kit_room.clone();
1326 let can_publish = ChannelRole::from(request.role()).can_publish_to_rooms();
1327 room_updated(&room, &session.peer);
1328 (live_kit_room, can_publish)
1329 };
1330
1331 if let Some(live_kit) = session.live_kit_client.as_ref() {
1332 live_kit
1333 .update_participant(
1334 live_kit_room.clone(),
1335 request.user_id.to_string(),
1336 live_kit_server::proto::ParticipantPermission {
1337 can_subscribe: true,
1338 can_publish,
1339 can_publish_data: can_publish,
1340 hidden: false,
1341 recorder: false,
1342 },
1343 )
1344 .await
1345 .trace_err();
1346 }
1347
1348 response.send(proto::Ack {})?;
1349 Ok(())
1350}
1351
1352/// Call someone else into the current room
1353async fn call(
1354 request: proto::Call,
1355 response: Response<proto::Call>,
1356 session: Session,
1357) -> Result<()> {
1358 let room_id = RoomId::from_proto(request.room_id);
1359 let calling_user_id = session.user_id;
1360 let calling_connection_id = session.connection_id;
1361 let called_user_id = UserId::from_proto(request.called_user_id);
1362 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1363 if !session
1364 .db()
1365 .await
1366 .has_contact(calling_user_id, called_user_id)
1367 .await?
1368 {
1369 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1370 }
1371
1372 let incoming_call = {
1373 let (room, incoming_call) = &mut *session
1374 .db()
1375 .await
1376 .call(
1377 room_id,
1378 calling_user_id,
1379 calling_connection_id,
1380 called_user_id,
1381 initial_project_id,
1382 )
1383 .await?;
1384 room_updated(&room, &session.peer);
1385 mem::take(incoming_call)
1386 };
1387 update_user_contacts(called_user_id, &session).await?;
1388
1389 let mut calls = session
1390 .connection_pool()
1391 .await
1392 .user_connection_ids(called_user_id)
1393 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1394 .collect::<FuturesUnordered<_>>();
1395
1396 while let Some(call_response) = calls.next().await {
1397 match call_response.as_ref() {
1398 Ok(_) => {
1399 response.send(proto::Ack {})?;
1400 return Ok(());
1401 }
1402 Err(_) => {
1403 call_response.trace_err();
1404 }
1405 }
1406 }
1407
1408 {
1409 let room = session
1410 .db()
1411 .await
1412 .call_failed(room_id, called_user_id)
1413 .await?;
1414 room_updated(&room, &session.peer);
1415 }
1416 update_user_contacts(called_user_id, &session).await?;
1417
1418 Err(anyhow!("failed to ring user"))?
1419}
1420
1421/// Cancel an outgoing call.
1422async fn cancel_call(
1423 request: proto::CancelCall,
1424 response: Response<proto::CancelCall>,
1425 session: Session,
1426) -> Result<()> {
1427 let called_user_id = UserId::from_proto(request.called_user_id);
1428 let room_id = RoomId::from_proto(request.room_id);
1429 {
1430 let room = session
1431 .db()
1432 .await
1433 .cancel_call(room_id, session.connection_id, called_user_id)
1434 .await?;
1435 room_updated(&room, &session.peer);
1436 }
1437
1438 for connection_id in session
1439 .connection_pool()
1440 .await
1441 .user_connection_ids(called_user_id)
1442 {
1443 session
1444 .peer
1445 .send(
1446 connection_id,
1447 proto::CallCanceled {
1448 room_id: room_id.to_proto(),
1449 },
1450 )
1451 .trace_err();
1452 }
1453 response.send(proto::Ack {})?;
1454
1455 update_user_contacts(called_user_id, &session).await?;
1456 Ok(())
1457}
1458
1459/// Decline an incoming call.
1460async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1461 let room_id = RoomId::from_proto(message.room_id);
1462 {
1463 let room = session
1464 .db()
1465 .await
1466 .decline_call(Some(room_id), session.user_id)
1467 .await?
1468 .ok_or_else(|| anyhow!("failed to decline call"))?;
1469 room_updated(&room, &session.peer);
1470 }
1471
1472 for connection_id in session
1473 .connection_pool()
1474 .await
1475 .user_connection_ids(session.user_id)
1476 {
1477 session
1478 .peer
1479 .send(
1480 connection_id,
1481 proto::CallCanceled {
1482 room_id: room_id.to_proto(),
1483 },
1484 )
1485 .trace_err();
1486 }
1487 update_user_contacts(session.user_id, &session).await?;
1488 Ok(())
1489}
1490
1491/// Updates other participants in the room with your current location.
1492async fn update_participant_location(
1493 request: proto::UpdateParticipantLocation,
1494 response: Response<proto::UpdateParticipantLocation>,
1495 session: Session,
1496) -> Result<()> {
1497 let room_id = RoomId::from_proto(request.room_id);
1498 let location = request
1499 .location
1500 .ok_or_else(|| anyhow!("invalid location"))?;
1501
1502 let db = session.db().await;
1503 let room = db
1504 .update_room_participant_location(room_id, session.connection_id, location)
1505 .await?;
1506
1507 room_updated(&room, &session.peer);
1508 response.send(proto::Ack {})?;
1509 Ok(())
1510}
1511
1512/// Share a project into the room.
1513async fn share_project(
1514 request: proto::ShareProject,
1515 response: Response<proto::ShareProject>,
1516 session: Session,
1517) -> Result<()> {
1518 let (project_id, room) = &*session
1519 .db()
1520 .await
1521 .share_project(
1522 RoomId::from_proto(request.room_id),
1523 session.connection_id,
1524 &request.worktrees,
1525 )
1526 .await?;
1527 response.send(proto::ShareProjectResponse {
1528 project_id: project_id.to_proto(),
1529 })?;
1530 room_updated(&room, &session.peer);
1531
1532 Ok(())
1533}
1534
1535/// Unshare a project from the room.
1536async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1537 let project_id = ProjectId::from_proto(message.project_id);
1538
1539 let (room, guest_connection_ids) = &*session
1540 .db()
1541 .await
1542 .unshare_project(project_id, session.connection_id)
1543 .await?;
1544
1545 broadcast(
1546 Some(session.connection_id),
1547 guest_connection_ids.iter().copied(),
1548 |conn_id| session.peer.send(conn_id, message.clone()),
1549 );
1550 room_updated(&room, &session.peer);
1551
1552 Ok(())
1553}
1554
1555/// Join someone elses shared project.
1556async fn join_project(
1557 request: proto::JoinProject,
1558 response: Response<proto::JoinProject>,
1559 session: Session,
1560) -> Result<()> {
1561 let project_id = ProjectId::from_proto(request.project_id);
1562 let guest_user_id = session.user_id;
1563
1564 tracing::info!(%project_id, "join project");
1565
1566 let (project, replica_id) = &mut *session
1567 .db()
1568 .await
1569 .join_project(project_id, session.connection_id)
1570 .await?;
1571
1572 let collaborators = project
1573 .collaborators
1574 .iter()
1575 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1576 .map(|collaborator| collaborator.to_proto())
1577 .collect::<Vec<_>>();
1578
1579 let worktrees = project
1580 .worktrees
1581 .iter()
1582 .map(|(id, worktree)| proto::WorktreeMetadata {
1583 id: *id,
1584 root_name: worktree.root_name.clone(),
1585 visible: worktree.visible,
1586 abs_path: worktree.abs_path.clone(),
1587 })
1588 .collect::<Vec<_>>();
1589
1590 for collaborator in &collaborators {
1591 session
1592 .peer
1593 .send(
1594 collaborator.peer_id.unwrap().into(),
1595 proto::AddProjectCollaborator {
1596 project_id: project_id.to_proto(),
1597 collaborator: Some(proto::Collaborator {
1598 peer_id: Some(session.connection_id.into()),
1599 replica_id: replica_id.0 as u32,
1600 user_id: guest_user_id.to_proto(),
1601 }),
1602 },
1603 )
1604 .trace_err();
1605 }
1606
1607 // First, we send the metadata associated with each worktree.
1608 response.send(proto::JoinProjectResponse {
1609 worktrees: worktrees.clone(),
1610 replica_id: replica_id.0 as u32,
1611 collaborators: collaborators.clone(),
1612 language_servers: project.language_servers.clone(),
1613 })?;
1614
1615 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1616 #[cfg(any(test, feature = "test-support"))]
1617 const MAX_CHUNK_SIZE: usize = 2;
1618 #[cfg(not(any(test, feature = "test-support")))]
1619 const MAX_CHUNK_SIZE: usize = 256;
1620
1621 // Stream this worktree's entries.
1622 let message = proto::UpdateWorktree {
1623 project_id: project_id.to_proto(),
1624 worktree_id,
1625 abs_path: worktree.abs_path.clone(),
1626 root_name: worktree.root_name,
1627 updated_entries: worktree.entries,
1628 removed_entries: Default::default(),
1629 scan_id: worktree.scan_id,
1630 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1631 updated_repositories: worktree.repository_entries.into_values().collect(),
1632 removed_repositories: Default::default(),
1633 };
1634 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1635 session.peer.send(session.connection_id, update.clone())?;
1636 }
1637
1638 // Stream this worktree's diagnostics.
1639 for summary in worktree.diagnostic_summaries {
1640 session.peer.send(
1641 session.connection_id,
1642 proto::UpdateDiagnosticSummary {
1643 project_id: project_id.to_proto(),
1644 worktree_id: worktree.id,
1645 summary: Some(summary),
1646 },
1647 )?;
1648 }
1649
1650 for settings_file in worktree.settings_files {
1651 session.peer.send(
1652 session.connection_id,
1653 proto::UpdateWorktreeSettings {
1654 project_id: project_id.to_proto(),
1655 worktree_id: worktree.id,
1656 path: settings_file.path,
1657 content: Some(settings_file.content),
1658 },
1659 )?;
1660 }
1661 }
1662
1663 for language_server in &project.language_servers {
1664 session.peer.send(
1665 session.connection_id,
1666 proto::UpdateLanguageServer {
1667 project_id: project_id.to_proto(),
1668 language_server_id: language_server.id,
1669 variant: Some(
1670 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1671 proto::LspDiskBasedDiagnosticsUpdated {},
1672 ),
1673 ),
1674 },
1675 )?;
1676 }
1677
1678 Ok(())
1679}
1680
1681/// Leave someone elses shared project.
1682async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1683 let sender_id = session.connection_id;
1684 let project_id = ProjectId::from_proto(request.project_id);
1685
1686 let (room, project) = &*session
1687 .db()
1688 .await
1689 .leave_project(project_id, sender_id)
1690 .await?;
1691 tracing::info!(
1692 %project_id,
1693 host_user_id = %project.host_user_id,
1694 host_connection_id = ?project.host_connection_id,
1695 "leave project"
1696 );
1697
1698 project_left(&project, &session);
1699 room_updated(&room, &session.peer);
1700
1701 Ok(())
1702}
1703
1704/// Updates other participants with changes to the project
1705async fn update_project(
1706 request: proto::UpdateProject,
1707 response: Response<proto::UpdateProject>,
1708 session: Session,
1709) -> Result<()> {
1710 let project_id = ProjectId::from_proto(request.project_id);
1711 let (room, guest_connection_ids) = &*session
1712 .db()
1713 .await
1714 .update_project(project_id, session.connection_id, &request.worktrees)
1715 .await?;
1716 broadcast(
1717 Some(session.connection_id),
1718 guest_connection_ids.iter().copied(),
1719 |connection_id| {
1720 session
1721 .peer
1722 .forward_send(session.connection_id, connection_id, request.clone())
1723 },
1724 );
1725 room_updated(&room, &session.peer);
1726 response.send(proto::Ack {})?;
1727
1728 Ok(())
1729}
1730
1731/// Updates other participants with changes to the worktree
1732async fn update_worktree(
1733 request: proto::UpdateWorktree,
1734 response: Response<proto::UpdateWorktree>,
1735 session: Session,
1736) -> Result<()> {
1737 let guest_connection_ids = session
1738 .db()
1739 .await
1740 .update_worktree(&request, session.connection_id)
1741 .await?;
1742
1743 broadcast(
1744 Some(session.connection_id),
1745 guest_connection_ids.iter().copied(),
1746 |connection_id| {
1747 session
1748 .peer
1749 .forward_send(session.connection_id, connection_id, request.clone())
1750 },
1751 );
1752 response.send(proto::Ack {})?;
1753 Ok(())
1754}
1755
1756/// Updates other participants with changes to the diagnostics
1757async fn update_diagnostic_summary(
1758 message: proto::UpdateDiagnosticSummary,
1759 session: Session,
1760) -> Result<()> {
1761 let guest_connection_ids = session
1762 .db()
1763 .await
1764 .update_diagnostic_summary(&message, session.connection_id)
1765 .await?;
1766
1767 broadcast(
1768 Some(session.connection_id),
1769 guest_connection_ids.iter().copied(),
1770 |connection_id| {
1771 session
1772 .peer
1773 .forward_send(session.connection_id, connection_id, message.clone())
1774 },
1775 );
1776
1777 Ok(())
1778}
1779
1780/// Updates other participants with changes to the worktree settings
1781async fn update_worktree_settings(
1782 message: proto::UpdateWorktreeSettings,
1783 session: Session,
1784) -> Result<()> {
1785 let guest_connection_ids = session
1786 .db()
1787 .await
1788 .update_worktree_settings(&message, session.connection_id)
1789 .await?;
1790
1791 broadcast(
1792 Some(session.connection_id),
1793 guest_connection_ids.iter().copied(),
1794 |connection_id| {
1795 session
1796 .peer
1797 .forward_send(session.connection_id, connection_id, message.clone())
1798 },
1799 );
1800
1801 Ok(())
1802}
1803
1804/// Notify other participants that a language server has started.
1805async fn start_language_server(
1806 request: proto::StartLanguageServer,
1807 session: Session,
1808) -> Result<()> {
1809 let guest_connection_ids = session
1810 .db()
1811 .await
1812 .start_language_server(&request, session.connection_id)
1813 .await?;
1814
1815 broadcast(
1816 Some(session.connection_id),
1817 guest_connection_ids.iter().copied(),
1818 |connection_id| {
1819 session
1820 .peer
1821 .forward_send(session.connection_id, connection_id, request.clone())
1822 },
1823 );
1824 Ok(())
1825}
1826
1827/// Notify other participants that a language server has changed.
1828async fn update_language_server(
1829 request: proto::UpdateLanguageServer,
1830 session: Session,
1831) -> Result<()> {
1832 let project_id = ProjectId::from_proto(request.project_id);
1833 let project_connection_ids = session
1834 .db()
1835 .await
1836 .project_connection_ids(project_id, session.connection_id)
1837 .await?;
1838 broadcast(
1839 Some(session.connection_id),
1840 project_connection_ids.iter().copied(),
1841 |connection_id| {
1842 session
1843 .peer
1844 .forward_send(session.connection_id, connection_id, request.clone())
1845 },
1846 );
1847 Ok(())
1848}
1849
1850/// forward a project request to the host. These requests should be read only
1851/// as guests are allowed to send them.
1852async fn forward_read_only_project_request<T>(
1853 request: T,
1854 response: Response<T>,
1855 session: Session,
1856) -> Result<()>
1857where
1858 T: EntityMessage + RequestMessage,
1859{
1860 let project_id = ProjectId::from_proto(request.remote_entity_id());
1861 let host_connection_id = session
1862 .db()
1863 .await
1864 .host_for_read_only_project_request(project_id, session.connection_id)
1865 .await?;
1866 let payload = session
1867 .peer
1868 .forward_request(session.connection_id, host_connection_id, request)
1869 .await?;
1870 response.send(payload)?;
1871 Ok(())
1872}
1873
1874/// forward a project request to the host. These requests are disallowed
1875/// for guests.
1876async fn forward_mutating_project_request<T>(
1877 request: T,
1878 response: Response<T>,
1879 session: Session,
1880) -> Result<()>
1881where
1882 T: EntityMessage + RequestMessage,
1883{
1884 let project_id = ProjectId::from_proto(request.remote_entity_id());
1885 let host_connection_id = session
1886 .db()
1887 .await
1888 .host_for_mutating_project_request(project_id, session.connection_id)
1889 .await?;
1890 let payload = session
1891 .peer
1892 .forward_request(session.connection_id, host_connection_id, request)
1893 .await?;
1894 response.send(payload)?;
1895 Ok(())
1896}
1897
1898/// Notify other participants that a new buffer has been created
1899async fn create_buffer_for_peer(
1900 request: proto::CreateBufferForPeer,
1901 session: Session,
1902) -> Result<()> {
1903 session
1904 .db()
1905 .await
1906 .check_user_is_project_host(
1907 ProjectId::from_proto(request.project_id),
1908 session.connection_id,
1909 )
1910 .await?;
1911 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1912 session
1913 .peer
1914 .forward_send(session.connection_id, peer_id.into(), request)?;
1915 Ok(())
1916}
1917
1918/// Notify other participants that a buffer has been updated. This is
1919/// allowed for guests as long as the update is limited to selections.
1920async fn update_buffer(
1921 request: proto::UpdateBuffer,
1922 response: Response<proto::UpdateBuffer>,
1923 session: Session,
1924) -> Result<()> {
1925 let project_id = ProjectId::from_proto(request.project_id);
1926 let mut guest_connection_ids;
1927 let mut host_connection_id = None;
1928
1929 let mut requires_write_permission = false;
1930
1931 for op in request.operations.iter() {
1932 match op.variant {
1933 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1934 Some(_) => requires_write_permission = true,
1935 }
1936 }
1937
1938 {
1939 let collaborators = session
1940 .db()
1941 .await
1942 .project_collaborators_for_buffer_update(
1943 project_id,
1944 session.connection_id,
1945 requires_write_permission,
1946 )
1947 .await?;
1948 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1949 for collaborator in collaborators.iter() {
1950 if collaborator.is_host {
1951 host_connection_id = Some(collaborator.connection_id);
1952 } else {
1953 guest_connection_ids.push(collaborator.connection_id);
1954 }
1955 }
1956 }
1957 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1958
1959 broadcast(
1960 Some(session.connection_id),
1961 guest_connection_ids,
1962 |connection_id| {
1963 session
1964 .peer
1965 .forward_send(session.connection_id, connection_id, request.clone())
1966 },
1967 );
1968 if host_connection_id != session.connection_id {
1969 session
1970 .peer
1971 .forward_request(session.connection_id, host_connection_id, request.clone())
1972 .await?;
1973 }
1974
1975 response.send(proto::Ack {})?;
1976 Ok(())
1977}
1978
1979/// Notify other participants that a project has been updated.
1980async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
1981 request: T,
1982 session: Session,
1983) -> Result<()> {
1984 let project_id = ProjectId::from_proto(request.remote_entity_id());
1985 let project_connection_ids = session
1986 .db()
1987 .await
1988 .project_connection_ids(project_id, session.connection_id)
1989 .await?;
1990
1991 broadcast(
1992 Some(session.connection_id),
1993 project_connection_ids.iter().copied(),
1994 |connection_id| {
1995 session
1996 .peer
1997 .forward_send(session.connection_id, connection_id, request.clone())
1998 },
1999 );
2000 Ok(())
2001}
2002
2003/// Start following another user in a call.
2004async fn follow(
2005 request: proto::Follow,
2006 response: Response<proto::Follow>,
2007 session: Session,
2008) -> Result<()> {
2009 let room_id = RoomId::from_proto(request.room_id);
2010 let project_id = request.project_id.map(ProjectId::from_proto);
2011 let leader_id = request
2012 .leader_id
2013 .ok_or_else(|| anyhow!("invalid leader id"))?
2014 .into();
2015 let follower_id = session.connection_id;
2016
2017 session
2018 .db()
2019 .await
2020 .check_room_participants(room_id, leader_id, session.connection_id)
2021 .await?;
2022
2023 let response_payload = session
2024 .peer
2025 .forward_request(session.connection_id, leader_id, request)
2026 .await?;
2027 response.send(response_payload)?;
2028
2029 if let Some(project_id) = project_id {
2030 let room = session
2031 .db()
2032 .await
2033 .follow(room_id, project_id, leader_id, follower_id)
2034 .await?;
2035 room_updated(&room, &session.peer);
2036 }
2037
2038 Ok(())
2039}
2040
2041/// Stop following another user in a call.
2042async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2043 let room_id = RoomId::from_proto(request.room_id);
2044 let project_id = request.project_id.map(ProjectId::from_proto);
2045 let leader_id = request
2046 .leader_id
2047 .ok_or_else(|| anyhow!("invalid leader id"))?
2048 .into();
2049 let follower_id = session.connection_id;
2050
2051 session
2052 .db()
2053 .await
2054 .check_room_participants(room_id, leader_id, session.connection_id)
2055 .await?;
2056
2057 session
2058 .peer
2059 .forward_send(session.connection_id, leader_id, request)?;
2060
2061 if let Some(project_id) = project_id {
2062 let room = session
2063 .db()
2064 .await
2065 .unfollow(room_id, project_id, leader_id, follower_id)
2066 .await?;
2067 room_updated(&room, &session.peer);
2068 }
2069
2070 Ok(())
2071}
2072
2073/// Notify everyone following you of your current location.
2074async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2075 let room_id = RoomId::from_proto(request.room_id);
2076 let database = session.db.lock().await;
2077
2078 let connection_ids = if let Some(project_id) = request.project_id {
2079 let project_id = ProjectId::from_proto(project_id);
2080 database
2081 .project_connection_ids(project_id, session.connection_id)
2082 .await?
2083 } else {
2084 database
2085 .room_connection_ids(room_id, session.connection_id)
2086 .await?
2087 };
2088
2089 // For now, don't send view update messages back to that view's current leader.
2090 let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2091 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2092 _ => None,
2093 });
2094
2095 for follower_peer_id in request.follower_ids.iter().copied() {
2096 let follower_connection_id = follower_peer_id.into();
2097 if Some(follower_peer_id) != connection_id_to_omit
2098 && connection_ids.contains(&follower_connection_id)
2099 {
2100 session.peer.forward_send(
2101 session.connection_id,
2102 follower_connection_id,
2103 request.clone(),
2104 )?;
2105 }
2106 }
2107 Ok(())
2108}
2109
2110/// Get public data about users.
2111async fn get_users(
2112 request: proto::GetUsers,
2113 response: Response<proto::GetUsers>,
2114 session: Session,
2115) -> Result<()> {
2116 let user_ids = request
2117 .user_ids
2118 .into_iter()
2119 .map(UserId::from_proto)
2120 .collect();
2121 let users = session
2122 .db()
2123 .await
2124 .get_users_by_ids(user_ids)
2125 .await?
2126 .into_iter()
2127 .map(|user| proto::User {
2128 id: user.id.to_proto(),
2129 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2130 github_login: user.github_login,
2131 })
2132 .collect();
2133 response.send(proto::UsersResponse { users })?;
2134 Ok(())
2135}
2136
2137/// Search for users (to invite) buy Github login
2138async fn fuzzy_search_users(
2139 request: proto::FuzzySearchUsers,
2140 response: Response<proto::FuzzySearchUsers>,
2141 session: Session,
2142) -> Result<()> {
2143 let query = request.query;
2144 let users = match query.len() {
2145 0 => vec![],
2146 1 | 2 => session
2147 .db()
2148 .await
2149 .get_user_by_github_login(&query)
2150 .await?
2151 .into_iter()
2152 .collect(),
2153 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2154 };
2155 let users = users
2156 .into_iter()
2157 .filter(|user| user.id != session.user_id)
2158 .map(|user| proto::User {
2159 id: user.id.to_proto(),
2160 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2161 github_login: user.github_login,
2162 })
2163 .collect();
2164 response.send(proto::UsersResponse { users })?;
2165 Ok(())
2166}
2167
2168/// Send a contact request to another user.
2169async fn request_contact(
2170 request: proto::RequestContact,
2171 response: Response<proto::RequestContact>,
2172 session: Session,
2173) -> Result<()> {
2174 let requester_id = session.user_id;
2175 let responder_id = UserId::from_proto(request.responder_id);
2176 if requester_id == responder_id {
2177 return Err(anyhow!("cannot add yourself as a contact"))?;
2178 }
2179
2180 let notifications = session
2181 .db()
2182 .await
2183 .send_contact_request(requester_id, responder_id)
2184 .await?;
2185
2186 // Update outgoing contact requests of requester
2187 let mut update = proto::UpdateContacts::default();
2188 update.outgoing_requests.push(responder_id.to_proto());
2189 for connection_id in session
2190 .connection_pool()
2191 .await
2192 .user_connection_ids(requester_id)
2193 {
2194 session.peer.send(connection_id, update.clone())?;
2195 }
2196
2197 // Update incoming contact requests of responder
2198 let mut update = proto::UpdateContacts::default();
2199 update
2200 .incoming_requests
2201 .push(proto::IncomingContactRequest {
2202 requester_id: requester_id.to_proto(),
2203 });
2204 let connection_pool = session.connection_pool().await;
2205 for connection_id in connection_pool.user_connection_ids(responder_id) {
2206 session.peer.send(connection_id, update.clone())?;
2207 }
2208
2209 send_notifications(&*connection_pool, &session.peer, notifications);
2210
2211 response.send(proto::Ack {})?;
2212 Ok(())
2213}
2214
2215/// Accept or decline a contact request
2216async fn respond_to_contact_request(
2217 request: proto::RespondToContactRequest,
2218 response: Response<proto::RespondToContactRequest>,
2219 session: Session,
2220) -> Result<()> {
2221 let responder_id = session.user_id;
2222 let requester_id = UserId::from_proto(request.requester_id);
2223 let db = session.db().await;
2224 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2225 db.dismiss_contact_notification(responder_id, requester_id)
2226 .await?;
2227 } else {
2228 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2229
2230 let notifications = db
2231 .respond_to_contact_request(responder_id, requester_id, accept)
2232 .await?;
2233 let requester_busy = db.is_user_busy(requester_id).await?;
2234 let responder_busy = db.is_user_busy(responder_id).await?;
2235
2236 let pool = session.connection_pool().await;
2237 // Update responder with new contact
2238 let mut update = proto::UpdateContacts::default();
2239 if accept {
2240 update
2241 .contacts
2242 .push(contact_for_user(requester_id, requester_busy, &pool));
2243 }
2244 update
2245 .remove_incoming_requests
2246 .push(requester_id.to_proto());
2247 for connection_id in pool.user_connection_ids(responder_id) {
2248 session.peer.send(connection_id, update.clone())?;
2249 }
2250
2251 // Update requester with new contact
2252 let mut update = proto::UpdateContacts::default();
2253 if accept {
2254 update
2255 .contacts
2256 .push(contact_for_user(responder_id, responder_busy, &pool));
2257 }
2258 update
2259 .remove_outgoing_requests
2260 .push(responder_id.to_proto());
2261
2262 for connection_id in pool.user_connection_ids(requester_id) {
2263 session.peer.send(connection_id, update.clone())?;
2264 }
2265
2266 send_notifications(&*pool, &session.peer, notifications);
2267 }
2268
2269 response.send(proto::Ack {})?;
2270 Ok(())
2271}
2272
2273/// Remove a contact.
2274async fn remove_contact(
2275 request: proto::RemoveContact,
2276 response: Response<proto::RemoveContact>,
2277 session: Session,
2278) -> Result<()> {
2279 let requester_id = session.user_id;
2280 let responder_id = UserId::from_proto(request.user_id);
2281 let db = session.db().await;
2282 let (contact_accepted, deleted_notification_id) =
2283 db.remove_contact(requester_id, responder_id).await?;
2284
2285 let pool = session.connection_pool().await;
2286 // Update outgoing contact requests of requester
2287 let mut update = proto::UpdateContacts::default();
2288 if contact_accepted {
2289 update.remove_contacts.push(responder_id.to_proto());
2290 } else {
2291 update
2292 .remove_outgoing_requests
2293 .push(responder_id.to_proto());
2294 }
2295 for connection_id in pool.user_connection_ids(requester_id) {
2296 session.peer.send(connection_id, update.clone())?;
2297 }
2298
2299 // Update incoming contact requests of responder
2300 let mut update = proto::UpdateContacts::default();
2301 if contact_accepted {
2302 update.remove_contacts.push(requester_id.to_proto());
2303 } else {
2304 update
2305 .remove_incoming_requests
2306 .push(requester_id.to_proto());
2307 }
2308 for connection_id in pool.user_connection_ids(responder_id) {
2309 session.peer.send(connection_id, update.clone())?;
2310 if let Some(notification_id) = deleted_notification_id {
2311 session.peer.send(
2312 connection_id,
2313 proto::DeleteNotification {
2314 notification_id: notification_id.to_proto(),
2315 },
2316 )?;
2317 }
2318 }
2319
2320 response.send(proto::Ack {})?;
2321 Ok(())
2322}
2323
2324/// Creates a new channel.
2325async fn create_channel(
2326 request: proto::CreateChannel,
2327 response: Response<proto::CreateChannel>,
2328 session: Session,
2329) -> Result<()> {
2330 let db = session.db().await;
2331
2332 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2333 let (channel, owner, channel_members) = db
2334 .create_channel(&request.name, parent_id, session.user_id)
2335 .await?;
2336
2337 response.send(proto::CreateChannelResponse {
2338 channel: Some(channel.to_proto()),
2339 parent_id: request.parent_id,
2340 })?;
2341
2342 let connection_pool = session.connection_pool().await;
2343 if let Some(owner) = owner {
2344 let update = proto::UpdateUserChannels {
2345 channel_memberships: vec![proto::ChannelMembership {
2346 channel_id: owner.channel_id.to_proto(),
2347 role: owner.role.into(),
2348 }],
2349 ..Default::default()
2350 };
2351 for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2352 session.peer.send(connection_id, update.clone())?;
2353 }
2354 }
2355
2356 for channel_member in channel_members {
2357 if !channel_member.role.can_see_channel(channel.visibility) {
2358 continue;
2359 }
2360
2361 let update = proto::UpdateChannels {
2362 channels: vec![channel.to_proto()],
2363 ..Default::default()
2364 };
2365 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2366 session.peer.send(connection_id, update.clone())?;
2367 }
2368 }
2369
2370 Ok(())
2371}
2372
2373/// Delete a channel
2374async fn delete_channel(
2375 request: proto::DeleteChannel,
2376 response: Response<proto::DeleteChannel>,
2377 session: Session,
2378) -> Result<()> {
2379 let db = session.db().await;
2380
2381 let channel_id = request.channel_id;
2382 let (removed_channels, member_ids) = db
2383 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2384 .await?;
2385 response.send(proto::Ack {})?;
2386
2387 // Notify members of removed channels
2388 let mut update = proto::UpdateChannels::default();
2389 update
2390 .delete_channels
2391 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2392
2393 let connection_pool = session.connection_pool().await;
2394 for member_id in member_ids {
2395 for connection_id in connection_pool.user_connection_ids(member_id) {
2396 session.peer.send(connection_id, update.clone())?;
2397 }
2398 }
2399
2400 Ok(())
2401}
2402
2403/// Invite someone to join a channel.
2404async fn invite_channel_member(
2405 request: proto::InviteChannelMember,
2406 response: Response<proto::InviteChannelMember>,
2407 session: Session,
2408) -> Result<()> {
2409 let db = session.db().await;
2410 let channel_id = ChannelId::from_proto(request.channel_id);
2411 let invitee_id = UserId::from_proto(request.user_id);
2412 let InviteMemberResult {
2413 channel,
2414 notifications,
2415 } = db
2416 .invite_channel_member(
2417 channel_id,
2418 invitee_id,
2419 session.user_id,
2420 request.role().into(),
2421 )
2422 .await?;
2423
2424 let update = proto::UpdateChannels {
2425 channel_invitations: vec![channel.to_proto()],
2426 ..Default::default()
2427 };
2428
2429 let connection_pool = session.connection_pool().await;
2430 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2431 session.peer.send(connection_id, update.clone())?;
2432 }
2433
2434 send_notifications(&*connection_pool, &session.peer, notifications);
2435
2436 response.send(proto::Ack {})?;
2437 Ok(())
2438}
2439
2440/// remove someone from a channel
2441async fn remove_channel_member(
2442 request: proto::RemoveChannelMember,
2443 response: Response<proto::RemoveChannelMember>,
2444 session: Session,
2445) -> Result<()> {
2446 let db = session.db().await;
2447 let channel_id = ChannelId::from_proto(request.channel_id);
2448 let member_id = UserId::from_proto(request.user_id);
2449
2450 let RemoveChannelMemberResult {
2451 membership_update,
2452 notification_id,
2453 } = db
2454 .remove_channel_member(channel_id, member_id, session.user_id)
2455 .await?;
2456
2457 let connection_pool = &session.connection_pool().await;
2458 notify_membership_updated(
2459 &connection_pool,
2460 membership_update,
2461 member_id,
2462 &session.peer,
2463 );
2464 for connection_id in connection_pool.user_connection_ids(member_id) {
2465 if let Some(notification_id) = notification_id {
2466 session
2467 .peer
2468 .send(
2469 connection_id,
2470 proto::DeleteNotification {
2471 notification_id: notification_id.to_proto(),
2472 },
2473 )
2474 .trace_err();
2475 }
2476 }
2477
2478 response.send(proto::Ack {})?;
2479 Ok(())
2480}
2481
2482/// Toggle the channel between public and private.
2483/// Care is taken to maintain the invariant that public channels only descend from public channels,
2484/// (though members-only channels can appear at any point in the hierarchy).
2485async fn set_channel_visibility(
2486 request: proto::SetChannelVisibility,
2487 response: Response<proto::SetChannelVisibility>,
2488 session: Session,
2489) -> Result<()> {
2490 let db = session.db().await;
2491 let channel_id = ChannelId::from_proto(request.channel_id);
2492 let visibility = request.visibility().into();
2493
2494 let (channel, channel_members) = db
2495 .set_channel_visibility(channel_id, visibility, session.user_id)
2496 .await?;
2497
2498 let connection_pool = session.connection_pool().await;
2499 for member in channel_members {
2500 let update = if member.role.can_see_channel(channel.visibility) {
2501 proto::UpdateChannels {
2502 channels: vec![channel.to_proto()],
2503 ..Default::default()
2504 }
2505 } else {
2506 proto::UpdateChannels {
2507 delete_channels: vec![channel.id.to_proto()],
2508 ..Default::default()
2509 }
2510 };
2511
2512 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2513 session.peer.send(connection_id, update.clone())?;
2514 }
2515 }
2516
2517 response.send(proto::Ack {})?;
2518 Ok(())
2519}
2520
2521/// Alter the role for a user in the channel.
2522async fn set_channel_member_role(
2523 request: proto::SetChannelMemberRole,
2524 response: Response<proto::SetChannelMemberRole>,
2525 session: Session,
2526) -> Result<()> {
2527 let db = session.db().await;
2528 let channel_id = ChannelId::from_proto(request.channel_id);
2529 let member_id = UserId::from_proto(request.user_id);
2530 let result = db
2531 .set_channel_member_role(
2532 channel_id,
2533 session.user_id,
2534 member_id,
2535 request.role().into(),
2536 )
2537 .await?;
2538
2539 match result {
2540 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2541 let connection_pool = session.connection_pool().await;
2542 notify_membership_updated(
2543 &connection_pool,
2544 membership_update,
2545 member_id,
2546 &session.peer,
2547 )
2548 }
2549 db::SetMemberRoleResult::InviteUpdated(channel) => {
2550 let update = proto::UpdateChannels {
2551 channel_invitations: vec![channel.to_proto()],
2552 ..Default::default()
2553 };
2554
2555 for connection_id in session
2556 .connection_pool()
2557 .await
2558 .user_connection_ids(member_id)
2559 {
2560 session.peer.send(connection_id, update.clone())?;
2561 }
2562 }
2563 }
2564
2565 response.send(proto::Ack {})?;
2566 Ok(())
2567}
2568
2569/// Change the name of a channel
2570async fn rename_channel(
2571 request: proto::RenameChannel,
2572 response: Response<proto::RenameChannel>,
2573 session: Session,
2574) -> Result<()> {
2575 let db = session.db().await;
2576 let channel_id = ChannelId::from_proto(request.channel_id);
2577 let (channel, channel_members) = db
2578 .rename_channel(channel_id, session.user_id, &request.name)
2579 .await?;
2580
2581 response.send(proto::RenameChannelResponse {
2582 channel: Some(channel.to_proto()),
2583 })?;
2584
2585 let connection_pool = session.connection_pool().await;
2586 for channel_member in channel_members {
2587 if !channel_member.role.can_see_channel(channel.visibility) {
2588 continue;
2589 }
2590 let update = proto::UpdateChannels {
2591 channels: vec![channel.to_proto()],
2592 ..Default::default()
2593 };
2594 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2595 session.peer.send(connection_id, update.clone())?;
2596 }
2597 }
2598
2599 Ok(())
2600}
2601
2602/// Move a channel to a new parent.
2603async fn move_channel(
2604 request: proto::MoveChannel,
2605 response: Response<proto::MoveChannel>,
2606 session: Session,
2607) -> Result<()> {
2608 let channel_id = ChannelId::from_proto(request.channel_id);
2609 let to = ChannelId::from_proto(request.to);
2610
2611 let (channels, channel_members) = session
2612 .db()
2613 .await
2614 .move_channel(channel_id, to, session.user_id)
2615 .await?;
2616
2617 let connection_pool = session.connection_pool().await;
2618 for member in channel_members {
2619 let channels = channels
2620 .iter()
2621 .filter_map(|channel| {
2622 if member.role.can_see_channel(channel.visibility) {
2623 Some(channel.to_proto())
2624 } else {
2625 None
2626 }
2627 })
2628 .collect::<Vec<_>>();
2629 if channels.is_empty() {
2630 continue;
2631 }
2632
2633 let update = proto::UpdateChannels {
2634 channels,
2635 ..Default::default()
2636 };
2637
2638 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2639 session.peer.send(connection_id, update.clone())?;
2640 }
2641 }
2642
2643 response.send(Ack {})?;
2644 Ok(())
2645}
2646
2647/// Get the list of channel members
2648async fn get_channel_members(
2649 request: proto::GetChannelMembers,
2650 response: Response<proto::GetChannelMembers>,
2651 session: Session,
2652) -> Result<()> {
2653 let db = session.db().await;
2654 let channel_id = ChannelId::from_proto(request.channel_id);
2655 let members = db
2656 .get_channel_participant_details(channel_id, session.user_id)
2657 .await?;
2658 response.send(proto::GetChannelMembersResponse { members })?;
2659 Ok(())
2660}
2661
2662/// Accept or decline a channel invitation.
2663async fn respond_to_channel_invite(
2664 request: proto::RespondToChannelInvite,
2665 response: Response<proto::RespondToChannelInvite>,
2666 session: Session,
2667) -> Result<()> {
2668 let db = session.db().await;
2669 let channel_id = ChannelId::from_proto(request.channel_id);
2670 let RespondToChannelInvite {
2671 membership_update,
2672 notifications,
2673 } = db
2674 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2675 .await?;
2676
2677 let connection_pool = session.connection_pool().await;
2678 if let Some(membership_update) = membership_update {
2679 notify_membership_updated(
2680 &connection_pool,
2681 membership_update,
2682 session.user_id,
2683 &session.peer,
2684 );
2685 } else {
2686 let update = proto::UpdateChannels {
2687 remove_channel_invitations: vec![channel_id.to_proto()],
2688 ..Default::default()
2689 };
2690
2691 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2692 session.peer.send(connection_id, update.clone())?;
2693 }
2694 };
2695
2696 send_notifications(&*connection_pool, &session.peer, notifications);
2697
2698 response.send(proto::Ack {})?;
2699
2700 Ok(())
2701}
2702
2703/// Join the channels' room
2704async fn join_channel(
2705 request: proto::JoinChannel,
2706 response: Response<proto::JoinChannel>,
2707 session: Session,
2708) -> Result<()> {
2709 let channel_id = ChannelId::from_proto(request.channel_id);
2710 join_channel_internal(channel_id, Box::new(response), session).await
2711}
2712
2713trait JoinChannelInternalResponse {
2714 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2715}
2716impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2717 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2718 Response::<proto::JoinChannel>::send(self, result)
2719 }
2720}
2721impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2722 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2723 Response::<proto::JoinRoom>::send(self, result)
2724 }
2725}
2726
2727async fn join_channel_internal(
2728 channel_id: ChannelId,
2729 response: Box<impl JoinChannelInternalResponse>,
2730 session: Session,
2731) -> Result<()> {
2732 let joined_room = {
2733 leave_room_for_session(&session).await?;
2734 let db = session.db().await;
2735
2736 let (joined_room, membership_updated, role) = db
2737 .join_channel(
2738 channel_id,
2739 session.user_id,
2740 session.connection_id,
2741 session.zed_environment.as_ref(),
2742 )
2743 .await?;
2744
2745 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2746 let (can_publish, token) = if role == ChannelRole::Guest {
2747 (
2748 false,
2749 live_kit
2750 .guest_token(
2751 &joined_room.room.live_kit_room,
2752 &session.user_id.to_string(),
2753 )
2754 .trace_err()?,
2755 )
2756 } else {
2757 (
2758 true,
2759 live_kit
2760 .room_token(
2761 &joined_room.room.live_kit_room,
2762 &session.user_id.to_string(),
2763 )
2764 .trace_err()?,
2765 )
2766 };
2767
2768 Some(LiveKitConnectionInfo {
2769 server_url: live_kit.url().into(),
2770 token,
2771 can_publish,
2772 })
2773 });
2774
2775 response.send(proto::JoinRoomResponse {
2776 room: Some(joined_room.room.clone()),
2777 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2778 live_kit_connection_info,
2779 })?;
2780
2781 let connection_pool = session.connection_pool().await;
2782 if let Some(membership_updated) = membership_updated {
2783 notify_membership_updated(
2784 &connection_pool,
2785 membership_updated,
2786 session.user_id,
2787 &session.peer,
2788 );
2789 }
2790
2791 room_updated(&joined_room.room, &session.peer);
2792
2793 joined_room
2794 };
2795
2796 channel_updated(
2797 channel_id,
2798 &joined_room.room,
2799 &joined_room.channel_members,
2800 &session.peer,
2801 &*session.connection_pool().await,
2802 );
2803
2804 update_user_contacts(session.user_id, &session).await?;
2805 Ok(())
2806}
2807
2808/// Start editing the channel notes
2809async fn join_channel_buffer(
2810 request: proto::JoinChannelBuffer,
2811 response: Response<proto::JoinChannelBuffer>,
2812 session: Session,
2813) -> Result<()> {
2814 let db = session.db().await;
2815 let channel_id = ChannelId::from_proto(request.channel_id);
2816
2817 let open_response = db
2818 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2819 .await?;
2820
2821 let collaborators = open_response.collaborators.clone();
2822 response.send(open_response)?;
2823
2824 let update = UpdateChannelBufferCollaborators {
2825 channel_id: channel_id.to_proto(),
2826 collaborators: collaborators.clone(),
2827 };
2828 channel_buffer_updated(
2829 session.connection_id,
2830 collaborators
2831 .iter()
2832 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2833 &update,
2834 &session.peer,
2835 );
2836
2837 Ok(())
2838}
2839
2840/// Edit the channel notes
2841async fn update_channel_buffer(
2842 request: proto::UpdateChannelBuffer,
2843 session: Session,
2844) -> Result<()> {
2845 let db = session.db().await;
2846 let channel_id = ChannelId::from_proto(request.channel_id);
2847
2848 let (collaborators, non_collaborators, epoch, version) = db
2849 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2850 .await?;
2851
2852 channel_buffer_updated(
2853 session.connection_id,
2854 collaborators,
2855 &proto::UpdateChannelBuffer {
2856 channel_id: channel_id.to_proto(),
2857 operations: request.operations,
2858 },
2859 &session.peer,
2860 );
2861
2862 let pool = &*session.connection_pool().await;
2863
2864 broadcast(
2865 None,
2866 non_collaborators
2867 .iter()
2868 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2869 |peer_id| {
2870 session.peer.send(
2871 peer_id.into(),
2872 proto::UpdateChannels {
2873 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2874 channel_id: channel_id.to_proto(),
2875 epoch: epoch as u64,
2876 version: version.clone(),
2877 }],
2878 ..Default::default()
2879 },
2880 )
2881 },
2882 );
2883
2884 Ok(())
2885}
2886
2887/// Rejoin the channel notes after a connection blip
2888async fn rejoin_channel_buffers(
2889 request: proto::RejoinChannelBuffers,
2890 response: Response<proto::RejoinChannelBuffers>,
2891 session: Session,
2892) -> Result<()> {
2893 let db = session.db().await;
2894 let buffers = db
2895 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2896 .await?;
2897
2898 for rejoined_buffer in &buffers {
2899 let collaborators_to_notify = rejoined_buffer
2900 .buffer
2901 .collaborators
2902 .iter()
2903 .filter_map(|c| Some(c.peer_id?.into()));
2904 channel_buffer_updated(
2905 session.connection_id,
2906 collaborators_to_notify,
2907 &proto::UpdateChannelBufferCollaborators {
2908 channel_id: rejoined_buffer.buffer.channel_id,
2909 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2910 },
2911 &session.peer,
2912 );
2913 }
2914
2915 response.send(proto::RejoinChannelBuffersResponse {
2916 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2917 })?;
2918
2919 Ok(())
2920}
2921
2922/// Stop editing the channel notes
2923async fn leave_channel_buffer(
2924 request: proto::LeaveChannelBuffer,
2925 response: Response<proto::LeaveChannelBuffer>,
2926 session: Session,
2927) -> Result<()> {
2928 let db = session.db().await;
2929 let channel_id = ChannelId::from_proto(request.channel_id);
2930
2931 let left_buffer = db
2932 .leave_channel_buffer(channel_id, session.connection_id)
2933 .await?;
2934
2935 response.send(Ack {})?;
2936
2937 channel_buffer_updated(
2938 session.connection_id,
2939 left_buffer.connections,
2940 &proto::UpdateChannelBufferCollaborators {
2941 channel_id: channel_id.to_proto(),
2942 collaborators: left_buffer.collaborators,
2943 },
2944 &session.peer,
2945 );
2946
2947 Ok(())
2948}
2949
2950fn channel_buffer_updated<T: EnvelopedMessage>(
2951 sender_id: ConnectionId,
2952 collaborators: impl IntoIterator<Item = ConnectionId>,
2953 message: &T,
2954 peer: &Peer,
2955) {
2956 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2957 peer.send(peer_id.into(), message.clone())
2958 });
2959}
2960
2961fn send_notifications(
2962 connection_pool: &ConnectionPool,
2963 peer: &Peer,
2964 notifications: db::NotificationBatch,
2965) {
2966 for (user_id, notification) in notifications {
2967 for connection_id in connection_pool.user_connection_ids(user_id) {
2968 if let Err(error) = peer.send(
2969 connection_id,
2970 proto::AddNotification {
2971 notification: Some(notification.clone()),
2972 },
2973 ) {
2974 tracing::error!(
2975 "failed to send notification to {:?} {}",
2976 connection_id,
2977 error
2978 );
2979 }
2980 }
2981 }
2982}
2983
2984/// Send a message to the channel
2985async fn send_channel_message(
2986 request: proto::SendChannelMessage,
2987 response: Response<proto::SendChannelMessage>,
2988 session: Session,
2989) -> Result<()> {
2990 // Validate the message body.
2991 let body = request.body.trim().to_string();
2992 if body.len() > MAX_MESSAGE_LEN {
2993 return Err(anyhow!("message is too long"))?;
2994 }
2995 if body.is_empty() {
2996 return Err(anyhow!("message can't be blank"))?;
2997 }
2998
2999 // TODO: adjust mentions if body is trimmed
3000
3001 let timestamp = OffsetDateTime::now_utc();
3002 let nonce = request
3003 .nonce
3004 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3005
3006 let channel_id = ChannelId::from_proto(request.channel_id);
3007 let CreatedChannelMessage {
3008 message_id,
3009 participant_connection_ids,
3010 channel_members,
3011 notifications,
3012 } = session
3013 .db()
3014 .await
3015 .create_channel_message(
3016 channel_id,
3017 session.user_id,
3018 &body,
3019 &request.mentions,
3020 timestamp,
3021 nonce.clone().into(),
3022 match request.reply_to_message_id {
3023 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3024 None => None,
3025 },
3026 )
3027 .await?;
3028 let message = proto::ChannelMessage {
3029 sender_id: session.user_id.to_proto(),
3030 id: message_id.to_proto(),
3031 body,
3032 mentions: request.mentions,
3033 timestamp: timestamp.unix_timestamp() as u64,
3034 nonce: Some(nonce),
3035 reply_to_message_id: request.reply_to_message_id,
3036 };
3037 broadcast(
3038 Some(session.connection_id),
3039 participant_connection_ids,
3040 |connection| {
3041 session.peer.send(
3042 connection,
3043 proto::ChannelMessageSent {
3044 channel_id: channel_id.to_proto(),
3045 message: Some(message.clone()),
3046 },
3047 )
3048 },
3049 );
3050 response.send(proto::SendChannelMessageResponse {
3051 message: Some(message),
3052 })?;
3053
3054 let pool = &*session.connection_pool().await;
3055 broadcast(
3056 None,
3057 channel_members
3058 .iter()
3059 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3060 |peer_id| {
3061 session.peer.send(
3062 peer_id.into(),
3063 proto::UpdateChannels {
3064 latest_channel_message_ids: vec![proto::ChannelMessageId {
3065 channel_id: channel_id.to_proto(),
3066 message_id: message_id.to_proto(),
3067 }],
3068 ..Default::default()
3069 },
3070 )
3071 },
3072 );
3073 send_notifications(pool, &session.peer, notifications);
3074
3075 Ok(())
3076}
3077
3078/// Delete a channel message
3079async fn remove_channel_message(
3080 request: proto::RemoveChannelMessage,
3081 response: Response<proto::RemoveChannelMessage>,
3082 session: Session,
3083) -> Result<()> {
3084 let channel_id = ChannelId::from_proto(request.channel_id);
3085 let message_id = MessageId::from_proto(request.message_id);
3086 let connection_ids = session
3087 .db()
3088 .await
3089 .remove_channel_message(channel_id, message_id, session.user_id)
3090 .await?;
3091 broadcast(Some(session.connection_id), connection_ids, |connection| {
3092 session.peer.send(connection, request.clone())
3093 });
3094 response.send(proto::Ack {})?;
3095 Ok(())
3096}
3097
3098/// Mark a channel message as read
3099async fn acknowledge_channel_message(
3100 request: proto::AckChannelMessage,
3101 session: Session,
3102) -> Result<()> {
3103 let channel_id = ChannelId::from_proto(request.channel_id);
3104 let message_id = MessageId::from_proto(request.message_id);
3105 let notifications = session
3106 .db()
3107 .await
3108 .observe_channel_message(channel_id, session.user_id, message_id)
3109 .await?;
3110 send_notifications(
3111 &*session.connection_pool().await,
3112 &session.peer,
3113 notifications,
3114 );
3115 Ok(())
3116}
3117
3118/// Mark a buffer version as synced
3119async fn acknowledge_buffer_version(
3120 request: proto::AckBufferOperation,
3121 session: Session,
3122) -> Result<()> {
3123 let buffer_id = BufferId::from_proto(request.buffer_id);
3124 session
3125 .db()
3126 .await
3127 .observe_buffer_version(
3128 buffer_id,
3129 session.user_id,
3130 request.epoch as i32,
3131 &request.version,
3132 )
3133 .await?;
3134 Ok(())
3135}
3136
3137/// Start receiving chat updates for a channel
3138async fn join_channel_chat(
3139 request: proto::JoinChannelChat,
3140 response: Response<proto::JoinChannelChat>,
3141 session: Session,
3142) -> Result<()> {
3143 let channel_id = ChannelId::from_proto(request.channel_id);
3144
3145 let db = session.db().await;
3146 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3147 .await?;
3148 let messages = db
3149 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3150 .await?;
3151 response.send(proto::JoinChannelChatResponse {
3152 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3153 messages,
3154 })?;
3155 Ok(())
3156}
3157
3158/// Stop receiving chat updates for a channel
3159async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3160 let channel_id = ChannelId::from_proto(request.channel_id);
3161 session
3162 .db()
3163 .await
3164 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3165 .await?;
3166 Ok(())
3167}
3168
3169/// Retrieve the chat history for a channel
3170async fn get_channel_messages(
3171 request: proto::GetChannelMessages,
3172 response: Response<proto::GetChannelMessages>,
3173 session: Session,
3174) -> Result<()> {
3175 let channel_id = ChannelId::from_proto(request.channel_id);
3176 let messages = session
3177 .db()
3178 .await
3179 .get_channel_messages(
3180 channel_id,
3181 session.user_id,
3182 MESSAGE_COUNT_PER_PAGE,
3183 Some(MessageId::from_proto(request.before_message_id)),
3184 )
3185 .await?;
3186 response.send(proto::GetChannelMessagesResponse {
3187 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3188 messages,
3189 })?;
3190 Ok(())
3191}
3192
3193/// Retrieve specific chat messages
3194async fn get_channel_messages_by_id(
3195 request: proto::GetChannelMessagesById,
3196 response: Response<proto::GetChannelMessagesById>,
3197 session: Session,
3198) -> Result<()> {
3199 let message_ids = request
3200 .message_ids
3201 .iter()
3202 .map(|id| MessageId::from_proto(*id))
3203 .collect::<Vec<_>>();
3204 let messages = session
3205 .db()
3206 .await
3207 .get_channel_messages_by_id(session.user_id, &message_ids)
3208 .await?;
3209 response.send(proto::GetChannelMessagesResponse {
3210 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3211 messages,
3212 })?;
3213 Ok(())
3214}
3215
3216/// Retrieve the current users notifications
3217async fn get_notifications(
3218 request: proto::GetNotifications,
3219 response: Response<proto::GetNotifications>,
3220 session: Session,
3221) -> Result<()> {
3222 let notifications = session
3223 .db()
3224 .await
3225 .get_notifications(
3226 session.user_id,
3227 NOTIFICATION_COUNT_PER_PAGE,
3228 request
3229 .before_id
3230 .map(|id| db::NotificationId::from_proto(id)),
3231 )
3232 .await?;
3233 response.send(proto::GetNotificationsResponse {
3234 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3235 notifications,
3236 })?;
3237 Ok(())
3238}
3239
3240/// Mark notifications as read
3241async fn mark_notification_as_read(
3242 request: proto::MarkNotificationRead,
3243 response: Response<proto::MarkNotificationRead>,
3244 session: Session,
3245) -> Result<()> {
3246 let database = &session.db().await;
3247 let notifications = database
3248 .mark_notification_as_read_by_id(
3249 session.user_id,
3250 NotificationId::from_proto(request.notification_id),
3251 )
3252 .await?;
3253 send_notifications(
3254 &*session.connection_pool().await,
3255 &session.peer,
3256 notifications,
3257 );
3258 response.send(proto::Ack {})?;
3259 Ok(())
3260}
3261
3262/// Get the current users information
3263async fn get_private_user_info(
3264 _request: proto::GetPrivateUserInfo,
3265 response: Response<proto::GetPrivateUserInfo>,
3266 session: Session,
3267) -> Result<()> {
3268 let db = session.db().await;
3269
3270 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3271 let user = db
3272 .get_user_by_id(session.user_id)
3273 .await?
3274 .ok_or_else(|| anyhow!("user not found"))?;
3275 let flags = db.get_user_flags(session.user_id).await?;
3276
3277 response.send(proto::GetPrivateUserInfoResponse {
3278 metrics_id,
3279 staff: user.admin,
3280 flags,
3281 })?;
3282 Ok(())
3283}
3284
3285fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3286 match message {
3287 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3288 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3289 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3290 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3291 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3292 code: frame.code.into(),
3293 reason: frame.reason,
3294 })),
3295 }
3296}
3297
3298fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3299 match message {
3300 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3301 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3302 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3303 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3304 AxumMessage::Close(frame) => {
3305 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3306 code: frame.code.into(),
3307 reason: frame.reason,
3308 }))
3309 }
3310 }
3311}
3312
3313fn notify_membership_updated(
3314 connection_pool: &ConnectionPool,
3315 result: MembershipUpdated,
3316 user_id: UserId,
3317 peer: &Peer,
3318) {
3319 let user_channels_update = proto::UpdateUserChannels {
3320 channel_memberships: result
3321 .new_channels
3322 .channel_memberships
3323 .iter()
3324 .map(|cm| proto::ChannelMembership {
3325 channel_id: cm.channel_id.to_proto(),
3326 role: cm.role.into(),
3327 })
3328 .collect(),
3329 ..Default::default()
3330 };
3331 let mut update = build_channels_update(result.new_channels, vec![]);
3332 update.delete_channels = result
3333 .removed_channels
3334 .into_iter()
3335 .map(|id| id.to_proto())
3336 .collect();
3337 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3338
3339 for connection_id in connection_pool.user_connection_ids(user_id) {
3340 peer.send(connection_id, user_channels_update.clone())
3341 .trace_err();
3342 peer.send(connection_id, update.clone()).trace_err();
3343 }
3344}
3345
3346fn build_update_user_channels(
3347 memberships: &Vec<db::channel_member::Model>,
3348) -> proto::UpdateUserChannels {
3349 proto::UpdateUserChannels {
3350 channel_memberships: memberships
3351 .iter()
3352 .map(|m| proto::ChannelMembership {
3353 channel_id: m.channel_id.to_proto(),
3354 role: m.role.into(),
3355 })
3356 .collect(),
3357 ..Default::default()
3358 }
3359}
3360
3361fn build_channels_update(
3362 channels: ChannelsForUser,
3363 channel_invites: Vec<db::Channel>,
3364) -> proto::UpdateChannels {
3365 let mut update = proto::UpdateChannels::default();
3366
3367 for channel in channels.channels {
3368 update.channels.push(channel.to_proto());
3369 }
3370
3371 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3372 update.latest_channel_message_ids = channels.latest_channel_messages;
3373
3374 for (channel_id, participants) in channels.channel_participants {
3375 update
3376 .channel_participants
3377 .push(proto::ChannelParticipants {
3378 channel_id: channel_id.to_proto(),
3379 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3380 });
3381 }
3382
3383 for channel in channel_invites {
3384 update.channel_invitations.push(channel.to_proto());
3385 }
3386
3387 update
3388}
3389
3390fn build_initial_contacts_update(
3391 contacts: Vec<db::Contact>,
3392 pool: &ConnectionPool,
3393) -> proto::UpdateContacts {
3394 let mut update = proto::UpdateContacts::default();
3395
3396 for contact in contacts {
3397 match contact {
3398 db::Contact::Accepted { user_id, busy } => {
3399 update.contacts.push(contact_for_user(user_id, busy, &pool));
3400 }
3401 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3402 db::Contact::Incoming { user_id } => {
3403 update
3404 .incoming_requests
3405 .push(proto::IncomingContactRequest {
3406 requester_id: user_id.to_proto(),
3407 })
3408 }
3409 }
3410 }
3411
3412 update
3413}
3414
3415fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3416 proto::Contact {
3417 user_id: user_id.to_proto(),
3418 online: pool.is_user_online(user_id),
3419 busy,
3420 }
3421}
3422
3423fn room_updated(room: &proto::Room, peer: &Peer) {
3424 broadcast(
3425 None,
3426 room.participants
3427 .iter()
3428 .filter_map(|participant| Some(participant.peer_id?.into())),
3429 |peer_id| {
3430 peer.send(
3431 peer_id.into(),
3432 proto::RoomUpdated {
3433 room: Some(room.clone()),
3434 },
3435 )
3436 },
3437 );
3438}
3439
3440fn channel_updated(
3441 channel_id: ChannelId,
3442 room: &proto::Room,
3443 channel_members: &[UserId],
3444 peer: &Peer,
3445 pool: &ConnectionPool,
3446) {
3447 let participants = room
3448 .participants
3449 .iter()
3450 .map(|p| p.user_id)
3451 .collect::<Vec<_>>();
3452
3453 broadcast(
3454 None,
3455 channel_members
3456 .iter()
3457 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3458 |peer_id| {
3459 peer.send(
3460 peer_id.into(),
3461 proto::UpdateChannels {
3462 channel_participants: vec![proto::ChannelParticipants {
3463 channel_id: channel_id.to_proto(),
3464 participant_user_ids: participants.clone(),
3465 }],
3466 ..Default::default()
3467 },
3468 )
3469 },
3470 );
3471}
3472
3473async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3474 let db = session.db().await;
3475
3476 let contacts = db.get_contacts(user_id).await?;
3477 let busy = db.is_user_busy(user_id).await?;
3478
3479 let pool = session.connection_pool().await;
3480 let updated_contact = contact_for_user(user_id, busy, &pool);
3481 for contact in contacts {
3482 if let db::Contact::Accepted {
3483 user_id: contact_user_id,
3484 ..
3485 } = contact
3486 {
3487 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3488 session
3489 .peer
3490 .send(
3491 contact_conn_id,
3492 proto::UpdateContacts {
3493 contacts: vec![updated_contact.clone()],
3494 remove_contacts: Default::default(),
3495 incoming_requests: Default::default(),
3496 remove_incoming_requests: Default::default(),
3497 outgoing_requests: Default::default(),
3498 remove_outgoing_requests: Default::default(),
3499 },
3500 )
3501 .trace_err();
3502 }
3503 }
3504 }
3505 Ok(())
3506}
3507
3508async fn leave_room_for_session(session: &Session) -> Result<()> {
3509 let mut contacts_to_update = HashSet::default();
3510
3511 let room_id;
3512 let canceled_calls_to_user_ids;
3513 let live_kit_room;
3514 let delete_live_kit_room;
3515 let room;
3516 let channel_members;
3517 let channel_id;
3518
3519 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3520 contacts_to_update.insert(session.user_id);
3521
3522 for project in left_room.left_projects.values() {
3523 project_left(project, session);
3524 }
3525
3526 room_id = RoomId::from_proto(left_room.room.id);
3527 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3528 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3529 delete_live_kit_room = left_room.deleted;
3530 room = mem::take(&mut left_room.room);
3531 channel_members = mem::take(&mut left_room.channel_members);
3532 channel_id = left_room.channel_id;
3533
3534 room_updated(&room, &session.peer);
3535 } else {
3536 return Ok(());
3537 }
3538
3539 if let Some(channel_id) = channel_id {
3540 channel_updated(
3541 channel_id,
3542 &room,
3543 &channel_members,
3544 &session.peer,
3545 &*session.connection_pool().await,
3546 );
3547 }
3548
3549 {
3550 let pool = session.connection_pool().await;
3551 for canceled_user_id in canceled_calls_to_user_ids {
3552 for connection_id in pool.user_connection_ids(canceled_user_id) {
3553 session
3554 .peer
3555 .send(
3556 connection_id,
3557 proto::CallCanceled {
3558 room_id: room_id.to_proto(),
3559 },
3560 )
3561 .trace_err();
3562 }
3563 contacts_to_update.insert(canceled_user_id);
3564 }
3565 }
3566
3567 for contact_user_id in contacts_to_update {
3568 update_user_contacts(contact_user_id, &session).await?;
3569 }
3570
3571 if let Some(live_kit) = session.live_kit_client.as_ref() {
3572 live_kit
3573 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3574 .await
3575 .trace_err();
3576
3577 if delete_live_kit_room {
3578 live_kit.delete_room(live_kit_room).await.trace_err();
3579 }
3580 }
3581
3582 Ok(())
3583}
3584
3585async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3586 let left_channel_buffers = session
3587 .db()
3588 .await
3589 .leave_channel_buffers(session.connection_id)
3590 .await?;
3591
3592 for left_buffer in left_channel_buffers {
3593 channel_buffer_updated(
3594 session.connection_id,
3595 left_buffer.connections,
3596 &proto::UpdateChannelBufferCollaborators {
3597 channel_id: left_buffer.channel_id.to_proto(),
3598 collaborators: left_buffer.collaborators,
3599 },
3600 &session.peer,
3601 );
3602 }
3603
3604 Ok(())
3605}
3606
3607fn project_left(project: &db::LeftProject, session: &Session) {
3608 for connection_id in &project.connection_ids {
3609 if project.host_user_id == session.user_id {
3610 session
3611 .peer
3612 .send(
3613 *connection_id,
3614 proto::UnshareProject {
3615 project_id: project.id.to_proto(),
3616 },
3617 )
3618 .trace_err();
3619 } else {
3620 session
3621 .peer
3622 .send(
3623 *connection_id,
3624 proto::RemoveProjectCollaborator {
3625 project_id: project.id.to_proto(),
3626 peer_id: Some(session.connection_id.into()),
3627 },
3628 )
3629 .trace_err();
3630 }
3631 }
3632}
3633
3634pub trait ResultExt {
3635 type Ok;
3636
3637 fn trace_err(self) -> Option<Self::Ok>;
3638}
3639
3640impl<T, E> ResultExt for Result<T, E>
3641where
3642 E: std::fmt::Debug,
3643{
3644 type Ok = T;
3645
3646 fn trace_err(self) -> Option<T> {
3647 match self {
3648 Ok(value) => Some(value),
3649 Err(error) => {
3650 tracing::error!("{:?}", error);
3651 None
3652 }
3653 }
3654 }
3655}