1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
10 },
11 executor::Executor,
12};
13use anyhow::{Context as _, anyhow, bail};
14use async_tungstenite::tungstenite::{
15 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
16};
17use axum::headers::UserAgent;
18use axum::{
19 Extension, Router, TypedHeader,
20 body::Body,
21 extract::{
22 ConnectInfo, WebSocketUpgrade,
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use core::fmt::{self, Debug, Formatter};
34use futures::TryFutureExt as _;
35use reqwest_client::ReqwestClient;
36use rpc::proto::split_repository_update;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38use tracing::Span;
39use util::paths::PathStyle;
40
41use futures::{
42 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
43 stream::FuturesUnordered,
44};
45use prometheus::{IntGauge, register_int_gauge};
46use rpc::{
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52};
53use semver::Version;
54use serde::{Serialize, Serializer};
55use std::{
56 any::TypeId,
57 future::Future,
58 marker::PhantomData,
59 mem,
60 net::SocketAddr,
61 ops::{Deref, DerefMut},
62 rc::Rc,
63 sync::{
64 Arc, OnceLock,
65 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
66 },
67 time::{Duration, Instant},
68};
69use tokio::sync::{Semaphore, watch};
70use tower::ServiceBuilder;
71use tracing::{
72 Instrument,
73 field::{self},
74 info_span, instrument,
75};
76
77pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
78
79// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
80pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
81
82const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
83const MAX_CONCURRENT_CONNECTIONS: usize = 512;
84
85static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
86
87const TOTAL_DURATION_MS: &str = "total_duration_ms";
88const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
89const QUEUE_DURATION_MS: &str = "queue_duration_ms";
90const HOST_WAITING_MS: &str = "host_waiting_ms";
91
92type MessageHandler =
93 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
94
95pub struct ConnectionGuard;
96
97impl ConnectionGuard {
98 pub fn try_acquire() -> Result<Self, ()> {
99 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
100 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
101 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
102 tracing::error!(
103 "too many concurrent connections: {}",
104 current_connections + 1
105 );
106 return Err(());
107 }
108 Ok(ConnectionGuard)
109 }
110}
111
112impl Drop for ConnectionGuard {
113 fn drop(&mut self) {
114 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
115 }
116}
117
118struct Response<R> {
119 peer: Arc<Peer>,
120 receipt: Receipt<R>,
121 responded: Arc<AtomicBool>,
122}
123
124impl<R: RequestMessage> Response<R> {
125 fn send(self, payload: R::Response) -> Result<()> {
126 self.responded.store(true, SeqCst);
127 self.peer.respond(self.receipt, payload)?;
128 Ok(())
129 }
130}
131
132#[derive(Clone, Debug)]
133pub enum Principal {
134 User(User),
135 Impersonated { user: User, admin: User },
136}
137
138impl Principal {
139 fn update_span(&self, span: &tracing::Span) {
140 match &self {
141 Principal::User(user) => {
142 span.record("user_id", user.id.0);
143 span.record("login", &user.github_login);
144 }
145 Principal::Impersonated { user, admin } => {
146 span.record("user_id", user.id.0);
147 span.record("login", &user.github_login);
148 span.record("impersonator", &admin.github_login);
149 }
150 }
151 }
152}
153
154#[derive(Clone)]
155struct MessageContext {
156 session: Session,
157 span: tracing::Span,
158}
159
160impl Deref for MessageContext {
161 type Target = Session;
162
163 fn deref(&self) -> &Self::Target {
164 &self.session
165 }
166}
167
168impl MessageContext {
169 pub fn forward_request<T: RequestMessage>(
170 &self,
171 receiver_id: ConnectionId,
172 request: T,
173 ) -> impl Future<Output = anyhow::Result<T::Response>> {
174 let request_start_time = Instant::now();
175 let span = self.span.clone();
176 tracing::info!("start forwarding request");
177 self.peer
178 .forward_request(self.connection_id, receiver_id, request)
179 .inspect(move |_| {
180 span.record(
181 HOST_WAITING_MS,
182 request_start_time.elapsed().as_micros() as f64 / 1000.0,
183 );
184 })
185 .inspect_err(|_| tracing::error!("error forwarding request"))
186 .inspect_ok(|_| tracing::info!("finished forwarding request"))
187 }
188}
189
190#[derive(Clone)]
191struct Session {
192 principal: Principal,
193 connection_id: ConnectionId,
194 db: Arc<tokio::sync::Mutex<DbHandle>>,
195 peer: Arc<Peer>,
196 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
197 app_state: Arc<AppState>,
198 supermaven_client: Option<Arc<SupermavenAdminApi>>,
199 /// The GeoIP country code for the user.
200 #[allow(unused)]
201 geoip_country_code: Option<String>,
202 #[allow(unused)]
203 system_id: Option<String>,
204 _executor: Executor,
205}
206
207impl Session {
208 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
209 #[cfg(test)]
210 tokio::task::yield_now().await;
211 let guard = self.db.lock().await;
212 #[cfg(test)]
213 tokio::task::yield_now().await;
214 guard
215 }
216
217 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
218 #[cfg(test)]
219 tokio::task::yield_now().await;
220 let guard = self.connection_pool.lock();
221 ConnectionPoolGuard {
222 guard,
223 _not_send: PhantomData,
224 }
225 }
226
227 fn is_staff(&self) -> bool {
228 match &self.principal {
229 Principal::User(user) => user.admin,
230 Principal::Impersonated { .. } => true,
231 }
232 }
233
234 fn user_id(&self) -> UserId {
235 match &self.principal {
236 Principal::User(user) => user.id,
237 Principal::Impersonated { user, .. } => user.id,
238 }
239 }
240
241 pub fn email(&self) -> Option<String> {
242 match &self.principal {
243 Principal::User(user) => user.email_address.clone(),
244 Principal::Impersonated { user, .. } => user.email_address.clone(),
245 }
246 }
247}
248
249impl Debug for Session {
250 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
251 let mut result = f.debug_struct("Session");
252 match &self.principal {
253 Principal::User(user) => {
254 result.field("user", &user.github_login);
255 }
256 Principal::Impersonated { user, admin } => {
257 result.field("user", &user.github_login);
258 result.field("impersonator", &admin.github_login);
259 }
260 }
261 result.field("connection_id", &self.connection_id).finish()
262 }
263}
264
265struct DbHandle(Arc<Database>);
266
267impl Deref for DbHandle {
268 type Target = Database;
269
270 fn deref(&self) -> &Self::Target {
271 self.0.as_ref()
272 }
273}
274
275pub struct Server {
276 id: parking_lot::Mutex<ServerId>,
277 peer: Arc<Peer>,
278 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
279 app_state: Arc<AppState>,
280 handlers: HashMap<TypeId, MessageHandler>,
281 teardown: watch::Sender<bool>,
282}
283
284pub(crate) struct ConnectionPoolGuard<'a> {
285 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
286 _not_send: PhantomData<Rc<()>>,
287}
288
289#[derive(Serialize)]
290pub struct ServerSnapshot<'a> {
291 peer: &'a Peer,
292 #[serde(serialize_with = "serialize_deref")]
293 connection_pool: ConnectionPoolGuard<'a>,
294}
295
296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
297where
298 S: Serializer,
299 T: Deref<Target = U>,
300 U: Serialize,
301{
302 Serialize::serialize(value.deref(), serializer)
303}
304
305impl Server {
306 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
307 let mut server = Self {
308 id: parking_lot::Mutex::new(id),
309 peer: Peer::new(id.0 as u32),
310 app_state,
311 connection_pool: Default::default(),
312 handlers: Default::default(),
313 teardown: watch::channel(false).0,
314 };
315
316 server
317 .add_request_handler(ping)
318 .add_request_handler(create_room)
319 .add_request_handler(join_room)
320 .add_request_handler(rejoin_room)
321 .add_request_handler(leave_room)
322 .add_request_handler(set_room_participant_role)
323 .add_request_handler(call)
324 .add_request_handler(cancel_call)
325 .add_message_handler(decline_call)
326 .add_request_handler(update_participant_location)
327 .add_request_handler(share_project)
328 .add_message_handler(unshare_project)
329 .add_request_handler(join_project)
330 .add_message_handler(leave_project)
331 .add_request_handler(update_project)
332 .add_request_handler(update_worktree)
333 .add_request_handler(update_repository)
334 .add_request_handler(remove_repository)
335 .add_message_handler(start_language_server)
336 .add_message_handler(update_language_server)
337 .add_message_handler(update_diagnostic_summary)
338 .add_message_handler(update_worktree_settings)
339 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
340 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
341 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
342 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
343 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
344 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
345 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
346 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
347 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
348 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
349 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
350 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
351 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
352 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
353 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
354 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
355 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
356 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
357 .add_request_handler(
358 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
359 )
360 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
361 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
362 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
363 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
364 .add_request_handler(
365 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
366 )
367 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
368 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
369 .add_request_handler(
370 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
371 )
372 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
373 .add_request_handler(
374 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
375 )
376 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
377 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
378 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
379 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
380 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
381 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
382 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
383 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
384 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
385 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
386 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
387 .add_request_handler(
388 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
389 )
390 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
391 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
392 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
393 .add_request_handler(lsp_query)
394 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
395 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
396 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
397 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
398 .add_message_handler(create_buffer_for_peer)
399 .add_message_handler(create_image_for_peer)
400 .add_request_handler(update_buffer)
401 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
402 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
403 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
404 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
405 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
406 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
407 .add_message_handler(
408 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
409 )
410 .add_request_handler(get_users)
411 .add_request_handler(fuzzy_search_users)
412 .add_request_handler(request_contact)
413 .add_request_handler(remove_contact)
414 .add_request_handler(respond_to_contact_request)
415 .add_message_handler(subscribe_to_channels)
416 .add_request_handler(create_channel)
417 .add_request_handler(delete_channel)
418 .add_request_handler(invite_channel_member)
419 .add_request_handler(remove_channel_member)
420 .add_request_handler(set_channel_member_role)
421 .add_request_handler(set_channel_visibility)
422 .add_request_handler(rename_channel)
423 .add_request_handler(join_channel_buffer)
424 .add_request_handler(leave_channel_buffer)
425 .add_message_handler(update_channel_buffer)
426 .add_request_handler(rejoin_channel_buffers)
427 .add_request_handler(get_channel_members)
428 .add_request_handler(respond_to_channel_invite)
429 .add_request_handler(join_channel)
430 .add_request_handler(join_channel_chat)
431 .add_message_handler(leave_channel_chat)
432 .add_request_handler(send_channel_message)
433 .add_request_handler(remove_channel_message)
434 .add_request_handler(update_channel_message)
435 .add_request_handler(get_channel_messages)
436 .add_request_handler(get_channel_messages_by_id)
437 .add_request_handler(get_notifications)
438 .add_request_handler(mark_notification_as_read)
439 .add_request_handler(move_channel)
440 .add_request_handler(reorder_channel)
441 .add_request_handler(follow)
442 .add_message_handler(unfollow)
443 .add_message_handler(update_followers)
444 .add_message_handler(acknowledge_channel_message)
445 .add_message_handler(acknowledge_buffer_version)
446 .add_request_handler(get_supermaven_api_key)
447 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
448 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
449 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
450 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
451 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
452 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
453 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
454 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
455 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
456 .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
457 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
458 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
459 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
460 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
461 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
462 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
463 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
464 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
465 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
466 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
467 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
468 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
469 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
470 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
471 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
472 .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
473 .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
474 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
475 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
476 .add_message_handler(update_context)
477 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
478 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
479
480 Arc::new(server)
481 }
482
483 pub async fn start(&self) -> Result<()> {
484 let server_id = *self.id.lock();
485 let app_state = self.app_state.clone();
486 let peer = self.peer.clone();
487 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
488 let pool = self.connection_pool.clone();
489 let livekit_client = self.app_state.livekit_client.clone();
490
491 let span = info_span!("start server");
492 self.app_state.executor.spawn_detached(
493 async move {
494 tracing::info!("waiting for cleanup timeout");
495 timeout.await;
496 tracing::info!("cleanup timeout expired, retrieving stale rooms");
497
498 app_state
499 .db
500 .delete_stale_channel_chat_participants(
501 &app_state.config.zed_environment,
502 server_id,
503 )
504 .await
505 .trace_err();
506
507 if let Some((room_ids, channel_ids)) = app_state
508 .db
509 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
510 .await
511 .trace_err()
512 {
513 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
514 tracing::info!(
515 stale_channel_buffer_count = channel_ids.len(),
516 "retrieved stale channel buffers"
517 );
518
519 for channel_id in channel_ids {
520 if let Some(refreshed_channel_buffer) = app_state
521 .db
522 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
523 .await
524 .trace_err()
525 {
526 for connection_id in refreshed_channel_buffer.connection_ids {
527 peer.send(
528 connection_id,
529 proto::UpdateChannelBufferCollaborators {
530 channel_id: channel_id.to_proto(),
531 collaborators: refreshed_channel_buffer
532 .collaborators
533 .clone(),
534 },
535 )
536 .trace_err();
537 }
538 }
539 }
540
541 for room_id in room_ids {
542 let mut contacts_to_update = HashSet::default();
543 let mut canceled_calls_to_user_ids = Vec::new();
544 let mut livekit_room = String::new();
545 let mut delete_livekit_room = false;
546
547 if let Some(mut refreshed_room) = app_state
548 .db
549 .clear_stale_room_participants(room_id, server_id)
550 .await
551 .trace_err()
552 {
553 tracing::info!(
554 room_id = room_id.0,
555 new_participant_count = refreshed_room.room.participants.len(),
556 "refreshed room"
557 );
558 room_updated(&refreshed_room.room, &peer);
559 if let Some(channel) = refreshed_room.channel.as_ref() {
560 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
561 }
562 contacts_to_update
563 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
564 contacts_to_update
565 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
566 canceled_calls_to_user_ids =
567 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
568 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
569 delete_livekit_room = refreshed_room.room.participants.is_empty();
570 }
571
572 {
573 let pool = pool.lock();
574 for canceled_user_id in canceled_calls_to_user_ids {
575 for connection_id in pool.user_connection_ids(canceled_user_id) {
576 peer.send(
577 connection_id,
578 proto::CallCanceled {
579 room_id: room_id.to_proto(),
580 },
581 )
582 .trace_err();
583 }
584 }
585 }
586
587 for user_id in contacts_to_update {
588 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
589 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
590 if let Some((busy, contacts)) = busy.zip(contacts) {
591 let pool = pool.lock();
592 let updated_contact = contact_for_user(user_id, busy, &pool);
593 for contact in contacts {
594 if let db::Contact::Accepted {
595 user_id: contact_user_id,
596 ..
597 } = contact
598 {
599 for contact_conn_id in
600 pool.user_connection_ids(contact_user_id)
601 {
602 peer.send(
603 contact_conn_id,
604 proto::UpdateContacts {
605 contacts: vec![updated_contact.clone()],
606 remove_contacts: Default::default(),
607 incoming_requests: Default::default(),
608 remove_incoming_requests: Default::default(),
609 outgoing_requests: Default::default(),
610 remove_outgoing_requests: Default::default(),
611 },
612 )
613 .trace_err();
614 }
615 }
616 }
617 }
618 }
619
620 if let Some(live_kit) = livekit_client.as_ref()
621 && delete_livekit_room
622 {
623 live_kit.delete_room(livekit_room).await.trace_err();
624 }
625 }
626 }
627
628 app_state
629 .db
630 .delete_stale_channel_chat_participants(
631 &app_state.config.zed_environment,
632 server_id,
633 )
634 .await
635 .trace_err();
636
637 app_state
638 .db
639 .clear_old_worktree_entries(server_id)
640 .await
641 .trace_err();
642
643 app_state
644 .db
645 .delete_stale_servers(&app_state.config.zed_environment, server_id)
646 .await
647 .trace_err();
648 }
649 .instrument(span),
650 );
651 Ok(())
652 }
653
654 pub fn teardown(&self) {
655 self.peer.teardown();
656 self.connection_pool.lock().reset();
657 let _ = self.teardown.send(true);
658 }
659
660 #[cfg(test)]
661 pub fn reset(&self, id: ServerId) {
662 self.teardown();
663 *self.id.lock() = id;
664 self.peer.reset(id.0 as u32);
665 let _ = self.teardown.send(false);
666 }
667
668 #[cfg(test)]
669 pub fn id(&self) -> ServerId {
670 *self.id.lock()
671 }
672
673 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
674 where
675 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
676 Fut: 'static + Send + Future<Output = Result<()>>,
677 M: EnvelopedMessage,
678 {
679 let prev_handler = self.handlers.insert(
680 TypeId::of::<M>(),
681 Box::new(move |envelope, session, span| {
682 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
683 let received_at = envelope.received_at;
684 tracing::info!("message received");
685 let start_time = Instant::now();
686 let future = (handler)(
687 *envelope,
688 MessageContext {
689 session,
690 span: span.clone(),
691 },
692 );
693 async move {
694 let result = future.await;
695 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
696 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
697 let queue_duration_ms = total_duration_ms - processing_duration_ms;
698 span.record(TOTAL_DURATION_MS, total_duration_ms);
699 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
700 span.record(QUEUE_DURATION_MS, queue_duration_ms);
701 match result {
702 Err(error) => {
703 tracing::error!(?error, "error handling message")
704 }
705 Ok(()) => tracing::info!("finished handling message"),
706 }
707 }
708 .boxed()
709 }),
710 );
711 if prev_handler.is_some() {
712 panic!("registered a handler for the same message twice");
713 }
714 self
715 }
716
717 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
718 where
719 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
720 Fut: 'static + Send + Future<Output = Result<()>>,
721 M: EnvelopedMessage,
722 {
723 self.add_handler(move |envelope, session| handler(envelope.payload, session));
724 self
725 }
726
727 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
728 where
729 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
730 Fut: Send + Future<Output = Result<()>>,
731 M: RequestMessage,
732 {
733 let handler = Arc::new(handler);
734 self.add_handler(move |envelope, session| {
735 let receipt = envelope.receipt();
736 let handler = handler.clone();
737 async move {
738 let peer = session.peer.clone();
739 let responded = Arc::new(AtomicBool::default());
740 let response = Response {
741 peer: peer.clone(),
742 responded: responded.clone(),
743 receipt,
744 };
745 match (handler)(envelope.payload, response, session).await {
746 Ok(()) => {
747 if responded.load(std::sync::atomic::Ordering::SeqCst) {
748 Ok(())
749 } else {
750 Err(anyhow!("handler did not send a response"))?
751 }
752 }
753 Err(error) => {
754 let proto_err = match &error {
755 Error::Internal(err) => err.to_proto(),
756 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
757 };
758 peer.respond_with_error(receipt, proto_err)?;
759 Err(error)
760 }
761 }
762 }
763 })
764 }
765
766 pub fn handle_connection(
767 self: &Arc<Self>,
768 connection: Connection,
769 address: String,
770 principal: Principal,
771 zed_version: ZedVersion,
772 release_channel: Option<String>,
773 user_agent: Option<String>,
774 geoip_country_code: Option<String>,
775 system_id: Option<String>,
776 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
777 executor: Executor,
778 connection_guard: Option<ConnectionGuard>,
779 ) -> impl Future<Output = ()> + use<> {
780 let this = self.clone();
781 let span = info_span!("handle connection", %address,
782 connection_id=field::Empty,
783 user_id=field::Empty,
784 login=field::Empty,
785 impersonator=field::Empty,
786 user_agent=field::Empty,
787 geoip_country_code=field::Empty,
788 release_channel=field::Empty,
789 );
790 principal.update_span(&span);
791 if let Some(user_agent) = user_agent {
792 span.record("user_agent", user_agent);
793 }
794 if let Some(release_channel) = release_channel {
795 span.record("release_channel", release_channel);
796 }
797
798 if let Some(country_code) = geoip_country_code.as_ref() {
799 span.record("geoip_country_code", country_code);
800 }
801
802 let mut teardown = self.teardown.subscribe();
803 async move {
804 if *teardown.borrow() {
805 tracing::error!("server is tearing down");
806 return;
807 }
808
809 let (connection_id, handle_io, mut incoming_rx) =
810 this.peer.add_connection(connection, {
811 let executor = executor.clone();
812 move |duration| executor.sleep(duration)
813 });
814 tracing::Span::current().record("connection_id", format!("{}", connection_id));
815
816 tracing::info!("connection opened");
817
818 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
819 let http_client = match ReqwestClient::user_agent(&user_agent) {
820 Ok(http_client) => Arc::new(http_client),
821 Err(error) => {
822 tracing::error!(?error, "failed to create HTTP client");
823 return;
824 }
825 };
826
827 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
828 |supermaven_admin_api_key| {
829 Arc::new(SupermavenAdminApi::new(
830 supermaven_admin_api_key.to_string(),
831 http_client.clone(),
832 ))
833 },
834 );
835
836 let session = Session {
837 principal: principal.clone(),
838 connection_id,
839 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
840 peer: this.peer.clone(),
841 connection_pool: this.connection_pool.clone(),
842 app_state: this.app_state.clone(),
843 geoip_country_code,
844 system_id,
845 _executor: executor.clone(),
846 supermaven_client,
847 };
848
849 if let Err(error) = this
850 .send_initial_client_update(
851 connection_id,
852 zed_version,
853 send_connection_id,
854 &session,
855 )
856 .await
857 {
858 tracing::error!(?error, "failed to send initial client update");
859 return;
860 }
861 drop(connection_guard);
862
863 let handle_io = handle_io.fuse();
864 futures::pin_mut!(handle_io);
865
866 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
867 // This prevents deadlocks when e.g., client A performs a request to client B and
868 // client B performs a request to client A. If both clients stop processing further
869 // messages until their respective request completes, they won't have a chance to
870 // respond to the other client's request and cause a deadlock.
871 //
872 // This arrangement ensures we will attempt to process earlier messages first, but fall
873 // back to processing messages arrived later in the spirit of making progress.
874 const MAX_CONCURRENT_HANDLERS: usize = 256;
875 let mut foreground_message_handlers = FuturesUnordered::new();
876 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
877 let get_concurrent_handlers = {
878 let concurrent_handlers = concurrent_handlers.clone();
879 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
880 };
881 loop {
882 let next_message = async {
883 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
884 let message = incoming_rx.next().await;
885 // Cache the concurrent_handlers here, so that we know what the
886 // queue looks like as each handler starts
887 (permit, message, get_concurrent_handlers())
888 }
889 .fuse();
890 futures::pin_mut!(next_message);
891 futures::select_biased! {
892 _ = teardown.changed().fuse() => return,
893 result = handle_io => {
894 if let Err(error) = result {
895 tracing::error!(?error, "error handling I/O");
896 }
897 break;
898 }
899 _ = foreground_message_handlers.next() => {}
900 next_message = next_message => {
901 let (permit, message, concurrent_handlers) = next_message;
902 if let Some(message) = message {
903 let type_name = message.payload_type_name();
904 // note: we copy all the fields from the parent span so we can query them in the logs.
905 // (https://github.com/tokio-rs/tracing/issues/2670).
906 let span = tracing::info_span!("receive message",
907 %connection_id,
908 %address,
909 type_name,
910 concurrent_handlers,
911 user_id=field::Empty,
912 login=field::Empty,
913 impersonator=field::Empty,
914 lsp_query_request=field::Empty,
915 release_channel=field::Empty,
916 { TOTAL_DURATION_MS }=field::Empty,
917 { PROCESSING_DURATION_MS }=field::Empty,
918 { QUEUE_DURATION_MS }=field::Empty,
919 { HOST_WAITING_MS }=field::Empty
920 );
921 principal.update_span(&span);
922 let span_enter = span.enter();
923 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
924 let is_background = message.is_background();
925 let handle_message = (handler)(message, session.clone(), span.clone());
926 drop(span_enter);
927
928 let handle_message = async move {
929 handle_message.await;
930 drop(permit);
931 }.instrument(span);
932 if is_background {
933 executor.spawn_detached(handle_message);
934 } else {
935 foreground_message_handlers.push(handle_message);
936 }
937 } else {
938 tracing::error!("no message handler");
939 }
940 } else {
941 tracing::info!("connection closed");
942 break;
943 }
944 }
945 }
946 }
947
948 drop(foreground_message_handlers);
949 let concurrent_handlers = get_concurrent_handlers();
950 tracing::info!(concurrent_handlers, "signing out");
951 if let Err(error) = connection_lost(session, teardown, executor).await {
952 tracing::error!(?error, "error signing out");
953 }
954 }
955 .instrument(span)
956 }
957
958 async fn send_initial_client_update(
959 &self,
960 connection_id: ConnectionId,
961 zed_version: ZedVersion,
962 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
963 session: &Session,
964 ) -> Result<()> {
965 self.peer.send(
966 connection_id,
967 proto::Hello {
968 peer_id: Some(connection_id.into()),
969 },
970 )?;
971 tracing::info!("sent hello message");
972 if let Some(send_connection_id) = send_connection_id.take() {
973 let _ = send_connection_id.send(connection_id);
974 }
975
976 match &session.principal {
977 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
978 if !user.connected_once {
979 self.peer.send(connection_id, proto::ShowContacts {})?;
980 self.app_state
981 .db
982 .set_user_connected_once(user.id, true)
983 .await?;
984 }
985
986 let contacts = self.app_state.db.get_contacts(user.id).await?;
987
988 {
989 let mut pool = self.connection_pool.lock();
990 pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
991 self.peer.send(
992 connection_id,
993 build_initial_contacts_update(contacts, &pool),
994 )?;
995 }
996
997 if should_auto_subscribe_to_channels(&zed_version) {
998 subscribe_user_to_channels(user.id, session).await?;
999 }
1000
1001 if let Some(incoming_call) =
1002 self.app_state.db.incoming_call_for_user(user.id).await?
1003 {
1004 self.peer.send(connection_id, incoming_call)?;
1005 }
1006
1007 update_user_contacts(user.id, session).await?;
1008 }
1009 }
1010
1011 Ok(())
1012 }
1013
1014 pub async fn invite_code_redeemed(
1015 self: &Arc<Self>,
1016 inviter_id: UserId,
1017 invitee_id: UserId,
1018 ) -> Result<()> {
1019 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1020 && let Some(code) = &user.invite_code
1021 {
1022 let pool = self.connection_pool.lock();
1023 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1024 for connection_id in pool.user_connection_ids(inviter_id) {
1025 self.peer.send(
1026 connection_id,
1027 proto::UpdateContacts {
1028 contacts: vec![invitee_contact.clone()],
1029 ..Default::default()
1030 },
1031 )?;
1032 self.peer.send(
1033 connection_id,
1034 proto::UpdateInviteInfo {
1035 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1036 count: user.invite_count as u32,
1037 },
1038 )?;
1039 }
1040 }
1041 Ok(())
1042 }
1043
1044 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1045 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1046 && let Some(invite_code) = &user.invite_code
1047 {
1048 let pool = self.connection_pool.lock();
1049 for connection_id in pool.user_connection_ids(user_id) {
1050 self.peer.send(
1051 connection_id,
1052 proto::UpdateInviteInfo {
1053 url: format!(
1054 "{}{}",
1055 self.app_state.config.invite_link_prefix, invite_code
1056 ),
1057 count: user.invite_count as u32,
1058 },
1059 )?;
1060 }
1061 }
1062 Ok(())
1063 }
1064
1065 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1066 ServerSnapshot {
1067 connection_pool: ConnectionPoolGuard {
1068 guard: self.connection_pool.lock(),
1069 _not_send: PhantomData,
1070 },
1071 peer: &self.peer,
1072 }
1073 }
1074}
1075
1076impl Deref for ConnectionPoolGuard<'_> {
1077 type Target = ConnectionPool;
1078
1079 fn deref(&self) -> &Self::Target {
1080 &self.guard
1081 }
1082}
1083
1084impl DerefMut for ConnectionPoolGuard<'_> {
1085 fn deref_mut(&mut self) -> &mut Self::Target {
1086 &mut self.guard
1087 }
1088}
1089
1090impl Drop for ConnectionPoolGuard<'_> {
1091 fn drop(&mut self) {
1092 #[cfg(test)]
1093 self.check_invariants();
1094 }
1095}
1096
1097fn broadcast<F>(
1098 sender_id: Option<ConnectionId>,
1099 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1100 mut f: F,
1101) where
1102 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1103{
1104 for receiver_id in receiver_ids {
1105 if Some(receiver_id) != sender_id
1106 && let Err(error) = f(receiver_id)
1107 {
1108 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1109 }
1110 }
1111}
1112
1113pub struct ProtocolVersion(u32);
1114
1115impl Header for ProtocolVersion {
1116 fn name() -> &'static HeaderName {
1117 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1118 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1119 }
1120
1121 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1122 where
1123 Self: Sized,
1124 I: Iterator<Item = &'i axum::http::HeaderValue>,
1125 {
1126 let version = values
1127 .next()
1128 .ok_or_else(axum::headers::Error::invalid)?
1129 .to_str()
1130 .map_err(|_| axum::headers::Error::invalid())?
1131 .parse()
1132 .map_err(|_| axum::headers::Error::invalid())?;
1133 Ok(Self(version))
1134 }
1135
1136 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1137 values.extend([self.0.to_string().parse().unwrap()]);
1138 }
1139}
1140
1141pub struct AppVersionHeader(Version);
1142impl Header for AppVersionHeader {
1143 fn name() -> &'static HeaderName {
1144 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1145 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1146 }
1147
1148 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1149 where
1150 Self: Sized,
1151 I: Iterator<Item = &'i axum::http::HeaderValue>,
1152 {
1153 let version = values
1154 .next()
1155 .ok_or_else(axum::headers::Error::invalid)?
1156 .to_str()
1157 .map_err(|_| axum::headers::Error::invalid())?
1158 .parse()
1159 .map_err(|_| axum::headers::Error::invalid())?;
1160 Ok(Self(version))
1161 }
1162
1163 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1164 values.extend([self.0.to_string().parse().unwrap()]);
1165 }
1166}
1167
1168#[derive(Debug)]
1169pub struct ReleaseChannelHeader(String);
1170
1171impl Header for ReleaseChannelHeader {
1172 fn name() -> &'static HeaderName {
1173 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1174 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1175 }
1176
1177 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1178 where
1179 Self: Sized,
1180 I: Iterator<Item = &'i axum::http::HeaderValue>,
1181 {
1182 Ok(Self(
1183 values
1184 .next()
1185 .ok_or_else(axum::headers::Error::invalid)?
1186 .to_str()
1187 .map_err(|_| axum::headers::Error::invalid())?
1188 .to_owned(),
1189 ))
1190 }
1191
1192 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1193 values.extend([self.0.parse().unwrap()]);
1194 }
1195}
1196
1197pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1198 Router::new()
1199 .route("/rpc", get(handle_websocket_request))
1200 .layer(
1201 ServiceBuilder::new()
1202 .layer(Extension(server.app_state.clone()))
1203 .layer(middleware::from_fn(auth::validate_header)),
1204 )
1205 .route("/metrics", get(handle_metrics))
1206 .layer(Extension(server))
1207}
1208
1209pub async fn handle_websocket_request(
1210 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1211 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1212 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1213 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1214 Extension(server): Extension<Arc<Server>>,
1215 Extension(principal): Extension<Principal>,
1216 user_agent: Option<TypedHeader<UserAgent>>,
1217 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1218 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1219 ws: WebSocketUpgrade,
1220) -> axum::response::Response {
1221 if protocol_version != rpc::PROTOCOL_VERSION {
1222 return (
1223 StatusCode::UPGRADE_REQUIRED,
1224 "client must be upgraded".to_string(),
1225 )
1226 .into_response();
1227 }
1228
1229 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1230 return (
1231 StatusCode::UPGRADE_REQUIRED,
1232 "no version header found".to_string(),
1233 )
1234 .into_response();
1235 };
1236
1237 let release_channel = release_channel_header.map(|header| header.0.0);
1238
1239 if !version.can_collaborate() {
1240 return (
1241 StatusCode::UPGRADE_REQUIRED,
1242 "client must be upgraded".to_string(),
1243 )
1244 .into_response();
1245 }
1246
1247 let socket_address = socket_address.to_string();
1248
1249 // Acquire connection guard before WebSocket upgrade
1250 let connection_guard = match ConnectionGuard::try_acquire() {
1251 Ok(guard) => guard,
1252 Err(()) => {
1253 return (
1254 StatusCode::SERVICE_UNAVAILABLE,
1255 "Too many concurrent connections",
1256 )
1257 .into_response();
1258 }
1259 };
1260
1261 ws.on_upgrade(move |socket| {
1262 let socket = socket
1263 .map_ok(to_tungstenite_message)
1264 .err_into()
1265 .with(|message| async move { to_axum_message(message) });
1266 let connection = Connection::new(Box::pin(socket));
1267 async move {
1268 server
1269 .handle_connection(
1270 connection,
1271 socket_address,
1272 principal,
1273 version,
1274 release_channel,
1275 user_agent.map(|header| header.to_string()),
1276 country_code_header.map(|header| header.to_string()),
1277 system_id_header.map(|header| header.to_string()),
1278 None,
1279 Executor::Production,
1280 Some(connection_guard),
1281 )
1282 .await;
1283 }
1284 })
1285}
1286
1287pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1288 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1289 let connections_metric = CONNECTIONS_METRIC
1290 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1291
1292 let connections = server
1293 .connection_pool
1294 .lock()
1295 .connections()
1296 .filter(|connection| !connection.admin)
1297 .count();
1298 connections_metric.set(connections as _);
1299
1300 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1301 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1302 register_int_gauge!(
1303 "shared_projects",
1304 "number of open projects with one or more guests"
1305 )
1306 .unwrap()
1307 });
1308
1309 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1310 shared_projects_metric.set(shared_projects as _);
1311
1312 let encoder = prometheus::TextEncoder::new();
1313 let metric_families = prometheus::gather();
1314 let encoded_metrics = encoder
1315 .encode_to_string(&metric_families)
1316 .map_err(|err| anyhow!("{err}"))?;
1317 Ok(encoded_metrics)
1318}
1319
1320#[instrument(err, skip(executor))]
1321async fn connection_lost(
1322 session: Session,
1323 mut teardown: watch::Receiver<bool>,
1324 executor: Executor,
1325) -> Result<()> {
1326 session.peer.disconnect(session.connection_id);
1327 session
1328 .connection_pool()
1329 .await
1330 .remove_connection(session.connection_id)?;
1331
1332 session
1333 .db()
1334 .await
1335 .connection_lost(session.connection_id)
1336 .await
1337 .trace_err();
1338
1339 futures::select_biased! {
1340 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1341
1342 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1343 leave_room_for_session(&session, session.connection_id).await.trace_err();
1344 leave_channel_buffers_for_session(&session)
1345 .await
1346 .trace_err();
1347
1348 if !session
1349 .connection_pool()
1350 .await
1351 .is_user_online(session.user_id())
1352 {
1353 let db = session.db().await;
1354 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1355 room_updated(&room, &session.peer);
1356 }
1357 }
1358
1359 update_user_contacts(session.user_id(), &session).await?;
1360 },
1361 _ = teardown.changed().fuse() => {}
1362 }
1363
1364 Ok(())
1365}
1366
1367/// Acknowledges a ping from a client, used to keep the connection alive.
1368async fn ping(
1369 _: proto::Ping,
1370 response: Response<proto::Ping>,
1371 _session: MessageContext,
1372) -> Result<()> {
1373 response.send(proto::Ack {})?;
1374 Ok(())
1375}
1376
1377/// Creates a new room for calling (outside of channels)
1378async fn create_room(
1379 _request: proto::CreateRoom,
1380 response: Response<proto::CreateRoom>,
1381 session: MessageContext,
1382) -> Result<()> {
1383 let livekit_room = nanoid::nanoid!(30);
1384
1385 let live_kit_connection_info = util::maybe!(async {
1386 let live_kit = session.app_state.livekit_client.as_ref();
1387 let live_kit = live_kit?;
1388 let user_id = session.user_id().to_string();
1389
1390 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1391
1392 Some(proto::LiveKitConnectionInfo {
1393 server_url: live_kit.url().into(),
1394 token,
1395 can_publish: true,
1396 })
1397 })
1398 .await;
1399
1400 let room = session
1401 .db()
1402 .await
1403 .create_room(session.user_id(), session.connection_id, &livekit_room)
1404 .await?;
1405
1406 response.send(proto::CreateRoomResponse {
1407 room: Some(room.clone()),
1408 live_kit_connection_info,
1409 })?;
1410
1411 update_user_contacts(session.user_id(), &session).await?;
1412 Ok(())
1413}
1414
1415/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1416async fn join_room(
1417 request: proto::JoinRoom,
1418 response: Response<proto::JoinRoom>,
1419 session: MessageContext,
1420) -> Result<()> {
1421 let room_id = RoomId::from_proto(request.id);
1422
1423 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1424
1425 if let Some(channel_id) = channel_id {
1426 return join_channel_internal(channel_id, Box::new(response), session).await;
1427 }
1428
1429 let joined_room = {
1430 let room = session
1431 .db()
1432 .await
1433 .join_room(room_id, session.user_id(), session.connection_id)
1434 .await?;
1435 room_updated(&room.room, &session.peer);
1436 room.into_inner()
1437 };
1438
1439 for connection_id in session
1440 .connection_pool()
1441 .await
1442 .user_connection_ids(session.user_id())
1443 {
1444 session
1445 .peer
1446 .send(
1447 connection_id,
1448 proto::CallCanceled {
1449 room_id: room_id.to_proto(),
1450 },
1451 )
1452 .trace_err();
1453 }
1454
1455 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1456 {
1457 live_kit
1458 .room_token(
1459 &joined_room.room.livekit_room,
1460 &session.user_id().to_string(),
1461 )
1462 .trace_err()
1463 .map(|token| proto::LiveKitConnectionInfo {
1464 server_url: live_kit.url().into(),
1465 token,
1466 can_publish: true,
1467 })
1468 } else {
1469 None
1470 };
1471
1472 response.send(proto::JoinRoomResponse {
1473 room: Some(joined_room.room),
1474 channel_id: None,
1475 live_kit_connection_info,
1476 })?;
1477
1478 update_user_contacts(session.user_id(), &session).await?;
1479 Ok(())
1480}
1481
1482/// Rejoin room is used to reconnect to a room after connection errors.
1483async fn rejoin_room(
1484 request: proto::RejoinRoom,
1485 response: Response<proto::RejoinRoom>,
1486 session: MessageContext,
1487) -> Result<()> {
1488 let room;
1489 let channel;
1490 {
1491 let mut rejoined_room = session
1492 .db()
1493 .await
1494 .rejoin_room(request, session.user_id(), session.connection_id)
1495 .await?;
1496
1497 response.send(proto::RejoinRoomResponse {
1498 room: Some(rejoined_room.room.clone()),
1499 reshared_projects: rejoined_room
1500 .reshared_projects
1501 .iter()
1502 .map(|project| proto::ResharedProject {
1503 id: project.id.to_proto(),
1504 collaborators: project
1505 .collaborators
1506 .iter()
1507 .map(|collaborator| collaborator.to_proto())
1508 .collect(),
1509 })
1510 .collect(),
1511 rejoined_projects: rejoined_room
1512 .rejoined_projects
1513 .iter()
1514 .map(|rejoined_project| rejoined_project.to_proto())
1515 .collect(),
1516 })?;
1517 room_updated(&rejoined_room.room, &session.peer);
1518
1519 for project in &rejoined_room.reshared_projects {
1520 for collaborator in &project.collaborators {
1521 session
1522 .peer
1523 .send(
1524 collaborator.connection_id,
1525 proto::UpdateProjectCollaborator {
1526 project_id: project.id.to_proto(),
1527 old_peer_id: Some(project.old_connection_id.into()),
1528 new_peer_id: Some(session.connection_id.into()),
1529 },
1530 )
1531 .trace_err();
1532 }
1533
1534 broadcast(
1535 Some(session.connection_id),
1536 project
1537 .collaborators
1538 .iter()
1539 .map(|collaborator| collaborator.connection_id),
1540 |connection_id| {
1541 session.peer.forward_send(
1542 session.connection_id,
1543 connection_id,
1544 proto::UpdateProject {
1545 project_id: project.id.to_proto(),
1546 worktrees: project.worktrees.clone(),
1547 },
1548 )
1549 },
1550 );
1551 }
1552
1553 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1554
1555 let rejoined_room = rejoined_room.into_inner();
1556
1557 room = rejoined_room.room;
1558 channel = rejoined_room.channel;
1559 }
1560
1561 if let Some(channel) = channel {
1562 channel_updated(
1563 &channel,
1564 &room,
1565 &session.peer,
1566 &*session.connection_pool().await,
1567 );
1568 }
1569
1570 update_user_contacts(session.user_id(), &session).await?;
1571 Ok(())
1572}
1573
1574fn notify_rejoined_projects(
1575 rejoined_projects: &mut Vec<RejoinedProject>,
1576 session: &Session,
1577) -> Result<()> {
1578 for project in rejoined_projects.iter() {
1579 for collaborator in &project.collaborators {
1580 session
1581 .peer
1582 .send(
1583 collaborator.connection_id,
1584 proto::UpdateProjectCollaborator {
1585 project_id: project.id.to_proto(),
1586 old_peer_id: Some(project.old_connection_id.into()),
1587 new_peer_id: Some(session.connection_id.into()),
1588 },
1589 )
1590 .trace_err();
1591 }
1592 }
1593
1594 for project in rejoined_projects {
1595 for worktree in mem::take(&mut project.worktrees) {
1596 // Stream this worktree's entries.
1597 let message = proto::UpdateWorktree {
1598 project_id: project.id.to_proto(),
1599 worktree_id: worktree.id,
1600 abs_path: worktree.abs_path.clone(),
1601 root_name: worktree.root_name,
1602 updated_entries: worktree.updated_entries,
1603 removed_entries: worktree.removed_entries,
1604 scan_id: worktree.scan_id,
1605 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1606 updated_repositories: worktree.updated_repositories,
1607 removed_repositories: worktree.removed_repositories,
1608 };
1609 for update in proto::split_worktree_update(message) {
1610 session.peer.send(session.connection_id, update)?;
1611 }
1612
1613 // Stream this worktree's diagnostics.
1614 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1615 if let Some(summary) = worktree_diagnostics.next() {
1616 let message = proto::UpdateDiagnosticSummary {
1617 project_id: project.id.to_proto(),
1618 worktree_id: worktree.id,
1619 summary: Some(summary),
1620 more_summaries: worktree_diagnostics.collect(),
1621 };
1622 session.peer.send(session.connection_id, message)?;
1623 }
1624
1625 for settings_file in worktree.settings_files {
1626 session.peer.send(
1627 session.connection_id,
1628 proto::UpdateWorktreeSettings {
1629 project_id: project.id.to_proto(),
1630 worktree_id: worktree.id,
1631 path: settings_file.path,
1632 content: Some(settings_file.content),
1633 kind: Some(settings_file.kind.to_proto().into()),
1634 },
1635 )?;
1636 }
1637 }
1638
1639 for repository in mem::take(&mut project.updated_repositories) {
1640 for update in split_repository_update(repository) {
1641 session.peer.send(session.connection_id, update)?;
1642 }
1643 }
1644
1645 for id in mem::take(&mut project.removed_repositories) {
1646 session.peer.send(
1647 session.connection_id,
1648 proto::RemoveRepository {
1649 project_id: project.id.to_proto(),
1650 id,
1651 },
1652 )?;
1653 }
1654 }
1655
1656 Ok(())
1657}
1658
1659/// leave room disconnects from the room.
1660async fn leave_room(
1661 _: proto::LeaveRoom,
1662 response: Response<proto::LeaveRoom>,
1663 session: MessageContext,
1664) -> Result<()> {
1665 leave_room_for_session(&session, session.connection_id).await?;
1666 response.send(proto::Ack {})?;
1667 Ok(())
1668}
1669
1670/// Updates the permissions of someone else in the room.
1671async fn set_room_participant_role(
1672 request: proto::SetRoomParticipantRole,
1673 response: Response<proto::SetRoomParticipantRole>,
1674 session: MessageContext,
1675) -> Result<()> {
1676 let user_id = UserId::from_proto(request.user_id);
1677 let role = ChannelRole::from(request.role());
1678
1679 let (livekit_room, can_publish) = {
1680 let room = session
1681 .db()
1682 .await
1683 .set_room_participant_role(
1684 session.user_id(),
1685 RoomId::from_proto(request.room_id),
1686 user_id,
1687 role,
1688 )
1689 .await?;
1690
1691 let livekit_room = room.livekit_room.clone();
1692 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1693 room_updated(&room, &session.peer);
1694 (livekit_room, can_publish)
1695 };
1696
1697 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1698 live_kit
1699 .update_participant(
1700 livekit_room.clone(),
1701 request.user_id.to_string(),
1702 livekit_api::proto::ParticipantPermission {
1703 can_subscribe: true,
1704 can_publish,
1705 can_publish_data: can_publish,
1706 hidden: false,
1707 recorder: false,
1708 },
1709 )
1710 .await
1711 .trace_err();
1712 }
1713
1714 response.send(proto::Ack {})?;
1715 Ok(())
1716}
1717
1718/// Call someone else into the current room
1719async fn call(
1720 request: proto::Call,
1721 response: Response<proto::Call>,
1722 session: MessageContext,
1723) -> Result<()> {
1724 let room_id = RoomId::from_proto(request.room_id);
1725 let calling_user_id = session.user_id();
1726 let calling_connection_id = session.connection_id;
1727 let called_user_id = UserId::from_proto(request.called_user_id);
1728 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1729 if !session
1730 .db()
1731 .await
1732 .has_contact(calling_user_id, called_user_id)
1733 .await?
1734 {
1735 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1736 }
1737
1738 let incoming_call = {
1739 let (room, incoming_call) = &mut *session
1740 .db()
1741 .await
1742 .call(
1743 room_id,
1744 calling_user_id,
1745 calling_connection_id,
1746 called_user_id,
1747 initial_project_id,
1748 )
1749 .await?;
1750 room_updated(room, &session.peer);
1751 mem::take(incoming_call)
1752 };
1753 update_user_contacts(called_user_id, &session).await?;
1754
1755 let mut calls = session
1756 .connection_pool()
1757 .await
1758 .user_connection_ids(called_user_id)
1759 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1760 .collect::<FuturesUnordered<_>>();
1761
1762 while let Some(call_response) = calls.next().await {
1763 match call_response.as_ref() {
1764 Ok(_) => {
1765 response.send(proto::Ack {})?;
1766 return Ok(());
1767 }
1768 Err(_) => {
1769 call_response.trace_err();
1770 }
1771 }
1772 }
1773
1774 {
1775 let room = session
1776 .db()
1777 .await
1778 .call_failed(room_id, called_user_id)
1779 .await?;
1780 room_updated(&room, &session.peer);
1781 }
1782 update_user_contacts(called_user_id, &session).await?;
1783
1784 Err(anyhow!("failed to ring user"))?
1785}
1786
1787/// Cancel an outgoing call.
1788async fn cancel_call(
1789 request: proto::CancelCall,
1790 response: Response<proto::CancelCall>,
1791 session: MessageContext,
1792) -> Result<()> {
1793 let called_user_id = UserId::from_proto(request.called_user_id);
1794 let room_id = RoomId::from_proto(request.room_id);
1795 {
1796 let room = session
1797 .db()
1798 .await
1799 .cancel_call(room_id, session.connection_id, called_user_id)
1800 .await?;
1801 room_updated(&room, &session.peer);
1802 }
1803
1804 for connection_id in session
1805 .connection_pool()
1806 .await
1807 .user_connection_ids(called_user_id)
1808 {
1809 session
1810 .peer
1811 .send(
1812 connection_id,
1813 proto::CallCanceled {
1814 room_id: room_id.to_proto(),
1815 },
1816 )
1817 .trace_err();
1818 }
1819 response.send(proto::Ack {})?;
1820
1821 update_user_contacts(called_user_id, &session).await?;
1822 Ok(())
1823}
1824
1825/// Decline an incoming call.
1826async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1827 let room_id = RoomId::from_proto(message.room_id);
1828 {
1829 let room = session
1830 .db()
1831 .await
1832 .decline_call(Some(room_id), session.user_id())
1833 .await?
1834 .context("declining call")?;
1835 room_updated(&room, &session.peer);
1836 }
1837
1838 for connection_id in session
1839 .connection_pool()
1840 .await
1841 .user_connection_ids(session.user_id())
1842 {
1843 session
1844 .peer
1845 .send(
1846 connection_id,
1847 proto::CallCanceled {
1848 room_id: room_id.to_proto(),
1849 },
1850 )
1851 .trace_err();
1852 }
1853 update_user_contacts(session.user_id(), &session).await?;
1854 Ok(())
1855}
1856
1857/// Updates other participants in the room with your current location.
1858async fn update_participant_location(
1859 request: proto::UpdateParticipantLocation,
1860 response: Response<proto::UpdateParticipantLocation>,
1861 session: MessageContext,
1862) -> Result<()> {
1863 let room_id = RoomId::from_proto(request.room_id);
1864 let location = request.location.context("invalid location")?;
1865
1866 let db = session.db().await;
1867 let room = db
1868 .update_room_participant_location(room_id, session.connection_id, location)
1869 .await?;
1870
1871 room_updated(&room, &session.peer);
1872 response.send(proto::Ack {})?;
1873 Ok(())
1874}
1875
1876/// Share a project into the room.
1877async fn share_project(
1878 request: proto::ShareProject,
1879 response: Response<proto::ShareProject>,
1880 session: MessageContext,
1881) -> Result<()> {
1882 let (project_id, room) = &*session
1883 .db()
1884 .await
1885 .share_project(
1886 RoomId::from_proto(request.room_id),
1887 session.connection_id,
1888 &request.worktrees,
1889 request.is_ssh_project,
1890 request.windows_paths.unwrap_or(false),
1891 )
1892 .await?;
1893 response.send(proto::ShareProjectResponse {
1894 project_id: project_id.to_proto(),
1895 })?;
1896 room_updated(room, &session.peer);
1897
1898 Ok(())
1899}
1900
1901/// Unshare a project from the room.
1902async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1903 let project_id = ProjectId::from_proto(message.project_id);
1904 unshare_project_internal(project_id, session.connection_id, &session).await
1905}
1906
1907async fn unshare_project_internal(
1908 project_id: ProjectId,
1909 connection_id: ConnectionId,
1910 session: &Session,
1911) -> Result<()> {
1912 let delete = {
1913 let room_guard = session
1914 .db()
1915 .await
1916 .unshare_project(project_id, connection_id)
1917 .await?;
1918
1919 let (delete, room, guest_connection_ids) = &*room_guard;
1920
1921 let message = proto::UnshareProject {
1922 project_id: project_id.to_proto(),
1923 };
1924
1925 broadcast(
1926 Some(connection_id),
1927 guest_connection_ids.iter().copied(),
1928 |conn_id| session.peer.send(conn_id, message.clone()),
1929 );
1930 if let Some(room) = room {
1931 room_updated(room, &session.peer);
1932 }
1933
1934 *delete
1935 };
1936
1937 if delete {
1938 let db = session.db().await;
1939 db.delete_project(project_id).await?;
1940 }
1941
1942 Ok(())
1943}
1944
1945/// Join someone elses shared project.
1946async fn join_project(
1947 request: proto::JoinProject,
1948 response: Response<proto::JoinProject>,
1949 session: MessageContext,
1950) -> Result<()> {
1951 let project_id = ProjectId::from_proto(request.project_id);
1952
1953 tracing::info!(%project_id, "join project");
1954
1955 let db = session.db().await;
1956 let (project, replica_id) = &mut *db
1957 .join_project(
1958 project_id,
1959 session.connection_id,
1960 session.user_id(),
1961 request.committer_name.clone(),
1962 request.committer_email.clone(),
1963 )
1964 .await?;
1965 drop(db);
1966 tracing::info!(%project_id, "join remote project");
1967 let collaborators = project
1968 .collaborators
1969 .iter()
1970 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1971 .map(|collaborator| collaborator.to_proto())
1972 .collect::<Vec<_>>();
1973 let project_id = project.id;
1974 let guest_user_id = session.user_id();
1975
1976 let worktrees = project
1977 .worktrees
1978 .iter()
1979 .map(|(id, worktree)| proto::WorktreeMetadata {
1980 id: *id,
1981 root_name: worktree.root_name.clone(),
1982 visible: worktree.visible,
1983 abs_path: worktree.abs_path.clone(),
1984 })
1985 .collect::<Vec<_>>();
1986
1987 let add_project_collaborator = proto::AddProjectCollaborator {
1988 project_id: project_id.to_proto(),
1989 collaborator: Some(proto::Collaborator {
1990 peer_id: Some(session.connection_id.into()),
1991 replica_id: replica_id.0 as u32,
1992 user_id: guest_user_id.to_proto(),
1993 is_host: false,
1994 committer_name: request.committer_name.clone(),
1995 committer_email: request.committer_email.clone(),
1996 }),
1997 };
1998
1999 for collaborator in &collaborators {
2000 session
2001 .peer
2002 .send(
2003 collaborator.peer_id.unwrap().into(),
2004 add_project_collaborator.clone(),
2005 )
2006 .trace_err();
2007 }
2008
2009 // First, we send the metadata associated with each worktree.
2010 let (language_servers, language_server_capabilities) = project
2011 .language_servers
2012 .clone()
2013 .into_iter()
2014 .map(|server| (server.server, server.capabilities))
2015 .unzip();
2016 response.send(proto::JoinProjectResponse {
2017 project_id: project.id.0 as u64,
2018 worktrees,
2019 replica_id: replica_id.0 as u32,
2020 collaborators,
2021 language_servers,
2022 language_server_capabilities,
2023 role: project.role.into(),
2024 windows_paths: project.path_style == PathStyle::Windows,
2025 })?;
2026
2027 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2028 // Stream this worktree's entries.
2029 let message = proto::UpdateWorktree {
2030 project_id: project_id.to_proto(),
2031 worktree_id,
2032 abs_path: worktree.abs_path.clone(),
2033 root_name: worktree.root_name,
2034 updated_entries: worktree.entries,
2035 removed_entries: Default::default(),
2036 scan_id: worktree.scan_id,
2037 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2038 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2039 removed_repositories: Default::default(),
2040 };
2041 for update in proto::split_worktree_update(message) {
2042 session.peer.send(session.connection_id, update.clone())?;
2043 }
2044
2045 // Stream this worktree's diagnostics.
2046 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2047 if let Some(summary) = worktree_diagnostics.next() {
2048 let message = proto::UpdateDiagnosticSummary {
2049 project_id: project.id.to_proto(),
2050 worktree_id: worktree.id,
2051 summary: Some(summary),
2052 more_summaries: worktree_diagnostics.collect(),
2053 };
2054 session.peer.send(session.connection_id, message)?;
2055 }
2056
2057 for settings_file in worktree.settings_files {
2058 session.peer.send(
2059 session.connection_id,
2060 proto::UpdateWorktreeSettings {
2061 project_id: project_id.to_proto(),
2062 worktree_id: worktree.id,
2063 path: settings_file.path,
2064 content: Some(settings_file.content),
2065 kind: Some(settings_file.kind.to_proto() as i32),
2066 },
2067 )?;
2068 }
2069 }
2070
2071 for repository in mem::take(&mut project.repositories) {
2072 for update in split_repository_update(repository) {
2073 session.peer.send(session.connection_id, update)?;
2074 }
2075 }
2076
2077 for language_server in &project.language_servers {
2078 session.peer.send(
2079 session.connection_id,
2080 proto::UpdateLanguageServer {
2081 project_id: project_id.to_proto(),
2082 server_name: Some(language_server.server.name.clone()),
2083 language_server_id: language_server.server.id,
2084 variant: Some(
2085 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2086 proto::LspDiskBasedDiagnosticsUpdated {},
2087 ),
2088 ),
2089 },
2090 )?;
2091 }
2092
2093 Ok(())
2094}
2095
2096/// Leave someone elses shared project.
2097async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2098 let sender_id = session.connection_id;
2099 let project_id = ProjectId::from_proto(request.project_id);
2100 let db = session.db().await;
2101
2102 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2103 tracing::info!(
2104 %project_id,
2105 "leave project"
2106 );
2107
2108 project_left(project, &session);
2109 if let Some(room) = room {
2110 room_updated(room, &session.peer);
2111 }
2112
2113 Ok(())
2114}
2115
2116/// Updates other participants with changes to the project
2117async fn update_project(
2118 request: proto::UpdateProject,
2119 response: Response<proto::UpdateProject>,
2120 session: MessageContext,
2121) -> Result<()> {
2122 let project_id = ProjectId::from_proto(request.project_id);
2123 let (room, guest_connection_ids) = &*session
2124 .db()
2125 .await
2126 .update_project(project_id, session.connection_id, &request.worktrees)
2127 .await?;
2128 broadcast(
2129 Some(session.connection_id),
2130 guest_connection_ids.iter().copied(),
2131 |connection_id| {
2132 session
2133 .peer
2134 .forward_send(session.connection_id, connection_id, request.clone())
2135 },
2136 );
2137 if let Some(room) = room {
2138 room_updated(room, &session.peer);
2139 }
2140 response.send(proto::Ack {})?;
2141
2142 Ok(())
2143}
2144
2145/// Updates other participants with changes to the worktree
2146async fn update_worktree(
2147 request: proto::UpdateWorktree,
2148 response: Response<proto::UpdateWorktree>,
2149 session: MessageContext,
2150) -> Result<()> {
2151 let guest_connection_ids = session
2152 .db()
2153 .await
2154 .update_worktree(&request, session.connection_id)
2155 .await?;
2156
2157 broadcast(
2158 Some(session.connection_id),
2159 guest_connection_ids.iter().copied(),
2160 |connection_id| {
2161 session
2162 .peer
2163 .forward_send(session.connection_id, connection_id, request.clone())
2164 },
2165 );
2166 response.send(proto::Ack {})?;
2167 Ok(())
2168}
2169
2170async fn update_repository(
2171 request: proto::UpdateRepository,
2172 response: Response<proto::UpdateRepository>,
2173 session: MessageContext,
2174) -> Result<()> {
2175 let guest_connection_ids = session
2176 .db()
2177 .await
2178 .update_repository(&request, session.connection_id)
2179 .await?;
2180
2181 broadcast(
2182 Some(session.connection_id),
2183 guest_connection_ids.iter().copied(),
2184 |connection_id| {
2185 session
2186 .peer
2187 .forward_send(session.connection_id, connection_id, request.clone())
2188 },
2189 );
2190 response.send(proto::Ack {})?;
2191 Ok(())
2192}
2193
2194async fn remove_repository(
2195 request: proto::RemoveRepository,
2196 response: Response<proto::RemoveRepository>,
2197 session: MessageContext,
2198) -> Result<()> {
2199 let guest_connection_ids = session
2200 .db()
2201 .await
2202 .remove_repository(&request, session.connection_id)
2203 .await?;
2204
2205 broadcast(
2206 Some(session.connection_id),
2207 guest_connection_ids.iter().copied(),
2208 |connection_id| {
2209 session
2210 .peer
2211 .forward_send(session.connection_id, connection_id, request.clone())
2212 },
2213 );
2214 response.send(proto::Ack {})?;
2215 Ok(())
2216}
2217
2218/// Updates other participants with changes to the diagnostics
2219async fn update_diagnostic_summary(
2220 message: proto::UpdateDiagnosticSummary,
2221 session: MessageContext,
2222) -> Result<()> {
2223 let guest_connection_ids = session
2224 .db()
2225 .await
2226 .update_diagnostic_summary(&message, session.connection_id)
2227 .await?;
2228
2229 broadcast(
2230 Some(session.connection_id),
2231 guest_connection_ids.iter().copied(),
2232 |connection_id| {
2233 session
2234 .peer
2235 .forward_send(session.connection_id, connection_id, message.clone())
2236 },
2237 );
2238
2239 Ok(())
2240}
2241
2242/// Updates other participants with changes to the worktree settings
2243async fn update_worktree_settings(
2244 message: proto::UpdateWorktreeSettings,
2245 session: MessageContext,
2246) -> Result<()> {
2247 let guest_connection_ids = session
2248 .db()
2249 .await
2250 .update_worktree_settings(&message, session.connection_id)
2251 .await?;
2252
2253 broadcast(
2254 Some(session.connection_id),
2255 guest_connection_ids.iter().copied(),
2256 |connection_id| {
2257 session
2258 .peer
2259 .forward_send(session.connection_id, connection_id, message.clone())
2260 },
2261 );
2262
2263 Ok(())
2264}
2265
2266/// Notify other participants that a language server has started.
2267async fn start_language_server(
2268 request: proto::StartLanguageServer,
2269 session: MessageContext,
2270) -> Result<()> {
2271 let guest_connection_ids = session
2272 .db()
2273 .await
2274 .start_language_server(&request, session.connection_id)
2275 .await?;
2276
2277 broadcast(
2278 Some(session.connection_id),
2279 guest_connection_ids.iter().copied(),
2280 |connection_id| {
2281 session
2282 .peer
2283 .forward_send(session.connection_id, connection_id, request.clone())
2284 },
2285 );
2286 Ok(())
2287}
2288
2289/// Notify other participants that a language server has changed.
2290async fn update_language_server(
2291 request: proto::UpdateLanguageServer,
2292 session: MessageContext,
2293) -> Result<()> {
2294 let project_id = ProjectId::from_proto(request.project_id);
2295 let db = session.db().await;
2296
2297 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2298 && let Some(capabilities) = update.capabilities.clone()
2299 {
2300 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2301 .await?;
2302 }
2303
2304 let project_connection_ids = db
2305 .project_connection_ids(project_id, session.connection_id, true)
2306 .await?;
2307 broadcast(
2308 Some(session.connection_id),
2309 project_connection_ids.iter().copied(),
2310 |connection_id| {
2311 session
2312 .peer
2313 .forward_send(session.connection_id, connection_id, request.clone())
2314 },
2315 );
2316 Ok(())
2317}
2318
2319/// forward a project request to the host. These requests should be read only
2320/// as guests are allowed to send them.
2321async fn forward_read_only_project_request<T>(
2322 request: T,
2323 response: Response<T>,
2324 session: MessageContext,
2325) -> Result<()>
2326where
2327 T: EntityMessage + RequestMessage,
2328{
2329 let project_id = ProjectId::from_proto(request.remote_entity_id());
2330 let host_connection_id = session
2331 .db()
2332 .await
2333 .host_for_read_only_project_request(project_id, session.connection_id)
2334 .await?;
2335 let payload = session.forward_request(host_connection_id, request).await?;
2336 response.send(payload)?;
2337 Ok(())
2338}
2339
2340/// forward a project request to the host. These requests are disallowed
2341/// for guests.
2342async fn forward_mutating_project_request<T>(
2343 request: T,
2344 response: Response<T>,
2345 session: MessageContext,
2346) -> Result<()>
2347where
2348 T: EntityMessage + RequestMessage,
2349{
2350 let project_id = ProjectId::from_proto(request.remote_entity_id());
2351
2352 let host_connection_id = session
2353 .db()
2354 .await
2355 .host_for_mutating_project_request(project_id, session.connection_id)
2356 .await?;
2357 let payload = session.forward_request(host_connection_id, request).await?;
2358 response.send(payload)?;
2359 Ok(())
2360}
2361
2362async fn lsp_query(
2363 request: proto::LspQuery,
2364 response: Response<proto::LspQuery>,
2365 session: MessageContext,
2366) -> Result<()> {
2367 let (name, should_write) = request.query_name_and_write_permissions();
2368 tracing::Span::current().record("lsp_query_request", name);
2369 tracing::info!("lsp_query message received");
2370 if should_write {
2371 forward_mutating_project_request(request, response, session).await
2372 } else {
2373 forward_read_only_project_request(request, response, session).await
2374 }
2375}
2376
2377/// Notify other participants that a new buffer has been created
2378async fn create_buffer_for_peer(
2379 request: proto::CreateBufferForPeer,
2380 session: MessageContext,
2381) -> Result<()> {
2382 session
2383 .db()
2384 .await
2385 .check_user_is_project_host(
2386 ProjectId::from_proto(request.project_id),
2387 session.connection_id,
2388 )
2389 .await?;
2390 let peer_id = request.peer_id.context("invalid peer id")?;
2391 session
2392 .peer
2393 .forward_send(session.connection_id, peer_id.into(), request)?;
2394 Ok(())
2395}
2396
2397/// Notify other participants that a new image has been created
2398async fn create_image_for_peer(
2399 request: proto::CreateImageForPeer,
2400 session: MessageContext,
2401) -> Result<()> {
2402 session
2403 .db()
2404 .await
2405 .check_user_is_project_host(
2406 ProjectId::from_proto(request.project_id),
2407 session.connection_id,
2408 )
2409 .await?;
2410 let peer_id = request.peer_id.context("invalid peer id")?;
2411 session
2412 .peer
2413 .forward_send(session.connection_id, peer_id.into(), request)?;
2414 Ok(())
2415}
2416
2417/// Notify other participants that a buffer has been updated. This is
2418/// allowed for guests as long as the update is limited to selections.
2419async fn update_buffer(
2420 request: proto::UpdateBuffer,
2421 response: Response<proto::UpdateBuffer>,
2422 session: MessageContext,
2423) -> Result<()> {
2424 let project_id = ProjectId::from_proto(request.project_id);
2425 let mut capability = Capability::ReadOnly;
2426
2427 for op in request.operations.iter() {
2428 match op.variant {
2429 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2430 Some(_) => capability = Capability::ReadWrite,
2431 }
2432 }
2433
2434 let host = {
2435 let guard = session
2436 .db()
2437 .await
2438 .connections_for_buffer_update(project_id, session.connection_id, capability)
2439 .await?;
2440
2441 let (host, guests) = &*guard;
2442
2443 broadcast(
2444 Some(session.connection_id),
2445 guests.clone(),
2446 |connection_id| {
2447 session
2448 .peer
2449 .forward_send(session.connection_id, connection_id, request.clone())
2450 },
2451 );
2452
2453 *host
2454 };
2455
2456 if host != session.connection_id {
2457 session.forward_request(host, request.clone()).await?;
2458 }
2459
2460 response.send(proto::Ack {})?;
2461 Ok(())
2462}
2463
2464async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2465 let project_id = ProjectId::from_proto(message.project_id);
2466
2467 let operation = message.operation.as_ref().context("invalid operation")?;
2468 let capability = match operation.variant.as_ref() {
2469 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2470 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2471 match buffer_op.variant {
2472 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2473 Capability::ReadOnly
2474 }
2475 _ => Capability::ReadWrite,
2476 }
2477 } else {
2478 Capability::ReadWrite
2479 }
2480 }
2481 Some(_) => Capability::ReadWrite,
2482 None => Capability::ReadOnly,
2483 };
2484
2485 let guard = session
2486 .db()
2487 .await
2488 .connections_for_buffer_update(project_id, session.connection_id, capability)
2489 .await?;
2490
2491 let (host, guests) = &*guard;
2492
2493 broadcast(
2494 Some(session.connection_id),
2495 guests.iter().chain([host]).copied(),
2496 |connection_id| {
2497 session
2498 .peer
2499 .forward_send(session.connection_id, connection_id, message.clone())
2500 },
2501 );
2502
2503 Ok(())
2504}
2505
2506/// Notify other participants that a project has been updated.
2507async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2508 request: T,
2509 session: MessageContext,
2510) -> Result<()> {
2511 let project_id = ProjectId::from_proto(request.remote_entity_id());
2512 let project_connection_ids = session
2513 .db()
2514 .await
2515 .project_connection_ids(project_id, session.connection_id, false)
2516 .await?;
2517
2518 broadcast(
2519 Some(session.connection_id),
2520 project_connection_ids.iter().copied(),
2521 |connection_id| {
2522 session
2523 .peer
2524 .forward_send(session.connection_id, connection_id, request.clone())
2525 },
2526 );
2527 Ok(())
2528}
2529
2530/// Start following another user in a call.
2531async fn follow(
2532 request: proto::Follow,
2533 response: Response<proto::Follow>,
2534 session: MessageContext,
2535) -> Result<()> {
2536 let room_id = RoomId::from_proto(request.room_id);
2537 let project_id = request.project_id.map(ProjectId::from_proto);
2538 let leader_id = request.leader_id.context("invalid leader id")?.into();
2539 let follower_id = session.connection_id;
2540
2541 session
2542 .db()
2543 .await
2544 .check_room_participants(room_id, leader_id, session.connection_id)
2545 .await?;
2546
2547 let response_payload = session.forward_request(leader_id, request).await?;
2548 response.send(response_payload)?;
2549
2550 if let Some(project_id) = project_id {
2551 let room = session
2552 .db()
2553 .await
2554 .follow(room_id, project_id, leader_id, follower_id)
2555 .await?;
2556 room_updated(&room, &session.peer);
2557 }
2558
2559 Ok(())
2560}
2561
2562/// Stop following another user in a call.
2563async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2564 let room_id = RoomId::from_proto(request.room_id);
2565 let project_id = request.project_id.map(ProjectId::from_proto);
2566 let leader_id = request.leader_id.context("invalid leader id")?.into();
2567 let follower_id = session.connection_id;
2568
2569 session
2570 .db()
2571 .await
2572 .check_room_participants(room_id, leader_id, session.connection_id)
2573 .await?;
2574
2575 session
2576 .peer
2577 .forward_send(session.connection_id, leader_id, request)?;
2578
2579 if let Some(project_id) = project_id {
2580 let room = session
2581 .db()
2582 .await
2583 .unfollow(room_id, project_id, leader_id, follower_id)
2584 .await?;
2585 room_updated(&room, &session.peer);
2586 }
2587
2588 Ok(())
2589}
2590
2591/// Notify everyone following you of your current location.
2592async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2593 let room_id = RoomId::from_proto(request.room_id);
2594 let database = session.db.lock().await;
2595
2596 let connection_ids = if let Some(project_id) = request.project_id {
2597 let project_id = ProjectId::from_proto(project_id);
2598 database
2599 .project_connection_ids(project_id, session.connection_id, true)
2600 .await?
2601 } else {
2602 database
2603 .room_connection_ids(room_id, session.connection_id)
2604 .await?
2605 };
2606
2607 // For now, don't send view update messages back to that view's current leader.
2608 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2609 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2610 _ => None,
2611 });
2612
2613 for connection_id in connection_ids.iter().cloned() {
2614 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2615 session
2616 .peer
2617 .forward_send(session.connection_id, connection_id, request.clone())?;
2618 }
2619 }
2620 Ok(())
2621}
2622
2623/// Get public data about users.
2624async fn get_users(
2625 request: proto::GetUsers,
2626 response: Response<proto::GetUsers>,
2627 session: MessageContext,
2628) -> Result<()> {
2629 let user_ids = request
2630 .user_ids
2631 .into_iter()
2632 .map(UserId::from_proto)
2633 .collect();
2634 let users = session
2635 .db()
2636 .await
2637 .get_users_by_ids(user_ids)
2638 .await?
2639 .into_iter()
2640 .map(|user| proto::User {
2641 id: user.id.to_proto(),
2642 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2643 github_login: user.github_login,
2644 name: user.name,
2645 })
2646 .collect();
2647 response.send(proto::UsersResponse { users })?;
2648 Ok(())
2649}
2650
2651/// Search for users (to invite) buy Github login
2652async fn fuzzy_search_users(
2653 request: proto::FuzzySearchUsers,
2654 response: Response<proto::FuzzySearchUsers>,
2655 session: MessageContext,
2656) -> Result<()> {
2657 let query = request.query;
2658 let users = match query.len() {
2659 0 => vec![],
2660 1 | 2 => session
2661 .db()
2662 .await
2663 .get_user_by_github_login(&query)
2664 .await?
2665 .into_iter()
2666 .collect(),
2667 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2668 };
2669 let users = users
2670 .into_iter()
2671 .filter(|user| user.id != session.user_id())
2672 .map(|user| proto::User {
2673 id: user.id.to_proto(),
2674 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2675 github_login: user.github_login,
2676 name: user.name,
2677 })
2678 .collect();
2679 response.send(proto::UsersResponse { users })?;
2680 Ok(())
2681}
2682
2683/// Send a contact request to another user.
2684async fn request_contact(
2685 request: proto::RequestContact,
2686 response: Response<proto::RequestContact>,
2687 session: MessageContext,
2688) -> Result<()> {
2689 let requester_id = session.user_id();
2690 let responder_id = UserId::from_proto(request.responder_id);
2691 if requester_id == responder_id {
2692 return Err(anyhow!("cannot add yourself as a contact"))?;
2693 }
2694
2695 let notifications = session
2696 .db()
2697 .await
2698 .send_contact_request(requester_id, responder_id)
2699 .await?;
2700
2701 // Update outgoing contact requests of requester
2702 let mut update = proto::UpdateContacts::default();
2703 update.outgoing_requests.push(responder_id.to_proto());
2704 for connection_id in session
2705 .connection_pool()
2706 .await
2707 .user_connection_ids(requester_id)
2708 {
2709 session.peer.send(connection_id, update.clone())?;
2710 }
2711
2712 // Update incoming contact requests of responder
2713 let mut update = proto::UpdateContacts::default();
2714 update
2715 .incoming_requests
2716 .push(proto::IncomingContactRequest {
2717 requester_id: requester_id.to_proto(),
2718 });
2719 let connection_pool = session.connection_pool().await;
2720 for connection_id in connection_pool.user_connection_ids(responder_id) {
2721 session.peer.send(connection_id, update.clone())?;
2722 }
2723
2724 send_notifications(&connection_pool, &session.peer, notifications);
2725
2726 response.send(proto::Ack {})?;
2727 Ok(())
2728}
2729
2730/// Accept or decline a contact request
2731async fn respond_to_contact_request(
2732 request: proto::RespondToContactRequest,
2733 response: Response<proto::RespondToContactRequest>,
2734 session: MessageContext,
2735) -> Result<()> {
2736 let responder_id = session.user_id();
2737 let requester_id = UserId::from_proto(request.requester_id);
2738 let db = session.db().await;
2739 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2740 db.dismiss_contact_notification(responder_id, requester_id)
2741 .await?;
2742 } else {
2743 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2744
2745 let notifications = db
2746 .respond_to_contact_request(responder_id, requester_id, accept)
2747 .await?;
2748 let requester_busy = db.is_user_busy(requester_id).await?;
2749 let responder_busy = db.is_user_busy(responder_id).await?;
2750
2751 let pool = session.connection_pool().await;
2752 // Update responder with new contact
2753 let mut update = proto::UpdateContacts::default();
2754 if accept {
2755 update
2756 .contacts
2757 .push(contact_for_user(requester_id, requester_busy, &pool));
2758 }
2759 update
2760 .remove_incoming_requests
2761 .push(requester_id.to_proto());
2762 for connection_id in pool.user_connection_ids(responder_id) {
2763 session.peer.send(connection_id, update.clone())?;
2764 }
2765
2766 // Update requester with new contact
2767 let mut update = proto::UpdateContacts::default();
2768 if accept {
2769 update
2770 .contacts
2771 .push(contact_for_user(responder_id, responder_busy, &pool));
2772 }
2773 update
2774 .remove_outgoing_requests
2775 .push(responder_id.to_proto());
2776
2777 for connection_id in pool.user_connection_ids(requester_id) {
2778 session.peer.send(connection_id, update.clone())?;
2779 }
2780
2781 send_notifications(&pool, &session.peer, notifications);
2782 }
2783
2784 response.send(proto::Ack {})?;
2785 Ok(())
2786}
2787
2788/// Remove a contact.
2789async fn remove_contact(
2790 request: proto::RemoveContact,
2791 response: Response<proto::RemoveContact>,
2792 session: MessageContext,
2793) -> Result<()> {
2794 let requester_id = session.user_id();
2795 let responder_id = UserId::from_proto(request.user_id);
2796 let db = session.db().await;
2797 let (contact_accepted, deleted_notification_id) =
2798 db.remove_contact(requester_id, responder_id).await?;
2799
2800 let pool = session.connection_pool().await;
2801 // Update outgoing contact requests of requester
2802 let mut update = proto::UpdateContacts::default();
2803 if contact_accepted {
2804 update.remove_contacts.push(responder_id.to_proto());
2805 } else {
2806 update
2807 .remove_outgoing_requests
2808 .push(responder_id.to_proto());
2809 }
2810 for connection_id in pool.user_connection_ids(requester_id) {
2811 session.peer.send(connection_id, update.clone())?;
2812 }
2813
2814 // Update incoming contact requests of responder
2815 let mut update = proto::UpdateContacts::default();
2816 if contact_accepted {
2817 update.remove_contacts.push(requester_id.to_proto());
2818 } else {
2819 update
2820 .remove_incoming_requests
2821 .push(requester_id.to_proto());
2822 }
2823 for connection_id in pool.user_connection_ids(responder_id) {
2824 session.peer.send(connection_id, update.clone())?;
2825 if let Some(notification_id) = deleted_notification_id {
2826 session.peer.send(
2827 connection_id,
2828 proto::DeleteNotification {
2829 notification_id: notification_id.to_proto(),
2830 },
2831 )?;
2832 }
2833 }
2834
2835 response.send(proto::Ack {})?;
2836 Ok(())
2837}
2838
2839fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2840 version.0.minor < 139
2841}
2842
2843async fn subscribe_to_channels(
2844 _: proto::SubscribeToChannels,
2845 session: MessageContext,
2846) -> Result<()> {
2847 subscribe_user_to_channels(session.user_id(), &session).await?;
2848 Ok(())
2849}
2850
2851async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2852 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2853 let mut pool = session.connection_pool().await;
2854 for membership in &channels_for_user.channel_memberships {
2855 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2856 }
2857 session.peer.send(
2858 session.connection_id,
2859 build_update_user_channels(&channels_for_user),
2860 )?;
2861 session.peer.send(
2862 session.connection_id,
2863 build_channels_update(channels_for_user),
2864 )?;
2865 Ok(())
2866}
2867
2868/// Creates a new channel.
2869async fn create_channel(
2870 request: proto::CreateChannel,
2871 response: Response<proto::CreateChannel>,
2872 session: MessageContext,
2873) -> Result<()> {
2874 let db = session.db().await;
2875
2876 let parent_id = request.parent_id.map(ChannelId::from_proto);
2877 let (channel, membership) = db
2878 .create_channel(&request.name, parent_id, session.user_id())
2879 .await?;
2880
2881 let root_id = channel.root_id();
2882 let channel = Channel::from_model(channel);
2883
2884 response.send(proto::CreateChannelResponse {
2885 channel: Some(channel.to_proto()),
2886 parent_id: request.parent_id,
2887 })?;
2888
2889 let mut connection_pool = session.connection_pool().await;
2890 if let Some(membership) = membership {
2891 connection_pool.subscribe_to_channel(
2892 membership.user_id,
2893 membership.channel_id,
2894 membership.role,
2895 );
2896 let update = proto::UpdateUserChannels {
2897 channel_memberships: vec![proto::ChannelMembership {
2898 channel_id: membership.channel_id.to_proto(),
2899 role: membership.role.into(),
2900 }],
2901 ..Default::default()
2902 };
2903 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2904 session.peer.send(connection_id, update.clone())?;
2905 }
2906 }
2907
2908 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2909 if !role.can_see_channel(channel.visibility) {
2910 continue;
2911 }
2912
2913 let update = proto::UpdateChannels {
2914 channels: vec![channel.to_proto()],
2915 ..Default::default()
2916 };
2917 session.peer.send(connection_id, update.clone())?;
2918 }
2919
2920 Ok(())
2921}
2922
2923/// Delete a channel
2924async fn delete_channel(
2925 request: proto::DeleteChannel,
2926 response: Response<proto::DeleteChannel>,
2927 session: MessageContext,
2928) -> Result<()> {
2929 let db = session.db().await;
2930
2931 let channel_id = request.channel_id;
2932 let (root_channel, removed_channels) = db
2933 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2934 .await?;
2935 response.send(proto::Ack {})?;
2936
2937 // Notify members of removed channels
2938 let mut update = proto::UpdateChannels::default();
2939 update
2940 .delete_channels
2941 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2942
2943 let connection_pool = session.connection_pool().await;
2944 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2945 session.peer.send(connection_id, update.clone())?;
2946 }
2947
2948 Ok(())
2949}
2950
2951/// Invite someone to join a channel.
2952async fn invite_channel_member(
2953 request: proto::InviteChannelMember,
2954 response: Response<proto::InviteChannelMember>,
2955 session: MessageContext,
2956) -> Result<()> {
2957 let db = session.db().await;
2958 let channel_id = ChannelId::from_proto(request.channel_id);
2959 let invitee_id = UserId::from_proto(request.user_id);
2960 let InviteMemberResult {
2961 channel,
2962 notifications,
2963 } = db
2964 .invite_channel_member(
2965 channel_id,
2966 invitee_id,
2967 session.user_id(),
2968 request.role().into(),
2969 )
2970 .await?;
2971
2972 let update = proto::UpdateChannels {
2973 channel_invitations: vec![channel.to_proto()],
2974 ..Default::default()
2975 };
2976
2977 let connection_pool = session.connection_pool().await;
2978 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2979 session.peer.send(connection_id, update.clone())?;
2980 }
2981
2982 send_notifications(&connection_pool, &session.peer, notifications);
2983
2984 response.send(proto::Ack {})?;
2985 Ok(())
2986}
2987
2988/// remove someone from a channel
2989async fn remove_channel_member(
2990 request: proto::RemoveChannelMember,
2991 response: Response<proto::RemoveChannelMember>,
2992 session: MessageContext,
2993) -> Result<()> {
2994 let db = session.db().await;
2995 let channel_id = ChannelId::from_proto(request.channel_id);
2996 let member_id = UserId::from_proto(request.user_id);
2997
2998 let RemoveChannelMemberResult {
2999 membership_update,
3000 notification_id,
3001 } = db
3002 .remove_channel_member(channel_id, member_id, session.user_id())
3003 .await?;
3004
3005 let mut connection_pool = session.connection_pool().await;
3006 notify_membership_updated(
3007 &mut connection_pool,
3008 membership_update,
3009 member_id,
3010 &session.peer,
3011 );
3012 for connection_id in connection_pool.user_connection_ids(member_id) {
3013 if let Some(notification_id) = notification_id {
3014 session
3015 .peer
3016 .send(
3017 connection_id,
3018 proto::DeleteNotification {
3019 notification_id: notification_id.to_proto(),
3020 },
3021 )
3022 .trace_err();
3023 }
3024 }
3025
3026 response.send(proto::Ack {})?;
3027 Ok(())
3028}
3029
3030/// Toggle the channel between public and private.
3031/// Care is taken to maintain the invariant that public channels only descend from public channels,
3032/// (though members-only channels can appear at any point in the hierarchy).
3033async fn set_channel_visibility(
3034 request: proto::SetChannelVisibility,
3035 response: Response<proto::SetChannelVisibility>,
3036 session: MessageContext,
3037) -> Result<()> {
3038 let db = session.db().await;
3039 let channel_id = ChannelId::from_proto(request.channel_id);
3040 let visibility = request.visibility().into();
3041
3042 let channel_model = db
3043 .set_channel_visibility(channel_id, visibility, session.user_id())
3044 .await?;
3045 let root_id = channel_model.root_id();
3046 let channel = Channel::from_model(channel_model);
3047
3048 let mut connection_pool = session.connection_pool().await;
3049 for (user_id, role) in connection_pool
3050 .channel_user_ids(root_id)
3051 .collect::<Vec<_>>()
3052 .into_iter()
3053 {
3054 let update = if role.can_see_channel(channel.visibility) {
3055 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3056 proto::UpdateChannels {
3057 channels: vec![channel.to_proto()],
3058 ..Default::default()
3059 }
3060 } else {
3061 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3062 proto::UpdateChannels {
3063 delete_channels: vec![channel.id.to_proto()],
3064 ..Default::default()
3065 }
3066 };
3067
3068 for connection_id in connection_pool.user_connection_ids(user_id) {
3069 session.peer.send(connection_id, update.clone())?;
3070 }
3071 }
3072
3073 response.send(proto::Ack {})?;
3074 Ok(())
3075}
3076
3077/// Alter the role for a user in the channel.
3078async fn set_channel_member_role(
3079 request: proto::SetChannelMemberRole,
3080 response: Response<proto::SetChannelMemberRole>,
3081 session: MessageContext,
3082) -> Result<()> {
3083 let db = session.db().await;
3084 let channel_id = ChannelId::from_proto(request.channel_id);
3085 let member_id = UserId::from_proto(request.user_id);
3086 let result = db
3087 .set_channel_member_role(
3088 channel_id,
3089 session.user_id(),
3090 member_id,
3091 request.role().into(),
3092 )
3093 .await?;
3094
3095 match result {
3096 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3097 let mut connection_pool = session.connection_pool().await;
3098 notify_membership_updated(
3099 &mut connection_pool,
3100 membership_update,
3101 member_id,
3102 &session.peer,
3103 )
3104 }
3105 db::SetMemberRoleResult::InviteUpdated(channel) => {
3106 let update = proto::UpdateChannels {
3107 channel_invitations: vec![channel.to_proto()],
3108 ..Default::default()
3109 };
3110
3111 for connection_id in session
3112 .connection_pool()
3113 .await
3114 .user_connection_ids(member_id)
3115 {
3116 session.peer.send(connection_id, update.clone())?;
3117 }
3118 }
3119 }
3120
3121 response.send(proto::Ack {})?;
3122 Ok(())
3123}
3124
3125/// Change the name of a channel
3126async fn rename_channel(
3127 request: proto::RenameChannel,
3128 response: Response<proto::RenameChannel>,
3129 session: MessageContext,
3130) -> Result<()> {
3131 let db = session.db().await;
3132 let channel_id = ChannelId::from_proto(request.channel_id);
3133 let channel_model = db
3134 .rename_channel(channel_id, session.user_id(), &request.name)
3135 .await?;
3136 let root_id = channel_model.root_id();
3137 let channel = Channel::from_model(channel_model);
3138
3139 response.send(proto::RenameChannelResponse {
3140 channel: Some(channel.to_proto()),
3141 })?;
3142
3143 let connection_pool = session.connection_pool().await;
3144 let update = proto::UpdateChannels {
3145 channels: vec![channel.to_proto()],
3146 ..Default::default()
3147 };
3148 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3149 if role.can_see_channel(channel.visibility) {
3150 session.peer.send(connection_id, update.clone())?;
3151 }
3152 }
3153
3154 Ok(())
3155}
3156
3157/// Move a channel to a new parent.
3158async fn move_channel(
3159 request: proto::MoveChannel,
3160 response: Response<proto::MoveChannel>,
3161 session: MessageContext,
3162) -> Result<()> {
3163 let channel_id = ChannelId::from_proto(request.channel_id);
3164 let to = ChannelId::from_proto(request.to);
3165
3166 let (root_id, channels) = session
3167 .db()
3168 .await
3169 .move_channel(channel_id, to, session.user_id())
3170 .await?;
3171
3172 let connection_pool = session.connection_pool().await;
3173 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3174 let channels = channels
3175 .iter()
3176 .filter_map(|channel| {
3177 if role.can_see_channel(channel.visibility) {
3178 Some(channel.to_proto())
3179 } else {
3180 None
3181 }
3182 })
3183 .collect::<Vec<_>>();
3184 if channels.is_empty() {
3185 continue;
3186 }
3187
3188 let update = proto::UpdateChannels {
3189 channels,
3190 ..Default::default()
3191 };
3192
3193 session.peer.send(connection_id, update.clone())?;
3194 }
3195
3196 response.send(Ack {})?;
3197 Ok(())
3198}
3199
3200async fn reorder_channel(
3201 request: proto::ReorderChannel,
3202 response: Response<proto::ReorderChannel>,
3203 session: MessageContext,
3204) -> Result<()> {
3205 let channel_id = ChannelId::from_proto(request.channel_id);
3206 let direction = request.direction();
3207
3208 let updated_channels = session
3209 .db()
3210 .await
3211 .reorder_channel(channel_id, direction, session.user_id())
3212 .await?;
3213
3214 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3215 let connection_pool = session.connection_pool().await;
3216 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3217 let channels = updated_channels
3218 .iter()
3219 .filter_map(|channel| {
3220 if role.can_see_channel(channel.visibility) {
3221 Some(channel.to_proto())
3222 } else {
3223 None
3224 }
3225 })
3226 .collect::<Vec<_>>();
3227
3228 if channels.is_empty() {
3229 continue;
3230 }
3231
3232 let update = proto::UpdateChannels {
3233 channels,
3234 ..Default::default()
3235 };
3236
3237 session.peer.send(connection_id, update.clone())?;
3238 }
3239 }
3240
3241 response.send(Ack {})?;
3242 Ok(())
3243}
3244
3245/// Get the list of channel members
3246async fn get_channel_members(
3247 request: proto::GetChannelMembers,
3248 response: Response<proto::GetChannelMembers>,
3249 session: MessageContext,
3250) -> Result<()> {
3251 let db = session.db().await;
3252 let channel_id = ChannelId::from_proto(request.channel_id);
3253 let limit = if request.limit == 0 {
3254 u16::MAX as u64
3255 } else {
3256 request.limit
3257 };
3258 let (members, users) = db
3259 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3260 .await?;
3261 response.send(proto::GetChannelMembersResponse { members, users })?;
3262 Ok(())
3263}
3264
3265/// Accept or decline a channel invitation.
3266async fn respond_to_channel_invite(
3267 request: proto::RespondToChannelInvite,
3268 response: Response<proto::RespondToChannelInvite>,
3269 session: MessageContext,
3270) -> Result<()> {
3271 let db = session.db().await;
3272 let channel_id = ChannelId::from_proto(request.channel_id);
3273 let RespondToChannelInvite {
3274 membership_update,
3275 notifications,
3276 } = db
3277 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3278 .await?;
3279
3280 let mut connection_pool = session.connection_pool().await;
3281 if let Some(membership_update) = membership_update {
3282 notify_membership_updated(
3283 &mut connection_pool,
3284 membership_update,
3285 session.user_id(),
3286 &session.peer,
3287 );
3288 } else {
3289 let update = proto::UpdateChannels {
3290 remove_channel_invitations: vec![channel_id.to_proto()],
3291 ..Default::default()
3292 };
3293
3294 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3295 session.peer.send(connection_id, update.clone())?;
3296 }
3297 };
3298
3299 send_notifications(&connection_pool, &session.peer, notifications);
3300
3301 response.send(proto::Ack {})?;
3302
3303 Ok(())
3304}
3305
3306/// Join the channels' room
3307async fn join_channel(
3308 request: proto::JoinChannel,
3309 response: Response<proto::JoinChannel>,
3310 session: MessageContext,
3311) -> Result<()> {
3312 let channel_id = ChannelId::from_proto(request.channel_id);
3313 join_channel_internal(channel_id, Box::new(response), session).await
3314}
3315
3316trait JoinChannelInternalResponse {
3317 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3318}
3319impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3320 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3321 Response::<proto::JoinChannel>::send(self, result)
3322 }
3323}
3324impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3325 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3326 Response::<proto::JoinRoom>::send(self, result)
3327 }
3328}
3329
3330async fn join_channel_internal(
3331 channel_id: ChannelId,
3332 response: Box<impl JoinChannelInternalResponse>,
3333 session: MessageContext,
3334) -> Result<()> {
3335 let joined_room = {
3336 let mut db = session.db().await;
3337 // If zed quits without leaving the room, and the user re-opens zed before the
3338 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3339 // room they were in.
3340 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3341 tracing::info!(
3342 stale_connection_id = %connection,
3343 "cleaning up stale connection",
3344 );
3345 drop(db);
3346 leave_room_for_session(&session, connection).await?;
3347 db = session.db().await;
3348 }
3349
3350 let (joined_room, membership_updated, role) = db
3351 .join_channel(channel_id, session.user_id(), session.connection_id)
3352 .await?;
3353
3354 let live_kit_connection_info =
3355 session
3356 .app_state
3357 .livekit_client
3358 .as_ref()
3359 .and_then(|live_kit| {
3360 let (can_publish, token) = if role == ChannelRole::Guest {
3361 (
3362 false,
3363 live_kit
3364 .guest_token(
3365 &joined_room.room.livekit_room,
3366 &session.user_id().to_string(),
3367 )
3368 .trace_err()?,
3369 )
3370 } else {
3371 (
3372 true,
3373 live_kit
3374 .room_token(
3375 &joined_room.room.livekit_room,
3376 &session.user_id().to_string(),
3377 )
3378 .trace_err()?,
3379 )
3380 };
3381
3382 Some(LiveKitConnectionInfo {
3383 server_url: live_kit.url().into(),
3384 token,
3385 can_publish,
3386 })
3387 });
3388
3389 response.send(proto::JoinRoomResponse {
3390 room: Some(joined_room.room.clone()),
3391 channel_id: joined_room
3392 .channel
3393 .as_ref()
3394 .map(|channel| channel.id.to_proto()),
3395 live_kit_connection_info,
3396 })?;
3397
3398 let mut connection_pool = session.connection_pool().await;
3399 if let Some(membership_updated) = membership_updated {
3400 notify_membership_updated(
3401 &mut connection_pool,
3402 membership_updated,
3403 session.user_id(),
3404 &session.peer,
3405 );
3406 }
3407
3408 room_updated(&joined_room.room, &session.peer);
3409
3410 joined_room
3411 };
3412
3413 channel_updated(
3414 &joined_room.channel.context("channel not returned")?,
3415 &joined_room.room,
3416 &session.peer,
3417 &*session.connection_pool().await,
3418 );
3419
3420 update_user_contacts(session.user_id(), &session).await?;
3421 Ok(())
3422}
3423
3424/// Start editing the channel notes
3425async fn join_channel_buffer(
3426 request: proto::JoinChannelBuffer,
3427 response: Response<proto::JoinChannelBuffer>,
3428 session: MessageContext,
3429) -> Result<()> {
3430 let db = session.db().await;
3431 let channel_id = ChannelId::from_proto(request.channel_id);
3432
3433 let open_response = db
3434 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3435 .await?;
3436
3437 let collaborators = open_response.collaborators.clone();
3438 response.send(open_response)?;
3439
3440 let update = UpdateChannelBufferCollaborators {
3441 channel_id: channel_id.to_proto(),
3442 collaborators: collaborators.clone(),
3443 };
3444 channel_buffer_updated(
3445 session.connection_id,
3446 collaborators
3447 .iter()
3448 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3449 &update,
3450 &session.peer,
3451 );
3452
3453 Ok(())
3454}
3455
3456/// Edit the channel notes
3457async fn update_channel_buffer(
3458 request: proto::UpdateChannelBuffer,
3459 session: MessageContext,
3460) -> Result<()> {
3461 let db = session.db().await;
3462 let channel_id = ChannelId::from_proto(request.channel_id);
3463
3464 let (collaborators, epoch, version) = db
3465 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3466 .await?;
3467
3468 channel_buffer_updated(
3469 session.connection_id,
3470 collaborators.clone(),
3471 &proto::UpdateChannelBuffer {
3472 channel_id: channel_id.to_proto(),
3473 operations: request.operations,
3474 },
3475 &session.peer,
3476 );
3477
3478 let pool = &*session.connection_pool().await;
3479
3480 let non_collaborators =
3481 pool.channel_connection_ids(channel_id)
3482 .filter_map(|(connection_id, _)| {
3483 if collaborators.contains(&connection_id) {
3484 None
3485 } else {
3486 Some(connection_id)
3487 }
3488 });
3489
3490 broadcast(None, non_collaborators, |peer_id| {
3491 session.peer.send(
3492 peer_id,
3493 proto::UpdateChannels {
3494 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3495 channel_id: channel_id.to_proto(),
3496 epoch: epoch as u64,
3497 version: version.clone(),
3498 }],
3499 ..Default::default()
3500 },
3501 )
3502 });
3503
3504 Ok(())
3505}
3506
3507/// Rejoin the channel notes after a connection blip
3508async fn rejoin_channel_buffers(
3509 request: proto::RejoinChannelBuffers,
3510 response: Response<proto::RejoinChannelBuffers>,
3511 session: MessageContext,
3512) -> Result<()> {
3513 let db = session.db().await;
3514 let buffers = db
3515 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3516 .await?;
3517
3518 for rejoined_buffer in &buffers {
3519 let collaborators_to_notify = rejoined_buffer
3520 .buffer
3521 .collaborators
3522 .iter()
3523 .filter_map(|c| Some(c.peer_id?.into()));
3524 channel_buffer_updated(
3525 session.connection_id,
3526 collaborators_to_notify,
3527 &proto::UpdateChannelBufferCollaborators {
3528 channel_id: rejoined_buffer.buffer.channel_id,
3529 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3530 },
3531 &session.peer,
3532 );
3533 }
3534
3535 response.send(proto::RejoinChannelBuffersResponse {
3536 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3537 })?;
3538
3539 Ok(())
3540}
3541
3542/// Stop editing the channel notes
3543async fn leave_channel_buffer(
3544 request: proto::LeaveChannelBuffer,
3545 response: Response<proto::LeaveChannelBuffer>,
3546 session: MessageContext,
3547) -> Result<()> {
3548 let db = session.db().await;
3549 let channel_id = ChannelId::from_proto(request.channel_id);
3550
3551 let left_buffer = db
3552 .leave_channel_buffer(channel_id, session.connection_id)
3553 .await?;
3554
3555 response.send(Ack {})?;
3556
3557 channel_buffer_updated(
3558 session.connection_id,
3559 left_buffer.connections,
3560 &proto::UpdateChannelBufferCollaborators {
3561 channel_id: channel_id.to_proto(),
3562 collaborators: left_buffer.collaborators,
3563 },
3564 &session.peer,
3565 );
3566
3567 Ok(())
3568}
3569
3570fn channel_buffer_updated<T: EnvelopedMessage>(
3571 sender_id: ConnectionId,
3572 collaborators: impl IntoIterator<Item = ConnectionId>,
3573 message: &T,
3574 peer: &Peer,
3575) {
3576 broadcast(Some(sender_id), collaborators, |peer_id| {
3577 peer.send(peer_id, message.clone())
3578 });
3579}
3580
3581fn send_notifications(
3582 connection_pool: &ConnectionPool,
3583 peer: &Peer,
3584 notifications: db::NotificationBatch,
3585) {
3586 for (user_id, notification) in notifications {
3587 for connection_id in connection_pool.user_connection_ids(user_id) {
3588 if let Err(error) = peer.send(
3589 connection_id,
3590 proto::AddNotification {
3591 notification: Some(notification.clone()),
3592 },
3593 ) {
3594 tracing::error!(
3595 "failed to send notification to {:?} {}",
3596 connection_id,
3597 error
3598 );
3599 }
3600 }
3601 }
3602}
3603
3604/// Send a message to the channel
3605async fn send_channel_message(
3606 _request: proto::SendChannelMessage,
3607 _response: Response<proto::SendChannelMessage>,
3608 _session: MessageContext,
3609) -> Result<()> {
3610 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3611}
3612
3613/// Delete a channel message
3614async fn remove_channel_message(
3615 _request: proto::RemoveChannelMessage,
3616 _response: Response<proto::RemoveChannelMessage>,
3617 _session: MessageContext,
3618) -> Result<()> {
3619 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3620}
3621
3622async fn update_channel_message(
3623 _request: proto::UpdateChannelMessage,
3624 _response: Response<proto::UpdateChannelMessage>,
3625 _session: MessageContext,
3626) -> Result<()> {
3627 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3628}
3629
3630/// Mark a channel message as read
3631async fn acknowledge_channel_message(
3632 _request: proto::AckChannelMessage,
3633 _session: MessageContext,
3634) -> Result<()> {
3635 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3636}
3637
3638/// Mark a buffer version as synced
3639async fn acknowledge_buffer_version(
3640 request: proto::AckBufferOperation,
3641 session: MessageContext,
3642) -> Result<()> {
3643 let buffer_id = BufferId::from_proto(request.buffer_id);
3644 session
3645 .db()
3646 .await
3647 .observe_buffer_version(
3648 buffer_id,
3649 session.user_id(),
3650 request.epoch as i32,
3651 &request.version,
3652 )
3653 .await?;
3654 Ok(())
3655}
3656
3657/// Get a Supermaven API key for the user
3658async fn get_supermaven_api_key(
3659 _request: proto::GetSupermavenApiKey,
3660 response: Response<proto::GetSupermavenApiKey>,
3661 session: MessageContext,
3662) -> Result<()> {
3663 let user_id: String = session.user_id().to_string();
3664 if !session.is_staff() {
3665 return Err(anyhow!("supermaven not enabled for this account"))?;
3666 }
3667
3668 let email = session.email().context("user must have an email")?;
3669
3670 let supermaven_admin_api = session
3671 .supermaven_client
3672 .as_ref()
3673 .context("supermaven not configured")?;
3674
3675 let result = supermaven_admin_api
3676 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3677 .await?;
3678
3679 response.send(proto::GetSupermavenApiKeyResponse {
3680 api_key: result.api_key,
3681 })?;
3682
3683 Ok(())
3684}
3685
3686/// Start receiving chat updates for a channel
3687async fn join_channel_chat(
3688 _request: proto::JoinChannelChat,
3689 _response: Response<proto::JoinChannelChat>,
3690 _session: MessageContext,
3691) -> Result<()> {
3692 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3693}
3694
3695/// Stop receiving chat updates for a channel
3696async fn leave_channel_chat(
3697 _request: proto::LeaveChannelChat,
3698 _session: MessageContext,
3699) -> Result<()> {
3700 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3701}
3702
3703/// Retrieve the chat history for a channel
3704async fn get_channel_messages(
3705 _request: proto::GetChannelMessages,
3706 _response: Response<proto::GetChannelMessages>,
3707 _session: MessageContext,
3708) -> Result<()> {
3709 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3710}
3711
3712/// Retrieve specific chat messages
3713async fn get_channel_messages_by_id(
3714 _request: proto::GetChannelMessagesById,
3715 _response: Response<proto::GetChannelMessagesById>,
3716 _session: MessageContext,
3717) -> Result<()> {
3718 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3719}
3720
3721/// Retrieve the current users notifications
3722async fn get_notifications(
3723 request: proto::GetNotifications,
3724 response: Response<proto::GetNotifications>,
3725 session: MessageContext,
3726) -> Result<()> {
3727 let notifications = session
3728 .db()
3729 .await
3730 .get_notifications(
3731 session.user_id(),
3732 NOTIFICATION_COUNT_PER_PAGE,
3733 request.before_id.map(db::NotificationId::from_proto),
3734 )
3735 .await?;
3736 response.send(proto::GetNotificationsResponse {
3737 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3738 notifications,
3739 })?;
3740 Ok(())
3741}
3742
3743/// Mark notifications as read
3744async fn mark_notification_as_read(
3745 request: proto::MarkNotificationRead,
3746 response: Response<proto::MarkNotificationRead>,
3747 session: MessageContext,
3748) -> Result<()> {
3749 let database = &session.db().await;
3750 let notifications = database
3751 .mark_notification_as_read_by_id(
3752 session.user_id(),
3753 NotificationId::from_proto(request.notification_id),
3754 )
3755 .await?;
3756 send_notifications(
3757 &*session.connection_pool().await,
3758 &session.peer,
3759 notifications,
3760 );
3761 response.send(proto::Ack {})?;
3762 Ok(())
3763}
3764
3765fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3766 let message = match message {
3767 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3768 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3769 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3770 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3771 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3772 code: frame.code.into(),
3773 reason: frame.reason.as_str().to_owned().into(),
3774 })),
3775 // We should never receive a frame while reading the message, according
3776 // to the `tungstenite` maintainers:
3777 //
3778 // > It cannot occur when you read messages from the WebSocket, but it
3779 // > can be used when you want to send the raw frames (e.g. you want to
3780 // > send the frames to the WebSocket without composing the full message first).
3781 // >
3782 // > — https://github.com/snapview/tungstenite-rs/issues/268
3783 TungsteniteMessage::Frame(_) => {
3784 bail!("received an unexpected frame while reading the message")
3785 }
3786 };
3787
3788 Ok(message)
3789}
3790
3791fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3792 match message {
3793 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3794 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3795 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3796 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3797 AxumMessage::Close(frame) => {
3798 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3799 code: frame.code.into(),
3800 reason: frame.reason.as_ref().into(),
3801 }))
3802 }
3803 }
3804}
3805
3806fn notify_membership_updated(
3807 connection_pool: &mut ConnectionPool,
3808 result: MembershipUpdated,
3809 user_id: UserId,
3810 peer: &Peer,
3811) {
3812 for membership in &result.new_channels.channel_memberships {
3813 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3814 }
3815 for channel_id in &result.removed_channels {
3816 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3817 }
3818
3819 let user_channels_update = proto::UpdateUserChannels {
3820 channel_memberships: result
3821 .new_channels
3822 .channel_memberships
3823 .iter()
3824 .map(|cm| proto::ChannelMembership {
3825 channel_id: cm.channel_id.to_proto(),
3826 role: cm.role.into(),
3827 })
3828 .collect(),
3829 ..Default::default()
3830 };
3831
3832 let mut update = build_channels_update(result.new_channels);
3833 update.delete_channels = result
3834 .removed_channels
3835 .into_iter()
3836 .map(|id| id.to_proto())
3837 .collect();
3838 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3839
3840 for connection_id in connection_pool.user_connection_ids(user_id) {
3841 peer.send(connection_id, user_channels_update.clone())
3842 .trace_err();
3843 peer.send(connection_id, update.clone()).trace_err();
3844 }
3845}
3846
3847fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3848 proto::UpdateUserChannels {
3849 channel_memberships: channels
3850 .channel_memberships
3851 .iter()
3852 .map(|m| proto::ChannelMembership {
3853 channel_id: m.channel_id.to_proto(),
3854 role: m.role.into(),
3855 })
3856 .collect(),
3857 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3858 }
3859}
3860
3861fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3862 let mut update = proto::UpdateChannels::default();
3863
3864 for channel in channels.channels {
3865 update.channels.push(channel.to_proto());
3866 }
3867
3868 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3869
3870 for (channel_id, participants) in channels.channel_participants {
3871 update
3872 .channel_participants
3873 .push(proto::ChannelParticipants {
3874 channel_id: channel_id.to_proto(),
3875 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3876 });
3877 }
3878
3879 for channel in channels.invited_channels {
3880 update.channel_invitations.push(channel.to_proto());
3881 }
3882
3883 update
3884}
3885
3886fn build_initial_contacts_update(
3887 contacts: Vec<db::Contact>,
3888 pool: &ConnectionPool,
3889) -> proto::UpdateContacts {
3890 let mut update = proto::UpdateContacts::default();
3891
3892 for contact in contacts {
3893 match contact {
3894 db::Contact::Accepted { user_id, busy } => {
3895 update.contacts.push(contact_for_user(user_id, busy, pool));
3896 }
3897 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3898 db::Contact::Incoming { user_id } => {
3899 update
3900 .incoming_requests
3901 .push(proto::IncomingContactRequest {
3902 requester_id: user_id.to_proto(),
3903 })
3904 }
3905 }
3906 }
3907
3908 update
3909}
3910
3911fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3912 proto::Contact {
3913 user_id: user_id.to_proto(),
3914 online: pool.is_user_online(user_id),
3915 busy,
3916 }
3917}
3918
3919fn room_updated(room: &proto::Room, peer: &Peer) {
3920 broadcast(
3921 None,
3922 room.participants
3923 .iter()
3924 .filter_map(|participant| Some(participant.peer_id?.into())),
3925 |peer_id| {
3926 peer.send(
3927 peer_id,
3928 proto::RoomUpdated {
3929 room: Some(room.clone()),
3930 },
3931 )
3932 },
3933 );
3934}
3935
3936fn channel_updated(
3937 channel: &db::channel::Model,
3938 room: &proto::Room,
3939 peer: &Peer,
3940 pool: &ConnectionPool,
3941) {
3942 let participants = room
3943 .participants
3944 .iter()
3945 .map(|p| p.user_id)
3946 .collect::<Vec<_>>();
3947
3948 broadcast(
3949 None,
3950 pool.channel_connection_ids(channel.root_id())
3951 .filter_map(|(channel_id, role)| {
3952 role.can_see_channel(channel.visibility)
3953 .then_some(channel_id)
3954 }),
3955 |peer_id| {
3956 peer.send(
3957 peer_id,
3958 proto::UpdateChannels {
3959 channel_participants: vec![proto::ChannelParticipants {
3960 channel_id: channel.id.to_proto(),
3961 participant_user_ids: participants.clone(),
3962 }],
3963 ..Default::default()
3964 },
3965 )
3966 },
3967 );
3968}
3969
3970async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3971 let db = session.db().await;
3972
3973 let contacts = db.get_contacts(user_id).await?;
3974 let busy = db.is_user_busy(user_id).await?;
3975
3976 let pool = session.connection_pool().await;
3977 let updated_contact = contact_for_user(user_id, busy, &pool);
3978 for contact in contacts {
3979 if let db::Contact::Accepted {
3980 user_id: contact_user_id,
3981 ..
3982 } = contact
3983 {
3984 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3985 session
3986 .peer
3987 .send(
3988 contact_conn_id,
3989 proto::UpdateContacts {
3990 contacts: vec![updated_contact.clone()],
3991 remove_contacts: Default::default(),
3992 incoming_requests: Default::default(),
3993 remove_incoming_requests: Default::default(),
3994 outgoing_requests: Default::default(),
3995 remove_outgoing_requests: Default::default(),
3996 },
3997 )
3998 .trace_err();
3999 }
4000 }
4001 }
4002 Ok(())
4003}
4004
4005async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4006 let mut contacts_to_update = HashSet::default();
4007
4008 let room_id;
4009 let canceled_calls_to_user_ids;
4010 let livekit_room;
4011 let delete_livekit_room;
4012 let room;
4013 let channel;
4014
4015 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4016 contacts_to_update.insert(session.user_id());
4017
4018 for project in left_room.left_projects.values() {
4019 project_left(project, session);
4020 }
4021
4022 room_id = RoomId::from_proto(left_room.room.id);
4023 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4024 livekit_room = mem::take(&mut left_room.room.livekit_room);
4025 delete_livekit_room = left_room.deleted;
4026 room = mem::take(&mut left_room.room);
4027 channel = mem::take(&mut left_room.channel);
4028
4029 room_updated(&room, &session.peer);
4030 } else {
4031 return Ok(());
4032 }
4033
4034 if let Some(channel) = channel {
4035 channel_updated(
4036 &channel,
4037 &room,
4038 &session.peer,
4039 &*session.connection_pool().await,
4040 );
4041 }
4042
4043 {
4044 let pool = session.connection_pool().await;
4045 for canceled_user_id in canceled_calls_to_user_ids {
4046 for connection_id in pool.user_connection_ids(canceled_user_id) {
4047 session
4048 .peer
4049 .send(
4050 connection_id,
4051 proto::CallCanceled {
4052 room_id: room_id.to_proto(),
4053 },
4054 )
4055 .trace_err();
4056 }
4057 contacts_to_update.insert(canceled_user_id);
4058 }
4059 }
4060
4061 for contact_user_id in contacts_to_update {
4062 update_user_contacts(contact_user_id, session).await?;
4063 }
4064
4065 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4066 live_kit
4067 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4068 .await
4069 .trace_err();
4070
4071 if delete_livekit_room {
4072 live_kit.delete_room(livekit_room).await.trace_err();
4073 }
4074 }
4075
4076 Ok(())
4077}
4078
4079async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4080 let left_channel_buffers = session
4081 .db()
4082 .await
4083 .leave_channel_buffers(session.connection_id)
4084 .await?;
4085
4086 for left_buffer in left_channel_buffers {
4087 channel_buffer_updated(
4088 session.connection_id,
4089 left_buffer.connections,
4090 &proto::UpdateChannelBufferCollaborators {
4091 channel_id: left_buffer.channel_id.to_proto(),
4092 collaborators: left_buffer.collaborators,
4093 },
4094 &session.peer,
4095 );
4096 }
4097
4098 Ok(())
4099}
4100
4101fn project_left(project: &db::LeftProject, session: &Session) {
4102 for connection_id in &project.connection_ids {
4103 if project.should_unshare {
4104 session
4105 .peer
4106 .send(
4107 *connection_id,
4108 proto::UnshareProject {
4109 project_id: project.id.to_proto(),
4110 },
4111 )
4112 .trace_err();
4113 } else {
4114 session
4115 .peer
4116 .send(
4117 *connection_id,
4118 proto::RemoveProjectCollaborator {
4119 project_id: project.id.to_proto(),
4120 peer_id: Some(session.connection_id.into()),
4121 },
4122 )
4123 .trace_err();
4124 }
4125 }
4126}
4127
4128pub trait ResultExt {
4129 type Ok;
4130
4131 fn trace_err(self) -> Option<Self::Ok>;
4132}
4133
4134impl<T, E> ResultExt for Result<T, E>
4135where
4136 E: std::fmt::Debug,
4137{
4138 type Ok = T;
4139
4140 #[track_caller]
4141 fn trace_err(self) -> Option<T> {
4142 match self {
4143 Ok(value) => Some(value),
4144 Err(error) => {
4145 tracing::error!("{:?}", error);
4146 None
4147 }
4148 }
4149 }
4150}