1mod connection_pool;
2
3use crate::{
4 auth::{self, Impersonator},
5 db::{
6 self, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage,
7 Database, InviteMemberResult, MembershipUpdated, MessageId, NotificationId, Project,
8 ProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId, ServerId,
9 User, UserId,
10 },
11 executor::Executor,
12 AppState, Error, RateLimit, RateLimiter, Result,
13};
14use anyhow::{anyhow, Context as _};
15use async_tungstenite::tungstenite::{
16 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
17};
18use axum::{
19 body::Body,
20 extract::{
21 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
22 ConnectInfo, WebSocketUpgrade,
23 },
24 headers::{Header, HeaderName},
25 http::StatusCode,
26 middleware,
27 response::IntoResponse,
28 routing::get,
29 Extension, Router, TypedHeader,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use core::fmt::{self, Debug, Formatter};
34
35use futures::{
36 channel::oneshot,
37 future::{self, BoxFuture},
38 stream::FuturesUnordered,
39 FutureExt, SinkExt, StreamExt, TryStreamExt,
40};
41use prometheus::{register_int_gauge, IntGauge};
42use rpc::{
43 proto::{
44 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
45 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
46 },
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48};
49use serde::{Serialize, Serializer};
50use std::{
51 any::TypeId,
52 future::Future,
53 marker::PhantomData,
54 mem,
55 net::SocketAddr,
56 ops::{Deref, DerefMut},
57 rc::Rc,
58 sync::{
59 atomic::{AtomicBool, Ordering::SeqCst},
60 Arc, OnceLock,
61 },
62 time::{Duration, Instant},
63};
64use time::OffsetDateTime;
65use tokio::sync::{watch, Semaphore};
66use tower::ServiceBuilder;
67use tracing::{field, info_span, instrument, Instrument};
68use util::{http::IsahcHttpClient, SemanticVersion};
69
70pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
71
72// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
73pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
74
75const MESSAGE_COUNT_PER_PAGE: usize = 100;
76const MAX_MESSAGE_LEN: usize = 1024;
77const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
78
79type MessageHandler =
80 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
81
82struct Response<R> {
83 peer: Arc<Peer>,
84 receipt: Receipt<R>,
85 responded: Arc<AtomicBool>,
86}
87
88impl<R: RequestMessage> Response<R> {
89 fn send(self, payload: R::Response) -> Result<()> {
90 self.responded.store(true, SeqCst);
91 self.peer.respond(self.receipt, payload)?;
92 Ok(())
93 }
94}
95
96struct StreamingResponse<R: RequestMessage> {
97 peer: Arc<Peer>,
98 receipt: Receipt<R>,
99}
100
101impl<R: RequestMessage> StreamingResponse<R> {
102 fn send(&self, payload: R::Response) -> Result<()> {
103 self.peer.respond(self.receipt, payload)?;
104 Ok(())
105 }
106}
107
108#[derive(Clone)]
109struct Session {
110 user_id: UserId,
111 connection_id: ConnectionId,
112 db: Arc<tokio::sync::Mutex<DbHandle>>,
113 peer: Arc<Peer>,
114 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
115 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
116 http_client: IsahcHttpClient,
117 rate_limiter: Arc<RateLimiter>,
118 _executor: Executor,
119}
120
121impl Session {
122 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
123 #[cfg(test)]
124 tokio::task::yield_now().await;
125 let guard = self.db.lock().await;
126 #[cfg(test)]
127 tokio::task::yield_now().await;
128 guard
129 }
130
131 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
132 #[cfg(test)]
133 tokio::task::yield_now().await;
134 let guard = self.connection_pool.lock();
135 ConnectionPoolGuard {
136 guard,
137 _not_send: PhantomData,
138 }
139 }
140}
141
142impl Debug for Session {
143 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
144 f.debug_struct("Session")
145 .field("user_id", &self.user_id)
146 .field("connection_id", &self.connection_id)
147 .finish()
148 }
149}
150
151struct DbHandle(Arc<Database>);
152
153impl Deref for DbHandle {
154 type Target = Database;
155
156 fn deref(&self) -> &Self::Target {
157 self.0.as_ref()
158 }
159}
160
161pub struct Server {
162 id: parking_lot::Mutex<ServerId>,
163 peer: Arc<Peer>,
164 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
165 app_state: Arc<AppState>,
166 handlers: HashMap<TypeId, MessageHandler>,
167 teardown: watch::Sender<bool>,
168}
169
170pub(crate) struct ConnectionPoolGuard<'a> {
171 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
172 _not_send: PhantomData<Rc<()>>,
173}
174
175#[derive(Serialize)]
176pub struct ServerSnapshot<'a> {
177 peer: &'a Peer,
178 #[serde(serialize_with = "serialize_deref")]
179 connection_pool: ConnectionPoolGuard<'a>,
180}
181
182pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
183where
184 S: Serializer,
185 T: Deref<Target = U>,
186 U: Serialize,
187{
188 Serialize::serialize(value.deref(), serializer)
189}
190
191impl Server {
192 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
193 let mut server = Self {
194 id: parking_lot::Mutex::new(id),
195 peer: Peer::new(id.0 as u32),
196 app_state: app_state.clone(),
197 connection_pool: Default::default(),
198 handlers: Default::default(),
199 teardown: watch::channel(false).0,
200 };
201
202 server
203 .add_request_handler(ping)
204 .add_request_handler(create_room)
205 .add_request_handler(join_room)
206 .add_request_handler(rejoin_room)
207 .add_request_handler(leave_room)
208 .add_request_handler(set_room_participant_role)
209 .add_request_handler(call)
210 .add_request_handler(cancel_call)
211 .add_message_handler(decline_call)
212 .add_request_handler(update_participant_location)
213 .add_request_handler(share_project)
214 .add_message_handler(unshare_project)
215 .add_request_handler(join_project)
216 .add_request_handler(join_hosted_project)
217 .add_message_handler(leave_project)
218 .add_request_handler(update_project)
219 .add_request_handler(update_worktree)
220 .add_message_handler(start_language_server)
221 .add_message_handler(update_language_server)
222 .add_message_handler(update_diagnostic_summary)
223 .add_message_handler(update_worktree_settings)
224 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
225 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
226 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
227 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
228 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
229 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
230 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
231 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
232 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
233 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
234 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
235 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
236 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
237 .add_request_handler(
238 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
239 )
240 .add_request_handler(
241 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
242 )
243 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
244 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
245 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
246 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
247 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
248 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
249 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
250 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
251 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
252 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
253 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
254 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
255 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
256 .add_message_handler(create_buffer_for_peer)
257 .add_request_handler(update_buffer)
258 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
259 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
260 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
261 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
262 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
263 .add_request_handler(get_users)
264 .add_request_handler(fuzzy_search_users)
265 .add_request_handler(request_contact)
266 .add_request_handler(remove_contact)
267 .add_request_handler(respond_to_contact_request)
268 .add_request_handler(create_channel)
269 .add_request_handler(delete_channel)
270 .add_request_handler(invite_channel_member)
271 .add_request_handler(remove_channel_member)
272 .add_request_handler(set_channel_member_role)
273 .add_request_handler(set_channel_visibility)
274 .add_request_handler(rename_channel)
275 .add_request_handler(join_channel_buffer)
276 .add_request_handler(leave_channel_buffer)
277 .add_message_handler(update_channel_buffer)
278 .add_request_handler(rejoin_channel_buffers)
279 .add_request_handler(get_channel_members)
280 .add_request_handler(respond_to_channel_invite)
281 .add_request_handler(join_channel)
282 .add_request_handler(join_channel_chat)
283 .add_message_handler(leave_channel_chat)
284 .add_request_handler(send_channel_message)
285 .add_request_handler(remove_channel_message)
286 .add_request_handler(get_channel_messages)
287 .add_request_handler(get_channel_messages_by_id)
288 .add_request_handler(get_notifications)
289 .add_request_handler(mark_notification_as_read)
290 .add_request_handler(move_channel)
291 .add_request_handler(follow)
292 .add_message_handler(unfollow)
293 .add_message_handler(update_followers)
294 .add_request_handler(get_private_user_info)
295 .add_message_handler(acknowledge_channel_message)
296 .add_message_handler(acknowledge_buffer_version)
297 .add_streaming_request_handler({
298 let app_state = app_state.clone();
299 move |request, response, session| {
300 complete_with_language_model(
301 request,
302 response,
303 session,
304 app_state.config.openai_api_key.clone(),
305 app_state.config.google_ai_api_key.clone(),
306 )
307 }
308 })
309 .add_request_handler({
310 let app_state = app_state.clone();
311 move |request, response, session| {
312 count_tokens_with_language_model(
313 request,
314 response,
315 session,
316 app_state.config.google_ai_api_key.clone(),
317 )
318 }
319 });
320
321 Arc::new(server)
322 }
323
324 pub async fn start(&self) -> Result<()> {
325 let server_id = *self.id.lock();
326 let app_state = self.app_state.clone();
327 let peer = self.peer.clone();
328 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
329 let pool = self.connection_pool.clone();
330 let live_kit_client = self.app_state.live_kit_client.clone();
331
332 let span = info_span!("start server");
333 self.app_state.executor.spawn_detached(
334 async move {
335 tracing::info!("waiting for cleanup timeout");
336 timeout.await;
337 tracing::info!("cleanup timeout expired, retrieving stale rooms");
338 if let Some((room_ids, channel_ids)) = app_state
339 .db
340 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
341 .await
342 .trace_err()
343 {
344 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
345 tracing::info!(
346 stale_channel_buffer_count = channel_ids.len(),
347 "retrieved stale channel buffers"
348 );
349
350 for channel_id in channel_ids {
351 if let Some(refreshed_channel_buffer) = app_state
352 .db
353 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
354 .await
355 .trace_err()
356 {
357 for connection_id in refreshed_channel_buffer.connection_ids {
358 peer.send(
359 connection_id,
360 proto::UpdateChannelBufferCollaborators {
361 channel_id: channel_id.to_proto(),
362 collaborators: refreshed_channel_buffer
363 .collaborators
364 .clone(),
365 },
366 )
367 .trace_err();
368 }
369 }
370 }
371
372 for room_id in room_ids {
373 let mut contacts_to_update = HashSet::default();
374 let mut canceled_calls_to_user_ids = Vec::new();
375 let mut live_kit_room = String::new();
376 let mut delete_live_kit_room = false;
377
378 if let Some(mut refreshed_room) = app_state
379 .db
380 .clear_stale_room_participants(room_id, server_id)
381 .await
382 .trace_err()
383 {
384 tracing::info!(
385 room_id = room_id.0,
386 new_participant_count = refreshed_room.room.participants.len(),
387 "refreshed room"
388 );
389 room_updated(&refreshed_room.room, &peer);
390 if let Some(channel) = refreshed_room.channel.as_ref() {
391 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
392 }
393 contacts_to_update
394 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
395 contacts_to_update
396 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
397 canceled_calls_to_user_ids =
398 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
399 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
400 delete_live_kit_room = refreshed_room.room.participants.is_empty();
401 }
402
403 {
404 let pool = pool.lock();
405 for canceled_user_id in canceled_calls_to_user_ids {
406 for connection_id in pool.user_connection_ids(canceled_user_id) {
407 peer.send(
408 connection_id,
409 proto::CallCanceled {
410 room_id: room_id.to_proto(),
411 },
412 )
413 .trace_err();
414 }
415 }
416 }
417
418 for user_id in contacts_to_update {
419 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
420 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
421 if let Some((busy, contacts)) = busy.zip(contacts) {
422 let pool = pool.lock();
423 let updated_contact = contact_for_user(user_id, busy, &pool);
424 for contact in contacts {
425 if let db::Contact::Accepted {
426 user_id: contact_user_id,
427 ..
428 } = contact
429 {
430 for contact_conn_id in
431 pool.user_connection_ids(contact_user_id)
432 {
433 peer.send(
434 contact_conn_id,
435 proto::UpdateContacts {
436 contacts: vec![updated_contact.clone()],
437 remove_contacts: Default::default(),
438 incoming_requests: Default::default(),
439 remove_incoming_requests: Default::default(),
440 outgoing_requests: Default::default(),
441 remove_outgoing_requests: Default::default(),
442 },
443 )
444 .trace_err();
445 }
446 }
447 }
448 }
449 }
450
451 if let Some(live_kit) = live_kit_client.as_ref() {
452 if delete_live_kit_room {
453 live_kit.delete_room(live_kit_room).await.trace_err();
454 }
455 }
456 }
457 }
458
459 app_state
460 .db
461 .delete_stale_servers(&app_state.config.zed_environment, server_id)
462 .await
463 .trace_err();
464 }
465 .instrument(span),
466 );
467 Ok(())
468 }
469
470 pub fn teardown(&self) {
471 self.peer.teardown();
472 self.connection_pool.lock().reset();
473 let _ = self.teardown.send(true);
474 }
475
476 #[cfg(test)]
477 pub fn reset(&self, id: ServerId) {
478 self.teardown();
479 *self.id.lock() = id;
480 self.peer.reset(id.0 as u32);
481 let _ = self.teardown.send(false);
482 }
483
484 #[cfg(test)]
485 pub fn id(&self) -> ServerId {
486 *self.id.lock()
487 }
488
489 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
490 where
491 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
492 Fut: 'static + Send + Future<Output = Result<()>>,
493 M: EnvelopedMessage,
494 {
495 let prev_handler = self.handlers.insert(
496 TypeId::of::<M>(),
497 Box::new(move |envelope, session| {
498 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
499 let received_at = envelope.received_at;
500 tracing::info!(
501 "message received"
502 );
503 let start_time = Instant::now();
504 let future = (handler)(*envelope, session);
505 async move {
506 let result = future.await;
507 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
508 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
509 let queue_duration_ms = total_duration_ms - processing_duration_ms;
510 match result {
511 Err(error) => {
512 tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, "error handling message")
513 }
514 Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
515 }
516 }
517 .boxed()
518 }),
519 );
520 if prev_handler.is_some() {
521 panic!("registered a handler for the same message twice");
522 }
523 self
524 }
525
526 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
527 where
528 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
529 Fut: 'static + Send + Future<Output = Result<()>>,
530 M: EnvelopedMessage,
531 {
532 self.add_handler(move |envelope, session| handler(envelope.payload, session));
533 self
534 }
535
536 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
537 where
538 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
539 Fut: Send + Future<Output = Result<()>>,
540 M: RequestMessage,
541 {
542 let handler = Arc::new(handler);
543 self.add_handler(move |envelope, session| {
544 let receipt = envelope.receipt();
545 let handler = handler.clone();
546 async move {
547 let peer = session.peer.clone();
548 let responded = Arc::new(AtomicBool::default());
549 let response = Response {
550 peer: peer.clone(),
551 responded: responded.clone(),
552 receipt,
553 };
554 match (handler)(envelope.payload, response, session).await {
555 Ok(()) => {
556 if responded.load(std::sync::atomic::Ordering::SeqCst) {
557 Ok(())
558 } else {
559 Err(anyhow!("handler did not send a response"))?
560 }
561 }
562 Err(error) => {
563 let proto_err = match &error {
564 Error::Internal(err) => err.to_proto(),
565 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
566 };
567 peer.respond_with_error(receipt, proto_err)?;
568 Err(error)
569 }
570 }
571 }
572 })
573 }
574
575 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
576 where
577 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
578 Fut: Send + Future<Output = Result<()>>,
579 M: RequestMessage,
580 {
581 let handler = Arc::new(handler);
582 self.add_handler(move |envelope, session| {
583 let receipt = envelope.receipt();
584 let handler = handler.clone();
585 async move {
586 let peer = session.peer.clone();
587 let response = StreamingResponse {
588 peer: peer.clone(),
589 receipt,
590 };
591 match (handler)(envelope.payload, response, session).await {
592 Ok(()) => {
593 peer.end_stream(receipt)?;
594 Ok(())
595 }
596 Err(error) => {
597 let proto_err = match &error {
598 Error::Internal(err) => err.to_proto(),
599 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
600 };
601 peer.respond_with_error(receipt, proto_err)?;
602 Err(error)
603 }
604 }
605 }
606 })
607 }
608
609 #[allow(clippy::too_many_arguments)]
610 pub fn handle_connection(
611 self: &Arc<Self>,
612 connection: Connection,
613 address: String,
614 user: User,
615 zed_version: ZedVersion,
616 impersonator: Option<User>,
617 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
618 executor: Executor,
619 ) -> impl Future<Output = ()> {
620 let this = self.clone();
621 let user_id = user.id;
622 let login = user.github_login.clone();
623 let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty, connection_id = field::Empty);
624 if let Some(impersonator) = impersonator {
625 span.record("impersonator", &impersonator.github_login);
626 }
627 let mut teardown = self.teardown.subscribe();
628 async move {
629 if *teardown.borrow() {
630 tracing::error!("server is tearing down");
631 return
632 }
633 let (connection_id, handle_io, mut incoming_rx) = this
634 .peer
635 .add_connection(connection, {
636 let executor = executor.clone();
637 move |duration| executor.sleep(duration)
638 });
639 tracing::Span::current().record("connection_id", format!("{}", connection_id));
640 tracing::info!("connection opened");
641
642 let http_client = match IsahcHttpClient::new() {
643 Ok(http_client) => http_client,
644 Err(error) => {
645 tracing::error!(?error, "failed to create HTTP client");
646 return;
647 }
648 };
649
650 let session = Session {
651 user_id,
652 connection_id,
653 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
654 peer: this.peer.clone(),
655 connection_pool: this.connection_pool.clone(),
656 live_kit_client: this.app_state.live_kit_client.clone(),
657 http_client,
658 rate_limiter: this.app_state.rate_limiter.clone(),
659 _executor: executor.clone(),
660 };
661
662 if let Err(error) = this.send_initial_client_update(connection_id, user, zed_version, send_connection_id, &session).await {
663 tracing::error!(?error, "failed to send initial client update");
664 return;
665 }
666
667 let handle_io = handle_io.fuse();
668 futures::pin_mut!(handle_io);
669
670 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
671 // This prevents deadlocks when e.g., client A performs a request to client B and
672 // client B performs a request to client A. If both clients stop processing further
673 // messages until their respective request completes, they won't have a chance to
674 // respond to the other client's request and cause a deadlock.
675 //
676 // This arrangement ensures we will attempt to process earlier messages first, but fall
677 // back to processing messages arrived later in the spirit of making progress.
678 let mut foreground_message_handlers = FuturesUnordered::new();
679 let concurrent_handlers = Arc::new(Semaphore::new(256));
680 loop {
681 let next_message = async {
682 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
683 let message = incoming_rx.next().await;
684 (permit, message)
685 }.fuse();
686 futures::pin_mut!(next_message);
687 futures::select_biased! {
688 _ = teardown.changed().fuse() => return,
689 result = handle_io => {
690 if let Err(error) = result {
691 tracing::error!(?error, "error handling I/O");
692 }
693 break;
694 }
695 _ = foreground_message_handlers.next() => {}
696 next_message = next_message => {
697 let (permit, message) = next_message;
698 if let Some(message) = message {
699 let type_name = message.payload_type_name();
700 // note: we copy all the fields from the parent span so we can query them in the logs.
701 // (https://github.com/tokio-rs/tracing/issues/2670).
702 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
703 let span_enter = span.enter();
704 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
705 let is_background = message.is_background();
706 let handle_message = (handler)(message, session.clone());
707 drop(span_enter);
708
709 let handle_message = async move {
710 handle_message.await;
711 drop(permit);
712 }.instrument(span);
713 if is_background {
714 executor.spawn_detached(handle_message);
715 } else {
716 foreground_message_handlers.push(handle_message);
717 }
718 } else {
719 tracing::error!("no message handler");
720 }
721 } else {
722 tracing::info!("connection closed");
723 break;
724 }
725 }
726 }
727 }
728
729 drop(foreground_message_handlers);
730 tracing::info!("signing out");
731 if let Err(error) = connection_lost(session, teardown, executor).await {
732 tracing::error!(?error, "error signing out");
733 }
734
735 }.instrument(span)
736 }
737
738 async fn send_initial_client_update(
739 &self,
740 connection_id: ConnectionId,
741 user: User,
742 zed_version: ZedVersion,
743 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
744 session: &Session,
745 ) -> Result<()> {
746 self.peer.send(
747 connection_id,
748 proto::Hello {
749 peer_id: Some(connection_id.into()),
750 },
751 )?;
752 tracing::info!("sent hello message");
753
754 if let Some(send_connection_id) = send_connection_id.take() {
755 let _ = send_connection_id.send(connection_id);
756 }
757
758 if !user.connected_once {
759 self.peer.send(connection_id, proto::ShowContacts {})?;
760 self.app_state
761 .db
762 .set_user_connected_once(user.id, true)
763 .await?;
764 }
765
766 let (contacts, channels_for_user, channel_invites) = future::try_join3(
767 self.app_state.db.get_contacts(user.id),
768 self.app_state.db.get_channels_for_user(user.id),
769 self.app_state.db.get_channel_invites_for_user(user.id),
770 )
771 .await?;
772
773 {
774 let mut pool = self.connection_pool.lock();
775 pool.add_connection(connection_id, user.id, user.admin, zed_version);
776 for membership in &channels_for_user.channel_memberships {
777 pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
778 }
779 self.peer.send(
780 connection_id,
781 build_initial_contacts_update(contacts, &pool),
782 )?;
783 self.peer.send(
784 connection_id,
785 build_update_user_channels(&channels_for_user),
786 )?;
787 self.peer.send(
788 connection_id,
789 build_channels_update(channels_for_user, channel_invites),
790 )?;
791 }
792
793 if let Some(incoming_call) = self.app_state.db.incoming_call_for_user(user.id).await? {
794 self.peer.send(connection_id, incoming_call)?;
795 }
796
797 update_user_contacts(user.id, &session).await?;
798 Ok(())
799 }
800
801 pub async fn invite_code_redeemed(
802 self: &Arc<Self>,
803 inviter_id: UserId,
804 invitee_id: UserId,
805 ) -> Result<()> {
806 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
807 if let Some(code) = &user.invite_code {
808 let pool = self.connection_pool.lock();
809 let invitee_contact = contact_for_user(invitee_id, false, &pool);
810 for connection_id in pool.user_connection_ids(inviter_id) {
811 self.peer.send(
812 connection_id,
813 proto::UpdateContacts {
814 contacts: vec![invitee_contact.clone()],
815 ..Default::default()
816 },
817 )?;
818 self.peer.send(
819 connection_id,
820 proto::UpdateInviteInfo {
821 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
822 count: user.invite_count as u32,
823 },
824 )?;
825 }
826 }
827 }
828 Ok(())
829 }
830
831 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
832 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
833 if let Some(invite_code) = &user.invite_code {
834 let pool = self.connection_pool.lock();
835 for connection_id in pool.user_connection_ids(user_id) {
836 self.peer.send(
837 connection_id,
838 proto::UpdateInviteInfo {
839 url: format!(
840 "{}{}",
841 self.app_state.config.invite_link_prefix, invite_code
842 ),
843 count: user.invite_count as u32,
844 },
845 )?;
846 }
847 }
848 }
849 Ok(())
850 }
851
852 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
853 ServerSnapshot {
854 connection_pool: ConnectionPoolGuard {
855 guard: self.connection_pool.lock(),
856 _not_send: PhantomData,
857 },
858 peer: &self.peer,
859 }
860 }
861}
862
863impl<'a> Deref for ConnectionPoolGuard<'a> {
864 type Target = ConnectionPool;
865
866 fn deref(&self) -> &Self::Target {
867 &self.guard
868 }
869}
870
871impl<'a> DerefMut for ConnectionPoolGuard<'a> {
872 fn deref_mut(&mut self) -> &mut Self::Target {
873 &mut self.guard
874 }
875}
876
877impl<'a> Drop for ConnectionPoolGuard<'a> {
878 fn drop(&mut self) {
879 #[cfg(test)]
880 self.check_invariants();
881 }
882}
883
884fn broadcast<F>(
885 sender_id: Option<ConnectionId>,
886 receiver_ids: impl IntoIterator<Item = ConnectionId>,
887 mut f: F,
888) where
889 F: FnMut(ConnectionId) -> anyhow::Result<()>,
890{
891 for receiver_id in receiver_ids {
892 if Some(receiver_id) != sender_id {
893 if let Err(error) = f(receiver_id) {
894 tracing::error!("failed to send to {:?} {}", receiver_id, error);
895 }
896 }
897 }
898}
899
900pub struct ProtocolVersion(u32);
901
902impl Header for ProtocolVersion {
903 fn name() -> &'static HeaderName {
904 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
905 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
906 }
907
908 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
909 where
910 Self: Sized,
911 I: Iterator<Item = &'i axum::http::HeaderValue>,
912 {
913 let version = values
914 .next()
915 .ok_or_else(axum::headers::Error::invalid)?
916 .to_str()
917 .map_err(|_| axum::headers::Error::invalid())?
918 .parse()
919 .map_err(|_| axum::headers::Error::invalid())?;
920 Ok(Self(version))
921 }
922
923 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
924 values.extend([self.0.to_string().parse().unwrap()]);
925 }
926}
927
928pub struct AppVersionHeader(SemanticVersion);
929impl Header for AppVersionHeader {
930 fn name() -> &'static HeaderName {
931 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
932 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
933 }
934
935 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
936 where
937 Self: Sized,
938 I: Iterator<Item = &'i axum::http::HeaderValue>,
939 {
940 let version = values
941 .next()
942 .ok_or_else(axum::headers::Error::invalid)?
943 .to_str()
944 .map_err(|_| axum::headers::Error::invalid())?
945 .parse()
946 .map_err(|_| axum::headers::Error::invalid())?;
947 Ok(Self(version))
948 }
949
950 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
951 values.extend([self.0.to_string().parse().unwrap()]);
952 }
953}
954
955pub fn routes(server: Arc<Server>) -> Router<(), Body> {
956 Router::new()
957 .route("/rpc", get(handle_websocket_request))
958 .layer(
959 ServiceBuilder::new()
960 .layer(Extension(server.app_state.clone()))
961 .layer(middleware::from_fn(auth::validate_header)),
962 )
963 .route("/metrics", get(handle_metrics))
964 .layer(Extension(server))
965}
966
967pub async fn handle_websocket_request(
968 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
969 app_version_header: Option<TypedHeader<AppVersionHeader>>,
970 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
971 Extension(server): Extension<Arc<Server>>,
972 Extension(user): Extension<User>,
973 Extension(impersonator): Extension<Impersonator>,
974 ws: WebSocketUpgrade,
975) -> axum::response::Response {
976 if protocol_version != rpc::PROTOCOL_VERSION {
977 return (
978 StatusCode::UPGRADE_REQUIRED,
979 "client must be upgraded".to_string(),
980 )
981 .into_response();
982 }
983
984 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
985 return (
986 StatusCode::UPGRADE_REQUIRED,
987 "no version header found".to_string(),
988 )
989 .into_response();
990 };
991
992 if !version.can_collaborate() {
993 return (
994 StatusCode::UPGRADE_REQUIRED,
995 "client must be upgraded".to_string(),
996 )
997 .into_response();
998 }
999
1000 let socket_address = socket_address.to_string();
1001 ws.on_upgrade(move |socket| {
1002 let socket = socket
1003 .map_ok(to_tungstenite_message)
1004 .err_into()
1005 .with(|message| async move { Ok(to_axum_message(message)) });
1006 let connection = Connection::new(Box::pin(socket));
1007 async move {
1008 server
1009 .handle_connection(
1010 connection,
1011 socket_address,
1012 user,
1013 version,
1014 impersonator.0,
1015 None,
1016 Executor::Production,
1017 )
1018 .await;
1019 }
1020 })
1021}
1022
1023pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1024 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1025 let connections_metric = CONNECTIONS_METRIC
1026 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1027
1028 let connections = server
1029 .connection_pool
1030 .lock()
1031 .connections()
1032 .filter(|connection| !connection.admin)
1033 .count();
1034 connections_metric.set(connections as _);
1035
1036 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1037 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1038 register_int_gauge!(
1039 "shared_projects",
1040 "number of open projects with one or more guests"
1041 )
1042 .unwrap()
1043 });
1044
1045 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1046 shared_projects_metric.set(shared_projects as _);
1047
1048 let encoder = prometheus::TextEncoder::new();
1049 let metric_families = prometheus::gather();
1050 let encoded_metrics = encoder
1051 .encode_to_string(&metric_families)
1052 .map_err(|err| anyhow!("{}", err))?;
1053 Ok(encoded_metrics)
1054}
1055
1056#[instrument(err, skip(executor))]
1057async fn connection_lost(
1058 session: Session,
1059 mut teardown: watch::Receiver<bool>,
1060 executor: Executor,
1061) -> Result<()> {
1062 session.peer.disconnect(session.connection_id);
1063 session
1064 .connection_pool()
1065 .await
1066 .remove_connection(session.connection_id)?;
1067
1068 session
1069 .db()
1070 .await
1071 .connection_lost(session.connection_id)
1072 .await
1073 .trace_err();
1074
1075 futures::select_biased! {
1076 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1077 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
1078 leave_room_for_session(&session).await.trace_err();
1079 leave_channel_buffers_for_session(&session)
1080 .await
1081 .trace_err();
1082
1083 if !session
1084 .connection_pool()
1085 .await
1086 .is_user_online(session.user_id)
1087 {
1088 let db = session.db().await;
1089 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
1090 room_updated(&room, &session.peer);
1091 }
1092 }
1093
1094 update_user_contacts(session.user_id, &session).await?;
1095 }
1096 _ = teardown.changed().fuse() => {}
1097 }
1098
1099 Ok(())
1100}
1101
1102/// Acknowledges a ping from a client, used to keep the connection alive.
1103async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1104 response.send(proto::Ack {})?;
1105 Ok(())
1106}
1107
1108/// Creates a new room for calling (outside of channels)
1109async fn create_room(
1110 _request: proto::CreateRoom,
1111 response: Response<proto::CreateRoom>,
1112 session: Session,
1113) -> Result<()> {
1114 let live_kit_room = nanoid::nanoid!(30);
1115
1116 let live_kit_connection_info = {
1117 let live_kit_room = live_kit_room.clone();
1118 let live_kit = session.live_kit_client.as_ref();
1119
1120 util::async_maybe!({
1121 let live_kit = live_kit?;
1122
1123 let token = live_kit
1124 .room_token(&live_kit_room, &session.user_id.to_string())
1125 .trace_err()?;
1126
1127 Some(proto::LiveKitConnectionInfo {
1128 server_url: live_kit.url().into(),
1129 token,
1130 can_publish: true,
1131 })
1132 })
1133 }
1134 .await;
1135
1136 let room = session
1137 .db()
1138 .await
1139 .create_room(session.user_id, session.connection_id, &live_kit_room)
1140 .await?;
1141
1142 response.send(proto::CreateRoomResponse {
1143 room: Some(room.clone()),
1144 live_kit_connection_info,
1145 })?;
1146
1147 update_user_contacts(session.user_id, &session).await?;
1148 Ok(())
1149}
1150
1151/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1152async fn join_room(
1153 request: proto::JoinRoom,
1154 response: Response<proto::JoinRoom>,
1155 session: Session,
1156) -> Result<()> {
1157 let room_id = RoomId::from_proto(request.id);
1158
1159 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1160
1161 if let Some(channel_id) = channel_id {
1162 return join_channel_internal(channel_id, Box::new(response), session).await;
1163 }
1164
1165 let joined_room = {
1166 let room = session
1167 .db()
1168 .await
1169 .join_room(room_id, session.user_id, session.connection_id)
1170 .await?;
1171 room_updated(&room.room, &session.peer);
1172 room.into_inner()
1173 };
1174
1175 for connection_id in session
1176 .connection_pool()
1177 .await
1178 .user_connection_ids(session.user_id)
1179 {
1180 session
1181 .peer
1182 .send(
1183 connection_id,
1184 proto::CallCanceled {
1185 room_id: room_id.to_proto(),
1186 },
1187 )
1188 .trace_err();
1189 }
1190
1191 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1192 if let Some(token) = live_kit
1193 .room_token(
1194 &joined_room.room.live_kit_room,
1195 &session.user_id.to_string(),
1196 )
1197 .trace_err()
1198 {
1199 Some(proto::LiveKitConnectionInfo {
1200 server_url: live_kit.url().into(),
1201 token,
1202 can_publish: true,
1203 })
1204 } else {
1205 None
1206 }
1207 } else {
1208 None
1209 };
1210
1211 response.send(proto::JoinRoomResponse {
1212 room: Some(joined_room.room),
1213 channel_id: None,
1214 live_kit_connection_info,
1215 })?;
1216
1217 update_user_contacts(session.user_id, &session).await?;
1218 Ok(())
1219}
1220
1221/// Rejoin room is used to reconnect to a room after connection errors.
1222async fn rejoin_room(
1223 request: proto::RejoinRoom,
1224 response: Response<proto::RejoinRoom>,
1225 session: Session,
1226) -> Result<()> {
1227 let room;
1228 let channel;
1229 {
1230 let mut rejoined_room = session
1231 .db()
1232 .await
1233 .rejoin_room(request, session.user_id, session.connection_id)
1234 .await?;
1235
1236 response.send(proto::RejoinRoomResponse {
1237 room: Some(rejoined_room.room.clone()),
1238 reshared_projects: rejoined_room
1239 .reshared_projects
1240 .iter()
1241 .map(|project| proto::ResharedProject {
1242 id: project.id.to_proto(),
1243 collaborators: project
1244 .collaborators
1245 .iter()
1246 .map(|collaborator| collaborator.to_proto())
1247 .collect(),
1248 })
1249 .collect(),
1250 rejoined_projects: rejoined_room
1251 .rejoined_projects
1252 .iter()
1253 .map(|rejoined_project| proto::RejoinedProject {
1254 id: rejoined_project.id.to_proto(),
1255 worktrees: rejoined_project
1256 .worktrees
1257 .iter()
1258 .map(|worktree| proto::WorktreeMetadata {
1259 id: worktree.id,
1260 root_name: worktree.root_name.clone(),
1261 visible: worktree.visible,
1262 abs_path: worktree.abs_path.clone(),
1263 })
1264 .collect(),
1265 collaborators: rejoined_project
1266 .collaborators
1267 .iter()
1268 .map(|collaborator| collaborator.to_proto())
1269 .collect(),
1270 language_servers: rejoined_project.language_servers.clone(),
1271 })
1272 .collect(),
1273 })?;
1274 room_updated(&rejoined_room.room, &session.peer);
1275
1276 for project in &rejoined_room.reshared_projects {
1277 for collaborator in &project.collaborators {
1278 session
1279 .peer
1280 .send(
1281 collaborator.connection_id,
1282 proto::UpdateProjectCollaborator {
1283 project_id: project.id.to_proto(),
1284 old_peer_id: Some(project.old_connection_id.into()),
1285 new_peer_id: Some(session.connection_id.into()),
1286 },
1287 )
1288 .trace_err();
1289 }
1290
1291 broadcast(
1292 Some(session.connection_id),
1293 project
1294 .collaborators
1295 .iter()
1296 .map(|collaborator| collaborator.connection_id),
1297 |connection_id| {
1298 session.peer.forward_send(
1299 session.connection_id,
1300 connection_id,
1301 proto::UpdateProject {
1302 project_id: project.id.to_proto(),
1303 worktrees: project.worktrees.clone(),
1304 },
1305 )
1306 },
1307 );
1308 }
1309
1310 for project in &rejoined_room.rejoined_projects {
1311 for collaborator in &project.collaborators {
1312 session
1313 .peer
1314 .send(
1315 collaborator.connection_id,
1316 proto::UpdateProjectCollaborator {
1317 project_id: project.id.to_proto(),
1318 old_peer_id: Some(project.old_connection_id.into()),
1319 new_peer_id: Some(session.connection_id.into()),
1320 },
1321 )
1322 .trace_err();
1323 }
1324 }
1325
1326 for project in &mut rejoined_room.rejoined_projects {
1327 for worktree in mem::take(&mut project.worktrees) {
1328 #[cfg(any(test, feature = "test-support"))]
1329 const MAX_CHUNK_SIZE: usize = 2;
1330 #[cfg(not(any(test, feature = "test-support")))]
1331 const MAX_CHUNK_SIZE: usize = 256;
1332
1333 // Stream this worktree's entries.
1334 let message = proto::UpdateWorktree {
1335 project_id: project.id.to_proto(),
1336 worktree_id: worktree.id,
1337 abs_path: worktree.abs_path.clone(),
1338 root_name: worktree.root_name,
1339 updated_entries: worktree.updated_entries,
1340 removed_entries: worktree.removed_entries,
1341 scan_id: worktree.scan_id,
1342 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1343 updated_repositories: worktree.updated_repositories,
1344 removed_repositories: worktree.removed_repositories,
1345 };
1346 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1347 session.peer.send(session.connection_id, update.clone())?;
1348 }
1349
1350 // Stream this worktree's diagnostics.
1351 for summary in worktree.diagnostic_summaries {
1352 session.peer.send(
1353 session.connection_id,
1354 proto::UpdateDiagnosticSummary {
1355 project_id: project.id.to_proto(),
1356 worktree_id: worktree.id,
1357 summary: Some(summary),
1358 },
1359 )?;
1360 }
1361
1362 for settings_file in worktree.settings_files {
1363 session.peer.send(
1364 session.connection_id,
1365 proto::UpdateWorktreeSettings {
1366 project_id: project.id.to_proto(),
1367 worktree_id: worktree.id,
1368 path: settings_file.path,
1369 content: Some(settings_file.content),
1370 },
1371 )?;
1372 }
1373 }
1374
1375 for language_server in &project.language_servers {
1376 session.peer.send(
1377 session.connection_id,
1378 proto::UpdateLanguageServer {
1379 project_id: project.id.to_proto(),
1380 language_server_id: language_server.id,
1381 variant: Some(
1382 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1383 proto::LspDiskBasedDiagnosticsUpdated {},
1384 ),
1385 ),
1386 },
1387 )?;
1388 }
1389 }
1390
1391 let rejoined_room = rejoined_room.into_inner();
1392
1393 room = rejoined_room.room;
1394 channel = rejoined_room.channel;
1395 }
1396
1397 if let Some(channel) = channel {
1398 channel_updated(
1399 &channel,
1400 &room,
1401 &session.peer,
1402 &*session.connection_pool().await,
1403 );
1404 }
1405
1406 update_user_contacts(session.user_id, &session).await?;
1407 Ok(())
1408}
1409
1410/// leave room disconnects from the room.
1411async fn leave_room(
1412 _: proto::LeaveRoom,
1413 response: Response<proto::LeaveRoom>,
1414 session: Session,
1415) -> Result<()> {
1416 leave_room_for_session(&session).await?;
1417 response.send(proto::Ack {})?;
1418 Ok(())
1419}
1420
1421/// Updates the permissions of someone else in the room.
1422async fn set_room_participant_role(
1423 request: proto::SetRoomParticipantRole,
1424 response: Response<proto::SetRoomParticipantRole>,
1425 session: Session,
1426) -> Result<()> {
1427 let user_id = UserId::from_proto(request.user_id);
1428 let role = ChannelRole::from(request.role());
1429
1430 let (live_kit_room, can_publish) = {
1431 let room = session
1432 .db()
1433 .await
1434 .set_room_participant_role(
1435 session.user_id,
1436 RoomId::from_proto(request.room_id),
1437 user_id,
1438 role,
1439 )
1440 .await?;
1441
1442 let live_kit_room = room.live_kit_room.clone();
1443 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1444 room_updated(&room, &session.peer);
1445 (live_kit_room, can_publish)
1446 };
1447
1448 if let Some(live_kit) = session.live_kit_client.as_ref() {
1449 live_kit
1450 .update_participant(
1451 live_kit_room.clone(),
1452 request.user_id.to_string(),
1453 live_kit_server::proto::ParticipantPermission {
1454 can_subscribe: true,
1455 can_publish,
1456 can_publish_data: can_publish,
1457 hidden: false,
1458 recorder: false,
1459 },
1460 )
1461 .await
1462 .trace_err();
1463 }
1464
1465 response.send(proto::Ack {})?;
1466 Ok(())
1467}
1468
1469/// Call someone else into the current room
1470async fn call(
1471 request: proto::Call,
1472 response: Response<proto::Call>,
1473 session: Session,
1474) -> Result<()> {
1475 let room_id = RoomId::from_proto(request.room_id);
1476 let calling_user_id = session.user_id;
1477 let calling_connection_id = session.connection_id;
1478 let called_user_id = UserId::from_proto(request.called_user_id);
1479 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1480 if !session
1481 .db()
1482 .await
1483 .has_contact(calling_user_id, called_user_id)
1484 .await?
1485 {
1486 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1487 }
1488
1489 let incoming_call = {
1490 let (room, incoming_call) = &mut *session
1491 .db()
1492 .await
1493 .call(
1494 room_id,
1495 calling_user_id,
1496 calling_connection_id,
1497 called_user_id,
1498 initial_project_id,
1499 )
1500 .await?;
1501 room_updated(&room, &session.peer);
1502 mem::take(incoming_call)
1503 };
1504 update_user_contacts(called_user_id, &session).await?;
1505
1506 let mut calls = session
1507 .connection_pool()
1508 .await
1509 .user_connection_ids(called_user_id)
1510 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1511 .collect::<FuturesUnordered<_>>();
1512
1513 while let Some(call_response) = calls.next().await {
1514 match call_response.as_ref() {
1515 Ok(_) => {
1516 response.send(proto::Ack {})?;
1517 return Ok(());
1518 }
1519 Err(_) => {
1520 call_response.trace_err();
1521 }
1522 }
1523 }
1524
1525 {
1526 let room = session
1527 .db()
1528 .await
1529 .call_failed(room_id, called_user_id)
1530 .await?;
1531 room_updated(&room, &session.peer);
1532 }
1533 update_user_contacts(called_user_id, &session).await?;
1534
1535 Err(anyhow!("failed to ring user"))?
1536}
1537
1538/// Cancel an outgoing call.
1539async fn cancel_call(
1540 request: proto::CancelCall,
1541 response: Response<proto::CancelCall>,
1542 session: Session,
1543) -> Result<()> {
1544 let called_user_id = UserId::from_proto(request.called_user_id);
1545 let room_id = RoomId::from_proto(request.room_id);
1546 {
1547 let room = session
1548 .db()
1549 .await
1550 .cancel_call(room_id, session.connection_id, called_user_id)
1551 .await?;
1552 room_updated(&room, &session.peer);
1553 }
1554
1555 for connection_id in session
1556 .connection_pool()
1557 .await
1558 .user_connection_ids(called_user_id)
1559 {
1560 session
1561 .peer
1562 .send(
1563 connection_id,
1564 proto::CallCanceled {
1565 room_id: room_id.to_proto(),
1566 },
1567 )
1568 .trace_err();
1569 }
1570 response.send(proto::Ack {})?;
1571
1572 update_user_contacts(called_user_id, &session).await?;
1573 Ok(())
1574}
1575
1576/// Decline an incoming call.
1577async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1578 let room_id = RoomId::from_proto(message.room_id);
1579 {
1580 let room = session
1581 .db()
1582 .await
1583 .decline_call(Some(room_id), session.user_id)
1584 .await?
1585 .ok_or_else(|| anyhow!("failed to decline call"))?;
1586 room_updated(&room, &session.peer);
1587 }
1588
1589 for connection_id in session
1590 .connection_pool()
1591 .await
1592 .user_connection_ids(session.user_id)
1593 {
1594 session
1595 .peer
1596 .send(
1597 connection_id,
1598 proto::CallCanceled {
1599 room_id: room_id.to_proto(),
1600 },
1601 )
1602 .trace_err();
1603 }
1604 update_user_contacts(session.user_id, &session).await?;
1605 Ok(())
1606}
1607
1608/// Updates other participants in the room with your current location.
1609async fn update_participant_location(
1610 request: proto::UpdateParticipantLocation,
1611 response: Response<proto::UpdateParticipantLocation>,
1612 session: Session,
1613) -> Result<()> {
1614 let room_id = RoomId::from_proto(request.room_id);
1615 let location = request
1616 .location
1617 .ok_or_else(|| anyhow!("invalid location"))?;
1618
1619 let db = session.db().await;
1620 let room = db
1621 .update_room_participant_location(room_id, session.connection_id, location)
1622 .await?;
1623
1624 room_updated(&room, &session.peer);
1625 response.send(proto::Ack {})?;
1626 Ok(())
1627}
1628
1629/// Share a project into the room.
1630async fn share_project(
1631 request: proto::ShareProject,
1632 response: Response<proto::ShareProject>,
1633 session: Session,
1634) -> Result<()> {
1635 let (project_id, room) = &*session
1636 .db()
1637 .await
1638 .share_project(
1639 RoomId::from_proto(request.room_id),
1640 session.connection_id,
1641 &request.worktrees,
1642 )
1643 .await?;
1644 response.send(proto::ShareProjectResponse {
1645 project_id: project_id.to_proto(),
1646 })?;
1647 room_updated(&room, &session.peer);
1648
1649 Ok(())
1650}
1651
1652/// Unshare a project from the room.
1653async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1654 let project_id = ProjectId::from_proto(message.project_id);
1655
1656 let (room, guest_connection_ids) = &*session
1657 .db()
1658 .await
1659 .unshare_project(project_id, session.connection_id)
1660 .await?;
1661
1662 broadcast(
1663 Some(session.connection_id),
1664 guest_connection_ids.iter().copied(),
1665 |conn_id| session.peer.send(conn_id, message.clone()),
1666 );
1667 room_updated(&room, &session.peer);
1668
1669 Ok(())
1670}
1671
1672/// Join someone elses shared project.
1673async fn join_project(
1674 request: proto::JoinProject,
1675 response: Response<proto::JoinProject>,
1676 session: Session,
1677) -> Result<()> {
1678 let project_id = ProjectId::from_proto(request.project_id);
1679
1680 tracing::info!(%project_id, "join project");
1681
1682 let (project, replica_id) = &mut *session
1683 .db()
1684 .await
1685 .join_project_in_room(project_id, session.connection_id)
1686 .await?;
1687
1688 join_project_internal(response, session, project, replica_id)
1689}
1690
1691trait JoinProjectInternalResponse {
1692 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1693}
1694impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1695 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1696 Response::<proto::JoinProject>::send(self, result)
1697 }
1698}
1699impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1700 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1701 Response::<proto::JoinHostedProject>::send(self, result)
1702 }
1703}
1704
1705fn join_project_internal(
1706 response: impl JoinProjectInternalResponse,
1707 session: Session,
1708 project: &mut Project,
1709 replica_id: &ReplicaId,
1710) -> Result<()> {
1711 let collaborators = project
1712 .collaborators
1713 .iter()
1714 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1715 .map(|collaborator| collaborator.to_proto())
1716 .collect::<Vec<_>>();
1717 let project_id = project.id;
1718 let guest_user_id = session.user_id;
1719
1720 let worktrees = project
1721 .worktrees
1722 .iter()
1723 .map(|(id, worktree)| proto::WorktreeMetadata {
1724 id: *id,
1725 root_name: worktree.root_name.clone(),
1726 visible: worktree.visible,
1727 abs_path: worktree.abs_path.clone(),
1728 })
1729 .collect::<Vec<_>>();
1730
1731 for collaborator in &collaborators {
1732 session
1733 .peer
1734 .send(
1735 collaborator.peer_id.unwrap().into(),
1736 proto::AddProjectCollaborator {
1737 project_id: project_id.to_proto(),
1738 collaborator: Some(proto::Collaborator {
1739 peer_id: Some(session.connection_id.into()),
1740 replica_id: replica_id.0 as u32,
1741 user_id: guest_user_id.to_proto(),
1742 }),
1743 },
1744 )
1745 .trace_err();
1746 }
1747
1748 // First, we send the metadata associated with each worktree.
1749 response.send(proto::JoinProjectResponse {
1750 project_id: project.id.0 as u64,
1751 worktrees: worktrees.clone(),
1752 replica_id: replica_id.0 as u32,
1753 collaborators: collaborators.clone(),
1754 language_servers: project.language_servers.clone(),
1755 role: project.role.into(), // todo
1756 })?;
1757
1758 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1759 #[cfg(any(test, feature = "test-support"))]
1760 const MAX_CHUNK_SIZE: usize = 2;
1761 #[cfg(not(any(test, feature = "test-support")))]
1762 const MAX_CHUNK_SIZE: usize = 256;
1763
1764 // Stream this worktree's entries.
1765 let message = proto::UpdateWorktree {
1766 project_id: project_id.to_proto(),
1767 worktree_id,
1768 abs_path: worktree.abs_path.clone(),
1769 root_name: worktree.root_name,
1770 updated_entries: worktree.entries,
1771 removed_entries: Default::default(),
1772 scan_id: worktree.scan_id,
1773 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1774 updated_repositories: worktree.repository_entries.into_values().collect(),
1775 removed_repositories: Default::default(),
1776 };
1777 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1778 session.peer.send(session.connection_id, update.clone())?;
1779 }
1780
1781 // Stream this worktree's diagnostics.
1782 for summary in worktree.diagnostic_summaries {
1783 session.peer.send(
1784 session.connection_id,
1785 proto::UpdateDiagnosticSummary {
1786 project_id: project_id.to_proto(),
1787 worktree_id: worktree.id,
1788 summary: Some(summary),
1789 },
1790 )?;
1791 }
1792
1793 for settings_file in worktree.settings_files {
1794 session.peer.send(
1795 session.connection_id,
1796 proto::UpdateWorktreeSettings {
1797 project_id: project_id.to_proto(),
1798 worktree_id: worktree.id,
1799 path: settings_file.path,
1800 content: Some(settings_file.content),
1801 },
1802 )?;
1803 }
1804 }
1805
1806 for language_server in &project.language_servers {
1807 session.peer.send(
1808 session.connection_id,
1809 proto::UpdateLanguageServer {
1810 project_id: project_id.to_proto(),
1811 language_server_id: language_server.id,
1812 variant: Some(
1813 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1814 proto::LspDiskBasedDiagnosticsUpdated {},
1815 ),
1816 ),
1817 },
1818 )?;
1819 }
1820
1821 Ok(())
1822}
1823
1824/// Leave someone elses shared project.
1825async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1826 let sender_id = session.connection_id;
1827 let project_id = ProjectId::from_proto(request.project_id);
1828 let db = session.db().await;
1829 if db.is_hosted_project(project_id).await? {
1830 let project = db.leave_hosted_project(project_id, sender_id).await?;
1831 project_left(&project, &session);
1832 return Ok(());
1833 }
1834
1835 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1836 tracing::info!(
1837 %project_id,
1838 host_user_id = ?project.host_user_id,
1839 host_connection_id = ?project.host_connection_id,
1840 "leave project"
1841 );
1842
1843 project_left(&project, &session);
1844 room_updated(&room, &session.peer);
1845
1846 Ok(())
1847}
1848
1849async fn join_hosted_project(
1850 request: proto::JoinHostedProject,
1851 response: Response<proto::JoinHostedProject>,
1852 session: Session,
1853) -> Result<()> {
1854 let (mut project, replica_id) = session
1855 .db()
1856 .await
1857 .join_hosted_project(
1858 ProjectId(request.project_id as i32),
1859 session.user_id,
1860 session.connection_id,
1861 )
1862 .await?;
1863
1864 join_project_internal(response, session, &mut project, &replica_id)
1865}
1866
1867/// Updates other participants with changes to the project
1868async fn update_project(
1869 request: proto::UpdateProject,
1870 response: Response<proto::UpdateProject>,
1871 session: Session,
1872) -> Result<()> {
1873 let project_id = ProjectId::from_proto(request.project_id);
1874 let (room, guest_connection_ids) = &*session
1875 .db()
1876 .await
1877 .update_project(project_id, session.connection_id, &request.worktrees)
1878 .await?;
1879 broadcast(
1880 Some(session.connection_id),
1881 guest_connection_ids.iter().copied(),
1882 |connection_id| {
1883 session
1884 .peer
1885 .forward_send(session.connection_id, connection_id, request.clone())
1886 },
1887 );
1888 room_updated(&room, &session.peer);
1889 response.send(proto::Ack {})?;
1890
1891 Ok(())
1892}
1893
1894/// Updates other participants with changes to the worktree
1895async fn update_worktree(
1896 request: proto::UpdateWorktree,
1897 response: Response<proto::UpdateWorktree>,
1898 session: Session,
1899) -> Result<()> {
1900 let guest_connection_ids = session
1901 .db()
1902 .await
1903 .update_worktree(&request, session.connection_id)
1904 .await?;
1905
1906 broadcast(
1907 Some(session.connection_id),
1908 guest_connection_ids.iter().copied(),
1909 |connection_id| {
1910 session
1911 .peer
1912 .forward_send(session.connection_id, connection_id, request.clone())
1913 },
1914 );
1915 response.send(proto::Ack {})?;
1916 Ok(())
1917}
1918
1919/// Updates other participants with changes to the diagnostics
1920async fn update_diagnostic_summary(
1921 message: proto::UpdateDiagnosticSummary,
1922 session: Session,
1923) -> Result<()> {
1924 let guest_connection_ids = session
1925 .db()
1926 .await
1927 .update_diagnostic_summary(&message, session.connection_id)
1928 .await?;
1929
1930 broadcast(
1931 Some(session.connection_id),
1932 guest_connection_ids.iter().copied(),
1933 |connection_id| {
1934 session
1935 .peer
1936 .forward_send(session.connection_id, connection_id, message.clone())
1937 },
1938 );
1939
1940 Ok(())
1941}
1942
1943/// Updates other participants with changes to the worktree settings
1944async fn update_worktree_settings(
1945 message: proto::UpdateWorktreeSettings,
1946 session: Session,
1947) -> Result<()> {
1948 let guest_connection_ids = session
1949 .db()
1950 .await
1951 .update_worktree_settings(&message, session.connection_id)
1952 .await?;
1953
1954 broadcast(
1955 Some(session.connection_id),
1956 guest_connection_ids.iter().copied(),
1957 |connection_id| {
1958 session
1959 .peer
1960 .forward_send(session.connection_id, connection_id, message.clone())
1961 },
1962 );
1963
1964 Ok(())
1965}
1966
1967/// Notify other participants that a language server has started.
1968async fn start_language_server(
1969 request: proto::StartLanguageServer,
1970 session: Session,
1971) -> Result<()> {
1972 let guest_connection_ids = session
1973 .db()
1974 .await
1975 .start_language_server(&request, session.connection_id)
1976 .await?;
1977
1978 broadcast(
1979 Some(session.connection_id),
1980 guest_connection_ids.iter().copied(),
1981 |connection_id| {
1982 session
1983 .peer
1984 .forward_send(session.connection_id, connection_id, request.clone())
1985 },
1986 );
1987 Ok(())
1988}
1989
1990/// Notify other participants that a language server has changed.
1991async fn update_language_server(
1992 request: proto::UpdateLanguageServer,
1993 session: Session,
1994) -> Result<()> {
1995 let project_id = ProjectId::from_proto(request.project_id);
1996 let project_connection_ids = session
1997 .db()
1998 .await
1999 .project_connection_ids(project_id, session.connection_id)
2000 .await?;
2001 broadcast(
2002 Some(session.connection_id),
2003 project_connection_ids.iter().copied(),
2004 |connection_id| {
2005 session
2006 .peer
2007 .forward_send(session.connection_id, connection_id, request.clone())
2008 },
2009 );
2010 Ok(())
2011}
2012
2013/// forward a project request to the host. These requests should be read only
2014/// as guests are allowed to send them.
2015async fn forward_read_only_project_request<T>(
2016 request: T,
2017 response: Response<T>,
2018 session: Session,
2019) -> Result<()>
2020where
2021 T: EntityMessage + RequestMessage,
2022{
2023 let project_id = ProjectId::from_proto(request.remote_entity_id());
2024 let host_connection_id = session
2025 .db()
2026 .await
2027 .host_for_read_only_project_request(project_id, session.connection_id)
2028 .await?;
2029 let payload = session
2030 .peer
2031 .forward_request(session.connection_id, host_connection_id, request)
2032 .await?;
2033 response.send(payload)?;
2034 Ok(())
2035}
2036
2037/// forward a project request to the host. These requests are disallowed
2038/// for guests.
2039async fn forward_mutating_project_request<T>(
2040 request: T,
2041 response: Response<T>,
2042 session: Session,
2043) -> Result<()>
2044where
2045 T: EntityMessage + RequestMessage,
2046{
2047 let project_id = ProjectId::from_proto(request.remote_entity_id());
2048 let host_connection_id = session
2049 .db()
2050 .await
2051 .host_for_mutating_project_request(project_id, session.connection_id)
2052 .await?;
2053 let payload = session
2054 .peer
2055 .forward_request(session.connection_id, host_connection_id, request)
2056 .await?;
2057 response.send(payload)?;
2058 Ok(())
2059}
2060
2061/// Notify other participants that a new buffer has been created
2062async fn create_buffer_for_peer(
2063 request: proto::CreateBufferForPeer,
2064 session: Session,
2065) -> Result<()> {
2066 session
2067 .db()
2068 .await
2069 .check_user_is_project_host(
2070 ProjectId::from_proto(request.project_id),
2071 session.connection_id,
2072 )
2073 .await?;
2074 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2075 session
2076 .peer
2077 .forward_send(session.connection_id, peer_id.into(), request)?;
2078 Ok(())
2079}
2080
2081/// Notify other participants that a buffer has been updated. This is
2082/// allowed for guests as long as the update is limited to selections.
2083async fn update_buffer(
2084 request: proto::UpdateBuffer,
2085 response: Response<proto::UpdateBuffer>,
2086 session: Session,
2087) -> Result<()> {
2088 let project_id = ProjectId::from_proto(request.project_id);
2089 let mut guest_connection_ids;
2090 let mut host_connection_id = None;
2091
2092 let mut requires_write_permission = false;
2093
2094 for op in request.operations.iter() {
2095 match op.variant {
2096 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2097 Some(_) => requires_write_permission = true,
2098 }
2099 }
2100
2101 {
2102 let collaborators = session
2103 .db()
2104 .await
2105 .project_collaborators_for_buffer_update(
2106 project_id,
2107 session.connection_id,
2108 requires_write_permission,
2109 )
2110 .await?;
2111 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2112 for collaborator in collaborators.iter() {
2113 if collaborator.is_host {
2114 host_connection_id = Some(collaborator.connection_id);
2115 } else {
2116 guest_connection_ids.push(collaborator.connection_id);
2117 }
2118 }
2119 }
2120 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2121
2122 broadcast(
2123 Some(session.connection_id),
2124 guest_connection_ids,
2125 |connection_id| {
2126 session
2127 .peer
2128 .forward_send(session.connection_id, connection_id, request.clone())
2129 },
2130 );
2131 if host_connection_id != session.connection_id {
2132 session
2133 .peer
2134 .forward_request(session.connection_id, host_connection_id, request.clone())
2135 .await?;
2136 }
2137
2138 response.send(proto::Ack {})?;
2139 Ok(())
2140}
2141
2142/// Notify other participants that a project has been updated.
2143async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2144 request: T,
2145 session: Session,
2146) -> Result<()> {
2147 let project_id = ProjectId::from_proto(request.remote_entity_id());
2148 let project_connection_ids = session
2149 .db()
2150 .await
2151 .project_connection_ids(project_id, session.connection_id)
2152 .await?;
2153
2154 broadcast(
2155 Some(session.connection_id),
2156 project_connection_ids.iter().copied(),
2157 |connection_id| {
2158 session
2159 .peer
2160 .forward_send(session.connection_id, connection_id, request.clone())
2161 },
2162 );
2163 Ok(())
2164}
2165
2166/// Start following another user in a call.
2167async fn follow(
2168 request: proto::Follow,
2169 response: Response<proto::Follow>,
2170 session: Session,
2171) -> Result<()> {
2172 let room_id = RoomId::from_proto(request.room_id);
2173 let project_id = request.project_id.map(ProjectId::from_proto);
2174 let leader_id = request
2175 .leader_id
2176 .ok_or_else(|| anyhow!("invalid leader id"))?
2177 .into();
2178 let follower_id = session.connection_id;
2179
2180 session
2181 .db()
2182 .await
2183 .check_room_participants(room_id, leader_id, session.connection_id)
2184 .await?;
2185
2186 let response_payload = session
2187 .peer
2188 .forward_request(session.connection_id, leader_id, request)
2189 .await?;
2190 response.send(response_payload)?;
2191
2192 if let Some(project_id) = project_id {
2193 let room = session
2194 .db()
2195 .await
2196 .follow(room_id, project_id, leader_id, follower_id)
2197 .await?;
2198 room_updated(&room, &session.peer);
2199 }
2200
2201 Ok(())
2202}
2203
2204/// Stop following another user in a call.
2205async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2206 let room_id = RoomId::from_proto(request.room_id);
2207 let project_id = request.project_id.map(ProjectId::from_proto);
2208 let leader_id = request
2209 .leader_id
2210 .ok_or_else(|| anyhow!("invalid leader id"))?
2211 .into();
2212 let follower_id = session.connection_id;
2213
2214 session
2215 .db()
2216 .await
2217 .check_room_participants(room_id, leader_id, session.connection_id)
2218 .await?;
2219
2220 session
2221 .peer
2222 .forward_send(session.connection_id, leader_id, request)?;
2223
2224 if let Some(project_id) = project_id {
2225 let room = session
2226 .db()
2227 .await
2228 .unfollow(room_id, project_id, leader_id, follower_id)
2229 .await?;
2230 room_updated(&room, &session.peer);
2231 }
2232
2233 Ok(())
2234}
2235
2236/// Notify everyone following you of your current location.
2237async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2238 let room_id = RoomId::from_proto(request.room_id);
2239 let database = session.db.lock().await;
2240
2241 let connection_ids = if let Some(project_id) = request.project_id {
2242 let project_id = ProjectId::from_proto(project_id);
2243 database
2244 .project_connection_ids(project_id, session.connection_id)
2245 .await?
2246 } else {
2247 database
2248 .room_connection_ids(room_id, session.connection_id)
2249 .await?
2250 };
2251
2252 // For now, don't send view update messages back to that view's current leader.
2253 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2254 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2255 _ => None,
2256 });
2257
2258 for connection_id in connection_ids.iter().cloned() {
2259 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2260 session
2261 .peer
2262 .forward_send(session.connection_id, connection_id, request.clone())?;
2263 }
2264 }
2265 Ok(())
2266}
2267
2268/// Get public data about users.
2269async fn get_users(
2270 request: proto::GetUsers,
2271 response: Response<proto::GetUsers>,
2272 session: Session,
2273) -> Result<()> {
2274 let user_ids = request
2275 .user_ids
2276 .into_iter()
2277 .map(UserId::from_proto)
2278 .collect();
2279 let users = session
2280 .db()
2281 .await
2282 .get_users_by_ids(user_ids)
2283 .await?
2284 .into_iter()
2285 .map(|user| proto::User {
2286 id: user.id.to_proto(),
2287 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2288 github_login: user.github_login,
2289 })
2290 .collect();
2291 response.send(proto::UsersResponse { users })?;
2292 Ok(())
2293}
2294
2295/// Search for users (to invite) buy Github login
2296async fn fuzzy_search_users(
2297 request: proto::FuzzySearchUsers,
2298 response: Response<proto::FuzzySearchUsers>,
2299 session: Session,
2300) -> Result<()> {
2301 let query = request.query;
2302 let users = match query.len() {
2303 0 => vec![],
2304 1 | 2 => session
2305 .db()
2306 .await
2307 .get_user_by_github_login(&query)
2308 .await?
2309 .into_iter()
2310 .collect(),
2311 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2312 };
2313 let users = users
2314 .into_iter()
2315 .filter(|user| user.id != session.user_id)
2316 .map(|user| proto::User {
2317 id: user.id.to_proto(),
2318 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2319 github_login: user.github_login,
2320 })
2321 .collect();
2322 response.send(proto::UsersResponse { users })?;
2323 Ok(())
2324}
2325
2326/// Send a contact request to another user.
2327async fn request_contact(
2328 request: proto::RequestContact,
2329 response: Response<proto::RequestContact>,
2330 session: Session,
2331) -> Result<()> {
2332 let requester_id = session.user_id;
2333 let responder_id = UserId::from_proto(request.responder_id);
2334 if requester_id == responder_id {
2335 return Err(anyhow!("cannot add yourself as a contact"))?;
2336 }
2337
2338 let notifications = session
2339 .db()
2340 .await
2341 .send_contact_request(requester_id, responder_id)
2342 .await?;
2343
2344 // Update outgoing contact requests of requester
2345 let mut update = proto::UpdateContacts::default();
2346 update.outgoing_requests.push(responder_id.to_proto());
2347 for connection_id in session
2348 .connection_pool()
2349 .await
2350 .user_connection_ids(requester_id)
2351 {
2352 session.peer.send(connection_id, update.clone())?;
2353 }
2354
2355 // Update incoming contact requests of responder
2356 let mut update = proto::UpdateContacts::default();
2357 update
2358 .incoming_requests
2359 .push(proto::IncomingContactRequest {
2360 requester_id: requester_id.to_proto(),
2361 });
2362 let connection_pool = session.connection_pool().await;
2363 for connection_id in connection_pool.user_connection_ids(responder_id) {
2364 session.peer.send(connection_id, update.clone())?;
2365 }
2366
2367 send_notifications(&connection_pool, &session.peer, notifications);
2368
2369 response.send(proto::Ack {})?;
2370 Ok(())
2371}
2372
2373/// Accept or decline a contact request
2374async fn respond_to_contact_request(
2375 request: proto::RespondToContactRequest,
2376 response: Response<proto::RespondToContactRequest>,
2377 session: Session,
2378) -> Result<()> {
2379 let responder_id = session.user_id;
2380 let requester_id = UserId::from_proto(request.requester_id);
2381 let db = session.db().await;
2382 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2383 db.dismiss_contact_notification(responder_id, requester_id)
2384 .await?;
2385 } else {
2386 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2387
2388 let notifications = db
2389 .respond_to_contact_request(responder_id, requester_id, accept)
2390 .await?;
2391 let requester_busy = db.is_user_busy(requester_id).await?;
2392 let responder_busy = db.is_user_busy(responder_id).await?;
2393
2394 let pool = session.connection_pool().await;
2395 // Update responder with new contact
2396 let mut update = proto::UpdateContacts::default();
2397 if accept {
2398 update
2399 .contacts
2400 .push(contact_for_user(requester_id, requester_busy, &pool));
2401 }
2402 update
2403 .remove_incoming_requests
2404 .push(requester_id.to_proto());
2405 for connection_id in pool.user_connection_ids(responder_id) {
2406 session.peer.send(connection_id, update.clone())?;
2407 }
2408
2409 // Update requester with new contact
2410 let mut update = proto::UpdateContacts::default();
2411 if accept {
2412 update
2413 .contacts
2414 .push(contact_for_user(responder_id, responder_busy, &pool));
2415 }
2416 update
2417 .remove_outgoing_requests
2418 .push(responder_id.to_proto());
2419
2420 for connection_id in pool.user_connection_ids(requester_id) {
2421 session.peer.send(connection_id, update.clone())?;
2422 }
2423
2424 send_notifications(&pool, &session.peer, notifications);
2425 }
2426
2427 response.send(proto::Ack {})?;
2428 Ok(())
2429}
2430
2431/// Remove a contact.
2432async fn remove_contact(
2433 request: proto::RemoveContact,
2434 response: Response<proto::RemoveContact>,
2435 session: Session,
2436) -> Result<()> {
2437 let requester_id = session.user_id;
2438 let responder_id = UserId::from_proto(request.user_id);
2439 let db = session.db().await;
2440 let (contact_accepted, deleted_notification_id) =
2441 db.remove_contact(requester_id, responder_id).await?;
2442
2443 let pool = session.connection_pool().await;
2444 // Update outgoing contact requests of requester
2445 let mut update = proto::UpdateContacts::default();
2446 if contact_accepted {
2447 update.remove_contacts.push(responder_id.to_proto());
2448 } else {
2449 update
2450 .remove_outgoing_requests
2451 .push(responder_id.to_proto());
2452 }
2453 for connection_id in pool.user_connection_ids(requester_id) {
2454 session.peer.send(connection_id, update.clone())?;
2455 }
2456
2457 // Update incoming contact requests of responder
2458 let mut update = proto::UpdateContacts::default();
2459 if contact_accepted {
2460 update.remove_contacts.push(requester_id.to_proto());
2461 } else {
2462 update
2463 .remove_incoming_requests
2464 .push(requester_id.to_proto());
2465 }
2466 for connection_id in pool.user_connection_ids(responder_id) {
2467 session.peer.send(connection_id, update.clone())?;
2468 if let Some(notification_id) = deleted_notification_id {
2469 session.peer.send(
2470 connection_id,
2471 proto::DeleteNotification {
2472 notification_id: notification_id.to_proto(),
2473 },
2474 )?;
2475 }
2476 }
2477
2478 response.send(proto::Ack {})?;
2479 Ok(())
2480}
2481
2482/// Creates a new channel.
2483async fn create_channel(
2484 request: proto::CreateChannel,
2485 response: Response<proto::CreateChannel>,
2486 session: Session,
2487) -> Result<()> {
2488 let db = session.db().await;
2489
2490 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2491 let (channel, membership) = db
2492 .create_channel(&request.name, parent_id, session.user_id)
2493 .await?;
2494
2495 let root_id = channel.root_id();
2496 let channel = Channel::from_model(channel);
2497
2498 response.send(proto::CreateChannelResponse {
2499 channel: Some(channel.to_proto()),
2500 parent_id: request.parent_id,
2501 })?;
2502
2503 let mut connection_pool = session.connection_pool().await;
2504 if let Some(membership) = membership {
2505 connection_pool.subscribe_to_channel(
2506 membership.user_id,
2507 membership.channel_id,
2508 membership.role,
2509 );
2510 let update = proto::UpdateUserChannels {
2511 channel_memberships: vec![proto::ChannelMembership {
2512 channel_id: membership.channel_id.to_proto(),
2513 role: membership.role.into(),
2514 }],
2515 ..Default::default()
2516 };
2517 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2518 session.peer.send(connection_id, update.clone())?;
2519 }
2520 }
2521
2522 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2523 if !role.can_see_channel(channel.visibility) {
2524 continue;
2525 }
2526
2527 let update = proto::UpdateChannels {
2528 channels: vec![channel.to_proto()],
2529 ..Default::default()
2530 };
2531 session.peer.send(connection_id, update.clone())?;
2532 }
2533
2534 Ok(())
2535}
2536
2537/// Delete a channel
2538async fn delete_channel(
2539 request: proto::DeleteChannel,
2540 response: Response<proto::DeleteChannel>,
2541 session: Session,
2542) -> Result<()> {
2543 let db = session.db().await;
2544
2545 let channel_id = request.channel_id;
2546 let (root_channel, removed_channels) = db
2547 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2548 .await?;
2549 response.send(proto::Ack {})?;
2550
2551 // Notify members of removed channels
2552 let mut update = proto::UpdateChannels::default();
2553 update
2554 .delete_channels
2555 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2556
2557 let connection_pool = session.connection_pool().await;
2558 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2559 session.peer.send(connection_id, update.clone())?;
2560 }
2561
2562 Ok(())
2563}
2564
2565/// Invite someone to join a channel.
2566async fn invite_channel_member(
2567 request: proto::InviteChannelMember,
2568 response: Response<proto::InviteChannelMember>,
2569 session: Session,
2570) -> Result<()> {
2571 let db = session.db().await;
2572 let channel_id = ChannelId::from_proto(request.channel_id);
2573 let invitee_id = UserId::from_proto(request.user_id);
2574 let InviteMemberResult {
2575 channel,
2576 notifications,
2577 } = db
2578 .invite_channel_member(
2579 channel_id,
2580 invitee_id,
2581 session.user_id,
2582 request.role().into(),
2583 )
2584 .await?;
2585
2586 let update = proto::UpdateChannels {
2587 channel_invitations: vec![channel.to_proto()],
2588 ..Default::default()
2589 };
2590
2591 let connection_pool = session.connection_pool().await;
2592 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2593 session.peer.send(connection_id, update.clone())?;
2594 }
2595
2596 send_notifications(&connection_pool, &session.peer, notifications);
2597
2598 response.send(proto::Ack {})?;
2599 Ok(())
2600}
2601
2602/// remove someone from a channel
2603async fn remove_channel_member(
2604 request: proto::RemoveChannelMember,
2605 response: Response<proto::RemoveChannelMember>,
2606 session: Session,
2607) -> Result<()> {
2608 let db = session.db().await;
2609 let channel_id = ChannelId::from_proto(request.channel_id);
2610 let member_id = UserId::from_proto(request.user_id);
2611
2612 let RemoveChannelMemberResult {
2613 membership_update,
2614 notification_id,
2615 } = db
2616 .remove_channel_member(channel_id, member_id, session.user_id)
2617 .await?;
2618
2619 let mut connection_pool = session.connection_pool().await;
2620 notify_membership_updated(
2621 &mut connection_pool,
2622 membership_update,
2623 member_id,
2624 &session.peer,
2625 );
2626 for connection_id in connection_pool.user_connection_ids(member_id) {
2627 if let Some(notification_id) = notification_id {
2628 session
2629 .peer
2630 .send(
2631 connection_id,
2632 proto::DeleteNotification {
2633 notification_id: notification_id.to_proto(),
2634 },
2635 )
2636 .trace_err();
2637 }
2638 }
2639
2640 response.send(proto::Ack {})?;
2641 Ok(())
2642}
2643
2644/// Toggle the channel between public and private.
2645/// Care is taken to maintain the invariant that public channels only descend from public channels,
2646/// (though members-only channels can appear at any point in the hierarchy).
2647async fn set_channel_visibility(
2648 request: proto::SetChannelVisibility,
2649 response: Response<proto::SetChannelVisibility>,
2650 session: Session,
2651) -> Result<()> {
2652 let db = session.db().await;
2653 let channel_id = ChannelId::from_proto(request.channel_id);
2654 let visibility = request.visibility().into();
2655
2656 let channel_model = db
2657 .set_channel_visibility(channel_id, visibility, session.user_id)
2658 .await?;
2659 let root_id = channel_model.root_id();
2660 let channel = Channel::from_model(channel_model);
2661
2662 let mut connection_pool = session.connection_pool().await;
2663 for (user_id, role) in connection_pool
2664 .channel_user_ids(root_id)
2665 .collect::<Vec<_>>()
2666 .into_iter()
2667 {
2668 let update = if role.can_see_channel(channel.visibility) {
2669 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2670 proto::UpdateChannels {
2671 channels: vec![channel.to_proto()],
2672 ..Default::default()
2673 }
2674 } else {
2675 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2676 proto::UpdateChannels {
2677 delete_channels: vec![channel.id.to_proto()],
2678 ..Default::default()
2679 }
2680 };
2681
2682 for connection_id in connection_pool.user_connection_ids(user_id) {
2683 session.peer.send(connection_id, update.clone())?;
2684 }
2685 }
2686
2687 response.send(proto::Ack {})?;
2688 Ok(())
2689}
2690
2691/// Alter the role for a user in the channel.
2692async fn set_channel_member_role(
2693 request: proto::SetChannelMemberRole,
2694 response: Response<proto::SetChannelMemberRole>,
2695 session: Session,
2696) -> Result<()> {
2697 let db = session.db().await;
2698 let channel_id = ChannelId::from_proto(request.channel_id);
2699 let member_id = UserId::from_proto(request.user_id);
2700 let result = db
2701 .set_channel_member_role(
2702 channel_id,
2703 session.user_id,
2704 member_id,
2705 request.role().into(),
2706 )
2707 .await?;
2708
2709 match result {
2710 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2711 let mut connection_pool = session.connection_pool().await;
2712 notify_membership_updated(
2713 &mut connection_pool,
2714 membership_update,
2715 member_id,
2716 &session.peer,
2717 )
2718 }
2719 db::SetMemberRoleResult::InviteUpdated(channel) => {
2720 let update = proto::UpdateChannels {
2721 channel_invitations: vec![channel.to_proto()],
2722 ..Default::default()
2723 };
2724
2725 for connection_id in session
2726 .connection_pool()
2727 .await
2728 .user_connection_ids(member_id)
2729 {
2730 session.peer.send(connection_id, update.clone())?;
2731 }
2732 }
2733 }
2734
2735 response.send(proto::Ack {})?;
2736 Ok(())
2737}
2738
2739/// Change the name of a channel
2740async fn rename_channel(
2741 request: proto::RenameChannel,
2742 response: Response<proto::RenameChannel>,
2743 session: Session,
2744) -> Result<()> {
2745 let db = session.db().await;
2746 let channel_id = ChannelId::from_proto(request.channel_id);
2747 let channel_model = db
2748 .rename_channel(channel_id, session.user_id, &request.name)
2749 .await?;
2750 let root_id = channel_model.root_id();
2751 let channel = Channel::from_model(channel_model);
2752
2753 response.send(proto::RenameChannelResponse {
2754 channel: Some(channel.to_proto()),
2755 })?;
2756
2757 let connection_pool = session.connection_pool().await;
2758 let update = proto::UpdateChannels {
2759 channels: vec![channel.to_proto()],
2760 ..Default::default()
2761 };
2762 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2763 if role.can_see_channel(channel.visibility) {
2764 session.peer.send(connection_id, update.clone())?;
2765 }
2766 }
2767
2768 Ok(())
2769}
2770
2771/// Move a channel to a new parent.
2772async fn move_channel(
2773 request: proto::MoveChannel,
2774 response: Response<proto::MoveChannel>,
2775 session: Session,
2776) -> Result<()> {
2777 let channel_id = ChannelId::from_proto(request.channel_id);
2778 let to = ChannelId::from_proto(request.to);
2779
2780 let (root_id, channels) = session
2781 .db()
2782 .await
2783 .move_channel(channel_id, to, session.user_id)
2784 .await?;
2785
2786 let connection_pool = session.connection_pool().await;
2787 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2788 let channels = channels
2789 .iter()
2790 .filter_map(|channel| {
2791 if role.can_see_channel(channel.visibility) {
2792 Some(channel.to_proto())
2793 } else {
2794 None
2795 }
2796 })
2797 .collect::<Vec<_>>();
2798 if channels.is_empty() {
2799 continue;
2800 }
2801
2802 let update = proto::UpdateChannels {
2803 channels,
2804 ..Default::default()
2805 };
2806
2807 session.peer.send(connection_id, update.clone())?;
2808 }
2809
2810 response.send(Ack {})?;
2811 Ok(())
2812}
2813
2814/// Get the list of channel members
2815async fn get_channel_members(
2816 request: proto::GetChannelMembers,
2817 response: Response<proto::GetChannelMembers>,
2818 session: Session,
2819) -> Result<()> {
2820 let db = session.db().await;
2821 let channel_id = ChannelId::from_proto(request.channel_id);
2822 let members = db
2823 .get_channel_participant_details(channel_id, session.user_id)
2824 .await?;
2825 response.send(proto::GetChannelMembersResponse { members })?;
2826 Ok(())
2827}
2828
2829/// Accept or decline a channel invitation.
2830async fn respond_to_channel_invite(
2831 request: proto::RespondToChannelInvite,
2832 response: Response<proto::RespondToChannelInvite>,
2833 session: Session,
2834) -> Result<()> {
2835 let db = session.db().await;
2836 let channel_id = ChannelId::from_proto(request.channel_id);
2837 let RespondToChannelInvite {
2838 membership_update,
2839 notifications,
2840 } = db
2841 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2842 .await?;
2843
2844 let mut connection_pool = session.connection_pool().await;
2845 if let Some(membership_update) = membership_update {
2846 notify_membership_updated(
2847 &mut connection_pool,
2848 membership_update,
2849 session.user_id,
2850 &session.peer,
2851 );
2852 } else {
2853 let update = proto::UpdateChannels {
2854 remove_channel_invitations: vec![channel_id.to_proto()],
2855 ..Default::default()
2856 };
2857
2858 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2859 session.peer.send(connection_id, update.clone())?;
2860 }
2861 };
2862
2863 send_notifications(&connection_pool, &session.peer, notifications);
2864
2865 response.send(proto::Ack {})?;
2866
2867 Ok(())
2868}
2869
2870/// Join the channels' room
2871async fn join_channel(
2872 request: proto::JoinChannel,
2873 response: Response<proto::JoinChannel>,
2874 session: Session,
2875) -> Result<()> {
2876 let channel_id = ChannelId::from_proto(request.channel_id);
2877 join_channel_internal(channel_id, Box::new(response), session).await
2878}
2879
2880trait JoinChannelInternalResponse {
2881 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2882}
2883impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2884 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2885 Response::<proto::JoinChannel>::send(self, result)
2886 }
2887}
2888impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2889 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2890 Response::<proto::JoinRoom>::send(self, result)
2891 }
2892}
2893
2894async fn join_channel_internal(
2895 channel_id: ChannelId,
2896 response: Box<impl JoinChannelInternalResponse>,
2897 session: Session,
2898) -> Result<()> {
2899 let joined_room = {
2900 leave_room_for_session(&session).await?;
2901 let db = session.db().await;
2902
2903 let (joined_room, membership_updated, role) = db
2904 .join_channel(channel_id, session.user_id, session.connection_id)
2905 .await?;
2906
2907 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2908 let (can_publish, token) = if role == ChannelRole::Guest {
2909 (
2910 false,
2911 live_kit
2912 .guest_token(
2913 &joined_room.room.live_kit_room,
2914 &session.user_id.to_string(),
2915 )
2916 .trace_err()?,
2917 )
2918 } else {
2919 (
2920 true,
2921 live_kit
2922 .room_token(
2923 &joined_room.room.live_kit_room,
2924 &session.user_id.to_string(),
2925 )
2926 .trace_err()?,
2927 )
2928 };
2929
2930 Some(LiveKitConnectionInfo {
2931 server_url: live_kit.url().into(),
2932 token,
2933 can_publish,
2934 })
2935 });
2936
2937 response.send(proto::JoinRoomResponse {
2938 room: Some(joined_room.room.clone()),
2939 channel_id: joined_room
2940 .channel
2941 .as_ref()
2942 .map(|channel| channel.id.to_proto()),
2943 live_kit_connection_info,
2944 })?;
2945
2946 let mut connection_pool = session.connection_pool().await;
2947 if let Some(membership_updated) = membership_updated {
2948 notify_membership_updated(
2949 &mut connection_pool,
2950 membership_updated,
2951 session.user_id,
2952 &session.peer,
2953 );
2954 }
2955
2956 room_updated(&joined_room.room, &session.peer);
2957
2958 joined_room
2959 };
2960
2961 channel_updated(
2962 &joined_room
2963 .channel
2964 .ok_or_else(|| anyhow!("channel not returned"))?,
2965 &joined_room.room,
2966 &session.peer,
2967 &*session.connection_pool().await,
2968 );
2969
2970 update_user_contacts(session.user_id, &session).await?;
2971 Ok(())
2972}
2973
2974/// Start editing the channel notes
2975async fn join_channel_buffer(
2976 request: proto::JoinChannelBuffer,
2977 response: Response<proto::JoinChannelBuffer>,
2978 session: Session,
2979) -> Result<()> {
2980 let db = session.db().await;
2981 let channel_id = ChannelId::from_proto(request.channel_id);
2982
2983 let open_response = db
2984 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2985 .await?;
2986
2987 let collaborators = open_response.collaborators.clone();
2988 response.send(open_response)?;
2989
2990 let update = UpdateChannelBufferCollaborators {
2991 channel_id: channel_id.to_proto(),
2992 collaborators: collaborators.clone(),
2993 };
2994 channel_buffer_updated(
2995 session.connection_id,
2996 collaborators
2997 .iter()
2998 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
2999 &update,
3000 &session.peer,
3001 );
3002
3003 Ok(())
3004}
3005
3006/// Edit the channel notes
3007async fn update_channel_buffer(
3008 request: proto::UpdateChannelBuffer,
3009 session: Session,
3010) -> Result<()> {
3011 let db = session.db().await;
3012 let channel_id = ChannelId::from_proto(request.channel_id);
3013
3014 let (collaborators, non_collaborators, epoch, version) = db
3015 .update_channel_buffer(channel_id, session.user_id, &request.operations)
3016 .await?;
3017
3018 channel_buffer_updated(
3019 session.connection_id,
3020 collaborators,
3021 &proto::UpdateChannelBuffer {
3022 channel_id: channel_id.to_proto(),
3023 operations: request.operations,
3024 },
3025 &session.peer,
3026 );
3027
3028 let pool = &*session.connection_pool().await;
3029
3030 broadcast(
3031 None,
3032 non_collaborators
3033 .iter()
3034 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3035 |peer_id| {
3036 session.peer.send(
3037 peer_id,
3038 proto::UpdateChannels {
3039 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3040 channel_id: channel_id.to_proto(),
3041 epoch: epoch as u64,
3042 version: version.clone(),
3043 }],
3044 ..Default::default()
3045 },
3046 )
3047 },
3048 );
3049
3050 Ok(())
3051}
3052
3053/// Rejoin the channel notes after a connection blip
3054async fn rejoin_channel_buffers(
3055 request: proto::RejoinChannelBuffers,
3056 response: Response<proto::RejoinChannelBuffers>,
3057 session: Session,
3058) -> Result<()> {
3059 let db = session.db().await;
3060 let buffers = db
3061 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
3062 .await?;
3063
3064 for rejoined_buffer in &buffers {
3065 let collaborators_to_notify = rejoined_buffer
3066 .buffer
3067 .collaborators
3068 .iter()
3069 .filter_map(|c| Some(c.peer_id?.into()));
3070 channel_buffer_updated(
3071 session.connection_id,
3072 collaborators_to_notify,
3073 &proto::UpdateChannelBufferCollaborators {
3074 channel_id: rejoined_buffer.buffer.channel_id,
3075 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3076 },
3077 &session.peer,
3078 );
3079 }
3080
3081 response.send(proto::RejoinChannelBuffersResponse {
3082 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3083 })?;
3084
3085 Ok(())
3086}
3087
3088/// Stop editing the channel notes
3089async fn leave_channel_buffer(
3090 request: proto::LeaveChannelBuffer,
3091 response: Response<proto::LeaveChannelBuffer>,
3092 session: Session,
3093) -> Result<()> {
3094 let db = session.db().await;
3095 let channel_id = ChannelId::from_proto(request.channel_id);
3096
3097 let left_buffer = db
3098 .leave_channel_buffer(channel_id, session.connection_id)
3099 .await?;
3100
3101 response.send(Ack {})?;
3102
3103 channel_buffer_updated(
3104 session.connection_id,
3105 left_buffer.connections,
3106 &proto::UpdateChannelBufferCollaborators {
3107 channel_id: channel_id.to_proto(),
3108 collaborators: left_buffer.collaborators,
3109 },
3110 &session.peer,
3111 );
3112
3113 Ok(())
3114}
3115
3116fn channel_buffer_updated<T: EnvelopedMessage>(
3117 sender_id: ConnectionId,
3118 collaborators: impl IntoIterator<Item = ConnectionId>,
3119 message: &T,
3120 peer: &Peer,
3121) {
3122 broadcast(Some(sender_id), collaborators, |peer_id| {
3123 peer.send(peer_id, message.clone())
3124 });
3125}
3126
3127fn send_notifications(
3128 connection_pool: &ConnectionPool,
3129 peer: &Peer,
3130 notifications: db::NotificationBatch,
3131) {
3132 for (user_id, notification) in notifications {
3133 for connection_id in connection_pool.user_connection_ids(user_id) {
3134 if let Err(error) = peer.send(
3135 connection_id,
3136 proto::AddNotification {
3137 notification: Some(notification.clone()),
3138 },
3139 ) {
3140 tracing::error!(
3141 "failed to send notification to {:?} {}",
3142 connection_id,
3143 error
3144 );
3145 }
3146 }
3147 }
3148}
3149
3150/// Send a message to the channel
3151async fn send_channel_message(
3152 request: proto::SendChannelMessage,
3153 response: Response<proto::SendChannelMessage>,
3154 session: Session,
3155) -> Result<()> {
3156 // Validate the message body.
3157 let body = request.body.trim().to_string();
3158 if body.len() > MAX_MESSAGE_LEN {
3159 return Err(anyhow!("message is too long"))?;
3160 }
3161 if body.is_empty() {
3162 return Err(anyhow!("message can't be blank"))?;
3163 }
3164
3165 // TODO: adjust mentions if body is trimmed
3166
3167 let timestamp = OffsetDateTime::now_utc();
3168 let nonce = request
3169 .nonce
3170 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3171
3172 let channel_id = ChannelId::from_proto(request.channel_id);
3173 let CreatedChannelMessage {
3174 message_id,
3175 participant_connection_ids,
3176 channel_members,
3177 notifications,
3178 } = session
3179 .db()
3180 .await
3181 .create_channel_message(
3182 channel_id,
3183 session.user_id,
3184 &body,
3185 &request.mentions,
3186 timestamp,
3187 nonce.clone().into(),
3188 match request.reply_to_message_id {
3189 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3190 None => None,
3191 },
3192 )
3193 .await?;
3194 let message = proto::ChannelMessage {
3195 sender_id: session.user_id.to_proto(),
3196 id: message_id.to_proto(),
3197 body,
3198 mentions: request.mentions,
3199 timestamp: timestamp.unix_timestamp() as u64,
3200 nonce: Some(nonce),
3201 reply_to_message_id: request.reply_to_message_id,
3202 };
3203 broadcast(
3204 Some(session.connection_id),
3205 participant_connection_ids,
3206 |connection| {
3207 session.peer.send(
3208 connection,
3209 proto::ChannelMessageSent {
3210 channel_id: channel_id.to_proto(),
3211 message: Some(message.clone()),
3212 },
3213 )
3214 },
3215 );
3216 response.send(proto::SendChannelMessageResponse {
3217 message: Some(message),
3218 })?;
3219
3220 let pool = &*session.connection_pool().await;
3221 broadcast(
3222 None,
3223 channel_members
3224 .iter()
3225 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3226 |peer_id| {
3227 session.peer.send(
3228 peer_id,
3229 proto::UpdateChannels {
3230 latest_channel_message_ids: vec![proto::ChannelMessageId {
3231 channel_id: channel_id.to_proto(),
3232 message_id: message_id.to_proto(),
3233 }],
3234 ..Default::default()
3235 },
3236 )
3237 },
3238 );
3239 send_notifications(pool, &session.peer, notifications);
3240
3241 Ok(())
3242}
3243
3244/// Delete a channel message
3245async fn remove_channel_message(
3246 request: proto::RemoveChannelMessage,
3247 response: Response<proto::RemoveChannelMessage>,
3248 session: Session,
3249) -> Result<()> {
3250 let channel_id = ChannelId::from_proto(request.channel_id);
3251 let message_id = MessageId::from_proto(request.message_id);
3252 let connection_ids = session
3253 .db()
3254 .await
3255 .remove_channel_message(channel_id, message_id, session.user_id)
3256 .await?;
3257 broadcast(Some(session.connection_id), connection_ids, |connection| {
3258 session.peer.send(connection, request.clone())
3259 });
3260 response.send(proto::Ack {})?;
3261 Ok(())
3262}
3263
3264/// Mark a channel message as read
3265async fn acknowledge_channel_message(
3266 request: proto::AckChannelMessage,
3267 session: Session,
3268) -> Result<()> {
3269 let channel_id = ChannelId::from_proto(request.channel_id);
3270 let message_id = MessageId::from_proto(request.message_id);
3271 let notifications = session
3272 .db()
3273 .await
3274 .observe_channel_message(channel_id, session.user_id, message_id)
3275 .await?;
3276 send_notifications(
3277 &*session.connection_pool().await,
3278 &session.peer,
3279 notifications,
3280 );
3281 Ok(())
3282}
3283
3284/// Mark a buffer version as synced
3285async fn acknowledge_buffer_version(
3286 request: proto::AckBufferOperation,
3287 session: Session,
3288) -> Result<()> {
3289 let buffer_id = BufferId::from_proto(request.buffer_id);
3290 session
3291 .db()
3292 .await
3293 .observe_buffer_version(
3294 buffer_id,
3295 session.user_id,
3296 request.epoch as i32,
3297 &request.version,
3298 )
3299 .await?;
3300 Ok(())
3301}
3302
3303struct CompleteWithLanguageModelRateLimit;
3304
3305impl RateLimit for CompleteWithLanguageModelRateLimit {
3306 fn capacity() -> usize {
3307 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3308 .ok()
3309 .and_then(|v| v.parse().ok())
3310 .unwrap_or(120) // Picked arbitrarily
3311 }
3312
3313 fn refill_duration() -> chrono::Duration {
3314 chrono::Duration::hours(1)
3315 }
3316
3317 fn db_name() -> &'static str {
3318 "complete-with-language-model"
3319 }
3320}
3321
3322async fn complete_with_language_model(
3323 request: proto::CompleteWithLanguageModel,
3324 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3325 session: Session,
3326 open_ai_api_key: Option<Arc<str>>,
3327 google_ai_api_key: Option<Arc<str>>,
3328) -> Result<()> {
3329 authorize_access_to_language_models(&session).await?;
3330 session
3331 .rate_limiter
3332 .check::<CompleteWithLanguageModelRateLimit>(session.user_id)
3333 .await?;
3334
3335 if request.model.starts_with("gpt") {
3336 let api_key =
3337 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
3338 complete_with_open_ai(request, response, session, api_key).await?;
3339 } else if request.model.starts_with("gemini") {
3340 let api_key = google_ai_api_key
3341 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3342 complete_with_google_ai(request, response, session, api_key).await?;
3343 }
3344
3345 Ok(())
3346}
3347
3348async fn complete_with_open_ai(
3349 request: proto::CompleteWithLanguageModel,
3350 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3351 session: Session,
3352 api_key: Arc<str>,
3353) -> Result<()> {
3354 const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
3355
3356 let mut completion_stream = open_ai::stream_completion(
3357 &session.http_client,
3358 OPEN_AI_API_URL,
3359 &api_key,
3360 crate::ai::language_model_request_to_open_ai(request)?,
3361 )
3362 .await
3363 .context("open_ai::stream_completion request failed")?;
3364
3365 while let Some(event) = completion_stream.next().await {
3366 let event = event?;
3367 response.send(proto::LanguageModelResponse {
3368 choices: event
3369 .choices
3370 .into_iter()
3371 .map(|choice| proto::LanguageModelChoiceDelta {
3372 index: choice.index,
3373 delta: Some(proto::LanguageModelResponseMessage {
3374 role: choice.delta.role.map(|role| match role {
3375 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
3376 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
3377 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
3378 } as i32),
3379 content: choice.delta.content,
3380 }),
3381 finish_reason: choice.finish_reason,
3382 })
3383 .collect(),
3384 })?;
3385 }
3386
3387 Ok(())
3388}
3389
3390async fn complete_with_google_ai(
3391 request: proto::CompleteWithLanguageModel,
3392 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3393 session: Session,
3394 api_key: Arc<str>,
3395) -> Result<()> {
3396 let mut stream = google_ai::stream_generate_content(
3397 &session.http_client,
3398 google_ai::API_URL,
3399 api_key.as_ref(),
3400 crate::ai::language_model_request_to_google_ai(request)?,
3401 )
3402 .await
3403 .context("google_ai::stream_generate_content request failed")?;
3404
3405 while let Some(event) = stream.next().await {
3406 let event = event?;
3407 response.send(proto::LanguageModelResponse {
3408 choices: event
3409 .candidates
3410 .unwrap_or_default()
3411 .into_iter()
3412 .map(|candidate| proto::LanguageModelChoiceDelta {
3413 index: candidate.index as u32,
3414 delta: Some(proto::LanguageModelResponseMessage {
3415 role: Some(match candidate.content.role {
3416 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
3417 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
3418 } as i32),
3419 content: Some(
3420 candidate
3421 .content
3422 .parts
3423 .into_iter()
3424 .filter_map(|part| match part {
3425 google_ai::Part::TextPart(part) => Some(part.text),
3426 google_ai::Part::InlineDataPart(_) => None,
3427 })
3428 .collect(),
3429 ),
3430 }),
3431 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
3432 })
3433 .collect(),
3434 })?;
3435 }
3436
3437 Ok(())
3438}
3439
3440struct CountTokensWithLanguageModelRateLimit;
3441
3442impl RateLimit for CountTokensWithLanguageModelRateLimit {
3443 fn capacity() -> usize {
3444 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3445 .ok()
3446 .and_then(|v| v.parse().ok())
3447 .unwrap_or(600) // Picked arbitrarily
3448 }
3449
3450 fn refill_duration() -> chrono::Duration {
3451 chrono::Duration::hours(1)
3452 }
3453
3454 fn db_name() -> &'static str {
3455 "count-tokens-with-language-model"
3456 }
3457}
3458
3459async fn count_tokens_with_language_model(
3460 request: proto::CountTokensWithLanguageModel,
3461 response: Response<proto::CountTokensWithLanguageModel>,
3462 session: Session,
3463 google_ai_api_key: Option<Arc<str>>,
3464) -> Result<()> {
3465 authorize_access_to_language_models(&session).await?;
3466
3467 if !request.model.starts_with("gemini") {
3468 return Err(anyhow!(
3469 "counting tokens for model: {:?} is not supported",
3470 request.model
3471 ))?;
3472 }
3473
3474 session
3475 .rate_limiter
3476 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id)
3477 .await?;
3478
3479 let api_key = google_ai_api_key
3480 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3481 let tokens_response = google_ai::count_tokens(
3482 &session.http_client,
3483 google_ai::API_URL,
3484 &api_key,
3485 crate::ai::count_tokens_request_to_google_ai(request)?,
3486 )
3487 .await?;
3488 response.send(proto::CountTokensResponse {
3489 token_count: tokens_response.total_tokens as u32,
3490 })?;
3491 Ok(())
3492}
3493
3494async fn authorize_access_to_language_models(session: &Session) -> Result<(), Error> {
3495 let db = session.db().await;
3496 let flags = db.get_user_flags(session.user_id).await?;
3497 if flags.iter().any(|flag| flag == "language-models") {
3498 Ok(())
3499 } else {
3500 Err(anyhow!("permission denied"))?
3501 }
3502}
3503
3504/// Start receiving chat updates for a channel
3505async fn join_channel_chat(
3506 request: proto::JoinChannelChat,
3507 response: Response<proto::JoinChannelChat>,
3508 session: Session,
3509) -> Result<()> {
3510 let channel_id = ChannelId::from_proto(request.channel_id);
3511
3512 let db = session.db().await;
3513 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3514 .await?;
3515 let messages = db
3516 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3517 .await?;
3518 response.send(proto::JoinChannelChatResponse {
3519 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3520 messages,
3521 })?;
3522 Ok(())
3523}
3524
3525/// Stop receiving chat updates for a channel
3526async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3527 let channel_id = ChannelId::from_proto(request.channel_id);
3528 session
3529 .db()
3530 .await
3531 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3532 .await?;
3533 Ok(())
3534}
3535
3536/// Retrieve the chat history for a channel
3537async fn get_channel_messages(
3538 request: proto::GetChannelMessages,
3539 response: Response<proto::GetChannelMessages>,
3540 session: Session,
3541) -> Result<()> {
3542 let channel_id = ChannelId::from_proto(request.channel_id);
3543 let messages = session
3544 .db()
3545 .await
3546 .get_channel_messages(
3547 channel_id,
3548 session.user_id,
3549 MESSAGE_COUNT_PER_PAGE,
3550 Some(MessageId::from_proto(request.before_message_id)),
3551 )
3552 .await?;
3553 response.send(proto::GetChannelMessagesResponse {
3554 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3555 messages,
3556 })?;
3557 Ok(())
3558}
3559
3560/// Retrieve specific chat messages
3561async fn get_channel_messages_by_id(
3562 request: proto::GetChannelMessagesById,
3563 response: Response<proto::GetChannelMessagesById>,
3564 session: Session,
3565) -> Result<()> {
3566 let message_ids = request
3567 .message_ids
3568 .iter()
3569 .map(|id| MessageId::from_proto(*id))
3570 .collect::<Vec<_>>();
3571 let messages = session
3572 .db()
3573 .await
3574 .get_channel_messages_by_id(session.user_id, &message_ids)
3575 .await?;
3576 response.send(proto::GetChannelMessagesResponse {
3577 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3578 messages,
3579 })?;
3580 Ok(())
3581}
3582
3583/// Retrieve the current users notifications
3584async fn get_notifications(
3585 request: proto::GetNotifications,
3586 response: Response<proto::GetNotifications>,
3587 session: Session,
3588) -> Result<()> {
3589 let notifications = session
3590 .db()
3591 .await
3592 .get_notifications(
3593 session.user_id,
3594 NOTIFICATION_COUNT_PER_PAGE,
3595 request
3596 .before_id
3597 .map(|id| db::NotificationId::from_proto(id)),
3598 )
3599 .await?;
3600 response.send(proto::GetNotificationsResponse {
3601 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3602 notifications,
3603 })?;
3604 Ok(())
3605}
3606
3607/// Mark notifications as read
3608async fn mark_notification_as_read(
3609 request: proto::MarkNotificationRead,
3610 response: Response<proto::MarkNotificationRead>,
3611 session: Session,
3612) -> Result<()> {
3613 let database = &session.db().await;
3614 let notifications = database
3615 .mark_notification_as_read_by_id(
3616 session.user_id,
3617 NotificationId::from_proto(request.notification_id),
3618 )
3619 .await?;
3620 send_notifications(
3621 &*session.connection_pool().await,
3622 &session.peer,
3623 notifications,
3624 );
3625 response.send(proto::Ack {})?;
3626 Ok(())
3627}
3628
3629/// Get the current users information
3630async fn get_private_user_info(
3631 _request: proto::GetPrivateUserInfo,
3632 response: Response<proto::GetPrivateUserInfo>,
3633 session: Session,
3634) -> Result<()> {
3635 let db = session.db().await;
3636
3637 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3638 let user = db
3639 .get_user_by_id(session.user_id)
3640 .await?
3641 .ok_or_else(|| anyhow!("user not found"))?;
3642 let flags = db.get_user_flags(session.user_id).await?;
3643
3644 response.send(proto::GetPrivateUserInfoResponse {
3645 metrics_id,
3646 staff: user.admin,
3647 flags,
3648 })?;
3649 Ok(())
3650}
3651
3652fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3653 match message {
3654 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3655 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3656 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3657 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3658 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3659 code: frame.code.into(),
3660 reason: frame.reason,
3661 })),
3662 }
3663}
3664
3665fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3666 match message {
3667 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3668 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3669 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3670 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3671 AxumMessage::Close(frame) => {
3672 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3673 code: frame.code.into(),
3674 reason: frame.reason,
3675 }))
3676 }
3677 }
3678}
3679
3680fn notify_membership_updated(
3681 connection_pool: &mut ConnectionPool,
3682 result: MembershipUpdated,
3683 user_id: UserId,
3684 peer: &Peer,
3685) {
3686 for membership in &result.new_channels.channel_memberships {
3687 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3688 }
3689 for channel_id in &result.removed_channels {
3690 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3691 }
3692
3693 let user_channels_update = proto::UpdateUserChannels {
3694 channel_memberships: result
3695 .new_channels
3696 .channel_memberships
3697 .iter()
3698 .map(|cm| proto::ChannelMembership {
3699 channel_id: cm.channel_id.to_proto(),
3700 role: cm.role.into(),
3701 })
3702 .collect(),
3703 ..Default::default()
3704 };
3705
3706 let mut update = build_channels_update(result.new_channels, vec![]);
3707 update.delete_channels = result
3708 .removed_channels
3709 .into_iter()
3710 .map(|id| id.to_proto())
3711 .collect();
3712 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3713
3714 for connection_id in connection_pool.user_connection_ids(user_id) {
3715 peer.send(connection_id, user_channels_update.clone())
3716 .trace_err();
3717 peer.send(connection_id, update.clone()).trace_err();
3718 }
3719}
3720
3721fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3722 proto::UpdateUserChannels {
3723 channel_memberships: channels
3724 .channel_memberships
3725 .iter()
3726 .map(|m| proto::ChannelMembership {
3727 channel_id: m.channel_id.to_proto(),
3728 role: m.role.into(),
3729 })
3730 .collect(),
3731 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3732 observed_channel_message_id: channels.observed_channel_messages.clone(),
3733 }
3734}
3735
3736fn build_channels_update(
3737 channels: ChannelsForUser,
3738 channel_invites: Vec<db::Channel>,
3739) -> proto::UpdateChannels {
3740 let mut update = proto::UpdateChannels::default();
3741
3742 for channel in channels.channels {
3743 update.channels.push(channel.to_proto());
3744 }
3745
3746 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3747 update.latest_channel_message_ids = channels.latest_channel_messages;
3748
3749 for (channel_id, participants) in channels.channel_participants {
3750 update
3751 .channel_participants
3752 .push(proto::ChannelParticipants {
3753 channel_id: channel_id.to_proto(),
3754 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3755 });
3756 }
3757
3758 for channel in channel_invites {
3759 update.channel_invitations.push(channel.to_proto());
3760 }
3761 for project in channels.hosted_projects {
3762 update.hosted_projects.push(project);
3763 }
3764
3765 update
3766}
3767
3768fn build_initial_contacts_update(
3769 contacts: Vec<db::Contact>,
3770 pool: &ConnectionPool,
3771) -> proto::UpdateContacts {
3772 let mut update = proto::UpdateContacts::default();
3773
3774 for contact in contacts {
3775 match contact {
3776 db::Contact::Accepted { user_id, busy } => {
3777 update.contacts.push(contact_for_user(user_id, busy, &pool));
3778 }
3779 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3780 db::Contact::Incoming { user_id } => {
3781 update
3782 .incoming_requests
3783 .push(proto::IncomingContactRequest {
3784 requester_id: user_id.to_proto(),
3785 })
3786 }
3787 }
3788 }
3789
3790 update
3791}
3792
3793fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3794 proto::Contact {
3795 user_id: user_id.to_proto(),
3796 online: pool.is_user_online(user_id),
3797 busy,
3798 }
3799}
3800
3801fn room_updated(room: &proto::Room, peer: &Peer) {
3802 broadcast(
3803 None,
3804 room.participants
3805 .iter()
3806 .filter_map(|participant| Some(participant.peer_id?.into())),
3807 |peer_id| {
3808 peer.send(
3809 peer_id,
3810 proto::RoomUpdated {
3811 room: Some(room.clone()),
3812 },
3813 )
3814 },
3815 );
3816}
3817
3818fn channel_updated(
3819 channel: &db::channel::Model,
3820 room: &proto::Room,
3821 peer: &Peer,
3822 pool: &ConnectionPool,
3823) {
3824 let participants = room
3825 .participants
3826 .iter()
3827 .map(|p| p.user_id)
3828 .collect::<Vec<_>>();
3829
3830 broadcast(
3831 None,
3832 pool.channel_connection_ids(channel.root_id())
3833 .filter_map(|(channel_id, role)| {
3834 role.can_see_channel(channel.visibility).then(|| channel_id)
3835 }),
3836 |peer_id| {
3837 peer.send(
3838 peer_id,
3839 proto::UpdateChannels {
3840 channel_participants: vec![proto::ChannelParticipants {
3841 channel_id: channel.id.to_proto(),
3842 participant_user_ids: participants.clone(),
3843 }],
3844 ..Default::default()
3845 },
3846 )
3847 },
3848 );
3849}
3850
3851async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3852 let db = session.db().await;
3853
3854 let contacts = db.get_contacts(user_id).await?;
3855 let busy = db.is_user_busy(user_id).await?;
3856
3857 let pool = session.connection_pool().await;
3858 let updated_contact = contact_for_user(user_id, busy, &pool);
3859 for contact in contacts {
3860 if let db::Contact::Accepted {
3861 user_id: contact_user_id,
3862 ..
3863 } = contact
3864 {
3865 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3866 session
3867 .peer
3868 .send(
3869 contact_conn_id,
3870 proto::UpdateContacts {
3871 contacts: vec![updated_contact.clone()],
3872 remove_contacts: Default::default(),
3873 incoming_requests: Default::default(),
3874 remove_incoming_requests: Default::default(),
3875 outgoing_requests: Default::default(),
3876 remove_outgoing_requests: Default::default(),
3877 },
3878 )
3879 .trace_err();
3880 }
3881 }
3882 }
3883 Ok(())
3884}
3885
3886async fn leave_room_for_session(session: &Session) -> Result<()> {
3887 let mut contacts_to_update = HashSet::default();
3888
3889 let room_id;
3890 let canceled_calls_to_user_ids;
3891 let live_kit_room;
3892 let delete_live_kit_room;
3893 let room;
3894 let channel;
3895
3896 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3897 contacts_to_update.insert(session.user_id);
3898
3899 for project in left_room.left_projects.values() {
3900 project_left(project, session);
3901 }
3902
3903 room_id = RoomId::from_proto(left_room.room.id);
3904 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3905 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3906 delete_live_kit_room = left_room.deleted;
3907 room = mem::take(&mut left_room.room);
3908 channel = mem::take(&mut left_room.channel);
3909
3910 room_updated(&room, &session.peer);
3911 } else {
3912 return Ok(());
3913 }
3914
3915 if let Some(channel) = channel {
3916 channel_updated(
3917 &channel,
3918 &room,
3919 &session.peer,
3920 &*session.connection_pool().await,
3921 );
3922 }
3923
3924 {
3925 let pool = session.connection_pool().await;
3926 for canceled_user_id in canceled_calls_to_user_ids {
3927 for connection_id in pool.user_connection_ids(canceled_user_id) {
3928 session
3929 .peer
3930 .send(
3931 connection_id,
3932 proto::CallCanceled {
3933 room_id: room_id.to_proto(),
3934 },
3935 )
3936 .trace_err();
3937 }
3938 contacts_to_update.insert(canceled_user_id);
3939 }
3940 }
3941
3942 for contact_user_id in contacts_to_update {
3943 update_user_contacts(contact_user_id, &session).await?;
3944 }
3945
3946 if let Some(live_kit) = session.live_kit_client.as_ref() {
3947 live_kit
3948 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
3949 .await
3950 .trace_err();
3951
3952 if delete_live_kit_room {
3953 live_kit.delete_room(live_kit_room).await.trace_err();
3954 }
3955 }
3956
3957 Ok(())
3958}
3959
3960async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3961 let left_channel_buffers = session
3962 .db()
3963 .await
3964 .leave_channel_buffers(session.connection_id)
3965 .await?;
3966
3967 for left_buffer in left_channel_buffers {
3968 channel_buffer_updated(
3969 session.connection_id,
3970 left_buffer.connections,
3971 &proto::UpdateChannelBufferCollaborators {
3972 channel_id: left_buffer.channel_id.to_proto(),
3973 collaborators: left_buffer.collaborators,
3974 },
3975 &session.peer,
3976 );
3977 }
3978
3979 Ok(())
3980}
3981
3982fn project_left(project: &db::LeftProject, session: &Session) {
3983 for connection_id in &project.connection_ids {
3984 if project.host_user_id == Some(session.user_id) {
3985 session
3986 .peer
3987 .send(
3988 *connection_id,
3989 proto::UnshareProject {
3990 project_id: project.id.to_proto(),
3991 },
3992 )
3993 .trace_err();
3994 } else {
3995 session
3996 .peer
3997 .send(
3998 *connection_id,
3999 proto::RemoveProjectCollaborator {
4000 project_id: project.id.to_proto(),
4001 peer_id: Some(session.connection_id.into()),
4002 },
4003 )
4004 .trace_err();
4005 }
4006 }
4007}
4008
4009pub trait ResultExt {
4010 type Ok;
4011
4012 fn trace_err(self) -> Option<Self::Ok>;
4013}
4014
4015impl<T, E> ResultExt for Result<T, E>
4016where
4017 E: std::fmt::Debug,
4018{
4019 type Ok = T;
4020
4021 fn trace_err(self) -> Option<T> {
4022 match self {
4023 Ok(value) => Some(value),
4024 Err(error) => {
4025 tracing::error!("{:?}", error);
4026 None
4027 }
4028 }
4029 }
4030}