1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
8 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
9 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
10 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use futures::TryFutureExt as _;
36use reqwest_client::ReqwestClient;
37use rpc::proto::{MultiLspQuery, split_repository_update};
38use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
39use tracing::Span;
40
41use futures::{
42 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
43 stream::FuturesUnordered,
44};
45use prometheus::{IntGauge, register_int_gauge};
46use rpc::{
47 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52};
53use semantic_version::SemanticVersion;
54use serde::{Serialize, Serializer};
55use std::{
56 any::TypeId,
57 future::Future,
58 marker::PhantomData,
59 mem,
60 net::SocketAddr,
61 ops::{Deref, DerefMut},
62 rc::Rc,
63 sync::{
64 Arc, OnceLock,
65 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
66 },
67 time::{Duration, Instant},
68};
69use time::OffsetDateTime;
70use tokio::sync::{Semaphore, watch};
71use tower::ServiceBuilder;
72use tracing::{
73 Instrument,
74 field::{self},
75 info_span, instrument,
76};
77
78pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
79
80// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
81pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
82
83const MESSAGE_COUNT_PER_PAGE: usize = 100;
84const MAX_MESSAGE_LEN: usize = 1024;
85const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
86const MAX_CONCURRENT_CONNECTIONS: usize = 512;
87
88static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
89
90const TOTAL_DURATION_MS: &str = "total_duration_ms";
91const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
92const QUEUE_DURATION_MS: &str = "queue_duration_ms";
93const HOST_WAITING_MS: &str = "host_waiting_ms";
94
95type MessageHandler =
96 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
97
98pub struct ConnectionGuard;
99
100impl ConnectionGuard {
101 pub fn try_acquire() -> Result<Self, ()> {
102 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
103 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
104 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
105 tracing::error!(
106 "too many concurrent connections: {}",
107 current_connections + 1
108 );
109 return Err(());
110 }
111 Ok(ConnectionGuard)
112 }
113}
114
115impl Drop for ConnectionGuard {
116 fn drop(&mut self) {
117 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
118 }
119}
120
121struct Response<R> {
122 peer: Arc<Peer>,
123 receipt: Receipt<R>,
124 responded: Arc<AtomicBool>,
125}
126
127impl<R: RequestMessage> Response<R> {
128 fn send(self, payload: R::Response) -> Result<()> {
129 self.responded.store(true, SeqCst);
130 self.peer.respond(self.receipt, payload)?;
131 Ok(())
132 }
133}
134
135#[derive(Clone, Debug)]
136pub enum Principal {
137 User(User),
138 Impersonated { user: User, admin: User },
139}
140
141impl Principal {
142 fn update_span(&self, span: &tracing::Span) {
143 match &self {
144 Principal::User(user) => {
145 span.record("user_id", user.id.0);
146 span.record("login", &user.github_login);
147 }
148 Principal::Impersonated { user, admin } => {
149 span.record("user_id", user.id.0);
150 span.record("login", &user.github_login);
151 span.record("impersonator", &admin.github_login);
152 }
153 }
154 }
155}
156
157#[derive(Clone)]
158struct MessageContext {
159 session: Session,
160 span: tracing::Span,
161}
162
163impl Deref for MessageContext {
164 type Target = Session;
165
166 fn deref(&self) -> &Self::Target {
167 &self.session
168 }
169}
170
171impl MessageContext {
172 pub fn forward_request<T: RequestMessage>(
173 &self,
174 receiver_id: ConnectionId,
175 request: T,
176 ) -> impl Future<Output = anyhow::Result<T::Response>> {
177 let request_start_time = Instant::now();
178 let span = self.span.clone();
179 tracing::info!("start forwarding request");
180 self.peer
181 .forward_request(self.connection_id, receiver_id, request)
182 .inspect(move |_| {
183 span.record(
184 HOST_WAITING_MS,
185 request_start_time.elapsed().as_micros() as f64 / 1000.0,
186 );
187 })
188 .inspect_err(|_| tracing::error!("error forwarding request"))
189 .inspect_ok(|_| tracing::info!("finished forwarding request"))
190 }
191}
192
193#[derive(Clone)]
194struct Session {
195 principal: Principal,
196 connection_id: ConnectionId,
197 db: Arc<tokio::sync::Mutex<DbHandle>>,
198 peer: Arc<Peer>,
199 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
200 app_state: Arc<AppState>,
201 supermaven_client: Option<Arc<SupermavenAdminApi>>,
202 /// The GeoIP country code for the user.
203 #[allow(unused)]
204 geoip_country_code: Option<String>,
205 #[allow(unused)]
206 system_id: Option<String>,
207 _executor: Executor,
208}
209
210impl Session {
211 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
212 #[cfg(test)]
213 tokio::task::yield_now().await;
214 let guard = self.db.lock().await;
215 #[cfg(test)]
216 tokio::task::yield_now().await;
217 guard
218 }
219
220 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
221 #[cfg(test)]
222 tokio::task::yield_now().await;
223 let guard = self.connection_pool.lock();
224 ConnectionPoolGuard {
225 guard,
226 _not_send: PhantomData,
227 }
228 }
229
230 fn is_staff(&self) -> bool {
231 match &self.principal {
232 Principal::User(user) => user.admin,
233 Principal::Impersonated { .. } => true,
234 }
235 }
236
237 fn user_id(&self) -> UserId {
238 match &self.principal {
239 Principal::User(user) => user.id,
240 Principal::Impersonated { user, .. } => user.id,
241 }
242 }
243
244 pub fn email(&self) -> Option<String> {
245 match &self.principal {
246 Principal::User(user) => user.email_address.clone(),
247 Principal::Impersonated { user, .. } => user.email_address.clone(),
248 }
249 }
250}
251
252impl Debug for Session {
253 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
254 let mut result = f.debug_struct("Session");
255 match &self.principal {
256 Principal::User(user) => {
257 result.field("user", &user.github_login);
258 }
259 Principal::Impersonated { user, admin } => {
260 result.field("user", &user.github_login);
261 result.field("impersonator", &admin.github_login);
262 }
263 }
264 result.field("connection_id", &self.connection_id).finish()
265 }
266}
267
268struct DbHandle(Arc<Database>);
269
270impl Deref for DbHandle {
271 type Target = Database;
272
273 fn deref(&self) -> &Self::Target {
274 self.0.as_ref()
275 }
276}
277
278pub struct Server {
279 id: parking_lot::Mutex<ServerId>,
280 peer: Arc<Peer>,
281 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
282 app_state: Arc<AppState>,
283 handlers: HashMap<TypeId, MessageHandler>,
284 teardown: watch::Sender<bool>,
285}
286
287pub(crate) struct ConnectionPoolGuard<'a> {
288 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
289 _not_send: PhantomData<Rc<()>>,
290}
291
292#[derive(Serialize)]
293pub struct ServerSnapshot<'a> {
294 peer: &'a Peer,
295 #[serde(serialize_with = "serialize_deref")]
296 connection_pool: ConnectionPoolGuard<'a>,
297}
298
299pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
300where
301 S: Serializer,
302 T: Deref<Target = U>,
303 U: Serialize,
304{
305 Serialize::serialize(value.deref(), serializer)
306}
307
308impl Server {
309 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
310 let mut server = Self {
311 id: parking_lot::Mutex::new(id),
312 peer: Peer::new(id.0 as u32),
313 app_state: app_state.clone(),
314 connection_pool: Default::default(),
315 handlers: Default::default(),
316 teardown: watch::channel(false).0,
317 };
318
319 server
320 .add_request_handler(ping)
321 .add_request_handler(create_room)
322 .add_request_handler(join_room)
323 .add_request_handler(rejoin_room)
324 .add_request_handler(leave_room)
325 .add_request_handler(set_room_participant_role)
326 .add_request_handler(call)
327 .add_request_handler(cancel_call)
328 .add_message_handler(decline_call)
329 .add_request_handler(update_participant_location)
330 .add_request_handler(share_project)
331 .add_message_handler(unshare_project)
332 .add_request_handler(join_project)
333 .add_message_handler(leave_project)
334 .add_request_handler(update_project)
335 .add_request_handler(update_worktree)
336 .add_request_handler(update_repository)
337 .add_request_handler(remove_repository)
338 .add_message_handler(start_language_server)
339 .add_message_handler(update_language_server)
340 .add_message_handler(update_diagnostic_summary)
341 .add_message_handler(update_worktree_settings)
342 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
343 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
344 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
345 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
346 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
347 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
348 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
349 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
350 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
351 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
352 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
353 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
354 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
355 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
356 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
357 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
358 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
359 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
360 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
361 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
362 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
363 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
364 .add_request_handler(
365 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
366 )
367 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
368 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
369 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
370 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
371 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
372 .add_request_handler(
373 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
374 )
375 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
376 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
377 .add_request_handler(
378 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
379 )
380 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
381 .add_request_handler(
382 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
383 )
384 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
385 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
386 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
387 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
388 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
389 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
390 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
391 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
392 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
393 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
394 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
395 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
396 .add_request_handler(
397 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
398 )
399 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
400 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
401 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
402 .add_request_handler(multi_lsp_query)
403 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
404 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
405 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
406 .add_message_handler(create_buffer_for_peer)
407 .add_request_handler(update_buffer)
408 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
409 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
410 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
411 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
412 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
413 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
414 .add_message_handler(
415 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
416 )
417 .add_request_handler(get_users)
418 .add_request_handler(fuzzy_search_users)
419 .add_request_handler(request_contact)
420 .add_request_handler(remove_contact)
421 .add_request_handler(respond_to_contact_request)
422 .add_message_handler(subscribe_to_channels)
423 .add_request_handler(create_channel)
424 .add_request_handler(delete_channel)
425 .add_request_handler(invite_channel_member)
426 .add_request_handler(remove_channel_member)
427 .add_request_handler(set_channel_member_role)
428 .add_request_handler(set_channel_visibility)
429 .add_request_handler(rename_channel)
430 .add_request_handler(join_channel_buffer)
431 .add_request_handler(leave_channel_buffer)
432 .add_message_handler(update_channel_buffer)
433 .add_request_handler(rejoin_channel_buffers)
434 .add_request_handler(get_channel_members)
435 .add_request_handler(respond_to_channel_invite)
436 .add_request_handler(join_channel)
437 .add_request_handler(join_channel_chat)
438 .add_message_handler(leave_channel_chat)
439 .add_request_handler(send_channel_message)
440 .add_request_handler(remove_channel_message)
441 .add_request_handler(update_channel_message)
442 .add_request_handler(get_channel_messages)
443 .add_request_handler(get_channel_messages_by_id)
444 .add_request_handler(get_notifications)
445 .add_request_handler(mark_notification_as_read)
446 .add_request_handler(move_channel)
447 .add_request_handler(reorder_channel)
448 .add_request_handler(follow)
449 .add_message_handler(unfollow)
450 .add_message_handler(update_followers)
451 .add_message_handler(acknowledge_channel_message)
452 .add_message_handler(acknowledge_buffer_version)
453 .add_request_handler(get_supermaven_api_key)
454 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
455 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
456 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
457 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
458 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
459 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
460 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
461 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
462 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
463 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
464 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
465 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
466 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
467 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
468 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
469 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
470 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
471 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
472 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
473 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
474 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
475 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
476 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
477 .add_message_handler(update_context);
478
479 Arc::new(server)
480 }
481
482 pub async fn start(&self) -> Result<()> {
483 let server_id = *self.id.lock();
484 let app_state = self.app_state.clone();
485 let peer = self.peer.clone();
486 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
487 let pool = self.connection_pool.clone();
488 let livekit_client = self.app_state.livekit_client.clone();
489
490 let span = info_span!("start server");
491 self.app_state.executor.spawn_detached(
492 async move {
493 tracing::info!("waiting for cleanup timeout");
494 timeout.await;
495 tracing::info!("cleanup timeout expired, retrieving stale rooms");
496
497 app_state
498 .db
499 .delete_stale_channel_chat_participants(
500 &app_state.config.zed_environment,
501 server_id,
502 )
503 .await
504 .trace_err();
505
506 if let Some((room_ids, channel_ids)) = app_state
507 .db
508 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
509 .await
510 .trace_err()
511 {
512 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
513 tracing::info!(
514 stale_channel_buffer_count = channel_ids.len(),
515 "retrieved stale channel buffers"
516 );
517
518 for channel_id in channel_ids {
519 if let Some(refreshed_channel_buffer) = app_state
520 .db
521 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
522 .await
523 .trace_err()
524 {
525 for connection_id in refreshed_channel_buffer.connection_ids {
526 peer.send(
527 connection_id,
528 proto::UpdateChannelBufferCollaborators {
529 channel_id: channel_id.to_proto(),
530 collaborators: refreshed_channel_buffer
531 .collaborators
532 .clone(),
533 },
534 )
535 .trace_err();
536 }
537 }
538 }
539
540 for room_id in room_ids {
541 let mut contacts_to_update = HashSet::default();
542 let mut canceled_calls_to_user_ids = Vec::new();
543 let mut livekit_room = String::new();
544 let mut delete_livekit_room = false;
545
546 if let Some(mut refreshed_room) = app_state
547 .db
548 .clear_stale_room_participants(room_id, server_id)
549 .await
550 .trace_err()
551 {
552 tracing::info!(
553 room_id = room_id.0,
554 new_participant_count = refreshed_room.room.participants.len(),
555 "refreshed room"
556 );
557 room_updated(&refreshed_room.room, &peer);
558 if let Some(channel) = refreshed_room.channel.as_ref() {
559 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
560 }
561 contacts_to_update
562 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
563 contacts_to_update
564 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
565 canceled_calls_to_user_ids =
566 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
567 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
568 delete_livekit_room = refreshed_room.room.participants.is_empty();
569 }
570
571 {
572 let pool = pool.lock();
573 for canceled_user_id in canceled_calls_to_user_ids {
574 for connection_id in pool.user_connection_ids(canceled_user_id) {
575 peer.send(
576 connection_id,
577 proto::CallCanceled {
578 room_id: room_id.to_proto(),
579 },
580 )
581 .trace_err();
582 }
583 }
584 }
585
586 for user_id in contacts_to_update {
587 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
588 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
589 if let Some((busy, contacts)) = busy.zip(contacts) {
590 let pool = pool.lock();
591 let updated_contact = contact_for_user(user_id, busy, &pool);
592 for contact in contacts {
593 if let db::Contact::Accepted {
594 user_id: contact_user_id,
595 ..
596 } = contact
597 {
598 for contact_conn_id in
599 pool.user_connection_ids(contact_user_id)
600 {
601 peer.send(
602 contact_conn_id,
603 proto::UpdateContacts {
604 contacts: vec![updated_contact.clone()],
605 remove_contacts: Default::default(),
606 incoming_requests: Default::default(),
607 remove_incoming_requests: Default::default(),
608 outgoing_requests: Default::default(),
609 remove_outgoing_requests: Default::default(),
610 },
611 )
612 .trace_err();
613 }
614 }
615 }
616 }
617 }
618
619 if let Some(live_kit) = livekit_client.as_ref() {
620 if delete_livekit_room {
621 live_kit.delete_room(livekit_room).await.trace_err();
622 }
623 }
624 }
625 }
626
627 app_state
628 .db
629 .delete_stale_channel_chat_participants(
630 &app_state.config.zed_environment,
631 server_id,
632 )
633 .await
634 .trace_err();
635
636 app_state
637 .db
638 .clear_old_worktree_entries(server_id)
639 .await
640 .trace_err();
641
642 app_state
643 .db
644 .delete_stale_servers(&app_state.config.zed_environment, server_id)
645 .await
646 .trace_err();
647 }
648 .instrument(span),
649 );
650 Ok(())
651 }
652
653 pub fn teardown(&self) {
654 self.peer.teardown();
655 self.connection_pool.lock().reset();
656 let _ = self.teardown.send(true);
657 }
658
659 #[cfg(test)]
660 pub fn reset(&self, id: ServerId) {
661 self.teardown();
662 *self.id.lock() = id;
663 self.peer.reset(id.0 as u32);
664 let _ = self.teardown.send(false);
665 }
666
667 #[cfg(test)]
668 pub fn id(&self) -> ServerId {
669 *self.id.lock()
670 }
671
672 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
673 where
674 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
675 Fut: 'static + Send + Future<Output = Result<()>>,
676 M: EnvelopedMessage,
677 {
678 let prev_handler = self.handlers.insert(
679 TypeId::of::<M>(),
680 Box::new(move |envelope, session, span| {
681 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
682 let received_at = envelope.received_at;
683 tracing::info!("message received");
684 let start_time = Instant::now();
685 let future = (handler)(
686 *envelope,
687 MessageContext {
688 session,
689 span: span.clone(),
690 },
691 );
692 async move {
693 let result = future.await;
694 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
695 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
696 let queue_duration_ms = total_duration_ms - processing_duration_ms;
697 span.record(TOTAL_DURATION_MS, total_duration_ms);
698 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
699 span.record(QUEUE_DURATION_MS, queue_duration_ms);
700 match result {
701 Err(error) => {
702 tracing::error!(?error, "error handling message")
703 }
704 Ok(()) => tracing::info!("finished handling message"),
705 }
706 }
707 .boxed()
708 }),
709 );
710 if prev_handler.is_some() {
711 panic!("registered a handler for the same message twice");
712 }
713 self
714 }
715
716 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
717 where
718 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
719 Fut: 'static + Send + Future<Output = Result<()>>,
720 M: EnvelopedMessage,
721 {
722 self.add_handler(move |envelope, session| handler(envelope.payload, session));
723 self
724 }
725
726 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
727 where
728 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
729 Fut: Send + Future<Output = Result<()>>,
730 M: RequestMessage,
731 {
732 let handler = Arc::new(handler);
733 self.add_handler(move |envelope, session| {
734 let receipt = envelope.receipt();
735 let handler = handler.clone();
736 async move {
737 let peer = session.peer.clone();
738 let responded = Arc::new(AtomicBool::default());
739 let response = Response {
740 peer: peer.clone(),
741 responded: responded.clone(),
742 receipt,
743 };
744 match (handler)(envelope.payload, response, session).await {
745 Ok(()) => {
746 if responded.load(std::sync::atomic::Ordering::SeqCst) {
747 Ok(())
748 } else {
749 Err(anyhow!("handler did not send a response"))?
750 }
751 }
752 Err(error) => {
753 let proto_err = match &error {
754 Error::Internal(err) => err.to_proto(),
755 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
756 };
757 peer.respond_with_error(receipt, proto_err)?;
758 Err(error)
759 }
760 }
761 }
762 })
763 }
764
765 pub fn handle_connection(
766 self: &Arc<Self>,
767 connection: Connection,
768 address: String,
769 principal: Principal,
770 zed_version: ZedVersion,
771 release_channel: Option<String>,
772 user_agent: Option<String>,
773 geoip_country_code: Option<String>,
774 system_id: Option<String>,
775 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
776 executor: Executor,
777 connection_guard: Option<ConnectionGuard>,
778 ) -> impl Future<Output = ()> + use<> {
779 let this = self.clone();
780 let span = info_span!("handle connection", %address,
781 connection_id=field::Empty,
782 user_id=field::Empty,
783 login=field::Empty,
784 impersonator=field::Empty,
785 user_agent=field::Empty,
786 geoip_country_code=field::Empty,
787 release_channel=field::Empty,
788 );
789 principal.update_span(&span);
790 if let Some(user_agent) = user_agent {
791 span.record("user_agent", user_agent);
792 }
793 if let Some(release_channel) = release_channel {
794 span.record("release_channel", release_channel);
795 }
796
797 if let Some(country_code) = geoip_country_code.as_ref() {
798 span.record("geoip_country_code", country_code);
799 }
800
801 let mut teardown = self.teardown.subscribe();
802 async move {
803 if *teardown.borrow() {
804 tracing::error!("server is tearing down");
805 return;
806 }
807
808 let (connection_id, handle_io, mut incoming_rx) =
809 this.peer.add_connection(connection, {
810 let executor = executor.clone();
811 move |duration| executor.sleep(duration)
812 });
813 tracing::Span::current().record("connection_id", format!("{}", connection_id));
814
815 tracing::info!("connection opened");
816
817 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
818 let http_client = match ReqwestClient::user_agent(&user_agent) {
819 Ok(http_client) => Arc::new(http_client),
820 Err(error) => {
821 tracing::error!(?error, "failed to create HTTP client");
822 return;
823 }
824 };
825
826 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
827 |supermaven_admin_api_key| {
828 Arc::new(SupermavenAdminApi::new(
829 supermaven_admin_api_key.to_string(),
830 http_client.clone(),
831 ))
832 },
833 );
834
835 let session = Session {
836 principal: principal.clone(),
837 connection_id,
838 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
839 peer: this.peer.clone(),
840 connection_pool: this.connection_pool.clone(),
841 app_state: this.app_state.clone(),
842 geoip_country_code,
843 system_id,
844 _executor: executor.clone(),
845 supermaven_client,
846 };
847
848 if let Err(error) = this
849 .send_initial_client_update(
850 connection_id,
851 zed_version,
852 send_connection_id,
853 &session,
854 )
855 .await
856 {
857 tracing::error!(?error, "failed to send initial client update");
858 return;
859 }
860 drop(connection_guard);
861
862 let handle_io = handle_io.fuse();
863 futures::pin_mut!(handle_io);
864
865 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
866 // This prevents deadlocks when e.g., client A performs a request to client B and
867 // client B performs a request to client A. If both clients stop processing further
868 // messages until their respective request completes, they won't have a chance to
869 // respond to the other client's request and cause a deadlock.
870 //
871 // This arrangement ensures we will attempt to process earlier messages first, but fall
872 // back to processing messages arrived later in the spirit of making progress.
873 const MAX_CONCURRENT_HANDLERS: usize = 256;
874 let mut foreground_message_handlers = FuturesUnordered::new();
875 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
876 let get_concurrent_handlers = {
877 let concurrent_handlers = concurrent_handlers.clone();
878 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
879 };
880 loop {
881 let next_message = async {
882 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
883 let message = incoming_rx.next().await;
884 // Cache the concurrent_handlers here, so that we know what the
885 // queue looks like as each handler starts
886 (permit, message, get_concurrent_handlers())
887 }
888 .fuse();
889 futures::pin_mut!(next_message);
890 futures::select_biased! {
891 _ = teardown.changed().fuse() => return,
892 result = handle_io => {
893 if let Err(error) = result {
894 tracing::error!(?error, "error handling I/O");
895 }
896 break;
897 }
898 _ = foreground_message_handlers.next() => {}
899 next_message = next_message => {
900 let (permit, message, concurrent_handlers) = next_message;
901 if let Some(message) = message {
902 let type_name = message.payload_type_name();
903 // note: we copy all the fields from the parent span so we can query them in the logs.
904 // (https://github.com/tokio-rs/tracing/issues/2670).
905 let span = tracing::info_span!("receive message",
906 %connection_id,
907 %address,
908 type_name,
909 concurrent_handlers,
910 user_id=field::Empty,
911 login=field::Empty,
912 impersonator=field::Empty,
913 multi_lsp_query_request=field::Empty,
914 release_channel=field::Empty,
915 { TOTAL_DURATION_MS }=field::Empty,
916 { PROCESSING_DURATION_MS }=field::Empty,
917 { QUEUE_DURATION_MS }=field::Empty,
918 { HOST_WAITING_MS }=field::Empty
919 );
920 principal.update_span(&span);
921 let span_enter = span.enter();
922 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
923 let is_background = message.is_background();
924 let handle_message = (handler)(message, session.clone(), span.clone());
925 drop(span_enter);
926
927 let handle_message = async move {
928 handle_message.await;
929 drop(permit);
930 }.instrument(span);
931 if is_background {
932 executor.spawn_detached(handle_message);
933 } else {
934 foreground_message_handlers.push(handle_message);
935 }
936 } else {
937 tracing::error!("no message handler");
938 }
939 } else {
940 tracing::info!("connection closed");
941 break;
942 }
943 }
944 }
945 }
946
947 drop(foreground_message_handlers);
948 let concurrent_handlers = get_concurrent_handlers();
949 tracing::info!(concurrent_handlers, "signing out");
950 if let Err(error) = connection_lost(session, teardown, executor).await {
951 tracing::error!(?error, "error signing out");
952 }
953 }
954 .instrument(span)
955 }
956
957 async fn send_initial_client_update(
958 &self,
959 connection_id: ConnectionId,
960 zed_version: ZedVersion,
961 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
962 session: &Session,
963 ) -> Result<()> {
964 self.peer.send(
965 connection_id,
966 proto::Hello {
967 peer_id: Some(connection_id.into()),
968 },
969 )?;
970 tracing::info!("sent hello message");
971 if let Some(send_connection_id) = send_connection_id.take() {
972 let _ = send_connection_id.send(connection_id);
973 }
974
975 match &session.principal {
976 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
977 if !user.connected_once {
978 self.peer.send(connection_id, proto::ShowContacts {})?;
979 self.app_state
980 .db
981 .set_user_connected_once(user.id, true)
982 .await?;
983 }
984
985 let contacts = self.app_state.db.get_contacts(user.id).await?;
986
987 {
988 let mut pool = self.connection_pool.lock();
989 pool.add_connection(connection_id, user.id, user.admin, zed_version);
990 self.peer.send(
991 connection_id,
992 build_initial_contacts_update(contacts, &pool),
993 )?;
994 }
995
996 if should_auto_subscribe_to_channels(zed_version) {
997 subscribe_user_to_channels(user.id, session).await?;
998 }
999
1000 if let Some(incoming_call) =
1001 self.app_state.db.incoming_call_for_user(user.id).await?
1002 {
1003 self.peer.send(connection_id, incoming_call)?;
1004 }
1005
1006 update_user_contacts(user.id, session).await?;
1007 }
1008 }
1009
1010 Ok(())
1011 }
1012
1013 pub async fn invite_code_redeemed(
1014 self: &Arc<Self>,
1015 inviter_id: UserId,
1016 invitee_id: UserId,
1017 ) -> Result<()> {
1018 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1019 if let Some(code) = &user.invite_code {
1020 let pool = self.connection_pool.lock();
1021 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1022 for connection_id in pool.user_connection_ids(inviter_id) {
1023 self.peer.send(
1024 connection_id,
1025 proto::UpdateContacts {
1026 contacts: vec![invitee_contact.clone()],
1027 ..Default::default()
1028 },
1029 )?;
1030 self.peer.send(
1031 connection_id,
1032 proto::UpdateInviteInfo {
1033 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1034 count: user.invite_count as u32,
1035 },
1036 )?;
1037 }
1038 }
1039 }
1040 Ok(())
1041 }
1042
1043 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1044 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1045 if let Some(invite_code) = &user.invite_code {
1046 let pool = self.connection_pool.lock();
1047 for connection_id in pool.user_connection_ids(user_id) {
1048 self.peer.send(
1049 connection_id,
1050 proto::UpdateInviteInfo {
1051 url: format!(
1052 "{}{}",
1053 self.app_state.config.invite_link_prefix, invite_code
1054 ),
1055 count: user.invite_count as u32,
1056 },
1057 )?;
1058 }
1059 }
1060 }
1061 Ok(())
1062 }
1063
1064 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1065 ServerSnapshot {
1066 connection_pool: ConnectionPoolGuard {
1067 guard: self.connection_pool.lock(),
1068 _not_send: PhantomData,
1069 },
1070 peer: &self.peer,
1071 }
1072 }
1073}
1074
1075impl Deref for ConnectionPoolGuard<'_> {
1076 type Target = ConnectionPool;
1077
1078 fn deref(&self) -> &Self::Target {
1079 &self.guard
1080 }
1081}
1082
1083impl DerefMut for ConnectionPoolGuard<'_> {
1084 fn deref_mut(&mut self) -> &mut Self::Target {
1085 &mut self.guard
1086 }
1087}
1088
1089impl Drop for ConnectionPoolGuard<'_> {
1090 fn drop(&mut self) {
1091 #[cfg(test)]
1092 self.check_invariants();
1093 }
1094}
1095
1096fn broadcast<F>(
1097 sender_id: Option<ConnectionId>,
1098 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1099 mut f: F,
1100) where
1101 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1102{
1103 for receiver_id in receiver_ids {
1104 if Some(receiver_id) != sender_id {
1105 if let Err(error) = f(receiver_id) {
1106 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1107 }
1108 }
1109 }
1110}
1111
1112pub struct ProtocolVersion(u32);
1113
1114impl Header for ProtocolVersion {
1115 fn name() -> &'static HeaderName {
1116 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1117 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1118 }
1119
1120 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1121 where
1122 Self: Sized,
1123 I: Iterator<Item = &'i axum::http::HeaderValue>,
1124 {
1125 let version = values
1126 .next()
1127 .ok_or_else(axum::headers::Error::invalid)?
1128 .to_str()
1129 .map_err(|_| axum::headers::Error::invalid())?
1130 .parse()
1131 .map_err(|_| axum::headers::Error::invalid())?;
1132 Ok(Self(version))
1133 }
1134
1135 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1136 values.extend([self.0.to_string().parse().unwrap()]);
1137 }
1138}
1139
1140pub struct AppVersionHeader(SemanticVersion);
1141impl Header for AppVersionHeader {
1142 fn name() -> &'static HeaderName {
1143 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1144 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1145 }
1146
1147 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1148 where
1149 Self: Sized,
1150 I: Iterator<Item = &'i axum::http::HeaderValue>,
1151 {
1152 let version = values
1153 .next()
1154 .ok_or_else(axum::headers::Error::invalid)?
1155 .to_str()
1156 .map_err(|_| axum::headers::Error::invalid())?
1157 .parse()
1158 .map_err(|_| axum::headers::Error::invalid())?;
1159 Ok(Self(version))
1160 }
1161
1162 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1163 values.extend([self.0.to_string().parse().unwrap()]);
1164 }
1165}
1166
1167#[derive(Debug)]
1168pub struct ReleaseChannelHeader(String);
1169
1170impl Header for ReleaseChannelHeader {
1171 fn name() -> &'static HeaderName {
1172 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1173 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1174 }
1175
1176 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1177 where
1178 Self: Sized,
1179 I: Iterator<Item = &'i axum::http::HeaderValue>,
1180 {
1181 Ok(Self(
1182 values
1183 .next()
1184 .ok_or_else(axum::headers::Error::invalid)?
1185 .to_str()
1186 .map_err(|_| axum::headers::Error::invalid())?
1187 .to_owned(),
1188 ))
1189 }
1190
1191 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1192 values.extend([self.0.parse().unwrap()]);
1193 }
1194}
1195
1196pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1197 Router::new()
1198 .route("/rpc", get(handle_websocket_request))
1199 .layer(
1200 ServiceBuilder::new()
1201 .layer(Extension(server.app_state.clone()))
1202 .layer(middleware::from_fn(auth::validate_header)),
1203 )
1204 .route("/metrics", get(handle_metrics))
1205 .layer(Extension(server))
1206}
1207
1208pub async fn handle_websocket_request(
1209 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1210 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1211 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1212 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1213 Extension(server): Extension<Arc<Server>>,
1214 Extension(principal): Extension<Principal>,
1215 user_agent: Option<TypedHeader<UserAgent>>,
1216 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1217 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1218 ws: WebSocketUpgrade,
1219) -> axum::response::Response {
1220 if protocol_version != rpc::PROTOCOL_VERSION {
1221 return (
1222 StatusCode::UPGRADE_REQUIRED,
1223 "client must be upgraded".to_string(),
1224 )
1225 .into_response();
1226 }
1227
1228 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1229 return (
1230 StatusCode::UPGRADE_REQUIRED,
1231 "no version header found".to_string(),
1232 )
1233 .into_response();
1234 };
1235
1236 let release_channel = release_channel_header.map(|header| header.0.0);
1237
1238 if !version.can_collaborate() {
1239 return (
1240 StatusCode::UPGRADE_REQUIRED,
1241 "client must be upgraded".to_string(),
1242 )
1243 .into_response();
1244 }
1245
1246 let socket_address = socket_address.to_string();
1247
1248 // Acquire connection guard before WebSocket upgrade
1249 let connection_guard = match ConnectionGuard::try_acquire() {
1250 Ok(guard) => guard,
1251 Err(()) => {
1252 return (
1253 StatusCode::SERVICE_UNAVAILABLE,
1254 "Too many concurrent connections",
1255 )
1256 .into_response();
1257 }
1258 };
1259
1260 ws.on_upgrade(move |socket| {
1261 let socket = socket
1262 .map_ok(to_tungstenite_message)
1263 .err_into()
1264 .with(|message| async move { to_axum_message(message) });
1265 let connection = Connection::new(Box::pin(socket));
1266 async move {
1267 server
1268 .handle_connection(
1269 connection,
1270 socket_address,
1271 principal,
1272 version,
1273 release_channel,
1274 user_agent.map(|header| header.to_string()),
1275 country_code_header.map(|header| header.to_string()),
1276 system_id_header.map(|header| header.to_string()),
1277 None,
1278 Executor::Production,
1279 Some(connection_guard),
1280 )
1281 .await;
1282 }
1283 })
1284}
1285
1286pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1287 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1288 let connections_metric = CONNECTIONS_METRIC
1289 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1290
1291 let connections = server
1292 .connection_pool
1293 .lock()
1294 .connections()
1295 .filter(|connection| !connection.admin)
1296 .count();
1297 connections_metric.set(connections as _);
1298
1299 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1300 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1301 register_int_gauge!(
1302 "shared_projects",
1303 "number of open projects with one or more guests"
1304 )
1305 .unwrap()
1306 });
1307
1308 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1309 shared_projects_metric.set(shared_projects as _);
1310
1311 let encoder = prometheus::TextEncoder::new();
1312 let metric_families = prometheus::gather();
1313 let encoded_metrics = encoder
1314 .encode_to_string(&metric_families)
1315 .map_err(|err| anyhow!("{err}"))?;
1316 Ok(encoded_metrics)
1317}
1318
1319#[instrument(err, skip(executor))]
1320async fn connection_lost(
1321 session: Session,
1322 mut teardown: watch::Receiver<bool>,
1323 executor: Executor,
1324) -> Result<()> {
1325 session.peer.disconnect(session.connection_id);
1326 session
1327 .connection_pool()
1328 .await
1329 .remove_connection(session.connection_id)?;
1330
1331 session
1332 .db()
1333 .await
1334 .connection_lost(session.connection_id)
1335 .await
1336 .trace_err();
1337
1338 futures::select_biased! {
1339 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1340
1341 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1342 leave_room_for_session(&session, session.connection_id).await.trace_err();
1343 leave_channel_buffers_for_session(&session)
1344 .await
1345 .trace_err();
1346
1347 if !session
1348 .connection_pool()
1349 .await
1350 .is_user_online(session.user_id())
1351 {
1352 let db = session.db().await;
1353 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1354 room_updated(&room, &session.peer);
1355 }
1356 }
1357
1358 update_user_contacts(session.user_id(), &session).await?;
1359 },
1360 _ = teardown.changed().fuse() => {}
1361 }
1362
1363 Ok(())
1364}
1365
1366/// Acknowledges a ping from a client, used to keep the connection alive.
1367async fn ping(
1368 _: proto::Ping,
1369 response: Response<proto::Ping>,
1370 _session: MessageContext,
1371) -> Result<()> {
1372 response.send(proto::Ack {})?;
1373 Ok(())
1374}
1375
1376/// Creates a new room for calling (outside of channels)
1377async fn create_room(
1378 _request: proto::CreateRoom,
1379 response: Response<proto::CreateRoom>,
1380 session: MessageContext,
1381) -> Result<()> {
1382 let livekit_room = nanoid::nanoid!(30);
1383
1384 let live_kit_connection_info = util::maybe!(async {
1385 let live_kit = session.app_state.livekit_client.as_ref();
1386 let live_kit = live_kit?;
1387 let user_id = session.user_id().to_string();
1388
1389 let token = live_kit
1390 .room_token(&livekit_room, &user_id.to_string())
1391 .trace_err()?;
1392
1393 Some(proto::LiveKitConnectionInfo {
1394 server_url: live_kit.url().into(),
1395 token,
1396 can_publish: true,
1397 })
1398 })
1399 .await;
1400
1401 let room = session
1402 .db()
1403 .await
1404 .create_room(session.user_id(), session.connection_id, &livekit_room)
1405 .await?;
1406
1407 response.send(proto::CreateRoomResponse {
1408 room: Some(room.clone()),
1409 live_kit_connection_info,
1410 })?;
1411
1412 update_user_contacts(session.user_id(), &session).await?;
1413 Ok(())
1414}
1415
1416/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1417async fn join_room(
1418 request: proto::JoinRoom,
1419 response: Response<proto::JoinRoom>,
1420 session: MessageContext,
1421) -> Result<()> {
1422 let room_id = RoomId::from_proto(request.id);
1423
1424 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1425
1426 if let Some(channel_id) = channel_id {
1427 return join_channel_internal(channel_id, Box::new(response), session).await;
1428 }
1429
1430 let joined_room = {
1431 let room = session
1432 .db()
1433 .await
1434 .join_room(room_id, session.user_id(), session.connection_id)
1435 .await?;
1436 room_updated(&room.room, &session.peer);
1437 room.into_inner()
1438 };
1439
1440 for connection_id in session
1441 .connection_pool()
1442 .await
1443 .user_connection_ids(session.user_id())
1444 {
1445 session
1446 .peer
1447 .send(
1448 connection_id,
1449 proto::CallCanceled {
1450 room_id: room_id.to_proto(),
1451 },
1452 )
1453 .trace_err();
1454 }
1455
1456 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1457 {
1458 live_kit
1459 .room_token(
1460 &joined_room.room.livekit_room,
1461 &session.user_id().to_string(),
1462 )
1463 .trace_err()
1464 .map(|token| proto::LiveKitConnectionInfo {
1465 server_url: live_kit.url().into(),
1466 token,
1467 can_publish: true,
1468 })
1469 } else {
1470 None
1471 };
1472
1473 response.send(proto::JoinRoomResponse {
1474 room: Some(joined_room.room),
1475 channel_id: None,
1476 live_kit_connection_info,
1477 })?;
1478
1479 update_user_contacts(session.user_id(), &session).await?;
1480 Ok(())
1481}
1482
1483/// Rejoin room is used to reconnect to a room after connection errors.
1484async fn rejoin_room(
1485 request: proto::RejoinRoom,
1486 response: Response<proto::RejoinRoom>,
1487 session: MessageContext,
1488) -> Result<()> {
1489 let room;
1490 let channel;
1491 {
1492 let mut rejoined_room = session
1493 .db()
1494 .await
1495 .rejoin_room(request, session.user_id(), session.connection_id)
1496 .await?;
1497
1498 response.send(proto::RejoinRoomResponse {
1499 room: Some(rejoined_room.room.clone()),
1500 reshared_projects: rejoined_room
1501 .reshared_projects
1502 .iter()
1503 .map(|project| proto::ResharedProject {
1504 id: project.id.to_proto(),
1505 collaborators: project
1506 .collaborators
1507 .iter()
1508 .map(|collaborator| collaborator.to_proto())
1509 .collect(),
1510 })
1511 .collect(),
1512 rejoined_projects: rejoined_room
1513 .rejoined_projects
1514 .iter()
1515 .map(|rejoined_project| rejoined_project.to_proto())
1516 .collect(),
1517 })?;
1518 room_updated(&rejoined_room.room, &session.peer);
1519
1520 for project in &rejoined_room.reshared_projects {
1521 for collaborator in &project.collaborators {
1522 session
1523 .peer
1524 .send(
1525 collaborator.connection_id,
1526 proto::UpdateProjectCollaborator {
1527 project_id: project.id.to_proto(),
1528 old_peer_id: Some(project.old_connection_id.into()),
1529 new_peer_id: Some(session.connection_id.into()),
1530 },
1531 )
1532 .trace_err();
1533 }
1534
1535 broadcast(
1536 Some(session.connection_id),
1537 project
1538 .collaborators
1539 .iter()
1540 .map(|collaborator| collaborator.connection_id),
1541 |connection_id| {
1542 session.peer.forward_send(
1543 session.connection_id,
1544 connection_id,
1545 proto::UpdateProject {
1546 project_id: project.id.to_proto(),
1547 worktrees: project.worktrees.clone(),
1548 },
1549 )
1550 },
1551 );
1552 }
1553
1554 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1555
1556 let rejoined_room = rejoined_room.into_inner();
1557
1558 room = rejoined_room.room;
1559 channel = rejoined_room.channel;
1560 }
1561
1562 if let Some(channel) = channel {
1563 channel_updated(
1564 &channel,
1565 &room,
1566 &session.peer,
1567 &*session.connection_pool().await,
1568 );
1569 }
1570
1571 update_user_contacts(session.user_id(), &session).await?;
1572 Ok(())
1573}
1574
1575fn notify_rejoined_projects(
1576 rejoined_projects: &mut Vec<RejoinedProject>,
1577 session: &Session,
1578) -> Result<()> {
1579 for project in rejoined_projects.iter() {
1580 for collaborator in &project.collaborators {
1581 session
1582 .peer
1583 .send(
1584 collaborator.connection_id,
1585 proto::UpdateProjectCollaborator {
1586 project_id: project.id.to_proto(),
1587 old_peer_id: Some(project.old_connection_id.into()),
1588 new_peer_id: Some(session.connection_id.into()),
1589 },
1590 )
1591 .trace_err();
1592 }
1593 }
1594
1595 for project in rejoined_projects {
1596 for worktree in mem::take(&mut project.worktrees) {
1597 // Stream this worktree's entries.
1598 let message = proto::UpdateWorktree {
1599 project_id: project.id.to_proto(),
1600 worktree_id: worktree.id,
1601 abs_path: worktree.abs_path.clone(),
1602 root_name: worktree.root_name,
1603 updated_entries: worktree.updated_entries,
1604 removed_entries: worktree.removed_entries,
1605 scan_id: worktree.scan_id,
1606 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1607 updated_repositories: worktree.updated_repositories,
1608 removed_repositories: worktree.removed_repositories,
1609 };
1610 for update in proto::split_worktree_update(message) {
1611 session.peer.send(session.connection_id, update)?;
1612 }
1613
1614 // Stream this worktree's diagnostics.
1615 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1616 if let Some(summary) = worktree_diagnostics.next() {
1617 let message = proto::UpdateDiagnosticSummary {
1618 project_id: project.id.to_proto(),
1619 worktree_id: worktree.id,
1620 summary: Some(summary),
1621 more_summaries: worktree_diagnostics.collect(),
1622 };
1623 session.peer.send(session.connection_id, message)?;
1624 }
1625
1626 for settings_file in worktree.settings_files {
1627 session.peer.send(
1628 session.connection_id,
1629 proto::UpdateWorktreeSettings {
1630 project_id: project.id.to_proto(),
1631 worktree_id: worktree.id,
1632 path: settings_file.path,
1633 content: Some(settings_file.content),
1634 kind: Some(settings_file.kind.to_proto().into()),
1635 },
1636 )?;
1637 }
1638 }
1639
1640 for repository in mem::take(&mut project.updated_repositories) {
1641 for update in split_repository_update(repository) {
1642 session.peer.send(session.connection_id, update)?;
1643 }
1644 }
1645
1646 for id in mem::take(&mut project.removed_repositories) {
1647 session.peer.send(
1648 session.connection_id,
1649 proto::RemoveRepository {
1650 project_id: project.id.to_proto(),
1651 id,
1652 },
1653 )?;
1654 }
1655 }
1656
1657 Ok(())
1658}
1659
1660/// leave room disconnects from the room.
1661async fn leave_room(
1662 _: proto::LeaveRoom,
1663 response: Response<proto::LeaveRoom>,
1664 session: MessageContext,
1665) -> Result<()> {
1666 leave_room_for_session(&session, session.connection_id).await?;
1667 response.send(proto::Ack {})?;
1668 Ok(())
1669}
1670
1671/// Updates the permissions of someone else in the room.
1672async fn set_room_participant_role(
1673 request: proto::SetRoomParticipantRole,
1674 response: Response<proto::SetRoomParticipantRole>,
1675 session: MessageContext,
1676) -> Result<()> {
1677 let user_id = UserId::from_proto(request.user_id);
1678 let role = ChannelRole::from(request.role());
1679
1680 let (livekit_room, can_publish) = {
1681 let room = session
1682 .db()
1683 .await
1684 .set_room_participant_role(
1685 session.user_id(),
1686 RoomId::from_proto(request.room_id),
1687 user_id,
1688 role,
1689 )
1690 .await?;
1691
1692 let livekit_room = room.livekit_room.clone();
1693 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1694 room_updated(&room, &session.peer);
1695 (livekit_room, can_publish)
1696 };
1697
1698 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1699 live_kit
1700 .update_participant(
1701 livekit_room.clone(),
1702 request.user_id.to_string(),
1703 livekit_api::proto::ParticipantPermission {
1704 can_subscribe: true,
1705 can_publish,
1706 can_publish_data: can_publish,
1707 hidden: false,
1708 recorder: false,
1709 },
1710 )
1711 .await
1712 .trace_err();
1713 }
1714
1715 response.send(proto::Ack {})?;
1716 Ok(())
1717}
1718
1719/// Call someone else into the current room
1720async fn call(
1721 request: proto::Call,
1722 response: Response<proto::Call>,
1723 session: MessageContext,
1724) -> Result<()> {
1725 let room_id = RoomId::from_proto(request.room_id);
1726 let calling_user_id = session.user_id();
1727 let calling_connection_id = session.connection_id;
1728 let called_user_id = UserId::from_proto(request.called_user_id);
1729 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1730 if !session
1731 .db()
1732 .await
1733 .has_contact(calling_user_id, called_user_id)
1734 .await?
1735 {
1736 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1737 }
1738
1739 let incoming_call = {
1740 let (room, incoming_call) = &mut *session
1741 .db()
1742 .await
1743 .call(
1744 room_id,
1745 calling_user_id,
1746 calling_connection_id,
1747 called_user_id,
1748 initial_project_id,
1749 )
1750 .await?;
1751 room_updated(room, &session.peer);
1752 mem::take(incoming_call)
1753 };
1754 update_user_contacts(called_user_id, &session).await?;
1755
1756 let mut calls = session
1757 .connection_pool()
1758 .await
1759 .user_connection_ids(called_user_id)
1760 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1761 .collect::<FuturesUnordered<_>>();
1762
1763 while let Some(call_response) = calls.next().await {
1764 match call_response.as_ref() {
1765 Ok(_) => {
1766 response.send(proto::Ack {})?;
1767 return Ok(());
1768 }
1769 Err(_) => {
1770 call_response.trace_err();
1771 }
1772 }
1773 }
1774
1775 {
1776 let room = session
1777 .db()
1778 .await
1779 .call_failed(room_id, called_user_id)
1780 .await?;
1781 room_updated(&room, &session.peer);
1782 }
1783 update_user_contacts(called_user_id, &session).await?;
1784
1785 Err(anyhow!("failed to ring user"))?
1786}
1787
1788/// Cancel an outgoing call.
1789async fn cancel_call(
1790 request: proto::CancelCall,
1791 response: Response<proto::CancelCall>,
1792 session: MessageContext,
1793) -> Result<()> {
1794 let called_user_id = UserId::from_proto(request.called_user_id);
1795 let room_id = RoomId::from_proto(request.room_id);
1796 {
1797 let room = session
1798 .db()
1799 .await
1800 .cancel_call(room_id, session.connection_id, called_user_id)
1801 .await?;
1802 room_updated(&room, &session.peer);
1803 }
1804
1805 for connection_id in session
1806 .connection_pool()
1807 .await
1808 .user_connection_ids(called_user_id)
1809 {
1810 session
1811 .peer
1812 .send(
1813 connection_id,
1814 proto::CallCanceled {
1815 room_id: room_id.to_proto(),
1816 },
1817 )
1818 .trace_err();
1819 }
1820 response.send(proto::Ack {})?;
1821
1822 update_user_contacts(called_user_id, &session).await?;
1823 Ok(())
1824}
1825
1826/// Decline an incoming call.
1827async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1828 let room_id = RoomId::from_proto(message.room_id);
1829 {
1830 let room = session
1831 .db()
1832 .await
1833 .decline_call(Some(room_id), session.user_id())
1834 .await?
1835 .context("declining call")?;
1836 room_updated(&room, &session.peer);
1837 }
1838
1839 for connection_id in session
1840 .connection_pool()
1841 .await
1842 .user_connection_ids(session.user_id())
1843 {
1844 session
1845 .peer
1846 .send(
1847 connection_id,
1848 proto::CallCanceled {
1849 room_id: room_id.to_proto(),
1850 },
1851 )
1852 .trace_err();
1853 }
1854 update_user_contacts(session.user_id(), &session).await?;
1855 Ok(())
1856}
1857
1858/// Updates other participants in the room with your current location.
1859async fn update_participant_location(
1860 request: proto::UpdateParticipantLocation,
1861 response: Response<proto::UpdateParticipantLocation>,
1862 session: MessageContext,
1863) -> Result<()> {
1864 let room_id = RoomId::from_proto(request.room_id);
1865 let location = request.location.context("invalid location")?;
1866
1867 let db = session.db().await;
1868 let room = db
1869 .update_room_participant_location(room_id, session.connection_id, location)
1870 .await?;
1871
1872 room_updated(&room, &session.peer);
1873 response.send(proto::Ack {})?;
1874 Ok(())
1875}
1876
1877/// Share a project into the room.
1878async fn share_project(
1879 request: proto::ShareProject,
1880 response: Response<proto::ShareProject>,
1881 session: MessageContext,
1882) -> Result<()> {
1883 let (project_id, room) = &*session
1884 .db()
1885 .await
1886 .share_project(
1887 RoomId::from_proto(request.room_id),
1888 session.connection_id,
1889 &request.worktrees,
1890 request.is_ssh_project,
1891 )
1892 .await?;
1893 response.send(proto::ShareProjectResponse {
1894 project_id: project_id.to_proto(),
1895 })?;
1896 room_updated(room, &session.peer);
1897
1898 Ok(())
1899}
1900
1901/// Unshare a project from the room.
1902async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1903 let project_id = ProjectId::from_proto(message.project_id);
1904 unshare_project_internal(project_id, session.connection_id, &session).await
1905}
1906
1907async fn unshare_project_internal(
1908 project_id: ProjectId,
1909 connection_id: ConnectionId,
1910 session: &Session,
1911) -> Result<()> {
1912 let delete = {
1913 let room_guard = session
1914 .db()
1915 .await
1916 .unshare_project(project_id, connection_id)
1917 .await?;
1918
1919 let (delete, room, guest_connection_ids) = &*room_guard;
1920
1921 let message = proto::UnshareProject {
1922 project_id: project_id.to_proto(),
1923 };
1924
1925 broadcast(
1926 Some(connection_id),
1927 guest_connection_ids.iter().copied(),
1928 |conn_id| session.peer.send(conn_id, message.clone()),
1929 );
1930 if let Some(room) = room {
1931 room_updated(room, &session.peer);
1932 }
1933
1934 *delete
1935 };
1936
1937 if delete {
1938 let db = session.db().await;
1939 db.delete_project(project_id).await?;
1940 }
1941
1942 Ok(())
1943}
1944
1945/// Join someone elses shared project.
1946async fn join_project(
1947 request: proto::JoinProject,
1948 response: Response<proto::JoinProject>,
1949 session: MessageContext,
1950) -> Result<()> {
1951 let project_id = ProjectId::from_proto(request.project_id);
1952
1953 tracing::info!(%project_id, "join project");
1954
1955 let db = session.db().await;
1956 let (project, replica_id) = &mut *db
1957 .join_project(
1958 project_id,
1959 session.connection_id,
1960 session.user_id(),
1961 request.committer_name.clone(),
1962 request.committer_email.clone(),
1963 )
1964 .await?;
1965 drop(db);
1966 tracing::info!(%project_id, "join remote project");
1967 let collaborators = project
1968 .collaborators
1969 .iter()
1970 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1971 .map(|collaborator| collaborator.to_proto())
1972 .collect::<Vec<_>>();
1973 let project_id = project.id;
1974 let guest_user_id = session.user_id();
1975
1976 let worktrees = project
1977 .worktrees
1978 .iter()
1979 .map(|(id, worktree)| proto::WorktreeMetadata {
1980 id: *id,
1981 root_name: worktree.root_name.clone(),
1982 visible: worktree.visible,
1983 abs_path: worktree.abs_path.clone(),
1984 })
1985 .collect::<Vec<_>>();
1986
1987 let add_project_collaborator = proto::AddProjectCollaborator {
1988 project_id: project_id.to_proto(),
1989 collaborator: Some(proto::Collaborator {
1990 peer_id: Some(session.connection_id.into()),
1991 replica_id: replica_id.0 as u32,
1992 user_id: guest_user_id.to_proto(),
1993 is_host: false,
1994 committer_name: request.committer_name.clone(),
1995 committer_email: request.committer_email.clone(),
1996 }),
1997 };
1998
1999 for collaborator in &collaborators {
2000 session
2001 .peer
2002 .send(
2003 collaborator.peer_id.unwrap().into(),
2004 add_project_collaborator.clone(),
2005 )
2006 .trace_err();
2007 }
2008
2009 // First, we send the metadata associated with each worktree.
2010 let (language_servers, language_server_capabilities) = project
2011 .language_servers
2012 .clone()
2013 .into_iter()
2014 .map(|server| (server.server, server.capabilities))
2015 .unzip();
2016 response.send(proto::JoinProjectResponse {
2017 project_id: project.id.0 as u64,
2018 worktrees: worktrees.clone(),
2019 replica_id: replica_id.0 as u32,
2020 collaborators: collaborators.clone(),
2021 language_servers,
2022 language_server_capabilities,
2023 role: project.role.into(),
2024 })?;
2025
2026 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2027 // Stream this worktree's entries.
2028 let message = proto::UpdateWorktree {
2029 project_id: project_id.to_proto(),
2030 worktree_id,
2031 abs_path: worktree.abs_path.clone(),
2032 root_name: worktree.root_name,
2033 updated_entries: worktree.entries,
2034 removed_entries: Default::default(),
2035 scan_id: worktree.scan_id,
2036 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2037 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2038 removed_repositories: Default::default(),
2039 };
2040 for update in proto::split_worktree_update(message) {
2041 session.peer.send(session.connection_id, update.clone())?;
2042 }
2043
2044 // Stream this worktree's diagnostics.
2045 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2046 if let Some(summary) = worktree_diagnostics.next() {
2047 let message = proto::UpdateDiagnosticSummary {
2048 project_id: project.id.to_proto(),
2049 worktree_id: worktree.id,
2050 summary: Some(summary),
2051 more_summaries: worktree_diagnostics.collect(),
2052 };
2053 session.peer.send(session.connection_id, message)?;
2054 }
2055
2056 for settings_file in worktree.settings_files {
2057 session.peer.send(
2058 session.connection_id,
2059 proto::UpdateWorktreeSettings {
2060 project_id: project_id.to_proto(),
2061 worktree_id: worktree.id,
2062 path: settings_file.path,
2063 content: Some(settings_file.content),
2064 kind: Some(settings_file.kind.to_proto() as i32),
2065 },
2066 )?;
2067 }
2068 }
2069
2070 for repository in mem::take(&mut project.repositories) {
2071 for update in split_repository_update(repository) {
2072 session.peer.send(session.connection_id, update)?;
2073 }
2074 }
2075
2076 for language_server in &project.language_servers {
2077 session.peer.send(
2078 session.connection_id,
2079 proto::UpdateLanguageServer {
2080 project_id: project_id.to_proto(),
2081 server_name: Some(language_server.server.name.clone()),
2082 language_server_id: language_server.server.id,
2083 variant: Some(
2084 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2085 proto::LspDiskBasedDiagnosticsUpdated {},
2086 ),
2087 ),
2088 },
2089 )?;
2090 }
2091
2092 Ok(())
2093}
2094
2095/// Leave someone elses shared project.
2096async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2097 let sender_id = session.connection_id;
2098 let project_id = ProjectId::from_proto(request.project_id);
2099 let db = session.db().await;
2100
2101 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2102 tracing::info!(
2103 %project_id,
2104 "leave project"
2105 );
2106
2107 project_left(project, &session);
2108 if let Some(room) = room {
2109 room_updated(room, &session.peer);
2110 }
2111
2112 Ok(())
2113}
2114
2115/// Updates other participants with changes to the project
2116async fn update_project(
2117 request: proto::UpdateProject,
2118 response: Response<proto::UpdateProject>,
2119 session: MessageContext,
2120) -> Result<()> {
2121 let project_id = ProjectId::from_proto(request.project_id);
2122 let (room, guest_connection_ids) = &*session
2123 .db()
2124 .await
2125 .update_project(project_id, session.connection_id, &request.worktrees)
2126 .await?;
2127 broadcast(
2128 Some(session.connection_id),
2129 guest_connection_ids.iter().copied(),
2130 |connection_id| {
2131 session
2132 .peer
2133 .forward_send(session.connection_id, connection_id, request.clone())
2134 },
2135 );
2136 if let Some(room) = room {
2137 room_updated(room, &session.peer);
2138 }
2139 response.send(proto::Ack {})?;
2140
2141 Ok(())
2142}
2143
2144/// Updates other participants with changes to the worktree
2145async fn update_worktree(
2146 request: proto::UpdateWorktree,
2147 response: Response<proto::UpdateWorktree>,
2148 session: MessageContext,
2149) -> Result<()> {
2150 let guest_connection_ids = session
2151 .db()
2152 .await
2153 .update_worktree(&request, session.connection_id)
2154 .await?;
2155
2156 broadcast(
2157 Some(session.connection_id),
2158 guest_connection_ids.iter().copied(),
2159 |connection_id| {
2160 session
2161 .peer
2162 .forward_send(session.connection_id, connection_id, request.clone())
2163 },
2164 );
2165 response.send(proto::Ack {})?;
2166 Ok(())
2167}
2168
2169async fn update_repository(
2170 request: proto::UpdateRepository,
2171 response: Response<proto::UpdateRepository>,
2172 session: MessageContext,
2173) -> Result<()> {
2174 let guest_connection_ids = session
2175 .db()
2176 .await
2177 .update_repository(&request, session.connection_id)
2178 .await?;
2179
2180 broadcast(
2181 Some(session.connection_id),
2182 guest_connection_ids.iter().copied(),
2183 |connection_id| {
2184 session
2185 .peer
2186 .forward_send(session.connection_id, connection_id, request.clone())
2187 },
2188 );
2189 response.send(proto::Ack {})?;
2190 Ok(())
2191}
2192
2193async fn remove_repository(
2194 request: proto::RemoveRepository,
2195 response: Response<proto::RemoveRepository>,
2196 session: MessageContext,
2197) -> Result<()> {
2198 let guest_connection_ids = session
2199 .db()
2200 .await
2201 .remove_repository(&request, session.connection_id)
2202 .await?;
2203
2204 broadcast(
2205 Some(session.connection_id),
2206 guest_connection_ids.iter().copied(),
2207 |connection_id| {
2208 session
2209 .peer
2210 .forward_send(session.connection_id, connection_id, request.clone())
2211 },
2212 );
2213 response.send(proto::Ack {})?;
2214 Ok(())
2215}
2216
2217/// Updates other participants with changes to the diagnostics
2218async fn update_diagnostic_summary(
2219 message: proto::UpdateDiagnosticSummary,
2220 session: MessageContext,
2221) -> Result<()> {
2222 let guest_connection_ids = session
2223 .db()
2224 .await
2225 .update_diagnostic_summary(&message, session.connection_id)
2226 .await?;
2227
2228 broadcast(
2229 Some(session.connection_id),
2230 guest_connection_ids.iter().copied(),
2231 |connection_id| {
2232 session
2233 .peer
2234 .forward_send(session.connection_id, connection_id, message.clone())
2235 },
2236 );
2237
2238 Ok(())
2239}
2240
2241/// Updates other participants with changes to the worktree settings
2242async fn update_worktree_settings(
2243 message: proto::UpdateWorktreeSettings,
2244 session: MessageContext,
2245) -> Result<()> {
2246 let guest_connection_ids = session
2247 .db()
2248 .await
2249 .update_worktree_settings(&message, session.connection_id)
2250 .await?;
2251
2252 broadcast(
2253 Some(session.connection_id),
2254 guest_connection_ids.iter().copied(),
2255 |connection_id| {
2256 session
2257 .peer
2258 .forward_send(session.connection_id, connection_id, message.clone())
2259 },
2260 );
2261
2262 Ok(())
2263}
2264
2265/// Notify other participants that a language server has started.
2266async fn start_language_server(
2267 request: proto::StartLanguageServer,
2268 session: MessageContext,
2269) -> Result<()> {
2270 let guest_connection_ids = session
2271 .db()
2272 .await
2273 .start_language_server(&request, session.connection_id)
2274 .await?;
2275
2276 broadcast(
2277 Some(session.connection_id),
2278 guest_connection_ids.iter().copied(),
2279 |connection_id| {
2280 session
2281 .peer
2282 .forward_send(session.connection_id, connection_id, request.clone())
2283 },
2284 );
2285 Ok(())
2286}
2287
2288/// Notify other participants that a language server has changed.
2289async fn update_language_server(
2290 request: proto::UpdateLanguageServer,
2291 session: MessageContext,
2292) -> Result<()> {
2293 let project_id = ProjectId::from_proto(request.project_id);
2294 let db = session.db().await;
2295
2296 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2297 {
2298 if let Some(capabilities) = update.capabilities.clone() {
2299 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2300 .await?;
2301 }
2302 }
2303
2304 let project_connection_ids = db
2305 .project_connection_ids(project_id, session.connection_id, true)
2306 .await?;
2307 broadcast(
2308 Some(session.connection_id),
2309 project_connection_ids.iter().copied(),
2310 |connection_id| {
2311 session
2312 .peer
2313 .forward_send(session.connection_id, connection_id, request.clone())
2314 },
2315 );
2316 Ok(())
2317}
2318
2319/// forward a project request to the host. These requests should be read only
2320/// as guests are allowed to send them.
2321async fn forward_read_only_project_request<T>(
2322 request: T,
2323 response: Response<T>,
2324 session: MessageContext,
2325) -> Result<()>
2326where
2327 T: EntityMessage + RequestMessage,
2328{
2329 let project_id = ProjectId::from_proto(request.remote_entity_id());
2330 let host_connection_id = session
2331 .db()
2332 .await
2333 .host_for_read_only_project_request(project_id, session.connection_id)
2334 .await?;
2335 let payload = session.forward_request(host_connection_id, request).await?;
2336 response.send(payload)?;
2337 Ok(())
2338}
2339
2340/// forward a project request to the host. These requests are disallowed
2341/// for guests.
2342async fn forward_mutating_project_request<T>(
2343 request: T,
2344 response: Response<T>,
2345 session: MessageContext,
2346) -> Result<()>
2347where
2348 T: EntityMessage + RequestMessage,
2349{
2350 let project_id = ProjectId::from_proto(request.remote_entity_id());
2351
2352 let host_connection_id = session
2353 .db()
2354 .await
2355 .host_for_mutating_project_request(project_id, session.connection_id)
2356 .await?;
2357 let payload = session.forward_request(host_connection_id, request).await?;
2358 response.send(payload)?;
2359 Ok(())
2360}
2361
2362async fn multi_lsp_query(
2363 request: MultiLspQuery,
2364 response: Response<MultiLspQuery>,
2365 session: MessageContext,
2366) -> Result<()> {
2367 tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2368 tracing::info!("multi_lsp_query message received");
2369 forward_mutating_project_request(request, response, session).await
2370}
2371
2372/// Notify other participants that a new buffer has been created
2373async fn create_buffer_for_peer(
2374 request: proto::CreateBufferForPeer,
2375 session: MessageContext,
2376) -> Result<()> {
2377 session
2378 .db()
2379 .await
2380 .check_user_is_project_host(
2381 ProjectId::from_proto(request.project_id),
2382 session.connection_id,
2383 )
2384 .await?;
2385 let peer_id = request.peer_id.context("invalid peer id")?;
2386 session
2387 .peer
2388 .forward_send(session.connection_id, peer_id.into(), request)?;
2389 Ok(())
2390}
2391
2392/// Notify other participants that a buffer has been updated. This is
2393/// allowed for guests as long as the update is limited to selections.
2394async fn update_buffer(
2395 request: proto::UpdateBuffer,
2396 response: Response<proto::UpdateBuffer>,
2397 session: MessageContext,
2398) -> Result<()> {
2399 let project_id = ProjectId::from_proto(request.project_id);
2400 let mut capability = Capability::ReadOnly;
2401
2402 for op in request.operations.iter() {
2403 match op.variant {
2404 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2405 Some(_) => capability = Capability::ReadWrite,
2406 }
2407 }
2408
2409 let host = {
2410 let guard = session
2411 .db()
2412 .await
2413 .connections_for_buffer_update(project_id, session.connection_id, capability)
2414 .await?;
2415
2416 let (host, guests) = &*guard;
2417
2418 broadcast(
2419 Some(session.connection_id),
2420 guests.clone(),
2421 |connection_id| {
2422 session
2423 .peer
2424 .forward_send(session.connection_id, connection_id, request.clone())
2425 },
2426 );
2427
2428 *host
2429 };
2430
2431 if host != session.connection_id {
2432 session.forward_request(host, request.clone()).await?;
2433 }
2434
2435 response.send(proto::Ack {})?;
2436 Ok(())
2437}
2438
2439async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2440 let project_id = ProjectId::from_proto(message.project_id);
2441
2442 let operation = message.operation.as_ref().context("invalid operation")?;
2443 let capability = match operation.variant.as_ref() {
2444 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2445 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2446 match buffer_op.variant {
2447 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2448 Capability::ReadOnly
2449 }
2450 _ => Capability::ReadWrite,
2451 }
2452 } else {
2453 Capability::ReadWrite
2454 }
2455 }
2456 Some(_) => Capability::ReadWrite,
2457 None => Capability::ReadOnly,
2458 };
2459
2460 let guard = session
2461 .db()
2462 .await
2463 .connections_for_buffer_update(project_id, session.connection_id, capability)
2464 .await?;
2465
2466 let (host, guests) = &*guard;
2467
2468 broadcast(
2469 Some(session.connection_id),
2470 guests.iter().chain([host]).copied(),
2471 |connection_id| {
2472 session
2473 .peer
2474 .forward_send(session.connection_id, connection_id, message.clone())
2475 },
2476 );
2477
2478 Ok(())
2479}
2480
2481/// Notify other participants that a project has been updated.
2482async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2483 request: T,
2484 session: MessageContext,
2485) -> Result<()> {
2486 let project_id = ProjectId::from_proto(request.remote_entity_id());
2487 let project_connection_ids = session
2488 .db()
2489 .await
2490 .project_connection_ids(project_id, session.connection_id, false)
2491 .await?;
2492
2493 broadcast(
2494 Some(session.connection_id),
2495 project_connection_ids.iter().copied(),
2496 |connection_id| {
2497 session
2498 .peer
2499 .forward_send(session.connection_id, connection_id, request.clone())
2500 },
2501 );
2502 Ok(())
2503}
2504
2505/// Start following another user in a call.
2506async fn follow(
2507 request: proto::Follow,
2508 response: Response<proto::Follow>,
2509 session: MessageContext,
2510) -> Result<()> {
2511 let room_id = RoomId::from_proto(request.room_id);
2512 let project_id = request.project_id.map(ProjectId::from_proto);
2513 let leader_id = request.leader_id.context("invalid leader id")?.into();
2514 let follower_id = session.connection_id;
2515
2516 session
2517 .db()
2518 .await
2519 .check_room_participants(room_id, leader_id, session.connection_id)
2520 .await?;
2521
2522 let response_payload = session.forward_request(leader_id, request).await?;
2523 response.send(response_payload)?;
2524
2525 if let Some(project_id) = project_id {
2526 let room = session
2527 .db()
2528 .await
2529 .follow(room_id, project_id, leader_id, follower_id)
2530 .await?;
2531 room_updated(&room, &session.peer);
2532 }
2533
2534 Ok(())
2535}
2536
2537/// Stop following another user in a call.
2538async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2539 let room_id = RoomId::from_proto(request.room_id);
2540 let project_id = request.project_id.map(ProjectId::from_proto);
2541 let leader_id = request.leader_id.context("invalid leader id")?.into();
2542 let follower_id = session.connection_id;
2543
2544 session
2545 .db()
2546 .await
2547 .check_room_participants(room_id, leader_id, session.connection_id)
2548 .await?;
2549
2550 session
2551 .peer
2552 .forward_send(session.connection_id, leader_id, request)?;
2553
2554 if let Some(project_id) = project_id {
2555 let room = session
2556 .db()
2557 .await
2558 .unfollow(room_id, project_id, leader_id, follower_id)
2559 .await?;
2560 room_updated(&room, &session.peer);
2561 }
2562
2563 Ok(())
2564}
2565
2566/// Notify everyone following you of your current location.
2567async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2568 let room_id = RoomId::from_proto(request.room_id);
2569 let database = session.db.lock().await;
2570
2571 let connection_ids = if let Some(project_id) = request.project_id {
2572 let project_id = ProjectId::from_proto(project_id);
2573 database
2574 .project_connection_ids(project_id, session.connection_id, true)
2575 .await?
2576 } else {
2577 database
2578 .room_connection_ids(room_id, session.connection_id)
2579 .await?
2580 };
2581
2582 // For now, don't send view update messages back to that view's current leader.
2583 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2584 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2585 _ => None,
2586 });
2587
2588 for connection_id in connection_ids.iter().cloned() {
2589 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2590 session
2591 .peer
2592 .forward_send(session.connection_id, connection_id, request.clone())?;
2593 }
2594 }
2595 Ok(())
2596}
2597
2598/// Get public data about users.
2599async fn get_users(
2600 request: proto::GetUsers,
2601 response: Response<proto::GetUsers>,
2602 session: MessageContext,
2603) -> Result<()> {
2604 let user_ids = request
2605 .user_ids
2606 .into_iter()
2607 .map(UserId::from_proto)
2608 .collect();
2609 let users = session
2610 .db()
2611 .await
2612 .get_users_by_ids(user_ids)
2613 .await?
2614 .into_iter()
2615 .map(|user| proto::User {
2616 id: user.id.to_proto(),
2617 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2618 github_login: user.github_login,
2619 name: user.name,
2620 })
2621 .collect();
2622 response.send(proto::UsersResponse { users })?;
2623 Ok(())
2624}
2625
2626/// Search for users (to invite) buy Github login
2627async fn fuzzy_search_users(
2628 request: proto::FuzzySearchUsers,
2629 response: Response<proto::FuzzySearchUsers>,
2630 session: MessageContext,
2631) -> Result<()> {
2632 let query = request.query;
2633 let users = match query.len() {
2634 0 => vec![],
2635 1 | 2 => session
2636 .db()
2637 .await
2638 .get_user_by_github_login(&query)
2639 .await?
2640 .into_iter()
2641 .collect(),
2642 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2643 };
2644 let users = users
2645 .into_iter()
2646 .filter(|user| user.id != session.user_id())
2647 .map(|user| proto::User {
2648 id: user.id.to_proto(),
2649 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2650 github_login: user.github_login,
2651 name: user.name,
2652 })
2653 .collect();
2654 response.send(proto::UsersResponse { users })?;
2655 Ok(())
2656}
2657
2658/// Send a contact request to another user.
2659async fn request_contact(
2660 request: proto::RequestContact,
2661 response: Response<proto::RequestContact>,
2662 session: MessageContext,
2663) -> Result<()> {
2664 let requester_id = session.user_id();
2665 let responder_id = UserId::from_proto(request.responder_id);
2666 if requester_id == responder_id {
2667 return Err(anyhow!("cannot add yourself as a contact"))?;
2668 }
2669
2670 let notifications = session
2671 .db()
2672 .await
2673 .send_contact_request(requester_id, responder_id)
2674 .await?;
2675
2676 // Update outgoing contact requests of requester
2677 let mut update = proto::UpdateContacts::default();
2678 update.outgoing_requests.push(responder_id.to_proto());
2679 for connection_id in session
2680 .connection_pool()
2681 .await
2682 .user_connection_ids(requester_id)
2683 {
2684 session.peer.send(connection_id, update.clone())?;
2685 }
2686
2687 // Update incoming contact requests of responder
2688 let mut update = proto::UpdateContacts::default();
2689 update
2690 .incoming_requests
2691 .push(proto::IncomingContactRequest {
2692 requester_id: requester_id.to_proto(),
2693 });
2694 let connection_pool = session.connection_pool().await;
2695 for connection_id in connection_pool.user_connection_ids(responder_id) {
2696 session.peer.send(connection_id, update.clone())?;
2697 }
2698
2699 send_notifications(&connection_pool, &session.peer, notifications);
2700
2701 response.send(proto::Ack {})?;
2702 Ok(())
2703}
2704
2705/// Accept or decline a contact request
2706async fn respond_to_contact_request(
2707 request: proto::RespondToContactRequest,
2708 response: Response<proto::RespondToContactRequest>,
2709 session: MessageContext,
2710) -> Result<()> {
2711 let responder_id = session.user_id();
2712 let requester_id = UserId::from_proto(request.requester_id);
2713 let db = session.db().await;
2714 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2715 db.dismiss_contact_notification(responder_id, requester_id)
2716 .await?;
2717 } else {
2718 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2719
2720 let notifications = db
2721 .respond_to_contact_request(responder_id, requester_id, accept)
2722 .await?;
2723 let requester_busy = db.is_user_busy(requester_id).await?;
2724 let responder_busy = db.is_user_busy(responder_id).await?;
2725
2726 let pool = session.connection_pool().await;
2727 // Update responder with new contact
2728 let mut update = proto::UpdateContacts::default();
2729 if accept {
2730 update
2731 .contacts
2732 .push(contact_for_user(requester_id, requester_busy, &pool));
2733 }
2734 update
2735 .remove_incoming_requests
2736 .push(requester_id.to_proto());
2737 for connection_id in pool.user_connection_ids(responder_id) {
2738 session.peer.send(connection_id, update.clone())?;
2739 }
2740
2741 // Update requester with new contact
2742 let mut update = proto::UpdateContacts::default();
2743 if accept {
2744 update
2745 .contacts
2746 .push(contact_for_user(responder_id, responder_busy, &pool));
2747 }
2748 update
2749 .remove_outgoing_requests
2750 .push(responder_id.to_proto());
2751
2752 for connection_id in pool.user_connection_ids(requester_id) {
2753 session.peer.send(connection_id, update.clone())?;
2754 }
2755
2756 send_notifications(&pool, &session.peer, notifications);
2757 }
2758
2759 response.send(proto::Ack {})?;
2760 Ok(())
2761}
2762
2763/// Remove a contact.
2764async fn remove_contact(
2765 request: proto::RemoveContact,
2766 response: Response<proto::RemoveContact>,
2767 session: MessageContext,
2768) -> Result<()> {
2769 let requester_id = session.user_id();
2770 let responder_id = UserId::from_proto(request.user_id);
2771 let db = session.db().await;
2772 let (contact_accepted, deleted_notification_id) =
2773 db.remove_contact(requester_id, responder_id).await?;
2774
2775 let pool = session.connection_pool().await;
2776 // Update outgoing contact requests of requester
2777 let mut update = proto::UpdateContacts::default();
2778 if contact_accepted {
2779 update.remove_contacts.push(responder_id.to_proto());
2780 } else {
2781 update
2782 .remove_outgoing_requests
2783 .push(responder_id.to_proto());
2784 }
2785 for connection_id in pool.user_connection_ids(requester_id) {
2786 session.peer.send(connection_id, update.clone())?;
2787 }
2788
2789 // Update incoming contact requests of responder
2790 let mut update = proto::UpdateContacts::default();
2791 if contact_accepted {
2792 update.remove_contacts.push(requester_id.to_proto());
2793 } else {
2794 update
2795 .remove_incoming_requests
2796 .push(requester_id.to_proto());
2797 }
2798 for connection_id in pool.user_connection_ids(responder_id) {
2799 session.peer.send(connection_id, update.clone())?;
2800 if let Some(notification_id) = deleted_notification_id {
2801 session.peer.send(
2802 connection_id,
2803 proto::DeleteNotification {
2804 notification_id: notification_id.to_proto(),
2805 },
2806 )?;
2807 }
2808 }
2809
2810 response.send(proto::Ack {})?;
2811 Ok(())
2812}
2813
2814fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2815 version.0.minor() < 139
2816}
2817
2818async fn subscribe_to_channels(
2819 _: proto::SubscribeToChannels,
2820 session: MessageContext,
2821) -> Result<()> {
2822 subscribe_user_to_channels(session.user_id(), &session).await?;
2823 Ok(())
2824}
2825
2826async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2827 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2828 let mut pool = session.connection_pool().await;
2829 for membership in &channels_for_user.channel_memberships {
2830 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2831 }
2832 session.peer.send(
2833 session.connection_id,
2834 build_update_user_channels(&channels_for_user),
2835 )?;
2836 session.peer.send(
2837 session.connection_id,
2838 build_channels_update(channels_for_user),
2839 )?;
2840 Ok(())
2841}
2842
2843/// Creates a new channel.
2844async fn create_channel(
2845 request: proto::CreateChannel,
2846 response: Response<proto::CreateChannel>,
2847 session: MessageContext,
2848) -> Result<()> {
2849 let db = session.db().await;
2850
2851 let parent_id = request.parent_id.map(ChannelId::from_proto);
2852 let (channel, membership) = db
2853 .create_channel(&request.name, parent_id, session.user_id())
2854 .await?;
2855
2856 let root_id = channel.root_id();
2857 let channel = Channel::from_model(channel);
2858
2859 response.send(proto::CreateChannelResponse {
2860 channel: Some(channel.to_proto()),
2861 parent_id: request.parent_id,
2862 })?;
2863
2864 let mut connection_pool = session.connection_pool().await;
2865 if let Some(membership) = membership {
2866 connection_pool.subscribe_to_channel(
2867 membership.user_id,
2868 membership.channel_id,
2869 membership.role,
2870 );
2871 let update = proto::UpdateUserChannels {
2872 channel_memberships: vec![proto::ChannelMembership {
2873 channel_id: membership.channel_id.to_proto(),
2874 role: membership.role.into(),
2875 }],
2876 ..Default::default()
2877 };
2878 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2879 session.peer.send(connection_id, update.clone())?;
2880 }
2881 }
2882
2883 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2884 if !role.can_see_channel(channel.visibility) {
2885 continue;
2886 }
2887
2888 let update = proto::UpdateChannels {
2889 channels: vec![channel.to_proto()],
2890 ..Default::default()
2891 };
2892 session.peer.send(connection_id, update.clone())?;
2893 }
2894
2895 Ok(())
2896}
2897
2898/// Delete a channel
2899async fn delete_channel(
2900 request: proto::DeleteChannel,
2901 response: Response<proto::DeleteChannel>,
2902 session: MessageContext,
2903) -> Result<()> {
2904 let db = session.db().await;
2905
2906 let channel_id = request.channel_id;
2907 let (root_channel, removed_channels) = db
2908 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2909 .await?;
2910 response.send(proto::Ack {})?;
2911
2912 // Notify members of removed channels
2913 let mut update = proto::UpdateChannels::default();
2914 update
2915 .delete_channels
2916 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2917
2918 let connection_pool = session.connection_pool().await;
2919 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2920 session.peer.send(connection_id, update.clone())?;
2921 }
2922
2923 Ok(())
2924}
2925
2926/// Invite someone to join a channel.
2927async fn invite_channel_member(
2928 request: proto::InviteChannelMember,
2929 response: Response<proto::InviteChannelMember>,
2930 session: MessageContext,
2931) -> Result<()> {
2932 let db = session.db().await;
2933 let channel_id = ChannelId::from_proto(request.channel_id);
2934 let invitee_id = UserId::from_proto(request.user_id);
2935 let InviteMemberResult {
2936 channel,
2937 notifications,
2938 } = db
2939 .invite_channel_member(
2940 channel_id,
2941 invitee_id,
2942 session.user_id(),
2943 request.role().into(),
2944 )
2945 .await?;
2946
2947 let update = proto::UpdateChannels {
2948 channel_invitations: vec![channel.to_proto()],
2949 ..Default::default()
2950 };
2951
2952 let connection_pool = session.connection_pool().await;
2953 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2954 session.peer.send(connection_id, update.clone())?;
2955 }
2956
2957 send_notifications(&connection_pool, &session.peer, notifications);
2958
2959 response.send(proto::Ack {})?;
2960 Ok(())
2961}
2962
2963/// remove someone from a channel
2964async fn remove_channel_member(
2965 request: proto::RemoveChannelMember,
2966 response: Response<proto::RemoveChannelMember>,
2967 session: MessageContext,
2968) -> Result<()> {
2969 let db = session.db().await;
2970 let channel_id = ChannelId::from_proto(request.channel_id);
2971 let member_id = UserId::from_proto(request.user_id);
2972
2973 let RemoveChannelMemberResult {
2974 membership_update,
2975 notification_id,
2976 } = db
2977 .remove_channel_member(channel_id, member_id, session.user_id())
2978 .await?;
2979
2980 let mut connection_pool = session.connection_pool().await;
2981 notify_membership_updated(
2982 &mut connection_pool,
2983 membership_update,
2984 member_id,
2985 &session.peer,
2986 );
2987 for connection_id in connection_pool.user_connection_ids(member_id) {
2988 if let Some(notification_id) = notification_id {
2989 session
2990 .peer
2991 .send(
2992 connection_id,
2993 proto::DeleteNotification {
2994 notification_id: notification_id.to_proto(),
2995 },
2996 )
2997 .trace_err();
2998 }
2999 }
3000
3001 response.send(proto::Ack {})?;
3002 Ok(())
3003}
3004
3005/// Toggle the channel between public and private.
3006/// Care is taken to maintain the invariant that public channels only descend from public channels,
3007/// (though members-only channels can appear at any point in the hierarchy).
3008async fn set_channel_visibility(
3009 request: proto::SetChannelVisibility,
3010 response: Response<proto::SetChannelVisibility>,
3011 session: MessageContext,
3012) -> Result<()> {
3013 let db = session.db().await;
3014 let channel_id = ChannelId::from_proto(request.channel_id);
3015 let visibility = request.visibility().into();
3016
3017 let channel_model = db
3018 .set_channel_visibility(channel_id, visibility, session.user_id())
3019 .await?;
3020 let root_id = channel_model.root_id();
3021 let channel = Channel::from_model(channel_model);
3022
3023 let mut connection_pool = session.connection_pool().await;
3024 for (user_id, role) in connection_pool
3025 .channel_user_ids(root_id)
3026 .collect::<Vec<_>>()
3027 .into_iter()
3028 {
3029 let update = if role.can_see_channel(channel.visibility) {
3030 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3031 proto::UpdateChannels {
3032 channels: vec![channel.to_proto()],
3033 ..Default::default()
3034 }
3035 } else {
3036 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3037 proto::UpdateChannels {
3038 delete_channels: vec![channel.id.to_proto()],
3039 ..Default::default()
3040 }
3041 };
3042
3043 for connection_id in connection_pool.user_connection_ids(user_id) {
3044 session.peer.send(connection_id, update.clone())?;
3045 }
3046 }
3047
3048 response.send(proto::Ack {})?;
3049 Ok(())
3050}
3051
3052/// Alter the role for a user in the channel.
3053async fn set_channel_member_role(
3054 request: proto::SetChannelMemberRole,
3055 response: Response<proto::SetChannelMemberRole>,
3056 session: MessageContext,
3057) -> Result<()> {
3058 let db = session.db().await;
3059 let channel_id = ChannelId::from_proto(request.channel_id);
3060 let member_id = UserId::from_proto(request.user_id);
3061 let result = db
3062 .set_channel_member_role(
3063 channel_id,
3064 session.user_id(),
3065 member_id,
3066 request.role().into(),
3067 )
3068 .await?;
3069
3070 match result {
3071 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3072 let mut connection_pool = session.connection_pool().await;
3073 notify_membership_updated(
3074 &mut connection_pool,
3075 membership_update,
3076 member_id,
3077 &session.peer,
3078 )
3079 }
3080 db::SetMemberRoleResult::InviteUpdated(channel) => {
3081 let update = proto::UpdateChannels {
3082 channel_invitations: vec![channel.to_proto()],
3083 ..Default::default()
3084 };
3085
3086 for connection_id in session
3087 .connection_pool()
3088 .await
3089 .user_connection_ids(member_id)
3090 {
3091 session.peer.send(connection_id, update.clone())?;
3092 }
3093 }
3094 }
3095
3096 response.send(proto::Ack {})?;
3097 Ok(())
3098}
3099
3100/// Change the name of a channel
3101async fn rename_channel(
3102 request: proto::RenameChannel,
3103 response: Response<proto::RenameChannel>,
3104 session: MessageContext,
3105) -> Result<()> {
3106 let db = session.db().await;
3107 let channel_id = ChannelId::from_proto(request.channel_id);
3108 let channel_model = db
3109 .rename_channel(channel_id, session.user_id(), &request.name)
3110 .await?;
3111 let root_id = channel_model.root_id();
3112 let channel = Channel::from_model(channel_model);
3113
3114 response.send(proto::RenameChannelResponse {
3115 channel: Some(channel.to_proto()),
3116 })?;
3117
3118 let connection_pool = session.connection_pool().await;
3119 let update = proto::UpdateChannels {
3120 channels: vec![channel.to_proto()],
3121 ..Default::default()
3122 };
3123 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3124 if role.can_see_channel(channel.visibility) {
3125 session.peer.send(connection_id, update.clone())?;
3126 }
3127 }
3128
3129 Ok(())
3130}
3131
3132/// Move a channel to a new parent.
3133async fn move_channel(
3134 request: proto::MoveChannel,
3135 response: Response<proto::MoveChannel>,
3136 session: MessageContext,
3137) -> Result<()> {
3138 let channel_id = ChannelId::from_proto(request.channel_id);
3139 let to = ChannelId::from_proto(request.to);
3140
3141 let (root_id, channels) = session
3142 .db()
3143 .await
3144 .move_channel(channel_id, to, session.user_id())
3145 .await?;
3146
3147 let connection_pool = session.connection_pool().await;
3148 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3149 let channels = channels
3150 .iter()
3151 .filter_map(|channel| {
3152 if role.can_see_channel(channel.visibility) {
3153 Some(channel.to_proto())
3154 } else {
3155 None
3156 }
3157 })
3158 .collect::<Vec<_>>();
3159 if channels.is_empty() {
3160 continue;
3161 }
3162
3163 let update = proto::UpdateChannels {
3164 channels,
3165 ..Default::default()
3166 };
3167
3168 session.peer.send(connection_id, update.clone())?;
3169 }
3170
3171 response.send(Ack {})?;
3172 Ok(())
3173}
3174
3175async fn reorder_channel(
3176 request: proto::ReorderChannel,
3177 response: Response<proto::ReorderChannel>,
3178 session: MessageContext,
3179) -> Result<()> {
3180 let channel_id = ChannelId::from_proto(request.channel_id);
3181 let direction = request.direction();
3182
3183 let updated_channels = session
3184 .db()
3185 .await
3186 .reorder_channel(channel_id, direction, session.user_id())
3187 .await?;
3188
3189 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3190 let connection_pool = session.connection_pool().await;
3191 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3192 let channels = updated_channels
3193 .iter()
3194 .filter_map(|channel| {
3195 if role.can_see_channel(channel.visibility) {
3196 Some(channel.to_proto())
3197 } else {
3198 None
3199 }
3200 })
3201 .collect::<Vec<_>>();
3202
3203 if channels.is_empty() {
3204 continue;
3205 }
3206
3207 let update = proto::UpdateChannels {
3208 channels,
3209 ..Default::default()
3210 };
3211
3212 session.peer.send(connection_id, update.clone())?;
3213 }
3214 }
3215
3216 response.send(Ack {})?;
3217 Ok(())
3218}
3219
3220/// Get the list of channel members
3221async fn get_channel_members(
3222 request: proto::GetChannelMembers,
3223 response: Response<proto::GetChannelMembers>,
3224 session: MessageContext,
3225) -> Result<()> {
3226 let db = session.db().await;
3227 let channel_id = ChannelId::from_proto(request.channel_id);
3228 let limit = if request.limit == 0 {
3229 u16::MAX as u64
3230 } else {
3231 request.limit
3232 };
3233 let (members, users) = db
3234 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3235 .await?;
3236 response.send(proto::GetChannelMembersResponse { members, users })?;
3237 Ok(())
3238}
3239
3240/// Accept or decline a channel invitation.
3241async fn respond_to_channel_invite(
3242 request: proto::RespondToChannelInvite,
3243 response: Response<proto::RespondToChannelInvite>,
3244 session: MessageContext,
3245) -> Result<()> {
3246 let db = session.db().await;
3247 let channel_id = ChannelId::from_proto(request.channel_id);
3248 let RespondToChannelInvite {
3249 membership_update,
3250 notifications,
3251 } = db
3252 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3253 .await?;
3254
3255 let mut connection_pool = session.connection_pool().await;
3256 if let Some(membership_update) = membership_update {
3257 notify_membership_updated(
3258 &mut connection_pool,
3259 membership_update,
3260 session.user_id(),
3261 &session.peer,
3262 );
3263 } else {
3264 let update = proto::UpdateChannels {
3265 remove_channel_invitations: vec![channel_id.to_proto()],
3266 ..Default::default()
3267 };
3268
3269 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3270 session.peer.send(connection_id, update.clone())?;
3271 }
3272 };
3273
3274 send_notifications(&connection_pool, &session.peer, notifications);
3275
3276 response.send(proto::Ack {})?;
3277
3278 Ok(())
3279}
3280
3281/// Join the channels' room
3282async fn join_channel(
3283 request: proto::JoinChannel,
3284 response: Response<proto::JoinChannel>,
3285 session: MessageContext,
3286) -> Result<()> {
3287 let channel_id = ChannelId::from_proto(request.channel_id);
3288 join_channel_internal(channel_id, Box::new(response), session).await
3289}
3290
3291trait JoinChannelInternalResponse {
3292 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3293}
3294impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3295 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3296 Response::<proto::JoinChannel>::send(self, result)
3297 }
3298}
3299impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3300 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3301 Response::<proto::JoinRoom>::send(self, result)
3302 }
3303}
3304
3305async fn join_channel_internal(
3306 channel_id: ChannelId,
3307 response: Box<impl JoinChannelInternalResponse>,
3308 session: MessageContext,
3309) -> Result<()> {
3310 let joined_room = {
3311 let mut db = session.db().await;
3312 // If zed quits without leaving the room, and the user re-opens zed before the
3313 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3314 // room they were in.
3315 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3316 tracing::info!(
3317 stale_connection_id = %connection,
3318 "cleaning up stale connection",
3319 );
3320 drop(db);
3321 leave_room_for_session(&session, connection).await?;
3322 db = session.db().await;
3323 }
3324
3325 let (joined_room, membership_updated, role) = db
3326 .join_channel(channel_id, session.user_id(), session.connection_id)
3327 .await?;
3328
3329 let live_kit_connection_info =
3330 session
3331 .app_state
3332 .livekit_client
3333 .as_ref()
3334 .and_then(|live_kit| {
3335 let (can_publish, token) = if role == ChannelRole::Guest {
3336 (
3337 false,
3338 live_kit
3339 .guest_token(
3340 &joined_room.room.livekit_room,
3341 &session.user_id().to_string(),
3342 )
3343 .trace_err()?,
3344 )
3345 } else {
3346 (
3347 true,
3348 live_kit
3349 .room_token(
3350 &joined_room.room.livekit_room,
3351 &session.user_id().to_string(),
3352 )
3353 .trace_err()?,
3354 )
3355 };
3356
3357 Some(LiveKitConnectionInfo {
3358 server_url: live_kit.url().into(),
3359 token,
3360 can_publish,
3361 })
3362 });
3363
3364 response.send(proto::JoinRoomResponse {
3365 room: Some(joined_room.room.clone()),
3366 channel_id: joined_room
3367 .channel
3368 .as_ref()
3369 .map(|channel| channel.id.to_proto()),
3370 live_kit_connection_info,
3371 })?;
3372
3373 let mut connection_pool = session.connection_pool().await;
3374 if let Some(membership_updated) = membership_updated {
3375 notify_membership_updated(
3376 &mut connection_pool,
3377 membership_updated,
3378 session.user_id(),
3379 &session.peer,
3380 );
3381 }
3382
3383 room_updated(&joined_room.room, &session.peer);
3384
3385 joined_room
3386 };
3387
3388 channel_updated(
3389 &joined_room.channel.context("channel not returned")?,
3390 &joined_room.room,
3391 &session.peer,
3392 &*session.connection_pool().await,
3393 );
3394
3395 update_user_contacts(session.user_id(), &session).await?;
3396 Ok(())
3397}
3398
3399/// Start editing the channel notes
3400async fn join_channel_buffer(
3401 request: proto::JoinChannelBuffer,
3402 response: Response<proto::JoinChannelBuffer>,
3403 session: MessageContext,
3404) -> Result<()> {
3405 let db = session.db().await;
3406 let channel_id = ChannelId::from_proto(request.channel_id);
3407
3408 let open_response = db
3409 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3410 .await?;
3411
3412 let collaborators = open_response.collaborators.clone();
3413 response.send(open_response)?;
3414
3415 let update = UpdateChannelBufferCollaborators {
3416 channel_id: channel_id.to_proto(),
3417 collaborators: collaborators.clone(),
3418 };
3419 channel_buffer_updated(
3420 session.connection_id,
3421 collaborators
3422 .iter()
3423 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3424 &update,
3425 &session.peer,
3426 );
3427
3428 Ok(())
3429}
3430
3431/// Edit the channel notes
3432async fn update_channel_buffer(
3433 request: proto::UpdateChannelBuffer,
3434 session: MessageContext,
3435) -> Result<()> {
3436 let db = session.db().await;
3437 let channel_id = ChannelId::from_proto(request.channel_id);
3438
3439 let (collaborators, epoch, version) = db
3440 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3441 .await?;
3442
3443 channel_buffer_updated(
3444 session.connection_id,
3445 collaborators.clone(),
3446 &proto::UpdateChannelBuffer {
3447 channel_id: channel_id.to_proto(),
3448 operations: request.operations,
3449 },
3450 &session.peer,
3451 );
3452
3453 let pool = &*session.connection_pool().await;
3454
3455 let non_collaborators =
3456 pool.channel_connection_ids(channel_id)
3457 .filter_map(|(connection_id, _)| {
3458 if collaborators.contains(&connection_id) {
3459 None
3460 } else {
3461 Some(connection_id)
3462 }
3463 });
3464
3465 broadcast(None, non_collaborators, |peer_id| {
3466 session.peer.send(
3467 peer_id,
3468 proto::UpdateChannels {
3469 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3470 channel_id: channel_id.to_proto(),
3471 epoch: epoch as u64,
3472 version: version.clone(),
3473 }],
3474 ..Default::default()
3475 },
3476 )
3477 });
3478
3479 Ok(())
3480}
3481
3482/// Rejoin the channel notes after a connection blip
3483async fn rejoin_channel_buffers(
3484 request: proto::RejoinChannelBuffers,
3485 response: Response<proto::RejoinChannelBuffers>,
3486 session: MessageContext,
3487) -> Result<()> {
3488 let db = session.db().await;
3489 let buffers = db
3490 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3491 .await?;
3492
3493 for rejoined_buffer in &buffers {
3494 let collaborators_to_notify = rejoined_buffer
3495 .buffer
3496 .collaborators
3497 .iter()
3498 .filter_map(|c| Some(c.peer_id?.into()));
3499 channel_buffer_updated(
3500 session.connection_id,
3501 collaborators_to_notify,
3502 &proto::UpdateChannelBufferCollaborators {
3503 channel_id: rejoined_buffer.buffer.channel_id,
3504 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3505 },
3506 &session.peer,
3507 );
3508 }
3509
3510 response.send(proto::RejoinChannelBuffersResponse {
3511 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3512 })?;
3513
3514 Ok(())
3515}
3516
3517/// Stop editing the channel notes
3518async fn leave_channel_buffer(
3519 request: proto::LeaveChannelBuffer,
3520 response: Response<proto::LeaveChannelBuffer>,
3521 session: MessageContext,
3522) -> Result<()> {
3523 let db = session.db().await;
3524 let channel_id = ChannelId::from_proto(request.channel_id);
3525
3526 let left_buffer = db
3527 .leave_channel_buffer(channel_id, session.connection_id)
3528 .await?;
3529
3530 response.send(Ack {})?;
3531
3532 channel_buffer_updated(
3533 session.connection_id,
3534 left_buffer.connections,
3535 &proto::UpdateChannelBufferCollaborators {
3536 channel_id: channel_id.to_proto(),
3537 collaborators: left_buffer.collaborators,
3538 },
3539 &session.peer,
3540 );
3541
3542 Ok(())
3543}
3544
3545fn channel_buffer_updated<T: EnvelopedMessage>(
3546 sender_id: ConnectionId,
3547 collaborators: impl IntoIterator<Item = ConnectionId>,
3548 message: &T,
3549 peer: &Peer,
3550) {
3551 broadcast(Some(sender_id), collaborators, |peer_id| {
3552 peer.send(peer_id, message.clone())
3553 });
3554}
3555
3556fn send_notifications(
3557 connection_pool: &ConnectionPool,
3558 peer: &Peer,
3559 notifications: db::NotificationBatch,
3560) {
3561 for (user_id, notification) in notifications {
3562 for connection_id in connection_pool.user_connection_ids(user_id) {
3563 if let Err(error) = peer.send(
3564 connection_id,
3565 proto::AddNotification {
3566 notification: Some(notification.clone()),
3567 },
3568 ) {
3569 tracing::error!(
3570 "failed to send notification to {:?} {}",
3571 connection_id,
3572 error
3573 );
3574 }
3575 }
3576 }
3577}
3578
3579/// Send a message to the channel
3580async fn send_channel_message(
3581 request: proto::SendChannelMessage,
3582 response: Response<proto::SendChannelMessage>,
3583 session: MessageContext,
3584) -> Result<()> {
3585 // Validate the message body.
3586 let body = request.body.trim().to_string();
3587 if body.len() > MAX_MESSAGE_LEN {
3588 return Err(anyhow!("message is too long"))?;
3589 }
3590 if body.is_empty() {
3591 return Err(anyhow!("message can't be blank"))?;
3592 }
3593
3594 // TODO: adjust mentions if body is trimmed
3595
3596 let timestamp = OffsetDateTime::now_utc();
3597 let nonce = request.nonce.context("nonce can't be blank")?;
3598
3599 let channel_id = ChannelId::from_proto(request.channel_id);
3600 let CreatedChannelMessage {
3601 message_id,
3602 participant_connection_ids,
3603 notifications,
3604 } = session
3605 .db()
3606 .await
3607 .create_channel_message(
3608 channel_id,
3609 session.user_id(),
3610 &body,
3611 &request.mentions,
3612 timestamp,
3613 nonce.clone().into(),
3614 request.reply_to_message_id.map(MessageId::from_proto),
3615 )
3616 .await?;
3617
3618 let message = proto::ChannelMessage {
3619 sender_id: session.user_id().to_proto(),
3620 id: message_id.to_proto(),
3621 body,
3622 mentions: request.mentions,
3623 timestamp: timestamp.unix_timestamp() as u64,
3624 nonce: Some(nonce),
3625 reply_to_message_id: request.reply_to_message_id,
3626 edited_at: None,
3627 };
3628 broadcast(
3629 Some(session.connection_id),
3630 participant_connection_ids.clone(),
3631 |connection| {
3632 session.peer.send(
3633 connection,
3634 proto::ChannelMessageSent {
3635 channel_id: channel_id.to_proto(),
3636 message: Some(message.clone()),
3637 },
3638 )
3639 },
3640 );
3641 response.send(proto::SendChannelMessageResponse {
3642 message: Some(message),
3643 })?;
3644
3645 let pool = &*session.connection_pool().await;
3646 let non_participants =
3647 pool.channel_connection_ids(channel_id)
3648 .filter_map(|(connection_id, _)| {
3649 if participant_connection_ids.contains(&connection_id) {
3650 None
3651 } else {
3652 Some(connection_id)
3653 }
3654 });
3655 broadcast(None, non_participants, |peer_id| {
3656 session.peer.send(
3657 peer_id,
3658 proto::UpdateChannels {
3659 latest_channel_message_ids: vec![proto::ChannelMessageId {
3660 channel_id: channel_id.to_proto(),
3661 message_id: message_id.to_proto(),
3662 }],
3663 ..Default::default()
3664 },
3665 )
3666 });
3667 send_notifications(pool, &session.peer, notifications);
3668
3669 Ok(())
3670}
3671
3672/// Delete a channel message
3673async fn remove_channel_message(
3674 request: proto::RemoveChannelMessage,
3675 response: Response<proto::RemoveChannelMessage>,
3676 session: MessageContext,
3677) -> Result<()> {
3678 let channel_id = ChannelId::from_proto(request.channel_id);
3679 let message_id = MessageId::from_proto(request.message_id);
3680 let (connection_ids, existing_notification_ids) = session
3681 .db()
3682 .await
3683 .remove_channel_message(channel_id, message_id, session.user_id())
3684 .await?;
3685
3686 broadcast(
3687 Some(session.connection_id),
3688 connection_ids,
3689 move |connection| {
3690 session.peer.send(connection, request.clone())?;
3691
3692 for notification_id in &existing_notification_ids {
3693 session.peer.send(
3694 connection,
3695 proto::DeleteNotification {
3696 notification_id: (*notification_id).to_proto(),
3697 },
3698 )?;
3699 }
3700
3701 Ok(())
3702 },
3703 );
3704 response.send(proto::Ack {})?;
3705 Ok(())
3706}
3707
3708async fn update_channel_message(
3709 request: proto::UpdateChannelMessage,
3710 response: Response<proto::UpdateChannelMessage>,
3711 session: MessageContext,
3712) -> Result<()> {
3713 let channel_id = ChannelId::from_proto(request.channel_id);
3714 let message_id = MessageId::from_proto(request.message_id);
3715 let updated_at = OffsetDateTime::now_utc();
3716 let UpdatedChannelMessage {
3717 message_id,
3718 participant_connection_ids,
3719 notifications,
3720 reply_to_message_id,
3721 timestamp,
3722 deleted_mention_notification_ids,
3723 updated_mention_notifications,
3724 } = session
3725 .db()
3726 .await
3727 .update_channel_message(
3728 channel_id,
3729 message_id,
3730 session.user_id(),
3731 request.body.as_str(),
3732 &request.mentions,
3733 updated_at,
3734 )
3735 .await?;
3736
3737 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3738
3739 let message = proto::ChannelMessage {
3740 sender_id: session.user_id().to_proto(),
3741 id: message_id.to_proto(),
3742 body: request.body.clone(),
3743 mentions: request.mentions.clone(),
3744 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3745 nonce: Some(nonce),
3746 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3747 edited_at: Some(updated_at.unix_timestamp() as u64),
3748 };
3749
3750 response.send(proto::Ack {})?;
3751
3752 let pool = &*session.connection_pool().await;
3753 broadcast(
3754 Some(session.connection_id),
3755 participant_connection_ids,
3756 |connection| {
3757 session.peer.send(
3758 connection,
3759 proto::ChannelMessageUpdate {
3760 channel_id: channel_id.to_proto(),
3761 message: Some(message.clone()),
3762 },
3763 )?;
3764
3765 for notification_id in &deleted_mention_notification_ids {
3766 session.peer.send(
3767 connection,
3768 proto::DeleteNotification {
3769 notification_id: (*notification_id).to_proto(),
3770 },
3771 )?;
3772 }
3773
3774 for notification in &updated_mention_notifications {
3775 session.peer.send(
3776 connection,
3777 proto::UpdateNotification {
3778 notification: Some(notification.clone()),
3779 },
3780 )?;
3781 }
3782
3783 Ok(())
3784 },
3785 );
3786
3787 send_notifications(pool, &session.peer, notifications);
3788
3789 Ok(())
3790}
3791
3792/// Mark a channel message as read
3793async fn acknowledge_channel_message(
3794 request: proto::AckChannelMessage,
3795 session: MessageContext,
3796) -> Result<()> {
3797 let channel_id = ChannelId::from_proto(request.channel_id);
3798 let message_id = MessageId::from_proto(request.message_id);
3799 let notifications = session
3800 .db()
3801 .await
3802 .observe_channel_message(channel_id, session.user_id(), message_id)
3803 .await?;
3804 send_notifications(
3805 &*session.connection_pool().await,
3806 &session.peer,
3807 notifications,
3808 );
3809 Ok(())
3810}
3811
3812/// Mark a buffer version as synced
3813async fn acknowledge_buffer_version(
3814 request: proto::AckBufferOperation,
3815 session: MessageContext,
3816) -> Result<()> {
3817 let buffer_id = BufferId::from_proto(request.buffer_id);
3818 session
3819 .db()
3820 .await
3821 .observe_buffer_version(
3822 buffer_id,
3823 session.user_id(),
3824 request.epoch as i32,
3825 &request.version,
3826 )
3827 .await?;
3828 Ok(())
3829}
3830
3831/// Get a Supermaven API key for the user
3832async fn get_supermaven_api_key(
3833 _request: proto::GetSupermavenApiKey,
3834 response: Response<proto::GetSupermavenApiKey>,
3835 session: MessageContext,
3836) -> Result<()> {
3837 let user_id: String = session.user_id().to_string();
3838 if !session.is_staff() {
3839 return Err(anyhow!("supermaven not enabled for this account"))?;
3840 }
3841
3842 let email = session.email().context("user must have an email")?;
3843
3844 let supermaven_admin_api = session
3845 .supermaven_client
3846 .as_ref()
3847 .context("supermaven not configured")?;
3848
3849 let result = supermaven_admin_api
3850 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3851 .await?;
3852
3853 response.send(proto::GetSupermavenApiKeyResponse {
3854 api_key: result.api_key,
3855 })?;
3856
3857 Ok(())
3858}
3859
3860/// Start receiving chat updates for a channel
3861async fn join_channel_chat(
3862 request: proto::JoinChannelChat,
3863 response: Response<proto::JoinChannelChat>,
3864 session: MessageContext,
3865) -> Result<()> {
3866 let channel_id = ChannelId::from_proto(request.channel_id);
3867
3868 let db = session.db().await;
3869 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3870 .await?;
3871 let messages = db
3872 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3873 .await?;
3874 response.send(proto::JoinChannelChatResponse {
3875 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3876 messages,
3877 })?;
3878 Ok(())
3879}
3880
3881/// Stop receiving chat updates for a channel
3882async fn leave_channel_chat(
3883 request: proto::LeaveChannelChat,
3884 session: MessageContext,
3885) -> Result<()> {
3886 let channel_id = ChannelId::from_proto(request.channel_id);
3887 session
3888 .db()
3889 .await
3890 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3891 .await?;
3892 Ok(())
3893}
3894
3895/// Retrieve the chat history for a channel
3896async fn get_channel_messages(
3897 request: proto::GetChannelMessages,
3898 response: Response<proto::GetChannelMessages>,
3899 session: MessageContext,
3900) -> Result<()> {
3901 let channel_id = ChannelId::from_proto(request.channel_id);
3902 let messages = session
3903 .db()
3904 .await
3905 .get_channel_messages(
3906 channel_id,
3907 session.user_id(),
3908 MESSAGE_COUNT_PER_PAGE,
3909 Some(MessageId::from_proto(request.before_message_id)),
3910 )
3911 .await?;
3912 response.send(proto::GetChannelMessagesResponse {
3913 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3914 messages,
3915 })?;
3916 Ok(())
3917}
3918
3919/// Retrieve specific chat messages
3920async fn get_channel_messages_by_id(
3921 request: proto::GetChannelMessagesById,
3922 response: Response<proto::GetChannelMessagesById>,
3923 session: MessageContext,
3924) -> Result<()> {
3925 let message_ids = request
3926 .message_ids
3927 .iter()
3928 .map(|id| MessageId::from_proto(*id))
3929 .collect::<Vec<_>>();
3930 let messages = session
3931 .db()
3932 .await
3933 .get_channel_messages_by_id(session.user_id(), &message_ids)
3934 .await?;
3935 response.send(proto::GetChannelMessagesResponse {
3936 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3937 messages,
3938 })?;
3939 Ok(())
3940}
3941
3942/// Retrieve the current users notifications
3943async fn get_notifications(
3944 request: proto::GetNotifications,
3945 response: Response<proto::GetNotifications>,
3946 session: MessageContext,
3947) -> Result<()> {
3948 let notifications = session
3949 .db()
3950 .await
3951 .get_notifications(
3952 session.user_id(),
3953 NOTIFICATION_COUNT_PER_PAGE,
3954 request.before_id.map(db::NotificationId::from_proto),
3955 )
3956 .await?;
3957 response.send(proto::GetNotificationsResponse {
3958 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3959 notifications,
3960 })?;
3961 Ok(())
3962}
3963
3964/// Mark notifications as read
3965async fn mark_notification_as_read(
3966 request: proto::MarkNotificationRead,
3967 response: Response<proto::MarkNotificationRead>,
3968 session: MessageContext,
3969) -> Result<()> {
3970 let database = &session.db().await;
3971 let notifications = database
3972 .mark_notification_as_read_by_id(
3973 session.user_id(),
3974 NotificationId::from_proto(request.notification_id),
3975 )
3976 .await?;
3977 send_notifications(
3978 &*session.connection_pool().await,
3979 &session.peer,
3980 notifications,
3981 );
3982 response.send(proto::Ack {})?;
3983 Ok(())
3984}
3985
3986fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3987 let message = match message {
3988 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3989 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3990 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3991 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3992 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3993 code: frame.code.into(),
3994 reason: frame.reason.as_str().to_owned().into(),
3995 })),
3996 // We should never receive a frame while reading the message, according
3997 // to the `tungstenite` maintainers:
3998 //
3999 // > It cannot occur when you read messages from the WebSocket, but it
4000 // > can be used when you want to send the raw frames (e.g. you want to
4001 // > send the frames to the WebSocket without composing the full message first).
4002 // >
4003 // > — https://github.com/snapview/tungstenite-rs/issues/268
4004 TungsteniteMessage::Frame(_) => {
4005 bail!("received an unexpected frame while reading the message")
4006 }
4007 };
4008
4009 Ok(message)
4010}
4011
4012fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4013 match message {
4014 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4015 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4016 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4017 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4018 AxumMessage::Close(frame) => {
4019 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4020 code: frame.code.into(),
4021 reason: frame.reason.as_ref().into(),
4022 }))
4023 }
4024 }
4025}
4026
4027fn notify_membership_updated(
4028 connection_pool: &mut ConnectionPool,
4029 result: MembershipUpdated,
4030 user_id: UserId,
4031 peer: &Peer,
4032) {
4033 for membership in &result.new_channels.channel_memberships {
4034 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4035 }
4036 for channel_id in &result.removed_channels {
4037 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4038 }
4039
4040 let user_channels_update = proto::UpdateUserChannels {
4041 channel_memberships: result
4042 .new_channels
4043 .channel_memberships
4044 .iter()
4045 .map(|cm| proto::ChannelMembership {
4046 channel_id: cm.channel_id.to_proto(),
4047 role: cm.role.into(),
4048 })
4049 .collect(),
4050 ..Default::default()
4051 };
4052
4053 let mut update = build_channels_update(result.new_channels);
4054 update.delete_channels = result
4055 .removed_channels
4056 .into_iter()
4057 .map(|id| id.to_proto())
4058 .collect();
4059 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4060
4061 for connection_id in connection_pool.user_connection_ids(user_id) {
4062 peer.send(connection_id, user_channels_update.clone())
4063 .trace_err();
4064 peer.send(connection_id, update.clone()).trace_err();
4065 }
4066}
4067
4068fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4069 proto::UpdateUserChannels {
4070 channel_memberships: channels
4071 .channel_memberships
4072 .iter()
4073 .map(|m| proto::ChannelMembership {
4074 channel_id: m.channel_id.to_proto(),
4075 role: m.role.into(),
4076 })
4077 .collect(),
4078 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4079 observed_channel_message_id: channels.observed_channel_messages.clone(),
4080 }
4081}
4082
4083fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4084 let mut update = proto::UpdateChannels::default();
4085
4086 for channel in channels.channels {
4087 update.channels.push(channel.to_proto());
4088 }
4089
4090 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4091 update.latest_channel_message_ids = channels.latest_channel_messages;
4092
4093 for (channel_id, participants) in channels.channel_participants {
4094 update
4095 .channel_participants
4096 .push(proto::ChannelParticipants {
4097 channel_id: channel_id.to_proto(),
4098 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4099 });
4100 }
4101
4102 for channel in channels.invited_channels {
4103 update.channel_invitations.push(channel.to_proto());
4104 }
4105
4106 update
4107}
4108
4109fn build_initial_contacts_update(
4110 contacts: Vec<db::Contact>,
4111 pool: &ConnectionPool,
4112) -> proto::UpdateContacts {
4113 let mut update = proto::UpdateContacts::default();
4114
4115 for contact in contacts {
4116 match contact {
4117 db::Contact::Accepted { user_id, busy } => {
4118 update.contacts.push(contact_for_user(user_id, busy, pool));
4119 }
4120 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4121 db::Contact::Incoming { user_id } => {
4122 update
4123 .incoming_requests
4124 .push(proto::IncomingContactRequest {
4125 requester_id: user_id.to_proto(),
4126 })
4127 }
4128 }
4129 }
4130
4131 update
4132}
4133
4134fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4135 proto::Contact {
4136 user_id: user_id.to_proto(),
4137 online: pool.is_user_online(user_id),
4138 busy,
4139 }
4140}
4141
4142fn room_updated(room: &proto::Room, peer: &Peer) {
4143 broadcast(
4144 None,
4145 room.participants
4146 .iter()
4147 .filter_map(|participant| Some(participant.peer_id?.into())),
4148 |peer_id| {
4149 peer.send(
4150 peer_id,
4151 proto::RoomUpdated {
4152 room: Some(room.clone()),
4153 },
4154 )
4155 },
4156 );
4157}
4158
4159fn channel_updated(
4160 channel: &db::channel::Model,
4161 room: &proto::Room,
4162 peer: &Peer,
4163 pool: &ConnectionPool,
4164) {
4165 let participants = room
4166 .participants
4167 .iter()
4168 .map(|p| p.user_id)
4169 .collect::<Vec<_>>();
4170
4171 broadcast(
4172 None,
4173 pool.channel_connection_ids(channel.root_id())
4174 .filter_map(|(channel_id, role)| {
4175 role.can_see_channel(channel.visibility)
4176 .then_some(channel_id)
4177 }),
4178 |peer_id| {
4179 peer.send(
4180 peer_id,
4181 proto::UpdateChannels {
4182 channel_participants: vec![proto::ChannelParticipants {
4183 channel_id: channel.id.to_proto(),
4184 participant_user_ids: participants.clone(),
4185 }],
4186 ..Default::default()
4187 },
4188 )
4189 },
4190 );
4191}
4192
4193async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4194 let db = session.db().await;
4195
4196 let contacts = db.get_contacts(user_id).await?;
4197 let busy = db.is_user_busy(user_id).await?;
4198
4199 let pool = session.connection_pool().await;
4200 let updated_contact = contact_for_user(user_id, busy, &pool);
4201 for contact in contacts {
4202 if let db::Contact::Accepted {
4203 user_id: contact_user_id,
4204 ..
4205 } = contact
4206 {
4207 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4208 session
4209 .peer
4210 .send(
4211 contact_conn_id,
4212 proto::UpdateContacts {
4213 contacts: vec![updated_contact.clone()],
4214 remove_contacts: Default::default(),
4215 incoming_requests: Default::default(),
4216 remove_incoming_requests: Default::default(),
4217 outgoing_requests: Default::default(),
4218 remove_outgoing_requests: Default::default(),
4219 },
4220 )
4221 .trace_err();
4222 }
4223 }
4224 }
4225 Ok(())
4226}
4227
4228async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4229 let mut contacts_to_update = HashSet::default();
4230
4231 let room_id;
4232 let canceled_calls_to_user_ids;
4233 let livekit_room;
4234 let delete_livekit_room;
4235 let room;
4236 let channel;
4237
4238 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4239 contacts_to_update.insert(session.user_id());
4240
4241 for project in left_room.left_projects.values() {
4242 project_left(project, session);
4243 }
4244
4245 room_id = RoomId::from_proto(left_room.room.id);
4246 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4247 livekit_room = mem::take(&mut left_room.room.livekit_room);
4248 delete_livekit_room = left_room.deleted;
4249 room = mem::take(&mut left_room.room);
4250 channel = mem::take(&mut left_room.channel);
4251
4252 room_updated(&room, &session.peer);
4253 } else {
4254 return Ok(());
4255 }
4256
4257 if let Some(channel) = channel {
4258 channel_updated(
4259 &channel,
4260 &room,
4261 &session.peer,
4262 &*session.connection_pool().await,
4263 );
4264 }
4265
4266 {
4267 let pool = session.connection_pool().await;
4268 for canceled_user_id in canceled_calls_to_user_ids {
4269 for connection_id in pool.user_connection_ids(canceled_user_id) {
4270 session
4271 .peer
4272 .send(
4273 connection_id,
4274 proto::CallCanceled {
4275 room_id: room_id.to_proto(),
4276 },
4277 )
4278 .trace_err();
4279 }
4280 contacts_to_update.insert(canceled_user_id);
4281 }
4282 }
4283
4284 for contact_user_id in contacts_to_update {
4285 update_user_contacts(contact_user_id, session).await?;
4286 }
4287
4288 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4289 live_kit
4290 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4291 .await
4292 .trace_err();
4293
4294 if delete_livekit_room {
4295 live_kit.delete_room(livekit_room).await.trace_err();
4296 }
4297 }
4298
4299 Ok(())
4300}
4301
4302async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4303 let left_channel_buffers = session
4304 .db()
4305 .await
4306 .leave_channel_buffers(session.connection_id)
4307 .await?;
4308
4309 for left_buffer in left_channel_buffers {
4310 channel_buffer_updated(
4311 session.connection_id,
4312 left_buffer.connections,
4313 &proto::UpdateChannelBufferCollaborators {
4314 channel_id: left_buffer.channel_id.to_proto(),
4315 collaborators: left_buffer.collaborators,
4316 },
4317 &session.peer,
4318 );
4319 }
4320
4321 Ok(())
4322}
4323
4324fn project_left(project: &db::LeftProject, session: &Session) {
4325 for connection_id in &project.connection_ids {
4326 if project.should_unshare {
4327 session
4328 .peer
4329 .send(
4330 *connection_id,
4331 proto::UnshareProject {
4332 project_id: project.id.to_proto(),
4333 },
4334 )
4335 .trace_err();
4336 } else {
4337 session
4338 .peer
4339 .send(
4340 *connection_id,
4341 proto::RemoveProjectCollaborator {
4342 project_id: project.id.to_proto(),
4343 peer_id: Some(session.connection_id.into()),
4344 },
4345 )
4346 .trace_err();
4347 }
4348 }
4349}
4350
4351pub trait ResultExt {
4352 type Ok;
4353
4354 fn trace_err(self) -> Option<Self::Ok>;
4355}
4356
4357impl<T, E> ResultExt for Result<T, E>
4358where
4359 E: std::fmt::Debug,
4360{
4361 type Ok = T;
4362
4363 #[track_caller]
4364 fn trace_err(self) -> Option<T> {
4365 match self {
4366 Ok(value) => Some(value),
4367 Err(error) => {
4368 tracing::error!("{:?}", error);
4369 None
4370 }
4371 }
4372 }
4373}