1mod connection_pool;
2
3use crate::{
4 auth::{self, Impersonator},
5 db::{
6 self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage, Database,
7 InviteMemberResult, MembershipUpdated, MessageId, NotificationId, ProjectId,
8 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
9 },
10 executor::Executor,
11 AppState, Error, Result,
12};
13use anyhow::anyhow;
14use async_tungstenite::tungstenite::{
15 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
16};
17use axum::{
18 body::Body,
19 extract::{
20 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
21 ConnectInfo, WebSocketUpgrade,
22 },
23 headers::{Header, HeaderName},
24 http::StatusCode,
25 middleware,
26 response::IntoResponse,
27 routing::get,
28 Extension, Router, TypedHeader,
29};
30use collections::{HashMap, HashSet};
31pub use connection_pool::ConnectionPool;
32use futures::{
33 channel::oneshot,
34 future::{self, BoxFuture},
35 stream::FuturesUnordered,
36 FutureExt, SinkExt, StreamExt, TryStreamExt,
37};
38use lazy_static::lazy_static;
39use prometheus::{register_int_gauge, IntGauge};
40use rpc::{
41 proto::{
42 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
43 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
44 },
45 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
46};
47use serde::{Serialize, Serializer};
48use std::{
49 any::TypeId,
50 fmt,
51 future::Future,
52 marker::PhantomData,
53 mem,
54 net::SocketAddr,
55 ops::{Deref, DerefMut},
56 rc::Rc,
57 sync::{
58 atomic::{AtomicBool, Ordering::SeqCst},
59 Arc,
60 },
61 time::{Duration, Instant},
62};
63use time::OffsetDateTime;
64use tokio::sync::{watch, Semaphore};
65use tower::ServiceBuilder;
66use tracing::{field, info_span, instrument, Instrument};
67use util::SemanticVersion;
68
69pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
70pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
71
72const MESSAGE_COUNT_PER_PAGE: usize = 100;
73const MAX_MESSAGE_LEN: usize = 1024;
74const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
75
76lazy_static! {
77 static ref METRIC_CONNECTIONS: IntGauge =
78 register_int_gauge!("connections", "number of connections").unwrap();
79 static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
80 "shared_projects",
81 "number of open projects with one or more guests"
82 )
83 .unwrap();
84}
85
86type MessageHandler =
87 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
88
89struct Response<R> {
90 peer: Arc<Peer>,
91 receipt: Receipt<R>,
92 responded: Arc<AtomicBool>,
93}
94
95impl<R: RequestMessage> Response<R> {
96 fn send(self, payload: R::Response) -> Result<()> {
97 self.responded.store(true, SeqCst);
98 self.peer.respond(self.receipt, payload)?;
99 Ok(())
100 }
101}
102
103#[derive(Clone)]
104struct Session {
105 zed_environment: Arc<str>,
106 user_id: UserId,
107 connection_id: ConnectionId,
108 zed_version: SemanticVersion,
109 db: Arc<tokio::sync::Mutex<DbHandle>>,
110 peer: Arc<Peer>,
111 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
112 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
113 _executor: Executor,
114}
115
116impl Session {
117 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
118 #[cfg(test)]
119 tokio::task::yield_now().await;
120 let guard = self.db.lock().await;
121 #[cfg(test)]
122 tokio::task::yield_now().await;
123 guard
124 }
125
126 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
127 #[cfg(test)]
128 tokio::task::yield_now().await;
129 let guard = self.connection_pool.lock();
130 ConnectionPoolGuard {
131 guard,
132 _not_send: PhantomData,
133 }
134 }
135
136 fn endpoint_removed_in(&self, endpoint: &str, version: SemanticVersion) -> anyhow::Result<()> {
137 if self.zed_version > version {
138 Err(anyhow!(
139 "{} was removed in {} (you're on {})",
140 endpoint,
141 version,
142 self.zed_version
143 ))
144 } else {
145 Ok(())
146 }
147 }
148}
149
150impl fmt::Debug for Session {
151 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152 f.debug_struct("Session")
153 .field("user_id", &self.user_id)
154 .field("connection_id", &self.connection_id)
155 .finish()
156 }
157}
158
159struct DbHandle(Arc<Database>);
160
161impl Deref for DbHandle {
162 type Target = Database;
163
164 fn deref(&self) -> &Self::Target {
165 self.0.as_ref()
166 }
167}
168
169pub struct Server {
170 id: parking_lot::Mutex<ServerId>,
171 peer: Arc<Peer>,
172 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
173 app_state: Arc<AppState>,
174 executor: Executor,
175 handlers: HashMap<TypeId, MessageHandler>,
176 teardown: watch::Sender<()>,
177}
178
179pub(crate) struct ConnectionPoolGuard<'a> {
180 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
181 _not_send: PhantomData<Rc<()>>,
182}
183
184#[derive(Serialize)]
185pub struct ServerSnapshot<'a> {
186 peer: &'a Peer,
187 #[serde(serialize_with = "serialize_deref")]
188 connection_pool: ConnectionPoolGuard<'a>,
189}
190
191pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
192where
193 S: Serializer,
194 T: Deref<Target = U>,
195 U: Serialize,
196{
197 Serialize::serialize(value.deref(), serializer)
198}
199
200impl Server {
201 pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
202 let mut server = Self {
203 id: parking_lot::Mutex::new(id),
204 peer: Peer::new(id.0 as u32),
205 app_state,
206 executor,
207 connection_pool: Default::default(),
208 handlers: Default::default(),
209 teardown: watch::channel(()).0,
210 };
211
212 server
213 .add_request_handler(ping)
214 .add_request_handler(create_room)
215 .add_request_handler(join_room)
216 .add_request_handler(rejoin_room)
217 .add_request_handler(leave_room)
218 .add_request_handler(set_room_participant_role)
219 .add_request_handler(call)
220 .add_request_handler(cancel_call)
221 .add_message_handler(decline_call)
222 .add_request_handler(update_participant_location)
223 .add_request_handler(share_project)
224 .add_message_handler(unshare_project)
225 .add_request_handler(join_project)
226 .add_message_handler(leave_project)
227 .add_request_handler(update_project)
228 .add_request_handler(update_worktree)
229 .add_message_handler(start_language_server)
230 .add_message_handler(update_language_server)
231 .add_message_handler(update_diagnostic_summary)
232 .add_message_handler(update_worktree_settings)
233 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
234 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
235 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
236 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
237 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
238 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
239 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
240 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
241 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
242 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
243 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
244 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
245 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
246 .add_request_handler(
247 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
248 )
249 .add_request_handler(
250 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
251 )
252 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
253 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
254 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
255 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
256 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
257 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
258 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
259 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
260 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
261 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
262 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
263 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
264 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
265 .add_message_handler(create_buffer_for_peer)
266 .add_request_handler(update_buffer)
267 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
268 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
269 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
270 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
271 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
272 .add_request_handler(get_users)
273 .add_request_handler(fuzzy_search_users)
274 .add_request_handler(request_contact)
275 .add_request_handler(remove_contact)
276 .add_request_handler(respond_to_contact_request)
277 .add_request_handler(create_channel)
278 .add_request_handler(delete_channel)
279 .add_request_handler(invite_channel_member)
280 .add_request_handler(remove_channel_member)
281 .add_request_handler(set_channel_member_role)
282 .add_request_handler(set_channel_visibility)
283 .add_request_handler(rename_channel)
284 .add_request_handler(join_channel_buffer)
285 .add_request_handler(leave_channel_buffer)
286 .add_message_handler(update_channel_buffer)
287 .add_request_handler(rejoin_channel_buffers)
288 .add_request_handler(get_channel_members)
289 .add_request_handler(respond_to_channel_invite)
290 .add_request_handler(join_channel)
291 .add_request_handler(join_channel2)
292 .add_request_handler(join_channel_chat)
293 .add_message_handler(leave_channel_chat)
294 .add_request_handler(join_channel_call)
295 .add_request_handler(leave_channel_call)
296 .add_request_handler(send_channel_message)
297 .add_request_handler(remove_channel_message)
298 .add_request_handler(get_channel_messages)
299 .add_request_handler(get_channel_messages_by_id)
300 .add_request_handler(get_notifications)
301 .add_request_handler(mark_notification_as_read)
302 .add_request_handler(move_channel)
303 .add_request_handler(follow)
304 .add_message_handler(unfollow)
305 .add_message_handler(update_followers)
306 .add_request_handler(get_private_user_info)
307 .add_message_handler(acknowledge_channel_message)
308 .add_message_handler(acknowledge_buffer_version);
309
310 Arc::new(server)
311 }
312
313 pub async fn start(&self) -> Result<()> {
314 let server_id = *self.id.lock();
315 let app_state = self.app_state.clone();
316 let peer = self.peer.clone();
317 let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
318 let pool = self.connection_pool.clone();
319 let live_kit_client = self.app_state.live_kit_client.clone();
320
321 let span = info_span!("start server");
322 self.executor.spawn_detached(
323 async move {
324 tracing::info!("waiting for cleanup timeout");
325 timeout.await;
326 tracing::info!("cleanup timeout expired, retrieving stale rooms");
327 if let Some((room_ids, channel_ids)) = app_state
328 .db
329 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
330 .await
331 .trace_err()
332 {
333 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
334 tracing::info!(
335 stale_channel_buffer_count = channel_ids.len(),
336 "retrieved stale channel buffers"
337 );
338
339 for channel_id in channel_ids {
340 if let Some(refreshed_channel_buffer) = app_state
341 .db
342 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
343 .await
344 .trace_err()
345 {
346 for connection_id in refreshed_channel_buffer.connection_ids {
347 peer.send(
348 connection_id,
349 proto::UpdateChannelBufferCollaborators {
350 channel_id: channel_id.to_proto(),
351 collaborators: refreshed_channel_buffer
352 .collaborators
353 .clone(),
354 },
355 )
356 .trace_err();
357 }
358 }
359 }
360
361 for room_id in room_ids {
362 let mut contacts_to_update = HashSet::default();
363 let mut canceled_calls_to_user_ids = Vec::new();
364 let mut live_kit_room = String::new();
365 let mut delete_live_kit_room = false;
366
367 if let Some(mut refreshed_room) = app_state
368 .db
369 .clear_stale_room_participants(room_id, server_id)
370 .await
371 .trace_err()
372 {
373 tracing::info!(
374 room_id = room_id.0,
375 new_participant_count = refreshed_room.room.participants.len(),
376 "refreshed room"
377 );
378 room_updated(&refreshed_room.room, &peer);
379 if let Some(channel_id) = refreshed_room.channel_id {
380 channel_updated(
381 channel_id,
382 &refreshed_room.room,
383 &refreshed_room.channel_members,
384 &peer,
385 &*pool.lock(),
386 );
387 }
388 contacts_to_update
389 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
390 contacts_to_update
391 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
392 canceled_calls_to_user_ids =
393 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
394 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
395 delete_live_kit_room = refreshed_room.room.participants.is_empty();
396 }
397
398 {
399 let pool = pool.lock();
400 for canceled_user_id in canceled_calls_to_user_ids {
401 for connection_id in pool.user_connection_ids(canceled_user_id) {
402 peer.send(
403 connection_id,
404 proto::CallCanceled {
405 room_id: room_id.to_proto(),
406 },
407 )
408 .trace_err();
409 }
410 }
411 }
412
413 for user_id in contacts_to_update {
414 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
415 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
416 if let Some((busy, contacts)) = busy.zip(contacts) {
417 let pool = pool.lock();
418 let updated_contact = contact_for_user(user_id, busy, &pool);
419 for contact in contacts {
420 if let db::Contact::Accepted {
421 user_id: contact_user_id,
422 ..
423 } = contact
424 {
425 for contact_conn_id in
426 pool.user_connection_ids(contact_user_id)
427 {
428 peer.send(
429 contact_conn_id,
430 proto::UpdateContacts {
431 contacts: vec![updated_contact.clone()],
432 remove_contacts: Default::default(),
433 incoming_requests: Default::default(),
434 remove_incoming_requests: Default::default(),
435 outgoing_requests: Default::default(),
436 remove_outgoing_requests: Default::default(),
437 },
438 )
439 .trace_err();
440 }
441 }
442 }
443 }
444 }
445
446 if let Some(live_kit) = live_kit_client.as_ref() {
447 if delete_live_kit_room {
448 live_kit.delete_room(live_kit_room).await.trace_err();
449 }
450 }
451 }
452 }
453
454 app_state
455 .db
456 .delete_stale_servers(&app_state.config.zed_environment, server_id)
457 .await
458 .trace_err();
459 }
460 .instrument(span),
461 );
462 Ok(())
463 }
464
465 pub fn teardown(&self) {
466 self.peer.teardown();
467 self.connection_pool.lock().reset();
468 let _ = self.teardown.send(());
469 }
470
471 #[cfg(test)]
472 pub fn reset(&self, id: ServerId) {
473 self.teardown();
474 *self.id.lock() = id;
475 self.peer.reset(id.0 as u32);
476 }
477
478 #[cfg(test)]
479 pub fn id(&self) -> ServerId {
480 *self.id.lock()
481 }
482
483 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
484 where
485 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
486 Fut: 'static + Send + Future<Output = Result<()>>,
487 M: EnvelopedMessage,
488 {
489 let prev_handler = self.handlers.insert(
490 TypeId::of::<M>(),
491 Box::new(move |envelope, session| {
492 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
493 let span = info_span!(
494 "handle message",
495 payload_type = envelope.payload_type_name()
496 );
497 span.in_scope(|| {
498 tracing::info!(
499 payload_type = envelope.payload_type_name(),
500 "message received"
501 );
502 });
503 let start_time = Instant::now();
504 let future = (handler)(*envelope, session);
505 async move {
506 let result = future.await;
507 let duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
508 match result {
509 Err(error) => {
510 tracing::error!(%error, ?duration_ms, "error handling message")
511 }
512 Ok(()) => tracing::info!(?duration_ms, "finished handling message"),
513 }
514 }
515 .instrument(span)
516 .boxed()
517 }),
518 );
519 if prev_handler.is_some() {
520 panic!("registered a handler for the same message twice");
521 }
522 self
523 }
524
525 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
526 where
527 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
528 Fut: 'static + Send + Future<Output = Result<()>>,
529 M: EnvelopedMessage,
530 {
531 self.add_handler(move |envelope, session| handler(envelope.payload, session));
532 self
533 }
534
535 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
536 where
537 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
538 Fut: Send + Future<Output = Result<()>>,
539 M: RequestMessage,
540 {
541 let handler = Arc::new(handler);
542 self.add_handler(move |envelope, session| {
543 let receipt = envelope.receipt();
544 let handler = handler.clone();
545 async move {
546 let peer = session.peer.clone();
547 let responded = Arc::new(AtomicBool::default());
548 let response = Response {
549 peer: peer.clone(),
550 responded: responded.clone(),
551 receipt,
552 };
553 match (handler)(envelope.payload, response, session).await {
554 Ok(()) => {
555 if responded.load(std::sync::atomic::Ordering::SeqCst) {
556 Ok(())
557 } else {
558 Err(anyhow!("handler did not send a response"))?
559 }
560 }
561 Err(error) => {
562 let proto_err = match &error {
563 Error::Internal(err) => err.to_proto(),
564 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
565 };
566 peer.respond_with_error(receipt, proto_err)?;
567 Err(error)
568 }
569 }
570 }
571 })
572 }
573
574 pub fn handle_connection(
575 self: &Arc<Self>,
576 connection: Connection,
577 address: String,
578 user: User,
579 zed_version: SemanticVersion,
580 impersonator: Option<User>,
581 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
582 executor: Executor,
583 ) -> impl Future<Output = Result<()>> {
584 let this = self.clone();
585 let user_id = user.id;
586 let login = user.github_login;
587 let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty);
588 if let Some(impersonator) = impersonator {
589 span.record("impersonator", &impersonator.github_login);
590 }
591 let mut teardown = self.teardown.subscribe();
592 async move {
593 let (connection_id, handle_io, mut incoming_rx) = this
594 .peer
595 .add_connection(connection, {
596 let executor = executor.clone();
597 move |duration| executor.sleep(duration)
598 });
599
600 tracing::info!(%user_id, %login, %connection_id, %address, "connection opened");
601 this.peer.send(connection_id, proto::Hello { peer_id: Some(connection_id.into()) })?;
602 tracing::info!(%user_id, %login, %connection_id, %address, "sent hello message");
603
604 if let Some(send_connection_id) = send_connection_id.take() {
605 let _ = send_connection_id.send(connection_id);
606 }
607
608 if !user.connected_once {
609 this.peer.send(connection_id, proto::ShowContacts {})?;
610 this.app_state.db.set_user_connected_once(user_id, true).await?;
611 }
612
613 let (contacts, channels_for_user, channel_invites) = future::try_join3(
614 this.app_state.db.get_contacts(user_id),
615 this.app_state.db.get_channels_for_user(user_id),
616 this.app_state.db.get_channel_invites_for_user(user_id),
617 ).await?;
618
619 {
620 let mut pool = this.connection_pool.lock();
621 pool.add_connection(connection_id, user_id, user.admin);
622 this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?;
623 this.peer.send(connection_id, build_update_user_channels(&channels_for_user))?;
624 this.peer.send(connection_id, build_channels_update(
625 channels_for_user,
626 channel_invites
627 ))?;
628 }
629
630 if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? {
631 this.peer.send(connection_id, incoming_call)?;
632 }
633
634 let session = Session {
635 user_id,
636 connection_id,
637 zed_version,
638 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
639 zed_environment: this.app_state.config.zed_environment.clone(),
640 peer: this.peer.clone(),
641 connection_pool: this.connection_pool.clone(),
642 live_kit_client: this.app_state.live_kit_client.clone(),
643 _executor: executor.clone()
644 };
645 update_user_contacts(user_id, &session).await?;
646
647 let handle_io = handle_io.fuse();
648 futures::pin_mut!(handle_io);
649
650 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
651 // This prevents deadlocks when e.g., client A performs a request to client B and
652 // client B performs a request to client A. If both clients stop processing further
653 // messages until their respective request completes, they won't have a chance to
654 // respond to the other client's request and cause a deadlock.
655 //
656 // This arrangement ensures we will attempt to process earlier messages first, but fall
657 // back to processing messages arrived later in the spirit of making progress.
658 let mut foreground_message_handlers = FuturesUnordered::new();
659 let concurrent_handlers = Arc::new(Semaphore::new(256));
660 loop {
661 let next_message = async {
662 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
663 let message = incoming_rx.next().await;
664 (permit, message)
665 }.fuse();
666 futures::pin_mut!(next_message);
667 futures::select_biased! {
668 _ = teardown.changed().fuse() => return Ok(()),
669 result = handle_io => {
670 if let Err(error) = result {
671 tracing::error!(?error, %user_id, %login, %connection_id, %address, "error handling I/O");
672 }
673 break;
674 }
675 _ = foreground_message_handlers.next() => {}
676 next_message = next_message => {
677 let (permit, message) = next_message;
678 if let Some(message) = message {
679 let type_name = message.payload_type_name();
680 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
681 let span_enter = span.enter();
682 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
683 let is_background = message.is_background();
684 let handle_message = (handler)(message, session.clone());
685 drop(span_enter);
686
687 let handle_message = async move {
688 handle_message.await;
689 drop(permit);
690 }.instrument(span);
691 if is_background {
692 executor.spawn_detached(handle_message);
693 } else {
694 foreground_message_handlers.push(handle_message);
695 }
696 } else {
697 tracing::error!(%user_id, %login, %connection_id, %address, "no message handler");
698 }
699 } else {
700 tracing::info!(%user_id, %login, %connection_id, %address, "connection closed");
701 break;
702 }
703 }
704 }
705 }
706
707 drop(foreground_message_handlers);
708 tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
709 if let Err(error) = connection_lost(session, teardown, executor).await {
710 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
711 }
712
713 Ok(())
714 }.instrument(span)
715 }
716
717 pub async fn invite_code_redeemed(
718 self: &Arc<Self>,
719 inviter_id: UserId,
720 invitee_id: UserId,
721 ) -> Result<()> {
722 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
723 if let Some(code) = &user.invite_code {
724 let pool = self.connection_pool.lock();
725 let invitee_contact = contact_for_user(invitee_id, false, &pool);
726 for connection_id in pool.user_connection_ids(inviter_id) {
727 self.peer.send(
728 connection_id,
729 proto::UpdateContacts {
730 contacts: vec![invitee_contact.clone()],
731 ..Default::default()
732 },
733 )?;
734 self.peer.send(
735 connection_id,
736 proto::UpdateInviteInfo {
737 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
738 count: user.invite_count as u32,
739 },
740 )?;
741 }
742 }
743 }
744 Ok(())
745 }
746
747 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
748 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
749 if let Some(invite_code) = &user.invite_code {
750 let pool = self.connection_pool.lock();
751 for connection_id in pool.user_connection_ids(user_id) {
752 self.peer.send(
753 connection_id,
754 proto::UpdateInviteInfo {
755 url: format!(
756 "{}{}",
757 self.app_state.config.invite_link_prefix, invite_code
758 ),
759 count: user.invite_count as u32,
760 },
761 )?;
762 }
763 }
764 }
765 Ok(())
766 }
767
768 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
769 ServerSnapshot {
770 connection_pool: ConnectionPoolGuard {
771 guard: self.connection_pool.lock(),
772 _not_send: PhantomData,
773 },
774 peer: &self.peer,
775 }
776 }
777}
778
779impl<'a> Deref for ConnectionPoolGuard<'a> {
780 type Target = ConnectionPool;
781
782 fn deref(&self) -> &Self::Target {
783 &*self.guard
784 }
785}
786
787impl<'a> DerefMut for ConnectionPoolGuard<'a> {
788 fn deref_mut(&mut self) -> &mut Self::Target {
789 &mut *self.guard
790 }
791}
792
793impl<'a> Drop for ConnectionPoolGuard<'a> {
794 fn drop(&mut self) {
795 #[cfg(test)]
796 self.check_invariants();
797 }
798}
799
800fn broadcast<F>(
801 sender_id: Option<ConnectionId>,
802 receiver_ids: impl IntoIterator<Item = ConnectionId>,
803 mut f: F,
804) where
805 F: FnMut(ConnectionId) -> anyhow::Result<()>,
806{
807 for receiver_id in receiver_ids {
808 if Some(receiver_id) != sender_id {
809 if let Err(error) = f(receiver_id) {
810 tracing::error!("failed to send to {:?} {}", receiver_id, error);
811 }
812 }
813 }
814}
815
816lazy_static! {
817 static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
818 static ref ZED_APP_VERSION: HeaderName = HeaderName::from_static("x-zed-app-version");
819}
820
821pub struct ProtocolVersion(u32);
822
823impl Header for ProtocolVersion {
824 fn name() -> &'static HeaderName {
825 &ZED_PROTOCOL_VERSION
826 }
827
828 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
829 where
830 Self: Sized,
831 I: Iterator<Item = &'i axum::http::HeaderValue>,
832 {
833 let version = values
834 .next()
835 .ok_or_else(axum::headers::Error::invalid)?
836 .to_str()
837 .map_err(|_| axum::headers::Error::invalid())?
838 .parse()
839 .map_err(|_| axum::headers::Error::invalid())?;
840 Ok(Self(version))
841 }
842
843 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
844 values.extend([self.0.to_string().parse().unwrap()]);
845 }
846}
847
848pub struct AppVersionHeader(SemanticVersion);
849impl Header for AppVersionHeader {
850 fn name() -> &'static HeaderName {
851 &ZED_APP_VERSION
852 }
853
854 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
855 where
856 Self: Sized,
857 I: Iterator<Item = &'i axum::http::HeaderValue>,
858 {
859 let version = values
860 .next()
861 .ok_or_else(axum::headers::Error::invalid)?
862 .to_str()
863 .map_err(|_| axum::headers::Error::invalid())?
864 .parse()
865 .map_err(|_| axum::headers::Error::invalid())?;
866 Ok(Self(version))
867 }
868
869 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
870 values.extend([self.0.to_string().parse().unwrap()]);
871 }
872}
873
874pub fn routes(server: Arc<Server>) -> Router<Body> {
875 Router::new()
876 .route("/rpc", get(handle_websocket_request))
877 .layer(
878 ServiceBuilder::new()
879 .layer(Extension(server.app_state.clone()))
880 .layer(middleware::from_fn(auth::validate_header)),
881 )
882 .route("/metrics", get(handle_metrics))
883 .layer(Extension(server))
884}
885
886pub async fn handle_websocket_request(
887 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
888 app_version_header: Option<TypedHeader<AppVersionHeader>>,
889 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
890 Extension(server): Extension<Arc<Server>>,
891 Extension(user): Extension<User>,
892 Extension(impersonator): Extension<Impersonator>,
893 ws: WebSocketUpgrade,
894) -> axum::response::Response {
895 if protocol_version != rpc::PROTOCOL_VERSION {
896 return (
897 StatusCode::UPGRADE_REQUIRED,
898 "client must be upgraded".to_string(),
899 )
900 .into_response();
901 }
902
903 // zed 0.122.x was the first version that sent an app header, so once that hits stable
904 // we can return UPGRADE_REQUIRED instead of unwrap_or_default();
905 let app_version = app_version_header
906 .map(|header| header.0 .0)
907 .unwrap_or_default();
908
909 let socket_address = socket_address.to_string();
910 ws.on_upgrade(move |socket| {
911 use util::ResultExt;
912 let socket = socket
913 .map_ok(to_tungstenite_message)
914 .err_into()
915 .with(|message| async move { Ok(to_axum_message(message)) });
916 let connection = Connection::new(Box::pin(socket));
917 async move {
918 server
919 .handle_connection(
920 connection,
921 socket_address,
922 user,
923 app_version,
924 impersonator.0,
925 None,
926 Executor::Production,
927 )
928 .await
929 .log_err();
930 }
931 })
932}
933
934pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
935 let connections = server
936 .connection_pool
937 .lock()
938 .connections()
939 .filter(|connection| !connection.admin)
940 .count();
941
942 METRIC_CONNECTIONS.set(connections as _);
943
944 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
945 METRIC_SHARED_PROJECTS.set(shared_projects as _);
946
947 let encoder = prometheus::TextEncoder::new();
948 let metric_families = prometheus::gather();
949 let encoded_metrics = encoder
950 .encode_to_string(&metric_families)
951 .map_err(|err| anyhow!("{}", err))?;
952 Ok(encoded_metrics)
953}
954
955#[instrument(err, skip(executor))]
956async fn connection_lost(
957 session: Session,
958 mut teardown: watch::Receiver<()>,
959 executor: Executor,
960) -> Result<()> {
961 session.peer.disconnect(session.connection_id);
962 session
963 .connection_pool()
964 .await
965 .remove_connection(session.connection_id)?;
966
967 session
968 .db()
969 .await
970 .connection_lost(session.connection_id)
971 .await
972 .trace_err();
973
974 futures::select_biased! {
975 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
976 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
977 leave_room_for_session(&session).await.trace_err();
978 leave_channel_buffers_for_session(&session)
979 .await
980 .trace_err();
981
982 if !session
983 .connection_pool()
984 .await
985 .is_user_online(session.user_id)
986 {
987 let db = session.db().await;
988 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
989 room_updated(&room, &session.peer);
990 }
991 }
992
993 update_user_contacts(session.user_id, &session).await?;
994 }
995 _ = teardown.changed().fuse() => {}
996 }
997
998 Ok(())
999}
1000
1001/// Acknowledges a ping from a client, used to keep the connection alive.
1002async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1003 response.send(proto::Ack {})?;
1004 Ok(())
1005}
1006
1007/// Creates a new room for calling (outside of channels)
1008async fn create_room(
1009 _request: proto::CreateRoom,
1010 response: Response<proto::CreateRoom>,
1011 session: Session,
1012) -> Result<()> {
1013 let live_kit_room = nanoid::nanoid!(30);
1014
1015 let live_kit_connection_info = {
1016 let live_kit_room = live_kit_room.clone();
1017 let live_kit = session.live_kit_client.as_ref();
1018
1019 util::async_maybe!({
1020 let live_kit = live_kit?;
1021
1022 let token = live_kit
1023 .room_token(&live_kit_room, &session.user_id.to_string())
1024 .trace_err()?;
1025
1026 Some(proto::LiveKitConnectionInfo {
1027 server_url: live_kit.url().into(),
1028 token,
1029 can_publish: true,
1030 })
1031 })
1032 }
1033 .await;
1034
1035 let room = session
1036 .db()
1037 .await
1038 .create_room(
1039 session.user_id,
1040 session.connection_id,
1041 &live_kit_room,
1042 &session.zed_environment,
1043 )
1044 .await?;
1045
1046 response.send(proto::CreateRoomResponse {
1047 room: Some(room.clone()),
1048 live_kit_connection_info,
1049 })?;
1050
1051 update_user_contacts(session.user_id, &session).await?;
1052 Ok(())
1053}
1054
1055/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1056async fn join_room(
1057 request: proto::JoinRoom,
1058 response: Response<proto::JoinRoom>,
1059 session: Session,
1060) -> Result<()> {
1061 let room_id = RoomId::from_proto(request.id);
1062
1063 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1064
1065 if let Some(channel_id) = channel_id {
1066 return join_channel_internal(channel_id, true, Box::new(response), session).await;
1067 }
1068
1069 let joined_room = {
1070 let room = session
1071 .db()
1072 .await
1073 .join_room(
1074 room_id,
1075 session.user_id,
1076 session.connection_id,
1077 session.zed_environment.as_ref(),
1078 )
1079 .await?;
1080 room_updated(&room.room, &session.peer);
1081 room.into_inner()
1082 };
1083
1084 for connection_id in session
1085 .connection_pool()
1086 .await
1087 .user_connection_ids(session.user_id)
1088 {
1089 session
1090 .peer
1091 .send(
1092 connection_id,
1093 proto::CallCanceled {
1094 room_id: room_id.to_proto(),
1095 },
1096 )
1097 .trace_err();
1098 }
1099
1100 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1101 if let Some(token) = live_kit
1102 .room_token(
1103 &joined_room.room.live_kit_room,
1104 &session.user_id.to_string(),
1105 )
1106 .trace_err()
1107 {
1108 Some(proto::LiveKitConnectionInfo {
1109 server_url: live_kit.url().into(),
1110 token,
1111 can_publish: true,
1112 })
1113 } else {
1114 None
1115 }
1116 } else {
1117 None
1118 };
1119
1120 response.send(proto::JoinRoomResponse {
1121 room: Some(joined_room.room),
1122 channel_id: None,
1123 live_kit_connection_info,
1124 })?;
1125
1126 update_user_contacts(session.user_id, &session).await?;
1127 Ok(())
1128}
1129
1130/// Rejoin room is used to reconnect to a room after connection errors.
1131async fn rejoin_room(
1132 request: proto::RejoinRoom,
1133 response: Response<proto::RejoinRoom>,
1134 session: Session,
1135) -> Result<()> {
1136 let room;
1137 let channel_id;
1138 let channel_members;
1139 {
1140 let mut rejoined_room = session
1141 .db()
1142 .await
1143 .rejoin_room(request, session.user_id, session.connection_id)
1144 .await?;
1145
1146 response.send(proto::RejoinRoomResponse {
1147 room: Some(rejoined_room.room.clone()),
1148 reshared_projects: rejoined_room
1149 .reshared_projects
1150 .iter()
1151 .map(|project| proto::ResharedProject {
1152 id: project.id.to_proto(),
1153 collaborators: project
1154 .collaborators
1155 .iter()
1156 .map(|collaborator| collaborator.to_proto())
1157 .collect(),
1158 })
1159 .collect(),
1160 rejoined_projects: rejoined_room
1161 .rejoined_projects
1162 .iter()
1163 .map(|rejoined_project| proto::RejoinedProject {
1164 id: rejoined_project.id.to_proto(),
1165 worktrees: rejoined_project
1166 .worktrees
1167 .iter()
1168 .map(|worktree| proto::WorktreeMetadata {
1169 id: worktree.id,
1170 root_name: worktree.root_name.clone(),
1171 visible: worktree.visible,
1172 abs_path: worktree.abs_path.clone(),
1173 })
1174 .collect(),
1175 collaborators: rejoined_project
1176 .collaborators
1177 .iter()
1178 .map(|collaborator| collaborator.to_proto())
1179 .collect(),
1180 language_servers: rejoined_project.language_servers.clone(),
1181 })
1182 .collect(),
1183 })?;
1184 room_updated(&rejoined_room.room, &session.peer);
1185
1186 for project in &rejoined_room.reshared_projects {
1187 for collaborator in &project.collaborators {
1188 session
1189 .peer
1190 .send(
1191 collaborator.connection_id,
1192 proto::UpdateProjectCollaborator {
1193 project_id: project.id.to_proto(),
1194 old_peer_id: Some(project.old_connection_id.into()),
1195 new_peer_id: Some(session.connection_id.into()),
1196 },
1197 )
1198 .trace_err();
1199 }
1200
1201 broadcast(
1202 Some(session.connection_id),
1203 project
1204 .collaborators
1205 .iter()
1206 .map(|collaborator| collaborator.connection_id),
1207 |connection_id| {
1208 session.peer.forward_send(
1209 session.connection_id,
1210 connection_id,
1211 proto::UpdateProject {
1212 project_id: project.id.to_proto(),
1213 worktrees: project.worktrees.clone(),
1214 },
1215 )
1216 },
1217 );
1218 }
1219
1220 for project in &rejoined_room.rejoined_projects {
1221 for collaborator in &project.collaborators {
1222 session
1223 .peer
1224 .send(
1225 collaborator.connection_id,
1226 proto::UpdateProjectCollaborator {
1227 project_id: project.id.to_proto(),
1228 old_peer_id: Some(project.old_connection_id.into()),
1229 new_peer_id: Some(session.connection_id.into()),
1230 },
1231 )
1232 .trace_err();
1233 }
1234 }
1235
1236 for project in &mut rejoined_room.rejoined_projects {
1237 for worktree in mem::take(&mut project.worktrees) {
1238 #[cfg(any(test, feature = "test-support"))]
1239 const MAX_CHUNK_SIZE: usize = 2;
1240 #[cfg(not(any(test, feature = "test-support")))]
1241 const MAX_CHUNK_SIZE: usize = 256;
1242
1243 // Stream this worktree's entries.
1244 let message = proto::UpdateWorktree {
1245 project_id: project.id.to_proto(),
1246 worktree_id: worktree.id,
1247 abs_path: worktree.abs_path.clone(),
1248 root_name: worktree.root_name,
1249 updated_entries: worktree.updated_entries,
1250 removed_entries: worktree.removed_entries,
1251 scan_id: worktree.scan_id,
1252 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1253 updated_repositories: worktree.updated_repositories,
1254 removed_repositories: worktree.removed_repositories,
1255 };
1256 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1257 session.peer.send(session.connection_id, update.clone())?;
1258 }
1259
1260 // Stream this worktree's diagnostics.
1261 for summary in worktree.diagnostic_summaries {
1262 session.peer.send(
1263 session.connection_id,
1264 proto::UpdateDiagnosticSummary {
1265 project_id: project.id.to_proto(),
1266 worktree_id: worktree.id,
1267 summary: Some(summary),
1268 },
1269 )?;
1270 }
1271
1272 for settings_file in worktree.settings_files {
1273 session.peer.send(
1274 session.connection_id,
1275 proto::UpdateWorktreeSettings {
1276 project_id: project.id.to_proto(),
1277 worktree_id: worktree.id,
1278 path: settings_file.path,
1279 content: Some(settings_file.content),
1280 },
1281 )?;
1282 }
1283 }
1284
1285 for language_server in &project.language_servers {
1286 session.peer.send(
1287 session.connection_id,
1288 proto::UpdateLanguageServer {
1289 project_id: project.id.to_proto(),
1290 language_server_id: language_server.id,
1291 variant: Some(
1292 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1293 proto::LspDiskBasedDiagnosticsUpdated {},
1294 ),
1295 ),
1296 },
1297 )?;
1298 }
1299 }
1300
1301 let rejoined_room = rejoined_room.into_inner();
1302
1303 room = rejoined_room.room;
1304 channel_id = rejoined_room.channel_id;
1305 channel_members = rejoined_room.channel_members;
1306 }
1307
1308 if let Some(channel_id) = channel_id {
1309 channel_updated(
1310 channel_id,
1311 &room,
1312 &channel_members,
1313 &session.peer,
1314 &*session.connection_pool().await,
1315 );
1316 }
1317
1318 update_user_contacts(session.user_id, &session).await?;
1319 Ok(())
1320}
1321
1322/// leave room disconnects from the room.
1323async fn leave_room(
1324 _: proto::LeaveRoom,
1325 response: Response<proto::LeaveRoom>,
1326 session: Session,
1327) -> Result<()> {
1328 leave_room_for_session(&session).await?;
1329 response.send(proto::Ack {})?;
1330 Ok(())
1331}
1332
1333/// Updates the permissions of someone else in the room.
1334async fn set_room_participant_role(
1335 request: proto::SetRoomParticipantRole,
1336 response: Response<proto::SetRoomParticipantRole>,
1337 session: Session,
1338) -> Result<()> {
1339 let (live_kit_room, can_publish) = {
1340 let room = session
1341 .db()
1342 .await
1343 .set_room_participant_role(
1344 session.user_id,
1345 RoomId::from_proto(request.room_id),
1346 UserId::from_proto(request.user_id),
1347 ChannelRole::from(request.role()),
1348 )
1349 .await?;
1350
1351 let live_kit_room = room.live_kit_room.clone();
1352 let can_publish = ChannelRole::from(request.role()).can_publish_to_rooms();
1353 room_updated(&room, &session.peer);
1354 (live_kit_room, can_publish)
1355 };
1356
1357 if let Some(live_kit) = session.live_kit_client.as_ref() {
1358 live_kit
1359 .update_participant(
1360 live_kit_room.clone(),
1361 request.user_id.to_string(),
1362 live_kit_server::proto::ParticipantPermission {
1363 can_subscribe: true,
1364 can_publish,
1365 can_publish_data: can_publish,
1366 hidden: false,
1367 recorder: false,
1368 },
1369 )
1370 .await
1371 .trace_err();
1372 }
1373
1374 response.send(proto::Ack {})?;
1375 Ok(())
1376}
1377
1378/// Call someone else into the current room
1379async fn call(
1380 request: proto::Call,
1381 response: Response<proto::Call>,
1382 session: Session,
1383) -> Result<()> {
1384 let room_id = RoomId::from_proto(request.room_id);
1385 let calling_user_id = session.user_id;
1386 let calling_connection_id = session.connection_id;
1387 let called_user_id = UserId::from_proto(request.called_user_id);
1388 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1389 if !session
1390 .db()
1391 .await
1392 .has_contact(calling_user_id, called_user_id)
1393 .await?
1394 {
1395 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1396 }
1397
1398 let incoming_call = {
1399 let (room, incoming_call) = &mut *session
1400 .db()
1401 .await
1402 .call(
1403 room_id,
1404 calling_user_id,
1405 calling_connection_id,
1406 called_user_id,
1407 initial_project_id,
1408 )
1409 .await?;
1410 room_updated(&room, &session.peer);
1411 mem::take(incoming_call)
1412 };
1413 update_user_contacts(called_user_id, &session).await?;
1414
1415 let mut calls = session
1416 .connection_pool()
1417 .await
1418 .user_connection_ids(called_user_id)
1419 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1420 .collect::<FuturesUnordered<_>>();
1421
1422 while let Some(call_response) = calls.next().await {
1423 match call_response.as_ref() {
1424 Ok(_) => {
1425 response.send(proto::Ack {})?;
1426 return Ok(());
1427 }
1428 Err(_) => {
1429 call_response.trace_err();
1430 }
1431 }
1432 }
1433
1434 {
1435 let room = session
1436 .db()
1437 .await
1438 .call_failed(room_id, called_user_id)
1439 .await?;
1440 room_updated(&room, &session.peer);
1441 }
1442 update_user_contacts(called_user_id, &session).await?;
1443
1444 Err(anyhow!("failed to ring user"))?
1445}
1446
1447/// Cancel an outgoing call.
1448async fn cancel_call(
1449 request: proto::CancelCall,
1450 response: Response<proto::CancelCall>,
1451 session: Session,
1452) -> Result<()> {
1453 let called_user_id = UserId::from_proto(request.called_user_id);
1454 let room_id = RoomId::from_proto(request.room_id);
1455 {
1456 let room = session
1457 .db()
1458 .await
1459 .cancel_call(room_id, session.connection_id, called_user_id)
1460 .await?;
1461 room_updated(&room, &session.peer);
1462 }
1463
1464 for connection_id in session
1465 .connection_pool()
1466 .await
1467 .user_connection_ids(called_user_id)
1468 {
1469 session
1470 .peer
1471 .send(
1472 connection_id,
1473 proto::CallCanceled {
1474 room_id: room_id.to_proto(),
1475 },
1476 )
1477 .trace_err();
1478 }
1479 response.send(proto::Ack {})?;
1480
1481 update_user_contacts(called_user_id, &session).await?;
1482 Ok(())
1483}
1484
1485/// Decline an incoming call.
1486async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1487 let room_id = RoomId::from_proto(message.room_id);
1488 {
1489 let room = session
1490 .db()
1491 .await
1492 .decline_call(Some(room_id), session.user_id)
1493 .await?
1494 .ok_or_else(|| anyhow!("failed to decline call"))?;
1495 room_updated(&room, &session.peer);
1496 }
1497
1498 for connection_id in session
1499 .connection_pool()
1500 .await
1501 .user_connection_ids(session.user_id)
1502 {
1503 session
1504 .peer
1505 .send(
1506 connection_id,
1507 proto::CallCanceled {
1508 room_id: room_id.to_proto(),
1509 },
1510 )
1511 .trace_err();
1512 }
1513 update_user_contacts(session.user_id, &session).await?;
1514 Ok(())
1515}
1516
1517/// Updates other participants in the room with your current location.
1518async fn update_participant_location(
1519 request: proto::UpdateParticipantLocation,
1520 response: Response<proto::UpdateParticipantLocation>,
1521 session: Session,
1522) -> Result<()> {
1523 let room_id = RoomId::from_proto(request.room_id);
1524 let location = request
1525 .location
1526 .ok_or_else(|| anyhow!("invalid location"))?;
1527
1528 let db = session.db().await;
1529 let room = db
1530 .update_room_participant_location(room_id, session.connection_id, location)
1531 .await?;
1532
1533 room_updated(&room, &session.peer);
1534 response.send(proto::Ack {})?;
1535 Ok(())
1536}
1537
1538/// Share a project into the room.
1539async fn share_project(
1540 request: proto::ShareProject,
1541 response: Response<proto::ShareProject>,
1542 session: Session,
1543) -> Result<()> {
1544 let (project_id, room) = &*session
1545 .db()
1546 .await
1547 .share_project(
1548 RoomId::from_proto(request.room_id),
1549 session.connection_id,
1550 &request.worktrees,
1551 )
1552 .await?;
1553 response.send(proto::ShareProjectResponse {
1554 project_id: project_id.to_proto(),
1555 })?;
1556 room_updated(&room, &session.peer);
1557
1558 Ok(())
1559}
1560
1561/// Unshare a project from the room.
1562async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1563 let project_id = ProjectId::from_proto(message.project_id);
1564
1565 let (room, guest_connection_ids) = &*session
1566 .db()
1567 .await
1568 .unshare_project(project_id, session.connection_id)
1569 .await?;
1570
1571 broadcast(
1572 Some(session.connection_id),
1573 guest_connection_ids.iter().copied(),
1574 |conn_id| session.peer.send(conn_id, message.clone()),
1575 );
1576 room_updated(&room, &session.peer);
1577
1578 Ok(())
1579}
1580
1581/// Join someone elses shared project.
1582async fn join_project(
1583 request: proto::JoinProject,
1584 response: Response<proto::JoinProject>,
1585 session: Session,
1586) -> Result<()> {
1587 let project_id = ProjectId::from_proto(request.project_id);
1588 let guest_user_id = session.user_id;
1589
1590 tracing::info!(%project_id, "join project");
1591
1592 let (project, replica_id) = &mut *session
1593 .db()
1594 .await
1595 .join_project(project_id, session.connection_id)
1596 .await?;
1597
1598 let collaborators = project
1599 .collaborators
1600 .iter()
1601 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1602 .map(|collaborator| collaborator.to_proto())
1603 .collect::<Vec<_>>();
1604
1605 let worktrees = project
1606 .worktrees
1607 .iter()
1608 .map(|(id, worktree)| proto::WorktreeMetadata {
1609 id: *id,
1610 root_name: worktree.root_name.clone(),
1611 visible: worktree.visible,
1612 abs_path: worktree.abs_path.clone(),
1613 })
1614 .collect::<Vec<_>>();
1615
1616 for collaborator in &collaborators {
1617 session
1618 .peer
1619 .send(
1620 collaborator.peer_id.unwrap().into(),
1621 proto::AddProjectCollaborator {
1622 project_id: project_id.to_proto(),
1623 collaborator: Some(proto::Collaborator {
1624 peer_id: Some(session.connection_id.into()),
1625 replica_id: replica_id.0 as u32,
1626 user_id: guest_user_id.to_proto(),
1627 }),
1628 },
1629 )
1630 .trace_err();
1631 }
1632
1633 // First, we send the metadata associated with each worktree.
1634 response.send(proto::JoinProjectResponse {
1635 worktrees: worktrees.clone(),
1636 replica_id: replica_id.0 as u32,
1637 collaborators: collaborators.clone(),
1638 language_servers: project.language_servers.clone(),
1639 })?;
1640
1641 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1642 #[cfg(any(test, feature = "test-support"))]
1643 const MAX_CHUNK_SIZE: usize = 2;
1644 #[cfg(not(any(test, feature = "test-support")))]
1645 const MAX_CHUNK_SIZE: usize = 256;
1646
1647 // Stream this worktree's entries.
1648 let message = proto::UpdateWorktree {
1649 project_id: project_id.to_proto(),
1650 worktree_id,
1651 abs_path: worktree.abs_path.clone(),
1652 root_name: worktree.root_name,
1653 updated_entries: worktree.entries,
1654 removed_entries: Default::default(),
1655 scan_id: worktree.scan_id,
1656 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1657 updated_repositories: worktree.repository_entries.into_values().collect(),
1658 removed_repositories: Default::default(),
1659 };
1660 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1661 session.peer.send(session.connection_id, update.clone())?;
1662 }
1663
1664 // Stream this worktree's diagnostics.
1665 for summary in worktree.diagnostic_summaries {
1666 session.peer.send(
1667 session.connection_id,
1668 proto::UpdateDiagnosticSummary {
1669 project_id: project_id.to_proto(),
1670 worktree_id: worktree.id,
1671 summary: Some(summary),
1672 },
1673 )?;
1674 }
1675
1676 for settings_file in worktree.settings_files {
1677 session.peer.send(
1678 session.connection_id,
1679 proto::UpdateWorktreeSettings {
1680 project_id: project_id.to_proto(),
1681 worktree_id: worktree.id,
1682 path: settings_file.path,
1683 content: Some(settings_file.content),
1684 },
1685 )?;
1686 }
1687 }
1688
1689 for language_server in &project.language_servers {
1690 session.peer.send(
1691 session.connection_id,
1692 proto::UpdateLanguageServer {
1693 project_id: project_id.to_proto(),
1694 language_server_id: language_server.id,
1695 variant: Some(
1696 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1697 proto::LspDiskBasedDiagnosticsUpdated {},
1698 ),
1699 ),
1700 },
1701 )?;
1702 }
1703
1704 Ok(())
1705}
1706
1707/// Leave someone elses shared project.
1708async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1709 let sender_id = session.connection_id;
1710 let project_id = ProjectId::from_proto(request.project_id);
1711
1712 let (room, project) = &*session
1713 .db()
1714 .await
1715 .leave_project(project_id, sender_id)
1716 .await?;
1717 tracing::info!(
1718 %project_id,
1719 host_user_id = %project.host_user_id,
1720 host_connection_id = ?project.host_connection_id,
1721 "leave project"
1722 );
1723
1724 project_left(&project, &session);
1725 room_updated(&room, &session.peer);
1726
1727 Ok(())
1728}
1729
1730/// Updates other participants with changes to the project
1731async fn update_project(
1732 request: proto::UpdateProject,
1733 response: Response<proto::UpdateProject>,
1734 session: Session,
1735) -> Result<()> {
1736 let project_id = ProjectId::from_proto(request.project_id);
1737 let (room, guest_connection_ids) = &*session
1738 .db()
1739 .await
1740 .update_project(project_id, session.connection_id, &request.worktrees)
1741 .await?;
1742 broadcast(
1743 Some(session.connection_id),
1744 guest_connection_ids.iter().copied(),
1745 |connection_id| {
1746 session
1747 .peer
1748 .forward_send(session.connection_id, connection_id, request.clone())
1749 },
1750 );
1751 room_updated(&room, &session.peer);
1752 response.send(proto::Ack {})?;
1753
1754 Ok(())
1755}
1756
1757/// Updates other participants with changes to the worktree
1758async fn update_worktree(
1759 request: proto::UpdateWorktree,
1760 response: Response<proto::UpdateWorktree>,
1761 session: Session,
1762) -> Result<()> {
1763 let guest_connection_ids = session
1764 .db()
1765 .await
1766 .update_worktree(&request, session.connection_id)
1767 .await?;
1768
1769 broadcast(
1770 Some(session.connection_id),
1771 guest_connection_ids.iter().copied(),
1772 |connection_id| {
1773 session
1774 .peer
1775 .forward_send(session.connection_id, connection_id, request.clone())
1776 },
1777 );
1778 response.send(proto::Ack {})?;
1779 Ok(())
1780}
1781
1782/// Updates other participants with changes to the diagnostics
1783async fn update_diagnostic_summary(
1784 message: proto::UpdateDiagnosticSummary,
1785 session: Session,
1786) -> Result<()> {
1787 let guest_connection_ids = session
1788 .db()
1789 .await
1790 .update_diagnostic_summary(&message, session.connection_id)
1791 .await?;
1792
1793 broadcast(
1794 Some(session.connection_id),
1795 guest_connection_ids.iter().copied(),
1796 |connection_id| {
1797 session
1798 .peer
1799 .forward_send(session.connection_id, connection_id, message.clone())
1800 },
1801 );
1802
1803 Ok(())
1804}
1805
1806/// Updates other participants with changes to the worktree settings
1807async fn update_worktree_settings(
1808 message: proto::UpdateWorktreeSettings,
1809 session: Session,
1810) -> Result<()> {
1811 let guest_connection_ids = session
1812 .db()
1813 .await
1814 .update_worktree_settings(&message, session.connection_id)
1815 .await?;
1816
1817 broadcast(
1818 Some(session.connection_id),
1819 guest_connection_ids.iter().copied(),
1820 |connection_id| {
1821 session
1822 .peer
1823 .forward_send(session.connection_id, connection_id, message.clone())
1824 },
1825 );
1826
1827 Ok(())
1828}
1829
1830/// Notify other participants that a language server has started.
1831async fn start_language_server(
1832 request: proto::StartLanguageServer,
1833 session: Session,
1834) -> Result<()> {
1835 let guest_connection_ids = session
1836 .db()
1837 .await
1838 .start_language_server(&request, session.connection_id)
1839 .await?;
1840
1841 broadcast(
1842 Some(session.connection_id),
1843 guest_connection_ids.iter().copied(),
1844 |connection_id| {
1845 session
1846 .peer
1847 .forward_send(session.connection_id, connection_id, request.clone())
1848 },
1849 );
1850 Ok(())
1851}
1852
1853/// Notify other participants that a language server has changed.
1854async fn update_language_server(
1855 request: proto::UpdateLanguageServer,
1856 session: Session,
1857) -> Result<()> {
1858 let project_id = ProjectId::from_proto(request.project_id);
1859 let project_connection_ids = session
1860 .db()
1861 .await
1862 .project_connection_ids(project_id, session.connection_id)
1863 .await?;
1864 broadcast(
1865 Some(session.connection_id),
1866 project_connection_ids.iter().copied(),
1867 |connection_id| {
1868 session
1869 .peer
1870 .forward_send(session.connection_id, connection_id, request.clone())
1871 },
1872 );
1873 Ok(())
1874}
1875
1876/// forward a project request to the host. These requests should be read only
1877/// as guests are allowed to send them.
1878async fn forward_read_only_project_request<T>(
1879 request: T,
1880 response: Response<T>,
1881 session: Session,
1882) -> Result<()>
1883where
1884 T: EntityMessage + RequestMessage,
1885{
1886 let project_id = ProjectId::from_proto(request.remote_entity_id());
1887 let host_connection_id = session
1888 .db()
1889 .await
1890 .host_for_read_only_project_request(project_id, session.connection_id)
1891 .await?;
1892 let payload = session
1893 .peer
1894 .forward_request(session.connection_id, host_connection_id, request)
1895 .await?;
1896 response.send(payload)?;
1897 Ok(())
1898}
1899
1900/// forward a project request to the host. These requests are disallowed
1901/// for guests.
1902async fn forward_mutating_project_request<T>(
1903 request: T,
1904 response: Response<T>,
1905 session: Session,
1906) -> Result<()>
1907where
1908 T: EntityMessage + RequestMessage,
1909{
1910 let project_id = ProjectId::from_proto(request.remote_entity_id());
1911 let host_connection_id = session
1912 .db()
1913 .await
1914 .host_for_mutating_project_request(project_id, session.connection_id)
1915 .await?;
1916 let payload = session
1917 .peer
1918 .forward_request(session.connection_id, host_connection_id, request)
1919 .await?;
1920 response.send(payload)?;
1921 Ok(())
1922}
1923
1924/// Notify other participants that a new buffer has been created
1925async fn create_buffer_for_peer(
1926 request: proto::CreateBufferForPeer,
1927 session: Session,
1928) -> Result<()> {
1929 session
1930 .db()
1931 .await
1932 .check_user_is_project_host(
1933 ProjectId::from_proto(request.project_id),
1934 session.connection_id,
1935 )
1936 .await?;
1937 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
1938 session
1939 .peer
1940 .forward_send(session.connection_id, peer_id.into(), request)?;
1941 Ok(())
1942}
1943
1944/// Notify other participants that a buffer has been updated. This is
1945/// allowed for guests as long as the update is limited to selections.
1946async fn update_buffer(
1947 request: proto::UpdateBuffer,
1948 response: Response<proto::UpdateBuffer>,
1949 session: Session,
1950) -> Result<()> {
1951 let project_id = ProjectId::from_proto(request.project_id);
1952 let mut guest_connection_ids;
1953 let mut host_connection_id = None;
1954
1955 let mut requires_write_permission = false;
1956
1957 for op in request.operations.iter() {
1958 match op.variant {
1959 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
1960 Some(_) => requires_write_permission = true,
1961 }
1962 }
1963
1964 {
1965 let collaborators = session
1966 .db()
1967 .await
1968 .project_collaborators_for_buffer_update(
1969 project_id,
1970 session.connection_id,
1971 requires_write_permission,
1972 )
1973 .await?;
1974 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
1975 for collaborator in collaborators.iter() {
1976 if collaborator.is_host {
1977 host_connection_id = Some(collaborator.connection_id);
1978 } else {
1979 guest_connection_ids.push(collaborator.connection_id);
1980 }
1981 }
1982 }
1983 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
1984
1985 broadcast(
1986 Some(session.connection_id),
1987 guest_connection_ids,
1988 |connection_id| {
1989 session
1990 .peer
1991 .forward_send(session.connection_id, connection_id, request.clone())
1992 },
1993 );
1994 if host_connection_id != session.connection_id {
1995 session
1996 .peer
1997 .forward_request(session.connection_id, host_connection_id, request.clone())
1998 .await?;
1999 }
2000
2001 response.send(proto::Ack {})?;
2002 Ok(())
2003}
2004
2005/// Notify other participants that a project has been updated.
2006async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2007 request: T,
2008 session: Session,
2009) -> Result<()> {
2010 let project_id = ProjectId::from_proto(request.remote_entity_id());
2011 let project_connection_ids = session
2012 .db()
2013 .await
2014 .project_connection_ids(project_id, session.connection_id)
2015 .await?;
2016
2017 broadcast(
2018 Some(session.connection_id),
2019 project_connection_ids.iter().copied(),
2020 |connection_id| {
2021 session
2022 .peer
2023 .forward_send(session.connection_id, connection_id, request.clone())
2024 },
2025 );
2026 Ok(())
2027}
2028
2029/// Start following another user in a call.
2030async fn follow(
2031 request: proto::Follow,
2032 response: Response<proto::Follow>,
2033 session: Session,
2034) -> Result<()> {
2035 let room_id = RoomId::from_proto(request.room_id);
2036 let project_id = request.project_id.map(ProjectId::from_proto);
2037 let leader_id = request
2038 .leader_id
2039 .ok_or_else(|| anyhow!("invalid leader id"))?
2040 .into();
2041 let follower_id = session.connection_id;
2042
2043 session
2044 .db()
2045 .await
2046 .check_room_participants(room_id, leader_id, session.connection_id)
2047 .await?;
2048
2049 let response_payload = session
2050 .peer
2051 .forward_request(session.connection_id, leader_id, request)
2052 .await?;
2053 response.send(response_payload)?;
2054
2055 if let Some(project_id) = project_id {
2056 let room = session
2057 .db()
2058 .await
2059 .follow(room_id, project_id, leader_id, follower_id)
2060 .await?;
2061 room_updated(&room, &session.peer);
2062 }
2063
2064 Ok(())
2065}
2066
2067/// Stop following another user in a call.
2068async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2069 let room_id = RoomId::from_proto(request.room_id);
2070 let project_id = request.project_id.map(ProjectId::from_proto);
2071 let leader_id = request
2072 .leader_id
2073 .ok_or_else(|| anyhow!("invalid leader id"))?
2074 .into();
2075 let follower_id = session.connection_id;
2076
2077 session
2078 .db()
2079 .await
2080 .check_room_participants(room_id, leader_id, session.connection_id)
2081 .await?;
2082
2083 session
2084 .peer
2085 .forward_send(session.connection_id, leader_id, request)?;
2086
2087 if let Some(project_id) = project_id {
2088 let room = session
2089 .db()
2090 .await
2091 .unfollow(room_id, project_id, leader_id, follower_id)
2092 .await?;
2093 room_updated(&room, &session.peer);
2094 }
2095
2096 Ok(())
2097}
2098
2099/// Notify everyone following you of your current location.
2100async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2101 let room_id = RoomId::from_proto(request.room_id);
2102 let database = session.db.lock().await;
2103
2104 let connection_ids = if let Some(project_id) = request.project_id {
2105 let project_id = ProjectId::from_proto(project_id);
2106 database
2107 .project_connection_ids(project_id, session.connection_id)
2108 .await?
2109 } else {
2110 database
2111 .room_connection_ids(room_id, session.connection_id)
2112 .await?
2113 };
2114
2115 // For now, don't send view update messages back to that view's current leader.
2116 let connection_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2117 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2118 _ => None,
2119 });
2120
2121 for follower_peer_id in request.follower_ids.iter().copied() {
2122 let follower_connection_id = follower_peer_id.into();
2123 if Some(follower_peer_id) != connection_id_to_omit
2124 && connection_ids.contains(&follower_connection_id)
2125 {
2126 session.peer.forward_send(
2127 session.connection_id,
2128 follower_connection_id,
2129 request.clone(),
2130 )?;
2131 }
2132 }
2133 Ok(())
2134}
2135
2136/// Get public data about users.
2137async fn get_users(
2138 request: proto::GetUsers,
2139 response: Response<proto::GetUsers>,
2140 session: Session,
2141) -> Result<()> {
2142 let user_ids = request
2143 .user_ids
2144 .into_iter()
2145 .map(UserId::from_proto)
2146 .collect();
2147 let users = session
2148 .db()
2149 .await
2150 .get_users_by_ids(user_ids)
2151 .await?
2152 .into_iter()
2153 .map(|user| proto::User {
2154 id: user.id.to_proto(),
2155 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2156 github_login: user.github_login,
2157 })
2158 .collect();
2159 response.send(proto::UsersResponse { users })?;
2160 Ok(())
2161}
2162
2163/// Search for users (to invite) buy Github login
2164async fn fuzzy_search_users(
2165 request: proto::FuzzySearchUsers,
2166 response: Response<proto::FuzzySearchUsers>,
2167 session: Session,
2168) -> Result<()> {
2169 let query = request.query;
2170 let users = match query.len() {
2171 0 => vec![],
2172 1 | 2 => session
2173 .db()
2174 .await
2175 .get_user_by_github_login(&query)
2176 .await?
2177 .into_iter()
2178 .collect(),
2179 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2180 };
2181 let users = users
2182 .into_iter()
2183 .filter(|user| user.id != session.user_id)
2184 .map(|user| proto::User {
2185 id: user.id.to_proto(),
2186 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2187 github_login: user.github_login,
2188 })
2189 .collect();
2190 response.send(proto::UsersResponse { users })?;
2191 Ok(())
2192}
2193
2194/// Send a contact request to another user.
2195async fn request_contact(
2196 request: proto::RequestContact,
2197 response: Response<proto::RequestContact>,
2198 session: Session,
2199) -> Result<()> {
2200 let requester_id = session.user_id;
2201 let responder_id = UserId::from_proto(request.responder_id);
2202 if requester_id == responder_id {
2203 return Err(anyhow!("cannot add yourself as a contact"))?;
2204 }
2205
2206 let notifications = session
2207 .db()
2208 .await
2209 .send_contact_request(requester_id, responder_id)
2210 .await?;
2211
2212 // Update outgoing contact requests of requester
2213 let mut update = proto::UpdateContacts::default();
2214 update.outgoing_requests.push(responder_id.to_proto());
2215 for connection_id in session
2216 .connection_pool()
2217 .await
2218 .user_connection_ids(requester_id)
2219 {
2220 session.peer.send(connection_id, update.clone())?;
2221 }
2222
2223 // Update incoming contact requests of responder
2224 let mut update = proto::UpdateContacts::default();
2225 update
2226 .incoming_requests
2227 .push(proto::IncomingContactRequest {
2228 requester_id: requester_id.to_proto(),
2229 });
2230 let connection_pool = session.connection_pool().await;
2231 for connection_id in connection_pool.user_connection_ids(responder_id) {
2232 session.peer.send(connection_id, update.clone())?;
2233 }
2234
2235 send_notifications(&*connection_pool, &session.peer, notifications);
2236
2237 response.send(proto::Ack {})?;
2238 Ok(())
2239}
2240
2241/// Accept or decline a contact request
2242async fn respond_to_contact_request(
2243 request: proto::RespondToContactRequest,
2244 response: Response<proto::RespondToContactRequest>,
2245 session: Session,
2246) -> Result<()> {
2247 let responder_id = session.user_id;
2248 let requester_id = UserId::from_proto(request.requester_id);
2249 let db = session.db().await;
2250 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2251 db.dismiss_contact_notification(responder_id, requester_id)
2252 .await?;
2253 } else {
2254 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2255
2256 let notifications = db
2257 .respond_to_contact_request(responder_id, requester_id, accept)
2258 .await?;
2259 let requester_busy = db.is_user_busy(requester_id).await?;
2260 let responder_busy = db.is_user_busy(responder_id).await?;
2261
2262 let pool = session.connection_pool().await;
2263 // Update responder with new contact
2264 let mut update = proto::UpdateContacts::default();
2265 if accept {
2266 update
2267 .contacts
2268 .push(contact_for_user(requester_id, requester_busy, &pool));
2269 }
2270 update
2271 .remove_incoming_requests
2272 .push(requester_id.to_proto());
2273 for connection_id in pool.user_connection_ids(responder_id) {
2274 session.peer.send(connection_id, update.clone())?;
2275 }
2276
2277 // Update requester with new contact
2278 let mut update = proto::UpdateContacts::default();
2279 if accept {
2280 update
2281 .contacts
2282 .push(contact_for_user(responder_id, responder_busy, &pool));
2283 }
2284 update
2285 .remove_outgoing_requests
2286 .push(responder_id.to_proto());
2287
2288 for connection_id in pool.user_connection_ids(requester_id) {
2289 session.peer.send(connection_id, update.clone())?;
2290 }
2291
2292 send_notifications(&*pool, &session.peer, notifications);
2293 }
2294
2295 response.send(proto::Ack {})?;
2296 Ok(())
2297}
2298
2299/// Remove a contact.
2300async fn remove_contact(
2301 request: proto::RemoveContact,
2302 response: Response<proto::RemoveContact>,
2303 session: Session,
2304) -> Result<()> {
2305 let requester_id = session.user_id;
2306 let responder_id = UserId::from_proto(request.user_id);
2307 let db = session.db().await;
2308 let (contact_accepted, deleted_notification_id) =
2309 db.remove_contact(requester_id, responder_id).await?;
2310
2311 let pool = session.connection_pool().await;
2312 // Update outgoing contact requests of requester
2313 let mut update = proto::UpdateContacts::default();
2314 if contact_accepted {
2315 update.remove_contacts.push(responder_id.to_proto());
2316 } else {
2317 update
2318 .remove_outgoing_requests
2319 .push(responder_id.to_proto());
2320 }
2321 for connection_id in pool.user_connection_ids(requester_id) {
2322 session.peer.send(connection_id, update.clone())?;
2323 }
2324
2325 // Update incoming contact requests of responder
2326 let mut update = proto::UpdateContacts::default();
2327 if contact_accepted {
2328 update.remove_contacts.push(requester_id.to_proto());
2329 } else {
2330 update
2331 .remove_incoming_requests
2332 .push(requester_id.to_proto());
2333 }
2334 for connection_id in pool.user_connection_ids(responder_id) {
2335 session.peer.send(connection_id, update.clone())?;
2336 if let Some(notification_id) = deleted_notification_id {
2337 session.peer.send(
2338 connection_id,
2339 proto::DeleteNotification {
2340 notification_id: notification_id.to_proto(),
2341 },
2342 )?;
2343 }
2344 }
2345
2346 response.send(proto::Ack {})?;
2347 Ok(())
2348}
2349
2350/// Creates a new channel.
2351async fn create_channel(
2352 request: proto::CreateChannel,
2353 response: Response<proto::CreateChannel>,
2354 session: Session,
2355) -> Result<()> {
2356 let db = session.db().await;
2357
2358 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2359 let (channel, owner, channel_members) = db
2360 .create_channel(&request.name, parent_id, session.user_id)
2361 .await?;
2362
2363 response.send(proto::CreateChannelResponse {
2364 channel: Some(channel.to_proto()),
2365 parent_id: request.parent_id,
2366 })?;
2367
2368 let connection_pool = session.connection_pool().await;
2369 if let Some(owner) = owner {
2370 let update = proto::UpdateUserChannels {
2371 channel_memberships: vec![proto::ChannelMembership {
2372 channel_id: owner.channel_id.to_proto(),
2373 role: owner.role.into(),
2374 }],
2375 ..Default::default()
2376 };
2377 for connection_id in connection_pool.user_connection_ids(owner.user_id) {
2378 session.peer.send(connection_id, update.clone())?;
2379 }
2380 }
2381
2382 for channel_member in channel_members {
2383 if !channel_member.role.can_see_channel(channel.visibility) {
2384 continue;
2385 }
2386
2387 let update = proto::UpdateChannels {
2388 channels: vec![channel.to_proto()],
2389 ..Default::default()
2390 };
2391 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2392 session.peer.send(connection_id, update.clone())?;
2393 }
2394 }
2395
2396 Ok(())
2397}
2398
2399/// Delete a channel
2400async fn delete_channel(
2401 request: proto::DeleteChannel,
2402 response: Response<proto::DeleteChannel>,
2403 session: Session,
2404) -> Result<()> {
2405 let db = session.db().await;
2406
2407 let channel_id = request.channel_id;
2408 let (removed_channels, member_ids) = db
2409 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2410 .await?;
2411 response.send(proto::Ack {})?;
2412
2413 // Notify members of removed channels
2414 let mut update = proto::UpdateChannels::default();
2415 update
2416 .delete_channels
2417 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2418
2419 let connection_pool = session.connection_pool().await;
2420 for member_id in member_ids {
2421 for connection_id in connection_pool.user_connection_ids(member_id) {
2422 session.peer.send(connection_id, update.clone())?;
2423 }
2424 }
2425
2426 Ok(())
2427}
2428
2429/// Invite someone to join a channel.
2430async fn invite_channel_member(
2431 request: proto::InviteChannelMember,
2432 response: Response<proto::InviteChannelMember>,
2433 session: Session,
2434) -> Result<()> {
2435 let db = session.db().await;
2436 let channel_id = ChannelId::from_proto(request.channel_id);
2437 let invitee_id = UserId::from_proto(request.user_id);
2438 let InviteMemberResult {
2439 channel,
2440 notifications,
2441 } = db
2442 .invite_channel_member(
2443 channel_id,
2444 invitee_id,
2445 session.user_id,
2446 request.role().into(),
2447 )
2448 .await?;
2449
2450 let update = proto::UpdateChannels {
2451 channel_invitations: vec![channel.to_proto()],
2452 ..Default::default()
2453 };
2454
2455 let connection_pool = session.connection_pool().await;
2456 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2457 session.peer.send(connection_id, update.clone())?;
2458 }
2459
2460 send_notifications(&*connection_pool, &session.peer, notifications);
2461
2462 response.send(proto::Ack {})?;
2463 Ok(())
2464}
2465
2466/// remove someone from a channel
2467async fn remove_channel_member(
2468 request: proto::RemoveChannelMember,
2469 response: Response<proto::RemoveChannelMember>,
2470 session: Session,
2471) -> Result<()> {
2472 let db = session.db().await;
2473 let channel_id = ChannelId::from_proto(request.channel_id);
2474 let member_id = UserId::from_proto(request.user_id);
2475
2476 let RemoveChannelMemberResult {
2477 membership_update,
2478 notification_id,
2479 } = db
2480 .remove_channel_member(channel_id, member_id, session.user_id)
2481 .await?;
2482
2483 let connection_pool = &session.connection_pool().await;
2484 notify_membership_updated(
2485 &connection_pool,
2486 membership_update,
2487 member_id,
2488 &session.peer,
2489 );
2490 for connection_id in connection_pool.user_connection_ids(member_id) {
2491 if let Some(notification_id) = notification_id {
2492 session
2493 .peer
2494 .send(
2495 connection_id,
2496 proto::DeleteNotification {
2497 notification_id: notification_id.to_proto(),
2498 },
2499 )
2500 .trace_err();
2501 }
2502 }
2503
2504 response.send(proto::Ack {})?;
2505 Ok(())
2506}
2507
2508/// Toggle the channel between public and private.
2509/// Care is taken to maintain the invariant that public channels only descend from public channels,
2510/// (though members-only channels can appear at any point in the hierarchy).
2511async fn set_channel_visibility(
2512 request: proto::SetChannelVisibility,
2513 response: Response<proto::SetChannelVisibility>,
2514 session: Session,
2515) -> Result<()> {
2516 let db = session.db().await;
2517 let channel_id = ChannelId::from_proto(request.channel_id);
2518 let visibility = request.visibility().into();
2519
2520 let (channel, channel_members) = db
2521 .set_channel_visibility(channel_id, visibility, session.user_id)
2522 .await?;
2523
2524 let connection_pool = session.connection_pool().await;
2525 for member in channel_members {
2526 let update = if member.role.can_see_channel(channel.visibility) {
2527 proto::UpdateChannels {
2528 channels: vec![channel.to_proto()],
2529 ..Default::default()
2530 }
2531 } else {
2532 proto::UpdateChannels {
2533 delete_channels: vec![channel.id.to_proto()],
2534 ..Default::default()
2535 }
2536 };
2537
2538 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2539 session.peer.send(connection_id, update.clone())?;
2540 }
2541 }
2542
2543 response.send(proto::Ack {})?;
2544 Ok(())
2545}
2546
2547/// Alter the role for a user in the channel.
2548async fn set_channel_member_role(
2549 request: proto::SetChannelMemberRole,
2550 response: Response<proto::SetChannelMemberRole>,
2551 session: Session,
2552) -> Result<()> {
2553 let db = session.db().await;
2554 let channel_id = ChannelId::from_proto(request.channel_id);
2555 let member_id = UserId::from_proto(request.user_id);
2556 let result = db
2557 .set_channel_member_role(
2558 channel_id,
2559 session.user_id,
2560 member_id,
2561 request.role().into(),
2562 )
2563 .await?;
2564
2565 match result {
2566 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2567 let connection_pool = session.connection_pool().await;
2568 notify_membership_updated(
2569 &connection_pool,
2570 membership_update,
2571 member_id,
2572 &session.peer,
2573 )
2574 }
2575 db::SetMemberRoleResult::InviteUpdated(channel) => {
2576 let update = proto::UpdateChannels {
2577 channel_invitations: vec![channel.to_proto()],
2578 ..Default::default()
2579 };
2580
2581 for connection_id in session
2582 .connection_pool()
2583 .await
2584 .user_connection_ids(member_id)
2585 {
2586 session.peer.send(connection_id, update.clone())?;
2587 }
2588 }
2589 }
2590
2591 response.send(proto::Ack {})?;
2592 Ok(())
2593}
2594
2595/// Change the name of a channel
2596async fn rename_channel(
2597 request: proto::RenameChannel,
2598 response: Response<proto::RenameChannel>,
2599 session: Session,
2600) -> Result<()> {
2601 let db = session.db().await;
2602 let channel_id = ChannelId::from_proto(request.channel_id);
2603 let (channel, channel_members) = db
2604 .rename_channel(channel_id, session.user_id, &request.name)
2605 .await?;
2606
2607 response.send(proto::RenameChannelResponse {
2608 channel: Some(channel.to_proto()),
2609 })?;
2610
2611 let connection_pool = session.connection_pool().await;
2612 for channel_member in channel_members {
2613 if !channel_member.role.can_see_channel(channel.visibility) {
2614 continue;
2615 }
2616 let update = proto::UpdateChannels {
2617 channels: vec![channel.to_proto()],
2618 ..Default::default()
2619 };
2620 for connection_id in connection_pool.user_connection_ids(channel_member.user_id) {
2621 session.peer.send(connection_id, update.clone())?;
2622 }
2623 }
2624
2625 Ok(())
2626}
2627
2628/// Move a channel to a new parent.
2629async fn move_channel(
2630 request: proto::MoveChannel,
2631 response: Response<proto::MoveChannel>,
2632 session: Session,
2633) -> Result<()> {
2634 let channel_id = ChannelId::from_proto(request.channel_id);
2635 let to = ChannelId::from_proto(request.to);
2636
2637 let (channels, channel_members) = session
2638 .db()
2639 .await
2640 .move_channel(channel_id, to, session.user_id)
2641 .await?;
2642
2643 let connection_pool = session.connection_pool().await;
2644 for member in channel_members {
2645 let channels = channels
2646 .iter()
2647 .filter_map(|channel| {
2648 if member.role.can_see_channel(channel.visibility) {
2649 Some(channel.to_proto())
2650 } else {
2651 None
2652 }
2653 })
2654 .collect::<Vec<_>>();
2655 if channels.is_empty() {
2656 continue;
2657 }
2658
2659 let update = proto::UpdateChannels {
2660 channels,
2661 ..Default::default()
2662 };
2663
2664 for connection_id in connection_pool.user_connection_ids(member.user_id) {
2665 session.peer.send(connection_id, update.clone())?;
2666 }
2667 }
2668
2669 response.send(Ack {})?;
2670 Ok(())
2671}
2672
2673/// Get the list of channel members
2674async fn get_channel_members(
2675 request: proto::GetChannelMembers,
2676 response: Response<proto::GetChannelMembers>,
2677 session: Session,
2678) -> Result<()> {
2679 let db = session.db().await;
2680 let channel_id = ChannelId::from_proto(request.channel_id);
2681 let members = db
2682 .get_channel_participant_details(channel_id, session.user_id)
2683 .await?;
2684 response.send(proto::GetChannelMembersResponse { members })?;
2685 Ok(())
2686}
2687
2688/// Accept or decline a channel invitation.
2689async fn respond_to_channel_invite(
2690 request: proto::RespondToChannelInvite,
2691 response: Response<proto::RespondToChannelInvite>,
2692 session: Session,
2693) -> Result<()> {
2694 let db = session.db().await;
2695 let channel_id = ChannelId::from_proto(request.channel_id);
2696 let RespondToChannelInvite {
2697 membership_update,
2698 notifications,
2699 } = db
2700 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2701 .await?;
2702
2703 let connection_pool = session.connection_pool().await;
2704 if let Some(membership_update) = membership_update {
2705 notify_membership_updated(
2706 &connection_pool,
2707 membership_update,
2708 session.user_id,
2709 &session.peer,
2710 );
2711 } else {
2712 let update = proto::UpdateChannels {
2713 remove_channel_invitations: vec![channel_id.to_proto()],
2714 ..Default::default()
2715 };
2716
2717 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2718 session.peer.send(connection_id, update.clone())?;
2719 }
2720 };
2721
2722 send_notifications(&*connection_pool, &session.peer, notifications);
2723
2724 response.send(proto::Ack {})?;
2725
2726 Ok(())
2727}
2728
2729/// Join the channels' call
2730async fn join_channel(
2731 request: proto::JoinChannel,
2732 response: Response<proto::JoinChannel>,
2733 session: Session,
2734) -> Result<()> {
2735 session.endpoint_removed_in("join_channel", "0.123.0".parse().unwrap())?;
2736
2737 let channel_id = ChannelId::from_proto(request.channel_id);
2738 join_channel_internal(channel_id, true, Box::new(response), session).await
2739}
2740
2741async fn join_channel2(
2742 request: proto::JoinChannel2,
2743 response: Response<proto::JoinChannel2>,
2744 session: Session,
2745) -> Result<()> {
2746 let channel_id = ChannelId::from_proto(request.channel_id);
2747 join_channel_internal(channel_id, false, Box::new(response), session).await
2748}
2749
2750async fn join_channel_call(
2751 request: proto::JoinChannelCall,
2752 response: Response<proto::JoinChannelCall>,
2753 session: Session,
2754) -> Result<()> {
2755 let channel_id = ChannelId::from_proto(request.channel_id);
2756 let db = session.db().await;
2757 let (joined_room, role) = db
2758 .set_in_channel_call(channel_id, session.user_id, true)
2759 .await?;
2760
2761 let Some(connection_info) = session.live_kit_client.as_ref().and_then(|live_kit| {
2762 live_kit_info_for_user(live_kit, &session.user_id, role, &joined_room.live_kit_room)
2763 }) else {
2764 Err(anyhow!("no live kit token info"))?
2765 };
2766
2767 room_updated(&joined_room, &session.peer);
2768 response.send(proto::JoinChannelCallResponse {
2769 live_kit_connection_info: Some(connection_info),
2770 })?;
2771
2772 Ok(())
2773}
2774
2775async fn leave_channel_call(
2776 request: proto::LeaveChannelCall,
2777 response: Response<proto::LeaveChannelCall>,
2778 session: Session,
2779) -> Result<()> {
2780 let channel_id = ChannelId::from_proto(request.channel_id);
2781 let db = session.db().await;
2782 let (joined_room, _) = db
2783 .set_in_channel_call(channel_id, session.user_id, false)
2784 .await?;
2785
2786 room_updated(&joined_room, &session.peer);
2787 response.send(proto::Ack {})?;
2788
2789 Ok(())
2790}
2791
2792trait JoinChannelInternalResponse {
2793 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2794}
2795impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2796 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2797 Response::<proto::JoinChannel>::send(self, result)
2798 }
2799}
2800impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2801 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2802 Response::<proto::JoinRoom>::send(self, result)
2803 }
2804}
2805impl JoinChannelInternalResponse for Response<proto::JoinChannel2> {
2806 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2807 Response::<proto::JoinChannel2>::send(self, result)
2808 }
2809}
2810
2811async fn join_channel_internal(
2812 channel_id: ChannelId,
2813 autojoin: bool,
2814 response: Box<impl JoinChannelInternalResponse>,
2815 session: Session,
2816) -> Result<()> {
2817 let joined_room = {
2818 leave_room_for_session(&session).await?;
2819 let db = session.db().await;
2820
2821 let (joined_room, membership_updated, role) = db
2822 .join_channel(
2823 channel_id,
2824 session.user_id,
2825 autojoin,
2826 session.connection_id,
2827 session.zed_environment.as_ref(),
2828 )
2829 .await?;
2830
2831 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2832 if !autojoin {
2833 return None;
2834 }
2835 live_kit_info_for_user(
2836 live_kit,
2837 &session.user_id,
2838 role,
2839 &joined_room.room.live_kit_room,
2840 )
2841 });
2842
2843 response.send(proto::JoinRoomResponse {
2844 room: Some(joined_room.room.clone()),
2845 channel_id: joined_room.channel_id.map(|id| id.to_proto()),
2846 live_kit_connection_info,
2847 })?;
2848
2849 let connection_pool = session.connection_pool().await;
2850 if let Some(membership_updated) = membership_updated {
2851 notify_membership_updated(
2852 &connection_pool,
2853 membership_updated,
2854 session.user_id,
2855 &session.peer,
2856 );
2857 }
2858
2859 room_updated(&joined_room.room, &session.peer);
2860
2861 joined_room
2862 };
2863
2864 channel_updated(
2865 channel_id,
2866 &joined_room.room,
2867 &joined_room.channel_members,
2868 &session.peer,
2869 &*session.connection_pool().await,
2870 );
2871
2872 update_user_contacts(session.user_id, &session).await?;
2873 Ok(())
2874}
2875
2876fn live_kit_info_for_user(
2877 live_kit: &Arc<dyn live_kit_server::api::Client>,
2878 user_id: &UserId,
2879 role: ChannelRole,
2880 live_kit_room: &String,
2881) -> Option<LiveKitConnectionInfo> {
2882 let (can_publish, token) = if role == ChannelRole::Guest {
2883 (
2884 false,
2885 live_kit
2886 .guest_token(live_kit_room, &user_id.to_string())
2887 .trace_err()?,
2888 )
2889 } else {
2890 (
2891 true,
2892 live_kit
2893 .room_token(live_kit_room, &user_id.to_string())
2894 .trace_err()?,
2895 )
2896 };
2897
2898 Some(LiveKitConnectionInfo {
2899 server_url: live_kit.url().into(),
2900 token,
2901 can_publish,
2902 })
2903}
2904
2905/// Start editing the channel notes
2906async fn join_channel_buffer(
2907 request: proto::JoinChannelBuffer,
2908 response: Response<proto::JoinChannelBuffer>,
2909 session: Session,
2910) -> Result<()> {
2911 let db = session.db().await;
2912 let channel_id = ChannelId::from_proto(request.channel_id);
2913
2914 let open_response = db
2915 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2916 .await?;
2917
2918 let collaborators = open_response.collaborators.clone();
2919 response.send(open_response)?;
2920
2921 let update = UpdateChannelBufferCollaborators {
2922 channel_id: channel_id.to_proto(),
2923 collaborators: collaborators.clone(),
2924 };
2925 channel_buffer_updated(
2926 session.connection_id,
2927 collaborators
2928 .iter()
2929 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2930 &update,
2931 &session.peer,
2932 );
2933
2934 Ok(())
2935}
2936
2937/// Edit the channel notes
2938async fn update_channel_buffer(
2939 request: proto::UpdateChannelBuffer,
2940 session: Session,
2941) -> Result<()> {
2942 let db = session.db().await;
2943 let channel_id = ChannelId::from_proto(request.channel_id);
2944
2945 let (collaborators, non_collaborators, epoch, version) = db
2946 .update_channel_buffer(channel_id, session.user_id, &request.operations)
2947 .await?;
2948
2949 channel_buffer_updated(
2950 session.connection_id,
2951 collaborators,
2952 &proto::UpdateChannelBuffer {
2953 channel_id: channel_id.to_proto(),
2954 operations: request.operations,
2955 },
2956 &session.peer,
2957 );
2958
2959 let pool = &*session.connection_pool().await;
2960
2961 broadcast(
2962 None,
2963 non_collaborators
2964 .iter()
2965 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
2966 |peer_id| {
2967 session.peer.send(
2968 peer_id.into(),
2969 proto::UpdateChannels {
2970 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
2971 channel_id: channel_id.to_proto(),
2972 epoch: epoch as u64,
2973 version: version.clone(),
2974 }],
2975 ..Default::default()
2976 },
2977 )
2978 },
2979 );
2980
2981 Ok(())
2982}
2983
2984/// Rejoin the channel notes after a connection blip
2985async fn rejoin_channel_buffers(
2986 request: proto::RejoinChannelBuffers,
2987 response: Response<proto::RejoinChannelBuffers>,
2988 session: Session,
2989) -> Result<()> {
2990 let db = session.db().await;
2991 let buffers = db
2992 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
2993 .await?;
2994
2995 for rejoined_buffer in &buffers {
2996 let collaborators_to_notify = rejoined_buffer
2997 .buffer
2998 .collaborators
2999 .iter()
3000 .filter_map(|c| Some(c.peer_id?.into()));
3001 channel_buffer_updated(
3002 session.connection_id,
3003 collaborators_to_notify,
3004 &proto::UpdateChannelBufferCollaborators {
3005 channel_id: rejoined_buffer.buffer.channel_id,
3006 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3007 },
3008 &session.peer,
3009 );
3010 }
3011
3012 response.send(proto::RejoinChannelBuffersResponse {
3013 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3014 })?;
3015
3016 Ok(())
3017}
3018
3019/// Stop editing the channel notes
3020async fn leave_channel_buffer(
3021 request: proto::LeaveChannelBuffer,
3022 response: Response<proto::LeaveChannelBuffer>,
3023 session: Session,
3024) -> Result<()> {
3025 let db = session.db().await;
3026 let channel_id = ChannelId::from_proto(request.channel_id);
3027
3028 let left_buffer = db
3029 .leave_channel_buffer(channel_id, session.connection_id)
3030 .await?;
3031
3032 response.send(Ack {})?;
3033
3034 channel_buffer_updated(
3035 session.connection_id,
3036 left_buffer.connections,
3037 &proto::UpdateChannelBufferCollaborators {
3038 channel_id: channel_id.to_proto(),
3039 collaborators: left_buffer.collaborators,
3040 },
3041 &session.peer,
3042 );
3043
3044 Ok(())
3045}
3046
3047fn channel_buffer_updated<T: EnvelopedMessage>(
3048 sender_id: ConnectionId,
3049 collaborators: impl IntoIterator<Item = ConnectionId>,
3050 message: &T,
3051 peer: &Peer,
3052) {
3053 broadcast(Some(sender_id), collaborators.into_iter(), |peer_id| {
3054 peer.send(peer_id.into(), message.clone())
3055 });
3056}
3057
3058fn send_notifications(
3059 connection_pool: &ConnectionPool,
3060 peer: &Peer,
3061 notifications: db::NotificationBatch,
3062) {
3063 for (user_id, notification) in notifications {
3064 for connection_id in connection_pool.user_connection_ids(user_id) {
3065 if let Err(error) = peer.send(
3066 connection_id,
3067 proto::AddNotification {
3068 notification: Some(notification.clone()),
3069 },
3070 ) {
3071 tracing::error!(
3072 "failed to send notification to {:?} {}",
3073 connection_id,
3074 error
3075 );
3076 }
3077 }
3078 }
3079}
3080
3081/// Send a message to the channel
3082async fn send_channel_message(
3083 request: proto::SendChannelMessage,
3084 response: Response<proto::SendChannelMessage>,
3085 session: Session,
3086) -> Result<()> {
3087 // Validate the message body.
3088 let body = request.body.trim().to_string();
3089 if body.len() > MAX_MESSAGE_LEN {
3090 return Err(anyhow!("message is too long"))?;
3091 }
3092 if body.is_empty() {
3093 return Err(anyhow!("message can't be blank"))?;
3094 }
3095
3096 // TODO: adjust mentions if body is trimmed
3097
3098 let timestamp = OffsetDateTime::now_utc();
3099 let nonce = request
3100 .nonce
3101 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3102
3103 let channel_id = ChannelId::from_proto(request.channel_id);
3104 let CreatedChannelMessage {
3105 message_id,
3106 participant_connection_ids,
3107 channel_members,
3108 notifications,
3109 } = session
3110 .db()
3111 .await
3112 .create_channel_message(
3113 channel_id,
3114 session.user_id,
3115 &body,
3116 &request.mentions,
3117 timestamp,
3118 nonce.clone().into(),
3119 match request.reply_to_message_id {
3120 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3121 None => None,
3122 },
3123 )
3124 .await?;
3125 let message = proto::ChannelMessage {
3126 sender_id: session.user_id.to_proto(),
3127 id: message_id.to_proto(),
3128 body,
3129 mentions: request.mentions,
3130 timestamp: timestamp.unix_timestamp() as u64,
3131 nonce: Some(nonce),
3132 reply_to_message_id: request.reply_to_message_id,
3133 };
3134 broadcast(
3135 Some(session.connection_id),
3136 participant_connection_ids,
3137 |connection| {
3138 session.peer.send(
3139 connection,
3140 proto::ChannelMessageSent {
3141 channel_id: channel_id.to_proto(),
3142 message: Some(message.clone()),
3143 },
3144 )
3145 },
3146 );
3147 response.send(proto::SendChannelMessageResponse {
3148 message: Some(message),
3149 })?;
3150
3151 let pool = &*session.connection_pool().await;
3152 broadcast(
3153 None,
3154 channel_members
3155 .iter()
3156 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3157 |peer_id| {
3158 session.peer.send(
3159 peer_id.into(),
3160 proto::UpdateChannels {
3161 latest_channel_message_ids: vec![proto::ChannelMessageId {
3162 channel_id: channel_id.to_proto(),
3163 message_id: message_id.to_proto(),
3164 }],
3165 ..Default::default()
3166 },
3167 )
3168 },
3169 );
3170 send_notifications(pool, &session.peer, notifications);
3171
3172 Ok(())
3173}
3174
3175/// Delete a channel message
3176async fn remove_channel_message(
3177 request: proto::RemoveChannelMessage,
3178 response: Response<proto::RemoveChannelMessage>,
3179 session: Session,
3180) -> Result<()> {
3181 let channel_id = ChannelId::from_proto(request.channel_id);
3182 let message_id = MessageId::from_proto(request.message_id);
3183 let connection_ids = session
3184 .db()
3185 .await
3186 .remove_channel_message(channel_id, message_id, session.user_id)
3187 .await?;
3188 broadcast(Some(session.connection_id), connection_ids, |connection| {
3189 session.peer.send(connection, request.clone())
3190 });
3191 response.send(proto::Ack {})?;
3192 Ok(())
3193}
3194
3195/// Mark a channel message as read
3196async fn acknowledge_channel_message(
3197 request: proto::AckChannelMessage,
3198 session: Session,
3199) -> Result<()> {
3200 let channel_id = ChannelId::from_proto(request.channel_id);
3201 let message_id = MessageId::from_proto(request.message_id);
3202 let notifications = session
3203 .db()
3204 .await
3205 .observe_channel_message(channel_id, session.user_id, message_id)
3206 .await?;
3207 send_notifications(
3208 &*session.connection_pool().await,
3209 &session.peer,
3210 notifications,
3211 );
3212 Ok(())
3213}
3214
3215/// Mark a buffer version as synced
3216async fn acknowledge_buffer_version(
3217 request: proto::AckBufferOperation,
3218 session: Session,
3219) -> Result<()> {
3220 let buffer_id = BufferId::from_proto(request.buffer_id);
3221 session
3222 .db()
3223 .await
3224 .observe_buffer_version(
3225 buffer_id,
3226 session.user_id,
3227 request.epoch as i32,
3228 &request.version,
3229 )
3230 .await?;
3231 Ok(())
3232}
3233
3234/// Start receiving chat updates for a channel
3235async fn join_channel_chat(
3236 request: proto::JoinChannelChat,
3237 response: Response<proto::JoinChannelChat>,
3238 session: Session,
3239) -> Result<()> {
3240 let channel_id = ChannelId::from_proto(request.channel_id);
3241
3242 let db = session.db().await;
3243 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3244 .await?;
3245 let messages = db
3246 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3247 .await?;
3248 response.send(proto::JoinChannelChatResponse {
3249 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3250 messages,
3251 })?;
3252 Ok(())
3253}
3254
3255/// Stop receiving chat updates for a channel
3256async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3257 let channel_id = ChannelId::from_proto(request.channel_id);
3258 session
3259 .db()
3260 .await
3261 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3262 .await?;
3263 Ok(())
3264}
3265
3266/// Retrieve the chat history for a channel
3267async fn get_channel_messages(
3268 request: proto::GetChannelMessages,
3269 response: Response<proto::GetChannelMessages>,
3270 session: Session,
3271) -> Result<()> {
3272 let channel_id = ChannelId::from_proto(request.channel_id);
3273 let messages = session
3274 .db()
3275 .await
3276 .get_channel_messages(
3277 channel_id,
3278 session.user_id,
3279 MESSAGE_COUNT_PER_PAGE,
3280 Some(MessageId::from_proto(request.before_message_id)),
3281 )
3282 .await?;
3283 response.send(proto::GetChannelMessagesResponse {
3284 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3285 messages,
3286 })?;
3287 Ok(())
3288}
3289
3290/// Retrieve specific chat messages
3291async fn get_channel_messages_by_id(
3292 request: proto::GetChannelMessagesById,
3293 response: Response<proto::GetChannelMessagesById>,
3294 session: Session,
3295) -> Result<()> {
3296 let message_ids = request
3297 .message_ids
3298 .iter()
3299 .map(|id| MessageId::from_proto(*id))
3300 .collect::<Vec<_>>();
3301 let messages = session
3302 .db()
3303 .await
3304 .get_channel_messages_by_id(session.user_id, &message_ids)
3305 .await?;
3306 response.send(proto::GetChannelMessagesResponse {
3307 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3308 messages,
3309 })?;
3310 Ok(())
3311}
3312
3313/// Retrieve the current users notifications
3314async fn get_notifications(
3315 request: proto::GetNotifications,
3316 response: Response<proto::GetNotifications>,
3317 session: Session,
3318) -> Result<()> {
3319 let notifications = session
3320 .db()
3321 .await
3322 .get_notifications(
3323 session.user_id,
3324 NOTIFICATION_COUNT_PER_PAGE,
3325 request
3326 .before_id
3327 .map(|id| db::NotificationId::from_proto(id)),
3328 )
3329 .await?;
3330 response.send(proto::GetNotificationsResponse {
3331 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3332 notifications,
3333 })?;
3334 Ok(())
3335}
3336
3337/// Mark notifications as read
3338async fn mark_notification_as_read(
3339 request: proto::MarkNotificationRead,
3340 response: Response<proto::MarkNotificationRead>,
3341 session: Session,
3342) -> Result<()> {
3343 let database = &session.db().await;
3344 let notifications = database
3345 .mark_notification_as_read_by_id(
3346 session.user_id,
3347 NotificationId::from_proto(request.notification_id),
3348 )
3349 .await?;
3350 send_notifications(
3351 &*session.connection_pool().await,
3352 &session.peer,
3353 notifications,
3354 );
3355 response.send(proto::Ack {})?;
3356 Ok(())
3357}
3358
3359/// Get the current users information
3360async fn get_private_user_info(
3361 _request: proto::GetPrivateUserInfo,
3362 response: Response<proto::GetPrivateUserInfo>,
3363 session: Session,
3364) -> Result<()> {
3365 let db = session.db().await;
3366
3367 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3368 let user = db
3369 .get_user_by_id(session.user_id)
3370 .await?
3371 .ok_or_else(|| anyhow!("user not found"))?;
3372 let flags = db.get_user_flags(session.user_id).await?;
3373
3374 response.send(proto::GetPrivateUserInfoResponse {
3375 metrics_id,
3376 staff: user.admin,
3377 flags,
3378 })?;
3379 Ok(())
3380}
3381
3382fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3383 match message {
3384 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3385 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3386 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3387 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3388 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3389 code: frame.code.into(),
3390 reason: frame.reason,
3391 })),
3392 }
3393}
3394
3395fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3396 match message {
3397 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3398 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3399 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3400 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3401 AxumMessage::Close(frame) => {
3402 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3403 code: frame.code.into(),
3404 reason: frame.reason,
3405 }))
3406 }
3407 }
3408}
3409
3410fn notify_membership_updated(
3411 connection_pool: &ConnectionPool,
3412 result: MembershipUpdated,
3413 user_id: UserId,
3414 peer: &Peer,
3415) {
3416 let user_channels_update = proto::UpdateUserChannels {
3417 channel_memberships: result
3418 .new_channels
3419 .channel_memberships
3420 .iter()
3421 .map(|cm| proto::ChannelMembership {
3422 channel_id: cm.channel_id.to_proto(),
3423 role: cm.role.into(),
3424 })
3425 .collect(),
3426 ..Default::default()
3427 };
3428 let mut update = build_channels_update(result.new_channels, vec![]);
3429 update.delete_channels = result
3430 .removed_channels
3431 .into_iter()
3432 .map(|id| id.to_proto())
3433 .collect();
3434 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3435
3436 for connection_id in connection_pool.user_connection_ids(user_id) {
3437 peer.send(connection_id, user_channels_update.clone())
3438 .trace_err();
3439 peer.send(connection_id, update.clone()).trace_err();
3440 }
3441}
3442
3443fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3444 proto::UpdateUserChannels {
3445 channel_memberships: channels
3446 .channel_memberships
3447 .iter()
3448 .map(|m| proto::ChannelMembership {
3449 channel_id: m.channel_id.to_proto(),
3450 role: m.role.into(),
3451 })
3452 .collect(),
3453 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3454 observed_channel_message_id: channels.observed_channel_messages.clone(),
3455 ..Default::default()
3456 }
3457}
3458
3459fn build_channels_update(
3460 channels: ChannelsForUser,
3461 channel_invites: Vec<db::Channel>,
3462) -> proto::UpdateChannels {
3463 let mut update = proto::UpdateChannels::default();
3464
3465 for channel in channels.channels {
3466 update.channels.push(channel.to_proto());
3467 }
3468
3469 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3470 update.latest_channel_message_ids = channels.latest_channel_messages;
3471
3472 for (channel_id, participants) in channels.channel_participants {
3473 update
3474 .channel_participants
3475 .push(proto::ChannelParticipants {
3476 channel_id: channel_id.to_proto(),
3477 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3478 });
3479 }
3480
3481 for channel in channel_invites {
3482 update.channel_invitations.push(channel.to_proto());
3483 }
3484
3485 update
3486}
3487
3488fn build_initial_contacts_update(
3489 contacts: Vec<db::Contact>,
3490 pool: &ConnectionPool,
3491) -> proto::UpdateContacts {
3492 let mut update = proto::UpdateContacts::default();
3493
3494 for contact in contacts {
3495 match contact {
3496 db::Contact::Accepted { user_id, busy } => {
3497 update.contacts.push(contact_for_user(user_id, busy, &pool));
3498 }
3499 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3500 db::Contact::Incoming { user_id } => {
3501 update
3502 .incoming_requests
3503 .push(proto::IncomingContactRequest {
3504 requester_id: user_id.to_proto(),
3505 })
3506 }
3507 }
3508 }
3509
3510 update
3511}
3512
3513fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3514 proto::Contact {
3515 user_id: user_id.to_proto(),
3516 online: pool.is_user_online(user_id),
3517 busy,
3518 }
3519}
3520
3521fn room_updated(room: &proto::Room, peer: &Peer) {
3522 broadcast(
3523 None,
3524 room.participants
3525 .iter()
3526 .filter_map(|participant| Some(participant.peer_id?.into())),
3527 |peer_id| {
3528 peer.send(
3529 peer_id.into(),
3530 proto::RoomUpdated {
3531 room: Some(room.clone()),
3532 },
3533 )
3534 },
3535 );
3536}
3537
3538fn channel_updated(
3539 channel_id: ChannelId,
3540 room: &proto::Room,
3541 channel_members: &[UserId],
3542 peer: &Peer,
3543 pool: &ConnectionPool,
3544) {
3545 let participants = room
3546 .participants
3547 .iter()
3548 .map(|p| p.user_id)
3549 .collect::<Vec<_>>();
3550
3551 broadcast(
3552 None,
3553 channel_members
3554 .iter()
3555 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3556 |peer_id| {
3557 peer.send(
3558 peer_id.into(),
3559 proto::UpdateChannels {
3560 channel_participants: vec![proto::ChannelParticipants {
3561 channel_id: channel_id.to_proto(),
3562 participant_user_ids: participants.clone(),
3563 }],
3564 ..Default::default()
3565 },
3566 )
3567 },
3568 );
3569}
3570
3571async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3572 let db = session.db().await;
3573
3574 let contacts = db.get_contacts(user_id).await?;
3575 let busy = db.is_user_busy(user_id).await?;
3576
3577 let pool = session.connection_pool().await;
3578 let updated_contact = contact_for_user(user_id, busy, &pool);
3579 for contact in contacts {
3580 if let db::Contact::Accepted {
3581 user_id: contact_user_id,
3582 ..
3583 } = contact
3584 {
3585 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3586 session
3587 .peer
3588 .send(
3589 contact_conn_id,
3590 proto::UpdateContacts {
3591 contacts: vec![updated_contact.clone()],
3592 remove_contacts: Default::default(),
3593 incoming_requests: Default::default(),
3594 remove_incoming_requests: Default::default(),
3595 outgoing_requests: Default::default(),
3596 remove_outgoing_requests: Default::default(),
3597 },
3598 )
3599 .trace_err();
3600 }
3601 }
3602 }
3603 Ok(())
3604}
3605
3606async fn leave_room_for_session(session: &Session) -> Result<()> {
3607 let mut contacts_to_update = HashSet::default();
3608
3609 let room_id;
3610 let canceled_calls_to_user_ids;
3611 let live_kit_room;
3612 let delete_live_kit_room;
3613 let room;
3614 let channel_members;
3615 let channel_id;
3616
3617 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3618 contacts_to_update.insert(session.user_id);
3619
3620 for project in left_room.left_projects.values() {
3621 project_left(project, session);
3622 }
3623
3624 room_id = RoomId::from_proto(left_room.room.id);
3625 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3626 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3627 delete_live_kit_room = left_room.deleted;
3628 room = mem::take(&mut left_room.room);
3629 channel_members = mem::take(&mut left_room.channel_members);
3630 channel_id = left_room.channel_id;
3631
3632 room_updated(&room, &session.peer);
3633 } else {
3634 return Ok(());
3635 }
3636
3637 if let Some(channel_id) = channel_id {
3638 channel_updated(
3639 channel_id,
3640 &room,
3641 &channel_members,
3642 &session.peer,
3643 &*session.connection_pool().await,
3644 );
3645 }
3646
3647 {
3648 let pool = session.connection_pool().await;
3649 for canceled_user_id in canceled_calls_to_user_ids {
3650 for connection_id in pool.user_connection_ids(canceled_user_id) {
3651 session
3652 .peer
3653 .send(
3654 connection_id,
3655 proto::CallCanceled {
3656 room_id: room_id.to_proto(),
3657 },
3658 )
3659 .trace_err();
3660 }
3661 contacts_to_update.insert(canceled_user_id);
3662 }
3663 }
3664
3665 for contact_user_id in contacts_to_update {
3666 update_user_contacts(contact_user_id, &session).await?;
3667 }
3668
3669 if let Some(live_kit) = session.live_kit_client.as_ref() {
3670 live_kit
3671 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3672 .await
3673 .trace_err();
3674
3675 if delete_live_kit_room {
3676 live_kit.delete_room(live_kit_room).await.trace_err();
3677 }
3678 }
3679
3680 Ok(())
3681}
3682
3683async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3684 let left_channel_buffers = session
3685 .db()
3686 .await
3687 .leave_channel_buffers(session.connection_id)
3688 .await?;
3689
3690 for left_buffer in left_channel_buffers {
3691 channel_buffer_updated(
3692 session.connection_id,
3693 left_buffer.connections,
3694 &proto::UpdateChannelBufferCollaborators {
3695 channel_id: left_buffer.channel_id.to_proto(),
3696 collaborators: left_buffer.collaborators,
3697 },
3698 &session.peer,
3699 );
3700 }
3701
3702 Ok(())
3703}
3704
3705fn project_left(project: &db::LeftProject, session: &Session) {
3706 for connection_id in &project.connection_ids {
3707 if project.host_user_id == session.user_id {
3708 session
3709 .peer
3710 .send(
3711 *connection_id,
3712 proto::UnshareProject {
3713 project_id: project.id.to_proto(),
3714 },
3715 )
3716 .trace_err();
3717 } else {
3718 session
3719 .peer
3720 .send(
3721 *connection_id,
3722 proto::RemoveProjectCollaborator {
3723 project_id: project.id.to_proto(),
3724 peer_id: Some(session.connection_id.into()),
3725 },
3726 )
3727 .trace_err();
3728 }
3729 }
3730}
3731
3732pub trait ResultExt {
3733 type Ok;
3734
3735 fn trace_err(self) -> Option<Self::Ok>;
3736}
3737
3738impl<T, E> ResultExt for Result<T, E>
3739where
3740 E: std::fmt::Debug,
3741{
3742 type Ok = T;
3743
3744 fn trace_err(self) -> Option<T> {
3745 match self {
3746 Ok(value) => Some(value),
3747 Err(error) => {
3748 tracing::error!("{:?}", error);
3749 None
3750 }
3751 }
3752 }
3753}