1mod connection_pool;
2
3use crate::{
4 auth::{self, Impersonator},
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
7 InviteMemberResult, MembershipUpdated, MessageId, NotificationId, ProjectId,
8 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
9 },
10 executor::Executor,
11 AppState, Error, Result,
12};
13use anyhow::anyhow;
14use async_tungstenite::tungstenite::{
15 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
16};
17use axum::{
18 body::Body,
19 extract::{
20 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
21 ConnectInfo, WebSocketUpgrade,
22 },
23 headers::{Header, HeaderName},
24 http::StatusCode,
25 middleware,
26 response::IntoResponse,
27 routing::get,
28 Extension, Router, TypedHeader,
29};
30use collections::{HashMap, HashSet};
31pub use connection_pool::ConnectionPool;
32use futures::{
33 channel::oneshot,
34 future::{self, BoxFuture},
35 stream::FuturesUnordered,
36 FutureExt, SinkExt, StreamExt, TryStreamExt,
37};
38use lazy_static::lazy_static;
39use prometheus::{register_int_gauge, IntGauge};
40use rpc::{
41 proto::{
42 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
43 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
44 },
45 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
46};
47use serde::{Serialize, Serializer};
48use std::{
49 any::TypeId,
50 fmt,
51 future::Future,
52 marker::PhantomData,
53 mem,
54 net::SocketAddr,
55 ops::{Deref, DerefMut},
56 rc::Rc,
57 sync::{
58 atomic::{AtomicBool, Ordering::SeqCst},
59 Arc,
60 },
61 time::{Duration, Instant},
62};
63use time::OffsetDateTime;
64use tokio::sync::{watch, Semaphore};
65use tower::ServiceBuilder;
66use tracing::{field, info_span, instrument, Instrument};
67
68pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
69pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
70
71const MESSAGE_COUNT_PER_PAGE: usize = 100;
72const MAX_MESSAGE_LEN: usize = 1024;
73const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
74
75lazy_static! {
76 static ref METRIC_CONNECTIONS: IntGauge =
77 register_int_gauge!("connections", "number of connections").unwrap();
78 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
79 "shared_projects",
80 "number of open projects with one or more guests"
81 )
82 .unwrap();
83}
84
85type MessageHandler =
86 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
87
88struct Response<R> {
89 peer: Arc<Peer>,
90 receipt: Receipt<R>,
91 responded: Arc<AtomicBool>,
92}
93
94impl<R: RequestMessage> Response<R> {
95 fn send(self, payload: R::Response) -> Result<()> {
96 self.responded.store(true, SeqCst);
97 self.peer.respond(self.receipt, payload)?;
98 Ok(())
99 }
100}
101
102#[derive(Clone)]
103struct Session {
104 zed_environment: Arc<str>,
105 user_id: UserId,
106 connection_id: ConnectionId,
107 db: Arc<tokio::sync::Mutex<DbHandle>>,
108 peer: Arc<Peer>,
109 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
110 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
111 _executor: Executor,
112}
113
114impl Session {
115 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
116 #[cfg(test)]
117 tokio::task::yield_now().await;
118 let guard = self.db.lock().await;
119 #[cfg(test)]
120 tokio::task::yield_now().await;
121 guard
122 }
123
124 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
125 #[cfg(test)]
126 tokio::task::yield_now().await;
127 let guard = self.connection_pool.lock();
128 ConnectionPoolGuard {
129 guard,
130 _not_send: PhantomData,
131 }
132 }
133}
134
135impl fmt::Debug for Session {
136 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137 f.debug_struct("Session")
138 .field("user_id", &self.user_id)
139 .field("connection_id", &self.connection_id)
140 .finish()
141 }
142}
143
144struct DbHandle(Arc<Database>);
145
146impl Deref for DbHandle {
147 type Target = Database;
148
149 fn deref(&self) -> &Self::Target {
150 self.0.as_ref()
151 }
152}
153
154pub struct Server {
155 id: parking_lot::Mutex<ServerId>,
156 peer: Arc<Peer>,
157 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
158 app_state: Arc<AppState>,
159 executor: Executor,
160 handlers: HashMap<TypeId, MessageHandler>,
161 teardown: watch::Sender<()>,
162}
163
164pub(crate) struct ConnectionPoolGuard<'a> {
165 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
166 _not_send: PhantomData<Rc<()>>,
167}
168
169#[derive(Serialize)]
170pub struct ServerSnapshot<'a> {
171 peer: &'a Peer,
172 #[serde(serialize_with = "serialize_deref")]
173 connection_pool: ConnectionPoolGuard<'a>,
174}
175
176pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
177where
178 S: Serializer,
179 T: Deref<Target = U>,
180 U: Serialize,
181{
182 Serialize::serialize(value.deref(), serializer)
183}
184
185impl Server {
186 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
187 let mut server = Self {
188 id: parking_lot::Mutex::new(id),
189 peer: Peer::new(id.0 as u32),
190 app_state,
191 executor,
192 connection_pool: Default::default(),
193 handlers: Default::default(),
194 teardown: watch::channel(()).0,
195 };
196
197 server
198 .add_request_handler(ping)
199 .add_request_handler(create_room)
200 .add_request_handler(join_room)
201 .add_request_handler(rejoin_room)
202 .add_request_handler(leave_room)
203 .add_request_handler(set_room_participant_role)
204 .add_request_handler(call)
205 .add_request_handler(cancel_call)
206 .add_message_handler(decline_call)
207 .add_request_handler(update_participant_location)
208 .add_request_handler(share_project)
209 .add_message_handler(unshare_project)
210 .add_request_handler(join_project)
211 .add_message_handler(leave_project)
212 .add_request_handler(update_project)
213 .add_request_handler(update_worktree)
214 .add_message_handler(start_language_server)
215 .add_message_handler(update_language_server)
216 .add_message_handler(update_diagnostic_summary)
217 .add_message_handler(update_worktree_settings)
218 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
219 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
220 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
221 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
222 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
223 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
224 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
225 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
226 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
227 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
228 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
229 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
230 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
231 .add_request_handler(
232 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
233 )
234 .add_request_handler(
235 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
236 )
237 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
238 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
239 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
240 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
241 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
242 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
243 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
244 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
245 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
246 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
247 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
248 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
249 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
250 .add_message_handler(create_buffer_for_peer)
251 .add_request_handler(update_buffer)
252 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
253 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
254 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
255 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
256 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
257 .add_request_handler(get_users)
258 .add_request_handler(fuzzy_search_users)
259 .add_request_handler(request_contact)
260 .add_request_handler(remove_contact)
261 .add_request_handler(respond_to_contact_request)
262 .add_request_handler(create_channel)
263 .add_request_handler(delete_channel)
264 .add_request_handler(invite_channel_member)
265 .add_request_handler(remove_channel_member)
266 .add_request_handler(set_channel_member_role)
267 .add_request_handler(set_channel_visibility)
268 .add_request_handler(rename_channel)
269 .add_request_handler(join_channel_buffer)
270 .add_request_handler(leave_channel_buffer)
271 .add_message_handler(update_channel_buffer)
272 .add_request_handler(rejoin_channel_buffers)
273 .add_request_handler(get_channel_members)
274 .add_request_handler(respond_to_channel_invite)
275 .add_request_handler(join_channel)
276 .add_request_handler(join_channel_chat)
277 .add_message_handler(leave_channel_chat)
278 .add_request_handler(send_channel_message)
279 .add_request_handler(remove_channel_message)
280 .add_request_handler(get_channel_messages)
281 .add_request_handler(get_channel_messages_by_id)
282 .add_request_handler(get_notifications)
283 .add_request_handler(mark_notification_as_read)
284 .add_request_handler(move_channel)
285 .add_request_handler(follow)
286 .add_message_handler(unfollow)
287 .add_message_handler(update_followers)
288 .add_request_handler(get_private_user_info)
289 .add_message_handler(acknowledge_channel_message)
290 .add_message_handler(acknowledge_buffer_version);
291
292 Arc::new(server)
293 }
294
295 pub async fn start(&self) -> Result<()> {
296 let server_id = *self.id.lock();
297 let app_state = self.app_state.clone();
298 let peer = self.peer.clone();
299 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
300 let pool = self.connection_pool.clone();
301 let live_kit_client = self.app_state.live_kit_client.clone();
302
303 let span = info_span!("start server");
304 self.executor.spawn_detached(
305 async move {
306 tracing::info!("waiting for cleanup timeout");
307 timeout.await;
308 tracing::info!("cleanup timeout expired, retrieving stale rooms");
309 if let Some((room_ids, channel_ids)) = app_state
310 .db
311 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
312 .await
313 .trace_err()
314 {
315 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
316 tracing::info!(
317 stale_channel_buffer_count = channel_ids.len(),
318 "retrieved stale channel buffers"
319 );
320
321 for channel_id in channel_ids {
322 if let Some(refreshed_channel_buffer) = app_state
323 .db
324 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
325 .await
326 .trace_err()
327 {
328 for connection_id in refreshed_channel_buffer.connection_ids {
329 peer.send(
330 connection_id,
331 proto::UpdateChannelBufferCollaborators {
332 channel_id: channel_id.to_proto(),
333 collaborators: refreshed_channel_buffer
334 .collaborators
335 .clone(),
336 },
337 )
338 .trace_err();
339 }
340 }
341 }
342
343 for room_id in room_ids {
344 let mut contacts_to_update = HashSet::default();
345 let mut canceled_calls_to_user_ids = Vec::new();
346 let mut live_kit_room = String::new();
347 let mut delete_live_kit_room = false;
348
349 if let Some(mut refreshed_room) = app_state
350 .db
351 .clear_stale_room_participants(room_id, server_id)
352 .await
353 .trace_err()
354 {
355 tracing::info!(
356 room_id = room_id.0,
357 new_participant_count = refreshed_room.room.participants.len(),
358 "refreshed room"
359 );
360 room_updated(&refreshed_room.room, &peer);
361 if let Some(channel_id) = refreshed_room.channel_id {
362 channel_updated(
363 channel_id,
364 &refreshed_room.room,
365 &refreshed_room.channel_members,
366 &peer,
367 &*pool.lock(),
368 );
369 }
370 contacts_to_update
371 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
372 contacts_to_update
373 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
374 canceled_calls_to_user_ids =
375 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
376 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
377 delete_live_kit_room = refreshed_room.room.participants.is_empty();
378 }
379
380 {
381 let pool = pool.lock();
382 for canceled_user_id in canceled_calls_to_user_ids {
383 for connection_id in pool.user_connection_ids(canceled_user_id) {
384 peer.send(
385 connection_id,
386 proto::CallCanceled {
387 room_id: room_id.to_proto(),
388 },
389 )
390 .trace_err();
391 }
392 }
393 }
394
395 for user_id in contacts_to_update {
396 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
397 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
398 if let Some((busy, contacts)) = busy.zip(contacts) {
399 let pool = pool.lock();
400 let updated_contact = contact_for_user(user_id, busy, &pool);
401 for contact in contacts {
402 if let db::Contact::Accepted {
403 user_id: contact_user_id,
404 ..
405 } = contact
406 {
407 for contact_conn_id in
408 pool.user_connection_ids(contact_user_id)
409 {
410 peer.send(
411 contact_conn_id,
412 proto::UpdateContacts {
413 contacts: vec![updated_contact.clone()],
414 remove_contacts: Default::default(),
415 incoming_requests: Default::default(),
416 remove_incoming_requests: Default::default(),
417 outgoing_requests: Default::default(),
418 remove_outgoing_requests: Default::default(),
419 },
420 )
421 .trace_err();
422 }
423 }
424 }
425 }
426 }
427
428 if let Some(live_kit) = live_kit_client.as_ref() {
429 if delete_live_kit_room {
430 live_kit.delete_room(live_kit_room).await.trace_err();
431 }
432 }
433 }
434 }
435
436 app_state
437 .db
438 .delete_stale_servers(&app_state.config.zed_environment, server_id)
439 .await
440 .trace_err();
441 }
442 .instrument(span),
443 );
444 Ok(())
445 }
446
447 pub fn teardown(&self) {
448 self.peer.teardown();
449 self.connection_pool.lock().reset();
450 let _ = self.teardown.send(());
451 }
452
453 #[cfg(test)]
454 pub fn reset(&self, id: ServerId) {
455 self.teardown();
456 *self.id.lock() = id;
457 self.peer.reset(id.0 as u32);
458 }
459
460 #[cfg(test)]
461 pub fn id(&self) -> ServerId {
462 *self.id.lock()
463 }
464
465 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
466 where
467 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
468 Fut: 'static + Send + Future<Output = Result<()>>,
469 M: EnvelopedMessage,
470 {
471 let prev_handler = self.handlers.insert(
472 TypeId::of::<M>(),
473 Box::new(move |envelope, session| {
474 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
475 let span = info_span!(
476 "handle message",
477 payload_type = envelope.payload_type_name()
478 );
479 span.in_scope(|| {
480 tracing::info!(
481 payload_type = envelope.payload_type_name(),
482 "message received"
483 );
484 });
485 let start_time = Instant::now();
486 let future = (handler)(*envelope, session);
487 async move {
488 let result = future.await;
489 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
490 match result {
491 Err(error) => {
492 tracing::error!(%error, ?duration_ms, "error handling message")
493 }
494 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
495 }
496 }
497 .instrument(span)
498 .boxed()
499 }),
500 );
501 if prev_handler.is_some() {
502 panic!("registered a handler for the same message twice");
503 }
504 self
505 }
506
507 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
508 where
509 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
510 Fut: 'static + Send + Future<Output = Result<()>>,
511 M: EnvelopedMessage,
512 {
513 self.add_handler(move |envelope, session| handler(envelope.payload, session));
514 self
515 }
516
517 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
518 where
519 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
520 Fut: Send + Future<Output = Result<()>>,
521 M: RequestMessage,
522 {
523 let handler = Arc::new(handler);
524 self.add_handler(move |envelope, session| {
525 let receipt = envelope.receipt();
526 let handler = handler.clone();
527 async move {
528 let peer = session.peer.clone();
529 let responded = Arc::new(AtomicBool::default());
530 let response = Response {
531 peer: peer.clone(),
532 responded: responded.clone(),
533 receipt,
534 };
535 match (handler)(envelope.payload, response, session).await {
536 Ok(()) => {
537 if responded.load(std::sync::atomic::Ordering::SeqCst) {
538 Ok(())
539 } else {
540 Err(anyhow!("handler did not send a response"))?
541 }
542 }
543 Err(error) => {
544 let proto_err = match &error {
545 Error::Internal(err) => err.to_proto(),
546 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
547 };
548 peer.respond_with_error(receipt, proto_err)?;
549 Err(error)
550 }
551 }
552 }
553 })
554 }
555
556 pub fn handle_connection(
557 self: &Arc<Self>,
558 connection: Connection,
559 address: String,
560 user: User,
561 impersonator: Option<User>,
562 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
563 executor: Executor,
564 ) -> impl Future<Output = Result<()>> {
565 let this = self.clone();
566 let user_id = user.id;
567 let login = user.github_login;
568 let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
569 if let Some(impersonator) = impersonator {
570 span.record("impersonator", &impersonator.github_login);
571 }
572 let mut teardown = self.teardown.subscribe();
573 async move {
574 let (connection_id, handle_io, mut incoming_rx) = this
575 .peer
576 .add_connection(connection, {
577 let executor = executor.clone();
578 move |duration| executor.sleep(duration)
579 });
580
581 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
582 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
583 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
584
585 if let Some(send_connection_id) = send_connection_id.take() {
586 let _ = send_connection_id.send(connection_id);
587 }
588
589 if !user.connected_once {
590 this.peer.send(connection_id, proto::ShowContacts {})?;
591 this.app_state.db.set_user_connected_once(user_id, true).await?;
592 }
593
594 let (contacts, channels_for_user, channel_invites) = future::try_join3(
595 this.app_state.db.get_contacts(user_id),
596 this.app_state.db.get_channels_for_user(user_id),
597 this.app_state.db.get_channel_invites_for_user(user_id),
598 ).await?;
599
600 {
601 let mut pool = this.connection_pool.lock();
602 pool.add_connection(connection_id, user_id, user.admin);
603 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
604 this.peer.send(connection_id, build_update_user_channels(&channels_for_user.channel_memberships))?;
605 this.peer.send(connection_id, build_channels_update(
606 channels_for_user,
607 channel_invites
608 ))?;
609 }
610
611 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
612 this.peer.send(connection_id, incoming_call)?;
613 }
614
615 let session = Session {
616 user_id,
617 connection_id,
618 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
619 zed_environment: this.app_state.config.zed_environment.clone(),
620 peer: this.peer.clone(),
621 connection_pool: this.connection_pool.clone(),
622 live_kit_client: this.app_state.live_kit_client.clone(),
623 _executor: executor.clone()
624 };
625 update_user_contacts(user_id, &session).await?;
626
627 let handle_io = handle_io.fuse();
628 futures::pin_mut!(handle_io);
629
630 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
631 // This prevents deadlocks when e.g., client A performs a request to client B and
632 // client B performs a request to client A. If both clients stop processing further
633 // messages until their respective request completes, they won't have a chance to
634 // respond to the other client's request and cause a deadlock.
635 //
636 // This arrangement ensures we will attempt to process earlier messages first, but fall
637 // back to processing messages arrived later in the spirit of making progress.
638 let mut foreground_message_handlers = FuturesUnordered::new();
639 let concurrent_handlers = Arc::new(Semaphore::new(256));
640 loop {
641 let next_message = async {
642 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
643 let message = incoming_rx.next().await;
644 (permit, message)
645 }.fuse();
646 futures::pin_mut!(next_message);
647 futures::select_biased! {
648 _ = teardown.changed().fuse() => return Ok(()),
649 result = handle_io => {
650 if let Err(error) = result {
651 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
652 }
653 break;
654 }
655 _ = foreground_message_handlers.next() => {}
656 next_message = next_message => {
657 let (permit, message) = next_message;
658 if let Some(message) = message {
659 let type_name = message.payload_type_name();
660 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
661 let span_enter = span.enter();
662 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
663 let is_background = message.is_background();
664 let handle_message = (handler)(message, session.clone());
665 drop(span_enter);
666
667 let handle_message = async move {
668 handle_message.await;
669 drop(permit);
670 }.instrument(span);
671 if is_background {
672 executor.spawn_detached(handle_message);
673 } else {
674 foreground_message_handlers.push(handle_message);
675 }
676 } else {
677 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
678 }
679 } else {
680 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
681 break;
682 }
683 }
684 }
685 }
686
687 drop(foreground_message_handlers);
688 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
689 if let Err(error) = connection_lost(session, teardown, executor).await {
690 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
691 }
692
693 Ok(())
694 }.instrument(span)
695 }
696
697 pub async fn invite_code_redeemed(
698 self: &Arc<Self>,
699 inviter_id: UserId,
700 invitee_id: UserId,
701 ) -> Result<()> {
702 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
703 if let Some(code) = &user.invite_code {
704 let pool = self.connection_pool.lock();
705 let invitee_contact = contact_for_user(invitee_id, false, &pool);
706 for connection_id in pool.user_connection_ids(inviter_id) {
707 self.peer.send(
708 connection_id,
709 proto::UpdateContacts {
710 contacts: vec![invitee_contact.clone()],
711 ..Default::default()
712 },
713 )?;
714 self.peer.send(
715 connection_id,
716 proto::UpdateInviteInfo {
717 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
718 count: user.invite_count as u32,
719 },
720 )?;
721 }
722 }
723 }
724 Ok(())
725 }
726
727 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
728 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
729 if let Some(invite_code) = &user.invite_code {
730 let pool = self.connection_pool.lock();
731 for connection_id in pool.user_connection_ids(user_id) {
732 self.peer.send(
733 connection_id,
734 proto::UpdateInviteInfo {
735 url: format!(
736 "{}{}",
737 self.app_state.config.invite_link_prefix, invite_code
738 ),
739 count: user.invite_count as u32,
740 },
741 )?;
742 }
743 }
744 }
745 Ok(())
746 }
747
748 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
749 ServerSnapshot {
750 connection_pool: ConnectionPoolGuard {
751 guard: self.connection_pool.lock(),
752 _not_send: PhantomData,
753 },
754 peer: &self.peer,
755 }
756 }
757}
758
759impl<'a> Deref for ConnectionPoolGuard<'a> {
760 type Target = ConnectionPool;
761
762 fn deref(&self) -> &Self::Target {
763 &*self.guard
764 }
765}
766
767impl<'a> DerefMut for ConnectionPoolGuard<'a> {
768 fn deref_mut(&mut self) -> &mut Self::Target {
769 &mut *self.guard
770 }
771}
772
773impl<'a> Drop for ConnectionPoolGuard<'a> {
774 fn drop(&mut self) {
775 #[cfg(test)]
776 self.check_invariants();
777 }
778}
779
780fn broadcast<F>(
781 sender_id: Option<ConnectionId>,
782 receiver_ids: impl IntoIterator<Item = ConnectionId>,
783 mut f: F,
784) where
785 F: FnMut(ConnectionId) -> anyhow::Result<()>,
786{
787 for receiver_id in receiver_ids {
788 if Some(receiver_id) != sender_id {
789 if let Err(error) = f(receiver_id) {
790 tracing::error!("failed to send to {:?} {}", receiver_id, error);
791 }
792 }
793 }
794}
795
796lazy_static! {
797 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
798}
799
800pub struct ProtocolVersion(u32);
801
802impl Header for ProtocolVersion {
803 fn name() -> &'static HeaderName {
804 &ZED_PROTOCOL_VERSION
805 }
806
807 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
808 where
809 Self: Sized,
810 I: Iterator<Item = &'i axum::http::HeaderValue>,
811 {
812 let version = values
813 .next()
814 .ok_or_else(axum::headers::Error::invalid)?
815 .to_str()
816 .map_err(|_| axum::headers::Error::invalid())?
817 .parse()
818 .map_err(|_| axum::headers::Error::invalid())?;
819 Ok(Self(version))
820 }
821
822 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
823 values.extend([self.0.to_string().parse().unwrap()]);
824 }
825}
826
827pub fn routes(server: Arc<Server>) -> Router<Body> {
828 Router::new()
829 .route("/rpc", get(handle_websocket_request))
830 .layer(
831 ServiceBuilder::new()
832 .layer(Extension(server.app_state.clone()))
833 .layer(middleware::from_fn(auth::validate_header)),
834 )
835 .route("/metrics", get(handle_metrics))
836 .layer(Extension(server))
837}
838
839pub async fn handle_websocket_request(
840 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
841 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
842 Extension(server): Extension<Arc<Server>>,
843 Extension(user): Extension<User>,
844 Extension(impersonator): Extension<Impersonator>,
845 ws: WebSocketUpgrade,
846) -> axum::response::Response {
847 if protocol_version != rpc::PROTOCOL_VERSION {
848 return (
849 StatusCode::UPGRADE_REQUIRED,
850 "client must be upgraded".to_string(),
851 )
852 .into_response();
853 }
854 let socket_address = socket_address.to_string();
855 ws.on_upgrade(move |socket| {
856 use util::ResultExt;
857 let socket = socket
858 .map_ok(to_tungstenite_message)
859 .err_into()
860 .with(|message| async move { Ok(to_axum_message(message)) });
861 let connection = Connection::new(Box::pin(socket));
862 async move {
863 server
864 .handle_connection(
865 connection,
866 socket_address,
867 user,
868 impersonator.0,
869 None,
870 Executor::Production,
871 )
872 .await
873 .log_err();
874 }
875 })
876}
877
878pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
879 let connections = server
880 .connection_pool
881 .lock()
882 .connections()
883 .filter(|connection| !connection.admin)
884 .count();
885
886 METRIC_CONNECTIONS.set(connections as _);
887
888 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
889 METRIC_SHARED_PROJECTS.set(shared_projects as _);
890
891 let encoder = prometheus::TextEncoder::new();
892 let metric_families = prometheus::gather();
893 let encoded_metrics = encoder
894 .encode_to_string(&metric_families)
895 .map_err(|err| anyhow!("{}", err))?;
896 Ok(encoded_metrics)
897}
898
899#[instrument(err, skip(executor))]
900async fn connection_lost(
901 session: Session,
902 mut teardown: watch::Receiver<()>,
903 executor: Executor,
904) -> Result<()> {
905 session.peer.disconnect(session.connection_id);
906 session
907 .connection_pool()
908 .await
909 .remove_connection(session.connection_id)?;
910
911 session
912 .db()
913 .await
914 .connection_lost(session.connection_id)
915 .await
916 .trace_err();
917
918 futures::select_biased! {
919 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
920 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
921 leave_room_for_session(&session).await.trace_err();
922 leave_channel_buffers_for_session(&session)
923 .await
924 .trace_err();
925
926 if !session
927 .connection_pool()
928 .await
929 .is_user_online(session.user_id)
930 {
931 let db = session.db().await;
932 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
933 room_updated(&room, &session.peer);
934 }
935 }
936
937 update_user_contacts(session.user_id, &session).await?;
938 }
939 _ = teardown.changed().fuse() => {}
940 }
941
942 Ok(())
943}
944
945/// Acknowledges a ping from a client, used to keep the connection alive.
946async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
947 response.send(proto::Ack {})?;
948 Ok(())
949}
950
951/// Creates a new room for calling (outside of channels)
952async fn create_room(
953 _request: proto::CreateRoom,
954 response: Response<proto::CreateRoom>,
955 session: Session,
956) -> Result<()> {
957 let live_kit_room = nanoid::nanoid!(30);
958
959 let live_kit_connection_info = {
960 let live_kit_room = live_kit_room.clone();
961 let live_kit = session.live_kit_client.as_ref();
962
963 util::async_maybe!({
964 let live_kit = live_kit?;
965
966 let token = live_kit
967 .room_token(&live_kit_room, &session.user_id.to_string())
968 .trace_err()?;
969
970 Some(proto::LiveKitConnectionInfo {
971 server_url: live_kit.url().into(),
972 token,
973 can_publish: true,
974 })
975 })
976 }
977 .await;
978
979 let room = session
980 .db()
981 .await
982 .create_room(
983 session.user_id,
984 session.connection_id,
985 &live_kit_room,
986 &session.zed_environment,
987 )
988 .await?;
989
990 response.send(proto::CreateRoomResponse {
991 room: Some(room.clone()),
992 live_kit_connection_info,
993 })?;
994
995 update_user_contacts(session.user_id, &session).await?;
996 Ok(())
997}
998
999/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1000async fn join_room(
1001 request: proto::JoinRoom,
1002 response: Response<proto::JoinRoom>,
1003 session: Session,
1004) -> Result<()> {
1005 let room_id = RoomId::from_proto(request.id);
1006
1007 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1008
1009 if let Some(channel_id) = channel_id {
1010 return join_channel_internal(channel_id, Box::new(response), session).await;
1011 }
1012
1013 let joined_room = {
1014 let room = session
1015 .db()
1016 .await
1017 .join_room(
1018 room_id,
1019 session.user_id,
1020 session.connection_id,
1021 session.zed_environment.as_ref(),
1022 )
1023 .await?;
1024 room_updated(&room.room, &session.peer);
1025 room.into_inner()
1026 };
1027
1028 for connection_id in session
1029 .connection_pool()
1030 .await
1031 .user_connection_ids(session.user_id)
1032 {
1033 session
1034 .peer
1035 .send(
1036 connection_id,
1037 proto::CallCanceled {
1038 room_id: room_id.to_proto(),
1039 },
1040 )
1041 .trace_err();
1042 }
1043
1044 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1045 if let Some(token) = live_kit
1046 .room_token(
1047 &joined_room.room.live_kit_room,
1048 &session.user_id.to_string(),
1049 )
1050 .trace_err()
1051 {
1052 Some(proto::LiveKitConnectionInfo {
1053 server_url: live_kit.url().into(),
1054 token,
1055 can_publish: true,
1056 })
1057 } else {
1058 None
1059 }
1060 } else {
1061 None
1062 };
1063
1064 response.send(proto::JoinRoomResponse {
1065 room: Some(joined_room.room),
1066 channel_id: None,
1067 live_kit_connection_info,
1068 })?;
1069
1070 update_user_contacts(session.user_id, &session).await?;
1071 Ok(())
1072}
1073
1074/// Rejoin room is used to reconnect to a room after connection errors.
1075async fn rejoin_room(
1076 request: proto::RejoinRoom,
1077 response: Response<proto::RejoinRoom>,
1078 session: Session,
1079) -> Result<()> {
1080 let room;
1081 let channel_id;
1082 let channel_members;
1083 {
1084 let mut rejoined_room = session
1085 .db()
1086 .await
1087 .rejoin_room(request, session.user_id, session.connection_id)
1088 .await?;
1089
1090 response.send(proto::RejoinRoomResponse {
1091 room: Some(rejoined_room.room.clone()),
1092 reshared_projects: rejoined_room
1093 .reshared_projects
1094 .iter()
1095 .map(|project| proto::ResharedProject {
1096 id: project.id.to_proto(),
1097 collaborators: project
1098 .collaborators
1099 .iter()
1100 .map(|collaborator| collaborator.to_proto())
1101 .collect(),
1102 })
1103 .collect(),
1104 rejoined_projects: rejoined_room
1105 .rejoined_projects
1106 .iter()
1107 .map(|rejoined_project| proto::RejoinedProject {
1108 id: rejoined_project.id.to_proto(),
1109 worktrees: rejoined_project
1110 .worktrees
1111 .iter()
1112 .map(|worktree| proto::WorktreeMetadata {
1113 id: worktree.id,
1114 root_name: worktree.root_name.clone(),
1115 visible: worktree.visible,
1116 abs_path: worktree.abs_path.clone(),
1117 })
1118 .collect(),
1119 collaborators: rejoined_project
1120 .collaborators
1121 .iter()
1122 .map(|collaborator| collaborator.to_proto())
1123 .collect(),
1124 language_servers: rejoined_project.language_servers.clone(),
1125 })
1126 .collect(),
1127 })?;
1128 room_updated(&rejoined_room.room, &session.peer);
1129
1130 for project in &rejoined_room.reshared_projects {
1131 for collaborator in &project.collaborators {
1132 session
1133 .peer
1134 .send(
1135 collaborator.connection_id,
1136 proto::UpdateProjectCollaborator {
1137 project_id: project.id.to_proto(),
1138 old_peer_id: Some(project.old_connection_id.into()),
1139 new_peer_id: Some(session.connection_id.into()),
1140 },
1141 )
1142 .trace_err();
1143 }
1144
1145 broadcast(
1146 Some(session.connection_id),
1147 project
1148 .collaborators
1149 .iter()
1150 .map(|collaborator| collaborator.connection_id),
1151 |connection_id| {
1152 session.peer.forward_send(
1153 session.connection_id,
1154 connection_id,
1155 proto::UpdateProject {
1156 project_id: project.id.to_proto(),
1157 worktrees: project.worktrees.clone(),
1158 },
1159 )
1160 },
1161 );
1162 }
1163
1164 for project in &rejoined_room.rejoined_projects {
1165 for collaborator in &project.collaborators {
1166 session
1167 .peer
1168 .send(
1169 collaborator.connection_id,
1170 proto::UpdateProjectCollaborator {
1171 project_id: project.id.to_proto(),
1172 old_peer_id: Some(project.old_connection_id.into()),
1173 new_peer_id: Some(session.connection_id.into()),
1174 },
1175 )
1176 .trace_err();
1177 }
1178 }
1179
1180 for project in &mut rejoined_room.rejoined_projects {
1181 for worktree in mem::take(&mut project.worktrees) {
1182 #[cfg(any(test, feature = "test-support"))]
1183 const MAX_CHUNK_SIZE: usize = 2;
1184 #[cfg(not(any(test, feature = "test-support")))]
1185 const MAX_CHUNK_SIZE: usize = 256;
1186
1187 // Stream this worktree's entries.
1188 let message = proto::UpdateWorktree {
1189 project_id: project.id.to_proto(),
1190 worktree_id: worktree.id,
1191 abs_path: worktree.abs_path.clone(),
1192 root_name: worktree.root_name,
1193 updated_entries: worktree.updated_entries,
1194 removed_entries: worktree.removed_entries,
1195 scan_id: worktree.scan_id,
1196 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1197 updated_repositories: worktree.updated_repositories,
1198 removed_repositories: worktree.removed_repositories,
1199 };
1200 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1201 session.peer.send(session.connection_id, update.clone())?;
1202 }
1203
1204 // Stream this worktree's diagnostics.
1205 for summary in worktree.diagnostic_summaries {
1206 session.peer.send(
1207 session.connection_id,
1208 proto::UpdateDiagnosticSummary {
1209 project_id: project.id.to_proto(),
1210 worktree_id: worktree.id,
1211 summary: Some(summary),
1212 },
1213 )?;
1214 }
1215
1216 for settings_file in worktree.settings_files {
1217 session.peer.send(
1218 session.connection_id,
1219 proto::UpdateWorktreeSettings {
1220 project_id: project.id.to_proto(),
1221 worktree_id: worktree.id,
1222 path: settings_file.path,
1223 content: Some(settings_file.content),
1224 },
1225 )?;
1226 }
1227 }
1228
1229 for language_server in &project.language_servers {
1230 session.peer.send(
1231 session.connection_id,
1232 proto::UpdateLanguageServer {
1233 project_id: project.id.to_proto(),
1234 language_server_id: language_server.id,
1235 variant: Some(
1236 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1237 proto::LspDiskBasedDiagnosticsUpdated {},
1238 ),
1239 ),
1240 },
1241 )?;
1242 }
1243 }
1244
1245 let rejoined_room = rejoined_room.into_inner();
1246
1247 room = rejoined_room.room;
1248 channel_id = rejoined_room.channel_id;
1249 channel_members = rejoined_room.channel_members;
1250 }
1251
1252 if let Some(channel_id) = channel_id {
1253 channel_updated(
1254 channel_id,
1255 &room,
1256 &channel_members,
1257 &session.peer,
1258 &*session.connection_pool().await,
1259 );
1260 }
1261
1262 update_user_contacts(session.user_id, &session).await?;
1263 Ok(())
1264}
1265
1266/// leave room disconnects from the room.
1267async fn leave_room(
1268 _: proto::LeaveRoom,
1269 response: Response<proto::LeaveRoom>,
1270 session: Session,
1271) -> Result<()> {
1272 leave_room_for_session(&session).await?;
1273 response.send(proto::Ack {})?;
1274 Ok(())
1275}
1276
1277/// Updates the permissions of someone else in the room.
1278async fn set_room_participant_role(
1279 request: proto::SetRoomParticipantRole,
1280 response: Response<proto::SetRoomParticipantRole>,
1281 session: Session,
1282) -> Result<()> {
1283 let (live_kit_room, can_publish) = {
1284 let room = session
1285 .db()
1286 .await
1287 .set_room_participant_role(
1288 session.user_id,
1289 RoomId::from_proto(request.room_id),
1290 UserId::from_proto(request.user_id),
1291 ChannelRole::from(request.role()),
1292 )
1293 .await?;
1294
1295 let live_kit_room = room.live_kit_room.clone();
1296 let can_publish = ChannelRole::from(request.role()).can_publish_to_rooms();
1297 room_updated(&room, &session.peer);
1298 (live_kit_room, can_publish)
1299 };
1300
1301 if let Some(live_kit) = session.live_kit_client.as_ref() {
1302 live_kit
1303 .update_participant(
1304 live_kit_room.clone(),
1305 request.user_id.to_string(),
1306 live_kit_server::proto::ParticipantPermission {
1307 can_subscribe: true,
1308 can_publish,
1309 can_publish_data: can_publish,
1310 hidden: false,
1311 recorder: false,
1312 },
1313 )
1314 .await
1315 .trace_err();
1316 }
1317
1318 response.send(proto::Ack {})?;
1319 Ok(())
1320}
1321
1322/// Call someone else into the current room
1323async fn call(
1324 request: proto::Call,
1325 response: Response<proto::Call>,
1326 session: Session,
1327) -> Result<()> {
1328 let room_id = RoomId::from_proto(request.room_id);
1329 let calling_user_id = session.user_id;
1330 let calling_connection_id = session.connection_id;
1331 let called_user_id = UserId::from_proto(request.called_user_id);
1332 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1333 if !session
1334 .db()
1335 .await
1336 .has_contact(calling_user_id, called_user_id)
1337 .await?
1338 {
1339 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1340 }
1341
1342 let incoming_call = {
1343 let (room, incoming_call) = &mut *session
1344 .db()
1345 .await
1346 .call(
1347 room_id,
1348 calling_user_id,
1349 calling_connection_id,
1350 called_user_id,
1351 initial_project_id,
1352 )
1353 .await?;
1354 room_updated(&room, &session.peer);
1355 mem::take(incoming_call)
1356 };
1357 update_user_contacts(called_user_id, &session).await?;
1358
1359 let mut calls = session
1360 .connection_pool()
1361 .await
1362 .user_connection_ids(called_user_id)
1363 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1364 .collect::<FuturesUnordered<_>>();
1365
1366 while let Some(call_response) = calls.next().await {
1367 match call_response.as_ref() {
1368 Ok(_) => {
1369 response.send(proto::Ack {})?;
1370 return Ok(());
1371 }
1372 Err(_) => {
1373 call_response.trace_err();
1374 }
1375 }
1376 }
1377
1378 {
1379 let room = session
1380 .db()
1381 .await
1382 .call_failed(room_id, called_user_id)
1383 .await?;
1384 room_updated(&room, &session.peer);
1385 }
1386 update_user_contacts(called_user_id, &session).await?;
1387
1388 Err(anyhow!("failed to ring user"))?
1389}
1390
1391/// Cancel an outgoing call.
1392async fn cancel_call(
1393 request: proto::CancelCall,
1394 response: Response<proto::CancelCall>,
1395 session: Session,
1396) -> Result<()> {
1397 let called_user_id = UserId::from_proto(request.called_user_id);
1398 let room_id = RoomId::from_proto(request.room_id);
1399 {
1400 let room = session
1401 .db()
1402 .await
1403 .cancel_call(room_id, session.connection_id, called_user_id)
1404 .await?;
1405 room_updated(&room, &session.peer);
1406 }
1407
1408 for connection_id in session
1409 .connection_pool()
1410 .await
1411 .user_connection_ids(called_user_id)
1412 {
1413 session
1414 .peer
1415 .send(
1416 connection_id,
1417 proto::CallCanceled {
1418 room_id: room_id.to_proto(),
1419 },
1420 )
1421 .trace_err();
1422 }
1423 response.send(proto::Ack {})?;
1424
1425 update_user_contacts(called_user_id, &session).await?;
1426 Ok(())
1427}
1428
1429/// Decline an incoming call.
1430async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1431 let room_id = RoomId::from_proto(message.room_id);
1432 {
1433 let room = session
1434 .db()
1435 .await
1436 .decline_call(Some(room_id), session.user_id)
1437 .await?
1438 .ok_or_else(|| anyhow!("failed to decline call"))?;
1439 room_updated(&room, &session.peer);
1440 }
1441
1442 for connection_id in session
1443 .connection_pool()
1444 .await
1445 .user_connection_ids(session.user_id)
1446 {
1447 session
1448 .peer
1449 .send(
1450 connection_id,
1451 proto::CallCanceled {
1452 room_id: room_id.to_proto(),
1453 },
1454 )
1455 .trace_err();
1456 }
1457 update_user_contacts(session.user_id, &session).await?;
1458 Ok(())
1459}
1460
1461/// Updates other participants in the room with your current location.
1462async fn update_participant_location(
1463 request: proto::UpdateParticipantLocation,
1464 response: Response<proto::UpdateParticipantLocation>,
1465 session: Session,
1466) -> Result<()> {
1467 let room_id = RoomId::from_proto(request.room_id);
1468 let location = request
1469 .location
1470 .ok_or_else(|| anyhow!("invalid location"))?;
1471
1472 let db = session.db().await;
1473 let room = db
1474 .update_room_participant_location(room_id, session.connection_id, location)
1475 .await?;
1476
1477 room_updated(&room, &session.peer);
1478 response.send(proto::Ack {})?;
1479 Ok(())
1480}
1481
1482/// Share a project into the room.
1483async fn share_project(
1484 request: proto::ShareProject,
1485 response: Response<proto::ShareProject>,
1486 session: Session,
1487) -> Result<()> {
1488 let (project_id, room) = &*session
1489 .db()
1490 .await
1491 .share_project(
1492 RoomId::from_proto(request.room_id),
1493 session.connection_id,
1494 &request.worktrees,
1495 )
1496 .await?;
1497 response.send(proto::ShareProjectResponse {
1498 project_id: project_id.to_proto(),
1499 })?;
1500 room_updated(&room, &session.peer);
1501
1502 Ok(())
1503}
1504
1505/// Unshare a project from the room.
1506async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1507 let project_id = ProjectId::from_proto(message.project_id);
1508
1509 let (room, guest_connection_ids) = &*session
1510 .db()
1511 .await
1512 .unshare_project(project_id, session.connection_id)
1513 .await?;
1514
1515 broadcast(
1516 Some(session.connection_id),
1517 guest_connection_ids.iter().copied(),
1518 |conn_id| session.peer.send(conn_id, message.clone()),
1519 );
1520 room_updated(&room, &session.peer);
1521
1522 Ok(())
1523}
1524
1525/// Join someone elses shared project.
1526async fn join_project(
1527 request: proto::JoinProject,
1528 response: Response<proto::JoinProject>,
1529 session: Session,
1530) -> Result<()> {
1531 let project_id = ProjectId::from_proto(request.project_id);
1532 let guest_user_id = session.user_id;
1533
1534 tracing::info!(%project_id, "join project");
1535
1536 let (project, replica_id) = &mut *session
1537 .db()
1538 .await
1539 .join_project(project_id, session.connection_id)
1540 .await?;
1541
1542 let collaborators = project
1543 .collaborators
1544 .iter()
1545 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1546 .map(|collaborator| collaborator.to_proto())
1547 .collect::<Vec<_>>();
1548
1549 let worktrees = project
1550 .worktrees
1551 .iter()
1552 .map(|(id, worktree)| proto::WorktreeMetadata {
1553 id: *id,
1554 root_name: worktree.root_name.clone(),
1555 visible: worktree.visible,
1556 abs_path: worktree.abs_path.clone(),
1557 })
1558 .collect::<Vec<_>>();
1559
1560 for collaborator in &collaborators {
1561 session
1562 .peer
1563 .send(
1564 collaborator.peer_id.unwrap().into(),
1565 proto::AddProjectCollaborator {
1566 project_id: project_id.to_proto(),
1567 collaborator: Some(proto::Collaborator {
1568 peer_id: Some(session.connection_id.into()),
1569 replica_id: replica_id.0 as u32,
1570 user_id: guest_user_id.to_proto(),
1571 }),
1572 },
1573 )
1574 .trace_err();
1575 }
1576
1577 // First, we send the metadata associated with each worktree.
1578 response.send(proto::JoinProjectResponse {
1579 worktrees: worktrees.clone(),
1580 replica_id: replica_id.0 as u32,
1581 collaborators: collaborators.clone(),
1582 language_servers: project.language_servers.clone(),
1583 })?;
1584
1585 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1586 #[cfg(any(test, feature = "test-support"))]
1587 const MAX_CHUNK_SIZE: usize = 2;
1588 #[cfg(not(any(test, feature = "test-support")))]
1589 const MAX_CHUNK_SIZE: usize = 256;
1590
1591 // Stream this worktree's entries.
1592 let message = proto::UpdateWorktree {
1593 project_id: project_id.to_proto(),
1594 worktree_id,
1595 abs_path: worktree.abs_path.clone(),
1596 root_name: worktree.root_name,
1597 updated_entries: worktree.entries,
1598 removed_entries: Default::default(),
1599 scan_id: worktree.scan_id,
1600 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1601 updated_repositories: worktree.repository_entries.into_values().collect(),
1602 removed_repositories: Default::default(),
1603 };
1604 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1605 session.peer.send(session.connection_id, update.clone())?;
1606 }
1607
1608 // Stream this worktree's diagnostics.
1609 for summary in worktree.diagnostic_summaries {
1610 session.peer.send(
1611 session.connection_id,
1612 proto::UpdateDiagnosticSummary {
1613 project_id: project_id.to_proto(),
1614 worktree_id: worktree.id,
1615 summary: Some(summary),
1616 },
1617 )?;
1618 }
1619
1620 for settings_file in worktree.settings_files {
1621 session.peer.send(
1622 session.connection_id,
1623 proto::UpdateWorktreeSettings {
1624 project_id: project_id.to_proto(),
1625 worktree_id: worktree.id,
1626 path: settings_file.path,
1627 content: Some(settings_file.content),
1628 },
1629 )?;
1630 }
1631 }
1632
1633 for language_server in &project.language_servers {
1634 session.peer.send(
1635 session.connection_id,
1636 proto::UpdateLanguageServer {
1637 project_id: project_id.to_proto(),
1638 language_server_id: language_server.id,
1639 variant: Some(
1640 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1641 proto::LspDiskBasedDiagnosticsUpdated {},
1642 ),
1643 ),
1644 },
1645 )?;
1646 }
1647
1648 Ok(())
1649}
1650
1651/// Leave someone elses shared project.
1652async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1653 let sender_id = session.connection_id;
1654 let project_id = ProjectId::from_proto(request.project_id);
1655
1656 let (room, project) = &*session
1657 .db()
1658 .await
1659 .leave_project(project_id, sender_id)
1660 .await?;
1661 tracing::info!(
1662 %project_id,
1663 host_user_id = %project.host_user_id,
1664 host_connection_id = %project.host_connection_id,
1665 "leave project"
1666 );
1667
1668 project_left(&project, &session);
1669 room_updated(&room, &session.peer);
1670
1671 Ok(())
1672}
1673
1674/// Updates other participants with changes to the project
1675async fn update_project(
1676 request: proto::UpdateProject,
1677 response: Response<proto::UpdateProject>,
1678 session: Session,
1679) -> Result<()> {
1680 let project_id = ProjectId::from_proto(request.project_id);
1681 let (room, guest_connection_ids) = &*session
1682 .db()
1683 .await
1684 .update_project(project_id, session.connection_id, &request.worktrees)
1685 .await?;
1686 broadcast(
1687 Some(session.connection_id),
1688 guest_connection_ids.iter().copied(),
1689 |connection_id| {
1690 session
1691 .peer
1692 .forward_send(session.connection_id, connection_id, request.clone())
1693 },
1694 );
1695 room_updated(&room, &session.peer);
1696 response.send(proto::Ack {})?;
1697
1698 Ok(())
1699}
1700
1701/// Updates other participants with changes to the worktree
1702async fn update_worktree(
1703 request: proto::UpdateWorktree,
1704 response: Response<proto::UpdateWorktree>,
1705 session: Session,
1706) -> Result<()> {
1707 let guest_connection_ids = session
1708 .db()
1709 .await
1710 .update_worktree(&request, session.connection_id)
1711 .await?;
1712
1713 broadcast(
1714 Some(session.connection_id),
1715 guest_connection_ids.iter().copied(),
1716 |connection_id| {
1717 session
1718 .peer
1719 .forward_send(session.connection_id, connection_id, request.clone())
1720 },
1721 );
1722 response.send(proto::Ack {})?;
1723 Ok(())
1724}
1725
1726/// Updates other participants with changes to the diagnostics
1727async fn update_diagnostic_summary(
1728 message: proto::UpdateDiagnosticSummary,
1729 session: Session,
1730) -> Result<()> {
1731 let guest_connection_ids = session
1732 .db()
1733 .await
1734 .update_diagnostic_summary(&message, session.connection_id)
1735 .await?;
1736
1737 broadcast(
1738 Some(session.connection_id),
1739 guest_connection_ids.iter().copied(),
1740 |connection_id| {
1741 session
1742 .peer
1743 .forward_send(session.connection_id, connection_id, message.clone())
1744 },
1745 );
1746
1747 Ok(())
1748}
1749
1750/// Updates other participants with changes to the worktree settings
1751async fn update_worktree_settings(
1752 message: proto::UpdateWorktreeSettings,
1753 session: Session,
1754) -> Result<()> {
1755 let guest_connection_ids = session
1756 .db()
1757 .await
1758 .update_worktree_settings(&message, session.connection_id)
1759 .await?;
1760
1761 broadcast(
1762 Some(session.connection_id),
1763 guest_connection_ids.iter().copied(),
1764 |connection_id| {
1765 session
1766 .peer
1767 .forward_send(session.connection_id, connection_id, message.clone())
1768 },
1769 );
1770
1771 Ok(())
1772}
1773
1774/// Notify other participants that a language server has started.
1775async fn start_language_server(
1776 request: proto::StartLanguageServer,
1777 session: Session,
1778) -> Result<()> {
1779 let guest_connection_ids = session
1780 .db()
1781 .await
1782 .start_language_server(&request, session.connection_id)
1783 .await?;
1784
1785 broadcast(
1786 Some(session.connection_id),
1787 guest_connection_ids.iter().copied(),
1788 |connection_id| {
1789 session
1790 .peer
1791 .forward_send(session.connection_id, connection_id, request.clone())
1792 },
1793 );
1794 Ok(())
1795}
1796
1797/// Notify other participants that a language server has changed.
1798async fn update_language_server(
1799 request: proto::UpdateLanguageServer,
1800 session: Session,
1801) -> Result<()> {
1802 let project_id = ProjectId::from_proto(request.project_id);
1803 let project_connection_ids = session
1804 .db()
1805 .await
1806 .project_connection_ids(project_id, session.connection_id)
1807 .await?;
1808 broadcast(
1809 Some(session.connection_id),
1810 project_connection_ids.iter().copied(),
1811 |connection_id| {
1812 session
1813 .peer
1814 .forward_send(session.connection_id, connection_id, request.clone())
1815 },
1816 );
1817 Ok(())
1818}
1819
1820/// forward a project request to the host. These requests should be read only
1821/// as guests are allowed to send them.
1822async fn forward_read_only_project_request<T>(
1823 request: T,
1824 response: Response<T>,
1825 session: Session,
1826) -> Result<()>
1827where
1828 T: EntityMessage + RequestMessage,
1829{
1830 let project_id = ProjectId::from_proto(request.remote_entity_id());
1831 let host_connection_id = session
1832 .db()
1833 .await
1834 .host_for_read_only_project_request(project_id, session.connection_id)
1835 .await?;
1836 let payload = session
1837 .peer
1838 .forward_request(session.connection_id, host_connection_id, request)
1839 .await?;
1840 response.send(payload)?;
1841 Ok(())
1842}
1843
1844/// forward a project request to the host. These requests are disallowed
1845/// for guests.
1846async fn forward_mutating_project_request<T>(
1847 request: T,
1848 response: Response<T>,
1849 session: Session,
1850) -> Result<()>
1851where
1852 T: EntityMessage + RequestMessage,
1853{
1854 let project_id = ProjectId::from_proto(request.remote_entity_id());
1855 let host_connection_id = session
1856 .db()
1857 .await
1858 .host_for_mutating_project_request(project_id, session.connection_id)
1859 .await?;
1860 let payload = session
1861 .peer
1862 .forward_request(session.connection_id, host_connection_id, request)
1863 .await?;
1864 response.send(payload)?;
1865 Ok(())
1866}
1867
1868/// Notify other participants that a new buffer has been created
1869async fn create_buffer_for_peer(
1870 request: proto::CreateBufferForPeer,
1871 session: Session,
1872) -> Result<()> {
1873 session
1874 .db()
1875 .await
1876 .check_user_is_project_host(
1877 ProjectId::from_proto(request.project_id),
1878 session.connection_id,
1879 )
1880 .await?;
1881 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1882 session
1883 .peer
1884 .forward_send(session.connection_id, peer_id.into(), request)?;
1885 Ok(())
1886}
1887
1888/// Notify other participants that a buffer has been updated. This is
1889/// allowed for guests as long as the update is limited to selections.
1890async fn update_buffer(
1891 request: proto::UpdateBuffer,
1892 response: Response<proto::UpdateBuffer>,
1893 session: Session,
1894) -> Result<()> {
1895 let project_id = ProjectId::from_proto(request.project_id);
1896 let mut guest_connection_ids;
1897 let mut host_connection_id = None;
1898
1899 let mut requires_write_permission = false;
1900
1901 for op in request.operations.iter() {
1902 match op.variant {
1903 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1904 Some(_) => requires_write_permission = true,
1905 }
1906 }
1907
1908 {
1909 let collaborators = session
1910 .db()
1911 .await
1912 .project_collaborators_for_buffer_update(
1913 project_id,
1914 session.connection_id,
1915 requires_write_permission,
1916 )
1917 .await?;
1918 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1919 for collaborator in collaborators.iter() {
1920 if collaborator.is_host {
1921 host_connection_id = Some(collaborator.connection_id);
1922 } else {
1923 guest_connection_ids.push(collaborator.connection_id);
1924 }
1925 }
1926 }
1927 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1928
1929 broadcast(
1930 Some(session.connection_id),
1931 guest_connection_ids,
1932 |connection_id| {
1933 session
1934 .peer
1935 .forward_send(session.connection_id, connection_id, request.clone())
1936 },
1937 );
1938 if host_connection_id != session.connection_id {
1939 session
1940 .peer
1941 .forward_request(session.connection_id, host_connection_id, request.clone())
1942 .await?;
1943 }
1944
1945 response.send(proto::Ack {})?;
1946 Ok(())
1947}
1948
1949/// Notify other participants that a project has been updated.
1950async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
1951 request: T,
1952 session: Session,
1953) -> Result<()> {
1954 let project_id = ProjectId::from_proto(request.remote_entity_id());
1955 let project_connection_ids = session
1956 .db()
1957 .await
1958 .project_connection_ids(project_id, session.connection_id)
1959 .await?;
1960
1961 broadcast(
1962 Some(session.connection_id),
1963 project_connection_ids.iter().copied(),
1964 |connection_id| {
1965 session
1966 .peer
1967 .forward_send(session.connection_id, connection_id, request.clone())
1968 },
1969 );
1970 Ok(())
1971}
1972
1973/// Start following another user in a call.
1974async fn follow(
1975 request: proto::Follow,
1976 response: Response<proto::Follow>,
1977 session: Session,
1978) -> Result<()> {
1979 let room_id = RoomId::from_proto(request.room_id);
1980 let project_id = request.project_id.map(ProjectId::from_proto);
1981 let leader_id = request
1982 .leader_id
1983 .ok_or_else(|| anyhow!("invalid leader id"))?
1984 .into();
1985 let follower_id = session.connection_id;
1986
1987 session
1988 .db()
1989 .await
1990 .check_room_participants(room_id, leader_id, session.connection_id)
1991 .await?;
1992
1993 let response_payload = session
1994 .peer
1995 .forward_request(session.connection_id, leader_id, request)
1996 .await?;
1997 response.send(response_payload)?;
1998
1999 if let Some(project_id) = project_id {
2000 let room = session
2001 .db()
2002 .await
2003 .follow(room_id, project_id, leader_id, follower_id)
2004 .await?;
2005 room_updated(&room, &session.peer);
2006 }
2007
2008 Ok(())
2009}
2010
2011/// Stop following another user in a call.
2012async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2013 let room_id = RoomId::from_proto(request.room_id);
2014 let project_id = request.project_id.map(ProjectId::from_proto);
2015 let leader_id = request
2016 .leader_id
2017 .ok_or_else(|| anyhow!("invalid leader id"))?
2018 .into();
2019 let follower_id = session.connection_id;
2020
2021 session
2022 .db()
2023 .await
2024 .check_room_participants(room_id, leader_id, session.connection_id)
2025 .await?;
2026
2027 session
2028 .peer
2029 .forward_send(session.connection_id, leader_id, request)?;
2030
2031 if let Some(project_id) = project_id {
2032 let room = session
2033 .db()
2034 .await
2035 .unfollow(room_id, project_id, leader_id, follower_id)
2036 .await?;
2037 room_updated(&room, &session.peer);
2038 }
2039
2040 Ok(())
2041}
2042
2043/// Notify everyone following you of your current location.
2044async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2045 let room_id = RoomId::from_proto(request.room_id);
2046 let database = session.db.lock().await;
2047
2048 let connection_ids = if let Some(project_id) = request.project_id {
2049 let project_id = ProjectId::from_proto(project_id);
2050 database
2051 .project_connection_ids(project_id, session.connection_id)
2052 .await?
2053 } else {
2054 database
2055 .room_connection_ids(room_id, session.connection_id)
2056 .await?
2057 };
2058
2059 // For now, don't send view update messages back to that view's current leader.
2060 let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2061 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2062 _ => None,
2063 });
2064
2065 for follower_peer_id in request.follower_ids.iter().copied() {
2066 let follower_connection_id = follower_peer_id.into();
2067 if Some(follower_peer_id) != connection_id_to_omit
2068 && connection_ids.contains(&follower_connection_id)
2069 {
2070 session.peer.forward_send(
2071 session.connection_id,
2072 follower_connection_id,
2073 request.clone(),
2074 )?;
2075 }
2076 }
2077 Ok(())
2078}
2079
2080/// Get public data about users.
2081async fn get_users(
2082 request: proto::GetUsers,
2083 response: Response<proto::GetUsers>,
2084 session: Session,
2085) -> Result<()> {
2086 let user_ids = request
2087 .user_ids
2088 .into_iter()
2089 .map(UserId::from_proto)
2090 .collect();
2091 let users = session
2092 .db()
2093 .await
2094 .get_users_by_ids(user_ids)
2095 .await?
2096 .into_iter()
2097 .map(|user| proto::User {
2098 id: user.id.to_proto(),
2099 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2100 github_login: user.github_login,
2101 })
2102 .collect();
2103 response.send(proto::UsersResponse { users })?;
2104 Ok(())
2105}
2106
2107/// Search for users (to invite) buy Github login
2108async fn fuzzy_search_users(
2109 request: proto::FuzzySearchUsers,
2110 response: Response<proto::FuzzySearchUsers>,
2111 session: Session,
2112) -> Result<()> {
2113 let query = request.query;
2114 let users = match query.len() {
2115 0 => vec![],
2116 1 | 2 => session
2117 .db()
2118 .await
2119 .get_user_by_github_login(&query)
2120 .await?
2121 .into_iter()
2122 .collect(),
2123 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2124 };
2125 let users = users
2126 .into_iter()
2127 .filter(|user| user.id != session.user_id)
2128 .map(|user| proto::User {
2129 id: user.id.to_proto(),
2130 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2131 github_login: user.github_login,
2132 })
2133 .collect();
2134 response.send(proto::UsersResponse { users })?;
2135 Ok(())
2136}
2137
2138/// Send a contact request to another user.
2139async fn request_contact(
2140 request: proto::RequestContact,
2141 response: Response<proto::RequestContact>,
2142 session: Session,
2143) -> Result<()> {
2144 let requester_id = session.user_id;
2145 let responder_id = UserId::from_proto(request.responder_id);
2146 if requester_id == responder_id {
2147 return Err(anyhow!("cannot add yourself as a contact"))?;
2148 }
2149
2150 let notifications = session
2151 .db()
2152 .await
2153 .send_contact_request(requester_id, responder_id)
2154 .await?;
2155
2156 // Update outgoing contact requests of requester
2157 let mut update = proto::UpdateContacts::default();
2158 update.outgoing_requests.push(responder_id.to_proto());
2159 for connection_id in session
2160 .connection_pool()
2161 .await
2162 .user_connection_ids(requester_id)
2163 {
2164 session.peer.send(connection_id, update.clone())?;
2165 }
2166
2167 // Update incoming contact requests of responder
2168 let mut update = proto::UpdateContacts::default();
2169 update
2170 .incoming_requests
2171 .push(proto::IncomingContactRequest {
2172 requester_id: requester_id.to_proto(),
2173 });
2174 let connection_pool = session.connection_pool().await;
2175 for connection_id in connection_pool.user_connection_ids(responder_id) {
2176 session.peer.send(connection_id, update.clone())?;
2177 }
2178
2179 send_notifications(&*connection_pool, &session.peer, notifications);
2180
2181 response.send(proto::Ack {})?;
2182 Ok(())
2183}
2184
2185/// Accept or decline a contact request
2186async fn respond_to_contact_request(
2187 request: proto::RespondToContactRequest,
2188 response: Response<proto::RespondToContactRequest>,
2189 session: Session,
2190) -> Result<()> {
2191 let responder_id = session.user_id;
2192 let requester_id = UserId::from_proto(request.requester_id);
2193 let db = session.db().await;
2194 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2195 db.dismiss_contact_notification(responder_id, requester_id)
2196 .await?;
2197 } else {
2198 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2199
2200 let notifications = db
2201 .respond_to_contact_request(responder_id, requester_id, accept)
2202 .await?;
2203 let requester_busy = db.is_user_busy(requester_id).await?;
2204 let responder_busy = db.is_user_busy(responder_id).await?;
2205
2206 let pool = session.connection_pool().await;
2207 // Update responder with new contact
2208 let mut update = proto::UpdateContacts::default();
2209 if accept {
2210 update
2211 .contacts
2212 .push(contact_for_user(requester_id, requester_busy, &pool));
2213 }
2214 update
2215 .remove_incoming_requests
2216 .push(requester_id.to_proto());
2217 for connection_id in pool.user_connection_ids(responder_id) {
2218 session.peer.send(connection_id, update.clone())?;
2219 }
2220
2221 // Update requester with new contact
2222 let mut update = proto::UpdateContacts::default();
2223 if accept {
2224 update
2225 .contacts
2226 .push(contact_for_user(responder_id, responder_busy, &pool));
2227 }
2228 update
2229 .remove_outgoing_requests
2230 .push(responder_id.to_proto());
2231
2232 for connection_id in pool.user_connection_ids(requester_id) {
2233 session.peer.send(connection_id, update.clone())?;
2234 }
2235
2236 send_notifications(&*pool, &session.peer, notifications);
2237 }
2238
2239 response.send(proto::Ack {})?;
2240 Ok(())
2241}
2242
2243/// Remove a contact.
2244async fn remove_contact(
2245 request: proto::RemoveContact,
2246 response: Response<proto::RemoveContact>,
2247 session: Session,
2248) -> Result<()> {
2249 let requester_id = session.user_id;
2250 let responder_id = UserId::from_proto(request.user_id);
2251 let db = session.db().await;
2252 let (contact_accepted, deleted_notification_id) =
2253 db.remove_contact(requester_id, responder_id).await?;
2254
2255 let pool = session.connection_pool().await;
2256 // Update outgoing contact requests of requester
2257 let mut update = proto::UpdateContacts::default();
2258 if contact_accepted {
2259 update.remove_contacts.push(responder_id.to_proto());
2260 } else {
2261 update
2262 .remove_outgoing_requests
2263 .push(responder_id.to_proto());
2264 }
2265 for connection_id in pool.user_connection_ids(requester_id) {
2266 session.peer.send(connection_id, update.clone())?;
2267 }
2268
2269 // Update incoming contact requests of responder
2270 let mut update = proto::UpdateContacts::default();
2271 if contact_accepted {
2272 update.remove_contacts.push(requester_id.to_proto());
2273 } else {
2274 update
2275 .remove_incoming_requests
2276 .push(requester_id.to_proto());
2277 }
2278 for connection_id in pool.user_connection_ids(responder_id) {
2279 session.peer.send(connection_id, update.clone())?;
2280 if let Some(notification_id) = deleted_notification_id {
2281 session.peer.send(
2282 connection_id,
2283 proto::DeleteNotification {
2284 notification_id: notification_id.to_proto(),
2285 },
2286 )?;
2287 }
2288 }
2289
2290 response.send(proto::Ack {})?;
2291 Ok(())
2292}
2293
2294/// Creates a new channel.
2295async fn create_channel(
2296 request: proto::CreateChannel,
2297 response: Response<proto::CreateChannel>,
2298 session: Session,
2299) -> Result<()> {
2300 let db = session.db().await;
2301
2302 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2303 let (channel, owner, channel_members) = db
2304 .create_channel(&request.name, parent_id, session.user_id)
2305 .await?;
2306
2307 response.send(proto::CreateChannelResponse {
2308 channel: Some(channel.to_proto()),
2309 parent_id: request.parent_id,
2310 })?;
2311
2312 let connection_pool = session.connection_pool().await;
2313 if let Some(owner) = owner {
2314 let update = proto::UpdateUserChannels {
2315 channel_memberships: vec![proto::ChannelMembership {
2316 channel_id: owner.channel_id.to_proto(),
2317 role: owner.role.into(),
2318 }],
2319 ..Default::default()
2320 };
2321 for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2322 session.peer.send(connection_id, update.clone())?;
2323 }
2324 }
2325
2326 for channel_member in channel_members {
2327 if !channel_member.role.can_see_channel(channel.visibility) {
2328 continue;
2329 }
2330
2331 let update = proto::UpdateChannels {
2332 channels: vec![channel.to_proto()],
2333 ..Default::default()
2334 };
2335 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2336 session.peer.send(connection_id, update.clone())?;
2337 }
2338 }
2339
2340 Ok(())
2341}
2342
2343/// Delete a channel
2344async fn delete_channel(
2345 request: proto::DeleteChannel,
2346 response: Response<proto::DeleteChannel>,
2347 session: Session,
2348) -> Result<()> {
2349 let db = session.db().await;
2350
2351 let channel_id = request.channel_id;
2352 let (removed_channels, member_ids) = db
2353 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2354 .await?;
2355 response.send(proto::Ack {})?;
2356
2357 // Notify members of removed channels
2358 let mut update = proto::UpdateChannels::default();
2359 update
2360 .delete_channels
2361 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2362
2363 let connection_pool = session.connection_pool().await;
2364 for member_id in member_ids {
2365 for connection_id in connection_pool.user_connection_ids(member_id) {
2366 session.peer.send(connection_id, update.clone())?;
2367 }
2368 }
2369
2370 Ok(())
2371}
2372
2373/// Invite someone to join a channel.
2374async fn invite_channel_member(
2375 request: proto::InviteChannelMember,
2376 response: Response<proto::InviteChannelMember>,
2377 session: Session,
2378) -> Result<()> {
2379 let db = session.db().await;
2380 let channel_id = ChannelId::from_proto(request.channel_id);
2381 let invitee_id = UserId::from_proto(request.user_id);
2382 let InviteMemberResult {
2383 channel,
2384 notifications,
2385 } = db
2386 .invite_channel_member(
2387 channel_id,
2388 invitee_id,
2389 session.user_id,
2390 request.role().into(),
2391 )
2392 .await?;
2393
2394 let update = proto::UpdateChannels {
2395 channel_invitations: vec![channel.to_proto()],
2396 ..Default::default()
2397 };
2398
2399 let connection_pool = session.connection_pool().await;
2400 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2401 session.peer.send(connection_id, update.clone())?;
2402 }
2403
2404 send_notifications(&*connection_pool, &session.peer, notifications);
2405
2406 response.send(proto::Ack {})?;
2407 Ok(())
2408}
2409
2410/// remove someone from a channel
2411async fn remove_channel_member(
2412 request: proto::RemoveChannelMember,
2413 response: Response<proto::RemoveChannelMember>,
2414 session: Session,
2415) -> Result<()> {
2416 let db = session.db().await;
2417 let channel_id = ChannelId::from_proto(request.channel_id);
2418 let member_id = UserId::from_proto(request.user_id);
2419
2420 let RemoveChannelMemberResult {
2421 membership_update,
2422 notification_id,
2423 } = db
2424 .remove_channel_member(channel_id, member_id, session.user_id)
2425 .await?;
2426
2427 let connection_pool = &session.connection_pool().await;
2428 notify_membership_updated(
2429 &connection_pool,
2430 membership_update,
2431 member_id,
2432 &session.peer,
2433 );
2434 for connection_id in connection_pool.user_connection_ids(member_id) {
2435 if let Some(notification_id) = notification_id {
2436 session
2437 .peer
2438 .send(
2439 connection_id,
2440 proto::DeleteNotification {
2441 notification_id: notification_id.to_proto(),
2442 },
2443 )
2444 .trace_err();
2445 }
2446 }
2447
2448 response.send(proto::Ack {})?;
2449 Ok(())
2450}
2451
2452/// Toggle the channel between public and private.
2453/// Care is taken to maintain the invariant that public channels only descend from public channels,
2454/// (though members-only channels can appear at any point in the hierarchy).
2455async fn set_channel_visibility(
2456 request: proto::SetChannelVisibility,
2457 response: Response<proto::SetChannelVisibility>,
2458 session: Session,
2459) -> Result<()> {
2460 let db = session.db().await;
2461 let channel_id = ChannelId::from_proto(request.channel_id);
2462 let visibility = request.visibility().into();
2463
2464 let (channel, channel_members) = db
2465 .set_channel_visibility(channel_id, visibility, session.user_id)
2466 .await?;
2467
2468 let connection_pool = session.connection_pool().await;
2469 for member in channel_members {
2470 let update = if member.role.can_see_channel(channel.visibility) {
2471 proto::UpdateChannels {
2472 channels: vec![channel.to_proto()],
2473 ..Default::default()
2474 }
2475 } else {
2476 proto::UpdateChannels {
2477 delete_channels: vec![channel.id.to_proto()],
2478 ..Default::default()
2479 }
2480 };
2481
2482 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2483 session.peer.send(connection_id, update.clone())?;
2484 }
2485 }
2486
2487 response.send(proto::Ack {})?;
2488 Ok(())
2489}
2490
2491/// Alter the role for a user in the channel.
2492async fn set_channel_member_role(
2493 request: proto::SetChannelMemberRole,
2494 response: Response<proto::SetChannelMemberRole>,
2495 session: Session,
2496) -> Result<()> {
2497 let db = session.db().await;
2498 let channel_id = ChannelId::from_proto(request.channel_id);
2499 let member_id = UserId::from_proto(request.user_id);
2500 let result = db
2501 .set_channel_member_role(
2502 channel_id,
2503 session.user_id,
2504 member_id,
2505 request.role().into(),
2506 )
2507 .await?;
2508
2509 match result {
2510 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2511 let connection_pool = session.connection_pool().await;
2512 notify_membership_updated(
2513 &connection_pool,
2514 membership_update,
2515 member_id,
2516 &session.peer,
2517 )
2518 }
2519 db::SetMemberRoleResult::InviteUpdated(channel) => {
2520 let update = proto::UpdateChannels {
2521 channel_invitations: vec![channel.to_proto()],
2522 ..Default::default()
2523 };
2524
2525 for connection_id in session
2526 .connection_pool()
2527 .await
2528 .user_connection_ids(member_id)
2529 {
2530 session.peer.send(connection_id, update.clone())?;
2531 }
2532 }
2533 }
2534
2535 response.send(proto::Ack {})?;
2536 Ok(())
2537}
2538
2539/// Change the name of a channel
2540async fn rename_channel(
2541 request: proto::RenameChannel,
2542 response: Response<proto::RenameChannel>,
2543 session: Session,
2544) -> Result<()> {
2545 let db = session.db().await;
2546 let channel_id = ChannelId::from_proto(request.channel_id);
2547 let (channel, channel_members) = db
2548 .rename_channel(channel_id, session.user_id, &request.name)
2549 .await?;
2550
2551 response.send(proto::RenameChannelResponse {
2552 channel: Some(channel.to_proto()),
2553 })?;
2554
2555 let connection_pool = session.connection_pool().await;
2556 for channel_member in channel_members {
2557 if !channel_member.role.can_see_channel(channel.visibility) {
2558 continue;
2559 }
2560 let update = proto::UpdateChannels {
2561 channels: vec![channel.to_proto()],
2562 ..Default::default()
2563 };
2564 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2565 session.peer.send(connection_id, update.clone())?;
2566 }
2567 }
2568
2569 Ok(())
2570}
2571
2572/// Move a channel to a new parent.
2573async fn move_channel(
2574 request: proto::MoveChannel,
2575 response: Response<proto::MoveChannel>,
2576 session: Session,
2577) -> Result<()> {
2578 let channel_id = ChannelId::from_proto(request.channel_id);
2579 let to = ChannelId::from_proto(request.to);
2580
2581 let (channels, channel_members) = session
2582 .db()
2583 .await
2584 .move_channel(channel_id, to, session.user_id)
2585 .await?;
2586
2587 let connection_pool = session.connection_pool().await;
2588 for member in channel_members {
2589 let channels = channels
2590 .iter()
2591 .filter_map(|channel| {
2592 if member.role.can_see_channel(channel.visibility) {
2593 Some(channel.to_proto())
2594 } else {
2595 None
2596 }
2597 })
2598 .collect::<Vec<_>>();
2599 if channels.is_empty() {
2600 continue;
2601 }
2602
2603 let update = proto::UpdateChannels {
2604 channels,
2605 ..Default::default()
2606 };
2607
2608 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2609 session.peer.send(connection_id, update.clone())?;
2610 }
2611 }
2612
2613 response.send(Ack {})?;
2614 Ok(())
2615}
2616
2617/// Get the list of channel members
2618async fn get_channel_members(
2619 request: proto::GetChannelMembers,
2620 response: Response<proto::GetChannelMembers>,
2621 session: Session,
2622) -> Result<()> {
2623 let db = session.db().await;
2624 let channel_id = ChannelId::from_proto(request.channel_id);
2625 let members = db
2626 .get_channel_participant_details(channel_id, session.user_id)
2627 .await?;
2628 response.send(proto::GetChannelMembersResponse { members })?;
2629 Ok(())
2630}
2631
2632/// Accept or decline a channel invitation.
2633async fn respond_to_channel_invite(
2634 request: proto::RespondToChannelInvite,
2635 response: Response<proto::RespondToChannelInvite>,
2636 session: Session,
2637) -> Result<()> {
2638 let db = session.db().await;
2639 let channel_id = ChannelId::from_proto(request.channel_id);
2640 let RespondToChannelInvite {
2641 membership_update,
2642 notifications,
2643 } = db
2644 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2645 .await?;
2646
2647 let connection_pool = session.connection_pool().await;
2648 if let Some(membership_update) = membership_update {
2649 notify_membership_updated(
2650 &connection_pool,
2651 membership_update,
2652 session.user_id,
2653 &session.peer,
2654 );
2655 } else {
2656 let update = proto::UpdateChannels {
2657 remove_channel_invitations: vec![channel_id.to_proto()],
2658 ..Default::default()
2659 };
2660
2661 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2662 session.peer.send(connection_id, update.clone())?;
2663 }
2664 };
2665
2666 send_notifications(&*connection_pool, &session.peer, notifications);
2667
2668 response.send(proto::Ack {})?;
2669
2670 Ok(())
2671}
2672
2673/// Join the channels' room
2674async fn join_channel(
2675 request: proto::JoinChannel,
2676 response: Response<proto::JoinChannel>,
2677 session: Session,
2678) -> Result<()> {
2679 let channel_id = ChannelId::from_proto(request.channel_id);
2680 join_channel_internal(channel_id, Box::new(response), session).await
2681}
2682
2683trait JoinChannelInternalResponse {
2684 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2685}
2686impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2687 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2688 Response::<proto::JoinChannel>::send(self, result)
2689 }
2690}
2691impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2692 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2693 Response::<proto::JoinRoom>::send(self, result)
2694 }
2695}
2696
2697async fn join_channel_internal(
2698 channel_id: ChannelId,
2699 response: Box<impl JoinChannelInternalResponse>,
2700 session: Session,
2701) -> Result<()> {
2702 let joined_room = {
2703 leave_room_for_session(&session).await?;
2704 let db = session.db().await;
2705
2706 let (joined_room, membership_updated, role) = db
2707 .join_channel(
2708 channel_id,
2709 session.user_id,
2710 session.connection_id,
2711 session.zed_environment.as_ref(),
2712 )
2713 .await?;
2714
2715 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2716 let (can_publish, token) = if role == ChannelRole::Guest {
2717 (
2718 false,
2719 live_kit
2720 .guest_token(
2721 &joined_room.room.live_kit_room,
2722 &session.user_id.to_string(),
2723 )
2724 .trace_err()?,
2725 )
2726 } else {
2727 (
2728 true,
2729 live_kit
2730 .room_token(
2731 &joined_room.room.live_kit_room,
2732 &session.user_id.to_string(),
2733 )
2734 .trace_err()?,
2735 )
2736 };
2737
2738 Some(LiveKitConnectionInfo {
2739 server_url: live_kit.url().into(),
2740 token,
2741 can_publish,
2742 })
2743 });
2744
2745 response.send(proto::JoinRoomResponse {
2746 room: Some(joined_room.room.clone()),
2747 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2748 live_kit_connection_info,
2749 })?;
2750
2751 let connection_pool = session.connection_pool().await;
2752 if let Some(membership_updated) = membership_updated {
2753 notify_membership_updated(
2754 &connection_pool,
2755 membership_updated,
2756 session.user_id,
2757 &session.peer,
2758 );
2759 }
2760
2761 room_updated(&joined_room.room, &session.peer);
2762
2763 joined_room
2764 };
2765
2766 channel_updated(
2767 channel_id,
2768 &joined_room.room,
2769 &joined_room.channel_members,
2770 &session.peer,
2771 &*session.connection_pool().await,
2772 );
2773
2774 update_user_contacts(session.user_id, &session).await?;
2775 Ok(())
2776}
2777
2778/// Start editing the channel notes
2779async fn join_channel_buffer(
2780 request: proto::JoinChannelBuffer,
2781 response: Response<proto::JoinChannelBuffer>,
2782 session: Session,
2783) -> Result<()> {
2784 let db = session.db().await;
2785 let channel_id = ChannelId::from_proto(request.channel_id);
2786
2787 let open_response = db
2788 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2789 .await?;
2790
2791 let collaborators = open_response.collaborators.clone();
2792 response.send(open_response)?;
2793
2794 let update = UpdateChannelBufferCollaborators {
2795 channel_id: channel_id.to_proto(),
2796 collaborators: collaborators.clone(),
2797 };
2798 channel_buffer_updated(
2799 session.connection_id,
2800 collaborators
2801 .iter()
2802 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2803 &update,
2804 &session.peer,
2805 );
2806
2807 Ok(())
2808}
2809
2810/// Edit the channel notes
2811async fn update_channel_buffer(
2812 request: proto::UpdateChannelBuffer,
2813 session: Session,
2814) -> Result<()> {
2815 let db = session.db().await;
2816 let channel_id = ChannelId::from_proto(request.channel_id);
2817
2818 let (collaborators, non_collaborators, epoch, version) = db
2819 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2820 .await?;
2821
2822 channel_buffer_updated(
2823 session.connection_id,
2824 collaborators,
2825 &proto::UpdateChannelBuffer {
2826 channel_id: channel_id.to_proto(),
2827 operations: request.operations,
2828 },
2829 &session.peer,
2830 );
2831
2832 let pool = &*session.connection_pool().await;
2833
2834 broadcast(
2835 None,
2836 non_collaborators
2837 .iter()
2838 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2839 |peer_id| {
2840 session.peer.send(
2841 peer_id.into(),
2842 proto::UpdateChannels {
2843 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2844 channel_id: channel_id.to_proto(),
2845 epoch: epoch as u64,
2846 version: version.clone(),
2847 }],
2848 ..Default::default()
2849 },
2850 )
2851 },
2852 );
2853
2854 Ok(())
2855}
2856
2857/// Rejoin the channel notes after a connection blip
2858async fn rejoin_channel_buffers(
2859 request: proto::RejoinChannelBuffers,
2860 response: Response<proto::RejoinChannelBuffers>,
2861 session: Session,
2862) -> Result<()> {
2863 let db = session.db().await;
2864 let buffers = db
2865 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2866 .await?;
2867
2868 for rejoined_buffer in &buffers {
2869 let collaborators_to_notify = rejoined_buffer
2870 .buffer
2871 .collaborators
2872 .iter()
2873 .filter_map(|c| Some(c.peer_id?.into()));
2874 channel_buffer_updated(
2875 session.connection_id,
2876 collaborators_to_notify,
2877 &proto::UpdateChannelBufferCollaborators {
2878 channel_id: rejoined_buffer.buffer.channel_id,
2879 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2880 },
2881 &session.peer,
2882 );
2883 }
2884
2885 response.send(proto::RejoinChannelBuffersResponse {
2886 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2887 })?;
2888
2889 Ok(())
2890}
2891
2892/// Stop editing the channel notes
2893async fn leave_channel_buffer(
2894 request: proto::LeaveChannelBuffer,
2895 response: Response<proto::LeaveChannelBuffer>,
2896 session: Session,
2897) -> Result<()> {
2898 let db = session.db().await;
2899 let channel_id = ChannelId::from_proto(request.channel_id);
2900
2901 let left_buffer = db
2902 .leave_channel_buffer(channel_id, session.connection_id)
2903 .await?;
2904
2905 response.send(Ack {})?;
2906
2907 channel_buffer_updated(
2908 session.connection_id,
2909 left_buffer.connections,
2910 &proto::UpdateChannelBufferCollaborators {
2911 channel_id: channel_id.to_proto(),
2912 collaborators: left_buffer.collaborators,
2913 },
2914 &session.peer,
2915 );
2916
2917 Ok(())
2918}
2919
2920fn channel_buffer_updated<T: EnvelopedMessage>(
2921 sender_id: ConnectionId,
2922 collaborators: impl IntoIterator<Item = ConnectionId>,
2923 message: &T,
2924 peer: &Peer,
2925) {
2926 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2927 peer.send(peer_id.into(), message.clone())
2928 });
2929}
2930
2931fn send_notifications(
2932 connection_pool: &ConnectionPool,
2933 peer: &Peer,
2934 notifications: db::NotificationBatch,
2935) {
2936 for (user_id, notification) in notifications {
2937 for connection_id in connection_pool.user_connection_ids(user_id) {
2938 if let Err(error) = peer.send(
2939 connection_id,
2940 proto::AddNotification {
2941 notification: Some(notification.clone()),
2942 },
2943 ) {
2944 tracing::error!(
2945 "failed to send notification to {:?} {}",
2946 connection_id,
2947 error
2948 );
2949 }
2950 }
2951 }
2952}
2953
2954/// Send a message to the channel
2955async fn send_channel_message(
2956 request: proto::SendChannelMessage,
2957 response: Response<proto::SendChannelMessage>,
2958 session: Session,
2959) -> Result<()> {
2960 // Validate the message body.
2961 let body = request.body.trim().to_string();
2962 if body.len() > MAX_MESSAGE_LEN {
2963 return Err(anyhow!("message is too long"))?;
2964 }
2965 if body.is_empty() {
2966 return Err(anyhow!("message can't be blank"))?;
2967 }
2968
2969 // TODO: adjust mentions if body is trimmed
2970
2971 let timestamp = OffsetDateTime::now_utc();
2972 let nonce = request
2973 .nonce
2974 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2975
2976 let channel_id = ChannelId::from_proto(request.channel_id);
2977 let CreatedChannelMessage {
2978 message_id,
2979 participant_connection_ids,
2980 channel_members,
2981 notifications,
2982 } = session
2983 .db()
2984 .await
2985 .create_channel_message(
2986 channel_id,
2987 session.user_id,
2988 &body,
2989 &request.mentions,
2990 timestamp,
2991 nonce.clone().into(),
2992 )
2993 .await?;
2994 let message = proto::ChannelMessage {
2995 sender_id: session.user_id.to_proto(),
2996 id: message_id.to_proto(),
2997 body,
2998 mentions: request.mentions,
2999 timestamp: timestamp.unix_timestamp() as u64,
3000 nonce: Some(nonce),
3001 };
3002 broadcast(
3003 Some(session.connection_id),
3004 participant_connection_ids,
3005 |connection| {
3006 session.peer.send(
3007 connection,
3008 proto::ChannelMessageSent {
3009 channel_id: channel_id.to_proto(),
3010 message: Some(message.clone()),
3011 },
3012 )
3013 },
3014 );
3015 response.send(proto::SendChannelMessageResponse {
3016 message: Some(message),
3017 })?;
3018
3019 let pool = &*session.connection_pool().await;
3020 broadcast(
3021 None,
3022 channel_members
3023 .iter()
3024 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3025 |peer_id| {
3026 session.peer.send(
3027 peer_id.into(),
3028 proto::UpdateChannels {
3029 latest_channel_message_ids: vec![proto::ChannelMessageId {
3030 channel_id: channel_id.to_proto(),
3031 message_id: message_id.to_proto(),
3032 }],
3033 ..Default::default()
3034 },
3035 )
3036 },
3037 );
3038 send_notifications(pool, &session.peer, notifications);
3039
3040 Ok(())
3041}
3042
3043/// Delete a channel message
3044async fn remove_channel_message(
3045 request: proto::RemoveChannelMessage,
3046 response: Response<proto::RemoveChannelMessage>,
3047 session: Session,
3048) -> Result<()> {
3049 let channel_id = ChannelId::from_proto(request.channel_id);
3050 let message_id = MessageId::from_proto(request.message_id);
3051 let connection_ids = session
3052 .db()
3053 .await
3054 .remove_channel_message(channel_id, message_id, session.user_id)
3055 .await?;
3056 broadcast(Some(session.connection_id), connection_ids, |connection| {
3057 session.peer.send(connection, request.clone())
3058 });
3059 response.send(proto::Ack {})?;
3060 Ok(())
3061}
3062
3063/// Mark a channel message as read
3064async fn acknowledge_channel_message(
3065 request: proto::AckChannelMessage,
3066 session: Session,
3067) -> Result<()> {
3068 let channel_id = ChannelId::from_proto(request.channel_id);
3069 let message_id = MessageId::from_proto(request.message_id);
3070 let notifications = session
3071 .db()
3072 .await
3073 .observe_channel_message(channel_id, session.user_id, message_id)
3074 .await?;
3075 send_notifications(
3076 &*session.connection_pool().await,
3077 &session.peer,
3078 notifications,
3079 );
3080 Ok(())
3081}
3082
3083/// Mark a buffer version as synced
3084async fn acknowledge_buffer_version(
3085 request: proto::AckBufferOperation,
3086 session: Session,
3087) -> Result<()> {
3088 let buffer_id = BufferId::from_proto(request.buffer_id);
3089 session
3090 .db()
3091 .await
3092 .observe_buffer_version(
3093 buffer_id,
3094 session.user_id,
3095 request.epoch as i32,
3096 &request.version,
3097 )
3098 .await?;
3099 Ok(())
3100}
3101
3102/// Start receiving chat updates for a channel
3103async fn join_channel_chat(
3104 request: proto::JoinChannelChat,
3105 response: Response<proto::JoinChannelChat>,
3106 session: Session,
3107) -> Result<()> {
3108 let channel_id = ChannelId::from_proto(request.channel_id);
3109
3110 let db = session.db().await;
3111 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3112 .await?;
3113 let messages = db
3114 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3115 .await?;
3116 response.send(proto::JoinChannelChatResponse {
3117 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3118 messages,
3119 })?;
3120 Ok(())
3121}
3122
3123/// Stop receiving chat updates for a channel
3124async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3125 let channel_id = ChannelId::from_proto(request.channel_id);
3126 session
3127 .db()
3128 .await
3129 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3130 .await?;
3131 Ok(())
3132}
3133
3134/// Retrieve the chat history for a channel
3135async fn get_channel_messages(
3136 request: proto::GetChannelMessages,
3137 response: Response<proto::GetChannelMessages>,
3138 session: Session,
3139) -> Result<()> {
3140 let channel_id = ChannelId::from_proto(request.channel_id);
3141 let messages = session
3142 .db()
3143 .await
3144 .get_channel_messages(
3145 channel_id,
3146 session.user_id,
3147 MESSAGE_COUNT_PER_PAGE,
3148 Some(MessageId::from_proto(request.before_message_id)),
3149 )
3150 .await?;
3151 response.send(proto::GetChannelMessagesResponse {
3152 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3153 messages,
3154 })?;
3155 Ok(())
3156}
3157
3158/// Retrieve specific chat messages
3159async fn get_channel_messages_by_id(
3160 request: proto::GetChannelMessagesById,
3161 response: Response<proto::GetChannelMessagesById>,
3162 session: Session,
3163) -> Result<()> {
3164 let message_ids = request
3165 .message_ids
3166 .iter()
3167 .map(|id| MessageId::from_proto(*id))
3168 .collect::<Vec<_>>();
3169 let messages = session
3170 .db()
3171 .await
3172 .get_channel_messages_by_id(session.user_id, &message_ids)
3173 .await?;
3174 response.send(proto::GetChannelMessagesResponse {
3175 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3176 messages,
3177 })?;
3178 Ok(())
3179}
3180
3181/// Retrieve the current users notifications
3182async fn get_notifications(
3183 request: proto::GetNotifications,
3184 response: Response<proto::GetNotifications>,
3185 session: Session,
3186) -> Result<()> {
3187 let notifications = session
3188 .db()
3189 .await
3190 .get_notifications(
3191 session.user_id,
3192 NOTIFICATION_COUNT_PER_PAGE,
3193 request
3194 .before_id
3195 .map(|id| db::NotificationId::from_proto(id)),
3196 )
3197 .await?;
3198 response.send(proto::GetNotificationsResponse {
3199 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3200 notifications,
3201 })?;
3202 Ok(())
3203}
3204
3205/// Mark notifications as read
3206async fn mark_notification_as_read(
3207 request: proto::MarkNotificationRead,
3208 response: Response<proto::MarkNotificationRead>,
3209 session: Session,
3210) -> Result<()> {
3211 let database = &session.db().await;
3212 let notifications = database
3213 .mark_notification_as_read_by_id(
3214 session.user_id,
3215 NotificationId::from_proto(request.notification_id),
3216 )
3217 .await?;
3218 send_notifications(
3219 &*session.connection_pool().await,
3220 &session.peer,
3221 notifications,
3222 );
3223 response.send(proto::Ack {})?;
3224 Ok(())
3225}
3226
3227/// Get the current users information
3228async fn get_private_user_info(
3229 _request: proto::GetPrivateUserInfo,
3230 response: Response<proto::GetPrivateUserInfo>,
3231 session: Session,
3232) -> Result<()> {
3233 let db = session.db().await;
3234
3235 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3236 let user = db
3237 .get_user_by_id(session.user_id)
3238 .await?
3239 .ok_or_else(|| anyhow!("user not found"))?;
3240 let flags = db.get_user_flags(session.user_id).await?;
3241
3242 response.send(proto::GetPrivateUserInfoResponse {
3243 metrics_id,
3244 staff: user.admin,
3245 flags,
3246 })?;
3247 Ok(())
3248}
3249
3250fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3251 match message {
3252 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3253 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3254 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3255 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3256 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3257 code: frame.code.into(),
3258 reason: frame.reason,
3259 })),
3260 }
3261}
3262
3263fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3264 match message {
3265 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3266 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3267 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3268 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3269 AxumMessage::Close(frame) => {
3270 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3271 code: frame.code.into(),
3272 reason: frame.reason,
3273 }))
3274 }
3275 }
3276}
3277
3278fn notify_membership_updated(
3279 connection_pool: &ConnectionPool,
3280 result: MembershipUpdated,
3281 user_id: UserId,
3282 peer: &Peer,
3283) {
3284 let user_channels_update = proto::UpdateUserChannels {
3285 channel_memberships: result
3286 .new_channels
3287 .channel_memberships
3288 .iter()
3289 .map(|cm| proto::ChannelMembership {
3290 channel_id: cm.channel_id.to_proto(),
3291 role: cm.role.into(),
3292 })
3293 .collect(),
3294 ..Default::default()
3295 };
3296 let mut update = build_channels_update(result.new_channels, vec![]);
3297 update.delete_channels = result
3298 .removed_channels
3299 .into_iter()
3300 .map(|id| id.to_proto())
3301 .collect();
3302 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3303
3304 for connection_id in connection_pool.user_connection_ids(user_id) {
3305 peer.send(connection_id, user_channels_update.clone())
3306 .trace_err();
3307 peer.send(connection_id, update.clone()).trace_err();
3308 }
3309}
3310
3311fn build_update_user_channels(
3312 memberships: &Vec<db::channel_member::Model>,
3313) -> proto::UpdateUserChannels {
3314 proto::UpdateUserChannels {
3315 channel_memberships: memberships
3316 .iter()
3317 .map(|m| proto::ChannelMembership {
3318 channel_id: m.channel_id.to_proto(),
3319 role: m.role.into(),
3320 })
3321 .collect(),
3322 ..Default::default()
3323 }
3324}
3325
3326fn build_channels_update(
3327 channels: ChannelsForUser,
3328 channel_invites: Vec<db::Channel>,
3329) -> proto::UpdateChannels {
3330 let mut update = proto::UpdateChannels::default();
3331
3332 for channel in channels.channels {
3333 update.channels.push(channel.to_proto());
3334 }
3335
3336 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3337 update.latest_channel_message_ids = channels.latest_channel_messages;
3338
3339 for (channel_id, participants) in channels.channel_participants {
3340 update
3341 .channel_participants
3342 .push(proto::ChannelParticipants {
3343 channel_id: channel_id.to_proto(),
3344 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3345 });
3346 }
3347
3348 for channel in channel_invites {
3349 update.channel_invitations.push(channel.to_proto());
3350 }
3351
3352 update
3353}
3354
3355fn build_initial_contacts_update(
3356 contacts: Vec<db::Contact>,
3357 pool: &ConnectionPool,
3358) -> proto::UpdateContacts {
3359 let mut update = proto::UpdateContacts::default();
3360
3361 for contact in contacts {
3362 match contact {
3363 db::Contact::Accepted { user_id, busy } => {
3364 update.contacts.push(contact_for_user(user_id, busy, &pool));
3365 }
3366 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3367 db::Contact::Incoming { user_id } => {
3368 update
3369 .incoming_requests
3370 .push(proto::IncomingContactRequest {
3371 requester_id: user_id.to_proto(),
3372 })
3373 }
3374 }
3375 }
3376
3377 update
3378}
3379
3380fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3381 proto::Contact {
3382 user_id: user_id.to_proto(),
3383 online: pool.is_user_online(user_id),
3384 busy,
3385 }
3386}
3387
3388fn room_updated(room: &proto::Room, peer: &Peer) {
3389 broadcast(
3390 None,
3391 room.participants
3392 .iter()
3393 .filter_map(|participant| Some(participant.peer_id?.into())),
3394 |peer_id| {
3395 peer.send(
3396 peer_id.into(),
3397 proto::RoomUpdated {
3398 room: Some(room.clone()),
3399 },
3400 )
3401 },
3402 );
3403}
3404
3405fn channel_updated(
3406 channel_id: ChannelId,
3407 room: &proto::Room,
3408 channel_members: &[UserId],
3409 peer: &Peer,
3410 pool: &ConnectionPool,
3411) {
3412 let participants = room
3413 .participants
3414 .iter()
3415 .map(|p| p.user_id)
3416 .collect::<Vec<_>>();
3417
3418 broadcast(
3419 None,
3420 channel_members
3421 .iter()
3422 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3423 |peer_id| {
3424 peer.send(
3425 peer_id.into(),
3426 proto::UpdateChannels {
3427 channel_participants: vec![proto::ChannelParticipants {
3428 channel_id: channel_id.to_proto(),
3429 participant_user_ids: participants.clone(),
3430 }],
3431 ..Default::default()
3432 },
3433 )
3434 },
3435 );
3436}
3437
3438async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3439 let db = session.db().await;
3440
3441 let contacts = db.get_contacts(user_id).await?;
3442 let busy = db.is_user_busy(user_id).await?;
3443
3444 let pool = session.connection_pool().await;
3445 let updated_contact = contact_for_user(user_id, busy, &pool);
3446 for contact in contacts {
3447 if let db::Contact::Accepted {
3448 user_id: contact_user_id,
3449 ..
3450 } = contact
3451 {
3452 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3453 session
3454 .peer
3455 .send(
3456 contact_conn_id,
3457 proto::UpdateContacts {
3458 contacts: vec![updated_contact.clone()],
3459 remove_contacts: Default::default(),
3460 incoming_requests: Default::default(),
3461 remove_incoming_requests: Default::default(),
3462 outgoing_requests: Default::default(),
3463 remove_outgoing_requests: Default::default(),
3464 },
3465 )
3466 .trace_err();
3467 }
3468 }
3469 }
3470 Ok(())
3471}
3472
3473async fn leave_room_for_session(session: &Session) -> Result<()> {
3474 let mut contacts_to_update = HashSet::default();
3475
3476 let room_id;
3477 let canceled_calls_to_user_ids;
3478 let live_kit_room;
3479 let delete_live_kit_room;
3480 let room;
3481 let channel_members;
3482 let channel_id;
3483
3484 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3485 contacts_to_update.insert(session.user_id);
3486
3487 for project in left_room.left_projects.values() {
3488 project_left(project, session);
3489 }
3490
3491 room_id = RoomId::from_proto(left_room.room.id);
3492 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3493 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3494 delete_live_kit_room = left_room.deleted;
3495 room = mem::take(&mut left_room.room);
3496 channel_members = mem::take(&mut left_room.channel_members);
3497 channel_id = left_room.channel_id;
3498
3499 room_updated(&room, &session.peer);
3500 } else {
3501 return Ok(());
3502 }
3503
3504 if let Some(channel_id) = channel_id {
3505 channel_updated(
3506 channel_id,
3507 &room,
3508 &channel_members,
3509 &session.peer,
3510 &*session.connection_pool().await,
3511 );
3512 }
3513
3514 {
3515 let pool = session.connection_pool().await;
3516 for canceled_user_id in canceled_calls_to_user_ids {
3517 for connection_id in pool.user_connection_ids(canceled_user_id) {
3518 session
3519 .peer
3520 .send(
3521 connection_id,
3522 proto::CallCanceled {
3523 room_id: room_id.to_proto(),
3524 },
3525 )
3526 .trace_err();
3527 }
3528 contacts_to_update.insert(canceled_user_id);
3529 }
3530 }
3531
3532 for contact_user_id in contacts_to_update {
3533 update_user_contacts(contact_user_id, &session).await?;
3534 }
3535
3536 if let Some(live_kit) = session.live_kit_client.as_ref() {
3537 live_kit
3538 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3539 .await
3540 .trace_err();
3541
3542 if delete_live_kit_room {
3543 live_kit.delete_room(live_kit_room).await.trace_err();
3544 }
3545 }
3546
3547 Ok(())
3548}
3549
3550async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3551 let left_channel_buffers = session
3552 .db()
3553 .await
3554 .leave_channel_buffers(session.connection_id)
3555 .await?;
3556
3557 for left_buffer in left_channel_buffers {
3558 channel_buffer_updated(
3559 session.connection_id,
3560 left_buffer.connections,
3561 &proto::UpdateChannelBufferCollaborators {
3562 channel_id: left_buffer.channel_id.to_proto(),
3563 collaborators: left_buffer.collaborators,
3564 },
3565 &session.peer,
3566 );
3567 }
3568
3569 Ok(())
3570}
3571
3572fn project_left(project: &db::LeftProject, session: &Session) {
3573 for connection_id in &project.connection_ids {
3574 if project.host_user_id == session.user_id {
3575 session
3576 .peer
3577 .send(
3578 *connection_id,
3579 proto::UnshareProject {
3580 project_id: project.id.to_proto(),
3581 },
3582 )
3583 .trace_err();
3584 } else {
3585 session
3586 .peer
3587 .send(
3588 *connection_id,
3589 proto::RemoveProjectCollaborator {
3590 project_id: project.id.to_proto(),
3591 peer_id: Some(session.connection_id.into()),
3592 },
3593 )
3594 .trace_err();
3595 }
3596 }
3597}
3598
3599pub trait ResultExt {
3600 type Ok;
3601
3602 fn trace_err(self) -> Option<Self::Ok>;
3603}
3604
3605impl<T, E> ResultExt for Result<T, E>
3606where
3607 E: std::fmt::Debug,
3608{
3609 type Ok = T;
3610
3611 fn trace_err(self) -> Option<T> {
3612 match self {
3613 Ok(value) => Some(value),
3614 Err(error) => {
3615 tracing::error!("{:?}", error);
3616 None
3617 }
3618 }
3619 }
3620}