1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
8 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
9 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
10 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use futures::TryFutureExt as _;
36use reqwest_client::ReqwestClient;
37use rpc::proto::{MultiLspQuery, split_repository_update};
38use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
39use tracing::Span;
40
41use futures::{
42 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
43 stream::FuturesUnordered,
44};
45use prometheus::{IntGauge, register_int_gauge};
46use rpc::{
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52};
53use semantic_version::SemanticVersion;
54use serde::{Serialize, Serializer};
55use std::{
56 any::TypeId,
57 future::Future,
58 marker::PhantomData,
59 mem,
60 net::SocketAddr,
61 ops::{Deref, DerefMut},
62 rc::Rc,
63 sync::{
64 Arc, OnceLock,
65 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
66 },
67 time::{Duration, Instant},
68};
69use time::OffsetDateTime;
70use tokio::sync::{Semaphore, watch};
71use tower::ServiceBuilder;
72use tracing::{
73 Instrument,
74 field::{self},
75 info_span, instrument,
76};
77
78pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
79
80// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
81pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
82
83const MESSAGE_COUNT_PER_PAGE: usize = 100;
84const MAX_MESSAGE_LEN: usize = 1024;
85const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
86const MAX_CONCURRENT_CONNECTIONS: usize = 512;
87
88static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
89
90const TOTAL_DURATION_MS: &str = "total_duration_ms";
91const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
92const QUEUE_DURATION_MS: &str = "queue_duration_ms";
93const HOST_WAITING_MS: &str = "host_waiting_ms";
94
95type MessageHandler =
96 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
97
98pub struct ConnectionGuard;
99
100impl ConnectionGuard {
101 pub fn try_acquire() -> Result<Self, ()> {
102 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
103 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
104 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
105 tracing::error!(
106 "too many concurrent connections: {}",
107 current_connections + 1
108 );
109 return Err(());
110 }
111 Ok(ConnectionGuard)
112 }
113}
114
115impl Drop for ConnectionGuard {
116 fn drop(&mut self) {
117 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
118 }
119}
120
121struct Response<R> {
122 peer: Arc<Peer>,
123 receipt: Receipt<R>,
124 responded: Arc<AtomicBool>,
125}
126
127impl<R: RequestMessage> Response<R> {
128 fn send(self, payload: R::Response) -> Result<()> {
129 self.responded.store(true, SeqCst);
130 self.peer.respond(self.receipt, payload)?;
131 Ok(())
132 }
133}
134
135#[derive(Clone, Debug)]
136pub enum Principal {
137 User(User),
138 Impersonated { user: User, admin: User },
139}
140
141impl Principal {
142 fn update_span(&self, span: &tracing::Span) {
143 match &self {
144 Principal::User(user) => {
145 span.record("user_id", user.id.0);
146 span.record("login", &user.github_login);
147 }
148 Principal::Impersonated { user, admin } => {
149 span.record("user_id", user.id.0);
150 span.record("login", &user.github_login);
151 span.record("impersonator", &admin.github_login);
152 }
153 }
154 }
155}
156
157#[derive(Clone)]
158struct MessageContext {
159 session: Session,
160 span: tracing::Span,
161}
162
163impl Deref for MessageContext {
164 type Target = Session;
165
166 fn deref(&self) -> &Self::Target {
167 &self.session
168 }
169}
170
171impl MessageContext {
172 pub fn forward_request<T: RequestMessage>(
173 &self,
174 receiver_id: ConnectionId,
175 request: T,
176 ) -> impl Future<Output = anyhow::Result<T::Response>> {
177 let request_start_time = Instant::now();
178 let span = self.span.clone();
179 tracing::info!("start forwarding request");
180 self.peer
181 .forward_request(self.connection_id, receiver_id, request)
182 .inspect(move |_| {
183 span.record(
184 HOST_WAITING_MS,
185 request_start_time.elapsed().as_micros() as f64 / 1000.0,
186 );
187 })
188 .inspect_err(|_| tracing::error!("error forwarding request"))
189 .inspect_ok(|_| tracing::info!("finished forwarding request"))
190 }
191}
192
193#[derive(Clone)]
194struct Session {
195 principal: Principal,
196 connection_id: ConnectionId,
197 db: Arc<tokio::sync::Mutex<DbHandle>>,
198 peer: Arc<Peer>,
199 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
200 app_state: Arc<AppState>,
201 supermaven_client: Option<Arc<SupermavenAdminApi>>,
202 /// The GeoIP country code for the user.
203 #[allow(unused)]
204 geoip_country_code: Option<String>,
205 #[allow(unused)]
206 system_id: Option<String>,
207 _executor: Executor,
208}
209
210impl Session {
211 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
212 #[cfg(test)]
213 tokio::task::yield_now().await;
214 let guard = self.db.lock().await;
215 #[cfg(test)]
216 tokio::task::yield_now().await;
217 guard
218 }
219
220 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
221 #[cfg(test)]
222 tokio::task::yield_now().await;
223 let guard = self.connection_pool.lock();
224 ConnectionPoolGuard {
225 guard,
226 _not_send: PhantomData,
227 }
228 }
229
230 fn is_staff(&self) -> bool {
231 match &self.principal {
232 Principal::User(user) => user.admin,
233 Principal::Impersonated { .. } => true,
234 }
235 }
236
237 fn user_id(&self) -> UserId {
238 match &self.principal {
239 Principal::User(user) => user.id,
240 Principal::Impersonated { user, .. } => user.id,
241 }
242 }
243
244 pub fn email(&self) -> Option<String> {
245 match &self.principal {
246 Principal::User(user) => user.email_address.clone(),
247 Principal::Impersonated { user, .. } => user.email_address.clone(),
248 }
249 }
250}
251
252impl Debug for Session {
253 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
254 let mut result = f.debug_struct("Session");
255 match &self.principal {
256 Principal::User(user) => {
257 result.field("user", &user.github_login);
258 }
259 Principal::Impersonated { user, admin } => {
260 result.field("user", &user.github_login);
261 result.field("impersonator", &admin.github_login);
262 }
263 }
264 result.field("connection_id", &self.connection_id).finish()
265 }
266}
267
268struct DbHandle(Arc<Database>);
269
270impl Deref for DbHandle {
271 type Target = Database;
272
273 fn deref(&self) -> &Self::Target {
274 self.0.as_ref()
275 }
276}
277
278pub struct Server {
279 id: parking_lot::Mutex<ServerId>,
280 peer: Arc<Peer>,
281 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
282 app_state: Arc<AppState>,
283 handlers: HashMap<TypeId, MessageHandler>,
284 teardown: watch::Sender<bool>,
285}
286
287pub(crate) struct ConnectionPoolGuard<'a> {
288 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
289 _not_send: PhantomData<Rc<()>>,
290}
291
292#[derive(Serialize)]
293pub struct ServerSnapshot<'a> {
294 peer: &'a Peer,
295 #[serde(serialize_with = "serialize_deref")]
296 connection_pool: ConnectionPoolGuard<'a>,
297}
298
299pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
300where
301 S: Serializer,
302 T: Deref<Target = U>,
303 U: Serialize,
304{
305 Serialize::serialize(value.deref(), serializer)
306}
307
308impl Server {
309 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
310 let mut server = Self {
311 id: parking_lot::Mutex::new(id),
312 peer: Peer::new(id.0 as u32),
313 app_state,
314 connection_pool: Default::default(),
315 handlers: Default::default(),
316 teardown: watch::channel(false).0,
317 };
318
319 server
320 .add_request_handler(ping)
321 .add_request_handler(create_room)
322 .add_request_handler(join_room)
323 .add_request_handler(rejoin_room)
324 .add_request_handler(leave_room)
325 .add_request_handler(set_room_participant_role)
326 .add_request_handler(call)
327 .add_request_handler(cancel_call)
328 .add_message_handler(decline_call)
329 .add_request_handler(update_participant_location)
330 .add_request_handler(share_project)
331 .add_message_handler(unshare_project)
332 .add_request_handler(join_project)
333 .add_message_handler(leave_project)
334 .add_request_handler(update_project)
335 .add_request_handler(update_worktree)
336 .add_request_handler(update_repository)
337 .add_request_handler(remove_repository)
338 .add_message_handler(start_language_server)
339 .add_message_handler(update_language_server)
340 .add_message_handler(update_diagnostic_summary)
341 .add_message_handler(update_worktree_settings)
342 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
343 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
344 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
345 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
346 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
347 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
348 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
349 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
350 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
351 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
352 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
353 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
354 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
355 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
356 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
357 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
358 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
359 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
360 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
361 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
362 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
363 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
364 .add_request_handler(
365 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
366 )
367 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
368 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
369 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
370 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
371 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
372 .add_request_handler(
373 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
374 )
375 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
376 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
377 .add_request_handler(
378 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
379 )
380 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
381 .add_request_handler(
382 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
383 )
384 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
385 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
386 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
387 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
388 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
389 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
390 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
391 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
392 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
393 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
394 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
395 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
396 .add_request_handler(
397 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
398 )
399 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
400 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
401 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
402 .add_request_handler(multi_lsp_query)
403 .add_request_handler(lsp_query)
404 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
405 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
406 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
407 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
408 .add_message_handler(create_buffer_for_peer)
409 .add_request_handler(update_buffer)
410 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
411 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
412 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
413 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
414 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
415 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
416 .add_message_handler(
417 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
418 )
419 .add_request_handler(get_users)
420 .add_request_handler(fuzzy_search_users)
421 .add_request_handler(request_contact)
422 .add_request_handler(remove_contact)
423 .add_request_handler(respond_to_contact_request)
424 .add_message_handler(subscribe_to_channels)
425 .add_request_handler(create_channel)
426 .add_request_handler(delete_channel)
427 .add_request_handler(invite_channel_member)
428 .add_request_handler(remove_channel_member)
429 .add_request_handler(set_channel_member_role)
430 .add_request_handler(set_channel_visibility)
431 .add_request_handler(rename_channel)
432 .add_request_handler(join_channel_buffer)
433 .add_request_handler(leave_channel_buffer)
434 .add_message_handler(update_channel_buffer)
435 .add_request_handler(rejoin_channel_buffers)
436 .add_request_handler(get_channel_members)
437 .add_request_handler(respond_to_channel_invite)
438 .add_request_handler(join_channel)
439 .add_request_handler(join_channel_chat)
440 .add_message_handler(leave_channel_chat)
441 .add_request_handler(send_channel_message)
442 .add_request_handler(remove_channel_message)
443 .add_request_handler(update_channel_message)
444 .add_request_handler(get_channel_messages)
445 .add_request_handler(get_channel_messages_by_id)
446 .add_request_handler(get_notifications)
447 .add_request_handler(mark_notification_as_read)
448 .add_request_handler(move_channel)
449 .add_request_handler(reorder_channel)
450 .add_request_handler(follow)
451 .add_message_handler(unfollow)
452 .add_message_handler(update_followers)
453 .add_message_handler(acknowledge_channel_message)
454 .add_message_handler(acknowledge_buffer_version)
455 .add_request_handler(get_supermaven_api_key)
456 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
457 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
458 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
459 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
460 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
461 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
462 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
463 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
464 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
465 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
466 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
467 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
468 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
469 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
470 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
471 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
472 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
473 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
474 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
475 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
476 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
477 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
478 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
479 .add_message_handler(update_context);
480
481 Arc::new(server)
482 }
483
484 pub async fn start(&self) -> Result<()> {
485 let server_id = *self.id.lock();
486 let app_state = self.app_state.clone();
487 let peer = self.peer.clone();
488 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
489 let pool = self.connection_pool.clone();
490 let livekit_client = self.app_state.livekit_client.clone();
491
492 let span = info_span!("start server");
493 self.app_state.executor.spawn_detached(
494 async move {
495 tracing::info!("waiting for cleanup timeout");
496 timeout.await;
497 tracing::info!("cleanup timeout expired, retrieving stale rooms");
498
499 app_state
500 .db
501 .delete_stale_channel_chat_participants(
502 &app_state.config.zed_environment,
503 server_id,
504 )
505 .await
506 .trace_err();
507
508 if let Some((room_ids, channel_ids)) = app_state
509 .db
510 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
511 .await
512 .trace_err()
513 {
514 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
515 tracing::info!(
516 stale_channel_buffer_count = channel_ids.len(),
517 "retrieved stale channel buffers"
518 );
519
520 for channel_id in channel_ids {
521 if let Some(refreshed_channel_buffer) = app_state
522 .db
523 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
524 .await
525 .trace_err()
526 {
527 for connection_id in refreshed_channel_buffer.connection_ids {
528 peer.send(
529 connection_id,
530 proto::UpdateChannelBufferCollaborators {
531 channel_id: channel_id.to_proto(),
532 collaborators: refreshed_channel_buffer
533 .collaborators
534 .clone(),
535 },
536 )
537 .trace_err();
538 }
539 }
540 }
541
542 for room_id in room_ids {
543 let mut contacts_to_update = HashSet::default();
544 let mut canceled_calls_to_user_ids = Vec::new();
545 let mut livekit_room = String::new();
546 let mut delete_livekit_room = false;
547
548 if let Some(mut refreshed_room) = app_state
549 .db
550 .clear_stale_room_participants(room_id, server_id)
551 .await
552 .trace_err()
553 {
554 tracing::info!(
555 room_id = room_id.0,
556 new_participant_count = refreshed_room.room.participants.len(),
557 "refreshed room"
558 );
559 room_updated(&refreshed_room.room, &peer);
560 if let Some(channel) = refreshed_room.channel.as_ref() {
561 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
562 }
563 contacts_to_update
564 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
565 contacts_to_update
566 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
567 canceled_calls_to_user_ids =
568 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
569 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
570 delete_livekit_room = refreshed_room.room.participants.is_empty();
571 }
572
573 {
574 let pool = pool.lock();
575 for canceled_user_id in canceled_calls_to_user_ids {
576 for connection_id in pool.user_connection_ids(canceled_user_id) {
577 peer.send(
578 connection_id,
579 proto::CallCanceled {
580 room_id: room_id.to_proto(),
581 },
582 )
583 .trace_err();
584 }
585 }
586 }
587
588 for user_id in contacts_to_update {
589 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
590 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
591 if let Some((busy, contacts)) = busy.zip(contacts) {
592 let pool = pool.lock();
593 let updated_contact = contact_for_user(user_id, busy, &pool);
594 for contact in contacts {
595 if let db::Contact::Accepted {
596 user_id: contact_user_id,
597 ..
598 } = contact
599 {
600 for contact_conn_id in
601 pool.user_connection_ids(contact_user_id)
602 {
603 peer.send(
604 contact_conn_id,
605 proto::UpdateContacts {
606 contacts: vec![updated_contact.clone()],
607 remove_contacts: Default::default(),
608 incoming_requests: Default::default(),
609 remove_incoming_requests: Default::default(),
610 outgoing_requests: Default::default(),
611 remove_outgoing_requests: Default::default(),
612 },
613 )
614 .trace_err();
615 }
616 }
617 }
618 }
619 }
620
621 if let Some(live_kit) = livekit_client.as_ref()
622 && delete_livekit_room
623 {
624 live_kit.delete_room(livekit_room).await.trace_err();
625 }
626 }
627 }
628
629 app_state
630 .db
631 .delete_stale_channel_chat_participants(
632 &app_state.config.zed_environment,
633 server_id,
634 )
635 .await
636 .trace_err();
637
638 app_state
639 .db
640 .clear_old_worktree_entries(server_id)
641 .await
642 .trace_err();
643
644 app_state
645 .db
646 .delete_stale_servers(&app_state.config.zed_environment, server_id)
647 .await
648 .trace_err();
649 }
650 .instrument(span),
651 );
652 Ok(())
653 }
654
655 pub fn teardown(&self) {
656 self.peer.teardown();
657 self.connection_pool.lock().reset();
658 let _ = self.teardown.send(true);
659 }
660
661 #[cfg(test)]
662 pub fn reset(&self, id: ServerId) {
663 self.teardown();
664 *self.id.lock() = id;
665 self.peer.reset(id.0 as u32);
666 let _ = self.teardown.send(false);
667 }
668
669 #[cfg(test)]
670 pub fn id(&self) -> ServerId {
671 *self.id.lock()
672 }
673
674 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
675 where
676 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
677 Fut: 'static + Send + Future<Output = Result<()>>,
678 M: EnvelopedMessage,
679 {
680 let prev_handler = self.handlers.insert(
681 TypeId::of::<M>(),
682 Box::new(move |envelope, session, span| {
683 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
684 let received_at = envelope.received_at;
685 tracing::info!("message received");
686 let start_time = Instant::now();
687 let future = (handler)(
688 *envelope,
689 MessageContext {
690 session,
691 span: span.clone(),
692 },
693 );
694 async move {
695 let result = future.await;
696 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
697 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
698 let queue_duration_ms = total_duration_ms - processing_duration_ms;
699 span.record(TOTAL_DURATION_MS, total_duration_ms);
700 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
701 span.record(QUEUE_DURATION_MS, queue_duration_ms);
702 match result {
703 Err(error) => {
704 tracing::error!(?error, "error handling message")
705 }
706 Ok(()) => tracing::info!("finished handling message"),
707 }
708 }
709 .boxed()
710 }),
711 );
712 if prev_handler.is_some() {
713 panic!("registered a handler for the same message twice");
714 }
715 self
716 }
717
718 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
719 where
720 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
721 Fut: 'static + Send + Future<Output = Result<()>>,
722 M: EnvelopedMessage,
723 {
724 self.add_handler(move |envelope, session| handler(envelope.payload, session));
725 self
726 }
727
728 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
729 where
730 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
731 Fut: Send + Future<Output = Result<()>>,
732 M: RequestMessage,
733 {
734 let handler = Arc::new(handler);
735 self.add_handler(move |envelope, session| {
736 let receipt = envelope.receipt();
737 let handler = handler.clone();
738 async move {
739 let peer = session.peer.clone();
740 let responded = Arc::new(AtomicBool::default());
741 let response = Response {
742 peer: peer.clone(),
743 responded: responded.clone(),
744 receipt,
745 };
746 match (handler)(envelope.payload, response, session).await {
747 Ok(()) => {
748 if responded.load(std::sync::atomic::Ordering::SeqCst) {
749 Ok(())
750 } else {
751 Err(anyhow!("handler did not send a response"))?
752 }
753 }
754 Err(error) => {
755 let proto_err = match &error {
756 Error::Internal(err) => err.to_proto(),
757 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
758 };
759 peer.respond_with_error(receipt, proto_err)?;
760 Err(error)
761 }
762 }
763 }
764 })
765 }
766
767 pub fn handle_connection(
768 self: &Arc<Self>,
769 connection: Connection,
770 address: String,
771 principal: Principal,
772 zed_version: ZedVersion,
773 release_channel: Option<String>,
774 user_agent: Option<String>,
775 geoip_country_code: Option<String>,
776 system_id: Option<String>,
777 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
778 executor: Executor,
779 connection_guard: Option<ConnectionGuard>,
780 ) -> impl Future<Output = ()> + use<> {
781 let this = self.clone();
782 let span = info_span!("handle connection", %address,
783 connection_id=field::Empty,
784 user_id=field::Empty,
785 login=field::Empty,
786 impersonator=field::Empty,
787 user_agent=field::Empty,
788 geoip_country_code=field::Empty,
789 release_channel=field::Empty,
790 );
791 principal.update_span(&span);
792 if let Some(user_agent) = user_agent {
793 span.record("user_agent", user_agent);
794 }
795 if let Some(release_channel) = release_channel {
796 span.record("release_channel", release_channel);
797 }
798
799 if let Some(country_code) = geoip_country_code.as_ref() {
800 span.record("geoip_country_code", country_code);
801 }
802
803 let mut teardown = self.teardown.subscribe();
804 async move {
805 if *teardown.borrow() {
806 tracing::error!("server is tearing down");
807 return;
808 }
809
810 let (connection_id, handle_io, mut incoming_rx) =
811 this.peer.add_connection(connection, {
812 let executor = executor.clone();
813 move |duration| executor.sleep(duration)
814 });
815 tracing::Span::current().record("connection_id", format!("{}", connection_id));
816
817 tracing::info!("connection opened");
818
819 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
820 let http_client = match ReqwestClient::user_agent(&user_agent) {
821 Ok(http_client) => Arc::new(http_client),
822 Err(error) => {
823 tracing::error!(?error, "failed to create HTTP client");
824 return;
825 }
826 };
827
828 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
829 |supermaven_admin_api_key| {
830 Arc::new(SupermavenAdminApi::new(
831 supermaven_admin_api_key.to_string(),
832 http_client.clone(),
833 ))
834 },
835 );
836
837 let session = Session {
838 principal: principal.clone(),
839 connection_id,
840 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
841 peer: this.peer.clone(),
842 connection_pool: this.connection_pool.clone(),
843 app_state: this.app_state.clone(),
844 geoip_country_code,
845 system_id,
846 _executor: executor.clone(),
847 supermaven_client,
848 };
849
850 if let Err(error) = this
851 .send_initial_client_update(
852 connection_id,
853 zed_version,
854 send_connection_id,
855 &session,
856 )
857 .await
858 {
859 tracing::error!(?error, "failed to send initial client update");
860 return;
861 }
862 drop(connection_guard);
863
864 let handle_io = handle_io.fuse();
865 futures::pin_mut!(handle_io);
866
867 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
868 // This prevents deadlocks when e.g., client A performs a request to client B and
869 // client B performs a request to client A. If both clients stop processing further
870 // messages until their respective request completes, they won't have a chance to
871 // respond to the other client's request and cause a deadlock.
872 //
873 // This arrangement ensures we will attempt to process earlier messages first, but fall
874 // back to processing messages arrived later in the spirit of making progress.
875 const MAX_CONCURRENT_HANDLERS: usize = 256;
876 let mut foreground_message_handlers = FuturesUnordered::new();
877 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
878 let get_concurrent_handlers = {
879 let concurrent_handlers = concurrent_handlers.clone();
880 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
881 };
882 loop {
883 let next_message = async {
884 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
885 let message = incoming_rx.next().await;
886 // Cache the concurrent_handlers here, so that we know what the
887 // queue looks like as each handler starts
888 (permit, message, get_concurrent_handlers())
889 }
890 .fuse();
891 futures::pin_mut!(next_message);
892 futures::select_biased! {
893 _ = teardown.changed().fuse() => return,
894 result = handle_io => {
895 if let Err(error) = result {
896 tracing::error!(?error, "error handling I/O");
897 }
898 break;
899 }
900 _ = foreground_message_handlers.next() => {}
901 next_message = next_message => {
902 let (permit, message, concurrent_handlers) = next_message;
903 if let Some(message) = message {
904 let type_name = message.payload_type_name();
905 // note: we copy all the fields from the parent span so we can query them in the logs.
906 // (https://github.com/tokio-rs/tracing/issues/2670).
907 let span = tracing::info_span!("receive message",
908 %connection_id,
909 %address,
910 type_name,
911 concurrent_handlers,
912 user_id=field::Empty,
913 login=field::Empty,
914 impersonator=field::Empty,
915 // todo(lsp) remove after Zed Stable hits v0.204.x
916 multi_lsp_query_request=field::Empty,
917 lsp_query_request=field::Empty,
918 release_channel=field::Empty,
919 { TOTAL_DURATION_MS }=field::Empty,
920 { PROCESSING_DURATION_MS }=field::Empty,
921 { QUEUE_DURATION_MS }=field::Empty,
922 { HOST_WAITING_MS }=field::Empty
923 );
924 principal.update_span(&span);
925 let span_enter = span.enter();
926 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
927 let is_background = message.is_background();
928 let handle_message = (handler)(message, session.clone(), span.clone());
929 drop(span_enter);
930
931 let handle_message = async move {
932 handle_message.await;
933 drop(permit);
934 }.instrument(span);
935 if is_background {
936 executor.spawn_detached(handle_message);
937 } else {
938 foreground_message_handlers.push(handle_message);
939 }
940 } else {
941 tracing::error!("no message handler");
942 }
943 } else {
944 tracing::info!("connection closed");
945 break;
946 }
947 }
948 }
949 }
950
951 drop(foreground_message_handlers);
952 let concurrent_handlers = get_concurrent_handlers();
953 tracing::info!(concurrent_handlers, "signing out");
954 if let Err(error) = connection_lost(session, teardown, executor).await {
955 tracing::error!(?error, "error signing out");
956 }
957 }
958 .instrument(span)
959 }
960
961 async fn send_initial_client_update(
962 &self,
963 connection_id: ConnectionId,
964 zed_version: ZedVersion,
965 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
966 session: &Session,
967 ) -> Result<()> {
968 self.peer.send(
969 connection_id,
970 proto::Hello {
971 peer_id: Some(connection_id.into()),
972 },
973 )?;
974 tracing::info!("sent hello message");
975 if let Some(send_connection_id) = send_connection_id.take() {
976 let _ = send_connection_id.send(connection_id);
977 }
978
979 match &session.principal {
980 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
981 if !user.connected_once {
982 self.peer.send(connection_id, proto::ShowContacts {})?;
983 self.app_state
984 .db
985 .set_user_connected_once(user.id, true)
986 .await?;
987 }
988
989 let contacts = self.app_state.db.get_contacts(user.id).await?;
990
991 {
992 let mut pool = self.connection_pool.lock();
993 pool.add_connection(connection_id, user.id, user.admin, zed_version);
994 self.peer.send(
995 connection_id,
996 build_initial_contacts_update(contacts, &pool),
997 )?;
998 }
999
1000 if should_auto_subscribe_to_channels(zed_version) {
1001 subscribe_user_to_channels(user.id, session).await?;
1002 }
1003
1004 if let Some(incoming_call) =
1005 self.app_state.db.incoming_call_for_user(user.id).await?
1006 {
1007 self.peer.send(connection_id, incoming_call)?;
1008 }
1009
1010 update_user_contacts(user.id, session).await?;
1011 }
1012 }
1013
1014 Ok(())
1015 }
1016
1017 pub async fn invite_code_redeemed(
1018 self: &Arc<Self>,
1019 inviter_id: UserId,
1020 invitee_id: UserId,
1021 ) -> Result<()> {
1022 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await?
1023 && let Some(code) = &user.invite_code
1024 {
1025 let pool = self.connection_pool.lock();
1026 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1027 for connection_id in pool.user_connection_ids(inviter_id) {
1028 self.peer.send(
1029 connection_id,
1030 proto::UpdateContacts {
1031 contacts: vec![invitee_contact.clone()],
1032 ..Default::default()
1033 },
1034 )?;
1035 self.peer.send(
1036 connection_id,
1037 proto::UpdateInviteInfo {
1038 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1039 count: user.invite_count as u32,
1040 },
1041 )?;
1042 }
1043 }
1044 Ok(())
1045 }
1046
1047 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1048 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await?
1049 && let Some(invite_code) = &user.invite_code
1050 {
1051 let pool = self.connection_pool.lock();
1052 for connection_id in pool.user_connection_ids(user_id) {
1053 self.peer.send(
1054 connection_id,
1055 proto::UpdateInviteInfo {
1056 url: format!(
1057 "{}{}",
1058 self.app_state.config.invite_link_prefix, invite_code
1059 ),
1060 count: user.invite_count as u32,
1061 },
1062 )?;
1063 }
1064 }
1065 Ok(())
1066 }
1067
1068 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1069 ServerSnapshot {
1070 connection_pool: ConnectionPoolGuard {
1071 guard: self.connection_pool.lock(),
1072 _not_send: PhantomData,
1073 },
1074 peer: &self.peer,
1075 }
1076 }
1077}
1078
1079impl Deref for ConnectionPoolGuard<'_> {
1080 type Target = ConnectionPool;
1081
1082 fn deref(&self) -> &Self::Target {
1083 &self.guard
1084 }
1085}
1086
1087impl DerefMut for ConnectionPoolGuard<'_> {
1088 fn deref_mut(&mut self) -> &mut Self::Target {
1089 &mut self.guard
1090 }
1091}
1092
1093impl Drop for ConnectionPoolGuard<'_> {
1094 fn drop(&mut self) {
1095 #[cfg(test)]
1096 self.check_invariants();
1097 }
1098}
1099
1100fn broadcast<F>(
1101 sender_id: Option<ConnectionId>,
1102 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1103 mut f: F,
1104) where
1105 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1106{
1107 for receiver_id in receiver_ids {
1108 if Some(receiver_id) != sender_id
1109 && let Err(error) = f(receiver_id)
1110 {
1111 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1112 }
1113 }
1114}
1115
1116pub struct ProtocolVersion(u32);
1117
1118impl Header for ProtocolVersion {
1119 fn name() -> &'static HeaderName {
1120 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1121 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1122 }
1123
1124 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1125 where
1126 Self: Sized,
1127 I: Iterator<Item = &'i axum::http::HeaderValue>,
1128 {
1129 let version = values
1130 .next()
1131 .ok_or_else(axum::headers::Error::invalid)?
1132 .to_str()
1133 .map_err(|_| axum::headers::Error::invalid())?
1134 .parse()
1135 .map_err(|_| axum::headers::Error::invalid())?;
1136 Ok(Self(version))
1137 }
1138
1139 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1140 values.extend([self.0.to_string().parse().unwrap()]);
1141 }
1142}
1143
1144pub struct AppVersionHeader(SemanticVersion);
1145impl Header for AppVersionHeader {
1146 fn name() -> &'static HeaderName {
1147 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1148 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1149 }
1150
1151 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1152 where
1153 Self: Sized,
1154 I: Iterator<Item = &'i axum::http::HeaderValue>,
1155 {
1156 let version = values
1157 .next()
1158 .ok_or_else(axum::headers::Error::invalid)?
1159 .to_str()
1160 .map_err(|_| axum::headers::Error::invalid())?
1161 .parse()
1162 .map_err(|_| axum::headers::Error::invalid())?;
1163 Ok(Self(version))
1164 }
1165
1166 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1167 values.extend([self.0.to_string().parse().unwrap()]);
1168 }
1169}
1170
1171#[derive(Debug)]
1172pub struct ReleaseChannelHeader(String);
1173
1174impl Header for ReleaseChannelHeader {
1175 fn name() -> &'static HeaderName {
1176 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1177 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1178 }
1179
1180 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1181 where
1182 Self: Sized,
1183 I: Iterator<Item = &'i axum::http::HeaderValue>,
1184 {
1185 Ok(Self(
1186 values
1187 .next()
1188 .ok_or_else(axum::headers::Error::invalid)?
1189 .to_str()
1190 .map_err(|_| axum::headers::Error::invalid())?
1191 .to_owned(),
1192 ))
1193 }
1194
1195 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1196 values.extend([self.0.parse().unwrap()]);
1197 }
1198}
1199
1200pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1201 Router::new()
1202 .route("/rpc", get(handle_websocket_request))
1203 .layer(
1204 ServiceBuilder::new()
1205 .layer(Extension(server.app_state.clone()))
1206 .layer(middleware::from_fn(auth::validate_header)),
1207 )
1208 .route("/metrics", get(handle_metrics))
1209 .layer(Extension(server))
1210}
1211
1212pub async fn handle_websocket_request(
1213 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1214 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1215 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1216 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1217 Extension(server): Extension<Arc<Server>>,
1218 Extension(principal): Extension<Principal>,
1219 user_agent: Option<TypedHeader<UserAgent>>,
1220 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1221 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1222 ws: WebSocketUpgrade,
1223) -> axum::response::Response {
1224 if protocol_version != rpc::PROTOCOL_VERSION {
1225 return (
1226 StatusCode::UPGRADE_REQUIRED,
1227 "client must be upgraded".to_string(),
1228 )
1229 .into_response();
1230 }
1231
1232 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1233 return (
1234 StatusCode::UPGRADE_REQUIRED,
1235 "no version header found".to_string(),
1236 )
1237 .into_response();
1238 };
1239
1240 let release_channel = release_channel_header.map(|header| header.0.0);
1241
1242 if !version.can_collaborate() {
1243 return (
1244 StatusCode::UPGRADE_REQUIRED,
1245 "client must be upgraded".to_string(),
1246 )
1247 .into_response();
1248 }
1249
1250 let socket_address = socket_address.to_string();
1251
1252 // Acquire connection guard before WebSocket upgrade
1253 let connection_guard = match ConnectionGuard::try_acquire() {
1254 Ok(guard) => guard,
1255 Err(()) => {
1256 return (
1257 StatusCode::SERVICE_UNAVAILABLE,
1258 "Too many concurrent connections",
1259 )
1260 .into_response();
1261 }
1262 };
1263
1264 ws.on_upgrade(move |socket| {
1265 let socket = socket
1266 .map_ok(to_tungstenite_message)
1267 .err_into()
1268 .with(|message| async move { to_axum_message(message) });
1269 let connection = Connection::new(Box::pin(socket));
1270 async move {
1271 server
1272 .handle_connection(
1273 connection,
1274 socket_address,
1275 principal,
1276 version,
1277 release_channel,
1278 user_agent.map(|header| header.to_string()),
1279 country_code_header.map(|header| header.to_string()),
1280 system_id_header.map(|header| header.to_string()),
1281 None,
1282 Executor::Production,
1283 Some(connection_guard),
1284 )
1285 .await;
1286 }
1287 })
1288}
1289
1290pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1291 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1292 let connections_metric = CONNECTIONS_METRIC
1293 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1294
1295 let connections = server
1296 .connection_pool
1297 .lock()
1298 .connections()
1299 .filter(|connection| !connection.admin)
1300 .count();
1301 connections_metric.set(connections as _);
1302
1303 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1304 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1305 register_int_gauge!(
1306 "shared_projects",
1307 "number of open projects with one or more guests"
1308 )
1309 .unwrap()
1310 });
1311
1312 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1313 shared_projects_metric.set(shared_projects as _);
1314
1315 let encoder = prometheus::TextEncoder::new();
1316 let metric_families = prometheus::gather();
1317 let encoded_metrics = encoder
1318 .encode_to_string(&metric_families)
1319 .map_err(|err| anyhow!("{err}"))?;
1320 Ok(encoded_metrics)
1321}
1322
1323#[instrument(err, skip(executor))]
1324async fn connection_lost(
1325 session: Session,
1326 mut teardown: watch::Receiver<bool>,
1327 executor: Executor,
1328) -> Result<()> {
1329 session.peer.disconnect(session.connection_id);
1330 session
1331 .connection_pool()
1332 .await
1333 .remove_connection(session.connection_id)?;
1334
1335 session
1336 .db()
1337 .await
1338 .connection_lost(session.connection_id)
1339 .await
1340 .trace_err();
1341
1342 futures::select_biased! {
1343 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1344
1345 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1346 leave_room_for_session(&session, session.connection_id).await.trace_err();
1347 leave_channel_buffers_for_session(&session)
1348 .await
1349 .trace_err();
1350
1351 if !session
1352 .connection_pool()
1353 .await
1354 .is_user_online(session.user_id())
1355 {
1356 let db = session.db().await;
1357 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1358 room_updated(&room, &session.peer);
1359 }
1360 }
1361
1362 update_user_contacts(session.user_id(), &session).await?;
1363 },
1364 _ = teardown.changed().fuse() => {}
1365 }
1366
1367 Ok(())
1368}
1369
1370/// Acknowledges a ping from a client, used to keep the connection alive.
1371async fn ping(
1372 _: proto::Ping,
1373 response: Response<proto::Ping>,
1374 _session: MessageContext,
1375) -> Result<()> {
1376 response.send(proto::Ack {})?;
1377 Ok(())
1378}
1379
1380/// Creates a new room for calling (outside of channels)
1381async fn create_room(
1382 _request: proto::CreateRoom,
1383 response: Response<proto::CreateRoom>,
1384 session: MessageContext,
1385) -> Result<()> {
1386 let livekit_room = nanoid::nanoid!(30);
1387
1388 let live_kit_connection_info = util::maybe!(async {
1389 let live_kit = session.app_state.livekit_client.as_ref();
1390 let live_kit = live_kit?;
1391 let user_id = session.user_id().to_string();
1392
1393 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1394
1395 Some(proto::LiveKitConnectionInfo {
1396 server_url: live_kit.url().into(),
1397 token,
1398 can_publish: true,
1399 })
1400 })
1401 .await;
1402
1403 let room = session
1404 .db()
1405 .await
1406 .create_room(session.user_id(), session.connection_id, &livekit_room)
1407 .await?;
1408
1409 response.send(proto::CreateRoomResponse {
1410 room: Some(room.clone()),
1411 live_kit_connection_info,
1412 })?;
1413
1414 update_user_contacts(session.user_id(), &session).await?;
1415 Ok(())
1416}
1417
1418/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1419async fn join_room(
1420 request: proto::JoinRoom,
1421 response: Response<proto::JoinRoom>,
1422 session: MessageContext,
1423) -> Result<()> {
1424 let room_id = RoomId::from_proto(request.id);
1425
1426 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1427
1428 if let Some(channel_id) = channel_id {
1429 return join_channel_internal(channel_id, Box::new(response), session).await;
1430 }
1431
1432 let joined_room = {
1433 let room = session
1434 .db()
1435 .await
1436 .join_room(room_id, session.user_id(), session.connection_id)
1437 .await?;
1438 room_updated(&room.room, &session.peer);
1439 room.into_inner()
1440 };
1441
1442 for connection_id in session
1443 .connection_pool()
1444 .await
1445 .user_connection_ids(session.user_id())
1446 {
1447 session
1448 .peer
1449 .send(
1450 connection_id,
1451 proto::CallCanceled {
1452 room_id: room_id.to_proto(),
1453 },
1454 )
1455 .trace_err();
1456 }
1457
1458 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1459 {
1460 live_kit
1461 .room_token(
1462 &joined_room.room.livekit_room,
1463 &session.user_id().to_string(),
1464 )
1465 .trace_err()
1466 .map(|token| proto::LiveKitConnectionInfo {
1467 server_url: live_kit.url().into(),
1468 token,
1469 can_publish: true,
1470 })
1471 } else {
1472 None
1473 };
1474
1475 response.send(proto::JoinRoomResponse {
1476 room: Some(joined_room.room),
1477 channel_id: None,
1478 live_kit_connection_info,
1479 })?;
1480
1481 update_user_contacts(session.user_id(), &session).await?;
1482 Ok(())
1483}
1484
1485/// Rejoin room is used to reconnect to a room after connection errors.
1486async fn rejoin_room(
1487 request: proto::RejoinRoom,
1488 response: Response<proto::RejoinRoom>,
1489 session: MessageContext,
1490) -> Result<()> {
1491 let room;
1492 let channel;
1493 {
1494 let mut rejoined_room = session
1495 .db()
1496 .await
1497 .rejoin_room(request, session.user_id(), session.connection_id)
1498 .await?;
1499
1500 response.send(proto::RejoinRoomResponse {
1501 room: Some(rejoined_room.room.clone()),
1502 reshared_projects: rejoined_room
1503 .reshared_projects
1504 .iter()
1505 .map(|project| proto::ResharedProject {
1506 id: project.id.to_proto(),
1507 collaborators: project
1508 .collaborators
1509 .iter()
1510 .map(|collaborator| collaborator.to_proto())
1511 .collect(),
1512 })
1513 .collect(),
1514 rejoined_projects: rejoined_room
1515 .rejoined_projects
1516 .iter()
1517 .map(|rejoined_project| rejoined_project.to_proto())
1518 .collect(),
1519 })?;
1520 room_updated(&rejoined_room.room, &session.peer);
1521
1522 for project in &rejoined_room.reshared_projects {
1523 for collaborator in &project.collaborators {
1524 session
1525 .peer
1526 .send(
1527 collaborator.connection_id,
1528 proto::UpdateProjectCollaborator {
1529 project_id: project.id.to_proto(),
1530 old_peer_id: Some(project.old_connection_id.into()),
1531 new_peer_id: Some(session.connection_id.into()),
1532 },
1533 )
1534 .trace_err();
1535 }
1536
1537 broadcast(
1538 Some(session.connection_id),
1539 project
1540 .collaborators
1541 .iter()
1542 .map(|collaborator| collaborator.connection_id),
1543 |connection_id| {
1544 session.peer.forward_send(
1545 session.connection_id,
1546 connection_id,
1547 proto::UpdateProject {
1548 project_id: project.id.to_proto(),
1549 worktrees: project.worktrees.clone(),
1550 },
1551 )
1552 },
1553 );
1554 }
1555
1556 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1557
1558 let rejoined_room = rejoined_room.into_inner();
1559
1560 room = rejoined_room.room;
1561 channel = rejoined_room.channel;
1562 }
1563
1564 if let Some(channel) = channel {
1565 channel_updated(
1566 &channel,
1567 &room,
1568 &session.peer,
1569 &*session.connection_pool().await,
1570 );
1571 }
1572
1573 update_user_contacts(session.user_id(), &session).await?;
1574 Ok(())
1575}
1576
1577fn notify_rejoined_projects(
1578 rejoined_projects: &mut Vec<RejoinedProject>,
1579 session: &Session,
1580) -> Result<()> {
1581 for project in rejoined_projects.iter() {
1582 for collaborator in &project.collaborators {
1583 session
1584 .peer
1585 .send(
1586 collaborator.connection_id,
1587 proto::UpdateProjectCollaborator {
1588 project_id: project.id.to_proto(),
1589 old_peer_id: Some(project.old_connection_id.into()),
1590 new_peer_id: Some(session.connection_id.into()),
1591 },
1592 )
1593 .trace_err();
1594 }
1595 }
1596
1597 for project in rejoined_projects {
1598 for worktree in mem::take(&mut project.worktrees) {
1599 // Stream this worktree's entries.
1600 let message = proto::UpdateWorktree {
1601 project_id: project.id.to_proto(),
1602 worktree_id: worktree.id,
1603 abs_path: worktree.abs_path.clone(),
1604 root_name: worktree.root_name,
1605 updated_entries: worktree.updated_entries,
1606 removed_entries: worktree.removed_entries,
1607 scan_id: worktree.scan_id,
1608 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1609 updated_repositories: worktree.updated_repositories,
1610 removed_repositories: worktree.removed_repositories,
1611 };
1612 for update in proto::split_worktree_update(message) {
1613 session.peer.send(session.connection_id, update)?;
1614 }
1615
1616 // Stream this worktree's diagnostics.
1617 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1618 if let Some(summary) = worktree_diagnostics.next() {
1619 let message = proto::UpdateDiagnosticSummary {
1620 project_id: project.id.to_proto(),
1621 worktree_id: worktree.id,
1622 summary: Some(summary),
1623 more_summaries: worktree_diagnostics.collect(),
1624 };
1625 session.peer.send(session.connection_id, message)?;
1626 }
1627
1628 for settings_file in worktree.settings_files {
1629 session.peer.send(
1630 session.connection_id,
1631 proto::UpdateWorktreeSettings {
1632 project_id: project.id.to_proto(),
1633 worktree_id: worktree.id,
1634 path: settings_file.path,
1635 content: Some(settings_file.content),
1636 kind: Some(settings_file.kind.to_proto().into()),
1637 },
1638 )?;
1639 }
1640 }
1641
1642 for repository in mem::take(&mut project.updated_repositories) {
1643 for update in split_repository_update(repository) {
1644 session.peer.send(session.connection_id, update)?;
1645 }
1646 }
1647
1648 for id in mem::take(&mut project.removed_repositories) {
1649 session.peer.send(
1650 session.connection_id,
1651 proto::RemoveRepository {
1652 project_id: project.id.to_proto(),
1653 id,
1654 },
1655 )?;
1656 }
1657 }
1658
1659 Ok(())
1660}
1661
1662/// leave room disconnects from the room.
1663async fn leave_room(
1664 _: proto::LeaveRoom,
1665 response: Response<proto::LeaveRoom>,
1666 session: MessageContext,
1667) -> Result<()> {
1668 leave_room_for_session(&session, session.connection_id).await?;
1669 response.send(proto::Ack {})?;
1670 Ok(())
1671}
1672
1673/// Updates the permissions of someone else in the room.
1674async fn set_room_participant_role(
1675 request: proto::SetRoomParticipantRole,
1676 response: Response<proto::SetRoomParticipantRole>,
1677 session: MessageContext,
1678) -> Result<()> {
1679 let user_id = UserId::from_proto(request.user_id);
1680 let role = ChannelRole::from(request.role());
1681
1682 let (livekit_room, can_publish) = {
1683 let room = session
1684 .db()
1685 .await
1686 .set_room_participant_role(
1687 session.user_id(),
1688 RoomId::from_proto(request.room_id),
1689 user_id,
1690 role,
1691 )
1692 .await?;
1693
1694 let livekit_room = room.livekit_room.clone();
1695 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1696 room_updated(&room, &session.peer);
1697 (livekit_room, can_publish)
1698 };
1699
1700 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1701 live_kit
1702 .update_participant(
1703 livekit_room.clone(),
1704 request.user_id.to_string(),
1705 livekit_api::proto::ParticipantPermission {
1706 can_subscribe: true,
1707 can_publish,
1708 can_publish_data: can_publish,
1709 hidden: false,
1710 recorder: false,
1711 },
1712 )
1713 .await
1714 .trace_err();
1715 }
1716
1717 response.send(proto::Ack {})?;
1718 Ok(())
1719}
1720
1721/// Call someone else into the current room
1722async fn call(
1723 request: proto::Call,
1724 response: Response<proto::Call>,
1725 session: MessageContext,
1726) -> Result<()> {
1727 let room_id = RoomId::from_proto(request.room_id);
1728 let calling_user_id = session.user_id();
1729 let calling_connection_id = session.connection_id;
1730 let called_user_id = UserId::from_proto(request.called_user_id);
1731 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1732 if !session
1733 .db()
1734 .await
1735 .has_contact(calling_user_id, called_user_id)
1736 .await?
1737 {
1738 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1739 }
1740
1741 let incoming_call = {
1742 let (room, incoming_call) = &mut *session
1743 .db()
1744 .await
1745 .call(
1746 room_id,
1747 calling_user_id,
1748 calling_connection_id,
1749 called_user_id,
1750 initial_project_id,
1751 )
1752 .await?;
1753 room_updated(room, &session.peer);
1754 mem::take(incoming_call)
1755 };
1756 update_user_contacts(called_user_id, &session).await?;
1757
1758 let mut calls = session
1759 .connection_pool()
1760 .await
1761 .user_connection_ids(called_user_id)
1762 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1763 .collect::<FuturesUnordered<_>>();
1764
1765 while let Some(call_response) = calls.next().await {
1766 match call_response.as_ref() {
1767 Ok(_) => {
1768 response.send(proto::Ack {})?;
1769 return Ok(());
1770 }
1771 Err(_) => {
1772 call_response.trace_err();
1773 }
1774 }
1775 }
1776
1777 {
1778 let room = session
1779 .db()
1780 .await
1781 .call_failed(room_id, called_user_id)
1782 .await?;
1783 room_updated(&room, &session.peer);
1784 }
1785 update_user_contacts(called_user_id, &session).await?;
1786
1787 Err(anyhow!("failed to ring user"))?
1788}
1789
1790/// Cancel an outgoing call.
1791async fn cancel_call(
1792 request: proto::CancelCall,
1793 response: Response<proto::CancelCall>,
1794 session: MessageContext,
1795) -> Result<()> {
1796 let called_user_id = UserId::from_proto(request.called_user_id);
1797 let room_id = RoomId::from_proto(request.room_id);
1798 {
1799 let room = session
1800 .db()
1801 .await
1802 .cancel_call(room_id, session.connection_id, called_user_id)
1803 .await?;
1804 room_updated(&room, &session.peer);
1805 }
1806
1807 for connection_id in session
1808 .connection_pool()
1809 .await
1810 .user_connection_ids(called_user_id)
1811 {
1812 session
1813 .peer
1814 .send(
1815 connection_id,
1816 proto::CallCanceled {
1817 room_id: room_id.to_proto(),
1818 },
1819 )
1820 .trace_err();
1821 }
1822 response.send(proto::Ack {})?;
1823
1824 update_user_contacts(called_user_id, &session).await?;
1825 Ok(())
1826}
1827
1828/// Decline an incoming call.
1829async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1830 let room_id = RoomId::from_proto(message.room_id);
1831 {
1832 let room = session
1833 .db()
1834 .await
1835 .decline_call(Some(room_id), session.user_id())
1836 .await?
1837 .context("declining call")?;
1838 room_updated(&room, &session.peer);
1839 }
1840
1841 for connection_id in session
1842 .connection_pool()
1843 .await
1844 .user_connection_ids(session.user_id())
1845 {
1846 session
1847 .peer
1848 .send(
1849 connection_id,
1850 proto::CallCanceled {
1851 room_id: room_id.to_proto(),
1852 },
1853 )
1854 .trace_err();
1855 }
1856 update_user_contacts(session.user_id(), &session).await?;
1857 Ok(())
1858}
1859
1860/// Updates other participants in the room with your current location.
1861async fn update_participant_location(
1862 request: proto::UpdateParticipantLocation,
1863 response: Response<proto::UpdateParticipantLocation>,
1864 session: MessageContext,
1865) -> Result<()> {
1866 let room_id = RoomId::from_proto(request.room_id);
1867 let location = request.location.context("invalid location")?;
1868
1869 let db = session.db().await;
1870 let room = db
1871 .update_room_participant_location(room_id, session.connection_id, location)
1872 .await?;
1873
1874 room_updated(&room, &session.peer);
1875 response.send(proto::Ack {})?;
1876 Ok(())
1877}
1878
1879/// Share a project into the room.
1880async fn share_project(
1881 request: proto::ShareProject,
1882 response: Response<proto::ShareProject>,
1883 session: MessageContext,
1884) -> Result<()> {
1885 let (project_id, room) = &*session
1886 .db()
1887 .await
1888 .share_project(
1889 RoomId::from_proto(request.room_id),
1890 session.connection_id,
1891 &request.worktrees,
1892 request.is_ssh_project,
1893 )
1894 .await?;
1895 response.send(proto::ShareProjectResponse {
1896 project_id: project_id.to_proto(),
1897 })?;
1898 room_updated(room, &session.peer);
1899
1900 Ok(())
1901}
1902
1903/// Unshare a project from the room.
1904async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1905 let project_id = ProjectId::from_proto(message.project_id);
1906 unshare_project_internal(project_id, session.connection_id, &session).await
1907}
1908
1909async fn unshare_project_internal(
1910 project_id: ProjectId,
1911 connection_id: ConnectionId,
1912 session: &Session,
1913) -> Result<()> {
1914 let delete = {
1915 let room_guard = session
1916 .db()
1917 .await
1918 .unshare_project(project_id, connection_id)
1919 .await?;
1920
1921 let (delete, room, guest_connection_ids) = &*room_guard;
1922
1923 let message = proto::UnshareProject {
1924 project_id: project_id.to_proto(),
1925 };
1926
1927 broadcast(
1928 Some(connection_id),
1929 guest_connection_ids.iter().copied(),
1930 |conn_id| session.peer.send(conn_id, message.clone()),
1931 );
1932 if let Some(room) = room {
1933 room_updated(room, &session.peer);
1934 }
1935
1936 *delete
1937 };
1938
1939 if delete {
1940 let db = session.db().await;
1941 db.delete_project(project_id).await?;
1942 }
1943
1944 Ok(())
1945}
1946
1947/// Join someone elses shared project.
1948async fn join_project(
1949 request: proto::JoinProject,
1950 response: Response<proto::JoinProject>,
1951 session: MessageContext,
1952) -> Result<()> {
1953 let project_id = ProjectId::from_proto(request.project_id);
1954
1955 tracing::info!(%project_id, "join project");
1956
1957 let db = session.db().await;
1958 let (project, replica_id) = &mut *db
1959 .join_project(
1960 project_id,
1961 session.connection_id,
1962 session.user_id(),
1963 request.committer_name.clone(),
1964 request.committer_email.clone(),
1965 )
1966 .await?;
1967 drop(db);
1968 tracing::info!(%project_id, "join remote project");
1969 let collaborators = project
1970 .collaborators
1971 .iter()
1972 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1973 .map(|collaborator| collaborator.to_proto())
1974 .collect::<Vec<_>>();
1975 let project_id = project.id;
1976 let guest_user_id = session.user_id();
1977
1978 let worktrees = project
1979 .worktrees
1980 .iter()
1981 .map(|(id, worktree)| proto::WorktreeMetadata {
1982 id: *id,
1983 root_name: worktree.root_name.clone(),
1984 visible: worktree.visible,
1985 abs_path: worktree.abs_path.clone(),
1986 })
1987 .collect::<Vec<_>>();
1988
1989 let add_project_collaborator = proto::AddProjectCollaborator {
1990 project_id: project_id.to_proto(),
1991 collaborator: Some(proto::Collaborator {
1992 peer_id: Some(session.connection_id.into()),
1993 replica_id: replica_id.0 as u32,
1994 user_id: guest_user_id.to_proto(),
1995 is_host: false,
1996 committer_name: request.committer_name.clone(),
1997 committer_email: request.committer_email.clone(),
1998 }),
1999 };
2000
2001 for collaborator in &collaborators {
2002 session
2003 .peer
2004 .send(
2005 collaborator.peer_id.unwrap().into(),
2006 add_project_collaborator.clone(),
2007 )
2008 .trace_err();
2009 }
2010
2011 // First, we send the metadata associated with each worktree.
2012 let (language_servers, language_server_capabilities) = project
2013 .language_servers
2014 .clone()
2015 .into_iter()
2016 .map(|server| (server.server, server.capabilities))
2017 .unzip();
2018 response.send(proto::JoinProjectResponse {
2019 project_id: project.id.0 as u64,
2020 worktrees,
2021 replica_id: replica_id.0 as u32,
2022 collaborators,
2023 language_servers,
2024 language_server_capabilities,
2025 role: project.role.into(),
2026 })?;
2027
2028 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2029 // Stream this worktree's entries.
2030 let message = proto::UpdateWorktree {
2031 project_id: project_id.to_proto(),
2032 worktree_id,
2033 abs_path: worktree.abs_path.clone(),
2034 root_name: worktree.root_name,
2035 updated_entries: worktree.entries,
2036 removed_entries: Default::default(),
2037 scan_id: worktree.scan_id,
2038 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2039 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2040 removed_repositories: Default::default(),
2041 };
2042 for update in proto::split_worktree_update(message) {
2043 session.peer.send(session.connection_id, update.clone())?;
2044 }
2045
2046 // Stream this worktree's diagnostics.
2047 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2048 if let Some(summary) = worktree_diagnostics.next() {
2049 let message = proto::UpdateDiagnosticSummary {
2050 project_id: project.id.to_proto(),
2051 worktree_id: worktree.id,
2052 summary: Some(summary),
2053 more_summaries: worktree_diagnostics.collect(),
2054 };
2055 session.peer.send(session.connection_id, message)?;
2056 }
2057
2058 for settings_file in worktree.settings_files {
2059 session.peer.send(
2060 session.connection_id,
2061 proto::UpdateWorktreeSettings {
2062 project_id: project_id.to_proto(),
2063 worktree_id: worktree.id,
2064 path: settings_file.path,
2065 content: Some(settings_file.content),
2066 kind: Some(settings_file.kind.to_proto() as i32),
2067 },
2068 )?;
2069 }
2070 }
2071
2072 for repository in mem::take(&mut project.repositories) {
2073 for update in split_repository_update(repository) {
2074 session.peer.send(session.connection_id, update)?;
2075 }
2076 }
2077
2078 for language_server in &project.language_servers {
2079 session.peer.send(
2080 session.connection_id,
2081 proto::UpdateLanguageServer {
2082 project_id: project_id.to_proto(),
2083 server_name: Some(language_server.server.name.clone()),
2084 language_server_id: language_server.server.id,
2085 variant: Some(
2086 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2087 proto::LspDiskBasedDiagnosticsUpdated {},
2088 ),
2089 ),
2090 },
2091 )?;
2092 }
2093
2094 Ok(())
2095}
2096
2097/// Leave someone elses shared project.
2098async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2099 let sender_id = session.connection_id;
2100 let project_id = ProjectId::from_proto(request.project_id);
2101 let db = session.db().await;
2102
2103 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2104 tracing::info!(
2105 %project_id,
2106 "leave project"
2107 );
2108
2109 project_left(project, &session);
2110 if let Some(room) = room {
2111 room_updated(room, &session.peer);
2112 }
2113
2114 Ok(())
2115}
2116
2117/// Updates other participants with changes to the project
2118async fn update_project(
2119 request: proto::UpdateProject,
2120 response: Response<proto::UpdateProject>,
2121 session: MessageContext,
2122) -> Result<()> {
2123 let project_id = ProjectId::from_proto(request.project_id);
2124 let (room, guest_connection_ids) = &*session
2125 .db()
2126 .await
2127 .update_project(project_id, session.connection_id, &request.worktrees)
2128 .await?;
2129 broadcast(
2130 Some(session.connection_id),
2131 guest_connection_ids.iter().copied(),
2132 |connection_id| {
2133 session
2134 .peer
2135 .forward_send(session.connection_id, connection_id, request.clone())
2136 },
2137 );
2138 if let Some(room) = room {
2139 room_updated(room, &session.peer);
2140 }
2141 response.send(proto::Ack {})?;
2142
2143 Ok(())
2144}
2145
2146/// Updates other participants with changes to the worktree
2147async fn update_worktree(
2148 request: proto::UpdateWorktree,
2149 response: Response<proto::UpdateWorktree>,
2150 session: MessageContext,
2151) -> Result<()> {
2152 let guest_connection_ids = session
2153 .db()
2154 .await
2155 .update_worktree(&request, session.connection_id)
2156 .await?;
2157
2158 broadcast(
2159 Some(session.connection_id),
2160 guest_connection_ids.iter().copied(),
2161 |connection_id| {
2162 session
2163 .peer
2164 .forward_send(session.connection_id, connection_id, request.clone())
2165 },
2166 );
2167 response.send(proto::Ack {})?;
2168 Ok(())
2169}
2170
2171async fn update_repository(
2172 request: proto::UpdateRepository,
2173 response: Response<proto::UpdateRepository>,
2174 session: MessageContext,
2175) -> Result<()> {
2176 let guest_connection_ids = session
2177 .db()
2178 .await
2179 .update_repository(&request, session.connection_id)
2180 .await?;
2181
2182 broadcast(
2183 Some(session.connection_id),
2184 guest_connection_ids.iter().copied(),
2185 |connection_id| {
2186 session
2187 .peer
2188 .forward_send(session.connection_id, connection_id, request.clone())
2189 },
2190 );
2191 response.send(proto::Ack {})?;
2192 Ok(())
2193}
2194
2195async fn remove_repository(
2196 request: proto::RemoveRepository,
2197 response: Response<proto::RemoveRepository>,
2198 session: MessageContext,
2199) -> Result<()> {
2200 let guest_connection_ids = session
2201 .db()
2202 .await
2203 .remove_repository(&request, session.connection_id)
2204 .await?;
2205
2206 broadcast(
2207 Some(session.connection_id),
2208 guest_connection_ids.iter().copied(),
2209 |connection_id| {
2210 session
2211 .peer
2212 .forward_send(session.connection_id, connection_id, request.clone())
2213 },
2214 );
2215 response.send(proto::Ack {})?;
2216 Ok(())
2217}
2218
2219/// Updates other participants with changes to the diagnostics
2220async fn update_diagnostic_summary(
2221 message: proto::UpdateDiagnosticSummary,
2222 session: MessageContext,
2223) -> Result<()> {
2224 let guest_connection_ids = session
2225 .db()
2226 .await
2227 .update_diagnostic_summary(&message, session.connection_id)
2228 .await?;
2229
2230 broadcast(
2231 Some(session.connection_id),
2232 guest_connection_ids.iter().copied(),
2233 |connection_id| {
2234 session
2235 .peer
2236 .forward_send(session.connection_id, connection_id, message.clone())
2237 },
2238 );
2239
2240 Ok(())
2241}
2242
2243/// Updates other participants with changes to the worktree settings
2244async fn update_worktree_settings(
2245 message: proto::UpdateWorktreeSettings,
2246 session: MessageContext,
2247) -> Result<()> {
2248 let guest_connection_ids = session
2249 .db()
2250 .await
2251 .update_worktree_settings(&message, session.connection_id)
2252 .await?;
2253
2254 broadcast(
2255 Some(session.connection_id),
2256 guest_connection_ids.iter().copied(),
2257 |connection_id| {
2258 session
2259 .peer
2260 .forward_send(session.connection_id, connection_id, message.clone())
2261 },
2262 );
2263
2264 Ok(())
2265}
2266
2267/// Notify other participants that a language server has started.
2268async fn start_language_server(
2269 request: proto::StartLanguageServer,
2270 session: MessageContext,
2271) -> Result<()> {
2272 let guest_connection_ids = session
2273 .db()
2274 .await
2275 .start_language_server(&request, session.connection_id)
2276 .await?;
2277
2278 broadcast(
2279 Some(session.connection_id),
2280 guest_connection_ids.iter().copied(),
2281 |connection_id| {
2282 session
2283 .peer
2284 .forward_send(session.connection_id, connection_id, request.clone())
2285 },
2286 );
2287 Ok(())
2288}
2289
2290/// Notify other participants that a language server has changed.
2291async fn update_language_server(
2292 request: proto::UpdateLanguageServer,
2293 session: MessageContext,
2294) -> Result<()> {
2295 let project_id = ProjectId::from_proto(request.project_id);
2296 let db = session.db().await;
2297
2298 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2299 && let Some(capabilities) = update.capabilities.clone()
2300 {
2301 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2302 .await?;
2303 }
2304
2305 let project_connection_ids = db
2306 .project_connection_ids(project_id, session.connection_id, true)
2307 .await?;
2308 broadcast(
2309 Some(session.connection_id),
2310 project_connection_ids.iter().copied(),
2311 |connection_id| {
2312 session
2313 .peer
2314 .forward_send(session.connection_id, connection_id, request.clone())
2315 },
2316 );
2317 Ok(())
2318}
2319
2320/// forward a project request to the host. These requests should be read only
2321/// as guests are allowed to send them.
2322async fn forward_read_only_project_request<T>(
2323 request: T,
2324 response: Response<T>,
2325 session: MessageContext,
2326) -> Result<()>
2327where
2328 T: EntityMessage + RequestMessage,
2329{
2330 let project_id = ProjectId::from_proto(request.remote_entity_id());
2331 let host_connection_id = session
2332 .db()
2333 .await
2334 .host_for_read_only_project_request(project_id, session.connection_id)
2335 .await?;
2336 let payload = session.forward_request(host_connection_id, request).await?;
2337 response.send(payload)?;
2338 Ok(())
2339}
2340
2341/// forward a project request to the host. These requests are disallowed
2342/// for guests.
2343async fn forward_mutating_project_request<T>(
2344 request: T,
2345 response: Response<T>,
2346 session: MessageContext,
2347) -> Result<()>
2348where
2349 T: EntityMessage + RequestMessage,
2350{
2351 let project_id = ProjectId::from_proto(request.remote_entity_id());
2352
2353 let host_connection_id = session
2354 .db()
2355 .await
2356 .host_for_mutating_project_request(project_id, session.connection_id)
2357 .await?;
2358 let payload = session.forward_request(host_connection_id, request).await?;
2359 response.send(payload)?;
2360 Ok(())
2361}
2362
2363// todo(lsp) remove after Zed Stable hits v0.204.x
2364async fn multi_lsp_query(
2365 request: MultiLspQuery,
2366 response: Response<MultiLspQuery>,
2367 session: MessageContext,
2368) -> Result<()> {
2369 tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2370 tracing::info!("multi_lsp_query message received");
2371 forward_mutating_project_request(request, response, session).await
2372}
2373
2374async fn lsp_query(
2375 request: proto::LspQuery,
2376 response: Response<proto::LspQuery>,
2377 session: MessageContext,
2378) -> Result<()> {
2379 let (name, should_write) = request.query_name_and_write_permissions();
2380 tracing::Span::current().record("lsp_query_request", name);
2381 tracing::info!("lsp_query message received");
2382 if should_write {
2383 forward_mutating_project_request(request, response, session).await
2384 } else {
2385 forward_read_only_project_request(request, response, session).await
2386 }
2387}
2388
2389/// Notify other participants that a new buffer has been created
2390async fn create_buffer_for_peer(
2391 request: proto::CreateBufferForPeer,
2392 session: MessageContext,
2393) -> Result<()> {
2394 session
2395 .db()
2396 .await
2397 .check_user_is_project_host(
2398 ProjectId::from_proto(request.project_id),
2399 session.connection_id,
2400 )
2401 .await?;
2402 let peer_id = request.peer_id.context("invalid peer id")?;
2403 session
2404 .peer
2405 .forward_send(session.connection_id, peer_id.into(), request)?;
2406 Ok(())
2407}
2408
2409/// Notify other participants that a buffer has been updated. This is
2410/// allowed for guests as long as the update is limited to selections.
2411async fn update_buffer(
2412 request: proto::UpdateBuffer,
2413 response: Response<proto::UpdateBuffer>,
2414 session: MessageContext,
2415) -> Result<()> {
2416 let project_id = ProjectId::from_proto(request.project_id);
2417 let mut capability = Capability::ReadOnly;
2418
2419 for op in request.operations.iter() {
2420 match op.variant {
2421 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2422 Some(_) => capability = Capability::ReadWrite,
2423 }
2424 }
2425
2426 let host = {
2427 let guard = session
2428 .db()
2429 .await
2430 .connections_for_buffer_update(project_id, session.connection_id, capability)
2431 .await?;
2432
2433 let (host, guests) = &*guard;
2434
2435 broadcast(
2436 Some(session.connection_id),
2437 guests.clone(),
2438 |connection_id| {
2439 session
2440 .peer
2441 .forward_send(session.connection_id, connection_id, request.clone())
2442 },
2443 );
2444
2445 *host
2446 };
2447
2448 if host != session.connection_id {
2449 session.forward_request(host, request.clone()).await?;
2450 }
2451
2452 response.send(proto::Ack {})?;
2453 Ok(())
2454}
2455
2456async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2457 let project_id = ProjectId::from_proto(message.project_id);
2458
2459 let operation = message.operation.as_ref().context("invalid operation")?;
2460 let capability = match operation.variant.as_ref() {
2461 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2462 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2463 match buffer_op.variant {
2464 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2465 Capability::ReadOnly
2466 }
2467 _ => Capability::ReadWrite,
2468 }
2469 } else {
2470 Capability::ReadWrite
2471 }
2472 }
2473 Some(_) => Capability::ReadWrite,
2474 None => Capability::ReadOnly,
2475 };
2476
2477 let guard = session
2478 .db()
2479 .await
2480 .connections_for_buffer_update(project_id, session.connection_id, capability)
2481 .await?;
2482
2483 let (host, guests) = &*guard;
2484
2485 broadcast(
2486 Some(session.connection_id),
2487 guests.iter().chain([host]).copied(),
2488 |connection_id| {
2489 session
2490 .peer
2491 .forward_send(session.connection_id, connection_id, message.clone())
2492 },
2493 );
2494
2495 Ok(())
2496}
2497
2498/// Notify other participants that a project has been updated.
2499async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2500 request: T,
2501 session: MessageContext,
2502) -> Result<()> {
2503 let project_id = ProjectId::from_proto(request.remote_entity_id());
2504 let project_connection_ids = session
2505 .db()
2506 .await
2507 .project_connection_ids(project_id, session.connection_id, false)
2508 .await?;
2509
2510 broadcast(
2511 Some(session.connection_id),
2512 project_connection_ids.iter().copied(),
2513 |connection_id| {
2514 session
2515 .peer
2516 .forward_send(session.connection_id, connection_id, request.clone())
2517 },
2518 );
2519 Ok(())
2520}
2521
2522/// Start following another user in a call.
2523async fn follow(
2524 request: proto::Follow,
2525 response: Response<proto::Follow>,
2526 session: MessageContext,
2527) -> Result<()> {
2528 let room_id = RoomId::from_proto(request.room_id);
2529 let project_id = request.project_id.map(ProjectId::from_proto);
2530 let leader_id = request.leader_id.context("invalid leader id")?.into();
2531 let follower_id = session.connection_id;
2532
2533 session
2534 .db()
2535 .await
2536 .check_room_participants(room_id, leader_id, session.connection_id)
2537 .await?;
2538
2539 let response_payload = session.forward_request(leader_id, request).await?;
2540 response.send(response_payload)?;
2541
2542 if let Some(project_id) = project_id {
2543 let room = session
2544 .db()
2545 .await
2546 .follow(room_id, project_id, leader_id, follower_id)
2547 .await?;
2548 room_updated(&room, &session.peer);
2549 }
2550
2551 Ok(())
2552}
2553
2554/// Stop following another user in a call.
2555async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2556 let room_id = RoomId::from_proto(request.room_id);
2557 let project_id = request.project_id.map(ProjectId::from_proto);
2558 let leader_id = request.leader_id.context("invalid leader id")?.into();
2559 let follower_id = session.connection_id;
2560
2561 session
2562 .db()
2563 .await
2564 .check_room_participants(room_id, leader_id, session.connection_id)
2565 .await?;
2566
2567 session
2568 .peer
2569 .forward_send(session.connection_id, leader_id, request)?;
2570
2571 if let Some(project_id) = project_id {
2572 let room = session
2573 .db()
2574 .await
2575 .unfollow(room_id, project_id, leader_id, follower_id)
2576 .await?;
2577 room_updated(&room, &session.peer);
2578 }
2579
2580 Ok(())
2581}
2582
2583/// Notify everyone following you of your current location.
2584async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2585 let room_id = RoomId::from_proto(request.room_id);
2586 let database = session.db.lock().await;
2587
2588 let connection_ids = if let Some(project_id) = request.project_id {
2589 let project_id = ProjectId::from_proto(project_id);
2590 database
2591 .project_connection_ids(project_id, session.connection_id, true)
2592 .await?
2593 } else {
2594 database
2595 .room_connection_ids(room_id, session.connection_id)
2596 .await?
2597 };
2598
2599 // For now, don't send view update messages back to that view's current leader.
2600 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2601 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2602 _ => None,
2603 });
2604
2605 for connection_id in connection_ids.iter().cloned() {
2606 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2607 session
2608 .peer
2609 .forward_send(session.connection_id, connection_id, request.clone())?;
2610 }
2611 }
2612 Ok(())
2613}
2614
2615/// Get public data about users.
2616async fn get_users(
2617 request: proto::GetUsers,
2618 response: Response<proto::GetUsers>,
2619 session: MessageContext,
2620) -> Result<()> {
2621 let user_ids = request
2622 .user_ids
2623 .into_iter()
2624 .map(UserId::from_proto)
2625 .collect();
2626 let users = session
2627 .db()
2628 .await
2629 .get_users_by_ids(user_ids)
2630 .await?
2631 .into_iter()
2632 .map(|user| proto::User {
2633 id: user.id.to_proto(),
2634 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2635 github_login: user.github_login,
2636 name: user.name,
2637 })
2638 .collect();
2639 response.send(proto::UsersResponse { users })?;
2640 Ok(())
2641}
2642
2643/// Search for users (to invite) buy Github login
2644async fn fuzzy_search_users(
2645 request: proto::FuzzySearchUsers,
2646 response: Response<proto::FuzzySearchUsers>,
2647 session: MessageContext,
2648) -> Result<()> {
2649 let query = request.query;
2650 let users = match query.len() {
2651 0 => vec![],
2652 1 | 2 => session
2653 .db()
2654 .await
2655 .get_user_by_github_login(&query)
2656 .await?
2657 .into_iter()
2658 .collect(),
2659 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2660 };
2661 let users = users
2662 .into_iter()
2663 .filter(|user| user.id != session.user_id())
2664 .map(|user| proto::User {
2665 id: user.id.to_proto(),
2666 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2667 github_login: user.github_login,
2668 name: user.name,
2669 })
2670 .collect();
2671 response.send(proto::UsersResponse { users })?;
2672 Ok(())
2673}
2674
2675/// Send a contact request to another user.
2676async fn request_contact(
2677 request: proto::RequestContact,
2678 response: Response<proto::RequestContact>,
2679 session: MessageContext,
2680) -> Result<()> {
2681 let requester_id = session.user_id();
2682 let responder_id = UserId::from_proto(request.responder_id);
2683 if requester_id == responder_id {
2684 return Err(anyhow!("cannot add yourself as a contact"))?;
2685 }
2686
2687 let notifications = session
2688 .db()
2689 .await
2690 .send_contact_request(requester_id, responder_id)
2691 .await?;
2692
2693 // Update outgoing contact requests of requester
2694 let mut update = proto::UpdateContacts::default();
2695 update.outgoing_requests.push(responder_id.to_proto());
2696 for connection_id in session
2697 .connection_pool()
2698 .await
2699 .user_connection_ids(requester_id)
2700 {
2701 session.peer.send(connection_id, update.clone())?;
2702 }
2703
2704 // Update incoming contact requests of responder
2705 let mut update = proto::UpdateContacts::default();
2706 update
2707 .incoming_requests
2708 .push(proto::IncomingContactRequest {
2709 requester_id: requester_id.to_proto(),
2710 });
2711 let connection_pool = session.connection_pool().await;
2712 for connection_id in connection_pool.user_connection_ids(responder_id) {
2713 session.peer.send(connection_id, update.clone())?;
2714 }
2715
2716 send_notifications(&connection_pool, &session.peer, notifications);
2717
2718 response.send(proto::Ack {})?;
2719 Ok(())
2720}
2721
2722/// Accept or decline a contact request
2723async fn respond_to_contact_request(
2724 request: proto::RespondToContactRequest,
2725 response: Response<proto::RespondToContactRequest>,
2726 session: MessageContext,
2727) -> Result<()> {
2728 let responder_id = session.user_id();
2729 let requester_id = UserId::from_proto(request.requester_id);
2730 let db = session.db().await;
2731 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2732 db.dismiss_contact_notification(responder_id, requester_id)
2733 .await?;
2734 } else {
2735 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2736
2737 let notifications = db
2738 .respond_to_contact_request(responder_id, requester_id, accept)
2739 .await?;
2740 let requester_busy = db.is_user_busy(requester_id).await?;
2741 let responder_busy = db.is_user_busy(responder_id).await?;
2742
2743 let pool = session.connection_pool().await;
2744 // Update responder with new contact
2745 let mut update = proto::UpdateContacts::default();
2746 if accept {
2747 update
2748 .contacts
2749 .push(contact_for_user(requester_id, requester_busy, &pool));
2750 }
2751 update
2752 .remove_incoming_requests
2753 .push(requester_id.to_proto());
2754 for connection_id in pool.user_connection_ids(responder_id) {
2755 session.peer.send(connection_id, update.clone())?;
2756 }
2757
2758 // Update requester with new contact
2759 let mut update = proto::UpdateContacts::default();
2760 if accept {
2761 update
2762 .contacts
2763 .push(contact_for_user(responder_id, responder_busy, &pool));
2764 }
2765 update
2766 .remove_outgoing_requests
2767 .push(responder_id.to_proto());
2768
2769 for connection_id in pool.user_connection_ids(requester_id) {
2770 session.peer.send(connection_id, update.clone())?;
2771 }
2772
2773 send_notifications(&pool, &session.peer, notifications);
2774 }
2775
2776 response.send(proto::Ack {})?;
2777 Ok(())
2778}
2779
2780/// Remove a contact.
2781async fn remove_contact(
2782 request: proto::RemoveContact,
2783 response: Response<proto::RemoveContact>,
2784 session: MessageContext,
2785) -> Result<()> {
2786 let requester_id = session.user_id();
2787 let responder_id = UserId::from_proto(request.user_id);
2788 let db = session.db().await;
2789 let (contact_accepted, deleted_notification_id) =
2790 db.remove_contact(requester_id, responder_id).await?;
2791
2792 let pool = session.connection_pool().await;
2793 // Update outgoing contact requests of requester
2794 let mut update = proto::UpdateContacts::default();
2795 if contact_accepted {
2796 update.remove_contacts.push(responder_id.to_proto());
2797 } else {
2798 update
2799 .remove_outgoing_requests
2800 .push(responder_id.to_proto());
2801 }
2802 for connection_id in pool.user_connection_ids(requester_id) {
2803 session.peer.send(connection_id, update.clone())?;
2804 }
2805
2806 // Update incoming contact requests of responder
2807 let mut update = proto::UpdateContacts::default();
2808 if contact_accepted {
2809 update.remove_contacts.push(requester_id.to_proto());
2810 } else {
2811 update
2812 .remove_incoming_requests
2813 .push(requester_id.to_proto());
2814 }
2815 for connection_id in pool.user_connection_ids(responder_id) {
2816 session.peer.send(connection_id, update.clone())?;
2817 if let Some(notification_id) = deleted_notification_id {
2818 session.peer.send(
2819 connection_id,
2820 proto::DeleteNotification {
2821 notification_id: notification_id.to_proto(),
2822 },
2823 )?;
2824 }
2825 }
2826
2827 response.send(proto::Ack {})?;
2828 Ok(())
2829}
2830
2831fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2832 version.0.minor() < 139
2833}
2834
2835async fn subscribe_to_channels(
2836 _: proto::SubscribeToChannels,
2837 session: MessageContext,
2838) -> Result<()> {
2839 subscribe_user_to_channels(session.user_id(), &session).await?;
2840 Ok(())
2841}
2842
2843async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2844 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2845 let mut pool = session.connection_pool().await;
2846 for membership in &channels_for_user.channel_memberships {
2847 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2848 }
2849 session.peer.send(
2850 session.connection_id,
2851 build_update_user_channels(&channels_for_user),
2852 )?;
2853 session.peer.send(
2854 session.connection_id,
2855 build_channels_update(channels_for_user),
2856 )?;
2857 Ok(())
2858}
2859
2860/// Creates a new channel.
2861async fn create_channel(
2862 request: proto::CreateChannel,
2863 response: Response<proto::CreateChannel>,
2864 session: MessageContext,
2865) -> Result<()> {
2866 let db = session.db().await;
2867
2868 let parent_id = request.parent_id.map(ChannelId::from_proto);
2869 let (channel, membership) = db
2870 .create_channel(&request.name, parent_id, session.user_id())
2871 .await?;
2872
2873 let root_id = channel.root_id();
2874 let channel = Channel::from_model(channel);
2875
2876 response.send(proto::CreateChannelResponse {
2877 channel: Some(channel.to_proto()),
2878 parent_id: request.parent_id,
2879 })?;
2880
2881 let mut connection_pool = session.connection_pool().await;
2882 if let Some(membership) = membership {
2883 connection_pool.subscribe_to_channel(
2884 membership.user_id,
2885 membership.channel_id,
2886 membership.role,
2887 );
2888 let update = proto::UpdateUserChannels {
2889 channel_memberships: vec![proto::ChannelMembership {
2890 channel_id: membership.channel_id.to_proto(),
2891 role: membership.role.into(),
2892 }],
2893 ..Default::default()
2894 };
2895 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2896 session.peer.send(connection_id, update.clone())?;
2897 }
2898 }
2899
2900 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2901 if !role.can_see_channel(channel.visibility) {
2902 continue;
2903 }
2904
2905 let update = proto::UpdateChannels {
2906 channels: vec![channel.to_proto()],
2907 ..Default::default()
2908 };
2909 session.peer.send(connection_id, update.clone())?;
2910 }
2911
2912 Ok(())
2913}
2914
2915/// Delete a channel
2916async fn delete_channel(
2917 request: proto::DeleteChannel,
2918 response: Response<proto::DeleteChannel>,
2919 session: MessageContext,
2920) -> Result<()> {
2921 let db = session.db().await;
2922
2923 let channel_id = request.channel_id;
2924 let (root_channel, removed_channels) = db
2925 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2926 .await?;
2927 response.send(proto::Ack {})?;
2928
2929 // Notify members of removed channels
2930 let mut update = proto::UpdateChannels::default();
2931 update
2932 .delete_channels
2933 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2934
2935 let connection_pool = session.connection_pool().await;
2936 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2937 session.peer.send(connection_id, update.clone())?;
2938 }
2939
2940 Ok(())
2941}
2942
2943/// Invite someone to join a channel.
2944async fn invite_channel_member(
2945 request: proto::InviteChannelMember,
2946 response: Response<proto::InviteChannelMember>,
2947 session: MessageContext,
2948) -> Result<()> {
2949 let db = session.db().await;
2950 let channel_id = ChannelId::from_proto(request.channel_id);
2951 let invitee_id = UserId::from_proto(request.user_id);
2952 let InviteMemberResult {
2953 channel,
2954 notifications,
2955 } = db
2956 .invite_channel_member(
2957 channel_id,
2958 invitee_id,
2959 session.user_id(),
2960 request.role().into(),
2961 )
2962 .await?;
2963
2964 let update = proto::UpdateChannels {
2965 channel_invitations: vec![channel.to_proto()],
2966 ..Default::default()
2967 };
2968
2969 let connection_pool = session.connection_pool().await;
2970 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2971 session.peer.send(connection_id, update.clone())?;
2972 }
2973
2974 send_notifications(&connection_pool, &session.peer, notifications);
2975
2976 response.send(proto::Ack {})?;
2977 Ok(())
2978}
2979
2980/// remove someone from a channel
2981async fn remove_channel_member(
2982 request: proto::RemoveChannelMember,
2983 response: Response<proto::RemoveChannelMember>,
2984 session: MessageContext,
2985) -> Result<()> {
2986 let db = session.db().await;
2987 let channel_id = ChannelId::from_proto(request.channel_id);
2988 let member_id = UserId::from_proto(request.user_id);
2989
2990 let RemoveChannelMemberResult {
2991 membership_update,
2992 notification_id,
2993 } = db
2994 .remove_channel_member(channel_id, member_id, session.user_id())
2995 .await?;
2996
2997 let mut connection_pool = session.connection_pool().await;
2998 notify_membership_updated(
2999 &mut connection_pool,
3000 membership_update,
3001 member_id,
3002 &session.peer,
3003 );
3004 for connection_id in connection_pool.user_connection_ids(member_id) {
3005 if let Some(notification_id) = notification_id {
3006 session
3007 .peer
3008 .send(
3009 connection_id,
3010 proto::DeleteNotification {
3011 notification_id: notification_id.to_proto(),
3012 },
3013 )
3014 .trace_err();
3015 }
3016 }
3017
3018 response.send(proto::Ack {})?;
3019 Ok(())
3020}
3021
3022/// Toggle the channel between public and private.
3023/// Care is taken to maintain the invariant that public channels only descend from public channels,
3024/// (though members-only channels can appear at any point in the hierarchy).
3025async fn set_channel_visibility(
3026 request: proto::SetChannelVisibility,
3027 response: Response<proto::SetChannelVisibility>,
3028 session: MessageContext,
3029) -> Result<()> {
3030 let db = session.db().await;
3031 let channel_id = ChannelId::from_proto(request.channel_id);
3032 let visibility = request.visibility().into();
3033
3034 let channel_model = db
3035 .set_channel_visibility(channel_id, visibility, session.user_id())
3036 .await?;
3037 let root_id = channel_model.root_id();
3038 let channel = Channel::from_model(channel_model);
3039
3040 let mut connection_pool = session.connection_pool().await;
3041 for (user_id, role) in connection_pool
3042 .channel_user_ids(root_id)
3043 .collect::<Vec<_>>()
3044 .into_iter()
3045 {
3046 let update = if role.can_see_channel(channel.visibility) {
3047 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3048 proto::UpdateChannels {
3049 channels: vec![channel.to_proto()],
3050 ..Default::default()
3051 }
3052 } else {
3053 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3054 proto::UpdateChannels {
3055 delete_channels: vec![channel.id.to_proto()],
3056 ..Default::default()
3057 }
3058 };
3059
3060 for connection_id in connection_pool.user_connection_ids(user_id) {
3061 session.peer.send(connection_id, update.clone())?;
3062 }
3063 }
3064
3065 response.send(proto::Ack {})?;
3066 Ok(())
3067}
3068
3069/// Alter the role for a user in the channel.
3070async fn set_channel_member_role(
3071 request: proto::SetChannelMemberRole,
3072 response: Response<proto::SetChannelMemberRole>,
3073 session: MessageContext,
3074) -> Result<()> {
3075 let db = session.db().await;
3076 let channel_id = ChannelId::from_proto(request.channel_id);
3077 let member_id = UserId::from_proto(request.user_id);
3078 let result = db
3079 .set_channel_member_role(
3080 channel_id,
3081 session.user_id(),
3082 member_id,
3083 request.role().into(),
3084 )
3085 .await?;
3086
3087 match result {
3088 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3089 let mut connection_pool = session.connection_pool().await;
3090 notify_membership_updated(
3091 &mut connection_pool,
3092 membership_update,
3093 member_id,
3094 &session.peer,
3095 )
3096 }
3097 db::SetMemberRoleResult::InviteUpdated(channel) => {
3098 let update = proto::UpdateChannels {
3099 channel_invitations: vec![channel.to_proto()],
3100 ..Default::default()
3101 };
3102
3103 for connection_id in session
3104 .connection_pool()
3105 .await
3106 .user_connection_ids(member_id)
3107 {
3108 session.peer.send(connection_id, update.clone())?;
3109 }
3110 }
3111 }
3112
3113 response.send(proto::Ack {})?;
3114 Ok(())
3115}
3116
3117/// Change the name of a channel
3118async fn rename_channel(
3119 request: proto::RenameChannel,
3120 response: Response<proto::RenameChannel>,
3121 session: MessageContext,
3122) -> Result<()> {
3123 let db = session.db().await;
3124 let channel_id = ChannelId::from_proto(request.channel_id);
3125 let channel_model = db
3126 .rename_channel(channel_id, session.user_id(), &request.name)
3127 .await?;
3128 let root_id = channel_model.root_id();
3129 let channel = Channel::from_model(channel_model);
3130
3131 response.send(proto::RenameChannelResponse {
3132 channel: Some(channel.to_proto()),
3133 })?;
3134
3135 let connection_pool = session.connection_pool().await;
3136 let update = proto::UpdateChannels {
3137 channels: vec![channel.to_proto()],
3138 ..Default::default()
3139 };
3140 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3141 if role.can_see_channel(channel.visibility) {
3142 session.peer.send(connection_id, update.clone())?;
3143 }
3144 }
3145
3146 Ok(())
3147}
3148
3149/// Move a channel to a new parent.
3150async fn move_channel(
3151 request: proto::MoveChannel,
3152 response: Response<proto::MoveChannel>,
3153 session: MessageContext,
3154) -> Result<()> {
3155 let channel_id = ChannelId::from_proto(request.channel_id);
3156 let to = ChannelId::from_proto(request.to);
3157
3158 let (root_id, channels) = session
3159 .db()
3160 .await
3161 .move_channel(channel_id, to, session.user_id())
3162 .await?;
3163
3164 let connection_pool = session.connection_pool().await;
3165 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3166 let channels = channels
3167 .iter()
3168 .filter_map(|channel| {
3169 if role.can_see_channel(channel.visibility) {
3170 Some(channel.to_proto())
3171 } else {
3172 None
3173 }
3174 })
3175 .collect::<Vec<_>>();
3176 if channels.is_empty() {
3177 continue;
3178 }
3179
3180 let update = proto::UpdateChannels {
3181 channels,
3182 ..Default::default()
3183 };
3184
3185 session.peer.send(connection_id, update.clone())?;
3186 }
3187
3188 response.send(Ack {})?;
3189 Ok(())
3190}
3191
3192async fn reorder_channel(
3193 request: proto::ReorderChannel,
3194 response: Response<proto::ReorderChannel>,
3195 session: MessageContext,
3196) -> Result<()> {
3197 let channel_id = ChannelId::from_proto(request.channel_id);
3198 let direction = request.direction();
3199
3200 let updated_channels = session
3201 .db()
3202 .await
3203 .reorder_channel(channel_id, direction, session.user_id())
3204 .await?;
3205
3206 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3207 let connection_pool = session.connection_pool().await;
3208 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3209 let channels = updated_channels
3210 .iter()
3211 .filter_map(|channel| {
3212 if role.can_see_channel(channel.visibility) {
3213 Some(channel.to_proto())
3214 } else {
3215 None
3216 }
3217 })
3218 .collect::<Vec<_>>();
3219
3220 if channels.is_empty() {
3221 continue;
3222 }
3223
3224 let update = proto::UpdateChannels {
3225 channels,
3226 ..Default::default()
3227 };
3228
3229 session.peer.send(connection_id, update.clone())?;
3230 }
3231 }
3232
3233 response.send(Ack {})?;
3234 Ok(())
3235}
3236
3237/// Get the list of channel members
3238async fn get_channel_members(
3239 request: proto::GetChannelMembers,
3240 response: Response<proto::GetChannelMembers>,
3241 session: MessageContext,
3242) -> Result<()> {
3243 let db = session.db().await;
3244 let channel_id = ChannelId::from_proto(request.channel_id);
3245 let limit = if request.limit == 0 {
3246 u16::MAX as u64
3247 } else {
3248 request.limit
3249 };
3250 let (members, users) = db
3251 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3252 .await?;
3253 response.send(proto::GetChannelMembersResponse { members, users })?;
3254 Ok(())
3255}
3256
3257/// Accept or decline a channel invitation.
3258async fn respond_to_channel_invite(
3259 request: proto::RespondToChannelInvite,
3260 response: Response<proto::RespondToChannelInvite>,
3261 session: MessageContext,
3262) -> Result<()> {
3263 let db = session.db().await;
3264 let channel_id = ChannelId::from_proto(request.channel_id);
3265 let RespondToChannelInvite {
3266 membership_update,
3267 notifications,
3268 } = db
3269 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3270 .await?;
3271
3272 let mut connection_pool = session.connection_pool().await;
3273 if let Some(membership_update) = membership_update {
3274 notify_membership_updated(
3275 &mut connection_pool,
3276 membership_update,
3277 session.user_id(),
3278 &session.peer,
3279 );
3280 } else {
3281 let update = proto::UpdateChannels {
3282 remove_channel_invitations: vec![channel_id.to_proto()],
3283 ..Default::default()
3284 };
3285
3286 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3287 session.peer.send(connection_id, update.clone())?;
3288 }
3289 };
3290
3291 send_notifications(&connection_pool, &session.peer, notifications);
3292
3293 response.send(proto::Ack {})?;
3294
3295 Ok(())
3296}
3297
3298/// Join the channels' room
3299async fn join_channel(
3300 request: proto::JoinChannel,
3301 response: Response<proto::JoinChannel>,
3302 session: MessageContext,
3303) -> Result<()> {
3304 let channel_id = ChannelId::from_proto(request.channel_id);
3305 join_channel_internal(channel_id, Box::new(response), session).await
3306}
3307
3308trait JoinChannelInternalResponse {
3309 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3310}
3311impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3312 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3313 Response::<proto::JoinChannel>::send(self, result)
3314 }
3315}
3316impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3317 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3318 Response::<proto::JoinRoom>::send(self, result)
3319 }
3320}
3321
3322async fn join_channel_internal(
3323 channel_id: ChannelId,
3324 response: Box<impl JoinChannelInternalResponse>,
3325 session: MessageContext,
3326) -> Result<()> {
3327 let joined_room = {
3328 let mut db = session.db().await;
3329 // If zed quits without leaving the room, and the user re-opens zed before the
3330 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3331 // room they were in.
3332 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3333 tracing::info!(
3334 stale_connection_id = %connection,
3335 "cleaning up stale connection",
3336 );
3337 drop(db);
3338 leave_room_for_session(&session, connection).await?;
3339 db = session.db().await;
3340 }
3341
3342 let (joined_room, membership_updated, role) = db
3343 .join_channel(channel_id, session.user_id(), session.connection_id)
3344 .await?;
3345
3346 let live_kit_connection_info =
3347 session
3348 .app_state
3349 .livekit_client
3350 .as_ref()
3351 .and_then(|live_kit| {
3352 let (can_publish, token) = if role == ChannelRole::Guest {
3353 (
3354 false,
3355 live_kit
3356 .guest_token(
3357 &joined_room.room.livekit_room,
3358 &session.user_id().to_string(),
3359 )
3360 .trace_err()?,
3361 )
3362 } else {
3363 (
3364 true,
3365 live_kit
3366 .room_token(
3367 &joined_room.room.livekit_room,
3368 &session.user_id().to_string(),
3369 )
3370 .trace_err()?,
3371 )
3372 };
3373
3374 Some(LiveKitConnectionInfo {
3375 server_url: live_kit.url().into(),
3376 token,
3377 can_publish,
3378 })
3379 });
3380
3381 response.send(proto::JoinRoomResponse {
3382 room: Some(joined_room.room.clone()),
3383 channel_id: joined_room
3384 .channel
3385 .as_ref()
3386 .map(|channel| channel.id.to_proto()),
3387 live_kit_connection_info,
3388 })?;
3389
3390 let mut connection_pool = session.connection_pool().await;
3391 if let Some(membership_updated) = membership_updated {
3392 notify_membership_updated(
3393 &mut connection_pool,
3394 membership_updated,
3395 session.user_id(),
3396 &session.peer,
3397 );
3398 }
3399
3400 room_updated(&joined_room.room, &session.peer);
3401
3402 joined_room
3403 };
3404
3405 channel_updated(
3406 &joined_room.channel.context("channel not returned")?,
3407 &joined_room.room,
3408 &session.peer,
3409 &*session.connection_pool().await,
3410 );
3411
3412 update_user_contacts(session.user_id(), &session).await?;
3413 Ok(())
3414}
3415
3416/// Start editing the channel notes
3417async fn join_channel_buffer(
3418 request: proto::JoinChannelBuffer,
3419 response: Response<proto::JoinChannelBuffer>,
3420 session: MessageContext,
3421) -> Result<()> {
3422 let db = session.db().await;
3423 let channel_id = ChannelId::from_proto(request.channel_id);
3424
3425 let open_response = db
3426 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3427 .await?;
3428
3429 let collaborators = open_response.collaborators.clone();
3430 response.send(open_response)?;
3431
3432 let update = UpdateChannelBufferCollaborators {
3433 channel_id: channel_id.to_proto(),
3434 collaborators: collaborators.clone(),
3435 };
3436 channel_buffer_updated(
3437 session.connection_id,
3438 collaborators
3439 .iter()
3440 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3441 &update,
3442 &session.peer,
3443 );
3444
3445 Ok(())
3446}
3447
3448/// Edit the channel notes
3449async fn update_channel_buffer(
3450 request: proto::UpdateChannelBuffer,
3451 session: MessageContext,
3452) -> Result<()> {
3453 let db = session.db().await;
3454 let channel_id = ChannelId::from_proto(request.channel_id);
3455
3456 let (collaborators, epoch, version) = db
3457 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3458 .await?;
3459
3460 channel_buffer_updated(
3461 session.connection_id,
3462 collaborators.clone(),
3463 &proto::UpdateChannelBuffer {
3464 channel_id: channel_id.to_proto(),
3465 operations: request.operations,
3466 },
3467 &session.peer,
3468 );
3469
3470 let pool = &*session.connection_pool().await;
3471
3472 let non_collaborators =
3473 pool.channel_connection_ids(channel_id)
3474 .filter_map(|(connection_id, _)| {
3475 if collaborators.contains(&connection_id) {
3476 None
3477 } else {
3478 Some(connection_id)
3479 }
3480 });
3481
3482 broadcast(None, non_collaborators, |peer_id| {
3483 session.peer.send(
3484 peer_id,
3485 proto::UpdateChannels {
3486 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3487 channel_id: channel_id.to_proto(),
3488 epoch: epoch as u64,
3489 version: version.clone(),
3490 }],
3491 ..Default::default()
3492 },
3493 )
3494 });
3495
3496 Ok(())
3497}
3498
3499/// Rejoin the channel notes after a connection blip
3500async fn rejoin_channel_buffers(
3501 request: proto::RejoinChannelBuffers,
3502 response: Response<proto::RejoinChannelBuffers>,
3503 session: MessageContext,
3504) -> Result<()> {
3505 let db = session.db().await;
3506 let buffers = db
3507 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3508 .await?;
3509
3510 for rejoined_buffer in &buffers {
3511 let collaborators_to_notify = rejoined_buffer
3512 .buffer
3513 .collaborators
3514 .iter()
3515 .filter_map(|c| Some(c.peer_id?.into()));
3516 channel_buffer_updated(
3517 session.connection_id,
3518 collaborators_to_notify,
3519 &proto::UpdateChannelBufferCollaborators {
3520 channel_id: rejoined_buffer.buffer.channel_id,
3521 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3522 },
3523 &session.peer,
3524 );
3525 }
3526
3527 response.send(proto::RejoinChannelBuffersResponse {
3528 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3529 })?;
3530
3531 Ok(())
3532}
3533
3534/// Stop editing the channel notes
3535async fn leave_channel_buffer(
3536 request: proto::LeaveChannelBuffer,
3537 response: Response<proto::LeaveChannelBuffer>,
3538 session: MessageContext,
3539) -> Result<()> {
3540 let db = session.db().await;
3541 let channel_id = ChannelId::from_proto(request.channel_id);
3542
3543 let left_buffer = db
3544 .leave_channel_buffer(channel_id, session.connection_id)
3545 .await?;
3546
3547 response.send(Ack {})?;
3548
3549 channel_buffer_updated(
3550 session.connection_id,
3551 left_buffer.connections,
3552 &proto::UpdateChannelBufferCollaborators {
3553 channel_id: channel_id.to_proto(),
3554 collaborators: left_buffer.collaborators,
3555 },
3556 &session.peer,
3557 );
3558
3559 Ok(())
3560}
3561
3562fn channel_buffer_updated<T: EnvelopedMessage>(
3563 sender_id: ConnectionId,
3564 collaborators: impl IntoIterator<Item = ConnectionId>,
3565 message: &T,
3566 peer: &Peer,
3567) {
3568 broadcast(Some(sender_id), collaborators, |peer_id| {
3569 peer.send(peer_id, message.clone())
3570 });
3571}
3572
3573fn send_notifications(
3574 connection_pool: &ConnectionPool,
3575 peer: &Peer,
3576 notifications: db::NotificationBatch,
3577) {
3578 for (user_id, notification) in notifications {
3579 for connection_id in connection_pool.user_connection_ids(user_id) {
3580 if let Err(error) = peer.send(
3581 connection_id,
3582 proto::AddNotification {
3583 notification: Some(notification.clone()),
3584 },
3585 ) {
3586 tracing::error!(
3587 "failed to send notification to {:?} {}",
3588 connection_id,
3589 error
3590 );
3591 }
3592 }
3593 }
3594}
3595
3596/// Send a message to the channel
3597async fn send_channel_message(
3598 request: proto::SendChannelMessage,
3599 response: Response<proto::SendChannelMessage>,
3600 session: MessageContext,
3601) -> Result<()> {
3602 // Validate the message body.
3603 let body = request.body.trim().to_string();
3604 if body.len() > MAX_MESSAGE_LEN {
3605 return Err(anyhow!("message is too long"))?;
3606 }
3607 if body.is_empty() {
3608 return Err(anyhow!("message can't be blank"))?;
3609 }
3610
3611 // TODO: adjust mentions if body is trimmed
3612
3613 let timestamp = OffsetDateTime::now_utc();
3614 let nonce = request.nonce.context("nonce can't be blank")?;
3615
3616 let channel_id = ChannelId::from_proto(request.channel_id);
3617 let CreatedChannelMessage {
3618 message_id,
3619 participant_connection_ids,
3620 notifications,
3621 } = session
3622 .db()
3623 .await
3624 .create_channel_message(
3625 channel_id,
3626 session.user_id(),
3627 &body,
3628 &request.mentions,
3629 timestamp,
3630 nonce.clone().into(),
3631 request.reply_to_message_id.map(MessageId::from_proto),
3632 )
3633 .await?;
3634
3635 let message = proto::ChannelMessage {
3636 sender_id: session.user_id().to_proto(),
3637 id: message_id.to_proto(),
3638 body,
3639 mentions: request.mentions,
3640 timestamp: timestamp.unix_timestamp() as u64,
3641 nonce: Some(nonce),
3642 reply_to_message_id: request.reply_to_message_id,
3643 edited_at: None,
3644 };
3645 broadcast(
3646 Some(session.connection_id),
3647 participant_connection_ids.clone(),
3648 |connection| {
3649 session.peer.send(
3650 connection,
3651 proto::ChannelMessageSent {
3652 channel_id: channel_id.to_proto(),
3653 message: Some(message.clone()),
3654 },
3655 )
3656 },
3657 );
3658 response.send(proto::SendChannelMessageResponse {
3659 message: Some(message),
3660 })?;
3661
3662 let pool = &*session.connection_pool().await;
3663 let non_participants =
3664 pool.channel_connection_ids(channel_id)
3665 .filter_map(|(connection_id, _)| {
3666 if participant_connection_ids.contains(&connection_id) {
3667 None
3668 } else {
3669 Some(connection_id)
3670 }
3671 });
3672 broadcast(None, non_participants, |peer_id| {
3673 session.peer.send(
3674 peer_id,
3675 proto::UpdateChannels {
3676 latest_channel_message_ids: vec![proto::ChannelMessageId {
3677 channel_id: channel_id.to_proto(),
3678 message_id: message_id.to_proto(),
3679 }],
3680 ..Default::default()
3681 },
3682 )
3683 });
3684 send_notifications(pool, &session.peer, notifications);
3685
3686 Ok(())
3687}
3688
3689/// Delete a channel message
3690async fn remove_channel_message(
3691 request: proto::RemoveChannelMessage,
3692 response: Response<proto::RemoveChannelMessage>,
3693 session: MessageContext,
3694) -> Result<()> {
3695 let channel_id = ChannelId::from_proto(request.channel_id);
3696 let message_id = MessageId::from_proto(request.message_id);
3697 let (connection_ids, existing_notification_ids) = session
3698 .db()
3699 .await
3700 .remove_channel_message(channel_id, message_id, session.user_id())
3701 .await?;
3702
3703 broadcast(
3704 Some(session.connection_id),
3705 connection_ids,
3706 move |connection| {
3707 session.peer.send(connection, request.clone())?;
3708
3709 for notification_id in &existing_notification_ids {
3710 session.peer.send(
3711 connection,
3712 proto::DeleteNotification {
3713 notification_id: (*notification_id).to_proto(),
3714 },
3715 )?;
3716 }
3717
3718 Ok(())
3719 },
3720 );
3721 response.send(proto::Ack {})?;
3722 Ok(())
3723}
3724
3725async fn update_channel_message(
3726 request: proto::UpdateChannelMessage,
3727 response: Response<proto::UpdateChannelMessage>,
3728 session: MessageContext,
3729) -> Result<()> {
3730 let channel_id = ChannelId::from_proto(request.channel_id);
3731 let message_id = MessageId::from_proto(request.message_id);
3732 let updated_at = OffsetDateTime::now_utc();
3733 let UpdatedChannelMessage {
3734 message_id,
3735 participant_connection_ids,
3736 notifications,
3737 reply_to_message_id,
3738 timestamp,
3739 deleted_mention_notification_ids,
3740 updated_mention_notifications,
3741 } = session
3742 .db()
3743 .await
3744 .update_channel_message(
3745 channel_id,
3746 message_id,
3747 session.user_id(),
3748 request.body.as_str(),
3749 &request.mentions,
3750 updated_at,
3751 )
3752 .await?;
3753
3754 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3755
3756 let message = proto::ChannelMessage {
3757 sender_id: session.user_id().to_proto(),
3758 id: message_id.to_proto(),
3759 body: request.body.clone(),
3760 mentions: request.mentions.clone(),
3761 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3762 nonce: Some(nonce),
3763 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3764 edited_at: Some(updated_at.unix_timestamp() as u64),
3765 };
3766
3767 response.send(proto::Ack {})?;
3768
3769 let pool = &*session.connection_pool().await;
3770 broadcast(
3771 Some(session.connection_id),
3772 participant_connection_ids,
3773 |connection| {
3774 session.peer.send(
3775 connection,
3776 proto::ChannelMessageUpdate {
3777 channel_id: channel_id.to_proto(),
3778 message: Some(message.clone()),
3779 },
3780 )?;
3781
3782 for notification_id in &deleted_mention_notification_ids {
3783 session.peer.send(
3784 connection,
3785 proto::DeleteNotification {
3786 notification_id: (*notification_id).to_proto(),
3787 },
3788 )?;
3789 }
3790
3791 for notification in &updated_mention_notifications {
3792 session.peer.send(
3793 connection,
3794 proto::UpdateNotification {
3795 notification: Some(notification.clone()),
3796 },
3797 )?;
3798 }
3799
3800 Ok(())
3801 },
3802 );
3803
3804 send_notifications(pool, &session.peer, notifications);
3805
3806 Ok(())
3807}
3808
3809/// Mark a channel message as read
3810async fn acknowledge_channel_message(
3811 request: proto::AckChannelMessage,
3812 session: MessageContext,
3813) -> Result<()> {
3814 let channel_id = ChannelId::from_proto(request.channel_id);
3815 let message_id = MessageId::from_proto(request.message_id);
3816 let notifications = session
3817 .db()
3818 .await
3819 .observe_channel_message(channel_id, session.user_id(), message_id)
3820 .await?;
3821 send_notifications(
3822 &*session.connection_pool().await,
3823 &session.peer,
3824 notifications,
3825 );
3826 Ok(())
3827}
3828
3829/// Mark a buffer version as synced
3830async fn acknowledge_buffer_version(
3831 request: proto::AckBufferOperation,
3832 session: MessageContext,
3833) -> Result<()> {
3834 let buffer_id = BufferId::from_proto(request.buffer_id);
3835 session
3836 .db()
3837 .await
3838 .observe_buffer_version(
3839 buffer_id,
3840 session.user_id(),
3841 request.epoch as i32,
3842 &request.version,
3843 )
3844 .await?;
3845 Ok(())
3846}
3847
3848/// Get a Supermaven API key for the user
3849async fn get_supermaven_api_key(
3850 _request: proto::GetSupermavenApiKey,
3851 response: Response<proto::GetSupermavenApiKey>,
3852 session: MessageContext,
3853) -> Result<()> {
3854 let user_id: String = session.user_id().to_string();
3855 if !session.is_staff() {
3856 return Err(anyhow!("supermaven not enabled for this account"))?;
3857 }
3858
3859 let email = session.email().context("user must have an email")?;
3860
3861 let supermaven_admin_api = session
3862 .supermaven_client
3863 .as_ref()
3864 .context("supermaven not configured")?;
3865
3866 let result = supermaven_admin_api
3867 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3868 .await?;
3869
3870 response.send(proto::GetSupermavenApiKeyResponse {
3871 api_key: result.api_key,
3872 })?;
3873
3874 Ok(())
3875}
3876
3877/// Start receiving chat updates for a channel
3878async fn join_channel_chat(
3879 request: proto::JoinChannelChat,
3880 response: Response<proto::JoinChannelChat>,
3881 session: MessageContext,
3882) -> Result<()> {
3883 let channel_id = ChannelId::from_proto(request.channel_id);
3884
3885 let db = session.db().await;
3886 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3887 .await?;
3888 let messages = db
3889 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3890 .await?;
3891 response.send(proto::JoinChannelChatResponse {
3892 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3893 messages,
3894 })?;
3895 Ok(())
3896}
3897
3898/// Stop receiving chat updates for a channel
3899async fn leave_channel_chat(
3900 request: proto::LeaveChannelChat,
3901 session: MessageContext,
3902) -> Result<()> {
3903 let channel_id = ChannelId::from_proto(request.channel_id);
3904 session
3905 .db()
3906 .await
3907 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3908 .await?;
3909 Ok(())
3910}
3911
3912/// Retrieve the chat history for a channel
3913async fn get_channel_messages(
3914 request: proto::GetChannelMessages,
3915 response: Response<proto::GetChannelMessages>,
3916 session: MessageContext,
3917) -> Result<()> {
3918 let channel_id = ChannelId::from_proto(request.channel_id);
3919 let messages = session
3920 .db()
3921 .await
3922 .get_channel_messages(
3923 channel_id,
3924 session.user_id(),
3925 MESSAGE_COUNT_PER_PAGE,
3926 Some(MessageId::from_proto(request.before_message_id)),
3927 )
3928 .await?;
3929 response.send(proto::GetChannelMessagesResponse {
3930 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3931 messages,
3932 })?;
3933 Ok(())
3934}
3935
3936/// Retrieve specific chat messages
3937async fn get_channel_messages_by_id(
3938 request: proto::GetChannelMessagesById,
3939 response: Response<proto::GetChannelMessagesById>,
3940 session: MessageContext,
3941) -> Result<()> {
3942 let message_ids = request
3943 .message_ids
3944 .iter()
3945 .map(|id| MessageId::from_proto(*id))
3946 .collect::<Vec<_>>();
3947 let messages = session
3948 .db()
3949 .await
3950 .get_channel_messages_by_id(session.user_id(), &message_ids)
3951 .await?;
3952 response.send(proto::GetChannelMessagesResponse {
3953 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3954 messages,
3955 })?;
3956 Ok(())
3957}
3958
3959/// Retrieve the current users notifications
3960async fn get_notifications(
3961 request: proto::GetNotifications,
3962 response: Response<proto::GetNotifications>,
3963 session: MessageContext,
3964) -> Result<()> {
3965 let notifications = session
3966 .db()
3967 .await
3968 .get_notifications(
3969 session.user_id(),
3970 NOTIFICATION_COUNT_PER_PAGE,
3971 request.before_id.map(db::NotificationId::from_proto),
3972 )
3973 .await?;
3974 response.send(proto::GetNotificationsResponse {
3975 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3976 notifications,
3977 })?;
3978 Ok(())
3979}
3980
3981/// Mark notifications as read
3982async fn mark_notification_as_read(
3983 request: proto::MarkNotificationRead,
3984 response: Response<proto::MarkNotificationRead>,
3985 session: MessageContext,
3986) -> Result<()> {
3987 let database = &session.db().await;
3988 let notifications = database
3989 .mark_notification_as_read_by_id(
3990 session.user_id(),
3991 NotificationId::from_proto(request.notification_id),
3992 )
3993 .await?;
3994 send_notifications(
3995 &*session.connection_pool().await,
3996 &session.peer,
3997 notifications,
3998 );
3999 response.send(proto::Ack {})?;
4000 Ok(())
4001}
4002
4003fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4004 let message = match message {
4005 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4006 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4007 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4008 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4009 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4010 code: frame.code.into(),
4011 reason: frame.reason.as_str().to_owned().into(),
4012 })),
4013 // We should never receive a frame while reading the message, according
4014 // to the `tungstenite` maintainers:
4015 //
4016 // > It cannot occur when you read messages from the WebSocket, but it
4017 // > can be used when you want to send the raw frames (e.g. you want to
4018 // > send the frames to the WebSocket without composing the full message first).
4019 // >
4020 // > — https://github.com/snapview/tungstenite-rs/issues/268
4021 TungsteniteMessage::Frame(_) => {
4022 bail!("received an unexpected frame while reading the message")
4023 }
4024 };
4025
4026 Ok(message)
4027}
4028
4029fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4030 match message {
4031 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4032 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4033 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4034 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4035 AxumMessage::Close(frame) => {
4036 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4037 code: frame.code.into(),
4038 reason: frame.reason.as_ref().into(),
4039 }))
4040 }
4041 }
4042}
4043
4044fn notify_membership_updated(
4045 connection_pool: &mut ConnectionPool,
4046 result: MembershipUpdated,
4047 user_id: UserId,
4048 peer: &Peer,
4049) {
4050 for membership in &result.new_channels.channel_memberships {
4051 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4052 }
4053 for channel_id in &result.removed_channels {
4054 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4055 }
4056
4057 let user_channels_update = proto::UpdateUserChannels {
4058 channel_memberships: result
4059 .new_channels
4060 .channel_memberships
4061 .iter()
4062 .map(|cm| proto::ChannelMembership {
4063 channel_id: cm.channel_id.to_proto(),
4064 role: cm.role.into(),
4065 })
4066 .collect(),
4067 ..Default::default()
4068 };
4069
4070 let mut update = build_channels_update(result.new_channels);
4071 update.delete_channels = result
4072 .removed_channels
4073 .into_iter()
4074 .map(|id| id.to_proto())
4075 .collect();
4076 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4077
4078 for connection_id in connection_pool.user_connection_ids(user_id) {
4079 peer.send(connection_id, user_channels_update.clone())
4080 .trace_err();
4081 peer.send(connection_id, update.clone()).trace_err();
4082 }
4083}
4084
4085fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4086 proto::UpdateUserChannels {
4087 channel_memberships: channels
4088 .channel_memberships
4089 .iter()
4090 .map(|m| proto::ChannelMembership {
4091 channel_id: m.channel_id.to_proto(),
4092 role: m.role.into(),
4093 })
4094 .collect(),
4095 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4096 observed_channel_message_id: channels.observed_channel_messages.clone(),
4097 }
4098}
4099
4100fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4101 let mut update = proto::UpdateChannels::default();
4102
4103 for channel in channels.channels {
4104 update.channels.push(channel.to_proto());
4105 }
4106
4107 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4108 update.latest_channel_message_ids = channels.latest_channel_messages;
4109
4110 for (channel_id, participants) in channels.channel_participants {
4111 update
4112 .channel_participants
4113 .push(proto::ChannelParticipants {
4114 channel_id: channel_id.to_proto(),
4115 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4116 });
4117 }
4118
4119 for channel in channels.invited_channels {
4120 update.channel_invitations.push(channel.to_proto());
4121 }
4122
4123 update
4124}
4125
4126fn build_initial_contacts_update(
4127 contacts: Vec<db::Contact>,
4128 pool: &ConnectionPool,
4129) -> proto::UpdateContacts {
4130 let mut update = proto::UpdateContacts::default();
4131
4132 for contact in contacts {
4133 match contact {
4134 db::Contact::Accepted { user_id, busy } => {
4135 update.contacts.push(contact_for_user(user_id, busy, pool));
4136 }
4137 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4138 db::Contact::Incoming { user_id } => {
4139 update
4140 .incoming_requests
4141 .push(proto::IncomingContactRequest {
4142 requester_id: user_id.to_proto(),
4143 })
4144 }
4145 }
4146 }
4147
4148 update
4149}
4150
4151fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4152 proto::Contact {
4153 user_id: user_id.to_proto(),
4154 online: pool.is_user_online(user_id),
4155 busy,
4156 }
4157}
4158
4159fn room_updated(room: &proto::Room, peer: &Peer) {
4160 broadcast(
4161 None,
4162 room.participants
4163 .iter()
4164 .filter_map(|participant| Some(participant.peer_id?.into())),
4165 |peer_id| {
4166 peer.send(
4167 peer_id,
4168 proto::RoomUpdated {
4169 room: Some(room.clone()),
4170 },
4171 )
4172 },
4173 );
4174}
4175
4176fn channel_updated(
4177 channel: &db::channel::Model,
4178 room: &proto::Room,
4179 peer: &Peer,
4180 pool: &ConnectionPool,
4181) {
4182 let participants = room
4183 .participants
4184 .iter()
4185 .map(|p| p.user_id)
4186 .collect::<Vec<_>>();
4187
4188 broadcast(
4189 None,
4190 pool.channel_connection_ids(channel.root_id())
4191 .filter_map(|(channel_id, role)| {
4192 role.can_see_channel(channel.visibility)
4193 .then_some(channel_id)
4194 }),
4195 |peer_id| {
4196 peer.send(
4197 peer_id,
4198 proto::UpdateChannels {
4199 channel_participants: vec![proto::ChannelParticipants {
4200 channel_id: channel.id.to_proto(),
4201 participant_user_ids: participants.clone(),
4202 }],
4203 ..Default::default()
4204 },
4205 )
4206 },
4207 );
4208}
4209
4210async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4211 let db = session.db().await;
4212
4213 let contacts = db.get_contacts(user_id).await?;
4214 let busy = db.is_user_busy(user_id).await?;
4215
4216 let pool = session.connection_pool().await;
4217 let updated_contact = contact_for_user(user_id, busy, &pool);
4218 for contact in contacts {
4219 if let db::Contact::Accepted {
4220 user_id: contact_user_id,
4221 ..
4222 } = contact
4223 {
4224 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4225 session
4226 .peer
4227 .send(
4228 contact_conn_id,
4229 proto::UpdateContacts {
4230 contacts: vec![updated_contact.clone()],
4231 remove_contacts: Default::default(),
4232 incoming_requests: Default::default(),
4233 remove_incoming_requests: Default::default(),
4234 outgoing_requests: Default::default(),
4235 remove_outgoing_requests: Default::default(),
4236 },
4237 )
4238 .trace_err();
4239 }
4240 }
4241 }
4242 Ok(())
4243}
4244
4245async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4246 let mut contacts_to_update = HashSet::default();
4247
4248 let room_id;
4249 let canceled_calls_to_user_ids;
4250 let livekit_room;
4251 let delete_livekit_room;
4252 let room;
4253 let channel;
4254
4255 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4256 contacts_to_update.insert(session.user_id());
4257
4258 for project in left_room.left_projects.values() {
4259 project_left(project, session);
4260 }
4261
4262 room_id = RoomId::from_proto(left_room.room.id);
4263 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4264 livekit_room = mem::take(&mut left_room.room.livekit_room);
4265 delete_livekit_room = left_room.deleted;
4266 room = mem::take(&mut left_room.room);
4267 channel = mem::take(&mut left_room.channel);
4268
4269 room_updated(&room, &session.peer);
4270 } else {
4271 return Ok(());
4272 }
4273
4274 if let Some(channel) = channel {
4275 channel_updated(
4276 &channel,
4277 &room,
4278 &session.peer,
4279 &*session.connection_pool().await,
4280 );
4281 }
4282
4283 {
4284 let pool = session.connection_pool().await;
4285 for canceled_user_id in canceled_calls_to_user_ids {
4286 for connection_id in pool.user_connection_ids(canceled_user_id) {
4287 session
4288 .peer
4289 .send(
4290 connection_id,
4291 proto::CallCanceled {
4292 room_id: room_id.to_proto(),
4293 },
4294 )
4295 .trace_err();
4296 }
4297 contacts_to_update.insert(canceled_user_id);
4298 }
4299 }
4300
4301 for contact_user_id in contacts_to_update {
4302 update_user_contacts(contact_user_id, session).await?;
4303 }
4304
4305 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4306 live_kit
4307 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4308 .await
4309 .trace_err();
4310
4311 if delete_livekit_room {
4312 live_kit.delete_room(livekit_room).await.trace_err();
4313 }
4314 }
4315
4316 Ok(())
4317}
4318
4319async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4320 let left_channel_buffers = session
4321 .db()
4322 .await
4323 .leave_channel_buffers(session.connection_id)
4324 .await?;
4325
4326 for left_buffer in left_channel_buffers {
4327 channel_buffer_updated(
4328 session.connection_id,
4329 left_buffer.connections,
4330 &proto::UpdateChannelBufferCollaborators {
4331 channel_id: left_buffer.channel_id.to_proto(),
4332 collaborators: left_buffer.collaborators,
4333 },
4334 &session.peer,
4335 );
4336 }
4337
4338 Ok(())
4339}
4340
4341fn project_left(project: &db::LeftProject, session: &Session) {
4342 for connection_id in &project.connection_ids {
4343 if project.should_unshare {
4344 session
4345 .peer
4346 .send(
4347 *connection_id,
4348 proto::UnshareProject {
4349 project_id: project.id.to_proto(),
4350 },
4351 )
4352 .trace_err();
4353 } else {
4354 session
4355 .peer
4356 .send(
4357 *connection_id,
4358 proto::RemoveProjectCollaborator {
4359 project_id: project.id.to_proto(),
4360 peer_id: Some(session.connection_id.into()),
4361 },
4362 )
4363 .trace_err();
4364 }
4365 }
4366}
4367
4368pub trait ResultExt {
4369 type Ok;
4370
4371 fn trace_err(self) -> Option<Self::Ok>;
4372}
4373
4374impl<T, E> ResultExt for Result<T, E>
4375where
4376 E: std::fmt::Debug,
4377{
4378 type Ok = T;
4379
4380 #[track_caller]
4381 fn trace_err(self) -> Option<T> {
4382 match self {
4383 Ok(value) => Some(value),
4384 Err(error) => {
4385 tracing::error!("{:?}", error);
4386 None
4387 }
4388 }
4389 }
4390}