1mod connection_pool;
2
3use crate::{
4 auth::{self, Impersonator},
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
7 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
8 MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
9 RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
10 User, UserId,
11 },
12 executor::Executor,
13 AppState, Result,
14};
15use anyhow::anyhow;
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::ConnectionPool;
34use futures::{
35 channel::oneshot,
36 future::{self, BoxFuture},
37 stream::FuturesUnordered,
38 FutureExt, SinkExt, StreamExt, TryStreamExt,
39};
40use lazy_static::lazy_static;
41use prometheus::{register_int_gauge, IntGauge};
42use rpc::{
43 proto::{
44 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
45 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
46 },
47 Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
48};
49use serde::{Serialize, Serializer};
50use std::{
51 any::TypeId,
52 fmt,
53 future::Future,
54 marker::PhantomData,
55 mem,
56 net::SocketAddr,
57 ops::{Deref, DerefMut},
58 rc::Rc,
59 sync::{
60 atomic::{AtomicBool, Ordering::SeqCst},
61 Arc,
62 },
63 time::{Duration, Instant},
64};
65use time::OffsetDateTime;
66use tokio::sync::{watch, Semaphore};
67use tower::ServiceBuilder;
68use tracing::{field, info_span, instrument, Instrument};
69
70pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
71pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
72
73const MESSAGE_COUNT_PER_PAGE: usize = 100;
74const MAX_MESSAGE_LEN: usize = 1024;
75const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
76
77lazy_static! {
78 static ref METRIC_CONNECTIONS: IntGauge =
79 register_int_gauge!("connections", "number of connections").unwrap();
80 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
81 "shared_projects",
82 "number of open projects with one or more guests"
83 )
84 .unwrap();
85}
86
87type MessageHandler =
88 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
89
90struct Response<R> {
91 peer: Arc<Peer>,
92 receipt: Receipt<R>,
93 responded: Arc<AtomicBool>,
94}
95
96impl<R: RequestMessage> Response<R> {
97 fn send(self, payload: R::Response) -> Result<()> {
98 self.responded.store(true, SeqCst);
99 self.peer.respond(self.receipt, payload)?;
100 Ok(())
101 }
102}
103
104#[derive(Clone)]
105struct Session {
106 zed_environment: Arc<str>,
107 user_id: UserId,
108 connection_id: ConnectionId,
109 db: Arc<tokio::sync::Mutex<DbHandle>>,
110 peer: Arc<Peer>,
111 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
112 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
113 _executor: Executor,
114}
115
116impl Session {
117 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
118 #[cfg(test)]
119 tokio::task::yield_now().await;
120 let guard = self.db.lock().await;
121 #[cfg(test)]
122 tokio::task::yield_now().await;
123 guard
124 }
125
126 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
127 #[cfg(test)]
128 tokio::task::yield_now().await;
129 let guard = self.connection_pool.lock();
130 ConnectionPoolGuard {
131 guard,
132 _not_send: PhantomData,
133 }
134 }
135}
136
137impl fmt::Debug for Session {
138 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139 f.debug_struct("Session")
140 .field("user_id", &self.user_id)
141 .field("connection_id", &self.connection_id)
142 .finish()
143 }
144}
145
146struct DbHandle(Arc<Database>);
147
148impl Deref for DbHandle {
149 type Target = Database;
150
151 fn deref(&self) -> &Self::Target {
152 self.0.as_ref()
153 }
154}
155
156pub struct Server {
157 id: parking_lot::Mutex<ServerId>,
158 peer: Arc<Peer>,
159 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
160 app_state: Arc<AppState>,
161 executor: Executor,
162 handlers: HashMap<TypeId, MessageHandler>,
163 teardown: watch::Sender<()>,
164}
165
166pub(crate) struct ConnectionPoolGuard<'a> {
167 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
168 _not_send: PhantomData<Rc<()>>,
169}
170
171#[derive(Serialize)]
172pub struct ServerSnapshot<'a> {
173 peer: &'a Peer,
174 #[serde(serialize_with = "serialize_deref")]
175 connection_pool: ConnectionPoolGuard<'a>,
176}
177
178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
179where
180 S: Serializer,
181 T: Deref<Target = U>,
182 U: Serialize,
183{
184 Serialize::serialize(value.deref(), serializer)
185}
186
187impl Server {
188 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
189 let mut server = Self {
190 id: parking_lot::Mutex::new(id),
191 peer: Peer::new(id.0 as u32),
192 app_state,
193 executor,
194 connection_pool: Default::default(),
195 handlers: Default::default(),
196 teardown: watch::channel(()).0,
197 };
198
199 server
200 .add_request_handler(ping)
201 .add_request_handler(create_room)
202 .add_request_handler(join_room)
203 .add_request_handler(rejoin_room)
204 .add_request_handler(leave_room)
205 .add_request_handler(set_room_participant_role)
206 .add_request_handler(call)
207 .add_request_handler(cancel_call)
208 .add_message_handler(decline_call)
209 .add_request_handler(update_participant_location)
210 .add_request_handler(share_project)
211 .add_message_handler(unshare_project)
212 .add_request_handler(join_project)
213 .add_message_handler(leave_project)
214 .add_request_handler(update_project)
215 .add_request_handler(update_worktree)
216 .add_message_handler(start_language_server)
217 .add_message_handler(update_language_server)
218 .add_message_handler(update_diagnostic_summary)
219 .add_message_handler(update_worktree_settings)
220 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
221 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
222 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
223 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
224 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
225 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
226 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
227 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
228 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
229 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
230 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
231 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
232 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
233 .add_request_handler(
234 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
235 )
236 .add_request_handler(
237 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
238 )
239 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
240 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
241 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
242 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
243 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
244 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
245 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
246 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
247 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
248 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
249 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
250 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
251 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
252 .add_message_handler(create_buffer_for_peer)
253 .add_request_handler(update_buffer)
254 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
255 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
256 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
257 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
258 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
259 .add_request_handler(get_users)
260 .add_request_handler(fuzzy_search_users)
261 .add_request_handler(request_contact)
262 .add_request_handler(remove_contact)
263 .add_request_handler(respond_to_contact_request)
264 .add_request_handler(create_channel)
265 .add_request_handler(delete_channel)
266 .add_request_handler(invite_channel_member)
267 .add_request_handler(remove_channel_member)
268 .add_request_handler(set_channel_member_role)
269 .add_request_handler(set_channel_visibility)
270 .add_request_handler(rename_channel)
271 .add_request_handler(join_channel_buffer)
272 .add_request_handler(leave_channel_buffer)
273 .add_message_handler(update_channel_buffer)
274 .add_request_handler(rejoin_channel_buffers)
275 .add_request_handler(get_channel_members)
276 .add_request_handler(respond_to_channel_invite)
277 .add_request_handler(join_channel)
278 .add_request_handler(join_channel_chat)
279 .add_message_handler(leave_channel_chat)
280 .add_request_handler(send_channel_message)
281 .add_request_handler(remove_channel_message)
282 .add_request_handler(get_channel_messages)
283 .add_request_handler(get_channel_messages_by_id)
284 .add_request_handler(get_notifications)
285 .add_request_handler(mark_notification_as_read)
286 .add_request_handler(move_channel)
287 .add_request_handler(follow)
288 .add_message_handler(unfollow)
289 .add_message_handler(update_followers)
290 .add_request_handler(get_private_user_info)
291 .add_message_handler(acknowledge_channel_message)
292 .add_message_handler(acknowledge_buffer_version);
293
294 Arc::new(server)
295 }
296
297 pub async fn start(&self) -> Result<()> {
298 let server_id = *self.id.lock();
299 let app_state = self.app_state.clone();
300 let peer = self.peer.clone();
301 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
302 let pool = self.connection_pool.clone();
303 let live_kit_client = self.app_state.live_kit_client.clone();
304
305 let span = info_span!("start server");
306 self.executor.spawn_detached(
307 async move {
308 tracing::info!("waiting for cleanup timeout");
309 timeout.await;
310 tracing::info!("cleanup timeout expired, retrieving stale rooms");
311 if let Some((room_ids, channel_ids)) = app_state
312 .db
313 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
314 .await
315 .trace_err()
316 {
317 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
318 tracing::info!(
319 stale_channel_buffer_count = channel_ids.len(),
320 "retrieved stale channel buffers"
321 );
322
323 for channel_id in channel_ids {
324 if let Some(refreshed_channel_buffer) = app_state
325 .db
326 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
327 .await
328 .trace_err()
329 {
330 for connection_id in refreshed_channel_buffer.connection_ids {
331 peer.send(
332 connection_id,
333 proto::UpdateChannelBufferCollaborators {
334 channel_id: channel_id.to_proto(),
335 collaborators: refreshed_channel_buffer
336 .collaborators
337 .clone(),
338 },
339 )
340 .trace_err();
341 }
342 }
343 }
344
345 for room_id in room_ids {
346 let mut contacts_to_update = HashSet::default();
347 let mut canceled_calls_to_user_ids = Vec::new();
348 let mut live_kit_room = String::new();
349 let mut delete_live_kit_room = false;
350
351 if let Some(mut refreshed_room) = app_state
352 .db
353 .clear_stale_room_participants(room_id, server_id)
354 .await
355 .trace_err()
356 {
357 tracing::info!(
358 room_id = room_id.0,
359 new_participant_count = refreshed_room.room.participants.len(),
360 "refreshed room"
361 );
362 room_updated(&refreshed_room.room, &peer);
363 if let Some(channel_id) = refreshed_room.channel_id {
364 channel_updated(
365 channel_id,
366 &refreshed_room.room,
367 &refreshed_room.channel_members,
368 &peer,
369 &*pool.lock(),
370 );
371 }
372 contacts_to_update
373 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
374 contacts_to_update
375 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
376 canceled_calls_to_user_ids =
377 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
378 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
379 delete_live_kit_room = refreshed_room.room.participants.is_empty();
380 }
381
382 {
383 let pool = pool.lock();
384 for canceled_user_id in canceled_calls_to_user_ids {
385 for connection_id in pool.user_connection_ids(canceled_user_id) {
386 peer.send(
387 connection_id,
388 proto::CallCanceled {
389 room_id: room_id.to_proto(),
390 },
391 )
392 .trace_err();
393 }
394 }
395 }
396
397 for user_id in contacts_to_update {
398 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
399 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
400 if let Some((busy, contacts)) = busy.zip(contacts) {
401 let pool = pool.lock();
402 let updated_contact = contact_for_user(user_id, busy, &pool);
403 for contact in contacts {
404 if let db::Contact::Accepted {
405 user_id: contact_user_id,
406 ..
407 } = contact
408 {
409 for contact_conn_id in
410 pool.user_connection_ids(contact_user_id)
411 {
412 peer.send(
413 contact_conn_id,
414 proto::UpdateContacts {
415 contacts: vec![updated_contact.clone()],
416 remove_contacts: Default::default(),
417 incoming_requests: Default::default(),
418 remove_incoming_requests: Default::default(),
419 outgoing_requests: Default::default(),
420 remove_outgoing_requests: Default::default(),
421 },
422 )
423 .trace_err();
424 }
425 }
426 }
427 }
428 }
429
430 if let Some(live_kit) = live_kit_client.as_ref() {
431 if delete_live_kit_room {
432 live_kit.delete_room(live_kit_room).await.trace_err();
433 }
434 }
435 }
436 }
437
438 app_state
439 .db
440 .delete_stale_servers(&app_state.config.zed_environment, server_id)
441 .await
442 .trace_err();
443 }
444 .instrument(span),
445 );
446 Ok(())
447 }
448
449 pub fn teardown(&self) {
450 self.peer.teardown();
451 self.connection_pool.lock().reset();
452 let _ = self.teardown.send(());
453 }
454
455 #[cfg(test)]
456 pub fn reset(&self, id: ServerId) {
457 self.teardown();
458 *self.id.lock() = id;
459 self.peer.reset(id.0 as u32);
460 }
461
462 #[cfg(test)]
463 pub fn id(&self) -> ServerId {
464 *self.id.lock()
465 }
466
467 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
468 where
469 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
470 Fut: 'static + Send + Future<Output = Result<()>>,
471 M: EnvelopedMessage,
472 {
473 let prev_handler = self.handlers.insert(
474 TypeId::of::<M>(),
475 Box::new(move |envelope, session| {
476 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
477 let span = info_span!(
478 "handle message",
479 payload_type = envelope.payload_type_name()
480 );
481 span.in_scope(|| {
482 tracing::info!(
483 payload_type = envelope.payload_type_name(),
484 "message received"
485 );
486 });
487 let start_time = Instant::now();
488 let future = (handler)(*envelope, session);
489 async move {
490 let result = future.await;
491 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
492 match result {
493 Err(error) => {
494 tracing::error!(%error, ?duration_ms, "error handling message")
495 }
496 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
497 }
498 }
499 .instrument(span)
500 .boxed()
501 }),
502 );
503 if prev_handler.is_some() {
504 panic!("registered a handler for the same message twice");
505 }
506 self
507 }
508
509 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
510 where
511 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
512 Fut: 'static + Send + Future<Output = Result<()>>,
513 M: EnvelopedMessage,
514 {
515 self.add_handler(move |envelope, session| handler(envelope.payload, session));
516 self
517 }
518
519 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
520 where
521 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
522 Fut: Send + Future<Output = Result<()>>,
523 M: RequestMessage,
524 {
525 let handler = Arc::new(handler);
526 self.add_handler(move |envelope, session| {
527 let receipt = envelope.receipt();
528 let handler = handler.clone();
529 async move {
530 let peer = session.peer.clone();
531 let responded = Arc::new(AtomicBool::default());
532 let response = Response {
533 peer: peer.clone(),
534 responded: responded.clone(),
535 receipt,
536 };
537 match (handler)(envelope.payload, response, session).await {
538 Ok(()) => {
539 if responded.load(std::sync::atomic::Ordering::SeqCst) {
540 Ok(())
541 } else {
542 Err(anyhow!("handler did not send a response"))?
543 }
544 }
545 Err(error) => {
546 peer.respond_with_error(
547 receipt,
548 proto::Error {
549 message: error.to_string(),
550 },
551 )?;
552 Err(error)
553 }
554 }
555 }
556 })
557 }
558
559 pub fn handle_connection(
560 self: &Arc<Self>,
561 connection: Connection,
562 address: String,
563 user: User,
564 impersonator: Option<User>,
565 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
566 executor: Executor,
567 ) -> impl Future<Output = Result<()>> {
568 let this = self.clone();
569 let user_id = user.id;
570 let login = user.github_login;
571 let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
572 if let Some(impersonator) = impersonator {
573 span.record("impersonator", &impersonator.github_login);
574 }
575 let mut teardown = self.teardown.subscribe();
576 async move {
577 let (connection_id, handle_io, mut incoming_rx) = this
578 .peer
579 .add_connection(connection, {
580 let executor = executor.clone();
581 move |duration| executor.sleep(duration)
582 });
583
584 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
585 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
586 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
587
588 if let Some(send_connection_id) = send_connection_id.take() {
589 let _ = send_connection_id.send(connection_id);
590 }
591
592 if !user.connected_once {
593 this.peer.send(connection_id, proto::ShowContacts {})?;
594 this.app_state.db.set_user_connected_once(user_id, true).await?;
595 }
596
597 let (contacts, channels_for_user, channel_invites) = future::try_join3(
598 this.app_state.db.get_contacts(user_id),
599 this.app_state.db.get_channels_for_user(user_id),
600 this.app_state.db.get_channel_invites_for_user(user_id),
601 ).await?;
602
603 {
604 let mut pool = this.connection_pool.lock();
605 pool.add_connection(connection_id, user_id, user.admin);
606 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
607 this.peer.send(connection_id, build_channels_update(
608 channels_for_user,
609 channel_invites
610 ))?;
611 }
612
613 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
614 this.peer.send(connection_id, incoming_call)?;
615 }
616
617 let session = Session {
618 user_id,
619 connection_id,
620 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
621 zed_environment: this.app_state.config.zed_environment.clone(),
622 peer: this.peer.clone(),
623 connection_pool: this.connection_pool.clone(),
624 live_kit_client: this.app_state.live_kit_client.clone(),
625 _executor: executor.clone()
626 };
627 update_user_contacts(user_id, &session).await?;
628
629 let handle_io = handle_io.fuse();
630 futures::pin_mut!(handle_io);
631
632 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
633 // This prevents deadlocks when e.g., client A performs a request to client B and
634 // client B performs a request to client A. If both clients stop processing further
635 // messages until their respective request completes, they won't have a chance to
636 // respond to the other client's request and cause a deadlock.
637 //
638 // This arrangement ensures we will attempt to process earlier messages first, but fall
639 // back to processing messages arrived later in the spirit of making progress.
640 let mut foreground_message_handlers = FuturesUnordered::new();
641 let concurrent_handlers = Arc::new(Semaphore::new(256));
642 loop {
643 let next_message = async {
644 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
645 let message = incoming_rx.next().await;
646 (permit, message)
647 }.fuse();
648 futures::pin_mut!(next_message);
649 futures::select_biased! {
650 _ = teardown.changed().fuse() => return Ok(()),
651 result = handle_io => {
652 if let Err(error) = result {
653 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
654 }
655 break;
656 }
657 _ = foreground_message_handlers.next() => {}
658 next_message = next_message => {
659 let (permit, message) = next_message;
660 if let Some(message) = message {
661 let type_name = message.payload_type_name();
662 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
663 let span_enter = span.enter();
664 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
665 let is_background = message.is_background();
666 let handle_message = (handler)(message, session.clone());
667 drop(span_enter);
668
669 let handle_message = async move {
670 handle_message.await;
671 drop(permit);
672 }.instrument(span);
673 if is_background {
674 executor.spawn_detached(handle_message);
675 } else {
676 foreground_message_handlers.push(handle_message);
677 }
678 } else {
679 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
680 }
681 } else {
682 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
683 break;
684 }
685 }
686 }
687 }
688
689 drop(foreground_message_handlers);
690 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
691 if let Err(error) = connection_lost(session, teardown, executor).await {
692 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
693 }
694
695 Ok(())
696 }.instrument(span)
697 }
698
699 pub async fn invite_code_redeemed(
700 self: &Arc<Self>,
701 inviter_id: UserId,
702 invitee_id: UserId,
703 ) -> Result<()> {
704 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
705 if let Some(code) = &user.invite_code {
706 let pool = self.connection_pool.lock();
707 let invitee_contact = contact_for_user(invitee_id, false, &pool);
708 for connection_id in pool.user_connection_ids(inviter_id) {
709 self.peer.send(
710 connection_id,
711 proto::UpdateContacts {
712 contacts: vec![invitee_contact.clone()],
713 ..Default::default()
714 },
715 )?;
716 self.peer.send(
717 connection_id,
718 proto::UpdateInviteInfo {
719 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
720 count: user.invite_count as u32,
721 },
722 )?;
723 }
724 }
725 }
726 Ok(())
727 }
728
729 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
730 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
731 if let Some(invite_code) = &user.invite_code {
732 let pool = self.connection_pool.lock();
733 for connection_id in pool.user_connection_ids(user_id) {
734 self.peer.send(
735 connection_id,
736 proto::UpdateInviteInfo {
737 url: format!(
738 "{}{}",
739 self.app_state.config.invite_link_prefix, invite_code
740 ),
741 count: user.invite_count as u32,
742 },
743 )?;
744 }
745 }
746 }
747 Ok(())
748 }
749
750 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
751 ServerSnapshot {
752 connection_pool: ConnectionPoolGuard {
753 guard: self.connection_pool.lock(),
754 _not_send: PhantomData,
755 },
756 peer: &self.peer,
757 }
758 }
759}
760
761impl<'a> Deref for ConnectionPoolGuard<'a> {
762 type Target = ConnectionPool;
763
764 fn deref(&self) -> &Self::Target {
765 &*self.guard
766 }
767}
768
769impl<'a> DerefMut for ConnectionPoolGuard<'a> {
770 fn deref_mut(&mut self) -> &mut Self::Target {
771 &mut *self.guard
772 }
773}
774
775impl<'a> Drop for ConnectionPoolGuard<'a> {
776 fn drop(&mut self) {
777 #[cfg(test)]
778 self.check_invariants();
779 }
780}
781
782fn broadcast<F>(
783 sender_id: Option<ConnectionId>,
784 receiver_ids: impl IntoIterator<Item = ConnectionId>,
785 mut f: F,
786) where
787 F: FnMut(ConnectionId) -> anyhow::Result<()>,
788{
789 for receiver_id in receiver_ids {
790 if Some(receiver_id) != sender_id {
791 if let Err(error) = f(receiver_id) {
792 tracing::error!("failed to send to {:?} {}", receiver_id, error);
793 }
794 }
795 }
796}
797
798lazy_static! {
799 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
800}
801
802pub struct ProtocolVersion(u32);
803
804impl Header for ProtocolVersion {
805 fn name() -> &'static HeaderName {
806 &ZED_PROTOCOL_VERSION
807 }
808
809 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
810 where
811 Self: Sized,
812 I: Iterator<Item = &'i axum::http::HeaderValue>,
813 {
814 let version = values
815 .next()
816 .ok_or_else(axum::headers::Error::invalid)?
817 .to_str()
818 .map_err(|_| axum::headers::Error::invalid())?
819 .parse()
820 .map_err(|_| axum::headers::Error::invalid())?;
821 Ok(Self(version))
822 }
823
824 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
825 values.extend([self.0.to_string().parse().unwrap()]);
826 }
827}
828
829pub fn routes(server: Arc<Server>) -> Router<Body> {
830 Router::new()
831 .route("/rpc", get(handle_websocket_request))
832 .layer(
833 ServiceBuilder::new()
834 .layer(Extension(server.app_state.clone()))
835 .layer(middleware::from_fn(auth::validate_header)),
836 )
837 .route("/metrics", get(handle_metrics))
838 .layer(Extension(server))
839}
840
841pub async fn handle_websocket_request(
842 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
843 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
844 Extension(server): Extension<Arc<Server>>,
845 Extension(user): Extension<User>,
846 Extension(impersonator): Extension<Impersonator>,
847 ws: WebSocketUpgrade,
848) -> axum::response::Response {
849 if protocol_version != rpc::PROTOCOL_VERSION {
850 return (
851 StatusCode::UPGRADE_REQUIRED,
852 "client must be upgraded".to_string(),
853 )
854 .into_response();
855 }
856 let socket_address = socket_address.to_string();
857 ws.on_upgrade(move |socket| {
858 use util::ResultExt;
859 let socket = socket
860 .map_ok(to_tungstenite_message)
861 .err_into()
862 .with(|message| async move { Ok(to_axum_message(message)) });
863 let connection = Connection::new(Box::pin(socket));
864 async move {
865 server
866 .handle_connection(
867 connection,
868 socket_address,
869 user,
870 impersonator.0,
871 None,
872 Executor::Production,
873 )
874 .await
875 .log_err();
876 }
877 })
878}
879
880pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
881 let connections = server
882 .connection_pool
883 .lock()
884 .connections()
885 .filter(|connection| !connection.admin)
886 .count();
887
888 METRIC_CONNECTIONS.set(connections as _);
889
890 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
891 METRIC_SHARED_PROJECTS.set(shared_projects as _);
892
893 let encoder = prometheus::TextEncoder::new();
894 let metric_families = prometheus::gather();
895 let encoded_metrics = encoder
896 .encode_to_string(&metric_families)
897 .map_err(|err| anyhow!("{}", err))?;
898 Ok(encoded_metrics)
899}
900
901#[instrument(err, skip(executor))]
902async fn connection_lost(
903 session: Session,
904 mut teardown: watch::Receiver<()>,
905 executor: Executor,
906) -> Result<()> {
907 session.peer.disconnect(session.connection_id);
908 session
909 .connection_pool()
910 .await
911 .remove_connection(session.connection_id)?;
912
913 session
914 .db()
915 .await
916 .connection_lost(session.connection_id)
917 .await
918 .trace_err();
919
920 futures::select_biased! {
921 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
922 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
923 leave_room_for_session(&session).await.trace_err();
924 leave_channel_buffers_for_session(&session)
925 .await
926 .trace_err();
927
928 if !session
929 .connection_pool()
930 .await
931 .is_user_online(session.user_id)
932 {
933 let db = session.db().await;
934 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
935 room_updated(&room, &session.peer);
936 }
937 }
938
939 update_user_contacts(session.user_id, &session).await?;
940 }
941 _ = teardown.changed().fuse() => {}
942 }
943
944 Ok(())
945}
946
947/// Acknowledges a ping from a client, used to keep the connection alive.
948async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
949 response.send(proto::Ack {})?;
950 Ok(())
951}
952
953/// Create a new room for calling (outside of channels)
954async fn create_room(
955 _request: proto::CreateRoom,
956 response: Response<proto::CreateRoom>,
957 session: Session,
958) -> Result<()> {
959 let live_kit_room = nanoid::nanoid!(30);
960
961 let live_kit_connection_info = {
962 let live_kit_room = live_kit_room.clone();
963 let live_kit = session.live_kit_client.as_ref();
964
965 util::async_maybe!({
966 let live_kit = live_kit?;
967
968 let token = live_kit
969 .room_token(&live_kit_room, &session.user_id.to_string())
970 .trace_err()?;
971
972 Some(proto::LiveKitConnectionInfo {
973 server_url: live_kit.url().into(),
974 token,
975 can_publish: true,
976 })
977 })
978 }
979 .await;
980
981 let room = session
982 .db()
983 .await
984 .create_room(
985 session.user_id,
986 session.connection_id,
987 &live_kit_room,
988 &session.zed_environment,
989 )
990 .await?;
991
992 response.send(proto::CreateRoomResponse {
993 room: Some(room.clone()),
994 live_kit_connection_info,
995 })?;
996
997 update_user_contacts(session.user_id, &session).await?;
998 Ok(())
999}
1000
1001/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1002async fn join_room(
1003 request: proto::JoinRoom,
1004 response: Response<proto::JoinRoom>,
1005 session: Session,
1006) -> Result<()> {
1007 let room_id = RoomId::from_proto(request.id);
1008
1009 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1010
1011 if let Some(channel_id) = channel_id {
1012 return join_channel_internal(channel_id, Box::new(response), session).await;
1013 }
1014
1015 let joined_room = {
1016 let room = session
1017 .db()
1018 .await
1019 .join_room(
1020 room_id,
1021 session.user_id,
1022 session.connection_id,
1023 session.zed_environment.as_ref(),
1024 )
1025 .await?;
1026 room_updated(&room.room, &session.peer);
1027 room.into_inner()
1028 };
1029
1030 for connection_id in session
1031 .connection_pool()
1032 .await
1033 .user_connection_ids(session.user_id)
1034 {
1035 session
1036 .peer
1037 .send(
1038 connection_id,
1039 proto::CallCanceled {
1040 room_id: room_id.to_proto(),
1041 },
1042 )
1043 .trace_err();
1044 }
1045
1046 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1047 if let Some(token) = live_kit
1048 .room_token(
1049 &joined_room.room.live_kit_room,
1050 &session.user_id.to_string(),
1051 )
1052 .trace_err()
1053 {
1054 Some(proto::LiveKitConnectionInfo {
1055 server_url: live_kit.url().into(),
1056 token,
1057 can_publish: true,
1058 })
1059 } else {
1060 None
1061 }
1062 } else {
1063 None
1064 };
1065
1066 response.send(proto::JoinRoomResponse {
1067 room: Some(joined_room.room),
1068 channel_id: None,
1069 live_kit_connection_info,
1070 })?;
1071
1072 update_user_contacts(session.user_id, &session).await?;
1073 Ok(())
1074}
1075
1076/// Rejoin room is used to reconnect to a room after connection errors.
1077async fn rejoin_room(
1078 request: proto::RejoinRoom,
1079 response: Response<proto::RejoinRoom>,
1080 session: Session,
1081) -> Result<()> {
1082 let room;
1083 let channel_id;
1084 let channel_members;
1085 {
1086 let mut rejoined_room = session
1087 .db()
1088 .await
1089 .rejoin_room(request, session.user_id, session.connection_id)
1090 .await?;
1091
1092 response.send(proto::RejoinRoomResponse {
1093 room: Some(rejoined_room.room.clone()),
1094 reshared_projects: rejoined_room
1095 .reshared_projects
1096 .iter()
1097 .map(|project| proto::ResharedProject {
1098 id: project.id.to_proto(),
1099 collaborators: project
1100 .collaborators
1101 .iter()
1102 .map(|collaborator| collaborator.to_proto())
1103 .collect(),
1104 })
1105 .collect(),
1106 rejoined_projects: rejoined_room
1107 .rejoined_projects
1108 .iter()
1109 .map(|rejoined_project| proto::RejoinedProject {
1110 id: rejoined_project.id.to_proto(),
1111 worktrees: rejoined_project
1112 .worktrees
1113 .iter()
1114 .map(|worktree| proto::WorktreeMetadata {
1115 id: worktree.id,
1116 root_name: worktree.root_name.clone(),
1117 visible: worktree.visible,
1118 abs_path: worktree.abs_path.clone(),
1119 })
1120 .collect(),
1121 collaborators: rejoined_project
1122 .collaborators
1123 .iter()
1124 .map(|collaborator| collaborator.to_proto())
1125 .collect(),
1126 language_servers: rejoined_project.language_servers.clone(),
1127 })
1128 .collect(),
1129 })?;
1130 room_updated(&rejoined_room.room, &session.peer);
1131
1132 for project in &rejoined_room.reshared_projects {
1133 for collaborator in &project.collaborators {
1134 session
1135 .peer
1136 .send(
1137 collaborator.connection_id,
1138 proto::UpdateProjectCollaborator {
1139 project_id: project.id.to_proto(),
1140 old_peer_id: Some(project.old_connection_id.into()),
1141 new_peer_id: Some(session.connection_id.into()),
1142 },
1143 )
1144 .trace_err();
1145 }
1146
1147 broadcast(
1148 Some(session.connection_id),
1149 project
1150 .collaborators
1151 .iter()
1152 .map(|collaborator| collaborator.connection_id),
1153 |connection_id| {
1154 session.peer.forward_send(
1155 session.connection_id,
1156 connection_id,
1157 proto::UpdateProject {
1158 project_id: project.id.to_proto(),
1159 worktrees: project.worktrees.clone(),
1160 },
1161 )
1162 },
1163 );
1164 }
1165
1166 for project in &rejoined_room.rejoined_projects {
1167 for collaborator in &project.collaborators {
1168 session
1169 .peer
1170 .send(
1171 collaborator.connection_id,
1172 proto::UpdateProjectCollaborator {
1173 project_id: project.id.to_proto(),
1174 old_peer_id: Some(project.old_connection_id.into()),
1175 new_peer_id: Some(session.connection_id.into()),
1176 },
1177 )
1178 .trace_err();
1179 }
1180 }
1181
1182 for project in &mut rejoined_room.rejoined_projects {
1183 for worktree in mem::take(&mut project.worktrees) {
1184 #[cfg(any(test, feature = "test-support"))]
1185 const MAX_CHUNK_SIZE: usize = 2;
1186 #[cfg(not(any(test, feature = "test-support")))]
1187 const MAX_CHUNK_SIZE: usize = 256;
1188
1189 // Stream this worktree's entries.
1190 let message = proto::UpdateWorktree {
1191 project_id: project.id.to_proto(),
1192 worktree_id: worktree.id,
1193 abs_path: worktree.abs_path.clone(),
1194 root_name: worktree.root_name,
1195 updated_entries: worktree.updated_entries,
1196 removed_entries: worktree.removed_entries,
1197 scan_id: worktree.scan_id,
1198 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1199 updated_repositories: worktree.updated_repositories,
1200 removed_repositories: worktree.removed_repositories,
1201 };
1202 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1203 session.peer.send(session.connection_id, update.clone())?;
1204 }
1205
1206 // Stream this worktree's diagnostics.
1207 for summary in worktree.diagnostic_summaries {
1208 session.peer.send(
1209 session.connection_id,
1210 proto::UpdateDiagnosticSummary {
1211 project_id: project.id.to_proto(),
1212 worktree_id: worktree.id,
1213 summary: Some(summary),
1214 },
1215 )?;
1216 }
1217
1218 for settings_file in worktree.settings_files {
1219 session.peer.send(
1220 session.connection_id,
1221 proto::UpdateWorktreeSettings {
1222 project_id: project.id.to_proto(),
1223 worktree_id: worktree.id,
1224 path: settings_file.path,
1225 content: Some(settings_file.content),
1226 },
1227 )?;
1228 }
1229 }
1230
1231 for language_server in &project.language_servers {
1232 session.peer.send(
1233 session.connection_id,
1234 proto::UpdateLanguageServer {
1235 project_id: project.id.to_proto(),
1236 language_server_id: language_server.id,
1237 variant: Some(
1238 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1239 proto::LspDiskBasedDiagnosticsUpdated {},
1240 ),
1241 ),
1242 },
1243 )?;
1244 }
1245 }
1246
1247 let rejoined_room = rejoined_room.into_inner();
1248
1249 room = rejoined_room.room;
1250 channel_id = rejoined_room.channel_id;
1251 channel_members = rejoined_room.channel_members;
1252 }
1253
1254 if let Some(channel_id) = channel_id {
1255 channel_updated(
1256 channel_id,
1257 &room,
1258 &channel_members,
1259 &session.peer,
1260 &*session.connection_pool().await,
1261 );
1262 }
1263
1264 update_user_contacts(session.user_id, &session).await?;
1265 Ok(())
1266}
1267
1268/// leave room disonnects from the room.
1269async fn leave_room(
1270 _: proto::LeaveRoom,
1271 response: Response<proto::LeaveRoom>,
1272 session: Session,
1273) -> Result<()> {
1274 leave_room_for_session(&session).await?;
1275 response.send(proto::Ack {})?;
1276 Ok(())
1277}
1278
1279/// Update the permissions of someone else in the room.
1280async fn set_room_participant_role(
1281 request: proto::SetRoomParticipantRole,
1282 response: Response<proto::SetRoomParticipantRole>,
1283 session: Session,
1284) -> Result<()> {
1285 let (live_kit_room, can_publish) = {
1286 let room = session
1287 .db()
1288 .await
1289 .set_room_participant_role(
1290 session.user_id,
1291 RoomId::from_proto(request.room_id),
1292 UserId::from_proto(request.user_id),
1293 ChannelRole::from(request.role()),
1294 )
1295 .await?;
1296
1297 let live_kit_room = room.live_kit_room.clone();
1298 let can_publish = ChannelRole::from(request.role()).can_publish_to_rooms();
1299 room_updated(&room, &session.peer);
1300 (live_kit_room, can_publish)
1301 };
1302
1303 if let Some(live_kit) = session.live_kit_client.as_ref() {
1304 live_kit
1305 .update_participant(
1306 live_kit_room.clone(),
1307 request.user_id.to_string(),
1308 live_kit_server::proto::ParticipantPermission {
1309 can_subscribe: true,
1310 can_publish,
1311 can_publish_data: can_publish,
1312 hidden: false,
1313 recorder: false,
1314 },
1315 )
1316 .await
1317 .trace_err();
1318 }
1319
1320 response.send(proto::Ack {})?;
1321 Ok(())
1322}
1323
1324/// Call someone else into the current room
1325async fn call(
1326 request: proto::Call,
1327 response: Response<proto::Call>,
1328 session: Session,
1329) -> Result<()> {
1330 let room_id = RoomId::from_proto(request.room_id);
1331 let calling_user_id = session.user_id;
1332 let calling_connection_id = session.connection_id;
1333 let called_user_id = UserId::from_proto(request.called_user_id);
1334 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1335 if !session
1336 .db()
1337 .await
1338 .has_contact(calling_user_id, called_user_id)
1339 .await?
1340 {
1341 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1342 }
1343
1344 let incoming_call = {
1345 let (room, incoming_call) = &mut *session
1346 .db()
1347 .await
1348 .call(
1349 room_id,
1350 calling_user_id,
1351 calling_connection_id,
1352 called_user_id,
1353 initial_project_id,
1354 )
1355 .await?;
1356 room_updated(&room, &session.peer);
1357 mem::take(incoming_call)
1358 };
1359 update_user_contacts(called_user_id, &session).await?;
1360
1361 let mut calls = session
1362 .connection_pool()
1363 .await
1364 .user_connection_ids(called_user_id)
1365 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1366 .collect::<FuturesUnordered<_>>();
1367
1368 while let Some(call_response) = calls.next().await {
1369 match call_response.as_ref() {
1370 Ok(_) => {
1371 response.send(proto::Ack {})?;
1372 return Ok(());
1373 }
1374 Err(_) => {
1375 call_response.trace_err();
1376 }
1377 }
1378 }
1379
1380 {
1381 let room = session
1382 .db()
1383 .await
1384 .call_failed(room_id, called_user_id)
1385 .await?;
1386 room_updated(&room, &session.peer);
1387 }
1388 update_user_contacts(called_user_id, &session).await?;
1389
1390 Err(anyhow!("failed to ring user"))?
1391}
1392
1393/// Cancel an outgoing call.
1394async fn cancel_call(
1395 request: proto::CancelCall,
1396 response: Response<proto::CancelCall>,
1397 session: Session,
1398) -> Result<()> {
1399 let called_user_id = UserId::from_proto(request.called_user_id);
1400 let room_id = RoomId::from_proto(request.room_id);
1401 {
1402 let room = session
1403 .db()
1404 .await
1405 .cancel_call(room_id, session.connection_id, called_user_id)
1406 .await?;
1407 room_updated(&room, &session.peer);
1408 }
1409
1410 for connection_id in session
1411 .connection_pool()
1412 .await
1413 .user_connection_ids(called_user_id)
1414 {
1415 session
1416 .peer
1417 .send(
1418 connection_id,
1419 proto::CallCanceled {
1420 room_id: room_id.to_proto(),
1421 },
1422 )
1423 .trace_err();
1424 }
1425 response.send(proto::Ack {})?;
1426
1427 update_user_contacts(called_user_id, &session).await?;
1428 Ok(())
1429}
1430
1431/// Decline an incoming call.
1432async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1433 let room_id = RoomId::from_proto(message.room_id);
1434 {
1435 let room = session
1436 .db()
1437 .await
1438 .decline_call(Some(room_id), session.user_id)
1439 .await?
1440 .ok_or_else(|| anyhow!("failed to decline call"))?;
1441 room_updated(&room, &session.peer);
1442 }
1443
1444 for connection_id in session
1445 .connection_pool()
1446 .await
1447 .user_connection_ids(session.user_id)
1448 {
1449 session
1450 .peer
1451 .send(
1452 connection_id,
1453 proto::CallCanceled {
1454 room_id: room_id.to_proto(),
1455 },
1456 )
1457 .trace_err();
1458 }
1459 update_user_contacts(session.user_id, &session).await?;
1460 Ok(())
1461}
1462
1463/// Update other participants in the room with your current location.
1464async fn update_participant_location(
1465 request: proto::UpdateParticipantLocation,
1466 response: Response<proto::UpdateParticipantLocation>,
1467 session: Session,
1468) -> Result<()> {
1469 let room_id = RoomId::from_proto(request.room_id);
1470 let location = request
1471 .location
1472 .ok_or_else(|| anyhow!("invalid location"))?;
1473
1474 let db = session.db().await;
1475 let room = db
1476 .update_room_participant_location(room_id, session.connection_id, location)
1477 .await?;
1478
1479 room_updated(&room, &session.peer);
1480 response.send(proto::Ack {})?;
1481 Ok(())
1482}
1483
1484/// Share a project into the room.
1485async fn share_project(
1486 request: proto::ShareProject,
1487 response: Response<proto::ShareProject>,
1488 session: Session,
1489) -> Result<()> {
1490 let (project_id, room) = &*session
1491 .db()
1492 .await
1493 .share_project(
1494 RoomId::from_proto(request.room_id),
1495 session.connection_id,
1496 &request.worktrees,
1497 )
1498 .await?;
1499 response.send(proto::ShareProjectResponse {
1500 project_id: project_id.to_proto(),
1501 })?;
1502 room_updated(&room, &session.peer);
1503
1504 Ok(())
1505}
1506
1507/// Unshare a project from the room.
1508async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1509 let project_id = ProjectId::from_proto(message.project_id);
1510
1511 let (room, guest_connection_ids) = &*session
1512 .db()
1513 .await
1514 .unshare_project(project_id, session.connection_id)
1515 .await?;
1516
1517 broadcast(
1518 Some(session.connection_id),
1519 guest_connection_ids.iter().copied(),
1520 |conn_id| session.peer.send(conn_id, message.clone()),
1521 );
1522 room_updated(&room, &session.peer);
1523
1524 Ok(())
1525}
1526
1527/// Join someone elses shared project.
1528async fn join_project(
1529 request: proto::JoinProject,
1530 response: Response<proto::JoinProject>,
1531 session: Session,
1532) -> Result<()> {
1533 let project_id = ProjectId::from_proto(request.project_id);
1534 let guest_user_id = session.user_id;
1535
1536 tracing::info!(%project_id, "join project");
1537
1538 let (project, replica_id) = &mut *session
1539 .db()
1540 .await
1541 .join_project(project_id, session.connection_id)
1542 .await?;
1543
1544 let collaborators = project
1545 .collaborators
1546 .iter()
1547 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1548 .map(|collaborator| collaborator.to_proto())
1549 .collect::<Vec<_>>();
1550
1551 let worktrees = project
1552 .worktrees
1553 .iter()
1554 .map(|(id, worktree)| proto::WorktreeMetadata {
1555 id: *id,
1556 root_name: worktree.root_name.clone(),
1557 visible: worktree.visible,
1558 abs_path: worktree.abs_path.clone(),
1559 })
1560 .collect::<Vec<_>>();
1561
1562 for collaborator in &collaborators {
1563 session
1564 .peer
1565 .send(
1566 collaborator.peer_id.unwrap().into(),
1567 proto::AddProjectCollaborator {
1568 project_id: project_id.to_proto(),
1569 collaborator: Some(proto::Collaborator {
1570 peer_id: Some(session.connection_id.into()),
1571 replica_id: replica_id.0 as u32,
1572 user_id: guest_user_id.to_proto(),
1573 }),
1574 },
1575 )
1576 .trace_err();
1577 }
1578
1579 // First, we send the metadata associated with each worktree.
1580 response.send(proto::JoinProjectResponse {
1581 worktrees: worktrees.clone(),
1582 replica_id: replica_id.0 as u32,
1583 collaborators: collaborators.clone(),
1584 language_servers: project.language_servers.clone(),
1585 })?;
1586
1587 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1588 #[cfg(any(test, feature = "test-support"))]
1589 const MAX_CHUNK_SIZE: usize = 2;
1590 #[cfg(not(any(test, feature = "test-support")))]
1591 const MAX_CHUNK_SIZE: usize = 256;
1592
1593 // Stream this worktree's entries.
1594 let message = proto::UpdateWorktree {
1595 project_id: project_id.to_proto(),
1596 worktree_id,
1597 abs_path: worktree.abs_path.clone(),
1598 root_name: worktree.root_name,
1599 updated_entries: worktree.entries,
1600 removed_entries: Default::default(),
1601 scan_id: worktree.scan_id,
1602 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1603 updated_repositories: worktree.repository_entries.into_values().collect(),
1604 removed_repositories: Default::default(),
1605 };
1606 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1607 session.peer.send(session.connection_id, update.clone())?;
1608 }
1609
1610 // Stream this worktree's diagnostics.
1611 for summary in worktree.diagnostic_summaries {
1612 session.peer.send(
1613 session.connection_id,
1614 proto::UpdateDiagnosticSummary {
1615 project_id: project_id.to_proto(),
1616 worktree_id: worktree.id,
1617 summary: Some(summary),
1618 },
1619 )?;
1620 }
1621
1622 for settings_file in worktree.settings_files {
1623 session.peer.send(
1624 session.connection_id,
1625 proto::UpdateWorktreeSettings {
1626 project_id: project_id.to_proto(),
1627 worktree_id: worktree.id,
1628 path: settings_file.path,
1629 content: Some(settings_file.content),
1630 },
1631 )?;
1632 }
1633 }
1634
1635 for language_server in &project.language_servers {
1636 session.peer.send(
1637 session.connection_id,
1638 proto::UpdateLanguageServer {
1639 project_id: project_id.to_proto(),
1640 language_server_id: language_server.id,
1641 variant: Some(
1642 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1643 proto::LspDiskBasedDiagnosticsUpdated {},
1644 ),
1645 ),
1646 },
1647 )?;
1648 }
1649
1650 Ok(())
1651}
1652
1653/// Leave someone elses shared project.
1654async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1655 let sender_id = session.connection_id;
1656 let project_id = ProjectId::from_proto(request.project_id);
1657
1658 let (room, project) = &*session
1659 .db()
1660 .await
1661 .leave_project(project_id, sender_id)
1662 .await?;
1663 tracing::info!(
1664 %project_id,
1665 host_user_id = %project.host_user_id,
1666 host_connection_id = %project.host_connection_id,
1667 "leave project"
1668 );
1669
1670 project_left(&project, &session);
1671 room_updated(&room, &session.peer);
1672
1673 Ok(())
1674}
1675
1676/// Update other participants with changes to the project
1677async fn update_project(
1678 request: proto::UpdateProject,
1679 response: Response<proto::UpdateProject>,
1680 session: Session,
1681) -> Result<()> {
1682 let project_id = ProjectId::from_proto(request.project_id);
1683 let (room, guest_connection_ids) = &*session
1684 .db()
1685 .await
1686 .update_project(project_id, session.connection_id, &request.worktrees)
1687 .await?;
1688 broadcast(
1689 Some(session.connection_id),
1690 guest_connection_ids.iter().copied(),
1691 |connection_id| {
1692 session
1693 .peer
1694 .forward_send(session.connection_id, connection_id, request.clone())
1695 },
1696 );
1697 room_updated(&room, &session.peer);
1698 response.send(proto::Ack {})?;
1699
1700 Ok(())
1701}
1702
1703/// Update other participants with changes to the worktree
1704async fn update_worktree(
1705 request: proto::UpdateWorktree,
1706 response: Response<proto::UpdateWorktree>,
1707 session: Session,
1708) -> Result<()> {
1709 let guest_connection_ids = session
1710 .db()
1711 .await
1712 .update_worktree(&request, session.connection_id)
1713 .await?;
1714
1715 broadcast(
1716 Some(session.connection_id),
1717 guest_connection_ids.iter().copied(),
1718 |connection_id| {
1719 session
1720 .peer
1721 .forward_send(session.connection_id, connection_id, request.clone())
1722 },
1723 );
1724 response.send(proto::Ack {})?;
1725 Ok(())
1726}
1727
1728/// Update other participants with changes to the diagnostics
1729async fn update_diagnostic_summary(
1730 message: proto::UpdateDiagnosticSummary,
1731 session: Session,
1732) -> Result<()> {
1733 let guest_connection_ids = session
1734 .db()
1735 .await
1736 .update_diagnostic_summary(&message, session.connection_id)
1737 .await?;
1738
1739 broadcast(
1740 Some(session.connection_id),
1741 guest_connection_ids.iter().copied(),
1742 |connection_id| {
1743 session
1744 .peer
1745 .forward_send(session.connection_id, connection_id, message.clone())
1746 },
1747 );
1748
1749 Ok(())
1750}
1751
1752/// Update other participants with changes to the worktree settings
1753async fn update_worktree_settings(
1754 message: proto::UpdateWorktreeSettings,
1755 session: Session,
1756) -> Result<()> {
1757 let guest_connection_ids = session
1758 .db()
1759 .await
1760 .update_worktree_settings(&message, session.connection_id)
1761 .await?;
1762
1763 broadcast(
1764 Some(session.connection_id),
1765 guest_connection_ids.iter().copied(),
1766 |connection_id| {
1767 session
1768 .peer
1769 .forward_send(session.connection_id, connection_id, message.clone())
1770 },
1771 );
1772
1773 Ok(())
1774}
1775
1776/// Notify other participants that a language server has started.
1777async fn start_language_server(
1778 request: proto::StartLanguageServer,
1779 session: Session,
1780) -> Result<()> {
1781 let guest_connection_ids = session
1782 .db()
1783 .await
1784 .start_language_server(&request, session.connection_id)
1785 .await?;
1786
1787 broadcast(
1788 Some(session.connection_id),
1789 guest_connection_ids.iter().copied(),
1790 |connection_id| {
1791 session
1792 .peer
1793 .forward_send(session.connection_id, connection_id, request.clone())
1794 },
1795 );
1796 Ok(())
1797}
1798
1799/// Notify other participants that a language server has changed.
1800async fn update_language_server(
1801 request: proto::UpdateLanguageServer,
1802 session: Session,
1803) -> Result<()> {
1804 let project_id = ProjectId::from_proto(request.project_id);
1805 let project_connection_ids = session
1806 .db()
1807 .await
1808 .project_connection_ids(project_id, session.connection_id)
1809 .await?;
1810 broadcast(
1811 Some(session.connection_id),
1812 project_connection_ids.iter().copied(),
1813 |connection_id| {
1814 session
1815 .peer
1816 .forward_send(session.connection_id, connection_id, request.clone())
1817 },
1818 );
1819 Ok(())
1820}
1821
1822/// forward a project request to the host. These requests should be read only
1823/// as guests are allowed to send them.
1824async fn forward_read_only_project_request<T>(
1825 request: T,
1826 response: Response<T>,
1827 session: Session,
1828) -> Result<()>
1829where
1830 T: EntityMessage + RequestMessage,
1831{
1832 let project_id = ProjectId::from_proto(request.remote_entity_id());
1833 let host_connection_id = session
1834 .db()
1835 .await
1836 .host_for_read_only_project_request(project_id, session.connection_id)
1837 .await?;
1838 let payload = session
1839 .peer
1840 .forward_request(session.connection_id, host_connection_id, request)
1841 .await?;
1842 response.send(payload)?;
1843 Ok(())
1844}
1845
1846/// forward a project request to the host. These requests are disallowed
1847/// for guests.
1848async fn forward_mutating_project_request<T>(
1849 request: T,
1850 response: Response<T>,
1851 session: Session,
1852) -> Result<()>
1853where
1854 T: EntityMessage + RequestMessage,
1855{
1856 let project_id = ProjectId::from_proto(request.remote_entity_id());
1857 let host_connection_id = session
1858 .db()
1859 .await
1860 .host_for_mutating_project_request(project_id, session.connection_id)
1861 .await?;
1862 let payload = session
1863 .peer
1864 .forward_request(session.connection_id, host_connection_id, request)
1865 .await?;
1866 response.send(payload)?;
1867 Ok(())
1868}
1869
1870/// Notify other participants that a new buffer has been created
1871async fn create_buffer_for_peer(
1872 request: proto::CreateBufferForPeer,
1873 session: Session,
1874) -> Result<()> {
1875 session
1876 .db()
1877 .await
1878 .check_user_is_project_host(
1879 ProjectId::from_proto(request.project_id),
1880 session.connection_id,
1881 )
1882 .await?;
1883 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1884 session
1885 .peer
1886 .forward_send(session.connection_id, peer_id.into(), request)?;
1887 Ok(())
1888}
1889
1890/// Notify other participants that a buffer has been updated. This is
1891/// allowed for guests as long as the update is limited to selections.
1892async fn update_buffer(
1893 request: proto::UpdateBuffer,
1894 response: Response<proto::UpdateBuffer>,
1895 session: Session,
1896) -> Result<()> {
1897 let project_id = ProjectId::from_proto(request.project_id);
1898 let mut guest_connection_ids;
1899 let mut host_connection_id = None;
1900
1901 let mut requires_write_permission = false;
1902
1903 for op in request.operations.iter() {
1904 match op.variant {
1905 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1906 Some(_) => requires_write_permission = true,
1907 }
1908 }
1909
1910 {
1911 let collaborators = session
1912 .db()
1913 .await
1914 .project_collaborators_for_buffer_update(
1915 project_id,
1916 session.connection_id,
1917 requires_write_permission,
1918 )
1919 .await?;
1920 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1921 for collaborator in collaborators.iter() {
1922 if collaborator.is_host {
1923 host_connection_id = Some(collaborator.connection_id);
1924 } else {
1925 guest_connection_ids.push(collaborator.connection_id);
1926 }
1927 }
1928 }
1929 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1930
1931 broadcast(
1932 Some(session.connection_id),
1933 guest_connection_ids,
1934 |connection_id| {
1935 session
1936 .peer
1937 .forward_send(session.connection_id, connection_id, request.clone())
1938 },
1939 );
1940 if host_connection_id != session.connection_id {
1941 session
1942 .peer
1943 .forward_request(session.connection_id, host_connection_id, request.clone())
1944 .await?;
1945 }
1946
1947 response.send(proto::Ack {})?;
1948 Ok(())
1949}
1950
1951/// Notify other participants that a project has been updated.
1952async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
1953 request: T,
1954 session: Session,
1955) -> Result<()> {
1956 let project_id = ProjectId::from_proto(request.remote_entity_id());
1957 let project_connection_ids = session
1958 .db()
1959 .await
1960 .project_connection_ids(project_id, session.connection_id)
1961 .await?;
1962
1963 broadcast(
1964 Some(session.connection_id),
1965 project_connection_ids.iter().copied(),
1966 |connection_id| {
1967 session
1968 .peer
1969 .forward_send(session.connection_id, connection_id, request.clone())
1970 },
1971 );
1972 Ok(())
1973}
1974
1975/// Start following another user in a call.
1976async fn follow(
1977 request: proto::Follow,
1978 response: Response<proto::Follow>,
1979 session: Session,
1980) -> Result<()> {
1981 let room_id = RoomId::from_proto(request.room_id);
1982 let project_id = request.project_id.map(ProjectId::from_proto);
1983 let leader_id = request
1984 .leader_id
1985 .ok_or_else(|| anyhow!("invalid leader id"))?
1986 .into();
1987 let follower_id = session.connection_id;
1988
1989 session
1990 .db()
1991 .await
1992 .check_room_participants(room_id, leader_id, session.connection_id)
1993 .await?;
1994
1995 let response_payload = session
1996 .peer
1997 .forward_request(session.connection_id, leader_id, request)
1998 .await?;
1999 response.send(response_payload)?;
2000
2001 if let Some(project_id) = project_id {
2002 let room = session
2003 .db()
2004 .await
2005 .follow(room_id, project_id, leader_id, follower_id)
2006 .await?;
2007 room_updated(&room, &session.peer);
2008 }
2009
2010 Ok(())
2011}
2012
2013/// Stop following another user in a call.
2014async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2015 let room_id = RoomId::from_proto(request.room_id);
2016 let project_id = request.project_id.map(ProjectId::from_proto);
2017 let leader_id = request
2018 .leader_id
2019 .ok_or_else(|| anyhow!("invalid leader id"))?
2020 .into();
2021 let follower_id = session.connection_id;
2022
2023 session
2024 .db()
2025 .await
2026 .check_room_participants(room_id, leader_id, session.connection_id)
2027 .await?;
2028
2029 session
2030 .peer
2031 .forward_send(session.connection_id, leader_id, request)?;
2032
2033 if let Some(project_id) = project_id {
2034 let room = session
2035 .db()
2036 .await
2037 .unfollow(room_id, project_id, leader_id, follower_id)
2038 .await?;
2039 room_updated(&room, &session.peer);
2040 }
2041
2042 Ok(())
2043}
2044
2045/// Notify everyone following you of your current location.
2046async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2047 let room_id = RoomId::from_proto(request.room_id);
2048 let database = session.db.lock().await;
2049
2050 let connection_ids = if let Some(project_id) = request.project_id {
2051 let project_id = ProjectId::from_proto(project_id);
2052 database
2053 .project_connection_ids(project_id, session.connection_id)
2054 .await?
2055 } else {
2056 database
2057 .room_connection_ids(room_id, session.connection_id)
2058 .await?
2059 };
2060
2061 // For now, don't send view update messages back to that view's current leader.
2062 let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2063 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2064 _ => None,
2065 });
2066
2067 for follower_peer_id in request.follower_ids.iter().copied() {
2068 let follower_connection_id = follower_peer_id.into();
2069 if Some(follower_peer_id) != connection_id_to_omit
2070 && connection_ids.contains(&follower_connection_id)
2071 {
2072 session.peer.forward_send(
2073 session.connection_id,
2074 follower_connection_id,
2075 request.clone(),
2076 )?;
2077 }
2078 }
2079 Ok(())
2080}
2081
2082/// Get public data about users.
2083async fn get_users(
2084 request: proto::GetUsers,
2085 response: Response<proto::GetUsers>,
2086 session: Session,
2087) -> Result<()> {
2088 let user_ids = request
2089 .user_ids
2090 .into_iter()
2091 .map(UserId::from_proto)
2092 .collect();
2093 let users = session
2094 .db()
2095 .await
2096 .get_users_by_ids(user_ids)
2097 .await?
2098 .into_iter()
2099 .map(|user| proto::User {
2100 id: user.id.to_proto(),
2101 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2102 github_login: user.github_login,
2103 })
2104 .collect();
2105 response.send(proto::UsersResponse { users })?;
2106 Ok(())
2107}
2108
2109/// Search for users (to invite) buy Github login
2110async fn fuzzy_search_users(
2111 request: proto::FuzzySearchUsers,
2112 response: Response<proto::FuzzySearchUsers>,
2113 session: Session,
2114) -> Result<()> {
2115 let query = request.query;
2116 let users = match query.len() {
2117 0 => vec![],
2118 1 | 2 => session
2119 .db()
2120 .await
2121 .get_user_by_github_login(&query)
2122 .await?
2123 .into_iter()
2124 .collect(),
2125 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2126 };
2127 let users = users
2128 .into_iter()
2129 .filter(|user| user.id != session.user_id)
2130 .map(|user| proto::User {
2131 id: user.id.to_proto(),
2132 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2133 github_login: user.github_login,
2134 })
2135 .collect();
2136 response.send(proto::UsersResponse { users })?;
2137 Ok(())
2138}
2139
2140/// Send a contact request to another user.
2141async fn request_contact(
2142 request: proto::RequestContact,
2143 response: Response<proto::RequestContact>,
2144 session: Session,
2145) -> Result<()> {
2146 let requester_id = session.user_id;
2147 let responder_id = UserId::from_proto(request.responder_id);
2148 if requester_id == responder_id {
2149 return Err(anyhow!("cannot add yourself as a contact"))?;
2150 }
2151
2152 let notifications = session
2153 .db()
2154 .await
2155 .send_contact_request(requester_id, responder_id)
2156 .await?;
2157
2158 // Update outgoing contact requests of requester
2159 let mut update = proto::UpdateContacts::default();
2160 update.outgoing_requests.push(responder_id.to_proto());
2161 for connection_id in session
2162 .connection_pool()
2163 .await
2164 .user_connection_ids(requester_id)
2165 {
2166 session.peer.send(connection_id, update.clone())?;
2167 }
2168
2169 // Update incoming contact requests of responder
2170 let mut update = proto::UpdateContacts::default();
2171 update
2172 .incoming_requests
2173 .push(proto::IncomingContactRequest {
2174 requester_id: requester_id.to_proto(),
2175 });
2176 let connection_pool = session.connection_pool().await;
2177 for connection_id in connection_pool.user_connection_ids(responder_id) {
2178 session.peer.send(connection_id, update.clone())?;
2179 }
2180
2181 send_notifications(&*connection_pool, &session.peer, notifications);
2182
2183 response.send(proto::Ack {})?;
2184 Ok(())
2185}
2186
2187/// Accept or decline a contact request
2188async fn respond_to_contact_request(
2189 request: proto::RespondToContactRequest,
2190 response: Response<proto::RespondToContactRequest>,
2191 session: Session,
2192) -> Result<()> {
2193 let responder_id = session.user_id;
2194 let requester_id = UserId::from_proto(request.requester_id);
2195 let db = session.db().await;
2196 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2197 db.dismiss_contact_notification(responder_id, requester_id)
2198 .await?;
2199 } else {
2200 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2201
2202 let notifications = db
2203 .respond_to_contact_request(responder_id, requester_id, accept)
2204 .await?;
2205 let requester_busy = db.is_user_busy(requester_id).await?;
2206 let responder_busy = db.is_user_busy(responder_id).await?;
2207
2208 let pool = session.connection_pool().await;
2209 // Update responder with new contact
2210 let mut update = proto::UpdateContacts::default();
2211 if accept {
2212 update
2213 .contacts
2214 .push(contact_for_user(requester_id, requester_busy, &pool));
2215 }
2216 update
2217 .remove_incoming_requests
2218 .push(requester_id.to_proto());
2219 for connection_id in pool.user_connection_ids(responder_id) {
2220 session.peer.send(connection_id, update.clone())?;
2221 }
2222
2223 // Update requester with new contact
2224 let mut update = proto::UpdateContacts::default();
2225 if accept {
2226 update
2227 .contacts
2228 .push(contact_for_user(responder_id, responder_busy, &pool));
2229 }
2230 update
2231 .remove_outgoing_requests
2232 .push(responder_id.to_proto());
2233
2234 for connection_id in pool.user_connection_ids(requester_id) {
2235 session.peer.send(connection_id, update.clone())?;
2236 }
2237
2238 send_notifications(&*pool, &session.peer, notifications);
2239 }
2240
2241 response.send(proto::Ack {})?;
2242 Ok(())
2243}
2244
2245/// Remove a contact.
2246async fn remove_contact(
2247 request: proto::RemoveContact,
2248 response: Response<proto::RemoveContact>,
2249 session: Session,
2250) -> Result<()> {
2251 let requester_id = session.user_id;
2252 let responder_id = UserId::from_proto(request.user_id);
2253 let db = session.db().await;
2254 let (contact_accepted, deleted_notification_id) =
2255 db.remove_contact(requester_id, responder_id).await?;
2256
2257 let pool = session.connection_pool().await;
2258 // Update outgoing contact requests of requester
2259 let mut update = proto::UpdateContacts::default();
2260 if contact_accepted {
2261 update.remove_contacts.push(responder_id.to_proto());
2262 } else {
2263 update
2264 .remove_outgoing_requests
2265 .push(responder_id.to_proto());
2266 }
2267 for connection_id in pool.user_connection_ids(requester_id) {
2268 session.peer.send(connection_id, update.clone())?;
2269 }
2270
2271 // Update incoming contact requests of responder
2272 let mut update = proto::UpdateContacts::default();
2273 if contact_accepted {
2274 update.remove_contacts.push(requester_id.to_proto());
2275 } else {
2276 update
2277 .remove_incoming_requests
2278 .push(requester_id.to_proto());
2279 }
2280 for connection_id in pool.user_connection_ids(responder_id) {
2281 session.peer.send(connection_id, update.clone())?;
2282 if let Some(notification_id) = deleted_notification_id {
2283 session.peer.send(
2284 connection_id,
2285 proto::DeleteNotification {
2286 notification_id: notification_id.to_proto(),
2287 },
2288 )?;
2289 }
2290 }
2291
2292 response.send(proto::Ack {})?;
2293 Ok(())
2294}
2295
2296/// Create a new channel.
2297async fn create_channel(
2298 request: proto::CreateChannel,
2299 response: Response<proto::CreateChannel>,
2300 session: Session,
2301) -> Result<()> {
2302 let db = session.db().await;
2303
2304 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2305 let CreateChannelResult {
2306 channel,
2307 participants_to_update,
2308 } = db
2309 .create_channel(&request.name, parent_id, session.user_id)
2310 .await?;
2311
2312 response.send(proto::CreateChannelResponse {
2313 channel: Some(channel.to_proto()),
2314 parent_id: request.parent_id,
2315 })?;
2316
2317 let connection_pool = session.connection_pool().await;
2318 for (user_id, channels) in participants_to_update {
2319 let update = build_channels_update(channels, vec![]);
2320 for connection_id in connection_pool.user_connection_ids(user_id) {
2321 if user_id == session.user_id {
2322 continue;
2323 }
2324 session.peer.send(connection_id, update.clone())?;
2325 }
2326 }
2327
2328 Ok(())
2329}
2330
2331/// Delete a channel
2332async fn delete_channel(
2333 request: proto::DeleteChannel,
2334 response: Response<proto::DeleteChannel>,
2335 session: Session,
2336) -> Result<()> {
2337 let db = session.db().await;
2338
2339 let channel_id = request.channel_id;
2340 let (removed_channels, member_ids) = db
2341 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2342 .await?;
2343 response.send(proto::Ack {})?;
2344
2345 // Notify members of removed channels
2346 let mut update = proto::UpdateChannels::default();
2347 update
2348 .delete_channels
2349 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2350
2351 let connection_pool = session.connection_pool().await;
2352 for member_id in member_ids {
2353 for connection_id in connection_pool.user_connection_ids(member_id) {
2354 session.peer.send(connection_id, update.clone())?;
2355 }
2356 }
2357
2358 Ok(())
2359}
2360
2361/// Invite someone to join a channel.
2362async fn invite_channel_member(
2363 request: proto::InviteChannelMember,
2364 response: Response<proto::InviteChannelMember>,
2365 session: Session,
2366) -> Result<()> {
2367 let db = session.db().await;
2368 let channel_id = ChannelId::from_proto(request.channel_id);
2369 let invitee_id = UserId::from_proto(request.user_id);
2370 let InviteMemberResult {
2371 channel,
2372 notifications,
2373 } = db
2374 .invite_channel_member(
2375 channel_id,
2376 invitee_id,
2377 session.user_id,
2378 request.role().into(),
2379 )
2380 .await?;
2381
2382 let update = proto::UpdateChannels {
2383 channel_invitations: vec![channel.to_proto()],
2384 ..Default::default()
2385 };
2386
2387 let connection_pool = session.connection_pool().await;
2388 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2389 session.peer.send(connection_id, update.clone())?;
2390 }
2391
2392 send_notifications(&*connection_pool, &session.peer, notifications);
2393
2394 response.send(proto::Ack {})?;
2395 Ok(())
2396}
2397
2398/// remove someone from a channel
2399async fn remove_channel_member(
2400 request: proto::RemoveChannelMember,
2401 response: Response<proto::RemoveChannelMember>,
2402 session: Session,
2403) -> Result<()> {
2404 let db = session.db().await;
2405 let channel_id = ChannelId::from_proto(request.channel_id);
2406 let member_id = UserId::from_proto(request.user_id);
2407
2408 let RemoveChannelMemberResult {
2409 membership_update,
2410 notification_id,
2411 } = db
2412 .remove_channel_member(channel_id, member_id, session.user_id)
2413 .await?;
2414
2415 let connection_pool = &session.connection_pool().await;
2416 notify_membership_updated(
2417 &connection_pool,
2418 membership_update,
2419 member_id,
2420 &session.peer,
2421 );
2422 for connection_id in connection_pool.user_connection_ids(member_id) {
2423 if let Some(notification_id) = notification_id {
2424 session
2425 .peer
2426 .send(
2427 connection_id,
2428 proto::DeleteNotification {
2429 notification_id: notification_id.to_proto(),
2430 },
2431 )
2432 .trace_err();
2433 }
2434 }
2435
2436 response.send(proto::Ack {})?;
2437 Ok(())
2438}
2439
2440/// Toggle the channel between public and private
2441async fn set_channel_visibility(
2442 request: proto::SetChannelVisibility,
2443 response: Response<proto::SetChannelVisibility>,
2444 session: Session,
2445) -> Result<()> {
2446 let db = session.db().await;
2447 let channel_id = ChannelId::from_proto(request.channel_id);
2448 let visibility = request.visibility().into();
2449
2450 let SetChannelVisibilityResult {
2451 participants_to_update,
2452 participants_to_remove,
2453 channels_to_remove,
2454 } = db
2455 .set_channel_visibility(channel_id, visibility, session.user_id)
2456 .await?;
2457
2458 let connection_pool = session.connection_pool().await;
2459 for (user_id, channels) in participants_to_update {
2460 let update = build_channels_update(channels, vec![]);
2461 for connection_id in connection_pool.user_connection_ids(user_id) {
2462 session.peer.send(connection_id, update.clone())?;
2463 }
2464 }
2465 for user_id in participants_to_remove {
2466 let update = proto::UpdateChannels {
2467 delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2468 ..Default::default()
2469 };
2470 for connection_id in connection_pool.user_connection_ids(user_id) {
2471 session.peer.send(connection_id, update.clone())?;
2472 }
2473 }
2474
2475 response.send(proto::Ack {})?;
2476 Ok(())
2477}
2478
2479/// Alter the role for a user in the channel
2480async fn set_channel_member_role(
2481 request: proto::SetChannelMemberRole,
2482 response: Response<proto::SetChannelMemberRole>,
2483 session: Session,
2484) -> Result<()> {
2485 let db = session.db().await;
2486 let channel_id = ChannelId::from_proto(request.channel_id);
2487 let member_id = UserId::from_proto(request.user_id);
2488 let result = db
2489 .set_channel_member_role(
2490 channel_id,
2491 session.user_id,
2492 member_id,
2493 request.role().into(),
2494 )
2495 .await?;
2496
2497 match result {
2498 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2499 let connection_pool = session.connection_pool().await;
2500 notify_membership_updated(
2501 &connection_pool,
2502 membership_update,
2503 member_id,
2504 &session.peer,
2505 )
2506 }
2507 db::SetMemberRoleResult::InviteUpdated(channel) => {
2508 let update = proto::UpdateChannels {
2509 channel_invitations: vec![channel.to_proto()],
2510 ..Default::default()
2511 };
2512
2513 for connection_id in session
2514 .connection_pool()
2515 .await
2516 .user_connection_ids(member_id)
2517 {
2518 session.peer.send(connection_id, update.clone())?;
2519 }
2520 }
2521 }
2522
2523 response.send(proto::Ack {})?;
2524 Ok(())
2525}
2526
2527/// Change the name of a channel
2528async fn rename_channel(
2529 request: proto::RenameChannel,
2530 response: Response<proto::RenameChannel>,
2531 session: Session,
2532) -> Result<()> {
2533 let db = session.db().await;
2534 let channel_id = ChannelId::from_proto(request.channel_id);
2535 let RenameChannelResult {
2536 channel,
2537 participants_to_update,
2538 } = db
2539 .rename_channel(channel_id, session.user_id, &request.name)
2540 .await?;
2541
2542 response.send(proto::RenameChannelResponse {
2543 channel: Some(channel.to_proto()),
2544 })?;
2545
2546 let connection_pool = session.connection_pool().await;
2547 for (user_id, channel) in participants_to_update {
2548 for connection_id in connection_pool.user_connection_ids(user_id) {
2549 let update = proto::UpdateChannels {
2550 channels: vec![channel.to_proto()],
2551 ..Default::default()
2552 };
2553
2554 session.peer.send(connection_id, update.clone())?;
2555 }
2556 }
2557
2558 Ok(())
2559}
2560
2561/// Move a channel to a new parent.
2562async fn move_channel(
2563 request: proto::MoveChannel,
2564 response: Response<proto::MoveChannel>,
2565 session: Session,
2566) -> Result<()> {
2567 let channel_id = ChannelId::from_proto(request.channel_id);
2568 let to = request.to.map(ChannelId::from_proto);
2569
2570 let result = session
2571 .db()
2572 .await
2573 .move_channel(channel_id, to, session.user_id)
2574 .await?;
2575
2576 notify_channel_moved(result, session).await?;
2577
2578 response.send(Ack {})?;
2579 Ok(())
2580}
2581
2582async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2583 let Some(MoveChannelResult {
2584 participants_to_remove,
2585 participants_to_update,
2586 moved_channels,
2587 }) = result
2588 else {
2589 return Ok(());
2590 };
2591 let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2592
2593 let connection_pool = session.connection_pool().await;
2594 for (user_id, channels) in participants_to_update {
2595 let mut update = build_channels_update(channels, vec![]);
2596 update.delete_channels = moved_channels.clone();
2597 for connection_id in connection_pool.user_connection_ids(user_id) {
2598 session.peer.send(connection_id, update.clone())?;
2599 }
2600 }
2601
2602 for user_id in participants_to_remove {
2603 let update = proto::UpdateChannels {
2604 delete_channels: moved_channels.clone(),
2605 ..Default::default()
2606 };
2607 for connection_id in connection_pool.user_connection_ids(user_id) {
2608 session.peer.send(connection_id, update.clone())?;
2609 }
2610 }
2611 Ok(())
2612}
2613
2614/// Get the list of channel members
2615async fn get_channel_members(
2616 request: proto::GetChannelMembers,
2617 response: Response<proto::GetChannelMembers>,
2618 session: Session,
2619) -> Result<()> {
2620 let db = session.db().await;
2621 let channel_id = ChannelId::from_proto(request.channel_id);
2622 let members = db
2623 .get_channel_participant_details(channel_id, session.user_id)
2624 .await?;
2625 response.send(proto::GetChannelMembersResponse { members })?;
2626 Ok(())
2627}
2628
2629/// Accept or decline a channel invitation.
2630async fn respond_to_channel_invite(
2631 request: proto::RespondToChannelInvite,
2632 response: Response<proto::RespondToChannelInvite>,
2633 session: Session,
2634) -> Result<()> {
2635 let db = session.db().await;
2636 let channel_id = ChannelId::from_proto(request.channel_id);
2637 let RespondToChannelInvite {
2638 membership_update,
2639 notifications,
2640 } = db
2641 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2642 .await?;
2643
2644 let connection_pool = session.connection_pool().await;
2645 if let Some(membership_update) = membership_update {
2646 notify_membership_updated(
2647 &connection_pool,
2648 membership_update,
2649 session.user_id,
2650 &session.peer,
2651 );
2652 } else {
2653 let update = proto::UpdateChannels {
2654 remove_channel_invitations: vec![channel_id.to_proto()],
2655 ..Default::default()
2656 };
2657
2658 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2659 session.peer.send(connection_id, update.clone())?;
2660 }
2661 };
2662
2663 send_notifications(&*connection_pool, &session.peer, notifications);
2664
2665 response.send(proto::Ack {})?;
2666
2667 Ok(())
2668}
2669
2670/// Join the channels' room
2671async fn join_channel(
2672 request: proto::JoinChannel,
2673 response: Response<proto::JoinChannel>,
2674 session: Session,
2675) -> Result<()> {
2676 let channel_id = ChannelId::from_proto(request.channel_id);
2677 join_channel_internal(channel_id, Box::new(response), session).await
2678}
2679
2680trait JoinChannelInternalResponse {
2681 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2682}
2683impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2684 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2685 Response::<proto::JoinChannel>::send(self, result)
2686 }
2687}
2688impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2689 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2690 Response::<proto::JoinRoom>::send(self, result)
2691 }
2692}
2693
2694async fn join_channel_internal(
2695 channel_id: ChannelId,
2696 response: Box<impl JoinChannelInternalResponse>,
2697 session: Session,
2698) -> Result<()> {
2699 let joined_room = {
2700 leave_room_for_session(&session).await?;
2701 let db = session.db().await;
2702
2703 let (joined_room, membership_updated, role) = db
2704 .join_channel(
2705 channel_id,
2706 session.user_id,
2707 session.connection_id,
2708 session.zed_environment.as_ref(),
2709 )
2710 .await?;
2711
2712 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2713 let (can_publish, token) = if role == ChannelRole::Guest {
2714 (
2715 false,
2716 live_kit
2717 .guest_token(
2718 &joined_room.room.live_kit_room,
2719 &session.user_id.to_string(),
2720 )
2721 .trace_err()?,
2722 )
2723 } else {
2724 (
2725 true,
2726 live_kit
2727 .room_token(
2728 &joined_room.room.live_kit_room,
2729 &session.user_id.to_string(),
2730 )
2731 .trace_err()?,
2732 )
2733 };
2734
2735 Some(LiveKitConnectionInfo {
2736 server_url: live_kit.url().into(),
2737 token,
2738 can_publish,
2739 })
2740 });
2741
2742 response.send(proto::JoinRoomResponse {
2743 room: Some(joined_room.room.clone()),
2744 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2745 live_kit_connection_info,
2746 })?;
2747
2748 let connection_pool = session.connection_pool().await;
2749 if let Some(membership_updated) = membership_updated {
2750 notify_membership_updated(
2751 &connection_pool,
2752 membership_updated,
2753 session.user_id,
2754 &session.peer,
2755 );
2756 }
2757
2758 room_updated(&joined_room.room, &session.peer);
2759
2760 joined_room
2761 };
2762
2763 channel_updated(
2764 channel_id,
2765 &joined_room.room,
2766 &joined_room.channel_members,
2767 &session.peer,
2768 &*session.connection_pool().await,
2769 );
2770
2771 update_user_contacts(session.user_id, &session).await?;
2772 Ok(())
2773}
2774
2775/// Start editing the channel notes
2776async fn join_channel_buffer(
2777 request: proto::JoinChannelBuffer,
2778 response: Response<proto::JoinChannelBuffer>,
2779 session: Session,
2780) -> Result<()> {
2781 let db = session.db().await;
2782 let channel_id = ChannelId::from_proto(request.channel_id);
2783
2784 let open_response = db
2785 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2786 .await?;
2787
2788 let collaborators = open_response.collaborators.clone();
2789 response.send(open_response)?;
2790
2791 let update = UpdateChannelBufferCollaborators {
2792 channel_id: channel_id.to_proto(),
2793 collaborators: collaborators.clone(),
2794 };
2795 channel_buffer_updated(
2796 session.connection_id,
2797 collaborators
2798 .iter()
2799 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2800 &update,
2801 &session.peer,
2802 );
2803
2804 Ok(())
2805}
2806
2807/// Edit the channel notes
2808async fn update_channel_buffer(
2809 request: proto::UpdateChannelBuffer,
2810 session: Session,
2811) -> Result<()> {
2812 let db = session.db().await;
2813 let channel_id = ChannelId::from_proto(request.channel_id);
2814
2815 let (collaborators, non_collaborators, epoch, version) = db
2816 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2817 .await?;
2818
2819 channel_buffer_updated(
2820 session.connection_id,
2821 collaborators,
2822 &proto::UpdateChannelBuffer {
2823 channel_id: channel_id.to_proto(),
2824 operations: request.operations,
2825 },
2826 &session.peer,
2827 );
2828
2829 let pool = &*session.connection_pool().await;
2830
2831 broadcast(
2832 None,
2833 non_collaborators
2834 .iter()
2835 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2836 |peer_id| {
2837 session.peer.send(
2838 peer_id.into(),
2839 proto::UpdateChannels {
2840 unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2841 channel_id: channel_id.to_proto(),
2842 epoch: epoch as u64,
2843 version: version.clone(),
2844 }],
2845 ..Default::default()
2846 },
2847 )
2848 },
2849 );
2850
2851 Ok(())
2852}
2853
2854/// Rejoin the channel notes after a connection blip
2855async fn rejoin_channel_buffers(
2856 request: proto::RejoinChannelBuffers,
2857 response: Response<proto::RejoinChannelBuffers>,
2858 session: Session,
2859) -> Result<()> {
2860 let db = session.db().await;
2861 let buffers = db
2862 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2863 .await?;
2864
2865 for rejoined_buffer in &buffers {
2866 let collaborators_to_notify = rejoined_buffer
2867 .buffer
2868 .collaborators
2869 .iter()
2870 .filter_map(|c| Some(c.peer_id?.into()));
2871 channel_buffer_updated(
2872 session.connection_id,
2873 collaborators_to_notify,
2874 &proto::UpdateChannelBufferCollaborators {
2875 channel_id: rejoined_buffer.buffer.channel_id,
2876 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2877 },
2878 &session.peer,
2879 );
2880 }
2881
2882 response.send(proto::RejoinChannelBuffersResponse {
2883 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2884 })?;
2885
2886 Ok(())
2887}
2888
2889/// Stop editing the channel notes
2890async fn leave_channel_buffer(
2891 request: proto::LeaveChannelBuffer,
2892 response: Response<proto::LeaveChannelBuffer>,
2893 session: Session,
2894) -> Result<()> {
2895 let db = session.db().await;
2896 let channel_id = ChannelId::from_proto(request.channel_id);
2897
2898 let left_buffer = db
2899 .leave_channel_buffer(channel_id, session.connection_id)
2900 .await?;
2901
2902 response.send(Ack {})?;
2903
2904 channel_buffer_updated(
2905 session.connection_id,
2906 left_buffer.connections,
2907 &proto::UpdateChannelBufferCollaborators {
2908 channel_id: channel_id.to_proto(),
2909 collaborators: left_buffer.collaborators,
2910 },
2911 &session.peer,
2912 );
2913
2914 Ok(())
2915}
2916
2917fn channel_buffer_updated<T: EnvelopedMessage>(
2918 sender_id: ConnectionId,
2919 collaborators: impl IntoIterator<Item = ConnectionId>,
2920 message: &T,
2921 peer: &Peer,
2922) {
2923 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2924 peer.send(peer_id.into(), message.clone())
2925 });
2926}
2927
2928fn send_notifications(
2929 connection_pool: &ConnectionPool,
2930 peer: &Peer,
2931 notifications: db::NotificationBatch,
2932) {
2933 for (user_id, notification) in notifications {
2934 for connection_id in connection_pool.user_connection_ids(user_id) {
2935 if let Err(error) = peer.send(
2936 connection_id,
2937 proto::AddNotification {
2938 notification: Some(notification.clone()),
2939 },
2940 ) {
2941 tracing::error!(
2942 "failed to send notification to {:?} {}",
2943 connection_id,
2944 error
2945 );
2946 }
2947 }
2948 }
2949}
2950
2951/// Send a message to the channel
2952async fn send_channel_message(
2953 request: proto::SendChannelMessage,
2954 response: Response<proto::SendChannelMessage>,
2955 session: Session,
2956) -> Result<()> {
2957 // Validate the message body.
2958 let body = request.body.trim().to_string();
2959 if body.len() > MAX_MESSAGE_LEN {
2960 return Err(anyhow!("message is too long"))?;
2961 }
2962 if body.is_empty() {
2963 return Err(anyhow!("message can't be blank"))?;
2964 }
2965
2966 // TODO: adjust mentions if body is trimmed
2967
2968 let timestamp = OffsetDateTime::now_utc();
2969 let nonce = request
2970 .nonce
2971 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2972
2973 let channel_id = ChannelId::from_proto(request.channel_id);
2974 let CreatedChannelMessage {
2975 message_id,
2976 participant_connection_ids,
2977 channel_members,
2978 notifications,
2979 } = session
2980 .db()
2981 .await
2982 .create_channel_message(
2983 channel_id,
2984 session.user_id,
2985 &body,
2986 &request.mentions,
2987 timestamp,
2988 nonce.clone().into(),
2989 )
2990 .await?;
2991 let message = proto::ChannelMessage {
2992 sender_id: session.user_id.to_proto(),
2993 id: message_id.to_proto(),
2994 body,
2995 mentions: request.mentions,
2996 timestamp: timestamp.unix_timestamp() as u64,
2997 nonce: Some(nonce),
2998 };
2999 broadcast(
3000 Some(session.connection_id),
3001 participant_connection_ids,
3002 |connection| {
3003 session.peer.send(
3004 connection,
3005 proto::ChannelMessageSent {
3006 channel_id: channel_id.to_proto(),
3007 message: Some(message.clone()),
3008 },
3009 )
3010 },
3011 );
3012 response.send(proto::SendChannelMessageResponse {
3013 message: Some(message),
3014 })?;
3015
3016 let pool = &*session.connection_pool().await;
3017 broadcast(
3018 None,
3019 channel_members
3020 .iter()
3021 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3022 |peer_id| {
3023 session.peer.send(
3024 peer_id.into(),
3025 proto::UpdateChannels {
3026 unseen_channel_messages: vec![proto::UnseenChannelMessage {
3027 channel_id: channel_id.to_proto(),
3028 message_id: message_id.to_proto(),
3029 }],
3030 ..Default::default()
3031 },
3032 )
3033 },
3034 );
3035 send_notifications(pool, &session.peer, notifications);
3036
3037 Ok(())
3038}
3039
3040/// Delete a channel message
3041async fn remove_channel_message(
3042 request: proto::RemoveChannelMessage,
3043 response: Response<proto::RemoveChannelMessage>,
3044 session: Session,
3045) -> Result<()> {
3046 let channel_id = ChannelId::from_proto(request.channel_id);
3047 let message_id = MessageId::from_proto(request.message_id);
3048 let connection_ids = session
3049 .db()
3050 .await
3051 .remove_channel_message(channel_id, message_id, session.user_id)
3052 .await?;
3053 broadcast(Some(session.connection_id), connection_ids, |connection| {
3054 session.peer.send(connection, request.clone())
3055 });
3056 response.send(proto::Ack {})?;
3057 Ok(())
3058}
3059
3060/// Mark a channel message as read
3061async fn acknowledge_channel_message(
3062 request: proto::AckChannelMessage,
3063 session: Session,
3064) -> Result<()> {
3065 let channel_id = ChannelId::from_proto(request.channel_id);
3066 let message_id = MessageId::from_proto(request.message_id);
3067 let notifications = session
3068 .db()
3069 .await
3070 .observe_channel_message(channel_id, session.user_id, message_id)
3071 .await?;
3072 send_notifications(
3073 &*session.connection_pool().await,
3074 &session.peer,
3075 notifications,
3076 );
3077 Ok(())
3078}
3079
3080/// Mark a buffer version as synced
3081async fn acknowledge_buffer_version(
3082 request: proto::AckBufferOperation,
3083 session: Session,
3084) -> Result<()> {
3085 let buffer_id = BufferId::from_proto(request.buffer_id);
3086 session
3087 .db()
3088 .await
3089 .observe_buffer_version(
3090 buffer_id,
3091 session.user_id,
3092 request.epoch as i32,
3093 &request.version,
3094 )
3095 .await?;
3096 Ok(())
3097}
3098
3099/// Start receiving chat updates for a channel
3100async fn join_channel_chat(
3101 request: proto::JoinChannelChat,
3102 response: Response<proto::JoinChannelChat>,
3103 session: Session,
3104) -> Result<()> {
3105 let channel_id = ChannelId::from_proto(request.channel_id);
3106
3107 let db = session.db().await;
3108 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3109 .await?;
3110 let messages = db
3111 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3112 .await?;
3113 response.send(proto::JoinChannelChatResponse {
3114 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3115 messages,
3116 })?;
3117 Ok(())
3118}
3119
3120/// Stop receiving chat updates for a channel
3121async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3122 let channel_id = ChannelId::from_proto(request.channel_id);
3123 session
3124 .db()
3125 .await
3126 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3127 .await?;
3128 Ok(())
3129}
3130
3131/// Retrieve the chat history for a channel
3132async fn get_channel_messages(
3133 request: proto::GetChannelMessages,
3134 response: Response<proto::GetChannelMessages>,
3135 session: Session,
3136) -> Result<()> {
3137 let channel_id = ChannelId::from_proto(request.channel_id);
3138 let messages = session
3139 .db()
3140 .await
3141 .get_channel_messages(
3142 channel_id,
3143 session.user_id,
3144 MESSAGE_COUNT_PER_PAGE,
3145 Some(MessageId::from_proto(request.before_message_id)),
3146 )
3147 .await?;
3148 response.send(proto::GetChannelMessagesResponse {
3149 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3150 messages,
3151 })?;
3152 Ok(())
3153}
3154
3155/// Retrieve specific chat messages
3156async fn get_channel_messages_by_id(
3157 request: proto::GetChannelMessagesById,
3158 response: Response<proto::GetChannelMessagesById>,
3159 session: Session,
3160) -> Result<()> {
3161 let message_ids = request
3162 .message_ids
3163 .iter()
3164 .map(|id| MessageId::from_proto(*id))
3165 .collect::<Vec<_>>();
3166 let messages = session
3167 .db()
3168 .await
3169 .get_channel_messages_by_id(session.user_id, &message_ids)
3170 .await?;
3171 response.send(proto::GetChannelMessagesResponse {
3172 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3173 messages,
3174 })?;
3175 Ok(())
3176}
3177
3178/// Retrieve the current users notifications
3179async fn get_notifications(
3180 request: proto::GetNotifications,
3181 response: Response<proto::GetNotifications>,
3182 session: Session,
3183) -> Result<()> {
3184 let notifications = session
3185 .db()
3186 .await
3187 .get_notifications(
3188 session.user_id,
3189 NOTIFICATION_COUNT_PER_PAGE,
3190 request
3191 .before_id
3192 .map(|id| db::NotificationId::from_proto(id)),
3193 )
3194 .await?;
3195 response.send(proto::GetNotificationsResponse {
3196 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3197 notifications,
3198 })?;
3199 Ok(())
3200}
3201
3202/// Mark notifications as read
3203async fn mark_notification_as_read(
3204 request: proto::MarkNotificationRead,
3205 response: Response<proto::MarkNotificationRead>,
3206 session: Session,
3207) -> Result<()> {
3208 let database = &session.db().await;
3209 let notifications = database
3210 .mark_notification_as_read_by_id(
3211 session.user_id,
3212 NotificationId::from_proto(request.notification_id),
3213 )
3214 .await?;
3215 send_notifications(
3216 &*session.connection_pool().await,
3217 &session.peer,
3218 notifications,
3219 );
3220 response.send(proto::Ack {})?;
3221 Ok(())
3222}
3223
3224/// Get the current users information
3225async fn get_private_user_info(
3226 _request: proto::GetPrivateUserInfo,
3227 response: Response<proto::GetPrivateUserInfo>,
3228 session: Session,
3229) -> Result<()> {
3230 let db = session.db().await;
3231
3232 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3233 let user = db
3234 .get_user_by_id(session.user_id)
3235 .await?
3236 .ok_or_else(|| anyhow!("user not found"))?;
3237 let flags = db.get_user_flags(session.user_id).await?;
3238
3239 response.send(proto::GetPrivateUserInfoResponse {
3240 metrics_id,
3241 staff: user.admin,
3242 flags,
3243 })?;
3244 Ok(())
3245}
3246
3247fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3248 match message {
3249 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3250 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3251 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3252 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3253 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3254 code: frame.code.into(),
3255 reason: frame.reason,
3256 })),
3257 }
3258}
3259
3260fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3261 match message {
3262 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3263 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3264 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3265 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3266 AxumMessage::Close(frame) => {
3267 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3268 code: frame.code.into(),
3269 reason: frame.reason,
3270 }))
3271 }
3272 }
3273}
3274
3275fn notify_membership_updated(
3276 connection_pool: &ConnectionPool,
3277 result: MembershipUpdated,
3278 user_id: UserId,
3279 peer: &Peer,
3280) {
3281 let mut update = build_channels_update(result.new_channels, vec![]);
3282 update.delete_channels = result
3283 .removed_channels
3284 .into_iter()
3285 .map(|id| id.to_proto())
3286 .collect();
3287 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3288
3289 for connection_id in connection_pool.user_connection_ids(user_id) {
3290 peer.send(connection_id, update.clone()).trace_err();
3291 }
3292}
3293
3294fn build_channels_update(
3295 channels: ChannelsForUser,
3296 channel_invites: Vec<db::Channel>,
3297) -> proto::UpdateChannels {
3298 let mut update = proto::UpdateChannels::default();
3299
3300 for channel in channels.channels {
3301 update.channels.push(channel.to_proto());
3302 }
3303
3304 update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3305 update.unseen_channel_messages = channels.channel_messages;
3306
3307 for (channel_id, participants) in channels.channel_participants {
3308 update
3309 .channel_participants
3310 .push(proto::ChannelParticipants {
3311 channel_id: channel_id.to_proto(),
3312 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3313 });
3314 }
3315
3316 for channel in channel_invites {
3317 update.channel_invitations.push(channel.to_proto());
3318 }
3319
3320 update
3321}
3322
3323fn build_initial_contacts_update(
3324 contacts: Vec<db::Contact>,
3325 pool: &ConnectionPool,
3326) -> proto::UpdateContacts {
3327 let mut update = proto::UpdateContacts::default();
3328
3329 for contact in contacts {
3330 match contact {
3331 db::Contact::Accepted { user_id, busy } => {
3332 update.contacts.push(contact_for_user(user_id, busy, &pool));
3333 }
3334 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3335 db::Contact::Incoming { user_id } => {
3336 update
3337 .incoming_requests
3338 .push(proto::IncomingContactRequest {
3339 requester_id: user_id.to_proto(),
3340 })
3341 }
3342 }
3343 }
3344
3345 update
3346}
3347
3348fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3349 proto::Contact {
3350 user_id: user_id.to_proto(),
3351 online: pool.is_user_online(user_id),
3352 busy,
3353 }
3354}
3355
3356fn room_updated(room: &proto::Room, peer: &Peer) {
3357 broadcast(
3358 None,
3359 room.participants
3360 .iter()
3361 .filter_map(|participant| Some(participant.peer_id?.into())),
3362 |peer_id| {
3363 peer.send(
3364 peer_id.into(),
3365 proto::RoomUpdated {
3366 room: Some(room.clone()),
3367 },
3368 )
3369 },
3370 );
3371}
3372
3373fn channel_updated(
3374 channel_id: ChannelId,
3375 room: &proto::Room,
3376 channel_members: &[UserId],
3377 peer: &Peer,
3378 pool: &ConnectionPool,
3379) {
3380 let participants = room
3381 .participants
3382 .iter()
3383 .map(|p| p.user_id)
3384 .collect::<Vec<_>>();
3385
3386 broadcast(
3387 None,
3388 channel_members
3389 .iter()
3390 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3391 |peer_id| {
3392 peer.send(
3393 peer_id.into(),
3394 proto::UpdateChannels {
3395 channel_participants: vec![proto::ChannelParticipants {
3396 channel_id: channel_id.to_proto(),
3397 participant_user_ids: participants.clone(),
3398 }],
3399 ..Default::default()
3400 },
3401 )
3402 },
3403 );
3404}
3405
3406async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3407 let db = session.db().await;
3408
3409 let contacts = db.get_contacts(user_id).await?;
3410 let busy = db.is_user_busy(user_id).await?;
3411
3412 let pool = session.connection_pool().await;
3413 let updated_contact = contact_for_user(user_id, busy, &pool);
3414 for contact in contacts {
3415 if let db::Contact::Accepted {
3416 user_id: contact_user_id,
3417 ..
3418 } = contact
3419 {
3420 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3421 session
3422 .peer
3423 .send(
3424 contact_conn_id,
3425 proto::UpdateContacts {
3426 contacts: vec![updated_contact.clone()],
3427 remove_contacts: Default::default(),
3428 incoming_requests: Default::default(),
3429 remove_incoming_requests: Default::default(),
3430 outgoing_requests: Default::default(),
3431 remove_outgoing_requests: Default::default(),
3432 },
3433 )
3434 .trace_err();
3435 }
3436 }
3437 }
3438 Ok(())
3439}
3440
3441async fn leave_room_for_session(session: &Session) -> Result<()> {
3442 let mut contacts_to_update = HashSet::default();
3443
3444 let room_id;
3445 let canceled_calls_to_user_ids;
3446 let live_kit_room;
3447 let delete_live_kit_room;
3448 let room;
3449 let channel_members;
3450 let channel_id;
3451
3452 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3453 contacts_to_update.insert(session.user_id);
3454
3455 for project in left_room.left_projects.values() {
3456 project_left(project, session);
3457 }
3458
3459 room_id = RoomId::from_proto(left_room.room.id);
3460 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3461 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3462 delete_live_kit_room = left_room.deleted;
3463 room = mem::take(&mut left_room.room);
3464 channel_members = mem::take(&mut left_room.channel_members);
3465 channel_id = left_room.channel_id;
3466
3467 room_updated(&room, &session.peer);
3468 } else {
3469 return Ok(());
3470 }
3471
3472 if let Some(channel_id) = channel_id {
3473 channel_updated(
3474 channel_id,
3475 &room,
3476 &channel_members,
3477 &session.peer,
3478 &*session.connection_pool().await,
3479 );
3480 }
3481
3482 {
3483 let pool = session.connection_pool().await;
3484 for canceled_user_id in canceled_calls_to_user_ids {
3485 for connection_id in pool.user_connection_ids(canceled_user_id) {
3486 session
3487 .peer
3488 .send(
3489 connection_id,
3490 proto::CallCanceled {
3491 room_id: room_id.to_proto(),
3492 },
3493 )
3494 .trace_err();
3495 }
3496 contacts_to_update.insert(canceled_user_id);
3497 }
3498 }
3499
3500 for contact_user_id in contacts_to_update {
3501 update_user_contacts(contact_user_id, &session).await?;
3502 }
3503
3504 if let Some(live_kit) = session.live_kit_client.as_ref() {
3505 live_kit
3506 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3507 .await
3508 .trace_err();
3509
3510 if delete_live_kit_room {
3511 live_kit.delete_room(live_kit_room).await.trace_err();
3512 }
3513 }
3514
3515 Ok(())
3516}
3517
3518async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3519 let left_channel_buffers = session
3520 .db()
3521 .await
3522 .leave_channel_buffers(session.connection_id)
3523 .await?;
3524
3525 for left_buffer in left_channel_buffers {
3526 channel_buffer_updated(
3527 session.connection_id,
3528 left_buffer.connections,
3529 &proto::UpdateChannelBufferCollaborators {
3530 channel_id: left_buffer.channel_id.to_proto(),
3531 collaborators: left_buffer.collaborators,
3532 },
3533 &session.peer,
3534 );
3535 }
3536
3537 Ok(())
3538}
3539
3540fn project_left(project: &db::LeftProject, session: &Session) {
3541 for connection_id in &project.connection_ids {
3542 if project.host_user_id == session.user_id {
3543 session
3544 .peer
3545 .send(
3546 *connection_id,
3547 proto::UnshareProject {
3548 project_id: project.id.to_proto(),
3549 },
3550 )
3551 .trace_err();
3552 } else {
3553 session
3554 .peer
3555 .send(
3556 *connection_id,
3557 proto::RemoveProjectCollaborator {
3558 project_id: project.id.to_proto(),
3559 peer_id: Some(session.connection_id.into()),
3560 },
3561 )
3562 .trace_err();
3563 }
3564 }
3565}
3566
3567pub trait ResultExt {
3568 type Ok;
3569
3570 fn trace_err(self) -> Option<Self::Ok>;
3571}
3572
3573impl<T, E> ResultExt for Result<T, E>
3574where
3575 E: std::fmt::Debug,
3576{
3577 type Ok = T;
3578
3579 fn trace_err(self) -> Option<T> {
3580 match self {
3581 Ok(value) => Some(value),
3582 Err(error) => {
3583 tracing::error!("{:?}", error);
3584 None
3585 }
3586 }
3587 }
3588}