1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
10 },
11 executor::Executor,
12};
13use anyhow::{Context as _, anyhow, bail};
14use async_tungstenite::tungstenite::{
15 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
16};
17use axum::headers::UserAgent;
18use axum::{
19 Extension, Router, TypedHeader,
20 body::Body,
21 extract::{
22 ConnectInfo, WebSocketUpgrade,
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use core::fmt::{self, Debug, Formatter};
34use futures::TryFutureExt as _;
35use reqwest_client::ReqwestClient;
36use rpc::proto::split_repository_update;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38use tracing::Span;
39use util::paths::PathStyle;
40
41use futures::{
42 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
43 stream::FuturesUnordered,
44};
45use prometheus::{IntGauge, register_int_gauge};
46use rpc::{
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52};
53use semantic_version::SemanticVersion;
54use serde::{Serialize, Serializer};
55use std::{
56 any::TypeId,
57 future::Future,
58 marker::PhantomData,
59 mem,
60 net::SocketAddr,
61 ops::{Deref, DerefMut},
62 rc::Rc,
63 sync::{
64 Arc, OnceLock,
65 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
66 },
67 time::{Duration, Instant},
68};
69use tokio::sync::{Semaphore, watch};
70use tower::ServiceBuilder;
71use tracing::{
72 Instrument,
73 field::{self},
74 info_span, instrument,
75};
76
77pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
78
79// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
80pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
81
82const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
83const MAX_CONCURRENT_CONNECTIONS: usize = 512;
84
85static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
86
87const TOTAL_DURATION_MS: &str = "total_duration_ms";
88const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
89const QUEUE_DURATION_MS: &str = "queue_duration_ms";
90const HOST_WAITING_MS: &str = "host_waiting_ms";
91
92type MessageHandler =
93 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
94
95pub struct ConnectionGuard;
96
97impl ConnectionGuard {
98 pub fn try_acquire() -> Result<Self, ()> {
99 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
100 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
101 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
102 tracing::error!(
103 "too many concurrent connections: {}",
104 current_connections + 1
105 );
106 return Err(());
107 }
108 Ok(ConnectionGuard)
109 }
110}
111
112impl Drop for ConnectionGuard {
113 fn drop(&mut self) {
114 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
115 }
116}
117
118struct Response<R> {
119 peer: Arc<Peer>,
120 receipt: Receipt<R>,
121 responded: Arc<AtomicBool>,
122}
123
124impl<R: RequestMessage> Response<R> {
125 fn send(self, payload: R::Response) -> Result<()> {
126 self.responded.store(true, SeqCst);
127 self.peer.respond(self.receipt, payload)?;
128 Ok(())
129 }
130}
131
132#[derive(Clone, Debug)]
133pub enum Principal {
134 User(User),
135 Impersonated { user: User, admin: User },
136}
137
138impl Principal {
139 fn update_span(&self, span: &tracing::Span) {
140 match &self {
141 Principal::User(user) => {
142 span.record("user_id", user.id.0);
143 span.record("login", &user.github_login);
144 }
145 Principal::Impersonated { user, admin } => {
146 span.record("user_id", user.id.0);
147 span.record("login", &user.github_login);
148 span.record("impersonator", &admin.github_login);
149 }
150 }
151 }
152}
153
154#[derive(Clone)]
155struct MessageContext {
156 session: Session,
157 span: tracing::Span,
158}
159
160impl Deref for MessageContext {
161 type Target = Session;
162
163 fn deref(&self) -> &Self::Target {
164 &self.session
165 }
166}
167
168impl MessageContext {
169 pub fn forward_request<T: RequestMessage>(
170 &self,
171 receiver_id: ConnectionId,
172 request: T,
173 ) -> impl Future<Output = anyhow::Result<T::Response>> {
174 let request_start_time = Instant::now();
175 let span = self.span.clone();
176 tracing::info!("start forwarding request");
177 self.peer
178 .forward_request(self.connection_id, receiver_id, request)
179 .inspect(move |_| {
180 span.record(
181 HOST_WAITING_MS,
182 request_start_time.elapsed().as_micros() as f64 / 1000.0,
183 );
184 })
185 .inspect_err(|_| tracing::error!("error forwarding request"))
186 .inspect_ok(|_| tracing::info!("finished forwarding request"))
187 }
188}
189
190#[derive(Clone)]
191struct Session {
192 principal: Principal,
193 connection_id: ConnectionId,
194 db: Arc<tokio::sync::Mutex<DbHandle>>,
195 peer: Arc<Peer>,
196 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
197 app_state: Arc<AppState>,
198 supermaven_client: Option<Arc<SupermavenAdminApi>>,
199 /// The GeoIP country code for the user.
200 #[allow(unused)]
201 geoip_country_code: Option<String>,
202 #[allow(unused)]
203 system_id: Option<String>,
204 _executor: Executor,
205}
206
207impl Session {
208 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
209 #[cfg(test)]
210 tokio::task::yield_now().await;
211 let guard = self.db.lock().await;
212 #[cfg(test)]
213 tokio::task::yield_now().await;
214 guard
215 }
216
217 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
218 #[cfg(test)]
219 tokio::task::yield_now().await;
220 let guard = self.connection_pool.lock();
221 ConnectionPoolGuard {
222 guard,
223 _not_send: PhantomData,
224 }
225 }
226
227 fn is_staff(&self) -> bool {
228 match &self.principal {
229 Principal::User(user) => user.admin,
230 Principal::Impersonated { .. } => true,
231 }
232 }
233
234 fn user_id(&self) -> UserId {
235 match &self.principal {
236 Principal::User(user) => user.id,
237 Principal::Impersonated { user, .. } => user.id,
238 }
239 }
240
241 pub fn email(&self) -> Option<String> {
242 match &self.principal {
243 Principal::User(user) => user.email_address.clone(),
244 Principal::Impersonated { user, .. } => user.email_address.clone(),
245 }
246 }
247}
248
249impl Debug for Session {
250 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
251 let mut result = f.debug_struct("Session");
252 match &self.principal {
253 Principal::User(user) => {
254 result.field("user", &user.github_login);
255 }
256 Principal::Impersonated { user, admin } => {
257 result.field("user", &user.github_login);
258 result.field("impersonator", &admin.github_login);
259 }
260 }
261 result.field("connection_id", &self.connection_id).finish()
262 }
263}
264
265struct DbHandle(Arc<Database>);
266
267impl Deref for DbHandle {
268 type Target = Database;
269
270 fn deref(&self) -> &Self::Target {
271 self.0.as_ref()
272 }
273}
274
275pub struct Server {
276 id: parking_lot::Mutex<ServerId>,
277 peer: Arc<Peer>,
278 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
279 app_state: Arc<AppState>,
280 handlers: HashMap<TypeId, MessageHandler>,
281 teardown: watch::Sender<bool>,
282}
283
284pub(crate) struct ConnectionPoolGuard<'a> {
285 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
286 _not_send: PhantomData<Rc<()>>,
287}
288
289#[derive(Serialize)]
290pub struct ServerSnapshot<'a> {
291 peer: &'a Peer,
292 #[serde(serialize_with = "serialize_deref")]
293 connection_pool: ConnectionPoolGuard<'a>,
294}
295
296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
297where
298 S: Serializer,
299 T: Deref<Target = U>,
300 U: Serialize,
301{
302 Serialize::serialize(value.deref(), serializer)
303}
304
305impl Server {
306 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
307 let mut server = Self {
308 id: parking_lot::Mutex::new(id),
309 peer: Peer::new(id.0 as u32),
310 app_state,
311 connection_pool: Default::default(),
312 handlers: Default::default(),
313 teardown: watch::channel(false).0,
314 };
315
316 server
317 .add_request_handler(ping)
318 .add_request_handler(create_room)
319 .add_request_handler(join_room)
320 .add_request_handler(rejoin_room)
321 .add_request_handler(leave_room)
322 .add_request_handler(set_room_participant_role)
323 .add_request_handler(call)
324 .add_request_handler(cancel_call)
325 .add_message_handler(decline_call)
326 .add_request_handler(update_participant_location)
327 .add_request_handler(share_project)
328 .add_message_handler(unshare_project)
329 .add_request_handler(join_project)
330 .add_message_handler(leave_project)
331 .add_request_handler(update_project)
332 .add_request_handler(update_worktree)
333 .add_request_handler(update_repository)
334 .add_request_handler(remove_repository)
335 .add_message_handler(start_language_server)
336 .add_message_handler(update_language_server)
337 .add_message_handler(update_diagnostic_summary)
338 .add_message_handler(update_worktree_settings)
339 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
340 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
341 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
342 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
343 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
344 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
345 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
346 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
347 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
348 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
349 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
350 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
351 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
352 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
353 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
354 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
355 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
356 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
357 .add_request_handler(
358 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
359 )
360 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
361 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
362 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
363 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
364 .add_request_handler(
365 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
366 )
367 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
368 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
369 .add_request_handler(
370 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
371 )
372 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
373 .add_request_handler(
374 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
375 )
376 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
377 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
378 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
379 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
380 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
381 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
382 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
383 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
384 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
385 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
386 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
387 .add_request_handler(
388 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
389 )
390 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
391 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
392 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
393 .add_request_handler(lsp_query)
394 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
395 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
396 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
397 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
398 .add_message_handler(create_buffer_for_peer)
399 .add_message_handler(create_image_for_peer)
400 .add_request_handler(update_buffer)
401 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
402 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
403 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
404 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
405 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
406 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
407 .add_message_handler(
408 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
409 )
410 .add_request_handler(get_users)
411 .add_request_handler(fuzzy_search_users)
412 .add_request_handler(request_contact)
413 .add_request_handler(remove_contact)
414 .add_request_handler(respond_to_contact_request)
415 .add_message_handler(subscribe_to_channels)
416 .add_request_handler(create_channel)
417 .add_request_handler(delete_channel)
418 .add_request_handler(invite_channel_member)
419 .add_request_handler(remove_channel_member)
420 .add_request_handler(set_channel_member_role)
421 .add_request_handler(set_channel_visibility)
422 .add_request_handler(rename_channel)
423 .add_request_handler(join_channel_buffer)
424 .add_request_handler(leave_channel_buffer)
425 .add_message_handler(update_channel_buffer)
426 .add_request_handler(rejoin_channel_buffers)
427 .add_request_handler(get_channel_members)
428 .add_request_handler(respond_to_channel_invite)
429 .add_request_handler(join_channel)
430 .add_request_handler(join_channel_chat)
431 .add_message_handler(leave_channel_chat)
432 .add_request_handler(send_channel_message)
433 .add_request_handler(remove_channel_message)
434 .add_request_handler(update_channel_message)
435 .add_request_handler(get_channel_messages)
436 .add_request_handler(get_channel_messages_by_id)
437 .add_request_handler(get_notifications)
438 .add_request_handler(mark_notification_as_read)
439 .add_request_handler(move_channel)
440 .add_request_handler(reorder_channel)
441 .add_request_handler(follow)
442 .add_message_handler(unfollow)
443 .add_message_handler(update_followers)
444 .add_message_handler(acknowledge_channel_message)
445 .add_message_handler(acknowledge_buffer_version)
446 .add_request_handler(get_supermaven_api_key)
447 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
448 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
449 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
450 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
451 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
452 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
453 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
454 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
455 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
456 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
457 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
458 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
459 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
460 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
461 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
462 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
463 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
464 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
465 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
466 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
467 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
468 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
469 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
470 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
471 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
472 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
473 .add_message_handler(update_context)
474 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
475 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
476
477 Arc::new(server)
478 }
479
480 pub async fn start(&self) -> Result<()> {
481 let server_id = *self.id.lock();
482 let app_state = self.app_state.clone();
483 let peer = self.peer.clone();
484 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
485 let pool = self.connection_pool.clone();
486 let livekit_client = self.app_state.livekit_client.clone();
487
488 let span = info_span!("start server");
489 self.app_state.executor.spawn_detached(
490 async move {
491 tracing::info!("waiting for cleanup timeout");
492 timeout.await;
493 tracing::info!("cleanup timeout expired, retrieving stale rooms");
494
495 app_state
496 .db
497 .delete_stale_channel_chat_participants(
498 &app_state.config.zed_environment,
499 server_id,
500 )
501 .await
502 .trace_err();
503
504 if let Some((room_ids, channel_ids)) = app_state
505 .db
506 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
507 .await
508 .trace_err()
509 {
510 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
511 tracing::info!(
512 stale_channel_buffer_count = channel_ids.len(),
513 "retrieved stale channel buffers"
514 );
515
516 for channel_id in channel_ids {
517 if let Some(refreshed_channel_buffer) = app_state
518 .db
519 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
520 .await
521 .trace_err()
522 {
523 for connection_id in refreshed_channel_buffer.connection_ids {
524 peer.send(
525 connection_id,
526 proto::UpdateChannelBufferCollaborators {
527 channel_id: channel_id.to_proto(),
528 collaborators: refreshed_channel_buffer
529 .collaborators
530 .clone(),
531 },
532 )
533 .trace_err();
534 }
535 }
536 }
537
538 for room_id in room_ids {
539 let mut contacts_to_update = HashSet::default();
540 let mut canceled_calls_to_user_ids = Vec::new();
541 let mut livekit_room = String::new();
542 let mut delete_livekit_room = false;
543
544 if let Some(mut refreshed_room) = app_state
545 .db
546 .clear_stale_room_participants(room_id, server_id)
547 .await
548 .trace_err()
549 {
550 tracing::info!(
551 room_id = room_id.0,
552 new_participant_count = refreshed_room.room.participants.len(),
553 "refreshed room"
554 );
555 room_updated(&refreshed_room.room, &peer);
556 if let Some(channel) = refreshed_room.channel.as_ref() {
557 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
558 }
559 contacts_to_update
560 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
561 contacts_to_update
562 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
563 canceled_calls_to_user_ids =
564 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
565 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
566 delete_livekit_room = refreshed_room.room.participants.is_empty();
567 }
568
569 {
570 let pool = pool.lock();
571 for canceled_user_id in canceled_calls_to_user_ids {
572 for connection_id in pool.user_connection_ids(canceled_user_id) {
573 peer.send(
574 connection_id,
575 proto::CallCanceled {
576 room_id: room_id.to_proto(),
577 },
578 )
579 .trace_err();
580 }
581 }
582 }
583
584 for user_id in contacts_to_update {
585 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
586 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
587 if let Some((busy, contacts)) = busy.zip(contacts) {
588 let pool = pool.lock();
589 let updated_contact = contact_for_user(user_id, busy, &pool);
590 for contact in contacts {
591 if let db::Contact::Accepted {
592 user_id: contact_user_id,
593 ..
594 } = contact
595 {
596 for contact_conn_id in
597 pool.user_connection_ids(contact_user_id)
598 {
599 peer.send(
600 contact_conn_id,
601 proto::UpdateContacts {
602 contacts: vec![updated_contact.clone()],
603 remove_contacts: Default::default(),
604 incoming_requests: Default::default(),
605 remove_incoming_requests: Default::default(),
606 outgoing_requests: Default::default(),
607 remove_outgoing_requests: Default::default(),
608 },
609 )
610 .trace_err();
611 }
612 }
613 }
614 }
615 }
616
617 if let Some(live_kit) = livekit_client.as_ref()
618 && delete_livekit_room
619 {
620 live_kit.delete_room(livekit_room).await.trace_err();
621 }
622 }
623 }
624
625 app_state
626 .db
627 .delete_stale_channel_chat_participants(
628 &app_state.config.zed_environment,
629 server_id,
630 )
631 .await
632 .trace_err();
633
634 app_state
635 .db
636 .clear_old_worktree_entries(server_id)
637 .await
638 .trace_err();
639
640 app_state
641 .db
642 .delete_stale_servers(&app_state.config.zed_environment, server_id)
643 .await
644 .trace_err();
645 }
646 .instrument(span),
647 );
648 Ok(())
649 }
650
651 pub fn teardown(&self) {
652 self.peer.teardown();
653 self.connection_pool.lock().reset();
654 let _ = self.teardown.send(true);
655 }
656
657 #[cfg(test)]
658 pub fn reset(&self, id: ServerId) {
659 self.teardown();
660 *self.id.lock() = id;
661 self.peer.reset(id.0 as u32);
662 let _ = self.teardown.send(false);
663 }
664
665 #[cfg(test)]
666 pub fn id(&self) -> ServerId {
667 *self.id.lock()
668 }
669
670 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
671 where
672 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
673 Fut: 'static + Send + Future<Output = Result<()>>,
674 M: EnvelopedMessage,
675 {
676 let prev_handler = self.handlers.insert(
677 TypeId::of::<M>(),
678 Box::new(move |envelope, session, span| {
679 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
680 let received_at = envelope.received_at;
681 tracing::info!("message received");
682 let start_time = Instant::now();
683 let future = (handler)(
684 *envelope,
685 MessageContext {
686 session,
687 span: span.clone(),
688 },
689 );
690 async move {
691 let result = future.await;
692 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
693 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
694 let queue_duration_ms = total_duration_ms - processing_duration_ms;
695 span.record(TOTAL_DURATION_MS, total_duration_ms);
696 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
697 span.record(QUEUE_DURATION_MS, queue_duration_ms);
698 match result {
699 Err(error) => {
700 tracing::error!(?error, "error handling message")
701 }
702 Ok(()) => tracing::info!("finished handling message"),
703 }
704 }
705 .boxed()
706 }),
707 );
708 if prev_handler.is_some() {
709 panic!("registered a handler for the same message twice");
710 }
711 self
712 }
713
714 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
715 where
716 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
717 Fut: 'static + Send + Future<Output = Result<()>>,
718 M: EnvelopedMessage,
719 {
720 self.add_handler(move |envelope, session| handler(envelope.payload, session));
721 self
722 }
723
724 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
725 where
726 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
727 Fut: Send + Future<Output = Result<()>>,
728 M: RequestMessage,
729 {
730 let handler = Arc::new(handler);
731 self.add_handler(move |envelope, session| {
732 let receipt = envelope.receipt();
733 let handler = handler.clone();
734 async move {
735 let peer = session.peer.clone();
736 let responded = Arc::new(AtomicBool::default());
737 let response = Response {
738 peer: peer.clone(),
739 responded: responded.clone(),
740 receipt,
741 };
742 match (handler)(envelope.payload, response, session).await {
743 Ok(()) => {
744 if responded.load(std::sync::atomic::Ordering::SeqCst) {
745 Ok(())
746 } else {
747 Err(anyhow!("handler did not send a response"))?
748 }
749 }
750 Err(error) => {
751 let proto_err = match &error {
752 Error::Internal(err) => err.to_proto(),
753 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
754 };
755 peer.respond_with_error(receipt, proto_err)?;
756 Err(error)
757 }
758 }
759 }
760 })
761 }
762
763 pub fn handle_connection(
764 self: &Arc<Self>,
765 connection: Connection,
766 address: String,
767 principal: Principal,
768 zed_version: ZedVersion,
769 release_channel: Option<String>,
770 user_agent: Option<String>,
771 geoip_country_code: Option<String>,
772 system_id: Option<String>,
773 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
774 executor: Executor,
775 connection_guard: Option<ConnectionGuard>,
776 ) -> impl Future<Output = ()> + use<> {
777 let this = self.clone();
778 let span = info_span!("handle connection", %address,
779 connection_id=field::Empty,
780 user_id=field::Empty,
781 login=field::Empty,
782 impersonator=field::Empty,
783 user_agent=field::Empty,
784 geoip_country_code=field::Empty,
785 release_channel=field::Empty,
786 );
787 principal.update_span(&span);
788 if let Some(user_agent) = user_agent {
789 span.record("user_agent", user_agent);
790 }
791 if let Some(release_channel) = release_channel {
792 span.record("release_channel", release_channel);
793 }
794
795 if let Some(country_code) = geoip_country_code.as_ref() {
796 span.record("geoip_country_code", country_code);
797 }
798
799 let mut teardown = self.teardown.subscribe();
800 async move {
801 if *teardown.borrow() {
802 tracing::error!("server is tearing down");
803 return;
804 }
805
806 let (connection_id, handle_io, mut incoming_rx) =
807 this.peer.add_connection(connection, {
808 let executor = executor.clone();
809 move |duration| executor.sleep(duration)
810 });
811 tracing::Span::current().record("connection_id", format!("{}", connection_id));
812
813 tracing::info!("connection opened");
814
815 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
816 let http_client = match ReqwestClient::user_agent(&user_agent) {
817 Ok(http_client) => Arc::new(http_client),
818 Err(error) => {
819 tracing::error!(?error, "failed to create HTTP client");
820 return;
821 }
822 };
823
824 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
825 |supermaven_admin_api_key| {
826 Arc::new(SupermavenAdminApi::new(
827 supermaven_admin_api_key.to_string(),
828 http_client.clone(),
829 ))
830 },
831 );
832
833 let session = Session {
834 principal: principal.clone(),
835 connection_id,
836 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
837 peer: this.peer.clone(),
838 connection_pool: this.connection_pool.clone(),
839 app_state: this.app_state.clone(),
840 geoip_country_code,
841 system_id,
842 _executor: executor.clone(),
843 supermaven_client,
844 };
845
846 if let Err(error) = this
847 .send_initial_client_update(
848 connection_id,
849 zed_version,
850 send_connection_id,
851 &session,
852 )
853 .await
854 {
855 tracing::error!(?error, "failed to send initial client update");
856 return;
857 }
858 drop(connection_guard);
859
860 let handle_io = handle_io.fuse();
861 futures::pin_mut!(handle_io);
862
863 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
864 // This prevents deadlocks when e.g., client A performs a request to client B and
865 // client B performs a request to client A. If both clients stop processing further
866 // messages until their respective request completes, they won't have a chance to
867 // respond to the other client's request and cause a deadlock.
868 //
869 // This arrangement ensures we will attempt to process earlier messages first, but fall
870 // back to processing messages arrived later in the spirit of making progress.
871 const MAX_CONCURRENT_HANDLERS: usize = 256;
872 let mut foreground_message_handlers = FuturesUnordered::new();
873 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
874 let get_concurrent_handlers = {
875 let concurrent_handlers = concurrent_handlers.clone();
876 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
877 };
878 loop {
879 let next_message = async {
880 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
881 let message = incoming_rx.next().await;
882 // Cache the concurrent_handlers here, so that we know what the
883 // queue looks like as each handler starts
884 (permit, message, get_concurrent_handlers())
885 }
886 .fuse();
887 futures::pin_mut!(next_message);
888 futures::select_biased! {
889 _ = teardown.changed().fuse() => return,
890 result = handle_io => {
891 if let Err(error) = result {
892 tracing::error!(?error, "error handling I/O");
893 }
894 break;
895 }
896 _ = foreground_message_handlers.next() => {}
897 next_message = next_message => {
898 let (permit, message, concurrent_handlers) = next_message;
899 if let Some(message) = message {
900 let type_name = message.payload_type_name();
901 // note: we copy all the fields from the parent span so we can query them in the logs.
902 // (https://github.com/tokio-rs/tracing/issues/2670).
903 let span = tracing::info_span!("receive message",
904 %connection_id,
905 %address,
906 type_name,
907 concurrent_handlers,
908 user_id=field::Empty,
909 login=field::Empty,
910 impersonator=field::Empty,
911 lsp_query_request=field::Empty,
912 release_channel=field::Empty,
913 { TOTAL_DURATION_MS }=field::Empty,
914 { PROCESSING_DURATION_MS }=field::Empty,
915 { QUEUE_DURATION_MS }=field::Empty,
916 { HOST_WAITING_MS }=field::Empty
917 );
918 principal.update_span(&span);
919 let span_enter = span.enter();
920 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
921 let is_background = message.is_background();
922 let handle_message = (handler)(message, session.clone(), span.clone());
923 drop(span_enter);
924
925 let handle_message = async move {
926 handle_message.await;
927 drop(permit);
928 }.instrument(span);
929 if is_background {
930 executor.spawn_detached(handle_message);
931 } else {
932 foreground_message_handlers.push(handle_message);
933 }
934 } else {
935 tracing::error!("no message handler");
936 }
937 } else {
938 tracing::info!("connection closed");
939 break;
940 }
941 }
942 }
943 }
944
945 drop(foreground_message_handlers);
946 let concurrent_handlers = get_concurrent_handlers();
947 tracing::info!(concurrent_handlers, "signing out");
948 if let Err(error) = connection_lost(session, teardown, executor).await {
949 tracing::error!(?error, "error signing out");
950 }
951 }
952 .instrument(span)
953 }
954
955 async fn send_initial_client_update(
956 &self,
957 connection_id: ConnectionId,
958 zed_version: ZedVersion,
959 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
960 session: &Session,
961 ) -> Result<()> {
962 self.peer.send(
963 connection_id,
964 proto::Hello {
965 peer_id: Some(connection_id.into()),
966 },
967 )?;
968 tracing::info!("sent hello message");
969 if let Some(send_connection_id) = send_connection_id.take() {
970 let _ = send_connection_id.send(connection_id);
971 }
972
973 match &session.principal {
974 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
975 if !user.connected_once {
976 self.peer.send(connection_id, proto::ShowContacts {})?;
977 self.app_state
978 .db
979 .set_user_connected_once(user.id, true)
980 .await?;
981 }
982
983 let contacts = self.app_state.db.get_contacts(user.id).await?;
984
985 {
986 let mut pool = self.connection_pool.lock();
987 pool.add_connection(connection_id, user.id, user.admin, zed_version);
988 self.peer.send(
989 connection_id,
990 build_initial_contacts_update(contacts, &pool),
991 )?;
992 }
993
994 if should_auto_subscribe_to_channels(zed_version) {
995 subscribe_user_to_channels(user.id, session).await?;
996 }
997
998 if let Some(incoming_call) =
999 self.app_state.db.incoming_call_for_user(user.id).await?
1000 {
1001 self.peer.send(connection_id, incoming_call)?;
1002 }
1003
1004 update_user_contacts(user.id, session).await?;
1005 }
1006 }
1007
1008 Ok(())
1009 }
1010
1011 pub async fn invite_code_redeemed(
1012 self: &Arc<Self>,
1013 inviter_id: UserId,
1014 invitee_id: UserId,
1015 ) -> Result<()> {
1016 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1017 && let Some(code) = &user.invite_code
1018 {
1019 let pool = self.connection_pool.lock();
1020 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1021 for connection_id in pool.user_connection_ids(inviter_id) {
1022 self.peer.send(
1023 connection_id,
1024 proto::UpdateContacts {
1025 contacts: vec![invitee_contact.clone()],
1026 ..Default::default()
1027 },
1028 )?;
1029 self.peer.send(
1030 connection_id,
1031 proto::UpdateInviteInfo {
1032 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1033 count: user.invite_count as u32,
1034 },
1035 )?;
1036 }
1037 }
1038 Ok(())
1039 }
1040
1041 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1042 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1043 && let Some(invite_code) = &user.invite_code
1044 {
1045 let pool = self.connection_pool.lock();
1046 for connection_id in pool.user_connection_ids(user_id) {
1047 self.peer.send(
1048 connection_id,
1049 proto::UpdateInviteInfo {
1050 url: format!(
1051 "{}{}",
1052 self.app_state.config.invite_link_prefix, invite_code
1053 ),
1054 count: user.invite_count as u32,
1055 },
1056 )?;
1057 }
1058 }
1059 Ok(())
1060 }
1061
1062 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1063 ServerSnapshot {
1064 connection_pool: ConnectionPoolGuard {
1065 guard: self.connection_pool.lock(),
1066 _not_send: PhantomData,
1067 },
1068 peer: &self.peer,
1069 }
1070 }
1071}
1072
1073impl Deref for ConnectionPoolGuard<'_> {
1074 type Target = ConnectionPool;
1075
1076 fn deref(&self) -> &Self::Target {
1077 &self.guard
1078 }
1079}
1080
1081impl DerefMut for ConnectionPoolGuard<'_> {
1082 fn deref_mut(&mut self) -> &mut Self::Target {
1083 &mut self.guard
1084 }
1085}
1086
1087impl Drop for ConnectionPoolGuard<'_> {
1088 fn drop(&mut self) {
1089 #[cfg(test)]
1090 self.check_invariants();
1091 }
1092}
1093
1094fn broadcast<F>(
1095 sender_id: Option<ConnectionId>,
1096 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1097 mut f: F,
1098) where
1099 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1100{
1101 for receiver_id in receiver_ids {
1102 if Some(receiver_id) != sender_id
1103 && let Err(error) = f(receiver_id)
1104 {
1105 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1106 }
1107 }
1108}
1109
1110pub struct ProtocolVersion(u32);
1111
1112impl Header for ProtocolVersion {
1113 fn name() -> &'static HeaderName {
1114 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1115 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1116 }
1117
1118 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1119 where
1120 Self: Sized,
1121 I: Iterator<Item = &'i axum::http::HeaderValue>,
1122 {
1123 let version = values
1124 .next()
1125 .ok_or_else(axum::headers::Error::invalid)?
1126 .to_str()
1127 .map_err(|_| axum::headers::Error::invalid())?
1128 .parse()
1129 .map_err(|_| axum::headers::Error::invalid())?;
1130 Ok(Self(version))
1131 }
1132
1133 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1134 values.extend([self.0.to_string().parse().unwrap()]);
1135 }
1136}
1137
1138pub struct AppVersionHeader(SemanticVersion);
1139impl Header for AppVersionHeader {
1140 fn name() -> &'static HeaderName {
1141 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1142 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1143 }
1144
1145 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1146 where
1147 Self: Sized,
1148 I: Iterator<Item = &'i axum::http::HeaderValue>,
1149 {
1150 let version = values
1151 .next()
1152 .ok_or_else(axum::headers::Error::invalid)?
1153 .to_str()
1154 .map_err(|_| axum::headers::Error::invalid())?
1155 .parse()
1156 .map_err(|_| axum::headers::Error::invalid())?;
1157 Ok(Self(version))
1158 }
1159
1160 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1161 values.extend([self.0.to_string().parse().unwrap()]);
1162 }
1163}
1164
1165#[derive(Debug)]
1166pub struct ReleaseChannelHeader(String);
1167
1168impl Header for ReleaseChannelHeader {
1169 fn name() -> &'static HeaderName {
1170 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1171 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1172 }
1173
1174 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1175 where
1176 Self: Sized,
1177 I: Iterator<Item = &'i axum::http::HeaderValue>,
1178 {
1179 Ok(Self(
1180 values
1181 .next()
1182 .ok_or_else(axum::headers::Error::invalid)?
1183 .to_str()
1184 .map_err(|_| axum::headers::Error::invalid())?
1185 .to_owned(),
1186 ))
1187 }
1188
1189 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1190 values.extend([self.0.parse().unwrap()]);
1191 }
1192}
1193
1194pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1195 Router::new()
1196 .route("/rpc", get(handle_websocket_request))
1197 .layer(
1198 ServiceBuilder::new()
1199 .layer(Extension(server.app_state.clone()))
1200 .layer(middleware::from_fn(auth::validate_header)),
1201 )
1202 .route("/metrics", get(handle_metrics))
1203 .layer(Extension(server))
1204}
1205
1206pub async fn handle_websocket_request(
1207 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1208 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1209 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1210 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1211 Extension(server): Extension<Arc<Server>>,
1212 Extension(principal): Extension<Principal>,
1213 user_agent: Option<TypedHeader<UserAgent>>,
1214 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1215 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1216 ws: WebSocketUpgrade,
1217) -> axum::response::Response {
1218 if protocol_version != rpc::PROTOCOL_VERSION {
1219 return (
1220 StatusCode::UPGRADE_REQUIRED,
1221 "client must be upgraded".to_string(),
1222 )
1223 .into_response();
1224 }
1225
1226 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1227 return (
1228 StatusCode::UPGRADE_REQUIRED,
1229 "no version header found".to_string(),
1230 )
1231 .into_response();
1232 };
1233
1234 let release_channel = release_channel_header.map(|header| header.0.0);
1235
1236 if !version.can_collaborate() {
1237 return (
1238 StatusCode::UPGRADE_REQUIRED,
1239 "client must be upgraded".to_string(),
1240 )
1241 .into_response();
1242 }
1243
1244 let socket_address = socket_address.to_string();
1245
1246 // Acquire connection guard before WebSocket upgrade
1247 let connection_guard = match ConnectionGuard::try_acquire() {
1248 Ok(guard) => guard,
1249 Err(()) => {
1250 return (
1251 StatusCode::SERVICE_UNAVAILABLE,
1252 "Too many concurrent connections",
1253 )
1254 .into_response();
1255 }
1256 };
1257
1258 ws.on_upgrade(move |socket| {
1259 let socket = socket
1260 .map_ok(to_tungstenite_message)
1261 .err_into()
1262 .with(|message| async move { to_axum_message(message) });
1263 let connection = Connection::new(Box::pin(socket));
1264 async move {
1265 server
1266 .handle_connection(
1267 connection,
1268 socket_address,
1269 principal,
1270 version,
1271 release_channel,
1272 user_agent.map(|header| header.to_string()),
1273 country_code_header.map(|header| header.to_string()),
1274 system_id_header.map(|header| header.to_string()),
1275 None,
1276 Executor::Production,
1277 Some(connection_guard),
1278 )
1279 .await;
1280 }
1281 })
1282}
1283
1284pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1285 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1286 let connections_metric = CONNECTIONS_METRIC
1287 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1288
1289 let connections = server
1290 .connection_pool
1291 .lock()
1292 .connections()
1293 .filter(|connection| !connection.admin)
1294 .count();
1295 connections_metric.set(connections as _);
1296
1297 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1298 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1299 register_int_gauge!(
1300 "shared_projects",
1301 "number of open projects with one or more guests"
1302 )
1303 .unwrap()
1304 });
1305
1306 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1307 shared_projects_metric.set(shared_projects as _);
1308
1309 let encoder = prometheus::TextEncoder::new();
1310 let metric_families = prometheus::gather();
1311 let encoded_metrics = encoder
1312 .encode_to_string(&metric_families)
1313 .map_err(|err| anyhow!("{err}"))?;
1314 Ok(encoded_metrics)
1315}
1316
1317#[instrument(err, skip(executor))]
1318async fn connection_lost(
1319 session: Session,
1320 mut teardown: watch::Receiver<bool>,
1321 executor: Executor,
1322) -> Result<()> {
1323 session.peer.disconnect(session.connection_id);
1324 session
1325 .connection_pool()
1326 .await
1327 .remove_connection(session.connection_id)?;
1328
1329 session
1330 .db()
1331 .await
1332 .connection_lost(session.connection_id)
1333 .await
1334 .trace_err();
1335
1336 futures::select_biased! {
1337 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1338
1339 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1340 leave_room_for_session(&session, session.connection_id).await.trace_err();
1341 leave_channel_buffers_for_session(&session)
1342 .await
1343 .trace_err();
1344
1345 if !session
1346 .connection_pool()
1347 .await
1348 .is_user_online(session.user_id())
1349 {
1350 let db = session.db().await;
1351 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1352 room_updated(&room, &session.peer);
1353 }
1354 }
1355
1356 update_user_contacts(session.user_id(), &session).await?;
1357 },
1358 _ = teardown.changed().fuse() => {}
1359 }
1360
1361 Ok(())
1362}
1363
1364/// Acknowledges a ping from a client, used to keep the connection alive.
1365async fn ping(
1366 _: proto::Ping,
1367 response: Response<proto::Ping>,
1368 _session: MessageContext,
1369) -> Result<()> {
1370 response.send(proto::Ack {})?;
1371 Ok(())
1372}
1373
1374/// Creates a new room for calling (outside of channels)
1375async fn create_room(
1376 _request: proto::CreateRoom,
1377 response: Response<proto::CreateRoom>,
1378 session: MessageContext,
1379) -> Result<()> {
1380 let livekit_room = nanoid::nanoid!(30);
1381
1382 let live_kit_connection_info = util::maybe!(async {
1383 let live_kit = session.app_state.livekit_client.as_ref();
1384 let live_kit = live_kit?;
1385 let user_id = session.user_id().to_string();
1386
1387 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1388
1389 Some(proto::LiveKitConnectionInfo {
1390 server_url: live_kit.url().into(),
1391 token,
1392 can_publish: true,
1393 })
1394 })
1395 .await;
1396
1397 let room = session
1398 .db()
1399 .await
1400 .create_room(session.user_id(), session.connection_id, &livekit_room)
1401 .await?;
1402
1403 response.send(proto::CreateRoomResponse {
1404 room: Some(room.clone()),
1405 live_kit_connection_info,
1406 })?;
1407
1408 update_user_contacts(session.user_id(), &session).await?;
1409 Ok(())
1410}
1411
1412/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1413async fn join_room(
1414 request: proto::JoinRoom,
1415 response: Response<proto::JoinRoom>,
1416 session: MessageContext,
1417) -> Result<()> {
1418 let room_id = RoomId::from_proto(request.id);
1419
1420 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1421
1422 if let Some(channel_id) = channel_id {
1423 return join_channel_internal(channel_id, Box::new(response), session).await;
1424 }
1425
1426 let joined_room = {
1427 let room = session
1428 .db()
1429 .await
1430 .join_room(room_id, session.user_id(), session.connection_id)
1431 .await?;
1432 room_updated(&room.room, &session.peer);
1433 room.into_inner()
1434 };
1435
1436 for connection_id in session
1437 .connection_pool()
1438 .await
1439 .user_connection_ids(session.user_id())
1440 {
1441 session
1442 .peer
1443 .send(
1444 connection_id,
1445 proto::CallCanceled {
1446 room_id: room_id.to_proto(),
1447 },
1448 )
1449 .trace_err();
1450 }
1451
1452 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1453 {
1454 live_kit
1455 .room_token(
1456 &joined_room.room.livekit_room,
1457 &session.user_id().to_string(),
1458 )
1459 .trace_err()
1460 .map(|token| proto::LiveKitConnectionInfo {
1461 server_url: live_kit.url().into(),
1462 token,
1463 can_publish: true,
1464 })
1465 } else {
1466 None
1467 };
1468
1469 response.send(proto::JoinRoomResponse {
1470 room: Some(joined_room.room),
1471 channel_id: None,
1472 live_kit_connection_info,
1473 })?;
1474
1475 update_user_contacts(session.user_id(), &session).await?;
1476 Ok(())
1477}
1478
1479/// Rejoin room is used to reconnect to a room after connection errors.
1480async fn rejoin_room(
1481 request: proto::RejoinRoom,
1482 response: Response<proto::RejoinRoom>,
1483 session: MessageContext,
1484) -> Result<()> {
1485 let room;
1486 let channel;
1487 {
1488 let mut rejoined_room = session
1489 .db()
1490 .await
1491 .rejoin_room(request, session.user_id(), session.connection_id)
1492 .await?;
1493
1494 response.send(proto::RejoinRoomResponse {
1495 room: Some(rejoined_room.room.clone()),
1496 reshared_projects: rejoined_room
1497 .reshared_projects
1498 .iter()
1499 .map(|project| proto::ResharedProject {
1500 id: project.id.to_proto(),
1501 collaborators: project
1502 .collaborators
1503 .iter()
1504 .map(|collaborator| collaborator.to_proto())
1505 .collect(),
1506 })
1507 .collect(),
1508 rejoined_projects: rejoined_room
1509 .rejoined_projects
1510 .iter()
1511 .map(|rejoined_project| rejoined_project.to_proto())
1512 .collect(),
1513 })?;
1514 room_updated(&rejoined_room.room, &session.peer);
1515
1516 for project in &rejoined_room.reshared_projects {
1517 for collaborator in &project.collaborators {
1518 session
1519 .peer
1520 .send(
1521 collaborator.connection_id,
1522 proto::UpdateProjectCollaborator {
1523 project_id: project.id.to_proto(),
1524 old_peer_id: Some(project.old_connection_id.into()),
1525 new_peer_id: Some(session.connection_id.into()),
1526 },
1527 )
1528 .trace_err();
1529 }
1530
1531 broadcast(
1532 Some(session.connection_id),
1533 project
1534 .collaborators
1535 .iter()
1536 .map(|collaborator| collaborator.connection_id),
1537 |connection_id| {
1538 session.peer.forward_send(
1539 session.connection_id,
1540 connection_id,
1541 proto::UpdateProject {
1542 project_id: project.id.to_proto(),
1543 worktrees: project.worktrees.clone(),
1544 },
1545 )
1546 },
1547 );
1548 }
1549
1550 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1551
1552 let rejoined_room = rejoined_room.into_inner();
1553
1554 room = rejoined_room.room;
1555 channel = rejoined_room.channel;
1556 }
1557
1558 if let Some(channel) = channel {
1559 channel_updated(
1560 &channel,
1561 &room,
1562 &session.peer,
1563 &*session.connection_pool().await,
1564 );
1565 }
1566
1567 update_user_contacts(session.user_id(), &session).await?;
1568 Ok(())
1569}
1570
1571fn notify_rejoined_projects(
1572 rejoined_projects: &mut Vec<RejoinedProject>,
1573 session: &Session,
1574) -> Result<()> {
1575 for project in rejoined_projects.iter() {
1576 for collaborator in &project.collaborators {
1577 session
1578 .peer
1579 .send(
1580 collaborator.connection_id,
1581 proto::UpdateProjectCollaborator {
1582 project_id: project.id.to_proto(),
1583 old_peer_id: Some(project.old_connection_id.into()),
1584 new_peer_id: Some(session.connection_id.into()),
1585 },
1586 )
1587 .trace_err();
1588 }
1589 }
1590
1591 for project in rejoined_projects {
1592 for worktree in mem::take(&mut project.worktrees) {
1593 // Stream this worktree's entries.
1594 let message = proto::UpdateWorktree {
1595 project_id: project.id.to_proto(),
1596 worktree_id: worktree.id,
1597 abs_path: worktree.abs_path.clone(),
1598 root_name: worktree.root_name,
1599 updated_entries: worktree.updated_entries,
1600 removed_entries: worktree.removed_entries,
1601 scan_id: worktree.scan_id,
1602 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1603 updated_repositories: worktree.updated_repositories,
1604 removed_repositories: worktree.removed_repositories,
1605 };
1606 for update in proto::split_worktree_update(message) {
1607 session.peer.send(session.connection_id, update)?;
1608 }
1609
1610 // Stream this worktree's diagnostics.
1611 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1612 if let Some(summary) = worktree_diagnostics.next() {
1613 let message = proto::UpdateDiagnosticSummary {
1614 project_id: project.id.to_proto(),
1615 worktree_id: worktree.id,
1616 summary: Some(summary),
1617 more_summaries: worktree_diagnostics.collect(),
1618 };
1619 session.peer.send(session.connection_id, message)?;
1620 }
1621
1622 for settings_file in worktree.settings_files {
1623 session.peer.send(
1624 session.connection_id,
1625 proto::UpdateWorktreeSettings {
1626 project_id: project.id.to_proto(),
1627 worktree_id: worktree.id,
1628 path: settings_file.path,
1629 content: Some(settings_file.content),
1630 kind: Some(settings_file.kind.to_proto().into()),
1631 },
1632 )?;
1633 }
1634 }
1635
1636 for repository in mem::take(&mut project.updated_repositories) {
1637 for update in split_repository_update(repository) {
1638 session.peer.send(session.connection_id, update)?;
1639 }
1640 }
1641
1642 for id in mem::take(&mut project.removed_repositories) {
1643 session.peer.send(
1644 session.connection_id,
1645 proto::RemoveRepository {
1646 project_id: project.id.to_proto(),
1647 id,
1648 },
1649 )?;
1650 }
1651 }
1652
1653 Ok(())
1654}
1655
1656/// leave room disconnects from the room.
1657async fn leave_room(
1658 _: proto::LeaveRoom,
1659 response: Response<proto::LeaveRoom>,
1660 session: MessageContext,
1661) -> Result<()> {
1662 leave_room_for_session(&session, session.connection_id).await?;
1663 response.send(proto::Ack {})?;
1664 Ok(())
1665}
1666
1667/// Updates the permissions of someone else in the room.
1668async fn set_room_participant_role(
1669 request: proto::SetRoomParticipantRole,
1670 response: Response<proto::SetRoomParticipantRole>,
1671 session: MessageContext,
1672) -> Result<()> {
1673 let user_id = UserId::from_proto(request.user_id);
1674 let role = ChannelRole::from(request.role());
1675
1676 let (livekit_room, can_publish) = {
1677 let room = session
1678 .db()
1679 .await
1680 .set_room_participant_role(
1681 session.user_id(),
1682 RoomId::from_proto(request.room_id),
1683 user_id,
1684 role,
1685 )
1686 .await?;
1687
1688 let livekit_room = room.livekit_room.clone();
1689 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1690 room_updated(&room, &session.peer);
1691 (livekit_room, can_publish)
1692 };
1693
1694 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1695 live_kit
1696 .update_participant(
1697 livekit_room.clone(),
1698 request.user_id.to_string(),
1699 livekit_api::proto::ParticipantPermission {
1700 can_subscribe: true,
1701 can_publish,
1702 can_publish_data: can_publish,
1703 hidden: false,
1704 recorder: false,
1705 },
1706 )
1707 .await
1708 .trace_err();
1709 }
1710
1711 response.send(proto::Ack {})?;
1712 Ok(())
1713}
1714
1715/// Call someone else into the current room
1716async fn call(
1717 request: proto::Call,
1718 response: Response<proto::Call>,
1719 session: MessageContext,
1720) -> Result<()> {
1721 let room_id = RoomId::from_proto(request.room_id);
1722 let calling_user_id = session.user_id();
1723 let calling_connection_id = session.connection_id;
1724 let called_user_id = UserId::from_proto(request.called_user_id);
1725 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1726 if !session
1727 .db()
1728 .await
1729 .has_contact(calling_user_id, called_user_id)
1730 .await?
1731 {
1732 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1733 }
1734
1735 let incoming_call = {
1736 let (room, incoming_call) = &mut *session
1737 .db()
1738 .await
1739 .call(
1740 room_id,
1741 calling_user_id,
1742 calling_connection_id,
1743 called_user_id,
1744 initial_project_id,
1745 )
1746 .await?;
1747 room_updated(room, &session.peer);
1748 mem::take(incoming_call)
1749 };
1750 update_user_contacts(called_user_id, &session).await?;
1751
1752 let mut calls = session
1753 .connection_pool()
1754 .await
1755 .user_connection_ids(called_user_id)
1756 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1757 .collect::<FuturesUnordered<_>>();
1758
1759 while let Some(call_response) = calls.next().await {
1760 match call_response.as_ref() {
1761 Ok(_) => {
1762 response.send(proto::Ack {})?;
1763 return Ok(());
1764 }
1765 Err(_) => {
1766 call_response.trace_err();
1767 }
1768 }
1769 }
1770
1771 {
1772 let room = session
1773 .db()
1774 .await
1775 .call_failed(room_id, called_user_id)
1776 .await?;
1777 room_updated(&room, &session.peer);
1778 }
1779 update_user_contacts(called_user_id, &session).await?;
1780
1781 Err(anyhow!("failed to ring user"))?
1782}
1783
1784/// Cancel an outgoing call.
1785async fn cancel_call(
1786 request: proto::CancelCall,
1787 response: Response<proto::CancelCall>,
1788 session: MessageContext,
1789) -> Result<()> {
1790 let called_user_id = UserId::from_proto(request.called_user_id);
1791 let room_id = RoomId::from_proto(request.room_id);
1792 {
1793 let room = session
1794 .db()
1795 .await
1796 .cancel_call(room_id, session.connection_id, called_user_id)
1797 .await?;
1798 room_updated(&room, &session.peer);
1799 }
1800
1801 for connection_id in session
1802 .connection_pool()
1803 .await
1804 .user_connection_ids(called_user_id)
1805 {
1806 session
1807 .peer
1808 .send(
1809 connection_id,
1810 proto::CallCanceled {
1811 room_id: room_id.to_proto(),
1812 },
1813 )
1814 .trace_err();
1815 }
1816 response.send(proto::Ack {})?;
1817
1818 update_user_contacts(called_user_id, &session).await?;
1819 Ok(())
1820}
1821
1822/// Decline an incoming call.
1823async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1824 let room_id = RoomId::from_proto(message.room_id);
1825 {
1826 let room = session
1827 .db()
1828 .await
1829 .decline_call(Some(room_id), session.user_id())
1830 .await?
1831 .context("declining call")?;
1832 room_updated(&room, &session.peer);
1833 }
1834
1835 for connection_id in session
1836 .connection_pool()
1837 .await
1838 .user_connection_ids(session.user_id())
1839 {
1840 session
1841 .peer
1842 .send(
1843 connection_id,
1844 proto::CallCanceled {
1845 room_id: room_id.to_proto(),
1846 },
1847 )
1848 .trace_err();
1849 }
1850 update_user_contacts(session.user_id(), &session).await?;
1851 Ok(())
1852}
1853
1854/// Updates other participants in the room with your current location.
1855async fn update_participant_location(
1856 request: proto::UpdateParticipantLocation,
1857 response: Response<proto::UpdateParticipantLocation>,
1858 session: MessageContext,
1859) -> Result<()> {
1860 let room_id = RoomId::from_proto(request.room_id);
1861 let location = request.location.context("invalid location")?;
1862
1863 let db = session.db().await;
1864 let room = db
1865 .update_room_participant_location(room_id, session.connection_id, location)
1866 .await?;
1867
1868 room_updated(&room, &session.peer);
1869 response.send(proto::Ack {})?;
1870 Ok(())
1871}
1872
1873/// Share a project into the room.
1874async fn share_project(
1875 request: proto::ShareProject,
1876 response: Response<proto::ShareProject>,
1877 session: MessageContext,
1878) -> Result<()> {
1879 let (project_id, room) = &*session
1880 .db()
1881 .await
1882 .share_project(
1883 RoomId::from_proto(request.room_id),
1884 session.connection_id,
1885 &request.worktrees,
1886 request.is_ssh_project,
1887 request.windows_paths.unwrap_or(false),
1888 )
1889 .await?;
1890 response.send(proto::ShareProjectResponse {
1891 project_id: project_id.to_proto(),
1892 })?;
1893 room_updated(room, &session.peer);
1894
1895 Ok(())
1896}
1897
1898/// Unshare a project from the room.
1899async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1900 let project_id = ProjectId::from_proto(message.project_id);
1901 unshare_project_internal(project_id, session.connection_id, &session).await
1902}
1903
1904async fn unshare_project_internal(
1905 project_id: ProjectId,
1906 connection_id: ConnectionId,
1907 session: &Session,
1908) -> Result<()> {
1909 let delete = {
1910 let room_guard = session
1911 .db()
1912 .await
1913 .unshare_project(project_id, connection_id)
1914 .await?;
1915
1916 let (delete, room, guest_connection_ids) = &*room_guard;
1917
1918 let message = proto::UnshareProject {
1919 project_id: project_id.to_proto(),
1920 };
1921
1922 broadcast(
1923 Some(connection_id),
1924 guest_connection_ids.iter().copied(),
1925 |conn_id| session.peer.send(conn_id, message.clone()),
1926 );
1927 if let Some(room) = room {
1928 room_updated(room, &session.peer);
1929 }
1930
1931 *delete
1932 };
1933
1934 if delete {
1935 let db = session.db().await;
1936 db.delete_project(project_id).await?;
1937 }
1938
1939 Ok(())
1940}
1941
1942/// Join someone elses shared project.
1943async fn join_project(
1944 request: proto::JoinProject,
1945 response: Response<proto::JoinProject>,
1946 session: MessageContext,
1947) -> Result<()> {
1948 let project_id = ProjectId::from_proto(request.project_id);
1949
1950 tracing::info!(%project_id, "join project");
1951
1952 let db = session.db().await;
1953 let (project, replica_id) = &mut *db
1954 .join_project(
1955 project_id,
1956 session.connection_id,
1957 session.user_id(),
1958 request.committer_name.clone(),
1959 request.committer_email.clone(),
1960 )
1961 .await?;
1962 drop(db);
1963 tracing::info!(%project_id, "join remote project");
1964 let collaborators = project
1965 .collaborators
1966 .iter()
1967 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1968 .map(|collaborator| collaborator.to_proto())
1969 .collect::<Vec<_>>();
1970 let project_id = project.id;
1971 let guest_user_id = session.user_id();
1972
1973 let worktrees = project
1974 .worktrees
1975 .iter()
1976 .map(|(id, worktree)| proto::WorktreeMetadata {
1977 id: *id,
1978 root_name: worktree.root_name.clone(),
1979 visible: worktree.visible,
1980 abs_path: worktree.abs_path.clone(),
1981 })
1982 .collect::<Vec<_>>();
1983
1984 let add_project_collaborator = proto::AddProjectCollaborator {
1985 project_id: project_id.to_proto(),
1986 collaborator: Some(proto::Collaborator {
1987 peer_id: Some(session.connection_id.into()),
1988 replica_id: replica_id.0 as u32,
1989 user_id: guest_user_id.to_proto(),
1990 is_host: false,
1991 committer_name: request.committer_name.clone(),
1992 committer_email: request.committer_email.clone(),
1993 }),
1994 };
1995
1996 for collaborator in &collaborators {
1997 session
1998 .peer
1999 .send(
2000 collaborator.peer_id.unwrap().into(),
2001 add_project_collaborator.clone(),
2002 )
2003 .trace_err();
2004 }
2005
2006 // First, we send the metadata associated with each worktree.
2007 let (language_servers, language_server_capabilities) = project
2008 .language_servers
2009 .clone()
2010 .into_iter()
2011 .map(|server| (server.server, server.capabilities))
2012 .unzip();
2013 response.send(proto::JoinProjectResponse {
2014 project_id: project.id.0 as u64,
2015 worktrees,
2016 replica_id: replica_id.0 as u32,
2017 collaborators,
2018 language_servers,
2019 language_server_capabilities,
2020 role: project.role.into(),
2021 windows_paths: project.path_style == PathStyle::Windows,
2022 })?;
2023
2024 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2025 // Stream this worktree's entries.
2026 let message = proto::UpdateWorktree {
2027 project_id: project_id.to_proto(),
2028 worktree_id,
2029 abs_path: worktree.abs_path.clone(),
2030 root_name: worktree.root_name,
2031 updated_entries: worktree.entries,
2032 removed_entries: Default::default(),
2033 scan_id: worktree.scan_id,
2034 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2035 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2036 removed_repositories: Default::default(),
2037 };
2038 for update in proto::split_worktree_update(message) {
2039 session.peer.send(session.connection_id, update.clone())?;
2040 }
2041
2042 // Stream this worktree's diagnostics.
2043 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2044 if let Some(summary) = worktree_diagnostics.next() {
2045 let message = proto::UpdateDiagnosticSummary {
2046 project_id: project.id.to_proto(),
2047 worktree_id: worktree.id,
2048 summary: Some(summary),
2049 more_summaries: worktree_diagnostics.collect(),
2050 };
2051 session.peer.send(session.connection_id, message)?;
2052 }
2053
2054 for settings_file in worktree.settings_files {
2055 session.peer.send(
2056 session.connection_id,
2057 proto::UpdateWorktreeSettings {
2058 project_id: project_id.to_proto(),
2059 worktree_id: worktree.id,
2060 path: settings_file.path,
2061 content: Some(settings_file.content),
2062 kind: Some(settings_file.kind.to_proto() as i32),
2063 },
2064 )?;
2065 }
2066 }
2067
2068 for repository in mem::take(&mut project.repositories) {
2069 for update in split_repository_update(repository) {
2070 session.peer.send(session.connection_id, update)?;
2071 }
2072 }
2073
2074 for language_server in &project.language_servers {
2075 session.peer.send(
2076 session.connection_id,
2077 proto::UpdateLanguageServer {
2078 project_id: project_id.to_proto(),
2079 server_name: Some(language_server.server.name.clone()),
2080 language_server_id: language_server.server.id,
2081 variant: Some(
2082 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2083 proto::LspDiskBasedDiagnosticsUpdated {},
2084 ),
2085 ),
2086 },
2087 )?;
2088 }
2089
2090 Ok(())
2091}
2092
2093/// Leave someone elses shared project.
2094async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2095 let sender_id = session.connection_id;
2096 let project_id = ProjectId::from_proto(request.project_id);
2097 let db = session.db().await;
2098
2099 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2100 tracing::info!(
2101 %project_id,
2102 "leave project"
2103 );
2104
2105 project_left(project, &session);
2106 if let Some(room) = room {
2107 room_updated(room, &session.peer);
2108 }
2109
2110 Ok(())
2111}
2112
2113/// Updates other participants with changes to the project
2114async fn update_project(
2115 request: proto::UpdateProject,
2116 response: Response<proto::UpdateProject>,
2117 session: MessageContext,
2118) -> Result<()> {
2119 let project_id = ProjectId::from_proto(request.project_id);
2120 let (room, guest_connection_ids) = &*session
2121 .db()
2122 .await
2123 .update_project(project_id, session.connection_id, &request.worktrees)
2124 .await?;
2125 broadcast(
2126 Some(session.connection_id),
2127 guest_connection_ids.iter().copied(),
2128 |connection_id| {
2129 session
2130 .peer
2131 .forward_send(session.connection_id, connection_id, request.clone())
2132 },
2133 );
2134 if let Some(room) = room {
2135 room_updated(room, &session.peer);
2136 }
2137 response.send(proto::Ack {})?;
2138
2139 Ok(())
2140}
2141
2142/// Updates other participants with changes to the worktree
2143async fn update_worktree(
2144 request: proto::UpdateWorktree,
2145 response: Response<proto::UpdateWorktree>,
2146 session: MessageContext,
2147) -> Result<()> {
2148 let guest_connection_ids = session
2149 .db()
2150 .await
2151 .update_worktree(&request, session.connection_id)
2152 .await?;
2153
2154 broadcast(
2155 Some(session.connection_id),
2156 guest_connection_ids.iter().copied(),
2157 |connection_id| {
2158 session
2159 .peer
2160 .forward_send(session.connection_id, connection_id, request.clone())
2161 },
2162 );
2163 response.send(proto::Ack {})?;
2164 Ok(())
2165}
2166
2167async fn update_repository(
2168 request: proto::UpdateRepository,
2169 response: Response<proto::UpdateRepository>,
2170 session: MessageContext,
2171) -> Result<()> {
2172 let guest_connection_ids = session
2173 .db()
2174 .await
2175 .update_repository(&request, session.connection_id)
2176 .await?;
2177
2178 broadcast(
2179 Some(session.connection_id),
2180 guest_connection_ids.iter().copied(),
2181 |connection_id| {
2182 session
2183 .peer
2184 .forward_send(session.connection_id, connection_id, request.clone())
2185 },
2186 );
2187 response.send(proto::Ack {})?;
2188 Ok(())
2189}
2190
2191async fn remove_repository(
2192 request: proto::RemoveRepository,
2193 response: Response<proto::RemoveRepository>,
2194 session: MessageContext,
2195) -> Result<()> {
2196 let guest_connection_ids = session
2197 .db()
2198 .await
2199 .remove_repository(&request, session.connection_id)
2200 .await?;
2201
2202 broadcast(
2203 Some(session.connection_id),
2204 guest_connection_ids.iter().copied(),
2205 |connection_id| {
2206 session
2207 .peer
2208 .forward_send(session.connection_id, connection_id, request.clone())
2209 },
2210 );
2211 response.send(proto::Ack {})?;
2212 Ok(())
2213}
2214
2215/// Updates other participants with changes to the diagnostics
2216async fn update_diagnostic_summary(
2217 message: proto::UpdateDiagnosticSummary,
2218 session: MessageContext,
2219) -> Result<()> {
2220 let guest_connection_ids = session
2221 .db()
2222 .await
2223 .update_diagnostic_summary(&message, session.connection_id)
2224 .await?;
2225
2226 broadcast(
2227 Some(session.connection_id),
2228 guest_connection_ids.iter().copied(),
2229 |connection_id| {
2230 session
2231 .peer
2232 .forward_send(session.connection_id, connection_id, message.clone())
2233 },
2234 );
2235
2236 Ok(())
2237}
2238
2239/// Updates other participants with changes to the worktree settings
2240async fn update_worktree_settings(
2241 message: proto::UpdateWorktreeSettings,
2242 session: MessageContext,
2243) -> Result<()> {
2244 let guest_connection_ids = session
2245 .db()
2246 .await
2247 .update_worktree_settings(&message, session.connection_id)
2248 .await?;
2249
2250 broadcast(
2251 Some(session.connection_id),
2252 guest_connection_ids.iter().copied(),
2253 |connection_id| {
2254 session
2255 .peer
2256 .forward_send(session.connection_id, connection_id, message.clone())
2257 },
2258 );
2259
2260 Ok(())
2261}
2262
2263/// Notify other participants that a language server has started.
2264async fn start_language_server(
2265 request: proto::StartLanguageServer,
2266 session: MessageContext,
2267) -> Result<()> {
2268 let guest_connection_ids = session
2269 .db()
2270 .await
2271 .start_language_server(&request, session.connection_id)
2272 .await?;
2273
2274 broadcast(
2275 Some(session.connection_id),
2276 guest_connection_ids.iter().copied(),
2277 |connection_id| {
2278 session
2279 .peer
2280 .forward_send(session.connection_id, connection_id, request.clone())
2281 },
2282 );
2283 Ok(())
2284}
2285
2286/// Notify other participants that a language server has changed.
2287async fn update_language_server(
2288 request: proto::UpdateLanguageServer,
2289 session: MessageContext,
2290) -> Result<()> {
2291 let project_id = ProjectId::from_proto(request.project_id);
2292 let db = session.db().await;
2293
2294 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2295 && let Some(capabilities) = update.capabilities.clone()
2296 {
2297 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2298 .await?;
2299 }
2300
2301 let project_connection_ids = db
2302 .project_connection_ids(project_id, session.connection_id, true)
2303 .await?;
2304 broadcast(
2305 Some(session.connection_id),
2306 project_connection_ids.iter().copied(),
2307 |connection_id| {
2308 session
2309 .peer
2310 .forward_send(session.connection_id, connection_id, request.clone())
2311 },
2312 );
2313 Ok(())
2314}
2315
2316/// forward a project request to the host. These requests should be read only
2317/// as guests are allowed to send them.
2318async fn forward_read_only_project_request<T>(
2319 request: T,
2320 response: Response<T>,
2321 session: MessageContext,
2322) -> Result<()>
2323where
2324 T: EntityMessage + RequestMessage,
2325{
2326 let project_id = ProjectId::from_proto(request.remote_entity_id());
2327 let host_connection_id = session
2328 .db()
2329 .await
2330 .host_for_read_only_project_request(project_id, session.connection_id)
2331 .await?;
2332 let payload = session.forward_request(host_connection_id, request).await?;
2333 response.send(payload)?;
2334 Ok(())
2335}
2336
2337/// forward a project request to the host. These requests are disallowed
2338/// for guests.
2339async fn forward_mutating_project_request<T>(
2340 request: T,
2341 response: Response<T>,
2342 session: MessageContext,
2343) -> Result<()>
2344where
2345 T: EntityMessage + RequestMessage,
2346{
2347 let project_id = ProjectId::from_proto(request.remote_entity_id());
2348
2349 let host_connection_id = session
2350 .db()
2351 .await
2352 .host_for_mutating_project_request(project_id, session.connection_id)
2353 .await?;
2354 let payload = session.forward_request(host_connection_id, request).await?;
2355 response.send(payload)?;
2356 Ok(())
2357}
2358
2359async fn lsp_query(
2360 request: proto::LspQuery,
2361 response: Response<proto::LspQuery>,
2362 session: MessageContext,
2363) -> Result<()> {
2364 let (name, should_write) = request.query_name_and_write_permissions();
2365 tracing::Span::current().record("lsp_query_request", name);
2366 tracing::info!("lsp_query message received");
2367 if should_write {
2368 forward_mutating_project_request(request, response, session).await
2369 } else {
2370 forward_read_only_project_request(request, response, session).await
2371 }
2372}
2373
2374/// Notify other participants that a new buffer has been created
2375async fn create_buffer_for_peer(
2376 request: proto::CreateBufferForPeer,
2377 session: MessageContext,
2378) -> Result<()> {
2379 session
2380 .db()
2381 .await
2382 .check_user_is_project_host(
2383 ProjectId::from_proto(request.project_id),
2384 session.connection_id,
2385 )
2386 .await?;
2387 let peer_id = request.peer_id.context("invalid peer id")?;
2388 session
2389 .peer
2390 .forward_send(session.connection_id, peer_id.into(), request)?;
2391 Ok(())
2392}
2393
2394/// Notify other participants that a new image has been created
2395async fn create_image_for_peer(
2396 request: proto::CreateImageForPeer,
2397 session: MessageContext,
2398) -> Result<()> {
2399 session
2400 .db()
2401 .await
2402 .check_user_is_project_host(
2403 ProjectId::from_proto(request.project_id),
2404 session.connection_id,
2405 )
2406 .await?;
2407 let peer_id = request.peer_id.context("invalid peer id")?;
2408 session
2409 .peer
2410 .forward_send(session.connection_id, peer_id.into(), request)?;
2411 Ok(())
2412}
2413
2414/// Notify other participants that a buffer has been updated. This is
2415/// allowed for guests as long as the update is limited to selections.
2416async fn update_buffer(
2417 request: proto::UpdateBuffer,
2418 response: Response<proto::UpdateBuffer>,
2419 session: MessageContext,
2420) -> Result<()> {
2421 let project_id = ProjectId::from_proto(request.project_id);
2422 let mut capability = Capability::ReadOnly;
2423
2424 for op in request.operations.iter() {
2425 match op.variant {
2426 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2427 Some(_) => capability = Capability::ReadWrite,
2428 }
2429 }
2430
2431 let host = {
2432 let guard = session
2433 .db()
2434 .await
2435 .connections_for_buffer_update(project_id, session.connection_id, capability)
2436 .await?;
2437
2438 let (host, guests) = &*guard;
2439
2440 broadcast(
2441 Some(session.connection_id),
2442 guests.clone(),
2443 |connection_id| {
2444 session
2445 .peer
2446 .forward_send(session.connection_id, connection_id, request.clone())
2447 },
2448 );
2449
2450 *host
2451 };
2452
2453 if host != session.connection_id {
2454 session.forward_request(host, request.clone()).await?;
2455 }
2456
2457 response.send(proto::Ack {})?;
2458 Ok(())
2459}
2460
2461async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2462 let project_id = ProjectId::from_proto(message.project_id);
2463
2464 let operation = message.operation.as_ref().context("invalid operation")?;
2465 let capability = match operation.variant.as_ref() {
2466 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2467 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2468 match buffer_op.variant {
2469 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2470 Capability::ReadOnly
2471 }
2472 _ => Capability::ReadWrite,
2473 }
2474 } else {
2475 Capability::ReadWrite
2476 }
2477 }
2478 Some(_) => Capability::ReadWrite,
2479 None => Capability::ReadOnly,
2480 };
2481
2482 let guard = session
2483 .db()
2484 .await
2485 .connections_for_buffer_update(project_id, session.connection_id, capability)
2486 .await?;
2487
2488 let (host, guests) = &*guard;
2489
2490 broadcast(
2491 Some(session.connection_id),
2492 guests.iter().chain([host]).copied(),
2493 |connection_id| {
2494 session
2495 .peer
2496 .forward_send(session.connection_id, connection_id, message.clone())
2497 },
2498 );
2499
2500 Ok(())
2501}
2502
2503/// Notify other participants that a project has been updated.
2504async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2505 request: T,
2506 session: MessageContext,
2507) -> Result<()> {
2508 let project_id = ProjectId::from_proto(request.remote_entity_id());
2509 let project_connection_ids = session
2510 .db()
2511 .await
2512 .project_connection_ids(project_id, session.connection_id, false)
2513 .await?;
2514
2515 broadcast(
2516 Some(session.connection_id),
2517 project_connection_ids.iter().copied(),
2518 |connection_id| {
2519 session
2520 .peer
2521 .forward_send(session.connection_id, connection_id, request.clone())
2522 },
2523 );
2524 Ok(())
2525}
2526
2527/// Start following another user in a call.
2528async fn follow(
2529 request: proto::Follow,
2530 response: Response<proto::Follow>,
2531 session: MessageContext,
2532) -> Result<()> {
2533 let room_id = RoomId::from_proto(request.room_id);
2534 let project_id = request.project_id.map(ProjectId::from_proto);
2535 let leader_id = request.leader_id.context("invalid leader id")?.into();
2536 let follower_id = session.connection_id;
2537
2538 session
2539 .db()
2540 .await
2541 .check_room_participants(room_id, leader_id, session.connection_id)
2542 .await?;
2543
2544 let response_payload = session.forward_request(leader_id, request).await?;
2545 response.send(response_payload)?;
2546
2547 if let Some(project_id) = project_id {
2548 let room = session
2549 .db()
2550 .await
2551 .follow(room_id, project_id, leader_id, follower_id)
2552 .await?;
2553 room_updated(&room, &session.peer);
2554 }
2555
2556 Ok(())
2557}
2558
2559/// Stop following another user in a call.
2560async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2561 let room_id = RoomId::from_proto(request.room_id);
2562 let project_id = request.project_id.map(ProjectId::from_proto);
2563 let leader_id = request.leader_id.context("invalid leader id")?.into();
2564 let follower_id = session.connection_id;
2565
2566 session
2567 .db()
2568 .await
2569 .check_room_participants(room_id, leader_id, session.connection_id)
2570 .await?;
2571
2572 session
2573 .peer
2574 .forward_send(session.connection_id, leader_id, request)?;
2575
2576 if let Some(project_id) = project_id {
2577 let room = session
2578 .db()
2579 .await
2580 .unfollow(room_id, project_id, leader_id, follower_id)
2581 .await?;
2582 room_updated(&room, &session.peer);
2583 }
2584
2585 Ok(())
2586}
2587
2588/// Notify everyone following you of your current location.
2589async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2590 let room_id = RoomId::from_proto(request.room_id);
2591 let database = session.db.lock().await;
2592
2593 let connection_ids = if let Some(project_id) = request.project_id {
2594 let project_id = ProjectId::from_proto(project_id);
2595 database
2596 .project_connection_ids(project_id, session.connection_id, true)
2597 .await?
2598 } else {
2599 database
2600 .room_connection_ids(room_id, session.connection_id)
2601 .await?
2602 };
2603
2604 // For now, don't send view update messages back to that view's current leader.
2605 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2606 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2607 _ => None,
2608 });
2609
2610 for connection_id in connection_ids.iter().cloned() {
2611 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2612 session
2613 .peer
2614 .forward_send(session.connection_id, connection_id, request.clone())?;
2615 }
2616 }
2617 Ok(())
2618}
2619
2620/// Get public data about users.
2621async fn get_users(
2622 request: proto::GetUsers,
2623 response: Response<proto::GetUsers>,
2624 session: MessageContext,
2625) -> Result<()> {
2626 let user_ids = request
2627 .user_ids
2628 .into_iter()
2629 .map(UserId::from_proto)
2630 .collect();
2631 let users = session
2632 .db()
2633 .await
2634 .get_users_by_ids(user_ids)
2635 .await?
2636 .into_iter()
2637 .map(|user| proto::User {
2638 id: user.id.to_proto(),
2639 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2640 github_login: user.github_login,
2641 name: user.name,
2642 })
2643 .collect();
2644 response.send(proto::UsersResponse { users })?;
2645 Ok(())
2646}
2647
2648/// Search for users (to invite) buy Github login
2649async fn fuzzy_search_users(
2650 request: proto::FuzzySearchUsers,
2651 response: Response<proto::FuzzySearchUsers>,
2652 session: MessageContext,
2653) -> Result<()> {
2654 let query = request.query;
2655 let users = match query.len() {
2656 0 => vec![],
2657 1 | 2 => session
2658 .db()
2659 .await
2660 .get_user_by_github_login(&query)
2661 .await?
2662 .into_iter()
2663 .collect(),
2664 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2665 };
2666 let users = users
2667 .into_iter()
2668 .filter(|user| user.id != session.user_id())
2669 .map(|user| proto::User {
2670 id: user.id.to_proto(),
2671 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2672 github_login: user.github_login,
2673 name: user.name,
2674 })
2675 .collect();
2676 response.send(proto::UsersResponse { users })?;
2677 Ok(())
2678}
2679
2680/// Send a contact request to another user.
2681async fn request_contact(
2682 request: proto::RequestContact,
2683 response: Response<proto::RequestContact>,
2684 session: MessageContext,
2685) -> Result<()> {
2686 let requester_id = session.user_id();
2687 let responder_id = UserId::from_proto(request.responder_id);
2688 if requester_id == responder_id {
2689 return Err(anyhow!("cannot add yourself as a contact"))?;
2690 }
2691
2692 let notifications = session
2693 .db()
2694 .await
2695 .send_contact_request(requester_id, responder_id)
2696 .await?;
2697
2698 // Update outgoing contact requests of requester
2699 let mut update = proto::UpdateContacts::default();
2700 update.outgoing_requests.push(responder_id.to_proto());
2701 for connection_id in session
2702 .connection_pool()
2703 .await
2704 .user_connection_ids(requester_id)
2705 {
2706 session.peer.send(connection_id, update.clone())?;
2707 }
2708
2709 // Update incoming contact requests of responder
2710 let mut update = proto::UpdateContacts::default();
2711 update
2712 .incoming_requests
2713 .push(proto::IncomingContactRequest {
2714 requester_id: requester_id.to_proto(),
2715 });
2716 let connection_pool = session.connection_pool().await;
2717 for connection_id in connection_pool.user_connection_ids(responder_id) {
2718 session.peer.send(connection_id, update.clone())?;
2719 }
2720
2721 send_notifications(&connection_pool, &session.peer, notifications);
2722
2723 response.send(proto::Ack {})?;
2724 Ok(())
2725}
2726
2727/// Accept or decline a contact request
2728async fn respond_to_contact_request(
2729 request: proto::RespondToContactRequest,
2730 response: Response<proto::RespondToContactRequest>,
2731 session: MessageContext,
2732) -> Result<()> {
2733 let responder_id = session.user_id();
2734 let requester_id = UserId::from_proto(request.requester_id);
2735 let db = session.db().await;
2736 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2737 db.dismiss_contact_notification(responder_id, requester_id)
2738 .await?;
2739 } else {
2740 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2741
2742 let notifications = db
2743 .respond_to_contact_request(responder_id, requester_id, accept)
2744 .await?;
2745 let requester_busy = db.is_user_busy(requester_id).await?;
2746 let responder_busy = db.is_user_busy(responder_id).await?;
2747
2748 let pool = session.connection_pool().await;
2749 // Update responder with new contact
2750 let mut update = proto::UpdateContacts::default();
2751 if accept {
2752 update
2753 .contacts
2754 .push(contact_for_user(requester_id, requester_busy, &pool));
2755 }
2756 update
2757 .remove_incoming_requests
2758 .push(requester_id.to_proto());
2759 for connection_id in pool.user_connection_ids(responder_id) {
2760 session.peer.send(connection_id, update.clone())?;
2761 }
2762
2763 // Update requester with new contact
2764 let mut update = proto::UpdateContacts::default();
2765 if accept {
2766 update
2767 .contacts
2768 .push(contact_for_user(responder_id, responder_busy, &pool));
2769 }
2770 update
2771 .remove_outgoing_requests
2772 .push(responder_id.to_proto());
2773
2774 for connection_id in pool.user_connection_ids(requester_id) {
2775 session.peer.send(connection_id, update.clone())?;
2776 }
2777
2778 send_notifications(&pool, &session.peer, notifications);
2779 }
2780
2781 response.send(proto::Ack {})?;
2782 Ok(())
2783}
2784
2785/// Remove a contact.
2786async fn remove_contact(
2787 request: proto::RemoveContact,
2788 response: Response<proto::RemoveContact>,
2789 session: MessageContext,
2790) -> Result<()> {
2791 let requester_id = session.user_id();
2792 let responder_id = UserId::from_proto(request.user_id);
2793 let db = session.db().await;
2794 let (contact_accepted, deleted_notification_id) =
2795 db.remove_contact(requester_id, responder_id).await?;
2796
2797 let pool = session.connection_pool().await;
2798 // Update outgoing contact requests of requester
2799 let mut update = proto::UpdateContacts::default();
2800 if contact_accepted {
2801 update.remove_contacts.push(responder_id.to_proto());
2802 } else {
2803 update
2804 .remove_outgoing_requests
2805 .push(responder_id.to_proto());
2806 }
2807 for connection_id in pool.user_connection_ids(requester_id) {
2808 session.peer.send(connection_id, update.clone())?;
2809 }
2810
2811 // Update incoming contact requests of responder
2812 let mut update = proto::UpdateContacts::default();
2813 if contact_accepted {
2814 update.remove_contacts.push(requester_id.to_proto());
2815 } else {
2816 update
2817 .remove_incoming_requests
2818 .push(requester_id.to_proto());
2819 }
2820 for connection_id in pool.user_connection_ids(responder_id) {
2821 session.peer.send(connection_id, update.clone())?;
2822 if let Some(notification_id) = deleted_notification_id {
2823 session.peer.send(
2824 connection_id,
2825 proto::DeleteNotification {
2826 notification_id: notification_id.to_proto(),
2827 },
2828 )?;
2829 }
2830 }
2831
2832 response.send(proto::Ack {})?;
2833 Ok(())
2834}
2835
2836fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2837 version.0.minor() < 139
2838}
2839
2840async fn subscribe_to_channels(
2841 _: proto::SubscribeToChannels,
2842 session: MessageContext,
2843) -> Result<()> {
2844 subscribe_user_to_channels(session.user_id(), &session).await?;
2845 Ok(())
2846}
2847
2848async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2849 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2850 let mut pool = session.connection_pool().await;
2851 for membership in &channels_for_user.channel_memberships {
2852 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2853 }
2854 session.peer.send(
2855 session.connection_id,
2856 build_update_user_channels(&channels_for_user),
2857 )?;
2858 session.peer.send(
2859 session.connection_id,
2860 build_channels_update(channels_for_user),
2861 )?;
2862 Ok(())
2863}
2864
2865/// Creates a new channel.
2866async fn create_channel(
2867 request: proto::CreateChannel,
2868 response: Response<proto::CreateChannel>,
2869 session: MessageContext,
2870) -> Result<()> {
2871 let db = session.db().await;
2872
2873 let parent_id = request.parent_id.map(ChannelId::from_proto);
2874 let (channel, membership) = db
2875 .create_channel(&request.name, parent_id, session.user_id())
2876 .await?;
2877
2878 let root_id = channel.root_id();
2879 let channel = Channel::from_model(channel);
2880
2881 response.send(proto::CreateChannelResponse {
2882 channel: Some(channel.to_proto()),
2883 parent_id: request.parent_id,
2884 })?;
2885
2886 let mut connection_pool = session.connection_pool().await;
2887 if let Some(membership) = membership {
2888 connection_pool.subscribe_to_channel(
2889 membership.user_id,
2890 membership.channel_id,
2891 membership.role,
2892 );
2893 let update = proto::UpdateUserChannels {
2894 channel_memberships: vec![proto::ChannelMembership {
2895 channel_id: membership.channel_id.to_proto(),
2896 role: membership.role.into(),
2897 }],
2898 ..Default::default()
2899 };
2900 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2901 session.peer.send(connection_id, update.clone())?;
2902 }
2903 }
2904
2905 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2906 if !role.can_see_channel(channel.visibility) {
2907 continue;
2908 }
2909
2910 let update = proto::UpdateChannels {
2911 channels: vec![channel.to_proto()],
2912 ..Default::default()
2913 };
2914 session.peer.send(connection_id, update.clone())?;
2915 }
2916
2917 Ok(())
2918}
2919
2920/// Delete a channel
2921async fn delete_channel(
2922 request: proto::DeleteChannel,
2923 response: Response<proto::DeleteChannel>,
2924 session: MessageContext,
2925) -> Result<()> {
2926 let db = session.db().await;
2927
2928 let channel_id = request.channel_id;
2929 let (root_channel, removed_channels) = db
2930 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2931 .await?;
2932 response.send(proto::Ack {})?;
2933
2934 // Notify members of removed channels
2935 let mut update = proto::UpdateChannels::default();
2936 update
2937 .delete_channels
2938 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2939
2940 let connection_pool = session.connection_pool().await;
2941 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2942 session.peer.send(connection_id, update.clone())?;
2943 }
2944
2945 Ok(())
2946}
2947
2948/// Invite someone to join a channel.
2949async fn invite_channel_member(
2950 request: proto::InviteChannelMember,
2951 response: Response<proto::InviteChannelMember>,
2952 session: MessageContext,
2953) -> Result<()> {
2954 let db = session.db().await;
2955 let channel_id = ChannelId::from_proto(request.channel_id);
2956 let invitee_id = UserId::from_proto(request.user_id);
2957 let InviteMemberResult {
2958 channel,
2959 notifications,
2960 } = db
2961 .invite_channel_member(
2962 channel_id,
2963 invitee_id,
2964 session.user_id(),
2965 request.role().into(),
2966 )
2967 .await?;
2968
2969 let update = proto::UpdateChannels {
2970 channel_invitations: vec![channel.to_proto()],
2971 ..Default::default()
2972 };
2973
2974 let connection_pool = session.connection_pool().await;
2975 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2976 session.peer.send(connection_id, update.clone())?;
2977 }
2978
2979 send_notifications(&connection_pool, &session.peer, notifications);
2980
2981 response.send(proto::Ack {})?;
2982 Ok(())
2983}
2984
2985/// remove someone from a channel
2986async fn remove_channel_member(
2987 request: proto::RemoveChannelMember,
2988 response: Response<proto::RemoveChannelMember>,
2989 session: MessageContext,
2990) -> Result<()> {
2991 let db = session.db().await;
2992 let channel_id = ChannelId::from_proto(request.channel_id);
2993 let member_id = UserId::from_proto(request.user_id);
2994
2995 let RemoveChannelMemberResult {
2996 membership_update,
2997 notification_id,
2998 } = db
2999 .remove_channel_member(channel_id, member_id, session.user_id())
3000 .await?;
3001
3002 let mut connection_pool = session.connection_pool().await;
3003 notify_membership_updated(
3004 &mut connection_pool,
3005 membership_update,
3006 member_id,
3007 &session.peer,
3008 );
3009 for connection_id in connection_pool.user_connection_ids(member_id) {
3010 if let Some(notification_id) = notification_id {
3011 session
3012 .peer
3013 .send(
3014 connection_id,
3015 proto::DeleteNotification {
3016 notification_id: notification_id.to_proto(),
3017 },
3018 )
3019 .trace_err();
3020 }
3021 }
3022
3023 response.send(proto::Ack {})?;
3024 Ok(())
3025}
3026
3027/// Toggle the channel between public and private.
3028/// Care is taken to maintain the invariant that public channels only descend from public channels,
3029/// (though members-only channels can appear at any point in the hierarchy).
3030async fn set_channel_visibility(
3031 request: proto::SetChannelVisibility,
3032 response: Response<proto::SetChannelVisibility>,
3033 session: MessageContext,
3034) -> Result<()> {
3035 let db = session.db().await;
3036 let channel_id = ChannelId::from_proto(request.channel_id);
3037 let visibility = request.visibility().into();
3038
3039 let channel_model = db
3040 .set_channel_visibility(channel_id, visibility, session.user_id())
3041 .await?;
3042 let root_id = channel_model.root_id();
3043 let channel = Channel::from_model(channel_model);
3044
3045 let mut connection_pool = session.connection_pool().await;
3046 for (user_id, role) in connection_pool
3047 .channel_user_ids(root_id)
3048 .collect::<Vec<_>>()
3049 .into_iter()
3050 {
3051 let update = if role.can_see_channel(channel.visibility) {
3052 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3053 proto::UpdateChannels {
3054 channels: vec![channel.to_proto()],
3055 ..Default::default()
3056 }
3057 } else {
3058 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3059 proto::UpdateChannels {
3060 delete_channels: vec![channel.id.to_proto()],
3061 ..Default::default()
3062 }
3063 };
3064
3065 for connection_id in connection_pool.user_connection_ids(user_id) {
3066 session.peer.send(connection_id, update.clone())?;
3067 }
3068 }
3069
3070 response.send(proto::Ack {})?;
3071 Ok(())
3072}
3073
3074/// Alter the role for a user in the channel.
3075async fn set_channel_member_role(
3076 request: proto::SetChannelMemberRole,
3077 response: Response<proto::SetChannelMemberRole>,
3078 session: MessageContext,
3079) -> Result<()> {
3080 let db = session.db().await;
3081 let channel_id = ChannelId::from_proto(request.channel_id);
3082 let member_id = UserId::from_proto(request.user_id);
3083 let result = db
3084 .set_channel_member_role(
3085 channel_id,
3086 session.user_id(),
3087 member_id,
3088 request.role().into(),
3089 )
3090 .await?;
3091
3092 match result {
3093 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3094 let mut connection_pool = session.connection_pool().await;
3095 notify_membership_updated(
3096 &mut connection_pool,
3097 membership_update,
3098 member_id,
3099 &session.peer,
3100 )
3101 }
3102 db::SetMemberRoleResult::InviteUpdated(channel) => {
3103 let update = proto::UpdateChannels {
3104 channel_invitations: vec![channel.to_proto()],
3105 ..Default::default()
3106 };
3107
3108 for connection_id in session
3109 .connection_pool()
3110 .await
3111 .user_connection_ids(member_id)
3112 {
3113 session.peer.send(connection_id, update.clone())?;
3114 }
3115 }
3116 }
3117
3118 response.send(proto::Ack {})?;
3119 Ok(())
3120}
3121
3122/// Change the name of a channel
3123async fn rename_channel(
3124 request: proto::RenameChannel,
3125 response: Response<proto::RenameChannel>,
3126 session: MessageContext,
3127) -> Result<()> {
3128 let db = session.db().await;
3129 let channel_id = ChannelId::from_proto(request.channel_id);
3130 let channel_model = db
3131 .rename_channel(channel_id, session.user_id(), &request.name)
3132 .await?;
3133 let root_id = channel_model.root_id();
3134 let channel = Channel::from_model(channel_model);
3135
3136 response.send(proto::RenameChannelResponse {
3137 channel: Some(channel.to_proto()),
3138 })?;
3139
3140 let connection_pool = session.connection_pool().await;
3141 let update = proto::UpdateChannels {
3142 channels: vec![channel.to_proto()],
3143 ..Default::default()
3144 };
3145 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3146 if role.can_see_channel(channel.visibility) {
3147 session.peer.send(connection_id, update.clone())?;
3148 }
3149 }
3150
3151 Ok(())
3152}
3153
3154/// Move a channel to a new parent.
3155async fn move_channel(
3156 request: proto::MoveChannel,
3157 response: Response<proto::MoveChannel>,
3158 session: MessageContext,
3159) -> Result<()> {
3160 let channel_id = ChannelId::from_proto(request.channel_id);
3161 let to = ChannelId::from_proto(request.to);
3162
3163 let (root_id, channels) = session
3164 .db()
3165 .await
3166 .move_channel(channel_id, to, session.user_id())
3167 .await?;
3168
3169 let connection_pool = session.connection_pool().await;
3170 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3171 let channels = channels
3172 .iter()
3173 .filter_map(|channel| {
3174 if role.can_see_channel(channel.visibility) {
3175 Some(channel.to_proto())
3176 } else {
3177 None
3178 }
3179 })
3180 .collect::<Vec<_>>();
3181 if channels.is_empty() {
3182 continue;
3183 }
3184
3185 let update = proto::UpdateChannels {
3186 channels,
3187 ..Default::default()
3188 };
3189
3190 session.peer.send(connection_id, update.clone())?;
3191 }
3192
3193 response.send(Ack {})?;
3194 Ok(())
3195}
3196
3197async fn reorder_channel(
3198 request: proto::ReorderChannel,
3199 response: Response<proto::ReorderChannel>,
3200 session: MessageContext,
3201) -> Result<()> {
3202 let channel_id = ChannelId::from_proto(request.channel_id);
3203 let direction = request.direction();
3204
3205 let updated_channels = session
3206 .db()
3207 .await
3208 .reorder_channel(channel_id, direction, session.user_id())
3209 .await?;
3210
3211 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3212 let connection_pool = session.connection_pool().await;
3213 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3214 let channels = updated_channels
3215 .iter()
3216 .filter_map(|channel| {
3217 if role.can_see_channel(channel.visibility) {
3218 Some(channel.to_proto())
3219 } else {
3220 None
3221 }
3222 })
3223 .collect::<Vec<_>>();
3224
3225 if channels.is_empty() {
3226 continue;
3227 }
3228
3229 let update = proto::UpdateChannels {
3230 channels,
3231 ..Default::default()
3232 };
3233
3234 session.peer.send(connection_id, update.clone())?;
3235 }
3236 }
3237
3238 response.send(Ack {})?;
3239 Ok(())
3240}
3241
3242/// Get the list of channel members
3243async fn get_channel_members(
3244 request: proto::GetChannelMembers,
3245 response: Response<proto::GetChannelMembers>,
3246 session: MessageContext,
3247) -> Result<()> {
3248 let db = session.db().await;
3249 let channel_id = ChannelId::from_proto(request.channel_id);
3250 let limit = if request.limit == 0 {
3251 u16::MAX as u64
3252 } else {
3253 request.limit
3254 };
3255 let (members, users) = db
3256 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3257 .await?;
3258 response.send(proto::GetChannelMembersResponse { members, users })?;
3259 Ok(())
3260}
3261
3262/// Accept or decline a channel invitation.
3263async fn respond_to_channel_invite(
3264 request: proto::RespondToChannelInvite,
3265 response: Response<proto::RespondToChannelInvite>,
3266 session: MessageContext,
3267) -> Result<()> {
3268 let db = session.db().await;
3269 let channel_id = ChannelId::from_proto(request.channel_id);
3270 let RespondToChannelInvite {
3271 membership_update,
3272 notifications,
3273 } = db
3274 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3275 .await?;
3276
3277 let mut connection_pool = session.connection_pool().await;
3278 if let Some(membership_update) = membership_update {
3279 notify_membership_updated(
3280 &mut connection_pool,
3281 membership_update,
3282 session.user_id(),
3283 &session.peer,
3284 );
3285 } else {
3286 let update = proto::UpdateChannels {
3287 remove_channel_invitations: vec![channel_id.to_proto()],
3288 ..Default::default()
3289 };
3290
3291 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3292 session.peer.send(connection_id, update.clone())?;
3293 }
3294 };
3295
3296 send_notifications(&connection_pool, &session.peer, notifications);
3297
3298 response.send(proto::Ack {})?;
3299
3300 Ok(())
3301}
3302
3303/// Join the channels' room
3304async fn join_channel(
3305 request: proto::JoinChannel,
3306 response: Response<proto::JoinChannel>,
3307 session: MessageContext,
3308) -> Result<()> {
3309 let channel_id = ChannelId::from_proto(request.channel_id);
3310 join_channel_internal(channel_id, Box::new(response), session).await
3311}
3312
3313trait JoinChannelInternalResponse {
3314 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3315}
3316impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3317 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3318 Response::<proto::JoinChannel>::send(self, result)
3319 }
3320}
3321impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3322 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3323 Response::<proto::JoinRoom>::send(self, result)
3324 }
3325}
3326
3327async fn join_channel_internal(
3328 channel_id: ChannelId,
3329 response: Box<impl JoinChannelInternalResponse>,
3330 session: MessageContext,
3331) -> Result<()> {
3332 let joined_room = {
3333 let mut db = session.db().await;
3334 // If zed quits without leaving the room, and the user re-opens zed before the
3335 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3336 // room they were in.
3337 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3338 tracing::info!(
3339 stale_connection_id = %connection,
3340 "cleaning up stale connection",
3341 );
3342 drop(db);
3343 leave_room_for_session(&session, connection).await?;
3344 db = session.db().await;
3345 }
3346
3347 let (joined_room, membership_updated, role) = db
3348 .join_channel(channel_id, session.user_id(), session.connection_id)
3349 .await?;
3350
3351 let live_kit_connection_info =
3352 session
3353 .app_state
3354 .livekit_client
3355 .as_ref()
3356 .and_then(|live_kit| {
3357 let (can_publish, token) = if role == ChannelRole::Guest {
3358 (
3359 false,
3360 live_kit
3361 .guest_token(
3362 &joined_room.room.livekit_room,
3363 &session.user_id().to_string(),
3364 )
3365 .trace_err()?,
3366 )
3367 } else {
3368 (
3369 true,
3370 live_kit
3371 .room_token(
3372 &joined_room.room.livekit_room,
3373 &session.user_id().to_string(),
3374 )
3375 .trace_err()?,
3376 )
3377 };
3378
3379 Some(LiveKitConnectionInfo {
3380 server_url: live_kit.url().into(),
3381 token,
3382 can_publish,
3383 })
3384 });
3385
3386 response.send(proto::JoinRoomResponse {
3387 room: Some(joined_room.room.clone()),
3388 channel_id: joined_room
3389 .channel
3390 .as_ref()
3391 .map(|channel| channel.id.to_proto()),
3392 live_kit_connection_info,
3393 })?;
3394
3395 let mut connection_pool = session.connection_pool().await;
3396 if let Some(membership_updated) = membership_updated {
3397 notify_membership_updated(
3398 &mut connection_pool,
3399 membership_updated,
3400 session.user_id(),
3401 &session.peer,
3402 );
3403 }
3404
3405 room_updated(&joined_room.room, &session.peer);
3406
3407 joined_room
3408 };
3409
3410 channel_updated(
3411 &joined_room.channel.context("channel not returned")?,
3412 &joined_room.room,
3413 &session.peer,
3414 &*session.connection_pool().await,
3415 );
3416
3417 update_user_contacts(session.user_id(), &session).await?;
3418 Ok(())
3419}
3420
3421/// Start editing the channel notes
3422async fn join_channel_buffer(
3423 request: proto::JoinChannelBuffer,
3424 response: Response<proto::JoinChannelBuffer>,
3425 session: MessageContext,
3426) -> Result<()> {
3427 let db = session.db().await;
3428 let channel_id = ChannelId::from_proto(request.channel_id);
3429
3430 let open_response = db
3431 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3432 .await?;
3433
3434 let collaborators = open_response.collaborators.clone();
3435 response.send(open_response)?;
3436
3437 let update = UpdateChannelBufferCollaborators {
3438 channel_id: channel_id.to_proto(),
3439 collaborators: collaborators.clone(),
3440 };
3441 channel_buffer_updated(
3442 session.connection_id,
3443 collaborators
3444 .iter()
3445 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3446 &update,
3447 &session.peer,
3448 );
3449
3450 Ok(())
3451}
3452
3453/// Edit the channel notes
3454async fn update_channel_buffer(
3455 request: proto::UpdateChannelBuffer,
3456 session: MessageContext,
3457) -> Result<()> {
3458 let db = session.db().await;
3459 let channel_id = ChannelId::from_proto(request.channel_id);
3460
3461 let (collaborators, epoch, version) = db
3462 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3463 .await?;
3464
3465 channel_buffer_updated(
3466 session.connection_id,
3467 collaborators.clone(),
3468 &proto::UpdateChannelBuffer {
3469 channel_id: channel_id.to_proto(),
3470 operations: request.operations,
3471 },
3472 &session.peer,
3473 );
3474
3475 let pool = &*session.connection_pool().await;
3476
3477 let non_collaborators =
3478 pool.channel_connection_ids(channel_id)
3479 .filter_map(|(connection_id, _)| {
3480 if collaborators.contains(&connection_id) {
3481 None
3482 } else {
3483 Some(connection_id)
3484 }
3485 });
3486
3487 broadcast(None, non_collaborators, |peer_id| {
3488 session.peer.send(
3489 peer_id,
3490 proto::UpdateChannels {
3491 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3492 channel_id: channel_id.to_proto(),
3493 epoch: epoch as u64,
3494 version: version.clone(),
3495 }],
3496 ..Default::default()
3497 },
3498 )
3499 });
3500
3501 Ok(())
3502}
3503
3504/// Rejoin the channel notes after a connection blip
3505async fn rejoin_channel_buffers(
3506 request: proto::RejoinChannelBuffers,
3507 response: Response<proto::RejoinChannelBuffers>,
3508 session: MessageContext,
3509) -> Result<()> {
3510 let db = session.db().await;
3511 let buffers = db
3512 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3513 .await?;
3514
3515 for rejoined_buffer in &buffers {
3516 let collaborators_to_notify = rejoined_buffer
3517 .buffer
3518 .collaborators
3519 .iter()
3520 .filter_map(|c| Some(c.peer_id?.into()));
3521 channel_buffer_updated(
3522 session.connection_id,
3523 collaborators_to_notify,
3524 &proto::UpdateChannelBufferCollaborators {
3525 channel_id: rejoined_buffer.buffer.channel_id,
3526 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3527 },
3528 &session.peer,
3529 );
3530 }
3531
3532 response.send(proto::RejoinChannelBuffersResponse {
3533 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3534 })?;
3535
3536 Ok(())
3537}
3538
3539/// Stop editing the channel notes
3540async fn leave_channel_buffer(
3541 request: proto::LeaveChannelBuffer,
3542 response: Response<proto::LeaveChannelBuffer>,
3543 session: MessageContext,
3544) -> Result<()> {
3545 let db = session.db().await;
3546 let channel_id = ChannelId::from_proto(request.channel_id);
3547
3548 let left_buffer = db
3549 .leave_channel_buffer(channel_id, session.connection_id)
3550 .await?;
3551
3552 response.send(Ack {})?;
3553
3554 channel_buffer_updated(
3555 session.connection_id,
3556 left_buffer.connections,
3557 &proto::UpdateChannelBufferCollaborators {
3558 channel_id: channel_id.to_proto(),
3559 collaborators: left_buffer.collaborators,
3560 },
3561 &session.peer,
3562 );
3563
3564 Ok(())
3565}
3566
3567fn channel_buffer_updated<T: EnvelopedMessage>(
3568 sender_id: ConnectionId,
3569 collaborators: impl IntoIterator<Item = ConnectionId>,
3570 message: &T,
3571 peer: &Peer,
3572) {
3573 broadcast(Some(sender_id), collaborators, |peer_id| {
3574 peer.send(peer_id, message.clone())
3575 });
3576}
3577
3578fn send_notifications(
3579 connection_pool: &ConnectionPool,
3580 peer: &Peer,
3581 notifications: db::NotificationBatch,
3582) {
3583 for (user_id, notification) in notifications {
3584 for connection_id in connection_pool.user_connection_ids(user_id) {
3585 if let Err(error) = peer.send(
3586 connection_id,
3587 proto::AddNotification {
3588 notification: Some(notification.clone()),
3589 },
3590 ) {
3591 tracing::error!(
3592 "failed to send notification to {:?} {}",
3593 connection_id,
3594 error
3595 );
3596 }
3597 }
3598 }
3599}
3600
3601/// Send a message to the channel
3602async fn send_channel_message(
3603 _request: proto::SendChannelMessage,
3604 _response: Response<proto::SendChannelMessage>,
3605 _session: MessageContext,
3606) -> Result<()> {
3607 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3608}
3609
3610/// Delete a channel message
3611async fn remove_channel_message(
3612 _request: proto::RemoveChannelMessage,
3613 _response: Response<proto::RemoveChannelMessage>,
3614 _session: MessageContext,
3615) -> Result<()> {
3616 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3617}
3618
3619async fn update_channel_message(
3620 _request: proto::UpdateChannelMessage,
3621 _response: Response<proto::UpdateChannelMessage>,
3622 _session: MessageContext,
3623) -> Result<()> {
3624 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3625}
3626
3627/// Mark a channel message as read
3628async fn acknowledge_channel_message(
3629 _request: proto::AckChannelMessage,
3630 _session: MessageContext,
3631) -> Result<()> {
3632 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3633}
3634
3635/// Mark a buffer version as synced
3636async fn acknowledge_buffer_version(
3637 request: proto::AckBufferOperation,
3638 session: MessageContext,
3639) -> Result<()> {
3640 let buffer_id = BufferId::from_proto(request.buffer_id);
3641 session
3642 .db()
3643 .await
3644 .observe_buffer_version(
3645 buffer_id,
3646 session.user_id(),
3647 request.epoch as i32,
3648 &request.version,
3649 )
3650 .await?;
3651 Ok(())
3652}
3653
3654/// Get a Supermaven API key for the user
3655async fn get_supermaven_api_key(
3656 _request: proto::GetSupermavenApiKey,
3657 response: Response<proto::GetSupermavenApiKey>,
3658 session: MessageContext,
3659) -> Result<()> {
3660 let user_id: String = session.user_id().to_string();
3661 if !session.is_staff() {
3662 return Err(anyhow!("supermaven not enabled for this account"))?;
3663 }
3664
3665 let email = session.email().context("user must have an email")?;
3666
3667 let supermaven_admin_api = session
3668 .supermaven_client
3669 .as_ref()
3670 .context("supermaven not configured")?;
3671
3672 let result = supermaven_admin_api
3673 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3674 .await?;
3675
3676 response.send(proto::GetSupermavenApiKeyResponse {
3677 api_key: result.api_key,
3678 })?;
3679
3680 Ok(())
3681}
3682
3683/// Start receiving chat updates for a channel
3684async fn join_channel_chat(
3685 _request: proto::JoinChannelChat,
3686 _response: Response<proto::JoinChannelChat>,
3687 _session: MessageContext,
3688) -> Result<()> {
3689 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3690}
3691
3692/// Stop receiving chat updates for a channel
3693async fn leave_channel_chat(
3694 _request: proto::LeaveChannelChat,
3695 _session: MessageContext,
3696) -> Result<()> {
3697 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3698}
3699
3700/// Retrieve the chat history for a channel
3701async fn get_channel_messages(
3702 _request: proto::GetChannelMessages,
3703 _response: Response<proto::GetChannelMessages>,
3704 _session: MessageContext,
3705) -> Result<()> {
3706 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3707}
3708
3709/// Retrieve specific chat messages
3710async fn get_channel_messages_by_id(
3711 _request: proto::GetChannelMessagesById,
3712 _response: Response<proto::GetChannelMessagesById>,
3713 _session: MessageContext,
3714) -> Result<()> {
3715 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3716}
3717
3718/// Retrieve the current users notifications
3719async fn get_notifications(
3720 request: proto::GetNotifications,
3721 response: Response<proto::GetNotifications>,
3722 session: MessageContext,
3723) -> Result<()> {
3724 let notifications = session
3725 .db()
3726 .await
3727 .get_notifications(
3728 session.user_id(),
3729 NOTIFICATION_COUNT_PER_PAGE,
3730 request.before_id.map(db::NotificationId::from_proto),
3731 )
3732 .await?;
3733 response.send(proto::GetNotificationsResponse {
3734 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3735 notifications,
3736 })?;
3737 Ok(())
3738}
3739
3740/// Mark notifications as read
3741async fn mark_notification_as_read(
3742 request: proto::MarkNotificationRead,
3743 response: Response<proto::MarkNotificationRead>,
3744 session: MessageContext,
3745) -> Result<()> {
3746 let database = &session.db().await;
3747 let notifications = database
3748 .mark_notification_as_read_by_id(
3749 session.user_id(),
3750 NotificationId::from_proto(request.notification_id),
3751 )
3752 .await?;
3753 send_notifications(
3754 &*session.connection_pool().await,
3755 &session.peer,
3756 notifications,
3757 );
3758 response.send(proto::Ack {})?;
3759 Ok(())
3760}
3761
3762fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3763 let message = match message {
3764 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3765 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3766 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3767 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3768 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3769 code: frame.code.into(),
3770 reason: frame.reason.as_str().to_owned().into(),
3771 })),
3772 // We should never receive a frame while reading the message, according
3773 // to the `tungstenite` maintainers:
3774 //
3775 // > It cannot occur when you read messages from the WebSocket, but it
3776 // > can be used when you want to send the raw frames (e.g. you want to
3777 // > send the frames to the WebSocket without composing the full message first).
3778 // >
3779 // > — https://github.com/snapview/tungstenite-rs/issues/268
3780 TungsteniteMessage::Frame(_) => {
3781 bail!("received an unexpected frame while reading the message")
3782 }
3783 };
3784
3785 Ok(message)
3786}
3787
3788fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3789 match message {
3790 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3791 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3792 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3793 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3794 AxumMessage::Close(frame) => {
3795 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3796 code: frame.code.into(),
3797 reason: frame.reason.as_ref().into(),
3798 }))
3799 }
3800 }
3801}
3802
3803fn notify_membership_updated(
3804 connection_pool: &mut ConnectionPool,
3805 result: MembershipUpdated,
3806 user_id: UserId,
3807 peer: &Peer,
3808) {
3809 for membership in &result.new_channels.channel_memberships {
3810 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3811 }
3812 for channel_id in &result.removed_channels {
3813 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3814 }
3815
3816 let user_channels_update = proto::UpdateUserChannels {
3817 channel_memberships: result
3818 .new_channels
3819 .channel_memberships
3820 .iter()
3821 .map(|cm| proto::ChannelMembership {
3822 channel_id: cm.channel_id.to_proto(),
3823 role: cm.role.into(),
3824 })
3825 .collect(),
3826 ..Default::default()
3827 };
3828
3829 let mut update = build_channels_update(result.new_channels);
3830 update.delete_channels = result
3831 .removed_channels
3832 .into_iter()
3833 .map(|id| id.to_proto())
3834 .collect();
3835 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3836
3837 for connection_id in connection_pool.user_connection_ids(user_id) {
3838 peer.send(connection_id, user_channels_update.clone())
3839 .trace_err();
3840 peer.send(connection_id, update.clone()).trace_err();
3841 }
3842}
3843
3844fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3845 proto::UpdateUserChannels {
3846 channel_memberships: channels
3847 .channel_memberships
3848 .iter()
3849 .map(|m| proto::ChannelMembership {
3850 channel_id: m.channel_id.to_proto(),
3851 role: m.role.into(),
3852 })
3853 .collect(),
3854 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3855 }
3856}
3857
3858fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3859 let mut update = proto::UpdateChannels::default();
3860
3861 for channel in channels.channels {
3862 update.channels.push(channel.to_proto());
3863 }
3864
3865 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3866
3867 for (channel_id, participants) in channels.channel_participants {
3868 update
3869 .channel_participants
3870 .push(proto::ChannelParticipants {
3871 channel_id: channel_id.to_proto(),
3872 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3873 });
3874 }
3875
3876 for channel in channels.invited_channels {
3877 update.channel_invitations.push(channel.to_proto());
3878 }
3879
3880 update
3881}
3882
3883fn build_initial_contacts_update(
3884 contacts: Vec<db::Contact>,
3885 pool: &ConnectionPool,
3886) -> proto::UpdateContacts {
3887 let mut update = proto::UpdateContacts::default();
3888
3889 for contact in contacts {
3890 match contact {
3891 db::Contact::Accepted { user_id, busy } => {
3892 update.contacts.push(contact_for_user(user_id, busy, pool));
3893 }
3894 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3895 db::Contact::Incoming { user_id } => {
3896 update
3897 .incoming_requests
3898 .push(proto::IncomingContactRequest {
3899 requester_id: user_id.to_proto(),
3900 })
3901 }
3902 }
3903 }
3904
3905 update
3906}
3907
3908fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3909 proto::Contact {
3910 user_id: user_id.to_proto(),
3911 online: pool.is_user_online(user_id),
3912 busy,
3913 }
3914}
3915
3916fn room_updated(room: &proto::Room, peer: &Peer) {
3917 broadcast(
3918 None,
3919 room.participants
3920 .iter()
3921 .filter_map(|participant| Some(participant.peer_id?.into())),
3922 |peer_id| {
3923 peer.send(
3924 peer_id,
3925 proto::RoomUpdated {
3926 room: Some(room.clone()),
3927 },
3928 )
3929 },
3930 );
3931}
3932
3933fn channel_updated(
3934 channel: &db::channel::Model,
3935 room: &proto::Room,
3936 peer: &Peer,
3937 pool: &ConnectionPool,
3938) {
3939 let participants = room
3940 .participants
3941 .iter()
3942 .map(|p| p.user_id)
3943 .collect::<Vec<_>>();
3944
3945 broadcast(
3946 None,
3947 pool.channel_connection_ids(channel.root_id())
3948 .filter_map(|(channel_id, role)| {
3949 role.can_see_channel(channel.visibility)
3950 .then_some(channel_id)
3951 }),
3952 |peer_id| {
3953 peer.send(
3954 peer_id,
3955 proto::UpdateChannels {
3956 channel_participants: vec![proto::ChannelParticipants {
3957 channel_id: channel.id.to_proto(),
3958 participant_user_ids: participants.clone(),
3959 }],
3960 ..Default::default()
3961 },
3962 )
3963 },
3964 );
3965}
3966
3967async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3968 let db = session.db().await;
3969
3970 let contacts = db.get_contacts(user_id).await?;
3971 let busy = db.is_user_busy(user_id).await?;
3972
3973 let pool = session.connection_pool().await;
3974 let updated_contact = contact_for_user(user_id, busy, &pool);
3975 for contact in contacts {
3976 if let db::Contact::Accepted {
3977 user_id: contact_user_id,
3978 ..
3979 } = contact
3980 {
3981 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3982 session
3983 .peer
3984 .send(
3985 contact_conn_id,
3986 proto::UpdateContacts {
3987 contacts: vec![updated_contact.clone()],
3988 remove_contacts: Default::default(),
3989 incoming_requests: Default::default(),
3990 remove_incoming_requests: Default::default(),
3991 outgoing_requests: Default::default(),
3992 remove_outgoing_requests: Default::default(),
3993 },
3994 )
3995 .trace_err();
3996 }
3997 }
3998 }
3999 Ok(())
4000}
4001
4002async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4003 let mut contacts_to_update = HashSet::default();
4004
4005 let room_id;
4006 let canceled_calls_to_user_ids;
4007 let livekit_room;
4008 let delete_livekit_room;
4009 let room;
4010 let channel;
4011
4012 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4013 contacts_to_update.insert(session.user_id());
4014
4015 for project in left_room.left_projects.values() {
4016 project_left(project, session);
4017 }
4018
4019 room_id = RoomId::from_proto(left_room.room.id);
4020 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4021 livekit_room = mem::take(&mut left_room.room.livekit_room);
4022 delete_livekit_room = left_room.deleted;
4023 room = mem::take(&mut left_room.room);
4024 channel = mem::take(&mut left_room.channel);
4025
4026 room_updated(&room, &session.peer);
4027 } else {
4028 return Ok(());
4029 }
4030
4031 if let Some(channel) = channel {
4032 channel_updated(
4033 &channel,
4034 &room,
4035 &session.peer,
4036 &*session.connection_pool().await,
4037 );
4038 }
4039
4040 {
4041 let pool = session.connection_pool().await;
4042 for canceled_user_id in canceled_calls_to_user_ids {
4043 for connection_id in pool.user_connection_ids(canceled_user_id) {
4044 session
4045 .peer
4046 .send(
4047 connection_id,
4048 proto::CallCanceled {
4049 room_id: room_id.to_proto(),
4050 },
4051 )
4052 .trace_err();
4053 }
4054 contacts_to_update.insert(canceled_user_id);
4055 }
4056 }
4057
4058 for contact_user_id in contacts_to_update {
4059 update_user_contacts(contact_user_id, session).await?;
4060 }
4061
4062 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4063 live_kit
4064 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4065 .await
4066 .trace_err();
4067
4068 if delete_livekit_room {
4069 live_kit.delete_room(livekit_room).await.trace_err();
4070 }
4071 }
4072
4073 Ok(())
4074}
4075
4076async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4077 let left_channel_buffers = session
4078 .db()
4079 .await
4080 .leave_channel_buffers(session.connection_id)
4081 .await?;
4082
4083 for left_buffer in left_channel_buffers {
4084 channel_buffer_updated(
4085 session.connection_id,
4086 left_buffer.connections,
4087 &proto::UpdateChannelBufferCollaborators {
4088 channel_id: left_buffer.channel_id.to_proto(),
4089 collaborators: left_buffer.collaborators,
4090 },
4091 &session.peer,
4092 );
4093 }
4094
4095 Ok(())
4096}
4097
4098fn project_left(project: &db::LeftProject, session: &Session) {
4099 for connection_id in &project.connection_ids {
4100 if project.should_unshare {
4101 session
4102 .peer
4103 .send(
4104 *connection_id,
4105 proto::UnshareProject {
4106 project_id: project.id.to_proto(),
4107 },
4108 )
4109 .trace_err();
4110 } else {
4111 session
4112 .peer
4113 .send(
4114 *connection_id,
4115 proto::RemoveProjectCollaborator {
4116 project_id: project.id.to_proto(),
4117 peer_id: Some(session.connection_id.into()),
4118 },
4119 )
4120 .trace_err();
4121 }
4122 }
4123}
4124
4125pub trait ResultExt {
4126 type Ok;
4127
4128 fn trace_err(self) -> Option<Self::Ok>;
4129}
4130
4131impl<T, E> ResultExt for Result<T, E>
4132where
4133 E: std::fmt::Debug,
4134{
4135 type Ok = T;
4136
4137 #[track_caller]
4138 fn trace_err(self) -> Option<T> {
4139 match self {
4140 Ok(value) => Some(value),
4141 Err(error) => {
4142 tracing::error!("{:?}", error);
4143 None
4144 }
4145 }
4146 }
4147}