1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, User, UserId,
10 },
11 executor::Executor,
12};
13use anyhow::{Context as _, anyhow, bail};
14use async_tungstenite::tungstenite::{
15 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
16};
17use axum::headers::UserAgent;
18use axum::{
19 Extension, Router, TypedHeader,
20 body::Body,
21 extract::{
22 ConnectInfo, WebSocketUpgrade,
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30};
31use collections::{HashMap, HashSet};
32pub use connection_pool::{ConnectionPool, ZedVersion};
33use core::fmt::{self, Debug, Formatter};
34use futures::TryFutureExt as _;
35use reqwest_client::ReqwestClient;
36use rpc::proto::{MultiLspQuery, split_repository_update};
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38use tracing::Span;
39
40use futures::{
41 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
42 stream::FuturesUnordered,
43};
44use prometheus::{IntGauge, register_int_gauge};
45use rpc::{
46 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
47 proto::{
48 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
49 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
50 },
51};
52use semantic_version::SemanticVersion;
53use serde::{Serialize, Serializer};
54use std::{
55 any::TypeId,
56 future::Future,
57 marker::PhantomData,
58 mem,
59 net::SocketAddr,
60 ops::{Deref, DerefMut},
61 rc::Rc,
62 sync::{
63 Arc, OnceLock,
64 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
65 },
66 time::{Duration, Instant},
67};
68use tokio::sync::{Semaphore, watch};
69use tower::ServiceBuilder;
70use tracing::{
71 Instrument,
72 field::{self},
73 info_span, instrument,
74};
75
76pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
77
78// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
79pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
80
81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
82const MAX_CONCURRENT_CONNECTIONS: usize = 512;
83
84static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
85
86const TOTAL_DURATION_MS: &str = "total_duration_ms";
87const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
88const QUEUE_DURATION_MS: &str = "queue_duration_ms";
89const HOST_WAITING_MS: &str = "host_waiting_ms";
90
91type MessageHandler =
92 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
93
94pub struct ConnectionGuard;
95
96impl ConnectionGuard {
97 pub fn try_acquire() -> Result<Self, ()> {
98 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
99 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
100 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
101 tracing::error!(
102 "too many concurrent connections: {}",
103 current_connections + 1
104 );
105 return Err(());
106 }
107 Ok(ConnectionGuard)
108 }
109}
110
111impl Drop for ConnectionGuard {
112 fn drop(&mut self) {
113 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
114 }
115}
116
117struct Response<R> {
118 peer: Arc<Peer>,
119 receipt: Receipt<R>,
120 responded: Arc<AtomicBool>,
121}
122
123impl<R: RequestMessage> Response<R> {
124 fn send(self, payload: R::Response) -> Result<()> {
125 self.responded.store(true, SeqCst);
126 self.peer.respond(self.receipt, payload)?;
127 Ok(())
128 }
129}
130
131#[derive(Clone, Debug)]
132pub enum Principal {
133 User(User),
134 Impersonated { user: User, admin: User },
135}
136
137impl Principal {
138 fn update_span(&self, span: &tracing::Span) {
139 match &self {
140 Principal::User(user) => {
141 span.record("user_id", user.id.0);
142 span.record("login", &user.github_login);
143 }
144 Principal::Impersonated { user, admin } => {
145 span.record("user_id", user.id.0);
146 span.record("login", &user.github_login);
147 span.record("impersonator", &admin.github_login);
148 }
149 }
150 }
151}
152
153#[derive(Clone)]
154struct MessageContext {
155 session: Session,
156 span: tracing::Span,
157}
158
159impl Deref for MessageContext {
160 type Target = Session;
161
162 fn deref(&self) -> &Self::Target {
163 &self.session
164 }
165}
166
167impl MessageContext {
168 pub fn forward_request<T: RequestMessage>(
169 &self,
170 receiver_id: ConnectionId,
171 request: T,
172 ) -> impl Future<Output = anyhow::Result<T::Response>> {
173 let request_start_time = Instant::now();
174 let span = self.span.clone();
175 tracing::info!("start forwarding request");
176 self.peer
177 .forward_request(self.connection_id, receiver_id, request)
178 .inspect(move |_| {
179 span.record(
180 HOST_WAITING_MS,
181 request_start_time.elapsed().as_micros() as f64 / 1000.0,
182 );
183 })
184 .inspect_err(|_| tracing::error!("error forwarding request"))
185 .inspect_ok(|_| tracing::info!("finished forwarding request"))
186 }
187}
188
189#[derive(Clone)]
190struct Session {
191 principal: Principal,
192 connection_id: ConnectionId,
193 db: Arc<tokio::sync::Mutex<DbHandle>>,
194 peer: Arc<Peer>,
195 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
196 app_state: Arc<AppState>,
197 supermaven_client: Option<Arc<SupermavenAdminApi>>,
198 /// The GeoIP country code for the user.
199 #[allow(unused)]
200 geoip_country_code: Option<String>,
201 #[allow(unused)]
202 system_id: Option<String>,
203 _executor: Executor,
204}
205
206impl Session {
207 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
208 #[cfg(test)]
209 tokio::task::yield_now().await;
210 let guard = self.db.lock().await;
211 #[cfg(test)]
212 tokio::task::yield_now().await;
213 guard
214 }
215
216 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
217 #[cfg(test)]
218 tokio::task::yield_now().await;
219 let guard = self.connection_pool.lock();
220 ConnectionPoolGuard {
221 guard,
222 _not_send: PhantomData,
223 }
224 }
225
226 fn is_staff(&self) -> bool {
227 match &self.principal {
228 Principal::User(user) => user.admin,
229 Principal::Impersonated { .. } => true,
230 }
231 }
232
233 fn user_id(&self) -> UserId {
234 match &self.principal {
235 Principal::User(user) => user.id,
236 Principal::Impersonated { user, .. } => user.id,
237 }
238 }
239
240 pub fn email(&self) -> Option<String> {
241 match &self.principal {
242 Principal::User(user) => user.email_address.clone(),
243 Principal::Impersonated { user, .. } => user.email_address.clone(),
244 }
245 }
246}
247
248impl Debug for Session {
249 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
250 let mut result = f.debug_struct("Session");
251 match &self.principal {
252 Principal::User(user) => {
253 result.field("user", &user.github_login);
254 }
255 Principal::Impersonated { user, admin } => {
256 result.field("user", &user.github_login);
257 result.field("impersonator", &admin.github_login);
258 }
259 }
260 result.field("connection_id", &self.connection_id).finish()
261 }
262}
263
264struct DbHandle(Arc<Database>);
265
266impl Deref for DbHandle {
267 type Target = Database;
268
269 fn deref(&self) -> &Self::Target {
270 self.0.as_ref()
271 }
272}
273
274pub struct Server {
275 id: parking_lot::Mutex<ServerId>,
276 peer: Arc<Peer>,
277 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
278 app_state: Arc<AppState>,
279 handlers: HashMap<TypeId, MessageHandler>,
280 teardown: watch::Sender<bool>,
281}
282
283pub(crate) struct ConnectionPoolGuard<'a> {
284 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
285 _not_send: PhantomData<Rc<()>>,
286}
287
288#[derive(Serialize)]
289pub struct ServerSnapshot<'a> {
290 peer: &'a Peer,
291 #[serde(serialize_with = "serialize_deref")]
292 connection_pool: ConnectionPoolGuard<'a>,
293}
294
295pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
296where
297 S: Serializer,
298 T: Deref<Target = U>,
299 U: Serialize,
300{
301 Serialize::serialize(value.deref(), serializer)
302}
303
304impl Server {
305 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
306 let mut server = Self {
307 id: parking_lot::Mutex::new(id),
308 peer: Peer::new(id.0 as u32),
309 app_state,
310 connection_pool: Default::default(),
311 handlers: Default::default(),
312 teardown: watch::channel(false).0,
313 };
314
315 server
316 .add_request_handler(ping)
317 .add_request_handler(create_room)
318 .add_request_handler(join_room)
319 .add_request_handler(rejoin_room)
320 .add_request_handler(leave_room)
321 .add_request_handler(set_room_participant_role)
322 .add_request_handler(call)
323 .add_request_handler(cancel_call)
324 .add_message_handler(decline_call)
325 .add_request_handler(update_participant_location)
326 .add_request_handler(share_project)
327 .add_message_handler(unshare_project)
328 .add_request_handler(join_project)
329 .add_message_handler(leave_project)
330 .add_request_handler(update_project)
331 .add_request_handler(update_worktree)
332 .add_request_handler(update_repository)
333 .add_request_handler(remove_repository)
334 .add_message_handler(start_language_server)
335 .add_message_handler(update_language_server)
336 .add_message_handler(update_diagnostic_summary)
337 .add_message_handler(update_worktree_settings)
338 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
339 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
340 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
341 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
342 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
343 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
344 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
345 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
346 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
347 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
348 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
349 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
350 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
351 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
352 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
353 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
354 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
355 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
356 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
357 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
358 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
359 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
360 .add_request_handler(
361 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
362 )
363 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
364 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
365 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
366 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
367 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
368 .add_request_handler(
369 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
370 )
371 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
372 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
373 .add_request_handler(
374 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
375 )
376 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
377 .add_request_handler(
378 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
379 )
380 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
381 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
382 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
383 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
384 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
385 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
386 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
387 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
388 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
389 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
390 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
391 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
392 .add_request_handler(
393 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
394 )
395 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
396 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
397 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
398 .add_request_handler(multi_lsp_query)
399 .add_request_handler(lsp_query)
400 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
401 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
402 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
403 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
404 .add_message_handler(create_buffer_for_peer)
405 .add_request_handler(update_buffer)
406 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
407 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
408 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
409 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
410 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
411 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
412 .add_message_handler(
413 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
414 )
415 .add_request_handler(get_users)
416 .add_request_handler(fuzzy_search_users)
417 .add_request_handler(request_contact)
418 .add_request_handler(remove_contact)
419 .add_request_handler(respond_to_contact_request)
420 .add_message_handler(subscribe_to_channels)
421 .add_request_handler(create_channel)
422 .add_request_handler(delete_channel)
423 .add_request_handler(invite_channel_member)
424 .add_request_handler(remove_channel_member)
425 .add_request_handler(set_channel_member_role)
426 .add_request_handler(set_channel_visibility)
427 .add_request_handler(rename_channel)
428 .add_request_handler(join_channel_buffer)
429 .add_request_handler(leave_channel_buffer)
430 .add_message_handler(update_channel_buffer)
431 .add_request_handler(rejoin_channel_buffers)
432 .add_request_handler(get_channel_members)
433 .add_request_handler(respond_to_channel_invite)
434 .add_request_handler(join_channel)
435 .add_request_handler(join_channel_chat)
436 .add_message_handler(leave_channel_chat)
437 .add_request_handler(send_channel_message)
438 .add_request_handler(remove_channel_message)
439 .add_request_handler(update_channel_message)
440 .add_request_handler(get_channel_messages)
441 .add_request_handler(get_channel_messages_by_id)
442 .add_request_handler(get_notifications)
443 .add_request_handler(mark_notification_as_read)
444 .add_request_handler(move_channel)
445 .add_request_handler(reorder_channel)
446 .add_request_handler(follow)
447 .add_message_handler(unfollow)
448 .add_message_handler(update_followers)
449 .add_message_handler(acknowledge_channel_message)
450 .add_message_handler(acknowledge_buffer_version)
451 .add_request_handler(get_supermaven_api_key)
452 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
453 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
454 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
455 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
456 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
457 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
458 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
459 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
460 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
461 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
462 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
463 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
464 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
465 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
466 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
467 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
468 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
469 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
470 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
471 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
472 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
473 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
474 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
475 .add_message_handler(update_context)
476 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
477 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
478
479 Arc::new(server)
480 }
481
482 pub async fn start(&self) -> Result<()> {
483 let server_id = *self.id.lock();
484 let app_state = self.app_state.clone();
485 let peer = self.peer.clone();
486 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
487 let pool = self.connection_pool.clone();
488 let livekit_client = self.app_state.livekit_client.clone();
489
490 let span = info_span!("start server");
491 self.app_state.executor.spawn_detached(
492 async move {
493 tracing::info!("waiting for cleanup timeout");
494 timeout.await;
495 tracing::info!("cleanup timeout expired, retrieving stale rooms");
496
497 app_state
498 .db
499 .delete_stale_channel_chat_participants(
500 &app_state.config.zed_environment,
501 server_id,
502 )
503 .await
504 .trace_err();
505
506 if let Some((room_ids, channel_ids)) = app_state
507 .db
508 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
509 .await
510 .trace_err()
511 {
512 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
513 tracing::info!(
514 stale_channel_buffer_count = channel_ids.len(),
515 "retrieved stale channel buffers"
516 );
517
518 for channel_id in channel_ids {
519 if let Some(refreshed_channel_buffer) = app_state
520 .db
521 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
522 .await
523 .trace_err()
524 {
525 for connection_id in refreshed_channel_buffer.connection_ids {
526 peer.send(
527 connection_id,
528 proto::UpdateChannelBufferCollaborators {
529 channel_id: channel_id.to_proto(),
530 collaborators: refreshed_channel_buffer
531 .collaborators
532 .clone(),
533 },
534 )
535 .trace_err();
536 }
537 }
538 }
539
540 for room_id in room_ids {
541 let mut contacts_to_update = HashSet::default();
542 let mut canceled_calls_to_user_ids = Vec::new();
543 let mut livekit_room = String::new();
544 let mut delete_livekit_room = false;
545
546 if let Some(mut refreshed_room) = app_state
547 .db
548 .clear_stale_room_participants(room_id, server_id)
549 .await
550 .trace_err()
551 {
552 tracing::info!(
553 room_id = room_id.0,
554 new_participant_count = refreshed_room.room.participants.len(),
555 "refreshed room"
556 );
557 room_updated(&refreshed_room.room, &peer);
558 if let Some(channel) = refreshed_room.channel.as_ref() {
559 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
560 }
561 contacts_to_update
562 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
563 contacts_to_update
564 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
565 canceled_calls_to_user_ids =
566 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
567 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
568 delete_livekit_room = refreshed_room.room.participants.is_empty();
569 }
570
571 {
572 let pool = pool.lock();
573 for canceled_user_id in canceled_calls_to_user_ids {
574 for connection_id in pool.user_connection_ids(canceled_user_id) {
575 peer.send(
576 connection_id,
577 proto::CallCanceled {
578 room_id: room_id.to_proto(),
579 },
580 )
581 .trace_err();
582 }
583 }
584 }
585
586 for user_id in contacts_to_update {
587 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
588 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
589 if let Some((busy, contacts)) = busy.zip(contacts) {
590 let pool = pool.lock();
591 let updated_contact = contact_for_user(user_id, busy, &pool);
592 for contact in contacts {
593 if let db::Contact::Accepted {
594 user_id: contact_user_id,
595 ..
596 } = contact
597 {
598 for contact_conn_id in
599 pool.user_connection_ids(contact_user_id)
600 {
601 peer.send(
602 contact_conn_id,
603 proto::UpdateContacts {
604 contacts: vec![updated_contact.clone()],
605 remove_contacts: Default::default(),
606 incoming_requests: Default::default(),
607 remove_incoming_requests: Default::default(),
608 outgoing_requests: Default::default(),
609 remove_outgoing_requests: Default::default(),
610 },
611 )
612 .trace_err();
613 }
614 }
615 }
616 }
617 }
618
619 if let Some(live_kit) = livekit_client.as_ref()
620 && delete_livekit_room
621 {
622 live_kit.delete_room(livekit_room).await.trace_err();
623 }
624 }
625 }
626
627 app_state
628 .db
629 .delete_stale_channel_chat_participants(
630 &app_state.config.zed_environment,
631 server_id,
632 )
633 .await
634 .trace_err();
635
636 app_state
637 .db
638 .clear_old_worktree_entries(server_id)
639 .await
640 .trace_err();
641
642 app_state
643 .db
644 .delete_stale_servers(&app_state.config.zed_environment, server_id)
645 .await
646 .trace_err();
647 }
648 .instrument(span),
649 );
650 Ok(())
651 }
652
653 pub fn teardown(&self) {
654 self.peer.teardown();
655 self.connection_pool.lock().reset();
656 let _ = self.teardown.send(true);
657 }
658
659 #[cfg(test)]
660 pub fn reset(&self, id: ServerId) {
661 self.teardown();
662 *self.id.lock() = id;
663 self.peer.reset(id.0 as u32);
664 let _ = self.teardown.send(false);
665 }
666
667 #[cfg(test)]
668 pub fn id(&self) -> ServerId {
669 *self.id.lock()
670 }
671
672 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
673 where
674 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
675 Fut: 'static + Send + Future<Output = Result<()>>,
676 M: EnvelopedMessage,
677 {
678 let prev_handler = self.handlers.insert(
679 TypeId::of::<M>(),
680 Box::new(move |envelope, session, span| {
681 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
682 let received_at = envelope.received_at;
683 tracing::info!("message received");
684 let start_time = Instant::now();
685 let future = (handler)(
686 *envelope,
687 MessageContext {
688 session,
689 span: span.clone(),
690 },
691 );
692 async move {
693 let result = future.await;
694 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
695 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
696 let queue_duration_ms = total_duration_ms - processing_duration_ms;
697 span.record(TOTAL_DURATION_MS, total_duration_ms);
698 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
699 span.record(QUEUE_DURATION_MS, queue_duration_ms);
700 match result {
701 Err(error) => {
702 tracing::error!(?error, "error handling message")
703 }
704 Ok(()) => tracing::info!("finished handling message"),
705 }
706 }
707 .boxed()
708 }),
709 );
710 if prev_handler.is_some() {
711 panic!("registered a handler for the same message twice");
712 }
713 self
714 }
715
716 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
717 where
718 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
719 Fut: 'static + Send + Future<Output = Result<()>>,
720 M: EnvelopedMessage,
721 {
722 self.add_handler(move |envelope, session| handler(envelope.payload, session));
723 self
724 }
725
726 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
727 where
728 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
729 Fut: Send + Future<Output = Result<()>>,
730 M: RequestMessage,
731 {
732 let handler = Arc::new(handler);
733 self.add_handler(move |envelope, session| {
734 let receipt = envelope.receipt();
735 let handler = handler.clone();
736 async move {
737 let peer = session.peer.clone();
738 let responded = Arc::new(AtomicBool::default());
739 let response = Response {
740 peer: peer.clone(),
741 responded: responded.clone(),
742 receipt,
743 };
744 match (handler)(envelope.payload, response, session).await {
745 Ok(()) => {
746 if responded.load(std::sync::atomic::Ordering::SeqCst) {
747 Ok(())
748 } else {
749 Err(anyhow!("handler did not send a response"))?
750 }
751 }
752 Err(error) => {
753 let proto_err = match &error {
754 Error::Internal(err) => err.to_proto(),
755 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
756 };
757 peer.respond_with_error(receipt, proto_err)?;
758 Err(error)
759 }
760 }
761 }
762 })
763 }
764
765 pub fn handle_connection(
766 self: &Arc<Self>,
767 connection: Connection,
768 address: String,
769 principal: Principal,
770 zed_version: ZedVersion,
771 release_channel: Option<String>,
772 user_agent: Option<String>,
773 geoip_country_code: Option<String>,
774 system_id: Option<String>,
775 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
776 executor: Executor,
777 connection_guard: Option<ConnectionGuard>,
778 ) -> impl Future<Output = ()> + use<> {
779 let this = self.clone();
780 let span = info_span!("handle connection", %address,
781 connection_id=field::Empty,
782 user_id=field::Empty,
783 login=field::Empty,
784 impersonator=field::Empty,
785 user_agent=field::Empty,
786 geoip_country_code=field::Empty,
787 release_channel=field::Empty,
788 );
789 principal.update_span(&span);
790 if let Some(user_agent) = user_agent {
791 span.record("user_agent", user_agent);
792 }
793 if let Some(release_channel) = release_channel {
794 span.record("release_channel", release_channel);
795 }
796
797 if let Some(country_code) = geoip_country_code.as_ref() {
798 span.record("geoip_country_code", country_code);
799 }
800
801 let mut teardown = self.teardown.subscribe();
802 async move {
803 if *teardown.borrow() {
804 tracing::error!("server is tearing down");
805 return;
806 }
807
808 let (connection_id, handle_io, mut incoming_rx) =
809 this.peer.add_connection(connection, {
810 let executor = executor.clone();
811 move |duration| executor.sleep(duration)
812 });
813 tracing::Span::current().record("connection_id", format!("{}", connection_id));
814
815 tracing::info!("connection opened");
816
817 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
818 let http_client = match ReqwestClient::user_agent(&user_agent) {
819 Ok(http_client) => Arc::new(http_client),
820 Err(error) => {
821 tracing::error!(?error, "failed to create HTTP client");
822 return;
823 }
824 };
825
826 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
827 |supermaven_admin_api_key| {
828 Arc::new(SupermavenAdminApi::new(
829 supermaven_admin_api_key.to_string(),
830 http_client.clone(),
831 ))
832 },
833 );
834
835 let session = Session {
836 principal: principal.clone(),
837 connection_id,
838 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
839 peer: this.peer.clone(),
840 connection_pool: this.connection_pool.clone(),
841 app_state: this.app_state.clone(),
842 geoip_country_code,
843 system_id,
844 _executor: executor.clone(),
845 supermaven_client,
846 };
847
848 if let Err(error) = this
849 .send_initial_client_update(
850 connection_id,
851 zed_version,
852 send_connection_id,
853 &session,
854 )
855 .await
856 {
857 tracing::error!(?error, "failed to send initial client update");
858 return;
859 }
860 drop(connection_guard);
861
862 let handle_io = handle_io.fuse();
863 futures::pin_mut!(handle_io);
864
865 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
866 // This prevents deadlocks when e.g., client A performs a request to client B and
867 // client B performs a request to client A. If both clients stop processing further
868 // messages until their respective request completes, they won't have a chance to
869 // respond to the other client's request and cause a deadlock.
870 //
871 // This arrangement ensures we will attempt to process earlier messages first, but fall
872 // back to processing messages arrived later in the spirit of making progress.
873 const MAX_CONCURRENT_HANDLERS: usize = 256;
874 let mut foreground_message_handlers = FuturesUnordered::new();
875 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
876 let get_concurrent_handlers = {
877 let concurrent_handlers = concurrent_handlers.clone();
878 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
879 };
880 loop {
881 let next_message = async {
882 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
883 let message = incoming_rx.next().await;
884 // Cache the concurrent_handlers here, so that we know what the
885 // queue looks like as each handler starts
886 (permit, message, get_concurrent_handlers())
887 }
888 .fuse();
889 futures::pin_mut!(next_message);
890 futures::select_biased! {
891 _ = teardown.changed().fuse() => return,
892 result = handle_io => {
893 if let Err(error) = result {
894 tracing::error!(?error, "error handling I/O");
895 }
896 break;
897 }
898 _ = foreground_message_handlers.next() => {}
899 next_message = next_message => {
900 let (permit, message, concurrent_handlers) = next_message;
901 if let Some(message) = message {
902 let type_name = message.payload_type_name();
903 // note: we copy all the fields from the parent span so we can query them in the logs.
904 // (https://github.com/tokio-rs/tracing/issues/2670).
905 let span = tracing::info_span!("receive message",
906 %connection_id,
907 %address,
908 type_name,
909 concurrent_handlers,
910 user_id=field::Empty,
911 login=field::Empty,
912 impersonator=field::Empty,
913 // todo(lsp) remove after Zed Stable hits v0.204.x
914 multi_lsp_query_request=field::Empty,
915 lsp_query_request=field::Empty,
916 release_channel=field::Empty,
917 { TOTAL_DURATION_MS }=field::Empty,
918 { PROCESSING_DURATION_MS }=field::Empty,
919 { QUEUE_DURATION_MS }=field::Empty,
920 { HOST_WAITING_MS }=field::Empty
921 );
922 principal.update_span(&span);
923 let span_enter = span.enter();
924 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
925 let is_background = message.is_background();
926 let handle_message = (handler)(message, session.clone(), span.clone());
927 drop(span_enter);
928
929 let handle_message = async move {
930 handle_message.await;
931 drop(permit);
932 }.instrument(span);
933 if is_background {
934 executor.spawn_detached(handle_message);
935 } else {
936 foreground_message_handlers.push(handle_message);
937 }
938 } else {
939 tracing::error!("no message handler");
940 }
941 } else {
942 tracing::info!("connection closed");
943 break;
944 }
945 }
946 }
947 }
948
949 drop(foreground_message_handlers);
950 let concurrent_handlers = get_concurrent_handlers();
951 tracing::info!(concurrent_handlers, "signing out");
952 if let Err(error) = connection_lost(session, teardown, executor).await {
953 tracing::error!(?error, "error signing out");
954 }
955 }
956 .instrument(span)
957 }
958
959 async fn send_initial_client_update(
960 &self,
961 connection_id: ConnectionId,
962 zed_version: ZedVersion,
963 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
964 session: &Session,
965 ) -> Result<()> {
966 self.peer.send(
967 connection_id,
968 proto::Hello {
969 peer_id: Some(connection_id.into()),
970 },
971 )?;
972 tracing::info!("sent hello message");
973 if let Some(send_connection_id) = send_connection_id.take() {
974 let _ = send_connection_id.send(connection_id);
975 }
976
977 match &session.principal {
978 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
979 if !user.connected_once {
980 self.peer.send(connection_id, proto::ShowContacts {})?;
981 self.app_state
982 .db
983 .set_user_connected_once(user.id, true)
984 .await?;
985 }
986
987 let contacts = self.app_state.db.get_contacts(user.id).await?;
988
989 {
990 let mut pool = self.connection_pool.lock();
991 pool.add_connection(connection_id, user.id, user.admin, zed_version);
992 self.peer.send(
993 connection_id,
994 build_initial_contacts_update(contacts, &pool),
995 )?;
996 }
997
998 if should_auto_subscribe_to_channels(zed_version) {
999 subscribe_user_to_channels(user.id, session).await?;
1000 }
1001
1002 if let Some(incoming_call) =
1003 self.app_state.db.incoming_call_for_user(user.id).await?
1004 {
1005 self.peer.send(connection_id, incoming_call)?;
1006 }
1007
1008 update_user_contacts(user.id, session).await?;
1009 }
1010 }
1011
1012 Ok(())
1013 }
1014
1015 pub async fn invite_code_redeemed(
1016 self: &Arc<Self>,
1017 inviter_id: UserId,
1018 invitee_id: UserId,
1019 ) -> Result<()> {
1020 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1021 && let Some(code) = &user.invite_code
1022 {
1023 let pool = self.connection_pool.lock();
1024 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1025 for connection_id in pool.user_connection_ids(inviter_id) {
1026 self.peer.send(
1027 connection_id,
1028 proto::UpdateContacts {
1029 contacts: vec![invitee_contact.clone()],
1030 ..Default::default()
1031 },
1032 )?;
1033 self.peer.send(
1034 connection_id,
1035 proto::UpdateInviteInfo {
1036 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1037 count: user.invite_count as u32,
1038 },
1039 )?;
1040 }
1041 }
1042 Ok(())
1043 }
1044
1045 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1046 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1047 && let Some(invite_code) = &user.invite_code
1048 {
1049 let pool = self.connection_pool.lock();
1050 for connection_id in pool.user_connection_ids(user_id) {
1051 self.peer.send(
1052 connection_id,
1053 proto::UpdateInviteInfo {
1054 url: format!(
1055 "{}{}",
1056 self.app_state.config.invite_link_prefix, invite_code
1057 ),
1058 count: user.invite_count as u32,
1059 },
1060 )?;
1061 }
1062 }
1063 Ok(())
1064 }
1065
1066 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1067 ServerSnapshot {
1068 connection_pool: ConnectionPoolGuard {
1069 guard: self.connection_pool.lock(),
1070 _not_send: PhantomData,
1071 },
1072 peer: &self.peer,
1073 }
1074 }
1075}
1076
1077impl Deref for ConnectionPoolGuard<'_> {
1078 type Target = ConnectionPool;
1079
1080 fn deref(&self) -> &Self::Target {
1081 &self.guard
1082 }
1083}
1084
1085impl DerefMut for ConnectionPoolGuard<'_> {
1086 fn deref_mut(&mut self) -> &mut Self::Target {
1087 &mut self.guard
1088 }
1089}
1090
1091impl Drop for ConnectionPoolGuard<'_> {
1092 fn drop(&mut self) {
1093 #[cfg(test)]
1094 self.check_invariants();
1095 }
1096}
1097
1098fn broadcast<F>(
1099 sender_id: Option<ConnectionId>,
1100 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1101 mut f: F,
1102) where
1103 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1104{
1105 for receiver_id in receiver_ids {
1106 if Some(receiver_id) != sender_id
1107 && let Err(error) = f(receiver_id)
1108 {
1109 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1110 }
1111 }
1112}
1113
1114pub struct ProtocolVersion(u32);
1115
1116impl Header for ProtocolVersion {
1117 fn name() -> &'static HeaderName {
1118 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1119 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1120 }
1121
1122 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1123 where
1124 Self: Sized,
1125 I: Iterator<Item = &'i axum::http::HeaderValue>,
1126 {
1127 let version = values
1128 .next()
1129 .ok_or_else(axum::headers::Error::invalid)?
1130 .to_str()
1131 .map_err(|_| axum::headers::Error::invalid())?
1132 .parse()
1133 .map_err(|_| axum::headers::Error::invalid())?;
1134 Ok(Self(version))
1135 }
1136
1137 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1138 values.extend([self.0.to_string().parse().unwrap()]);
1139 }
1140}
1141
1142pub struct AppVersionHeader(SemanticVersion);
1143impl Header for AppVersionHeader {
1144 fn name() -> &'static HeaderName {
1145 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1146 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1147 }
1148
1149 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1150 where
1151 Self: Sized,
1152 I: Iterator<Item = &'i axum::http::HeaderValue>,
1153 {
1154 let version = values
1155 .next()
1156 .ok_or_else(axum::headers::Error::invalid)?
1157 .to_str()
1158 .map_err(|_| axum::headers::Error::invalid())?
1159 .parse()
1160 .map_err(|_| axum::headers::Error::invalid())?;
1161 Ok(Self(version))
1162 }
1163
1164 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1165 values.extend([self.0.to_string().parse().unwrap()]);
1166 }
1167}
1168
1169#[derive(Debug)]
1170pub struct ReleaseChannelHeader(String);
1171
1172impl Header for ReleaseChannelHeader {
1173 fn name() -> &'static HeaderName {
1174 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1175 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1176 }
1177
1178 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1179 where
1180 Self: Sized,
1181 I: Iterator<Item = &'i axum::http::HeaderValue>,
1182 {
1183 Ok(Self(
1184 values
1185 .next()
1186 .ok_or_else(axum::headers::Error::invalid)?
1187 .to_str()
1188 .map_err(|_| axum::headers::Error::invalid())?
1189 .to_owned(),
1190 ))
1191 }
1192
1193 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1194 values.extend([self.0.parse().unwrap()]);
1195 }
1196}
1197
1198pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1199 Router::new()
1200 .route("/rpc", get(handle_websocket_request))
1201 .layer(
1202 ServiceBuilder::new()
1203 .layer(Extension(server.app_state.clone()))
1204 .layer(middleware::from_fn(auth::validate_header)),
1205 )
1206 .route("/metrics", get(handle_metrics))
1207 .layer(Extension(server))
1208}
1209
1210pub async fn handle_websocket_request(
1211 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1212 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1213 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1214 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1215 Extension(server): Extension<Arc<Server>>,
1216 Extension(principal): Extension<Principal>,
1217 user_agent: Option<TypedHeader<UserAgent>>,
1218 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1219 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1220 ws: WebSocketUpgrade,
1221) -> axum::response::Response {
1222 if protocol_version != rpc::PROTOCOL_VERSION {
1223 return (
1224 StatusCode::UPGRADE_REQUIRED,
1225 "client must be upgraded".to_string(),
1226 )
1227 .into_response();
1228 }
1229
1230 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1231 return (
1232 StatusCode::UPGRADE_REQUIRED,
1233 "no version header found".to_string(),
1234 )
1235 .into_response();
1236 };
1237
1238 let release_channel = release_channel_header.map(|header| header.0.0);
1239
1240 if !version.can_collaborate() {
1241 return (
1242 StatusCode::UPGRADE_REQUIRED,
1243 "client must be upgraded".to_string(),
1244 )
1245 .into_response();
1246 }
1247
1248 let socket_address = socket_address.to_string();
1249
1250 // Acquire connection guard before WebSocket upgrade
1251 let connection_guard = match ConnectionGuard::try_acquire() {
1252 Ok(guard) => guard,
1253 Err(()) => {
1254 return (
1255 StatusCode::SERVICE_UNAVAILABLE,
1256 "Too many concurrent connections",
1257 )
1258 .into_response();
1259 }
1260 };
1261
1262 ws.on_upgrade(move |socket| {
1263 let socket = socket
1264 .map_ok(to_tungstenite_message)
1265 .err_into()
1266 .with(|message| async move { to_axum_message(message) });
1267 let connection = Connection::new(Box::pin(socket));
1268 async move {
1269 server
1270 .handle_connection(
1271 connection,
1272 socket_address,
1273 principal,
1274 version,
1275 release_channel,
1276 user_agent.map(|header| header.to_string()),
1277 country_code_header.map(|header| header.to_string()),
1278 system_id_header.map(|header| header.to_string()),
1279 None,
1280 Executor::Production,
1281 Some(connection_guard),
1282 )
1283 .await;
1284 }
1285 })
1286}
1287
1288pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1289 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1290 let connections_metric = CONNECTIONS_METRIC
1291 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1292
1293 let connections = server
1294 .connection_pool
1295 .lock()
1296 .connections()
1297 .filter(|connection| !connection.admin)
1298 .count();
1299 connections_metric.set(connections as _);
1300
1301 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1302 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1303 register_int_gauge!(
1304 "shared_projects",
1305 "number of open projects with one or more guests"
1306 )
1307 .unwrap()
1308 });
1309
1310 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1311 shared_projects_metric.set(shared_projects as _);
1312
1313 let encoder = prometheus::TextEncoder::new();
1314 let metric_families = prometheus::gather();
1315 let encoded_metrics = encoder
1316 .encode_to_string(&metric_families)
1317 .map_err(|err| anyhow!("{err}"))?;
1318 Ok(encoded_metrics)
1319}
1320
1321#[instrument(err, skip(executor))]
1322async fn connection_lost(
1323 session: Session,
1324 mut teardown: watch::Receiver<bool>,
1325 executor: Executor,
1326) -> Result<()> {
1327 session.peer.disconnect(session.connection_id);
1328 session
1329 .connection_pool()
1330 .await
1331 .remove_connection(session.connection_id)?;
1332
1333 session
1334 .db()
1335 .await
1336 .connection_lost(session.connection_id)
1337 .await
1338 .trace_err();
1339
1340 futures::select_biased! {
1341 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1342
1343 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1344 leave_room_for_session(&session, session.connection_id).await.trace_err();
1345 leave_channel_buffers_for_session(&session)
1346 .await
1347 .trace_err();
1348
1349 if !session
1350 .connection_pool()
1351 .await
1352 .is_user_online(session.user_id())
1353 {
1354 let db = session.db().await;
1355 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1356 room_updated(&room, &session.peer);
1357 }
1358 }
1359
1360 update_user_contacts(session.user_id(), &session).await?;
1361 },
1362 _ = teardown.changed().fuse() => {}
1363 }
1364
1365 Ok(())
1366}
1367
1368/// Acknowledges a ping from a client, used to keep the connection alive.
1369async fn ping(
1370 _: proto::Ping,
1371 response: Response<proto::Ping>,
1372 _session: MessageContext,
1373) -> Result<()> {
1374 response.send(proto::Ack {})?;
1375 Ok(())
1376}
1377
1378/// Creates a new room for calling (outside of channels)
1379async fn create_room(
1380 _request: proto::CreateRoom,
1381 response: Response<proto::CreateRoom>,
1382 session: MessageContext,
1383) -> Result<()> {
1384 let livekit_room = nanoid::nanoid!(30);
1385
1386 let live_kit_connection_info = util::maybe!(async {
1387 let live_kit = session.app_state.livekit_client.as_ref();
1388 let live_kit = live_kit?;
1389 let user_id = session.user_id().to_string();
1390
1391 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1392
1393 Some(proto::LiveKitConnectionInfo {
1394 server_url: live_kit.url().into(),
1395 token,
1396 can_publish: true,
1397 })
1398 })
1399 .await;
1400
1401 let room = session
1402 .db()
1403 .await
1404 .create_room(session.user_id(), session.connection_id, &livekit_room)
1405 .await?;
1406
1407 response.send(proto::CreateRoomResponse {
1408 room: Some(room.clone()),
1409 live_kit_connection_info,
1410 })?;
1411
1412 update_user_contacts(session.user_id(), &session).await?;
1413 Ok(())
1414}
1415
1416/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1417async fn join_room(
1418 request: proto::JoinRoom,
1419 response: Response<proto::JoinRoom>,
1420 session: MessageContext,
1421) -> Result<()> {
1422 let room_id = RoomId::from_proto(request.id);
1423
1424 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1425
1426 if let Some(channel_id) = channel_id {
1427 return join_channel_internal(channel_id, Box::new(response), session).await;
1428 }
1429
1430 let joined_room = {
1431 let room = session
1432 .db()
1433 .await
1434 .join_room(room_id, session.user_id(), session.connection_id)
1435 .await?;
1436 room_updated(&room.room, &session.peer);
1437 room.into_inner()
1438 };
1439
1440 for connection_id in session
1441 .connection_pool()
1442 .await
1443 .user_connection_ids(session.user_id())
1444 {
1445 session
1446 .peer
1447 .send(
1448 connection_id,
1449 proto::CallCanceled {
1450 room_id: room_id.to_proto(),
1451 },
1452 )
1453 .trace_err();
1454 }
1455
1456 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1457 {
1458 live_kit
1459 .room_token(
1460 &joined_room.room.livekit_room,
1461 &session.user_id().to_string(),
1462 )
1463 .trace_err()
1464 .map(|token| proto::LiveKitConnectionInfo {
1465 server_url: live_kit.url().into(),
1466 token,
1467 can_publish: true,
1468 })
1469 } else {
1470 None
1471 };
1472
1473 response.send(proto::JoinRoomResponse {
1474 room: Some(joined_room.room),
1475 channel_id: None,
1476 live_kit_connection_info,
1477 })?;
1478
1479 update_user_contacts(session.user_id(), &session).await?;
1480 Ok(())
1481}
1482
1483/// Rejoin room is used to reconnect to a room after connection errors.
1484async fn rejoin_room(
1485 request: proto::RejoinRoom,
1486 response: Response<proto::RejoinRoom>,
1487 session: MessageContext,
1488) -> Result<()> {
1489 let room;
1490 let channel;
1491 {
1492 let mut rejoined_room = session
1493 .db()
1494 .await
1495 .rejoin_room(request, session.user_id(), session.connection_id)
1496 .await?;
1497
1498 response.send(proto::RejoinRoomResponse {
1499 room: Some(rejoined_room.room.clone()),
1500 reshared_projects: rejoined_room
1501 .reshared_projects
1502 .iter()
1503 .map(|project| proto::ResharedProject {
1504 id: project.id.to_proto(),
1505 collaborators: project
1506 .collaborators
1507 .iter()
1508 .map(|collaborator| collaborator.to_proto())
1509 .collect(),
1510 })
1511 .collect(),
1512 rejoined_projects: rejoined_room
1513 .rejoined_projects
1514 .iter()
1515 .map(|rejoined_project| rejoined_project.to_proto())
1516 .collect(),
1517 })?;
1518 room_updated(&rejoined_room.room, &session.peer);
1519
1520 for project in &rejoined_room.reshared_projects {
1521 for collaborator in &project.collaborators {
1522 session
1523 .peer
1524 .send(
1525 collaborator.connection_id,
1526 proto::UpdateProjectCollaborator {
1527 project_id: project.id.to_proto(),
1528 old_peer_id: Some(project.old_connection_id.into()),
1529 new_peer_id: Some(session.connection_id.into()),
1530 },
1531 )
1532 .trace_err();
1533 }
1534
1535 broadcast(
1536 Some(session.connection_id),
1537 project
1538 .collaborators
1539 .iter()
1540 .map(|collaborator| collaborator.connection_id),
1541 |connection_id| {
1542 session.peer.forward_send(
1543 session.connection_id,
1544 connection_id,
1545 proto::UpdateProject {
1546 project_id: project.id.to_proto(),
1547 worktrees: project.worktrees.clone(),
1548 },
1549 )
1550 },
1551 );
1552 }
1553
1554 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1555
1556 let rejoined_room = rejoined_room.into_inner();
1557
1558 room = rejoined_room.room;
1559 channel = rejoined_room.channel;
1560 }
1561
1562 if let Some(channel) = channel {
1563 channel_updated(
1564 &channel,
1565 &room,
1566 &session.peer,
1567 &*session.connection_pool().await,
1568 );
1569 }
1570
1571 update_user_contacts(session.user_id(), &session).await?;
1572 Ok(())
1573}
1574
1575fn notify_rejoined_projects(
1576 rejoined_projects: &mut Vec<RejoinedProject>,
1577 session: &Session,
1578) -> Result<()> {
1579 for project in rejoined_projects.iter() {
1580 for collaborator in &project.collaborators {
1581 session
1582 .peer
1583 .send(
1584 collaborator.connection_id,
1585 proto::UpdateProjectCollaborator {
1586 project_id: project.id.to_proto(),
1587 old_peer_id: Some(project.old_connection_id.into()),
1588 new_peer_id: Some(session.connection_id.into()),
1589 },
1590 )
1591 .trace_err();
1592 }
1593 }
1594
1595 for project in rejoined_projects {
1596 for worktree in mem::take(&mut project.worktrees) {
1597 // Stream this worktree's entries.
1598 let message = proto::UpdateWorktree {
1599 project_id: project.id.to_proto(),
1600 worktree_id: worktree.id,
1601 abs_path: worktree.abs_path.clone(),
1602 root_name: worktree.root_name,
1603 updated_entries: worktree.updated_entries,
1604 removed_entries: worktree.removed_entries,
1605 scan_id: worktree.scan_id,
1606 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1607 updated_repositories: worktree.updated_repositories,
1608 removed_repositories: worktree.removed_repositories,
1609 };
1610 for update in proto::split_worktree_update(message) {
1611 session.peer.send(session.connection_id, update)?;
1612 }
1613
1614 // Stream this worktree's diagnostics.
1615 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1616 if let Some(summary) = worktree_diagnostics.next() {
1617 let message = proto::UpdateDiagnosticSummary {
1618 project_id: project.id.to_proto(),
1619 worktree_id: worktree.id,
1620 summary: Some(summary),
1621 more_summaries: worktree_diagnostics.collect(),
1622 };
1623 session.peer.send(session.connection_id, message)?;
1624 }
1625
1626 for settings_file in worktree.settings_files {
1627 session.peer.send(
1628 session.connection_id,
1629 proto::UpdateWorktreeSettings {
1630 project_id: project.id.to_proto(),
1631 worktree_id: worktree.id,
1632 path: settings_file.path,
1633 content: Some(settings_file.content),
1634 kind: Some(settings_file.kind.to_proto().into()),
1635 },
1636 )?;
1637 }
1638 }
1639
1640 for repository in mem::take(&mut project.updated_repositories) {
1641 for update in split_repository_update(repository) {
1642 session.peer.send(session.connection_id, update)?;
1643 }
1644 }
1645
1646 for id in mem::take(&mut project.removed_repositories) {
1647 session.peer.send(
1648 session.connection_id,
1649 proto::RemoveRepository {
1650 project_id: project.id.to_proto(),
1651 id,
1652 },
1653 )?;
1654 }
1655 }
1656
1657 Ok(())
1658}
1659
1660/// leave room disconnects from the room.
1661async fn leave_room(
1662 _: proto::LeaveRoom,
1663 response: Response<proto::LeaveRoom>,
1664 session: MessageContext,
1665) -> Result<()> {
1666 leave_room_for_session(&session, session.connection_id).await?;
1667 response.send(proto::Ack {})?;
1668 Ok(())
1669}
1670
1671/// Updates the permissions of someone else in the room.
1672async fn set_room_participant_role(
1673 request: proto::SetRoomParticipantRole,
1674 response: Response<proto::SetRoomParticipantRole>,
1675 session: MessageContext,
1676) -> Result<()> {
1677 let user_id = UserId::from_proto(request.user_id);
1678 let role = ChannelRole::from(request.role());
1679
1680 let (livekit_room, can_publish) = {
1681 let room = session
1682 .db()
1683 .await
1684 .set_room_participant_role(
1685 session.user_id(),
1686 RoomId::from_proto(request.room_id),
1687 user_id,
1688 role,
1689 )
1690 .await?;
1691
1692 let livekit_room = room.livekit_room.clone();
1693 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1694 room_updated(&room, &session.peer);
1695 (livekit_room, can_publish)
1696 };
1697
1698 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1699 live_kit
1700 .update_participant(
1701 livekit_room.clone(),
1702 request.user_id.to_string(),
1703 livekit_api::proto::ParticipantPermission {
1704 can_subscribe: true,
1705 can_publish,
1706 can_publish_data: can_publish,
1707 hidden: false,
1708 recorder: false,
1709 },
1710 )
1711 .await
1712 .trace_err();
1713 }
1714
1715 response.send(proto::Ack {})?;
1716 Ok(())
1717}
1718
1719/// Call someone else into the current room
1720async fn call(
1721 request: proto::Call,
1722 response: Response<proto::Call>,
1723 session: MessageContext,
1724) -> Result<()> {
1725 let room_id = RoomId::from_proto(request.room_id);
1726 let calling_user_id = session.user_id();
1727 let calling_connection_id = session.connection_id;
1728 let called_user_id = UserId::from_proto(request.called_user_id);
1729 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1730 if !session
1731 .db()
1732 .await
1733 .has_contact(calling_user_id, called_user_id)
1734 .await?
1735 {
1736 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1737 }
1738
1739 let incoming_call = {
1740 let (room, incoming_call) = &mut *session
1741 .db()
1742 .await
1743 .call(
1744 room_id,
1745 calling_user_id,
1746 calling_connection_id,
1747 called_user_id,
1748 initial_project_id,
1749 )
1750 .await?;
1751 room_updated(room, &session.peer);
1752 mem::take(incoming_call)
1753 };
1754 update_user_contacts(called_user_id, &session).await?;
1755
1756 let mut calls = session
1757 .connection_pool()
1758 .await
1759 .user_connection_ids(called_user_id)
1760 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1761 .collect::<FuturesUnordered<_>>();
1762
1763 while let Some(call_response) = calls.next().await {
1764 match call_response.as_ref() {
1765 Ok(_) => {
1766 response.send(proto::Ack {})?;
1767 return Ok(());
1768 }
1769 Err(_) => {
1770 call_response.trace_err();
1771 }
1772 }
1773 }
1774
1775 {
1776 let room = session
1777 .db()
1778 .await
1779 .call_failed(room_id, called_user_id)
1780 .await?;
1781 room_updated(&room, &session.peer);
1782 }
1783 update_user_contacts(called_user_id, &session).await?;
1784
1785 Err(anyhow!("failed to ring user"))?
1786}
1787
1788/// Cancel an outgoing call.
1789async fn cancel_call(
1790 request: proto::CancelCall,
1791 response: Response<proto::CancelCall>,
1792 session: MessageContext,
1793) -> Result<()> {
1794 let called_user_id = UserId::from_proto(request.called_user_id);
1795 let room_id = RoomId::from_proto(request.room_id);
1796 {
1797 let room = session
1798 .db()
1799 .await
1800 .cancel_call(room_id, session.connection_id, called_user_id)
1801 .await?;
1802 room_updated(&room, &session.peer);
1803 }
1804
1805 for connection_id in session
1806 .connection_pool()
1807 .await
1808 .user_connection_ids(called_user_id)
1809 {
1810 session
1811 .peer
1812 .send(
1813 connection_id,
1814 proto::CallCanceled {
1815 room_id: room_id.to_proto(),
1816 },
1817 )
1818 .trace_err();
1819 }
1820 response.send(proto::Ack {})?;
1821
1822 update_user_contacts(called_user_id, &session).await?;
1823 Ok(())
1824}
1825
1826/// Decline an incoming call.
1827async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1828 let room_id = RoomId::from_proto(message.room_id);
1829 {
1830 let room = session
1831 .db()
1832 .await
1833 .decline_call(Some(room_id), session.user_id())
1834 .await?
1835 .context("declining call")?;
1836 room_updated(&room, &session.peer);
1837 }
1838
1839 for connection_id in session
1840 .connection_pool()
1841 .await
1842 .user_connection_ids(session.user_id())
1843 {
1844 session
1845 .peer
1846 .send(
1847 connection_id,
1848 proto::CallCanceled {
1849 room_id: room_id.to_proto(),
1850 },
1851 )
1852 .trace_err();
1853 }
1854 update_user_contacts(session.user_id(), &session).await?;
1855 Ok(())
1856}
1857
1858/// Updates other participants in the room with your current location.
1859async fn update_participant_location(
1860 request: proto::UpdateParticipantLocation,
1861 response: Response<proto::UpdateParticipantLocation>,
1862 session: MessageContext,
1863) -> Result<()> {
1864 let room_id = RoomId::from_proto(request.room_id);
1865 let location = request.location.context("invalid location")?;
1866
1867 let db = session.db().await;
1868 let room = db
1869 .update_room_participant_location(room_id, session.connection_id, location)
1870 .await?;
1871
1872 room_updated(&room, &session.peer);
1873 response.send(proto::Ack {})?;
1874 Ok(())
1875}
1876
1877/// Share a project into the room.
1878async fn share_project(
1879 request: proto::ShareProject,
1880 response: Response<proto::ShareProject>,
1881 session: MessageContext,
1882) -> Result<()> {
1883 let (project_id, room) = &*session
1884 .db()
1885 .await
1886 .share_project(
1887 RoomId::from_proto(request.room_id),
1888 session.connection_id,
1889 &request.worktrees,
1890 request.is_ssh_project,
1891 )
1892 .await?;
1893 response.send(proto::ShareProjectResponse {
1894 project_id: project_id.to_proto(),
1895 })?;
1896 room_updated(room, &session.peer);
1897
1898 Ok(())
1899}
1900
1901/// Unshare a project from the room.
1902async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1903 let project_id = ProjectId::from_proto(message.project_id);
1904 unshare_project_internal(project_id, session.connection_id, &session).await
1905}
1906
1907async fn unshare_project_internal(
1908 project_id: ProjectId,
1909 connection_id: ConnectionId,
1910 session: &Session,
1911) -> Result<()> {
1912 let delete = {
1913 let room_guard = session
1914 .db()
1915 .await
1916 .unshare_project(project_id, connection_id)
1917 .await?;
1918
1919 let (delete, room, guest_connection_ids) = &*room_guard;
1920
1921 let message = proto::UnshareProject {
1922 project_id: project_id.to_proto(),
1923 };
1924
1925 broadcast(
1926 Some(connection_id),
1927 guest_connection_ids.iter().copied(),
1928 |conn_id| session.peer.send(conn_id, message.clone()),
1929 );
1930 if let Some(room) = room {
1931 room_updated(room, &session.peer);
1932 }
1933
1934 *delete
1935 };
1936
1937 if delete {
1938 let db = session.db().await;
1939 db.delete_project(project_id).await?;
1940 }
1941
1942 Ok(())
1943}
1944
1945/// Join someone elses shared project.
1946async fn join_project(
1947 request: proto::JoinProject,
1948 response: Response<proto::JoinProject>,
1949 session: MessageContext,
1950) -> Result<()> {
1951 let project_id = ProjectId::from_proto(request.project_id);
1952
1953 tracing::info!(%project_id, "join project");
1954
1955 let db = session.db().await;
1956 let (project, replica_id) = &mut *db
1957 .join_project(
1958 project_id,
1959 session.connection_id,
1960 session.user_id(),
1961 request.committer_name.clone(),
1962 request.committer_email.clone(),
1963 )
1964 .await?;
1965 drop(db);
1966 tracing::info!(%project_id, "join remote project");
1967 let collaborators = project
1968 .collaborators
1969 .iter()
1970 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1971 .map(|collaborator| collaborator.to_proto())
1972 .collect::<Vec<_>>();
1973 let project_id = project.id;
1974 let guest_user_id = session.user_id();
1975
1976 let worktrees = project
1977 .worktrees
1978 .iter()
1979 .map(|(id, worktree)| proto::WorktreeMetadata {
1980 id: *id,
1981 root_name: worktree.root_name.clone(),
1982 visible: worktree.visible,
1983 abs_path: worktree.abs_path.clone(),
1984 })
1985 .collect::<Vec<_>>();
1986
1987 let add_project_collaborator = proto::AddProjectCollaborator {
1988 project_id: project_id.to_proto(),
1989 collaborator: Some(proto::Collaborator {
1990 peer_id: Some(session.connection_id.into()),
1991 replica_id: replica_id.0 as u32,
1992 user_id: guest_user_id.to_proto(),
1993 is_host: false,
1994 committer_name: request.committer_name.clone(),
1995 committer_email: request.committer_email.clone(),
1996 }),
1997 };
1998
1999 for collaborator in &collaborators {
2000 session
2001 .peer
2002 .send(
2003 collaborator.peer_id.unwrap().into(),
2004 add_project_collaborator.clone(),
2005 )
2006 .trace_err();
2007 }
2008
2009 // First, we send the metadata associated with each worktree.
2010 let (language_servers, language_server_capabilities) = project
2011 .language_servers
2012 .clone()
2013 .into_iter()
2014 .map(|server| (server.server, server.capabilities))
2015 .unzip();
2016 response.send(proto::JoinProjectResponse {
2017 project_id: project.id.0 as u64,
2018 worktrees,
2019 replica_id: replica_id.0 as u32,
2020 collaborators,
2021 language_servers,
2022 language_server_capabilities,
2023 role: project.role.into(),
2024 })?;
2025
2026 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2027 // Stream this worktree's entries.
2028 let message = proto::UpdateWorktree {
2029 project_id: project_id.to_proto(),
2030 worktree_id,
2031 abs_path: worktree.abs_path.clone(),
2032 root_name: worktree.root_name,
2033 updated_entries: worktree.entries,
2034 removed_entries: Default::default(),
2035 scan_id: worktree.scan_id,
2036 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2037 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2038 removed_repositories: Default::default(),
2039 };
2040 for update in proto::split_worktree_update(message) {
2041 session.peer.send(session.connection_id, update.clone())?;
2042 }
2043
2044 // Stream this worktree's diagnostics.
2045 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2046 if let Some(summary) = worktree_diagnostics.next() {
2047 let message = proto::UpdateDiagnosticSummary {
2048 project_id: project.id.to_proto(),
2049 worktree_id: worktree.id,
2050 summary: Some(summary),
2051 more_summaries: worktree_diagnostics.collect(),
2052 };
2053 session.peer.send(session.connection_id, message)?;
2054 }
2055
2056 for settings_file in worktree.settings_files {
2057 session.peer.send(
2058 session.connection_id,
2059 proto::UpdateWorktreeSettings {
2060 project_id: project_id.to_proto(),
2061 worktree_id: worktree.id,
2062 path: settings_file.path,
2063 content: Some(settings_file.content),
2064 kind: Some(settings_file.kind.to_proto() as i32),
2065 },
2066 )?;
2067 }
2068 }
2069
2070 for repository in mem::take(&mut project.repositories) {
2071 for update in split_repository_update(repository) {
2072 session.peer.send(session.connection_id, update)?;
2073 }
2074 }
2075
2076 for language_server in &project.language_servers {
2077 session.peer.send(
2078 session.connection_id,
2079 proto::UpdateLanguageServer {
2080 project_id: project_id.to_proto(),
2081 server_name: Some(language_server.server.name.clone()),
2082 language_server_id: language_server.server.id,
2083 variant: Some(
2084 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2085 proto::LspDiskBasedDiagnosticsUpdated {},
2086 ),
2087 ),
2088 },
2089 )?;
2090 }
2091
2092 Ok(())
2093}
2094
2095/// Leave someone elses shared project.
2096async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2097 let sender_id = session.connection_id;
2098 let project_id = ProjectId::from_proto(request.project_id);
2099 let db = session.db().await;
2100
2101 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2102 tracing::info!(
2103 %project_id,
2104 "leave project"
2105 );
2106
2107 project_left(project, &session);
2108 if let Some(room) = room {
2109 room_updated(room, &session.peer);
2110 }
2111
2112 Ok(())
2113}
2114
2115/// Updates other participants with changes to the project
2116async fn update_project(
2117 request: proto::UpdateProject,
2118 response: Response<proto::UpdateProject>,
2119 session: MessageContext,
2120) -> Result<()> {
2121 let project_id = ProjectId::from_proto(request.project_id);
2122 let (room, guest_connection_ids) = &*session
2123 .db()
2124 .await
2125 .update_project(project_id, session.connection_id, &request.worktrees)
2126 .await?;
2127 broadcast(
2128 Some(session.connection_id),
2129 guest_connection_ids.iter().copied(),
2130 |connection_id| {
2131 session
2132 .peer
2133 .forward_send(session.connection_id, connection_id, request.clone())
2134 },
2135 );
2136 if let Some(room) = room {
2137 room_updated(room, &session.peer);
2138 }
2139 response.send(proto::Ack {})?;
2140
2141 Ok(())
2142}
2143
2144/// Updates other participants with changes to the worktree
2145async fn update_worktree(
2146 request: proto::UpdateWorktree,
2147 response: Response<proto::UpdateWorktree>,
2148 session: MessageContext,
2149) -> Result<()> {
2150 let guest_connection_ids = session
2151 .db()
2152 .await
2153 .update_worktree(&request, session.connection_id)
2154 .await?;
2155
2156 broadcast(
2157 Some(session.connection_id),
2158 guest_connection_ids.iter().copied(),
2159 |connection_id| {
2160 session
2161 .peer
2162 .forward_send(session.connection_id, connection_id, request.clone())
2163 },
2164 );
2165 response.send(proto::Ack {})?;
2166 Ok(())
2167}
2168
2169async fn update_repository(
2170 request: proto::UpdateRepository,
2171 response: Response<proto::UpdateRepository>,
2172 session: MessageContext,
2173) -> Result<()> {
2174 let guest_connection_ids = session
2175 .db()
2176 .await
2177 .update_repository(&request, session.connection_id)
2178 .await?;
2179
2180 broadcast(
2181 Some(session.connection_id),
2182 guest_connection_ids.iter().copied(),
2183 |connection_id| {
2184 session
2185 .peer
2186 .forward_send(session.connection_id, connection_id, request.clone())
2187 },
2188 );
2189 response.send(proto::Ack {})?;
2190 Ok(())
2191}
2192
2193async fn remove_repository(
2194 request: proto::RemoveRepository,
2195 response: Response<proto::RemoveRepository>,
2196 session: MessageContext,
2197) -> Result<()> {
2198 let guest_connection_ids = session
2199 .db()
2200 .await
2201 .remove_repository(&request, session.connection_id)
2202 .await?;
2203
2204 broadcast(
2205 Some(session.connection_id),
2206 guest_connection_ids.iter().copied(),
2207 |connection_id| {
2208 session
2209 .peer
2210 .forward_send(session.connection_id, connection_id, request.clone())
2211 },
2212 );
2213 response.send(proto::Ack {})?;
2214 Ok(())
2215}
2216
2217/// Updates other participants with changes to the diagnostics
2218async fn update_diagnostic_summary(
2219 message: proto::UpdateDiagnosticSummary,
2220 session: MessageContext,
2221) -> Result<()> {
2222 let guest_connection_ids = session
2223 .db()
2224 .await
2225 .update_diagnostic_summary(&message, session.connection_id)
2226 .await?;
2227
2228 broadcast(
2229 Some(session.connection_id),
2230 guest_connection_ids.iter().copied(),
2231 |connection_id| {
2232 session
2233 .peer
2234 .forward_send(session.connection_id, connection_id, message.clone())
2235 },
2236 );
2237
2238 Ok(())
2239}
2240
2241/// Updates other participants with changes to the worktree settings
2242async fn update_worktree_settings(
2243 message: proto::UpdateWorktreeSettings,
2244 session: MessageContext,
2245) -> Result<()> {
2246 let guest_connection_ids = session
2247 .db()
2248 .await
2249 .update_worktree_settings(&message, session.connection_id)
2250 .await?;
2251
2252 broadcast(
2253 Some(session.connection_id),
2254 guest_connection_ids.iter().copied(),
2255 |connection_id| {
2256 session
2257 .peer
2258 .forward_send(session.connection_id, connection_id, message.clone())
2259 },
2260 );
2261
2262 Ok(())
2263}
2264
2265/// Notify other participants that a language server has started.
2266async fn start_language_server(
2267 request: proto::StartLanguageServer,
2268 session: MessageContext,
2269) -> Result<()> {
2270 let guest_connection_ids = session
2271 .db()
2272 .await
2273 .start_language_server(&request, session.connection_id)
2274 .await?;
2275
2276 broadcast(
2277 Some(session.connection_id),
2278 guest_connection_ids.iter().copied(),
2279 |connection_id| {
2280 session
2281 .peer
2282 .forward_send(session.connection_id, connection_id, request.clone())
2283 },
2284 );
2285 Ok(())
2286}
2287
2288/// Notify other participants that a language server has changed.
2289async fn update_language_server(
2290 request: proto::UpdateLanguageServer,
2291 session: MessageContext,
2292) -> Result<()> {
2293 let project_id = ProjectId::from_proto(request.project_id);
2294 let db = session.db().await;
2295
2296 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2297 && let Some(capabilities) = update.capabilities.clone()
2298 {
2299 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2300 .await?;
2301 }
2302
2303 let project_connection_ids = db
2304 .project_connection_ids(project_id, session.connection_id, true)
2305 .await?;
2306 broadcast(
2307 Some(session.connection_id),
2308 project_connection_ids.iter().copied(),
2309 |connection_id| {
2310 session
2311 .peer
2312 .forward_send(session.connection_id, connection_id, request.clone())
2313 },
2314 );
2315 Ok(())
2316}
2317
2318/// forward a project request to the host. These requests should be read only
2319/// as guests are allowed to send them.
2320async fn forward_read_only_project_request<T>(
2321 request: T,
2322 response: Response<T>,
2323 session: MessageContext,
2324) -> Result<()>
2325where
2326 T: EntityMessage + RequestMessage,
2327{
2328 let project_id = ProjectId::from_proto(request.remote_entity_id());
2329 let host_connection_id = session
2330 .db()
2331 .await
2332 .host_for_read_only_project_request(project_id, session.connection_id)
2333 .await?;
2334 let payload = session.forward_request(host_connection_id, request).await?;
2335 response.send(payload)?;
2336 Ok(())
2337}
2338
2339/// forward a project request to the host. These requests are disallowed
2340/// for guests.
2341async fn forward_mutating_project_request<T>(
2342 request: T,
2343 response: Response<T>,
2344 session: MessageContext,
2345) -> Result<()>
2346where
2347 T: EntityMessage + RequestMessage,
2348{
2349 let project_id = ProjectId::from_proto(request.remote_entity_id());
2350
2351 let host_connection_id = session
2352 .db()
2353 .await
2354 .host_for_mutating_project_request(project_id, session.connection_id)
2355 .await?;
2356 let payload = session.forward_request(host_connection_id, request).await?;
2357 response.send(payload)?;
2358 Ok(())
2359}
2360
2361// todo(lsp) remove after Zed Stable hits v0.204.x
2362async fn multi_lsp_query(
2363 request: MultiLspQuery,
2364 response: Response<MultiLspQuery>,
2365 session: MessageContext,
2366) -> Result<()> {
2367 tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2368 tracing::info!("multi_lsp_query message received");
2369 forward_mutating_project_request(request, response, session).await
2370}
2371
2372async fn lsp_query(
2373 request: proto::LspQuery,
2374 response: Response<proto::LspQuery>,
2375 session: MessageContext,
2376) -> Result<()> {
2377 let (name, should_write) = request.query_name_and_write_permissions();
2378 tracing::Span::current().record("lsp_query_request", name);
2379 tracing::info!("lsp_query message received");
2380 if should_write {
2381 forward_mutating_project_request(request, response, session).await
2382 } else {
2383 forward_read_only_project_request(request, response, session).await
2384 }
2385}
2386
2387/// Notify other participants that a new buffer has been created
2388async fn create_buffer_for_peer(
2389 request: proto::CreateBufferForPeer,
2390 session: MessageContext,
2391) -> Result<()> {
2392 session
2393 .db()
2394 .await
2395 .check_user_is_project_host(
2396 ProjectId::from_proto(request.project_id),
2397 session.connection_id,
2398 )
2399 .await?;
2400 let peer_id = request.peer_id.context("invalid peer id")?;
2401 session
2402 .peer
2403 .forward_send(session.connection_id, peer_id.into(), request)?;
2404 Ok(())
2405}
2406
2407/// Notify other participants that a buffer has been updated. This is
2408/// allowed for guests as long as the update is limited to selections.
2409async fn update_buffer(
2410 request: proto::UpdateBuffer,
2411 response: Response<proto::UpdateBuffer>,
2412 session: MessageContext,
2413) -> Result<()> {
2414 let project_id = ProjectId::from_proto(request.project_id);
2415 let mut capability = Capability::ReadOnly;
2416
2417 for op in request.operations.iter() {
2418 match op.variant {
2419 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2420 Some(_) => capability = Capability::ReadWrite,
2421 }
2422 }
2423
2424 let host = {
2425 let guard = session
2426 .db()
2427 .await
2428 .connections_for_buffer_update(project_id, session.connection_id, capability)
2429 .await?;
2430
2431 let (host, guests) = &*guard;
2432
2433 broadcast(
2434 Some(session.connection_id),
2435 guests.clone(),
2436 |connection_id| {
2437 session
2438 .peer
2439 .forward_send(session.connection_id, connection_id, request.clone())
2440 },
2441 );
2442
2443 *host
2444 };
2445
2446 if host != session.connection_id {
2447 session.forward_request(host, request.clone()).await?;
2448 }
2449
2450 response.send(proto::Ack {})?;
2451 Ok(())
2452}
2453
2454async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2455 let project_id = ProjectId::from_proto(message.project_id);
2456
2457 let operation = message.operation.as_ref().context("invalid operation")?;
2458 let capability = match operation.variant.as_ref() {
2459 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2460 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2461 match buffer_op.variant {
2462 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2463 Capability::ReadOnly
2464 }
2465 _ => Capability::ReadWrite,
2466 }
2467 } else {
2468 Capability::ReadWrite
2469 }
2470 }
2471 Some(_) => Capability::ReadWrite,
2472 None => Capability::ReadOnly,
2473 };
2474
2475 let guard = session
2476 .db()
2477 .await
2478 .connections_for_buffer_update(project_id, session.connection_id, capability)
2479 .await?;
2480
2481 let (host, guests) = &*guard;
2482
2483 broadcast(
2484 Some(session.connection_id),
2485 guests.iter().chain([host]).copied(),
2486 |connection_id| {
2487 session
2488 .peer
2489 .forward_send(session.connection_id, connection_id, message.clone())
2490 },
2491 );
2492
2493 Ok(())
2494}
2495
2496/// Notify other participants that a project has been updated.
2497async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2498 request: T,
2499 session: MessageContext,
2500) -> Result<()> {
2501 let project_id = ProjectId::from_proto(request.remote_entity_id());
2502 let project_connection_ids = session
2503 .db()
2504 .await
2505 .project_connection_ids(project_id, session.connection_id, false)
2506 .await?;
2507
2508 broadcast(
2509 Some(session.connection_id),
2510 project_connection_ids.iter().copied(),
2511 |connection_id| {
2512 session
2513 .peer
2514 .forward_send(session.connection_id, connection_id, request.clone())
2515 },
2516 );
2517 Ok(())
2518}
2519
2520/// Start following another user in a call.
2521async fn follow(
2522 request: proto::Follow,
2523 response: Response<proto::Follow>,
2524 session: MessageContext,
2525) -> Result<()> {
2526 let room_id = RoomId::from_proto(request.room_id);
2527 let project_id = request.project_id.map(ProjectId::from_proto);
2528 let leader_id = request.leader_id.context("invalid leader id")?.into();
2529 let follower_id = session.connection_id;
2530
2531 session
2532 .db()
2533 .await
2534 .check_room_participants(room_id, leader_id, session.connection_id)
2535 .await?;
2536
2537 let response_payload = session.forward_request(leader_id, request).await?;
2538 response.send(response_payload)?;
2539
2540 if let Some(project_id) = project_id {
2541 let room = session
2542 .db()
2543 .await
2544 .follow(room_id, project_id, leader_id, follower_id)
2545 .await?;
2546 room_updated(&room, &session.peer);
2547 }
2548
2549 Ok(())
2550}
2551
2552/// Stop following another user in a call.
2553async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2554 let room_id = RoomId::from_proto(request.room_id);
2555 let project_id = request.project_id.map(ProjectId::from_proto);
2556 let leader_id = request.leader_id.context("invalid leader id")?.into();
2557 let follower_id = session.connection_id;
2558
2559 session
2560 .db()
2561 .await
2562 .check_room_participants(room_id, leader_id, session.connection_id)
2563 .await?;
2564
2565 session
2566 .peer
2567 .forward_send(session.connection_id, leader_id, request)?;
2568
2569 if let Some(project_id) = project_id {
2570 let room = session
2571 .db()
2572 .await
2573 .unfollow(room_id, project_id, leader_id, follower_id)
2574 .await?;
2575 room_updated(&room, &session.peer);
2576 }
2577
2578 Ok(())
2579}
2580
2581/// Notify everyone following you of your current location.
2582async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2583 let room_id = RoomId::from_proto(request.room_id);
2584 let database = session.db.lock().await;
2585
2586 let connection_ids = if let Some(project_id) = request.project_id {
2587 let project_id = ProjectId::from_proto(project_id);
2588 database
2589 .project_connection_ids(project_id, session.connection_id, true)
2590 .await?
2591 } else {
2592 database
2593 .room_connection_ids(room_id, session.connection_id)
2594 .await?
2595 };
2596
2597 // For now, don't send view update messages back to that view's current leader.
2598 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2599 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2600 _ => None,
2601 });
2602
2603 for connection_id in connection_ids.iter().cloned() {
2604 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2605 session
2606 .peer
2607 .forward_send(session.connection_id, connection_id, request.clone())?;
2608 }
2609 }
2610 Ok(())
2611}
2612
2613/// Get public data about users.
2614async fn get_users(
2615 request: proto::GetUsers,
2616 response: Response<proto::GetUsers>,
2617 session: MessageContext,
2618) -> Result<()> {
2619 let user_ids = request
2620 .user_ids
2621 .into_iter()
2622 .map(UserId::from_proto)
2623 .collect();
2624 let users = session
2625 .db()
2626 .await
2627 .get_users_by_ids(user_ids)
2628 .await?
2629 .into_iter()
2630 .map(|user| proto::User {
2631 id: user.id.to_proto(),
2632 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2633 github_login: user.github_login,
2634 name: user.name,
2635 })
2636 .collect();
2637 response.send(proto::UsersResponse { users })?;
2638 Ok(())
2639}
2640
2641/// Search for users (to invite) buy Github login
2642async fn fuzzy_search_users(
2643 request: proto::FuzzySearchUsers,
2644 response: Response<proto::FuzzySearchUsers>,
2645 session: MessageContext,
2646) -> Result<()> {
2647 let query = request.query;
2648 let users = match query.len() {
2649 0 => vec![],
2650 1 | 2 => session
2651 .db()
2652 .await
2653 .get_user_by_github_login(&query)
2654 .await?
2655 .into_iter()
2656 .collect(),
2657 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2658 };
2659 let users = users
2660 .into_iter()
2661 .filter(|user| user.id != session.user_id())
2662 .map(|user| proto::User {
2663 id: user.id.to_proto(),
2664 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2665 github_login: user.github_login,
2666 name: user.name,
2667 })
2668 .collect();
2669 response.send(proto::UsersResponse { users })?;
2670 Ok(())
2671}
2672
2673/// Send a contact request to another user.
2674async fn request_contact(
2675 request: proto::RequestContact,
2676 response: Response<proto::RequestContact>,
2677 session: MessageContext,
2678) -> Result<()> {
2679 let requester_id = session.user_id();
2680 let responder_id = UserId::from_proto(request.responder_id);
2681 if requester_id == responder_id {
2682 return Err(anyhow!("cannot add yourself as a contact"))?;
2683 }
2684
2685 let notifications = session
2686 .db()
2687 .await
2688 .send_contact_request(requester_id, responder_id)
2689 .await?;
2690
2691 // Update outgoing contact requests of requester
2692 let mut update = proto::UpdateContacts::default();
2693 update.outgoing_requests.push(responder_id.to_proto());
2694 for connection_id in session
2695 .connection_pool()
2696 .await
2697 .user_connection_ids(requester_id)
2698 {
2699 session.peer.send(connection_id, update.clone())?;
2700 }
2701
2702 // Update incoming contact requests of responder
2703 let mut update = proto::UpdateContacts::default();
2704 update
2705 .incoming_requests
2706 .push(proto::IncomingContactRequest {
2707 requester_id: requester_id.to_proto(),
2708 });
2709 let connection_pool = session.connection_pool().await;
2710 for connection_id in connection_pool.user_connection_ids(responder_id) {
2711 session.peer.send(connection_id, update.clone())?;
2712 }
2713
2714 send_notifications(&connection_pool, &session.peer, notifications);
2715
2716 response.send(proto::Ack {})?;
2717 Ok(())
2718}
2719
2720/// Accept or decline a contact request
2721async fn respond_to_contact_request(
2722 request: proto::RespondToContactRequest,
2723 response: Response<proto::RespondToContactRequest>,
2724 session: MessageContext,
2725) -> Result<()> {
2726 let responder_id = session.user_id();
2727 let requester_id = UserId::from_proto(request.requester_id);
2728 let db = session.db().await;
2729 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2730 db.dismiss_contact_notification(responder_id, requester_id)
2731 .await?;
2732 } else {
2733 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2734
2735 let notifications = db
2736 .respond_to_contact_request(responder_id, requester_id, accept)
2737 .await?;
2738 let requester_busy = db.is_user_busy(requester_id).await?;
2739 let responder_busy = db.is_user_busy(responder_id).await?;
2740
2741 let pool = session.connection_pool().await;
2742 // Update responder with new contact
2743 let mut update = proto::UpdateContacts::default();
2744 if accept {
2745 update
2746 .contacts
2747 .push(contact_for_user(requester_id, requester_busy, &pool));
2748 }
2749 update
2750 .remove_incoming_requests
2751 .push(requester_id.to_proto());
2752 for connection_id in pool.user_connection_ids(responder_id) {
2753 session.peer.send(connection_id, update.clone())?;
2754 }
2755
2756 // Update requester with new contact
2757 let mut update = proto::UpdateContacts::default();
2758 if accept {
2759 update
2760 .contacts
2761 .push(contact_for_user(responder_id, responder_busy, &pool));
2762 }
2763 update
2764 .remove_outgoing_requests
2765 .push(responder_id.to_proto());
2766
2767 for connection_id in pool.user_connection_ids(requester_id) {
2768 session.peer.send(connection_id, update.clone())?;
2769 }
2770
2771 send_notifications(&pool, &session.peer, notifications);
2772 }
2773
2774 response.send(proto::Ack {})?;
2775 Ok(())
2776}
2777
2778/// Remove a contact.
2779async fn remove_contact(
2780 request: proto::RemoveContact,
2781 response: Response<proto::RemoveContact>,
2782 session: MessageContext,
2783) -> Result<()> {
2784 let requester_id = session.user_id();
2785 let responder_id = UserId::from_proto(request.user_id);
2786 let db = session.db().await;
2787 let (contact_accepted, deleted_notification_id) =
2788 db.remove_contact(requester_id, responder_id).await?;
2789
2790 let pool = session.connection_pool().await;
2791 // Update outgoing contact requests of requester
2792 let mut update = proto::UpdateContacts::default();
2793 if contact_accepted {
2794 update.remove_contacts.push(responder_id.to_proto());
2795 } else {
2796 update
2797 .remove_outgoing_requests
2798 .push(responder_id.to_proto());
2799 }
2800 for connection_id in pool.user_connection_ids(requester_id) {
2801 session.peer.send(connection_id, update.clone())?;
2802 }
2803
2804 // Update incoming contact requests of responder
2805 let mut update = proto::UpdateContacts::default();
2806 if contact_accepted {
2807 update.remove_contacts.push(requester_id.to_proto());
2808 } else {
2809 update
2810 .remove_incoming_requests
2811 .push(requester_id.to_proto());
2812 }
2813 for connection_id in pool.user_connection_ids(responder_id) {
2814 session.peer.send(connection_id, update.clone())?;
2815 if let Some(notification_id) = deleted_notification_id {
2816 session.peer.send(
2817 connection_id,
2818 proto::DeleteNotification {
2819 notification_id: notification_id.to_proto(),
2820 },
2821 )?;
2822 }
2823 }
2824
2825 response.send(proto::Ack {})?;
2826 Ok(())
2827}
2828
2829fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2830 version.0.minor() < 139
2831}
2832
2833async fn subscribe_to_channels(
2834 _: proto::SubscribeToChannels,
2835 session: MessageContext,
2836) -> Result<()> {
2837 subscribe_user_to_channels(session.user_id(), &session).await?;
2838 Ok(())
2839}
2840
2841async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2842 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2843 let mut pool = session.connection_pool().await;
2844 for membership in &channels_for_user.channel_memberships {
2845 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2846 }
2847 session.peer.send(
2848 session.connection_id,
2849 build_update_user_channels(&channels_for_user),
2850 )?;
2851 session.peer.send(
2852 session.connection_id,
2853 build_channels_update(channels_for_user),
2854 )?;
2855 Ok(())
2856}
2857
2858/// Creates a new channel.
2859async fn create_channel(
2860 request: proto::CreateChannel,
2861 response: Response<proto::CreateChannel>,
2862 session: MessageContext,
2863) -> Result<()> {
2864 let db = session.db().await;
2865
2866 let parent_id = request.parent_id.map(ChannelId::from_proto);
2867 let (channel, membership) = db
2868 .create_channel(&request.name, parent_id, session.user_id())
2869 .await?;
2870
2871 let root_id = channel.root_id();
2872 let channel = Channel::from_model(channel);
2873
2874 response.send(proto::CreateChannelResponse {
2875 channel: Some(channel.to_proto()),
2876 parent_id: request.parent_id,
2877 })?;
2878
2879 let mut connection_pool = session.connection_pool().await;
2880 if let Some(membership) = membership {
2881 connection_pool.subscribe_to_channel(
2882 membership.user_id,
2883 membership.channel_id,
2884 membership.role,
2885 );
2886 let update = proto::UpdateUserChannels {
2887 channel_memberships: vec![proto::ChannelMembership {
2888 channel_id: membership.channel_id.to_proto(),
2889 role: membership.role.into(),
2890 }],
2891 ..Default::default()
2892 };
2893 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2894 session.peer.send(connection_id, update.clone())?;
2895 }
2896 }
2897
2898 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2899 if !role.can_see_channel(channel.visibility) {
2900 continue;
2901 }
2902
2903 let update = proto::UpdateChannels {
2904 channels: vec![channel.to_proto()],
2905 ..Default::default()
2906 };
2907 session.peer.send(connection_id, update.clone())?;
2908 }
2909
2910 Ok(())
2911}
2912
2913/// Delete a channel
2914async fn delete_channel(
2915 request: proto::DeleteChannel,
2916 response: Response<proto::DeleteChannel>,
2917 session: MessageContext,
2918) -> Result<()> {
2919 let db = session.db().await;
2920
2921 let channel_id = request.channel_id;
2922 let (root_channel, removed_channels) = db
2923 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2924 .await?;
2925 response.send(proto::Ack {})?;
2926
2927 // Notify members of removed channels
2928 let mut update = proto::UpdateChannels::default();
2929 update
2930 .delete_channels
2931 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2932
2933 let connection_pool = session.connection_pool().await;
2934 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2935 session.peer.send(connection_id, update.clone())?;
2936 }
2937
2938 Ok(())
2939}
2940
2941/// Invite someone to join a channel.
2942async fn invite_channel_member(
2943 request: proto::InviteChannelMember,
2944 response: Response<proto::InviteChannelMember>,
2945 session: MessageContext,
2946) -> Result<()> {
2947 let db = session.db().await;
2948 let channel_id = ChannelId::from_proto(request.channel_id);
2949 let invitee_id = UserId::from_proto(request.user_id);
2950 let InviteMemberResult {
2951 channel,
2952 notifications,
2953 } = db
2954 .invite_channel_member(
2955 channel_id,
2956 invitee_id,
2957 session.user_id(),
2958 request.role().into(),
2959 )
2960 .await?;
2961
2962 let update = proto::UpdateChannels {
2963 channel_invitations: vec![channel.to_proto()],
2964 ..Default::default()
2965 };
2966
2967 let connection_pool = session.connection_pool().await;
2968 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2969 session.peer.send(connection_id, update.clone())?;
2970 }
2971
2972 send_notifications(&connection_pool, &session.peer, notifications);
2973
2974 response.send(proto::Ack {})?;
2975 Ok(())
2976}
2977
2978/// remove someone from a channel
2979async fn remove_channel_member(
2980 request: proto::RemoveChannelMember,
2981 response: Response<proto::RemoveChannelMember>,
2982 session: MessageContext,
2983) -> Result<()> {
2984 let db = session.db().await;
2985 let channel_id = ChannelId::from_proto(request.channel_id);
2986 let member_id = UserId::from_proto(request.user_id);
2987
2988 let RemoveChannelMemberResult {
2989 membership_update,
2990 notification_id,
2991 } = db
2992 .remove_channel_member(channel_id, member_id, session.user_id())
2993 .await?;
2994
2995 let mut connection_pool = session.connection_pool().await;
2996 notify_membership_updated(
2997 &mut connection_pool,
2998 membership_update,
2999 member_id,
3000 &session.peer,
3001 );
3002 for connection_id in connection_pool.user_connection_ids(member_id) {
3003 if let Some(notification_id) = notification_id {
3004 session
3005 .peer
3006 .send(
3007 connection_id,
3008 proto::DeleteNotification {
3009 notification_id: notification_id.to_proto(),
3010 },
3011 )
3012 .trace_err();
3013 }
3014 }
3015
3016 response.send(proto::Ack {})?;
3017 Ok(())
3018}
3019
3020/// Toggle the channel between public and private.
3021/// Care is taken to maintain the invariant that public channels only descend from public channels,
3022/// (though members-only channels can appear at any point in the hierarchy).
3023async fn set_channel_visibility(
3024 request: proto::SetChannelVisibility,
3025 response: Response<proto::SetChannelVisibility>,
3026 session: MessageContext,
3027) -> Result<()> {
3028 let db = session.db().await;
3029 let channel_id = ChannelId::from_proto(request.channel_id);
3030 let visibility = request.visibility().into();
3031
3032 let channel_model = db
3033 .set_channel_visibility(channel_id, visibility, session.user_id())
3034 .await?;
3035 let root_id = channel_model.root_id();
3036 let channel = Channel::from_model(channel_model);
3037
3038 let mut connection_pool = session.connection_pool().await;
3039 for (user_id, role) in connection_pool
3040 .channel_user_ids(root_id)
3041 .collect::<Vec<_>>()
3042 .into_iter()
3043 {
3044 let update = if role.can_see_channel(channel.visibility) {
3045 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3046 proto::UpdateChannels {
3047 channels: vec![channel.to_proto()],
3048 ..Default::default()
3049 }
3050 } else {
3051 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3052 proto::UpdateChannels {
3053 delete_channels: vec![channel.id.to_proto()],
3054 ..Default::default()
3055 }
3056 };
3057
3058 for connection_id in connection_pool.user_connection_ids(user_id) {
3059 session.peer.send(connection_id, update.clone())?;
3060 }
3061 }
3062
3063 response.send(proto::Ack {})?;
3064 Ok(())
3065}
3066
3067/// Alter the role for a user in the channel.
3068async fn set_channel_member_role(
3069 request: proto::SetChannelMemberRole,
3070 response: Response<proto::SetChannelMemberRole>,
3071 session: MessageContext,
3072) -> Result<()> {
3073 let db = session.db().await;
3074 let channel_id = ChannelId::from_proto(request.channel_id);
3075 let member_id = UserId::from_proto(request.user_id);
3076 let result = db
3077 .set_channel_member_role(
3078 channel_id,
3079 session.user_id(),
3080 member_id,
3081 request.role().into(),
3082 )
3083 .await?;
3084
3085 match result {
3086 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3087 let mut connection_pool = session.connection_pool().await;
3088 notify_membership_updated(
3089 &mut connection_pool,
3090 membership_update,
3091 member_id,
3092 &session.peer,
3093 )
3094 }
3095 db::SetMemberRoleResult::InviteUpdated(channel) => {
3096 let update = proto::UpdateChannels {
3097 channel_invitations: vec![channel.to_proto()],
3098 ..Default::default()
3099 };
3100
3101 for connection_id in session
3102 .connection_pool()
3103 .await
3104 .user_connection_ids(member_id)
3105 {
3106 session.peer.send(connection_id, update.clone())?;
3107 }
3108 }
3109 }
3110
3111 response.send(proto::Ack {})?;
3112 Ok(())
3113}
3114
3115/// Change the name of a channel
3116async fn rename_channel(
3117 request: proto::RenameChannel,
3118 response: Response<proto::RenameChannel>,
3119 session: MessageContext,
3120) -> Result<()> {
3121 let db = session.db().await;
3122 let channel_id = ChannelId::from_proto(request.channel_id);
3123 let channel_model = db
3124 .rename_channel(channel_id, session.user_id(), &request.name)
3125 .await?;
3126 let root_id = channel_model.root_id();
3127 let channel = Channel::from_model(channel_model);
3128
3129 response.send(proto::RenameChannelResponse {
3130 channel: Some(channel.to_proto()),
3131 })?;
3132
3133 let connection_pool = session.connection_pool().await;
3134 let update = proto::UpdateChannels {
3135 channels: vec![channel.to_proto()],
3136 ..Default::default()
3137 };
3138 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3139 if role.can_see_channel(channel.visibility) {
3140 session.peer.send(connection_id, update.clone())?;
3141 }
3142 }
3143
3144 Ok(())
3145}
3146
3147/// Move a channel to a new parent.
3148async fn move_channel(
3149 request: proto::MoveChannel,
3150 response: Response<proto::MoveChannel>,
3151 session: MessageContext,
3152) -> Result<()> {
3153 let channel_id = ChannelId::from_proto(request.channel_id);
3154 let to = ChannelId::from_proto(request.to);
3155
3156 let (root_id, channels) = session
3157 .db()
3158 .await
3159 .move_channel(channel_id, to, session.user_id())
3160 .await?;
3161
3162 let connection_pool = session.connection_pool().await;
3163 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3164 let channels = channels
3165 .iter()
3166 .filter_map(|channel| {
3167 if role.can_see_channel(channel.visibility) {
3168 Some(channel.to_proto())
3169 } else {
3170 None
3171 }
3172 })
3173 .collect::<Vec<_>>();
3174 if channels.is_empty() {
3175 continue;
3176 }
3177
3178 let update = proto::UpdateChannels {
3179 channels,
3180 ..Default::default()
3181 };
3182
3183 session.peer.send(connection_id, update.clone())?;
3184 }
3185
3186 response.send(Ack {})?;
3187 Ok(())
3188}
3189
3190async fn reorder_channel(
3191 request: proto::ReorderChannel,
3192 response: Response<proto::ReorderChannel>,
3193 session: MessageContext,
3194) -> Result<()> {
3195 let channel_id = ChannelId::from_proto(request.channel_id);
3196 let direction = request.direction();
3197
3198 let updated_channels = session
3199 .db()
3200 .await
3201 .reorder_channel(channel_id, direction, session.user_id())
3202 .await?;
3203
3204 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3205 let connection_pool = session.connection_pool().await;
3206 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3207 let channels = updated_channels
3208 .iter()
3209 .filter_map(|channel| {
3210 if role.can_see_channel(channel.visibility) {
3211 Some(channel.to_proto())
3212 } else {
3213 None
3214 }
3215 })
3216 .collect::<Vec<_>>();
3217
3218 if channels.is_empty() {
3219 continue;
3220 }
3221
3222 let update = proto::UpdateChannels {
3223 channels,
3224 ..Default::default()
3225 };
3226
3227 session.peer.send(connection_id, update.clone())?;
3228 }
3229 }
3230
3231 response.send(Ack {})?;
3232 Ok(())
3233}
3234
3235/// Get the list of channel members
3236async fn get_channel_members(
3237 request: proto::GetChannelMembers,
3238 response: Response<proto::GetChannelMembers>,
3239 session: MessageContext,
3240) -> Result<()> {
3241 let db = session.db().await;
3242 let channel_id = ChannelId::from_proto(request.channel_id);
3243 let limit = if request.limit == 0 {
3244 u16::MAX as u64
3245 } else {
3246 request.limit
3247 };
3248 let (members, users) = db
3249 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3250 .await?;
3251 response.send(proto::GetChannelMembersResponse { members, users })?;
3252 Ok(())
3253}
3254
3255/// Accept or decline a channel invitation.
3256async fn respond_to_channel_invite(
3257 request: proto::RespondToChannelInvite,
3258 response: Response<proto::RespondToChannelInvite>,
3259 session: MessageContext,
3260) -> Result<()> {
3261 let db = session.db().await;
3262 let channel_id = ChannelId::from_proto(request.channel_id);
3263 let RespondToChannelInvite {
3264 membership_update,
3265 notifications,
3266 } = db
3267 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3268 .await?;
3269
3270 let mut connection_pool = session.connection_pool().await;
3271 if let Some(membership_update) = membership_update {
3272 notify_membership_updated(
3273 &mut connection_pool,
3274 membership_update,
3275 session.user_id(),
3276 &session.peer,
3277 );
3278 } else {
3279 let update = proto::UpdateChannels {
3280 remove_channel_invitations: vec![channel_id.to_proto()],
3281 ..Default::default()
3282 };
3283
3284 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3285 session.peer.send(connection_id, update.clone())?;
3286 }
3287 };
3288
3289 send_notifications(&connection_pool, &session.peer, notifications);
3290
3291 response.send(proto::Ack {})?;
3292
3293 Ok(())
3294}
3295
3296/// Join the channels' room
3297async fn join_channel(
3298 request: proto::JoinChannel,
3299 response: Response<proto::JoinChannel>,
3300 session: MessageContext,
3301) -> Result<()> {
3302 let channel_id = ChannelId::from_proto(request.channel_id);
3303 join_channel_internal(channel_id, Box::new(response), session).await
3304}
3305
3306trait JoinChannelInternalResponse {
3307 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3308}
3309impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3310 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3311 Response::<proto::JoinChannel>::send(self, result)
3312 }
3313}
3314impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3315 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3316 Response::<proto::JoinRoom>::send(self, result)
3317 }
3318}
3319
3320async fn join_channel_internal(
3321 channel_id: ChannelId,
3322 response: Box<impl JoinChannelInternalResponse>,
3323 session: MessageContext,
3324) -> Result<()> {
3325 let joined_room = {
3326 let mut db = session.db().await;
3327 // If zed quits without leaving the room, and the user re-opens zed before the
3328 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3329 // room they were in.
3330 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3331 tracing::info!(
3332 stale_connection_id = %connection,
3333 "cleaning up stale connection",
3334 );
3335 drop(db);
3336 leave_room_for_session(&session, connection).await?;
3337 db = session.db().await;
3338 }
3339
3340 let (joined_room, membership_updated, role) = db
3341 .join_channel(channel_id, session.user_id(), session.connection_id)
3342 .await?;
3343
3344 let live_kit_connection_info =
3345 session
3346 .app_state
3347 .livekit_client
3348 .as_ref()
3349 .and_then(|live_kit| {
3350 let (can_publish, token) = if role == ChannelRole::Guest {
3351 (
3352 false,
3353 live_kit
3354 .guest_token(
3355 &joined_room.room.livekit_room,
3356 &session.user_id().to_string(),
3357 )
3358 .trace_err()?,
3359 )
3360 } else {
3361 (
3362 true,
3363 live_kit
3364 .room_token(
3365 &joined_room.room.livekit_room,
3366 &session.user_id().to_string(),
3367 )
3368 .trace_err()?,
3369 )
3370 };
3371
3372 Some(LiveKitConnectionInfo {
3373 server_url: live_kit.url().into(),
3374 token,
3375 can_publish,
3376 })
3377 });
3378
3379 response.send(proto::JoinRoomResponse {
3380 room: Some(joined_room.room.clone()),
3381 channel_id: joined_room
3382 .channel
3383 .as_ref()
3384 .map(|channel| channel.id.to_proto()),
3385 live_kit_connection_info,
3386 })?;
3387
3388 let mut connection_pool = session.connection_pool().await;
3389 if let Some(membership_updated) = membership_updated {
3390 notify_membership_updated(
3391 &mut connection_pool,
3392 membership_updated,
3393 session.user_id(),
3394 &session.peer,
3395 );
3396 }
3397
3398 room_updated(&joined_room.room, &session.peer);
3399
3400 joined_room
3401 };
3402
3403 channel_updated(
3404 &joined_room.channel.context("channel not returned")?,
3405 &joined_room.room,
3406 &session.peer,
3407 &*session.connection_pool().await,
3408 );
3409
3410 update_user_contacts(session.user_id(), &session).await?;
3411 Ok(())
3412}
3413
3414/// Start editing the channel notes
3415async fn join_channel_buffer(
3416 request: proto::JoinChannelBuffer,
3417 response: Response<proto::JoinChannelBuffer>,
3418 session: MessageContext,
3419) -> Result<()> {
3420 let db = session.db().await;
3421 let channel_id = ChannelId::from_proto(request.channel_id);
3422
3423 let open_response = db
3424 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3425 .await?;
3426
3427 let collaborators = open_response.collaborators.clone();
3428 response.send(open_response)?;
3429
3430 let update = UpdateChannelBufferCollaborators {
3431 channel_id: channel_id.to_proto(),
3432 collaborators: collaborators.clone(),
3433 };
3434 channel_buffer_updated(
3435 session.connection_id,
3436 collaborators
3437 .iter()
3438 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3439 &update,
3440 &session.peer,
3441 );
3442
3443 Ok(())
3444}
3445
3446/// Edit the channel notes
3447async fn update_channel_buffer(
3448 request: proto::UpdateChannelBuffer,
3449 session: MessageContext,
3450) -> Result<()> {
3451 let db = session.db().await;
3452 let channel_id = ChannelId::from_proto(request.channel_id);
3453
3454 let (collaborators, epoch, version) = db
3455 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3456 .await?;
3457
3458 channel_buffer_updated(
3459 session.connection_id,
3460 collaborators.clone(),
3461 &proto::UpdateChannelBuffer {
3462 channel_id: channel_id.to_proto(),
3463 operations: request.operations,
3464 },
3465 &session.peer,
3466 );
3467
3468 let pool = &*session.connection_pool().await;
3469
3470 let non_collaborators =
3471 pool.channel_connection_ids(channel_id)
3472 .filter_map(|(connection_id, _)| {
3473 if collaborators.contains(&connection_id) {
3474 None
3475 } else {
3476 Some(connection_id)
3477 }
3478 });
3479
3480 broadcast(None, non_collaborators, |peer_id| {
3481 session.peer.send(
3482 peer_id,
3483 proto::UpdateChannels {
3484 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3485 channel_id: channel_id.to_proto(),
3486 epoch: epoch as u64,
3487 version: version.clone(),
3488 }],
3489 ..Default::default()
3490 },
3491 )
3492 });
3493
3494 Ok(())
3495}
3496
3497/// Rejoin the channel notes after a connection blip
3498async fn rejoin_channel_buffers(
3499 request: proto::RejoinChannelBuffers,
3500 response: Response<proto::RejoinChannelBuffers>,
3501 session: MessageContext,
3502) -> Result<()> {
3503 let db = session.db().await;
3504 let buffers = db
3505 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3506 .await?;
3507
3508 for rejoined_buffer in &buffers {
3509 let collaborators_to_notify = rejoined_buffer
3510 .buffer
3511 .collaborators
3512 .iter()
3513 .filter_map(|c| Some(c.peer_id?.into()));
3514 channel_buffer_updated(
3515 session.connection_id,
3516 collaborators_to_notify,
3517 &proto::UpdateChannelBufferCollaborators {
3518 channel_id: rejoined_buffer.buffer.channel_id,
3519 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3520 },
3521 &session.peer,
3522 );
3523 }
3524
3525 response.send(proto::RejoinChannelBuffersResponse {
3526 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3527 })?;
3528
3529 Ok(())
3530}
3531
3532/// Stop editing the channel notes
3533async fn leave_channel_buffer(
3534 request: proto::LeaveChannelBuffer,
3535 response: Response<proto::LeaveChannelBuffer>,
3536 session: MessageContext,
3537) -> Result<()> {
3538 let db = session.db().await;
3539 let channel_id = ChannelId::from_proto(request.channel_id);
3540
3541 let left_buffer = db
3542 .leave_channel_buffer(channel_id, session.connection_id)
3543 .await?;
3544
3545 response.send(Ack {})?;
3546
3547 channel_buffer_updated(
3548 session.connection_id,
3549 left_buffer.connections,
3550 &proto::UpdateChannelBufferCollaborators {
3551 channel_id: channel_id.to_proto(),
3552 collaborators: left_buffer.collaborators,
3553 },
3554 &session.peer,
3555 );
3556
3557 Ok(())
3558}
3559
3560fn channel_buffer_updated<T: EnvelopedMessage>(
3561 sender_id: ConnectionId,
3562 collaborators: impl IntoIterator<Item = ConnectionId>,
3563 message: &T,
3564 peer: &Peer,
3565) {
3566 broadcast(Some(sender_id), collaborators, |peer_id| {
3567 peer.send(peer_id, message.clone())
3568 });
3569}
3570
3571fn send_notifications(
3572 connection_pool: &ConnectionPool,
3573 peer: &Peer,
3574 notifications: db::NotificationBatch,
3575) {
3576 for (user_id, notification) in notifications {
3577 for connection_id in connection_pool.user_connection_ids(user_id) {
3578 if let Err(error) = peer.send(
3579 connection_id,
3580 proto::AddNotification {
3581 notification: Some(notification.clone()),
3582 },
3583 ) {
3584 tracing::error!(
3585 "failed to send notification to {:?} {}",
3586 connection_id,
3587 error
3588 );
3589 }
3590 }
3591 }
3592}
3593
3594/// Send a message to the channel
3595async fn send_channel_message(
3596 _request: proto::SendChannelMessage,
3597 _response: Response<proto::SendChannelMessage>,
3598 _session: MessageContext,
3599) -> Result<()> {
3600 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3601}
3602
3603/// Delete a channel message
3604async fn remove_channel_message(
3605 _request: proto::RemoveChannelMessage,
3606 _response: Response<proto::RemoveChannelMessage>,
3607 _session: MessageContext,
3608) -> Result<()> {
3609 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3610}
3611
3612async fn update_channel_message(
3613 _request: proto::UpdateChannelMessage,
3614 _response: Response<proto::UpdateChannelMessage>,
3615 _session: MessageContext,
3616) -> Result<()> {
3617 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3618}
3619
3620/// Mark a channel message as read
3621async fn acknowledge_channel_message(
3622 _request: proto::AckChannelMessage,
3623 _session: MessageContext,
3624) -> Result<()> {
3625 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3626}
3627
3628/// Mark a buffer version as synced
3629async fn acknowledge_buffer_version(
3630 request: proto::AckBufferOperation,
3631 session: MessageContext,
3632) -> Result<()> {
3633 let buffer_id = BufferId::from_proto(request.buffer_id);
3634 session
3635 .db()
3636 .await
3637 .observe_buffer_version(
3638 buffer_id,
3639 session.user_id(),
3640 request.epoch as i32,
3641 &request.version,
3642 )
3643 .await?;
3644 Ok(())
3645}
3646
3647/// Get a Supermaven API key for the user
3648async fn get_supermaven_api_key(
3649 _request: proto::GetSupermavenApiKey,
3650 response: Response<proto::GetSupermavenApiKey>,
3651 session: MessageContext,
3652) -> Result<()> {
3653 let user_id: String = session.user_id().to_string();
3654 if !session.is_staff() {
3655 return Err(anyhow!("supermaven not enabled for this account"))?;
3656 }
3657
3658 let email = session.email().context("user must have an email")?;
3659
3660 let supermaven_admin_api = session
3661 .supermaven_client
3662 .as_ref()
3663 .context("supermaven not configured")?;
3664
3665 let result = supermaven_admin_api
3666 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3667 .await?;
3668
3669 response.send(proto::GetSupermavenApiKeyResponse {
3670 api_key: result.api_key,
3671 })?;
3672
3673 Ok(())
3674}
3675
3676/// Start receiving chat updates for a channel
3677async fn join_channel_chat(
3678 _request: proto::JoinChannelChat,
3679 _response: Response<proto::JoinChannelChat>,
3680 _session: MessageContext,
3681) -> Result<()> {
3682 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3683}
3684
3685/// Stop receiving chat updates for a channel
3686async fn leave_channel_chat(
3687 _request: proto::LeaveChannelChat,
3688 _session: MessageContext,
3689) -> Result<()> {
3690 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3691}
3692
3693/// Retrieve the chat history for a channel
3694async fn get_channel_messages(
3695 _request: proto::GetChannelMessages,
3696 _response: Response<proto::GetChannelMessages>,
3697 _session: MessageContext,
3698) -> Result<()> {
3699 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3700}
3701
3702/// Retrieve specific chat messages
3703async fn get_channel_messages_by_id(
3704 _request: proto::GetChannelMessagesById,
3705 _response: Response<proto::GetChannelMessagesById>,
3706 _session: MessageContext,
3707) -> Result<()> {
3708 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3709}
3710
3711/// Retrieve the current users notifications
3712async fn get_notifications(
3713 request: proto::GetNotifications,
3714 response: Response<proto::GetNotifications>,
3715 session: MessageContext,
3716) -> Result<()> {
3717 let notifications = session
3718 .db()
3719 .await
3720 .get_notifications(
3721 session.user_id(),
3722 NOTIFICATION_COUNT_PER_PAGE,
3723 request.before_id.map(db::NotificationId::from_proto),
3724 )
3725 .await?;
3726 response.send(proto::GetNotificationsResponse {
3727 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3728 notifications,
3729 })?;
3730 Ok(())
3731}
3732
3733/// Mark notifications as read
3734async fn mark_notification_as_read(
3735 request: proto::MarkNotificationRead,
3736 response: Response<proto::MarkNotificationRead>,
3737 session: MessageContext,
3738) -> Result<()> {
3739 let database = &session.db().await;
3740 let notifications = database
3741 .mark_notification_as_read_by_id(
3742 session.user_id(),
3743 NotificationId::from_proto(request.notification_id),
3744 )
3745 .await?;
3746 send_notifications(
3747 &*session.connection_pool().await,
3748 &session.peer,
3749 notifications,
3750 );
3751 response.send(proto::Ack {})?;
3752 Ok(())
3753}
3754
3755fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3756 let message = match message {
3757 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3758 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3759 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3760 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3761 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3762 code: frame.code.into(),
3763 reason: frame.reason.as_str().to_owned().into(),
3764 })),
3765 // We should never receive a frame while reading the message, according
3766 // to the `tungstenite` maintainers:
3767 //
3768 // > It cannot occur when you read messages from the WebSocket, but it
3769 // > can be used when you want to send the raw frames (e.g. you want to
3770 // > send the frames to the WebSocket without composing the full message first).
3771 // >
3772 // > — https://github.com/snapview/tungstenite-rs/issues/268
3773 TungsteniteMessage::Frame(_) => {
3774 bail!("received an unexpected frame while reading the message")
3775 }
3776 };
3777
3778 Ok(message)
3779}
3780
3781fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3782 match message {
3783 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3784 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3785 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3786 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3787 AxumMessage::Close(frame) => {
3788 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3789 code: frame.code.into(),
3790 reason: frame.reason.as_ref().into(),
3791 }))
3792 }
3793 }
3794}
3795
3796fn notify_membership_updated(
3797 connection_pool: &mut ConnectionPool,
3798 result: MembershipUpdated,
3799 user_id: UserId,
3800 peer: &Peer,
3801) {
3802 for membership in &result.new_channels.channel_memberships {
3803 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3804 }
3805 for channel_id in &result.removed_channels {
3806 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3807 }
3808
3809 let user_channels_update = proto::UpdateUserChannels {
3810 channel_memberships: result
3811 .new_channels
3812 .channel_memberships
3813 .iter()
3814 .map(|cm| proto::ChannelMembership {
3815 channel_id: cm.channel_id.to_proto(),
3816 role: cm.role.into(),
3817 })
3818 .collect(),
3819 ..Default::default()
3820 };
3821
3822 let mut update = build_channels_update(result.new_channels);
3823 update.delete_channels = result
3824 .removed_channels
3825 .into_iter()
3826 .map(|id| id.to_proto())
3827 .collect();
3828 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3829
3830 for connection_id in connection_pool.user_connection_ids(user_id) {
3831 peer.send(connection_id, user_channels_update.clone())
3832 .trace_err();
3833 peer.send(connection_id, update.clone()).trace_err();
3834 }
3835}
3836
3837fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3838 proto::UpdateUserChannels {
3839 channel_memberships: channels
3840 .channel_memberships
3841 .iter()
3842 .map(|m| proto::ChannelMembership {
3843 channel_id: m.channel_id.to_proto(),
3844 role: m.role.into(),
3845 })
3846 .collect(),
3847 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3848 }
3849}
3850
3851fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3852 let mut update = proto::UpdateChannels::default();
3853
3854 for channel in channels.channels {
3855 update.channels.push(channel.to_proto());
3856 }
3857
3858 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3859
3860 for (channel_id, participants) in channels.channel_participants {
3861 update
3862 .channel_participants
3863 .push(proto::ChannelParticipants {
3864 channel_id: channel_id.to_proto(),
3865 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3866 });
3867 }
3868
3869 for channel in channels.invited_channels {
3870 update.channel_invitations.push(channel.to_proto());
3871 }
3872
3873 update
3874}
3875
3876fn build_initial_contacts_update(
3877 contacts: Vec<db::Contact>,
3878 pool: &ConnectionPool,
3879) -> proto::UpdateContacts {
3880 let mut update = proto::UpdateContacts::default();
3881
3882 for contact in contacts {
3883 match contact {
3884 db::Contact::Accepted { user_id, busy } => {
3885 update.contacts.push(contact_for_user(user_id, busy, pool));
3886 }
3887 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3888 db::Contact::Incoming { user_id } => {
3889 update
3890 .incoming_requests
3891 .push(proto::IncomingContactRequest {
3892 requester_id: user_id.to_proto(),
3893 })
3894 }
3895 }
3896 }
3897
3898 update
3899}
3900
3901fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3902 proto::Contact {
3903 user_id: user_id.to_proto(),
3904 online: pool.is_user_online(user_id),
3905 busy,
3906 }
3907}
3908
3909fn room_updated(room: &proto::Room, peer: &Peer) {
3910 broadcast(
3911 None,
3912 room.participants
3913 .iter()
3914 .filter_map(|participant| Some(participant.peer_id?.into())),
3915 |peer_id| {
3916 peer.send(
3917 peer_id,
3918 proto::RoomUpdated {
3919 room: Some(room.clone()),
3920 },
3921 )
3922 },
3923 );
3924}
3925
3926fn channel_updated(
3927 channel: &db::channel::Model,
3928 room: &proto::Room,
3929 peer: &Peer,
3930 pool: &ConnectionPool,
3931) {
3932 let participants = room
3933 .participants
3934 .iter()
3935 .map(|p| p.user_id)
3936 .collect::<Vec<_>>();
3937
3938 broadcast(
3939 None,
3940 pool.channel_connection_ids(channel.root_id())
3941 .filter_map(|(channel_id, role)| {
3942 role.can_see_channel(channel.visibility)
3943 .then_some(channel_id)
3944 }),
3945 |peer_id| {
3946 peer.send(
3947 peer_id,
3948 proto::UpdateChannels {
3949 channel_participants: vec![proto::ChannelParticipants {
3950 channel_id: channel.id.to_proto(),
3951 participant_user_ids: participants.clone(),
3952 }],
3953 ..Default::default()
3954 },
3955 )
3956 },
3957 );
3958}
3959
3960async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3961 let db = session.db().await;
3962
3963 let contacts = db.get_contacts(user_id).await?;
3964 let busy = db.is_user_busy(user_id).await?;
3965
3966 let pool = session.connection_pool().await;
3967 let updated_contact = contact_for_user(user_id, busy, &pool);
3968 for contact in contacts {
3969 if let db::Contact::Accepted {
3970 user_id: contact_user_id,
3971 ..
3972 } = contact
3973 {
3974 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3975 session
3976 .peer
3977 .send(
3978 contact_conn_id,
3979 proto::UpdateContacts {
3980 contacts: vec![updated_contact.clone()],
3981 remove_contacts: Default::default(),
3982 incoming_requests: Default::default(),
3983 remove_incoming_requests: Default::default(),
3984 outgoing_requests: Default::default(),
3985 remove_outgoing_requests: Default::default(),
3986 },
3987 )
3988 .trace_err();
3989 }
3990 }
3991 }
3992 Ok(())
3993}
3994
3995async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3996 let mut contacts_to_update = HashSet::default();
3997
3998 let room_id;
3999 let canceled_calls_to_user_ids;
4000 let livekit_room;
4001 let delete_livekit_room;
4002 let room;
4003 let channel;
4004
4005 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4006 contacts_to_update.insert(session.user_id());
4007
4008 for project in left_room.left_projects.values() {
4009 project_left(project, session);
4010 }
4011
4012 room_id = RoomId::from_proto(left_room.room.id);
4013 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4014 livekit_room = mem::take(&mut left_room.room.livekit_room);
4015 delete_livekit_room = left_room.deleted;
4016 room = mem::take(&mut left_room.room);
4017 channel = mem::take(&mut left_room.channel);
4018
4019 room_updated(&room, &session.peer);
4020 } else {
4021 return Ok(());
4022 }
4023
4024 if let Some(channel) = channel {
4025 channel_updated(
4026 &channel,
4027 &room,
4028 &session.peer,
4029 &*session.connection_pool().await,
4030 );
4031 }
4032
4033 {
4034 let pool = session.connection_pool().await;
4035 for canceled_user_id in canceled_calls_to_user_ids {
4036 for connection_id in pool.user_connection_ids(canceled_user_id) {
4037 session
4038 .peer
4039 .send(
4040 connection_id,
4041 proto::CallCanceled {
4042 room_id: room_id.to_proto(),
4043 },
4044 )
4045 .trace_err();
4046 }
4047 contacts_to_update.insert(canceled_user_id);
4048 }
4049 }
4050
4051 for contact_user_id in contacts_to_update {
4052 update_user_contacts(contact_user_id, session).await?;
4053 }
4054
4055 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4056 live_kit
4057 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4058 .await
4059 .trace_err();
4060
4061 if delete_livekit_room {
4062 live_kit.delete_room(livekit_room).await.trace_err();
4063 }
4064 }
4065
4066 Ok(())
4067}
4068
4069async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4070 let left_channel_buffers = session
4071 .db()
4072 .await
4073 .leave_channel_buffers(session.connection_id)
4074 .await?;
4075
4076 for left_buffer in left_channel_buffers {
4077 channel_buffer_updated(
4078 session.connection_id,
4079 left_buffer.connections,
4080 &proto::UpdateChannelBufferCollaborators {
4081 channel_id: left_buffer.channel_id.to_proto(),
4082 collaborators: left_buffer.collaborators,
4083 },
4084 &session.peer,
4085 );
4086 }
4087
4088 Ok(())
4089}
4090
4091fn project_left(project: &db::LeftProject, session: &Session) {
4092 for connection_id in &project.connection_ids {
4093 if project.should_unshare {
4094 session
4095 .peer
4096 .send(
4097 *connection_id,
4098 proto::UnshareProject {
4099 project_id: project.id.to_proto(),
4100 },
4101 )
4102 .trace_err();
4103 } else {
4104 session
4105 .peer
4106 .send(
4107 *connection_id,
4108 proto::RemoveProjectCollaborator {
4109 project_id: project.id.to_proto(),
4110 peer_id: Some(session.connection_id.into()),
4111 },
4112 )
4113 .trace_err();
4114 }
4115 }
4116}
4117
4118pub trait ResultExt {
4119 type Ok;
4120
4121 fn trace_err(self) -> Option<Self::Ok>;
4122}
4123
4124impl<T, E> ResultExt for Result<T, E>
4125where
4126 E: std::fmt::Debug,
4127{
4128 type Ok = T;
4129
4130 #[track_caller]
4131 fn trace_err(self) -> Option<T> {
4132 match self {
4133 Ok(value) => Some(value),
4134 Err(error) => {
4135 tracing::error!("{:?}", error);
4136 None
4137 }
4138 }
4139 }
4140}