1mod connection_pool;
2
3use crate::{
4 auth::{self},
5 db::{
6 self, dev_server, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
8 NotificationId, Project, ProjectId, RemoveChannelMemberResult, ReplicaId,
9 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
10 },
11 executor::Executor,
12 AppState, Error, RateLimit, RateLimiter, Result,
13};
14use anyhow::{anyhow, Context as _};
15use async_tungstenite::tungstenite::{
16 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
17};
18use axum::{
19 body::Body,
20 extract::{
21 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
22 ConnectInfo, WebSocketUpgrade,
23 },
24 headers::{Header, HeaderName},
25 http::StatusCode,
26 middleware,
27 response::IntoResponse,
28 routing::get,
29 Extension, Router, TypedHeader,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use core::fmt::{self, Debug, Formatter};
34
35use futures::{
36 channel::oneshot,
37 future::{self, BoxFuture},
38 stream::FuturesUnordered,
39 FutureExt, SinkExt, StreamExt, TryStreamExt,
40};
41use prometheus::{register_int_gauge, IntGauge};
42use rpc::{
43 proto::{
44 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
45 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
46 },
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48};
49use semantic_version::SemanticVersion;
50use serde::{Serialize, Serializer};
51use std::{
52 any::TypeId,
53 future::Future,
54 marker::PhantomData,
55 mem,
56 net::SocketAddr,
57 ops::{Deref, DerefMut},
58 rc::Rc,
59 sync::{
60 atomic::{AtomicBool, Ordering::SeqCst},
61 Arc, OnceLock,
62 },
63 time::{Duration, Instant},
64};
65use time::OffsetDateTime;
66use tokio::sync::{watch, Semaphore};
67use tower::ServiceBuilder;
68use tracing::{
69 field::{self},
70 info_span, instrument, Instrument,
71};
72use util::http::IsahcHttpClient;
73
74pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
75
76// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
77pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
78
79const MESSAGE_COUNT_PER_PAGE: usize = 100;
80const MAX_MESSAGE_LEN: usize = 1024;
81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
82
83type MessageHandler =
84 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
85
86struct Response<R> {
87 peer: Arc<Peer>,
88 receipt: Receipt<R>,
89 responded: Arc<AtomicBool>,
90}
91
92impl<R: RequestMessage> Response<R> {
93 fn send(self, payload: R::Response) -> Result<()> {
94 self.responded.store(true, SeqCst);
95 self.peer.respond(self.receipt, payload)?;
96 Ok(())
97 }
98}
99
100struct StreamingResponse<R: RequestMessage> {
101 peer: Arc<Peer>,
102 receipt: Receipt<R>,
103}
104
105impl<R: RequestMessage> StreamingResponse<R> {
106 fn send(&self, payload: R::Response) -> Result<()> {
107 self.peer.respond(self.receipt, payload)?;
108 Ok(())
109 }
110}
111
112#[derive(Clone, Debug)]
113pub enum Principal {
114 User(User),
115 Impersonated { user: User, admin: User },
116 DevServer(dev_server::Model),
117}
118
119impl Principal {
120 fn update_span(&self, span: &tracing::Span) {
121 match &self {
122 Principal::User(user) => {
123 span.record("user_id", &user.id.0);
124 span.record("login", &user.github_login);
125 }
126 Principal::Impersonated { user, admin } => {
127 span.record("user_id", &user.id.0);
128 span.record("login", &user.github_login);
129 span.record("impersonator", &admin.github_login);
130 }
131 Principal::DevServer(dev_server) => {
132 span.record("dev_server_id", &dev_server.id.0);
133 }
134 }
135 }
136}
137
138#[derive(Clone)]
139struct Session {
140 principal: Principal,
141 connection_id: ConnectionId,
142 db: Arc<tokio::sync::Mutex<DbHandle>>,
143 peer: Arc<Peer>,
144 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
145 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
146 http_client: IsahcHttpClient,
147 rate_limiter: Arc<RateLimiter>,
148 _executor: Executor,
149}
150
151impl Session {
152 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
153 #[cfg(test)]
154 tokio::task::yield_now().await;
155 let guard = self.db.lock().await;
156 #[cfg(test)]
157 tokio::task::yield_now().await;
158 guard
159 }
160
161 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
162 #[cfg(test)]
163 tokio::task::yield_now().await;
164 let guard = self.connection_pool.lock();
165 ConnectionPoolGuard {
166 guard,
167 _not_send: PhantomData,
168 }
169 }
170
171 fn for_user(self) -> Option<UserSession> {
172 UserSession::new(self)
173 }
174
175 fn user_id(&self) -> Option<UserId> {
176 match &self.principal {
177 Principal::User(user) => Some(user.id),
178 Principal::Impersonated { user, .. } => Some(user.id),
179 Principal::DevServer(_) => None,
180 }
181 }
182}
183
184impl Debug for Session {
185 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
186 let mut result = f.debug_struct("Session");
187 match &self.principal {
188 Principal::User(user) => {
189 result.field("user", &user.github_login);
190 }
191 Principal::Impersonated { user, admin } => {
192 result.field("user", &user.github_login);
193 result.field("impersonator", &admin.github_login);
194 }
195 Principal::DevServer(dev_server) => {
196 result.field("dev_server", &dev_server.id);
197 }
198 }
199 result.field("connection_id", &self.connection_id).finish()
200 }
201}
202
203struct UserSession(Session);
204
205impl UserSession {
206 pub fn new(s: Session) -> Option<Self> {
207 s.user_id().map(|_| UserSession(s))
208 }
209 pub fn user_id(&self) -> UserId {
210 self.0.user_id().unwrap()
211 }
212}
213
214impl Deref for UserSession {
215 type Target = Session;
216
217 fn deref(&self) -> &Self::Target {
218 &self.0
219 }
220}
221impl DerefMut for UserSession {
222 fn deref_mut(&mut self) -> &mut Self::Target {
223 &mut self.0
224 }
225}
226
227fn user_handler<M: RequestMessage, Fut>(
228 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
229) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
230where
231 Fut: Send + Future<Output = Result<()>>,
232{
233 let handler = Arc::new(handler);
234 move |message, response, session| {
235 let handler = handler.clone();
236 Box::pin(async move {
237 if let Some(user_session) = session.for_user() {
238 Ok(handler(message, response, user_session).await?)
239 } else {
240 Err(Error::Internal(anyhow!("must be a user")))
241 }
242 })
243 }
244}
245
246fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
247 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
248) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
249where
250 InnertRetFut: Send + Future<Output = Result<()>>,
251{
252 let handler = Arc::new(handler);
253 move |message, session| {
254 let handler = handler.clone();
255 Box::pin(async move {
256 if let Some(user_session) = session.for_user() {
257 Ok(handler(message, user_session).await?)
258 } else {
259 Err(Error::Internal(anyhow!("must be a user")))
260 }
261 })
262 }
263}
264
265struct DbHandle(Arc<Database>);
266
267impl Deref for DbHandle {
268 type Target = Database;
269
270 fn deref(&self) -> &Self::Target {
271 self.0.as_ref()
272 }
273}
274
275pub struct Server {
276 id: parking_lot::Mutex<ServerId>,
277 peer: Arc<Peer>,
278 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
279 app_state: Arc<AppState>,
280 handlers: HashMap<TypeId, MessageHandler>,
281 teardown: watch::Sender<bool>,
282}
283
284pub(crate) struct ConnectionPoolGuard<'a> {
285 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
286 _not_send: PhantomData<Rc<()>>,
287}
288
289#[derive(Serialize)]
290pub struct ServerSnapshot<'a> {
291 peer: &'a Peer,
292 #[serde(serialize_with = "serialize_deref")]
293 connection_pool: ConnectionPoolGuard<'a>,
294}
295
296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
297where
298 S: Serializer,
299 T: Deref<Target = U>,
300 U: Serialize,
301{
302 Serialize::serialize(value.deref(), serializer)
303}
304
305impl Server {
306 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
307 let mut server = Self {
308 id: parking_lot::Mutex::new(id),
309 peer: Peer::new(id.0 as u32),
310 app_state: app_state.clone(),
311 connection_pool: Default::default(),
312 handlers: Default::default(),
313 teardown: watch::channel(false).0,
314 };
315
316 server
317 .add_request_handler(ping)
318 .add_request_handler(user_handler(create_room))
319 .add_request_handler(user_handler(join_room))
320 .add_request_handler(user_handler(rejoin_room))
321 .add_request_handler(user_handler(leave_room))
322 .add_request_handler(user_handler(set_room_participant_role))
323 .add_request_handler(user_handler(call))
324 .add_request_handler(user_handler(cancel_call))
325 .add_message_handler(user_message_handler(decline_call))
326 .add_request_handler(user_handler(update_participant_location))
327 .add_request_handler(share_project)
328 .add_message_handler(unshare_project)
329 .add_request_handler(user_handler(join_project))
330 .add_request_handler(user_handler(join_hosted_project))
331 .add_message_handler(user_message_handler(leave_project))
332 .add_request_handler(update_project)
333 .add_request_handler(update_worktree)
334 .add_message_handler(start_language_server)
335 .add_message_handler(update_language_server)
336 .add_message_handler(update_diagnostic_summary)
337 .add_message_handler(update_worktree_settings)
338 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
339 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
340 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
341 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
342 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
343 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
344 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
345 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
346 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
347 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
348 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
349 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
350 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
353 )
354 .add_request_handler(
355 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
356 )
357 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
358 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
359 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
360 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
361 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
362 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
363 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
364 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
365 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
366 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
367 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
368 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
369 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
370 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
371 .add_message_handler(create_buffer_for_peer)
372 .add_request_handler(update_buffer)
373 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
374 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
375 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
376 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
377 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
378 .add_request_handler(get_users)
379 .add_request_handler(user_handler(fuzzy_search_users))
380 .add_request_handler(user_handler(request_contact))
381 .add_request_handler(user_handler(remove_contact))
382 .add_request_handler(user_handler(respond_to_contact_request))
383 .add_request_handler(user_handler(create_channel))
384 .add_request_handler(user_handler(delete_channel))
385 .add_request_handler(user_handler(invite_channel_member))
386 .add_request_handler(user_handler(remove_channel_member))
387 .add_request_handler(user_handler(set_channel_member_role))
388 .add_request_handler(user_handler(set_channel_visibility))
389 .add_request_handler(user_handler(rename_channel))
390 .add_request_handler(user_handler(join_channel_buffer))
391 .add_request_handler(user_handler(leave_channel_buffer))
392 .add_message_handler(user_message_handler(update_channel_buffer))
393 .add_request_handler(user_handler(rejoin_channel_buffers))
394 .add_request_handler(user_handler(get_channel_members))
395 .add_request_handler(user_handler(respond_to_channel_invite))
396 .add_request_handler(user_handler(join_channel))
397 .add_request_handler(user_handler(join_channel_chat))
398 .add_message_handler(user_message_handler(leave_channel_chat))
399 .add_request_handler(user_handler(send_channel_message))
400 .add_request_handler(user_handler(remove_channel_message))
401 .add_request_handler(user_handler(update_channel_message))
402 .add_request_handler(user_handler(get_channel_messages))
403 .add_request_handler(user_handler(get_channel_messages_by_id))
404 .add_request_handler(user_handler(get_notifications))
405 .add_request_handler(user_handler(mark_notification_as_read))
406 .add_request_handler(user_handler(move_channel))
407 .add_request_handler(user_handler(follow))
408 .add_message_handler(user_message_handler(unfollow))
409 .add_message_handler(user_message_handler(update_followers))
410 .add_request_handler(user_handler(get_private_user_info))
411 .add_message_handler(user_message_handler(acknowledge_channel_message))
412 .add_message_handler(user_message_handler(acknowledge_buffer_version))
413 .add_streaming_request_handler({
414 let app_state = app_state.clone();
415 move |request, response, session| {
416 complete_with_language_model(
417 request,
418 response,
419 session,
420 app_state.config.openai_api_key.clone(),
421 app_state.config.google_ai_api_key.clone(),
422 )
423 }
424 })
425 .add_request_handler({
426 let app_state = app_state.clone();
427 user_handler(move |request, response, session| {
428 count_tokens_with_language_model(
429 request,
430 response,
431 session,
432 app_state.config.google_ai_api_key.clone(),
433 )
434 })
435 });
436
437 Arc::new(server)
438 }
439
440 pub async fn start(&self) -> Result<()> {
441 let server_id = *self.id.lock();
442 let app_state = self.app_state.clone();
443 let peer = self.peer.clone();
444 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
445 let pool = self.connection_pool.clone();
446 let live_kit_client = self.app_state.live_kit_client.clone();
447
448 let span = info_span!("start server");
449 self.app_state.executor.spawn_detached(
450 async move {
451 tracing::info!("waiting for cleanup timeout");
452 timeout.await;
453 tracing::info!("cleanup timeout expired, retrieving stale rooms");
454 if let Some((room_ids, channel_ids)) = app_state
455 .db
456 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
457 .await
458 .trace_err()
459 {
460 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
461 tracing::info!(
462 stale_channel_buffer_count = channel_ids.len(),
463 "retrieved stale channel buffers"
464 );
465
466 for channel_id in channel_ids {
467 if let Some(refreshed_channel_buffer) = app_state
468 .db
469 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
470 .await
471 .trace_err()
472 {
473 for connection_id in refreshed_channel_buffer.connection_ids {
474 peer.send(
475 connection_id,
476 proto::UpdateChannelBufferCollaborators {
477 channel_id: channel_id.to_proto(),
478 collaborators: refreshed_channel_buffer
479 .collaborators
480 .clone(),
481 },
482 )
483 .trace_err();
484 }
485 }
486 }
487
488 for room_id in room_ids {
489 let mut contacts_to_update = HashSet::default();
490 let mut canceled_calls_to_user_ids = Vec::new();
491 let mut live_kit_room = String::new();
492 let mut delete_live_kit_room = false;
493
494 if let Some(mut refreshed_room) = app_state
495 .db
496 .clear_stale_room_participants(room_id, server_id)
497 .await
498 .trace_err()
499 {
500 tracing::info!(
501 room_id = room_id.0,
502 new_participant_count = refreshed_room.room.participants.len(),
503 "refreshed room"
504 );
505 room_updated(&refreshed_room.room, &peer);
506 if let Some(channel) = refreshed_room.channel.as_ref() {
507 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
508 }
509 contacts_to_update
510 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
511 contacts_to_update
512 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
513 canceled_calls_to_user_ids =
514 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
515 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
516 delete_live_kit_room = refreshed_room.room.participants.is_empty();
517 }
518
519 {
520 let pool = pool.lock();
521 for canceled_user_id in canceled_calls_to_user_ids {
522 for connection_id in pool.user_connection_ids(canceled_user_id) {
523 peer.send(
524 connection_id,
525 proto::CallCanceled {
526 room_id: room_id.to_proto(),
527 },
528 )
529 .trace_err();
530 }
531 }
532 }
533
534 for user_id in contacts_to_update {
535 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
536 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
537 if let Some((busy, contacts)) = busy.zip(contacts) {
538 let pool = pool.lock();
539 let updated_contact = contact_for_user(user_id, busy, &pool);
540 for contact in contacts {
541 if let db::Contact::Accepted {
542 user_id: contact_user_id,
543 ..
544 } = contact
545 {
546 for contact_conn_id in
547 pool.user_connection_ids(contact_user_id)
548 {
549 peer.send(
550 contact_conn_id,
551 proto::UpdateContacts {
552 contacts: vec![updated_contact.clone()],
553 remove_contacts: Default::default(),
554 incoming_requests: Default::default(),
555 remove_incoming_requests: Default::default(),
556 outgoing_requests: Default::default(),
557 remove_outgoing_requests: Default::default(),
558 },
559 )
560 .trace_err();
561 }
562 }
563 }
564 }
565 }
566
567 if let Some(live_kit) = live_kit_client.as_ref() {
568 if delete_live_kit_room {
569 live_kit.delete_room(live_kit_room).await.trace_err();
570 }
571 }
572 }
573 }
574
575 app_state
576 .db
577 .delete_stale_servers(&app_state.config.zed_environment, server_id)
578 .await
579 .trace_err();
580 }
581 .instrument(span),
582 );
583 Ok(())
584 }
585
586 pub fn teardown(&self) {
587 self.peer.teardown();
588 self.connection_pool.lock().reset();
589 let _ = self.teardown.send(true);
590 }
591
592 #[cfg(test)]
593 pub fn reset(&self, id: ServerId) {
594 self.teardown();
595 *self.id.lock() = id;
596 self.peer.reset(id.0 as u32);
597 let _ = self.teardown.send(false);
598 }
599
600 #[cfg(test)]
601 pub fn id(&self) -> ServerId {
602 *self.id.lock()
603 }
604
605 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
606 where
607 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
608 Fut: 'static + Send + Future<Output = Result<()>>,
609 M: EnvelopedMessage,
610 {
611 let prev_handler = self.handlers.insert(
612 TypeId::of::<M>(),
613 Box::new(move |envelope, session| {
614 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
615 let received_at = envelope.received_at;
616 tracing::info!(
617 "message received"
618 );
619 let start_time = Instant::now();
620 let future = (handler)(*envelope, session);
621 async move {
622 let result = future.await;
623 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
624 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
625 let queue_duration_ms = total_duration_ms - processing_duration_ms;
626 match result {
627 Err(error) => {
628 tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, "error handling message")
629 }
630 Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
631 }
632 }
633 .boxed()
634 }),
635 );
636 if prev_handler.is_some() {
637 panic!("registered a handler for the same message twice");
638 }
639 self
640 }
641
642 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
643 where
644 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
645 Fut: 'static + Send + Future<Output = Result<()>>,
646 M: EnvelopedMessage,
647 {
648 self.add_handler(move |envelope, session| handler(envelope.payload, session));
649 self
650 }
651
652 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
653 where
654 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
655 Fut: Send + Future<Output = Result<()>>,
656 M: RequestMessage,
657 {
658 let handler = Arc::new(handler);
659 self.add_handler(move |envelope, session| {
660 let receipt = envelope.receipt();
661 let handler = handler.clone();
662 async move {
663 let peer = session.peer.clone();
664 let responded = Arc::new(AtomicBool::default());
665 let response = Response {
666 peer: peer.clone(),
667 responded: responded.clone(),
668 receipt,
669 };
670 match (handler)(envelope.payload, response, session).await {
671 Ok(()) => {
672 if responded.load(std::sync::atomic::Ordering::SeqCst) {
673 Ok(())
674 } else {
675 Err(anyhow!("handler did not send a response"))?
676 }
677 }
678 Err(error) => {
679 let proto_err = match &error {
680 Error::Internal(err) => err.to_proto(),
681 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
682 };
683 peer.respond_with_error(receipt, proto_err)?;
684 Err(error)
685 }
686 }
687 }
688 })
689 }
690
691 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
692 where
693 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
694 Fut: Send + Future<Output = Result<()>>,
695 M: RequestMessage,
696 {
697 let handler = Arc::new(handler);
698 self.add_handler(move |envelope, session| {
699 let receipt = envelope.receipt();
700 let handler = handler.clone();
701 async move {
702 let peer = session.peer.clone();
703 let response = StreamingResponse {
704 peer: peer.clone(),
705 receipt,
706 };
707 match (handler)(envelope.payload, response, session).await {
708 Ok(()) => {
709 peer.end_stream(receipt)?;
710 Ok(())
711 }
712 Err(error) => {
713 let proto_err = match &error {
714 Error::Internal(err) => err.to_proto(),
715 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
716 };
717 peer.respond_with_error(receipt, proto_err)?;
718 Err(error)
719 }
720 }
721 }
722 })
723 }
724
725 #[allow(clippy::too_many_arguments)]
726 pub fn handle_connection(
727 self: &Arc<Self>,
728 connection: Connection,
729 address: String,
730 principal: Principal,
731 zed_version: ZedVersion,
732 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
733 executor: Executor,
734 ) -> impl Future<Output = ()> {
735 let this = self.clone();
736 let span = info_span!("handle connection", %address, impersonator = field::Empty, connection_id = field::Empty);
737 principal.update_span(&span);
738
739 let mut teardown = self.teardown.subscribe();
740 async move {
741 if *teardown.borrow() {
742 tracing::error!("server is tearing down");
743 return
744 }
745 let (connection_id, handle_io, mut incoming_rx) = this
746 .peer
747 .add_connection(connection, {
748 let executor = executor.clone();
749 move |duration| executor.sleep(duration)
750 });
751 tracing::Span::current().record("connection_id", format!("{}", connection_id));
752 tracing::info!("connection opened");
753
754 let http_client = match IsahcHttpClient::new() {
755 Ok(http_client) => http_client,
756 Err(error) => {
757 tracing::error!(?error, "failed to create HTTP client");
758 return;
759 }
760 };
761
762 let session = Session {
763 principal: principal.clone(),
764 connection_id,
765 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
766 peer: this.peer.clone(),
767 connection_pool: this.connection_pool.clone(),
768 live_kit_client: this.app_state.live_kit_client.clone(),
769 http_client,
770 rate_limiter: this.app_state.rate_limiter.clone(),
771 _executor: executor.clone(),
772 };
773
774 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
775 tracing::error!(?error, "failed to send initial client update");
776 return;
777 }
778
779 let handle_io = handle_io.fuse();
780 futures::pin_mut!(handle_io);
781
782 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
783 // This prevents deadlocks when e.g., client A performs a request to client B and
784 // client B performs a request to client A. If both clients stop processing further
785 // messages until their respective request completes, they won't have a chance to
786 // respond to the other client's request and cause a deadlock.
787 //
788 // This arrangement ensures we will attempt to process earlier messages first, but fall
789 // back to processing messages arrived later in the spirit of making progress.
790 let mut foreground_message_handlers = FuturesUnordered::new();
791 let concurrent_handlers = Arc::new(Semaphore::new(256));
792 loop {
793 let next_message = async {
794 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
795 let message = incoming_rx.next().await;
796 (permit, message)
797 }.fuse();
798 futures::pin_mut!(next_message);
799 futures::select_biased! {
800 _ = teardown.changed().fuse() => return,
801 result = handle_io => {
802 if let Err(error) = result {
803 tracing::error!(?error, "error handling I/O");
804 }
805 break;
806 }
807 _ = foreground_message_handlers.next() => {}
808 next_message = next_message => {
809 let (permit, message) = next_message;
810 if let Some(message) = message {
811 let type_name = message.payload_type_name();
812 // note: we copy all the fields from the parent span so we can query them in the logs.
813 // (https://github.com/tokio-rs/tracing/issues/2670).
814 let span = tracing::info_span!("receive message", %connection_id, %address, type_name);
815 principal.update_span(&span);
816 let span_enter = span.enter();
817 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
818 let is_background = message.is_background();
819 let handle_message = (handler)(message, session.clone());
820 drop(span_enter);
821
822 let handle_message = async move {
823 handle_message.await;
824 drop(permit);
825 }.instrument(span);
826 if is_background {
827 executor.spawn_detached(handle_message);
828 } else {
829 foreground_message_handlers.push(handle_message);
830 }
831 } else {
832 tracing::error!("no message handler");
833 }
834 } else {
835 tracing::info!("connection closed");
836 break;
837 }
838 }
839 }
840 }
841
842 drop(foreground_message_handlers);
843 tracing::info!("signing out");
844 if let Err(error) = connection_lost(session, teardown, executor).await {
845 tracing::error!(?error, "error signing out");
846 }
847
848 }.instrument(span)
849 }
850
851 async fn send_initial_client_update(
852 &self,
853 connection_id: ConnectionId,
854 principal: &Principal,
855 zed_version: ZedVersion,
856 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
857 session: &Session,
858 ) -> Result<()> {
859 self.peer.send(
860 connection_id,
861 proto::Hello {
862 peer_id: Some(connection_id.into()),
863 },
864 )?;
865 tracing::info!("sent hello message");
866
867 let Principal::User(user) = principal else {
868 return Ok(());
869 };
870
871 if let Some(send_connection_id) = send_connection_id.take() {
872 let _ = send_connection_id.send(connection_id);
873 }
874
875 if !user.connected_once {
876 self.peer.send(connection_id, proto::ShowContacts {})?;
877 self.app_state
878 .db
879 .set_user_connected_once(user.id, true)
880 .await?;
881 }
882
883 let (contacts, channels_for_user, channel_invites) = future::try_join3(
884 self.app_state.db.get_contacts(user.id),
885 self.app_state.db.get_channels_for_user(user.id),
886 self.app_state.db.get_channel_invites_for_user(user.id),
887 )
888 .await?;
889
890 {
891 let mut pool = self.connection_pool.lock();
892 pool.add_connection(connection_id, user.id, user.admin, zed_version);
893 for membership in &channels_for_user.channel_memberships {
894 pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
895 }
896 self.peer.send(
897 connection_id,
898 build_initial_contacts_update(contacts, &pool),
899 )?;
900 self.peer.send(
901 connection_id,
902 build_update_user_channels(&channels_for_user),
903 )?;
904 self.peer.send(
905 connection_id,
906 build_channels_update(channels_for_user, channel_invites),
907 )?;
908 }
909
910 if let Some(incoming_call) = self.app_state.db.incoming_call_for_user(user.id).await? {
911 self.peer.send(connection_id, incoming_call)?;
912 }
913
914 update_user_contacts(user.id, &session).await?;
915 Ok(())
916 }
917
918 pub async fn invite_code_redeemed(
919 self: &Arc<Self>,
920 inviter_id: UserId,
921 invitee_id: UserId,
922 ) -> Result<()> {
923 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
924 if let Some(code) = &user.invite_code {
925 let pool = self.connection_pool.lock();
926 let invitee_contact = contact_for_user(invitee_id, false, &pool);
927 for connection_id in pool.user_connection_ids(inviter_id) {
928 self.peer.send(
929 connection_id,
930 proto::UpdateContacts {
931 contacts: vec![invitee_contact.clone()],
932 ..Default::default()
933 },
934 )?;
935 self.peer.send(
936 connection_id,
937 proto::UpdateInviteInfo {
938 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
939 count: user.invite_count as u32,
940 },
941 )?;
942 }
943 }
944 }
945 Ok(())
946 }
947
948 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
949 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
950 if let Some(invite_code) = &user.invite_code {
951 let pool = self.connection_pool.lock();
952 for connection_id in pool.user_connection_ids(user_id) {
953 self.peer.send(
954 connection_id,
955 proto::UpdateInviteInfo {
956 url: format!(
957 "{}{}",
958 self.app_state.config.invite_link_prefix, invite_code
959 ),
960 count: user.invite_count as u32,
961 },
962 )?;
963 }
964 }
965 }
966 Ok(())
967 }
968
969 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
970 ServerSnapshot {
971 connection_pool: ConnectionPoolGuard {
972 guard: self.connection_pool.lock(),
973 _not_send: PhantomData,
974 },
975 peer: &self.peer,
976 }
977 }
978}
979
980impl<'a> Deref for ConnectionPoolGuard<'a> {
981 type Target = ConnectionPool;
982
983 fn deref(&self) -> &Self::Target {
984 &self.guard
985 }
986}
987
988impl<'a> DerefMut for ConnectionPoolGuard<'a> {
989 fn deref_mut(&mut self) -> &mut Self::Target {
990 &mut self.guard
991 }
992}
993
994impl<'a> Drop for ConnectionPoolGuard<'a> {
995 fn drop(&mut self) {
996 #[cfg(test)]
997 self.check_invariants();
998 }
999}
1000
1001fn broadcast<F>(
1002 sender_id: Option<ConnectionId>,
1003 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1004 mut f: F,
1005) where
1006 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1007{
1008 for receiver_id in receiver_ids {
1009 if Some(receiver_id) != sender_id {
1010 if let Err(error) = f(receiver_id) {
1011 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1012 }
1013 }
1014 }
1015}
1016
1017pub struct ProtocolVersion(u32);
1018
1019impl Header for ProtocolVersion {
1020 fn name() -> &'static HeaderName {
1021 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1022 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1023 }
1024
1025 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1026 where
1027 Self: Sized,
1028 I: Iterator<Item = &'i axum::http::HeaderValue>,
1029 {
1030 let version = values
1031 .next()
1032 .ok_or_else(axum::headers::Error::invalid)?
1033 .to_str()
1034 .map_err(|_| axum::headers::Error::invalid())?
1035 .parse()
1036 .map_err(|_| axum::headers::Error::invalid())?;
1037 Ok(Self(version))
1038 }
1039
1040 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1041 values.extend([self.0.to_string().parse().unwrap()]);
1042 }
1043}
1044
1045pub struct AppVersionHeader(SemanticVersion);
1046impl Header for AppVersionHeader {
1047 fn name() -> &'static HeaderName {
1048 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1049 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1050 }
1051
1052 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1053 where
1054 Self: Sized,
1055 I: Iterator<Item = &'i axum::http::HeaderValue>,
1056 {
1057 let version = values
1058 .next()
1059 .ok_or_else(axum::headers::Error::invalid)?
1060 .to_str()
1061 .map_err(|_| axum::headers::Error::invalid())?
1062 .parse()
1063 .map_err(|_| axum::headers::Error::invalid())?;
1064 Ok(Self(version))
1065 }
1066
1067 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1068 values.extend([self.0.to_string().parse().unwrap()]);
1069 }
1070}
1071
1072pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1073 Router::new()
1074 .route("/rpc", get(handle_websocket_request))
1075 .layer(
1076 ServiceBuilder::new()
1077 .layer(Extension(server.app_state.clone()))
1078 .layer(middleware::from_fn(auth::validate_header)),
1079 )
1080 .route("/metrics", get(handle_metrics))
1081 .layer(Extension(server))
1082}
1083
1084pub async fn handle_websocket_request(
1085 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1086 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1087 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1088 Extension(server): Extension<Arc<Server>>,
1089 Extension(principal): Extension<Principal>,
1090 ws: WebSocketUpgrade,
1091) -> axum::response::Response {
1092 if protocol_version != rpc::PROTOCOL_VERSION {
1093 return (
1094 StatusCode::UPGRADE_REQUIRED,
1095 "client must be upgraded".to_string(),
1096 )
1097 .into_response();
1098 }
1099
1100 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1101 return (
1102 StatusCode::UPGRADE_REQUIRED,
1103 "no version header found".to_string(),
1104 )
1105 .into_response();
1106 };
1107
1108 if !version.can_collaborate() {
1109 return (
1110 StatusCode::UPGRADE_REQUIRED,
1111 "client must be upgraded".to_string(),
1112 )
1113 .into_response();
1114 }
1115
1116 let socket_address = socket_address.to_string();
1117 ws.on_upgrade(move |socket| {
1118 let socket = socket
1119 .map_ok(to_tungstenite_message)
1120 .err_into()
1121 .with(|message| async move { Ok(to_axum_message(message)) });
1122 let connection = Connection::new(Box::pin(socket));
1123 async move {
1124 server
1125 .handle_connection(
1126 connection,
1127 socket_address,
1128 principal,
1129 version,
1130 None,
1131 Executor::Production,
1132 )
1133 .await;
1134 }
1135 })
1136}
1137
1138pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1139 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1140 let connections_metric = CONNECTIONS_METRIC
1141 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1142
1143 let connections = server
1144 .connection_pool
1145 .lock()
1146 .connections()
1147 .filter(|connection| !connection.admin)
1148 .count();
1149 connections_metric.set(connections as _);
1150
1151 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1152 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1153 register_int_gauge!(
1154 "shared_projects",
1155 "number of open projects with one or more guests"
1156 )
1157 .unwrap()
1158 });
1159
1160 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1161 shared_projects_metric.set(shared_projects as _);
1162
1163 let encoder = prometheus::TextEncoder::new();
1164 let metric_families = prometheus::gather();
1165 let encoded_metrics = encoder
1166 .encode_to_string(&metric_families)
1167 .map_err(|err| anyhow!("{}", err))?;
1168 Ok(encoded_metrics)
1169}
1170
1171#[instrument(err, skip(executor))]
1172async fn connection_lost(
1173 session: Session,
1174 mut teardown: watch::Receiver<bool>,
1175 executor: Executor,
1176) -> Result<()> {
1177 session.peer.disconnect(session.connection_id);
1178 session
1179 .connection_pool()
1180 .await
1181 .remove_connection(session.connection_id)?;
1182
1183 session
1184 .db()
1185 .await
1186 .connection_lost(session.connection_id)
1187 .await
1188 .trace_err();
1189
1190 futures::select_biased! {
1191 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1192 if let Some(session) = session.for_user() {
1193 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1194 leave_room_for_session(&session).await.trace_err();
1195 leave_channel_buffers_for_session(&session)
1196 .await
1197 .trace_err();
1198
1199 if !session
1200 .connection_pool()
1201 .await
1202 .is_user_online(session.user_id())
1203 {
1204 let db = session.db().await;
1205 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1206 room_updated(&room, &session.peer);
1207 }
1208 }
1209
1210 update_user_contacts(session.user_id(), &session).await?;
1211 }
1212 }
1213 _ = teardown.changed().fuse() => {}
1214 }
1215
1216 Ok(())
1217}
1218
1219/// Acknowledges a ping from a client, used to keep the connection alive.
1220async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1221 response.send(proto::Ack {})?;
1222 Ok(())
1223}
1224
1225/// Creates a new room for calling (outside of channels)
1226async fn create_room(
1227 _request: proto::CreateRoom,
1228 response: Response<proto::CreateRoom>,
1229 session: UserSession,
1230) -> Result<()> {
1231 let live_kit_room = nanoid::nanoid!(30);
1232
1233 let live_kit_connection_info = util::maybe!(async {
1234 let live_kit = session.live_kit_client.as_ref();
1235 let live_kit = live_kit?;
1236 let user_id = session.user_id().to_string();
1237
1238 let token = live_kit
1239 .room_token(&live_kit_room, &user_id.to_string())
1240 .trace_err()?;
1241
1242 Some(proto::LiveKitConnectionInfo {
1243 server_url: live_kit.url().into(),
1244 token,
1245 can_publish: true,
1246 })
1247 })
1248 .await;
1249
1250 let room = session
1251 .db()
1252 .await
1253 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1254 .await?;
1255
1256 response.send(proto::CreateRoomResponse {
1257 room: Some(room.clone()),
1258 live_kit_connection_info,
1259 })?;
1260
1261 update_user_contacts(session.user_id(), &session).await?;
1262 Ok(())
1263}
1264
1265/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1266async fn join_room(
1267 request: proto::JoinRoom,
1268 response: Response<proto::JoinRoom>,
1269 session: UserSession,
1270) -> Result<()> {
1271 let room_id = RoomId::from_proto(request.id);
1272
1273 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1274
1275 if let Some(channel_id) = channel_id {
1276 return join_channel_internal(channel_id, Box::new(response), session).await;
1277 }
1278
1279 let joined_room = {
1280 let room = session
1281 .db()
1282 .await
1283 .join_room(room_id, session.user_id(), session.connection_id)
1284 .await?;
1285 room_updated(&room.room, &session.peer);
1286 room.into_inner()
1287 };
1288
1289 for connection_id in session
1290 .connection_pool()
1291 .await
1292 .user_connection_ids(session.user_id())
1293 {
1294 session
1295 .peer
1296 .send(
1297 connection_id,
1298 proto::CallCanceled {
1299 room_id: room_id.to_proto(),
1300 },
1301 )
1302 .trace_err();
1303 }
1304
1305 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1306 if let Some(token) = live_kit
1307 .room_token(
1308 &joined_room.room.live_kit_room,
1309 &session.user_id().to_string(),
1310 )
1311 .trace_err()
1312 {
1313 Some(proto::LiveKitConnectionInfo {
1314 server_url: live_kit.url().into(),
1315 token,
1316 can_publish: true,
1317 })
1318 } else {
1319 None
1320 }
1321 } else {
1322 None
1323 };
1324
1325 response.send(proto::JoinRoomResponse {
1326 room: Some(joined_room.room),
1327 channel_id: None,
1328 live_kit_connection_info,
1329 })?;
1330
1331 update_user_contacts(session.user_id(), &session).await?;
1332 Ok(())
1333}
1334
1335/// Rejoin room is used to reconnect to a room after connection errors.
1336async fn rejoin_room(
1337 request: proto::RejoinRoom,
1338 response: Response<proto::RejoinRoom>,
1339 session: UserSession,
1340) -> Result<()> {
1341 let room;
1342 let channel;
1343 {
1344 let mut rejoined_room = session
1345 .db()
1346 .await
1347 .rejoin_room(request, session.user_id(), session.connection_id)
1348 .await?;
1349
1350 response.send(proto::RejoinRoomResponse {
1351 room: Some(rejoined_room.room.clone()),
1352 reshared_projects: rejoined_room
1353 .reshared_projects
1354 .iter()
1355 .map(|project| proto::ResharedProject {
1356 id: project.id.to_proto(),
1357 collaborators: project
1358 .collaborators
1359 .iter()
1360 .map(|collaborator| collaborator.to_proto())
1361 .collect(),
1362 })
1363 .collect(),
1364 rejoined_projects: rejoined_room
1365 .rejoined_projects
1366 .iter()
1367 .map(|rejoined_project| proto::RejoinedProject {
1368 id: rejoined_project.id.to_proto(),
1369 worktrees: rejoined_project
1370 .worktrees
1371 .iter()
1372 .map(|worktree| proto::WorktreeMetadata {
1373 id: worktree.id,
1374 root_name: worktree.root_name.clone(),
1375 visible: worktree.visible,
1376 abs_path: worktree.abs_path.clone(),
1377 })
1378 .collect(),
1379 collaborators: rejoined_project
1380 .collaborators
1381 .iter()
1382 .map(|collaborator| collaborator.to_proto())
1383 .collect(),
1384 language_servers: rejoined_project.language_servers.clone(),
1385 })
1386 .collect(),
1387 })?;
1388 room_updated(&rejoined_room.room, &session.peer);
1389
1390 for project in &rejoined_room.reshared_projects {
1391 for collaborator in &project.collaborators {
1392 session
1393 .peer
1394 .send(
1395 collaborator.connection_id,
1396 proto::UpdateProjectCollaborator {
1397 project_id: project.id.to_proto(),
1398 old_peer_id: Some(project.old_connection_id.into()),
1399 new_peer_id: Some(session.connection_id.into()),
1400 },
1401 )
1402 .trace_err();
1403 }
1404
1405 broadcast(
1406 Some(session.connection_id),
1407 project
1408 .collaborators
1409 .iter()
1410 .map(|collaborator| collaborator.connection_id),
1411 |connection_id| {
1412 session.peer.forward_send(
1413 session.connection_id,
1414 connection_id,
1415 proto::UpdateProject {
1416 project_id: project.id.to_proto(),
1417 worktrees: project.worktrees.clone(),
1418 },
1419 )
1420 },
1421 );
1422 }
1423
1424 for project in &rejoined_room.rejoined_projects {
1425 for collaborator in &project.collaborators {
1426 session
1427 .peer
1428 .send(
1429 collaborator.connection_id,
1430 proto::UpdateProjectCollaborator {
1431 project_id: project.id.to_proto(),
1432 old_peer_id: Some(project.old_connection_id.into()),
1433 new_peer_id: Some(session.connection_id.into()),
1434 },
1435 )
1436 .trace_err();
1437 }
1438 }
1439
1440 for project in &mut rejoined_room.rejoined_projects {
1441 for worktree in mem::take(&mut project.worktrees) {
1442 #[cfg(any(test, feature = "test-support"))]
1443 const MAX_CHUNK_SIZE: usize = 2;
1444 #[cfg(not(any(test, feature = "test-support")))]
1445 const MAX_CHUNK_SIZE: usize = 256;
1446
1447 // Stream this worktree's entries.
1448 let message = proto::UpdateWorktree {
1449 project_id: project.id.to_proto(),
1450 worktree_id: worktree.id,
1451 abs_path: worktree.abs_path.clone(),
1452 root_name: worktree.root_name,
1453 updated_entries: worktree.updated_entries,
1454 removed_entries: worktree.removed_entries,
1455 scan_id: worktree.scan_id,
1456 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1457 updated_repositories: worktree.updated_repositories,
1458 removed_repositories: worktree.removed_repositories,
1459 };
1460 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1461 session.peer.send(session.connection_id, update.clone())?;
1462 }
1463
1464 // Stream this worktree's diagnostics.
1465 for summary in worktree.diagnostic_summaries {
1466 session.peer.send(
1467 session.connection_id,
1468 proto::UpdateDiagnosticSummary {
1469 project_id: project.id.to_proto(),
1470 worktree_id: worktree.id,
1471 summary: Some(summary),
1472 },
1473 )?;
1474 }
1475
1476 for settings_file in worktree.settings_files {
1477 session.peer.send(
1478 session.connection_id,
1479 proto::UpdateWorktreeSettings {
1480 project_id: project.id.to_proto(),
1481 worktree_id: worktree.id,
1482 path: settings_file.path,
1483 content: Some(settings_file.content),
1484 },
1485 )?;
1486 }
1487 }
1488
1489 for language_server in &project.language_servers {
1490 session.peer.send(
1491 session.connection_id,
1492 proto::UpdateLanguageServer {
1493 project_id: project.id.to_proto(),
1494 language_server_id: language_server.id,
1495 variant: Some(
1496 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1497 proto::LspDiskBasedDiagnosticsUpdated {},
1498 ),
1499 ),
1500 },
1501 )?;
1502 }
1503 }
1504
1505 let rejoined_room = rejoined_room.into_inner();
1506
1507 room = rejoined_room.room;
1508 channel = rejoined_room.channel;
1509 }
1510
1511 if let Some(channel) = channel {
1512 channel_updated(
1513 &channel,
1514 &room,
1515 &session.peer,
1516 &*session.connection_pool().await,
1517 );
1518 }
1519
1520 update_user_contacts(session.user_id(), &session).await?;
1521 Ok(())
1522}
1523
1524/// leave room disconnects from the room.
1525async fn leave_room(
1526 _: proto::LeaveRoom,
1527 response: Response<proto::LeaveRoom>,
1528 session: UserSession,
1529) -> Result<()> {
1530 leave_room_for_session(&session).await?;
1531 response.send(proto::Ack {})?;
1532 Ok(())
1533}
1534
1535/// Updates the permissions of someone else in the room.
1536async fn set_room_participant_role(
1537 request: proto::SetRoomParticipantRole,
1538 response: Response<proto::SetRoomParticipantRole>,
1539 session: UserSession,
1540) -> Result<()> {
1541 let user_id = UserId::from_proto(request.user_id);
1542 let role = ChannelRole::from(request.role());
1543
1544 let (live_kit_room, can_publish) = {
1545 let room = session
1546 .db()
1547 .await
1548 .set_room_participant_role(
1549 session.user_id(),
1550 RoomId::from_proto(request.room_id),
1551 user_id,
1552 role,
1553 )
1554 .await?;
1555
1556 let live_kit_room = room.live_kit_room.clone();
1557 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1558 room_updated(&room, &session.peer);
1559 (live_kit_room, can_publish)
1560 };
1561
1562 if let Some(live_kit) = session.live_kit_client.as_ref() {
1563 live_kit
1564 .update_participant(
1565 live_kit_room.clone(),
1566 request.user_id.to_string(),
1567 live_kit_server::proto::ParticipantPermission {
1568 can_subscribe: true,
1569 can_publish,
1570 can_publish_data: can_publish,
1571 hidden: false,
1572 recorder: false,
1573 },
1574 )
1575 .await
1576 .trace_err();
1577 }
1578
1579 response.send(proto::Ack {})?;
1580 Ok(())
1581}
1582
1583/// Call someone else into the current room
1584async fn call(
1585 request: proto::Call,
1586 response: Response<proto::Call>,
1587 session: UserSession,
1588) -> Result<()> {
1589 let room_id = RoomId::from_proto(request.room_id);
1590 let calling_user_id = session.user_id();
1591 let calling_connection_id = session.connection_id;
1592 let called_user_id = UserId::from_proto(request.called_user_id);
1593 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1594 if !session
1595 .db()
1596 .await
1597 .has_contact(calling_user_id, called_user_id)
1598 .await?
1599 {
1600 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1601 }
1602
1603 let incoming_call = {
1604 let (room, incoming_call) = &mut *session
1605 .db()
1606 .await
1607 .call(
1608 room_id,
1609 calling_user_id,
1610 calling_connection_id,
1611 called_user_id,
1612 initial_project_id,
1613 )
1614 .await?;
1615 room_updated(&room, &session.peer);
1616 mem::take(incoming_call)
1617 };
1618 update_user_contacts(called_user_id, &session).await?;
1619
1620 let mut calls = session
1621 .connection_pool()
1622 .await
1623 .user_connection_ids(called_user_id)
1624 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1625 .collect::<FuturesUnordered<_>>();
1626
1627 while let Some(call_response) = calls.next().await {
1628 match call_response.as_ref() {
1629 Ok(_) => {
1630 response.send(proto::Ack {})?;
1631 return Ok(());
1632 }
1633 Err(_) => {
1634 call_response.trace_err();
1635 }
1636 }
1637 }
1638
1639 {
1640 let room = session
1641 .db()
1642 .await
1643 .call_failed(room_id, called_user_id)
1644 .await?;
1645 room_updated(&room, &session.peer);
1646 }
1647 update_user_contacts(called_user_id, &session).await?;
1648
1649 Err(anyhow!("failed to ring user"))?
1650}
1651
1652/// Cancel an outgoing call.
1653async fn cancel_call(
1654 request: proto::CancelCall,
1655 response: Response<proto::CancelCall>,
1656 session: UserSession,
1657) -> Result<()> {
1658 let called_user_id = UserId::from_proto(request.called_user_id);
1659 let room_id = RoomId::from_proto(request.room_id);
1660 {
1661 let room = session
1662 .db()
1663 .await
1664 .cancel_call(room_id, session.connection_id, called_user_id)
1665 .await?;
1666 room_updated(&room, &session.peer);
1667 }
1668
1669 for connection_id in session
1670 .connection_pool()
1671 .await
1672 .user_connection_ids(called_user_id)
1673 {
1674 session
1675 .peer
1676 .send(
1677 connection_id,
1678 proto::CallCanceled {
1679 room_id: room_id.to_proto(),
1680 },
1681 )
1682 .trace_err();
1683 }
1684 response.send(proto::Ack {})?;
1685
1686 update_user_contacts(called_user_id, &session).await?;
1687 Ok(())
1688}
1689
1690/// Decline an incoming call.
1691async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1692 let room_id = RoomId::from_proto(message.room_id);
1693 {
1694 let room = session
1695 .db()
1696 .await
1697 .decline_call(Some(room_id), session.user_id())
1698 .await?
1699 .ok_or_else(|| anyhow!("failed to decline call"))?;
1700 room_updated(&room, &session.peer);
1701 }
1702
1703 for connection_id in session
1704 .connection_pool()
1705 .await
1706 .user_connection_ids(session.user_id())
1707 {
1708 session
1709 .peer
1710 .send(
1711 connection_id,
1712 proto::CallCanceled {
1713 room_id: room_id.to_proto(),
1714 },
1715 )
1716 .trace_err();
1717 }
1718 update_user_contacts(session.user_id(), &session).await?;
1719 Ok(())
1720}
1721
1722/// Updates other participants in the room with your current location.
1723async fn update_participant_location(
1724 request: proto::UpdateParticipantLocation,
1725 response: Response<proto::UpdateParticipantLocation>,
1726 session: UserSession,
1727) -> Result<()> {
1728 let room_id = RoomId::from_proto(request.room_id);
1729 let location = request
1730 .location
1731 .ok_or_else(|| anyhow!("invalid location"))?;
1732
1733 let db = session.db().await;
1734 let room = db
1735 .update_room_participant_location(room_id, session.connection_id, location)
1736 .await?;
1737
1738 room_updated(&room, &session.peer);
1739 response.send(proto::Ack {})?;
1740 Ok(())
1741}
1742
1743/// Share a project into the room.
1744async fn share_project(
1745 request: proto::ShareProject,
1746 response: Response<proto::ShareProject>,
1747 session: Session,
1748) -> Result<()> {
1749 let (project_id, room) = &*session
1750 .db()
1751 .await
1752 .share_project(
1753 RoomId::from_proto(request.room_id),
1754 session.connection_id,
1755 &request.worktrees,
1756 )
1757 .await?;
1758 response.send(proto::ShareProjectResponse {
1759 project_id: project_id.to_proto(),
1760 })?;
1761 room_updated(&room, &session.peer);
1762
1763 Ok(())
1764}
1765
1766/// Unshare a project from the room.
1767async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1768 let project_id = ProjectId::from_proto(message.project_id);
1769
1770 let (room, guest_connection_ids) = &*session
1771 .db()
1772 .await
1773 .unshare_project(project_id, session.connection_id)
1774 .await?;
1775
1776 broadcast(
1777 Some(session.connection_id),
1778 guest_connection_ids.iter().copied(),
1779 |conn_id| session.peer.send(conn_id, message.clone()),
1780 );
1781 room_updated(&room, &session.peer);
1782
1783 Ok(())
1784}
1785
1786/// Join someone elses shared project.
1787async fn join_project(
1788 request: proto::JoinProject,
1789 response: Response<proto::JoinProject>,
1790 session: UserSession,
1791) -> Result<()> {
1792 let project_id = ProjectId::from_proto(request.project_id);
1793
1794 tracing::info!(%project_id, "join project");
1795
1796 let (project, replica_id) = &mut *session
1797 .db()
1798 .await
1799 .join_project_in_room(project_id, session.connection_id)
1800 .await?;
1801
1802 join_project_internal(response, session, project, replica_id)
1803}
1804
1805trait JoinProjectInternalResponse {
1806 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1807}
1808impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1809 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1810 Response::<proto::JoinProject>::send(self, result)
1811 }
1812}
1813impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1814 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1815 Response::<proto::JoinHostedProject>::send(self, result)
1816 }
1817}
1818
1819fn join_project_internal(
1820 response: impl JoinProjectInternalResponse,
1821 session: UserSession,
1822 project: &mut Project,
1823 replica_id: &ReplicaId,
1824) -> Result<()> {
1825 let collaborators = project
1826 .collaborators
1827 .iter()
1828 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1829 .map(|collaborator| collaborator.to_proto())
1830 .collect::<Vec<_>>();
1831 let project_id = project.id;
1832 let guest_user_id = session.user_id();
1833
1834 let worktrees = project
1835 .worktrees
1836 .iter()
1837 .map(|(id, worktree)| proto::WorktreeMetadata {
1838 id: *id,
1839 root_name: worktree.root_name.clone(),
1840 visible: worktree.visible,
1841 abs_path: worktree.abs_path.clone(),
1842 })
1843 .collect::<Vec<_>>();
1844
1845 for collaborator in &collaborators {
1846 session
1847 .peer
1848 .send(
1849 collaborator.peer_id.unwrap().into(),
1850 proto::AddProjectCollaborator {
1851 project_id: project_id.to_proto(),
1852 collaborator: Some(proto::Collaborator {
1853 peer_id: Some(session.connection_id.into()),
1854 replica_id: replica_id.0 as u32,
1855 user_id: guest_user_id.to_proto(),
1856 }),
1857 },
1858 )
1859 .trace_err();
1860 }
1861
1862 // First, we send the metadata associated with each worktree.
1863 response.send(proto::JoinProjectResponse {
1864 project_id: project.id.0 as u64,
1865 worktrees: worktrees.clone(),
1866 replica_id: replica_id.0 as u32,
1867 collaborators: collaborators.clone(),
1868 language_servers: project.language_servers.clone(),
1869 role: project.role.into(), // todo
1870 })?;
1871
1872 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1873 #[cfg(any(test, feature = "test-support"))]
1874 const MAX_CHUNK_SIZE: usize = 2;
1875 #[cfg(not(any(test, feature = "test-support")))]
1876 const MAX_CHUNK_SIZE: usize = 256;
1877
1878 // Stream this worktree's entries.
1879 let message = proto::UpdateWorktree {
1880 project_id: project_id.to_proto(),
1881 worktree_id,
1882 abs_path: worktree.abs_path.clone(),
1883 root_name: worktree.root_name,
1884 updated_entries: worktree.entries,
1885 removed_entries: Default::default(),
1886 scan_id: worktree.scan_id,
1887 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1888 updated_repositories: worktree.repository_entries.into_values().collect(),
1889 removed_repositories: Default::default(),
1890 };
1891 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1892 session.peer.send(session.connection_id, update.clone())?;
1893 }
1894
1895 // Stream this worktree's diagnostics.
1896 for summary in worktree.diagnostic_summaries {
1897 session.peer.send(
1898 session.connection_id,
1899 proto::UpdateDiagnosticSummary {
1900 project_id: project_id.to_proto(),
1901 worktree_id: worktree.id,
1902 summary: Some(summary),
1903 },
1904 )?;
1905 }
1906
1907 for settings_file in worktree.settings_files {
1908 session.peer.send(
1909 session.connection_id,
1910 proto::UpdateWorktreeSettings {
1911 project_id: project_id.to_proto(),
1912 worktree_id: worktree.id,
1913 path: settings_file.path,
1914 content: Some(settings_file.content),
1915 },
1916 )?;
1917 }
1918 }
1919
1920 for language_server in &project.language_servers {
1921 session.peer.send(
1922 session.connection_id,
1923 proto::UpdateLanguageServer {
1924 project_id: project_id.to_proto(),
1925 language_server_id: language_server.id,
1926 variant: Some(
1927 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1928 proto::LspDiskBasedDiagnosticsUpdated {},
1929 ),
1930 ),
1931 },
1932 )?;
1933 }
1934
1935 Ok(())
1936}
1937
1938/// Leave someone elses shared project.
1939async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
1940 let sender_id = session.connection_id;
1941 let project_id = ProjectId::from_proto(request.project_id);
1942 let db = session.db().await;
1943 if db.is_hosted_project(project_id).await? {
1944 let project = db.leave_hosted_project(project_id, sender_id).await?;
1945 project_left(&project, &session);
1946 return Ok(());
1947 }
1948
1949 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1950 tracing::info!(
1951 %project_id,
1952 host_user_id = ?project.host_user_id,
1953 host_connection_id = ?project.host_connection_id,
1954 "leave project"
1955 );
1956
1957 project_left(&project, &session);
1958 room_updated(&room, &session.peer);
1959
1960 Ok(())
1961}
1962
1963async fn join_hosted_project(
1964 request: proto::JoinHostedProject,
1965 response: Response<proto::JoinHostedProject>,
1966 session: UserSession,
1967) -> Result<()> {
1968 let (mut project, replica_id) = session
1969 .db()
1970 .await
1971 .join_hosted_project(
1972 ProjectId(request.project_id as i32),
1973 session.user_id(),
1974 session.connection_id,
1975 )
1976 .await?;
1977
1978 join_project_internal(response, session, &mut project, &replica_id)
1979}
1980
1981/// Updates other participants with changes to the project
1982async fn update_project(
1983 request: proto::UpdateProject,
1984 response: Response<proto::UpdateProject>,
1985 session: Session,
1986) -> Result<()> {
1987 let project_id = ProjectId::from_proto(request.project_id);
1988 let (room, guest_connection_ids) = &*session
1989 .db()
1990 .await
1991 .update_project(project_id, session.connection_id, &request.worktrees)
1992 .await?;
1993 broadcast(
1994 Some(session.connection_id),
1995 guest_connection_ids.iter().copied(),
1996 |connection_id| {
1997 session
1998 .peer
1999 .forward_send(session.connection_id, connection_id, request.clone())
2000 },
2001 );
2002 room_updated(&room, &session.peer);
2003 response.send(proto::Ack {})?;
2004
2005 Ok(())
2006}
2007
2008/// Updates other participants with changes to the worktree
2009async fn update_worktree(
2010 request: proto::UpdateWorktree,
2011 response: Response<proto::UpdateWorktree>,
2012 session: Session,
2013) -> Result<()> {
2014 let guest_connection_ids = session
2015 .db()
2016 .await
2017 .update_worktree(&request, session.connection_id)
2018 .await?;
2019
2020 broadcast(
2021 Some(session.connection_id),
2022 guest_connection_ids.iter().copied(),
2023 |connection_id| {
2024 session
2025 .peer
2026 .forward_send(session.connection_id, connection_id, request.clone())
2027 },
2028 );
2029 response.send(proto::Ack {})?;
2030 Ok(())
2031}
2032
2033/// Updates other participants with changes to the diagnostics
2034async fn update_diagnostic_summary(
2035 message: proto::UpdateDiagnosticSummary,
2036 session: Session,
2037) -> Result<()> {
2038 let guest_connection_ids = session
2039 .db()
2040 .await
2041 .update_diagnostic_summary(&message, session.connection_id)
2042 .await?;
2043
2044 broadcast(
2045 Some(session.connection_id),
2046 guest_connection_ids.iter().copied(),
2047 |connection_id| {
2048 session
2049 .peer
2050 .forward_send(session.connection_id, connection_id, message.clone())
2051 },
2052 );
2053
2054 Ok(())
2055}
2056
2057/// Updates other participants with changes to the worktree settings
2058async fn update_worktree_settings(
2059 message: proto::UpdateWorktreeSettings,
2060 session: Session,
2061) -> Result<()> {
2062 let guest_connection_ids = session
2063 .db()
2064 .await
2065 .update_worktree_settings(&message, session.connection_id)
2066 .await?;
2067
2068 broadcast(
2069 Some(session.connection_id),
2070 guest_connection_ids.iter().copied(),
2071 |connection_id| {
2072 session
2073 .peer
2074 .forward_send(session.connection_id, connection_id, message.clone())
2075 },
2076 );
2077
2078 Ok(())
2079}
2080
2081/// Notify other participants that a language server has started.
2082async fn start_language_server(
2083 request: proto::StartLanguageServer,
2084 session: Session,
2085) -> Result<()> {
2086 let guest_connection_ids = session
2087 .db()
2088 .await
2089 .start_language_server(&request, session.connection_id)
2090 .await?;
2091
2092 broadcast(
2093 Some(session.connection_id),
2094 guest_connection_ids.iter().copied(),
2095 |connection_id| {
2096 session
2097 .peer
2098 .forward_send(session.connection_id, connection_id, request.clone())
2099 },
2100 );
2101 Ok(())
2102}
2103
2104/// Notify other participants that a language server has changed.
2105async fn update_language_server(
2106 request: proto::UpdateLanguageServer,
2107 session: Session,
2108) -> Result<()> {
2109 let project_id = ProjectId::from_proto(request.project_id);
2110 let project_connection_ids = session
2111 .db()
2112 .await
2113 .project_connection_ids(project_id, session.connection_id)
2114 .await?;
2115 broadcast(
2116 Some(session.connection_id),
2117 project_connection_ids.iter().copied(),
2118 |connection_id| {
2119 session
2120 .peer
2121 .forward_send(session.connection_id, connection_id, request.clone())
2122 },
2123 );
2124 Ok(())
2125}
2126
2127/// forward a project request to the host. These requests should be read only
2128/// as guests are allowed to send them.
2129async fn forward_read_only_project_request<T>(
2130 request: T,
2131 response: Response<T>,
2132 session: Session,
2133) -> Result<()>
2134where
2135 T: EntityMessage + RequestMessage,
2136{
2137 let project_id = ProjectId::from_proto(request.remote_entity_id());
2138 let host_connection_id = session
2139 .db()
2140 .await
2141 .host_for_read_only_project_request(project_id, session.connection_id)
2142 .await?;
2143 let payload = session
2144 .peer
2145 .forward_request(session.connection_id, host_connection_id, request)
2146 .await?;
2147 response.send(payload)?;
2148 Ok(())
2149}
2150
2151/// forward a project request to the host. These requests are disallowed
2152/// for guests.
2153async fn forward_mutating_project_request<T>(
2154 request: T,
2155 response: Response<T>,
2156 session: Session,
2157) -> Result<()>
2158where
2159 T: EntityMessage + RequestMessage,
2160{
2161 let project_id = ProjectId::from_proto(request.remote_entity_id());
2162 let host_connection_id = session
2163 .db()
2164 .await
2165 .host_for_mutating_project_request(project_id, session.connection_id)
2166 .await?;
2167 let payload = session
2168 .peer
2169 .forward_request(session.connection_id, host_connection_id, request)
2170 .await?;
2171 response.send(payload)?;
2172 Ok(())
2173}
2174
2175/// Notify other participants that a new buffer has been created
2176async fn create_buffer_for_peer(
2177 request: proto::CreateBufferForPeer,
2178 session: Session,
2179) -> Result<()> {
2180 session
2181 .db()
2182 .await
2183 .check_user_is_project_host(
2184 ProjectId::from_proto(request.project_id),
2185 session.connection_id,
2186 )
2187 .await?;
2188 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2189 session
2190 .peer
2191 .forward_send(session.connection_id, peer_id.into(), request)?;
2192 Ok(())
2193}
2194
2195/// Notify other participants that a buffer has been updated. This is
2196/// allowed for guests as long as the update is limited to selections.
2197async fn update_buffer(
2198 request: proto::UpdateBuffer,
2199 response: Response<proto::UpdateBuffer>,
2200 session: Session,
2201) -> Result<()> {
2202 let project_id = ProjectId::from_proto(request.project_id);
2203 let mut guest_connection_ids;
2204 let mut host_connection_id = None;
2205
2206 let mut requires_write_permission = false;
2207
2208 for op in request.operations.iter() {
2209 match op.variant {
2210 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2211 Some(_) => requires_write_permission = true,
2212 }
2213 }
2214
2215 {
2216 let collaborators = session
2217 .db()
2218 .await
2219 .project_collaborators_for_buffer_update(
2220 project_id,
2221 session.connection_id,
2222 requires_write_permission,
2223 )
2224 .await?;
2225 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2226 for collaborator in collaborators.iter() {
2227 if collaborator.is_host {
2228 host_connection_id = Some(collaborator.connection_id);
2229 } else {
2230 guest_connection_ids.push(collaborator.connection_id);
2231 }
2232 }
2233 }
2234 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2235
2236 broadcast(
2237 Some(session.connection_id),
2238 guest_connection_ids,
2239 |connection_id| {
2240 session
2241 .peer
2242 .forward_send(session.connection_id, connection_id, request.clone())
2243 },
2244 );
2245 if host_connection_id != session.connection_id {
2246 session
2247 .peer
2248 .forward_request(session.connection_id, host_connection_id, request.clone())
2249 .await?;
2250 }
2251
2252 response.send(proto::Ack {})?;
2253 Ok(())
2254}
2255
2256/// Notify other participants that a project has been updated.
2257async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2258 request: T,
2259 session: Session,
2260) -> Result<()> {
2261 let project_id = ProjectId::from_proto(request.remote_entity_id());
2262 let project_connection_ids = session
2263 .db()
2264 .await
2265 .project_connection_ids(project_id, session.connection_id)
2266 .await?;
2267
2268 broadcast(
2269 Some(session.connection_id),
2270 project_connection_ids.iter().copied(),
2271 |connection_id| {
2272 session
2273 .peer
2274 .forward_send(session.connection_id, connection_id, request.clone())
2275 },
2276 );
2277 Ok(())
2278}
2279
2280/// Start following another user in a call.
2281async fn follow(
2282 request: proto::Follow,
2283 response: Response<proto::Follow>,
2284 session: UserSession,
2285) -> Result<()> {
2286 let room_id = RoomId::from_proto(request.room_id);
2287 let project_id = request.project_id.map(ProjectId::from_proto);
2288 let leader_id = request
2289 .leader_id
2290 .ok_or_else(|| anyhow!("invalid leader id"))?
2291 .into();
2292 let follower_id = session.connection_id;
2293
2294 session
2295 .db()
2296 .await
2297 .check_room_participants(room_id, leader_id, session.connection_id)
2298 .await?;
2299
2300 let response_payload = session
2301 .peer
2302 .forward_request(session.connection_id, leader_id, request)
2303 .await?;
2304 response.send(response_payload)?;
2305
2306 if let Some(project_id) = project_id {
2307 let room = session
2308 .db()
2309 .await
2310 .follow(room_id, project_id, leader_id, follower_id)
2311 .await?;
2312 room_updated(&room, &session.peer);
2313 }
2314
2315 Ok(())
2316}
2317
2318/// Stop following another user in a call.
2319async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2320 let room_id = RoomId::from_proto(request.room_id);
2321 let project_id = request.project_id.map(ProjectId::from_proto);
2322 let leader_id = request
2323 .leader_id
2324 .ok_or_else(|| anyhow!("invalid leader id"))?
2325 .into();
2326 let follower_id = session.connection_id;
2327
2328 session
2329 .db()
2330 .await
2331 .check_room_participants(room_id, leader_id, session.connection_id)
2332 .await?;
2333
2334 session
2335 .peer
2336 .forward_send(session.connection_id, leader_id, request)?;
2337
2338 if let Some(project_id) = project_id {
2339 let room = session
2340 .db()
2341 .await
2342 .unfollow(room_id, project_id, leader_id, follower_id)
2343 .await?;
2344 room_updated(&room, &session.peer);
2345 }
2346
2347 Ok(())
2348}
2349
2350/// Notify everyone following you of your current location.
2351async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
2352 let room_id = RoomId::from_proto(request.room_id);
2353 let database = session.db.lock().await;
2354
2355 let connection_ids = if let Some(project_id) = request.project_id {
2356 let project_id = ProjectId::from_proto(project_id);
2357 database
2358 .project_connection_ids(project_id, session.connection_id)
2359 .await?
2360 } else {
2361 database
2362 .room_connection_ids(room_id, session.connection_id)
2363 .await?
2364 };
2365
2366 // For now, don't send view update messages back to that view's current leader.
2367 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2368 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2369 _ => None,
2370 });
2371
2372 for connection_id in connection_ids.iter().cloned() {
2373 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2374 session
2375 .peer
2376 .forward_send(session.connection_id, connection_id, request.clone())?;
2377 }
2378 }
2379 Ok(())
2380}
2381
2382/// Get public data about users.
2383async fn get_users(
2384 request: proto::GetUsers,
2385 response: Response<proto::GetUsers>,
2386 session: Session,
2387) -> Result<()> {
2388 let user_ids = request
2389 .user_ids
2390 .into_iter()
2391 .map(UserId::from_proto)
2392 .collect();
2393 let users = session
2394 .db()
2395 .await
2396 .get_users_by_ids(user_ids)
2397 .await?
2398 .into_iter()
2399 .map(|user| proto::User {
2400 id: user.id.to_proto(),
2401 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2402 github_login: user.github_login,
2403 })
2404 .collect();
2405 response.send(proto::UsersResponse { users })?;
2406 Ok(())
2407}
2408
2409/// Search for users (to invite) buy Github login
2410async fn fuzzy_search_users(
2411 request: proto::FuzzySearchUsers,
2412 response: Response<proto::FuzzySearchUsers>,
2413 session: UserSession,
2414) -> Result<()> {
2415 let query = request.query;
2416 let users = match query.len() {
2417 0 => vec![],
2418 1 | 2 => session
2419 .db()
2420 .await
2421 .get_user_by_github_login(&query)
2422 .await?
2423 .into_iter()
2424 .collect(),
2425 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2426 };
2427 let users = users
2428 .into_iter()
2429 .filter(|user| user.id != session.user_id())
2430 .map(|user| proto::User {
2431 id: user.id.to_proto(),
2432 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2433 github_login: user.github_login,
2434 })
2435 .collect();
2436 response.send(proto::UsersResponse { users })?;
2437 Ok(())
2438}
2439
2440/// Send a contact request to another user.
2441async fn request_contact(
2442 request: proto::RequestContact,
2443 response: Response<proto::RequestContact>,
2444 session: UserSession,
2445) -> Result<()> {
2446 let requester_id = session.user_id();
2447 let responder_id = UserId::from_proto(request.responder_id);
2448 if requester_id == responder_id {
2449 return Err(anyhow!("cannot add yourself as a contact"))?;
2450 }
2451
2452 let notifications = session
2453 .db()
2454 .await
2455 .send_contact_request(requester_id, responder_id)
2456 .await?;
2457
2458 // Update outgoing contact requests of requester
2459 let mut update = proto::UpdateContacts::default();
2460 update.outgoing_requests.push(responder_id.to_proto());
2461 for connection_id in session
2462 .connection_pool()
2463 .await
2464 .user_connection_ids(requester_id)
2465 {
2466 session.peer.send(connection_id, update.clone())?;
2467 }
2468
2469 // Update incoming contact requests of responder
2470 let mut update = proto::UpdateContacts::default();
2471 update
2472 .incoming_requests
2473 .push(proto::IncomingContactRequest {
2474 requester_id: requester_id.to_proto(),
2475 });
2476 let connection_pool = session.connection_pool().await;
2477 for connection_id in connection_pool.user_connection_ids(responder_id) {
2478 session.peer.send(connection_id, update.clone())?;
2479 }
2480
2481 send_notifications(&connection_pool, &session.peer, notifications);
2482
2483 response.send(proto::Ack {})?;
2484 Ok(())
2485}
2486
2487/// Accept or decline a contact request
2488async fn respond_to_contact_request(
2489 request: proto::RespondToContactRequest,
2490 response: Response<proto::RespondToContactRequest>,
2491 session: UserSession,
2492) -> Result<()> {
2493 let responder_id = session.user_id();
2494 let requester_id = UserId::from_proto(request.requester_id);
2495 let db = session.db().await;
2496 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2497 db.dismiss_contact_notification(responder_id, requester_id)
2498 .await?;
2499 } else {
2500 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2501
2502 let notifications = db
2503 .respond_to_contact_request(responder_id, requester_id, accept)
2504 .await?;
2505 let requester_busy = db.is_user_busy(requester_id).await?;
2506 let responder_busy = db.is_user_busy(responder_id).await?;
2507
2508 let pool = session.connection_pool().await;
2509 // Update responder with new contact
2510 let mut update = proto::UpdateContacts::default();
2511 if accept {
2512 update
2513 .contacts
2514 .push(contact_for_user(requester_id, requester_busy, &pool));
2515 }
2516 update
2517 .remove_incoming_requests
2518 .push(requester_id.to_proto());
2519 for connection_id in pool.user_connection_ids(responder_id) {
2520 session.peer.send(connection_id, update.clone())?;
2521 }
2522
2523 // Update requester with new contact
2524 let mut update = proto::UpdateContacts::default();
2525 if accept {
2526 update
2527 .contacts
2528 .push(contact_for_user(responder_id, responder_busy, &pool));
2529 }
2530 update
2531 .remove_outgoing_requests
2532 .push(responder_id.to_proto());
2533
2534 for connection_id in pool.user_connection_ids(requester_id) {
2535 session.peer.send(connection_id, update.clone())?;
2536 }
2537
2538 send_notifications(&pool, &session.peer, notifications);
2539 }
2540
2541 response.send(proto::Ack {})?;
2542 Ok(())
2543}
2544
2545/// Remove a contact.
2546async fn remove_contact(
2547 request: proto::RemoveContact,
2548 response: Response<proto::RemoveContact>,
2549 session: UserSession,
2550) -> Result<()> {
2551 let requester_id = session.user_id();
2552 let responder_id = UserId::from_proto(request.user_id);
2553 let db = session.db().await;
2554 let (contact_accepted, deleted_notification_id) =
2555 db.remove_contact(requester_id, responder_id).await?;
2556
2557 let pool = session.connection_pool().await;
2558 // Update outgoing contact requests of requester
2559 let mut update = proto::UpdateContacts::default();
2560 if contact_accepted {
2561 update.remove_contacts.push(responder_id.to_proto());
2562 } else {
2563 update
2564 .remove_outgoing_requests
2565 .push(responder_id.to_proto());
2566 }
2567 for connection_id in pool.user_connection_ids(requester_id) {
2568 session.peer.send(connection_id, update.clone())?;
2569 }
2570
2571 // Update incoming contact requests of responder
2572 let mut update = proto::UpdateContacts::default();
2573 if contact_accepted {
2574 update.remove_contacts.push(requester_id.to_proto());
2575 } else {
2576 update
2577 .remove_incoming_requests
2578 .push(requester_id.to_proto());
2579 }
2580 for connection_id in pool.user_connection_ids(responder_id) {
2581 session.peer.send(connection_id, update.clone())?;
2582 if let Some(notification_id) = deleted_notification_id {
2583 session.peer.send(
2584 connection_id,
2585 proto::DeleteNotification {
2586 notification_id: notification_id.to_proto(),
2587 },
2588 )?;
2589 }
2590 }
2591
2592 response.send(proto::Ack {})?;
2593 Ok(())
2594}
2595
2596/// Creates a new channel.
2597async fn create_channel(
2598 request: proto::CreateChannel,
2599 response: Response<proto::CreateChannel>,
2600 session: UserSession,
2601) -> Result<()> {
2602 let db = session.db().await;
2603
2604 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2605 let (channel, membership) = db
2606 .create_channel(&request.name, parent_id, session.user_id())
2607 .await?;
2608
2609 let root_id = channel.root_id();
2610 let channel = Channel::from_model(channel);
2611
2612 response.send(proto::CreateChannelResponse {
2613 channel: Some(channel.to_proto()),
2614 parent_id: request.parent_id,
2615 })?;
2616
2617 let mut connection_pool = session.connection_pool().await;
2618 if let Some(membership) = membership {
2619 connection_pool.subscribe_to_channel(
2620 membership.user_id,
2621 membership.channel_id,
2622 membership.role,
2623 );
2624 let update = proto::UpdateUserChannels {
2625 channel_memberships: vec![proto::ChannelMembership {
2626 channel_id: membership.channel_id.to_proto(),
2627 role: membership.role.into(),
2628 }],
2629 ..Default::default()
2630 };
2631 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2632 session.peer.send(connection_id, update.clone())?;
2633 }
2634 }
2635
2636 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2637 if !role.can_see_channel(channel.visibility) {
2638 continue;
2639 }
2640
2641 let update = proto::UpdateChannels {
2642 channels: vec![channel.to_proto()],
2643 ..Default::default()
2644 };
2645 session.peer.send(connection_id, update.clone())?;
2646 }
2647
2648 Ok(())
2649}
2650
2651/// Delete a channel
2652async fn delete_channel(
2653 request: proto::DeleteChannel,
2654 response: Response<proto::DeleteChannel>,
2655 session: UserSession,
2656) -> Result<()> {
2657 let db = session.db().await;
2658
2659 let channel_id = request.channel_id;
2660 let (root_channel, removed_channels) = db
2661 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2662 .await?;
2663 response.send(proto::Ack {})?;
2664
2665 // Notify members of removed channels
2666 let mut update = proto::UpdateChannels::default();
2667 update
2668 .delete_channels
2669 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2670
2671 let connection_pool = session.connection_pool().await;
2672 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2673 session.peer.send(connection_id, update.clone())?;
2674 }
2675
2676 Ok(())
2677}
2678
2679/// Invite someone to join a channel.
2680async fn invite_channel_member(
2681 request: proto::InviteChannelMember,
2682 response: Response<proto::InviteChannelMember>,
2683 session: UserSession,
2684) -> Result<()> {
2685 let db = session.db().await;
2686 let channel_id = ChannelId::from_proto(request.channel_id);
2687 let invitee_id = UserId::from_proto(request.user_id);
2688 let InviteMemberResult {
2689 channel,
2690 notifications,
2691 } = db
2692 .invite_channel_member(
2693 channel_id,
2694 invitee_id,
2695 session.user_id(),
2696 request.role().into(),
2697 )
2698 .await?;
2699
2700 let update = proto::UpdateChannels {
2701 channel_invitations: vec![channel.to_proto()],
2702 ..Default::default()
2703 };
2704
2705 let connection_pool = session.connection_pool().await;
2706 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2707 session.peer.send(connection_id, update.clone())?;
2708 }
2709
2710 send_notifications(&connection_pool, &session.peer, notifications);
2711
2712 response.send(proto::Ack {})?;
2713 Ok(())
2714}
2715
2716/// remove someone from a channel
2717async fn remove_channel_member(
2718 request: proto::RemoveChannelMember,
2719 response: Response<proto::RemoveChannelMember>,
2720 session: UserSession,
2721) -> Result<()> {
2722 let db = session.db().await;
2723 let channel_id = ChannelId::from_proto(request.channel_id);
2724 let member_id = UserId::from_proto(request.user_id);
2725
2726 let RemoveChannelMemberResult {
2727 membership_update,
2728 notification_id,
2729 } = db
2730 .remove_channel_member(channel_id, member_id, session.user_id())
2731 .await?;
2732
2733 let mut connection_pool = session.connection_pool().await;
2734 notify_membership_updated(
2735 &mut connection_pool,
2736 membership_update,
2737 member_id,
2738 &session.peer,
2739 );
2740 for connection_id in connection_pool.user_connection_ids(member_id) {
2741 if let Some(notification_id) = notification_id {
2742 session
2743 .peer
2744 .send(
2745 connection_id,
2746 proto::DeleteNotification {
2747 notification_id: notification_id.to_proto(),
2748 },
2749 )
2750 .trace_err();
2751 }
2752 }
2753
2754 response.send(proto::Ack {})?;
2755 Ok(())
2756}
2757
2758/// Toggle the channel between public and private.
2759/// Care is taken to maintain the invariant that public channels only descend from public channels,
2760/// (though members-only channels can appear at any point in the hierarchy).
2761async fn set_channel_visibility(
2762 request: proto::SetChannelVisibility,
2763 response: Response<proto::SetChannelVisibility>,
2764 session: UserSession,
2765) -> Result<()> {
2766 let db = session.db().await;
2767 let channel_id = ChannelId::from_proto(request.channel_id);
2768 let visibility = request.visibility().into();
2769
2770 let channel_model = db
2771 .set_channel_visibility(channel_id, visibility, session.user_id())
2772 .await?;
2773 let root_id = channel_model.root_id();
2774 let channel = Channel::from_model(channel_model);
2775
2776 let mut connection_pool = session.connection_pool().await;
2777 for (user_id, role) in connection_pool
2778 .channel_user_ids(root_id)
2779 .collect::<Vec<_>>()
2780 .into_iter()
2781 {
2782 let update = if role.can_see_channel(channel.visibility) {
2783 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2784 proto::UpdateChannels {
2785 channels: vec![channel.to_proto()],
2786 ..Default::default()
2787 }
2788 } else {
2789 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2790 proto::UpdateChannels {
2791 delete_channels: vec![channel.id.to_proto()],
2792 ..Default::default()
2793 }
2794 };
2795
2796 for connection_id in connection_pool.user_connection_ids(user_id) {
2797 session.peer.send(connection_id, update.clone())?;
2798 }
2799 }
2800
2801 response.send(proto::Ack {})?;
2802 Ok(())
2803}
2804
2805/// Alter the role for a user in the channel.
2806async fn set_channel_member_role(
2807 request: proto::SetChannelMemberRole,
2808 response: Response<proto::SetChannelMemberRole>,
2809 session: UserSession,
2810) -> Result<()> {
2811 let db = session.db().await;
2812 let channel_id = ChannelId::from_proto(request.channel_id);
2813 let member_id = UserId::from_proto(request.user_id);
2814 let result = db
2815 .set_channel_member_role(
2816 channel_id,
2817 session.user_id(),
2818 member_id,
2819 request.role().into(),
2820 )
2821 .await?;
2822
2823 match result {
2824 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2825 let mut connection_pool = session.connection_pool().await;
2826 notify_membership_updated(
2827 &mut connection_pool,
2828 membership_update,
2829 member_id,
2830 &session.peer,
2831 )
2832 }
2833 db::SetMemberRoleResult::InviteUpdated(channel) => {
2834 let update = proto::UpdateChannels {
2835 channel_invitations: vec![channel.to_proto()],
2836 ..Default::default()
2837 };
2838
2839 for connection_id in session
2840 .connection_pool()
2841 .await
2842 .user_connection_ids(member_id)
2843 {
2844 session.peer.send(connection_id, update.clone())?;
2845 }
2846 }
2847 }
2848
2849 response.send(proto::Ack {})?;
2850 Ok(())
2851}
2852
2853/// Change the name of a channel
2854async fn rename_channel(
2855 request: proto::RenameChannel,
2856 response: Response<proto::RenameChannel>,
2857 session: UserSession,
2858) -> Result<()> {
2859 let db = session.db().await;
2860 let channel_id = ChannelId::from_proto(request.channel_id);
2861 let channel_model = db
2862 .rename_channel(channel_id, session.user_id(), &request.name)
2863 .await?;
2864 let root_id = channel_model.root_id();
2865 let channel = Channel::from_model(channel_model);
2866
2867 response.send(proto::RenameChannelResponse {
2868 channel: Some(channel.to_proto()),
2869 })?;
2870
2871 let connection_pool = session.connection_pool().await;
2872 let update = proto::UpdateChannels {
2873 channels: vec![channel.to_proto()],
2874 ..Default::default()
2875 };
2876 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2877 if role.can_see_channel(channel.visibility) {
2878 session.peer.send(connection_id, update.clone())?;
2879 }
2880 }
2881
2882 Ok(())
2883}
2884
2885/// Move a channel to a new parent.
2886async fn move_channel(
2887 request: proto::MoveChannel,
2888 response: Response<proto::MoveChannel>,
2889 session: UserSession,
2890) -> Result<()> {
2891 let channel_id = ChannelId::from_proto(request.channel_id);
2892 let to = ChannelId::from_proto(request.to);
2893
2894 let (root_id, channels) = session
2895 .db()
2896 .await
2897 .move_channel(channel_id, to, session.user_id())
2898 .await?;
2899
2900 let connection_pool = session.connection_pool().await;
2901 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2902 let channels = channels
2903 .iter()
2904 .filter_map(|channel| {
2905 if role.can_see_channel(channel.visibility) {
2906 Some(channel.to_proto())
2907 } else {
2908 None
2909 }
2910 })
2911 .collect::<Vec<_>>();
2912 if channels.is_empty() {
2913 continue;
2914 }
2915
2916 let update = proto::UpdateChannels {
2917 channels,
2918 ..Default::default()
2919 };
2920
2921 session.peer.send(connection_id, update.clone())?;
2922 }
2923
2924 response.send(Ack {})?;
2925 Ok(())
2926}
2927
2928/// Get the list of channel members
2929async fn get_channel_members(
2930 request: proto::GetChannelMembers,
2931 response: Response<proto::GetChannelMembers>,
2932 session: UserSession,
2933) -> Result<()> {
2934 let db = session.db().await;
2935 let channel_id = ChannelId::from_proto(request.channel_id);
2936 let members = db
2937 .get_channel_participant_details(channel_id, session.user_id())
2938 .await?;
2939 response.send(proto::GetChannelMembersResponse { members })?;
2940 Ok(())
2941}
2942
2943/// Accept or decline a channel invitation.
2944async fn respond_to_channel_invite(
2945 request: proto::RespondToChannelInvite,
2946 response: Response<proto::RespondToChannelInvite>,
2947 session: UserSession,
2948) -> Result<()> {
2949 let db = session.db().await;
2950 let channel_id = ChannelId::from_proto(request.channel_id);
2951 let RespondToChannelInvite {
2952 membership_update,
2953 notifications,
2954 } = db
2955 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
2956 .await?;
2957
2958 let mut connection_pool = session.connection_pool().await;
2959 if let Some(membership_update) = membership_update {
2960 notify_membership_updated(
2961 &mut connection_pool,
2962 membership_update,
2963 session.user_id(),
2964 &session.peer,
2965 );
2966 } else {
2967 let update = proto::UpdateChannels {
2968 remove_channel_invitations: vec![channel_id.to_proto()],
2969 ..Default::default()
2970 };
2971
2972 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
2973 session.peer.send(connection_id, update.clone())?;
2974 }
2975 };
2976
2977 send_notifications(&connection_pool, &session.peer, notifications);
2978
2979 response.send(proto::Ack {})?;
2980
2981 Ok(())
2982}
2983
2984/// Join the channels' room
2985async fn join_channel(
2986 request: proto::JoinChannel,
2987 response: Response<proto::JoinChannel>,
2988 session: UserSession,
2989) -> Result<()> {
2990 let channel_id = ChannelId::from_proto(request.channel_id);
2991 join_channel_internal(channel_id, Box::new(response), session).await
2992}
2993
2994trait JoinChannelInternalResponse {
2995 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2996}
2997impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2998 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2999 Response::<proto::JoinChannel>::send(self, result)
3000 }
3001}
3002impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3003 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3004 Response::<proto::JoinRoom>::send(self, result)
3005 }
3006}
3007
3008async fn join_channel_internal(
3009 channel_id: ChannelId,
3010 response: Box<impl JoinChannelInternalResponse>,
3011 session: UserSession,
3012) -> Result<()> {
3013 let joined_room = {
3014 leave_room_for_session(&session).await?;
3015 let db = session.db().await;
3016
3017 let (joined_room, membership_updated, role) = db
3018 .join_channel(channel_id, session.user_id(), session.connection_id)
3019 .await?;
3020
3021 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3022 let (can_publish, token) = if role == ChannelRole::Guest {
3023 (
3024 false,
3025 live_kit
3026 .guest_token(
3027 &joined_room.room.live_kit_room,
3028 &session.user_id().to_string(),
3029 )
3030 .trace_err()?,
3031 )
3032 } else {
3033 (
3034 true,
3035 live_kit
3036 .room_token(
3037 &joined_room.room.live_kit_room,
3038 &session.user_id().to_string(),
3039 )
3040 .trace_err()?,
3041 )
3042 };
3043
3044 Some(LiveKitConnectionInfo {
3045 server_url: live_kit.url().into(),
3046 token,
3047 can_publish,
3048 })
3049 });
3050
3051 response.send(proto::JoinRoomResponse {
3052 room: Some(joined_room.room.clone()),
3053 channel_id: joined_room
3054 .channel
3055 .as_ref()
3056 .map(|channel| channel.id.to_proto()),
3057 live_kit_connection_info,
3058 })?;
3059
3060 let mut connection_pool = session.connection_pool().await;
3061 if let Some(membership_updated) = membership_updated {
3062 notify_membership_updated(
3063 &mut connection_pool,
3064 membership_updated,
3065 session.user_id(),
3066 &session.peer,
3067 );
3068 }
3069
3070 room_updated(&joined_room.room, &session.peer);
3071
3072 joined_room
3073 };
3074
3075 channel_updated(
3076 &joined_room
3077 .channel
3078 .ok_or_else(|| anyhow!("channel not returned"))?,
3079 &joined_room.room,
3080 &session.peer,
3081 &*session.connection_pool().await,
3082 );
3083
3084 update_user_contacts(session.user_id(), &session).await?;
3085 Ok(())
3086}
3087
3088/// Start editing the channel notes
3089async fn join_channel_buffer(
3090 request: proto::JoinChannelBuffer,
3091 response: Response<proto::JoinChannelBuffer>,
3092 session: UserSession,
3093) -> Result<()> {
3094 let db = session.db().await;
3095 let channel_id = ChannelId::from_proto(request.channel_id);
3096
3097 let open_response = db
3098 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3099 .await?;
3100
3101 let collaborators = open_response.collaborators.clone();
3102 response.send(open_response)?;
3103
3104 let update = UpdateChannelBufferCollaborators {
3105 channel_id: channel_id.to_proto(),
3106 collaborators: collaborators.clone(),
3107 };
3108 channel_buffer_updated(
3109 session.connection_id,
3110 collaborators
3111 .iter()
3112 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3113 &update,
3114 &session.peer,
3115 );
3116
3117 Ok(())
3118}
3119
3120/// Edit the channel notes
3121async fn update_channel_buffer(
3122 request: proto::UpdateChannelBuffer,
3123 session: UserSession,
3124) -> Result<()> {
3125 let db = session.db().await;
3126 let channel_id = ChannelId::from_proto(request.channel_id);
3127
3128 let (collaborators, non_collaborators, epoch, version) = db
3129 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3130 .await?;
3131
3132 channel_buffer_updated(
3133 session.connection_id,
3134 collaborators,
3135 &proto::UpdateChannelBuffer {
3136 channel_id: channel_id.to_proto(),
3137 operations: request.operations,
3138 },
3139 &session.peer,
3140 );
3141
3142 let pool = &*session.connection_pool().await;
3143
3144 broadcast(
3145 None,
3146 non_collaborators
3147 .iter()
3148 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3149 |peer_id| {
3150 session.peer.send(
3151 peer_id,
3152 proto::UpdateChannels {
3153 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3154 channel_id: channel_id.to_proto(),
3155 epoch: epoch as u64,
3156 version: version.clone(),
3157 }],
3158 ..Default::default()
3159 },
3160 )
3161 },
3162 );
3163
3164 Ok(())
3165}
3166
3167/// Rejoin the channel notes after a connection blip
3168async fn rejoin_channel_buffers(
3169 request: proto::RejoinChannelBuffers,
3170 response: Response<proto::RejoinChannelBuffers>,
3171 session: UserSession,
3172) -> Result<()> {
3173 let db = session.db().await;
3174 let buffers = db
3175 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3176 .await?;
3177
3178 for rejoined_buffer in &buffers {
3179 let collaborators_to_notify = rejoined_buffer
3180 .buffer
3181 .collaborators
3182 .iter()
3183 .filter_map(|c| Some(c.peer_id?.into()));
3184 channel_buffer_updated(
3185 session.connection_id,
3186 collaborators_to_notify,
3187 &proto::UpdateChannelBufferCollaborators {
3188 channel_id: rejoined_buffer.buffer.channel_id,
3189 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3190 },
3191 &session.peer,
3192 );
3193 }
3194
3195 response.send(proto::RejoinChannelBuffersResponse {
3196 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3197 })?;
3198
3199 Ok(())
3200}
3201
3202/// Stop editing the channel notes
3203async fn leave_channel_buffer(
3204 request: proto::LeaveChannelBuffer,
3205 response: Response<proto::LeaveChannelBuffer>,
3206 session: UserSession,
3207) -> Result<()> {
3208 let db = session.db().await;
3209 let channel_id = ChannelId::from_proto(request.channel_id);
3210
3211 let left_buffer = db
3212 .leave_channel_buffer(channel_id, session.connection_id)
3213 .await?;
3214
3215 response.send(Ack {})?;
3216
3217 channel_buffer_updated(
3218 session.connection_id,
3219 left_buffer.connections,
3220 &proto::UpdateChannelBufferCollaborators {
3221 channel_id: channel_id.to_proto(),
3222 collaborators: left_buffer.collaborators,
3223 },
3224 &session.peer,
3225 );
3226
3227 Ok(())
3228}
3229
3230fn channel_buffer_updated<T: EnvelopedMessage>(
3231 sender_id: ConnectionId,
3232 collaborators: impl IntoIterator<Item = ConnectionId>,
3233 message: &T,
3234 peer: &Peer,
3235) {
3236 broadcast(Some(sender_id), collaborators, |peer_id| {
3237 peer.send(peer_id, message.clone())
3238 });
3239}
3240
3241fn send_notifications(
3242 connection_pool: &ConnectionPool,
3243 peer: &Peer,
3244 notifications: db::NotificationBatch,
3245) {
3246 for (user_id, notification) in notifications {
3247 for connection_id in connection_pool.user_connection_ids(user_id) {
3248 if let Err(error) = peer.send(
3249 connection_id,
3250 proto::AddNotification {
3251 notification: Some(notification.clone()),
3252 },
3253 ) {
3254 tracing::error!(
3255 "failed to send notification to {:?} {}",
3256 connection_id,
3257 error
3258 );
3259 }
3260 }
3261 }
3262}
3263
3264/// Send a message to the channel
3265async fn send_channel_message(
3266 request: proto::SendChannelMessage,
3267 response: Response<proto::SendChannelMessage>,
3268 session: UserSession,
3269) -> Result<()> {
3270 // Validate the message body.
3271 let body = request.body.trim().to_string();
3272 if body.len() > MAX_MESSAGE_LEN {
3273 return Err(anyhow!("message is too long"))?;
3274 }
3275 if body.is_empty() {
3276 return Err(anyhow!("message can't be blank"))?;
3277 }
3278
3279 // TODO: adjust mentions if body is trimmed
3280
3281 let timestamp = OffsetDateTime::now_utc();
3282 let nonce = request
3283 .nonce
3284 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3285
3286 let channel_id = ChannelId::from_proto(request.channel_id);
3287 let CreatedChannelMessage {
3288 message_id,
3289 participant_connection_ids,
3290 channel_members,
3291 notifications,
3292 } = session
3293 .db()
3294 .await
3295 .create_channel_message(
3296 channel_id,
3297 session.user_id(),
3298 &body,
3299 &request.mentions,
3300 timestamp,
3301 nonce.clone().into(),
3302 match request.reply_to_message_id {
3303 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3304 None => None,
3305 },
3306 )
3307 .await?;
3308
3309 let message = proto::ChannelMessage {
3310 sender_id: session.user_id().to_proto(),
3311 id: message_id.to_proto(),
3312 body,
3313 mentions: request.mentions,
3314 timestamp: timestamp.unix_timestamp() as u64,
3315 nonce: Some(nonce),
3316 reply_to_message_id: request.reply_to_message_id,
3317 edited_at: None,
3318 };
3319 broadcast(
3320 Some(session.connection_id),
3321 participant_connection_ids,
3322 |connection| {
3323 session.peer.send(
3324 connection,
3325 proto::ChannelMessageSent {
3326 channel_id: channel_id.to_proto(),
3327 message: Some(message.clone()),
3328 },
3329 )
3330 },
3331 );
3332 response.send(proto::SendChannelMessageResponse {
3333 message: Some(message),
3334 })?;
3335
3336 let pool = &*session.connection_pool().await;
3337 broadcast(
3338 None,
3339 channel_members
3340 .iter()
3341 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3342 |peer_id| {
3343 session.peer.send(
3344 peer_id,
3345 proto::UpdateChannels {
3346 latest_channel_message_ids: vec![proto::ChannelMessageId {
3347 channel_id: channel_id.to_proto(),
3348 message_id: message_id.to_proto(),
3349 }],
3350 ..Default::default()
3351 },
3352 )
3353 },
3354 );
3355 send_notifications(pool, &session.peer, notifications);
3356
3357 Ok(())
3358}
3359
3360/// Delete a channel message
3361async fn remove_channel_message(
3362 request: proto::RemoveChannelMessage,
3363 response: Response<proto::RemoveChannelMessage>,
3364 session: UserSession,
3365) -> Result<()> {
3366 let channel_id = ChannelId::from_proto(request.channel_id);
3367 let message_id = MessageId::from_proto(request.message_id);
3368 let connection_ids = session
3369 .db()
3370 .await
3371 .remove_channel_message(channel_id, message_id, session.user_id())
3372 .await?;
3373 broadcast(Some(session.connection_id), connection_ids, |connection| {
3374 session.peer.send(connection, request.clone())
3375 });
3376 response.send(proto::Ack {})?;
3377 Ok(())
3378}
3379
3380async fn update_channel_message(
3381 request: proto::UpdateChannelMessage,
3382 response: Response<proto::UpdateChannelMessage>,
3383 session: UserSession,
3384) -> Result<()> {
3385 let channel_id = ChannelId::from_proto(request.channel_id);
3386 let message_id = MessageId::from_proto(request.message_id);
3387 let updated_at = OffsetDateTime::now_utc();
3388 let UpdatedChannelMessage {
3389 message_id,
3390 participant_connection_ids,
3391 notifications,
3392 reply_to_message_id,
3393 timestamp,
3394 } = session
3395 .db()
3396 .await
3397 .update_channel_message(
3398 channel_id,
3399 message_id,
3400 session.user_id(),
3401 request.body.as_str(),
3402 &request.mentions,
3403 updated_at,
3404 )
3405 .await?;
3406
3407 let nonce = request
3408 .nonce
3409 .clone()
3410 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3411
3412 let message = proto::ChannelMessage {
3413 sender_id: session.user_id().to_proto(),
3414 id: message_id.to_proto(),
3415 body: request.body.clone(),
3416 mentions: request.mentions.clone(),
3417 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3418 nonce: Some(nonce),
3419 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3420 edited_at: Some(updated_at.unix_timestamp() as u64),
3421 };
3422
3423 response.send(proto::Ack {})?;
3424
3425 let pool = &*session.connection_pool().await;
3426 broadcast(
3427 Some(session.connection_id),
3428 participant_connection_ids,
3429 |connection| {
3430 session.peer.send(
3431 connection,
3432 proto::ChannelMessageUpdate {
3433 channel_id: channel_id.to_proto(),
3434 message: Some(message.clone()),
3435 },
3436 )
3437 },
3438 );
3439
3440 send_notifications(pool, &session.peer, notifications);
3441
3442 Ok(())
3443}
3444
3445/// Mark a channel message as read
3446async fn acknowledge_channel_message(
3447 request: proto::AckChannelMessage,
3448 session: UserSession,
3449) -> Result<()> {
3450 let channel_id = ChannelId::from_proto(request.channel_id);
3451 let message_id = MessageId::from_proto(request.message_id);
3452 let notifications = session
3453 .db()
3454 .await
3455 .observe_channel_message(channel_id, session.user_id(), message_id)
3456 .await?;
3457 send_notifications(
3458 &*session.connection_pool().await,
3459 &session.peer,
3460 notifications,
3461 );
3462 Ok(())
3463}
3464
3465/// Mark a buffer version as synced
3466async fn acknowledge_buffer_version(
3467 request: proto::AckBufferOperation,
3468 session: UserSession,
3469) -> Result<()> {
3470 let buffer_id = BufferId::from_proto(request.buffer_id);
3471 session
3472 .db()
3473 .await
3474 .observe_buffer_version(
3475 buffer_id,
3476 session.user_id(),
3477 request.epoch as i32,
3478 &request.version,
3479 )
3480 .await?;
3481 Ok(())
3482}
3483
3484struct CompleteWithLanguageModelRateLimit;
3485
3486impl RateLimit for CompleteWithLanguageModelRateLimit {
3487 fn capacity() -> usize {
3488 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3489 .ok()
3490 .and_then(|v| v.parse().ok())
3491 .unwrap_or(120) // Picked arbitrarily
3492 }
3493
3494 fn refill_duration() -> chrono::Duration {
3495 chrono::Duration::hours(1)
3496 }
3497
3498 fn db_name() -> &'static str {
3499 "complete-with-language-model"
3500 }
3501}
3502
3503async fn complete_with_language_model(
3504 request: proto::CompleteWithLanguageModel,
3505 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3506 session: Session,
3507 open_ai_api_key: Option<Arc<str>>,
3508 google_ai_api_key: Option<Arc<str>>,
3509) -> Result<()> {
3510 let Some(session) = session.for_user() else {
3511 return Err(anyhow!("user not found"))?;
3512 };
3513 authorize_access_to_language_models(&session).await?;
3514 session
3515 .rate_limiter
3516 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
3517 .await?;
3518
3519 if request.model.starts_with("gpt") {
3520 let api_key =
3521 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
3522 complete_with_open_ai(request, response, session, api_key).await?;
3523 } else if request.model.starts_with("gemini") {
3524 let api_key = google_ai_api_key
3525 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3526 complete_with_google_ai(request, response, session, api_key).await?;
3527 }
3528
3529 Ok(())
3530}
3531
3532async fn complete_with_open_ai(
3533 request: proto::CompleteWithLanguageModel,
3534 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3535 session: UserSession,
3536 api_key: Arc<str>,
3537) -> Result<()> {
3538 const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
3539
3540 let mut completion_stream = open_ai::stream_completion(
3541 &session.http_client,
3542 OPEN_AI_API_URL,
3543 &api_key,
3544 crate::ai::language_model_request_to_open_ai(request)?,
3545 )
3546 .await
3547 .context("open_ai::stream_completion request failed")?;
3548
3549 while let Some(event) = completion_stream.next().await {
3550 let event = event?;
3551 response.send(proto::LanguageModelResponse {
3552 choices: event
3553 .choices
3554 .into_iter()
3555 .map(|choice| proto::LanguageModelChoiceDelta {
3556 index: choice.index,
3557 delta: Some(proto::LanguageModelResponseMessage {
3558 role: choice.delta.role.map(|role| match role {
3559 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
3560 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
3561 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
3562 } as i32),
3563 content: choice.delta.content,
3564 }),
3565 finish_reason: choice.finish_reason,
3566 })
3567 .collect(),
3568 })?;
3569 }
3570
3571 Ok(())
3572}
3573
3574async fn complete_with_google_ai(
3575 request: proto::CompleteWithLanguageModel,
3576 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3577 session: UserSession,
3578 api_key: Arc<str>,
3579) -> Result<()> {
3580 let mut stream = google_ai::stream_generate_content(
3581 &session.http_client,
3582 google_ai::API_URL,
3583 api_key.as_ref(),
3584 crate::ai::language_model_request_to_google_ai(request)?,
3585 )
3586 .await
3587 .context("google_ai::stream_generate_content request failed")?;
3588
3589 while let Some(event) = stream.next().await {
3590 let event = event?;
3591 response.send(proto::LanguageModelResponse {
3592 choices: event
3593 .candidates
3594 .unwrap_or_default()
3595 .into_iter()
3596 .map(|candidate| proto::LanguageModelChoiceDelta {
3597 index: candidate.index as u32,
3598 delta: Some(proto::LanguageModelResponseMessage {
3599 role: Some(match candidate.content.role {
3600 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
3601 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
3602 } as i32),
3603 content: Some(
3604 candidate
3605 .content
3606 .parts
3607 .into_iter()
3608 .filter_map(|part| match part {
3609 google_ai::Part::TextPart(part) => Some(part.text),
3610 google_ai::Part::InlineDataPart(_) => None,
3611 })
3612 .collect(),
3613 ),
3614 }),
3615 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
3616 })
3617 .collect(),
3618 })?;
3619 }
3620
3621 Ok(())
3622}
3623
3624struct CountTokensWithLanguageModelRateLimit;
3625
3626impl RateLimit for CountTokensWithLanguageModelRateLimit {
3627 fn capacity() -> usize {
3628 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3629 .ok()
3630 .and_then(|v| v.parse().ok())
3631 .unwrap_or(600) // Picked arbitrarily
3632 }
3633
3634 fn refill_duration() -> chrono::Duration {
3635 chrono::Duration::hours(1)
3636 }
3637
3638 fn db_name() -> &'static str {
3639 "count-tokens-with-language-model"
3640 }
3641}
3642
3643async fn count_tokens_with_language_model(
3644 request: proto::CountTokensWithLanguageModel,
3645 response: Response<proto::CountTokensWithLanguageModel>,
3646 session: UserSession,
3647 google_ai_api_key: Option<Arc<str>>,
3648) -> Result<()> {
3649 authorize_access_to_language_models(&session).await?;
3650
3651 if !request.model.starts_with("gemini") {
3652 return Err(anyhow!(
3653 "counting tokens for model: {:?} is not supported",
3654 request.model
3655 ))?;
3656 }
3657
3658 session
3659 .rate_limiter
3660 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
3661 .await?;
3662
3663 let api_key = google_ai_api_key
3664 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3665 let tokens_response = google_ai::count_tokens(
3666 &session.http_client,
3667 google_ai::API_URL,
3668 &api_key,
3669 crate::ai::count_tokens_request_to_google_ai(request)?,
3670 )
3671 .await?;
3672 response.send(proto::CountTokensResponse {
3673 token_count: tokens_response.total_tokens as u32,
3674 })?;
3675 Ok(())
3676}
3677
3678async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
3679 let db = session.db().await;
3680 let flags = db.get_user_flags(session.user_id()).await?;
3681 if flags.iter().any(|flag| flag == "language-models") {
3682 Ok(())
3683 } else {
3684 Err(anyhow!("permission denied"))?
3685 }
3686}
3687
3688/// Start receiving chat updates for a channel
3689async fn join_channel_chat(
3690 request: proto::JoinChannelChat,
3691 response: Response<proto::JoinChannelChat>,
3692 session: UserSession,
3693) -> Result<()> {
3694 let channel_id = ChannelId::from_proto(request.channel_id);
3695
3696 let db = session.db().await;
3697 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3698 .await?;
3699 let messages = db
3700 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3701 .await?;
3702 response.send(proto::JoinChannelChatResponse {
3703 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3704 messages,
3705 })?;
3706 Ok(())
3707}
3708
3709/// Stop receiving chat updates for a channel
3710async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
3711 let channel_id = ChannelId::from_proto(request.channel_id);
3712 session
3713 .db()
3714 .await
3715 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3716 .await?;
3717 Ok(())
3718}
3719
3720/// Retrieve the chat history for a channel
3721async fn get_channel_messages(
3722 request: proto::GetChannelMessages,
3723 response: Response<proto::GetChannelMessages>,
3724 session: UserSession,
3725) -> Result<()> {
3726 let channel_id = ChannelId::from_proto(request.channel_id);
3727 let messages = session
3728 .db()
3729 .await
3730 .get_channel_messages(
3731 channel_id,
3732 session.user_id(),
3733 MESSAGE_COUNT_PER_PAGE,
3734 Some(MessageId::from_proto(request.before_message_id)),
3735 )
3736 .await?;
3737 response.send(proto::GetChannelMessagesResponse {
3738 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3739 messages,
3740 })?;
3741 Ok(())
3742}
3743
3744/// Retrieve specific chat messages
3745async fn get_channel_messages_by_id(
3746 request: proto::GetChannelMessagesById,
3747 response: Response<proto::GetChannelMessagesById>,
3748 session: UserSession,
3749) -> Result<()> {
3750 let message_ids = request
3751 .message_ids
3752 .iter()
3753 .map(|id| MessageId::from_proto(*id))
3754 .collect::<Vec<_>>();
3755 let messages = session
3756 .db()
3757 .await
3758 .get_channel_messages_by_id(session.user_id(), &message_ids)
3759 .await?;
3760 response.send(proto::GetChannelMessagesResponse {
3761 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3762 messages,
3763 })?;
3764 Ok(())
3765}
3766
3767/// Retrieve the current users notifications
3768async fn get_notifications(
3769 request: proto::GetNotifications,
3770 response: Response<proto::GetNotifications>,
3771 session: UserSession,
3772) -> Result<()> {
3773 let notifications = session
3774 .db()
3775 .await
3776 .get_notifications(
3777 session.user_id(),
3778 NOTIFICATION_COUNT_PER_PAGE,
3779 request
3780 .before_id
3781 .map(|id| db::NotificationId::from_proto(id)),
3782 )
3783 .await?;
3784 response.send(proto::GetNotificationsResponse {
3785 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3786 notifications,
3787 })?;
3788 Ok(())
3789}
3790
3791/// Mark notifications as read
3792async fn mark_notification_as_read(
3793 request: proto::MarkNotificationRead,
3794 response: Response<proto::MarkNotificationRead>,
3795 session: UserSession,
3796) -> Result<()> {
3797 let database = &session.db().await;
3798 let notifications = database
3799 .mark_notification_as_read_by_id(
3800 session.user_id(),
3801 NotificationId::from_proto(request.notification_id),
3802 )
3803 .await?;
3804 send_notifications(
3805 &*session.connection_pool().await,
3806 &session.peer,
3807 notifications,
3808 );
3809 response.send(proto::Ack {})?;
3810 Ok(())
3811}
3812
3813/// Get the current users information
3814async fn get_private_user_info(
3815 _request: proto::GetPrivateUserInfo,
3816 response: Response<proto::GetPrivateUserInfo>,
3817 session: UserSession,
3818) -> Result<()> {
3819 let db = session.db().await;
3820
3821 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
3822 let user = db
3823 .get_user_by_id(session.user_id())
3824 .await?
3825 .ok_or_else(|| anyhow!("user not found"))?;
3826 let flags = db.get_user_flags(session.user_id()).await?;
3827
3828 response.send(proto::GetPrivateUserInfoResponse {
3829 metrics_id,
3830 staff: user.admin,
3831 flags,
3832 })?;
3833 Ok(())
3834}
3835
3836fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3837 match message {
3838 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3839 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3840 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3841 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3842 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3843 code: frame.code.into(),
3844 reason: frame.reason,
3845 })),
3846 }
3847}
3848
3849fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3850 match message {
3851 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3852 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3853 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3854 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3855 AxumMessage::Close(frame) => {
3856 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3857 code: frame.code.into(),
3858 reason: frame.reason,
3859 }))
3860 }
3861 }
3862}
3863
3864fn notify_membership_updated(
3865 connection_pool: &mut ConnectionPool,
3866 result: MembershipUpdated,
3867 user_id: UserId,
3868 peer: &Peer,
3869) {
3870 for membership in &result.new_channels.channel_memberships {
3871 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3872 }
3873 for channel_id in &result.removed_channels {
3874 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3875 }
3876
3877 let user_channels_update = proto::UpdateUserChannels {
3878 channel_memberships: result
3879 .new_channels
3880 .channel_memberships
3881 .iter()
3882 .map(|cm| proto::ChannelMembership {
3883 channel_id: cm.channel_id.to_proto(),
3884 role: cm.role.into(),
3885 })
3886 .collect(),
3887 ..Default::default()
3888 };
3889
3890 let mut update = build_channels_update(result.new_channels, vec![]);
3891 update.delete_channels = result
3892 .removed_channels
3893 .into_iter()
3894 .map(|id| id.to_proto())
3895 .collect();
3896 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3897
3898 for connection_id in connection_pool.user_connection_ids(user_id) {
3899 peer.send(connection_id, user_channels_update.clone())
3900 .trace_err();
3901 peer.send(connection_id, update.clone()).trace_err();
3902 }
3903}
3904
3905fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3906 proto::UpdateUserChannels {
3907 channel_memberships: channels
3908 .channel_memberships
3909 .iter()
3910 .map(|m| proto::ChannelMembership {
3911 channel_id: m.channel_id.to_proto(),
3912 role: m.role.into(),
3913 })
3914 .collect(),
3915 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3916 observed_channel_message_id: channels.observed_channel_messages.clone(),
3917 }
3918}
3919
3920fn build_channels_update(
3921 channels: ChannelsForUser,
3922 channel_invites: Vec<db::Channel>,
3923) -> proto::UpdateChannels {
3924 let mut update = proto::UpdateChannels::default();
3925
3926 for channel in channels.channels {
3927 update.channels.push(channel.to_proto());
3928 }
3929
3930 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3931 update.latest_channel_message_ids = channels.latest_channel_messages;
3932
3933 for (channel_id, participants) in channels.channel_participants {
3934 update
3935 .channel_participants
3936 .push(proto::ChannelParticipants {
3937 channel_id: channel_id.to_proto(),
3938 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3939 });
3940 }
3941
3942 for channel in channel_invites {
3943 update.channel_invitations.push(channel.to_proto());
3944 }
3945 for project in channels.hosted_projects {
3946 update.hosted_projects.push(project);
3947 }
3948
3949 update
3950}
3951
3952fn build_initial_contacts_update(
3953 contacts: Vec<db::Contact>,
3954 pool: &ConnectionPool,
3955) -> proto::UpdateContacts {
3956 let mut update = proto::UpdateContacts::default();
3957
3958 for contact in contacts {
3959 match contact {
3960 db::Contact::Accepted { user_id, busy } => {
3961 update.contacts.push(contact_for_user(user_id, busy, &pool));
3962 }
3963 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3964 db::Contact::Incoming { user_id } => {
3965 update
3966 .incoming_requests
3967 .push(proto::IncomingContactRequest {
3968 requester_id: user_id.to_proto(),
3969 })
3970 }
3971 }
3972 }
3973
3974 update
3975}
3976
3977fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3978 proto::Contact {
3979 user_id: user_id.to_proto(),
3980 online: pool.is_user_online(user_id),
3981 busy,
3982 }
3983}
3984
3985fn room_updated(room: &proto::Room, peer: &Peer) {
3986 broadcast(
3987 None,
3988 room.participants
3989 .iter()
3990 .filter_map(|participant| Some(participant.peer_id?.into())),
3991 |peer_id| {
3992 peer.send(
3993 peer_id,
3994 proto::RoomUpdated {
3995 room: Some(room.clone()),
3996 },
3997 )
3998 },
3999 );
4000}
4001
4002fn channel_updated(
4003 channel: &db::channel::Model,
4004 room: &proto::Room,
4005 peer: &Peer,
4006 pool: &ConnectionPool,
4007) {
4008 let participants = room
4009 .participants
4010 .iter()
4011 .map(|p| p.user_id)
4012 .collect::<Vec<_>>();
4013
4014 broadcast(
4015 None,
4016 pool.channel_connection_ids(channel.root_id())
4017 .filter_map(|(channel_id, role)| {
4018 role.can_see_channel(channel.visibility).then(|| channel_id)
4019 }),
4020 |peer_id| {
4021 peer.send(
4022 peer_id,
4023 proto::UpdateChannels {
4024 channel_participants: vec![proto::ChannelParticipants {
4025 channel_id: channel.id.to_proto(),
4026 participant_user_ids: participants.clone(),
4027 }],
4028 ..Default::default()
4029 },
4030 )
4031 },
4032 );
4033}
4034
4035async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4036 let db = session.db().await;
4037
4038 let contacts = db.get_contacts(user_id).await?;
4039 let busy = db.is_user_busy(user_id).await?;
4040
4041 let pool = session.connection_pool().await;
4042 let updated_contact = contact_for_user(user_id, busy, &pool);
4043 for contact in contacts {
4044 if let db::Contact::Accepted {
4045 user_id: contact_user_id,
4046 ..
4047 } = contact
4048 {
4049 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4050 session
4051 .peer
4052 .send(
4053 contact_conn_id,
4054 proto::UpdateContacts {
4055 contacts: vec![updated_contact.clone()],
4056 remove_contacts: Default::default(),
4057 incoming_requests: Default::default(),
4058 remove_incoming_requests: Default::default(),
4059 outgoing_requests: Default::default(),
4060 remove_outgoing_requests: Default::default(),
4061 },
4062 )
4063 .trace_err();
4064 }
4065 }
4066 }
4067 Ok(())
4068}
4069
4070async fn leave_room_for_session(session: &UserSession) -> Result<()> {
4071 let mut contacts_to_update = HashSet::default();
4072
4073 let room_id;
4074 let canceled_calls_to_user_ids;
4075 let live_kit_room;
4076 let delete_live_kit_room;
4077 let room;
4078 let channel;
4079
4080 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
4081 contacts_to_update.insert(session.user_id());
4082
4083 for project in left_room.left_projects.values() {
4084 project_left(project, session);
4085 }
4086
4087 room_id = RoomId::from_proto(left_room.room.id);
4088 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4089 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
4090 delete_live_kit_room = left_room.deleted;
4091 room = mem::take(&mut left_room.room);
4092 channel = mem::take(&mut left_room.channel);
4093
4094 room_updated(&room, &session.peer);
4095 } else {
4096 return Ok(());
4097 }
4098
4099 if let Some(channel) = channel {
4100 channel_updated(
4101 &channel,
4102 &room,
4103 &session.peer,
4104 &*session.connection_pool().await,
4105 );
4106 }
4107
4108 {
4109 let pool = session.connection_pool().await;
4110 for canceled_user_id in canceled_calls_to_user_ids {
4111 for connection_id in pool.user_connection_ids(canceled_user_id) {
4112 session
4113 .peer
4114 .send(
4115 connection_id,
4116 proto::CallCanceled {
4117 room_id: room_id.to_proto(),
4118 },
4119 )
4120 .trace_err();
4121 }
4122 contacts_to_update.insert(canceled_user_id);
4123 }
4124 }
4125
4126 for contact_user_id in contacts_to_update {
4127 update_user_contacts(contact_user_id, &session).await?;
4128 }
4129
4130 if let Some(live_kit) = session.live_kit_client.as_ref() {
4131 live_kit
4132 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
4133 .await
4134 .trace_err();
4135
4136 if delete_live_kit_room {
4137 live_kit.delete_room(live_kit_room).await.trace_err();
4138 }
4139 }
4140
4141 Ok(())
4142}
4143
4144async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4145 let left_channel_buffers = session
4146 .db()
4147 .await
4148 .leave_channel_buffers(session.connection_id)
4149 .await?;
4150
4151 for left_buffer in left_channel_buffers {
4152 channel_buffer_updated(
4153 session.connection_id,
4154 left_buffer.connections,
4155 &proto::UpdateChannelBufferCollaborators {
4156 channel_id: left_buffer.channel_id.to_proto(),
4157 collaborators: left_buffer.collaborators,
4158 },
4159 &session.peer,
4160 );
4161 }
4162
4163 Ok(())
4164}
4165
4166fn project_left(project: &db::LeftProject, session: &UserSession) {
4167 for connection_id in &project.connection_ids {
4168 if project.host_user_id == Some(session.user_id()) {
4169 session
4170 .peer
4171 .send(
4172 *connection_id,
4173 proto::UnshareProject {
4174 project_id: project.id.to_proto(),
4175 },
4176 )
4177 .trace_err();
4178 } else {
4179 session
4180 .peer
4181 .send(
4182 *connection_id,
4183 proto::RemoveProjectCollaborator {
4184 project_id: project.id.to_proto(),
4185 peer_id: Some(session.connection_id.into()),
4186 },
4187 )
4188 .trace_err();
4189 }
4190 }
4191}
4192
4193pub trait ResultExt {
4194 type Ok;
4195
4196 fn trace_err(self) -> Option<Self::Ok>;
4197}
4198
4199impl<T, E> ResultExt for Result<T, E>
4200where
4201 E: std::fmt::Debug,
4202{
4203 type Ok = T;
4204
4205 fn trace_err(self) -> Option<T> {
4206 match self {
4207 Ok(value) => Some(value),
4208 Err(error) => {
4209 tracing::error!("{:?}", error);
4210 None
4211 }
4212 }
4213 }
4214}