1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
10 },
11 executor::Executor,
12};
13use anyhow::{Context as _, anyhow, bail};
14use async_tungstenite::tungstenite::{
15 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
16};
17use axum::headers::UserAgent;
18use axum::{
19 Extension, Router, TypedHeader,
20 body::Body,
21 extract::{
22 ConnectInfo, WebSocketUpgrade,
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use core::fmt::{self, Debug, Formatter};
34use futures::TryFutureExt as _;
35use reqwest_client::ReqwestClient;
36use rpc::proto::split_repository_update;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38use tracing::Span;
39use util::paths::PathStyle;
40
41use futures::{
42 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
43 stream::FuturesUnordered,
44};
45use prometheus::{IntGauge, register_int_gauge};
46use rpc::{
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52};
53use semantic_version::SemanticVersion;
54use serde::{Serialize, Serializer};
55use std::{
56 any::TypeId,
57 future::Future,
58 marker::PhantomData,
59 mem,
60 net::SocketAddr,
61 ops::{Deref, DerefMut},
62 rc::Rc,
63 sync::{
64 Arc, OnceLock,
65 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
66 },
67 time::{Duration, Instant},
68};
69use tokio::sync::{Semaphore, watch};
70use tower::ServiceBuilder;
71use tracing::{
72 Instrument,
73 field::{self},
74 info_span, instrument,
75};
76
77pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
78
79// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
80pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
81
82const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
83const MAX_CONCURRENT_CONNECTIONS: usize = 512;
84
85static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
86
87const TOTAL_DURATION_MS: &str = "total_duration_ms";
88const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
89const QUEUE_DURATION_MS: &str = "queue_duration_ms";
90const HOST_WAITING_MS: &str = "host_waiting_ms";
91
92type MessageHandler =
93 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
94
95pub struct ConnectionGuard;
96
97impl ConnectionGuard {
98 pub fn try_acquire() -> Result<Self, ()> {
99 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
100 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
101 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
102 tracing::error!(
103 "too many concurrent connections: {}",
104 current_connections + 1
105 );
106 return Err(());
107 }
108 Ok(ConnectionGuard)
109 }
110}
111
112impl Drop for ConnectionGuard {
113 fn drop(&mut self) {
114 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
115 }
116}
117
118struct Response<R> {
119 peer: Arc<Peer>,
120 receipt: Receipt<R>,
121 responded: Arc<AtomicBool>,
122}
123
124impl<R: RequestMessage> Response<R> {
125 fn send(self, payload: R::Response) -> Result<()> {
126 self.responded.store(true, SeqCst);
127 self.peer.respond(self.receipt, payload)?;
128 Ok(())
129 }
130}
131
132#[derive(Clone, Debug)]
133pub enum Principal {
134 User(User),
135 Impersonated { user: User, admin: User },
136}
137
138impl Principal {
139 fn update_span(&self, span: &tracing::Span) {
140 match &self {
141 Principal::User(user) => {
142 span.record("user_id", user.id.0);
143 span.record("login", &user.github_login);
144 }
145 Principal::Impersonated { user, admin } => {
146 span.record("user_id", user.id.0);
147 span.record("login", &user.github_login);
148 span.record("impersonator", &admin.github_login);
149 }
150 }
151 }
152}
153
154#[derive(Clone)]
155struct MessageContext {
156 session: Session,
157 span: tracing::Span,
158}
159
160impl Deref for MessageContext {
161 type Target = Session;
162
163 fn deref(&self) -> &Self::Target {
164 &self.session
165 }
166}
167
168impl MessageContext {
169 pub fn forward_request<T: RequestMessage>(
170 &self,
171 receiver_id: ConnectionId,
172 request: T,
173 ) -> impl Future<Output = anyhow::Result<T::Response>> {
174 let request_start_time = Instant::now();
175 let span = self.span.clone();
176 tracing::info!("start forwarding request");
177 self.peer
178 .forward_request(self.connection_id, receiver_id, request)
179 .inspect(move |_| {
180 span.record(
181 HOST_WAITING_MS,
182 request_start_time.elapsed().as_micros() as f64 / 1000.0,
183 );
184 })
185 .inspect_err(|_| tracing::error!("error forwarding request"))
186 .inspect_ok(|_| tracing::info!("finished forwarding request"))
187 }
188}
189
190#[derive(Clone)]
191struct Session {
192 principal: Principal,
193 connection_id: ConnectionId,
194 db: Arc<tokio::sync::Mutex<DbHandle>>,
195 peer: Arc<Peer>,
196 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
197 app_state: Arc<AppState>,
198 supermaven_client: Option<Arc<SupermavenAdminApi>>,
199 /// The GeoIP country code for the user.
200 #[allow(unused)]
201 geoip_country_code: Option<String>,
202 #[allow(unused)]
203 system_id: Option<String>,
204 _executor: Executor,
205}
206
207impl Session {
208 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
209 #[cfg(test)]
210 tokio::task::yield_now().await;
211 let guard = self.db.lock().await;
212 #[cfg(test)]
213 tokio::task::yield_now().await;
214 guard
215 }
216
217 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
218 #[cfg(test)]
219 tokio::task::yield_now().await;
220 let guard = self.connection_pool.lock();
221 ConnectionPoolGuard {
222 guard,
223 _not_send: PhantomData,
224 }
225 }
226
227 fn is_staff(&self) -> bool {
228 match &self.principal {
229 Principal::User(user) => user.admin,
230 Principal::Impersonated { .. } => true,
231 }
232 }
233
234 fn user_id(&self) -> UserId {
235 match &self.principal {
236 Principal::User(user) => user.id,
237 Principal::Impersonated { user, .. } => user.id,
238 }
239 }
240
241 pub fn email(&self) -> Option<String> {
242 match &self.principal {
243 Principal::User(user) => user.email_address.clone(),
244 Principal::Impersonated { user, .. } => user.email_address.clone(),
245 }
246 }
247}
248
249impl Debug for Session {
250 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
251 let mut result = f.debug_struct("Session");
252 match &self.principal {
253 Principal::User(user) => {
254 result.field("user", &user.github_login);
255 }
256 Principal::Impersonated { user, admin } => {
257 result.field("user", &user.github_login);
258 result.field("impersonator", &admin.github_login);
259 }
260 }
261 result.field("connection_id", &self.connection_id).finish()
262 }
263}
264
265struct DbHandle(Arc<Database>);
266
267impl Deref for DbHandle {
268 type Target = Database;
269
270 fn deref(&self) -> &Self::Target {
271 self.0.as_ref()
272 }
273}
274
275pub struct Server {
276 id: parking_lot::Mutex<ServerId>,
277 peer: Arc<Peer>,
278 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
279 app_state: Arc<AppState>,
280 handlers: HashMap<TypeId, MessageHandler>,
281 teardown: watch::Sender<bool>,
282}
283
284pub(crate) struct ConnectionPoolGuard<'a> {
285 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
286 _not_send: PhantomData<Rc<()>>,
287}
288
289#[derive(Serialize)]
290pub struct ServerSnapshot<'a> {
291 peer: &'a Peer,
292 #[serde(serialize_with = "serialize_deref")]
293 connection_pool: ConnectionPoolGuard<'a>,
294}
295
296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
297where
298 S: Serializer,
299 T: Deref<Target = U>,
300 U: Serialize,
301{
302 Serialize::serialize(value.deref(), serializer)
303}
304
305impl Server {
306 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
307 let mut server = Self {
308 id: parking_lot::Mutex::new(id),
309 peer: Peer::new(id.0 as u32),
310 app_state,
311 connection_pool: Default::default(),
312 handlers: Default::default(),
313 teardown: watch::channel(false).0,
314 };
315
316 server
317 .add_request_handler(ping)
318 .add_request_handler(create_room)
319 .add_request_handler(join_room)
320 .add_request_handler(rejoin_room)
321 .add_request_handler(leave_room)
322 .add_request_handler(set_room_participant_role)
323 .add_request_handler(call)
324 .add_request_handler(cancel_call)
325 .add_message_handler(decline_call)
326 .add_request_handler(update_participant_location)
327 .add_request_handler(share_project)
328 .add_message_handler(unshare_project)
329 .add_request_handler(join_project)
330 .add_message_handler(leave_project)
331 .add_request_handler(update_project)
332 .add_request_handler(update_worktree)
333 .add_request_handler(update_repository)
334 .add_request_handler(remove_repository)
335 .add_message_handler(start_language_server)
336 .add_message_handler(update_language_server)
337 .add_message_handler(update_diagnostic_summary)
338 .add_message_handler(update_worktree_settings)
339 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
340 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
341 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
342 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
343 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
344 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
345 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
346 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
347 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
348 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
349 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
350 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
351 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
352 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
353 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
354 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
355 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
356 .add_request_handler(
357 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
358 )
359 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
360 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
361 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
362 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
363 .add_request_handler(
364 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
365 )
366 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
367 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
368 .add_request_handler(
369 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
370 )
371 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
372 .add_request_handler(
373 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
374 )
375 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
376 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
377 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
378 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
379 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
380 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
381 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
382 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
383 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
384 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
385 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
386 .add_request_handler(
387 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
388 )
389 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
390 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
391 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
392 .add_request_handler(lsp_query)
393 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
394 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
395 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
396 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
397 .add_message_handler(create_buffer_for_peer)
398 .add_request_handler(update_buffer)
399 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
400 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
401 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
402 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
403 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
404 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
405 .add_message_handler(
406 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
407 )
408 .add_request_handler(get_users)
409 .add_request_handler(fuzzy_search_users)
410 .add_request_handler(request_contact)
411 .add_request_handler(remove_contact)
412 .add_request_handler(respond_to_contact_request)
413 .add_message_handler(subscribe_to_channels)
414 .add_request_handler(create_channel)
415 .add_request_handler(delete_channel)
416 .add_request_handler(invite_channel_member)
417 .add_request_handler(remove_channel_member)
418 .add_request_handler(set_channel_member_role)
419 .add_request_handler(set_channel_visibility)
420 .add_request_handler(rename_channel)
421 .add_request_handler(join_channel_buffer)
422 .add_request_handler(leave_channel_buffer)
423 .add_message_handler(update_channel_buffer)
424 .add_request_handler(rejoin_channel_buffers)
425 .add_request_handler(get_channel_members)
426 .add_request_handler(respond_to_channel_invite)
427 .add_request_handler(join_channel)
428 .add_request_handler(join_channel_chat)
429 .add_message_handler(leave_channel_chat)
430 .add_request_handler(send_channel_message)
431 .add_request_handler(remove_channel_message)
432 .add_request_handler(update_channel_message)
433 .add_request_handler(get_channel_messages)
434 .add_request_handler(get_channel_messages_by_id)
435 .add_request_handler(get_notifications)
436 .add_request_handler(mark_notification_as_read)
437 .add_request_handler(move_channel)
438 .add_request_handler(reorder_channel)
439 .add_request_handler(follow)
440 .add_message_handler(unfollow)
441 .add_message_handler(update_followers)
442 .add_message_handler(acknowledge_channel_message)
443 .add_message_handler(acknowledge_buffer_version)
444 .add_request_handler(get_supermaven_api_key)
445 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
446 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
447 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
448 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
449 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
450 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
451 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
452 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
453 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
454 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
455 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
456 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
457 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
458 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
459 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
460 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
461 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
462 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
463 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
464 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
465 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
466 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
467 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
468 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
469 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
470 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
471 .add_message_handler(update_context)
472 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
473 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
474
475 Arc::new(server)
476 }
477
478 pub async fn start(&self) -> Result<()> {
479 let server_id = *self.id.lock();
480 let app_state = self.app_state.clone();
481 let peer = self.peer.clone();
482 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
483 let pool = self.connection_pool.clone();
484 let livekit_client = self.app_state.livekit_client.clone();
485
486 let span = info_span!("start server");
487 self.app_state.executor.spawn_detached(
488 async move {
489 tracing::info!("waiting for cleanup timeout");
490 timeout.await;
491 tracing::info!("cleanup timeout expired, retrieving stale rooms");
492
493 app_state
494 .db
495 .delete_stale_channel_chat_participants(
496 &app_state.config.zed_environment,
497 server_id,
498 )
499 .await
500 .trace_err();
501
502 if let Some((room_ids, channel_ids)) = app_state
503 .db
504 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
505 .await
506 .trace_err()
507 {
508 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
509 tracing::info!(
510 stale_channel_buffer_count = channel_ids.len(),
511 "retrieved stale channel buffers"
512 );
513
514 for channel_id in channel_ids {
515 if let Some(refreshed_channel_buffer) = app_state
516 .db
517 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
518 .await
519 .trace_err()
520 {
521 for connection_id in refreshed_channel_buffer.connection_ids {
522 peer.send(
523 connection_id,
524 proto::UpdateChannelBufferCollaborators {
525 channel_id: channel_id.to_proto(),
526 collaborators: refreshed_channel_buffer
527 .collaborators
528 .clone(),
529 },
530 )
531 .trace_err();
532 }
533 }
534 }
535
536 for room_id in room_ids {
537 let mut contacts_to_update = HashSet::default();
538 let mut canceled_calls_to_user_ids = Vec::new();
539 let mut livekit_room = String::new();
540 let mut delete_livekit_room = false;
541
542 if let Some(mut refreshed_room) = app_state
543 .db
544 .clear_stale_room_participants(room_id, server_id)
545 .await
546 .trace_err()
547 {
548 tracing::info!(
549 room_id = room_id.0,
550 new_participant_count = refreshed_room.room.participants.len(),
551 "refreshed room"
552 );
553 room_updated(&refreshed_room.room, &peer);
554 if let Some(channel) = refreshed_room.channel.as_ref() {
555 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
556 }
557 contacts_to_update
558 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
559 contacts_to_update
560 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
561 canceled_calls_to_user_ids =
562 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
563 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
564 delete_livekit_room = refreshed_room.room.participants.is_empty();
565 }
566
567 {
568 let pool = pool.lock();
569 for canceled_user_id in canceled_calls_to_user_ids {
570 for connection_id in pool.user_connection_ids(canceled_user_id) {
571 peer.send(
572 connection_id,
573 proto::CallCanceled {
574 room_id: room_id.to_proto(),
575 },
576 )
577 .trace_err();
578 }
579 }
580 }
581
582 for user_id in contacts_to_update {
583 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
584 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
585 if let Some((busy, contacts)) = busy.zip(contacts) {
586 let pool = pool.lock();
587 let updated_contact = contact_for_user(user_id, busy, &pool);
588 for contact in contacts {
589 if let db::Contact::Accepted {
590 user_id: contact_user_id,
591 ..
592 } = contact
593 {
594 for contact_conn_id in
595 pool.user_connection_ids(contact_user_id)
596 {
597 peer.send(
598 contact_conn_id,
599 proto::UpdateContacts {
600 contacts: vec![updated_contact.clone()],
601 remove_contacts: Default::default(),
602 incoming_requests: Default::default(),
603 remove_incoming_requests: Default::default(),
604 outgoing_requests: Default::default(),
605 remove_outgoing_requests: Default::default(),
606 },
607 )
608 .trace_err();
609 }
610 }
611 }
612 }
613 }
614
615 if let Some(live_kit) = livekit_client.as_ref()
616 && delete_livekit_room
617 {
618 live_kit.delete_room(livekit_room).await.trace_err();
619 }
620 }
621 }
622
623 app_state
624 .db
625 .delete_stale_channel_chat_participants(
626 &app_state.config.zed_environment,
627 server_id,
628 )
629 .await
630 .trace_err();
631
632 app_state
633 .db
634 .clear_old_worktree_entries(server_id)
635 .await
636 .trace_err();
637
638 app_state
639 .db
640 .delete_stale_servers(&app_state.config.zed_environment, server_id)
641 .await
642 .trace_err();
643 }
644 .instrument(span),
645 );
646 Ok(())
647 }
648
649 pub fn teardown(&self) {
650 self.peer.teardown();
651 self.connection_pool.lock().reset();
652 let _ = self.teardown.send(true);
653 }
654
655 #[cfg(test)]
656 pub fn reset(&self, id: ServerId) {
657 self.teardown();
658 *self.id.lock() = id;
659 self.peer.reset(id.0 as u32);
660 let _ = self.teardown.send(false);
661 }
662
663 #[cfg(test)]
664 pub fn id(&self) -> ServerId {
665 *self.id.lock()
666 }
667
668 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
669 where
670 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
671 Fut: 'static + Send + Future<Output = Result<()>>,
672 M: EnvelopedMessage,
673 {
674 let prev_handler = self.handlers.insert(
675 TypeId::of::<M>(),
676 Box::new(move |envelope, session, span| {
677 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
678 let received_at = envelope.received_at;
679 tracing::info!("message received");
680 let start_time = Instant::now();
681 let future = (handler)(
682 *envelope,
683 MessageContext {
684 session,
685 span: span.clone(),
686 },
687 );
688 async move {
689 let result = future.await;
690 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
691 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
692 let queue_duration_ms = total_duration_ms - processing_duration_ms;
693 span.record(TOTAL_DURATION_MS, total_duration_ms);
694 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
695 span.record(QUEUE_DURATION_MS, queue_duration_ms);
696 match result {
697 Err(error) => {
698 tracing::error!(?error, "error handling message")
699 }
700 Ok(()) => tracing::info!("finished handling message"),
701 }
702 }
703 .boxed()
704 }),
705 );
706 if prev_handler.is_some() {
707 panic!("registered a handler for the same message twice");
708 }
709 self
710 }
711
712 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
713 where
714 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
715 Fut: 'static + Send + Future<Output = Result<()>>,
716 M: EnvelopedMessage,
717 {
718 self.add_handler(move |envelope, session| handler(envelope.payload, session));
719 self
720 }
721
722 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
723 where
724 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
725 Fut: Send + Future<Output = Result<()>>,
726 M: RequestMessage,
727 {
728 let handler = Arc::new(handler);
729 self.add_handler(move |envelope, session| {
730 let receipt = envelope.receipt();
731 let handler = handler.clone();
732 async move {
733 let peer = session.peer.clone();
734 let responded = Arc::new(AtomicBool::default());
735 let response = Response {
736 peer: peer.clone(),
737 responded: responded.clone(),
738 receipt,
739 };
740 match (handler)(envelope.payload, response, session).await {
741 Ok(()) => {
742 if responded.load(std::sync::atomic::Ordering::SeqCst) {
743 Ok(())
744 } else {
745 Err(anyhow!("handler did not send a response"))?
746 }
747 }
748 Err(error) => {
749 let proto_err = match &error {
750 Error::Internal(err) => err.to_proto(),
751 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
752 };
753 peer.respond_with_error(receipt, proto_err)?;
754 Err(error)
755 }
756 }
757 }
758 })
759 }
760
761 pub fn handle_connection(
762 self: &Arc<Self>,
763 connection: Connection,
764 address: String,
765 principal: Principal,
766 zed_version: ZedVersion,
767 release_channel: Option<String>,
768 user_agent: Option<String>,
769 geoip_country_code: Option<String>,
770 system_id: Option<String>,
771 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
772 executor: Executor,
773 connection_guard: Option<ConnectionGuard>,
774 ) -> impl Future<Output = ()> + use<> {
775 let this = self.clone();
776 let span = info_span!("handle connection", %address,
777 connection_id=field::Empty,
778 user_id=field::Empty,
779 login=field::Empty,
780 impersonator=field::Empty,
781 user_agent=field::Empty,
782 geoip_country_code=field::Empty,
783 release_channel=field::Empty,
784 );
785 principal.update_span(&span);
786 if let Some(user_agent) = user_agent {
787 span.record("user_agent", user_agent);
788 }
789 if let Some(release_channel) = release_channel {
790 span.record("release_channel", release_channel);
791 }
792
793 if let Some(country_code) = geoip_country_code.as_ref() {
794 span.record("geoip_country_code", country_code);
795 }
796
797 let mut teardown = self.teardown.subscribe();
798 async move {
799 if *teardown.borrow() {
800 tracing::error!("server is tearing down");
801 return;
802 }
803
804 let (connection_id, handle_io, mut incoming_rx) =
805 this.peer.add_connection(connection, {
806 let executor = executor.clone();
807 move |duration| executor.sleep(duration)
808 });
809 tracing::Span::current().record("connection_id", format!("{}", connection_id));
810
811 tracing::info!("connection opened");
812
813 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
814 let http_client = match ReqwestClient::user_agent(&user_agent) {
815 Ok(http_client) => Arc::new(http_client),
816 Err(error) => {
817 tracing::error!(?error, "failed to create HTTP client");
818 return;
819 }
820 };
821
822 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
823 |supermaven_admin_api_key| {
824 Arc::new(SupermavenAdminApi::new(
825 supermaven_admin_api_key.to_string(),
826 http_client.clone(),
827 ))
828 },
829 );
830
831 let session = Session {
832 principal: principal.clone(),
833 connection_id,
834 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
835 peer: this.peer.clone(),
836 connection_pool: this.connection_pool.clone(),
837 app_state: this.app_state.clone(),
838 geoip_country_code,
839 system_id,
840 _executor: executor.clone(),
841 supermaven_client,
842 };
843
844 if let Err(error) = this
845 .send_initial_client_update(
846 connection_id,
847 zed_version,
848 send_connection_id,
849 &session,
850 )
851 .await
852 {
853 tracing::error!(?error, "failed to send initial client update");
854 return;
855 }
856 drop(connection_guard);
857
858 let handle_io = handle_io.fuse();
859 futures::pin_mut!(handle_io);
860
861 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
862 // This prevents deadlocks when e.g., client A performs a request to client B and
863 // client B performs a request to client A. If both clients stop processing further
864 // messages until their respective request completes, they won't have a chance to
865 // respond to the other client's request and cause a deadlock.
866 //
867 // This arrangement ensures we will attempt to process earlier messages first, but fall
868 // back to processing messages arrived later in the spirit of making progress.
869 const MAX_CONCURRENT_HANDLERS: usize = 256;
870 let mut foreground_message_handlers = FuturesUnordered::new();
871 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
872 let get_concurrent_handlers = {
873 let concurrent_handlers = concurrent_handlers.clone();
874 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
875 };
876 loop {
877 let next_message = async {
878 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
879 let message = incoming_rx.next().await;
880 // Cache the concurrent_handlers here, so that we know what the
881 // queue looks like as each handler starts
882 (permit, message, get_concurrent_handlers())
883 }
884 .fuse();
885 futures::pin_mut!(next_message);
886 futures::select_biased! {
887 _ = teardown.changed().fuse() => return,
888 result = handle_io => {
889 if let Err(error) = result {
890 tracing::error!(?error, "error handling I/O");
891 }
892 break;
893 }
894 _ = foreground_message_handlers.next() => {}
895 next_message = next_message => {
896 let (permit, message, concurrent_handlers) = next_message;
897 if let Some(message) = message {
898 let type_name = message.payload_type_name();
899 // note: we copy all the fields from the parent span so we can query them in the logs.
900 // (https://github.com/tokio-rs/tracing/issues/2670).
901 let span = tracing::info_span!("receive message",
902 %connection_id,
903 %address,
904 type_name,
905 concurrent_handlers,
906 user_id=field::Empty,
907 login=field::Empty,
908 impersonator=field::Empty,
909 lsp_query_request=field::Empty,
910 release_channel=field::Empty,
911 { TOTAL_DURATION_MS }=field::Empty,
912 { PROCESSING_DURATION_MS }=field::Empty,
913 { QUEUE_DURATION_MS }=field::Empty,
914 { HOST_WAITING_MS }=field::Empty
915 );
916 principal.update_span(&span);
917 let span_enter = span.enter();
918 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
919 let is_background = message.is_background();
920 let handle_message = (handler)(message, session.clone(), span.clone());
921 drop(span_enter);
922
923 let handle_message = async move {
924 handle_message.await;
925 drop(permit);
926 }.instrument(span);
927 if is_background {
928 executor.spawn_detached(handle_message);
929 } else {
930 foreground_message_handlers.push(handle_message);
931 }
932 } else {
933 tracing::error!("no message handler");
934 }
935 } else {
936 tracing::info!("connection closed");
937 break;
938 }
939 }
940 }
941 }
942
943 drop(foreground_message_handlers);
944 let concurrent_handlers = get_concurrent_handlers();
945 tracing::info!(concurrent_handlers, "signing out");
946 if let Err(error) = connection_lost(session, teardown, executor).await {
947 tracing::error!(?error, "error signing out");
948 }
949 }
950 .instrument(span)
951 }
952
953 async fn send_initial_client_update(
954 &self,
955 connection_id: ConnectionId,
956 zed_version: ZedVersion,
957 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
958 session: &Session,
959 ) -> Result<()> {
960 self.peer.send(
961 connection_id,
962 proto::Hello {
963 peer_id: Some(connection_id.into()),
964 },
965 )?;
966 tracing::info!("sent hello message");
967 if let Some(send_connection_id) = send_connection_id.take() {
968 let _ = send_connection_id.send(connection_id);
969 }
970
971 match &session.principal {
972 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
973 if !user.connected_once {
974 self.peer.send(connection_id, proto::ShowContacts {})?;
975 self.app_state
976 .db
977 .set_user_connected_once(user.id, true)
978 .await?;
979 }
980
981 let contacts = self.app_state.db.get_contacts(user.id).await?;
982
983 {
984 let mut pool = self.connection_pool.lock();
985 pool.add_connection(connection_id, user.id, user.admin, zed_version);
986 self.peer.send(
987 connection_id,
988 build_initial_contacts_update(contacts, &pool),
989 )?;
990 }
991
992 if should_auto_subscribe_to_channels(zed_version) {
993 subscribe_user_to_channels(user.id, session).await?;
994 }
995
996 if let Some(incoming_call) =
997 self.app_state.db.incoming_call_for_user(user.id).await?
998 {
999 self.peer.send(connection_id, incoming_call)?;
1000 }
1001
1002 update_user_contacts(user.id, session).await?;
1003 }
1004 }
1005
1006 Ok(())
1007 }
1008
1009 pub async fn invite_code_redeemed(
1010 self: &Arc<Self>,
1011 inviter_id: UserId,
1012 invitee_id: UserId,
1013 ) -> Result<()> {
1014 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1015 && let Some(code) = &user.invite_code
1016 {
1017 let pool = self.connection_pool.lock();
1018 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1019 for connection_id in pool.user_connection_ids(inviter_id) {
1020 self.peer.send(
1021 connection_id,
1022 proto::UpdateContacts {
1023 contacts: vec![invitee_contact.clone()],
1024 ..Default::default()
1025 },
1026 )?;
1027 self.peer.send(
1028 connection_id,
1029 proto::UpdateInviteInfo {
1030 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1031 count: user.invite_count as u32,
1032 },
1033 )?;
1034 }
1035 }
1036 Ok(())
1037 }
1038
1039 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1040 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1041 && let Some(invite_code) = &user.invite_code
1042 {
1043 let pool = self.connection_pool.lock();
1044 for connection_id in pool.user_connection_ids(user_id) {
1045 self.peer.send(
1046 connection_id,
1047 proto::UpdateInviteInfo {
1048 url: format!(
1049 "{}{}",
1050 self.app_state.config.invite_link_prefix, invite_code
1051 ),
1052 count: user.invite_count as u32,
1053 },
1054 )?;
1055 }
1056 }
1057 Ok(())
1058 }
1059
1060 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1061 ServerSnapshot {
1062 connection_pool: ConnectionPoolGuard {
1063 guard: self.connection_pool.lock(),
1064 _not_send: PhantomData,
1065 },
1066 peer: &self.peer,
1067 }
1068 }
1069}
1070
1071impl Deref for ConnectionPoolGuard<'_> {
1072 type Target = ConnectionPool;
1073
1074 fn deref(&self) -> &Self::Target {
1075 &self.guard
1076 }
1077}
1078
1079impl DerefMut for ConnectionPoolGuard<'_> {
1080 fn deref_mut(&mut self) -> &mut Self::Target {
1081 &mut self.guard
1082 }
1083}
1084
1085impl Drop for ConnectionPoolGuard<'_> {
1086 fn drop(&mut self) {
1087 #[cfg(test)]
1088 self.check_invariants();
1089 }
1090}
1091
1092fn broadcast<F>(
1093 sender_id: Option<ConnectionId>,
1094 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1095 mut f: F,
1096) where
1097 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1098{
1099 for receiver_id in receiver_ids {
1100 if Some(receiver_id) != sender_id
1101 && let Err(error) = f(receiver_id)
1102 {
1103 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1104 }
1105 }
1106}
1107
1108pub struct ProtocolVersion(u32);
1109
1110impl Header for ProtocolVersion {
1111 fn name() -> &'static HeaderName {
1112 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1113 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1114 }
1115
1116 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1117 where
1118 Self: Sized,
1119 I: Iterator<Item = &'i axum::http::HeaderValue>,
1120 {
1121 let version = values
1122 .next()
1123 .ok_or_else(axum::headers::Error::invalid)?
1124 .to_str()
1125 .map_err(|_| axum::headers::Error::invalid())?
1126 .parse()
1127 .map_err(|_| axum::headers::Error::invalid())?;
1128 Ok(Self(version))
1129 }
1130
1131 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1132 values.extend([self.0.to_string().parse().unwrap()]);
1133 }
1134}
1135
1136pub struct AppVersionHeader(SemanticVersion);
1137impl Header for AppVersionHeader {
1138 fn name() -> &'static HeaderName {
1139 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1140 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1141 }
1142
1143 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1144 where
1145 Self: Sized,
1146 I: Iterator<Item = &'i axum::http::HeaderValue>,
1147 {
1148 let version = values
1149 .next()
1150 .ok_or_else(axum::headers::Error::invalid)?
1151 .to_str()
1152 .map_err(|_| axum::headers::Error::invalid())?
1153 .parse()
1154 .map_err(|_| axum::headers::Error::invalid())?;
1155 Ok(Self(version))
1156 }
1157
1158 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1159 values.extend([self.0.to_string().parse().unwrap()]);
1160 }
1161}
1162
1163#[derive(Debug)]
1164pub struct ReleaseChannelHeader(String);
1165
1166impl Header for ReleaseChannelHeader {
1167 fn name() -> &'static HeaderName {
1168 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1169 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1170 }
1171
1172 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1173 where
1174 Self: Sized,
1175 I: Iterator<Item = &'i axum::http::HeaderValue>,
1176 {
1177 Ok(Self(
1178 values
1179 .next()
1180 .ok_or_else(axum::headers::Error::invalid)?
1181 .to_str()
1182 .map_err(|_| axum::headers::Error::invalid())?
1183 .to_owned(),
1184 ))
1185 }
1186
1187 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1188 values.extend([self.0.parse().unwrap()]);
1189 }
1190}
1191
1192pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1193 Router::new()
1194 .route("/rpc", get(handle_websocket_request))
1195 .layer(
1196 ServiceBuilder::new()
1197 .layer(Extension(server.app_state.clone()))
1198 .layer(middleware::from_fn(auth::validate_header)),
1199 )
1200 .route("/metrics", get(handle_metrics))
1201 .layer(Extension(server))
1202}
1203
1204pub async fn handle_websocket_request(
1205 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1206 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1207 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1208 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1209 Extension(server): Extension<Arc<Server>>,
1210 Extension(principal): Extension<Principal>,
1211 user_agent: Option<TypedHeader<UserAgent>>,
1212 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1213 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1214 ws: WebSocketUpgrade,
1215) -> axum::response::Response {
1216 if protocol_version != rpc::PROTOCOL_VERSION {
1217 return (
1218 StatusCode::UPGRADE_REQUIRED,
1219 "client must be upgraded".to_string(),
1220 )
1221 .into_response();
1222 }
1223
1224 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1225 return (
1226 StatusCode::UPGRADE_REQUIRED,
1227 "no version header found".to_string(),
1228 )
1229 .into_response();
1230 };
1231
1232 let release_channel = release_channel_header.map(|header| header.0.0);
1233
1234 if !version.can_collaborate() {
1235 return (
1236 StatusCode::UPGRADE_REQUIRED,
1237 "client must be upgraded".to_string(),
1238 )
1239 .into_response();
1240 }
1241
1242 let socket_address = socket_address.to_string();
1243
1244 // Acquire connection guard before WebSocket upgrade
1245 let connection_guard = match ConnectionGuard::try_acquire() {
1246 Ok(guard) => guard,
1247 Err(()) => {
1248 return (
1249 StatusCode::SERVICE_UNAVAILABLE,
1250 "Too many concurrent connections",
1251 )
1252 .into_response();
1253 }
1254 };
1255
1256 ws.on_upgrade(move |socket| {
1257 let socket = socket
1258 .map_ok(to_tungstenite_message)
1259 .err_into()
1260 .with(|message| async move { to_axum_message(message) });
1261 let connection = Connection::new(Box::pin(socket));
1262 async move {
1263 server
1264 .handle_connection(
1265 connection,
1266 socket_address,
1267 principal,
1268 version,
1269 release_channel,
1270 user_agent.map(|header| header.to_string()),
1271 country_code_header.map(|header| header.to_string()),
1272 system_id_header.map(|header| header.to_string()),
1273 None,
1274 Executor::Production,
1275 Some(connection_guard),
1276 )
1277 .await;
1278 }
1279 })
1280}
1281
1282pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1283 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1284 let connections_metric = CONNECTIONS_METRIC
1285 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1286
1287 let connections = server
1288 .connection_pool
1289 .lock()
1290 .connections()
1291 .filter(|connection| !connection.admin)
1292 .count();
1293 connections_metric.set(connections as _);
1294
1295 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1296 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1297 register_int_gauge!(
1298 "shared_projects",
1299 "number of open projects with one or more guests"
1300 )
1301 .unwrap()
1302 });
1303
1304 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1305 shared_projects_metric.set(shared_projects as _);
1306
1307 let encoder = prometheus::TextEncoder::new();
1308 let metric_families = prometheus::gather();
1309 let encoded_metrics = encoder
1310 .encode_to_string(&metric_families)
1311 .map_err(|err| anyhow!("{err}"))?;
1312 Ok(encoded_metrics)
1313}
1314
1315#[instrument(err, skip(executor))]
1316async fn connection_lost(
1317 session: Session,
1318 mut teardown: watch::Receiver<bool>,
1319 executor: Executor,
1320) -> Result<()> {
1321 session.peer.disconnect(session.connection_id);
1322 session
1323 .connection_pool()
1324 .await
1325 .remove_connection(session.connection_id)?;
1326
1327 session
1328 .db()
1329 .await
1330 .connection_lost(session.connection_id)
1331 .await
1332 .trace_err();
1333
1334 futures::select_biased! {
1335 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1336
1337 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1338 leave_room_for_session(&session, session.connection_id).await.trace_err();
1339 leave_channel_buffers_for_session(&session)
1340 .await
1341 .trace_err();
1342
1343 if !session
1344 .connection_pool()
1345 .await
1346 .is_user_online(session.user_id())
1347 {
1348 let db = session.db().await;
1349 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1350 room_updated(&room, &session.peer);
1351 }
1352 }
1353
1354 update_user_contacts(session.user_id(), &session).await?;
1355 },
1356 _ = teardown.changed().fuse() => {}
1357 }
1358
1359 Ok(())
1360}
1361
1362/// Acknowledges a ping from a client, used to keep the connection alive.
1363async fn ping(
1364 _: proto::Ping,
1365 response: Response<proto::Ping>,
1366 _session: MessageContext,
1367) -> Result<()> {
1368 response.send(proto::Ack {})?;
1369 Ok(())
1370}
1371
1372/// Creates a new room for calling (outside of channels)
1373async fn create_room(
1374 _request: proto::CreateRoom,
1375 response: Response<proto::CreateRoom>,
1376 session: MessageContext,
1377) -> Result<()> {
1378 let livekit_room = nanoid::nanoid!(30);
1379
1380 let live_kit_connection_info = util::maybe!(async {
1381 let live_kit = session.app_state.livekit_client.as_ref();
1382 let live_kit = live_kit?;
1383 let user_id = session.user_id().to_string();
1384
1385 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1386
1387 Some(proto::LiveKitConnectionInfo {
1388 server_url: live_kit.url().into(),
1389 token,
1390 can_publish: true,
1391 })
1392 })
1393 .await;
1394
1395 let room = session
1396 .db()
1397 .await
1398 .create_room(session.user_id(), session.connection_id, &livekit_room)
1399 .await?;
1400
1401 response.send(proto::CreateRoomResponse {
1402 room: Some(room.clone()),
1403 live_kit_connection_info,
1404 })?;
1405
1406 update_user_contacts(session.user_id(), &session).await?;
1407 Ok(())
1408}
1409
1410/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1411async fn join_room(
1412 request: proto::JoinRoom,
1413 response: Response<proto::JoinRoom>,
1414 session: MessageContext,
1415) -> Result<()> {
1416 let room_id = RoomId::from_proto(request.id);
1417
1418 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1419
1420 if let Some(channel_id) = channel_id {
1421 return join_channel_internal(channel_id, Box::new(response), session).await;
1422 }
1423
1424 let joined_room = {
1425 let room = session
1426 .db()
1427 .await
1428 .join_room(room_id, session.user_id(), session.connection_id)
1429 .await?;
1430 room_updated(&room.room, &session.peer);
1431 room.into_inner()
1432 };
1433
1434 for connection_id in session
1435 .connection_pool()
1436 .await
1437 .user_connection_ids(session.user_id())
1438 {
1439 session
1440 .peer
1441 .send(
1442 connection_id,
1443 proto::CallCanceled {
1444 room_id: room_id.to_proto(),
1445 },
1446 )
1447 .trace_err();
1448 }
1449
1450 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1451 {
1452 live_kit
1453 .room_token(
1454 &joined_room.room.livekit_room,
1455 &session.user_id().to_string(),
1456 )
1457 .trace_err()
1458 .map(|token| proto::LiveKitConnectionInfo {
1459 server_url: live_kit.url().into(),
1460 token,
1461 can_publish: true,
1462 })
1463 } else {
1464 None
1465 };
1466
1467 response.send(proto::JoinRoomResponse {
1468 room: Some(joined_room.room),
1469 channel_id: None,
1470 live_kit_connection_info,
1471 })?;
1472
1473 update_user_contacts(session.user_id(), &session).await?;
1474 Ok(())
1475}
1476
1477/// Rejoin room is used to reconnect to a room after connection errors.
1478async fn rejoin_room(
1479 request: proto::RejoinRoom,
1480 response: Response<proto::RejoinRoom>,
1481 session: MessageContext,
1482) -> Result<()> {
1483 let room;
1484 let channel;
1485 {
1486 let mut rejoined_room = session
1487 .db()
1488 .await
1489 .rejoin_room(request, session.user_id(), session.connection_id)
1490 .await?;
1491
1492 response.send(proto::RejoinRoomResponse {
1493 room: Some(rejoined_room.room.clone()),
1494 reshared_projects: rejoined_room
1495 .reshared_projects
1496 .iter()
1497 .map(|project| proto::ResharedProject {
1498 id: project.id.to_proto(),
1499 collaborators: project
1500 .collaborators
1501 .iter()
1502 .map(|collaborator| collaborator.to_proto())
1503 .collect(),
1504 })
1505 .collect(),
1506 rejoined_projects: rejoined_room
1507 .rejoined_projects
1508 .iter()
1509 .map(|rejoined_project| rejoined_project.to_proto())
1510 .collect(),
1511 })?;
1512 room_updated(&rejoined_room.room, &session.peer);
1513
1514 for project in &rejoined_room.reshared_projects {
1515 for collaborator in &project.collaborators {
1516 session
1517 .peer
1518 .send(
1519 collaborator.connection_id,
1520 proto::UpdateProjectCollaborator {
1521 project_id: project.id.to_proto(),
1522 old_peer_id: Some(project.old_connection_id.into()),
1523 new_peer_id: Some(session.connection_id.into()),
1524 },
1525 )
1526 .trace_err();
1527 }
1528
1529 broadcast(
1530 Some(session.connection_id),
1531 project
1532 .collaborators
1533 .iter()
1534 .map(|collaborator| collaborator.connection_id),
1535 |connection_id| {
1536 session.peer.forward_send(
1537 session.connection_id,
1538 connection_id,
1539 proto::UpdateProject {
1540 project_id: project.id.to_proto(),
1541 worktrees: project.worktrees.clone(),
1542 },
1543 )
1544 },
1545 );
1546 }
1547
1548 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1549
1550 let rejoined_room = rejoined_room.into_inner();
1551
1552 room = rejoined_room.room;
1553 channel = rejoined_room.channel;
1554 }
1555
1556 if let Some(channel) = channel {
1557 channel_updated(
1558 &channel,
1559 &room,
1560 &session.peer,
1561 &*session.connection_pool().await,
1562 );
1563 }
1564
1565 update_user_contacts(session.user_id(), &session).await?;
1566 Ok(())
1567}
1568
1569fn notify_rejoined_projects(
1570 rejoined_projects: &mut Vec<RejoinedProject>,
1571 session: &Session,
1572) -> Result<()> {
1573 for project in rejoined_projects.iter() {
1574 for collaborator in &project.collaborators {
1575 session
1576 .peer
1577 .send(
1578 collaborator.connection_id,
1579 proto::UpdateProjectCollaborator {
1580 project_id: project.id.to_proto(),
1581 old_peer_id: Some(project.old_connection_id.into()),
1582 new_peer_id: Some(session.connection_id.into()),
1583 },
1584 )
1585 .trace_err();
1586 }
1587 }
1588
1589 for project in rejoined_projects {
1590 for worktree in mem::take(&mut project.worktrees) {
1591 // Stream this worktree's entries.
1592 let message = proto::UpdateWorktree {
1593 project_id: project.id.to_proto(),
1594 worktree_id: worktree.id,
1595 abs_path: worktree.abs_path.clone(),
1596 root_name: worktree.root_name,
1597 updated_entries: worktree.updated_entries,
1598 removed_entries: worktree.removed_entries,
1599 scan_id: worktree.scan_id,
1600 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1601 updated_repositories: worktree.updated_repositories,
1602 removed_repositories: worktree.removed_repositories,
1603 };
1604 for update in proto::split_worktree_update(message) {
1605 session.peer.send(session.connection_id, update)?;
1606 }
1607
1608 // Stream this worktree's diagnostics.
1609 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1610 if let Some(summary) = worktree_diagnostics.next() {
1611 let message = proto::UpdateDiagnosticSummary {
1612 project_id: project.id.to_proto(),
1613 worktree_id: worktree.id,
1614 summary: Some(summary),
1615 more_summaries: worktree_diagnostics.collect(),
1616 };
1617 session.peer.send(session.connection_id, message)?;
1618 }
1619
1620 for settings_file in worktree.settings_files {
1621 session.peer.send(
1622 session.connection_id,
1623 proto::UpdateWorktreeSettings {
1624 project_id: project.id.to_proto(),
1625 worktree_id: worktree.id,
1626 path: settings_file.path,
1627 content: Some(settings_file.content),
1628 kind: Some(settings_file.kind.to_proto().into()),
1629 },
1630 )?;
1631 }
1632 }
1633
1634 for repository in mem::take(&mut project.updated_repositories) {
1635 for update in split_repository_update(repository) {
1636 session.peer.send(session.connection_id, update)?;
1637 }
1638 }
1639
1640 for id in mem::take(&mut project.removed_repositories) {
1641 session.peer.send(
1642 session.connection_id,
1643 proto::RemoveRepository {
1644 project_id: project.id.to_proto(),
1645 id,
1646 },
1647 )?;
1648 }
1649 }
1650
1651 Ok(())
1652}
1653
1654/// leave room disconnects from the room.
1655async fn leave_room(
1656 _: proto::LeaveRoom,
1657 response: Response<proto::LeaveRoom>,
1658 session: MessageContext,
1659) -> Result<()> {
1660 leave_room_for_session(&session, session.connection_id).await?;
1661 response.send(proto::Ack {})?;
1662 Ok(())
1663}
1664
1665/// Updates the permissions of someone else in the room.
1666async fn set_room_participant_role(
1667 request: proto::SetRoomParticipantRole,
1668 response: Response<proto::SetRoomParticipantRole>,
1669 session: MessageContext,
1670) -> Result<()> {
1671 let user_id = UserId::from_proto(request.user_id);
1672 let role = ChannelRole::from(request.role());
1673
1674 let (livekit_room, can_publish) = {
1675 let room = session
1676 .db()
1677 .await
1678 .set_room_participant_role(
1679 session.user_id(),
1680 RoomId::from_proto(request.room_id),
1681 user_id,
1682 role,
1683 )
1684 .await?;
1685
1686 let livekit_room = room.livekit_room.clone();
1687 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1688 room_updated(&room, &session.peer);
1689 (livekit_room, can_publish)
1690 };
1691
1692 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1693 live_kit
1694 .update_participant(
1695 livekit_room.clone(),
1696 request.user_id.to_string(),
1697 livekit_api::proto::ParticipantPermission {
1698 can_subscribe: true,
1699 can_publish,
1700 can_publish_data: can_publish,
1701 hidden: false,
1702 recorder: false,
1703 },
1704 )
1705 .await
1706 .trace_err();
1707 }
1708
1709 response.send(proto::Ack {})?;
1710 Ok(())
1711}
1712
1713/// Call someone else into the current room
1714async fn call(
1715 request: proto::Call,
1716 response: Response<proto::Call>,
1717 session: MessageContext,
1718) -> Result<()> {
1719 let room_id = RoomId::from_proto(request.room_id);
1720 let calling_user_id = session.user_id();
1721 let calling_connection_id = session.connection_id;
1722 let called_user_id = UserId::from_proto(request.called_user_id);
1723 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1724 if !session
1725 .db()
1726 .await
1727 .has_contact(calling_user_id, called_user_id)
1728 .await?
1729 {
1730 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1731 }
1732
1733 let incoming_call = {
1734 let (room, incoming_call) = &mut *session
1735 .db()
1736 .await
1737 .call(
1738 room_id,
1739 calling_user_id,
1740 calling_connection_id,
1741 called_user_id,
1742 initial_project_id,
1743 )
1744 .await?;
1745 room_updated(room, &session.peer);
1746 mem::take(incoming_call)
1747 };
1748 update_user_contacts(called_user_id, &session).await?;
1749
1750 let mut calls = session
1751 .connection_pool()
1752 .await
1753 .user_connection_ids(called_user_id)
1754 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1755 .collect::<FuturesUnordered<_>>();
1756
1757 while let Some(call_response) = calls.next().await {
1758 match call_response.as_ref() {
1759 Ok(_) => {
1760 response.send(proto::Ack {})?;
1761 return Ok(());
1762 }
1763 Err(_) => {
1764 call_response.trace_err();
1765 }
1766 }
1767 }
1768
1769 {
1770 let room = session
1771 .db()
1772 .await
1773 .call_failed(room_id, called_user_id)
1774 .await?;
1775 room_updated(&room, &session.peer);
1776 }
1777 update_user_contacts(called_user_id, &session).await?;
1778
1779 Err(anyhow!("failed to ring user"))?
1780}
1781
1782/// Cancel an outgoing call.
1783async fn cancel_call(
1784 request: proto::CancelCall,
1785 response: Response<proto::CancelCall>,
1786 session: MessageContext,
1787) -> Result<()> {
1788 let called_user_id = UserId::from_proto(request.called_user_id);
1789 let room_id = RoomId::from_proto(request.room_id);
1790 {
1791 let room = session
1792 .db()
1793 .await
1794 .cancel_call(room_id, session.connection_id, called_user_id)
1795 .await?;
1796 room_updated(&room, &session.peer);
1797 }
1798
1799 for connection_id in session
1800 .connection_pool()
1801 .await
1802 .user_connection_ids(called_user_id)
1803 {
1804 session
1805 .peer
1806 .send(
1807 connection_id,
1808 proto::CallCanceled {
1809 room_id: room_id.to_proto(),
1810 },
1811 )
1812 .trace_err();
1813 }
1814 response.send(proto::Ack {})?;
1815
1816 update_user_contacts(called_user_id, &session).await?;
1817 Ok(())
1818}
1819
1820/// Decline an incoming call.
1821async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1822 let room_id = RoomId::from_proto(message.room_id);
1823 {
1824 let room = session
1825 .db()
1826 .await
1827 .decline_call(Some(room_id), session.user_id())
1828 .await?
1829 .context("declining call")?;
1830 room_updated(&room, &session.peer);
1831 }
1832
1833 for connection_id in session
1834 .connection_pool()
1835 .await
1836 .user_connection_ids(session.user_id())
1837 {
1838 session
1839 .peer
1840 .send(
1841 connection_id,
1842 proto::CallCanceled {
1843 room_id: room_id.to_proto(),
1844 },
1845 )
1846 .trace_err();
1847 }
1848 update_user_contacts(session.user_id(), &session).await?;
1849 Ok(())
1850}
1851
1852/// Updates other participants in the room with your current location.
1853async fn update_participant_location(
1854 request: proto::UpdateParticipantLocation,
1855 response: Response<proto::UpdateParticipantLocation>,
1856 session: MessageContext,
1857) -> Result<()> {
1858 let room_id = RoomId::from_proto(request.room_id);
1859 let location = request.location.context("invalid location")?;
1860
1861 let db = session.db().await;
1862 let room = db
1863 .update_room_participant_location(room_id, session.connection_id, location)
1864 .await?;
1865
1866 room_updated(&room, &session.peer);
1867 response.send(proto::Ack {})?;
1868 Ok(())
1869}
1870
1871/// Share a project into the room.
1872async fn share_project(
1873 request: proto::ShareProject,
1874 response: Response<proto::ShareProject>,
1875 session: MessageContext,
1876) -> Result<()> {
1877 let (project_id, room) = &*session
1878 .db()
1879 .await
1880 .share_project(
1881 RoomId::from_proto(request.room_id),
1882 session.connection_id,
1883 &request.worktrees,
1884 request.is_ssh_project,
1885 request.windows_paths.unwrap_or(false),
1886 )
1887 .await?;
1888 response.send(proto::ShareProjectResponse {
1889 project_id: project_id.to_proto(),
1890 })?;
1891 room_updated(room, &session.peer);
1892
1893 Ok(())
1894}
1895
1896/// Unshare a project from the room.
1897async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1898 let project_id = ProjectId::from_proto(message.project_id);
1899 unshare_project_internal(project_id, session.connection_id, &session).await
1900}
1901
1902async fn unshare_project_internal(
1903 project_id: ProjectId,
1904 connection_id: ConnectionId,
1905 session: &Session,
1906) -> Result<()> {
1907 let delete = {
1908 let room_guard = session
1909 .db()
1910 .await
1911 .unshare_project(project_id, connection_id)
1912 .await?;
1913
1914 let (delete, room, guest_connection_ids) = &*room_guard;
1915
1916 let message = proto::UnshareProject {
1917 project_id: project_id.to_proto(),
1918 };
1919
1920 broadcast(
1921 Some(connection_id),
1922 guest_connection_ids.iter().copied(),
1923 |conn_id| session.peer.send(conn_id, message.clone()),
1924 );
1925 if let Some(room) = room {
1926 room_updated(room, &session.peer);
1927 }
1928
1929 *delete
1930 };
1931
1932 if delete {
1933 let db = session.db().await;
1934 db.delete_project(project_id).await?;
1935 }
1936
1937 Ok(())
1938}
1939
1940/// Join someone elses shared project.
1941async fn join_project(
1942 request: proto::JoinProject,
1943 response: Response<proto::JoinProject>,
1944 session: MessageContext,
1945) -> Result<()> {
1946 let project_id = ProjectId::from_proto(request.project_id);
1947
1948 tracing::info!(%project_id, "join project");
1949
1950 let db = session.db().await;
1951 let (project, replica_id) = &mut *db
1952 .join_project(
1953 project_id,
1954 session.connection_id,
1955 session.user_id(),
1956 request.committer_name.clone(),
1957 request.committer_email.clone(),
1958 )
1959 .await?;
1960 drop(db);
1961 tracing::info!(%project_id, "join remote project");
1962 let collaborators = project
1963 .collaborators
1964 .iter()
1965 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1966 .map(|collaborator| collaborator.to_proto())
1967 .collect::<Vec<_>>();
1968 let project_id = project.id;
1969 let guest_user_id = session.user_id();
1970
1971 let worktrees = project
1972 .worktrees
1973 .iter()
1974 .map(|(id, worktree)| proto::WorktreeMetadata {
1975 id: *id,
1976 root_name: worktree.root_name.clone(),
1977 visible: worktree.visible,
1978 abs_path: worktree.abs_path.clone(),
1979 })
1980 .collect::<Vec<_>>();
1981
1982 let add_project_collaborator = proto::AddProjectCollaborator {
1983 project_id: project_id.to_proto(),
1984 collaborator: Some(proto::Collaborator {
1985 peer_id: Some(session.connection_id.into()),
1986 replica_id: replica_id.0 as u32,
1987 user_id: guest_user_id.to_proto(),
1988 is_host: false,
1989 committer_name: request.committer_name.clone(),
1990 committer_email: request.committer_email.clone(),
1991 }),
1992 };
1993
1994 for collaborator in &collaborators {
1995 session
1996 .peer
1997 .send(
1998 collaborator.peer_id.unwrap().into(),
1999 add_project_collaborator.clone(),
2000 )
2001 .trace_err();
2002 }
2003
2004 // First, we send the metadata associated with each worktree.
2005 let (language_servers, language_server_capabilities) = project
2006 .language_servers
2007 .clone()
2008 .into_iter()
2009 .map(|server| (server.server, server.capabilities))
2010 .unzip();
2011 response.send(proto::JoinProjectResponse {
2012 project_id: project.id.0 as u64,
2013 worktrees,
2014 replica_id: replica_id.0 as u32,
2015 collaborators,
2016 language_servers,
2017 language_server_capabilities,
2018 role: project.role.into(),
2019 windows_paths: project.path_style == PathStyle::Windows,
2020 })?;
2021
2022 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2023 // Stream this worktree's entries.
2024 let message = proto::UpdateWorktree {
2025 project_id: project_id.to_proto(),
2026 worktree_id,
2027 abs_path: worktree.abs_path.clone(),
2028 root_name: worktree.root_name,
2029 updated_entries: worktree.entries,
2030 removed_entries: Default::default(),
2031 scan_id: worktree.scan_id,
2032 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2033 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2034 removed_repositories: Default::default(),
2035 };
2036 for update in proto::split_worktree_update(message) {
2037 session.peer.send(session.connection_id, update.clone())?;
2038 }
2039
2040 // Stream this worktree's diagnostics.
2041 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2042 if let Some(summary) = worktree_diagnostics.next() {
2043 let message = proto::UpdateDiagnosticSummary {
2044 project_id: project.id.to_proto(),
2045 worktree_id: worktree.id,
2046 summary: Some(summary),
2047 more_summaries: worktree_diagnostics.collect(),
2048 };
2049 session.peer.send(session.connection_id, message)?;
2050 }
2051
2052 for settings_file in worktree.settings_files {
2053 session.peer.send(
2054 session.connection_id,
2055 proto::UpdateWorktreeSettings {
2056 project_id: project_id.to_proto(),
2057 worktree_id: worktree.id,
2058 path: settings_file.path,
2059 content: Some(settings_file.content),
2060 kind: Some(settings_file.kind.to_proto() as i32),
2061 },
2062 )?;
2063 }
2064 }
2065
2066 for repository in mem::take(&mut project.repositories) {
2067 for update in split_repository_update(repository) {
2068 session.peer.send(session.connection_id, update)?;
2069 }
2070 }
2071
2072 for language_server in &project.language_servers {
2073 session.peer.send(
2074 session.connection_id,
2075 proto::UpdateLanguageServer {
2076 project_id: project_id.to_proto(),
2077 server_name: Some(language_server.server.name.clone()),
2078 language_server_id: language_server.server.id,
2079 variant: Some(
2080 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2081 proto::LspDiskBasedDiagnosticsUpdated {},
2082 ),
2083 ),
2084 },
2085 )?;
2086 }
2087
2088 Ok(())
2089}
2090
2091/// Leave someone elses shared project.
2092async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2093 let sender_id = session.connection_id;
2094 let project_id = ProjectId::from_proto(request.project_id);
2095 let db = session.db().await;
2096
2097 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2098 tracing::info!(
2099 %project_id,
2100 "leave project"
2101 );
2102
2103 project_left(project, &session);
2104 if let Some(room) = room {
2105 room_updated(room, &session.peer);
2106 }
2107
2108 Ok(())
2109}
2110
2111/// Updates other participants with changes to the project
2112async fn update_project(
2113 request: proto::UpdateProject,
2114 response: Response<proto::UpdateProject>,
2115 session: MessageContext,
2116) -> Result<()> {
2117 let project_id = ProjectId::from_proto(request.project_id);
2118 let (room, guest_connection_ids) = &*session
2119 .db()
2120 .await
2121 .update_project(project_id, session.connection_id, &request.worktrees)
2122 .await?;
2123 broadcast(
2124 Some(session.connection_id),
2125 guest_connection_ids.iter().copied(),
2126 |connection_id| {
2127 session
2128 .peer
2129 .forward_send(session.connection_id, connection_id, request.clone())
2130 },
2131 );
2132 if let Some(room) = room {
2133 room_updated(room, &session.peer);
2134 }
2135 response.send(proto::Ack {})?;
2136
2137 Ok(())
2138}
2139
2140/// Updates other participants with changes to the worktree
2141async fn update_worktree(
2142 request: proto::UpdateWorktree,
2143 response: Response<proto::UpdateWorktree>,
2144 session: MessageContext,
2145) -> Result<()> {
2146 let guest_connection_ids = session
2147 .db()
2148 .await
2149 .update_worktree(&request, session.connection_id)
2150 .await?;
2151
2152 broadcast(
2153 Some(session.connection_id),
2154 guest_connection_ids.iter().copied(),
2155 |connection_id| {
2156 session
2157 .peer
2158 .forward_send(session.connection_id, connection_id, request.clone())
2159 },
2160 );
2161 response.send(proto::Ack {})?;
2162 Ok(())
2163}
2164
2165async fn update_repository(
2166 request: proto::UpdateRepository,
2167 response: Response<proto::UpdateRepository>,
2168 session: MessageContext,
2169) -> Result<()> {
2170 let guest_connection_ids = session
2171 .db()
2172 .await
2173 .update_repository(&request, session.connection_id)
2174 .await?;
2175
2176 broadcast(
2177 Some(session.connection_id),
2178 guest_connection_ids.iter().copied(),
2179 |connection_id| {
2180 session
2181 .peer
2182 .forward_send(session.connection_id, connection_id, request.clone())
2183 },
2184 );
2185 response.send(proto::Ack {})?;
2186 Ok(())
2187}
2188
2189async fn remove_repository(
2190 request: proto::RemoveRepository,
2191 response: Response<proto::RemoveRepository>,
2192 session: MessageContext,
2193) -> Result<()> {
2194 let guest_connection_ids = session
2195 .db()
2196 .await
2197 .remove_repository(&request, session.connection_id)
2198 .await?;
2199
2200 broadcast(
2201 Some(session.connection_id),
2202 guest_connection_ids.iter().copied(),
2203 |connection_id| {
2204 session
2205 .peer
2206 .forward_send(session.connection_id, connection_id, request.clone())
2207 },
2208 );
2209 response.send(proto::Ack {})?;
2210 Ok(())
2211}
2212
2213/// Updates other participants with changes to the diagnostics
2214async fn update_diagnostic_summary(
2215 message: proto::UpdateDiagnosticSummary,
2216 session: MessageContext,
2217) -> Result<()> {
2218 let guest_connection_ids = session
2219 .db()
2220 .await
2221 .update_diagnostic_summary(&message, session.connection_id)
2222 .await?;
2223
2224 broadcast(
2225 Some(session.connection_id),
2226 guest_connection_ids.iter().copied(),
2227 |connection_id| {
2228 session
2229 .peer
2230 .forward_send(session.connection_id, connection_id, message.clone())
2231 },
2232 );
2233
2234 Ok(())
2235}
2236
2237/// Updates other participants with changes to the worktree settings
2238async fn update_worktree_settings(
2239 message: proto::UpdateWorktreeSettings,
2240 session: MessageContext,
2241) -> Result<()> {
2242 let guest_connection_ids = session
2243 .db()
2244 .await
2245 .update_worktree_settings(&message, session.connection_id)
2246 .await?;
2247
2248 broadcast(
2249 Some(session.connection_id),
2250 guest_connection_ids.iter().copied(),
2251 |connection_id| {
2252 session
2253 .peer
2254 .forward_send(session.connection_id, connection_id, message.clone())
2255 },
2256 );
2257
2258 Ok(())
2259}
2260
2261/// Notify other participants that a language server has started.
2262async fn start_language_server(
2263 request: proto::StartLanguageServer,
2264 session: MessageContext,
2265) -> Result<()> {
2266 let guest_connection_ids = session
2267 .db()
2268 .await
2269 .start_language_server(&request, session.connection_id)
2270 .await?;
2271
2272 broadcast(
2273 Some(session.connection_id),
2274 guest_connection_ids.iter().copied(),
2275 |connection_id| {
2276 session
2277 .peer
2278 .forward_send(session.connection_id, connection_id, request.clone())
2279 },
2280 );
2281 Ok(())
2282}
2283
2284/// Notify other participants that a language server has changed.
2285async fn update_language_server(
2286 request: proto::UpdateLanguageServer,
2287 session: MessageContext,
2288) -> Result<()> {
2289 let project_id = ProjectId::from_proto(request.project_id);
2290 let db = session.db().await;
2291
2292 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2293 && let Some(capabilities) = update.capabilities.clone()
2294 {
2295 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2296 .await?;
2297 }
2298
2299 let project_connection_ids = db
2300 .project_connection_ids(project_id, session.connection_id, true)
2301 .await?;
2302 broadcast(
2303 Some(session.connection_id),
2304 project_connection_ids.iter().copied(),
2305 |connection_id| {
2306 session
2307 .peer
2308 .forward_send(session.connection_id, connection_id, request.clone())
2309 },
2310 );
2311 Ok(())
2312}
2313
2314/// forward a project request to the host. These requests should be read only
2315/// as guests are allowed to send them.
2316async fn forward_read_only_project_request<T>(
2317 request: T,
2318 response: Response<T>,
2319 session: MessageContext,
2320) -> Result<()>
2321where
2322 T: EntityMessage + RequestMessage,
2323{
2324 let project_id = ProjectId::from_proto(request.remote_entity_id());
2325 let host_connection_id = session
2326 .db()
2327 .await
2328 .host_for_read_only_project_request(project_id, session.connection_id)
2329 .await?;
2330 let payload = session.forward_request(host_connection_id, request).await?;
2331 response.send(payload)?;
2332 Ok(())
2333}
2334
2335/// forward a project request to the host. These requests are disallowed
2336/// for guests.
2337async fn forward_mutating_project_request<T>(
2338 request: T,
2339 response: Response<T>,
2340 session: MessageContext,
2341) -> Result<()>
2342where
2343 T: EntityMessage + RequestMessage,
2344{
2345 let project_id = ProjectId::from_proto(request.remote_entity_id());
2346
2347 let host_connection_id = session
2348 .db()
2349 .await
2350 .host_for_mutating_project_request(project_id, session.connection_id)
2351 .await?;
2352 let payload = session.forward_request(host_connection_id, request).await?;
2353 response.send(payload)?;
2354 Ok(())
2355}
2356
2357async fn lsp_query(
2358 request: proto::LspQuery,
2359 response: Response<proto::LspQuery>,
2360 session: MessageContext,
2361) -> Result<()> {
2362 let (name, should_write) = request.query_name_and_write_permissions();
2363 tracing::Span::current().record("lsp_query_request", name);
2364 tracing::info!("lsp_query message received");
2365 if should_write {
2366 forward_mutating_project_request(request, response, session).await
2367 } else {
2368 forward_read_only_project_request(request, response, session).await
2369 }
2370}
2371
2372/// Notify other participants that a new buffer has been created
2373async fn create_buffer_for_peer(
2374 request: proto::CreateBufferForPeer,
2375 session: MessageContext,
2376) -> Result<()> {
2377 session
2378 .db()
2379 .await
2380 .check_user_is_project_host(
2381 ProjectId::from_proto(request.project_id),
2382 session.connection_id,
2383 )
2384 .await?;
2385 let peer_id = request.peer_id.context("invalid peer id")?;
2386 session
2387 .peer
2388 .forward_send(session.connection_id, peer_id.into(), request)?;
2389 Ok(())
2390}
2391
2392/// Notify other participants that a buffer has been updated. This is
2393/// allowed for guests as long as the update is limited to selections.
2394async fn update_buffer(
2395 request: proto::UpdateBuffer,
2396 response: Response<proto::UpdateBuffer>,
2397 session: MessageContext,
2398) -> Result<()> {
2399 let project_id = ProjectId::from_proto(request.project_id);
2400 let mut capability = Capability::ReadOnly;
2401
2402 for op in request.operations.iter() {
2403 match op.variant {
2404 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2405 Some(_) => capability = Capability::ReadWrite,
2406 }
2407 }
2408
2409 let host = {
2410 let guard = session
2411 .db()
2412 .await
2413 .connections_for_buffer_update(project_id, session.connection_id, capability)
2414 .await?;
2415
2416 let (host, guests) = &*guard;
2417
2418 broadcast(
2419 Some(session.connection_id),
2420 guests.clone(),
2421 |connection_id| {
2422 session
2423 .peer
2424 .forward_send(session.connection_id, connection_id, request.clone())
2425 },
2426 );
2427
2428 *host
2429 };
2430
2431 if host != session.connection_id {
2432 session.forward_request(host, request.clone()).await?;
2433 }
2434
2435 response.send(proto::Ack {})?;
2436 Ok(())
2437}
2438
2439async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2440 let project_id = ProjectId::from_proto(message.project_id);
2441
2442 let operation = message.operation.as_ref().context("invalid operation")?;
2443 let capability = match operation.variant.as_ref() {
2444 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2445 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2446 match buffer_op.variant {
2447 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2448 Capability::ReadOnly
2449 }
2450 _ => Capability::ReadWrite,
2451 }
2452 } else {
2453 Capability::ReadWrite
2454 }
2455 }
2456 Some(_) => Capability::ReadWrite,
2457 None => Capability::ReadOnly,
2458 };
2459
2460 let guard = session
2461 .db()
2462 .await
2463 .connections_for_buffer_update(project_id, session.connection_id, capability)
2464 .await?;
2465
2466 let (host, guests) = &*guard;
2467
2468 broadcast(
2469 Some(session.connection_id),
2470 guests.iter().chain([host]).copied(),
2471 |connection_id| {
2472 session
2473 .peer
2474 .forward_send(session.connection_id, connection_id, message.clone())
2475 },
2476 );
2477
2478 Ok(())
2479}
2480
2481/// Notify other participants that a project has been updated.
2482async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2483 request: T,
2484 session: MessageContext,
2485) -> Result<()> {
2486 let project_id = ProjectId::from_proto(request.remote_entity_id());
2487 let project_connection_ids = session
2488 .db()
2489 .await
2490 .project_connection_ids(project_id, session.connection_id, false)
2491 .await?;
2492
2493 broadcast(
2494 Some(session.connection_id),
2495 project_connection_ids.iter().copied(),
2496 |connection_id| {
2497 session
2498 .peer
2499 .forward_send(session.connection_id, connection_id, request.clone())
2500 },
2501 );
2502 Ok(())
2503}
2504
2505/// Start following another user in a call.
2506async fn follow(
2507 request: proto::Follow,
2508 response: Response<proto::Follow>,
2509 session: MessageContext,
2510) -> Result<()> {
2511 let room_id = RoomId::from_proto(request.room_id);
2512 let project_id = request.project_id.map(ProjectId::from_proto);
2513 let leader_id = request.leader_id.context("invalid leader id")?.into();
2514 let follower_id = session.connection_id;
2515
2516 session
2517 .db()
2518 .await
2519 .check_room_participants(room_id, leader_id, session.connection_id)
2520 .await?;
2521
2522 let response_payload = session.forward_request(leader_id, request).await?;
2523 response.send(response_payload)?;
2524
2525 if let Some(project_id) = project_id {
2526 let room = session
2527 .db()
2528 .await
2529 .follow(room_id, project_id, leader_id, follower_id)
2530 .await?;
2531 room_updated(&room, &session.peer);
2532 }
2533
2534 Ok(())
2535}
2536
2537/// Stop following another user in a call.
2538async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2539 let room_id = RoomId::from_proto(request.room_id);
2540 let project_id = request.project_id.map(ProjectId::from_proto);
2541 let leader_id = request.leader_id.context("invalid leader id")?.into();
2542 let follower_id = session.connection_id;
2543
2544 session
2545 .db()
2546 .await
2547 .check_room_participants(room_id, leader_id, session.connection_id)
2548 .await?;
2549
2550 session
2551 .peer
2552 .forward_send(session.connection_id, leader_id, request)?;
2553
2554 if let Some(project_id) = project_id {
2555 let room = session
2556 .db()
2557 .await
2558 .unfollow(room_id, project_id, leader_id, follower_id)
2559 .await?;
2560 room_updated(&room, &session.peer);
2561 }
2562
2563 Ok(())
2564}
2565
2566/// Notify everyone following you of your current location.
2567async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2568 let room_id = RoomId::from_proto(request.room_id);
2569 let database = session.db.lock().await;
2570
2571 let connection_ids = if let Some(project_id) = request.project_id {
2572 let project_id = ProjectId::from_proto(project_id);
2573 database
2574 .project_connection_ids(project_id, session.connection_id, true)
2575 .await?
2576 } else {
2577 database
2578 .room_connection_ids(room_id, session.connection_id)
2579 .await?
2580 };
2581
2582 // For now, don't send view update messages back to that view's current leader.
2583 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2584 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2585 _ => None,
2586 });
2587
2588 for connection_id in connection_ids.iter().cloned() {
2589 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2590 session
2591 .peer
2592 .forward_send(session.connection_id, connection_id, request.clone())?;
2593 }
2594 }
2595 Ok(())
2596}
2597
2598/// Get public data about users.
2599async fn get_users(
2600 request: proto::GetUsers,
2601 response: Response<proto::GetUsers>,
2602 session: MessageContext,
2603) -> Result<()> {
2604 let user_ids = request
2605 .user_ids
2606 .into_iter()
2607 .map(UserId::from_proto)
2608 .collect();
2609 let users = session
2610 .db()
2611 .await
2612 .get_users_by_ids(user_ids)
2613 .await?
2614 .into_iter()
2615 .map(|user| proto::User {
2616 id: user.id.to_proto(),
2617 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2618 github_login: user.github_login,
2619 name: user.name,
2620 })
2621 .collect();
2622 response.send(proto::UsersResponse { users })?;
2623 Ok(())
2624}
2625
2626/// Search for users (to invite) buy Github login
2627async fn fuzzy_search_users(
2628 request: proto::FuzzySearchUsers,
2629 response: Response<proto::FuzzySearchUsers>,
2630 session: MessageContext,
2631) -> Result<()> {
2632 let query = request.query;
2633 let users = match query.len() {
2634 0 => vec![],
2635 1 | 2 => session
2636 .db()
2637 .await
2638 .get_user_by_github_login(&query)
2639 .await?
2640 .into_iter()
2641 .collect(),
2642 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2643 };
2644 let users = users
2645 .into_iter()
2646 .filter(|user| user.id != session.user_id())
2647 .map(|user| proto::User {
2648 id: user.id.to_proto(),
2649 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2650 github_login: user.github_login,
2651 name: user.name,
2652 })
2653 .collect();
2654 response.send(proto::UsersResponse { users })?;
2655 Ok(())
2656}
2657
2658/// Send a contact request to another user.
2659async fn request_contact(
2660 request: proto::RequestContact,
2661 response: Response<proto::RequestContact>,
2662 session: MessageContext,
2663) -> Result<()> {
2664 let requester_id = session.user_id();
2665 let responder_id = UserId::from_proto(request.responder_id);
2666 if requester_id == responder_id {
2667 return Err(anyhow!("cannot add yourself as a contact"))?;
2668 }
2669
2670 let notifications = session
2671 .db()
2672 .await
2673 .send_contact_request(requester_id, responder_id)
2674 .await?;
2675
2676 // Update outgoing contact requests of requester
2677 let mut update = proto::UpdateContacts::default();
2678 update.outgoing_requests.push(responder_id.to_proto());
2679 for connection_id in session
2680 .connection_pool()
2681 .await
2682 .user_connection_ids(requester_id)
2683 {
2684 session.peer.send(connection_id, update.clone())?;
2685 }
2686
2687 // Update incoming contact requests of responder
2688 let mut update = proto::UpdateContacts::default();
2689 update
2690 .incoming_requests
2691 .push(proto::IncomingContactRequest {
2692 requester_id: requester_id.to_proto(),
2693 });
2694 let connection_pool = session.connection_pool().await;
2695 for connection_id in connection_pool.user_connection_ids(responder_id) {
2696 session.peer.send(connection_id, update.clone())?;
2697 }
2698
2699 send_notifications(&connection_pool, &session.peer, notifications);
2700
2701 response.send(proto::Ack {})?;
2702 Ok(())
2703}
2704
2705/// Accept or decline a contact request
2706async fn respond_to_contact_request(
2707 request: proto::RespondToContactRequest,
2708 response: Response<proto::RespondToContactRequest>,
2709 session: MessageContext,
2710) -> Result<()> {
2711 let responder_id = session.user_id();
2712 let requester_id = UserId::from_proto(request.requester_id);
2713 let db = session.db().await;
2714 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2715 db.dismiss_contact_notification(responder_id, requester_id)
2716 .await?;
2717 } else {
2718 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2719
2720 let notifications = db
2721 .respond_to_contact_request(responder_id, requester_id, accept)
2722 .await?;
2723 let requester_busy = db.is_user_busy(requester_id).await?;
2724 let responder_busy = db.is_user_busy(responder_id).await?;
2725
2726 let pool = session.connection_pool().await;
2727 // Update responder with new contact
2728 let mut update = proto::UpdateContacts::default();
2729 if accept {
2730 update
2731 .contacts
2732 .push(contact_for_user(requester_id, requester_busy, &pool));
2733 }
2734 update
2735 .remove_incoming_requests
2736 .push(requester_id.to_proto());
2737 for connection_id in pool.user_connection_ids(responder_id) {
2738 session.peer.send(connection_id, update.clone())?;
2739 }
2740
2741 // Update requester with new contact
2742 let mut update = proto::UpdateContacts::default();
2743 if accept {
2744 update
2745 .contacts
2746 .push(contact_for_user(responder_id, responder_busy, &pool));
2747 }
2748 update
2749 .remove_outgoing_requests
2750 .push(responder_id.to_proto());
2751
2752 for connection_id in pool.user_connection_ids(requester_id) {
2753 session.peer.send(connection_id, update.clone())?;
2754 }
2755
2756 send_notifications(&pool, &session.peer, notifications);
2757 }
2758
2759 response.send(proto::Ack {})?;
2760 Ok(())
2761}
2762
2763/// Remove a contact.
2764async fn remove_contact(
2765 request: proto::RemoveContact,
2766 response: Response<proto::RemoveContact>,
2767 session: MessageContext,
2768) -> Result<()> {
2769 let requester_id = session.user_id();
2770 let responder_id = UserId::from_proto(request.user_id);
2771 let db = session.db().await;
2772 let (contact_accepted, deleted_notification_id) =
2773 db.remove_contact(requester_id, responder_id).await?;
2774
2775 let pool = session.connection_pool().await;
2776 // Update outgoing contact requests of requester
2777 let mut update = proto::UpdateContacts::default();
2778 if contact_accepted {
2779 update.remove_contacts.push(responder_id.to_proto());
2780 } else {
2781 update
2782 .remove_outgoing_requests
2783 .push(responder_id.to_proto());
2784 }
2785 for connection_id in pool.user_connection_ids(requester_id) {
2786 session.peer.send(connection_id, update.clone())?;
2787 }
2788
2789 // Update incoming contact requests of responder
2790 let mut update = proto::UpdateContacts::default();
2791 if contact_accepted {
2792 update.remove_contacts.push(requester_id.to_proto());
2793 } else {
2794 update
2795 .remove_incoming_requests
2796 .push(requester_id.to_proto());
2797 }
2798 for connection_id in pool.user_connection_ids(responder_id) {
2799 session.peer.send(connection_id, update.clone())?;
2800 if let Some(notification_id) = deleted_notification_id {
2801 session.peer.send(
2802 connection_id,
2803 proto::DeleteNotification {
2804 notification_id: notification_id.to_proto(),
2805 },
2806 )?;
2807 }
2808 }
2809
2810 response.send(proto::Ack {})?;
2811 Ok(())
2812}
2813
2814fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2815 version.0.minor() < 139
2816}
2817
2818async fn subscribe_to_channels(
2819 _: proto::SubscribeToChannels,
2820 session: MessageContext,
2821) -> Result<()> {
2822 subscribe_user_to_channels(session.user_id(), &session).await?;
2823 Ok(())
2824}
2825
2826async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2827 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2828 let mut pool = session.connection_pool().await;
2829 for membership in &channels_for_user.channel_memberships {
2830 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2831 }
2832 session.peer.send(
2833 session.connection_id,
2834 build_update_user_channels(&channels_for_user),
2835 )?;
2836 session.peer.send(
2837 session.connection_id,
2838 build_channels_update(channels_for_user),
2839 )?;
2840 Ok(())
2841}
2842
2843/// Creates a new channel.
2844async fn create_channel(
2845 request: proto::CreateChannel,
2846 response: Response<proto::CreateChannel>,
2847 session: MessageContext,
2848) -> Result<()> {
2849 let db = session.db().await;
2850
2851 let parent_id = request.parent_id.map(ChannelId::from_proto);
2852 let (channel, membership) = db
2853 .create_channel(&request.name, parent_id, session.user_id())
2854 .await?;
2855
2856 let root_id = channel.root_id();
2857 let channel = Channel::from_model(channel);
2858
2859 response.send(proto::CreateChannelResponse {
2860 channel: Some(channel.to_proto()),
2861 parent_id: request.parent_id,
2862 })?;
2863
2864 let mut connection_pool = session.connection_pool().await;
2865 if let Some(membership) = membership {
2866 connection_pool.subscribe_to_channel(
2867 membership.user_id,
2868 membership.channel_id,
2869 membership.role,
2870 );
2871 let update = proto::UpdateUserChannels {
2872 channel_memberships: vec![proto::ChannelMembership {
2873 channel_id: membership.channel_id.to_proto(),
2874 role: membership.role.into(),
2875 }],
2876 ..Default::default()
2877 };
2878 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2879 session.peer.send(connection_id, update.clone())?;
2880 }
2881 }
2882
2883 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2884 if !role.can_see_channel(channel.visibility) {
2885 continue;
2886 }
2887
2888 let update = proto::UpdateChannels {
2889 channels: vec![channel.to_proto()],
2890 ..Default::default()
2891 };
2892 session.peer.send(connection_id, update.clone())?;
2893 }
2894
2895 Ok(())
2896}
2897
2898/// Delete a channel
2899async fn delete_channel(
2900 request: proto::DeleteChannel,
2901 response: Response<proto::DeleteChannel>,
2902 session: MessageContext,
2903) -> Result<()> {
2904 let db = session.db().await;
2905
2906 let channel_id = request.channel_id;
2907 let (root_channel, removed_channels) = db
2908 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2909 .await?;
2910 response.send(proto::Ack {})?;
2911
2912 // Notify members of removed channels
2913 let mut update = proto::UpdateChannels::default();
2914 update
2915 .delete_channels
2916 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2917
2918 let connection_pool = session.connection_pool().await;
2919 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2920 session.peer.send(connection_id, update.clone())?;
2921 }
2922
2923 Ok(())
2924}
2925
2926/// Invite someone to join a channel.
2927async fn invite_channel_member(
2928 request: proto::InviteChannelMember,
2929 response: Response<proto::InviteChannelMember>,
2930 session: MessageContext,
2931) -> Result<()> {
2932 let db = session.db().await;
2933 let channel_id = ChannelId::from_proto(request.channel_id);
2934 let invitee_id = UserId::from_proto(request.user_id);
2935 let InviteMemberResult {
2936 channel,
2937 notifications,
2938 } = db
2939 .invite_channel_member(
2940 channel_id,
2941 invitee_id,
2942 session.user_id(),
2943 request.role().into(),
2944 )
2945 .await?;
2946
2947 let update = proto::UpdateChannels {
2948 channel_invitations: vec![channel.to_proto()],
2949 ..Default::default()
2950 };
2951
2952 let connection_pool = session.connection_pool().await;
2953 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2954 session.peer.send(connection_id, update.clone())?;
2955 }
2956
2957 send_notifications(&connection_pool, &session.peer, notifications);
2958
2959 response.send(proto::Ack {})?;
2960 Ok(())
2961}
2962
2963/// remove someone from a channel
2964async fn remove_channel_member(
2965 request: proto::RemoveChannelMember,
2966 response: Response<proto::RemoveChannelMember>,
2967 session: MessageContext,
2968) -> Result<()> {
2969 let db = session.db().await;
2970 let channel_id = ChannelId::from_proto(request.channel_id);
2971 let member_id = UserId::from_proto(request.user_id);
2972
2973 let RemoveChannelMemberResult {
2974 membership_update,
2975 notification_id,
2976 } = db
2977 .remove_channel_member(channel_id, member_id, session.user_id())
2978 .await?;
2979
2980 let mut connection_pool = session.connection_pool().await;
2981 notify_membership_updated(
2982 &mut connection_pool,
2983 membership_update,
2984 member_id,
2985 &session.peer,
2986 );
2987 for connection_id in connection_pool.user_connection_ids(member_id) {
2988 if let Some(notification_id) = notification_id {
2989 session
2990 .peer
2991 .send(
2992 connection_id,
2993 proto::DeleteNotification {
2994 notification_id: notification_id.to_proto(),
2995 },
2996 )
2997 .trace_err();
2998 }
2999 }
3000
3001 response.send(proto::Ack {})?;
3002 Ok(())
3003}
3004
3005/// Toggle the channel between public and private.
3006/// Care is taken to maintain the invariant that public channels only descend from public channels,
3007/// (though members-only channels can appear at any point in the hierarchy).
3008async fn set_channel_visibility(
3009 request: proto::SetChannelVisibility,
3010 response: Response<proto::SetChannelVisibility>,
3011 session: MessageContext,
3012) -> Result<()> {
3013 let db = session.db().await;
3014 let channel_id = ChannelId::from_proto(request.channel_id);
3015 let visibility = request.visibility().into();
3016
3017 let channel_model = db
3018 .set_channel_visibility(channel_id, visibility, session.user_id())
3019 .await?;
3020 let root_id = channel_model.root_id();
3021 let channel = Channel::from_model(channel_model);
3022
3023 let mut connection_pool = session.connection_pool().await;
3024 for (user_id, role) in connection_pool
3025 .channel_user_ids(root_id)
3026 .collect::<Vec<_>>()
3027 .into_iter()
3028 {
3029 let update = if role.can_see_channel(channel.visibility) {
3030 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3031 proto::UpdateChannels {
3032 channels: vec![channel.to_proto()],
3033 ..Default::default()
3034 }
3035 } else {
3036 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3037 proto::UpdateChannels {
3038 delete_channels: vec![channel.id.to_proto()],
3039 ..Default::default()
3040 }
3041 };
3042
3043 for connection_id in connection_pool.user_connection_ids(user_id) {
3044 session.peer.send(connection_id, update.clone())?;
3045 }
3046 }
3047
3048 response.send(proto::Ack {})?;
3049 Ok(())
3050}
3051
3052/// Alter the role for a user in the channel.
3053async fn set_channel_member_role(
3054 request: proto::SetChannelMemberRole,
3055 response: Response<proto::SetChannelMemberRole>,
3056 session: MessageContext,
3057) -> Result<()> {
3058 let db = session.db().await;
3059 let channel_id = ChannelId::from_proto(request.channel_id);
3060 let member_id = UserId::from_proto(request.user_id);
3061 let result = db
3062 .set_channel_member_role(
3063 channel_id,
3064 session.user_id(),
3065 member_id,
3066 request.role().into(),
3067 )
3068 .await?;
3069
3070 match result {
3071 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3072 let mut connection_pool = session.connection_pool().await;
3073 notify_membership_updated(
3074 &mut connection_pool,
3075 membership_update,
3076 member_id,
3077 &session.peer,
3078 )
3079 }
3080 db::SetMemberRoleResult::InviteUpdated(channel) => {
3081 let update = proto::UpdateChannels {
3082 channel_invitations: vec![channel.to_proto()],
3083 ..Default::default()
3084 };
3085
3086 for connection_id in session
3087 .connection_pool()
3088 .await
3089 .user_connection_ids(member_id)
3090 {
3091 session.peer.send(connection_id, update.clone())?;
3092 }
3093 }
3094 }
3095
3096 response.send(proto::Ack {})?;
3097 Ok(())
3098}
3099
3100/// Change the name of a channel
3101async fn rename_channel(
3102 request: proto::RenameChannel,
3103 response: Response<proto::RenameChannel>,
3104 session: MessageContext,
3105) -> Result<()> {
3106 let db = session.db().await;
3107 let channel_id = ChannelId::from_proto(request.channel_id);
3108 let channel_model = db
3109 .rename_channel(channel_id, session.user_id(), &request.name)
3110 .await?;
3111 let root_id = channel_model.root_id();
3112 let channel = Channel::from_model(channel_model);
3113
3114 response.send(proto::RenameChannelResponse {
3115 channel: Some(channel.to_proto()),
3116 })?;
3117
3118 let connection_pool = session.connection_pool().await;
3119 let update = proto::UpdateChannels {
3120 channels: vec![channel.to_proto()],
3121 ..Default::default()
3122 };
3123 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3124 if role.can_see_channel(channel.visibility) {
3125 session.peer.send(connection_id, update.clone())?;
3126 }
3127 }
3128
3129 Ok(())
3130}
3131
3132/// Move a channel to a new parent.
3133async fn move_channel(
3134 request: proto::MoveChannel,
3135 response: Response<proto::MoveChannel>,
3136 session: MessageContext,
3137) -> Result<()> {
3138 let channel_id = ChannelId::from_proto(request.channel_id);
3139 let to = ChannelId::from_proto(request.to);
3140
3141 let (root_id, channels) = session
3142 .db()
3143 .await
3144 .move_channel(channel_id, to, session.user_id())
3145 .await?;
3146
3147 let connection_pool = session.connection_pool().await;
3148 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3149 let channels = channels
3150 .iter()
3151 .filter_map(|channel| {
3152 if role.can_see_channel(channel.visibility) {
3153 Some(channel.to_proto())
3154 } else {
3155 None
3156 }
3157 })
3158 .collect::<Vec<_>>();
3159 if channels.is_empty() {
3160 continue;
3161 }
3162
3163 let update = proto::UpdateChannels {
3164 channels,
3165 ..Default::default()
3166 };
3167
3168 session.peer.send(connection_id, update.clone())?;
3169 }
3170
3171 response.send(Ack {})?;
3172 Ok(())
3173}
3174
3175async fn reorder_channel(
3176 request: proto::ReorderChannel,
3177 response: Response<proto::ReorderChannel>,
3178 session: MessageContext,
3179) -> Result<()> {
3180 let channel_id = ChannelId::from_proto(request.channel_id);
3181 let direction = request.direction();
3182
3183 let updated_channels = session
3184 .db()
3185 .await
3186 .reorder_channel(channel_id, direction, session.user_id())
3187 .await?;
3188
3189 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3190 let connection_pool = session.connection_pool().await;
3191 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3192 let channels = updated_channels
3193 .iter()
3194 .filter_map(|channel| {
3195 if role.can_see_channel(channel.visibility) {
3196 Some(channel.to_proto())
3197 } else {
3198 None
3199 }
3200 })
3201 .collect::<Vec<_>>();
3202
3203 if channels.is_empty() {
3204 continue;
3205 }
3206
3207 let update = proto::UpdateChannels {
3208 channels,
3209 ..Default::default()
3210 };
3211
3212 session.peer.send(connection_id, update.clone())?;
3213 }
3214 }
3215
3216 response.send(Ack {})?;
3217 Ok(())
3218}
3219
3220/// Get the list of channel members
3221async fn get_channel_members(
3222 request: proto::GetChannelMembers,
3223 response: Response<proto::GetChannelMembers>,
3224 session: MessageContext,
3225) -> Result<()> {
3226 let db = session.db().await;
3227 let channel_id = ChannelId::from_proto(request.channel_id);
3228 let limit = if request.limit == 0 {
3229 u16::MAX as u64
3230 } else {
3231 request.limit
3232 };
3233 let (members, users) = db
3234 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3235 .await?;
3236 response.send(proto::GetChannelMembersResponse { members, users })?;
3237 Ok(())
3238}
3239
3240/// Accept or decline a channel invitation.
3241async fn respond_to_channel_invite(
3242 request: proto::RespondToChannelInvite,
3243 response: Response<proto::RespondToChannelInvite>,
3244 session: MessageContext,
3245) -> Result<()> {
3246 let db = session.db().await;
3247 let channel_id = ChannelId::from_proto(request.channel_id);
3248 let RespondToChannelInvite {
3249 membership_update,
3250 notifications,
3251 } = db
3252 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3253 .await?;
3254
3255 let mut connection_pool = session.connection_pool().await;
3256 if let Some(membership_update) = membership_update {
3257 notify_membership_updated(
3258 &mut connection_pool,
3259 membership_update,
3260 session.user_id(),
3261 &session.peer,
3262 );
3263 } else {
3264 let update = proto::UpdateChannels {
3265 remove_channel_invitations: vec![channel_id.to_proto()],
3266 ..Default::default()
3267 };
3268
3269 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3270 session.peer.send(connection_id, update.clone())?;
3271 }
3272 };
3273
3274 send_notifications(&connection_pool, &session.peer, notifications);
3275
3276 response.send(proto::Ack {})?;
3277
3278 Ok(())
3279}
3280
3281/// Join the channels' room
3282async fn join_channel(
3283 request: proto::JoinChannel,
3284 response: Response<proto::JoinChannel>,
3285 session: MessageContext,
3286) -> Result<()> {
3287 let channel_id = ChannelId::from_proto(request.channel_id);
3288 join_channel_internal(channel_id, Box::new(response), session).await
3289}
3290
3291trait JoinChannelInternalResponse {
3292 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3293}
3294impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3295 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3296 Response::<proto::JoinChannel>::send(self, result)
3297 }
3298}
3299impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3300 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3301 Response::<proto::JoinRoom>::send(self, result)
3302 }
3303}
3304
3305async fn join_channel_internal(
3306 channel_id: ChannelId,
3307 response: Box<impl JoinChannelInternalResponse>,
3308 session: MessageContext,
3309) -> Result<()> {
3310 let joined_room = {
3311 let mut db = session.db().await;
3312 // If zed quits without leaving the room, and the user re-opens zed before the
3313 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3314 // room they were in.
3315 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3316 tracing::info!(
3317 stale_connection_id = %connection,
3318 "cleaning up stale connection",
3319 );
3320 drop(db);
3321 leave_room_for_session(&session, connection).await?;
3322 db = session.db().await;
3323 }
3324
3325 let (joined_room, membership_updated, role) = db
3326 .join_channel(channel_id, session.user_id(), session.connection_id)
3327 .await?;
3328
3329 let live_kit_connection_info =
3330 session
3331 .app_state
3332 .livekit_client
3333 .as_ref()
3334 .and_then(|live_kit| {
3335 let (can_publish, token) = if role == ChannelRole::Guest {
3336 (
3337 false,
3338 live_kit
3339 .guest_token(
3340 &joined_room.room.livekit_room,
3341 &session.user_id().to_string(),
3342 )
3343 .trace_err()?,
3344 )
3345 } else {
3346 (
3347 true,
3348 live_kit
3349 .room_token(
3350 &joined_room.room.livekit_room,
3351 &session.user_id().to_string(),
3352 )
3353 .trace_err()?,
3354 )
3355 };
3356
3357 Some(LiveKitConnectionInfo {
3358 server_url: live_kit.url().into(),
3359 token,
3360 can_publish,
3361 })
3362 });
3363
3364 response.send(proto::JoinRoomResponse {
3365 room: Some(joined_room.room.clone()),
3366 channel_id: joined_room
3367 .channel
3368 .as_ref()
3369 .map(|channel| channel.id.to_proto()),
3370 live_kit_connection_info,
3371 })?;
3372
3373 let mut connection_pool = session.connection_pool().await;
3374 if let Some(membership_updated) = membership_updated {
3375 notify_membership_updated(
3376 &mut connection_pool,
3377 membership_updated,
3378 session.user_id(),
3379 &session.peer,
3380 );
3381 }
3382
3383 room_updated(&joined_room.room, &session.peer);
3384
3385 joined_room
3386 };
3387
3388 channel_updated(
3389 &joined_room.channel.context("channel not returned")?,
3390 &joined_room.room,
3391 &session.peer,
3392 &*session.connection_pool().await,
3393 );
3394
3395 update_user_contacts(session.user_id(), &session).await?;
3396 Ok(())
3397}
3398
3399/// Start editing the channel notes
3400async fn join_channel_buffer(
3401 request: proto::JoinChannelBuffer,
3402 response: Response<proto::JoinChannelBuffer>,
3403 session: MessageContext,
3404) -> Result<()> {
3405 let db = session.db().await;
3406 let channel_id = ChannelId::from_proto(request.channel_id);
3407
3408 let open_response = db
3409 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3410 .await?;
3411
3412 let collaborators = open_response.collaborators.clone();
3413 response.send(open_response)?;
3414
3415 let update = UpdateChannelBufferCollaborators {
3416 channel_id: channel_id.to_proto(),
3417 collaborators: collaborators.clone(),
3418 };
3419 channel_buffer_updated(
3420 session.connection_id,
3421 collaborators
3422 .iter()
3423 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3424 &update,
3425 &session.peer,
3426 );
3427
3428 Ok(())
3429}
3430
3431/// Edit the channel notes
3432async fn update_channel_buffer(
3433 request: proto::UpdateChannelBuffer,
3434 session: MessageContext,
3435) -> Result<()> {
3436 let db = session.db().await;
3437 let channel_id = ChannelId::from_proto(request.channel_id);
3438
3439 let (collaborators, epoch, version) = db
3440 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3441 .await?;
3442
3443 channel_buffer_updated(
3444 session.connection_id,
3445 collaborators.clone(),
3446 &proto::UpdateChannelBuffer {
3447 channel_id: channel_id.to_proto(),
3448 operations: request.operations,
3449 },
3450 &session.peer,
3451 );
3452
3453 let pool = &*session.connection_pool().await;
3454
3455 let non_collaborators =
3456 pool.channel_connection_ids(channel_id)
3457 .filter_map(|(connection_id, _)| {
3458 if collaborators.contains(&connection_id) {
3459 None
3460 } else {
3461 Some(connection_id)
3462 }
3463 });
3464
3465 broadcast(None, non_collaborators, |peer_id| {
3466 session.peer.send(
3467 peer_id,
3468 proto::UpdateChannels {
3469 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3470 channel_id: channel_id.to_proto(),
3471 epoch: epoch as u64,
3472 version: version.clone(),
3473 }],
3474 ..Default::default()
3475 },
3476 )
3477 });
3478
3479 Ok(())
3480}
3481
3482/// Rejoin the channel notes after a connection blip
3483async fn rejoin_channel_buffers(
3484 request: proto::RejoinChannelBuffers,
3485 response: Response<proto::RejoinChannelBuffers>,
3486 session: MessageContext,
3487) -> Result<()> {
3488 let db = session.db().await;
3489 let buffers = db
3490 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3491 .await?;
3492
3493 for rejoined_buffer in &buffers {
3494 let collaborators_to_notify = rejoined_buffer
3495 .buffer
3496 .collaborators
3497 .iter()
3498 .filter_map(|c| Some(c.peer_id?.into()));
3499 channel_buffer_updated(
3500 session.connection_id,
3501 collaborators_to_notify,
3502 &proto::UpdateChannelBufferCollaborators {
3503 channel_id: rejoined_buffer.buffer.channel_id,
3504 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3505 },
3506 &session.peer,
3507 );
3508 }
3509
3510 response.send(proto::RejoinChannelBuffersResponse {
3511 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3512 })?;
3513
3514 Ok(())
3515}
3516
3517/// Stop editing the channel notes
3518async fn leave_channel_buffer(
3519 request: proto::LeaveChannelBuffer,
3520 response: Response<proto::LeaveChannelBuffer>,
3521 session: MessageContext,
3522) -> Result<()> {
3523 let db = session.db().await;
3524 let channel_id = ChannelId::from_proto(request.channel_id);
3525
3526 let left_buffer = db
3527 .leave_channel_buffer(channel_id, session.connection_id)
3528 .await?;
3529
3530 response.send(Ack {})?;
3531
3532 channel_buffer_updated(
3533 session.connection_id,
3534 left_buffer.connections,
3535 &proto::UpdateChannelBufferCollaborators {
3536 channel_id: channel_id.to_proto(),
3537 collaborators: left_buffer.collaborators,
3538 },
3539 &session.peer,
3540 );
3541
3542 Ok(())
3543}
3544
3545fn channel_buffer_updated<T: EnvelopedMessage>(
3546 sender_id: ConnectionId,
3547 collaborators: impl IntoIterator<Item = ConnectionId>,
3548 message: &T,
3549 peer: &Peer,
3550) {
3551 broadcast(Some(sender_id), collaborators, |peer_id| {
3552 peer.send(peer_id, message.clone())
3553 });
3554}
3555
3556fn send_notifications(
3557 connection_pool: &ConnectionPool,
3558 peer: &Peer,
3559 notifications: db::NotificationBatch,
3560) {
3561 for (user_id, notification) in notifications {
3562 for connection_id in connection_pool.user_connection_ids(user_id) {
3563 if let Err(error) = peer.send(
3564 connection_id,
3565 proto::AddNotification {
3566 notification: Some(notification.clone()),
3567 },
3568 ) {
3569 tracing::error!(
3570 "failed to send notification to {:?} {}",
3571 connection_id,
3572 error
3573 );
3574 }
3575 }
3576 }
3577}
3578
3579/// Send a message to the channel
3580async fn send_channel_message(
3581 _request: proto::SendChannelMessage,
3582 _response: Response<proto::SendChannelMessage>,
3583 _session: MessageContext,
3584) -> Result<()> {
3585 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3586}
3587
3588/// Delete a channel message
3589async fn remove_channel_message(
3590 _request: proto::RemoveChannelMessage,
3591 _response: Response<proto::RemoveChannelMessage>,
3592 _session: MessageContext,
3593) -> Result<()> {
3594 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3595}
3596
3597async fn update_channel_message(
3598 _request: proto::UpdateChannelMessage,
3599 _response: Response<proto::UpdateChannelMessage>,
3600 _session: MessageContext,
3601) -> Result<()> {
3602 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3603}
3604
3605/// Mark a channel message as read
3606async fn acknowledge_channel_message(
3607 _request: proto::AckChannelMessage,
3608 _session: MessageContext,
3609) -> Result<()> {
3610 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3611}
3612
3613/// Mark a buffer version as synced
3614async fn acknowledge_buffer_version(
3615 request: proto::AckBufferOperation,
3616 session: MessageContext,
3617) -> Result<()> {
3618 let buffer_id = BufferId::from_proto(request.buffer_id);
3619 session
3620 .db()
3621 .await
3622 .observe_buffer_version(
3623 buffer_id,
3624 session.user_id(),
3625 request.epoch as i32,
3626 &request.version,
3627 )
3628 .await?;
3629 Ok(())
3630}
3631
3632/// Get a Supermaven API key for the user
3633async fn get_supermaven_api_key(
3634 _request: proto::GetSupermavenApiKey,
3635 response: Response<proto::GetSupermavenApiKey>,
3636 session: MessageContext,
3637) -> Result<()> {
3638 let user_id: String = session.user_id().to_string();
3639 if !session.is_staff() {
3640 return Err(anyhow!("supermaven not enabled for this account"))?;
3641 }
3642
3643 let email = session.email().context("user must have an email")?;
3644
3645 let supermaven_admin_api = session
3646 .supermaven_client
3647 .as_ref()
3648 .context("supermaven not configured")?;
3649
3650 let result = supermaven_admin_api
3651 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3652 .await?;
3653
3654 response.send(proto::GetSupermavenApiKeyResponse {
3655 api_key: result.api_key,
3656 })?;
3657
3658 Ok(())
3659}
3660
3661/// Start receiving chat updates for a channel
3662async fn join_channel_chat(
3663 _request: proto::JoinChannelChat,
3664 _response: Response<proto::JoinChannelChat>,
3665 _session: MessageContext,
3666) -> Result<()> {
3667 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3668}
3669
3670/// Stop receiving chat updates for a channel
3671async fn leave_channel_chat(
3672 _request: proto::LeaveChannelChat,
3673 _session: MessageContext,
3674) -> Result<()> {
3675 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3676}
3677
3678/// Retrieve the chat history for a channel
3679async fn get_channel_messages(
3680 _request: proto::GetChannelMessages,
3681 _response: Response<proto::GetChannelMessages>,
3682 _session: MessageContext,
3683) -> Result<()> {
3684 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3685}
3686
3687/// Retrieve specific chat messages
3688async fn get_channel_messages_by_id(
3689 _request: proto::GetChannelMessagesById,
3690 _response: Response<proto::GetChannelMessagesById>,
3691 _session: MessageContext,
3692) -> Result<()> {
3693 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3694}
3695
3696/// Retrieve the current users notifications
3697async fn get_notifications(
3698 request: proto::GetNotifications,
3699 response: Response<proto::GetNotifications>,
3700 session: MessageContext,
3701) -> Result<()> {
3702 let notifications = session
3703 .db()
3704 .await
3705 .get_notifications(
3706 session.user_id(),
3707 NOTIFICATION_COUNT_PER_PAGE,
3708 request.before_id.map(db::NotificationId::from_proto),
3709 )
3710 .await?;
3711 response.send(proto::GetNotificationsResponse {
3712 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3713 notifications,
3714 })?;
3715 Ok(())
3716}
3717
3718/// Mark notifications as read
3719async fn mark_notification_as_read(
3720 request: proto::MarkNotificationRead,
3721 response: Response<proto::MarkNotificationRead>,
3722 session: MessageContext,
3723) -> Result<()> {
3724 let database = &session.db().await;
3725 let notifications = database
3726 .mark_notification_as_read_by_id(
3727 session.user_id(),
3728 NotificationId::from_proto(request.notification_id),
3729 )
3730 .await?;
3731 send_notifications(
3732 &*session.connection_pool().await,
3733 &session.peer,
3734 notifications,
3735 );
3736 response.send(proto::Ack {})?;
3737 Ok(())
3738}
3739
3740fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3741 let message = match message {
3742 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3743 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3744 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3745 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3746 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3747 code: frame.code.into(),
3748 reason: frame.reason.as_str().to_owned().into(),
3749 })),
3750 // We should never receive a frame while reading the message, according
3751 // to the `tungstenite` maintainers:
3752 //
3753 // > It cannot occur when you read messages from the WebSocket, but it
3754 // > can be used when you want to send the raw frames (e.g. you want to
3755 // > send the frames to the WebSocket without composing the full message first).
3756 // >
3757 // > — https://github.com/snapview/tungstenite-rs/issues/268
3758 TungsteniteMessage::Frame(_) => {
3759 bail!("received an unexpected frame while reading the message")
3760 }
3761 };
3762
3763 Ok(message)
3764}
3765
3766fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3767 match message {
3768 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3769 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3770 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3771 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3772 AxumMessage::Close(frame) => {
3773 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3774 code: frame.code.into(),
3775 reason: frame.reason.as_ref().into(),
3776 }))
3777 }
3778 }
3779}
3780
3781fn notify_membership_updated(
3782 connection_pool: &mut ConnectionPool,
3783 result: MembershipUpdated,
3784 user_id: UserId,
3785 peer: &Peer,
3786) {
3787 for membership in &result.new_channels.channel_memberships {
3788 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3789 }
3790 for channel_id in &result.removed_channels {
3791 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3792 }
3793
3794 let user_channels_update = proto::UpdateUserChannels {
3795 channel_memberships: result
3796 .new_channels
3797 .channel_memberships
3798 .iter()
3799 .map(|cm| proto::ChannelMembership {
3800 channel_id: cm.channel_id.to_proto(),
3801 role: cm.role.into(),
3802 })
3803 .collect(),
3804 ..Default::default()
3805 };
3806
3807 let mut update = build_channels_update(result.new_channels);
3808 update.delete_channels = result
3809 .removed_channels
3810 .into_iter()
3811 .map(|id| id.to_proto())
3812 .collect();
3813 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3814
3815 for connection_id in connection_pool.user_connection_ids(user_id) {
3816 peer.send(connection_id, user_channels_update.clone())
3817 .trace_err();
3818 peer.send(connection_id, update.clone()).trace_err();
3819 }
3820}
3821
3822fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3823 proto::UpdateUserChannels {
3824 channel_memberships: channels
3825 .channel_memberships
3826 .iter()
3827 .map(|m| proto::ChannelMembership {
3828 channel_id: m.channel_id.to_proto(),
3829 role: m.role.into(),
3830 })
3831 .collect(),
3832 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3833 }
3834}
3835
3836fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3837 let mut update = proto::UpdateChannels::default();
3838
3839 for channel in channels.channels {
3840 update.channels.push(channel.to_proto());
3841 }
3842
3843 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3844
3845 for (channel_id, participants) in channels.channel_participants {
3846 update
3847 .channel_participants
3848 .push(proto::ChannelParticipants {
3849 channel_id: channel_id.to_proto(),
3850 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3851 });
3852 }
3853
3854 for channel in channels.invited_channels {
3855 update.channel_invitations.push(channel.to_proto());
3856 }
3857
3858 update
3859}
3860
3861fn build_initial_contacts_update(
3862 contacts: Vec<db::Contact>,
3863 pool: &ConnectionPool,
3864) -> proto::UpdateContacts {
3865 let mut update = proto::UpdateContacts::default();
3866
3867 for contact in contacts {
3868 match contact {
3869 db::Contact::Accepted { user_id, busy } => {
3870 update.contacts.push(contact_for_user(user_id, busy, pool));
3871 }
3872 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3873 db::Contact::Incoming { user_id } => {
3874 update
3875 .incoming_requests
3876 .push(proto::IncomingContactRequest {
3877 requester_id: user_id.to_proto(),
3878 })
3879 }
3880 }
3881 }
3882
3883 update
3884}
3885
3886fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3887 proto::Contact {
3888 user_id: user_id.to_proto(),
3889 online: pool.is_user_online(user_id),
3890 busy,
3891 }
3892}
3893
3894fn room_updated(room: &proto::Room, peer: &Peer) {
3895 broadcast(
3896 None,
3897 room.participants
3898 .iter()
3899 .filter_map(|participant| Some(participant.peer_id?.into())),
3900 |peer_id| {
3901 peer.send(
3902 peer_id,
3903 proto::RoomUpdated {
3904 room: Some(room.clone()),
3905 },
3906 )
3907 },
3908 );
3909}
3910
3911fn channel_updated(
3912 channel: &db::channel::Model,
3913 room: &proto::Room,
3914 peer: &Peer,
3915 pool: &ConnectionPool,
3916) {
3917 let participants = room
3918 .participants
3919 .iter()
3920 .map(|p| p.user_id)
3921 .collect::<Vec<_>>();
3922
3923 broadcast(
3924 None,
3925 pool.channel_connection_ids(channel.root_id())
3926 .filter_map(|(channel_id, role)| {
3927 role.can_see_channel(channel.visibility)
3928 .then_some(channel_id)
3929 }),
3930 |peer_id| {
3931 peer.send(
3932 peer_id,
3933 proto::UpdateChannels {
3934 channel_participants: vec![proto::ChannelParticipants {
3935 channel_id: channel.id.to_proto(),
3936 participant_user_ids: participants.clone(),
3937 }],
3938 ..Default::default()
3939 },
3940 )
3941 },
3942 );
3943}
3944
3945async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3946 let db = session.db().await;
3947
3948 let contacts = db.get_contacts(user_id).await?;
3949 let busy = db.is_user_busy(user_id).await?;
3950
3951 let pool = session.connection_pool().await;
3952 let updated_contact = contact_for_user(user_id, busy, &pool);
3953 for contact in contacts {
3954 if let db::Contact::Accepted {
3955 user_id: contact_user_id,
3956 ..
3957 } = contact
3958 {
3959 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3960 session
3961 .peer
3962 .send(
3963 contact_conn_id,
3964 proto::UpdateContacts {
3965 contacts: vec![updated_contact.clone()],
3966 remove_contacts: Default::default(),
3967 incoming_requests: Default::default(),
3968 remove_incoming_requests: Default::default(),
3969 outgoing_requests: Default::default(),
3970 remove_outgoing_requests: Default::default(),
3971 },
3972 )
3973 .trace_err();
3974 }
3975 }
3976 }
3977 Ok(())
3978}
3979
3980async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3981 let mut contacts_to_update = HashSet::default();
3982
3983 let room_id;
3984 let canceled_calls_to_user_ids;
3985 let livekit_room;
3986 let delete_livekit_room;
3987 let room;
3988 let channel;
3989
3990 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3991 contacts_to_update.insert(session.user_id());
3992
3993 for project in left_room.left_projects.values() {
3994 project_left(project, session);
3995 }
3996
3997 room_id = RoomId::from_proto(left_room.room.id);
3998 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3999 livekit_room = mem::take(&mut left_room.room.livekit_room);
4000 delete_livekit_room = left_room.deleted;
4001 room = mem::take(&mut left_room.room);
4002 channel = mem::take(&mut left_room.channel);
4003
4004 room_updated(&room, &session.peer);
4005 } else {
4006 return Ok(());
4007 }
4008
4009 if let Some(channel) = channel {
4010 channel_updated(
4011 &channel,
4012 &room,
4013 &session.peer,
4014 &*session.connection_pool().await,
4015 );
4016 }
4017
4018 {
4019 let pool = session.connection_pool().await;
4020 for canceled_user_id in canceled_calls_to_user_ids {
4021 for connection_id in pool.user_connection_ids(canceled_user_id) {
4022 session
4023 .peer
4024 .send(
4025 connection_id,
4026 proto::CallCanceled {
4027 room_id: room_id.to_proto(),
4028 },
4029 )
4030 .trace_err();
4031 }
4032 contacts_to_update.insert(canceled_user_id);
4033 }
4034 }
4035
4036 for contact_user_id in contacts_to_update {
4037 update_user_contacts(contact_user_id, session).await?;
4038 }
4039
4040 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4041 live_kit
4042 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4043 .await
4044 .trace_err();
4045
4046 if delete_livekit_room {
4047 live_kit.delete_room(livekit_room).await.trace_err();
4048 }
4049 }
4050
4051 Ok(())
4052}
4053
4054async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4055 let left_channel_buffers = session
4056 .db()
4057 .await
4058 .leave_channel_buffers(session.connection_id)
4059 .await?;
4060
4061 for left_buffer in left_channel_buffers {
4062 channel_buffer_updated(
4063 session.connection_id,
4064 left_buffer.connections,
4065 &proto::UpdateChannelBufferCollaborators {
4066 channel_id: left_buffer.channel_id.to_proto(),
4067 collaborators: left_buffer.collaborators,
4068 },
4069 &session.peer,
4070 );
4071 }
4072
4073 Ok(())
4074}
4075
4076fn project_left(project: &db::LeftProject, session: &Session) {
4077 for connection_id in &project.connection_ids {
4078 if project.should_unshare {
4079 session
4080 .peer
4081 .send(
4082 *connection_id,
4083 proto::UnshareProject {
4084 project_id: project.id.to_proto(),
4085 },
4086 )
4087 .trace_err();
4088 } else {
4089 session
4090 .peer
4091 .send(
4092 *connection_id,
4093 proto::RemoveProjectCollaborator {
4094 project_id: project.id.to_proto(),
4095 peer_id: Some(session.connection_id.into()),
4096 },
4097 )
4098 .trace_err();
4099 }
4100 }
4101}
4102
4103pub trait ResultExt {
4104 type Ok;
4105
4106 fn trace_err(self) -> Option<Self::Ok>;
4107}
4108
4109impl<T, E> ResultExt for Result<T, E>
4110where
4111 E: std::fmt::Debug,
4112{
4113 type Ok = T;
4114
4115 #[track_caller]
4116 fn trace_err(self) -> Option<T> {
4117 match self {
4118 Ok(value) => Some(value),
4119 Err(error) => {
4120 tracing::error!("{:?}", error);
4121 None
4122 }
4123 }
4124 }
4125}