1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
10 },
11 executor::Executor,
12};
13use anyhow::{Context as _, anyhow, bail};
14use async_tungstenite::tungstenite::{
15 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
16};
17use axum::headers::UserAgent;
18use axum::{
19 Extension, Router, TypedHeader,
20 body::Body,
21 extract::{
22 ConnectInfo, WebSocketUpgrade,
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use core::fmt::{self, Debug, Formatter};
34use futures::TryFutureExt as _;
35use reqwest_client::ReqwestClient;
36use rpc::proto::split_repository_update;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38use tracing::Span;
39use util::paths::PathStyle;
40
41use futures::{
42 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
43 stream::FuturesUnordered,
44};
45use prometheus::{IntGauge, register_int_gauge};
46use rpc::{
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52};
53use semantic_version::SemanticVersion;
54use serde::{Serialize, Serializer};
55use std::{
56 any::TypeId,
57 future::Future,
58 marker::PhantomData,
59 mem,
60 net::SocketAddr,
61 ops::{Deref, DerefMut},
62 rc::Rc,
63 sync::{
64 Arc, OnceLock,
65 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
66 },
67 time::{Duration, Instant},
68};
69use tokio::sync::{Semaphore, watch};
70use tower::ServiceBuilder;
71use tracing::{
72 Instrument,
73 field::{self},
74 info_span, instrument,
75};
76
77pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
78
79// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
80pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
81
82const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
83const MAX_CONCURRENT_CONNECTIONS: usize = 512;
84
85static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
86
87const TOTAL_DURATION_MS: &str = "total_duration_ms";
88const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
89const QUEUE_DURATION_MS: &str = "queue_duration_ms";
90const HOST_WAITING_MS: &str = "host_waiting_ms";
91
92type MessageHandler =
93 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
94
95pub struct ConnectionGuard;
96
97impl ConnectionGuard {
98 pub fn try_acquire() -> Result<Self, ()> {
99 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
100 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
101 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
102 tracing::error!(
103 "too many concurrent connections: {}",
104 current_connections + 1
105 );
106 return Err(());
107 }
108 Ok(ConnectionGuard)
109 }
110}
111
112impl Drop for ConnectionGuard {
113 fn drop(&mut self) {
114 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
115 }
116}
117
118struct Response<R> {
119 peer: Arc<Peer>,
120 receipt: Receipt<R>,
121 responded: Arc<AtomicBool>,
122}
123
124impl<R: RequestMessage> Response<R> {
125 fn send(self, payload: R::Response) -> Result<()> {
126 self.responded.store(true, SeqCst);
127 self.peer.respond(self.receipt, payload)?;
128 Ok(())
129 }
130}
131
132#[derive(Clone, Debug)]
133pub enum Principal {
134 User(User),
135 Impersonated { user: User, admin: User },
136}
137
138impl Principal {
139 fn update_span(&self, span: &tracing::Span) {
140 match &self {
141 Principal::User(user) => {
142 span.record("user_id", user.id.0);
143 span.record("login", &user.github_login);
144 }
145 Principal::Impersonated { user, admin } => {
146 span.record("user_id", user.id.0);
147 span.record("login", &user.github_login);
148 span.record("impersonator", &admin.github_login);
149 }
150 }
151 }
152}
153
154#[derive(Clone)]
155struct MessageContext {
156 session: Session,
157 span: tracing::Span,
158}
159
160impl Deref for MessageContext {
161 type Target = Session;
162
163 fn deref(&self) -> &Self::Target {
164 &self.session
165 }
166}
167
168impl MessageContext {
169 pub fn forward_request<T: RequestMessage>(
170 &self,
171 receiver_id: ConnectionId,
172 request: T,
173 ) -> impl Future<Output = anyhow::Result<T::Response>> {
174 let request_start_time = Instant::now();
175 let span = self.span.clone();
176 tracing::info!("start forwarding request");
177 self.peer
178 .forward_request(self.connection_id, receiver_id, request)
179 .inspect(move |_| {
180 span.record(
181 HOST_WAITING_MS,
182 request_start_time.elapsed().as_micros() as f64 / 1000.0,
183 );
184 })
185 .inspect_err(|_| tracing::error!("error forwarding request"))
186 .inspect_ok(|_| tracing::info!("finished forwarding request"))
187 }
188}
189
190#[derive(Clone)]
191struct Session {
192 principal: Principal,
193 connection_id: ConnectionId,
194 db: Arc<tokio::sync::Mutex<DbHandle>>,
195 peer: Arc<Peer>,
196 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
197 app_state: Arc<AppState>,
198 supermaven_client: Option<Arc<SupermavenAdminApi>>,
199 /// The GeoIP country code for the user.
200 #[allow(unused)]
201 geoip_country_code: Option<String>,
202 #[allow(unused)]
203 system_id: Option<String>,
204 _executor: Executor,
205}
206
207impl Session {
208 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
209 #[cfg(test)]
210 tokio::task::yield_now().await;
211 let guard = self.db.lock().await;
212 #[cfg(test)]
213 tokio::task::yield_now().await;
214 guard
215 }
216
217 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
218 #[cfg(test)]
219 tokio::task::yield_now().await;
220 let guard = self.connection_pool.lock();
221 ConnectionPoolGuard {
222 guard,
223 _not_send: PhantomData,
224 }
225 }
226
227 const fn is_staff(&self) -> bool {
228 match &self.principal {
229 Principal::User(user) => user.admin,
230 Principal::Impersonated { .. } => true,
231 }
232 }
233
234 const fn user_id(&self) -> UserId {
235 match &self.principal {
236 Principal::User(user) => user.id,
237 Principal::Impersonated { user, .. } => user.id,
238 }
239 }
240
241 pub fn email(&self) -> Option<String> {
242 match &self.principal {
243 Principal::User(user) => user.email_address.clone(),
244 Principal::Impersonated { user, .. } => user.email_address.clone(),
245 }
246 }
247}
248
249impl Debug for Session {
250 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
251 let mut result = f.debug_struct("Session");
252 match &self.principal {
253 Principal::User(user) => {
254 result.field("user", &user.github_login);
255 }
256 Principal::Impersonated { user, admin } => {
257 result.field("user", &user.github_login);
258 result.field("impersonator", &admin.github_login);
259 }
260 }
261 result.field("connection_id", &self.connection_id).finish()
262 }
263}
264
265struct DbHandle(Arc<Database>);
266
267impl Deref for DbHandle {
268 type Target = Database;
269
270 fn deref(&self) -> &Self::Target {
271 self.0.as_ref()
272 }
273}
274
275pub struct Server {
276 id: parking_lot::Mutex<ServerId>,
277 peer: Arc<Peer>,
278 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
279 app_state: Arc<AppState>,
280 handlers: HashMap<TypeId, MessageHandler>,
281 teardown: watch::Sender<bool>,
282}
283
284pub(crate) struct ConnectionPoolGuard<'a> {
285 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
286 _not_send: PhantomData<Rc<()>>,
287}
288
289#[derive(Serialize)]
290pub struct ServerSnapshot<'a> {
291 peer: &'a Peer,
292 #[serde(serialize_with = "serialize_deref")]
293 connection_pool: ConnectionPoolGuard<'a>,
294}
295
296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
297where
298 S: Serializer,
299 T: Deref<Target = U>,
300 U: Serialize,
301{
302 Serialize::serialize(value.deref(), serializer)
303}
304
305impl Server {
306 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
307 let mut server = Self {
308 id: parking_lot::Mutex::new(id),
309 peer: Peer::new(id.0 as u32),
310 app_state,
311 connection_pool: Default::default(),
312 handlers: Default::default(),
313 teardown: watch::channel(false).0,
314 };
315
316 server
317 .add_request_handler(ping)
318 .add_request_handler(create_room)
319 .add_request_handler(join_room)
320 .add_request_handler(rejoin_room)
321 .add_request_handler(leave_room)
322 .add_request_handler(set_room_participant_role)
323 .add_request_handler(call)
324 .add_request_handler(cancel_call)
325 .add_message_handler(decline_call)
326 .add_request_handler(update_participant_location)
327 .add_request_handler(share_project)
328 .add_message_handler(unshare_project)
329 .add_request_handler(join_project)
330 .add_message_handler(leave_project)
331 .add_request_handler(update_project)
332 .add_request_handler(update_worktree)
333 .add_request_handler(update_repository)
334 .add_request_handler(remove_repository)
335 .add_message_handler(start_language_server)
336 .add_message_handler(update_language_server)
337 .add_message_handler(update_diagnostic_summary)
338 .add_message_handler(update_worktree_settings)
339 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
340 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
341 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
342 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
343 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
344 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
345 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
346 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
347 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
348 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
349 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
350 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
351 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
352 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
353 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
354 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
355 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
356 .add_request_handler(
357 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
358 )
359 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
360 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
361 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
362 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
363 .add_request_handler(
364 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
365 )
366 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
367 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
368 .add_request_handler(
369 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
370 )
371 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
372 .add_request_handler(
373 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
374 )
375 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
376 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
377 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
378 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
379 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
380 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
381 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
382 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
383 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
384 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
385 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
386 .add_request_handler(
387 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
388 )
389 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
390 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
391 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
392 .add_request_handler(lsp_query)
393 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
394 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
395 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
396 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
397 .add_message_handler(create_buffer_for_peer)
398 .add_request_handler(update_buffer)
399 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
400 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
401 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
402 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
403 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
404 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
405 .add_message_handler(
406 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
407 )
408 .add_request_handler(get_users)
409 .add_request_handler(fuzzy_search_users)
410 .add_request_handler(request_contact)
411 .add_request_handler(remove_contact)
412 .add_request_handler(respond_to_contact_request)
413 .add_message_handler(subscribe_to_channels)
414 .add_request_handler(create_channel)
415 .add_request_handler(delete_channel)
416 .add_request_handler(invite_channel_member)
417 .add_request_handler(remove_channel_member)
418 .add_request_handler(set_channel_member_role)
419 .add_request_handler(set_channel_visibility)
420 .add_request_handler(rename_channel)
421 .add_request_handler(join_channel_buffer)
422 .add_request_handler(leave_channel_buffer)
423 .add_message_handler(update_channel_buffer)
424 .add_request_handler(rejoin_channel_buffers)
425 .add_request_handler(get_channel_members)
426 .add_request_handler(respond_to_channel_invite)
427 .add_request_handler(join_channel)
428 .add_request_handler(join_channel_chat)
429 .add_message_handler(leave_channel_chat)
430 .add_request_handler(send_channel_message)
431 .add_request_handler(remove_channel_message)
432 .add_request_handler(update_channel_message)
433 .add_request_handler(get_channel_messages)
434 .add_request_handler(get_channel_messages_by_id)
435 .add_request_handler(get_notifications)
436 .add_request_handler(mark_notification_as_read)
437 .add_request_handler(move_channel)
438 .add_request_handler(reorder_channel)
439 .add_request_handler(follow)
440 .add_message_handler(unfollow)
441 .add_message_handler(update_followers)
442 .add_message_handler(acknowledge_channel_message)
443 .add_message_handler(acknowledge_buffer_version)
444 .add_request_handler(get_supermaven_api_key)
445 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
446 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
447 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
448 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
449 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
450 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
451 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
452 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
453 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
454 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
455 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
456 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
457 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
458 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
459 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
460 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
461 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
462 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
463 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
464 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
465 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
466 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
467 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
468 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
469 .add_message_handler(update_context)
470 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
471 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
472
473 Arc::new(server)
474 }
475
476 pub async fn start(&self) -> Result<()> {
477 let server_id = *self.id.lock();
478 let app_state = self.app_state.clone();
479 let peer = self.peer.clone();
480 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
481 let pool = self.connection_pool.clone();
482 let livekit_client = self.app_state.livekit_client.clone();
483
484 let span = info_span!("start server");
485 self.app_state.executor.spawn_detached(
486 async move {
487 tracing::info!("waiting for cleanup timeout");
488 timeout.await;
489 tracing::info!("cleanup timeout expired, retrieving stale rooms");
490
491 app_state
492 .db
493 .delete_stale_channel_chat_participants(
494 &app_state.config.zed_environment,
495 server_id,
496 )
497 .await
498 .trace_err();
499
500 if let Some((room_ids, channel_ids)) = app_state
501 .db
502 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
503 .await
504 .trace_err()
505 {
506 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
507 tracing::info!(
508 stale_channel_buffer_count = channel_ids.len(),
509 "retrieved stale channel buffers"
510 );
511
512 for channel_id in channel_ids {
513 if let Some(refreshed_channel_buffer) = app_state
514 .db
515 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
516 .await
517 .trace_err()
518 {
519 for connection_id in refreshed_channel_buffer.connection_ids {
520 peer.send(
521 connection_id,
522 proto::UpdateChannelBufferCollaborators {
523 channel_id: channel_id.to_proto(),
524 collaborators: refreshed_channel_buffer
525 .collaborators
526 .clone(),
527 },
528 )
529 .trace_err();
530 }
531 }
532 }
533
534 for room_id in room_ids {
535 let mut contacts_to_update = HashSet::default();
536 let mut canceled_calls_to_user_ids = Vec::new();
537 let mut livekit_room = String::new();
538 let mut delete_livekit_room = false;
539
540 if let Some(mut refreshed_room) = app_state
541 .db
542 .clear_stale_room_participants(room_id, server_id)
543 .await
544 .trace_err()
545 {
546 tracing::info!(
547 room_id = room_id.0,
548 new_participant_count = refreshed_room.room.participants.len(),
549 "refreshed room"
550 );
551 room_updated(&refreshed_room.room, &peer);
552 if let Some(channel) = refreshed_room.channel.as_ref() {
553 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
554 }
555 contacts_to_update
556 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
557 contacts_to_update
558 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
559 canceled_calls_to_user_ids =
560 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
561 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
562 delete_livekit_room = refreshed_room.room.participants.is_empty();
563 }
564
565 {
566 let pool = pool.lock();
567 for canceled_user_id in canceled_calls_to_user_ids {
568 for connection_id in pool.user_connection_ids(canceled_user_id) {
569 peer.send(
570 connection_id,
571 proto::CallCanceled {
572 room_id: room_id.to_proto(),
573 },
574 )
575 .trace_err();
576 }
577 }
578 }
579
580 for user_id in contacts_to_update {
581 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
582 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
583 if let Some((busy, contacts)) = busy.zip(contacts) {
584 let pool = pool.lock();
585 let updated_contact = contact_for_user(user_id, busy, &pool);
586 for contact in contacts {
587 if let db::Contact::Accepted {
588 user_id: contact_user_id,
589 ..
590 } = contact
591 {
592 for contact_conn_id in
593 pool.user_connection_ids(contact_user_id)
594 {
595 peer.send(
596 contact_conn_id,
597 proto::UpdateContacts {
598 contacts: vec![updated_contact.clone()],
599 remove_contacts: Default::default(),
600 incoming_requests: Default::default(),
601 remove_incoming_requests: Default::default(),
602 outgoing_requests: Default::default(),
603 remove_outgoing_requests: Default::default(),
604 },
605 )
606 .trace_err();
607 }
608 }
609 }
610 }
611 }
612
613 if let Some(live_kit) = livekit_client.as_ref()
614 && delete_livekit_room
615 {
616 live_kit.delete_room(livekit_room).await.trace_err();
617 }
618 }
619 }
620
621 app_state
622 .db
623 .delete_stale_channel_chat_participants(
624 &app_state.config.zed_environment,
625 server_id,
626 )
627 .await
628 .trace_err();
629
630 app_state
631 .db
632 .clear_old_worktree_entries(server_id)
633 .await
634 .trace_err();
635
636 app_state
637 .db
638 .delete_stale_servers(&app_state.config.zed_environment, server_id)
639 .await
640 .trace_err();
641 }
642 .instrument(span),
643 );
644 Ok(())
645 }
646
647 pub fn teardown(&self) {
648 self.peer.teardown();
649 self.connection_pool.lock().reset();
650 let _ = self.teardown.send(true);
651 }
652
653 #[cfg(test)]
654 pub fn reset(&self, id: ServerId) {
655 self.teardown();
656 *self.id.lock() = id;
657 self.peer.reset(id.0 as u32);
658 let _ = self.teardown.send(false);
659 }
660
661 #[cfg(test)]
662 pub fn id(&self) -> ServerId {
663 *self.id.lock()
664 }
665
666 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
667 where
668 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
669 Fut: 'static + Send + Future<Output = Result<()>>,
670 M: EnvelopedMessage,
671 {
672 let prev_handler = self.handlers.insert(
673 TypeId::of::<M>(),
674 Box::new(move |envelope, session, span| {
675 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
676 let received_at = envelope.received_at;
677 tracing::info!("message received");
678 let start_time = Instant::now();
679 let future = (handler)(
680 *envelope,
681 MessageContext {
682 session,
683 span: span.clone(),
684 },
685 );
686 async move {
687 let result = future.await;
688 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
689 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
690 let queue_duration_ms = total_duration_ms - processing_duration_ms;
691 span.record(TOTAL_DURATION_MS, total_duration_ms);
692 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
693 span.record(QUEUE_DURATION_MS, queue_duration_ms);
694 match result {
695 Err(error) => {
696 tracing::error!(?error, "error handling message")
697 }
698 Ok(()) => tracing::info!("finished handling message"),
699 }
700 }
701 .boxed()
702 }),
703 );
704 if prev_handler.is_some() {
705 panic!("registered a handler for the same message twice");
706 }
707 self
708 }
709
710 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
711 where
712 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
713 Fut: 'static + Send + Future<Output = Result<()>>,
714 M: EnvelopedMessage,
715 {
716 self.add_handler(move |envelope, session| handler(envelope.payload, session));
717 self
718 }
719
720 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
721 where
722 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
723 Fut: Send + Future<Output = Result<()>>,
724 M: RequestMessage,
725 {
726 let handler = Arc::new(handler);
727 self.add_handler(move |envelope, session| {
728 let receipt = envelope.receipt();
729 let handler = handler.clone();
730 async move {
731 let peer = session.peer.clone();
732 let responded = Arc::new(AtomicBool::default());
733 let response = Response {
734 peer: peer.clone(),
735 responded: responded.clone(),
736 receipt,
737 };
738 match (handler)(envelope.payload, response, session).await {
739 Ok(()) => {
740 if responded.load(std::sync::atomic::Ordering::SeqCst) {
741 Ok(())
742 } else {
743 Err(anyhow!("handler did not send a response"))?
744 }
745 }
746 Err(error) => {
747 let proto_err = match &error {
748 Error::Internal(err) => err.to_proto(),
749 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
750 };
751 peer.respond_with_error(receipt, proto_err)?;
752 Err(error)
753 }
754 }
755 }
756 })
757 }
758
759 pub fn handle_connection(
760 self: &Arc<Self>,
761 connection: Connection,
762 address: String,
763 principal: Principal,
764 zed_version: ZedVersion,
765 release_channel: Option<String>,
766 user_agent: Option<String>,
767 geoip_country_code: Option<String>,
768 system_id: Option<String>,
769 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
770 executor: Executor,
771 connection_guard: Option<ConnectionGuard>,
772 ) -> impl Future<Output = ()> + use<> {
773 let this = self.clone();
774 let span = info_span!("handle connection", %address,
775 connection_id=field::Empty,
776 user_id=field::Empty,
777 login=field::Empty,
778 impersonator=field::Empty,
779 user_agent=field::Empty,
780 geoip_country_code=field::Empty,
781 release_channel=field::Empty,
782 );
783 principal.update_span(&span);
784 if let Some(user_agent) = user_agent {
785 span.record("user_agent", user_agent);
786 }
787 if let Some(release_channel) = release_channel {
788 span.record("release_channel", release_channel);
789 }
790
791 if let Some(country_code) = geoip_country_code.as_ref() {
792 span.record("geoip_country_code", country_code);
793 }
794
795 let mut teardown = self.teardown.subscribe();
796 async move {
797 if *teardown.borrow() {
798 tracing::error!("server is tearing down");
799 return;
800 }
801
802 let (connection_id, handle_io, mut incoming_rx) =
803 this.peer.add_connection(connection, {
804 let executor = executor.clone();
805 move |duration| executor.sleep(duration)
806 });
807 tracing::Span::current().record("connection_id", format!("{}", connection_id));
808
809 tracing::info!("connection opened");
810
811 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
812 let http_client = match ReqwestClient::user_agent(&user_agent) {
813 Ok(http_client) => Arc::new(http_client),
814 Err(error) => {
815 tracing::error!(?error, "failed to create HTTP client");
816 return;
817 }
818 };
819
820 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
821 |supermaven_admin_api_key| {
822 Arc::new(SupermavenAdminApi::new(
823 supermaven_admin_api_key.to_string(),
824 http_client.clone(),
825 ))
826 },
827 );
828
829 let session = Session {
830 principal: principal.clone(),
831 connection_id,
832 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
833 peer: this.peer.clone(),
834 connection_pool: this.connection_pool.clone(),
835 app_state: this.app_state.clone(),
836 geoip_country_code,
837 system_id,
838 _executor: executor.clone(),
839 supermaven_client,
840 };
841
842 if let Err(error) = this
843 .send_initial_client_update(
844 connection_id,
845 zed_version,
846 send_connection_id,
847 &session,
848 )
849 .await
850 {
851 tracing::error!(?error, "failed to send initial client update");
852 return;
853 }
854 drop(connection_guard);
855
856 let handle_io = handle_io.fuse();
857 futures::pin_mut!(handle_io);
858
859 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
860 // This prevents deadlocks when e.g., client A performs a request to client B and
861 // client B performs a request to client A. If both clients stop processing further
862 // messages until their respective request completes, they won't have a chance to
863 // respond to the other client's request and cause a deadlock.
864 //
865 // This arrangement ensures we will attempt to process earlier messages first, but fall
866 // back to processing messages arrived later in the spirit of making progress.
867 const MAX_CONCURRENT_HANDLERS: usize = 256;
868 let mut foreground_message_handlers = FuturesUnordered::new();
869 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
870 let get_concurrent_handlers = {
871 let concurrent_handlers = concurrent_handlers.clone();
872 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
873 };
874 loop {
875 let next_message = async {
876 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
877 let message = incoming_rx.next().await;
878 // Cache the concurrent_handlers here, so that we know what the
879 // queue looks like as each handler starts
880 (permit, message, get_concurrent_handlers())
881 }
882 .fuse();
883 futures::pin_mut!(next_message);
884 futures::select_biased! {
885 _ = teardown.changed().fuse() => return,
886 result = handle_io => {
887 if let Err(error) = result {
888 tracing::error!(?error, "error handling I/O");
889 }
890 break;
891 }
892 _ = foreground_message_handlers.next() => {}
893 next_message = next_message => {
894 let (permit, message, concurrent_handlers) = next_message;
895 if let Some(message) = message {
896 let type_name = message.payload_type_name();
897 // note: we copy all the fields from the parent span so we can query them in the logs.
898 // (https://github.com/tokio-rs/tracing/issues/2670).
899 let span = tracing::info_span!("receive message",
900 %connection_id,
901 %address,
902 type_name,
903 concurrent_handlers,
904 user_id=field::Empty,
905 login=field::Empty,
906 impersonator=field::Empty,
907 lsp_query_request=field::Empty,
908 release_channel=field::Empty,
909 { TOTAL_DURATION_MS }=field::Empty,
910 { PROCESSING_DURATION_MS }=field::Empty,
911 { QUEUE_DURATION_MS }=field::Empty,
912 { HOST_WAITING_MS }=field::Empty
913 );
914 principal.update_span(&span);
915 let span_enter = span.enter();
916 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
917 let is_background = message.is_background();
918 let handle_message = (handler)(message, session.clone(), span.clone());
919 drop(span_enter);
920
921 let handle_message = async move {
922 handle_message.await;
923 drop(permit);
924 }.instrument(span);
925 if is_background {
926 executor.spawn_detached(handle_message);
927 } else {
928 foreground_message_handlers.push(handle_message);
929 }
930 } else {
931 tracing::error!("no message handler");
932 }
933 } else {
934 tracing::info!("connection closed");
935 break;
936 }
937 }
938 }
939 }
940
941 drop(foreground_message_handlers);
942 let concurrent_handlers = get_concurrent_handlers();
943 tracing::info!(concurrent_handlers, "signing out");
944 if let Err(error) = connection_lost(session, teardown, executor).await {
945 tracing::error!(?error, "error signing out");
946 }
947 }
948 .instrument(span)
949 }
950
951 async fn send_initial_client_update(
952 &self,
953 connection_id: ConnectionId,
954 zed_version: ZedVersion,
955 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
956 session: &Session,
957 ) -> Result<()> {
958 self.peer.send(
959 connection_id,
960 proto::Hello {
961 peer_id: Some(connection_id.into()),
962 },
963 )?;
964 tracing::info!("sent hello message");
965 if let Some(send_connection_id) = send_connection_id.take() {
966 let _ = send_connection_id.send(connection_id);
967 }
968
969 match &session.principal {
970 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
971 if !user.connected_once {
972 self.peer.send(connection_id, proto::ShowContacts {})?;
973 self.app_state
974 .db
975 .set_user_connected_once(user.id, true)
976 .await?;
977 }
978
979 let contacts = self.app_state.db.get_contacts(user.id).await?;
980
981 {
982 let mut pool = self.connection_pool.lock();
983 pool.add_connection(connection_id, user.id, user.admin, zed_version);
984 self.peer.send(
985 connection_id,
986 build_initial_contacts_update(contacts, &pool),
987 )?;
988 }
989
990 if should_auto_subscribe_to_channels(zed_version) {
991 subscribe_user_to_channels(user.id, session).await?;
992 }
993
994 if let Some(incoming_call) =
995 self.app_state.db.incoming_call_for_user(user.id).await?
996 {
997 self.peer.send(connection_id, incoming_call)?;
998 }
999
1000 update_user_contacts(user.id, session).await?;
1001 }
1002 }
1003
1004 Ok(())
1005 }
1006
1007 pub async fn invite_code_redeemed(
1008 self: &Arc<Self>,
1009 inviter_id: UserId,
1010 invitee_id: UserId,
1011 ) -> Result<()> {
1012 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1013 && let Some(code) = &user.invite_code
1014 {
1015 let pool = self.connection_pool.lock();
1016 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1017 for connection_id in pool.user_connection_ids(inviter_id) {
1018 self.peer.send(
1019 connection_id,
1020 proto::UpdateContacts {
1021 contacts: vec![invitee_contact.clone()],
1022 ..Default::default()
1023 },
1024 )?;
1025 self.peer.send(
1026 connection_id,
1027 proto::UpdateInviteInfo {
1028 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1029 count: user.invite_count as u32,
1030 },
1031 )?;
1032 }
1033 }
1034 Ok(())
1035 }
1036
1037 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1038 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1039 && let Some(invite_code) = &user.invite_code
1040 {
1041 let pool = self.connection_pool.lock();
1042 for connection_id in pool.user_connection_ids(user_id) {
1043 self.peer.send(
1044 connection_id,
1045 proto::UpdateInviteInfo {
1046 url: format!(
1047 "{}{}",
1048 self.app_state.config.invite_link_prefix, invite_code
1049 ),
1050 count: user.invite_count as u32,
1051 },
1052 )?;
1053 }
1054 }
1055 Ok(())
1056 }
1057
1058 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1059 ServerSnapshot {
1060 connection_pool: ConnectionPoolGuard {
1061 guard: self.connection_pool.lock(),
1062 _not_send: PhantomData,
1063 },
1064 peer: &self.peer,
1065 }
1066 }
1067}
1068
1069impl Deref for ConnectionPoolGuard<'_> {
1070 type Target = ConnectionPool;
1071
1072 fn deref(&self) -> &Self::Target {
1073 &self.guard
1074 }
1075}
1076
1077impl DerefMut for ConnectionPoolGuard<'_> {
1078 fn deref_mut(&mut self) -> &mut Self::Target {
1079 &mut self.guard
1080 }
1081}
1082
1083impl Drop for ConnectionPoolGuard<'_> {
1084 fn drop(&mut self) {
1085 #[cfg(test)]
1086 self.check_invariants();
1087 }
1088}
1089
1090fn broadcast<F>(
1091 sender_id: Option<ConnectionId>,
1092 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1093 mut f: F,
1094) where
1095 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1096{
1097 for receiver_id in receiver_ids {
1098 if Some(receiver_id) != sender_id
1099 && let Err(error) = f(receiver_id)
1100 {
1101 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1102 }
1103 }
1104}
1105
1106pub struct ProtocolVersion(u32);
1107
1108impl Header for ProtocolVersion {
1109 fn name() -> &'static HeaderName {
1110 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1111 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1112 }
1113
1114 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1115 where
1116 Self: Sized,
1117 I: Iterator<Item = &'i axum::http::HeaderValue>,
1118 {
1119 let version = values
1120 .next()
1121 .ok_or_else(axum::headers::Error::invalid)?
1122 .to_str()
1123 .map_err(|_| axum::headers::Error::invalid())?
1124 .parse()
1125 .map_err(|_| axum::headers::Error::invalid())?;
1126 Ok(Self(version))
1127 }
1128
1129 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1130 values.extend([self.0.to_string().parse().unwrap()]);
1131 }
1132}
1133
1134pub struct AppVersionHeader(SemanticVersion);
1135impl Header for AppVersionHeader {
1136 fn name() -> &'static HeaderName {
1137 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1138 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1139 }
1140
1141 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1142 where
1143 Self: Sized,
1144 I: Iterator<Item = &'i axum::http::HeaderValue>,
1145 {
1146 let version = values
1147 .next()
1148 .ok_or_else(axum::headers::Error::invalid)?
1149 .to_str()
1150 .map_err(|_| axum::headers::Error::invalid())?
1151 .parse()
1152 .map_err(|_| axum::headers::Error::invalid())?;
1153 Ok(Self(version))
1154 }
1155
1156 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1157 values.extend([self.0.to_string().parse().unwrap()]);
1158 }
1159}
1160
1161#[derive(Debug)]
1162pub struct ReleaseChannelHeader(String);
1163
1164impl Header for ReleaseChannelHeader {
1165 fn name() -> &'static HeaderName {
1166 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1167 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1168 }
1169
1170 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1171 where
1172 Self: Sized,
1173 I: Iterator<Item = &'i axum::http::HeaderValue>,
1174 {
1175 Ok(Self(
1176 values
1177 .next()
1178 .ok_or_else(axum::headers::Error::invalid)?
1179 .to_str()
1180 .map_err(|_| axum::headers::Error::invalid())?
1181 .to_owned(),
1182 ))
1183 }
1184
1185 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1186 values.extend([self.0.parse().unwrap()]);
1187 }
1188}
1189
1190pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1191 Router::new()
1192 .route("/rpc", get(handle_websocket_request))
1193 .layer(
1194 ServiceBuilder::new()
1195 .layer(Extension(server.app_state.clone()))
1196 .layer(middleware::from_fn(auth::validate_header)),
1197 )
1198 .route("/metrics", get(handle_metrics))
1199 .layer(Extension(server))
1200}
1201
1202pub async fn handle_websocket_request(
1203 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1204 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1205 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1206 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1207 Extension(server): Extension<Arc<Server>>,
1208 Extension(principal): Extension<Principal>,
1209 user_agent: Option<TypedHeader<UserAgent>>,
1210 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1211 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1212 ws: WebSocketUpgrade,
1213) -> axum::response::Response {
1214 if protocol_version != rpc::PROTOCOL_VERSION {
1215 return (
1216 StatusCode::UPGRADE_REQUIRED,
1217 "client must be upgraded".to_string(),
1218 )
1219 .into_response();
1220 }
1221
1222 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1223 return (
1224 StatusCode::UPGRADE_REQUIRED,
1225 "no version header found".to_string(),
1226 )
1227 .into_response();
1228 };
1229
1230 let release_channel = release_channel_header.map(|header| header.0.0);
1231
1232 if !version.can_collaborate() {
1233 return (
1234 StatusCode::UPGRADE_REQUIRED,
1235 "client must be upgraded".to_string(),
1236 )
1237 .into_response();
1238 }
1239
1240 let socket_address = socket_address.to_string();
1241
1242 // Acquire connection guard before WebSocket upgrade
1243 let connection_guard = match ConnectionGuard::try_acquire() {
1244 Ok(guard) => guard,
1245 Err(()) => {
1246 return (
1247 StatusCode::SERVICE_UNAVAILABLE,
1248 "Too many concurrent connections",
1249 )
1250 .into_response();
1251 }
1252 };
1253
1254 ws.on_upgrade(move |socket| {
1255 let socket = socket
1256 .map_ok(to_tungstenite_message)
1257 .err_into()
1258 .with(|message| async move { to_axum_message(message) });
1259 let connection = Connection::new(Box::pin(socket));
1260 async move {
1261 server
1262 .handle_connection(
1263 connection,
1264 socket_address,
1265 principal,
1266 version,
1267 release_channel,
1268 user_agent.map(|header| header.to_string()),
1269 country_code_header.map(|header| header.to_string()),
1270 system_id_header.map(|header| header.to_string()),
1271 None,
1272 Executor::Production,
1273 Some(connection_guard),
1274 )
1275 .await;
1276 }
1277 })
1278}
1279
1280pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1281 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1282 let connections_metric = CONNECTIONS_METRIC
1283 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1284
1285 let connections = server
1286 .connection_pool
1287 .lock()
1288 .connections()
1289 .filter(|connection| !connection.admin)
1290 .count();
1291 connections_metric.set(connections as _);
1292
1293 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1294 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1295 register_int_gauge!(
1296 "shared_projects",
1297 "number of open projects with one or more guests"
1298 )
1299 .unwrap()
1300 });
1301
1302 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1303 shared_projects_metric.set(shared_projects as _);
1304
1305 let encoder = prometheus::TextEncoder::new();
1306 let metric_families = prometheus::gather();
1307 let encoded_metrics = encoder
1308 .encode_to_string(&metric_families)
1309 .map_err(|err| anyhow!("{err}"))?;
1310 Ok(encoded_metrics)
1311}
1312
1313#[instrument(err, skip(executor))]
1314async fn connection_lost(
1315 session: Session,
1316 mut teardown: watch::Receiver<bool>,
1317 executor: Executor,
1318) -> Result<()> {
1319 session.peer.disconnect(session.connection_id);
1320 session
1321 .connection_pool()
1322 .await
1323 .remove_connection(session.connection_id)?;
1324
1325 session
1326 .db()
1327 .await
1328 .connection_lost(session.connection_id)
1329 .await
1330 .trace_err();
1331
1332 futures::select_biased! {
1333 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1334
1335 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1336 leave_room_for_session(&session, session.connection_id).await.trace_err();
1337 leave_channel_buffers_for_session(&session)
1338 .await
1339 .trace_err();
1340
1341 if !session
1342 .connection_pool()
1343 .await
1344 .is_user_online(session.user_id())
1345 {
1346 let db = session.db().await;
1347 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1348 room_updated(&room, &session.peer);
1349 }
1350 }
1351
1352 update_user_contacts(session.user_id(), &session).await?;
1353 },
1354 _ = teardown.changed().fuse() => {}
1355 }
1356
1357 Ok(())
1358}
1359
1360/// Acknowledges a ping from a client, used to keep the connection alive.
1361async fn ping(
1362 _: proto::Ping,
1363 response: Response<proto::Ping>,
1364 _session: MessageContext,
1365) -> Result<()> {
1366 response.send(proto::Ack {})?;
1367 Ok(())
1368}
1369
1370/// Creates a new room for calling (outside of channels)
1371async fn create_room(
1372 _request: proto::CreateRoom,
1373 response: Response<proto::CreateRoom>,
1374 session: MessageContext,
1375) -> Result<()> {
1376 let livekit_room = nanoid::nanoid!(30);
1377
1378 let live_kit_connection_info = util::maybe!(async {
1379 let live_kit = session.app_state.livekit_client.as_ref();
1380 let live_kit = live_kit?;
1381 let user_id = session.user_id().to_string();
1382
1383 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1384
1385 Some(proto::LiveKitConnectionInfo {
1386 server_url: live_kit.url().into(),
1387 token,
1388 can_publish: true,
1389 })
1390 })
1391 .await;
1392
1393 let room = session
1394 .db()
1395 .await
1396 .create_room(session.user_id(), session.connection_id, &livekit_room)
1397 .await?;
1398
1399 response.send(proto::CreateRoomResponse {
1400 room: Some(room.clone()),
1401 live_kit_connection_info,
1402 })?;
1403
1404 update_user_contacts(session.user_id(), &session).await?;
1405 Ok(())
1406}
1407
1408/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1409async fn join_room(
1410 request: proto::JoinRoom,
1411 response: Response<proto::JoinRoom>,
1412 session: MessageContext,
1413) -> Result<()> {
1414 let room_id = RoomId::from_proto(request.id);
1415
1416 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1417
1418 if let Some(channel_id) = channel_id {
1419 return join_channel_internal(channel_id, Box::new(response), session).await;
1420 }
1421
1422 let joined_room = {
1423 let room = session
1424 .db()
1425 .await
1426 .join_room(room_id, session.user_id(), session.connection_id)
1427 .await?;
1428 room_updated(&room.room, &session.peer);
1429 room.into_inner()
1430 };
1431
1432 for connection_id in session
1433 .connection_pool()
1434 .await
1435 .user_connection_ids(session.user_id())
1436 {
1437 session
1438 .peer
1439 .send(
1440 connection_id,
1441 proto::CallCanceled {
1442 room_id: room_id.to_proto(),
1443 },
1444 )
1445 .trace_err();
1446 }
1447
1448 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1449 {
1450 live_kit
1451 .room_token(
1452 &joined_room.room.livekit_room,
1453 &session.user_id().to_string(),
1454 )
1455 .trace_err()
1456 .map(|token| proto::LiveKitConnectionInfo {
1457 server_url: live_kit.url().into(),
1458 token,
1459 can_publish: true,
1460 })
1461 } else {
1462 None
1463 };
1464
1465 response.send(proto::JoinRoomResponse {
1466 room: Some(joined_room.room),
1467 channel_id: None,
1468 live_kit_connection_info,
1469 })?;
1470
1471 update_user_contacts(session.user_id(), &session).await?;
1472 Ok(())
1473}
1474
1475/// Rejoin room is used to reconnect to a room after connection errors.
1476async fn rejoin_room(
1477 request: proto::RejoinRoom,
1478 response: Response<proto::RejoinRoom>,
1479 session: MessageContext,
1480) -> Result<()> {
1481 let room;
1482 let channel;
1483 {
1484 let mut rejoined_room = session
1485 .db()
1486 .await
1487 .rejoin_room(request, session.user_id(), session.connection_id)
1488 .await?;
1489
1490 response.send(proto::RejoinRoomResponse {
1491 room: Some(rejoined_room.room.clone()),
1492 reshared_projects: rejoined_room
1493 .reshared_projects
1494 .iter()
1495 .map(|project| proto::ResharedProject {
1496 id: project.id.to_proto(),
1497 collaborators: project
1498 .collaborators
1499 .iter()
1500 .map(|collaborator| collaborator.to_proto())
1501 .collect(),
1502 })
1503 .collect(),
1504 rejoined_projects: rejoined_room
1505 .rejoined_projects
1506 .iter()
1507 .map(|rejoined_project| rejoined_project.to_proto())
1508 .collect(),
1509 })?;
1510 room_updated(&rejoined_room.room, &session.peer);
1511
1512 for project in &rejoined_room.reshared_projects {
1513 for collaborator in &project.collaborators {
1514 session
1515 .peer
1516 .send(
1517 collaborator.connection_id,
1518 proto::UpdateProjectCollaborator {
1519 project_id: project.id.to_proto(),
1520 old_peer_id: Some(project.old_connection_id.into()),
1521 new_peer_id: Some(session.connection_id.into()),
1522 },
1523 )
1524 .trace_err();
1525 }
1526
1527 broadcast(
1528 Some(session.connection_id),
1529 project
1530 .collaborators
1531 .iter()
1532 .map(|collaborator| collaborator.connection_id),
1533 |connection_id| {
1534 session.peer.forward_send(
1535 session.connection_id,
1536 connection_id,
1537 proto::UpdateProject {
1538 project_id: project.id.to_proto(),
1539 worktrees: project.worktrees.clone(),
1540 },
1541 )
1542 },
1543 );
1544 }
1545
1546 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1547
1548 let rejoined_room = rejoined_room.into_inner();
1549
1550 room = rejoined_room.room;
1551 channel = rejoined_room.channel;
1552 }
1553
1554 if let Some(channel) = channel {
1555 channel_updated(
1556 &channel,
1557 &room,
1558 &session.peer,
1559 &*session.connection_pool().await,
1560 );
1561 }
1562
1563 update_user_contacts(session.user_id(), &session).await?;
1564 Ok(())
1565}
1566
1567fn notify_rejoined_projects(
1568 rejoined_projects: &mut Vec<RejoinedProject>,
1569 session: &Session,
1570) -> Result<()> {
1571 for project in rejoined_projects.iter() {
1572 for collaborator in &project.collaborators {
1573 session
1574 .peer
1575 .send(
1576 collaborator.connection_id,
1577 proto::UpdateProjectCollaborator {
1578 project_id: project.id.to_proto(),
1579 old_peer_id: Some(project.old_connection_id.into()),
1580 new_peer_id: Some(session.connection_id.into()),
1581 },
1582 )
1583 .trace_err();
1584 }
1585 }
1586
1587 for project in rejoined_projects {
1588 for worktree in mem::take(&mut project.worktrees) {
1589 // Stream this worktree's entries.
1590 let message = proto::UpdateWorktree {
1591 project_id: project.id.to_proto(),
1592 worktree_id: worktree.id,
1593 abs_path: worktree.abs_path.clone(),
1594 root_name: worktree.root_name,
1595 updated_entries: worktree.updated_entries,
1596 removed_entries: worktree.removed_entries,
1597 scan_id: worktree.scan_id,
1598 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1599 updated_repositories: worktree.updated_repositories,
1600 removed_repositories: worktree.removed_repositories,
1601 };
1602 for update in proto::split_worktree_update(message) {
1603 session.peer.send(session.connection_id, update)?;
1604 }
1605
1606 // Stream this worktree's diagnostics.
1607 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1608 if let Some(summary) = worktree_diagnostics.next() {
1609 let message = proto::UpdateDiagnosticSummary {
1610 project_id: project.id.to_proto(),
1611 worktree_id: worktree.id,
1612 summary: Some(summary),
1613 more_summaries: worktree_diagnostics.collect(),
1614 };
1615 session.peer.send(session.connection_id, message)?;
1616 }
1617
1618 for settings_file in worktree.settings_files {
1619 session.peer.send(
1620 session.connection_id,
1621 proto::UpdateWorktreeSettings {
1622 project_id: project.id.to_proto(),
1623 worktree_id: worktree.id,
1624 path: settings_file.path,
1625 content: Some(settings_file.content),
1626 kind: Some(settings_file.kind.to_proto().into()),
1627 },
1628 )?;
1629 }
1630 }
1631
1632 for repository in mem::take(&mut project.updated_repositories) {
1633 for update in split_repository_update(repository) {
1634 session.peer.send(session.connection_id, update)?;
1635 }
1636 }
1637
1638 for id in mem::take(&mut project.removed_repositories) {
1639 session.peer.send(
1640 session.connection_id,
1641 proto::RemoveRepository {
1642 project_id: project.id.to_proto(),
1643 id,
1644 },
1645 )?;
1646 }
1647 }
1648
1649 Ok(())
1650}
1651
1652/// leave room disconnects from the room.
1653async fn leave_room(
1654 _: proto::LeaveRoom,
1655 response: Response<proto::LeaveRoom>,
1656 session: MessageContext,
1657) -> Result<()> {
1658 leave_room_for_session(&session, session.connection_id).await?;
1659 response.send(proto::Ack {})?;
1660 Ok(())
1661}
1662
1663/// Updates the permissions of someone else in the room.
1664async fn set_room_participant_role(
1665 request: proto::SetRoomParticipantRole,
1666 response: Response<proto::SetRoomParticipantRole>,
1667 session: MessageContext,
1668) -> Result<()> {
1669 let user_id = UserId::from_proto(request.user_id);
1670 let role = ChannelRole::from(request.role());
1671
1672 let (livekit_room, can_publish) = {
1673 let room = session
1674 .db()
1675 .await
1676 .set_room_participant_role(
1677 session.user_id(),
1678 RoomId::from_proto(request.room_id),
1679 user_id,
1680 role,
1681 )
1682 .await?;
1683
1684 let livekit_room = room.livekit_room.clone();
1685 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1686 room_updated(&room, &session.peer);
1687 (livekit_room, can_publish)
1688 };
1689
1690 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1691 live_kit
1692 .update_participant(
1693 livekit_room.clone(),
1694 request.user_id.to_string(),
1695 livekit_api::proto::ParticipantPermission {
1696 can_subscribe: true,
1697 can_publish,
1698 can_publish_data: can_publish,
1699 hidden: false,
1700 recorder: false,
1701 },
1702 )
1703 .await
1704 .trace_err();
1705 }
1706
1707 response.send(proto::Ack {})?;
1708 Ok(())
1709}
1710
1711/// Call someone else into the current room
1712async fn call(
1713 request: proto::Call,
1714 response: Response<proto::Call>,
1715 session: MessageContext,
1716) -> Result<()> {
1717 let room_id = RoomId::from_proto(request.room_id);
1718 let calling_user_id = session.user_id();
1719 let calling_connection_id = session.connection_id;
1720 let called_user_id = UserId::from_proto(request.called_user_id);
1721 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1722 if !session
1723 .db()
1724 .await
1725 .has_contact(calling_user_id, called_user_id)
1726 .await?
1727 {
1728 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1729 }
1730
1731 let incoming_call = {
1732 let (room, incoming_call) = &mut *session
1733 .db()
1734 .await
1735 .call(
1736 room_id,
1737 calling_user_id,
1738 calling_connection_id,
1739 called_user_id,
1740 initial_project_id,
1741 )
1742 .await?;
1743 room_updated(room, &session.peer);
1744 mem::take(incoming_call)
1745 };
1746 update_user_contacts(called_user_id, &session).await?;
1747
1748 let mut calls = session
1749 .connection_pool()
1750 .await
1751 .user_connection_ids(called_user_id)
1752 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1753 .collect::<FuturesUnordered<_>>();
1754
1755 while let Some(call_response) = calls.next().await {
1756 match call_response.as_ref() {
1757 Ok(_) => {
1758 response.send(proto::Ack {})?;
1759 return Ok(());
1760 }
1761 Err(_) => {
1762 call_response.trace_err();
1763 }
1764 }
1765 }
1766
1767 {
1768 let room = session
1769 .db()
1770 .await
1771 .call_failed(room_id, called_user_id)
1772 .await?;
1773 room_updated(&room, &session.peer);
1774 }
1775 update_user_contacts(called_user_id, &session).await?;
1776
1777 Err(anyhow!("failed to ring user"))?
1778}
1779
1780/// Cancel an outgoing call.
1781async fn cancel_call(
1782 request: proto::CancelCall,
1783 response: Response<proto::CancelCall>,
1784 session: MessageContext,
1785) -> Result<()> {
1786 let called_user_id = UserId::from_proto(request.called_user_id);
1787 let room_id = RoomId::from_proto(request.room_id);
1788 {
1789 let room = session
1790 .db()
1791 .await
1792 .cancel_call(room_id, session.connection_id, called_user_id)
1793 .await?;
1794 room_updated(&room, &session.peer);
1795 }
1796
1797 for connection_id in session
1798 .connection_pool()
1799 .await
1800 .user_connection_ids(called_user_id)
1801 {
1802 session
1803 .peer
1804 .send(
1805 connection_id,
1806 proto::CallCanceled {
1807 room_id: room_id.to_proto(),
1808 },
1809 )
1810 .trace_err();
1811 }
1812 response.send(proto::Ack {})?;
1813
1814 update_user_contacts(called_user_id, &session).await?;
1815 Ok(())
1816}
1817
1818/// Decline an incoming call.
1819async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1820 let room_id = RoomId::from_proto(message.room_id);
1821 {
1822 let room = session
1823 .db()
1824 .await
1825 .decline_call(Some(room_id), session.user_id())
1826 .await?
1827 .context("declining call")?;
1828 room_updated(&room, &session.peer);
1829 }
1830
1831 for connection_id in session
1832 .connection_pool()
1833 .await
1834 .user_connection_ids(session.user_id())
1835 {
1836 session
1837 .peer
1838 .send(
1839 connection_id,
1840 proto::CallCanceled {
1841 room_id: room_id.to_proto(),
1842 },
1843 )
1844 .trace_err();
1845 }
1846 update_user_contacts(session.user_id(), &session).await?;
1847 Ok(())
1848}
1849
1850/// Updates other participants in the room with your current location.
1851async fn update_participant_location(
1852 request: proto::UpdateParticipantLocation,
1853 response: Response<proto::UpdateParticipantLocation>,
1854 session: MessageContext,
1855) -> Result<()> {
1856 let room_id = RoomId::from_proto(request.room_id);
1857 let location = request.location.context("invalid location")?;
1858
1859 let db = session.db().await;
1860 let room = db
1861 .update_room_participant_location(room_id, session.connection_id, location)
1862 .await?;
1863
1864 room_updated(&room, &session.peer);
1865 response.send(proto::Ack {})?;
1866 Ok(())
1867}
1868
1869/// Share a project into the room.
1870async fn share_project(
1871 request: proto::ShareProject,
1872 response: Response<proto::ShareProject>,
1873 session: MessageContext,
1874) -> Result<()> {
1875 let (project_id, room) = &*session
1876 .db()
1877 .await
1878 .share_project(
1879 RoomId::from_proto(request.room_id),
1880 session.connection_id,
1881 &request.worktrees,
1882 request.is_ssh_project,
1883 request.windows_paths.unwrap_or(false),
1884 )
1885 .await?;
1886 response.send(proto::ShareProjectResponse {
1887 project_id: project_id.to_proto(),
1888 })?;
1889 room_updated(room, &session.peer);
1890
1891 Ok(())
1892}
1893
1894/// Unshare a project from the room.
1895async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1896 let project_id = ProjectId::from_proto(message.project_id);
1897 unshare_project_internal(project_id, session.connection_id, &session).await
1898}
1899
1900async fn unshare_project_internal(
1901 project_id: ProjectId,
1902 connection_id: ConnectionId,
1903 session: &Session,
1904) -> Result<()> {
1905 let delete = {
1906 let room_guard = session
1907 .db()
1908 .await
1909 .unshare_project(project_id, connection_id)
1910 .await?;
1911
1912 let (delete, room, guest_connection_ids) = &*room_guard;
1913
1914 let message = proto::UnshareProject {
1915 project_id: project_id.to_proto(),
1916 };
1917
1918 broadcast(
1919 Some(connection_id),
1920 guest_connection_ids.iter().copied(),
1921 |conn_id| session.peer.send(conn_id, message.clone()),
1922 );
1923 if let Some(room) = room {
1924 room_updated(room, &session.peer);
1925 }
1926
1927 *delete
1928 };
1929
1930 if delete {
1931 let db = session.db().await;
1932 db.delete_project(project_id).await?;
1933 }
1934
1935 Ok(())
1936}
1937
1938/// Join someone elses shared project.
1939async fn join_project(
1940 request: proto::JoinProject,
1941 response: Response<proto::JoinProject>,
1942 session: MessageContext,
1943) -> Result<()> {
1944 let project_id = ProjectId::from_proto(request.project_id);
1945
1946 tracing::info!(%project_id, "join project");
1947
1948 let db = session.db().await;
1949 let (project, replica_id) = &mut *db
1950 .join_project(
1951 project_id,
1952 session.connection_id,
1953 session.user_id(),
1954 request.committer_name.clone(),
1955 request.committer_email.clone(),
1956 )
1957 .await?;
1958 drop(db);
1959 tracing::info!(%project_id, "join remote project");
1960 let collaborators = project
1961 .collaborators
1962 .iter()
1963 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1964 .map(|collaborator| collaborator.to_proto())
1965 .collect::<Vec<_>>();
1966 let project_id = project.id;
1967 let guest_user_id = session.user_id();
1968
1969 let worktrees = project
1970 .worktrees
1971 .iter()
1972 .map(|(id, worktree)| proto::WorktreeMetadata {
1973 id: *id,
1974 root_name: worktree.root_name.clone(),
1975 visible: worktree.visible,
1976 abs_path: worktree.abs_path.clone(),
1977 })
1978 .collect::<Vec<_>>();
1979
1980 let add_project_collaborator = proto::AddProjectCollaborator {
1981 project_id: project_id.to_proto(),
1982 collaborator: Some(proto::Collaborator {
1983 peer_id: Some(session.connection_id.into()),
1984 replica_id: replica_id.0 as u32,
1985 user_id: guest_user_id.to_proto(),
1986 is_host: false,
1987 committer_name: request.committer_name.clone(),
1988 committer_email: request.committer_email.clone(),
1989 }),
1990 };
1991
1992 for collaborator in &collaborators {
1993 session
1994 .peer
1995 .send(
1996 collaborator.peer_id.unwrap().into(),
1997 add_project_collaborator.clone(),
1998 )
1999 .trace_err();
2000 }
2001
2002 // First, we send the metadata associated with each worktree.
2003 let (language_servers, language_server_capabilities) = project
2004 .language_servers
2005 .clone()
2006 .into_iter()
2007 .map(|server| (server.server, server.capabilities))
2008 .unzip();
2009 response.send(proto::JoinProjectResponse {
2010 project_id: project.id.0 as u64,
2011 worktrees,
2012 replica_id: replica_id.0 as u32,
2013 collaborators,
2014 language_servers,
2015 language_server_capabilities,
2016 role: project.role.into(),
2017 windows_paths: project.path_style == PathStyle::Windows,
2018 })?;
2019
2020 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2021 // Stream this worktree's entries.
2022 let message = proto::UpdateWorktree {
2023 project_id: project_id.to_proto(),
2024 worktree_id,
2025 abs_path: worktree.abs_path.clone(),
2026 root_name: worktree.root_name,
2027 updated_entries: worktree.entries,
2028 removed_entries: Default::default(),
2029 scan_id: worktree.scan_id,
2030 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2031 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2032 removed_repositories: Default::default(),
2033 };
2034 for update in proto::split_worktree_update(message) {
2035 session.peer.send(session.connection_id, update.clone())?;
2036 }
2037
2038 // Stream this worktree's diagnostics.
2039 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2040 if let Some(summary) = worktree_diagnostics.next() {
2041 let message = proto::UpdateDiagnosticSummary {
2042 project_id: project.id.to_proto(),
2043 worktree_id: worktree.id,
2044 summary: Some(summary),
2045 more_summaries: worktree_diagnostics.collect(),
2046 };
2047 session.peer.send(session.connection_id, message)?;
2048 }
2049
2050 for settings_file in worktree.settings_files {
2051 session.peer.send(
2052 session.connection_id,
2053 proto::UpdateWorktreeSettings {
2054 project_id: project_id.to_proto(),
2055 worktree_id: worktree.id,
2056 path: settings_file.path,
2057 content: Some(settings_file.content),
2058 kind: Some(settings_file.kind.to_proto() as i32),
2059 },
2060 )?;
2061 }
2062 }
2063
2064 for repository in mem::take(&mut project.repositories) {
2065 for update in split_repository_update(repository) {
2066 session.peer.send(session.connection_id, update)?;
2067 }
2068 }
2069
2070 for language_server in &project.language_servers {
2071 session.peer.send(
2072 session.connection_id,
2073 proto::UpdateLanguageServer {
2074 project_id: project_id.to_proto(),
2075 server_name: Some(language_server.server.name.clone()),
2076 language_server_id: language_server.server.id,
2077 variant: Some(
2078 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2079 proto::LspDiskBasedDiagnosticsUpdated {},
2080 ),
2081 ),
2082 },
2083 )?;
2084 }
2085
2086 Ok(())
2087}
2088
2089/// Leave someone elses shared project.
2090async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2091 let sender_id = session.connection_id;
2092 let project_id = ProjectId::from_proto(request.project_id);
2093 let db = session.db().await;
2094
2095 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2096 tracing::info!(
2097 %project_id,
2098 "leave project"
2099 );
2100
2101 project_left(project, &session);
2102 if let Some(room) = room {
2103 room_updated(room, &session.peer);
2104 }
2105
2106 Ok(())
2107}
2108
2109/// Updates other participants with changes to the project
2110async fn update_project(
2111 request: proto::UpdateProject,
2112 response: Response<proto::UpdateProject>,
2113 session: MessageContext,
2114) -> Result<()> {
2115 let project_id = ProjectId::from_proto(request.project_id);
2116 let (room, guest_connection_ids) = &*session
2117 .db()
2118 .await
2119 .update_project(project_id, session.connection_id, &request.worktrees)
2120 .await?;
2121 broadcast(
2122 Some(session.connection_id),
2123 guest_connection_ids.iter().copied(),
2124 |connection_id| {
2125 session
2126 .peer
2127 .forward_send(session.connection_id, connection_id, request.clone())
2128 },
2129 );
2130 if let Some(room) = room {
2131 room_updated(room, &session.peer);
2132 }
2133 response.send(proto::Ack {})?;
2134
2135 Ok(())
2136}
2137
2138/// Updates other participants with changes to the worktree
2139async fn update_worktree(
2140 request: proto::UpdateWorktree,
2141 response: Response<proto::UpdateWorktree>,
2142 session: MessageContext,
2143) -> Result<()> {
2144 let guest_connection_ids = session
2145 .db()
2146 .await
2147 .update_worktree(&request, session.connection_id)
2148 .await?;
2149
2150 broadcast(
2151 Some(session.connection_id),
2152 guest_connection_ids.iter().copied(),
2153 |connection_id| {
2154 session
2155 .peer
2156 .forward_send(session.connection_id, connection_id, request.clone())
2157 },
2158 );
2159 response.send(proto::Ack {})?;
2160 Ok(())
2161}
2162
2163async fn update_repository(
2164 request: proto::UpdateRepository,
2165 response: Response<proto::UpdateRepository>,
2166 session: MessageContext,
2167) -> Result<()> {
2168 let guest_connection_ids = session
2169 .db()
2170 .await
2171 .update_repository(&request, session.connection_id)
2172 .await?;
2173
2174 broadcast(
2175 Some(session.connection_id),
2176 guest_connection_ids.iter().copied(),
2177 |connection_id| {
2178 session
2179 .peer
2180 .forward_send(session.connection_id, connection_id, request.clone())
2181 },
2182 );
2183 response.send(proto::Ack {})?;
2184 Ok(())
2185}
2186
2187async fn remove_repository(
2188 request: proto::RemoveRepository,
2189 response: Response<proto::RemoveRepository>,
2190 session: MessageContext,
2191) -> Result<()> {
2192 let guest_connection_ids = session
2193 .db()
2194 .await
2195 .remove_repository(&request, session.connection_id)
2196 .await?;
2197
2198 broadcast(
2199 Some(session.connection_id),
2200 guest_connection_ids.iter().copied(),
2201 |connection_id| {
2202 session
2203 .peer
2204 .forward_send(session.connection_id, connection_id, request.clone())
2205 },
2206 );
2207 response.send(proto::Ack {})?;
2208 Ok(())
2209}
2210
2211/// Updates other participants with changes to the diagnostics
2212async fn update_diagnostic_summary(
2213 message: proto::UpdateDiagnosticSummary,
2214 session: MessageContext,
2215) -> Result<()> {
2216 let guest_connection_ids = session
2217 .db()
2218 .await
2219 .update_diagnostic_summary(&message, session.connection_id)
2220 .await?;
2221
2222 broadcast(
2223 Some(session.connection_id),
2224 guest_connection_ids.iter().copied(),
2225 |connection_id| {
2226 session
2227 .peer
2228 .forward_send(session.connection_id, connection_id, message.clone())
2229 },
2230 );
2231
2232 Ok(())
2233}
2234
2235/// Updates other participants with changes to the worktree settings
2236async fn update_worktree_settings(
2237 message: proto::UpdateWorktreeSettings,
2238 session: MessageContext,
2239) -> Result<()> {
2240 let guest_connection_ids = session
2241 .db()
2242 .await
2243 .update_worktree_settings(&message, session.connection_id)
2244 .await?;
2245
2246 broadcast(
2247 Some(session.connection_id),
2248 guest_connection_ids.iter().copied(),
2249 |connection_id| {
2250 session
2251 .peer
2252 .forward_send(session.connection_id, connection_id, message.clone())
2253 },
2254 );
2255
2256 Ok(())
2257}
2258
2259/// Notify other participants that a language server has started.
2260async fn start_language_server(
2261 request: proto::StartLanguageServer,
2262 session: MessageContext,
2263) -> Result<()> {
2264 let guest_connection_ids = session
2265 .db()
2266 .await
2267 .start_language_server(&request, session.connection_id)
2268 .await?;
2269
2270 broadcast(
2271 Some(session.connection_id),
2272 guest_connection_ids.iter().copied(),
2273 |connection_id| {
2274 session
2275 .peer
2276 .forward_send(session.connection_id, connection_id, request.clone())
2277 },
2278 );
2279 Ok(())
2280}
2281
2282/// Notify other participants that a language server has changed.
2283async fn update_language_server(
2284 request: proto::UpdateLanguageServer,
2285 session: MessageContext,
2286) -> Result<()> {
2287 let project_id = ProjectId::from_proto(request.project_id);
2288 let db = session.db().await;
2289
2290 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2291 && let Some(capabilities) = update.capabilities.clone()
2292 {
2293 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2294 .await?;
2295 }
2296
2297 let project_connection_ids = db
2298 .project_connection_ids(project_id, session.connection_id, true)
2299 .await?;
2300 broadcast(
2301 Some(session.connection_id),
2302 project_connection_ids.iter().copied(),
2303 |connection_id| {
2304 session
2305 .peer
2306 .forward_send(session.connection_id, connection_id, request.clone())
2307 },
2308 );
2309 Ok(())
2310}
2311
2312/// forward a project request to the host. These requests should be read only
2313/// as guests are allowed to send them.
2314async fn forward_read_only_project_request<T>(
2315 request: T,
2316 response: Response<T>,
2317 session: MessageContext,
2318) -> Result<()>
2319where
2320 T: EntityMessage + RequestMessage,
2321{
2322 let project_id = ProjectId::from_proto(request.remote_entity_id());
2323 let host_connection_id = session
2324 .db()
2325 .await
2326 .host_for_read_only_project_request(project_id, session.connection_id)
2327 .await?;
2328 let payload = session.forward_request(host_connection_id, request).await?;
2329 response.send(payload)?;
2330 Ok(())
2331}
2332
2333/// forward a project request to the host. These requests are disallowed
2334/// for guests.
2335async fn forward_mutating_project_request<T>(
2336 request: T,
2337 response: Response<T>,
2338 session: MessageContext,
2339) -> Result<()>
2340where
2341 T: EntityMessage + RequestMessage,
2342{
2343 let project_id = ProjectId::from_proto(request.remote_entity_id());
2344
2345 let host_connection_id = session
2346 .db()
2347 .await
2348 .host_for_mutating_project_request(project_id, session.connection_id)
2349 .await?;
2350 let payload = session.forward_request(host_connection_id, request).await?;
2351 response.send(payload)?;
2352 Ok(())
2353}
2354
2355async fn lsp_query(
2356 request: proto::LspQuery,
2357 response: Response<proto::LspQuery>,
2358 session: MessageContext,
2359) -> Result<()> {
2360 let (name, should_write) = request.query_name_and_write_permissions();
2361 tracing::Span::current().record("lsp_query_request", name);
2362 tracing::info!("lsp_query message received");
2363 if should_write {
2364 forward_mutating_project_request(request, response, session).await
2365 } else {
2366 forward_read_only_project_request(request, response, session).await
2367 }
2368}
2369
2370/// Notify other participants that a new buffer has been created
2371async fn create_buffer_for_peer(
2372 request: proto::CreateBufferForPeer,
2373 session: MessageContext,
2374) -> Result<()> {
2375 session
2376 .db()
2377 .await
2378 .check_user_is_project_host(
2379 ProjectId::from_proto(request.project_id),
2380 session.connection_id,
2381 )
2382 .await?;
2383 let peer_id = request.peer_id.context("invalid peer id")?;
2384 session
2385 .peer
2386 .forward_send(session.connection_id, peer_id.into(), request)?;
2387 Ok(())
2388}
2389
2390/// Notify other participants that a buffer has been updated. This is
2391/// allowed for guests as long as the update is limited to selections.
2392async fn update_buffer(
2393 request: proto::UpdateBuffer,
2394 response: Response<proto::UpdateBuffer>,
2395 session: MessageContext,
2396) -> Result<()> {
2397 let project_id = ProjectId::from_proto(request.project_id);
2398 let mut capability = Capability::ReadOnly;
2399
2400 for op in request.operations.iter() {
2401 match op.variant {
2402 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2403 Some(_) => capability = Capability::ReadWrite,
2404 }
2405 }
2406
2407 let host = {
2408 let guard = session
2409 .db()
2410 .await
2411 .connections_for_buffer_update(project_id, session.connection_id, capability)
2412 .await?;
2413
2414 let (host, guests) = &*guard;
2415
2416 broadcast(
2417 Some(session.connection_id),
2418 guests.clone(),
2419 |connection_id| {
2420 session
2421 .peer
2422 .forward_send(session.connection_id, connection_id, request.clone())
2423 },
2424 );
2425
2426 *host
2427 };
2428
2429 if host != session.connection_id {
2430 session.forward_request(host, request.clone()).await?;
2431 }
2432
2433 response.send(proto::Ack {})?;
2434 Ok(())
2435}
2436
2437async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2438 let project_id = ProjectId::from_proto(message.project_id);
2439
2440 let operation = message.operation.as_ref().context("invalid operation")?;
2441 let capability = match operation.variant.as_ref() {
2442 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2443 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2444 match buffer_op.variant {
2445 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2446 Capability::ReadOnly
2447 }
2448 _ => Capability::ReadWrite,
2449 }
2450 } else {
2451 Capability::ReadWrite
2452 }
2453 }
2454 Some(_) => Capability::ReadWrite,
2455 None => Capability::ReadOnly,
2456 };
2457
2458 let guard = session
2459 .db()
2460 .await
2461 .connections_for_buffer_update(project_id, session.connection_id, capability)
2462 .await?;
2463
2464 let (host, guests) = &*guard;
2465
2466 broadcast(
2467 Some(session.connection_id),
2468 guests.iter().chain([host]).copied(),
2469 |connection_id| {
2470 session
2471 .peer
2472 .forward_send(session.connection_id, connection_id, message.clone())
2473 },
2474 );
2475
2476 Ok(())
2477}
2478
2479/// Notify other participants that a project has been updated.
2480async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2481 request: T,
2482 session: MessageContext,
2483) -> Result<()> {
2484 let project_id = ProjectId::from_proto(request.remote_entity_id());
2485 let project_connection_ids = session
2486 .db()
2487 .await
2488 .project_connection_ids(project_id, session.connection_id, false)
2489 .await?;
2490
2491 broadcast(
2492 Some(session.connection_id),
2493 project_connection_ids.iter().copied(),
2494 |connection_id| {
2495 session
2496 .peer
2497 .forward_send(session.connection_id, connection_id, request.clone())
2498 },
2499 );
2500 Ok(())
2501}
2502
2503/// Start following another user in a call.
2504async fn follow(
2505 request: proto::Follow,
2506 response: Response<proto::Follow>,
2507 session: MessageContext,
2508) -> Result<()> {
2509 let room_id = RoomId::from_proto(request.room_id);
2510 let project_id = request.project_id.map(ProjectId::from_proto);
2511 let leader_id = request.leader_id.context("invalid leader id")?.into();
2512 let follower_id = session.connection_id;
2513
2514 session
2515 .db()
2516 .await
2517 .check_room_participants(room_id, leader_id, session.connection_id)
2518 .await?;
2519
2520 let response_payload = session.forward_request(leader_id, request).await?;
2521 response.send(response_payload)?;
2522
2523 if let Some(project_id) = project_id {
2524 let room = session
2525 .db()
2526 .await
2527 .follow(room_id, project_id, leader_id, follower_id)
2528 .await?;
2529 room_updated(&room, &session.peer);
2530 }
2531
2532 Ok(())
2533}
2534
2535/// Stop following another user in a call.
2536async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2537 let room_id = RoomId::from_proto(request.room_id);
2538 let project_id = request.project_id.map(ProjectId::from_proto);
2539 let leader_id = request.leader_id.context("invalid leader id")?.into();
2540 let follower_id = session.connection_id;
2541
2542 session
2543 .db()
2544 .await
2545 .check_room_participants(room_id, leader_id, session.connection_id)
2546 .await?;
2547
2548 session
2549 .peer
2550 .forward_send(session.connection_id, leader_id, request)?;
2551
2552 if let Some(project_id) = project_id {
2553 let room = session
2554 .db()
2555 .await
2556 .unfollow(room_id, project_id, leader_id, follower_id)
2557 .await?;
2558 room_updated(&room, &session.peer);
2559 }
2560
2561 Ok(())
2562}
2563
2564/// Notify everyone following you of your current location.
2565async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2566 let room_id = RoomId::from_proto(request.room_id);
2567 let database = session.db.lock().await;
2568
2569 let connection_ids = if let Some(project_id) = request.project_id {
2570 let project_id = ProjectId::from_proto(project_id);
2571 database
2572 .project_connection_ids(project_id, session.connection_id, true)
2573 .await?
2574 } else {
2575 database
2576 .room_connection_ids(room_id, session.connection_id)
2577 .await?
2578 };
2579
2580 // For now, don't send view update messages back to that view's current leader.
2581 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2582 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2583 _ => None,
2584 });
2585
2586 for connection_id in connection_ids.iter().cloned() {
2587 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2588 session
2589 .peer
2590 .forward_send(session.connection_id, connection_id, request.clone())?;
2591 }
2592 }
2593 Ok(())
2594}
2595
2596/// Get public data about users.
2597async fn get_users(
2598 request: proto::GetUsers,
2599 response: Response<proto::GetUsers>,
2600 session: MessageContext,
2601) -> Result<()> {
2602 let user_ids = request
2603 .user_ids
2604 .into_iter()
2605 .map(UserId::from_proto)
2606 .collect();
2607 let users = session
2608 .db()
2609 .await
2610 .get_users_by_ids(user_ids)
2611 .await?
2612 .into_iter()
2613 .map(|user| proto::User {
2614 id: user.id.to_proto(),
2615 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2616 github_login: user.github_login,
2617 name: user.name,
2618 })
2619 .collect();
2620 response.send(proto::UsersResponse { users })?;
2621 Ok(())
2622}
2623
2624/// Search for users (to invite) buy Github login
2625async fn fuzzy_search_users(
2626 request: proto::FuzzySearchUsers,
2627 response: Response<proto::FuzzySearchUsers>,
2628 session: MessageContext,
2629) -> Result<()> {
2630 let query = request.query;
2631 let users = match query.len() {
2632 0 => vec![],
2633 1 | 2 => session
2634 .db()
2635 .await
2636 .get_user_by_github_login(&query)
2637 .await?
2638 .into_iter()
2639 .collect(),
2640 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2641 };
2642 let users = users
2643 .into_iter()
2644 .filter(|user| user.id != session.user_id())
2645 .map(|user| proto::User {
2646 id: user.id.to_proto(),
2647 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2648 github_login: user.github_login,
2649 name: user.name,
2650 })
2651 .collect();
2652 response.send(proto::UsersResponse { users })?;
2653 Ok(())
2654}
2655
2656/// Send a contact request to another user.
2657async fn request_contact(
2658 request: proto::RequestContact,
2659 response: Response<proto::RequestContact>,
2660 session: MessageContext,
2661) -> Result<()> {
2662 let requester_id = session.user_id();
2663 let responder_id = UserId::from_proto(request.responder_id);
2664 if requester_id == responder_id {
2665 return Err(anyhow!("cannot add yourself as a contact"))?;
2666 }
2667
2668 let notifications = session
2669 .db()
2670 .await
2671 .send_contact_request(requester_id, responder_id)
2672 .await?;
2673
2674 // Update outgoing contact requests of requester
2675 let mut update = proto::UpdateContacts::default();
2676 update.outgoing_requests.push(responder_id.to_proto());
2677 for connection_id in session
2678 .connection_pool()
2679 .await
2680 .user_connection_ids(requester_id)
2681 {
2682 session.peer.send(connection_id, update.clone())?;
2683 }
2684
2685 // Update incoming contact requests of responder
2686 let mut update = proto::UpdateContacts::default();
2687 update
2688 .incoming_requests
2689 .push(proto::IncomingContactRequest {
2690 requester_id: requester_id.to_proto(),
2691 });
2692 let connection_pool = session.connection_pool().await;
2693 for connection_id in connection_pool.user_connection_ids(responder_id) {
2694 session.peer.send(connection_id, update.clone())?;
2695 }
2696
2697 send_notifications(&connection_pool, &session.peer, notifications);
2698
2699 response.send(proto::Ack {})?;
2700 Ok(())
2701}
2702
2703/// Accept or decline a contact request
2704async fn respond_to_contact_request(
2705 request: proto::RespondToContactRequest,
2706 response: Response<proto::RespondToContactRequest>,
2707 session: MessageContext,
2708) -> Result<()> {
2709 let responder_id = session.user_id();
2710 let requester_id = UserId::from_proto(request.requester_id);
2711 let db = session.db().await;
2712 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2713 db.dismiss_contact_notification(responder_id, requester_id)
2714 .await?;
2715 } else {
2716 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2717
2718 let notifications = db
2719 .respond_to_contact_request(responder_id, requester_id, accept)
2720 .await?;
2721 let requester_busy = db.is_user_busy(requester_id).await?;
2722 let responder_busy = db.is_user_busy(responder_id).await?;
2723
2724 let pool = session.connection_pool().await;
2725 // Update responder with new contact
2726 let mut update = proto::UpdateContacts::default();
2727 if accept {
2728 update
2729 .contacts
2730 .push(contact_for_user(requester_id, requester_busy, &pool));
2731 }
2732 update
2733 .remove_incoming_requests
2734 .push(requester_id.to_proto());
2735 for connection_id in pool.user_connection_ids(responder_id) {
2736 session.peer.send(connection_id, update.clone())?;
2737 }
2738
2739 // Update requester with new contact
2740 let mut update = proto::UpdateContacts::default();
2741 if accept {
2742 update
2743 .contacts
2744 .push(contact_for_user(responder_id, responder_busy, &pool));
2745 }
2746 update
2747 .remove_outgoing_requests
2748 .push(responder_id.to_proto());
2749
2750 for connection_id in pool.user_connection_ids(requester_id) {
2751 session.peer.send(connection_id, update.clone())?;
2752 }
2753
2754 send_notifications(&pool, &session.peer, notifications);
2755 }
2756
2757 response.send(proto::Ack {})?;
2758 Ok(())
2759}
2760
2761/// Remove a contact.
2762async fn remove_contact(
2763 request: proto::RemoveContact,
2764 response: Response<proto::RemoveContact>,
2765 session: MessageContext,
2766) -> Result<()> {
2767 let requester_id = session.user_id();
2768 let responder_id = UserId::from_proto(request.user_id);
2769 let db = session.db().await;
2770 let (contact_accepted, deleted_notification_id) =
2771 db.remove_contact(requester_id, responder_id).await?;
2772
2773 let pool = session.connection_pool().await;
2774 // Update outgoing contact requests of requester
2775 let mut update = proto::UpdateContacts::default();
2776 if contact_accepted {
2777 update.remove_contacts.push(responder_id.to_proto());
2778 } else {
2779 update
2780 .remove_outgoing_requests
2781 .push(responder_id.to_proto());
2782 }
2783 for connection_id in pool.user_connection_ids(requester_id) {
2784 session.peer.send(connection_id, update.clone())?;
2785 }
2786
2787 // Update incoming contact requests of responder
2788 let mut update = proto::UpdateContacts::default();
2789 if contact_accepted {
2790 update.remove_contacts.push(requester_id.to_proto());
2791 } else {
2792 update
2793 .remove_incoming_requests
2794 .push(requester_id.to_proto());
2795 }
2796 for connection_id in pool.user_connection_ids(responder_id) {
2797 session.peer.send(connection_id, update.clone())?;
2798 if let Some(notification_id) = deleted_notification_id {
2799 session.peer.send(
2800 connection_id,
2801 proto::DeleteNotification {
2802 notification_id: notification_id.to_proto(),
2803 },
2804 )?;
2805 }
2806 }
2807
2808 response.send(proto::Ack {})?;
2809 Ok(())
2810}
2811
2812const fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2813 version.0.minor() < 139
2814}
2815
2816async fn subscribe_to_channels(
2817 _: proto::SubscribeToChannels,
2818 session: MessageContext,
2819) -> Result<()> {
2820 subscribe_user_to_channels(session.user_id(), &session).await?;
2821 Ok(())
2822}
2823
2824async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2825 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2826 let mut pool = session.connection_pool().await;
2827 for membership in &channels_for_user.channel_memberships {
2828 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2829 }
2830 session.peer.send(
2831 session.connection_id,
2832 build_update_user_channels(&channels_for_user),
2833 )?;
2834 session.peer.send(
2835 session.connection_id,
2836 build_channels_update(channels_for_user),
2837 )?;
2838 Ok(())
2839}
2840
2841/// Creates a new channel.
2842async fn create_channel(
2843 request: proto::CreateChannel,
2844 response: Response<proto::CreateChannel>,
2845 session: MessageContext,
2846) -> Result<()> {
2847 let db = session.db().await;
2848
2849 let parent_id = request.parent_id.map(ChannelId::from_proto);
2850 let (channel, membership) = db
2851 .create_channel(&request.name, parent_id, session.user_id())
2852 .await?;
2853
2854 let root_id = channel.root_id();
2855 let channel = Channel::from_model(channel);
2856
2857 response.send(proto::CreateChannelResponse {
2858 channel: Some(channel.to_proto()),
2859 parent_id: request.parent_id,
2860 })?;
2861
2862 let mut connection_pool = session.connection_pool().await;
2863 if let Some(membership) = membership {
2864 connection_pool.subscribe_to_channel(
2865 membership.user_id,
2866 membership.channel_id,
2867 membership.role,
2868 );
2869 let update = proto::UpdateUserChannels {
2870 channel_memberships: vec![proto::ChannelMembership {
2871 channel_id: membership.channel_id.to_proto(),
2872 role: membership.role.into(),
2873 }],
2874 ..Default::default()
2875 };
2876 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2877 session.peer.send(connection_id, update.clone())?;
2878 }
2879 }
2880
2881 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2882 if !role.can_see_channel(channel.visibility) {
2883 continue;
2884 }
2885
2886 let update = proto::UpdateChannels {
2887 channels: vec![channel.to_proto()],
2888 ..Default::default()
2889 };
2890 session.peer.send(connection_id, update.clone())?;
2891 }
2892
2893 Ok(())
2894}
2895
2896/// Delete a channel
2897async fn delete_channel(
2898 request: proto::DeleteChannel,
2899 response: Response<proto::DeleteChannel>,
2900 session: MessageContext,
2901) -> Result<()> {
2902 let db = session.db().await;
2903
2904 let channel_id = request.channel_id;
2905 let (root_channel, removed_channels) = db
2906 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2907 .await?;
2908 response.send(proto::Ack {})?;
2909
2910 // Notify members of removed channels
2911 let mut update = proto::UpdateChannels::default();
2912 update
2913 .delete_channels
2914 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2915
2916 let connection_pool = session.connection_pool().await;
2917 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2918 session.peer.send(connection_id, update.clone())?;
2919 }
2920
2921 Ok(())
2922}
2923
2924/// Invite someone to join a channel.
2925async fn invite_channel_member(
2926 request: proto::InviteChannelMember,
2927 response: Response<proto::InviteChannelMember>,
2928 session: MessageContext,
2929) -> Result<()> {
2930 let db = session.db().await;
2931 let channel_id = ChannelId::from_proto(request.channel_id);
2932 let invitee_id = UserId::from_proto(request.user_id);
2933 let InviteMemberResult {
2934 channel,
2935 notifications,
2936 } = db
2937 .invite_channel_member(
2938 channel_id,
2939 invitee_id,
2940 session.user_id(),
2941 request.role().into(),
2942 )
2943 .await?;
2944
2945 let update = proto::UpdateChannels {
2946 channel_invitations: vec![channel.to_proto()],
2947 ..Default::default()
2948 };
2949
2950 let connection_pool = session.connection_pool().await;
2951 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2952 session.peer.send(connection_id, update.clone())?;
2953 }
2954
2955 send_notifications(&connection_pool, &session.peer, notifications);
2956
2957 response.send(proto::Ack {})?;
2958 Ok(())
2959}
2960
2961/// remove someone from a channel
2962async fn remove_channel_member(
2963 request: proto::RemoveChannelMember,
2964 response: Response<proto::RemoveChannelMember>,
2965 session: MessageContext,
2966) -> Result<()> {
2967 let db = session.db().await;
2968 let channel_id = ChannelId::from_proto(request.channel_id);
2969 let member_id = UserId::from_proto(request.user_id);
2970
2971 let RemoveChannelMemberResult {
2972 membership_update,
2973 notification_id,
2974 } = db
2975 .remove_channel_member(channel_id, member_id, session.user_id())
2976 .await?;
2977
2978 let mut connection_pool = session.connection_pool().await;
2979 notify_membership_updated(
2980 &mut connection_pool,
2981 membership_update,
2982 member_id,
2983 &session.peer,
2984 );
2985 for connection_id in connection_pool.user_connection_ids(member_id) {
2986 if let Some(notification_id) = notification_id {
2987 session
2988 .peer
2989 .send(
2990 connection_id,
2991 proto::DeleteNotification {
2992 notification_id: notification_id.to_proto(),
2993 },
2994 )
2995 .trace_err();
2996 }
2997 }
2998
2999 response.send(proto::Ack {})?;
3000 Ok(())
3001}
3002
3003/// Toggle the channel between public and private.
3004/// Care is taken to maintain the invariant that public channels only descend from public channels,
3005/// (though members-only channels can appear at any point in the hierarchy).
3006async fn set_channel_visibility(
3007 request: proto::SetChannelVisibility,
3008 response: Response<proto::SetChannelVisibility>,
3009 session: MessageContext,
3010) -> Result<()> {
3011 let db = session.db().await;
3012 let channel_id = ChannelId::from_proto(request.channel_id);
3013 let visibility = request.visibility().into();
3014
3015 let channel_model = db
3016 .set_channel_visibility(channel_id, visibility, session.user_id())
3017 .await?;
3018 let root_id = channel_model.root_id();
3019 let channel = Channel::from_model(channel_model);
3020
3021 let mut connection_pool = session.connection_pool().await;
3022 for (user_id, role) in connection_pool
3023 .channel_user_ids(root_id)
3024 .collect::<Vec<_>>()
3025 .into_iter()
3026 {
3027 let update = if role.can_see_channel(channel.visibility) {
3028 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3029 proto::UpdateChannels {
3030 channels: vec![channel.to_proto()],
3031 ..Default::default()
3032 }
3033 } else {
3034 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3035 proto::UpdateChannels {
3036 delete_channels: vec![channel.id.to_proto()],
3037 ..Default::default()
3038 }
3039 };
3040
3041 for connection_id in connection_pool.user_connection_ids(user_id) {
3042 session.peer.send(connection_id, update.clone())?;
3043 }
3044 }
3045
3046 response.send(proto::Ack {})?;
3047 Ok(())
3048}
3049
3050/// Alter the role for a user in the channel.
3051async fn set_channel_member_role(
3052 request: proto::SetChannelMemberRole,
3053 response: Response<proto::SetChannelMemberRole>,
3054 session: MessageContext,
3055) -> Result<()> {
3056 let db = session.db().await;
3057 let channel_id = ChannelId::from_proto(request.channel_id);
3058 let member_id = UserId::from_proto(request.user_id);
3059 let result = db
3060 .set_channel_member_role(
3061 channel_id,
3062 session.user_id(),
3063 member_id,
3064 request.role().into(),
3065 )
3066 .await?;
3067
3068 match result {
3069 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3070 let mut connection_pool = session.connection_pool().await;
3071 notify_membership_updated(
3072 &mut connection_pool,
3073 membership_update,
3074 member_id,
3075 &session.peer,
3076 )
3077 }
3078 db::SetMemberRoleResult::InviteUpdated(channel) => {
3079 let update = proto::UpdateChannels {
3080 channel_invitations: vec![channel.to_proto()],
3081 ..Default::default()
3082 };
3083
3084 for connection_id in session
3085 .connection_pool()
3086 .await
3087 .user_connection_ids(member_id)
3088 {
3089 session.peer.send(connection_id, update.clone())?;
3090 }
3091 }
3092 }
3093
3094 response.send(proto::Ack {})?;
3095 Ok(())
3096}
3097
3098/// Change the name of a channel
3099async fn rename_channel(
3100 request: proto::RenameChannel,
3101 response: Response<proto::RenameChannel>,
3102 session: MessageContext,
3103) -> Result<()> {
3104 let db = session.db().await;
3105 let channel_id = ChannelId::from_proto(request.channel_id);
3106 let channel_model = db
3107 .rename_channel(channel_id, session.user_id(), &request.name)
3108 .await?;
3109 let root_id = channel_model.root_id();
3110 let channel = Channel::from_model(channel_model);
3111
3112 response.send(proto::RenameChannelResponse {
3113 channel: Some(channel.to_proto()),
3114 })?;
3115
3116 let connection_pool = session.connection_pool().await;
3117 let update = proto::UpdateChannels {
3118 channels: vec![channel.to_proto()],
3119 ..Default::default()
3120 };
3121 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3122 if role.can_see_channel(channel.visibility) {
3123 session.peer.send(connection_id, update.clone())?;
3124 }
3125 }
3126
3127 Ok(())
3128}
3129
3130/// Move a channel to a new parent.
3131async fn move_channel(
3132 request: proto::MoveChannel,
3133 response: Response<proto::MoveChannel>,
3134 session: MessageContext,
3135) -> Result<()> {
3136 let channel_id = ChannelId::from_proto(request.channel_id);
3137 let to = ChannelId::from_proto(request.to);
3138
3139 let (root_id, channels) = session
3140 .db()
3141 .await
3142 .move_channel(channel_id, to, session.user_id())
3143 .await?;
3144
3145 let connection_pool = session.connection_pool().await;
3146 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3147 let channels = channels
3148 .iter()
3149 .filter_map(|channel| {
3150 if role.can_see_channel(channel.visibility) {
3151 Some(channel.to_proto())
3152 } else {
3153 None
3154 }
3155 })
3156 .collect::<Vec<_>>();
3157 if channels.is_empty() {
3158 continue;
3159 }
3160
3161 let update = proto::UpdateChannels {
3162 channels,
3163 ..Default::default()
3164 };
3165
3166 session.peer.send(connection_id, update.clone())?;
3167 }
3168
3169 response.send(Ack {})?;
3170 Ok(())
3171}
3172
3173async fn reorder_channel(
3174 request: proto::ReorderChannel,
3175 response: Response<proto::ReorderChannel>,
3176 session: MessageContext,
3177) -> Result<()> {
3178 let channel_id = ChannelId::from_proto(request.channel_id);
3179 let direction = request.direction();
3180
3181 let updated_channels = session
3182 .db()
3183 .await
3184 .reorder_channel(channel_id, direction, session.user_id())
3185 .await?;
3186
3187 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3188 let connection_pool = session.connection_pool().await;
3189 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3190 let channels = updated_channels
3191 .iter()
3192 .filter_map(|channel| {
3193 if role.can_see_channel(channel.visibility) {
3194 Some(channel.to_proto())
3195 } else {
3196 None
3197 }
3198 })
3199 .collect::<Vec<_>>();
3200
3201 if channels.is_empty() {
3202 continue;
3203 }
3204
3205 let update = proto::UpdateChannels {
3206 channels,
3207 ..Default::default()
3208 };
3209
3210 session.peer.send(connection_id, update.clone())?;
3211 }
3212 }
3213
3214 response.send(Ack {})?;
3215 Ok(())
3216}
3217
3218/// Get the list of channel members
3219async fn get_channel_members(
3220 request: proto::GetChannelMembers,
3221 response: Response<proto::GetChannelMembers>,
3222 session: MessageContext,
3223) -> Result<()> {
3224 let db = session.db().await;
3225 let channel_id = ChannelId::from_proto(request.channel_id);
3226 let limit = if request.limit == 0 {
3227 u16::MAX as u64
3228 } else {
3229 request.limit
3230 };
3231 let (members, users) = db
3232 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3233 .await?;
3234 response.send(proto::GetChannelMembersResponse { members, users })?;
3235 Ok(())
3236}
3237
3238/// Accept or decline a channel invitation.
3239async fn respond_to_channel_invite(
3240 request: proto::RespondToChannelInvite,
3241 response: Response<proto::RespondToChannelInvite>,
3242 session: MessageContext,
3243) -> Result<()> {
3244 let db = session.db().await;
3245 let channel_id = ChannelId::from_proto(request.channel_id);
3246 let RespondToChannelInvite {
3247 membership_update,
3248 notifications,
3249 } = db
3250 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3251 .await?;
3252
3253 let mut connection_pool = session.connection_pool().await;
3254 if let Some(membership_update) = membership_update {
3255 notify_membership_updated(
3256 &mut connection_pool,
3257 membership_update,
3258 session.user_id(),
3259 &session.peer,
3260 );
3261 } else {
3262 let update = proto::UpdateChannels {
3263 remove_channel_invitations: vec![channel_id.to_proto()],
3264 ..Default::default()
3265 };
3266
3267 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3268 session.peer.send(connection_id, update.clone())?;
3269 }
3270 };
3271
3272 send_notifications(&connection_pool, &session.peer, notifications);
3273
3274 response.send(proto::Ack {})?;
3275
3276 Ok(())
3277}
3278
3279/// Join the channels' room
3280async fn join_channel(
3281 request: proto::JoinChannel,
3282 response: Response<proto::JoinChannel>,
3283 session: MessageContext,
3284) -> Result<()> {
3285 let channel_id = ChannelId::from_proto(request.channel_id);
3286 join_channel_internal(channel_id, Box::new(response), session).await
3287}
3288
3289trait JoinChannelInternalResponse {
3290 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3291}
3292impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3293 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3294 Response::<proto::JoinChannel>::send(self, result)
3295 }
3296}
3297impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3298 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3299 Response::<proto::JoinRoom>::send(self, result)
3300 }
3301}
3302
3303async fn join_channel_internal(
3304 channel_id: ChannelId,
3305 response: Box<impl JoinChannelInternalResponse>,
3306 session: MessageContext,
3307) -> Result<()> {
3308 let joined_room = {
3309 let mut db = session.db().await;
3310 // If zed quits without leaving the room, and the user re-opens zed before the
3311 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3312 // room they were in.
3313 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3314 tracing::info!(
3315 stale_connection_id = %connection,
3316 "cleaning up stale connection",
3317 );
3318 drop(db);
3319 leave_room_for_session(&session, connection).await?;
3320 db = session.db().await;
3321 }
3322
3323 let (joined_room, membership_updated, role) = db
3324 .join_channel(channel_id, session.user_id(), session.connection_id)
3325 .await?;
3326
3327 let live_kit_connection_info =
3328 session
3329 .app_state
3330 .livekit_client
3331 .as_ref()
3332 .and_then(|live_kit| {
3333 let (can_publish, token) = if role == ChannelRole::Guest {
3334 (
3335 false,
3336 live_kit
3337 .guest_token(
3338 &joined_room.room.livekit_room,
3339 &session.user_id().to_string(),
3340 )
3341 .trace_err()?,
3342 )
3343 } else {
3344 (
3345 true,
3346 live_kit
3347 .room_token(
3348 &joined_room.room.livekit_room,
3349 &session.user_id().to_string(),
3350 )
3351 .trace_err()?,
3352 )
3353 };
3354
3355 Some(LiveKitConnectionInfo {
3356 server_url: live_kit.url().into(),
3357 token,
3358 can_publish,
3359 })
3360 });
3361
3362 response.send(proto::JoinRoomResponse {
3363 room: Some(joined_room.room.clone()),
3364 channel_id: joined_room
3365 .channel
3366 .as_ref()
3367 .map(|channel| channel.id.to_proto()),
3368 live_kit_connection_info,
3369 })?;
3370
3371 let mut connection_pool = session.connection_pool().await;
3372 if let Some(membership_updated) = membership_updated {
3373 notify_membership_updated(
3374 &mut connection_pool,
3375 membership_updated,
3376 session.user_id(),
3377 &session.peer,
3378 );
3379 }
3380
3381 room_updated(&joined_room.room, &session.peer);
3382
3383 joined_room
3384 };
3385
3386 channel_updated(
3387 &joined_room.channel.context("channel not returned")?,
3388 &joined_room.room,
3389 &session.peer,
3390 &*session.connection_pool().await,
3391 );
3392
3393 update_user_contacts(session.user_id(), &session).await?;
3394 Ok(())
3395}
3396
3397/// Start editing the channel notes
3398async fn join_channel_buffer(
3399 request: proto::JoinChannelBuffer,
3400 response: Response<proto::JoinChannelBuffer>,
3401 session: MessageContext,
3402) -> Result<()> {
3403 let db = session.db().await;
3404 let channel_id = ChannelId::from_proto(request.channel_id);
3405
3406 let open_response = db
3407 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3408 .await?;
3409
3410 let collaborators = open_response.collaborators.clone();
3411 response.send(open_response)?;
3412
3413 let update = UpdateChannelBufferCollaborators {
3414 channel_id: channel_id.to_proto(),
3415 collaborators: collaborators.clone(),
3416 };
3417 channel_buffer_updated(
3418 session.connection_id,
3419 collaborators
3420 .iter()
3421 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3422 &update,
3423 &session.peer,
3424 );
3425
3426 Ok(())
3427}
3428
3429/// Edit the channel notes
3430async fn update_channel_buffer(
3431 request: proto::UpdateChannelBuffer,
3432 session: MessageContext,
3433) -> Result<()> {
3434 let db = session.db().await;
3435 let channel_id = ChannelId::from_proto(request.channel_id);
3436
3437 let (collaborators, epoch, version) = db
3438 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3439 .await?;
3440
3441 channel_buffer_updated(
3442 session.connection_id,
3443 collaborators.clone(),
3444 &proto::UpdateChannelBuffer {
3445 channel_id: channel_id.to_proto(),
3446 operations: request.operations,
3447 },
3448 &session.peer,
3449 );
3450
3451 let pool = &*session.connection_pool().await;
3452
3453 let non_collaborators =
3454 pool.channel_connection_ids(channel_id)
3455 .filter_map(|(connection_id, _)| {
3456 if collaborators.contains(&connection_id) {
3457 None
3458 } else {
3459 Some(connection_id)
3460 }
3461 });
3462
3463 broadcast(None, non_collaborators, |peer_id| {
3464 session.peer.send(
3465 peer_id,
3466 proto::UpdateChannels {
3467 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3468 channel_id: channel_id.to_proto(),
3469 epoch: epoch as u64,
3470 version: version.clone(),
3471 }],
3472 ..Default::default()
3473 },
3474 )
3475 });
3476
3477 Ok(())
3478}
3479
3480/// Rejoin the channel notes after a connection blip
3481async fn rejoin_channel_buffers(
3482 request: proto::RejoinChannelBuffers,
3483 response: Response<proto::RejoinChannelBuffers>,
3484 session: MessageContext,
3485) -> Result<()> {
3486 let db = session.db().await;
3487 let buffers = db
3488 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3489 .await?;
3490
3491 for rejoined_buffer in &buffers {
3492 let collaborators_to_notify = rejoined_buffer
3493 .buffer
3494 .collaborators
3495 .iter()
3496 .filter_map(|c| Some(c.peer_id?.into()));
3497 channel_buffer_updated(
3498 session.connection_id,
3499 collaborators_to_notify,
3500 &proto::UpdateChannelBufferCollaborators {
3501 channel_id: rejoined_buffer.buffer.channel_id,
3502 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3503 },
3504 &session.peer,
3505 );
3506 }
3507
3508 response.send(proto::RejoinChannelBuffersResponse {
3509 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3510 })?;
3511
3512 Ok(())
3513}
3514
3515/// Stop editing the channel notes
3516async fn leave_channel_buffer(
3517 request: proto::LeaveChannelBuffer,
3518 response: Response<proto::LeaveChannelBuffer>,
3519 session: MessageContext,
3520) -> Result<()> {
3521 let db = session.db().await;
3522 let channel_id = ChannelId::from_proto(request.channel_id);
3523
3524 let left_buffer = db
3525 .leave_channel_buffer(channel_id, session.connection_id)
3526 .await?;
3527
3528 response.send(Ack {})?;
3529
3530 channel_buffer_updated(
3531 session.connection_id,
3532 left_buffer.connections,
3533 &proto::UpdateChannelBufferCollaborators {
3534 channel_id: channel_id.to_proto(),
3535 collaborators: left_buffer.collaborators,
3536 },
3537 &session.peer,
3538 );
3539
3540 Ok(())
3541}
3542
3543fn channel_buffer_updated<T: EnvelopedMessage>(
3544 sender_id: ConnectionId,
3545 collaborators: impl IntoIterator<Item = ConnectionId>,
3546 message: &T,
3547 peer: &Peer,
3548) {
3549 broadcast(Some(sender_id), collaborators, |peer_id| {
3550 peer.send(peer_id, message.clone())
3551 });
3552}
3553
3554fn send_notifications(
3555 connection_pool: &ConnectionPool,
3556 peer: &Peer,
3557 notifications: db::NotificationBatch,
3558) {
3559 for (user_id, notification) in notifications {
3560 for connection_id in connection_pool.user_connection_ids(user_id) {
3561 if let Err(error) = peer.send(
3562 connection_id,
3563 proto::AddNotification {
3564 notification: Some(notification.clone()),
3565 },
3566 ) {
3567 tracing::error!(
3568 "failed to send notification to {:?} {}",
3569 connection_id,
3570 error
3571 );
3572 }
3573 }
3574 }
3575}
3576
3577/// Send a message to the channel
3578async fn send_channel_message(
3579 _request: proto::SendChannelMessage,
3580 _response: Response<proto::SendChannelMessage>,
3581 _session: MessageContext,
3582) -> Result<()> {
3583 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3584}
3585
3586/// Delete a channel message
3587async fn remove_channel_message(
3588 _request: proto::RemoveChannelMessage,
3589 _response: Response<proto::RemoveChannelMessage>,
3590 _session: MessageContext,
3591) -> Result<()> {
3592 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3593}
3594
3595async fn update_channel_message(
3596 _request: proto::UpdateChannelMessage,
3597 _response: Response<proto::UpdateChannelMessage>,
3598 _session: MessageContext,
3599) -> Result<()> {
3600 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3601}
3602
3603/// Mark a channel message as read
3604async fn acknowledge_channel_message(
3605 _request: proto::AckChannelMessage,
3606 _session: MessageContext,
3607) -> Result<()> {
3608 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3609}
3610
3611/// Mark a buffer version as synced
3612async fn acknowledge_buffer_version(
3613 request: proto::AckBufferOperation,
3614 session: MessageContext,
3615) -> Result<()> {
3616 let buffer_id = BufferId::from_proto(request.buffer_id);
3617 session
3618 .db()
3619 .await
3620 .observe_buffer_version(
3621 buffer_id,
3622 session.user_id(),
3623 request.epoch as i32,
3624 &request.version,
3625 )
3626 .await?;
3627 Ok(())
3628}
3629
3630/// Get a Supermaven API key for the user
3631async fn get_supermaven_api_key(
3632 _request: proto::GetSupermavenApiKey,
3633 response: Response<proto::GetSupermavenApiKey>,
3634 session: MessageContext,
3635) -> Result<()> {
3636 let user_id: String = session.user_id().to_string();
3637 if !session.is_staff() {
3638 return Err(anyhow!("supermaven not enabled for this account"))?;
3639 }
3640
3641 let email = session.email().context("user must have an email")?;
3642
3643 let supermaven_admin_api = session
3644 .supermaven_client
3645 .as_ref()
3646 .context("supermaven not configured")?;
3647
3648 let result = supermaven_admin_api
3649 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3650 .await?;
3651
3652 response.send(proto::GetSupermavenApiKeyResponse {
3653 api_key: result.api_key,
3654 })?;
3655
3656 Ok(())
3657}
3658
3659/// Start receiving chat updates for a channel
3660async fn join_channel_chat(
3661 _request: proto::JoinChannelChat,
3662 _response: Response<proto::JoinChannelChat>,
3663 _session: MessageContext,
3664) -> Result<()> {
3665 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3666}
3667
3668/// Stop receiving chat updates for a channel
3669async fn leave_channel_chat(
3670 _request: proto::LeaveChannelChat,
3671 _session: MessageContext,
3672) -> Result<()> {
3673 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3674}
3675
3676/// Retrieve the chat history for a channel
3677async fn get_channel_messages(
3678 _request: proto::GetChannelMessages,
3679 _response: Response<proto::GetChannelMessages>,
3680 _session: MessageContext,
3681) -> Result<()> {
3682 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3683}
3684
3685/// Retrieve specific chat messages
3686async fn get_channel_messages_by_id(
3687 _request: proto::GetChannelMessagesById,
3688 _response: Response<proto::GetChannelMessagesById>,
3689 _session: MessageContext,
3690) -> Result<()> {
3691 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3692}
3693
3694/// Retrieve the current users notifications
3695async fn get_notifications(
3696 request: proto::GetNotifications,
3697 response: Response<proto::GetNotifications>,
3698 session: MessageContext,
3699) -> Result<()> {
3700 let notifications = session
3701 .db()
3702 .await
3703 .get_notifications(
3704 session.user_id(),
3705 NOTIFICATION_COUNT_PER_PAGE,
3706 request.before_id.map(db::NotificationId::from_proto),
3707 )
3708 .await?;
3709 response.send(proto::GetNotificationsResponse {
3710 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3711 notifications,
3712 })?;
3713 Ok(())
3714}
3715
3716/// Mark notifications as read
3717async fn mark_notification_as_read(
3718 request: proto::MarkNotificationRead,
3719 response: Response<proto::MarkNotificationRead>,
3720 session: MessageContext,
3721) -> Result<()> {
3722 let database = &session.db().await;
3723 let notifications = database
3724 .mark_notification_as_read_by_id(
3725 session.user_id(),
3726 NotificationId::from_proto(request.notification_id),
3727 )
3728 .await?;
3729 send_notifications(
3730 &*session.connection_pool().await,
3731 &session.peer,
3732 notifications,
3733 );
3734 response.send(proto::Ack {})?;
3735 Ok(())
3736}
3737
3738fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3739 let message = match message {
3740 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3741 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3742 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3743 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3744 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3745 code: frame.code.into(),
3746 reason: frame.reason.as_str().to_owned().into(),
3747 })),
3748 // We should never receive a frame while reading the message, according
3749 // to the `tungstenite` maintainers:
3750 //
3751 // > It cannot occur when you read messages from the WebSocket, but it
3752 // > can be used when you want to send the raw frames (e.g. you want to
3753 // > send the frames to the WebSocket without composing the full message first).
3754 // >
3755 // > — https://github.com/snapview/tungstenite-rs/issues/268
3756 TungsteniteMessage::Frame(_) => {
3757 bail!("received an unexpected frame while reading the message")
3758 }
3759 };
3760
3761 Ok(message)
3762}
3763
3764fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3765 match message {
3766 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3767 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3768 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3769 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3770 AxumMessage::Close(frame) => {
3771 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3772 code: frame.code.into(),
3773 reason: frame.reason.as_ref().into(),
3774 }))
3775 }
3776 }
3777}
3778
3779fn notify_membership_updated(
3780 connection_pool: &mut ConnectionPool,
3781 result: MembershipUpdated,
3782 user_id: UserId,
3783 peer: &Peer,
3784) {
3785 for membership in &result.new_channels.channel_memberships {
3786 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3787 }
3788 for channel_id in &result.removed_channels {
3789 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3790 }
3791
3792 let user_channels_update = proto::UpdateUserChannels {
3793 channel_memberships: result
3794 .new_channels
3795 .channel_memberships
3796 .iter()
3797 .map(|cm| proto::ChannelMembership {
3798 channel_id: cm.channel_id.to_proto(),
3799 role: cm.role.into(),
3800 })
3801 .collect(),
3802 ..Default::default()
3803 };
3804
3805 let mut update = build_channels_update(result.new_channels);
3806 update.delete_channels = result
3807 .removed_channels
3808 .into_iter()
3809 .map(|id| id.to_proto())
3810 .collect();
3811 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3812
3813 for connection_id in connection_pool.user_connection_ids(user_id) {
3814 peer.send(connection_id, user_channels_update.clone())
3815 .trace_err();
3816 peer.send(connection_id, update.clone()).trace_err();
3817 }
3818}
3819
3820fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3821 proto::UpdateUserChannels {
3822 channel_memberships: channels
3823 .channel_memberships
3824 .iter()
3825 .map(|m| proto::ChannelMembership {
3826 channel_id: m.channel_id.to_proto(),
3827 role: m.role.into(),
3828 })
3829 .collect(),
3830 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3831 }
3832}
3833
3834fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3835 let mut update = proto::UpdateChannels::default();
3836
3837 for channel in channels.channels {
3838 update.channels.push(channel.to_proto());
3839 }
3840
3841 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3842
3843 for (channel_id, participants) in channels.channel_participants {
3844 update
3845 .channel_participants
3846 .push(proto::ChannelParticipants {
3847 channel_id: channel_id.to_proto(),
3848 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3849 });
3850 }
3851
3852 for channel in channels.invited_channels {
3853 update.channel_invitations.push(channel.to_proto());
3854 }
3855
3856 update
3857}
3858
3859fn build_initial_contacts_update(
3860 contacts: Vec<db::Contact>,
3861 pool: &ConnectionPool,
3862) -> proto::UpdateContacts {
3863 let mut update = proto::UpdateContacts::default();
3864
3865 for contact in contacts {
3866 match contact {
3867 db::Contact::Accepted { user_id, busy } => {
3868 update.contacts.push(contact_for_user(user_id, busy, pool));
3869 }
3870 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3871 db::Contact::Incoming { user_id } => {
3872 update
3873 .incoming_requests
3874 .push(proto::IncomingContactRequest {
3875 requester_id: user_id.to_proto(),
3876 })
3877 }
3878 }
3879 }
3880
3881 update
3882}
3883
3884fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3885 proto::Contact {
3886 user_id: user_id.to_proto(),
3887 online: pool.is_user_online(user_id),
3888 busy,
3889 }
3890}
3891
3892fn room_updated(room: &proto::Room, peer: &Peer) {
3893 broadcast(
3894 None,
3895 room.participants
3896 .iter()
3897 .filter_map(|participant| Some(participant.peer_id?.into())),
3898 |peer_id| {
3899 peer.send(
3900 peer_id,
3901 proto::RoomUpdated {
3902 room: Some(room.clone()),
3903 },
3904 )
3905 },
3906 );
3907}
3908
3909fn channel_updated(
3910 channel: &db::channel::Model,
3911 room: &proto::Room,
3912 peer: &Peer,
3913 pool: &ConnectionPool,
3914) {
3915 let participants = room
3916 .participants
3917 .iter()
3918 .map(|p| p.user_id)
3919 .collect::<Vec<_>>();
3920
3921 broadcast(
3922 None,
3923 pool.channel_connection_ids(channel.root_id())
3924 .filter_map(|(channel_id, role)| {
3925 role.can_see_channel(channel.visibility)
3926 .then_some(channel_id)
3927 }),
3928 |peer_id| {
3929 peer.send(
3930 peer_id,
3931 proto::UpdateChannels {
3932 channel_participants: vec![proto::ChannelParticipants {
3933 channel_id: channel.id.to_proto(),
3934 participant_user_ids: participants.clone(),
3935 }],
3936 ..Default::default()
3937 },
3938 )
3939 },
3940 );
3941}
3942
3943async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3944 let db = session.db().await;
3945
3946 let contacts = db.get_contacts(user_id).await?;
3947 let busy = db.is_user_busy(user_id).await?;
3948
3949 let pool = session.connection_pool().await;
3950 let updated_contact = contact_for_user(user_id, busy, &pool);
3951 for contact in contacts {
3952 if let db::Contact::Accepted {
3953 user_id: contact_user_id,
3954 ..
3955 } = contact
3956 {
3957 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3958 session
3959 .peer
3960 .send(
3961 contact_conn_id,
3962 proto::UpdateContacts {
3963 contacts: vec![updated_contact.clone()],
3964 remove_contacts: Default::default(),
3965 incoming_requests: Default::default(),
3966 remove_incoming_requests: Default::default(),
3967 outgoing_requests: Default::default(),
3968 remove_outgoing_requests: Default::default(),
3969 },
3970 )
3971 .trace_err();
3972 }
3973 }
3974 }
3975 Ok(())
3976}
3977
3978async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3979 let mut contacts_to_update = HashSet::default();
3980
3981 let room_id;
3982 let canceled_calls_to_user_ids;
3983 let livekit_room;
3984 let delete_livekit_room;
3985 let room;
3986 let channel;
3987
3988 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3989 contacts_to_update.insert(session.user_id());
3990
3991 for project in left_room.left_projects.values() {
3992 project_left(project, session);
3993 }
3994
3995 room_id = RoomId::from_proto(left_room.room.id);
3996 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3997 livekit_room = mem::take(&mut left_room.room.livekit_room);
3998 delete_livekit_room = left_room.deleted;
3999 room = mem::take(&mut left_room.room);
4000 channel = mem::take(&mut left_room.channel);
4001
4002 room_updated(&room, &session.peer);
4003 } else {
4004 return Ok(());
4005 }
4006
4007 if let Some(channel) = channel {
4008 channel_updated(
4009 &channel,
4010 &room,
4011 &session.peer,
4012 &*session.connection_pool().await,
4013 );
4014 }
4015
4016 {
4017 let pool = session.connection_pool().await;
4018 for canceled_user_id in canceled_calls_to_user_ids {
4019 for connection_id in pool.user_connection_ids(canceled_user_id) {
4020 session
4021 .peer
4022 .send(
4023 connection_id,
4024 proto::CallCanceled {
4025 room_id: room_id.to_proto(),
4026 },
4027 )
4028 .trace_err();
4029 }
4030 contacts_to_update.insert(canceled_user_id);
4031 }
4032 }
4033
4034 for contact_user_id in contacts_to_update {
4035 update_user_contacts(contact_user_id, session).await?;
4036 }
4037
4038 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4039 live_kit
4040 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4041 .await
4042 .trace_err();
4043
4044 if delete_livekit_room {
4045 live_kit.delete_room(livekit_room).await.trace_err();
4046 }
4047 }
4048
4049 Ok(())
4050}
4051
4052async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4053 let left_channel_buffers = session
4054 .db()
4055 .await
4056 .leave_channel_buffers(session.connection_id)
4057 .await?;
4058
4059 for left_buffer in left_channel_buffers {
4060 channel_buffer_updated(
4061 session.connection_id,
4062 left_buffer.connections,
4063 &proto::UpdateChannelBufferCollaborators {
4064 channel_id: left_buffer.channel_id.to_proto(),
4065 collaborators: left_buffer.collaborators,
4066 },
4067 &session.peer,
4068 );
4069 }
4070
4071 Ok(())
4072}
4073
4074fn project_left(project: &db::LeftProject, session: &Session) {
4075 for connection_id in &project.connection_ids {
4076 if project.should_unshare {
4077 session
4078 .peer
4079 .send(
4080 *connection_id,
4081 proto::UnshareProject {
4082 project_id: project.id.to_proto(),
4083 },
4084 )
4085 .trace_err();
4086 } else {
4087 session
4088 .peer
4089 .send(
4090 *connection_id,
4091 proto::RemoveProjectCollaborator {
4092 project_id: project.id.to_proto(),
4093 peer_id: Some(session.connection_id.into()),
4094 },
4095 )
4096 .trace_err();
4097 }
4098 }
4099}
4100
4101pub trait ResultExt {
4102 type Ok;
4103
4104 fn trace_err(self) -> Option<Self::Ok>;
4105}
4106
4107impl<T, E> ResultExt for Result<T, E>
4108where
4109 E: std::fmt::Debug,
4110{
4111 type Ok = T;
4112
4113 #[track_caller]
4114 fn trace_err(self) -> Option<T> {
4115 match self {
4116 Ok(value) => Some(value),
4117 Err(error) => {
4118 tracing::error!("{:?}", error);
4119 None
4120 }
4121 }
4122 }
4123}