1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
8 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
9 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
10 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use futures::TryFutureExt as _;
36use reqwest_client::ReqwestClient;
37use rpc::proto::{MultiLspQuery, split_repository_update};
38use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
39use tracing::Span;
40
41use futures::{
42 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
43 stream::FuturesUnordered,
44};
45use prometheus::{IntGauge, register_int_gauge};
46use rpc::{
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52};
53use semantic_version::SemanticVersion;
54use serde::{Serialize, Serializer};
55use std::{
56 any::TypeId,
57 future::Future,
58 marker::PhantomData,
59 mem,
60 net::SocketAddr,
61 ops::{Deref, DerefMut},
62 rc::Rc,
63 sync::{
64 Arc, OnceLock,
65 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
66 },
67 time::{Duration, Instant},
68};
69use time::OffsetDateTime;
70use tokio::sync::{Semaphore, watch};
71use tower::ServiceBuilder;
72use tracing::{
73 Instrument,
74 field::{self},
75 info_span, instrument,
76};
77
78pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
79
80// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
81pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
82
83const MESSAGE_COUNT_PER_PAGE: usize = 100;
84const MAX_MESSAGE_LEN: usize = 1024;
85const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
86const MAX_CONCURRENT_CONNECTIONS: usize = 512;
87
88static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
89
90const TOTAL_DURATION_MS: &str = "total_duration_ms";
91const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
92const QUEUE_DURATION_MS: &str = "queue_duration_ms";
93const HOST_WAITING_MS: &str = "host_waiting_ms";
94
95type MessageHandler =
96 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
97
98pub struct ConnectionGuard;
99
100impl ConnectionGuard {
101 pub fn try_acquire() -> Result<Self, ()> {
102 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
103 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
104 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
105 tracing::error!(
106 "too many concurrent connections: {}",
107 current_connections + 1
108 );
109 return Err(());
110 }
111 Ok(ConnectionGuard)
112 }
113}
114
115impl Drop for ConnectionGuard {
116 fn drop(&mut self) {
117 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
118 }
119}
120
121struct Response<R> {
122 peer: Arc<Peer>,
123 receipt: Receipt<R>,
124 responded: Arc<AtomicBool>,
125}
126
127impl<R: RequestMessage> Response<R> {
128 fn send(self, payload: R::Response) -> Result<()> {
129 self.responded.store(true, SeqCst);
130 self.peer.respond(self.receipt, payload)?;
131 Ok(())
132 }
133}
134
135#[derive(Clone, Debug)]
136pub enum Principal {
137 User(User),
138 Impersonated { user: User, admin: User },
139}
140
141impl Principal {
142 fn update_span(&self, span: &tracing::Span) {
143 match &self {
144 Principal::User(user) => {
145 span.record("user_id", user.id.0);
146 span.record("login", &user.github_login);
147 }
148 Principal::Impersonated { user, admin } => {
149 span.record("user_id", user.id.0);
150 span.record("login", &user.github_login);
151 span.record("impersonator", &admin.github_login);
152 }
153 }
154 }
155}
156
157#[derive(Clone)]
158struct MessageContext {
159 session: Session,
160 span: tracing::Span,
161}
162
163impl Deref for MessageContext {
164 type Target = Session;
165
166 fn deref(&self) -> &Self::Target {
167 &self.session
168 }
169}
170
171impl MessageContext {
172 pub fn forward_request<T: RequestMessage>(
173 &self,
174 receiver_id: ConnectionId,
175 request: T,
176 ) -> impl Future<Output = anyhow::Result<T::Response>> {
177 let request_start_time = Instant::now();
178 let span = self.span.clone();
179 tracing::info!("start forwarding request");
180 self.peer
181 .forward_request(self.connection_id, receiver_id, request)
182 .inspect(move |_| {
183 span.record(
184 HOST_WAITING_MS,
185 request_start_time.elapsed().as_micros() as f64 / 1000.0,
186 );
187 })
188 .inspect_err(|_| tracing::error!("error forwarding request"))
189 .inspect_ok(|_| tracing::info!("finished forwarding request"))
190 }
191}
192
193#[derive(Clone)]
194struct Session {
195 principal: Principal,
196 connection_id: ConnectionId,
197 db: Arc<tokio::sync::Mutex<DbHandle>>,
198 peer: Arc<Peer>,
199 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
200 app_state: Arc<AppState>,
201 supermaven_client: Option<Arc<SupermavenAdminApi>>,
202 /// The GeoIP country code for the user.
203 #[allow(unused)]
204 geoip_country_code: Option<String>,
205 #[allow(unused)]
206 system_id: Option<String>,
207 _executor: Executor,
208}
209
210impl Session {
211 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
212 #[cfg(test)]
213 tokio::task::yield_now().await;
214 let guard = self.db.lock().await;
215 #[cfg(test)]
216 tokio::task::yield_now().await;
217 guard
218 }
219
220 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
221 #[cfg(test)]
222 tokio::task::yield_now().await;
223 let guard = self.connection_pool.lock();
224 ConnectionPoolGuard {
225 guard,
226 _not_send: PhantomData,
227 }
228 }
229
230 fn is_staff(&self) -> bool {
231 match &self.principal {
232 Principal::User(user) => user.admin,
233 Principal::Impersonated { .. } => true,
234 }
235 }
236
237 fn user_id(&self) -> UserId {
238 match &self.principal {
239 Principal::User(user) => user.id,
240 Principal::Impersonated { user, .. } => user.id,
241 }
242 }
243
244 pub fn email(&self) -> Option<String> {
245 match &self.principal {
246 Principal::User(user) => user.email_address.clone(),
247 Principal::Impersonated { user, .. } => user.email_address.clone(),
248 }
249 }
250}
251
252impl Debug for Session {
253 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
254 let mut result = f.debug_struct("Session");
255 match &self.principal {
256 Principal::User(user) => {
257 result.field("user", &user.github_login);
258 }
259 Principal::Impersonated { user, admin } => {
260 result.field("user", &user.github_login);
261 result.field("impersonator", &admin.github_login);
262 }
263 }
264 result.field("connection_id", &self.connection_id).finish()
265 }
266}
267
268struct DbHandle(Arc<Database>);
269
270impl Deref for DbHandle {
271 type Target = Database;
272
273 fn deref(&self) -> &Self::Target {
274 self.0.as_ref()
275 }
276}
277
278pub struct Server {
279 id: parking_lot::Mutex<ServerId>,
280 peer: Arc<Peer>,
281 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
282 app_state: Arc<AppState>,
283 handlers: HashMap<TypeId, MessageHandler>,
284 teardown: watch::Sender<bool>,
285}
286
287pub(crate) struct ConnectionPoolGuard<'a> {
288 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
289 _not_send: PhantomData<Rc<()>>,
290}
291
292#[derive(Serialize)]
293pub struct ServerSnapshot<'a> {
294 peer: &'a Peer,
295 #[serde(serialize_with = "serialize_deref")]
296 connection_pool: ConnectionPoolGuard<'a>,
297}
298
299pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
300where
301 S: Serializer,
302 T: Deref<Target = U>,
303 U: Serialize,
304{
305 Serialize::serialize(value.deref(), serializer)
306}
307
308impl Server {
309 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
310 let mut server = Self {
311 id: parking_lot::Mutex::new(id),
312 peer: Peer::new(id.0 as u32),
313 app_state,
314 connection_pool: Default::default(),
315 handlers: Default::default(),
316 teardown: watch::channel(false).0,
317 };
318
319 server
320 .add_request_handler(ping)
321 .add_request_handler(create_room)
322 .add_request_handler(join_room)
323 .add_request_handler(rejoin_room)
324 .add_request_handler(leave_room)
325 .add_request_handler(set_room_participant_role)
326 .add_request_handler(call)
327 .add_request_handler(cancel_call)
328 .add_message_handler(decline_call)
329 .add_request_handler(update_participant_location)
330 .add_request_handler(share_project)
331 .add_message_handler(unshare_project)
332 .add_request_handler(join_project)
333 .add_message_handler(leave_project)
334 .add_request_handler(update_project)
335 .add_request_handler(update_worktree)
336 .add_request_handler(update_repository)
337 .add_request_handler(remove_repository)
338 .add_message_handler(start_language_server)
339 .add_message_handler(update_language_server)
340 .add_message_handler(update_diagnostic_summary)
341 .add_message_handler(update_worktree_settings)
342 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
343 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
344 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
345 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
346 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
347 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
348 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
349 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
350 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
351 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
352 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
353 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
354 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
355 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
356 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
357 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
358 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
359 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
360 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
361 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
362 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
363 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
364 .add_request_handler(
365 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
366 )
367 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
368 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
369 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
370 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
371 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
372 .add_request_handler(
373 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
374 )
375 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
376 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
377 .add_request_handler(
378 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
379 )
380 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
381 .add_request_handler(
382 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
383 )
384 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
385 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
386 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
387 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
388 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
389 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
390 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
391 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
392 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
393 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
394 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
395 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
396 .add_request_handler(
397 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
398 )
399 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
400 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
401 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
402 .add_request_handler(multi_lsp_query)
403 .add_request_handler(lsp_query)
404 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
405 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
406 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
407 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
408 .add_message_handler(create_buffer_for_peer)
409 .add_request_handler(update_buffer)
410 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
411 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
412 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
413 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
414 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
415 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
416 .add_message_handler(
417 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
418 )
419 .add_request_handler(get_users)
420 .add_request_handler(fuzzy_search_users)
421 .add_request_handler(request_contact)
422 .add_request_handler(remove_contact)
423 .add_request_handler(respond_to_contact_request)
424 .add_message_handler(subscribe_to_channels)
425 .add_request_handler(create_channel)
426 .add_request_handler(delete_channel)
427 .add_request_handler(invite_channel_member)
428 .add_request_handler(remove_channel_member)
429 .add_request_handler(set_channel_member_role)
430 .add_request_handler(set_channel_visibility)
431 .add_request_handler(rename_channel)
432 .add_request_handler(join_channel_buffer)
433 .add_request_handler(leave_channel_buffer)
434 .add_message_handler(update_channel_buffer)
435 .add_request_handler(rejoin_channel_buffers)
436 .add_request_handler(get_channel_members)
437 .add_request_handler(respond_to_channel_invite)
438 .add_request_handler(join_channel)
439 .add_request_handler(join_channel_chat)
440 .add_message_handler(leave_channel_chat)
441 .add_request_handler(send_channel_message)
442 .add_request_handler(remove_channel_message)
443 .add_request_handler(update_channel_message)
444 .add_request_handler(get_channel_messages)
445 .add_request_handler(get_channel_messages_by_id)
446 .add_request_handler(get_notifications)
447 .add_request_handler(mark_notification_as_read)
448 .add_request_handler(move_channel)
449 .add_request_handler(reorder_channel)
450 .add_request_handler(follow)
451 .add_message_handler(unfollow)
452 .add_message_handler(update_followers)
453 .add_message_handler(acknowledge_channel_message)
454 .add_message_handler(acknowledge_buffer_version)
455 .add_request_handler(get_supermaven_api_key)
456 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
457 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
458 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
459 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
460 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
461 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
462 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
463 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
464 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
465 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
466 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
467 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
468 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
469 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
470 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
471 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
472 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
473 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
474 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
475 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
476 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
477 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
478 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
479 .add_message_handler(update_context)
480 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
481 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>);
482
483 Arc::new(server)
484 }
485
486 pub async fn start(&self) -> Result<()> {
487 let server_id = *self.id.lock();
488 let app_state = self.app_state.clone();
489 let peer = self.peer.clone();
490 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
491 let pool = self.connection_pool.clone();
492 let livekit_client = self.app_state.livekit_client.clone();
493
494 let span = info_span!("start server");
495 self.app_state.executor.spawn_detached(
496 async move {
497 tracing::info!("waiting for cleanup timeout");
498 timeout.await;
499 tracing::info!("cleanup timeout expired, retrieving stale rooms");
500
501 app_state
502 .db
503 .delete_stale_channel_chat_participants(
504 &app_state.config.zed_environment,
505 server_id,
506 )
507 .await
508 .trace_err();
509
510 if let Some((room_ids, channel_ids)) = app_state
511 .db
512 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
513 .await
514 .trace_err()
515 {
516 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
517 tracing::info!(
518 stale_channel_buffer_count = channel_ids.len(),
519 "retrieved stale channel buffers"
520 );
521
522 for channel_id in channel_ids {
523 if let Some(refreshed_channel_buffer) = app_state
524 .db
525 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
526 .await
527 .trace_err()
528 {
529 for connection_id in refreshed_channel_buffer.connection_ids {
530 peer.send(
531 connection_id,
532 proto::UpdateChannelBufferCollaborators {
533 channel_id: channel_id.to_proto(),
534 collaborators: refreshed_channel_buffer
535 .collaborators
536 .clone(),
537 },
538 )
539 .trace_err();
540 }
541 }
542 }
543
544 for room_id in room_ids {
545 let mut contacts_to_update = HashSet::default();
546 let mut canceled_calls_to_user_ids = Vec::new();
547 let mut livekit_room = String::new();
548 let mut delete_livekit_room = false;
549
550 if let Some(mut refreshed_room) = app_state
551 .db
552 .clear_stale_room_participants(room_id, server_id)
553 .await
554 .trace_err()
555 {
556 tracing::info!(
557 room_id = room_id.0,
558 new_participant_count = refreshed_room.room.participants.len(),
559 "refreshed room"
560 );
561 room_updated(&refreshed_room.room, &peer);
562 if let Some(channel) = refreshed_room.channel.as_ref() {
563 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
564 }
565 contacts_to_update
566 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
567 contacts_to_update
568 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
569 canceled_calls_to_user_ids =
570 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
571 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
572 delete_livekit_room = refreshed_room.room.participants.is_empty();
573 }
574
575 {
576 let pool = pool.lock();
577 for canceled_user_id in canceled_calls_to_user_ids {
578 for connection_id in pool.user_connection_ids(canceled_user_id) {
579 peer.send(
580 connection_id,
581 proto::CallCanceled {
582 room_id: room_id.to_proto(),
583 },
584 )
585 .trace_err();
586 }
587 }
588 }
589
590 for user_id in contacts_to_update {
591 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
592 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
593 if let Some((busy, contacts)) = busy.zip(contacts) {
594 let pool = pool.lock();
595 let updated_contact = contact_for_user(user_id, busy, &pool);
596 for contact in contacts {
597 if let db::Contact::Accepted {
598 user_id: contact_user_id,
599 ..
600 } = contact
601 {
602 for contact_conn_id in
603 pool.user_connection_ids(contact_user_id)
604 {
605 peer.send(
606 contact_conn_id,
607 proto::UpdateContacts {
608 contacts: vec![updated_contact.clone()],
609 remove_contacts: Default::default(),
610 incoming_requests: Default::default(),
611 remove_incoming_requests: Default::default(),
612 outgoing_requests: Default::default(),
613 remove_outgoing_requests: Default::default(),
614 },
615 )
616 .trace_err();
617 }
618 }
619 }
620 }
621 }
622
623 if let Some(live_kit) = livekit_client.as_ref()
624 && delete_livekit_room
625 {
626 live_kit.delete_room(livekit_room).await.trace_err();
627 }
628 }
629 }
630
631 app_state
632 .db
633 .delete_stale_channel_chat_participants(
634 &app_state.config.zed_environment,
635 server_id,
636 )
637 .await
638 .trace_err();
639
640 app_state
641 .db
642 .clear_old_worktree_entries(server_id)
643 .await
644 .trace_err();
645
646 app_state
647 .db
648 .delete_stale_servers(&app_state.config.zed_environment, server_id)
649 .await
650 .trace_err();
651 }
652 .instrument(span),
653 );
654 Ok(())
655 }
656
657 pub fn teardown(&self) {
658 self.peer.teardown();
659 self.connection_pool.lock().reset();
660 let _ = self.teardown.send(true);
661 }
662
663 #[cfg(test)]
664 pub fn reset(&self, id: ServerId) {
665 self.teardown();
666 *self.id.lock() = id;
667 self.peer.reset(id.0 as u32);
668 let _ = self.teardown.send(false);
669 }
670
671 #[cfg(test)]
672 pub fn id(&self) -> ServerId {
673 *self.id.lock()
674 }
675
676 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
677 where
678 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
679 Fut: 'static + Send + Future<Output = Result<()>>,
680 M: EnvelopedMessage,
681 {
682 let prev_handler = self.handlers.insert(
683 TypeId::of::<M>(),
684 Box::new(move |envelope, session, span| {
685 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
686 let received_at = envelope.received_at;
687 tracing::info!("message received");
688 let start_time = Instant::now();
689 let future = (handler)(
690 *envelope,
691 MessageContext {
692 session,
693 span: span.clone(),
694 },
695 );
696 async move {
697 let result = future.await;
698 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
699 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
700 let queue_duration_ms = total_duration_ms - processing_duration_ms;
701 span.record(TOTAL_DURATION_MS, total_duration_ms);
702 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
703 span.record(QUEUE_DURATION_MS, queue_duration_ms);
704 match result {
705 Err(error) => {
706 tracing::error!(?error, "error handling message")
707 }
708 Ok(()) => tracing::info!("finished handling message"),
709 }
710 }
711 .boxed()
712 }),
713 );
714 if prev_handler.is_some() {
715 panic!("registered a handler for the same message twice");
716 }
717 self
718 }
719
720 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
721 where
722 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
723 Fut: 'static + Send + Future<Output = Result<()>>,
724 M: EnvelopedMessage,
725 {
726 self.add_handler(move |envelope, session| handler(envelope.payload, session));
727 self
728 }
729
730 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
731 where
732 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
733 Fut: Send + Future<Output = Result<()>>,
734 M: RequestMessage,
735 {
736 let handler = Arc::new(handler);
737 self.add_handler(move |envelope, session| {
738 let receipt = envelope.receipt();
739 let handler = handler.clone();
740 async move {
741 let peer = session.peer.clone();
742 let responded = Arc::new(AtomicBool::default());
743 let response = Response {
744 peer: peer.clone(),
745 responded: responded.clone(),
746 receipt,
747 };
748 match (handler)(envelope.payload, response, session).await {
749 Ok(()) => {
750 if responded.load(std::sync::atomic::Ordering::SeqCst) {
751 Ok(())
752 } else {
753 Err(anyhow!("handler did not send a response"))?
754 }
755 }
756 Err(error) => {
757 let proto_err = match &error {
758 Error::Internal(err) => err.to_proto(),
759 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
760 };
761 peer.respond_with_error(receipt, proto_err)?;
762 Err(error)
763 }
764 }
765 }
766 })
767 }
768
769 pub fn handle_connection(
770 self: &Arc<Self>,
771 connection: Connection,
772 address: String,
773 principal: Principal,
774 zed_version: ZedVersion,
775 release_channel: Option<String>,
776 user_agent: Option<String>,
777 geoip_country_code: Option<String>,
778 system_id: Option<String>,
779 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
780 executor: Executor,
781 connection_guard: Option<ConnectionGuard>,
782 ) -> impl Future<Output = ()> + use<> {
783 let this = self.clone();
784 let span = info_span!("handle connection", %address,
785 connection_id=field::Empty,
786 user_id=field::Empty,
787 login=field::Empty,
788 impersonator=field::Empty,
789 user_agent=field::Empty,
790 geoip_country_code=field::Empty,
791 release_channel=field::Empty,
792 );
793 principal.update_span(&span);
794 if let Some(user_agent) = user_agent {
795 span.record("user_agent", user_agent);
796 }
797 if let Some(release_channel) = release_channel {
798 span.record("release_channel", release_channel);
799 }
800
801 if let Some(country_code) = geoip_country_code.as_ref() {
802 span.record("geoip_country_code", country_code);
803 }
804
805 let mut teardown = self.teardown.subscribe();
806 async move {
807 if *teardown.borrow() {
808 tracing::error!("server is tearing down");
809 return;
810 }
811
812 let (connection_id, handle_io, mut incoming_rx) =
813 this.peer.add_connection(connection, {
814 let executor = executor.clone();
815 move |duration| executor.sleep(duration)
816 });
817 tracing::Span::current().record("connection_id", format!("{}", connection_id));
818
819 tracing::info!("connection opened");
820
821 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
822 let http_client = match ReqwestClient::user_agent(&user_agent) {
823 Ok(http_client) => Arc::new(http_client),
824 Err(error) => {
825 tracing::error!(?error, "failed to create HTTP client");
826 return;
827 }
828 };
829
830 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
831 |supermaven_admin_api_key| {
832 Arc::new(SupermavenAdminApi::new(
833 supermaven_admin_api_key.to_string(),
834 http_client.clone(),
835 ))
836 },
837 );
838
839 let session = Session {
840 principal: principal.clone(),
841 connection_id,
842 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
843 peer: this.peer.clone(),
844 connection_pool: this.connection_pool.clone(),
845 app_state: this.app_state.clone(),
846 geoip_country_code,
847 system_id,
848 _executor: executor.clone(),
849 supermaven_client,
850 };
851
852 if let Err(error) = this
853 .send_initial_client_update(
854 connection_id,
855 zed_version,
856 send_connection_id,
857 &session,
858 )
859 .await
860 {
861 tracing::error!(?error, "failed to send initial client update");
862 return;
863 }
864 drop(connection_guard);
865
866 let handle_io = handle_io.fuse();
867 futures::pin_mut!(handle_io);
868
869 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
870 // This prevents deadlocks when e.g., client A performs a request to client B and
871 // client B performs a request to client A. If both clients stop processing further
872 // messages until their respective request completes, they won't have a chance to
873 // respond to the other client's request and cause a deadlock.
874 //
875 // This arrangement ensures we will attempt to process earlier messages first, but fall
876 // back to processing messages arrived later in the spirit of making progress.
877 const MAX_CONCURRENT_HANDLERS: usize = 256;
878 let mut foreground_message_handlers = FuturesUnordered::new();
879 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
880 let get_concurrent_handlers = {
881 let concurrent_handlers = concurrent_handlers.clone();
882 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
883 };
884 loop {
885 let next_message = async {
886 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
887 let message = incoming_rx.next().await;
888 // Cache the concurrent_handlers here, so that we know what the
889 // queue looks like as each handler starts
890 (permit, message, get_concurrent_handlers())
891 }
892 .fuse();
893 futures::pin_mut!(next_message);
894 futures::select_biased! {
895 _ = teardown.changed().fuse() => return,
896 result = handle_io => {
897 if let Err(error) = result {
898 tracing::error!(?error, "error handling I/O");
899 }
900 break;
901 }
902 _ = foreground_message_handlers.next() => {}
903 next_message = next_message => {
904 let (permit, message, concurrent_handlers) = next_message;
905 if let Some(message) = message {
906 let type_name = message.payload_type_name();
907 // note: we copy all the fields from the parent span so we can query them in the logs.
908 // (https://github.com/tokio-rs/tracing/issues/2670).
909 let span = tracing::info_span!("receive message",
910 %connection_id,
911 %address,
912 type_name,
913 concurrent_handlers,
914 user_id=field::Empty,
915 login=field::Empty,
916 impersonator=field::Empty,
917 // todo(lsp) remove after Zed Stable hits v0.204.x
918 multi_lsp_query_request=field::Empty,
919 lsp_query_request=field::Empty,
920 release_channel=field::Empty,
921 { TOTAL_DURATION_MS }=field::Empty,
922 { PROCESSING_DURATION_MS }=field::Empty,
923 { QUEUE_DURATION_MS }=field::Empty,
924 { HOST_WAITING_MS }=field::Empty
925 );
926 principal.update_span(&span);
927 let span_enter = span.enter();
928 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
929 let is_background = message.is_background();
930 let handle_message = (handler)(message, session.clone(), span.clone());
931 drop(span_enter);
932
933 let handle_message = async move {
934 handle_message.await;
935 drop(permit);
936 }.instrument(span);
937 if is_background {
938 executor.spawn_detached(handle_message);
939 } else {
940 foreground_message_handlers.push(handle_message);
941 }
942 } else {
943 tracing::error!("no message handler");
944 }
945 } else {
946 tracing::info!("connection closed");
947 break;
948 }
949 }
950 }
951 }
952
953 drop(foreground_message_handlers);
954 let concurrent_handlers = get_concurrent_handlers();
955 tracing::info!(concurrent_handlers, "signing out");
956 if let Err(error) = connection_lost(session, teardown, executor).await {
957 tracing::error!(?error, "error signing out");
958 }
959 }
960 .instrument(span)
961 }
962
963 async fn send_initial_client_update(
964 &self,
965 connection_id: ConnectionId,
966 zed_version: ZedVersion,
967 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
968 session: &Session,
969 ) -> Result<()> {
970 self.peer.send(
971 connection_id,
972 proto::Hello {
973 peer_id: Some(connection_id.into()),
974 },
975 )?;
976 tracing::info!("sent hello message");
977 if let Some(send_connection_id) = send_connection_id.take() {
978 let _ = send_connection_id.send(connection_id);
979 }
980
981 match &session.principal {
982 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
983 if !user.connected_once {
984 self.peer.send(connection_id, proto::ShowContacts {})?;
985 self.app_state
986 .db
987 .set_user_connected_once(user.id, true)
988 .await?;
989 }
990
991 let contacts = self.app_state.db.get_contacts(user.id).await?;
992
993 {
994 let mut pool = self.connection_pool.lock();
995 pool.add_connection(connection_id, user.id, user.admin, zed_version);
996 self.peer.send(
997 connection_id,
998 build_initial_contacts_update(contacts, &pool),
999 )?;
1000 }
1001
1002 if should_auto_subscribe_to_channels(zed_version) {
1003 subscribe_user_to_channels(user.id, session).await?;
1004 }
1005
1006 if let Some(incoming_call) =
1007 self.app_state.db.incoming_call_for_user(user.id).await?
1008 {
1009 self.peer.send(connection_id, incoming_call)?;
1010 }
1011
1012 update_user_contacts(user.id, session).await?;
1013 }
1014 }
1015
1016 Ok(())
1017 }
1018
1019 pub async fn invite_code_redeemed(
1020 self: &Arc<Self>,
1021 inviter_id: UserId,
1022 invitee_id: UserId,
1023 ) -> Result<()> {
1024 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1025 && let Some(code) = &user.invite_code
1026 {
1027 let pool = self.connection_pool.lock();
1028 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1029 for connection_id in pool.user_connection_ids(inviter_id) {
1030 self.peer.send(
1031 connection_id,
1032 proto::UpdateContacts {
1033 contacts: vec![invitee_contact.clone()],
1034 ..Default::default()
1035 },
1036 )?;
1037 self.peer.send(
1038 connection_id,
1039 proto::UpdateInviteInfo {
1040 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1041 count: user.invite_count as u32,
1042 },
1043 )?;
1044 }
1045 }
1046 Ok(())
1047 }
1048
1049 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1050 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1051 && let Some(invite_code) = &user.invite_code
1052 {
1053 let pool = self.connection_pool.lock();
1054 for connection_id in pool.user_connection_ids(user_id) {
1055 self.peer.send(
1056 connection_id,
1057 proto::UpdateInviteInfo {
1058 url: format!(
1059 "{}{}",
1060 self.app_state.config.invite_link_prefix, invite_code
1061 ),
1062 count: user.invite_count as u32,
1063 },
1064 )?;
1065 }
1066 }
1067 Ok(())
1068 }
1069
1070 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1071 ServerSnapshot {
1072 connection_pool: ConnectionPoolGuard {
1073 guard: self.connection_pool.lock(),
1074 _not_send: PhantomData,
1075 },
1076 peer: &self.peer,
1077 }
1078 }
1079}
1080
1081impl Deref for ConnectionPoolGuard<'_> {
1082 type Target = ConnectionPool;
1083
1084 fn deref(&self) -> &Self::Target {
1085 &self.guard
1086 }
1087}
1088
1089impl DerefMut for ConnectionPoolGuard<'_> {
1090 fn deref_mut(&mut self) -> &mut Self::Target {
1091 &mut self.guard
1092 }
1093}
1094
1095impl Drop for ConnectionPoolGuard<'_> {
1096 fn drop(&mut self) {
1097 #[cfg(test)]
1098 self.check_invariants();
1099 }
1100}
1101
1102fn broadcast<F>(
1103 sender_id: Option<ConnectionId>,
1104 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1105 mut f: F,
1106) where
1107 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1108{
1109 for receiver_id in receiver_ids {
1110 if Some(receiver_id) != sender_id
1111 && let Err(error) = f(receiver_id)
1112 {
1113 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1114 }
1115 }
1116}
1117
1118pub struct ProtocolVersion(u32);
1119
1120impl Header for ProtocolVersion {
1121 fn name() -> &'static HeaderName {
1122 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1123 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1124 }
1125
1126 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1127 where
1128 Self: Sized,
1129 I: Iterator<Item = &'i axum::http::HeaderValue>,
1130 {
1131 let version = values
1132 .next()
1133 .ok_or_else(axum::headers::Error::invalid)?
1134 .to_str()
1135 .map_err(|_| axum::headers::Error::invalid())?
1136 .parse()
1137 .map_err(|_| axum::headers::Error::invalid())?;
1138 Ok(Self(version))
1139 }
1140
1141 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1142 values.extend([self.0.to_string().parse().unwrap()]);
1143 }
1144}
1145
1146pub struct AppVersionHeader(SemanticVersion);
1147impl Header for AppVersionHeader {
1148 fn name() -> &'static HeaderName {
1149 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1150 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1151 }
1152
1153 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1154 where
1155 Self: Sized,
1156 I: Iterator<Item = &'i axum::http::HeaderValue>,
1157 {
1158 let version = values
1159 .next()
1160 .ok_or_else(axum::headers::Error::invalid)?
1161 .to_str()
1162 .map_err(|_| axum::headers::Error::invalid())?
1163 .parse()
1164 .map_err(|_| axum::headers::Error::invalid())?;
1165 Ok(Self(version))
1166 }
1167
1168 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1169 values.extend([self.0.to_string().parse().unwrap()]);
1170 }
1171}
1172
1173#[derive(Debug)]
1174pub struct ReleaseChannelHeader(String);
1175
1176impl Header for ReleaseChannelHeader {
1177 fn name() -> &'static HeaderName {
1178 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1179 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1180 }
1181
1182 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1183 where
1184 Self: Sized,
1185 I: Iterator<Item = &'i axum::http::HeaderValue>,
1186 {
1187 Ok(Self(
1188 values
1189 .next()
1190 .ok_or_else(axum::headers::Error::invalid)?
1191 .to_str()
1192 .map_err(|_| axum::headers::Error::invalid())?
1193 .to_owned(),
1194 ))
1195 }
1196
1197 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1198 values.extend([self.0.parse().unwrap()]);
1199 }
1200}
1201
1202pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1203 Router::new()
1204 .route("/rpc", get(handle_websocket_request))
1205 .layer(
1206 ServiceBuilder::new()
1207 .layer(Extension(server.app_state.clone()))
1208 .layer(middleware::from_fn(auth::validate_header)),
1209 )
1210 .route("/metrics", get(handle_metrics))
1211 .layer(Extension(server))
1212}
1213
1214pub async fn handle_websocket_request(
1215 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1216 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1217 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1218 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1219 Extension(server): Extension<Arc<Server>>,
1220 Extension(principal): Extension<Principal>,
1221 user_agent: Option<TypedHeader<UserAgent>>,
1222 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1223 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1224 ws: WebSocketUpgrade,
1225) -> axum::response::Response {
1226 if protocol_version != rpc::PROTOCOL_VERSION {
1227 return (
1228 StatusCode::UPGRADE_REQUIRED,
1229 "client must be upgraded".to_string(),
1230 )
1231 .into_response();
1232 }
1233
1234 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1235 return (
1236 StatusCode::UPGRADE_REQUIRED,
1237 "no version header found".to_string(),
1238 )
1239 .into_response();
1240 };
1241
1242 let release_channel = release_channel_header.map(|header| header.0.0);
1243
1244 if !version.can_collaborate() {
1245 return (
1246 StatusCode::UPGRADE_REQUIRED,
1247 "client must be upgraded".to_string(),
1248 )
1249 .into_response();
1250 }
1251
1252 let socket_address = socket_address.to_string();
1253
1254 // Acquire connection guard before WebSocket upgrade
1255 let connection_guard = match ConnectionGuard::try_acquire() {
1256 Ok(guard) => guard,
1257 Err(()) => {
1258 return (
1259 StatusCode::SERVICE_UNAVAILABLE,
1260 "Too many concurrent connections",
1261 )
1262 .into_response();
1263 }
1264 };
1265
1266 ws.on_upgrade(move |socket| {
1267 let socket = socket
1268 .map_ok(to_tungstenite_message)
1269 .err_into()
1270 .with(|message| async move { to_axum_message(message) });
1271 let connection = Connection::new(Box::pin(socket));
1272 async move {
1273 server
1274 .handle_connection(
1275 connection,
1276 socket_address,
1277 principal,
1278 version,
1279 release_channel,
1280 user_agent.map(|header| header.to_string()),
1281 country_code_header.map(|header| header.to_string()),
1282 system_id_header.map(|header| header.to_string()),
1283 None,
1284 Executor::Production,
1285 Some(connection_guard),
1286 )
1287 .await;
1288 }
1289 })
1290}
1291
1292pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1293 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1294 let connections_metric = CONNECTIONS_METRIC
1295 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1296
1297 let connections = server
1298 .connection_pool
1299 .lock()
1300 .connections()
1301 .filter(|connection| !connection.admin)
1302 .count();
1303 connections_metric.set(connections as _);
1304
1305 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1306 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1307 register_int_gauge!(
1308 "shared_projects",
1309 "number of open projects with one or more guests"
1310 )
1311 .unwrap()
1312 });
1313
1314 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1315 shared_projects_metric.set(shared_projects as _);
1316
1317 let encoder = prometheus::TextEncoder::new();
1318 let metric_families = prometheus::gather();
1319 let encoded_metrics = encoder
1320 .encode_to_string(&metric_families)
1321 .map_err(|err| anyhow!("{err}"))?;
1322 Ok(encoded_metrics)
1323}
1324
1325#[instrument(err, skip(executor))]
1326async fn connection_lost(
1327 session: Session,
1328 mut teardown: watch::Receiver<bool>,
1329 executor: Executor,
1330) -> Result<()> {
1331 session.peer.disconnect(session.connection_id);
1332 session
1333 .connection_pool()
1334 .await
1335 .remove_connection(session.connection_id)?;
1336
1337 session
1338 .db()
1339 .await
1340 .connection_lost(session.connection_id)
1341 .await
1342 .trace_err();
1343
1344 futures::select_biased! {
1345 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1346
1347 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1348 leave_room_for_session(&session, session.connection_id).await.trace_err();
1349 leave_channel_buffers_for_session(&session)
1350 .await
1351 .trace_err();
1352
1353 if !session
1354 .connection_pool()
1355 .await
1356 .is_user_online(session.user_id())
1357 {
1358 let db = session.db().await;
1359 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1360 room_updated(&room, &session.peer);
1361 }
1362 }
1363
1364 update_user_contacts(session.user_id(), &session).await?;
1365 },
1366 _ = teardown.changed().fuse() => {}
1367 }
1368
1369 Ok(())
1370}
1371
1372/// Acknowledges a ping from a client, used to keep the connection alive.
1373async fn ping(
1374 _: proto::Ping,
1375 response: Response<proto::Ping>,
1376 _session: MessageContext,
1377) -> Result<()> {
1378 response.send(proto::Ack {})?;
1379 Ok(())
1380}
1381
1382/// Creates a new room for calling (outside of channels)
1383async fn create_room(
1384 _request: proto::CreateRoom,
1385 response: Response<proto::CreateRoom>,
1386 session: MessageContext,
1387) -> Result<()> {
1388 let livekit_room = nanoid::nanoid!(30);
1389
1390 let live_kit_connection_info = util::maybe!(async {
1391 let live_kit = session.app_state.livekit_client.as_ref();
1392 let live_kit = live_kit?;
1393 let user_id = session.user_id().to_string();
1394
1395 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1396
1397 Some(proto::LiveKitConnectionInfo {
1398 server_url: live_kit.url().into(),
1399 token,
1400 can_publish: true,
1401 })
1402 })
1403 .await;
1404
1405 let room = session
1406 .db()
1407 .await
1408 .create_room(session.user_id(), session.connection_id, &livekit_room)
1409 .await?;
1410
1411 response.send(proto::CreateRoomResponse {
1412 room: Some(room.clone()),
1413 live_kit_connection_info,
1414 })?;
1415
1416 update_user_contacts(session.user_id(), &session).await?;
1417 Ok(())
1418}
1419
1420/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1421async fn join_room(
1422 request: proto::JoinRoom,
1423 response: Response<proto::JoinRoom>,
1424 session: MessageContext,
1425) -> Result<()> {
1426 let room_id = RoomId::from_proto(request.id);
1427
1428 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1429
1430 if let Some(channel_id) = channel_id {
1431 return join_channel_internal(channel_id, Box::new(response), session).await;
1432 }
1433
1434 let joined_room = {
1435 let room = session
1436 .db()
1437 .await
1438 .join_room(room_id, session.user_id(), session.connection_id)
1439 .await?;
1440 room_updated(&room.room, &session.peer);
1441 room.into_inner()
1442 };
1443
1444 for connection_id in session
1445 .connection_pool()
1446 .await
1447 .user_connection_ids(session.user_id())
1448 {
1449 session
1450 .peer
1451 .send(
1452 connection_id,
1453 proto::CallCanceled {
1454 room_id: room_id.to_proto(),
1455 },
1456 )
1457 .trace_err();
1458 }
1459
1460 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1461 {
1462 live_kit
1463 .room_token(
1464 &joined_room.room.livekit_room,
1465 &session.user_id().to_string(),
1466 )
1467 .trace_err()
1468 .map(|token| proto::LiveKitConnectionInfo {
1469 server_url: live_kit.url().into(),
1470 token,
1471 can_publish: true,
1472 })
1473 } else {
1474 None
1475 };
1476
1477 response.send(proto::JoinRoomResponse {
1478 room: Some(joined_room.room),
1479 channel_id: None,
1480 live_kit_connection_info,
1481 })?;
1482
1483 update_user_contacts(session.user_id(), &session).await?;
1484 Ok(())
1485}
1486
1487/// Rejoin room is used to reconnect to a room after connection errors.
1488async fn rejoin_room(
1489 request: proto::RejoinRoom,
1490 response: Response<proto::RejoinRoom>,
1491 session: MessageContext,
1492) -> Result<()> {
1493 let room;
1494 let channel;
1495 {
1496 let mut rejoined_room = session
1497 .db()
1498 .await
1499 .rejoin_room(request, session.user_id(), session.connection_id)
1500 .await?;
1501
1502 response.send(proto::RejoinRoomResponse {
1503 room: Some(rejoined_room.room.clone()),
1504 reshared_projects: rejoined_room
1505 .reshared_projects
1506 .iter()
1507 .map(|project| proto::ResharedProject {
1508 id: project.id.to_proto(),
1509 collaborators: project
1510 .collaborators
1511 .iter()
1512 .map(|collaborator| collaborator.to_proto())
1513 .collect(),
1514 })
1515 .collect(),
1516 rejoined_projects: rejoined_room
1517 .rejoined_projects
1518 .iter()
1519 .map(|rejoined_project| rejoined_project.to_proto())
1520 .collect(),
1521 })?;
1522 room_updated(&rejoined_room.room, &session.peer);
1523
1524 for project in &rejoined_room.reshared_projects {
1525 for collaborator in &project.collaborators {
1526 session
1527 .peer
1528 .send(
1529 collaborator.connection_id,
1530 proto::UpdateProjectCollaborator {
1531 project_id: project.id.to_proto(),
1532 old_peer_id: Some(project.old_connection_id.into()),
1533 new_peer_id: Some(session.connection_id.into()),
1534 },
1535 )
1536 .trace_err();
1537 }
1538
1539 broadcast(
1540 Some(session.connection_id),
1541 project
1542 .collaborators
1543 .iter()
1544 .map(|collaborator| collaborator.connection_id),
1545 |connection_id| {
1546 session.peer.forward_send(
1547 session.connection_id,
1548 connection_id,
1549 proto::UpdateProject {
1550 project_id: project.id.to_proto(),
1551 worktrees: project.worktrees.clone(),
1552 },
1553 )
1554 },
1555 );
1556 }
1557
1558 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1559
1560 let rejoined_room = rejoined_room.into_inner();
1561
1562 room = rejoined_room.room;
1563 channel = rejoined_room.channel;
1564 }
1565
1566 if let Some(channel) = channel {
1567 channel_updated(
1568 &channel,
1569 &room,
1570 &session.peer,
1571 &*session.connection_pool().await,
1572 );
1573 }
1574
1575 update_user_contacts(session.user_id(), &session).await?;
1576 Ok(())
1577}
1578
1579fn notify_rejoined_projects(
1580 rejoined_projects: &mut Vec<RejoinedProject>,
1581 session: &Session,
1582) -> Result<()> {
1583 for project in rejoined_projects.iter() {
1584 for collaborator in &project.collaborators {
1585 session
1586 .peer
1587 .send(
1588 collaborator.connection_id,
1589 proto::UpdateProjectCollaborator {
1590 project_id: project.id.to_proto(),
1591 old_peer_id: Some(project.old_connection_id.into()),
1592 new_peer_id: Some(session.connection_id.into()),
1593 },
1594 )
1595 .trace_err();
1596 }
1597 }
1598
1599 for project in rejoined_projects {
1600 for worktree in mem::take(&mut project.worktrees) {
1601 // Stream this worktree's entries.
1602 let message = proto::UpdateWorktree {
1603 project_id: project.id.to_proto(),
1604 worktree_id: worktree.id,
1605 abs_path: worktree.abs_path.clone(),
1606 root_name: worktree.root_name,
1607 updated_entries: worktree.updated_entries,
1608 removed_entries: worktree.removed_entries,
1609 scan_id: worktree.scan_id,
1610 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1611 updated_repositories: worktree.updated_repositories,
1612 removed_repositories: worktree.removed_repositories,
1613 };
1614 for update in proto::split_worktree_update(message) {
1615 session.peer.send(session.connection_id, update)?;
1616 }
1617
1618 // Stream this worktree's diagnostics.
1619 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1620 if let Some(summary) = worktree_diagnostics.next() {
1621 let message = proto::UpdateDiagnosticSummary {
1622 project_id: project.id.to_proto(),
1623 worktree_id: worktree.id,
1624 summary: Some(summary),
1625 more_summaries: worktree_diagnostics.collect(),
1626 };
1627 session.peer.send(session.connection_id, message)?;
1628 }
1629
1630 for settings_file in worktree.settings_files {
1631 session.peer.send(
1632 session.connection_id,
1633 proto::UpdateWorktreeSettings {
1634 project_id: project.id.to_proto(),
1635 worktree_id: worktree.id,
1636 path: settings_file.path,
1637 content: Some(settings_file.content),
1638 kind: Some(settings_file.kind.to_proto().into()),
1639 },
1640 )?;
1641 }
1642 }
1643
1644 for repository in mem::take(&mut project.updated_repositories) {
1645 for update in split_repository_update(repository) {
1646 session.peer.send(session.connection_id, update)?;
1647 }
1648 }
1649
1650 for id in mem::take(&mut project.removed_repositories) {
1651 session.peer.send(
1652 session.connection_id,
1653 proto::RemoveRepository {
1654 project_id: project.id.to_proto(),
1655 id,
1656 },
1657 )?;
1658 }
1659 }
1660
1661 Ok(())
1662}
1663
1664/// leave room disconnects from the room.
1665async fn leave_room(
1666 _: proto::LeaveRoom,
1667 response: Response<proto::LeaveRoom>,
1668 session: MessageContext,
1669) -> Result<()> {
1670 leave_room_for_session(&session, session.connection_id).await?;
1671 response.send(proto::Ack {})?;
1672 Ok(())
1673}
1674
1675/// Updates the permissions of someone else in the room.
1676async fn set_room_participant_role(
1677 request: proto::SetRoomParticipantRole,
1678 response: Response<proto::SetRoomParticipantRole>,
1679 session: MessageContext,
1680) -> Result<()> {
1681 let user_id = UserId::from_proto(request.user_id);
1682 let role = ChannelRole::from(request.role());
1683
1684 let (livekit_room, can_publish) = {
1685 let room = session
1686 .db()
1687 .await
1688 .set_room_participant_role(
1689 session.user_id(),
1690 RoomId::from_proto(request.room_id),
1691 user_id,
1692 role,
1693 )
1694 .await?;
1695
1696 let livekit_room = room.livekit_room.clone();
1697 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1698 room_updated(&room, &session.peer);
1699 (livekit_room, can_publish)
1700 };
1701
1702 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1703 live_kit
1704 .update_participant(
1705 livekit_room.clone(),
1706 request.user_id.to_string(),
1707 livekit_api::proto::ParticipantPermission {
1708 can_subscribe: true,
1709 can_publish,
1710 can_publish_data: can_publish,
1711 hidden: false,
1712 recorder: false,
1713 },
1714 )
1715 .await
1716 .trace_err();
1717 }
1718
1719 response.send(proto::Ack {})?;
1720 Ok(())
1721}
1722
1723/// Call someone else into the current room
1724async fn call(
1725 request: proto::Call,
1726 response: Response<proto::Call>,
1727 session: MessageContext,
1728) -> Result<()> {
1729 let room_id = RoomId::from_proto(request.room_id);
1730 let calling_user_id = session.user_id();
1731 let calling_connection_id = session.connection_id;
1732 let called_user_id = UserId::from_proto(request.called_user_id);
1733 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1734 if !session
1735 .db()
1736 .await
1737 .has_contact(calling_user_id, called_user_id)
1738 .await?
1739 {
1740 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1741 }
1742
1743 let incoming_call = {
1744 let (room, incoming_call) = &mut *session
1745 .db()
1746 .await
1747 .call(
1748 room_id,
1749 calling_user_id,
1750 calling_connection_id,
1751 called_user_id,
1752 initial_project_id,
1753 )
1754 .await?;
1755 room_updated(room, &session.peer);
1756 mem::take(incoming_call)
1757 };
1758 update_user_contacts(called_user_id, &session).await?;
1759
1760 let mut calls = session
1761 .connection_pool()
1762 .await
1763 .user_connection_ids(called_user_id)
1764 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1765 .collect::<FuturesUnordered<_>>();
1766
1767 while let Some(call_response) = calls.next().await {
1768 match call_response.as_ref() {
1769 Ok(_) => {
1770 response.send(proto::Ack {})?;
1771 return Ok(());
1772 }
1773 Err(_) => {
1774 call_response.trace_err();
1775 }
1776 }
1777 }
1778
1779 {
1780 let room = session
1781 .db()
1782 .await
1783 .call_failed(room_id, called_user_id)
1784 .await?;
1785 room_updated(&room, &session.peer);
1786 }
1787 update_user_contacts(called_user_id, &session).await?;
1788
1789 Err(anyhow!("failed to ring user"))?
1790}
1791
1792/// Cancel an outgoing call.
1793async fn cancel_call(
1794 request: proto::CancelCall,
1795 response: Response<proto::CancelCall>,
1796 session: MessageContext,
1797) -> Result<()> {
1798 let called_user_id = UserId::from_proto(request.called_user_id);
1799 let room_id = RoomId::from_proto(request.room_id);
1800 {
1801 let room = session
1802 .db()
1803 .await
1804 .cancel_call(room_id, session.connection_id, called_user_id)
1805 .await?;
1806 room_updated(&room, &session.peer);
1807 }
1808
1809 for connection_id in session
1810 .connection_pool()
1811 .await
1812 .user_connection_ids(called_user_id)
1813 {
1814 session
1815 .peer
1816 .send(
1817 connection_id,
1818 proto::CallCanceled {
1819 room_id: room_id.to_proto(),
1820 },
1821 )
1822 .trace_err();
1823 }
1824 response.send(proto::Ack {})?;
1825
1826 update_user_contacts(called_user_id, &session).await?;
1827 Ok(())
1828}
1829
1830/// Decline an incoming call.
1831async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1832 let room_id = RoomId::from_proto(message.room_id);
1833 {
1834 let room = session
1835 .db()
1836 .await
1837 .decline_call(Some(room_id), session.user_id())
1838 .await?
1839 .context("declining call")?;
1840 room_updated(&room, &session.peer);
1841 }
1842
1843 for connection_id in session
1844 .connection_pool()
1845 .await
1846 .user_connection_ids(session.user_id())
1847 {
1848 session
1849 .peer
1850 .send(
1851 connection_id,
1852 proto::CallCanceled {
1853 room_id: room_id.to_proto(),
1854 },
1855 )
1856 .trace_err();
1857 }
1858 update_user_contacts(session.user_id(), &session).await?;
1859 Ok(())
1860}
1861
1862/// Updates other participants in the room with your current location.
1863async fn update_participant_location(
1864 request: proto::UpdateParticipantLocation,
1865 response: Response<proto::UpdateParticipantLocation>,
1866 session: MessageContext,
1867) -> Result<()> {
1868 let room_id = RoomId::from_proto(request.room_id);
1869 let location = request.location.context("invalid location")?;
1870
1871 let db = session.db().await;
1872 let room = db
1873 .update_room_participant_location(room_id, session.connection_id, location)
1874 .await?;
1875
1876 room_updated(&room, &session.peer);
1877 response.send(proto::Ack {})?;
1878 Ok(())
1879}
1880
1881/// Share a project into the room.
1882async fn share_project(
1883 request: proto::ShareProject,
1884 response: Response<proto::ShareProject>,
1885 session: MessageContext,
1886) -> Result<()> {
1887 let (project_id, room) = &*session
1888 .db()
1889 .await
1890 .share_project(
1891 RoomId::from_proto(request.room_id),
1892 session.connection_id,
1893 &request.worktrees,
1894 request.is_ssh_project,
1895 )
1896 .await?;
1897 response.send(proto::ShareProjectResponse {
1898 project_id: project_id.to_proto(),
1899 })?;
1900 room_updated(room, &session.peer);
1901
1902 Ok(())
1903}
1904
1905/// Unshare a project from the room.
1906async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1907 let project_id = ProjectId::from_proto(message.project_id);
1908 unshare_project_internal(project_id, session.connection_id, &session).await
1909}
1910
1911async fn unshare_project_internal(
1912 project_id: ProjectId,
1913 connection_id: ConnectionId,
1914 session: &Session,
1915) -> Result<()> {
1916 let delete = {
1917 let room_guard = session
1918 .db()
1919 .await
1920 .unshare_project(project_id, connection_id)
1921 .await?;
1922
1923 let (delete, room, guest_connection_ids) = &*room_guard;
1924
1925 let message = proto::UnshareProject {
1926 project_id: project_id.to_proto(),
1927 };
1928
1929 broadcast(
1930 Some(connection_id),
1931 guest_connection_ids.iter().copied(),
1932 |conn_id| session.peer.send(conn_id, message.clone()),
1933 );
1934 if let Some(room) = room {
1935 room_updated(room, &session.peer);
1936 }
1937
1938 *delete
1939 };
1940
1941 if delete {
1942 let db = session.db().await;
1943 db.delete_project(project_id).await?;
1944 }
1945
1946 Ok(())
1947}
1948
1949/// Join someone elses shared project.
1950async fn join_project(
1951 request: proto::JoinProject,
1952 response: Response<proto::JoinProject>,
1953 session: MessageContext,
1954) -> Result<()> {
1955 let project_id = ProjectId::from_proto(request.project_id);
1956
1957 tracing::info!(%project_id, "join project");
1958
1959 let db = session.db().await;
1960 let (project, replica_id) = &mut *db
1961 .join_project(
1962 project_id,
1963 session.connection_id,
1964 session.user_id(),
1965 request.committer_name.clone(),
1966 request.committer_email.clone(),
1967 )
1968 .await?;
1969 drop(db);
1970 tracing::info!(%project_id, "join remote project");
1971 let collaborators = project
1972 .collaborators
1973 .iter()
1974 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1975 .map(|collaborator| collaborator.to_proto())
1976 .collect::<Vec<_>>();
1977 let project_id = project.id;
1978 let guest_user_id = session.user_id();
1979
1980 let worktrees = project
1981 .worktrees
1982 .iter()
1983 .map(|(id, worktree)| proto::WorktreeMetadata {
1984 id: *id,
1985 root_name: worktree.root_name.clone(),
1986 visible: worktree.visible,
1987 abs_path: worktree.abs_path.clone(),
1988 })
1989 .collect::<Vec<_>>();
1990
1991 let add_project_collaborator = proto::AddProjectCollaborator {
1992 project_id: project_id.to_proto(),
1993 collaborator: Some(proto::Collaborator {
1994 peer_id: Some(session.connection_id.into()),
1995 replica_id: replica_id.0 as u32,
1996 user_id: guest_user_id.to_proto(),
1997 is_host: false,
1998 committer_name: request.committer_name.clone(),
1999 committer_email: request.committer_email.clone(),
2000 }),
2001 };
2002
2003 for collaborator in &collaborators {
2004 session
2005 .peer
2006 .send(
2007 collaborator.peer_id.unwrap().into(),
2008 add_project_collaborator.clone(),
2009 )
2010 .trace_err();
2011 }
2012
2013 // First, we send the metadata associated with each worktree.
2014 let (language_servers, language_server_capabilities) = project
2015 .language_servers
2016 .clone()
2017 .into_iter()
2018 .map(|server| (server.server, server.capabilities))
2019 .unzip();
2020 response.send(proto::JoinProjectResponse {
2021 project_id: project.id.0 as u64,
2022 worktrees,
2023 replica_id: replica_id.0 as u32,
2024 collaborators,
2025 language_servers,
2026 language_server_capabilities,
2027 role: project.role.into(),
2028 })?;
2029
2030 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2031 // Stream this worktree's entries.
2032 let message = proto::UpdateWorktree {
2033 project_id: project_id.to_proto(),
2034 worktree_id,
2035 abs_path: worktree.abs_path.clone(),
2036 root_name: worktree.root_name,
2037 updated_entries: worktree.entries,
2038 removed_entries: Default::default(),
2039 scan_id: worktree.scan_id,
2040 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2041 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2042 removed_repositories: Default::default(),
2043 };
2044 for update in proto::split_worktree_update(message) {
2045 session.peer.send(session.connection_id, update.clone())?;
2046 }
2047
2048 // Stream this worktree's diagnostics.
2049 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2050 if let Some(summary) = worktree_diagnostics.next() {
2051 let message = proto::UpdateDiagnosticSummary {
2052 project_id: project.id.to_proto(),
2053 worktree_id: worktree.id,
2054 summary: Some(summary),
2055 more_summaries: worktree_diagnostics.collect(),
2056 };
2057 session.peer.send(session.connection_id, message)?;
2058 }
2059
2060 for settings_file in worktree.settings_files {
2061 session.peer.send(
2062 session.connection_id,
2063 proto::UpdateWorktreeSettings {
2064 project_id: project_id.to_proto(),
2065 worktree_id: worktree.id,
2066 path: settings_file.path,
2067 content: Some(settings_file.content),
2068 kind: Some(settings_file.kind.to_proto() as i32),
2069 },
2070 )?;
2071 }
2072 }
2073
2074 for repository in mem::take(&mut project.repositories) {
2075 for update in split_repository_update(repository) {
2076 session.peer.send(session.connection_id, update)?;
2077 }
2078 }
2079
2080 for language_server in &project.language_servers {
2081 session.peer.send(
2082 session.connection_id,
2083 proto::UpdateLanguageServer {
2084 project_id: project_id.to_proto(),
2085 server_name: Some(language_server.server.name.clone()),
2086 language_server_id: language_server.server.id,
2087 variant: Some(
2088 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2089 proto::LspDiskBasedDiagnosticsUpdated {},
2090 ),
2091 ),
2092 },
2093 )?;
2094 }
2095
2096 Ok(())
2097}
2098
2099/// Leave someone elses shared project.
2100async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2101 let sender_id = session.connection_id;
2102 let project_id = ProjectId::from_proto(request.project_id);
2103 let db = session.db().await;
2104
2105 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2106 tracing::info!(
2107 %project_id,
2108 "leave project"
2109 );
2110
2111 project_left(project, &session);
2112 if let Some(room) = room {
2113 room_updated(room, &session.peer);
2114 }
2115
2116 Ok(())
2117}
2118
2119/// Updates other participants with changes to the project
2120async fn update_project(
2121 request: proto::UpdateProject,
2122 response: Response<proto::UpdateProject>,
2123 session: MessageContext,
2124) -> Result<()> {
2125 let project_id = ProjectId::from_proto(request.project_id);
2126 let (room, guest_connection_ids) = &*session
2127 .db()
2128 .await
2129 .update_project(project_id, session.connection_id, &request.worktrees)
2130 .await?;
2131 broadcast(
2132 Some(session.connection_id),
2133 guest_connection_ids.iter().copied(),
2134 |connection_id| {
2135 session
2136 .peer
2137 .forward_send(session.connection_id, connection_id, request.clone())
2138 },
2139 );
2140 if let Some(room) = room {
2141 room_updated(room, &session.peer);
2142 }
2143 response.send(proto::Ack {})?;
2144
2145 Ok(())
2146}
2147
2148/// Updates other participants with changes to the worktree
2149async fn update_worktree(
2150 request: proto::UpdateWorktree,
2151 response: Response<proto::UpdateWorktree>,
2152 session: MessageContext,
2153) -> Result<()> {
2154 let guest_connection_ids = session
2155 .db()
2156 .await
2157 .update_worktree(&request, session.connection_id)
2158 .await?;
2159
2160 broadcast(
2161 Some(session.connection_id),
2162 guest_connection_ids.iter().copied(),
2163 |connection_id| {
2164 session
2165 .peer
2166 .forward_send(session.connection_id, connection_id, request.clone())
2167 },
2168 );
2169 response.send(proto::Ack {})?;
2170 Ok(())
2171}
2172
2173async fn update_repository(
2174 request: proto::UpdateRepository,
2175 response: Response<proto::UpdateRepository>,
2176 session: MessageContext,
2177) -> Result<()> {
2178 let guest_connection_ids = session
2179 .db()
2180 .await
2181 .update_repository(&request, session.connection_id)
2182 .await?;
2183
2184 broadcast(
2185 Some(session.connection_id),
2186 guest_connection_ids.iter().copied(),
2187 |connection_id| {
2188 session
2189 .peer
2190 .forward_send(session.connection_id, connection_id, request.clone())
2191 },
2192 );
2193 response.send(proto::Ack {})?;
2194 Ok(())
2195}
2196
2197async fn remove_repository(
2198 request: proto::RemoveRepository,
2199 response: Response<proto::RemoveRepository>,
2200 session: MessageContext,
2201) -> Result<()> {
2202 let guest_connection_ids = session
2203 .db()
2204 .await
2205 .remove_repository(&request, session.connection_id)
2206 .await?;
2207
2208 broadcast(
2209 Some(session.connection_id),
2210 guest_connection_ids.iter().copied(),
2211 |connection_id| {
2212 session
2213 .peer
2214 .forward_send(session.connection_id, connection_id, request.clone())
2215 },
2216 );
2217 response.send(proto::Ack {})?;
2218 Ok(())
2219}
2220
2221/// Updates other participants with changes to the diagnostics
2222async fn update_diagnostic_summary(
2223 message: proto::UpdateDiagnosticSummary,
2224 session: MessageContext,
2225) -> Result<()> {
2226 let guest_connection_ids = session
2227 .db()
2228 .await
2229 .update_diagnostic_summary(&message, session.connection_id)
2230 .await?;
2231
2232 broadcast(
2233 Some(session.connection_id),
2234 guest_connection_ids.iter().copied(),
2235 |connection_id| {
2236 session
2237 .peer
2238 .forward_send(session.connection_id, connection_id, message.clone())
2239 },
2240 );
2241
2242 Ok(())
2243}
2244
2245/// Updates other participants with changes to the worktree settings
2246async fn update_worktree_settings(
2247 message: proto::UpdateWorktreeSettings,
2248 session: MessageContext,
2249) -> Result<()> {
2250 let guest_connection_ids = session
2251 .db()
2252 .await
2253 .update_worktree_settings(&message, session.connection_id)
2254 .await?;
2255
2256 broadcast(
2257 Some(session.connection_id),
2258 guest_connection_ids.iter().copied(),
2259 |connection_id| {
2260 session
2261 .peer
2262 .forward_send(session.connection_id, connection_id, message.clone())
2263 },
2264 );
2265
2266 Ok(())
2267}
2268
2269/// Notify other participants that a language server has started.
2270async fn start_language_server(
2271 request: proto::StartLanguageServer,
2272 session: MessageContext,
2273) -> Result<()> {
2274 let guest_connection_ids = session
2275 .db()
2276 .await
2277 .start_language_server(&request, session.connection_id)
2278 .await?;
2279
2280 broadcast(
2281 Some(session.connection_id),
2282 guest_connection_ids.iter().copied(),
2283 |connection_id| {
2284 session
2285 .peer
2286 .forward_send(session.connection_id, connection_id, request.clone())
2287 },
2288 );
2289 Ok(())
2290}
2291
2292/// Notify other participants that a language server has changed.
2293async fn update_language_server(
2294 request: proto::UpdateLanguageServer,
2295 session: MessageContext,
2296) -> Result<()> {
2297 let project_id = ProjectId::from_proto(request.project_id);
2298 let db = session.db().await;
2299
2300 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2301 && let Some(capabilities) = update.capabilities.clone()
2302 {
2303 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2304 .await?;
2305 }
2306
2307 let project_connection_ids = db
2308 .project_connection_ids(project_id, session.connection_id, true)
2309 .await?;
2310 broadcast(
2311 Some(session.connection_id),
2312 project_connection_ids.iter().copied(),
2313 |connection_id| {
2314 session
2315 .peer
2316 .forward_send(session.connection_id, connection_id, request.clone())
2317 },
2318 );
2319 Ok(())
2320}
2321
2322/// forward a project request to the host. These requests should be read only
2323/// as guests are allowed to send them.
2324async fn forward_read_only_project_request<T>(
2325 request: T,
2326 response: Response<T>,
2327 session: MessageContext,
2328) -> Result<()>
2329where
2330 T: EntityMessage + RequestMessage,
2331{
2332 let project_id = ProjectId::from_proto(request.remote_entity_id());
2333 let host_connection_id = session
2334 .db()
2335 .await
2336 .host_for_read_only_project_request(project_id, session.connection_id)
2337 .await?;
2338 let payload = session.forward_request(host_connection_id, request).await?;
2339 response.send(payload)?;
2340 Ok(())
2341}
2342
2343/// forward a project request to the host. These requests are disallowed
2344/// for guests.
2345async fn forward_mutating_project_request<T>(
2346 request: T,
2347 response: Response<T>,
2348 session: MessageContext,
2349) -> Result<()>
2350where
2351 T: EntityMessage + RequestMessage,
2352{
2353 let project_id = ProjectId::from_proto(request.remote_entity_id());
2354
2355 let host_connection_id = session
2356 .db()
2357 .await
2358 .host_for_mutating_project_request(project_id, session.connection_id)
2359 .await?;
2360 let payload = session.forward_request(host_connection_id, request).await?;
2361 response.send(payload)?;
2362 Ok(())
2363}
2364
2365// todo(lsp) remove after Zed Stable hits v0.204.x
2366async fn multi_lsp_query(
2367 request: MultiLspQuery,
2368 response: Response<MultiLspQuery>,
2369 session: MessageContext,
2370) -> Result<()> {
2371 tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2372 tracing::info!("multi_lsp_query message received");
2373 forward_mutating_project_request(request, response, session).await
2374}
2375
2376async fn lsp_query(
2377 request: proto::LspQuery,
2378 response: Response<proto::LspQuery>,
2379 session: MessageContext,
2380) -> Result<()> {
2381 let (name, should_write) = request.query_name_and_write_permissions();
2382 tracing::Span::current().record("lsp_query_request", name);
2383 tracing::info!("lsp_query message received");
2384 if should_write {
2385 forward_mutating_project_request(request, response, session).await
2386 } else {
2387 forward_read_only_project_request(request, response, session).await
2388 }
2389}
2390
2391/// Notify other participants that a new buffer has been created
2392async fn create_buffer_for_peer(
2393 request: proto::CreateBufferForPeer,
2394 session: MessageContext,
2395) -> Result<()> {
2396 session
2397 .db()
2398 .await
2399 .check_user_is_project_host(
2400 ProjectId::from_proto(request.project_id),
2401 session.connection_id,
2402 )
2403 .await?;
2404 let peer_id = request.peer_id.context("invalid peer id")?;
2405 session
2406 .peer
2407 .forward_send(session.connection_id, peer_id.into(), request)?;
2408 Ok(())
2409}
2410
2411/// Notify other participants that a buffer has been updated. This is
2412/// allowed for guests as long as the update is limited to selections.
2413async fn update_buffer(
2414 request: proto::UpdateBuffer,
2415 response: Response<proto::UpdateBuffer>,
2416 session: MessageContext,
2417) -> Result<()> {
2418 let project_id = ProjectId::from_proto(request.project_id);
2419 let mut capability = Capability::ReadOnly;
2420
2421 for op in request.operations.iter() {
2422 match op.variant {
2423 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2424 Some(_) => capability = Capability::ReadWrite,
2425 }
2426 }
2427
2428 let host = {
2429 let guard = session
2430 .db()
2431 .await
2432 .connections_for_buffer_update(project_id, session.connection_id, capability)
2433 .await?;
2434
2435 let (host, guests) = &*guard;
2436
2437 broadcast(
2438 Some(session.connection_id),
2439 guests.clone(),
2440 |connection_id| {
2441 session
2442 .peer
2443 .forward_send(session.connection_id, connection_id, request.clone())
2444 },
2445 );
2446
2447 *host
2448 };
2449
2450 if host != session.connection_id {
2451 session.forward_request(host, request.clone()).await?;
2452 }
2453
2454 response.send(proto::Ack {})?;
2455 Ok(())
2456}
2457
2458async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2459 let project_id = ProjectId::from_proto(message.project_id);
2460
2461 let operation = message.operation.as_ref().context("invalid operation")?;
2462 let capability = match operation.variant.as_ref() {
2463 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2464 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2465 match buffer_op.variant {
2466 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2467 Capability::ReadOnly
2468 }
2469 _ => Capability::ReadWrite,
2470 }
2471 } else {
2472 Capability::ReadWrite
2473 }
2474 }
2475 Some(_) => Capability::ReadWrite,
2476 None => Capability::ReadOnly,
2477 };
2478
2479 let guard = session
2480 .db()
2481 .await
2482 .connections_for_buffer_update(project_id, session.connection_id, capability)
2483 .await?;
2484
2485 let (host, guests) = &*guard;
2486
2487 broadcast(
2488 Some(session.connection_id),
2489 guests.iter().chain([host]).copied(),
2490 |connection_id| {
2491 session
2492 .peer
2493 .forward_send(session.connection_id, connection_id, message.clone())
2494 },
2495 );
2496
2497 Ok(())
2498}
2499
2500/// Notify other participants that a project has been updated.
2501async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2502 request: T,
2503 session: MessageContext,
2504) -> Result<()> {
2505 let project_id = ProjectId::from_proto(request.remote_entity_id());
2506 let project_connection_ids = session
2507 .db()
2508 .await
2509 .project_connection_ids(project_id, session.connection_id, false)
2510 .await?;
2511
2512 broadcast(
2513 Some(session.connection_id),
2514 project_connection_ids.iter().copied(),
2515 |connection_id| {
2516 session
2517 .peer
2518 .forward_send(session.connection_id, connection_id, request.clone())
2519 },
2520 );
2521 Ok(())
2522}
2523
2524/// Start following another user in a call.
2525async fn follow(
2526 request: proto::Follow,
2527 response: Response<proto::Follow>,
2528 session: MessageContext,
2529) -> Result<()> {
2530 let room_id = RoomId::from_proto(request.room_id);
2531 let project_id = request.project_id.map(ProjectId::from_proto);
2532 let leader_id = request.leader_id.context("invalid leader id")?.into();
2533 let follower_id = session.connection_id;
2534
2535 session
2536 .db()
2537 .await
2538 .check_room_participants(room_id, leader_id, session.connection_id)
2539 .await?;
2540
2541 let response_payload = session.forward_request(leader_id, request).await?;
2542 response.send(response_payload)?;
2543
2544 if let Some(project_id) = project_id {
2545 let room = session
2546 .db()
2547 .await
2548 .follow(room_id, project_id, leader_id, follower_id)
2549 .await?;
2550 room_updated(&room, &session.peer);
2551 }
2552
2553 Ok(())
2554}
2555
2556/// Stop following another user in a call.
2557async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2558 let room_id = RoomId::from_proto(request.room_id);
2559 let project_id = request.project_id.map(ProjectId::from_proto);
2560 let leader_id = request.leader_id.context("invalid leader id")?.into();
2561 let follower_id = session.connection_id;
2562
2563 session
2564 .db()
2565 .await
2566 .check_room_participants(room_id, leader_id, session.connection_id)
2567 .await?;
2568
2569 session
2570 .peer
2571 .forward_send(session.connection_id, leader_id, request)?;
2572
2573 if let Some(project_id) = project_id {
2574 let room = session
2575 .db()
2576 .await
2577 .unfollow(room_id, project_id, leader_id, follower_id)
2578 .await?;
2579 room_updated(&room, &session.peer);
2580 }
2581
2582 Ok(())
2583}
2584
2585/// Notify everyone following you of your current location.
2586async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2587 let room_id = RoomId::from_proto(request.room_id);
2588 let database = session.db.lock().await;
2589
2590 let connection_ids = if let Some(project_id) = request.project_id {
2591 let project_id = ProjectId::from_proto(project_id);
2592 database
2593 .project_connection_ids(project_id, session.connection_id, true)
2594 .await?
2595 } else {
2596 database
2597 .room_connection_ids(room_id, session.connection_id)
2598 .await?
2599 };
2600
2601 // For now, don't send view update messages back to that view's current leader.
2602 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2603 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2604 _ => None,
2605 });
2606
2607 for connection_id in connection_ids.iter().cloned() {
2608 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2609 session
2610 .peer
2611 .forward_send(session.connection_id, connection_id, request.clone())?;
2612 }
2613 }
2614 Ok(())
2615}
2616
2617/// Get public data about users.
2618async fn get_users(
2619 request: proto::GetUsers,
2620 response: Response<proto::GetUsers>,
2621 session: MessageContext,
2622) -> Result<()> {
2623 let user_ids = request
2624 .user_ids
2625 .into_iter()
2626 .map(UserId::from_proto)
2627 .collect();
2628 let users = session
2629 .db()
2630 .await
2631 .get_users_by_ids(user_ids)
2632 .await?
2633 .into_iter()
2634 .map(|user| proto::User {
2635 id: user.id.to_proto(),
2636 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2637 github_login: user.github_login,
2638 name: user.name,
2639 })
2640 .collect();
2641 response.send(proto::UsersResponse { users })?;
2642 Ok(())
2643}
2644
2645/// Search for users (to invite) buy Github login
2646async fn fuzzy_search_users(
2647 request: proto::FuzzySearchUsers,
2648 response: Response<proto::FuzzySearchUsers>,
2649 session: MessageContext,
2650) -> Result<()> {
2651 let query = request.query;
2652 let users = match query.len() {
2653 0 => vec![],
2654 1 | 2 => session
2655 .db()
2656 .await
2657 .get_user_by_github_login(&query)
2658 .await?
2659 .into_iter()
2660 .collect(),
2661 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2662 };
2663 let users = users
2664 .into_iter()
2665 .filter(|user| user.id != session.user_id())
2666 .map(|user| proto::User {
2667 id: user.id.to_proto(),
2668 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2669 github_login: user.github_login,
2670 name: user.name,
2671 })
2672 .collect();
2673 response.send(proto::UsersResponse { users })?;
2674 Ok(())
2675}
2676
2677/// Send a contact request to another user.
2678async fn request_contact(
2679 request: proto::RequestContact,
2680 response: Response<proto::RequestContact>,
2681 session: MessageContext,
2682) -> Result<()> {
2683 let requester_id = session.user_id();
2684 let responder_id = UserId::from_proto(request.responder_id);
2685 if requester_id == responder_id {
2686 return Err(anyhow!("cannot add yourself as a contact"))?;
2687 }
2688
2689 let notifications = session
2690 .db()
2691 .await
2692 .send_contact_request(requester_id, responder_id)
2693 .await?;
2694
2695 // Update outgoing contact requests of requester
2696 let mut update = proto::UpdateContacts::default();
2697 update.outgoing_requests.push(responder_id.to_proto());
2698 for connection_id in session
2699 .connection_pool()
2700 .await
2701 .user_connection_ids(requester_id)
2702 {
2703 session.peer.send(connection_id, update.clone())?;
2704 }
2705
2706 // Update incoming contact requests of responder
2707 let mut update = proto::UpdateContacts::default();
2708 update
2709 .incoming_requests
2710 .push(proto::IncomingContactRequest {
2711 requester_id: requester_id.to_proto(),
2712 });
2713 let connection_pool = session.connection_pool().await;
2714 for connection_id in connection_pool.user_connection_ids(responder_id) {
2715 session.peer.send(connection_id, update.clone())?;
2716 }
2717
2718 send_notifications(&connection_pool, &session.peer, notifications);
2719
2720 response.send(proto::Ack {})?;
2721 Ok(())
2722}
2723
2724/// Accept or decline a contact request
2725async fn respond_to_contact_request(
2726 request: proto::RespondToContactRequest,
2727 response: Response<proto::RespondToContactRequest>,
2728 session: MessageContext,
2729) -> Result<()> {
2730 let responder_id = session.user_id();
2731 let requester_id = UserId::from_proto(request.requester_id);
2732 let db = session.db().await;
2733 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2734 db.dismiss_contact_notification(responder_id, requester_id)
2735 .await?;
2736 } else {
2737 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2738
2739 let notifications = db
2740 .respond_to_contact_request(responder_id, requester_id, accept)
2741 .await?;
2742 let requester_busy = db.is_user_busy(requester_id).await?;
2743 let responder_busy = db.is_user_busy(responder_id).await?;
2744
2745 let pool = session.connection_pool().await;
2746 // Update responder with new contact
2747 let mut update = proto::UpdateContacts::default();
2748 if accept {
2749 update
2750 .contacts
2751 .push(contact_for_user(requester_id, requester_busy, &pool));
2752 }
2753 update
2754 .remove_incoming_requests
2755 .push(requester_id.to_proto());
2756 for connection_id in pool.user_connection_ids(responder_id) {
2757 session.peer.send(connection_id, update.clone())?;
2758 }
2759
2760 // Update requester with new contact
2761 let mut update = proto::UpdateContacts::default();
2762 if accept {
2763 update
2764 .contacts
2765 .push(contact_for_user(responder_id, responder_busy, &pool));
2766 }
2767 update
2768 .remove_outgoing_requests
2769 .push(responder_id.to_proto());
2770
2771 for connection_id in pool.user_connection_ids(requester_id) {
2772 session.peer.send(connection_id, update.clone())?;
2773 }
2774
2775 send_notifications(&pool, &session.peer, notifications);
2776 }
2777
2778 response.send(proto::Ack {})?;
2779 Ok(())
2780}
2781
2782/// Remove a contact.
2783async fn remove_contact(
2784 request: proto::RemoveContact,
2785 response: Response<proto::RemoveContact>,
2786 session: MessageContext,
2787) -> Result<()> {
2788 let requester_id = session.user_id();
2789 let responder_id = UserId::from_proto(request.user_id);
2790 let db = session.db().await;
2791 let (contact_accepted, deleted_notification_id) =
2792 db.remove_contact(requester_id, responder_id).await?;
2793
2794 let pool = session.connection_pool().await;
2795 // Update outgoing contact requests of requester
2796 let mut update = proto::UpdateContacts::default();
2797 if contact_accepted {
2798 update.remove_contacts.push(responder_id.to_proto());
2799 } else {
2800 update
2801 .remove_outgoing_requests
2802 .push(responder_id.to_proto());
2803 }
2804 for connection_id in pool.user_connection_ids(requester_id) {
2805 session.peer.send(connection_id, update.clone())?;
2806 }
2807
2808 // Update incoming contact requests of responder
2809 let mut update = proto::UpdateContacts::default();
2810 if contact_accepted {
2811 update.remove_contacts.push(requester_id.to_proto());
2812 } else {
2813 update
2814 .remove_incoming_requests
2815 .push(requester_id.to_proto());
2816 }
2817 for connection_id in pool.user_connection_ids(responder_id) {
2818 session.peer.send(connection_id, update.clone())?;
2819 if let Some(notification_id) = deleted_notification_id {
2820 session.peer.send(
2821 connection_id,
2822 proto::DeleteNotification {
2823 notification_id: notification_id.to_proto(),
2824 },
2825 )?;
2826 }
2827 }
2828
2829 response.send(proto::Ack {})?;
2830 Ok(())
2831}
2832
2833fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2834 version.0.minor() < 139
2835}
2836
2837async fn subscribe_to_channels(
2838 _: proto::SubscribeToChannels,
2839 session: MessageContext,
2840) -> Result<()> {
2841 subscribe_user_to_channels(session.user_id(), &session).await?;
2842 Ok(())
2843}
2844
2845async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2846 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2847 let mut pool = session.connection_pool().await;
2848 for membership in &channels_for_user.channel_memberships {
2849 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2850 }
2851 session.peer.send(
2852 session.connection_id,
2853 build_update_user_channels(&channels_for_user),
2854 )?;
2855 session.peer.send(
2856 session.connection_id,
2857 build_channels_update(channels_for_user),
2858 )?;
2859 Ok(())
2860}
2861
2862/// Creates a new channel.
2863async fn create_channel(
2864 request: proto::CreateChannel,
2865 response: Response<proto::CreateChannel>,
2866 session: MessageContext,
2867) -> Result<()> {
2868 let db = session.db().await;
2869
2870 let parent_id = request.parent_id.map(ChannelId::from_proto);
2871 let (channel, membership) = db
2872 .create_channel(&request.name, parent_id, session.user_id())
2873 .await?;
2874
2875 let root_id = channel.root_id();
2876 let channel = Channel::from_model(channel);
2877
2878 response.send(proto::CreateChannelResponse {
2879 channel: Some(channel.to_proto()),
2880 parent_id: request.parent_id,
2881 })?;
2882
2883 let mut connection_pool = session.connection_pool().await;
2884 if let Some(membership) = membership {
2885 connection_pool.subscribe_to_channel(
2886 membership.user_id,
2887 membership.channel_id,
2888 membership.role,
2889 );
2890 let update = proto::UpdateUserChannels {
2891 channel_memberships: vec![proto::ChannelMembership {
2892 channel_id: membership.channel_id.to_proto(),
2893 role: membership.role.into(),
2894 }],
2895 ..Default::default()
2896 };
2897 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2898 session.peer.send(connection_id, update.clone())?;
2899 }
2900 }
2901
2902 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2903 if !role.can_see_channel(channel.visibility) {
2904 continue;
2905 }
2906
2907 let update = proto::UpdateChannels {
2908 channels: vec![channel.to_proto()],
2909 ..Default::default()
2910 };
2911 session.peer.send(connection_id, update.clone())?;
2912 }
2913
2914 Ok(())
2915}
2916
2917/// Delete a channel
2918async fn delete_channel(
2919 request: proto::DeleteChannel,
2920 response: Response<proto::DeleteChannel>,
2921 session: MessageContext,
2922) -> Result<()> {
2923 let db = session.db().await;
2924
2925 let channel_id = request.channel_id;
2926 let (root_channel, removed_channels) = db
2927 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2928 .await?;
2929 response.send(proto::Ack {})?;
2930
2931 // Notify members of removed channels
2932 let mut update = proto::UpdateChannels::default();
2933 update
2934 .delete_channels
2935 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2936
2937 let connection_pool = session.connection_pool().await;
2938 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2939 session.peer.send(connection_id, update.clone())?;
2940 }
2941
2942 Ok(())
2943}
2944
2945/// Invite someone to join a channel.
2946async fn invite_channel_member(
2947 request: proto::InviteChannelMember,
2948 response: Response<proto::InviteChannelMember>,
2949 session: MessageContext,
2950) -> Result<()> {
2951 let db = session.db().await;
2952 let channel_id = ChannelId::from_proto(request.channel_id);
2953 let invitee_id = UserId::from_proto(request.user_id);
2954 let InviteMemberResult {
2955 channel,
2956 notifications,
2957 } = db
2958 .invite_channel_member(
2959 channel_id,
2960 invitee_id,
2961 session.user_id(),
2962 request.role().into(),
2963 )
2964 .await?;
2965
2966 let update = proto::UpdateChannels {
2967 channel_invitations: vec![channel.to_proto()],
2968 ..Default::default()
2969 };
2970
2971 let connection_pool = session.connection_pool().await;
2972 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2973 session.peer.send(connection_id, update.clone())?;
2974 }
2975
2976 send_notifications(&connection_pool, &session.peer, notifications);
2977
2978 response.send(proto::Ack {})?;
2979 Ok(())
2980}
2981
2982/// remove someone from a channel
2983async fn remove_channel_member(
2984 request: proto::RemoveChannelMember,
2985 response: Response<proto::RemoveChannelMember>,
2986 session: MessageContext,
2987) -> Result<()> {
2988 let db = session.db().await;
2989 let channel_id = ChannelId::from_proto(request.channel_id);
2990 let member_id = UserId::from_proto(request.user_id);
2991
2992 let RemoveChannelMemberResult {
2993 membership_update,
2994 notification_id,
2995 } = db
2996 .remove_channel_member(channel_id, member_id, session.user_id())
2997 .await?;
2998
2999 let mut connection_pool = session.connection_pool().await;
3000 notify_membership_updated(
3001 &mut connection_pool,
3002 membership_update,
3003 member_id,
3004 &session.peer,
3005 );
3006 for connection_id in connection_pool.user_connection_ids(member_id) {
3007 if let Some(notification_id) = notification_id {
3008 session
3009 .peer
3010 .send(
3011 connection_id,
3012 proto::DeleteNotification {
3013 notification_id: notification_id.to_proto(),
3014 },
3015 )
3016 .trace_err();
3017 }
3018 }
3019
3020 response.send(proto::Ack {})?;
3021 Ok(())
3022}
3023
3024/// Toggle the channel between public and private.
3025/// Care is taken to maintain the invariant that public channels only descend from public channels,
3026/// (though members-only channels can appear at any point in the hierarchy).
3027async fn set_channel_visibility(
3028 request: proto::SetChannelVisibility,
3029 response: Response<proto::SetChannelVisibility>,
3030 session: MessageContext,
3031) -> Result<()> {
3032 let db = session.db().await;
3033 let channel_id = ChannelId::from_proto(request.channel_id);
3034 let visibility = request.visibility().into();
3035
3036 let channel_model = db
3037 .set_channel_visibility(channel_id, visibility, session.user_id())
3038 .await?;
3039 let root_id = channel_model.root_id();
3040 let channel = Channel::from_model(channel_model);
3041
3042 let mut connection_pool = session.connection_pool().await;
3043 for (user_id, role) in connection_pool
3044 .channel_user_ids(root_id)
3045 .collect::<Vec<_>>()
3046 .into_iter()
3047 {
3048 let update = if role.can_see_channel(channel.visibility) {
3049 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3050 proto::UpdateChannels {
3051 channels: vec![channel.to_proto()],
3052 ..Default::default()
3053 }
3054 } else {
3055 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3056 proto::UpdateChannels {
3057 delete_channels: vec![channel.id.to_proto()],
3058 ..Default::default()
3059 }
3060 };
3061
3062 for connection_id in connection_pool.user_connection_ids(user_id) {
3063 session.peer.send(connection_id, update.clone())?;
3064 }
3065 }
3066
3067 response.send(proto::Ack {})?;
3068 Ok(())
3069}
3070
3071/// Alter the role for a user in the channel.
3072async fn set_channel_member_role(
3073 request: proto::SetChannelMemberRole,
3074 response: Response<proto::SetChannelMemberRole>,
3075 session: MessageContext,
3076) -> Result<()> {
3077 let db = session.db().await;
3078 let channel_id = ChannelId::from_proto(request.channel_id);
3079 let member_id = UserId::from_proto(request.user_id);
3080 let result = db
3081 .set_channel_member_role(
3082 channel_id,
3083 session.user_id(),
3084 member_id,
3085 request.role().into(),
3086 )
3087 .await?;
3088
3089 match result {
3090 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3091 let mut connection_pool = session.connection_pool().await;
3092 notify_membership_updated(
3093 &mut connection_pool,
3094 membership_update,
3095 member_id,
3096 &session.peer,
3097 )
3098 }
3099 db::SetMemberRoleResult::InviteUpdated(channel) => {
3100 let update = proto::UpdateChannels {
3101 channel_invitations: vec![channel.to_proto()],
3102 ..Default::default()
3103 };
3104
3105 for connection_id in session
3106 .connection_pool()
3107 .await
3108 .user_connection_ids(member_id)
3109 {
3110 session.peer.send(connection_id, update.clone())?;
3111 }
3112 }
3113 }
3114
3115 response.send(proto::Ack {})?;
3116 Ok(())
3117}
3118
3119/// Change the name of a channel
3120async fn rename_channel(
3121 request: proto::RenameChannel,
3122 response: Response<proto::RenameChannel>,
3123 session: MessageContext,
3124) -> Result<()> {
3125 let db = session.db().await;
3126 let channel_id = ChannelId::from_proto(request.channel_id);
3127 let channel_model = db
3128 .rename_channel(channel_id, session.user_id(), &request.name)
3129 .await?;
3130 let root_id = channel_model.root_id();
3131 let channel = Channel::from_model(channel_model);
3132
3133 response.send(proto::RenameChannelResponse {
3134 channel: Some(channel.to_proto()),
3135 })?;
3136
3137 let connection_pool = session.connection_pool().await;
3138 let update = proto::UpdateChannels {
3139 channels: vec![channel.to_proto()],
3140 ..Default::default()
3141 };
3142 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3143 if role.can_see_channel(channel.visibility) {
3144 session.peer.send(connection_id, update.clone())?;
3145 }
3146 }
3147
3148 Ok(())
3149}
3150
3151/// Move a channel to a new parent.
3152async fn move_channel(
3153 request: proto::MoveChannel,
3154 response: Response<proto::MoveChannel>,
3155 session: MessageContext,
3156) -> Result<()> {
3157 let channel_id = ChannelId::from_proto(request.channel_id);
3158 let to = ChannelId::from_proto(request.to);
3159
3160 let (root_id, channels) = session
3161 .db()
3162 .await
3163 .move_channel(channel_id, to, session.user_id())
3164 .await?;
3165
3166 let connection_pool = session.connection_pool().await;
3167 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3168 let channels = channels
3169 .iter()
3170 .filter_map(|channel| {
3171 if role.can_see_channel(channel.visibility) {
3172 Some(channel.to_proto())
3173 } else {
3174 None
3175 }
3176 })
3177 .collect::<Vec<_>>();
3178 if channels.is_empty() {
3179 continue;
3180 }
3181
3182 let update = proto::UpdateChannels {
3183 channels,
3184 ..Default::default()
3185 };
3186
3187 session.peer.send(connection_id, update.clone())?;
3188 }
3189
3190 response.send(Ack {})?;
3191 Ok(())
3192}
3193
3194async fn reorder_channel(
3195 request: proto::ReorderChannel,
3196 response: Response<proto::ReorderChannel>,
3197 session: MessageContext,
3198) -> Result<()> {
3199 let channel_id = ChannelId::from_proto(request.channel_id);
3200 let direction = request.direction();
3201
3202 let updated_channels = session
3203 .db()
3204 .await
3205 .reorder_channel(channel_id, direction, session.user_id())
3206 .await?;
3207
3208 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3209 let connection_pool = session.connection_pool().await;
3210 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3211 let channels = updated_channels
3212 .iter()
3213 .filter_map(|channel| {
3214 if role.can_see_channel(channel.visibility) {
3215 Some(channel.to_proto())
3216 } else {
3217 None
3218 }
3219 })
3220 .collect::<Vec<_>>();
3221
3222 if channels.is_empty() {
3223 continue;
3224 }
3225
3226 let update = proto::UpdateChannels {
3227 channels,
3228 ..Default::default()
3229 };
3230
3231 session.peer.send(connection_id, update.clone())?;
3232 }
3233 }
3234
3235 response.send(Ack {})?;
3236 Ok(())
3237}
3238
3239/// Get the list of channel members
3240async fn get_channel_members(
3241 request: proto::GetChannelMembers,
3242 response: Response<proto::GetChannelMembers>,
3243 session: MessageContext,
3244) -> Result<()> {
3245 let db = session.db().await;
3246 let channel_id = ChannelId::from_proto(request.channel_id);
3247 let limit = if request.limit == 0 {
3248 u16::MAX as u64
3249 } else {
3250 request.limit
3251 };
3252 let (members, users) = db
3253 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3254 .await?;
3255 response.send(proto::GetChannelMembersResponse { members, users })?;
3256 Ok(())
3257}
3258
3259/// Accept or decline a channel invitation.
3260async fn respond_to_channel_invite(
3261 request: proto::RespondToChannelInvite,
3262 response: Response<proto::RespondToChannelInvite>,
3263 session: MessageContext,
3264) -> Result<()> {
3265 let db = session.db().await;
3266 let channel_id = ChannelId::from_proto(request.channel_id);
3267 let RespondToChannelInvite {
3268 membership_update,
3269 notifications,
3270 } = db
3271 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3272 .await?;
3273
3274 let mut connection_pool = session.connection_pool().await;
3275 if let Some(membership_update) = membership_update {
3276 notify_membership_updated(
3277 &mut connection_pool,
3278 membership_update,
3279 session.user_id(),
3280 &session.peer,
3281 );
3282 } else {
3283 let update = proto::UpdateChannels {
3284 remove_channel_invitations: vec![channel_id.to_proto()],
3285 ..Default::default()
3286 };
3287
3288 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3289 session.peer.send(connection_id, update.clone())?;
3290 }
3291 };
3292
3293 send_notifications(&connection_pool, &session.peer, notifications);
3294
3295 response.send(proto::Ack {})?;
3296
3297 Ok(())
3298}
3299
3300/// Join the channels' room
3301async fn join_channel(
3302 request: proto::JoinChannel,
3303 response: Response<proto::JoinChannel>,
3304 session: MessageContext,
3305) -> Result<()> {
3306 let channel_id = ChannelId::from_proto(request.channel_id);
3307 join_channel_internal(channel_id, Box::new(response), session).await
3308}
3309
3310trait JoinChannelInternalResponse {
3311 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3312}
3313impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3314 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3315 Response::<proto::JoinChannel>::send(self, result)
3316 }
3317}
3318impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3319 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3320 Response::<proto::JoinRoom>::send(self, result)
3321 }
3322}
3323
3324async fn join_channel_internal(
3325 channel_id: ChannelId,
3326 response: Box<impl JoinChannelInternalResponse>,
3327 session: MessageContext,
3328) -> Result<()> {
3329 let joined_room = {
3330 let mut db = session.db().await;
3331 // If zed quits without leaving the room, and the user re-opens zed before the
3332 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3333 // room they were in.
3334 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3335 tracing::info!(
3336 stale_connection_id = %connection,
3337 "cleaning up stale connection",
3338 );
3339 drop(db);
3340 leave_room_for_session(&session, connection).await?;
3341 db = session.db().await;
3342 }
3343
3344 let (joined_room, membership_updated, role) = db
3345 .join_channel(channel_id, session.user_id(), session.connection_id)
3346 .await?;
3347
3348 let live_kit_connection_info =
3349 session
3350 .app_state
3351 .livekit_client
3352 .as_ref()
3353 .and_then(|live_kit| {
3354 let (can_publish, token) = if role == ChannelRole::Guest {
3355 (
3356 false,
3357 live_kit
3358 .guest_token(
3359 &joined_room.room.livekit_room,
3360 &session.user_id().to_string(),
3361 )
3362 .trace_err()?,
3363 )
3364 } else {
3365 (
3366 true,
3367 live_kit
3368 .room_token(
3369 &joined_room.room.livekit_room,
3370 &session.user_id().to_string(),
3371 )
3372 .trace_err()?,
3373 )
3374 };
3375
3376 Some(LiveKitConnectionInfo {
3377 server_url: live_kit.url().into(),
3378 token,
3379 can_publish,
3380 })
3381 });
3382
3383 response.send(proto::JoinRoomResponse {
3384 room: Some(joined_room.room.clone()),
3385 channel_id: joined_room
3386 .channel
3387 .as_ref()
3388 .map(|channel| channel.id.to_proto()),
3389 live_kit_connection_info,
3390 })?;
3391
3392 let mut connection_pool = session.connection_pool().await;
3393 if let Some(membership_updated) = membership_updated {
3394 notify_membership_updated(
3395 &mut connection_pool,
3396 membership_updated,
3397 session.user_id(),
3398 &session.peer,
3399 );
3400 }
3401
3402 room_updated(&joined_room.room, &session.peer);
3403
3404 joined_room
3405 };
3406
3407 channel_updated(
3408 &joined_room.channel.context("channel not returned")?,
3409 &joined_room.room,
3410 &session.peer,
3411 &*session.connection_pool().await,
3412 );
3413
3414 update_user_contacts(session.user_id(), &session).await?;
3415 Ok(())
3416}
3417
3418/// Start editing the channel notes
3419async fn join_channel_buffer(
3420 request: proto::JoinChannelBuffer,
3421 response: Response<proto::JoinChannelBuffer>,
3422 session: MessageContext,
3423) -> Result<()> {
3424 let db = session.db().await;
3425 let channel_id = ChannelId::from_proto(request.channel_id);
3426
3427 let open_response = db
3428 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3429 .await?;
3430
3431 let collaborators = open_response.collaborators.clone();
3432 response.send(open_response)?;
3433
3434 let update = UpdateChannelBufferCollaborators {
3435 channel_id: channel_id.to_proto(),
3436 collaborators: collaborators.clone(),
3437 };
3438 channel_buffer_updated(
3439 session.connection_id,
3440 collaborators
3441 .iter()
3442 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3443 &update,
3444 &session.peer,
3445 );
3446
3447 Ok(())
3448}
3449
3450/// Edit the channel notes
3451async fn update_channel_buffer(
3452 request: proto::UpdateChannelBuffer,
3453 session: MessageContext,
3454) -> Result<()> {
3455 let db = session.db().await;
3456 let channel_id = ChannelId::from_proto(request.channel_id);
3457
3458 let (collaborators, epoch, version) = db
3459 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3460 .await?;
3461
3462 channel_buffer_updated(
3463 session.connection_id,
3464 collaborators.clone(),
3465 &proto::UpdateChannelBuffer {
3466 channel_id: channel_id.to_proto(),
3467 operations: request.operations,
3468 },
3469 &session.peer,
3470 );
3471
3472 let pool = &*session.connection_pool().await;
3473
3474 let non_collaborators =
3475 pool.channel_connection_ids(channel_id)
3476 .filter_map(|(connection_id, _)| {
3477 if collaborators.contains(&connection_id) {
3478 None
3479 } else {
3480 Some(connection_id)
3481 }
3482 });
3483
3484 broadcast(None, non_collaborators, |peer_id| {
3485 session.peer.send(
3486 peer_id,
3487 proto::UpdateChannels {
3488 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3489 channel_id: channel_id.to_proto(),
3490 epoch: epoch as u64,
3491 version: version.clone(),
3492 }],
3493 ..Default::default()
3494 },
3495 )
3496 });
3497
3498 Ok(())
3499}
3500
3501/// Rejoin the channel notes after a connection blip
3502async fn rejoin_channel_buffers(
3503 request: proto::RejoinChannelBuffers,
3504 response: Response<proto::RejoinChannelBuffers>,
3505 session: MessageContext,
3506) -> Result<()> {
3507 let db = session.db().await;
3508 let buffers = db
3509 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3510 .await?;
3511
3512 for rejoined_buffer in &buffers {
3513 let collaborators_to_notify = rejoined_buffer
3514 .buffer
3515 .collaborators
3516 .iter()
3517 .filter_map(|c| Some(c.peer_id?.into()));
3518 channel_buffer_updated(
3519 session.connection_id,
3520 collaborators_to_notify,
3521 &proto::UpdateChannelBufferCollaborators {
3522 channel_id: rejoined_buffer.buffer.channel_id,
3523 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3524 },
3525 &session.peer,
3526 );
3527 }
3528
3529 response.send(proto::RejoinChannelBuffersResponse {
3530 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3531 })?;
3532
3533 Ok(())
3534}
3535
3536/// Stop editing the channel notes
3537async fn leave_channel_buffer(
3538 request: proto::LeaveChannelBuffer,
3539 response: Response<proto::LeaveChannelBuffer>,
3540 session: MessageContext,
3541) -> Result<()> {
3542 let db = session.db().await;
3543 let channel_id = ChannelId::from_proto(request.channel_id);
3544
3545 let left_buffer = db
3546 .leave_channel_buffer(channel_id, session.connection_id)
3547 .await?;
3548
3549 response.send(Ack {})?;
3550
3551 channel_buffer_updated(
3552 session.connection_id,
3553 left_buffer.connections,
3554 &proto::UpdateChannelBufferCollaborators {
3555 channel_id: channel_id.to_proto(),
3556 collaborators: left_buffer.collaborators,
3557 },
3558 &session.peer,
3559 );
3560
3561 Ok(())
3562}
3563
3564fn channel_buffer_updated<T: EnvelopedMessage>(
3565 sender_id: ConnectionId,
3566 collaborators: impl IntoIterator<Item = ConnectionId>,
3567 message: &T,
3568 peer: &Peer,
3569) {
3570 broadcast(Some(sender_id), collaborators, |peer_id| {
3571 peer.send(peer_id, message.clone())
3572 });
3573}
3574
3575fn send_notifications(
3576 connection_pool: &ConnectionPool,
3577 peer: &Peer,
3578 notifications: db::NotificationBatch,
3579) {
3580 for (user_id, notification) in notifications {
3581 for connection_id in connection_pool.user_connection_ids(user_id) {
3582 if let Err(error) = peer.send(
3583 connection_id,
3584 proto::AddNotification {
3585 notification: Some(notification.clone()),
3586 },
3587 ) {
3588 tracing::error!(
3589 "failed to send notification to {:?} {}",
3590 connection_id,
3591 error
3592 );
3593 }
3594 }
3595 }
3596}
3597
3598/// Send a message to the channel
3599async fn send_channel_message(
3600 request: proto::SendChannelMessage,
3601 response: Response<proto::SendChannelMessage>,
3602 session: MessageContext,
3603) -> Result<()> {
3604 // Validate the message body.
3605 let body = request.body.trim().to_string();
3606 if body.len() > MAX_MESSAGE_LEN {
3607 return Err(anyhow!("message is too long"))?;
3608 }
3609 if body.is_empty() {
3610 return Err(anyhow!("message can't be blank"))?;
3611 }
3612
3613 // TODO: adjust mentions if body is trimmed
3614
3615 let timestamp = OffsetDateTime::now_utc();
3616 let nonce = request.nonce.context("nonce can't be blank")?;
3617
3618 let channel_id = ChannelId::from_proto(request.channel_id);
3619 let CreatedChannelMessage {
3620 message_id,
3621 participant_connection_ids,
3622 notifications,
3623 } = session
3624 .db()
3625 .await
3626 .create_channel_message(
3627 channel_id,
3628 session.user_id(),
3629 &body,
3630 &request.mentions,
3631 timestamp,
3632 nonce.clone().into(),
3633 request.reply_to_message_id.map(MessageId::from_proto),
3634 )
3635 .await?;
3636
3637 let message = proto::ChannelMessage {
3638 sender_id: session.user_id().to_proto(),
3639 id: message_id.to_proto(),
3640 body,
3641 mentions: request.mentions,
3642 timestamp: timestamp.unix_timestamp() as u64,
3643 nonce: Some(nonce),
3644 reply_to_message_id: request.reply_to_message_id,
3645 edited_at: None,
3646 };
3647 broadcast(
3648 Some(session.connection_id),
3649 participant_connection_ids.clone(),
3650 |connection| {
3651 session.peer.send(
3652 connection,
3653 proto::ChannelMessageSent {
3654 channel_id: channel_id.to_proto(),
3655 message: Some(message.clone()),
3656 },
3657 )
3658 },
3659 );
3660 response.send(proto::SendChannelMessageResponse {
3661 message: Some(message),
3662 })?;
3663
3664 let pool = &*session.connection_pool().await;
3665 let non_participants =
3666 pool.channel_connection_ids(channel_id)
3667 .filter_map(|(connection_id, _)| {
3668 if participant_connection_ids.contains(&connection_id) {
3669 None
3670 } else {
3671 Some(connection_id)
3672 }
3673 });
3674 broadcast(None, non_participants, |peer_id| {
3675 session.peer.send(
3676 peer_id,
3677 proto::UpdateChannels {
3678 latest_channel_message_ids: vec![proto::ChannelMessageId {
3679 channel_id: channel_id.to_proto(),
3680 message_id: message_id.to_proto(),
3681 }],
3682 ..Default::default()
3683 },
3684 )
3685 });
3686 send_notifications(pool, &session.peer, notifications);
3687
3688 Ok(())
3689}
3690
3691/// Delete a channel message
3692async fn remove_channel_message(
3693 request: proto::RemoveChannelMessage,
3694 response: Response<proto::RemoveChannelMessage>,
3695 session: MessageContext,
3696) -> Result<()> {
3697 let channel_id = ChannelId::from_proto(request.channel_id);
3698 let message_id = MessageId::from_proto(request.message_id);
3699 let (connection_ids, existing_notification_ids) = session
3700 .db()
3701 .await
3702 .remove_channel_message(channel_id, message_id, session.user_id())
3703 .await?;
3704
3705 broadcast(
3706 Some(session.connection_id),
3707 connection_ids,
3708 move |connection| {
3709 session.peer.send(connection, request.clone())?;
3710
3711 for notification_id in &existing_notification_ids {
3712 session.peer.send(
3713 connection,
3714 proto::DeleteNotification {
3715 notification_id: (*notification_id).to_proto(),
3716 },
3717 )?;
3718 }
3719
3720 Ok(())
3721 },
3722 );
3723 response.send(proto::Ack {})?;
3724 Ok(())
3725}
3726
3727async fn update_channel_message(
3728 request: proto::UpdateChannelMessage,
3729 response: Response<proto::UpdateChannelMessage>,
3730 session: MessageContext,
3731) -> Result<()> {
3732 let channel_id = ChannelId::from_proto(request.channel_id);
3733 let message_id = MessageId::from_proto(request.message_id);
3734 let updated_at = OffsetDateTime::now_utc();
3735 let UpdatedChannelMessage {
3736 message_id,
3737 participant_connection_ids,
3738 notifications,
3739 reply_to_message_id,
3740 timestamp,
3741 deleted_mention_notification_ids,
3742 updated_mention_notifications,
3743 } = session
3744 .db()
3745 .await
3746 .update_channel_message(
3747 channel_id,
3748 message_id,
3749 session.user_id(),
3750 request.body.as_str(),
3751 &request.mentions,
3752 updated_at,
3753 )
3754 .await?;
3755
3756 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3757
3758 let message = proto::ChannelMessage {
3759 sender_id: session.user_id().to_proto(),
3760 id: message_id.to_proto(),
3761 body: request.body.clone(),
3762 mentions: request.mentions.clone(),
3763 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3764 nonce: Some(nonce),
3765 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3766 edited_at: Some(updated_at.unix_timestamp() as u64),
3767 };
3768
3769 response.send(proto::Ack {})?;
3770
3771 let pool = &*session.connection_pool().await;
3772 broadcast(
3773 Some(session.connection_id),
3774 participant_connection_ids,
3775 |connection| {
3776 session.peer.send(
3777 connection,
3778 proto::ChannelMessageUpdate {
3779 channel_id: channel_id.to_proto(),
3780 message: Some(message.clone()),
3781 },
3782 )?;
3783
3784 for notification_id in &deleted_mention_notification_ids {
3785 session.peer.send(
3786 connection,
3787 proto::DeleteNotification {
3788 notification_id: (*notification_id).to_proto(),
3789 },
3790 )?;
3791 }
3792
3793 for notification in &updated_mention_notifications {
3794 session.peer.send(
3795 connection,
3796 proto::UpdateNotification {
3797 notification: Some(notification.clone()),
3798 },
3799 )?;
3800 }
3801
3802 Ok(())
3803 },
3804 );
3805
3806 send_notifications(pool, &session.peer, notifications);
3807
3808 Ok(())
3809}
3810
3811/// Mark a channel message as read
3812async fn acknowledge_channel_message(
3813 request: proto::AckChannelMessage,
3814 session: MessageContext,
3815) -> Result<()> {
3816 let channel_id = ChannelId::from_proto(request.channel_id);
3817 let message_id = MessageId::from_proto(request.message_id);
3818 let notifications = session
3819 .db()
3820 .await
3821 .observe_channel_message(channel_id, session.user_id(), message_id)
3822 .await?;
3823 send_notifications(
3824 &*session.connection_pool().await,
3825 &session.peer,
3826 notifications,
3827 );
3828 Ok(())
3829}
3830
3831/// Mark a buffer version as synced
3832async fn acknowledge_buffer_version(
3833 request: proto::AckBufferOperation,
3834 session: MessageContext,
3835) -> Result<()> {
3836 let buffer_id = BufferId::from_proto(request.buffer_id);
3837 session
3838 .db()
3839 .await
3840 .observe_buffer_version(
3841 buffer_id,
3842 session.user_id(),
3843 request.epoch as i32,
3844 &request.version,
3845 )
3846 .await?;
3847 Ok(())
3848}
3849
3850/// Get a Supermaven API key for the user
3851async fn get_supermaven_api_key(
3852 _request: proto::GetSupermavenApiKey,
3853 response: Response<proto::GetSupermavenApiKey>,
3854 session: MessageContext,
3855) -> Result<()> {
3856 let user_id: String = session.user_id().to_string();
3857 if !session.is_staff() {
3858 return Err(anyhow!("supermaven not enabled for this account"))?;
3859 }
3860
3861 let email = session.email().context("user must have an email")?;
3862
3863 let supermaven_admin_api = session
3864 .supermaven_client
3865 .as_ref()
3866 .context("supermaven not configured")?;
3867
3868 let result = supermaven_admin_api
3869 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3870 .await?;
3871
3872 response.send(proto::GetSupermavenApiKeyResponse {
3873 api_key: result.api_key,
3874 })?;
3875
3876 Ok(())
3877}
3878
3879/// Start receiving chat updates for a channel
3880async fn join_channel_chat(
3881 request: proto::JoinChannelChat,
3882 response: Response<proto::JoinChannelChat>,
3883 session: MessageContext,
3884) -> Result<()> {
3885 let channel_id = ChannelId::from_proto(request.channel_id);
3886
3887 let db = session.db().await;
3888 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3889 .await?;
3890 let messages = db
3891 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3892 .await?;
3893 response.send(proto::JoinChannelChatResponse {
3894 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3895 messages,
3896 })?;
3897 Ok(())
3898}
3899
3900/// Stop receiving chat updates for a channel
3901async fn leave_channel_chat(
3902 request: proto::LeaveChannelChat,
3903 session: MessageContext,
3904) -> Result<()> {
3905 let channel_id = ChannelId::from_proto(request.channel_id);
3906 session
3907 .db()
3908 .await
3909 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3910 .await?;
3911 Ok(())
3912}
3913
3914/// Retrieve the chat history for a channel
3915async fn get_channel_messages(
3916 request: proto::GetChannelMessages,
3917 response: Response<proto::GetChannelMessages>,
3918 session: MessageContext,
3919) -> Result<()> {
3920 let channel_id = ChannelId::from_proto(request.channel_id);
3921 let messages = session
3922 .db()
3923 .await
3924 .get_channel_messages(
3925 channel_id,
3926 session.user_id(),
3927 MESSAGE_COUNT_PER_PAGE,
3928 Some(MessageId::from_proto(request.before_message_id)),
3929 )
3930 .await?;
3931 response.send(proto::GetChannelMessagesResponse {
3932 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3933 messages,
3934 })?;
3935 Ok(())
3936}
3937
3938/// Retrieve specific chat messages
3939async fn get_channel_messages_by_id(
3940 request: proto::GetChannelMessagesById,
3941 response: Response<proto::GetChannelMessagesById>,
3942 session: MessageContext,
3943) -> Result<()> {
3944 let message_ids = request
3945 .message_ids
3946 .iter()
3947 .map(|id| MessageId::from_proto(*id))
3948 .collect::<Vec<_>>();
3949 let messages = session
3950 .db()
3951 .await
3952 .get_channel_messages_by_id(session.user_id(), &message_ids)
3953 .await?;
3954 response.send(proto::GetChannelMessagesResponse {
3955 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3956 messages,
3957 })?;
3958 Ok(())
3959}
3960
3961/// Retrieve the current users notifications
3962async fn get_notifications(
3963 request: proto::GetNotifications,
3964 response: Response<proto::GetNotifications>,
3965 session: MessageContext,
3966) -> Result<()> {
3967 let notifications = session
3968 .db()
3969 .await
3970 .get_notifications(
3971 session.user_id(),
3972 NOTIFICATION_COUNT_PER_PAGE,
3973 request.before_id.map(db::NotificationId::from_proto),
3974 )
3975 .await?;
3976 response.send(proto::GetNotificationsResponse {
3977 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3978 notifications,
3979 })?;
3980 Ok(())
3981}
3982
3983/// Mark notifications as read
3984async fn mark_notification_as_read(
3985 request: proto::MarkNotificationRead,
3986 response: Response<proto::MarkNotificationRead>,
3987 session: MessageContext,
3988) -> Result<()> {
3989 let database = &session.db().await;
3990 let notifications = database
3991 .mark_notification_as_read_by_id(
3992 session.user_id(),
3993 NotificationId::from_proto(request.notification_id),
3994 )
3995 .await?;
3996 send_notifications(
3997 &*session.connection_pool().await,
3998 &session.peer,
3999 notifications,
4000 );
4001 response.send(proto::Ack {})?;
4002 Ok(())
4003}
4004
4005fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4006 let message = match message {
4007 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4008 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4009 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4010 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4011 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4012 code: frame.code.into(),
4013 reason: frame.reason.as_str().to_owned().into(),
4014 })),
4015 // We should never receive a frame while reading the message, according
4016 // to the `tungstenite` maintainers:
4017 //
4018 // > It cannot occur when you read messages from the WebSocket, but it
4019 // > can be used when you want to send the raw frames (e.g. you want to
4020 // > send the frames to the WebSocket without composing the full message first).
4021 // >
4022 // > — https://github.com/snapview/tungstenite-rs/issues/268
4023 TungsteniteMessage::Frame(_) => {
4024 bail!("received an unexpected frame while reading the message")
4025 }
4026 };
4027
4028 Ok(message)
4029}
4030
4031fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4032 match message {
4033 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4034 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4035 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4036 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4037 AxumMessage::Close(frame) => {
4038 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4039 code: frame.code.into(),
4040 reason: frame.reason.as_ref().into(),
4041 }))
4042 }
4043 }
4044}
4045
4046fn notify_membership_updated(
4047 connection_pool: &mut ConnectionPool,
4048 result: MembershipUpdated,
4049 user_id: UserId,
4050 peer: &Peer,
4051) {
4052 for membership in &result.new_channels.channel_memberships {
4053 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4054 }
4055 for channel_id in &result.removed_channels {
4056 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4057 }
4058
4059 let user_channels_update = proto::UpdateUserChannels {
4060 channel_memberships: result
4061 .new_channels
4062 .channel_memberships
4063 .iter()
4064 .map(|cm| proto::ChannelMembership {
4065 channel_id: cm.channel_id.to_proto(),
4066 role: cm.role.into(),
4067 })
4068 .collect(),
4069 ..Default::default()
4070 };
4071
4072 let mut update = build_channels_update(result.new_channels);
4073 update.delete_channels = result
4074 .removed_channels
4075 .into_iter()
4076 .map(|id| id.to_proto())
4077 .collect();
4078 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4079
4080 for connection_id in connection_pool.user_connection_ids(user_id) {
4081 peer.send(connection_id, user_channels_update.clone())
4082 .trace_err();
4083 peer.send(connection_id, update.clone()).trace_err();
4084 }
4085}
4086
4087fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4088 proto::UpdateUserChannels {
4089 channel_memberships: channels
4090 .channel_memberships
4091 .iter()
4092 .map(|m| proto::ChannelMembership {
4093 channel_id: m.channel_id.to_proto(),
4094 role: m.role.into(),
4095 })
4096 .collect(),
4097 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4098 observed_channel_message_id: channels.observed_channel_messages.clone(),
4099 }
4100}
4101
4102fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4103 let mut update = proto::UpdateChannels::default();
4104
4105 for channel in channels.channels {
4106 update.channels.push(channel.to_proto());
4107 }
4108
4109 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4110 update.latest_channel_message_ids = channels.latest_channel_messages;
4111
4112 for (channel_id, participants) in channels.channel_participants {
4113 update
4114 .channel_participants
4115 .push(proto::ChannelParticipants {
4116 channel_id: channel_id.to_proto(),
4117 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4118 });
4119 }
4120
4121 for channel in channels.invited_channels {
4122 update.channel_invitations.push(channel.to_proto());
4123 }
4124
4125 update
4126}
4127
4128fn build_initial_contacts_update(
4129 contacts: Vec<db::Contact>,
4130 pool: &ConnectionPool,
4131) -> proto::UpdateContacts {
4132 let mut update = proto::UpdateContacts::default();
4133
4134 for contact in contacts {
4135 match contact {
4136 db::Contact::Accepted { user_id, busy } => {
4137 update.contacts.push(contact_for_user(user_id, busy, pool));
4138 }
4139 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4140 db::Contact::Incoming { user_id } => {
4141 update
4142 .incoming_requests
4143 .push(proto::IncomingContactRequest {
4144 requester_id: user_id.to_proto(),
4145 })
4146 }
4147 }
4148 }
4149
4150 update
4151}
4152
4153fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4154 proto::Contact {
4155 user_id: user_id.to_proto(),
4156 online: pool.is_user_online(user_id),
4157 busy,
4158 }
4159}
4160
4161fn room_updated(room: &proto::Room, peer: &Peer) {
4162 broadcast(
4163 None,
4164 room.participants
4165 .iter()
4166 .filter_map(|participant| Some(participant.peer_id?.into())),
4167 |peer_id| {
4168 peer.send(
4169 peer_id,
4170 proto::RoomUpdated {
4171 room: Some(room.clone()),
4172 },
4173 )
4174 },
4175 );
4176}
4177
4178fn channel_updated(
4179 channel: &db::channel::Model,
4180 room: &proto::Room,
4181 peer: &Peer,
4182 pool: &ConnectionPool,
4183) {
4184 let participants = room
4185 .participants
4186 .iter()
4187 .map(|p| p.user_id)
4188 .collect::<Vec<_>>();
4189
4190 broadcast(
4191 None,
4192 pool.channel_connection_ids(channel.root_id())
4193 .filter_map(|(channel_id, role)| {
4194 role.can_see_channel(channel.visibility)
4195 .then_some(channel_id)
4196 }),
4197 |peer_id| {
4198 peer.send(
4199 peer_id,
4200 proto::UpdateChannels {
4201 channel_participants: vec![proto::ChannelParticipants {
4202 channel_id: channel.id.to_proto(),
4203 participant_user_ids: participants.clone(),
4204 }],
4205 ..Default::default()
4206 },
4207 )
4208 },
4209 );
4210}
4211
4212async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4213 let db = session.db().await;
4214
4215 let contacts = db.get_contacts(user_id).await?;
4216 let busy = db.is_user_busy(user_id).await?;
4217
4218 let pool = session.connection_pool().await;
4219 let updated_contact = contact_for_user(user_id, busy, &pool);
4220 for contact in contacts {
4221 if let db::Contact::Accepted {
4222 user_id: contact_user_id,
4223 ..
4224 } = contact
4225 {
4226 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4227 session
4228 .peer
4229 .send(
4230 contact_conn_id,
4231 proto::UpdateContacts {
4232 contacts: vec![updated_contact.clone()],
4233 remove_contacts: Default::default(),
4234 incoming_requests: Default::default(),
4235 remove_incoming_requests: Default::default(),
4236 outgoing_requests: Default::default(),
4237 remove_outgoing_requests: Default::default(),
4238 },
4239 )
4240 .trace_err();
4241 }
4242 }
4243 }
4244 Ok(())
4245}
4246
4247async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4248 let mut contacts_to_update = HashSet::default();
4249
4250 let room_id;
4251 let canceled_calls_to_user_ids;
4252 let livekit_room;
4253 let delete_livekit_room;
4254 let room;
4255 let channel;
4256
4257 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4258 contacts_to_update.insert(session.user_id());
4259
4260 for project in left_room.left_projects.values() {
4261 project_left(project, session);
4262 }
4263
4264 room_id = RoomId::from_proto(left_room.room.id);
4265 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4266 livekit_room = mem::take(&mut left_room.room.livekit_room);
4267 delete_livekit_room = left_room.deleted;
4268 room = mem::take(&mut left_room.room);
4269 channel = mem::take(&mut left_room.channel);
4270
4271 room_updated(&room, &session.peer);
4272 } else {
4273 return Ok(());
4274 }
4275
4276 if let Some(channel) = channel {
4277 channel_updated(
4278 &channel,
4279 &room,
4280 &session.peer,
4281 &*session.connection_pool().await,
4282 );
4283 }
4284
4285 {
4286 let pool = session.connection_pool().await;
4287 for canceled_user_id in canceled_calls_to_user_ids {
4288 for connection_id in pool.user_connection_ids(canceled_user_id) {
4289 session
4290 .peer
4291 .send(
4292 connection_id,
4293 proto::CallCanceled {
4294 room_id: room_id.to_proto(),
4295 },
4296 )
4297 .trace_err();
4298 }
4299 contacts_to_update.insert(canceled_user_id);
4300 }
4301 }
4302
4303 for contact_user_id in contacts_to_update {
4304 update_user_contacts(contact_user_id, session).await?;
4305 }
4306
4307 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4308 live_kit
4309 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4310 .await
4311 .trace_err();
4312
4313 if delete_livekit_room {
4314 live_kit.delete_room(livekit_room).await.trace_err();
4315 }
4316 }
4317
4318 Ok(())
4319}
4320
4321async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4322 let left_channel_buffers = session
4323 .db()
4324 .await
4325 .leave_channel_buffers(session.connection_id)
4326 .await?;
4327
4328 for left_buffer in left_channel_buffers {
4329 channel_buffer_updated(
4330 session.connection_id,
4331 left_buffer.connections,
4332 &proto::UpdateChannelBufferCollaborators {
4333 channel_id: left_buffer.channel_id.to_proto(),
4334 collaborators: left_buffer.collaborators,
4335 },
4336 &session.peer,
4337 );
4338 }
4339
4340 Ok(())
4341}
4342
4343fn project_left(project: &db::LeftProject, session: &Session) {
4344 for connection_id in &project.connection_ids {
4345 if project.should_unshare {
4346 session
4347 .peer
4348 .send(
4349 *connection_id,
4350 proto::UnshareProject {
4351 project_id: project.id.to_proto(),
4352 },
4353 )
4354 .trace_err();
4355 } else {
4356 session
4357 .peer
4358 .send(
4359 *connection_id,
4360 proto::RemoveProjectCollaborator {
4361 project_id: project.id.to_proto(),
4362 peer_id: Some(session.connection_id.into()),
4363 },
4364 )
4365 .trace_err();
4366 }
4367 }
4368}
4369
4370pub trait ResultExt {
4371 type Ok;
4372
4373 fn trace_err(self) -> Option<Self::Ok>;
4374}
4375
4376impl<T, E> ResultExt for Result<T, E>
4377where
4378 E: std::fmt::Debug,
4379{
4380 type Ok = T;
4381
4382 #[track_caller]
4383 fn trace_err(self) -> Option<T> {
4384 match self {
4385 Ok(value) => Some(value),
4386 Err(error) => {
4387 tracing::error!("{:?}", error);
4388 None
4389 }
4390 }
4391 }
4392}