1mod connection_pool;
2
3use crate::{
4 auth::{self},
5 db::{
6 self, dev_server, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
8 NotificationId, Project, ProjectId, RemoveChannelMemberResult, ReplicaId,
9 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
10 },
11 executor::Executor,
12 AppState, Error, RateLimit, RateLimiter, Result,
13};
14use anyhow::{anyhow, Context as _};
15use async_tungstenite::tungstenite::{
16 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
17};
18use axum::{
19 body::Body,
20 extract::{
21 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
22 ConnectInfo, WebSocketUpgrade,
23 },
24 headers::{Header, HeaderName},
25 http::StatusCode,
26 middleware,
27 response::IntoResponse,
28 routing::get,
29 Extension, Router, TypedHeader,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use core::fmt::{self, Debug, Formatter};
34
35use futures::{
36 channel::oneshot,
37 future::{self, BoxFuture},
38 stream::FuturesUnordered,
39 FutureExt, SinkExt, StreamExt, TryStreamExt,
40};
41use prometheus::{register_int_gauge, IntGauge};
42use rpc::{
43 proto::{
44 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
45 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
46 },
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48};
49use serde::{Serialize, Serializer};
50use std::{
51 any::TypeId,
52 future::Future,
53 marker::PhantomData,
54 mem,
55 net::SocketAddr,
56 ops::{Deref, DerefMut},
57 rc::Rc,
58 sync::{
59 atomic::{AtomicBool, Ordering::SeqCst},
60 Arc, OnceLock,
61 },
62 time::{Duration, Instant},
63};
64use time::OffsetDateTime;
65use tokio::sync::{watch, Semaphore};
66use tower::ServiceBuilder;
67use tracing::{
68 field::{self},
69 info_span, instrument, Instrument,
70};
71use util::{http::IsahcHttpClient, SemanticVersion};
72
73pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
74
75// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
76pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
77
78const MESSAGE_COUNT_PER_PAGE: usize = 100;
79const MAX_MESSAGE_LEN: usize = 1024;
80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
81
82type MessageHandler =
83 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
84
85struct Response<R> {
86 peer: Arc<Peer>,
87 receipt: Receipt<R>,
88 responded: Arc<AtomicBool>,
89}
90
91impl<R: RequestMessage> Response<R> {
92 fn send(self, payload: R::Response) -> Result<()> {
93 self.responded.store(true, SeqCst);
94 self.peer.respond(self.receipt, payload)?;
95 Ok(())
96 }
97}
98
99struct StreamingResponse<R: RequestMessage> {
100 peer: Arc<Peer>,
101 receipt: Receipt<R>,
102}
103
104impl<R: RequestMessage> StreamingResponse<R> {
105 fn send(&self, payload: R::Response) -> Result<()> {
106 self.peer.respond(self.receipt, payload)?;
107 Ok(())
108 }
109}
110
111#[derive(Clone, Debug)]
112pub enum Principal {
113 User(User),
114 Impersonated { user: User, admin: User },
115 DevServer(dev_server::Model),
116}
117
118impl Principal {
119 fn update_span(&self, span: &tracing::Span) {
120 match &self {
121 Principal::User(user) => {
122 span.record("user_id", &user.id.0);
123 span.record("login", &user.github_login);
124 }
125 Principal::Impersonated { user, admin } => {
126 span.record("user_id", &user.id.0);
127 span.record("login", &user.github_login);
128 span.record("impersonator", &admin.github_login);
129 }
130 Principal::DevServer(dev_server) => {
131 span.record("dev_server_id", &dev_server.id.0);
132 }
133 }
134 }
135}
136
137#[derive(Clone)]
138struct Session {
139 principal: Principal,
140 connection_id: ConnectionId,
141 db: Arc<tokio::sync::Mutex<DbHandle>>,
142 peer: Arc<Peer>,
143 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
144 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
145 http_client: IsahcHttpClient,
146 rate_limiter: Arc<RateLimiter>,
147 _executor: Executor,
148}
149
150impl Session {
151 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
152 #[cfg(test)]
153 tokio::task::yield_now().await;
154 let guard = self.db.lock().await;
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 guard
158 }
159
160 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
161 #[cfg(test)]
162 tokio::task::yield_now().await;
163 let guard = self.connection_pool.lock();
164 ConnectionPoolGuard {
165 guard,
166 _not_send: PhantomData,
167 }
168 }
169
170 fn for_user(self) -> Option<UserSession> {
171 UserSession::new(self)
172 }
173
174 fn user_id(&self) -> Option<UserId> {
175 match &self.principal {
176 Principal::User(user) => Some(user.id),
177 Principal::Impersonated { user, .. } => Some(user.id),
178 Principal::DevServer(_) => None,
179 }
180 }
181}
182
183impl Debug for Session {
184 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
185 let mut result = f.debug_struct("Session");
186 match &self.principal {
187 Principal::User(user) => {
188 result.field("user", &user.github_login);
189 }
190 Principal::Impersonated { user, admin } => {
191 result.field("user", &user.github_login);
192 result.field("impersonator", &admin.github_login);
193 }
194 Principal::DevServer(dev_server) => {
195 result.field("dev_server", &dev_server.id);
196 }
197 }
198 result.field("connection_id", &self.connection_id).finish()
199 }
200}
201
202struct UserSession(Session);
203
204impl UserSession {
205 pub fn new(s: Session) -> Option<Self> {
206 s.user_id().map(|_| UserSession(s))
207 }
208 pub fn user_id(&self) -> UserId {
209 self.0.user_id().unwrap()
210 }
211}
212
213impl Deref for UserSession {
214 type Target = Session;
215
216 fn deref(&self) -> &Self::Target {
217 &self.0
218 }
219}
220impl DerefMut for UserSession {
221 fn deref_mut(&mut self) -> &mut Self::Target {
222 &mut self.0
223 }
224}
225
226fn user_handler<M: RequestMessage, Fut>(
227 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
228) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
229where
230 Fut: Send + Future<Output = Result<()>>,
231{
232 let handler = Arc::new(handler);
233 move |message, response, session| {
234 let handler = handler.clone();
235 Box::pin(async move {
236 if let Some(user_session) = session.for_user() {
237 Ok(handler(message, response, user_session).await?)
238 } else {
239 Err(Error::Internal(anyhow!("must be a user")))
240 }
241 })
242 }
243}
244
245fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
246 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
247) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
248where
249 InnertRetFut: Send + Future<Output = Result<()>>,
250{
251 let handler = Arc::new(handler);
252 move |message, session| {
253 let handler = handler.clone();
254 Box::pin(async move {
255 if let Some(user_session) = session.for_user() {
256 Ok(handler(message, user_session).await?)
257 } else {
258 Err(Error::Internal(anyhow!("must be a user")))
259 }
260 })
261 }
262}
263
264struct DbHandle(Arc<Database>);
265
266impl Deref for DbHandle {
267 type Target = Database;
268
269 fn deref(&self) -> &Self::Target {
270 self.0.as_ref()
271 }
272}
273
274pub struct Server {
275 id: parking_lot::Mutex<ServerId>,
276 peer: Arc<Peer>,
277 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
278 app_state: Arc<AppState>,
279 handlers: HashMap<TypeId, MessageHandler>,
280 teardown: watch::Sender<bool>,
281}
282
283pub(crate) struct ConnectionPoolGuard<'a> {
284 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
285 _not_send: PhantomData<Rc<()>>,
286}
287
288#[derive(Serialize)]
289pub struct ServerSnapshot<'a> {
290 peer: &'a Peer,
291 #[serde(serialize_with = "serialize_deref")]
292 connection_pool: ConnectionPoolGuard<'a>,
293}
294
295pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
296where
297 S: Serializer,
298 T: Deref<Target = U>,
299 U: Serialize,
300{
301 Serialize::serialize(value.deref(), serializer)
302}
303
304impl Server {
305 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
306 let mut server = Self {
307 id: parking_lot::Mutex::new(id),
308 peer: Peer::new(id.0 as u32),
309 app_state: app_state.clone(),
310 connection_pool: Default::default(),
311 handlers: Default::default(),
312 teardown: watch::channel(false).0,
313 };
314
315 server
316 .add_request_handler(ping)
317 .add_request_handler(user_handler(create_room))
318 .add_request_handler(user_handler(join_room))
319 .add_request_handler(user_handler(rejoin_room))
320 .add_request_handler(user_handler(leave_room))
321 .add_request_handler(user_handler(set_room_participant_role))
322 .add_request_handler(user_handler(call))
323 .add_request_handler(user_handler(cancel_call))
324 .add_message_handler(user_message_handler(decline_call))
325 .add_request_handler(user_handler(update_participant_location))
326 .add_request_handler(share_project)
327 .add_message_handler(unshare_project)
328 .add_request_handler(user_handler(join_project))
329 .add_request_handler(user_handler(join_hosted_project))
330 .add_message_handler(user_message_handler(leave_project))
331 .add_request_handler(update_project)
332 .add_request_handler(update_worktree)
333 .add_message_handler(start_language_server)
334 .add_message_handler(update_language_server)
335 .add_message_handler(update_diagnostic_summary)
336 .add_message_handler(update_worktree_settings)
337 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
338 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
339 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
340 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
341 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
342 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
343 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
344 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
345 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
346 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
347 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
348 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
349 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
350 .add_request_handler(
351 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
352 )
353 .add_request_handler(
354 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
355 )
356 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
357 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
358 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
359 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
360 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
361 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
362 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
363 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
364 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
365 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
366 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
367 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
368 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
369 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
370 .add_message_handler(create_buffer_for_peer)
371 .add_request_handler(update_buffer)
372 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
373 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
374 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
375 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
376 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
377 .add_request_handler(get_users)
378 .add_request_handler(user_handler(fuzzy_search_users))
379 .add_request_handler(user_handler(request_contact))
380 .add_request_handler(user_handler(remove_contact))
381 .add_request_handler(user_handler(respond_to_contact_request))
382 .add_request_handler(user_handler(create_channel))
383 .add_request_handler(user_handler(delete_channel))
384 .add_request_handler(user_handler(invite_channel_member))
385 .add_request_handler(user_handler(remove_channel_member))
386 .add_request_handler(user_handler(set_channel_member_role))
387 .add_request_handler(user_handler(set_channel_visibility))
388 .add_request_handler(user_handler(rename_channel))
389 .add_request_handler(user_handler(join_channel_buffer))
390 .add_request_handler(user_handler(leave_channel_buffer))
391 .add_message_handler(user_message_handler(update_channel_buffer))
392 .add_request_handler(user_handler(rejoin_channel_buffers))
393 .add_request_handler(user_handler(get_channel_members))
394 .add_request_handler(user_handler(respond_to_channel_invite))
395 .add_request_handler(user_handler(join_channel))
396 .add_request_handler(user_handler(join_channel_chat))
397 .add_message_handler(user_message_handler(leave_channel_chat))
398 .add_request_handler(user_handler(send_channel_message))
399 .add_request_handler(user_handler(remove_channel_message))
400 .add_request_handler(user_handler(update_channel_message))
401 .add_request_handler(user_handler(get_channel_messages))
402 .add_request_handler(user_handler(get_channel_messages_by_id))
403 .add_request_handler(user_handler(get_notifications))
404 .add_request_handler(user_handler(mark_notification_as_read))
405 .add_request_handler(user_handler(move_channel))
406 .add_request_handler(user_handler(follow))
407 .add_message_handler(user_message_handler(unfollow))
408 .add_message_handler(user_message_handler(update_followers))
409 .add_request_handler(user_handler(get_private_user_info))
410 .add_message_handler(user_message_handler(acknowledge_channel_message))
411 .add_message_handler(user_message_handler(acknowledge_buffer_version))
412 .add_streaming_request_handler({
413 let app_state = app_state.clone();
414 move |request, response, session| {
415 complete_with_language_model(
416 request,
417 response,
418 session,
419 app_state.config.openai_api_key.clone(),
420 app_state.config.google_ai_api_key.clone(),
421 )
422 }
423 })
424 .add_request_handler({
425 let app_state = app_state.clone();
426 user_handler(move |request, response, session| {
427 count_tokens_with_language_model(
428 request,
429 response,
430 session,
431 app_state.config.google_ai_api_key.clone(),
432 )
433 })
434 });
435
436 Arc::new(server)
437 }
438
439 pub async fn start(&self) -> Result<()> {
440 let server_id = *self.id.lock();
441 let app_state = self.app_state.clone();
442 let peer = self.peer.clone();
443 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
444 let pool = self.connection_pool.clone();
445 let live_kit_client = self.app_state.live_kit_client.clone();
446
447 let span = info_span!("start server");
448 self.app_state.executor.spawn_detached(
449 async move {
450 tracing::info!("waiting for cleanup timeout");
451 timeout.await;
452 tracing::info!("cleanup timeout expired, retrieving stale rooms");
453 if let Some((room_ids, channel_ids)) = app_state
454 .db
455 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
456 .await
457 .trace_err()
458 {
459 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
460 tracing::info!(
461 stale_channel_buffer_count = channel_ids.len(),
462 "retrieved stale channel buffers"
463 );
464
465 for channel_id in channel_ids {
466 if let Some(refreshed_channel_buffer) = app_state
467 .db
468 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
469 .await
470 .trace_err()
471 {
472 for connection_id in refreshed_channel_buffer.connection_ids {
473 peer.send(
474 connection_id,
475 proto::UpdateChannelBufferCollaborators {
476 channel_id: channel_id.to_proto(),
477 collaborators: refreshed_channel_buffer
478 .collaborators
479 .clone(),
480 },
481 )
482 .trace_err();
483 }
484 }
485 }
486
487 for room_id in room_ids {
488 let mut contacts_to_update = HashSet::default();
489 let mut canceled_calls_to_user_ids = Vec::new();
490 let mut live_kit_room = String::new();
491 let mut delete_live_kit_room = false;
492
493 if let Some(mut refreshed_room) = app_state
494 .db
495 .clear_stale_room_participants(room_id, server_id)
496 .await
497 .trace_err()
498 {
499 tracing::info!(
500 room_id = room_id.0,
501 new_participant_count = refreshed_room.room.participants.len(),
502 "refreshed room"
503 );
504 room_updated(&refreshed_room.room, &peer);
505 if let Some(channel) = refreshed_room.channel.as_ref() {
506 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
507 }
508 contacts_to_update
509 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
510 contacts_to_update
511 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
512 canceled_calls_to_user_ids =
513 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
514 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
515 delete_live_kit_room = refreshed_room.room.participants.is_empty();
516 }
517
518 {
519 let pool = pool.lock();
520 for canceled_user_id in canceled_calls_to_user_ids {
521 for connection_id in pool.user_connection_ids(canceled_user_id) {
522 peer.send(
523 connection_id,
524 proto::CallCanceled {
525 room_id: room_id.to_proto(),
526 },
527 )
528 .trace_err();
529 }
530 }
531 }
532
533 for user_id in contacts_to_update {
534 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
535 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
536 if let Some((busy, contacts)) = busy.zip(contacts) {
537 let pool = pool.lock();
538 let updated_contact = contact_for_user(user_id, busy, &pool);
539 for contact in contacts {
540 if let db::Contact::Accepted {
541 user_id: contact_user_id,
542 ..
543 } = contact
544 {
545 for contact_conn_id in
546 pool.user_connection_ids(contact_user_id)
547 {
548 peer.send(
549 contact_conn_id,
550 proto::UpdateContacts {
551 contacts: vec![updated_contact.clone()],
552 remove_contacts: Default::default(),
553 incoming_requests: Default::default(),
554 remove_incoming_requests: Default::default(),
555 outgoing_requests: Default::default(),
556 remove_outgoing_requests: Default::default(),
557 },
558 )
559 .trace_err();
560 }
561 }
562 }
563 }
564 }
565
566 if let Some(live_kit) = live_kit_client.as_ref() {
567 if delete_live_kit_room {
568 live_kit.delete_room(live_kit_room).await.trace_err();
569 }
570 }
571 }
572 }
573
574 app_state
575 .db
576 .delete_stale_servers(&app_state.config.zed_environment, server_id)
577 .await
578 .trace_err();
579 }
580 .instrument(span),
581 );
582 Ok(())
583 }
584
585 pub fn teardown(&self) {
586 self.peer.teardown();
587 self.connection_pool.lock().reset();
588 let _ = self.teardown.send(true);
589 }
590
591 #[cfg(test)]
592 pub fn reset(&self, id: ServerId) {
593 self.teardown();
594 *self.id.lock() = id;
595 self.peer.reset(id.0 as u32);
596 let _ = self.teardown.send(false);
597 }
598
599 #[cfg(test)]
600 pub fn id(&self) -> ServerId {
601 *self.id.lock()
602 }
603
604 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
605 where
606 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
607 Fut: 'static + Send + Future<Output = Result<()>>,
608 M: EnvelopedMessage,
609 {
610 let prev_handler = self.handlers.insert(
611 TypeId::of::<M>(),
612 Box::new(move |envelope, session| {
613 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
614 let received_at = envelope.received_at;
615 tracing::info!(
616 "message received"
617 );
618 let start_time = Instant::now();
619 let future = (handler)(*envelope, session);
620 async move {
621 let result = future.await;
622 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
623 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
624 let queue_duration_ms = total_duration_ms - processing_duration_ms;
625 match result {
626 Err(error) => {
627 tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, "error handling message")
628 }
629 Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
630 }
631 }
632 .boxed()
633 }),
634 );
635 if prev_handler.is_some() {
636 panic!("registered a handler for the same message twice");
637 }
638 self
639 }
640
641 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
642 where
643 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
644 Fut: 'static + Send + Future<Output = Result<()>>,
645 M: EnvelopedMessage,
646 {
647 self.add_handler(move |envelope, session| handler(envelope.payload, session));
648 self
649 }
650
651 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
652 where
653 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
654 Fut: Send + Future<Output = Result<()>>,
655 M: RequestMessage,
656 {
657 let handler = Arc::new(handler);
658 self.add_handler(move |envelope, session| {
659 let receipt = envelope.receipt();
660 let handler = handler.clone();
661 async move {
662 let peer = session.peer.clone();
663 let responded = Arc::new(AtomicBool::default());
664 let response = Response {
665 peer: peer.clone(),
666 responded: responded.clone(),
667 receipt,
668 };
669 match (handler)(envelope.payload, response, session).await {
670 Ok(()) => {
671 if responded.load(std::sync::atomic::Ordering::SeqCst) {
672 Ok(())
673 } else {
674 Err(anyhow!("handler did not send a response"))?
675 }
676 }
677 Err(error) => {
678 let proto_err = match &error {
679 Error::Internal(err) => err.to_proto(),
680 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
681 };
682 peer.respond_with_error(receipt, proto_err)?;
683 Err(error)
684 }
685 }
686 }
687 })
688 }
689
690 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
691 where
692 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
693 Fut: Send + Future<Output = Result<()>>,
694 M: RequestMessage,
695 {
696 let handler = Arc::new(handler);
697 self.add_handler(move |envelope, session| {
698 let receipt = envelope.receipt();
699 let handler = handler.clone();
700 async move {
701 let peer = session.peer.clone();
702 let response = StreamingResponse {
703 peer: peer.clone(),
704 receipt,
705 };
706 match (handler)(envelope.payload, response, session).await {
707 Ok(()) => {
708 peer.end_stream(receipt)?;
709 Ok(())
710 }
711 Err(error) => {
712 let proto_err = match &error {
713 Error::Internal(err) => err.to_proto(),
714 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
715 };
716 peer.respond_with_error(receipt, proto_err)?;
717 Err(error)
718 }
719 }
720 }
721 })
722 }
723
724 #[allow(clippy::too_many_arguments)]
725 pub fn handle_connection(
726 self: &Arc<Self>,
727 connection: Connection,
728 address: String,
729 principal: Principal,
730 zed_version: ZedVersion,
731 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
732 executor: Executor,
733 ) -> impl Future<Output = ()> {
734 let this = self.clone();
735 let span = info_span!("handle connection", %address, impersonator = field::Empty, connection_id = field::Empty);
736 principal.update_span(&span);
737
738 let mut teardown = self.teardown.subscribe();
739 async move {
740 if *teardown.borrow() {
741 tracing::error!("server is tearing down");
742 return
743 }
744 let (connection_id, handle_io, mut incoming_rx) = this
745 .peer
746 .add_connection(connection, {
747 let executor = executor.clone();
748 move |duration| executor.sleep(duration)
749 });
750 tracing::Span::current().record("connection_id", format!("{}", connection_id));
751 tracing::info!("connection opened");
752
753 let http_client = match IsahcHttpClient::new() {
754 Ok(http_client) => http_client,
755 Err(error) => {
756 tracing::error!(?error, "failed to create HTTP client");
757 return;
758 }
759 };
760
761 let session = Session {
762 principal: principal.clone(),
763 connection_id,
764 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
765 peer: this.peer.clone(),
766 connection_pool: this.connection_pool.clone(),
767 live_kit_client: this.app_state.live_kit_client.clone(),
768 http_client,
769 rate_limiter: this.app_state.rate_limiter.clone(),
770 _executor: executor.clone(),
771 };
772
773 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
774 tracing::error!(?error, "failed to send initial client update");
775 return;
776 }
777
778 let handle_io = handle_io.fuse();
779 futures::pin_mut!(handle_io);
780
781 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
782 // This prevents deadlocks when e.g., client A performs a request to client B and
783 // client B performs a request to client A. If both clients stop processing further
784 // messages until their respective request completes, they won't have a chance to
785 // respond to the other client's request and cause a deadlock.
786 //
787 // This arrangement ensures we will attempt to process earlier messages first, but fall
788 // back to processing messages arrived later in the spirit of making progress.
789 let mut foreground_message_handlers = FuturesUnordered::new();
790 let concurrent_handlers = Arc::new(Semaphore::new(256));
791 loop {
792 let next_message = async {
793 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
794 let message = incoming_rx.next().await;
795 (permit, message)
796 }.fuse();
797 futures::pin_mut!(next_message);
798 futures::select_biased! {
799 _ = teardown.changed().fuse() => return,
800 result = handle_io => {
801 if let Err(error) = result {
802 tracing::error!(?error, "error handling I/O");
803 }
804 break;
805 }
806 _ = foreground_message_handlers.next() => {}
807 next_message = next_message => {
808 let (permit, message) = next_message;
809 if let Some(message) = message {
810 let type_name = message.payload_type_name();
811 // note: we copy all the fields from the parent span so we can query them in the logs.
812 // (https://github.com/tokio-rs/tracing/issues/2670).
813 let span = tracing::info_span!("receive message", %connection_id, %address, type_name);
814 principal.update_span(&span);
815 let span_enter = span.enter();
816 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
817 let is_background = message.is_background();
818 let handle_message = (handler)(message, session.clone());
819 drop(span_enter);
820
821 let handle_message = async move {
822 handle_message.await;
823 drop(permit);
824 }.instrument(span);
825 if is_background {
826 executor.spawn_detached(handle_message);
827 } else {
828 foreground_message_handlers.push(handle_message);
829 }
830 } else {
831 tracing::error!("no message handler");
832 }
833 } else {
834 tracing::info!("connection closed");
835 break;
836 }
837 }
838 }
839 }
840
841 drop(foreground_message_handlers);
842 tracing::info!("signing out");
843 if let Err(error) = connection_lost(session, teardown, executor).await {
844 tracing::error!(?error, "error signing out");
845 }
846
847 }.instrument(span)
848 }
849
850 async fn send_initial_client_update(
851 &self,
852 connection_id: ConnectionId,
853 principal: &Principal,
854 zed_version: ZedVersion,
855 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
856 session: &Session,
857 ) -> Result<()> {
858 self.peer.send(
859 connection_id,
860 proto::Hello {
861 peer_id: Some(connection_id.into()),
862 },
863 )?;
864 tracing::info!("sent hello message");
865
866 let Principal::User(user) = principal else {
867 return Ok(());
868 };
869
870 if let Some(send_connection_id) = send_connection_id.take() {
871 let _ = send_connection_id.send(connection_id);
872 }
873
874 if !user.connected_once {
875 self.peer.send(connection_id, proto::ShowContacts {})?;
876 self.app_state
877 .db
878 .set_user_connected_once(user.id, true)
879 .await?;
880 }
881
882 let (contacts, channels_for_user, channel_invites) = future::try_join3(
883 self.app_state.db.get_contacts(user.id),
884 self.app_state.db.get_channels_for_user(user.id),
885 self.app_state.db.get_channel_invites_for_user(user.id),
886 )
887 .await?;
888
889 {
890 let mut pool = self.connection_pool.lock();
891 pool.add_connection(connection_id, user.id, user.admin, zed_version);
892 for membership in &channels_for_user.channel_memberships {
893 pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
894 }
895 self.peer.send(
896 connection_id,
897 build_initial_contacts_update(contacts, &pool),
898 )?;
899 self.peer.send(
900 connection_id,
901 build_update_user_channels(&channels_for_user),
902 )?;
903 self.peer.send(
904 connection_id,
905 build_channels_update(channels_for_user, channel_invites),
906 )?;
907 }
908
909 if let Some(incoming_call) = self.app_state.db.incoming_call_for_user(user.id).await? {
910 self.peer.send(connection_id, incoming_call)?;
911 }
912
913 update_user_contacts(user.id, &session).await?;
914 Ok(())
915 }
916
917 pub async fn invite_code_redeemed(
918 self: &Arc<Self>,
919 inviter_id: UserId,
920 invitee_id: UserId,
921 ) -> Result<()> {
922 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
923 if let Some(code) = &user.invite_code {
924 let pool = self.connection_pool.lock();
925 let invitee_contact = contact_for_user(invitee_id, false, &pool);
926 for connection_id in pool.user_connection_ids(inviter_id) {
927 self.peer.send(
928 connection_id,
929 proto::UpdateContacts {
930 contacts: vec![invitee_contact.clone()],
931 ..Default::default()
932 },
933 )?;
934 self.peer.send(
935 connection_id,
936 proto::UpdateInviteInfo {
937 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
938 count: user.invite_count as u32,
939 },
940 )?;
941 }
942 }
943 }
944 Ok(())
945 }
946
947 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
948 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
949 if let Some(invite_code) = &user.invite_code {
950 let pool = self.connection_pool.lock();
951 for connection_id in pool.user_connection_ids(user_id) {
952 self.peer.send(
953 connection_id,
954 proto::UpdateInviteInfo {
955 url: format!(
956 "{}{}",
957 self.app_state.config.invite_link_prefix, invite_code
958 ),
959 count: user.invite_count as u32,
960 },
961 )?;
962 }
963 }
964 }
965 Ok(())
966 }
967
968 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
969 ServerSnapshot {
970 connection_pool: ConnectionPoolGuard {
971 guard: self.connection_pool.lock(),
972 _not_send: PhantomData,
973 },
974 peer: &self.peer,
975 }
976 }
977}
978
979impl<'a> Deref for ConnectionPoolGuard<'a> {
980 type Target = ConnectionPool;
981
982 fn deref(&self) -> &Self::Target {
983 &self.guard
984 }
985}
986
987impl<'a> DerefMut for ConnectionPoolGuard<'a> {
988 fn deref_mut(&mut self) -> &mut Self::Target {
989 &mut self.guard
990 }
991}
992
993impl<'a> Drop for ConnectionPoolGuard<'a> {
994 fn drop(&mut self) {
995 #[cfg(test)]
996 self.check_invariants();
997 }
998}
999
1000fn broadcast<F>(
1001 sender_id: Option<ConnectionId>,
1002 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1003 mut f: F,
1004) where
1005 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1006{
1007 for receiver_id in receiver_ids {
1008 if Some(receiver_id) != sender_id {
1009 if let Err(error) = f(receiver_id) {
1010 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1011 }
1012 }
1013 }
1014}
1015
1016pub struct ProtocolVersion(u32);
1017
1018impl Header for ProtocolVersion {
1019 fn name() -> &'static HeaderName {
1020 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1021 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1022 }
1023
1024 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1025 where
1026 Self: Sized,
1027 I: Iterator<Item = &'i axum::http::HeaderValue>,
1028 {
1029 let version = values
1030 .next()
1031 .ok_or_else(axum::headers::Error::invalid)?
1032 .to_str()
1033 .map_err(|_| axum::headers::Error::invalid())?
1034 .parse()
1035 .map_err(|_| axum::headers::Error::invalid())?;
1036 Ok(Self(version))
1037 }
1038
1039 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1040 values.extend([self.0.to_string().parse().unwrap()]);
1041 }
1042}
1043
1044pub struct AppVersionHeader(SemanticVersion);
1045impl Header for AppVersionHeader {
1046 fn name() -> &'static HeaderName {
1047 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1048 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1049 }
1050
1051 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1052 where
1053 Self: Sized,
1054 I: Iterator<Item = &'i axum::http::HeaderValue>,
1055 {
1056 let version = values
1057 .next()
1058 .ok_or_else(axum::headers::Error::invalid)?
1059 .to_str()
1060 .map_err(|_| axum::headers::Error::invalid())?
1061 .parse()
1062 .map_err(|_| axum::headers::Error::invalid())?;
1063 Ok(Self(version))
1064 }
1065
1066 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1067 values.extend([self.0.to_string().parse().unwrap()]);
1068 }
1069}
1070
1071pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1072 Router::new()
1073 .route("/rpc", get(handle_websocket_request))
1074 .layer(
1075 ServiceBuilder::new()
1076 .layer(Extension(server.app_state.clone()))
1077 .layer(middleware::from_fn(auth::validate_header)),
1078 )
1079 .route("/metrics", get(handle_metrics))
1080 .layer(Extension(server))
1081}
1082
1083pub async fn handle_websocket_request(
1084 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1085 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1086 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1087 Extension(server): Extension<Arc<Server>>,
1088 Extension(principal): Extension<Principal>,
1089 ws: WebSocketUpgrade,
1090) -> axum::response::Response {
1091 if protocol_version != rpc::PROTOCOL_VERSION {
1092 return (
1093 StatusCode::UPGRADE_REQUIRED,
1094 "client must be upgraded".to_string(),
1095 )
1096 .into_response();
1097 }
1098
1099 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1100 return (
1101 StatusCode::UPGRADE_REQUIRED,
1102 "no version header found".to_string(),
1103 )
1104 .into_response();
1105 };
1106
1107 if !version.can_collaborate() {
1108 return (
1109 StatusCode::UPGRADE_REQUIRED,
1110 "client must be upgraded".to_string(),
1111 )
1112 .into_response();
1113 }
1114
1115 let socket_address = socket_address.to_string();
1116 ws.on_upgrade(move |socket| {
1117 let socket = socket
1118 .map_ok(to_tungstenite_message)
1119 .err_into()
1120 .with(|message| async move { Ok(to_axum_message(message)) });
1121 let connection = Connection::new(Box::pin(socket));
1122 async move {
1123 server
1124 .handle_connection(
1125 connection,
1126 socket_address,
1127 principal,
1128 version,
1129 None,
1130 Executor::Production,
1131 )
1132 .await;
1133 }
1134 })
1135}
1136
1137pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1138 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1139 let connections_metric = CONNECTIONS_METRIC
1140 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1141
1142 let connections = server
1143 .connection_pool
1144 .lock()
1145 .connections()
1146 .filter(|connection| !connection.admin)
1147 .count();
1148 connections_metric.set(connections as _);
1149
1150 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1151 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1152 register_int_gauge!(
1153 "shared_projects",
1154 "number of open projects with one or more guests"
1155 )
1156 .unwrap()
1157 });
1158
1159 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1160 shared_projects_metric.set(shared_projects as _);
1161
1162 let encoder = prometheus::TextEncoder::new();
1163 let metric_families = prometheus::gather();
1164 let encoded_metrics = encoder
1165 .encode_to_string(&metric_families)
1166 .map_err(|err| anyhow!("{}", err))?;
1167 Ok(encoded_metrics)
1168}
1169
1170#[instrument(err, skip(executor))]
1171async fn connection_lost(
1172 session: Session,
1173 mut teardown: watch::Receiver<bool>,
1174 executor: Executor,
1175) -> Result<()> {
1176 session.peer.disconnect(session.connection_id);
1177 session
1178 .connection_pool()
1179 .await
1180 .remove_connection(session.connection_id)?;
1181
1182 session
1183 .db()
1184 .await
1185 .connection_lost(session.connection_id)
1186 .await
1187 .trace_err();
1188
1189 futures::select_biased! {
1190 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1191 if let Some(session) = session.for_user() {
1192 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1193 leave_room_for_session(&session).await.trace_err();
1194 leave_channel_buffers_for_session(&session)
1195 .await
1196 .trace_err();
1197
1198 if !session
1199 .connection_pool()
1200 .await
1201 .is_user_online(session.user_id())
1202 {
1203 let db = session.db().await;
1204 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1205 room_updated(&room, &session.peer);
1206 }
1207 }
1208
1209 update_user_contacts(session.user_id(), &session).await?;
1210 }
1211 }
1212 _ = teardown.changed().fuse() => {}
1213 }
1214
1215 Ok(())
1216}
1217
1218/// Acknowledges a ping from a client, used to keep the connection alive.
1219async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1220 response.send(proto::Ack {})?;
1221 Ok(())
1222}
1223
1224/// Creates a new room for calling (outside of channels)
1225async fn create_room(
1226 _request: proto::CreateRoom,
1227 response: Response<proto::CreateRoom>,
1228 session: UserSession,
1229) -> Result<()> {
1230 let live_kit_room = nanoid::nanoid!(30);
1231
1232 let live_kit_connection_info = util::maybe!(async {
1233 let live_kit = session.live_kit_client.as_ref();
1234 let live_kit = live_kit?;
1235 let user_id = session.user_id().to_string();
1236
1237 let token = live_kit
1238 .room_token(&live_kit_room, &user_id.to_string())
1239 .trace_err()?;
1240
1241 Some(proto::LiveKitConnectionInfo {
1242 server_url: live_kit.url().into(),
1243 token,
1244 can_publish: true,
1245 })
1246 })
1247 .await;
1248
1249 let room = session
1250 .db()
1251 .await
1252 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1253 .await?;
1254
1255 response.send(proto::CreateRoomResponse {
1256 room: Some(room.clone()),
1257 live_kit_connection_info,
1258 })?;
1259
1260 update_user_contacts(session.user_id(), &session).await?;
1261 Ok(())
1262}
1263
1264/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1265async fn join_room(
1266 request: proto::JoinRoom,
1267 response: Response<proto::JoinRoom>,
1268 session: UserSession,
1269) -> Result<()> {
1270 let room_id = RoomId::from_proto(request.id);
1271
1272 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1273
1274 if let Some(channel_id) = channel_id {
1275 return join_channel_internal(channel_id, Box::new(response), session).await;
1276 }
1277
1278 let joined_room = {
1279 let room = session
1280 .db()
1281 .await
1282 .join_room(room_id, session.user_id(), session.connection_id)
1283 .await?;
1284 room_updated(&room.room, &session.peer);
1285 room.into_inner()
1286 };
1287
1288 for connection_id in session
1289 .connection_pool()
1290 .await
1291 .user_connection_ids(session.user_id())
1292 {
1293 session
1294 .peer
1295 .send(
1296 connection_id,
1297 proto::CallCanceled {
1298 room_id: room_id.to_proto(),
1299 },
1300 )
1301 .trace_err();
1302 }
1303
1304 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1305 if let Some(token) = live_kit
1306 .room_token(
1307 &joined_room.room.live_kit_room,
1308 &session.user_id().to_string(),
1309 )
1310 .trace_err()
1311 {
1312 Some(proto::LiveKitConnectionInfo {
1313 server_url: live_kit.url().into(),
1314 token,
1315 can_publish: true,
1316 })
1317 } else {
1318 None
1319 }
1320 } else {
1321 None
1322 };
1323
1324 response.send(proto::JoinRoomResponse {
1325 room: Some(joined_room.room),
1326 channel_id: None,
1327 live_kit_connection_info,
1328 })?;
1329
1330 update_user_contacts(session.user_id(), &session).await?;
1331 Ok(())
1332}
1333
1334/// Rejoin room is used to reconnect to a room after connection errors.
1335async fn rejoin_room(
1336 request: proto::RejoinRoom,
1337 response: Response<proto::RejoinRoom>,
1338 session: UserSession,
1339) -> Result<()> {
1340 let room;
1341 let channel;
1342 {
1343 let mut rejoined_room = session
1344 .db()
1345 .await
1346 .rejoin_room(request, session.user_id(), session.connection_id)
1347 .await?;
1348
1349 response.send(proto::RejoinRoomResponse {
1350 room: Some(rejoined_room.room.clone()),
1351 reshared_projects: rejoined_room
1352 .reshared_projects
1353 .iter()
1354 .map(|project| proto::ResharedProject {
1355 id: project.id.to_proto(),
1356 collaborators: project
1357 .collaborators
1358 .iter()
1359 .map(|collaborator| collaborator.to_proto())
1360 .collect(),
1361 })
1362 .collect(),
1363 rejoined_projects: rejoined_room
1364 .rejoined_projects
1365 .iter()
1366 .map(|rejoined_project| proto::RejoinedProject {
1367 id: rejoined_project.id.to_proto(),
1368 worktrees: rejoined_project
1369 .worktrees
1370 .iter()
1371 .map(|worktree| proto::WorktreeMetadata {
1372 id: worktree.id,
1373 root_name: worktree.root_name.clone(),
1374 visible: worktree.visible,
1375 abs_path: worktree.abs_path.clone(),
1376 })
1377 .collect(),
1378 collaborators: rejoined_project
1379 .collaborators
1380 .iter()
1381 .map(|collaborator| collaborator.to_proto())
1382 .collect(),
1383 language_servers: rejoined_project.language_servers.clone(),
1384 })
1385 .collect(),
1386 })?;
1387 room_updated(&rejoined_room.room, &session.peer);
1388
1389 for project in &rejoined_room.reshared_projects {
1390 for collaborator in &project.collaborators {
1391 session
1392 .peer
1393 .send(
1394 collaborator.connection_id,
1395 proto::UpdateProjectCollaborator {
1396 project_id: project.id.to_proto(),
1397 old_peer_id: Some(project.old_connection_id.into()),
1398 new_peer_id: Some(session.connection_id.into()),
1399 },
1400 )
1401 .trace_err();
1402 }
1403
1404 broadcast(
1405 Some(session.connection_id),
1406 project
1407 .collaborators
1408 .iter()
1409 .map(|collaborator| collaborator.connection_id),
1410 |connection_id| {
1411 session.peer.forward_send(
1412 session.connection_id,
1413 connection_id,
1414 proto::UpdateProject {
1415 project_id: project.id.to_proto(),
1416 worktrees: project.worktrees.clone(),
1417 },
1418 )
1419 },
1420 );
1421 }
1422
1423 for project in &rejoined_room.rejoined_projects {
1424 for collaborator in &project.collaborators {
1425 session
1426 .peer
1427 .send(
1428 collaborator.connection_id,
1429 proto::UpdateProjectCollaborator {
1430 project_id: project.id.to_proto(),
1431 old_peer_id: Some(project.old_connection_id.into()),
1432 new_peer_id: Some(session.connection_id.into()),
1433 },
1434 )
1435 .trace_err();
1436 }
1437 }
1438
1439 for project in &mut rejoined_room.rejoined_projects {
1440 for worktree in mem::take(&mut project.worktrees) {
1441 #[cfg(any(test, feature = "test-support"))]
1442 const MAX_CHUNK_SIZE: usize = 2;
1443 #[cfg(not(any(test, feature = "test-support")))]
1444 const MAX_CHUNK_SIZE: usize = 256;
1445
1446 // Stream this worktree's entries.
1447 let message = proto::UpdateWorktree {
1448 project_id: project.id.to_proto(),
1449 worktree_id: worktree.id,
1450 abs_path: worktree.abs_path.clone(),
1451 root_name: worktree.root_name,
1452 updated_entries: worktree.updated_entries,
1453 removed_entries: worktree.removed_entries,
1454 scan_id: worktree.scan_id,
1455 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1456 updated_repositories: worktree.updated_repositories,
1457 removed_repositories: worktree.removed_repositories,
1458 };
1459 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1460 session.peer.send(session.connection_id, update.clone())?;
1461 }
1462
1463 // Stream this worktree's diagnostics.
1464 for summary in worktree.diagnostic_summaries {
1465 session.peer.send(
1466 session.connection_id,
1467 proto::UpdateDiagnosticSummary {
1468 project_id: project.id.to_proto(),
1469 worktree_id: worktree.id,
1470 summary: Some(summary),
1471 },
1472 )?;
1473 }
1474
1475 for settings_file in worktree.settings_files {
1476 session.peer.send(
1477 session.connection_id,
1478 proto::UpdateWorktreeSettings {
1479 project_id: project.id.to_proto(),
1480 worktree_id: worktree.id,
1481 path: settings_file.path,
1482 content: Some(settings_file.content),
1483 },
1484 )?;
1485 }
1486 }
1487
1488 for language_server in &project.language_servers {
1489 session.peer.send(
1490 session.connection_id,
1491 proto::UpdateLanguageServer {
1492 project_id: project.id.to_proto(),
1493 language_server_id: language_server.id,
1494 variant: Some(
1495 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1496 proto::LspDiskBasedDiagnosticsUpdated {},
1497 ),
1498 ),
1499 },
1500 )?;
1501 }
1502 }
1503
1504 let rejoined_room = rejoined_room.into_inner();
1505
1506 room = rejoined_room.room;
1507 channel = rejoined_room.channel;
1508 }
1509
1510 if let Some(channel) = channel {
1511 channel_updated(
1512 &channel,
1513 &room,
1514 &session.peer,
1515 &*session.connection_pool().await,
1516 );
1517 }
1518
1519 update_user_contacts(session.user_id(), &session).await?;
1520 Ok(())
1521}
1522
1523/// leave room disconnects from the room.
1524async fn leave_room(
1525 _: proto::LeaveRoom,
1526 response: Response<proto::LeaveRoom>,
1527 session: UserSession,
1528) -> Result<()> {
1529 leave_room_for_session(&session).await?;
1530 response.send(proto::Ack {})?;
1531 Ok(())
1532}
1533
1534/// Updates the permissions of someone else in the room.
1535async fn set_room_participant_role(
1536 request: proto::SetRoomParticipantRole,
1537 response: Response<proto::SetRoomParticipantRole>,
1538 session: UserSession,
1539) -> Result<()> {
1540 let user_id = UserId::from_proto(request.user_id);
1541 let role = ChannelRole::from(request.role());
1542
1543 let (live_kit_room, can_publish) = {
1544 let room = session
1545 .db()
1546 .await
1547 .set_room_participant_role(
1548 session.user_id(),
1549 RoomId::from_proto(request.room_id),
1550 user_id,
1551 role,
1552 )
1553 .await?;
1554
1555 let live_kit_room = room.live_kit_room.clone();
1556 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1557 room_updated(&room, &session.peer);
1558 (live_kit_room, can_publish)
1559 };
1560
1561 if let Some(live_kit) = session.live_kit_client.as_ref() {
1562 live_kit
1563 .update_participant(
1564 live_kit_room.clone(),
1565 request.user_id.to_string(),
1566 live_kit_server::proto::ParticipantPermission {
1567 can_subscribe: true,
1568 can_publish,
1569 can_publish_data: can_publish,
1570 hidden: false,
1571 recorder: false,
1572 },
1573 )
1574 .await
1575 .trace_err();
1576 }
1577
1578 response.send(proto::Ack {})?;
1579 Ok(())
1580}
1581
1582/// Call someone else into the current room
1583async fn call(
1584 request: proto::Call,
1585 response: Response<proto::Call>,
1586 session: UserSession,
1587) -> Result<()> {
1588 let room_id = RoomId::from_proto(request.room_id);
1589 let calling_user_id = session.user_id();
1590 let calling_connection_id = session.connection_id;
1591 let called_user_id = UserId::from_proto(request.called_user_id);
1592 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1593 if !session
1594 .db()
1595 .await
1596 .has_contact(calling_user_id, called_user_id)
1597 .await?
1598 {
1599 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1600 }
1601
1602 let incoming_call = {
1603 let (room, incoming_call) = &mut *session
1604 .db()
1605 .await
1606 .call(
1607 room_id,
1608 calling_user_id,
1609 calling_connection_id,
1610 called_user_id,
1611 initial_project_id,
1612 )
1613 .await?;
1614 room_updated(&room, &session.peer);
1615 mem::take(incoming_call)
1616 };
1617 update_user_contacts(called_user_id, &session).await?;
1618
1619 let mut calls = session
1620 .connection_pool()
1621 .await
1622 .user_connection_ids(called_user_id)
1623 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1624 .collect::<FuturesUnordered<_>>();
1625
1626 while let Some(call_response) = calls.next().await {
1627 match call_response.as_ref() {
1628 Ok(_) => {
1629 response.send(proto::Ack {})?;
1630 return Ok(());
1631 }
1632 Err(_) => {
1633 call_response.trace_err();
1634 }
1635 }
1636 }
1637
1638 {
1639 let room = session
1640 .db()
1641 .await
1642 .call_failed(room_id, called_user_id)
1643 .await?;
1644 room_updated(&room, &session.peer);
1645 }
1646 update_user_contacts(called_user_id, &session).await?;
1647
1648 Err(anyhow!("failed to ring user"))?
1649}
1650
1651/// Cancel an outgoing call.
1652async fn cancel_call(
1653 request: proto::CancelCall,
1654 response: Response<proto::CancelCall>,
1655 session: UserSession,
1656) -> Result<()> {
1657 let called_user_id = UserId::from_proto(request.called_user_id);
1658 let room_id = RoomId::from_proto(request.room_id);
1659 {
1660 let room = session
1661 .db()
1662 .await
1663 .cancel_call(room_id, session.connection_id, called_user_id)
1664 .await?;
1665 room_updated(&room, &session.peer);
1666 }
1667
1668 for connection_id in session
1669 .connection_pool()
1670 .await
1671 .user_connection_ids(called_user_id)
1672 {
1673 session
1674 .peer
1675 .send(
1676 connection_id,
1677 proto::CallCanceled {
1678 room_id: room_id.to_proto(),
1679 },
1680 )
1681 .trace_err();
1682 }
1683 response.send(proto::Ack {})?;
1684
1685 update_user_contacts(called_user_id, &session).await?;
1686 Ok(())
1687}
1688
1689/// Decline an incoming call.
1690async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1691 let room_id = RoomId::from_proto(message.room_id);
1692 {
1693 let room = session
1694 .db()
1695 .await
1696 .decline_call(Some(room_id), session.user_id())
1697 .await?
1698 .ok_or_else(|| anyhow!("failed to decline call"))?;
1699 room_updated(&room, &session.peer);
1700 }
1701
1702 for connection_id in session
1703 .connection_pool()
1704 .await
1705 .user_connection_ids(session.user_id())
1706 {
1707 session
1708 .peer
1709 .send(
1710 connection_id,
1711 proto::CallCanceled {
1712 room_id: room_id.to_proto(),
1713 },
1714 )
1715 .trace_err();
1716 }
1717 update_user_contacts(session.user_id(), &session).await?;
1718 Ok(())
1719}
1720
1721/// Updates other participants in the room with your current location.
1722async fn update_participant_location(
1723 request: proto::UpdateParticipantLocation,
1724 response: Response<proto::UpdateParticipantLocation>,
1725 session: UserSession,
1726) -> Result<()> {
1727 let room_id = RoomId::from_proto(request.room_id);
1728 let location = request
1729 .location
1730 .ok_or_else(|| anyhow!("invalid location"))?;
1731
1732 let db = session.db().await;
1733 let room = db
1734 .update_room_participant_location(room_id, session.connection_id, location)
1735 .await?;
1736
1737 room_updated(&room, &session.peer);
1738 response.send(proto::Ack {})?;
1739 Ok(())
1740}
1741
1742/// Share a project into the room.
1743async fn share_project(
1744 request: proto::ShareProject,
1745 response: Response<proto::ShareProject>,
1746 session: Session,
1747) -> Result<()> {
1748 let (project_id, room) = &*session
1749 .db()
1750 .await
1751 .share_project(
1752 RoomId::from_proto(request.room_id),
1753 session.connection_id,
1754 &request.worktrees,
1755 )
1756 .await?;
1757 response.send(proto::ShareProjectResponse {
1758 project_id: project_id.to_proto(),
1759 })?;
1760 room_updated(&room, &session.peer);
1761
1762 Ok(())
1763}
1764
1765/// Unshare a project from the room.
1766async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1767 let project_id = ProjectId::from_proto(message.project_id);
1768
1769 let (room, guest_connection_ids) = &*session
1770 .db()
1771 .await
1772 .unshare_project(project_id, session.connection_id)
1773 .await?;
1774
1775 broadcast(
1776 Some(session.connection_id),
1777 guest_connection_ids.iter().copied(),
1778 |conn_id| session.peer.send(conn_id, message.clone()),
1779 );
1780 room_updated(&room, &session.peer);
1781
1782 Ok(())
1783}
1784
1785/// Join someone elses shared project.
1786async fn join_project(
1787 request: proto::JoinProject,
1788 response: Response<proto::JoinProject>,
1789 session: UserSession,
1790) -> Result<()> {
1791 let project_id = ProjectId::from_proto(request.project_id);
1792
1793 tracing::info!(%project_id, "join project");
1794
1795 let (project, replica_id) = &mut *session
1796 .db()
1797 .await
1798 .join_project_in_room(project_id, session.connection_id)
1799 .await?;
1800
1801 join_project_internal(response, session, project, replica_id)
1802}
1803
1804trait JoinProjectInternalResponse {
1805 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1806}
1807impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1808 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1809 Response::<proto::JoinProject>::send(self, result)
1810 }
1811}
1812impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1813 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1814 Response::<proto::JoinHostedProject>::send(self, result)
1815 }
1816}
1817
1818fn join_project_internal(
1819 response: impl JoinProjectInternalResponse,
1820 session: UserSession,
1821 project: &mut Project,
1822 replica_id: &ReplicaId,
1823) -> Result<()> {
1824 let collaborators = project
1825 .collaborators
1826 .iter()
1827 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1828 .map(|collaborator| collaborator.to_proto())
1829 .collect::<Vec<_>>();
1830 let project_id = project.id;
1831 let guest_user_id = session.user_id();
1832
1833 let worktrees = project
1834 .worktrees
1835 .iter()
1836 .map(|(id, worktree)| proto::WorktreeMetadata {
1837 id: *id,
1838 root_name: worktree.root_name.clone(),
1839 visible: worktree.visible,
1840 abs_path: worktree.abs_path.clone(),
1841 })
1842 .collect::<Vec<_>>();
1843
1844 for collaborator in &collaborators {
1845 session
1846 .peer
1847 .send(
1848 collaborator.peer_id.unwrap().into(),
1849 proto::AddProjectCollaborator {
1850 project_id: project_id.to_proto(),
1851 collaborator: Some(proto::Collaborator {
1852 peer_id: Some(session.connection_id.into()),
1853 replica_id: replica_id.0 as u32,
1854 user_id: guest_user_id.to_proto(),
1855 }),
1856 },
1857 )
1858 .trace_err();
1859 }
1860
1861 // First, we send the metadata associated with each worktree.
1862 response.send(proto::JoinProjectResponse {
1863 project_id: project.id.0 as u64,
1864 worktrees: worktrees.clone(),
1865 replica_id: replica_id.0 as u32,
1866 collaborators: collaborators.clone(),
1867 language_servers: project.language_servers.clone(),
1868 role: project.role.into(), // todo
1869 })?;
1870
1871 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1872 #[cfg(any(test, feature = "test-support"))]
1873 const MAX_CHUNK_SIZE: usize = 2;
1874 #[cfg(not(any(test, feature = "test-support")))]
1875 const MAX_CHUNK_SIZE: usize = 256;
1876
1877 // Stream this worktree's entries.
1878 let message = proto::UpdateWorktree {
1879 project_id: project_id.to_proto(),
1880 worktree_id,
1881 abs_path: worktree.abs_path.clone(),
1882 root_name: worktree.root_name,
1883 updated_entries: worktree.entries,
1884 removed_entries: Default::default(),
1885 scan_id: worktree.scan_id,
1886 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1887 updated_repositories: worktree.repository_entries.into_values().collect(),
1888 removed_repositories: Default::default(),
1889 };
1890 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1891 session.peer.send(session.connection_id, update.clone())?;
1892 }
1893
1894 // Stream this worktree's diagnostics.
1895 for summary in worktree.diagnostic_summaries {
1896 session.peer.send(
1897 session.connection_id,
1898 proto::UpdateDiagnosticSummary {
1899 project_id: project_id.to_proto(),
1900 worktree_id: worktree.id,
1901 summary: Some(summary),
1902 },
1903 )?;
1904 }
1905
1906 for settings_file in worktree.settings_files {
1907 session.peer.send(
1908 session.connection_id,
1909 proto::UpdateWorktreeSettings {
1910 project_id: project_id.to_proto(),
1911 worktree_id: worktree.id,
1912 path: settings_file.path,
1913 content: Some(settings_file.content),
1914 },
1915 )?;
1916 }
1917 }
1918
1919 for language_server in &project.language_servers {
1920 session.peer.send(
1921 session.connection_id,
1922 proto::UpdateLanguageServer {
1923 project_id: project_id.to_proto(),
1924 language_server_id: language_server.id,
1925 variant: Some(
1926 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1927 proto::LspDiskBasedDiagnosticsUpdated {},
1928 ),
1929 ),
1930 },
1931 )?;
1932 }
1933
1934 Ok(())
1935}
1936
1937/// Leave someone elses shared project.
1938async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
1939 let sender_id = session.connection_id;
1940 let project_id = ProjectId::from_proto(request.project_id);
1941 let db = session.db().await;
1942 if db.is_hosted_project(project_id).await? {
1943 let project = db.leave_hosted_project(project_id, sender_id).await?;
1944 project_left(&project, &session);
1945 return Ok(());
1946 }
1947
1948 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1949 tracing::info!(
1950 %project_id,
1951 host_user_id = ?project.host_user_id,
1952 host_connection_id = ?project.host_connection_id,
1953 "leave project"
1954 );
1955
1956 project_left(&project, &session);
1957 room_updated(&room, &session.peer);
1958
1959 Ok(())
1960}
1961
1962async fn join_hosted_project(
1963 request: proto::JoinHostedProject,
1964 response: Response<proto::JoinHostedProject>,
1965 session: UserSession,
1966) -> Result<()> {
1967 let (mut project, replica_id) = session
1968 .db()
1969 .await
1970 .join_hosted_project(
1971 ProjectId(request.project_id as i32),
1972 session.user_id(),
1973 session.connection_id,
1974 )
1975 .await?;
1976
1977 join_project_internal(response, session, &mut project, &replica_id)
1978}
1979
1980/// Updates other participants with changes to the project
1981async fn update_project(
1982 request: proto::UpdateProject,
1983 response: Response<proto::UpdateProject>,
1984 session: Session,
1985) -> Result<()> {
1986 let project_id = ProjectId::from_proto(request.project_id);
1987 let (room, guest_connection_ids) = &*session
1988 .db()
1989 .await
1990 .update_project(project_id, session.connection_id, &request.worktrees)
1991 .await?;
1992 broadcast(
1993 Some(session.connection_id),
1994 guest_connection_ids.iter().copied(),
1995 |connection_id| {
1996 session
1997 .peer
1998 .forward_send(session.connection_id, connection_id, request.clone())
1999 },
2000 );
2001 room_updated(&room, &session.peer);
2002 response.send(proto::Ack {})?;
2003
2004 Ok(())
2005}
2006
2007/// Updates other participants with changes to the worktree
2008async fn update_worktree(
2009 request: proto::UpdateWorktree,
2010 response: Response<proto::UpdateWorktree>,
2011 session: Session,
2012) -> Result<()> {
2013 let guest_connection_ids = session
2014 .db()
2015 .await
2016 .update_worktree(&request, session.connection_id)
2017 .await?;
2018
2019 broadcast(
2020 Some(session.connection_id),
2021 guest_connection_ids.iter().copied(),
2022 |connection_id| {
2023 session
2024 .peer
2025 .forward_send(session.connection_id, connection_id, request.clone())
2026 },
2027 );
2028 response.send(proto::Ack {})?;
2029 Ok(())
2030}
2031
2032/// Updates other participants with changes to the diagnostics
2033async fn update_diagnostic_summary(
2034 message: proto::UpdateDiagnosticSummary,
2035 session: Session,
2036) -> Result<()> {
2037 let guest_connection_ids = session
2038 .db()
2039 .await
2040 .update_diagnostic_summary(&message, session.connection_id)
2041 .await?;
2042
2043 broadcast(
2044 Some(session.connection_id),
2045 guest_connection_ids.iter().copied(),
2046 |connection_id| {
2047 session
2048 .peer
2049 .forward_send(session.connection_id, connection_id, message.clone())
2050 },
2051 );
2052
2053 Ok(())
2054}
2055
2056/// Updates other participants with changes to the worktree settings
2057async fn update_worktree_settings(
2058 message: proto::UpdateWorktreeSettings,
2059 session: Session,
2060) -> Result<()> {
2061 let guest_connection_ids = session
2062 .db()
2063 .await
2064 .update_worktree_settings(&message, session.connection_id)
2065 .await?;
2066
2067 broadcast(
2068 Some(session.connection_id),
2069 guest_connection_ids.iter().copied(),
2070 |connection_id| {
2071 session
2072 .peer
2073 .forward_send(session.connection_id, connection_id, message.clone())
2074 },
2075 );
2076
2077 Ok(())
2078}
2079
2080/// Notify other participants that a language server has started.
2081async fn start_language_server(
2082 request: proto::StartLanguageServer,
2083 session: Session,
2084) -> Result<()> {
2085 let guest_connection_ids = session
2086 .db()
2087 .await
2088 .start_language_server(&request, session.connection_id)
2089 .await?;
2090
2091 broadcast(
2092 Some(session.connection_id),
2093 guest_connection_ids.iter().copied(),
2094 |connection_id| {
2095 session
2096 .peer
2097 .forward_send(session.connection_id, connection_id, request.clone())
2098 },
2099 );
2100 Ok(())
2101}
2102
2103/// Notify other participants that a language server has changed.
2104async fn update_language_server(
2105 request: proto::UpdateLanguageServer,
2106 session: Session,
2107) -> Result<()> {
2108 let project_id = ProjectId::from_proto(request.project_id);
2109 let project_connection_ids = session
2110 .db()
2111 .await
2112 .project_connection_ids(project_id, session.connection_id)
2113 .await?;
2114 broadcast(
2115 Some(session.connection_id),
2116 project_connection_ids.iter().copied(),
2117 |connection_id| {
2118 session
2119 .peer
2120 .forward_send(session.connection_id, connection_id, request.clone())
2121 },
2122 );
2123 Ok(())
2124}
2125
2126/// forward a project request to the host. These requests should be read only
2127/// as guests are allowed to send them.
2128async fn forward_read_only_project_request<T>(
2129 request: T,
2130 response: Response<T>,
2131 session: Session,
2132) -> Result<()>
2133where
2134 T: EntityMessage + RequestMessage,
2135{
2136 let project_id = ProjectId::from_proto(request.remote_entity_id());
2137 let host_connection_id = session
2138 .db()
2139 .await
2140 .host_for_read_only_project_request(project_id, session.connection_id)
2141 .await?;
2142 let payload = session
2143 .peer
2144 .forward_request(session.connection_id, host_connection_id, request)
2145 .await?;
2146 response.send(payload)?;
2147 Ok(())
2148}
2149
2150/// forward a project request to the host. These requests are disallowed
2151/// for guests.
2152async fn forward_mutating_project_request<T>(
2153 request: T,
2154 response: Response<T>,
2155 session: Session,
2156) -> Result<()>
2157where
2158 T: EntityMessage + RequestMessage,
2159{
2160 let project_id = ProjectId::from_proto(request.remote_entity_id());
2161 let host_connection_id = session
2162 .db()
2163 .await
2164 .host_for_mutating_project_request(project_id, session.connection_id)
2165 .await?;
2166 let payload = session
2167 .peer
2168 .forward_request(session.connection_id, host_connection_id, request)
2169 .await?;
2170 response.send(payload)?;
2171 Ok(())
2172}
2173
2174/// Notify other participants that a new buffer has been created
2175async fn create_buffer_for_peer(
2176 request: proto::CreateBufferForPeer,
2177 session: Session,
2178) -> Result<()> {
2179 session
2180 .db()
2181 .await
2182 .check_user_is_project_host(
2183 ProjectId::from_proto(request.project_id),
2184 session.connection_id,
2185 )
2186 .await?;
2187 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2188 session
2189 .peer
2190 .forward_send(session.connection_id, peer_id.into(), request)?;
2191 Ok(())
2192}
2193
2194/// Notify other participants that a buffer has been updated. This is
2195/// allowed for guests as long as the update is limited to selections.
2196async fn update_buffer(
2197 request: proto::UpdateBuffer,
2198 response: Response<proto::UpdateBuffer>,
2199 session: Session,
2200) -> Result<()> {
2201 let project_id = ProjectId::from_proto(request.project_id);
2202 let mut guest_connection_ids;
2203 let mut host_connection_id = None;
2204
2205 let mut requires_write_permission = false;
2206
2207 for op in request.operations.iter() {
2208 match op.variant {
2209 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2210 Some(_) => requires_write_permission = true,
2211 }
2212 }
2213
2214 {
2215 let collaborators = session
2216 .db()
2217 .await
2218 .project_collaborators_for_buffer_update(
2219 project_id,
2220 session.connection_id,
2221 requires_write_permission,
2222 )
2223 .await?;
2224 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2225 for collaborator in collaborators.iter() {
2226 if collaborator.is_host {
2227 host_connection_id = Some(collaborator.connection_id);
2228 } else {
2229 guest_connection_ids.push(collaborator.connection_id);
2230 }
2231 }
2232 }
2233 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2234
2235 broadcast(
2236 Some(session.connection_id),
2237 guest_connection_ids,
2238 |connection_id| {
2239 session
2240 .peer
2241 .forward_send(session.connection_id, connection_id, request.clone())
2242 },
2243 );
2244 if host_connection_id != session.connection_id {
2245 session
2246 .peer
2247 .forward_request(session.connection_id, host_connection_id, request.clone())
2248 .await?;
2249 }
2250
2251 response.send(proto::Ack {})?;
2252 Ok(())
2253}
2254
2255/// Notify other participants that a project has been updated.
2256async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2257 request: T,
2258 session: Session,
2259) -> Result<()> {
2260 let project_id = ProjectId::from_proto(request.remote_entity_id());
2261 let project_connection_ids = session
2262 .db()
2263 .await
2264 .project_connection_ids(project_id, session.connection_id)
2265 .await?;
2266
2267 broadcast(
2268 Some(session.connection_id),
2269 project_connection_ids.iter().copied(),
2270 |connection_id| {
2271 session
2272 .peer
2273 .forward_send(session.connection_id, connection_id, request.clone())
2274 },
2275 );
2276 Ok(())
2277}
2278
2279/// Start following another user in a call.
2280async fn follow(
2281 request: proto::Follow,
2282 response: Response<proto::Follow>,
2283 session: UserSession,
2284) -> Result<()> {
2285 let room_id = RoomId::from_proto(request.room_id);
2286 let project_id = request.project_id.map(ProjectId::from_proto);
2287 let leader_id = request
2288 .leader_id
2289 .ok_or_else(|| anyhow!("invalid leader id"))?
2290 .into();
2291 let follower_id = session.connection_id;
2292
2293 session
2294 .db()
2295 .await
2296 .check_room_participants(room_id, leader_id, session.connection_id)
2297 .await?;
2298
2299 let response_payload = session
2300 .peer
2301 .forward_request(session.connection_id, leader_id, request)
2302 .await?;
2303 response.send(response_payload)?;
2304
2305 if let Some(project_id) = project_id {
2306 let room = session
2307 .db()
2308 .await
2309 .follow(room_id, project_id, leader_id, follower_id)
2310 .await?;
2311 room_updated(&room, &session.peer);
2312 }
2313
2314 Ok(())
2315}
2316
2317/// Stop following another user in a call.
2318async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2319 let room_id = RoomId::from_proto(request.room_id);
2320 let project_id = request.project_id.map(ProjectId::from_proto);
2321 let leader_id = request
2322 .leader_id
2323 .ok_or_else(|| anyhow!("invalid leader id"))?
2324 .into();
2325 let follower_id = session.connection_id;
2326
2327 session
2328 .db()
2329 .await
2330 .check_room_participants(room_id, leader_id, session.connection_id)
2331 .await?;
2332
2333 session
2334 .peer
2335 .forward_send(session.connection_id, leader_id, request)?;
2336
2337 if let Some(project_id) = project_id {
2338 let room = session
2339 .db()
2340 .await
2341 .unfollow(room_id, project_id, leader_id, follower_id)
2342 .await?;
2343 room_updated(&room, &session.peer);
2344 }
2345
2346 Ok(())
2347}
2348
2349/// Notify everyone following you of your current location.
2350async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2351 let room_id = RoomId::from_proto(request.room_id);
2352 let database = session.db.lock().await;
2353
2354 let connection_ids = if let Some(project_id) = request.project_id {
2355 let project_id = ProjectId::from_proto(project_id);
2356 database
2357 .project_connection_ids(project_id, session.connection_id)
2358 .await?
2359 } else {
2360 database
2361 .room_connection_ids(room_id, session.connection_id)
2362 .await?
2363 };
2364
2365 // For now, don't send view update messages back to that view's current leader.
2366 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2367 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2368 _ => None,
2369 });
2370
2371 for connection_id in connection_ids.iter().cloned() {
2372 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2373 session
2374 .peer
2375 .forward_send(session.connection_id, connection_id, request.clone())?;
2376 }
2377 }
2378 Ok(())
2379}
2380
2381/// Get public data about users.
2382async fn get_users(
2383 request: proto::GetUsers,
2384 response: Response<proto::GetUsers>,
2385 session: Session,
2386) -> Result<()> {
2387 let user_ids = request
2388 .user_ids
2389 .into_iter()
2390 .map(UserId::from_proto)
2391 .collect();
2392 let users = session
2393 .db()
2394 .await
2395 .get_users_by_ids(user_ids)
2396 .await?
2397 .into_iter()
2398 .map(|user| proto::User {
2399 id: user.id.to_proto(),
2400 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2401 github_login: user.github_login,
2402 })
2403 .collect();
2404 response.send(proto::UsersResponse { users })?;
2405 Ok(())
2406}
2407
2408/// Search for users (to invite) buy Github login
2409async fn fuzzy_search_users(
2410 request: proto::FuzzySearchUsers,
2411 response: Response<proto::FuzzySearchUsers>,
2412 session: UserSession,
2413) -> Result<()> {
2414 let query = request.query;
2415 let users = match query.len() {
2416 0 => vec![],
2417 1 | 2 => session
2418 .db()
2419 .await
2420 .get_user_by_github_login(&query)
2421 .await?
2422 .into_iter()
2423 .collect(),
2424 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2425 };
2426 let users = users
2427 .into_iter()
2428 .filter(|user| user.id != session.user_id())
2429 .map(|user| proto::User {
2430 id: user.id.to_proto(),
2431 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2432 github_login: user.github_login,
2433 })
2434 .collect();
2435 response.send(proto::UsersResponse { users })?;
2436 Ok(())
2437}
2438
2439/// Send a contact request to another user.
2440async fn request_contact(
2441 request: proto::RequestContact,
2442 response: Response<proto::RequestContact>,
2443 session: UserSession,
2444) -> Result<()> {
2445 let requester_id = session.user_id();
2446 let responder_id = UserId::from_proto(request.responder_id);
2447 if requester_id == responder_id {
2448 return Err(anyhow!("cannot add yourself as a contact"))?;
2449 }
2450
2451 let notifications = session
2452 .db()
2453 .await
2454 .send_contact_request(requester_id, responder_id)
2455 .await?;
2456
2457 // Update outgoing contact requests of requester
2458 let mut update = proto::UpdateContacts::default();
2459 update.outgoing_requests.push(responder_id.to_proto());
2460 for connection_id in session
2461 .connection_pool()
2462 .await
2463 .user_connection_ids(requester_id)
2464 {
2465 session.peer.send(connection_id, update.clone())?;
2466 }
2467
2468 // Update incoming contact requests of responder
2469 let mut update = proto::UpdateContacts::default();
2470 update
2471 .incoming_requests
2472 .push(proto::IncomingContactRequest {
2473 requester_id: requester_id.to_proto(),
2474 });
2475 let connection_pool = session.connection_pool().await;
2476 for connection_id in connection_pool.user_connection_ids(responder_id) {
2477 session.peer.send(connection_id, update.clone())?;
2478 }
2479
2480 send_notifications(&connection_pool, &session.peer, notifications);
2481
2482 response.send(proto::Ack {})?;
2483 Ok(())
2484}
2485
2486/// Accept or decline a contact request
2487async fn respond_to_contact_request(
2488 request: proto::RespondToContactRequest,
2489 response: Response<proto::RespondToContactRequest>,
2490 session: UserSession,
2491) -> Result<()> {
2492 let responder_id = session.user_id();
2493 let requester_id = UserId::from_proto(request.requester_id);
2494 let db = session.db().await;
2495 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2496 db.dismiss_contact_notification(responder_id, requester_id)
2497 .await?;
2498 } else {
2499 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2500
2501 let notifications = db
2502 .respond_to_contact_request(responder_id, requester_id, accept)
2503 .await?;
2504 let requester_busy = db.is_user_busy(requester_id).await?;
2505 let responder_busy = db.is_user_busy(responder_id).await?;
2506
2507 let pool = session.connection_pool().await;
2508 // Update responder with new contact
2509 let mut update = proto::UpdateContacts::default();
2510 if accept {
2511 update
2512 .contacts
2513 .push(contact_for_user(requester_id, requester_busy, &pool));
2514 }
2515 update
2516 .remove_incoming_requests
2517 .push(requester_id.to_proto());
2518 for connection_id in pool.user_connection_ids(responder_id) {
2519 session.peer.send(connection_id, update.clone())?;
2520 }
2521
2522 // Update requester with new contact
2523 let mut update = proto::UpdateContacts::default();
2524 if accept {
2525 update
2526 .contacts
2527 .push(contact_for_user(responder_id, responder_busy, &pool));
2528 }
2529 update
2530 .remove_outgoing_requests
2531 .push(responder_id.to_proto());
2532
2533 for connection_id in pool.user_connection_ids(requester_id) {
2534 session.peer.send(connection_id, update.clone())?;
2535 }
2536
2537 send_notifications(&pool, &session.peer, notifications);
2538 }
2539
2540 response.send(proto::Ack {})?;
2541 Ok(())
2542}
2543
2544/// Remove a contact.
2545async fn remove_contact(
2546 request: proto::RemoveContact,
2547 response: Response<proto::RemoveContact>,
2548 session: UserSession,
2549) -> Result<()> {
2550 let requester_id = session.user_id();
2551 let responder_id = UserId::from_proto(request.user_id);
2552 let db = session.db().await;
2553 let (contact_accepted, deleted_notification_id) =
2554 db.remove_contact(requester_id, responder_id).await?;
2555
2556 let pool = session.connection_pool().await;
2557 // Update outgoing contact requests of requester
2558 let mut update = proto::UpdateContacts::default();
2559 if contact_accepted {
2560 update.remove_contacts.push(responder_id.to_proto());
2561 } else {
2562 update
2563 .remove_outgoing_requests
2564 .push(responder_id.to_proto());
2565 }
2566 for connection_id in pool.user_connection_ids(requester_id) {
2567 session.peer.send(connection_id, update.clone())?;
2568 }
2569
2570 // Update incoming contact requests of responder
2571 let mut update = proto::UpdateContacts::default();
2572 if contact_accepted {
2573 update.remove_contacts.push(requester_id.to_proto());
2574 } else {
2575 update
2576 .remove_incoming_requests
2577 .push(requester_id.to_proto());
2578 }
2579 for connection_id in pool.user_connection_ids(responder_id) {
2580 session.peer.send(connection_id, update.clone())?;
2581 if let Some(notification_id) = deleted_notification_id {
2582 session.peer.send(
2583 connection_id,
2584 proto::DeleteNotification {
2585 notification_id: notification_id.to_proto(),
2586 },
2587 )?;
2588 }
2589 }
2590
2591 response.send(proto::Ack {})?;
2592 Ok(())
2593}
2594
2595/// Creates a new channel.
2596async fn create_channel(
2597 request: proto::CreateChannel,
2598 response: Response<proto::CreateChannel>,
2599 session: UserSession,
2600) -> Result<()> {
2601 let db = session.db().await;
2602
2603 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2604 let (channel, membership) = db
2605 .create_channel(&request.name, parent_id, session.user_id())
2606 .await?;
2607
2608 let root_id = channel.root_id();
2609 let channel = Channel::from_model(channel);
2610
2611 response.send(proto::CreateChannelResponse {
2612 channel: Some(channel.to_proto()),
2613 parent_id: request.parent_id,
2614 })?;
2615
2616 let mut connection_pool = session.connection_pool().await;
2617 if let Some(membership) = membership {
2618 connection_pool.subscribe_to_channel(
2619 membership.user_id,
2620 membership.channel_id,
2621 membership.role,
2622 );
2623 let update = proto::UpdateUserChannels {
2624 channel_memberships: vec![proto::ChannelMembership {
2625 channel_id: membership.channel_id.to_proto(),
2626 role: membership.role.into(),
2627 }],
2628 ..Default::default()
2629 };
2630 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2631 session.peer.send(connection_id, update.clone())?;
2632 }
2633 }
2634
2635 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2636 if !role.can_see_channel(channel.visibility) {
2637 continue;
2638 }
2639
2640 let update = proto::UpdateChannels {
2641 channels: vec![channel.to_proto()],
2642 ..Default::default()
2643 };
2644 session.peer.send(connection_id, update.clone())?;
2645 }
2646
2647 Ok(())
2648}
2649
2650/// Delete a channel
2651async fn delete_channel(
2652 request: proto::DeleteChannel,
2653 response: Response<proto::DeleteChannel>,
2654 session: UserSession,
2655) -> Result<()> {
2656 let db = session.db().await;
2657
2658 let channel_id = request.channel_id;
2659 let (root_channel, removed_channels) = db
2660 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2661 .await?;
2662 response.send(proto::Ack {})?;
2663
2664 // Notify members of removed channels
2665 let mut update = proto::UpdateChannels::default();
2666 update
2667 .delete_channels
2668 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2669
2670 let connection_pool = session.connection_pool().await;
2671 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2672 session.peer.send(connection_id, update.clone())?;
2673 }
2674
2675 Ok(())
2676}
2677
2678/// Invite someone to join a channel.
2679async fn invite_channel_member(
2680 request: proto::InviteChannelMember,
2681 response: Response<proto::InviteChannelMember>,
2682 session: UserSession,
2683) -> Result<()> {
2684 let db = session.db().await;
2685 let channel_id = ChannelId::from_proto(request.channel_id);
2686 let invitee_id = UserId::from_proto(request.user_id);
2687 let InviteMemberResult {
2688 channel,
2689 notifications,
2690 } = db
2691 .invite_channel_member(
2692 channel_id,
2693 invitee_id,
2694 session.user_id(),
2695 request.role().into(),
2696 )
2697 .await?;
2698
2699 let update = proto::UpdateChannels {
2700 channel_invitations: vec![channel.to_proto()],
2701 ..Default::default()
2702 };
2703
2704 let connection_pool = session.connection_pool().await;
2705 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2706 session.peer.send(connection_id, update.clone())?;
2707 }
2708
2709 send_notifications(&connection_pool, &session.peer, notifications);
2710
2711 response.send(proto::Ack {})?;
2712 Ok(())
2713}
2714
2715/// remove someone from a channel
2716async fn remove_channel_member(
2717 request: proto::RemoveChannelMember,
2718 response: Response<proto::RemoveChannelMember>,
2719 session: UserSession,
2720) -> Result<()> {
2721 let db = session.db().await;
2722 let channel_id = ChannelId::from_proto(request.channel_id);
2723 let member_id = UserId::from_proto(request.user_id);
2724
2725 let RemoveChannelMemberResult {
2726 membership_update,
2727 notification_id,
2728 } = db
2729 .remove_channel_member(channel_id, member_id, session.user_id())
2730 .await?;
2731
2732 let mut connection_pool = session.connection_pool().await;
2733 notify_membership_updated(
2734 &mut connection_pool,
2735 membership_update,
2736 member_id,
2737 &session.peer,
2738 );
2739 for connection_id in connection_pool.user_connection_ids(member_id) {
2740 if let Some(notification_id) = notification_id {
2741 session
2742 .peer
2743 .send(
2744 connection_id,
2745 proto::DeleteNotification {
2746 notification_id: notification_id.to_proto(),
2747 },
2748 )
2749 .trace_err();
2750 }
2751 }
2752
2753 response.send(proto::Ack {})?;
2754 Ok(())
2755}
2756
2757/// Toggle the channel between public and private.
2758/// Care is taken to maintain the invariant that public channels only descend from public channels,
2759/// (though members-only channels can appear at any point in the hierarchy).
2760async fn set_channel_visibility(
2761 request: proto::SetChannelVisibility,
2762 response: Response<proto::SetChannelVisibility>,
2763 session: UserSession,
2764) -> Result<()> {
2765 let db = session.db().await;
2766 let channel_id = ChannelId::from_proto(request.channel_id);
2767 let visibility = request.visibility().into();
2768
2769 let channel_model = db
2770 .set_channel_visibility(channel_id, visibility, session.user_id())
2771 .await?;
2772 let root_id = channel_model.root_id();
2773 let channel = Channel::from_model(channel_model);
2774
2775 let mut connection_pool = session.connection_pool().await;
2776 for (user_id, role) in connection_pool
2777 .channel_user_ids(root_id)
2778 .collect::<Vec<_>>()
2779 .into_iter()
2780 {
2781 let update = if role.can_see_channel(channel.visibility) {
2782 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2783 proto::UpdateChannels {
2784 channels: vec![channel.to_proto()],
2785 ..Default::default()
2786 }
2787 } else {
2788 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2789 proto::UpdateChannels {
2790 delete_channels: vec![channel.id.to_proto()],
2791 ..Default::default()
2792 }
2793 };
2794
2795 for connection_id in connection_pool.user_connection_ids(user_id) {
2796 session.peer.send(connection_id, update.clone())?;
2797 }
2798 }
2799
2800 response.send(proto::Ack {})?;
2801 Ok(())
2802}
2803
2804/// Alter the role for a user in the channel.
2805async fn set_channel_member_role(
2806 request: proto::SetChannelMemberRole,
2807 response: Response<proto::SetChannelMemberRole>,
2808 session: UserSession,
2809) -> Result<()> {
2810 let db = session.db().await;
2811 let channel_id = ChannelId::from_proto(request.channel_id);
2812 let member_id = UserId::from_proto(request.user_id);
2813 let result = db
2814 .set_channel_member_role(
2815 channel_id,
2816 session.user_id(),
2817 member_id,
2818 request.role().into(),
2819 )
2820 .await?;
2821
2822 match result {
2823 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2824 let mut connection_pool = session.connection_pool().await;
2825 notify_membership_updated(
2826 &mut connection_pool,
2827 membership_update,
2828 member_id,
2829 &session.peer,
2830 )
2831 }
2832 db::SetMemberRoleResult::InviteUpdated(channel) => {
2833 let update = proto::UpdateChannels {
2834 channel_invitations: vec![channel.to_proto()],
2835 ..Default::default()
2836 };
2837
2838 for connection_id in session
2839 .connection_pool()
2840 .await
2841 .user_connection_ids(member_id)
2842 {
2843 session.peer.send(connection_id, update.clone())?;
2844 }
2845 }
2846 }
2847
2848 response.send(proto::Ack {})?;
2849 Ok(())
2850}
2851
2852/// Change the name of a channel
2853async fn rename_channel(
2854 request: proto::RenameChannel,
2855 response: Response<proto::RenameChannel>,
2856 session: UserSession,
2857) -> Result<()> {
2858 let db = session.db().await;
2859 let channel_id = ChannelId::from_proto(request.channel_id);
2860 let channel_model = db
2861 .rename_channel(channel_id, session.user_id(), &request.name)
2862 .await?;
2863 let root_id = channel_model.root_id();
2864 let channel = Channel::from_model(channel_model);
2865
2866 response.send(proto::RenameChannelResponse {
2867 channel: Some(channel.to_proto()),
2868 })?;
2869
2870 let connection_pool = session.connection_pool().await;
2871 let update = proto::UpdateChannels {
2872 channels: vec![channel.to_proto()],
2873 ..Default::default()
2874 };
2875 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2876 if role.can_see_channel(channel.visibility) {
2877 session.peer.send(connection_id, update.clone())?;
2878 }
2879 }
2880
2881 Ok(())
2882}
2883
2884/// Move a channel to a new parent.
2885async fn move_channel(
2886 request: proto::MoveChannel,
2887 response: Response<proto::MoveChannel>,
2888 session: UserSession,
2889) -> Result<()> {
2890 let channel_id = ChannelId::from_proto(request.channel_id);
2891 let to = ChannelId::from_proto(request.to);
2892
2893 let (root_id, channels) = session
2894 .db()
2895 .await
2896 .move_channel(channel_id, to, session.user_id())
2897 .await?;
2898
2899 let connection_pool = session.connection_pool().await;
2900 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2901 let channels = channels
2902 .iter()
2903 .filter_map(|channel| {
2904 if role.can_see_channel(channel.visibility) {
2905 Some(channel.to_proto())
2906 } else {
2907 None
2908 }
2909 })
2910 .collect::<Vec<_>>();
2911 if channels.is_empty() {
2912 continue;
2913 }
2914
2915 let update = proto::UpdateChannels {
2916 channels,
2917 ..Default::default()
2918 };
2919
2920 session.peer.send(connection_id, update.clone())?;
2921 }
2922
2923 response.send(Ack {})?;
2924 Ok(())
2925}
2926
2927/// Get the list of channel members
2928async fn get_channel_members(
2929 request: proto::GetChannelMembers,
2930 response: Response<proto::GetChannelMembers>,
2931 session: UserSession,
2932) -> Result<()> {
2933 let db = session.db().await;
2934 let channel_id = ChannelId::from_proto(request.channel_id);
2935 let members = db
2936 .get_channel_participant_details(channel_id, session.user_id())
2937 .await?;
2938 response.send(proto::GetChannelMembersResponse { members })?;
2939 Ok(())
2940}
2941
2942/// Accept or decline a channel invitation.
2943async fn respond_to_channel_invite(
2944 request: proto::RespondToChannelInvite,
2945 response: Response<proto::RespondToChannelInvite>,
2946 session: UserSession,
2947) -> Result<()> {
2948 let db = session.db().await;
2949 let channel_id = ChannelId::from_proto(request.channel_id);
2950 let RespondToChannelInvite {
2951 membership_update,
2952 notifications,
2953 } = db
2954 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
2955 .await?;
2956
2957 let mut connection_pool = session.connection_pool().await;
2958 if let Some(membership_update) = membership_update {
2959 notify_membership_updated(
2960 &mut connection_pool,
2961 membership_update,
2962 session.user_id(),
2963 &session.peer,
2964 );
2965 } else {
2966 let update = proto::UpdateChannels {
2967 remove_channel_invitations: vec![channel_id.to_proto()],
2968 ..Default::default()
2969 };
2970
2971 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
2972 session.peer.send(connection_id, update.clone())?;
2973 }
2974 };
2975
2976 send_notifications(&connection_pool, &session.peer, notifications);
2977
2978 response.send(proto::Ack {})?;
2979
2980 Ok(())
2981}
2982
2983/// Join the channels' room
2984async fn join_channel(
2985 request: proto::JoinChannel,
2986 response: Response<proto::JoinChannel>,
2987 session: UserSession,
2988) -> Result<()> {
2989 let channel_id = ChannelId::from_proto(request.channel_id);
2990 join_channel_internal(channel_id, Box::new(response), session).await
2991}
2992
2993trait JoinChannelInternalResponse {
2994 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2995}
2996impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2997 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2998 Response::<proto::JoinChannel>::send(self, result)
2999 }
3000}
3001impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3002 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3003 Response::<proto::JoinRoom>::send(self, result)
3004 }
3005}
3006
3007async fn join_channel_internal(
3008 channel_id: ChannelId,
3009 response: Box<impl JoinChannelInternalResponse>,
3010 session: UserSession,
3011) -> Result<()> {
3012 let joined_room = {
3013 leave_room_for_session(&session).await?;
3014 let db = session.db().await;
3015
3016 let (joined_room, membership_updated, role) = db
3017 .join_channel(channel_id, session.user_id(), session.connection_id)
3018 .await?;
3019
3020 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3021 let (can_publish, token) = if role == ChannelRole::Guest {
3022 (
3023 false,
3024 live_kit
3025 .guest_token(
3026 &joined_room.room.live_kit_room,
3027 &session.user_id().to_string(),
3028 )
3029 .trace_err()?,
3030 )
3031 } else {
3032 (
3033 true,
3034 live_kit
3035 .room_token(
3036 &joined_room.room.live_kit_room,
3037 &session.user_id().to_string(),
3038 )
3039 .trace_err()?,
3040 )
3041 };
3042
3043 Some(LiveKitConnectionInfo {
3044 server_url: live_kit.url().into(),
3045 token,
3046 can_publish,
3047 })
3048 });
3049
3050 response.send(proto::JoinRoomResponse {
3051 room: Some(joined_room.room.clone()),
3052 channel_id: joined_room
3053 .channel
3054 .as_ref()
3055 .map(|channel| channel.id.to_proto()),
3056 live_kit_connection_info,
3057 })?;
3058
3059 let mut connection_pool = session.connection_pool().await;
3060 if let Some(membership_updated) = membership_updated {
3061 notify_membership_updated(
3062 &mut connection_pool,
3063 membership_updated,
3064 session.user_id(),
3065 &session.peer,
3066 );
3067 }
3068
3069 room_updated(&joined_room.room, &session.peer);
3070
3071 joined_room
3072 };
3073
3074 channel_updated(
3075 &joined_room
3076 .channel
3077 .ok_or_else(|| anyhow!("channel not returned"))?,
3078 &joined_room.room,
3079 &session.peer,
3080 &*session.connection_pool().await,
3081 );
3082
3083 update_user_contacts(session.user_id(), &session).await?;
3084 Ok(())
3085}
3086
3087/// Start editing the channel notes
3088async fn join_channel_buffer(
3089 request: proto::JoinChannelBuffer,
3090 response: Response<proto::JoinChannelBuffer>,
3091 session: UserSession,
3092) -> Result<()> {
3093 let db = session.db().await;
3094 let channel_id = ChannelId::from_proto(request.channel_id);
3095
3096 let open_response = db
3097 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3098 .await?;
3099
3100 let collaborators = open_response.collaborators.clone();
3101 response.send(open_response)?;
3102
3103 let update = UpdateChannelBufferCollaborators {
3104 channel_id: channel_id.to_proto(),
3105 collaborators: collaborators.clone(),
3106 };
3107 channel_buffer_updated(
3108 session.connection_id,
3109 collaborators
3110 .iter()
3111 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3112 &update,
3113 &session.peer,
3114 );
3115
3116 Ok(())
3117}
3118
3119/// Edit the channel notes
3120async fn update_channel_buffer(
3121 request: proto::UpdateChannelBuffer,
3122 session: UserSession,
3123) -> Result<()> {
3124 let db = session.db().await;
3125 let channel_id = ChannelId::from_proto(request.channel_id);
3126
3127 let (collaborators, non_collaborators, epoch, version) = db
3128 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3129 .await?;
3130
3131 channel_buffer_updated(
3132 session.connection_id,
3133 collaborators,
3134 &proto::UpdateChannelBuffer {
3135 channel_id: channel_id.to_proto(),
3136 operations: request.operations,
3137 },
3138 &session.peer,
3139 );
3140
3141 let pool = &*session.connection_pool().await;
3142
3143 broadcast(
3144 None,
3145 non_collaborators
3146 .iter()
3147 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3148 |peer_id| {
3149 session.peer.send(
3150 peer_id,
3151 proto::UpdateChannels {
3152 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3153 channel_id: channel_id.to_proto(),
3154 epoch: epoch as u64,
3155 version: version.clone(),
3156 }],
3157 ..Default::default()
3158 },
3159 )
3160 },
3161 );
3162
3163 Ok(())
3164}
3165
3166/// Rejoin the channel notes after a connection blip
3167async fn rejoin_channel_buffers(
3168 request: proto::RejoinChannelBuffers,
3169 response: Response<proto::RejoinChannelBuffers>,
3170 session: UserSession,
3171) -> Result<()> {
3172 let db = session.db().await;
3173 let buffers = db
3174 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3175 .await?;
3176
3177 for rejoined_buffer in &buffers {
3178 let collaborators_to_notify = rejoined_buffer
3179 .buffer
3180 .collaborators
3181 .iter()
3182 .filter_map(|c| Some(c.peer_id?.into()));
3183 channel_buffer_updated(
3184 session.connection_id,
3185 collaborators_to_notify,
3186 &proto::UpdateChannelBufferCollaborators {
3187 channel_id: rejoined_buffer.buffer.channel_id,
3188 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3189 },
3190 &session.peer,
3191 );
3192 }
3193
3194 response.send(proto::RejoinChannelBuffersResponse {
3195 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3196 })?;
3197
3198 Ok(())
3199}
3200
3201/// Stop editing the channel notes
3202async fn leave_channel_buffer(
3203 request: proto::LeaveChannelBuffer,
3204 response: Response<proto::LeaveChannelBuffer>,
3205 session: UserSession,
3206) -> Result<()> {
3207 let db = session.db().await;
3208 let channel_id = ChannelId::from_proto(request.channel_id);
3209
3210 let left_buffer = db
3211 .leave_channel_buffer(channel_id, session.connection_id)
3212 .await?;
3213
3214 response.send(Ack {})?;
3215
3216 channel_buffer_updated(
3217 session.connection_id,
3218 left_buffer.connections,
3219 &proto::UpdateChannelBufferCollaborators {
3220 channel_id: channel_id.to_proto(),
3221 collaborators: left_buffer.collaborators,
3222 },
3223 &session.peer,
3224 );
3225
3226 Ok(())
3227}
3228
3229fn channel_buffer_updated<T: EnvelopedMessage>(
3230 sender_id: ConnectionId,
3231 collaborators: impl IntoIterator<Item = ConnectionId>,
3232 message: &T,
3233 peer: &Peer,
3234) {
3235 broadcast(Some(sender_id), collaborators, |peer_id| {
3236 peer.send(peer_id, message.clone())
3237 });
3238}
3239
3240fn send_notifications(
3241 connection_pool: &ConnectionPool,
3242 peer: &Peer,
3243 notifications: db::NotificationBatch,
3244) {
3245 for (user_id, notification) in notifications {
3246 for connection_id in connection_pool.user_connection_ids(user_id) {
3247 if let Err(error) = peer.send(
3248 connection_id,
3249 proto::AddNotification {
3250 notification: Some(notification.clone()),
3251 },
3252 ) {
3253 tracing::error!(
3254 "failed to send notification to {:?} {}",
3255 connection_id,
3256 error
3257 );
3258 }
3259 }
3260 }
3261}
3262
3263/// Send a message to the channel
3264async fn send_channel_message(
3265 request: proto::SendChannelMessage,
3266 response: Response<proto::SendChannelMessage>,
3267 session: UserSession,
3268) -> Result<()> {
3269 // Validate the message body.
3270 let body = request.body.trim().to_string();
3271 if body.len() > MAX_MESSAGE_LEN {
3272 return Err(anyhow!("message is too long"))?;
3273 }
3274 if body.is_empty() {
3275 return Err(anyhow!("message can't be blank"))?;
3276 }
3277
3278 // TODO: adjust mentions if body is trimmed
3279
3280 let timestamp = OffsetDateTime::now_utc();
3281 let nonce = request
3282 .nonce
3283 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3284
3285 let channel_id = ChannelId::from_proto(request.channel_id);
3286 let CreatedChannelMessage {
3287 message_id,
3288 participant_connection_ids,
3289 channel_members,
3290 notifications,
3291 } = session
3292 .db()
3293 .await
3294 .create_channel_message(
3295 channel_id,
3296 session.user_id(),
3297 &body,
3298 &request.mentions,
3299 timestamp,
3300 nonce.clone().into(),
3301 match request.reply_to_message_id {
3302 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3303 None => None,
3304 },
3305 )
3306 .await?;
3307
3308 let message = proto::ChannelMessage {
3309 sender_id: session.user_id().to_proto(),
3310 id: message_id.to_proto(),
3311 body,
3312 mentions: request.mentions,
3313 timestamp: timestamp.unix_timestamp() as u64,
3314 nonce: Some(nonce),
3315 reply_to_message_id: request.reply_to_message_id,
3316 edited_at: None,
3317 };
3318 broadcast(
3319 Some(session.connection_id),
3320 participant_connection_ids,
3321 |connection| {
3322 session.peer.send(
3323 connection,
3324 proto::ChannelMessageSent {
3325 channel_id: channel_id.to_proto(),
3326 message: Some(message.clone()),
3327 },
3328 )
3329 },
3330 );
3331 response.send(proto::SendChannelMessageResponse {
3332 message: Some(message),
3333 })?;
3334
3335 let pool = &*session.connection_pool().await;
3336 broadcast(
3337 None,
3338 channel_members
3339 .iter()
3340 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3341 |peer_id| {
3342 session.peer.send(
3343 peer_id,
3344 proto::UpdateChannels {
3345 latest_channel_message_ids: vec![proto::ChannelMessageId {
3346 channel_id: channel_id.to_proto(),
3347 message_id: message_id.to_proto(),
3348 }],
3349 ..Default::default()
3350 },
3351 )
3352 },
3353 );
3354 send_notifications(pool, &session.peer, notifications);
3355
3356 Ok(())
3357}
3358
3359/// Delete a channel message
3360async fn remove_channel_message(
3361 request: proto::RemoveChannelMessage,
3362 response: Response<proto::RemoveChannelMessage>,
3363 session: UserSession,
3364) -> Result<()> {
3365 let channel_id = ChannelId::from_proto(request.channel_id);
3366 let message_id = MessageId::from_proto(request.message_id);
3367 let connection_ids = session
3368 .db()
3369 .await
3370 .remove_channel_message(channel_id, message_id, session.user_id())
3371 .await?;
3372 broadcast(Some(session.connection_id), connection_ids, |connection| {
3373 session.peer.send(connection, request.clone())
3374 });
3375 response.send(proto::Ack {})?;
3376 Ok(())
3377}
3378
3379async fn update_channel_message(
3380 request: proto::UpdateChannelMessage,
3381 response: Response<proto::UpdateChannelMessage>,
3382 session: UserSession,
3383) -> Result<()> {
3384 let channel_id = ChannelId::from_proto(request.channel_id);
3385 let message_id = MessageId::from_proto(request.message_id);
3386 let updated_at = OffsetDateTime::now_utc();
3387 let UpdatedChannelMessage {
3388 message_id,
3389 participant_connection_ids,
3390 notifications,
3391 reply_to_message_id,
3392 timestamp,
3393 } = session
3394 .db()
3395 .await
3396 .update_channel_message(
3397 channel_id,
3398 message_id,
3399 session.user_id(),
3400 request.body.as_str(),
3401 &request.mentions,
3402 updated_at,
3403 )
3404 .await?;
3405
3406 let nonce = request
3407 .nonce
3408 .clone()
3409 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3410
3411 let message = proto::ChannelMessage {
3412 sender_id: session.user_id().to_proto(),
3413 id: message_id.to_proto(),
3414 body: request.body.clone(),
3415 mentions: request.mentions.clone(),
3416 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3417 nonce: Some(nonce),
3418 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3419 edited_at: Some(updated_at.unix_timestamp() as u64),
3420 };
3421
3422 response.send(proto::Ack {})?;
3423
3424 let pool = &*session.connection_pool().await;
3425 broadcast(
3426 Some(session.connection_id),
3427 participant_connection_ids,
3428 |connection| {
3429 session.peer.send(
3430 connection,
3431 proto::ChannelMessageUpdate {
3432 channel_id: channel_id.to_proto(),
3433 message: Some(message.clone()),
3434 },
3435 )
3436 },
3437 );
3438
3439 send_notifications(pool, &session.peer, notifications);
3440
3441 Ok(())
3442}
3443
3444/// Mark a channel message as read
3445async fn acknowledge_channel_message(
3446 request: proto::AckChannelMessage,
3447 session: UserSession,
3448) -> Result<()> {
3449 let channel_id = ChannelId::from_proto(request.channel_id);
3450 let message_id = MessageId::from_proto(request.message_id);
3451 let notifications = session
3452 .db()
3453 .await
3454 .observe_channel_message(channel_id, session.user_id(), message_id)
3455 .await?;
3456 send_notifications(
3457 &*session.connection_pool().await,
3458 &session.peer,
3459 notifications,
3460 );
3461 Ok(())
3462}
3463
3464/// Mark a buffer version as synced
3465async fn acknowledge_buffer_version(
3466 request: proto::AckBufferOperation,
3467 session: UserSession,
3468) -> Result<()> {
3469 let buffer_id = BufferId::from_proto(request.buffer_id);
3470 session
3471 .db()
3472 .await
3473 .observe_buffer_version(
3474 buffer_id,
3475 session.user_id(),
3476 request.epoch as i32,
3477 &request.version,
3478 )
3479 .await?;
3480 Ok(())
3481}
3482
3483struct CompleteWithLanguageModelRateLimit;
3484
3485impl RateLimit for CompleteWithLanguageModelRateLimit {
3486 fn capacity() -> usize {
3487 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3488 .ok()
3489 .and_then(|v| v.parse().ok())
3490 .unwrap_or(120) // Picked arbitrarily
3491 }
3492
3493 fn refill_duration() -> chrono::Duration {
3494 chrono::Duration::hours(1)
3495 }
3496
3497 fn db_name() -> &'static str {
3498 "complete-with-language-model"
3499 }
3500}
3501
3502async fn complete_with_language_model(
3503 request: proto::CompleteWithLanguageModel,
3504 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3505 session: Session,
3506 open_ai_api_key: Option<Arc<str>>,
3507 google_ai_api_key: Option<Arc<str>>,
3508) -> Result<()> {
3509 let Some(session) = session.for_user() else {
3510 return Err(anyhow!("user not found"))?;
3511 };
3512 authorize_access_to_language_models(&session).await?;
3513 session
3514 .rate_limiter
3515 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
3516 .await?;
3517
3518 if request.model.starts_with("gpt") {
3519 let api_key =
3520 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
3521 complete_with_open_ai(request, response, session, api_key).await?;
3522 } else if request.model.starts_with("gemini") {
3523 let api_key = google_ai_api_key
3524 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3525 complete_with_google_ai(request, response, session, api_key).await?;
3526 }
3527
3528 Ok(())
3529}
3530
3531async fn complete_with_open_ai(
3532 request: proto::CompleteWithLanguageModel,
3533 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3534 session: UserSession,
3535 api_key: Arc<str>,
3536) -> Result<()> {
3537 const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
3538
3539 let mut completion_stream = open_ai::stream_completion(
3540 &session.http_client,
3541 OPEN_AI_API_URL,
3542 &api_key,
3543 crate::ai::language_model_request_to_open_ai(request)?,
3544 )
3545 .await
3546 .context("open_ai::stream_completion request failed")?;
3547
3548 while let Some(event) = completion_stream.next().await {
3549 let event = event?;
3550 response.send(proto::LanguageModelResponse {
3551 choices: event
3552 .choices
3553 .into_iter()
3554 .map(|choice| proto::LanguageModelChoiceDelta {
3555 index: choice.index,
3556 delta: Some(proto::LanguageModelResponseMessage {
3557 role: choice.delta.role.map(|role| match role {
3558 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
3559 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
3560 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
3561 } as i32),
3562 content: choice.delta.content,
3563 }),
3564 finish_reason: choice.finish_reason,
3565 })
3566 .collect(),
3567 })?;
3568 }
3569
3570 Ok(())
3571}
3572
3573async fn complete_with_google_ai(
3574 request: proto::CompleteWithLanguageModel,
3575 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3576 session: UserSession,
3577 api_key: Arc<str>,
3578) -> Result<()> {
3579 let mut stream = google_ai::stream_generate_content(
3580 &session.http_client,
3581 google_ai::API_URL,
3582 api_key.as_ref(),
3583 crate::ai::language_model_request_to_google_ai(request)?,
3584 )
3585 .await
3586 .context("google_ai::stream_generate_content request failed")?;
3587
3588 while let Some(event) = stream.next().await {
3589 let event = event?;
3590 response.send(proto::LanguageModelResponse {
3591 choices: event
3592 .candidates
3593 .unwrap_or_default()
3594 .into_iter()
3595 .map(|candidate| proto::LanguageModelChoiceDelta {
3596 index: candidate.index as u32,
3597 delta: Some(proto::LanguageModelResponseMessage {
3598 role: Some(match candidate.content.role {
3599 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
3600 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
3601 } as i32),
3602 content: Some(
3603 candidate
3604 .content
3605 .parts
3606 .into_iter()
3607 .filter_map(|part| match part {
3608 google_ai::Part::TextPart(part) => Some(part.text),
3609 google_ai::Part::InlineDataPart(_) => None,
3610 })
3611 .collect(),
3612 ),
3613 }),
3614 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
3615 })
3616 .collect(),
3617 })?;
3618 }
3619
3620 Ok(())
3621}
3622
3623struct CountTokensWithLanguageModelRateLimit;
3624
3625impl RateLimit for CountTokensWithLanguageModelRateLimit {
3626 fn capacity() -> usize {
3627 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3628 .ok()
3629 .and_then(|v| v.parse().ok())
3630 .unwrap_or(600) // Picked arbitrarily
3631 }
3632
3633 fn refill_duration() -> chrono::Duration {
3634 chrono::Duration::hours(1)
3635 }
3636
3637 fn db_name() -> &'static str {
3638 "count-tokens-with-language-model"
3639 }
3640}
3641
3642async fn count_tokens_with_language_model(
3643 request: proto::CountTokensWithLanguageModel,
3644 response: Response<proto::CountTokensWithLanguageModel>,
3645 session: UserSession,
3646 google_ai_api_key: Option<Arc<str>>,
3647) -> Result<()> {
3648 authorize_access_to_language_models(&session).await?;
3649
3650 if !request.model.starts_with("gemini") {
3651 return Err(anyhow!(
3652 "counting tokens for model: {:?} is not supported",
3653 request.model
3654 ))?;
3655 }
3656
3657 session
3658 .rate_limiter
3659 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
3660 .await?;
3661
3662 let api_key = google_ai_api_key
3663 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3664 let tokens_response = google_ai::count_tokens(
3665 &session.http_client,
3666 google_ai::API_URL,
3667 &api_key,
3668 crate::ai::count_tokens_request_to_google_ai(request)?,
3669 )
3670 .await?;
3671 response.send(proto::CountTokensResponse {
3672 token_count: tokens_response.total_tokens as u32,
3673 })?;
3674 Ok(())
3675}
3676
3677async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
3678 let db = session.db().await;
3679 let flags = db.get_user_flags(session.user_id()).await?;
3680 if flags.iter().any(|flag| flag == "language-models") {
3681 Ok(())
3682 } else {
3683 Err(anyhow!("permission denied"))?
3684 }
3685}
3686
3687/// Start receiving chat updates for a channel
3688async fn join_channel_chat(
3689 request: proto::JoinChannelChat,
3690 response: Response<proto::JoinChannelChat>,
3691 session: UserSession,
3692) -> Result<()> {
3693 let channel_id = ChannelId::from_proto(request.channel_id);
3694
3695 let db = session.db().await;
3696 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3697 .await?;
3698 let messages = db
3699 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3700 .await?;
3701 response.send(proto::JoinChannelChatResponse {
3702 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3703 messages,
3704 })?;
3705 Ok(())
3706}
3707
3708/// Stop receiving chat updates for a channel
3709async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
3710 let channel_id = ChannelId::from_proto(request.channel_id);
3711 session
3712 .db()
3713 .await
3714 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3715 .await?;
3716 Ok(())
3717}
3718
3719/// Retrieve the chat history for a channel
3720async fn get_channel_messages(
3721 request: proto::GetChannelMessages,
3722 response: Response<proto::GetChannelMessages>,
3723 session: UserSession,
3724) -> Result<()> {
3725 let channel_id = ChannelId::from_proto(request.channel_id);
3726 let messages = session
3727 .db()
3728 .await
3729 .get_channel_messages(
3730 channel_id,
3731 session.user_id(),
3732 MESSAGE_COUNT_PER_PAGE,
3733 Some(MessageId::from_proto(request.before_message_id)),
3734 )
3735 .await?;
3736 response.send(proto::GetChannelMessagesResponse {
3737 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3738 messages,
3739 })?;
3740 Ok(())
3741}
3742
3743/// Retrieve specific chat messages
3744async fn get_channel_messages_by_id(
3745 request: proto::GetChannelMessagesById,
3746 response: Response<proto::GetChannelMessagesById>,
3747 session: UserSession,
3748) -> Result<()> {
3749 let message_ids = request
3750 .message_ids
3751 .iter()
3752 .map(|id| MessageId::from_proto(*id))
3753 .collect::<Vec<_>>();
3754 let messages = session
3755 .db()
3756 .await
3757 .get_channel_messages_by_id(session.user_id(), &message_ids)
3758 .await?;
3759 response.send(proto::GetChannelMessagesResponse {
3760 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3761 messages,
3762 })?;
3763 Ok(())
3764}
3765
3766/// Retrieve the current users notifications
3767async fn get_notifications(
3768 request: proto::GetNotifications,
3769 response: Response<proto::GetNotifications>,
3770 session: UserSession,
3771) -> Result<()> {
3772 let notifications = session
3773 .db()
3774 .await
3775 .get_notifications(
3776 session.user_id(),
3777 NOTIFICATION_COUNT_PER_PAGE,
3778 request
3779 .before_id
3780 .map(|id| db::NotificationId::from_proto(id)),
3781 )
3782 .await?;
3783 response.send(proto::GetNotificationsResponse {
3784 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3785 notifications,
3786 })?;
3787 Ok(())
3788}
3789
3790/// Mark notifications as read
3791async fn mark_notification_as_read(
3792 request: proto::MarkNotificationRead,
3793 response: Response<proto::MarkNotificationRead>,
3794 session: UserSession,
3795) -> Result<()> {
3796 let database = &session.db().await;
3797 let notifications = database
3798 .mark_notification_as_read_by_id(
3799 session.user_id(),
3800 NotificationId::from_proto(request.notification_id),
3801 )
3802 .await?;
3803 send_notifications(
3804 &*session.connection_pool().await,
3805 &session.peer,
3806 notifications,
3807 );
3808 response.send(proto::Ack {})?;
3809 Ok(())
3810}
3811
3812/// Get the current users information
3813async fn get_private_user_info(
3814 _request: proto::GetPrivateUserInfo,
3815 response: Response<proto::GetPrivateUserInfo>,
3816 session: UserSession,
3817) -> Result<()> {
3818 let db = session.db().await;
3819
3820 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3821 let user = db
3822 .get_user_by_id(session.user_id())
3823 .await?
3824 .ok_or_else(|| anyhow!("user not found"))?;
3825 let flags = db.get_user_flags(session.user_id()).await?;
3826
3827 response.send(proto::GetPrivateUserInfoResponse {
3828 metrics_id,
3829 staff: user.admin,
3830 flags,
3831 })?;
3832 Ok(())
3833}
3834
3835fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3836 match message {
3837 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3838 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3839 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3840 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3841 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3842 code: frame.code.into(),
3843 reason: frame.reason,
3844 })),
3845 }
3846}
3847
3848fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3849 match message {
3850 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3851 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3852 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3853 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3854 AxumMessage::Close(frame) => {
3855 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3856 code: frame.code.into(),
3857 reason: frame.reason,
3858 }))
3859 }
3860 }
3861}
3862
3863fn notify_membership_updated(
3864 connection_pool: &mut ConnectionPool,
3865 result: MembershipUpdated,
3866 user_id: UserId,
3867 peer: &Peer,
3868) {
3869 for membership in &result.new_channels.channel_memberships {
3870 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3871 }
3872 for channel_id in &result.removed_channels {
3873 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3874 }
3875
3876 let user_channels_update = proto::UpdateUserChannels {
3877 channel_memberships: result
3878 .new_channels
3879 .channel_memberships
3880 .iter()
3881 .map(|cm| proto::ChannelMembership {
3882 channel_id: cm.channel_id.to_proto(),
3883 role: cm.role.into(),
3884 })
3885 .collect(),
3886 ..Default::default()
3887 };
3888
3889 let mut update = build_channels_update(result.new_channels, vec![]);
3890 update.delete_channels = result
3891 .removed_channels
3892 .into_iter()
3893 .map(|id| id.to_proto())
3894 .collect();
3895 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3896
3897 for connection_id in connection_pool.user_connection_ids(user_id) {
3898 peer.send(connection_id, user_channels_update.clone())
3899 .trace_err();
3900 peer.send(connection_id, update.clone()).trace_err();
3901 }
3902}
3903
3904fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3905 proto::UpdateUserChannels {
3906 channel_memberships: channels
3907 .channel_memberships
3908 .iter()
3909 .map(|m| proto::ChannelMembership {
3910 channel_id: m.channel_id.to_proto(),
3911 role: m.role.into(),
3912 })
3913 .collect(),
3914 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3915 observed_channel_message_id: channels.observed_channel_messages.clone(),
3916 }
3917}
3918
3919fn build_channels_update(
3920 channels: ChannelsForUser,
3921 channel_invites: Vec<db::Channel>,
3922) -> proto::UpdateChannels {
3923 let mut update = proto::UpdateChannels::default();
3924
3925 for channel in channels.channels {
3926 update.channels.push(channel.to_proto());
3927 }
3928
3929 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3930 update.latest_channel_message_ids = channels.latest_channel_messages;
3931
3932 for (channel_id, participants) in channels.channel_participants {
3933 update
3934 .channel_participants
3935 .push(proto::ChannelParticipants {
3936 channel_id: channel_id.to_proto(),
3937 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3938 });
3939 }
3940
3941 for channel in channel_invites {
3942 update.channel_invitations.push(channel.to_proto());
3943 }
3944 for project in channels.hosted_projects {
3945 update.hosted_projects.push(project);
3946 }
3947
3948 update
3949}
3950
3951fn build_initial_contacts_update(
3952 contacts: Vec<db::Contact>,
3953 pool: &ConnectionPool,
3954) -> proto::UpdateContacts {
3955 let mut update = proto::UpdateContacts::default();
3956
3957 for contact in contacts {
3958 match contact {
3959 db::Contact::Accepted { user_id, busy } => {
3960 update.contacts.push(contact_for_user(user_id, busy, &pool));
3961 }
3962 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3963 db::Contact::Incoming { user_id } => {
3964 update
3965 .incoming_requests
3966 .push(proto::IncomingContactRequest {
3967 requester_id: user_id.to_proto(),
3968 })
3969 }
3970 }
3971 }
3972
3973 update
3974}
3975
3976fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3977 proto::Contact {
3978 user_id: user_id.to_proto(),
3979 online: pool.is_user_online(user_id),
3980 busy,
3981 }
3982}
3983
3984fn room_updated(room: &proto::Room, peer: &Peer) {
3985 broadcast(
3986 None,
3987 room.participants
3988 .iter()
3989 .filter_map(|participant| Some(participant.peer_id?.into())),
3990 |peer_id| {
3991 peer.send(
3992 peer_id,
3993 proto::RoomUpdated {
3994 room: Some(room.clone()),
3995 },
3996 )
3997 },
3998 );
3999}
4000
4001fn channel_updated(
4002 channel: &db::channel::Model,
4003 room: &proto::Room,
4004 peer: &Peer,
4005 pool: &ConnectionPool,
4006) {
4007 let participants = room
4008 .participants
4009 .iter()
4010 .map(|p| p.user_id)
4011 .collect::<Vec<_>>();
4012
4013 broadcast(
4014 None,
4015 pool.channel_connection_ids(channel.root_id())
4016 .filter_map(|(channel_id, role)| {
4017 role.can_see_channel(channel.visibility).then(|| channel_id)
4018 }),
4019 |peer_id| {
4020 peer.send(
4021 peer_id,
4022 proto::UpdateChannels {
4023 channel_participants: vec![proto::ChannelParticipants {
4024 channel_id: channel.id.to_proto(),
4025 participant_user_ids: participants.clone(),
4026 }],
4027 ..Default::default()
4028 },
4029 )
4030 },
4031 );
4032}
4033
4034async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4035 let db = session.db().await;
4036
4037 let contacts = db.get_contacts(user_id).await?;
4038 let busy = db.is_user_busy(user_id).await?;
4039
4040 let pool = session.connection_pool().await;
4041 let updated_contact = contact_for_user(user_id, busy, &pool);
4042 for contact in contacts {
4043 if let db::Contact::Accepted {
4044 user_id: contact_user_id,
4045 ..
4046 } = contact
4047 {
4048 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4049 session
4050 .peer
4051 .send(
4052 contact_conn_id,
4053 proto::UpdateContacts {
4054 contacts: vec![updated_contact.clone()],
4055 remove_contacts: Default::default(),
4056 incoming_requests: Default::default(),
4057 remove_incoming_requests: Default::default(),
4058 outgoing_requests: Default::default(),
4059 remove_outgoing_requests: Default::default(),
4060 },
4061 )
4062 .trace_err();
4063 }
4064 }
4065 }
4066 Ok(())
4067}
4068
4069async fn leave_room_for_session(session: &UserSession) -> Result<()> {
4070 let mut contacts_to_update = HashSet::default();
4071
4072 let room_id;
4073 let canceled_calls_to_user_ids;
4074 let live_kit_room;
4075 let delete_live_kit_room;
4076 let room;
4077 let channel;
4078
4079 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
4080 contacts_to_update.insert(session.user_id());
4081
4082 for project in left_room.left_projects.values() {
4083 project_left(project, session);
4084 }
4085
4086 room_id = RoomId::from_proto(left_room.room.id);
4087 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4088 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4089 delete_live_kit_room = left_room.deleted;
4090 room = mem::take(&mut left_room.room);
4091 channel = mem::take(&mut left_room.channel);
4092
4093 room_updated(&room, &session.peer);
4094 } else {
4095 return Ok(());
4096 }
4097
4098 if let Some(channel) = channel {
4099 channel_updated(
4100 &channel,
4101 &room,
4102 &session.peer,
4103 &*session.connection_pool().await,
4104 );
4105 }
4106
4107 {
4108 let pool = session.connection_pool().await;
4109 for canceled_user_id in canceled_calls_to_user_ids {
4110 for connection_id in pool.user_connection_ids(canceled_user_id) {
4111 session
4112 .peer
4113 .send(
4114 connection_id,
4115 proto::CallCanceled {
4116 room_id: room_id.to_proto(),
4117 },
4118 )
4119 .trace_err();
4120 }
4121 contacts_to_update.insert(canceled_user_id);
4122 }
4123 }
4124
4125 for contact_user_id in contacts_to_update {
4126 update_user_contacts(contact_user_id, &session).await?;
4127 }
4128
4129 if let Some(live_kit) = session.live_kit_client.as_ref() {
4130 live_kit
4131 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4132 .await
4133 .trace_err();
4134
4135 if delete_live_kit_room {
4136 live_kit.delete_room(live_kit_room).await.trace_err();
4137 }
4138 }
4139
4140 Ok(())
4141}
4142
4143async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4144 let left_channel_buffers = session
4145 .db()
4146 .await
4147 .leave_channel_buffers(session.connection_id)
4148 .await?;
4149
4150 for left_buffer in left_channel_buffers {
4151 channel_buffer_updated(
4152 session.connection_id,
4153 left_buffer.connections,
4154 &proto::UpdateChannelBufferCollaborators {
4155 channel_id: left_buffer.channel_id.to_proto(),
4156 collaborators: left_buffer.collaborators,
4157 },
4158 &session.peer,
4159 );
4160 }
4161
4162 Ok(())
4163}
4164
4165fn project_left(project: &db::LeftProject, session: &UserSession) {
4166 for connection_id in &project.connection_ids {
4167 if project.host_user_id == Some(session.user_id()) {
4168 session
4169 .peer
4170 .send(
4171 *connection_id,
4172 proto::UnshareProject {
4173 project_id: project.id.to_proto(),
4174 },
4175 )
4176 .trace_err();
4177 } else {
4178 session
4179 .peer
4180 .send(
4181 *connection_id,
4182 proto::RemoveProjectCollaborator {
4183 project_id: project.id.to_proto(),
4184 peer_id: Some(session.connection_id.into()),
4185 },
4186 )
4187 .trace_err();
4188 }
4189 }
4190}
4191
4192pub trait ResultExt {
4193 type Ok;
4194
4195 fn trace_err(self) -> Option<Self::Ok>;
4196}
4197
4198impl<T, E> ResultExt for Result<T, E>
4199where
4200 E: std::fmt::Debug,
4201{
4202 type Ok = T;
4203
4204 fn trace_err(self) -> Option<T> {
4205 match self {
4206 Ok(value) => Some(value),
4207 Err(error) => {
4208 tracing::error!("{:?}", error);
4209 None
4210 }
4211 }
4212 }
4213}