1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
10 },
11 executor::Executor,
12};
13use anyhow::{Context as _, anyhow, bail};
14use async_tungstenite::tungstenite::{
15 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
16};
17use axum::headers::UserAgent;
18use axum::{
19 Extension, Router, TypedHeader,
20 body::Body,
21 extract::{
22 ConnectInfo, WebSocketUpgrade,
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use core::fmt::{self, Debug, Formatter};
34use futures::TryFutureExt as _;
35use reqwest_client::ReqwestClient;
36use rpc::proto::split_repository_update;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38use tracing::Span;
39use util::paths::PathStyle;
40
41use futures::{
42 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
43 stream::FuturesUnordered,
44};
45use prometheus::{IntGauge, register_int_gauge};
46use rpc::{
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52};
53use semver::Version;
54use serde::{Serialize, Serializer};
55use std::{
56 any::TypeId,
57 future::Future,
58 marker::PhantomData,
59 mem,
60 net::SocketAddr,
61 ops::{Deref, DerefMut},
62 rc::Rc,
63 sync::{
64 Arc, OnceLock,
65 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
66 },
67 time::{Duration, Instant},
68};
69use tokio::sync::{Semaphore, watch};
70use tower::ServiceBuilder;
71use tracing::{
72 Instrument,
73 field::{self},
74 info_span, instrument,
75};
76
77pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
78
79// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
80pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
81
82const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
83const MAX_CONCURRENT_CONNECTIONS: usize = 512;
84
85static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
86
87const TOTAL_DURATION_MS: &str = "total_duration_ms";
88const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
89const QUEUE_DURATION_MS: &str = "queue_duration_ms";
90const HOST_WAITING_MS: &str = "host_waiting_ms";
91
92type MessageHandler =
93 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
94
95pub struct ConnectionGuard;
96
97impl ConnectionGuard {
98 pub fn try_acquire() -> Result<Self, ()> {
99 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
100 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
101 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
102 tracing::error!(
103 "too many concurrent connections: {}",
104 current_connections + 1
105 );
106 return Err(());
107 }
108 Ok(ConnectionGuard)
109 }
110}
111
112impl Drop for ConnectionGuard {
113 fn drop(&mut self) {
114 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
115 }
116}
117
118struct Response<R> {
119 peer: Arc<Peer>,
120 receipt: Receipt<R>,
121 responded: Arc<AtomicBool>,
122}
123
124impl<R: RequestMessage> Response<R> {
125 fn send(self, payload: R::Response) -> Result<()> {
126 self.responded.store(true, SeqCst);
127 self.peer.respond(self.receipt, payload)?;
128 Ok(())
129 }
130}
131
132#[derive(Clone, Debug)]
133pub enum Principal {
134 User(User),
135 Impersonated { user: User, admin: User },
136}
137
138impl Principal {
139 fn update_span(&self, span: &tracing::Span) {
140 match &self {
141 Principal::User(user) => {
142 span.record("user_id", user.id.0);
143 span.record("login", &user.github_login);
144 }
145 Principal::Impersonated { user, admin } => {
146 span.record("user_id", user.id.0);
147 span.record("login", &user.github_login);
148 span.record("impersonator", &admin.github_login);
149 }
150 }
151 }
152}
153
154#[derive(Clone)]
155struct MessageContext {
156 session: Session,
157 span: tracing::Span,
158}
159
160impl Deref for MessageContext {
161 type Target = Session;
162
163 fn deref(&self) -> &Self::Target {
164 &self.session
165 }
166}
167
168impl MessageContext {
169 pub fn forward_request<T: RequestMessage>(
170 &self,
171 receiver_id: ConnectionId,
172 request: T,
173 ) -> impl Future<Output = anyhow::Result<T::Response>> {
174 let request_start_time = Instant::now();
175 let span = self.span.clone();
176 tracing::info!("start forwarding request");
177 self.peer
178 .forward_request(self.connection_id, receiver_id, request)
179 .inspect(move |_| {
180 span.record(
181 HOST_WAITING_MS,
182 request_start_time.elapsed().as_micros() as f64 / 1000.0,
183 );
184 })
185 .inspect_err(|_| tracing::error!("error forwarding request"))
186 .inspect_ok(|_| tracing::info!("finished forwarding request"))
187 }
188}
189
190#[derive(Clone)]
191struct Session {
192 principal: Principal,
193 connection_id: ConnectionId,
194 db: Arc<tokio::sync::Mutex<DbHandle>>,
195 peer: Arc<Peer>,
196 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
197 app_state: Arc<AppState>,
198 supermaven_client: Option<Arc<SupermavenAdminApi>>,
199 /// The GeoIP country code for the user.
200 #[allow(unused)]
201 geoip_country_code: Option<String>,
202 #[allow(unused)]
203 system_id: Option<String>,
204 _executor: Executor,
205}
206
207impl Session {
208 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
209 #[cfg(test)]
210 tokio::task::yield_now().await;
211 let guard = self.db.lock().await;
212 #[cfg(test)]
213 tokio::task::yield_now().await;
214 guard
215 }
216
217 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
218 #[cfg(test)]
219 tokio::task::yield_now().await;
220 let guard = self.connection_pool.lock();
221 ConnectionPoolGuard {
222 guard,
223 _not_send: PhantomData,
224 }
225 }
226
227 fn is_staff(&self) -> bool {
228 match &self.principal {
229 Principal::User(user) => user.admin,
230 Principal::Impersonated { .. } => true,
231 }
232 }
233
234 fn user_id(&self) -> UserId {
235 match &self.principal {
236 Principal::User(user) => user.id,
237 Principal::Impersonated { user, .. } => user.id,
238 }
239 }
240
241 pub fn email(&self) -> Option<String> {
242 match &self.principal {
243 Principal::User(user) => user.email_address.clone(),
244 Principal::Impersonated { user, .. } => user.email_address.clone(),
245 }
246 }
247}
248
249impl Debug for Session {
250 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
251 let mut result = f.debug_struct("Session");
252 match &self.principal {
253 Principal::User(user) => {
254 result.field("user", &user.github_login);
255 }
256 Principal::Impersonated { user, admin } => {
257 result.field("user", &user.github_login);
258 result.field("impersonator", &admin.github_login);
259 }
260 }
261 result.field("connection_id", &self.connection_id).finish()
262 }
263}
264
265struct DbHandle(Arc<Database>);
266
267impl Deref for DbHandle {
268 type Target = Database;
269
270 fn deref(&self) -> &Self::Target {
271 self.0.as_ref()
272 }
273}
274
275pub struct Server {
276 id: parking_lot::Mutex<ServerId>,
277 peer: Arc<Peer>,
278 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
279 app_state: Arc<AppState>,
280 handlers: HashMap<TypeId, MessageHandler>,
281 teardown: watch::Sender<bool>,
282}
283
284pub(crate) struct ConnectionPoolGuard<'a> {
285 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
286 _not_send: PhantomData<Rc<()>>,
287}
288
289#[derive(Serialize)]
290pub struct ServerSnapshot<'a> {
291 peer: &'a Peer,
292 #[serde(serialize_with = "serialize_deref")]
293 connection_pool: ConnectionPoolGuard<'a>,
294}
295
296pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
297where
298 S: Serializer,
299 T: Deref<Target = U>,
300 U: Serialize,
301{
302 Serialize::serialize(value.deref(), serializer)
303}
304
305impl Server {
306 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
307 let mut server = Self {
308 id: parking_lot::Mutex::new(id),
309 peer: Peer::new(id.0 as u32),
310 app_state,
311 connection_pool: Default::default(),
312 handlers: Default::default(),
313 teardown: watch::channel(false).0,
314 };
315
316 server
317 .add_request_handler(ping)
318 .add_request_handler(create_room)
319 .add_request_handler(join_room)
320 .add_request_handler(rejoin_room)
321 .add_request_handler(leave_room)
322 .add_request_handler(set_room_participant_role)
323 .add_request_handler(call)
324 .add_request_handler(cancel_call)
325 .add_message_handler(decline_call)
326 .add_request_handler(update_participant_location)
327 .add_request_handler(share_project)
328 .add_message_handler(unshare_project)
329 .add_request_handler(join_project)
330 .add_message_handler(leave_project)
331 .add_request_handler(update_project)
332 .add_request_handler(update_worktree)
333 .add_request_handler(update_repository)
334 .add_request_handler(remove_repository)
335 .add_message_handler(start_language_server)
336 .add_message_handler(update_language_server)
337 .add_message_handler(update_diagnostic_summary)
338 .add_message_handler(update_worktree_settings)
339 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
340 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
341 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
342 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
343 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
344 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
345 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
346 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
347 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
348 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
349 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
350 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
351 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
352 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
353 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
354 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
355 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
356 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
357 .add_request_handler(
358 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
359 )
360 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
361 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
362 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
363 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
364 .add_request_handler(
365 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
366 )
367 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
368 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
369 .add_request_handler(
370 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
371 )
372 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
373 .add_request_handler(
374 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
375 )
376 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
377 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
378 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
379 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
380 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
381 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
382 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
383 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
384 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
385 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
386 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
387 .add_request_handler(
388 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
389 )
390 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
391 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
392 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
393 .add_request_handler(lsp_query)
394 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
395 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
396 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
397 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
398 .add_message_handler(create_buffer_for_peer)
399 .add_message_handler(create_image_for_peer)
400 .add_request_handler(update_buffer)
401 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
402 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
403 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
404 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
405 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
406 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
407 .add_message_handler(
408 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
409 )
410 .add_request_handler(get_users)
411 .add_request_handler(fuzzy_search_users)
412 .add_request_handler(request_contact)
413 .add_request_handler(remove_contact)
414 .add_request_handler(respond_to_contact_request)
415 .add_message_handler(subscribe_to_channels)
416 .add_request_handler(create_channel)
417 .add_request_handler(delete_channel)
418 .add_request_handler(invite_channel_member)
419 .add_request_handler(remove_channel_member)
420 .add_request_handler(set_channel_member_role)
421 .add_request_handler(set_channel_visibility)
422 .add_request_handler(rename_channel)
423 .add_request_handler(join_channel_buffer)
424 .add_request_handler(leave_channel_buffer)
425 .add_message_handler(update_channel_buffer)
426 .add_request_handler(rejoin_channel_buffers)
427 .add_request_handler(get_channel_members)
428 .add_request_handler(respond_to_channel_invite)
429 .add_request_handler(join_channel)
430 .add_request_handler(join_channel_chat)
431 .add_message_handler(leave_channel_chat)
432 .add_request_handler(send_channel_message)
433 .add_request_handler(remove_channel_message)
434 .add_request_handler(update_channel_message)
435 .add_request_handler(get_channel_messages)
436 .add_request_handler(get_channel_messages_by_id)
437 .add_request_handler(get_notifications)
438 .add_request_handler(mark_notification_as_read)
439 .add_request_handler(move_channel)
440 .add_request_handler(reorder_channel)
441 .add_request_handler(follow)
442 .add_message_handler(unfollow)
443 .add_message_handler(update_followers)
444 .add_message_handler(acknowledge_channel_message)
445 .add_message_handler(acknowledge_buffer_version)
446 .add_request_handler(get_supermaven_api_key)
447 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
448 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
449 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
450 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
451 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
452 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
453 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
454 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
455 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
456 .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
457 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
458 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
459 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
460 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
461 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
462 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
463 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
464 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
465 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
466 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
467 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
468 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
469 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
470 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
471 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
472 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
473 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
474 .add_message_handler(update_context)
475 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
476 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
477
478 Arc::new(server)
479 }
480
481 pub async fn start(&self) -> Result<()> {
482 let server_id = *self.id.lock();
483 let app_state = self.app_state.clone();
484 let peer = self.peer.clone();
485 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
486 let pool = self.connection_pool.clone();
487 let livekit_client = self.app_state.livekit_client.clone();
488
489 let span = info_span!("start server");
490 self.app_state.executor.spawn_detached(
491 async move {
492 tracing::info!("waiting for cleanup timeout");
493 timeout.await;
494 tracing::info!("cleanup timeout expired, retrieving stale rooms");
495
496 app_state
497 .db
498 .delete_stale_channel_chat_participants(
499 &app_state.config.zed_environment,
500 server_id,
501 )
502 .await
503 .trace_err();
504
505 if let Some((room_ids, channel_ids)) = app_state
506 .db
507 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
508 .await
509 .trace_err()
510 {
511 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
512 tracing::info!(
513 stale_channel_buffer_count = channel_ids.len(),
514 "retrieved stale channel buffers"
515 );
516
517 for channel_id in channel_ids {
518 if let Some(refreshed_channel_buffer) = app_state
519 .db
520 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
521 .await
522 .trace_err()
523 {
524 for connection_id in refreshed_channel_buffer.connection_ids {
525 peer.send(
526 connection_id,
527 proto::UpdateChannelBufferCollaborators {
528 channel_id: channel_id.to_proto(),
529 collaborators: refreshed_channel_buffer
530 .collaborators
531 .clone(),
532 },
533 )
534 .trace_err();
535 }
536 }
537 }
538
539 for room_id in room_ids {
540 let mut contacts_to_update = HashSet::default();
541 let mut canceled_calls_to_user_ids = Vec::new();
542 let mut livekit_room = String::new();
543 let mut delete_livekit_room = false;
544
545 if let Some(mut refreshed_room) = app_state
546 .db
547 .clear_stale_room_participants(room_id, server_id)
548 .await
549 .trace_err()
550 {
551 tracing::info!(
552 room_id = room_id.0,
553 new_participant_count = refreshed_room.room.participants.len(),
554 "refreshed room"
555 );
556 room_updated(&refreshed_room.room, &peer);
557 if let Some(channel) = refreshed_room.channel.as_ref() {
558 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
559 }
560 contacts_to_update
561 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
562 contacts_to_update
563 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
564 canceled_calls_to_user_ids =
565 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
566 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
567 delete_livekit_room = refreshed_room.room.participants.is_empty();
568 }
569
570 {
571 let pool = pool.lock();
572 for canceled_user_id in canceled_calls_to_user_ids {
573 for connection_id in pool.user_connection_ids(canceled_user_id) {
574 peer.send(
575 connection_id,
576 proto::CallCanceled {
577 room_id: room_id.to_proto(),
578 },
579 )
580 .trace_err();
581 }
582 }
583 }
584
585 for user_id in contacts_to_update {
586 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
587 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
588 if let Some((busy, contacts)) = busy.zip(contacts) {
589 let pool = pool.lock();
590 let updated_contact = contact_for_user(user_id, busy, &pool);
591 for contact in contacts {
592 if let db::Contact::Accepted {
593 user_id: contact_user_id,
594 ..
595 } = contact
596 {
597 for contact_conn_id in
598 pool.user_connection_ids(contact_user_id)
599 {
600 peer.send(
601 contact_conn_id,
602 proto::UpdateContacts {
603 contacts: vec![updated_contact.clone()],
604 remove_contacts: Default::default(),
605 incoming_requests: Default::default(),
606 remove_incoming_requests: Default::default(),
607 outgoing_requests: Default::default(),
608 remove_outgoing_requests: Default::default(),
609 },
610 )
611 .trace_err();
612 }
613 }
614 }
615 }
616 }
617
618 if let Some(live_kit) = livekit_client.as_ref()
619 && delete_livekit_room
620 {
621 live_kit.delete_room(livekit_room).await.trace_err();
622 }
623 }
624 }
625
626 app_state
627 .db
628 .delete_stale_channel_chat_participants(
629 &app_state.config.zed_environment,
630 server_id,
631 )
632 .await
633 .trace_err();
634
635 app_state
636 .db
637 .clear_old_worktree_entries(server_id)
638 .await
639 .trace_err();
640
641 app_state
642 .db
643 .delete_stale_servers(&app_state.config.zed_environment, server_id)
644 .await
645 .trace_err();
646 }
647 .instrument(span),
648 );
649 Ok(())
650 }
651
652 pub fn teardown(&self) {
653 self.peer.teardown();
654 self.connection_pool.lock().reset();
655 let _ = self.teardown.send(true);
656 }
657
658 #[cfg(test)]
659 pub fn reset(&self, id: ServerId) {
660 self.teardown();
661 *self.id.lock() = id;
662 self.peer.reset(id.0 as u32);
663 let _ = self.teardown.send(false);
664 }
665
666 #[cfg(test)]
667 pub fn id(&self) -> ServerId {
668 *self.id.lock()
669 }
670
671 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
672 where
673 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
674 Fut: 'static + Send + Future<Output = Result<()>>,
675 M: EnvelopedMessage,
676 {
677 let prev_handler = self.handlers.insert(
678 TypeId::of::<M>(),
679 Box::new(move |envelope, session, span| {
680 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
681 let received_at = envelope.received_at;
682 tracing::info!("message received");
683 let start_time = Instant::now();
684 let future = (handler)(
685 *envelope,
686 MessageContext {
687 session,
688 span: span.clone(),
689 },
690 );
691 async move {
692 let result = future.await;
693 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
694 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
695 let queue_duration_ms = total_duration_ms - processing_duration_ms;
696 span.record(TOTAL_DURATION_MS, total_duration_ms);
697 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
698 span.record(QUEUE_DURATION_MS, queue_duration_ms);
699 match result {
700 Err(error) => {
701 tracing::error!(?error, "error handling message")
702 }
703 Ok(()) => tracing::info!("finished handling message"),
704 }
705 }
706 .boxed()
707 }),
708 );
709 if prev_handler.is_some() {
710 panic!("registered a handler for the same message twice");
711 }
712 self
713 }
714
715 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
716 where
717 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
718 Fut: 'static + Send + Future<Output = Result<()>>,
719 M: EnvelopedMessage,
720 {
721 self.add_handler(move |envelope, session| handler(envelope.payload, session));
722 self
723 }
724
725 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
726 where
727 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
728 Fut: Send + Future<Output = Result<()>>,
729 M: RequestMessage,
730 {
731 let handler = Arc::new(handler);
732 self.add_handler(move |envelope, session| {
733 let receipt = envelope.receipt();
734 let handler = handler.clone();
735 async move {
736 let peer = session.peer.clone();
737 let responded = Arc::new(AtomicBool::default());
738 let response = Response {
739 peer: peer.clone(),
740 responded: responded.clone(),
741 receipt,
742 };
743 match (handler)(envelope.payload, response, session).await {
744 Ok(()) => {
745 if responded.load(std::sync::atomic::Ordering::SeqCst) {
746 Ok(())
747 } else {
748 Err(anyhow!("handler did not send a response"))?
749 }
750 }
751 Err(error) => {
752 let proto_err = match &error {
753 Error::Internal(err) => err.to_proto(),
754 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
755 };
756 peer.respond_with_error(receipt, proto_err)?;
757 Err(error)
758 }
759 }
760 }
761 })
762 }
763
764 pub fn handle_connection(
765 self: &Arc<Self>,
766 connection: Connection,
767 address: String,
768 principal: Principal,
769 zed_version: ZedVersion,
770 release_channel: Option<String>,
771 user_agent: Option<String>,
772 geoip_country_code: Option<String>,
773 system_id: Option<String>,
774 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
775 executor: Executor,
776 connection_guard: Option<ConnectionGuard>,
777 ) -> impl Future<Output = ()> + use<> {
778 let this = self.clone();
779 let span = info_span!("handle connection", %address,
780 connection_id=field::Empty,
781 user_id=field::Empty,
782 login=field::Empty,
783 impersonator=field::Empty,
784 user_agent=field::Empty,
785 geoip_country_code=field::Empty,
786 release_channel=field::Empty,
787 );
788 principal.update_span(&span);
789 if let Some(user_agent) = user_agent {
790 span.record("user_agent", user_agent);
791 }
792 if let Some(release_channel) = release_channel {
793 span.record("release_channel", release_channel);
794 }
795
796 if let Some(country_code) = geoip_country_code.as_ref() {
797 span.record("geoip_country_code", country_code);
798 }
799
800 let mut teardown = self.teardown.subscribe();
801 async move {
802 if *teardown.borrow() {
803 tracing::error!("server is tearing down");
804 return;
805 }
806
807 let (connection_id, handle_io, mut incoming_rx) =
808 this.peer.add_connection(connection, {
809 let executor = executor.clone();
810 move |duration| executor.sleep(duration)
811 });
812 tracing::Span::current().record("connection_id", format!("{}", connection_id));
813
814 tracing::info!("connection opened");
815
816 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
817 let http_client = match ReqwestClient::user_agent(&user_agent) {
818 Ok(http_client) => Arc::new(http_client),
819 Err(error) => {
820 tracing::error!(?error, "failed to create HTTP client");
821 return;
822 }
823 };
824
825 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
826 |supermaven_admin_api_key| {
827 Arc::new(SupermavenAdminApi::new(
828 supermaven_admin_api_key.to_string(),
829 http_client.clone(),
830 ))
831 },
832 );
833
834 let session = Session {
835 principal: principal.clone(),
836 connection_id,
837 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
838 peer: this.peer.clone(),
839 connection_pool: this.connection_pool.clone(),
840 app_state: this.app_state.clone(),
841 geoip_country_code,
842 system_id,
843 _executor: executor.clone(),
844 supermaven_client,
845 };
846
847 if let Err(error) = this
848 .send_initial_client_update(
849 connection_id,
850 zed_version,
851 send_connection_id,
852 &session,
853 )
854 .await
855 {
856 tracing::error!(?error, "failed to send initial client update");
857 return;
858 }
859 drop(connection_guard);
860
861 let handle_io = handle_io.fuse();
862 futures::pin_mut!(handle_io);
863
864 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
865 // This prevents deadlocks when e.g., client A performs a request to client B and
866 // client B performs a request to client A. If both clients stop processing further
867 // messages until their respective request completes, they won't have a chance to
868 // respond to the other client's request and cause a deadlock.
869 //
870 // This arrangement ensures we will attempt to process earlier messages first, but fall
871 // back to processing messages arrived later in the spirit of making progress.
872 const MAX_CONCURRENT_HANDLERS: usize = 256;
873 let mut foreground_message_handlers = FuturesUnordered::new();
874 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
875 let get_concurrent_handlers = {
876 let concurrent_handlers = concurrent_handlers.clone();
877 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
878 };
879 loop {
880 let next_message = async {
881 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
882 let message = incoming_rx.next().await;
883 // Cache the concurrent_handlers here, so that we know what the
884 // queue looks like as each handler starts
885 (permit, message, get_concurrent_handlers())
886 }
887 .fuse();
888 futures::pin_mut!(next_message);
889 futures::select_biased! {
890 _ = teardown.changed().fuse() => return,
891 result = handle_io => {
892 if let Err(error) = result {
893 tracing::error!(?error, "error handling I/O");
894 }
895 break;
896 }
897 _ = foreground_message_handlers.next() => {}
898 next_message = next_message => {
899 let (permit, message, concurrent_handlers) = next_message;
900 if let Some(message) = message {
901 let type_name = message.payload_type_name();
902 // note: we copy all the fields from the parent span so we can query them in the logs.
903 // (https://github.com/tokio-rs/tracing/issues/2670).
904 let span = tracing::info_span!("receive message",
905 %connection_id,
906 %address,
907 type_name,
908 concurrent_handlers,
909 user_id=field::Empty,
910 login=field::Empty,
911 impersonator=field::Empty,
912 lsp_query_request=field::Empty,
913 release_channel=field::Empty,
914 { TOTAL_DURATION_MS }=field::Empty,
915 { PROCESSING_DURATION_MS }=field::Empty,
916 { QUEUE_DURATION_MS }=field::Empty,
917 { HOST_WAITING_MS }=field::Empty
918 );
919 principal.update_span(&span);
920 let span_enter = span.enter();
921 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
922 let is_background = message.is_background();
923 let handle_message = (handler)(message, session.clone(), span.clone());
924 drop(span_enter);
925
926 let handle_message = async move {
927 handle_message.await;
928 drop(permit);
929 }.instrument(span);
930 if is_background {
931 executor.spawn_detached(handle_message);
932 } else {
933 foreground_message_handlers.push(handle_message);
934 }
935 } else {
936 tracing::error!("no message handler");
937 }
938 } else {
939 tracing::info!("connection closed");
940 break;
941 }
942 }
943 }
944 }
945
946 drop(foreground_message_handlers);
947 let concurrent_handlers = get_concurrent_handlers();
948 tracing::info!(concurrent_handlers, "signing out");
949 if let Err(error) = connection_lost(session, teardown, executor).await {
950 tracing::error!(?error, "error signing out");
951 }
952 }
953 .instrument(span)
954 }
955
956 async fn send_initial_client_update(
957 &self,
958 connection_id: ConnectionId,
959 zed_version: ZedVersion,
960 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
961 session: &Session,
962 ) -> Result<()> {
963 self.peer.send(
964 connection_id,
965 proto::Hello {
966 peer_id: Some(connection_id.into()),
967 },
968 )?;
969 tracing::info!("sent hello message");
970 if let Some(send_connection_id) = send_connection_id.take() {
971 let _ = send_connection_id.send(connection_id);
972 }
973
974 match &session.principal {
975 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
976 if !user.connected_once {
977 self.peer.send(connection_id, proto::ShowContacts {})?;
978 self.app_state
979 .db
980 .set_user_connected_once(user.id, true)
981 .await?;
982 }
983
984 let contacts = self.app_state.db.get_contacts(user.id).await?;
985
986 {
987 let mut pool = self.connection_pool.lock();
988 pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
989 self.peer.send(
990 connection_id,
991 build_initial_contacts_update(contacts, &pool),
992 )?;
993 }
994
995 if should_auto_subscribe_to_channels(&zed_version) {
996 subscribe_user_to_channels(user.id, session).await?;
997 }
998
999 if let Some(incoming_call) =
1000 self.app_state.db.incoming_call_for_user(user.id).await?
1001 {
1002 self.peer.send(connection_id, incoming_call)?;
1003 }
1004
1005 update_user_contacts(user.id, session).await?;
1006 }
1007 }
1008
1009 Ok(())
1010 }
1011
1012 pub async fn invite_code_redeemed(
1013 self: &Arc<Self>,
1014 inviter_id: UserId,
1015 invitee_id: UserId,
1016 ) -> Result<()> {
1017 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1018 && let Some(code) = &user.invite_code
1019 {
1020 let pool = self.connection_pool.lock();
1021 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1022 for connection_id in pool.user_connection_ids(inviter_id) {
1023 self.peer.send(
1024 connection_id,
1025 proto::UpdateContacts {
1026 contacts: vec![invitee_contact.clone()],
1027 ..Default::default()
1028 },
1029 )?;
1030 self.peer.send(
1031 connection_id,
1032 proto::UpdateInviteInfo {
1033 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1034 count: user.invite_count as u32,
1035 },
1036 )?;
1037 }
1038 }
1039 Ok(())
1040 }
1041
1042 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1043 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1044 && let Some(invite_code) = &user.invite_code
1045 {
1046 let pool = self.connection_pool.lock();
1047 for connection_id in pool.user_connection_ids(user_id) {
1048 self.peer.send(
1049 connection_id,
1050 proto::UpdateInviteInfo {
1051 url: format!(
1052 "{}{}",
1053 self.app_state.config.invite_link_prefix, invite_code
1054 ),
1055 count: user.invite_count as u32,
1056 },
1057 )?;
1058 }
1059 }
1060 Ok(())
1061 }
1062
1063 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1064 ServerSnapshot {
1065 connection_pool: ConnectionPoolGuard {
1066 guard: self.connection_pool.lock(),
1067 _not_send: PhantomData,
1068 },
1069 peer: &self.peer,
1070 }
1071 }
1072}
1073
1074impl Deref for ConnectionPoolGuard<'_> {
1075 type Target = ConnectionPool;
1076
1077 fn deref(&self) -> &Self::Target {
1078 &self.guard
1079 }
1080}
1081
1082impl DerefMut for ConnectionPoolGuard<'_> {
1083 fn deref_mut(&mut self) -> &mut Self::Target {
1084 &mut self.guard
1085 }
1086}
1087
1088impl Drop for ConnectionPoolGuard<'_> {
1089 fn drop(&mut self) {
1090 #[cfg(test)]
1091 self.check_invariants();
1092 }
1093}
1094
1095fn broadcast<F>(
1096 sender_id: Option<ConnectionId>,
1097 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1098 mut f: F,
1099) where
1100 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1101{
1102 for receiver_id in receiver_ids {
1103 if Some(receiver_id) != sender_id
1104 && let Err(error) = f(receiver_id)
1105 {
1106 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1107 }
1108 }
1109}
1110
1111pub struct ProtocolVersion(u32);
1112
1113impl Header for ProtocolVersion {
1114 fn name() -> &'static HeaderName {
1115 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1116 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1117 }
1118
1119 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1120 where
1121 Self: Sized,
1122 I: Iterator<Item = &'i axum::http::HeaderValue>,
1123 {
1124 let version = values
1125 .next()
1126 .ok_or_else(axum::headers::Error::invalid)?
1127 .to_str()
1128 .map_err(|_| axum::headers::Error::invalid())?
1129 .parse()
1130 .map_err(|_| axum::headers::Error::invalid())?;
1131 Ok(Self(version))
1132 }
1133
1134 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1135 values.extend([self.0.to_string().parse().unwrap()]);
1136 }
1137}
1138
1139pub struct AppVersionHeader(Version);
1140impl Header for AppVersionHeader {
1141 fn name() -> &'static HeaderName {
1142 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1143 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1144 }
1145
1146 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1147 where
1148 Self: Sized,
1149 I: Iterator<Item = &'i axum::http::HeaderValue>,
1150 {
1151 let version = values
1152 .next()
1153 .ok_or_else(axum::headers::Error::invalid)?
1154 .to_str()
1155 .map_err(|_| axum::headers::Error::invalid())?
1156 .parse()
1157 .map_err(|_| axum::headers::Error::invalid())?;
1158 Ok(Self(version))
1159 }
1160
1161 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1162 values.extend([self.0.to_string().parse().unwrap()]);
1163 }
1164}
1165
1166#[derive(Debug)]
1167pub struct ReleaseChannelHeader(String);
1168
1169impl Header for ReleaseChannelHeader {
1170 fn name() -> &'static HeaderName {
1171 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1172 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1173 }
1174
1175 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1176 where
1177 Self: Sized,
1178 I: Iterator<Item = &'i axum::http::HeaderValue>,
1179 {
1180 Ok(Self(
1181 values
1182 .next()
1183 .ok_or_else(axum::headers::Error::invalid)?
1184 .to_str()
1185 .map_err(|_| axum::headers::Error::invalid())?
1186 .to_owned(),
1187 ))
1188 }
1189
1190 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1191 values.extend([self.0.parse().unwrap()]);
1192 }
1193}
1194
1195pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1196 Router::new()
1197 .route("/rpc", get(handle_websocket_request))
1198 .layer(
1199 ServiceBuilder::new()
1200 .layer(Extension(server.app_state.clone()))
1201 .layer(middleware::from_fn(auth::validate_header)),
1202 )
1203 .route("/metrics", get(handle_metrics))
1204 .layer(Extension(server))
1205}
1206
1207pub async fn handle_websocket_request(
1208 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1209 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1210 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1211 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1212 Extension(server): Extension<Arc<Server>>,
1213 Extension(principal): Extension<Principal>,
1214 user_agent: Option<TypedHeader<UserAgent>>,
1215 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1216 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1217 ws: WebSocketUpgrade,
1218) -> axum::response::Response {
1219 if protocol_version != rpc::PROTOCOL_VERSION {
1220 return (
1221 StatusCode::UPGRADE_REQUIRED,
1222 "client must be upgraded".to_string(),
1223 )
1224 .into_response();
1225 }
1226
1227 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1228 return (
1229 StatusCode::UPGRADE_REQUIRED,
1230 "no version header found".to_string(),
1231 )
1232 .into_response();
1233 };
1234
1235 let release_channel = release_channel_header.map(|header| header.0.0);
1236
1237 if !version.can_collaborate() {
1238 return (
1239 StatusCode::UPGRADE_REQUIRED,
1240 "client must be upgraded".to_string(),
1241 )
1242 .into_response();
1243 }
1244
1245 let socket_address = socket_address.to_string();
1246
1247 // Acquire connection guard before WebSocket upgrade
1248 let connection_guard = match ConnectionGuard::try_acquire() {
1249 Ok(guard) => guard,
1250 Err(()) => {
1251 return (
1252 StatusCode::SERVICE_UNAVAILABLE,
1253 "Too many concurrent connections",
1254 )
1255 .into_response();
1256 }
1257 };
1258
1259 ws.on_upgrade(move |socket| {
1260 let socket = socket
1261 .map_ok(to_tungstenite_message)
1262 .err_into()
1263 .with(|message| async move { to_axum_message(message) });
1264 let connection = Connection::new(Box::pin(socket));
1265 async move {
1266 server
1267 .handle_connection(
1268 connection,
1269 socket_address,
1270 principal,
1271 version,
1272 release_channel,
1273 user_agent.map(|header| header.to_string()),
1274 country_code_header.map(|header| header.to_string()),
1275 system_id_header.map(|header| header.to_string()),
1276 None,
1277 Executor::Production,
1278 Some(connection_guard),
1279 )
1280 .await;
1281 }
1282 })
1283}
1284
1285pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1286 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1287 let connections_metric = CONNECTIONS_METRIC
1288 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1289
1290 let connections = server
1291 .connection_pool
1292 .lock()
1293 .connections()
1294 .filter(|connection| !connection.admin)
1295 .count();
1296 connections_metric.set(connections as _);
1297
1298 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1299 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1300 register_int_gauge!(
1301 "shared_projects",
1302 "number of open projects with one or more guests"
1303 )
1304 .unwrap()
1305 });
1306
1307 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1308 shared_projects_metric.set(shared_projects as _);
1309
1310 let encoder = prometheus::TextEncoder::new();
1311 let metric_families = prometheus::gather();
1312 let encoded_metrics = encoder
1313 .encode_to_string(&metric_families)
1314 .map_err(|err| anyhow!("{err}"))?;
1315 Ok(encoded_metrics)
1316}
1317
1318#[instrument(err, skip(executor))]
1319async fn connection_lost(
1320 session: Session,
1321 mut teardown: watch::Receiver<bool>,
1322 executor: Executor,
1323) -> Result<()> {
1324 session.peer.disconnect(session.connection_id);
1325 session
1326 .connection_pool()
1327 .await
1328 .remove_connection(session.connection_id)?;
1329
1330 session
1331 .db()
1332 .await
1333 .connection_lost(session.connection_id)
1334 .await
1335 .trace_err();
1336
1337 futures::select_biased! {
1338 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1339
1340 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1341 leave_room_for_session(&session, session.connection_id).await.trace_err();
1342 leave_channel_buffers_for_session(&session)
1343 .await
1344 .trace_err();
1345
1346 if !session
1347 .connection_pool()
1348 .await
1349 .is_user_online(session.user_id())
1350 {
1351 let db = session.db().await;
1352 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1353 room_updated(&room, &session.peer);
1354 }
1355 }
1356
1357 update_user_contacts(session.user_id(), &session).await?;
1358 },
1359 _ = teardown.changed().fuse() => {}
1360 }
1361
1362 Ok(())
1363}
1364
1365/// Acknowledges a ping from a client, used to keep the connection alive.
1366async fn ping(
1367 _: proto::Ping,
1368 response: Response<proto::Ping>,
1369 _session: MessageContext,
1370) -> Result<()> {
1371 response.send(proto::Ack {})?;
1372 Ok(())
1373}
1374
1375/// Creates a new room for calling (outside of channels)
1376async fn create_room(
1377 _request: proto::CreateRoom,
1378 response: Response<proto::CreateRoom>,
1379 session: MessageContext,
1380) -> Result<()> {
1381 let livekit_room = nanoid::nanoid!(30);
1382
1383 let live_kit_connection_info = util::maybe!(async {
1384 let live_kit = session.app_state.livekit_client.as_ref();
1385 let live_kit = live_kit?;
1386 let user_id = session.user_id().to_string();
1387
1388 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1389
1390 Some(proto::LiveKitConnectionInfo {
1391 server_url: live_kit.url().into(),
1392 token,
1393 can_publish: true,
1394 })
1395 })
1396 .await;
1397
1398 let room = session
1399 .db()
1400 .await
1401 .create_room(session.user_id(), session.connection_id, &livekit_room)
1402 .await?;
1403
1404 response.send(proto::CreateRoomResponse {
1405 room: Some(room.clone()),
1406 live_kit_connection_info,
1407 })?;
1408
1409 update_user_contacts(session.user_id(), &session).await?;
1410 Ok(())
1411}
1412
1413/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1414async fn join_room(
1415 request: proto::JoinRoom,
1416 response: Response<proto::JoinRoom>,
1417 session: MessageContext,
1418) -> Result<()> {
1419 let room_id = RoomId::from_proto(request.id);
1420
1421 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1422
1423 if let Some(channel_id) = channel_id {
1424 return join_channel_internal(channel_id, Box::new(response), session).await;
1425 }
1426
1427 let joined_room = {
1428 let room = session
1429 .db()
1430 .await
1431 .join_room(room_id, session.user_id(), session.connection_id)
1432 .await?;
1433 room_updated(&room.room, &session.peer);
1434 room.into_inner()
1435 };
1436
1437 for connection_id in session
1438 .connection_pool()
1439 .await
1440 .user_connection_ids(session.user_id())
1441 {
1442 session
1443 .peer
1444 .send(
1445 connection_id,
1446 proto::CallCanceled {
1447 room_id: room_id.to_proto(),
1448 },
1449 )
1450 .trace_err();
1451 }
1452
1453 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1454 {
1455 live_kit
1456 .room_token(
1457 &joined_room.room.livekit_room,
1458 &session.user_id().to_string(),
1459 )
1460 .trace_err()
1461 .map(|token| proto::LiveKitConnectionInfo {
1462 server_url: live_kit.url().into(),
1463 token,
1464 can_publish: true,
1465 })
1466 } else {
1467 None
1468 };
1469
1470 response.send(proto::JoinRoomResponse {
1471 room: Some(joined_room.room),
1472 channel_id: None,
1473 live_kit_connection_info,
1474 })?;
1475
1476 update_user_contacts(session.user_id(), &session).await?;
1477 Ok(())
1478}
1479
1480/// Rejoin room is used to reconnect to a room after connection errors.
1481async fn rejoin_room(
1482 request: proto::RejoinRoom,
1483 response: Response<proto::RejoinRoom>,
1484 session: MessageContext,
1485) -> Result<()> {
1486 let room;
1487 let channel;
1488 {
1489 let mut rejoined_room = session
1490 .db()
1491 .await
1492 .rejoin_room(request, session.user_id(), session.connection_id)
1493 .await?;
1494
1495 response.send(proto::RejoinRoomResponse {
1496 room: Some(rejoined_room.room.clone()),
1497 reshared_projects: rejoined_room
1498 .reshared_projects
1499 .iter()
1500 .map(|project| proto::ResharedProject {
1501 id: project.id.to_proto(),
1502 collaborators: project
1503 .collaborators
1504 .iter()
1505 .map(|collaborator| collaborator.to_proto())
1506 .collect(),
1507 })
1508 .collect(),
1509 rejoined_projects: rejoined_room
1510 .rejoined_projects
1511 .iter()
1512 .map(|rejoined_project| rejoined_project.to_proto())
1513 .collect(),
1514 })?;
1515 room_updated(&rejoined_room.room, &session.peer);
1516
1517 for project in &rejoined_room.reshared_projects {
1518 for collaborator in &project.collaborators {
1519 session
1520 .peer
1521 .send(
1522 collaborator.connection_id,
1523 proto::UpdateProjectCollaborator {
1524 project_id: project.id.to_proto(),
1525 old_peer_id: Some(project.old_connection_id.into()),
1526 new_peer_id: Some(session.connection_id.into()),
1527 },
1528 )
1529 .trace_err();
1530 }
1531
1532 broadcast(
1533 Some(session.connection_id),
1534 project
1535 .collaborators
1536 .iter()
1537 .map(|collaborator| collaborator.connection_id),
1538 |connection_id| {
1539 session.peer.forward_send(
1540 session.connection_id,
1541 connection_id,
1542 proto::UpdateProject {
1543 project_id: project.id.to_proto(),
1544 worktrees: project.worktrees.clone(),
1545 },
1546 )
1547 },
1548 );
1549 }
1550
1551 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1552
1553 let rejoined_room = rejoined_room.into_inner();
1554
1555 room = rejoined_room.room;
1556 channel = rejoined_room.channel;
1557 }
1558
1559 if let Some(channel) = channel {
1560 channel_updated(
1561 &channel,
1562 &room,
1563 &session.peer,
1564 &*session.connection_pool().await,
1565 );
1566 }
1567
1568 update_user_contacts(session.user_id(), &session).await?;
1569 Ok(())
1570}
1571
1572fn notify_rejoined_projects(
1573 rejoined_projects: &mut Vec<RejoinedProject>,
1574 session: &Session,
1575) -> Result<()> {
1576 for project in rejoined_projects.iter() {
1577 for collaborator in &project.collaborators {
1578 session
1579 .peer
1580 .send(
1581 collaborator.connection_id,
1582 proto::UpdateProjectCollaborator {
1583 project_id: project.id.to_proto(),
1584 old_peer_id: Some(project.old_connection_id.into()),
1585 new_peer_id: Some(session.connection_id.into()),
1586 },
1587 )
1588 .trace_err();
1589 }
1590 }
1591
1592 for project in rejoined_projects {
1593 for worktree in mem::take(&mut project.worktrees) {
1594 // Stream this worktree's entries.
1595 let message = proto::UpdateWorktree {
1596 project_id: project.id.to_proto(),
1597 worktree_id: worktree.id,
1598 abs_path: worktree.abs_path.clone(),
1599 root_name: worktree.root_name,
1600 updated_entries: worktree.updated_entries,
1601 removed_entries: worktree.removed_entries,
1602 scan_id: worktree.scan_id,
1603 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1604 updated_repositories: worktree.updated_repositories,
1605 removed_repositories: worktree.removed_repositories,
1606 };
1607 for update in proto::split_worktree_update(message) {
1608 session.peer.send(session.connection_id, update)?;
1609 }
1610
1611 // Stream this worktree's diagnostics.
1612 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1613 if let Some(summary) = worktree_diagnostics.next() {
1614 let message = proto::UpdateDiagnosticSummary {
1615 project_id: project.id.to_proto(),
1616 worktree_id: worktree.id,
1617 summary: Some(summary),
1618 more_summaries: worktree_diagnostics.collect(),
1619 };
1620 session.peer.send(session.connection_id, message)?;
1621 }
1622
1623 for settings_file in worktree.settings_files {
1624 session.peer.send(
1625 session.connection_id,
1626 proto::UpdateWorktreeSettings {
1627 project_id: project.id.to_proto(),
1628 worktree_id: worktree.id,
1629 path: settings_file.path,
1630 content: Some(settings_file.content),
1631 kind: Some(settings_file.kind.to_proto().into()),
1632 },
1633 )?;
1634 }
1635 }
1636
1637 for repository in mem::take(&mut project.updated_repositories) {
1638 for update in split_repository_update(repository) {
1639 session.peer.send(session.connection_id, update)?;
1640 }
1641 }
1642
1643 for id in mem::take(&mut project.removed_repositories) {
1644 session.peer.send(
1645 session.connection_id,
1646 proto::RemoveRepository {
1647 project_id: project.id.to_proto(),
1648 id,
1649 },
1650 )?;
1651 }
1652 }
1653
1654 Ok(())
1655}
1656
1657/// leave room disconnects from the room.
1658async fn leave_room(
1659 _: proto::LeaveRoom,
1660 response: Response<proto::LeaveRoom>,
1661 session: MessageContext,
1662) -> Result<()> {
1663 leave_room_for_session(&session, session.connection_id).await?;
1664 response.send(proto::Ack {})?;
1665 Ok(())
1666}
1667
1668/// Updates the permissions of someone else in the room.
1669async fn set_room_participant_role(
1670 request: proto::SetRoomParticipantRole,
1671 response: Response<proto::SetRoomParticipantRole>,
1672 session: MessageContext,
1673) -> Result<()> {
1674 let user_id = UserId::from_proto(request.user_id);
1675 let role = ChannelRole::from(request.role());
1676
1677 let (livekit_room, can_publish) = {
1678 let room = session
1679 .db()
1680 .await
1681 .set_room_participant_role(
1682 session.user_id(),
1683 RoomId::from_proto(request.room_id),
1684 user_id,
1685 role,
1686 )
1687 .await?;
1688
1689 let livekit_room = room.livekit_room.clone();
1690 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1691 room_updated(&room, &session.peer);
1692 (livekit_room, can_publish)
1693 };
1694
1695 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1696 live_kit
1697 .update_participant(
1698 livekit_room.clone(),
1699 request.user_id.to_string(),
1700 livekit_api::proto::ParticipantPermission {
1701 can_subscribe: true,
1702 can_publish,
1703 can_publish_data: can_publish,
1704 hidden: false,
1705 recorder: false,
1706 },
1707 )
1708 .await
1709 .trace_err();
1710 }
1711
1712 response.send(proto::Ack {})?;
1713 Ok(())
1714}
1715
1716/// Call someone else into the current room
1717async fn call(
1718 request: proto::Call,
1719 response: Response<proto::Call>,
1720 session: MessageContext,
1721) -> Result<()> {
1722 let room_id = RoomId::from_proto(request.room_id);
1723 let calling_user_id = session.user_id();
1724 let calling_connection_id = session.connection_id;
1725 let called_user_id = UserId::from_proto(request.called_user_id);
1726 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1727 if !session
1728 .db()
1729 .await
1730 .has_contact(calling_user_id, called_user_id)
1731 .await?
1732 {
1733 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1734 }
1735
1736 let incoming_call = {
1737 let (room, incoming_call) = &mut *session
1738 .db()
1739 .await
1740 .call(
1741 room_id,
1742 calling_user_id,
1743 calling_connection_id,
1744 called_user_id,
1745 initial_project_id,
1746 )
1747 .await?;
1748 room_updated(room, &session.peer);
1749 mem::take(incoming_call)
1750 };
1751 update_user_contacts(called_user_id, &session).await?;
1752
1753 let mut calls = session
1754 .connection_pool()
1755 .await
1756 .user_connection_ids(called_user_id)
1757 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1758 .collect::<FuturesUnordered<_>>();
1759
1760 while let Some(call_response) = calls.next().await {
1761 match call_response.as_ref() {
1762 Ok(_) => {
1763 response.send(proto::Ack {})?;
1764 return Ok(());
1765 }
1766 Err(_) => {
1767 call_response.trace_err();
1768 }
1769 }
1770 }
1771
1772 {
1773 let room = session
1774 .db()
1775 .await
1776 .call_failed(room_id, called_user_id)
1777 .await?;
1778 room_updated(&room, &session.peer);
1779 }
1780 update_user_contacts(called_user_id, &session).await?;
1781
1782 Err(anyhow!("failed to ring user"))?
1783}
1784
1785/// Cancel an outgoing call.
1786async fn cancel_call(
1787 request: proto::CancelCall,
1788 response: Response<proto::CancelCall>,
1789 session: MessageContext,
1790) -> Result<()> {
1791 let called_user_id = UserId::from_proto(request.called_user_id);
1792 let room_id = RoomId::from_proto(request.room_id);
1793 {
1794 let room = session
1795 .db()
1796 .await
1797 .cancel_call(room_id, session.connection_id, called_user_id)
1798 .await?;
1799 room_updated(&room, &session.peer);
1800 }
1801
1802 for connection_id in session
1803 .connection_pool()
1804 .await
1805 .user_connection_ids(called_user_id)
1806 {
1807 session
1808 .peer
1809 .send(
1810 connection_id,
1811 proto::CallCanceled {
1812 room_id: room_id.to_proto(),
1813 },
1814 )
1815 .trace_err();
1816 }
1817 response.send(proto::Ack {})?;
1818
1819 update_user_contacts(called_user_id, &session).await?;
1820 Ok(())
1821}
1822
1823/// Decline an incoming call.
1824async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1825 let room_id = RoomId::from_proto(message.room_id);
1826 {
1827 let room = session
1828 .db()
1829 .await
1830 .decline_call(Some(room_id), session.user_id())
1831 .await?
1832 .context("declining call")?;
1833 room_updated(&room, &session.peer);
1834 }
1835
1836 for connection_id in session
1837 .connection_pool()
1838 .await
1839 .user_connection_ids(session.user_id())
1840 {
1841 session
1842 .peer
1843 .send(
1844 connection_id,
1845 proto::CallCanceled {
1846 room_id: room_id.to_proto(),
1847 },
1848 )
1849 .trace_err();
1850 }
1851 update_user_contacts(session.user_id(), &session).await?;
1852 Ok(())
1853}
1854
1855/// Updates other participants in the room with your current location.
1856async fn update_participant_location(
1857 request: proto::UpdateParticipantLocation,
1858 response: Response<proto::UpdateParticipantLocation>,
1859 session: MessageContext,
1860) -> Result<()> {
1861 let room_id = RoomId::from_proto(request.room_id);
1862 let location = request.location.context("invalid location")?;
1863
1864 let db = session.db().await;
1865 let room = db
1866 .update_room_participant_location(room_id, session.connection_id, location)
1867 .await?;
1868
1869 room_updated(&room, &session.peer);
1870 response.send(proto::Ack {})?;
1871 Ok(())
1872}
1873
1874/// Share a project into the room.
1875async fn share_project(
1876 request: proto::ShareProject,
1877 response: Response<proto::ShareProject>,
1878 session: MessageContext,
1879) -> Result<()> {
1880 let (project_id, room) = &*session
1881 .db()
1882 .await
1883 .share_project(
1884 RoomId::from_proto(request.room_id),
1885 session.connection_id,
1886 &request.worktrees,
1887 request.is_ssh_project,
1888 request.windows_paths.unwrap_or(false),
1889 )
1890 .await?;
1891 response.send(proto::ShareProjectResponse {
1892 project_id: project_id.to_proto(),
1893 })?;
1894 room_updated(room, &session.peer);
1895
1896 Ok(())
1897}
1898
1899/// Unshare a project from the room.
1900async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1901 let project_id = ProjectId::from_proto(message.project_id);
1902 unshare_project_internal(project_id, session.connection_id, &session).await
1903}
1904
1905async fn unshare_project_internal(
1906 project_id: ProjectId,
1907 connection_id: ConnectionId,
1908 session: &Session,
1909) -> Result<()> {
1910 let delete = {
1911 let room_guard = session
1912 .db()
1913 .await
1914 .unshare_project(project_id, connection_id)
1915 .await?;
1916
1917 let (delete, room, guest_connection_ids) = &*room_guard;
1918
1919 let message = proto::UnshareProject {
1920 project_id: project_id.to_proto(),
1921 };
1922
1923 broadcast(
1924 Some(connection_id),
1925 guest_connection_ids.iter().copied(),
1926 |conn_id| session.peer.send(conn_id, message.clone()),
1927 );
1928 if let Some(room) = room {
1929 room_updated(room, &session.peer);
1930 }
1931
1932 *delete
1933 };
1934
1935 if delete {
1936 let db = session.db().await;
1937 db.delete_project(project_id).await?;
1938 }
1939
1940 Ok(())
1941}
1942
1943/// Join someone elses shared project.
1944async fn join_project(
1945 request: proto::JoinProject,
1946 response: Response<proto::JoinProject>,
1947 session: MessageContext,
1948) -> Result<()> {
1949 let project_id = ProjectId::from_proto(request.project_id);
1950
1951 tracing::info!(%project_id, "join project");
1952
1953 let db = session.db().await;
1954 let (project, replica_id) = &mut *db
1955 .join_project(
1956 project_id,
1957 session.connection_id,
1958 session.user_id(),
1959 request.committer_name.clone(),
1960 request.committer_email.clone(),
1961 )
1962 .await?;
1963 drop(db);
1964 tracing::info!(%project_id, "join remote project");
1965 let collaborators = project
1966 .collaborators
1967 .iter()
1968 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1969 .map(|collaborator| collaborator.to_proto())
1970 .collect::<Vec<_>>();
1971 let project_id = project.id;
1972 let guest_user_id = session.user_id();
1973
1974 let worktrees = project
1975 .worktrees
1976 .iter()
1977 .map(|(id, worktree)| proto::WorktreeMetadata {
1978 id: *id,
1979 root_name: worktree.root_name.clone(),
1980 visible: worktree.visible,
1981 abs_path: worktree.abs_path.clone(),
1982 })
1983 .collect::<Vec<_>>();
1984
1985 let add_project_collaborator = proto::AddProjectCollaborator {
1986 project_id: project_id.to_proto(),
1987 collaborator: Some(proto::Collaborator {
1988 peer_id: Some(session.connection_id.into()),
1989 replica_id: replica_id.0 as u32,
1990 user_id: guest_user_id.to_proto(),
1991 is_host: false,
1992 committer_name: request.committer_name.clone(),
1993 committer_email: request.committer_email.clone(),
1994 }),
1995 };
1996
1997 for collaborator in &collaborators {
1998 session
1999 .peer
2000 .send(
2001 collaborator.peer_id.unwrap().into(),
2002 add_project_collaborator.clone(),
2003 )
2004 .trace_err();
2005 }
2006
2007 // First, we send the metadata associated with each worktree.
2008 let (language_servers, language_server_capabilities) = project
2009 .language_servers
2010 .clone()
2011 .into_iter()
2012 .map(|server| (server.server, server.capabilities))
2013 .unzip();
2014 response.send(proto::JoinProjectResponse {
2015 project_id: project.id.0 as u64,
2016 worktrees,
2017 replica_id: replica_id.0 as u32,
2018 collaborators,
2019 language_servers,
2020 language_server_capabilities,
2021 role: project.role.into(),
2022 windows_paths: project.path_style == PathStyle::Windows,
2023 })?;
2024
2025 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2026 // Stream this worktree's entries.
2027 let message = proto::UpdateWorktree {
2028 project_id: project_id.to_proto(),
2029 worktree_id,
2030 abs_path: worktree.abs_path.clone(),
2031 root_name: worktree.root_name,
2032 updated_entries: worktree.entries,
2033 removed_entries: Default::default(),
2034 scan_id: worktree.scan_id,
2035 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2036 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2037 removed_repositories: Default::default(),
2038 };
2039 for update in proto::split_worktree_update(message) {
2040 session.peer.send(session.connection_id, update.clone())?;
2041 }
2042
2043 // Stream this worktree's diagnostics.
2044 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2045 if let Some(summary) = worktree_diagnostics.next() {
2046 let message = proto::UpdateDiagnosticSummary {
2047 project_id: project.id.to_proto(),
2048 worktree_id: worktree.id,
2049 summary: Some(summary),
2050 more_summaries: worktree_diagnostics.collect(),
2051 };
2052 session.peer.send(session.connection_id, message)?;
2053 }
2054
2055 for settings_file in worktree.settings_files {
2056 session.peer.send(
2057 session.connection_id,
2058 proto::UpdateWorktreeSettings {
2059 project_id: project_id.to_proto(),
2060 worktree_id: worktree.id,
2061 path: settings_file.path,
2062 content: Some(settings_file.content),
2063 kind: Some(settings_file.kind.to_proto() as i32),
2064 },
2065 )?;
2066 }
2067 }
2068
2069 for repository in mem::take(&mut project.repositories) {
2070 for update in split_repository_update(repository) {
2071 session.peer.send(session.connection_id, update)?;
2072 }
2073 }
2074
2075 for language_server in &project.language_servers {
2076 session.peer.send(
2077 session.connection_id,
2078 proto::UpdateLanguageServer {
2079 project_id: project_id.to_proto(),
2080 server_name: Some(language_server.server.name.clone()),
2081 language_server_id: language_server.server.id,
2082 variant: Some(
2083 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2084 proto::LspDiskBasedDiagnosticsUpdated {},
2085 ),
2086 ),
2087 },
2088 )?;
2089 }
2090
2091 Ok(())
2092}
2093
2094/// Leave someone elses shared project.
2095async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2096 let sender_id = session.connection_id;
2097 let project_id = ProjectId::from_proto(request.project_id);
2098 let db = session.db().await;
2099
2100 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2101 tracing::info!(
2102 %project_id,
2103 "leave project"
2104 );
2105
2106 project_left(project, &session);
2107 if let Some(room) = room {
2108 room_updated(room, &session.peer);
2109 }
2110
2111 Ok(())
2112}
2113
2114/// Updates other participants with changes to the project
2115async fn update_project(
2116 request: proto::UpdateProject,
2117 response: Response<proto::UpdateProject>,
2118 session: MessageContext,
2119) -> Result<()> {
2120 let project_id = ProjectId::from_proto(request.project_id);
2121 let (room, guest_connection_ids) = &*session
2122 .db()
2123 .await
2124 .update_project(project_id, session.connection_id, &request.worktrees)
2125 .await?;
2126 broadcast(
2127 Some(session.connection_id),
2128 guest_connection_ids.iter().copied(),
2129 |connection_id| {
2130 session
2131 .peer
2132 .forward_send(session.connection_id, connection_id, request.clone())
2133 },
2134 );
2135 if let Some(room) = room {
2136 room_updated(room, &session.peer);
2137 }
2138 response.send(proto::Ack {})?;
2139
2140 Ok(())
2141}
2142
2143/// Updates other participants with changes to the worktree
2144async fn update_worktree(
2145 request: proto::UpdateWorktree,
2146 response: Response<proto::UpdateWorktree>,
2147 session: MessageContext,
2148) -> Result<()> {
2149 let guest_connection_ids = session
2150 .db()
2151 .await
2152 .update_worktree(&request, session.connection_id)
2153 .await?;
2154
2155 broadcast(
2156 Some(session.connection_id),
2157 guest_connection_ids.iter().copied(),
2158 |connection_id| {
2159 session
2160 .peer
2161 .forward_send(session.connection_id, connection_id, request.clone())
2162 },
2163 );
2164 response.send(proto::Ack {})?;
2165 Ok(())
2166}
2167
2168async fn update_repository(
2169 request: proto::UpdateRepository,
2170 response: Response<proto::UpdateRepository>,
2171 session: MessageContext,
2172) -> Result<()> {
2173 let guest_connection_ids = session
2174 .db()
2175 .await
2176 .update_repository(&request, session.connection_id)
2177 .await?;
2178
2179 broadcast(
2180 Some(session.connection_id),
2181 guest_connection_ids.iter().copied(),
2182 |connection_id| {
2183 session
2184 .peer
2185 .forward_send(session.connection_id, connection_id, request.clone())
2186 },
2187 );
2188 response.send(proto::Ack {})?;
2189 Ok(())
2190}
2191
2192async fn remove_repository(
2193 request: proto::RemoveRepository,
2194 response: Response<proto::RemoveRepository>,
2195 session: MessageContext,
2196) -> Result<()> {
2197 let guest_connection_ids = session
2198 .db()
2199 .await
2200 .remove_repository(&request, session.connection_id)
2201 .await?;
2202
2203 broadcast(
2204 Some(session.connection_id),
2205 guest_connection_ids.iter().copied(),
2206 |connection_id| {
2207 session
2208 .peer
2209 .forward_send(session.connection_id, connection_id, request.clone())
2210 },
2211 );
2212 response.send(proto::Ack {})?;
2213 Ok(())
2214}
2215
2216/// Updates other participants with changes to the diagnostics
2217async fn update_diagnostic_summary(
2218 message: proto::UpdateDiagnosticSummary,
2219 session: MessageContext,
2220) -> Result<()> {
2221 let guest_connection_ids = session
2222 .db()
2223 .await
2224 .update_diagnostic_summary(&message, session.connection_id)
2225 .await?;
2226
2227 broadcast(
2228 Some(session.connection_id),
2229 guest_connection_ids.iter().copied(),
2230 |connection_id| {
2231 session
2232 .peer
2233 .forward_send(session.connection_id, connection_id, message.clone())
2234 },
2235 );
2236
2237 Ok(())
2238}
2239
2240/// Updates other participants with changes to the worktree settings
2241async fn update_worktree_settings(
2242 message: proto::UpdateWorktreeSettings,
2243 session: MessageContext,
2244) -> Result<()> {
2245 let guest_connection_ids = session
2246 .db()
2247 .await
2248 .update_worktree_settings(&message, session.connection_id)
2249 .await?;
2250
2251 broadcast(
2252 Some(session.connection_id),
2253 guest_connection_ids.iter().copied(),
2254 |connection_id| {
2255 session
2256 .peer
2257 .forward_send(session.connection_id, connection_id, message.clone())
2258 },
2259 );
2260
2261 Ok(())
2262}
2263
2264/// Notify other participants that a language server has started.
2265async fn start_language_server(
2266 request: proto::StartLanguageServer,
2267 session: MessageContext,
2268) -> Result<()> {
2269 let guest_connection_ids = session
2270 .db()
2271 .await
2272 .start_language_server(&request, session.connection_id)
2273 .await?;
2274
2275 broadcast(
2276 Some(session.connection_id),
2277 guest_connection_ids.iter().copied(),
2278 |connection_id| {
2279 session
2280 .peer
2281 .forward_send(session.connection_id, connection_id, request.clone())
2282 },
2283 );
2284 Ok(())
2285}
2286
2287/// Notify other participants that a language server has changed.
2288async fn update_language_server(
2289 request: proto::UpdateLanguageServer,
2290 session: MessageContext,
2291) -> Result<()> {
2292 let project_id = ProjectId::from_proto(request.project_id);
2293 let db = session.db().await;
2294
2295 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2296 && let Some(capabilities) = update.capabilities.clone()
2297 {
2298 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2299 .await?;
2300 }
2301
2302 let project_connection_ids = db
2303 .project_connection_ids(project_id, session.connection_id, true)
2304 .await?;
2305 broadcast(
2306 Some(session.connection_id),
2307 project_connection_ids.iter().copied(),
2308 |connection_id| {
2309 session
2310 .peer
2311 .forward_send(session.connection_id, connection_id, request.clone())
2312 },
2313 );
2314 Ok(())
2315}
2316
2317/// forward a project request to the host. These requests should be read only
2318/// as guests are allowed to send them.
2319async fn forward_read_only_project_request<T>(
2320 request: T,
2321 response: Response<T>,
2322 session: MessageContext,
2323) -> Result<()>
2324where
2325 T: EntityMessage + RequestMessage,
2326{
2327 let project_id = ProjectId::from_proto(request.remote_entity_id());
2328 let host_connection_id = session
2329 .db()
2330 .await
2331 .host_for_read_only_project_request(project_id, session.connection_id)
2332 .await?;
2333 let payload = session.forward_request(host_connection_id, request).await?;
2334 response.send(payload)?;
2335 Ok(())
2336}
2337
2338/// forward a project request to the host. These requests are disallowed
2339/// for guests.
2340async fn forward_mutating_project_request<T>(
2341 request: T,
2342 response: Response<T>,
2343 session: MessageContext,
2344) -> Result<()>
2345where
2346 T: EntityMessage + RequestMessage,
2347{
2348 let project_id = ProjectId::from_proto(request.remote_entity_id());
2349
2350 let host_connection_id = session
2351 .db()
2352 .await
2353 .host_for_mutating_project_request(project_id, session.connection_id)
2354 .await?;
2355 let payload = session.forward_request(host_connection_id, request).await?;
2356 response.send(payload)?;
2357 Ok(())
2358}
2359
2360async fn lsp_query(
2361 request: proto::LspQuery,
2362 response: Response<proto::LspQuery>,
2363 session: MessageContext,
2364) -> Result<()> {
2365 let (name, should_write) = request.query_name_and_write_permissions();
2366 tracing::Span::current().record("lsp_query_request", name);
2367 tracing::info!("lsp_query message received");
2368 if should_write {
2369 forward_mutating_project_request(request, response, session).await
2370 } else {
2371 forward_read_only_project_request(request, response, session).await
2372 }
2373}
2374
2375/// Notify other participants that a new buffer has been created
2376async fn create_buffer_for_peer(
2377 request: proto::CreateBufferForPeer,
2378 session: MessageContext,
2379) -> Result<()> {
2380 session
2381 .db()
2382 .await
2383 .check_user_is_project_host(
2384 ProjectId::from_proto(request.project_id),
2385 session.connection_id,
2386 )
2387 .await?;
2388 let peer_id = request.peer_id.context("invalid peer id")?;
2389 session
2390 .peer
2391 .forward_send(session.connection_id, peer_id.into(), request)?;
2392 Ok(())
2393}
2394
2395/// Notify other participants that a new image has been created
2396async fn create_image_for_peer(
2397 request: proto::CreateImageForPeer,
2398 session: MessageContext,
2399) -> Result<()> {
2400 session
2401 .db()
2402 .await
2403 .check_user_is_project_host(
2404 ProjectId::from_proto(request.project_id),
2405 session.connection_id,
2406 )
2407 .await?;
2408 let peer_id = request.peer_id.context("invalid peer id")?;
2409 session
2410 .peer
2411 .forward_send(session.connection_id, peer_id.into(), request)?;
2412 Ok(())
2413}
2414
2415/// Notify other participants that a buffer has been updated. This is
2416/// allowed for guests as long as the update is limited to selections.
2417async fn update_buffer(
2418 request: proto::UpdateBuffer,
2419 response: Response<proto::UpdateBuffer>,
2420 session: MessageContext,
2421) -> Result<()> {
2422 let project_id = ProjectId::from_proto(request.project_id);
2423 let mut capability = Capability::ReadOnly;
2424
2425 for op in request.operations.iter() {
2426 match op.variant {
2427 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2428 Some(_) => capability = Capability::ReadWrite,
2429 }
2430 }
2431
2432 let host = {
2433 let guard = session
2434 .db()
2435 .await
2436 .connections_for_buffer_update(project_id, session.connection_id, capability)
2437 .await?;
2438
2439 let (host, guests) = &*guard;
2440
2441 broadcast(
2442 Some(session.connection_id),
2443 guests.clone(),
2444 |connection_id| {
2445 session
2446 .peer
2447 .forward_send(session.connection_id, connection_id, request.clone())
2448 },
2449 );
2450
2451 *host
2452 };
2453
2454 if host != session.connection_id {
2455 session.forward_request(host, request.clone()).await?;
2456 }
2457
2458 response.send(proto::Ack {})?;
2459 Ok(())
2460}
2461
2462async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2463 let project_id = ProjectId::from_proto(message.project_id);
2464
2465 let operation = message.operation.as_ref().context("invalid operation")?;
2466 let capability = match operation.variant.as_ref() {
2467 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2468 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2469 match buffer_op.variant {
2470 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2471 Capability::ReadOnly
2472 }
2473 _ => Capability::ReadWrite,
2474 }
2475 } else {
2476 Capability::ReadWrite
2477 }
2478 }
2479 Some(_) => Capability::ReadWrite,
2480 None => Capability::ReadOnly,
2481 };
2482
2483 let guard = session
2484 .db()
2485 .await
2486 .connections_for_buffer_update(project_id, session.connection_id, capability)
2487 .await?;
2488
2489 let (host, guests) = &*guard;
2490
2491 broadcast(
2492 Some(session.connection_id),
2493 guests.iter().chain([host]).copied(),
2494 |connection_id| {
2495 session
2496 .peer
2497 .forward_send(session.connection_id, connection_id, message.clone())
2498 },
2499 );
2500
2501 Ok(())
2502}
2503
2504/// Notify other participants that a project has been updated.
2505async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2506 request: T,
2507 session: MessageContext,
2508) -> Result<()> {
2509 let project_id = ProjectId::from_proto(request.remote_entity_id());
2510 let project_connection_ids = session
2511 .db()
2512 .await
2513 .project_connection_ids(project_id, session.connection_id, false)
2514 .await?;
2515
2516 broadcast(
2517 Some(session.connection_id),
2518 project_connection_ids.iter().copied(),
2519 |connection_id| {
2520 session
2521 .peer
2522 .forward_send(session.connection_id, connection_id, request.clone())
2523 },
2524 );
2525 Ok(())
2526}
2527
2528/// Start following another user in a call.
2529async fn follow(
2530 request: proto::Follow,
2531 response: Response<proto::Follow>,
2532 session: MessageContext,
2533) -> Result<()> {
2534 let room_id = RoomId::from_proto(request.room_id);
2535 let project_id = request.project_id.map(ProjectId::from_proto);
2536 let leader_id = request.leader_id.context("invalid leader id")?.into();
2537 let follower_id = session.connection_id;
2538
2539 session
2540 .db()
2541 .await
2542 .check_room_participants(room_id, leader_id, session.connection_id)
2543 .await?;
2544
2545 let response_payload = session.forward_request(leader_id, request).await?;
2546 response.send(response_payload)?;
2547
2548 if let Some(project_id) = project_id {
2549 let room = session
2550 .db()
2551 .await
2552 .follow(room_id, project_id, leader_id, follower_id)
2553 .await?;
2554 room_updated(&room, &session.peer);
2555 }
2556
2557 Ok(())
2558}
2559
2560/// Stop following another user in a call.
2561async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2562 let room_id = RoomId::from_proto(request.room_id);
2563 let project_id = request.project_id.map(ProjectId::from_proto);
2564 let leader_id = request.leader_id.context("invalid leader id")?.into();
2565 let follower_id = session.connection_id;
2566
2567 session
2568 .db()
2569 .await
2570 .check_room_participants(room_id, leader_id, session.connection_id)
2571 .await?;
2572
2573 session
2574 .peer
2575 .forward_send(session.connection_id, leader_id, request)?;
2576
2577 if let Some(project_id) = project_id {
2578 let room = session
2579 .db()
2580 .await
2581 .unfollow(room_id, project_id, leader_id, follower_id)
2582 .await?;
2583 room_updated(&room, &session.peer);
2584 }
2585
2586 Ok(())
2587}
2588
2589/// Notify everyone following you of your current location.
2590async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2591 let room_id = RoomId::from_proto(request.room_id);
2592 let database = session.db.lock().await;
2593
2594 let connection_ids = if let Some(project_id) = request.project_id {
2595 let project_id = ProjectId::from_proto(project_id);
2596 database
2597 .project_connection_ids(project_id, session.connection_id, true)
2598 .await?
2599 } else {
2600 database
2601 .room_connection_ids(room_id, session.connection_id)
2602 .await?
2603 };
2604
2605 // For now, don't send view update messages back to that view's current leader.
2606 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2607 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2608 _ => None,
2609 });
2610
2611 for connection_id in connection_ids.iter().cloned() {
2612 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2613 session
2614 .peer
2615 .forward_send(session.connection_id, connection_id, request.clone())?;
2616 }
2617 }
2618 Ok(())
2619}
2620
2621/// Get public data about users.
2622async fn get_users(
2623 request: proto::GetUsers,
2624 response: Response<proto::GetUsers>,
2625 session: MessageContext,
2626) -> Result<()> {
2627 let user_ids = request
2628 .user_ids
2629 .into_iter()
2630 .map(UserId::from_proto)
2631 .collect();
2632 let users = session
2633 .db()
2634 .await
2635 .get_users_by_ids(user_ids)
2636 .await?
2637 .into_iter()
2638 .map(|user| proto::User {
2639 id: user.id.to_proto(),
2640 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2641 github_login: user.github_login,
2642 name: user.name,
2643 })
2644 .collect();
2645 response.send(proto::UsersResponse { users })?;
2646 Ok(())
2647}
2648
2649/// Search for users (to invite) buy Github login
2650async fn fuzzy_search_users(
2651 request: proto::FuzzySearchUsers,
2652 response: Response<proto::FuzzySearchUsers>,
2653 session: MessageContext,
2654) -> Result<()> {
2655 let query = request.query;
2656 let users = match query.len() {
2657 0 => vec![],
2658 1 | 2 => session
2659 .db()
2660 .await
2661 .get_user_by_github_login(&query)
2662 .await?
2663 .into_iter()
2664 .collect(),
2665 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2666 };
2667 let users = users
2668 .into_iter()
2669 .filter(|user| user.id != session.user_id())
2670 .map(|user| proto::User {
2671 id: user.id.to_proto(),
2672 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2673 github_login: user.github_login,
2674 name: user.name,
2675 })
2676 .collect();
2677 response.send(proto::UsersResponse { users })?;
2678 Ok(())
2679}
2680
2681/// Send a contact request to another user.
2682async fn request_contact(
2683 request: proto::RequestContact,
2684 response: Response<proto::RequestContact>,
2685 session: MessageContext,
2686) -> Result<()> {
2687 let requester_id = session.user_id();
2688 let responder_id = UserId::from_proto(request.responder_id);
2689 if requester_id == responder_id {
2690 return Err(anyhow!("cannot add yourself as a contact"))?;
2691 }
2692
2693 let notifications = session
2694 .db()
2695 .await
2696 .send_contact_request(requester_id, responder_id)
2697 .await?;
2698
2699 // Update outgoing contact requests of requester
2700 let mut update = proto::UpdateContacts::default();
2701 update.outgoing_requests.push(responder_id.to_proto());
2702 for connection_id in session
2703 .connection_pool()
2704 .await
2705 .user_connection_ids(requester_id)
2706 {
2707 session.peer.send(connection_id, update.clone())?;
2708 }
2709
2710 // Update incoming contact requests of responder
2711 let mut update = proto::UpdateContacts::default();
2712 update
2713 .incoming_requests
2714 .push(proto::IncomingContactRequest {
2715 requester_id: requester_id.to_proto(),
2716 });
2717 let connection_pool = session.connection_pool().await;
2718 for connection_id in connection_pool.user_connection_ids(responder_id) {
2719 session.peer.send(connection_id, update.clone())?;
2720 }
2721
2722 send_notifications(&connection_pool, &session.peer, notifications);
2723
2724 response.send(proto::Ack {})?;
2725 Ok(())
2726}
2727
2728/// Accept or decline a contact request
2729async fn respond_to_contact_request(
2730 request: proto::RespondToContactRequest,
2731 response: Response<proto::RespondToContactRequest>,
2732 session: MessageContext,
2733) -> Result<()> {
2734 let responder_id = session.user_id();
2735 let requester_id = UserId::from_proto(request.requester_id);
2736 let db = session.db().await;
2737 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2738 db.dismiss_contact_notification(responder_id, requester_id)
2739 .await?;
2740 } else {
2741 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2742
2743 let notifications = db
2744 .respond_to_contact_request(responder_id, requester_id, accept)
2745 .await?;
2746 let requester_busy = db.is_user_busy(requester_id).await?;
2747 let responder_busy = db.is_user_busy(responder_id).await?;
2748
2749 let pool = session.connection_pool().await;
2750 // Update responder with new contact
2751 let mut update = proto::UpdateContacts::default();
2752 if accept {
2753 update
2754 .contacts
2755 .push(contact_for_user(requester_id, requester_busy, &pool));
2756 }
2757 update
2758 .remove_incoming_requests
2759 .push(requester_id.to_proto());
2760 for connection_id in pool.user_connection_ids(responder_id) {
2761 session.peer.send(connection_id, update.clone())?;
2762 }
2763
2764 // Update requester with new contact
2765 let mut update = proto::UpdateContacts::default();
2766 if accept {
2767 update
2768 .contacts
2769 .push(contact_for_user(responder_id, responder_busy, &pool));
2770 }
2771 update
2772 .remove_outgoing_requests
2773 .push(responder_id.to_proto());
2774
2775 for connection_id in pool.user_connection_ids(requester_id) {
2776 session.peer.send(connection_id, update.clone())?;
2777 }
2778
2779 send_notifications(&pool, &session.peer, notifications);
2780 }
2781
2782 response.send(proto::Ack {})?;
2783 Ok(())
2784}
2785
2786/// Remove a contact.
2787async fn remove_contact(
2788 request: proto::RemoveContact,
2789 response: Response<proto::RemoveContact>,
2790 session: MessageContext,
2791) -> Result<()> {
2792 let requester_id = session.user_id();
2793 let responder_id = UserId::from_proto(request.user_id);
2794 let db = session.db().await;
2795 let (contact_accepted, deleted_notification_id) =
2796 db.remove_contact(requester_id, responder_id).await?;
2797
2798 let pool = session.connection_pool().await;
2799 // Update outgoing contact requests of requester
2800 let mut update = proto::UpdateContacts::default();
2801 if contact_accepted {
2802 update.remove_contacts.push(responder_id.to_proto());
2803 } else {
2804 update
2805 .remove_outgoing_requests
2806 .push(responder_id.to_proto());
2807 }
2808 for connection_id in pool.user_connection_ids(requester_id) {
2809 session.peer.send(connection_id, update.clone())?;
2810 }
2811
2812 // Update incoming contact requests of responder
2813 let mut update = proto::UpdateContacts::default();
2814 if contact_accepted {
2815 update.remove_contacts.push(requester_id.to_proto());
2816 } else {
2817 update
2818 .remove_incoming_requests
2819 .push(requester_id.to_proto());
2820 }
2821 for connection_id in pool.user_connection_ids(responder_id) {
2822 session.peer.send(connection_id, update.clone())?;
2823 if let Some(notification_id) = deleted_notification_id {
2824 session.peer.send(
2825 connection_id,
2826 proto::DeleteNotification {
2827 notification_id: notification_id.to_proto(),
2828 },
2829 )?;
2830 }
2831 }
2832
2833 response.send(proto::Ack {})?;
2834 Ok(())
2835}
2836
2837fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2838 version.0.minor < 139
2839}
2840
2841async fn subscribe_to_channels(
2842 _: proto::SubscribeToChannels,
2843 session: MessageContext,
2844) -> Result<()> {
2845 subscribe_user_to_channels(session.user_id(), &session).await?;
2846 Ok(())
2847}
2848
2849async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2850 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2851 let mut pool = session.connection_pool().await;
2852 for membership in &channels_for_user.channel_memberships {
2853 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2854 }
2855 session.peer.send(
2856 session.connection_id,
2857 build_update_user_channels(&channels_for_user),
2858 )?;
2859 session.peer.send(
2860 session.connection_id,
2861 build_channels_update(channels_for_user),
2862 )?;
2863 Ok(())
2864}
2865
2866/// Creates a new channel.
2867async fn create_channel(
2868 request: proto::CreateChannel,
2869 response: Response<proto::CreateChannel>,
2870 session: MessageContext,
2871) -> Result<()> {
2872 let db = session.db().await;
2873
2874 let parent_id = request.parent_id.map(ChannelId::from_proto);
2875 let (channel, membership) = db
2876 .create_channel(&request.name, parent_id, session.user_id())
2877 .await?;
2878
2879 let root_id = channel.root_id();
2880 let channel = Channel::from_model(channel);
2881
2882 response.send(proto::CreateChannelResponse {
2883 channel: Some(channel.to_proto()),
2884 parent_id: request.parent_id,
2885 })?;
2886
2887 let mut connection_pool = session.connection_pool().await;
2888 if let Some(membership) = membership {
2889 connection_pool.subscribe_to_channel(
2890 membership.user_id,
2891 membership.channel_id,
2892 membership.role,
2893 );
2894 let update = proto::UpdateUserChannels {
2895 channel_memberships: vec![proto::ChannelMembership {
2896 channel_id: membership.channel_id.to_proto(),
2897 role: membership.role.into(),
2898 }],
2899 ..Default::default()
2900 };
2901 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2902 session.peer.send(connection_id, update.clone())?;
2903 }
2904 }
2905
2906 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2907 if !role.can_see_channel(channel.visibility) {
2908 continue;
2909 }
2910
2911 let update = proto::UpdateChannels {
2912 channels: vec![channel.to_proto()],
2913 ..Default::default()
2914 };
2915 session.peer.send(connection_id, update.clone())?;
2916 }
2917
2918 Ok(())
2919}
2920
2921/// Delete a channel
2922async fn delete_channel(
2923 request: proto::DeleteChannel,
2924 response: Response<proto::DeleteChannel>,
2925 session: MessageContext,
2926) -> Result<()> {
2927 let db = session.db().await;
2928
2929 let channel_id = request.channel_id;
2930 let (root_channel, removed_channels) = db
2931 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2932 .await?;
2933 response.send(proto::Ack {})?;
2934
2935 // Notify members of removed channels
2936 let mut update = proto::UpdateChannels::default();
2937 update
2938 .delete_channels
2939 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2940
2941 let connection_pool = session.connection_pool().await;
2942 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2943 session.peer.send(connection_id, update.clone())?;
2944 }
2945
2946 Ok(())
2947}
2948
2949/// Invite someone to join a channel.
2950async fn invite_channel_member(
2951 request: proto::InviteChannelMember,
2952 response: Response<proto::InviteChannelMember>,
2953 session: MessageContext,
2954) -> Result<()> {
2955 let db = session.db().await;
2956 let channel_id = ChannelId::from_proto(request.channel_id);
2957 let invitee_id = UserId::from_proto(request.user_id);
2958 let InviteMemberResult {
2959 channel,
2960 notifications,
2961 } = db
2962 .invite_channel_member(
2963 channel_id,
2964 invitee_id,
2965 session.user_id(),
2966 request.role().into(),
2967 )
2968 .await?;
2969
2970 let update = proto::UpdateChannels {
2971 channel_invitations: vec![channel.to_proto()],
2972 ..Default::default()
2973 };
2974
2975 let connection_pool = session.connection_pool().await;
2976 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2977 session.peer.send(connection_id, update.clone())?;
2978 }
2979
2980 send_notifications(&connection_pool, &session.peer, notifications);
2981
2982 response.send(proto::Ack {})?;
2983 Ok(())
2984}
2985
2986/// remove someone from a channel
2987async fn remove_channel_member(
2988 request: proto::RemoveChannelMember,
2989 response: Response<proto::RemoveChannelMember>,
2990 session: MessageContext,
2991) -> Result<()> {
2992 let db = session.db().await;
2993 let channel_id = ChannelId::from_proto(request.channel_id);
2994 let member_id = UserId::from_proto(request.user_id);
2995
2996 let RemoveChannelMemberResult {
2997 membership_update,
2998 notification_id,
2999 } = db
3000 .remove_channel_member(channel_id, member_id, session.user_id())
3001 .await?;
3002
3003 let mut connection_pool = session.connection_pool().await;
3004 notify_membership_updated(
3005 &mut connection_pool,
3006 membership_update,
3007 member_id,
3008 &session.peer,
3009 );
3010 for connection_id in connection_pool.user_connection_ids(member_id) {
3011 if let Some(notification_id) = notification_id {
3012 session
3013 .peer
3014 .send(
3015 connection_id,
3016 proto::DeleteNotification {
3017 notification_id: notification_id.to_proto(),
3018 },
3019 )
3020 .trace_err();
3021 }
3022 }
3023
3024 response.send(proto::Ack {})?;
3025 Ok(())
3026}
3027
3028/// Toggle the channel between public and private.
3029/// Care is taken to maintain the invariant that public channels only descend from public channels,
3030/// (though members-only channels can appear at any point in the hierarchy).
3031async fn set_channel_visibility(
3032 request: proto::SetChannelVisibility,
3033 response: Response<proto::SetChannelVisibility>,
3034 session: MessageContext,
3035) -> Result<()> {
3036 let db = session.db().await;
3037 let channel_id = ChannelId::from_proto(request.channel_id);
3038 let visibility = request.visibility().into();
3039
3040 let channel_model = db
3041 .set_channel_visibility(channel_id, visibility, session.user_id())
3042 .await?;
3043 let root_id = channel_model.root_id();
3044 let channel = Channel::from_model(channel_model);
3045
3046 let mut connection_pool = session.connection_pool().await;
3047 for (user_id, role) in connection_pool
3048 .channel_user_ids(root_id)
3049 .collect::<Vec<_>>()
3050 .into_iter()
3051 {
3052 let update = if role.can_see_channel(channel.visibility) {
3053 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3054 proto::UpdateChannels {
3055 channels: vec![channel.to_proto()],
3056 ..Default::default()
3057 }
3058 } else {
3059 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3060 proto::UpdateChannels {
3061 delete_channels: vec![channel.id.to_proto()],
3062 ..Default::default()
3063 }
3064 };
3065
3066 for connection_id in connection_pool.user_connection_ids(user_id) {
3067 session.peer.send(connection_id, update.clone())?;
3068 }
3069 }
3070
3071 response.send(proto::Ack {})?;
3072 Ok(())
3073}
3074
3075/// Alter the role for a user in the channel.
3076async fn set_channel_member_role(
3077 request: proto::SetChannelMemberRole,
3078 response: Response<proto::SetChannelMemberRole>,
3079 session: MessageContext,
3080) -> Result<()> {
3081 let db = session.db().await;
3082 let channel_id = ChannelId::from_proto(request.channel_id);
3083 let member_id = UserId::from_proto(request.user_id);
3084 let result = db
3085 .set_channel_member_role(
3086 channel_id,
3087 session.user_id(),
3088 member_id,
3089 request.role().into(),
3090 )
3091 .await?;
3092
3093 match result {
3094 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3095 let mut connection_pool = session.connection_pool().await;
3096 notify_membership_updated(
3097 &mut connection_pool,
3098 membership_update,
3099 member_id,
3100 &session.peer,
3101 )
3102 }
3103 db::SetMemberRoleResult::InviteUpdated(channel) => {
3104 let update = proto::UpdateChannels {
3105 channel_invitations: vec![channel.to_proto()],
3106 ..Default::default()
3107 };
3108
3109 for connection_id in session
3110 .connection_pool()
3111 .await
3112 .user_connection_ids(member_id)
3113 {
3114 session.peer.send(connection_id, update.clone())?;
3115 }
3116 }
3117 }
3118
3119 response.send(proto::Ack {})?;
3120 Ok(())
3121}
3122
3123/// Change the name of a channel
3124async fn rename_channel(
3125 request: proto::RenameChannel,
3126 response: Response<proto::RenameChannel>,
3127 session: MessageContext,
3128) -> Result<()> {
3129 let db = session.db().await;
3130 let channel_id = ChannelId::from_proto(request.channel_id);
3131 let channel_model = db
3132 .rename_channel(channel_id, session.user_id(), &request.name)
3133 .await?;
3134 let root_id = channel_model.root_id();
3135 let channel = Channel::from_model(channel_model);
3136
3137 response.send(proto::RenameChannelResponse {
3138 channel: Some(channel.to_proto()),
3139 })?;
3140
3141 let connection_pool = session.connection_pool().await;
3142 let update = proto::UpdateChannels {
3143 channels: vec![channel.to_proto()],
3144 ..Default::default()
3145 };
3146 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3147 if role.can_see_channel(channel.visibility) {
3148 session.peer.send(connection_id, update.clone())?;
3149 }
3150 }
3151
3152 Ok(())
3153}
3154
3155/// Move a channel to a new parent.
3156async fn move_channel(
3157 request: proto::MoveChannel,
3158 response: Response<proto::MoveChannel>,
3159 session: MessageContext,
3160) -> Result<()> {
3161 let channel_id = ChannelId::from_proto(request.channel_id);
3162 let to = ChannelId::from_proto(request.to);
3163
3164 let (root_id, channels) = session
3165 .db()
3166 .await
3167 .move_channel(channel_id, to, session.user_id())
3168 .await?;
3169
3170 let connection_pool = session.connection_pool().await;
3171 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3172 let channels = channels
3173 .iter()
3174 .filter_map(|channel| {
3175 if role.can_see_channel(channel.visibility) {
3176 Some(channel.to_proto())
3177 } else {
3178 None
3179 }
3180 })
3181 .collect::<Vec<_>>();
3182 if channels.is_empty() {
3183 continue;
3184 }
3185
3186 let update = proto::UpdateChannels {
3187 channels,
3188 ..Default::default()
3189 };
3190
3191 session.peer.send(connection_id, update.clone())?;
3192 }
3193
3194 response.send(Ack {})?;
3195 Ok(())
3196}
3197
3198async fn reorder_channel(
3199 request: proto::ReorderChannel,
3200 response: Response<proto::ReorderChannel>,
3201 session: MessageContext,
3202) -> Result<()> {
3203 let channel_id = ChannelId::from_proto(request.channel_id);
3204 let direction = request.direction();
3205
3206 let updated_channels = session
3207 .db()
3208 .await
3209 .reorder_channel(channel_id, direction, session.user_id())
3210 .await?;
3211
3212 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3213 let connection_pool = session.connection_pool().await;
3214 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3215 let channels = updated_channels
3216 .iter()
3217 .filter_map(|channel| {
3218 if role.can_see_channel(channel.visibility) {
3219 Some(channel.to_proto())
3220 } else {
3221 None
3222 }
3223 })
3224 .collect::<Vec<_>>();
3225
3226 if channels.is_empty() {
3227 continue;
3228 }
3229
3230 let update = proto::UpdateChannels {
3231 channels,
3232 ..Default::default()
3233 };
3234
3235 session.peer.send(connection_id, update.clone())?;
3236 }
3237 }
3238
3239 response.send(Ack {})?;
3240 Ok(())
3241}
3242
3243/// Get the list of channel members
3244async fn get_channel_members(
3245 request: proto::GetChannelMembers,
3246 response: Response<proto::GetChannelMembers>,
3247 session: MessageContext,
3248) -> Result<()> {
3249 let db = session.db().await;
3250 let channel_id = ChannelId::from_proto(request.channel_id);
3251 let limit = if request.limit == 0 {
3252 u16::MAX as u64
3253 } else {
3254 request.limit
3255 };
3256 let (members, users) = db
3257 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3258 .await?;
3259 response.send(proto::GetChannelMembersResponse { members, users })?;
3260 Ok(())
3261}
3262
3263/// Accept or decline a channel invitation.
3264async fn respond_to_channel_invite(
3265 request: proto::RespondToChannelInvite,
3266 response: Response<proto::RespondToChannelInvite>,
3267 session: MessageContext,
3268) -> Result<()> {
3269 let db = session.db().await;
3270 let channel_id = ChannelId::from_proto(request.channel_id);
3271 let RespondToChannelInvite {
3272 membership_update,
3273 notifications,
3274 } = db
3275 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3276 .await?;
3277
3278 let mut connection_pool = session.connection_pool().await;
3279 if let Some(membership_update) = membership_update {
3280 notify_membership_updated(
3281 &mut connection_pool,
3282 membership_update,
3283 session.user_id(),
3284 &session.peer,
3285 );
3286 } else {
3287 let update = proto::UpdateChannels {
3288 remove_channel_invitations: vec![channel_id.to_proto()],
3289 ..Default::default()
3290 };
3291
3292 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3293 session.peer.send(connection_id, update.clone())?;
3294 }
3295 };
3296
3297 send_notifications(&connection_pool, &session.peer, notifications);
3298
3299 response.send(proto::Ack {})?;
3300
3301 Ok(())
3302}
3303
3304/// Join the channels' room
3305async fn join_channel(
3306 request: proto::JoinChannel,
3307 response: Response<proto::JoinChannel>,
3308 session: MessageContext,
3309) -> Result<()> {
3310 let channel_id = ChannelId::from_proto(request.channel_id);
3311 join_channel_internal(channel_id, Box::new(response), session).await
3312}
3313
3314trait JoinChannelInternalResponse {
3315 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3316}
3317impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3318 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3319 Response::<proto::JoinChannel>::send(self, result)
3320 }
3321}
3322impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3323 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3324 Response::<proto::JoinRoom>::send(self, result)
3325 }
3326}
3327
3328async fn join_channel_internal(
3329 channel_id: ChannelId,
3330 response: Box<impl JoinChannelInternalResponse>,
3331 session: MessageContext,
3332) -> Result<()> {
3333 let joined_room = {
3334 let mut db = session.db().await;
3335 // If zed quits without leaving the room, and the user re-opens zed before the
3336 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3337 // room they were in.
3338 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3339 tracing::info!(
3340 stale_connection_id = %connection,
3341 "cleaning up stale connection",
3342 );
3343 drop(db);
3344 leave_room_for_session(&session, connection).await?;
3345 db = session.db().await;
3346 }
3347
3348 let (joined_room, membership_updated, role) = db
3349 .join_channel(channel_id, session.user_id(), session.connection_id)
3350 .await?;
3351
3352 let live_kit_connection_info =
3353 session
3354 .app_state
3355 .livekit_client
3356 .as_ref()
3357 .and_then(|live_kit| {
3358 let (can_publish, token) = if role == ChannelRole::Guest {
3359 (
3360 false,
3361 live_kit
3362 .guest_token(
3363 &joined_room.room.livekit_room,
3364 &session.user_id().to_string(),
3365 )
3366 .trace_err()?,
3367 )
3368 } else {
3369 (
3370 true,
3371 live_kit
3372 .room_token(
3373 &joined_room.room.livekit_room,
3374 &session.user_id().to_string(),
3375 )
3376 .trace_err()?,
3377 )
3378 };
3379
3380 Some(LiveKitConnectionInfo {
3381 server_url: live_kit.url().into(),
3382 token,
3383 can_publish,
3384 })
3385 });
3386
3387 response.send(proto::JoinRoomResponse {
3388 room: Some(joined_room.room.clone()),
3389 channel_id: joined_room
3390 .channel
3391 .as_ref()
3392 .map(|channel| channel.id.to_proto()),
3393 live_kit_connection_info,
3394 })?;
3395
3396 let mut connection_pool = session.connection_pool().await;
3397 if let Some(membership_updated) = membership_updated {
3398 notify_membership_updated(
3399 &mut connection_pool,
3400 membership_updated,
3401 session.user_id(),
3402 &session.peer,
3403 );
3404 }
3405
3406 room_updated(&joined_room.room, &session.peer);
3407
3408 joined_room
3409 };
3410
3411 channel_updated(
3412 &joined_room.channel.context("channel not returned")?,
3413 &joined_room.room,
3414 &session.peer,
3415 &*session.connection_pool().await,
3416 );
3417
3418 update_user_contacts(session.user_id(), &session).await?;
3419 Ok(())
3420}
3421
3422/// Start editing the channel notes
3423async fn join_channel_buffer(
3424 request: proto::JoinChannelBuffer,
3425 response: Response<proto::JoinChannelBuffer>,
3426 session: MessageContext,
3427) -> Result<()> {
3428 let db = session.db().await;
3429 let channel_id = ChannelId::from_proto(request.channel_id);
3430
3431 let open_response = db
3432 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3433 .await?;
3434
3435 let collaborators = open_response.collaborators.clone();
3436 response.send(open_response)?;
3437
3438 let update = UpdateChannelBufferCollaborators {
3439 channel_id: channel_id.to_proto(),
3440 collaborators: collaborators.clone(),
3441 };
3442 channel_buffer_updated(
3443 session.connection_id,
3444 collaborators
3445 .iter()
3446 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3447 &update,
3448 &session.peer,
3449 );
3450
3451 Ok(())
3452}
3453
3454/// Edit the channel notes
3455async fn update_channel_buffer(
3456 request: proto::UpdateChannelBuffer,
3457 session: MessageContext,
3458) -> Result<()> {
3459 let db = session.db().await;
3460 let channel_id = ChannelId::from_proto(request.channel_id);
3461
3462 let (collaborators, epoch, version) = db
3463 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3464 .await?;
3465
3466 channel_buffer_updated(
3467 session.connection_id,
3468 collaborators.clone(),
3469 &proto::UpdateChannelBuffer {
3470 channel_id: channel_id.to_proto(),
3471 operations: request.operations,
3472 },
3473 &session.peer,
3474 );
3475
3476 let pool = &*session.connection_pool().await;
3477
3478 let non_collaborators =
3479 pool.channel_connection_ids(channel_id)
3480 .filter_map(|(connection_id, _)| {
3481 if collaborators.contains(&connection_id) {
3482 None
3483 } else {
3484 Some(connection_id)
3485 }
3486 });
3487
3488 broadcast(None, non_collaborators, |peer_id| {
3489 session.peer.send(
3490 peer_id,
3491 proto::UpdateChannels {
3492 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3493 channel_id: channel_id.to_proto(),
3494 epoch: epoch as u64,
3495 version: version.clone(),
3496 }],
3497 ..Default::default()
3498 },
3499 )
3500 });
3501
3502 Ok(())
3503}
3504
3505/// Rejoin the channel notes after a connection blip
3506async fn rejoin_channel_buffers(
3507 request: proto::RejoinChannelBuffers,
3508 response: Response<proto::RejoinChannelBuffers>,
3509 session: MessageContext,
3510) -> Result<()> {
3511 let db = session.db().await;
3512 let buffers = db
3513 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3514 .await?;
3515
3516 for rejoined_buffer in &buffers {
3517 let collaborators_to_notify = rejoined_buffer
3518 .buffer
3519 .collaborators
3520 .iter()
3521 .filter_map(|c| Some(c.peer_id?.into()));
3522 channel_buffer_updated(
3523 session.connection_id,
3524 collaborators_to_notify,
3525 &proto::UpdateChannelBufferCollaborators {
3526 channel_id: rejoined_buffer.buffer.channel_id,
3527 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3528 },
3529 &session.peer,
3530 );
3531 }
3532
3533 response.send(proto::RejoinChannelBuffersResponse {
3534 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3535 })?;
3536
3537 Ok(())
3538}
3539
3540/// Stop editing the channel notes
3541async fn leave_channel_buffer(
3542 request: proto::LeaveChannelBuffer,
3543 response: Response<proto::LeaveChannelBuffer>,
3544 session: MessageContext,
3545) -> Result<()> {
3546 let db = session.db().await;
3547 let channel_id = ChannelId::from_proto(request.channel_id);
3548
3549 let left_buffer = db
3550 .leave_channel_buffer(channel_id, session.connection_id)
3551 .await?;
3552
3553 response.send(Ack {})?;
3554
3555 channel_buffer_updated(
3556 session.connection_id,
3557 left_buffer.connections,
3558 &proto::UpdateChannelBufferCollaborators {
3559 channel_id: channel_id.to_proto(),
3560 collaborators: left_buffer.collaborators,
3561 },
3562 &session.peer,
3563 );
3564
3565 Ok(())
3566}
3567
3568fn channel_buffer_updated<T: EnvelopedMessage>(
3569 sender_id: ConnectionId,
3570 collaborators: impl IntoIterator<Item = ConnectionId>,
3571 message: &T,
3572 peer: &Peer,
3573) {
3574 broadcast(Some(sender_id), collaborators, |peer_id| {
3575 peer.send(peer_id, message.clone())
3576 });
3577}
3578
3579fn send_notifications(
3580 connection_pool: &ConnectionPool,
3581 peer: &Peer,
3582 notifications: db::NotificationBatch,
3583) {
3584 for (user_id, notification) in notifications {
3585 for connection_id in connection_pool.user_connection_ids(user_id) {
3586 if let Err(error) = peer.send(
3587 connection_id,
3588 proto::AddNotification {
3589 notification: Some(notification.clone()),
3590 },
3591 ) {
3592 tracing::error!(
3593 "failed to send notification to {:?} {}",
3594 connection_id,
3595 error
3596 );
3597 }
3598 }
3599 }
3600}
3601
3602/// Send a message to the channel
3603async fn send_channel_message(
3604 _request: proto::SendChannelMessage,
3605 _response: Response<proto::SendChannelMessage>,
3606 _session: MessageContext,
3607) -> Result<()> {
3608 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3609}
3610
3611/// Delete a channel message
3612async fn remove_channel_message(
3613 _request: proto::RemoveChannelMessage,
3614 _response: Response<proto::RemoveChannelMessage>,
3615 _session: MessageContext,
3616) -> Result<()> {
3617 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3618}
3619
3620async fn update_channel_message(
3621 _request: proto::UpdateChannelMessage,
3622 _response: Response<proto::UpdateChannelMessage>,
3623 _session: MessageContext,
3624) -> Result<()> {
3625 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3626}
3627
3628/// Mark a channel message as read
3629async fn acknowledge_channel_message(
3630 _request: proto::AckChannelMessage,
3631 _session: MessageContext,
3632) -> Result<()> {
3633 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3634}
3635
3636/// Mark a buffer version as synced
3637async fn acknowledge_buffer_version(
3638 request: proto::AckBufferOperation,
3639 session: MessageContext,
3640) -> Result<()> {
3641 let buffer_id = BufferId::from_proto(request.buffer_id);
3642 session
3643 .db()
3644 .await
3645 .observe_buffer_version(
3646 buffer_id,
3647 session.user_id(),
3648 request.epoch as i32,
3649 &request.version,
3650 )
3651 .await?;
3652 Ok(())
3653}
3654
3655/// Get a Supermaven API key for the user
3656async fn get_supermaven_api_key(
3657 _request: proto::GetSupermavenApiKey,
3658 response: Response<proto::GetSupermavenApiKey>,
3659 session: MessageContext,
3660) -> Result<()> {
3661 let user_id: String = session.user_id().to_string();
3662 if !session.is_staff() {
3663 return Err(anyhow!("supermaven not enabled for this account"))?;
3664 }
3665
3666 let email = session.email().context("user must have an email")?;
3667
3668 let supermaven_admin_api = session
3669 .supermaven_client
3670 .as_ref()
3671 .context("supermaven not configured")?;
3672
3673 let result = supermaven_admin_api
3674 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3675 .await?;
3676
3677 response.send(proto::GetSupermavenApiKeyResponse {
3678 api_key: result.api_key,
3679 })?;
3680
3681 Ok(())
3682}
3683
3684/// Start receiving chat updates for a channel
3685async fn join_channel_chat(
3686 _request: proto::JoinChannelChat,
3687 _response: Response<proto::JoinChannelChat>,
3688 _session: MessageContext,
3689) -> Result<()> {
3690 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3691}
3692
3693/// Stop receiving chat updates for a channel
3694async fn leave_channel_chat(
3695 _request: proto::LeaveChannelChat,
3696 _session: MessageContext,
3697) -> Result<()> {
3698 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3699}
3700
3701/// Retrieve the chat history for a channel
3702async fn get_channel_messages(
3703 _request: proto::GetChannelMessages,
3704 _response: Response<proto::GetChannelMessages>,
3705 _session: MessageContext,
3706) -> Result<()> {
3707 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3708}
3709
3710/// Retrieve specific chat messages
3711async fn get_channel_messages_by_id(
3712 _request: proto::GetChannelMessagesById,
3713 _response: Response<proto::GetChannelMessagesById>,
3714 _session: MessageContext,
3715) -> Result<()> {
3716 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3717}
3718
3719/// Retrieve the current users notifications
3720async fn get_notifications(
3721 request: proto::GetNotifications,
3722 response: Response<proto::GetNotifications>,
3723 session: MessageContext,
3724) -> Result<()> {
3725 let notifications = session
3726 .db()
3727 .await
3728 .get_notifications(
3729 session.user_id(),
3730 NOTIFICATION_COUNT_PER_PAGE,
3731 request.before_id.map(db::NotificationId::from_proto),
3732 )
3733 .await?;
3734 response.send(proto::GetNotificationsResponse {
3735 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3736 notifications,
3737 })?;
3738 Ok(())
3739}
3740
3741/// Mark notifications as read
3742async fn mark_notification_as_read(
3743 request: proto::MarkNotificationRead,
3744 response: Response<proto::MarkNotificationRead>,
3745 session: MessageContext,
3746) -> Result<()> {
3747 let database = &session.db().await;
3748 let notifications = database
3749 .mark_notification_as_read_by_id(
3750 session.user_id(),
3751 NotificationId::from_proto(request.notification_id),
3752 )
3753 .await?;
3754 send_notifications(
3755 &*session.connection_pool().await,
3756 &session.peer,
3757 notifications,
3758 );
3759 response.send(proto::Ack {})?;
3760 Ok(())
3761}
3762
3763fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3764 let message = match message {
3765 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3766 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3767 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3768 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3769 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3770 code: frame.code.into(),
3771 reason: frame.reason.as_str().to_owned().into(),
3772 })),
3773 // We should never receive a frame while reading the message, according
3774 // to the `tungstenite` maintainers:
3775 //
3776 // > It cannot occur when you read messages from the WebSocket, but it
3777 // > can be used when you want to send the raw frames (e.g. you want to
3778 // > send the frames to the WebSocket without composing the full message first).
3779 // >
3780 // > — https://github.com/snapview/tungstenite-rs/issues/268
3781 TungsteniteMessage::Frame(_) => {
3782 bail!("received an unexpected frame while reading the message")
3783 }
3784 };
3785
3786 Ok(message)
3787}
3788
3789fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3790 match message {
3791 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3792 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3793 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3794 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3795 AxumMessage::Close(frame) => {
3796 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3797 code: frame.code.into(),
3798 reason: frame.reason.as_ref().into(),
3799 }))
3800 }
3801 }
3802}
3803
3804fn notify_membership_updated(
3805 connection_pool: &mut ConnectionPool,
3806 result: MembershipUpdated,
3807 user_id: UserId,
3808 peer: &Peer,
3809) {
3810 for membership in &result.new_channels.channel_memberships {
3811 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3812 }
3813 for channel_id in &result.removed_channels {
3814 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3815 }
3816
3817 let user_channels_update = proto::UpdateUserChannels {
3818 channel_memberships: result
3819 .new_channels
3820 .channel_memberships
3821 .iter()
3822 .map(|cm| proto::ChannelMembership {
3823 channel_id: cm.channel_id.to_proto(),
3824 role: cm.role.into(),
3825 })
3826 .collect(),
3827 ..Default::default()
3828 };
3829
3830 let mut update = build_channels_update(result.new_channels);
3831 update.delete_channels = result
3832 .removed_channels
3833 .into_iter()
3834 .map(|id| id.to_proto())
3835 .collect();
3836 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3837
3838 for connection_id in connection_pool.user_connection_ids(user_id) {
3839 peer.send(connection_id, user_channels_update.clone())
3840 .trace_err();
3841 peer.send(connection_id, update.clone()).trace_err();
3842 }
3843}
3844
3845fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3846 proto::UpdateUserChannels {
3847 channel_memberships: channels
3848 .channel_memberships
3849 .iter()
3850 .map(|m| proto::ChannelMembership {
3851 channel_id: m.channel_id.to_proto(),
3852 role: m.role.into(),
3853 })
3854 .collect(),
3855 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3856 }
3857}
3858
3859fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3860 let mut update = proto::UpdateChannels::default();
3861
3862 for channel in channels.channels {
3863 update.channels.push(channel.to_proto());
3864 }
3865
3866 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3867
3868 for (channel_id, participants) in channels.channel_participants {
3869 update
3870 .channel_participants
3871 .push(proto::ChannelParticipants {
3872 channel_id: channel_id.to_proto(),
3873 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3874 });
3875 }
3876
3877 for channel in channels.invited_channels {
3878 update.channel_invitations.push(channel.to_proto());
3879 }
3880
3881 update
3882}
3883
3884fn build_initial_contacts_update(
3885 contacts: Vec<db::Contact>,
3886 pool: &ConnectionPool,
3887) -> proto::UpdateContacts {
3888 let mut update = proto::UpdateContacts::default();
3889
3890 for contact in contacts {
3891 match contact {
3892 db::Contact::Accepted { user_id, busy } => {
3893 update.contacts.push(contact_for_user(user_id, busy, pool));
3894 }
3895 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3896 db::Contact::Incoming { user_id } => {
3897 update
3898 .incoming_requests
3899 .push(proto::IncomingContactRequest {
3900 requester_id: user_id.to_proto(),
3901 })
3902 }
3903 }
3904 }
3905
3906 update
3907}
3908
3909fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3910 proto::Contact {
3911 user_id: user_id.to_proto(),
3912 online: pool.is_user_online(user_id),
3913 busy,
3914 }
3915}
3916
3917fn room_updated(room: &proto::Room, peer: &Peer) {
3918 broadcast(
3919 None,
3920 room.participants
3921 .iter()
3922 .filter_map(|participant| Some(participant.peer_id?.into())),
3923 |peer_id| {
3924 peer.send(
3925 peer_id,
3926 proto::RoomUpdated {
3927 room: Some(room.clone()),
3928 },
3929 )
3930 },
3931 );
3932}
3933
3934fn channel_updated(
3935 channel: &db::channel::Model,
3936 room: &proto::Room,
3937 peer: &Peer,
3938 pool: &ConnectionPool,
3939) {
3940 let participants = room
3941 .participants
3942 .iter()
3943 .map(|p| p.user_id)
3944 .collect::<Vec<_>>();
3945
3946 broadcast(
3947 None,
3948 pool.channel_connection_ids(channel.root_id())
3949 .filter_map(|(channel_id, role)| {
3950 role.can_see_channel(channel.visibility)
3951 .then_some(channel_id)
3952 }),
3953 |peer_id| {
3954 peer.send(
3955 peer_id,
3956 proto::UpdateChannels {
3957 channel_participants: vec![proto::ChannelParticipants {
3958 channel_id: channel.id.to_proto(),
3959 participant_user_ids: participants.clone(),
3960 }],
3961 ..Default::default()
3962 },
3963 )
3964 },
3965 );
3966}
3967
3968async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3969 let db = session.db().await;
3970
3971 let contacts = db.get_contacts(user_id).await?;
3972 let busy = db.is_user_busy(user_id).await?;
3973
3974 let pool = session.connection_pool().await;
3975 let updated_contact = contact_for_user(user_id, busy, &pool);
3976 for contact in contacts {
3977 if let db::Contact::Accepted {
3978 user_id: contact_user_id,
3979 ..
3980 } = contact
3981 {
3982 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3983 session
3984 .peer
3985 .send(
3986 contact_conn_id,
3987 proto::UpdateContacts {
3988 contacts: vec![updated_contact.clone()],
3989 remove_contacts: Default::default(),
3990 incoming_requests: Default::default(),
3991 remove_incoming_requests: Default::default(),
3992 outgoing_requests: Default::default(),
3993 remove_outgoing_requests: Default::default(),
3994 },
3995 )
3996 .trace_err();
3997 }
3998 }
3999 }
4000 Ok(())
4001}
4002
4003async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4004 let mut contacts_to_update = HashSet::default();
4005
4006 let room_id;
4007 let canceled_calls_to_user_ids;
4008 let livekit_room;
4009 let delete_livekit_room;
4010 let room;
4011 let channel;
4012
4013 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4014 contacts_to_update.insert(session.user_id());
4015
4016 for project in left_room.left_projects.values() {
4017 project_left(project, session);
4018 }
4019
4020 room_id = RoomId::from_proto(left_room.room.id);
4021 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4022 livekit_room = mem::take(&mut left_room.room.livekit_room);
4023 delete_livekit_room = left_room.deleted;
4024 room = mem::take(&mut left_room.room);
4025 channel = mem::take(&mut left_room.channel);
4026
4027 room_updated(&room, &session.peer);
4028 } else {
4029 return Ok(());
4030 }
4031
4032 if let Some(channel) = channel {
4033 channel_updated(
4034 &channel,
4035 &room,
4036 &session.peer,
4037 &*session.connection_pool().await,
4038 );
4039 }
4040
4041 {
4042 let pool = session.connection_pool().await;
4043 for canceled_user_id in canceled_calls_to_user_ids {
4044 for connection_id in pool.user_connection_ids(canceled_user_id) {
4045 session
4046 .peer
4047 .send(
4048 connection_id,
4049 proto::CallCanceled {
4050 room_id: room_id.to_proto(),
4051 },
4052 )
4053 .trace_err();
4054 }
4055 contacts_to_update.insert(canceled_user_id);
4056 }
4057 }
4058
4059 for contact_user_id in contacts_to_update {
4060 update_user_contacts(contact_user_id, session).await?;
4061 }
4062
4063 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4064 live_kit
4065 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4066 .await
4067 .trace_err();
4068
4069 if delete_livekit_room {
4070 live_kit.delete_room(livekit_room).await.trace_err();
4071 }
4072 }
4073
4074 Ok(())
4075}
4076
4077async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4078 let left_channel_buffers = session
4079 .db()
4080 .await
4081 .leave_channel_buffers(session.connection_id)
4082 .await?;
4083
4084 for left_buffer in left_channel_buffers {
4085 channel_buffer_updated(
4086 session.connection_id,
4087 left_buffer.connections,
4088 &proto::UpdateChannelBufferCollaborators {
4089 channel_id: left_buffer.channel_id.to_proto(),
4090 collaborators: left_buffer.collaborators,
4091 },
4092 &session.peer,
4093 );
4094 }
4095
4096 Ok(())
4097}
4098
4099fn project_left(project: &db::LeftProject, session: &Session) {
4100 for connection_id in &project.connection_ids {
4101 if project.should_unshare {
4102 session
4103 .peer
4104 .send(
4105 *connection_id,
4106 proto::UnshareProject {
4107 project_id: project.id.to_proto(),
4108 },
4109 )
4110 .trace_err();
4111 } else {
4112 session
4113 .peer
4114 .send(
4115 *connection_id,
4116 proto::RemoveProjectCollaborator {
4117 project_id: project.id.to_proto(),
4118 peer_id: Some(session.connection_id.into()),
4119 },
4120 )
4121 .trace_err();
4122 }
4123 }
4124}
4125
4126pub trait ResultExt {
4127 type Ok;
4128
4129 fn trace_err(self) -> Option<Self::Ok>;
4130}
4131
4132impl<T, E> ResultExt for Result<T, E>
4133where
4134 E: std::fmt::Debug,
4135{
4136 type Ok = T;
4137
4138 #[track_caller]
4139 fn trace_err(self) -> Option<T> {
4140 match self {
4141 Ok(value) => Some(value),
4142 Err(error) => {
4143 tracing::error!("{:?}", error);
4144 None
4145 }
4146 }
4147 }
4148}