1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
10 },
11 executor::Executor,
12};
13use anyhow::{Context as _, anyhow, bail};
14use async_tungstenite::tungstenite::{
15 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
16};
17use axum::headers::UserAgent;
18use axum::{
19 Extension, Router, TypedHeader,
20 body::Body,
21 extract::{
22 ConnectInfo, WebSocketUpgrade,
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use core::fmt::{self, Debug, Formatter};
34use futures::TryFutureExt as _;
35use reqwest_client::ReqwestClient;
36use rpc::proto::split_repository_update;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38use tracing::Span;
39use util::paths::PathStyle;
40
41use futures::{
42 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
43 stream::FuturesUnordered,
44};
45use prometheus::{IntGauge, register_int_gauge};
46use rpc::{
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52};
53use semantic_version::SemanticVersion;
54use serde::{Serialize, Serializer};
55use std::{
56 any::TypeId,
57 future::Future,
58 marker::PhantomData,
59 mem,
60 net::SocketAddr,
61 ops::{Deref, DerefMut},
62 rc::Rc,
63 sync::{
64 Arc, OnceLock,
65 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
66 },
67 time::{Duration, Instant},
68};
69use tokio::sync::{Semaphore, watch};
70use tower::ServiceBuilder;
71use tracing::{
72 Instrument,
73 field::{self},
74 info_span, instrument,
75};
76
77pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
78
79// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
80pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
81
82const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
83const MAX_CONCURRENT_CONNECTIONS: usize = 512;
84
85static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
86
87const TOTAL_DURATION_MS: &str = "total_duration_ms";
88const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
89const QUEUE_DURATION_MS: &str = "queue_duration_ms";
90const HOST_WAITING_MS: &str = "host_waiting_ms";
91
92type MessageHandler =
93 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
94
95pub struct ConnectionGuard;
96
97impl ConnectionGuard {
98 pub fn try_acquire() -> Result<Self, ()> {
99 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
100 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
101 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
102 tracing::error!(
103 "too many concurrent connections: {}",
104 current_connections + 1
105 );
106 return Err(());
107 }
108 Ok(ConnectionGuard)
109 }
110}
111
112impl Drop for ConnectionGuard {
113 fn drop(&mut self) {
114 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
115 }
116}
117
118struct Response<R> {
119 peer: Arc<Peer>,
120 receipt: Receipt<R>,
121 responded: Arc<AtomicBool>,
122}
123
124impl<R: RequestMessage> Response<R> {
125 fn send(self, payload: R::Response) -> Result<()> {
126 self.responded.store(true, SeqCst);
127 self.peer.respond(self.receipt, payload)?;
128 Ok(())
129 }
130}
131
132#[derive(Clone, Debug)]
133pub enum Principal {
134 User(User),
135 Impersonated { user: User, admin: User },
136}
137
138impl Principal {
139 fn update_span(&self, span: &tracing::Span) {
140 match &self {
141 Principal::User(user) => {
142 span.record("user_id", user.id.0);
143 span.record("login", &user.github_login);
144 }
145 Principal::Impersonated { user, admin } => {
146 span.record("user_id", user.id.0);
147 span.record("login", &user.github_login);
148 span.record("impersonator", &admin.github_login);
149 }
150 }
151 }
152}
153
154#[derive(Clone)]
155struct MessageContext {
156 session: Session,
157 span: tracing::Span,
158}
159
160impl Deref for MessageContext {
161 type Target = Session;
162
163 fn deref(&self) -> &Self::Target {
164 &self.session
165 }
166}
167
168impl MessageContext {
169 pub fn forward_request<T: RequestMessage>(
170 &self,
171 receiver_id: ConnectionId,
172 request: T,
173 ) -> impl Future<Output = anyhow::Result<T::Response>> {
174 let request_start_time = Instant::now();
175 let span = self.span.clone();
176 tracing::info!("start forwarding request");
177 self.peer
178 .forward_request(self.connection_id, receiver_id, request)
179 .inspect(move |_| {
180 span.record(
181 HOST_WAITING_MS,
182 request_start_time.elapsed().as_micros() as f64 / 1000.0,
183 );
184 })
185 .inspect_err(|_| tracing::error!("error forwarding request"))
186 .inspect_ok(|_| tracing::info!("finished forwarding request"))
187 }
188}
189
190#[derive(Clone)]
191struct Session {
192 principal: Principal,
193 connection_id: ConnectionId,
194 db: Arc<tokio::sync::Mutex<DbHandle>>,
195 peer: Arc<Peer>,
196 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
197 app_state: Arc<AppState>,
198 supermaven_client: Option<Arc<SupermavenAdminApi>>,
199 /// The GeoIP country code for the user.
200 #[allow(unused)]
201 geoip_country_code: Option<String>,
202 #[allow(unused)]
203 system_id: Option<String>,
204 _executor: Executor,
205}
206
207impl Session {
208 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
209 #[cfg(test)]
210 tokio::task::yield_now().await;
211 let guard = self.db.lock().await;
212 #[cfg(test)]
213 tokio::task::yield_now().await;
214 guard
215 }
216
217 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
218 #[cfg(test)]
219 tokio::task::yield_now().await;
220 let guard = self.connection_pool.lock();
221 ConnectionPoolGuard {
222 guard,
223 _not_send: PhantomData,
224 }
225 }
226
227 fn is_staff(&self) -> bool {
228 match &self.principal {
229 Principal::User(user) => user.admin,
230 Principal::Impersonated { .. } => true,
231 }
232 }
233
234 fn user_id(&self) -> UserId {
235 match &self.principal {
236 Principal::User(user) => user.id,
237 Principal::Impersonated { user, .. } => user.id,
238 }
239 }
240
241 pub fn email(&self) -> Option<String> {
242 match &self.principal {
243 Principal::User(user) => user.email_address.clone(),
244 Principal::Impersonated { user, .. } => user.email_address.clone(),
245 }
246 }
247}
248
249impl Debug for Session {
250 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
251 let mut result = f.debug_struct("Session");
252 match &self.principal {
253 Principal::User(user) => {
254 result.field("user", &user.github_login);
255 }
256 Principal::Impersonated { user, admin } => {
257 result.field("user", &user.github_login);
258 result.field("impersonator", &admin.github_login);
259 }
260 }
261 result.field("connection_id", &self.connection_id).finish()
262 }
263}
264
265struct DbHandle(Arc<Database>);
266
267impl Deref for DbHandle {
268 type Target = Database;
269
270 fn deref(&self) -> &Self::Target {
271 self.0.as_ref()
272 }
273}
274
275pub struct Server {
276 id: parking_lot::Mutex<ServerId>,
277 peer: Arc<Peer>,
278 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
279 app_state: Arc<AppState>,
280 handlers: HashMap<TypeId, MessageHandler>,
281 teardown: watch::Sender<bool>,
282}
283
284pub(crate) struct ConnectionPoolGuard<'a> {
285 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
286 _not_send: PhantomData<Rc<()>>,
287}
288
289#[derive(Serialize)]
290pub struct ServerSnapshot<'a> {
291 peer: &'a Peer,
292 #[serde(serialize_with = "serialize_deref")]
293 connection_pool: ConnectionPoolGuard<'a>,
294}
295
296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
297where
298 S: Serializer,
299 T: Deref<Target = U>,
300 U: Serialize,
301{
302 Serialize::serialize(value.deref(), serializer)
303}
304
305impl Server {
306 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
307 let mut server = Self {
308 id: parking_lot::Mutex::new(id),
309 peer: Peer::new(id.0 as u32),
310 app_state,
311 connection_pool: Default::default(),
312 handlers: Default::default(),
313 teardown: watch::channel(false).0,
314 };
315
316 server
317 .add_request_handler(ping)
318 .add_request_handler(create_room)
319 .add_request_handler(join_room)
320 .add_request_handler(rejoin_room)
321 .add_request_handler(leave_room)
322 .add_request_handler(set_room_participant_role)
323 .add_request_handler(call)
324 .add_request_handler(cancel_call)
325 .add_message_handler(decline_call)
326 .add_request_handler(update_participant_location)
327 .add_request_handler(share_project)
328 .add_message_handler(unshare_project)
329 .add_request_handler(join_project)
330 .add_message_handler(leave_project)
331 .add_request_handler(update_project)
332 .add_request_handler(update_worktree)
333 .add_request_handler(update_repository)
334 .add_request_handler(remove_repository)
335 .add_message_handler(start_language_server)
336 .add_message_handler(update_language_server)
337 .add_message_handler(update_diagnostic_summary)
338 .add_message_handler(update_worktree_settings)
339 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
340 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
341 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
342 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
343 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
344 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
345 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
346 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
347 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
348 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
349 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
350 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
351 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
352 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
353 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
354 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
355 .add_request_handler(
356 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
357 )
358 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
359 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
360 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
361 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
362 .add_request_handler(
363 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
364 )
365 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
366 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
367 .add_request_handler(
368 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
369 )
370 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
371 .add_request_handler(
372 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
373 )
374 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
375 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
376 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
377 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
378 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
379 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
380 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
381 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
382 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
383 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
384 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
385 .add_request_handler(
386 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
387 )
388 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
389 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
390 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
391 .add_request_handler(lsp_query)
392 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
393 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
394 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
395 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
396 .add_message_handler(create_buffer_for_peer)
397 .add_request_handler(update_buffer)
398 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
399 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
400 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
401 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
402 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
403 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
404 .add_message_handler(
405 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
406 )
407 .add_request_handler(get_users)
408 .add_request_handler(fuzzy_search_users)
409 .add_request_handler(request_contact)
410 .add_request_handler(remove_contact)
411 .add_request_handler(respond_to_contact_request)
412 .add_message_handler(subscribe_to_channels)
413 .add_request_handler(create_channel)
414 .add_request_handler(delete_channel)
415 .add_request_handler(invite_channel_member)
416 .add_request_handler(remove_channel_member)
417 .add_request_handler(set_channel_member_role)
418 .add_request_handler(set_channel_visibility)
419 .add_request_handler(rename_channel)
420 .add_request_handler(join_channel_buffer)
421 .add_request_handler(leave_channel_buffer)
422 .add_message_handler(update_channel_buffer)
423 .add_request_handler(rejoin_channel_buffers)
424 .add_request_handler(get_channel_members)
425 .add_request_handler(respond_to_channel_invite)
426 .add_request_handler(join_channel)
427 .add_request_handler(join_channel_chat)
428 .add_message_handler(leave_channel_chat)
429 .add_request_handler(send_channel_message)
430 .add_request_handler(remove_channel_message)
431 .add_request_handler(update_channel_message)
432 .add_request_handler(get_channel_messages)
433 .add_request_handler(get_channel_messages_by_id)
434 .add_request_handler(get_notifications)
435 .add_request_handler(mark_notification_as_read)
436 .add_request_handler(move_channel)
437 .add_request_handler(reorder_channel)
438 .add_request_handler(follow)
439 .add_message_handler(unfollow)
440 .add_message_handler(update_followers)
441 .add_message_handler(acknowledge_channel_message)
442 .add_message_handler(acknowledge_buffer_version)
443 .add_request_handler(get_supermaven_api_key)
444 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
445 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
446 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
447 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
448 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
449 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
450 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
451 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
452 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
453 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
454 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
455 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
456 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
457 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
458 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
459 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
460 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
461 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
462 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
463 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
464 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
465 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
466 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
467 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
468 .add_message_handler(update_context)
469 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
470 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
471
472 Arc::new(server)
473 }
474
475 pub async fn start(&self) -> Result<()> {
476 let server_id = *self.id.lock();
477 let app_state = self.app_state.clone();
478 let peer = self.peer.clone();
479 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
480 let pool = self.connection_pool.clone();
481 let livekit_client = self.app_state.livekit_client.clone();
482
483 let span = info_span!("start server");
484 self.app_state.executor.spawn_detached(
485 async move {
486 tracing::info!("waiting for cleanup timeout");
487 timeout.await;
488 tracing::info!("cleanup timeout expired, retrieving stale rooms");
489
490 app_state
491 .db
492 .delete_stale_channel_chat_participants(
493 &app_state.config.zed_environment,
494 server_id,
495 )
496 .await
497 .trace_err();
498
499 if let Some((room_ids, channel_ids)) = app_state
500 .db
501 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
502 .await
503 .trace_err()
504 {
505 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
506 tracing::info!(
507 stale_channel_buffer_count = channel_ids.len(),
508 "retrieved stale channel buffers"
509 );
510
511 for channel_id in channel_ids {
512 if let Some(refreshed_channel_buffer) = app_state
513 .db
514 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
515 .await
516 .trace_err()
517 {
518 for connection_id in refreshed_channel_buffer.connection_ids {
519 peer.send(
520 connection_id,
521 proto::UpdateChannelBufferCollaborators {
522 channel_id: channel_id.to_proto(),
523 collaborators: refreshed_channel_buffer
524 .collaborators
525 .clone(),
526 },
527 )
528 .trace_err();
529 }
530 }
531 }
532
533 for room_id in room_ids {
534 let mut contacts_to_update = HashSet::default();
535 let mut canceled_calls_to_user_ids = Vec::new();
536 let mut livekit_room = String::new();
537 let mut delete_livekit_room = false;
538
539 if let Some(mut refreshed_room) = app_state
540 .db
541 .clear_stale_room_participants(room_id, server_id)
542 .await
543 .trace_err()
544 {
545 tracing::info!(
546 room_id = room_id.0,
547 new_participant_count = refreshed_room.room.participants.len(),
548 "refreshed room"
549 );
550 room_updated(&refreshed_room.room, &peer);
551 if let Some(channel) = refreshed_room.channel.as_ref() {
552 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
553 }
554 contacts_to_update
555 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
556 contacts_to_update
557 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
558 canceled_calls_to_user_ids =
559 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
560 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
561 delete_livekit_room = refreshed_room.room.participants.is_empty();
562 }
563
564 {
565 let pool = pool.lock();
566 for canceled_user_id in canceled_calls_to_user_ids {
567 for connection_id in pool.user_connection_ids(canceled_user_id) {
568 peer.send(
569 connection_id,
570 proto::CallCanceled {
571 room_id: room_id.to_proto(),
572 },
573 )
574 .trace_err();
575 }
576 }
577 }
578
579 for user_id in contacts_to_update {
580 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
581 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
582 if let Some((busy, contacts)) = busy.zip(contacts) {
583 let pool = pool.lock();
584 let updated_contact = contact_for_user(user_id, busy, &pool);
585 for contact in contacts {
586 if let db::Contact::Accepted {
587 user_id: contact_user_id,
588 ..
589 } = contact
590 {
591 for contact_conn_id in
592 pool.user_connection_ids(contact_user_id)
593 {
594 peer.send(
595 contact_conn_id,
596 proto::UpdateContacts {
597 contacts: vec![updated_contact.clone()],
598 remove_contacts: Default::default(),
599 incoming_requests: Default::default(),
600 remove_incoming_requests: Default::default(),
601 outgoing_requests: Default::default(),
602 remove_outgoing_requests: Default::default(),
603 },
604 )
605 .trace_err();
606 }
607 }
608 }
609 }
610 }
611
612 if let Some(live_kit) = livekit_client.as_ref()
613 && delete_livekit_room
614 {
615 live_kit.delete_room(livekit_room).await.trace_err();
616 }
617 }
618 }
619
620 app_state
621 .db
622 .delete_stale_channel_chat_participants(
623 &app_state.config.zed_environment,
624 server_id,
625 )
626 .await
627 .trace_err();
628
629 app_state
630 .db
631 .clear_old_worktree_entries(server_id)
632 .await
633 .trace_err();
634
635 app_state
636 .db
637 .delete_stale_servers(&app_state.config.zed_environment, server_id)
638 .await
639 .trace_err();
640 }
641 .instrument(span),
642 );
643 Ok(())
644 }
645
646 pub fn teardown(&self) {
647 self.peer.teardown();
648 self.connection_pool.lock().reset();
649 let _ = self.teardown.send(true);
650 }
651
652 #[cfg(test)]
653 pub fn reset(&self, id: ServerId) {
654 self.teardown();
655 *self.id.lock() = id;
656 self.peer.reset(id.0 as u32);
657 let _ = self.teardown.send(false);
658 }
659
660 #[cfg(test)]
661 pub fn id(&self) -> ServerId {
662 *self.id.lock()
663 }
664
665 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
666 where
667 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
668 Fut: 'static + Send + Future<Output = Result<()>>,
669 M: EnvelopedMessage,
670 {
671 let prev_handler = self.handlers.insert(
672 TypeId::of::<M>(),
673 Box::new(move |envelope, session, span| {
674 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
675 let received_at = envelope.received_at;
676 tracing::info!("message received");
677 let start_time = Instant::now();
678 let future = (handler)(
679 *envelope,
680 MessageContext {
681 session,
682 span: span.clone(),
683 },
684 );
685 async move {
686 let result = future.await;
687 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
688 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
689 let queue_duration_ms = total_duration_ms - processing_duration_ms;
690 span.record(TOTAL_DURATION_MS, total_duration_ms);
691 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
692 span.record(QUEUE_DURATION_MS, queue_duration_ms);
693 match result {
694 Err(error) => {
695 tracing::error!(?error, "error handling message")
696 }
697 Ok(()) => tracing::info!("finished handling message"),
698 }
699 }
700 .boxed()
701 }),
702 );
703 if prev_handler.is_some() {
704 panic!("registered a handler for the same message twice");
705 }
706 self
707 }
708
709 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
710 where
711 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
712 Fut: 'static + Send + Future<Output = Result<()>>,
713 M: EnvelopedMessage,
714 {
715 self.add_handler(move |envelope, session| handler(envelope.payload, session));
716 self
717 }
718
719 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
720 where
721 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
722 Fut: Send + Future<Output = Result<()>>,
723 M: RequestMessage,
724 {
725 let handler = Arc::new(handler);
726 self.add_handler(move |envelope, session| {
727 let receipt = envelope.receipt();
728 let handler = handler.clone();
729 async move {
730 let peer = session.peer.clone();
731 let responded = Arc::new(AtomicBool::default());
732 let response = Response {
733 peer: peer.clone(),
734 responded: responded.clone(),
735 receipt,
736 };
737 match (handler)(envelope.payload, response, session).await {
738 Ok(()) => {
739 if responded.load(std::sync::atomic::Ordering::SeqCst) {
740 Ok(())
741 } else {
742 Err(anyhow!("handler did not send a response"))?
743 }
744 }
745 Err(error) => {
746 let proto_err = match &error {
747 Error::Internal(err) => err.to_proto(),
748 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
749 };
750 peer.respond_with_error(receipt, proto_err)?;
751 Err(error)
752 }
753 }
754 }
755 })
756 }
757
758 pub fn handle_connection(
759 self: &Arc<Self>,
760 connection: Connection,
761 address: String,
762 principal: Principal,
763 zed_version: ZedVersion,
764 release_channel: Option<String>,
765 user_agent: Option<String>,
766 geoip_country_code: Option<String>,
767 system_id: Option<String>,
768 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
769 executor: Executor,
770 connection_guard: Option<ConnectionGuard>,
771 ) -> impl Future<Output = ()> + use<> {
772 let this = self.clone();
773 let span = info_span!("handle connection", %address,
774 connection_id=field::Empty,
775 user_id=field::Empty,
776 login=field::Empty,
777 impersonator=field::Empty,
778 user_agent=field::Empty,
779 geoip_country_code=field::Empty,
780 release_channel=field::Empty,
781 );
782 principal.update_span(&span);
783 if let Some(user_agent) = user_agent {
784 span.record("user_agent", user_agent);
785 }
786 if let Some(release_channel) = release_channel {
787 span.record("release_channel", release_channel);
788 }
789
790 if let Some(country_code) = geoip_country_code.as_ref() {
791 span.record("geoip_country_code", country_code);
792 }
793
794 let mut teardown = self.teardown.subscribe();
795 async move {
796 if *teardown.borrow() {
797 tracing::error!("server is tearing down");
798 return;
799 }
800
801 let (connection_id, handle_io, mut incoming_rx) =
802 this.peer.add_connection(connection, {
803 let executor = executor.clone();
804 move |duration| executor.sleep(duration)
805 });
806 tracing::Span::current().record("connection_id", format!("{}", connection_id));
807
808 tracing::info!("connection opened");
809
810 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
811 let http_client = match ReqwestClient::user_agent(&user_agent) {
812 Ok(http_client) => Arc::new(http_client),
813 Err(error) => {
814 tracing::error!(?error, "failed to create HTTP client");
815 return;
816 }
817 };
818
819 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
820 |supermaven_admin_api_key| {
821 Arc::new(SupermavenAdminApi::new(
822 supermaven_admin_api_key.to_string(),
823 http_client.clone(),
824 ))
825 },
826 );
827
828 let session = Session {
829 principal: principal.clone(),
830 connection_id,
831 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
832 peer: this.peer.clone(),
833 connection_pool: this.connection_pool.clone(),
834 app_state: this.app_state.clone(),
835 geoip_country_code,
836 system_id,
837 _executor: executor.clone(),
838 supermaven_client,
839 };
840
841 if let Err(error) = this
842 .send_initial_client_update(
843 connection_id,
844 zed_version,
845 send_connection_id,
846 &session,
847 )
848 .await
849 {
850 tracing::error!(?error, "failed to send initial client update");
851 return;
852 }
853 drop(connection_guard);
854
855 let handle_io = handle_io.fuse();
856 futures::pin_mut!(handle_io);
857
858 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
859 // This prevents deadlocks when e.g., client A performs a request to client B and
860 // client B performs a request to client A. If both clients stop processing further
861 // messages until their respective request completes, they won't have a chance to
862 // respond to the other client's request and cause a deadlock.
863 //
864 // This arrangement ensures we will attempt to process earlier messages first, but fall
865 // back to processing messages arrived later in the spirit of making progress.
866 const MAX_CONCURRENT_HANDLERS: usize = 256;
867 let mut foreground_message_handlers = FuturesUnordered::new();
868 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
869 let get_concurrent_handlers = {
870 let concurrent_handlers = concurrent_handlers.clone();
871 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
872 };
873 loop {
874 let next_message = async {
875 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
876 let message = incoming_rx.next().await;
877 // Cache the concurrent_handlers here, so that we know what the
878 // queue looks like as each handler starts
879 (permit, message, get_concurrent_handlers())
880 }
881 .fuse();
882 futures::pin_mut!(next_message);
883 futures::select_biased! {
884 _ = teardown.changed().fuse() => return,
885 result = handle_io => {
886 if let Err(error) = result {
887 tracing::error!(?error, "error handling I/O");
888 }
889 break;
890 }
891 _ = foreground_message_handlers.next() => {}
892 next_message = next_message => {
893 let (permit, message, concurrent_handlers) = next_message;
894 if let Some(message) = message {
895 let type_name = message.payload_type_name();
896 // note: we copy all the fields from the parent span so we can query them in the logs.
897 // (https://github.com/tokio-rs/tracing/issues/2670).
898 let span = tracing::info_span!("receive message",
899 %connection_id,
900 %address,
901 type_name,
902 concurrent_handlers,
903 user_id=field::Empty,
904 login=field::Empty,
905 impersonator=field::Empty,
906 lsp_query_request=field::Empty,
907 release_channel=field::Empty,
908 { TOTAL_DURATION_MS }=field::Empty,
909 { PROCESSING_DURATION_MS }=field::Empty,
910 { QUEUE_DURATION_MS }=field::Empty,
911 { HOST_WAITING_MS }=field::Empty
912 );
913 principal.update_span(&span);
914 let span_enter = span.enter();
915 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
916 let is_background = message.is_background();
917 let handle_message = (handler)(message, session.clone(), span.clone());
918 drop(span_enter);
919
920 let handle_message = async move {
921 handle_message.await;
922 drop(permit);
923 }.instrument(span);
924 if is_background {
925 executor.spawn_detached(handle_message);
926 } else {
927 foreground_message_handlers.push(handle_message);
928 }
929 } else {
930 tracing::error!("no message handler");
931 }
932 } else {
933 tracing::info!("connection closed");
934 break;
935 }
936 }
937 }
938 }
939
940 drop(foreground_message_handlers);
941 let concurrent_handlers = get_concurrent_handlers();
942 tracing::info!(concurrent_handlers, "signing out");
943 if let Err(error) = connection_lost(session, teardown, executor).await {
944 tracing::error!(?error, "error signing out");
945 }
946 }
947 .instrument(span)
948 }
949
950 async fn send_initial_client_update(
951 &self,
952 connection_id: ConnectionId,
953 zed_version: ZedVersion,
954 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
955 session: &Session,
956 ) -> Result<()> {
957 self.peer.send(
958 connection_id,
959 proto::Hello {
960 peer_id: Some(connection_id.into()),
961 },
962 )?;
963 tracing::info!("sent hello message");
964 if let Some(send_connection_id) = send_connection_id.take() {
965 let _ = send_connection_id.send(connection_id);
966 }
967
968 match &session.principal {
969 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
970 if !user.connected_once {
971 self.peer.send(connection_id, proto::ShowContacts {})?;
972 self.app_state
973 .db
974 .set_user_connected_once(user.id, true)
975 .await?;
976 }
977
978 let contacts = self.app_state.db.get_contacts(user.id).await?;
979
980 {
981 let mut pool = self.connection_pool.lock();
982 pool.add_connection(connection_id, user.id, user.admin, zed_version);
983 self.peer.send(
984 connection_id,
985 build_initial_contacts_update(contacts, &pool),
986 )?;
987 }
988
989 if should_auto_subscribe_to_channels(zed_version) {
990 subscribe_user_to_channels(user.id, session).await?;
991 }
992
993 if let Some(incoming_call) =
994 self.app_state.db.incoming_call_for_user(user.id).await?
995 {
996 self.peer.send(connection_id, incoming_call)?;
997 }
998
999 update_user_contacts(user.id, session).await?;
1000 }
1001 }
1002
1003 Ok(())
1004 }
1005
1006 pub async fn invite_code_redeemed(
1007 self: &Arc<Self>,
1008 inviter_id: UserId,
1009 invitee_id: UserId,
1010 ) -> Result<()> {
1011 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1012 && let Some(code) = &user.invite_code
1013 {
1014 let pool = self.connection_pool.lock();
1015 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1016 for connection_id in pool.user_connection_ids(inviter_id) {
1017 self.peer.send(
1018 connection_id,
1019 proto::UpdateContacts {
1020 contacts: vec![invitee_contact.clone()],
1021 ..Default::default()
1022 },
1023 )?;
1024 self.peer.send(
1025 connection_id,
1026 proto::UpdateInviteInfo {
1027 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1028 count: user.invite_count as u32,
1029 },
1030 )?;
1031 }
1032 }
1033 Ok(())
1034 }
1035
1036 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1037 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1038 && let Some(invite_code) = &user.invite_code
1039 {
1040 let pool = self.connection_pool.lock();
1041 for connection_id in pool.user_connection_ids(user_id) {
1042 self.peer.send(
1043 connection_id,
1044 proto::UpdateInviteInfo {
1045 url: format!(
1046 "{}{}",
1047 self.app_state.config.invite_link_prefix, invite_code
1048 ),
1049 count: user.invite_count as u32,
1050 },
1051 )?;
1052 }
1053 }
1054 Ok(())
1055 }
1056
1057 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1058 ServerSnapshot {
1059 connection_pool: ConnectionPoolGuard {
1060 guard: self.connection_pool.lock(),
1061 _not_send: PhantomData,
1062 },
1063 peer: &self.peer,
1064 }
1065 }
1066}
1067
1068impl Deref for ConnectionPoolGuard<'_> {
1069 type Target = ConnectionPool;
1070
1071 fn deref(&self) -> &Self::Target {
1072 &self.guard
1073 }
1074}
1075
1076impl DerefMut for ConnectionPoolGuard<'_> {
1077 fn deref_mut(&mut self) -> &mut Self::Target {
1078 &mut self.guard
1079 }
1080}
1081
1082impl Drop for ConnectionPoolGuard<'_> {
1083 fn drop(&mut self) {
1084 #[cfg(test)]
1085 self.check_invariants();
1086 }
1087}
1088
1089fn broadcast<F>(
1090 sender_id: Option<ConnectionId>,
1091 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1092 mut f: F,
1093) where
1094 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1095{
1096 for receiver_id in receiver_ids {
1097 if Some(receiver_id) != sender_id
1098 && let Err(error) = f(receiver_id)
1099 {
1100 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1101 }
1102 }
1103}
1104
1105pub struct ProtocolVersion(u32);
1106
1107impl Header for ProtocolVersion {
1108 fn name() -> &'static HeaderName {
1109 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1110 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1111 }
1112
1113 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1114 where
1115 Self: Sized,
1116 I: Iterator<Item = &'i axum::http::HeaderValue>,
1117 {
1118 let version = values
1119 .next()
1120 .ok_or_else(axum::headers::Error::invalid)?
1121 .to_str()
1122 .map_err(|_| axum::headers::Error::invalid())?
1123 .parse()
1124 .map_err(|_| axum::headers::Error::invalid())?;
1125 Ok(Self(version))
1126 }
1127
1128 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1129 values.extend([self.0.to_string().parse().unwrap()]);
1130 }
1131}
1132
1133pub struct AppVersionHeader(SemanticVersion);
1134impl Header for AppVersionHeader {
1135 fn name() -> &'static HeaderName {
1136 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1137 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1138 }
1139
1140 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1141 where
1142 Self: Sized,
1143 I: Iterator<Item = &'i axum::http::HeaderValue>,
1144 {
1145 let version = values
1146 .next()
1147 .ok_or_else(axum::headers::Error::invalid)?
1148 .to_str()
1149 .map_err(|_| axum::headers::Error::invalid())?
1150 .parse()
1151 .map_err(|_| axum::headers::Error::invalid())?;
1152 Ok(Self(version))
1153 }
1154
1155 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1156 values.extend([self.0.to_string().parse().unwrap()]);
1157 }
1158}
1159
1160#[derive(Debug)]
1161pub struct ReleaseChannelHeader(String);
1162
1163impl Header for ReleaseChannelHeader {
1164 fn name() -> &'static HeaderName {
1165 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1166 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1167 }
1168
1169 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1170 where
1171 Self: Sized,
1172 I: Iterator<Item = &'i axum::http::HeaderValue>,
1173 {
1174 Ok(Self(
1175 values
1176 .next()
1177 .ok_or_else(axum::headers::Error::invalid)?
1178 .to_str()
1179 .map_err(|_| axum::headers::Error::invalid())?
1180 .to_owned(),
1181 ))
1182 }
1183
1184 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1185 values.extend([self.0.parse().unwrap()]);
1186 }
1187}
1188
1189pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1190 Router::new()
1191 .route("/rpc", get(handle_websocket_request))
1192 .layer(
1193 ServiceBuilder::new()
1194 .layer(Extension(server.app_state.clone()))
1195 .layer(middleware::from_fn(auth::validate_header)),
1196 )
1197 .route("/metrics", get(handle_metrics))
1198 .layer(Extension(server))
1199}
1200
1201pub async fn handle_websocket_request(
1202 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1203 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1204 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1205 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1206 Extension(server): Extension<Arc<Server>>,
1207 Extension(principal): Extension<Principal>,
1208 user_agent: Option<TypedHeader<UserAgent>>,
1209 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1210 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1211 ws: WebSocketUpgrade,
1212) -> axum::response::Response {
1213 if protocol_version != rpc::PROTOCOL_VERSION {
1214 return (
1215 StatusCode::UPGRADE_REQUIRED,
1216 "client must be upgraded".to_string(),
1217 )
1218 .into_response();
1219 }
1220
1221 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1222 return (
1223 StatusCode::UPGRADE_REQUIRED,
1224 "no version header found".to_string(),
1225 )
1226 .into_response();
1227 };
1228
1229 let release_channel = release_channel_header.map(|header| header.0.0);
1230
1231 if !version.can_collaborate() {
1232 return (
1233 StatusCode::UPGRADE_REQUIRED,
1234 "client must be upgraded".to_string(),
1235 )
1236 .into_response();
1237 }
1238
1239 let socket_address = socket_address.to_string();
1240
1241 // Acquire connection guard before WebSocket upgrade
1242 let connection_guard = match ConnectionGuard::try_acquire() {
1243 Ok(guard) => guard,
1244 Err(()) => {
1245 return (
1246 StatusCode::SERVICE_UNAVAILABLE,
1247 "Too many concurrent connections",
1248 )
1249 .into_response();
1250 }
1251 };
1252
1253 ws.on_upgrade(move |socket| {
1254 let socket = socket
1255 .map_ok(to_tungstenite_message)
1256 .err_into()
1257 .with(|message| async move { to_axum_message(message) });
1258 let connection = Connection::new(Box::pin(socket));
1259 async move {
1260 server
1261 .handle_connection(
1262 connection,
1263 socket_address,
1264 principal,
1265 version,
1266 release_channel,
1267 user_agent.map(|header| header.to_string()),
1268 country_code_header.map(|header| header.to_string()),
1269 system_id_header.map(|header| header.to_string()),
1270 None,
1271 Executor::Production,
1272 Some(connection_guard),
1273 )
1274 .await;
1275 }
1276 })
1277}
1278
1279pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1280 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1281 let connections_metric = CONNECTIONS_METRIC
1282 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1283
1284 let connections = server
1285 .connection_pool
1286 .lock()
1287 .connections()
1288 .filter(|connection| !connection.admin)
1289 .count();
1290 connections_metric.set(connections as _);
1291
1292 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1293 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1294 register_int_gauge!(
1295 "shared_projects",
1296 "number of open projects with one or more guests"
1297 )
1298 .unwrap()
1299 });
1300
1301 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1302 shared_projects_metric.set(shared_projects as _);
1303
1304 let encoder = prometheus::TextEncoder::new();
1305 let metric_families = prometheus::gather();
1306 let encoded_metrics = encoder
1307 .encode_to_string(&metric_families)
1308 .map_err(|err| anyhow!("{err}"))?;
1309 Ok(encoded_metrics)
1310}
1311
1312#[instrument(err, skip(executor))]
1313async fn connection_lost(
1314 session: Session,
1315 mut teardown: watch::Receiver<bool>,
1316 executor: Executor,
1317) -> Result<()> {
1318 session.peer.disconnect(session.connection_id);
1319 session
1320 .connection_pool()
1321 .await
1322 .remove_connection(session.connection_id)?;
1323
1324 session
1325 .db()
1326 .await
1327 .connection_lost(session.connection_id)
1328 .await
1329 .trace_err();
1330
1331 futures::select_biased! {
1332 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1333
1334 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1335 leave_room_for_session(&session, session.connection_id).await.trace_err();
1336 leave_channel_buffers_for_session(&session)
1337 .await
1338 .trace_err();
1339
1340 if !session
1341 .connection_pool()
1342 .await
1343 .is_user_online(session.user_id())
1344 {
1345 let db = session.db().await;
1346 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1347 room_updated(&room, &session.peer);
1348 }
1349 }
1350
1351 update_user_contacts(session.user_id(), &session).await?;
1352 },
1353 _ = teardown.changed().fuse() => {}
1354 }
1355
1356 Ok(())
1357}
1358
1359/// Acknowledges a ping from a client, used to keep the connection alive.
1360async fn ping(
1361 _: proto::Ping,
1362 response: Response<proto::Ping>,
1363 _session: MessageContext,
1364) -> Result<()> {
1365 response.send(proto::Ack {})?;
1366 Ok(())
1367}
1368
1369/// Creates a new room for calling (outside of channels)
1370async fn create_room(
1371 _request: proto::CreateRoom,
1372 response: Response<proto::CreateRoom>,
1373 session: MessageContext,
1374) -> Result<()> {
1375 let livekit_room = nanoid::nanoid!(30);
1376
1377 let live_kit_connection_info = util::maybe!(async {
1378 let live_kit = session.app_state.livekit_client.as_ref();
1379 let live_kit = live_kit?;
1380 let user_id = session.user_id().to_string();
1381
1382 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1383
1384 Some(proto::LiveKitConnectionInfo {
1385 server_url: live_kit.url().into(),
1386 token,
1387 can_publish: true,
1388 })
1389 })
1390 .await;
1391
1392 let room = session
1393 .db()
1394 .await
1395 .create_room(session.user_id(), session.connection_id, &livekit_room)
1396 .await?;
1397
1398 response.send(proto::CreateRoomResponse {
1399 room: Some(room.clone()),
1400 live_kit_connection_info,
1401 })?;
1402
1403 update_user_contacts(session.user_id(), &session).await?;
1404 Ok(())
1405}
1406
1407/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1408async fn join_room(
1409 request: proto::JoinRoom,
1410 response: Response<proto::JoinRoom>,
1411 session: MessageContext,
1412) -> Result<()> {
1413 let room_id = RoomId::from_proto(request.id);
1414
1415 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1416
1417 if let Some(channel_id) = channel_id {
1418 return join_channel_internal(channel_id, Box::new(response), session).await;
1419 }
1420
1421 let joined_room = {
1422 let room = session
1423 .db()
1424 .await
1425 .join_room(room_id, session.user_id(), session.connection_id)
1426 .await?;
1427 room_updated(&room.room, &session.peer);
1428 room.into_inner()
1429 };
1430
1431 for connection_id in session
1432 .connection_pool()
1433 .await
1434 .user_connection_ids(session.user_id())
1435 {
1436 session
1437 .peer
1438 .send(
1439 connection_id,
1440 proto::CallCanceled {
1441 room_id: room_id.to_proto(),
1442 },
1443 )
1444 .trace_err();
1445 }
1446
1447 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1448 {
1449 live_kit
1450 .room_token(
1451 &joined_room.room.livekit_room,
1452 &session.user_id().to_string(),
1453 )
1454 .trace_err()
1455 .map(|token| proto::LiveKitConnectionInfo {
1456 server_url: live_kit.url().into(),
1457 token,
1458 can_publish: true,
1459 })
1460 } else {
1461 None
1462 };
1463
1464 response.send(proto::JoinRoomResponse {
1465 room: Some(joined_room.room),
1466 channel_id: None,
1467 live_kit_connection_info,
1468 })?;
1469
1470 update_user_contacts(session.user_id(), &session).await?;
1471 Ok(())
1472}
1473
1474/// Rejoin room is used to reconnect to a room after connection errors.
1475async fn rejoin_room(
1476 request: proto::RejoinRoom,
1477 response: Response<proto::RejoinRoom>,
1478 session: MessageContext,
1479) -> Result<()> {
1480 let room;
1481 let channel;
1482 {
1483 let mut rejoined_room = session
1484 .db()
1485 .await
1486 .rejoin_room(request, session.user_id(), session.connection_id)
1487 .await?;
1488
1489 response.send(proto::RejoinRoomResponse {
1490 room: Some(rejoined_room.room.clone()),
1491 reshared_projects: rejoined_room
1492 .reshared_projects
1493 .iter()
1494 .map(|project| proto::ResharedProject {
1495 id: project.id.to_proto(),
1496 collaborators: project
1497 .collaborators
1498 .iter()
1499 .map(|collaborator| collaborator.to_proto())
1500 .collect(),
1501 })
1502 .collect(),
1503 rejoined_projects: rejoined_room
1504 .rejoined_projects
1505 .iter()
1506 .map(|rejoined_project| rejoined_project.to_proto())
1507 .collect(),
1508 })?;
1509 room_updated(&rejoined_room.room, &session.peer);
1510
1511 for project in &rejoined_room.reshared_projects {
1512 for collaborator in &project.collaborators {
1513 session
1514 .peer
1515 .send(
1516 collaborator.connection_id,
1517 proto::UpdateProjectCollaborator {
1518 project_id: project.id.to_proto(),
1519 old_peer_id: Some(project.old_connection_id.into()),
1520 new_peer_id: Some(session.connection_id.into()),
1521 },
1522 )
1523 .trace_err();
1524 }
1525
1526 broadcast(
1527 Some(session.connection_id),
1528 project
1529 .collaborators
1530 .iter()
1531 .map(|collaborator| collaborator.connection_id),
1532 |connection_id| {
1533 session.peer.forward_send(
1534 session.connection_id,
1535 connection_id,
1536 proto::UpdateProject {
1537 project_id: project.id.to_proto(),
1538 worktrees: project.worktrees.clone(),
1539 },
1540 )
1541 },
1542 );
1543 }
1544
1545 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1546
1547 let rejoined_room = rejoined_room.into_inner();
1548
1549 room = rejoined_room.room;
1550 channel = rejoined_room.channel;
1551 }
1552
1553 if let Some(channel) = channel {
1554 channel_updated(
1555 &channel,
1556 &room,
1557 &session.peer,
1558 &*session.connection_pool().await,
1559 );
1560 }
1561
1562 update_user_contacts(session.user_id(), &session).await?;
1563 Ok(())
1564}
1565
1566fn notify_rejoined_projects(
1567 rejoined_projects: &mut Vec<RejoinedProject>,
1568 session: &Session,
1569) -> Result<()> {
1570 for project in rejoined_projects.iter() {
1571 for collaborator in &project.collaborators {
1572 session
1573 .peer
1574 .send(
1575 collaborator.connection_id,
1576 proto::UpdateProjectCollaborator {
1577 project_id: project.id.to_proto(),
1578 old_peer_id: Some(project.old_connection_id.into()),
1579 new_peer_id: Some(session.connection_id.into()),
1580 },
1581 )
1582 .trace_err();
1583 }
1584 }
1585
1586 for project in rejoined_projects {
1587 for worktree in mem::take(&mut project.worktrees) {
1588 // Stream this worktree's entries.
1589 let message = proto::UpdateWorktree {
1590 project_id: project.id.to_proto(),
1591 worktree_id: worktree.id,
1592 abs_path: worktree.abs_path.clone(),
1593 root_name: worktree.root_name,
1594 updated_entries: worktree.updated_entries,
1595 removed_entries: worktree.removed_entries,
1596 scan_id: worktree.scan_id,
1597 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1598 updated_repositories: worktree.updated_repositories,
1599 removed_repositories: worktree.removed_repositories,
1600 };
1601 for update in proto::split_worktree_update(message) {
1602 session.peer.send(session.connection_id, update)?;
1603 }
1604
1605 // Stream this worktree's diagnostics.
1606 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1607 if let Some(summary) = worktree_diagnostics.next() {
1608 let message = proto::UpdateDiagnosticSummary {
1609 project_id: project.id.to_proto(),
1610 worktree_id: worktree.id,
1611 summary: Some(summary),
1612 more_summaries: worktree_diagnostics.collect(),
1613 };
1614 session.peer.send(session.connection_id, message)?;
1615 }
1616
1617 for settings_file in worktree.settings_files {
1618 session.peer.send(
1619 session.connection_id,
1620 proto::UpdateWorktreeSettings {
1621 project_id: project.id.to_proto(),
1622 worktree_id: worktree.id,
1623 path: settings_file.path,
1624 content: Some(settings_file.content),
1625 kind: Some(settings_file.kind.to_proto().into()),
1626 },
1627 )?;
1628 }
1629 }
1630
1631 for repository in mem::take(&mut project.updated_repositories) {
1632 for update in split_repository_update(repository) {
1633 session.peer.send(session.connection_id, update)?;
1634 }
1635 }
1636
1637 for id in mem::take(&mut project.removed_repositories) {
1638 session.peer.send(
1639 session.connection_id,
1640 proto::RemoveRepository {
1641 project_id: project.id.to_proto(),
1642 id,
1643 },
1644 )?;
1645 }
1646 }
1647
1648 Ok(())
1649}
1650
1651/// leave room disconnects from the room.
1652async fn leave_room(
1653 _: proto::LeaveRoom,
1654 response: Response<proto::LeaveRoom>,
1655 session: MessageContext,
1656) -> Result<()> {
1657 leave_room_for_session(&session, session.connection_id).await?;
1658 response.send(proto::Ack {})?;
1659 Ok(())
1660}
1661
1662/// Updates the permissions of someone else in the room.
1663async fn set_room_participant_role(
1664 request: proto::SetRoomParticipantRole,
1665 response: Response<proto::SetRoomParticipantRole>,
1666 session: MessageContext,
1667) -> Result<()> {
1668 let user_id = UserId::from_proto(request.user_id);
1669 let role = ChannelRole::from(request.role());
1670
1671 let (livekit_room, can_publish) = {
1672 let room = session
1673 .db()
1674 .await
1675 .set_room_participant_role(
1676 session.user_id(),
1677 RoomId::from_proto(request.room_id),
1678 user_id,
1679 role,
1680 )
1681 .await?;
1682
1683 let livekit_room = room.livekit_room.clone();
1684 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1685 room_updated(&room, &session.peer);
1686 (livekit_room, can_publish)
1687 };
1688
1689 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1690 live_kit
1691 .update_participant(
1692 livekit_room.clone(),
1693 request.user_id.to_string(),
1694 livekit_api::proto::ParticipantPermission {
1695 can_subscribe: true,
1696 can_publish,
1697 can_publish_data: can_publish,
1698 hidden: false,
1699 recorder: false,
1700 },
1701 )
1702 .await
1703 .trace_err();
1704 }
1705
1706 response.send(proto::Ack {})?;
1707 Ok(())
1708}
1709
1710/// Call someone else into the current room
1711async fn call(
1712 request: proto::Call,
1713 response: Response<proto::Call>,
1714 session: MessageContext,
1715) -> Result<()> {
1716 let room_id = RoomId::from_proto(request.room_id);
1717 let calling_user_id = session.user_id();
1718 let calling_connection_id = session.connection_id;
1719 let called_user_id = UserId::from_proto(request.called_user_id);
1720 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1721 if !session
1722 .db()
1723 .await
1724 .has_contact(calling_user_id, called_user_id)
1725 .await?
1726 {
1727 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1728 }
1729
1730 let incoming_call = {
1731 let (room, incoming_call) = &mut *session
1732 .db()
1733 .await
1734 .call(
1735 room_id,
1736 calling_user_id,
1737 calling_connection_id,
1738 called_user_id,
1739 initial_project_id,
1740 )
1741 .await?;
1742 room_updated(room, &session.peer);
1743 mem::take(incoming_call)
1744 };
1745 update_user_contacts(called_user_id, &session).await?;
1746
1747 let mut calls = session
1748 .connection_pool()
1749 .await
1750 .user_connection_ids(called_user_id)
1751 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1752 .collect::<FuturesUnordered<_>>();
1753
1754 while let Some(call_response) = calls.next().await {
1755 match call_response.as_ref() {
1756 Ok(_) => {
1757 response.send(proto::Ack {})?;
1758 return Ok(());
1759 }
1760 Err(_) => {
1761 call_response.trace_err();
1762 }
1763 }
1764 }
1765
1766 {
1767 let room = session
1768 .db()
1769 .await
1770 .call_failed(room_id, called_user_id)
1771 .await?;
1772 room_updated(&room, &session.peer);
1773 }
1774 update_user_contacts(called_user_id, &session).await?;
1775
1776 Err(anyhow!("failed to ring user"))?
1777}
1778
1779/// Cancel an outgoing call.
1780async fn cancel_call(
1781 request: proto::CancelCall,
1782 response: Response<proto::CancelCall>,
1783 session: MessageContext,
1784) -> Result<()> {
1785 let called_user_id = UserId::from_proto(request.called_user_id);
1786 let room_id = RoomId::from_proto(request.room_id);
1787 {
1788 let room = session
1789 .db()
1790 .await
1791 .cancel_call(room_id, session.connection_id, called_user_id)
1792 .await?;
1793 room_updated(&room, &session.peer);
1794 }
1795
1796 for connection_id in session
1797 .connection_pool()
1798 .await
1799 .user_connection_ids(called_user_id)
1800 {
1801 session
1802 .peer
1803 .send(
1804 connection_id,
1805 proto::CallCanceled {
1806 room_id: room_id.to_proto(),
1807 },
1808 )
1809 .trace_err();
1810 }
1811 response.send(proto::Ack {})?;
1812
1813 update_user_contacts(called_user_id, &session).await?;
1814 Ok(())
1815}
1816
1817/// Decline an incoming call.
1818async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1819 let room_id = RoomId::from_proto(message.room_id);
1820 {
1821 let room = session
1822 .db()
1823 .await
1824 .decline_call(Some(room_id), session.user_id())
1825 .await?
1826 .context("declining call")?;
1827 room_updated(&room, &session.peer);
1828 }
1829
1830 for connection_id in session
1831 .connection_pool()
1832 .await
1833 .user_connection_ids(session.user_id())
1834 {
1835 session
1836 .peer
1837 .send(
1838 connection_id,
1839 proto::CallCanceled {
1840 room_id: room_id.to_proto(),
1841 },
1842 )
1843 .trace_err();
1844 }
1845 update_user_contacts(session.user_id(), &session).await?;
1846 Ok(())
1847}
1848
1849/// Updates other participants in the room with your current location.
1850async fn update_participant_location(
1851 request: proto::UpdateParticipantLocation,
1852 response: Response<proto::UpdateParticipantLocation>,
1853 session: MessageContext,
1854) -> Result<()> {
1855 let room_id = RoomId::from_proto(request.room_id);
1856 let location = request.location.context("invalid location")?;
1857
1858 let db = session.db().await;
1859 let room = db
1860 .update_room_participant_location(room_id, session.connection_id, location)
1861 .await?;
1862
1863 room_updated(&room, &session.peer);
1864 response.send(proto::Ack {})?;
1865 Ok(())
1866}
1867
1868/// Share a project into the room.
1869async fn share_project(
1870 request: proto::ShareProject,
1871 response: Response<proto::ShareProject>,
1872 session: MessageContext,
1873) -> Result<()> {
1874 let (project_id, room) = &*session
1875 .db()
1876 .await
1877 .share_project(
1878 RoomId::from_proto(request.room_id),
1879 session.connection_id,
1880 &request.worktrees,
1881 request.is_ssh_project,
1882 request.windows_paths.unwrap_or(false),
1883 )
1884 .await?;
1885 response.send(proto::ShareProjectResponse {
1886 project_id: project_id.to_proto(),
1887 })?;
1888 room_updated(room, &session.peer);
1889
1890 Ok(())
1891}
1892
1893/// Unshare a project from the room.
1894async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1895 let project_id = ProjectId::from_proto(message.project_id);
1896 unshare_project_internal(project_id, session.connection_id, &session).await
1897}
1898
1899async fn unshare_project_internal(
1900 project_id: ProjectId,
1901 connection_id: ConnectionId,
1902 session: &Session,
1903) -> Result<()> {
1904 let delete = {
1905 let room_guard = session
1906 .db()
1907 .await
1908 .unshare_project(project_id, connection_id)
1909 .await?;
1910
1911 let (delete, room, guest_connection_ids) = &*room_guard;
1912
1913 let message = proto::UnshareProject {
1914 project_id: project_id.to_proto(),
1915 };
1916
1917 broadcast(
1918 Some(connection_id),
1919 guest_connection_ids.iter().copied(),
1920 |conn_id| session.peer.send(conn_id, message.clone()),
1921 );
1922 if let Some(room) = room {
1923 room_updated(room, &session.peer);
1924 }
1925
1926 *delete
1927 };
1928
1929 if delete {
1930 let db = session.db().await;
1931 db.delete_project(project_id).await?;
1932 }
1933
1934 Ok(())
1935}
1936
1937/// Join someone elses shared project.
1938async fn join_project(
1939 request: proto::JoinProject,
1940 response: Response<proto::JoinProject>,
1941 session: MessageContext,
1942) -> Result<()> {
1943 let project_id = ProjectId::from_proto(request.project_id);
1944
1945 tracing::info!(%project_id, "join project");
1946
1947 let db = session.db().await;
1948 let (project, replica_id) = &mut *db
1949 .join_project(
1950 project_id,
1951 session.connection_id,
1952 session.user_id(),
1953 request.committer_name.clone(),
1954 request.committer_email.clone(),
1955 )
1956 .await?;
1957 drop(db);
1958 tracing::info!(%project_id, "join remote project");
1959 let collaborators = project
1960 .collaborators
1961 .iter()
1962 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1963 .map(|collaborator| collaborator.to_proto())
1964 .collect::<Vec<_>>();
1965 let project_id = project.id;
1966 let guest_user_id = session.user_id();
1967
1968 let worktrees = project
1969 .worktrees
1970 .iter()
1971 .map(|(id, worktree)| proto::WorktreeMetadata {
1972 id: *id,
1973 root_name: worktree.root_name.clone(),
1974 visible: worktree.visible,
1975 abs_path: worktree.abs_path.clone(),
1976 })
1977 .collect::<Vec<_>>();
1978
1979 let add_project_collaborator = proto::AddProjectCollaborator {
1980 project_id: project_id.to_proto(),
1981 collaborator: Some(proto::Collaborator {
1982 peer_id: Some(session.connection_id.into()),
1983 replica_id: replica_id.0 as u32,
1984 user_id: guest_user_id.to_proto(),
1985 is_host: false,
1986 committer_name: request.committer_name.clone(),
1987 committer_email: request.committer_email.clone(),
1988 }),
1989 };
1990
1991 for collaborator in &collaborators {
1992 session
1993 .peer
1994 .send(
1995 collaborator.peer_id.unwrap().into(),
1996 add_project_collaborator.clone(),
1997 )
1998 .trace_err();
1999 }
2000
2001 // First, we send the metadata associated with each worktree.
2002 let (language_servers, language_server_capabilities) = project
2003 .language_servers
2004 .clone()
2005 .into_iter()
2006 .map(|server| (server.server, server.capabilities))
2007 .unzip();
2008 response.send(proto::JoinProjectResponse {
2009 project_id: project.id.0 as u64,
2010 worktrees,
2011 replica_id: replica_id.0 as u32,
2012 collaborators,
2013 language_servers,
2014 language_server_capabilities,
2015 role: project.role.into(),
2016 windows_paths: project.path_style == PathStyle::Windows,
2017 })?;
2018
2019 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2020 // Stream this worktree's entries.
2021 let message = proto::UpdateWorktree {
2022 project_id: project_id.to_proto(),
2023 worktree_id,
2024 abs_path: worktree.abs_path.clone(),
2025 root_name: worktree.root_name,
2026 updated_entries: worktree.entries,
2027 removed_entries: Default::default(),
2028 scan_id: worktree.scan_id,
2029 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2030 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2031 removed_repositories: Default::default(),
2032 };
2033 for update in proto::split_worktree_update(message) {
2034 session.peer.send(session.connection_id, update.clone())?;
2035 }
2036
2037 // Stream this worktree's diagnostics.
2038 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2039 if let Some(summary) = worktree_diagnostics.next() {
2040 let message = proto::UpdateDiagnosticSummary {
2041 project_id: project.id.to_proto(),
2042 worktree_id: worktree.id,
2043 summary: Some(summary),
2044 more_summaries: worktree_diagnostics.collect(),
2045 };
2046 session.peer.send(session.connection_id, message)?;
2047 }
2048
2049 for settings_file in worktree.settings_files {
2050 session.peer.send(
2051 session.connection_id,
2052 proto::UpdateWorktreeSettings {
2053 project_id: project_id.to_proto(),
2054 worktree_id: worktree.id,
2055 path: settings_file.path,
2056 content: Some(settings_file.content),
2057 kind: Some(settings_file.kind.to_proto() as i32),
2058 },
2059 )?;
2060 }
2061 }
2062
2063 for repository in mem::take(&mut project.repositories) {
2064 for update in split_repository_update(repository) {
2065 session.peer.send(session.connection_id, update)?;
2066 }
2067 }
2068
2069 for language_server in &project.language_servers {
2070 session.peer.send(
2071 session.connection_id,
2072 proto::UpdateLanguageServer {
2073 project_id: project_id.to_proto(),
2074 server_name: Some(language_server.server.name.clone()),
2075 language_server_id: language_server.server.id,
2076 variant: Some(
2077 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2078 proto::LspDiskBasedDiagnosticsUpdated {},
2079 ),
2080 ),
2081 },
2082 )?;
2083 }
2084
2085 Ok(())
2086}
2087
2088/// Leave someone elses shared project.
2089async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2090 let sender_id = session.connection_id;
2091 let project_id = ProjectId::from_proto(request.project_id);
2092 let db = session.db().await;
2093
2094 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2095 tracing::info!(
2096 %project_id,
2097 "leave project"
2098 );
2099
2100 project_left(project, &session);
2101 if let Some(room) = room {
2102 room_updated(room, &session.peer);
2103 }
2104
2105 Ok(())
2106}
2107
2108/// Updates other participants with changes to the project
2109async fn update_project(
2110 request: proto::UpdateProject,
2111 response: Response<proto::UpdateProject>,
2112 session: MessageContext,
2113) -> Result<()> {
2114 let project_id = ProjectId::from_proto(request.project_id);
2115 let (room, guest_connection_ids) = &*session
2116 .db()
2117 .await
2118 .update_project(project_id, session.connection_id, &request.worktrees)
2119 .await?;
2120 broadcast(
2121 Some(session.connection_id),
2122 guest_connection_ids.iter().copied(),
2123 |connection_id| {
2124 session
2125 .peer
2126 .forward_send(session.connection_id, connection_id, request.clone())
2127 },
2128 );
2129 if let Some(room) = room {
2130 room_updated(room, &session.peer);
2131 }
2132 response.send(proto::Ack {})?;
2133
2134 Ok(())
2135}
2136
2137/// Updates other participants with changes to the worktree
2138async fn update_worktree(
2139 request: proto::UpdateWorktree,
2140 response: Response<proto::UpdateWorktree>,
2141 session: MessageContext,
2142) -> Result<()> {
2143 let guest_connection_ids = session
2144 .db()
2145 .await
2146 .update_worktree(&request, session.connection_id)
2147 .await?;
2148
2149 broadcast(
2150 Some(session.connection_id),
2151 guest_connection_ids.iter().copied(),
2152 |connection_id| {
2153 session
2154 .peer
2155 .forward_send(session.connection_id, connection_id, request.clone())
2156 },
2157 );
2158 response.send(proto::Ack {})?;
2159 Ok(())
2160}
2161
2162async fn update_repository(
2163 request: proto::UpdateRepository,
2164 response: Response<proto::UpdateRepository>,
2165 session: MessageContext,
2166) -> Result<()> {
2167 let guest_connection_ids = session
2168 .db()
2169 .await
2170 .update_repository(&request, session.connection_id)
2171 .await?;
2172
2173 broadcast(
2174 Some(session.connection_id),
2175 guest_connection_ids.iter().copied(),
2176 |connection_id| {
2177 session
2178 .peer
2179 .forward_send(session.connection_id, connection_id, request.clone())
2180 },
2181 );
2182 response.send(proto::Ack {})?;
2183 Ok(())
2184}
2185
2186async fn remove_repository(
2187 request: proto::RemoveRepository,
2188 response: Response<proto::RemoveRepository>,
2189 session: MessageContext,
2190) -> Result<()> {
2191 let guest_connection_ids = session
2192 .db()
2193 .await
2194 .remove_repository(&request, session.connection_id)
2195 .await?;
2196
2197 broadcast(
2198 Some(session.connection_id),
2199 guest_connection_ids.iter().copied(),
2200 |connection_id| {
2201 session
2202 .peer
2203 .forward_send(session.connection_id, connection_id, request.clone())
2204 },
2205 );
2206 response.send(proto::Ack {})?;
2207 Ok(())
2208}
2209
2210/// Updates other participants with changes to the diagnostics
2211async fn update_diagnostic_summary(
2212 message: proto::UpdateDiagnosticSummary,
2213 session: MessageContext,
2214) -> Result<()> {
2215 let guest_connection_ids = session
2216 .db()
2217 .await
2218 .update_diagnostic_summary(&message, session.connection_id)
2219 .await?;
2220
2221 broadcast(
2222 Some(session.connection_id),
2223 guest_connection_ids.iter().copied(),
2224 |connection_id| {
2225 session
2226 .peer
2227 .forward_send(session.connection_id, connection_id, message.clone())
2228 },
2229 );
2230
2231 Ok(())
2232}
2233
2234/// Updates other participants with changes to the worktree settings
2235async fn update_worktree_settings(
2236 message: proto::UpdateWorktreeSettings,
2237 session: MessageContext,
2238) -> Result<()> {
2239 let guest_connection_ids = session
2240 .db()
2241 .await
2242 .update_worktree_settings(&message, session.connection_id)
2243 .await?;
2244
2245 broadcast(
2246 Some(session.connection_id),
2247 guest_connection_ids.iter().copied(),
2248 |connection_id| {
2249 session
2250 .peer
2251 .forward_send(session.connection_id, connection_id, message.clone())
2252 },
2253 );
2254
2255 Ok(())
2256}
2257
2258/// Notify other participants that a language server has started.
2259async fn start_language_server(
2260 request: proto::StartLanguageServer,
2261 session: MessageContext,
2262) -> Result<()> {
2263 let guest_connection_ids = session
2264 .db()
2265 .await
2266 .start_language_server(&request, session.connection_id)
2267 .await?;
2268
2269 broadcast(
2270 Some(session.connection_id),
2271 guest_connection_ids.iter().copied(),
2272 |connection_id| {
2273 session
2274 .peer
2275 .forward_send(session.connection_id, connection_id, request.clone())
2276 },
2277 );
2278 Ok(())
2279}
2280
2281/// Notify other participants that a language server has changed.
2282async fn update_language_server(
2283 request: proto::UpdateLanguageServer,
2284 session: MessageContext,
2285) -> Result<()> {
2286 let project_id = ProjectId::from_proto(request.project_id);
2287 let db = session.db().await;
2288
2289 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2290 && let Some(capabilities) = update.capabilities.clone()
2291 {
2292 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2293 .await?;
2294 }
2295
2296 let project_connection_ids = db
2297 .project_connection_ids(project_id, session.connection_id, true)
2298 .await?;
2299 broadcast(
2300 Some(session.connection_id),
2301 project_connection_ids.iter().copied(),
2302 |connection_id| {
2303 session
2304 .peer
2305 .forward_send(session.connection_id, connection_id, request.clone())
2306 },
2307 );
2308 Ok(())
2309}
2310
2311/// forward a project request to the host. These requests should be read only
2312/// as guests are allowed to send them.
2313async fn forward_read_only_project_request<T>(
2314 request: T,
2315 response: Response<T>,
2316 session: MessageContext,
2317) -> Result<()>
2318where
2319 T: EntityMessage + RequestMessage,
2320{
2321 let project_id = ProjectId::from_proto(request.remote_entity_id());
2322 let host_connection_id = session
2323 .db()
2324 .await
2325 .host_for_read_only_project_request(project_id, session.connection_id)
2326 .await?;
2327 let payload = session.forward_request(host_connection_id, request).await?;
2328 response.send(payload)?;
2329 Ok(())
2330}
2331
2332/// forward a project request to the host. These requests are disallowed
2333/// for guests.
2334async fn forward_mutating_project_request<T>(
2335 request: T,
2336 response: Response<T>,
2337 session: MessageContext,
2338) -> Result<()>
2339where
2340 T: EntityMessage + RequestMessage,
2341{
2342 let project_id = ProjectId::from_proto(request.remote_entity_id());
2343
2344 let host_connection_id = session
2345 .db()
2346 .await
2347 .host_for_mutating_project_request(project_id, session.connection_id)
2348 .await?;
2349 let payload = session.forward_request(host_connection_id, request).await?;
2350 response.send(payload)?;
2351 Ok(())
2352}
2353
2354async fn lsp_query(
2355 request: proto::LspQuery,
2356 response: Response<proto::LspQuery>,
2357 session: MessageContext,
2358) -> Result<()> {
2359 let (name, should_write) = request.query_name_and_write_permissions();
2360 tracing::Span::current().record("lsp_query_request", name);
2361 tracing::info!("lsp_query message received");
2362 if should_write {
2363 forward_mutating_project_request(request, response, session).await
2364 } else {
2365 forward_read_only_project_request(request, response, session).await
2366 }
2367}
2368
2369/// Notify other participants that a new buffer has been created
2370async fn create_buffer_for_peer(
2371 request: proto::CreateBufferForPeer,
2372 session: MessageContext,
2373) -> Result<()> {
2374 session
2375 .db()
2376 .await
2377 .check_user_is_project_host(
2378 ProjectId::from_proto(request.project_id),
2379 session.connection_id,
2380 )
2381 .await?;
2382 let peer_id = request.peer_id.context("invalid peer id")?;
2383 session
2384 .peer
2385 .forward_send(session.connection_id, peer_id.into(), request)?;
2386 Ok(())
2387}
2388
2389/// Notify other participants that a buffer has been updated. This is
2390/// allowed for guests as long as the update is limited to selections.
2391async fn update_buffer(
2392 request: proto::UpdateBuffer,
2393 response: Response<proto::UpdateBuffer>,
2394 session: MessageContext,
2395) -> Result<()> {
2396 let project_id = ProjectId::from_proto(request.project_id);
2397 let mut capability = Capability::ReadOnly;
2398
2399 for op in request.operations.iter() {
2400 match op.variant {
2401 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2402 Some(_) => capability = Capability::ReadWrite,
2403 }
2404 }
2405
2406 let host = {
2407 let guard = session
2408 .db()
2409 .await
2410 .connections_for_buffer_update(project_id, session.connection_id, capability)
2411 .await?;
2412
2413 let (host, guests) = &*guard;
2414
2415 broadcast(
2416 Some(session.connection_id),
2417 guests.clone(),
2418 |connection_id| {
2419 session
2420 .peer
2421 .forward_send(session.connection_id, connection_id, request.clone())
2422 },
2423 );
2424
2425 *host
2426 };
2427
2428 if host != session.connection_id {
2429 session.forward_request(host, request.clone()).await?;
2430 }
2431
2432 response.send(proto::Ack {})?;
2433 Ok(())
2434}
2435
2436async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2437 let project_id = ProjectId::from_proto(message.project_id);
2438
2439 let operation = message.operation.as_ref().context("invalid operation")?;
2440 let capability = match operation.variant.as_ref() {
2441 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2442 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2443 match buffer_op.variant {
2444 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2445 Capability::ReadOnly
2446 }
2447 _ => Capability::ReadWrite,
2448 }
2449 } else {
2450 Capability::ReadWrite
2451 }
2452 }
2453 Some(_) => Capability::ReadWrite,
2454 None => Capability::ReadOnly,
2455 };
2456
2457 let guard = session
2458 .db()
2459 .await
2460 .connections_for_buffer_update(project_id, session.connection_id, capability)
2461 .await?;
2462
2463 let (host, guests) = &*guard;
2464
2465 broadcast(
2466 Some(session.connection_id),
2467 guests.iter().chain([host]).copied(),
2468 |connection_id| {
2469 session
2470 .peer
2471 .forward_send(session.connection_id, connection_id, message.clone())
2472 },
2473 );
2474
2475 Ok(())
2476}
2477
2478/// Notify other participants that a project has been updated.
2479async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2480 request: T,
2481 session: MessageContext,
2482) -> Result<()> {
2483 let project_id = ProjectId::from_proto(request.remote_entity_id());
2484 let project_connection_ids = session
2485 .db()
2486 .await
2487 .project_connection_ids(project_id, session.connection_id, false)
2488 .await?;
2489
2490 broadcast(
2491 Some(session.connection_id),
2492 project_connection_ids.iter().copied(),
2493 |connection_id| {
2494 session
2495 .peer
2496 .forward_send(session.connection_id, connection_id, request.clone())
2497 },
2498 );
2499 Ok(())
2500}
2501
2502/// Start following another user in a call.
2503async fn follow(
2504 request: proto::Follow,
2505 response: Response<proto::Follow>,
2506 session: MessageContext,
2507) -> Result<()> {
2508 let room_id = RoomId::from_proto(request.room_id);
2509 let project_id = request.project_id.map(ProjectId::from_proto);
2510 let leader_id = request.leader_id.context("invalid leader id")?.into();
2511 let follower_id = session.connection_id;
2512
2513 session
2514 .db()
2515 .await
2516 .check_room_participants(room_id, leader_id, session.connection_id)
2517 .await?;
2518
2519 let response_payload = session.forward_request(leader_id, request).await?;
2520 response.send(response_payload)?;
2521
2522 if let Some(project_id) = project_id {
2523 let room = session
2524 .db()
2525 .await
2526 .follow(room_id, project_id, leader_id, follower_id)
2527 .await?;
2528 room_updated(&room, &session.peer);
2529 }
2530
2531 Ok(())
2532}
2533
2534/// Stop following another user in a call.
2535async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2536 let room_id = RoomId::from_proto(request.room_id);
2537 let project_id = request.project_id.map(ProjectId::from_proto);
2538 let leader_id = request.leader_id.context("invalid leader id")?.into();
2539 let follower_id = session.connection_id;
2540
2541 session
2542 .db()
2543 .await
2544 .check_room_participants(room_id, leader_id, session.connection_id)
2545 .await?;
2546
2547 session
2548 .peer
2549 .forward_send(session.connection_id, leader_id, request)?;
2550
2551 if let Some(project_id) = project_id {
2552 let room = session
2553 .db()
2554 .await
2555 .unfollow(room_id, project_id, leader_id, follower_id)
2556 .await?;
2557 room_updated(&room, &session.peer);
2558 }
2559
2560 Ok(())
2561}
2562
2563/// Notify everyone following you of your current location.
2564async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2565 let room_id = RoomId::from_proto(request.room_id);
2566 let database = session.db.lock().await;
2567
2568 let connection_ids = if let Some(project_id) = request.project_id {
2569 let project_id = ProjectId::from_proto(project_id);
2570 database
2571 .project_connection_ids(project_id, session.connection_id, true)
2572 .await?
2573 } else {
2574 database
2575 .room_connection_ids(room_id, session.connection_id)
2576 .await?
2577 };
2578
2579 // For now, don't send view update messages back to that view's current leader.
2580 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2581 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2582 _ => None,
2583 });
2584
2585 for connection_id in connection_ids.iter().cloned() {
2586 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2587 session
2588 .peer
2589 .forward_send(session.connection_id, connection_id, request.clone())?;
2590 }
2591 }
2592 Ok(())
2593}
2594
2595/// Get public data about users.
2596async fn get_users(
2597 request: proto::GetUsers,
2598 response: Response<proto::GetUsers>,
2599 session: MessageContext,
2600) -> Result<()> {
2601 let user_ids = request
2602 .user_ids
2603 .into_iter()
2604 .map(UserId::from_proto)
2605 .collect();
2606 let users = session
2607 .db()
2608 .await
2609 .get_users_by_ids(user_ids)
2610 .await?
2611 .into_iter()
2612 .map(|user| proto::User {
2613 id: user.id.to_proto(),
2614 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2615 github_login: user.github_login,
2616 name: user.name,
2617 })
2618 .collect();
2619 response.send(proto::UsersResponse { users })?;
2620 Ok(())
2621}
2622
2623/// Search for users (to invite) buy Github login
2624async fn fuzzy_search_users(
2625 request: proto::FuzzySearchUsers,
2626 response: Response<proto::FuzzySearchUsers>,
2627 session: MessageContext,
2628) -> Result<()> {
2629 let query = request.query;
2630 let users = match query.len() {
2631 0 => vec![],
2632 1 | 2 => session
2633 .db()
2634 .await
2635 .get_user_by_github_login(&query)
2636 .await?
2637 .into_iter()
2638 .collect(),
2639 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2640 };
2641 let users = users
2642 .into_iter()
2643 .filter(|user| user.id != session.user_id())
2644 .map(|user| proto::User {
2645 id: user.id.to_proto(),
2646 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2647 github_login: user.github_login,
2648 name: user.name,
2649 })
2650 .collect();
2651 response.send(proto::UsersResponse { users })?;
2652 Ok(())
2653}
2654
2655/// Send a contact request to another user.
2656async fn request_contact(
2657 request: proto::RequestContact,
2658 response: Response<proto::RequestContact>,
2659 session: MessageContext,
2660) -> Result<()> {
2661 let requester_id = session.user_id();
2662 let responder_id = UserId::from_proto(request.responder_id);
2663 if requester_id == responder_id {
2664 return Err(anyhow!("cannot add yourself as a contact"))?;
2665 }
2666
2667 let notifications = session
2668 .db()
2669 .await
2670 .send_contact_request(requester_id, responder_id)
2671 .await?;
2672
2673 // Update outgoing contact requests of requester
2674 let mut update = proto::UpdateContacts::default();
2675 update.outgoing_requests.push(responder_id.to_proto());
2676 for connection_id in session
2677 .connection_pool()
2678 .await
2679 .user_connection_ids(requester_id)
2680 {
2681 session.peer.send(connection_id, update.clone())?;
2682 }
2683
2684 // Update incoming contact requests of responder
2685 let mut update = proto::UpdateContacts::default();
2686 update
2687 .incoming_requests
2688 .push(proto::IncomingContactRequest {
2689 requester_id: requester_id.to_proto(),
2690 });
2691 let connection_pool = session.connection_pool().await;
2692 for connection_id in connection_pool.user_connection_ids(responder_id) {
2693 session.peer.send(connection_id, update.clone())?;
2694 }
2695
2696 send_notifications(&connection_pool, &session.peer, notifications);
2697
2698 response.send(proto::Ack {})?;
2699 Ok(())
2700}
2701
2702/// Accept or decline a contact request
2703async fn respond_to_contact_request(
2704 request: proto::RespondToContactRequest,
2705 response: Response<proto::RespondToContactRequest>,
2706 session: MessageContext,
2707) -> Result<()> {
2708 let responder_id = session.user_id();
2709 let requester_id = UserId::from_proto(request.requester_id);
2710 let db = session.db().await;
2711 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2712 db.dismiss_contact_notification(responder_id, requester_id)
2713 .await?;
2714 } else {
2715 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2716
2717 let notifications = db
2718 .respond_to_contact_request(responder_id, requester_id, accept)
2719 .await?;
2720 let requester_busy = db.is_user_busy(requester_id).await?;
2721 let responder_busy = db.is_user_busy(responder_id).await?;
2722
2723 let pool = session.connection_pool().await;
2724 // Update responder with new contact
2725 let mut update = proto::UpdateContacts::default();
2726 if accept {
2727 update
2728 .contacts
2729 .push(contact_for_user(requester_id, requester_busy, &pool));
2730 }
2731 update
2732 .remove_incoming_requests
2733 .push(requester_id.to_proto());
2734 for connection_id in pool.user_connection_ids(responder_id) {
2735 session.peer.send(connection_id, update.clone())?;
2736 }
2737
2738 // Update requester with new contact
2739 let mut update = proto::UpdateContacts::default();
2740 if accept {
2741 update
2742 .contacts
2743 .push(contact_for_user(responder_id, responder_busy, &pool));
2744 }
2745 update
2746 .remove_outgoing_requests
2747 .push(responder_id.to_proto());
2748
2749 for connection_id in pool.user_connection_ids(requester_id) {
2750 session.peer.send(connection_id, update.clone())?;
2751 }
2752
2753 send_notifications(&pool, &session.peer, notifications);
2754 }
2755
2756 response.send(proto::Ack {})?;
2757 Ok(())
2758}
2759
2760/// Remove a contact.
2761async fn remove_contact(
2762 request: proto::RemoveContact,
2763 response: Response<proto::RemoveContact>,
2764 session: MessageContext,
2765) -> Result<()> {
2766 let requester_id = session.user_id();
2767 let responder_id = UserId::from_proto(request.user_id);
2768 let db = session.db().await;
2769 let (contact_accepted, deleted_notification_id) =
2770 db.remove_contact(requester_id, responder_id).await?;
2771
2772 let pool = session.connection_pool().await;
2773 // Update outgoing contact requests of requester
2774 let mut update = proto::UpdateContacts::default();
2775 if contact_accepted {
2776 update.remove_contacts.push(responder_id.to_proto());
2777 } else {
2778 update
2779 .remove_outgoing_requests
2780 .push(responder_id.to_proto());
2781 }
2782 for connection_id in pool.user_connection_ids(requester_id) {
2783 session.peer.send(connection_id, update.clone())?;
2784 }
2785
2786 // Update incoming contact requests of responder
2787 let mut update = proto::UpdateContacts::default();
2788 if contact_accepted {
2789 update.remove_contacts.push(requester_id.to_proto());
2790 } else {
2791 update
2792 .remove_incoming_requests
2793 .push(requester_id.to_proto());
2794 }
2795 for connection_id in pool.user_connection_ids(responder_id) {
2796 session.peer.send(connection_id, update.clone())?;
2797 if let Some(notification_id) = deleted_notification_id {
2798 session.peer.send(
2799 connection_id,
2800 proto::DeleteNotification {
2801 notification_id: notification_id.to_proto(),
2802 },
2803 )?;
2804 }
2805 }
2806
2807 response.send(proto::Ack {})?;
2808 Ok(())
2809}
2810
2811fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2812 version.0.minor() < 139
2813}
2814
2815async fn subscribe_to_channels(
2816 _: proto::SubscribeToChannels,
2817 session: MessageContext,
2818) -> Result<()> {
2819 subscribe_user_to_channels(session.user_id(), &session).await?;
2820 Ok(())
2821}
2822
2823async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2824 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2825 let mut pool = session.connection_pool().await;
2826 for membership in &channels_for_user.channel_memberships {
2827 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2828 }
2829 session.peer.send(
2830 session.connection_id,
2831 build_update_user_channels(&channels_for_user),
2832 )?;
2833 session.peer.send(
2834 session.connection_id,
2835 build_channels_update(channels_for_user),
2836 )?;
2837 Ok(())
2838}
2839
2840/// Creates a new channel.
2841async fn create_channel(
2842 request: proto::CreateChannel,
2843 response: Response<proto::CreateChannel>,
2844 session: MessageContext,
2845) -> Result<()> {
2846 let db = session.db().await;
2847
2848 let parent_id = request.parent_id.map(ChannelId::from_proto);
2849 let (channel, membership) = db
2850 .create_channel(&request.name, parent_id, session.user_id())
2851 .await?;
2852
2853 let root_id = channel.root_id();
2854 let channel = Channel::from_model(channel);
2855
2856 response.send(proto::CreateChannelResponse {
2857 channel: Some(channel.to_proto()),
2858 parent_id: request.parent_id,
2859 })?;
2860
2861 let mut connection_pool = session.connection_pool().await;
2862 if let Some(membership) = membership {
2863 connection_pool.subscribe_to_channel(
2864 membership.user_id,
2865 membership.channel_id,
2866 membership.role,
2867 );
2868 let update = proto::UpdateUserChannels {
2869 channel_memberships: vec![proto::ChannelMembership {
2870 channel_id: membership.channel_id.to_proto(),
2871 role: membership.role.into(),
2872 }],
2873 ..Default::default()
2874 };
2875 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2876 session.peer.send(connection_id, update.clone())?;
2877 }
2878 }
2879
2880 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2881 if !role.can_see_channel(channel.visibility) {
2882 continue;
2883 }
2884
2885 let update = proto::UpdateChannels {
2886 channels: vec![channel.to_proto()],
2887 ..Default::default()
2888 };
2889 session.peer.send(connection_id, update.clone())?;
2890 }
2891
2892 Ok(())
2893}
2894
2895/// Delete a channel
2896async fn delete_channel(
2897 request: proto::DeleteChannel,
2898 response: Response<proto::DeleteChannel>,
2899 session: MessageContext,
2900) -> Result<()> {
2901 let db = session.db().await;
2902
2903 let channel_id = request.channel_id;
2904 let (root_channel, removed_channels) = db
2905 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2906 .await?;
2907 response.send(proto::Ack {})?;
2908
2909 // Notify members of removed channels
2910 let mut update = proto::UpdateChannels::default();
2911 update
2912 .delete_channels
2913 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2914
2915 let connection_pool = session.connection_pool().await;
2916 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2917 session.peer.send(connection_id, update.clone())?;
2918 }
2919
2920 Ok(())
2921}
2922
2923/// Invite someone to join a channel.
2924async fn invite_channel_member(
2925 request: proto::InviteChannelMember,
2926 response: Response<proto::InviteChannelMember>,
2927 session: MessageContext,
2928) -> Result<()> {
2929 let db = session.db().await;
2930 let channel_id = ChannelId::from_proto(request.channel_id);
2931 let invitee_id = UserId::from_proto(request.user_id);
2932 let InviteMemberResult {
2933 channel,
2934 notifications,
2935 } = db
2936 .invite_channel_member(
2937 channel_id,
2938 invitee_id,
2939 session.user_id(),
2940 request.role().into(),
2941 )
2942 .await?;
2943
2944 let update = proto::UpdateChannels {
2945 channel_invitations: vec![channel.to_proto()],
2946 ..Default::default()
2947 };
2948
2949 let connection_pool = session.connection_pool().await;
2950 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2951 session.peer.send(connection_id, update.clone())?;
2952 }
2953
2954 send_notifications(&connection_pool, &session.peer, notifications);
2955
2956 response.send(proto::Ack {})?;
2957 Ok(())
2958}
2959
2960/// remove someone from a channel
2961async fn remove_channel_member(
2962 request: proto::RemoveChannelMember,
2963 response: Response<proto::RemoveChannelMember>,
2964 session: MessageContext,
2965) -> Result<()> {
2966 let db = session.db().await;
2967 let channel_id = ChannelId::from_proto(request.channel_id);
2968 let member_id = UserId::from_proto(request.user_id);
2969
2970 let RemoveChannelMemberResult {
2971 membership_update,
2972 notification_id,
2973 } = db
2974 .remove_channel_member(channel_id, member_id, session.user_id())
2975 .await?;
2976
2977 let mut connection_pool = session.connection_pool().await;
2978 notify_membership_updated(
2979 &mut connection_pool,
2980 membership_update,
2981 member_id,
2982 &session.peer,
2983 );
2984 for connection_id in connection_pool.user_connection_ids(member_id) {
2985 if let Some(notification_id) = notification_id {
2986 session
2987 .peer
2988 .send(
2989 connection_id,
2990 proto::DeleteNotification {
2991 notification_id: notification_id.to_proto(),
2992 },
2993 )
2994 .trace_err();
2995 }
2996 }
2997
2998 response.send(proto::Ack {})?;
2999 Ok(())
3000}
3001
3002/// Toggle the channel between public and private.
3003/// Care is taken to maintain the invariant that public channels only descend from public channels,
3004/// (though members-only channels can appear at any point in the hierarchy).
3005async fn set_channel_visibility(
3006 request: proto::SetChannelVisibility,
3007 response: Response<proto::SetChannelVisibility>,
3008 session: MessageContext,
3009) -> Result<()> {
3010 let db = session.db().await;
3011 let channel_id = ChannelId::from_proto(request.channel_id);
3012 let visibility = request.visibility().into();
3013
3014 let channel_model = db
3015 .set_channel_visibility(channel_id, visibility, session.user_id())
3016 .await?;
3017 let root_id = channel_model.root_id();
3018 let channel = Channel::from_model(channel_model);
3019
3020 let mut connection_pool = session.connection_pool().await;
3021 for (user_id, role) in connection_pool
3022 .channel_user_ids(root_id)
3023 .collect::<Vec<_>>()
3024 .into_iter()
3025 {
3026 let update = if role.can_see_channel(channel.visibility) {
3027 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3028 proto::UpdateChannels {
3029 channels: vec![channel.to_proto()],
3030 ..Default::default()
3031 }
3032 } else {
3033 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3034 proto::UpdateChannels {
3035 delete_channels: vec![channel.id.to_proto()],
3036 ..Default::default()
3037 }
3038 };
3039
3040 for connection_id in connection_pool.user_connection_ids(user_id) {
3041 session.peer.send(connection_id, update.clone())?;
3042 }
3043 }
3044
3045 response.send(proto::Ack {})?;
3046 Ok(())
3047}
3048
3049/// Alter the role for a user in the channel.
3050async fn set_channel_member_role(
3051 request: proto::SetChannelMemberRole,
3052 response: Response<proto::SetChannelMemberRole>,
3053 session: MessageContext,
3054) -> Result<()> {
3055 let db = session.db().await;
3056 let channel_id = ChannelId::from_proto(request.channel_id);
3057 let member_id = UserId::from_proto(request.user_id);
3058 let result = db
3059 .set_channel_member_role(
3060 channel_id,
3061 session.user_id(),
3062 member_id,
3063 request.role().into(),
3064 )
3065 .await?;
3066
3067 match result {
3068 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3069 let mut connection_pool = session.connection_pool().await;
3070 notify_membership_updated(
3071 &mut connection_pool,
3072 membership_update,
3073 member_id,
3074 &session.peer,
3075 )
3076 }
3077 db::SetMemberRoleResult::InviteUpdated(channel) => {
3078 let update = proto::UpdateChannels {
3079 channel_invitations: vec![channel.to_proto()],
3080 ..Default::default()
3081 };
3082
3083 for connection_id in session
3084 .connection_pool()
3085 .await
3086 .user_connection_ids(member_id)
3087 {
3088 session.peer.send(connection_id, update.clone())?;
3089 }
3090 }
3091 }
3092
3093 response.send(proto::Ack {})?;
3094 Ok(())
3095}
3096
3097/// Change the name of a channel
3098async fn rename_channel(
3099 request: proto::RenameChannel,
3100 response: Response<proto::RenameChannel>,
3101 session: MessageContext,
3102) -> Result<()> {
3103 let db = session.db().await;
3104 let channel_id = ChannelId::from_proto(request.channel_id);
3105 let channel_model = db
3106 .rename_channel(channel_id, session.user_id(), &request.name)
3107 .await?;
3108 let root_id = channel_model.root_id();
3109 let channel = Channel::from_model(channel_model);
3110
3111 response.send(proto::RenameChannelResponse {
3112 channel: Some(channel.to_proto()),
3113 })?;
3114
3115 let connection_pool = session.connection_pool().await;
3116 let update = proto::UpdateChannels {
3117 channels: vec![channel.to_proto()],
3118 ..Default::default()
3119 };
3120 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3121 if role.can_see_channel(channel.visibility) {
3122 session.peer.send(connection_id, update.clone())?;
3123 }
3124 }
3125
3126 Ok(())
3127}
3128
3129/// Move a channel to a new parent.
3130async fn move_channel(
3131 request: proto::MoveChannel,
3132 response: Response<proto::MoveChannel>,
3133 session: MessageContext,
3134) -> Result<()> {
3135 let channel_id = ChannelId::from_proto(request.channel_id);
3136 let to = ChannelId::from_proto(request.to);
3137
3138 let (root_id, channels) = session
3139 .db()
3140 .await
3141 .move_channel(channel_id, to, session.user_id())
3142 .await?;
3143
3144 let connection_pool = session.connection_pool().await;
3145 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3146 let channels = channels
3147 .iter()
3148 .filter_map(|channel| {
3149 if role.can_see_channel(channel.visibility) {
3150 Some(channel.to_proto())
3151 } else {
3152 None
3153 }
3154 })
3155 .collect::<Vec<_>>();
3156 if channels.is_empty() {
3157 continue;
3158 }
3159
3160 let update = proto::UpdateChannels {
3161 channels,
3162 ..Default::default()
3163 };
3164
3165 session.peer.send(connection_id, update.clone())?;
3166 }
3167
3168 response.send(Ack {})?;
3169 Ok(())
3170}
3171
3172async fn reorder_channel(
3173 request: proto::ReorderChannel,
3174 response: Response<proto::ReorderChannel>,
3175 session: MessageContext,
3176) -> Result<()> {
3177 let channel_id = ChannelId::from_proto(request.channel_id);
3178 let direction = request.direction();
3179
3180 let updated_channels = session
3181 .db()
3182 .await
3183 .reorder_channel(channel_id, direction, session.user_id())
3184 .await?;
3185
3186 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3187 let connection_pool = session.connection_pool().await;
3188 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3189 let channels = updated_channels
3190 .iter()
3191 .filter_map(|channel| {
3192 if role.can_see_channel(channel.visibility) {
3193 Some(channel.to_proto())
3194 } else {
3195 None
3196 }
3197 })
3198 .collect::<Vec<_>>();
3199
3200 if channels.is_empty() {
3201 continue;
3202 }
3203
3204 let update = proto::UpdateChannels {
3205 channels,
3206 ..Default::default()
3207 };
3208
3209 session.peer.send(connection_id, update.clone())?;
3210 }
3211 }
3212
3213 response.send(Ack {})?;
3214 Ok(())
3215}
3216
3217/// Get the list of channel members
3218async fn get_channel_members(
3219 request: proto::GetChannelMembers,
3220 response: Response<proto::GetChannelMembers>,
3221 session: MessageContext,
3222) -> Result<()> {
3223 let db = session.db().await;
3224 let channel_id = ChannelId::from_proto(request.channel_id);
3225 let limit = if request.limit == 0 {
3226 u16::MAX as u64
3227 } else {
3228 request.limit
3229 };
3230 let (members, users) = db
3231 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3232 .await?;
3233 response.send(proto::GetChannelMembersResponse { members, users })?;
3234 Ok(())
3235}
3236
3237/// Accept or decline a channel invitation.
3238async fn respond_to_channel_invite(
3239 request: proto::RespondToChannelInvite,
3240 response: Response<proto::RespondToChannelInvite>,
3241 session: MessageContext,
3242) -> Result<()> {
3243 let db = session.db().await;
3244 let channel_id = ChannelId::from_proto(request.channel_id);
3245 let RespondToChannelInvite {
3246 membership_update,
3247 notifications,
3248 } = db
3249 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3250 .await?;
3251
3252 let mut connection_pool = session.connection_pool().await;
3253 if let Some(membership_update) = membership_update {
3254 notify_membership_updated(
3255 &mut connection_pool,
3256 membership_update,
3257 session.user_id(),
3258 &session.peer,
3259 );
3260 } else {
3261 let update = proto::UpdateChannels {
3262 remove_channel_invitations: vec![channel_id.to_proto()],
3263 ..Default::default()
3264 };
3265
3266 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3267 session.peer.send(connection_id, update.clone())?;
3268 }
3269 };
3270
3271 send_notifications(&connection_pool, &session.peer, notifications);
3272
3273 response.send(proto::Ack {})?;
3274
3275 Ok(())
3276}
3277
3278/// Join the channels' room
3279async fn join_channel(
3280 request: proto::JoinChannel,
3281 response: Response<proto::JoinChannel>,
3282 session: MessageContext,
3283) -> Result<()> {
3284 let channel_id = ChannelId::from_proto(request.channel_id);
3285 join_channel_internal(channel_id, Box::new(response), session).await
3286}
3287
3288trait JoinChannelInternalResponse {
3289 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3290}
3291impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3292 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3293 Response::<proto::JoinChannel>::send(self, result)
3294 }
3295}
3296impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3297 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3298 Response::<proto::JoinRoom>::send(self, result)
3299 }
3300}
3301
3302async fn join_channel_internal(
3303 channel_id: ChannelId,
3304 response: Box<impl JoinChannelInternalResponse>,
3305 session: MessageContext,
3306) -> Result<()> {
3307 let joined_room = {
3308 let mut db = session.db().await;
3309 // If zed quits without leaving the room, and the user re-opens zed before the
3310 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3311 // room they were in.
3312 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3313 tracing::info!(
3314 stale_connection_id = %connection,
3315 "cleaning up stale connection",
3316 );
3317 drop(db);
3318 leave_room_for_session(&session, connection).await?;
3319 db = session.db().await;
3320 }
3321
3322 let (joined_room, membership_updated, role) = db
3323 .join_channel(channel_id, session.user_id(), session.connection_id)
3324 .await?;
3325
3326 let live_kit_connection_info =
3327 session
3328 .app_state
3329 .livekit_client
3330 .as_ref()
3331 .and_then(|live_kit| {
3332 let (can_publish, token) = if role == ChannelRole::Guest {
3333 (
3334 false,
3335 live_kit
3336 .guest_token(
3337 &joined_room.room.livekit_room,
3338 &session.user_id().to_string(),
3339 )
3340 .trace_err()?,
3341 )
3342 } else {
3343 (
3344 true,
3345 live_kit
3346 .room_token(
3347 &joined_room.room.livekit_room,
3348 &session.user_id().to_string(),
3349 )
3350 .trace_err()?,
3351 )
3352 };
3353
3354 Some(LiveKitConnectionInfo {
3355 server_url: live_kit.url().into(),
3356 token,
3357 can_publish,
3358 })
3359 });
3360
3361 response.send(proto::JoinRoomResponse {
3362 room: Some(joined_room.room.clone()),
3363 channel_id: joined_room
3364 .channel
3365 .as_ref()
3366 .map(|channel| channel.id.to_proto()),
3367 live_kit_connection_info,
3368 })?;
3369
3370 let mut connection_pool = session.connection_pool().await;
3371 if let Some(membership_updated) = membership_updated {
3372 notify_membership_updated(
3373 &mut connection_pool,
3374 membership_updated,
3375 session.user_id(),
3376 &session.peer,
3377 );
3378 }
3379
3380 room_updated(&joined_room.room, &session.peer);
3381
3382 joined_room
3383 };
3384
3385 channel_updated(
3386 &joined_room.channel.context("channel not returned")?,
3387 &joined_room.room,
3388 &session.peer,
3389 &*session.connection_pool().await,
3390 );
3391
3392 update_user_contacts(session.user_id(), &session).await?;
3393 Ok(())
3394}
3395
3396/// Start editing the channel notes
3397async fn join_channel_buffer(
3398 request: proto::JoinChannelBuffer,
3399 response: Response<proto::JoinChannelBuffer>,
3400 session: MessageContext,
3401) -> Result<()> {
3402 let db = session.db().await;
3403 let channel_id = ChannelId::from_proto(request.channel_id);
3404
3405 let open_response = db
3406 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3407 .await?;
3408
3409 let collaborators = open_response.collaborators.clone();
3410 response.send(open_response)?;
3411
3412 let update = UpdateChannelBufferCollaborators {
3413 channel_id: channel_id.to_proto(),
3414 collaborators: collaborators.clone(),
3415 };
3416 channel_buffer_updated(
3417 session.connection_id,
3418 collaborators
3419 .iter()
3420 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3421 &update,
3422 &session.peer,
3423 );
3424
3425 Ok(())
3426}
3427
3428/// Edit the channel notes
3429async fn update_channel_buffer(
3430 request: proto::UpdateChannelBuffer,
3431 session: MessageContext,
3432) -> Result<()> {
3433 let db = session.db().await;
3434 let channel_id = ChannelId::from_proto(request.channel_id);
3435
3436 let (collaborators, epoch, version) = db
3437 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3438 .await?;
3439
3440 channel_buffer_updated(
3441 session.connection_id,
3442 collaborators.clone(),
3443 &proto::UpdateChannelBuffer {
3444 channel_id: channel_id.to_proto(),
3445 operations: request.operations,
3446 },
3447 &session.peer,
3448 );
3449
3450 let pool = &*session.connection_pool().await;
3451
3452 let non_collaborators =
3453 pool.channel_connection_ids(channel_id)
3454 .filter_map(|(connection_id, _)| {
3455 if collaborators.contains(&connection_id) {
3456 None
3457 } else {
3458 Some(connection_id)
3459 }
3460 });
3461
3462 broadcast(None, non_collaborators, |peer_id| {
3463 session.peer.send(
3464 peer_id,
3465 proto::UpdateChannels {
3466 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3467 channel_id: channel_id.to_proto(),
3468 epoch: epoch as u64,
3469 version: version.clone(),
3470 }],
3471 ..Default::default()
3472 },
3473 )
3474 });
3475
3476 Ok(())
3477}
3478
3479/// Rejoin the channel notes after a connection blip
3480async fn rejoin_channel_buffers(
3481 request: proto::RejoinChannelBuffers,
3482 response: Response<proto::RejoinChannelBuffers>,
3483 session: MessageContext,
3484) -> Result<()> {
3485 let db = session.db().await;
3486 let buffers = db
3487 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3488 .await?;
3489
3490 for rejoined_buffer in &buffers {
3491 let collaborators_to_notify = rejoined_buffer
3492 .buffer
3493 .collaborators
3494 .iter()
3495 .filter_map(|c| Some(c.peer_id?.into()));
3496 channel_buffer_updated(
3497 session.connection_id,
3498 collaborators_to_notify,
3499 &proto::UpdateChannelBufferCollaborators {
3500 channel_id: rejoined_buffer.buffer.channel_id,
3501 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3502 },
3503 &session.peer,
3504 );
3505 }
3506
3507 response.send(proto::RejoinChannelBuffersResponse {
3508 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3509 })?;
3510
3511 Ok(())
3512}
3513
3514/// Stop editing the channel notes
3515async fn leave_channel_buffer(
3516 request: proto::LeaveChannelBuffer,
3517 response: Response<proto::LeaveChannelBuffer>,
3518 session: MessageContext,
3519) -> Result<()> {
3520 let db = session.db().await;
3521 let channel_id = ChannelId::from_proto(request.channel_id);
3522
3523 let left_buffer = db
3524 .leave_channel_buffer(channel_id, session.connection_id)
3525 .await?;
3526
3527 response.send(Ack {})?;
3528
3529 channel_buffer_updated(
3530 session.connection_id,
3531 left_buffer.connections,
3532 &proto::UpdateChannelBufferCollaborators {
3533 channel_id: channel_id.to_proto(),
3534 collaborators: left_buffer.collaborators,
3535 },
3536 &session.peer,
3537 );
3538
3539 Ok(())
3540}
3541
3542fn channel_buffer_updated<T: EnvelopedMessage>(
3543 sender_id: ConnectionId,
3544 collaborators: impl IntoIterator<Item = ConnectionId>,
3545 message: &T,
3546 peer: &Peer,
3547) {
3548 broadcast(Some(sender_id), collaborators, |peer_id| {
3549 peer.send(peer_id, message.clone())
3550 });
3551}
3552
3553fn send_notifications(
3554 connection_pool: &ConnectionPool,
3555 peer: &Peer,
3556 notifications: db::NotificationBatch,
3557) {
3558 for (user_id, notification) in notifications {
3559 for connection_id in connection_pool.user_connection_ids(user_id) {
3560 if let Err(error) = peer.send(
3561 connection_id,
3562 proto::AddNotification {
3563 notification: Some(notification.clone()),
3564 },
3565 ) {
3566 tracing::error!(
3567 "failed to send notification to {:?} {}",
3568 connection_id,
3569 error
3570 );
3571 }
3572 }
3573 }
3574}
3575
3576/// Send a message to the channel
3577async fn send_channel_message(
3578 _request: proto::SendChannelMessage,
3579 _response: Response<proto::SendChannelMessage>,
3580 _session: MessageContext,
3581) -> Result<()> {
3582 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3583}
3584
3585/// Delete a channel message
3586async fn remove_channel_message(
3587 _request: proto::RemoveChannelMessage,
3588 _response: Response<proto::RemoveChannelMessage>,
3589 _session: MessageContext,
3590) -> Result<()> {
3591 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3592}
3593
3594async fn update_channel_message(
3595 _request: proto::UpdateChannelMessage,
3596 _response: Response<proto::UpdateChannelMessage>,
3597 _session: MessageContext,
3598) -> Result<()> {
3599 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3600}
3601
3602/// Mark a channel message as read
3603async fn acknowledge_channel_message(
3604 _request: proto::AckChannelMessage,
3605 _session: MessageContext,
3606) -> Result<()> {
3607 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3608}
3609
3610/// Mark a buffer version as synced
3611async fn acknowledge_buffer_version(
3612 request: proto::AckBufferOperation,
3613 session: MessageContext,
3614) -> Result<()> {
3615 let buffer_id = BufferId::from_proto(request.buffer_id);
3616 session
3617 .db()
3618 .await
3619 .observe_buffer_version(
3620 buffer_id,
3621 session.user_id(),
3622 request.epoch as i32,
3623 &request.version,
3624 )
3625 .await?;
3626 Ok(())
3627}
3628
3629/// Get a Supermaven API key for the user
3630async fn get_supermaven_api_key(
3631 _request: proto::GetSupermavenApiKey,
3632 response: Response<proto::GetSupermavenApiKey>,
3633 session: MessageContext,
3634) -> Result<()> {
3635 let user_id: String = session.user_id().to_string();
3636 if !session.is_staff() {
3637 return Err(anyhow!("supermaven not enabled for this account"))?;
3638 }
3639
3640 let email = session.email().context("user must have an email")?;
3641
3642 let supermaven_admin_api = session
3643 .supermaven_client
3644 .as_ref()
3645 .context("supermaven not configured")?;
3646
3647 let result = supermaven_admin_api
3648 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3649 .await?;
3650
3651 response.send(proto::GetSupermavenApiKeyResponse {
3652 api_key: result.api_key,
3653 })?;
3654
3655 Ok(())
3656}
3657
3658/// Start receiving chat updates for a channel
3659async fn join_channel_chat(
3660 _request: proto::JoinChannelChat,
3661 _response: Response<proto::JoinChannelChat>,
3662 _session: MessageContext,
3663) -> Result<()> {
3664 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3665}
3666
3667/// Stop receiving chat updates for a channel
3668async fn leave_channel_chat(
3669 _request: proto::LeaveChannelChat,
3670 _session: MessageContext,
3671) -> Result<()> {
3672 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3673}
3674
3675/// Retrieve the chat history for a channel
3676async fn get_channel_messages(
3677 _request: proto::GetChannelMessages,
3678 _response: Response<proto::GetChannelMessages>,
3679 _session: MessageContext,
3680) -> Result<()> {
3681 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3682}
3683
3684/// Retrieve specific chat messages
3685async fn get_channel_messages_by_id(
3686 _request: proto::GetChannelMessagesById,
3687 _response: Response<proto::GetChannelMessagesById>,
3688 _session: MessageContext,
3689) -> Result<()> {
3690 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3691}
3692
3693/// Retrieve the current users notifications
3694async fn get_notifications(
3695 request: proto::GetNotifications,
3696 response: Response<proto::GetNotifications>,
3697 session: MessageContext,
3698) -> Result<()> {
3699 let notifications = session
3700 .db()
3701 .await
3702 .get_notifications(
3703 session.user_id(),
3704 NOTIFICATION_COUNT_PER_PAGE,
3705 request.before_id.map(db::NotificationId::from_proto),
3706 )
3707 .await?;
3708 response.send(proto::GetNotificationsResponse {
3709 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3710 notifications,
3711 })?;
3712 Ok(())
3713}
3714
3715/// Mark notifications as read
3716async fn mark_notification_as_read(
3717 request: proto::MarkNotificationRead,
3718 response: Response<proto::MarkNotificationRead>,
3719 session: MessageContext,
3720) -> Result<()> {
3721 let database = &session.db().await;
3722 let notifications = database
3723 .mark_notification_as_read_by_id(
3724 session.user_id(),
3725 NotificationId::from_proto(request.notification_id),
3726 )
3727 .await?;
3728 send_notifications(
3729 &*session.connection_pool().await,
3730 &session.peer,
3731 notifications,
3732 );
3733 response.send(proto::Ack {})?;
3734 Ok(())
3735}
3736
3737fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3738 let message = match message {
3739 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3740 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3741 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3742 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3743 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3744 code: frame.code.into(),
3745 reason: frame.reason.as_str().to_owned().into(),
3746 })),
3747 // We should never receive a frame while reading the message, according
3748 // to the `tungstenite` maintainers:
3749 //
3750 // > It cannot occur when you read messages from the WebSocket, but it
3751 // > can be used when you want to send the raw frames (e.g. you want to
3752 // > send the frames to the WebSocket without composing the full message first).
3753 // >
3754 // > — https://github.com/snapview/tungstenite-rs/issues/268
3755 TungsteniteMessage::Frame(_) => {
3756 bail!("received an unexpected frame while reading the message")
3757 }
3758 };
3759
3760 Ok(message)
3761}
3762
3763fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3764 match message {
3765 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3766 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3767 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3768 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3769 AxumMessage::Close(frame) => {
3770 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3771 code: frame.code.into(),
3772 reason: frame.reason.as_ref().into(),
3773 }))
3774 }
3775 }
3776}
3777
3778fn notify_membership_updated(
3779 connection_pool: &mut ConnectionPool,
3780 result: MembershipUpdated,
3781 user_id: UserId,
3782 peer: &Peer,
3783) {
3784 for membership in &result.new_channels.channel_memberships {
3785 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3786 }
3787 for channel_id in &result.removed_channels {
3788 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3789 }
3790
3791 let user_channels_update = proto::UpdateUserChannels {
3792 channel_memberships: result
3793 .new_channels
3794 .channel_memberships
3795 .iter()
3796 .map(|cm| proto::ChannelMembership {
3797 channel_id: cm.channel_id.to_proto(),
3798 role: cm.role.into(),
3799 })
3800 .collect(),
3801 ..Default::default()
3802 };
3803
3804 let mut update = build_channels_update(result.new_channels);
3805 update.delete_channels = result
3806 .removed_channels
3807 .into_iter()
3808 .map(|id| id.to_proto())
3809 .collect();
3810 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3811
3812 for connection_id in connection_pool.user_connection_ids(user_id) {
3813 peer.send(connection_id, user_channels_update.clone())
3814 .trace_err();
3815 peer.send(connection_id, update.clone()).trace_err();
3816 }
3817}
3818
3819fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3820 proto::UpdateUserChannels {
3821 channel_memberships: channels
3822 .channel_memberships
3823 .iter()
3824 .map(|m| proto::ChannelMembership {
3825 channel_id: m.channel_id.to_proto(),
3826 role: m.role.into(),
3827 })
3828 .collect(),
3829 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3830 }
3831}
3832
3833fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3834 let mut update = proto::UpdateChannels::default();
3835
3836 for channel in channels.channels {
3837 update.channels.push(channel.to_proto());
3838 }
3839
3840 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3841
3842 for (channel_id, participants) in channels.channel_participants {
3843 update
3844 .channel_participants
3845 .push(proto::ChannelParticipants {
3846 channel_id: channel_id.to_proto(),
3847 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3848 });
3849 }
3850
3851 for channel in channels.invited_channels {
3852 update.channel_invitations.push(channel.to_proto());
3853 }
3854
3855 update
3856}
3857
3858fn build_initial_contacts_update(
3859 contacts: Vec<db::Contact>,
3860 pool: &ConnectionPool,
3861) -> proto::UpdateContacts {
3862 let mut update = proto::UpdateContacts::default();
3863
3864 for contact in contacts {
3865 match contact {
3866 db::Contact::Accepted { user_id, busy } => {
3867 update.contacts.push(contact_for_user(user_id, busy, pool));
3868 }
3869 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3870 db::Contact::Incoming { user_id } => {
3871 update
3872 .incoming_requests
3873 .push(proto::IncomingContactRequest {
3874 requester_id: user_id.to_proto(),
3875 })
3876 }
3877 }
3878 }
3879
3880 update
3881}
3882
3883fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3884 proto::Contact {
3885 user_id: user_id.to_proto(),
3886 online: pool.is_user_online(user_id),
3887 busy,
3888 }
3889}
3890
3891fn room_updated(room: &proto::Room, peer: &Peer) {
3892 broadcast(
3893 None,
3894 room.participants
3895 .iter()
3896 .filter_map(|participant| Some(participant.peer_id?.into())),
3897 |peer_id| {
3898 peer.send(
3899 peer_id,
3900 proto::RoomUpdated {
3901 room: Some(room.clone()),
3902 },
3903 )
3904 },
3905 );
3906}
3907
3908fn channel_updated(
3909 channel: &db::channel::Model,
3910 room: &proto::Room,
3911 peer: &Peer,
3912 pool: &ConnectionPool,
3913) {
3914 let participants = room
3915 .participants
3916 .iter()
3917 .map(|p| p.user_id)
3918 .collect::<Vec<_>>();
3919
3920 broadcast(
3921 None,
3922 pool.channel_connection_ids(channel.root_id())
3923 .filter_map(|(channel_id, role)| {
3924 role.can_see_channel(channel.visibility)
3925 .then_some(channel_id)
3926 }),
3927 |peer_id| {
3928 peer.send(
3929 peer_id,
3930 proto::UpdateChannels {
3931 channel_participants: vec![proto::ChannelParticipants {
3932 channel_id: channel.id.to_proto(),
3933 participant_user_ids: participants.clone(),
3934 }],
3935 ..Default::default()
3936 },
3937 )
3938 },
3939 );
3940}
3941
3942async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3943 let db = session.db().await;
3944
3945 let contacts = db.get_contacts(user_id).await?;
3946 let busy = db.is_user_busy(user_id).await?;
3947
3948 let pool = session.connection_pool().await;
3949 let updated_contact = contact_for_user(user_id, busy, &pool);
3950 for contact in contacts {
3951 if let db::Contact::Accepted {
3952 user_id: contact_user_id,
3953 ..
3954 } = contact
3955 {
3956 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3957 session
3958 .peer
3959 .send(
3960 contact_conn_id,
3961 proto::UpdateContacts {
3962 contacts: vec![updated_contact.clone()],
3963 remove_contacts: Default::default(),
3964 incoming_requests: Default::default(),
3965 remove_incoming_requests: Default::default(),
3966 outgoing_requests: Default::default(),
3967 remove_outgoing_requests: Default::default(),
3968 },
3969 )
3970 .trace_err();
3971 }
3972 }
3973 }
3974 Ok(())
3975}
3976
3977async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3978 let mut contacts_to_update = HashSet::default();
3979
3980 let room_id;
3981 let canceled_calls_to_user_ids;
3982 let livekit_room;
3983 let delete_livekit_room;
3984 let room;
3985 let channel;
3986
3987 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3988 contacts_to_update.insert(session.user_id());
3989
3990 for project in left_room.left_projects.values() {
3991 project_left(project, session);
3992 }
3993
3994 room_id = RoomId::from_proto(left_room.room.id);
3995 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3996 livekit_room = mem::take(&mut left_room.room.livekit_room);
3997 delete_livekit_room = left_room.deleted;
3998 room = mem::take(&mut left_room.room);
3999 channel = mem::take(&mut left_room.channel);
4000
4001 room_updated(&room, &session.peer);
4002 } else {
4003 return Ok(());
4004 }
4005
4006 if let Some(channel) = channel {
4007 channel_updated(
4008 &channel,
4009 &room,
4010 &session.peer,
4011 &*session.connection_pool().await,
4012 );
4013 }
4014
4015 {
4016 let pool = session.connection_pool().await;
4017 for canceled_user_id in canceled_calls_to_user_ids {
4018 for connection_id in pool.user_connection_ids(canceled_user_id) {
4019 session
4020 .peer
4021 .send(
4022 connection_id,
4023 proto::CallCanceled {
4024 room_id: room_id.to_proto(),
4025 },
4026 )
4027 .trace_err();
4028 }
4029 contacts_to_update.insert(canceled_user_id);
4030 }
4031 }
4032
4033 for contact_user_id in contacts_to_update {
4034 update_user_contacts(contact_user_id, session).await?;
4035 }
4036
4037 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4038 live_kit
4039 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4040 .await
4041 .trace_err();
4042
4043 if delete_livekit_room {
4044 live_kit.delete_room(livekit_room).await.trace_err();
4045 }
4046 }
4047
4048 Ok(())
4049}
4050
4051async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4052 let left_channel_buffers = session
4053 .db()
4054 .await
4055 .leave_channel_buffers(session.connection_id)
4056 .await?;
4057
4058 for left_buffer in left_channel_buffers {
4059 channel_buffer_updated(
4060 session.connection_id,
4061 left_buffer.connections,
4062 &proto::UpdateChannelBufferCollaborators {
4063 channel_id: left_buffer.channel_id.to_proto(),
4064 collaborators: left_buffer.collaborators,
4065 },
4066 &session.peer,
4067 );
4068 }
4069
4070 Ok(())
4071}
4072
4073fn project_left(project: &db::LeftProject, session: &Session) {
4074 for connection_id in &project.connection_ids {
4075 if project.should_unshare {
4076 session
4077 .peer
4078 .send(
4079 *connection_id,
4080 proto::UnshareProject {
4081 project_id: project.id.to_proto(),
4082 },
4083 )
4084 .trace_err();
4085 } else {
4086 session
4087 .peer
4088 .send(
4089 *connection_id,
4090 proto::RemoveProjectCollaborator {
4091 project_id: project.id.to_proto(),
4092 peer_id: Some(session.connection_id.into()),
4093 },
4094 )
4095 .trace_err();
4096 }
4097 }
4098}
4099
4100pub trait ResultExt {
4101 type Ok;
4102
4103 fn trace_err(self) -> Option<Self::Ok>;
4104}
4105
4106impl<T, E> ResultExt for Result<T, E>
4107where
4108 E: std::fmt::Debug,
4109{
4110 type Ok = T;
4111
4112 #[track_caller]
4113 fn trace_err(self) -> Option<T> {
4114 match self {
4115 Ok(value) => Some(value),
4116 Err(error) => {
4117 tracing::error!("{:?}", error);
4118 None
4119 }
4120 }
4121 }
4122}