1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
10 },
11 executor::Executor,
12};
13use anyhow::{Context as _, anyhow, bail};
14use async_tungstenite::tungstenite::{
15 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
16};
17use axum::headers::UserAgent;
18use axum::{
19 Extension, Router, TypedHeader,
20 body::Body,
21 extract::{
22 ConnectInfo, WebSocketUpgrade,
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use core::fmt::{self, Debug, Formatter};
34use futures::TryFutureExt as _;
35use rpc::proto::split_repository_update;
36use tracing::Span;
37use util::paths::PathStyle;
38
39use futures::{
40 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
41 stream::FuturesUnordered,
42};
43use prometheus::{IntGauge, register_int_gauge};
44use rpc::{
45 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
46 proto::{
47 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
48 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
49 },
50};
51use semver::Version;
52use serde::{Serialize, Serializer};
53use std::{
54 any::TypeId,
55 future::Future,
56 marker::PhantomData,
57 mem,
58 net::SocketAddr,
59 ops::{Deref, DerefMut},
60 rc::Rc,
61 sync::{
62 Arc, OnceLock,
63 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
64 },
65 time::{Duration, Instant},
66};
67use tokio::sync::{Semaphore, watch};
68use tower::ServiceBuilder;
69use tracing::{
70 Instrument,
71 field::{self},
72 info_span, instrument,
73};
74
75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
76
77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
79
80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
82
83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
84
85const TOTAL_DURATION_MS: &str = "total_duration_ms";
86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
88const HOST_WAITING_MS: &str = "host_waiting_ms";
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
92
93pub struct ConnectionGuard;
94
95impl ConnectionGuard {
96 pub fn try_acquire() -> Result<Self, ()> {
97 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
98 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
99 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
100 tracing::error!(
101 "too many concurrent connections: {}",
102 current_connections + 1
103 );
104 return Err(());
105 }
106 Ok(ConnectionGuard)
107 }
108}
109
110impl Drop for ConnectionGuard {
111 fn drop(&mut self) {
112 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
113 }
114}
115
116struct Response<R> {
117 peer: Arc<Peer>,
118 receipt: Receipt<R>,
119 responded: Arc<AtomicBool>,
120}
121
122impl<R: RequestMessage> Response<R> {
123 fn send(self, payload: R::Response) -> Result<()> {
124 self.responded.store(true, SeqCst);
125 self.peer.respond(self.receipt, payload)?;
126 Ok(())
127 }
128}
129
130#[derive(Clone, Debug)]
131pub enum Principal {
132 User(User),
133 Impersonated { user: User, admin: User },
134}
135
136impl Principal {
137 fn update_span(&self, span: &tracing::Span) {
138 match &self {
139 Principal::User(user) => {
140 span.record("user_id", user.id.0);
141 span.record("login", &user.github_login);
142 }
143 Principal::Impersonated { user, admin } => {
144 span.record("user_id", user.id.0);
145 span.record("login", &user.github_login);
146 span.record("impersonator", &admin.github_login);
147 }
148 }
149 }
150}
151
152#[derive(Clone)]
153struct MessageContext {
154 session: Session,
155 span: tracing::Span,
156}
157
158impl Deref for MessageContext {
159 type Target = Session;
160
161 fn deref(&self) -> &Self::Target {
162 &self.session
163 }
164}
165
166impl MessageContext {
167 pub fn forward_request<T: RequestMessage>(
168 &self,
169 receiver_id: ConnectionId,
170 request: T,
171 ) -> impl Future<Output = anyhow::Result<T::Response>> {
172 let request_start_time = Instant::now();
173 let span = self.span.clone();
174 tracing::info!("start forwarding request");
175 self.peer
176 .forward_request(self.connection_id, receiver_id, request)
177 .inspect(move |_| {
178 span.record(
179 HOST_WAITING_MS,
180 request_start_time.elapsed().as_micros() as f64 / 1000.0,
181 );
182 })
183 .inspect_err(|_| tracing::error!("error forwarding request"))
184 .inspect_ok(|_| tracing::info!("finished forwarding request"))
185 }
186}
187
188#[derive(Clone)]
189struct Session {
190 principal: Principal,
191 connection_id: ConnectionId,
192 db: Arc<tokio::sync::Mutex<DbHandle>>,
193 peer: Arc<Peer>,
194 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
195 app_state: Arc<AppState>,
196 /// The GeoIP country code for the user.
197 #[allow(unused)]
198 geoip_country_code: Option<String>,
199 #[allow(unused)]
200 system_id: Option<String>,
201 _executor: Executor,
202}
203
204impl Session {
205 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
206 #[cfg(test)]
207 tokio::task::yield_now().await;
208 let guard = self.db.lock().await;
209 #[cfg(test)]
210 tokio::task::yield_now().await;
211 guard
212 }
213
214 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
215 #[cfg(test)]
216 tokio::task::yield_now().await;
217 let guard = self.connection_pool.lock();
218 ConnectionPoolGuard {
219 guard,
220 _not_send: PhantomData,
221 }
222 }
223
224 #[expect(dead_code)]
225 fn is_staff(&self) -> bool {
226 match &self.principal {
227 Principal::User(user) => user.admin,
228 Principal::Impersonated { .. } => true,
229 }
230 }
231
232 fn user_id(&self) -> UserId {
233 match &self.principal {
234 Principal::User(user) => user.id,
235 Principal::Impersonated { user, .. } => user.id,
236 }
237 }
238}
239
240impl Debug for Session {
241 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
242 let mut result = f.debug_struct("Session");
243 match &self.principal {
244 Principal::User(user) => {
245 result.field("user", &user.github_login);
246 }
247 Principal::Impersonated { user, admin } => {
248 result.field("user", &user.github_login);
249 result.field("impersonator", &admin.github_login);
250 }
251 }
252 result.field("connection_id", &self.connection_id).finish()
253 }
254}
255
256struct DbHandle(Arc<Database>);
257
258impl Deref for DbHandle {
259 type Target = Database;
260
261 fn deref(&self) -> &Self::Target {
262 self.0.as_ref()
263 }
264}
265
266pub struct Server {
267 id: parking_lot::Mutex<ServerId>,
268 peer: Arc<Peer>,
269 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
270 app_state: Arc<AppState>,
271 handlers: HashMap<TypeId, MessageHandler>,
272 teardown: watch::Sender<bool>,
273}
274
275pub(crate) struct ConnectionPoolGuard<'a> {
276 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
277 _not_send: PhantomData<Rc<()>>,
278}
279
280#[derive(Serialize)]
281pub struct ServerSnapshot<'a> {
282 peer: &'a Peer,
283 #[serde(serialize_with = "serialize_deref")]
284 connection_pool: ConnectionPoolGuard<'a>,
285}
286
287pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
288where
289 S: Serializer,
290 T: Deref<Target = U>,
291 U: Serialize,
292{
293 Serialize::serialize(value.deref(), serializer)
294}
295
296impl Server {
297 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
298 let mut server = Self {
299 id: parking_lot::Mutex::new(id),
300 peer: Peer::new(id.0 as u32),
301 app_state,
302 connection_pool: Default::default(),
303 handlers: Default::default(),
304 teardown: watch::channel(false).0,
305 };
306
307 server
308 .add_request_handler(ping)
309 .add_request_handler(create_room)
310 .add_request_handler(join_room)
311 .add_request_handler(rejoin_room)
312 .add_request_handler(leave_room)
313 .add_request_handler(set_room_participant_role)
314 .add_request_handler(call)
315 .add_request_handler(cancel_call)
316 .add_message_handler(decline_call)
317 .add_request_handler(update_participant_location)
318 .add_request_handler(share_project)
319 .add_message_handler(unshare_project)
320 .add_request_handler(join_project)
321 .add_message_handler(leave_project)
322 .add_request_handler(update_project)
323 .add_request_handler(update_worktree)
324 .add_request_handler(update_repository)
325 .add_request_handler(remove_repository)
326 .add_message_handler(start_language_server)
327 .add_message_handler(update_language_server)
328 .add_message_handler(update_diagnostic_summary)
329 .add_message_handler(update_worktree_settings)
330 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
331 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
332 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
333 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
334 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
335 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
336 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
337 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
338 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
339 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
340 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
341 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
342 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
343 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
344 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
345 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
346 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
347 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
348 .add_request_handler(
349 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
350 )
351 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
352 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
353 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
354 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
355 .add_request_handler(
356 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
357 )
358 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
359 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
360 .add_request_handler(
361 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
362 )
363 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
364 .add_request_handler(
365 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
366 )
367 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
368 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
369 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
370 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
371 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
372 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
373 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
374 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
375 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
376 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
377 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
378 .add_request_handler(
379 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
380 )
381 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
382 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
383 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
384 .add_request_handler(lsp_query)
385 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
386 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
387 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
388 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
389 .add_message_handler(create_buffer_for_peer)
390 .add_message_handler(create_image_for_peer)
391 .add_request_handler(update_buffer)
392 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
393 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
394 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
395 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
396 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
397 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
398 .add_message_handler(
399 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
400 )
401 .add_request_handler(get_users)
402 .add_request_handler(fuzzy_search_users)
403 .add_request_handler(request_contact)
404 .add_request_handler(remove_contact)
405 .add_request_handler(respond_to_contact_request)
406 .add_message_handler(subscribe_to_channels)
407 .add_request_handler(create_channel)
408 .add_request_handler(delete_channel)
409 .add_request_handler(invite_channel_member)
410 .add_request_handler(remove_channel_member)
411 .add_request_handler(set_channel_member_role)
412 .add_request_handler(set_channel_visibility)
413 .add_request_handler(rename_channel)
414 .add_request_handler(join_channel_buffer)
415 .add_request_handler(leave_channel_buffer)
416 .add_message_handler(update_channel_buffer)
417 .add_request_handler(rejoin_channel_buffers)
418 .add_request_handler(get_channel_members)
419 .add_request_handler(respond_to_channel_invite)
420 .add_request_handler(join_channel)
421 .add_request_handler(join_channel_chat)
422 .add_message_handler(leave_channel_chat)
423 .add_request_handler(send_channel_message)
424 .add_request_handler(remove_channel_message)
425 .add_request_handler(update_channel_message)
426 .add_request_handler(get_channel_messages)
427 .add_request_handler(get_channel_messages_by_id)
428 .add_request_handler(get_notifications)
429 .add_request_handler(mark_notification_as_read)
430 .add_request_handler(move_channel)
431 .add_request_handler(reorder_channel)
432 .add_request_handler(follow)
433 .add_message_handler(unfollow)
434 .add_message_handler(update_followers)
435 .add_message_handler(acknowledge_channel_message)
436 .add_message_handler(acknowledge_buffer_version)
437 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
438 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
439 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
440 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
441 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
442 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
443 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
444 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
445 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
446 .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
447 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
448 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
449 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
450 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
451 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
452 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
453 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
454 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
455 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
456 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
457 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
458 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
459 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
460 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
461 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
462 .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
463 .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
464 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
465 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
466 .add_message_handler(update_context)
467 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
468 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
469
470 Arc::new(server)
471 }
472
473 pub async fn start(&self) -> Result<()> {
474 let server_id = *self.id.lock();
475 let app_state = self.app_state.clone();
476 let peer = self.peer.clone();
477 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
478 let pool = self.connection_pool.clone();
479 let livekit_client = self.app_state.livekit_client.clone();
480
481 let span = info_span!("start server");
482 self.app_state.executor.spawn_detached(
483 async move {
484 tracing::info!("waiting for cleanup timeout");
485 timeout.await;
486 tracing::info!("cleanup timeout expired, retrieving stale rooms");
487
488 app_state
489 .db
490 .delete_stale_channel_chat_participants(
491 &app_state.config.zed_environment,
492 server_id,
493 )
494 .await
495 .trace_err();
496
497 if let Some((room_ids, channel_ids)) = app_state
498 .db
499 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
500 .await
501 .trace_err()
502 {
503 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
504 tracing::info!(
505 stale_channel_buffer_count = channel_ids.len(),
506 "retrieved stale channel buffers"
507 );
508
509 for channel_id in channel_ids {
510 if let Some(refreshed_channel_buffer) = app_state
511 .db
512 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
513 .await
514 .trace_err()
515 {
516 for connection_id in refreshed_channel_buffer.connection_ids {
517 peer.send(
518 connection_id,
519 proto::UpdateChannelBufferCollaborators {
520 channel_id: channel_id.to_proto(),
521 collaborators: refreshed_channel_buffer
522 .collaborators
523 .clone(),
524 },
525 )
526 .trace_err();
527 }
528 }
529 }
530
531 for room_id in room_ids {
532 let mut contacts_to_update = HashSet::default();
533 let mut canceled_calls_to_user_ids = Vec::new();
534 let mut livekit_room = String::new();
535 let mut delete_livekit_room = false;
536
537 if let Some(mut refreshed_room) = app_state
538 .db
539 .clear_stale_room_participants(room_id, server_id)
540 .await
541 .trace_err()
542 {
543 tracing::info!(
544 room_id = room_id.0,
545 new_participant_count = refreshed_room.room.participants.len(),
546 "refreshed room"
547 );
548 room_updated(&refreshed_room.room, &peer);
549 if let Some(channel) = refreshed_room.channel.as_ref() {
550 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
551 }
552 contacts_to_update
553 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
554 contacts_to_update
555 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
556 canceled_calls_to_user_ids =
557 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
558 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
559 delete_livekit_room = refreshed_room.room.participants.is_empty();
560 }
561
562 {
563 let pool = pool.lock();
564 for canceled_user_id in canceled_calls_to_user_ids {
565 for connection_id in pool.user_connection_ids(canceled_user_id) {
566 peer.send(
567 connection_id,
568 proto::CallCanceled {
569 room_id: room_id.to_proto(),
570 },
571 )
572 .trace_err();
573 }
574 }
575 }
576
577 for user_id in contacts_to_update {
578 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
579 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
580 if let Some((busy, contacts)) = busy.zip(contacts) {
581 let pool = pool.lock();
582 let updated_contact = contact_for_user(user_id, busy, &pool);
583 for contact in contacts {
584 if let db::Contact::Accepted {
585 user_id: contact_user_id,
586 ..
587 } = contact
588 {
589 for contact_conn_id in
590 pool.user_connection_ids(contact_user_id)
591 {
592 peer.send(
593 contact_conn_id,
594 proto::UpdateContacts {
595 contacts: vec![updated_contact.clone()],
596 remove_contacts: Default::default(),
597 incoming_requests: Default::default(),
598 remove_incoming_requests: Default::default(),
599 outgoing_requests: Default::default(),
600 remove_outgoing_requests: Default::default(),
601 },
602 )
603 .trace_err();
604 }
605 }
606 }
607 }
608 }
609
610 if let Some(live_kit) = livekit_client.as_ref()
611 && delete_livekit_room
612 {
613 live_kit.delete_room(livekit_room).await.trace_err();
614 }
615 }
616 }
617
618 app_state
619 .db
620 .delete_stale_channel_chat_participants(
621 &app_state.config.zed_environment,
622 server_id,
623 )
624 .await
625 .trace_err();
626
627 app_state
628 .db
629 .clear_old_worktree_entries(server_id)
630 .await
631 .trace_err();
632
633 app_state
634 .db
635 .delete_stale_servers(&app_state.config.zed_environment, server_id)
636 .await
637 .trace_err();
638 }
639 .instrument(span),
640 );
641 Ok(())
642 }
643
644 pub fn teardown(&self) {
645 self.peer.teardown();
646 self.connection_pool.lock().reset();
647 let _ = self.teardown.send(true);
648 }
649
650 #[cfg(test)]
651 pub fn reset(&self, id: ServerId) {
652 self.teardown();
653 *self.id.lock() = id;
654 self.peer.reset(id.0 as u32);
655 let _ = self.teardown.send(false);
656 }
657
658 #[cfg(test)]
659 pub fn id(&self) -> ServerId {
660 *self.id.lock()
661 }
662
663 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
664 where
665 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
666 Fut: 'static + Send + Future<Output = Result<()>>,
667 M: EnvelopedMessage,
668 {
669 let prev_handler = self.handlers.insert(
670 TypeId::of::<M>(),
671 Box::new(move |envelope, session, span| {
672 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
673 let received_at = envelope.received_at;
674 tracing::info!("message received");
675 let start_time = Instant::now();
676 let future = (handler)(
677 *envelope,
678 MessageContext {
679 session,
680 span: span.clone(),
681 },
682 );
683 async move {
684 let result = future.await;
685 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
686 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
687 let queue_duration_ms = total_duration_ms - processing_duration_ms;
688 span.record(TOTAL_DURATION_MS, total_duration_ms);
689 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
690 span.record(QUEUE_DURATION_MS, queue_duration_ms);
691 match result {
692 Err(error) => {
693 tracing::error!(?error, "error handling message")
694 }
695 Ok(()) => tracing::info!("finished handling message"),
696 }
697 }
698 .boxed()
699 }),
700 );
701 if prev_handler.is_some() {
702 panic!("registered a handler for the same message twice");
703 }
704 self
705 }
706
707 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
708 where
709 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
710 Fut: 'static + Send + Future<Output = Result<()>>,
711 M: EnvelopedMessage,
712 {
713 self.add_handler(move |envelope, session| handler(envelope.payload, session));
714 self
715 }
716
717 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
718 where
719 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
720 Fut: Send + Future<Output = Result<()>>,
721 M: RequestMessage,
722 {
723 let handler = Arc::new(handler);
724 self.add_handler(move |envelope, session| {
725 let receipt = envelope.receipt();
726 let handler = handler.clone();
727 async move {
728 let peer = session.peer.clone();
729 let responded = Arc::new(AtomicBool::default());
730 let response = Response {
731 peer: peer.clone(),
732 responded: responded.clone(),
733 receipt,
734 };
735 match (handler)(envelope.payload, response, session).await {
736 Ok(()) => {
737 if responded.load(std::sync::atomic::Ordering::SeqCst) {
738 Ok(())
739 } else {
740 Err(anyhow!("handler did not send a response"))?
741 }
742 }
743 Err(error) => {
744 let proto_err = match &error {
745 Error::Internal(err) => err.to_proto(),
746 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
747 };
748 peer.respond_with_error(receipt, proto_err)?;
749 Err(error)
750 }
751 }
752 }
753 })
754 }
755
756 pub fn handle_connection(
757 self: &Arc<Self>,
758 connection: Connection,
759 address: String,
760 principal: Principal,
761 zed_version: ZedVersion,
762 release_channel: Option<String>,
763 user_agent: Option<String>,
764 geoip_country_code: Option<String>,
765 system_id: Option<String>,
766 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
767 executor: Executor,
768 connection_guard: Option<ConnectionGuard>,
769 ) -> impl Future<Output = ()> + use<> {
770 let this = self.clone();
771 let span = info_span!("handle connection", %address,
772 connection_id=field::Empty,
773 user_id=field::Empty,
774 login=field::Empty,
775 impersonator=field::Empty,
776 user_agent=field::Empty,
777 geoip_country_code=field::Empty,
778 release_channel=field::Empty,
779 );
780 principal.update_span(&span);
781 if let Some(user_agent) = user_agent {
782 span.record("user_agent", user_agent);
783 }
784 if let Some(release_channel) = release_channel {
785 span.record("release_channel", release_channel);
786 }
787
788 if let Some(country_code) = geoip_country_code.as_ref() {
789 span.record("geoip_country_code", country_code);
790 }
791
792 let mut teardown = self.teardown.subscribe();
793 async move {
794 if *teardown.borrow() {
795 tracing::error!("server is tearing down");
796 return;
797 }
798
799 let (connection_id, handle_io, mut incoming_rx) =
800 this.peer.add_connection(connection, {
801 let executor = executor.clone();
802 move |duration| executor.sleep(duration)
803 });
804 tracing::Span::current().record("connection_id", format!("{}", connection_id));
805
806 tracing::info!("connection opened");
807
808 let session = Session {
809 principal: principal.clone(),
810 connection_id,
811 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
812 peer: this.peer.clone(),
813 connection_pool: this.connection_pool.clone(),
814 app_state: this.app_state.clone(),
815 geoip_country_code,
816 system_id,
817 _executor: executor.clone(),
818 };
819
820 if let Err(error) = this
821 .send_initial_client_update(
822 connection_id,
823 zed_version,
824 send_connection_id,
825 &session,
826 )
827 .await
828 {
829 tracing::error!(?error, "failed to send initial client update");
830 return;
831 }
832 drop(connection_guard);
833
834 let handle_io = handle_io.fuse();
835 futures::pin_mut!(handle_io);
836
837 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
838 // This prevents deadlocks when e.g., client A performs a request to client B and
839 // client B performs a request to client A. If both clients stop processing further
840 // messages until their respective request completes, they won't have a chance to
841 // respond to the other client's request and cause a deadlock.
842 //
843 // This arrangement ensures we will attempt to process earlier messages first, but fall
844 // back to processing messages arrived later in the spirit of making progress.
845 const MAX_CONCURRENT_HANDLERS: usize = 256;
846 let mut foreground_message_handlers = FuturesUnordered::new();
847 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
848 let get_concurrent_handlers = {
849 let concurrent_handlers = concurrent_handlers.clone();
850 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
851 };
852 loop {
853 let next_message = async {
854 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
855 let message = incoming_rx.next().await;
856 // Cache the concurrent_handlers here, so that we know what the
857 // queue looks like as each handler starts
858 (permit, message, get_concurrent_handlers())
859 }
860 .fuse();
861 futures::pin_mut!(next_message);
862 futures::select_biased! {
863 _ = teardown.changed().fuse() => return,
864 result = handle_io => {
865 if let Err(error) = result {
866 tracing::error!(?error, "error handling I/O");
867 }
868 break;
869 }
870 _ = foreground_message_handlers.next() => {}
871 next_message = next_message => {
872 let (permit, message, concurrent_handlers) = next_message;
873 if let Some(message) = message {
874 let type_name = message.payload_type_name();
875 // note: we copy all the fields from the parent span so we can query them in the logs.
876 // (https://github.com/tokio-rs/tracing/issues/2670).
877 let span = tracing::info_span!("receive message",
878 %connection_id,
879 %address,
880 type_name,
881 concurrent_handlers,
882 user_id=field::Empty,
883 login=field::Empty,
884 impersonator=field::Empty,
885 lsp_query_request=field::Empty,
886 release_channel=field::Empty,
887 { TOTAL_DURATION_MS }=field::Empty,
888 { PROCESSING_DURATION_MS }=field::Empty,
889 { QUEUE_DURATION_MS }=field::Empty,
890 { HOST_WAITING_MS }=field::Empty
891 );
892 principal.update_span(&span);
893 let span_enter = span.enter();
894 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
895 let is_background = message.is_background();
896 let handle_message = (handler)(message, session.clone(), span.clone());
897 drop(span_enter);
898
899 let handle_message = async move {
900 handle_message.await;
901 drop(permit);
902 }.instrument(span);
903 if is_background {
904 executor.spawn_detached(handle_message);
905 } else {
906 foreground_message_handlers.push(handle_message);
907 }
908 } else {
909 tracing::error!("no message handler");
910 }
911 } else {
912 tracing::info!("connection closed");
913 break;
914 }
915 }
916 }
917 }
918
919 drop(foreground_message_handlers);
920 let concurrent_handlers = get_concurrent_handlers();
921 tracing::info!(concurrent_handlers, "signing out");
922 if let Err(error) = connection_lost(session, teardown, executor).await {
923 tracing::error!(?error, "error signing out");
924 }
925 }
926 .instrument(span)
927 }
928
929 async fn send_initial_client_update(
930 &self,
931 connection_id: ConnectionId,
932 zed_version: ZedVersion,
933 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
934 session: &Session,
935 ) -> Result<()> {
936 self.peer.send(
937 connection_id,
938 proto::Hello {
939 peer_id: Some(connection_id.into()),
940 },
941 )?;
942 tracing::info!("sent hello message");
943 if let Some(send_connection_id) = send_connection_id.take() {
944 let _ = send_connection_id.send(connection_id);
945 }
946
947 match &session.principal {
948 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
949 if !user.connected_once {
950 self.peer.send(connection_id, proto::ShowContacts {})?;
951 self.app_state
952 .db
953 .set_user_connected_once(user.id, true)
954 .await?;
955 }
956
957 let contacts = self.app_state.db.get_contacts(user.id).await?;
958
959 {
960 let mut pool = self.connection_pool.lock();
961 pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
962 self.peer.send(
963 connection_id,
964 build_initial_contacts_update(contacts, &pool),
965 )?;
966 }
967
968 if should_auto_subscribe_to_channels(&zed_version) {
969 subscribe_user_to_channels(user.id, session).await?;
970 }
971
972 if let Some(incoming_call) =
973 self.app_state.db.incoming_call_for_user(user.id).await?
974 {
975 self.peer.send(connection_id, incoming_call)?;
976 }
977
978 update_user_contacts(user.id, session).await?;
979 }
980 }
981
982 Ok(())
983 }
984
985 pub async fn invite_code_redeemed(
986 self: &Arc<Self>,
987 inviter_id: UserId,
988 invitee_id: UserId,
989 ) -> Result<()> {
990 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
991 && let Some(code) = &user.invite_code
992 {
993 let pool = self.connection_pool.lock();
994 let invitee_contact = contact_for_user(invitee_id, false, &pool);
995 for connection_id in pool.user_connection_ids(inviter_id) {
996 self.peer.send(
997 connection_id,
998 proto::UpdateContacts {
999 contacts: vec![invitee_contact.clone()],
1000 ..Default::default()
1001 },
1002 )?;
1003 self.peer.send(
1004 connection_id,
1005 proto::UpdateInviteInfo {
1006 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1007 count: user.invite_count as u32,
1008 },
1009 )?;
1010 }
1011 }
1012 Ok(())
1013 }
1014
1015 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1016 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1017 && let Some(invite_code) = &user.invite_code
1018 {
1019 let pool = self.connection_pool.lock();
1020 for connection_id in pool.user_connection_ids(user_id) {
1021 self.peer.send(
1022 connection_id,
1023 proto::UpdateInviteInfo {
1024 url: format!(
1025 "{}{}",
1026 self.app_state.config.invite_link_prefix, invite_code
1027 ),
1028 count: user.invite_count as u32,
1029 },
1030 )?;
1031 }
1032 }
1033 Ok(())
1034 }
1035
1036 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1037 ServerSnapshot {
1038 connection_pool: ConnectionPoolGuard {
1039 guard: self.connection_pool.lock(),
1040 _not_send: PhantomData,
1041 },
1042 peer: &self.peer,
1043 }
1044 }
1045}
1046
1047impl Deref for ConnectionPoolGuard<'_> {
1048 type Target = ConnectionPool;
1049
1050 fn deref(&self) -> &Self::Target {
1051 &self.guard
1052 }
1053}
1054
1055impl DerefMut for ConnectionPoolGuard<'_> {
1056 fn deref_mut(&mut self) -> &mut Self::Target {
1057 &mut self.guard
1058 }
1059}
1060
1061impl Drop for ConnectionPoolGuard<'_> {
1062 fn drop(&mut self) {
1063 #[cfg(test)]
1064 self.check_invariants();
1065 }
1066}
1067
1068fn broadcast<F>(
1069 sender_id: Option<ConnectionId>,
1070 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1071 mut f: F,
1072) where
1073 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1074{
1075 for receiver_id in receiver_ids {
1076 if Some(receiver_id) != sender_id
1077 && let Err(error) = f(receiver_id)
1078 {
1079 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1080 }
1081 }
1082}
1083
1084pub struct ProtocolVersion(u32);
1085
1086impl Header for ProtocolVersion {
1087 fn name() -> &'static HeaderName {
1088 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1089 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1090 }
1091
1092 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1093 where
1094 Self: Sized,
1095 I: Iterator<Item = &'i axum::http::HeaderValue>,
1096 {
1097 let version = values
1098 .next()
1099 .ok_or_else(axum::headers::Error::invalid)?
1100 .to_str()
1101 .map_err(|_| axum::headers::Error::invalid())?
1102 .parse()
1103 .map_err(|_| axum::headers::Error::invalid())?;
1104 Ok(Self(version))
1105 }
1106
1107 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1108 values.extend([self.0.to_string().parse().unwrap()]);
1109 }
1110}
1111
1112pub struct AppVersionHeader(Version);
1113impl Header for AppVersionHeader {
1114 fn name() -> &'static HeaderName {
1115 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1116 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1117 }
1118
1119 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1120 where
1121 Self: Sized,
1122 I: Iterator<Item = &'i axum::http::HeaderValue>,
1123 {
1124 let version = values
1125 .next()
1126 .ok_or_else(axum::headers::Error::invalid)?
1127 .to_str()
1128 .map_err(|_| axum::headers::Error::invalid())?
1129 .parse()
1130 .map_err(|_| axum::headers::Error::invalid())?;
1131 Ok(Self(version))
1132 }
1133
1134 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1135 values.extend([self.0.to_string().parse().unwrap()]);
1136 }
1137}
1138
1139#[derive(Debug)]
1140pub struct ReleaseChannelHeader(String);
1141
1142impl Header for ReleaseChannelHeader {
1143 fn name() -> &'static HeaderName {
1144 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1145 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1146 }
1147
1148 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1149 where
1150 Self: Sized,
1151 I: Iterator<Item = &'i axum::http::HeaderValue>,
1152 {
1153 Ok(Self(
1154 values
1155 .next()
1156 .ok_or_else(axum::headers::Error::invalid)?
1157 .to_str()
1158 .map_err(|_| axum::headers::Error::invalid())?
1159 .to_owned(),
1160 ))
1161 }
1162
1163 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1164 values.extend([self.0.parse().unwrap()]);
1165 }
1166}
1167
1168pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1169 Router::new()
1170 .route("/rpc", get(handle_websocket_request))
1171 .layer(
1172 ServiceBuilder::new()
1173 .layer(Extension(server.app_state.clone()))
1174 .layer(middleware::from_fn(auth::validate_header)),
1175 )
1176 .route("/metrics", get(handle_metrics))
1177 .layer(Extension(server))
1178}
1179
1180pub async fn handle_websocket_request(
1181 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1182 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1183 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1184 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1185 Extension(server): Extension<Arc<Server>>,
1186 Extension(principal): Extension<Principal>,
1187 user_agent: Option<TypedHeader<UserAgent>>,
1188 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1189 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1190 ws: WebSocketUpgrade,
1191) -> axum::response::Response {
1192 if protocol_version != rpc::PROTOCOL_VERSION {
1193 return (
1194 StatusCode::UPGRADE_REQUIRED,
1195 "client must be upgraded".to_string(),
1196 )
1197 .into_response();
1198 }
1199
1200 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1201 return (
1202 StatusCode::UPGRADE_REQUIRED,
1203 "no version header found".to_string(),
1204 )
1205 .into_response();
1206 };
1207
1208 let release_channel = release_channel_header.map(|header| header.0.0);
1209
1210 if !version.can_collaborate() {
1211 return (
1212 StatusCode::UPGRADE_REQUIRED,
1213 "client must be upgraded".to_string(),
1214 )
1215 .into_response();
1216 }
1217
1218 let socket_address = socket_address.to_string();
1219
1220 // Acquire connection guard before WebSocket upgrade
1221 let connection_guard = match ConnectionGuard::try_acquire() {
1222 Ok(guard) => guard,
1223 Err(()) => {
1224 return (
1225 StatusCode::SERVICE_UNAVAILABLE,
1226 "Too many concurrent connections",
1227 )
1228 .into_response();
1229 }
1230 };
1231
1232 ws.on_upgrade(move |socket| {
1233 let socket = socket
1234 .map_ok(to_tungstenite_message)
1235 .err_into()
1236 .with(|message| async move { to_axum_message(message) });
1237 let connection = Connection::new(Box::pin(socket));
1238 async move {
1239 server
1240 .handle_connection(
1241 connection,
1242 socket_address,
1243 principal,
1244 version,
1245 release_channel,
1246 user_agent.map(|header| header.to_string()),
1247 country_code_header.map(|header| header.to_string()),
1248 system_id_header.map(|header| header.to_string()),
1249 None,
1250 Executor::Production,
1251 Some(connection_guard),
1252 )
1253 .await;
1254 }
1255 })
1256}
1257
1258pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1259 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1260 let connections_metric = CONNECTIONS_METRIC
1261 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1262
1263 let connections = server
1264 .connection_pool
1265 .lock()
1266 .connections()
1267 .filter(|connection| !connection.admin)
1268 .count();
1269 connections_metric.set(connections as _);
1270
1271 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1272 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1273 register_int_gauge!(
1274 "shared_projects",
1275 "number of open projects with one or more guests"
1276 )
1277 .unwrap()
1278 });
1279
1280 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1281 shared_projects_metric.set(shared_projects as _);
1282
1283 let encoder = prometheus::TextEncoder::new();
1284 let metric_families = prometheus::gather();
1285 let encoded_metrics = encoder
1286 .encode_to_string(&metric_families)
1287 .map_err(|err| anyhow!("{err}"))?;
1288 Ok(encoded_metrics)
1289}
1290
1291#[instrument(err, skip(executor))]
1292async fn connection_lost(
1293 session: Session,
1294 mut teardown: watch::Receiver<bool>,
1295 executor: Executor,
1296) -> Result<()> {
1297 session.peer.disconnect(session.connection_id);
1298 session
1299 .connection_pool()
1300 .await
1301 .remove_connection(session.connection_id)?;
1302
1303 session
1304 .db()
1305 .await
1306 .connection_lost(session.connection_id)
1307 .await
1308 .trace_err();
1309
1310 futures::select_biased! {
1311 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1312
1313 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1314 leave_room_for_session(&session, session.connection_id).await.trace_err();
1315 leave_channel_buffers_for_session(&session)
1316 .await
1317 .trace_err();
1318
1319 if !session
1320 .connection_pool()
1321 .await
1322 .is_user_online(session.user_id())
1323 {
1324 let db = session.db().await;
1325 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1326 room_updated(&room, &session.peer);
1327 }
1328 }
1329
1330 update_user_contacts(session.user_id(), &session).await?;
1331 },
1332 _ = teardown.changed().fuse() => {}
1333 }
1334
1335 Ok(())
1336}
1337
1338/// Acknowledges a ping from a client, used to keep the connection alive.
1339async fn ping(
1340 _: proto::Ping,
1341 response: Response<proto::Ping>,
1342 _session: MessageContext,
1343) -> Result<()> {
1344 response.send(proto::Ack {})?;
1345 Ok(())
1346}
1347
1348/// Creates a new room for calling (outside of channels)
1349async fn create_room(
1350 _request: proto::CreateRoom,
1351 response: Response<proto::CreateRoom>,
1352 session: MessageContext,
1353) -> Result<()> {
1354 let livekit_room = nanoid::nanoid!(30);
1355
1356 let live_kit_connection_info = util::maybe!(async {
1357 let live_kit = session.app_state.livekit_client.as_ref();
1358 let live_kit = live_kit?;
1359 let user_id = session.user_id().to_string();
1360
1361 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1362
1363 Some(proto::LiveKitConnectionInfo {
1364 server_url: live_kit.url().into(),
1365 token,
1366 can_publish: true,
1367 })
1368 })
1369 .await;
1370
1371 let room = session
1372 .db()
1373 .await
1374 .create_room(session.user_id(), session.connection_id, &livekit_room)
1375 .await?;
1376
1377 response.send(proto::CreateRoomResponse {
1378 room: Some(room.clone()),
1379 live_kit_connection_info,
1380 })?;
1381
1382 update_user_contacts(session.user_id(), &session).await?;
1383 Ok(())
1384}
1385
1386/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1387async fn join_room(
1388 request: proto::JoinRoom,
1389 response: Response<proto::JoinRoom>,
1390 session: MessageContext,
1391) -> Result<()> {
1392 let room_id = RoomId::from_proto(request.id);
1393
1394 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1395
1396 if let Some(channel_id) = channel_id {
1397 return join_channel_internal(channel_id, Box::new(response), session).await;
1398 }
1399
1400 let joined_room = {
1401 let room = session
1402 .db()
1403 .await
1404 .join_room(room_id, session.user_id(), session.connection_id)
1405 .await?;
1406 room_updated(&room.room, &session.peer);
1407 room.into_inner()
1408 };
1409
1410 for connection_id in session
1411 .connection_pool()
1412 .await
1413 .user_connection_ids(session.user_id())
1414 {
1415 session
1416 .peer
1417 .send(
1418 connection_id,
1419 proto::CallCanceled {
1420 room_id: room_id.to_proto(),
1421 },
1422 )
1423 .trace_err();
1424 }
1425
1426 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1427 {
1428 live_kit
1429 .room_token(
1430 &joined_room.room.livekit_room,
1431 &session.user_id().to_string(),
1432 )
1433 .trace_err()
1434 .map(|token| proto::LiveKitConnectionInfo {
1435 server_url: live_kit.url().into(),
1436 token,
1437 can_publish: true,
1438 })
1439 } else {
1440 None
1441 };
1442
1443 response.send(proto::JoinRoomResponse {
1444 room: Some(joined_room.room),
1445 channel_id: None,
1446 live_kit_connection_info,
1447 })?;
1448
1449 update_user_contacts(session.user_id(), &session).await?;
1450 Ok(())
1451}
1452
1453/// Rejoin room is used to reconnect to a room after connection errors.
1454async fn rejoin_room(
1455 request: proto::RejoinRoom,
1456 response: Response<proto::RejoinRoom>,
1457 session: MessageContext,
1458) -> Result<()> {
1459 let room;
1460 let channel;
1461 {
1462 let mut rejoined_room = session
1463 .db()
1464 .await
1465 .rejoin_room(request, session.user_id(), session.connection_id)
1466 .await?;
1467
1468 response.send(proto::RejoinRoomResponse {
1469 room: Some(rejoined_room.room.clone()),
1470 reshared_projects: rejoined_room
1471 .reshared_projects
1472 .iter()
1473 .map(|project| proto::ResharedProject {
1474 id: project.id.to_proto(),
1475 collaborators: project
1476 .collaborators
1477 .iter()
1478 .map(|collaborator| collaborator.to_proto())
1479 .collect(),
1480 })
1481 .collect(),
1482 rejoined_projects: rejoined_room
1483 .rejoined_projects
1484 .iter()
1485 .map(|rejoined_project| rejoined_project.to_proto())
1486 .collect(),
1487 })?;
1488 room_updated(&rejoined_room.room, &session.peer);
1489
1490 for project in &rejoined_room.reshared_projects {
1491 for collaborator in &project.collaborators {
1492 session
1493 .peer
1494 .send(
1495 collaborator.connection_id,
1496 proto::UpdateProjectCollaborator {
1497 project_id: project.id.to_proto(),
1498 old_peer_id: Some(project.old_connection_id.into()),
1499 new_peer_id: Some(session.connection_id.into()),
1500 },
1501 )
1502 .trace_err();
1503 }
1504
1505 broadcast(
1506 Some(session.connection_id),
1507 project
1508 .collaborators
1509 .iter()
1510 .map(|collaborator| collaborator.connection_id),
1511 |connection_id| {
1512 session.peer.forward_send(
1513 session.connection_id,
1514 connection_id,
1515 proto::UpdateProject {
1516 project_id: project.id.to_proto(),
1517 worktrees: project.worktrees.clone(),
1518 },
1519 )
1520 },
1521 );
1522 }
1523
1524 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1525
1526 let rejoined_room = rejoined_room.into_inner();
1527
1528 room = rejoined_room.room;
1529 channel = rejoined_room.channel;
1530 }
1531
1532 if let Some(channel) = channel {
1533 channel_updated(
1534 &channel,
1535 &room,
1536 &session.peer,
1537 &*session.connection_pool().await,
1538 );
1539 }
1540
1541 update_user_contacts(session.user_id(), &session).await?;
1542 Ok(())
1543}
1544
1545fn notify_rejoined_projects(
1546 rejoined_projects: &mut Vec<RejoinedProject>,
1547 session: &Session,
1548) -> Result<()> {
1549 for project in rejoined_projects.iter() {
1550 for collaborator in &project.collaborators {
1551 session
1552 .peer
1553 .send(
1554 collaborator.connection_id,
1555 proto::UpdateProjectCollaborator {
1556 project_id: project.id.to_proto(),
1557 old_peer_id: Some(project.old_connection_id.into()),
1558 new_peer_id: Some(session.connection_id.into()),
1559 },
1560 )
1561 .trace_err();
1562 }
1563 }
1564
1565 for project in rejoined_projects {
1566 for worktree in mem::take(&mut project.worktrees) {
1567 // Stream this worktree's entries.
1568 let message = proto::UpdateWorktree {
1569 project_id: project.id.to_proto(),
1570 worktree_id: worktree.id,
1571 abs_path: worktree.abs_path.clone(),
1572 root_name: worktree.root_name,
1573 updated_entries: worktree.updated_entries,
1574 removed_entries: worktree.removed_entries,
1575 scan_id: worktree.scan_id,
1576 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1577 updated_repositories: worktree.updated_repositories,
1578 removed_repositories: worktree.removed_repositories,
1579 };
1580 for update in proto::split_worktree_update(message) {
1581 session.peer.send(session.connection_id, update)?;
1582 }
1583
1584 // Stream this worktree's diagnostics.
1585 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1586 if let Some(summary) = worktree_diagnostics.next() {
1587 let message = proto::UpdateDiagnosticSummary {
1588 project_id: project.id.to_proto(),
1589 worktree_id: worktree.id,
1590 summary: Some(summary),
1591 more_summaries: worktree_diagnostics.collect(),
1592 };
1593 session.peer.send(session.connection_id, message)?;
1594 }
1595
1596 for settings_file in worktree.settings_files {
1597 session.peer.send(
1598 session.connection_id,
1599 proto::UpdateWorktreeSettings {
1600 project_id: project.id.to_proto(),
1601 worktree_id: worktree.id,
1602 path: settings_file.path,
1603 content: Some(settings_file.content),
1604 kind: Some(settings_file.kind.to_proto().into()),
1605 },
1606 )?;
1607 }
1608 }
1609
1610 for repository in mem::take(&mut project.updated_repositories) {
1611 for update in split_repository_update(repository) {
1612 session.peer.send(session.connection_id, update)?;
1613 }
1614 }
1615
1616 for id in mem::take(&mut project.removed_repositories) {
1617 session.peer.send(
1618 session.connection_id,
1619 proto::RemoveRepository {
1620 project_id: project.id.to_proto(),
1621 id,
1622 },
1623 )?;
1624 }
1625 }
1626
1627 Ok(())
1628}
1629
1630/// leave room disconnects from the room.
1631async fn leave_room(
1632 _: proto::LeaveRoom,
1633 response: Response<proto::LeaveRoom>,
1634 session: MessageContext,
1635) -> Result<()> {
1636 leave_room_for_session(&session, session.connection_id).await?;
1637 response.send(proto::Ack {})?;
1638 Ok(())
1639}
1640
1641/// Updates the permissions of someone else in the room.
1642async fn set_room_participant_role(
1643 request: proto::SetRoomParticipantRole,
1644 response: Response<proto::SetRoomParticipantRole>,
1645 session: MessageContext,
1646) -> Result<()> {
1647 let user_id = UserId::from_proto(request.user_id);
1648 let role = ChannelRole::from(request.role());
1649
1650 let (livekit_room, can_publish) = {
1651 let room = session
1652 .db()
1653 .await
1654 .set_room_participant_role(
1655 session.user_id(),
1656 RoomId::from_proto(request.room_id),
1657 user_id,
1658 role,
1659 )
1660 .await?;
1661
1662 let livekit_room = room.livekit_room.clone();
1663 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1664 room_updated(&room, &session.peer);
1665 (livekit_room, can_publish)
1666 };
1667
1668 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1669 live_kit
1670 .update_participant(
1671 livekit_room.clone(),
1672 request.user_id.to_string(),
1673 livekit_api::proto::ParticipantPermission {
1674 can_subscribe: true,
1675 can_publish,
1676 can_publish_data: can_publish,
1677 hidden: false,
1678 recorder: false,
1679 },
1680 )
1681 .await
1682 .trace_err();
1683 }
1684
1685 response.send(proto::Ack {})?;
1686 Ok(())
1687}
1688
1689/// Call someone else into the current room
1690async fn call(
1691 request: proto::Call,
1692 response: Response<proto::Call>,
1693 session: MessageContext,
1694) -> Result<()> {
1695 let room_id = RoomId::from_proto(request.room_id);
1696 let calling_user_id = session.user_id();
1697 let calling_connection_id = session.connection_id;
1698 let called_user_id = UserId::from_proto(request.called_user_id);
1699 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1700 if !session
1701 .db()
1702 .await
1703 .has_contact(calling_user_id, called_user_id)
1704 .await?
1705 {
1706 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1707 }
1708
1709 let incoming_call = {
1710 let (room, incoming_call) = &mut *session
1711 .db()
1712 .await
1713 .call(
1714 room_id,
1715 calling_user_id,
1716 calling_connection_id,
1717 called_user_id,
1718 initial_project_id,
1719 )
1720 .await?;
1721 room_updated(room, &session.peer);
1722 mem::take(incoming_call)
1723 };
1724 update_user_contacts(called_user_id, &session).await?;
1725
1726 let mut calls = session
1727 .connection_pool()
1728 .await
1729 .user_connection_ids(called_user_id)
1730 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1731 .collect::<FuturesUnordered<_>>();
1732
1733 while let Some(call_response) = calls.next().await {
1734 match call_response.as_ref() {
1735 Ok(_) => {
1736 response.send(proto::Ack {})?;
1737 return Ok(());
1738 }
1739 Err(_) => {
1740 call_response.trace_err();
1741 }
1742 }
1743 }
1744
1745 {
1746 let room = session
1747 .db()
1748 .await
1749 .call_failed(room_id, called_user_id)
1750 .await?;
1751 room_updated(&room, &session.peer);
1752 }
1753 update_user_contacts(called_user_id, &session).await?;
1754
1755 Err(anyhow!("failed to ring user"))?
1756}
1757
1758/// Cancel an outgoing call.
1759async fn cancel_call(
1760 request: proto::CancelCall,
1761 response: Response<proto::CancelCall>,
1762 session: MessageContext,
1763) -> Result<()> {
1764 let called_user_id = UserId::from_proto(request.called_user_id);
1765 let room_id = RoomId::from_proto(request.room_id);
1766 {
1767 let room = session
1768 .db()
1769 .await
1770 .cancel_call(room_id, session.connection_id, called_user_id)
1771 .await?;
1772 room_updated(&room, &session.peer);
1773 }
1774
1775 for connection_id in session
1776 .connection_pool()
1777 .await
1778 .user_connection_ids(called_user_id)
1779 {
1780 session
1781 .peer
1782 .send(
1783 connection_id,
1784 proto::CallCanceled {
1785 room_id: room_id.to_proto(),
1786 },
1787 )
1788 .trace_err();
1789 }
1790 response.send(proto::Ack {})?;
1791
1792 update_user_contacts(called_user_id, &session).await?;
1793 Ok(())
1794}
1795
1796/// Decline an incoming call.
1797async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1798 let room_id = RoomId::from_proto(message.room_id);
1799 {
1800 let room = session
1801 .db()
1802 .await
1803 .decline_call(Some(room_id), session.user_id())
1804 .await?
1805 .context("declining call")?;
1806 room_updated(&room, &session.peer);
1807 }
1808
1809 for connection_id in session
1810 .connection_pool()
1811 .await
1812 .user_connection_ids(session.user_id())
1813 {
1814 session
1815 .peer
1816 .send(
1817 connection_id,
1818 proto::CallCanceled {
1819 room_id: room_id.to_proto(),
1820 },
1821 )
1822 .trace_err();
1823 }
1824 update_user_contacts(session.user_id(), &session).await?;
1825 Ok(())
1826}
1827
1828/// Updates other participants in the room with your current location.
1829async fn update_participant_location(
1830 request: proto::UpdateParticipantLocation,
1831 response: Response<proto::UpdateParticipantLocation>,
1832 session: MessageContext,
1833) -> Result<()> {
1834 let room_id = RoomId::from_proto(request.room_id);
1835 let location = request.location.context("invalid location")?;
1836
1837 let db = session.db().await;
1838 let room = db
1839 .update_room_participant_location(room_id, session.connection_id, location)
1840 .await?;
1841
1842 room_updated(&room, &session.peer);
1843 response.send(proto::Ack {})?;
1844 Ok(())
1845}
1846
1847/// Share a project into the room.
1848async fn share_project(
1849 request: proto::ShareProject,
1850 response: Response<proto::ShareProject>,
1851 session: MessageContext,
1852) -> Result<()> {
1853 let (project_id, room) = &*session
1854 .db()
1855 .await
1856 .share_project(
1857 RoomId::from_proto(request.room_id),
1858 session.connection_id,
1859 &request.worktrees,
1860 request.is_ssh_project,
1861 request.windows_paths.unwrap_or(false),
1862 )
1863 .await?;
1864 response.send(proto::ShareProjectResponse {
1865 project_id: project_id.to_proto(),
1866 })?;
1867 room_updated(room, &session.peer);
1868
1869 Ok(())
1870}
1871
1872/// Unshare a project from the room.
1873async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1874 let project_id = ProjectId::from_proto(message.project_id);
1875 unshare_project_internal(project_id, session.connection_id, &session).await
1876}
1877
1878async fn unshare_project_internal(
1879 project_id: ProjectId,
1880 connection_id: ConnectionId,
1881 session: &Session,
1882) -> Result<()> {
1883 let delete = {
1884 let room_guard = session
1885 .db()
1886 .await
1887 .unshare_project(project_id, connection_id)
1888 .await?;
1889
1890 let (delete, room, guest_connection_ids) = &*room_guard;
1891
1892 let message = proto::UnshareProject {
1893 project_id: project_id.to_proto(),
1894 };
1895
1896 broadcast(
1897 Some(connection_id),
1898 guest_connection_ids.iter().copied(),
1899 |conn_id| session.peer.send(conn_id, message.clone()),
1900 );
1901 if let Some(room) = room {
1902 room_updated(room, &session.peer);
1903 }
1904
1905 *delete
1906 };
1907
1908 if delete {
1909 let db = session.db().await;
1910 db.delete_project(project_id).await?;
1911 }
1912
1913 Ok(())
1914}
1915
1916/// Join someone elses shared project.
1917async fn join_project(
1918 request: proto::JoinProject,
1919 response: Response<proto::JoinProject>,
1920 session: MessageContext,
1921) -> Result<()> {
1922 let project_id = ProjectId::from_proto(request.project_id);
1923
1924 tracing::info!(%project_id, "join project");
1925
1926 let db = session.db().await;
1927 let (project, replica_id) = &mut *db
1928 .join_project(
1929 project_id,
1930 session.connection_id,
1931 session.user_id(),
1932 request.committer_name.clone(),
1933 request.committer_email.clone(),
1934 )
1935 .await?;
1936 drop(db);
1937 tracing::info!(%project_id, "join remote project");
1938 let collaborators = project
1939 .collaborators
1940 .iter()
1941 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1942 .map(|collaborator| collaborator.to_proto())
1943 .collect::<Vec<_>>();
1944 let project_id = project.id;
1945 let guest_user_id = session.user_id();
1946
1947 let worktrees = project
1948 .worktrees
1949 .iter()
1950 .map(|(id, worktree)| proto::WorktreeMetadata {
1951 id: *id,
1952 root_name: worktree.root_name.clone(),
1953 visible: worktree.visible,
1954 abs_path: worktree.abs_path.clone(),
1955 })
1956 .collect::<Vec<_>>();
1957
1958 let add_project_collaborator = proto::AddProjectCollaborator {
1959 project_id: project_id.to_proto(),
1960 collaborator: Some(proto::Collaborator {
1961 peer_id: Some(session.connection_id.into()),
1962 replica_id: replica_id.0 as u32,
1963 user_id: guest_user_id.to_proto(),
1964 is_host: false,
1965 committer_name: request.committer_name.clone(),
1966 committer_email: request.committer_email.clone(),
1967 }),
1968 };
1969
1970 for collaborator in &collaborators {
1971 session
1972 .peer
1973 .send(
1974 collaborator.peer_id.unwrap().into(),
1975 add_project_collaborator.clone(),
1976 )
1977 .trace_err();
1978 }
1979
1980 // First, we send the metadata associated with each worktree.
1981 let (language_servers, language_server_capabilities) = project
1982 .language_servers
1983 .clone()
1984 .into_iter()
1985 .map(|server| (server.server, server.capabilities))
1986 .unzip();
1987 response.send(proto::JoinProjectResponse {
1988 project_id: project.id.0 as u64,
1989 worktrees,
1990 replica_id: replica_id.0 as u32,
1991 collaborators,
1992 language_servers,
1993 language_server_capabilities,
1994 role: project.role.into(),
1995 windows_paths: project.path_style == PathStyle::Windows,
1996 })?;
1997
1998 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1999 // Stream this worktree's entries.
2000 let message = proto::UpdateWorktree {
2001 project_id: project_id.to_proto(),
2002 worktree_id,
2003 abs_path: worktree.abs_path.clone(),
2004 root_name: worktree.root_name,
2005 updated_entries: worktree.entries,
2006 removed_entries: Default::default(),
2007 scan_id: worktree.scan_id,
2008 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2009 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2010 removed_repositories: Default::default(),
2011 };
2012 for update in proto::split_worktree_update(message) {
2013 session.peer.send(session.connection_id, update.clone())?;
2014 }
2015
2016 // Stream this worktree's diagnostics.
2017 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2018 if let Some(summary) = worktree_diagnostics.next() {
2019 let message = proto::UpdateDiagnosticSummary {
2020 project_id: project.id.to_proto(),
2021 worktree_id: worktree.id,
2022 summary: Some(summary),
2023 more_summaries: worktree_diagnostics.collect(),
2024 };
2025 session.peer.send(session.connection_id, message)?;
2026 }
2027
2028 for settings_file in worktree.settings_files {
2029 session.peer.send(
2030 session.connection_id,
2031 proto::UpdateWorktreeSettings {
2032 project_id: project_id.to_proto(),
2033 worktree_id: worktree.id,
2034 path: settings_file.path,
2035 content: Some(settings_file.content),
2036 kind: Some(settings_file.kind.to_proto() as i32),
2037 },
2038 )?;
2039 }
2040 }
2041
2042 for repository in mem::take(&mut project.repositories) {
2043 for update in split_repository_update(repository) {
2044 session.peer.send(session.connection_id, update)?;
2045 }
2046 }
2047
2048 for language_server in &project.language_servers {
2049 session.peer.send(
2050 session.connection_id,
2051 proto::UpdateLanguageServer {
2052 project_id: project_id.to_proto(),
2053 server_name: Some(language_server.server.name.clone()),
2054 language_server_id: language_server.server.id,
2055 variant: Some(
2056 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2057 proto::LspDiskBasedDiagnosticsUpdated {},
2058 ),
2059 ),
2060 },
2061 )?;
2062 }
2063
2064 Ok(())
2065}
2066
2067/// Leave someone elses shared project.
2068async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2069 let sender_id = session.connection_id;
2070 let project_id = ProjectId::from_proto(request.project_id);
2071 let db = session.db().await;
2072
2073 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2074 tracing::info!(
2075 %project_id,
2076 "leave project"
2077 );
2078
2079 project_left(project, &session);
2080 if let Some(room) = room {
2081 room_updated(room, &session.peer);
2082 }
2083
2084 Ok(())
2085}
2086
2087/// Updates other participants with changes to the project
2088async fn update_project(
2089 request: proto::UpdateProject,
2090 response: Response<proto::UpdateProject>,
2091 session: MessageContext,
2092) -> Result<()> {
2093 let project_id = ProjectId::from_proto(request.project_id);
2094 let (room, guest_connection_ids) = &*session
2095 .db()
2096 .await
2097 .update_project(project_id, session.connection_id, &request.worktrees)
2098 .await?;
2099 broadcast(
2100 Some(session.connection_id),
2101 guest_connection_ids.iter().copied(),
2102 |connection_id| {
2103 session
2104 .peer
2105 .forward_send(session.connection_id, connection_id, request.clone())
2106 },
2107 );
2108 if let Some(room) = room {
2109 room_updated(room, &session.peer);
2110 }
2111 response.send(proto::Ack {})?;
2112
2113 Ok(())
2114}
2115
2116/// Updates other participants with changes to the worktree
2117async fn update_worktree(
2118 request: proto::UpdateWorktree,
2119 response: Response<proto::UpdateWorktree>,
2120 session: MessageContext,
2121) -> Result<()> {
2122 let guest_connection_ids = session
2123 .db()
2124 .await
2125 .update_worktree(&request, session.connection_id)
2126 .await?;
2127
2128 broadcast(
2129 Some(session.connection_id),
2130 guest_connection_ids.iter().copied(),
2131 |connection_id| {
2132 session
2133 .peer
2134 .forward_send(session.connection_id, connection_id, request.clone())
2135 },
2136 );
2137 response.send(proto::Ack {})?;
2138 Ok(())
2139}
2140
2141async fn update_repository(
2142 request: proto::UpdateRepository,
2143 response: Response<proto::UpdateRepository>,
2144 session: MessageContext,
2145) -> Result<()> {
2146 let guest_connection_ids = session
2147 .db()
2148 .await
2149 .update_repository(&request, session.connection_id)
2150 .await?;
2151
2152 broadcast(
2153 Some(session.connection_id),
2154 guest_connection_ids.iter().copied(),
2155 |connection_id| {
2156 session
2157 .peer
2158 .forward_send(session.connection_id, connection_id, request.clone())
2159 },
2160 );
2161 response.send(proto::Ack {})?;
2162 Ok(())
2163}
2164
2165async fn remove_repository(
2166 request: proto::RemoveRepository,
2167 response: Response<proto::RemoveRepository>,
2168 session: MessageContext,
2169) -> Result<()> {
2170 let guest_connection_ids = session
2171 .db()
2172 .await
2173 .remove_repository(&request, session.connection_id)
2174 .await?;
2175
2176 broadcast(
2177 Some(session.connection_id),
2178 guest_connection_ids.iter().copied(),
2179 |connection_id| {
2180 session
2181 .peer
2182 .forward_send(session.connection_id, connection_id, request.clone())
2183 },
2184 );
2185 response.send(proto::Ack {})?;
2186 Ok(())
2187}
2188
2189/// Updates other participants with changes to the diagnostics
2190async fn update_diagnostic_summary(
2191 message: proto::UpdateDiagnosticSummary,
2192 session: MessageContext,
2193) -> Result<()> {
2194 let guest_connection_ids = session
2195 .db()
2196 .await
2197 .update_diagnostic_summary(&message, session.connection_id)
2198 .await?;
2199
2200 broadcast(
2201 Some(session.connection_id),
2202 guest_connection_ids.iter().copied(),
2203 |connection_id| {
2204 session
2205 .peer
2206 .forward_send(session.connection_id, connection_id, message.clone())
2207 },
2208 );
2209
2210 Ok(())
2211}
2212
2213/// Updates other participants with changes to the worktree settings
2214async fn update_worktree_settings(
2215 message: proto::UpdateWorktreeSettings,
2216 session: MessageContext,
2217) -> Result<()> {
2218 let guest_connection_ids = session
2219 .db()
2220 .await
2221 .update_worktree_settings(&message, session.connection_id)
2222 .await?;
2223
2224 broadcast(
2225 Some(session.connection_id),
2226 guest_connection_ids.iter().copied(),
2227 |connection_id| {
2228 session
2229 .peer
2230 .forward_send(session.connection_id, connection_id, message.clone())
2231 },
2232 );
2233
2234 Ok(())
2235}
2236
2237/// Notify other participants that a language server has started.
2238async fn start_language_server(
2239 request: proto::StartLanguageServer,
2240 session: MessageContext,
2241) -> Result<()> {
2242 let guest_connection_ids = session
2243 .db()
2244 .await
2245 .start_language_server(&request, session.connection_id)
2246 .await?;
2247
2248 broadcast(
2249 Some(session.connection_id),
2250 guest_connection_ids.iter().copied(),
2251 |connection_id| {
2252 session
2253 .peer
2254 .forward_send(session.connection_id, connection_id, request.clone())
2255 },
2256 );
2257 Ok(())
2258}
2259
2260/// Notify other participants that a language server has changed.
2261async fn update_language_server(
2262 request: proto::UpdateLanguageServer,
2263 session: MessageContext,
2264) -> Result<()> {
2265 let project_id = ProjectId::from_proto(request.project_id);
2266 let db = session.db().await;
2267
2268 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2269 && let Some(capabilities) = update.capabilities.clone()
2270 {
2271 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2272 .await?;
2273 }
2274
2275 let project_connection_ids = db
2276 .project_connection_ids(project_id, session.connection_id, true)
2277 .await?;
2278 broadcast(
2279 Some(session.connection_id),
2280 project_connection_ids.iter().copied(),
2281 |connection_id| {
2282 session
2283 .peer
2284 .forward_send(session.connection_id, connection_id, request.clone())
2285 },
2286 );
2287 Ok(())
2288}
2289
2290/// forward a project request to the host. These requests should be read only
2291/// as guests are allowed to send them.
2292async fn forward_read_only_project_request<T>(
2293 request: T,
2294 response: Response<T>,
2295 session: MessageContext,
2296) -> Result<()>
2297where
2298 T: EntityMessage + RequestMessage,
2299{
2300 let project_id = ProjectId::from_proto(request.remote_entity_id());
2301 let host_connection_id = session
2302 .db()
2303 .await
2304 .host_for_read_only_project_request(project_id, session.connection_id)
2305 .await?;
2306 let payload = session.forward_request(host_connection_id, request).await?;
2307 response.send(payload)?;
2308 Ok(())
2309}
2310
2311/// forward a project request to the host. These requests are disallowed
2312/// for guests.
2313async fn forward_mutating_project_request<T>(
2314 request: T,
2315 response: Response<T>,
2316 session: MessageContext,
2317) -> Result<()>
2318where
2319 T: EntityMessage + RequestMessage,
2320{
2321 let project_id = ProjectId::from_proto(request.remote_entity_id());
2322
2323 let host_connection_id = session
2324 .db()
2325 .await
2326 .host_for_mutating_project_request(project_id, session.connection_id)
2327 .await?;
2328 let payload = session.forward_request(host_connection_id, request).await?;
2329 response.send(payload)?;
2330 Ok(())
2331}
2332
2333async fn lsp_query(
2334 request: proto::LspQuery,
2335 response: Response<proto::LspQuery>,
2336 session: MessageContext,
2337) -> Result<()> {
2338 let (name, should_write) = request.query_name_and_write_permissions();
2339 tracing::Span::current().record("lsp_query_request", name);
2340 tracing::info!("lsp_query message received");
2341 if should_write {
2342 forward_mutating_project_request(request, response, session).await
2343 } else {
2344 forward_read_only_project_request(request, response, session).await
2345 }
2346}
2347
2348/// Notify other participants that a new buffer has been created
2349async fn create_buffer_for_peer(
2350 request: proto::CreateBufferForPeer,
2351 session: MessageContext,
2352) -> Result<()> {
2353 session
2354 .db()
2355 .await
2356 .check_user_is_project_host(
2357 ProjectId::from_proto(request.project_id),
2358 session.connection_id,
2359 )
2360 .await?;
2361 let peer_id = request.peer_id.context("invalid peer id")?;
2362 session
2363 .peer
2364 .forward_send(session.connection_id, peer_id.into(), request)?;
2365 Ok(())
2366}
2367
2368/// Notify other participants that a new image has been created
2369async fn create_image_for_peer(
2370 request: proto::CreateImageForPeer,
2371 session: MessageContext,
2372) -> Result<()> {
2373 session
2374 .db()
2375 .await
2376 .check_user_is_project_host(
2377 ProjectId::from_proto(request.project_id),
2378 session.connection_id,
2379 )
2380 .await?;
2381 let peer_id = request.peer_id.context("invalid peer id")?;
2382 session
2383 .peer
2384 .forward_send(session.connection_id, peer_id.into(), request)?;
2385 Ok(())
2386}
2387
2388/// Notify other participants that a buffer has been updated. This is
2389/// allowed for guests as long as the update is limited to selections.
2390async fn update_buffer(
2391 request: proto::UpdateBuffer,
2392 response: Response<proto::UpdateBuffer>,
2393 session: MessageContext,
2394) -> Result<()> {
2395 let project_id = ProjectId::from_proto(request.project_id);
2396 let mut capability = Capability::ReadOnly;
2397
2398 for op in request.operations.iter() {
2399 match op.variant {
2400 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2401 Some(_) => capability = Capability::ReadWrite,
2402 }
2403 }
2404
2405 let host = {
2406 let guard = session
2407 .db()
2408 .await
2409 .connections_for_buffer_update(project_id, session.connection_id, capability)
2410 .await?;
2411
2412 let (host, guests) = &*guard;
2413
2414 broadcast(
2415 Some(session.connection_id),
2416 guests.clone(),
2417 |connection_id| {
2418 session
2419 .peer
2420 .forward_send(session.connection_id, connection_id, request.clone())
2421 },
2422 );
2423
2424 *host
2425 };
2426
2427 if host != session.connection_id {
2428 session.forward_request(host, request.clone()).await?;
2429 }
2430
2431 response.send(proto::Ack {})?;
2432 Ok(())
2433}
2434
2435async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2436 let project_id = ProjectId::from_proto(message.project_id);
2437
2438 let operation = message.operation.as_ref().context("invalid operation")?;
2439 let capability = match operation.variant.as_ref() {
2440 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2441 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2442 match buffer_op.variant {
2443 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2444 Capability::ReadOnly
2445 }
2446 _ => Capability::ReadWrite,
2447 }
2448 } else {
2449 Capability::ReadWrite
2450 }
2451 }
2452 Some(_) => Capability::ReadWrite,
2453 None => Capability::ReadOnly,
2454 };
2455
2456 let guard = session
2457 .db()
2458 .await
2459 .connections_for_buffer_update(project_id, session.connection_id, capability)
2460 .await?;
2461
2462 let (host, guests) = &*guard;
2463
2464 broadcast(
2465 Some(session.connection_id),
2466 guests.iter().chain([host]).copied(),
2467 |connection_id| {
2468 session
2469 .peer
2470 .forward_send(session.connection_id, connection_id, message.clone())
2471 },
2472 );
2473
2474 Ok(())
2475}
2476
2477/// Notify other participants that a project has been updated.
2478async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2479 request: T,
2480 session: MessageContext,
2481) -> Result<()> {
2482 let project_id = ProjectId::from_proto(request.remote_entity_id());
2483 let project_connection_ids = session
2484 .db()
2485 .await
2486 .project_connection_ids(project_id, session.connection_id, false)
2487 .await?;
2488
2489 broadcast(
2490 Some(session.connection_id),
2491 project_connection_ids.iter().copied(),
2492 |connection_id| {
2493 session
2494 .peer
2495 .forward_send(session.connection_id, connection_id, request.clone())
2496 },
2497 );
2498 Ok(())
2499}
2500
2501/// Start following another user in a call.
2502async fn follow(
2503 request: proto::Follow,
2504 response: Response<proto::Follow>,
2505 session: MessageContext,
2506) -> Result<()> {
2507 let room_id = RoomId::from_proto(request.room_id);
2508 let project_id = request.project_id.map(ProjectId::from_proto);
2509 let leader_id = request.leader_id.context("invalid leader id")?.into();
2510 let follower_id = session.connection_id;
2511
2512 session
2513 .db()
2514 .await
2515 .check_room_participants(room_id, leader_id, session.connection_id)
2516 .await?;
2517
2518 let response_payload = session.forward_request(leader_id, request).await?;
2519 response.send(response_payload)?;
2520
2521 if let Some(project_id) = project_id {
2522 let room = session
2523 .db()
2524 .await
2525 .follow(room_id, project_id, leader_id, follower_id)
2526 .await?;
2527 room_updated(&room, &session.peer);
2528 }
2529
2530 Ok(())
2531}
2532
2533/// Stop following another user in a call.
2534async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2535 let room_id = RoomId::from_proto(request.room_id);
2536 let project_id = request.project_id.map(ProjectId::from_proto);
2537 let leader_id = request.leader_id.context("invalid leader id")?.into();
2538 let follower_id = session.connection_id;
2539
2540 session
2541 .db()
2542 .await
2543 .check_room_participants(room_id, leader_id, session.connection_id)
2544 .await?;
2545
2546 session
2547 .peer
2548 .forward_send(session.connection_id, leader_id, request)?;
2549
2550 if let Some(project_id) = project_id {
2551 let room = session
2552 .db()
2553 .await
2554 .unfollow(room_id, project_id, leader_id, follower_id)
2555 .await?;
2556 room_updated(&room, &session.peer);
2557 }
2558
2559 Ok(())
2560}
2561
2562/// Notify everyone following you of your current location.
2563async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2564 let room_id = RoomId::from_proto(request.room_id);
2565 let database = session.db.lock().await;
2566
2567 let connection_ids = if let Some(project_id) = request.project_id {
2568 let project_id = ProjectId::from_proto(project_id);
2569 database
2570 .project_connection_ids(project_id, session.connection_id, true)
2571 .await?
2572 } else {
2573 database
2574 .room_connection_ids(room_id, session.connection_id)
2575 .await?
2576 };
2577
2578 // For now, don't send view update messages back to that view's current leader.
2579 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2580 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2581 _ => None,
2582 });
2583
2584 for connection_id in connection_ids.iter().cloned() {
2585 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2586 session
2587 .peer
2588 .forward_send(session.connection_id, connection_id, request.clone())?;
2589 }
2590 }
2591 Ok(())
2592}
2593
2594/// Get public data about users.
2595async fn get_users(
2596 request: proto::GetUsers,
2597 response: Response<proto::GetUsers>,
2598 session: MessageContext,
2599) -> Result<()> {
2600 let user_ids = request
2601 .user_ids
2602 .into_iter()
2603 .map(UserId::from_proto)
2604 .collect();
2605 let users = session
2606 .db()
2607 .await
2608 .get_users_by_ids(user_ids)
2609 .await?
2610 .into_iter()
2611 .map(|user| proto::User {
2612 id: user.id.to_proto(),
2613 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2614 github_login: user.github_login,
2615 name: user.name,
2616 })
2617 .collect();
2618 response.send(proto::UsersResponse { users })?;
2619 Ok(())
2620}
2621
2622/// Search for users (to invite) buy Github login
2623async fn fuzzy_search_users(
2624 request: proto::FuzzySearchUsers,
2625 response: Response<proto::FuzzySearchUsers>,
2626 session: MessageContext,
2627) -> Result<()> {
2628 let query = request.query;
2629 let users = match query.len() {
2630 0 => vec![],
2631 1 | 2 => session
2632 .db()
2633 .await
2634 .get_user_by_github_login(&query)
2635 .await?
2636 .into_iter()
2637 .collect(),
2638 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2639 };
2640 let users = users
2641 .into_iter()
2642 .filter(|user| user.id != session.user_id())
2643 .map(|user| proto::User {
2644 id: user.id.to_proto(),
2645 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2646 github_login: user.github_login,
2647 name: user.name,
2648 })
2649 .collect();
2650 response.send(proto::UsersResponse { users })?;
2651 Ok(())
2652}
2653
2654/// Send a contact request to another user.
2655async fn request_contact(
2656 request: proto::RequestContact,
2657 response: Response<proto::RequestContact>,
2658 session: MessageContext,
2659) -> Result<()> {
2660 let requester_id = session.user_id();
2661 let responder_id = UserId::from_proto(request.responder_id);
2662 if requester_id == responder_id {
2663 return Err(anyhow!("cannot add yourself as a contact"))?;
2664 }
2665
2666 let notifications = session
2667 .db()
2668 .await
2669 .send_contact_request(requester_id, responder_id)
2670 .await?;
2671
2672 // Update outgoing contact requests of requester
2673 let mut update = proto::UpdateContacts::default();
2674 update.outgoing_requests.push(responder_id.to_proto());
2675 for connection_id in session
2676 .connection_pool()
2677 .await
2678 .user_connection_ids(requester_id)
2679 {
2680 session.peer.send(connection_id, update.clone())?;
2681 }
2682
2683 // Update incoming contact requests of responder
2684 let mut update = proto::UpdateContacts::default();
2685 update
2686 .incoming_requests
2687 .push(proto::IncomingContactRequest {
2688 requester_id: requester_id.to_proto(),
2689 });
2690 let connection_pool = session.connection_pool().await;
2691 for connection_id in connection_pool.user_connection_ids(responder_id) {
2692 session.peer.send(connection_id, update.clone())?;
2693 }
2694
2695 send_notifications(&connection_pool, &session.peer, notifications);
2696
2697 response.send(proto::Ack {})?;
2698 Ok(())
2699}
2700
2701/// Accept or decline a contact request
2702async fn respond_to_contact_request(
2703 request: proto::RespondToContactRequest,
2704 response: Response<proto::RespondToContactRequest>,
2705 session: MessageContext,
2706) -> Result<()> {
2707 let responder_id = session.user_id();
2708 let requester_id = UserId::from_proto(request.requester_id);
2709 let db = session.db().await;
2710 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2711 db.dismiss_contact_notification(responder_id, requester_id)
2712 .await?;
2713 } else {
2714 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2715
2716 let notifications = db
2717 .respond_to_contact_request(responder_id, requester_id, accept)
2718 .await?;
2719 let requester_busy = db.is_user_busy(requester_id).await?;
2720 let responder_busy = db.is_user_busy(responder_id).await?;
2721
2722 let pool = session.connection_pool().await;
2723 // Update responder with new contact
2724 let mut update = proto::UpdateContacts::default();
2725 if accept {
2726 update
2727 .contacts
2728 .push(contact_for_user(requester_id, requester_busy, &pool));
2729 }
2730 update
2731 .remove_incoming_requests
2732 .push(requester_id.to_proto());
2733 for connection_id in pool.user_connection_ids(responder_id) {
2734 session.peer.send(connection_id, update.clone())?;
2735 }
2736
2737 // Update requester with new contact
2738 let mut update = proto::UpdateContacts::default();
2739 if accept {
2740 update
2741 .contacts
2742 .push(contact_for_user(responder_id, responder_busy, &pool));
2743 }
2744 update
2745 .remove_outgoing_requests
2746 .push(responder_id.to_proto());
2747
2748 for connection_id in pool.user_connection_ids(requester_id) {
2749 session.peer.send(connection_id, update.clone())?;
2750 }
2751
2752 send_notifications(&pool, &session.peer, notifications);
2753 }
2754
2755 response.send(proto::Ack {})?;
2756 Ok(())
2757}
2758
2759/// Remove a contact.
2760async fn remove_contact(
2761 request: proto::RemoveContact,
2762 response: Response<proto::RemoveContact>,
2763 session: MessageContext,
2764) -> Result<()> {
2765 let requester_id = session.user_id();
2766 let responder_id = UserId::from_proto(request.user_id);
2767 let db = session.db().await;
2768 let (contact_accepted, deleted_notification_id) =
2769 db.remove_contact(requester_id, responder_id).await?;
2770
2771 let pool = session.connection_pool().await;
2772 // Update outgoing contact requests of requester
2773 let mut update = proto::UpdateContacts::default();
2774 if contact_accepted {
2775 update.remove_contacts.push(responder_id.to_proto());
2776 } else {
2777 update
2778 .remove_outgoing_requests
2779 .push(responder_id.to_proto());
2780 }
2781 for connection_id in pool.user_connection_ids(requester_id) {
2782 session.peer.send(connection_id, update.clone())?;
2783 }
2784
2785 // Update incoming contact requests of responder
2786 let mut update = proto::UpdateContacts::default();
2787 if contact_accepted {
2788 update.remove_contacts.push(requester_id.to_proto());
2789 } else {
2790 update
2791 .remove_incoming_requests
2792 .push(requester_id.to_proto());
2793 }
2794 for connection_id in pool.user_connection_ids(responder_id) {
2795 session.peer.send(connection_id, update.clone())?;
2796 if let Some(notification_id) = deleted_notification_id {
2797 session.peer.send(
2798 connection_id,
2799 proto::DeleteNotification {
2800 notification_id: notification_id.to_proto(),
2801 },
2802 )?;
2803 }
2804 }
2805
2806 response.send(proto::Ack {})?;
2807 Ok(())
2808}
2809
2810fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2811 version.0.minor < 139
2812}
2813
2814async fn subscribe_to_channels(
2815 _: proto::SubscribeToChannels,
2816 session: MessageContext,
2817) -> Result<()> {
2818 subscribe_user_to_channels(session.user_id(), &session).await?;
2819 Ok(())
2820}
2821
2822async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2823 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2824 let mut pool = session.connection_pool().await;
2825 for membership in &channels_for_user.channel_memberships {
2826 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2827 }
2828 session.peer.send(
2829 session.connection_id,
2830 build_update_user_channels(&channels_for_user),
2831 )?;
2832 session.peer.send(
2833 session.connection_id,
2834 build_channels_update(channels_for_user),
2835 )?;
2836 Ok(())
2837}
2838
2839/// Creates a new channel.
2840async fn create_channel(
2841 request: proto::CreateChannel,
2842 response: Response<proto::CreateChannel>,
2843 session: MessageContext,
2844) -> Result<()> {
2845 let db = session.db().await;
2846
2847 let parent_id = request.parent_id.map(ChannelId::from_proto);
2848 let (channel, membership) = db
2849 .create_channel(&request.name, parent_id, session.user_id())
2850 .await?;
2851
2852 let root_id = channel.root_id();
2853 let channel = Channel::from_model(channel);
2854
2855 response.send(proto::CreateChannelResponse {
2856 channel: Some(channel.to_proto()),
2857 parent_id: request.parent_id,
2858 })?;
2859
2860 let mut connection_pool = session.connection_pool().await;
2861 if let Some(membership) = membership {
2862 connection_pool.subscribe_to_channel(
2863 membership.user_id,
2864 membership.channel_id,
2865 membership.role,
2866 );
2867 let update = proto::UpdateUserChannels {
2868 channel_memberships: vec![proto::ChannelMembership {
2869 channel_id: membership.channel_id.to_proto(),
2870 role: membership.role.into(),
2871 }],
2872 ..Default::default()
2873 };
2874 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2875 session.peer.send(connection_id, update.clone())?;
2876 }
2877 }
2878
2879 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2880 if !role.can_see_channel(channel.visibility) {
2881 continue;
2882 }
2883
2884 let update = proto::UpdateChannels {
2885 channels: vec![channel.to_proto()],
2886 ..Default::default()
2887 };
2888 session.peer.send(connection_id, update.clone())?;
2889 }
2890
2891 Ok(())
2892}
2893
2894/// Delete a channel
2895async fn delete_channel(
2896 request: proto::DeleteChannel,
2897 response: Response<proto::DeleteChannel>,
2898 session: MessageContext,
2899) -> Result<()> {
2900 let db = session.db().await;
2901
2902 let channel_id = request.channel_id;
2903 let (root_channel, removed_channels) = db
2904 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2905 .await?;
2906 response.send(proto::Ack {})?;
2907
2908 // Notify members of removed channels
2909 let mut update = proto::UpdateChannels::default();
2910 update
2911 .delete_channels
2912 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2913
2914 let connection_pool = session.connection_pool().await;
2915 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2916 session.peer.send(connection_id, update.clone())?;
2917 }
2918
2919 Ok(())
2920}
2921
2922/// Invite someone to join a channel.
2923async fn invite_channel_member(
2924 request: proto::InviteChannelMember,
2925 response: Response<proto::InviteChannelMember>,
2926 session: MessageContext,
2927) -> Result<()> {
2928 let db = session.db().await;
2929 let channel_id = ChannelId::from_proto(request.channel_id);
2930 let invitee_id = UserId::from_proto(request.user_id);
2931 let InviteMemberResult {
2932 channel,
2933 notifications,
2934 } = db
2935 .invite_channel_member(
2936 channel_id,
2937 invitee_id,
2938 session.user_id(),
2939 request.role().into(),
2940 )
2941 .await?;
2942
2943 let update = proto::UpdateChannels {
2944 channel_invitations: vec![channel.to_proto()],
2945 ..Default::default()
2946 };
2947
2948 let connection_pool = session.connection_pool().await;
2949 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2950 session.peer.send(connection_id, update.clone())?;
2951 }
2952
2953 send_notifications(&connection_pool, &session.peer, notifications);
2954
2955 response.send(proto::Ack {})?;
2956 Ok(())
2957}
2958
2959/// remove someone from a channel
2960async fn remove_channel_member(
2961 request: proto::RemoveChannelMember,
2962 response: Response<proto::RemoveChannelMember>,
2963 session: MessageContext,
2964) -> Result<()> {
2965 let db = session.db().await;
2966 let channel_id = ChannelId::from_proto(request.channel_id);
2967 let member_id = UserId::from_proto(request.user_id);
2968
2969 let RemoveChannelMemberResult {
2970 membership_update,
2971 notification_id,
2972 } = db
2973 .remove_channel_member(channel_id, member_id, session.user_id())
2974 .await?;
2975
2976 let mut connection_pool = session.connection_pool().await;
2977 notify_membership_updated(
2978 &mut connection_pool,
2979 membership_update,
2980 member_id,
2981 &session.peer,
2982 );
2983 for connection_id in connection_pool.user_connection_ids(member_id) {
2984 if let Some(notification_id) = notification_id {
2985 session
2986 .peer
2987 .send(
2988 connection_id,
2989 proto::DeleteNotification {
2990 notification_id: notification_id.to_proto(),
2991 },
2992 )
2993 .trace_err();
2994 }
2995 }
2996
2997 response.send(proto::Ack {})?;
2998 Ok(())
2999}
3000
3001/// Toggle the channel between public and private.
3002/// Care is taken to maintain the invariant that public channels only descend from public channels,
3003/// (though members-only channels can appear at any point in the hierarchy).
3004async fn set_channel_visibility(
3005 request: proto::SetChannelVisibility,
3006 response: Response<proto::SetChannelVisibility>,
3007 session: MessageContext,
3008) -> Result<()> {
3009 let db = session.db().await;
3010 let channel_id = ChannelId::from_proto(request.channel_id);
3011 let visibility = request.visibility().into();
3012
3013 let channel_model = db
3014 .set_channel_visibility(channel_id, visibility, session.user_id())
3015 .await?;
3016 let root_id = channel_model.root_id();
3017 let channel = Channel::from_model(channel_model);
3018
3019 let mut connection_pool = session.connection_pool().await;
3020 for (user_id, role) in connection_pool
3021 .channel_user_ids(root_id)
3022 .collect::<Vec<_>>()
3023 .into_iter()
3024 {
3025 let update = if role.can_see_channel(channel.visibility) {
3026 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3027 proto::UpdateChannels {
3028 channels: vec![channel.to_proto()],
3029 ..Default::default()
3030 }
3031 } else {
3032 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3033 proto::UpdateChannels {
3034 delete_channels: vec![channel.id.to_proto()],
3035 ..Default::default()
3036 }
3037 };
3038
3039 for connection_id in connection_pool.user_connection_ids(user_id) {
3040 session.peer.send(connection_id, update.clone())?;
3041 }
3042 }
3043
3044 response.send(proto::Ack {})?;
3045 Ok(())
3046}
3047
3048/// Alter the role for a user in the channel.
3049async fn set_channel_member_role(
3050 request: proto::SetChannelMemberRole,
3051 response: Response<proto::SetChannelMemberRole>,
3052 session: MessageContext,
3053) -> Result<()> {
3054 let db = session.db().await;
3055 let channel_id = ChannelId::from_proto(request.channel_id);
3056 let member_id = UserId::from_proto(request.user_id);
3057 let result = db
3058 .set_channel_member_role(
3059 channel_id,
3060 session.user_id(),
3061 member_id,
3062 request.role().into(),
3063 )
3064 .await?;
3065
3066 match result {
3067 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3068 let mut connection_pool = session.connection_pool().await;
3069 notify_membership_updated(
3070 &mut connection_pool,
3071 membership_update,
3072 member_id,
3073 &session.peer,
3074 )
3075 }
3076 db::SetMemberRoleResult::InviteUpdated(channel) => {
3077 let update = proto::UpdateChannels {
3078 channel_invitations: vec![channel.to_proto()],
3079 ..Default::default()
3080 };
3081
3082 for connection_id in session
3083 .connection_pool()
3084 .await
3085 .user_connection_ids(member_id)
3086 {
3087 session.peer.send(connection_id, update.clone())?;
3088 }
3089 }
3090 }
3091
3092 response.send(proto::Ack {})?;
3093 Ok(())
3094}
3095
3096/// Change the name of a channel
3097async fn rename_channel(
3098 request: proto::RenameChannel,
3099 response: Response<proto::RenameChannel>,
3100 session: MessageContext,
3101) -> Result<()> {
3102 let db = session.db().await;
3103 let channel_id = ChannelId::from_proto(request.channel_id);
3104 let channel_model = db
3105 .rename_channel(channel_id, session.user_id(), &request.name)
3106 .await?;
3107 let root_id = channel_model.root_id();
3108 let channel = Channel::from_model(channel_model);
3109
3110 response.send(proto::RenameChannelResponse {
3111 channel: Some(channel.to_proto()),
3112 })?;
3113
3114 let connection_pool = session.connection_pool().await;
3115 let update = proto::UpdateChannels {
3116 channels: vec![channel.to_proto()],
3117 ..Default::default()
3118 };
3119 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3120 if role.can_see_channel(channel.visibility) {
3121 session.peer.send(connection_id, update.clone())?;
3122 }
3123 }
3124
3125 Ok(())
3126}
3127
3128/// Move a channel to a new parent.
3129async fn move_channel(
3130 request: proto::MoveChannel,
3131 response: Response<proto::MoveChannel>,
3132 session: MessageContext,
3133) -> Result<()> {
3134 let channel_id = ChannelId::from_proto(request.channel_id);
3135 let to = ChannelId::from_proto(request.to);
3136
3137 let (root_id, channels) = session
3138 .db()
3139 .await
3140 .move_channel(channel_id, to, session.user_id())
3141 .await?;
3142
3143 let connection_pool = session.connection_pool().await;
3144 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3145 let channels = channels
3146 .iter()
3147 .filter_map(|channel| {
3148 if role.can_see_channel(channel.visibility) {
3149 Some(channel.to_proto())
3150 } else {
3151 None
3152 }
3153 })
3154 .collect::<Vec<_>>();
3155 if channels.is_empty() {
3156 continue;
3157 }
3158
3159 let update = proto::UpdateChannels {
3160 channels,
3161 ..Default::default()
3162 };
3163
3164 session.peer.send(connection_id, update.clone())?;
3165 }
3166
3167 response.send(Ack {})?;
3168 Ok(())
3169}
3170
3171async fn reorder_channel(
3172 request: proto::ReorderChannel,
3173 response: Response<proto::ReorderChannel>,
3174 session: MessageContext,
3175) -> Result<()> {
3176 let channel_id = ChannelId::from_proto(request.channel_id);
3177 let direction = request.direction();
3178
3179 let updated_channels = session
3180 .db()
3181 .await
3182 .reorder_channel(channel_id, direction, session.user_id())
3183 .await?;
3184
3185 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3186 let connection_pool = session.connection_pool().await;
3187 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3188 let channels = updated_channels
3189 .iter()
3190 .filter_map(|channel| {
3191 if role.can_see_channel(channel.visibility) {
3192 Some(channel.to_proto())
3193 } else {
3194 None
3195 }
3196 })
3197 .collect::<Vec<_>>();
3198
3199 if channels.is_empty() {
3200 continue;
3201 }
3202
3203 let update = proto::UpdateChannels {
3204 channels,
3205 ..Default::default()
3206 };
3207
3208 session.peer.send(connection_id, update.clone())?;
3209 }
3210 }
3211
3212 response.send(Ack {})?;
3213 Ok(())
3214}
3215
3216/// Get the list of channel members
3217async fn get_channel_members(
3218 request: proto::GetChannelMembers,
3219 response: Response<proto::GetChannelMembers>,
3220 session: MessageContext,
3221) -> Result<()> {
3222 let db = session.db().await;
3223 let channel_id = ChannelId::from_proto(request.channel_id);
3224 let limit = if request.limit == 0 {
3225 u16::MAX as u64
3226 } else {
3227 request.limit
3228 };
3229 let (members, users) = db
3230 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3231 .await?;
3232 response.send(proto::GetChannelMembersResponse { members, users })?;
3233 Ok(())
3234}
3235
3236/// Accept or decline a channel invitation.
3237async fn respond_to_channel_invite(
3238 request: proto::RespondToChannelInvite,
3239 response: Response<proto::RespondToChannelInvite>,
3240 session: MessageContext,
3241) -> Result<()> {
3242 let db = session.db().await;
3243 let channel_id = ChannelId::from_proto(request.channel_id);
3244 let RespondToChannelInvite {
3245 membership_update,
3246 notifications,
3247 } = db
3248 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3249 .await?;
3250
3251 let mut connection_pool = session.connection_pool().await;
3252 if let Some(membership_update) = membership_update {
3253 notify_membership_updated(
3254 &mut connection_pool,
3255 membership_update,
3256 session.user_id(),
3257 &session.peer,
3258 );
3259 } else {
3260 let update = proto::UpdateChannels {
3261 remove_channel_invitations: vec![channel_id.to_proto()],
3262 ..Default::default()
3263 };
3264
3265 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3266 session.peer.send(connection_id, update.clone())?;
3267 }
3268 };
3269
3270 send_notifications(&connection_pool, &session.peer, notifications);
3271
3272 response.send(proto::Ack {})?;
3273
3274 Ok(())
3275}
3276
3277/// Join the channels' room
3278async fn join_channel(
3279 request: proto::JoinChannel,
3280 response: Response<proto::JoinChannel>,
3281 session: MessageContext,
3282) -> Result<()> {
3283 let channel_id = ChannelId::from_proto(request.channel_id);
3284 join_channel_internal(channel_id, Box::new(response), session).await
3285}
3286
3287trait JoinChannelInternalResponse {
3288 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3289}
3290impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3291 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3292 Response::<proto::JoinChannel>::send(self, result)
3293 }
3294}
3295impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3296 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3297 Response::<proto::JoinRoom>::send(self, result)
3298 }
3299}
3300
3301async fn join_channel_internal(
3302 channel_id: ChannelId,
3303 response: Box<impl JoinChannelInternalResponse>,
3304 session: MessageContext,
3305) -> Result<()> {
3306 let joined_room = {
3307 let mut db = session.db().await;
3308 // If zed quits without leaving the room, and the user re-opens zed before the
3309 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3310 // room they were in.
3311 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3312 tracing::info!(
3313 stale_connection_id = %connection,
3314 "cleaning up stale connection",
3315 );
3316 drop(db);
3317 leave_room_for_session(&session, connection).await?;
3318 db = session.db().await;
3319 }
3320
3321 let (joined_room, membership_updated, role) = db
3322 .join_channel(channel_id, session.user_id(), session.connection_id)
3323 .await?;
3324
3325 let live_kit_connection_info =
3326 session
3327 .app_state
3328 .livekit_client
3329 .as_ref()
3330 .and_then(|live_kit| {
3331 let (can_publish, token) = if role == ChannelRole::Guest {
3332 (
3333 false,
3334 live_kit
3335 .guest_token(
3336 &joined_room.room.livekit_room,
3337 &session.user_id().to_string(),
3338 )
3339 .trace_err()?,
3340 )
3341 } else {
3342 (
3343 true,
3344 live_kit
3345 .room_token(
3346 &joined_room.room.livekit_room,
3347 &session.user_id().to_string(),
3348 )
3349 .trace_err()?,
3350 )
3351 };
3352
3353 Some(LiveKitConnectionInfo {
3354 server_url: live_kit.url().into(),
3355 token,
3356 can_publish,
3357 })
3358 });
3359
3360 response.send(proto::JoinRoomResponse {
3361 room: Some(joined_room.room.clone()),
3362 channel_id: joined_room
3363 .channel
3364 .as_ref()
3365 .map(|channel| channel.id.to_proto()),
3366 live_kit_connection_info,
3367 })?;
3368
3369 let mut connection_pool = session.connection_pool().await;
3370 if let Some(membership_updated) = membership_updated {
3371 notify_membership_updated(
3372 &mut connection_pool,
3373 membership_updated,
3374 session.user_id(),
3375 &session.peer,
3376 );
3377 }
3378
3379 room_updated(&joined_room.room, &session.peer);
3380
3381 joined_room
3382 };
3383
3384 channel_updated(
3385 &joined_room.channel.context("channel not returned")?,
3386 &joined_room.room,
3387 &session.peer,
3388 &*session.connection_pool().await,
3389 );
3390
3391 update_user_contacts(session.user_id(), &session).await?;
3392 Ok(())
3393}
3394
3395/// Start editing the channel notes
3396async fn join_channel_buffer(
3397 request: proto::JoinChannelBuffer,
3398 response: Response<proto::JoinChannelBuffer>,
3399 session: MessageContext,
3400) -> Result<()> {
3401 let db = session.db().await;
3402 let channel_id = ChannelId::from_proto(request.channel_id);
3403
3404 let open_response = db
3405 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3406 .await?;
3407
3408 let collaborators = open_response.collaborators.clone();
3409 response.send(open_response)?;
3410
3411 let update = UpdateChannelBufferCollaborators {
3412 channel_id: channel_id.to_proto(),
3413 collaborators: collaborators.clone(),
3414 };
3415 channel_buffer_updated(
3416 session.connection_id,
3417 collaborators
3418 .iter()
3419 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3420 &update,
3421 &session.peer,
3422 );
3423
3424 Ok(())
3425}
3426
3427/// Edit the channel notes
3428async fn update_channel_buffer(
3429 request: proto::UpdateChannelBuffer,
3430 session: MessageContext,
3431) -> Result<()> {
3432 let db = session.db().await;
3433 let channel_id = ChannelId::from_proto(request.channel_id);
3434
3435 let (collaborators, epoch, version) = db
3436 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3437 .await?;
3438
3439 channel_buffer_updated(
3440 session.connection_id,
3441 collaborators.clone(),
3442 &proto::UpdateChannelBuffer {
3443 channel_id: channel_id.to_proto(),
3444 operations: request.operations,
3445 },
3446 &session.peer,
3447 );
3448
3449 let pool = &*session.connection_pool().await;
3450
3451 let non_collaborators =
3452 pool.channel_connection_ids(channel_id)
3453 .filter_map(|(connection_id, _)| {
3454 if collaborators.contains(&connection_id) {
3455 None
3456 } else {
3457 Some(connection_id)
3458 }
3459 });
3460
3461 broadcast(None, non_collaborators, |peer_id| {
3462 session.peer.send(
3463 peer_id,
3464 proto::UpdateChannels {
3465 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3466 channel_id: channel_id.to_proto(),
3467 epoch: epoch as u64,
3468 version: version.clone(),
3469 }],
3470 ..Default::default()
3471 },
3472 )
3473 });
3474
3475 Ok(())
3476}
3477
3478/// Rejoin the channel notes after a connection blip
3479async fn rejoin_channel_buffers(
3480 request: proto::RejoinChannelBuffers,
3481 response: Response<proto::RejoinChannelBuffers>,
3482 session: MessageContext,
3483) -> Result<()> {
3484 let db = session.db().await;
3485 let buffers = db
3486 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3487 .await?;
3488
3489 for rejoined_buffer in &buffers {
3490 let collaborators_to_notify = rejoined_buffer
3491 .buffer
3492 .collaborators
3493 .iter()
3494 .filter_map(|c| Some(c.peer_id?.into()));
3495 channel_buffer_updated(
3496 session.connection_id,
3497 collaborators_to_notify,
3498 &proto::UpdateChannelBufferCollaborators {
3499 channel_id: rejoined_buffer.buffer.channel_id,
3500 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3501 },
3502 &session.peer,
3503 );
3504 }
3505
3506 response.send(proto::RejoinChannelBuffersResponse {
3507 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3508 })?;
3509
3510 Ok(())
3511}
3512
3513/// Stop editing the channel notes
3514async fn leave_channel_buffer(
3515 request: proto::LeaveChannelBuffer,
3516 response: Response<proto::LeaveChannelBuffer>,
3517 session: MessageContext,
3518) -> Result<()> {
3519 let db = session.db().await;
3520 let channel_id = ChannelId::from_proto(request.channel_id);
3521
3522 let left_buffer = db
3523 .leave_channel_buffer(channel_id, session.connection_id)
3524 .await?;
3525
3526 response.send(Ack {})?;
3527
3528 channel_buffer_updated(
3529 session.connection_id,
3530 left_buffer.connections,
3531 &proto::UpdateChannelBufferCollaborators {
3532 channel_id: channel_id.to_proto(),
3533 collaborators: left_buffer.collaborators,
3534 },
3535 &session.peer,
3536 );
3537
3538 Ok(())
3539}
3540
3541fn channel_buffer_updated<T: EnvelopedMessage>(
3542 sender_id: ConnectionId,
3543 collaborators: impl IntoIterator<Item = ConnectionId>,
3544 message: &T,
3545 peer: &Peer,
3546) {
3547 broadcast(Some(sender_id), collaborators, |peer_id| {
3548 peer.send(peer_id, message.clone())
3549 });
3550}
3551
3552fn send_notifications(
3553 connection_pool: &ConnectionPool,
3554 peer: &Peer,
3555 notifications: db::NotificationBatch,
3556) {
3557 for (user_id, notification) in notifications {
3558 for connection_id in connection_pool.user_connection_ids(user_id) {
3559 if let Err(error) = peer.send(
3560 connection_id,
3561 proto::AddNotification {
3562 notification: Some(notification.clone()),
3563 },
3564 ) {
3565 tracing::error!(
3566 "failed to send notification to {:?} {}",
3567 connection_id,
3568 error
3569 );
3570 }
3571 }
3572 }
3573}
3574
3575/// Send a message to the channel
3576async fn send_channel_message(
3577 _request: proto::SendChannelMessage,
3578 _response: Response<proto::SendChannelMessage>,
3579 _session: MessageContext,
3580) -> Result<()> {
3581 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3582}
3583
3584/// Delete a channel message
3585async fn remove_channel_message(
3586 _request: proto::RemoveChannelMessage,
3587 _response: Response<proto::RemoveChannelMessage>,
3588 _session: MessageContext,
3589) -> Result<()> {
3590 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3591}
3592
3593async fn update_channel_message(
3594 _request: proto::UpdateChannelMessage,
3595 _response: Response<proto::UpdateChannelMessage>,
3596 _session: MessageContext,
3597) -> Result<()> {
3598 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3599}
3600
3601/// Mark a channel message as read
3602async fn acknowledge_channel_message(
3603 _request: proto::AckChannelMessage,
3604 _session: MessageContext,
3605) -> Result<()> {
3606 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3607}
3608
3609/// Mark a buffer version as synced
3610async fn acknowledge_buffer_version(
3611 request: proto::AckBufferOperation,
3612 session: MessageContext,
3613) -> Result<()> {
3614 let buffer_id = BufferId::from_proto(request.buffer_id);
3615 session
3616 .db()
3617 .await
3618 .observe_buffer_version(
3619 buffer_id,
3620 session.user_id(),
3621 request.epoch as i32,
3622 &request.version,
3623 )
3624 .await?;
3625 Ok(())
3626}
3627
3628/// Start receiving chat updates for a channel
3629async fn join_channel_chat(
3630 _request: proto::JoinChannelChat,
3631 _response: Response<proto::JoinChannelChat>,
3632 _session: MessageContext,
3633) -> Result<()> {
3634 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3635}
3636
3637/// Stop receiving chat updates for a channel
3638async fn leave_channel_chat(
3639 _request: proto::LeaveChannelChat,
3640 _session: MessageContext,
3641) -> Result<()> {
3642 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3643}
3644
3645/// Retrieve the chat history for a channel
3646async fn get_channel_messages(
3647 _request: proto::GetChannelMessages,
3648 _response: Response<proto::GetChannelMessages>,
3649 _session: MessageContext,
3650) -> Result<()> {
3651 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3652}
3653
3654/// Retrieve specific chat messages
3655async fn get_channel_messages_by_id(
3656 _request: proto::GetChannelMessagesById,
3657 _response: Response<proto::GetChannelMessagesById>,
3658 _session: MessageContext,
3659) -> Result<()> {
3660 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3661}
3662
3663/// Retrieve the current users notifications
3664async fn get_notifications(
3665 request: proto::GetNotifications,
3666 response: Response<proto::GetNotifications>,
3667 session: MessageContext,
3668) -> Result<()> {
3669 let notifications = session
3670 .db()
3671 .await
3672 .get_notifications(
3673 session.user_id(),
3674 NOTIFICATION_COUNT_PER_PAGE,
3675 request.before_id.map(db::NotificationId::from_proto),
3676 )
3677 .await?;
3678 response.send(proto::GetNotificationsResponse {
3679 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3680 notifications,
3681 })?;
3682 Ok(())
3683}
3684
3685/// Mark notifications as read
3686async fn mark_notification_as_read(
3687 request: proto::MarkNotificationRead,
3688 response: Response<proto::MarkNotificationRead>,
3689 session: MessageContext,
3690) -> Result<()> {
3691 let database = &session.db().await;
3692 let notifications = database
3693 .mark_notification_as_read_by_id(
3694 session.user_id(),
3695 NotificationId::from_proto(request.notification_id),
3696 )
3697 .await?;
3698 send_notifications(
3699 &*session.connection_pool().await,
3700 &session.peer,
3701 notifications,
3702 );
3703 response.send(proto::Ack {})?;
3704 Ok(())
3705}
3706
3707fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3708 let message = match message {
3709 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3710 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3711 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3712 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3713 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3714 code: frame.code.into(),
3715 reason: frame.reason.as_str().to_owned().into(),
3716 })),
3717 // We should never receive a frame while reading the message, according
3718 // to the `tungstenite` maintainers:
3719 //
3720 // > It cannot occur when you read messages from the WebSocket, but it
3721 // > can be used when you want to send the raw frames (e.g. you want to
3722 // > send the frames to the WebSocket without composing the full message first).
3723 // >
3724 // > — https://github.com/snapview/tungstenite-rs/issues/268
3725 TungsteniteMessage::Frame(_) => {
3726 bail!("received an unexpected frame while reading the message")
3727 }
3728 };
3729
3730 Ok(message)
3731}
3732
3733fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3734 match message {
3735 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3736 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3737 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3738 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3739 AxumMessage::Close(frame) => {
3740 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3741 code: frame.code.into(),
3742 reason: frame.reason.as_ref().into(),
3743 }))
3744 }
3745 }
3746}
3747
3748fn notify_membership_updated(
3749 connection_pool: &mut ConnectionPool,
3750 result: MembershipUpdated,
3751 user_id: UserId,
3752 peer: &Peer,
3753) {
3754 for membership in &result.new_channels.channel_memberships {
3755 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3756 }
3757 for channel_id in &result.removed_channels {
3758 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3759 }
3760
3761 let user_channels_update = proto::UpdateUserChannels {
3762 channel_memberships: result
3763 .new_channels
3764 .channel_memberships
3765 .iter()
3766 .map(|cm| proto::ChannelMembership {
3767 channel_id: cm.channel_id.to_proto(),
3768 role: cm.role.into(),
3769 })
3770 .collect(),
3771 ..Default::default()
3772 };
3773
3774 let mut update = build_channels_update(result.new_channels);
3775 update.delete_channels = result
3776 .removed_channels
3777 .into_iter()
3778 .map(|id| id.to_proto())
3779 .collect();
3780 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3781
3782 for connection_id in connection_pool.user_connection_ids(user_id) {
3783 peer.send(connection_id, user_channels_update.clone())
3784 .trace_err();
3785 peer.send(connection_id, update.clone()).trace_err();
3786 }
3787}
3788
3789fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3790 proto::UpdateUserChannels {
3791 channel_memberships: channels
3792 .channel_memberships
3793 .iter()
3794 .map(|m| proto::ChannelMembership {
3795 channel_id: m.channel_id.to_proto(),
3796 role: m.role.into(),
3797 })
3798 .collect(),
3799 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3800 }
3801}
3802
3803fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3804 let mut update = proto::UpdateChannels::default();
3805
3806 for channel in channels.channels {
3807 update.channels.push(channel.to_proto());
3808 }
3809
3810 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3811
3812 for (channel_id, participants) in channels.channel_participants {
3813 update
3814 .channel_participants
3815 .push(proto::ChannelParticipants {
3816 channel_id: channel_id.to_proto(),
3817 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3818 });
3819 }
3820
3821 for channel in channels.invited_channels {
3822 update.channel_invitations.push(channel.to_proto());
3823 }
3824
3825 update
3826}
3827
3828fn build_initial_contacts_update(
3829 contacts: Vec<db::Contact>,
3830 pool: &ConnectionPool,
3831) -> proto::UpdateContacts {
3832 let mut update = proto::UpdateContacts::default();
3833
3834 for contact in contacts {
3835 match contact {
3836 db::Contact::Accepted { user_id, busy } => {
3837 update.contacts.push(contact_for_user(user_id, busy, pool));
3838 }
3839 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3840 db::Contact::Incoming { user_id } => {
3841 update
3842 .incoming_requests
3843 .push(proto::IncomingContactRequest {
3844 requester_id: user_id.to_proto(),
3845 })
3846 }
3847 }
3848 }
3849
3850 update
3851}
3852
3853fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3854 proto::Contact {
3855 user_id: user_id.to_proto(),
3856 online: pool.is_user_online(user_id),
3857 busy,
3858 }
3859}
3860
3861fn room_updated(room: &proto::Room, peer: &Peer) {
3862 broadcast(
3863 None,
3864 room.participants
3865 .iter()
3866 .filter_map(|participant| Some(participant.peer_id?.into())),
3867 |peer_id| {
3868 peer.send(
3869 peer_id,
3870 proto::RoomUpdated {
3871 room: Some(room.clone()),
3872 },
3873 )
3874 },
3875 );
3876}
3877
3878fn channel_updated(
3879 channel: &db::channel::Model,
3880 room: &proto::Room,
3881 peer: &Peer,
3882 pool: &ConnectionPool,
3883) {
3884 let participants = room
3885 .participants
3886 .iter()
3887 .map(|p| p.user_id)
3888 .collect::<Vec<_>>();
3889
3890 broadcast(
3891 None,
3892 pool.channel_connection_ids(channel.root_id())
3893 .filter_map(|(channel_id, role)| {
3894 role.can_see_channel(channel.visibility)
3895 .then_some(channel_id)
3896 }),
3897 |peer_id| {
3898 peer.send(
3899 peer_id,
3900 proto::UpdateChannels {
3901 channel_participants: vec![proto::ChannelParticipants {
3902 channel_id: channel.id.to_proto(),
3903 participant_user_ids: participants.clone(),
3904 }],
3905 ..Default::default()
3906 },
3907 )
3908 },
3909 );
3910}
3911
3912async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3913 let db = session.db().await;
3914
3915 let contacts = db.get_contacts(user_id).await?;
3916 let busy = db.is_user_busy(user_id).await?;
3917
3918 let pool = session.connection_pool().await;
3919 let updated_contact = contact_for_user(user_id, busy, &pool);
3920 for contact in contacts {
3921 if let db::Contact::Accepted {
3922 user_id: contact_user_id,
3923 ..
3924 } = contact
3925 {
3926 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3927 session
3928 .peer
3929 .send(
3930 contact_conn_id,
3931 proto::UpdateContacts {
3932 contacts: vec![updated_contact.clone()],
3933 remove_contacts: Default::default(),
3934 incoming_requests: Default::default(),
3935 remove_incoming_requests: Default::default(),
3936 outgoing_requests: Default::default(),
3937 remove_outgoing_requests: Default::default(),
3938 },
3939 )
3940 .trace_err();
3941 }
3942 }
3943 }
3944 Ok(())
3945}
3946
3947async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3948 let mut contacts_to_update = HashSet::default();
3949
3950 let room_id;
3951 let canceled_calls_to_user_ids;
3952 let livekit_room;
3953 let delete_livekit_room;
3954 let room;
3955 let channel;
3956
3957 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3958 contacts_to_update.insert(session.user_id());
3959
3960 for project in left_room.left_projects.values() {
3961 project_left(project, session);
3962 }
3963
3964 room_id = RoomId::from_proto(left_room.room.id);
3965 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3966 livekit_room = mem::take(&mut left_room.room.livekit_room);
3967 delete_livekit_room = left_room.deleted;
3968 room = mem::take(&mut left_room.room);
3969 channel = mem::take(&mut left_room.channel);
3970
3971 room_updated(&room, &session.peer);
3972 } else {
3973 return Ok(());
3974 }
3975
3976 if let Some(channel) = channel {
3977 channel_updated(
3978 &channel,
3979 &room,
3980 &session.peer,
3981 &*session.connection_pool().await,
3982 );
3983 }
3984
3985 {
3986 let pool = session.connection_pool().await;
3987 for canceled_user_id in canceled_calls_to_user_ids {
3988 for connection_id in pool.user_connection_ids(canceled_user_id) {
3989 session
3990 .peer
3991 .send(
3992 connection_id,
3993 proto::CallCanceled {
3994 room_id: room_id.to_proto(),
3995 },
3996 )
3997 .trace_err();
3998 }
3999 contacts_to_update.insert(canceled_user_id);
4000 }
4001 }
4002
4003 for contact_user_id in contacts_to_update {
4004 update_user_contacts(contact_user_id, session).await?;
4005 }
4006
4007 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4008 live_kit
4009 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4010 .await
4011 .trace_err();
4012
4013 if delete_livekit_room {
4014 live_kit.delete_room(livekit_room).await.trace_err();
4015 }
4016 }
4017
4018 Ok(())
4019}
4020
4021async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4022 let left_channel_buffers = session
4023 .db()
4024 .await
4025 .leave_channel_buffers(session.connection_id)
4026 .await?;
4027
4028 for left_buffer in left_channel_buffers {
4029 channel_buffer_updated(
4030 session.connection_id,
4031 left_buffer.connections,
4032 &proto::UpdateChannelBufferCollaborators {
4033 channel_id: left_buffer.channel_id.to_proto(),
4034 collaborators: left_buffer.collaborators,
4035 },
4036 &session.peer,
4037 );
4038 }
4039
4040 Ok(())
4041}
4042
4043fn project_left(project: &db::LeftProject, session: &Session) {
4044 for connection_id in &project.connection_ids {
4045 if project.should_unshare {
4046 session
4047 .peer
4048 .send(
4049 *connection_id,
4050 proto::UnshareProject {
4051 project_id: project.id.to_proto(),
4052 },
4053 )
4054 .trace_err();
4055 } else {
4056 session
4057 .peer
4058 .send(
4059 *connection_id,
4060 proto::RemoveProjectCollaborator {
4061 project_id: project.id.to_proto(),
4062 peer_id: Some(session.connection_id.into()),
4063 },
4064 )
4065 .trace_err();
4066 }
4067 }
4068}
4069
4070pub trait ResultExt {
4071 type Ok;
4072
4073 fn trace_err(self) -> Option<Self::Ok>;
4074}
4075
4076impl<T, E> ResultExt for Result<T, E>
4077where
4078 E: std::fmt::Debug,
4079{
4080 type Ok = T;
4081
4082 #[track_caller]
4083 fn trace_err(self) -> Option<T> {
4084 match self {
4085 Ok(value) => Some(value),
4086 Err(error) => {
4087 tracing::error!("{:?}", error);
4088 None
4089 }
4090 }
4091 }
4092}