1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
10 UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use futures::TryFutureExt as _;
36use rpc::proto::split_repository_update;
37use tracing::Span;
38use util::paths::PathStyle;
39
40use futures::{
41 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
42 stream::FuturesUnordered,
43};
44use prometheus::{IntGauge, register_int_gauge};
45use rpc::{
46 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
47 proto::{
48 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
49 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
50 },
51};
52use semver::Version;
53use serde::{Serialize, Serializer};
54use std::{
55 any::TypeId,
56 future::Future,
57 marker::PhantomData,
58 mem,
59 net::SocketAddr,
60 ops::{Deref, DerefMut},
61 rc::Rc,
62 sync::{
63 Arc, OnceLock,
64 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
65 },
66 time::{Duration, Instant},
67};
68use tokio::sync::{Semaphore, watch};
69use tower::ServiceBuilder;
70use tracing::{
71 Instrument,
72 field::{self},
73 info_span, instrument,
74};
75
76pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
77
78// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
79pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
80
81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
82const MAX_CONCURRENT_CONNECTIONS: usize = 512;
83
84static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
85
86const TOTAL_DURATION_MS: &str = "total_duration_ms";
87const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
88const QUEUE_DURATION_MS: &str = "queue_duration_ms";
89const HOST_WAITING_MS: &str = "host_waiting_ms";
90
91type MessageHandler =
92 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
93
94pub struct ConnectionGuard;
95
96impl ConnectionGuard {
97 pub fn try_acquire() -> Result<Self, ()> {
98 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
99 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
100 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
101 tracing::error!(
102 "too many concurrent connections: {}",
103 current_connections + 1
104 );
105 return Err(());
106 }
107 Ok(ConnectionGuard)
108 }
109}
110
111impl Drop for ConnectionGuard {
112 fn drop(&mut self) {
113 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
114 }
115}
116
117struct Response<R> {
118 peer: Arc<Peer>,
119 receipt: Receipt<R>,
120 responded: Arc<AtomicBool>,
121}
122
123impl<R: RequestMessage> Response<R> {
124 fn send(self, payload: R::Response) -> Result<()> {
125 self.responded.store(true, SeqCst);
126 self.peer.respond(self.receipt, payload)?;
127 Ok(())
128 }
129}
130
131#[derive(Clone, Debug)]
132pub enum Principal {
133 User(User),
134 Impersonated { user: User, admin: User },
135}
136
137impl Principal {
138 fn update_span(&self, span: &tracing::Span) {
139 match &self {
140 Principal::User(user) => {
141 span.record("user_id", user.id.0);
142 span.record("login", &user.github_login);
143 }
144 Principal::Impersonated { user, admin } => {
145 span.record("user_id", user.id.0);
146 span.record("login", &user.github_login);
147 span.record("impersonator", &admin.github_login);
148 }
149 }
150 }
151}
152
153#[derive(Clone)]
154struct MessageContext {
155 session: Session,
156 span: tracing::Span,
157}
158
159impl Deref for MessageContext {
160 type Target = Session;
161
162 fn deref(&self) -> &Self::Target {
163 &self.session
164 }
165}
166
167impl MessageContext {
168 pub fn forward_request<T: RequestMessage>(
169 &self,
170 receiver_id: ConnectionId,
171 request: T,
172 ) -> impl Future<Output = anyhow::Result<T::Response>> {
173 let request_start_time = Instant::now();
174 let span = self.span.clone();
175 tracing::info!("start forwarding request");
176 self.peer
177 .forward_request(self.connection_id, receiver_id, request)
178 .inspect(move |_| {
179 span.record(
180 HOST_WAITING_MS,
181 request_start_time.elapsed().as_micros() as f64 / 1000.0,
182 );
183 })
184 .inspect_err(|_| tracing::error!("error forwarding request"))
185 .inspect_ok(|_| tracing::info!("finished forwarding request"))
186 }
187}
188
189#[derive(Clone)]
190struct Session {
191 principal: Principal,
192 connection_id: ConnectionId,
193 db: Arc<tokio::sync::Mutex<DbHandle>>,
194 peer: Arc<Peer>,
195 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
196 app_state: Arc<AppState>,
197 /// The GeoIP country code for the user.
198 #[allow(unused)]
199 geoip_country_code: Option<String>,
200 #[allow(unused)]
201 system_id: Option<String>,
202 _executor: Executor,
203}
204
205impl Session {
206 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
207 #[cfg(feature = "test-support")]
208 tokio::task::yield_now().await;
209 let guard = self.db.lock().await;
210 #[cfg(feature = "test-support")]
211 tokio::task::yield_now().await;
212 guard
213 }
214
215 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
216 #[cfg(feature = "test-support")]
217 tokio::task::yield_now().await;
218 let guard = self.connection_pool.lock();
219 ConnectionPoolGuard {
220 guard,
221 _not_send: PhantomData,
222 }
223 }
224
225 #[expect(dead_code)]
226 fn is_staff(&self) -> bool {
227 match &self.principal {
228 Principal::User(user) => user.admin,
229 Principal::Impersonated { .. } => true,
230 }
231 }
232
233 fn user_id(&self) -> UserId {
234 match &self.principal {
235 Principal::User(user) => user.id,
236 Principal::Impersonated { user, .. } => user.id,
237 }
238 }
239}
240
241impl Debug for Session {
242 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
243 let mut result = f.debug_struct("Session");
244 match &self.principal {
245 Principal::User(user) => {
246 result.field("user", &user.github_login);
247 }
248 Principal::Impersonated { user, admin } => {
249 result.field("user", &user.github_login);
250 result.field("impersonator", &admin.github_login);
251 }
252 }
253 result.field("connection_id", &self.connection_id).finish()
254 }
255}
256
257struct DbHandle(Arc<Database>);
258
259impl Deref for DbHandle {
260 type Target = Database;
261
262 fn deref(&self) -> &Self::Target {
263 self.0.as_ref()
264 }
265}
266
267pub struct Server {
268 id: parking_lot::Mutex<ServerId>,
269 peer: Arc<Peer>,
270 pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
271 app_state: Arc<AppState>,
272 handlers: HashMap<TypeId, MessageHandler>,
273 teardown: watch::Sender<bool>,
274}
275
276struct ConnectionPoolGuard<'a> {
277 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
278 _not_send: PhantomData<Rc<()>>,
279}
280
281#[derive(Serialize)]
282pub struct ServerSnapshot<'a> {
283 peer: &'a Peer,
284 #[serde(serialize_with = "serialize_deref")]
285 connection_pool: ConnectionPoolGuard<'a>,
286}
287
288pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
289where
290 S: Serializer,
291 T: Deref<Target = U>,
292 U: Serialize,
293{
294 Serialize::serialize(value.deref(), serializer)
295}
296
297impl Server {
298 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
299 let mut server = Self {
300 id: parking_lot::Mutex::new(id),
301 peer: Peer::new(id.0 as u32),
302 app_state,
303 connection_pool: Default::default(),
304 handlers: Default::default(),
305 teardown: watch::channel(false).0,
306 };
307
308 server
309 .add_request_handler(ping)
310 .add_request_handler(create_room)
311 .add_request_handler(join_room)
312 .add_request_handler(rejoin_room)
313 .add_request_handler(leave_room)
314 .add_request_handler(set_room_participant_role)
315 .add_request_handler(call)
316 .add_request_handler(cancel_call)
317 .add_message_handler(decline_call)
318 .add_request_handler(update_participant_location)
319 .add_request_handler(share_project)
320 .add_message_handler(unshare_project)
321 .add_request_handler(join_project)
322 .add_message_handler(leave_project)
323 .add_request_handler(update_project)
324 .add_request_handler(update_worktree)
325 .add_request_handler(update_repository)
326 .add_request_handler(remove_repository)
327 .add_message_handler(start_language_server)
328 .add_message_handler(update_language_server)
329 .add_message_handler(update_diagnostic_summary)
330 .add_message_handler(update_worktree_settings)
331 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
332 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
333 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
334 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
335 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
336 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
337 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
338 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
339 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
340 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
341 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
342 .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
343 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
344 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
345 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
346 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
347 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
348 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
349 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
350 .add_request_handler(
351 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
352 )
353 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
354 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
355 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
356 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
357 .add_request_handler(
358 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
359 )
360 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
361 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
362 .add_request_handler(
363 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
364 )
365 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
366 .add_request_handler(
367 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
368 )
369 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
370 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
371 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
372 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
373 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
374 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
375 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
376 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
377 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
378 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
379 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
380 .add_request_handler(
381 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
382 )
383 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
384 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
385 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
386 .add_request_handler(lsp_query)
387 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
388 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
389 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
390 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
391 .add_message_handler(create_buffer_for_peer)
392 .add_message_handler(create_image_for_peer)
393 .add_request_handler(update_buffer)
394 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
395 .add_message_handler(
396 broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
397 )
398 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
399 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
400 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
401 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
402 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
403 .add_message_handler(
404 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
405 )
406 .add_request_handler(get_users)
407 .add_request_handler(fuzzy_search_users)
408 .add_request_handler(request_contact)
409 .add_request_handler(remove_contact)
410 .add_request_handler(respond_to_contact_request)
411 .add_message_handler(subscribe_to_channels)
412 .add_request_handler(create_channel)
413 .add_request_handler(delete_channel)
414 .add_request_handler(invite_channel_member)
415 .add_request_handler(remove_channel_member)
416 .add_request_handler(set_channel_member_role)
417 .add_request_handler(set_channel_visibility)
418 .add_request_handler(rename_channel)
419 .add_request_handler(join_channel_buffer)
420 .add_request_handler(leave_channel_buffer)
421 .add_message_handler(update_channel_buffer)
422 .add_request_handler(rejoin_channel_buffers)
423 .add_request_handler(get_channel_members)
424 .add_request_handler(respond_to_channel_invite)
425 .add_request_handler(join_channel)
426 .add_request_handler(join_channel_chat)
427 .add_message_handler(leave_channel_chat)
428 .add_request_handler(send_channel_message)
429 .add_request_handler(remove_channel_message)
430 .add_request_handler(update_channel_message)
431 .add_request_handler(get_channel_messages)
432 .add_request_handler(get_channel_messages_by_id)
433 .add_request_handler(get_notifications)
434 .add_request_handler(mark_notification_as_read)
435 .add_request_handler(move_channel)
436 .add_request_handler(reorder_channel)
437 .add_request_handler(follow)
438 .add_message_handler(unfollow)
439 .add_message_handler(update_followers)
440 .add_message_handler(acknowledge_channel_message)
441 .add_message_handler(acknowledge_buffer_version)
442 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
443 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
444 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
445 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
446 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
447 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
448 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
449 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
450 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
451 .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
452 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
453 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
454 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
455 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
456 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
457 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
458 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
459 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
460 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
461 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
462 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
463 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
464 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
465 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
466 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
467 .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
468 .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
469 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
470 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
471 .add_message_handler(update_context)
472 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
473 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
474 .add_request_handler(share_agent_thread)
475 .add_request_handler(get_shared_agent_thread)
476 .add_request_handler(forward_project_search_chunk);
477
478 Arc::new(server)
479 }
480
481 pub async fn start(&self) -> Result<()> {
482 let server_id = *self.id.lock();
483 let app_state = self.app_state.clone();
484 let peer = self.peer.clone();
485 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
486 let pool = self.connection_pool.clone();
487 let livekit_client = self.app_state.livekit_client.clone();
488
489 let span = info_span!("start server");
490 self.app_state.executor.spawn_detached(
491 async move {
492 tracing::info!("waiting for cleanup timeout");
493 timeout.await;
494 tracing::info!("cleanup timeout expired, retrieving stale rooms");
495
496 app_state
497 .db
498 .delete_stale_channel_chat_participants(
499 &app_state.config.zed_environment,
500 server_id,
501 )
502 .await
503 .trace_err();
504
505 if let Some((room_ids, channel_ids)) = app_state
506 .db
507 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
508 .await
509 .trace_err()
510 {
511 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
512 tracing::info!(
513 stale_channel_buffer_count = channel_ids.len(),
514 "retrieved stale channel buffers"
515 );
516
517 for channel_id in channel_ids {
518 if let Some(refreshed_channel_buffer) = app_state
519 .db
520 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
521 .await
522 .trace_err()
523 {
524 for connection_id in refreshed_channel_buffer.connection_ids {
525 peer.send(
526 connection_id,
527 proto::UpdateChannelBufferCollaborators {
528 channel_id: channel_id.to_proto(),
529 collaborators: refreshed_channel_buffer
530 .collaborators
531 .clone(),
532 },
533 )
534 .trace_err();
535 }
536 }
537 }
538
539 for room_id in room_ids {
540 let mut contacts_to_update = HashSet::default();
541 let mut canceled_calls_to_user_ids = Vec::new();
542 let mut livekit_room = String::new();
543 let mut delete_livekit_room = false;
544
545 if let Some(mut refreshed_room) = app_state
546 .db
547 .clear_stale_room_participants(room_id, server_id)
548 .await
549 .trace_err()
550 {
551 tracing::info!(
552 room_id = room_id.0,
553 new_participant_count = refreshed_room.room.participants.len(),
554 "refreshed room"
555 );
556 room_updated(&refreshed_room.room, &peer);
557 if let Some(channel) = refreshed_room.channel.as_ref() {
558 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
559 }
560 contacts_to_update
561 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
562 contacts_to_update
563 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
564 canceled_calls_to_user_ids =
565 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
566 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
567 delete_livekit_room = refreshed_room.room.participants.is_empty();
568 }
569
570 {
571 let pool = pool.lock();
572 for canceled_user_id in canceled_calls_to_user_ids {
573 for connection_id in pool.user_connection_ids(canceled_user_id) {
574 peer.send(
575 connection_id,
576 proto::CallCanceled {
577 room_id: room_id.to_proto(),
578 },
579 )
580 .trace_err();
581 }
582 }
583 }
584
585 for user_id in contacts_to_update {
586 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
587 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
588 if let Some((busy, contacts)) = busy.zip(contacts) {
589 let pool = pool.lock();
590 let updated_contact = contact_for_user(user_id, busy, &pool);
591 for contact in contacts {
592 if let db::Contact::Accepted {
593 user_id: contact_user_id,
594 ..
595 } = contact
596 {
597 for contact_conn_id in
598 pool.user_connection_ids(contact_user_id)
599 {
600 peer.send(
601 contact_conn_id,
602 proto::UpdateContacts {
603 contacts: vec![updated_contact.clone()],
604 remove_contacts: Default::default(),
605 incoming_requests: Default::default(),
606 remove_incoming_requests: Default::default(),
607 outgoing_requests: Default::default(),
608 remove_outgoing_requests: Default::default(),
609 },
610 )
611 .trace_err();
612 }
613 }
614 }
615 }
616 }
617
618 if let Some(live_kit) = livekit_client.as_ref()
619 && delete_livekit_room
620 {
621 live_kit.delete_room(livekit_room).await.trace_err();
622 }
623 }
624 }
625
626 app_state
627 .db
628 .delete_stale_channel_chat_participants(
629 &app_state.config.zed_environment,
630 server_id,
631 )
632 .await
633 .trace_err();
634
635 app_state
636 .db
637 .clear_old_worktree_entries(server_id)
638 .await
639 .trace_err();
640
641 app_state
642 .db
643 .delete_stale_servers(&app_state.config.zed_environment, server_id)
644 .await
645 .trace_err();
646 }
647 .instrument(span),
648 );
649 Ok(())
650 }
651
652 pub fn teardown(&self) {
653 self.peer.teardown();
654 self.connection_pool.lock().reset();
655 let _ = self.teardown.send(true);
656 }
657
658 #[cfg(feature = "test-support")]
659 pub fn reset(&self, id: ServerId) {
660 self.teardown();
661 *self.id.lock() = id;
662 self.peer.reset(id.0 as u32);
663 let _ = self.teardown.send(false);
664 }
665
666 #[cfg(feature = "test-support")]
667 pub fn id(&self) -> ServerId {
668 *self.id.lock()
669 }
670
671 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
672 where
673 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
674 Fut: 'static + Send + Future<Output = Result<()>>,
675 M: EnvelopedMessage,
676 {
677 let prev_handler = self.handlers.insert(
678 TypeId::of::<M>(),
679 Box::new(move |envelope, session, span| {
680 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
681 let received_at = envelope.received_at;
682 tracing::info!("message received");
683 let start_time = Instant::now();
684 let future = (handler)(
685 *envelope,
686 MessageContext {
687 session,
688 span: span.clone(),
689 },
690 );
691 async move {
692 let result = future.await;
693 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
694 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
695 let queue_duration_ms = total_duration_ms - processing_duration_ms;
696 span.record(TOTAL_DURATION_MS, total_duration_ms);
697 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
698 span.record(QUEUE_DURATION_MS, queue_duration_ms);
699 match result {
700 Err(error) => {
701 tracing::error!(?error, "error handling message")
702 }
703 Ok(()) => tracing::info!("finished handling message"),
704 }
705 }
706 .boxed()
707 }),
708 );
709 if prev_handler.is_some() {
710 panic!("registered a handler for the same message twice");
711 }
712 self
713 }
714
715 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
716 where
717 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
718 Fut: 'static + Send + Future<Output = Result<()>>,
719 M: EnvelopedMessage,
720 {
721 self.add_handler(move |envelope, session| handler(envelope.payload, session));
722 self
723 }
724
725 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
726 where
727 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
728 Fut: Send + Future<Output = Result<()>>,
729 M: RequestMessage,
730 {
731 let handler = Arc::new(handler);
732 self.add_handler(move |envelope, session| {
733 let receipt = envelope.receipt();
734 let handler = handler.clone();
735 async move {
736 let peer = session.peer.clone();
737 let responded = Arc::new(AtomicBool::default());
738 let response = Response {
739 peer: peer.clone(),
740 responded: responded.clone(),
741 receipt,
742 };
743 match (handler)(envelope.payload, response, session).await {
744 Ok(()) => {
745 if responded.load(std::sync::atomic::Ordering::SeqCst) {
746 Ok(())
747 } else {
748 Err(anyhow!("handler did not send a response"))?
749 }
750 }
751 Err(error) => {
752 let proto_err = match &error {
753 Error::Internal(err) => err.to_proto(),
754 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
755 };
756 peer.respond_with_error(receipt, proto_err)?;
757 Err(error)
758 }
759 }
760 }
761 })
762 }
763
764 pub fn handle_connection(
765 self: &Arc<Self>,
766 connection: Connection,
767 address: String,
768 principal: Principal,
769 zed_version: ZedVersion,
770 release_channel: Option<String>,
771 user_agent: Option<String>,
772 geoip_country_code: Option<String>,
773 system_id: Option<String>,
774 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
775 executor: Executor,
776 connection_guard: Option<ConnectionGuard>,
777 ) -> impl Future<Output = ()> + use<> {
778 let this = self.clone();
779 let span = info_span!("handle connection", %address,
780 connection_id=field::Empty,
781 user_id=field::Empty,
782 login=field::Empty,
783 impersonator=field::Empty,
784 user_agent=field::Empty,
785 geoip_country_code=field::Empty,
786 release_channel=field::Empty,
787 );
788 principal.update_span(&span);
789 if let Some(user_agent) = user_agent {
790 span.record("user_agent", user_agent);
791 }
792 if let Some(release_channel) = release_channel {
793 span.record("release_channel", release_channel);
794 }
795
796 if let Some(country_code) = geoip_country_code.as_ref() {
797 span.record("geoip_country_code", country_code);
798 }
799
800 let mut teardown = self.teardown.subscribe();
801 async move {
802 if *teardown.borrow() {
803 tracing::error!("server is tearing down");
804 return;
805 }
806
807 let (connection_id, handle_io, mut incoming_rx) =
808 this.peer.add_connection(connection, {
809 let executor = executor.clone();
810 move |duration| executor.sleep(duration)
811 });
812 tracing::Span::current().record("connection_id", format!("{}", connection_id));
813
814 tracing::info!("connection opened");
815
816 let session = Session {
817 principal: principal.clone(),
818 connection_id,
819 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
820 peer: this.peer.clone(),
821 connection_pool: this.connection_pool.clone(),
822 app_state: this.app_state.clone(),
823 geoip_country_code,
824 system_id,
825 _executor: executor.clone(),
826 };
827
828 if let Err(error) = this
829 .send_initial_client_update(
830 connection_id,
831 zed_version,
832 send_connection_id,
833 &session,
834 )
835 .await
836 {
837 tracing::error!(?error, "failed to send initial client update");
838 return;
839 }
840 drop(connection_guard);
841
842 let handle_io = handle_io.fuse();
843 futures::pin_mut!(handle_io);
844
845 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
846 // This prevents deadlocks when e.g., client A performs a request to client B and
847 // client B performs a request to client A. If both clients stop processing further
848 // messages until their respective request completes, they won't have a chance to
849 // respond to the other client's request and cause a deadlock.
850 //
851 // This arrangement ensures we will attempt to process earlier messages first, but fall
852 // back to processing messages arrived later in the spirit of making progress.
853 const MAX_CONCURRENT_HANDLERS: usize = 256;
854 let mut foreground_message_handlers = FuturesUnordered::new();
855 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
856 let get_concurrent_handlers = {
857 let concurrent_handlers = concurrent_handlers.clone();
858 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
859 };
860 loop {
861 let next_message = async {
862 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
863 let message = incoming_rx.next().await;
864 // Cache the concurrent_handlers here, so that we know what the
865 // queue looks like as each handler starts
866 (permit, message, get_concurrent_handlers())
867 }
868 .fuse();
869 futures::pin_mut!(next_message);
870 futures::select_biased! {
871 _ = teardown.changed().fuse() => return,
872 result = handle_io => {
873 if let Err(error) = result {
874 tracing::error!(?error, "error handling I/O");
875 }
876 break;
877 }
878 _ = foreground_message_handlers.next() => {}
879 next_message = next_message => {
880 let (permit, message, concurrent_handlers) = next_message;
881 if let Some(message) = message {
882 let type_name = message.payload_type_name();
883 // note: we copy all the fields from the parent span so we can query them in the logs.
884 // (https://github.com/tokio-rs/tracing/issues/2670).
885 let span = tracing::info_span!("receive message",
886 %connection_id,
887 %address,
888 type_name,
889 concurrent_handlers,
890 user_id=field::Empty,
891 login=field::Empty,
892 impersonator=field::Empty,
893 lsp_query_request=field::Empty,
894 release_channel=field::Empty,
895 { TOTAL_DURATION_MS }=field::Empty,
896 { PROCESSING_DURATION_MS }=field::Empty,
897 { QUEUE_DURATION_MS }=field::Empty,
898 { HOST_WAITING_MS }=field::Empty
899 );
900 principal.update_span(&span);
901 let span_enter = span.enter();
902 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
903 let is_background = message.is_background();
904 let handle_message = (handler)(message, session.clone(), span.clone());
905 drop(span_enter);
906
907 let handle_message = async move {
908 handle_message.await;
909 drop(permit);
910 }.instrument(span);
911 if is_background {
912 executor.spawn_detached(handle_message);
913 } else {
914 foreground_message_handlers.push(handle_message);
915 }
916 } else {
917 tracing::error!("no message handler");
918 }
919 } else {
920 tracing::info!("connection closed");
921 break;
922 }
923 }
924 }
925 }
926
927 drop(foreground_message_handlers);
928 let concurrent_handlers = get_concurrent_handlers();
929 tracing::info!(concurrent_handlers, "signing out");
930 if let Err(error) = connection_lost(session, teardown, executor).await {
931 tracing::error!(?error, "error signing out");
932 }
933 }
934 .instrument(span)
935 }
936
937 async fn send_initial_client_update(
938 &self,
939 connection_id: ConnectionId,
940 zed_version: ZedVersion,
941 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
942 session: &Session,
943 ) -> Result<()> {
944 self.peer.send(
945 connection_id,
946 proto::Hello {
947 peer_id: Some(connection_id.into()),
948 },
949 )?;
950 tracing::info!("sent hello message");
951 if let Some(send_connection_id) = send_connection_id.take() {
952 let _ = send_connection_id.send(connection_id);
953 }
954
955 match &session.principal {
956 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
957 if !user.connected_once {
958 self.peer.send(connection_id, proto::ShowContacts {})?;
959 self.app_state
960 .db
961 .set_user_connected_once(user.id, true)
962 .await?;
963 }
964
965 let contacts = self.app_state.db.get_contacts(user.id).await?;
966
967 {
968 let mut pool = self.connection_pool.lock();
969 pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
970 self.peer.send(
971 connection_id,
972 build_initial_contacts_update(contacts, &pool),
973 )?;
974 }
975
976 if should_auto_subscribe_to_channels(&zed_version) {
977 subscribe_user_to_channels(user.id, session).await?;
978 }
979
980 if let Some(incoming_call) =
981 self.app_state.db.incoming_call_for_user(user.id).await?
982 {
983 self.peer.send(connection_id, incoming_call)?;
984 }
985
986 update_user_contacts(user.id, session).await?;
987 }
988 }
989
990 Ok(())
991 }
992
993 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
994 ServerSnapshot {
995 connection_pool: ConnectionPoolGuard {
996 guard: self.connection_pool.lock(),
997 _not_send: PhantomData,
998 },
999 peer: &self.peer,
1000 }
1001 }
1002}
1003
1004impl Deref for ConnectionPoolGuard<'_> {
1005 type Target = ConnectionPool;
1006
1007 fn deref(&self) -> &Self::Target {
1008 &self.guard
1009 }
1010}
1011
1012impl DerefMut for ConnectionPoolGuard<'_> {
1013 fn deref_mut(&mut self) -> &mut Self::Target {
1014 &mut self.guard
1015 }
1016}
1017
1018impl Drop for ConnectionPoolGuard<'_> {
1019 fn drop(&mut self) {
1020 #[cfg(feature = "test-support")]
1021 self.check_invariants();
1022 }
1023}
1024
1025fn broadcast<F>(
1026 sender_id: Option<ConnectionId>,
1027 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1028 mut f: F,
1029) where
1030 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1031{
1032 for receiver_id in receiver_ids {
1033 if Some(receiver_id) != sender_id
1034 && let Err(error) = f(receiver_id)
1035 {
1036 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1037 }
1038 }
1039}
1040
1041pub struct ProtocolVersion(u32);
1042
1043impl Header for ProtocolVersion {
1044 fn name() -> &'static HeaderName {
1045 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1046 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1047 }
1048
1049 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1050 where
1051 Self: Sized,
1052 I: Iterator<Item = &'i axum::http::HeaderValue>,
1053 {
1054 let version = values
1055 .next()
1056 .ok_or_else(axum::headers::Error::invalid)?
1057 .to_str()
1058 .map_err(|_| axum::headers::Error::invalid())?
1059 .parse()
1060 .map_err(|_| axum::headers::Error::invalid())?;
1061 Ok(Self(version))
1062 }
1063
1064 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1065 values.extend([self.0.to_string().parse().unwrap()]);
1066 }
1067}
1068
1069pub struct AppVersionHeader(Version);
1070impl Header for AppVersionHeader {
1071 fn name() -> &'static HeaderName {
1072 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1073 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1074 }
1075
1076 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1077 where
1078 Self: Sized,
1079 I: Iterator<Item = &'i axum::http::HeaderValue>,
1080 {
1081 let version = values
1082 .next()
1083 .ok_or_else(axum::headers::Error::invalid)?
1084 .to_str()
1085 .map_err(|_| axum::headers::Error::invalid())?
1086 .parse()
1087 .map_err(|_| axum::headers::Error::invalid())?;
1088 Ok(Self(version))
1089 }
1090
1091 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1092 values.extend([self.0.to_string().parse().unwrap()]);
1093 }
1094}
1095
1096#[derive(Debug)]
1097pub struct ReleaseChannelHeader(String);
1098
1099impl Header for ReleaseChannelHeader {
1100 fn name() -> &'static HeaderName {
1101 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1102 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1103 }
1104
1105 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1106 where
1107 Self: Sized,
1108 I: Iterator<Item = &'i axum::http::HeaderValue>,
1109 {
1110 Ok(Self(
1111 values
1112 .next()
1113 .ok_or_else(axum::headers::Error::invalid)?
1114 .to_str()
1115 .map_err(|_| axum::headers::Error::invalid())?
1116 .to_owned(),
1117 ))
1118 }
1119
1120 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1121 values.extend([self.0.parse().unwrap()]);
1122 }
1123}
1124
1125pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1126 Router::new()
1127 .route("/rpc", get(handle_websocket_request))
1128 .layer(
1129 ServiceBuilder::new()
1130 .layer(Extension(server.app_state.clone()))
1131 .layer(middleware::from_fn(auth::validate_header)),
1132 )
1133 .route("/metrics", get(handle_metrics))
1134 .layer(Extension(server))
1135}
1136
1137pub async fn handle_websocket_request(
1138 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1139 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1140 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1141 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1142 Extension(server): Extension<Arc<Server>>,
1143 Extension(principal): Extension<Principal>,
1144 user_agent: Option<TypedHeader<UserAgent>>,
1145 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1146 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1147 ws: WebSocketUpgrade,
1148) -> axum::response::Response {
1149 if protocol_version != rpc::PROTOCOL_VERSION {
1150 return (
1151 StatusCode::UPGRADE_REQUIRED,
1152 "client must be upgraded".to_string(),
1153 )
1154 .into_response();
1155 }
1156
1157 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1158 return (
1159 StatusCode::UPGRADE_REQUIRED,
1160 "no version header found".to_string(),
1161 )
1162 .into_response();
1163 };
1164
1165 let release_channel = release_channel_header.map(|header| header.0.0);
1166
1167 if !version.can_collaborate() {
1168 return (
1169 StatusCode::UPGRADE_REQUIRED,
1170 "client must be upgraded".to_string(),
1171 )
1172 .into_response();
1173 }
1174
1175 let socket_address = socket_address.to_string();
1176
1177 // Acquire connection guard before WebSocket upgrade
1178 let connection_guard = match ConnectionGuard::try_acquire() {
1179 Ok(guard) => guard,
1180 Err(()) => {
1181 return (
1182 StatusCode::SERVICE_UNAVAILABLE,
1183 "Too many concurrent connections",
1184 )
1185 .into_response();
1186 }
1187 };
1188
1189 ws.on_upgrade(move |socket| {
1190 let socket = socket
1191 .map_ok(to_tungstenite_message)
1192 .err_into()
1193 .with(|message| async move { to_axum_message(message) });
1194 let connection = Connection::new(Box::pin(socket));
1195 async move {
1196 server
1197 .handle_connection(
1198 connection,
1199 socket_address,
1200 principal,
1201 version,
1202 release_channel,
1203 user_agent.map(|header| header.to_string()),
1204 country_code_header.map(|header| header.to_string()),
1205 system_id_header.map(|header| header.to_string()),
1206 None,
1207 Executor::Production,
1208 Some(connection_guard),
1209 )
1210 .await;
1211 }
1212 })
1213}
1214
1215pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1216 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1217 let connections_metric = CONNECTIONS_METRIC
1218 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1219
1220 let connections = server
1221 .connection_pool
1222 .lock()
1223 .connections()
1224 .filter(|connection| !connection.admin)
1225 .count();
1226 connections_metric.set(connections as _);
1227
1228 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1229 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1230 register_int_gauge!(
1231 "shared_projects",
1232 "number of open projects with one or more guests"
1233 )
1234 .unwrap()
1235 });
1236
1237 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1238 shared_projects_metric.set(shared_projects as _);
1239
1240 let encoder = prometheus::TextEncoder::new();
1241 let metric_families = prometheus::gather();
1242 let encoded_metrics = encoder
1243 .encode_to_string(&metric_families)
1244 .map_err(|err| anyhow!("{err}"))?;
1245 Ok(encoded_metrics)
1246}
1247
1248#[instrument(err, skip(executor))]
1249async fn connection_lost(
1250 session: Session,
1251 mut teardown: watch::Receiver<bool>,
1252 executor: Executor,
1253) -> Result<()> {
1254 session.peer.disconnect(session.connection_id);
1255 session
1256 .connection_pool()
1257 .await
1258 .remove_connection(session.connection_id)?;
1259
1260 session
1261 .db()
1262 .await
1263 .connection_lost(session.connection_id)
1264 .await
1265 .trace_err();
1266
1267 futures::select_biased! {
1268 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1269
1270 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1271 leave_room_for_session(&session, session.connection_id).await.trace_err();
1272 leave_channel_buffers_for_session(&session)
1273 .await
1274 .trace_err();
1275
1276 if !session
1277 .connection_pool()
1278 .await
1279 .is_user_online(session.user_id())
1280 {
1281 let db = session.db().await;
1282 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1283 room_updated(&room, &session.peer);
1284 }
1285 }
1286
1287 update_user_contacts(session.user_id(), &session).await?;
1288 },
1289 _ = teardown.changed().fuse() => {}
1290 }
1291
1292 Ok(())
1293}
1294
1295/// Acknowledges a ping from a client, used to keep the connection alive.
1296async fn ping(
1297 _: proto::Ping,
1298 response: Response<proto::Ping>,
1299 _session: MessageContext,
1300) -> Result<()> {
1301 response.send(proto::Ack {})?;
1302 Ok(())
1303}
1304
1305/// Creates a new room for calling (outside of channels)
1306async fn create_room(
1307 _request: proto::CreateRoom,
1308 response: Response<proto::CreateRoom>,
1309 session: MessageContext,
1310) -> Result<()> {
1311 let livekit_room = nanoid::nanoid!(30);
1312
1313 let live_kit_connection_info = util::maybe!(async {
1314 let live_kit = session.app_state.livekit_client.as_ref();
1315 let live_kit = live_kit?;
1316 let user_id = session.user_id().to_string();
1317
1318 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1319
1320 Some(proto::LiveKitConnectionInfo {
1321 server_url: live_kit.url().into(),
1322 token,
1323 can_publish: true,
1324 })
1325 })
1326 .await;
1327
1328 let room = session
1329 .db()
1330 .await
1331 .create_room(session.user_id(), session.connection_id, &livekit_room)
1332 .await?;
1333
1334 response.send(proto::CreateRoomResponse {
1335 room: Some(room.clone()),
1336 live_kit_connection_info,
1337 })?;
1338
1339 update_user_contacts(session.user_id(), &session).await?;
1340 Ok(())
1341}
1342
1343/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1344async fn join_room(
1345 request: proto::JoinRoom,
1346 response: Response<proto::JoinRoom>,
1347 session: MessageContext,
1348) -> Result<()> {
1349 let room_id = RoomId::from_proto(request.id);
1350
1351 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1352
1353 if let Some(channel_id) = channel_id {
1354 return join_channel_internal(channel_id, Box::new(response), session).await;
1355 }
1356
1357 let joined_room = {
1358 let room = session
1359 .db()
1360 .await
1361 .join_room(room_id, session.user_id(), session.connection_id)
1362 .await?;
1363 room_updated(&room.room, &session.peer);
1364 room.into_inner()
1365 };
1366
1367 for connection_id in session
1368 .connection_pool()
1369 .await
1370 .user_connection_ids(session.user_id())
1371 {
1372 session
1373 .peer
1374 .send(
1375 connection_id,
1376 proto::CallCanceled {
1377 room_id: room_id.to_proto(),
1378 },
1379 )
1380 .trace_err();
1381 }
1382
1383 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1384 {
1385 live_kit
1386 .room_token(
1387 &joined_room.room.livekit_room,
1388 &session.user_id().to_string(),
1389 )
1390 .trace_err()
1391 .map(|token| proto::LiveKitConnectionInfo {
1392 server_url: live_kit.url().into(),
1393 token,
1394 can_publish: true,
1395 })
1396 } else {
1397 None
1398 };
1399
1400 response.send(proto::JoinRoomResponse {
1401 room: Some(joined_room.room),
1402 channel_id: None,
1403 live_kit_connection_info,
1404 })?;
1405
1406 update_user_contacts(session.user_id(), &session).await?;
1407 Ok(())
1408}
1409
1410/// Rejoin room is used to reconnect to a room after connection errors.
1411async fn rejoin_room(
1412 request: proto::RejoinRoom,
1413 response: Response<proto::RejoinRoom>,
1414 session: MessageContext,
1415) -> Result<()> {
1416 let room;
1417 let channel;
1418 {
1419 let mut rejoined_room = session
1420 .db()
1421 .await
1422 .rejoin_room(request, session.user_id(), session.connection_id)
1423 .await?;
1424
1425 response.send(proto::RejoinRoomResponse {
1426 room: Some(rejoined_room.room.clone()),
1427 reshared_projects: rejoined_room
1428 .reshared_projects
1429 .iter()
1430 .map(|project| proto::ResharedProject {
1431 id: project.id.to_proto(),
1432 collaborators: project
1433 .collaborators
1434 .iter()
1435 .map(|collaborator| collaborator.to_proto())
1436 .collect(),
1437 })
1438 .collect(),
1439 rejoined_projects: rejoined_room
1440 .rejoined_projects
1441 .iter()
1442 .map(|rejoined_project| rejoined_project.to_proto())
1443 .collect(),
1444 })?;
1445 room_updated(&rejoined_room.room, &session.peer);
1446
1447 for project in &rejoined_room.reshared_projects {
1448 for collaborator in &project.collaborators {
1449 session
1450 .peer
1451 .send(
1452 collaborator.connection_id,
1453 proto::UpdateProjectCollaborator {
1454 project_id: project.id.to_proto(),
1455 old_peer_id: Some(project.old_connection_id.into()),
1456 new_peer_id: Some(session.connection_id.into()),
1457 },
1458 )
1459 .trace_err();
1460 }
1461
1462 broadcast(
1463 Some(session.connection_id),
1464 project
1465 .collaborators
1466 .iter()
1467 .map(|collaborator| collaborator.connection_id),
1468 |connection_id| {
1469 session.peer.forward_send(
1470 session.connection_id,
1471 connection_id,
1472 proto::UpdateProject {
1473 project_id: project.id.to_proto(),
1474 worktrees: project.worktrees.clone(),
1475 },
1476 )
1477 },
1478 );
1479 }
1480
1481 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1482
1483 let rejoined_room = rejoined_room.into_inner();
1484
1485 room = rejoined_room.room;
1486 channel = rejoined_room.channel;
1487 }
1488
1489 if let Some(channel) = channel {
1490 channel_updated(
1491 &channel,
1492 &room,
1493 &session.peer,
1494 &*session.connection_pool().await,
1495 );
1496 }
1497
1498 update_user_contacts(session.user_id(), &session).await?;
1499 Ok(())
1500}
1501
1502fn notify_rejoined_projects(
1503 rejoined_projects: &mut Vec<RejoinedProject>,
1504 session: &Session,
1505) -> Result<()> {
1506 for project in rejoined_projects.iter() {
1507 for collaborator in &project.collaborators {
1508 session
1509 .peer
1510 .send(
1511 collaborator.connection_id,
1512 proto::UpdateProjectCollaborator {
1513 project_id: project.id.to_proto(),
1514 old_peer_id: Some(project.old_connection_id.into()),
1515 new_peer_id: Some(session.connection_id.into()),
1516 },
1517 )
1518 .trace_err();
1519 }
1520 }
1521
1522 for project in rejoined_projects {
1523 for worktree in mem::take(&mut project.worktrees) {
1524 // Stream this worktree's entries.
1525 let message = proto::UpdateWorktree {
1526 project_id: project.id.to_proto(),
1527 worktree_id: worktree.id,
1528 abs_path: worktree.abs_path.clone(),
1529 root_name: worktree.root_name,
1530 updated_entries: worktree.updated_entries,
1531 removed_entries: worktree.removed_entries,
1532 scan_id: worktree.scan_id,
1533 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1534 updated_repositories: worktree.updated_repositories,
1535 removed_repositories: worktree.removed_repositories,
1536 };
1537 for update in proto::split_worktree_update(message) {
1538 session.peer.send(session.connection_id, update)?;
1539 }
1540
1541 // Stream this worktree's diagnostics.
1542 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1543 if let Some(summary) = worktree_diagnostics.next() {
1544 let message = proto::UpdateDiagnosticSummary {
1545 project_id: project.id.to_proto(),
1546 worktree_id: worktree.id,
1547 summary: Some(summary),
1548 more_summaries: worktree_diagnostics.collect(),
1549 };
1550 session.peer.send(session.connection_id, message)?;
1551 }
1552
1553 for settings_file in worktree.settings_files {
1554 session.peer.send(
1555 session.connection_id,
1556 proto::UpdateWorktreeSettings {
1557 project_id: project.id.to_proto(),
1558 worktree_id: worktree.id,
1559 path: settings_file.path,
1560 content: Some(settings_file.content),
1561 kind: Some(settings_file.kind.to_proto().into()),
1562 outside_worktree: Some(settings_file.outside_worktree),
1563 },
1564 )?;
1565 }
1566 }
1567
1568 for repository in mem::take(&mut project.updated_repositories) {
1569 for update in split_repository_update(repository) {
1570 session.peer.send(session.connection_id, update)?;
1571 }
1572 }
1573
1574 for id in mem::take(&mut project.removed_repositories) {
1575 session.peer.send(
1576 session.connection_id,
1577 proto::RemoveRepository {
1578 project_id: project.id.to_proto(),
1579 id,
1580 },
1581 )?;
1582 }
1583 }
1584
1585 Ok(())
1586}
1587
1588/// leave room disconnects from the room.
1589async fn leave_room(
1590 _: proto::LeaveRoom,
1591 response: Response<proto::LeaveRoom>,
1592 session: MessageContext,
1593) -> Result<()> {
1594 leave_room_for_session(&session, session.connection_id).await?;
1595 response.send(proto::Ack {})?;
1596 Ok(())
1597}
1598
1599/// Updates the permissions of someone else in the room.
1600async fn set_room_participant_role(
1601 request: proto::SetRoomParticipantRole,
1602 response: Response<proto::SetRoomParticipantRole>,
1603 session: MessageContext,
1604) -> Result<()> {
1605 let user_id = UserId::from_proto(request.user_id);
1606 let role = ChannelRole::from(request.role());
1607
1608 let (livekit_room, can_publish) = {
1609 let room = session
1610 .db()
1611 .await
1612 .set_room_participant_role(
1613 session.user_id(),
1614 RoomId::from_proto(request.room_id),
1615 user_id,
1616 role,
1617 )
1618 .await?;
1619
1620 let livekit_room = room.livekit_room.clone();
1621 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1622 room_updated(&room, &session.peer);
1623 (livekit_room, can_publish)
1624 };
1625
1626 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1627 live_kit
1628 .update_participant(
1629 livekit_room.clone(),
1630 request.user_id.to_string(),
1631 livekit_api::proto::ParticipantPermission {
1632 can_subscribe: true,
1633 can_publish,
1634 can_publish_data: can_publish,
1635 hidden: false,
1636 recorder: false,
1637 },
1638 )
1639 .await
1640 .trace_err();
1641 }
1642
1643 response.send(proto::Ack {})?;
1644 Ok(())
1645}
1646
1647/// Call someone else into the current room
1648async fn call(
1649 request: proto::Call,
1650 response: Response<proto::Call>,
1651 session: MessageContext,
1652) -> Result<()> {
1653 let room_id = RoomId::from_proto(request.room_id);
1654 let calling_user_id = session.user_id();
1655 let calling_connection_id = session.connection_id;
1656 let called_user_id = UserId::from_proto(request.called_user_id);
1657 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1658 if !session
1659 .db()
1660 .await
1661 .has_contact(calling_user_id, called_user_id)
1662 .await?
1663 {
1664 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1665 }
1666
1667 let incoming_call = {
1668 let (room, incoming_call) = &mut *session
1669 .db()
1670 .await
1671 .call(
1672 room_id,
1673 calling_user_id,
1674 calling_connection_id,
1675 called_user_id,
1676 initial_project_id,
1677 )
1678 .await?;
1679 room_updated(room, &session.peer);
1680 mem::take(incoming_call)
1681 };
1682 update_user_contacts(called_user_id, &session).await?;
1683
1684 let mut calls = session
1685 .connection_pool()
1686 .await
1687 .user_connection_ids(called_user_id)
1688 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1689 .collect::<FuturesUnordered<_>>();
1690
1691 while let Some(call_response) = calls.next().await {
1692 match call_response.as_ref() {
1693 Ok(_) => {
1694 response.send(proto::Ack {})?;
1695 return Ok(());
1696 }
1697 Err(_) => {
1698 call_response.trace_err();
1699 }
1700 }
1701 }
1702
1703 {
1704 let room = session
1705 .db()
1706 .await
1707 .call_failed(room_id, called_user_id)
1708 .await?;
1709 room_updated(&room, &session.peer);
1710 }
1711 update_user_contacts(called_user_id, &session).await?;
1712
1713 Err(anyhow!("failed to ring user"))?
1714}
1715
1716/// Cancel an outgoing call.
1717async fn cancel_call(
1718 request: proto::CancelCall,
1719 response: Response<proto::CancelCall>,
1720 session: MessageContext,
1721) -> Result<()> {
1722 let called_user_id = UserId::from_proto(request.called_user_id);
1723 let room_id = RoomId::from_proto(request.room_id);
1724 {
1725 let room = session
1726 .db()
1727 .await
1728 .cancel_call(room_id, session.connection_id, called_user_id)
1729 .await?;
1730 room_updated(&room, &session.peer);
1731 }
1732
1733 for connection_id in session
1734 .connection_pool()
1735 .await
1736 .user_connection_ids(called_user_id)
1737 {
1738 session
1739 .peer
1740 .send(
1741 connection_id,
1742 proto::CallCanceled {
1743 room_id: room_id.to_proto(),
1744 },
1745 )
1746 .trace_err();
1747 }
1748 response.send(proto::Ack {})?;
1749
1750 update_user_contacts(called_user_id, &session).await?;
1751 Ok(())
1752}
1753
1754/// Decline an incoming call.
1755async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1756 let room_id = RoomId::from_proto(message.room_id);
1757 {
1758 let room = session
1759 .db()
1760 .await
1761 .decline_call(Some(room_id), session.user_id())
1762 .await?
1763 .context("declining call")?;
1764 room_updated(&room, &session.peer);
1765 }
1766
1767 for connection_id in session
1768 .connection_pool()
1769 .await
1770 .user_connection_ids(session.user_id())
1771 {
1772 session
1773 .peer
1774 .send(
1775 connection_id,
1776 proto::CallCanceled {
1777 room_id: room_id.to_proto(),
1778 },
1779 )
1780 .trace_err();
1781 }
1782 update_user_contacts(session.user_id(), &session).await?;
1783 Ok(())
1784}
1785
1786/// Updates other participants in the room with your current location.
1787async fn update_participant_location(
1788 request: proto::UpdateParticipantLocation,
1789 response: Response<proto::UpdateParticipantLocation>,
1790 session: MessageContext,
1791) -> Result<()> {
1792 let room_id = RoomId::from_proto(request.room_id);
1793 let location = request.location.context("invalid location")?;
1794
1795 let db = session.db().await;
1796 let room = db
1797 .update_room_participant_location(room_id, session.connection_id, location)
1798 .await?;
1799
1800 room_updated(&room, &session.peer);
1801 response.send(proto::Ack {})?;
1802 Ok(())
1803}
1804
1805/// Share a project into the room.
1806async fn share_project(
1807 request: proto::ShareProject,
1808 response: Response<proto::ShareProject>,
1809 session: MessageContext,
1810) -> Result<()> {
1811 let (project_id, room) = &*session
1812 .db()
1813 .await
1814 .share_project(
1815 RoomId::from_proto(request.room_id),
1816 session.connection_id,
1817 &request.worktrees,
1818 request.is_ssh_project,
1819 request.windows_paths.unwrap_or(false),
1820 )
1821 .await?;
1822 response.send(proto::ShareProjectResponse {
1823 project_id: project_id.to_proto(),
1824 })?;
1825 room_updated(room, &session.peer);
1826
1827 Ok(())
1828}
1829
1830/// Unshare a project from the room.
1831async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1832 let project_id = ProjectId::from_proto(message.project_id);
1833 unshare_project_internal(project_id, session.connection_id, &session).await
1834}
1835
1836async fn unshare_project_internal(
1837 project_id: ProjectId,
1838 connection_id: ConnectionId,
1839 session: &Session,
1840) -> Result<()> {
1841 let delete = {
1842 let room_guard = session
1843 .db()
1844 .await
1845 .unshare_project(project_id, connection_id)
1846 .await?;
1847
1848 let (delete, room, guest_connection_ids) = &*room_guard;
1849
1850 let message = proto::UnshareProject {
1851 project_id: project_id.to_proto(),
1852 };
1853
1854 broadcast(
1855 Some(connection_id),
1856 guest_connection_ids.iter().copied(),
1857 |conn_id| session.peer.send(conn_id, message.clone()),
1858 );
1859 if let Some(room) = room {
1860 room_updated(room, &session.peer);
1861 }
1862
1863 *delete
1864 };
1865
1866 if delete {
1867 let db = session.db().await;
1868 db.delete_project(project_id).await?;
1869 }
1870
1871 Ok(())
1872}
1873
1874/// Join someone elses shared project.
1875async fn join_project(
1876 request: proto::JoinProject,
1877 response: Response<proto::JoinProject>,
1878 session: MessageContext,
1879) -> Result<()> {
1880 let project_id = ProjectId::from_proto(request.project_id);
1881
1882 tracing::info!(%project_id, "join project");
1883
1884 let db = session.db().await;
1885 let (project, replica_id) = &mut *db
1886 .join_project(
1887 project_id,
1888 session.connection_id,
1889 session.user_id(),
1890 request.committer_name.clone(),
1891 request.committer_email.clone(),
1892 )
1893 .await?;
1894 drop(db);
1895 tracing::info!(%project_id, "join remote project");
1896 let collaborators = project
1897 .collaborators
1898 .iter()
1899 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1900 .map(|collaborator| collaborator.to_proto())
1901 .collect::<Vec<_>>();
1902 let project_id = project.id;
1903 let guest_user_id = session.user_id();
1904
1905 let worktrees = project
1906 .worktrees
1907 .iter()
1908 .map(|(id, worktree)| proto::WorktreeMetadata {
1909 id: *id,
1910 root_name: worktree.root_name.clone(),
1911 visible: worktree.visible,
1912 abs_path: worktree.abs_path.clone(),
1913 })
1914 .collect::<Vec<_>>();
1915
1916 let add_project_collaborator = proto::AddProjectCollaborator {
1917 project_id: project_id.to_proto(),
1918 collaborator: Some(proto::Collaborator {
1919 peer_id: Some(session.connection_id.into()),
1920 replica_id: replica_id.0 as u32,
1921 user_id: guest_user_id.to_proto(),
1922 is_host: false,
1923 committer_name: request.committer_name.clone(),
1924 committer_email: request.committer_email.clone(),
1925 }),
1926 };
1927
1928 for collaborator in &collaborators {
1929 session
1930 .peer
1931 .send(
1932 collaborator.peer_id.unwrap().into(),
1933 add_project_collaborator.clone(),
1934 )
1935 .trace_err();
1936 }
1937
1938 // First, we send the metadata associated with each worktree.
1939 let (language_servers, language_server_capabilities) = project
1940 .language_servers
1941 .clone()
1942 .into_iter()
1943 .map(|server| (server.server, server.capabilities))
1944 .unzip();
1945 response.send(proto::JoinProjectResponse {
1946 project_id: project.id.0 as u64,
1947 worktrees,
1948 replica_id: replica_id.0 as u32,
1949 collaborators,
1950 language_servers,
1951 language_server_capabilities,
1952 role: project.role.into(),
1953 windows_paths: project.path_style == PathStyle::Windows,
1954 })?;
1955
1956 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1957 // Stream this worktree's entries.
1958 let message = proto::UpdateWorktree {
1959 project_id: project_id.to_proto(),
1960 worktree_id,
1961 abs_path: worktree.abs_path.clone(),
1962 root_name: worktree.root_name,
1963 updated_entries: worktree.entries,
1964 removed_entries: Default::default(),
1965 scan_id: worktree.scan_id,
1966 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1967 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1968 removed_repositories: Default::default(),
1969 };
1970 for update in proto::split_worktree_update(message) {
1971 session.peer.send(session.connection_id, update.clone())?;
1972 }
1973
1974 // Stream this worktree's diagnostics.
1975 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1976 if let Some(summary) = worktree_diagnostics.next() {
1977 let message = proto::UpdateDiagnosticSummary {
1978 project_id: project.id.to_proto(),
1979 worktree_id: worktree.id,
1980 summary: Some(summary),
1981 more_summaries: worktree_diagnostics.collect(),
1982 };
1983 session.peer.send(session.connection_id, message)?;
1984 }
1985
1986 for settings_file in worktree.settings_files {
1987 session.peer.send(
1988 session.connection_id,
1989 proto::UpdateWorktreeSettings {
1990 project_id: project_id.to_proto(),
1991 worktree_id: worktree.id,
1992 path: settings_file.path,
1993 content: Some(settings_file.content),
1994 kind: Some(settings_file.kind.to_proto() as i32),
1995 outside_worktree: Some(settings_file.outside_worktree),
1996 },
1997 )?;
1998 }
1999 }
2000
2001 for repository in mem::take(&mut project.repositories) {
2002 for update in split_repository_update(repository) {
2003 session.peer.send(session.connection_id, update)?;
2004 }
2005 }
2006
2007 for language_server in &project.language_servers {
2008 session.peer.send(
2009 session.connection_id,
2010 proto::UpdateLanguageServer {
2011 project_id: project_id.to_proto(),
2012 server_name: Some(language_server.server.name.clone()),
2013 language_server_id: language_server.server.id,
2014 variant: Some(
2015 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2016 proto::LspDiskBasedDiagnosticsUpdated {},
2017 ),
2018 ),
2019 },
2020 )?;
2021 }
2022
2023 Ok(())
2024}
2025
2026/// Leave someone elses shared project.
2027async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2028 let sender_id = session.connection_id;
2029 let project_id = ProjectId::from_proto(request.project_id);
2030 let db = session.db().await;
2031
2032 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2033 tracing::info!(
2034 %project_id,
2035 "leave project"
2036 );
2037
2038 project_left(project, &session);
2039 if let Some(room) = room {
2040 room_updated(room, &session.peer);
2041 }
2042
2043 Ok(())
2044}
2045
2046/// Updates other participants with changes to the project
2047async fn update_project(
2048 request: proto::UpdateProject,
2049 response: Response<proto::UpdateProject>,
2050 session: MessageContext,
2051) -> Result<()> {
2052 let project_id = ProjectId::from_proto(request.project_id);
2053 let (room, guest_connection_ids) = &*session
2054 .db()
2055 .await
2056 .update_project(project_id, session.connection_id, &request.worktrees)
2057 .await?;
2058 broadcast(
2059 Some(session.connection_id),
2060 guest_connection_ids.iter().copied(),
2061 |connection_id| {
2062 session
2063 .peer
2064 .forward_send(session.connection_id, connection_id, request.clone())
2065 },
2066 );
2067 if let Some(room) = room {
2068 room_updated(room, &session.peer);
2069 }
2070 response.send(proto::Ack {})?;
2071
2072 Ok(())
2073}
2074
2075/// Updates other participants with changes to the worktree
2076async fn update_worktree(
2077 request: proto::UpdateWorktree,
2078 response: Response<proto::UpdateWorktree>,
2079 session: MessageContext,
2080) -> Result<()> {
2081 let guest_connection_ids = session
2082 .db()
2083 .await
2084 .update_worktree(&request, session.connection_id)
2085 .await?;
2086
2087 broadcast(
2088 Some(session.connection_id),
2089 guest_connection_ids.iter().copied(),
2090 |connection_id| {
2091 session
2092 .peer
2093 .forward_send(session.connection_id, connection_id, request.clone())
2094 },
2095 );
2096 response.send(proto::Ack {})?;
2097 Ok(())
2098}
2099
2100async fn update_repository(
2101 request: proto::UpdateRepository,
2102 response: Response<proto::UpdateRepository>,
2103 session: MessageContext,
2104) -> Result<()> {
2105 let guest_connection_ids = session
2106 .db()
2107 .await
2108 .update_repository(&request, session.connection_id)
2109 .await?;
2110
2111 broadcast(
2112 Some(session.connection_id),
2113 guest_connection_ids.iter().copied(),
2114 |connection_id| {
2115 session
2116 .peer
2117 .forward_send(session.connection_id, connection_id, request.clone())
2118 },
2119 );
2120 response.send(proto::Ack {})?;
2121 Ok(())
2122}
2123
2124async fn remove_repository(
2125 request: proto::RemoveRepository,
2126 response: Response<proto::RemoveRepository>,
2127 session: MessageContext,
2128) -> Result<()> {
2129 let guest_connection_ids = session
2130 .db()
2131 .await
2132 .remove_repository(&request, session.connection_id)
2133 .await?;
2134
2135 broadcast(
2136 Some(session.connection_id),
2137 guest_connection_ids.iter().copied(),
2138 |connection_id| {
2139 session
2140 .peer
2141 .forward_send(session.connection_id, connection_id, request.clone())
2142 },
2143 );
2144 response.send(proto::Ack {})?;
2145 Ok(())
2146}
2147
2148/// Updates other participants with changes to the diagnostics
2149async fn update_diagnostic_summary(
2150 message: proto::UpdateDiagnosticSummary,
2151 session: MessageContext,
2152) -> Result<()> {
2153 let guest_connection_ids = session
2154 .db()
2155 .await
2156 .update_diagnostic_summary(&message, session.connection_id)
2157 .await?;
2158
2159 broadcast(
2160 Some(session.connection_id),
2161 guest_connection_ids.iter().copied(),
2162 |connection_id| {
2163 session
2164 .peer
2165 .forward_send(session.connection_id, connection_id, message.clone())
2166 },
2167 );
2168
2169 Ok(())
2170}
2171
2172/// Updates other participants with changes to the worktree settings
2173async fn update_worktree_settings(
2174 message: proto::UpdateWorktreeSettings,
2175 session: MessageContext,
2176) -> Result<()> {
2177 let guest_connection_ids = session
2178 .db()
2179 .await
2180 .update_worktree_settings(&message, session.connection_id)
2181 .await?;
2182
2183 broadcast(
2184 Some(session.connection_id),
2185 guest_connection_ids.iter().copied(),
2186 |connection_id| {
2187 session
2188 .peer
2189 .forward_send(session.connection_id, connection_id, message.clone())
2190 },
2191 );
2192
2193 Ok(())
2194}
2195
2196/// Notify other participants that a language server has started.
2197async fn start_language_server(
2198 request: proto::StartLanguageServer,
2199 session: MessageContext,
2200) -> Result<()> {
2201 let guest_connection_ids = session
2202 .db()
2203 .await
2204 .start_language_server(&request, session.connection_id)
2205 .await?;
2206
2207 broadcast(
2208 Some(session.connection_id),
2209 guest_connection_ids.iter().copied(),
2210 |connection_id| {
2211 session
2212 .peer
2213 .forward_send(session.connection_id, connection_id, request.clone())
2214 },
2215 );
2216 Ok(())
2217}
2218
2219/// Notify other participants that a language server has changed.
2220async fn update_language_server(
2221 request: proto::UpdateLanguageServer,
2222 session: MessageContext,
2223) -> Result<()> {
2224 let project_id = ProjectId::from_proto(request.project_id);
2225 let db = session.db().await;
2226
2227 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2228 && let Some(capabilities) = update.capabilities.clone()
2229 {
2230 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2231 .await?;
2232 }
2233
2234 let project_connection_ids = db
2235 .project_connection_ids(project_id, session.connection_id, true)
2236 .await?;
2237 broadcast(
2238 Some(session.connection_id),
2239 project_connection_ids.iter().copied(),
2240 |connection_id| {
2241 session
2242 .peer
2243 .forward_send(session.connection_id, connection_id, request.clone())
2244 },
2245 );
2246 Ok(())
2247}
2248
2249/// forward a project request to the host. These requests should be read only
2250/// as guests are allowed to send them.
2251async fn forward_read_only_project_request<T>(
2252 request: T,
2253 response: Response<T>,
2254 session: MessageContext,
2255) -> Result<()>
2256where
2257 T: EntityMessage + RequestMessage,
2258{
2259 let project_id = ProjectId::from_proto(request.remote_entity_id());
2260 let host_connection_id = session
2261 .db()
2262 .await
2263 .host_for_read_only_project_request(project_id, session.connection_id)
2264 .await?;
2265 let payload = session.forward_request(host_connection_id, request).await?;
2266 response.send(payload)?;
2267 Ok(())
2268}
2269
2270/// forward a project request to the host. These requests are disallowed
2271/// for guests.
2272async fn forward_mutating_project_request<T>(
2273 request: T,
2274 response: Response<T>,
2275 session: MessageContext,
2276) -> Result<()>
2277where
2278 T: EntityMessage + RequestMessage,
2279{
2280 let project_id = ProjectId::from_proto(request.remote_entity_id());
2281
2282 let host_connection_id = session
2283 .db()
2284 .await
2285 .host_for_mutating_project_request(project_id, session.connection_id)
2286 .await?;
2287 let payload = session.forward_request(host_connection_id, request).await?;
2288 response.send(payload)?;
2289 Ok(())
2290}
2291
2292async fn lsp_query(
2293 request: proto::LspQuery,
2294 response: Response<proto::LspQuery>,
2295 session: MessageContext,
2296) -> Result<()> {
2297 let (name, should_write) = request.query_name_and_write_permissions();
2298 tracing::Span::current().record("lsp_query_request", name);
2299 tracing::info!("lsp_query message received");
2300 if should_write {
2301 forward_mutating_project_request(request, response, session).await
2302 } else {
2303 forward_read_only_project_request(request, response, session).await
2304 }
2305}
2306
2307/// Notify other participants that a new buffer has been created
2308async fn create_buffer_for_peer(
2309 request: proto::CreateBufferForPeer,
2310 session: MessageContext,
2311) -> Result<()> {
2312 session
2313 .db()
2314 .await
2315 .check_user_is_project_host(
2316 ProjectId::from_proto(request.project_id),
2317 session.connection_id,
2318 )
2319 .await?;
2320 let peer_id = request.peer_id.context("invalid peer id")?;
2321 session
2322 .peer
2323 .forward_send(session.connection_id, peer_id.into(), request)?;
2324 Ok(())
2325}
2326
2327/// Notify other participants that a new image has been created
2328async fn create_image_for_peer(
2329 request: proto::CreateImageForPeer,
2330 session: MessageContext,
2331) -> Result<()> {
2332 session
2333 .db()
2334 .await
2335 .check_user_is_project_host(
2336 ProjectId::from_proto(request.project_id),
2337 session.connection_id,
2338 )
2339 .await?;
2340 let peer_id = request.peer_id.context("invalid peer id")?;
2341 session
2342 .peer
2343 .forward_send(session.connection_id, peer_id.into(), request)?;
2344 Ok(())
2345}
2346
2347/// Notify other participants that a buffer has been updated. This is
2348/// allowed for guests as long as the update is limited to selections.
2349async fn update_buffer(
2350 request: proto::UpdateBuffer,
2351 response: Response<proto::UpdateBuffer>,
2352 session: MessageContext,
2353) -> Result<()> {
2354 let project_id = ProjectId::from_proto(request.project_id);
2355 let mut capability = Capability::ReadOnly;
2356
2357 for op in request.operations.iter() {
2358 match op.variant {
2359 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2360 Some(_) => capability = Capability::ReadWrite,
2361 }
2362 }
2363
2364 let host = {
2365 let guard = session
2366 .db()
2367 .await
2368 .connections_for_buffer_update(project_id, session.connection_id, capability)
2369 .await?;
2370
2371 let (host, guests) = &*guard;
2372
2373 broadcast(
2374 Some(session.connection_id),
2375 guests.clone(),
2376 |connection_id| {
2377 session
2378 .peer
2379 .forward_send(session.connection_id, connection_id, request.clone())
2380 },
2381 );
2382
2383 *host
2384 };
2385
2386 if host != session.connection_id {
2387 session.forward_request(host, request.clone()).await?;
2388 }
2389
2390 response.send(proto::Ack {})?;
2391 Ok(())
2392}
2393
2394async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2395 let project_id = ProjectId::from_proto(message.project_id);
2396
2397 let operation = message.operation.as_ref().context("invalid operation")?;
2398 let capability = match operation.variant.as_ref() {
2399 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2400 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2401 match buffer_op.variant {
2402 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2403 Capability::ReadOnly
2404 }
2405 _ => Capability::ReadWrite,
2406 }
2407 } else {
2408 Capability::ReadWrite
2409 }
2410 }
2411 Some(_) => Capability::ReadWrite,
2412 None => Capability::ReadOnly,
2413 };
2414
2415 let guard = session
2416 .db()
2417 .await
2418 .connections_for_buffer_update(project_id, session.connection_id, capability)
2419 .await?;
2420
2421 let (host, guests) = &*guard;
2422
2423 broadcast(
2424 Some(session.connection_id),
2425 guests.iter().chain([host]).copied(),
2426 |connection_id| {
2427 session
2428 .peer
2429 .forward_send(session.connection_id, connection_id, message.clone())
2430 },
2431 );
2432
2433 Ok(())
2434}
2435
2436async fn forward_project_search_chunk(
2437 message: proto::FindSearchCandidatesChunk,
2438 response: Response<proto::FindSearchCandidatesChunk>,
2439 session: MessageContext,
2440) -> Result<()> {
2441 let peer_id = message.peer_id.context("missing peer_id")?;
2442 let payload = session
2443 .peer
2444 .forward_request(session.connection_id, peer_id.into(), message)
2445 .await?;
2446 response.send(payload)?;
2447 Ok(())
2448}
2449
2450/// Notify other participants that a project has been updated.
2451async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2452 request: T,
2453 session: MessageContext,
2454) -> Result<()> {
2455 let project_id = ProjectId::from_proto(request.remote_entity_id());
2456 let project_connection_ids = session
2457 .db()
2458 .await
2459 .project_connection_ids(project_id, session.connection_id, false)
2460 .await?;
2461
2462 broadcast(
2463 Some(session.connection_id),
2464 project_connection_ids.iter().copied(),
2465 |connection_id| {
2466 session
2467 .peer
2468 .forward_send(session.connection_id, connection_id, request.clone())
2469 },
2470 );
2471 Ok(())
2472}
2473
2474/// Start following another user in a call.
2475async fn follow(
2476 request: proto::Follow,
2477 response: Response<proto::Follow>,
2478 session: MessageContext,
2479) -> Result<()> {
2480 let room_id = RoomId::from_proto(request.room_id);
2481 let project_id = request.project_id.map(ProjectId::from_proto);
2482 let leader_id = request.leader_id.context("invalid leader id")?.into();
2483 let follower_id = session.connection_id;
2484
2485 session
2486 .db()
2487 .await
2488 .check_room_participants(room_id, leader_id, session.connection_id)
2489 .await?;
2490
2491 let response_payload = session.forward_request(leader_id, request).await?;
2492 response.send(response_payload)?;
2493
2494 if let Some(project_id) = project_id {
2495 let room = session
2496 .db()
2497 .await
2498 .follow(room_id, project_id, leader_id, follower_id)
2499 .await?;
2500 room_updated(&room, &session.peer);
2501 }
2502
2503 Ok(())
2504}
2505
2506/// Stop following another user in a call.
2507async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2508 let room_id = RoomId::from_proto(request.room_id);
2509 let project_id = request.project_id.map(ProjectId::from_proto);
2510 let leader_id = request.leader_id.context("invalid leader id")?.into();
2511 let follower_id = session.connection_id;
2512
2513 session
2514 .db()
2515 .await
2516 .check_room_participants(room_id, leader_id, session.connection_id)
2517 .await?;
2518
2519 session
2520 .peer
2521 .forward_send(session.connection_id, leader_id, request)?;
2522
2523 if let Some(project_id) = project_id {
2524 let room = session
2525 .db()
2526 .await
2527 .unfollow(room_id, project_id, leader_id, follower_id)
2528 .await?;
2529 room_updated(&room, &session.peer);
2530 }
2531
2532 Ok(())
2533}
2534
2535/// Notify everyone following you of your current location.
2536async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2537 let room_id = RoomId::from_proto(request.room_id);
2538 let database = session.db.lock().await;
2539
2540 let connection_ids = if let Some(project_id) = request.project_id {
2541 let project_id = ProjectId::from_proto(project_id);
2542 database
2543 .project_connection_ids(project_id, session.connection_id, true)
2544 .await?
2545 } else {
2546 database
2547 .room_connection_ids(room_id, session.connection_id)
2548 .await?
2549 };
2550
2551 // For now, don't send view update messages back to that view's current leader.
2552 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2553 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2554 _ => None,
2555 });
2556
2557 for connection_id in connection_ids.iter().cloned() {
2558 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2559 session
2560 .peer
2561 .forward_send(session.connection_id, connection_id, request.clone())?;
2562 }
2563 }
2564 Ok(())
2565}
2566
2567/// Get public data about users.
2568async fn get_users(
2569 request: proto::GetUsers,
2570 response: Response<proto::GetUsers>,
2571 session: MessageContext,
2572) -> Result<()> {
2573 let user_ids = request
2574 .user_ids
2575 .into_iter()
2576 .map(UserId::from_proto)
2577 .collect();
2578 let users = session
2579 .db()
2580 .await
2581 .get_users_by_ids(user_ids)
2582 .await?
2583 .into_iter()
2584 .map(|user| proto::User {
2585 id: user.id.to_proto(),
2586 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2587 github_login: user.github_login,
2588 name: user.name,
2589 })
2590 .collect();
2591 response.send(proto::UsersResponse { users })?;
2592 Ok(())
2593}
2594
2595/// Search for users (to invite) buy Github login
2596async fn fuzzy_search_users(
2597 request: proto::FuzzySearchUsers,
2598 response: Response<proto::FuzzySearchUsers>,
2599 session: MessageContext,
2600) -> Result<()> {
2601 let query = request.query;
2602 let users = match query.len() {
2603 0 => vec![],
2604 1 | 2 => session
2605 .db()
2606 .await
2607 .get_user_by_github_login(&query)
2608 .await?
2609 .into_iter()
2610 .collect(),
2611 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2612 };
2613 let users = users
2614 .into_iter()
2615 .filter(|user| user.id != session.user_id())
2616 .map(|user| proto::User {
2617 id: user.id.to_proto(),
2618 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2619 github_login: user.github_login,
2620 name: user.name,
2621 })
2622 .collect();
2623 response.send(proto::UsersResponse { users })?;
2624 Ok(())
2625}
2626
2627/// Send a contact request to another user.
2628async fn request_contact(
2629 request: proto::RequestContact,
2630 response: Response<proto::RequestContact>,
2631 session: MessageContext,
2632) -> Result<()> {
2633 let requester_id = session.user_id();
2634 let responder_id = UserId::from_proto(request.responder_id);
2635 if requester_id == responder_id {
2636 return Err(anyhow!("cannot add yourself as a contact"))?;
2637 }
2638
2639 let notifications = session
2640 .db()
2641 .await
2642 .send_contact_request(requester_id, responder_id)
2643 .await?;
2644
2645 // Update outgoing contact requests of requester
2646 let mut update = proto::UpdateContacts::default();
2647 update.outgoing_requests.push(responder_id.to_proto());
2648 for connection_id in session
2649 .connection_pool()
2650 .await
2651 .user_connection_ids(requester_id)
2652 {
2653 session.peer.send(connection_id, update.clone())?;
2654 }
2655
2656 // Update incoming contact requests of responder
2657 let mut update = proto::UpdateContacts::default();
2658 update
2659 .incoming_requests
2660 .push(proto::IncomingContactRequest {
2661 requester_id: requester_id.to_proto(),
2662 });
2663 let connection_pool = session.connection_pool().await;
2664 for connection_id in connection_pool.user_connection_ids(responder_id) {
2665 session.peer.send(connection_id, update.clone())?;
2666 }
2667
2668 send_notifications(&connection_pool, &session.peer, notifications);
2669
2670 response.send(proto::Ack {})?;
2671 Ok(())
2672}
2673
2674/// Accept or decline a contact request
2675async fn respond_to_contact_request(
2676 request: proto::RespondToContactRequest,
2677 response: Response<proto::RespondToContactRequest>,
2678 session: MessageContext,
2679) -> Result<()> {
2680 let responder_id = session.user_id();
2681 let requester_id = UserId::from_proto(request.requester_id);
2682 let db = session.db().await;
2683 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2684 db.dismiss_contact_notification(responder_id, requester_id)
2685 .await?;
2686 } else {
2687 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2688
2689 let notifications = db
2690 .respond_to_contact_request(responder_id, requester_id, accept)
2691 .await?;
2692 let requester_busy = db.is_user_busy(requester_id).await?;
2693 let responder_busy = db.is_user_busy(responder_id).await?;
2694
2695 let pool = session.connection_pool().await;
2696 // Update responder with new contact
2697 let mut update = proto::UpdateContacts::default();
2698 if accept {
2699 update
2700 .contacts
2701 .push(contact_for_user(requester_id, requester_busy, &pool));
2702 }
2703 update
2704 .remove_incoming_requests
2705 .push(requester_id.to_proto());
2706 for connection_id in pool.user_connection_ids(responder_id) {
2707 session.peer.send(connection_id, update.clone())?;
2708 }
2709
2710 // Update requester with new contact
2711 let mut update = proto::UpdateContacts::default();
2712 if accept {
2713 update
2714 .contacts
2715 .push(contact_for_user(responder_id, responder_busy, &pool));
2716 }
2717 update
2718 .remove_outgoing_requests
2719 .push(responder_id.to_proto());
2720
2721 for connection_id in pool.user_connection_ids(requester_id) {
2722 session.peer.send(connection_id, update.clone())?;
2723 }
2724
2725 send_notifications(&pool, &session.peer, notifications);
2726 }
2727
2728 response.send(proto::Ack {})?;
2729 Ok(())
2730}
2731
2732/// Remove a contact.
2733async fn remove_contact(
2734 request: proto::RemoveContact,
2735 response: Response<proto::RemoveContact>,
2736 session: MessageContext,
2737) -> Result<()> {
2738 let requester_id = session.user_id();
2739 let responder_id = UserId::from_proto(request.user_id);
2740 let db = session.db().await;
2741 let (contact_accepted, deleted_notification_id) =
2742 db.remove_contact(requester_id, responder_id).await?;
2743
2744 let pool = session.connection_pool().await;
2745 // Update outgoing contact requests of requester
2746 let mut update = proto::UpdateContacts::default();
2747 if contact_accepted {
2748 update.remove_contacts.push(responder_id.to_proto());
2749 } else {
2750 update
2751 .remove_outgoing_requests
2752 .push(responder_id.to_proto());
2753 }
2754 for connection_id in pool.user_connection_ids(requester_id) {
2755 session.peer.send(connection_id, update.clone())?;
2756 }
2757
2758 // Update incoming contact requests of responder
2759 let mut update = proto::UpdateContacts::default();
2760 if contact_accepted {
2761 update.remove_contacts.push(requester_id.to_proto());
2762 } else {
2763 update
2764 .remove_incoming_requests
2765 .push(requester_id.to_proto());
2766 }
2767 for connection_id in pool.user_connection_ids(responder_id) {
2768 session.peer.send(connection_id, update.clone())?;
2769 if let Some(notification_id) = deleted_notification_id {
2770 session.peer.send(
2771 connection_id,
2772 proto::DeleteNotification {
2773 notification_id: notification_id.to_proto(),
2774 },
2775 )?;
2776 }
2777 }
2778
2779 response.send(proto::Ack {})?;
2780 Ok(())
2781}
2782
2783fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2784 version.0.minor < 139
2785}
2786
2787async fn subscribe_to_channels(
2788 _: proto::SubscribeToChannels,
2789 session: MessageContext,
2790) -> Result<()> {
2791 subscribe_user_to_channels(session.user_id(), &session).await?;
2792 Ok(())
2793}
2794
2795async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2796 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2797 let mut pool = session.connection_pool().await;
2798 for membership in &channels_for_user.channel_memberships {
2799 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2800 }
2801 session.peer.send(
2802 session.connection_id,
2803 build_update_user_channels(&channels_for_user),
2804 )?;
2805 session.peer.send(
2806 session.connection_id,
2807 build_channels_update(channels_for_user),
2808 )?;
2809 Ok(())
2810}
2811
2812/// Creates a new channel.
2813async fn create_channel(
2814 request: proto::CreateChannel,
2815 response: Response<proto::CreateChannel>,
2816 session: MessageContext,
2817) -> Result<()> {
2818 let db = session.db().await;
2819
2820 let parent_id = request.parent_id.map(ChannelId::from_proto);
2821 let (channel, membership) = db
2822 .create_channel(&request.name, parent_id, session.user_id())
2823 .await?;
2824
2825 let root_id = channel.root_id();
2826 let channel = Channel::from_model(channel);
2827
2828 response.send(proto::CreateChannelResponse {
2829 channel: Some(channel.to_proto()),
2830 parent_id: request.parent_id,
2831 })?;
2832
2833 let mut connection_pool = session.connection_pool().await;
2834 if let Some(membership) = membership {
2835 connection_pool.subscribe_to_channel(
2836 membership.user_id,
2837 membership.channel_id,
2838 membership.role,
2839 );
2840 let update = proto::UpdateUserChannels {
2841 channel_memberships: vec![proto::ChannelMembership {
2842 channel_id: membership.channel_id.to_proto(),
2843 role: membership.role.into(),
2844 }],
2845 ..Default::default()
2846 };
2847 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2848 session.peer.send(connection_id, update.clone())?;
2849 }
2850 }
2851
2852 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2853 if !role.can_see_channel(channel.visibility) {
2854 continue;
2855 }
2856
2857 let update = proto::UpdateChannels {
2858 channels: vec![channel.to_proto()],
2859 ..Default::default()
2860 };
2861 session.peer.send(connection_id, update.clone())?;
2862 }
2863
2864 Ok(())
2865}
2866
2867/// Delete a channel
2868async fn delete_channel(
2869 request: proto::DeleteChannel,
2870 response: Response<proto::DeleteChannel>,
2871 session: MessageContext,
2872) -> Result<()> {
2873 let db = session.db().await;
2874
2875 let channel_id = request.channel_id;
2876 let (root_channel, removed_channels) = db
2877 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2878 .await?;
2879 response.send(proto::Ack {})?;
2880
2881 // Notify members of removed channels
2882 let mut update = proto::UpdateChannels::default();
2883 update
2884 .delete_channels
2885 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2886
2887 let connection_pool = session.connection_pool().await;
2888 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2889 session.peer.send(connection_id, update.clone())?;
2890 }
2891
2892 Ok(())
2893}
2894
2895/// Invite someone to join a channel.
2896async fn invite_channel_member(
2897 request: proto::InviteChannelMember,
2898 response: Response<proto::InviteChannelMember>,
2899 session: MessageContext,
2900) -> Result<()> {
2901 let db = session.db().await;
2902 let channel_id = ChannelId::from_proto(request.channel_id);
2903 let invitee_id = UserId::from_proto(request.user_id);
2904 let InviteMemberResult {
2905 channel,
2906 notifications,
2907 } = db
2908 .invite_channel_member(
2909 channel_id,
2910 invitee_id,
2911 session.user_id(),
2912 request.role().into(),
2913 )
2914 .await?;
2915
2916 let update = proto::UpdateChannels {
2917 channel_invitations: vec![channel.to_proto()],
2918 ..Default::default()
2919 };
2920
2921 let connection_pool = session.connection_pool().await;
2922 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2923 session.peer.send(connection_id, update.clone())?;
2924 }
2925
2926 send_notifications(&connection_pool, &session.peer, notifications);
2927
2928 response.send(proto::Ack {})?;
2929 Ok(())
2930}
2931
2932/// remove someone from a channel
2933async fn remove_channel_member(
2934 request: proto::RemoveChannelMember,
2935 response: Response<proto::RemoveChannelMember>,
2936 session: MessageContext,
2937) -> Result<()> {
2938 let db = session.db().await;
2939 let channel_id = ChannelId::from_proto(request.channel_id);
2940 let member_id = UserId::from_proto(request.user_id);
2941
2942 let RemoveChannelMemberResult {
2943 membership_update,
2944 notification_id,
2945 } = db
2946 .remove_channel_member(channel_id, member_id, session.user_id())
2947 .await?;
2948
2949 let mut connection_pool = session.connection_pool().await;
2950 notify_membership_updated(
2951 &mut connection_pool,
2952 membership_update,
2953 member_id,
2954 &session.peer,
2955 );
2956 for connection_id in connection_pool.user_connection_ids(member_id) {
2957 if let Some(notification_id) = notification_id {
2958 session
2959 .peer
2960 .send(
2961 connection_id,
2962 proto::DeleteNotification {
2963 notification_id: notification_id.to_proto(),
2964 },
2965 )
2966 .trace_err();
2967 }
2968 }
2969
2970 response.send(proto::Ack {})?;
2971 Ok(())
2972}
2973
2974/// Toggle the channel between public and private.
2975/// Care is taken to maintain the invariant that public channels only descend from public channels,
2976/// (though members-only channels can appear at any point in the hierarchy).
2977async fn set_channel_visibility(
2978 request: proto::SetChannelVisibility,
2979 response: Response<proto::SetChannelVisibility>,
2980 session: MessageContext,
2981) -> Result<()> {
2982 let db = session.db().await;
2983 let channel_id = ChannelId::from_proto(request.channel_id);
2984 let visibility = request.visibility().into();
2985
2986 let channel_model = db
2987 .set_channel_visibility(channel_id, visibility, session.user_id())
2988 .await?;
2989 let root_id = channel_model.root_id();
2990 let channel = Channel::from_model(channel_model);
2991
2992 let mut connection_pool = session.connection_pool().await;
2993 for (user_id, role) in connection_pool
2994 .channel_user_ids(root_id)
2995 .collect::<Vec<_>>()
2996 .into_iter()
2997 {
2998 let update = if role.can_see_channel(channel.visibility) {
2999 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3000 proto::UpdateChannels {
3001 channels: vec![channel.to_proto()],
3002 ..Default::default()
3003 }
3004 } else {
3005 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3006 proto::UpdateChannels {
3007 delete_channels: vec![channel.id.to_proto()],
3008 ..Default::default()
3009 }
3010 };
3011
3012 for connection_id in connection_pool.user_connection_ids(user_id) {
3013 session.peer.send(connection_id, update.clone())?;
3014 }
3015 }
3016
3017 response.send(proto::Ack {})?;
3018 Ok(())
3019}
3020
3021/// Alter the role for a user in the channel.
3022async fn set_channel_member_role(
3023 request: proto::SetChannelMemberRole,
3024 response: Response<proto::SetChannelMemberRole>,
3025 session: MessageContext,
3026) -> Result<()> {
3027 let db = session.db().await;
3028 let channel_id = ChannelId::from_proto(request.channel_id);
3029 let member_id = UserId::from_proto(request.user_id);
3030 let result = db
3031 .set_channel_member_role(
3032 channel_id,
3033 session.user_id(),
3034 member_id,
3035 request.role().into(),
3036 )
3037 .await?;
3038
3039 match result {
3040 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3041 let mut connection_pool = session.connection_pool().await;
3042 notify_membership_updated(
3043 &mut connection_pool,
3044 membership_update,
3045 member_id,
3046 &session.peer,
3047 )
3048 }
3049 db::SetMemberRoleResult::InviteUpdated(channel) => {
3050 let update = proto::UpdateChannels {
3051 channel_invitations: vec![channel.to_proto()],
3052 ..Default::default()
3053 };
3054
3055 for connection_id in session
3056 .connection_pool()
3057 .await
3058 .user_connection_ids(member_id)
3059 {
3060 session.peer.send(connection_id, update.clone())?;
3061 }
3062 }
3063 }
3064
3065 response.send(proto::Ack {})?;
3066 Ok(())
3067}
3068
3069/// Change the name of a channel
3070async fn rename_channel(
3071 request: proto::RenameChannel,
3072 response: Response<proto::RenameChannel>,
3073 session: MessageContext,
3074) -> Result<()> {
3075 let db = session.db().await;
3076 let channel_id = ChannelId::from_proto(request.channel_id);
3077 let channel_model = db
3078 .rename_channel(channel_id, session.user_id(), &request.name)
3079 .await?;
3080 let root_id = channel_model.root_id();
3081 let channel = Channel::from_model(channel_model);
3082
3083 response.send(proto::RenameChannelResponse {
3084 channel: Some(channel.to_proto()),
3085 })?;
3086
3087 let connection_pool = session.connection_pool().await;
3088 let update = proto::UpdateChannels {
3089 channels: vec![channel.to_proto()],
3090 ..Default::default()
3091 };
3092 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3093 if role.can_see_channel(channel.visibility) {
3094 session.peer.send(connection_id, update.clone())?;
3095 }
3096 }
3097
3098 Ok(())
3099}
3100
3101/// Move a channel to a new parent.
3102async fn move_channel(
3103 request: proto::MoveChannel,
3104 response: Response<proto::MoveChannel>,
3105 session: MessageContext,
3106) -> Result<()> {
3107 let channel_id = ChannelId::from_proto(request.channel_id);
3108 let to = ChannelId::from_proto(request.to);
3109
3110 let (root_id, channels) = session
3111 .db()
3112 .await
3113 .move_channel(channel_id, to, session.user_id())
3114 .await?;
3115
3116 let connection_pool = session.connection_pool().await;
3117 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3118 let channels = channels
3119 .iter()
3120 .filter_map(|channel| {
3121 if role.can_see_channel(channel.visibility) {
3122 Some(channel.to_proto())
3123 } else {
3124 None
3125 }
3126 })
3127 .collect::<Vec<_>>();
3128 if channels.is_empty() {
3129 continue;
3130 }
3131
3132 let update = proto::UpdateChannels {
3133 channels,
3134 ..Default::default()
3135 };
3136
3137 session.peer.send(connection_id, update.clone())?;
3138 }
3139
3140 response.send(Ack {})?;
3141 Ok(())
3142}
3143
3144async fn reorder_channel(
3145 request: proto::ReorderChannel,
3146 response: Response<proto::ReorderChannel>,
3147 session: MessageContext,
3148) -> Result<()> {
3149 let channel_id = ChannelId::from_proto(request.channel_id);
3150 let direction = request.direction();
3151
3152 let updated_channels = session
3153 .db()
3154 .await
3155 .reorder_channel(channel_id, direction, session.user_id())
3156 .await?;
3157
3158 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3159 let connection_pool = session.connection_pool().await;
3160 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3161 let channels = updated_channels
3162 .iter()
3163 .filter_map(|channel| {
3164 if role.can_see_channel(channel.visibility) {
3165 Some(channel.to_proto())
3166 } else {
3167 None
3168 }
3169 })
3170 .collect::<Vec<_>>();
3171
3172 if channels.is_empty() {
3173 continue;
3174 }
3175
3176 let update = proto::UpdateChannels {
3177 channels,
3178 ..Default::default()
3179 };
3180
3181 session.peer.send(connection_id, update.clone())?;
3182 }
3183 }
3184
3185 response.send(Ack {})?;
3186 Ok(())
3187}
3188
3189/// Get the list of channel members
3190async fn get_channel_members(
3191 request: proto::GetChannelMembers,
3192 response: Response<proto::GetChannelMembers>,
3193 session: MessageContext,
3194) -> Result<()> {
3195 let db = session.db().await;
3196 let channel_id = ChannelId::from_proto(request.channel_id);
3197 let limit = if request.limit == 0 {
3198 u16::MAX as u64
3199 } else {
3200 request.limit
3201 };
3202 let (members, users) = db
3203 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3204 .await?;
3205 response.send(proto::GetChannelMembersResponse { members, users })?;
3206 Ok(())
3207}
3208
3209/// Accept or decline a channel invitation.
3210async fn respond_to_channel_invite(
3211 request: proto::RespondToChannelInvite,
3212 response: Response<proto::RespondToChannelInvite>,
3213 session: MessageContext,
3214) -> Result<()> {
3215 let db = session.db().await;
3216 let channel_id = ChannelId::from_proto(request.channel_id);
3217 let RespondToChannelInvite {
3218 membership_update,
3219 notifications,
3220 } = db
3221 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3222 .await?;
3223
3224 let mut connection_pool = session.connection_pool().await;
3225 if let Some(membership_update) = membership_update {
3226 notify_membership_updated(
3227 &mut connection_pool,
3228 membership_update,
3229 session.user_id(),
3230 &session.peer,
3231 );
3232 } else {
3233 let update = proto::UpdateChannels {
3234 remove_channel_invitations: vec![channel_id.to_proto()],
3235 ..Default::default()
3236 };
3237
3238 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3239 session.peer.send(connection_id, update.clone())?;
3240 }
3241 };
3242
3243 send_notifications(&connection_pool, &session.peer, notifications);
3244
3245 response.send(proto::Ack {})?;
3246
3247 Ok(())
3248}
3249
3250/// Join the channels' room
3251async fn join_channel(
3252 request: proto::JoinChannel,
3253 response: Response<proto::JoinChannel>,
3254 session: MessageContext,
3255) -> Result<()> {
3256 let channel_id = ChannelId::from_proto(request.channel_id);
3257 join_channel_internal(channel_id, Box::new(response), session).await
3258}
3259
3260trait JoinChannelInternalResponse {
3261 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3262}
3263impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3264 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3265 Response::<proto::JoinChannel>::send(self, result)
3266 }
3267}
3268impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3269 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3270 Response::<proto::JoinRoom>::send(self, result)
3271 }
3272}
3273
3274async fn join_channel_internal(
3275 channel_id: ChannelId,
3276 response: Box<impl JoinChannelInternalResponse>,
3277 session: MessageContext,
3278) -> Result<()> {
3279 let joined_room = {
3280 let mut db = session.db().await;
3281 // If zed quits without leaving the room, and the user re-opens zed before the
3282 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3283 // room they were in.
3284 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3285 tracing::info!(
3286 stale_connection_id = %connection,
3287 "cleaning up stale connection",
3288 );
3289 drop(db);
3290 leave_room_for_session(&session, connection).await?;
3291 db = session.db().await;
3292 }
3293
3294 let (joined_room, membership_updated, role) = db
3295 .join_channel(channel_id, session.user_id(), session.connection_id)
3296 .await?;
3297
3298 let live_kit_connection_info =
3299 session
3300 .app_state
3301 .livekit_client
3302 .as_ref()
3303 .and_then(|live_kit| {
3304 let (can_publish, token) = if role == ChannelRole::Guest {
3305 (
3306 false,
3307 live_kit
3308 .guest_token(
3309 &joined_room.room.livekit_room,
3310 &session.user_id().to_string(),
3311 )
3312 .trace_err()?,
3313 )
3314 } else {
3315 (
3316 true,
3317 live_kit
3318 .room_token(
3319 &joined_room.room.livekit_room,
3320 &session.user_id().to_string(),
3321 )
3322 .trace_err()?,
3323 )
3324 };
3325
3326 Some(LiveKitConnectionInfo {
3327 server_url: live_kit.url().into(),
3328 token,
3329 can_publish,
3330 })
3331 });
3332
3333 response.send(proto::JoinRoomResponse {
3334 room: Some(joined_room.room.clone()),
3335 channel_id: joined_room
3336 .channel
3337 .as_ref()
3338 .map(|channel| channel.id.to_proto()),
3339 live_kit_connection_info,
3340 })?;
3341
3342 let mut connection_pool = session.connection_pool().await;
3343 if let Some(membership_updated) = membership_updated {
3344 notify_membership_updated(
3345 &mut connection_pool,
3346 membership_updated,
3347 session.user_id(),
3348 &session.peer,
3349 );
3350 }
3351
3352 room_updated(&joined_room.room, &session.peer);
3353
3354 joined_room
3355 };
3356
3357 channel_updated(
3358 &joined_room.channel.context("channel not returned")?,
3359 &joined_room.room,
3360 &session.peer,
3361 &*session.connection_pool().await,
3362 );
3363
3364 update_user_contacts(session.user_id(), &session).await?;
3365 Ok(())
3366}
3367
3368/// Start editing the channel notes
3369async fn join_channel_buffer(
3370 request: proto::JoinChannelBuffer,
3371 response: Response<proto::JoinChannelBuffer>,
3372 session: MessageContext,
3373) -> Result<()> {
3374 let db = session.db().await;
3375 let channel_id = ChannelId::from_proto(request.channel_id);
3376
3377 let open_response = db
3378 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3379 .await?;
3380
3381 let collaborators = open_response.collaborators.clone();
3382 response.send(open_response)?;
3383
3384 let update = UpdateChannelBufferCollaborators {
3385 channel_id: channel_id.to_proto(),
3386 collaborators: collaborators.clone(),
3387 };
3388 channel_buffer_updated(
3389 session.connection_id,
3390 collaborators
3391 .iter()
3392 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3393 &update,
3394 &session.peer,
3395 );
3396
3397 Ok(())
3398}
3399
3400/// Edit the channel notes
3401async fn update_channel_buffer(
3402 request: proto::UpdateChannelBuffer,
3403 session: MessageContext,
3404) -> Result<()> {
3405 let db = session.db().await;
3406 let channel_id = ChannelId::from_proto(request.channel_id);
3407
3408 let (collaborators, epoch, version) = db
3409 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3410 .await?;
3411
3412 channel_buffer_updated(
3413 session.connection_id,
3414 collaborators.clone(),
3415 &proto::UpdateChannelBuffer {
3416 channel_id: channel_id.to_proto(),
3417 operations: request.operations,
3418 },
3419 &session.peer,
3420 );
3421
3422 let pool = &*session.connection_pool().await;
3423
3424 let non_collaborators =
3425 pool.channel_connection_ids(channel_id)
3426 .filter_map(|(connection_id, _)| {
3427 if collaborators.contains(&connection_id) {
3428 None
3429 } else {
3430 Some(connection_id)
3431 }
3432 });
3433
3434 broadcast(None, non_collaborators, |peer_id| {
3435 session.peer.send(
3436 peer_id,
3437 proto::UpdateChannels {
3438 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3439 channel_id: channel_id.to_proto(),
3440 epoch: epoch as u64,
3441 version: version.clone(),
3442 }],
3443 ..Default::default()
3444 },
3445 )
3446 });
3447
3448 Ok(())
3449}
3450
3451/// Rejoin the channel notes after a connection blip
3452async fn rejoin_channel_buffers(
3453 request: proto::RejoinChannelBuffers,
3454 response: Response<proto::RejoinChannelBuffers>,
3455 session: MessageContext,
3456) -> Result<()> {
3457 let db = session.db().await;
3458 let buffers = db
3459 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3460 .await?;
3461
3462 for rejoined_buffer in &buffers {
3463 let collaborators_to_notify = rejoined_buffer
3464 .buffer
3465 .collaborators
3466 .iter()
3467 .filter_map(|c| Some(c.peer_id?.into()));
3468 channel_buffer_updated(
3469 session.connection_id,
3470 collaborators_to_notify,
3471 &proto::UpdateChannelBufferCollaborators {
3472 channel_id: rejoined_buffer.buffer.channel_id,
3473 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3474 },
3475 &session.peer,
3476 );
3477 }
3478
3479 response.send(proto::RejoinChannelBuffersResponse {
3480 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3481 })?;
3482
3483 Ok(())
3484}
3485
3486/// Stop editing the channel notes
3487async fn leave_channel_buffer(
3488 request: proto::LeaveChannelBuffer,
3489 response: Response<proto::LeaveChannelBuffer>,
3490 session: MessageContext,
3491) -> Result<()> {
3492 let db = session.db().await;
3493 let channel_id = ChannelId::from_proto(request.channel_id);
3494
3495 let left_buffer = db
3496 .leave_channel_buffer(channel_id, session.connection_id)
3497 .await?;
3498
3499 response.send(Ack {})?;
3500
3501 channel_buffer_updated(
3502 session.connection_id,
3503 left_buffer.connections,
3504 &proto::UpdateChannelBufferCollaborators {
3505 channel_id: channel_id.to_proto(),
3506 collaborators: left_buffer.collaborators,
3507 },
3508 &session.peer,
3509 );
3510
3511 Ok(())
3512}
3513
3514fn channel_buffer_updated<T: EnvelopedMessage>(
3515 sender_id: ConnectionId,
3516 collaborators: impl IntoIterator<Item = ConnectionId>,
3517 message: &T,
3518 peer: &Peer,
3519) {
3520 broadcast(Some(sender_id), collaborators, |peer_id| {
3521 peer.send(peer_id, message.clone())
3522 });
3523}
3524
3525fn send_notifications(
3526 connection_pool: &ConnectionPool,
3527 peer: &Peer,
3528 notifications: db::NotificationBatch,
3529) {
3530 for (user_id, notification) in notifications {
3531 for connection_id in connection_pool.user_connection_ids(user_id) {
3532 if let Err(error) = peer.send(
3533 connection_id,
3534 proto::AddNotification {
3535 notification: Some(notification.clone()),
3536 },
3537 ) {
3538 tracing::error!(
3539 "failed to send notification to {:?} {}",
3540 connection_id,
3541 error
3542 );
3543 }
3544 }
3545 }
3546}
3547
3548/// Send a message to the channel
3549async fn send_channel_message(
3550 _request: proto::SendChannelMessage,
3551 _response: Response<proto::SendChannelMessage>,
3552 _session: MessageContext,
3553) -> Result<()> {
3554 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3555}
3556
3557/// Delete a channel message
3558async fn remove_channel_message(
3559 _request: proto::RemoveChannelMessage,
3560 _response: Response<proto::RemoveChannelMessage>,
3561 _session: MessageContext,
3562) -> Result<()> {
3563 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3564}
3565
3566async fn update_channel_message(
3567 _request: proto::UpdateChannelMessage,
3568 _response: Response<proto::UpdateChannelMessage>,
3569 _session: MessageContext,
3570) -> Result<()> {
3571 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3572}
3573
3574/// Mark a channel message as read
3575async fn acknowledge_channel_message(
3576 _request: proto::AckChannelMessage,
3577 _session: MessageContext,
3578) -> Result<()> {
3579 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3580}
3581
3582/// Mark a buffer version as synced
3583async fn acknowledge_buffer_version(
3584 request: proto::AckBufferOperation,
3585 session: MessageContext,
3586) -> Result<()> {
3587 let buffer_id = BufferId::from_proto(request.buffer_id);
3588 session
3589 .db()
3590 .await
3591 .observe_buffer_version(
3592 buffer_id,
3593 session.user_id(),
3594 request.epoch as i32,
3595 &request.version,
3596 )
3597 .await?;
3598 Ok(())
3599}
3600
3601/// Start receiving chat updates for a channel
3602async fn join_channel_chat(
3603 _request: proto::JoinChannelChat,
3604 _response: Response<proto::JoinChannelChat>,
3605 _session: MessageContext,
3606) -> Result<()> {
3607 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3608}
3609
3610/// Stop receiving chat updates for a channel
3611async fn leave_channel_chat(
3612 _request: proto::LeaveChannelChat,
3613 _session: MessageContext,
3614) -> Result<()> {
3615 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3616}
3617
3618/// Retrieve the chat history for a channel
3619async fn get_channel_messages(
3620 _request: proto::GetChannelMessages,
3621 _response: Response<proto::GetChannelMessages>,
3622 _session: MessageContext,
3623) -> Result<()> {
3624 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3625}
3626
3627/// Retrieve specific chat messages
3628async fn get_channel_messages_by_id(
3629 _request: proto::GetChannelMessagesById,
3630 _response: Response<proto::GetChannelMessagesById>,
3631 _session: MessageContext,
3632) -> Result<()> {
3633 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3634}
3635
3636/// Retrieve the current users notifications
3637async fn get_notifications(
3638 request: proto::GetNotifications,
3639 response: Response<proto::GetNotifications>,
3640 session: MessageContext,
3641) -> Result<()> {
3642 let notifications = session
3643 .db()
3644 .await
3645 .get_notifications(
3646 session.user_id(),
3647 NOTIFICATION_COUNT_PER_PAGE,
3648 request.before_id.map(db::NotificationId::from_proto),
3649 )
3650 .await?;
3651 response.send(proto::GetNotificationsResponse {
3652 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3653 notifications,
3654 })?;
3655 Ok(())
3656}
3657
3658/// Mark notifications as read
3659async fn mark_notification_as_read(
3660 request: proto::MarkNotificationRead,
3661 response: Response<proto::MarkNotificationRead>,
3662 session: MessageContext,
3663) -> Result<()> {
3664 let database = &session.db().await;
3665 let notifications = database
3666 .mark_notification_as_read_by_id(
3667 session.user_id(),
3668 NotificationId::from_proto(request.notification_id),
3669 )
3670 .await?;
3671 send_notifications(
3672 &*session.connection_pool().await,
3673 &session.peer,
3674 notifications,
3675 );
3676 response.send(proto::Ack {})?;
3677 Ok(())
3678}
3679
3680fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3681 let message = match message {
3682 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3683 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3684 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3685 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3686 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3687 code: frame.code.into(),
3688 reason: frame.reason.as_str().to_owned().into(),
3689 })),
3690 // We should never receive a frame while reading the message, according
3691 // to the `tungstenite` maintainers:
3692 //
3693 // > It cannot occur when you read messages from the WebSocket, but it
3694 // > can be used when you want to send the raw frames (e.g. you want to
3695 // > send the frames to the WebSocket without composing the full message first).
3696 // >
3697 // > — https://github.com/snapview/tungstenite-rs/issues/268
3698 TungsteniteMessage::Frame(_) => {
3699 bail!("received an unexpected frame while reading the message")
3700 }
3701 };
3702
3703 Ok(message)
3704}
3705
3706fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3707 match message {
3708 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3709 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3710 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3711 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3712 AxumMessage::Close(frame) => {
3713 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3714 code: frame.code.into(),
3715 reason: frame.reason.as_ref().into(),
3716 }))
3717 }
3718 }
3719}
3720
3721fn notify_membership_updated(
3722 connection_pool: &mut ConnectionPool,
3723 result: MembershipUpdated,
3724 user_id: UserId,
3725 peer: &Peer,
3726) {
3727 for membership in &result.new_channels.channel_memberships {
3728 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3729 }
3730 for channel_id in &result.removed_channels {
3731 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3732 }
3733
3734 let user_channels_update = proto::UpdateUserChannels {
3735 channel_memberships: result
3736 .new_channels
3737 .channel_memberships
3738 .iter()
3739 .map(|cm| proto::ChannelMembership {
3740 channel_id: cm.channel_id.to_proto(),
3741 role: cm.role.into(),
3742 })
3743 .collect(),
3744 ..Default::default()
3745 };
3746
3747 let mut update = build_channels_update(result.new_channels);
3748 update.delete_channels = result
3749 .removed_channels
3750 .into_iter()
3751 .map(|id| id.to_proto())
3752 .collect();
3753 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3754
3755 for connection_id in connection_pool.user_connection_ids(user_id) {
3756 peer.send(connection_id, user_channels_update.clone())
3757 .trace_err();
3758 peer.send(connection_id, update.clone()).trace_err();
3759 }
3760}
3761
3762fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3763 proto::UpdateUserChannels {
3764 channel_memberships: channels
3765 .channel_memberships
3766 .iter()
3767 .map(|m| proto::ChannelMembership {
3768 channel_id: m.channel_id.to_proto(),
3769 role: m.role.into(),
3770 })
3771 .collect(),
3772 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3773 }
3774}
3775
3776fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3777 let mut update = proto::UpdateChannels::default();
3778
3779 for channel in channels.channels {
3780 update.channels.push(channel.to_proto());
3781 }
3782
3783 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3784
3785 for (channel_id, participants) in channels.channel_participants {
3786 update
3787 .channel_participants
3788 .push(proto::ChannelParticipants {
3789 channel_id: channel_id.to_proto(),
3790 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3791 });
3792 }
3793
3794 for channel in channels.invited_channels {
3795 update.channel_invitations.push(channel.to_proto());
3796 }
3797
3798 update
3799}
3800
3801fn build_initial_contacts_update(
3802 contacts: Vec<db::Contact>,
3803 pool: &ConnectionPool,
3804) -> proto::UpdateContacts {
3805 let mut update = proto::UpdateContacts::default();
3806
3807 for contact in contacts {
3808 match contact {
3809 db::Contact::Accepted { user_id, busy } => {
3810 update.contacts.push(contact_for_user(user_id, busy, pool));
3811 }
3812 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3813 db::Contact::Incoming { user_id } => {
3814 update
3815 .incoming_requests
3816 .push(proto::IncomingContactRequest {
3817 requester_id: user_id.to_proto(),
3818 })
3819 }
3820 }
3821 }
3822
3823 update
3824}
3825
3826fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3827 proto::Contact {
3828 user_id: user_id.to_proto(),
3829 online: pool.is_user_online(user_id),
3830 busy,
3831 }
3832}
3833
3834fn room_updated(room: &proto::Room, peer: &Peer) {
3835 broadcast(
3836 None,
3837 room.participants
3838 .iter()
3839 .filter_map(|participant| Some(participant.peer_id?.into())),
3840 |peer_id| {
3841 peer.send(
3842 peer_id,
3843 proto::RoomUpdated {
3844 room: Some(room.clone()),
3845 },
3846 )
3847 },
3848 );
3849}
3850
3851fn channel_updated(
3852 channel: &db::channel::Model,
3853 room: &proto::Room,
3854 peer: &Peer,
3855 pool: &ConnectionPool,
3856) {
3857 let participants = room
3858 .participants
3859 .iter()
3860 .map(|p| p.user_id)
3861 .collect::<Vec<_>>();
3862
3863 broadcast(
3864 None,
3865 pool.channel_connection_ids(channel.root_id())
3866 .filter_map(|(channel_id, role)| {
3867 role.can_see_channel(channel.visibility)
3868 .then_some(channel_id)
3869 }),
3870 |peer_id| {
3871 peer.send(
3872 peer_id,
3873 proto::UpdateChannels {
3874 channel_participants: vec![proto::ChannelParticipants {
3875 channel_id: channel.id.to_proto(),
3876 participant_user_ids: participants.clone(),
3877 }],
3878 ..Default::default()
3879 },
3880 )
3881 },
3882 );
3883}
3884
3885async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3886 let db = session.db().await;
3887
3888 let contacts = db.get_contacts(user_id).await?;
3889 let busy = db.is_user_busy(user_id).await?;
3890
3891 let pool = session.connection_pool().await;
3892 let updated_contact = contact_for_user(user_id, busy, &pool);
3893 for contact in contacts {
3894 if let db::Contact::Accepted {
3895 user_id: contact_user_id,
3896 ..
3897 } = contact
3898 {
3899 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3900 session
3901 .peer
3902 .send(
3903 contact_conn_id,
3904 proto::UpdateContacts {
3905 contacts: vec![updated_contact.clone()],
3906 remove_contacts: Default::default(),
3907 incoming_requests: Default::default(),
3908 remove_incoming_requests: Default::default(),
3909 outgoing_requests: Default::default(),
3910 remove_outgoing_requests: Default::default(),
3911 },
3912 )
3913 .trace_err();
3914 }
3915 }
3916 }
3917 Ok(())
3918}
3919
3920async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3921 let mut contacts_to_update = HashSet::default();
3922
3923 let room_id;
3924 let canceled_calls_to_user_ids;
3925 let livekit_room;
3926 let delete_livekit_room;
3927 let room;
3928 let channel;
3929
3930 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3931 contacts_to_update.insert(session.user_id());
3932
3933 for project in left_room.left_projects.values() {
3934 project_left(project, session);
3935 }
3936
3937 room_id = RoomId::from_proto(left_room.room.id);
3938 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3939 livekit_room = mem::take(&mut left_room.room.livekit_room);
3940 delete_livekit_room = left_room.deleted;
3941 room = mem::take(&mut left_room.room);
3942 channel = mem::take(&mut left_room.channel);
3943
3944 room_updated(&room, &session.peer);
3945 } else {
3946 return Ok(());
3947 }
3948
3949 if let Some(channel) = channel {
3950 channel_updated(
3951 &channel,
3952 &room,
3953 &session.peer,
3954 &*session.connection_pool().await,
3955 );
3956 }
3957
3958 {
3959 let pool = session.connection_pool().await;
3960 for canceled_user_id in canceled_calls_to_user_ids {
3961 for connection_id in pool.user_connection_ids(canceled_user_id) {
3962 session
3963 .peer
3964 .send(
3965 connection_id,
3966 proto::CallCanceled {
3967 room_id: room_id.to_proto(),
3968 },
3969 )
3970 .trace_err();
3971 }
3972 contacts_to_update.insert(canceled_user_id);
3973 }
3974 }
3975
3976 for contact_user_id in contacts_to_update {
3977 update_user_contacts(contact_user_id, session).await?;
3978 }
3979
3980 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3981 live_kit
3982 .remove_participant(livekit_room.clone(), session.user_id().to_string())
3983 .await
3984 .trace_err();
3985
3986 if delete_livekit_room {
3987 live_kit.delete_room(livekit_room).await.trace_err();
3988 }
3989 }
3990
3991 Ok(())
3992}
3993
3994async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3995 let left_channel_buffers = session
3996 .db()
3997 .await
3998 .leave_channel_buffers(session.connection_id)
3999 .await?;
4000
4001 for left_buffer in left_channel_buffers {
4002 channel_buffer_updated(
4003 session.connection_id,
4004 left_buffer.connections,
4005 &proto::UpdateChannelBufferCollaborators {
4006 channel_id: left_buffer.channel_id.to_proto(),
4007 collaborators: left_buffer.collaborators,
4008 },
4009 &session.peer,
4010 );
4011 }
4012
4013 Ok(())
4014}
4015
4016fn project_left(project: &db::LeftProject, session: &Session) {
4017 for connection_id in &project.connection_ids {
4018 if project.should_unshare {
4019 session
4020 .peer
4021 .send(
4022 *connection_id,
4023 proto::UnshareProject {
4024 project_id: project.id.to_proto(),
4025 },
4026 )
4027 .trace_err();
4028 } else {
4029 session
4030 .peer
4031 .send(
4032 *connection_id,
4033 proto::RemoveProjectCollaborator {
4034 project_id: project.id.to_proto(),
4035 peer_id: Some(session.connection_id.into()),
4036 },
4037 )
4038 .trace_err();
4039 }
4040 }
4041}
4042
4043async fn share_agent_thread(
4044 request: proto::ShareAgentThread,
4045 response: Response<proto::ShareAgentThread>,
4046 session: MessageContext,
4047) -> Result<()> {
4048 let user_id = session.user_id();
4049
4050 let share_id = SharedThreadId::from_proto(request.session_id.clone())
4051 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4052
4053 session
4054 .db()
4055 .await
4056 .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4057 .await?;
4058
4059 response.send(proto::Ack {})?;
4060
4061 Ok(())
4062}
4063
4064async fn get_shared_agent_thread(
4065 request: proto::GetSharedAgentThread,
4066 response: Response<proto::GetSharedAgentThread>,
4067 session: MessageContext,
4068) -> Result<()> {
4069 let share_id = SharedThreadId::from_proto(request.session_id)
4070 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4071
4072 let result = session.db().await.get_shared_thread(share_id).await?;
4073
4074 match result {
4075 Some((thread, username)) => {
4076 response.send(proto::GetSharedAgentThreadResponse {
4077 title: thread.title,
4078 thread_data: thread.data,
4079 sharer_username: username,
4080 created_at: thread.created_at.and_utc().to_rfc3339(),
4081 })?;
4082 }
4083 None => {
4084 return Err(anyhow!("Shared thread not found").into());
4085 }
4086 }
4087
4088 Ok(())
4089}
4090
4091pub trait ResultExt {
4092 type Ok;
4093
4094 fn trace_err(self) -> Option<Self::Ok>;
4095}
4096
4097impl<T, E> ResultExt for Result<T, E>
4098where
4099 E: std::fmt::Debug,
4100{
4101 type Ok = T;
4102
4103 #[track_caller]
4104 fn trace_err(self) -> Option<T> {
4105 match self {
4106 Ok(value) => Some(value),
4107 Err(error) => {
4108 tracing::error!("{:?}", error);
4109 None
4110 }
4111 }
4112 }
4113}