1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
10 },
11 executor::Executor,
12};
13use anyhow::{Context as _, anyhow, bail};
14use async_tungstenite::tungstenite::{
15 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
16};
17use axum::headers::UserAgent;
18use axum::{
19 Extension, Router, TypedHeader,
20 body::Body,
21 extract::{
22 ConnectInfo, WebSocketUpgrade,
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use core::fmt::{self, Debug, Formatter};
34use futures::TryFutureExt as _;
35use rpc::proto::split_repository_update;
36use tracing::Span;
37use util::paths::PathStyle;
38
39use futures::{
40 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
41 stream::FuturesUnordered,
42};
43use prometheus::{IntGauge, register_int_gauge};
44use rpc::{
45 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
46 proto::{
47 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
48 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
49 },
50};
51use semver::Version;
52use serde::{Serialize, Serializer};
53use std::{
54 any::TypeId,
55 future::Future,
56 marker::PhantomData,
57 mem,
58 net::SocketAddr,
59 ops::{Deref, DerefMut},
60 rc::Rc,
61 sync::{
62 Arc, OnceLock,
63 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
64 },
65 time::{Duration, Instant},
66};
67use tokio::sync::{Semaphore, watch};
68use tower::ServiceBuilder;
69use tracing::{
70 Instrument,
71 field::{self},
72 info_span, instrument,
73};
74
75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
76
77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
79
80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
82
83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
84
85const TOTAL_DURATION_MS: &str = "total_duration_ms";
86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
88const HOST_WAITING_MS: &str = "host_waiting_ms";
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
92
93pub struct ConnectionGuard;
94
95impl ConnectionGuard {
96 pub fn try_acquire() -> Result<Self, ()> {
97 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
98 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
99 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
100 tracing::error!(
101 "too many concurrent connections: {}",
102 current_connections + 1
103 );
104 return Err(());
105 }
106 Ok(ConnectionGuard)
107 }
108}
109
110impl Drop for ConnectionGuard {
111 fn drop(&mut self) {
112 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
113 }
114}
115
116struct Response<R> {
117 peer: Arc<Peer>,
118 receipt: Receipt<R>,
119 responded: Arc<AtomicBool>,
120}
121
122impl<R: RequestMessage> Response<R> {
123 fn send(self, payload: R::Response) -> Result<()> {
124 self.responded.store(true, SeqCst);
125 self.peer.respond(self.receipt, payload)?;
126 Ok(())
127 }
128}
129
130#[derive(Clone, Debug)]
131pub enum Principal {
132 User(User),
133 Impersonated { user: User, admin: User },
134}
135
136impl Principal {
137 fn update_span(&self, span: &tracing::Span) {
138 match &self {
139 Principal::User(user) => {
140 span.record("user_id", user.id.0);
141 span.record("login", &user.github_login);
142 }
143 Principal::Impersonated { user, admin } => {
144 span.record("user_id", user.id.0);
145 span.record("login", &user.github_login);
146 span.record("impersonator", &admin.github_login);
147 }
148 }
149 }
150}
151
152#[derive(Clone)]
153struct MessageContext {
154 session: Session,
155 span: tracing::Span,
156}
157
158impl Deref for MessageContext {
159 type Target = Session;
160
161 fn deref(&self) -> &Self::Target {
162 &self.session
163 }
164}
165
166impl MessageContext {
167 pub fn forward_request<T: RequestMessage>(
168 &self,
169 receiver_id: ConnectionId,
170 request: T,
171 ) -> impl Future<Output = anyhow::Result<T::Response>> {
172 let request_start_time = Instant::now();
173 let span = self.span.clone();
174 tracing::info!("start forwarding request");
175 self.peer
176 .forward_request(self.connection_id, receiver_id, request)
177 .inspect(move |_| {
178 span.record(
179 HOST_WAITING_MS,
180 request_start_time.elapsed().as_micros() as f64 / 1000.0,
181 );
182 })
183 .inspect_err(|_| tracing::error!("error forwarding request"))
184 .inspect_ok(|_| tracing::info!("finished forwarding request"))
185 }
186}
187
188#[derive(Clone)]
189struct Session {
190 principal: Principal,
191 connection_id: ConnectionId,
192 db: Arc<tokio::sync::Mutex<DbHandle>>,
193 peer: Arc<Peer>,
194 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
195 app_state: Arc<AppState>,
196 /// The GeoIP country code for the user.
197 #[allow(unused)]
198 geoip_country_code: Option<String>,
199 #[allow(unused)]
200 system_id: Option<String>,
201 _executor: Executor,
202}
203
204impl Session {
205 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
206 #[cfg(test)]
207 tokio::task::yield_now().await;
208 let guard = self.db.lock().await;
209 #[cfg(test)]
210 tokio::task::yield_now().await;
211 guard
212 }
213
214 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
215 #[cfg(test)]
216 tokio::task::yield_now().await;
217 let guard = self.connection_pool.lock();
218 ConnectionPoolGuard {
219 guard,
220 _not_send: PhantomData,
221 }
222 }
223
224 #[expect(dead_code)]
225 fn is_staff(&self) -> bool {
226 match &self.principal {
227 Principal::User(user) => user.admin,
228 Principal::Impersonated { .. } => true,
229 }
230 }
231
232 fn user_id(&self) -> UserId {
233 match &self.principal {
234 Principal::User(user) => user.id,
235 Principal::Impersonated { user, .. } => user.id,
236 }
237 }
238}
239
240impl Debug for Session {
241 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
242 let mut result = f.debug_struct("Session");
243 match &self.principal {
244 Principal::User(user) => {
245 result.field("user", &user.github_login);
246 }
247 Principal::Impersonated { user, admin } => {
248 result.field("user", &user.github_login);
249 result.field("impersonator", &admin.github_login);
250 }
251 }
252 result.field("connection_id", &self.connection_id).finish()
253 }
254}
255
256struct DbHandle(Arc<Database>);
257
258impl Deref for DbHandle {
259 type Target = Database;
260
261 fn deref(&self) -> &Self::Target {
262 self.0.as_ref()
263 }
264}
265
266pub struct Server {
267 id: parking_lot::Mutex<ServerId>,
268 peer: Arc<Peer>,
269 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
270 app_state: Arc<AppState>,
271 handlers: HashMap<TypeId, MessageHandler>,
272 teardown: watch::Sender<bool>,
273}
274
275pub(crate) struct ConnectionPoolGuard<'a> {
276 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
277 _not_send: PhantomData<Rc<()>>,
278}
279
280#[derive(Serialize)]
281pub struct ServerSnapshot<'a> {
282 peer: &'a Peer,
283 #[serde(serialize_with = "serialize_deref")]
284 connection_pool: ConnectionPoolGuard<'a>,
285}
286
287pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
288where
289 S: Serializer,
290 T: Deref<Target = U>,
291 U: Serialize,
292{
293 Serialize::serialize(value.deref(), serializer)
294}
295
296impl Server {
297 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
298 let mut server = Self {
299 id: parking_lot::Mutex::new(id),
300 peer: Peer::new(id.0 as u32),
301 app_state,
302 connection_pool: Default::default(),
303 handlers: Default::default(),
304 teardown: watch::channel(false).0,
305 };
306
307 server
308 .add_request_handler(ping)
309 .add_request_handler(create_room)
310 .add_request_handler(join_room)
311 .add_request_handler(rejoin_room)
312 .add_request_handler(leave_room)
313 .add_request_handler(set_room_participant_role)
314 .add_request_handler(call)
315 .add_request_handler(cancel_call)
316 .add_message_handler(decline_call)
317 .add_request_handler(update_participant_location)
318 .add_request_handler(share_project)
319 .add_message_handler(unshare_project)
320 .add_request_handler(join_project)
321 .add_message_handler(leave_project)
322 .add_request_handler(update_project)
323 .add_request_handler(update_worktree)
324 .add_request_handler(update_repository)
325 .add_request_handler(remove_repository)
326 .add_message_handler(start_language_server)
327 .add_message_handler(update_language_server)
328 .add_message_handler(update_diagnostic_summary)
329 .add_message_handler(update_worktree_settings)
330 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
331 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
332 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
333 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
334 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
335 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
336 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
337 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
338 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
339 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
340 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
341 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
342 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
343 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
344 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
345 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
346 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
347 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
348 .add_request_handler(
349 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
350 )
351 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
352 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
353 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
354 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
355 .add_request_handler(
356 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
357 )
358 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
359 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
360 .add_request_handler(
361 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
362 )
363 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
364 .add_request_handler(
365 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
366 )
367 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
368 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
369 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
370 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
371 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
372 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
373 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
374 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
375 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
376 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
377 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
378 .add_request_handler(
379 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
380 )
381 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
382 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
383 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
384 .add_request_handler(lsp_query)
385 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
386 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
387 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
388 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
389 .add_message_handler(create_buffer_for_peer)
390 .add_message_handler(create_image_for_peer)
391 .add_request_handler(update_buffer)
392 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
393 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
394 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
395 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
396 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
397 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
398 .add_message_handler(
399 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
400 )
401 .add_request_handler(get_users)
402 .add_request_handler(fuzzy_search_users)
403 .add_request_handler(request_contact)
404 .add_request_handler(remove_contact)
405 .add_request_handler(respond_to_contact_request)
406 .add_message_handler(subscribe_to_channels)
407 .add_request_handler(create_channel)
408 .add_request_handler(delete_channel)
409 .add_request_handler(invite_channel_member)
410 .add_request_handler(remove_channel_member)
411 .add_request_handler(set_channel_member_role)
412 .add_request_handler(set_channel_visibility)
413 .add_request_handler(rename_channel)
414 .add_request_handler(join_channel_buffer)
415 .add_request_handler(leave_channel_buffer)
416 .add_message_handler(update_channel_buffer)
417 .add_request_handler(rejoin_channel_buffers)
418 .add_request_handler(get_channel_members)
419 .add_request_handler(respond_to_channel_invite)
420 .add_request_handler(join_channel)
421 .add_request_handler(join_channel_chat)
422 .add_message_handler(leave_channel_chat)
423 .add_request_handler(send_channel_message)
424 .add_request_handler(remove_channel_message)
425 .add_request_handler(update_channel_message)
426 .add_request_handler(get_channel_messages)
427 .add_request_handler(get_channel_messages_by_id)
428 .add_request_handler(get_notifications)
429 .add_request_handler(mark_notification_as_read)
430 .add_request_handler(move_channel)
431 .add_request_handler(reorder_channel)
432 .add_request_handler(follow)
433 .add_message_handler(unfollow)
434 .add_message_handler(update_followers)
435 .add_message_handler(acknowledge_channel_message)
436 .add_message_handler(acknowledge_buffer_version)
437 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
438 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
439 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
440 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
441 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
442 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
443 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
444 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
445 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
446 .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
447 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
448 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
449 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
450 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
451 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
452 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
453 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
454 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
455 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
456 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
457 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
458 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
459 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
460 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
461 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
462 .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
463 .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
464 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
465 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
466 .add_message_handler(update_context)
467 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
468 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
469
470 Arc::new(server)
471 }
472
473 pub async fn start(&self) -> Result<()> {
474 let server_id = *self.id.lock();
475 let app_state = self.app_state.clone();
476 let peer = self.peer.clone();
477 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
478 let pool = self.connection_pool.clone();
479 let livekit_client = self.app_state.livekit_client.clone();
480
481 let span = info_span!("start server");
482 self.app_state.executor.spawn_detached(
483 async move {
484 tracing::info!("waiting for cleanup timeout");
485 timeout.await;
486 tracing::info!("cleanup timeout expired, retrieving stale rooms");
487
488 app_state
489 .db
490 .delete_stale_channel_chat_participants(
491 &app_state.config.zed_environment,
492 server_id,
493 )
494 .await
495 .trace_err();
496
497 if let Some((room_ids, channel_ids)) = app_state
498 .db
499 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
500 .await
501 .trace_err()
502 {
503 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
504 tracing::info!(
505 stale_channel_buffer_count = channel_ids.len(),
506 "retrieved stale channel buffers"
507 );
508
509 for channel_id in channel_ids {
510 if let Some(refreshed_channel_buffer) = app_state
511 .db
512 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
513 .await
514 .trace_err()
515 {
516 for connection_id in refreshed_channel_buffer.connection_ids {
517 peer.send(
518 connection_id,
519 proto::UpdateChannelBufferCollaborators {
520 channel_id: channel_id.to_proto(),
521 collaborators: refreshed_channel_buffer
522 .collaborators
523 .clone(),
524 },
525 )
526 .trace_err();
527 }
528 }
529 }
530
531 for room_id in room_ids {
532 let mut contacts_to_update = HashSet::default();
533 let mut canceled_calls_to_user_ids = Vec::new();
534 let mut livekit_room = String::new();
535 let mut delete_livekit_room = false;
536
537 if let Some(mut refreshed_room) = app_state
538 .db
539 .clear_stale_room_participants(room_id, server_id)
540 .await
541 .trace_err()
542 {
543 tracing::info!(
544 room_id = room_id.0,
545 new_participant_count = refreshed_room.room.participants.len(),
546 "refreshed room"
547 );
548 room_updated(&refreshed_room.room, &peer);
549 if let Some(channel) = refreshed_room.channel.as_ref() {
550 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
551 }
552 contacts_to_update
553 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
554 contacts_to_update
555 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
556 canceled_calls_to_user_ids =
557 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
558 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
559 delete_livekit_room = refreshed_room.room.participants.is_empty();
560 }
561
562 {
563 let pool = pool.lock();
564 for canceled_user_id in canceled_calls_to_user_ids {
565 for connection_id in pool.user_connection_ids(canceled_user_id) {
566 peer.send(
567 connection_id,
568 proto::CallCanceled {
569 room_id: room_id.to_proto(),
570 },
571 )
572 .trace_err();
573 }
574 }
575 }
576
577 for user_id in contacts_to_update {
578 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
579 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
580 if let Some((busy, contacts)) = busy.zip(contacts) {
581 let pool = pool.lock();
582 let updated_contact = contact_for_user(user_id, busy, &pool);
583 for contact in contacts {
584 if let db::Contact::Accepted {
585 user_id: contact_user_id,
586 ..
587 } = contact
588 {
589 for contact_conn_id in
590 pool.user_connection_ids(contact_user_id)
591 {
592 peer.send(
593 contact_conn_id,
594 proto::UpdateContacts {
595 contacts: vec![updated_contact.clone()],
596 remove_contacts: Default::default(),
597 incoming_requests: Default::default(),
598 remove_incoming_requests: Default::default(),
599 outgoing_requests: Default::default(),
600 remove_outgoing_requests: Default::default(),
601 },
602 )
603 .trace_err();
604 }
605 }
606 }
607 }
608 }
609
610 if let Some(live_kit) = livekit_client.as_ref()
611 && delete_livekit_room
612 {
613 live_kit.delete_room(livekit_room).await.trace_err();
614 }
615 }
616 }
617
618 app_state
619 .db
620 .delete_stale_channel_chat_participants(
621 &app_state.config.zed_environment,
622 server_id,
623 )
624 .await
625 .trace_err();
626
627 app_state
628 .db
629 .clear_old_worktree_entries(server_id)
630 .await
631 .trace_err();
632
633 app_state
634 .db
635 .delete_stale_servers(&app_state.config.zed_environment, server_id)
636 .await
637 .trace_err();
638 }
639 .instrument(span),
640 );
641 Ok(())
642 }
643
644 pub fn teardown(&self) {
645 self.peer.teardown();
646 self.connection_pool.lock().reset();
647 let _ = self.teardown.send(true);
648 }
649
650 #[cfg(test)]
651 pub fn reset(&self, id: ServerId) {
652 self.teardown();
653 *self.id.lock() = id;
654 self.peer.reset(id.0 as u32);
655 let _ = self.teardown.send(false);
656 }
657
658 #[cfg(test)]
659 pub fn id(&self) -> ServerId {
660 *self.id.lock()
661 }
662
663 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
664 where
665 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
666 Fut: 'static + Send + Future<Output = Result<()>>,
667 M: EnvelopedMessage,
668 {
669 let prev_handler = self.handlers.insert(
670 TypeId::of::<M>(),
671 Box::new(move |envelope, session, span| {
672 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
673 let received_at = envelope.received_at;
674 tracing::info!("message received");
675 let start_time = Instant::now();
676 let future = (handler)(
677 *envelope,
678 MessageContext {
679 session,
680 span: span.clone(),
681 },
682 );
683 async move {
684 let result = future.await;
685 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
686 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
687 let queue_duration_ms = total_duration_ms - processing_duration_ms;
688 span.record(TOTAL_DURATION_MS, total_duration_ms);
689 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
690 span.record(QUEUE_DURATION_MS, queue_duration_ms);
691 match result {
692 Err(error) => {
693 tracing::error!(?error, "error handling message")
694 }
695 Ok(()) => tracing::info!("finished handling message"),
696 }
697 }
698 .boxed()
699 }),
700 );
701 if prev_handler.is_some() {
702 panic!("registered a handler for the same message twice");
703 }
704 self
705 }
706
707 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
708 where
709 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
710 Fut: 'static + Send + Future<Output = Result<()>>,
711 M: EnvelopedMessage,
712 {
713 self.add_handler(move |envelope, session| handler(envelope.payload, session));
714 self
715 }
716
717 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
718 where
719 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
720 Fut: Send + Future<Output = Result<()>>,
721 M: RequestMessage,
722 {
723 let handler = Arc::new(handler);
724 self.add_handler(move |envelope, session| {
725 let receipt = envelope.receipt();
726 let handler = handler.clone();
727 async move {
728 let peer = session.peer.clone();
729 let responded = Arc::new(AtomicBool::default());
730 let response = Response {
731 peer: peer.clone(),
732 responded: responded.clone(),
733 receipt,
734 };
735 match (handler)(envelope.payload, response, session).await {
736 Ok(()) => {
737 if responded.load(std::sync::atomic::Ordering::SeqCst) {
738 Ok(())
739 } else {
740 Err(anyhow!("handler did not send a response"))?
741 }
742 }
743 Err(error) => {
744 let proto_err = match &error {
745 Error::Internal(err) => err.to_proto(),
746 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
747 };
748 peer.respond_with_error(receipt, proto_err)?;
749 Err(error)
750 }
751 }
752 }
753 })
754 }
755
756 pub fn handle_connection(
757 self: &Arc<Self>,
758 connection: Connection,
759 address: String,
760 principal: Principal,
761 zed_version: ZedVersion,
762 release_channel: Option<String>,
763 user_agent: Option<String>,
764 geoip_country_code: Option<String>,
765 system_id: Option<String>,
766 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
767 executor: Executor,
768 connection_guard: Option<ConnectionGuard>,
769 ) -> impl Future<Output = ()> + use<> {
770 let this = self.clone();
771 let span = info_span!("handle connection", %address,
772 connection_id=field::Empty,
773 user_id=field::Empty,
774 login=field::Empty,
775 impersonator=field::Empty,
776 user_agent=field::Empty,
777 geoip_country_code=field::Empty,
778 release_channel=field::Empty,
779 );
780 principal.update_span(&span);
781 if let Some(user_agent) = user_agent {
782 span.record("user_agent", user_agent);
783 }
784 if let Some(release_channel) = release_channel {
785 span.record("release_channel", release_channel);
786 }
787
788 if let Some(country_code) = geoip_country_code.as_ref() {
789 span.record("geoip_country_code", country_code);
790 }
791
792 let mut teardown = self.teardown.subscribe();
793 async move {
794 if *teardown.borrow() {
795 tracing::error!("server is tearing down");
796 return;
797 }
798
799 let (connection_id, handle_io, mut incoming_rx) =
800 this.peer.add_connection(connection, {
801 let executor = executor.clone();
802 move |duration| executor.sleep(duration)
803 });
804 tracing::Span::current().record("connection_id", format!("{}", connection_id));
805
806 tracing::info!("connection opened");
807
808 let session = Session {
809 principal: principal.clone(),
810 connection_id,
811 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
812 peer: this.peer.clone(),
813 connection_pool: this.connection_pool.clone(),
814 app_state: this.app_state.clone(),
815 geoip_country_code,
816 system_id,
817 _executor: executor.clone(),
818 };
819
820 if let Err(error) = this
821 .send_initial_client_update(
822 connection_id,
823 zed_version,
824 send_connection_id,
825 &session,
826 )
827 .await
828 {
829 tracing::error!(?error, "failed to send initial client update");
830 return;
831 }
832 drop(connection_guard);
833
834 let handle_io = handle_io.fuse();
835 futures::pin_mut!(handle_io);
836
837 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
838 // This prevents deadlocks when e.g., client A performs a request to client B and
839 // client B performs a request to client A. If both clients stop processing further
840 // messages until their respective request completes, they won't have a chance to
841 // respond to the other client's request and cause a deadlock.
842 //
843 // This arrangement ensures we will attempt to process earlier messages first, but fall
844 // back to processing messages arrived later in the spirit of making progress.
845 const MAX_CONCURRENT_HANDLERS: usize = 256;
846 let mut foreground_message_handlers = FuturesUnordered::new();
847 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
848 let get_concurrent_handlers = {
849 let concurrent_handlers = concurrent_handlers.clone();
850 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
851 };
852 loop {
853 let next_message = async {
854 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
855 let message = incoming_rx.next().await;
856 // Cache the concurrent_handlers here, so that we know what the
857 // queue looks like as each handler starts
858 (permit, message, get_concurrent_handlers())
859 }
860 .fuse();
861 futures::pin_mut!(next_message);
862 futures::select_biased! {
863 _ = teardown.changed().fuse() => return,
864 result = handle_io => {
865 if let Err(error) = result {
866 tracing::error!(?error, "error handling I/O");
867 }
868 break;
869 }
870 _ = foreground_message_handlers.next() => {}
871 next_message = next_message => {
872 let (permit, message, concurrent_handlers) = next_message;
873 if let Some(message) = message {
874 let type_name = message.payload_type_name();
875 // note: we copy all the fields from the parent span so we can query them in the logs.
876 // (https://github.com/tokio-rs/tracing/issues/2670).
877 let span = tracing::info_span!("receive message",
878 %connection_id,
879 %address,
880 type_name,
881 concurrent_handlers,
882 user_id=field::Empty,
883 login=field::Empty,
884 impersonator=field::Empty,
885 lsp_query_request=field::Empty,
886 release_channel=field::Empty,
887 { TOTAL_DURATION_MS }=field::Empty,
888 { PROCESSING_DURATION_MS }=field::Empty,
889 { QUEUE_DURATION_MS }=field::Empty,
890 { HOST_WAITING_MS }=field::Empty
891 );
892 principal.update_span(&span);
893 let span_enter = span.enter();
894 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
895 let is_background = message.is_background();
896 let handle_message = (handler)(message, session.clone(), span.clone());
897 drop(span_enter);
898
899 let handle_message = async move {
900 handle_message.await;
901 drop(permit);
902 }.instrument(span);
903 if is_background {
904 executor.spawn_detached(handle_message);
905 } else {
906 foreground_message_handlers.push(handle_message);
907 }
908 } else {
909 tracing::error!("no message handler");
910 }
911 } else {
912 tracing::info!("connection closed");
913 break;
914 }
915 }
916 }
917 }
918
919 drop(foreground_message_handlers);
920 let concurrent_handlers = get_concurrent_handlers();
921 tracing::info!(concurrent_handlers, "signing out");
922 if let Err(error) = connection_lost(session, teardown, executor).await {
923 tracing::error!(?error, "error signing out");
924 }
925 }
926 .instrument(span)
927 }
928
929 async fn send_initial_client_update(
930 &self,
931 connection_id: ConnectionId,
932 zed_version: ZedVersion,
933 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
934 session: &Session,
935 ) -> Result<()> {
936 self.peer.send(
937 connection_id,
938 proto::Hello {
939 peer_id: Some(connection_id.into()),
940 },
941 )?;
942 tracing::info!("sent hello message");
943 if let Some(send_connection_id) = send_connection_id.take() {
944 let _ = send_connection_id.send(connection_id);
945 }
946
947 match &session.principal {
948 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
949 if !user.connected_once {
950 self.peer.send(connection_id, proto::ShowContacts {})?;
951 self.app_state
952 .db
953 .set_user_connected_once(user.id, true)
954 .await?;
955 }
956
957 let contacts = self.app_state.db.get_contacts(user.id).await?;
958
959 {
960 let mut pool = self.connection_pool.lock();
961 pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
962 self.peer.send(
963 connection_id,
964 build_initial_contacts_update(contacts, &pool),
965 )?;
966 }
967
968 if should_auto_subscribe_to_channels(&zed_version) {
969 subscribe_user_to_channels(user.id, session).await?;
970 }
971
972 if let Some(incoming_call) =
973 self.app_state.db.incoming_call_for_user(user.id).await?
974 {
975 self.peer.send(connection_id, incoming_call)?;
976 }
977
978 update_user_contacts(user.id, session).await?;
979 }
980 }
981
982 Ok(())
983 }
984
985 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
986 ServerSnapshot {
987 connection_pool: ConnectionPoolGuard {
988 guard: self.connection_pool.lock(),
989 _not_send: PhantomData,
990 },
991 peer: &self.peer,
992 }
993 }
994}
995
996impl Deref for ConnectionPoolGuard<'_> {
997 type Target = ConnectionPool;
998
999 fn deref(&self) -> &Self::Target {
1000 &self.guard
1001 }
1002}
1003
1004impl DerefMut for ConnectionPoolGuard<'_> {
1005 fn deref_mut(&mut self) -> &mut Self::Target {
1006 &mut self.guard
1007 }
1008}
1009
1010impl Drop for ConnectionPoolGuard<'_> {
1011 fn drop(&mut self) {
1012 #[cfg(test)]
1013 self.check_invariants();
1014 }
1015}
1016
1017fn broadcast<F>(
1018 sender_id: Option<ConnectionId>,
1019 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1020 mut f: F,
1021) where
1022 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1023{
1024 for receiver_id in receiver_ids {
1025 if Some(receiver_id) != sender_id
1026 && let Err(error) = f(receiver_id)
1027 {
1028 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1029 }
1030 }
1031}
1032
1033pub struct ProtocolVersion(u32);
1034
1035impl Header for ProtocolVersion {
1036 fn name() -> &'static HeaderName {
1037 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1038 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1039 }
1040
1041 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1042 where
1043 Self: Sized,
1044 I: Iterator<Item = &'i axum::http::HeaderValue>,
1045 {
1046 let version = values
1047 .next()
1048 .ok_or_else(axum::headers::Error::invalid)?
1049 .to_str()
1050 .map_err(|_| axum::headers::Error::invalid())?
1051 .parse()
1052 .map_err(|_| axum::headers::Error::invalid())?;
1053 Ok(Self(version))
1054 }
1055
1056 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1057 values.extend([self.0.to_string().parse().unwrap()]);
1058 }
1059}
1060
1061pub struct AppVersionHeader(Version);
1062impl Header for AppVersionHeader {
1063 fn name() -> &'static HeaderName {
1064 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1065 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1066 }
1067
1068 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1069 where
1070 Self: Sized,
1071 I: Iterator<Item = &'i axum::http::HeaderValue>,
1072 {
1073 let version = values
1074 .next()
1075 .ok_or_else(axum::headers::Error::invalid)?
1076 .to_str()
1077 .map_err(|_| axum::headers::Error::invalid())?
1078 .parse()
1079 .map_err(|_| axum::headers::Error::invalid())?;
1080 Ok(Self(version))
1081 }
1082
1083 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1084 values.extend([self.0.to_string().parse().unwrap()]);
1085 }
1086}
1087
1088#[derive(Debug)]
1089pub struct ReleaseChannelHeader(String);
1090
1091impl Header for ReleaseChannelHeader {
1092 fn name() -> &'static HeaderName {
1093 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1094 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1095 }
1096
1097 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1098 where
1099 Self: Sized,
1100 I: Iterator<Item = &'i axum::http::HeaderValue>,
1101 {
1102 Ok(Self(
1103 values
1104 .next()
1105 .ok_or_else(axum::headers::Error::invalid)?
1106 .to_str()
1107 .map_err(|_| axum::headers::Error::invalid())?
1108 .to_owned(),
1109 ))
1110 }
1111
1112 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1113 values.extend([self.0.parse().unwrap()]);
1114 }
1115}
1116
1117pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1118 Router::new()
1119 .route("/rpc", get(handle_websocket_request))
1120 .layer(
1121 ServiceBuilder::new()
1122 .layer(Extension(server.app_state.clone()))
1123 .layer(middleware::from_fn(auth::validate_header)),
1124 )
1125 .route("/metrics", get(handle_metrics))
1126 .layer(Extension(server))
1127}
1128
1129pub async fn handle_websocket_request(
1130 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1131 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1132 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1133 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1134 Extension(server): Extension<Arc<Server>>,
1135 Extension(principal): Extension<Principal>,
1136 user_agent: Option<TypedHeader<UserAgent>>,
1137 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1138 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1139 ws: WebSocketUpgrade,
1140) -> axum::response::Response {
1141 if protocol_version != rpc::PROTOCOL_VERSION {
1142 return (
1143 StatusCode::UPGRADE_REQUIRED,
1144 "client must be upgraded".to_string(),
1145 )
1146 .into_response();
1147 }
1148
1149 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1150 return (
1151 StatusCode::UPGRADE_REQUIRED,
1152 "no version header found".to_string(),
1153 )
1154 .into_response();
1155 };
1156
1157 let release_channel = release_channel_header.map(|header| header.0.0);
1158
1159 if !version.can_collaborate() {
1160 return (
1161 StatusCode::UPGRADE_REQUIRED,
1162 "client must be upgraded".to_string(),
1163 )
1164 .into_response();
1165 }
1166
1167 let socket_address = socket_address.to_string();
1168
1169 // Acquire connection guard before WebSocket upgrade
1170 let connection_guard = match ConnectionGuard::try_acquire() {
1171 Ok(guard) => guard,
1172 Err(()) => {
1173 return (
1174 StatusCode::SERVICE_UNAVAILABLE,
1175 "Too many concurrent connections",
1176 )
1177 .into_response();
1178 }
1179 };
1180
1181 ws.on_upgrade(move |socket| {
1182 let socket = socket
1183 .map_ok(to_tungstenite_message)
1184 .err_into()
1185 .with(|message| async move { to_axum_message(message) });
1186 let connection = Connection::new(Box::pin(socket));
1187 async move {
1188 server
1189 .handle_connection(
1190 connection,
1191 socket_address,
1192 principal,
1193 version,
1194 release_channel,
1195 user_agent.map(|header| header.to_string()),
1196 country_code_header.map(|header| header.to_string()),
1197 system_id_header.map(|header| header.to_string()),
1198 None,
1199 Executor::Production,
1200 Some(connection_guard),
1201 )
1202 .await;
1203 }
1204 })
1205}
1206
1207pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1208 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1209 let connections_metric = CONNECTIONS_METRIC
1210 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1211
1212 let connections = server
1213 .connection_pool
1214 .lock()
1215 .connections()
1216 .filter(|connection| !connection.admin)
1217 .count();
1218 connections_metric.set(connections as _);
1219
1220 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1221 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1222 register_int_gauge!(
1223 "shared_projects",
1224 "number of open projects with one or more guests"
1225 )
1226 .unwrap()
1227 });
1228
1229 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1230 shared_projects_metric.set(shared_projects as _);
1231
1232 let encoder = prometheus::TextEncoder::new();
1233 let metric_families = prometheus::gather();
1234 let encoded_metrics = encoder
1235 .encode_to_string(&metric_families)
1236 .map_err(|err| anyhow!("{err}"))?;
1237 Ok(encoded_metrics)
1238}
1239
1240#[instrument(err, skip(executor))]
1241async fn connection_lost(
1242 session: Session,
1243 mut teardown: watch::Receiver<bool>,
1244 executor: Executor,
1245) -> Result<()> {
1246 session.peer.disconnect(session.connection_id);
1247 session
1248 .connection_pool()
1249 .await
1250 .remove_connection(session.connection_id)?;
1251
1252 session
1253 .db()
1254 .await
1255 .connection_lost(session.connection_id)
1256 .await
1257 .trace_err();
1258
1259 futures::select_biased! {
1260 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1261
1262 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1263 leave_room_for_session(&session, session.connection_id).await.trace_err();
1264 leave_channel_buffers_for_session(&session)
1265 .await
1266 .trace_err();
1267
1268 if !session
1269 .connection_pool()
1270 .await
1271 .is_user_online(session.user_id())
1272 {
1273 let db = session.db().await;
1274 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1275 room_updated(&room, &session.peer);
1276 }
1277 }
1278
1279 update_user_contacts(session.user_id(), &session).await?;
1280 },
1281 _ = teardown.changed().fuse() => {}
1282 }
1283
1284 Ok(())
1285}
1286
1287/// Acknowledges a ping from a client, used to keep the connection alive.
1288async fn ping(
1289 _: proto::Ping,
1290 response: Response<proto::Ping>,
1291 _session: MessageContext,
1292) -> Result<()> {
1293 response.send(proto::Ack {})?;
1294 Ok(())
1295}
1296
1297/// Creates a new room for calling (outside of channels)
1298async fn create_room(
1299 _request: proto::CreateRoom,
1300 response: Response<proto::CreateRoom>,
1301 session: MessageContext,
1302) -> Result<()> {
1303 let livekit_room = nanoid::nanoid!(30);
1304
1305 let live_kit_connection_info = util::maybe!(async {
1306 let live_kit = session.app_state.livekit_client.as_ref();
1307 let live_kit = live_kit?;
1308 let user_id = session.user_id().to_string();
1309
1310 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1311
1312 Some(proto::LiveKitConnectionInfo {
1313 server_url: live_kit.url().into(),
1314 token,
1315 can_publish: true,
1316 })
1317 })
1318 .await;
1319
1320 let room = session
1321 .db()
1322 .await
1323 .create_room(session.user_id(), session.connection_id, &livekit_room)
1324 .await?;
1325
1326 response.send(proto::CreateRoomResponse {
1327 room: Some(room.clone()),
1328 live_kit_connection_info,
1329 })?;
1330
1331 update_user_contacts(session.user_id(), &session).await?;
1332 Ok(())
1333}
1334
1335/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1336async fn join_room(
1337 request: proto::JoinRoom,
1338 response: Response<proto::JoinRoom>,
1339 session: MessageContext,
1340) -> Result<()> {
1341 let room_id = RoomId::from_proto(request.id);
1342
1343 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1344
1345 if let Some(channel_id) = channel_id {
1346 return join_channel_internal(channel_id, Box::new(response), session).await;
1347 }
1348
1349 let joined_room = {
1350 let room = session
1351 .db()
1352 .await
1353 .join_room(room_id, session.user_id(), session.connection_id)
1354 .await?;
1355 room_updated(&room.room, &session.peer);
1356 room.into_inner()
1357 };
1358
1359 for connection_id in session
1360 .connection_pool()
1361 .await
1362 .user_connection_ids(session.user_id())
1363 {
1364 session
1365 .peer
1366 .send(
1367 connection_id,
1368 proto::CallCanceled {
1369 room_id: room_id.to_proto(),
1370 },
1371 )
1372 .trace_err();
1373 }
1374
1375 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1376 {
1377 live_kit
1378 .room_token(
1379 &joined_room.room.livekit_room,
1380 &session.user_id().to_string(),
1381 )
1382 .trace_err()
1383 .map(|token| proto::LiveKitConnectionInfo {
1384 server_url: live_kit.url().into(),
1385 token,
1386 can_publish: true,
1387 })
1388 } else {
1389 None
1390 };
1391
1392 response.send(proto::JoinRoomResponse {
1393 room: Some(joined_room.room),
1394 channel_id: None,
1395 live_kit_connection_info,
1396 })?;
1397
1398 update_user_contacts(session.user_id(), &session).await?;
1399 Ok(())
1400}
1401
1402/// Rejoin room is used to reconnect to a room after connection errors.
1403async fn rejoin_room(
1404 request: proto::RejoinRoom,
1405 response: Response<proto::RejoinRoom>,
1406 session: MessageContext,
1407) -> Result<()> {
1408 let room;
1409 let channel;
1410 {
1411 let mut rejoined_room = session
1412 .db()
1413 .await
1414 .rejoin_room(request, session.user_id(), session.connection_id)
1415 .await?;
1416
1417 response.send(proto::RejoinRoomResponse {
1418 room: Some(rejoined_room.room.clone()),
1419 reshared_projects: rejoined_room
1420 .reshared_projects
1421 .iter()
1422 .map(|project| proto::ResharedProject {
1423 id: project.id.to_proto(),
1424 collaborators: project
1425 .collaborators
1426 .iter()
1427 .map(|collaborator| collaborator.to_proto())
1428 .collect(),
1429 })
1430 .collect(),
1431 rejoined_projects: rejoined_room
1432 .rejoined_projects
1433 .iter()
1434 .map(|rejoined_project| rejoined_project.to_proto())
1435 .collect(),
1436 })?;
1437 room_updated(&rejoined_room.room, &session.peer);
1438
1439 for project in &rejoined_room.reshared_projects {
1440 for collaborator in &project.collaborators {
1441 session
1442 .peer
1443 .send(
1444 collaborator.connection_id,
1445 proto::UpdateProjectCollaborator {
1446 project_id: project.id.to_proto(),
1447 old_peer_id: Some(project.old_connection_id.into()),
1448 new_peer_id: Some(session.connection_id.into()),
1449 },
1450 )
1451 .trace_err();
1452 }
1453
1454 broadcast(
1455 Some(session.connection_id),
1456 project
1457 .collaborators
1458 .iter()
1459 .map(|collaborator| collaborator.connection_id),
1460 |connection_id| {
1461 session.peer.forward_send(
1462 session.connection_id,
1463 connection_id,
1464 proto::UpdateProject {
1465 project_id: project.id.to_proto(),
1466 worktrees: project.worktrees.clone(),
1467 },
1468 )
1469 },
1470 );
1471 }
1472
1473 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1474
1475 let rejoined_room = rejoined_room.into_inner();
1476
1477 room = rejoined_room.room;
1478 channel = rejoined_room.channel;
1479 }
1480
1481 if let Some(channel) = channel {
1482 channel_updated(
1483 &channel,
1484 &room,
1485 &session.peer,
1486 &*session.connection_pool().await,
1487 );
1488 }
1489
1490 update_user_contacts(session.user_id(), &session).await?;
1491 Ok(())
1492}
1493
1494fn notify_rejoined_projects(
1495 rejoined_projects: &mut Vec<RejoinedProject>,
1496 session: &Session,
1497) -> Result<()> {
1498 for project in rejoined_projects.iter() {
1499 for collaborator in &project.collaborators {
1500 session
1501 .peer
1502 .send(
1503 collaborator.connection_id,
1504 proto::UpdateProjectCollaborator {
1505 project_id: project.id.to_proto(),
1506 old_peer_id: Some(project.old_connection_id.into()),
1507 new_peer_id: Some(session.connection_id.into()),
1508 },
1509 )
1510 .trace_err();
1511 }
1512 }
1513
1514 for project in rejoined_projects {
1515 for worktree in mem::take(&mut project.worktrees) {
1516 // Stream this worktree's entries.
1517 let message = proto::UpdateWorktree {
1518 project_id: project.id.to_proto(),
1519 worktree_id: worktree.id,
1520 abs_path: worktree.abs_path.clone(),
1521 root_name: worktree.root_name,
1522 updated_entries: worktree.updated_entries,
1523 removed_entries: worktree.removed_entries,
1524 scan_id: worktree.scan_id,
1525 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1526 updated_repositories: worktree.updated_repositories,
1527 removed_repositories: worktree.removed_repositories,
1528 };
1529 for update in proto::split_worktree_update(message) {
1530 session.peer.send(session.connection_id, update)?;
1531 }
1532
1533 // Stream this worktree's diagnostics.
1534 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1535 if let Some(summary) = worktree_diagnostics.next() {
1536 let message = proto::UpdateDiagnosticSummary {
1537 project_id: project.id.to_proto(),
1538 worktree_id: worktree.id,
1539 summary: Some(summary),
1540 more_summaries: worktree_diagnostics.collect(),
1541 };
1542 session.peer.send(session.connection_id, message)?;
1543 }
1544
1545 for settings_file in worktree.settings_files {
1546 session.peer.send(
1547 session.connection_id,
1548 proto::UpdateWorktreeSettings {
1549 project_id: project.id.to_proto(),
1550 worktree_id: worktree.id,
1551 path: settings_file.path,
1552 content: Some(settings_file.content),
1553 kind: Some(settings_file.kind.to_proto().into()),
1554 },
1555 )?;
1556 }
1557 }
1558
1559 for repository in mem::take(&mut project.updated_repositories) {
1560 for update in split_repository_update(repository) {
1561 session.peer.send(session.connection_id, update)?;
1562 }
1563 }
1564
1565 for id in mem::take(&mut project.removed_repositories) {
1566 session.peer.send(
1567 session.connection_id,
1568 proto::RemoveRepository {
1569 project_id: project.id.to_proto(),
1570 id,
1571 },
1572 )?;
1573 }
1574 }
1575
1576 Ok(())
1577}
1578
1579/// leave room disconnects from the room.
1580async fn leave_room(
1581 _: proto::LeaveRoom,
1582 response: Response<proto::LeaveRoom>,
1583 session: MessageContext,
1584) -> Result<()> {
1585 leave_room_for_session(&session, session.connection_id).await?;
1586 response.send(proto::Ack {})?;
1587 Ok(())
1588}
1589
1590/// Updates the permissions of someone else in the room.
1591async fn set_room_participant_role(
1592 request: proto::SetRoomParticipantRole,
1593 response: Response<proto::SetRoomParticipantRole>,
1594 session: MessageContext,
1595) -> Result<()> {
1596 let user_id = UserId::from_proto(request.user_id);
1597 let role = ChannelRole::from(request.role());
1598
1599 let (livekit_room, can_publish) = {
1600 let room = session
1601 .db()
1602 .await
1603 .set_room_participant_role(
1604 session.user_id(),
1605 RoomId::from_proto(request.room_id),
1606 user_id,
1607 role,
1608 )
1609 .await?;
1610
1611 let livekit_room = room.livekit_room.clone();
1612 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1613 room_updated(&room, &session.peer);
1614 (livekit_room, can_publish)
1615 };
1616
1617 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1618 live_kit
1619 .update_participant(
1620 livekit_room.clone(),
1621 request.user_id.to_string(),
1622 livekit_api::proto::ParticipantPermission {
1623 can_subscribe: true,
1624 can_publish,
1625 can_publish_data: can_publish,
1626 hidden: false,
1627 recorder: false,
1628 },
1629 )
1630 .await
1631 .trace_err();
1632 }
1633
1634 response.send(proto::Ack {})?;
1635 Ok(())
1636}
1637
1638/// Call someone else into the current room
1639async fn call(
1640 request: proto::Call,
1641 response: Response<proto::Call>,
1642 session: MessageContext,
1643) -> Result<()> {
1644 let room_id = RoomId::from_proto(request.room_id);
1645 let calling_user_id = session.user_id();
1646 let calling_connection_id = session.connection_id;
1647 let called_user_id = UserId::from_proto(request.called_user_id);
1648 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1649 if !session
1650 .db()
1651 .await
1652 .has_contact(calling_user_id, called_user_id)
1653 .await?
1654 {
1655 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1656 }
1657
1658 let incoming_call = {
1659 let (room, incoming_call) = &mut *session
1660 .db()
1661 .await
1662 .call(
1663 room_id,
1664 calling_user_id,
1665 calling_connection_id,
1666 called_user_id,
1667 initial_project_id,
1668 )
1669 .await?;
1670 room_updated(room, &session.peer);
1671 mem::take(incoming_call)
1672 };
1673 update_user_contacts(called_user_id, &session).await?;
1674
1675 let mut calls = session
1676 .connection_pool()
1677 .await
1678 .user_connection_ids(called_user_id)
1679 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1680 .collect::<FuturesUnordered<_>>();
1681
1682 while let Some(call_response) = calls.next().await {
1683 match call_response.as_ref() {
1684 Ok(_) => {
1685 response.send(proto::Ack {})?;
1686 return Ok(());
1687 }
1688 Err(_) => {
1689 call_response.trace_err();
1690 }
1691 }
1692 }
1693
1694 {
1695 let room = session
1696 .db()
1697 .await
1698 .call_failed(room_id, called_user_id)
1699 .await?;
1700 room_updated(&room, &session.peer);
1701 }
1702 update_user_contacts(called_user_id, &session).await?;
1703
1704 Err(anyhow!("failed to ring user"))?
1705}
1706
1707/// Cancel an outgoing call.
1708async fn cancel_call(
1709 request: proto::CancelCall,
1710 response: Response<proto::CancelCall>,
1711 session: MessageContext,
1712) -> Result<()> {
1713 let called_user_id = UserId::from_proto(request.called_user_id);
1714 let room_id = RoomId::from_proto(request.room_id);
1715 {
1716 let room = session
1717 .db()
1718 .await
1719 .cancel_call(room_id, session.connection_id, called_user_id)
1720 .await?;
1721 room_updated(&room, &session.peer);
1722 }
1723
1724 for connection_id in session
1725 .connection_pool()
1726 .await
1727 .user_connection_ids(called_user_id)
1728 {
1729 session
1730 .peer
1731 .send(
1732 connection_id,
1733 proto::CallCanceled {
1734 room_id: room_id.to_proto(),
1735 },
1736 )
1737 .trace_err();
1738 }
1739 response.send(proto::Ack {})?;
1740
1741 update_user_contacts(called_user_id, &session).await?;
1742 Ok(())
1743}
1744
1745/// Decline an incoming call.
1746async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1747 let room_id = RoomId::from_proto(message.room_id);
1748 {
1749 let room = session
1750 .db()
1751 .await
1752 .decline_call(Some(room_id), session.user_id())
1753 .await?
1754 .context("declining call")?;
1755 room_updated(&room, &session.peer);
1756 }
1757
1758 for connection_id in session
1759 .connection_pool()
1760 .await
1761 .user_connection_ids(session.user_id())
1762 {
1763 session
1764 .peer
1765 .send(
1766 connection_id,
1767 proto::CallCanceled {
1768 room_id: room_id.to_proto(),
1769 },
1770 )
1771 .trace_err();
1772 }
1773 update_user_contacts(session.user_id(), &session).await?;
1774 Ok(())
1775}
1776
1777/// Updates other participants in the room with your current location.
1778async fn update_participant_location(
1779 request: proto::UpdateParticipantLocation,
1780 response: Response<proto::UpdateParticipantLocation>,
1781 session: MessageContext,
1782) -> Result<()> {
1783 let room_id = RoomId::from_proto(request.room_id);
1784 let location = request.location.context("invalid location")?;
1785
1786 let db = session.db().await;
1787 let room = db
1788 .update_room_participant_location(room_id, session.connection_id, location)
1789 .await?;
1790
1791 room_updated(&room, &session.peer);
1792 response.send(proto::Ack {})?;
1793 Ok(())
1794}
1795
1796/// Share a project into the room.
1797async fn share_project(
1798 request: proto::ShareProject,
1799 response: Response<proto::ShareProject>,
1800 session: MessageContext,
1801) -> Result<()> {
1802 let (project_id, room) = &*session
1803 .db()
1804 .await
1805 .share_project(
1806 RoomId::from_proto(request.room_id),
1807 session.connection_id,
1808 &request.worktrees,
1809 request.is_ssh_project,
1810 request.windows_paths.unwrap_or(false),
1811 )
1812 .await?;
1813 response.send(proto::ShareProjectResponse {
1814 project_id: project_id.to_proto(),
1815 })?;
1816 room_updated(room, &session.peer);
1817
1818 Ok(())
1819}
1820
1821/// Unshare a project from the room.
1822async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1823 let project_id = ProjectId::from_proto(message.project_id);
1824 unshare_project_internal(project_id, session.connection_id, &session).await
1825}
1826
1827async fn unshare_project_internal(
1828 project_id: ProjectId,
1829 connection_id: ConnectionId,
1830 session: &Session,
1831) -> Result<()> {
1832 let delete = {
1833 let room_guard = session
1834 .db()
1835 .await
1836 .unshare_project(project_id, connection_id)
1837 .await?;
1838
1839 let (delete, room, guest_connection_ids) = &*room_guard;
1840
1841 let message = proto::UnshareProject {
1842 project_id: project_id.to_proto(),
1843 };
1844
1845 broadcast(
1846 Some(connection_id),
1847 guest_connection_ids.iter().copied(),
1848 |conn_id| session.peer.send(conn_id, message.clone()),
1849 );
1850 if let Some(room) = room {
1851 room_updated(room, &session.peer);
1852 }
1853
1854 *delete
1855 };
1856
1857 if delete {
1858 let db = session.db().await;
1859 db.delete_project(project_id).await?;
1860 }
1861
1862 Ok(())
1863}
1864
1865/// Join someone elses shared project.
1866async fn join_project(
1867 request: proto::JoinProject,
1868 response: Response<proto::JoinProject>,
1869 session: MessageContext,
1870) -> Result<()> {
1871 let project_id = ProjectId::from_proto(request.project_id);
1872
1873 tracing::info!(%project_id, "join project");
1874
1875 let db = session.db().await;
1876 let (project, replica_id) = &mut *db
1877 .join_project(
1878 project_id,
1879 session.connection_id,
1880 session.user_id(),
1881 request.committer_name.clone(),
1882 request.committer_email.clone(),
1883 )
1884 .await?;
1885 drop(db);
1886 tracing::info!(%project_id, "join remote project");
1887 let collaborators = project
1888 .collaborators
1889 .iter()
1890 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1891 .map(|collaborator| collaborator.to_proto())
1892 .collect::<Vec<_>>();
1893 let project_id = project.id;
1894 let guest_user_id = session.user_id();
1895
1896 let worktrees = project
1897 .worktrees
1898 .iter()
1899 .map(|(id, worktree)| proto::WorktreeMetadata {
1900 id: *id,
1901 root_name: worktree.root_name.clone(),
1902 visible: worktree.visible,
1903 abs_path: worktree.abs_path.clone(),
1904 })
1905 .collect::<Vec<_>>();
1906
1907 let add_project_collaborator = proto::AddProjectCollaborator {
1908 project_id: project_id.to_proto(),
1909 collaborator: Some(proto::Collaborator {
1910 peer_id: Some(session.connection_id.into()),
1911 replica_id: replica_id.0 as u32,
1912 user_id: guest_user_id.to_proto(),
1913 is_host: false,
1914 committer_name: request.committer_name.clone(),
1915 committer_email: request.committer_email.clone(),
1916 }),
1917 };
1918
1919 for collaborator in &collaborators {
1920 session
1921 .peer
1922 .send(
1923 collaborator.peer_id.unwrap().into(),
1924 add_project_collaborator.clone(),
1925 )
1926 .trace_err();
1927 }
1928
1929 // First, we send the metadata associated with each worktree.
1930 let (language_servers, language_server_capabilities) = project
1931 .language_servers
1932 .clone()
1933 .into_iter()
1934 .map(|server| (server.server, server.capabilities))
1935 .unzip();
1936 response.send(proto::JoinProjectResponse {
1937 project_id: project.id.0 as u64,
1938 worktrees,
1939 replica_id: replica_id.0 as u32,
1940 collaborators,
1941 language_servers,
1942 language_server_capabilities,
1943 role: project.role.into(),
1944 windows_paths: project.path_style == PathStyle::Windows,
1945 })?;
1946
1947 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1948 // Stream this worktree's entries.
1949 let message = proto::UpdateWorktree {
1950 project_id: project_id.to_proto(),
1951 worktree_id,
1952 abs_path: worktree.abs_path.clone(),
1953 root_name: worktree.root_name,
1954 updated_entries: worktree.entries,
1955 removed_entries: Default::default(),
1956 scan_id: worktree.scan_id,
1957 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1958 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1959 removed_repositories: Default::default(),
1960 };
1961 for update in proto::split_worktree_update(message) {
1962 session.peer.send(session.connection_id, update.clone())?;
1963 }
1964
1965 // Stream this worktree's diagnostics.
1966 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1967 if let Some(summary) = worktree_diagnostics.next() {
1968 let message = proto::UpdateDiagnosticSummary {
1969 project_id: project.id.to_proto(),
1970 worktree_id: worktree.id,
1971 summary: Some(summary),
1972 more_summaries: worktree_diagnostics.collect(),
1973 };
1974 session.peer.send(session.connection_id, message)?;
1975 }
1976
1977 for settings_file in worktree.settings_files {
1978 session.peer.send(
1979 session.connection_id,
1980 proto::UpdateWorktreeSettings {
1981 project_id: project_id.to_proto(),
1982 worktree_id: worktree.id,
1983 path: settings_file.path,
1984 content: Some(settings_file.content),
1985 kind: Some(settings_file.kind.to_proto() as i32),
1986 },
1987 )?;
1988 }
1989 }
1990
1991 for repository in mem::take(&mut project.repositories) {
1992 for update in split_repository_update(repository) {
1993 session.peer.send(session.connection_id, update)?;
1994 }
1995 }
1996
1997 for language_server in &project.language_servers {
1998 session.peer.send(
1999 session.connection_id,
2000 proto::UpdateLanguageServer {
2001 project_id: project_id.to_proto(),
2002 server_name: Some(language_server.server.name.clone()),
2003 language_server_id: language_server.server.id,
2004 variant: Some(
2005 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2006 proto::LspDiskBasedDiagnosticsUpdated {},
2007 ),
2008 ),
2009 },
2010 )?;
2011 }
2012
2013 Ok(())
2014}
2015
2016/// Leave someone elses shared project.
2017async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2018 let sender_id = session.connection_id;
2019 let project_id = ProjectId::from_proto(request.project_id);
2020 let db = session.db().await;
2021
2022 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2023 tracing::info!(
2024 %project_id,
2025 "leave project"
2026 );
2027
2028 project_left(project, &session);
2029 if let Some(room) = room {
2030 room_updated(room, &session.peer);
2031 }
2032
2033 Ok(())
2034}
2035
2036/// Updates other participants with changes to the project
2037async fn update_project(
2038 request: proto::UpdateProject,
2039 response: Response<proto::UpdateProject>,
2040 session: MessageContext,
2041) -> Result<()> {
2042 let project_id = ProjectId::from_proto(request.project_id);
2043 let (room, guest_connection_ids) = &*session
2044 .db()
2045 .await
2046 .update_project(project_id, session.connection_id, &request.worktrees)
2047 .await?;
2048 broadcast(
2049 Some(session.connection_id),
2050 guest_connection_ids.iter().copied(),
2051 |connection_id| {
2052 session
2053 .peer
2054 .forward_send(session.connection_id, connection_id, request.clone())
2055 },
2056 );
2057 if let Some(room) = room {
2058 room_updated(room, &session.peer);
2059 }
2060 response.send(proto::Ack {})?;
2061
2062 Ok(())
2063}
2064
2065/// Updates other participants with changes to the worktree
2066async fn update_worktree(
2067 request: proto::UpdateWorktree,
2068 response: Response<proto::UpdateWorktree>,
2069 session: MessageContext,
2070) -> Result<()> {
2071 let guest_connection_ids = session
2072 .db()
2073 .await
2074 .update_worktree(&request, session.connection_id)
2075 .await?;
2076
2077 broadcast(
2078 Some(session.connection_id),
2079 guest_connection_ids.iter().copied(),
2080 |connection_id| {
2081 session
2082 .peer
2083 .forward_send(session.connection_id, connection_id, request.clone())
2084 },
2085 );
2086 response.send(proto::Ack {})?;
2087 Ok(())
2088}
2089
2090async fn update_repository(
2091 request: proto::UpdateRepository,
2092 response: Response<proto::UpdateRepository>,
2093 session: MessageContext,
2094) -> Result<()> {
2095 let guest_connection_ids = session
2096 .db()
2097 .await
2098 .update_repository(&request, session.connection_id)
2099 .await?;
2100
2101 broadcast(
2102 Some(session.connection_id),
2103 guest_connection_ids.iter().copied(),
2104 |connection_id| {
2105 session
2106 .peer
2107 .forward_send(session.connection_id, connection_id, request.clone())
2108 },
2109 );
2110 response.send(proto::Ack {})?;
2111 Ok(())
2112}
2113
2114async fn remove_repository(
2115 request: proto::RemoveRepository,
2116 response: Response<proto::RemoveRepository>,
2117 session: MessageContext,
2118) -> Result<()> {
2119 let guest_connection_ids = session
2120 .db()
2121 .await
2122 .remove_repository(&request, session.connection_id)
2123 .await?;
2124
2125 broadcast(
2126 Some(session.connection_id),
2127 guest_connection_ids.iter().copied(),
2128 |connection_id| {
2129 session
2130 .peer
2131 .forward_send(session.connection_id, connection_id, request.clone())
2132 },
2133 );
2134 response.send(proto::Ack {})?;
2135 Ok(())
2136}
2137
2138/// Updates other participants with changes to the diagnostics
2139async fn update_diagnostic_summary(
2140 message: proto::UpdateDiagnosticSummary,
2141 session: MessageContext,
2142) -> Result<()> {
2143 let guest_connection_ids = session
2144 .db()
2145 .await
2146 .update_diagnostic_summary(&message, session.connection_id)
2147 .await?;
2148
2149 broadcast(
2150 Some(session.connection_id),
2151 guest_connection_ids.iter().copied(),
2152 |connection_id| {
2153 session
2154 .peer
2155 .forward_send(session.connection_id, connection_id, message.clone())
2156 },
2157 );
2158
2159 Ok(())
2160}
2161
2162/// Updates other participants with changes to the worktree settings
2163async fn update_worktree_settings(
2164 message: proto::UpdateWorktreeSettings,
2165 session: MessageContext,
2166) -> Result<()> {
2167 let guest_connection_ids = session
2168 .db()
2169 .await
2170 .update_worktree_settings(&message, session.connection_id)
2171 .await?;
2172
2173 broadcast(
2174 Some(session.connection_id),
2175 guest_connection_ids.iter().copied(),
2176 |connection_id| {
2177 session
2178 .peer
2179 .forward_send(session.connection_id, connection_id, message.clone())
2180 },
2181 );
2182
2183 Ok(())
2184}
2185
2186/// Notify other participants that a language server has started.
2187async fn start_language_server(
2188 request: proto::StartLanguageServer,
2189 session: MessageContext,
2190) -> Result<()> {
2191 let guest_connection_ids = session
2192 .db()
2193 .await
2194 .start_language_server(&request, session.connection_id)
2195 .await?;
2196
2197 broadcast(
2198 Some(session.connection_id),
2199 guest_connection_ids.iter().copied(),
2200 |connection_id| {
2201 session
2202 .peer
2203 .forward_send(session.connection_id, connection_id, request.clone())
2204 },
2205 );
2206 Ok(())
2207}
2208
2209/// Notify other participants that a language server has changed.
2210async fn update_language_server(
2211 request: proto::UpdateLanguageServer,
2212 session: MessageContext,
2213) -> Result<()> {
2214 let project_id = ProjectId::from_proto(request.project_id);
2215 let db = session.db().await;
2216
2217 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2218 && let Some(capabilities) = update.capabilities.clone()
2219 {
2220 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2221 .await?;
2222 }
2223
2224 let project_connection_ids = db
2225 .project_connection_ids(project_id, session.connection_id, true)
2226 .await?;
2227 broadcast(
2228 Some(session.connection_id),
2229 project_connection_ids.iter().copied(),
2230 |connection_id| {
2231 session
2232 .peer
2233 .forward_send(session.connection_id, connection_id, request.clone())
2234 },
2235 );
2236 Ok(())
2237}
2238
2239/// forward a project request to the host. These requests should be read only
2240/// as guests are allowed to send them.
2241async fn forward_read_only_project_request<T>(
2242 request: T,
2243 response: Response<T>,
2244 session: MessageContext,
2245) -> Result<()>
2246where
2247 T: EntityMessage + RequestMessage,
2248{
2249 let project_id = ProjectId::from_proto(request.remote_entity_id());
2250 let host_connection_id = session
2251 .db()
2252 .await
2253 .host_for_read_only_project_request(project_id, session.connection_id)
2254 .await?;
2255 let payload = session.forward_request(host_connection_id, request).await?;
2256 response.send(payload)?;
2257 Ok(())
2258}
2259
2260/// forward a project request to the host. These requests are disallowed
2261/// for guests.
2262async fn forward_mutating_project_request<T>(
2263 request: T,
2264 response: Response<T>,
2265 session: MessageContext,
2266) -> Result<()>
2267where
2268 T: EntityMessage + RequestMessage,
2269{
2270 let project_id = ProjectId::from_proto(request.remote_entity_id());
2271
2272 let host_connection_id = session
2273 .db()
2274 .await
2275 .host_for_mutating_project_request(project_id, session.connection_id)
2276 .await?;
2277 let payload = session.forward_request(host_connection_id, request).await?;
2278 response.send(payload)?;
2279 Ok(())
2280}
2281
2282async fn lsp_query(
2283 request: proto::LspQuery,
2284 response: Response<proto::LspQuery>,
2285 session: MessageContext,
2286) -> Result<()> {
2287 let (name, should_write) = request.query_name_and_write_permissions();
2288 tracing::Span::current().record("lsp_query_request", name);
2289 tracing::info!("lsp_query message received");
2290 if should_write {
2291 forward_mutating_project_request(request, response, session).await
2292 } else {
2293 forward_read_only_project_request(request, response, session).await
2294 }
2295}
2296
2297/// Notify other participants that a new buffer has been created
2298async fn create_buffer_for_peer(
2299 request: proto::CreateBufferForPeer,
2300 session: MessageContext,
2301) -> Result<()> {
2302 session
2303 .db()
2304 .await
2305 .check_user_is_project_host(
2306 ProjectId::from_proto(request.project_id),
2307 session.connection_id,
2308 )
2309 .await?;
2310 let peer_id = request.peer_id.context("invalid peer id")?;
2311 session
2312 .peer
2313 .forward_send(session.connection_id, peer_id.into(), request)?;
2314 Ok(())
2315}
2316
2317/// Notify other participants that a new image has been created
2318async fn create_image_for_peer(
2319 request: proto::CreateImageForPeer,
2320 session: MessageContext,
2321) -> Result<()> {
2322 session
2323 .db()
2324 .await
2325 .check_user_is_project_host(
2326 ProjectId::from_proto(request.project_id),
2327 session.connection_id,
2328 )
2329 .await?;
2330 let peer_id = request.peer_id.context("invalid peer id")?;
2331 session
2332 .peer
2333 .forward_send(session.connection_id, peer_id.into(), request)?;
2334 Ok(())
2335}
2336
2337/// Notify other participants that a buffer has been updated. This is
2338/// allowed for guests as long as the update is limited to selections.
2339async fn update_buffer(
2340 request: proto::UpdateBuffer,
2341 response: Response<proto::UpdateBuffer>,
2342 session: MessageContext,
2343) -> Result<()> {
2344 let project_id = ProjectId::from_proto(request.project_id);
2345 let mut capability = Capability::ReadOnly;
2346
2347 for op in request.operations.iter() {
2348 match op.variant {
2349 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2350 Some(_) => capability = Capability::ReadWrite,
2351 }
2352 }
2353
2354 let host = {
2355 let guard = session
2356 .db()
2357 .await
2358 .connections_for_buffer_update(project_id, session.connection_id, capability)
2359 .await?;
2360
2361 let (host, guests) = &*guard;
2362
2363 broadcast(
2364 Some(session.connection_id),
2365 guests.clone(),
2366 |connection_id| {
2367 session
2368 .peer
2369 .forward_send(session.connection_id, connection_id, request.clone())
2370 },
2371 );
2372
2373 *host
2374 };
2375
2376 if host != session.connection_id {
2377 session.forward_request(host, request.clone()).await?;
2378 }
2379
2380 response.send(proto::Ack {})?;
2381 Ok(())
2382}
2383
2384async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2385 let project_id = ProjectId::from_proto(message.project_id);
2386
2387 let operation = message.operation.as_ref().context("invalid operation")?;
2388 let capability = match operation.variant.as_ref() {
2389 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2390 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2391 match buffer_op.variant {
2392 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2393 Capability::ReadOnly
2394 }
2395 _ => Capability::ReadWrite,
2396 }
2397 } else {
2398 Capability::ReadWrite
2399 }
2400 }
2401 Some(_) => Capability::ReadWrite,
2402 None => Capability::ReadOnly,
2403 };
2404
2405 let guard = session
2406 .db()
2407 .await
2408 .connections_for_buffer_update(project_id, session.connection_id, capability)
2409 .await?;
2410
2411 let (host, guests) = &*guard;
2412
2413 broadcast(
2414 Some(session.connection_id),
2415 guests.iter().chain([host]).copied(),
2416 |connection_id| {
2417 session
2418 .peer
2419 .forward_send(session.connection_id, connection_id, message.clone())
2420 },
2421 );
2422
2423 Ok(())
2424}
2425
2426/// Notify other participants that a project has been updated.
2427async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2428 request: T,
2429 session: MessageContext,
2430) -> Result<()> {
2431 let project_id = ProjectId::from_proto(request.remote_entity_id());
2432 let project_connection_ids = session
2433 .db()
2434 .await
2435 .project_connection_ids(project_id, session.connection_id, false)
2436 .await?;
2437
2438 broadcast(
2439 Some(session.connection_id),
2440 project_connection_ids.iter().copied(),
2441 |connection_id| {
2442 session
2443 .peer
2444 .forward_send(session.connection_id, connection_id, request.clone())
2445 },
2446 );
2447 Ok(())
2448}
2449
2450/// Start following another user in a call.
2451async fn follow(
2452 request: proto::Follow,
2453 response: Response<proto::Follow>,
2454 session: MessageContext,
2455) -> Result<()> {
2456 let room_id = RoomId::from_proto(request.room_id);
2457 let project_id = request.project_id.map(ProjectId::from_proto);
2458 let leader_id = request.leader_id.context("invalid leader id")?.into();
2459 let follower_id = session.connection_id;
2460
2461 session
2462 .db()
2463 .await
2464 .check_room_participants(room_id, leader_id, session.connection_id)
2465 .await?;
2466
2467 let response_payload = session.forward_request(leader_id, request).await?;
2468 response.send(response_payload)?;
2469
2470 if let Some(project_id) = project_id {
2471 let room = session
2472 .db()
2473 .await
2474 .follow(room_id, project_id, leader_id, follower_id)
2475 .await?;
2476 room_updated(&room, &session.peer);
2477 }
2478
2479 Ok(())
2480}
2481
2482/// Stop following another user in a call.
2483async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2484 let room_id = RoomId::from_proto(request.room_id);
2485 let project_id = request.project_id.map(ProjectId::from_proto);
2486 let leader_id = request.leader_id.context("invalid leader id")?.into();
2487 let follower_id = session.connection_id;
2488
2489 session
2490 .db()
2491 .await
2492 .check_room_participants(room_id, leader_id, session.connection_id)
2493 .await?;
2494
2495 session
2496 .peer
2497 .forward_send(session.connection_id, leader_id, request)?;
2498
2499 if let Some(project_id) = project_id {
2500 let room = session
2501 .db()
2502 .await
2503 .unfollow(room_id, project_id, leader_id, follower_id)
2504 .await?;
2505 room_updated(&room, &session.peer);
2506 }
2507
2508 Ok(())
2509}
2510
2511/// Notify everyone following you of your current location.
2512async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2513 let room_id = RoomId::from_proto(request.room_id);
2514 let database = session.db.lock().await;
2515
2516 let connection_ids = if let Some(project_id) = request.project_id {
2517 let project_id = ProjectId::from_proto(project_id);
2518 database
2519 .project_connection_ids(project_id, session.connection_id, true)
2520 .await?
2521 } else {
2522 database
2523 .room_connection_ids(room_id, session.connection_id)
2524 .await?
2525 };
2526
2527 // For now, don't send view update messages back to that view's current leader.
2528 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2529 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2530 _ => None,
2531 });
2532
2533 for connection_id in connection_ids.iter().cloned() {
2534 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2535 session
2536 .peer
2537 .forward_send(session.connection_id, connection_id, request.clone())?;
2538 }
2539 }
2540 Ok(())
2541}
2542
2543/// Get public data about users.
2544async fn get_users(
2545 request: proto::GetUsers,
2546 response: Response<proto::GetUsers>,
2547 session: MessageContext,
2548) -> Result<()> {
2549 let user_ids = request
2550 .user_ids
2551 .into_iter()
2552 .map(UserId::from_proto)
2553 .collect();
2554 let users = session
2555 .db()
2556 .await
2557 .get_users_by_ids(user_ids)
2558 .await?
2559 .into_iter()
2560 .map(|user| proto::User {
2561 id: user.id.to_proto(),
2562 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2563 github_login: user.github_login,
2564 name: user.name,
2565 })
2566 .collect();
2567 response.send(proto::UsersResponse { users })?;
2568 Ok(())
2569}
2570
2571/// Search for users (to invite) buy Github login
2572async fn fuzzy_search_users(
2573 request: proto::FuzzySearchUsers,
2574 response: Response<proto::FuzzySearchUsers>,
2575 session: MessageContext,
2576) -> Result<()> {
2577 let query = request.query;
2578 let users = match query.len() {
2579 0 => vec![],
2580 1 | 2 => session
2581 .db()
2582 .await
2583 .get_user_by_github_login(&query)
2584 .await?
2585 .into_iter()
2586 .collect(),
2587 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2588 };
2589 let users = users
2590 .into_iter()
2591 .filter(|user| user.id != session.user_id())
2592 .map(|user| proto::User {
2593 id: user.id.to_proto(),
2594 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2595 github_login: user.github_login,
2596 name: user.name,
2597 })
2598 .collect();
2599 response.send(proto::UsersResponse { users })?;
2600 Ok(())
2601}
2602
2603/// Send a contact request to another user.
2604async fn request_contact(
2605 request: proto::RequestContact,
2606 response: Response<proto::RequestContact>,
2607 session: MessageContext,
2608) -> Result<()> {
2609 let requester_id = session.user_id();
2610 let responder_id = UserId::from_proto(request.responder_id);
2611 if requester_id == responder_id {
2612 return Err(anyhow!("cannot add yourself as a contact"))?;
2613 }
2614
2615 let notifications = session
2616 .db()
2617 .await
2618 .send_contact_request(requester_id, responder_id)
2619 .await?;
2620
2621 // Update outgoing contact requests of requester
2622 let mut update = proto::UpdateContacts::default();
2623 update.outgoing_requests.push(responder_id.to_proto());
2624 for connection_id in session
2625 .connection_pool()
2626 .await
2627 .user_connection_ids(requester_id)
2628 {
2629 session.peer.send(connection_id, update.clone())?;
2630 }
2631
2632 // Update incoming contact requests of responder
2633 let mut update = proto::UpdateContacts::default();
2634 update
2635 .incoming_requests
2636 .push(proto::IncomingContactRequest {
2637 requester_id: requester_id.to_proto(),
2638 });
2639 let connection_pool = session.connection_pool().await;
2640 for connection_id in connection_pool.user_connection_ids(responder_id) {
2641 session.peer.send(connection_id, update.clone())?;
2642 }
2643
2644 send_notifications(&connection_pool, &session.peer, notifications);
2645
2646 response.send(proto::Ack {})?;
2647 Ok(())
2648}
2649
2650/// Accept or decline a contact request
2651async fn respond_to_contact_request(
2652 request: proto::RespondToContactRequest,
2653 response: Response<proto::RespondToContactRequest>,
2654 session: MessageContext,
2655) -> Result<()> {
2656 let responder_id = session.user_id();
2657 let requester_id = UserId::from_proto(request.requester_id);
2658 let db = session.db().await;
2659 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2660 db.dismiss_contact_notification(responder_id, requester_id)
2661 .await?;
2662 } else {
2663 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2664
2665 let notifications = db
2666 .respond_to_contact_request(responder_id, requester_id, accept)
2667 .await?;
2668 let requester_busy = db.is_user_busy(requester_id).await?;
2669 let responder_busy = db.is_user_busy(responder_id).await?;
2670
2671 let pool = session.connection_pool().await;
2672 // Update responder with new contact
2673 let mut update = proto::UpdateContacts::default();
2674 if accept {
2675 update
2676 .contacts
2677 .push(contact_for_user(requester_id, requester_busy, &pool));
2678 }
2679 update
2680 .remove_incoming_requests
2681 .push(requester_id.to_proto());
2682 for connection_id in pool.user_connection_ids(responder_id) {
2683 session.peer.send(connection_id, update.clone())?;
2684 }
2685
2686 // Update requester with new contact
2687 let mut update = proto::UpdateContacts::default();
2688 if accept {
2689 update
2690 .contacts
2691 .push(contact_for_user(responder_id, responder_busy, &pool));
2692 }
2693 update
2694 .remove_outgoing_requests
2695 .push(responder_id.to_proto());
2696
2697 for connection_id in pool.user_connection_ids(requester_id) {
2698 session.peer.send(connection_id, update.clone())?;
2699 }
2700
2701 send_notifications(&pool, &session.peer, notifications);
2702 }
2703
2704 response.send(proto::Ack {})?;
2705 Ok(())
2706}
2707
2708/// Remove a contact.
2709async fn remove_contact(
2710 request: proto::RemoveContact,
2711 response: Response<proto::RemoveContact>,
2712 session: MessageContext,
2713) -> Result<()> {
2714 let requester_id = session.user_id();
2715 let responder_id = UserId::from_proto(request.user_id);
2716 let db = session.db().await;
2717 let (contact_accepted, deleted_notification_id) =
2718 db.remove_contact(requester_id, responder_id).await?;
2719
2720 let pool = session.connection_pool().await;
2721 // Update outgoing contact requests of requester
2722 let mut update = proto::UpdateContacts::default();
2723 if contact_accepted {
2724 update.remove_contacts.push(responder_id.to_proto());
2725 } else {
2726 update
2727 .remove_outgoing_requests
2728 .push(responder_id.to_proto());
2729 }
2730 for connection_id in pool.user_connection_ids(requester_id) {
2731 session.peer.send(connection_id, update.clone())?;
2732 }
2733
2734 // Update incoming contact requests of responder
2735 let mut update = proto::UpdateContacts::default();
2736 if contact_accepted {
2737 update.remove_contacts.push(requester_id.to_proto());
2738 } else {
2739 update
2740 .remove_incoming_requests
2741 .push(requester_id.to_proto());
2742 }
2743 for connection_id in pool.user_connection_ids(responder_id) {
2744 session.peer.send(connection_id, update.clone())?;
2745 if let Some(notification_id) = deleted_notification_id {
2746 session.peer.send(
2747 connection_id,
2748 proto::DeleteNotification {
2749 notification_id: notification_id.to_proto(),
2750 },
2751 )?;
2752 }
2753 }
2754
2755 response.send(proto::Ack {})?;
2756 Ok(())
2757}
2758
2759fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2760 version.0.minor < 139
2761}
2762
2763async fn subscribe_to_channels(
2764 _: proto::SubscribeToChannels,
2765 session: MessageContext,
2766) -> Result<()> {
2767 subscribe_user_to_channels(session.user_id(), &session).await?;
2768 Ok(())
2769}
2770
2771async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2772 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2773 let mut pool = session.connection_pool().await;
2774 for membership in &channels_for_user.channel_memberships {
2775 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2776 }
2777 session.peer.send(
2778 session.connection_id,
2779 build_update_user_channels(&channels_for_user),
2780 )?;
2781 session.peer.send(
2782 session.connection_id,
2783 build_channels_update(channels_for_user),
2784 )?;
2785 Ok(())
2786}
2787
2788/// Creates a new channel.
2789async fn create_channel(
2790 request: proto::CreateChannel,
2791 response: Response<proto::CreateChannel>,
2792 session: MessageContext,
2793) -> Result<()> {
2794 let db = session.db().await;
2795
2796 let parent_id = request.parent_id.map(ChannelId::from_proto);
2797 let (channel, membership) = db
2798 .create_channel(&request.name, parent_id, session.user_id())
2799 .await?;
2800
2801 let root_id = channel.root_id();
2802 let channel = Channel::from_model(channel);
2803
2804 response.send(proto::CreateChannelResponse {
2805 channel: Some(channel.to_proto()),
2806 parent_id: request.parent_id,
2807 })?;
2808
2809 let mut connection_pool = session.connection_pool().await;
2810 if let Some(membership) = membership {
2811 connection_pool.subscribe_to_channel(
2812 membership.user_id,
2813 membership.channel_id,
2814 membership.role,
2815 );
2816 let update = proto::UpdateUserChannels {
2817 channel_memberships: vec![proto::ChannelMembership {
2818 channel_id: membership.channel_id.to_proto(),
2819 role: membership.role.into(),
2820 }],
2821 ..Default::default()
2822 };
2823 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2824 session.peer.send(connection_id, update.clone())?;
2825 }
2826 }
2827
2828 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2829 if !role.can_see_channel(channel.visibility) {
2830 continue;
2831 }
2832
2833 let update = proto::UpdateChannels {
2834 channels: vec![channel.to_proto()],
2835 ..Default::default()
2836 };
2837 session.peer.send(connection_id, update.clone())?;
2838 }
2839
2840 Ok(())
2841}
2842
2843/// Delete a channel
2844async fn delete_channel(
2845 request: proto::DeleteChannel,
2846 response: Response<proto::DeleteChannel>,
2847 session: MessageContext,
2848) -> Result<()> {
2849 let db = session.db().await;
2850
2851 let channel_id = request.channel_id;
2852 let (root_channel, removed_channels) = db
2853 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2854 .await?;
2855 response.send(proto::Ack {})?;
2856
2857 // Notify members of removed channels
2858 let mut update = proto::UpdateChannels::default();
2859 update
2860 .delete_channels
2861 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2862
2863 let connection_pool = session.connection_pool().await;
2864 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2865 session.peer.send(connection_id, update.clone())?;
2866 }
2867
2868 Ok(())
2869}
2870
2871/// Invite someone to join a channel.
2872async fn invite_channel_member(
2873 request: proto::InviteChannelMember,
2874 response: Response<proto::InviteChannelMember>,
2875 session: MessageContext,
2876) -> Result<()> {
2877 let db = session.db().await;
2878 let channel_id = ChannelId::from_proto(request.channel_id);
2879 let invitee_id = UserId::from_proto(request.user_id);
2880 let InviteMemberResult {
2881 channel,
2882 notifications,
2883 } = db
2884 .invite_channel_member(
2885 channel_id,
2886 invitee_id,
2887 session.user_id(),
2888 request.role().into(),
2889 )
2890 .await?;
2891
2892 let update = proto::UpdateChannels {
2893 channel_invitations: vec![channel.to_proto()],
2894 ..Default::default()
2895 };
2896
2897 let connection_pool = session.connection_pool().await;
2898 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2899 session.peer.send(connection_id, update.clone())?;
2900 }
2901
2902 send_notifications(&connection_pool, &session.peer, notifications);
2903
2904 response.send(proto::Ack {})?;
2905 Ok(())
2906}
2907
2908/// remove someone from a channel
2909async fn remove_channel_member(
2910 request: proto::RemoveChannelMember,
2911 response: Response<proto::RemoveChannelMember>,
2912 session: MessageContext,
2913) -> Result<()> {
2914 let db = session.db().await;
2915 let channel_id = ChannelId::from_proto(request.channel_id);
2916 let member_id = UserId::from_proto(request.user_id);
2917
2918 let RemoveChannelMemberResult {
2919 membership_update,
2920 notification_id,
2921 } = db
2922 .remove_channel_member(channel_id, member_id, session.user_id())
2923 .await?;
2924
2925 let mut connection_pool = session.connection_pool().await;
2926 notify_membership_updated(
2927 &mut connection_pool,
2928 membership_update,
2929 member_id,
2930 &session.peer,
2931 );
2932 for connection_id in connection_pool.user_connection_ids(member_id) {
2933 if let Some(notification_id) = notification_id {
2934 session
2935 .peer
2936 .send(
2937 connection_id,
2938 proto::DeleteNotification {
2939 notification_id: notification_id.to_proto(),
2940 },
2941 )
2942 .trace_err();
2943 }
2944 }
2945
2946 response.send(proto::Ack {})?;
2947 Ok(())
2948}
2949
2950/// Toggle the channel between public and private.
2951/// Care is taken to maintain the invariant that public channels only descend from public channels,
2952/// (though members-only channels can appear at any point in the hierarchy).
2953async fn set_channel_visibility(
2954 request: proto::SetChannelVisibility,
2955 response: Response<proto::SetChannelVisibility>,
2956 session: MessageContext,
2957) -> Result<()> {
2958 let db = session.db().await;
2959 let channel_id = ChannelId::from_proto(request.channel_id);
2960 let visibility = request.visibility().into();
2961
2962 let channel_model = db
2963 .set_channel_visibility(channel_id, visibility, session.user_id())
2964 .await?;
2965 let root_id = channel_model.root_id();
2966 let channel = Channel::from_model(channel_model);
2967
2968 let mut connection_pool = session.connection_pool().await;
2969 for (user_id, role) in connection_pool
2970 .channel_user_ids(root_id)
2971 .collect::<Vec<_>>()
2972 .into_iter()
2973 {
2974 let update = if role.can_see_channel(channel.visibility) {
2975 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2976 proto::UpdateChannels {
2977 channels: vec![channel.to_proto()],
2978 ..Default::default()
2979 }
2980 } else {
2981 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2982 proto::UpdateChannels {
2983 delete_channels: vec![channel.id.to_proto()],
2984 ..Default::default()
2985 }
2986 };
2987
2988 for connection_id in connection_pool.user_connection_ids(user_id) {
2989 session.peer.send(connection_id, update.clone())?;
2990 }
2991 }
2992
2993 response.send(proto::Ack {})?;
2994 Ok(())
2995}
2996
2997/// Alter the role for a user in the channel.
2998async fn set_channel_member_role(
2999 request: proto::SetChannelMemberRole,
3000 response: Response<proto::SetChannelMemberRole>,
3001 session: MessageContext,
3002) -> Result<()> {
3003 let db = session.db().await;
3004 let channel_id = ChannelId::from_proto(request.channel_id);
3005 let member_id = UserId::from_proto(request.user_id);
3006 let result = db
3007 .set_channel_member_role(
3008 channel_id,
3009 session.user_id(),
3010 member_id,
3011 request.role().into(),
3012 )
3013 .await?;
3014
3015 match result {
3016 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3017 let mut connection_pool = session.connection_pool().await;
3018 notify_membership_updated(
3019 &mut connection_pool,
3020 membership_update,
3021 member_id,
3022 &session.peer,
3023 )
3024 }
3025 db::SetMemberRoleResult::InviteUpdated(channel) => {
3026 let update = proto::UpdateChannels {
3027 channel_invitations: vec![channel.to_proto()],
3028 ..Default::default()
3029 };
3030
3031 for connection_id in session
3032 .connection_pool()
3033 .await
3034 .user_connection_ids(member_id)
3035 {
3036 session.peer.send(connection_id, update.clone())?;
3037 }
3038 }
3039 }
3040
3041 response.send(proto::Ack {})?;
3042 Ok(())
3043}
3044
3045/// Change the name of a channel
3046async fn rename_channel(
3047 request: proto::RenameChannel,
3048 response: Response<proto::RenameChannel>,
3049 session: MessageContext,
3050) -> Result<()> {
3051 let db = session.db().await;
3052 let channel_id = ChannelId::from_proto(request.channel_id);
3053 let channel_model = db
3054 .rename_channel(channel_id, session.user_id(), &request.name)
3055 .await?;
3056 let root_id = channel_model.root_id();
3057 let channel = Channel::from_model(channel_model);
3058
3059 response.send(proto::RenameChannelResponse {
3060 channel: Some(channel.to_proto()),
3061 })?;
3062
3063 let connection_pool = session.connection_pool().await;
3064 let update = proto::UpdateChannels {
3065 channels: vec![channel.to_proto()],
3066 ..Default::default()
3067 };
3068 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3069 if role.can_see_channel(channel.visibility) {
3070 session.peer.send(connection_id, update.clone())?;
3071 }
3072 }
3073
3074 Ok(())
3075}
3076
3077/// Move a channel to a new parent.
3078async fn move_channel(
3079 request: proto::MoveChannel,
3080 response: Response<proto::MoveChannel>,
3081 session: MessageContext,
3082) -> Result<()> {
3083 let channel_id = ChannelId::from_proto(request.channel_id);
3084 let to = ChannelId::from_proto(request.to);
3085
3086 let (root_id, channels) = session
3087 .db()
3088 .await
3089 .move_channel(channel_id, to, session.user_id())
3090 .await?;
3091
3092 let connection_pool = session.connection_pool().await;
3093 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3094 let channels = channels
3095 .iter()
3096 .filter_map(|channel| {
3097 if role.can_see_channel(channel.visibility) {
3098 Some(channel.to_proto())
3099 } else {
3100 None
3101 }
3102 })
3103 .collect::<Vec<_>>();
3104 if channels.is_empty() {
3105 continue;
3106 }
3107
3108 let update = proto::UpdateChannels {
3109 channels,
3110 ..Default::default()
3111 };
3112
3113 session.peer.send(connection_id, update.clone())?;
3114 }
3115
3116 response.send(Ack {})?;
3117 Ok(())
3118}
3119
3120async fn reorder_channel(
3121 request: proto::ReorderChannel,
3122 response: Response<proto::ReorderChannel>,
3123 session: MessageContext,
3124) -> Result<()> {
3125 let channel_id = ChannelId::from_proto(request.channel_id);
3126 let direction = request.direction();
3127
3128 let updated_channels = session
3129 .db()
3130 .await
3131 .reorder_channel(channel_id, direction, session.user_id())
3132 .await?;
3133
3134 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3135 let connection_pool = session.connection_pool().await;
3136 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3137 let channels = updated_channels
3138 .iter()
3139 .filter_map(|channel| {
3140 if role.can_see_channel(channel.visibility) {
3141 Some(channel.to_proto())
3142 } else {
3143 None
3144 }
3145 })
3146 .collect::<Vec<_>>();
3147
3148 if channels.is_empty() {
3149 continue;
3150 }
3151
3152 let update = proto::UpdateChannels {
3153 channels,
3154 ..Default::default()
3155 };
3156
3157 session.peer.send(connection_id, update.clone())?;
3158 }
3159 }
3160
3161 response.send(Ack {})?;
3162 Ok(())
3163}
3164
3165/// Get the list of channel members
3166async fn get_channel_members(
3167 request: proto::GetChannelMembers,
3168 response: Response<proto::GetChannelMembers>,
3169 session: MessageContext,
3170) -> Result<()> {
3171 let db = session.db().await;
3172 let channel_id = ChannelId::from_proto(request.channel_id);
3173 let limit = if request.limit == 0 {
3174 u16::MAX as u64
3175 } else {
3176 request.limit
3177 };
3178 let (members, users) = db
3179 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3180 .await?;
3181 response.send(proto::GetChannelMembersResponse { members, users })?;
3182 Ok(())
3183}
3184
3185/// Accept or decline a channel invitation.
3186async fn respond_to_channel_invite(
3187 request: proto::RespondToChannelInvite,
3188 response: Response<proto::RespondToChannelInvite>,
3189 session: MessageContext,
3190) -> Result<()> {
3191 let db = session.db().await;
3192 let channel_id = ChannelId::from_proto(request.channel_id);
3193 let RespondToChannelInvite {
3194 membership_update,
3195 notifications,
3196 } = db
3197 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3198 .await?;
3199
3200 let mut connection_pool = session.connection_pool().await;
3201 if let Some(membership_update) = membership_update {
3202 notify_membership_updated(
3203 &mut connection_pool,
3204 membership_update,
3205 session.user_id(),
3206 &session.peer,
3207 );
3208 } else {
3209 let update = proto::UpdateChannels {
3210 remove_channel_invitations: vec![channel_id.to_proto()],
3211 ..Default::default()
3212 };
3213
3214 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3215 session.peer.send(connection_id, update.clone())?;
3216 }
3217 };
3218
3219 send_notifications(&connection_pool, &session.peer, notifications);
3220
3221 response.send(proto::Ack {})?;
3222
3223 Ok(())
3224}
3225
3226/// Join the channels' room
3227async fn join_channel(
3228 request: proto::JoinChannel,
3229 response: Response<proto::JoinChannel>,
3230 session: MessageContext,
3231) -> Result<()> {
3232 let channel_id = ChannelId::from_proto(request.channel_id);
3233 join_channel_internal(channel_id, Box::new(response), session).await
3234}
3235
3236trait JoinChannelInternalResponse {
3237 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3238}
3239impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3240 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3241 Response::<proto::JoinChannel>::send(self, result)
3242 }
3243}
3244impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3245 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3246 Response::<proto::JoinRoom>::send(self, result)
3247 }
3248}
3249
3250async fn join_channel_internal(
3251 channel_id: ChannelId,
3252 response: Box<impl JoinChannelInternalResponse>,
3253 session: MessageContext,
3254) -> Result<()> {
3255 let joined_room = {
3256 let mut db = session.db().await;
3257 // If zed quits without leaving the room, and the user re-opens zed before the
3258 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3259 // room they were in.
3260 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3261 tracing::info!(
3262 stale_connection_id = %connection,
3263 "cleaning up stale connection",
3264 );
3265 drop(db);
3266 leave_room_for_session(&session, connection).await?;
3267 db = session.db().await;
3268 }
3269
3270 let (joined_room, membership_updated, role) = db
3271 .join_channel(channel_id, session.user_id(), session.connection_id)
3272 .await?;
3273
3274 let live_kit_connection_info =
3275 session
3276 .app_state
3277 .livekit_client
3278 .as_ref()
3279 .and_then(|live_kit| {
3280 let (can_publish, token) = if role == ChannelRole::Guest {
3281 (
3282 false,
3283 live_kit
3284 .guest_token(
3285 &joined_room.room.livekit_room,
3286 &session.user_id().to_string(),
3287 )
3288 .trace_err()?,
3289 )
3290 } else {
3291 (
3292 true,
3293 live_kit
3294 .room_token(
3295 &joined_room.room.livekit_room,
3296 &session.user_id().to_string(),
3297 )
3298 .trace_err()?,
3299 )
3300 };
3301
3302 Some(LiveKitConnectionInfo {
3303 server_url: live_kit.url().into(),
3304 token,
3305 can_publish,
3306 })
3307 });
3308
3309 response.send(proto::JoinRoomResponse {
3310 room: Some(joined_room.room.clone()),
3311 channel_id: joined_room
3312 .channel
3313 .as_ref()
3314 .map(|channel| channel.id.to_proto()),
3315 live_kit_connection_info,
3316 })?;
3317
3318 let mut connection_pool = session.connection_pool().await;
3319 if let Some(membership_updated) = membership_updated {
3320 notify_membership_updated(
3321 &mut connection_pool,
3322 membership_updated,
3323 session.user_id(),
3324 &session.peer,
3325 );
3326 }
3327
3328 room_updated(&joined_room.room, &session.peer);
3329
3330 joined_room
3331 };
3332
3333 channel_updated(
3334 &joined_room.channel.context("channel not returned")?,
3335 &joined_room.room,
3336 &session.peer,
3337 &*session.connection_pool().await,
3338 );
3339
3340 update_user_contacts(session.user_id(), &session).await?;
3341 Ok(())
3342}
3343
3344/// Start editing the channel notes
3345async fn join_channel_buffer(
3346 request: proto::JoinChannelBuffer,
3347 response: Response<proto::JoinChannelBuffer>,
3348 session: MessageContext,
3349) -> Result<()> {
3350 let db = session.db().await;
3351 let channel_id = ChannelId::from_proto(request.channel_id);
3352
3353 let open_response = db
3354 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3355 .await?;
3356
3357 let collaborators = open_response.collaborators.clone();
3358 response.send(open_response)?;
3359
3360 let update = UpdateChannelBufferCollaborators {
3361 channel_id: channel_id.to_proto(),
3362 collaborators: collaborators.clone(),
3363 };
3364 channel_buffer_updated(
3365 session.connection_id,
3366 collaborators
3367 .iter()
3368 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3369 &update,
3370 &session.peer,
3371 );
3372
3373 Ok(())
3374}
3375
3376/// Edit the channel notes
3377async fn update_channel_buffer(
3378 request: proto::UpdateChannelBuffer,
3379 session: MessageContext,
3380) -> Result<()> {
3381 let db = session.db().await;
3382 let channel_id = ChannelId::from_proto(request.channel_id);
3383
3384 let (collaborators, epoch, version) = db
3385 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3386 .await?;
3387
3388 channel_buffer_updated(
3389 session.connection_id,
3390 collaborators.clone(),
3391 &proto::UpdateChannelBuffer {
3392 channel_id: channel_id.to_proto(),
3393 operations: request.operations,
3394 },
3395 &session.peer,
3396 );
3397
3398 let pool = &*session.connection_pool().await;
3399
3400 let non_collaborators =
3401 pool.channel_connection_ids(channel_id)
3402 .filter_map(|(connection_id, _)| {
3403 if collaborators.contains(&connection_id) {
3404 None
3405 } else {
3406 Some(connection_id)
3407 }
3408 });
3409
3410 broadcast(None, non_collaborators, |peer_id| {
3411 session.peer.send(
3412 peer_id,
3413 proto::UpdateChannels {
3414 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3415 channel_id: channel_id.to_proto(),
3416 epoch: epoch as u64,
3417 version: version.clone(),
3418 }],
3419 ..Default::default()
3420 },
3421 )
3422 });
3423
3424 Ok(())
3425}
3426
3427/// Rejoin the channel notes after a connection blip
3428async fn rejoin_channel_buffers(
3429 request: proto::RejoinChannelBuffers,
3430 response: Response<proto::RejoinChannelBuffers>,
3431 session: MessageContext,
3432) -> Result<()> {
3433 let db = session.db().await;
3434 let buffers = db
3435 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3436 .await?;
3437
3438 for rejoined_buffer in &buffers {
3439 let collaborators_to_notify = rejoined_buffer
3440 .buffer
3441 .collaborators
3442 .iter()
3443 .filter_map(|c| Some(c.peer_id?.into()));
3444 channel_buffer_updated(
3445 session.connection_id,
3446 collaborators_to_notify,
3447 &proto::UpdateChannelBufferCollaborators {
3448 channel_id: rejoined_buffer.buffer.channel_id,
3449 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3450 },
3451 &session.peer,
3452 );
3453 }
3454
3455 response.send(proto::RejoinChannelBuffersResponse {
3456 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3457 })?;
3458
3459 Ok(())
3460}
3461
3462/// Stop editing the channel notes
3463async fn leave_channel_buffer(
3464 request: proto::LeaveChannelBuffer,
3465 response: Response<proto::LeaveChannelBuffer>,
3466 session: MessageContext,
3467) -> Result<()> {
3468 let db = session.db().await;
3469 let channel_id = ChannelId::from_proto(request.channel_id);
3470
3471 let left_buffer = db
3472 .leave_channel_buffer(channel_id, session.connection_id)
3473 .await?;
3474
3475 response.send(Ack {})?;
3476
3477 channel_buffer_updated(
3478 session.connection_id,
3479 left_buffer.connections,
3480 &proto::UpdateChannelBufferCollaborators {
3481 channel_id: channel_id.to_proto(),
3482 collaborators: left_buffer.collaborators,
3483 },
3484 &session.peer,
3485 );
3486
3487 Ok(())
3488}
3489
3490fn channel_buffer_updated<T: EnvelopedMessage>(
3491 sender_id: ConnectionId,
3492 collaborators: impl IntoIterator<Item = ConnectionId>,
3493 message: &T,
3494 peer: &Peer,
3495) {
3496 broadcast(Some(sender_id), collaborators, |peer_id| {
3497 peer.send(peer_id, message.clone())
3498 });
3499}
3500
3501fn send_notifications(
3502 connection_pool: &ConnectionPool,
3503 peer: &Peer,
3504 notifications: db::NotificationBatch,
3505) {
3506 for (user_id, notification) in notifications {
3507 for connection_id in connection_pool.user_connection_ids(user_id) {
3508 if let Err(error) = peer.send(
3509 connection_id,
3510 proto::AddNotification {
3511 notification: Some(notification.clone()),
3512 },
3513 ) {
3514 tracing::error!(
3515 "failed to send notification to {:?} {}",
3516 connection_id,
3517 error
3518 );
3519 }
3520 }
3521 }
3522}
3523
3524/// Send a message to the channel
3525async fn send_channel_message(
3526 _request: proto::SendChannelMessage,
3527 _response: Response<proto::SendChannelMessage>,
3528 _session: MessageContext,
3529) -> Result<()> {
3530 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3531}
3532
3533/// Delete a channel message
3534async fn remove_channel_message(
3535 _request: proto::RemoveChannelMessage,
3536 _response: Response<proto::RemoveChannelMessage>,
3537 _session: MessageContext,
3538) -> Result<()> {
3539 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3540}
3541
3542async fn update_channel_message(
3543 _request: proto::UpdateChannelMessage,
3544 _response: Response<proto::UpdateChannelMessage>,
3545 _session: MessageContext,
3546) -> Result<()> {
3547 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3548}
3549
3550/// Mark a channel message as read
3551async fn acknowledge_channel_message(
3552 _request: proto::AckChannelMessage,
3553 _session: MessageContext,
3554) -> Result<()> {
3555 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3556}
3557
3558/// Mark a buffer version as synced
3559async fn acknowledge_buffer_version(
3560 request: proto::AckBufferOperation,
3561 session: MessageContext,
3562) -> Result<()> {
3563 let buffer_id = BufferId::from_proto(request.buffer_id);
3564 session
3565 .db()
3566 .await
3567 .observe_buffer_version(
3568 buffer_id,
3569 session.user_id(),
3570 request.epoch as i32,
3571 &request.version,
3572 )
3573 .await?;
3574 Ok(())
3575}
3576
3577/// Start receiving chat updates for a channel
3578async fn join_channel_chat(
3579 _request: proto::JoinChannelChat,
3580 _response: Response<proto::JoinChannelChat>,
3581 _session: MessageContext,
3582) -> Result<()> {
3583 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3584}
3585
3586/// Stop receiving chat updates for a channel
3587async fn leave_channel_chat(
3588 _request: proto::LeaveChannelChat,
3589 _session: MessageContext,
3590) -> Result<()> {
3591 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3592}
3593
3594/// Retrieve the chat history for a channel
3595async fn get_channel_messages(
3596 _request: proto::GetChannelMessages,
3597 _response: Response<proto::GetChannelMessages>,
3598 _session: MessageContext,
3599) -> Result<()> {
3600 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3601}
3602
3603/// Retrieve specific chat messages
3604async fn get_channel_messages_by_id(
3605 _request: proto::GetChannelMessagesById,
3606 _response: Response<proto::GetChannelMessagesById>,
3607 _session: MessageContext,
3608) -> Result<()> {
3609 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3610}
3611
3612/// Retrieve the current users notifications
3613async fn get_notifications(
3614 request: proto::GetNotifications,
3615 response: Response<proto::GetNotifications>,
3616 session: MessageContext,
3617) -> Result<()> {
3618 let notifications = session
3619 .db()
3620 .await
3621 .get_notifications(
3622 session.user_id(),
3623 NOTIFICATION_COUNT_PER_PAGE,
3624 request.before_id.map(db::NotificationId::from_proto),
3625 )
3626 .await?;
3627 response.send(proto::GetNotificationsResponse {
3628 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3629 notifications,
3630 })?;
3631 Ok(())
3632}
3633
3634/// Mark notifications as read
3635async fn mark_notification_as_read(
3636 request: proto::MarkNotificationRead,
3637 response: Response<proto::MarkNotificationRead>,
3638 session: MessageContext,
3639) -> Result<()> {
3640 let database = &session.db().await;
3641 let notifications = database
3642 .mark_notification_as_read_by_id(
3643 session.user_id(),
3644 NotificationId::from_proto(request.notification_id),
3645 )
3646 .await?;
3647 send_notifications(
3648 &*session.connection_pool().await,
3649 &session.peer,
3650 notifications,
3651 );
3652 response.send(proto::Ack {})?;
3653 Ok(())
3654}
3655
3656fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3657 let message = match message {
3658 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3659 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3660 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3661 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3662 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3663 code: frame.code.into(),
3664 reason: frame.reason.as_str().to_owned().into(),
3665 })),
3666 // We should never receive a frame while reading the message, according
3667 // to the `tungstenite` maintainers:
3668 //
3669 // > It cannot occur when you read messages from the WebSocket, but it
3670 // > can be used when you want to send the raw frames (e.g. you want to
3671 // > send the frames to the WebSocket without composing the full message first).
3672 // >
3673 // > — https://github.com/snapview/tungstenite-rs/issues/268
3674 TungsteniteMessage::Frame(_) => {
3675 bail!("received an unexpected frame while reading the message")
3676 }
3677 };
3678
3679 Ok(message)
3680}
3681
3682fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3683 match message {
3684 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3685 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3686 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3687 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3688 AxumMessage::Close(frame) => {
3689 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3690 code: frame.code.into(),
3691 reason: frame.reason.as_ref().into(),
3692 }))
3693 }
3694 }
3695}
3696
3697fn notify_membership_updated(
3698 connection_pool: &mut ConnectionPool,
3699 result: MembershipUpdated,
3700 user_id: UserId,
3701 peer: &Peer,
3702) {
3703 for membership in &result.new_channels.channel_memberships {
3704 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3705 }
3706 for channel_id in &result.removed_channels {
3707 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3708 }
3709
3710 let user_channels_update = proto::UpdateUserChannels {
3711 channel_memberships: result
3712 .new_channels
3713 .channel_memberships
3714 .iter()
3715 .map(|cm| proto::ChannelMembership {
3716 channel_id: cm.channel_id.to_proto(),
3717 role: cm.role.into(),
3718 })
3719 .collect(),
3720 ..Default::default()
3721 };
3722
3723 let mut update = build_channels_update(result.new_channels);
3724 update.delete_channels = result
3725 .removed_channels
3726 .into_iter()
3727 .map(|id| id.to_proto())
3728 .collect();
3729 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3730
3731 for connection_id in connection_pool.user_connection_ids(user_id) {
3732 peer.send(connection_id, user_channels_update.clone())
3733 .trace_err();
3734 peer.send(connection_id, update.clone()).trace_err();
3735 }
3736}
3737
3738fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3739 proto::UpdateUserChannels {
3740 channel_memberships: channels
3741 .channel_memberships
3742 .iter()
3743 .map(|m| proto::ChannelMembership {
3744 channel_id: m.channel_id.to_proto(),
3745 role: m.role.into(),
3746 })
3747 .collect(),
3748 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3749 }
3750}
3751
3752fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3753 let mut update = proto::UpdateChannels::default();
3754
3755 for channel in channels.channels {
3756 update.channels.push(channel.to_proto());
3757 }
3758
3759 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3760
3761 for (channel_id, participants) in channels.channel_participants {
3762 update
3763 .channel_participants
3764 .push(proto::ChannelParticipants {
3765 channel_id: channel_id.to_proto(),
3766 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3767 });
3768 }
3769
3770 for channel in channels.invited_channels {
3771 update.channel_invitations.push(channel.to_proto());
3772 }
3773
3774 update
3775}
3776
3777fn build_initial_contacts_update(
3778 contacts: Vec<db::Contact>,
3779 pool: &ConnectionPool,
3780) -> proto::UpdateContacts {
3781 let mut update = proto::UpdateContacts::default();
3782
3783 for contact in contacts {
3784 match contact {
3785 db::Contact::Accepted { user_id, busy } => {
3786 update.contacts.push(contact_for_user(user_id, busy, pool));
3787 }
3788 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3789 db::Contact::Incoming { user_id } => {
3790 update
3791 .incoming_requests
3792 .push(proto::IncomingContactRequest {
3793 requester_id: user_id.to_proto(),
3794 })
3795 }
3796 }
3797 }
3798
3799 update
3800}
3801
3802fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3803 proto::Contact {
3804 user_id: user_id.to_proto(),
3805 online: pool.is_user_online(user_id),
3806 busy,
3807 }
3808}
3809
3810fn room_updated(room: &proto::Room, peer: &Peer) {
3811 broadcast(
3812 None,
3813 room.participants
3814 .iter()
3815 .filter_map(|participant| Some(participant.peer_id?.into())),
3816 |peer_id| {
3817 peer.send(
3818 peer_id,
3819 proto::RoomUpdated {
3820 room: Some(room.clone()),
3821 },
3822 )
3823 },
3824 );
3825}
3826
3827fn channel_updated(
3828 channel: &db::channel::Model,
3829 room: &proto::Room,
3830 peer: &Peer,
3831 pool: &ConnectionPool,
3832) {
3833 let participants = room
3834 .participants
3835 .iter()
3836 .map(|p| p.user_id)
3837 .collect::<Vec<_>>();
3838
3839 broadcast(
3840 None,
3841 pool.channel_connection_ids(channel.root_id())
3842 .filter_map(|(channel_id, role)| {
3843 role.can_see_channel(channel.visibility)
3844 .then_some(channel_id)
3845 }),
3846 |peer_id| {
3847 peer.send(
3848 peer_id,
3849 proto::UpdateChannels {
3850 channel_participants: vec![proto::ChannelParticipants {
3851 channel_id: channel.id.to_proto(),
3852 participant_user_ids: participants.clone(),
3853 }],
3854 ..Default::default()
3855 },
3856 )
3857 },
3858 );
3859}
3860
3861async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3862 let db = session.db().await;
3863
3864 let contacts = db.get_contacts(user_id).await?;
3865 let busy = db.is_user_busy(user_id).await?;
3866
3867 let pool = session.connection_pool().await;
3868 let updated_contact = contact_for_user(user_id, busy, &pool);
3869 for contact in contacts {
3870 if let db::Contact::Accepted {
3871 user_id: contact_user_id,
3872 ..
3873 } = contact
3874 {
3875 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3876 session
3877 .peer
3878 .send(
3879 contact_conn_id,
3880 proto::UpdateContacts {
3881 contacts: vec![updated_contact.clone()],
3882 remove_contacts: Default::default(),
3883 incoming_requests: Default::default(),
3884 remove_incoming_requests: Default::default(),
3885 outgoing_requests: Default::default(),
3886 remove_outgoing_requests: Default::default(),
3887 },
3888 )
3889 .trace_err();
3890 }
3891 }
3892 }
3893 Ok(())
3894}
3895
3896async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3897 let mut contacts_to_update = HashSet::default();
3898
3899 let room_id;
3900 let canceled_calls_to_user_ids;
3901 let livekit_room;
3902 let delete_livekit_room;
3903 let room;
3904 let channel;
3905
3906 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3907 contacts_to_update.insert(session.user_id());
3908
3909 for project in left_room.left_projects.values() {
3910 project_left(project, session);
3911 }
3912
3913 room_id = RoomId::from_proto(left_room.room.id);
3914 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3915 livekit_room = mem::take(&mut left_room.room.livekit_room);
3916 delete_livekit_room = left_room.deleted;
3917 room = mem::take(&mut left_room.room);
3918 channel = mem::take(&mut left_room.channel);
3919
3920 room_updated(&room, &session.peer);
3921 } else {
3922 return Ok(());
3923 }
3924
3925 if let Some(channel) = channel {
3926 channel_updated(
3927 &channel,
3928 &room,
3929 &session.peer,
3930 &*session.connection_pool().await,
3931 );
3932 }
3933
3934 {
3935 let pool = session.connection_pool().await;
3936 for canceled_user_id in canceled_calls_to_user_ids {
3937 for connection_id in pool.user_connection_ids(canceled_user_id) {
3938 session
3939 .peer
3940 .send(
3941 connection_id,
3942 proto::CallCanceled {
3943 room_id: room_id.to_proto(),
3944 },
3945 )
3946 .trace_err();
3947 }
3948 contacts_to_update.insert(canceled_user_id);
3949 }
3950 }
3951
3952 for contact_user_id in contacts_to_update {
3953 update_user_contacts(contact_user_id, session).await?;
3954 }
3955
3956 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3957 live_kit
3958 .remove_participant(livekit_room.clone(), session.user_id().to_string())
3959 .await
3960 .trace_err();
3961
3962 if delete_livekit_room {
3963 live_kit.delete_room(livekit_room).await.trace_err();
3964 }
3965 }
3966
3967 Ok(())
3968}
3969
3970async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3971 let left_channel_buffers = session
3972 .db()
3973 .await
3974 .leave_channel_buffers(session.connection_id)
3975 .await?;
3976
3977 for left_buffer in left_channel_buffers {
3978 channel_buffer_updated(
3979 session.connection_id,
3980 left_buffer.connections,
3981 &proto::UpdateChannelBufferCollaborators {
3982 channel_id: left_buffer.channel_id.to_proto(),
3983 collaborators: left_buffer.collaborators,
3984 },
3985 &session.peer,
3986 );
3987 }
3988
3989 Ok(())
3990}
3991
3992fn project_left(project: &db::LeftProject, session: &Session) {
3993 for connection_id in &project.connection_ids {
3994 if project.should_unshare {
3995 session
3996 .peer
3997 .send(
3998 *connection_id,
3999 proto::UnshareProject {
4000 project_id: project.id.to_proto(),
4001 },
4002 )
4003 .trace_err();
4004 } else {
4005 session
4006 .peer
4007 .send(
4008 *connection_id,
4009 proto::RemoveProjectCollaborator {
4010 project_id: project.id.to_proto(),
4011 peer_id: Some(session.connection_id.into()),
4012 },
4013 )
4014 .trace_err();
4015 }
4016 }
4017}
4018
4019pub trait ResultExt {
4020 type Ok;
4021
4022 fn trace_err(self) -> Option<Self::Ok>;
4023}
4024
4025impl<T, E> ResultExt for Result<T, E>
4026where
4027 E: std::fmt::Debug,
4028{
4029 type Ok = T;
4030
4031 #[track_caller]
4032 fn trace_err(self) -> Option<T> {
4033 match self {
4034 Ok(value) => Some(value),
4035 Err(error) => {
4036 tracing::error!("{:?}", error);
4037 None
4038 }
4039 }
4040 }
4041}