1mod connection_pool;
2
3use crate::{
4 auth::{self, Impersonator},
5 db::{
6 self, BufferId, Channel, ChannelId, ChannelRole, ChannelsForUser, CreatedChannelMessage,
7 Database, InviteMemberResult, MembershipUpdated, MessageId, NotificationId, Project,
8 ProjectId, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId, ServerId,
9 UpdatedChannelMessage, User, UserId,
10 },
11 executor::Executor,
12 AppState, Error, RateLimit, RateLimiter, Result,
13};
14use anyhow::{anyhow, Context as _};
15use async_tungstenite::tungstenite::{
16 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
17};
18use axum::{
19 body::Body,
20 extract::{
21 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
22 ConnectInfo, WebSocketUpgrade,
23 },
24 headers::{Header, HeaderName},
25 http::StatusCode,
26 middleware,
27 response::IntoResponse,
28 routing::get,
29 Extension, Router, TypedHeader,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use core::fmt::{self, Debug, Formatter};
34
35use futures::{
36 channel::oneshot,
37 future::{self, BoxFuture},
38 stream::FuturesUnordered,
39 FutureExt, SinkExt, StreamExt, TryStreamExt,
40};
41use prometheus::{register_int_gauge, IntGauge};
42use rpc::{
43 proto::{
44 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
45 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
46 },
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48};
49use serde::{Serialize, Serializer};
50use std::{
51 any::TypeId,
52 future::Future,
53 marker::PhantomData,
54 mem,
55 net::SocketAddr,
56 ops::{Deref, DerefMut},
57 rc::Rc,
58 sync::{
59 atomic::{AtomicBool, Ordering::SeqCst},
60 Arc, OnceLock,
61 },
62 time::{Duration, Instant},
63};
64use time::OffsetDateTime;
65use tokio::sync::{watch, Semaphore};
66use tower::ServiceBuilder;
67use tracing::{field, info_span, instrument, Instrument};
68use util::{http::IsahcHttpClient, SemanticVersion};
69
70pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
71
72// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
73pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
74
75const MESSAGE_COUNT_PER_PAGE: usize = 100;
76const MAX_MESSAGE_LEN: usize = 1024;
77const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
78
79type MessageHandler =
80 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
81
82struct Response<R> {
83 peer: Arc<Peer>,
84 receipt: Receipt<R>,
85 responded: Arc<AtomicBool>,
86}
87
88impl<R: RequestMessage> Response<R> {
89 fn send(self, payload: R::Response) -> Result<()> {
90 self.responded.store(true, SeqCst);
91 self.peer.respond(self.receipt, payload)?;
92 Ok(())
93 }
94}
95
96struct StreamingResponse<R: RequestMessage> {
97 peer: Arc<Peer>,
98 receipt: Receipt<R>,
99}
100
101impl<R: RequestMessage> StreamingResponse<R> {
102 fn send(&self, payload: R::Response) -> Result<()> {
103 self.peer.respond(self.receipt, payload)?;
104 Ok(())
105 }
106}
107
108#[derive(Clone)]
109struct Session {
110 user_id: UserId,
111 connection_id: ConnectionId,
112 db: Arc<tokio::sync::Mutex<DbHandle>>,
113 peer: Arc<Peer>,
114 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
115 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
116 http_client: IsahcHttpClient,
117 rate_limiter: Arc<RateLimiter>,
118 _executor: Executor,
119}
120
121impl Session {
122 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
123 #[cfg(test)]
124 tokio::task::yield_now().await;
125 let guard = self.db.lock().await;
126 #[cfg(test)]
127 tokio::task::yield_now().await;
128 guard
129 }
130
131 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
132 #[cfg(test)]
133 tokio::task::yield_now().await;
134 let guard = self.connection_pool.lock();
135 ConnectionPoolGuard {
136 guard,
137 _not_send: PhantomData,
138 }
139 }
140}
141
142impl Debug for Session {
143 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
144 f.debug_struct("Session")
145 .field("user_id", &self.user_id)
146 .field("connection_id", &self.connection_id)
147 .finish()
148 }
149}
150
151struct DbHandle(Arc<Database>);
152
153impl Deref for DbHandle {
154 type Target = Database;
155
156 fn deref(&self) -> &Self::Target {
157 self.0.as_ref()
158 }
159}
160
161pub struct Server {
162 id: parking_lot::Mutex<ServerId>,
163 peer: Arc<Peer>,
164 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
165 app_state: Arc<AppState>,
166 handlers: HashMap<TypeId, MessageHandler>,
167 teardown: watch::Sender<bool>,
168}
169
170pub(crate) struct ConnectionPoolGuard<'a> {
171 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
172 _not_send: PhantomData<Rc<()>>,
173}
174
175#[derive(Serialize)]
176pub struct ServerSnapshot<'a> {
177 peer: &'a Peer,
178 #[serde(serialize_with = "serialize_deref")]
179 connection_pool: ConnectionPoolGuard<'a>,
180}
181
182pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
183where
184 S: Serializer,
185 T: Deref<Target = U>,
186 U: Serialize,
187{
188 Serialize::serialize(value.deref(), serializer)
189}
190
191impl Server {
192 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
193 let mut server = Self {
194 id: parking_lot::Mutex::new(id),
195 peer: Peer::new(id.0 as u32),
196 app_state: app_state.clone(),
197 connection_pool: Default::default(),
198 handlers: Default::default(),
199 teardown: watch::channel(false).0,
200 };
201
202 server
203 .add_request_handler(ping)
204 .add_request_handler(create_room)
205 .add_request_handler(join_room)
206 .add_request_handler(rejoin_room)
207 .add_request_handler(leave_room)
208 .add_request_handler(set_room_participant_role)
209 .add_request_handler(call)
210 .add_request_handler(cancel_call)
211 .add_message_handler(decline_call)
212 .add_request_handler(update_participant_location)
213 .add_request_handler(share_project)
214 .add_message_handler(unshare_project)
215 .add_request_handler(join_project)
216 .add_request_handler(join_hosted_project)
217 .add_message_handler(leave_project)
218 .add_request_handler(update_project)
219 .add_request_handler(update_worktree)
220 .add_message_handler(start_language_server)
221 .add_message_handler(update_language_server)
222 .add_message_handler(update_diagnostic_summary)
223 .add_message_handler(update_worktree_settings)
224 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
225 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
226 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
227 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
228 .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
229 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
230 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
231 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
232 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
233 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
234 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
235 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
236 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
237 .add_request_handler(
238 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
239 )
240 .add_request_handler(
241 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
242 )
243 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
244 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
245 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
246 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
247 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
248 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
249 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
250 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
251 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
252 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
253 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
254 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
255 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
256 .add_message_handler(create_buffer_for_peer)
257 .add_request_handler(update_buffer)
258 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
259 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
260 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
261 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
262 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
263 .add_request_handler(get_users)
264 .add_request_handler(fuzzy_search_users)
265 .add_request_handler(request_contact)
266 .add_request_handler(remove_contact)
267 .add_request_handler(respond_to_contact_request)
268 .add_request_handler(create_channel)
269 .add_request_handler(delete_channel)
270 .add_request_handler(invite_channel_member)
271 .add_request_handler(remove_channel_member)
272 .add_request_handler(set_channel_member_role)
273 .add_request_handler(set_channel_visibility)
274 .add_request_handler(rename_channel)
275 .add_request_handler(join_channel_buffer)
276 .add_request_handler(leave_channel_buffer)
277 .add_message_handler(update_channel_buffer)
278 .add_request_handler(rejoin_channel_buffers)
279 .add_request_handler(get_channel_members)
280 .add_request_handler(respond_to_channel_invite)
281 .add_request_handler(join_channel)
282 .add_request_handler(join_channel_chat)
283 .add_message_handler(leave_channel_chat)
284 .add_request_handler(send_channel_message)
285 .add_request_handler(remove_channel_message)
286 .add_request_handler(update_channel_message)
287 .add_request_handler(get_channel_messages)
288 .add_request_handler(get_channel_messages_by_id)
289 .add_request_handler(get_notifications)
290 .add_request_handler(mark_notification_as_read)
291 .add_request_handler(move_channel)
292 .add_request_handler(follow)
293 .add_message_handler(unfollow)
294 .add_message_handler(update_followers)
295 .add_request_handler(get_private_user_info)
296 .add_message_handler(acknowledge_channel_message)
297 .add_message_handler(acknowledge_buffer_version)
298 .add_streaming_request_handler({
299 let app_state = app_state.clone();
300 move |request, response, session| {
301 complete_with_language_model(
302 request,
303 response,
304 session,
305 app_state.config.openai_api_key.clone(),
306 app_state.config.google_ai_api_key.clone(),
307 )
308 }
309 })
310 .add_request_handler({
311 let app_state = app_state.clone();
312 move |request, response, session| {
313 count_tokens_with_language_model(
314 request,
315 response,
316 session,
317 app_state.config.google_ai_api_key.clone(),
318 )
319 }
320 });
321
322 Arc::new(server)
323 }
324
325 pub async fn start(&self) -> Result<()> {
326 let server_id = *self.id.lock();
327 let app_state = self.app_state.clone();
328 let peer = self.peer.clone();
329 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
330 let pool = self.connection_pool.clone();
331 let live_kit_client = self.app_state.live_kit_client.clone();
332
333 let span = info_span!("start server");
334 self.app_state.executor.spawn_detached(
335 async move {
336 tracing::info!("waiting for cleanup timeout");
337 timeout.await;
338 tracing::info!("cleanup timeout expired, retrieving stale rooms");
339 if let Some((room_ids, channel_ids)) = app_state
340 .db
341 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
342 .await
343 .trace_err()
344 {
345 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
346 tracing::info!(
347 stale_channel_buffer_count = channel_ids.len(),
348 "retrieved stale channel buffers"
349 );
350
351 for channel_id in channel_ids {
352 if let Some(refreshed_channel_buffer) = app_state
353 .db
354 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
355 .await
356 .trace_err()
357 {
358 for connection_id in refreshed_channel_buffer.connection_ids {
359 peer.send(
360 connection_id,
361 proto::UpdateChannelBufferCollaborators {
362 channel_id: channel_id.to_proto(),
363 collaborators: refreshed_channel_buffer
364 .collaborators
365 .clone(),
366 },
367 )
368 .trace_err();
369 }
370 }
371 }
372
373 for room_id in room_ids {
374 let mut contacts_to_update = HashSet::default();
375 let mut canceled_calls_to_user_ids = Vec::new();
376 let mut live_kit_room = String::new();
377 let mut delete_live_kit_room = false;
378
379 if let Some(mut refreshed_room) = app_state
380 .db
381 .clear_stale_room_participants(room_id, server_id)
382 .await
383 .trace_err()
384 {
385 tracing::info!(
386 room_id = room_id.0,
387 new_participant_count = refreshed_room.room.participants.len(),
388 "refreshed room"
389 );
390 room_updated(&refreshed_room.room, &peer);
391 if let Some(channel) = refreshed_room.channel.as_ref() {
392 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
393 }
394 contacts_to_update
395 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
396 contacts_to_update
397 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
398 canceled_calls_to_user_ids =
399 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
400 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
401 delete_live_kit_room = refreshed_room.room.participants.is_empty();
402 }
403
404 {
405 let pool = pool.lock();
406 for canceled_user_id in canceled_calls_to_user_ids {
407 for connection_id in pool.user_connection_ids(canceled_user_id) {
408 peer.send(
409 connection_id,
410 proto::CallCanceled {
411 room_id: room_id.to_proto(),
412 },
413 )
414 .trace_err();
415 }
416 }
417 }
418
419 for user_id in contacts_to_update {
420 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
421 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
422 if let Some((busy, contacts)) = busy.zip(contacts) {
423 let pool = pool.lock();
424 let updated_contact = contact_for_user(user_id, busy, &pool);
425 for contact in contacts {
426 if let db::Contact::Accepted {
427 user_id: contact_user_id,
428 ..
429 } = contact
430 {
431 for contact_conn_id in
432 pool.user_connection_ids(contact_user_id)
433 {
434 peer.send(
435 contact_conn_id,
436 proto::UpdateContacts {
437 contacts: vec![updated_contact.clone()],
438 remove_contacts: Default::default(),
439 incoming_requests: Default::default(),
440 remove_incoming_requests: Default::default(),
441 outgoing_requests: Default::default(),
442 remove_outgoing_requests: Default::default(),
443 },
444 )
445 .trace_err();
446 }
447 }
448 }
449 }
450 }
451
452 if let Some(live_kit) = live_kit_client.as_ref() {
453 if delete_live_kit_room {
454 live_kit.delete_room(live_kit_room).await.trace_err();
455 }
456 }
457 }
458 }
459
460 app_state
461 .db
462 .delete_stale_servers(&app_state.config.zed_environment, server_id)
463 .await
464 .trace_err();
465 }
466 .instrument(span),
467 );
468 Ok(())
469 }
470
471 pub fn teardown(&self) {
472 self.peer.teardown();
473 self.connection_pool.lock().reset();
474 let _ = self.teardown.send(true);
475 }
476
477 #[cfg(test)]
478 pub fn reset(&self, id: ServerId) {
479 self.teardown();
480 *self.id.lock() = id;
481 self.peer.reset(id.0 as u32);
482 let _ = self.teardown.send(false);
483 }
484
485 #[cfg(test)]
486 pub fn id(&self) -> ServerId {
487 *self.id.lock()
488 }
489
490 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
491 where
492 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
493 Fut: 'static + Send + Future<Output = Result<()>>,
494 M: EnvelopedMessage,
495 {
496 let prev_handler = self.handlers.insert(
497 TypeId::of::<M>(),
498 Box::new(move |envelope, session| {
499 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
500 let received_at = envelope.received_at;
501 tracing::info!(
502 "message received"
503 );
504 let start_time = Instant::now();
505 let future = (handler)(*envelope, session);
506 async move {
507 let result = future.await;
508 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
509 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
510 let queue_duration_ms = total_duration_ms - processing_duration_ms;
511 match result {
512 Err(error) => {
513 tracing::error!(%error, total_duration_ms, processing_duration_ms, queue_duration_ms, "error handling message")
514 }
515 Ok(()) => tracing::info!(total_duration_ms, processing_duration_ms, queue_duration_ms, "finished handling message"),
516 }
517 }
518 .boxed()
519 }),
520 );
521 if prev_handler.is_some() {
522 panic!("registered a handler for the same message twice");
523 }
524 self
525 }
526
527 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
528 where
529 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
530 Fut: 'static + Send + Future<Output = Result<()>>,
531 M: EnvelopedMessage,
532 {
533 self.add_handler(move |envelope, session| handler(envelope.payload, session));
534 self
535 }
536
537 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
538 where
539 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
540 Fut: Send + Future<Output = Result<()>>,
541 M: RequestMessage,
542 {
543 let handler = Arc::new(handler);
544 self.add_handler(move |envelope, session| {
545 let receipt = envelope.receipt();
546 let handler = handler.clone();
547 async move {
548 let peer = session.peer.clone();
549 let responded = Arc::new(AtomicBool::default());
550 let response = Response {
551 peer: peer.clone(),
552 responded: responded.clone(),
553 receipt,
554 };
555 match (handler)(envelope.payload, response, session).await {
556 Ok(()) => {
557 if responded.load(std::sync::atomic::Ordering::SeqCst) {
558 Ok(())
559 } else {
560 Err(anyhow!("handler did not send a response"))?
561 }
562 }
563 Err(error) => {
564 let proto_err = match &error {
565 Error::Internal(err) => err.to_proto(),
566 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
567 };
568 peer.respond_with_error(receipt, proto_err)?;
569 Err(error)
570 }
571 }
572 }
573 })
574 }
575
576 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
577 where
578 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
579 Fut: Send + Future<Output = Result<()>>,
580 M: RequestMessage,
581 {
582 let handler = Arc::new(handler);
583 self.add_handler(move |envelope, session| {
584 let receipt = envelope.receipt();
585 let handler = handler.clone();
586 async move {
587 let peer = session.peer.clone();
588 let response = StreamingResponse {
589 peer: peer.clone(),
590 receipt,
591 };
592 match (handler)(envelope.payload, response, session).await {
593 Ok(()) => {
594 peer.end_stream(receipt)?;
595 Ok(())
596 }
597 Err(error) => {
598 let proto_err = match &error {
599 Error::Internal(err) => err.to_proto(),
600 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
601 };
602 peer.respond_with_error(receipt, proto_err)?;
603 Err(error)
604 }
605 }
606 }
607 })
608 }
609
610 #[allow(clippy::too_many_arguments)]
611 pub fn handle_connection(
612 self: &Arc<Self>,
613 connection: Connection,
614 address: String,
615 user: User,
616 zed_version: ZedVersion,
617 impersonator: Option<User>,
618 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
619 executor: Executor,
620 ) -> impl Future<Output = ()> {
621 let this = self.clone();
622 let user_id = user.id;
623 let login = user.github_login.clone();
624 let span = info_span!("handle connection", %user_id, %login, %address, impersonator = field::Empty, connection_id = field::Empty);
625 if let Some(impersonator) = impersonator {
626 span.record("impersonator", &impersonator.github_login);
627 }
628 let mut teardown = self.teardown.subscribe();
629 async move {
630 if *teardown.borrow() {
631 tracing::error!("server is tearing down");
632 return
633 }
634 let (connection_id, handle_io, mut incoming_rx) = this
635 .peer
636 .add_connection(connection, {
637 let executor = executor.clone();
638 move |duration| executor.sleep(duration)
639 });
640 tracing::Span::current().record("connection_id", format!("{}", connection_id));
641 tracing::info!("connection opened");
642
643 let http_client = match IsahcHttpClient::new() {
644 Ok(http_client) => http_client,
645 Err(error) => {
646 tracing::error!(?error, "failed to create HTTP client");
647 return;
648 }
649 };
650
651 let session = Session {
652 user_id,
653 connection_id,
654 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
655 peer: this.peer.clone(),
656 connection_pool: this.connection_pool.clone(),
657 live_kit_client: this.app_state.live_kit_client.clone(),
658 http_client,
659 rate_limiter: this.app_state.rate_limiter.clone(),
660 _executor: executor.clone(),
661 };
662
663 if let Err(error) = this.send_initial_client_update(connection_id, user, zed_version, send_connection_id, &session).await {
664 tracing::error!(?error, "failed to send initial client update");
665 return;
666 }
667
668 let handle_io = handle_io.fuse();
669 futures::pin_mut!(handle_io);
670
671 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
672 // This prevents deadlocks when e.g., client A performs a request to client B and
673 // client B performs a request to client A. If both clients stop processing further
674 // messages until their respective request completes, they won't have a chance to
675 // respond to the other client's request and cause a deadlock.
676 //
677 // This arrangement ensures we will attempt to process earlier messages first, but fall
678 // back to processing messages arrived later in the spirit of making progress.
679 let mut foreground_message_handlers = FuturesUnordered::new();
680 let concurrent_handlers = Arc::new(Semaphore::new(256));
681 loop {
682 let next_message = async {
683 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
684 let message = incoming_rx.next().await;
685 (permit, message)
686 }.fuse();
687 futures::pin_mut!(next_message);
688 futures::select_biased! {
689 _ = teardown.changed().fuse() => return,
690 result = handle_io => {
691 if let Err(error) = result {
692 tracing::error!(?error, "error handling I/O");
693 }
694 break;
695 }
696 _ = foreground_message_handlers.next() => {}
697 next_message = next_message => {
698 let (permit, message) = next_message;
699 if let Some(message) = message {
700 let type_name = message.payload_type_name();
701 // note: we copy all the fields from the parent span so we can query them in the logs.
702 // (https://github.com/tokio-rs/tracing/issues/2670).
703 let span = tracing::info_span!("receive message", %user_id, %login, %connection_id, %address, type_name);
704 let span_enter = span.enter();
705 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
706 let is_background = message.is_background();
707 let handle_message = (handler)(message, session.clone());
708 drop(span_enter);
709
710 let handle_message = async move {
711 handle_message.await;
712 drop(permit);
713 }.instrument(span);
714 if is_background {
715 executor.spawn_detached(handle_message);
716 } else {
717 foreground_message_handlers.push(handle_message);
718 }
719 } else {
720 tracing::error!("no message handler");
721 }
722 } else {
723 tracing::info!("connection closed");
724 break;
725 }
726 }
727 }
728 }
729
730 drop(foreground_message_handlers);
731 tracing::info!("signing out");
732 if let Err(error) = connection_lost(session, teardown, executor).await {
733 tracing::error!(?error, "error signing out");
734 }
735
736 }.instrument(span)
737 }
738
739 async fn send_initial_client_update(
740 &self,
741 connection_id: ConnectionId,
742 user: User,
743 zed_version: ZedVersion,
744 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
745 session: &Session,
746 ) -> Result<()> {
747 self.peer.send(
748 connection_id,
749 proto::Hello {
750 peer_id: Some(connection_id.into()),
751 },
752 )?;
753 tracing::info!("sent hello message");
754
755 if let Some(send_connection_id) = send_connection_id.take() {
756 let _ = send_connection_id.send(connection_id);
757 }
758
759 if !user.connected_once {
760 self.peer.send(connection_id, proto::ShowContacts {})?;
761 self.app_state
762 .db
763 .set_user_connected_once(user.id, true)
764 .await?;
765 }
766
767 let (contacts, channels_for_user, channel_invites) = future::try_join3(
768 self.app_state.db.get_contacts(user.id),
769 self.app_state.db.get_channels_for_user(user.id),
770 self.app_state.db.get_channel_invites_for_user(user.id),
771 )
772 .await?;
773
774 {
775 let mut pool = self.connection_pool.lock();
776 pool.add_connection(connection_id, user.id, user.admin, zed_version);
777 for membership in &channels_for_user.channel_memberships {
778 pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
779 }
780 self.peer.send(
781 connection_id,
782 build_initial_contacts_update(contacts, &pool),
783 )?;
784 self.peer.send(
785 connection_id,
786 build_update_user_channels(&channels_for_user),
787 )?;
788 self.peer.send(
789 connection_id,
790 build_channels_update(channels_for_user, channel_invites),
791 )?;
792 }
793
794 if let Some(incoming_call) = self.app_state.db.incoming_call_for_user(user.id).await? {
795 self.peer.send(connection_id, incoming_call)?;
796 }
797
798 update_user_contacts(user.id, &session).await?;
799 Ok(())
800 }
801
802 pub async fn invite_code_redeemed(
803 self: &Arc<Self>,
804 inviter_id: UserId,
805 invitee_id: UserId,
806 ) -> Result<()> {
807 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
808 if let Some(code) = &user.invite_code {
809 let pool = self.connection_pool.lock();
810 let invitee_contact = contact_for_user(invitee_id, false, &pool);
811 for connection_id in pool.user_connection_ids(inviter_id) {
812 self.peer.send(
813 connection_id,
814 proto::UpdateContacts {
815 contacts: vec![invitee_contact.clone()],
816 ..Default::default()
817 },
818 )?;
819 self.peer.send(
820 connection_id,
821 proto::UpdateInviteInfo {
822 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
823 count: user.invite_count as u32,
824 },
825 )?;
826 }
827 }
828 }
829 Ok(())
830 }
831
832 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
833 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
834 if let Some(invite_code) = &user.invite_code {
835 let pool = self.connection_pool.lock();
836 for connection_id in pool.user_connection_ids(user_id) {
837 self.peer.send(
838 connection_id,
839 proto::UpdateInviteInfo {
840 url: format!(
841 "{}{}",
842 self.app_state.config.invite_link_prefix, invite_code
843 ),
844 count: user.invite_count as u32,
845 },
846 )?;
847 }
848 }
849 }
850 Ok(())
851 }
852
853 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
854 ServerSnapshot {
855 connection_pool: ConnectionPoolGuard {
856 guard: self.connection_pool.lock(),
857 _not_send: PhantomData,
858 },
859 peer: &self.peer,
860 }
861 }
862}
863
864impl<'a> Deref for ConnectionPoolGuard<'a> {
865 type Target = ConnectionPool;
866
867 fn deref(&self) -> &Self::Target {
868 &self.guard
869 }
870}
871
872impl<'a> DerefMut for ConnectionPoolGuard<'a> {
873 fn deref_mut(&mut self) -> &mut Self::Target {
874 &mut self.guard
875 }
876}
877
878impl<'a> Drop for ConnectionPoolGuard<'a> {
879 fn drop(&mut self) {
880 #[cfg(test)]
881 self.check_invariants();
882 }
883}
884
885fn broadcast<F>(
886 sender_id: Option<ConnectionId>,
887 receiver_ids: impl IntoIterator<Item = ConnectionId>,
888 mut f: F,
889) where
890 F: FnMut(ConnectionId) -> anyhow::Result<()>,
891{
892 for receiver_id in receiver_ids {
893 if Some(receiver_id) != sender_id {
894 if let Err(error) = f(receiver_id) {
895 tracing::error!("failed to send to {:?} {}", receiver_id, error);
896 }
897 }
898 }
899}
900
901pub struct ProtocolVersion(u32);
902
903impl Header for ProtocolVersion {
904 fn name() -> &'static HeaderName {
905 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
906 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
907 }
908
909 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
910 where
911 Self: Sized,
912 I: Iterator<Item = &'i axum::http::HeaderValue>,
913 {
914 let version = values
915 .next()
916 .ok_or_else(axum::headers::Error::invalid)?
917 .to_str()
918 .map_err(|_| axum::headers::Error::invalid())?
919 .parse()
920 .map_err(|_| axum::headers::Error::invalid())?;
921 Ok(Self(version))
922 }
923
924 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
925 values.extend([self.0.to_string().parse().unwrap()]);
926 }
927}
928
929pub struct AppVersionHeader(SemanticVersion);
930impl Header for AppVersionHeader {
931 fn name() -> &'static HeaderName {
932 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
933 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
934 }
935
936 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
937 where
938 Self: Sized,
939 I: Iterator<Item = &'i axum::http::HeaderValue>,
940 {
941 let version = values
942 .next()
943 .ok_or_else(axum::headers::Error::invalid)?
944 .to_str()
945 .map_err(|_| axum::headers::Error::invalid())?
946 .parse()
947 .map_err(|_| axum::headers::Error::invalid())?;
948 Ok(Self(version))
949 }
950
951 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
952 values.extend([self.0.to_string().parse().unwrap()]);
953 }
954}
955
956pub fn routes(server: Arc<Server>) -> Router<(), Body> {
957 Router::new()
958 .route("/rpc", get(handle_websocket_request))
959 .layer(
960 ServiceBuilder::new()
961 .layer(Extension(server.app_state.clone()))
962 .layer(middleware::from_fn(auth::validate_header)),
963 )
964 .route("/metrics", get(handle_metrics))
965 .layer(Extension(server))
966}
967
968pub async fn handle_websocket_request(
969 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
970 app_version_header: Option<TypedHeader<AppVersionHeader>>,
971 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
972 Extension(server): Extension<Arc<Server>>,
973 Extension(user): Extension<User>,
974 Extension(impersonator): Extension<Impersonator>,
975 ws: WebSocketUpgrade,
976) -> axum::response::Response {
977 if protocol_version != rpc::PROTOCOL_VERSION {
978 return (
979 StatusCode::UPGRADE_REQUIRED,
980 "client must be upgraded".to_string(),
981 )
982 .into_response();
983 }
984
985 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
986 return (
987 StatusCode::UPGRADE_REQUIRED,
988 "no version header found".to_string(),
989 )
990 .into_response();
991 };
992
993 if !version.can_collaborate() {
994 return (
995 StatusCode::UPGRADE_REQUIRED,
996 "client must be upgraded".to_string(),
997 )
998 .into_response();
999 }
1000
1001 let socket_address = socket_address.to_string();
1002 ws.on_upgrade(move |socket| {
1003 let socket = socket
1004 .map_ok(to_tungstenite_message)
1005 .err_into()
1006 .with(|message| async move { Ok(to_axum_message(message)) });
1007 let connection = Connection::new(Box::pin(socket));
1008 async move {
1009 server
1010 .handle_connection(
1011 connection,
1012 socket_address,
1013 user,
1014 version,
1015 impersonator.0,
1016 None,
1017 Executor::Production,
1018 )
1019 .await;
1020 }
1021 })
1022}
1023
1024pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1025 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1026 let connections_metric = CONNECTIONS_METRIC
1027 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1028
1029 let connections = server
1030 .connection_pool
1031 .lock()
1032 .connections()
1033 .filter(|connection| !connection.admin)
1034 .count();
1035 connections_metric.set(connections as _);
1036
1037 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1038 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1039 register_int_gauge!(
1040 "shared_projects",
1041 "number of open projects with one or more guests"
1042 )
1043 .unwrap()
1044 });
1045
1046 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1047 shared_projects_metric.set(shared_projects as _);
1048
1049 let encoder = prometheus::TextEncoder::new();
1050 let metric_families = prometheus::gather();
1051 let encoded_metrics = encoder
1052 .encode_to_string(&metric_families)
1053 .map_err(|err| anyhow!("{}", err))?;
1054 Ok(encoded_metrics)
1055}
1056
1057#[instrument(err, skip(executor))]
1058async fn connection_lost(
1059 session: Session,
1060 mut teardown: watch::Receiver<bool>,
1061 executor: Executor,
1062) -> Result<()> {
1063 session.peer.disconnect(session.connection_id);
1064 session
1065 .connection_pool()
1066 .await
1067 .remove_connection(session.connection_id)?;
1068
1069 session
1070 .db()
1071 .await
1072 .connection_lost(session.connection_id)
1073 .await
1074 .trace_err();
1075
1076 futures::select_biased! {
1077 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1078 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id, session.connection_id);
1079 leave_room_for_session(&session).await.trace_err();
1080 leave_channel_buffers_for_session(&session)
1081 .await
1082 .trace_err();
1083
1084 if !session
1085 .connection_pool()
1086 .await
1087 .is_user_online(session.user_id)
1088 {
1089 let db = session.db().await;
1090 if let Some(room) = db.decline_call(None, session.user_id).await.trace_err().flatten() {
1091 room_updated(&room, &session.peer);
1092 }
1093 }
1094
1095 update_user_contacts(session.user_id, &session).await?;
1096 }
1097 _ = teardown.changed().fuse() => {}
1098 }
1099
1100 Ok(())
1101}
1102
1103/// Acknowledges a ping from a client, used to keep the connection alive.
1104async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1105 response.send(proto::Ack {})?;
1106 Ok(())
1107}
1108
1109/// Creates a new room for calling (outside of channels)
1110async fn create_room(
1111 _request: proto::CreateRoom,
1112 response: Response<proto::CreateRoom>,
1113 session: Session,
1114) -> Result<()> {
1115 let live_kit_room = nanoid::nanoid!(30);
1116
1117 let live_kit_connection_info = {
1118 let live_kit_room = live_kit_room.clone();
1119 let live_kit = session.live_kit_client.as_ref();
1120
1121 util::async_maybe!({
1122 let live_kit = live_kit?;
1123
1124 let token = live_kit
1125 .room_token(&live_kit_room, &session.user_id.to_string())
1126 .trace_err()?;
1127
1128 Some(proto::LiveKitConnectionInfo {
1129 server_url: live_kit.url().into(),
1130 token,
1131 can_publish: true,
1132 })
1133 })
1134 }
1135 .await;
1136
1137 let room = session
1138 .db()
1139 .await
1140 .create_room(session.user_id, session.connection_id, &live_kit_room)
1141 .await?;
1142
1143 response.send(proto::CreateRoomResponse {
1144 room: Some(room.clone()),
1145 live_kit_connection_info,
1146 })?;
1147
1148 update_user_contacts(session.user_id, &session).await?;
1149 Ok(())
1150}
1151
1152/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1153async fn join_room(
1154 request: proto::JoinRoom,
1155 response: Response<proto::JoinRoom>,
1156 session: Session,
1157) -> Result<()> {
1158 let room_id = RoomId::from_proto(request.id);
1159
1160 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1161
1162 if let Some(channel_id) = channel_id {
1163 return join_channel_internal(channel_id, Box::new(response), session).await;
1164 }
1165
1166 let joined_room = {
1167 let room = session
1168 .db()
1169 .await
1170 .join_room(room_id, session.user_id, session.connection_id)
1171 .await?;
1172 room_updated(&room.room, &session.peer);
1173 room.into_inner()
1174 };
1175
1176 for connection_id in session
1177 .connection_pool()
1178 .await
1179 .user_connection_ids(session.user_id)
1180 {
1181 session
1182 .peer
1183 .send(
1184 connection_id,
1185 proto::CallCanceled {
1186 room_id: room_id.to_proto(),
1187 },
1188 )
1189 .trace_err();
1190 }
1191
1192 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1193 if let Some(token) = live_kit
1194 .room_token(
1195 &joined_room.room.live_kit_room,
1196 &session.user_id.to_string(),
1197 )
1198 .trace_err()
1199 {
1200 Some(proto::LiveKitConnectionInfo {
1201 server_url: live_kit.url().into(),
1202 token,
1203 can_publish: true,
1204 })
1205 } else {
1206 None
1207 }
1208 } else {
1209 None
1210 };
1211
1212 response.send(proto::JoinRoomResponse {
1213 room: Some(joined_room.room),
1214 channel_id: None,
1215 live_kit_connection_info,
1216 })?;
1217
1218 update_user_contacts(session.user_id, &session).await?;
1219 Ok(())
1220}
1221
1222/// Rejoin room is used to reconnect to a room after connection errors.
1223async fn rejoin_room(
1224 request: proto::RejoinRoom,
1225 response: Response<proto::RejoinRoom>,
1226 session: Session,
1227) -> Result<()> {
1228 let room;
1229 let channel;
1230 {
1231 let mut rejoined_room = session
1232 .db()
1233 .await
1234 .rejoin_room(request, session.user_id, session.connection_id)
1235 .await?;
1236
1237 response.send(proto::RejoinRoomResponse {
1238 room: Some(rejoined_room.room.clone()),
1239 reshared_projects: rejoined_room
1240 .reshared_projects
1241 .iter()
1242 .map(|project| proto::ResharedProject {
1243 id: project.id.to_proto(),
1244 collaborators: project
1245 .collaborators
1246 .iter()
1247 .map(|collaborator| collaborator.to_proto())
1248 .collect(),
1249 })
1250 .collect(),
1251 rejoined_projects: rejoined_room
1252 .rejoined_projects
1253 .iter()
1254 .map(|rejoined_project| proto::RejoinedProject {
1255 id: rejoined_project.id.to_proto(),
1256 worktrees: rejoined_project
1257 .worktrees
1258 .iter()
1259 .map(|worktree| proto::WorktreeMetadata {
1260 id: worktree.id,
1261 root_name: worktree.root_name.clone(),
1262 visible: worktree.visible,
1263 abs_path: worktree.abs_path.clone(),
1264 })
1265 .collect(),
1266 collaborators: rejoined_project
1267 .collaborators
1268 .iter()
1269 .map(|collaborator| collaborator.to_proto())
1270 .collect(),
1271 language_servers: rejoined_project.language_servers.clone(),
1272 })
1273 .collect(),
1274 })?;
1275 room_updated(&rejoined_room.room, &session.peer);
1276
1277 for project in &rejoined_room.reshared_projects {
1278 for collaborator in &project.collaborators {
1279 session
1280 .peer
1281 .send(
1282 collaborator.connection_id,
1283 proto::UpdateProjectCollaborator {
1284 project_id: project.id.to_proto(),
1285 old_peer_id: Some(project.old_connection_id.into()),
1286 new_peer_id: Some(session.connection_id.into()),
1287 },
1288 )
1289 .trace_err();
1290 }
1291
1292 broadcast(
1293 Some(session.connection_id),
1294 project
1295 .collaborators
1296 .iter()
1297 .map(|collaborator| collaborator.connection_id),
1298 |connection_id| {
1299 session.peer.forward_send(
1300 session.connection_id,
1301 connection_id,
1302 proto::UpdateProject {
1303 project_id: project.id.to_proto(),
1304 worktrees: project.worktrees.clone(),
1305 },
1306 )
1307 },
1308 );
1309 }
1310
1311 for project in &rejoined_room.rejoined_projects {
1312 for collaborator in &project.collaborators {
1313 session
1314 .peer
1315 .send(
1316 collaborator.connection_id,
1317 proto::UpdateProjectCollaborator {
1318 project_id: project.id.to_proto(),
1319 old_peer_id: Some(project.old_connection_id.into()),
1320 new_peer_id: Some(session.connection_id.into()),
1321 },
1322 )
1323 .trace_err();
1324 }
1325 }
1326
1327 for project in &mut rejoined_room.rejoined_projects {
1328 for worktree in mem::take(&mut project.worktrees) {
1329 #[cfg(any(test, feature = "test-support"))]
1330 const MAX_CHUNK_SIZE: usize = 2;
1331 #[cfg(not(any(test, feature = "test-support")))]
1332 const MAX_CHUNK_SIZE: usize = 256;
1333
1334 // Stream this worktree's entries.
1335 let message = proto::UpdateWorktree {
1336 project_id: project.id.to_proto(),
1337 worktree_id: worktree.id,
1338 abs_path: worktree.abs_path.clone(),
1339 root_name: worktree.root_name,
1340 updated_entries: worktree.updated_entries,
1341 removed_entries: worktree.removed_entries,
1342 scan_id: worktree.scan_id,
1343 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1344 updated_repositories: worktree.updated_repositories,
1345 removed_repositories: worktree.removed_repositories,
1346 };
1347 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1348 session.peer.send(session.connection_id, update.clone())?;
1349 }
1350
1351 // Stream this worktree's diagnostics.
1352 for summary in worktree.diagnostic_summaries {
1353 session.peer.send(
1354 session.connection_id,
1355 proto::UpdateDiagnosticSummary {
1356 project_id: project.id.to_proto(),
1357 worktree_id: worktree.id,
1358 summary: Some(summary),
1359 },
1360 )?;
1361 }
1362
1363 for settings_file in worktree.settings_files {
1364 session.peer.send(
1365 session.connection_id,
1366 proto::UpdateWorktreeSettings {
1367 project_id: project.id.to_proto(),
1368 worktree_id: worktree.id,
1369 path: settings_file.path,
1370 content: Some(settings_file.content),
1371 },
1372 )?;
1373 }
1374 }
1375
1376 for language_server in &project.language_servers {
1377 session.peer.send(
1378 session.connection_id,
1379 proto::UpdateLanguageServer {
1380 project_id: project.id.to_proto(),
1381 language_server_id: language_server.id,
1382 variant: Some(
1383 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1384 proto::LspDiskBasedDiagnosticsUpdated {},
1385 ),
1386 ),
1387 },
1388 )?;
1389 }
1390 }
1391
1392 let rejoined_room = rejoined_room.into_inner();
1393
1394 room = rejoined_room.room;
1395 channel = rejoined_room.channel;
1396 }
1397
1398 if let Some(channel) = channel {
1399 channel_updated(
1400 &channel,
1401 &room,
1402 &session.peer,
1403 &*session.connection_pool().await,
1404 );
1405 }
1406
1407 update_user_contacts(session.user_id, &session).await?;
1408 Ok(())
1409}
1410
1411/// leave room disconnects from the room.
1412async fn leave_room(
1413 _: proto::LeaveRoom,
1414 response: Response<proto::LeaveRoom>,
1415 session: Session,
1416) -> Result<()> {
1417 leave_room_for_session(&session).await?;
1418 response.send(proto::Ack {})?;
1419 Ok(())
1420}
1421
1422/// Updates the permissions of someone else in the room.
1423async fn set_room_participant_role(
1424 request: proto::SetRoomParticipantRole,
1425 response: Response<proto::SetRoomParticipantRole>,
1426 session: Session,
1427) -> Result<()> {
1428 let user_id = UserId::from_proto(request.user_id);
1429 let role = ChannelRole::from(request.role());
1430
1431 let (live_kit_room, can_publish) = {
1432 let room = session
1433 .db()
1434 .await
1435 .set_room_participant_role(
1436 session.user_id,
1437 RoomId::from_proto(request.room_id),
1438 user_id,
1439 role,
1440 )
1441 .await?;
1442
1443 let live_kit_room = room.live_kit_room.clone();
1444 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1445 room_updated(&room, &session.peer);
1446 (live_kit_room, can_publish)
1447 };
1448
1449 if let Some(live_kit) = session.live_kit_client.as_ref() {
1450 live_kit
1451 .update_participant(
1452 live_kit_room.clone(),
1453 request.user_id.to_string(),
1454 live_kit_server::proto::ParticipantPermission {
1455 can_subscribe: true,
1456 can_publish,
1457 can_publish_data: can_publish,
1458 hidden: false,
1459 recorder: false,
1460 },
1461 )
1462 .await
1463 .trace_err();
1464 }
1465
1466 response.send(proto::Ack {})?;
1467 Ok(())
1468}
1469
1470/// Call someone else into the current room
1471async fn call(
1472 request: proto::Call,
1473 response: Response<proto::Call>,
1474 session: Session,
1475) -> Result<()> {
1476 let room_id = RoomId::from_proto(request.room_id);
1477 let calling_user_id = session.user_id;
1478 let calling_connection_id = session.connection_id;
1479 let called_user_id = UserId::from_proto(request.called_user_id);
1480 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1481 if !session
1482 .db()
1483 .await
1484 .has_contact(calling_user_id, called_user_id)
1485 .await?
1486 {
1487 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1488 }
1489
1490 let incoming_call = {
1491 let (room, incoming_call) = &mut *session
1492 .db()
1493 .await
1494 .call(
1495 room_id,
1496 calling_user_id,
1497 calling_connection_id,
1498 called_user_id,
1499 initial_project_id,
1500 )
1501 .await?;
1502 room_updated(&room, &session.peer);
1503 mem::take(incoming_call)
1504 };
1505 update_user_contacts(called_user_id, &session).await?;
1506
1507 let mut calls = session
1508 .connection_pool()
1509 .await
1510 .user_connection_ids(called_user_id)
1511 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1512 .collect::<FuturesUnordered<_>>();
1513
1514 while let Some(call_response) = calls.next().await {
1515 match call_response.as_ref() {
1516 Ok(_) => {
1517 response.send(proto::Ack {})?;
1518 return Ok(());
1519 }
1520 Err(_) => {
1521 call_response.trace_err();
1522 }
1523 }
1524 }
1525
1526 {
1527 let room = session
1528 .db()
1529 .await
1530 .call_failed(room_id, called_user_id)
1531 .await?;
1532 room_updated(&room, &session.peer);
1533 }
1534 update_user_contacts(called_user_id, &session).await?;
1535
1536 Err(anyhow!("failed to ring user"))?
1537}
1538
1539/// Cancel an outgoing call.
1540async fn cancel_call(
1541 request: proto::CancelCall,
1542 response: Response<proto::CancelCall>,
1543 session: Session,
1544) -> Result<()> {
1545 let called_user_id = UserId::from_proto(request.called_user_id);
1546 let room_id = RoomId::from_proto(request.room_id);
1547 {
1548 let room = session
1549 .db()
1550 .await
1551 .cancel_call(room_id, session.connection_id, called_user_id)
1552 .await?;
1553 room_updated(&room, &session.peer);
1554 }
1555
1556 for connection_id in session
1557 .connection_pool()
1558 .await
1559 .user_connection_ids(called_user_id)
1560 {
1561 session
1562 .peer
1563 .send(
1564 connection_id,
1565 proto::CallCanceled {
1566 room_id: room_id.to_proto(),
1567 },
1568 )
1569 .trace_err();
1570 }
1571 response.send(proto::Ack {})?;
1572
1573 update_user_contacts(called_user_id, &session).await?;
1574 Ok(())
1575}
1576
1577/// Decline an incoming call.
1578async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1579 let room_id = RoomId::from_proto(message.room_id);
1580 {
1581 let room = session
1582 .db()
1583 .await
1584 .decline_call(Some(room_id), session.user_id)
1585 .await?
1586 .ok_or_else(|| anyhow!("failed to decline call"))?;
1587 room_updated(&room, &session.peer);
1588 }
1589
1590 for connection_id in session
1591 .connection_pool()
1592 .await
1593 .user_connection_ids(session.user_id)
1594 {
1595 session
1596 .peer
1597 .send(
1598 connection_id,
1599 proto::CallCanceled {
1600 room_id: room_id.to_proto(),
1601 },
1602 )
1603 .trace_err();
1604 }
1605 update_user_contacts(session.user_id, &session).await?;
1606 Ok(())
1607}
1608
1609/// Updates other participants in the room with your current location.
1610async fn update_participant_location(
1611 request: proto::UpdateParticipantLocation,
1612 response: Response<proto::UpdateParticipantLocation>,
1613 session: Session,
1614) -> Result<()> {
1615 let room_id = RoomId::from_proto(request.room_id);
1616 let location = request
1617 .location
1618 .ok_or_else(|| anyhow!("invalid location"))?;
1619
1620 let db = session.db().await;
1621 let room = db
1622 .update_room_participant_location(room_id, session.connection_id, location)
1623 .await?;
1624
1625 room_updated(&room, &session.peer);
1626 response.send(proto::Ack {})?;
1627 Ok(())
1628}
1629
1630/// Share a project into the room.
1631async fn share_project(
1632 request: proto::ShareProject,
1633 response: Response<proto::ShareProject>,
1634 session: Session,
1635) -> Result<()> {
1636 let (project_id, room) = &*session
1637 .db()
1638 .await
1639 .share_project(
1640 RoomId::from_proto(request.room_id),
1641 session.connection_id,
1642 &request.worktrees,
1643 )
1644 .await?;
1645 response.send(proto::ShareProjectResponse {
1646 project_id: project_id.to_proto(),
1647 })?;
1648 room_updated(&room, &session.peer);
1649
1650 Ok(())
1651}
1652
1653/// Unshare a project from the room.
1654async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1655 let project_id = ProjectId::from_proto(message.project_id);
1656
1657 let (room, guest_connection_ids) = &*session
1658 .db()
1659 .await
1660 .unshare_project(project_id, session.connection_id)
1661 .await?;
1662
1663 broadcast(
1664 Some(session.connection_id),
1665 guest_connection_ids.iter().copied(),
1666 |conn_id| session.peer.send(conn_id, message.clone()),
1667 );
1668 room_updated(&room, &session.peer);
1669
1670 Ok(())
1671}
1672
1673/// Join someone elses shared project.
1674async fn join_project(
1675 request: proto::JoinProject,
1676 response: Response<proto::JoinProject>,
1677 session: Session,
1678) -> Result<()> {
1679 let project_id = ProjectId::from_proto(request.project_id);
1680
1681 tracing::info!(%project_id, "join project");
1682
1683 let (project, replica_id) = &mut *session
1684 .db()
1685 .await
1686 .join_project_in_room(project_id, session.connection_id)
1687 .await?;
1688
1689 join_project_internal(response, session, project, replica_id)
1690}
1691
1692trait JoinProjectInternalResponse {
1693 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1694}
1695impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1696 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1697 Response::<proto::JoinProject>::send(self, result)
1698 }
1699}
1700impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
1701 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1702 Response::<proto::JoinHostedProject>::send(self, result)
1703 }
1704}
1705
1706fn join_project_internal(
1707 response: impl JoinProjectInternalResponse,
1708 session: Session,
1709 project: &mut Project,
1710 replica_id: &ReplicaId,
1711) -> Result<()> {
1712 let collaborators = project
1713 .collaborators
1714 .iter()
1715 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1716 .map(|collaborator| collaborator.to_proto())
1717 .collect::<Vec<_>>();
1718 let project_id = project.id;
1719 let guest_user_id = session.user_id;
1720
1721 let worktrees = project
1722 .worktrees
1723 .iter()
1724 .map(|(id, worktree)| proto::WorktreeMetadata {
1725 id: *id,
1726 root_name: worktree.root_name.clone(),
1727 visible: worktree.visible,
1728 abs_path: worktree.abs_path.clone(),
1729 })
1730 .collect::<Vec<_>>();
1731
1732 for collaborator in &collaborators {
1733 session
1734 .peer
1735 .send(
1736 collaborator.peer_id.unwrap().into(),
1737 proto::AddProjectCollaborator {
1738 project_id: project_id.to_proto(),
1739 collaborator: Some(proto::Collaborator {
1740 peer_id: Some(session.connection_id.into()),
1741 replica_id: replica_id.0 as u32,
1742 user_id: guest_user_id.to_proto(),
1743 }),
1744 },
1745 )
1746 .trace_err();
1747 }
1748
1749 // First, we send the metadata associated with each worktree.
1750 response.send(proto::JoinProjectResponse {
1751 project_id: project.id.0 as u64,
1752 worktrees: worktrees.clone(),
1753 replica_id: replica_id.0 as u32,
1754 collaborators: collaborators.clone(),
1755 language_servers: project.language_servers.clone(),
1756 role: project.role.into(), // todo
1757 })?;
1758
1759 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1760 #[cfg(any(test, feature = "test-support"))]
1761 const MAX_CHUNK_SIZE: usize = 2;
1762 #[cfg(not(any(test, feature = "test-support")))]
1763 const MAX_CHUNK_SIZE: usize = 256;
1764
1765 // Stream this worktree's entries.
1766 let message = proto::UpdateWorktree {
1767 project_id: project_id.to_proto(),
1768 worktree_id,
1769 abs_path: worktree.abs_path.clone(),
1770 root_name: worktree.root_name,
1771 updated_entries: worktree.entries,
1772 removed_entries: Default::default(),
1773 scan_id: worktree.scan_id,
1774 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1775 updated_repositories: worktree.repository_entries.into_values().collect(),
1776 removed_repositories: Default::default(),
1777 };
1778 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1779 session.peer.send(session.connection_id, update.clone())?;
1780 }
1781
1782 // Stream this worktree's diagnostics.
1783 for summary in worktree.diagnostic_summaries {
1784 session.peer.send(
1785 session.connection_id,
1786 proto::UpdateDiagnosticSummary {
1787 project_id: project_id.to_proto(),
1788 worktree_id: worktree.id,
1789 summary: Some(summary),
1790 },
1791 )?;
1792 }
1793
1794 for settings_file in worktree.settings_files {
1795 session.peer.send(
1796 session.connection_id,
1797 proto::UpdateWorktreeSettings {
1798 project_id: project_id.to_proto(),
1799 worktree_id: worktree.id,
1800 path: settings_file.path,
1801 content: Some(settings_file.content),
1802 },
1803 )?;
1804 }
1805 }
1806
1807 for language_server in &project.language_servers {
1808 session.peer.send(
1809 session.connection_id,
1810 proto::UpdateLanguageServer {
1811 project_id: project_id.to_proto(),
1812 language_server_id: language_server.id,
1813 variant: Some(
1814 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1815 proto::LspDiskBasedDiagnosticsUpdated {},
1816 ),
1817 ),
1818 },
1819 )?;
1820 }
1821
1822 Ok(())
1823}
1824
1825/// Leave someone elses shared project.
1826async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1827 let sender_id = session.connection_id;
1828 let project_id = ProjectId::from_proto(request.project_id);
1829 let db = session.db().await;
1830 if db.is_hosted_project(project_id).await? {
1831 let project = db.leave_hosted_project(project_id, sender_id).await?;
1832 project_left(&project, &session);
1833 return Ok(());
1834 }
1835
1836 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1837 tracing::info!(
1838 %project_id,
1839 host_user_id = ?project.host_user_id,
1840 host_connection_id = ?project.host_connection_id,
1841 "leave project"
1842 );
1843
1844 project_left(&project, &session);
1845 room_updated(&room, &session.peer);
1846
1847 Ok(())
1848}
1849
1850async fn join_hosted_project(
1851 request: proto::JoinHostedProject,
1852 response: Response<proto::JoinHostedProject>,
1853 session: Session,
1854) -> Result<()> {
1855 let (mut project, replica_id) = session
1856 .db()
1857 .await
1858 .join_hosted_project(
1859 ProjectId(request.project_id as i32),
1860 session.user_id,
1861 session.connection_id,
1862 )
1863 .await?;
1864
1865 join_project_internal(response, session, &mut project, &replica_id)
1866}
1867
1868/// Updates other participants with changes to the project
1869async fn update_project(
1870 request: proto::UpdateProject,
1871 response: Response<proto::UpdateProject>,
1872 session: Session,
1873) -> Result<()> {
1874 let project_id = ProjectId::from_proto(request.project_id);
1875 let (room, guest_connection_ids) = &*session
1876 .db()
1877 .await
1878 .update_project(project_id, session.connection_id, &request.worktrees)
1879 .await?;
1880 broadcast(
1881 Some(session.connection_id),
1882 guest_connection_ids.iter().copied(),
1883 |connection_id| {
1884 session
1885 .peer
1886 .forward_send(session.connection_id, connection_id, request.clone())
1887 },
1888 );
1889 room_updated(&room, &session.peer);
1890 response.send(proto::Ack {})?;
1891
1892 Ok(())
1893}
1894
1895/// Updates other participants with changes to the worktree
1896async fn update_worktree(
1897 request: proto::UpdateWorktree,
1898 response: Response<proto::UpdateWorktree>,
1899 session: Session,
1900) -> Result<()> {
1901 let guest_connection_ids = session
1902 .db()
1903 .await
1904 .update_worktree(&request, session.connection_id)
1905 .await?;
1906
1907 broadcast(
1908 Some(session.connection_id),
1909 guest_connection_ids.iter().copied(),
1910 |connection_id| {
1911 session
1912 .peer
1913 .forward_send(session.connection_id, connection_id, request.clone())
1914 },
1915 );
1916 response.send(proto::Ack {})?;
1917 Ok(())
1918}
1919
1920/// Updates other participants with changes to the diagnostics
1921async fn update_diagnostic_summary(
1922 message: proto::UpdateDiagnosticSummary,
1923 session: Session,
1924) -> Result<()> {
1925 let guest_connection_ids = session
1926 .db()
1927 .await
1928 .update_diagnostic_summary(&message, session.connection_id)
1929 .await?;
1930
1931 broadcast(
1932 Some(session.connection_id),
1933 guest_connection_ids.iter().copied(),
1934 |connection_id| {
1935 session
1936 .peer
1937 .forward_send(session.connection_id, connection_id, message.clone())
1938 },
1939 );
1940
1941 Ok(())
1942}
1943
1944/// Updates other participants with changes to the worktree settings
1945async fn update_worktree_settings(
1946 message: proto::UpdateWorktreeSettings,
1947 session: Session,
1948) -> Result<()> {
1949 let guest_connection_ids = session
1950 .db()
1951 .await
1952 .update_worktree_settings(&message, session.connection_id)
1953 .await?;
1954
1955 broadcast(
1956 Some(session.connection_id),
1957 guest_connection_ids.iter().copied(),
1958 |connection_id| {
1959 session
1960 .peer
1961 .forward_send(session.connection_id, connection_id, message.clone())
1962 },
1963 );
1964
1965 Ok(())
1966}
1967
1968/// Notify other participants that a language server has started.
1969async fn start_language_server(
1970 request: proto::StartLanguageServer,
1971 session: Session,
1972) -> Result<()> {
1973 let guest_connection_ids = session
1974 .db()
1975 .await
1976 .start_language_server(&request, session.connection_id)
1977 .await?;
1978
1979 broadcast(
1980 Some(session.connection_id),
1981 guest_connection_ids.iter().copied(),
1982 |connection_id| {
1983 session
1984 .peer
1985 .forward_send(session.connection_id, connection_id, request.clone())
1986 },
1987 );
1988 Ok(())
1989}
1990
1991/// Notify other participants that a language server has changed.
1992async fn update_language_server(
1993 request: proto::UpdateLanguageServer,
1994 session: Session,
1995) -> Result<()> {
1996 let project_id = ProjectId::from_proto(request.project_id);
1997 let project_connection_ids = session
1998 .db()
1999 .await
2000 .project_connection_ids(project_id, session.connection_id)
2001 .await?;
2002 broadcast(
2003 Some(session.connection_id),
2004 project_connection_ids.iter().copied(),
2005 |connection_id| {
2006 session
2007 .peer
2008 .forward_send(session.connection_id, connection_id, request.clone())
2009 },
2010 );
2011 Ok(())
2012}
2013
2014/// forward a project request to the host. These requests should be read only
2015/// as guests are allowed to send them.
2016async fn forward_read_only_project_request<T>(
2017 request: T,
2018 response: Response<T>,
2019 session: Session,
2020) -> Result<()>
2021where
2022 T: EntityMessage + RequestMessage,
2023{
2024 let project_id = ProjectId::from_proto(request.remote_entity_id());
2025 let host_connection_id = session
2026 .db()
2027 .await
2028 .host_for_read_only_project_request(project_id, session.connection_id)
2029 .await?;
2030 let payload = session
2031 .peer
2032 .forward_request(session.connection_id, host_connection_id, request)
2033 .await?;
2034 response.send(payload)?;
2035 Ok(())
2036}
2037
2038/// forward a project request to the host. These requests are disallowed
2039/// for guests.
2040async fn forward_mutating_project_request<T>(
2041 request: T,
2042 response: Response<T>,
2043 session: Session,
2044) -> Result<()>
2045where
2046 T: EntityMessage + RequestMessage,
2047{
2048 let project_id = ProjectId::from_proto(request.remote_entity_id());
2049 let host_connection_id = session
2050 .db()
2051 .await
2052 .host_for_mutating_project_request(project_id, session.connection_id)
2053 .await?;
2054 let payload = session
2055 .peer
2056 .forward_request(session.connection_id, host_connection_id, request)
2057 .await?;
2058 response.send(payload)?;
2059 Ok(())
2060}
2061
2062/// Notify other participants that a new buffer has been created
2063async fn create_buffer_for_peer(
2064 request: proto::CreateBufferForPeer,
2065 session: Session,
2066) -> Result<()> {
2067 session
2068 .db()
2069 .await
2070 .check_user_is_project_host(
2071 ProjectId::from_proto(request.project_id),
2072 session.connection_id,
2073 )
2074 .await?;
2075 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2076 session
2077 .peer
2078 .forward_send(session.connection_id, peer_id.into(), request)?;
2079 Ok(())
2080}
2081
2082/// Notify other participants that a buffer has been updated. This is
2083/// allowed for guests as long as the update is limited to selections.
2084async fn update_buffer(
2085 request: proto::UpdateBuffer,
2086 response: Response<proto::UpdateBuffer>,
2087 session: Session,
2088) -> Result<()> {
2089 let project_id = ProjectId::from_proto(request.project_id);
2090 let mut guest_connection_ids;
2091 let mut host_connection_id = None;
2092
2093 let mut requires_write_permission = false;
2094
2095 for op in request.operations.iter() {
2096 match op.variant {
2097 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2098 Some(_) => requires_write_permission = true,
2099 }
2100 }
2101
2102 {
2103 let collaborators = session
2104 .db()
2105 .await
2106 .project_collaborators_for_buffer_update(
2107 project_id,
2108 session.connection_id,
2109 requires_write_permission,
2110 )
2111 .await?;
2112 guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
2113 for collaborator in collaborators.iter() {
2114 if collaborator.is_host {
2115 host_connection_id = Some(collaborator.connection_id);
2116 } else {
2117 guest_connection_ids.push(collaborator.connection_id);
2118 }
2119 }
2120 }
2121 let host_connection_id = host_connection_id.ok_or_else(|| anyhow!("host not found"))?;
2122
2123 broadcast(
2124 Some(session.connection_id),
2125 guest_connection_ids,
2126 |connection_id| {
2127 session
2128 .peer
2129 .forward_send(session.connection_id, connection_id, request.clone())
2130 },
2131 );
2132 if host_connection_id != session.connection_id {
2133 session
2134 .peer
2135 .forward_request(session.connection_id, host_connection_id, request.clone())
2136 .await?;
2137 }
2138
2139 response.send(proto::Ack {})?;
2140 Ok(())
2141}
2142
2143/// Notify other participants that a project has been updated.
2144async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2145 request: T,
2146 session: Session,
2147) -> Result<()> {
2148 let project_id = ProjectId::from_proto(request.remote_entity_id());
2149 let project_connection_ids = session
2150 .db()
2151 .await
2152 .project_connection_ids(project_id, session.connection_id)
2153 .await?;
2154
2155 broadcast(
2156 Some(session.connection_id),
2157 project_connection_ids.iter().copied(),
2158 |connection_id| {
2159 session
2160 .peer
2161 .forward_send(session.connection_id, connection_id, request.clone())
2162 },
2163 );
2164 Ok(())
2165}
2166
2167/// Start following another user in a call.
2168async fn follow(
2169 request: proto::Follow,
2170 response: Response<proto::Follow>,
2171 session: Session,
2172) -> Result<()> {
2173 let room_id = RoomId::from_proto(request.room_id);
2174 let project_id = request.project_id.map(ProjectId::from_proto);
2175 let leader_id = request
2176 .leader_id
2177 .ok_or_else(|| anyhow!("invalid leader id"))?
2178 .into();
2179 let follower_id = session.connection_id;
2180
2181 session
2182 .db()
2183 .await
2184 .check_room_participants(room_id, leader_id, session.connection_id)
2185 .await?;
2186
2187 let response_payload = session
2188 .peer
2189 .forward_request(session.connection_id, leader_id, request)
2190 .await?;
2191 response.send(response_payload)?;
2192
2193 if let Some(project_id) = project_id {
2194 let room = session
2195 .db()
2196 .await
2197 .follow(room_id, project_id, leader_id, follower_id)
2198 .await?;
2199 room_updated(&room, &session.peer);
2200 }
2201
2202 Ok(())
2203}
2204
2205/// Stop following another user in a call.
2206async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2207 let room_id = RoomId::from_proto(request.room_id);
2208 let project_id = request.project_id.map(ProjectId::from_proto);
2209 let leader_id = request
2210 .leader_id
2211 .ok_or_else(|| anyhow!("invalid leader id"))?
2212 .into();
2213 let follower_id = session.connection_id;
2214
2215 session
2216 .db()
2217 .await
2218 .check_room_participants(room_id, leader_id, session.connection_id)
2219 .await?;
2220
2221 session
2222 .peer
2223 .forward_send(session.connection_id, leader_id, request)?;
2224
2225 if let Some(project_id) = project_id {
2226 let room = session
2227 .db()
2228 .await
2229 .unfollow(room_id, project_id, leader_id, follower_id)
2230 .await?;
2231 room_updated(&room, &session.peer);
2232 }
2233
2234 Ok(())
2235}
2236
2237/// Notify everyone following you of your current location.
2238async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2239 let room_id = RoomId::from_proto(request.room_id);
2240 let database = session.db.lock().await;
2241
2242 let connection_ids = if let Some(project_id) = request.project_id {
2243 let project_id = ProjectId::from_proto(project_id);
2244 database
2245 .project_connection_ids(project_id, session.connection_id)
2246 .await?
2247 } else {
2248 database
2249 .room_connection_ids(room_id, session.connection_id)
2250 .await?
2251 };
2252
2253 // For now, don't send view update messages back to that view's current leader.
2254 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2255 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2256 _ => None,
2257 });
2258
2259 for connection_id in connection_ids.iter().cloned() {
2260 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2261 session
2262 .peer
2263 .forward_send(session.connection_id, connection_id, request.clone())?;
2264 }
2265 }
2266 Ok(())
2267}
2268
2269/// Get public data about users.
2270async fn get_users(
2271 request: proto::GetUsers,
2272 response: Response<proto::GetUsers>,
2273 session: Session,
2274) -> Result<()> {
2275 let user_ids = request
2276 .user_ids
2277 .into_iter()
2278 .map(UserId::from_proto)
2279 .collect();
2280 let users = session
2281 .db()
2282 .await
2283 .get_users_by_ids(user_ids)
2284 .await?
2285 .into_iter()
2286 .map(|user| proto::User {
2287 id: user.id.to_proto(),
2288 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2289 github_login: user.github_login,
2290 })
2291 .collect();
2292 response.send(proto::UsersResponse { users })?;
2293 Ok(())
2294}
2295
2296/// Search for users (to invite) buy Github login
2297async fn fuzzy_search_users(
2298 request: proto::FuzzySearchUsers,
2299 response: Response<proto::FuzzySearchUsers>,
2300 session: Session,
2301) -> Result<()> {
2302 let query = request.query;
2303 let users = match query.len() {
2304 0 => vec![],
2305 1 | 2 => session
2306 .db()
2307 .await
2308 .get_user_by_github_login(&query)
2309 .await?
2310 .into_iter()
2311 .collect(),
2312 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2313 };
2314 let users = users
2315 .into_iter()
2316 .filter(|user| user.id != session.user_id)
2317 .map(|user| proto::User {
2318 id: user.id.to_proto(),
2319 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2320 github_login: user.github_login,
2321 })
2322 .collect();
2323 response.send(proto::UsersResponse { users })?;
2324 Ok(())
2325}
2326
2327/// Send a contact request to another user.
2328async fn request_contact(
2329 request: proto::RequestContact,
2330 response: Response<proto::RequestContact>,
2331 session: Session,
2332) -> Result<()> {
2333 let requester_id = session.user_id;
2334 let responder_id = UserId::from_proto(request.responder_id);
2335 if requester_id == responder_id {
2336 return Err(anyhow!("cannot add yourself as a contact"))?;
2337 }
2338
2339 let notifications = session
2340 .db()
2341 .await
2342 .send_contact_request(requester_id, responder_id)
2343 .await?;
2344
2345 // Update outgoing contact requests of requester
2346 let mut update = proto::UpdateContacts::default();
2347 update.outgoing_requests.push(responder_id.to_proto());
2348 for connection_id in session
2349 .connection_pool()
2350 .await
2351 .user_connection_ids(requester_id)
2352 {
2353 session.peer.send(connection_id, update.clone())?;
2354 }
2355
2356 // Update incoming contact requests of responder
2357 let mut update = proto::UpdateContacts::default();
2358 update
2359 .incoming_requests
2360 .push(proto::IncomingContactRequest {
2361 requester_id: requester_id.to_proto(),
2362 });
2363 let connection_pool = session.connection_pool().await;
2364 for connection_id in connection_pool.user_connection_ids(responder_id) {
2365 session.peer.send(connection_id, update.clone())?;
2366 }
2367
2368 send_notifications(&connection_pool, &session.peer, notifications);
2369
2370 response.send(proto::Ack {})?;
2371 Ok(())
2372}
2373
2374/// Accept or decline a contact request
2375async fn respond_to_contact_request(
2376 request: proto::RespondToContactRequest,
2377 response: Response<proto::RespondToContactRequest>,
2378 session: Session,
2379) -> Result<()> {
2380 let responder_id = session.user_id;
2381 let requester_id = UserId::from_proto(request.requester_id);
2382 let db = session.db().await;
2383 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2384 db.dismiss_contact_notification(responder_id, requester_id)
2385 .await?;
2386 } else {
2387 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2388
2389 let notifications = db
2390 .respond_to_contact_request(responder_id, requester_id, accept)
2391 .await?;
2392 let requester_busy = db.is_user_busy(requester_id).await?;
2393 let responder_busy = db.is_user_busy(responder_id).await?;
2394
2395 let pool = session.connection_pool().await;
2396 // Update responder with new contact
2397 let mut update = proto::UpdateContacts::default();
2398 if accept {
2399 update
2400 .contacts
2401 .push(contact_for_user(requester_id, requester_busy, &pool));
2402 }
2403 update
2404 .remove_incoming_requests
2405 .push(requester_id.to_proto());
2406 for connection_id in pool.user_connection_ids(responder_id) {
2407 session.peer.send(connection_id, update.clone())?;
2408 }
2409
2410 // Update requester with new contact
2411 let mut update = proto::UpdateContacts::default();
2412 if accept {
2413 update
2414 .contacts
2415 .push(contact_for_user(responder_id, responder_busy, &pool));
2416 }
2417 update
2418 .remove_outgoing_requests
2419 .push(responder_id.to_proto());
2420
2421 for connection_id in pool.user_connection_ids(requester_id) {
2422 session.peer.send(connection_id, update.clone())?;
2423 }
2424
2425 send_notifications(&pool, &session.peer, notifications);
2426 }
2427
2428 response.send(proto::Ack {})?;
2429 Ok(())
2430}
2431
2432/// Remove a contact.
2433async fn remove_contact(
2434 request: proto::RemoveContact,
2435 response: Response<proto::RemoveContact>,
2436 session: Session,
2437) -> Result<()> {
2438 let requester_id = session.user_id;
2439 let responder_id = UserId::from_proto(request.user_id);
2440 let db = session.db().await;
2441 let (contact_accepted, deleted_notification_id) =
2442 db.remove_contact(requester_id, responder_id).await?;
2443
2444 let pool = session.connection_pool().await;
2445 // Update outgoing contact requests of requester
2446 let mut update = proto::UpdateContacts::default();
2447 if contact_accepted {
2448 update.remove_contacts.push(responder_id.to_proto());
2449 } else {
2450 update
2451 .remove_outgoing_requests
2452 .push(responder_id.to_proto());
2453 }
2454 for connection_id in pool.user_connection_ids(requester_id) {
2455 session.peer.send(connection_id, update.clone())?;
2456 }
2457
2458 // Update incoming contact requests of responder
2459 let mut update = proto::UpdateContacts::default();
2460 if contact_accepted {
2461 update.remove_contacts.push(requester_id.to_proto());
2462 } else {
2463 update
2464 .remove_incoming_requests
2465 .push(requester_id.to_proto());
2466 }
2467 for connection_id in pool.user_connection_ids(responder_id) {
2468 session.peer.send(connection_id, update.clone())?;
2469 if let Some(notification_id) = deleted_notification_id {
2470 session.peer.send(
2471 connection_id,
2472 proto::DeleteNotification {
2473 notification_id: notification_id.to_proto(),
2474 },
2475 )?;
2476 }
2477 }
2478
2479 response.send(proto::Ack {})?;
2480 Ok(())
2481}
2482
2483/// Creates a new channel.
2484async fn create_channel(
2485 request: proto::CreateChannel,
2486 response: Response<proto::CreateChannel>,
2487 session: Session,
2488) -> Result<()> {
2489 let db = session.db().await;
2490
2491 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
2492 let (channel, membership) = db
2493 .create_channel(&request.name, parent_id, session.user_id)
2494 .await?;
2495
2496 let root_id = channel.root_id();
2497 let channel = Channel::from_model(channel);
2498
2499 response.send(proto::CreateChannelResponse {
2500 channel: Some(channel.to_proto()),
2501 parent_id: request.parent_id,
2502 })?;
2503
2504 let mut connection_pool = session.connection_pool().await;
2505 if let Some(membership) = membership {
2506 connection_pool.subscribe_to_channel(
2507 membership.user_id,
2508 membership.channel_id,
2509 membership.role,
2510 );
2511 let update = proto::UpdateUserChannels {
2512 channel_memberships: vec![proto::ChannelMembership {
2513 channel_id: membership.channel_id.to_proto(),
2514 role: membership.role.into(),
2515 }],
2516 ..Default::default()
2517 };
2518 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2519 session.peer.send(connection_id, update.clone())?;
2520 }
2521 }
2522
2523 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2524 if !role.can_see_channel(channel.visibility) {
2525 continue;
2526 }
2527
2528 let update = proto::UpdateChannels {
2529 channels: vec![channel.to_proto()],
2530 ..Default::default()
2531 };
2532 session.peer.send(connection_id, update.clone())?;
2533 }
2534
2535 Ok(())
2536}
2537
2538/// Delete a channel
2539async fn delete_channel(
2540 request: proto::DeleteChannel,
2541 response: Response<proto::DeleteChannel>,
2542 session: Session,
2543) -> Result<()> {
2544 let db = session.db().await;
2545
2546 let channel_id = request.channel_id;
2547 let (root_channel, removed_channels) = db
2548 .delete_channel(ChannelId::from_proto(channel_id), session.user_id)
2549 .await?;
2550 response.send(proto::Ack {})?;
2551
2552 // Notify members of removed channels
2553 let mut update = proto::UpdateChannels::default();
2554 update
2555 .delete_channels
2556 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2557
2558 let connection_pool = session.connection_pool().await;
2559 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2560 session.peer.send(connection_id, update.clone())?;
2561 }
2562
2563 Ok(())
2564}
2565
2566/// Invite someone to join a channel.
2567async fn invite_channel_member(
2568 request: proto::InviteChannelMember,
2569 response: Response<proto::InviteChannelMember>,
2570 session: Session,
2571) -> Result<()> {
2572 let db = session.db().await;
2573 let channel_id = ChannelId::from_proto(request.channel_id);
2574 let invitee_id = UserId::from_proto(request.user_id);
2575 let InviteMemberResult {
2576 channel,
2577 notifications,
2578 } = db
2579 .invite_channel_member(
2580 channel_id,
2581 invitee_id,
2582 session.user_id,
2583 request.role().into(),
2584 )
2585 .await?;
2586
2587 let update = proto::UpdateChannels {
2588 channel_invitations: vec![channel.to_proto()],
2589 ..Default::default()
2590 };
2591
2592 let connection_pool = session.connection_pool().await;
2593 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2594 session.peer.send(connection_id, update.clone())?;
2595 }
2596
2597 send_notifications(&connection_pool, &session.peer, notifications);
2598
2599 response.send(proto::Ack {})?;
2600 Ok(())
2601}
2602
2603/// remove someone from a channel
2604async fn remove_channel_member(
2605 request: proto::RemoveChannelMember,
2606 response: Response<proto::RemoveChannelMember>,
2607 session: Session,
2608) -> Result<()> {
2609 let db = session.db().await;
2610 let channel_id = ChannelId::from_proto(request.channel_id);
2611 let member_id = UserId::from_proto(request.user_id);
2612
2613 let RemoveChannelMemberResult {
2614 membership_update,
2615 notification_id,
2616 } = db
2617 .remove_channel_member(channel_id, member_id, session.user_id)
2618 .await?;
2619
2620 let mut connection_pool = session.connection_pool().await;
2621 notify_membership_updated(
2622 &mut connection_pool,
2623 membership_update,
2624 member_id,
2625 &session.peer,
2626 );
2627 for connection_id in connection_pool.user_connection_ids(member_id) {
2628 if let Some(notification_id) = notification_id {
2629 session
2630 .peer
2631 .send(
2632 connection_id,
2633 proto::DeleteNotification {
2634 notification_id: notification_id.to_proto(),
2635 },
2636 )
2637 .trace_err();
2638 }
2639 }
2640
2641 response.send(proto::Ack {})?;
2642 Ok(())
2643}
2644
2645/// Toggle the channel between public and private.
2646/// Care is taken to maintain the invariant that public channels only descend from public channels,
2647/// (though members-only channels can appear at any point in the hierarchy).
2648async fn set_channel_visibility(
2649 request: proto::SetChannelVisibility,
2650 response: Response<proto::SetChannelVisibility>,
2651 session: Session,
2652) -> Result<()> {
2653 let db = session.db().await;
2654 let channel_id = ChannelId::from_proto(request.channel_id);
2655 let visibility = request.visibility().into();
2656
2657 let channel_model = db
2658 .set_channel_visibility(channel_id, visibility, session.user_id)
2659 .await?;
2660 let root_id = channel_model.root_id();
2661 let channel = Channel::from_model(channel_model);
2662
2663 let mut connection_pool = session.connection_pool().await;
2664 for (user_id, role) in connection_pool
2665 .channel_user_ids(root_id)
2666 .collect::<Vec<_>>()
2667 .into_iter()
2668 {
2669 let update = if role.can_see_channel(channel.visibility) {
2670 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2671 proto::UpdateChannels {
2672 channels: vec![channel.to_proto()],
2673 ..Default::default()
2674 }
2675 } else {
2676 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2677 proto::UpdateChannels {
2678 delete_channels: vec![channel.id.to_proto()],
2679 ..Default::default()
2680 }
2681 };
2682
2683 for connection_id in connection_pool.user_connection_ids(user_id) {
2684 session.peer.send(connection_id, update.clone())?;
2685 }
2686 }
2687
2688 response.send(proto::Ack {})?;
2689 Ok(())
2690}
2691
2692/// Alter the role for a user in the channel.
2693async fn set_channel_member_role(
2694 request: proto::SetChannelMemberRole,
2695 response: Response<proto::SetChannelMemberRole>,
2696 session: Session,
2697) -> Result<()> {
2698 let db = session.db().await;
2699 let channel_id = ChannelId::from_proto(request.channel_id);
2700 let member_id = UserId::from_proto(request.user_id);
2701 let result = db
2702 .set_channel_member_role(
2703 channel_id,
2704 session.user_id,
2705 member_id,
2706 request.role().into(),
2707 )
2708 .await?;
2709
2710 match result {
2711 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2712 let mut connection_pool = session.connection_pool().await;
2713 notify_membership_updated(
2714 &mut connection_pool,
2715 membership_update,
2716 member_id,
2717 &session.peer,
2718 )
2719 }
2720 db::SetMemberRoleResult::InviteUpdated(channel) => {
2721 let update = proto::UpdateChannels {
2722 channel_invitations: vec![channel.to_proto()],
2723 ..Default::default()
2724 };
2725
2726 for connection_id in session
2727 .connection_pool()
2728 .await
2729 .user_connection_ids(member_id)
2730 {
2731 session.peer.send(connection_id, update.clone())?;
2732 }
2733 }
2734 }
2735
2736 response.send(proto::Ack {})?;
2737 Ok(())
2738}
2739
2740/// Change the name of a channel
2741async fn rename_channel(
2742 request: proto::RenameChannel,
2743 response: Response<proto::RenameChannel>,
2744 session: Session,
2745) -> Result<()> {
2746 let db = session.db().await;
2747 let channel_id = ChannelId::from_proto(request.channel_id);
2748 let channel_model = db
2749 .rename_channel(channel_id, session.user_id, &request.name)
2750 .await?;
2751 let root_id = channel_model.root_id();
2752 let channel = Channel::from_model(channel_model);
2753
2754 response.send(proto::RenameChannelResponse {
2755 channel: Some(channel.to_proto()),
2756 })?;
2757
2758 let connection_pool = session.connection_pool().await;
2759 let update = proto::UpdateChannels {
2760 channels: vec![channel.to_proto()],
2761 ..Default::default()
2762 };
2763 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2764 if role.can_see_channel(channel.visibility) {
2765 session.peer.send(connection_id, update.clone())?;
2766 }
2767 }
2768
2769 Ok(())
2770}
2771
2772/// Move a channel to a new parent.
2773async fn move_channel(
2774 request: proto::MoveChannel,
2775 response: Response<proto::MoveChannel>,
2776 session: Session,
2777) -> Result<()> {
2778 let channel_id = ChannelId::from_proto(request.channel_id);
2779 let to = ChannelId::from_proto(request.to);
2780
2781 let (root_id, channels) = session
2782 .db()
2783 .await
2784 .move_channel(channel_id, to, session.user_id)
2785 .await?;
2786
2787 let connection_pool = session.connection_pool().await;
2788 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2789 let channels = channels
2790 .iter()
2791 .filter_map(|channel| {
2792 if role.can_see_channel(channel.visibility) {
2793 Some(channel.to_proto())
2794 } else {
2795 None
2796 }
2797 })
2798 .collect::<Vec<_>>();
2799 if channels.is_empty() {
2800 continue;
2801 }
2802
2803 let update = proto::UpdateChannels {
2804 channels,
2805 ..Default::default()
2806 };
2807
2808 session.peer.send(connection_id, update.clone())?;
2809 }
2810
2811 response.send(Ack {})?;
2812 Ok(())
2813}
2814
2815/// Get the list of channel members
2816async fn get_channel_members(
2817 request: proto::GetChannelMembers,
2818 response: Response<proto::GetChannelMembers>,
2819 session: Session,
2820) -> Result<()> {
2821 let db = session.db().await;
2822 let channel_id = ChannelId::from_proto(request.channel_id);
2823 let members = db
2824 .get_channel_participant_details(channel_id, session.user_id)
2825 .await?;
2826 response.send(proto::GetChannelMembersResponse { members })?;
2827 Ok(())
2828}
2829
2830/// Accept or decline a channel invitation.
2831async fn respond_to_channel_invite(
2832 request: proto::RespondToChannelInvite,
2833 response: Response<proto::RespondToChannelInvite>,
2834 session: Session,
2835) -> Result<()> {
2836 let db = session.db().await;
2837 let channel_id = ChannelId::from_proto(request.channel_id);
2838 let RespondToChannelInvite {
2839 membership_update,
2840 notifications,
2841 } = db
2842 .respond_to_channel_invite(channel_id, session.user_id, request.accept)
2843 .await?;
2844
2845 let mut connection_pool = session.connection_pool().await;
2846 if let Some(membership_update) = membership_update {
2847 notify_membership_updated(
2848 &mut connection_pool,
2849 membership_update,
2850 session.user_id,
2851 &session.peer,
2852 );
2853 } else {
2854 let update = proto::UpdateChannels {
2855 remove_channel_invitations: vec![channel_id.to_proto()],
2856 ..Default::default()
2857 };
2858
2859 for connection_id in connection_pool.user_connection_ids(session.user_id) {
2860 session.peer.send(connection_id, update.clone())?;
2861 }
2862 };
2863
2864 send_notifications(&connection_pool, &session.peer, notifications);
2865
2866 response.send(proto::Ack {})?;
2867
2868 Ok(())
2869}
2870
2871/// Join the channels' room
2872async fn join_channel(
2873 request: proto::JoinChannel,
2874 response: Response<proto::JoinChannel>,
2875 session: Session,
2876) -> Result<()> {
2877 let channel_id = ChannelId::from_proto(request.channel_id);
2878 join_channel_internal(channel_id, Box::new(response), session).await
2879}
2880
2881trait JoinChannelInternalResponse {
2882 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
2883}
2884impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
2885 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2886 Response::<proto::JoinChannel>::send(self, result)
2887 }
2888}
2889impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
2890 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
2891 Response::<proto::JoinRoom>::send(self, result)
2892 }
2893}
2894
2895async fn join_channel_internal(
2896 channel_id: ChannelId,
2897 response: Box<impl JoinChannelInternalResponse>,
2898 session: Session,
2899) -> Result<()> {
2900 let joined_room = {
2901 leave_room_for_session(&session).await?;
2902 let db = session.db().await;
2903
2904 let (joined_room, membership_updated, role) = db
2905 .join_channel(channel_id, session.user_id, session.connection_id)
2906 .await?;
2907
2908 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
2909 let (can_publish, token) = if role == ChannelRole::Guest {
2910 (
2911 false,
2912 live_kit
2913 .guest_token(
2914 &joined_room.room.live_kit_room,
2915 &session.user_id.to_string(),
2916 )
2917 .trace_err()?,
2918 )
2919 } else {
2920 (
2921 true,
2922 live_kit
2923 .room_token(
2924 &joined_room.room.live_kit_room,
2925 &session.user_id.to_string(),
2926 )
2927 .trace_err()?,
2928 )
2929 };
2930
2931 Some(LiveKitConnectionInfo {
2932 server_url: live_kit.url().into(),
2933 token,
2934 can_publish,
2935 })
2936 });
2937
2938 response.send(proto::JoinRoomResponse {
2939 room: Some(joined_room.room.clone()),
2940 channel_id: joined_room
2941 .channel
2942 .as_ref()
2943 .map(|channel| channel.id.to_proto()),
2944 live_kit_connection_info,
2945 })?;
2946
2947 let mut connection_pool = session.connection_pool().await;
2948 if let Some(membership_updated) = membership_updated {
2949 notify_membership_updated(
2950 &mut connection_pool,
2951 membership_updated,
2952 session.user_id,
2953 &session.peer,
2954 );
2955 }
2956
2957 room_updated(&joined_room.room, &session.peer);
2958
2959 joined_room
2960 };
2961
2962 channel_updated(
2963 &joined_room
2964 .channel
2965 .ok_or_else(|| anyhow!("channel not returned"))?,
2966 &joined_room.room,
2967 &session.peer,
2968 &*session.connection_pool().await,
2969 );
2970
2971 update_user_contacts(session.user_id, &session).await?;
2972 Ok(())
2973}
2974
2975/// Start editing the channel notes
2976async fn join_channel_buffer(
2977 request: proto::JoinChannelBuffer,
2978 response: Response<proto::JoinChannelBuffer>,
2979 session: Session,
2980) -> Result<()> {
2981 let db = session.db().await;
2982 let channel_id = ChannelId::from_proto(request.channel_id);
2983
2984 let open_response = db
2985 .join_channel_buffer(channel_id, session.user_id, session.connection_id)
2986 .await?;
2987
2988 let collaborators = open_response.collaborators.clone();
2989 response.send(open_response)?;
2990
2991 let update = UpdateChannelBufferCollaborators {
2992 channel_id: channel_id.to_proto(),
2993 collaborators: collaborators.clone(),
2994 };
2995 channel_buffer_updated(
2996 session.connection_id,
2997 collaborators
2998 .iter()
2999 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3000 &update,
3001 &session.peer,
3002 );
3003
3004 Ok(())
3005}
3006
3007/// Edit the channel notes
3008async fn update_channel_buffer(
3009 request: proto::UpdateChannelBuffer,
3010 session: Session,
3011) -> Result<()> {
3012 let db = session.db().await;
3013 let channel_id = ChannelId::from_proto(request.channel_id);
3014
3015 let (collaborators, non_collaborators, epoch, version) = db
3016 .update_channel_buffer(channel_id, session.user_id, &request.operations)
3017 .await?;
3018
3019 channel_buffer_updated(
3020 session.connection_id,
3021 collaborators,
3022 &proto::UpdateChannelBuffer {
3023 channel_id: channel_id.to_proto(),
3024 operations: request.operations,
3025 },
3026 &session.peer,
3027 );
3028
3029 let pool = &*session.connection_pool().await;
3030
3031 broadcast(
3032 None,
3033 non_collaborators
3034 .iter()
3035 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3036 |peer_id| {
3037 session.peer.send(
3038 peer_id,
3039 proto::UpdateChannels {
3040 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3041 channel_id: channel_id.to_proto(),
3042 epoch: epoch as u64,
3043 version: version.clone(),
3044 }],
3045 ..Default::default()
3046 },
3047 )
3048 },
3049 );
3050
3051 Ok(())
3052}
3053
3054/// Rejoin the channel notes after a connection blip
3055async fn rejoin_channel_buffers(
3056 request: proto::RejoinChannelBuffers,
3057 response: Response<proto::RejoinChannelBuffers>,
3058 session: Session,
3059) -> Result<()> {
3060 let db = session.db().await;
3061 let buffers = db
3062 .rejoin_channel_buffers(&request.buffers, session.user_id, session.connection_id)
3063 .await?;
3064
3065 for rejoined_buffer in &buffers {
3066 let collaborators_to_notify = rejoined_buffer
3067 .buffer
3068 .collaborators
3069 .iter()
3070 .filter_map(|c| Some(c.peer_id?.into()));
3071 channel_buffer_updated(
3072 session.connection_id,
3073 collaborators_to_notify,
3074 &proto::UpdateChannelBufferCollaborators {
3075 channel_id: rejoined_buffer.buffer.channel_id,
3076 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3077 },
3078 &session.peer,
3079 );
3080 }
3081
3082 response.send(proto::RejoinChannelBuffersResponse {
3083 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3084 })?;
3085
3086 Ok(())
3087}
3088
3089/// Stop editing the channel notes
3090async fn leave_channel_buffer(
3091 request: proto::LeaveChannelBuffer,
3092 response: Response<proto::LeaveChannelBuffer>,
3093 session: Session,
3094) -> Result<()> {
3095 let db = session.db().await;
3096 let channel_id = ChannelId::from_proto(request.channel_id);
3097
3098 let left_buffer = db
3099 .leave_channel_buffer(channel_id, session.connection_id)
3100 .await?;
3101
3102 response.send(Ack {})?;
3103
3104 channel_buffer_updated(
3105 session.connection_id,
3106 left_buffer.connections,
3107 &proto::UpdateChannelBufferCollaborators {
3108 channel_id: channel_id.to_proto(),
3109 collaborators: left_buffer.collaborators,
3110 },
3111 &session.peer,
3112 );
3113
3114 Ok(())
3115}
3116
3117fn channel_buffer_updated<T: EnvelopedMessage>(
3118 sender_id: ConnectionId,
3119 collaborators: impl IntoIterator<Item = ConnectionId>,
3120 message: &T,
3121 peer: &Peer,
3122) {
3123 broadcast(Some(sender_id), collaborators, |peer_id| {
3124 peer.send(peer_id, message.clone())
3125 });
3126}
3127
3128fn send_notifications(
3129 connection_pool: &ConnectionPool,
3130 peer: &Peer,
3131 notifications: db::NotificationBatch,
3132) {
3133 for (user_id, notification) in notifications {
3134 for connection_id in connection_pool.user_connection_ids(user_id) {
3135 if let Err(error) = peer.send(
3136 connection_id,
3137 proto::AddNotification {
3138 notification: Some(notification.clone()),
3139 },
3140 ) {
3141 tracing::error!(
3142 "failed to send notification to {:?} {}",
3143 connection_id,
3144 error
3145 );
3146 }
3147 }
3148 }
3149}
3150
3151/// Send a message to the channel
3152async fn send_channel_message(
3153 request: proto::SendChannelMessage,
3154 response: Response<proto::SendChannelMessage>,
3155 session: Session,
3156) -> Result<()> {
3157 // Validate the message body.
3158 let body = request.body.trim().to_string();
3159 if body.len() > MAX_MESSAGE_LEN {
3160 return Err(anyhow!("message is too long"))?;
3161 }
3162 if body.is_empty() {
3163 return Err(anyhow!("message can't be blank"))?;
3164 }
3165
3166 // TODO: adjust mentions if body is trimmed
3167
3168 let timestamp = OffsetDateTime::now_utc();
3169 let nonce = request
3170 .nonce
3171 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3172
3173 let channel_id = ChannelId::from_proto(request.channel_id);
3174 let CreatedChannelMessage {
3175 message_id,
3176 participant_connection_ids,
3177 channel_members,
3178 notifications,
3179 } = session
3180 .db()
3181 .await
3182 .create_channel_message(
3183 channel_id,
3184 session.user_id,
3185 &body,
3186 &request.mentions,
3187 timestamp,
3188 nonce.clone().into(),
3189 match request.reply_to_message_id {
3190 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3191 None => None,
3192 },
3193 )
3194 .await?;
3195
3196 let message = proto::ChannelMessage {
3197 sender_id: session.user_id.to_proto(),
3198 id: message_id.to_proto(),
3199 body,
3200 mentions: request.mentions,
3201 timestamp: timestamp.unix_timestamp() as u64,
3202 nonce: Some(nonce),
3203 reply_to_message_id: request.reply_to_message_id,
3204 edited_at: None,
3205 };
3206 broadcast(
3207 Some(session.connection_id),
3208 participant_connection_ids,
3209 |connection| {
3210 session.peer.send(
3211 connection,
3212 proto::ChannelMessageSent {
3213 channel_id: channel_id.to_proto(),
3214 message: Some(message.clone()),
3215 },
3216 )
3217 },
3218 );
3219 response.send(proto::SendChannelMessageResponse {
3220 message: Some(message),
3221 })?;
3222
3223 let pool = &*session.connection_pool().await;
3224 broadcast(
3225 None,
3226 channel_members
3227 .iter()
3228 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3229 |peer_id| {
3230 session.peer.send(
3231 peer_id,
3232 proto::UpdateChannels {
3233 latest_channel_message_ids: vec![proto::ChannelMessageId {
3234 channel_id: channel_id.to_proto(),
3235 message_id: message_id.to_proto(),
3236 }],
3237 ..Default::default()
3238 },
3239 )
3240 },
3241 );
3242 send_notifications(pool, &session.peer, notifications);
3243
3244 Ok(())
3245}
3246
3247/// Delete a channel message
3248async fn remove_channel_message(
3249 request: proto::RemoveChannelMessage,
3250 response: Response<proto::RemoveChannelMessage>,
3251 session: Session,
3252) -> Result<()> {
3253 let channel_id = ChannelId::from_proto(request.channel_id);
3254 let message_id = MessageId::from_proto(request.message_id);
3255 let connection_ids = session
3256 .db()
3257 .await
3258 .remove_channel_message(channel_id, message_id, session.user_id)
3259 .await?;
3260 broadcast(Some(session.connection_id), connection_ids, |connection| {
3261 session.peer.send(connection, request.clone())
3262 });
3263 response.send(proto::Ack {})?;
3264 Ok(())
3265}
3266
3267async fn update_channel_message(
3268 request: proto::UpdateChannelMessage,
3269 response: Response<proto::UpdateChannelMessage>,
3270 session: Session,
3271) -> Result<()> {
3272 let channel_id = ChannelId::from_proto(request.channel_id);
3273 let message_id = MessageId::from_proto(request.message_id);
3274 let updated_at = OffsetDateTime::now_utc();
3275 let UpdatedChannelMessage {
3276 message_id,
3277 participant_connection_ids,
3278 notifications,
3279 reply_to_message_id,
3280 timestamp,
3281 } = session
3282 .db()
3283 .await
3284 .update_channel_message(
3285 channel_id,
3286 message_id,
3287 session.user_id,
3288 request.body.as_str(),
3289 &request.mentions,
3290 updated_at,
3291 )
3292 .await?;
3293
3294 let nonce = request
3295 .nonce
3296 .clone()
3297 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3298
3299 let message = proto::ChannelMessage {
3300 sender_id: session.user_id.to_proto(),
3301 id: message_id.to_proto(),
3302 body: request.body.clone(),
3303 mentions: request.mentions.clone(),
3304 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3305 nonce: Some(nonce),
3306 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3307 edited_at: Some(updated_at.unix_timestamp() as u64),
3308 };
3309
3310 response.send(proto::Ack {})?;
3311
3312 let pool = &*session.connection_pool().await;
3313 broadcast(
3314 Some(session.connection_id),
3315 participant_connection_ids,
3316 |connection| {
3317 session.peer.send(
3318 connection,
3319 proto::ChannelMessageUpdate {
3320 channel_id: channel_id.to_proto(),
3321 message: Some(message.clone()),
3322 },
3323 )
3324 },
3325 );
3326
3327 send_notifications(pool, &session.peer, notifications);
3328
3329 Ok(())
3330}
3331
3332/// Mark a channel message as read
3333async fn acknowledge_channel_message(
3334 request: proto::AckChannelMessage,
3335 session: Session,
3336) -> Result<()> {
3337 let channel_id = ChannelId::from_proto(request.channel_id);
3338 let message_id = MessageId::from_proto(request.message_id);
3339 let notifications = session
3340 .db()
3341 .await
3342 .observe_channel_message(channel_id, session.user_id, message_id)
3343 .await?;
3344 send_notifications(
3345 &*session.connection_pool().await,
3346 &session.peer,
3347 notifications,
3348 );
3349 Ok(())
3350}
3351
3352/// Mark a buffer version as synced
3353async fn acknowledge_buffer_version(
3354 request: proto::AckBufferOperation,
3355 session: Session,
3356) -> Result<()> {
3357 let buffer_id = BufferId::from_proto(request.buffer_id);
3358 session
3359 .db()
3360 .await
3361 .observe_buffer_version(
3362 buffer_id,
3363 session.user_id,
3364 request.epoch as i32,
3365 &request.version,
3366 )
3367 .await?;
3368 Ok(())
3369}
3370
3371struct CompleteWithLanguageModelRateLimit;
3372
3373impl RateLimit for CompleteWithLanguageModelRateLimit {
3374 fn capacity() -> usize {
3375 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3376 .ok()
3377 .and_then(|v| v.parse().ok())
3378 .unwrap_or(120) // Picked arbitrarily
3379 }
3380
3381 fn refill_duration() -> chrono::Duration {
3382 chrono::Duration::hours(1)
3383 }
3384
3385 fn db_name() -> &'static str {
3386 "complete-with-language-model"
3387 }
3388}
3389
3390async fn complete_with_language_model(
3391 request: proto::CompleteWithLanguageModel,
3392 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3393 session: Session,
3394 open_ai_api_key: Option<Arc<str>>,
3395 google_ai_api_key: Option<Arc<str>>,
3396) -> Result<()> {
3397 authorize_access_to_language_models(&session).await?;
3398 session
3399 .rate_limiter
3400 .check::<CompleteWithLanguageModelRateLimit>(session.user_id)
3401 .await?;
3402
3403 if request.model.starts_with("gpt") {
3404 let api_key =
3405 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
3406 complete_with_open_ai(request, response, session, api_key).await?;
3407 } else if request.model.starts_with("gemini") {
3408 let api_key = google_ai_api_key
3409 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3410 complete_with_google_ai(request, response, session, api_key).await?;
3411 }
3412
3413 Ok(())
3414}
3415
3416async fn complete_with_open_ai(
3417 request: proto::CompleteWithLanguageModel,
3418 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3419 session: Session,
3420 api_key: Arc<str>,
3421) -> Result<()> {
3422 const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
3423
3424 let mut completion_stream = open_ai::stream_completion(
3425 &session.http_client,
3426 OPEN_AI_API_URL,
3427 &api_key,
3428 crate::ai::language_model_request_to_open_ai(request)?,
3429 )
3430 .await
3431 .context("open_ai::stream_completion request failed")?;
3432
3433 while let Some(event) = completion_stream.next().await {
3434 let event = event?;
3435 response.send(proto::LanguageModelResponse {
3436 choices: event
3437 .choices
3438 .into_iter()
3439 .map(|choice| proto::LanguageModelChoiceDelta {
3440 index: choice.index,
3441 delta: Some(proto::LanguageModelResponseMessage {
3442 role: choice.delta.role.map(|role| match role {
3443 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
3444 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
3445 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
3446 } as i32),
3447 content: choice.delta.content,
3448 }),
3449 finish_reason: choice.finish_reason,
3450 })
3451 .collect(),
3452 })?;
3453 }
3454
3455 Ok(())
3456}
3457
3458async fn complete_with_google_ai(
3459 request: proto::CompleteWithLanguageModel,
3460 response: StreamingResponse<proto::CompleteWithLanguageModel>,
3461 session: Session,
3462 api_key: Arc<str>,
3463) -> Result<()> {
3464 let mut stream = google_ai::stream_generate_content(
3465 &session.http_client,
3466 google_ai::API_URL,
3467 api_key.as_ref(),
3468 crate::ai::language_model_request_to_google_ai(request)?,
3469 )
3470 .await
3471 .context("google_ai::stream_generate_content request failed")?;
3472
3473 while let Some(event) = stream.next().await {
3474 let event = event?;
3475 response.send(proto::LanguageModelResponse {
3476 choices: event
3477 .candidates
3478 .unwrap_or_default()
3479 .into_iter()
3480 .map(|candidate| proto::LanguageModelChoiceDelta {
3481 index: candidate.index as u32,
3482 delta: Some(proto::LanguageModelResponseMessage {
3483 role: Some(match candidate.content.role {
3484 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
3485 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
3486 } as i32),
3487 content: Some(
3488 candidate
3489 .content
3490 .parts
3491 .into_iter()
3492 .filter_map(|part| match part {
3493 google_ai::Part::TextPart(part) => Some(part.text),
3494 google_ai::Part::InlineDataPart(_) => None,
3495 })
3496 .collect(),
3497 ),
3498 }),
3499 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
3500 })
3501 .collect(),
3502 })?;
3503 }
3504
3505 Ok(())
3506}
3507
3508struct CountTokensWithLanguageModelRateLimit;
3509
3510impl RateLimit for CountTokensWithLanguageModelRateLimit {
3511 fn capacity() -> usize {
3512 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
3513 .ok()
3514 .and_then(|v| v.parse().ok())
3515 .unwrap_or(600) // Picked arbitrarily
3516 }
3517
3518 fn refill_duration() -> chrono::Duration {
3519 chrono::Duration::hours(1)
3520 }
3521
3522 fn db_name() -> &'static str {
3523 "count-tokens-with-language-model"
3524 }
3525}
3526
3527async fn count_tokens_with_language_model(
3528 request: proto::CountTokensWithLanguageModel,
3529 response: Response<proto::CountTokensWithLanguageModel>,
3530 session: Session,
3531 google_ai_api_key: Option<Arc<str>>,
3532) -> Result<()> {
3533 authorize_access_to_language_models(&session).await?;
3534
3535 if !request.model.starts_with("gemini") {
3536 return Err(anyhow!(
3537 "counting tokens for model: {:?} is not supported",
3538 request.model
3539 ))?;
3540 }
3541
3542 session
3543 .rate_limiter
3544 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id)
3545 .await?;
3546
3547 let api_key = google_ai_api_key
3548 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
3549 let tokens_response = google_ai::count_tokens(
3550 &session.http_client,
3551 google_ai::API_URL,
3552 &api_key,
3553 crate::ai::count_tokens_request_to_google_ai(request)?,
3554 )
3555 .await?;
3556 response.send(proto::CountTokensResponse {
3557 token_count: tokens_response.total_tokens as u32,
3558 })?;
3559 Ok(())
3560}
3561
3562async fn authorize_access_to_language_models(session: &Session) -> Result<(), Error> {
3563 let db = session.db().await;
3564 let flags = db.get_user_flags(session.user_id).await?;
3565 if flags.iter().any(|flag| flag == "language-models") {
3566 Ok(())
3567 } else {
3568 Err(anyhow!("permission denied"))?
3569 }
3570}
3571
3572/// Start receiving chat updates for a channel
3573async fn join_channel_chat(
3574 request: proto::JoinChannelChat,
3575 response: Response<proto::JoinChannelChat>,
3576 session: Session,
3577) -> Result<()> {
3578 let channel_id = ChannelId::from_proto(request.channel_id);
3579
3580 let db = session.db().await;
3581 db.join_channel_chat(channel_id, session.connection_id, session.user_id)
3582 .await?;
3583 let messages = db
3584 .get_channel_messages(channel_id, session.user_id, MESSAGE_COUNT_PER_PAGE, None)
3585 .await?;
3586 response.send(proto::JoinChannelChatResponse {
3587 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3588 messages,
3589 })?;
3590 Ok(())
3591}
3592
3593/// Stop receiving chat updates for a channel
3594async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3595 let channel_id = ChannelId::from_proto(request.channel_id);
3596 session
3597 .db()
3598 .await
3599 .leave_channel_chat(channel_id, session.connection_id, session.user_id)
3600 .await?;
3601 Ok(())
3602}
3603
3604/// Retrieve the chat history for a channel
3605async fn get_channel_messages(
3606 request: proto::GetChannelMessages,
3607 response: Response<proto::GetChannelMessages>,
3608 session: Session,
3609) -> Result<()> {
3610 let channel_id = ChannelId::from_proto(request.channel_id);
3611 let messages = session
3612 .db()
3613 .await
3614 .get_channel_messages(
3615 channel_id,
3616 session.user_id,
3617 MESSAGE_COUNT_PER_PAGE,
3618 Some(MessageId::from_proto(request.before_message_id)),
3619 )
3620 .await?;
3621 response.send(proto::GetChannelMessagesResponse {
3622 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3623 messages,
3624 })?;
3625 Ok(())
3626}
3627
3628/// Retrieve specific chat messages
3629async fn get_channel_messages_by_id(
3630 request: proto::GetChannelMessagesById,
3631 response: Response<proto::GetChannelMessagesById>,
3632 session: Session,
3633) -> Result<()> {
3634 let message_ids = request
3635 .message_ids
3636 .iter()
3637 .map(|id| MessageId::from_proto(*id))
3638 .collect::<Vec<_>>();
3639 let messages = session
3640 .db()
3641 .await
3642 .get_channel_messages_by_id(session.user_id, &message_ids)
3643 .await?;
3644 response.send(proto::GetChannelMessagesResponse {
3645 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3646 messages,
3647 })?;
3648 Ok(())
3649}
3650
3651/// Retrieve the current users notifications
3652async fn get_notifications(
3653 request: proto::GetNotifications,
3654 response: Response<proto::GetNotifications>,
3655 session: Session,
3656) -> Result<()> {
3657 let notifications = session
3658 .db()
3659 .await
3660 .get_notifications(
3661 session.user_id,
3662 NOTIFICATION_COUNT_PER_PAGE,
3663 request
3664 .before_id
3665 .map(|id| db::NotificationId::from_proto(id)),
3666 )
3667 .await?;
3668 response.send(proto::GetNotificationsResponse {
3669 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3670 notifications,
3671 })?;
3672 Ok(())
3673}
3674
3675/// Mark notifications as read
3676async fn mark_notification_as_read(
3677 request: proto::MarkNotificationRead,
3678 response: Response<proto::MarkNotificationRead>,
3679 session: Session,
3680) -> Result<()> {
3681 let database = &session.db().await;
3682 let notifications = database
3683 .mark_notification_as_read_by_id(
3684 session.user_id,
3685 NotificationId::from_proto(request.notification_id),
3686 )
3687 .await?;
3688 send_notifications(
3689 &*session.connection_pool().await,
3690 &session.peer,
3691 notifications,
3692 );
3693 response.send(proto::Ack {})?;
3694 Ok(())
3695}
3696
3697/// Get the current users information
3698async fn get_private_user_info(
3699 _request: proto::GetPrivateUserInfo,
3700 response: Response<proto::GetPrivateUserInfo>,
3701 session: Session,
3702) -> Result<()> {
3703 let db = session.db().await;
3704
3705 let metrics_id = db.get_user_metrics_id(session.user_id).await?;
3706 let user = db
3707 .get_user_by_id(session.user_id)
3708 .await?
3709 .ok_or_else(|| anyhow!("user not found"))?;
3710 let flags = db.get_user_flags(session.user_id).await?;
3711
3712 response.send(proto::GetPrivateUserInfoResponse {
3713 metrics_id,
3714 staff: user.admin,
3715 flags,
3716 })?;
3717 Ok(())
3718}
3719
3720fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
3721 match message {
3722 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
3723 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
3724 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
3725 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
3726 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3727 code: frame.code.into(),
3728 reason: frame.reason,
3729 })),
3730 }
3731}
3732
3733fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3734 match message {
3735 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
3736 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
3737 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
3738 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
3739 AxumMessage::Close(frame) => {
3740 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3741 code: frame.code.into(),
3742 reason: frame.reason,
3743 }))
3744 }
3745 }
3746}
3747
3748fn notify_membership_updated(
3749 connection_pool: &mut ConnectionPool,
3750 result: MembershipUpdated,
3751 user_id: UserId,
3752 peer: &Peer,
3753) {
3754 for membership in &result.new_channels.channel_memberships {
3755 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3756 }
3757 for channel_id in &result.removed_channels {
3758 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3759 }
3760
3761 let user_channels_update = proto::UpdateUserChannels {
3762 channel_memberships: result
3763 .new_channels
3764 .channel_memberships
3765 .iter()
3766 .map(|cm| proto::ChannelMembership {
3767 channel_id: cm.channel_id.to_proto(),
3768 role: cm.role.into(),
3769 })
3770 .collect(),
3771 ..Default::default()
3772 };
3773
3774 let mut update = build_channels_update(result.new_channels, vec![]);
3775 update.delete_channels = result
3776 .removed_channels
3777 .into_iter()
3778 .map(|id| id.to_proto())
3779 .collect();
3780 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3781
3782 for connection_id in connection_pool.user_connection_ids(user_id) {
3783 peer.send(connection_id, user_channels_update.clone())
3784 .trace_err();
3785 peer.send(connection_id, update.clone()).trace_err();
3786 }
3787}
3788
3789fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3790 proto::UpdateUserChannels {
3791 channel_memberships: channels
3792 .channel_memberships
3793 .iter()
3794 .map(|m| proto::ChannelMembership {
3795 channel_id: m.channel_id.to_proto(),
3796 role: m.role.into(),
3797 })
3798 .collect(),
3799 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3800 observed_channel_message_id: channels.observed_channel_messages.clone(),
3801 }
3802}
3803
3804fn build_channels_update(
3805 channels: ChannelsForUser,
3806 channel_invites: Vec<db::Channel>,
3807) -> proto::UpdateChannels {
3808 let mut update = proto::UpdateChannels::default();
3809
3810 for channel in channels.channels {
3811 update.channels.push(channel.to_proto());
3812 }
3813
3814 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3815 update.latest_channel_message_ids = channels.latest_channel_messages;
3816
3817 for (channel_id, participants) in channels.channel_participants {
3818 update
3819 .channel_participants
3820 .push(proto::ChannelParticipants {
3821 channel_id: channel_id.to_proto(),
3822 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3823 });
3824 }
3825
3826 for channel in channel_invites {
3827 update.channel_invitations.push(channel.to_proto());
3828 }
3829 for project in channels.hosted_projects {
3830 update.hosted_projects.push(project);
3831 }
3832
3833 update
3834}
3835
3836fn build_initial_contacts_update(
3837 contacts: Vec<db::Contact>,
3838 pool: &ConnectionPool,
3839) -> proto::UpdateContacts {
3840 let mut update = proto::UpdateContacts::default();
3841
3842 for contact in contacts {
3843 match contact {
3844 db::Contact::Accepted { user_id, busy } => {
3845 update.contacts.push(contact_for_user(user_id, busy, &pool));
3846 }
3847 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3848 db::Contact::Incoming { user_id } => {
3849 update
3850 .incoming_requests
3851 .push(proto::IncomingContactRequest {
3852 requester_id: user_id.to_proto(),
3853 })
3854 }
3855 }
3856 }
3857
3858 update
3859}
3860
3861fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3862 proto::Contact {
3863 user_id: user_id.to_proto(),
3864 online: pool.is_user_online(user_id),
3865 busy,
3866 }
3867}
3868
3869fn room_updated(room: &proto::Room, peer: &Peer) {
3870 broadcast(
3871 None,
3872 room.participants
3873 .iter()
3874 .filter_map(|participant| Some(participant.peer_id?.into())),
3875 |peer_id| {
3876 peer.send(
3877 peer_id,
3878 proto::RoomUpdated {
3879 room: Some(room.clone()),
3880 },
3881 )
3882 },
3883 );
3884}
3885
3886fn channel_updated(
3887 channel: &db::channel::Model,
3888 room: &proto::Room,
3889 peer: &Peer,
3890 pool: &ConnectionPool,
3891) {
3892 let participants = room
3893 .participants
3894 .iter()
3895 .map(|p| p.user_id)
3896 .collect::<Vec<_>>();
3897
3898 broadcast(
3899 None,
3900 pool.channel_connection_ids(channel.root_id())
3901 .filter_map(|(channel_id, role)| {
3902 role.can_see_channel(channel.visibility).then(|| channel_id)
3903 }),
3904 |peer_id| {
3905 peer.send(
3906 peer_id,
3907 proto::UpdateChannels {
3908 channel_participants: vec![proto::ChannelParticipants {
3909 channel_id: channel.id.to_proto(),
3910 participant_user_ids: participants.clone(),
3911 }],
3912 ..Default::default()
3913 },
3914 )
3915 },
3916 );
3917}
3918
3919async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3920 let db = session.db().await;
3921
3922 let contacts = db.get_contacts(user_id).await?;
3923 let busy = db.is_user_busy(user_id).await?;
3924
3925 let pool = session.connection_pool().await;
3926 let updated_contact = contact_for_user(user_id, busy, &pool);
3927 for contact in contacts {
3928 if let db::Contact::Accepted {
3929 user_id: contact_user_id,
3930 ..
3931 } = contact
3932 {
3933 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3934 session
3935 .peer
3936 .send(
3937 contact_conn_id,
3938 proto::UpdateContacts {
3939 contacts: vec![updated_contact.clone()],
3940 remove_contacts: Default::default(),
3941 incoming_requests: Default::default(),
3942 remove_incoming_requests: Default::default(),
3943 outgoing_requests: Default::default(),
3944 remove_outgoing_requests: Default::default(),
3945 },
3946 )
3947 .trace_err();
3948 }
3949 }
3950 }
3951 Ok(())
3952}
3953
3954async fn leave_room_for_session(session: &Session) -> Result<()> {
3955 let mut contacts_to_update = HashSet::default();
3956
3957 let room_id;
3958 let canceled_calls_to_user_ids;
3959 let live_kit_room;
3960 let delete_live_kit_room;
3961 let room;
3962 let channel;
3963
3964 if let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? {
3965 contacts_to_update.insert(session.user_id);
3966
3967 for project in left_room.left_projects.values() {
3968 project_left(project, session);
3969 }
3970
3971 room_id = RoomId::from_proto(left_room.room.id);
3972 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3973 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
3974 delete_live_kit_room = left_room.deleted;
3975 room = mem::take(&mut left_room.room);
3976 channel = mem::take(&mut left_room.channel);
3977
3978 room_updated(&room, &session.peer);
3979 } else {
3980 return Ok(());
3981 }
3982
3983 if let Some(channel) = channel {
3984 channel_updated(
3985 &channel,
3986 &room,
3987 &session.peer,
3988 &*session.connection_pool().await,
3989 );
3990 }
3991
3992 {
3993 let pool = session.connection_pool().await;
3994 for canceled_user_id in canceled_calls_to_user_ids {
3995 for connection_id in pool.user_connection_ids(canceled_user_id) {
3996 session
3997 .peer
3998 .send(
3999 connection_id,
4000 proto::CallCanceled {
4001 room_id: room_id.to_proto(),
4002 },
4003 )
4004 .trace_err();
4005 }
4006 contacts_to_update.insert(canceled_user_id);
4007 }
4008 }
4009
4010 for contact_user_id in contacts_to_update {
4011 update_user_contacts(contact_user_id, &session).await?;
4012 }
4013
4014 if let Some(live_kit) = session.live_kit_client.as_ref() {
4015 live_kit
4016 .remove_participant(live_kit_room.clone(), session.user_id.to_string())
4017 .await
4018 .trace_err();
4019
4020 if delete_live_kit_room {
4021 live_kit.delete_room(live_kit_room).await.trace_err();
4022 }
4023 }
4024
4025 Ok(())
4026}
4027
4028async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4029 let left_channel_buffers = session
4030 .db()
4031 .await
4032 .leave_channel_buffers(session.connection_id)
4033 .await?;
4034
4035 for left_buffer in left_channel_buffers {
4036 channel_buffer_updated(
4037 session.connection_id,
4038 left_buffer.connections,
4039 &proto::UpdateChannelBufferCollaborators {
4040 channel_id: left_buffer.channel_id.to_proto(),
4041 collaborators: left_buffer.collaborators,
4042 },
4043 &session.peer,
4044 );
4045 }
4046
4047 Ok(())
4048}
4049
4050fn project_left(project: &db::LeftProject, session: &Session) {
4051 for connection_id in &project.connection_ids {
4052 if project.host_user_id == Some(session.user_id) {
4053 session
4054 .peer
4055 .send(
4056 *connection_id,
4057 proto::UnshareProject {
4058 project_id: project.id.to_proto(),
4059 },
4060 )
4061 .trace_err();
4062 } else {
4063 session
4064 .peer
4065 .send(
4066 *connection_id,
4067 proto::RemoveProjectCollaborator {
4068 project_id: project.id.to_proto(),
4069 peer_id: Some(session.connection_id.into()),
4070 },
4071 )
4072 .trace_err();
4073 }
4074 }
4075}
4076
4077pub trait ResultExt {
4078 type Ok;
4079
4080 fn trace_err(self) -> Option<Self::Ok>;
4081}
4082
4083impl<T, E> ResultExt for Result<T, E>
4084where
4085 E: std::fmt::Debug,
4086{
4087 type Ok = T;
4088
4089 fn trace_err(self) -> Option<T> {
4090 match self {
4091 Ok(value) => Some(value),
4092 Err(error) => {
4093 tracing::error!("{:?}", error);
4094 None
4095 }
4096 }
4097 }
4098}