1mod connection_pool;
2
3use crate::{
4 auth::{self, Impersonator},
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
7 InviteMemberResult, MembershipUpdated, MessageId, NotificationId, ProjectId,
8 RemoveChannelMemberResult, RenameChannelResult, RespondToChannelInvite, RoomId, ServerId,
9 SetChannelVisibilityResult, User, UserId,
10 },
11 executor::Executor,
12 AppState, Error, Result,
13};
14use anyhow::anyhow;
15use async_tungstenite::tungstenite::{
16 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
17};
18use axum::{
19 body::Body,
20 extract::{
21 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
22 ConnectInfo, WebSocketUpgrade,
23 },
24 headers::{Header, HeaderName},
25 http::StatusCode,
26 middleware,
27 response::IntoResponse,
28 routing::get,
29 Extension, Router, TypedHeader,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::ConnectionPool;
33use futures::{
34 channel::oneshot,
35 future::{self, BoxFuture},
36 stream::FuturesUnordered,
37 FutureExt, SinkExt, StreamExt, TryStreamExt,
38};
39use lazy_static::lazy_static;
40use prometheus::{register_int_gauge, IntGauge};
41use rpc::{
42 proto::{
43 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
44 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
45 },
46 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
47};
48use serde::{Serialize, Serializer};
49use std::{
50 any::TypeId,
51 fmt,
52 future::Future,
53 marker::PhantomData,
54 mem,
55 net::SocketAddr,
56 ops::{Deref, DerefMut},
57 rc::Rc,
58 sync::{
59 atomic::{AtomicBool, Ordering::SeqCst},
60 Arc,
61 },
62 time::{Duration, Instant},
63};
64use time::OffsetDateTime;
65use tokio::sync::{watch, Semaphore};
66use tower::ServiceBuilder;
67use tracing::{field, info_span, instrument, Instrument};
68
69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
70pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
71
72const MESSAGE_COUNT_PER_PAGE: usize = 100;
73const MAX_MESSAGE_LEN: usize = 1024;
74const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
75
76lazy_static! {
77 static ref METRIC_CONNECTIONS: IntGauge =
78 register_int_gauge!("connections", "number of connections").unwrap();
79 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
80 "shared_projects",
81 "number of open projects with one or more guests"
82 )
83 .unwrap();
84}
85
86type MessageHandler =
87 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
88
89struct Response<R> {
90 peer: Arc<Peer>,
91 receipt: Receipt<R>,
92 responded: Arc<AtomicBool>,
93}
94
95impl<R: RequestMessage> Response<R> {
96 fn send(self, payload: R::Response) -> Result<()> {
97 self.responded.store(true, SeqCst);
98 self.peer.respond(self.receipt, payload)?;
99 Ok(())
100 }
101}
102
103#[derive(Clone)]
104struct Session {
105 zed_environment: Arc<str>,
106 user_id: UserId,
107 connection_id: ConnectionId,
108 db: Arc<tokio::sync::Mutex<DbHandle>>,
109 peer: Arc<Peer>,
110 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
111 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
112 _executor: Executor,
113}
114
115impl Session {
116 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
117 #[cfg(test)]
118 tokio::task::yield_now().await;
119 let guard = self.db.lock().await;
120 #[cfg(test)]
121 tokio::task::yield_now().await;
122 guard
123 }
124
125 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
126 #[cfg(test)]
127 tokio::task::yield_now().await;
128 let guard = self.connection_pool.lock();
129 ConnectionPoolGuard {
130 guard,
131 _not_send: PhantomData,
132 }
133 }
134}
135
136impl fmt::Debug for Session {
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 f.debug_struct("Session")
139 .field("user_id", &self.user_id)
140 .field("connection_id", &self.connection_id)
141 .finish()
142 }
143}
144
145struct DbHandle(Arc<Database>);
146
147impl Deref for DbHandle {
148 type Target = Database;
149
150 fn deref(&self) -> &Self::Target {
151 self.0.as_ref()
152 }
153}
154
155pub struct Server {
156 id: parking_lot::Mutex<ServerId>,
157 peer: Arc<Peer>,
158 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
159 app_state: Arc<AppState>,
160 executor: Executor,
161 handlers: HashMap<TypeId, MessageHandler>,
162 teardown: watch::Sender<()>,
163}
164
165pub(crate) struct ConnectionPoolGuard<'a> {
166 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
167 _not_send: PhantomData<Rc<()>>,
168}
169
170#[derive(Serialize)]
171pub struct ServerSnapshot<'a> {
172 peer: &'a Peer,
173 #[serde(serialize_with = "serialize_deref")]
174 connection_pool: ConnectionPoolGuard<'a>,
175}
176
177pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
178where
179 S: Serializer,
180 T: Deref<Target = U>,
181 U: Serialize,
182{
183 Serialize::serialize(value.deref(), serializer)
184}
185
186impl Server {
187 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
188 let mut server = Self {
189 id: parking_lot::Mutex::new(id),
190 peer: Peer::new(id.0 as u32),
191 app_state,
192 executor,
193 connection_pool: Default::default(),
194 handlers: Default::default(),
195 teardown: watch::channel(()).0,
196 };
197
198 server
199 .add_request_handler(ping)
200 .add_request_handler(create_room)
201 .add_request_handler(join_room)
202 .add_request_handler(rejoin_room)
203 .add_request_handler(leave_room)
204 .add_request_handler(set_room_participant_role)
205 .add_request_handler(call)
206 .add_request_handler(cancel_call)
207 .add_message_handler(decline_call)
208 .add_request_handler(update_participant_location)
209 .add_request_handler(share_project)
210 .add_message_handler(unshare_project)
211 .add_request_handler(join_project)
212 .add_message_handler(leave_project)
213 .add_request_handler(update_project)
214 .add_request_handler(update_worktree)
215 .add_message_handler(start_language_server)
216 .add_message_handler(update_language_server)
217 .add_message_handler(update_diagnostic_summary)
218 .add_message_handler(update_worktree_settings)
219 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
220 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
221 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
222 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
223 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
224 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
225 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
226 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
227 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
228 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
229 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
230 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
231 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
232 .add_request_handler(
233 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
234 )
235 .add_request_handler(
236 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
237 )
238 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
239 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
240 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
241 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
242 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
243 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
244 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
245 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
246 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
247 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
248 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
249 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
250 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
251 .add_message_handler(create_buffer_for_peer)
252 .add_request_handler(update_buffer)
253 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
254 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
255 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
256 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
257 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
258 .add_request_handler(get_users)
259 .add_request_handler(fuzzy_search_users)
260 .add_request_handler(request_contact)
261 .add_request_handler(remove_contact)
262 .add_request_handler(respond_to_contact_request)
263 .add_request_handler(create_channel)
264 .add_request_handler(delete_channel)
265 .add_request_handler(invite_channel_member)
266 .add_request_handler(remove_channel_member)
267 .add_request_handler(set_channel_member_role)
268 .add_request_handler(set_channel_visibility)
269 .add_request_handler(rename_channel)
270 .add_request_handler(join_channel_buffer)
271 .add_request_handler(leave_channel_buffer)
272 .add_message_handler(update_channel_buffer)
273 .add_request_handler(rejoin_channel_buffers)
274 .add_request_handler(get_channel_members)
275 .add_request_handler(respond_to_channel_invite)
276 .add_request_handler(join_channel)
277 .add_request_handler(join_channel_chat)
278 .add_message_handler(leave_channel_chat)
279 .add_request_handler(send_channel_message)
280 .add_request_handler(remove_channel_message)
281 .add_request_handler(get_channel_messages)
282 .add_request_handler(get_channel_messages_by_id)
283 .add_request_handler(get_notifications)
284 .add_request_handler(mark_notification_as_read)
285 .add_request_handler(move_channel)
286 .add_request_handler(follow)
287 .add_message_handler(unfollow)
288 .add_message_handler(update_followers)
289 .add_request_handler(get_private_user_info)
290 .add_message_handler(acknowledge_channel_message)
291 .add_message_handler(acknowledge_buffer_version);
292
293 Arc::new(server)
294 }
295
296 pub async fn start(&self) -> Result<()> {
297 let server_id = *self.id.lock();
298 let app_state = self.app_state.clone();
299 let peer = self.peer.clone();
300 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
301 let pool = self.connection_pool.clone();
302 let live_kit_client = self.app_state.live_kit_client.clone();
303
304 let span = info_span!("start server");
305 self.executor.spawn_detached(
306 async move {
307 tracing::info!("waiting for cleanup timeout");
308 timeout.await;
309 tracing::info!("cleanup timeout expired, retrieving stale rooms");
310 if let Some((room_ids, channel_ids)) = app_state
311 .db
312 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
313 .await
314 .trace_err()
315 {
316 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
317 tracing::info!(
318 stale_channel_buffer_count = channel_ids.len(),
319 "retrieved stale channel buffers"
320 );
321
322 for channel_id in channel_ids {
323 if let Some(refreshed_channel_buffer) = app_state
324 .db
325 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
326 .await
327 .trace_err()
328 {
329 for connection_id in refreshed_channel_buffer.connection_ids {
330 peer.send(
331 connection_id,
332 proto::UpdateChannelBufferCollaborators {
333 channel_id: channel_id.to_proto(),
334 collaborators: refreshed_channel_buffer
335 .collaborators
336 .clone(),
337 },
338 )
339 .trace_err();
340 }
341 }
342 }
343
344 for room_id in room_ids {
345 let mut contacts_to_update = HashSet::default();
346 let mut canceled_calls_to_user_ids = Vec::new();
347 let mut live_kit_room = String::new();
348 let mut delete_live_kit_room = false;
349
350 if let Some(mut refreshed_room) = app_state
351 .db
352 .clear_stale_room_participants(room_id, server_id)
353 .await
354 .trace_err()
355 {
356 tracing::info!(
357 room_id = room_id.0,
358 new_participant_count = refreshed_room.room.participants.len(),
359 "refreshed room"
360 );
361 room_updated(&refreshed_room.room, &peer);
362 if let Some(channel_id) = refreshed_room.channel_id {
363 channel_updated(
364 channel_id,
365 &refreshed_room.room,
366 &refreshed_room.channel_members,
367 &peer,
368 &*pool.lock(),
369 );
370 }
371 contacts_to_update
372 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
373 contacts_to_update
374 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
375 canceled_calls_to_user_ids =
376 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
377 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
378 delete_live_kit_room = refreshed_room.room.participants.is_empty();
379 }
380
381 {
382 let pool = pool.lock();
383 for canceled_user_id in canceled_calls_to_user_ids {
384 for connection_id in pool.user_connection_ids(canceled_user_id) {
385 peer.send(
386 connection_id,
387 proto::CallCanceled {
388 room_id: room_id.to_proto(),
389 },
390 )
391 .trace_err();
392 }
393 }
394 }
395
396 for user_id in contacts_to_update {
397 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
398 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
399 if let Some((busy, contacts)) = busy.zip(contacts) {
400 let pool = pool.lock();
401 let updated_contact = contact_for_user(user_id, busy, &pool);
402 for contact in contacts {
403 if let db::Contact::Accepted {
404 user_id: contact_user_id,
405 ..
406 } = contact
407 {
408 for contact_conn_id in
409 pool.user_connection_ids(contact_user_id)
410 {
411 peer.send(
412 contact_conn_id,
413 proto::UpdateContacts {
414 contacts: vec![updated_contact.clone()],
415 remove_contacts: Default::default(),
416 incoming_requests: Default::default(),
417 remove_incoming_requests: Default::default(),
418 outgoing_requests: Default::default(),
419 remove_outgoing_requests: Default::default(),
420 },
421 )
422 .trace_err();
423 }
424 }
425 }
426 }
427 }
428
429 if let Some(live_kit) = live_kit_client.as_ref() {
430 if delete_live_kit_room {
431 live_kit.delete_room(live_kit_room).await.trace_err();
432 }
433 }
434 }
435 }
436
437 app_state
438 .db
439 .delete_stale_servers(&app_state.config.zed_environment, server_id)
440 .await
441 .trace_err();
442 }
443 .instrument(span),
444 );
445 Ok(())
446 }
447
448 pub fn teardown(&self) {
449 self.peer.teardown();
450 self.connection_pool.lock().reset();
451 let _ = self.teardown.send(());
452 }
453
454 #[cfg(test)]
455 pub fn reset(&self, id: ServerId) {
456 self.teardown();
457 *self.id.lock() = id;
458 self.peer.reset(id.0 as u32);
459 }
460
461 #[cfg(test)]
462 pub fn id(&self) -> ServerId {
463 *self.id.lock()
464 }
465
466 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
467 where
468 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
469 Fut: 'static + Send + Future<Output = Result<()>>,
470 M: EnvelopedMessage,
471 {
472 let prev_handler = self.handlers.insert(
473 TypeId::of::<M>(),
474 Box::new(move |envelope, session| {
475 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
476 let span = info_span!(
477 "handle message",
478 payload_type = envelope.payload_type_name()
479 );
480 span.in_scope(|| {
481 tracing::info!(
482 payload_type = envelope.payload_type_name(),
483 "message received"
484 );
485 });
486 let start_time = Instant::now();
487 let future = (handler)(*envelope, session);
488 async move {
489 let result = future.await;
490 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
491 match result {
492 Err(error) => {
493 tracing::error!(%error, ?duration_ms, "error handling message")
494 }
495 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
496 }
497 }
498 .instrument(span)
499 .boxed()
500 }),
501 );
502 if prev_handler.is_some() {
503 panic!("registered a handler for the same message twice");
504 }
505 self
506 }
507
508 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
509 where
510 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
511 Fut: 'static + Send + Future<Output = Result<()>>,
512 M: EnvelopedMessage,
513 {
514 self.add_handler(move |envelope, session| handler(envelope.payload, session));
515 self
516 }
517
518 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
519 where
520 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
521 Fut: Send + Future<Output = Result<()>>,
522 M: RequestMessage,
523 {
524 let handler = Arc::new(handler);
525 self.add_handler(move |envelope, session| {
526 let receipt = envelope.receipt();
527 let handler = handler.clone();
528 async move {
529 let peer = session.peer.clone();
530 let responded = Arc::new(AtomicBool::default());
531 let response = Response {
532 peer: peer.clone(),
533 responded: responded.clone(),
534 receipt,
535 };
536 match (handler)(envelope.payload, response, session).await {
537 Ok(()) => {
538 if responded.load(std::sync::atomic::Ordering::SeqCst) {
539 Ok(())
540 } else {
541 Err(anyhow!("handler did not send a response"))?
542 }
543 }
544 Err(error) => {
545 let proto_err = match &error {
546 Error::Internal(err) => err.to_proto(),
547 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
548 };
549 peer.respond_with_error(receipt, proto_err)?;
550 Err(error)
551 }
552 }
553 }
554 })
555 }
556
557 pub fn handle_connection(
558 self: &Arc<Self>,
559 connection: Connection,
560 address: String,
561 user: User,
562 impersonator: Option<User>,
563 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
564 executor: Executor,
565 ) -> impl Future<Output = Result<()>> {
566 let this = self.clone();
567 let user_id = user.id;
568 let login = user.github_login;
569 let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
570 if let Some(impersonator) = impersonator {
571 span.record("impersonator", &impersonator.github_login);
572 }
573 let mut teardown = self.teardown.subscribe();
574 async move {
575 let (connection_id, handle_io, mut incoming_rx) = this
576 .peer
577 .add_connection(connection, {
578 let executor = executor.clone();
579 move |duration| executor.sleep(duration)
580 });
581
582 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
583 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
584 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
585
586 if let Some(send_connection_id) = send_connection_id.take() {
587 let _ = send_connection_id.send(connection_id);
588 }
589
590 if !user.connected_once {
591 this.peer.send(connection_id, proto::ShowContacts {})?;
592 this.app_state.db.set_user_connected_once(user_id, true).await?;
593 }
594
595 let (contacts, channels_for_user, channel_invites) = future::try_join3(
596 this.app_state.db.get_contacts(user_id),
597 this.app_state.db.get_channels_for_user(user_id),
598 this.app_state.db.get_channel_invites_for_user(user_id),
599 ).await?;
600
601 {
602 let mut pool = this.connection_pool.lock();
603 pool.add_connection(connection_id, user_id, user.admin);
604 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
605 this.peer.send(connection_id, build_channels_update(
606 channels_for_user,
607 channel_invites
608 ))?;
609 }
610
611 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
612 this.peer.send(connection_id, incoming_call)?;
613 }
614
615 let session = Session {
616 user_id,
617 connection_id,
618 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
619 zed_environment: this.app_state.config.zed_environment.clone(),
620 peer: this.peer.clone(),
621 connection_pool: this.connection_pool.clone(),
622 live_kit_client: this.app_state.live_kit_client.clone(),
623 _executor: executor.clone()
624 };
625 update_user_contacts(user_id, &session).await?;
626
627 let handle_io = handle_io.fuse();
628 futures::pin_mut!(handle_io);
629
630 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
631 // This prevents deadlocks when e.g., client A performs a request to client B and
632 // client B performs a request to client A. If both clients stop processing further
633 // messages until their respective request completes, they won't have a chance to
634 // respond to the other client's request and cause a deadlock.
635 //
636 // This arrangement ensures we will attempt to process earlier messages first, but fall
637 // back to processing messages arrived later in the spirit of making progress.
638 let mut foreground_message_handlers = FuturesUnordered::new();
639 let concurrent_handlers = Arc::new(Semaphore::new(256));
640 loop {
641 let next_message = async {
642 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
643 let message = incoming_rx.next().await;
644 (permit, message)
645 }.fuse();
646 futures::pin_mut!(next_message);
647 futures::select_biased! {
648 _ = teardown.changed().fuse() => return Ok(()),
649 result = handle_io => {
650 if let Err(error) = result {
651 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
652 }
653 break;
654 }
655 _ = foreground_message_handlers.next() => {}
656 next_message = next_message => {
657 let (permit, message) = next_message;
658 if let Some(message) = message {
659 let type_name = message.payload_type_name();
660 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
661 let span_enter = span.enter();
662 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
663 let is_background = message.is_background();
664 let handle_message = (handler)(message, session.clone());
665 drop(span_enter);
666
667 let handle_message = async move {
668 handle_message.await;
669 drop(permit);
670 }.instrument(span);
671 if is_background {
672 executor.spawn_detached(handle_message);
673 } else {
674 foreground_message_handlers.push(handle_message);
675 }
676 } else {
677 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
678 }
679 } else {
680 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
681 break;
682 }
683 }
684 }
685 }
686
687 drop(foreground_message_handlers);
688 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
689 if let Err(error) = connection_lost(session, teardown, executor).await {
690 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
691 }
692
693 Ok(())
694 }.instrument(span)
695 }
696
697 pub async fn invite_code_redeemed(
698 self: &Arc<Self>,
699 inviter_id: UserId,
700 invitee_id: UserId,
701 ) -> Result<()> {
702 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
703 if let Some(code) = &user.invite_code {
704 let pool = self.connection_pool.lock();
705 let invitee_contact = contact_for_user(invitee_id, false, &pool);
706 for connection_id in pool.user_connection_ids(inviter_id) {
707 self.peer.send(
708 connection_id,
709 proto::UpdateContacts {
710 contacts: vec![invitee_contact.clone()],
711 ..Default::default()
712 },
713 )?;
714 self.peer.send(
715 connection_id,
716 proto::UpdateInviteInfo {
717 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
718 count: user.invite_count as u32,
719 },
720 )?;
721 }
722 }
723 }
724 Ok(())
725 }
726
727 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
728 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
729 if let Some(invite_code) = &user.invite_code {
730 let pool = self.connection_pool.lock();
731 for connection_id in pool.user_connection_ids(user_id) {
732 self.peer.send(
733 connection_id,
734 proto::UpdateInviteInfo {
735 url: format!(
736 "{}{}",
737 self.app_state.config.invite_link_prefix, invite_code
738 ),
739 count: user.invite_count as u32,
740 },
741 )?;
742 }
743 }
744 }
745 Ok(())
746 }
747
748 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
749 ServerSnapshot {
750 connection_pool: ConnectionPoolGuard {
751 guard: self.connection_pool.lock(),
752 _not_send: PhantomData,
753 },
754 peer: &self.peer,
755 }
756 }
757}
758
759impl<'a> Deref for ConnectionPoolGuard<'a> {
760 type Target = ConnectionPool;
761
762 fn deref(&self) -> &Self::Target {
763 &*self.guard
764 }
765}
766
767impl<'a> DerefMut for ConnectionPoolGuard<'a> {
768 fn deref_mut(&mut self) -> &mut Self::Target {
769 &mut *self.guard
770 }
771}
772
773impl<'a> Drop for ConnectionPoolGuard<'a> {
774 fn drop(&mut self) {
775 #[cfg(test)]
776 self.check_invariants();
777 }
778}
779
780fn broadcast<F>(
781 sender_id: Option<ConnectionId>,
782 receiver_ids: impl IntoIterator<Item = ConnectionId>,
783 mut f: F,
784) where
785 F: FnMut(ConnectionId) -> anyhow::Result<()>,
786{
787 for receiver_id in receiver_ids {
788 if Some(receiver_id) != sender_id {
789 if let Err(error) = f(receiver_id) {
790 tracing::error!("failed to send to {:?} {}", receiver_id, error);
791 }
792 }
793 }
794}
795
796lazy_static! {
797 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
798}
799
800pub struct ProtocolVersion(u32);
801
802impl Header for ProtocolVersion {
803 fn name() -> &'static HeaderName {
804 &ZED_PROTOCOL_VERSION
805 }
806
807 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
808 where
809 Self: Sized,
810 I: Iterator<Item = &'i axum::http::HeaderValue>,
811 {
812 let version = values
813 .next()
814 .ok_or_else(axum::headers::Error::invalid)?
815 .to_str()
816 .map_err(|_| axum::headers::Error::invalid())?
817 .parse()
818 .map_err(|_| axum::headers::Error::invalid())?;
819 Ok(Self(version))
820 }
821
822 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
823 values.extend([self.0.to_string().parse().unwrap()]);
824 }
825}
826
827pub fn routes(server: Arc<Server>) -> Router<Body> {
828 Router::new()
829 .route("/rpc", get(handle_websocket_request))
830 .layer(
831 ServiceBuilder::new()
832 .layer(Extension(server.app_state.clone()))
833 .layer(middleware::from_fn(auth::validate_header)),
834 )
835 .route("/metrics", get(handle_metrics))
836 .layer(Extension(server))
837}
838
839pub async fn handle_websocket_request(
840 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
841 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
842 Extension(server): Extension<Arc<Server>>,
843 Extension(user): Extension<User>,
844 Extension(impersonator): Extension<Impersonator>,
845 ws: WebSocketUpgrade,
846) -> axum::response::Response {
847 if protocol_version != rpc::PROTOCOL_VERSION {
848 return (
849 StatusCode::UPGRADE_REQUIRED,
850 "client must be upgraded".to_string(),
851 )
852 .into_response();
853 }
854 let socket_address = socket_address.to_string();
855 ws.on_upgrade(move |socket| {
856 use util::ResultExt;
857 let socket = socket
858 .map_ok(to_tungstenite_message)
859 .err_into()
860 .with(|message| async move { Ok(to_axum_message(message)) });
861 let connection = Connection::new(Box::pin(socket));
862 async move {
863 server
864 .handle_connection(
865 connection,
866 socket_address,
867 user,
868 impersonator.0,
869 None,
870 Executor::Production,
871 )
872 .await
873 .log_err();
874 }
875 })
876}
877
878pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
879 let connections = server
880 .connection_pool
881 .lock()
882 .connections()
883 .filter(|connection| !connection.admin)
884 .count();
885
886 METRIC_CONNECTIONS.set(connections as _);
887
888 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
889 METRIC_SHARED_PROJECTS.set(shared_projects as _);
890
891 let encoder = prometheus::TextEncoder::new();
892 let metric_families = prometheus::gather();
893 let encoded_metrics = encoder
894 .encode_to_string(&metric_families)
895 .map_err(|err| anyhow!("{}", err))?;
896 Ok(encoded_metrics)
897}
898
899#[instrument(err, skip(executor))]
900async fn connection_lost(
901 session: Session,
902 mut teardown: watch::Receiver<()>,
903 executor: Executor,
904) -> Result<()> {
905 session.peer.disconnect(session.connection_id);
906 session
907 .connection_pool()
908 .await
909 .remove_connection(session.connection_id)?;
910
911 session
912 .db()
913 .await
914 .connection_lost(session.connection_id)
915 .await
916 .trace_err();
917
918 futures::select_biased! {
919 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
920 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
921 leave_room_for_session(&session).await.trace_err();
922 leave_channel_buffers_for_session(&session)
923 .await
924 .trace_err();
925
926 if !session
927 .connection_pool()
928 .await
929 .is_user_online(session.user_id)
930 {
931 let db = session.db().await;
932 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
933 room_updated(&room, &session.peer);
934 }
935 }
936
937 update_user_contacts(session.user_id, &session).await?;
938 }
939 _ = teardown.changed().fuse() => {}
940 }
941
942 Ok(())
943}
944
945/// Acknowledges a ping from a client, used to keep the connection alive.
946async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
947 response.send(proto::Ack {})?;
948 Ok(())
949}
950
951/// Creates a new room for calling (outside of channels)
952async fn create_room(
953 _request: proto::CreateRoom,
954 response: Response<proto::CreateRoom>,
955 session: Session,
956) -> Result<()> {
957 let live_kit_room = nanoid::nanoid!(30);
958
959 let live_kit_connection_info = {
960 let live_kit_room = live_kit_room.clone();
961 let live_kit = session.live_kit_client.as_ref();
962
963 util::async_maybe!({
964 let live_kit = live_kit?;
965
966 let token = live_kit
967 .room_token(&live_kit_room, &session.user_id.to_string())
968 .trace_err()?;
969
970 Some(proto::LiveKitConnectionInfo {
971 server_url: live_kit.url().into(),
972 token,
973 can_publish: true,
974 })
975 })
976 }
977 .await;
978
979 let room = session
980 .db()
981 .await
982 .create_room(
983 session.user_id,
984 session.connection_id,
985 &live_kit_room,
986 &session.zed_environment,
987 )
988 .await?;
989
990 response.send(proto::CreateRoomResponse {
991 room: Some(room.clone()),
992 live_kit_connection_info,
993 })?;
994
995 update_user_contacts(session.user_id, &session).await?;
996 Ok(())
997}
998
999/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1000async fn join_room(
1001 request: proto::JoinRoom,
1002 response: Response<proto::JoinRoom>,
1003 session: Session,
1004) -> Result<()> {
1005 let room_id = RoomId::from_proto(request.id);
1006
1007 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1008
1009 if let Some(channel_id) = channel_id {
1010 return join_channel_internal(channel_id, Box::new(response), session).await;
1011 }
1012
1013 let joined_room = {
1014 let room = session
1015 .db()
1016 .await
1017 .join_room(
1018 room_id,
1019 session.user_id,
1020 session.connection_id,
1021 session.zed_environment.as_ref(),
1022 )
1023 .await?;
1024 room_updated(&room.room, &session.peer);
1025 room.into_inner()
1026 };
1027
1028 for connection_id in session
1029 .connection_pool()
1030 .await
1031 .user_connection_ids(session.user_id)
1032 {
1033 session
1034 .peer
1035 .send(
1036 connection_id,
1037 proto::CallCanceled {
1038 room_id: room_id.to_proto(),
1039 },
1040 )
1041 .trace_err();
1042 }
1043
1044 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1045 if let Some(token) = live_kit
1046 .room_token(
1047 &joined_room.room.live_kit_room,
1048 &session.user_id.to_string(),
1049 )
1050 .trace_err()
1051 {
1052 Some(proto::LiveKitConnectionInfo {
1053 server_url: live_kit.url().into(),
1054 token,
1055 can_publish: true,
1056 })
1057 } else {
1058 None
1059 }
1060 } else {
1061 None
1062 };
1063
1064 response.send(proto::JoinRoomResponse {
1065 room: Some(joined_room.room),
1066 channel_id: None,
1067 live_kit_connection_info,
1068 })?;
1069
1070 update_user_contacts(session.user_id, &session).await?;
1071 Ok(())
1072}
1073
1074/// Rejoin room is used to reconnect to a room after connection errors.
1075async fn rejoin_room(
1076 request: proto::RejoinRoom,
1077 response: Response<proto::RejoinRoom>,
1078 session: Session,
1079) -> Result<()> {
1080 let room;
1081 let channel_id;
1082 let channel_members;
1083 {
1084 let mut rejoined_room = session
1085 .db()
1086 .await
1087 .rejoin_room(request, session.user_id, session.connection_id)
1088 .await?;
1089
1090 response.send(proto::RejoinRoomResponse {
1091 room: Some(rejoined_room.room.clone()),
1092 reshared_projects: rejoined_room
1093 .reshared_projects
1094 .iter()
1095 .map(|project| proto::ResharedProject {
1096 id: project.id.to_proto(),
1097 collaborators: project
1098 .collaborators
1099 .iter()
1100 .map(|collaborator| collaborator.to_proto())
1101 .collect(),
1102 })
1103 .collect(),
1104 rejoined_projects: rejoined_room
1105 .rejoined_projects
1106 .iter()
1107 .map(|rejoined_project| proto::RejoinedProject {
1108 id: rejoined_project.id.to_proto(),
1109 worktrees: rejoined_project
1110 .worktrees
1111 .iter()
1112 .map(|worktree| proto::WorktreeMetadata {
1113 id: worktree.id,
1114 root_name: worktree.root_name.clone(),
1115 visible: worktree.visible,
1116 abs_path: worktree.abs_path.clone(),
1117 })
1118 .collect(),
1119 collaborators: rejoined_project
1120 .collaborators
1121 .iter()
1122 .map(|collaborator| collaborator.to_proto())
1123 .collect(),
1124 language_servers: rejoined_project.language_servers.clone(),
1125 })
1126 .collect(),
1127 })?;
1128 room_updated(&rejoined_room.room, &session.peer);
1129
1130 for project in &rejoined_room.reshared_projects {
1131 for collaborator in &project.collaborators {
1132 session
1133 .peer
1134 .send(
1135 collaborator.connection_id,
1136 proto::UpdateProjectCollaborator {
1137 project_id: project.id.to_proto(),
1138 old_peer_id: Some(project.old_connection_id.into()),
1139 new_peer_id: Some(session.connection_id.into()),
1140 },
1141 )
1142 .trace_err();
1143 }
1144
1145 broadcast(
1146 Some(session.connection_id),
1147 project
1148 .collaborators
1149 .iter()
1150 .map(|collaborator| collaborator.connection_id),
1151 |connection_id| {
1152 session.peer.forward_send(
1153 session.connection_id,
1154 connection_id,
1155 proto::UpdateProject {
1156 project_id: project.id.to_proto(),
1157 worktrees: project.worktrees.clone(),
1158 },
1159 )
1160 },
1161 );
1162 }
1163
1164 for project in &rejoined_room.rejoined_projects {
1165 for collaborator in &project.collaborators {
1166 session
1167 .peer
1168 .send(
1169 collaborator.connection_id,
1170 proto::UpdateProjectCollaborator {
1171 project_id: project.id.to_proto(),
1172 old_peer_id: Some(project.old_connection_id.into()),
1173 new_peer_id: Some(session.connection_id.into()),
1174 },
1175 )
1176 .trace_err();
1177 }
1178 }
1179
1180 for project in &mut rejoined_room.rejoined_projects {
1181 for worktree in mem::take(&mut project.worktrees) {
1182 #[cfg(any(test, feature = "test-support"))]
1183 const MAX_CHUNK_SIZE: usize = 2;
1184 #[cfg(not(any(test, feature = "test-support")))]
1185 const MAX_CHUNK_SIZE: usize = 256;
1186
1187 // Stream this worktree's entries.
1188 let message = proto::UpdateWorktree {
1189 project_id: project.id.to_proto(),
1190 worktree_id: worktree.id,
1191 abs_path: worktree.abs_path.clone(),
1192 root_name: worktree.root_name,
1193 updated_entries: worktree.updated_entries,
1194 removed_entries: worktree.removed_entries,
1195 scan_id: worktree.scan_id,
1196 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1197 updated_repositories: worktree.updated_repositories,
1198 removed_repositories: worktree.removed_repositories,
1199 };
1200 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1201 session.peer.send(session.connection_id, update.clone())?;
1202 }
1203
1204 // Stream this worktree's diagnostics.
1205 for summary in worktree.diagnostic_summaries {
1206 session.peer.send(
1207 session.connection_id,
1208 proto::UpdateDiagnosticSummary {
1209 project_id: project.id.to_proto(),
1210 worktree_id: worktree.id,
1211 summary: Some(summary),
1212 },
1213 )?;
1214 }
1215
1216 for settings_file in worktree.settings_files {
1217 session.peer.send(
1218 session.connection_id,
1219 proto::UpdateWorktreeSettings {
1220 project_id: project.id.to_proto(),
1221 worktree_id: worktree.id,
1222 path: settings_file.path,
1223 content: Some(settings_file.content),
1224 },
1225 )?;
1226 }
1227 }
1228
1229 for language_server in &project.language_servers {
1230 session.peer.send(
1231 session.connection_id,
1232 proto::UpdateLanguageServer {
1233 project_id: project.id.to_proto(),
1234 language_server_id: language_server.id,
1235 variant: Some(
1236 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1237 proto::LspDiskBasedDiagnosticsUpdated {},
1238 ),
1239 ),
1240 },
1241 )?;
1242 }
1243 }
1244
1245 let rejoined_room = rejoined_room.into_inner();
1246
1247 room = rejoined_room.room;
1248 channel_id = rejoined_room.channel_id;
1249 channel_members = rejoined_room.channel_members;
1250 }
1251
1252 if let Some(channel_id) = channel_id {
1253 channel_updated(
1254 channel_id,
1255 &room,
1256 &channel_members,
1257 &session.peer,
1258 &*session.connection_pool().await,
1259 );
1260 }
1261
1262 update_user_contacts(session.user_id, &session).await?;
1263 Ok(())
1264}
1265
1266/// leave room disonnects from the room.
1267async fn leave_room(
1268 _: proto::LeaveRoom,
1269 response: Response<proto::LeaveRoom>,
1270 session: Session,
1271) -> Result<()> {
1272 leave_room_for_session(&session).await?;
1273 response.send(proto::Ack {})?;
1274 Ok(())
1275}
1276
1277/// Updates the permissions of someone else in the room.
1278async fn set_room_participant_role(
1279 request: proto::SetRoomParticipantRole,
1280 response: Response<proto::SetRoomParticipantRole>,
1281 session: Session,
1282) -> Result<()> {
1283 let (live_kit_room, can_publish) = {
1284 let room = session
1285 .db()
1286 .await
1287 .set_room_participant_role(
1288 session.user_id,
1289 RoomId::from_proto(request.room_id),
1290 UserId::from_proto(request.user_id),
1291 ChannelRole::from(request.role()),
1292 )
1293 .await?;
1294
1295 let live_kit_room = room.live_kit_room.clone();
1296 let can_publish = ChannelRole::from(request.role()).can_publish_to_rooms();
1297 room_updated(&room, &session.peer);
1298 (live_kit_room, can_publish)
1299 };
1300
1301 if let Some(live_kit) = session.live_kit_client.as_ref() {
1302 live_kit
1303 .update_participant(
1304 live_kit_room.clone(),
1305 request.user_id.to_string(),
1306 live_kit_server::proto::ParticipantPermission {
1307 can_subscribe: true,
1308 can_publish,
1309 can_publish_data: can_publish,
1310 hidden: false,
1311 recorder: false,
1312 },
1313 )
1314 .await
1315 .trace_err();
1316 }
1317
1318 response.send(proto::Ack {})?;
1319 Ok(())
1320}
1321
1322/// Call someone else into the current room
1323async fn call(
1324 request: proto::Call,
1325 response: Response<proto::Call>,
1326 session: Session,
1327) -> Result<()> {
1328 let room_id = RoomId::from_proto(request.room_id);
1329 let calling_user_id = session.user_id;
1330 let calling_connection_id = session.connection_id;
1331 let called_user_id = UserId::from_proto(request.called_user_id);
1332 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1333 if !session
1334 .db()
1335 .await
1336 .has_contact(calling_user_id, called_user_id)
1337 .await?
1338 {
1339 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1340 }
1341
1342 let incoming_call = {
1343 let (room, incoming_call) = &mut *session
1344 .db()
1345 .await
1346 .call(
1347 room_id,
1348 calling_user_id,
1349 calling_connection_id,
1350 called_user_id,
1351 initial_project_id,
1352 )
1353 .await?;
1354 room_updated(&room, &session.peer);
1355 mem::take(incoming_call)
1356 };
1357 update_user_contacts(called_user_id, &session).await?;
1358
1359 let mut calls = session
1360 .connection_pool()
1361 .await
1362 .user_connection_ids(called_user_id)
1363 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1364 .collect::<FuturesUnordered<_>>();
1365
1366 while let Some(call_response) = calls.next().await {
1367 match call_response.as_ref() {
1368 Ok(_) => {
1369 response.send(proto::Ack {})?;
1370 return Ok(());
1371 }
1372 Err(_) => {
1373 call_response.trace_err();
1374 }
1375 }
1376 }
1377
1378 {
1379 let room = session
1380 .db()
1381 .await
1382 .call_failed(room_id, called_user_id)
1383 .await?;
1384 room_updated(&room, &session.peer);
1385 }
1386 update_user_contacts(called_user_id, &session).await?;
1387
1388 Err(anyhow!("failed to ring user"))?
1389}
1390
1391/// Cancel an outgoing call.
1392async fn cancel_call(
1393 request: proto::CancelCall,
1394 response: Response<proto::CancelCall>,
1395 session: Session,
1396) -> Result<()> {
1397 let called_user_id = UserId::from_proto(request.called_user_id);
1398 let room_id = RoomId::from_proto(request.room_id);
1399 {
1400 let room = session
1401 .db()
1402 .await
1403 .cancel_call(room_id, session.connection_id, called_user_id)
1404 .await?;
1405 room_updated(&room, &session.peer);
1406 }
1407
1408 for connection_id in session
1409 .connection_pool()
1410 .await
1411 .user_connection_ids(called_user_id)
1412 {
1413 session
1414 .peer
1415 .send(
1416 connection_id,
1417 proto::CallCanceled {
1418 room_id: room_id.to_proto(),
1419 },
1420 )
1421 .trace_err();
1422 }
1423 response.send(proto::Ack {})?;
1424
1425 update_user_contacts(called_user_id, &session).await?;
1426 Ok(())
1427}
1428
1429/// Decline an incoming call.
1430async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1431 let room_id = RoomId::from_proto(message.room_id);
1432 {
1433 let room = session
1434 .db()
1435 .await
1436 .decline_call(Some(room_id), session.user_id)
1437 .await?
1438 .ok_or_else(|| anyhow!("failed to decline call"))?;
1439 room_updated(&room, &session.peer);
1440 }
1441
1442 for connection_id in session
1443 .connection_pool()
1444 .await
1445 .user_connection_ids(session.user_id)
1446 {
1447 session
1448 .peer
1449 .send(
1450 connection_id,
1451 proto::CallCanceled {
1452 room_id: room_id.to_proto(),
1453 },
1454 )
1455 .trace_err();
1456 }
1457 update_user_contacts(session.user_id, &session).await?;
1458 Ok(())
1459}
1460
1461/// Updates other participants in the room with your current location.
1462async fn update_participant_location(
1463 request: proto::UpdateParticipantLocation,
1464 response: Response<proto::UpdateParticipantLocation>,
1465 session: Session,
1466) -> Result<()> {
1467 let room_id = RoomId::from_proto(request.room_id);
1468 let location = request
1469 .location
1470 .ok_or_else(|| anyhow!("invalid location"))?;
1471
1472 let db = session.db().await;
1473 let room = db
1474 .update_room_participant_location(room_id, session.connection_id, location)
1475 .await?;
1476
1477 room_updated(&room, &session.peer);
1478 response.send(proto::Ack {})?;
1479 Ok(())
1480}
1481
1482/// Share a project into the room.
1483async fn share_project(
1484 request: proto::ShareProject,
1485 response: Response<proto::ShareProject>,
1486 session: Session,
1487) -> Result<()> {
1488 let (project_id, room) = &*session
1489 .db()
1490 .await
1491 .share_project(
1492 RoomId::from_proto(request.room_id),
1493 session.connection_id,
1494 &request.worktrees,
1495 )
1496 .await?;
1497 response.send(proto::ShareProjectResponse {
1498 project_id: project_id.to_proto(),
1499 })?;
1500 room_updated(&room, &session.peer);
1501
1502 Ok(())
1503}
1504
1505/// Unshare a project from the room.
1506async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1507 let project_id = ProjectId::from_proto(message.project_id);
1508
1509 let (room, guest_connection_ids) = &*session
1510 .db()
1511 .await
1512 .unshare_project(project_id, session.connection_id)
1513 .await?;
1514
1515 broadcast(
1516 Some(session.connection_id),
1517 guest_connection_ids.iter().copied(),
1518 |conn_id| session.peer.send(conn_id, message.clone()),
1519 );
1520 room_updated(&room, &session.peer);
1521
1522 Ok(())
1523}
1524
1525/// Join someone elses shared project.
1526async fn join_project(
1527 request: proto::JoinProject,
1528 response: Response<proto::JoinProject>,
1529 session: Session,
1530) -> Result<()> {
1531 let project_id = ProjectId::from_proto(request.project_id);
1532 let guest_user_id = session.user_id;
1533
1534 tracing::info!(%project_id, "join project");
1535
1536 let (project, replica_id) = &mut *session
1537 .db()
1538 .await
1539 .join_project(project_id, session.connection_id)
1540 .await?;
1541
1542 let collaborators = project
1543 .collaborators
1544 .iter()
1545 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1546 .map(|collaborator| collaborator.to_proto())
1547 .collect::<Vec<_>>();
1548
1549 let worktrees = project
1550 .worktrees
1551 .iter()
1552 .map(|(id, worktree)| proto::WorktreeMetadata {
1553 id: *id,
1554 root_name: worktree.root_name.clone(),
1555 visible: worktree.visible,
1556 abs_path: worktree.abs_path.clone(),
1557 })
1558 .collect::<Vec<_>>();
1559
1560 for collaborator in &collaborators {
1561 session
1562 .peer
1563 .send(
1564 collaborator.peer_id.unwrap().into(),
1565 proto::AddProjectCollaborator {
1566 project_id: project_id.to_proto(),
1567 collaborator: Some(proto::Collaborator {
1568 peer_id: Some(session.connection_id.into()),
1569 replica_id: replica_id.0 as u32,
1570 user_id: guest_user_id.to_proto(),
1571 }),
1572 },
1573 )
1574 .trace_err();
1575 }
1576
1577 // First, we send the metadata associated with each worktree.
1578 response.send(proto::JoinProjectResponse {
1579 worktrees: worktrees.clone(),
1580 replica_id: replica_id.0 as u32,
1581 collaborators: collaborators.clone(),
1582 language_servers: project.language_servers.clone(),
1583 })?;
1584
1585 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1586 #[cfg(any(test, feature = "test-support"))]
1587 const MAX_CHUNK_SIZE: usize = 2;
1588 #[cfg(not(any(test, feature = "test-support")))]
1589 const MAX_CHUNK_SIZE: usize = 256;
1590
1591 // Stream this worktree's entries.
1592 let message = proto::UpdateWorktree {
1593 project_id: project_id.to_proto(),
1594 worktree_id,
1595 abs_path: worktree.abs_path.clone(),
1596 root_name: worktree.root_name,
1597 updated_entries: worktree.entries,
1598 removed_entries: Default::default(),
1599 scan_id: worktree.scan_id,
1600 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1601 updated_repositories: worktree.repository_entries.into_values().collect(),
1602 removed_repositories: Default::default(),
1603 };
1604 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1605 session.peer.send(session.connection_id, update.clone())?;
1606 }
1607
1608 // Stream this worktree's diagnostics.
1609 for summary in worktree.diagnostic_summaries {
1610 session.peer.send(
1611 session.connection_id,
1612 proto::UpdateDiagnosticSummary {
1613 project_id: project_id.to_proto(),
1614 worktree_id: worktree.id,
1615 summary: Some(summary),
1616 },
1617 )?;
1618 }
1619
1620 for settings_file in worktree.settings_files {
1621 session.peer.send(
1622 session.connection_id,
1623 proto::UpdateWorktreeSettings {
1624 project_id: project_id.to_proto(),
1625 worktree_id: worktree.id,
1626 path: settings_file.path,
1627 content: Some(settings_file.content),
1628 },
1629 )?;
1630 }
1631 }
1632
1633 for language_server in &project.language_servers {
1634 session.peer.send(
1635 session.connection_id,
1636 proto::UpdateLanguageServer {
1637 project_id: project_id.to_proto(),
1638 language_server_id: language_server.id,
1639 variant: Some(
1640 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1641 proto::LspDiskBasedDiagnosticsUpdated {},
1642 ),
1643 ),
1644 },
1645 )?;
1646 }
1647
1648 Ok(())
1649}
1650
1651/// Leave someone elses shared project.
1652async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1653 let sender_id = session.connection_id;
1654 let project_id = ProjectId::from_proto(request.project_id);
1655
1656 let (room, project) = &*session
1657 .db()
1658 .await
1659 .leave_project(project_id, sender_id)
1660 .await?;
1661 tracing::info!(
1662 %project_id,
1663 host_user_id = %project.host_user_id,
1664 host_connection_id = %project.host_connection_id,
1665 "leave project"
1666 );
1667
1668 project_left(&project, &session);
1669 room_updated(&room, &session.peer);
1670
1671 Ok(())
1672}
1673
1674/// Updates other participants with changes to the project
1675async fn update_project(
1676 request: proto::UpdateProject,
1677 response: Response<proto::UpdateProject>,
1678 session: Session,
1679) -> Result<()> {
1680 let project_id = ProjectId::from_proto(request.project_id);
1681 let (room, guest_connection_ids) = &*session
1682 .db()
1683 .await
1684 .update_project(project_id, session.connection_id, &request.worktrees)
1685 .await?;
1686 broadcast(
1687 Some(session.connection_id),
1688 guest_connection_ids.iter().copied(),
1689 |connection_id| {
1690 session
1691 .peer
1692 .forward_send(session.connection_id, connection_id, request.clone())
1693 },
1694 );
1695 room_updated(&room, &session.peer);
1696 response.send(proto::Ack {})?;
1697
1698 Ok(())
1699}
1700
1701/// Updates other participants with changes to the worktree
1702async fn update_worktree(
1703 request: proto::UpdateWorktree,
1704 response: Response<proto::UpdateWorktree>,
1705 session: Session,
1706) -> Result<()> {
1707 let guest_connection_ids = session
1708 .db()
1709 .await
1710 .update_worktree(&request, session.connection_id)
1711 .await?;
1712
1713 broadcast(
1714 Some(session.connection_id),
1715 guest_connection_ids.iter().copied(),
1716 |connection_id| {
1717 session
1718 .peer
1719 .forward_send(session.connection_id, connection_id, request.clone())
1720 },
1721 );
1722 response.send(proto::Ack {})?;
1723 Ok(())
1724}
1725
1726/// Updates other participants with changes to the diagnostics
1727async fn update_diagnostic_summary(
1728 message: proto::UpdateDiagnosticSummary,
1729 session: Session,
1730) -> Result<()> {
1731 let guest_connection_ids = session
1732 .db()
1733 .await
1734 .update_diagnostic_summary(&message, session.connection_id)
1735 .await?;
1736
1737 broadcast(
1738 Some(session.connection_id),
1739 guest_connection_ids.iter().copied(),
1740 |connection_id| {
1741 session
1742 .peer
1743 .forward_send(session.connection_id, connection_id, message.clone())
1744 },
1745 );
1746
1747 Ok(())
1748}
1749
1750/// Updates other participants with changes to the worktree settings
1751async fn update_worktree_settings(
1752 message: proto::UpdateWorktreeSettings,
1753 session: Session,
1754) -> Result<()> {
1755 let guest_connection_ids = session
1756 .db()
1757 .await
1758 .update_worktree_settings(&message, session.connection_id)
1759 .await?;
1760
1761 broadcast(
1762 Some(session.connection_id),
1763 guest_connection_ids.iter().copied(),
1764 |connection_id| {
1765 session
1766 .peer
1767 .forward_send(session.connection_id, connection_id, message.clone())
1768 },
1769 );
1770
1771 Ok(())
1772}
1773
1774/// Notify other participants that a language server has started.
1775async fn start_language_server(
1776 request: proto::StartLanguageServer,
1777 session: Session,
1778) -> Result<()> {
1779 let guest_connection_ids = session
1780 .db()
1781 .await
1782 .start_language_server(&request, session.connection_id)
1783 .await?;
1784
1785 broadcast(
1786 Some(session.connection_id),
1787 guest_connection_ids.iter().copied(),
1788 |connection_id| {
1789 session
1790 .peer
1791 .forward_send(session.connection_id, connection_id, request.clone())
1792 },
1793 );
1794 Ok(())
1795}
1796
1797/// Notify other participants that a language server has changed.
1798async fn update_language_server(
1799 request: proto::UpdateLanguageServer,
1800 session: Session,
1801) -> Result<()> {
1802 let project_id = ProjectId::from_proto(request.project_id);
1803 let project_connection_ids = session
1804 .db()
1805 .await
1806 .project_connection_ids(project_id, session.connection_id)
1807 .await?;
1808 broadcast(
1809 Some(session.connection_id),
1810 project_connection_ids.iter().copied(),
1811 |connection_id| {
1812 session
1813 .peer
1814 .forward_send(session.connection_id, connection_id, request.clone())
1815 },
1816 );
1817 Ok(())
1818}
1819
1820/// forward a project request to the host. These requests should be read only
1821/// as guests are allowed to send them.
1822async fn forward_read_only_project_request<T>(
1823 request: T,
1824 response: Response<T>,
1825 session: Session,
1826) -> Result<()>
1827where
1828 T: EntityMessage + RequestMessage,
1829{
1830 let project_id = ProjectId::from_proto(request.remote_entity_id());
1831 let host_connection_id = session
1832 .db()
1833 .await
1834 .host_for_read_only_project_request(project_id, session.connection_id)
1835 .await?;
1836 let payload = session
1837 .peer
1838 .forward_request(session.connection_id, host_connection_id, request)
1839 .await?;
1840 response.send(payload)?;
1841 Ok(())
1842}
1843
1844/// forward a project request to the host. These requests are disallowed
1845/// for guests.
1846async fn forward_mutating_project_request<T>(
1847 request: T,
1848 response: Response<T>,
1849 session: Session,
1850) -> Result<()>
1851where
1852 T: EntityMessage + RequestMessage,
1853{
1854 let project_id = ProjectId::from_proto(request.remote_entity_id());
1855 let host_connection_id = session
1856 .db()
1857 .await
1858 .host_for_mutating_project_request(project_id, session.connection_id)
1859 .await?;
1860 let payload = session
1861 .peer
1862 .forward_request(session.connection_id, host_connection_id, request)
1863 .await?;
1864 response.send(payload)?;
1865 Ok(())
1866}
1867
1868/// Notify other participants that a new buffer has been created
1869async fn create_buffer_for_peer(
1870 request: proto::CreateBufferForPeer,
1871 session: Session,
1872) -> Result<()> {
1873 session
1874 .db()
1875 .await
1876 .check_user_is_project_host(
1877 ProjectId::from_proto(request.project_id),
1878 session.connection_id,
1879 )
1880 .await?;
1881 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1882 session
1883 .peer
1884 .forward_send(session.connection_id, peer_id.into(), request)?;
1885 Ok(())
1886}
1887
1888/// Notify other participants that a buffer has been updated. This is
1889/// allowed for guests as long as the update is limited to selections.
1890async fn update_buffer(
1891 request: proto::UpdateBuffer,
1892 response: Response<proto::UpdateBuffer>,
1893 session: Session,
1894) -> Result<()> {
1895 let project_id = ProjectId::from_proto(request.project_id);
1896 let mut guest_connection_ids;
1897 let mut host_connection_id = None;
1898
1899 let mut requires_write_permission = false;
1900
1901 for op in request.operations.iter() {
1902 match op.variant {
1903 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1904 Some(_) => requires_write_permission = true,
1905 }
1906 }
1907
1908 {
1909 let collaborators = session
1910 .db()
1911 .await
1912 .project_collaborators_for_buffer_update(
1913 project_id,
1914 session.connection_id,
1915 requires_write_permission,
1916 )
1917 .await?;
1918 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1919 for collaborator in collaborators.iter() {
1920 if collaborator.is_host {
1921 host_connection_id = Some(collaborator.connection_id);
1922 } else {
1923 guest_connection_ids.push(collaborator.connection_id);
1924 }
1925 }
1926 }
1927 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1928
1929 broadcast(
1930 Some(session.connection_id),
1931 guest_connection_ids,
1932 |connection_id| {
1933 session
1934 .peer
1935 .forward_send(session.connection_id, connection_id, request.clone())
1936 },
1937 );
1938 if host_connection_id != session.connection_id {
1939 session
1940 .peer
1941 .forward_request(session.connection_id, host_connection_id, request.clone())
1942 .await?;
1943 }
1944
1945 response.send(proto::Ack {})?;
1946 Ok(())
1947}
1948
1949/// Notify other participants that a project has been updated.
1950async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
1951 request: T,
1952 session: Session,
1953) -> Result<()> {
1954 let project_id = ProjectId::from_proto(request.remote_entity_id());
1955 let project_connection_ids = session
1956 .db()
1957 .await
1958 .project_connection_ids(project_id, session.connection_id)
1959 .await?;
1960
1961 broadcast(
1962 Some(session.connection_id),
1963 project_connection_ids.iter().copied(),
1964 |connection_id| {
1965 session
1966 .peer
1967 .forward_send(session.connection_id, connection_id, request.clone())
1968 },
1969 );
1970 Ok(())
1971}
1972
1973/// Start following another user in a call.
1974async fn follow(
1975 request: proto::Follow,
1976 response: Response<proto::Follow>,
1977 session: Session,
1978) -> Result<()> {
1979 let room_id = RoomId::from_proto(request.room_id);
1980 let project_id = request.project_id.map(ProjectId::from_proto);
1981 let leader_id = request
1982 .leader_id
1983 .ok_or_else(|| anyhow!("invalid leader id"))?
1984 .into();
1985 let follower_id = session.connection_id;
1986
1987 session
1988 .db()
1989 .await
1990 .check_room_participants(room_id, leader_id, session.connection_id)
1991 .await?;
1992
1993 let response_payload = session
1994 .peer
1995 .forward_request(session.connection_id, leader_id, request)
1996 .await?;
1997 response.send(response_payload)?;
1998
1999 if let Some(project_id) = project_id {
2000 let room = session
2001 .db()
2002 .await
2003 .follow(room_id, project_id, leader_id, follower_id)
2004 .await?;
2005 room_updated(&room, &session.peer);
2006 }
2007
2008 Ok(())
2009}
2010
2011/// Stop following another user in a call.
2012async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2013 let room_id = RoomId::from_proto(request.room_id);
2014 let project_id = request.project_id.map(ProjectId::from_proto);
2015 let leader_id = request
2016 .leader_id
2017 .ok_or_else(|| anyhow!("invalid leader id"))?
2018 .into();
2019 let follower_id = session.connection_id;
2020
2021 session
2022 .db()
2023 .await
2024 .check_room_participants(room_id, leader_id, session.connection_id)
2025 .await?;
2026
2027 session
2028 .peer
2029 .forward_send(session.connection_id, leader_id, request)?;
2030
2031 if let Some(project_id) = project_id {
2032 let room = session
2033 .db()
2034 .await
2035 .unfollow(room_id, project_id, leader_id, follower_id)
2036 .await?;
2037 room_updated(&room, &session.peer);
2038 }
2039
2040 Ok(())
2041}
2042
2043/// Notify everyone following you of your current location.
2044async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2045 let room_id = RoomId::from_proto(request.room_id);
2046 let database = session.db.lock().await;
2047
2048 let connection_ids = if let Some(project_id) = request.project_id {
2049 let project_id = ProjectId::from_proto(project_id);
2050 database
2051 .project_connection_ids(project_id, session.connection_id)
2052 .await?
2053 } else {
2054 database
2055 .room_connection_ids(room_id, session.connection_id)
2056 .await?
2057 };
2058
2059 // For now, don't send view update messages back to that view's current leader.
2060 let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2061 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2062 _ => None,
2063 });
2064
2065 for follower_peer_id in request.follower_ids.iter().copied() {
2066 let follower_connection_id = follower_peer_id.into();
2067 if Some(follower_peer_id) != connection_id_to_omit
2068 && connection_ids.contains(&follower_connection_id)
2069 {
2070 session.peer.forward_send(
2071 session.connection_id,
2072 follower_connection_id,
2073 request.clone(),
2074 )?;
2075 }
2076 }
2077 Ok(())
2078}
2079
2080/// Get public data about users.
2081async fn get_users(
2082 request: proto::GetUsers,
2083 response: Response<proto::GetUsers>,
2084 session: Session,
2085) -> Result<()> {
2086 let user_ids = request
2087 .user_ids
2088 .into_iter()
2089 .map(UserId::from_proto)
2090 .collect();
2091 let users = session
2092 .db()
2093 .await
2094 .get_users_by_ids(user_ids)
2095 .await?
2096 .into_iter()
2097 .map(|user| proto::User {
2098 id: user.id.to_proto(),
2099 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2100 github_login: user.github_login,
2101 })
2102 .collect();
2103 response.send(proto::UsersResponse { users })?;
2104 Ok(())
2105}
2106
2107/// Search for users (to invite) buy Github login
2108async fn fuzzy_search_users(
2109 request: proto::FuzzySearchUsers,
2110 response: Response<proto::FuzzySearchUsers>,
2111 session: Session,
2112) -> Result<()> {
2113 let query = request.query;
2114 let users = match query.len() {
2115 0 => vec![],
2116 1 | 2 => session
2117 .db()
2118 .await
2119 .get_user_by_github_login(&query)
2120 .await?
2121 .into_iter()
2122 .collect(),
2123 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2124 };
2125 let users = users
2126 .into_iter()
2127 .filter(|user| user.id != session.user_id)
2128 .map(|user| proto::User {
2129 id: user.id.to_proto(),
2130 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2131 github_login: user.github_login,
2132 })
2133 .collect();
2134 response.send(proto::UsersResponse { users })?;
2135 Ok(())
2136}
2137
2138/// Send a contact request to another user.
2139async fn request_contact(
2140 request: proto::RequestContact,
2141 response: Response<proto::RequestContact>,
2142 session: Session,
2143) -> Result<()> {
2144 let requester_id = session.user_id;
2145 let responder_id = UserId::from_proto(request.responder_id);
2146 if requester_id == responder_id {
2147 return Err(anyhow!("cannot add yourself as a contact"))?;
2148 }
2149
2150 let notifications = session
2151 .db()
2152 .await
2153 .send_contact_request(requester_id, responder_id)
2154 .await?;
2155
2156 // Update outgoing contact requests of requester
2157 let mut update = proto::UpdateContacts::default();
2158 update.outgoing_requests.push(responder_id.to_proto());
2159 for connection_id in session
2160 .connection_pool()
2161 .await
2162 .user_connection_ids(requester_id)
2163 {
2164 session.peer.send(connection_id, update.clone())?;
2165 }
2166
2167 // Update incoming contact requests of responder
2168 let mut update = proto::UpdateContacts::default();
2169 update
2170 .incoming_requests
2171 .push(proto::IncomingContactRequest {
2172 requester_id: requester_id.to_proto(),
2173 });
2174 let connection_pool = session.connection_pool().await;
2175 for connection_id in connection_pool.user_connection_ids(responder_id) {
2176 session.peer.send(connection_id, update.clone())?;
2177 }
2178
2179 send_notifications(&*connection_pool, &session.peer, notifications);
2180
2181 response.send(proto::Ack {})?;
2182 Ok(())
2183}
2184
2185/// Accept or decline a contact request
2186async fn respond_to_contact_request(
2187 request: proto::RespondToContactRequest,
2188 response: Response<proto::RespondToContactRequest>,
2189 session: Session,
2190) -> Result<()> {
2191 let responder_id = session.user_id;
2192 let requester_id = UserId::from_proto(request.requester_id);
2193 let db = session.db().await;
2194 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2195 db.dismiss_contact_notification(responder_id, requester_id)
2196 .await?;
2197 } else {
2198 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2199
2200 let notifications = db
2201 .respond_to_contact_request(responder_id, requester_id, accept)
2202 .await?;
2203 let requester_busy = db.is_user_busy(requester_id).await?;
2204 let responder_busy = db.is_user_busy(responder_id).await?;
2205
2206 let pool = session.connection_pool().await;
2207 // Update responder with new contact
2208 let mut update = proto::UpdateContacts::default();
2209 if accept {
2210 update
2211 .contacts
2212 .push(contact_for_user(requester_id, requester_busy, &pool));
2213 }
2214 update
2215 .remove_incoming_requests
2216 .push(requester_id.to_proto());
2217 for connection_id in pool.user_connection_ids(responder_id) {
2218 session.peer.send(connection_id, update.clone())?;
2219 }
2220
2221 // Update requester with new contact
2222 let mut update = proto::UpdateContacts::default();
2223 if accept {
2224 update
2225 .contacts
2226 .push(contact_for_user(responder_id, responder_busy, &pool));
2227 }
2228 update
2229 .remove_outgoing_requests
2230 .push(responder_id.to_proto());
2231
2232 for connection_id in pool.user_connection_ids(requester_id) {
2233 session.peer.send(connection_id, update.clone())?;
2234 }
2235
2236 send_notifications(&*pool, &session.peer, notifications);
2237 }
2238
2239 response.send(proto::Ack {})?;
2240 Ok(())
2241}
2242
2243/// Remove a contact.
2244async fn remove_contact(
2245 request: proto::RemoveContact,
2246 response: Response<proto::RemoveContact>,
2247 session: Session,
2248) -> Result<()> {
2249 let requester_id = session.user_id;
2250 let responder_id = UserId::from_proto(request.user_id);
2251 let db = session.db().await;
2252 let (contact_accepted, deleted_notification_id) =
2253 db.remove_contact(requester_id, responder_id).await?;
2254
2255 let pool = session.connection_pool().await;
2256 // Update outgoing contact requests of requester
2257 let mut update = proto::UpdateContacts::default();
2258 if contact_accepted {
2259 update.remove_contacts.push(responder_id.to_proto());
2260 } else {
2261 update
2262 .remove_outgoing_requests
2263 .push(responder_id.to_proto());
2264 }
2265 for connection_id in pool.user_connection_ids(requester_id) {
2266 session.peer.send(connection_id, update.clone())?;
2267 }
2268
2269 // Update incoming contact requests of responder
2270 let mut update = proto::UpdateContacts::default();
2271 if contact_accepted {
2272 update.remove_contacts.push(requester_id.to_proto());
2273 } else {
2274 update
2275 .remove_incoming_requests
2276 .push(requester_id.to_proto());
2277 }
2278 for connection_id in pool.user_connection_ids(responder_id) {
2279 session.peer.send(connection_id, update.clone())?;
2280 if let Some(notification_id) = deleted_notification_id {
2281 session.peer.send(
2282 connection_id,
2283 proto::DeleteNotification {
2284 notification_id: notification_id.to_proto(),
2285 },
2286 )?;
2287 }
2288 }
2289
2290 response.send(proto::Ack {})?;
2291 Ok(())
2292}
2293
2294/// Creates a new channel.
2295async fn create_channel(
2296 request: proto::CreateChannel,
2297 response: Response<proto::CreateChannel>,
2298 session: Session,
2299) -> Result<()> {
2300 let db = session.db().await;
2301
2302 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2303 let channel = db
2304 .create_channel(&request.name, parent_id, session.user_id)
2305 .await?;
2306
2307 response.send(proto::CreateChannelResponse {
2308 channel: Some(channel.to_proto()),
2309 parent_id: request.parent_id,
2310 })?;
2311
2312 let participants_to_update;
2313 if let Some(parent) = parent_id {
2314 participants_to_update = db.new_participants_to_notify(parent).await?;
2315 } else {
2316 participants_to_update = vec![];
2317 }
2318
2319 let connection_pool = session.connection_pool().await;
2320 for (user_id, channels) in participants_to_update {
2321 let update = build_channels_update(channels, vec![]);
2322 for connection_id in connection_pool.user_connection_ids(user_id) {
2323 if user_id == session.user_id {
2324 continue;
2325 }
2326 session.peer.send(connection_id, update.clone())?;
2327 }
2328 }
2329
2330 Ok(())
2331}
2332
2333/// Delete a channel
2334async fn delete_channel(
2335 request: proto::DeleteChannel,
2336 response: Response<proto::DeleteChannel>,
2337 session: Session,
2338) -> Result<()> {
2339 let db = session.db().await;
2340
2341 let channel_id = request.channel_id;
2342 let (removed_channels, member_ids) = db
2343 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2344 .await?;
2345 response.send(proto::Ack {})?;
2346
2347 // Notify members of removed channels
2348 let mut update = proto::UpdateChannels::default();
2349 update
2350 .delete_channels
2351 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2352
2353 let connection_pool = session.connection_pool().await;
2354 for member_id in member_ids {
2355 for connection_id in connection_pool.user_connection_ids(member_id) {
2356 session.peer.send(connection_id, update.clone())?;
2357 }
2358 }
2359
2360 Ok(())
2361}
2362
2363/// Invite someone to join a channel.
2364async fn invite_channel_member(
2365 request: proto::InviteChannelMember,
2366 response: Response<proto::InviteChannelMember>,
2367 session: Session,
2368) -> Result<()> {
2369 let db = session.db().await;
2370 let channel_id = ChannelId::from_proto(request.channel_id);
2371 let invitee_id = UserId::from_proto(request.user_id);
2372 let InviteMemberResult {
2373 channel,
2374 notifications,
2375 } = db
2376 .invite_channel_member(
2377 channel_id,
2378 invitee_id,
2379 session.user_id,
2380 request.role().into(),
2381 )
2382 .await?;
2383
2384 let update = proto::UpdateChannels {
2385 channel_invitations: vec![channel.to_proto()],
2386 ..Default::default()
2387 };
2388
2389 let connection_pool = session.connection_pool().await;
2390 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2391 session.peer.send(connection_id, update.clone())?;
2392 }
2393
2394 send_notifications(&*connection_pool, &session.peer, notifications);
2395
2396 response.send(proto::Ack {})?;
2397 Ok(())
2398}
2399
2400/// remove someone from a channel
2401async fn remove_channel_member(
2402 request: proto::RemoveChannelMember,
2403 response: Response<proto::RemoveChannelMember>,
2404 session: Session,
2405) -> Result<()> {
2406 let db = session.db().await;
2407 let channel_id = ChannelId::from_proto(request.channel_id);
2408 let member_id = UserId::from_proto(request.user_id);
2409
2410 let RemoveChannelMemberResult {
2411 membership_update,
2412 notification_id,
2413 } = db
2414 .remove_channel_member(channel_id, member_id, session.user_id)
2415 .await?;
2416
2417 let connection_pool = &session.connection_pool().await;
2418 notify_membership_updated(
2419 &connection_pool,
2420 membership_update,
2421 member_id,
2422 &session.peer,
2423 );
2424 for connection_id in connection_pool.user_connection_ids(member_id) {
2425 if let Some(notification_id) = notification_id {
2426 session
2427 .peer
2428 .send(
2429 connection_id,
2430 proto::DeleteNotification {
2431 notification_id: notification_id.to_proto(),
2432 },
2433 )
2434 .trace_err();
2435 }
2436 }
2437
2438 response.send(proto::Ack {})?;
2439 Ok(())
2440}
2441
2442/// Toggle the channel between public and private
2443async fn set_channel_visibility(
2444 request: proto::SetChannelVisibility,
2445 response: Response<proto::SetChannelVisibility>,
2446 session: Session,
2447) -> Result<()> {
2448 let db = session.db().await;
2449 let channel_id = ChannelId::from_proto(request.channel_id);
2450 let visibility = request.visibility().into();
2451
2452 let SetChannelVisibilityResult {
2453 participants_to_update,
2454 participants_to_remove,
2455 channels_to_remove,
2456 } = db
2457 .set_channel_visibility(channel_id, visibility, session.user_id)
2458 .await?;
2459
2460 let connection_pool = session.connection_pool().await;
2461 for (user_id, channels) in participants_to_update {
2462 let update = build_channels_update(channels, vec![]);
2463 for connection_id in connection_pool.user_connection_ids(user_id) {
2464 session.peer.send(connection_id, update.clone())?;
2465 }
2466 }
2467 for user_id in participants_to_remove {
2468 let update = proto::UpdateChannels {
2469 delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2470 ..Default::default()
2471 };
2472 for connection_id in connection_pool.user_connection_ids(user_id) {
2473 session.peer.send(connection_id, update.clone())?;
2474 }
2475 }
2476
2477 response.send(proto::Ack {})?;
2478 Ok(())
2479}
2480
2481/// Alter the role for a user in the channel
2482async fn set_channel_member_role(
2483 request: proto::SetChannelMemberRole,
2484 response: Response<proto::SetChannelMemberRole>,
2485 session: Session,
2486) -> Result<()> {
2487 let db = session.db().await;
2488 let channel_id = ChannelId::from_proto(request.channel_id);
2489 let member_id = UserId::from_proto(request.user_id);
2490 let result = db
2491 .set_channel_member_role(
2492 channel_id,
2493 session.user_id,
2494 member_id,
2495 request.role().into(),
2496 )
2497 .await?;
2498
2499 match result {
2500 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2501 let connection_pool = session.connection_pool().await;
2502 notify_membership_updated(
2503 &connection_pool,
2504 membership_update,
2505 member_id,
2506 &session.peer,
2507 )
2508 }
2509 db::SetMemberRoleResult::InviteUpdated(channel) => {
2510 let update = proto::UpdateChannels {
2511 channel_invitations: vec![channel.to_proto()],
2512 ..Default::default()
2513 };
2514
2515 for connection_id in session
2516 .connection_pool()
2517 .await
2518 .user_connection_ids(member_id)
2519 {
2520 session.peer.send(connection_id, update.clone())?;
2521 }
2522 }
2523 }
2524
2525 response.send(proto::Ack {})?;
2526 Ok(())
2527}
2528
2529/// Change the name of a channel
2530async fn rename_channel(
2531 request: proto::RenameChannel,
2532 response: Response<proto::RenameChannel>,
2533 session: Session,
2534) -> Result<()> {
2535 let db = session.db().await;
2536 let channel_id = ChannelId::from_proto(request.channel_id);
2537 let RenameChannelResult {
2538 channel,
2539 participants_to_update,
2540 } = db
2541 .rename_channel(channel_id, session.user_id, &request.name)
2542 .await?;
2543
2544 response.send(proto::RenameChannelResponse {
2545 channel: Some(channel.to_proto()),
2546 })?;
2547
2548 let connection_pool = session.connection_pool().await;
2549 for (user_id, channel) in participants_to_update {
2550 for connection_id in connection_pool.user_connection_ids(user_id) {
2551 let update = proto::UpdateChannels {
2552 channels: vec![channel.to_proto()],
2553 ..Default::default()
2554 };
2555
2556 session.peer.send(connection_id, update.clone())?;
2557 }
2558 }
2559
2560 Ok(())
2561}
2562
2563/// Move a channel to a new parent.
2564async fn move_channel(
2565 request: proto::MoveChannel,
2566 response: Response<proto::MoveChannel>,
2567 session: Session,
2568) -> Result<()> {
2569 let channel_id = ChannelId::from_proto(request.channel_id);
2570 let to = request.to.map(ChannelId::from_proto);
2571
2572 let result = session
2573 .db()
2574 .await
2575 .move_channel(channel_id, to, session.user_id)
2576 .await?;
2577
2578 if let Some(result) = result {
2579 let participants_to_update: HashMap<_, _> = session
2580 .db()
2581 .await
2582 .new_participants_to_notify(to.unwrap_or(channel_id))
2583 .await?
2584 .into_iter()
2585 .collect();
2586
2587 let mut moved_channels: HashSet<ChannelId> = HashSet::default();
2588 for id in result.descendent_ids {
2589 moved_channels.insert(id);
2590 }
2591 moved_channels.insert(channel_id);
2592
2593 let mut participants_to_remove: HashSet<UserId> = HashSet::default();
2594 for participant in result.previous_participants {
2595 if participant.kind == proto::channel_member::Kind::AncestorMember {
2596 if !participants_to_update.contains_key(&participant.user_id) {
2597 participants_to_remove.insert(participant.user_id);
2598 }
2599 }
2600 }
2601
2602 let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2603
2604 let connection_pool = session.connection_pool().await;
2605 for (user_id, channels) in participants_to_update {
2606 let mut update = build_channels_update(channels, vec![]);
2607 update.delete_channels = moved_channels.clone();
2608 for connection_id in connection_pool.user_connection_ids(user_id) {
2609 session.peer.send(connection_id, update.clone())?;
2610 }
2611 }
2612
2613 for user_id in participants_to_remove {
2614 let update = proto::UpdateChannels {
2615 delete_channels: moved_channels.clone(),
2616 ..Default::default()
2617 };
2618 for connection_id in connection_pool.user_connection_ids(user_id) {
2619 session.peer.send(connection_id, update.clone())?;
2620 }
2621 }
2622 }
2623
2624 response.send(Ack {})?;
2625 Ok(())
2626}
2627
2628/// Get the list of channel members
2629async fn get_channel_members(
2630 request: proto::GetChannelMembers,
2631 response: Response<proto::GetChannelMembers>,
2632 session: Session,
2633) -> Result<()> {
2634 let db = session.db().await;
2635 let channel_id = ChannelId::from_proto(request.channel_id);
2636 let members = db
2637 .get_channel_participant_details(channel_id, session.user_id)
2638 .await?;
2639 response.send(proto::GetChannelMembersResponse { members })?;
2640 Ok(())
2641}
2642
2643/// Accept or decline a channel invitation.
2644async fn respond_to_channel_invite(
2645 request: proto::RespondToChannelInvite,
2646 response: Response<proto::RespondToChannelInvite>,
2647 session: Session,
2648) -> Result<()> {
2649 let db = session.db().await;
2650 let channel_id = ChannelId::from_proto(request.channel_id);
2651 let RespondToChannelInvite {
2652 membership_update,
2653 notifications,
2654 } = db
2655 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2656 .await?;
2657
2658 let connection_pool = session.connection_pool().await;
2659 if let Some(membership_update) = membership_update {
2660 notify_membership_updated(
2661 &connection_pool,
2662 membership_update,
2663 session.user_id,
2664 &session.peer,
2665 );
2666 } else {
2667 let update = proto::UpdateChannels {
2668 remove_channel_invitations: vec![channel_id.to_proto()],
2669 ..Default::default()
2670 };
2671
2672 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2673 session.peer.send(connection_id, update.clone())?;
2674 }
2675 };
2676
2677 send_notifications(&*connection_pool, &session.peer, notifications);
2678
2679 response.send(proto::Ack {})?;
2680
2681 Ok(())
2682}
2683
2684/// Join the channels' room
2685async fn join_channel(
2686 request: proto::JoinChannel,
2687 response: Response<proto::JoinChannel>,
2688 session: Session,
2689) -> Result<()> {
2690 let channel_id = ChannelId::from_proto(request.channel_id);
2691 join_channel_internal(channel_id, Box::new(response), session).await
2692}
2693
2694trait JoinChannelInternalResponse {
2695 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2696}
2697impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2698 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2699 Response::<proto::JoinChannel>::send(self, result)
2700 }
2701}
2702impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2703 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2704 Response::<proto::JoinRoom>::send(self, result)
2705 }
2706}
2707
2708async fn join_channel_internal(
2709 channel_id: ChannelId,
2710 response: Box<impl JoinChannelInternalResponse>,
2711 session: Session,
2712) -> Result<()> {
2713 let joined_room = {
2714 leave_room_for_session(&session).await?;
2715 let db = session.db().await;
2716
2717 let (joined_room, membership_updated, role) = db
2718 .join_channel(
2719 channel_id,
2720 session.user_id,
2721 session.connection_id,
2722 session.zed_environment.as_ref(),
2723 )
2724 .await?;
2725
2726 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2727 let (can_publish, token) = if role == ChannelRole::Guest {
2728 (
2729 false,
2730 live_kit
2731 .guest_token(
2732 &joined_room.room.live_kit_room,
2733 &session.user_id.to_string(),
2734 )
2735 .trace_err()?,
2736 )
2737 } else {
2738 (
2739 true,
2740 live_kit
2741 .room_token(
2742 &joined_room.room.live_kit_room,
2743 &session.user_id.to_string(),
2744 )
2745 .trace_err()?,
2746 )
2747 };
2748
2749 Some(LiveKitConnectionInfo {
2750 server_url: live_kit.url().into(),
2751 token,
2752 can_publish,
2753 })
2754 });
2755
2756 response.send(proto::JoinRoomResponse {
2757 room: Some(joined_room.room.clone()),
2758 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2759 live_kit_connection_info,
2760 })?;
2761
2762 let connection_pool = session.connection_pool().await;
2763 if let Some(membership_updated) = membership_updated {
2764 notify_membership_updated(
2765 &connection_pool,
2766 membership_updated,
2767 session.user_id,
2768 &session.peer,
2769 );
2770 }
2771
2772 room_updated(&joined_room.room, &session.peer);
2773
2774 joined_room
2775 };
2776
2777 channel_updated(
2778 channel_id,
2779 &joined_room.room,
2780 &joined_room.channel_members,
2781 &session.peer,
2782 &*session.connection_pool().await,
2783 );
2784
2785 update_user_contacts(session.user_id, &session).await?;
2786 Ok(())
2787}
2788
2789/// Start editing the channel notes
2790async fn join_channel_buffer(
2791 request: proto::JoinChannelBuffer,
2792 response: Response<proto::JoinChannelBuffer>,
2793 session: Session,
2794) -> Result<()> {
2795 let db = session.db().await;
2796 let channel_id = ChannelId::from_proto(request.channel_id);
2797
2798 let open_response = db
2799 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2800 .await?;
2801
2802 let collaborators = open_response.collaborators.clone();
2803 response.send(open_response)?;
2804
2805 let update = UpdateChannelBufferCollaborators {
2806 channel_id: channel_id.to_proto(),
2807 collaborators: collaborators.clone(),
2808 };
2809 channel_buffer_updated(
2810 session.connection_id,
2811 collaborators
2812 .iter()
2813 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2814 &update,
2815 &session.peer,
2816 );
2817
2818 Ok(())
2819}
2820
2821/// Edit the channel notes
2822async fn update_channel_buffer(
2823 request: proto::UpdateChannelBuffer,
2824 session: Session,
2825) -> Result<()> {
2826 let db = session.db().await;
2827 let channel_id = ChannelId::from_proto(request.channel_id);
2828
2829 let (collaborators, non_collaborators, epoch, version) = db
2830 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2831 .await?;
2832
2833 channel_buffer_updated(
2834 session.connection_id,
2835 collaborators,
2836 &proto::UpdateChannelBuffer {
2837 channel_id: channel_id.to_proto(),
2838 operations: request.operations,
2839 },
2840 &session.peer,
2841 );
2842
2843 let pool = &*session.connection_pool().await;
2844
2845 broadcast(
2846 None,
2847 non_collaborators
2848 .iter()
2849 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2850 |peer_id| {
2851 session.peer.send(
2852 peer_id.into(),
2853 proto::UpdateChannels {
2854 unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2855 channel_id: channel_id.to_proto(),
2856 epoch: epoch as u64,
2857 version: version.clone(),
2858 }],
2859 ..Default::default()
2860 },
2861 )
2862 },
2863 );
2864
2865 Ok(())
2866}
2867
2868/// Rejoin the channel notes after a connection blip
2869async fn rejoin_channel_buffers(
2870 request: proto::RejoinChannelBuffers,
2871 response: Response<proto::RejoinChannelBuffers>,
2872 session: Session,
2873) -> Result<()> {
2874 let db = session.db().await;
2875 let buffers = db
2876 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2877 .await?;
2878
2879 for rejoined_buffer in &buffers {
2880 let collaborators_to_notify = rejoined_buffer
2881 .buffer
2882 .collaborators
2883 .iter()
2884 .filter_map(|c| Some(c.peer_id?.into()));
2885 channel_buffer_updated(
2886 session.connection_id,
2887 collaborators_to_notify,
2888 &proto::UpdateChannelBufferCollaborators {
2889 channel_id: rejoined_buffer.buffer.channel_id,
2890 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2891 },
2892 &session.peer,
2893 );
2894 }
2895
2896 response.send(proto::RejoinChannelBuffersResponse {
2897 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2898 })?;
2899
2900 Ok(())
2901}
2902
2903/// Stop editing the channel notes
2904async fn leave_channel_buffer(
2905 request: proto::LeaveChannelBuffer,
2906 response: Response<proto::LeaveChannelBuffer>,
2907 session: Session,
2908) -> Result<()> {
2909 let db = session.db().await;
2910 let channel_id = ChannelId::from_proto(request.channel_id);
2911
2912 let left_buffer = db
2913 .leave_channel_buffer(channel_id, session.connection_id)
2914 .await?;
2915
2916 response.send(Ack {})?;
2917
2918 channel_buffer_updated(
2919 session.connection_id,
2920 left_buffer.connections,
2921 &proto::UpdateChannelBufferCollaborators {
2922 channel_id: channel_id.to_proto(),
2923 collaborators: left_buffer.collaborators,
2924 },
2925 &session.peer,
2926 );
2927
2928 Ok(())
2929}
2930
2931fn channel_buffer_updated<T: EnvelopedMessage>(
2932 sender_id: ConnectionId,
2933 collaborators: impl IntoIterator<Item = ConnectionId>,
2934 message: &T,
2935 peer: &Peer,
2936) {
2937 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2938 peer.send(peer_id.into(), message.clone())
2939 });
2940}
2941
2942fn send_notifications(
2943 connection_pool: &ConnectionPool,
2944 peer: &Peer,
2945 notifications: db::NotificationBatch,
2946) {
2947 for (user_id, notification) in notifications {
2948 for connection_id in connection_pool.user_connection_ids(user_id) {
2949 if let Err(error) = peer.send(
2950 connection_id,
2951 proto::AddNotification {
2952 notification: Some(notification.clone()),
2953 },
2954 ) {
2955 tracing::error!(
2956 "failed to send notification to {:?} {}",
2957 connection_id,
2958 error
2959 );
2960 }
2961 }
2962 }
2963}
2964
2965/// Send a message to the channel
2966async fn send_channel_message(
2967 request: proto::SendChannelMessage,
2968 response: Response<proto::SendChannelMessage>,
2969 session: Session,
2970) -> Result<()> {
2971 // Validate the message body.
2972 let body = request.body.trim().to_string();
2973 if body.len() > MAX_MESSAGE_LEN {
2974 return Err(anyhow!("message is too long"))?;
2975 }
2976 if body.is_empty() {
2977 return Err(anyhow!("message can't be blank"))?;
2978 }
2979
2980 // TODO: adjust mentions if body is trimmed
2981
2982 let timestamp = OffsetDateTime::now_utc();
2983 let nonce = request
2984 .nonce
2985 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2986
2987 let channel_id = ChannelId::from_proto(request.channel_id);
2988 let CreatedChannelMessage {
2989 message_id,
2990 participant_connection_ids,
2991 channel_members,
2992 notifications,
2993 } = session
2994 .db()
2995 .await
2996 .create_channel_message(
2997 channel_id,
2998 session.user_id,
2999 &body,
3000 &request.mentions,
3001 timestamp,
3002 nonce.clone().into(),
3003 )
3004 .await?;
3005 let message = proto::ChannelMessage {
3006 sender_id: session.user_id.to_proto(),
3007 id: message_id.to_proto(),
3008 body,
3009 mentions: request.mentions,
3010 timestamp: timestamp.unix_timestamp() as u64,
3011 nonce: Some(nonce),
3012 };
3013 broadcast(
3014 Some(session.connection_id),
3015 participant_connection_ids,
3016 |connection| {
3017 session.peer.send(
3018 connection,
3019 proto::ChannelMessageSent {
3020 channel_id: channel_id.to_proto(),
3021 message: Some(message.clone()),
3022 },
3023 )
3024 },
3025 );
3026 response.send(proto::SendChannelMessageResponse {
3027 message: Some(message),
3028 })?;
3029
3030 let pool = &*session.connection_pool().await;
3031 broadcast(
3032 None,
3033 channel_members
3034 .iter()
3035 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3036 |peer_id| {
3037 session.peer.send(
3038 peer_id.into(),
3039 proto::UpdateChannels {
3040 unseen_channel_messages: vec![proto::UnseenChannelMessage {
3041 channel_id: channel_id.to_proto(),
3042 message_id: message_id.to_proto(),
3043 }],
3044 ..Default::default()
3045 },
3046 )
3047 },
3048 );
3049 send_notifications(pool, &session.peer, notifications);
3050
3051 Ok(())
3052}
3053
3054/// Delete a channel message
3055async fn remove_channel_message(
3056 request: proto::RemoveChannelMessage,
3057 response: Response<proto::RemoveChannelMessage>,
3058 session: Session,
3059) -> Result<()> {
3060 let channel_id = ChannelId::from_proto(request.channel_id);
3061 let message_id = MessageId::from_proto(request.message_id);
3062 let connection_ids = session
3063 .db()
3064 .await
3065 .remove_channel_message(channel_id, message_id, session.user_id)
3066 .await?;
3067 broadcast(Some(session.connection_id), connection_ids, |connection| {
3068 session.peer.send(connection, request.clone())
3069 });
3070 response.send(proto::Ack {})?;
3071 Ok(())
3072}
3073
3074/// Mark a channel message as read
3075async fn acknowledge_channel_message(
3076 request: proto::AckChannelMessage,
3077 session: Session,
3078) -> Result<()> {
3079 let channel_id = ChannelId::from_proto(request.channel_id);
3080 let message_id = MessageId::from_proto(request.message_id);
3081 let notifications = session
3082 .db()
3083 .await
3084 .observe_channel_message(channel_id, session.user_id, message_id)
3085 .await?;
3086 send_notifications(
3087 &*session.connection_pool().await,
3088 &session.peer,
3089 notifications,
3090 );
3091 Ok(())
3092}
3093
3094/// Mark a buffer version as synced
3095async fn acknowledge_buffer_version(
3096 request: proto::AckBufferOperation,
3097 session: Session,
3098) -> Result<()> {
3099 let buffer_id = BufferId::from_proto(request.buffer_id);
3100 session
3101 .db()
3102 .await
3103 .observe_buffer_version(
3104 buffer_id,
3105 session.user_id,
3106 request.epoch as i32,
3107 &request.version,
3108 )
3109 .await?;
3110 Ok(())
3111}
3112
3113/// Start receiving chat updates for a channel
3114async fn join_channel_chat(
3115 request: proto::JoinChannelChat,
3116 response: Response<proto::JoinChannelChat>,
3117 session: Session,
3118) -> Result<()> {
3119 let channel_id = ChannelId::from_proto(request.channel_id);
3120
3121 let db = session.db().await;
3122 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3123 .await?;
3124 let messages = db
3125 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3126 .await?;
3127 response.send(proto::JoinChannelChatResponse {
3128 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3129 messages,
3130 })?;
3131 Ok(())
3132}
3133
3134/// Stop receiving chat updates for a channel
3135async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3136 let channel_id = ChannelId::from_proto(request.channel_id);
3137 session
3138 .db()
3139 .await
3140 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3141 .await?;
3142 Ok(())
3143}
3144
3145/// Retrieve the chat history for a channel
3146async fn get_channel_messages(
3147 request: proto::GetChannelMessages,
3148 response: Response<proto::GetChannelMessages>,
3149 session: Session,
3150) -> Result<()> {
3151 let channel_id = ChannelId::from_proto(request.channel_id);
3152 let messages = session
3153 .db()
3154 .await
3155 .get_channel_messages(
3156 channel_id,
3157 session.user_id,
3158 MESSAGE_COUNT_PER_PAGE,
3159 Some(MessageId::from_proto(request.before_message_id)),
3160 )
3161 .await?;
3162 response.send(proto::GetChannelMessagesResponse {
3163 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3164 messages,
3165 })?;
3166 Ok(())
3167}
3168
3169/// Retrieve specific chat messages
3170async fn get_channel_messages_by_id(
3171 request: proto::GetChannelMessagesById,
3172 response: Response<proto::GetChannelMessagesById>,
3173 session: Session,
3174) -> Result<()> {
3175 let message_ids = request
3176 .message_ids
3177 .iter()
3178 .map(|id| MessageId::from_proto(*id))
3179 .collect::<Vec<_>>();
3180 let messages = session
3181 .db()
3182 .await
3183 .get_channel_messages_by_id(session.user_id, &message_ids)
3184 .await?;
3185 response.send(proto::GetChannelMessagesResponse {
3186 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3187 messages,
3188 })?;
3189 Ok(())
3190}
3191
3192/// Retrieve the current users notifications
3193async fn get_notifications(
3194 request: proto::GetNotifications,
3195 response: Response<proto::GetNotifications>,
3196 session: Session,
3197) -> Result<()> {
3198 let notifications = session
3199 .db()
3200 .await
3201 .get_notifications(
3202 session.user_id,
3203 NOTIFICATION_COUNT_PER_PAGE,
3204 request
3205 .before_id
3206 .map(|id| db::NotificationId::from_proto(id)),
3207 )
3208 .await?;
3209 response.send(proto::GetNotificationsResponse {
3210 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3211 notifications,
3212 })?;
3213 Ok(())
3214}
3215
3216/// Mark notifications as read
3217async fn mark_notification_as_read(
3218 request: proto::MarkNotificationRead,
3219 response: Response<proto::MarkNotificationRead>,
3220 session: Session,
3221) -> Result<()> {
3222 let database = &session.db().await;
3223 let notifications = database
3224 .mark_notification_as_read_by_id(
3225 session.user_id,
3226 NotificationId::from_proto(request.notification_id),
3227 )
3228 .await?;
3229 send_notifications(
3230 &*session.connection_pool().await,
3231 &session.peer,
3232 notifications,
3233 );
3234 response.send(proto::Ack {})?;
3235 Ok(())
3236}
3237
3238/// Get the current users information
3239async fn get_private_user_info(
3240 _request: proto::GetPrivateUserInfo,
3241 response: Response<proto::GetPrivateUserInfo>,
3242 session: Session,
3243) -> Result<()> {
3244 let db = session.db().await;
3245
3246 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3247 let user = db
3248 .get_user_by_id(session.user_id)
3249 .await?
3250 .ok_or_else(|| anyhow!("user not found"))?;
3251 let flags = db.get_user_flags(session.user_id).await?;
3252
3253 response.send(proto::GetPrivateUserInfoResponse {
3254 metrics_id,
3255 staff: user.admin,
3256 flags,
3257 })?;
3258 Ok(())
3259}
3260
3261fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3262 match message {
3263 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3264 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3265 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3266 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3267 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3268 code: frame.code.into(),
3269 reason: frame.reason,
3270 })),
3271 }
3272}
3273
3274fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3275 match message {
3276 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3277 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3278 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3279 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3280 AxumMessage::Close(frame) => {
3281 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3282 code: frame.code.into(),
3283 reason: frame.reason,
3284 }))
3285 }
3286 }
3287}
3288
3289fn notify_membership_updated(
3290 connection_pool: &ConnectionPool,
3291 result: MembershipUpdated,
3292 user_id: UserId,
3293 peer: &Peer,
3294) {
3295 let mut update = build_channels_update(result.new_channels, vec![]);
3296 update.delete_channels = result
3297 .removed_channels
3298 .into_iter()
3299 .map(|id| id.to_proto())
3300 .collect();
3301 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3302
3303 for connection_id in connection_pool.user_connection_ids(user_id) {
3304 peer.send(connection_id, update.clone()).trace_err();
3305 }
3306}
3307
3308fn build_channels_update(
3309 channels: ChannelsForUser,
3310 channel_invites: Vec<db::Channel>,
3311) -> proto::UpdateChannels {
3312 let mut update = proto::UpdateChannels::default();
3313
3314 for channel in channels.channels {
3315 update.channels.push(channel.to_proto());
3316 }
3317
3318 update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3319 update.unseen_channel_messages = channels.channel_messages;
3320
3321 for (channel_id, participants) in channels.channel_participants {
3322 update
3323 .channel_participants
3324 .push(proto::ChannelParticipants {
3325 channel_id: channel_id.to_proto(),
3326 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3327 });
3328 }
3329
3330 for channel in channel_invites {
3331 update.channel_invitations.push(channel.to_proto());
3332 }
3333
3334 update
3335}
3336
3337fn build_initial_contacts_update(
3338 contacts: Vec<db::Contact>,
3339 pool: &ConnectionPool,
3340) -> proto::UpdateContacts {
3341 let mut update = proto::UpdateContacts::default();
3342
3343 for contact in contacts {
3344 match contact {
3345 db::Contact::Accepted { user_id, busy } => {
3346 update.contacts.push(contact_for_user(user_id, busy, &pool));
3347 }
3348 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3349 db::Contact::Incoming { user_id } => {
3350 update
3351 .incoming_requests
3352 .push(proto::IncomingContactRequest {
3353 requester_id: user_id.to_proto(),
3354 })
3355 }
3356 }
3357 }
3358
3359 update
3360}
3361
3362fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3363 proto::Contact {
3364 user_id: user_id.to_proto(),
3365 online: pool.is_user_online(user_id),
3366 busy,
3367 }
3368}
3369
3370fn room_updated(room: &proto::Room, peer: &Peer) {
3371 broadcast(
3372 None,
3373 room.participants
3374 .iter()
3375 .filter_map(|participant| Some(participant.peer_id?.into())),
3376 |peer_id| {
3377 peer.send(
3378 peer_id.into(),
3379 proto::RoomUpdated {
3380 room: Some(room.clone()),
3381 },
3382 )
3383 },
3384 );
3385}
3386
3387fn channel_updated(
3388 channel_id: ChannelId,
3389 room: &proto::Room,
3390 channel_members: &[UserId],
3391 peer: &Peer,
3392 pool: &ConnectionPool,
3393) {
3394 let participants = room
3395 .participants
3396 .iter()
3397 .map(|p| p.user_id)
3398 .collect::<Vec<_>>();
3399
3400 broadcast(
3401 None,
3402 channel_members
3403 .iter()
3404 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3405 |peer_id| {
3406 peer.send(
3407 peer_id.into(),
3408 proto::UpdateChannels {
3409 channel_participants: vec![proto::ChannelParticipants {
3410 channel_id: channel_id.to_proto(),
3411 participant_user_ids: participants.clone(),
3412 }],
3413 ..Default::default()
3414 },
3415 )
3416 },
3417 );
3418}
3419
3420async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3421 let db = session.db().await;
3422
3423 let contacts = db.get_contacts(user_id).await?;
3424 let busy = db.is_user_busy(user_id).await?;
3425
3426 let pool = session.connection_pool().await;
3427 let updated_contact = contact_for_user(user_id, busy, &pool);
3428 for contact in contacts {
3429 if let db::Contact::Accepted {
3430 user_id: contact_user_id,
3431 ..
3432 } = contact
3433 {
3434 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3435 session
3436 .peer
3437 .send(
3438 contact_conn_id,
3439 proto::UpdateContacts {
3440 contacts: vec![updated_contact.clone()],
3441 remove_contacts: Default::default(),
3442 incoming_requests: Default::default(),
3443 remove_incoming_requests: Default::default(),
3444 outgoing_requests: Default::default(),
3445 remove_outgoing_requests: Default::default(),
3446 },
3447 )
3448 .trace_err();
3449 }
3450 }
3451 }
3452 Ok(())
3453}
3454
3455async fn leave_room_for_session(session: &Session) -> Result<()> {
3456 let mut contacts_to_update = HashSet::default();
3457
3458 let room_id;
3459 let canceled_calls_to_user_ids;
3460 let live_kit_room;
3461 let delete_live_kit_room;
3462 let room;
3463 let channel_members;
3464 let channel_id;
3465
3466 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3467 contacts_to_update.insert(session.user_id);
3468
3469 for project in left_room.left_projects.values() {
3470 project_left(project, session);
3471 }
3472
3473 room_id = RoomId::from_proto(left_room.room.id);
3474 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3475 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3476 delete_live_kit_room = left_room.deleted;
3477 room = mem::take(&mut left_room.room);
3478 channel_members = mem::take(&mut left_room.channel_members);
3479 channel_id = left_room.channel_id;
3480
3481 room_updated(&room, &session.peer);
3482 } else {
3483 return Ok(());
3484 }
3485
3486 if let Some(channel_id) = channel_id {
3487 channel_updated(
3488 channel_id,
3489 &room,
3490 &channel_members,
3491 &session.peer,
3492 &*session.connection_pool().await,
3493 );
3494 }
3495
3496 {
3497 let pool = session.connection_pool().await;
3498 for canceled_user_id in canceled_calls_to_user_ids {
3499 for connection_id in pool.user_connection_ids(canceled_user_id) {
3500 session
3501 .peer
3502 .send(
3503 connection_id,
3504 proto::CallCanceled {
3505 room_id: room_id.to_proto(),
3506 },
3507 )
3508 .trace_err();
3509 }
3510 contacts_to_update.insert(canceled_user_id);
3511 }
3512 }
3513
3514 for contact_user_id in contacts_to_update {
3515 update_user_contacts(contact_user_id, &session).await?;
3516 }
3517
3518 if let Some(live_kit) = session.live_kit_client.as_ref() {
3519 live_kit
3520 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3521 .await
3522 .trace_err();
3523
3524 if delete_live_kit_room {
3525 live_kit.delete_room(live_kit_room).await.trace_err();
3526 }
3527 }
3528
3529 Ok(())
3530}
3531
3532async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3533 let left_channel_buffers = session
3534 .db()
3535 .await
3536 .leave_channel_buffers(session.connection_id)
3537 .await?;
3538
3539 for left_buffer in left_channel_buffers {
3540 channel_buffer_updated(
3541 session.connection_id,
3542 left_buffer.connections,
3543 &proto::UpdateChannelBufferCollaborators {
3544 channel_id: left_buffer.channel_id.to_proto(),
3545 collaborators: left_buffer.collaborators,
3546 },
3547 &session.peer,
3548 );
3549 }
3550
3551 Ok(())
3552}
3553
3554fn project_left(project: &db::LeftProject, session: &Session) {
3555 for connection_id in &project.connection_ids {
3556 if project.host_user_id == session.user_id {
3557 session
3558 .peer
3559 .send(
3560 *connection_id,
3561 proto::UnshareProject {
3562 project_id: project.id.to_proto(),
3563 },
3564 )
3565 .trace_err();
3566 } else {
3567 session
3568 .peer
3569 .send(
3570 *connection_id,
3571 proto::RemoveProjectCollaborator {
3572 project_id: project.id.to_proto(),
3573 peer_id: Some(session.connection_id.into()),
3574 },
3575 )
3576 .trace_err();
3577 }
3578 }
3579}
3580
3581pub trait ResultExt {
3582 type Ok;
3583
3584 fn trace_err(self) -> Option<Self::Ok>;
3585}
3586
3587impl<T, E> ResultExt for Result<T, E>
3588where
3589 E: std::fmt::Debug,
3590{
3591 type Ok = T;
3592
3593 fn trace_err(self) -> Option<T> {
3594 match self {
3595 Ok(value) => Some(value),
3596 Err(error) => {
3597 tracing::error!("{:?}", error);
3598 None
3599 }
3600 }
3601 }
3602}