1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
10 },
11 executor::Executor,
12};
13use anyhow::{Context as _, anyhow, bail};
14use async_tungstenite::tungstenite::{
15 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
16};
17use axum::headers::UserAgent;
18use axum::{
19 Extension, Router, TypedHeader,
20 body::Body,
21 extract::{
22 ConnectInfo, WebSocketUpgrade,
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use core::fmt::{self, Debug, Formatter};
34use futures::TryFutureExt as _;
35use reqwest_client::ReqwestClient;
36use rpc::proto::split_repository_update;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38use tracing::Span;
39
40use futures::{
41 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
42 stream::FuturesUnordered,
43};
44use prometheus::{IntGauge, register_int_gauge};
45use rpc::{
46 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
47 proto::{
48 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
49 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
50 },
51};
52use semantic_version::SemanticVersion;
53use serde::{Serialize, Serializer};
54use std::{
55 any::TypeId,
56 future::Future,
57 marker::PhantomData,
58 mem,
59 net::SocketAddr,
60 ops::{Deref, DerefMut},
61 rc::Rc,
62 sync::{
63 Arc, OnceLock,
64 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
65 },
66 time::{Duration, Instant},
67};
68use tokio::sync::{Semaphore, watch};
69use tower::ServiceBuilder;
70use tracing::{
71 Instrument,
72 field::{self},
73 info_span, instrument,
74};
75
76pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
77
78// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
79pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
80
81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
82const MAX_CONCURRENT_CONNECTIONS: usize = 512;
83
84static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
85
86const TOTAL_DURATION_MS: &str = "total_duration_ms";
87const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
88const QUEUE_DURATION_MS: &str = "queue_duration_ms";
89const HOST_WAITING_MS: &str = "host_waiting_ms";
90
91type MessageHandler =
92 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
93
94pub struct ConnectionGuard;
95
96impl ConnectionGuard {
97 pub fn try_acquire() -> Result<Self, ()> {
98 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
99 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
100 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
101 tracing::error!(
102 "too many concurrent connections: {}",
103 current_connections + 1
104 );
105 return Err(());
106 }
107 Ok(ConnectionGuard)
108 }
109}
110
111impl Drop for ConnectionGuard {
112 fn drop(&mut self) {
113 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
114 }
115}
116
117struct Response<R> {
118 peer: Arc<Peer>,
119 receipt: Receipt<R>,
120 responded: Arc<AtomicBool>,
121}
122
123impl<R: RequestMessage> Response<R> {
124 fn send(self, payload: R::Response) -> Result<()> {
125 self.responded.store(true, SeqCst);
126 self.peer.respond(self.receipt, payload)?;
127 Ok(())
128 }
129}
130
131#[derive(Clone, Debug)]
132pub enum Principal {
133 User(User),
134 Impersonated { user: User, admin: User },
135}
136
137impl Principal {
138 fn update_span(&self, span: &tracing::Span) {
139 match &self {
140 Principal::User(user) => {
141 span.record("user_id", user.id.0);
142 span.record("login", &user.github_login);
143 }
144 Principal::Impersonated { user, admin } => {
145 span.record("user_id", user.id.0);
146 span.record("login", &user.github_login);
147 span.record("impersonator", &admin.github_login);
148 }
149 }
150 }
151}
152
153#[derive(Clone)]
154struct MessageContext {
155 session: Session,
156 span: tracing::Span,
157}
158
159impl Deref for MessageContext {
160 type Target = Session;
161
162 fn deref(&self) -> &Self::Target {
163 &self.session
164 }
165}
166
167impl MessageContext {
168 pub fn forward_request<T: RequestMessage>(
169 &self,
170 receiver_id: ConnectionId,
171 request: T,
172 ) -> impl Future<Output = anyhow::Result<T::Response>> {
173 let request_start_time = Instant::now();
174 let span = self.span.clone();
175 tracing::info!("start forwarding request");
176 self.peer
177 .forward_request(self.connection_id, receiver_id, request)
178 .inspect(move |_| {
179 span.record(
180 HOST_WAITING_MS,
181 request_start_time.elapsed().as_micros() as f64 / 1000.0,
182 );
183 })
184 .inspect_err(|_| tracing::error!("error forwarding request"))
185 .inspect_ok(|_| tracing::info!("finished forwarding request"))
186 }
187}
188
189#[derive(Clone)]
190struct Session {
191 principal: Principal,
192 connection_id: ConnectionId,
193 db: Arc<tokio::sync::Mutex<DbHandle>>,
194 peer: Arc<Peer>,
195 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
196 app_state: Arc<AppState>,
197 supermaven_client: Option<Arc<SupermavenAdminApi>>,
198 /// The GeoIP country code for the user.
199 #[allow(unused)]
200 geoip_country_code: Option<String>,
201 #[allow(unused)]
202 system_id: Option<String>,
203 _executor: Executor,
204}
205
206impl Session {
207 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
208 #[cfg(test)]
209 tokio::task::yield_now().await;
210 let guard = self.db.lock().await;
211 #[cfg(test)]
212 tokio::task::yield_now().await;
213 guard
214 }
215
216 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
217 #[cfg(test)]
218 tokio::task::yield_now().await;
219 let guard = self.connection_pool.lock();
220 ConnectionPoolGuard {
221 guard,
222 _not_send: PhantomData,
223 }
224 }
225
226 fn is_staff(&self) -> bool {
227 match &self.principal {
228 Principal::User(user) => user.admin,
229 Principal::Impersonated { .. } => true,
230 }
231 }
232
233 fn user_id(&self) -> UserId {
234 match &self.principal {
235 Principal::User(user) => user.id,
236 Principal::Impersonated { user, .. } => user.id,
237 }
238 }
239
240 pub fn email(&self) -> Option<String> {
241 match &self.principal {
242 Principal::User(user) => user.email_address.clone(),
243 Principal::Impersonated { user, .. } => user.email_address.clone(),
244 }
245 }
246}
247
248impl Debug for Session {
249 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
250 let mut result = f.debug_struct("Session");
251 match &self.principal {
252 Principal::User(user) => {
253 result.field("user", &user.github_login);
254 }
255 Principal::Impersonated { user, admin } => {
256 result.field("user", &user.github_login);
257 result.field("impersonator", &admin.github_login);
258 }
259 }
260 result.field("connection_id", &self.connection_id).finish()
261 }
262}
263
264struct DbHandle(Arc<Database>);
265
266impl Deref for DbHandle {
267 type Target = Database;
268
269 fn deref(&self) -> &Self::Target {
270 self.0.as_ref()
271 }
272}
273
274pub struct Server {
275 id: parking_lot::Mutex<ServerId>,
276 peer: Arc<Peer>,
277 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
278 app_state: Arc<AppState>,
279 handlers: HashMap<TypeId, MessageHandler>,
280 teardown: watch::Sender<bool>,
281}
282
283pub(crate) struct ConnectionPoolGuard<'a> {
284 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
285 _not_send: PhantomData<Rc<()>>,
286}
287
288#[derive(Serialize)]
289pub struct ServerSnapshot<'a> {
290 peer: &'a Peer,
291 #[serde(serialize_with = "serialize_deref")]
292 connection_pool: ConnectionPoolGuard<'a>,
293}
294
295pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
296where
297 S: Serializer,
298 T: Deref<Target = U>,
299 U: Serialize,
300{
301 Serialize::serialize(value.deref(), serializer)
302}
303
304impl Server {
305 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
306 let mut server = Self {
307 id: parking_lot::Mutex::new(id),
308 peer: Peer::new(id.0 as u32),
309 app_state,
310 connection_pool: Default::default(),
311 handlers: Default::default(),
312 teardown: watch::channel(false).0,
313 };
314
315 server
316 .add_request_handler(ping)
317 .add_request_handler(create_room)
318 .add_request_handler(join_room)
319 .add_request_handler(rejoin_room)
320 .add_request_handler(leave_room)
321 .add_request_handler(set_room_participant_role)
322 .add_request_handler(call)
323 .add_request_handler(cancel_call)
324 .add_message_handler(decline_call)
325 .add_request_handler(update_participant_location)
326 .add_request_handler(share_project)
327 .add_message_handler(unshare_project)
328 .add_request_handler(join_project)
329 .add_message_handler(leave_project)
330 .add_request_handler(update_project)
331 .add_request_handler(update_worktree)
332 .add_request_handler(update_repository)
333 .add_request_handler(remove_repository)
334 .add_message_handler(start_language_server)
335 .add_message_handler(update_language_server)
336 .add_message_handler(update_diagnostic_summary)
337 .add_message_handler(update_worktree_settings)
338 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
339 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
340 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
341 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
342 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
343 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
344 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
345 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
346 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
347 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
348 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
349 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
350 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
351 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
352 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
353 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
354 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
355 .add_request_handler(
356 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
357 )
358 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
359 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
360 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
361 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
362 .add_request_handler(
363 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
364 )
365 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
366 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
367 .add_request_handler(
368 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
369 )
370 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
371 .add_request_handler(
372 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
373 )
374 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
375 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
376 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
377 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
378 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
379 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
380 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
381 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
382 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
383 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
384 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
385 .add_request_handler(
386 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
387 )
388 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
389 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
390 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
391 .add_request_handler(lsp_query)
392 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
393 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
394 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
395 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
396 .add_message_handler(create_buffer_for_peer)
397 .add_request_handler(update_buffer)
398 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
399 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
400 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
401 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
402 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
403 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
404 .add_message_handler(
405 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
406 )
407 .add_request_handler(get_users)
408 .add_request_handler(fuzzy_search_users)
409 .add_request_handler(request_contact)
410 .add_request_handler(remove_contact)
411 .add_request_handler(respond_to_contact_request)
412 .add_message_handler(subscribe_to_channels)
413 .add_request_handler(create_channel)
414 .add_request_handler(delete_channel)
415 .add_request_handler(invite_channel_member)
416 .add_request_handler(remove_channel_member)
417 .add_request_handler(set_channel_member_role)
418 .add_request_handler(set_channel_visibility)
419 .add_request_handler(rename_channel)
420 .add_request_handler(join_channel_buffer)
421 .add_request_handler(leave_channel_buffer)
422 .add_message_handler(update_channel_buffer)
423 .add_request_handler(rejoin_channel_buffers)
424 .add_request_handler(get_channel_members)
425 .add_request_handler(respond_to_channel_invite)
426 .add_request_handler(join_channel)
427 .add_request_handler(join_channel_chat)
428 .add_message_handler(leave_channel_chat)
429 .add_request_handler(send_channel_message)
430 .add_request_handler(remove_channel_message)
431 .add_request_handler(update_channel_message)
432 .add_request_handler(get_channel_messages)
433 .add_request_handler(get_channel_messages_by_id)
434 .add_request_handler(get_notifications)
435 .add_request_handler(mark_notification_as_read)
436 .add_request_handler(move_channel)
437 .add_request_handler(reorder_channel)
438 .add_request_handler(follow)
439 .add_message_handler(unfollow)
440 .add_message_handler(update_followers)
441 .add_message_handler(acknowledge_channel_message)
442 .add_message_handler(acknowledge_buffer_version)
443 .add_request_handler(get_supermaven_api_key)
444 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
445 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
446 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
447 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
448 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
449 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
450 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
451 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
452 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
453 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
454 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
455 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
456 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
457 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
458 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
459 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
460 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
461 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
462 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
463 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
464 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
465 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
466 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
467 .add_message_handler(update_context)
468 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
469 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
470
471 Arc::new(server)
472 }
473
474 pub async fn start(&self) -> Result<()> {
475 let server_id = *self.id.lock();
476 let app_state = self.app_state.clone();
477 let peer = self.peer.clone();
478 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
479 let pool = self.connection_pool.clone();
480 let livekit_client = self.app_state.livekit_client.clone();
481
482 let span = info_span!("start server");
483 self.app_state.executor.spawn_detached(
484 async move {
485 tracing::info!("waiting for cleanup timeout");
486 timeout.await;
487 tracing::info!("cleanup timeout expired, retrieving stale rooms");
488
489 app_state
490 .db
491 .delete_stale_channel_chat_participants(
492 &app_state.config.zed_environment,
493 server_id,
494 )
495 .await
496 .trace_err();
497
498 if let Some((room_ids, channel_ids)) = app_state
499 .db
500 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
501 .await
502 .trace_err()
503 {
504 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
505 tracing::info!(
506 stale_channel_buffer_count = channel_ids.len(),
507 "retrieved stale channel buffers"
508 );
509
510 for channel_id in channel_ids {
511 if let Some(refreshed_channel_buffer) = app_state
512 .db
513 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
514 .await
515 .trace_err()
516 {
517 for connection_id in refreshed_channel_buffer.connection_ids {
518 peer.send(
519 connection_id,
520 proto::UpdateChannelBufferCollaborators {
521 channel_id: channel_id.to_proto(),
522 collaborators: refreshed_channel_buffer
523 .collaborators
524 .clone(),
525 },
526 )
527 .trace_err();
528 }
529 }
530 }
531
532 for room_id in room_ids {
533 let mut contacts_to_update = HashSet::default();
534 let mut canceled_calls_to_user_ids = Vec::new();
535 let mut livekit_room = String::new();
536 let mut delete_livekit_room = false;
537
538 if let Some(mut refreshed_room) = app_state
539 .db
540 .clear_stale_room_participants(room_id, server_id)
541 .await
542 .trace_err()
543 {
544 tracing::info!(
545 room_id = room_id.0,
546 new_participant_count = refreshed_room.room.participants.len(),
547 "refreshed room"
548 );
549 room_updated(&refreshed_room.room, &peer);
550 if let Some(channel) = refreshed_room.channel.as_ref() {
551 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
552 }
553 contacts_to_update
554 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
555 contacts_to_update
556 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
557 canceled_calls_to_user_ids =
558 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
559 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
560 delete_livekit_room = refreshed_room.room.participants.is_empty();
561 }
562
563 {
564 let pool = pool.lock();
565 for canceled_user_id in canceled_calls_to_user_ids {
566 for connection_id in pool.user_connection_ids(canceled_user_id) {
567 peer.send(
568 connection_id,
569 proto::CallCanceled {
570 room_id: room_id.to_proto(),
571 },
572 )
573 .trace_err();
574 }
575 }
576 }
577
578 for user_id in contacts_to_update {
579 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
580 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
581 if let Some((busy, contacts)) = busy.zip(contacts) {
582 let pool = pool.lock();
583 let updated_contact = contact_for_user(user_id, busy, &pool);
584 for contact in contacts {
585 if let db::Contact::Accepted {
586 user_id: contact_user_id,
587 ..
588 } = contact
589 {
590 for contact_conn_id in
591 pool.user_connection_ids(contact_user_id)
592 {
593 peer.send(
594 contact_conn_id,
595 proto::UpdateContacts {
596 contacts: vec![updated_contact.clone()],
597 remove_contacts: Default::default(),
598 incoming_requests: Default::default(),
599 remove_incoming_requests: Default::default(),
600 outgoing_requests: Default::default(),
601 remove_outgoing_requests: Default::default(),
602 },
603 )
604 .trace_err();
605 }
606 }
607 }
608 }
609 }
610
611 if let Some(live_kit) = livekit_client.as_ref()
612 && delete_livekit_room
613 {
614 live_kit.delete_room(livekit_room).await.trace_err();
615 }
616 }
617 }
618
619 app_state
620 .db
621 .delete_stale_channel_chat_participants(
622 &app_state.config.zed_environment,
623 server_id,
624 )
625 .await
626 .trace_err();
627
628 app_state
629 .db
630 .clear_old_worktree_entries(server_id)
631 .await
632 .trace_err();
633
634 app_state
635 .db
636 .delete_stale_servers(&app_state.config.zed_environment, server_id)
637 .await
638 .trace_err();
639 }
640 .instrument(span),
641 );
642 Ok(())
643 }
644
645 pub fn teardown(&self) {
646 self.peer.teardown();
647 self.connection_pool.lock().reset();
648 let _ = self.teardown.send(true);
649 }
650
651 #[cfg(test)]
652 pub fn reset(&self, id: ServerId) {
653 self.teardown();
654 *self.id.lock() = id;
655 self.peer.reset(id.0 as u32);
656 let _ = self.teardown.send(false);
657 }
658
659 #[cfg(test)]
660 pub fn id(&self) -> ServerId {
661 *self.id.lock()
662 }
663
664 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
665 where
666 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
667 Fut: 'static + Send + Future<Output = Result<()>>,
668 M: EnvelopedMessage,
669 {
670 let prev_handler = self.handlers.insert(
671 TypeId::of::<M>(),
672 Box::new(move |envelope, session, span| {
673 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
674 let received_at = envelope.received_at;
675 tracing::info!("message received");
676 let start_time = Instant::now();
677 let future = (handler)(
678 *envelope,
679 MessageContext {
680 session,
681 span: span.clone(),
682 },
683 );
684 async move {
685 let result = future.await;
686 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
687 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
688 let queue_duration_ms = total_duration_ms - processing_duration_ms;
689 span.record(TOTAL_DURATION_MS, total_duration_ms);
690 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
691 span.record(QUEUE_DURATION_MS, queue_duration_ms);
692 match result {
693 Err(error) => {
694 tracing::error!(?error, "error handling message")
695 }
696 Ok(()) => tracing::info!("finished handling message"),
697 }
698 }
699 .boxed()
700 }),
701 );
702 if prev_handler.is_some() {
703 panic!("registered a handler for the same message twice");
704 }
705 self
706 }
707
708 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
709 where
710 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
711 Fut: 'static + Send + Future<Output = Result<()>>,
712 M: EnvelopedMessage,
713 {
714 self.add_handler(move |envelope, session| handler(envelope.payload, session));
715 self
716 }
717
718 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
719 where
720 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
721 Fut: Send + Future<Output = Result<()>>,
722 M: RequestMessage,
723 {
724 let handler = Arc::new(handler);
725 self.add_handler(move |envelope, session| {
726 let receipt = envelope.receipt();
727 let handler = handler.clone();
728 async move {
729 let peer = session.peer.clone();
730 let responded = Arc::new(AtomicBool::default());
731 let response = Response {
732 peer: peer.clone(),
733 responded: responded.clone(),
734 receipt,
735 };
736 match (handler)(envelope.payload, response, session).await {
737 Ok(()) => {
738 if responded.load(std::sync::atomic::Ordering::SeqCst) {
739 Ok(())
740 } else {
741 Err(anyhow!("handler did not send a response"))?
742 }
743 }
744 Err(error) => {
745 let proto_err = match &error {
746 Error::Internal(err) => err.to_proto(),
747 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
748 };
749 peer.respond_with_error(receipt, proto_err)?;
750 Err(error)
751 }
752 }
753 }
754 })
755 }
756
757 pub fn handle_connection(
758 self: &Arc<Self>,
759 connection: Connection,
760 address: String,
761 principal: Principal,
762 zed_version: ZedVersion,
763 release_channel: Option<String>,
764 user_agent: Option<String>,
765 geoip_country_code: Option<String>,
766 system_id: Option<String>,
767 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
768 executor: Executor,
769 connection_guard: Option<ConnectionGuard>,
770 ) -> impl Future<Output = ()> + use<> {
771 let this = self.clone();
772 let span = info_span!("handle connection", %address,
773 connection_id=field::Empty,
774 user_id=field::Empty,
775 login=field::Empty,
776 impersonator=field::Empty,
777 user_agent=field::Empty,
778 geoip_country_code=field::Empty,
779 release_channel=field::Empty,
780 );
781 principal.update_span(&span);
782 if let Some(user_agent) = user_agent {
783 span.record("user_agent", user_agent);
784 }
785 if let Some(release_channel) = release_channel {
786 span.record("release_channel", release_channel);
787 }
788
789 if let Some(country_code) = geoip_country_code.as_ref() {
790 span.record("geoip_country_code", country_code);
791 }
792
793 let mut teardown = self.teardown.subscribe();
794 async move {
795 if *teardown.borrow() {
796 tracing::error!("server is tearing down");
797 return;
798 }
799
800 let (connection_id, handle_io, mut incoming_rx) =
801 this.peer.add_connection(connection, {
802 let executor = executor.clone();
803 move |duration| executor.sleep(duration)
804 });
805 tracing::Span::current().record("connection_id", format!("{}", connection_id));
806
807 tracing::info!("connection opened");
808
809 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
810 let http_client = match ReqwestClient::user_agent(&user_agent) {
811 Ok(http_client) => Arc::new(http_client),
812 Err(error) => {
813 tracing::error!(?error, "failed to create HTTP client");
814 return;
815 }
816 };
817
818 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
819 |supermaven_admin_api_key| {
820 Arc::new(SupermavenAdminApi::new(
821 supermaven_admin_api_key.to_string(),
822 http_client.clone(),
823 ))
824 },
825 );
826
827 let session = Session {
828 principal: principal.clone(),
829 connection_id,
830 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
831 peer: this.peer.clone(),
832 connection_pool: this.connection_pool.clone(),
833 app_state: this.app_state.clone(),
834 geoip_country_code,
835 system_id,
836 _executor: executor.clone(),
837 supermaven_client,
838 };
839
840 if let Err(error) = this
841 .send_initial_client_update(
842 connection_id,
843 zed_version,
844 send_connection_id,
845 &session,
846 )
847 .await
848 {
849 tracing::error!(?error, "failed to send initial client update");
850 return;
851 }
852 drop(connection_guard);
853
854 let handle_io = handle_io.fuse();
855 futures::pin_mut!(handle_io);
856
857 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
858 // This prevents deadlocks when e.g., client A performs a request to client B and
859 // client B performs a request to client A. If both clients stop processing further
860 // messages until their respective request completes, they won't have a chance to
861 // respond to the other client's request and cause a deadlock.
862 //
863 // This arrangement ensures we will attempt to process earlier messages first, but fall
864 // back to processing messages arrived later in the spirit of making progress.
865 const MAX_CONCURRENT_HANDLERS: usize = 256;
866 let mut foreground_message_handlers = FuturesUnordered::new();
867 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
868 let get_concurrent_handlers = {
869 let concurrent_handlers = concurrent_handlers.clone();
870 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
871 };
872 loop {
873 let next_message = async {
874 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
875 let message = incoming_rx.next().await;
876 // Cache the concurrent_handlers here, so that we know what the
877 // queue looks like as each handler starts
878 (permit, message, get_concurrent_handlers())
879 }
880 .fuse();
881 futures::pin_mut!(next_message);
882 futures::select_biased! {
883 _ = teardown.changed().fuse() => return,
884 result = handle_io => {
885 if let Err(error) = result {
886 tracing::error!(?error, "error handling I/O");
887 }
888 break;
889 }
890 _ = foreground_message_handlers.next() => {}
891 next_message = next_message => {
892 let (permit, message, concurrent_handlers) = next_message;
893 if let Some(message) = message {
894 let type_name = message.payload_type_name();
895 // note: we copy all the fields from the parent span so we can query them in the logs.
896 // (https://github.com/tokio-rs/tracing/issues/2670).
897 let span = tracing::info_span!("receive message",
898 %connection_id,
899 %address,
900 type_name,
901 concurrent_handlers,
902 user_id=field::Empty,
903 login=field::Empty,
904 impersonator=field::Empty,
905 lsp_query_request=field::Empty,
906 release_channel=field::Empty,
907 { TOTAL_DURATION_MS }=field::Empty,
908 { PROCESSING_DURATION_MS }=field::Empty,
909 { QUEUE_DURATION_MS }=field::Empty,
910 { HOST_WAITING_MS }=field::Empty
911 );
912 principal.update_span(&span);
913 let span_enter = span.enter();
914 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
915 let is_background = message.is_background();
916 let handle_message = (handler)(message, session.clone(), span.clone());
917 drop(span_enter);
918
919 let handle_message = async move {
920 handle_message.await;
921 drop(permit);
922 }.instrument(span);
923 if is_background {
924 executor.spawn_detached(handle_message);
925 } else {
926 foreground_message_handlers.push(handle_message);
927 }
928 } else {
929 tracing::error!("no message handler");
930 }
931 } else {
932 tracing::info!("connection closed");
933 break;
934 }
935 }
936 }
937 }
938
939 drop(foreground_message_handlers);
940 let concurrent_handlers = get_concurrent_handlers();
941 tracing::info!(concurrent_handlers, "signing out");
942 if let Err(error) = connection_lost(session, teardown, executor).await {
943 tracing::error!(?error, "error signing out");
944 }
945 }
946 .instrument(span)
947 }
948
949 async fn send_initial_client_update(
950 &self,
951 connection_id: ConnectionId,
952 zed_version: ZedVersion,
953 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
954 session: &Session,
955 ) -> Result<()> {
956 self.peer.send(
957 connection_id,
958 proto::Hello {
959 peer_id: Some(connection_id.into()),
960 },
961 )?;
962 tracing::info!("sent hello message");
963 if let Some(send_connection_id) = send_connection_id.take() {
964 let _ = send_connection_id.send(connection_id);
965 }
966
967 match &session.principal {
968 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
969 if !user.connected_once {
970 self.peer.send(connection_id, proto::ShowContacts {})?;
971 self.app_state
972 .db
973 .set_user_connected_once(user.id, true)
974 .await?;
975 }
976
977 let contacts = self.app_state.db.get_contacts(user.id).await?;
978
979 {
980 let mut pool = self.connection_pool.lock();
981 pool.add_connection(connection_id, user.id, user.admin, zed_version);
982 self.peer.send(
983 connection_id,
984 build_initial_contacts_update(contacts, &pool),
985 )?;
986 }
987
988 if should_auto_subscribe_to_channels(zed_version) {
989 subscribe_user_to_channels(user.id, session).await?;
990 }
991
992 if let Some(incoming_call) =
993 self.app_state.db.incoming_call_for_user(user.id).await?
994 {
995 self.peer.send(connection_id, incoming_call)?;
996 }
997
998 update_user_contacts(user.id, session).await?;
999 }
1000 }
1001
1002 Ok(())
1003 }
1004
1005 pub async fn invite_code_redeemed(
1006 self: &Arc<Self>,
1007 inviter_id: UserId,
1008 invitee_id: UserId,
1009 ) -> Result<()> {
1010 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1011 && let Some(code) = &user.invite_code
1012 {
1013 let pool = self.connection_pool.lock();
1014 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1015 for connection_id in pool.user_connection_ids(inviter_id) {
1016 self.peer.send(
1017 connection_id,
1018 proto::UpdateContacts {
1019 contacts: vec![invitee_contact.clone()],
1020 ..Default::default()
1021 },
1022 )?;
1023 self.peer.send(
1024 connection_id,
1025 proto::UpdateInviteInfo {
1026 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1027 count: user.invite_count as u32,
1028 },
1029 )?;
1030 }
1031 }
1032 Ok(())
1033 }
1034
1035 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1036 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1037 && let Some(invite_code) = &user.invite_code
1038 {
1039 let pool = self.connection_pool.lock();
1040 for connection_id in pool.user_connection_ids(user_id) {
1041 self.peer.send(
1042 connection_id,
1043 proto::UpdateInviteInfo {
1044 url: format!(
1045 "{}{}",
1046 self.app_state.config.invite_link_prefix, invite_code
1047 ),
1048 count: user.invite_count as u32,
1049 },
1050 )?;
1051 }
1052 }
1053 Ok(())
1054 }
1055
1056 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1057 ServerSnapshot {
1058 connection_pool: ConnectionPoolGuard {
1059 guard: self.connection_pool.lock(),
1060 _not_send: PhantomData,
1061 },
1062 peer: &self.peer,
1063 }
1064 }
1065}
1066
1067impl Deref for ConnectionPoolGuard<'_> {
1068 type Target = ConnectionPool;
1069
1070 fn deref(&self) -> &Self::Target {
1071 &self.guard
1072 }
1073}
1074
1075impl DerefMut for ConnectionPoolGuard<'_> {
1076 fn deref_mut(&mut self) -> &mut Self::Target {
1077 &mut self.guard
1078 }
1079}
1080
1081impl Drop for ConnectionPoolGuard<'_> {
1082 fn drop(&mut self) {
1083 #[cfg(test)]
1084 self.check_invariants();
1085 }
1086}
1087
1088fn broadcast<F>(
1089 sender_id: Option<ConnectionId>,
1090 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1091 mut f: F,
1092) where
1093 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1094{
1095 for receiver_id in receiver_ids {
1096 if Some(receiver_id) != sender_id
1097 && let Err(error) = f(receiver_id)
1098 {
1099 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1100 }
1101 }
1102}
1103
1104pub struct ProtocolVersion(u32);
1105
1106impl Header for ProtocolVersion {
1107 fn name() -> &'static HeaderName {
1108 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1109 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1110 }
1111
1112 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1113 where
1114 Self: Sized,
1115 I: Iterator<Item = &'i axum::http::HeaderValue>,
1116 {
1117 let version = values
1118 .next()
1119 .ok_or_else(axum::headers::Error::invalid)?
1120 .to_str()
1121 .map_err(|_| axum::headers::Error::invalid())?
1122 .parse()
1123 .map_err(|_| axum::headers::Error::invalid())?;
1124 Ok(Self(version))
1125 }
1126
1127 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1128 values.extend([self.0.to_string().parse().unwrap()]);
1129 }
1130}
1131
1132pub struct AppVersionHeader(SemanticVersion);
1133impl Header for AppVersionHeader {
1134 fn name() -> &'static HeaderName {
1135 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1136 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1137 }
1138
1139 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1140 where
1141 Self: Sized,
1142 I: Iterator<Item = &'i axum::http::HeaderValue>,
1143 {
1144 let version = values
1145 .next()
1146 .ok_or_else(axum::headers::Error::invalid)?
1147 .to_str()
1148 .map_err(|_| axum::headers::Error::invalid())?
1149 .parse()
1150 .map_err(|_| axum::headers::Error::invalid())?;
1151 Ok(Self(version))
1152 }
1153
1154 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1155 values.extend([self.0.to_string().parse().unwrap()]);
1156 }
1157}
1158
1159#[derive(Debug)]
1160pub struct ReleaseChannelHeader(String);
1161
1162impl Header for ReleaseChannelHeader {
1163 fn name() -> &'static HeaderName {
1164 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1165 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1166 }
1167
1168 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1169 where
1170 Self: Sized,
1171 I: Iterator<Item = &'i axum::http::HeaderValue>,
1172 {
1173 Ok(Self(
1174 values
1175 .next()
1176 .ok_or_else(axum::headers::Error::invalid)?
1177 .to_str()
1178 .map_err(|_| axum::headers::Error::invalid())?
1179 .to_owned(),
1180 ))
1181 }
1182
1183 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1184 values.extend([self.0.parse().unwrap()]);
1185 }
1186}
1187
1188pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1189 Router::new()
1190 .route("/rpc", get(handle_websocket_request))
1191 .layer(
1192 ServiceBuilder::new()
1193 .layer(Extension(server.app_state.clone()))
1194 .layer(middleware::from_fn(auth::validate_header)),
1195 )
1196 .route("/metrics", get(handle_metrics))
1197 .layer(Extension(server))
1198}
1199
1200pub async fn handle_websocket_request(
1201 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1202 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1203 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1204 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1205 Extension(server): Extension<Arc<Server>>,
1206 Extension(principal): Extension<Principal>,
1207 user_agent: Option<TypedHeader<UserAgent>>,
1208 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1209 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1210 ws: WebSocketUpgrade,
1211) -> axum::response::Response {
1212 if protocol_version != rpc::PROTOCOL_VERSION {
1213 return (
1214 StatusCode::UPGRADE_REQUIRED,
1215 "client must be upgraded".to_string(),
1216 )
1217 .into_response();
1218 }
1219
1220 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1221 return (
1222 StatusCode::UPGRADE_REQUIRED,
1223 "no version header found".to_string(),
1224 )
1225 .into_response();
1226 };
1227
1228 let release_channel = release_channel_header.map(|header| header.0.0);
1229
1230 if !version.can_collaborate() {
1231 return (
1232 StatusCode::UPGRADE_REQUIRED,
1233 "client must be upgraded".to_string(),
1234 )
1235 .into_response();
1236 }
1237
1238 let socket_address = socket_address.to_string();
1239
1240 // Acquire connection guard before WebSocket upgrade
1241 let connection_guard = match ConnectionGuard::try_acquire() {
1242 Ok(guard) => guard,
1243 Err(()) => {
1244 return (
1245 StatusCode::SERVICE_UNAVAILABLE,
1246 "Too many concurrent connections",
1247 )
1248 .into_response();
1249 }
1250 };
1251
1252 ws.on_upgrade(move |socket| {
1253 let socket = socket
1254 .map_ok(to_tungstenite_message)
1255 .err_into()
1256 .with(|message| async move { to_axum_message(message) });
1257 let connection = Connection::new(Box::pin(socket));
1258 async move {
1259 server
1260 .handle_connection(
1261 connection,
1262 socket_address,
1263 principal,
1264 version,
1265 release_channel,
1266 user_agent.map(|header| header.to_string()),
1267 country_code_header.map(|header| header.to_string()),
1268 system_id_header.map(|header| header.to_string()),
1269 None,
1270 Executor::Production,
1271 Some(connection_guard),
1272 )
1273 .await;
1274 }
1275 })
1276}
1277
1278pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1279 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1280 let connections_metric = CONNECTIONS_METRIC
1281 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1282
1283 let connections = server
1284 .connection_pool
1285 .lock()
1286 .connections()
1287 .filter(|connection| !connection.admin)
1288 .count();
1289 connections_metric.set(connections as _);
1290
1291 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1292 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1293 register_int_gauge!(
1294 "shared_projects",
1295 "number of open projects with one or more guests"
1296 )
1297 .unwrap()
1298 });
1299
1300 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1301 shared_projects_metric.set(shared_projects as _);
1302
1303 let encoder = prometheus::TextEncoder::new();
1304 let metric_families = prometheus::gather();
1305 let encoded_metrics = encoder
1306 .encode_to_string(&metric_families)
1307 .map_err(|err| anyhow!("{err}"))?;
1308 Ok(encoded_metrics)
1309}
1310
1311#[instrument(err, skip(executor))]
1312async fn connection_lost(
1313 session: Session,
1314 mut teardown: watch::Receiver<bool>,
1315 executor: Executor,
1316) -> Result<()> {
1317 session.peer.disconnect(session.connection_id);
1318 session
1319 .connection_pool()
1320 .await
1321 .remove_connection(session.connection_id)?;
1322
1323 session
1324 .db()
1325 .await
1326 .connection_lost(session.connection_id)
1327 .await
1328 .trace_err();
1329
1330 futures::select_biased! {
1331 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1332
1333 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1334 leave_room_for_session(&session, session.connection_id).await.trace_err();
1335 leave_channel_buffers_for_session(&session)
1336 .await
1337 .trace_err();
1338
1339 if !session
1340 .connection_pool()
1341 .await
1342 .is_user_online(session.user_id())
1343 {
1344 let db = session.db().await;
1345 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1346 room_updated(&room, &session.peer);
1347 }
1348 }
1349
1350 update_user_contacts(session.user_id(), &session).await?;
1351 },
1352 _ = teardown.changed().fuse() => {}
1353 }
1354
1355 Ok(())
1356}
1357
1358/// Acknowledges a ping from a client, used to keep the connection alive.
1359async fn ping(
1360 _: proto::Ping,
1361 response: Response<proto::Ping>,
1362 _session: MessageContext,
1363) -> Result<()> {
1364 response.send(proto::Ack {})?;
1365 Ok(())
1366}
1367
1368/// Creates a new room for calling (outside of channels)
1369async fn create_room(
1370 _request: proto::CreateRoom,
1371 response: Response<proto::CreateRoom>,
1372 session: MessageContext,
1373) -> Result<()> {
1374 let livekit_room = nanoid::nanoid!(30);
1375
1376 let live_kit_connection_info = util::maybe!(async {
1377 let live_kit = session.app_state.livekit_client.as_ref();
1378 let live_kit = live_kit?;
1379 let user_id = session.user_id().to_string();
1380
1381 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1382
1383 Some(proto::LiveKitConnectionInfo {
1384 server_url: live_kit.url().into(),
1385 token,
1386 can_publish: true,
1387 })
1388 })
1389 .await;
1390
1391 let room = session
1392 .db()
1393 .await
1394 .create_room(session.user_id(), session.connection_id, &livekit_room)
1395 .await?;
1396
1397 response.send(proto::CreateRoomResponse {
1398 room: Some(room.clone()),
1399 live_kit_connection_info,
1400 })?;
1401
1402 update_user_contacts(session.user_id(), &session).await?;
1403 Ok(())
1404}
1405
1406/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1407async fn join_room(
1408 request: proto::JoinRoom,
1409 response: Response<proto::JoinRoom>,
1410 session: MessageContext,
1411) -> Result<()> {
1412 let room_id = RoomId::from_proto(request.id);
1413
1414 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1415
1416 if let Some(channel_id) = channel_id {
1417 return join_channel_internal(channel_id, Box::new(response), session).await;
1418 }
1419
1420 let joined_room = {
1421 let room = session
1422 .db()
1423 .await
1424 .join_room(room_id, session.user_id(), session.connection_id)
1425 .await?;
1426 room_updated(&room.room, &session.peer);
1427 room.into_inner()
1428 };
1429
1430 for connection_id in session
1431 .connection_pool()
1432 .await
1433 .user_connection_ids(session.user_id())
1434 {
1435 session
1436 .peer
1437 .send(
1438 connection_id,
1439 proto::CallCanceled {
1440 room_id: room_id.to_proto(),
1441 },
1442 )
1443 .trace_err();
1444 }
1445
1446 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1447 {
1448 live_kit
1449 .room_token(
1450 &joined_room.room.livekit_room,
1451 &session.user_id().to_string(),
1452 )
1453 .trace_err()
1454 .map(|token| proto::LiveKitConnectionInfo {
1455 server_url: live_kit.url().into(),
1456 token,
1457 can_publish: true,
1458 })
1459 } else {
1460 None
1461 };
1462
1463 response.send(proto::JoinRoomResponse {
1464 room: Some(joined_room.room),
1465 channel_id: None,
1466 live_kit_connection_info,
1467 })?;
1468
1469 update_user_contacts(session.user_id(), &session).await?;
1470 Ok(())
1471}
1472
1473/// Rejoin room is used to reconnect to a room after connection errors.
1474async fn rejoin_room(
1475 request: proto::RejoinRoom,
1476 response: Response<proto::RejoinRoom>,
1477 session: MessageContext,
1478) -> Result<()> {
1479 let room;
1480 let channel;
1481 {
1482 let mut rejoined_room = session
1483 .db()
1484 .await
1485 .rejoin_room(request, session.user_id(), session.connection_id)
1486 .await?;
1487
1488 response.send(proto::RejoinRoomResponse {
1489 room: Some(rejoined_room.room.clone()),
1490 reshared_projects: rejoined_room
1491 .reshared_projects
1492 .iter()
1493 .map(|project| proto::ResharedProject {
1494 id: project.id.to_proto(),
1495 collaborators: project
1496 .collaborators
1497 .iter()
1498 .map(|collaborator| collaborator.to_proto())
1499 .collect(),
1500 })
1501 .collect(),
1502 rejoined_projects: rejoined_room
1503 .rejoined_projects
1504 .iter()
1505 .map(|rejoined_project| rejoined_project.to_proto())
1506 .collect(),
1507 })?;
1508 room_updated(&rejoined_room.room, &session.peer);
1509
1510 for project in &rejoined_room.reshared_projects {
1511 for collaborator in &project.collaborators {
1512 session
1513 .peer
1514 .send(
1515 collaborator.connection_id,
1516 proto::UpdateProjectCollaborator {
1517 project_id: project.id.to_proto(),
1518 old_peer_id: Some(project.old_connection_id.into()),
1519 new_peer_id: Some(session.connection_id.into()),
1520 },
1521 )
1522 .trace_err();
1523 }
1524
1525 broadcast(
1526 Some(session.connection_id),
1527 project
1528 .collaborators
1529 .iter()
1530 .map(|collaborator| collaborator.connection_id),
1531 |connection_id| {
1532 session.peer.forward_send(
1533 session.connection_id,
1534 connection_id,
1535 proto::UpdateProject {
1536 project_id: project.id.to_proto(),
1537 worktrees: project.worktrees.clone(),
1538 },
1539 )
1540 },
1541 );
1542 }
1543
1544 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1545
1546 let rejoined_room = rejoined_room.into_inner();
1547
1548 room = rejoined_room.room;
1549 channel = rejoined_room.channel;
1550 }
1551
1552 if let Some(channel) = channel {
1553 channel_updated(
1554 &channel,
1555 &room,
1556 &session.peer,
1557 &*session.connection_pool().await,
1558 );
1559 }
1560
1561 update_user_contacts(session.user_id(), &session).await?;
1562 Ok(())
1563}
1564
1565fn notify_rejoined_projects(
1566 rejoined_projects: &mut Vec<RejoinedProject>,
1567 session: &Session,
1568) -> Result<()> {
1569 for project in rejoined_projects.iter() {
1570 for collaborator in &project.collaborators {
1571 session
1572 .peer
1573 .send(
1574 collaborator.connection_id,
1575 proto::UpdateProjectCollaborator {
1576 project_id: project.id.to_proto(),
1577 old_peer_id: Some(project.old_connection_id.into()),
1578 new_peer_id: Some(session.connection_id.into()),
1579 },
1580 )
1581 .trace_err();
1582 }
1583 }
1584
1585 for project in rejoined_projects {
1586 for worktree in mem::take(&mut project.worktrees) {
1587 // Stream this worktree's entries.
1588 let message = proto::UpdateWorktree {
1589 project_id: project.id.to_proto(),
1590 worktree_id: worktree.id,
1591 abs_path: worktree.abs_path.clone(),
1592 root_name: worktree.root_name,
1593 updated_entries: worktree.updated_entries,
1594 removed_entries: worktree.removed_entries,
1595 scan_id: worktree.scan_id,
1596 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1597 updated_repositories: worktree.updated_repositories,
1598 removed_repositories: worktree.removed_repositories,
1599 };
1600 for update in proto::split_worktree_update(message) {
1601 session.peer.send(session.connection_id, update)?;
1602 }
1603
1604 // Stream this worktree's diagnostics.
1605 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1606 if let Some(summary) = worktree_diagnostics.next() {
1607 let message = proto::UpdateDiagnosticSummary {
1608 project_id: project.id.to_proto(),
1609 worktree_id: worktree.id,
1610 summary: Some(summary),
1611 more_summaries: worktree_diagnostics.collect(),
1612 };
1613 session.peer.send(session.connection_id, message)?;
1614 }
1615
1616 for settings_file in worktree.settings_files {
1617 session.peer.send(
1618 session.connection_id,
1619 proto::UpdateWorktreeSettings {
1620 project_id: project.id.to_proto(),
1621 worktree_id: worktree.id,
1622 path: settings_file.path,
1623 content: Some(settings_file.content),
1624 kind: Some(settings_file.kind.to_proto().into()),
1625 },
1626 )?;
1627 }
1628 }
1629
1630 for repository in mem::take(&mut project.updated_repositories) {
1631 for update in split_repository_update(repository) {
1632 session.peer.send(session.connection_id, update)?;
1633 }
1634 }
1635
1636 for id in mem::take(&mut project.removed_repositories) {
1637 session.peer.send(
1638 session.connection_id,
1639 proto::RemoveRepository {
1640 project_id: project.id.to_proto(),
1641 id,
1642 },
1643 )?;
1644 }
1645 }
1646
1647 Ok(())
1648}
1649
1650/// leave room disconnects from the room.
1651async fn leave_room(
1652 _: proto::LeaveRoom,
1653 response: Response<proto::LeaveRoom>,
1654 session: MessageContext,
1655) -> Result<()> {
1656 leave_room_for_session(&session, session.connection_id).await?;
1657 response.send(proto::Ack {})?;
1658 Ok(())
1659}
1660
1661/// Updates the permissions of someone else in the room.
1662async fn set_room_participant_role(
1663 request: proto::SetRoomParticipantRole,
1664 response: Response<proto::SetRoomParticipantRole>,
1665 session: MessageContext,
1666) -> Result<()> {
1667 let user_id = UserId::from_proto(request.user_id);
1668 let role = ChannelRole::from(request.role());
1669
1670 let (livekit_room, can_publish) = {
1671 let room = session
1672 .db()
1673 .await
1674 .set_room_participant_role(
1675 session.user_id(),
1676 RoomId::from_proto(request.room_id),
1677 user_id,
1678 role,
1679 )
1680 .await?;
1681
1682 let livekit_room = room.livekit_room.clone();
1683 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1684 room_updated(&room, &session.peer);
1685 (livekit_room, can_publish)
1686 };
1687
1688 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1689 live_kit
1690 .update_participant(
1691 livekit_room.clone(),
1692 request.user_id.to_string(),
1693 livekit_api::proto::ParticipantPermission {
1694 can_subscribe: true,
1695 can_publish,
1696 can_publish_data: can_publish,
1697 hidden: false,
1698 recorder: false,
1699 },
1700 )
1701 .await
1702 .trace_err();
1703 }
1704
1705 response.send(proto::Ack {})?;
1706 Ok(())
1707}
1708
1709/// Call someone else into the current room
1710async fn call(
1711 request: proto::Call,
1712 response: Response<proto::Call>,
1713 session: MessageContext,
1714) -> Result<()> {
1715 let room_id = RoomId::from_proto(request.room_id);
1716 let calling_user_id = session.user_id();
1717 let calling_connection_id = session.connection_id;
1718 let called_user_id = UserId::from_proto(request.called_user_id);
1719 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1720 if !session
1721 .db()
1722 .await
1723 .has_contact(calling_user_id, called_user_id)
1724 .await?
1725 {
1726 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1727 }
1728
1729 let incoming_call = {
1730 let (room, incoming_call) = &mut *session
1731 .db()
1732 .await
1733 .call(
1734 room_id,
1735 calling_user_id,
1736 calling_connection_id,
1737 called_user_id,
1738 initial_project_id,
1739 )
1740 .await?;
1741 room_updated(room, &session.peer);
1742 mem::take(incoming_call)
1743 };
1744 update_user_contacts(called_user_id, &session).await?;
1745
1746 let mut calls = session
1747 .connection_pool()
1748 .await
1749 .user_connection_ids(called_user_id)
1750 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1751 .collect::<FuturesUnordered<_>>();
1752
1753 while let Some(call_response) = calls.next().await {
1754 match call_response.as_ref() {
1755 Ok(_) => {
1756 response.send(proto::Ack {})?;
1757 return Ok(());
1758 }
1759 Err(_) => {
1760 call_response.trace_err();
1761 }
1762 }
1763 }
1764
1765 {
1766 let room = session
1767 .db()
1768 .await
1769 .call_failed(room_id, called_user_id)
1770 .await?;
1771 room_updated(&room, &session.peer);
1772 }
1773 update_user_contacts(called_user_id, &session).await?;
1774
1775 Err(anyhow!("failed to ring user"))?
1776}
1777
1778/// Cancel an outgoing call.
1779async fn cancel_call(
1780 request: proto::CancelCall,
1781 response: Response<proto::CancelCall>,
1782 session: MessageContext,
1783) -> Result<()> {
1784 let called_user_id = UserId::from_proto(request.called_user_id);
1785 let room_id = RoomId::from_proto(request.room_id);
1786 {
1787 let room = session
1788 .db()
1789 .await
1790 .cancel_call(room_id, session.connection_id, called_user_id)
1791 .await?;
1792 room_updated(&room, &session.peer);
1793 }
1794
1795 for connection_id in session
1796 .connection_pool()
1797 .await
1798 .user_connection_ids(called_user_id)
1799 {
1800 session
1801 .peer
1802 .send(
1803 connection_id,
1804 proto::CallCanceled {
1805 room_id: room_id.to_proto(),
1806 },
1807 )
1808 .trace_err();
1809 }
1810 response.send(proto::Ack {})?;
1811
1812 update_user_contacts(called_user_id, &session).await?;
1813 Ok(())
1814}
1815
1816/// Decline an incoming call.
1817async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1818 let room_id = RoomId::from_proto(message.room_id);
1819 {
1820 let room = session
1821 .db()
1822 .await
1823 .decline_call(Some(room_id), session.user_id())
1824 .await?
1825 .context("declining call")?;
1826 room_updated(&room, &session.peer);
1827 }
1828
1829 for connection_id in session
1830 .connection_pool()
1831 .await
1832 .user_connection_ids(session.user_id())
1833 {
1834 session
1835 .peer
1836 .send(
1837 connection_id,
1838 proto::CallCanceled {
1839 room_id: room_id.to_proto(),
1840 },
1841 )
1842 .trace_err();
1843 }
1844 update_user_contacts(session.user_id(), &session).await?;
1845 Ok(())
1846}
1847
1848/// Updates other participants in the room with your current location.
1849async fn update_participant_location(
1850 request: proto::UpdateParticipantLocation,
1851 response: Response<proto::UpdateParticipantLocation>,
1852 session: MessageContext,
1853) -> Result<()> {
1854 let room_id = RoomId::from_proto(request.room_id);
1855 let location = request.location.context("invalid location")?;
1856
1857 let db = session.db().await;
1858 let room = db
1859 .update_room_participant_location(room_id, session.connection_id, location)
1860 .await?;
1861
1862 room_updated(&room, &session.peer);
1863 response.send(proto::Ack {})?;
1864 Ok(())
1865}
1866
1867/// Share a project into the room.
1868async fn share_project(
1869 request: proto::ShareProject,
1870 response: Response<proto::ShareProject>,
1871 session: MessageContext,
1872) -> Result<()> {
1873 let (project_id, room) = &*session
1874 .db()
1875 .await
1876 .share_project(
1877 RoomId::from_proto(request.room_id),
1878 session.connection_id,
1879 &request.worktrees,
1880 request.is_ssh_project,
1881 )
1882 .await?;
1883 response.send(proto::ShareProjectResponse {
1884 project_id: project_id.to_proto(),
1885 })?;
1886 room_updated(room, &session.peer);
1887
1888 Ok(())
1889}
1890
1891/// Unshare a project from the room.
1892async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1893 let project_id = ProjectId::from_proto(message.project_id);
1894 unshare_project_internal(project_id, session.connection_id, &session).await
1895}
1896
1897async fn unshare_project_internal(
1898 project_id: ProjectId,
1899 connection_id: ConnectionId,
1900 session: &Session,
1901) -> Result<()> {
1902 let delete = {
1903 let room_guard = session
1904 .db()
1905 .await
1906 .unshare_project(project_id, connection_id)
1907 .await?;
1908
1909 let (delete, room, guest_connection_ids) = &*room_guard;
1910
1911 let message = proto::UnshareProject {
1912 project_id: project_id.to_proto(),
1913 };
1914
1915 broadcast(
1916 Some(connection_id),
1917 guest_connection_ids.iter().copied(),
1918 |conn_id| session.peer.send(conn_id, message.clone()),
1919 );
1920 if let Some(room) = room {
1921 room_updated(room, &session.peer);
1922 }
1923
1924 *delete
1925 };
1926
1927 if delete {
1928 let db = session.db().await;
1929 db.delete_project(project_id).await?;
1930 }
1931
1932 Ok(())
1933}
1934
1935/// Join someone elses shared project.
1936async fn join_project(
1937 request: proto::JoinProject,
1938 response: Response<proto::JoinProject>,
1939 session: MessageContext,
1940) -> Result<()> {
1941 let project_id = ProjectId::from_proto(request.project_id);
1942
1943 tracing::info!(%project_id, "join project");
1944
1945 let db = session.db().await;
1946 let (project, replica_id) = &mut *db
1947 .join_project(
1948 project_id,
1949 session.connection_id,
1950 session.user_id(),
1951 request.committer_name.clone(),
1952 request.committer_email.clone(),
1953 )
1954 .await?;
1955 drop(db);
1956 tracing::info!(%project_id, "join remote project");
1957 let collaborators = project
1958 .collaborators
1959 .iter()
1960 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1961 .map(|collaborator| collaborator.to_proto())
1962 .collect::<Vec<_>>();
1963 let project_id = project.id;
1964 let guest_user_id = session.user_id();
1965
1966 let worktrees = project
1967 .worktrees
1968 .iter()
1969 .map(|(id, worktree)| proto::WorktreeMetadata {
1970 id: *id,
1971 root_name: worktree.root_name.clone(),
1972 visible: worktree.visible,
1973 abs_path: worktree.abs_path.clone(),
1974 })
1975 .collect::<Vec<_>>();
1976
1977 let add_project_collaborator = proto::AddProjectCollaborator {
1978 project_id: project_id.to_proto(),
1979 collaborator: Some(proto::Collaborator {
1980 peer_id: Some(session.connection_id.into()),
1981 replica_id: replica_id.0 as u32,
1982 user_id: guest_user_id.to_proto(),
1983 is_host: false,
1984 committer_name: request.committer_name.clone(),
1985 committer_email: request.committer_email.clone(),
1986 }),
1987 };
1988
1989 for collaborator in &collaborators {
1990 session
1991 .peer
1992 .send(
1993 collaborator.peer_id.unwrap().into(),
1994 add_project_collaborator.clone(),
1995 )
1996 .trace_err();
1997 }
1998
1999 // First, we send the metadata associated with each worktree.
2000 let (language_servers, language_server_capabilities) = project
2001 .language_servers
2002 .clone()
2003 .into_iter()
2004 .map(|server| (server.server, server.capabilities))
2005 .unzip();
2006 response.send(proto::JoinProjectResponse {
2007 project_id: project.id.0 as u64,
2008 worktrees,
2009 replica_id: replica_id.0 as u32,
2010 collaborators,
2011 language_servers,
2012 language_server_capabilities,
2013 role: project.role.into(),
2014 })?;
2015
2016 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2017 // Stream this worktree's entries.
2018 let message = proto::UpdateWorktree {
2019 project_id: project_id.to_proto(),
2020 worktree_id,
2021 abs_path: worktree.abs_path.clone(),
2022 root_name: worktree.root_name,
2023 updated_entries: worktree.entries,
2024 removed_entries: Default::default(),
2025 scan_id: worktree.scan_id,
2026 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2027 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2028 removed_repositories: Default::default(),
2029 };
2030 for update in proto::split_worktree_update(message) {
2031 session.peer.send(session.connection_id, update.clone())?;
2032 }
2033
2034 // Stream this worktree's diagnostics.
2035 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2036 if let Some(summary) = worktree_diagnostics.next() {
2037 let message = proto::UpdateDiagnosticSummary {
2038 project_id: project.id.to_proto(),
2039 worktree_id: worktree.id,
2040 summary: Some(summary),
2041 more_summaries: worktree_diagnostics.collect(),
2042 };
2043 session.peer.send(session.connection_id, message)?;
2044 }
2045
2046 for settings_file in worktree.settings_files {
2047 session.peer.send(
2048 session.connection_id,
2049 proto::UpdateWorktreeSettings {
2050 project_id: project_id.to_proto(),
2051 worktree_id: worktree.id,
2052 path: settings_file.path,
2053 content: Some(settings_file.content),
2054 kind: Some(settings_file.kind.to_proto() as i32),
2055 },
2056 )?;
2057 }
2058 }
2059
2060 for repository in mem::take(&mut project.repositories) {
2061 for update in split_repository_update(repository) {
2062 session.peer.send(session.connection_id, update)?;
2063 }
2064 }
2065
2066 for language_server in &project.language_servers {
2067 session.peer.send(
2068 session.connection_id,
2069 proto::UpdateLanguageServer {
2070 project_id: project_id.to_proto(),
2071 server_name: Some(language_server.server.name.clone()),
2072 language_server_id: language_server.server.id,
2073 variant: Some(
2074 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2075 proto::LspDiskBasedDiagnosticsUpdated {},
2076 ),
2077 ),
2078 },
2079 )?;
2080 }
2081
2082 Ok(())
2083}
2084
2085/// Leave someone elses shared project.
2086async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2087 let sender_id = session.connection_id;
2088 let project_id = ProjectId::from_proto(request.project_id);
2089 let db = session.db().await;
2090
2091 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2092 tracing::info!(
2093 %project_id,
2094 "leave project"
2095 );
2096
2097 project_left(project, &session);
2098 if let Some(room) = room {
2099 room_updated(room, &session.peer);
2100 }
2101
2102 Ok(())
2103}
2104
2105/// Updates other participants with changes to the project
2106async fn update_project(
2107 request: proto::UpdateProject,
2108 response: Response<proto::UpdateProject>,
2109 session: MessageContext,
2110) -> Result<()> {
2111 let project_id = ProjectId::from_proto(request.project_id);
2112 let (room, guest_connection_ids) = &*session
2113 .db()
2114 .await
2115 .update_project(project_id, session.connection_id, &request.worktrees)
2116 .await?;
2117 broadcast(
2118 Some(session.connection_id),
2119 guest_connection_ids.iter().copied(),
2120 |connection_id| {
2121 session
2122 .peer
2123 .forward_send(session.connection_id, connection_id, request.clone())
2124 },
2125 );
2126 if let Some(room) = room {
2127 room_updated(room, &session.peer);
2128 }
2129 response.send(proto::Ack {})?;
2130
2131 Ok(())
2132}
2133
2134/// Updates other participants with changes to the worktree
2135async fn update_worktree(
2136 request: proto::UpdateWorktree,
2137 response: Response<proto::UpdateWorktree>,
2138 session: MessageContext,
2139) -> Result<()> {
2140 let guest_connection_ids = session
2141 .db()
2142 .await
2143 .update_worktree(&request, session.connection_id)
2144 .await?;
2145
2146 broadcast(
2147 Some(session.connection_id),
2148 guest_connection_ids.iter().copied(),
2149 |connection_id| {
2150 session
2151 .peer
2152 .forward_send(session.connection_id, connection_id, request.clone())
2153 },
2154 );
2155 response.send(proto::Ack {})?;
2156 Ok(())
2157}
2158
2159async fn update_repository(
2160 request: proto::UpdateRepository,
2161 response: Response<proto::UpdateRepository>,
2162 session: MessageContext,
2163) -> Result<()> {
2164 let guest_connection_ids = session
2165 .db()
2166 .await
2167 .update_repository(&request, session.connection_id)
2168 .await?;
2169
2170 broadcast(
2171 Some(session.connection_id),
2172 guest_connection_ids.iter().copied(),
2173 |connection_id| {
2174 session
2175 .peer
2176 .forward_send(session.connection_id, connection_id, request.clone())
2177 },
2178 );
2179 response.send(proto::Ack {})?;
2180 Ok(())
2181}
2182
2183async fn remove_repository(
2184 request: proto::RemoveRepository,
2185 response: Response<proto::RemoveRepository>,
2186 session: MessageContext,
2187) -> Result<()> {
2188 let guest_connection_ids = session
2189 .db()
2190 .await
2191 .remove_repository(&request, session.connection_id)
2192 .await?;
2193
2194 broadcast(
2195 Some(session.connection_id),
2196 guest_connection_ids.iter().copied(),
2197 |connection_id| {
2198 session
2199 .peer
2200 .forward_send(session.connection_id, connection_id, request.clone())
2201 },
2202 );
2203 response.send(proto::Ack {})?;
2204 Ok(())
2205}
2206
2207/// Updates other participants with changes to the diagnostics
2208async fn update_diagnostic_summary(
2209 message: proto::UpdateDiagnosticSummary,
2210 session: MessageContext,
2211) -> Result<()> {
2212 let guest_connection_ids = session
2213 .db()
2214 .await
2215 .update_diagnostic_summary(&message, session.connection_id)
2216 .await?;
2217
2218 broadcast(
2219 Some(session.connection_id),
2220 guest_connection_ids.iter().copied(),
2221 |connection_id| {
2222 session
2223 .peer
2224 .forward_send(session.connection_id, connection_id, message.clone())
2225 },
2226 );
2227
2228 Ok(())
2229}
2230
2231/// Updates other participants with changes to the worktree settings
2232async fn update_worktree_settings(
2233 message: proto::UpdateWorktreeSettings,
2234 session: MessageContext,
2235) -> Result<()> {
2236 let guest_connection_ids = session
2237 .db()
2238 .await
2239 .update_worktree_settings(&message, session.connection_id)
2240 .await?;
2241
2242 broadcast(
2243 Some(session.connection_id),
2244 guest_connection_ids.iter().copied(),
2245 |connection_id| {
2246 session
2247 .peer
2248 .forward_send(session.connection_id, connection_id, message.clone())
2249 },
2250 );
2251
2252 Ok(())
2253}
2254
2255/// Notify other participants that a language server has started.
2256async fn start_language_server(
2257 request: proto::StartLanguageServer,
2258 session: MessageContext,
2259) -> Result<()> {
2260 let guest_connection_ids = session
2261 .db()
2262 .await
2263 .start_language_server(&request, session.connection_id)
2264 .await?;
2265
2266 broadcast(
2267 Some(session.connection_id),
2268 guest_connection_ids.iter().copied(),
2269 |connection_id| {
2270 session
2271 .peer
2272 .forward_send(session.connection_id, connection_id, request.clone())
2273 },
2274 );
2275 Ok(())
2276}
2277
2278/// Notify other participants that a language server has changed.
2279async fn update_language_server(
2280 request: proto::UpdateLanguageServer,
2281 session: MessageContext,
2282) -> Result<()> {
2283 let project_id = ProjectId::from_proto(request.project_id);
2284 let db = session.db().await;
2285
2286 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2287 && let Some(capabilities) = update.capabilities.clone()
2288 {
2289 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2290 .await?;
2291 }
2292
2293 let project_connection_ids = db
2294 .project_connection_ids(project_id, session.connection_id, true)
2295 .await?;
2296 broadcast(
2297 Some(session.connection_id),
2298 project_connection_ids.iter().copied(),
2299 |connection_id| {
2300 session
2301 .peer
2302 .forward_send(session.connection_id, connection_id, request.clone())
2303 },
2304 );
2305 Ok(())
2306}
2307
2308/// forward a project request to the host. These requests should be read only
2309/// as guests are allowed to send them.
2310async fn forward_read_only_project_request<T>(
2311 request: T,
2312 response: Response<T>,
2313 session: MessageContext,
2314) -> Result<()>
2315where
2316 T: EntityMessage + RequestMessage,
2317{
2318 let project_id = ProjectId::from_proto(request.remote_entity_id());
2319 let host_connection_id = session
2320 .db()
2321 .await
2322 .host_for_read_only_project_request(project_id, session.connection_id)
2323 .await?;
2324 let payload = session.forward_request(host_connection_id, request).await?;
2325 response.send(payload)?;
2326 Ok(())
2327}
2328
2329/// forward a project request to the host. These requests are disallowed
2330/// for guests.
2331async fn forward_mutating_project_request<T>(
2332 request: T,
2333 response: Response<T>,
2334 session: MessageContext,
2335) -> Result<()>
2336where
2337 T: EntityMessage + RequestMessage,
2338{
2339 let project_id = ProjectId::from_proto(request.remote_entity_id());
2340
2341 let host_connection_id = session
2342 .db()
2343 .await
2344 .host_for_mutating_project_request(project_id, session.connection_id)
2345 .await?;
2346 let payload = session.forward_request(host_connection_id, request).await?;
2347 response.send(payload)?;
2348 Ok(())
2349}
2350
2351async fn lsp_query(
2352 request: proto::LspQuery,
2353 response: Response<proto::LspQuery>,
2354 session: MessageContext,
2355) -> Result<()> {
2356 let (name, should_write) = request.query_name_and_write_permissions();
2357 tracing::Span::current().record("lsp_query_request", name);
2358 tracing::info!("lsp_query message received");
2359 if should_write {
2360 forward_mutating_project_request(request, response, session).await
2361 } else {
2362 forward_read_only_project_request(request, response, session).await
2363 }
2364}
2365
2366/// Notify other participants that a new buffer has been created
2367async fn create_buffer_for_peer(
2368 request: proto::CreateBufferForPeer,
2369 session: MessageContext,
2370) -> Result<()> {
2371 session
2372 .db()
2373 .await
2374 .check_user_is_project_host(
2375 ProjectId::from_proto(request.project_id),
2376 session.connection_id,
2377 )
2378 .await?;
2379 let peer_id = request.peer_id.context("invalid peer id")?;
2380 session
2381 .peer
2382 .forward_send(session.connection_id, peer_id.into(), request)?;
2383 Ok(())
2384}
2385
2386/// Notify other participants that a buffer has been updated. This is
2387/// allowed for guests as long as the update is limited to selections.
2388async fn update_buffer(
2389 request: proto::UpdateBuffer,
2390 response: Response<proto::UpdateBuffer>,
2391 session: MessageContext,
2392) -> Result<()> {
2393 let project_id = ProjectId::from_proto(request.project_id);
2394 let mut capability = Capability::ReadOnly;
2395
2396 for op in request.operations.iter() {
2397 match op.variant {
2398 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2399 Some(_) => capability = Capability::ReadWrite,
2400 }
2401 }
2402
2403 let host = {
2404 let guard = session
2405 .db()
2406 .await
2407 .connections_for_buffer_update(project_id, session.connection_id, capability)
2408 .await?;
2409
2410 let (host, guests) = &*guard;
2411
2412 broadcast(
2413 Some(session.connection_id),
2414 guests.clone(),
2415 |connection_id| {
2416 session
2417 .peer
2418 .forward_send(session.connection_id, connection_id, request.clone())
2419 },
2420 );
2421
2422 *host
2423 };
2424
2425 if host != session.connection_id {
2426 session.forward_request(host, request.clone()).await?;
2427 }
2428
2429 response.send(proto::Ack {})?;
2430 Ok(())
2431}
2432
2433async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2434 let project_id = ProjectId::from_proto(message.project_id);
2435
2436 let operation = message.operation.as_ref().context("invalid operation")?;
2437 let capability = match operation.variant.as_ref() {
2438 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2439 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2440 match buffer_op.variant {
2441 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2442 Capability::ReadOnly
2443 }
2444 _ => Capability::ReadWrite,
2445 }
2446 } else {
2447 Capability::ReadWrite
2448 }
2449 }
2450 Some(_) => Capability::ReadWrite,
2451 None => Capability::ReadOnly,
2452 };
2453
2454 let guard = session
2455 .db()
2456 .await
2457 .connections_for_buffer_update(project_id, session.connection_id, capability)
2458 .await?;
2459
2460 let (host, guests) = &*guard;
2461
2462 broadcast(
2463 Some(session.connection_id),
2464 guests.iter().chain([host]).copied(),
2465 |connection_id| {
2466 session
2467 .peer
2468 .forward_send(session.connection_id, connection_id, message.clone())
2469 },
2470 );
2471
2472 Ok(())
2473}
2474
2475/// Notify other participants that a project has been updated.
2476async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2477 request: T,
2478 session: MessageContext,
2479) -> Result<()> {
2480 let project_id = ProjectId::from_proto(request.remote_entity_id());
2481 let project_connection_ids = session
2482 .db()
2483 .await
2484 .project_connection_ids(project_id, session.connection_id, false)
2485 .await?;
2486
2487 broadcast(
2488 Some(session.connection_id),
2489 project_connection_ids.iter().copied(),
2490 |connection_id| {
2491 session
2492 .peer
2493 .forward_send(session.connection_id, connection_id, request.clone())
2494 },
2495 );
2496 Ok(())
2497}
2498
2499/// Start following another user in a call.
2500async fn follow(
2501 request: proto::Follow,
2502 response: Response<proto::Follow>,
2503 session: MessageContext,
2504) -> Result<()> {
2505 let room_id = RoomId::from_proto(request.room_id);
2506 let project_id = request.project_id.map(ProjectId::from_proto);
2507 let leader_id = request.leader_id.context("invalid leader id")?.into();
2508 let follower_id = session.connection_id;
2509
2510 session
2511 .db()
2512 .await
2513 .check_room_participants(room_id, leader_id, session.connection_id)
2514 .await?;
2515
2516 let response_payload = session.forward_request(leader_id, request).await?;
2517 response.send(response_payload)?;
2518
2519 if let Some(project_id) = project_id {
2520 let room = session
2521 .db()
2522 .await
2523 .follow(room_id, project_id, leader_id, follower_id)
2524 .await?;
2525 room_updated(&room, &session.peer);
2526 }
2527
2528 Ok(())
2529}
2530
2531/// Stop following another user in a call.
2532async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2533 let room_id = RoomId::from_proto(request.room_id);
2534 let project_id = request.project_id.map(ProjectId::from_proto);
2535 let leader_id = request.leader_id.context("invalid leader id")?.into();
2536 let follower_id = session.connection_id;
2537
2538 session
2539 .db()
2540 .await
2541 .check_room_participants(room_id, leader_id, session.connection_id)
2542 .await?;
2543
2544 session
2545 .peer
2546 .forward_send(session.connection_id, leader_id, request)?;
2547
2548 if let Some(project_id) = project_id {
2549 let room = session
2550 .db()
2551 .await
2552 .unfollow(room_id, project_id, leader_id, follower_id)
2553 .await?;
2554 room_updated(&room, &session.peer);
2555 }
2556
2557 Ok(())
2558}
2559
2560/// Notify everyone following you of your current location.
2561async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2562 let room_id = RoomId::from_proto(request.room_id);
2563 let database = session.db.lock().await;
2564
2565 let connection_ids = if let Some(project_id) = request.project_id {
2566 let project_id = ProjectId::from_proto(project_id);
2567 database
2568 .project_connection_ids(project_id, session.connection_id, true)
2569 .await?
2570 } else {
2571 database
2572 .room_connection_ids(room_id, session.connection_id)
2573 .await?
2574 };
2575
2576 // For now, don't send view update messages back to that view's current leader.
2577 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2578 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2579 _ => None,
2580 });
2581
2582 for connection_id in connection_ids.iter().cloned() {
2583 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2584 session
2585 .peer
2586 .forward_send(session.connection_id, connection_id, request.clone())?;
2587 }
2588 }
2589 Ok(())
2590}
2591
2592/// Get public data about users.
2593async fn get_users(
2594 request: proto::GetUsers,
2595 response: Response<proto::GetUsers>,
2596 session: MessageContext,
2597) -> Result<()> {
2598 let user_ids = request
2599 .user_ids
2600 .into_iter()
2601 .map(UserId::from_proto)
2602 .collect();
2603 let users = session
2604 .db()
2605 .await
2606 .get_users_by_ids(user_ids)
2607 .await?
2608 .into_iter()
2609 .map(|user| proto::User {
2610 id: user.id.to_proto(),
2611 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2612 github_login: user.github_login,
2613 name: user.name,
2614 })
2615 .collect();
2616 response.send(proto::UsersResponse { users })?;
2617 Ok(())
2618}
2619
2620/// Search for users (to invite) buy Github login
2621async fn fuzzy_search_users(
2622 request: proto::FuzzySearchUsers,
2623 response: Response<proto::FuzzySearchUsers>,
2624 session: MessageContext,
2625) -> Result<()> {
2626 let query = request.query;
2627 let users = match query.len() {
2628 0 => vec![],
2629 1 | 2 => session
2630 .db()
2631 .await
2632 .get_user_by_github_login(&query)
2633 .await?
2634 .into_iter()
2635 .collect(),
2636 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2637 };
2638 let users = users
2639 .into_iter()
2640 .filter(|user| user.id != session.user_id())
2641 .map(|user| proto::User {
2642 id: user.id.to_proto(),
2643 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2644 github_login: user.github_login,
2645 name: user.name,
2646 })
2647 .collect();
2648 response.send(proto::UsersResponse { users })?;
2649 Ok(())
2650}
2651
2652/// Send a contact request to another user.
2653async fn request_contact(
2654 request: proto::RequestContact,
2655 response: Response<proto::RequestContact>,
2656 session: MessageContext,
2657) -> Result<()> {
2658 let requester_id = session.user_id();
2659 let responder_id = UserId::from_proto(request.responder_id);
2660 if requester_id == responder_id {
2661 return Err(anyhow!("cannot add yourself as a contact"))?;
2662 }
2663
2664 let notifications = session
2665 .db()
2666 .await
2667 .send_contact_request(requester_id, responder_id)
2668 .await?;
2669
2670 // Update outgoing contact requests of requester
2671 let mut update = proto::UpdateContacts::default();
2672 update.outgoing_requests.push(responder_id.to_proto());
2673 for connection_id in session
2674 .connection_pool()
2675 .await
2676 .user_connection_ids(requester_id)
2677 {
2678 session.peer.send(connection_id, update.clone())?;
2679 }
2680
2681 // Update incoming contact requests of responder
2682 let mut update = proto::UpdateContacts::default();
2683 update
2684 .incoming_requests
2685 .push(proto::IncomingContactRequest {
2686 requester_id: requester_id.to_proto(),
2687 });
2688 let connection_pool = session.connection_pool().await;
2689 for connection_id in connection_pool.user_connection_ids(responder_id) {
2690 session.peer.send(connection_id, update.clone())?;
2691 }
2692
2693 send_notifications(&connection_pool, &session.peer, notifications);
2694
2695 response.send(proto::Ack {})?;
2696 Ok(())
2697}
2698
2699/// Accept or decline a contact request
2700async fn respond_to_contact_request(
2701 request: proto::RespondToContactRequest,
2702 response: Response<proto::RespondToContactRequest>,
2703 session: MessageContext,
2704) -> Result<()> {
2705 let responder_id = session.user_id();
2706 let requester_id = UserId::from_proto(request.requester_id);
2707 let db = session.db().await;
2708 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2709 db.dismiss_contact_notification(responder_id, requester_id)
2710 .await?;
2711 } else {
2712 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2713
2714 let notifications = db
2715 .respond_to_contact_request(responder_id, requester_id, accept)
2716 .await?;
2717 let requester_busy = db.is_user_busy(requester_id).await?;
2718 let responder_busy = db.is_user_busy(responder_id).await?;
2719
2720 let pool = session.connection_pool().await;
2721 // Update responder with new contact
2722 let mut update = proto::UpdateContacts::default();
2723 if accept {
2724 update
2725 .contacts
2726 .push(contact_for_user(requester_id, requester_busy, &pool));
2727 }
2728 update
2729 .remove_incoming_requests
2730 .push(requester_id.to_proto());
2731 for connection_id in pool.user_connection_ids(responder_id) {
2732 session.peer.send(connection_id, update.clone())?;
2733 }
2734
2735 // Update requester with new contact
2736 let mut update = proto::UpdateContacts::default();
2737 if accept {
2738 update
2739 .contacts
2740 .push(contact_for_user(responder_id, responder_busy, &pool));
2741 }
2742 update
2743 .remove_outgoing_requests
2744 .push(responder_id.to_proto());
2745
2746 for connection_id in pool.user_connection_ids(requester_id) {
2747 session.peer.send(connection_id, update.clone())?;
2748 }
2749
2750 send_notifications(&pool, &session.peer, notifications);
2751 }
2752
2753 response.send(proto::Ack {})?;
2754 Ok(())
2755}
2756
2757/// Remove a contact.
2758async fn remove_contact(
2759 request: proto::RemoveContact,
2760 response: Response<proto::RemoveContact>,
2761 session: MessageContext,
2762) -> Result<()> {
2763 let requester_id = session.user_id();
2764 let responder_id = UserId::from_proto(request.user_id);
2765 let db = session.db().await;
2766 let (contact_accepted, deleted_notification_id) =
2767 db.remove_contact(requester_id, responder_id).await?;
2768
2769 let pool = session.connection_pool().await;
2770 // Update outgoing contact requests of requester
2771 let mut update = proto::UpdateContacts::default();
2772 if contact_accepted {
2773 update.remove_contacts.push(responder_id.to_proto());
2774 } else {
2775 update
2776 .remove_outgoing_requests
2777 .push(responder_id.to_proto());
2778 }
2779 for connection_id in pool.user_connection_ids(requester_id) {
2780 session.peer.send(connection_id, update.clone())?;
2781 }
2782
2783 // Update incoming contact requests of responder
2784 let mut update = proto::UpdateContacts::default();
2785 if contact_accepted {
2786 update.remove_contacts.push(requester_id.to_proto());
2787 } else {
2788 update
2789 .remove_incoming_requests
2790 .push(requester_id.to_proto());
2791 }
2792 for connection_id in pool.user_connection_ids(responder_id) {
2793 session.peer.send(connection_id, update.clone())?;
2794 if let Some(notification_id) = deleted_notification_id {
2795 session.peer.send(
2796 connection_id,
2797 proto::DeleteNotification {
2798 notification_id: notification_id.to_proto(),
2799 },
2800 )?;
2801 }
2802 }
2803
2804 response.send(proto::Ack {})?;
2805 Ok(())
2806}
2807
2808fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2809 version.0.minor() < 139
2810}
2811
2812async fn subscribe_to_channels(
2813 _: proto::SubscribeToChannels,
2814 session: MessageContext,
2815) -> Result<()> {
2816 subscribe_user_to_channels(session.user_id(), &session).await?;
2817 Ok(())
2818}
2819
2820async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2821 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2822 let mut pool = session.connection_pool().await;
2823 for membership in &channels_for_user.channel_memberships {
2824 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2825 }
2826 session.peer.send(
2827 session.connection_id,
2828 build_update_user_channels(&channels_for_user),
2829 )?;
2830 session.peer.send(
2831 session.connection_id,
2832 build_channels_update(channels_for_user),
2833 )?;
2834 Ok(())
2835}
2836
2837/// Creates a new channel.
2838async fn create_channel(
2839 request: proto::CreateChannel,
2840 response: Response<proto::CreateChannel>,
2841 session: MessageContext,
2842) -> Result<()> {
2843 let db = session.db().await;
2844
2845 let parent_id = request.parent_id.map(ChannelId::from_proto);
2846 let (channel, membership) = db
2847 .create_channel(&request.name, parent_id, session.user_id())
2848 .await?;
2849
2850 let root_id = channel.root_id();
2851 let channel = Channel::from_model(channel);
2852
2853 response.send(proto::CreateChannelResponse {
2854 channel: Some(channel.to_proto()),
2855 parent_id: request.parent_id,
2856 })?;
2857
2858 let mut connection_pool = session.connection_pool().await;
2859 if let Some(membership) = membership {
2860 connection_pool.subscribe_to_channel(
2861 membership.user_id,
2862 membership.channel_id,
2863 membership.role,
2864 );
2865 let update = proto::UpdateUserChannels {
2866 channel_memberships: vec![proto::ChannelMembership {
2867 channel_id: membership.channel_id.to_proto(),
2868 role: membership.role.into(),
2869 }],
2870 ..Default::default()
2871 };
2872 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2873 session.peer.send(connection_id, update.clone())?;
2874 }
2875 }
2876
2877 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2878 if !role.can_see_channel(channel.visibility) {
2879 continue;
2880 }
2881
2882 let update = proto::UpdateChannels {
2883 channels: vec![channel.to_proto()],
2884 ..Default::default()
2885 };
2886 session.peer.send(connection_id, update.clone())?;
2887 }
2888
2889 Ok(())
2890}
2891
2892/// Delete a channel
2893async fn delete_channel(
2894 request: proto::DeleteChannel,
2895 response: Response<proto::DeleteChannel>,
2896 session: MessageContext,
2897) -> Result<()> {
2898 let db = session.db().await;
2899
2900 let channel_id = request.channel_id;
2901 let (root_channel, removed_channels) = db
2902 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2903 .await?;
2904 response.send(proto::Ack {})?;
2905
2906 // Notify members of removed channels
2907 let mut update = proto::UpdateChannels::default();
2908 update
2909 .delete_channels
2910 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2911
2912 let connection_pool = session.connection_pool().await;
2913 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2914 session.peer.send(connection_id, update.clone())?;
2915 }
2916
2917 Ok(())
2918}
2919
2920/// Invite someone to join a channel.
2921async fn invite_channel_member(
2922 request: proto::InviteChannelMember,
2923 response: Response<proto::InviteChannelMember>,
2924 session: MessageContext,
2925) -> Result<()> {
2926 let db = session.db().await;
2927 let channel_id = ChannelId::from_proto(request.channel_id);
2928 let invitee_id = UserId::from_proto(request.user_id);
2929 let InviteMemberResult {
2930 channel,
2931 notifications,
2932 } = db
2933 .invite_channel_member(
2934 channel_id,
2935 invitee_id,
2936 session.user_id(),
2937 request.role().into(),
2938 )
2939 .await?;
2940
2941 let update = proto::UpdateChannels {
2942 channel_invitations: vec![channel.to_proto()],
2943 ..Default::default()
2944 };
2945
2946 let connection_pool = session.connection_pool().await;
2947 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2948 session.peer.send(connection_id, update.clone())?;
2949 }
2950
2951 send_notifications(&connection_pool, &session.peer, notifications);
2952
2953 response.send(proto::Ack {})?;
2954 Ok(())
2955}
2956
2957/// remove someone from a channel
2958async fn remove_channel_member(
2959 request: proto::RemoveChannelMember,
2960 response: Response<proto::RemoveChannelMember>,
2961 session: MessageContext,
2962) -> Result<()> {
2963 let db = session.db().await;
2964 let channel_id = ChannelId::from_proto(request.channel_id);
2965 let member_id = UserId::from_proto(request.user_id);
2966
2967 let RemoveChannelMemberResult {
2968 membership_update,
2969 notification_id,
2970 } = db
2971 .remove_channel_member(channel_id, member_id, session.user_id())
2972 .await?;
2973
2974 let mut connection_pool = session.connection_pool().await;
2975 notify_membership_updated(
2976 &mut connection_pool,
2977 membership_update,
2978 member_id,
2979 &session.peer,
2980 );
2981 for connection_id in connection_pool.user_connection_ids(member_id) {
2982 if let Some(notification_id) = notification_id {
2983 session
2984 .peer
2985 .send(
2986 connection_id,
2987 proto::DeleteNotification {
2988 notification_id: notification_id.to_proto(),
2989 },
2990 )
2991 .trace_err();
2992 }
2993 }
2994
2995 response.send(proto::Ack {})?;
2996 Ok(())
2997}
2998
2999/// Toggle the channel between public and private.
3000/// Care is taken to maintain the invariant that public channels only descend from public channels,
3001/// (though members-only channels can appear at any point in the hierarchy).
3002async fn set_channel_visibility(
3003 request: proto::SetChannelVisibility,
3004 response: Response<proto::SetChannelVisibility>,
3005 session: MessageContext,
3006) -> Result<()> {
3007 let db = session.db().await;
3008 let channel_id = ChannelId::from_proto(request.channel_id);
3009 let visibility = request.visibility().into();
3010
3011 let channel_model = db
3012 .set_channel_visibility(channel_id, visibility, session.user_id())
3013 .await?;
3014 let root_id = channel_model.root_id();
3015 let channel = Channel::from_model(channel_model);
3016
3017 let mut connection_pool = session.connection_pool().await;
3018 for (user_id, role) in connection_pool
3019 .channel_user_ids(root_id)
3020 .collect::<Vec<_>>()
3021 .into_iter()
3022 {
3023 let update = if role.can_see_channel(channel.visibility) {
3024 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3025 proto::UpdateChannels {
3026 channels: vec![channel.to_proto()],
3027 ..Default::default()
3028 }
3029 } else {
3030 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3031 proto::UpdateChannels {
3032 delete_channels: vec![channel.id.to_proto()],
3033 ..Default::default()
3034 }
3035 };
3036
3037 for connection_id in connection_pool.user_connection_ids(user_id) {
3038 session.peer.send(connection_id, update.clone())?;
3039 }
3040 }
3041
3042 response.send(proto::Ack {})?;
3043 Ok(())
3044}
3045
3046/// Alter the role for a user in the channel.
3047async fn set_channel_member_role(
3048 request: proto::SetChannelMemberRole,
3049 response: Response<proto::SetChannelMemberRole>,
3050 session: MessageContext,
3051) -> Result<()> {
3052 let db = session.db().await;
3053 let channel_id = ChannelId::from_proto(request.channel_id);
3054 let member_id = UserId::from_proto(request.user_id);
3055 let result = db
3056 .set_channel_member_role(
3057 channel_id,
3058 session.user_id(),
3059 member_id,
3060 request.role().into(),
3061 )
3062 .await?;
3063
3064 match result {
3065 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3066 let mut connection_pool = session.connection_pool().await;
3067 notify_membership_updated(
3068 &mut connection_pool,
3069 membership_update,
3070 member_id,
3071 &session.peer,
3072 )
3073 }
3074 db::SetMemberRoleResult::InviteUpdated(channel) => {
3075 let update = proto::UpdateChannels {
3076 channel_invitations: vec![channel.to_proto()],
3077 ..Default::default()
3078 };
3079
3080 for connection_id in session
3081 .connection_pool()
3082 .await
3083 .user_connection_ids(member_id)
3084 {
3085 session.peer.send(connection_id, update.clone())?;
3086 }
3087 }
3088 }
3089
3090 response.send(proto::Ack {})?;
3091 Ok(())
3092}
3093
3094/// Change the name of a channel
3095async fn rename_channel(
3096 request: proto::RenameChannel,
3097 response: Response<proto::RenameChannel>,
3098 session: MessageContext,
3099) -> Result<()> {
3100 let db = session.db().await;
3101 let channel_id = ChannelId::from_proto(request.channel_id);
3102 let channel_model = db
3103 .rename_channel(channel_id, session.user_id(), &request.name)
3104 .await?;
3105 let root_id = channel_model.root_id();
3106 let channel = Channel::from_model(channel_model);
3107
3108 response.send(proto::RenameChannelResponse {
3109 channel: Some(channel.to_proto()),
3110 })?;
3111
3112 let connection_pool = session.connection_pool().await;
3113 let update = proto::UpdateChannels {
3114 channels: vec![channel.to_proto()],
3115 ..Default::default()
3116 };
3117 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3118 if role.can_see_channel(channel.visibility) {
3119 session.peer.send(connection_id, update.clone())?;
3120 }
3121 }
3122
3123 Ok(())
3124}
3125
3126/// Move a channel to a new parent.
3127async fn move_channel(
3128 request: proto::MoveChannel,
3129 response: Response<proto::MoveChannel>,
3130 session: MessageContext,
3131) -> Result<()> {
3132 let channel_id = ChannelId::from_proto(request.channel_id);
3133 let to = ChannelId::from_proto(request.to);
3134
3135 let (root_id, channels) = session
3136 .db()
3137 .await
3138 .move_channel(channel_id, to, session.user_id())
3139 .await?;
3140
3141 let connection_pool = session.connection_pool().await;
3142 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3143 let channels = channels
3144 .iter()
3145 .filter_map(|channel| {
3146 if role.can_see_channel(channel.visibility) {
3147 Some(channel.to_proto())
3148 } else {
3149 None
3150 }
3151 })
3152 .collect::<Vec<_>>();
3153 if channels.is_empty() {
3154 continue;
3155 }
3156
3157 let update = proto::UpdateChannels {
3158 channels,
3159 ..Default::default()
3160 };
3161
3162 session.peer.send(connection_id, update.clone())?;
3163 }
3164
3165 response.send(Ack {})?;
3166 Ok(())
3167}
3168
3169async fn reorder_channel(
3170 request: proto::ReorderChannel,
3171 response: Response<proto::ReorderChannel>,
3172 session: MessageContext,
3173) -> Result<()> {
3174 let channel_id = ChannelId::from_proto(request.channel_id);
3175 let direction = request.direction();
3176
3177 let updated_channels = session
3178 .db()
3179 .await
3180 .reorder_channel(channel_id, direction, session.user_id())
3181 .await?;
3182
3183 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3184 let connection_pool = session.connection_pool().await;
3185 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3186 let channels = updated_channels
3187 .iter()
3188 .filter_map(|channel| {
3189 if role.can_see_channel(channel.visibility) {
3190 Some(channel.to_proto())
3191 } else {
3192 None
3193 }
3194 })
3195 .collect::<Vec<_>>();
3196
3197 if channels.is_empty() {
3198 continue;
3199 }
3200
3201 let update = proto::UpdateChannels {
3202 channels,
3203 ..Default::default()
3204 };
3205
3206 session.peer.send(connection_id, update.clone())?;
3207 }
3208 }
3209
3210 response.send(Ack {})?;
3211 Ok(())
3212}
3213
3214/// Get the list of channel members
3215async fn get_channel_members(
3216 request: proto::GetChannelMembers,
3217 response: Response<proto::GetChannelMembers>,
3218 session: MessageContext,
3219) -> Result<()> {
3220 let db = session.db().await;
3221 let channel_id = ChannelId::from_proto(request.channel_id);
3222 let limit = if request.limit == 0 {
3223 u16::MAX as u64
3224 } else {
3225 request.limit
3226 };
3227 let (members, users) = db
3228 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3229 .await?;
3230 response.send(proto::GetChannelMembersResponse { members, users })?;
3231 Ok(())
3232}
3233
3234/// Accept or decline a channel invitation.
3235async fn respond_to_channel_invite(
3236 request: proto::RespondToChannelInvite,
3237 response: Response<proto::RespondToChannelInvite>,
3238 session: MessageContext,
3239) -> Result<()> {
3240 let db = session.db().await;
3241 let channel_id = ChannelId::from_proto(request.channel_id);
3242 let RespondToChannelInvite {
3243 membership_update,
3244 notifications,
3245 } = db
3246 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3247 .await?;
3248
3249 let mut connection_pool = session.connection_pool().await;
3250 if let Some(membership_update) = membership_update {
3251 notify_membership_updated(
3252 &mut connection_pool,
3253 membership_update,
3254 session.user_id(),
3255 &session.peer,
3256 );
3257 } else {
3258 let update = proto::UpdateChannels {
3259 remove_channel_invitations: vec![channel_id.to_proto()],
3260 ..Default::default()
3261 };
3262
3263 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3264 session.peer.send(connection_id, update.clone())?;
3265 }
3266 };
3267
3268 send_notifications(&connection_pool, &session.peer, notifications);
3269
3270 response.send(proto::Ack {})?;
3271
3272 Ok(())
3273}
3274
3275/// Join the channels' room
3276async fn join_channel(
3277 request: proto::JoinChannel,
3278 response: Response<proto::JoinChannel>,
3279 session: MessageContext,
3280) -> Result<()> {
3281 let channel_id = ChannelId::from_proto(request.channel_id);
3282 join_channel_internal(channel_id, Box::new(response), session).await
3283}
3284
3285trait JoinChannelInternalResponse {
3286 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3287}
3288impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3289 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3290 Response::<proto::JoinChannel>::send(self, result)
3291 }
3292}
3293impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3294 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3295 Response::<proto::JoinRoom>::send(self, result)
3296 }
3297}
3298
3299async fn join_channel_internal(
3300 channel_id: ChannelId,
3301 response: Box<impl JoinChannelInternalResponse>,
3302 session: MessageContext,
3303) -> Result<()> {
3304 let joined_room = {
3305 let mut db = session.db().await;
3306 // If zed quits without leaving the room, and the user re-opens zed before the
3307 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3308 // room they were in.
3309 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3310 tracing::info!(
3311 stale_connection_id = %connection,
3312 "cleaning up stale connection",
3313 );
3314 drop(db);
3315 leave_room_for_session(&session, connection).await?;
3316 db = session.db().await;
3317 }
3318
3319 let (joined_room, membership_updated, role) = db
3320 .join_channel(channel_id, session.user_id(), session.connection_id)
3321 .await?;
3322
3323 let live_kit_connection_info =
3324 session
3325 .app_state
3326 .livekit_client
3327 .as_ref()
3328 .and_then(|live_kit| {
3329 let (can_publish, token) = if role == ChannelRole::Guest {
3330 (
3331 false,
3332 live_kit
3333 .guest_token(
3334 &joined_room.room.livekit_room,
3335 &session.user_id().to_string(),
3336 )
3337 .trace_err()?,
3338 )
3339 } else {
3340 (
3341 true,
3342 live_kit
3343 .room_token(
3344 &joined_room.room.livekit_room,
3345 &session.user_id().to_string(),
3346 )
3347 .trace_err()?,
3348 )
3349 };
3350
3351 Some(LiveKitConnectionInfo {
3352 server_url: live_kit.url().into(),
3353 token,
3354 can_publish,
3355 })
3356 });
3357
3358 response.send(proto::JoinRoomResponse {
3359 room: Some(joined_room.room.clone()),
3360 channel_id: joined_room
3361 .channel
3362 .as_ref()
3363 .map(|channel| channel.id.to_proto()),
3364 live_kit_connection_info,
3365 })?;
3366
3367 let mut connection_pool = session.connection_pool().await;
3368 if let Some(membership_updated) = membership_updated {
3369 notify_membership_updated(
3370 &mut connection_pool,
3371 membership_updated,
3372 session.user_id(),
3373 &session.peer,
3374 );
3375 }
3376
3377 room_updated(&joined_room.room, &session.peer);
3378
3379 joined_room
3380 };
3381
3382 channel_updated(
3383 &joined_room.channel.context("channel not returned")?,
3384 &joined_room.room,
3385 &session.peer,
3386 &*session.connection_pool().await,
3387 );
3388
3389 update_user_contacts(session.user_id(), &session).await?;
3390 Ok(())
3391}
3392
3393/// Start editing the channel notes
3394async fn join_channel_buffer(
3395 request: proto::JoinChannelBuffer,
3396 response: Response<proto::JoinChannelBuffer>,
3397 session: MessageContext,
3398) -> Result<()> {
3399 let db = session.db().await;
3400 let channel_id = ChannelId::from_proto(request.channel_id);
3401
3402 let open_response = db
3403 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3404 .await?;
3405
3406 let collaborators = open_response.collaborators.clone();
3407 response.send(open_response)?;
3408
3409 let update = UpdateChannelBufferCollaborators {
3410 channel_id: channel_id.to_proto(),
3411 collaborators: collaborators.clone(),
3412 };
3413 channel_buffer_updated(
3414 session.connection_id,
3415 collaborators
3416 .iter()
3417 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3418 &update,
3419 &session.peer,
3420 );
3421
3422 Ok(())
3423}
3424
3425/// Edit the channel notes
3426async fn update_channel_buffer(
3427 request: proto::UpdateChannelBuffer,
3428 session: MessageContext,
3429) -> Result<()> {
3430 let db = session.db().await;
3431 let channel_id = ChannelId::from_proto(request.channel_id);
3432
3433 let (collaborators, epoch, version) = db
3434 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3435 .await?;
3436
3437 channel_buffer_updated(
3438 session.connection_id,
3439 collaborators.clone(),
3440 &proto::UpdateChannelBuffer {
3441 channel_id: channel_id.to_proto(),
3442 operations: request.operations,
3443 },
3444 &session.peer,
3445 );
3446
3447 let pool = &*session.connection_pool().await;
3448
3449 let non_collaborators =
3450 pool.channel_connection_ids(channel_id)
3451 .filter_map(|(connection_id, _)| {
3452 if collaborators.contains(&connection_id) {
3453 None
3454 } else {
3455 Some(connection_id)
3456 }
3457 });
3458
3459 broadcast(None, non_collaborators, |peer_id| {
3460 session.peer.send(
3461 peer_id,
3462 proto::UpdateChannels {
3463 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3464 channel_id: channel_id.to_proto(),
3465 epoch: epoch as u64,
3466 version: version.clone(),
3467 }],
3468 ..Default::default()
3469 },
3470 )
3471 });
3472
3473 Ok(())
3474}
3475
3476/// Rejoin the channel notes after a connection blip
3477async fn rejoin_channel_buffers(
3478 request: proto::RejoinChannelBuffers,
3479 response: Response<proto::RejoinChannelBuffers>,
3480 session: MessageContext,
3481) -> Result<()> {
3482 let db = session.db().await;
3483 let buffers = db
3484 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3485 .await?;
3486
3487 for rejoined_buffer in &buffers {
3488 let collaborators_to_notify = rejoined_buffer
3489 .buffer
3490 .collaborators
3491 .iter()
3492 .filter_map(|c| Some(c.peer_id?.into()));
3493 channel_buffer_updated(
3494 session.connection_id,
3495 collaborators_to_notify,
3496 &proto::UpdateChannelBufferCollaborators {
3497 channel_id: rejoined_buffer.buffer.channel_id,
3498 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3499 },
3500 &session.peer,
3501 );
3502 }
3503
3504 response.send(proto::RejoinChannelBuffersResponse {
3505 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3506 })?;
3507
3508 Ok(())
3509}
3510
3511/// Stop editing the channel notes
3512async fn leave_channel_buffer(
3513 request: proto::LeaveChannelBuffer,
3514 response: Response<proto::LeaveChannelBuffer>,
3515 session: MessageContext,
3516) -> Result<()> {
3517 let db = session.db().await;
3518 let channel_id = ChannelId::from_proto(request.channel_id);
3519
3520 let left_buffer = db
3521 .leave_channel_buffer(channel_id, session.connection_id)
3522 .await?;
3523
3524 response.send(Ack {})?;
3525
3526 channel_buffer_updated(
3527 session.connection_id,
3528 left_buffer.connections,
3529 &proto::UpdateChannelBufferCollaborators {
3530 channel_id: channel_id.to_proto(),
3531 collaborators: left_buffer.collaborators,
3532 },
3533 &session.peer,
3534 );
3535
3536 Ok(())
3537}
3538
3539fn channel_buffer_updated<T: EnvelopedMessage>(
3540 sender_id: ConnectionId,
3541 collaborators: impl IntoIterator<Item = ConnectionId>,
3542 message: &T,
3543 peer: &Peer,
3544) {
3545 broadcast(Some(sender_id), collaborators, |peer_id| {
3546 peer.send(peer_id, message.clone())
3547 });
3548}
3549
3550fn send_notifications(
3551 connection_pool: &ConnectionPool,
3552 peer: &Peer,
3553 notifications: db::NotificationBatch,
3554) {
3555 for (user_id, notification) in notifications {
3556 for connection_id in connection_pool.user_connection_ids(user_id) {
3557 if let Err(error) = peer.send(
3558 connection_id,
3559 proto::AddNotification {
3560 notification: Some(notification.clone()),
3561 },
3562 ) {
3563 tracing::error!(
3564 "failed to send notification to {:?} {}",
3565 connection_id,
3566 error
3567 );
3568 }
3569 }
3570 }
3571}
3572
3573/// Send a message to the channel
3574async fn send_channel_message(
3575 _request: proto::SendChannelMessage,
3576 _response: Response<proto::SendChannelMessage>,
3577 _session: MessageContext,
3578) -> Result<()> {
3579 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3580}
3581
3582/// Delete a channel message
3583async fn remove_channel_message(
3584 _request: proto::RemoveChannelMessage,
3585 _response: Response<proto::RemoveChannelMessage>,
3586 _session: MessageContext,
3587) -> Result<()> {
3588 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3589}
3590
3591async fn update_channel_message(
3592 _request: proto::UpdateChannelMessage,
3593 _response: Response<proto::UpdateChannelMessage>,
3594 _session: MessageContext,
3595) -> Result<()> {
3596 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3597}
3598
3599/// Mark a channel message as read
3600async fn acknowledge_channel_message(
3601 _request: proto::AckChannelMessage,
3602 _session: MessageContext,
3603) -> Result<()> {
3604 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3605}
3606
3607/// Mark a buffer version as synced
3608async fn acknowledge_buffer_version(
3609 request: proto::AckBufferOperation,
3610 session: MessageContext,
3611) -> Result<()> {
3612 let buffer_id = BufferId::from_proto(request.buffer_id);
3613 session
3614 .db()
3615 .await
3616 .observe_buffer_version(
3617 buffer_id,
3618 session.user_id(),
3619 request.epoch as i32,
3620 &request.version,
3621 )
3622 .await?;
3623 Ok(())
3624}
3625
3626/// Get a Supermaven API key for the user
3627async fn get_supermaven_api_key(
3628 _request: proto::GetSupermavenApiKey,
3629 response: Response<proto::GetSupermavenApiKey>,
3630 session: MessageContext,
3631) -> Result<()> {
3632 let user_id: String = session.user_id().to_string();
3633 if !session.is_staff() {
3634 return Err(anyhow!("supermaven not enabled for this account"))?;
3635 }
3636
3637 let email = session.email().context("user must have an email")?;
3638
3639 let supermaven_admin_api = session
3640 .supermaven_client
3641 .as_ref()
3642 .context("supermaven not configured")?;
3643
3644 let result = supermaven_admin_api
3645 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3646 .await?;
3647
3648 response.send(proto::GetSupermavenApiKeyResponse {
3649 api_key: result.api_key,
3650 })?;
3651
3652 Ok(())
3653}
3654
3655/// Start receiving chat updates for a channel
3656async fn join_channel_chat(
3657 _request: proto::JoinChannelChat,
3658 _response: Response<proto::JoinChannelChat>,
3659 _session: MessageContext,
3660) -> Result<()> {
3661 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3662}
3663
3664/// Stop receiving chat updates for a channel
3665async fn leave_channel_chat(
3666 _request: proto::LeaveChannelChat,
3667 _session: MessageContext,
3668) -> Result<()> {
3669 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3670}
3671
3672/// Retrieve the chat history for a channel
3673async fn get_channel_messages(
3674 _request: proto::GetChannelMessages,
3675 _response: Response<proto::GetChannelMessages>,
3676 _session: MessageContext,
3677) -> Result<()> {
3678 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3679}
3680
3681/// Retrieve specific chat messages
3682async fn get_channel_messages_by_id(
3683 _request: proto::GetChannelMessagesById,
3684 _response: Response<proto::GetChannelMessagesById>,
3685 _session: MessageContext,
3686) -> Result<()> {
3687 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3688}
3689
3690/// Retrieve the current users notifications
3691async fn get_notifications(
3692 request: proto::GetNotifications,
3693 response: Response<proto::GetNotifications>,
3694 session: MessageContext,
3695) -> Result<()> {
3696 let notifications = session
3697 .db()
3698 .await
3699 .get_notifications(
3700 session.user_id(),
3701 NOTIFICATION_COUNT_PER_PAGE,
3702 request.before_id.map(db::NotificationId::from_proto),
3703 )
3704 .await?;
3705 response.send(proto::GetNotificationsResponse {
3706 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3707 notifications,
3708 })?;
3709 Ok(())
3710}
3711
3712/// Mark notifications as read
3713async fn mark_notification_as_read(
3714 request: proto::MarkNotificationRead,
3715 response: Response<proto::MarkNotificationRead>,
3716 session: MessageContext,
3717) -> Result<()> {
3718 let database = &session.db().await;
3719 let notifications = database
3720 .mark_notification_as_read_by_id(
3721 session.user_id(),
3722 NotificationId::from_proto(request.notification_id),
3723 )
3724 .await?;
3725 send_notifications(
3726 &*session.connection_pool().await,
3727 &session.peer,
3728 notifications,
3729 );
3730 response.send(proto::Ack {})?;
3731 Ok(())
3732}
3733
3734fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3735 let message = match message {
3736 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3737 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3738 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3739 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3740 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3741 code: frame.code.into(),
3742 reason: frame.reason.as_str().to_owned().into(),
3743 })),
3744 // We should never receive a frame while reading the message, according
3745 // to the `tungstenite` maintainers:
3746 //
3747 // > It cannot occur when you read messages from the WebSocket, but it
3748 // > can be used when you want to send the raw frames (e.g. you want to
3749 // > send the frames to the WebSocket without composing the full message first).
3750 // >
3751 // > — https://github.com/snapview/tungstenite-rs/issues/268
3752 TungsteniteMessage::Frame(_) => {
3753 bail!("received an unexpected frame while reading the message")
3754 }
3755 };
3756
3757 Ok(message)
3758}
3759
3760fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3761 match message {
3762 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3763 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3764 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3765 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3766 AxumMessage::Close(frame) => {
3767 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3768 code: frame.code.into(),
3769 reason: frame.reason.as_ref().into(),
3770 }))
3771 }
3772 }
3773}
3774
3775fn notify_membership_updated(
3776 connection_pool: &mut ConnectionPool,
3777 result: MembershipUpdated,
3778 user_id: UserId,
3779 peer: &Peer,
3780) {
3781 for membership in &result.new_channels.channel_memberships {
3782 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3783 }
3784 for channel_id in &result.removed_channels {
3785 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3786 }
3787
3788 let user_channels_update = proto::UpdateUserChannels {
3789 channel_memberships: result
3790 .new_channels
3791 .channel_memberships
3792 .iter()
3793 .map(|cm| proto::ChannelMembership {
3794 channel_id: cm.channel_id.to_proto(),
3795 role: cm.role.into(),
3796 })
3797 .collect(),
3798 ..Default::default()
3799 };
3800
3801 let mut update = build_channels_update(result.new_channels);
3802 update.delete_channels = result
3803 .removed_channels
3804 .into_iter()
3805 .map(|id| id.to_proto())
3806 .collect();
3807 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3808
3809 for connection_id in connection_pool.user_connection_ids(user_id) {
3810 peer.send(connection_id, user_channels_update.clone())
3811 .trace_err();
3812 peer.send(connection_id, update.clone()).trace_err();
3813 }
3814}
3815
3816fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3817 proto::UpdateUserChannels {
3818 channel_memberships: channels
3819 .channel_memberships
3820 .iter()
3821 .map(|m| proto::ChannelMembership {
3822 channel_id: m.channel_id.to_proto(),
3823 role: m.role.into(),
3824 })
3825 .collect(),
3826 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3827 }
3828}
3829
3830fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3831 let mut update = proto::UpdateChannels::default();
3832
3833 for channel in channels.channels {
3834 update.channels.push(channel.to_proto());
3835 }
3836
3837 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3838
3839 for (channel_id, participants) in channels.channel_participants {
3840 update
3841 .channel_participants
3842 .push(proto::ChannelParticipants {
3843 channel_id: channel_id.to_proto(),
3844 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3845 });
3846 }
3847
3848 for channel in channels.invited_channels {
3849 update.channel_invitations.push(channel.to_proto());
3850 }
3851
3852 update
3853}
3854
3855fn build_initial_contacts_update(
3856 contacts: Vec<db::Contact>,
3857 pool: &ConnectionPool,
3858) -> proto::UpdateContacts {
3859 let mut update = proto::UpdateContacts::default();
3860
3861 for contact in contacts {
3862 match contact {
3863 db::Contact::Accepted { user_id, busy } => {
3864 update.contacts.push(contact_for_user(user_id, busy, pool));
3865 }
3866 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3867 db::Contact::Incoming { user_id } => {
3868 update
3869 .incoming_requests
3870 .push(proto::IncomingContactRequest {
3871 requester_id: user_id.to_proto(),
3872 })
3873 }
3874 }
3875 }
3876
3877 update
3878}
3879
3880fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3881 proto::Contact {
3882 user_id: user_id.to_proto(),
3883 online: pool.is_user_online(user_id),
3884 busy,
3885 }
3886}
3887
3888fn room_updated(room: &proto::Room, peer: &Peer) {
3889 broadcast(
3890 None,
3891 room.participants
3892 .iter()
3893 .filter_map(|participant| Some(participant.peer_id?.into())),
3894 |peer_id| {
3895 peer.send(
3896 peer_id,
3897 proto::RoomUpdated {
3898 room: Some(room.clone()),
3899 },
3900 )
3901 },
3902 );
3903}
3904
3905fn channel_updated(
3906 channel: &db::channel::Model,
3907 room: &proto::Room,
3908 peer: &Peer,
3909 pool: &ConnectionPool,
3910) {
3911 let participants = room
3912 .participants
3913 .iter()
3914 .map(|p| p.user_id)
3915 .collect::<Vec<_>>();
3916
3917 broadcast(
3918 None,
3919 pool.channel_connection_ids(channel.root_id())
3920 .filter_map(|(channel_id, role)| {
3921 role.can_see_channel(channel.visibility)
3922 .then_some(channel_id)
3923 }),
3924 |peer_id| {
3925 peer.send(
3926 peer_id,
3927 proto::UpdateChannels {
3928 channel_participants: vec![proto::ChannelParticipants {
3929 channel_id: channel.id.to_proto(),
3930 participant_user_ids: participants.clone(),
3931 }],
3932 ..Default::default()
3933 },
3934 )
3935 },
3936 );
3937}
3938
3939async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3940 let db = session.db().await;
3941
3942 let contacts = db.get_contacts(user_id).await?;
3943 let busy = db.is_user_busy(user_id).await?;
3944
3945 let pool = session.connection_pool().await;
3946 let updated_contact = contact_for_user(user_id, busy, &pool);
3947 for contact in contacts {
3948 if let db::Contact::Accepted {
3949 user_id: contact_user_id,
3950 ..
3951 } = contact
3952 {
3953 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3954 session
3955 .peer
3956 .send(
3957 contact_conn_id,
3958 proto::UpdateContacts {
3959 contacts: vec![updated_contact.clone()],
3960 remove_contacts: Default::default(),
3961 incoming_requests: Default::default(),
3962 remove_incoming_requests: Default::default(),
3963 outgoing_requests: Default::default(),
3964 remove_outgoing_requests: Default::default(),
3965 },
3966 )
3967 .trace_err();
3968 }
3969 }
3970 }
3971 Ok(())
3972}
3973
3974async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3975 let mut contacts_to_update = HashSet::default();
3976
3977 let room_id;
3978 let canceled_calls_to_user_ids;
3979 let livekit_room;
3980 let delete_livekit_room;
3981 let room;
3982 let channel;
3983
3984 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3985 contacts_to_update.insert(session.user_id());
3986
3987 for project in left_room.left_projects.values() {
3988 project_left(project, session);
3989 }
3990
3991 room_id = RoomId::from_proto(left_room.room.id);
3992 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3993 livekit_room = mem::take(&mut left_room.room.livekit_room);
3994 delete_livekit_room = left_room.deleted;
3995 room = mem::take(&mut left_room.room);
3996 channel = mem::take(&mut left_room.channel);
3997
3998 room_updated(&room, &session.peer);
3999 } else {
4000 return Ok(());
4001 }
4002
4003 if let Some(channel) = channel {
4004 channel_updated(
4005 &channel,
4006 &room,
4007 &session.peer,
4008 &*session.connection_pool().await,
4009 );
4010 }
4011
4012 {
4013 let pool = session.connection_pool().await;
4014 for canceled_user_id in canceled_calls_to_user_ids {
4015 for connection_id in pool.user_connection_ids(canceled_user_id) {
4016 session
4017 .peer
4018 .send(
4019 connection_id,
4020 proto::CallCanceled {
4021 room_id: room_id.to_proto(),
4022 },
4023 )
4024 .trace_err();
4025 }
4026 contacts_to_update.insert(canceled_user_id);
4027 }
4028 }
4029
4030 for contact_user_id in contacts_to_update {
4031 update_user_contacts(contact_user_id, session).await?;
4032 }
4033
4034 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4035 live_kit
4036 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4037 .await
4038 .trace_err();
4039
4040 if delete_livekit_room {
4041 live_kit.delete_room(livekit_room).await.trace_err();
4042 }
4043 }
4044
4045 Ok(())
4046}
4047
4048async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4049 let left_channel_buffers = session
4050 .db()
4051 .await
4052 .leave_channel_buffers(session.connection_id)
4053 .await?;
4054
4055 for left_buffer in left_channel_buffers {
4056 channel_buffer_updated(
4057 session.connection_id,
4058 left_buffer.connections,
4059 &proto::UpdateChannelBufferCollaborators {
4060 channel_id: left_buffer.channel_id.to_proto(),
4061 collaborators: left_buffer.collaborators,
4062 },
4063 &session.peer,
4064 );
4065 }
4066
4067 Ok(())
4068}
4069
4070fn project_left(project: &db::LeftProject, session: &Session) {
4071 for connection_id in &project.connection_ids {
4072 if project.should_unshare {
4073 session
4074 .peer
4075 .send(
4076 *connection_id,
4077 proto::UnshareProject {
4078 project_id: project.id.to_proto(),
4079 },
4080 )
4081 .trace_err();
4082 } else {
4083 session
4084 .peer
4085 .send(
4086 *connection_id,
4087 proto::RemoveProjectCollaborator {
4088 project_id: project.id.to_proto(),
4089 peer_id: Some(session.connection_id.into()),
4090 },
4091 )
4092 .trace_err();
4093 }
4094 }
4095}
4096
4097pub trait ResultExt {
4098 type Ok;
4099
4100 fn trace_err(self) -> Option<Self::Ok>;
4101}
4102
4103impl<T, E> ResultExt for Result<T, E>
4104where
4105 E: std::fmt::Debug,
4106{
4107 type Ok = T;
4108
4109 #[track_caller]
4110 fn trace_err(self) -> Option<T> {
4111 match self {
4112 Ok(value) => Some(value),
4113 Err(error) => {
4114 tracing::error!("{:?}", error);
4115 None
4116 }
4117 }
4118 }
4119}