1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
8 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
9 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
10 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use futures::TryFutureExt as _;
36use reqwest_client::ReqwestClient;
37use rpc::proto::{MultiLspQuery, split_repository_update};
38use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
39use tracing::Span;
40
41use futures::{
42 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
43 stream::FuturesUnordered,
44};
45use prometheus::{IntGauge, register_int_gauge};
46use rpc::{
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52};
53use semantic_version::SemanticVersion;
54use serde::{Serialize, Serializer};
55use std::{
56 any::TypeId,
57 future::Future,
58 marker::PhantomData,
59 mem,
60 net::SocketAddr,
61 ops::{Deref, DerefMut},
62 rc::Rc,
63 sync::{
64 Arc, OnceLock,
65 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
66 },
67 time::{Duration, Instant},
68};
69use time::OffsetDateTime;
70use tokio::sync::{Semaphore, watch};
71use tower::ServiceBuilder;
72use tracing::{
73 Instrument,
74 field::{self},
75 info_span, instrument,
76};
77
78pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
79
80// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
81pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
82
83const MESSAGE_COUNT_PER_PAGE: usize = 100;
84const MAX_MESSAGE_LEN: usize = 1024;
85const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
86const MAX_CONCURRENT_CONNECTIONS: usize = 512;
87
88static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
89
90const TOTAL_DURATION_MS: &str = "total_duration_ms";
91const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
92const QUEUE_DURATION_MS: &str = "queue_duration_ms";
93const HOST_WAITING_MS: &str = "host_waiting_ms";
94
95type MessageHandler =
96 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
97
98pub struct ConnectionGuard;
99
100impl ConnectionGuard {
101 pub fn try_acquire() -> Result<Self, ()> {
102 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
103 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
104 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
105 tracing::error!(
106 "too many concurrent connections: {}",
107 current_connections + 1
108 );
109 return Err(());
110 }
111 Ok(ConnectionGuard)
112 }
113}
114
115impl Drop for ConnectionGuard {
116 fn drop(&mut self) {
117 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
118 }
119}
120
121struct Response<R> {
122 peer: Arc<Peer>,
123 receipt: Receipt<R>,
124 responded: Arc<AtomicBool>,
125}
126
127impl<R: RequestMessage> Response<R> {
128 fn send(self, payload: R::Response) -> Result<()> {
129 self.responded.store(true, SeqCst);
130 self.peer.respond(self.receipt, payload)?;
131 Ok(())
132 }
133}
134
135#[derive(Clone, Debug)]
136pub enum Principal {
137 User(User),
138 Impersonated { user: User, admin: User },
139}
140
141impl Principal {
142 fn update_span(&self, span: &tracing::Span) {
143 match &self {
144 Principal::User(user) => {
145 span.record("user_id", user.id.0);
146 span.record("login", &user.github_login);
147 }
148 Principal::Impersonated { user, admin } => {
149 span.record("user_id", user.id.0);
150 span.record("login", &user.github_login);
151 span.record("impersonator", &admin.github_login);
152 }
153 }
154 }
155}
156
157#[derive(Clone)]
158struct MessageContext {
159 session: Session,
160 span: tracing::Span,
161}
162
163impl Deref for MessageContext {
164 type Target = Session;
165
166 fn deref(&self) -> &Self::Target {
167 &self.session
168 }
169}
170
171impl MessageContext {
172 pub fn forward_request<T: RequestMessage>(
173 &self,
174 receiver_id: ConnectionId,
175 request: T,
176 ) -> impl Future<Output = anyhow::Result<T::Response>> {
177 let request_start_time = Instant::now();
178 let span = self.span.clone();
179 tracing::info!("start forwarding request");
180 self.peer
181 .forward_request(self.connection_id, receiver_id, request)
182 .inspect(move |_| {
183 span.record(
184 HOST_WAITING_MS,
185 request_start_time.elapsed().as_micros() as f64 / 1000.0,
186 );
187 })
188 .inspect_err(|_| tracing::error!("error forwarding request"))
189 .inspect_ok(|_| tracing::info!("finished forwarding request"))
190 }
191}
192
193#[derive(Clone)]
194struct Session {
195 principal: Principal,
196 connection_id: ConnectionId,
197 db: Arc<tokio::sync::Mutex<DbHandle>>,
198 peer: Arc<Peer>,
199 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
200 app_state: Arc<AppState>,
201 supermaven_client: Option<Arc<SupermavenAdminApi>>,
202 /// The GeoIP country code for the user.
203 #[allow(unused)]
204 geoip_country_code: Option<String>,
205 #[allow(unused)]
206 system_id: Option<String>,
207 _executor: Executor,
208}
209
210impl Session {
211 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
212 #[cfg(test)]
213 tokio::task::yield_now().await;
214 let guard = self.db.lock().await;
215 #[cfg(test)]
216 tokio::task::yield_now().await;
217 guard
218 }
219
220 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
221 #[cfg(test)]
222 tokio::task::yield_now().await;
223 let guard = self.connection_pool.lock();
224 ConnectionPoolGuard {
225 guard,
226 _not_send: PhantomData,
227 }
228 }
229
230 fn is_staff(&self) -> bool {
231 match &self.principal {
232 Principal::User(user) => user.admin,
233 Principal::Impersonated { .. } => true,
234 }
235 }
236
237 fn user_id(&self) -> UserId {
238 match &self.principal {
239 Principal::User(user) => user.id,
240 Principal::Impersonated { user, .. } => user.id,
241 }
242 }
243
244 pub fn email(&self) -> Option<String> {
245 match &self.principal {
246 Principal::User(user) => user.email_address.clone(),
247 Principal::Impersonated { user, .. } => user.email_address.clone(),
248 }
249 }
250}
251
252impl Debug for Session {
253 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
254 let mut result = f.debug_struct("Session");
255 match &self.principal {
256 Principal::User(user) => {
257 result.field("user", &user.github_login);
258 }
259 Principal::Impersonated { user, admin } => {
260 result.field("user", &user.github_login);
261 result.field("impersonator", &admin.github_login);
262 }
263 }
264 result.field("connection_id", &self.connection_id).finish()
265 }
266}
267
268struct DbHandle(Arc<Database>);
269
270impl Deref for DbHandle {
271 type Target = Database;
272
273 fn deref(&self) -> &Self::Target {
274 self.0.as_ref()
275 }
276}
277
278pub struct Server {
279 id: parking_lot::Mutex<ServerId>,
280 peer: Arc<Peer>,
281 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
282 app_state: Arc<AppState>,
283 handlers: HashMap<TypeId, MessageHandler>,
284 teardown: watch::Sender<bool>,
285}
286
287pub(crate) struct ConnectionPoolGuard<'a> {
288 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
289 _not_send: PhantomData<Rc<()>>,
290}
291
292#[derive(Serialize)]
293pub struct ServerSnapshot<'a> {
294 peer: &'a Peer,
295 #[serde(serialize_with = "serialize_deref")]
296 connection_pool: ConnectionPoolGuard<'a>,
297}
298
299pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
300where
301 S: Serializer,
302 T: Deref<Target = U>,
303 U: Serialize,
304{
305 Serialize::serialize(value.deref(), serializer)
306}
307
308impl Server {
309 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
310 let mut server = Self {
311 id: parking_lot::Mutex::new(id),
312 peer: Peer::new(id.0 as u32),
313 app_state: app_state.clone(),
314 connection_pool: Default::default(),
315 handlers: Default::default(),
316 teardown: watch::channel(false).0,
317 };
318
319 server
320 .add_request_handler(ping)
321 .add_request_handler(create_room)
322 .add_request_handler(join_room)
323 .add_request_handler(rejoin_room)
324 .add_request_handler(leave_room)
325 .add_request_handler(set_room_participant_role)
326 .add_request_handler(call)
327 .add_request_handler(cancel_call)
328 .add_message_handler(decline_call)
329 .add_request_handler(update_participant_location)
330 .add_request_handler(share_project)
331 .add_message_handler(unshare_project)
332 .add_request_handler(join_project)
333 .add_message_handler(leave_project)
334 .add_request_handler(update_project)
335 .add_request_handler(update_worktree)
336 .add_request_handler(update_repository)
337 .add_request_handler(remove_repository)
338 .add_message_handler(start_language_server)
339 .add_message_handler(update_language_server)
340 .add_message_handler(update_diagnostic_summary)
341 .add_message_handler(update_worktree_settings)
342 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
343 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
344 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
345 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
346 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
347 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
348 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
349 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
350 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
351 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
352 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
353 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
354 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
355 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
356 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
357 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
358 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
359 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
360 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
361 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
362 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
363 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
364 .add_request_handler(
365 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
366 )
367 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
368 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
369 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
370 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
371 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
372 .add_request_handler(
373 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
374 )
375 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
376 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
377 .add_request_handler(
378 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
379 )
380 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
381 .add_request_handler(
382 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
383 )
384 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
385 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
386 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
387 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
388 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
389 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
390 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
391 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
392 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
393 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
394 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
395 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
396 .add_request_handler(
397 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
398 )
399 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
400 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
401 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
402 .add_request_handler(multi_lsp_query)
403 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
404 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
405 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
406 .add_message_handler(create_buffer_for_peer)
407 .add_request_handler(update_buffer)
408 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
409 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
410 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
411 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
412 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
413 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
414 .add_message_handler(
415 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
416 )
417 .add_request_handler(get_users)
418 .add_request_handler(fuzzy_search_users)
419 .add_request_handler(request_contact)
420 .add_request_handler(remove_contact)
421 .add_request_handler(respond_to_contact_request)
422 .add_message_handler(subscribe_to_channels)
423 .add_request_handler(create_channel)
424 .add_request_handler(delete_channel)
425 .add_request_handler(invite_channel_member)
426 .add_request_handler(remove_channel_member)
427 .add_request_handler(set_channel_member_role)
428 .add_request_handler(set_channel_visibility)
429 .add_request_handler(rename_channel)
430 .add_request_handler(join_channel_buffer)
431 .add_request_handler(leave_channel_buffer)
432 .add_message_handler(update_channel_buffer)
433 .add_request_handler(rejoin_channel_buffers)
434 .add_request_handler(get_channel_members)
435 .add_request_handler(respond_to_channel_invite)
436 .add_request_handler(join_channel)
437 .add_request_handler(join_channel_chat)
438 .add_message_handler(leave_channel_chat)
439 .add_request_handler(send_channel_message)
440 .add_request_handler(remove_channel_message)
441 .add_request_handler(update_channel_message)
442 .add_request_handler(get_channel_messages)
443 .add_request_handler(get_channel_messages_by_id)
444 .add_request_handler(get_notifications)
445 .add_request_handler(mark_notification_as_read)
446 .add_request_handler(move_channel)
447 .add_request_handler(reorder_channel)
448 .add_request_handler(follow)
449 .add_message_handler(unfollow)
450 .add_message_handler(update_followers)
451 .add_message_handler(acknowledge_channel_message)
452 .add_message_handler(acknowledge_buffer_version)
453 .add_request_handler(get_supermaven_api_key)
454 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
455 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
456 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
457 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
458 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
459 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
460 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
461 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
462 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
463 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
464 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
465 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
466 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
467 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
468 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
469 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
470 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
471 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
472 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
473 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
474 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
475 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
476 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
477 .add_message_handler(update_context);
478
479 Arc::new(server)
480 }
481
482 pub async fn start(&self) -> Result<()> {
483 let server_id = *self.id.lock();
484 let app_state = self.app_state.clone();
485 let peer = self.peer.clone();
486 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
487 let pool = self.connection_pool.clone();
488 let livekit_client = self.app_state.livekit_client.clone();
489
490 let span = info_span!("start server");
491 self.app_state.executor.spawn_detached(
492 async move {
493 tracing::info!("waiting for cleanup timeout");
494 timeout.await;
495 tracing::info!("cleanup timeout expired, retrieving stale rooms");
496
497 app_state
498 .db
499 .delete_stale_channel_chat_participants(
500 &app_state.config.zed_environment,
501 server_id,
502 )
503 .await
504 .trace_err();
505
506 if let Some((room_ids, channel_ids)) = app_state
507 .db
508 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
509 .await
510 .trace_err()
511 {
512 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
513 tracing::info!(
514 stale_channel_buffer_count = channel_ids.len(),
515 "retrieved stale channel buffers"
516 );
517
518 for channel_id in channel_ids {
519 if let Some(refreshed_channel_buffer) = app_state
520 .db
521 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
522 .await
523 .trace_err()
524 {
525 for connection_id in refreshed_channel_buffer.connection_ids {
526 peer.send(
527 connection_id,
528 proto::UpdateChannelBufferCollaborators {
529 channel_id: channel_id.to_proto(),
530 collaborators: refreshed_channel_buffer
531 .collaborators
532 .clone(),
533 },
534 )
535 .trace_err();
536 }
537 }
538 }
539
540 for room_id in room_ids {
541 let mut contacts_to_update = HashSet::default();
542 let mut canceled_calls_to_user_ids = Vec::new();
543 let mut livekit_room = String::new();
544 let mut delete_livekit_room = false;
545
546 if let Some(mut refreshed_room) = app_state
547 .db
548 .clear_stale_room_participants(room_id, server_id)
549 .await
550 .trace_err()
551 {
552 tracing::info!(
553 room_id = room_id.0,
554 new_participant_count = refreshed_room.room.participants.len(),
555 "refreshed room"
556 );
557 room_updated(&refreshed_room.room, &peer);
558 if let Some(channel) = refreshed_room.channel.as_ref() {
559 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
560 }
561 contacts_to_update
562 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
563 contacts_to_update
564 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
565 canceled_calls_to_user_ids =
566 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
567 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
568 delete_livekit_room = refreshed_room.room.participants.is_empty();
569 }
570
571 {
572 let pool = pool.lock();
573 for canceled_user_id in canceled_calls_to_user_ids {
574 for connection_id in pool.user_connection_ids(canceled_user_id) {
575 peer.send(
576 connection_id,
577 proto::CallCanceled {
578 room_id: room_id.to_proto(),
579 },
580 )
581 .trace_err();
582 }
583 }
584 }
585
586 for user_id in contacts_to_update {
587 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
588 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
589 if let Some((busy, contacts)) = busy.zip(contacts) {
590 let pool = pool.lock();
591 let updated_contact = contact_for_user(user_id, busy, &pool);
592 for contact in contacts {
593 if let db::Contact::Accepted {
594 user_id: contact_user_id,
595 ..
596 } = contact
597 {
598 for contact_conn_id in
599 pool.user_connection_ids(contact_user_id)
600 {
601 peer.send(
602 contact_conn_id,
603 proto::UpdateContacts {
604 contacts: vec![updated_contact.clone()],
605 remove_contacts: Default::default(),
606 incoming_requests: Default::default(),
607 remove_incoming_requests: Default::default(),
608 outgoing_requests: Default::default(),
609 remove_outgoing_requests: Default::default(),
610 },
611 )
612 .trace_err();
613 }
614 }
615 }
616 }
617 }
618
619 if let Some(live_kit) = livekit_client.as_ref()
620 && delete_livekit_room {
621 live_kit.delete_room(livekit_room).await.trace_err();
622 }
623 }
624 }
625
626 app_state
627 .db
628 .delete_stale_channel_chat_participants(
629 &app_state.config.zed_environment,
630 server_id,
631 )
632 .await
633 .trace_err();
634
635 app_state
636 .db
637 .clear_old_worktree_entries(server_id)
638 .await
639 .trace_err();
640
641 app_state
642 .db
643 .delete_stale_servers(&app_state.config.zed_environment, server_id)
644 .await
645 .trace_err();
646 }
647 .instrument(span),
648 );
649 Ok(())
650 }
651
652 pub fn teardown(&self) {
653 self.peer.teardown();
654 self.connection_pool.lock().reset();
655 let _ = self.teardown.send(true);
656 }
657
658 #[cfg(test)]
659 pub fn reset(&self, id: ServerId) {
660 self.teardown();
661 *self.id.lock() = id;
662 self.peer.reset(id.0 as u32);
663 let _ = self.teardown.send(false);
664 }
665
666 #[cfg(test)]
667 pub fn id(&self) -> ServerId {
668 *self.id.lock()
669 }
670
671 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
672 where
673 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
674 Fut: 'static + Send + Future<Output = Result<()>>,
675 M: EnvelopedMessage,
676 {
677 let prev_handler = self.handlers.insert(
678 TypeId::of::<M>(),
679 Box::new(move |envelope, session, span| {
680 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
681 let received_at = envelope.received_at;
682 tracing::info!("message received");
683 let start_time = Instant::now();
684 let future = (handler)(
685 *envelope,
686 MessageContext {
687 session,
688 span: span.clone(),
689 },
690 );
691 async move {
692 let result = future.await;
693 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
694 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
695 let queue_duration_ms = total_duration_ms - processing_duration_ms;
696 span.record(TOTAL_DURATION_MS, total_duration_ms);
697 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
698 span.record(QUEUE_DURATION_MS, queue_duration_ms);
699 match result {
700 Err(error) => {
701 tracing::error!(?error, "error handling message")
702 }
703 Ok(()) => tracing::info!("finished handling message"),
704 }
705 }
706 .boxed()
707 }),
708 );
709 if prev_handler.is_some() {
710 panic!("registered a handler for the same message twice");
711 }
712 self
713 }
714
715 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
716 where
717 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
718 Fut: 'static + Send + Future<Output = Result<()>>,
719 M: EnvelopedMessage,
720 {
721 self.add_handler(move |envelope, session| handler(envelope.payload, session));
722 self
723 }
724
725 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
726 where
727 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
728 Fut: Send + Future<Output = Result<()>>,
729 M: RequestMessage,
730 {
731 let handler = Arc::new(handler);
732 self.add_handler(move |envelope, session| {
733 let receipt = envelope.receipt();
734 let handler = handler.clone();
735 async move {
736 let peer = session.peer.clone();
737 let responded = Arc::new(AtomicBool::default());
738 let response = Response {
739 peer: peer.clone(),
740 responded: responded.clone(),
741 receipt,
742 };
743 match (handler)(envelope.payload, response, session).await {
744 Ok(()) => {
745 if responded.load(std::sync::atomic::Ordering::SeqCst) {
746 Ok(())
747 } else {
748 Err(anyhow!("handler did not send a response"))?
749 }
750 }
751 Err(error) => {
752 let proto_err = match &error {
753 Error::Internal(err) => err.to_proto(),
754 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
755 };
756 peer.respond_with_error(receipt, proto_err)?;
757 Err(error)
758 }
759 }
760 }
761 })
762 }
763
764 pub fn handle_connection(
765 self: &Arc<Self>,
766 connection: Connection,
767 address: String,
768 principal: Principal,
769 zed_version: ZedVersion,
770 release_channel: Option<String>,
771 user_agent: Option<String>,
772 geoip_country_code: Option<String>,
773 system_id: Option<String>,
774 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
775 executor: Executor,
776 connection_guard: Option<ConnectionGuard>,
777 ) -> impl Future<Output = ()> + use<> {
778 let this = self.clone();
779 let span = info_span!("handle connection", %address,
780 connection_id=field::Empty,
781 user_id=field::Empty,
782 login=field::Empty,
783 impersonator=field::Empty,
784 user_agent=field::Empty,
785 geoip_country_code=field::Empty,
786 release_channel=field::Empty,
787 );
788 principal.update_span(&span);
789 if let Some(user_agent) = user_agent {
790 span.record("user_agent", user_agent);
791 }
792 if let Some(release_channel) = release_channel {
793 span.record("release_channel", release_channel);
794 }
795
796 if let Some(country_code) = geoip_country_code.as_ref() {
797 span.record("geoip_country_code", country_code);
798 }
799
800 let mut teardown = self.teardown.subscribe();
801 async move {
802 if *teardown.borrow() {
803 tracing::error!("server is tearing down");
804 return;
805 }
806
807 let (connection_id, handle_io, mut incoming_rx) =
808 this.peer.add_connection(connection, {
809 let executor = executor.clone();
810 move |duration| executor.sleep(duration)
811 });
812 tracing::Span::current().record("connection_id", format!("{}", connection_id));
813
814 tracing::info!("connection opened");
815
816 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
817 let http_client = match ReqwestClient::user_agent(&user_agent) {
818 Ok(http_client) => Arc::new(http_client),
819 Err(error) => {
820 tracing::error!(?error, "failed to create HTTP client");
821 return;
822 }
823 };
824
825 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
826 |supermaven_admin_api_key| {
827 Arc::new(SupermavenAdminApi::new(
828 supermaven_admin_api_key.to_string(),
829 http_client.clone(),
830 ))
831 },
832 );
833
834 let session = Session {
835 principal: principal.clone(),
836 connection_id,
837 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
838 peer: this.peer.clone(),
839 connection_pool: this.connection_pool.clone(),
840 app_state: this.app_state.clone(),
841 geoip_country_code,
842 system_id,
843 _executor: executor.clone(),
844 supermaven_client,
845 };
846
847 if let Err(error) = this
848 .send_initial_client_update(
849 connection_id,
850 zed_version,
851 send_connection_id,
852 &session,
853 )
854 .await
855 {
856 tracing::error!(?error, "failed to send initial client update");
857 return;
858 }
859 drop(connection_guard);
860
861 let handle_io = handle_io.fuse();
862 futures::pin_mut!(handle_io);
863
864 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
865 // This prevents deadlocks when e.g., client A performs a request to client B and
866 // client B performs a request to client A. If both clients stop processing further
867 // messages until their respective request completes, they won't have a chance to
868 // respond to the other client's request and cause a deadlock.
869 //
870 // This arrangement ensures we will attempt to process earlier messages first, but fall
871 // back to processing messages arrived later in the spirit of making progress.
872 const MAX_CONCURRENT_HANDLERS: usize = 256;
873 let mut foreground_message_handlers = FuturesUnordered::new();
874 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
875 let get_concurrent_handlers = {
876 let concurrent_handlers = concurrent_handlers.clone();
877 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
878 };
879 loop {
880 let next_message = async {
881 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
882 let message = incoming_rx.next().await;
883 // Cache the concurrent_handlers here, so that we know what the
884 // queue looks like as each handler starts
885 (permit, message, get_concurrent_handlers())
886 }
887 .fuse();
888 futures::pin_mut!(next_message);
889 futures::select_biased! {
890 _ = teardown.changed().fuse() => return,
891 result = handle_io => {
892 if let Err(error) = result {
893 tracing::error!(?error, "error handling I/O");
894 }
895 break;
896 }
897 _ = foreground_message_handlers.next() => {}
898 next_message = next_message => {
899 let (permit, message, concurrent_handlers) = next_message;
900 if let Some(message) = message {
901 let type_name = message.payload_type_name();
902 // note: we copy all the fields from the parent span so we can query them in the logs.
903 // (https://github.com/tokio-rs/tracing/issues/2670).
904 let span = tracing::info_span!("receive message",
905 %connection_id,
906 %address,
907 type_name,
908 concurrent_handlers,
909 user_id=field::Empty,
910 login=field::Empty,
911 impersonator=field::Empty,
912 multi_lsp_query_request=field::Empty,
913 release_channel=field::Empty,
914 { TOTAL_DURATION_MS }=field::Empty,
915 { PROCESSING_DURATION_MS }=field::Empty,
916 { QUEUE_DURATION_MS }=field::Empty,
917 { HOST_WAITING_MS }=field::Empty
918 );
919 principal.update_span(&span);
920 let span_enter = span.enter();
921 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
922 let is_background = message.is_background();
923 let handle_message = (handler)(message, session.clone(), span.clone());
924 drop(span_enter);
925
926 let handle_message = async move {
927 handle_message.await;
928 drop(permit);
929 }.instrument(span);
930 if is_background {
931 executor.spawn_detached(handle_message);
932 } else {
933 foreground_message_handlers.push(handle_message);
934 }
935 } else {
936 tracing::error!("no message handler");
937 }
938 } else {
939 tracing::info!("connection closed");
940 break;
941 }
942 }
943 }
944 }
945
946 drop(foreground_message_handlers);
947 let concurrent_handlers = get_concurrent_handlers();
948 tracing::info!(concurrent_handlers, "signing out");
949 if let Err(error) = connection_lost(session, teardown, executor).await {
950 tracing::error!(?error, "error signing out");
951 }
952 }
953 .instrument(span)
954 }
955
956 async fn send_initial_client_update(
957 &self,
958 connection_id: ConnectionId,
959 zed_version: ZedVersion,
960 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
961 session: &Session,
962 ) -> Result<()> {
963 self.peer.send(
964 connection_id,
965 proto::Hello {
966 peer_id: Some(connection_id.into()),
967 },
968 )?;
969 tracing::info!("sent hello message");
970 if let Some(send_connection_id) = send_connection_id.take() {
971 let _ = send_connection_id.send(connection_id);
972 }
973
974 match &session.principal {
975 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
976 if !user.connected_once {
977 self.peer.send(connection_id, proto::ShowContacts {})?;
978 self.app_state
979 .db
980 .set_user_connected_once(user.id, true)
981 .await?;
982 }
983
984 let contacts = self.app_state.db.get_contacts(user.id).await?;
985
986 {
987 let mut pool = self.connection_pool.lock();
988 pool.add_connection(connection_id, user.id, user.admin, zed_version);
989 self.peer.send(
990 connection_id,
991 build_initial_contacts_update(contacts, &pool),
992 )?;
993 }
994
995 if should_auto_subscribe_to_channels(zed_version) {
996 subscribe_user_to_channels(user.id, session).await?;
997 }
998
999 if let Some(incoming_call) =
1000 self.app_state.db.incoming_call_for_user(user.id).await?
1001 {
1002 self.peer.send(connection_id, incoming_call)?;
1003 }
1004
1005 update_user_contacts(user.id, session).await?;
1006 }
1007 }
1008
1009 Ok(())
1010 }
1011
1012 pub async fn invite_code_redeemed(
1013 self: &Arc<Self>,
1014 inviter_id: UserId,
1015 invitee_id: UserId,
1016 ) -> Result<()> {
1017 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1018 && let Some(code) = &user.invite_code {
1019 let pool = self.connection_pool.lock();
1020 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1021 for connection_id in pool.user_connection_ids(inviter_id) {
1022 self.peer.send(
1023 connection_id,
1024 proto::UpdateContacts {
1025 contacts: vec![invitee_contact.clone()],
1026 ..Default::default()
1027 },
1028 )?;
1029 self.peer.send(
1030 connection_id,
1031 proto::UpdateInviteInfo {
1032 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1033 count: user.invite_count as u32,
1034 },
1035 )?;
1036 }
1037 }
1038 Ok(())
1039 }
1040
1041 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1042 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1043 && let Some(invite_code) = &user.invite_code {
1044 let pool = self.connection_pool.lock();
1045 for connection_id in pool.user_connection_ids(user_id) {
1046 self.peer.send(
1047 connection_id,
1048 proto::UpdateInviteInfo {
1049 url: format!(
1050 "{}{}",
1051 self.app_state.config.invite_link_prefix, invite_code
1052 ),
1053 count: user.invite_count as u32,
1054 },
1055 )?;
1056 }
1057 }
1058 Ok(())
1059 }
1060
1061 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1062 ServerSnapshot {
1063 connection_pool: ConnectionPoolGuard {
1064 guard: self.connection_pool.lock(),
1065 _not_send: PhantomData,
1066 },
1067 peer: &self.peer,
1068 }
1069 }
1070}
1071
1072impl Deref for ConnectionPoolGuard<'_> {
1073 type Target = ConnectionPool;
1074
1075 fn deref(&self) -> &Self::Target {
1076 &self.guard
1077 }
1078}
1079
1080impl DerefMut for ConnectionPoolGuard<'_> {
1081 fn deref_mut(&mut self) -> &mut Self::Target {
1082 &mut self.guard
1083 }
1084}
1085
1086impl Drop for ConnectionPoolGuard<'_> {
1087 fn drop(&mut self) {
1088 #[cfg(test)]
1089 self.check_invariants();
1090 }
1091}
1092
1093fn broadcast<F>(
1094 sender_id: Option<ConnectionId>,
1095 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1096 mut f: F,
1097) where
1098 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1099{
1100 for receiver_id in receiver_ids {
1101 if Some(receiver_id) != sender_id
1102 && let Err(error) = f(receiver_id) {
1103 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1104 }
1105 }
1106}
1107
1108pub struct ProtocolVersion(u32);
1109
1110impl Header for ProtocolVersion {
1111 fn name() -> &'static HeaderName {
1112 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1113 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1114 }
1115
1116 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1117 where
1118 Self: Sized,
1119 I: Iterator<Item = &'i axum::http::HeaderValue>,
1120 {
1121 let version = values
1122 .next()
1123 .ok_or_else(axum::headers::Error::invalid)?
1124 .to_str()
1125 .map_err(|_| axum::headers::Error::invalid())?
1126 .parse()
1127 .map_err(|_| axum::headers::Error::invalid())?;
1128 Ok(Self(version))
1129 }
1130
1131 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1132 values.extend([self.0.to_string().parse().unwrap()]);
1133 }
1134}
1135
1136pub struct AppVersionHeader(SemanticVersion);
1137impl Header for AppVersionHeader {
1138 fn name() -> &'static HeaderName {
1139 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1140 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1141 }
1142
1143 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1144 where
1145 Self: Sized,
1146 I: Iterator<Item = &'i axum::http::HeaderValue>,
1147 {
1148 let version = values
1149 .next()
1150 .ok_or_else(axum::headers::Error::invalid)?
1151 .to_str()
1152 .map_err(|_| axum::headers::Error::invalid())?
1153 .parse()
1154 .map_err(|_| axum::headers::Error::invalid())?;
1155 Ok(Self(version))
1156 }
1157
1158 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1159 values.extend([self.0.to_string().parse().unwrap()]);
1160 }
1161}
1162
1163#[derive(Debug)]
1164pub struct ReleaseChannelHeader(String);
1165
1166impl Header for ReleaseChannelHeader {
1167 fn name() -> &'static HeaderName {
1168 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1169 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1170 }
1171
1172 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1173 where
1174 Self: Sized,
1175 I: Iterator<Item = &'i axum::http::HeaderValue>,
1176 {
1177 Ok(Self(
1178 values
1179 .next()
1180 .ok_or_else(axum::headers::Error::invalid)?
1181 .to_str()
1182 .map_err(|_| axum::headers::Error::invalid())?
1183 .to_owned(),
1184 ))
1185 }
1186
1187 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1188 values.extend([self.0.parse().unwrap()]);
1189 }
1190}
1191
1192pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1193 Router::new()
1194 .route("/rpc", get(handle_websocket_request))
1195 .layer(
1196 ServiceBuilder::new()
1197 .layer(Extension(server.app_state.clone()))
1198 .layer(middleware::from_fn(auth::validate_header)),
1199 )
1200 .route("/metrics", get(handle_metrics))
1201 .layer(Extension(server))
1202}
1203
1204pub async fn handle_websocket_request(
1205 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1206 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1207 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1208 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1209 Extension(server): Extension<Arc<Server>>,
1210 Extension(principal): Extension<Principal>,
1211 user_agent: Option<TypedHeader<UserAgent>>,
1212 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1213 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1214 ws: WebSocketUpgrade,
1215) -> axum::response::Response {
1216 if protocol_version != rpc::PROTOCOL_VERSION {
1217 return (
1218 StatusCode::UPGRADE_REQUIRED,
1219 "client must be upgraded".to_string(),
1220 )
1221 .into_response();
1222 }
1223
1224 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1225 return (
1226 StatusCode::UPGRADE_REQUIRED,
1227 "no version header found".to_string(),
1228 )
1229 .into_response();
1230 };
1231
1232 let release_channel = release_channel_header.map(|header| header.0.0);
1233
1234 if !version.can_collaborate() {
1235 return (
1236 StatusCode::UPGRADE_REQUIRED,
1237 "client must be upgraded".to_string(),
1238 )
1239 .into_response();
1240 }
1241
1242 let socket_address = socket_address.to_string();
1243
1244 // Acquire connection guard before WebSocket upgrade
1245 let connection_guard = match ConnectionGuard::try_acquire() {
1246 Ok(guard) => guard,
1247 Err(()) => {
1248 return (
1249 StatusCode::SERVICE_UNAVAILABLE,
1250 "Too many concurrent connections",
1251 )
1252 .into_response();
1253 }
1254 };
1255
1256 ws.on_upgrade(move |socket| {
1257 let socket = socket
1258 .map_ok(to_tungstenite_message)
1259 .err_into()
1260 .with(|message| async move { to_axum_message(message) });
1261 let connection = Connection::new(Box::pin(socket));
1262 async move {
1263 server
1264 .handle_connection(
1265 connection,
1266 socket_address,
1267 principal,
1268 version,
1269 release_channel,
1270 user_agent.map(|header| header.to_string()),
1271 country_code_header.map(|header| header.to_string()),
1272 system_id_header.map(|header| header.to_string()),
1273 None,
1274 Executor::Production,
1275 Some(connection_guard),
1276 )
1277 .await;
1278 }
1279 })
1280}
1281
1282pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1283 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1284 let connections_metric = CONNECTIONS_METRIC
1285 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1286
1287 let connections = server
1288 .connection_pool
1289 .lock()
1290 .connections()
1291 .filter(|connection| !connection.admin)
1292 .count();
1293 connections_metric.set(connections as _);
1294
1295 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1296 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1297 register_int_gauge!(
1298 "shared_projects",
1299 "number of open projects with one or more guests"
1300 )
1301 .unwrap()
1302 });
1303
1304 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1305 shared_projects_metric.set(shared_projects as _);
1306
1307 let encoder = prometheus::TextEncoder::new();
1308 let metric_families = prometheus::gather();
1309 let encoded_metrics = encoder
1310 .encode_to_string(&metric_families)
1311 .map_err(|err| anyhow!("{err}"))?;
1312 Ok(encoded_metrics)
1313}
1314
1315#[instrument(err, skip(executor))]
1316async fn connection_lost(
1317 session: Session,
1318 mut teardown: watch::Receiver<bool>,
1319 executor: Executor,
1320) -> Result<()> {
1321 session.peer.disconnect(session.connection_id);
1322 session
1323 .connection_pool()
1324 .await
1325 .remove_connection(session.connection_id)?;
1326
1327 session
1328 .db()
1329 .await
1330 .connection_lost(session.connection_id)
1331 .await
1332 .trace_err();
1333
1334 futures::select_biased! {
1335 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1336
1337 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1338 leave_room_for_session(&session, session.connection_id).await.trace_err();
1339 leave_channel_buffers_for_session(&session)
1340 .await
1341 .trace_err();
1342
1343 if !session
1344 .connection_pool()
1345 .await
1346 .is_user_online(session.user_id())
1347 {
1348 let db = session.db().await;
1349 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1350 room_updated(&room, &session.peer);
1351 }
1352 }
1353
1354 update_user_contacts(session.user_id(), &session).await?;
1355 },
1356 _ = teardown.changed().fuse() => {}
1357 }
1358
1359 Ok(())
1360}
1361
1362/// Acknowledges a ping from a client, used to keep the connection alive.
1363async fn ping(
1364 _: proto::Ping,
1365 response: Response<proto::Ping>,
1366 _session: MessageContext,
1367) -> Result<()> {
1368 response.send(proto::Ack {})?;
1369 Ok(())
1370}
1371
1372/// Creates a new room for calling (outside of channels)
1373async fn create_room(
1374 _request: proto::CreateRoom,
1375 response: Response<proto::CreateRoom>,
1376 session: MessageContext,
1377) -> Result<()> {
1378 let livekit_room = nanoid::nanoid!(30);
1379
1380 let live_kit_connection_info = util::maybe!(async {
1381 let live_kit = session.app_state.livekit_client.as_ref();
1382 let live_kit = live_kit?;
1383 let user_id = session.user_id().to_string();
1384
1385 let token = live_kit
1386 .room_token(&livekit_room, &user_id.to_string())
1387 .trace_err()?;
1388
1389 Some(proto::LiveKitConnectionInfo {
1390 server_url: live_kit.url().into(),
1391 token,
1392 can_publish: true,
1393 })
1394 })
1395 .await;
1396
1397 let room = session
1398 .db()
1399 .await
1400 .create_room(session.user_id(), session.connection_id, &livekit_room)
1401 .await?;
1402
1403 response.send(proto::CreateRoomResponse {
1404 room: Some(room.clone()),
1405 live_kit_connection_info,
1406 })?;
1407
1408 update_user_contacts(session.user_id(), &session).await?;
1409 Ok(())
1410}
1411
1412/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1413async fn join_room(
1414 request: proto::JoinRoom,
1415 response: Response<proto::JoinRoom>,
1416 session: MessageContext,
1417) -> Result<()> {
1418 let room_id = RoomId::from_proto(request.id);
1419
1420 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1421
1422 if let Some(channel_id) = channel_id {
1423 return join_channel_internal(channel_id, Box::new(response), session).await;
1424 }
1425
1426 let joined_room = {
1427 let room = session
1428 .db()
1429 .await
1430 .join_room(room_id, session.user_id(), session.connection_id)
1431 .await?;
1432 room_updated(&room.room, &session.peer);
1433 room.into_inner()
1434 };
1435
1436 for connection_id in session
1437 .connection_pool()
1438 .await
1439 .user_connection_ids(session.user_id())
1440 {
1441 session
1442 .peer
1443 .send(
1444 connection_id,
1445 proto::CallCanceled {
1446 room_id: room_id.to_proto(),
1447 },
1448 )
1449 .trace_err();
1450 }
1451
1452 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1453 {
1454 live_kit
1455 .room_token(
1456 &joined_room.room.livekit_room,
1457 &session.user_id().to_string(),
1458 )
1459 .trace_err()
1460 .map(|token| proto::LiveKitConnectionInfo {
1461 server_url: live_kit.url().into(),
1462 token,
1463 can_publish: true,
1464 })
1465 } else {
1466 None
1467 };
1468
1469 response.send(proto::JoinRoomResponse {
1470 room: Some(joined_room.room),
1471 channel_id: None,
1472 live_kit_connection_info,
1473 })?;
1474
1475 update_user_contacts(session.user_id(), &session).await?;
1476 Ok(())
1477}
1478
1479/// Rejoin room is used to reconnect to a room after connection errors.
1480async fn rejoin_room(
1481 request: proto::RejoinRoom,
1482 response: Response<proto::RejoinRoom>,
1483 session: MessageContext,
1484) -> Result<()> {
1485 let room;
1486 let channel;
1487 {
1488 let mut rejoined_room = session
1489 .db()
1490 .await
1491 .rejoin_room(request, session.user_id(), session.connection_id)
1492 .await?;
1493
1494 response.send(proto::RejoinRoomResponse {
1495 room: Some(rejoined_room.room.clone()),
1496 reshared_projects: rejoined_room
1497 .reshared_projects
1498 .iter()
1499 .map(|project| proto::ResharedProject {
1500 id: project.id.to_proto(),
1501 collaborators: project
1502 .collaborators
1503 .iter()
1504 .map(|collaborator| collaborator.to_proto())
1505 .collect(),
1506 })
1507 .collect(),
1508 rejoined_projects: rejoined_room
1509 .rejoined_projects
1510 .iter()
1511 .map(|rejoined_project| rejoined_project.to_proto())
1512 .collect(),
1513 })?;
1514 room_updated(&rejoined_room.room, &session.peer);
1515
1516 for project in &rejoined_room.reshared_projects {
1517 for collaborator in &project.collaborators {
1518 session
1519 .peer
1520 .send(
1521 collaborator.connection_id,
1522 proto::UpdateProjectCollaborator {
1523 project_id: project.id.to_proto(),
1524 old_peer_id: Some(project.old_connection_id.into()),
1525 new_peer_id: Some(session.connection_id.into()),
1526 },
1527 )
1528 .trace_err();
1529 }
1530
1531 broadcast(
1532 Some(session.connection_id),
1533 project
1534 .collaborators
1535 .iter()
1536 .map(|collaborator| collaborator.connection_id),
1537 |connection_id| {
1538 session.peer.forward_send(
1539 session.connection_id,
1540 connection_id,
1541 proto::UpdateProject {
1542 project_id: project.id.to_proto(),
1543 worktrees: project.worktrees.clone(),
1544 },
1545 )
1546 },
1547 );
1548 }
1549
1550 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1551
1552 let rejoined_room = rejoined_room.into_inner();
1553
1554 room = rejoined_room.room;
1555 channel = rejoined_room.channel;
1556 }
1557
1558 if let Some(channel) = channel {
1559 channel_updated(
1560 &channel,
1561 &room,
1562 &session.peer,
1563 &*session.connection_pool().await,
1564 );
1565 }
1566
1567 update_user_contacts(session.user_id(), &session).await?;
1568 Ok(())
1569}
1570
1571fn notify_rejoined_projects(
1572 rejoined_projects: &mut Vec<RejoinedProject>,
1573 session: &Session,
1574) -> Result<()> {
1575 for project in rejoined_projects.iter() {
1576 for collaborator in &project.collaborators {
1577 session
1578 .peer
1579 .send(
1580 collaborator.connection_id,
1581 proto::UpdateProjectCollaborator {
1582 project_id: project.id.to_proto(),
1583 old_peer_id: Some(project.old_connection_id.into()),
1584 new_peer_id: Some(session.connection_id.into()),
1585 },
1586 )
1587 .trace_err();
1588 }
1589 }
1590
1591 for project in rejoined_projects {
1592 for worktree in mem::take(&mut project.worktrees) {
1593 // Stream this worktree's entries.
1594 let message = proto::UpdateWorktree {
1595 project_id: project.id.to_proto(),
1596 worktree_id: worktree.id,
1597 abs_path: worktree.abs_path.clone(),
1598 root_name: worktree.root_name,
1599 updated_entries: worktree.updated_entries,
1600 removed_entries: worktree.removed_entries,
1601 scan_id: worktree.scan_id,
1602 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1603 updated_repositories: worktree.updated_repositories,
1604 removed_repositories: worktree.removed_repositories,
1605 };
1606 for update in proto::split_worktree_update(message) {
1607 session.peer.send(session.connection_id, update)?;
1608 }
1609
1610 // Stream this worktree's diagnostics.
1611 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1612 if let Some(summary) = worktree_diagnostics.next() {
1613 let message = proto::UpdateDiagnosticSummary {
1614 project_id: project.id.to_proto(),
1615 worktree_id: worktree.id,
1616 summary: Some(summary),
1617 more_summaries: worktree_diagnostics.collect(),
1618 };
1619 session.peer.send(session.connection_id, message)?;
1620 }
1621
1622 for settings_file in worktree.settings_files {
1623 session.peer.send(
1624 session.connection_id,
1625 proto::UpdateWorktreeSettings {
1626 project_id: project.id.to_proto(),
1627 worktree_id: worktree.id,
1628 path: settings_file.path,
1629 content: Some(settings_file.content),
1630 kind: Some(settings_file.kind.to_proto().into()),
1631 },
1632 )?;
1633 }
1634 }
1635
1636 for repository in mem::take(&mut project.updated_repositories) {
1637 for update in split_repository_update(repository) {
1638 session.peer.send(session.connection_id, update)?;
1639 }
1640 }
1641
1642 for id in mem::take(&mut project.removed_repositories) {
1643 session.peer.send(
1644 session.connection_id,
1645 proto::RemoveRepository {
1646 project_id: project.id.to_proto(),
1647 id,
1648 },
1649 )?;
1650 }
1651 }
1652
1653 Ok(())
1654}
1655
1656/// leave room disconnects from the room.
1657async fn leave_room(
1658 _: proto::LeaveRoom,
1659 response: Response<proto::LeaveRoom>,
1660 session: MessageContext,
1661) -> Result<()> {
1662 leave_room_for_session(&session, session.connection_id).await?;
1663 response.send(proto::Ack {})?;
1664 Ok(())
1665}
1666
1667/// Updates the permissions of someone else in the room.
1668async fn set_room_participant_role(
1669 request: proto::SetRoomParticipantRole,
1670 response: Response<proto::SetRoomParticipantRole>,
1671 session: MessageContext,
1672) -> Result<()> {
1673 let user_id = UserId::from_proto(request.user_id);
1674 let role = ChannelRole::from(request.role());
1675
1676 let (livekit_room, can_publish) = {
1677 let room = session
1678 .db()
1679 .await
1680 .set_room_participant_role(
1681 session.user_id(),
1682 RoomId::from_proto(request.room_id),
1683 user_id,
1684 role,
1685 )
1686 .await?;
1687
1688 let livekit_room = room.livekit_room.clone();
1689 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1690 room_updated(&room, &session.peer);
1691 (livekit_room, can_publish)
1692 };
1693
1694 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1695 live_kit
1696 .update_participant(
1697 livekit_room.clone(),
1698 request.user_id.to_string(),
1699 livekit_api::proto::ParticipantPermission {
1700 can_subscribe: true,
1701 can_publish,
1702 can_publish_data: can_publish,
1703 hidden: false,
1704 recorder: false,
1705 },
1706 )
1707 .await
1708 .trace_err();
1709 }
1710
1711 response.send(proto::Ack {})?;
1712 Ok(())
1713}
1714
1715/// Call someone else into the current room
1716async fn call(
1717 request: proto::Call,
1718 response: Response<proto::Call>,
1719 session: MessageContext,
1720) -> Result<()> {
1721 let room_id = RoomId::from_proto(request.room_id);
1722 let calling_user_id = session.user_id();
1723 let calling_connection_id = session.connection_id;
1724 let called_user_id = UserId::from_proto(request.called_user_id);
1725 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1726 if !session
1727 .db()
1728 .await
1729 .has_contact(calling_user_id, called_user_id)
1730 .await?
1731 {
1732 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1733 }
1734
1735 let incoming_call = {
1736 let (room, incoming_call) = &mut *session
1737 .db()
1738 .await
1739 .call(
1740 room_id,
1741 calling_user_id,
1742 calling_connection_id,
1743 called_user_id,
1744 initial_project_id,
1745 )
1746 .await?;
1747 room_updated(room, &session.peer);
1748 mem::take(incoming_call)
1749 };
1750 update_user_contacts(called_user_id, &session).await?;
1751
1752 let mut calls = session
1753 .connection_pool()
1754 .await
1755 .user_connection_ids(called_user_id)
1756 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1757 .collect::<FuturesUnordered<_>>();
1758
1759 while let Some(call_response) = calls.next().await {
1760 match call_response.as_ref() {
1761 Ok(_) => {
1762 response.send(proto::Ack {})?;
1763 return Ok(());
1764 }
1765 Err(_) => {
1766 call_response.trace_err();
1767 }
1768 }
1769 }
1770
1771 {
1772 let room = session
1773 .db()
1774 .await
1775 .call_failed(room_id, called_user_id)
1776 .await?;
1777 room_updated(&room, &session.peer);
1778 }
1779 update_user_contacts(called_user_id, &session).await?;
1780
1781 Err(anyhow!("failed to ring user"))?
1782}
1783
1784/// Cancel an outgoing call.
1785async fn cancel_call(
1786 request: proto::CancelCall,
1787 response: Response<proto::CancelCall>,
1788 session: MessageContext,
1789) -> Result<()> {
1790 let called_user_id = UserId::from_proto(request.called_user_id);
1791 let room_id = RoomId::from_proto(request.room_id);
1792 {
1793 let room = session
1794 .db()
1795 .await
1796 .cancel_call(room_id, session.connection_id, called_user_id)
1797 .await?;
1798 room_updated(&room, &session.peer);
1799 }
1800
1801 for connection_id in session
1802 .connection_pool()
1803 .await
1804 .user_connection_ids(called_user_id)
1805 {
1806 session
1807 .peer
1808 .send(
1809 connection_id,
1810 proto::CallCanceled {
1811 room_id: room_id.to_proto(),
1812 },
1813 )
1814 .trace_err();
1815 }
1816 response.send(proto::Ack {})?;
1817
1818 update_user_contacts(called_user_id, &session).await?;
1819 Ok(())
1820}
1821
1822/// Decline an incoming call.
1823async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1824 let room_id = RoomId::from_proto(message.room_id);
1825 {
1826 let room = session
1827 .db()
1828 .await
1829 .decline_call(Some(room_id), session.user_id())
1830 .await?
1831 .context("declining call")?;
1832 room_updated(&room, &session.peer);
1833 }
1834
1835 for connection_id in session
1836 .connection_pool()
1837 .await
1838 .user_connection_ids(session.user_id())
1839 {
1840 session
1841 .peer
1842 .send(
1843 connection_id,
1844 proto::CallCanceled {
1845 room_id: room_id.to_proto(),
1846 },
1847 )
1848 .trace_err();
1849 }
1850 update_user_contacts(session.user_id(), &session).await?;
1851 Ok(())
1852}
1853
1854/// Updates other participants in the room with your current location.
1855async fn update_participant_location(
1856 request: proto::UpdateParticipantLocation,
1857 response: Response<proto::UpdateParticipantLocation>,
1858 session: MessageContext,
1859) -> Result<()> {
1860 let room_id = RoomId::from_proto(request.room_id);
1861 let location = request.location.context("invalid location")?;
1862
1863 let db = session.db().await;
1864 let room = db
1865 .update_room_participant_location(room_id, session.connection_id, location)
1866 .await?;
1867
1868 room_updated(&room, &session.peer);
1869 response.send(proto::Ack {})?;
1870 Ok(())
1871}
1872
1873/// Share a project into the room.
1874async fn share_project(
1875 request: proto::ShareProject,
1876 response: Response<proto::ShareProject>,
1877 session: MessageContext,
1878) -> Result<()> {
1879 let (project_id, room) = &*session
1880 .db()
1881 .await
1882 .share_project(
1883 RoomId::from_proto(request.room_id),
1884 session.connection_id,
1885 &request.worktrees,
1886 request.is_ssh_project,
1887 )
1888 .await?;
1889 response.send(proto::ShareProjectResponse {
1890 project_id: project_id.to_proto(),
1891 })?;
1892 room_updated(room, &session.peer);
1893
1894 Ok(())
1895}
1896
1897/// Unshare a project from the room.
1898async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1899 let project_id = ProjectId::from_proto(message.project_id);
1900 unshare_project_internal(project_id, session.connection_id, &session).await
1901}
1902
1903async fn unshare_project_internal(
1904 project_id: ProjectId,
1905 connection_id: ConnectionId,
1906 session: &Session,
1907) -> Result<()> {
1908 let delete = {
1909 let room_guard = session
1910 .db()
1911 .await
1912 .unshare_project(project_id, connection_id)
1913 .await?;
1914
1915 let (delete, room, guest_connection_ids) = &*room_guard;
1916
1917 let message = proto::UnshareProject {
1918 project_id: project_id.to_proto(),
1919 };
1920
1921 broadcast(
1922 Some(connection_id),
1923 guest_connection_ids.iter().copied(),
1924 |conn_id| session.peer.send(conn_id, message.clone()),
1925 );
1926 if let Some(room) = room {
1927 room_updated(room, &session.peer);
1928 }
1929
1930 *delete
1931 };
1932
1933 if delete {
1934 let db = session.db().await;
1935 db.delete_project(project_id).await?;
1936 }
1937
1938 Ok(())
1939}
1940
1941/// Join someone elses shared project.
1942async fn join_project(
1943 request: proto::JoinProject,
1944 response: Response<proto::JoinProject>,
1945 session: MessageContext,
1946) -> Result<()> {
1947 let project_id = ProjectId::from_proto(request.project_id);
1948
1949 tracing::info!(%project_id, "join project");
1950
1951 let db = session.db().await;
1952 let (project, replica_id) = &mut *db
1953 .join_project(
1954 project_id,
1955 session.connection_id,
1956 session.user_id(),
1957 request.committer_name.clone(),
1958 request.committer_email.clone(),
1959 )
1960 .await?;
1961 drop(db);
1962 tracing::info!(%project_id, "join remote project");
1963 let collaborators = project
1964 .collaborators
1965 .iter()
1966 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1967 .map(|collaborator| collaborator.to_proto())
1968 .collect::<Vec<_>>();
1969 let project_id = project.id;
1970 let guest_user_id = session.user_id();
1971
1972 let worktrees = project
1973 .worktrees
1974 .iter()
1975 .map(|(id, worktree)| proto::WorktreeMetadata {
1976 id: *id,
1977 root_name: worktree.root_name.clone(),
1978 visible: worktree.visible,
1979 abs_path: worktree.abs_path.clone(),
1980 })
1981 .collect::<Vec<_>>();
1982
1983 let add_project_collaborator = proto::AddProjectCollaborator {
1984 project_id: project_id.to_proto(),
1985 collaborator: Some(proto::Collaborator {
1986 peer_id: Some(session.connection_id.into()),
1987 replica_id: replica_id.0 as u32,
1988 user_id: guest_user_id.to_proto(),
1989 is_host: false,
1990 committer_name: request.committer_name.clone(),
1991 committer_email: request.committer_email.clone(),
1992 }),
1993 };
1994
1995 for collaborator in &collaborators {
1996 session
1997 .peer
1998 .send(
1999 collaborator.peer_id.unwrap().into(),
2000 add_project_collaborator.clone(),
2001 )
2002 .trace_err();
2003 }
2004
2005 // First, we send the metadata associated with each worktree.
2006 let (language_servers, language_server_capabilities) = project
2007 .language_servers
2008 .clone()
2009 .into_iter()
2010 .map(|server| (server.server, server.capabilities))
2011 .unzip();
2012 response.send(proto::JoinProjectResponse {
2013 project_id: project.id.0 as u64,
2014 worktrees: worktrees.clone(),
2015 replica_id: replica_id.0 as u32,
2016 collaborators: collaborators.clone(),
2017 language_servers,
2018 language_server_capabilities,
2019 role: project.role.into(),
2020 })?;
2021
2022 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2023 // Stream this worktree's entries.
2024 let message = proto::UpdateWorktree {
2025 project_id: project_id.to_proto(),
2026 worktree_id,
2027 abs_path: worktree.abs_path.clone(),
2028 root_name: worktree.root_name,
2029 updated_entries: worktree.entries,
2030 removed_entries: Default::default(),
2031 scan_id: worktree.scan_id,
2032 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2033 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2034 removed_repositories: Default::default(),
2035 };
2036 for update in proto::split_worktree_update(message) {
2037 session.peer.send(session.connection_id, update.clone())?;
2038 }
2039
2040 // Stream this worktree's diagnostics.
2041 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2042 if let Some(summary) = worktree_diagnostics.next() {
2043 let message = proto::UpdateDiagnosticSummary {
2044 project_id: project.id.to_proto(),
2045 worktree_id: worktree.id,
2046 summary: Some(summary),
2047 more_summaries: worktree_diagnostics.collect(),
2048 };
2049 session.peer.send(session.connection_id, message)?;
2050 }
2051
2052 for settings_file in worktree.settings_files {
2053 session.peer.send(
2054 session.connection_id,
2055 proto::UpdateWorktreeSettings {
2056 project_id: project_id.to_proto(),
2057 worktree_id: worktree.id,
2058 path: settings_file.path,
2059 content: Some(settings_file.content),
2060 kind: Some(settings_file.kind.to_proto() as i32),
2061 },
2062 )?;
2063 }
2064 }
2065
2066 for repository in mem::take(&mut project.repositories) {
2067 for update in split_repository_update(repository) {
2068 session.peer.send(session.connection_id, update)?;
2069 }
2070 }
2071
2072 for language_server in &project.language_servers {
2073 session.peer.send(
2074 session.connection_id,
2075 proto::UpdateLanguageServer {
2076 project_id: project_id.to_proto(),
2077 server_name: Some(language_server.server.name.clone()),
2078 language_server_id: language_server.server.id,
2079 variant: Some(
2080 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2081 proto::LspDiskBasedDiagnosticsUpdated {},
2082 ),
2083 ),
2084 },
2085 )?;
2086 }
2087
2088 Ok(())
2089}
2090
2091/// Leave someone elses shared project.
2092async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2093 let sender_id = session.connection_id;
2094 let project_id = ProjectId::from_proto(request.project_id);
2095 let db = session.db().await;
2096
2097 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2098 tracing::info!(
2099 %project_id,
2100 "leave project"
2101 );
2102
2103 project_left(project, &session);
2104 if let Some(room) = room {
2105 room_updated(room, &session.peer);
2106 }
2107
2108 Ok(())
2109}
2110
2111/// Updates other participants with changes to the project
2112async fn update_project(
2113 request: proto::UpdateProject,
2114 response: Response<proto::UpdateProject>,
2115 session: MessageContext,
2116) -> Result<()> {
2117 let project_id = ProjectId::from_proto(request.project_id);
2118 let (room, guest_connection_ids) = &*session
2119 .db()
2120 .await
2121 .update_project(project_id, session.connection_id, &request.worktrees)
2122 .await?;
2123 broadcast(
2124 Some(session.connection_id),
2125 guest_connection_ids.iter().copied(),
2126 |connection_id| {
2127 session
2128 .peer
2129 .forward_send(session.connection_id, connection_id, request.clone())
2130 },
2131 );
2132 if let Some(room) = room {
2133 room_updated(room, &session.peer);
2134 }
2135 response.send(proto::Ack {})?;
2136
2137 Ok(())
2138}
2139
2140/// Updates other participants with changes to the worktree
2141async fn update_worktree(
2142 request: proto::UpdateWorktree,
2143 response: Response<proto::UpdateWorktree>,
2144 session: MessageContext,
2145) -> Result<()> {
2146 let guest_connection_ids = session
2147 .db()
2148 .await
2149 .update_worktree(&request, session.connection_id)
2150 .await?;
2151
2152 broadcast(
2153 Some(session.connection_id),
2154 guest_connection_ids.iter().copied(),
2155 |connection_id| {
2156 session
2157 .peer
2158 .forward_send(session.connection_id, connection_id, request.clone())
2159 },
2160 );
2161 response.send(proto::Ack {})?;
2162 Ok(())
2163}
2164
2165async fn update_repository(
2166 request: proto::UpdateRepository,
2167 response: Response<proto::UpdateRepository>,
2168 session: MessageContext,
2169) -> Result<()> {
2170 let guest_connection_ids = session
2171 .db()
2172 .await
2173 .update_repository(&request, session.connection_id)
2174 .await?;
2175
2176 broadcast(
2177 Some(session.connection_id),
2178 guest_connection_ids.iter().copied(),
2179 |connection_id| {
2180 session
2181 .peer
2182 .forward_send(session.connection_id, connection_id, request.clone())
2183 },
2184 );
2185 response.send(proto::Ack {})?;
2186 Ok(())
2187}
2188
2189async fn remove_repository(
2190 request: proto::RemoveRepository,
2191 response: Response<proto::RemoveRepository>,
2192 session: MessageContext,
2193) -> Result<()> {
2194 let guest_connection_ids = session
2195 .db()
2196 .await
2197 .remove_repository(&request, session.connection_id)
2198 .await?;
2199
2200 broadcast(
2201 Some(session.connection_id),
2202 guest_connection_ids.iter().copied(),
2203 |connection_id| {
2204 session
2205 .peer
2206 .forward_send(session.connection_id, connection_id, request.clone())
2207 },
2208 );
2209 response.send(proto::Ack {})?;
2210 Ok(())
2211}
2212
2213/// Updates other participants with changes to the diagnostics
2214async fn update_diagnostic_summary(
2215 message: proto::UpdateDiagnosticSummary,
2216 session: MessageContext,
2217) -> Result<()> {
2218 let guest_connection_ids = session
2219 .db()
2220 .await
2221 .update_diagnostic_summary(&message, session.connection_id)
2222 .await?;
2223
2224 broadcast(
2225 Some(session.connection_id),
2226 guest_connection_ids.iter().copied(),
2227 |connection_id| {
2228 session
2229 .peer
2230 .forward_send(session.connection_id, connection_id, message.clone())
2231 },
2232 );
2233
2234 Ok(())
2235}
2236
2237/// Updates other participants with changes to the worktree settings
2238async fn update_worktree_settings(
2239 message: proto::UpdateWorktreeSettings,
2240 session: MessageContext,
2241) -> Result<()> {
2242 let guest_connection_ids = session
2243 .db()
2244 .await
2245 .update_worktree_settings(&message, session.connection_id)
2246 .await?;
2247
2248 broadcast(
2249 Some(session.connection_id),
2250 guest_connection_ids.iter().copied(),
2251 |connection_id| {
2252 session
2253 .peer
2254 .forward_send(session.connection_id, connection_id, message.clone())
2255 },
2256 );
2257
2258 Ok(())
2259}
2260
2261/// Notify other participants that a language server has started.
2262async fn start_language_server(
2263 request: proto::StartLanguageServer,
2264 session: MessageContext,
2265) -> Result<()> {
2266 let guest_connection_ids = session
2267 .db()
2268 .await
2269 .start_language_server(&request, session.connection_id)
2270 .await?;
2271
2272 broadcast(
2273 Some(session.connection_id),
2274 guest_connection_ids.iter().copied(),
2275 |connection_id| {
2276 session
2277 .peer
2278 .forward_send(session.connection_id, connection_id, request.clone())
2279 },
2280 );
2281 Ok(())
2282}
2283
2284/// Notify other participants that a language server has changed.
2285async fn update_language_server(
2286 request: proto::UpdateLanguageServer,
2287 session: MessageContext,
2288) -> Result<()> {
2289 let project_id = ProjectId::from_proto(request.project_id);
2290 let db = session.db().await;
2291
2292 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2293 && let Some(capabilities) = update.capabilities.clone() {
2294 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2295 .await?;
2296 }
2297
2298 let project_connection_ids = db
2299 .project_connection_ids(project_id, session.connection_id, true)
2300 .await?;
2301 broadcast(
2302 Some(session.connection_id),
2303 project_connection_ids.iter().copied(),
2304 |connection_id| {
2305 session
2306 .peer
2307 .forward_send(session.connection_id, connection_id, request.clone())
2308 },
2309 );
2310 Ok(())
2311}
2312
2313/// forward a project request to the host. These requests should be read only
2314/// as guests are allowed to send them.
2315async fn forward_read_only_project_request<T>(
2316 request: T,
2317 response: Response<T>,
2318 session: MessageContext,
2319) -> Result<()>
2320where
2321 T: EntityMessage + RequestMessage,
2322{
2323 let project_id = ProjectId::from_proto(request.remote_entity_id());
2324 let host_connection_id = session
2325 .db()
2326 .await
2327 .host_for_read_only_project_request(project_id, session.connection_id)
2328 .await?;
2329 let payload = session.forward_request(host_connection_id, request).await?;
2330 response.send(payload)?;
2331 Ok(())
2332}
2333
2334/// forward a project request to the host. These requests are disallowed
2335/// for guests.
2336async fn forward_mutating_project_request<T>(
2337 request: T,
2338 response: Response<T>,
2339 session: MessageContext,
2340) -> Result<()>
2341where
2342 T: EntityMessage + RequestMessage,
2343{
2344 let project_id = ProjectId::from_proto(request.remote_entity_id());
2345
2346 let host_connection_id = session
2347 .db()
2348 .await
2349 .host_for_mutating_project_request(project_id, session.connection_id)
2350 .await?;
2351 let payload = session.forward_request(host_connection_id, request).await?;
2352 response.send(payload)?;
2353 Ok(())
2354}
2355
2356async fn multi_lsp_query(
2357 request: MultiLspQuery,
2358 response: Response<MultiLspQuery>,
2359 session: MessageContext,
2360) -> Result<()> {
2361 tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2362 tracing::info!("multi_lsp_query message received");
2363 forward_mutating_project_request(request, response, session).await
2364}
2365
2366/// Notify other participants that a new buffer has been created
2367async fn create_buffer_for_peer(
2368 request: proto::CreateBufferForPeer,
2369 session: MessageContext,
2370) -> Result<()> {
2371 session
2372 .db()
2373 .await
2374 .check_user_is_project_host(
2375 ProjectId::from_proto(request.project_id),
2376 session.connection_id,
2377 )
2378 .await?;
2379 let peer_id = request.peer_id.context("invalid peer id")?;
2380 session
2381 .peer
2382 .forward_send(session.connection_id, peer_id.into(), request)?;
2383 Ok(())
2384}
2385
2386/// Notify other participants that a buffer has been updated. This is
2387/// allowed for guests as long as the update is limited to selections.
2388async fn update_buffer(
2389 request: proto::UpdateBuffer,
2390 response: Response<proto::UpdateBuffer>,
2391 session: MessageContext,
2392) -> Result<()> {
2393 let project_id = ProjectId::from_proto(request.project_id);
2394 let mut capability = Capability::ReadOnly;
2395
2396 for op in request.operations.iter() {
2397 match op.variant {
2398 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2399 Some(_) => capability = Capability::ReadWrite,
2400 }
2401 }
2402
2403 let host = {
2404 let guard = session
2405 .db()
2406 .await
2407 .connections_for_buffer_update(project_id, session.connection_id, capability)
2408 .await?;
2409
2410 let (host, guests) = &*guard;
2411
2412 broadcast(
2413 Some(session.connection_id),
2414 guests.clone(),
2415 |connection_id| {
2416 session
2417 .peer
2418 .forward_send(session.connection_id, connection_id, request.clone())
2419 },
2420 );
2421
2422 *host
2423 };
2424
2425 if host != session.connection_id {
2426 session.forward_request(host, request.clone()).await?;
2427 }
2428
2429 response.send(proto::Ack {})?;
2430 Ok(())
2431}
2432
2433async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2434 let project_id = ProjectId::from_proto(message.project_id);
2435
2436 let operation = message.operation.as_ref().context("invalid operation")?;
2437 let capability = match operation.variant.as_ref() {
2438 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2439 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2440 match buffer_op.variant {
2441 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2442 Capability::ReadOnly
2443 }
2444 _ => Capability::ReadWrite,
2445 }
2446 } else {
2447 Capability::ReadWrite
2448 }
2449 }
2450 Some(_) => Capability::ReadWrite,
2451 None => Capability::ReadOnly,
2452 };
2453
2454 let guard = session
2455 .db()
2456 .await
2457 .connections_for_buffer_update(project_id, session.connection_id, capability)
2458 .await?;
2459
2460 let (host, guests) = &*guard;
2461
2462 broadcast(
2463 Some(session.connection_id),
2464 guests.iter().chain([host]).copied(),
2465 |connection_id| {
2466 session
2467 .peer
2468 .forward_send(session.connection_id, connection_id, message.clone())
2469 },
2470 );
2471
2472 Ok(())
2473}
2474
2475/// Notify other participants that a project has been updated.
2476async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2477 request: T,
2478 session: MessageContext,
2479) -> Result<()> {
2480 let project_id = ProjectId::from_proto(request.remote_entity_id());
2481 let project_connection_ids = session
2482 .db()
2483 .await
2484 .project_connection_ids(project_id, session.connection_id, false)
2485 .await?;
2486
2487 broadcast(
2488 Some(session.connection_id),
2489 project_connection_ids.iter().copied(),
2490 |connection_id| {
2491 session
2492 .peer
2493 .forward_send(session.connection_id, connection_id, request.clone())
2494 },
2495 );
2496 Ok(())
2497}
2498
2499/// Start following another user in a call.
2500async fn follow(
2501 request: proto::Follow,
2502 response: Response<proto::Follow>,
2503 session: MessageContext,
2504) -> Result<()> {
2505 let room_id = RoomId::from_proto(request.room_id);
2506 let project_id = request.project_id.map(ProjectId::from_proto);
2507 let leader_id = request.leader_id.context("invalid leader id")?.into();
2508 let follower_id = session.connection_id;
2509
2510 session
2511 .db()
2512 .await
2513 .check_room_participants(room_id, leader_id, session.connection_id)
2514 .await?;
2515
2516 let response_payload = session.forward_request(leader_id, request).await?;
2517 response.send(response_payload)?;
2518
2519 if let Some(project_id) = project_id {
2520 let room = session
2521 .db()
2522 .await
2523 .follow(room_id, project_id, leader_id, follower_id)
2524 .await?;
2525 room_updated(&room, &session.peer);
2526 }
2527
2528 Ok(())
2529}
2530
2531/// Stop following another user in a call.
2532async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2533 let room_id = RoomId::from_proto(request.room_id);
2534 let project_id = request.project_id.map(ProjectId::from_proto);
2535 let leader_id = request.leader_id.context("invalid leader id")?.into();
2536 let follower_id = session.connection_id;
2537
2538 session
2539 .db()
2540 .await
2541 .check_room_participants(room_id, leader_id, session.connection_id)
2542 .await?;
2543
2544 session
2545 .peer
2546 .forward_send(session.connection_id, leader_id, request)?;
2547
2548 if let Some(project_id) = project_id {
2549 let room = session
2550 .db()
2551 .await
2552 .unfollow(room_id, project_id, leader_id, follower_id)
2553 .await?;
2554 room_updated(&room, &session.peer);
2555 }
2556
2557 Ok(())
2558}
2559
2560/// Notify everyone following you of your current location.
2561async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2562 let room_id = RoomId::from_proto(request.room_id);
2563 let database = session.db.lock().await;
2564
2565 let connection_ids = if let Some(project_id) = request.project_id {
2566 let project_id = ProjectId::from_proto(project_id);
2567 database
2568 .project_connection_ids(project_id, session.connection_id, true)
2569 .await?
2570 } else {
2571 database
2572 .room_connection_ids(room_id, session.connection_id)
2573 .await?
2574 };
2575
2576 // For now, don't send view update messages back to that view's current leader.
2577 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2578 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2579 _ => None,
2580 });
2581
2582 for connection_id in connection_ids.iter().cloned() {
2583 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2584 session
2585 .peer
2586 .forward_send(session.connection_id, connection_id, request.clone())?;
2587 }
2588 }
2589 Ok(())
2590}
2591
2592/// Get public data about users.
2593async fn get_users(
2594 request: proto::GetUsers,
2595 response: Response<proto::GetUsers>,
2596 session: MessageContext,
2597) -> Result<()> {
2598 let user_ids = request
2599 .user_ids
2600 .into_iter()
2601 .map(UserId::from_proto)
2602 .collect();
2603 let users = session
2604 .db()
2605 .await
2606 .get_users_by_ids(user_ids)
2607 .await?
2608 .into_iter()
2609 .map(|user| proto::User {
2610 id: user.id.to_proto(),
2611 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2612 github_login: user.github_login,
2613 name: user.name,
2614 })
2615 .collect();
2616 response.send(proto::UsersResponse { users })?;
2617 Ok(())
2618}
2619
2620/// Search for users (to invite) buy Github login
2621async fn fuzzy_search_users(
2622 request: proto::FuzzySearchUsers,
2623 response: Response<proto::FuzzySearchUsers>,
2624 session: MessageContext,
2625) -> Result<()> {
2626 let query = request.query;
2627 let users = match query.len() {
2628 0 => vec![],
2629 1 | 2 => session
2630 .db()
2631 .await
2632 .get_user_by_github_login(&query)
2633 .await?
2634 .into_iter()
2635 .collect(),
2636 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2637 };
2638 let users = users
2639 .into_iter()
2640 .filter(|user| user.id != session.user_id())
2641 .map(|user| proto::User {
2642 id: user.id.to_proto(),
2643 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2644 github_login: user.github_login,
2645 name: user.name,
2646 })
2647 .collect();
2648 response.send(proto::UsersResponse { users })?;
2649 Ok(())
2650}
2651
2652/// Send a contact request to another user.
2653async fn request_contact(
2654 request: proto::RequestContact,
2655 response: Response<proto::RequestContact>,
2656 session: MessageContext,
2657) -> Result<()> {
2658 let requester_id = session.user_id();
2659 let responder_id = UserId::from_proto(request.responder_id);
2660 if requester_id == responder_id {
2661 return Err(anyhow!("cannot add yourself as a contact"))?;
2662 }
2663
2664 let notifications = session
2665 .db()
2666 .await
2667 .send_contact_request(requester_id, responder_id)
2668 .await?;
2669
2670 // Update outgoing contact requests of requester
2671 let mut update = proto::UpdateContacts::default();
2672 update.outgoing_requests.push(responder_id.to_proto());
2673 for connection_id in session
2674 .connection_pool()
2675 .await
2676 .user_connection_ids(requester_id)
2677 {
2678 session.peer.send(connection_id, update.clone())?;
2679 }
2680
2681 // Update incoming contact requests of responder
2682 let mut update = proto::UpdateContacts::default();
2683 update
2684 .incoming_requests
2685 .push(proto::IncomingContactRequest {
2686 requester_id: requester_id.to_proto(),
2687 });
2688 let connection_pool = session.connection_pool().await;
2689 for connection_id in connection_pool.user_connection_ids(responder_id) {
2690 session.peer.send(connection_id, update.clone())?;
2691 }
2692
2693 send_notifications(&connection_pool, &session.peer, notifications);
2694
2695 response.send(proto::Ack {})?;
2696 Ok(())
2697}
2698
2699/// Accept or decline a contact request
2700async fn respond_to_contact_request(
2701 request: proto::RespondToContactRequest,
2702 response: Response<proto::RespondToContactRequest>,
2703 session: MessageContext,
2704) -> Result<()> {
2705 let responder_id = session.user_id();
2706 let requester_id = UserId::from_proto(request.requester_id);
2707 let db = session.db().await;
2708 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2709 db.dismiss_contact_notification(responder_id, requester_id)
2710 .await?;
2711 } else {
2712 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2713
2714 let notifications = db
2715 .respond_to_contact_request(responder_id, requester_id, accept)
2716 .await?;
2717 let requester_busy = db.is_user_busy(requester_id).await?;
2718 let responder_busy = db.is_user_busy(responder_id).await?;
2719
2720 let pool = session.connection_pool().await;
2721 // Update responder with new contact
2722 let mut update = proto::UpdateContacts::default();
2723 if accept {
2724 update
2725 .contacts
2726 .push(contact_for_user(requester_id, requester_busy, &pool));
2727 }
2728 update
2729 .remove_incoming_requests
2730 .push(requester_id.to_proto());
2731 for connection_id in pool.user_connection_ids(responder_id) {
2732 session.peer.send(connection_id, update.clone())?;
2733 }
2734
2735 // Update requester with new contact
2736 let mut update = proto::UpdateContacts::default();
2737 if accept {
2738 update
2739 .contacts
2740 .push(contact_for_user(responder_id, responder_busy, &pool));
2741 }
2742 update
2743 .remove_outgoing_requests
2744 .push(responder_id.to_proto());
2745
2746 for connection_id in pool.user_connection_ids(requester_id) {
2747 session.peer.send(connection_id, update.clone())?;
2748 }
2749
2750 send_notifications(&pool, &session.peer, notifications);
2751 }
2752
2753 response.send(proto::Ack {})?;
2754 Ok(())
2755}
2756
2757/// Remove a contact.
2758async fn remove_contact(
2759 request: proto::RemoveContact,
2760 response: Response<proto::RemoveContact>,
2761 session: MessageContext,
2762) -> Result<()> {
2763 let requester_id = session.user_id();
2764 let responder_id = UserId::from_proto(request.user_id);
2765 let db = session.db().await;
2766 let (contact_accepted, deleted_notification_id) =
2767 db.remove_contact(requester_id, responder_id).await?;
2768
2769 let pool = session.connection_pool().await;
2770 // Update outgoing contact requests of requester
2771 let mut update = proto::UpdateContacts::default();
2772 if contact_accepted {
2773 update.remove_contacts.push(responder_id.to_proto());
2774 } else {
2775 update
2776 .remove_outgoing_requests
2777 .push(responder_id.to_proto());
2778 }
2779 for connection_id in pool.user_connection_ids(requester_id) {
2780 session.peer.send(connection_id, update.clone())?;
2781 }
2782
2783 // Update incoming contact requests of responder
2784 let mut update = proto::UpdateContacts::default();
2785 if contact_accepted {
2786 update.remove_contacts.push(requester_id.to_proto());
2787 } else {
2788 update
2789 .remove_incoming_requests
2790 .push(requester_id.to_proto());
2791 }
2792 for connection_id in pool.user_connection_ids(responder_id) {
2793 session.peer.send(connection_id, update.clone())?;
2794 if let Some(notification_id) = deleted_notification_id {
2795 session.peer.send(
2796 connection_id,
2797 proto::DeleteNotification {
2798 notification_id: notification_id.to_proto(),
2799 },
2800 )?;
2801 }
2802 }
2803
2804 response.send(proto::Ack {})?;
2805 Ok(())
2806}
2807
2808fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2809 version.0.minor() < 139
2810}
2811
2812async fn subscribe_to_channels(
2813 _: proto::SubscribeToChannels,
2814 session: MessageContext,
2815) -> Result<()> {
2816 subscribe_user_to_channels(session.user_id(), &session).await?;
2817 Ok(())
2818}
2819
2820async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2821 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2822 let mut pool = session.connection_pool().await;
2823 for membership in &channels_for_user.channel_memberships {
2824 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2825 }
2826 session.peer.send(
2827 session.connection_id,
2828 build_update_user_channels(&channels_for_user),
2829 )?;
2830 session.peer.send(
2831 session.connection_id,
2832 build_channels_update(channels_for_user),
2833 )?;
2834 Ok(())
2835}
2836
2837/// Creates a new channel.
2838async fn create_channel(
2839 request: proto::CreateChannel,
2840 response: Response<proto::CreateChannel>,
2841 session: MessageContext,
2842) -> Result<()> {
2843 let db = session.db().await;
2844
2845 let parent_id = request.parent_id.map(ChannelId::from_proto);
2846 let (channel, membership) = db
2847 .create_channel(&request.name, parent_id, session.user_id())
2848 .await?;
2849
2850 let root_id = channel.root_id();
2851 let channel = Channel::from_model(channel);
2852
2853 response.send(proto::CreateChannelResponse {
2854 channel: Some(channel.to_proto()),
2855 parent_id: request.parent_id,
2856 })?;
2857
2858 let mut connection_pool = session.connection_pool().await;
2859 if let Some(membership) = membership {
2860 connection_pool.subscribe_to_channel(
2861 membership.user_id,
2862 membership.channel_id,
2863 membership.role,
2864 );
2865 let update = proto::UpdateUserChannels {
2866 channel_memberships: vec![proto::ChannelMembership {
2867 channel_id: membership.channel_id.to_proto(),
2868 role: membership.role.into(),
2869 }],
2870 ..Default::default()
2871 };
2872 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2873 session.peer.send(connection_id, update.clone())?;
2874 }
2875 }
2876
2877 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2878 if !role.can_see_channel(channel.visibility) {
2879 continue;
2880 }
2881
2882 let update = proto::UpdateChannels {
2883 channels: vec![channel.to_proto()],
2884 ..Default::default()
2885 };
2886 session.peer.send(connection_id, update.clone())?;
2887 }
2888
2889 Ok(())
2890}
2891
2892/// Delete a channel
2893async fn delete_channel(
2894 request: proto::DeleteChannel,
2895 response: Response<proto::DeleteChannel>,
2896 session: MessageContext,
2897) -> Result<()> {
2898 let db = session.db().await;
2899
2900 let channel_id = request.channel_id;
2901 let (root_channel, removed_channels) = db
2902 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2903 .await?;
2904 response.send(proto::Ack {})?;
2905
2906 // Notify members of removed channels
2907 let mut update = proto::UpdateChannels::default();
2908 update
2909 .delete_channels
2910 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2911
2912 let connection_pool = session.connection_pool().await;
2913 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2914 session.peer.send(connection_id, update.clone())?;
2915 }
2916
2917 Ok(())
2918}
2919
2920/// Invite someone to join a channel.
2921async fn invite_channel_member(
2922 request: proto::InviteChannelMember,
2923 response: Response<proto::InviteChannelMember>,
2924 session: MessageContext,
2925) -> Result<()> {
2926 let db = session.db().await;
2927 let channel_id = ChannelId::from_proto(request.channel_id);
2928 let invitee_id = UserId::from_proto(request.user_id);
2929 let InviteMemberResult {
2930 channel,
2931 notifications,
2932 } = db
2933 .invite_channel_member(
2934 channel_id,
2935 invitee_id,
2936 session.user_id(),
2937 request.role().into(),
2938 )
2939 .await?;
2940
2941 let update = proto::UpdateChannels {
2942 channel_invitations: vec![channel.to_proto()],
2943 ..Default::default()
2944 };
2945
2946 let connection_pool = session.connection_pool().await;
2947 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2948 session.peer.send(connection_id, update.clone())?;
2949 }
2950
2951 send_notifications(&connection_pool, &session.peer, notifications);
2952
2953 response.send(proto::Ack {})?;
2954 Ok(())
2955}
2956
2957/// remove someone from a channel
2958async fn remove_channel_member(
2959 request: proto::RemoveChannelMember,
2960 response: Response<proto::RemoveChannelMember>,
2961 session: MessageContext,
2962) -> Result<()> {
2963 let db = session.db().await;
2964 let channel_id = ChannelId::from_proto(request.channel_id);
2965 let member_id = UserId::from_proto(request.user_id);
2966
2967 let RemoveChannelMemberResult {
2968 membership_update,
2969 notification_id,
2970 } = db
2971 .remove_channel_member(channel_id, member_id, session.user_id())
2972 .await?;
2973
2974 let mut connection_pool = session.connection_pool().await;
2975 notify_membership_updated(
2976 &mut connection_pool,
2977 membership_update,
2978 member_id,
2979 &session.peer,
2980 );
2981 for connection_id in connection_pool.user_connection_ids(member_id) {
2982 if let Some(notification_id) = notification_id {
2983 session
2984 .peer
2985 .send(
2986 connection_id,
2987 proto::DeleteNotification {
2988 notification_id: notification_id.to_proto(),
2989 },
2990 )
2991 .trace_err();
2992 }
2993 }
2994
2995 response.send(proto::Ack {})?;
2996 Ok(())
2997}
2998
2999/// Toggle the channel between public and private.
3000/// Care is taken to maintain the invariant that public channels only descend from public channels,
3001/// (though members-only channels can appear at any point in the hierarchy).
3002async fn set_channel_visibility(
3003 request: proto::SetChannelVisibility,
3004 response: Response<proto::SetChannelVisibility>,
3005 session: MessageContext,
3006) -> Result<()> {
3007 let db = session.db().await;
3008 let channel_id = ChannelId::from_proto(request.channel_id);
3009 let visibility = request.visibility().into();
3010
3011 let channel_model = db
3012 .set_channel_visibility(channel_id, visibility, session.user_id())
3013 .await?;
3014 let root_id = channel_model.root_id();
3015 let channel = Channel::from_model(channel_model);
3016
3017 let mut connection_pool = session.connection_pool().await;
3018 for (user_id, role) in connection_pool
3019 .channel_user_ids(root_id)
3020 .collect::<Vec<_>>()
3021 .into_iter()
3022 {
3023 let update = if role.can_see_channel(channel.visibility) {
3024 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3025 proto::UpdateChannels {
3026 channels: vec![channel.to_proto()],
3027 ..Default::default()
3028 }
3029 } else {
3030 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3031 proto::UpdateChannels {
3032 delete_channels: vec![channel.id.to_proto()],
3033 ..Default::default()
3034 }
3035 };
3036
3037 for connection_id in connection_pool.user_connection_ids(user_id) {
3038 session.peer.send(connection_id, update.clone())?;
3039 }
3040 }
3041
3042 response.send(proto::Ack {})?;
3043 Ok(())
3044}
3045
3046/// Alter the role for a user in the channel.
3047async fn set_channel_member_role(
3048 request: proto::SetChannelMemberRole,
3049 response: Response<proto::SetChannelMemberRole>,
3050 session: MessageContext,
3051) -> Result<()> {
3052 let db = session.db().await;
3053 let channel_id = ChannelId::from_proto(request.channel_id);
3054 let member_id = UserId::from_proto(request.user_id);
3055 let result = db
3056 .set_channel_member_role(
3057 channel_id,
3058 session.user_id(),
3059 member_id,
3060 request.role().into(),
3061 )
3062 .await?;
3063
3064 match result {
3065 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3066 let mut connection_pool = session.connection_pool().await;
3067 notify_membership_updated(
3068 &mut connection_pool,
3069 membership_update,
3070 member_id,
3071 &session.peer,
3072 )
3073 }
3074 db::SetMemberRoleResult::InviteUpdated(channel) => {
3075 let update = proto::UpdateChannels {
3076 channel_invitations: vec![channel.to_proto()],
3077 ..Default::default()
3078 };
3079
3080 for connection_id in session
3081 .connection_pool()
3082 .await
3083 .user_connection_ids(member_id)
3084 {
3085 session.peer.send(connection_id, update.clone())?;
3086 }
3087 }
3088 }
3089
3090 response.send(proto::Ack {})?;
3091 Ok(())
3092}
3093
3094/// Change the name of a channel
3095async fn rename_channel(
3096 request: proto::RenameChannel,
3097 response: Response<proto::RenameChannel>,
3098 session: MessageContext,
3099) -> Result<()> {
3100 let db = session.db().await;
3101 let channel_id = ChannelId::from_proto(request.channel_id);
3102 let channel_model = db
3103 .rename_channel(channel_id, session.user_id(), &request.name)
3104 .await?;
3105 let root_id = channel_model.root_id();
3106 let channel = Channel::from_model(channel_model);
3107
3108 response.send(proto::RenameChannelResponse {
3109 channel: Some(channel.to_proto()),
3110 })?;
3111
3112 let connection_pool = session.connection_pool().await;
3113 let update = proto::UpdateChannels {
3114 channels: vec![channel.to_proto()],
3115 ..Default::default()
3116 };
3117 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3118 if role.can_see_channel(channel.visibility) {
3119 session.peer.send(connection_id, update.clone())?;
3120 }
3121 }
3122
3123 Ok(())
3124}
3125
3126/// Move a channel to a new parent.
3127async fn move_channel(
3128 request: proto::MoveChannel,
3129 response: Response<proto::MoveChannel>,
3130 session: MessageContext,
3131) -> Result<()> {
3132 let channel_id = ChannelId::from_proto(request.channel_id);
3133 let to = ChannelId::from_proto(request.to);
3134
3135 let (root_id, channels) = session
3136 .db()
3137 .await
3138 .move_channel(channel_id, to, session.user_id())
3139 .await?;
3140
3141 let connection_pool = session.connection_pool().await;
3142 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3143 let channels = channels
3144 .iter()
3145 .filter_map(|channel| {
3146 if role.can_see_channel(channel.visibility) {
3147 Some(channel.to_proto())
3148 } else {
3149 None
3150 }
3151 })
3152 .collect::<Vec<_>>();
3153 if channels.is_empty() {
3154 continue;
3155 }
3156
3157 let update = proto::UpdateChannels {
3158 channels,
3159 ..Default::default()
3160 };
3161
3162 session.peer.send(connection_id, update.clone())?;
3163 }
3164
3165 response.send(Ack {})?;
3166 Ok(())
3167}
3168
3169async fn reorder_channel(
3170 request: proto::ReorderChannel,
3171 response: Response<proto::ReorderChannel>,
3172 session: MessageContext,
3173) -> Result<()> {
3174 let channel_id = ChannelId::from_proto(request.channel_id);
3175 let direction = request.direction();
3176
3177 let updated_channels = session
3178 .db()
3179 .await
3180 .reorder_channel(channel_id, direction, session.user_id())
3181 .await?;
3182
3183 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3184 let connection_pool = session.connection_pool().await;
3185 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3186 let channels = updated_channels
3187 .iter()
3188 .filter_map(|channel| {
3189 if role.can_see_channel(channel.visibility) {
3190 Some(channel.to_proto())
3191 } else {
3192 None
3193 }
3194 })
3195 .collect::<Vec<_>>();
3196
3197 if channels.is_empty() {
3198 continue;
3199 }
3200
3201 let update = proto::UpdateChannels {
3202 channels,
3203 ..Default::default()
3204 };
3205
3206 session.peer.send(connection_id, update.clone())?;
3207 }
3208 }
3209
3210 response.send(Ack {})?;
3211 Ok(())
3212}
3213
3214/// Get the list of channel members
3215async fn get_channel_members(
3216 request: proto::GetChannelMembers,
3217 response: Response<proto::GetChannelMembers>,
3218 session: MessageContext,
3219) -> Result<()> {
3220 let db = session.db().await;
3221 let channel_id = ChannelId::from_proto(request.channel_id);
3222 let limit = if request.limit == 0 {
3223 u16::MAX as u64
3224 } else {
3225 request.limit
3226 };
3227 let (members, users) = db
3228 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3229 .await?;
3230 response.send(proto::GetChannelMembersResponse { members, users })?;
3231 Ok(())
3232}
3233
3234/// Accept or decline a channel invitation.
3235async fn respond_to_channel_invite(
3236 request: proto::RespondToChannelInvite,
3237 response: Response<proto::RespondToChannelInvite>,
3238 session: MessageContext,
3239) -> Result<()> {
3240 let db = session.db().await;
3241 let channel_id = ChannelId::from_proto(request.channel_id);
3242 let RespondToChannelInvite {
3243 membership_update,
3244 notifications,
3245 } = db
3246 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3247 .await?;
3248
3249 let mut connection_pool = session.connection_pool().await;
3250 if let Some(membership_update) = membership_update {
3251 notify_membership_updated(
3252 &mut connection_pool,
3253 membership_update,
3254 session.user_id(),
3255 &session.peer,
3256 );
3257 } else {
3258 let update = proto::UpdateChannels {
3259 remove_channel_invitations: vec![channel_id.to_proto()],
3260 ..Default::default()
3261 };
3262
3263 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3264 session.peer.send(connection_id, update.clone())?;
3265 }
3266 };
3267
3268 send_notifications(&connection_pool, &session.peer, notifications);
3269
3270 response.send(proto::Ack {})?;
3271
3272 Ok(())
3273}
3274
3275/// Join the channels' room
3276async fn join_channel(
3277 request: proto::JoinChannel,
3278 response: Response<proto::JoinChannel>,
3279 session: MessageContext,
3280) -> Result<()> {
3281 let channel_id = ChannelId::from_proto(request.channel_id);
3282 join_channel_internal(channel_id, Box::new(response), session).await
3283}
3284
3285trait JoinChannelInternalResponse {
3286 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3287}
3288impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3289 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3290 Response::<proto::JoinChannel>::send(self, result)
3291 }
3292}
3293impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3294 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3295 Response::<proto::JoinRoom>::send(self, result)
3296 }
3297}
3298
3299async fn join_channel_internal(
3300 channel_id: ChannelId,
3301 response: Box<impl JoinChannelInternalResponse>,
3302 session: MessageContext,
3303) -> Result<()> {
3304 let joined_room = {
3305 let mut db = session.db().await;
3306 // If zed quits without leaving the room, and the user re-opens zed before the
3307 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3308 // room they were in.
3309 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3310 tracing::info!(
3311 stale_connection_id = %connection,
3312 "cleaning up stale connection",
3313 );
3314 drop(db);
3315 leave_room_for_session(&session, connection).await?;
3316 db = session.db().await;
3317 }
3318
3319 let (joined_room, membership_updated, role) = db
3320 .join_channel(channel_id, session.user_id(), session.connection_id)
3321 .await?;
3322
3323 let live_kit_connection_info =
3324 session
3325 .app_state
3326 .livekit_client
3327 .as_ref()
3328 .and_then(|live_kit| {
3329 let (can_publish, token) = if role == ChannelRole::Guest {
3330 (
3331 false,
3332 live_kit
3333 .guest_token(
3334 &joined_room.room.livekit_room,
3335 &session.user_id().to_string(),
3336 )
3337 .trace_err()?,
3338 )
3339 } else {
3340 (
3341 true,
3342 live_kit
3343 .room_token(
3344 &joined_room.room.livekit_room,
3345 &session.user_id().to_string(),
3346 )
3347 .trace_err()?,
3348 )
3349 };
3350
3351 Some(LiveKitConnectionInfo {
3352 server_url: live_kit.url().into(),
3353 token,
3354 can_publish,
3355 })
3356 });
3357
3358 response.send(proto::JoinRoomResponse {
3359 room: Some(joined_room.room.clone()),
3360 channel_id: joined_room
3361 .channel
3362 .as_ref()
3363 .map(|channel| channel.id.to_proto()),
3364 live_kit_connection_info,
3365 })?;
3366
3367 let mut connection_pool = session.connection_pool().await;
3368 if let Some(membership_updated) = membership_updated {
3369 notify_membership_updated(
3370 &mut connection_pool,
3371 membership_updated,
3372 session.user_id(),
3373 &session.peer,
3374 );
3375 }
3376
3377 room_updated(&joined_room.room, &session.peer);
3378
3379 joined_room
3380 };
3381
3382 channel_updated(
3383 &joined_room.channel.context("channel not returned")?,
3384 &joined_room.room,
3385 &session.peer,
3386 &*session.connection_pool().await,
3387 );
3388
3389 update_user_contacts(session.user_id(), &session).await?;
3390 Ok(())
3391}
3392
3393/// Start editing the channel notes
3394async fn join_channel_buffer(
3395 request: proto::JoinChannelBuffer,
3396 response: Response<proto::JoinChannelBuffer>,
3397 session: MessageContext,
3398) -> Result<()> {
3399 let db = session.db().await;
3400 let channel_id = ChannelId::from_proto(request.channel_id);
3401
3402 let open_response = db
3403 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3404 .await?;
3405
3406 let collaborators = open_response.collaborators.clone();
3407 response.send(open_response)?;
3408
3409 let update = UpdateChannelBufferCollaborators {
3410 channel_id: channel_id.to_proto(),
3411 collaborators: collaborators.clone(),
3412 };
3413 channel_buffer_updated(
3414 session.connection_id,
3415 collaborators
3416 .iter()
3417 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3418 &update,
3419 &session.peer,
3420 );
3421
3422 Ok(())
3423}
3424
3425/// Edit the channel notes
3426async fn update_channel_buffer(
3427 request: proto::UpdateChannelBuffer,
3428 session: MessageContext,
3429) -> Result<()> {
3430 let db = session.db().await;
3431 let channel_id = ChannelId::from_proto(request.channel_id);
3432
3433 let (collaborators, epoch, version) = db
3434 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3435 .await?;
3436
3437 channel_buffer_updated(
3438 session.connection_id,
3439 collaborators.clone(),
3440 &proto::UpdateChannelBuffer {
3441 channel_id: channel_id.to_proto(),
3442 operations: request.operations,
3443 },
3444 &session.peer,
3445 );
3446
3447 let pool = &*session.connection_pool().await;
3448
3449 let non_collaborators =
3450 pool.channel_connection_ids(channel_id)
3451 .filter_map(|(connection_id, _)| {
3452 if collaborators.contains(&connection_id) {
3453 None
3454 } else {
3455 Some(connection_id)
3456 }
3457 });
3458
3459 broadcast(None, non_collaborators, |peer_id| {
3460 session.peer.send(
3461 peer_id,
3462 proto::UpdateChannels {
3463 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3464 channel_id: channel_id.to_proto(),
3465 epoch: epoch as u64,
3466 version: version.clone(),
3467 }],
3468 ..Default::default()
3469 },
3470 )
3471 });
3472
3473 Ok(())
3474}
3475
3476/// Rejoin the channel notes after a connection blip
3477async fn rejoin_channel_buffers(
3478 request: proto::RejoinChannelBuffers,
3479 response: Response<proto::RejoinChannelBuffers>,
3480 session: MessageContext,
3481) -> Result<()> {
3482 let db = session.db().await;
3483 let buffers = db
3484 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3485 .await?;
3486
3487 for rejoined_buffer in &buffers {
3488 let collaborators_to_notify = rejoined_buffer
3489 .buffer
3490 .collaborators
3491 .iter()
3492 .filter_map(|c| Some(c.peer_id?.into()));
3493 channel_buffer_updated(
3494 session.connection_id,
3495 collaborators_to_notify,
3496 &proto::UpdateChannelBufferCollaborators {
3497 channel_id: rejoined_buffer.buffer.channel_id,
3498 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3499 },
3500 &session.peer,
3501 );
3502 }
3503
3504 response.send(proto::RejoinChannelBuffersResponse {
3505 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3506 })?;
3507
3508 Ok(())
3509}
3510
3511/// Stop editing the channel notes
3512async fn leave_channel_buffer(
3513 request: proto::LeaveChannelBuffer,
3514 response: Response<proto::LeaveChannelBuffer>,
3515 session: MessageContext,
3516) -> Result<()> {
3517 let db = session.db().await;
3518 let channel_id = ChannelId::from_proto(request.channel_id);
3519
3520 let left_buffer = db
3521 .leave_channel_buffer(channel_id, session.connection_id)
3522 .await?;
3523
3524 response.send(Ack {})?;
3525
3526 channel_buffer_updated(
3527 session.connection_id,
3528 left_buffer.connections,
3529 &proto::UpdateChannelBufferCollaborators {
3530 channel_id: channel_id.to_proto(),
3531 collaborators: left_buffer.collaborators,
3532 },
3533 &session.peer,
3534 );
3535
3536 Ok(())
3537}
3538
3539fn channel_buffer_updated<T: EnvelopedMessage>(
3540 sender_id: ConnectionId,
3541 collaborators: impl IntoIterator<Item = ConnectionId>,
3542 message: &T,
3543 peer: &Peer,
3544) {
3545 broadcast(Some(sender_id), collaborators, |peer_id| {
3546 peer.send(peer_id, message.clone())
3547 });
3548}
3549
3550fn send_notifications(
3551 connection_pool: &ConnectionPool,
3552 peer: &Peer,
3553 notifications: db::NotificationBatch,
3554) {
3555 for (user_id, notification) in notifications {
3556 for connection_id in connection_pool.user_connection_ids(user_id) {
3557 if let Err(error) = peer.send(
3558 connection_id,
3559 proto::AddNotification {
3560 notification: Some(notification.clone()),
3561 },
3562 ) {
3563 tracing::error!(
3564 "failed to send notification to {:?} {}",
3565 connection_id,
3566 error
3567 );
3568 }
3569 }
3570 }
3571}
3572
3573/// Send a message to the channel
3574async fn send_channel_message(
3575 request: proto::SendChannelMessage,
3576 response: Response<proto::SendChannelMessage>,
3577 session: MessageContext,
3578) -> Result<()> {
3579 // Validate the message body.
3580 let body = request.body.trim().to_string();
3581 if body.len() > MAX_MESSAGE_LEN {
3582 return Err(anyhow!("message is too long"))?;
3583 }
3584 if body.is_empty() {
3585 return Err(anyhow!("message can't be blank"))?;
3586 }
3587
3588 // TODO: adjust mentions if body is trimmed
3589
3590 let timestamp = OffsetDateTime::now_utc();
3591 let nonce = request.nonce.context("nonce can't be blank")?;
3592
3593 let channel_id = ChannelId::from_proto(request.channel_id);
3594 let CreatedChannelMessage {
3595 message_id,
3596 participant_connection_ids,
3597 notifications,
3598 } = session
3599 .db()
3600 .await
3601 .create_channel_message(
3602 channel_id,
3603 session.user_id(),
3604 &body,
3605 &request.mentions,
3606 timestamp,
3607 nonce.clone().into(),
3608 request.reply_to_message_id.map(MessageId::from_proto),
3609 )
3610 .await?;
3611
3612 let message = proto::ChannelMessage {
3613 sender_id: session.user_id().to_proto(),
3614 id: message_id.to_proto(),
3615 body,
3616 mentions: request.mentions,
3617 timestamp: timestamp.unix_timestamp() as u64,
3618 nonce: Some(nonce),
3619 reply_to_message_id: request.reply_to_message_id,
3620 edited_at: None,
3621 };
3622 broadcast(
3623 Some(session.connection_id),
3624 participant_connection_ids.clone(),
3625 |connection| {
3626 session.peer.send(
3627 connection,
3628 proto::ChannelMessageSent {
3629 channel_id: channel_id.to_proto(),
3630 message: Some(message.clone()),
3631 },
3632 )
3633 },
3634 );
3635 response.send(proto::SendChannelMessageResponse {
3636 message: Some(message),
3637 })?;
3638
3639 let pool = &*session.connection_pool().await;
3640 let non_participants =
3641 pool.channel_connection_ids(channel_id)
3642 .filter_map(|(connection_id, _)| {
3643 if participant_connection_ids.contains(&connection_id) {
3644 None
3645 } else {
3646 Some(connection_id)
3647 }
3648 });
3649 broadcast(None, non_participants, |peer_id| {
3650 session.peer.send(
3651 peer_id,
3652 proto::UpdateChannels {
3653 latest_channel_message_ids: vec![proto::ChannelMessageId {
3654 channel_id: channel_id.to_proto(),
3655 message_id: message_id.to_proto(),
3656 }],
3657 ..Default::default()
3658 },
3659 )
3660 });
3661 send_notifications(pool, &session.peer, notifications);
3662
3663 Ok(())
3664}
3665
3666/// Delete a channel message
3667async fn remove_channel_message(
3668 request: proto::RemoveChannelMessage,
3669 response: Response<proto::RemoveChannelMessage>,
3670 session: MessageContext,
3671) -> Result<()> {
3672 let channel_id = ChannelId::from_proto(request.channel_id);
3673 let message_id = MessageId::from_proto(request.message_id);
3674 let (connection_ids, existing_notification_ids) = session
3675 .db()
3676 .await
3677 .remove_channel_message(channel_id, message_id, session.user_id())
3678 .await?;
3679
3680 broadcast(
3681 Some(session.connection_id),
3682 connection_ids,
3683 move |connection| {
3684 session.peer.send(connection, request.clone())?;
3685
3686 for notification_id in &existing_notification_ids {
3687 session.peer.send(
3688 connection,
3689 proto::DeleteNotification {
3690 notification_id: (*notification_id).to_proto(),
3691 },
3692 )?;
3693 }
3694
3695 Ok(())
3696 },
3697 );
3698 response.send(proto::Ack {})?;
3699 Ok(())
3700}
3701
3702async fn update_channel_message(
3703 request: proto::UpdateChannelMessage,
3704 response: Response<proto::UpdateChannelMessage>,
3705 session: MessageContext,
3706) -> Result<()> {
3707 let channel_id = ChannelId::from_proto(request.channel_id);
3708 let message_id = MessageId::from_proto(request.message_id);
3709 let updated_at = OffsetDateTime::now_utc();
3710 let UpdatedChannelMessage {
3711 message_id,
3712 participant_connection_ids,
3713 notifications,
3714 reply_to_message_id,
3715 timestamp,
3716 deleted_mention_notification_ids,
3717 updated_mention_notifications,
3718 } = session
3719 .db()
3720 .await
3721 .update_channel_message(
3722 channel_id,
3723 message_id,
3724 session.user_id(),
3725 request.body.as_str(),
3726 &request.mentions,
3727 updated_at,
3728 )
3729 .await?;
3730
3731 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3732
3733 let message = proto::ChannelMessage {
3734 sender_id: session.user_id().to_proto(),
3735 id: message_id.to_proto(),
3736 body: request.body.clone(),
3737 mentions: request.mentions.clone(),
3738 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3739 nonce: Some(nonce),
3740 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3741 edited_at: Some(updated_at.unix_timestamp() as u64),
3742 };
3743
3744 response.send(proto::Ack {})?;
3745
3746 let pool = &*session.connection_pool().await;
3747 broadcast(
3748 Some(session.connection_id),
3749 participant_connection_ids,
3750 |connection| {
3751 session.peer.send(
3752 connection,
3753 proto::ChannelMessageUpdate {
3754 channel_id: channel_id.to_proto(),
3755 message: Some(message.clone()),
3756 },
3757 )?;
3758
3759 for notification_id in &deleted_mention_notification_ids {
3760 session.peer.send(
3761 connection,
3762 proto::DeleteNotification {
3763 notification_id: (*notification_id).to_proto(),
3764 },
3765 )?;
3766 }
3767
3768 for notification in &updated_mention_notifications {
3769 session.peer.send(
3770 connection,
3771 proto::UpdateNotification {
3772 notification: Some(notification.clone()),
3773 },
3774 )?;
3775 }
3776
3777 Ok(())
3778 },
3779 );
3780
3781 send_notifications(pool, &session.peer, notifications);
3782
3783 Ok(())
3784}
3785
3786/// Mark a channel message as read
3787async fn acknowledge_channel_message(
3788 request: proto::AckChannelMessage,
3789 session: MessageContext,
3790) -> Result<()> {
3791 let channel_id = ChannelId::from_proto(request.channel_id);
3792 let message_id = MessageId::from_proto(request.message_id);
3793 let notifications = session
3794 .db()
3795 .await
3796 .observe_channel_message(channel_id, session.user_id(), message_id)
3797 .await?;
3798 send_notifications(
3799 &*session.connection_pool().await,
3800 &session.peer,
3801 notifications,
3802 );
3803 Ok(())
3804}
3805
3806/// Mark a buffer version as synced
3807async fn acknowledge_buffer_version(
3808 request: proto::AckBufferOperation,
3809 session: MessageContext,
3810) -> Result<()> {
3811 let buffer_id = BufferId::from_proto(request.buffer_id);
3812 session
3813 .db()
3814 .await
3815 .observe_buffer_version(
3816 buffer_id,
3817 session.user_id(),
3818 request.epoch as i32,
3819 &request.version,
3820 )
3821 .await?;
3822 Ok(())
3823}
3824
3825/// Get a Supermaven API key for the user
3826async fn get_supermaven_api_key(
3827 _request: proto::GetSupermavenApiKey,
3828 response: Response<proto::GetSupermavenApiKey>,
3829 session: MessageContext,
3830) -> Result<()> {
3831 let user_id: String = session.user_id().to_string();
3832 if !session.is_staff() {
3833 return Err(anyhow!("supermaven not enabled for this account"))?;
3834 }
3835
3836 let email = session.email().context("user must have an email")?;
3837
3838 let supermaven_admin_api = session
3839 .supermaven_client
3840 .as_ref()
3841 .context("supermaven not configured")?;
3842
3843 let result = supermaven_admin_api
3844 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3845 .await?;
3846
3847 response.send(proto::GetSupermavenApiKeyResponse {
3848 api_key: result.api_key,
3849 })?;
3850
3851 Ok(())
3852}
3853
3854/// Start receiving chat updates for a channel
3855async fn join_channel_chat(
3856 request: proto::JoinChannelChat,
3857 response: Response<proto::JoinChannelChat>,
3858 session: MessageContext,
3859) -> Result<()> {
3860 let channel_id = ChannelId::from_proto(request.channel_id);
3861
3862 let db = session.db().await;
3863 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3864 .await?;
3865 let messages = db
3866 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3867 .await?;
3868 response.send(proto::JoinChannelChatResponse {
3869 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3870 messages,
3871 })?;
3872 Ok(())
3873}
3874
3875/// Stop receiving chat updates for a channel
3876async fn leave_channel_chat(
3877 request: proto::LeaveChannelChat,
3878 session: MessageContext,
3879) -> Result<()> {
3880 let channel_id = ChannelId::from_proto(request.channel_id);
3881 session
3882 .db()
3883 .await
3884 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3885 .await?;
3886 Ok(())
3887}
3888
3889/// Retrieve the chat history for a channel
3890async fn get_channel_messages(
3891 request: proto::GetChannelMessages,
3892 response: Response<proto::GetChannelMessages>,
3893 session: MessageContext,
3894) -> Result<()> {
3895 let channel_id = ChannelId::from_proto(request.channel_id);
3896 let messages = session
3897 .db()
3898 .await
3899 .get_channel_messages(
3900 channel_id,
3901 session.user_id(),
3902 MESSAGE_COUNT_PER_PAGE,
3903 Some(MessageId::from_proto(request.before_message_id)),
3904 )
3905 .await?;
3906 response.send(proto::GetChannelMessagesResponse {
3907 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3908 messages,
3909 })?;
3910 Ok(())
3911}
3912
3913/// Retrieve specific chat messages
3914async fn get_channel_messages_by_id(
3915 request: proto::GetChannelMessagesById,
3916 response: Response<proto::GetChannelMessagesById>,
3917 session: MessageContext,
3918) -> Result<()> {
3919 let message_ids = request
3920 .message_ids
3921 .iter()
3922 .map(|id| MessageId::from_proto(*id))
3923 .collect::<Vec<_>>();
3924 let messages = session
3925 .db()
3926 .await
3927 .get_channel_messages_by_id(session.user_id(), &message_ids)
3928 .await?;
3929 response.send(proto::GetChannelMessagesResponse {
3930 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3931 messages,
3932 })?;
3933 Ok(())
3934}
3935
3936/// Retrieve the current users notifications
3937async fn get_notifications(
3938 request: proto::GetNotifications,
3939 response: Response<proto::GetNotifications>,
3940 session: MessageContext,
3941) -> Result<()> {
3942 let notifications = session
3943 .db()
3944 .await
3945 .get_notifications(
3946 session.user_id(),
3947 NOTIFICATION_COUNT_PER_PAGE,
3948 request.before_id.map(db::NotificationId::from_proto),
3949 )
3950 .await?;
3951 response.send(proto::GetNotificationsResponse {
3952 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3953 notifications,
3954 })?;
3955 Ok(())
3956}
3957
3958/// Mark notifications as read
3959async fn mark_notification_as_read(
3960 request: proto::MarkNotificationRead,
3961 response: Response<proto::MarkNotificationRead>,
3962 session: MessageContext,
3963) -> Result<()> {
3964 let database = &session.db().await;
3965 let notifications = database
3966 .mark_notification_as_read_by_id(
3967 session.user_id(),
3968 NotificationId::from_proto(request.notification_id),
3969 )
3970 .await?;
3971 send_notifications(
3972 &*session.connection_pool().await,
3973 &session.peer,
3974 notifications,
3975 );
3976 response.send(proto::Ack {})?;
3977 Ok(())
3978}
3979
3980fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3981 let message = match message {
3982 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3983 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3984 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3985 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3986 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3987 code: frame.code.into(),
3988 reason: frame.reason.as_str().to_owned().into(),
3989 })),
3990 // We should never receive a frame while reading the message, according
3991 // to the `tungstenite` maintainers:
3992 //
3993 // > It cannot occur when you read messages from the WebSocket, but it
3994 // > can be used when you want to send the raw frames (e.g. you want to
3995 // > send the frames to the WebSocket without composing the full message first).
3996 // >
3997 // > — https://github.com/snapview/tungstenite-rs/issues/268
3998 TungsteniteMessage::Frame(_) => {
3999 bail!("received an unexpected frame while reading the message")
4000 }
4001 };
4002
4003 Ok(message)
4004}
4005
4006fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4007 match message {
4008 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4009 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4010 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4011 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4012 AxumMessage::Close(frame) => {
4013 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4014 code: frame.code.into(),
4015 reason: frame.reason.as_ref().into(),
4016 }))
4017 }
4018 }
4019}
4020
4021fn notify_membership_updated(
4022 connection_pool: &mut ConnectionPool,
4023 result: MembershipUpdated,
4024 user_id: UserId,
4025 peer: &Peer,
4026) {
4027 for membership in &result.new_channels.channel_memberships {
4028 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4029 }
4030 for channel_id in &result.removed_channels {
4031 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4032 }
4033
4034 let user_channels_update = proto::UpdateUserChannels {
4035 channel_memberships: result
4036 .new_channels
4037 .channel_memberships
4038 .iter()
4039 .map(|cm| proto::ChannelMembership {
4040 channel_id: cm.channel_id.to_proto(),
4041 role: cm.role.into(),
4042 })
4043 .collect(),
4044 ..Default::default()
4045 };
4046
4047 let mut update = build_channels_update(result.new_channels);
4048 update.delete_channels = result
4049 .removed_channels
4050 .into_iter()
4051 .map(|id| id.to_proto())
4052 .collect();
4053 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4054
4055 for connection_id in connection_pool.user_connection_ids(user_id) {
4056 peer.send(connection_id, user_channels_update.clone())
4057 .trace_err();
4058 peer.send(connection_id, update.clone()).trace_err();
4059 }
4060}
4061
4062fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4063 proto::UpdateUserChannels {
4064 channel_memberships: channels
4065 .channel_memberships
4066 .iter()
4067 .map(|m| proto::ChannelMembership {
4068 channel_id: m.channel_id.to_proto(),
4069 role: m.role.into(),
4070 })
4071 .collect(),
4072 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4073 observed_channel_message_id: channels.observed_channel_messages.clone(),
4074 }
4075}
4076
4077fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4078 let mut update = proto::UpdateChannels::default();
4079
4080 for channel in channels.channels {
4081 update.channels.push(channel.to_proto());
4082 }
4083
4084 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4085 update.latest_channel_message_ids = channels.latest_channel_messages;
4086
4087 for (channel_id, participants) in channels.channel_participants {
4088 update
4089 .channel_participants
4090 .push(proto::ChannelParticipants {
4091 channel_id: channel_id.to_proto(),
4092 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4093 });
4094 }
4095
4096 for channel in channels.invited_channels {
4097 update.channel_invitations.push(channel.to_proto());
4098 }
4099
4100 update
4101}
4102
4103fn build_initial_contacts_update(
4104 contacts: Vec<db::Contact>,
4105 pool: &ConnectionPool,
4106) -> proto::UpdateContacts {
4107 let mut update = proto::UpdateContacts::default();
4108
4109 for contact in contacts {
4110 match contact {
4111 db::Contact::Accepted { user_id, busy } => {
4112 update.contacts.push(contact_for_user(user_id, busy, pool));
4113 }
4114 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4115 db::Contact::Incoming { user_id } => {
4116 update
4117 .incoming_requests
4118 .push(proto::IncomingContactRequest {
4119 requester_id: user_id.to_proto(),
4120 })
4121 }
4122 }
4123 }
4124
4125 update
4126}
4127
4128fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4129 proto::Contact {
4130 user_id: user_id.to_proto(),
4131 online: pool.is_user_online(user_id),
4132 busy,
4133 }
4134}
4135
4136fn room_updated(room: &proto::Room, peer: &Peer) {
4137 broadcast(
4138 None,
4139 room.participants
4140 .iter()
4141 .filter_map(|participant| Some(participant.peer_id?.into())),
4142 |peer_id| {
4143 peer.send(
4144 peer_id,
4145 proto::RoomUpdated {
4146 room: Some(room.clone()),
4147 },
4148 )
4149 },
4150 );
4151}
4152
4153fn channel_updated(
4154 channel: &db::channel::Model,
4155 room: &proto::Room,
4156 peer: &Peer,
4157 pool: &ConnectionPool,
4158) {
4159 let participants = room
4160 .participants
4161 .iter()
4162 .map(|p| p.user_id)
4163 .collect::<Vec<_>>();
4164
4165 broadcast(
4166 None,
4167 pool.channel_connection_ids(channel.root_id())
4168 .filter_map(|(channel_id, role)| {
4169 role.can_see_channel(channel.visibility)
4170 .then_some(channel_id)
4171 }),
4172 |peer_id| {
4173 peer.send(
4174 peer_id,
4175 proto::UpdateChannels {
4176 channel_participants: vec![proto::ChannelParticipants {
4177 channel_id: channel.id.to_proto(),
4178 participant_user_ids: participants.clone(),
4179 }],
4180 ..Default::default()
4181 },
4182 )
4183 },
4184 );
4185}
4186
4187async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4188 let db = session.db().await;
4189
4190 let contacts = db.get_contacts(user_id).await?;
4191 let busy = db.is_user_busy(user_id).await?;
4192
4193 let pool = session.connection_pool().await;
4194 let updated_contact = contact_for_user(user_id, busy, &pool);
4195 for contact in contacts {
4196 if let db::Contact::Accepted {
4197 user_id: contact_user_id,
4198 ..
4199 } = contact
4200 {
4201 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4202 session
4203 .peer
4204 .send(
4205 contact_conn_id,
4206 proto::UpdateContacts {
4207 contacts: vec![updated_contact.clone()],
4208 remove_contacts: Default::default(),
4209 incoming_requests: Default::default(),
4210 remove_incoming_requests: Default::default(),
4211 outgoing_requests: Default::default(),
4212 remove_outgoing_requests: Default::default(),
4213 },
4214 )
4215 .trace_err();
4216 }
4217 }
4218 }
4219 Ok(())
4220}
4221
4222async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4223 let mut contacts_to_update = HashSet::default();
4224
4225 let room_id;
4226 let canceled_calls_to_user_ids;
4227 let livekit_room;
4228 let delete_livekit_room;
4229 let room;
4230 let channel;
4231
4232 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4233 contacts_to_update.insert(session.user_id());
4234
4235 for project in left_room.left_projects.values() {
4236 project_left(project, session);
4237 }
4238
4239 room_id = RoomId::from_proto(left_room.room.id);
4240 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4241 livekit_room = mem::take(&mut left_room.room.livekit_room);
4242 delete_livekit_room = left_room.deleted;
4243 room = mem::take(&mut left_room.room);
4244 channel = mem::take(&mut left_room.channel);
4245
4246 room_updated(&room, &session.peer);
4247 } else {
4248 return Ok(());
4249 }
4250
4251 if let Some(channel) = channel {
4252 channel_updated(
4253 &channel,
4254 &room,
4255 &session.peer,
4256 &*session.connection_pool().await,
4257 );
4258 }
4259
4260 {
4261 let pool = session.connection_pool().await;
4262 for canceled_user_id in canceled_calls_to_user_ids {
4263 for connection_id in pool.user_connection_ids(canceled_user_id) {
4264 session
4265 .peer
4266 .send(
4267 connection_id,
4268 proto::CallCanceled {
4269 room_id: room_id.to_proto(),
4270 },
4271 )
4272 .trace_err();
4273 }
4274 contacts_to_update.insert(canceled_user_id);
4275 }
4276 }
4277
4278 for contact_user_id in contacts_to_update {
4279 update_user_contacts(contact_user_id, session).await?;
4280 }
4281
4282 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4283 live_kit
4284 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4285 .await
4286 .trace_err();
4287
4288 if delete_livekit_room {
4289 live_kit.delete_room(livekit_room).await.trace_err();
4290 }
4291 }
4292
4293 Ok(())
4294}
4295
4296async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4297 let left_channel_buffers = session
4298 .db()
4299 .await
4300 .leave_channel_buffers(session.connection_id)
4301 .await?;
4302
4303 for left_buffer in left_channel_buffers {
4304 channel_buffer_updated(
4305 session.connection_id,
4306 left_buffer.connections,
4307 &proto::UpdateChannelBufferCollaborators {
4308 channel_id: left_buffer.channel_id.to_proto(),
4309 collaborators: left_buffer.collaborators,
4310 },
4311 &session.peer,
4312 );
4313 }
4314
4315 Ok(())
4316}
4317
4318fn project_left(project: &db::LeftProject, session: &Session) {
4319 for connection_id in &project.connection_ids {
4320 if project.should_unshare {
4321 session
4322 .peer
4323 .send(
4324 *connection_id,
4325 proto::UnshareProject {
4326 project_id: project.id.to_proto(),
4327 },
4328 )
4329 .trace_err();
4330 } else {
4331 session
4332 .peer
4333 .send(
4334 *connection_id,
4335 proto::RemoveProjectCollaborator {
4336 project_id: project.id.to_proto(),
4337 peer_id: Some(session.connection_id.into()),
4338 },
4339 )
4340 .trace_err();
4341 }
4342 }
4343}
4344
4345pub trait ResultExt {
4346 type Ok;
4347
4348 fn trace_err(self) -> Option<Self::Ok>;
4349}
4350
4351impl<T, E> ResultExt for Result<T, E>
4352where
4353 E: std::fmt::Debug,
4354{
4355 type Ok = T;
4356
4357 #[track_caller]
4358 fn trace_err(self) -> Option<T> {
4359 match self {
4360 Ok(value) => Some(value),
4361 Err(error) => {
4362 tracing::error!("{:?}", error);
4363 None
4364 }
4365 }
4366 }
4367}