1mod connection_pool;
2
3use crate::{
4 auth::{self, Impersonator},
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
7 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
8 MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
9 RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
10 User, UserId,
11 },
12 executor::Executor,
13 AppState, Error, Result,
14};
15use anyhow::anyhow;
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::ConnectionPool;
34use futures::{
35 channel::oneshot,
36 future::{self, BoxFuture},
37 stream::FuturesUnordered,
38 FutureExt, SinkExt, StreamExt, TryStreamExt,
39};
40use lazy_static::lazy_static;
41use prometheus::{register_int_gauge, IntGauge};
42use rpc::{
43 proto::{
44 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
45 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
46 },
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48};
49use serde::{Serialize, Serializer};
50use std::{
51 any::TypeId,
52 fmt,
53 future::Future,
54 marker::PhantomData,
55 mem,
56 net::SocketAddr,
57 ops::{Deref, DerefMut},
58 rc::Rc,
59 sync::{
60 atomic::{AtomicBool, Ordering::SeqCst},
61 Arc,
62 },
63 time::{Duration, Instant},
64};
65use time::OffsetDateTime;
66use tokio::sync::{watch, Semaphore};
67use tower::ServiceBuilder;
68use tracing::{field, info_span, instrument, Instrument};
69
70pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
71pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
72
73const MESSAGE_COUNT_PER_PAGE: usize = 100;
74const MAX_MESSAGE_LEN: usize = 1024;
75const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
76
77lazy_static! {
78 static ref METRIC_CONNECTIONS: IntGauge =
79 register_int_gauge!("connections", "number of connections").unwrap();
80 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
81 "shared_projects",
82 "number of open projects with one or more guests"
83 )
84 .unwrap();
85}
86
87type MessageHandler =
88 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
89
90struct Response<R> {
91 peer: Arc<Peer>,
92 receipt: Receipt<R>,
93 responded: Arc<AtomicBool>,
94}
95
96impl<R: RequestMessage> Response<R> {
97 fn send(self, payload: R::Response) -> Result<()> {
98 self.responded.store(true, SeqCst);
99 self.peer.respond(self.receipt, payload)?;
100 Ok(())
101 }
102}
103
104#[derive(Clone)]
105struct Session {
106 zed_environment: Arc<str>,
107 user_id: UserId,
108 connection_id: ConnectionId,
109 db: Arc<tokio::sync::Mutex<DbHandle>>,
110 peer: Arc<Peer>,
111 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
112 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
113 _executor: Executor,
114}
115
116impl Session {
117 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
118 #[cfg(test)]
119 tokio::task::yield_now().await;
120 let guard = self.db.lock().await;
121 #[cfg(test)]
122 tokio::task::yield_now().await;
123 guard
124 }
125
126 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
127 #[cfg(test)]
128 tokio::task::yield_now().await;
129 let guard = self.connection_pool.lock();
130 ConnectionPoolGuard {
131 guard,
132 _not_send: PhantomData,
133 }
134 }
135}
136
137impl fmt::Debug for Session {
138 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139 f.debug_struct("Session")
140 .field("user_id", &self.user_id)
141 .field("connection_id", &self.connection_id)
142 .finish()
143 }
144}
145
146struct DbHandle(Arc<Database>);
147
148impl Deref for DbHandle {
149 type Target = Database;
150
151 fn deref(&self) -> &Self::Target {
152 self.0.as_ref()
153 }
154}
155
156pub struct Server {
157 id: parking_lot::Mutex<ServerId>,
158 peer: Arc<Peer>,
159 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
160 app_state: Arc<AppState>,
161 executor: Executor,
162 handlers: HashMap<TypeId, MessageHandler>,
163 teardown: watch::Sender<()>,
164}
165
166pub(crate) struct ConnectionPoolGuard<'a> {
167 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
168 _not_send: PhantomData<Rc<()>>,
169}
170
171#[derive(Serialize)]
172pub struct ServerSnapshot<'a> {
173 peer: &'a Peer,
174 #[serde(serialize_with = "serialize_deref")]
175 connection_pool: ConnectionPoolGuard<'a>,
176}
177
178pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
179where
180 S: Serializer,
181 T: Deref<Target = U>,
182 U: Serialize,
183{
184 Serialize::serialize(value.deref(), serializer)
185}
186
187impl Server {
188 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
189 let mut server = Self {
190 id: parking_lot::Mutex::new(id),
191 peer: Peer::new(id.0 as u32),
192 app_state,
193 executor,
194 connection_pool: Default::default(),
195 handlers: Default::default(),
196 teardown: watch::channel(()).0,
197 };
198
199 server
200 .add_request_handler(ping)
201 .add_request_handler(create_room)
202 .add_request_handler(join_room)
203 .add_request_handler(rejoin_room)
204 .add_request_handler(leave_room)
205 .add_request_handler(set_room_participant_role)
206 .add_request_handler(call)
207 .add_request_handler(cancel_call)
208 .add_message_handler(decline_call)
209 .add_request_handler(update_participant_location)
210 .add_request_handler(share_project)
211 .add_message_handler(unshare_project)
212 .add_request_handler(join_project)
213 .add_message_handler(leave_project)
214 .add_request_handler(update_project)
215 .add_request_handler(update_worktree)
216 .add_message_handler(start_language_server)
217 .add_message_handler(update_language_server)
218 .add_message_handler(update_diagnostic_summary)
219 .add_message_handler(update_worktree_settings)
220 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
221 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
222 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
223 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
224 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
225 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
226 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
227 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
228 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
229 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
230 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
231 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
232 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
233 .add_request_handler(
234 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
235 )
236 .add_request_handler(
237 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
238 )
239 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
240 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
241 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
242 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
243 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
244 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
245 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
246 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
247 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
248 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
249 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
250 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
251 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
252 .add_message_handler(create_buffer_for_peer)
253 .add_request_handler(update_buffer)
254 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
255 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
256 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
257 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
258 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
259 .add_request_handler(get_users)
260 .add_request_handler(fuzzy_search_users)
261 .add_request_handler(request_contact)
262 .add_request_handler(remove_contact)
263 .add_request_handler(respond_to_contact_request)
264 .add_request_handler(create_channel)
265 .add_request_handler(delete_channel)
266 .add_request_handler(invite_channel_member)
267 .add_request_handler(remove_channel_member)
268 .add_request_handler(set_channel_member_role)
269 .add_request_handler(set_channel_visibility)
270 .add_request_handler(rename_channel)
271 .add_request_handler(join_channel_buffer)
272 .add_request_handler(leave_channel_buffer)
273 .add_message_handler(update_channel_buffer)
274 .add_request_handler(rejoin_channel_buffers)
275 .add_request_handler(get_channel_members)
276 .add_request_handler(respond_to_channel_invite)
277 .add_request_handler(join_channel)
278 .add_request_handler(join_channel_chat)
279 .add_message_handler(leave_channel_chat)
280 .add_request_handler(send_channel_message)
281 .add_request_handler(remove_channel_message)
282 .add_request_handler(get_channel_messages)
283 .add_request_handler(get_channel_messages_by_id)
284 .add_request_handler(get_notifications)
285 .add_request_handler(mark_notification_as_read)
286 .add_request_handler(move_channel)
287 .add_request_handler(follow)
288 .add_message_handler(unfollow)
289 .add_message_handler(update_followers)
290 .add_request_handler(get_private_user_info)
291 .add_message_handler(acknowledge_channel_message)
292 .add_message_handler(acknowledge_buffer_version);
293
294 Arc::new(server)
295 }
296
297 pub async fn start(&self) -> Result<()> {
298 let server_id = *self.id.lock();
299 let app_state = self.app_state.clone();
300 let peer = self.peer.clone();
301 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
302 let pool = self.connection_pool.clone();
303 let live_kit_client = self.app_state.live_kit_client.clone();
304
305 let span = info_span!("start server");
306 self.executor.spawn_detached(
307 async move {
308 tracing::info!("waiting for cleanup timeout");
309 timeout.await;
310 tracing::info!("cleanup timeout expired, retrieving stale rooms");
311 if let Some((room_ids, channel_ids)) = app_state
312 .db
313 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
314 .await
315 .trace_err()
316 {
317 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
318 tracing::info!(
319 stale_channel_buffer_count = channel_ids.len(),
320 "retrieved stale channel buffers"
321 );
322
323 for channel_id in channel_ids {
324 if let Some(refreshed_channel_buffer) = app_state
325 .db
326 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
327 .await
328 .trace_err()
329 {
330 for connection_id in refreshed_channel_buffer.connection_ids {
331 peer.send(
332 connection_id,
333 proto::UpdateChannelBufferCollaborators {
334 channel_id: channel_id.to_proto(),
335 collaborators: refreshed_channel_buffer
336 .collaborators
337 .clone(),
338 },
339 )
340 .trace_err();
341 }
342 }
343 }
344
345 for room_id in room_ids {
346 let mut contacts_to_update = HashSet::default();
347 let mut canceled_calls_to_user_ids = Vec::new();
348 let mut live_kit_room = String::new();
349 let mut delete_live_kit_room = false;
350
351 if let Some(mut refreshed_room) = app_state
352 .db
353 .clear_stale_room_participants(room_id, server_id)
354 .await
355 .trace_err()
356 {
357 tracing::info!(
358 room_id = room_id.0,
359 new_participant_count = refreshed_room.room.participants.len(),
360 "refreshed room"
361 );
362 room_updated(&refreshed_room.room, &peer);
363 if let Some(channel_id) = refreshed_room.channel_id {
364 channel_updated(
365 channel_id,
366 &refreshed_room.room,
367 &refreshed_room.channel_members,
368 &peer,
369 &*pool.lock(),
370 );
371 }
372 contacts_to_update
373 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
374 contacts_to_update
375 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
376 canceled_calls_to_user_ids =
377 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
378 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
379 delete_live_kit_room = refreshed_room.room.participants.is_empty();
380 }
381
382 {
383 let pool = pool.lock();
384 for canceled_user_id in canceled_calls_to_user_ids {
385 for connection_id in pool.user_connection_ids(canceled_user_id) {
386 peer.send(
387 connection_id,
388 proto::CallCanceled {
389 room_id: room_id.to_proto(),
390 },
391 )
392 .trace_err();
393 }
394 }
395 }
396
397 for user_id in contacts_to_update {
398 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
399 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
400 if let Some((busy, contacts)) = busy.zip(contacts) {
401 let pool = pool.lock();
402 let updated_contact = contact_for_user(user_id, busy, &pool);
403 for contact in contacts {
404 if let db::Contact::Accepted {
405 user_id: contact_user_id,
406 ..
407 } = contact
408 {
409 for contact_conn_id in
410 pool.user_connection_ids(contact_user_id)
411 {
412 peer.send(
413 contact_conn_id,
414 proto::UpdateContacts {
415 contacts: vec![updated_contact.clone()],
416 remove_contacts: Default::default(),
417 incoming_requests: Default::default(),
418 remove_incoming_requests: Default::default(),
419 outgoing_requests: Default::default(),
420 remove_outgoing_requests: Default::default(),
421 },
422 )
423 .trace_err();
424 }
425 }
426 }
427 }
428 }
429
430 if let Some(live_kit) = live_kit_client.as_ref() {
431 if delete_live_kit_room {
432 live_kit.delete_room(live_kit_room).await.trace_err();
433 }
434 }
435 }
436 }
437
438 app_state
439 .db
440 .delete_stale_servers(&app_state.config.zed_environment, server_id)
441 .await
442 .trace_err();
443 }
444 .instrument(span),
445 );
446 Ok(())
447 }
448
449 pub fn teardown(&self) {
450 self.peer.teardown();
451 self.connection_pool.lock().reset();
452 let _ = self.teardown.send(());
453 }
454
455 #[cfg(test)]
456 pub fn reset(&self, id: ServerId) {
457 self.teardown();
458 *self.id.lock() = id;
459 self.peer.reset(id.0 as u32);
460 }
461
462 #[cfg(test)]
463 pub fn id(&self) -> ServerId {
464 *self.id.lock()
465 }
466
467 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
468 where
469 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
470 Fut: 'static + Send + Future<Output = Result<()>>,
471 M: EnvelopedMessage,
472 {
473 let prev_handler = self.handlers.insert(
474 TypeId::of::<M>(),
475 Box::new(move |envelope, session| {
476 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
477 let span = info_span!(
478 "handle message",
479 payload_type = envelope.payload_type_name()
480 );
481 span.in_scope(|| {
482 tracing::info!(
483 payload_type = envelope.payload_type_name(),
484 "message received"
485 );
486 });
487 let start_time = Instant::now();
488 let future = (handler)(*envelope, session);
489 async move {
490 let result = future.await;
491 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
492 match result {
493 Err(error) => {
494 tracing::error!(%error, ?duration_ms, "error handling message")
495 }
496 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
497 }
498 }
499 .instrument(span)
500 .boxed()
501 }),
502 );
503 if prev_handler.is_some() {
504 panic!("registered a handler for the same message twice");
505 }
506 self
507 }
508
509 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
510 where
511 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
512 Fut: 'static + Send + Future<Output = Result<()>>,
513 M: EnvelopedMessage,
514 {
515 self.add_handler(move |envelope, session| handler(envelope.payload, session));
516 self
517 }
518
519 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
520 where
521 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
522 Fut: Send + Future<Output = Result<()>>,
523 M: RequestMessage,
524 {
525 let handler = Arc::new(handler);
526 self.add_handler(move |envelope, session| {
527 let receipt = envelope.receipt();
528 let handler = handler.clone();
529 async move {
530 let peer = session.peer.clone();
531 let responded = Arc::new(AtomicBool::default());
532 let response = Response {
533 peer: peer.clone(),
534 responded: responded.clone(),
535 receipt,
536 };
537 match (handler)(envelope.payload, response, session).await {
538 Ok(()) => {
539 if responded.load(std::sync::atomic::Ordering::SeqCst) {
540 Ok(())
541 } else {
542 Err(anyhow!("handler did not send a response"))?
543 }
544 }
545 Err(error) => {
546 let proto_err = match &error {
547 Error::Internal(err) => err.to_proto(),
548 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
549 };
550 peer.respond_with_error(receipt, proto_err)?;
551 Err(error)
552 }
553 }
554 }
555 })
556 }
557
558 pub fn handle_connection(
559 self: &Arc<Self>,
560 connection: Connection,
561 address: String,
562 user: User,
563 impersonator: Option<User>,
564 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
565 executor: Executor,
566 ) -> impl Future<Output = Result<()>> {
567 let this = self.clone();
568 let user_id = user.id;
569 let login = user.github_login;
570 let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
571 if let Some(impersonator) = impersonator {
572 span.record("impersonator", &impersonator.github_login);
573 }
574 let mut teardown = self.teardown.subscribe();
575 async move {
576 let (connection_id, handle_io, mut incoming_rx) = this
577 .peer
578 .add_connection(connection, {
579 let executor = executor.clone();
580 move |duration| executor.sleep(duration)
581 });
582
583 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
584 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
585 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
586
587 if let Some(send_connection_id) = send_connection_id.take() {
588 let _ = send_connection_id.send(connection_id);
589 }
590
591 if !user.connected_once {
592 this.peer.send(connection_id, proto::ShowContacts {})?;
593 this.app_state.db.set_user_connected_once(user_id, true).await?;
594 }
595
596 let (contacts, channels_for_user, channel_invites) = future::try_join3(
597 this.app_state.db.get_contacts(user_id),
598 this.app_state.db.get_channels_for_user(user_id),
599 this.app_state.db.get_channel_invites_for_user(user_id),
600 ).await?;
601
602 {
603 let mut pool = this.connection_pool.lock();
604 pool.add_connection(connection_id, user_id, user.admin);
605 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
606 this.peer.send(connection_id, build_channels_update(
607 channels_for_user,
608 channel_invites
609 ))?;
610 }
611
612 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
613 this.peer.send(connection_id, incoming_call)?;
614 }
615
616 let session = Session {
617 user_id,
618 connection_id,
619 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
620 zed_environment: this.app_state.config.zed_environment.clone(),
621 peer: this.peer.clone(),
622 connection_pool: this.connection_pool.clone(),
623 live_kit_client: this.app_state.live_kit_client.clone(),
624 _executor: executor.clone()
625 };
626 update_user_contacts(user_id, &session).await?;
627
628 let handle_io = handle_io.fuse();
629 futures::pin_mut!(handle_io);
630
631 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
632 // This prevents deadlocks when e.g., client A performs a request to client B and
633 // client B performs a request to client A. If both clients stop processing further
634 // messages until their respective request completes, they won't have a chance to
635 // respond to the other client's request and cause a deadlock.
636 //
637 // This arrangement ensures we will attempt to process earlier messages first, but fall
638 // back to processing messages arrived later in the spirit of making progress.
639 let mut foreground_message_handlers = FuturesUnordered::new();
640 let concurrent_handlers = Arc::new(Semaphore::new(256));
641 loop {
642 let next_message = async {
643 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
644 let message = incoming_rx.next().await;
645 (permit, message)
646 }.fuse();
647 futures::pin_mut!(next_message);
648 futures::select_biased! {
649 _ = teardown.changed().fuse() => return Ok(()),
650 result = handle_io => {
651 if let Err(error) = result {
652 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
653 }
654 break;
655 }
656 _ = foreground_message_handlers.next() => {}
657 next_message = next_message => {
658 let (permit, message) = next_message;
659 if let Some(message) = message {
660 let type_name = message.payload_type_name();
661 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
662 let span_enter = span.enter();
663 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
664 let is_background = message.is_background();
665 let handle_message = (handler)(message, session.clone());
666 drop(span_enter);
667
668 let handle_message = async move {
669 handle_message.await;
670 drop(permit);
671 }.instrument(span);
672 if is_background {
673 executor.spawn_detached(handle_message);
674 } else {
675 foreground_message_handlers.push(handle_message);
676 }
677 } else {
678 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
679 }
680 } else {
681 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
682 break;
683 }
684 }
685 }
686 }
687
688 drop(foreground_message_handlers);
689 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
690 if let Err(error) = connection_lost(session, teardown, executor).await {
691 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
692 }
693
694 Ok(())
695 }.instrument(span)
696 }
697
698 pub async fn invite_code_redeemed(
699 self: &Arc<Self>,
700 inviter_id: UserId,
701 invitee_id: UserId,
702 ) -> Result<()> {
703 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
704 if let Some(code) = &user.invite_code {
705 let pool = self.connection_pool.lock();
706 let invitee_contact = contact_for_user(invitee_id, false, &pool);
707 for connection_id in pool.user_connection_ids(inviter_id) {
708 self.peer.send(
709 connection_id,
710 proto::UpdateContacts {
711 contacts: vec![invitee_contact.clone()],
712 ..Default::default()
713 },
714 )?;
715 self.peer.send(
716 connection_id,
717 proto::UpdateInviteInfo {
718 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
719 count: user.invite_count as u32,
720 },
721 )?;
722 }
723 }
724 }
725 Ok(())
726 }
727
728 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
729 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
730 if let Some(invite_code) = &user.invite_code {
731 let pool = self.connection_pool.lock();
732 for connection_id in pool.user_connection_ids(user_id) {
733 self.peer.send(
734 connection_id,
735 proto::UpdateInviteInfo {
736 url: format!(
737 "{}{}",
738 self.app_state.config.invite_link_prefix, invite_code
739 ),
740 count: user.invite_count as u32,
741 },
742 )?;
743 }
744 }
745 }
746 Ok(())
747 }
748
749 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
750 ServerSnapshot {
751 connection_pool: ConnectionPoolGuard {
752 guard: self.connection_pool.lock(),
753 _not_send: PhantomData,
754 },
755 peer: &self.peer,
756 }
757 }
758}
759
760impl<'a> Deref for ConnectionPoolGuard<'a> {
761 type Target = ConnectionPool;
762
763 fn deref(&self) -> &Self::Target {
764 &*self.guard
765 }
766}
767
768impl<'a> DerefMut for ConnectionPoolGuard<'a> {
769 fn deref_mut(&mut self) -> &mut Self::Target {
770 &mut *self.guard
771 }
772}
773
774impl<'a> Drop for ConnectionPoolGuard<'a> {
775 fn drop(&mut self) {
776 #[cfg(test)]
777 self.check_invariants();
778 }
779}
780
781fn broadcast<F>(
782 sender_id: Option<ConnectionId>,
783 receiver_ids: impl IntoIterator<Item = ConnectionId>,
784 mut f: F,
785) where
786 F: FnMut(ConnectionId) -> anyhow::Result<()>,
787{
788 for receiver_id in receiver_ids {
789 if Some(receiver_id) != sender_id {
790 if let Err(error) = f(receiver_id) {
791 tracing::error!("failed to send to {:?} {}", receiver_id, error);
792 }
793 }
794 }
795}
796
797lazy_static! {
798 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
799}
800
801pub struct ProtocolVersion(u32);
802
803impl Header for ProtocolVersion {
804 fn name() -> &'static HeaderName {
805 &ZED_PROTOCOL_VERSION
806 }
807
808 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
809 where
810 Self: Sized,
811 I: Iterator<Item = &'i axum::http::HeaderValue>,
812 {
813 let version = values
814 .next()
815 .ok_or_else(axum::headers::Error::invalid)?
816 .to_str()
817 .map_err(|_| axum::headers::Error::invalid())?
818 .parse()
819 .map_err(|_| axum::headers::Error::invalid())?;
820 Ok(Self(version))
821 }
822
823 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
824 values.extend([self.0.to_string().parse().unwrap()]);
825 }
826}
827
828pub fn routes(server: Arc<Server>) -> Router<Body> {
829 Router::new()
830 .route("/rpc", get(handle_websocket_request))
831 .layer(
832 ServiceBuilder::new()
833 .layer(Extension(server.app_state.clone()))
834 .layer(middleware::from_fn(auth::validate_header)),
835 )
836 .route("/metrics", get(handle_metrics))
837 .layer(Extension(server))
838}
839
840pub async fn handle_websocket_request(
841 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
842 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
843 Extension(server): Extension<Arc<Server>>,
844 Extension(user): Extension<User>,
845 Extension(impersonator): Extension<Impersonator>,
846 ws: WebSocketUpgrade,
847) -> axum::response::Response {
848 if protocol_version != rpc::PROTOCOL_VERSION {
849 return (
850 StatusCode::UPGRADE_REQUIRED,
851 "client must be upgraded".to_string(),
852 )
853 .into_response();
854 }
855 let socket_address = socket_address.to_string();
856 ws.on_upgrade(move |socket| {
857 use util::ResultExt;
858 let socket = socket
859 .map_ok(to_tungstenite_message)
860 .err_into()
861 .with(|message| async move { Ok(to_axum_message(message)) });
862 let connection = Connection::new(Box::pin(socket));
863 async move {
864 server
865 .handle_connection(
866 connection,
867 socket_address,
868 user,
869 impersonator.0,
870 None,
871 Executor::Production,
872 )
873 .await
874 .log_err();
875 }
876 })
877}
878
879pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
880 let connections = server
881 .connection_pool
882 .lock()
883 .connections()
884 .filter(|connection| !connection.admin)
885 .count();
886
887 METRIC_CONNECTIONS.set(connections as _);
888
889 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
890 METRIC_SHARED_PROJECTS.set(shared_projects as _);
891
892 let encoder = prometheus::TextEncoder::new();
893 let metric_families = prometheus::gather();
894 let encoded_metrics = encoder
895 .encode_to_string(&metric_families)
896 .map_err(|err| anyhow!("{}", err))?;
897 Ok(encoded_metrics)
898}
899
900#[instrument(err, skip(executor))]
901async fn connection_lost(
902 session: Session,
903 mut teardown: watch::Receiver<()>,
904 executor: Executor,
905) -> Result<()> {
906 session.peer.disconnect(session.connection_id);
907 session
908 .connection_pool()
909 .await
910 .remove_connection(session.connection_id)?;
911
912 session
913 .db()
914 .await
915 .connection_lost(session.connection_id)
916 .await
917 .trace_err();
918
919 futures::select_biased! {
920 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
921 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
922 leave_room_for_session(&session).await.trace_err();
923 leave_channel_buffers_for_session(&session)
924 .await
925 .trace_err();
926
927 if !session
928 .connection_pool()
929 .await
930 .is_user_online(session.user_id)
931 {
932 let db = session.db().await;
933 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
934 room_updated(&room, &session.peer);
935 }
936 }
937
938 update_user_contacts(session.user_id, &session).await?;
939 }
940 _ = teardown.changed().fuse() => {}
941 }
942
943 Ok(())
944}
945
946/// Acknowledges a ping from a client, used to keep the connection alive.
947async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
948 response.send(proto::Ack {})?;
949 Ok(())
950}
951
952/// Creates a new room for calling (outside of channels)
953async fn create_room(
954 _request: proto::CreateRoom,
955 response: Response<proto::CreateRoom>,
956 session: Session,
957) -> Result<()> {
958 let live_kit_room = nanoid::nanoid!(30);
959
960 let live_kit_connection_info = {
961 let live_kit_room = live_kit_room.clone();
962 let live_kit = session.live_kit_client.as_ref();
963
964 util::async_maybe!({
965 let live_kit = live_kit?;
966
967 let token = live_kit
968 .room_token(&live_kit_room, &session.user_id.to_string())
969 .trace_err()?;
970
971 Some(proto::LiveKitConnectionInfo {
972 server_url: live_kit.url().into(),
973 token,
974 can_publish: true,
975 })
976 })
977 }
978 .await;
979
980 let room = session
981 .db()
982 .await
983 .create_room(
984 session.user_id,
985 session.connection_id,
986 &live_kit_room,
987 &session.zed_environment,
988 )
989 .await?;
990
991 response.send(proto::CreateRoomResponse {
992 room: Some(room.clone()),
993 live_kit_connection_info,
994 })?;
995
996 update_user_contacts(session.user_id, &session).await?;
997 Ok(())
998}
999
1000/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1001async fn join_room(
1002 request: proto::JoinRoom,
1003 response: Response<proto::JoinRoom>,
1004 session: Session,
1005) -> Result<()> {
1006 let room_id = RoomId::from_proto(request.id);
1007
1008 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1009
1010 if let Some(channel_id) = channel_id {
1011 return join_channel_internal(channel_id, Box::new(response), session).await;
1012 }
1013
1014 let joined_room = {
1015 let room = session
1016 .db()
1017 .await
1018 .join_room(
1019 room_id,
1020 session.user_id,
1021 session.connection_id,
1022 session.zed_environment.as_ref(),
1023 )
1024 .await?;
1025 room_updated(&room.room, &session.peer);
1026 room.into_inner()
1027 };
1028
1029 for connection_id in session
1030 .connection_pool()
1031 .await
1032 .user_connection_ids(session.user_id)
1033 {
1034 session
1035 .peer
1036 .send(
1037 connection_id,
1038 proto::CallCanceled {
1039 room_id: room_id.to_proto(),
1040 },
1041 )
1042 .trace_err();
1043 }
1044
1045 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1046 if let Some(token) = live_kit
1047 .room_token(
1048 &joined_room.room.live_kit_room,
1049 &session.user_id.to_string(),
1050 )
1051 .trace_err()
1052 {
1053 Some(proto::LiveKitConnectionInfo {
1054 server_url: live_kit.url().into(),
1055 token,
1056 can_publish: true,
1057 })
1058 } else {
1059 None
1060 }
1061 } else {
1062 None
1063 };
1064
1065 response.send(proto::JoinRoomResponse {
1066 room: Some(joined_room.room),
1067 channel_id: None,
1068 live_kit_connection_info,
1069 })?;
1070
1071 update_user_contacts(session.user_id, &session).await?;
1072 Ok(())
1073}
1074
1075/// Rejoin room is used to reconnect to a room after connection errors.
1076async fn rejoin_room(
1077 request: proto::RejoinRoom,
1078 response: Response<proto::RejoinRoom>,
1079 session: Session,
1080) -> Result<()> {
1081 let room;
1082 let channel_id;
1083 let channel_members;
1084 {
1085 let mut rejoined_room = session
1086 .db()
1087 .await
1088 .rejoin_room(request, session.user_id, session.connection_id)
1089 .await?;
1090
1091 response.send(proto::RejoinRoomResponse {
1092 room: Some(rejoined_room.room.clone()),
1093 reshared_projects: rejoined_room
1094 .reshared_projects
1095 .iter()
1096 .map(|project| proto::ResharedProject {
1097 id: project.id.to_proto(),
1098 collaborators: project
1099 .collaborators
1100 .iter()
1101 .map(|collaborator| collaborator.to_proto())
1102 .collect(),
1103 })
1104 .collect(),
1105 rejoined_projects: rejoined_room
1106 .rejoined_projects
1107 .iter()
1108 .map(|rejoined_project| proto::RejoinedProject {
1109 id: rejoined_project.id.to_proto(),
1110 worktrees: rejoined_project
1111 .worktrees
1112 .iter()
1113 .map(|worktree| proto::WorktreeMetadata {
1114 id: worktree.id,
1115 root_name: worktree.root_name.clone(),
1116 visible: worktree.visible,
1117 abs_path: worktree.abs_path.clone(),
1118 })
1119 .collect(),
1120 collaborators: rejoined_project
1121 .collaborators
1122 .iter()
1123 .map(|collaborator| collaborator.to_proto())
1124 .collect(),
1125 language_servers: rejoined_project.language_servers.clone(),
1126 })
1127 .collect(),
1128 })?;
1129 room_updated(&rejoined_room.room, &session.peer);
1130
1131 for project in &rejoined_room.reshared_projects {
1132 for collaborator in &project.collaborators {
1133 session
1134 .peer
1135 .send(
1136 collaborator.connection_id,
1137 proto::UpdateProjectCollaborator {
1138 project_id: project.id.to_proto(),
1139 old_peer_id: Some(project.old_connection_id.into()),
1140 new_peer_id: Some(session.connection_id.into()),
1141 },
1142 )
1143 .trace_err();
1144 }
1145
1146 broadcast(
1147 Some(session.connection_id),
1148 project
1149 .collaborators
1150 .iter()
1151 .map(|collaborator| collaborator.connection_id),
1152 |connection_id| {
1153 session.peer.forward_send(
1154 session.connection_id,
1155 connection_id,
1156 proto::UpdateProject {
1157 project_id: project.id.to_proto(),
1158 worktrees: project.worktrees.clone(),
1159 },
1160 )
1161 },
1162 );
1163 }
1164
1165 for project in &rejoined_room.rejoined_projects {
1166 for collaborator in &project.collaborators {
1167 session
1168 .peer
1169 .send(
1170 collaborator.connection_id,
1171 proto::UpdateProjectCollaborator {
1172 project_id: project.id.to_proto(),
1173 old_peer_id: Some(project.old_connection_id.into()),
1174 new_peer_id: Some(session.connection_id.into()),
1175 },
1176 )
1177 .trace_err();
1178 }
1179 }
1180
1181 for project in &mut rejoined_room.rejoined_projects {
1182 for worktree in mem::take(&mut project.worktrees) {
1183 #[cfg(any(test, feature = "test-support"))]
1184 const MAX_CHUNK_SIZE: usize = 2;
1185 #[cfg(not(any(test, feature = "test-support")))]
1186 const MAX_CHUNK_SIZE: usize = 256;
1187
1188 // Stream this worktree's entries.
1189 let message = proto::UpdateWorktree {
1190 project_id: project.id.to_proto(),
1191 worktree_id: worktree.id,
1192 abs_path: worktree.abs_path.clone(),
1193 root_name: worktree.root_name,
1194 updated_entries: worktree.updated_entries,
1195 removed_entries: worktree.removed_entries,
1196 scan_id: worktree.scan_id,
1197 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1198 updated_repositories: worktree.updated_repositories,
1199 removed_repositories: worktree.removed_repositories,
1200 };
1201 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1202 session.peer.send(session.connection_id, update.clone())?;
1203 }
1204
1205 // Stream this worktree's diagnostics.
1206 for summary in worktree.diagnostic_summaries {
1207 session.peer.send(
1208 session.connection_id,
1209 proto::UpdateDiagnosticSummary {
1210 project_id: project.id.to_proto(),
1211 worktree_id: worktree.id,
1212 summary: Some(summary),
1213 },
1214 )?;
1215 }
1216
1217 for settings_file in worktree.settings_files {
1218 session.peer.send(
1219 session.connection_id,
1220 proto::UpdateWorktreeSettings {
1221 project_id: project.id.to_proto(),
1222 worktree_id: worktree.id,
1223 path: settings_file.path,
1224 content: Some(settings_file.content),
1225 },
1226 )?;
1227 }
1228 }
1229
1230 for language_server in &project.language_servers {
1231 session.peer.send(
1232 session.connection_id,
1233 proto::UpdateLanguageServer {
1234 project_id: project.id.to_proto(),
1235 language_server_id: language_server.id,
1236 variant: Some(
1237 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1238 proto::LspDiskBasedDiagnosticsUpdated {},
1239 ),
1240 ),
1241 },
1242 )?;
1243 }
1244 }
1245
1246 let rejoined_room = rejoined_room.into_inner();
1247
1248 room = rejoined_room.room;
1249 channel_id = rejoined_room.channel_id;
1250 channel_members = rejoined_room.channel_members;
1251 }
1252
1253 if let Some(channel_id) = channel_id {
1254 channel_updated(
1255 channel_id,
1256 &room,
1257 &channel_members,
1258 &session.peer,
1259 &*session.connection_pool().await,
1260 );
1261 }
1262
1263 update_user_contacts(session.user_id, &session).await?;
1264 Ok(())
1265}
1266
1267/// leave room disonnects from the room.
1268async fn leave_room(
1269 _: proto::LeaveRoom,
1270 response: Response<proto::LeaveRoom>,
1271 session: Session,
1272) -> Result<()> {
1273 leave_room_for_session(&session).await?;
1274 response.send(proto::Ack {})?;
1275 Ok(())
1276}
1277
1278/// Updates the permissions of someone else in the room.
1279async fn set_room_participant_role(
1280 request: proto::SetRoomParticipantRole,
1281 response: Response<proto::SetRoomParticipantRole>,
1282 session: Session,
1283) -> Result<()> {
1284 let (live_kit_room, can_publish) = {
1285 let room = session
1286 .db()
1287 .await
1288 .set_room_participant_role(
1289 session.user_id,
1290 RoomId::from_proto(request.room_id),
1291 UserId::from_proto(request.user_id),
1292 ChannelRole::from(request.role()),
1293 )
1294 .await?;
1295
1296 let live_kit_room = room.live_kit_room.clone();
1297 let can_publish = ChannelRole::from(request.role()).can_publish_to_rooms();
1298 room_updated(&room, &session.peer);
1299 (live_kit_room, can_publish)
1300 };
1301
1302 if let Some(live_kit) = session.live_kit_client.as_ref() {
1303 live_kit
1304 .update_participant(
1305 live_kit_room.clone(),
1306 request.user_id.to_string(),
1307 live_kit_server::proto::ParticipantPermission {
1308 can_subscribe: true,
1309 can_publish,
1310 can_publish_data: can_publish,
1311 hidden: false,
1312 recorder: false,
1313 },
1314 )
1315 .await
1316 .trace_err();
1317 }
1318
1319 response.send(proto::Ack {})?;
1320 Ok(())
1321}
1322
1323/// Call someone else into the current room
1324async fn call(
1325 request: proto::Call,
1326 response: Response<proto::Call>,
1327 session: Session,
1328) -> Result<()> {
1329 let room_id = RoomId::from_proto(request.room_id);
1330 let calling_user_id = session.user_id;
1331 let calling_connection_id = session.connection_id;
1332 let called_user_id = UserId::from_proto(request.called_user_id);
1333 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1334 if !session
1335 .db()
1336 .await
1337 .has_contact(calling_user_id, called_user_id)
1338 .await?
1339 {
1340 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1341 }
1342
1343 let incoming_call = {
1344 let (room, incoming_call) = &mut *session
1345 .db()
1346 .await
1347 .call(
1348 room_id,
1349 calling_user_id,
1350 calling_connection_id,
1351 called_user_id,
1352 initial_project_id,
1353 )
1354 .await?;
1355 room_updated(&room, &session.peer);
1356 mem::take(incoming_call)
1357 };
1358 update_user_contacts(called_user_id, &session).await?;
1359
1360 let mut calls = session
1361 .connection_pool()
1362 .await
1363 .user_connection_ids(called_user_id)
1364 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1365 .collect::<FuturesUnordered<_>>();
1366
1367 while let Some(call_response) = calls.next().await {
1368 match call_response.as_ref() {
1369 Ok(_) => {
1370 response.send(proto::Ack {})?;
1371 return Ok(());
1372 }
1373 Err(_) => {
1374 call_response.trace_err();
1375 }
1376 }
1377 }
1378
1379 {
1380 let room = session
1381 .db()
1382 .await
1383 .call_failed(room_id, called_user_id)
1384 .await?;
1385 room_updated(&room, &session.peer);
1386 }
1387 update_user_contacts(called_user_id, &session).await?;
1388
1389 Err(anyhow!("failed to ring user"))?
1390}
1391
1392/// Cancel an outgoing call.
1393async fn cancel_call(
1394 request: proto::CancelCall,
1395 response: Response<proto::CancelCall>,
1396 session: Session,
1397) -> Result<()> {
1398 let called_user_id = UserId::from_proto(request.called_user_id);
1399 let room_id = RoomId::from_proto(request.room_id);
1400 {
1401 let room = session
1402 .db()
1403 .await
1404 .cancel_call(room_id, session.connection_id, called_user_id)
1405 .await?;
1406 room_updated(&room, &session.peer);
1407 }
1408
1409 for connection_id in session
1410 .connection_pool()
1411 .await
1412 .user_connection_ids(called_user_id)
1413 {
1414 session
1415 .peer
1416 .send(
1417 connection_id,
1418 proto::CallCanceled {
1419 room_id: room_id.to_proto(),
1420 },
1421 )
1422 .trace_err();
1423 }
1424 response.send(proto::Ack {})?;
1425
1426 update_user_contacts(called_user_id, &session).await?;
1427 Ok(())
1428}
1429
1430/// Decline an incoming call.
1431async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1432 let room_id = RoomId::from_proto(message.room_id);
1433 {
1434 let room = session
1435 .db()
1436 .await
1437 .decline_call(Some(room_id), session.user_id)
1438 .await?
1439 .ok_or_else(|| anyhow!("failed to decline call"))?;
1440 room_updated(&room, &session.peer);
1441 }
1442
1443 for connection_id in session
1444 .connection_pool()
1445 .await
1446 .user_connection_ids(session.user_id)
1447 {
1448 session
1449 .peer
1450 .send(
1451 connection_id,
1452 proto::CallCanceled {
1453 room_id: room_id.to_proto(),
1454 },
1455 )
1456 .trace_err();
1457 }
1458 update_user_contacts(session.user_id, &session).await?;
1459 Ok(())
1460}
1461
1462/// Updates other participants in the room with your current location.
1463async fn update_participant_location(
1464 request: proto::UpdateParticipantLocation,
1465 response: Response<proto::UpdateParticipantLocation>,
1466 session: Session,
1467) -> Result<()> {
1468 let room_id = RoomId::from_proto(request.room_id);
1469 let location = request
1470 .location
1471 .ok_or_else(|| anyhow!("invalid location"))?;
1472
1473 let db = session.db().await;
1474 let room = db
1475 .update_room_participant_location(room_id, session.connection_id, location)
1476 .await?;
1477
1478 room_updated(&room, &session.peer);
1479 response.send(proto::Ack {})?;
1480 Ok(())
1481}
1482
1483/// Share a project into the room.
1484async fn share_project(
1485 request: proto::ShareProject,
1486 response: Response<proto::ShareProject>,
1487 session: Session,
1488) -> Result<()> {
1489 let (project_id, room) = &*session
1490 .db()
1491 .await
1492 .share_project(
1493 RoomId::from_proto(request.room_id),
1494 session.connection_id,
1495 &request.worktrees,
1496 )
1497 .await?;
1498 response.send(proto::ShareProjectResponse {
1499 project_id: project_id.to_proto(),
1500 })?;
1501 room_updated(&room, &session.peer);
1502
1503 Ok(())
1504}
1505
1506/// Unshare a project from the room.
1507async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1508 let project_id = ProjectId::from_proto(message.project_id);
1509
1510 let (room, guest_connection_ids) = &*session
1511 .db()
1512 .await
1513 .unshare_project(project_id, session.connection_id)
1514 .await?;
1515
1516 broadcast(
1517 Some(session.connection_id),
1518 guest_connection_ids.iter().copied(),
1519 |conn_id| session.peer.send(conn_id, message.clone()),
1520 );
1521 room_updated(&room, &session.peer);
1522
1523 Ok(())
1524}
1525
1526/// Join someone elses shared project.
1527async fn join_project(
1528 request: proto::JoinProject,
1529 response: Response<proto::JoinProject>,
1530 session: Session,
1531) -> Result<()> {
1532 let project_id = ProjectId::from_proto(request.project_id);
1533 let guest_user_id = session.user_id;
1534
1535 tracing::info!(%project_id, "join project");
1536
1537 let (project, replica_id) = &mut *session
1538 .db()
1539 .await
1540 .join_project(project_id, session.connection_id)
1541 .await?;
1542
1543 let collaborators = project
1544 .collaborators
1545 .iter()
1546 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1547 .map(|collaborator| collaborator.to_proto())
1548 .collect::<Vec<_>>();
1549
1550 let worktrees = project
1551 .worktrees
1552 .iter()
1553 .map(|(id, worktree)| proto::WorktreeMetadata {
1554 id: *id,
1555 root_name: worktree.root_name.clone(),
1556 visible: worktree.visible,
1557 abs_path: worktree.abs_path.clone(),
1558 })
1559 .collect::<Vec<_>>();
1560
1561 for collaborator in &collaborators {
1562 session
1563 .peer
1564 .send(
1565 collaborator.peer_id.unwrap().into(),
1566 proto::AddProjectCollaborator {
1567 project_id: project_id.to_proto(),
1568 collaborator: Some(proto::Collaborator {
1569 peer_id: Some(session.connection_id.into()),
1570 replica_id: replica_id.0 as u32,
1571 user_id: guest_user_id.to_proto(),
1572 }),
1573 },
1574 )
1575 .trace_err();
1576 }
1577
1578 // First, we send the metadata associated with each worktree.
1579 response.send(proto::JoinProjectResponse {
1580 worktrees: worktrees.clone(),
1581 replica_id: replica_id.0 as u32,
1582 collaborators: collaborators.clone(),
1583 language_servers: project.language_servers.clone(),
1584 })?;
1585
1586 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1587 #[cfg(any(test, feature = "test-support"))]
1588 const MAX_CHUNK_SIZE: usize = 2;
1589 #[cfg(not(any(test, feature = "test-support")))]
1590 const MAX_CHUNK_SIZE: usize = 256;
1591
1592 // Stream this worktree's entries.
1593 let message = proto::UpdateWorktree {
1594 project_id: project_id.to_proto(),
1595 worktree_id,
1596 abs_path: worktree.abs_path.clone(),
1597 root_name: worktree.root_name,
1598 updated_entries: worktree.entries,
1599 removed_entries: Default::default(),
1600 scan_id: worktree.scan_id,
1601 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1602 updated_repositories: worktree.repository_entries.into_values().collect(),
1603 removed_repositories: Default::default(),
1604 };
1605 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1606 session.peer.send(session.connection_id, update.clone())?;
1607 }
1608
1609 // Stream this worktree's diagnostics.
1610 for summary in worktree.diagnostic_summaries {
1611 session.peer.send(
1612 session.connection_id,
1613 proto::UpdateDiagnosticSummary {
1614 project_id: project_id.to_proto(),
1615 worktree_id: worktree.id,
1616 summary: Some(summary),
1617 },
1618 )?;
1619 }
1620
1621 for settings_file in worktree.settings_files {
1622 session.peer.send(
1623 session.connection_id,
1624 proto::UpdateWorktreeSettings {
1625 project_id: project_id.to_proto(),
1626 worktree_id: worktree.id,
1627 path: settings_file.path,
1628 content: Some(settings_file.content),
1629 },
1630 )?;
1631 }
1632 }
1633
1634 for language_server in &project.language_servers {
1635 session.peer.send(
1636 session.connection_id,
1637 proto::UpdateLanguageServer {
1638 project_id: project_id.to_proto(),
1639 language_server_id: language_server.id,
1640 variant: Some(
1641 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1642 proto::LspDiskBasedDiagnosticsUpdated {},
1643 ),
1644 ),
1645 },
1646 )?;
1647 }
1648
1649 Ok(())
1650}
1651
1652/// Leave someone elses shared project.
1653async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1654 let sender_id = session.connection_id;
1655 let project_id = ProjectId::from_proto(request.project_id);
1656
1657 let (room, project) = &*session
1658 .db()
1659 .await
1660 .leave_project(project_id, sender_id)
1661 .await?;
1662 tracing::info!(
1663 %project_id,
1664 host_user_id = %project.host_user_id,
1665 host_connection_id = %project.host_connection_id,
1666 "leave project"
1667 );
1668
1669 project_left(&project, &session);
1670 room_updated(&room, &session.peer);
1671
1672 Ok(())
1673}
1674
1675/// Updates other participants with changes to the project
1676async fn update_project(
1677 request: proto::UpdateProject,
1678 response: Response<proto::UpdateProject>,
1679 session: Session,
1680) -> Result<()> {
1681 let project_id = ProjectId::from_proto(request.project_id);
1682 let (room, guest_connection_ids) = &*session
1683 .db()
1684 .await
1685 .update_project(project_id, session.connection_id, &request.worktrees)
1686 .await?;
1687 broadcast(
1688 Some(session.connection_id),
1689 guest_connection_ids.iter().copied(),
1690 |connection_id| {
1691 session
1692 .peer
1693 .forward_send(session.connection_id, connection_id, request.clone())
1694 },
1695 );
1696 room_updated(&room, &session.peer);
1697 response.send(proto::Ack {})?;
1698
1699 Ok(())
1700}
1701
1702/// Updates other participants with changes to the worktree
1703async fn update_worktree(
1704 request: proto::UpdateWorktree,
1705 response: Response<proto::UpdateWorktree>,
1706 session: Session,
1707) -> Result<()> {
1708 let guest_connection_ids = session
1709 .db()
1710 .await
1711 .update_worktree(&request, session.connection_id)
1712 .await?;
1713
1714 broadcast(
1715 Some(session.connection_id),
1716 guest_connection_ids.iter().copied(),
1717 |connection_id| {
1718 session
1719 .peer
1720 .forward_send(session.connection_id, connection_id, request.clone())
1721 },
1722 );
1723 response.send(proto::Ack {})?;
1724 Ok(())
1725}
1726
1727/// Updates other participants with changes to the diagnostics
1728async fn update_diagnostic_summary(
1729 message: proto::UpdateDiagnosticSummary,
1730 session: Session,
1731) -> Result<()> {
1732 let guest_connection_ids = session
1733 .db()
1734 .await
1735 .update_diagnostic_summary(&message, session.connection_id)
1736 .await?;
1737
1738 broadcast(
1739 Some(session.connection_id),
1740 guest_connection_ids.iter().copied(),
1741 |connection_id| {
1742 session
1743 .peer
1744 .forward_send(session.connection_id, connection_id, message.clone())
1745 },
1746 );
1747
1748 Ok(())
1749}
1750
1751/// Updates other participants with changes to the worktree settings
1752async fn update_worktree_settings(
1753 message: proto::UpdateWorktreeSettings,
1754 session: Session,
1755) -> Result<()> {
1756 let guest_connection_ids = session
1757 .db()
1758 .await
1759 .update_worktree_settings(&message, session.connection_id)
1760 .await?;
1761
1762 broadcast(
1763 Some(session.connection_id),
1764 guest_connection_ids.iter().copied(),
1765 |connection_id| {
1766 session
1767 .peer
1768 .forward_send(session.connection_id, connection_id, message.clone())
1769 },
1770 );
1771
1772 Ok(())
1773}
1774
1775/// Notify other participants that a language server has started.
1776async fn start_language_server(
1777 request: proto::StartLanguageServer,
1778 session: Session,
1779) -> Result<()> {
1780 let guest_connection_ids = session
1781 .db()
1782 .await
1783 .start_language_server(&request, session.connection_id)
1784 .await?;
1785
1786 broadcast(
1787 Some(session.connection_id),
1788 guest_connection_ids.iter().copied(),
1789 |connection_id| {
1790 session
1791 .peer
1792 .forward_send(session.connection_id, connection_id, request.clone())
1793 },
1794 );
1795 Ok(())
1796}
1797
1798/// Notify other participants that a language server has changed.
1799async fn update_language_server(
1800 request: proto::UpdateLanguageServer,
1801 session: Session,
1802) -> Result<()> {
1803 let project_id = ProjectId::from_proto(request.project_id);
1804 let project_connection_ids = session
1805 .db()
1806 .await
1807 .project_connection_ids(project_id, session.connection_id)
1808 .await?;
1809 broadcast(
1810 Some(session.connection_id),
1811 project_connection_ids.iter().copied(),
1812 |connection_id| {
1813 session
1814 .peer
1815 .forward_send(session.connection_id, connection_id, request.clone())
1816 },
1817 );
1818 Ok(())
1819}
1820
1821/// forward a project request to the host. These requests should be read only
1822/// as guests are allowed to send them.
1823async fn forward_read_only_project_request<T>(
1824 request: T,
1825 response: Response<T>,
1826 session: Session,
1827) -> Result<()>
1828where
1829 T: EntityMessage + RequestMessage,
1830{
1831 let project_id = ProjectId::from_proto(request.remote_entity_id());
1832 let host_connection_id = session
1833 .db()
1834 .await
1835 .host_for_read_only_project_request(project_id, session.connection_id)
1836 .await?;
1837 let payload = session
1838 .peer
1839 .forward_request(session.connection_id, host_connection_id, request)
1840 .await?;
1841 response.send(payload)?;
1842 Ok(())
1843}
1844
1845/// forward a project request to the host. These requests are disallowed
1846/// for guests.
1847async fn forward_mutating_project_request<T>(
1848 request: T,
1849 response: Response<T>,
1850 session: Session,
1851) -> Result<()>
1852where
1853 T: EntityMessage + RequestMessage,
1854{
1855 let project_id = ProjectId::from_proto(request.remote_entity_id());
1856 let host_connection_id = session
1857 .db()
1858 .await
1859 .host_for_mutating_project_request(project_id, session.connection_id)
1860 .await?;
1861 let payload = session
1862 .peer
1863 .forward_request(session.connection_id, host_connection_id, request)
1864 .await?;
1865 response.send(payload)?;
1866 Ok(())
1867}
1868
1869/// Notify other participants that a new buffer has been created
1870async fn create_buffer_for_peer(
1871 request: proto::CreateBufferForPeer,
1872 session: Session,
1873) -> Result<()> {
1874 session
1875 .db()
1876 .await
1877 .check_user_is_project_host(
1878 ProjectId::from_proto(request.project_id),
1879 session.connection_id,
1880 )
1881 .await?;
1882 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1883 session
1884 .peer
1885 .forward_send(session.connection_id, peer_id.into(), request)?;
1886 Ok(())
1887}
1888
1889/// Notify other participants that a buffer has been updated. This is
1890/// allowed for guests as long as the update is limited to selections.
1891async fn update_buffer(
1892 request: proto::UpdateBuffer,
1893 response: Response<proto::UpdateBuffer>,
1894 session: Session,
1895) -> Result<()> {
1896 let project_id = ProjectId::from_proto(request.project_id);
1897 let mut guest_connection_ids;
1898 let mut host_connection_id = None;
1899
1900 let mut requires_write_permission = false;
1901
1902 for op in request.operations.iter() {
1903 match op.variant {
1904 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1905 Some(_) => requires_write_permission = true,
1906 }
1907 }
1908
1909 {
1910 let collaborators = session
1911 .db()
1912 .await
1913 .project_collaborators_for_buffer_update(
1914 project_id,
1915 session.connection_id,
1916 requires_write_permission,
1917 )
1918 .await?;
1919 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1920 for collaborator in collaborators.iter() {
1921 if collaborator.is_host {
1922 host_connection_id = Some(collaborator.connection_id);
1923 } else {
1924 guest_connection_ids.push(collaborator.connection_id);
1925 }
1926 }
1927 }
1928 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1929
1930 broadcast(
1931 Some(session.connection_id),
1932 guest_connection_ids,
1933 |connection_id| {
1934 session
1935 .peer
1936 .forward_send(session.connection_id, connection_id, request.clone())
1937 },
1938 );
1939 if host_connection_id != session.connection_id {
1940 session
1941 .peer
1942 .forward_request(session.connection_id, host_connection_id, request.clone())
1943 .await?;
1944 }
1945
1946 response.send(proto::Ack {})?;
1947 Ok(())
1948}
1949
1950/// Notify other participants that a project has been updated.
1951async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
1952 request: T,
1953 session: Session,
1954) -> Result<()> {
1955 let project_id = ProjectId::from_proto(request.remote_entity_id());
1956 let project_connection_ids = session
1957 .db()
1958 .await
1959 .project_connection_ids(project_id, session.connection_id)
1960 .await?;
1961
1962 broadcast(
1963 Some(session.connection_id),
1964 project_connection_ids.iter().copied(),
1965 |connection_id| {
1966 session
1967 .peer
1968 .forward_send(session.connection_id, connection_id, request.clone())
1969 },
1970 );
1971 Ok(())
1972}
1973
1974/// Start following another user in a call.
1975async fn follow(
1976 request: proto::Follow,
1977 response: Response<proto::Follow>,
1978 session: Session,
1979) -> Result<()> {
1980 let room_id = RoomId::from_proto(request.room_id);
1981 let project_id = request.project_id.map(ProjectId::from_proto);
1982 let leader_id = request
1983 .leader_id
1984 .ok_or_else(|| anyhow!("invalid leader id"))?
1985 .into();
1986 let follower_id = session.connection_id;
1987
1988 session
1989 .db()
1990 .await
1991 .check_room_participants(room_id, leader_id, session.connection_id)
1992 .await?;
1993
1994 let response_payload = session
1995 .peer
1996 .forward_request(session.connection_id, leader_id, request)
1997 .await?;
1998 response.send(response_payload)?;
1999
2000 if let Some(project_id) = project_id {
2001 let room = session
2002 .db()
2003 .await
2004 .follow(room_id, project_id, leader_id, follower_id)
2005 .await?;
2006 room_updated(&room, &session.peer);
2007 }
2008
2009 Ok(())
2010}
2011
2012/// Stop following another user in a call.
2013async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2014 let room_id = RoomId::from_proto(request.room_id);
2015 let project_id = request.project_id.map(ProjectId::from_proto);
2016 let leader_id = request
2017 .leader_id
2018 .ok_or_else(|| anyhow!("invalid leader id"))?
2019 .into();
2020 let follower_id = session.connection_id;
2021
2022 session
2023 .db()
2024 .await
2025 .check_room_participants(room_id, leader_id, session.connection_id)
2026 .await?;
2027
2028 session
2029 .peer
2030 .forward_send(session.connection_id, leader_id, request)?;
2031
2032 if let Some(project_id) = project_id {
2033 let room = session
2034 .db()
2035 .await
2036 .unfollow(room_id, project_id, leader_id, follower_id)
2037 .await?;
2038 room_updated(&room, &session.peer);
2039 }
2040
2041 Ok(())
2042}
2043
2044/// Notify everyone following you of your current location.
2045async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2046 let room_id = RoomId::from_proto(request.room_id);
2047 let database = session.db.lock().await;
2048
2049 let connection_ids = if let Some(project_id) = request.project_id {
2050 let project_id = ProjectId::from_proto(project_id);
2051 database
2052 .project_connection_ids(project_id, session.connection_id)
2053 .await?
2054 } else {
2055 database
2056 .room_connection_ids(room_id, session.connection_id)
2057 .await?
2058 };
2059
2060 // For now, don't send view update messages back to that view's current leader.
2061 let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2062 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2063 _ => None,
2064 });
2065
2066 for follower_peer_id in request.follower_ids.iter().copied() {
2067 let follower_connection_id = follower_peer_id.into();
2068 if Some(follower_peer_id) != connection_id_to_omit
2069 && connection_ids.contains(&follower_connection_id)
2070 {
2071 session.peer.forward_send(
2072 session.connection_id,
2073 follower_connection_id,
2074 request.clone(),
2075 )?;
2076 }
2077 }
2078 Ok(())
2079}
2080
2081/// Get public data about users.
2082async fn get_users(
2083 request: proto::GetUsers,
2084 response: Response<proto::GetUsers>,
2085 session: Session,
2086) -> Result<()> {
2087 let user_ids = request
2088 .user_ids
2089 .into_iter()
2090 .map(UserId::from_proto)
2091 .collect();
2092 let users = session
2093 .db()
2094 .await
2095 .get_users_by_ids(user_ids)
2096 .await?
2097 .into_iter()
2098 .map(|user| proto::User {
2099 id: user.id.to_proto(),
2100 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2101 github_login: user.github_login,
2102 })
2103 .collect();
2104 response.send(proto::UsersResponse { users })?;
2105 Ok(())
2106}
2107
2108/// Search for users (to invite) buy Github login
2109async fn fuzzy_search_users(
2110 request: proto::FuzzySearchUsers,
2111 response: Response<proto::FuzzySearchUsers>,
2112 session: Session,
2113) -> Result<()> {
2114 let query = request.query;
2115 let users = match query.len() {
2116 0 => vec![],
2117 1 | 2 => session
2118 .db()
2119 .await
2120 .get_user_by_github_login(&query)
2121 .await?
2122 .into_iter()
2123 .collect(),
2124 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2125 };
2126 let users = users
2127 .into_iter()
2128 .filter(|user| user.id != session.user_id)
2129 .map(|user| proto::User {
2130 id: user.id.to_proto(),
2131 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2132 github_login: user.github_login,
2133 })
2134 .collect();
2135 response.send(proto::UsersResponse { users })?;
2136 Ok(())
2137}
2138
2139/// Send a contact request to another user.
2140async fn request_contact(
2141 request: proto::RequestContact,
2142 response: Response<proto::RequestContact>,
2143 session: Session,
2144) -> Result<()> {
2145 let requester_id = session.user_id;
2146 let responder_id = UserId::from_proto(request.responder_id);
2147 if requester_id == responder_id {
2148 return Err(anyhow!("cannot add yourself as a contact"))?;
2149 }
2150
2151 let notifications = session
2152 .db()
2153 .await
2154 .send_contact_request(requester_id, responder_id)
2155 .await?;
2156
2157 // Update outgoing contact requests of requester
2158 let mut update = proto::UpdateContacts::default();
2159 update.outgoing_requests.push(responder_id.to_proto());
2160 for connection_id in session
2161 .connection_pool()
2162 .await
2163 .user_connection_ids(requester_id)
2164 {
2165 session.peer.send(connection_id, update.clone())?;
2166 }
2167
2168 // Update incoming contact requests of responder
2169 let mut update = proto::UpdateContacts::default();
2170 update
2171 .incoming_requests
2172 .push(proto::IncomingContactRequest {
2173 requester_id: requester_id.to_proto(),
2174 });
2175 let connection_pool = session.connection_pool().await;
2176 for connection_id in connection_pool.user_connection_ids(responder_id) {
2177 session.peer.send(connection_id, update.clone())?;
2178 }
2179
2180 send_notifications(&*connection_pool, &session.peer, notifications);
2181
2182 response.send(proto::Ack {})?;
2183 Ok(())
2184}
2185
2186/// Accept or decline a contact request
2187async fn respond_to_contact_request(
2188 request: proto::RespondToContactRequest,
2189 response: Response<proto::RespondToContactRequest>,
2190 session: Session,
2191) -> Result<()> {
2192 let responder_id = session.user_id;
2193 let requester_id = UserId::from_proto(request.requester_id);
2194 let db = session.db().await;
2195 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2196 db.dismiss_contact_notification(responder_id, requester_id)
2197 .await?;
2198 } else {
2199 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2200
2201 let notifications = db
2202 .respond_to_contact_request(responder_id, requester_id, accept)
2203 .await?;
2204 let requester_busy = db.is_user_busy(requester_id).await?;
2205 let responder_busy = db.is_user_busy(responder_id).await?;
2206
2207 let pool = session.connection_pool().await;
2208 // Update responder with new contact
2209 let mut update = proto::UpdateContacts::default();
2210 if accept {
2211 update
2212 .contacts
2213 .push(contact_for_user(requester_id, requester_busy, &pool));
2214 }
2215 update
2216 .remove_incoming_requests
2217 .push(requester_id.to_proto());
2218 for connection_id in pool.user_connection_ids(responder_id) {
2219 session.peer.send(connection_id, update.clone())?;
2220 }
2221
2222 // Update requester with new contact
2223 let mut update = proto::UpdateContacts::default();
2224 if accept {
2225 update
2226 .contacts
2227 .push(contact_for_user(responder_id, responder_busy, &pool));
2228 }
2229 update
2230 .remove_outgoing_requests
2231 .push(responder_id.to_proto());
2232
2233 for connection_id in pool.user_connection_ids(requester_id) {
2234 session.peer.send(connection_id, update.clone())?;
2235 }
2236
2237 send_notifications(&*pool, &session.peer, notifications);
2238 }
2239
2240 response.send(proto::Ack {})?;
2241 Ok(())
2242}
2243
2244/// Remove a contact.
2245async fn remove_contact(
2246 request: proto::RemoveContact,
2247 response: Response<proto::RemoveContact>,
2248 session: Session,
2249) -> Result<()> {
2250 let requester_id = session.user_id;
2251 let responder_id = UserId::from_proto(request.user_id);
2252 let db = session.db().await;
2253 let (contact_accepted, deleted_notification_id) =
2254 db.remove_contact(requester_id, responder_id).await?;
2255
2256 let pool = session.connection_pool().await;
2257 // Update outgoing contact requests of requester
2258 let mut update = proto::UpdateContacts::default();
2259 if contact_accepted {
2260 update.remove_contacts.push(responder_id.to_proto());
2261 } else {
2262 update
2263 .remove_outgoing_requests
2264 .push(responder_id.to_proto());
2265 }
2266 for connection_id in pool.user_connection_ids(requester_id) {
2267 session.peer.send(connection_id, update.clone())?;
2268 }
2269
2270 // Update incoming contact requests of responder
2271 let mut update = proto::UpdateContacts::default();
2272 if contact_accepted {
2273 update.remove_contacts.push(requester_id.to_proto());
2274 } else {
2275 update
2276 .remove_incoming_requests
2277 .push(requester_id.to_proto());
2278 }
2279 for connection_id in pool.user_connection_ids(responder_id) {
2280 session.peer.send(connection_id, update.clone())?;
2281 if let Some(notification_id) = deleted_notification_id {
2282 session.peer.send(
2283 connection_id,
2284 proto::DeleteNotification {
2285 notification_id: notification_id.to_proto(),
2286 },
2287 )?;
2288 }
2289 }
2290
2291 response.send(proto::Ack {})?;
2292 Ok(())
2293}
2294
2295/// Creates a new channel.
2296async fn create_channel(
2297 request: proto::CreateChannel,
2298 response: Response<proto::CreateChannel>,
2299 session: Session,
2300) -> Result<()> {
2301 let db = session.db().await;
2302
2303 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2304 let CreateChannelResult {
2305 channel,
2306 participants_to_update,
2307 } = db
2308 .create_channel(&request.name, parent_id, session.user_id)
2309 .await?;
2310
2311 response.send(proto::CreateChannelResponse {
2312 channel: Some(channel.to_proto()),
2313 parent_id: request.parent_id,
2314 })?;
2315
2316 let connection_pool = session.connection_pool().await;
2317 for (user_id, channels) in participants_to_update {
2318 let update = build_channels_update(channels, vec![]);
2319 for connection_id in connection_pool.user_connection_ids(user_id) {
2320 if user_id == session.user_id {
2321 continue;
2322 }
2323 session.peer.send(connection_id, update.clone())?;
2324 }
2325 }
2326
2327 Ok(())
2328}
2329
2330/// Delete a channel
2331async fn delete_channel(
2332 request: proto::DeleteChannel,
2333 response: Response<proto::DeleteChannel>,
2334 session: Session,
2335) -> Result<()> {
2336 let db = session.db().await;
2337
2338 let channel_id = request.channel_id;
2339 let (removed_channels, member_ids) = db
2340 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2341 .await?;
2342 response.send(proto::Ack {})?;
2343
2344 // Notify members of removed channels
2345 let mut update = proto::UpdateChannels::default();
2346 update
2347 .delete_channels
2348 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2349
2350 let connection_pool = session.connection_pool().await;
2351 for member_id in member_ids {
2352 for connection_id in connection_pool.user_connection_ids(member_id) {
2353 session.peer.send(connection_id, update.clone())?;
2354 }
2355 }
2356
2357 Ok(())
2358}
2359
2360/// Invite someone to join a channel.
2361async fn invite_channel_member(
2362 request: proto::InviteChannelMember,
2363 response: Response<proto::InviteChannelMember>,
2364 session: Session,
2365) -> Result<()> {
2366 let db = session.db().await;
2367 let channel_id = ChannelId::from_proto(request.channel_id);
2368 let invitee_id = UserId::from_proto(request.user_id);
2369 let InviteMemberResult {
2370 channel,
2371 notifications,
2372 } = db
2373 .invite_channel_member(
2374 channel_id,
2375 invitee_id,
2376 session.user_id,
2377 request.role().into(),
2378 )
2379 .await?;
2380
2381 let update = proto::UpdateChannels {
2382 channel_invitations: vec![channel.to_proto()],
2383 ..Default::default()
2384 };
2385
2386 let connection_pool = session.connection_pool().await;
2387 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2388 session.peer.send(connection_id, update.clone())?;
2389 }
2390
2391 send_notifications(&*connection_pool, &session.peer, notifications);
2392
2393 response.send(proto::Ack {})?;
2394 Ok(())
2395}
2396
2397/// remove someone from a channel
2398async fn remove_channel_member(
2399 request: proto::RemoveChannelMember,
2400 response: Response<proto::RemoveChannelMember>,
2401 session: Session,
2402) -> Result<()> {
2403 let db = session.db().await;
2404 let channel_id = ChannelId::from_proto(request.channel_id);
2405 let member_id = UserId::from_proto(request.user_id);
2406
2407 let RemoveChannelMemberResult {
2408 membership_update,
2409 notification_id,
2410 } = db
2411 .remove_channel_member(channel_id, member_id, session.user_id)
2412 .await?;
2413
2414 let connection_pool = &session.connection_pool().await;
2415 notify_membership_updated(
2416 &connection_pool,
2417 membership_update,
2418 member_id,
2419 &session.peer,
2420 );
2421 for connection_id in connection_pool.user_connection_ids(member_id) {
2422 if let Some(notification_id) = notification_id {
2423 session
2424 .peer
2425 .send(
2426 connection_id,
2427 proto::DeleteNotification {
2428 notification_id: notification_id.to_proto(),
2429 },
2430 )
2431 .trace_err();
2432 }
2433 }
2434
2435 response.send(proto::Ack {})?;
2436 Ok(())
2437}
2438
2439/// Toggle the channel between public and private
2440async fn set_channel_visibility(
2441 request: proto::SetChannelVisibility,
2442 response: Response<proto::SetChannelVisibility>,
2443 session: Session,
2444) -> Result<()> {
2445 let db = session.db().await;
2446 let channel_id = ChannelId::from_proto(request.channel_id);
2447 let visibility = request.visibility().into();
2448
2449 let SetChannelVisibilityResult {
2450 participants_to_update,
2451 participants_to_remove,
2452 channels_to_remove,
2453 } = db
2454 .set_channel_visibility(channel_id, visibility, session.user_id)
2455 .await?;
2456
2457 let connection_pool = session.connection_pool().await;
2458 for (user_id, channels) in participants_to_update {
2459 let update = build_channels_update(channels, vec![]);
2460 for connection_id in connection_pool.user_connection_ids(user_id) {
2461 session.peer.send(connection_id, update.clone())?;
2462 }
2463 }
2464 for user_id in participants_to_remove {
2465 let update = proto::UpdateChannels {
2466 delete_channels: channels_to_remove.iter().map(|id| id.to_proto()).collect(),
2467 ..Default::default()
2468 };
2469 for connection_id in connection_pool.user_connection_ids(user_id) {
2470 session.peer.send(connection_id, update.clone())?;
2471 }
2472 }
2473
2474 response.send(proto::Ack {})?;
2475 Ok(())
2476}
2477
2478/// Alter the role for a user in the channel
2479async fn set_channel_member_role(
2480 request: proto::SetChannelMemberRole,
2481 response: Response<proto::SetChannelMemberRole>,
2482 session: Session,
2483) -> Result<()> {
2484 let db = session.db().await;
2485 let channel_id = ChannelId::from_proto(request.channel_id);
2486 let member_id = UserId::from_proto(request.user_id);
2487 let result = db
2488 .set_channel_member_role(
2489 channel_id,
2490 session.user_id,
2491 member_id,
2492 request.role().into(),
2493 )
2494 .await?;
2495
2496 match result {
2497 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2498 let connection_pool = session.connection_pool().await;
2499 notify_membership_updated(
2500 &connection_pool,
2501 membership_update,
2502 member_id,
2503 &session.peer,
2504 )
2505 }
2506 db::SetMemberRoleResult::InviteUpdated(channel) => {
2507 let update = proto::UpdateChannels {
2508 channel_invitations: vec![channel.to_proto()],
2509 ..Default::default()
2510 };
2511
2512 for connection_id in session
2513 .connection_pool()
2514 .await
2515 .user_connection_ids(member_id)
2516 {
2517 session.peer.send(connection_id, update.clone())?;
2518 }
2519 }
2520 }
2521
2522 response.send(proto::Ack {})?;
2523 Ok(())
2524}
2525
2526/// Change the name of a channel
2527async fn rename_channel(
2528 request: proto::RenameChannel,
2529 response: Response<proto::RenameChannel>,
2530 session: Session,
2531) -> Result<()> {
2532 let db = session.db().await;
2533 let channel_id = ChannelId::from_proto(request.channel_id);
2534 let RenameChannelResult {
2535 channel,
2536 participants_to_update,
2537 } = db
2538 .rename_channel(channel_id, session.user_id, &request.name)
2539 .await?;
2540
2541 response.send(proto::RenameChannelResponse {
2542 channel: Some(channel.to_proto()),
2543 })?;
2544
2545 let connection_pool = session.connection_pool().await;
2546 for (user_id, channel) in participants_to_update {
2547 for connection_id in connection_pool.user_connection_ids(user_id) {
2548 let update = proto::UpdateChannels {
2549 channels: vec![channel.to_proto()],
2550 ..Default::default()
2551 };
2552
2553 session.peer.send(connection_id, update.clone())?;
2554 }
2555 }
2556
2557 Ok(())
2558}
2559
2560/// Move a channel to a new parent.
2561async fn move_channel(
2562 request: proto::MoveChannel,
2563 response: Response<proto::MoveChannel>,
2564 session: Session,
2565) -> Result<()> {
2566 let channel_id = ChannelId::from_proto(request.channel_id);
2567 let to = request.to.map(ChannelId::from_proto);
2568
2569 let result = session
2570 .db()
2571 .await
2572 .move_channel(channel_id, to, session.user_id)
2573 .await?;
2574
2575 notify_channel_moved(result, session).await?;
2576
2577 response.send(Ack {})?;
2578 Ok(())
2579}
2580
2581async fn notify_channel_moved(result: Option<MoveChannelResult>, session: Session) -> Result<()> {
2582 let Some(MoveChannelResult {
2583 participants_to_remove,
2584 participants_to_update,
2585 moved_channels,
2586 }) = result
2587 else {
2588 return Ok(());
2589 };
2590 let moved_channels: Vec<u64> = moved_channels.iter().map(|id| id.to_proto()).collect();
2591
2592 let connection_pool = session.connection_pool().await;
2593 for (user_id, channels) in participants_to_update {
2594 let mut update = build_channels_update(channels, vec![]);
2595 update.delete_channels = moved_channels.clone();
2596 for connection_id in connection_pool.user_connection_ids(user_id) {
2597 session.peer.send(connection_id, update.clone())?;
2598 }
2599 }
2600
2601 for user_id in participants_to_remove {
2602 let update = proto::UpdateChannels {
2603 delete_channels: moved_channels.clone(),
2604 ..Default::default()
2605 };
2606 for connection_id in connection_pool.user_connection_ids(user_id) {
2607 session.peer.send(connection_id, update.clone())?;
2608 }
2609 }
2610 Ok(())
2611}
2612
2613/// Get the list of channel members
2614async fn get_channel_members(
2615 request: proto::GetChannelMembers,
2616 response: Response<proto::GetChannelMembers>,
2617 session: Session,
2618) -> Result<()> {
2619 let db = session.db().await;
2620 let channel_id = ChannelId::from_proto(request.channel_id);
2621 let members = db
2622 .get_channel_participant_details(channel_id, session.user_id)
2623 .await?;
2624 response.send(proto::GetChannelMembersResponse { members })?;
2625 Ok(())
2626}
2627
2628/// Accept or decline a channel invitation.
2629async fn respond_to_channel_invite(
2630 request: proto::RespondToChannelInvite,
2631 response: Response<proto::RespondToChannelInvite>,
2632 session: Session,
2633) -> Result<()> {
2634 let db = session.db().await;
2635 let channel_id = ChannelId::from_proto(request.channel_id);
2636 let RespondToChannelInvite {
2637 membership_update,
2638 notifications,
2639 } = db
2640 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2641 .await?;
2642
2643 let connection_pool = session.connection_pool().await;
2644 if let Some(membership_update) = membership_update {
2645 notify_membership_updated(
2646 &connection_pool,
2647 membership_update,
2648 session.user_id,
2649 &session.peer,
2650 );
2651 } else {
2652 let update = proto::UpdateChannels {
2653 remove_channel_invitations: vec![channel_id.to_proto()],
2654 ..Default::default()
2655 };
2656
2657 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2658 session.peer.send(connection_id, update.clone())?;
2659 }
2660 };
2661
2662 send_notifications(&*connection_pool, &session.peer, notifications);
2663
2664 response.send(proto::Ack {})?;
2665
2666 Ok(())
2667}
2668
2669/// Join the channels' room
2670async fn join_channel(
2671 request: proto::JoinChannel,
2672 response: Response<proto::JoinChannel>,
2673 session: Session,
2674) -> Result<()> {
2675 let channel_id = ChannelId::from_proto(request.channel_id);
2676 join_channel_internal(channel_id, Box::new(response), session).await
2677}
2678
2679trait JoinChannelInternalResponse {
2680 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2681}
2682impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2683 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2684 Response::<proto::JoinChannel>::send(self, result)
2685 }
2686}
2687impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2688 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2689 Response::<proto::JoinRoom>::send(self, result)
2690 }
2691}
2692
2693async fn join_channel_internal(
2694 channel_id: ChannelId,
2695 response: Box<impl JoinChannelInternalResponse>,
2696 session: Session,
2697) -> Result<()> {
2698 let joined_room = {
2699 leave_room_for_session(&session).await?;
2700 let db = session.db().await;
2701
2702 let (joined_room, membership_updated, role) = db
2703 .join_channel(
2704 channel_id,
2705 session.user_id,
2706 session.connection_id,
2707 session.zed_environment.as_ref(),
2708 )
2709 .await?;
2710
2711 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2712 let (can_publish, token) = if role == ChannelRole::Guest {
2713 (
2714 false,
2715 live_kit
2716 .guest_token(
2717 &joined_room.room.live_kit_room,
2718 &session.user_id.to_string(),
2719 )
2720 .trace_err()?,
2721 )
2722 } else {
2723 (
2724 true,
2725 live_kit
2726 .room_token(
2727 &joined_room.room.live_kit_room,
2728 &session.user_id.to_string(),
2729 )
2730 .trace_err()?,
2731 )
2732 };
2733
2734 Some(LiveKitConnectionInfo {
2735 server_url: live_kit.url().into(),
2736 token,
2737 can_publish,
2738 })
2739 });
2740
2741 response.send(proto::JoinRoomResponse {
2742 room: Some(joined_room.room.clone()),
2743 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2744 live_kit_connection_info,
2745 })?;
2746
2747 let connection_pool = session.connection_pool().await;
2748 if let Some(membership_updated) = membership_updated {
2749 notify_membership_updated(
2750 &connection_pool,
2751 membership_updated,
2752 session.user_id,
2753 &session.peer,
2754 );
2755 }
2756
2757 room_updated(&joined_room.room, &session.peer);
2758
2759 joined_room
2760 };
2761
2762 channel_updated(
2763 channel_id,
2764 &joined_room.room,
2765 &joined_room.channel_members,
2766 &session.peer,
2767 &*session.connection_pool().await,
2768 );
2769
2770 update_user_contacts(session.user_id, &session).await?;
2771 Ok(())
2772}
2773
2774/// Start editing the channel notes
2775async fn join_channel_buffer(
2776 request: proto::JoinChannelBuffer,
2777 response: Response<proto::JoinChannelBuffer>,
2778 session: Session,
2779) -> Result<()> {
2780 let db = session.db().await;
2781 let channel_id = ChannelId::from_proto(request.channel_id);
2782
2783 let open_response = db
2784 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2785 .await?;
2786
2787 let collaborators = open_response.collaborators.clone();
2788 response.send(open_response)?;
2789
2790 let update = UpdateChannelBufferCollaborators {
2791 channel_id: channel_id.to_proto(),
2792 collaborators: collaborators.clone(),
2793 };
2794 channel_buffer_updated(
2795 session.connection_id,
2796 collaborators
2797 .iter()
2798 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2799 &update,
2800 &session.peer,
2801 );
2802
2803 Ok(())
2804}
2805
2806/// Edit the channel notes
2807async fn update_channel_buffer(
2808 request: proto::UpdateChannelBuffer,
2809 session: Session,
2810) -> Result<()> {
2811 let db = session.db().await;
2812 let channel_id = ChannelId::from_proto(request.channel_id);
2813
2814 let (collaborators, non_collaborators, epoch, version) = db
2815 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2816 .await?;
2817
2818 channel_buffer_updated(
2819 session.connection_id,
2820 collaborators,
2821 &proto::UpdateChannelBuffer {
2822 channel_id: channel_id.to_proto(),
2823 operations: request.operations,
2824 },
2825 &session.peer,
2826 );
2827
2828 let pool = &*session.connection_pool().await;
2829
2830 broadcast(
2831 None,
2832 non_collaborators
2833 .iter()
2834 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2835 |peer_id| {
2836 session.peer.send(
2837 peer_id.into(),
2838 proto::UpdateChannels {
2839 unseen_channel_buffer_changes: vec![proto::UnseenChannelBufferChange {
2840 channel_id: channel_id.to_proto(),
2841 epoch: epoch as u64,
2842 version: version.clone(),
2843 }],
2844 ..Default::default()
2845 },
2846 )
2847 },
2848 );
2849
2850 Ok(())
2851}
2852
2853/// Rejoin the channel notes after a connection blip
2854async fn rejoin_channel_buffers(
2855 request: proto::RejoinChannelBuffers,
2856 response: Response<proto::RejoinChannelBuffers>,
2857 session: Session,
2858) -> Result<()> {
2859 let db = session.db().await;
2860 let buffers = db
2861 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2862 .await?;
2863
2864 for rejoined_buffer in &buffers {
2865 let collaborators_to_notify = rejoined_buffer
2866 .buffer
2867 .collaborators
2868 .iter()
2869 .filter_map(|c| Some(c.peer_id?.into()));
2870 channel_buffer_updated(
2871 session.connection_id,
2872 collaborators_to_notify,
2873 &proto::UpdateChannelBufferCollaborators {
2874 channel_id: rejoined_buffer.buffer.channel_id,
2875 collaborators: rejoined_buffer.buffer.collaborators.clone(),
2876 },
2877 &session.peer,
2878 );
2879 }
2880
2881 response.send(proto::RejoinChannelBuffersResponse {
2882 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
2883 })?;
2884
2885 Ok(())
2886}
2887
2888/// Stop editing the channel notes
2889async fn leave_channel_buffer(
2890 request: proto::LeaveChannelBuffer,
2891 response: Response<proto::LeaveChannelBuffer>,
2892 session: Session,
2893) -> Result<()> {
2894 let db = session.db().await;
2895 let channel_id = ChannelId::from_proto(request.channel_id);
2896
2897 let left_buffer = db
2898 .leave_channel_buffer(channel_id, session.connection_id)
2899 .await?;
2900
2901 response.send(Ack {})?;
2902
2903 channel_buffer_updated(
2904 session.connection_id,
2905 left_buffer.connections,
2906 &proto::UpdateChannelBufferCollaborators {
2907 channel_id: channel_id.to_proto(),
2908 collaborators: left_buffer.collaborators,
2909 },
2910 &session.peer,
2911 );
2912
2913 Ok(())
2914}
2915
2916fn channel_buffer_updated<T: EnvelopedMessage>(
2917 sender_id: ConnectionId,
2918 collaborators: impl IntoIterator<Item = ConnectionId>,
2919 message: &T,
2920 peer: &Peer,
2921) {
2922 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
2923 peer.send(peer_id.into(), message.clone())
2924 });
2925}
2926
2927fn send_notifications(
2928 connection_pool: &ConnectionPool,
2929 peer: &Peer,
2930 notifications: db::NotificationBatch,
2931) {
2932 for (user_id, notification) in notifications {
2933 for connection_id in connection_pool.user_connection_ids(user_id) {
2934 if let Err(error) = peer.send(
2935 connection_id,
2936 proto::AddNotification {
2937 notification: Some(notification.clone()),
2938 },
2939 ) {
2940 tracing::error!(
2941 "failed to send notification to {:?} {}",
2942 connection_id,
2943 error
2944 );
2945 }
2946 }
2947 }
2948}
2949
2950/// Send a message to the channel
2951async fn send_channel_message(
2952 request: proto::SendChannelMessage,
2953 response: Response<proto::SendChannelMessage>,
2954 session: Session,
2955) -> Result<()> {
2956 // Validate the message body.
2957 let body = request.body.trim().to_string();
2958 if body.len() > MAX_MESSAGE_LEN {
2959 return Err(anyhow!("message is too long"))?;
2960 }
2961 if body.is_empty() {
2962 return Err(anyhow!("message can't be blank"))?;
2963 }
2964
2965 // TODO: adjust mentions if body is trimmed
2966
2967 let timestamp = OffsetDateTime::now_utc();
2968 let nonce = request
2969 .nonce
2970 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
2971
2972 let channel_id = ChannelId::from_proto(request.channel_id);
2973 let CreatedChannelMessage {
2974 message_id,
2975 participant_connection_ids,
2976 channel_members,
2977 notifications,
2978 } = session
2979 .db()
2980 .await
2981 .create_channel_message(
2982 channel_id,
2983 session.user_id,
2984 &body,
2985 &request.mentions,
2986 timestamp,
2987 nonce.clone().into(),
2988 )
2989 .await?;
2990 let message = proto::ChannelMessage {
2991 sender_id: session.user_id.to_proto(),
2992 id: message_id.to_proto(),
2993 body,
2994 mentions: request.mentions,
2995 timestamp: timestamp.unix_timestamp() as u64,
2996 nonce: Some(nonce),
2997 };
2998 broadcast(
2999 Some(session.connection_id),
3000 participant_connection_ids,
3001 |connection| {
3002 session.peer.send(
3003 connection,
3004 proto::ChannelMessageSent {
3005 channel_id: channel_id.to_proto(),
3006 message: Some(message.clone()),
3007 },
3008 )
3009 },
3010 );
3011 response.send(proto::SendChannelMessageResponse {
3012 message: Some(message),
3013 })?;
3014
3015 let pool = &*session.connection_pool().await;
3016 broadcast(
3017 None,
3018 channel_members
3019 .iter()
3020 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3021 |peer_id| {
3022 session.peer.send(
3023 peer_id.into(),
3024 proto::UpdateChannels {
3025 unseen_channel_messages: vec![proto::UnseenChannelMessage {
3026 channel_id: channel_id.to_proto(),
3027 message_id: message_id.to_proto(),
3028 }],
3029 ..Default::default()
3030 },
3031 )
3032 },
3033 );
3034 send_notifications(pool, &session.peer, notifications);
3035
3036 Ok(())
3037}
3038
3039/// Delete a channel message
3040async fn remove_channel_message(
3041 request: proto::RemoveChannelMessage,
3042 response: Response<proto::RemoveChannelMessage>,
3043 session: Session,
3044) -> Result<()> {
3045 let channel_id = ChannelId::from_proto(request.channel_id);
3046 let message_id = MessageId::from_proto(request.message_id);
3047 let connection_ids = session
3048 .db()
3049 .await
3050 .remove_channel_message(channel_id, message_id, session.user_id)
3051 .await?;
3052 broadcast(Some(session.connection_id), connection_ids, |connection| {
3053 session.peer.send(connection, request.clone())
3054 });
3055 response.send(proto::Ack {})?;
3056 Ok(())
3057}
3058
3059/// Mark a channel message as read
3060async fn acknowledge_channel_message(
3061 request: proto::AckChannelMessage,
3062 session: Session,
3063) -> Result<()> {
3064 let channel_id = ChannelId::from_proto(request.channel_id);
3065 let message_id = MessageId::from_proto(request.message_id);
3066 let notifications = session
3067 .db()
3068 .await
3069 .observe_channel_message(channel_id, session.user_id, message_id)
3070 .await?;
3071 send_notifications(
3072 &*session.connection_pool().await,
3073 &session.peer,
3074 notifications,
3075 );
3076 Ok(())
3077}
3078
3079/// Mark a buffer version as synced
3080async fn acknowledge_buffer_version(
3081 request: proto::AckBufferOperation,
3082 session: Session,
3083) -> Result<()> {
3084 let buffer_id = BufferId::from_proto(request.buffer_id);
3085 session
3086 .db()
3087 .await
3088 .observe_buffer_version(
3089 buffer_id,
3090 session.user_id,
3091 request.epoch as i32,
3092 &request.version,
3093 )
3094 .await?;
3095 Ok(())
3096}
3097
3098/// Start receiving chat updates for a channel
3099async fn join_channel_chat(
3100 request: proto::JoinChannelChat,
3101 response: Response<proto::JoinChannelChat>,
3102 session: Session,
3103) -> Result<()> {
3104 let channel_id = ChannelId::from_proto(request.channel_id);
3105
3106 let db = session.db().await;
3107 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3108 .await?;
3109 let messages = db
3110 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3111 .await?;
3112 response.send(proto::JoinChannelChatResponse {
3113 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3114 messages,
3115 })?;
3116 Ok(())
3117}
3118
3119/// Stop receiving chat updates for a channel
3120async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3121 let channel_id = ChannelId::from_proto(request.channel_id);
3122 session
3123 .db()
3124 .await
3125 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3126 .await?;
3127 Ok(())
3128}
3129
3130/// Retrieve the chat history for a channel
3131async fn get_channel_messages(
3132 request: proto::GetChannelMessages,
3133 response: Response<proto::GetChannelMessages>,
3134 session: Session,
3135) -> Result<()> {
3136 let channel_id = ChannelId::from_proto(request.channel_id);
3137 let messages = session
3138 .db()
3139 .await
3140 .get_channel_messages(
3141 channel_id,
3142 session.user_id,
3143 MESSAGE_COUNT_PER_PAGE,
3144 Some(MessageId::from_proto(request.before_message_id)),
3145 )
3146 .await?;
3147 response.send(proto::GetChannelMessagesResponse {
3148 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3149 messages,
3150 })?;
3151 Ok(())
3152}
3153
3154/// Retrieve specific chat messages
3155async fn get_channel_messages_by_id(
3156 request: proto::GetChannelMessagesById,
3157 response: Response<proto::GetChannelMessagesById>,
3158 session: Session,
3159) -> Result<()> {
3160 let message_ids = request
3161 .message_ids
3162 .iter()
3163 .map(|id| MessageId::from_proto(*id))
3164 .collect::<Vec<_>>();
3165 let messages = session
3166 .db()
3167 .await
3168 .get_channel_messages_by_id(session.user_id, &message_ids)
3169 .await?;
3170 response.send(proto::GetChannelMessagesResponse {
3171 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3172 messages,
3173 })?;
3174 Ok(())
3175}
3176
3177/// Retrieve the current users notifications
3178async fn get_notifications(
3179 request: proto::GetNotifications,
3180 response: Response<proto::GetNotifications>,
3181 session: Session,
3182) -> Result<()> {
3183 let notifications = session
3184 .db()
3185 .await
3186 .get_notifications(
3187 session.user_id,
3188 NOTIFICATION_COUNT_PER_PAGE,
3189 request
3190 .before_id
3191 .map(|id| db::NotificationId::from_proto(id)),
3192 )
3193 .await?;
3194 response.send(proto::GetNotificationsResponse {
3195 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3196 notifications,
3197 })?;
3198 Ok(())
3199}
3200
3201/// Mark notifications as read
3202async fn mark_notification_as_read(
3203 request: proto::MarkNotificationRead,
3204 response: Response<proto::MarkNotificationRead>,
3205 session: Session,
3206) -> Result<()> {
3207 let database = &session.db().await;
3208 let notifications = database
3209 .mark_notification_as_read_by_id(
3210 session.user_id,
3211 NotificationId::from_proto(request.notification_id),
3212 )
3213 .await?;
3214 send_notifications(
3215 &*session.connection_pool().await,
3216 &session.peer,
3217 notifications,
3218 );
3219 response.send(proto::Ack {})?;
3220 Ok(())
3221}
3222
3223/// Get the current users information
3224async fn get_private_user_info(
3225 _request: proto::GetPrivateUserInfo,
3226 response: Response<proto::GetPrivateUserInfo>,
3227 session: Session,
3228) -> Result<()> {
3229 let db = session.db().await;
3230
3231 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3232 let user = db
3233 .get_user_by_id(session.user_id)
3234 .await?
3235 .ok_or_else(|| anyhow!("user not found"))?;
3236 let flags = db.get_user_flags(session.user_id).await?;
3237
3238 response.send(proto::GetPrivateUserInfoResponse {
3239 metrics_id,
3240 staff: user.admin,
3241 flags,
3242 })?;
3243 Ok(())
3244}
3245
3246fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3247 match message {
3248 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3249 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3250 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3251 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3252 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3253 code: frame.code.into(),
3254 reason: frame.reason,
3255 })),
3256 }
3257}
3258
3259fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3260 match message {
3261 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3262 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3263 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3264 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3265 AxumMessage::Close(frame) => {
3266 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3267 code: frame.code.into(),
3268 reason: frame.reason,
3269 }))
3270 }
3271 }
3272}
3273
3274fn notify_membership_updated(
3275 connection_pool: &ConnectionPool,
3276 result: MembershipUpdated,
3277 user_id: UserId,
3278 peer: &Peer,
3279) {
3280 let mut update = build_channels_update(result.new_channels, vec![]);
3281 update.delete_channels = result
3282 .removed_channels
3283 .into_iter()
3284 .map(|id| id.to_proto())
3285 .collect();
3286 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3287
3288 for connection_id in connection_pool.user_connection_ids(user_id) {
3289 peer.send(connection_id, update.clone()).trace_err();
3290 }
3291}
3292
3293fn build_channels_update(
3294 channels: ChannelsForUser,
3295 channel_invites: Vec<db::Channel>,
3296) -> proto::UpdateChannels {
3297 let mut update = proto::UpdateChannels::default();
3298
3299 for channel in channels.channels {
3300 update.channels.push(channel.to_proto());
3301 }
3302
3303 update.unseen_channel_buffer_changes = channels.unseen_buffer_changes;
3304 update.unseen_channel_messages = channels.channel_messages;
3305
3306 for (channel_id, participants) in channels.channel_participants {
3307 update
3308 .channel_participants
3309 .push(proto::ChannelParticipants {
3310 channel_id: channel_id.to_proto(),
3311 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3312 });
3313 }
3314
3315 for channel in channel_invites {
3316 update.channel_invitations.push(channel.to_proto());
3317 }
3318
3319 update
3320}
3321
3322fn build_initial_contacts_update(
3323 contacts: Vec<db::Contact>,
3324 pool: &ConnectionPool,
3325) -> proto::UpdateContacts {
3326 let mut update = proto::UpdateContacts::default();
3327
3328 for contact in contacts {
3329 match contact {
3330 db::Contact::Accepted { user_id, busy } => {
3331 update.contacts.push(contact_for_user(user_id, busy, &pool));
3332 }
3333 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3334 db::Contact::Incoming { user_id } => {
3335 update
3336 .incoming_requests
3337 .push(proto::IncomingContactRequest {
3338 requester_id: user_id.to_proto(),
3339 })
3340 }
3341 }
3342 }
3343
3344 update
3345}
3346
3347fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3348 proto::Contact {
3349 user_id: user_id.to_proto(),
3350 online: pool.is_user_online(user_id),
3351 busy,
3352 }
3353}
3354
3355fn room_updated(room: &proto::Room, peer: &Peer) {
3356 broadcast(
3357 None,
3358 room.participants
3359 .iter()
3360 .filter_map(|participant| Some(participant.peer_id?.into())),
3361 |peer_id| {
3362 peer.send(
3363 peer_id.into(),
3364 proto::RoomUpdated {
3365 room: Some(room.clone()),
3366 },
3367 )
3368 },
3369 );
3370}
3371
3372fn channel_updated(
3373 channel_id: ChannelId,
3374 room: &proto::Room,
3375 channel_members: &[UserId],
3376 peer: &Peer,
3377 pool: &ConnectionPool,
3378) {
3379 let participants = room
3380 .participants
3381 .iter()
3382 .map(|p| p.user_id)
3383 .collect::<Vec<_>>();
3384
3385 broadcast(
3386 None,
3387 channel_members
3388 .iter()
3389 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3390 |peer_id| {
3391 peer.send(
3392 peer_id.into(),
3393 proto::UpdateChannels {
3394 channel_participants: vec![proto::ChannelParticipants {
3395 channel_id: channel_id.to_proto(),
3396 participant_user_ids: participants.clone(),
3397 }],
3398 ..Default::default()
3399 },
3400 )
3401 },
3402 );
3403}
3404
3405async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3406 let db = session.db().await;
3407
3408 let contacts = db.get_contacts(user_id).await?;
3409 let busy = db.is_user_busy(user_id).await?;
3410
3411 let pool = session.connection_pool().await;
3412 let updated_contact = contact_for_user(user_id, busy, &pool);
3413 for contact in contacts {
3414 if let db::Contact::Accepted {
3415 user_id: contact_user_id,
3416 ..
3417 } = contact
3418 {
3419 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3420 session
3421 .peer
3422 .send(
3423 contact_conn_id,
3424 proto::UpdateContacts {
3425 contacts: vec![updated_contact.clone()],
3426 remove_contacts: Default::default(),
3427 incoming_requests: Default::default(),
3428 remove_incoming_requests: Default::default(),
3429 outgoing_requests: Default::default(),
3430 remove_outgoing_requests: Default::default(),
3431 },
3432 )
3433 .trace_err();
3434 }
3435 }
3436 }
3437 Ok(())
3438}
3439
3440async fn leave_room_for_session(session: &Session) -> Result<()> {
3441 let mut contacts_to_update = HashSet::default();
3442
3443 let room_id;
3444 let canceled_calls_to_user_ids;
3445 let live_kit_room;
3446 let delete_live_kit_room;
3447 let room;
3448 let channel_members;
3449 let channel_id;
3450
3451 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3452 contacts_to_update.insert(session.user_id);
3453
3454 for project in left_room.left_projects.values() {
3455 project_left(project, session);
3456 }
3457
3458 room_id = RoomId::from_proto(left_room.room.id);
3459 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3460 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3461 delete_live_kit_room = left_room.deleted;
3462 room = mem::take(&mut left_room.room);
3463 channel_members = mem::take(&mut left_room.channel_members);
3464 channel_id = left_room.channel_id;
3465
3466 room_updated(&room, &session.peer);
3467 } else {
3468 return Ok(());
3469 }
3470
3471 if let Some(channel_id) = channel_id {
3472 channel_updated(
3473 channel_id,
3474 &room,
3475 &channel_members,
3476 &session.peer,
3477 &*session.connection_pool().await,
3478 );
3479 }
3480
3481 {
3482 let pool = session.connection_pool().await;
3483 for canceled_user_id in canceled_calls_to_user_ids {
3484 for connection_id in pool.user_connection_ids(canceled_user_id) {
3485 session
3486 .peer
3487 .send(
3488 connection_id,
3489 proto::CallCanceled {
3490 room_id: room_id.to_proto(),
3491 },
3492 )
3493 .trace_err();
3494 }
3495 contacts_to_update.insert(canceled_user_id);
3496 }
3497 }
3498
3499 for contact_user_id in contacts_to_update {
3500 update_user_contacts(contact_user_id, &session).await?;
3501 }
3502
3503 if let Some(live_kit) = session.live_kit_client.as_ref() {
3504 live_kit
3505 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3506 .await
3507 .trace_err();
3508
3509 if delete_live_kit_room {
3510 live_kit.delete_room(live_kit_room).await.trace_err();
3511 }
3512 }
3513
3514 Ok(())
3515}
3516
3517async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3518 let left_channel_buffers = session
3519 .db()
3520 .await
3521 .leave_channel_buffers(session.connection_id)
3522 .await?;
3523
3524 for left_buffer in left_channel_buffers {
3525 channel_buffer_updated(
3526 session.connection_id,
3527 left_buffer.connections,
3528 &proto::UpdateChannelBufferCollaborators {
3529 channel_id: left_buffer.channel_id.to_proto(),
3530 collaborators: left_buffer.collaborators,
3531 },
3532 &session.peer,
3533 );
3534 }
3535
3536 Ok(())
3537}
3538
3539fn project_left(project: &db::LeftProject, session: &Session) {
3540 for connection_id in &project.connection_ids {
3541 if project.host_user_id == session.user_id {
3542 session
3543 .peer
3544 .send(
3545 *connection_id,
3546 proto::UnshareProject {
3547 project_id: project.id.to_proto(),
3548 },
3549 )
3550 .trace_err();
3551 } else {
3552 session
3553 .peer
3554 .send(
3555 *connection_id,
3556 proto::RemoveProjectCollaborator {
3557 project_id: project.id.to_proto(),
3558 peer_id: Some(session.connection_id.into()),
3559 },
3560 )
3561 .trace_err();
3562 }
3563 }
3564}
3565
3566pub trait ResultExt {
3567 type Ok;
3568
3569 fn trace_err(self) -> Option<Self::Ok>;
3570}
3571
3572impl<T, E> ResultExt for Result<T, E>
3573where
3574 E: std::fmt::Debug,
3575{
3576 type Ok = T;
3577
3578 fn trace_err(self) -> Option<T> {
3579 match self {
3580 Ok(value) => Some(value),
3581 Err(error) => {
3582 tracing::error!("{:?}", error);
3583 None
3584 }
3585 }
3586 }
3587}