1mod connection_pool;
2
3use crate::{
4 auth::{self},
5 db::{
6 self, dev_server, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
8 NotificationId, Project, ProjectId, RemoveChannelMemberResult, ReplicaId,
9 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
10 },
11 executor::Executor,
12 AppState, Error, RateLimit, RateLimiter, Result,
13};
14use anyhow::{anyhow, Context as _};
15use async_tungstenite::tungstenite::{
16 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
17};
18use axum::{
19 body::Body,
20 extract::{
21 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
22 ConnectInfo, WebSocketUpgrade,
23 },
24 headers::{Header, HeaderName},
25 http::StatusCode,
26 middleware,
27 response::IntoResponse,
28 routing::get,
29 Extension, Router, TypedHeader,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use core::fmt::{self, Debug, Formatter};
34
35use futures::{
36 channel::oneshot,
37 future::{self, BoxFuture},
38 stream::FuturesUnordered,
39 FutureExt, SinkExt, StreamExt, TryStreamExt,
40};
41use prometheus::{register_int_gauge, IntGauge};
42use rpc::{
43 proto::{
44 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
45 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
46 },
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48};
49use serde::{Serialize, Serializer};
50use std::{
51 any::TypeId,
52 future::Future,
53 marker::PhantomData,
54 mem,
55 net::SocketAddr,
56 ops::{Deref, DerefMut},
57 rc::Rc,
58 sync::{
59 atomic::{AtomicBool, Ordering::SeqCst},
60 Arc, OnceLock,
61 },
62 time::{Duration, Instant},
63};
64use time::OffsetDateTime;
65use tokio::sync::{watch, Semaphore};
66use tower::ServiceBuilder;
67use tracing::{
68 field::{self},
69 info_span, instrument, Instrument,
70};
71use util::{http::IsahcHttpClient, SemanticVersion};
72
73pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
74
75// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
76pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
77
78const MESSAGE_COUNT_PER_PAGE: usize = 100;
79const MAX_MESSAGE_LEN: usize = 1024;
80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
81
82type MessageHandler =
83 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
84
85struct Response<R> {
86 peer: Arc<Peer>,
87 receipt: Receipt<R>,
88 responded: Arc<AtomicBool>,
89}
90
91impl<R: RequestMessage> Response<R> {
92 fn send(self, payload: R::Response) -> Result<()> {
93 self.responded.store(true, SeqCst);
94 self.peer.respond(self.receipt, payload)?;
95 Ok(())
96 }
97}
98
99struct StreamingResponse<R: RequestMessage> {
100 peer: Arc<Peer>,
101 receipt: Receipt<R>,
102}
103
104impl<R: RequestMessage> StreamingResponse<R> {
105 fn send(&self, payload: R::Response) -> Result<()> {
106 self.peer.respond(self.receipt, payload)?;
107 Ok(())
108 }
109}
110
111#[derive(Clone, Debug)]
112pub enum Principal {
113 User(User),
114 Impersonated { user: User, admin: User },
115 DevServer(dev_server::Model),
116}
117
118impl Principal {
119 fn update_span(&self, span: &tracing::Span) {
120 match &self {
121 Principal::User(user) => {
122 span.record("user_id", &user.id.0);
123 span.record("login", &user.github_login);
124 }
125 Principal::Impersonated { user, admin } => {
126 span.record("user_id", &user.id.0);
127 span.record("login", &user.github_login);
128 span.record("impersonator", &admin.github_login);
129 }
130 Principal::DevServer(dev_server) => {
131 span.record("dev_server_id", &dev_server.id.0);
132 }
133 }
134 }
135}
136
137#[derive(Clone)]
138struct Session {
139 principal: Principal,
140 connection_id: ConnectionId,
141 db: Arc<tokio::sync::Mutex<DbHandle>>,
142 peer: Arc<Peer>,
143 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
144 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
145 http_client: IsahcHttpClient,
146 rate_limiter: Arc<RateLimiter>,
147 _executor: Executor,
148}
149
150impl Session {
151 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
152 #[cfg(test)]
153 tokio::task::yield_now().await;
154 let guard = self.db.lock().await;
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 guard
158 }
159
160 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
161 #[cfg(test)]
162 tokio::task::yield_now().await;
163 let guard = self.connection_pool.lock();
164 ConnectionPoolGuard {
165 guard,
166 _not_send: PhantomData,
167 }
168 }
169
170 fn for_user(self) -> Option<UserSession> {
171 UserSession::new(self)
172 }
173
174 fn user_id(&self) -> Option<UserId> {
175 match &self.principal {
176 Principal::User(user) => Some(user.id),
177 Principal::Impersonated { user, .. } => Some(user.id),
178 Principal::DevServer(_) => None,
179 }
180 }
181}
182
183impl Debug for Session {
184 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
185 let mut result = f.debug_struct("Session");
186 match &self.principal {
187 Principal::User(user) => {
188 result.field("user", &user.github_login);
189 }
190 Principal::Impersonated { user, admin } => {
191 result.field("user", &user.github_login);
192 result.field("impersonator", &admin.github_login);
193 }
194 Principal::DevServer(dev_server) => {
195 result.field("dev_server", &dev_server.id);
196 }
197 }
198 result.field("connection_id", &self.connection_id).finish()
199 }
200}
201
202struct UserSession(Session);
203
204impl UserSession {
205 pub fn new(s: Session) -> Option<Self> {
206 s.user_id().map(|_| UserSession(s))
207 }
208 pub fn user_id(&self) -> UserId {
209 self.0.user_id().unwrap()
210 }
211}
212
213impl Deref for UserSession {
214 type Target = Session;
215
216 fn deref(&self) -> &Self::Target {
217 &self.0
218 }
219}
220impl DerefMut for UserSession {
221 fn deref_mut(&mut self) -> &mut Self::Target {
222 &mut self.0
223 }
224}
225
226fn user_handler<M: RequestMessage, Fut>(
227 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
228) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
229where
230 Fut: Send + Future<Output = Result<()>>,
231{
232 let handler = Arc::new(handler);
233 move |message, response, session| {
234 let handler = handler.clone();
235 Box::pin(async move {
236 if let Some(user_session) = session.for_user() {
237 Ok(handler(message, response, user_session).await?)
238 } else {
239 Err(Error::Internal(anyhow!("must be a user")))
240 }
241 })
242 }
243}
244
245fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
246 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
247) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
248where
249 InnertRetFut: Send + Future<Output = Result<()>>,
250{
251 let handler = Arc::new(handler);
252 move |message, session| {
253 let handler = handler.clone();
254 Box::pin(async move {
255 if let Some(user_session) = session.for_user() {
256 Ok(handler(message, user_session).await?)
257 } else {
258 Err(Error::Internal(anyhow!("must be a user")))
259 }
260 })
261 }
262}
263
264struct DbHandle(Arc<Database>);
265
266impl Deref for DbHandle {
267 type Target = Database;
268
269 fn deref(&self) -> &Self::Target {
270 self.0.as_ref()
271 }
272}
273
274pub struct Server {
275 id: parking_lot::Mutex<ServerId>,
276 peer: Arc<Peer>,
277 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
278 app_state: Arc<AppState>,
279 handlers: HashMap<TypeId, MessageHandler>,
280 teardown: watch::Sender<bool>,
281}
282
283pub(crate) struct ConnectionPoolGuard<'a> {
284 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
285 _not_send: PhantomData<Rc<()>>,
286}
287
288#[derive(Serialize)]
289pub struct ServerSnapshot<'a> {
290 peer: &'a Peer,
291 #[serde(serialize_with = "serialize_deref")]
292 connection_pool: ConnectionPoolGuard<'a>,
293}
294
295pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
296where
297 S: Serializer,
298 T: Deref<Target = U>,
299 U: Serialize,
300{
301 Serialize::serialize(value.deref(), serializer)
302}
303
304impl Server {
305 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
306 let mut server = Self {
307 id: parking_lot::Mutex::new(id),
308 peer: Peer::new(id.0 as u32),
309 app_state: app_state.clone(),
310 connection_pool: Default::default(),
311 handlers: Default::default(),
312 teardown: watch::channel(false).0,
313 };
314
315 server
316 .add_request_handler(ping)
317 .add_request_handler(user_handler(create_room))
318 .add_request_handler(user_handler(join_room))
319 .add_request_handler(user_handler(rejoin_room))
320 .add_request_handler(user_handler(leave_room))
321 .add_request_handler(user_handler(set_room_participant_role))
322 .add_request_handler(user_handler(call))
323 .add_request_handler(user_handler(cancel_call))
324 .add_message_handler(user_message_handler(decline_call))
325 .add_request_handler(user_handler(update_participant_location))
326 .add_request_handler(share_project)
327 .add_message_handler(unshare_project)
328 .add_request_handler(user_handler(join_project))
329 .add_request_handler(user_handler(join_hosted_project))
330 .add_message_handler(user_message_handler(leave_project))
331 .add_request_handler(update_project)
332 .add_request_handler(update_worktree)
333 .add_message_handler(start_language_server)
334 .add_message_handler(update_language_server)
335 .add_message_handler(update_diagnostic_summary)
336 .add_message_handler(update_worktree_settings)
337 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
338 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
339 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
340 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
341 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
342 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
343 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
344 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
345 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
346 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
347 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
348 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
349 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
350 .add_request_handler(
351 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
352 )
353 .add_request_handler(
354 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
355 )
356 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
357 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
358 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
359 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
360 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
361 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
362 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
363 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
364 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
365 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
366 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
367 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
368 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
369 .add_message_handler(create_buffer_for_peer)
370 .add_request_handler(update_buffer)
371 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
372 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
373 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
374 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
375 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
376 .add_request_handler(get_users)
377 .add_request_handler(user_handler(fuzzy_search_users))
378 .add_request_handler(user_handler(request_contact))
379 .add_request_handler(user_handler(remove_contact))
380 .add_request_handler(user_handler(respond_to_contact_request))
381 .add_request_handler(user_handler(create_channel))
382 .add_request_handler(user_handler(delete_channel))
383 .add_request_handler(user_handler(invite_channel_member))
384 .add_request_handler(user_handler(remove_channel_member))
385 .add_request_handler(user_handler(set_channel_member_role))
386 .add_request_handler(user_handler(set_channel_visibility))
387 .add_request_handler(user_handler(rename_channel))
388 .add_request_handler(user_handler(join_channel_buffer))
389 .add_request_handler(user_handler(leave_channel_buffer))
390 .add_message_handler(user_message_handler(update_channel_buffer))
391 .add_request_handler(user_handler(rejoin_channel_buffers))
392 .add_request_handler(user_handler(get_channel_members))
393 .add_request_handler(user_handler(respond_to_channel_invite))
394 .add_request_handler(user_handler(join_channel))
395 .add_request_handler(user_handler(join_channel_chat))
396 .add_message_handler(user_message_handler(leave_channel_chat))
397 .add_request_handler(user_handler(send_channel_message))
398 .add_request_handler(user_handler(remove_channel_message))
399 .add_request_handler(user_handler(update_channel_message))
400 .add_request_handler(user_handler(get_channel_messages))
401 .add_request_handler(user_handler(get_channel_messages_by_id))
402 .add_request_handler(user_handler(get_notifications))
403 .add_request_handler(user_handler(mark_notification_as_read))
404 .add_request_handler(user_handler(move_channel))
405 .add_request_handler(user_handler(follow))
406 .add_message_handler(user_message_handler(unfollow))
407 .add_message_handler(user_message_handler(update_followers))
408 .add_request_handler(user_handler(get_private_user_info))
409 .add_message_handler(user_message_handler(acknowledge_channel_message))
410 .add_message_handler(user_message_handler(acknowledge_buffer_version))
411 .add_streaming_request_handler({
412 let app_state = app_state.clone();
413 move |request, response, session| {
414 complete_with_language_model(
415 request,
416 response,
417 session,
418 app_state.config.openai_api_key.clone(),
419 app_state.config.google_ai_api_key.clone(),
420 )
421 }
422 })
423 .add_request_handler({
424 let app_state = app_state.clone();
425 user_handler(move |request, response, session| {
426 count_tokens_with_language_model(
427 request,
428 response,
429 session,
430 app_state.config.google_ai_api_key.clone(),
431 )
432 })
433 });
434
435 Arc::new(server)
436 }
437
438 pub async fn start(&self) -> Result<()> {
439 let server_id = *self.id.lock();
440 let app_state = self.app_state.clone();
441 let peer = self.peer.clone();
442 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
443 let pool = self.connection_pool.clone();
444 let live_kit_client = self.app_state.live_kit_client.clone();
445
446 let span = info_span!("start server");
447 self.app_state.executor.spawn_detached(
448 async move {
449 tracing::info!("waiting for cleanup timeout");
450 timeout.await;
451 tracing::info!("cleanup timeout expired, retrieving stale rooms");
452 if let Some((room_ids, channel_ids)) = app_state
453 .db
454 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
455 .await
456 .trace_err()
457 {
458 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
459 tracing::info!(
460 stale_channel_buffer_count = channel_ids.len(),
461 "retrieved stale channel buffers"
462 );
463
464 for channel_id in channel_ids {
465 if let Some(refreshed_channel_buffer) = app_state
466 .db
467 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
468 .await
469 .trace_err()
470 {
471 for connection_id in refreshed_channel_buffer.connection_ids {
472 peer.send(
473 connection_id,
474 proto::UpdateChannelBufferCollaborators {
475 channel_id: channel_id.to_proto(),
476 collaborators: refreshed_channel_buffer
477 .collaborators
478 .clone(),
479 },
480 )
481 .trace_err();
482 }
483 }
484 }
485
486 for room_id in room_ids {
487 let mut contacts_to_update = HashSet::default();
488 let mut canceled_calls_to_user_ids = Vec::new();
489 let mut live_kit_room = String::new();
490 let mut delete_live_kit_room = false;
491
492 if let Some(mut refreshed_room) = app_state
493 .db
494 .clear_stale_room_participants(room_id, server_id)
495 .await
496 .trace_err()
497 {
498 tracing::info!(
499 room_id = room_id.0,
500 new_participant_count = refreshed_room.room.participants.len(),
501 "refreshed room"
502 );
503 room_updated(&refreshed_room.room, &peer);
504 if let Some(channel) = refreshed_room.channel.as_ref() {
505 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
506 }
507 contacts_to_update
508 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
509 contacts_to_update
510 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
511 canceled_calls_to_user_ids =
512 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
513 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
514 delete_live_kit_room = refreshed_room.room.participants.is_empty();
515 }
516
517 {
518 let pool = pool.lock();
519 for canceled_user_id in canceled_calls_to_user_ids {
520 for connection_id in pool.user_connection_ids(canceled_user_id) {
521 peer.send(
522 connection_id,
523 proto::CallCanceled {
524 room_id: room_id.to_proto(),
525 },
526 )
527 .trace_err();
528 }
529 }
530 }
531
532 for user_id in contacts_to_update {
533 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
534 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
535 if let Some((busy, contacts)) = busy.zip(contacts) {
536 let pool = pool.lock();
537 let updated_contact = contact_for_user(user_id, busy, &pool);
538 for contact in contacts {
539 if let db::Contact::Accepted {
540 user_id: contact_user_id,
541 ..
542 } = contact
543 {
544 for contact_conn_id in
545 pool.user_connection_ids(contact_user_id)
546 {
547 peer.send(
548 contact_conn_id,
549 proto::UpdateContacts {
550 contacts: vec![updated_contact.clone()],
551 remove_contacts: Default::default(),
552 incoming_requests: Default::default(),
553 remove_incoming_requests: Default::default(),
554 outgoing_requests: Default::default(),
555 remove_outgoing_requests: Default::default(),
556 },
557 )
558 .trace_err();
559 }
560 }
561 }
562 }
563 }
564
565 if let Some(live_kit) = live_kit_client.as_ref() {
566 if delete_live_kit_room {
567 live_kit.delete_room(live_kit_room).await.trace_err();
568 }
569 }
570 }
571 }
572
573 app_state
574 .db
575 .delete_stale_servers(&app_state.config.zed_environment, server_id)
576 .await
577 .trace_err();
578 }
579 .instrument(span),
580 );
581 Ok(())
582 }
583
584 pub fn teardown(&self) {
585 self.peer.teardown();
586 self.connection_pool.lock().reset();
587 let _ = self.teardown.send(true);
588 }
589
590 #[cfg(test)]
591 pub fn reset(&self, id: ServerId) {
592 self.teardown();
593 *self.id.lock() = id;
594 self.peer.reset(id.0 as u32);
595 let _ = self.teardown.send(false);
596 }
597
598 #[cfg(test)]
599 pub fn id(&self) -> ServerId {
600 *self.id.lock()
601 }
602
603 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
604 where
605 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
606 Fut: 'static + Send + Future<Output = Result<()>>,
607 M: EnvelopedMessage,
608 {
609 let prev_handler = self.handlers.insert(
610 TypeId::of::<M>(),
611 Box::new(move |envelope, session| {
612 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
613 let received_at = envelope.received_at;
614 tracing::info!(
615 "message received"
616 );
617 let start_time = Instant::now();
618 let future = (handler)(*envelope, session);
619 async move {
620 let result = future.await;
621 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
622 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
623 let queue_duration_ms = total_duration_ms - processing_duration_ms;
624 match result {
625 Err(error) => {
626 tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, "error handling message")
627 }
628 Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
629 }
630 }
631 .boxed()
632 }),
633 );
634 if prev_handler.is_some() {
635 panic!("registered a handler for the same message twice");
636 }
637 self
638 }
639
640 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
641 where
642 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
643 Fut: 'static + Send + Future<Output = Result<()>>,
644 M: EnvelopedMessage,
645 {
646 self.add_handler(move |envelope, session| handler(envelope.payload, session));
647 self
648 }
649
650 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
651 where
652 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
653 Fut: Send + Future<Output = Result<()>>,
654 M: RequestMessage,
655 {
656 let handler = Arc::new(handler);
657 self.add_handler(move |envelope, session| {
658 let receipt = envelope.receipt();
659 let handler = handler.clone();
660 async move {
661 let peer = session.peer.clone();
662 let responded = Arc::new(AtomicBool::default());
663 let response = Response {
664 peer: peer.clone(),
665 responded: responded.clone(),
666 receipt,
667 };
668 match (handler)(envelope.payload, response, session).await {
669 Ok(()) => {
670 if responded.load(std::sync::atomic::Ordering::SeqCst) {
671 Ok(())
672 } else {
673 Err(anyhow!("handler did not send a response"))?
674 }
675 }
676 Err(error) => {
677 let proto_err = match &error {
678 Error::Internal(err) => err.to_proto(),
679 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
680 };
681 peer.respond_with_error(receipt, proto_err)?;
682 Err(error)
683 }
684 }
685 }
686 })
687 }
688
689 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
690 where
691 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
692 Fut: Send + Future<Output = Result<()>>,
693 M: RequestMessage,
694 {
695 let handler = Arc::new(handler);
696 self.add_handler(move |envelope, session| {
697 let receipt = envelope.receipt();
698 let handler = handler.clone();
699 async move {
700 let peer = session.peer.clone();
701 let response = StreamingResponse {
702 peer: peer.clone(),
703 receipt,
704 };
705 match (handler)(envelope.payload, response, session).await {
706 Ok(()) => {
707 peer.end_stream(receipt)?;
708 Ok(())
709 }
710 Err(error) => {
711 let proto_err = match &error {
712 Error::Internal(err) => err.to_proto(),
713 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
714 };
715 peer.respond_with_error(receipt, proto_err)?;
716 Err(error)
717 }
718 }
719 }
720 })
721 }
722
723 #[allow(clippy::too_many_arguments)]
724 pub fn handle_connection(
725 self: &Arc<Self>,
726 connection: Connection,
727 address: String,
728 principal: Principal,
729 zed_version: ZedVersion,
730 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
731 executor: Executor,
732 ) -> impl Future<Output = ()> {
733 let this = self.clone();
734 let span = info_span!("handle connection", %address, impersonator = field::Empty, connection_id = field::Empty);
735 principal.update_span(&span);
736
737 let mut teardown = self.teardown.subscribe();
738 async move {
739 if *teardown.borrow() {
740 tracing::error!("server is tearing down");
741 return
742 }
743 let (connection_id, handle_io, mut incoming_rx) = this
744 .peer
745 .add_connection(connection, {
746 let executor = executor.clone();
747 move |duration| executor.sleep(duration)
748 });
749 tracing::Span::current().record("connection_id", format!("{}", connection_id));
750 tracing::info!("connection opened");
751
752 let http_client = match IsahcHttpClient::new() {
753 Ok(http_client) => http_client,
754 Err(error) => {
755 tracing::error!(?error, "failed to create HTTP client");
756 return;
757 }
758 };
759
760 let session = Session {
761 principal: principal.clone(),
762 connection_id,
763 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
764 peer: this.peer.clone(),
765 connection_pool: this.connection_pool.clone(),
766 live_kit_client: this.app_state.live_kit_client.clone(),
767 http_client,
768 rate_limiter: this.app_state.rate_limiter.clone(),
769 _executor: executor.clone(),
770 };
771
772 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
773 tracing::error!(?error, "failed to send initial client update");
774 return;
775 }
776
777 let handle_io = handle_io.fuse();
778 futures::pin_mut!(handle_io);
779
780 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
781 // This prevents deadlocks when e.g., client A performs a request to client B and
782 // client B performs a request to client A. If both clients stop processing further
783 // messages until their respective request completes, they won't have a chance to
784 // respond to the other client's request and cause a deadlock.
785 //
786 // This arrangement ensures we will attempt to process earlier messages first, but fall
787 // back to processing messages arrived later in the spirit of making progress.
788 let mut foreground_message_handlers = FuturesUnordered::new();
789 let concurrent_handlers = Arc::new(Semaphore::new(256));
790 loop {
791 let next_message = async {
792 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
793 let message = incoming_rx.next().await;
794 (permit, message)
795 }.fuse();
796 futures::pin_mut!(next_message);
797 futures::select_biased! {
798 _ = teardown.changed().fuse() => return,
799 result = handle_io => {
800 if let Err(error) = result {
801 tracing::error!(?error, "error handling I/O");
802 }
803 break;
804 }
805 _ = foreground_message_handlers.next() => {}
806 next_message = next_message => {
807 let (permit, message) = next_message;
808 if let Some(message) = message {
809 let type_name = message.payload_type_name();
810 // note: we copy all the fields from the parent span so we can query them in the logs.
811 // (https://github.com/tokio-rs/tracing/issues/2670).
812 let span = tracing::info_span!("receive message", %connection_id, %address, type_name);
813 principal.update_span(&span);
814 let span_enter = span.enter();
815 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
816 let is_background = message.is_background();
817 let handle_message = (handler)(message, session.clone());
818 drop(span_enter);
819
820 let handle_message = async move {
821 handle_message.await;
822 drop(permit);
823 }.instrument(span);
824 if is_background {
825 executor.spawn_detached(handle_message);
826 } else {
827 foreground_message_handlers.push(handle_message);
828 }
829 } else {
830 tracing::error!("no message handler");
831 }
832 } else {
833 tracing::info!("connection closed");
834 break;
835 }
836 }
837 }
838 }
839
840 drop(foreground_message_handlers);
841 tracing::info!("signing out");
842 if let Err(error) = connection_lost(session, teardown, executor).await {
843 tracing::error!(?error, "error signing out");
844 }
845
846 }.instrument(span)
847 }
848
849 async fn send_initial_client_update(
850 &self,
851 connection_id: ConnectionId,
852 principal: &Principal,
853 zed_version: ZedVersion,
854 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
855 session: &Session,
856 ) -> Result<()> {
857 self.peer.send(
858 connection_id,
859 proto::Hello {
860 peer_id: Some(connection_id.into()),
861 },
862 )?;
863 tracing::info!("sent hello message");
864
865 let Principal::User(user) = principal else {
866 return Ok(());
867 };
868
869 if let Some(send_connection_id) = send_connection_id.take() {
870 let _ = send_connection_id.send(connection_id);
871 }
872
873 if !user.connected_once {
874 self.peer.send(connection_id, proto::ShowContacts {})?;
875 self.app_state
876 .db
877 .set_user_connected_once(user.id, true)
878 .await?;
879 }
880
881 let (contacts, channels_for_user, channel_invites) = future::try_join3(
882 self.app_state.db.get_contacts(user.id),
883 self.app_state.db.get_channels_for_user(user.id),
884 self.app_state.db.get_channel_invites_for_user(user.id),
885 )
886 .await?;
887
888 {
889 let mut pool = self.connection_pool.lock();
890 pool.add_connection(connection_id, user.id, user.admin, zed_version);
891 for membership in &channels_for_user.channel_memberships {
892 pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
893 }
894 self.peer.send(
895 connection_id,
896 build_initial_contacts_update(contacts, &pool),
897 )?;
898 self.peer.send(
899 connection_id,
900 build_update_user_channels(&channels_for_user),
901 )?;
902 self.peer.send(
903 connection_id,
904 build_channels_update(channels_for_user, channel_invites),
905 )?;
906 }
907
908 if let Some(incoming_call) = self.app_state.db.incoming_call_for_user(user.id).await? {
909 self.peer.send(connection_id, incoming_call)?;
910 }
911
912 update_user_contacts(user.id, &session).await?;
913 Ok(())
914 }
915
916 pub async fn invite_code_redeemed(
917 self: &Arc<Self>,
918 inviter_id: UserId,
919 invitee_id: UserId,
920 ) -> Result<()> {
921 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
922 if let Some(code) = &user.invite_code {
923 let pool = self.connection_pool.lock();
924 let invitee_contact = contact_for_user(invitee_id, false, &pool);
925 for connection_id in pool.user_connection_ids(inviter_id) {
926 self.peer.send(
927 connection_id,
928 proto::UpdateContacts {
929 contacts: vec![invitee_contact.clone()],
930 ..Default::default()
931 },
932 )?;
933 self.peer.send(
934 connection_id,
935 proto::UpdateInviteInfo {
936 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
937 count: user.invite_count as u32,
938 },
939 )?;
940 }
941 }
942 }
943 Ok(())
944 }
945
946 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
947 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
948 if let Some(invite_code) = &user.invite_code {
949 let pool = self.connection_pool.lock();
950 for connection_id in pool.user_connection_ids(user_id) {
951 self.peer.send(
952 connection_id,
953 proto::UpdateInviteInfo {
954 url: format!(
955 "{}{}",
956 self.app_state.config.invite_link_prefix, invite_code
957 ),
958 count: user.invite_count as u32,
959 },
960 )?;
961 }
962 }
963 }
964 Ok(())
965 }
966
967 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
968 ServerSnapshot {
969 connection_pool: ConnectionPoolGuard {
970 guard: self.connection_pool.lock(),
971 _not_send: PhantomData,
972 },
973 peer: &self.peer,
974 }
975 }
976}
977
978impl<'a> Deref for ConnectionPoolGuard<'a> {
979 type Target = ConnectionPool;
980
981 fn deref(&self) -> &Self::Target {
982 &self.guard
983 }
984}
985
986impl<'a> DerefMut for ConnectionPoolGuard<'a> {
987 fn deref_mut(&mut self) -> &mut Self::Target {
988 &mut self.guard
989 }
990}
991
992impl<'a> Drop for ConnectionPoolGuard<'a> {
993 fn drop(&mut self) {
994 #[cfg(test)]
995 self.check_invariants();
996 }
997}
998
999fn broadcast<F>(
1000 sender_id: Option<ConnectionId>,
1001 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1002 mut f: F,
1003) where
1004 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1005{
1006 for receiver_id in receiver_ids {
1007 if Some(receiver_id) != sender_id {
1008 if let Err(error) = f(receiver_id) {
1009 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1010 }
1011 }
1012 }
1013}
1014
1015pub struct ProtocolVersion(u32);
1016
1017impl Header for ProtocolVersion {
1018 fn name() -> &'static HeaderName {
1019 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1020 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1021 }
1022
1023 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1024 where
1025 Self: Sized,
1026 I: Iterator<Item = &'i axum::http::HeaderValue>,
1027 {
1028 let version = values
1029 .next()
1030 .ok_or_else(axum::headers::Error::invalid)?
1031 .to_str()
1032 .map_err(|_| axum::headers::Error::invalid())?
1033 .parse()
1034 .map_err(|_| axum::headers::Error::invalid())?;
1035 Ok(Self(version))
1036 }
1037
1038 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1039 values.extend([self.0.to_string().parse().unwrap()]);
1040 }
1041}
1042
1043pub struct AppVersionHeader(SemanticVersion);
1044impl Header for AppVersionHeader {
1045 fn name() -> &'static HeaderName {
1046 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1047 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1048 }
1049
1050 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1051 where
1052 Self: Sized,
1053 I: Iterator<Item = &'i axum::http::HeaderValue>,
1054 {
1055 let version = values
1056 .next()
1057 .ok_or_else(axum::headers::Error::invalid)?
1058 .to_str()
1059 .map_err(|_| axum::headers::Error::invalid())?
1060 .parse()
1061 .map_err(|_| axum::headers::Error::invalid())?;
1062 Ok(Self(version))
1063 }
1064
1065 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1066 values.extend([self.0.to_string().parse().unwrap()]);
1067 }
1068}
1069
1070pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1071 Router::new()
1072 .route("/rpc", get(handle_websocket_request))
1073 .layer(
1074 ServiceBuilder::new()
1075 .layer(Extension(server.app_state.clone()))
1076 .layer(middleware::from_fn(auth::validate_header)),
1077 )
1078 .route("/metrics", get(handle_metrics))
1079 .layer(Extension(server))
1080}
1081
1082pub async fn handle_websocket_request(
1083 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1084 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1085 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1086 Extension(server): Extension<Arc<Server>>,
1087 Extension(principal): Extension<Principal>,
1088 ws: WebSocketUpgrade,
1089) -> axum::response::Response {
1090 if protocol_version != rpc::PROTOCOL_VERSION {
1091 return (
1092 StatusCode::UPGRADE_REQUIRED,
1093 "client must be upgraded".to_string(),
1094 )
1095 .into_response();
1096 }
1097
1098 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1099 return (
1100 StatusCode::UPGRADE_REQUIRED,
1101 "no version header found".to_string(),
1102 )
1103 .into_response();
1104 };
1105
1106 if !version.can_collaborate() {
1107 return (
1108 StatusCode::UPGRADE_REQUIRED,
1109 "client must be upgraded".to_string(),
1110 )
1111 .into_response();
1112 }
1113
1114 let socket_address = socket_address.to_string();
1115 ws.on_upgrade(move |socket| {
1116 let socket = socket
1117 .map_ok(to_tungstenite_message)
1118 .err_into()
1119 .with(|message| async move { Ok(to_axum_message(message)) });
1120 let connection = Connection::new(Box::pin(socket));
1121 async move {
1122 server
1123 .handle_connection(
1124 connection,
1125 socket_address,
1126 principal,
1127 version,
1128 None,
1129 Executor::Production,
1130 )
1131 .await;
1132 }
1133 })
1134}
1135
1136pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1137 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1138 let connections_metric = CONNECTIONS_METRIC
1139 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1140
1141 let connections = server
1142 .connection_pool
1143 .lock()
1144 .connections()
1145 .filter(|connection| !connection.admin)
1146 .count();
1147 connections_metric.set(connections as _);
1148
1149 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1150 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1151 register_int_gauge!(
1152 "shared_projects",
1153 "number of open projects with one or more guests"
1154 )
1155 .unwrap()
1156 });
1157
1158 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1159 shared_projects_metric.set(shared_projects as _);
1160
1161 let encoder = prometheus::TextEncoder::new();
1162 let metric_families = prometheus::gather();
1163 let encoded_metrics = encoder
1164 .encode_to_string(&metric_families)
1165 .map_err(|err| anyhow!("{}", err))?;
1166 Ok(encoded_metrics)
1167}
1168
1169#[instrument(err, skip(executor))]
1170async fn connection_lost(
1171 session: Session,
1172 mut teardown: watch::Receiver<bool>,
1173 executor: Executor,
1174) -> Result<()> {
1175 session.peer.disconnect(session.connection_id);
1176 session
1177 .connection_pool()
1178 .await
1179 .remove_connection(session.connection_id)?;
1180
1181 session
1182 .db()
1183 .await
1184 .connection_lost(session.connection_id)
1185 .await
1186 .trace_err();
1187
1188 futures::select_biased! {
1189 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1190 if let Some(session) = session.for_user() {
1191 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1192 leave_room_for_session(&session).await.trace_err();
1193 leave_channel_buffers_for_session(&session)
1194 .await
1195 .trace_err();
1196
1197 if !session
1198 .connection_pool()
1199 .await
1200 .is_user_online(session.user_id())
1201 {
1202 let db = session.db().await;
1203 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1204 room_updated(&room, &session.peer);
1205 }
1206 }
1207
1208 update_user_contacts(session.user_id(), &session).await?;
1209 }
1210 }
1211 _ = teardown.changed().fuse() => {}
1212 }
1213
1214 Ok(())
1215}
1216
1217/// Acknowledges a ping from a client, used to keep the connection alive.
1218async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1219 response.send(proto::Ack {})?;
1220 Ok(())
1221}
1222
1223/// Creates a new room for calling (outside of channels)
1224async fn create_room(
1225 _request: proto::CreateRoom,
1226 response: Response<proto::CreateRoom>,
1227 session: UserSession,
1228) -> Result<()> {
1229 let live_kit_room = nanoid::nanoid!(30);
1230
1231 let live_kit_connection_info = util::maybe!(async {
1232 let live_kit = session.live_kit_client.as_ref();
1233 let live_kit = live_kit?;
1234 let user_id = session.user_id().to_string();
1235
1236 let token = live_kit
1237 .room_token(&live_kit_room, &user_id.to_string())
1238 .trace_err()?;
1239
1240 Some(proto::LiveKitConnectionInfo {
1241 server_url: live_kit.url().into(),
1242 token,
1243 can_publish: true,
1244 })
1245 })
1246 .await;
1247
1248 let room = session
1249 .db()
1250 .await
1251 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1252 .await?;
1253
1254 response.send(proto::CreateRoomResponse {
1255 room: Some(room.clone()),
1256 live_kit_connection_info,
1257 })?;
1258
1259 update_user_contacts(session.user_id(), &session).await?;
1260 Ok(())
1261}
1262
1263/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1264async fn join_room(
1265 request: proto::JoinRoom,
1266 response: Response<proto::JoinRoom>,
1267 session: UserSession,
1268) -> Result<()> {
1269 let room_id = RoomId::from_proto(request.id);
1270
1271 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1272
1273 if let Some(channel_id) = channel_id {
1274 return join_channel_internal(channel_id, Box::new(response), session).await;
1275 }
1276
1277 let joined_room = {
1278 let room = session
1279 .db()
1280 .await
1281 .join_room(room_id, session.user_id(), session.connection_id)
1282 .await?;
1283 room_updated(&room.room, &session.peer);
1284 room.into_inner()
1285 };
1286
1287 for connection_id in session
1288 .connection_pool()
1289 .await
1290 .user_connection_ids(session.user_id())
1291 {
1292 session
1293 .peer
1294 .send(
1295 connection_id,
1296 proto::CallCanceled {
1297 room_id: room_id.to_proto(),
1298 },
1299 )
1300 .trace_err();
1301 }
1302
1303 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1304 if let Some(token) = live_kit
1305 .room_token(
1306 &joined_room.room.live_kit_room,
1307 &session.user_id().to_string(),
1308 )
1309 .trace_err()
1310 {
1311 Some(proto::LiveKitConnectionInfo {
1312 server_url: live_kit.url().into(),
1313 token,
1314 can_publish: true,
1315 })
1316 } else {
1317 None
1318 }
1319 } else {
1320 None
1321 };
1322
1323 response.send(proto::JoinRoomResponse {
1324 room: Some(joined_room.room),
1325 channel_id: None,
1326 live_kit_connection_info,
1327 })?;
1328
1329 update_user_contacts(session.user_id(), &session).await?;
1330 Ok(())
1331}
1332
1333/// Rejoin room is used to reconnect to a room after connection errors.
1334async fn rejoin_room(
1335 request: proto::RejoinRoom,
1336 response: Response<proto::RejoinRoom>,
1337 session: UserSession,
1338) -> Result<()> {
1339 let room;
1340 let channel;
1341 {
1342 let mut rejoined_room = session
1343 .db()
1344 .await
1345 .rejoin_room(request, session.user_id(), session.connection_id)
1346 .await?;
1347
1348 response.send(proto::RejoinRoomResponse {
1349 room: Some(rejoined_room.room.clone()),
1350 reshared_projects: rejoined_room
1351 .reshared_projects
1352 .iter()
1353 .map(|project| proto::ResharedProject {
1354 id: project.id.to_proto(),
1355 collaborators: project
1356 .collaborators
1357 .iter()
1358 .map(|collaborator| collaborator.to_proto())
1359 .collect(),
1360 })
1361 .collect(),
1362 rejoined_projects: rejoined_room
1363 .rejoined_projects
1364 .iter()
1365 .map(|rejoined_project| proto::RejoinedProject {
1366 id: rejoined_project.id.to_proto(),
1367 worktrees: rejoined_project
1368 .worktrees
1369 .iter()
1370 .map(|worktree| proto::WorktreeMetadata {
1371 id: worktree.id,
1372 root_name: worktree.root_name.clone(),
1373 visible: worktree.visible,
1374 abs_path: worktree.abs_path.clone(),
1375 })
1376 .collect(),
1377 collaborators: rejoined_project
1378 .collaborators
1379 .iter()
1380 .map(|collaborator| collaborator.to_proto())
1381 .collect(),
1382 language_servers: rejoined_project.language_servers.clone(),
1383 })
1384 .collect(),
1385 })?;
1386 room_updated(&rejoined_room.room, &session.peer);
1387
1388 for project in &rejoined_room.reshared_projects {
1389 for collaborator in &project.collaborators {
1390 session
1391 .peer
1392 .send(
1393 collaborator.connection_id,
1394 proto::UpdateProjectCollaborator {
1395 project_id: project.id.to_proto(),
1396 old_peer_id: Some(project.old_connection_id.into()),
1397 new_peer_id: Some(session.connection_id.into()),
1398 },
1399 )
1400 .trace_err();
1401 }
1402
1403 broadcast(
1404 Some(session.connection_id),
1405 project
1406 .collaborators
1407 .iter()
1408 .map(|collaborator| collaborator.connection_id),
1409 |connection_id| {
1410 session.peer.forward_send(
1411 session.connection_id,
1412 connection_id,
1413 proto::UpdateProject {
1414 project_id: project.id.to_proto(),
1415 worktrees: project.worktrees.clone(),
1416 },
1417 )
1418 },
1419 );
1420 }
1421
1422 for project in &rejoined_room.rejoined_projects {
1423 for collaborator in &project.collaborators {
1424 session
1425 .peer
1426 .send(
1427 collaborator.connection_id,
1428 proto::UpdateProjectCollaborator {
1429 project_id: project.id.to_proto(),
1430 old_peer_id: Some(project.old_connection_id.into()),
1431 new_peer_id: Some(session.connection_id.into()),
1432 },
1433 )
1434 .trace_err();
1435 }
1436 }
1437
1438 for project in &mut rejoined_room.rejoined_projects {
1439 for worktree in mem::take(&mut project.worktrees) {
1440 #[cfg(any(test, feature = "test-support"))]
1441 const MAX_CHUNK_SIZE: usize = 2;
1442 #[cfg(not(any(test, feature = "test-support")))]
1443 const MAX_CHUNK_SIZE: usize = 256;
1444
1445 // Stream this worktree's entries.
1446 let message = proto::UpdateWorktree {
1447 project_id: project.id.to_proto(),
1448 worktree_id: worktree.id,
1449 abs_path: worktree.abs_path.clone(),
1450 root_name: worktree.root_name,
1451 updated_entries: worktree.updated_entries,
1452 removed_entries: worktree.removed_entries,
1453 scan_id: worktree.scan_id,
1454 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1455 updated_repositories: worktree.updated_repositories,
1456 removed_repositories: worktree.removed_repositories,
1457 };
1458 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1459 session.peer.send(session.connection_id, update.clone())?;
1460 }
1461
1462 // Stream this worktree's diagnostics.
1463 for summary in worktree.diagnostic_summaries {
1464 session.peer.send(
1465 session.connection_id,
1466 proto::UpdateDiagnosticSummary {
1467 project_id: project.id.to_proto(),
1468 worktree_id: worktree.id,
1469 summary: Some(summary),
1470 },
1471 )?;
1472 }
1473
1474 for settings_file in worktree.settings_files {
1475 session.peer.send(
1476 session.connection_id,
1477 proto::UpdateWorktreeSettings {
1478 project_id: project.id.to_proto(),
1479 worktree_id: worktree.id,
1480 path: settings_file.path,
1481 content: Some(settings_file.content),
1482 },
1483 )?;
1484 }
1485 }
1486
1487 for language_server in &project.language_servers {
1488 session.peer.send(
1489 session.connection_id,
1490 proto::UpdateLanguageServer {
1491 project_id: project.id.to_proto(),
1492 language_server_id: language_server.id,
1493 variant: Some(
1494 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1495 proto::LspDiskBasedDiagnosticsUpdated {},
1496 ),
1497 ),
1498 },
1499 )?;
1500 }
1501 }
1502
1503 let rejoined_room = rejoined_room.into_inner();
1504
1505 room = rejoined_room.room;
1506 channel = rejoined_room.channel;
1507 }
1508
1509 if let Some(channel) = channel {
1510 channel_updated(
1511 &channel,
1512 &room,
1513 &session.peer,
1514 &*session.connection_pool().await,
1515 );
1516 }
1517
1518 update_user_contacts(session.user_id(), &session).await?;
1519 Ok(())
1520}
1521
1522/// leave room disconnects from the room.
1523async fn leave_room(
1524 _: proto::LeaveRoom,
1525 response: Response<proto::LeaveRoom>,
1526 session: UserSession,
1527) -> Result<()> {
1528 leave_room_for_session(&session).await?;
1529 response.send(proto::Ack {})?;
1530 Ok(())
1531}
1532
1533/// Updates the permissions of someone else in the room.
1534async fn set_room_participant_role(
1535 request: proto::SetRoomParticipantRole,
1536 response: Response<proto::SetRoomParticipantRole>,
1537 session: UserSession,
1538) -> Result<()> {
1539 let user_id = UserId::from_proto(request.user_id);
1540 let role = ChannelRole::from(request.role());
1541
1542 let (live_kit_room, can_publish) = {
1543 let room = session
1544 .db()
1545 .await
1546 .set_room_participant_role(
1547 session.user_id(),
1548 RoomId::from_proto(request.room_id),
1549 user_id,
1550 role,
1551 )
1552 .await?;
1553
1554 let live_kit_room = room.live_kit_room.clone();
1555 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1556 room_updated(&room, &session.peer);
1557 (live_kit_room, can_publish)
1558 };
1559
1560 if let Some(live_kit) = session.live_kit_client.as_ref() {
1561 live_kit
1562 .update_participant(
1563 live_kit_room.clone(),
1564 request.user_id.to_string(),
1565 live_kit_server::proto::ParticipantPermission {
1566 can_subscribe: true,
1567 can_publish,
1568 can_publish_data: can_publish,
1569 hidden: false,
1570 recorder: false,
1571 },
1572 )
1573 .await
1574 .trace_err();
1575 }
1576
1577 response.send(proto::Ack {})?;
1578 Ok(())
1579}
1580
1581/// Call someone else into the current room
1582async fn call(
1583 request: proto::Call,
1584 response: Response<proto::Call>,
1585 session: UserSession,
1586) -> Result<()> {
1587 let room_id = RoomId::from_proto(request.room_id);
1588 let calling_user_id = session.user_id();
1589 let calling_connection_id = session.connection_id;
1590 let called_user_id = UserId::from_proto(request.called_user_id);
1591 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1592 if !session
1593 .db()
1594 .await
1595 .has_contact(calling_user_id, called_user_id)
1596 .await?
1597 {
1598 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1599 }
1600
1601 let incoming_call = {
1602 let (room, incoming_call) = &mut *session
1603 .db()
1604 .await
1605 .call(
1606 room_id,
1607 calling_user_id,
1608 calling_connection_id,
1609 called_user_id,
1610 initial_project_id,
1611 )
1612 .await?;
1613 room_updated(&room, &session.peer);
1614 mem::take(incoming_call)
1615 };
1616 update_user_contacts(called_user_id, &session).await?;
1617
1618 let mut calls = session
1619 .connection_pool()
1620 .await
1621 .user_connection_ids(called_user_id)
1622 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1623 .collect::<FuturesUnordered<_>>();
1624
1625 while let Some(call_response) = calls.next().await {
1626 match call_response.as_ref() {
1627 Ok(_) => {
1628 response.send(proto::Ack {})?;
1629 return Ok(());
1630 }
1631 Err(_) => {
1632 call_response.trace_err();
1633 }
1634 }
1635 }
1636
1637 {
1638 let room = session
1639 .db()
1640 .await
1641 .call_failed(room_id, called_user_id)
1642 .await?;
1643 room_updated(&room, &session.peer);
1644 }
1645 update_user_contacts(called_user_id, &session).await?;
1646
1647 Err(anyhow!("failed to ring user"))?
1648}
1649
1650/// Cancel an outgoing call.
1651async fn cancel_call(
1652 request: proto::CancelCall,
1653 response: Response<proto::CancelCall>,
1654 session: UserSession,
1655) -> Result<()> {
1656 let called_user_id = UserId::from_proto(request.called_user_id);
1657 let room_id = RoomId::from_proto(request.room_id);
1658 {
1659 let room = session
1660 .db()
1661 .await
1662 .cancel_call(room_id, session.connection_id, called_user_id)
1663 .await?;
1664 room_updated(&room, &session.peer);
1665 }
1666
1667 for connection_id in session
1668 .connection_pool()
1669 .await
1670 .user_connection_ids(called_user_id)
1671 {
1672 session
1673 .peer
1674 .send(
1675 connection_id,
1676 proto::CallCanceled {
1677 room_id: room_id.to_proto(),
1678 },
1679 )
1680 .trace_err();
1681 }
1682 response.send(proto::Ack {})?;
1683
1684 update_user_contacts(called_user_id, &session).await?;
1685 Ok(())
1686}
1687
1688/// Decline an incoming call.
1689async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1690 let room_id = RoomId::from_proto(message.room_id);
1691 {
1692 let room = session
1693 .db()
1694 .await
1695 .decline_call(Some(room_id), session.user_id())
1696 .await?
1697 .ok_or_else(|| anyhow!("failed to decline call"))?;
1698 room_updated(&room, &session.peer);
1699 }
1700
1701 for connection_id in session
1702 .connection_pool()
1703 .await
1704 .user_connection_ids(session.user_id())
1705 {
1706 session
1707 .peer
1708 .send(
1709 connection_id,
1710 proto::CallCanceled {
1711 room_id: room_id.to_proto(),
1712 },
1713 )
1714 .trace_err();
1715 }
1716 update_user_contacts(session.user_id(), &session).await?;
1717 Ok(())
1718}
1719
1720/// Updates other participants in the room with your current location.
1721async fn update_participant_location(
1722 request: proto::UpdateParticipantLocation,
1723 response: Response<proto::UpdateParticipantLocation>,
1724 session: UserSession,
1725) -> Result<()> {
1726 let room_id = RoomId::from_proto(request.room_id);
1727 let location = request
1728 .location
1729 .ok_or_else(|| anyhow!("invalid location"))?;
1730
1731 let db = session.db().await;
1732 let room = db
1733 .update_room_participant_location(room_id, session.connection_id, location)
1734 .await?;
1735
1736 room_updated(&room, &session.peer);
1737 response.send(proto::Ack {})?;
1738 Ok(())
1739}
1740
1741/// Share a project into the room.
1742async fn share_project(
1743 request: proto::ShareProject,
1744 response: Response<proto::ShareProject>,
1745 session: Session,
1746) -> Result<()> {
1747 let (project_id, room) = &*session
1748 .db()
1749 .await
1750 .share_project(
1751 RoomId::from_proto(request.room_id),
1752 session.connection_id,
1753 &request.worktrees,
1754 )
1755 .await?;
1756 response.send(proto::ShareProjectResponse {
1757 project_id: project_id.to_proto(),
1758 })?;
1759 room_updated(&room, &session.peer);
1760
1761 Ok(())
1762}
1763
1764/// Unshare a project from the room.
1765async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1766 let project_id = ProjectId::from_proto(message.project_id);
1767
1768 let (room, guest_connection_ids) = &*session
1769 .db()
1770 .await
1771 .unshare_project(project_id, session.connection_id)
1772 .await?;
1773
1774 broadcast(
1775 Some(session.connection_id),
1776 guest_connection_ids.iter().copied(),
1777 |conn_id| session.peer.send(conn_id, message.clone()),
1778 );
1779 room_updated(&room, &session.peer);
1780
1781 Ok(())
1782}
1783
1784/// Join someone elses shared project.
1785async fn join_project(
1786 request: proto::JoinProject,
1787 response: Response<proto::JoinProject>,
1788 session: UserSession,
1789) -> Result<()> {
1790 let project_id = ProjectId::from_proto(request.project_id);
1791
1792 tracing::info!(%project_id, "join project");
1793
1794 let (project, replica_id) = &mut *session
1795 .db()
1796 .await
1797 .join_project_in_room(project_id, session.connection_id)
1798 .await?;
1799
1800 join_project_internal(response, session, project, replica_id)
1801}
1802
1803trait JoinProjectInternalResponse {
1804 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1805}
1806impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1807 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1808 Response::<proto::JoinProject>::send(self, result)
1809 }
1810}
1811impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1812 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1813 Response::<proto::JoinHostedProject>::send(self, result)
1814 }
1815}
1816
1817fn join_project_internal(
1818 response: impl JoinProjectInternalResponse,
1819 session: UserSession,
1820 project: &mut Project,
1821 replica_id: &ReplicaId,
1822) -> Result<()> {
1823 let collaborators = project
1824 .collaborators
1825 .iter()
1826 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1827 .map(|collaborator| collaborator.to_proto())
1828 .collect::<Vec<_>>();
1829 let project_id = project.id;
1830 let guest_user_id = session.user_id();
1831
1832 let worktrees = project
1833 .worktrees
1834 .iter()
1835 .map(|(id, worktree)| proto::WorktreeMetadata {
1836 id: *id,
1837 root_name: worktree.root_name.clone(),
1838 visible: worktree.visible,
1839 abs_path: worktree.abs_path.clone(),
1840 })
1841 .collect::<Vec<_>>();
1842
1843 for collaborator in &collaborators {
1844 session
1845 .peer
1846 .send(
1847 collaborator.peer_id.unwrap().into(),
1848 proto::AddProjectCollaborator {
1849 project_id: project_id.to_proto(),
1850 collaborator: Some(proto::Collaborator {
1851 peer_id: Some(session.connection_id.into()),
1852 replica_id: replica_id.0 as u32,
1853 user_id: guest_user_id.to_proto(),
1854 }),
1855 },
1856 )
1857 .trace_err();
1858 }
1859
1860 // First, we send the metadata associated with each worktree.
1861 response.send(proto::JoinProjectResponse {
1862 project_id: project.id.0 as u64,
1863 worktrees: worktrees.clone(),
1864 replica_id: replica_id.0 as u32,
1865 collaborators: collaborators.clone(),
1866 language_servers: project.language_servers.clone(),
1867 role: project.role.into(), // todo
1868 })?;
1869
1870 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1871 #[cfg(any(test, feature = "test-support"))]
1872 const MAX_CHUNK_SIZE: usize = 2;
1873 #[cfg(not(any(test, feature = "test-support")))]
1874 const MAX_CHUNK_SIZE: usize = 256;
1875
1876 // Stream this worktree's entries.
1877 let message = proto::UpdateWorktree {
1878 project_id: project_id.to_proto(),
1879 worktree_id,
1880 abs_path: worktree.abs_path.clone(),
1881 root_name: worktree.root_name,
1882 updated_entries: worktree.entries,
1883 removed_entries: Default::default(),
1884 scan_id: worktree.scan_id,
1885 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1886 updated_repositories: worktree.repository_entries.into_values().collect(),
1887 removed_repositories: Default::default(),
1888 };
1889 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1890 session.peer.send(session.connection_id, update.clone())?;
1891 }
1892
1893 // Stream this worktree's diagnostics.
1894 for summary in worktree.diagnostic_summaries {
1895 session.peer.send(
1896 session.connection_id,
1897 proto::UpdateDiagnosticSummary {
1898 project_id: project_id.to_proto(),
1899 worktree_id: worktree.id,
1900 summary: Some(summary),
1901 },
1902 )?;
1903 }
1904
1905 for settings_file in worktree.settings_files {
1906 session.peer.send(
1907 session.connection_id,
1908 proto::UpdateWorktreeSettings {
1909 project_id: project_id.to_proto(),
1910 worktree_id: worktree.id,
1911 path: settings_file.path,
1912 content: Some(settings_file.content),
1913 },
1914 )?;
1915 }
1916 }
1917
1918 for language_server in &project.language_servers {
1919 session.peer.send(
1920 session.connection_id,
1921 proto::UpdateLanguageServer {
1922 project_id: project_id.to_proto(),
1923 language_server_id: language_server.id,
1924 variant: Some(
1925 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1926 proto::LspDiskBasedDiagnosticsUpdated {},
1927 ),
1928 ),
1929 },
1930 )?;
1931 }
1932
1933 Ok(())
1934}
1935
1936/// Leave someone elses shared project.
1937async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
1938 let sender_id = session.connection_id;
1939 let project_id = ProjectId::from_proto(request.project_id);
1940 let db = session.db().await;
1941 if db.is_hosted_project(project_id).await? {
1942 let project = db.leave_hosted_project(project_id, sender_id).await?;
1943 project_left(&project, &session);
1944 return Ok(());
1945 }
1946
1947 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1948 tracing::info!(
1949 %project_id,
1950 host_user_id = ?project.host_user_id,
1951 host_connection_id = ?project.host_connection_id,
1952 "leave project"
1953 );
1954
1955 project_left(&project, &session);
1956 room_updated(&room, &session.peer);
1957
1958 Ok(())
1959}
1960
1961async fn join_hosted_project(
1962 request: proto::JoinHostedProject,
1963 response: Response<proto::JoinHostedProject>,
1964 session: UserSession,
1965) -> Result<()> {
1966 let (mut project, replica_id) = session
1967 .db()
1968 .await
1969 .join_hosted_project(
1970 ProjectId(request.project_id as i32),
1971 session.user_id(),
1972 session.connection_id,
1973 )
1974 .await?;
1975
1976 join_project_internal(response, session, &mut project, &replica_id)
1977}
1978
1979/// Updates other participants with changes to the project
1980async fn update_project(
1981 request: proto::UpdateProject,
1982 response: Response<proto::UpdateProject>,
1983 session: Session,
1984) -> Result<()> {
1985 let project_id = ProjectId::from_proto(request.project_id);
1986 let (room, guest_connection_ids) = &*session
1987 .db()
1988 .await
1989 .update_project(project_id, session.connection_id, &request.worktrees)
1990 .await?;
1991 broadcast(
1992 Some(session.connection_id),
1993 guest_connection_ids.iter().copied(),
1994 |connection_id| {
1995 session
1996 .peer
1997 .forward_send(session.connection_id, connection_id, request.clone())
1998 },
1999 );
2000 room_updated(&room, &session.peer);
2001 response.send(proto::Ack {})?;
2002
2003 Ok(())
2004}
2005
2006/// Updates other participants with changes to the worktree
2007async fn update_worktree(
2008 request: proto::UpdateWorktree,
2009 response: Response<proto::UpdateWorktree>,
2010 session: Session,
2011) -> Result<()> {
2012 let guest_connection_ids = session
2013 .db()
2014 .await
2015 .update_worktree(&request, session.connection_id)
2016 .await?;
2017
2018 broadcast(
2019 Some(session.connection_id),
2020 guest_connection_ids.iter().copied(),
2021 |connection_id| {
2022 session
2023 .peer
2024 .forward_send(session.connection_id, connection_id, request.clone())
2025 },
2026 );
2027 response.send(proto::Ack {})?;
2028 Ok(())
2029}
2030
2031/// Updates other participants with changes to the diagnostics
2032async fn update_diagnostic_summary(
2033 message: proto::UpdateDiagnosticSummary,
2034 session: Session,
2035) -> Result<()> {
2036 let guest_connection_ids = session
2037 .db()
2038 .await
2039 .update_diagnostic_summary(&message, session.connection_id)
2040 .await?;
2041
2042 broadcast(
2043 Some(session.connection_id),
2044 guest_connection_ids.iter().copied(),
2045 |connection_id| {
2046 session
2047 .peer
2048 .forward_send(session.connection_id, connection_id, message.clone())
2049 },
2050 );
2051
2052 Ok(())
2053}
2054
2055/// Updates other participants with changes to the worktree settings
2056async fn update_worktree_settings(
2057 message: proto::UpdateWorktreeSettings,
2058 session: Session,
2059) -> Result<()> {
2060 let guest_connection_ids = session
2061 .db()
2062 .await
2063 .update_worktree_settings(&message, session.connection_id)
2064 .await?;
2065
2066 broadcast(
2067 Some(session.connection_id),
2068 guest_connection_ids.iter().copied(),
2069 |connection_id| {
2070 session
2071 .peer
2072 .forward_send(session.connection_id, connection_id, message.clone())
2073 },
2074 );
2075
2076 Ok(())
2077}
2078
2079/// Notify other participants that a language server has started.
2080async fn start_language_server(
2081 request: proto::StartLanguageServer,
2082 session: Session,
2083) -> Result<()> {
2084 let guest_connection_ids = session
2085 .db()
2086 .await
2087 .start_language_server(&request, session.connection_id)
2088 .await?;
2089
2090 broadcast(
2091 Some(session.connection_id),
2092 guest_connection_ids.iter().copied(),
2093 |connection_id| {
2094 session
2095 .peer
2096 .forward_send(session.connection_id, connection_id, request.clone())
2097 },
2098 );
2099 Ok(())
2100}
2101
2102/// Notify other participants that a language server has changed.
2103async fn update_language_server(
2104 request: proto::UpdateLanguageServer,
2105 session: Session,
2106) -> Result<()> {
2107 let project_id = ProjectId::from_proto(request.project_id);
2108 let project_connection_ids = session
2109 .db()
2110 .await
2111 .project_connection_ids(project_id, session.connection_id)
2112 .await?;
2113 broadcast(
2114 Some(session.connection_id),
2115 project_connection_ids.iter().copied(),
2116 |connection_id| {
2117 session
2118 .peer
2119 .forward_send(session.connection_id, connection_id, request.clone())
2120 },
2121 );
2122 Ok(())
2123}
2124
2125/// forward a project request to the host. These requests should be read only
2126/// as guests are allowed to send them.
2127async fn forward_read_only_project_request<T>(
2128 request: T,
2129 response: Response<T>,
2130 session: Session,
2131) -> Result<()>
2132where
2133 T: EntityMessage + RequestMessage,
2134{
2135 let project_id = ProjectId::from_proto(request.remote_entity_id());
2136 let host_connection_id = session
2137 .db()
2138 .await
2139 .host_for_read_only_project_request(project_id, session.connection_id)
2140 .await?;
2141 let payload = session
2142 .peer
2143 .forward_request(session.connection_id, host_connection_id, request)
2144 .await?;
2145 response.send(payload)?;
2146 Ok(())
2147}
2148
2149/// forward a project request to the host. These requests are disallowed
2150/// for guests.
2151async fn forward_mutating_project_request<T>(
2152 request: T,
2153 response: Response<T>,
2154 session: Session,
2155) -> Result<()>
2156where
2157 T: EntityMessage + RequestMessage,
2158{
2159 let project_id = ProjectId::from_proto(request.remote_entity_id());
2160 let host_connection_id = session
2161 .db()
2162 .await
2163 .host_for_mutating_project_request(project_id, session.connection_id)
2164 .await?;
2165 let payload = session
2166 .peer
2167 .forward_request(session.connection_id, host_connection_id, request)
2168 .await?;
2169 response.send(payload)?;
2170 Ok(())
2171}
2172
2173/// Notify other participants that a new buffer has been created
2174async fn create_buffer_for_peer(
2175 request: proto::CreateBufferForPeer,
2176 session: Session,
2177) -> Result<()> {
2178 session
2179 .db()
2180 .await
2181 .check_user_is_project_host(
2182 ProjectId::from_proto(request.project_id),
2183 session.connection_id,
2184 )
2185 .await?;
2186 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2187 session
2188 .peer
2189 .forward_send(session.connection_id, peer_id.into(), request)?;
2190 Ok(())
2191}
2192
2193/// Notify other participants that a buffer has been updated. This is
2194/// allowed for guests as long as the update is limited to selections.
2195async fn update_buffer(
2196 request: proto::UpdateBuffer,
2197 response: Response<proto::UpdateBuffer>,
2198 session: Session,
2199) -> Result<()> {
2200 let project_id = ProjectId::from_proto(request.project_id);
2201 let mut guest_connection_ids;
2202 let mut host_connection_id = None;
2203
2204 let mut requires_write_permission = false;
2205
2206 for op in request.operations.iter() {
2207 match op.variant {
2208 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2209 Some(_) => requires_write_permission = true,
2210 }
2211 }
2212
2213 {
2214 let collaborators = session
2215 .db()
2216 .await
2217 .project_collaborators_for_buffer_update(
2218 project_id,
2219 session.connection_id,
2220 requires_write_permission,
2221 )
2222 .await?;
2223 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2224 for collaborator in collaborators.iter() {
2225 if collaborator.is_host {
2226 host_connection_id = Some(collaborator.connection_id);
2227 } else {
2228 guest_connection_ids.push(collaborator.connection_id);
2229 }
2230 }
2231 }
2232 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2233
2234 broadcast(
2235 Some(session.connection_id),
2236 guest_connection_ids,
2237 |connection_id| {
2238 session
2239 .peer
2240 .forward_send(session.connection_id, connection_id, request.clone())
2241 },
2242 );
2243 if host_connection_id != session.connection_id {
2244 session
2245 .peer
2246 .forward_request(session.connection_id, host_connection_id, request.clone())
2247 .await?;
2248 }
2249
2250 response.send(proto::Ack {})?;
2251 Ok(())
2252}
2253
2254/// Notify other participants that a project has been updated.
2255async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2256 request: T,
2257 session: Session,
2258) -> Result<()> {
2259 let project_id = ProjectId::from_proto(request.remote_entity_id());
2260 let project_connection_ids = session
2261 .db()
2262 .await
2263 .project_connection_ids(project_id, session.connection_id)
2264 .await?;
2265
2266 broadcast(
2267 Some(session.connection_id),
2268 project_connection_ids.iter().copied(),
2269 |connection_id| {
2270 session
2271 .peer
2272 .forward_send(session.connection_id, connection_id, request.clone())
2273 },
2274 );
2275 Ok(())
2276}
2277
2278/// Start following another user in a call.
2279async fn follow(
2280 request: proto::Follow,
2281 response: Response<proto::Follow>,
2282 session: UserSession,
2283) -> Result<()> {
2284 let room_id = RoomId::from_proto(request.room_id);
2285 let project_id = request.project_id.map(ProjectId::from_proto);
2286 let leader_id = request
2287 .leader_id
2288 .ok_or_else(|| anyhow!("invalid leader id"))?
2289 .into();
2290 let follower_id = session.connection_id;
2291
2292 session
2293 .db()
2294 .await
2295 .check_room_participants(room_id, leader_id, session.connection_id)
2296 .await?;
2297
2298 let response_payload = session
2299 .peer
2300 .forward_request(session.connection_id, leader_id, request)
2301 .await?;
2302 response.send(response_payload)?;
2303
2304 if let Some(project_id) = project_id {
2305 let room = session
2306 .db()
2307 .await
2308 .follow(room_id, project_id, leader_id, follower_id)
2309 .await?;
2310 room_updated(&room, &session.peer);
2311 }
2312
2313 Ok(())
2314}
2315
2316/// Stop following another user in a call.
2317async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2318 let room_id = RoomId::from_proto(request.room_id);
2319 let project_id = request.project_id.map(ProjectId::from_proto);
2320 let leader_id = request
2321 .leader_id
2322 .ok_or_else(|| anyhow!("invalid leader id"))?
2323 .into();
2324 let follower_id = session.connection_id;
2325
2326 session
2327 .db()
2328 .await
2329 .check_room_participants(room_id, leader_id, session.connection_id)
2330 .await?;
2331
2332 session
2333 .peer
2334 .forward_send(session.connection_id, leader_id, request)?;
2335
2336 if let Some(project_id) = project_id {
2337 let room = session
2338 .db()
2339 .await
2340 .unfollow(room_id, project_id, leader_id, follower_id)
2341 .await?;
2342 room_updated(&room, &session.peer);
2343 }
2344
2345 Ok(())
2346}
2347
2348/// Notify everyone following you of your current location.
2349async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2350 let room_id = RoomId::from_proto(request.room_id);
2351 let database = session.db.lock().await;
2352
2353 let connection_ids = if let Some(project_id) = request.project_id {
2354 let project_id = ProjectId::from_proto(project_id);
2355 database
2356 .project_connection_ids(project_id, session.connection_id)
2357 .await?
2358 } else {
2359 database
2360 .room_connection_ids(room_id, session.connection_id)
2361 .await?
2362 };
2363
2364 // For now, don't send view update messages back to that view's current leader.
2365 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2366 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2367 _ => None,
2368 });
2369
2370 for connection_id in connection_ids.iter().cloned() {
2371 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2372 session
2373 .peer
2374 .forward_send(session.connection_id, connection_id, request.clone())?;
2375 }
2376 }
2377 Ok(())
2378}
2379
2380/// Get public data about users.
2381async fn get_users(
2382 request: proto::GetUsers,
2383 response: Response<proto::GetUsers>,
2384 session: Session,
2385) -> Result<()> {
2386 let user_ids = request
2387 .user_ids
2388 .into_iter()
2389 .map(UserId::from_proto)
2390 .collect();
2391 let users = session
2392 .db()
2393 .await
2394 .get_users_by_ids(user_ids)
2395 .await?
2396 .into_iter()
2397 .map(|user| proto::User {
2398 id: user.id.to_proto(),
2399 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2400 github_login: user.github_login,
2401 })
2402 .collect();
2403 response.send(proto::UsersResponse { users })?;
2404 Ok(())
2405}
2406
2407/// Search for users (to invite) buy Github login
2408async fn fuzzy_search_users(
2409 request: proto::FuzzySearchUsers,
2410 response: Response<proto::FuzzySearchUsers>,
2411 session: UserSession,
2412) -> Result<()> {
2413 let query = request.query;
2414 let users = match query.len() {
2415 0 => vec![],
2416 1 | 2 => session
2417 .db()
2418 .await
2419 .get_user_by_github_login(&query)
2420 .await?
2421 .into_iter()
2422 .collect(),
2423 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2424 };
2425 let users = users
2426 .into_iter()
2427 .filter(|user| user.id != session.user_id())
2428 .map(|user| proto::User {
2429 id: user.id.to_proto(),
2430 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2431 github_login: user.github_login,
2432 })
2433 .collect();
2434 response.send(proto::UsersResponse { users })?;
2435 Ok(())
2436}
2437
2438/// Send a contact request to another user.
2439async fn request_contact(
2440 request: proto::RequestContact,
2441 response: Response<proto::RequestContact>,
2442 session: UserSession,
2443) -> Result<()> {
2444 let requester_id = session.user_id();
2445 let responder_id = UserId::from_proto(request.responder_id);
2446 if requester_id == responder_id {
2447 return Err(anyhow!("cannot add yourself as a contact"))?;
2448 }
2449
2450 let notifications = session
2451 .db()
2452 .await
2453 .send_contact_request(requester_id, responder_id)
2454 .await?;
2455
2456 // Update outgoing contact requests of requester
2457 let mut update = proto::UpdateContacts::default();
2458 update.outgoing_requests.push(responder_id.to_proto());
2459 for connection_id in session
2460 .connection_pool()
2461 .await
2462 .user_connection_ids(requester_id)
2463 {
2464 session.peer.send(connection_id, update.clone())?;
2465 }
2466
2467 // Update incoming contact requests of responder
2468 let mut update = proto::UpdateContacts::default();
2469 update
2470 .incoming_requests
2471 .push(proto::IncomingContactRequest {
2472 requester_id: requester_id.to_proto(),
2473 });
2474 let connection_pool = session.connection_pool().await;
2475 for connection_id in connection_pool.user_connection_ids(responder_id) {
2476 session.peer.send(connection_id, update.clone())?;
2477 }
2478
2479 send_notifications(&connection_pool, &session.peer, notifications);
2480
2481 response.send(proto::Ack {})?;
2482 Ok(())
2483}
2484
2485/// Accept or decline a contact request
2486async fn respond_to_contact_request(
2487 request: proto::RespondToContactRequest,
2488 response: Response<proto::RespondToContactRequest>,
2489 session: UserSession,
2490) -> Result<()> {
2491 let responder_id = session.user_id();
2492 let requester_id = UserId::from_proto(request.requester_id);
2493 let db = session.db().await;
2494 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2495 db.dismiss_contact_notification(responder_id, requester_id)
2496 .await?;
2497 } else {
2498 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2499
2500 let notifications = db
2501 .respond_to_contact_request(responder_id, requester_id, accept)
2502 .await?;
2503 let requester_busy = db.is_user_busy(requester_id).await?;
2504 let responder_busy = db.is_user_busy(responder_id).await?;
2505
2506 let pool = session.connection_pool().await;
2507 // Update responder with new contact
2508 let mut update = proto::UpdateContacts::default();
2509 if accept {
2510 update
2511 .contacts
2512 .push(contact_for_user(requester_id, requester_busy, &pool));
2513 }
2514 update
2515 .remove_incoming_requests
2516 .push(requester_id.to_proto());
2517 for connection_id in pool.user_connection_ids(responder_id) {
2518 session.peer.send(connection_id, update.clone())?;
2519 }
2520
2521 // Update requester with new contact
2522 let mut update = proto::UpdateContacts::default();
2523 if accept {
2524 update
2525 .contacts
2526 .push(contact_for_user(responder_id, responder_busy, &pool));
2527 }
2528 update
2529 .remove_outgoing_requests
2530 .push(responder_id.to_proto());
2531
2532 for connection_id in pool.user_connection_ids(requester_id) {
2533 session.peer.send(connection_id, update.clone())?;
2534 }
2535
2536 send_notifications(&pool, &session.peer, notifications);
2537 }
2538
2539 response.send(proto::Ack {})?;
2540 Ok(())
2541}
2542
2543/// Remove a contact.
2544async fn remove_contact(
2545 request: proto::RemoveContact,
2546 response: Response<proto::RemoveContact>,
2547 session: UserSession,
2548) -> Result<()> {
2549 let requester_id = session.user_id();
2550 let responder_id = UserId::from_proto(request.user_id);
2551 let db = session.db().await;
2552 let (contact_accepted, deleted_notification_id) =
2553 db.remove_contact(requester_id, responder_id).await?;
2554
2555 let pool = session.connection_pool().await;
2556 // Update outgoing contact requests of requester
2557 let mut update = proto::UpdateContacts::default();
2558 if contact_accepted {
2559 update.remove_contacts.push(responder_id.to_proto());
2560 } else {
2561 update
2562 .remove_outgoing_requests
2563 .push(responder_id.to_proto());
2564 }
2565 for connection_id in pool.user_connection_ids(requester_id) {
2566 session.peer.send(connection_id, update.clone())?;
2567 }
2568
2569 // Update incoming contact requests of responder
2570 let mut update = proto::UpdateContacts::default();
2571 if contact_accepted {
2572 update.remove_contacts.push(requester_id.to_proto());
2573 } else {
2574 update
2575 .remove_incoming_requests
2576 .push(requester_id.to_proto());
2577 }
2578 for connection_id in pool.user_connection_ids(responder_id) {
2579 session.peer.send(connection_id, update.clone())?;
2580 if let Some(notification_id) = deleted_notification_id {
2581 session.peer.send(
2582 connection_id,
2583 proto::DeleteNotification {
2584 notification_id: notification_id.to_proto(),
2585 },
2586 )?;
2587 }
2588 }
2589
2590 response.send(proto::Ack {})?;
2591 Ok(())
2592}
2593
2594/// Creates a new channel.
2595async fn create_channel(
2596 request: proto::CreateChannel,
2597 response: Response<proto::CreateChannel>,
2598 session: UserSession,
2599) -> Result<()> {
2600 let db = session.db().await;
2601
2602 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2603 let (channel, membership) = db
2604 .create_channel(&request.name, parent_id, session.user_id())
2605 .await?;
2606
2607 let root_id = channel.root_id();
2608 let channel = Channel::from_model(channel);
2609
2610 response.send(proto::CreateChannelResponse {
2611 channel: Some(channel.to_proto()),
2612 parent_id: request.parent_id,
2613 })?;
2614
2615 let mut connection_pool = session.connection_pool().await;
2616 if let Some(membership) = membership {
2617 connection_pool.subscribe_to_channel(
2618 membership.user_id,
2619 membership.channel_id,
2620 membership.role,
2621 );
2622 let update = proto::UpdateUserChannels {
2623 channel_memberships: vec![proto::ChannelMembership {
2624 channel_id: membership.channel_id.to_proto(),
2625 role: membership.role.into(),
2626 }],
2627 ..Default::default()
2628 };
2629 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2630 session.peer.send(connection_id, update.clone())?;
2631 }
2632 }
2633
2634 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2635 if !role.can_see_channel(channel.visibility) {
2636 continue;
2637 }
2638
2639 let update = proto::UpdateChannels {
2640 channels: vec![channel.to_proto()],
2641 ..Default::default()
2642 };
2643 session.peer.send(connection_id, update.clone())?;
2644 }
2645
2646 Ok(())
2647}
2648
2649/// Delete a channel
2650async fn delete_channel(
2651 request: proto::DeleteChannel,
2652 response: Response<proto::DeleteChannel>,
2653 session: UserSession,
2654) -> Result<()> {
2655 let db = session.db().await;
2656
2657 let channel_id = request.channel_id;
2658 let (root_channel, removed_channels) = db
2659 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2660 .await?;
2661 response.send(proto::Ack {})?;
2662
2663 // Notify members of removed channels
2664 let mut update = proto::UpdateChannels::default();
2665 update
2666 .delete_channels
2667 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2668
2669 let connection_pool = session.connection_pool().await;
2670 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2671 session.peer.send(connection_id, update.clone())?;
2672 }
2673
2674 Ok(())
2675}
2676
2677/// Invite someone to join a channel.
2678async fn invite_channel_member(
2679 request: proto::InviteChannelMember,
2680 response: Response<proto::InviteChannelMember>,
2681 session: UserSession,
2682) -> Result<()> {
2683 let db = session.db().await;
2684 let channel_id = ChannelId::from_proto(request.channel_id);
2685 let invitee_id = UserId::from_proto(request.user_id);
2686 let InviteMemberResult {
2687 channel,
2688 notifications,
2689 } = db
2690 .invite_channel_member(
2691 channel_id,
2692 invitee_id,
2693 session.user_id(),
2694 request.role().into(),
2695 )
2696 .await?;
2697
2698 let update = proto::UpdateChannels {
2699 channel_invitations: vec![channel.to_proto()],
2700 ..Default::default()
2701 };
2702
2703 let connection_pool = session.connection_pool().await;
2704 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2705 session.peer.send(connection_id, update.clone())?;
2706 }
2707
2708 send_notifications(&connection_pool, &session.peer, notifications);
2709
2710 response.send(proto::Ack {})?;
2711 Ok(())
2712}
2713
2714/// remove someone from a channel
2715async fn remove_channel_member(
2716 request: proto::RemoveChannelMember,
2717 response: Response<proto::RemoveChannelMember>,
2718 session: UserSession,
2719) -> Result<()> {
2720 let db = session.db().await;
2721 let channel_id = ChannelId::from_proto(request.channel_id);
2722 let member_id = UserId::from_proto(request.user_id);
2723
2724 let RemoveChannelMemberResult {
2725 membership_update,
2726 notification_id,
2727 } = db
2728 .remove_channel_member(channel_id, member_id, session.user_id())
2729 .await?;
2730
2731 let mut connection_pool = session.connection_pool().await;
2732 notify_membership_updated(
2733 &mut connection_pool,
2734 membership_update,
2735 member_id,
2736 &session.peer,
2737 );
2738 for connection_id in connection_pool.user_connection_ids(member_id) {
2739 if let Some(notification_id) = notification_id {
2740 session
2741 .peer
2742 .send(
2743 connection_id,
2744 proto::DeleteNotification {
2745 notification_id: notification_id.to_proto(),
2746 },
2747 )
2748 .trace_err();
2749 }
2750 }
2751
2752 response.send(proto::Ack {})?;
2753 Ok(())
2754}
2755
2756/// Toggle the channel between public and private.
2757/// Care is taken to maintain the invariant that public channels only descend from public channels,
2758/// (though members-only channels can appear at any point in the hierarchy).
2759async fn set_channel_visibility(
2760 request: proto::SetChannelVisibility,
2761 response: Response<proto::SetChannelVisibility>,
2762 session: UserSession,
2763) -> Result<()> {
2764 let db = session.db().await;
2765 let channel_id = ChannelId::from_proto(request.channel_id);
2766 let visibility = request.visibility().into();
2767
2768 let channel_model = db
2769 .set_channel_visibility(channel_id, visibility, session.user_id())
2770 .await?;
2771 let root_id = channel_model.root_id();
2772 let channel = Channel::from_model(channel_model);
2773
2774 let mut connection_pool = session.connection_pool().await;
2775 for (user_id, role) in connection_pool
2776 .channel_user_ids(root_id)
2777 .collect::<Vec<_>>()
2778 .into_iter()
2779 {
2780 let update = if role.can_see_channel(channel.visibility) {
2781 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2782 proto::UpdateChannels {
2783 channels: vec![channel.to_proto()],
2784 ..Default::default()
2785 }
2786 } else {
2787 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2788 proto::UpdateChannels {
2789 delete_channels: vec![channel.id.to_proto()],
2790 ..Default::default()
2791 }
2792 };
2793
2794 for connection_id in connection_pool.user_connection_ids(user_id) {
2795 session.peer.send(connection_id, update.clone())?;
2796 }
2797 }
2798
2799 response.send(proto::Ack {})?;
2800 Ok(())
2801}
2802
2803/// Alter the role for a user in the channel.
2804async fn set_channel_member_role(
2805 request: proto::SetChannelMemberRole,
2806 response: Response<proto::SetChannelMemberRole>,
2807 session: UserSession,
2808) -> Result<()> {
2809 let db = session.db().await;
2810 let channel_id = ChannelId::from_proto(request.channel_id);
2811 let member_id = UserId::from_proto(request.user_id);
2812 let result = db
2813 .set_channel_member_role(
2814 channel_id,
2815 session.user_id(),
2816 member_id,
2817 request.role().into(),
2818 )
2819 .await?;
2820
2821 match result {
2822 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2823 let mut connection_pool = session.connection_pool().await;
2824 notify_membership_updated(
2825 &mut connection_pool,
2826 membership_update,
2827 member_id,
2828 &session.peer,
2829 )
2830 }
2831 db::SetMemberRoleResult::InviteUpdated(channel) => {
2832 let update = proto::UpdateChannels {
2833 channel_invitations: vec![channel.to_proto()],
2834 ..Default::default()
2835 };
2836
2837 for connection_id in session
2838 .connection_pool()
2839 .await
2840 .user_connection_ids(member_id)
2841 {
2842 session.peer.send(connection_id, update.clone())?;
2843 }
2844 }
2845 }
2846
2847 response.send(proto::Ack {})?;
2848 Ok(())
2849}
2850
2851/// Change the name of a channel
2852async fn rename_channel(
2853 request: proto::RenameChannel,
2854 response: Response<proto::RenameChannel>,
2855 session: UserSession,
2856) -> Result<()> {
2857 let db = session.db().await;
2858 let channel_id = ChannelId::from_proto(request.channel_id);
2859 let channel_model = db
2860 .rename_channel(channel_id, session.user_id(), &request.name)
2861 .await?;
2862 let root_id = channel_model.root_id();
2863 let channel = Channel::from_model(channel_model);
2864
2865 response.send(proto::RenameChannelResponse {
2866 channel: Some(channel.to_proto()),
2867 })?;
2868
2869 let connection_pool = session.connection_pool().await;
2870 let update = proto::UpdateChannels {
2871 channels: vec![channel.to_proto()],
2872 ..Default::default()
2873 };
2874 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2875 if role.can_see_channel(channel.visibility) {
2876 session.peer.send(connection_id, update.clone())?;
2877 }
2878 }
2879
2880 Ok(())
2881}
2882
2883/// Move a channel to a new parent.
2884async fn move_channel(
2885 request: proto::MoveChannel,
2886 response: Response<proto::MoveChannel>,
2887 session: UserSession,
2888) -> Result<()> {
2889 let channel_id = ChannelId::from_proto(request.channel_id);
2890 let to = ChannelId::from_proto(request.to);
2891
2892 let (root_id, channels) = session
2893 .db()
2894 .await
2895 .move_channel(channel_id, to, session.user_id())
2896 .await?;
2897
2898 let connection_pool = session.connection_pool().await;
2899 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2900 let channels = channels
2901 .iter()
2902 .filter_map(|channel| {
2903 if role.can_see_channel(channel.visibility) {
2904 Some(channel.to_proto())
2905 } else {
2906 None
2907 }
2908 })
2909 .collect::<Vec<_>>();
2910 if channels.is_empty() {
2911 continue;
2912 }
2913
2914 let update = proto::UpdateChannels {
2915 channels,
2916 ..Default::default()
2917 };
2918
2919 session.peer.send(connection_id, update.clone())?;
2920 }
2921
2922 response.send(Ack {})?;
2923 Ok(())
2924}
2925
2926/// Get the list of channel members
2927async fn get_channel_members(
2928 request: proto::GetChannelMembers,
2929 response: Response<proto::GetChannelMembers>,
2930 session: UserSession,
2931) -> Result<()> {
2932 let db = session.db().await;
2933 let channel_id = ChannelId::from_proto(request.channel_id);
2934 let members = db
2935 .get_channel_participant_details(channel_id, session.user_id())
2936 .await?;
2937 response.send(proto::GetChannelMembersResponse { members })?;
2938 Ok(())
2939}
2940
2941/// Accept or decline a channel invitation.
2942async fn respond_to_channel_invite(
2943 request: proto::RespondToChannelInvite,
2944 response: Response<proto::RespondToChannelInvite>,
2945 session: UserSession,
2946) -> Result<()> {
2947 let db = session.db().await;
2948 let channel_id = ChannelId::from_proto(request.channel_id);
2949 let RespondToChannelInvite {
2950 membership_update,
2951 notifications,
2952 } = db
2953 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
2954 .await?;
2955
2956 let mut connection_pool = session.connection_pool().await;
2957 if let Some(membership_update) = membership_update {
2958 notify_membership_updated(
2959 &mut connection_pool,
2960 membership_update,
2961 session.user_id(),
2962 &session.peer,
2963 );
2964 } else {
2965 let update = proto::UpdateChannels {
2966 remove_channel_invitations: vec![channel_id.to_proto()],
2967 ..Default::default()
2968 };
2969
2970 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
2971 session.peer.send(connection_id, update.clone())?;
2972 }
2973 };
2974
2975 send_notifications(&connection_pool, &session.peer, notifications);
2976
2977 response.send(proto::Ack {})?;
2978
2979 Ok(())
2980}
2981
2982/// Join the channels' room
2983async fn join_channel(
2984 request: proto::JoinChannel,
2985 response: Response<proto::JoinChannel>,
2986 session: UserSession,
2987) -> Result<()> {
2988 let channel_id = ChannelId::from_proto(request.channel_id);
2989 join_channel_internal(channel_id, Box::new(response), session).await
2990}
2991
2992trait JoinChannelInternalResponse {
2993 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2994}
2995impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2996 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2997 Response::<proto::JoinChannel>::send(self, result)
2998 }
2999}
3000impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3001 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3002 Response::<proto::JoinRoom>::send(self, result)
3003 }
3004}
3005
3006async fn join_channel_internal(
3007 channel_id: ChannelId,
3008 response: Box<impl JoinChannelInternalResponse>,
3009 session: UserSession,
3010) -> Result<()> {
3011 let joined_room = {
3012 leave_room_for_session(&session).await?;
3013 let db = session.db().await;
3014
3015 let (joined_room, membership_updated, role) = db
3016 .join_channel(channel_id, session.user_id(), session.connection_id)
3017 .await?;
3018
3019 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3020 let (can_publish, token) = if role == ChannelRole::Guest {
3021 (
3022 false,
3023 live_kit
3024 .guest_token(
3025 &joined_room.room.live_kit_room,
3026 &session.user_id().to_string(),
3027 )
3028 .trace_err()?,
3029 )
3030 } else {
3031 (
3032 true,
3033 live_kit
3034 .room_token(
3035 &joined_room.room.live_kit_room,
3036 &session.user_id().to_string(),
3037 )
3038 .trace_err()?,
3039 )
3040 };
3041
3042 Some(LiveKitConnectionInfo {
3043 server_url: live_kit.url().into(),
3044 token,
3045 can_publish,
3046 })
3047 });
3048
3049 response.send(proto::JoinRoomResponse {
3050 room: Some(joined_room.room.clone()),
3051 channel_id: joined_room
3052 .channel
3053 .as_ref()
3054 .map(|channel| channel.id.to_proto()),
3055 live_kit_connection_info,
3056 })?;
3057
3058 let mut connection_pool = session.connection_pool().await;
3059 if let Some(membership_updated) = membership_updated {
3060 notify_membership_updated(
3061 &mut connection_pool,
3062 membership_updated,
3063 session.user_id(),
3064 &session.peer,
3065 );
3066 }
3067
3068 room_updated(&joined_room.room, &session.peer);
3069
3070 joined_room
3071 };
3072
3073 channel_updated(
3074 &joined_room
3075 .channel
3076 .ok_or_else(|| anyhow!("channel not returned"))?,
3077 &joined_room.room,
3078 &session.peer,
3079 &*session.connection_pool().await,
3080 );
3081
3082 update_user_contacts(session.user_id(), &session).await?;
3083 Ok(())
3084}
3085
3086/// Start editing the channel notes
3087async fn join_channel_buffer(
3088 request: proto::JoinChannelBuffer,
3089 response: Response<proto::JoinChannelBuffer>,
3090 session: UserSession,
3091) -> Result<()> {
3092 let db = session.db().await;
3093 let channel_id = ChannelId::from_proto(request.channel_id);
3094
3095 let open_response = db
3096 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3097 .await?;
3098
3099 let collaborators = open_response.collaborators.clone();
3100 response.send(open_response)?;
3101
3102 let update = UpdateChannelBufferCollaborators {
3103 channel_id: channel_id.to_proto(),
3104 collaborators: collaborators.clone(),
3105 };
3106 channel_buffer_updated(
3107 session.connection_id,
3108 collaborators
3109 .iter()
3110 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3111 &update,
3112 &session.peer,
3113 );
3114
3115 Ok(())
3116}
3117
3118/// Edit the channel notes
3119async fn update_channel_buffer(
3120 request: proto::UpdateChannelBuffer,
3121 session: UserSession,
3122) -> Result<()> {
3123 let db = session.db().await;
3124 let channel_id = ChannelId::from_proto(request.channel_id);
3125
3126 let (collaborators, non_collaborators, epoch, version) = db
3127 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3128 .await?;
3129
3130 channel_buffer_updated(
3131 session.connection_id,
3132 collaborators,
3133 &proto::UpdateChannelBuffer {
3134 channel_id: channel_id.to_proto(),
3135 operations: request.operations,
3136 },
3137 &session.peer,
3138 );
3139
3140 let pool = &*session.connection_pool().await;
3141
3142 broadcast(
3143 None,
3144 non_collaborators
3145 .iter()
3146 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3147 |peer_id| {
3148 session.peer.send(
3149 peer_id,
3150 proto::UpdateChannels {
3151 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3152 channel_id: channel_id.to_proto(),
3153 epoch: epoch as u64,
3154 version: version.clone(),
3155 }],
3156 ..Default::default()
3157 },
3158 )
3159 },
3160 );
3161
3162 Ok(())
3163}
3164
3165/// Rejoin the channel notes after a connection blip
3166async fn rejoin_channel_buffers(
3167 request: proto::RejoinChannelBuffers,
3168 response: Response<proto::RejoinChannelBuffers>,
3169 session: UserSession,
3170) -> Result<()> {
3171 let db = session.db().await;
3172 let buffers = db
3173 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3174 .await?;
3175
3176 for rejoined_buffer in &buffers {
3177 let collaborators_to_notify = rejoined_buffer
3178 .buffer
3179 .collaborators
3180 .iter()
3181 .filter_map(|c| Some(c.peer_id?.into()));
3182 channel_buffer_updated(
3183 session.connection_id,
3184 collaborators_to_notify,
3185 &proto::UpdateChannelBufferCollaborators {
3186 channel_id: rejoined_buffer.buffer.channel_id,
3187 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3188 },
3189 &session.peer,
3190 );
3191 }
3192
3193 response.send(proto::RejoinChannelBuffersResponse {
3194 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3195 })?;
3196
3197 Ok(())
3198}
3199
3200/// Stop editing the channel notes
3201async fn leave_channel_buffer(
3202 request: proto::LeaveChannelBuffer,
3203 response: Response<proto::LeaveChannelBuffer>,
3204 session: UserSession,
3205) -> Result<()> {
3206 let db = session.db().await;
3207 let channel_id = ChannelId::from_proto(request.channel_id);
3208
3209 let left_buffer = db
3210 .leave_channel_buffer(channel_id, session.connection_id)
3211 .await?;
3212
3213 response.send(Ack {})?;
3214
3215 channel_buffer_updated(
3216 session.connection_id,
3217 left_buffer.connections,
3218 &proto::UpdateChannelBufferCollaborators {
3219 channel_id: channel_id.to_proto(),
3220 collaborators: left_buffer.collaborators,
3221 },
3222 &session.peer,
3223 );
3224
3225 Ok(())
3226}
3227
3228fn channel_buffer_updated<T: EnvelopedMessage>(
3229 sender_id: ConnectionId,
3230 collaborators: impl IntoIterator<Item = ConnectionId>,
3231 message: &T,
3232 peer: &Peer,
3233) {
3234 broadcast(Some(sender_id), collaborators, |peer_id| {
3235 peer.send(peer_id, message.clone())
3236 });
3237}
3238
3239fn send_notifications(
3240 connection_pool: &ConnectionPool,
3241 peer: &Peer,
3242 notifications: db::NotificationBatch,
3243) {
3244 for (user_id, notification) in notifications {
3245 for connection_id in connection_pool.user_connection_ids(user_id) {
3246 if let Err(error) = peer.send(
3247 connection_id,
3248 proto::AddNotification {
3249 notification: Some(notification.clone()),
3250 },
3251 ) {
3252 tracing::error!(
3253 "failed to send notification to {:?} {}",
3254 connection_id,
3255 error
3256 );
3257 }
3258 }
3259 }
3260}
3261
3262/// Send a message to the channel
3263async fn send_channel_message(
3264 request: proto::SendChannelMessage,
3265 response: Response<proto::SendChannelMessage>,
3266 session: UserSession,
3267) -> Result<()> {
3268 // Validate the message body.
3269 let body = request.body.trim().to_string();
3270 if body.len() > MAX_MESSAGE_LEN {
3271 return Err(anyhow!("message is too long"))?;
3272 }
3273 if body.is_empty() {
3274 return Err(anyhow!("message can't be blank"))?;
3275 }
3276
3277 // TODO: adjust mentions if body is trimmed
3278
3279 let timestamp = OffsetDateTime::now_utc();
3280 let nonce = request
3281 .nonce
3282 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3283
3284 let channel_id = ChannelId::from_proto(request.channel_id);
3285 let CreatedChannelMessage {
3286 message_id,
3287 participant_connection_ids,
3288 channel_members,
3289 notifications,
3290 } = session
3291 .db()
3292 .await
3293 .create_channel_message(
3294 channel_id,
3295 session.user_id(),
3296 &body,
3297 &request.mentions,
3298 timestamp,
3299 nonce.clone().into(),
3300 match request.reply_to_message_id {
3301 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3302 None => None,
3303 },
3304 )
3305 .await?;
3306
3307 let message = proto::ChannelMessage {
3308 sender_id: session.user_id().to_proto(),
3309 id: message_id.to_proto(),
3310 body,
3311 mentions: request.mentions,
3312 timestamp: timestamp.unix_timestamp() as u64,
3313 nonce: Some(nonce),
3314 reply_to_message_id: request.reply_to_message_id,
3315 edited_at: None,
3316 };
3317 broadcast(
3318 Some(session.connection_id),
3319 participant_connection_ids,
3320 |connection| {
3321 session.peer.send(
3322 connection,
3323 proto::ChannelMessageSent {
3324 channel_id: channel_id.to_proto(),
3325 message: Some(message.clone()),
3326 },
3327 )
3328 },
3329 );
3330 response.send(proto::SendChannelMessageResponse {
3331 message: Some(message),
3332 })?;
3333
3334 let pool = &*session.connection_pool().await;
3335 broadcast(
3336 None,
3337 channel_members
3338 .iter()
3339 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3340 |peer_id| {
3341 session.peer.send(
3342 peer_id,
3343 proto::UpdateChannels {
3344 latest_channel_message_ids: vec![proto::ChannelMessageId {
3345 channel_id: channel_id.to_proto(),
3346 message_id: message_id.to_proto(),
3347 }],
3348 ..Default::default()
3349 },
3350 )
3351 },
3352 );
3353 send_notifications(pool, &session.peer, notifications);
3354
3355 Ok(())
3356}
3357
3358/// Delete a channel message
3359async fn remove_channel_message(
3360 request: proto::RemoveChannelMessage,
3361 response: Response<proto::RemoveChannelMessage>,
3362 session: UserSession,
3363) -> Result<()> {
3364 let channel_id = ChannelId::from_proto(request.channel_id);
3365 let message_id = MessageId::from_proto(request.message_id);
3366 let connection_ids = session
3367 .db()
3368 .await
3369 .remove_channel_message(channel_id, message_id, session.user_id())
3370 .await?;
3371 broadcast(Some(session.connection_id), connection_ids, |connection| {
3372 session.peer.send(connection, request.clone())
3373 });
3374 response.send(proto::Ack {})?;
3375 Ok(())
3376}
3377
3378async fn update_channel_message(
3379 request: proto::UpdateChannelMessage,
3380 response: Response<proto::UpdateChannelMessage>,
3381 session: UserSession,
3382) -> Result<()> {
3383 let channel_id = ChannelId::from_proto(request.channel_id);
3384 let message_id = MessageId::from_proto(request.message_id);
3385 let updated_at = OffsetDateTime::now_utc();
3386 let UpdatedChannelMessage {
3387 message_id,
3388 participant_connection_ids,
3389 notifications,
3390 reply_to_message_id,
3391 timestamp,
3392 } = session
3393 .db()
3394 .await
3395 .update_channel_message(
3396 channel_id,
3397 message_id,
3398 session.user_id(),
3399 request.body.as_str(),
3400 &request.mentions,
3401 updated_at,
3402 )
3403 .await?;
3404
3405 let nonce = request
3406 .nonce
3407 .clone()
3408 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3409
3410 let message = proto::ChannelMessage {
3411 sender_id: session.user_id().to_proto(),
3412 id: message_id.to_proto(),
3413 body: request.body.clone(),
3414 mentions: request.mentions.clone(),
3415 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3416 nonce: Some(nonce),
3417 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3418 edited_at: Some(updated_at.unix_timestamp() as u64),
3419 };
3420
3421 response.send(proto::Ack {})?;
3422
3423 let pool = &*session.connection_pool().await;
3424 broadcast(
3425 Some(session.connection_id),
3426 participant_connection_ids,
3427 |connection| {
3428 session.peer.send(
3429 connection,
3430 proto::ChannelMessageUpdate {
3431 channel_id: channel_id.to_proto(),
3432 message: Some(message.clone()),
3433 },
3434 )
3435 },
3436 );
3437
3438 send_notifications(pool, &session.peer, notifications);
3439
3440 Ok(())
3441}
3442
3443/// Mark a channel message as read
3444async fn acknowledge_channel_message(
3445 request: proto::AckChannelMessage,
3446 session: UserSession,
3447) -> Result<()> {
3448 let channel_id = ChannelId::from_proto(request.channel_id);
3449 let message_id = MessageId::from_proto(request.message_id);
3450 let notifications = session
3451 .db()
3452 .await
3453 .observe_channel_message(channel_id, session.user_id(), message_id)
3454 .await?;
3455 send_notifications(
3456 &*session.connection_pool().await,
3457 &session.peer,
3458 notifications,
3459 );
3460 Ok(())
3461}
3462
3463/// Mark a buffer version as synced
3464async fn acknowledge_buffer_version(
3465 request: proto::AckBufferOperation,
3466 session: UserSession,
3467) -> Result<()> {
3468 let buffer_id = BufferId::from_proto(request.buffer_id);
3469 session
3470 .db()
3471 .await
3472 .observe_buffer_version(
3473 buffer_id,
3474 session.user_id(),
3475 request.epoch as i32,
3476 &request.version,
3477 )
3478 .await?;
3479 Ok(())
3480}
3481
3482struct CompleteWithLanguageModelRateLimit;
3483
3484impl RateLimit for CompleteWithLanguageModelRateLimit {
3485 fn capacity() -> usize {
3486 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3487 .ok()
3488 .and_then(|v| v.parse().ok())
3489 .unwrap_or(120) // Picked arbitrarily
3490 }
3491
3492 fn refill_duration() -> chrono::Duration {
3493 chrono::Duration::hours(1)
3494 }
3495
3496 fn db_name() -> &'static str {
3497 "complete-with-language-model"
3498 }
3499}
3500
3501async fn complete_with_language_model(
3502 request: proto::CompleteWithLanguageModel,
3503 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3504 session: Session,
3505 open_ai_api_key: Option<Arc<str>>,
3506 google_ai_api_key: Option<Arc<str>>,
3507) -> Result<()> {
3508 let Some(session) = session.for_user() else {
3509 return Err(anyhow!("user not found"))?;
3510 };
3511 authorize_access_to_language_models(&session).await?;
3512 session
3513 .rate_limiter
3514 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
3515 .await?;
3516
3517 if request.model.starts_with("gpt") {
3518 let api_key =
3519 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
3520 complete_with_open_ai(request, response, session, api_key).await?;
3521 } else if request.model.starts_with("gemini") {
3522 let api_key = google_ai_api_key
3523 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3524 complete_with_google_ai(request, response, session, api_key).await?;
3525 }
3526
3527 Ok(())
3528}
3529
3530async fn complete_with_open_ai(
3531 request: proto::CompleteWithLanguageModel,
3532 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3533 session: UserSession,
3534 api_key: Arc<str>,
3535) -> Result<()> {
3536 const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
3537
3538 let mut completion_stream = open_ai::stream_completion(
3539 &session.http_client,
3540 OPEN_AI_API_URL,
3541 &api_key,
3542 crate::ai::language_model_request_to_open_ai(request)?,
3543 )
3544 .await
3545 .context("open_ai::stream_completion request failed")?;
3546
3547 while let Some(event) = completion_stream.next().await {
3548 let event = event?;
3549 response.send(proto::LanguageModelResponse {
3550 choices: event
3551 .choices
3552 .into_iter()
3553 .map(|choice| proto::LanguageModelChoiceDelta {
3554 index: choice.index,
3555 delta: Some(proto::LanguageModelResponseMessage {
3556 role: choice.delta.role.map(|role| match role {
3557 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
3558 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
3559 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
3560 } as i32),
3561 content: choice.delta.content,
3562 }),
3563 finish_reason: choice.finish_reason,
3564 })
3565 .collect(),
3566 })?;
3567 }
3568
3569 Ok(())
3570}
3571
3572async fn complete_with_google_ai(
3573 request: proto::CompleteWithLanguageModel,
3574 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3575 session: UserSession,
3576 api_key: Arc<str>,
3577) -> Result<()> {
3578 let mut stream = google_ai::stream_generate_content(
3579 &session.http_client,
3580 google_ai::API_URL,
3581 api_key.as_ref(),
3582 crate::ai::language_model_request_to_google_ai(request)?,
3583 )
3584 .await
3585 .context("google_ai::stream_generate_content request failed")?;
3586
3587 while let Some(event) = stream.next().await {
3588 let event = event?;
3589 response.send(proto::LanguageModelResponse {
3590 choices: event
3591 .candidates
3592 .unwrap_or_default()
3593 .into_iter()
3594 .map(|candidate| proto::LanguageModelChoiceDelta {
3595 index: candidate.index as u32,
3596 delta: Some(proto::LanguageModelResponseMessage {
3597 role: Some(match candidate.content.role {
3598 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
3599 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
3600 } as i32),
3601 content: Some(
3602 candidate
3603 .content
3604 .parts
3605 .into_iter()
3606 .filter_map(|part| match part {
3607 google_ai::Part::TextPart(part) => Some(part.text),
3608 google_ai::Part::InlineDataPart(_) => None,
3609 })
3610 .collect(),
3611 ),
3612 }),
3613 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
3614 })
3615 .collect(),
3616 })?;
3617 }
3618
3619 Ok(())
3620}
3621
3622struct CountTokensWithLanguageModelRateLimit;
3623
3624impl RateLimit for CountTokensWithLanguageModelRateLimit {
3625 fn capacity() -> usize {
3626 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3627 .ok()
3628 .and_then(|v| v.parse().ok())
3629 .unwrap_or(600) // Picked arbitrarily
3630 }
3631
3632 fn refill_duration() -> chrono::Duration {
3633 chrono::Duration::hours(1)
3634 }
3635
3636 fn db_name() -> &'static str {
3637 "count-tokens-with-language-model"
3638 }
3639}
3640
3641async fn count_tokens_with_language_model(
3642 request: proto::CountTokensWithLanguageModel,
3643 response: Response<proto::CountTokensWithLanguageModel>,
3644 session: UserSession,
3645 google_ai_api_key: Option<Arc<str>>,
3646) -> Result<()> {
3647 authorize_access_to_language_models(&session).await?;
3648
3649 if !request.model.starts_with("gemini") {
3650 return Err(anyhow!(
3651 "counting tokens for model: {:?} is not supported",
3652 request.model
3653 ))?;
3654 }
3655
3656 session
3657 .rate_limiter
3658 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
3659 .await?;
3660
3661 let api_key = google_ai_api_key
3662 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3663 let tokens_response = google_ai::count_tokens(
3664 &session.http_client,
3665 google_ai::API_URL,
3666 &api_key,
3667 crate::ai::count_tokens_request_to_google_ai(request)?,
3668 )
3669 .await?;
3670 response.send(proto::CountTokensResponse {
3671 token_count: tokens_response.total_tokens as u32,
3672 })?;
3673 Ok(())
3674}
3675
3676async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
3677 let db = session.db().await;
3678 let flags = db.get_user_flags(session.user_id()).await?;
3679 if flags.iter().any(|flag| flag == "language-models") {
3680 Ok(())
3681 } else {
3682 Err(anyhow!("permission denied"))?
3683 }
3684}
3685
3686/// Start receiving chat updates for a channel
3687async fn join_channel_chat(
3688 request: proto::JoinChannelChat,
3689 response: Response<proto::JoinChannelChat>,
3690 session: UserSession,
3691) -> Result<()> {
3692 let channel_id = ChannelId::from_proto(request.channel_id);
3693
3694 let db = session.db().await;
3695 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3696 .await?;
3697 let messages = db
3698 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3699 .await?;
3700 response.send(proto::JoinChannelChatResponse {
3701 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3702 messages,
3703 })?;
3704 Ok(())
3705}
3706
3707/// Stop receiving chat updates for a channel
3708async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
3709 let channel_id = ChannelId::from_proto(request.channel_id);
3710 session
3711 .db()
3712 .await
3713 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3714 .await?;
3715 Ok(())
3716}
3717
3718/// Retrieve the chat history for a channel
3719async fn get_channel_messages(
3720 request: proto::GetChannelMessages,
3721 response: Response<proto::GetChannelMessages>,
3722 session: UserSession,
3723) -> Result<()> {
3724 let channel_id = ChannelId::from_proto(request.channel_id);
3725 let messages = session
3726 .db()
3727 .await
3728 .get_channel_messages(
3729 channel_id,
3730 session.user_id(),
3731 MESSAGE_COUNT_PER_PAGE,
3732 Some(MessageId::from_proto(request.before_message_id)),
3733 )
3734 .await?;
3735 response.send(proto::GetChannelMessagesResponse {
3736 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3737 messages,
3738 })?;
3739 Ok(())
3740}
3741
3742/// Retrieve specific chat messages
3743async fn get_channel_messages_by_id(
3744 request: proto::GetChannelMessagesById,
3745 response: Response<proto::GetChannelMessagesById>,
3746 session: UserSession,
3747) -> Result<()> {
3748 let message_ids = request
3749 .message_ids
3750 .iter()
3751 .map(|id| MessageId::from_proto(*id))
3752 .collect::<Vec<_>>();
3753 let messages = session
3754 .db()
3755 .await
3756 .get_channel_messages_by_id(session.user_id(), &message_ids)
3757 .await?;
3758 response.send(proto::GetChannelMessagesResponse {
3759 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3760 messages,
3761 })?;
3762 Ok(())
3763}
3764
3765/// Retrieve the current users notifications
3766async fn get_notifications(
3767 request: proto::GetNotifications,
3768 response: Response<proto::GetNotifications>,
3769 session: UserSession,
3770) -> Result<()> {
3771 let notifications = session
3772 .db()
3773 .await
3774 .get_notifications(
3775 session.user_id(),
3776 NOTIFICATION_COUNT_PER_PAGE,
3777 request
3778 .before_id
3779 .map(|id| db::NotificationId::from_proto(id)),
3780 )
3781 .await?;
3782 response.send(proto::GetNotificationsResponse {
3783 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3784 notifications,
3785 })?;
3786 Ok(())
3787}
3788
3789/// Mark notifications as read
3790async fn mark_notification_as_read(
3791 request: proto::MarkNotificationRead,
3792 response: Response<proto::MarkNotificationRead>,
3793 session: UserSession,
3794) -> Result<()> {
3795 let database = &session.db().await;
3796 let notifications = database
3797 .mark_notification_as_read_by_id(
3798 session.user_id(),
3799 NotificationId::from_proto(request.notification_id),
3800 )
3801 .await?;
3802 send_notifications(
3803 &*session.connection_pool().await,
3804 &session.peer,
3805 notifications,
3806 );
3807 response.send(proto::Ack {})?;
3808 Ok(())
3809}
3810
3811/// Get the current users information
3812async fn get_private_user_info(
3813 _request: proto::GetPrivateUserInfo,
3814 response: Response<proto::GetPrivateUserInfo>,
3815 session: UserSession,
3816) -> Result<()> {
3817 let db = session.db().await;
3818
3819 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3820 let user = db
3821 .get_user_by_id(session.user_id())
3822 .await?
3823 .ok_or_else(|| anyhow!("user not found"))?;
3824 let flags = db.get_user_flags(session.user_id()).await?;
3825
3826 response.send(proto::GetPrivateUserInfoResponse {
3827 metrics_id,
3828 staff: user.admin,
3829 flags,
3830 })?;
3831 Ok(())
3832}
3833
3834fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3835 match message {
3836 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3837 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3838 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3839 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3840 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3841 code: frame.code.into(),
3842 reason: frame.reason,
3843 })),
3844 }
3845}
3846
3847fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3848 match message {
3849 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3850 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3851 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3852 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3853 AxumMessage::Close(frame) => {
3854 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3855 code: frame.code.into(),
3856 reason: frame.reason,
3857 }))
3858 }
3859 }
3860}
3861
3862fn notify_membership_updated(
3863 connection_pool: &mut ConnectionPool,
3864 result: MembershipUpdated,
3865 user_id: UserId,
3866 peer: &Peer,
3867) {
3868 for membership in &result.new_channels.channel_memberships {
3869 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3870 }
3871 for channel_id in &result.removed_channels {
3872 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3873 }
3874
3875 let user_channels_update = proto::UpdateUserChannels {
3876 channel_memberships: result
3877 .new_channels
3878 .channel_memberships
3879 .iter()
3880 .map(|cm| proto::ChannelMembership {
3881 channel_id: cm.channel_id.to_proto(),
3882 role: cm.role.into(),
3883 })
3884 .collect(),
3885 ..Default::default()
3886 };
3887
3888 let mut update = build_channels_update(result.new_channels, vec![]);
3889 update.delete_channels = result
3890 .removed_channels
3891 .into_iter()
3892 .map(|id| id.to_proto())
3893 .collect();
3894 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3895
3896 for connection_id in connection_pool.user_connection_ids(user_id) {
3897 peer.send(connection_id, user_channels_update.clone())
3898 .trace_err();
3899 peer.send(connection_id, update.clone()).trace_err();
3900 }
3901}
3902
3903fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3904 proto::UpdateUserChannels {
3905 channel_memberships: channels
3906 .channel_memberships
3907 .iter()
3908 .map(|m| proto::ChannelMembership {
3909 channel_id: m.channel_id.to_proto(),
3910 role: m.role.into(),
3911 })
3912 .collect(),
3913 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3914 observed_channel_message_id: channels.observed_channel_messages.clone(),
3915 }
3916}
3917
3918fn build_channels_update(
3919 channels: ChannelsForUser,
3920 channel_invites: Vec<db::Channel>,
3921) -> proto::UpdateChannels {
3922 let mut update = proto::UpdateChannels::default();
3923
3924 for channel in channels.channels {
3925 update.channels.push(channel.to_proto());
3926 }
3927
3928 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3929 update.latest_channel_message_ids = channels.latest_channel_messages;
3930
3931 for (channel_id, participants) in channels.channel_participants {
3932 update
3933 .channel_participants
3934 .push(proto::ChannelParticipants {
3935 channel_id: channel_id.to_proto(),
3936 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3937 });
3938 }
3939
3940 for channel in channel_invites {
3941 update.channel_invitations.push(channel.to_proto());
3942 }
3943 for project in channels.hosted_projects {
3944 update.hosted_projects.push(project);
3945 }
3946
3947 update
3948}
3949
3950fn build_initial_contacts_update(
3951 contacts: Vec<db::Contact>,
3952 pool: &ConnectionPool,
3953) -> proto::UpdateContacts {
3954 let mut update = proto::UpdateContacts::default();
3955
3956 for contact in contacts {
3957 match contact {
3958 db::Contact::Accepted { user_id, busy } => {
3959 update.contacts.push(contact_for_user(user_id, busy, &pool));
3960 }
3961 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3962 db::Contact::Incoming { user_id } => {
3963 update
3964 .incoming_requests
3965 .push(proto::IncomingContactRequest {
3966 requester_id: user_id.to_proto(),
3967 })
3968 }
3969 }
3970 }
3971
3972 update
3973}
3974
3975fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3976 proto::Contact {
3977 user_id: user_id.to_proto(),
3978 online: pool.is_user_online(user_id),
3979 busy,
3980 }
3981}
3982
3983fn room_updated(room: &proto::Room, peer: &Peer) {
3984 broadcast(
3985 None,
3986 room.participants
3987 .iter()
3988 .filter_map(|participant| Some(participant.peer_id?.into())),
3989 |peer_id| {
3990 peer.send(
3991 peer_id,
3992 proto::RoomUpdated {
3993 room: Some(room.clone()),
3994 },
3995 )
3996 },
3997 );
3998}
3999
4000fn channel_updated(
4001 channel: &db::channel::Model,
4002 room: &proto::Room,
4003 peer: &Peer,
4004 pool: &ConnectionPool,
4005) {
4006 let participants = room
4007 .participants
4008 .iter()
4009 .map(|p| p.user_id)
4010 .collect::<Vec<_>>();
4011
4012 broadcast(
4013 None,
4014 pool.channel_connection_ids(channel.root_id())
4015 .filter_map(|(channel_id, role)| {
4016 role.can_see_channel(channel.visibility).then(|| channel_id)
4017 }),
4018 |peer_id| {
4019 peer.send(
4020 peer_id,
4021 proto::UpdateChannels {
4022 channel_participants: vec![proto::ChannelParticipants {
4023 channel_id: channel.id.to_proto(),
4024 participant_user_ids: participants.clone(),
4025 }],
4026 ..Default::default()
4027 },
4028 )
4029 },
4030 );
4031}
4032
4033async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4034 let db = session.db().await;
4035
4036 let contacts = db.get_contacts(user_id).await?;
4037 let busy = db.is_user_busy(user_id).await?;
4038
4039 let pool = session.connection_pool().await;
4040 let updated_contact = contact_for_user(user_id, busy, &pool);
4041 for contact in contacts {
4042 if let db::Contact::Accepted {
4043 user_id: contact_user_id,
4044 ..
4045 } = contact
4046 {
4047 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4048 session
4049 .peer
4050 .send(
4051 contact_conn_id,
4052 proto::UpdateContacts {
4053 contacts: vec![updated_contact.clone()],
4054 remove_contacts: Default::default(),
4055 incoming_requests: Default::default(),
4056 remove_incoming_requests: Default::default(),
4057 outgoing_requests: Default::default(),
4058 remove_outgoing_requests: Default::default(),
4059 },
4060 )
4061 .trace_err();
4062 }
4063 }
4064 }
4065 Ok(())
4066}
4067
4068async fn leave_room_for_session(session: &UserSession) -> Result<()> {
4069 let mut contacts_to_update = HashSet::default();
4070
4071 let room_id;
4072 let canceled_calls_to_user_ids;
4073 let live_kit_room;
4074 let delete_live_kit_room;
4075 let room;
4076 let channel;
4077
4078 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
4079 contacts_to_update.insert(session.user_id());
4080
4081 for project in left_room.left_projects.values() {
4082 project_left(project, session);
4083 }
4084
4085 room_id = RoomId::from_proto(left_room.room.id);
4086 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4087 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4088 delete_live_kit_room = left_room.deleted;
4089 room = mem::take(&mut left_room.room);
4090 channel = mem::take(&mut left_room.channel);
4091
4092 room_updated(&room, &session.peer);
4093 } else {
4094 return Ok(());
4095 }
4096
4097 if let Some(channel) = channel {
4098 channel_updated(
4099 &channel,
4100 &room,
4101 &session.peer,
4102 &*session.connection_pool().await,
4103 );
4104 }
4105
4106 {
4107 let pool = session.connection_pool().await;
4108 for canceled_user_id in canceled_calls_to_user_ids {
4109 for connection_id in pool.user_connection_ids(canceled_user_id) {
4110 session
4111 .peer
4112 .send(
4113 connection_id,
4114 proto::CallCanceled {
4115 room_id: room_id.to_proto(),
4116 },
4117 )
4118 .trace_err();
4119 }
4120 contacts_to_update.insert(canceled_user_id);
4121 }
4122 }
4123
4124 for contact_user_id in contacts_to_update {
4125 update_user_contacts(contact_user_id, &session).await?;
4126 }
4127
4128 if let Some(live_kit) = session.live_kit_client.as_ref() {
4129 live_kit
4130 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4131 .await
4132 .trace_err();
4133
4134 if delete_live_kit_room {
4135 live_kit.delete_room(live_kit_room).await.trace_err();
4136 }
4137 }
4138
4139 Ok(())
4140}
4141
4142async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4143 let left_channel_buffers = session
4144 .db()
4145 .await
4146 .leave_channel_buffers(session.connection_id)
4147 .await?;
4148
4149 for left_buffer in left_channel_buffers {
4150 channel_buffer_updated(
4151 session.connection_id,
4152 left_buffer.connections,
4153 &proto::UpdateChannelBufferCollaborators {
4154 channel_id: left_buffer.channel_id.to_proto(),
4155 collaborators: left_buffer.collaborators,
4156 },
4157 &session.peer,
4158 );
4159 }
4160
4161 Ok(())
4162}
4163
4164fn project_left(project: &db::LeftProject, session: &UserSession) {
4165 for connection_id in &project.connection_ids {
4166 if project.host_user_id == Some(session.user_id()) {
4167 session
4168 .peer
4169 .send(
4170 *connection_id,
4171 proto::UnshareProject {
4172 project_id: project.id.to_proto(),
4173 },
4174 )
4175 .trace_err();
4176 } else {
4177 session
4178 .peer
4179 .send(
4180 *connection_id,
4181 proto::RemoveProjectCollaborator {
4182 project_id: project.id.to_proto(),
4183 peer_id: Some(session.connection_id.into()),
4184 },
4185 )
4186 .trace_err();
4187 }
4188 }
4189}
4190
4191pub trait ResultExt {
4192 type Ok;
4193
4194 fn trace_err(self) -> Option<Self::Ok>;
4195}
4196
4197impl<T, E> ResultExt for Result<T, E>
4198where
4199 E: std::fmt::Debug,
4200{
4201 type Ok = T;
4202
4203 fn trace_err(self) -> Option<T> {
4204 match self {
4205 Ok(value) => Some(value),
4206 Err(error) => {
4207 tracing::error!("{:?}", error);
4208 None
4209 }
4210 }
4211 }
4212}