1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
8 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
9 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
10 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use chrono::Utc;
33use collections::{HashMap, HashSet};
34pub use connection_pool::{ConnectionPool, ZedVersion};
35use core::fmt::{self, Debug, Formatter};
36use futures::TryFutureExt as _;
37use reqwest_client::ReqwestClient;
38use rpc::proto::{MultiLspQuery, split_repository_update};
39use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
40use tracing::Span;
41
42use futures::{
43 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
44 stream::FuturesUnordered,
45};
46use prometheus::{IntGauge, register_int_gauge};
47use rpc::{
48 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 Arc, OnceLock,
66 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{Semaphore, watch};
72use tower::ServiceBuilder;
73use tracing::{
74 Instrument,
75 field::{self},
76 info_span, instrument,
77};
78
79pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
80
81// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
82pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
83
84const MESSAGE_COUNT_PER_PAGE: usize = 100;
85const MAX_MESSAGE_LEN: usize = 1024;
86const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
87const MAX_CONCURRENT_CONNECTIONS: usize = 512;
88
89static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
90
91const TOTAL_DURATION_MS: &str = "total_duration_ms";
92const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
93const QUEUE_DURATION_MS: &str = "queue_duration_ms";
94const HOST_WAITING_MS: &str = "host_waiting_ms";
95
96type MessageHandler =
97 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
98
99pub struct ConnectionGuard;
100
101impl ConnectionGuard {
102 pub fn try_acquire() -> Result<Self, ()> {
103 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
104 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
105 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
106 tracing::error!(
107 "too many concurrent connections: {}",
108 current_connections + 1
109 );
110 return Err(());
111 }
112 Ok(ConnectionGuard)
113 }
114}
115
116impl Drop for ConnectionGuard {
117 fn drop(&mut self) {
118 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
119 }
120}
121
122struct Response<R> {
123 peer: Arc<Peer>,
124 receipt: Receipt<R>,
125 responded: Arc<AtomicBool>,
126}
127
128impl<R: RequestMessage> Response<R> {
129 fn send(self, payload: R::Response) -> Result<()> {
130 self.responded.store(true, SeqCst);
131 self.peer.respond(self.receipt, payload)?;
132 Ok(())
133 }
134}
135
136#[derive(Clone, Debug)]
137pub enum Principal {
138 User(User),
139 Impersonated { user: User, admin: User },
140}
141
142impl Principal {
143 fn update_span(&self, span: &tracing::Span) {
144 match &self {
145 Principal::User(user) => {
146 span.record("user_id", user.id.0);
147 span.record("login", &user.github_login);
148 }
149 Principal::Impersonated { user, admin } => {
150 span.record("user_id", user.id.0);
151 span.record("login", &user.github_login);
152 span.record("impersonator", &admin.github_login);
153 }
154 }
155 }
156}
157
158#[derive(Clone)]
159struct MessageContext {
160 session: Session,
161 span: tracing::Span,
162}
163
164impl Deref for MessageContext {
165 type Target = Session;
166
167 fn deref(&self) -> &Self::Target {
168 &self.session
169 }
170}
171
172impl MessageContext {
173 pub fn forward_request<T: RequestMessage>(
174 &self,
175 receiver_id: ConnectionId,
176 request: T,
177 ) -> impl Future<Output = anyhow::Result<T::Response>> {
178 let request_start_time = Instant::now();
179 let span = self.span.clone();
180 tracing::info!("start forwarding request");
181 self.peer
182 .forward_request(self.connection_id, receiver_id, request)
183 .inspect(move |_| {
184 span.record(
185 HOST_WAITING_MS,
186 request_start_time.elapsed().as_micros() as f64 / 1000.0,
187 );
188 })
189 .inspect_err(|_| tracing::error!("error forwarding request"))
190 .inspect_ok(|_| tracing::info!("finished forwarding request"))
191 }
192}
193
194#[derive(Clone)]
195struct Session {
196 principal: Principal,
197 connection_id: ConnectionId,
198 db: Arc<tokio::sync::Mutex<DbHandle>>,
199 peer: Arc<Peer>,
200 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
201 app_state: Arc<AppState>,
202 supermaven_client: Option<Arc<SupermavenAdminApi>>,
203 /// The GeoIP country code for the user.
204 #[allow(unused)]
205 geoip_country_code: Option<String>,
206 #[allow(unused)]
207 system_id: Option<String>,
208 _executor: Executor,
209}
210
211impl Session {
212 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
213 #[cfg(test)]
214 tokio::task::yield_now().await;
215 let guard = self.db.lock().await;
216 #[cfg(test)]
217 tokio::task::yield_now().await;
218 guard
219 }
220
221 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
222 #[cfg(test)]
223 tokio::task::yield_now().await;
224 let guard = self.connection_pool.lock();
225 ConnectionPoolGuard {
226 guard,
227 _not_send: PhantomData,
228 }
229 }
230
231 fn is_staff(&self) -> bool {
232 match &self.principal {
233 Principal::User(user) => user.admin,
234 Principal::Impersonated { .. } => true,
235 }
236 }
237
238 fn user_id(&self) -> UserId {
239 match &self.principal {
240 Principal::User(user) => user.id,
241 Principal::Impersonated { user, .. } => user.id,
242 }
243 }
244
245 pub fn email(&self) -> Option<String> {
246 match &self.principal {
247 Principal::User(user) => user.email_address.clone(),
248 Principal::Impersonated { user, .. } => user.email_address.clone(),
249 }
250 }
251}
252
253impl Debug for Session {
254 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
255 let mut result = f.debug_struct("Session");
256 match &self.principal {
257 Principal::User(user) => {
258 result.field("user", &user.github_login);
259 }
260 Principal::Impersonated { user, admin } => {
261 result.field("user", &user.github_login);
262 result.field("impersonator", &admin.github_login);
263 }
264 }
265 result.field("connection_id", &self.connection_id).finish()
266 }
267}
268
269struct DbHandle(Arc<Database>);
270
271impl Deref for DbHandle {
272 type Target = Database;
273
274 fn deref(&self) -> &Self::Target {
275 self.0.as_ref()
276 }
277}
278
279pub struct Server {
280 id: parking_lot::Mutex<ServerId>,
281 peer: Arc<Peer>,
282 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
283 app_state: Arc<AppState>,
284 handlers: HashMap<TypeId, MessageHandler>,
285 teardown: watch::Sender<bool>,
286}
287
288pub(crate) struct ConnectionPoolGuard<'a> {
289 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
290 _not_send: PhantomData<Rc<()>>,
291}
292
293#[derive(Serialize)]
294pub struct ServerSnapshot<'a> {
295 peer: &'a Peer,
296 #[serde(serialize_with = "serialize_deref")]
297 connection_pool: ConnectionPoolGuard<'a>,
298}
299
300pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
301where
302 S: Serializer,
303 T: Deref<Target = U>,
304 U: Serialize,
305{
306 Serialize::serialize(value.deref(), serializer)
307}
308
309impl Server {
310 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
311 let mut server = Self {
312 id: parking_lot::Mutex::new(id),
313 peer: Peer::new(id.0 as u32),
314 app_state: app_state.clone(),
315 connection_pool: Default::default(),
316 handlers: Default::default(),
317 teardown: watch::channel(false).0,
318 };
319
320 server
321 .add_request_handler(ping)
322 .add_request_handler(create_room)
323 .add_request_handler(join_room)
324 .add_request_handler(rejoin_room)
325 .add_request_handler(leave_room)
326 .add_request_handler(set_room_participant_role)
327 .add_request_handler(call)
328 .add_request_handler(cancel_call)
329 .add_message_handler(decline_call)
330 .add_request_handler(update_participant_location)
331 .add_request_handler(share_project)
332 .add_message_handler(unshare_project)
333 .add_request_handler(join_project)
334 .add_message_handler(leave_project)
335 .add_request_handler(update_project)
336 .add_request_handler(update_worktree)
337 .add_request_handler(update_repository)
338 .add_request_handler(remove_repository)
339 .add_message_handler(start_language_server)
340 .add_message_handler(update_language_server)
341 .add_message_handler(update_diagnostic_summary)
342 .add_message_handler(update_worktree_settings)
343 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
344 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
345 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
346 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
347 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
348 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
349 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
350 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
351 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
352 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
353 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
354 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
355 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
356 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
357 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
358 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
359 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
360 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
361 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
362 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
363 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
364 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
365 .add_request_handler(
366 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
367 )
368 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
369 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
370 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
371 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
372 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
373 .add_request_handler(
374 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
375 )
376 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
377 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
378 .add_request_handler(
379 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
380 )
381 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
382 .add_request_handler(
383 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
384 )
385 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
386 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
387 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
388 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
389 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
390 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
391 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
392 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
393 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
394 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
395 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
396 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
397 .add_request_handler(
398 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
399 )
400 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
401 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
402 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
403 .add_request_handler(multi_lsp_query)
404 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
405 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
406 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
407 .add_message_handler(create_buffer_for_peer)
408 .add_request_handler(update_buffer)
409 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
410 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
411 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
412 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
413 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
414 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
415 .add_message_handler(
416 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
417 )
418 .add_request_handler(get_users)
419 .add_request_handler(fuzzy_search_users)
420 .add_request_handler(request_contact)
421 .add_request_handler(remove_contact)
422 .add_request_handler(respond_to_contact_request)
423 .add_message_handler(subscribe_to_channels)
424 .add_request_handler(create_channel)
425 .add_request_handler(delete_channel)
426 .add_request_handler(invite_channel_member)
427 .add_request_handler(remove_channel_member)
428 .add_request_handler(set_channel_member_role)
429 .add_request_handler(set_channel_visibility)
430 .add_request_handler(rename_channel)
431 .add_request_handler(join_channel_buffer)
432 .add_request_handler(leave_channel_buffer)
433 .add_message_handler(update_channel_buffer)
434 .add_request_handler(rejoin_channel_buffers)
435 .add_request_handler(get_channel_members)
436 .add_request_handler(respond_to_channel_invite)
437 .add_request_handler(join_channel)
438 .add_request_handler(join_channel_chat)
439 .add_message_handler(leave_channel_chat)
440 .add_request_handler(send_channel_message)
441 .add_request_handler(remove_channel_message)
442 .add_request_handler(update_channel_message)
443 .add_request_handler(get_channel_messages)
444 .add_request_handler(get_channel_messages_by_id)
445 .add_request_handler(get_notifications)
446 .add_request_handler(mark_notification_as_read)
447 .add_request_handler(move_channel)
448 .add_request_handler(reorder_channel)
449 .add_request_handler(follow)
450 .add_message_handler(unfollow)
451 .add_message_handler(update_followers)
452 .add_request_handler(accept_terms_of_service)
453 .add_message_handler(acknowledge_channel_message)
454 .add_message_handler(acknowledge_buffer_version)
455 .add_request_handler(get_supermaven_api_key)
456 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
457 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
458 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
459 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
460 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
461 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
462 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
463 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
464 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
465 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
466 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
467 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
468 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
469 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
470 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
471 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
472 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
473 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
474 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
475 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
476 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
477 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
478 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
479 .add_message_handler(update_context);
480
481 Arc::new(server)
482 }
483
484 pub async fn start(&self) -> Result<()> {
485 let server_id = *self.id.lock();
486 let app_state = self.app_state.clone();
487 let peer = self.peer.clone();
488 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
489 let pool = self.connection_pool.clone();
490 let livekit_client = self.app_state.livekit_client.clone();
491
492 let span = info_span!("start server");
493 self.app_state.executor.spawn_detached(
494 async move {
495 tracing::info!("waiting for cleanup timeout");
496 timeout.await;
497 tracing::info!("cleanup timeout expired, retrieving stale rooms");
498
499 app_state
500 .db
501 .delete_stale_channel_chat_participants(
502 &app_state.config.zed_environment,
503 server_id,
504 )
505 .await
506 .trace_err();
507
508 if let Some((room_ids, channel_ids)) = app_state
509 .db
510 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
511 .await
512 .trace_err()
513 {
514 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
515 tracing::info!(
516 stale_channel_buffer_count = channel_ids.len(),
517 "retrieved stale channel buffers"
518 );
519
520 for channel_id in channel_ids {
521 if let Some(refreshed_channel_buffer) = app_state
522 .db
523 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
524 .await
525 .trace_err()
526 {
527 for connection_id in refreshed_channel_buffer.connection_ids {
528 peer.send(
529 connection_id,
530 proto::UpdateChannelBufferCollaborators {
531 channel_id: channel_id.to_proto(),
532 collaborators: refreshed_channel_buffer
533 .collaborators
534 .clone(),
535 },
536 )
537 .trace_err();
538 }
539 }
540 }
541
542 for room_id in room_ids {
543 let mut contacts_to_update = HashSet::default();
544 let mut canceled_calls_to_user_ids = Vec::new();
545 let mut livekit_room = String::new();
546 let mut delete_livekit_room = false;
547
548 if let Some(mut refreshed_room) = app_state
549 .db
550 .clear_stale_room_participants(room_id, server_id)
551 .await
552 .trace_err()
553 {
554 tracing::info!(
555 room_id = room_id.0,
556 new_participant_count = refreshed_room.room.participants.len(),
557 "refreshed room"
558 );
559 room_updated(&refreshed_room.room, &peer);
560 if let Some(channel) = refreshed_room.channel.as_ref() {
561 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
562 }
563 contacts_to_update
564 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
565 contacts_to_update
566 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
567 canceled_calls_to_user_ids =
568 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
569 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
570 delete_livekit_room = refreshed_room.room.participants.is_empty();
571 }
572
573 {
574 let pool = pool.lock();
575 for canceled_user_id in canceled_calls_to_user_ids {
576 for connection_id in pool.user_connection_ids(canceled_user_id) {
577 peer.send(
578 connection_id,
579 proto::CallCanceled {
580 room_id: room_id.to_proto(),
581 },
582 )
583 .trace_err();
584 }
585 }
586 }
587
588 for user_id in contacts_to_update {
589 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
590 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
591 if let Some((busy, contacts)) = busy.zip(contacts) {
592 let pool = pool.lock();
593 let updated_contact = contact_for_user(user_id, busy, &pool);
594 for contact in contacts {
595 if let db::Contact::Accepted {
596 user_id: contact_user_id,
597 ..
598 } = contact
599 {
600 for contact_conn_id in
601 pool.user_connection_ids(contact_user_id)
602 {
603 peer.send(
604 contact_conn_id,
605 proto::UpdateContacts {
606 contacts: vec![updated_contact.clone()],
607 remove_contacts: Default::default(),
608 incoming_requests: Default::default(),
609 remove_incoming_requests: Default::default(),
610 outgoing_requests: Default::default(),
611 remove_outgoing_requests: Default::default(),
612 },
613 )
614 .trace_err();
615 }
616 }
617 }
618 }
619 }
620
621 if let Some(live_kit) = livekit_client.as_ref() {
622 if delete_livekit_room {
623 live_kit.delete_room(livekit_room).await.trace_err();
624 }
625 }
626 }
627 }
628
629 app_state
630 .db
631 .delete_stale_channel_chat_participants(
632 &app_state.config.zed_environment,
633 server_id,
634 )
635 .await
636 .trace_err();
637
638 app_state
639 .db
640 .clear_old_worktree_entries(server_id)
641 .await
642 .trace_err();
643
644 app_state
645 .db
646 .delete_stale_servers(&app_state.config.zed_environment, server_id)
647 .await
648 .trace_err();
649 }
650 .instrument(span),
651 );
652 Ok(())
653 }
654
655 pub fn teardown(&self) {
656 self.peer.teardown();
657 self.connection_pool.lock().reset();
658 let _ = self.teardown.send(true);
659 }
660
661 #[cfg(test)]
662 pub fn reset(&self, id: ServerId) {
663 self.teardown();
664 *self.id.lock() = id;
665 self.peer.reset(id.0 as u32);
666 let _ = self.teardown.send(false);
667 }
668
669 #[cfg(test)]
670 pub fn id(&self) -> ServerId {
671 *self.id.lock()
672 }
673
674 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
675 where
676 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
677 Fut: 'static + Send + Future<Output = Result<()>>,
678 M: EnvelopedMessage,
679 {
680 let prev_handler = self.handlers.insert(
681 TypeId::of::<M>(),
682 Box::new(move |envelope, session, span| {
683 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
684 let received_at = envelope.received_at;
685 tracing::info!("message received");
686 let start_time = Instant::now();
687 let future = (handler)(
688 *envelope,
689 MessageContext {
690 session,
691 span: span.clone(),
692 },
693 );
694 async move {
695 let result = future.await;
696 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
697 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
698 let queue_duration_ms = total_duration_ms - processing_duration_ms;
699 span.record(TOTAL_DURATION_MS, total_duration_ms);
700 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
701 span.record(QUEUE_DURATION_MS, queue_duration_ms);
702 match result {
703 Err(error) => {
704 tracing::error!(?error, "error handling message")
705 }
706 Ok(()) => tracing::info!("finished handling message"),
707 }
708 }
709 .boxed()
710 }),
711 );
712 if prev_handler.is_some() {
713 panic!("registered a handler for the same message twice");
714 }
715 self
716 }
717
718 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
719 where
720 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
721 Fut: 'static + Send + Future<Output = Result<()>>,
722 M: EnvelopedMessage,
723 {
724 self.add_handler(move |envelope, session| handler(envelope.payload, session));
725 self
726 }
727
728 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
729 where
730 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
731 Fut: Send + Future<Output = Result<()>>,
732 M: RequestMessage,
733 {
734 let handler = Arc::new(handler);
735 self.add_handler(move |envelope, session| {
736 let receipt = envelope.receipt();
737 let handler = handler.clone();
738 async move {
739 let peer = session.peer.clone();
740 let responded = Arc::new(AtomicBool::default());
741 let response = Response {
742 peer: peer.clone(),
743 responded: responded.clone(),
744 receipt,
745 };
746 match (handler)(envelope.payload, response, session).await {
747 Ok(()) => {
748 if responded.load(std::sync::atomic::Ordering::SeqCst) {
749 Ok(())
750 } else {
751 Err(anyhow!("handler did not send a response"))?
752 }
753 }
754 Err(error) => {
755 let proto_err = match &error {
756 Error::Internal(err) => err.to_proto(),
757 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
758 };
759 peer.respond_with_error(receipt, proto_err)?;
760 Err(error)
761 }
762 }
763 }
764 })
765 }
766
767 pub fn handle_connection(
768 self: &Arc<Self>,
769 connection: Connection,
770 address: String,
771 principal: Principal,
772 zed_version: ZedVersion,
773 release_channel: Option<String>,
774 user_agent: Option<String>,
775 geoip_country_code: Option<String>,
776 system_id: Option<String>,
777 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
778 executor: Executor,
779 connection_guard: Option<ConnectionGuard>,
780 ) -> impl Future<Output = ()> + use<> {
781 let this = self.clone();
782 let span = info_span!("handle connection", %address,
783 connection_id=field::Empty,
784 user_id=field::Empty,
785 login=field::Empty,
786 impersonator=field::Empty,
787 user_agent=field::Empty,
788 geoip_country_code=field::Empty,
789 release_channel=field::Empty,
790 );
791 principal.update_span(&span);
792 if let Some(user_agent) = user_agent {
793 span.record("user_agent", user_agent);
794 }
795 if let Some(release_channel) = release_channel {
796 span.record("release_channel", release_channel);
797 }
798
799 if let Some(country_code) = geoip_country_code.as_ref() {
800 span.record("geoip_country_code", country_code);
801 }
802
803 let mut teardown = self.teardown.subscribe();
804 async move {
805 if *teardown.borrow() {
806 tracing::error!("server is tearing down");
807 return;
808 }
809
810 let (connection_id, handle_io, mut incoming_rx) =
811 this.peer.add_connection(connection, {
812 let executor = executor.clone();
813 move |duration| executor.sleep(duration)
814 });
815 tracing::Span::current().record("connection_id", format!("{}", connection_id));
816
817 tracing::info!("connection opened");
818
819 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
820 let http_client = match ReqwestClient::user_agent(&user_agent) {
821 Ok(http_client) => Arc::new(http_client),
822 Err(error) => {
823 tracing::error!(?error, "failed to create HTTP client");
824 return;
825 }
826 };
827
828 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
829 |supermaven_admin_api_key| {
830 Arc::new(SupermavenAdminApi::new(
831 supermaven_admin_api_key.to_string(),
832 http_client.clone(),
833 ))
834 },
835 );
836
837 let session = Session {
838 principal: principal.clone(),
839 connection_id,
840 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
841 peer: this.peer.clone(),
842 connection_pool: this.connection_pool.clone(),
843 app_state: this.app_state.clone(),
844 geoip_country_code,
845 system_id,
846 _executor: executor.clone(),
847 supermaven_client,
848 };
849
850 if let Err(error) = this
851 .send_initial_client_update(
852 connection_id,
853 zed_version,
854 send_connection_id,
855 &session,
856 )
857 .await
858 {
859 tracing::error!(?error, "failed to send initial client update");
860 return;
861 }
862 drop(connection_guard);
863
864 let handle_io = handle_io.fuse();
865 futures::pin_mut!(handle_io);
866
867 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
868 // This prevents deadlocks when e.g., client A performs a request to client B and
869 // client B performs a request to client A. If both clients stop processing further
870 // messages until their respective request completes, they won't have a chance to
871 // respond to the other client's request and cause a deadlock.
872 //
873 // This arrangement ensures we will attempt to process earlier messages first, but fall
874 // back to processing messages arrived later in the spirit of making progress.
875 const MAX_CONCURRENT_HANDLERS: usize = 256;
876 let mut foreground_message_handlers = FuturesUnordered::new();
877 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
878 let get_concurrent_handlers = {
879 let concurrent_handlers = concurrent_handlers.clone();
880 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
881 };
882 loop {
883 let next_message = async {
884 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
885 let message = incoming_rx.next().await;
886 // Cache the concurrent_handlers here, so that we know what the
887 // queue looks like as each handler starts
888 (permit, message, get_concurrent_handlers())
889 }
890 .fuse();
891 futures::pin_mut!(next_message);
892 futures::select_biased! {
893 _ = teardown.changed().fuse() => return,
894 result = handle_io => {
895 if let Err(error) = result {
896 tracing::error!(?error, "error handling I/O");
897 }
898 break;
899 }
900 _ = foreground_message_handlers.next() => {}
901 next_message = next_message => {
902 let (permit, message, concurrent_handlers) = next_message;
903 if let Some(message) = message {
904 let type_name = message.payload_type_name();
905 // note: we copy all the fields from the parent span so we can query them in the logs.
906 // (https://github.com/tokio-rs/tracing/issues/2670).
907 let span = tracing::info_span!("receive message",
908 %connection_id,
909 %address,
910 type_name,
911 concurrent_handlers,
912 user_id=field::Empty,
913 login=field::Empty,
914 impersonator=field::Empty,
915 multi_lsp_query_request=field::Empty,
916 release_channel=field::Empty,
917 { TOTAL_DURATION_MS }=field::Empty,
918 { PROCESSING_DURATION_MS }=field::Empty,
919 { QUEUE_DURATION_MS }=field::Empty,
920 { HOST_WAITING_MS }=field::Empty
921 );
922 principal.update_span(&span);
923 let span_enter = span.enter();
924 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
925 let is_background = message.is_background();
926 let handle_message = (handler)(message, session.clone(), span.clone());
927 drop(span_enter);
928
929 let handle_message = async move {
930 handle_message.await;
931 drop(permit);
932 }.instrument(span);
933 if is_background {
934 executor.spawn_detached(handle_message);
935 } else {
936 foreground_message_handlers.push(handle_message);
937 }
938 } else {
939 tracing::error!("no message handler");
940 }
941 } else {
942 tracing::info!("connection closed");
943 break;
944 }
945 }
946 }
947 }
948
949 drop(foreground_message_handlers);
950 let concurrent_handlers = get_concurrent_handlers();
951 tracing::info!(concurrent_handlers, "signing out");
952 if let Err(error) = connection_lost(session, teardown, executor).await {
953 tracing::error!(?error, "error signing out");
954 }
955 }
956 .instrument(span)
957 }
958
959 async fn send_initial_client_update(
960 &self,
961 connection_id: ConnectionId,
962 zed_version: ZedVersion,
963 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
964 session: &Session,
965 ) -> Result<()> {
966 self.peer.send(
967 connection_id,
968 proto::Hello {
969 peer_id: Some(connection_id.into()),
970 },
971 )?;
972 tracing::info!("sent hello message");
973 if let Some(send_connection_id) = send_connection_id.take() {
974 let _ = send_connection_id.send(connection_id);
975 }
976
977 match &session.principal {
978 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
979 if !user.connected_once {
980 self.peer.send(connection_id, proto::ShowContacts {})?;
981 self.app_state
982 .db
983 .set_user_connected_once(user.id, true)
984 .await?;
985 }
986
987 let contacts = self.app_state.db.get_contacts(user.id).await?;
988
989 {
990 let mut pool = self.connection_pool.lock();
991 pool.add_connection(connection_id, user.id, user.admin, zed_version);
992 self.peer.send(
993 connection_id,
994 build_initial_contacts_update(contacts, &pool),
995 )?;
996 }
997
998 if should_auto_subscribe_to_channels(zed_version) {
999 subscribe_user_to_channels(user.id, session).await?;
1000 }
1001
1002 if let Some(incoming_call) =
1003 self.app_state.db.incoming_call_for_user(user.id).await?
1004 {
1005 self.peer.send(connection_id, incoming_call)?;
1006 }
1007
1008 update_user_contacts(user.id, session).await?;
1009 }
1010 }
1011
1012 Ok(())
1013 }
1014
1015 pub async fn invite_code_redeemed(
1016 self: &Arc<Self>,
1017 inviter_id: UserId,
1018 invitee_id: UserId,
1019 ) -> Result<()> {
1020 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1021 if let Some(code) = &user.invite_code {
1022 let pool = self.connection_pool.lock();
1023 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1024 for connection_id in pool.user_connection_ids(inviter_id) {
1025 self.peer.send(
1026 connection_id,
1027 proto::UpdateContacts {
1028 contacts: vec![invitee_contact.clone()],
1029 ..Default::default()
1030 },
1031 )?;
1032 self.peer.send(
1033 connection_id,
1034 proto::UpdateInviteInfo {
1035 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1036 count: user.invite_count as u32,
1037 },
1038 )?;
1039 }
1040 }
1041 }
1042 Ok(())
1043 }
1044
1045 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1046 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1047 if let Some(invite_code) = &user.invite_code {
1048 let pool = self.connection_pool.lock();
1049 for connection_id in pool.user_connection_ids(user_id) {
1050 self.peer.send(
1051 connection_id,
1052 proto::UpdateInviteInfo {
1053 url: format!(
1054 "{}{}",
1055 self.app_state.config.invite_link_prefix, invite_code
1056 ),
1057 count: user.invite_count as u32,
1058 },
1059 )?;
1060 }
1061 }
1062 }
1063 Ok(())
1064 }
1065
1066 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1067 ServerSnapshot {
1068 connection_pool: ConnectionPoolGuard {
1069 guard: self.connection_pool.lock(),
1070 _not_send: PhantomData,
1071 },
1072 peer: &self.peer,
1073 }
1074 }
1075}
1076
1077impl Deref for ConnectionPoolGuard<'_> {
1078 type Target = ConnectionPool;
1079
1080 fn deref(&self) -> &Self::Target {
1081 &self.guard
1082 }
1083}
1084
1085impl DerefMut for ConnectionPoolGuard<'_> {
1086 fn deref_mut(&mut self) -> &mut Self::Target {
1087 &mut self.guard
1088 }
1089}
1090
1091impl Drop for ConnectionPoolGuard<'_> {
1092 fn drop(&mut self) {
1093 #[cfg(test)]
1094 self.check_invariants();
1095 }
1096}
1097
1098fn broadcast<F>(
1099 sender_id: Option<ConnectionId>,
1100 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1101 mut f: F,
1102) where
1103 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1104{
1105 for receiver_id in receiver_ids {
1106 if Some(receiver_id) != sender_id {
1107 if let Err(error) = f(receiver_id) {
1108 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1109 }
1110 }
1111 }
1112}
1113
1114pub struct ProtocolVersion(u32);
1115
1116impl Header for ProtocolVersion {
1117 fn name() -> &'static HeaderName {
1118 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1119 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1120 }
1121
1122 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1123 where
1124 Self: Sized,
1125 I: Iterator<Item = &'i axum::http::HeaderValue>,
1126 {
1127 let version = values
1128 .next()
1129 .ok_or_else(axum::headers::Error::invalid)?
1130 .to_str()
1131 .map_err(|_| axum::headers::Error::invalid())?
1132 .parse()
1133 .map_err(|_| axum::headers::Error::invalid())?;
1134 Ok(Self(version))
1135 }
1136
1137 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1138 values.extend([self.0.to_string().parse().unwrap()]);
1139 }
1140}
1141
1142pub struct AppVersionHeader(SemanticVersion);
1143impl Header for AppVersionHeader {
1144 fn name() -> &'static HeaderName {
1145 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1146 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1147 }
1148
1149 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1150 where
1151 Self: Sized,
1152 I: Iterator<Item = &'i axum::http::HeaderValue>,
1153 {
1154 let version = values
1155 .next()
1156 .ok_or_else(axum::headers::Error::invalid)?
1157 .to_str()
1158 .map_err(|_| axum::headers::Error::invalid())?
1159 .parse()
1160 .map_err(|_| axum::headers::Error::invalid())?;
1161 Ok(Self(version))
1162 }
1163
1164 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1165 values.extend([self.0.to_string().parse().unwrap()]);
1166 }
1167}
1168
1169#[derive(Debug)]
1170pub struct ReleaseChannelHeader(String);
1171
1172impl Header for ReleaseChannelHeader {
1173 fn name() -> &'static HeaderName {
1174 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1175 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1176 }
1177
1178 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1179 where
1180 Self: Sized,
1181 I: Iterator<Item = &'i axum::http::HeaderValue>,
1182 {
1183 Ok(Self(
1184 values
1185 .next()
1186 .ok_or_else(axum::headers::Error::invalid)?
1187 .to_str()
1188 .map_err(|_| axum::headers::Error::invalid())?
1189 .to_owned(),
1190 ))
1191 }
1192
1193 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1194 values.extend([self.0.parse().unwrap()]);
1195 }
1196}
1197
1198pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1199 Router::new()
1200 .route("/rpc", get(handle_websocket_request))
1201 .layer(
1202 ServiceBuilder::new()
1203 .layer(Extension(server.app_state.clone()))
1204 .layer(middleware::from_fn(auth::validate_header)),
1205 )
1206 .route("/metrics", get(handle_metrics))
1207 .layer(Extension(server))
1208}
1209
1210pub async fn handle_websocket_request(
1211 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1212 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1213 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1214 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1215 Extension(server): Extension<Arc<Server>>,
1216 Extension(principal): Extension<Principal>,
1217 user_agent: Option<TypedHeader<UserAgent>>,
1218 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1219 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1220 ws: WebSocketUpgrade,
1221) -> axum::response::Response {
1222 if protocol_version != rpc::PROTOCOL_VERSION {
1223 return (
1224 StatusCode::UPGRADE_REQUIRED,
1225 "client must be upgraded".to_string(),
1226 )
1227 .into_response();
1228 }
1229
1230 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1231 return (
1232 StatusCode::UPGRADE_REQUIRED,
1233 "no version header found".to_string(),
1234 )
1235 .into_response();
1236 };
1237
1238 let release_channel = release_channel_header.map(|header| header.0.0);
1239
1240 if !version.can_collaborate() {
1241 return (
1242 StatusCode::UPGRADE_REQUIRED,
1243 "client must be upgraded".to_string(),
1244 )
1245 .into_response();
1246 }
1247
1248 let socket_address = socket_address.to_string();
1249
1250 // Acquire connection guard before WebSocket upgrade
1251 let connection_guard = match ConnectionGuard::try_acquire() {
1252 Ok(guard) => guard,
1253 Err(()) => {
1254 return (
1255 StatusCode::SERVICE_UNAVAILABLE,
1256 "Too many concurrent connections",
1257 )
1258 .into_response();
1259 }
1260 };
1261
1262 ws.on_upgrade(move |socket| {
1263 let socket = socket
1264 .map_ok(to_tungstenite_message)
1265 .err_into()
1266 .with(|message| async move { to_axum_message(message) });
1267 let connection = Connection::new(Box::pin(socket));
1268 async move {
1269 server
1270 .handle_connection(
1271 connection,
1272 socket_address,
1273 principal,
1274 version,
1275 release_channel,
1276 user_agent.map(|header| header.to_string()),
1277 country_code_header.map(|header| header.to_string()),
1278 system_id_header.map(|header| header.to_string()),
1279 None,
1280 Executor::Production,
1281 Some(connection_guard),
1282 )
1283 .await;
1284 }
1285 })
1286}
1287
1288pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1289 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1290 let connections_metric = CONNECTIONS_METRIC
1291 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1292
1293 let connections = server
1294 .connection_pool
1295 .lock()
1296 .connections()
1297 .filter(|connection| !connection.admin)
1298 .count();
1299 connections_metric.set(connections as _);
1300
1301 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1302 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1303 register_int_gauge!(
1304 "shared_projects",
1305 "number of open projects with one or more guests"
1306 )
1307 .unwrap()
1308 });
1309
1310 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1311 shared_projects_metric.set(shared_projects as _);
1312
1313 let encoder = prometheus::TextEncoder::new();
1314 let metric_families = prometheus::gather();
1315 let encoded_metrics = encoder
1316 .encode_to_string(&metric_families)
1317 .map_err(|err| anyhow!("{err}"))?;
1318 Ok(encoded_metrics)
1319}
1320
1321#[instrument(err, skip(executor))]
1322async fn connection_lost(
1323 session: Session,
1324 mut teardown: watch::Receiver<bool>,
1325 executor: Executor,
1326) -> Result<()> {
1327 session.peer.disconnect(session.connection_id);
1328 session
1329 .connection_pool()
1330 .await
1331 .remove_connection(session.connection_id)?;
1332
1333 session
1334 .db()
1335 .await
1336 .connection_lost(session.connection_id)
1337 .await
1338 .trace_err();
1339
1340 futures::select_biased! {
1341 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1342
1343 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1344 leave_room_for_session(&session, session.connection_id).await.trace_err();
1345 leave_channel_buffers_for_session(&session)
1346 .await
1347 .trace_err();
1348
1349 if !session
1350 .connection_pool()
1351 .await
1352 .is_user_online(session.user_id())
1353 {
1354 let db = session.db().await;
1355 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1356 room_updated(&room, &session.peer);
1357 }
1358 }
1359
1360 update_user_contacts(session.user_id(), &session).await?;
1361 },
1362 _ = teardown.changed().fuse() => {}
1363 }
1364
1365 Ok(())
1366}
1367
1368/// Acknowledges a ping from a client, used to keep the connection alive.
1369async fn ping(
1370 _: proto::Ping,
1371 response: Response<proto::Ping>,
1372 _session: MessageContext,
1373) -> Result<()> {
1374 response.send(proto::Ack {})?;
1375 Ok(())
1376}
1377
1378/// Creates a new room for calling (outside of channels)
1379async fn create_room(
1380 _request: proto::CreateRoom,
1381 response: Response<proto::CreateRoom>,
1382 session: MessageContext,
1383) -> Result<()> {
1384 let livekit_room = nanoid::nanoid!(30);
1385
1386 let live_kit_connection_info = util::maybe!(async {
1387 let live_kit = session.app_state.livekit_client.as_ref();
1388 let live_kit = live_kit?;
1389 let user_id = session.user_id().to_string();
1390
1391 let token = live_kit
1392 .room_token(&livekit_room, &user_id.to_string())
1393 .trace_err()?;
1394
1395 Some(proto::LiveKitConnectionInfo {
1396 server_url: live_kit.url().into(),
1397 token,
1398 can_publish: true,
1399 })
1400 })
1401 .await;
1402
1403 let room = session
1404 .db()
1405 .await
1406 .create_room(session.user_id(), session.connection_id, &livekit_room)
1407 .await?;
1408
1409 response.send(proto::CreateRoomResponse {
1410 room: Some(room.clone()),
1411 live_kit_connection_info,
1412 })?;
1413
1414 update_user_contacts(session.user_id(), &session).await?;
1415 Ok(())
1416}
1417
1418/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1419async fn join_room(
1420 request: proto::JoinRoom,
1421 response: Response<proto::JoinRoom>,
1422 session: MessageContext,
1423) -> Result<()> {
1424 let room_id = RoomId::from_proto(request.id);
1425
1426 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1427
1428 if let Some(channel_id) = channel_id {
1429 return join_channel_internal(channel_id, Box::new(response), session).await;
1430 }
1431
1432 let joined_room = {
1433 let room = session
1434 .db()
1435 .await
1436 .join_room(room_id, session.user_id(), session.connection_id)
1437 .await?;
1438 room_updated(&room.room, &session.peer);
1439 room.into_inner()
1440 };
1441
1442 for connection_id in session
1443 .connection_pool()
1444 .await
1445 .user_connection_ids(session.user_id())
1446 {
1447 session
1448 .peer
1449 .send(
1450 connection_id,
1451 proto::CallCanceled {
1452 room_id: room_id.to_proto(),
1453 },
1454 )
1455 .trace_err();
1456 }
1457
1458 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1459 {
1460 live_kit
1461 .room_token(
1462 &joined_room.room.livekit_room,
1463 &session.user_id().to_string(),
1464 )
1465 .trace_err()
1466 .map(|token| proto::LiveKitConnectionInfo {
1467 server_url: live_kit.url().into(),
1468 token,
1469 can_publish: true,
1470 })
1471 } else {
1472 None
1473 };
1474
1475 response.send(proto::JoinRoomResponse {
1476 room: Some(joined_room.room),
1477 channel_id: None,
1478 live_kit_connection_info,
1479 })?;
1480
1481 update_user_contacts(session.user_id(), &session).await?;
1482 Ok(())
1483}
1484
1485/// Rejoin room is used to reconnect to a room after connection errors.
1486async fn rejoin_room(
1487 request: proto::RejoinRoom,
1488 response: Response<proto::RejoinRoom>,
1489 session: MessageContext,
1490) -> Result<()> {
1491 let room;
1492 let channel;
1493 {
1494 let mut rejoined_room = session
1495 .db()
1496 .await
1497 .rejoin_room(request, session.user_id(), session.connection_id)
1498 .await?;
1499
1500 response.send(proto::RejoinRoomResponse {
1501 room: Some(rejoined_room.room.clone()),
1502 reshared_projects: rejoined_room
1503 .reshared_projects
1504 .iter()
1505 .map(|project| proto::ResharedProject {
1506 id: project.id.to_proto(),
1507 collaborators: project
1508 .collaborators
1509 .iter()
1510 .map(|collaborator| collaborator.to_proto())
1511 .collect(),
1512 })
1513 .collect(),
1514 rejoined_projects: rejoined_room
1515 .rejoined_projects
1516 .iter()
1517 .map(|rejoined_project| rejoined_project.to_proto())
1518 .collect(),
1519 })?;
1520 room_updated(&rejoined_room.room, &session.peer);
1521
1522 for project in &rejoined_room.reshared_projects {
1523 for collaborator in &project.collaborators {
1524 session
1525 .peer
1526 .send(
1527 collaborator.connection_id,
1528 proto::UpdateProjectCollaborator {
1529 project_id: project.id.to_proto(),
1530 old_peer_id: Some(project.old_connection_id.into()),
1531 new_peer_id: Some(session.connection_id.into()),
1532 },
1533 )
1534 .trace_err();
1535 }
1536
1537 broadcast(
1538 Some(session.connection_id),
1539 project
1540 .collaborators
1541 .iter()
1542 .map(|collaborator| collaborator.connection_id),
1543 |connection_id| {
1544 session.peer.forward_send(
1545 session.connection_id,
1546 connection_id,
1547 proto::UpdateProject {
1548 project_id: project.id.to_proto(),
1549 worktrees: project.worktrees.clone(),
1550 },
1551 )
1552 },
1553 );
1554 }
1555
1556 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1557
1558 let rejoined_room = rejoined_room.into_inner();
1559
1560 room = rejoined_room.room;
1561 channel = rejoined_room.channel;
1562 }
1563
1564 if let Some(channel) = channel {
1565 channel_updated(
1566 &channel,
1567 &room,
1568 &session.peer,
1569 &*session.connection_pool().await,
1570 );
1571 }
1572
1573 update_user_contacts(session.user_id(), &session).await?;
1574 Ok(())
1575}
1576
1577fn notify_rejoined_projects(
1578 rejoined_projects: &mut Vec<RejoinedProject>,
1579 session: &Session,
1580) -> Result<()> {
1581 for project in rejoined_projects.iter() {
1582 for collaborator in &project.collaborators {
1583 session
1584 .peer
1585 .send(
1586 collaborator.connection_id,
1587 proto::UpdateProjectCollaborator {
1588 project_id: project.id.to_proto(),
1589 old_peer_id: Some(project.old_connection_id.into()),
1590 new_peer_id: Some(session.connection_id.into()),
1591 },
1592 )
1593 .trace_err();
1594 }
1595 }
1596
1597 for project in rejoined_projects {
1598 for worktree in mem::take(&mut project.worktrees) {
1599 // Stream this worktree's entries.
1600 let message = proto::UpdateWorktree {
1601 project_id: project.id.to_proto(),
1602 worktree_id: worktree.id,
1603 abs_path: worktree.abs_path.clone(),
1604 root_name: worktree.root_name,
1605 updated_entries: worktree.updated_entries,
1606 removed_entries: worktree.removed_entries,
1607 scan_id: worktree.scan_id,
1608 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1609 updated_repositories: worktree.updated_repositories,
1610 removed_repositories: worktree.removed_repositories,
1611 };
1612 for update in proto::split_worktree_update(message) {
1613 session.peer.send(session.connection_id, update)?;
1614 }
1615
1616 // Stream this worktree's diagnostics.
1617 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1618 if let Some(summary) = worktree_diagnostics.next() {
1619 let message = proto::UpdateDiagnosticSummary {
1620 project_id: project.id.to_proto(),
1621 worktree_id: worktree.id,
1622 summary: Some(summary),
1623 more_summaries: worktree_diagnostics.collect(),
1624 };
1625 session.peer.send(session.connection_id, message)?;
1626 }
1627
1628 for settings_file in worktree.settings_files {
1629 session.peer.send(
1630 session.connection_id,
1631 proto::UpdateWorktreeSettings {
1632 project_id: project.id.to_proto(),
1633 worktree_id: worktree.id,
1634 path: settings_file.path,
1635 content: Some(settings_file.content),
1636 kind: Some(settings_file.kind.to_proto().into()),
1637 },
1638 )?;
1639 }
1640 }
1641
1642 for repository in mem::take(&mut project.updated_repositories) {
1643 for update in split_repository_update(repository) {
1644 session.peer.send(session.connection_id, update)?;
1645 }
1646 }
1647
1648 for id in mem::take(&mut project.removed_repositories) {
1649 session.peer.send(
1650 session.connection_id,
1651 proto::RemoveRepository {
1652 project_id: project.id.to_proto(),
1653 id,
1654 },
1655 )?;
1656 }
1657 }
1658
1659 Ok(())
1660}
1661
1662/// leave room disconnects from the room.
1663async fn leave_room(
1664 _: proto::LeaveRoom,
1665 response: Response<proto::LeaveRoom>,
1666 session: MessageContext,
1667) -> Result<()> {
1668 leave_room_for_session(&session, session.connection_id).await?;
1669 response.send(proto::Ack {})?;
1670 Ok(())
1671}
1672
1673/// Updates the permissions of someone else in the room.
1674async fn set_room_participant_role(
1675 request: proto::SetRoomParticipantRole,
1676 response: Response<proto::SetRoomParticipantRole>,
1677 session: MessageContext,
1678) -> Result<()> {
1679 let user_id = UserId::from_proto(request.user_id);
1680 let role = ChannelRole::from(request.role());
1681
1682 let (livekit_room, can_publish) = {
1683 let room = session
1684 .db()
1685 .await
1686 .set_room_participant_role(
1687 session.user_id(),
1688 RoomId::from_proto(request.room_id),
1689 user_id,
1690 role,
1691 )
1692 .await?;
1693
1694 let livekit_room = room.livekit_room.clone();
1695 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1696 room_updated(&room, &session.peer);
1697 (livekit_room, can_publish)
1698 };
1699
1700 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1701 live_kit
1702 .update_participant(
1703 livekit_room.clone(),
1704 request.user_id.to_string(),
1705 livekit_api::proto::ParticipantPermission {
1706 can_subscribe: true,
1707 can_publish,
1708 can_publish_data: can_publish,
1709 hidden: false,
1710 recorder: false,
1711 },
1712 )
1713 .await
1714 .trace_err();
1715 }
1716
1717 response.send(proto::Ack {})?;
1718 Ok(())
1719}
1720
1721/// Call someone else into the current room
1722async fn call(
1723 request: proto::Call,
1724 response: Response<proto::Call>,
1725 session: MessageContext,
1726) -> Result<()> {
1727 let room_id = RoomId::from_proto(request.room_id);
1728 let calling_user_id = session.user_id();
1729 let calling_connection_id = session.connection_id;
1730 let called_user_id = UserId::from_proto(request.called_user_id);
1731 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1732 if !session
1733 .db()
1734 .await
1735 .has_contact(calling_user_id, called_user_id)
1736 .await?
1737 {
1738 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1739 }
1740
1741 let incoming_call = {
1742 let (room, incoming_call) = &mut *session
1743 .db()
1744 .await
1745 .call(
1746 room_id,
1747 calling_user_id,
1748 calling_connection_id,
1749 called_user_id,
1750 initial_project_id,
1751 )
1752 .await?;
1753 room_updated(room, &session.peer);
1754 mem::take(incoming_call)
1755 };
1756 update_user_contacts(called_user_id, &session).await?;
1757
1758 let mut calls = session
1759 .connection_pool()
1760 .await
1761 .user_connection_ids(called_user_id)
1762 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1763 .collect::<FuturesUnordered<_>>();
1764
1765 while let Some(call_response) = calls.next().await {
1766 match call_response.as_ref() {
1767 Ok(_) => {
1768 response.send(proto::Ack {})?;
1769 return Ok(());
1770 }
1771 Err(_) => {
1772 call_response.trace_err();
1773 }
1774 }
1775 }
1776
1777 {
1778 let room = session
1779 .db()
1780 .await
1781 .call_failed(room_id, called_user_id)
1782 .await?;
1783 room_updated(&room, &session.peer);
1784 }
1785 update_user_contacts(called_user_id, &session).await?;
1786
1787 Err(anyhow!("failed to ring user"))?
1788}
1789
1790/// Cancel an outgoing call.
1791async fn cancel_call(
1792 request: proto::CancelCall,
1793 response: Response<proto::CancelCall>,
1794 session: MessageContext,
1795) -> Result<()> {
1796 let called_user_id = UserId::from_proto(request.called_user_id);
1797 let room_id = RoomId::from_proto(request.room_id);
1798 {
1799 let room = session
1800 .db()
1801 .await
1802 .cancel_call(room_id, session.connection_id, called_user_id)
1803 .await?;
1804 room_updated(&room, &session.peer);
1805 }
1806
1807 for connection_id in session
1808 .connection_pool()
1809 .await
1810 .user_connection_ids(called_user_id)
1811 {
1812 session
1813 .peer
1814 .send(
1815 connection_id,
1816 proto::CallCanceled {
1817 room_id: room_id.to_proto(),
1818 },
1819 )
1820 .trace_err();
1821 }
1822 response.send(proto::Ack {})?;
1823
1824 update_user_contacts(called_user_id, &session).await?;
1825 Ok(())
1826}
1827
1828/// Decline an incoming call.
1829async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1830 let room_id = RoomId::from_proto(message.room_id);
1831 {
1832 let room = session
1833 .db()
1834 .await
1835 .decline_call(Some(room_id), session.user_id())
1836 .await?
1837 .context("declining call")?;
1838 room_updated(&room, &session.peer);
1839 }
1840
1841 for connection_id in session
1842 .connection_pool()
1843 .await
1844 .user_connection_ids(session.user_id())
1845 {
1846 session
1847 .peer
1848 .send(
1849 connection_id,
1850 proto::CallCanceled {
1851 room_id: room_id.to_proto(),
1852 },
1853 )
1854 .trace_err();
1855 }
1856 update_user_contacts(session.user_id(), &session).await?;
1857 Ok(())
1858}
1859
1860/// Updates other participants in the room with your current location.
1861async fn update_participant_location(
1862 request: proto::UpdateParticipantLocation,
1863 response: Response<proto::UpdateParticipantLocation>,
1864 session: MessageContext,
1865) -> Result<()> {
1866 let room_id = RoomId::from_proto(request.room_id);
1867 let location = request.location.context("invalid location")?;
1868
1869 let db = session.db().await;
1870 let room = db
1871 .update_room_participant_location(room_id, session.connection_id, location)
1872 .await?;
1873
1874 room_updated(&room, &session.peer);
1875 response.send(proto::Ack {})?;
1876 Ok(())
1877}
1878
1879/// Share a project into the room.
1880async fn share_project(
1881 request: proto::ShareProject,
1882 response: Response<proto::ShareProject>,
1883 session: MessageContext,
1884) -> Result<()> {
1885 let (project_id, room) = &*session
1886 .db()
1887 .await
1888 .share_project(
1889 RoomId::from_proto(request.room_id),
1890 session.connection_id,
1891 &request.worktrees,
1892 request.is_ssh_project,
1893 )
1894 .await?;
1895 response.send(proto::ShareProjectResponse {
1896 project_id: project_id.to_proto(),
1897 })?;
1898 room_updated(room, &session.peer);
1899
1900 Ok(())
1901}
1902
1903/// Unshare a project from the room.
1904async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1905 let project_id = ProjectId::from_proto(message.project_id);
1906 unshare_project_internal(project_id, session.connection_id, &session).await
1907}
1908
1909async fn unshare_project_internal(
1910 project_id: ProjectId,
1911 connection_id: ConnectionId,
1912 session: &Session,
1913) -> Result<()> {
1914 let delete = {
1915 let room_guard = session
1916 .db()
1917 .await
1918 .unshare_project(project_id, connection_id)
1919 .await?;
1920
1921 let (delete, room, guest_connection_ids) = &*room_guard;
1922
1923 let message = proto::UnshareProject {
1924 project_id: project_id.to_proto(),
1925 };
1926
1927 broadcast(
1928 Some(connection_id),
1929 guest_connection_ids.iter().copied(),
1930 |conn_id| session.peer.send(conn_id, message.clone()),
1931 );
1932 if let Some(room) = room {
1933 room_updated(room, &session.peer);
1934 }
1935
1936 *delete
1937 };
1938
1939 if delete {
1940 let db = session.db().await;
1941 db.delete_project(project_id).await?;
1942 }
1943
1944 Ok(())
1945}
1946
1947/// Join someone elses shared project.
1948async fn join_project(
1949 request: proto::JoinProject,
1950 response: Response<proto::JoinProject>,
1951 session: MessageContext,
1952) -> Result<()> {
1953 let project_id = ProjectId::from_proto(request.project_id);
1954
1955 tracing::info!(%project_id, "join project");
1956
1957 let db = session.db().await;
1958 let (project, replica_id) = &mut *db
1959 .join_project(
1960 project_id,
1961 session.connection_id,
1962 session.user_id(),
1963 request.committer_name.clone(),
1964 request.committer_email.clone(),
1965 )
1966 .await?;
1967 drop(db);
1968 tracing::info!(%project_id, "join remote project");
1969 let collaborators = project
1970 .collaborators
1971 .iter()
1972 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1973 .map(|collaborator| collaborator.to_proto())
1974 .collect::<Vec<_>>();
1975 let project_id = project.id;
1976 let guest_user_id = session.user_id();
1977
1978 let worktrees = project
1979 .worktrees
1980 .iter()
1981 .map(|(id, worktree)| proto::WorktreeMetadata {
1982 id: *id,
1983 root_name: worktree.root_name.clone(),
1984 visible: worktree.visible,
1985 abs_path: worktree.abs_path.clone(),
1986 })
1987 .collect::<Vec<_>>();
1988
1989 let add_project_collaborator = proto::AddProjectCollaborator {
1990 project_id: project_id.to_proto(),
1991 collaborator: Some(proto::Collaborator {
1992 peer_id: Some(session.connection_id.into()),
1993 replica_id: replica_id.0 as u32,
1994 user_id: guest_user_id.to_proto(),
1995 is_host: false,
1996 committer_name: request.committer_name.clone(),
1997 committer_email: request.committer_email.clone(),
1998 }),
1999 };
2000
2001 for collaborator in &collaborators {
2002 session
2003 .peer
2004 .send(
2005 collaborator.peer_id.unwrap().into(),
2006 add_project_collaborator.clone(),
2007 )
2008 .trace_err();
2009 }
2010
2011 // First, we send the metadata associated with each worktree.
2012 let (language_servers, language_server_capabilities) = project
2013 .language_servers
2014 .clone()
2015 .into_iter()
2016 .map(|server| (server.server, server.capabilities))
2017 .unzip();
2018 response.send(proto::JoinProjectResponse {
2019 project_id: project.id.0 as u64,
2020 worktrees: worktrees.clone(),
2021 replica_id: replica_id.0 as u32,
2022 collaborators: collaborators.clone(),
2023 language_servers,
2024 language_server_capabilities,
2025 role: project.role.into(),
2026 })?;
2027
2028 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2029 // Stream this worktree's entries.
2030 let message = proto::UpdateWorktree {
2031 project_id: project_id.to_proto(),
2032 worktree_id,
2033 abs_path: worktree.abs_path.clone(),
2034 root_name: worktree.root_name,
2035 updated_entries: worktree.entries,
2036 removed_entries: Default::default(),
2037 scan_id: worktree.scan_id,
2038 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2039 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2040 removed_repositories: Default::default(),
2041 };
2042 for update in proto::split_worktree_update(message) {
2043 session.peer.send(session.connection_id, update.clone())?;
2044 }
2045
2046 // Stream this worktree's diagnostics.
2047 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2048 if let Some(summary) = worktree_diagnostics.next() {
2049 let message = proto::UpdateDiagnosticSummary {
2050 project_id: project.id.to_proto(),
2051 worktree_id: worktree.id,
2052 summary: Some(summary),
2053 more_summaries: worktree_diagnostics.collect(),
2054 };
2055 session.peer.send(session.connection_id, message)?;
2056 }
2057
2058 for settings_file in worktree.settings_files {
2059 session.peer.send(
2060 session.connection_id,
2061 proto::UpdateWorktreeSettings {
2062 project_id: project_id.to_proto(),
2063 worktree_id: worktree.id,
2064 path: settings_file.path,
2065 content: Some(settings_file.content),
2066 kind: Some(settings_file.kind.to_proto() as i32),
2067 },
2068 )?;
2069 }
2070 }
2071
2072 for repository in mem::take(&mut project.repositories) {
2073 for update in split_repository_update(repository) {
2074 session.peer.send(session.connection_id, update)?;
2075 }
2076 }
2077
2078 for language_server in &project.language_servers {
2079 session.peer.send(
2080 session.connection_id,
2081 proto::UpdateLanguageServer {
2082 project_id: project_id.to_proto(),
2083 server_name: Some(language_server.server.name.clone()),
2084 language_server_id: language_server.server.id,
2085 variant: Some(
2086 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2087 proto::LspDiskBasedDiagnosticsUpdated {},
2088 ),
2089 ),
2090 },
2091 )?;
2092 }
2093
2094 Ok(())
2095}
2096
2097/// Leave someone elses shared project.
2098async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2099 let sender_id = session.connection_id;
2100 let project_id = ProjectId::from_proto(request.project_id);
2101 let db = session.db().await;
2102
2103 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2104 tracing::info!(
2105 %project_id,
2106 "leave project"
2107 );
2108
2109 project_left(project, &session);
2110 if let Some(room) = room {
2111 room_updated(room, &session.peer);
2112 }
2113
2114 Ok(())
2115}
2116
2117/// Updates other participants with changes to the project
2118async fn update_project(
2119 request: proto::UpdateProject,
2120 response: Response<proto::UpdateProject>,
2121 session: MessageContext,
2122) -> Result<()> {
2123 let project_id = ProjectId::from_proto(request.project_id);
2124 let (room, guest_connection_ids) = &*session
2125 .db()
2126 .await
2127 .update_project(project_id, session.connection_id, &request.worktrees)
2128 .await?;
2129 broadcast(
2130 Some(session.connection_id),
2131 guest_connection_ids.iter().copied(),
2132 |connection_id| {
2133 session
2134 .peer
2135 .forward_send(session.connection_id, connection_id, request.clone())
2136 },
2137 );
2138 if let Some(room) = room {
2139 room_updated(room, &session.peer);
2140 }
2141 response.send(proto::Ack {})?;
2142
2143 Ok(())
2144}
2145
2146/// Updates other participants with changes to the worktree
2147async fn update_worktree(
2148 request: proto::UpdateWorktree,
2149 response: Response<proto::UpdateWorktree>,
2150 session: MessageContext,
2151) -> Result<()> {
2152 let guest_connection_ids = session
2153 .db()
2154 .await
2155 .update_worktree(&request, session.connection_id)
2156 .await?;
2157
2158 broadcast(
2159 Some(session.connection_id),
2160 guest_connection_ids.iter().copied(),
2161 |connection_id| {
2162 session
2163 .peer
2164 .forward_send(session.connection_id, connection_id, request.clone())
2165 },
2166 );
2167 response.send(proto::Ack {})?;
2168 Ok(())
2169}
2170
2171async fn update_repository(
2172 request: proto::UpdateRepository,
2173 response: Response<proto::UpdateRepository>,
2174 session: MessageContext,
2175) -> Result<()> {
2176 let guest_connection_ids = session
2177 .db()
2178 .await
2179 .update_repository(&request, session.connection_id)
2180 .await?;
2181
2182 broadcast(
2183 Some(session.connection_id),
2184 guest_connection_ids.iter().copied(),
2185 |connection_id| {
2186 session
2187 .peer
2188 .forward_send(session.connection_id, connection_id, request.clone())
2189 },
2190 );
2191 response.send(proto::Ack {})?;
2192 Ok(())
2193}
2194
2195async fn remove_repository(
2196 request: proto::RemoveRepository,
2197 response: Response<proto::RemoveRepository>,
2198 session: MessageContext,
2199) -> Result<()> {
2200 let guest_connection_ids = session
2201 .db()
2202 .await
2203 .remove_repository(&request, session.connection_id)
2204 .await?;
2205
2206 broadcast(
2207 Some(session.connection_id),
2208 guest_connection_ids.iter().copied(),
2209 |connection_id| {
2210 session
2211 .peer
2212 .forward_send(session.connection_id, connection_id, request.clone())
2213 },
2214 );
2215 response.send(proto::Ack {})?;
2216 Ok(())
2217}
2218
2219/// Updates other participants with changes to the diagnostics
2220async fn update_diagnostic_summary(
2221 message: proto::UpdateDiagnosticSummary,
2222 session: MessageContext,
2223) -> Result<()> {
2224 let guest_connection_ids = session
2225 .db()
2226 .await
2227 .update_diagnostic_summary(&message, session.connection_id)
2228 .await?;
2229
2230 broadcast(
2231 Some(session.connection_id),
2232 guest_connection_ids.iter().copied(),
2233 |connection_id| {
2234 session
2235 .peer
2236 .forward_send(session.connection_id, connection_id, message.clone())
2237 },
2238 );
2239
2240 Ok(())
2241}
2242
2243/// Updates other participants with changes to the worktree settings
2244async fn update_worktree_settings(
2245 message: proto::UpdateWorktreeSettings,
2246 session: MessageContext,
2247) -> Result<()> {
2248 let guest_connection_ids = session
2249 .db()
2250 .await
2251 .update_worktree_settings(&message, session.connection_id)
2252 .await?;
2253
2254 broadcast(
2255 Some(session.connection_id),
2256 guest_connection_ids.iter().copied(),
2257 |connection_id| {
2258 session
2259 .peer
2260 .forward_send(session.connection_id, connection_id, message.clone())
2261 },
2262 );
2263
2264 Ok(())
2265}
2266
2267/// Notify other participants that a language server has started.
2268async fn start_language_server(
2269 request: proto::StartLanguageServer,
2270 session: MessageContext,
2271) -> Result<()> {
2272 let guest_connection_ids = session
2273 .db()
2274 .await
2275 .start_language_server(&request, session.connection_id)
2276 .await?;
2277
2278 broadcast(
2279 Some(session.connection_id),
2280 guest_connection_ids.iter().copied(),
2281 |connection_id| {
2282 session
2283 .peer
2284 .forward_send(session.connection_id, connection_id, request.clone())
2285 },
2286 );
2287 Ok(())
2288}
2289
2290/// Notify other participants that a language server has changed.
2291async fn update_language_server(
2292 request: proto::UpdateLanguageServer,
2293 session: MessageContext,
2294) -> Result<()> {
2295 let project_id = ProjectId::from_proto(request.project_id);
2296 let db = session.db().await;
2297
2298 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2299 {
2300 if let Some(capabilities) = update.capabilities.clone() {
2301 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2302 .await?;
2303 }
2304 }
2305
2306 let project_connection_ids = db
2307 .project_connection_ids(project_id, session.connection_id, true)
2308 .await?;
2309 broadcast(
2310 Some(session.connection_id),
2311 project_connection_ids.iter().copied(),
2312 |connection_id| {
2313 session
2314 .peer
2315 .forward_send(session.connection_id, connection_id, request.clone())
2316 },
2317 );
2318 Ok(())
2319}
2320
2321/// forward a project request to the host. These requests should be read only
2322/// as guests are allowed to send them.
2323async fn forward_read_only_project_request<T>(
2324 request: T,
2325 response: Response<T>,
2326 session: MessageContext,
2327) -> Result<()>
2328where
2329 T: EntityMessage + RequestMessage,
2330{
2331 let project_id = ProjectId::from_proto(request.remote_entity_id());
2332 let host_connection_id = session
2333 .db()
2334 .await
2335 .host_for_read_only_project_request(project_id, session.connection_id)
2336 .await?;
2337 let payload = session.forward_request(host_connection_id, request).await?;
2338 response.send(payload)?;
2339 Ok(())
2340}
2341
2342/// forward a project request to the host. These requests are disallowed
2343/// for guests.
2344async fn forward_mutating_project_request<T>(
2345 request: T,
2346 response: Response<T>,
2347 session: MessageContext,
2348) -> Result<()>
2349where
2350 T: EntityMessage + RequestMessage,
2351{
2352 let project_id = ProjectId::from_proto(request.remote_entity_id());
2353
2354 let host_connection_id = session
2355 .db()
2356 .await
2357 .host_for_mutating_project_request(project_id, session.connection_id)
2358 .await?;
2359 let payload = session.forward_request(host_connection_id, request).await?;
2360 response.send(payload)?;
2361 Ok(())
2362}
2363
2364async fn multi_lsp_query(
2365 request: MultiLspQuery,
2366 response: Response<MultiLspQuery>,
2367 session: MessageContext,
2368) -> Result<()> {
2369 tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2370 tracing::info!("multi_lsp_query message received");
2371 forward_mutating_project_request(request, response, session).await
2372}
2373
2374/// Notify other participants that a new buffer has been created
2375async fn create_buffer_for_peer(
2376 request: proto::CreateBufferForPeer,
2377 session: MessageContext,
2378) -> Result<()> {
2379 session
2380 .db()
2381 .await
2382 .check_user_is_project_host(
2383 ProjectId::from_proto(request.project_id),
2384 session.connection_id,
2385 )
2386 .await?;
2387 let peer_id = request.peer_id.context("invalid peer id")?;
2388 session
2389 .peer
2390 .forward_send(session.connection_id, peer_id.into(), request)?;
2391 Ok(())
2392}
2393
2394/// Notify other participants that a buffer has been updated. This is
2395/// allowed for guests as long as the update is limited to selections.
2396async fn update_buffer(
2397 request: proto::UpdateBuffer,
2398 response: Response<proto::UpdateBuffer>,
2399 session: MessageContext,
2400) -> Result<()> {
2401 let project_id = ProjectId::from_proto(request.project_id);
2402 let mut capability = Capability::ReadOnly;
2403
2404 for op in request.operations.iter() {
2405 match op.variant {
2406 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2407 Some(_) => capability = Capability::ReadWrite,
2408 }
2409 }
2410
2411 let host = {
2412 let guard = session
2413 .db()
2414 .await
2415 .connections_for_buffer_update(project_id, session.connection_id, capability)
2416 .await?;
2417
2418 let (host, guests) = &*guard;
2419
2420 broadcast(
2421 Some(session.connection_id),
2422 guests.clone(),
2423 |connection_id| {
2424 session
2425 .peer
2426 .forward_send(session.connection_id, connection_id, request.clone())
2427 },
2428 );
2429
2430 *host
2431 };
2432
2433 if host != session.connection_id {
2434 session.forward_request(host, request.clone()).await?;
2435 }
2436
2437 response.send(proto::Ack {})?;
2438 Ok(())
2439}
2440
2441async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2442 let project_id = ProjectId::from_proto(message.project_id);
2443
2444 let operation = message.operation.as_ref().context("invalid operation")?;
2445 let capability = match operation.variant.as_ref() {
2446 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2447 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2448 match buffer_op.variant {
2449 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2450 Capability::ReadOnly
2451 }
2452 _ => Capability::ReadWrite,
2453 }
2454 } else {
2455 Capability::ReadWrite
2456 }
2457 }
2458 Some(_) => Capability::ReadWrite,
2459 None => Capability::ReadOnly,
2460 };
2461
2462 let guard = session
2463 .db()
2464 .await
2465 .connections_for_buffer_update(project_id, session.connection_id, capability)
2466 .await?;
2467
2468 let (host, guests) = &*guard;
2469
2470 broadcast(
2471 Some(session.connection_id),
2472 guests.iter().chain([host]).copied(),
2473 |connection_id| {
2474 session
2475 .peer
2476 .forward_send(session.connection_id, connection_id, message.clone())
2477 },
2478 );
2479
2480 Ok(())
2481}
2482
2483/// Notify other participants that a project has been updated.
2484async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2485 request: T,
2486 session: MessageContext,
2487) -> Result<()> {
2488 let project_id = ProjectId::from_proto(request.remote_entity_id());
2489 let project_connection_ids = session
2490 .db()
2491 .await
2492 .project_connection_ids(project_id, session.connection_id, false)
2493 .await?;
2494
2495 broadcast(
2496 Some(session.connection_id),
2497 project_connection_ids.iter().copied(),
2498 |connection_id| {
2499 session
2500 .peer
2501 .forward_send(session.connection_id, connection_id, request.clone())
2502 },
2503 );
2504 Ok(())
2505}
2506
2507/// Start following another user in a call.
2508async fn follow(
2509 request: proto::Follow,
2510 response: Response<proto::Follow>,
2511 session: MessageContext,
2512) -> Result<()> {
2513 let room_id = RoomId::from_proto(request.room_id);
2514 let project_id = request.project_id.map(ProjectId::from_proto);
2515 let leader_id = request.leader_id.context("invalid leader id")?.into();
2516 let follower_id = session.connection_id;
2517
2518 session
2519 .db()
2520 .await
2521 .check_room_participants(room_id, leader_id, session.connection_id)
2522 .await?;
2523
2524 let response_payload = session.forward_request(leader_id, request).await?;
2525 response.send(response_payload)?;
2526
2527 if let Some(project_id) = project_id {
2528 let room = session
2529 .db()
2530 .await
2531 .follow(room_id, project_id, leader_id, follower_id)
2532 .await?;
2533 room_updated(&room, &session.peer);
2534 }
2535
2536 Ok(())
2537}
2538
2539/// Stop following another user in a call.
2540async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2541 let room_id = RoomId::from_proto(request.room_id);
2542 let project_id = request.project_id.map(ProjectId::from_proto);
2543 let leader_id = request.leader_id.context("invalid leader id")?.into();
2544 let follower_id = session.connection_id;
2545
2546 session
2547 .db()
2548 .await
2549 .check_room_participants(room_id, leader_id, session.connection_id)
2550 .await?;
2551
2552 session
2553 .peer
2554 .forward_send(session.connection_id, leader_id, request)?;
2555
2556 if let Some(project_id) = project_id {
2557 let room = session
2558 .db()
2559 .await
2560 .unfollow(room_id, project_id, leader_id, follower_id)
2561 .await?;
2562 room_updated(&room, &session.peer);
2563 }
2564
2565 Ok(())
2566}
2567
2568/// Notify everyone following you of your current location.
2569async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2570 let room_id = RoomId::from_proto(request.room_id);
2571 let database = session.db.lock().await;
2572
2573 let connection_ids = if let Some(project_id) = request.project_id {
2574 let project_id = ProjectId::from_proto(project_id);
2575 database
2576 .project_connection_ids(project_id, session.connection_id, true)
2577 .await?
2578 } else {
2579 database
2580 .room_connection_ids(room_id, session.connection_id)
2581 .await?
2582 };
2583
2584 // For now, don't send view update messages back to that view's current leader.
2585 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2586 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2587 _ => None,
2588 });
2589
2590 for connection_id in connection_ids.iter().cloned() {
2591 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2592 session
2593 .peer
2594 .forward_send(session.connection_id, connection_id, request.clone())?;
2595 }
2596 }
2597 Ok(())
2598}
2599
2600/// Get public data about users.
2601async fn get_users(
2602 request: proto::GetUsers,
2603 response: Response<proto::GetUsers>,
2604 session: MessageContext,
2605) -> Result<()> {
2606 let user_ids = request
2607 .user_ids
2608 .into_iter()
2609 .map(UserId::from_proto)
2610 .collect();
2611 let users = session
2612 .db()
2613 .await
2614 .get_users_by_ids(user_ids)
2615 .await?
2616 .into_iter()
2617 .map(|user| proto::User {
2618 id: user.id.to_proto(),
2619 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2620 github_login: user.github_login,
2621 name: user.name,
2622 })
2623 .collect();
2624 response.send(proto::UsersResponse { users })?;
2625 Ok(())
2626}
2627
2628/// Search for users (to invite) buy Github login
2629async fn fuzzy_search_users(
2630 request: proto::FuzzySearchUsers,
2631 response: Response<proto::FuzzySearchUsers>,
2632 session: MessageContext,
2633) -> Result<()> {
2634 let query = request.query;
2635 let users = match query.len() {
2636 0 => vec![],
2637 1 | 2 => session
2638 .db()
2639 .await
2640 .get_user_by_github_login(&query)
2641 .await?
2642 .into_iter()
2643 .collect(),
2644 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2645 };
2646 let users = users
2647 .into_iter()
2648 .filter(|user| user.id != session.user_id())
2649 .map(|user| proto::User {
2650 id: user.id.to_proto(),
2651 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2652 github_login: user.github_login,
2653 name: user.name,
2654 })
2655 .collect();
2656 response.send(proto::UsersResponse { users })?;
2657 Ok(())
2658}
2659
2660/// Send a contact request to another user.
2661async fn request_contact(
2662 request: proto::RequestContact,
2663 response: Response<proto::RequestContact>,
2664 session: MessageContext,
2665) -> Result<()> {
2666 let requester_id = session.user_id();
2667 let responder_id = UserId::from_proto(request.responder_id);
2668 if requester_id == responder_id {
2669 return Err(anyhow!("cannot add yourself as a contact"))?;
2670 }
2671
2672 let notifications = session
2673 .db()
2674 .await
2675 .send_contact_request(requester_id, responder_id)
2676 .await?;
2677
2678 // Update outgoing contact requests of requester
2679 let mut update = proto::UpdateContacts::default();
2680 update.outgoing_requests.push(responder_id.to_proto());
2681 for connection_id in session
2682 .connection_pool()
2683 .await
2684 .user_connection_ids(requester_id)
2685 {
2686 session.peer.send(connection_id, update.clone())?;
2687 }
2688
2689 // Update incoming contact requests of responder
2690 let mut update = proto::UpdateContacts::default();
2691 update
2692 .incoming_requests
2693 .push(proto::IncomingContactRequest {
2694 requester_id: requester_id.to_proto(),
2695 });
2696 let connection_pool = session.connection_pool().await;
2697 for connection_id in connection_pool.user_connection_ids(responder_id) {
2698 session.peer.send(connection_id, update.clone())?;
2699 }
2700
2701 send_notifications(&connection_pool, &session.peer, notifications);
2702
2703 response.send(proto::Ack {})?;
2704 Ok(())
2705}
2706
2707/// Accept or decline a contact request
2708async fn respond_to_contact_request(
2709 request: proto::RespondToContactRequest,
2710 response: Response<proto::RespondToContactRequest>,
2711 session: MessageContext,
2712) -> Result<()> {
2713 let responder_id = session.user_id();
2714 let requester_id = UserId::from_proto(request.requester_id);
2715 let db = session.db().await;
2716 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2717 db.dismiss_contact_notification(responder_id, requester_id)
2718 .await?;
2719 } else {
2720 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2721
2722 let notifications = db
2723 .respond_to_contact_request(responder_id, requester_id, accept)
2724 .await?;
2725 let requester_busy = db.is_user_busy(requester_id).await?;
2726 let responder_busy = db.is_user_busy(responder_id).await?;
2727
2728 let pool = session.connection_pool().await;
2729 // Update responder with new contact
2730 let mut update = proto::UpdateContacts::default();
2731 if accept {
2732 update
2733 .contacts
2734 .push(contact_for_user(requester_id, requester_busy, &pool));
2735 }
2736 update
2737 .remove_incoming_requests
2738 .push(requester_id.to_proto());
2739 for connection_id in pool.user_connection_ids(responder_id) {
2740 session.peer.send(connection_id, update.clone())?;
2741 }
2742
2743 // Update requester with new contact
2744 let mut update = proto::UpdateContacts::default();
2745 if accept {
2746 update
2747 .contacts
2748 .push(contact_for_user(responder_id, responder_busy, &pool));
2749 }
2750 update
2751 .remove_outgoing_requests
2752 .push(responder_id.to_proto());
2753
2754 for connection_id in pool.user_connection_ids(requester_id) {
2755 session.peer.send(connection_id, update.clone())?;
2756 }
2757
2758 send_notifications(&pool, &session.peer, notifications);
2759 }
2760
2761 response.send(proto::Ack {})?;
2762 Ok(())
2763}
2764
2765/// Remove a contact.
2766async fn remove_contact(
2767 request: proto::RemoveContact,
2768 response: Response<proto::RemoveContact>,
2769 session: MessageContext,
2770) -> Result<()> {
2771 let requester_id = session.user_id();
2772 let responder_id = UserId::from_proto(request.user_id);
2773 let db = session.db().await;
2774 let (contact_accepted, deleted_notification_id) =
2775 db.remove_contact(requester_id, responder_id).await?;
2776
2777 let pool = session.connection_pool().await;
2778 // Update outgoing contact requests of requester
2779 let mut update = proto::UpdateContacts::default();
2780 if contact_accepted {
2781 update.remove_contacts.push(responder_id.to_proto());
2782 } else {
2783 update
2784 .remove_outgoing_requests
2785 .push(responder_id.to_proto());
2786 }
2787 for connection_id in pool.user_connection_ids(requester_id) {
2788 session.peer.send(connection_id, update.clone())?;
2789 }
2790
2791 // Update incoming contact requests of responder
2792 let mut update = proto::UpdateContacts::default();
2793 if contact_accepted {
2794 update.remove_contacts.push(requester_id.to_proto());
2795 } else {
2796 update
2797 .remove_incoming_requests
2798 .push(requester_id.to_proto());
2799 }
2800 for connection_id in pool.user_connection_ids(responder_id) {
2801 session.peer.send(connection_id, update.clone())?;
2802 if let Some(notification_id) = deleted_notification_id {
2803 session.peer.send(
2804 connection_id,
2805 proto::DeleteNotification {
2806 notification_id: notification_id.to_proto(),
2807 },
2808 )?;
2809 }
2810 }
2811
2812 response.send(proto::Ack {})?;
2813 Ok(())
2814}
2815
2816fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2817 version.0.minor() < 139
2818}
2819
2820async fn subscribe_to_channels(
2821 _: proto::SubscribeToChannels,
2822 session: MessageContext,
2823) -> Result<()> {
2824 subscribe_user_to_channels(session.user_id(), &session).await?;
2825 Ok(())
2826}
2827
2828async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2829 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2830 let mut pool = session.connection_pool().await;
2831 for membership in &channels_for_user.channel_memberships {
2832 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2833 }
2834 session.peer.send(
2835 session.connection_id,
2836 build_update_user_channels(&channels_for_user),
2837 )?;
2838 session.peer.send(
2839 session.connection_id,
2840 build_channels_update(channels_for_user),
2841 )?;
2842 Ok(())
2843}
2844
2845/// Creates a new channel.
2846async fn create_channel(
2847 request: proto::CreateChannel,
2848 response: Response<proto::CreateChannel>,
2849 session: MessageContext,
2850) -> Result<()> {
2851 let db = session.db().await;
2852
2853 let parent_id = request.parent_id.map(ChannelId::from_proto);
2854 let (channel, membership) = db
2855 .create_channel(&request.name, parent_id, session.user_id())
2856 .await?;
2857
2858 let root_id = channel.root_id();
2859 let channel = Channel::from_model(channel);
2860
2861 response.send(proto::CreateChannelResponse {
2862 channel: Some(channel.to_proto()),
2863 parent_id: request.parent_id,
2864 })?;
2865
2866 let mut connection_pool = session.connection_pool().await;
2867 if let Some(membership) = membership {
2868 connection_pool.subscribe_to_channel(
2869 membership.user_id,
2870 membership.channel_id,
2871 membership.role,
2872 );
2873 let update = proto::UpdateUserChannels {
2874 channel_memberships: vec![proto::ChannelMembership {
2875 channel_id: membership.channel_id.to_proto(),
2876 role: membership.role.into(),
2877 }],
2878 ..Default::default()
2879 };
2880 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2881 session.peer.send(connection_id, update.clone())?;
2882 }
2883 }
2884
2885 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2886 if !role.can_see_channel(channel.visibility) {
2887 continue;
2888 }
2889
2890 let update = proto::UpdateChannels {
2891 channels: vec![channel.to_proto()],
2892 ..Default::default()
2893 };
2894 session.peer.send(connection_id, update.clone())?;
2895 }
2896
2897 Ok(())
2898}
2899
2900/// Delete a channel
2901async fn delete_channel(
2902 request: proto::DeleteChannel,
2903 response: Response<proto::DeleteChannel>,
2904 session: MessageContext,
2905) -> Result<()> {
2906 let db = session.db().await;
2907
2908 let channel_id = request.channel_id;
2909 let (root_channel, removed_channels) = db
2910 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2911 .await?;
2912 response.send(proto::Ack {})?;
2913
2914 // Notify members of removed channels
2915 let mut update = proto::UpdateChannels::default();
2916 update
2917 .delete_channels
2918 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2919
2920 let connection_pool = session.connection_pool().await;
2921 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2922 session.peer.send(connection_id, update.clone())?;
2923 }
2924
2925 Ok(())
2926}
2927
2928/// Invite someone to join a channel.
2929async fn invite_channel_member(
2930 request: proto::InviteChannelMember,
2931 response: Response<proto::InviteChannelMember>,
2932 session: MessageContext,
2933) -> Result<()> {
2934 let db = session.db().await;
2935 let channel_id = ChannelId::from_proto(request.channel_id);
2936 let invitee_id = UserId::from_proto(request.user_id);
2937 let InviteMemberResult {
2938 channel,
2939 notifications,
2940 } = db
2941 .invite_channel_member(
2942 channel_id,
2943 invitee_id,
2944 session.user_id(),
2945 request.role().into(),
2946 )
2947 .await?;
2948
2949 let update = proto::UpdateChannels {
2950 channel_invitations: vec![channel.to_proto()],
2951 ..Default::default()
2952 };
2953
2954 let connection_pool = session.connection_pool().await;
2955 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2956 session.peer.send(connection_id, update.clone())?;
2957 }
2958
2959 send_notifications(&connection_pool, &session.peer, notifications);
2960
2961 response.send(proto::Ack {})?;
2962 Ok(())
2963}
2964
2965/// remove someone from a channel
2966async fn remove_channel_member(
2967 request: proto::RemoveChannelMember,
2968 response: Response<proto::RemoveChannelMember>,
2969 session: MessageContext,
2970) -> Result<()> {
2971 let db = session.db().await;
2972 let channel_id = ChannelId::from_proto(request.channel_id);
2973 let member_id = UserId::from_proto(request.user_id);
2974
2975 let RemoveChannelMemberResult {
2976 membership_update,
2977 notification_id,
2978 } = db
2979 .remove_channel_member(channel_id, member_id, session.user_id())
2980 .await?;
2981
2982 let mut connection_pool = session.connection_pool().await;
2983 notify_membership_updated(
2984 &mut connection_pool,
2985 membership_update,
2986 member_id,
2987 &session.peer,
2988 );
2989 for connection_id in connection_pool.user_connection_ids(member_id) {
2990 if let Some(notification_id) = notification_id {
2991 session
2992 .peer
2993 .send(
2994 connection_id,
2995 proto::DeleteNotification {
2996 notification_id: notification_id.to_proto(),
2997 },
2998 )
2999 .trace_err();
3000 }
3001 }
3002
3003 response.send(proto::Ack {})?;
3004 Ok(())
3005}
3006
3007/// Toggle the channel between public and private.
3008/// Care is taken to maintain the invariant that public channels only descend from public channels,
3009/// (though members-only channels can appear at any point in the hierarchy).
3010async fn set_channel_visibility(
3011 request: proto::SetChannelVisibility,
3012 response: Response<proto::SetChannelVisibility>,
3013 session: MessageContext,
3014) -> Result<()> {
3015 let db = session.db().await;
3016 let channel_id = ChannelId::from_proto(request.channel_id);
3017 let visibility = request.visibility().into();
3018
3019 let channel_model = db
3020 .set_channel_visibility(channel_id, visibility, session.user_id())
3021 .await?;
3022 let root_id = channel_model.root_id();
3023 let channel = Channel::from_model(channel_model);
3024
3025 let mut connection_pool = session.connection_pool().await;
3026 for (user_id, role) in connection_pool
3027 .channel_user_ids(root_id)
3028 .collect::<Vec<_>>()
3029 .into_iter()
3030 {
3031 let update = if role.can_see_channel(channel.visibility) {
3032 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3033 proto::UpdateChannels {
3034 channels: vec![channel.to_proto()],
3035 ..Default::default()
3036 }
3037 } else {
3038 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3039 proto::UpdateChannels {
3040 delete_channels: vec![channel.id.to_proto()],
3041 ..Default::default()
3042 }
3043 };
3044
3045 for connection_id in connection_pool.user_connection_ids(user_id) {
3046 session.peer.send(connection_id, update.clone())?;
3047 }
3048 }
3049
3050 response.send(proto::Ack {})?;
3051 Ok(())
3052}
3053
3054/// Alter the role for a user in the channel.
3055async fn set_channel_member_role(
3056 request: proto::SetChannelMemberRole,
3057 response: Response<proto::SetChannelMemberRole>,
3058 session: MessageContext,
3059) -> Result<()> {
3060 let db = session.db().await;
3061 let channel_id = ChannelId::from_proto(request.channel_id);
3062 let member_id = UserId::from_proto(request.user_id);
3063 let result = db
3064 .set_channel_member_role(
3065 channel_id,
3066 session.user_id(),
3067 member_id,
3068 request.role().into(),
3069 )
3070 .await?;
3071
3072 match result {
3073 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3074 let mut connection_pool = session.connection_pool().await;
3075 notify_membership_updated(
3076 &mut connection_pool,
3077 membership_update,
3078 member_id,
3079 &session.peer,
3080 )
3081 }
3082 db::SetMemberRoleResult::InviteUpdated(channel) => {
3083 let update = proto::UpdateChannels {
3084 channel_invitations: vec![channel.to_proto()],
3085 ..Default::default()
3086 };
3087
3088 for connection_id in session
3089 .connection_pool()
3090 .await
3091 .user_connection_ids(member_id)
3092 {
3093 session.peer.send(connection_id, update.clone())?;
3094 }
3095 }
3096 }
3097
3098 response.send(proto::Ack {})?;
3099 Ok(())
3100}
3101
3102/// Change the name of a channel
3103async fn rename_channel(
3104 request: proto::RenameChannel,
3105 response: Response<proto::RenameChannel>,
3106 session: MessageContext,
3107) -> Result<()> {
3108 let db = session.db().await;
3109 let channel_id = ChannelId::from_proto(request.channel_id);
3110 let channel_model = db
3111 .rename_channel(channel_id, session.user_id(), &request.name)
3112 .await?;
3113 let root_id = channel_model.root_id();
3114 let channel = Channel::from_model(channel_model);
3115
3116 response.send(proto::RenameChannelResponse {
3117 channel: Some(channel.to_proto()),
3118 })?;
3119
3120 let connection_pool = session.connection_pool().await;
3121 let update = proto::UpdateChannels {
3122 channels: vec![channel.to_proto()],
3123 ..Default::default()
3124 };
3125 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3126 if role.can_see_channel(channel.visibility) {
3127 session.peer.send(connection_id, update.clone())?;
3128 }
3129 }
3130
3131 Ok(())
3132}
3133
3134/// Move a channel to a new parent.
3135async fn move_channel(
3136 request: proto::MoveChannel,
3137 response: Response<proto::MoveChannel>,
3138 session: MessageContext,
3139) -> Result<()> {
3140 let channel_id = ChannelId::from_proto(request.channel_id);
3141 let to = ChannelId::from_proto(request.to);
3142
3143 let (root_id, channels) = session
3144 .db()
3145 .await
3146 .move_channel(channel_id, to, session.user_id())
3147 .await?;
3148
3149 let connection_pool = session.connection_pool().await;
3150 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3151 let channels = channels
3152 .iter()
3153 .filter_map(|channel| {
3154 if role.can_see_channel(channel.visibility) {
3155 Some(channel.to_proto())
3156 } else {
3157 None
3158 }
3159 })
3160 .collect::<Vec<_>>();
3161 if channels.is_empty() {
3162 continue;
3163 }
3164
3165 let update = proto::UpdateChannels {
3166 channels,
3167 ..Default::default()
3168 };
3169
3170 session.peer.send(connection_id, update.clone())?;
3171 }
3172
3173 response.send(Ack {})?;
3174 Ok(())
3175}
3176
3177async fn reorder_channel(
3178 request: proto::ReorderChannel,
3179 response: Response<proto::ReorderChannel>,
3180 session: MessageContext,
3181) -> Result<()> {
3182 let channel_id = ChannelId::from_proto(request.channel_id);
3183 let direction = request.direction();
3184
3185 let updated_channels = session
3186 .db()
3187 .await
3188 .reorder_channel(channel_id, direction, session.user_id())
3189 .await?;
3190
3191 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3192 let connection_pool = session.connection_pool().await;
3193 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3194 let channels = updated_channels
3195 .iter()
3196 .filter_map(|channel| {
3197 if role.can_see_channel(channel.visibility) {
3198 Some(channel.to_proto())
3199 } else {
3200 None
3201 }
3202 })
3203 .collect::<Vec<_>>();
3204
3205 if channels.is_empty() {
3206 continue;
3207 }
3208
3209 let update = proto::UpdateChannels {
3210 channels,
3211 ..Default::default()
3212 };
3213
3214 session.peer.send(connection_id, update.clone())?;
3215 }
3216 }
3217
3218 response.send(Ack {})?;
3219 Ok(())
3220}
3221
3222/// Get the list of channel members
3223async fn get_channel_members(
3224 request: proto::GetChannelMembers,
3225 response: Response<proto::GetChannelMembers>,
3226 session: MessageContext,
3227) -> Result<()> {
3228 let db = session.db().await;
3229 let channel_id = ChannelId::from_proto(request.channel_id);
3230 let limit = if request.limit == 0 {
3231 u16::MAX as u64
3232 } else {
3233 request.limit
3234 };
3235 let (members, users) = db
3236 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3237 .await?;
3238 response.send(proto::GetChannelMembersResponse { members, users })?;
3239 Ok(())
3240}
3241
3242/// Accept or decline a channel invitation.
3243async fn respond_to_channel_invite(
3244 request: proto::RespondToChannelInvite,
3245 response: Response<proto::RespondToChannelInvite>,
3246 session: MessageContext,
3247) -> Result<()> {
3248 let db = session.db().await;
3249 let channel_id = ChannelId::from_proto(request.channel_id);
3250 let RespondToChannelInvite {
3251 membership_update,
3252 notifications,
3253 } = db
3254 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3255 .await?;
3256
3257 let mut connection_pool = session.connection_pool().await;
3258 if let Some(membership_update) = membership_update {
3259 notify_membership_updated(
3260 &mut connection_pool,
3261 membership_update,
3262 session.user_id(),
3263 &session.peer,
3264 );
3265 } else {
3266 let update = proto::UpdateChannels {
3267 remove_channel_invitations: vec![channel_id.to_proto()],
3268 ..Default::default()
3269 };
3270
3271 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3272 session.peer.send(connection_id, update.clone())?;
3273 }
3274 };
3275
3276 send_notifications(&connection_pool, &session.peer, notifications);
3277
3278 response.send(proto::Ack {})?;
3279
3280 Ok(())
3281}
3282
3283/// Join the channels' room
3284async fn join_channel(
3285 request: proto::JoinChannel,
3286 response: Response<proto::JoinChannel>,
3287 session: MessageContext,
3288) -> Result<()> {
3289 let channel_id = ChannelId::from_proto(request.channel_id);
3290 join_channel_internal(channel_id, Box::new(response), session).await
3291}
3292
3293trait JoinChannelInternalResponse {
3294 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3295}
3296impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3297 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3298 Response::<proto::JoinChannel>::send(self, result)
3299 }
3300}
3301impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3302 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3303 Response::<proto::JoinRoom>::send(self, result)
3304 }
3305}
3306
3307async fn join_channel_internal(
3308 channel_id: ChannelId,
3309 response: Box<impl JoinChannelInternalResponse>,
3310 session: MessageContext,
3311) -> Result<()> {
3312 let joined_room = {
3313 let mut db = session.db().await;
3314 // If zed quits without leaving the room, and the user re-opens zed before the
3315 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3316 // room they were in.
3317 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3318 tracing::info!(
3319 stale_connection_id = %connection,
3320 "cleaning up stale connection",
3321 );
3322 drop(db);
3323 leave_room_for_session(&session, connection).await?;
3324 db = session.db().await;
3325 }
3326
3327 let (joined_room, membership_updated, role) = db
3328 .join_channel(channel_id, session.user_id(), session.connection_id)
3329 .await?;
3330
3331 let live_kit_connection_info =
3332 session
3333 .app_state
3334 .livekit_client
3335 .as_ref()
3336 .and_then(|live_kit| {
3337 let (can_publish, token) = if role == ChannelRole::Guest {
3338 (
3339 false,
3340 live_kit
3341 .guest_token(
3342 &joined_room.room.livekit_room,
3343 &session.user_id().to_string(),
3344 )
3345 .trace_err()?,
3346 )
3347 } else {
3348 (
3349 true,
3350 live_kit
3351 .room_token(
3352 &joined_room.room.livekit_room,
3353 &session.user_id().to_string(),
3354 )
3355 .trace_err()?,
3356 )
3357 };
3358
3359 Some(LiveKitConnectionInfo {
3360 server_url: live_kit.url().into(),
3361 token,
3362 can_publish,
3363 })
3364 });
3365
3366 response.send(proto::JoinRoomResponse {
3367 room: Some(joined_room.room.clone()),
3368 channel_id: joined_room
3369 .channel
3370 .as_ref()
3371 .map(|channel| channel.id.to_proto()),
3372 live_kit_connection_info,
3373 })?;
3374
3375 let mut connection_pool = session.connection_pool().await;
3376 if let Some(membership_updated) = membership_updated {
3377 notify_membership_updated(
3378 &mut connection_pool,
3379 membership_updated,
3380 session.user_id(),
3381 &session.peer,
3382 );
3383 }
3384
3385 room_updated(&joined_room.room, &session.peer);
3386
3387 joined_room
3388 };
3389
3390 channel_updated(
3391 &joined_room.channel.context("channel not returned")?,
3392 &joined_room.room,
3393 &session.peer,
3394 &*session.connection_pool().await,
3395 );
3396
3397 update_user_contacts(session.user_id(), &session).await?;
3398 Ok(())
3399}
3400
3401/// Start editing the channel notes
3402async fn join_channel_buffer(
3403 request: proto::JoinChannelBuffer,
3404 response: Response<proto::JoinChannelBuffer>,
3405 session: MessageContext,
3406) -> Result<()> {
3407 let db = session.db().await;
3408 let channel_id = ChannelId::from_proto(request.channel_id);
3409
3410 let open_response = db
3411 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3412 .await?;
3413
3414 let collaborators = open_response.collaborators.clone();
3415 response.send(open_response)?;
3416
3417 let update = UpdateChannelBufferCollaborators {
3418 channel_id: channel_id.to_proto(),
3419 collaborators: collaborators.clone(),
3420 };
3421 channel_buffer_updated(
3422 session.connection_id,
3423 collaborators
3424 .iter()
3425 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3426 &update,
3427 &session.peer,
3428 );
3429
3430 Ok(())
3431}
3432
3433/// Edit the channel notes
3434async fn update_channel_buffer(
3435 request: proto::UpdateChannelBuffer,
3436 session: MessageContext,
3437) -> Result<()> {
3438 let db = session.db().await;
3439 let channel_id = ChannelId::from_proto(request.channel_id);
3440
3441 let (collaborators, epoch, version) = db
3442 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3443 .await?;
3444
3445 channel_buffer_updated(
3446 session.connection_id,
3447 collaborators.clone(),
3448 &proto::UpdateChannelBuffer {
3449 channel_id: channel_id.to_proto(),
3450 operations: request.operations,
3451 },
3452 &session.peer,
3453 );
3454
3455 let pool = &*session.connection_pool().await;
3456
3457 let non_collaborators =
3458 pool.channel_connection_ids(channel_id)
3459 .filter_map(|(connection_id, _)| {
3460 if collaborators.contains(&connection_id) {
3461 None
3462 } else {
3463 Some(connection_id)
3464 }
3465 });
3466
3467 broadcast(None, non_collaborators, |peer_id| {
3468 session.peer.send(
3469 peer_id,
3470 proto::UpdateChannels {
3471 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3472 channel_id: channel_id.to_proto(),
3473 epoch: epoch as u64,
3474 version: version.clone(),
3475 }],
3476 ..Default::default()
3477 },
3478 )
3479 });
3480
3481 Ok(())
3482}
3483
3484/// Rejoin the channel notes after a connection blip
3485async fn rejoin_channel_buffers(
3486 request: proto::RejoinChannelBuffers,
3487 response: Response<proto::RejoinChannelBuffers>,
3488 session: MessageContext,
3489) -> Result<()> {
3490 let db = session.db().await;
3491 let buffers = db
3492 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3493 .await?;
3494
3495 for rejoined_buffer in &buffers {
3496 let collaborators_to_notify = rejoined_buffer
3497 .buffer
3498 .collaborators
3499 .iter()
3500 .filter_map(|c| Some(c.peer_id?.into()));
3501 channel_buffer_updated(
3502 session.connection_id,
3503 collaborators_to_notify,
3504 &proto::UpdateChannelBufferCollaborators {
3505 channel_id: rejoined_buffer.buffer.channel_id,
3506 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3507 },
3508 &session.peer,
3509 );
3510 }
3511
3512 response.send(proto::RejoinChannelBuffersResponse {
3513 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3514 })?;
3515
3516 Ok(())
3517}
3518
3519/// Stop editing the channel notes
3520async fn leave_channel_buffer(
3521 request: proto::LeaveChannelBuffer,
3522 response: Response<proto::LeaveChannelBuffer>,
3523 session: MessageContext,
3524) -> Result<()> {
3525 let db = session.db().await;
3526 let channel_id = ChannelId::from_proto(request.channel_id);
3527
3528 let left_buffer = db
3529 .leave_channel_buffer(channel_id, session.connection_id)
3530 .await?;
3531
3532 response.send(Ack {})?;
3533
3534 channel_buffer_updated(
3535 session.connection_id,
3536 left_buffer.connections,
3537 &proto::UpdateChannelBufferCollaborators {
3538 channel_id: channel_id.to_proto(),
3539 collaborators: left_buffer.collaborators,
3540 },
3541 &session.peer,
3542 );
3543
3544 Ok(())
3545}
3546
3547fn channel_buffer_updated<T: EnvelopedMessage>(
3548 sender_id: ConnectionId,
3549 collaborators: impl IntoIterator<Item = ConnectionId>,
3550 message: &T,
3551 peer: &Peer,
3552) {
3553 broadcast(Some(sender_id), collaborators, |peer_id| {
3554 peer.send(peer_id, message.clone())
3555 });
3556}
3557
3558fn send_notifications(
3559 connection_pool: &ConnectionPool,
3560 peer: &Peer,
3561 notifications: db::NotificationBatch,
3562) {
3563 for (user_id, notification) in notifications {
3564 for connection_id in connection_pool.user_connection_ids(user_id) {
3565 if let Err(error) = peer.send(
3566 connection_id,
3567 proto::AddNotification {
3568 notification: Some(notification.clone()),
3569 },
3570 ) {
3571 tracing::error!(
3572 "failed to send notification to {:?} {}",
3573 connection_id,
3574 error
3575 );
3576 }
3577 }
3578 }
3579}
3580
3581/// Send a message to the channel
3582async fn send_channel_message(
3583 request: proto::SendChannelMessage,
3584 response: Response<proto::SendChannelMessage>,
3585 session: MessageContext,
3586) -> Result<()> {
3587 // Validate the message body.
3588 let body = request.body.trim().to_string();
3589 if body.len() > MAX_MESSAGE_LEN {
3590 return Err(anyhow!("message is too long"))?;
3591 }
3592 if body.is_empty() {
3593 return Err(anyhow!("message can't be blank"))?;
3594 }
3595
3596 // TODO: adjust mentions if body is trimmed
3597
3598 let timestamp = OffsetDateTime::now_utc();
3599 let nonce = request.nonce.context("nonce can't be blank")?;
3600
3601 let channel_id = ChannelId::from_proto(request.channel_id);
3602 let CreatedChannelMessage {
3603 message_id,
3604 participant_connection_ids,
3605 notifications,
3606 } = session
3607 .db()
3608 .await
3609 .create_channel_message(
3610 channel_id,
3611 session.user_id(),
3612 &body,
3613 &request.mentions,
3614 timestamp,
3615 nonce.clone().into(),
3616 request.reply_to_message_id.map(MessageId::from_proto),
3617 )
3618 .await?;
3619
3620 let message = proto::ChannelMessage {
3621 sender_id: session.user_id().to_proto(),
3622 id: message_id.to_proto(),
3623 body,
3624 mentions: request.mentions,
3625 timestamp: timestamp.unix_timestamp() as u64,
3626 nonce: Some(nonce),
3627 reply_to_message_id: request.reply_to_message_id,
3628 edited_at: None,
3629 };
3630 broadcast(
3631 Some(session.connection_id),
3632 participant_connection_ids.clone(),
3633 |connection| {
3634 session.peer.send(
3635 connection,
3636 proto::ChannelMessageSent {
3637 channel_id: channel_id.to_proto(),
3638 message: Some(message.clone()),
3639 },
3640 )
3641 },
3642 );
3643 response.send(proto::SendChannelMessageResponse {
3644 message: Some(message),
3645 })?;
3646
3647 let pool = &*session.connection_pool().await;
3648 let non_participants =
3649 pool.channel_connection_ids(channel_id)
3650 .filter_map(|(connection_id, _)| {
3651 if participant_connection_ids.contains(&connection_id) {
3652 None
3653 } else {
3654 Some(connection_id)
3655 }
3656 });
3657 broadcast(None, non_participants, |peer_id| {
3658 session.peer.send(
3659 peer_id,
3660 proto::UpdateChannels {
3661 latest_channel_message_ids: vec![proto::ChannelMessageId {
3662 channel_id: channel_id.to_proto(),
3663 message_id: message_id.to_proto(),
3664 }],
3665 ..Default::default()
3666 },
3667 )
3668 });
3669 send_notifications(pool, &session.peer, notifications);
3670
3671 Ok(())
3672}
3673
3674/// Delete a channel message
3675async fn remove_channel_message(
3676 request: proto::RemoveChannelMessage,
3677 response: Response<proto::RemoveChannelMessage>,
3678 session: MessageContext,
3679) -> Result<()> {
3680 let channel_id = ChannelId::from_proto(request.channel_id);
3681 let message_id = MessageId::from_proto(request.message_id);
3682 let (connection_ids, existing_notification_ids) = session
3683 .db()
3684 .await
3685 .remove_channel_message(channel_id, message_id, session.user_id())
3686 .await?;
3687
3688 broadcast(
3689 Some(session.connection_id),
3690 connection_ids,
3691 move |connection| {
3692 session.peer.send(connection, request.clone())?;
3693
3694 for notification_id in &existing_notification_ids {
3695 session.peer.send(
3696 connection,
3697 proto::DeleteNotification {
3698 notification_id: (*notification_id).to_proto(),
3699 },
3700 )?;
3701 }
3702
3703 Ok(())
3704 },
3705 );
3706 response.send(proto::Ack {})?;
3707 Ok(())
3708}
3709
3710async fn update_channel_message(
3711 request: proto::UpdateChannelMessage,
3712 response: Response<proto::UpdateChannelMessage>,
3713 session: MessageContext,
3714) -> Result<()> {
3715 let channel_id = ChannelId::from_proto(request.channel_id);
3716 let message_id = MessageId::from_proto(request.message_id);
3717 let updated_at = OffsetDateTime::now_utc();
3718 let UpdatedChannelMessage {
3719 message_id,
3720 participant_connection_ids,
3721 notifications,
3722 reply_to_message_id,
3723 timestamp,
3724 deleted_mention_notification_ids,
3725 updated_mention_notifications,
3726 } = session
3727 .db()
3728 .await
3729 .update_channel_message(
3730 channel_id,
3731 message_id,
3732 session.user_id(),
3733 request.body.as_str(),
3734 &request.mentions,
3735 updated_at,
3736 )
3737 .await?;
3738
3739 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3740
3741 let message = proto::ChannelMessage {
3742 sender_id: session.user_id().to_proto(),
3743 id: message_id.to_proto(),
3744 body: request.body.clone(),
3745 mentions: request.mentions.clone(),
3746 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3747 nonce: Some(nonce),
3748 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3749 edited_at: Some(updated_at.unix_timestamp() as u64),
3750 };
3751
3752 response.send(proto::Ack {})?;
3753
3754 let pool = &*session.connection_pool().await;
3755 broadcast(
3756 Some(session.connection_id),
3757 participant_connection_ids,
3758 |connection| {
3759 session.peer.send(
3760 connection,
3761 proto::ChannelMessageUpdate {
3762 channel_id: channel_id.to_proto(),
3763 message: Some(message.clone()),
3764 },
3765 )?;
3766
3767 for notification_id in &deleted_mention_notification_ids {
3768 session.peer.send(
3769 connection,
3770 proto::DeleteNotification {
3771 notification_id: (*notification_id).to_proto(),
3772 },
3773 )?;
3774 }
3775
3776 for notification in &updated_mention_notifications {
3777 session.peer.send(
3778 connection,
3779 proto::UpdateNotification {
3780 notification: Some(notification.clone()),
3781 },
3782 )?;
3783 }
3784
3785 Ok(())
3786 },
3787 );
3788
3789 send_notifications(pool, &session.peer, notifications);
3790
3791 Ok(())
3792}
3793
3794/// Mark a channel message as read
3795async fn acknowledge_channel_message(
3796 request: proto::AckChannelMessage,
3797 session: MessageContext,
3798) -> Result<()> {
3799 let channel_id = ChannelId::from_proto(request.channel_id);
3800 let message_id = MessageId::from_proto(request.message_id);
3801 let notifications = session
3802 .db()
3803 .await
3804 .observe_channel_message(channel_id, session.user_id(), message_id)
3805 .await?;
3806 send_notifications(
3807 &*session.connection_pool().await,
3808 &session.peer,
3809 notifications,
3810 );
3811 Ok(())
3812}
3813
3814/// Mark a buffer version as synced
3815async fn acknowledge_buffer_version(
3816 request: proto::AckBufferOperation,
3817 session: MessageContext,
3818) -> Result<()> {
3819 let buffer_id = BufferId::from_proto(request.buffer_id);
3820 session
3821 .db()
3822 .await
3823 .observe_buffer_version(
3824 buffer_id,
3825 session.user_id(),
3826 request.epoch as i32,
3827 &request.version,
3828 )
3829 .await?;
3830 Ok(())
3831}
3832
3833/// Get a Supermaven API key for the user
3834async fn get_supermaven_api_key(
3835 _request: proto::GetSupermavenApiKey,
3836 response: Response<proto::GetSupermavenApiKey>,
3837 session: MessageContext,
3838) -> Result<()> {
3839 let user_id: String = session.user_id().to_string();
3840 if !session.is_staff() {
3841 return Err(anyhow!("supermaven not enabled for this account"))?;
3842 }
3843
3844 let email = session.email().context("user must have an email")?;
3845
3846 let supermaven_admin_api = session
3847 .supermaven_client
3848 .as_ref()
3849 .context("supermaven not configured")?;
3850
3851 let result = supermaven_admin_api
3852 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3853 .await?;
3854
3855 response.send(proto::GetSupermavenApiKeyResponse {
3856 api_key: result.api_key,
3857 })?;
3858
3859 Ok(())
3860}
3861
3862/// Start receiving chat updates for a channel
3863async fn join_channel_chat(
3864 request: proto::JoinChannelChat,
3865 response: Response<proto::JoinChannelChat>,
3866 session: MessageContext,
3867) -> Result<()> {
3868 let channel_id = ChannelId::from_proto(request.channel_id);
3869
3870 let db = session.db().await;
3871 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3872 .await?;
3873 let messages = db
3874 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3875 .await?;
3876 response.send(proto::JoinChannelChatResponse {
3877 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3878 messages,
3879 })?;
3880 Ok(())
3881}
3882
3883/// Stop receiving chat updates for a channel
3884async fn leave_channel_chat(
3885 request: proto::LeaveChannelChat,
3886 session: MessageContext,
3887) -> Result<()> {
3888 let channel_id = ChannelId::from_proto(request.channel_id);
3889 session
3890 .db()
3891 .await
3892 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3893 .await?;
3894 Ok(())
3895}
3896
3897/// Retrieve the chat history for a channel
3898async fn get_channel_messages(
3899 request: proto::GetChannelMessages,
3900 response: Response<proto::GetChannelMessages>,
3901 session: MessageContext,
3902) -> Result<()> {
3903 let channel_id = ChannelId::from_proto(request.channel_id);
3904 let messages = session
3905 .db()
3906 .await
3907 .get_channel_messages(
3908 channel_id,
3909 session.user_id(),
3910 MESSAGE_COUNT_PER_PAGE,
3911 Some(MessageId::from_proto(request.before_message_id)),
3912 )
3913 .await?;
3914 response.send(proto::GetChannelMessagesResponse {
3915 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3916 messages,
3917 })?;
3918 Ok(())
3919}
3920
3921/// Retrieve specific chat messages
3922async fn get_channel_messages_by_id(
3923 request: proto::GetChannelMessagesById,
3924 response: Response<proto::GetChannelMessagesById>,
3925 session: MessageContext,
3926) -> Result<()> {
3927 let message_ids = request
3928 .message_ids
3929 .iter()
3930 .map(|id| MessageId::from_proto(*id))
3931 .collect::<Vec<_>>();
3932 let messages = session
3933 .db()
3934 .await
3935 .get_channel_messages_by_id(session.user_id(), &message_ids)
3936 .await?;
3937 response.send(proto::GetChannelMessagesResponse {
3938 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3939 messages,
3940 })?;
3941 Ok(())
3942}
3943
3944/// Retrieve the current users notifications
3945async fn get_notifications(
3946 request: proto::GetNotifications,
3947 response: Response<proto::GetNotifications>,
3948 session: MessageContext,
3949) -> Result<()> {
3950 let notifications = session
3951 .db()
3952 .await
3953 .get_notifications(
3954 session.user_id(),
3955 NOTIFICATION_COUNT_PER_PAGE,
3956 request.before_id.map(db::NotificationId::from_proto),
3957 )
3958 .await?;
3959 response.send(proto::GetNotificationsResponse {
3960 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3961 notifications,
3962 })?;
3963 Ok(())
3964}
3965
3966/// Mark notifications as read
3967async fn mark_notification_as_read(
3968 request: proto::MarkNotificationRead,
3969 response: Response<proto::MarkNotificationRead>,
3970 session: MessageContext,
3971) -> Result<()> {
3972 let database = &session.db().await;
3973 let notifications = database
3974 .mark_notification_as_read_by_id(
3975 session.user_id(),
3976 NotificationId::from_proto(request.notification_id),
3977 )
3978 .await?;
3979 send_notifications(
3980 &*session.connection_pool().await,
3981 &session.peer,
3982 notifications,
3983 );
3984 response.send(proto::Ack {})?;
3985 Ok(())
3986}
3987
3988/// Accept the terms of service (tos) on behalf of the current user
3989async fn accept_terms_of_service(
3990 _request: proto::AcceptTermsOfService,
3991 response: Response<proto::AcceptTermsOfService>,
3992 session: MessageContext,
3993) -> Result<()> {
3994 let db = session.db().await;
3995
3996 let accepted_tos_at = Utc::now();
3997 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
3998 .await?;
3999
4000 response.send(proto::AcceptTermsOfServiceResponse {
4001 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4002 })?;
4003
4004 Ok(())
4005}
4006
4007fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4008 let message = match message {
4009 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4010 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4011 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4012 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4013 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4014 code: frame.code.into(),
4015 reason: frame.reason.as_str().to_owned().into(),
4016 })),
4017 // We should never receive a frame while reading the message, according
4018 // to the `tungstenite` maintainers:
4019 //
4020 // > It cannot occur when you read messages from the WebSocket, but it
4021 // > can be used when you want to send the raw frames (e.g. you want to
4022 // > send the frames to the WebSocket without composing the full message first).
4023 // >
4024 // > — https://github.com/snapview/tungstenite-rs/issues/268
4025 TungsteniteMessage::Frame(_) => {
4026 bail!("received an unexpected frame while reading the message")
4027 }
4028 };
4029
4030 Ok(message)
4031}
4032
4033fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4034 match message {
4035 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4036 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4037 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4038 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4039 AxumMessage::Close(frame) => {
4040 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4041 code: frame.code.into(),
4042 reason: frame.reason.as_ref().into(),
4043 }))
4044 }
4045 }
4046}
4047
4048fn notify_membership_updated(
4049 connection_pool: &mut ConnectionPool,
4050 result: MembershipUpdated,
4051 user_id: UserId,
4052 peer: &Peer,
4053) {
4054 for membership in &result.new_channels.channel_memberships {
4055 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4056 }
4057 for channel_id in &result.removed_channels {
4058 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4059 }
4060
4061 let user_channels_update = proto::UpdateUserChannels {
4062 channel_memberships: result
4063 .new_channels
4064 .channel_memberships
4065 .iter()
4066 .map(|cm| proto::ChannelMembership {
4067 channel_id: cm.channel_id.to_proto(),
4068 role: cm.role.into(),
4069 })
4070 .collect(),
4071 ..Default::default()
4072 };
4073
4074 let mut update = build_channels_update(result.new_channels);
4075 update.delete_channels = result
4076 .removed_channels
4077 .into_iter()
4078 .map(|id| id.to_proto())
4079 .collect();
4080 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4081
4082 for connection_id in connection_pool.user_connection_ids(user_id) {
4083 peer.send(connection_id, user_channels_update.clone())
4084 .trace_err();
4085 peer.send(connection_id, update.clone()).trace_err();
4086 }
4087}
4088
4089fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4090 proto::UpdateUserChannels {
4091 channel_memberships: channels
4092 .channel_memberships
4093 .iter()
4094 .map(|m| proto::ChannelMembership {
4095 channel_id: m.channel_id.to_proto(),
4096 role: m.role.into(),
4097 })
4098 .collect(),
4099 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4100 observed_channel_message_id: channels.observed_channel_messages.clone(),
4101 }
4102}
4103
4104fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4105 let mut update = proto::UpdateChannels::default();
4106
4107 for channel in channels.channels {
4108 update.channels.push(channel.to_proto());
4109 }
4110
4111 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4112 update.latest_channel_message_ids = channels.latest_channel_messages;
4113
4114 for (channel_id, participants) in channels.channel_participants {
4115 update
4116 .channel_participants
4117 .push(proto::ChannelParticipants {
4118 channel_id: channel_id.to_proto(),
4119 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4120 });
4121 }
4122
4123 for channel in channels.invited_channels {
4124 update.channel_invitations.push(channel.to_proto());
4125 }
4126
4127 update
4128}
4129
4130fn build_initial_contacts_update(
4131 contacts: Vec<db::Contact>,
4132 pool: &ConnectionPool,
4133) -> proto::UpdateContacts {
4134 let mut update = proto::UpdateContacts::default();
4135
4136 for contact in contacts {
4137 match contact {
4138 db::Contact::Accepted { user_id, busy } => {
4139 update.contacts.push(contact_for_user(user_id, busy, pool));
4140 }
4141 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4142 db::Contact::Incoming { user_id } => {
4143 update
4144 .incoming_requests
4145 .push(proto::IncomingContactRequest {
4146 requester_id: user_id.to_proto(),
4147 })
4148 }
4149 }
4150 }
4151
4152 update
4153}
4154
4155fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4156 proto::Contact {
4157 user_id: user_id.to_proto(),
4158 online: pool.is_user_online(user_id),
4159 busy,
4160 }
4161}
4162
4163fn room_updated(room: &proto::Room, peer: &Peer) {
4164 broadcast(
4165 None,
4166 room.participants
4167 .iter()
4168 .filter_map(|participant| Some(participant.peer_id?.into())),
4169 |peer_id| {
4170 peer.send(
4171 peer_id,
4172 proto::RoomUpdated {
4173 room: Some(room.clone()),
4174 },
4175 )
4176 },
4177 );
4178}
4179
4180fn channel_updated(
4181 channel: &db::channel::Model,
4182 room: &proto::Room,
4183 peer: &Peer,
4184 pool: &ConnectionPool,
4185) {
4186 let participants = room
4187 .participants
4188 .iter()
4189 .map(|p| p.user_id)
4190 .collect::<Vec<_>>();
4191
4192 broadcast(
4193 None,
4194 pool.channel_connection_ids(channel.root_id())
4195 .filter_map(|(channel_id, role)| {
4196 role.can_see_channel(channel.visibility)
4197 .then_some(channel_id)
4198 }),
4199 |peer_id| {
4200 peer.send(
4201 peer_id,
4202 proto::UpdateChannels {
4203 channel_participants: vec![proto::ChannelParticipants {
4204 channel_id: channel.id.to_proto(),
4205 participant_user_ids: participants.clone(),
4206 }],
4207 ..Default::default()
4208 },
4209 )
4210 },
4211 );
4212}
4213
4214async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4215 let db = session.db().await;
4216
4217 let contacts = db.get_contacts(user_id).await?;
4218 let busy = db.is_user_busy(user_id).await?;
4219
4220 let pool = session.connection_pool().await;
4221 let updated_contact = contact_for_user(user_id, busy, &pool);
4222 for contact in contacts {
4223 if let db::Contact::Accepted {
4224 user_id: contact_user_id,
4225 ..
4226 } = contact
4227 {
4228 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4229 session
4230 .peer
4231 .send(
4232 contact_conn_id,
4233 proto::UpdateContacts {
4234 contacts: vec![updated_contact.clone()],
4235 remove_contacts: Default::default(),
4236 incoming_requests: Default::default(),
4237 remove_incoming_requests: Default::default(),
4238 outgoing_requests: Default::default(),
4239 remove_outgoing_requests: Default::default(),
4240 },
4241 )
4242 .trace_err();
4243 }
4244 }
4245 }
4246 Ok(())
4247}
4248
4249async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4250 let mut contacts_to_update = HashSet::default();
4251
4252 let room_id;
4253 let canceled_calls_to_user_ids;
4254 let livekit_room;
4255 let delete_livekit_room;
4256 let room;
4257 let channel;
4258
4259 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4260 contacts_to_update.insert(session.user_id());
4261
4262 for project in left_room.left_projects.values() {
4263 project_left(project, session);
4264 }
4265
4266 room_id = RoomId::from_proto(left_room.room.id);
4267 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4268 livekit_room = mem::take(&mut left_room.room.livekit_room);
4269 delete_livekit_room = left_room.deleted;
4270 room = mem::take(&mut left_room.room);
4271 channel = mem::take(&mut left_room.channel);
4272
4273 room_updated(&room, &session.peer);
4274 } else {
4275 return Ok(());
4276 }
4277
4278 if let Some(channel) = channel {
4279 channel_updated(
4280 &channel,
4281 &room,
4282 &session.peer,
4283 &*session.connection_pool().await,
4284 );
4285 }
4286
4287 {
4288 let pool = session.connection_pool().await;
4289 for canceled_user_id in canceled_calls_to_user_ids {
4290 for connection_id in pool.user_connection_ids(canceled_user_id) {
4291 session
4292 .peer
4293 .send(
4294 connection_id,
4295 proto::CallCanceled {
4296 room_id: room_id.to_proto(),
4297 },
4298 )
4299 .trace_err();
4300 }
4301 contacts_to_update.insert(canceled_user_id);
4302 }
4303 }
4304
4305 for contact_user_id in contacts_to_update {
4306 update_user_contacts(contact_user_id, session).await?;
4307 }
4308
4309 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4310 live_kit
4311 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4312 .await
4313 .trace_err();
4314
4315 if delete_livekit_room {
4316 live_kit.delete_room(livekit_room).await.trace_err();
4317 }
4318 }
4319
4320 Ok(())
4321}
4322
4323async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4324 let left_channel_buffers = session
4325 .db()
4326 .await
4327 .leave_channel_buffers(session.connection_id)
4328 .await?;
4329
4330 for left_buffer in left_channel_buffers {
4331 channel_buffer_updated(
4332 session.connection_id,
4333 left_buffer.connections,
4334 &proto::UpdateChannelBufferCollaborators {
4335 channel_id: left_buffer.channel_id.to_proto(),
4336 collaborators: left_buffer.collaborators,
4337 },
4338 &session.peer,
4339 );
4340 }
4341
4342 Ok(())
4343}
4344
4345fn project_left(project: &db::LeftProject, session: &Session) {
4346 for connection_id in &project.connection_ids {
4347 if project.should_unshare {
4348 session
4349 .peer
4350 .send(
4351 *connection_id,
4352 proto::UnshareProject {
4353 project_id: project.id.to_proto(),
4354 },
4355 )
4356 .trace_err();
4357 } else {
4358 session
4359 .peer
4360 .send(
4361 *connection_id,
4362 proto::RemoveProjectCollaborator {
4363 project_id: project.id.to_proto(),
4364 peer_id: Some(session.connection_id.into()),
4365 },
4366 )
4367 .trace_err();
4368 }
4369 }
4370}
4371
4372pub trait ResultExt {
4373 type Ok;
4374
4375 fn trace_err(self) -> Option<Self::Ok>;
4376}
4377
4378impl<T, E> ResultExt for Result<T, E>
4379where
4380 E: std::fmt::Debug,
4381{
4382 type Ok = T;
4383
4384 #[track_caller]
4385 fn trace_err(self) -> Option<T> {
4386 match self {
4387 Ok(value) => Some(value),
4388 Err(error) => {
4389 tracing::error!("{:?}", error);
4390 None
4391 }
4392 }
4393 }
4394}