1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
10 UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use futures::TryFutureExt as _;
36use rpc::proto::split_repository_update;
37use tracing::Span;
38use util::paths::PathStyle;
39
40use futures::{
41 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
42 stream::FuturesUnordered,
43};
44use prometheus::{IntGauge, register_int_gauge};
45use rpc::{
46 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
47 proto::{
48 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
49 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
50 },
51};
52use semver::Version;
53use serde::{Serialize, Serializer};
54use std::{
55 any::TypeId,
56 future::Future,
57 marker::PhantomData,
58 mem,
59 net::SocketAddr,
60 ops::{Deref, DerefMut},
61 rc::Rc,
62 sync::{
63 Arc, OnceLock,
64 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
65 },
66 time::{Duration, Instant},
67};
68use tokio::sync::{Semaphore, watch};
69use tower::ServiceBuilder;
70use tracing::{
71 Instrument,
72 field::{self},
73 info_span, instrument,
74};
75
76pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
77
78// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
79pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
80
81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
82const MAX_CONCURRENT_CONNECTIONS: usize = 512;
83
84static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
85
86const TOTAL_DURATION_MS: &str = "total_duration_ms";
87const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
88const QUEUE_DURATION_MS: &str = "queue_duration_ms";
89const HOST_WAITING_MS: &str = "host_waiting_ms";
90
91type MessageHandler =
92 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
93
94pub struct ConnectionGuard;
95
96impl ConnectionGuard {
97 pub fn try_acquire() -> Result<Self, ()> {
98 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
99 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
100 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
101 tracing::error!(
102 "too many concurrent connections: {}",
103 current_connections + 1
104 );
105 return Err(());
106 }
107 Ok(ConnectionGuard)
108 }
109}
110
111impl Drop for ConnectionGuard {
112 fn drop(&mut self) {
113 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
114 }
115}
116
117struct Response<R> {
118 peer: Arc<Peer>,
119 receipt: Receipt<R>,
120 responded: Arc<AtomicBool>,
121}
122
123impl<R: RequestMessage> Response<R> {
124 fn send(self, payload: R::Response) -> Result<()> {
125 self.responded.store(true, SeqCst);
126 self.peer.respond(self.receipt, payload)?;
127 Ok(())
128 }
129}
130
131#[derive(Clone, Debug)]
132pub enum Principal {
133 User(User),
134 Impersonated { user: User, admin: User },
135}
136
137impl Principal {
138 fn update_span(&self, span: &tracing::Span) {
139 match &self {
140 Principal::User(user) => {
141 span.record("user_id", user.id.0);
142 span.record("login", &user.github_login);
143 }
144 Principal::Impersonated { user, admin } => {
145 span.record("user_id", user.id.0);
146 span.record("login", &user.github_login);
147 span.record("impersonator", &admin.github_login);
148 }
149 }
150 }
151}
152
153#[derive(Clone)]
154struct MessageContext {
155 session: Session,
156 span: tracing::Span,
157}
158
159impl Deref for MessageContext {
160 type Target = Session;
161
162 fn deref(&self) -> &Self::Target {
163 &self.session
164 }
165}
166
167impl MessageContext {
168 pub fn forward_request<T: RequestMessage>(
169 &self,
170 receiver_id: ConnectionId,
171 request: T,
172 ) -> impl Future<Output = anyhow::Result<T::Response>> {
173 let request_start_time = Instant::now();
174 let span = self.span.clone();
175 tracing::info!("start forwarding request");
176 self.peer
177 .forward_request(self.connection_id, receiver_id, request)
178 .inspect(move |_| {
179 span.record(
180 HOST_WAITING_MS,
181 request_start_time.elapsed().as_micros() as f64 / 1000.0,
182 );
183 })
184 .inspect_err(|_| tracing::error!("error forwarding request"))
185 .inspect_ok(|_| tracing::info!("finished forwarding request"))
186 }
187}
188
189#[derive(Clone)]
190struct Session {
191 principal: Principal,
192 connection_id: ConnectionId,
193 db: Arc<tokio::sync::Mutex<DbHandle>>,
194 peer: Arc<Peer>,
195 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
196 app_state: Arc<AppState>,
197 /// The GeoIP country code for the user.
198 #[allow(unused)]
199 geoip_country_code: Option<String>,
200 #[allow(unused)]
201 system_id: Option<String>,
202 _executor: Executor,
203}
204
205impl Session {
206 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
207 #[cfg(feature = "test-support")]
208 tokio::task::yield_now().await;
209 let guard = self.db.lock().await;
210 #[cfg(feature = "test-support")]
211 tokio::task::yield_now().await;
212 guard
213 }
214
215 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
216 #[cfg(feature = "test-support")]
217 tokio::task::yield_now().await;
218 let guard = self.connection_pool.lock();
219 ConnectionPoolGuard {
220 guard,
221 _not_send: PhantomData,
222 }
223 }
224
225 #[expect(dead_code)]
226 fn is_staff(&self) -> bool {
227 match &self.principal {
228 Principal::User(user) => user.admin,
229 Principal::Impersonated { .. } => true,
230 }
231 }
232
233 fn user_id(&self) -> UserId {
234 match &self.principal {
235 Principal::User(user) => user.id,
236 Principal::Impersonated { user, .. } => user.id,
237 }
238 }
239}
240
241impl Debug for Session {
242 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
243 let mut result = f.debug_struct("Session");
244 match &self.principal {
245 Principal::User(user) => {
246 result.field("user", &user.github_login);
247 }
248 Principal::Impersonated { user, admin } => {
249 result.field("user", &user.github_login);
250 result.field("impersonator", &admin.github_login);
251 }
252 }
253 result.field("connection_id", &self.connection_id).finish()
254 }
255}
256
257struct DbHandle(Arc<Database>);
258
259impl Deref for DbHandle {
260 type Target = Database;
261
262 fn deref(&self) -> &Self::Target {
263 self.0.as_ref()
264 }
265}
266
267pub struct Server {
268 id: parking_lot::Mutex<ServerId>,
269 peer: Arc<Peer>,
270 pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
271 app_state: Arc<AppState>,
272 handlers: HashMap<TypeId, MessageHandler>,
273 teardown: watch::Sender<bool>,
274}
275
276struct ConnectionPoolGuard<'a> {
277 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
278 _not_send: PhantomData<Rc<()>>,
279}
280
281#[derive(Serialize)]
282pub struct ServerSnapshot<'a> {
283 peer: &'a Peer,
284 #[serde(serialize_with = "serialize_deref")]
285 connection_pool: ConnectionPoolGuard<'a>,
286}
287
288pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
289where
290 S: Serializer,
291 T: Deref<Target = U>,
292 U: Serialize,
293{
294 Serialize::serialize(value.deref(), serializer)
295}
296
297impl Server {
298 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
299 let mut server = Self {
300 id: parking_lot::Mutex::new(id),
301 peer: Peer::new(id.0 as u32),
302 app_state,
303 connection_pool: Default::default(),
304 handlers: Default::default(),
305 teardown: watch::channel(false).0,
306 };
307
308 server
309 .add_request_handler(ping)
310 .add_request_handler(create_room)
311 .add_request_handler(join_room)
312 .add_request_handler(rejoin_room)
313 .add_request_handler(leave_room)
314 .add_request_handler(set_room_participant_role)
315 .add_request_handler(call)
316 .add_request_handler(cancel_call)
317 .add_message_handler(decline_call)
318 .add_request_handler(update_participant_location)
319 .add_request_handler(share_project)
320 .add_message_handler(unshare_project)
321 .add_request_handler(join_project)
322 .add_message_handler(leave_project)
323 .add_request_handler(update_project)
324 .add_request_handler(update_worktree)
325 .add_request_handler(update_repository)
326 .add_request_handler(remove_repository)
327 .add_message_handler(start_language_server)
328 .add_message_handler(update_language_server)
329 .add_message_handler(update_diagnostic_summary)
330 .add_message_handler(update_worktree_settings)
331 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
332 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
333 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
334 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
335 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
336 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
337 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
338 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
339 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
340 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
341 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
342 .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
343 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
344 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
345 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
346 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
347 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
348 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
349 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
350 .add_request_handler(
351 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
352 )
353 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
354 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
355 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
356 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
357 .add_request_handler(
358 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
359 )
360 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
361 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
362 .add_request_handler(
363 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
364 )
365 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
366 .add_request_handler(
367 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
368 )
369 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
370 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
371 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
372 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
373 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
374 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
375 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
376 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
377 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
378 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
379 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
380 .add_request_handler(
381 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
382 )
383 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
384 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
385 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
386 .add_request_handler(lsp_query)
387 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
388 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
389 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
390 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
391 .add_message_handler(create_buffer_for_peer)
392 .add_message_handler(create_image_for_peer)
393 .add_request_handler(update_buffer)
394 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
395 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
396 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
397 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
398 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
399 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
400 .add_message_handler(
401 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
402 )
403 .add_request_handler(get_users)
404 .add_request_handler(fuzzy_search_users)
405 .add_request_handler(request_contact)
406 .add_request_handler(remove_contact)
407 .add_request_handler(respond_to_contact_request)
408 .add_message_handler(subscribe_to_channels)
409 .add_request_handler(create_channel)
410 .add_request_handler(delete_channel)
411 .add_request_handler(invite_channel_member)
412 .add_request_handler(remove_channel_member)
413 .add_request_handler(set_channel_member_role)
414 .add_request_handler(set_channel_visibility)
415 .add_request_handler(rename_channel)
416 .add_request_handler(join_channel_buffer)
417 .add_request_handler(leave_channel_buffer)
418 .add_message_handler(update_channel_buffer)
419 .add_request_handler(rejoin_channel_buffers)
420 .add_request_handler(get_channel_members)
421 .add_request_handler(respond_to_channel_invite)
422 .add_request_handler(join_channel)
423 .add_request_handler(join_channel_chat)
424 .add_message_handler(leave_channel_chat)
425 .add_request_handler(send_channel_message)
426 .add_request_handler(remove_channel_message)
427 .add_request_handler(update_channel_message)
428 .add_request_handler(get_channel_messages)
429 .add_request_handler(get_channel_messages_by_id)
430 .add_request_handler(get_notifications)
431 .add_request_handler(mark_notification_as_read)
432 .add_request_handler(move_channel)
433 .add_request_handler(reorder_channel)
434 .add_request_handler(follow)
435 .add_message_handler(unfollow)
436 .add_message_handler(update_followers)
437 .add_message_handler(acknowledge_channel_message)
438 .add_message_handler(acknowledge_buffer_version)
439 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
440 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
441 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
442 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
443 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
444 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
445 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
446 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
447 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
448 .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
449 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
450 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
451 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
452 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
453 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
454 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
455 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
456 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
457 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
458 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
459 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
460 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
461 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
462 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
463 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
464 .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
465 .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
466 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
467 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
468 .add_message_handler(update_context)
469 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
470 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
471 .add_request_handler(share_agent_thread)
472 .add_request_handler(get_shared_agent_thread)
473 .add_request_handler(forward_project_search_chunk);
474
475 Arc::new(server)
476 }
477
478 pub async fn start(&self) -> Result<()> {
479 let server_id = *self.id.lock();
480 let app_state = self.app_state.clone();
481 let peer = self.peer.clone();
482 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
483 let pool = self.connection_pool.clone();
484 let livekit_client = self.app_state.livekit_client.clone();
485
486 let span = info_span!("start server");
487 self.app_state.executor.spawn_detached(
488 async move {
489 tracing::info!("waiting for cleanup timeout");
490 timeout.await;
491 tracing::info!("cleanup timeout expired, retrieving stale rooms");
492
493 app_state
494 .db
495 .delete_stale_channel_chat_participants(
496 &app_state.config.zed_environment,
497 server_id,
498 )
499 .await
500 .trace_err();
501
502 if let Some((room_ids, channel_ids)) = app_state
503 .db
504 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
505 .await
506 .trace_err()
507 {
508 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
509 tracing::info!(
510 stale_channel_buffer_count = channel_ids.len(),
511 "retrieved stale channel buffers"
512 );
513
514 for channel_id in channel_ids {
515 if let Some(refreshed_channel_buffer) = app_state
516 .db
517 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
518 .await
519 .trace_err()
520 {
521 for connection_id in refreshed_channel_buffer.connection_ids {
522 peer.send(
523 connection_id,
524 proto::UpdateChannelBufferCollaborators {
525 channel_id: channel_id.to_proto(),
526 collaborators: refreshed_channel_buffer
527 .collaborators
528 .clone(),
529 },
530 )
531 .trace_err();
532 }
533 }
534 }
535
536 for room_id in room_ids {
537 let mut contacts_to_update = HashSet::default();
538 let mut canceled_calls_to_user_ids = Vec::new();
539 let mut livekit_room = String::new();
540 let mut delete_livekit_room = false;
541
542 if let Some(mut refreshed_room) = app_state
543 .db
544 .clear_stale_room_participants(room_id, server_id)
545 .await
546 .trace_err()
547 {
548 tracing::info!(
549 room_id = room_id.0,
550 new_participant_count = refreshed_room.room.participants.len(),
551 "refreshed room"
552 );
553 room_updated(&refreshed_room.room, &peer);
554 if let Some(channel) = refreshed_room.channel.as_ref() {
555 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
556 }
557 contacts_to_update
558 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
559 contacts_to_update
560 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
561 canceled_calls_to_user_ids =
562 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
563 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
564 delete_livekit_room = refreshed_room.room.participants.is_empty();
565 }
566
567 {
568 let pool = pool.lock();
569 for canceled_user_id in canceled_calls_to_user_ids {
570 for connection_id in pool.user_connection_ids(canceled_user_id) {
571 peer.send(
572 connection_id,
573 proto::CallCanceled {
574 room_id: room_id.to_proto(),
575 },
576 )
577 .trace_err();
578 }
579 }
580 }
581
582 for user_id in contacts_to_update {
583 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
584 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
585 if let Some((busy, contacts)) = busy.zip(contacts) {
586 let pool = pool.lock();
587 let updated_contact = contact_for_user(user_id, busy, &pool);
588 for contact in contacts {
589 if let db::Contact::Accepted {
590 user_id: contact_user_id,
591 ..
592 } = contact
593 {
594 for contact_conn_id in
595 pool.user_connection_ids(contact_user_id)
596 {
597 peer.send(
598 contact_conn_id,
599 proto::UpdateContacts {
600 contacts: vec![updated_contact.clone()],
601 remove_contacts: Default::default(),
602 incoming_requests: Default::default(),
603 remove_incoming_requests: Default::default(),
604 outgoing_requests: Default::default(),
605 remove_outgoing_requests: Default::default(),
606 },
607 )
608 .trace_err();
609 }
610 }
611 }
612 }
613 }
614
615 if let Some(live_kit) = livekit_client.as_ref()
616 && delete_livekit_room
617 {
618 live_kit.delete_room(livekit_room).await.trace_err();
619 }
620 }
621 }
622
623 app_state
624 .db
625 .delete_stale_channel_chat_participants(
626 &app_state.config.zed_environment,
627 server_id,
628 )
629 .await
630 .trace_err();
631
632 app_state
633 .db
634 .clear_old_worktree_entries(server_id)
635 .await
636 .trace_err();
637
638 app_state
639 .db
640 .delete_stale_servers(&app_state.config.zed_environment, server_id)
641 .await
642 .trace_err();
643 }
644 .instrument(span),
645 );
646 Ok(())
647 }
648
649 pub fn teardown(&self) {
650 self.peer.teardown();
651 self.connection_pool.lock().reset();
652 let _ = self.teardown.send(true);
653 }
654
655 #[cfg(feature = "test-support")]
656 pub fn reset(&self, id: ServerId) {
657 self.teardown();
658 *self.id.lock() = id;
659 self.peer.reset(id.0 as u32);
660 let _ = self.teardown.send(false);
661 }
662
663 #[cfg(feature = "test-support")]
664 pub fn id(&self) -> ServerId {
665 *self.id.lock()
666 }
667
668 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
669 where
670 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
671 Fut: 'static + Send + Future<Output = Result<()>>,
672 M: EnvelopedMessage,
673 {
674 let prev_handler = self.handlers.insert(
675 TypeId::of::<M>(),
676 Box::new(move |envelope, session, span| {
677 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
678 let received_at = envelope.received_at;
679 tracing::info!("message received");
680 let start_time = Instant::now();
681 let future = (handler)(
682 *envelope,
683 MessageContext {
684 session,
685 span: span.clone(),
686 },
687 );
688 async move {
689 let result = future.await;
690 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
691 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
692 let queue_duration_ms = total_duration_ms - processing_duration_ms;
693 span.record(TOTAL_DURATION_MS, total_duration_ms);
694 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
695 span.record(QUEUE_DURATION_MS, queue_duration_ms);
696 match result {
697 Err(error) => {
698 tracing::error!(?error, "error handling message")
699 }
700 Ok(()) => tracing::info!("finished handling message"),
701 }
702 }
703 .boxed()
704 }),
705 );
706 if prev_handler.is_some() {
707 panic!("registered a handler for the same message twice");
708 }
709 self
710 }
711
712 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
713 where
714 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
715 Fut: 'static + Send + Future<Output = Result<()>>,
716 M: EnvelopedMessage,
717 {
718 self.add_handler(move |envelope, session| handler(envelope.payload, session));
719 self
720 }
721
722 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
723 where
724 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
725 Fut: Send + Future<Output = Result<()>>,
726 M: RequestMessage,
727 {
728 let handler = Arc::new(handler);
729 self.add_handler(move |envelope, session| {
730 let receipt = envelope.receipt();
731 let handler = handler.clone();
732 async move {
733 let peer = session.peer.clone();
734 let responded = Arc::new(AtomicBool::default());
735 let response = Response {
736 peer: peer.clone(),
737 responded: responded.clone(),
738 receipt,
739 };
740 match (handler)(envelope.payload, response, session).await {
741 Ok(()) => {
742 if responded.load(std::sync::atomic::Ordering::SeqCst) {
743 Ok(())
744 } else {
745 Err(anyhow!("handler did not send a response"))?
746 }
747 }
748 Err(error) => {
749 let proto_err = match &error {
750 Error::Internal(err) => err.to_proto(),
751 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
752 };
753 peer.respond_with_error(receipt, proto_err)?;
754 Err(error)
755 }
756 }
757 }
758 })
759 }
760
761 pub fn handle_connection(
762 self: &Arc<Self>,
763 connection: Connection,
764 address: String,
765 principal: Principal,
766 zed_version: ZedVersion,
767 release_channel: Option<String>,
768 user_agent: Option<String>,
769 geoip_country_code: Option<String>,
770 system_id: Option<String>,
771 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
772 executor: Executor,
773 connection_guard: Option<ConnectionGuard>,
774 ) -> impl Future<Output = ()> + use<> {
775 let this = self.clone();
776 let span = info_span!("handle connection", %address,
777 connection_id=field::Empty,
778 user_id=field::Empty,
779 login=field::Empty,
780 impersonator=field::Empty,
781 user_agent=field::Empty,
782 geoip_country_code=field::Empty,
783 release_channel=field::Empty,
784 );
785 principal.update_span(&span);
786 if let Some(user_agent) = user_agent {
787 span.record("user_agent", user_agent);
788 }
789 if let Some(release_channel) = release_channel {
790 span.record("release_channel", release_channel);
791 }
792
793 if let Some(country_code) = geoip_country_code.as_ref() {
794 span.record("geoip_country_code", country_code);
795 }
796
797 let mut teardown = self.teardown.subscribe();
798 async move {
799 if *teardown.borrow() {
800 tracing::error!("server is tearing down");
801 return;
802 }
803
804 let (connection_id, handle_io, mut incoming_rx) =
805 this.peer.add_connection(connection, {
806 let executor = executor.clone();
807 move |duration| executor.sleep(duration)
808 });
809 tracing::Span::current().record("connection_id", format!("{}", connection_id));
810
811 tracing::info!("connection opened");
812
813 let session = Session {
814 principal: principal.clone(),
815 connection_id,
816 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
817 peer: this.peer.clone(),
818 connection_pool: this.connection_pool.clone(),
819 app_state: this.app_state.clone(),
820 geoip_country_code,
821 system_id,
822 _executor: executor.clone(),
823 };
824
825 if let Err(error) = this
826 .send_initial_client_update(
827 connection_id,
828 zed_version,
829 send_connection_id,
830 &session,
831 )
832 .await
833 {
834 tracing::error!(?error, "failed to send initial client update");
835 return;
836 }
837 drop(connection_guard);
838
839 let handle_io = handle_io.fuse();
840 futures::pin_mut!(handle_io);
841
842 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
843 // This prevents deadlocks when e.g., client A performs a request to client B and
844 // client B performs a request to client A. If both clients stop processing further
845 // messages until their respective request completes, they won't have a chance to
846 // respond to the other client's request and cause a deadlock.
847 //
848 // This arrangement ensures we will attempt to process earlier messages first, but fall
849 // back to processing messages arrived later in the spirit of making progress.
850 const MAX_CONCURRENT_HANDLERS: usize = 256;
851 let mut foreground_message_handlers = FuturesUnordered::new();
852 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
853 let get_concurrent_handlers = {
854 let concurrent_handlers = concurrent_handlers.clone();
855 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
856 };
857 loop {
858 let next_message = async {
859 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
860 let message = incoming_rx.next().await;
861 // Cache the concurrent_handlers here, so that we know what the
862 // queue looks like as each handler starts
863 (permit, message, get_concurrent_handlers())
864 }
865 .fuse();
866 futures::pin_mut!(next_message);
867 futures::select_biased! {
868 _ = teardown.changed().fuse() => return,
869 result = handle_io => {
870 if let Err(error) = result {
871 tracing::error!(?error, "error handling I/O");
872 }
873 break;
874 }
875 _ = foreground_message_handlers.next() => {}
876 next_message = next_message => {
877 let (permit, message, concurrent_handlers) = next_message;
878 if let Some(message) = message {
879 let type_name = message.payload_type_name();
880 // note: we copy all the fields from the parent span so we can query them in the logs.
881 // (https://github.com/tokio-rs/tracing/issues/2670).
882 let span = tracing::info_span!("receive message",
883 %connection_id,
884 %address,
885 type_name,
886 concurrent_handlers,
887 user_id=field::Empty,
888 login=field::Empty,
889 impersonator=field::Empty,
890 lsp_query_request=field::Empty,
891 release_channel=field::Empty,
892 { TOTAL_DURATION_MS }=field::Empty,
893 { PROCESSING_DURATION_MS }=field::Empty,
894 { QUEUE_DURATION_MS }=field::Empty,
895 { HOST_WAITING_MS }=field::Empty
896 );
897 principal.update_span(&span);
898 let span_enter = span.enter();
899 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
900 let is_background = message.is_background();
901 let handle_message = (handler)(message, session.clone(), span.clone());
902 drop(span_enter);
903
904 let handle_message = async move {
905 handle_message.await;
906 drop(permit);
907 }.instrument(span);
908 if is_background {
909 executor.spawn_detached(handle_message);
910 } else {
911 foreground_message_handlers.push(handle_message);
912 }
913 } else {
914 tracing::error!("no message handler");
915 }
916 } else {
917 tracing::info!("connection closed");
918 break;
919 }
920 }
921 }
922 }
923
924 drop(foreground_message_handlers);
925 let concurrent_handlers = get_concurrent_handlers();
926 tracing::info!(concurrent_handlers, "signing out");
927 if let Err(error) = connection_lost(session, teardown, executor).await {
928 tracing::error!(?error, "error signing out");
929 }
930 }
931 .instrument(span)
932 }
933
934 async fn send_initial_client_update(
935 &self,
936 connection_id: ConnectionId,
937 zed_version: ZedVersion,
938 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
939 session: &Session,
940 ) -> Result<()> {
941 self.peer.send(
942 connection_id,
943 proto::Hello {
944 peer_id: Some(connection_id.into()),
945 },
946 )?;
947 tracing::info!("sent hello message");
948 if let Some(send_connection_id) = send_connection_id.take() {
949 let _ = send_connection_id.send(connection_id);
950 }
951
952 match &session.principal {
953 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
954 if !user.connected_once {
955 self.peer.send(connection_id, proto::ShowContacts {})?;
956 self.app_state
957 .db
958 .set_user_connected_once(user.id, true)
959 .await?;
960 }
961
962 let contacts = self.app_state.db.get_contacts(user.id).await?;
963
964 {
965 let mut pool = self.connection_pool.lock();
966 pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
967 self.peer.send(
968 connection_id,
969 build_initial_contacts_update(contacts, &pool),
970 )?;
971 }
972
973 if should_auto_subscribe_to_channels(&zed_version) {
974 subscribe_user_to_channels(user.id, session).await?;
975 }
976
977 if let Some(incoming_call) =
978 self.app_state.db.incoming_call_for_user(user.id).await?
979 {
980 self.peer.send(connection_id, incoming_call)?;
981 }
982
983 update_user_contacts(user.id, session).await?;
984 }
985 }
986
987 Ok(())
988 }
989
990 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
991 ServerSnapshot {
992 connection_pool: ConnectionPoolGuard {
993 guard: self.connection_pool.lock(),
994 _not_send: PhantomData,
995 },
996 peer: &self.peer,
997 }
998 }
999}
1000
1001impl Deref for ConnectionPoolGuard<'_> {
1002 type Target = ConnectionPool;
1003
1004 fn deref(&self) -> &Self::Target {
1005 &self.guard
1006 }
1007}
1008
1009impl DerefMut for ConnectionPoolGuard<'_> {
1010 fn deref_mut(&mut self) -> &mut Self::Target {
1011 &mut self.guard
1012 }
1013}
1014
1015impl Drop for ConnectionPoolGuard<'_> {
1016 fn drop(&mut self) {
1017 #[cfg(feature = "test-support")]
1018 self.check_invariants();
1019 }
1020}
1021
1022fn broadcast<F>(
1023 sender_id: Option<ConnectionId>,
1024 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1025 mut f: F,
1026) where
1027 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1028{
1029 for receiver_id in receiver_ids {
1030 if Some(receiver_id) != sender_id
1031 && let Err(error) = f(receiver_id)
1032 {
1033 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1034 }
1035 }
1036}
1037
1038pub struct ProtocolVersion(u32);
1039
1040impl Header for ProtocolVersion {
1041 fn name() -> &'static HeaderName {
1042 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1043 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1044 }
1045
1046 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1047 where
1048 Self: Sized,
1049 I: Iterator<Item = &'i axum::http::HeaderValue>,
1050 {
1051 let version = values
1052 .next()
1053 .ok_or_else(axum::headers::Error::invalid)?
1054 .to_str()
1055 .map_err(|_| axum::headers::Error::invalid())?
1056 .parse()
1057 .map_err(|_| axum::headers::Error::invalid())?;
1058 Ok(Self(version))
1059 }
1060
1061 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1062 values.extend([self.0.to_string().parse().unwrap()]);
1063 }
1064}
1065
1066pub struct AppVersionHeader(Version);
1067impl Header for AppVersionHeader {
1068 fn name() -> &'static HeaderName {
1069 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1070 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1071 }
1072
1073 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1074 where
1075 Self: Sized,
1076 I: Iterator<Item = &'i axum::http::HeaderValue>,
1077 {
1078 let version = values
1079 .next()
1080 .ok_or_else(axum::headers::Error::invalid)?
1081 .to_str()
1082 .map_err(|_| axum::headers::Error::invalid())?
1083 .parse()
1084 .map_err(|_| axum::headers::Error::invalid())?;
1085 Ok(Self(version))
1086 }
1087
1088 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1089 values.extend([self.0.to_string().parse().unwrap()]);
1090 }
1091}
1092
1093#[derive(Debug)]
1094pub struct ReleaseChannelHeader(String);
1095
1096impl Header for ReleaseChannelHeader {
1097 fn name() -> &'static HeaderName {
1098 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1099 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1100 }
1101
1102 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1103 where
1104 Self: Sized,
1105 I: Iterator<Item = &'i axum::http::HeaderValue>,
1106 {
1107 Ok(Self(
1108 values
1109 .next()
1110 .ok_or_else(axum::headers::Error::invalid)?
1111 .to_str()
1112 .map_err(|_| axum::headers::Error::invalid())?
1113 .to_owned(),
1114 ))
1115 }
1116
1117 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1118 values.extend([self.0.parse().unwrap()]);
1119 }
1120}
1121
1122pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1123 Router::new()
1124 .route("/rpc", get(handle_websocket_request))
1125 .layer(
1126 ServiceBuilder::new()
1127 .layer(Extension(server.app_state.clone()))
1128 .layer(middleware::from_fn(auth::validate_header)),
1129 )
1130 .route("/metrics", get(handle_metrics))
1131 .layer(Extension(server))
1132}
1133
1134pub async fn handle_websocket_request(
1135 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1136 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1137 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1138 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1139 Extension(server): Extension<Arc<Server>>,
1140 Extension(principal): Extension<Principal>,
1141 user_agent: Option<TypedHeader<UserAgent>>,
1142 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1143 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1144 ws: WebSocketUpgrade,
1145) -> axum::response::Response {
1146 if protocol_version != rpc::PROTOCOL_VERSION {
1147 return (
1148 StatusCode::UPGRADE_REQUIRED,
1149 "client must be upgraded".to_string(),
1150 )
1151 .into_response();
1152 }
1153
1154 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1155 return (
1156 StatusCode::UPGRADE_REQUIRED,
1157 "no version header found".to_string(),
1158 )
1159 .into_response();
1160 };
1161
1162 let release_channel = release_channel_header.map(|header| header.0.0);
1163
1164 if !version.can_collaborate() {
1165 return (
1166 StatusCode::UPGRADE_REQUIRED,
1167 "client must be upgraded".to_string(),
1168 )
1169 .into_response();
1170 }
1171
1172 let socket_address = socket_address.to_string();
1173
1174 // Acquire connection guard before WebSocket upgrade
1175 let connection_guard = match ConnectionGuard::try_acquire() {
1176 Ok(guard) => guard,
1177 Err(()) => {
1178 return (
1179 StatusCode::SERVICE_UNAVAILABLE,
1180 "Too many concurrent connections",
1181 )
1182 .into_response();
1183 }
1184 };
1185
1186 ws.on_upgrade(move |socket| {
1187 let socket = socket
1188 .map_ok(to_tungstenite_message)
1189 .err_into()
1190 .with(|message| async move { to_axum_message(message) });
1191 let connection = Connection::new(Box::pin(socket));
1192 async move {
1193 server
1194 .handle_connection(
1195 connection,
1196 socket_address,
1197 principal,
1198 version,
1199 release_channel,
1200 user_agent.map(|header| header.to_string()),
1201 country_code_header.map(|header| header.to_string()),
1202 system_id_header.map(|header| header.to_string()),
1203 None,
1204 Executor::Production,
1205 Some(connection_guard),
1206 )
1207 .await;
1208 }
1209 })
1210}
1211
1212pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1213 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1214 let connections_metric = CONNECTIONS_METRIC
1215 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1216
1217 let connections = server
1218 .connection_pool
1219 .lock()
1220 .connections()
1221 .filter(|connection| !connection.admin)
1222 .count();
1223 connections_metric.set(connections as _);
1224
1225 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1226 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1227 register_int_gauge!(
1228 "shared_projects",
1229 "number of open projects with one or more guests"
1230 )
1231 .unwrap()
1232 });
1233
1234 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1235 shared_projects_metric.set(shared_projects as _);
1236
1237 let encoder = prometheus::TextEncoder::new();
1238 let metric_families = prometheus::gather();
1239 let encoded_metrics = encoder
1240 .encode_to_string(&metric_families)
1241 .map_err(|err| anyhow!("{err}"))?;
1242 Ok(encoded_metrics)
1243}
1244
1245#[instrument(err, skip(executor))]
1246async fn connection_lost(
1247 session: Session,
1248 mut teardown: watch::Receiver<bool>,
1249 executor: Executor,
1250) -> Result<()> {
1251 session.peer.disconnect(session.connection_id);
1252 session
1253 .connection_pool()
1254 .await
1255 .remove_connection(session.connection_id)?;
1256
1257 session
1258 .db()
1259 .await
1260 .connection_lost(session.connection_id)
1261 .await
1262 .trace_err();
1263
1264 futures::select_biased! {
1265 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1266
1267 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1268 leave_room_for_session(&session, session.connection_id).await.trace_err();
1269 leave_channel_buffers_for_session(&session)
1270 .await
1271 .trace_err();
1272
1273 if !session
1274 .connection_pool()
1275 .await
1276 .is_user_online(session.user_id())
1277 {
1278 let db = session.db().await;
1279 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1280 room_updated(&room, &session.peer);
1281 }
1282 }
1283
1284 update_user_contacts(session.user_id(), &session).await?;
1285 },
1286 _ = teardown.changed().fuse() => {}
1287 }
1288
1289 Ok(())
1290}
1291
1292/// Acknowledges a ping from a client, used to keep the connection alive.
1293async fn ping(
1294 _: proto::Ping,
1295 response: Response<proto::Ping>,
1296 _session: MessageContext,
1297) -> Result<()> {
1298 response.send(proto::Ack {})?;
1299 Ok(())
1300}
1301
1302/// Creates a new room for calling (outside of channels)
1303async fn create_room(
1304 _request: proto::CreateRoom,
1305 response: Response<proto::CreateRoom>,
1306 session: MessageContext,
1307) -> Result<()> {
1308 let livekit_room = nanoid::nanoid!(30);
1309
1310 let live_kit_connection_info = util::maybe!(async {
1311 let live_kit = session.app_state.livekit_client.as_ref();
1312 let live_kit = live_kit?;
1313 let user_id = session.user_id().to_string();
1314
1315 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1316
1317 Some(proto::LiveKitConnectionInfo {
1318 server_url: live_kit.url().into(),
1319 token,
1320 can_publish: true,
1321 })
1322 })
1323 .await;
1324
1325 let room = session
1326 .db()
1327 .await
1328 .create_room(session.user_id(), session.connection_id, &livekit_room)
1329 .await?;
1330
1331 response.send(proto::CreateRoomResponse {
1332 room: Some(room.clone()),
1333 live_kit_connection_info,
1334 })?;
1335
1336 update_user_contacts(session.user_id(), &session).await?;
1337 Ok(())
1338}
1339
1340/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1341async fn join_room(
1342 request: proto::JoinRoom,
1343 response: Response<proto::JoinRoom>,
1344 session: MessageContext,
1345) -> Result<()> {
1346 let room_id = RoomId::from_proto(request.id);
1347
1348 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1349
1350 if let Some(channel_id) = channel_id {
1351 return join_channel_internal(channel_id, Box::new(response), session).await;
1352 }
1353
1354 let joined_room = {
1355 let room = session
1356 .db()
1357 .await
1358 .join_room(room_id, session.user_id(), session.connection_id)
1359 .await?;
1360 room_updated(&room.room, &session.peer);
1361 room.into_inner()
1362 };
1363
1364 for connection_id in session
1365 .connection_pool()
1366 .await
1367 .user_connection_ids(session.user_id())
1368 {
1369 session
1370 .peer
1371 .send(
1372 connection_id,
1373 proto::CallCanceled {
1374 room_id: room_id.to_proto(),
1375 },
1376 )
1377 .trace_err();
1378 }
1379
1380 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1381 {
1382 live_kit
1383 .room_token(
1384 &joined_room.room.livekit_room,
1385 &session.user_id().to_string(),
1386 )
1387 .trace_err()
1388 .map(|token| proto::LiveKitConnectionInfo {
1389 server_url: live_kit.url().into(),
1390 token,
1391 can_publish: true,
1392 })
1393 } else {
1394 None
1395 };
1396
1397 response.send(proto::JoinRoomResponse {
1398 room: Some(joined_room.room),
1399 channel_id: None,
1400 live_kit_connection_info,
1401 })?;
1402
1403 update_user_contacts(session.user_id(), &session).await?;
1404 Ok(())
1405}
1406
1407/// Rejoin room is used to reconnect to a room after connection errors.
1408async fn rejoin_room(
1409 request: proto::RejoinRoom,
1410 response: Response<proto::RejoinRoom>,
1411 session: MessageContext,
1412) -> Result<()> {
1413 let room;
1414 let channel;
1415 {
1416 let mut rejoined_room = session
1417 .db()
1418 .await
1419 .rejoin_room(request, session.user_id(), session.connection_id)
1420 .await?;
1421
1422 response.send(proto::RejoinRoomResponse {
1423 room: Some(rejoined_room.room.clone()),
1424 reshared_projects: rejoined_room
1425 .reshared_projects
1426 .iter()
1427 .map(|project| proto::ResharedProject {
1428 id: project.id.to_proto(),
1429 collaborators: project
1430 .collaborators
1431 .iter()
1432 .map(|collaborator| collaborator.to_proto())
1433 .collect(),
1434 })
1435 .collect(),
1436 rejoined_projects: rejoined_room
1437 .rejoined_projects
1438 .iter()
1439 .map(|rejoined_project| rejoined_project.to_proto())
1440 .collect(),
1441 })?;
1442 room_updated(&rejoined_room.room, &session.peer);
1443
1444 for project in &rejoined_room.reshared_projects {
1445 for collaborator in &project.collaborators {
1446 session
1447 .peer
1448 .send(
1449 collaborator.connection_id,
1450 proto::UpdateProjectCollaborator {
1451 project_id: project.id.to_proto(),
1452 old_peer_id: Some(project.old_connection_id.into()),
1453 new_peer_id: Some(session.connection_id.into()),
1454 },
1455 )
1456 .trace_err();
1457 }
1458
1459 broadcast(
1460 Some(session.connection_id),
1461 project
1462 .collaborators
1463 .iter()
1464 .map(|collaborator| collaborator.connection_id),
1465 |connection_id| {
1466 session.peer.forward_send(
1467 session.connection_id,
1468 connection_id,
1469 proto::UpdateProject {
1470 project_id: project.id.to_proto(),
1471 worktrees: project.worktrees.clone(),
1472 },
1473 )
1474 },
1475 );
1476 }
1477
1478 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1479
1480 let rejoined_room = rejoined_room.into_inner();
1481
1482 room = rejoined_room.room;
1483 channel = rejoined_room.channel;
1484 }
1485
1486 if let Some(channel) = channel {
1487 channel_updated(
1488 &channel,
1489 &room,
1490 &session.peer,
1491 &*session.connection_pool().await,
1492 );
1493 }
1494
1495 update_user_contacts(session.user_id(), &session).await?;
1496 Ok(())
1497}
1498
1499fn notify_rejoined_projects(
1500 rejoined_projects: &mut Vec<RejoinedProject>,
1501 session: &Session,
1502) -> Result<()> {
1503 for project in rejoined_projects.iter() {
1504 for collaborator in &project.collaborators {
1505 session
1506 .peer
1507 .send(
1508 collaborator.connection_id,
1509 proto::UpdateProjectCollaborator {
1510 project_id: project.id.to_proto(),
1511 old_peer_id: Some(project.old_connection_id.into()),
1512 new_peer_id: Some(session.connection_id.into()),
1513 },
1514 )
1515 .trace_err();
1516 }
1517 }
1518
1519 for project in rejoined_projects {
1520 for worktree in mem::take(&mut project.worktrees) {
1521 // Stream this worktree's entries.
1522 let message = proto::UpdateWorktree {
1523 project_id: project.id.to_proto(),
1524 worktree_id: worktree.id,
1525 abs_path: worktree.abs_path.clone(),
1526 root_name: worktree.root_name,
1527 updated_entries: worktree.updated_entries,
1528 removed_entries: worktree.removed_entries,
1529 scan_id: worktree.scan_id,
1530 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1531 updated_repositories: worktree.updated_repositories,
1532 removed_repositories: worktree.removed_repositories,
1533 };
1534 for update in proto::split_worktree_update(message) {
1535 session.peer.send(session.connection_id, update)?;
1536 }
1537
1538 // Stream this worktree's diagnostics.
1539 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1540 if let Some(summary) = worktree_diagnostics.next() {
1541 let message = proto::UpdateDiagnosticSummary {
1542 project_id: project.id.to_proto(),
1543 worktree_id: worktree.id,
1544 summary: Some(summary),
1545 more_summaries: worktree_diagnostics.collect(),
1546 };
1547 session.peer.send(session.connection_id, message)?;
1548 }
1549
1550 for settings_file in worktree.settings_files {
1551 session.peer.send(
1552 session.connection_id,
1553 proto::UpdateWorktreeSettings {
1554 project_id: project.id.to_proto(),
1555 worktree_id: worktree.id,
1556 path: settings_file.path,
1557 content: Some(settings_file.content),
1558 kind: Some(settings_file.kind.to_proto().into()),
1559 outside_worktree: Some(settings_file.outside_worktree),
1560 },
1561 )?;
1562 }
1563 }
1564
1565 for repository in mem::take(&mut project.updated_repositories) {
1566 for update in split_repository_update(repository) {
1567 session.peer.send(session.connection_id, update)?;
1568 }
1569 }
1570
1571 for id in mem::take(&mut project.removed_repositories) {
1572 session.peer.send(
1573 session.connection_id,
1574 proto::RemoveRepository {
1575 project_id: project.id.to_proto(),
1576 id,
1577 },
1578 )?;
1579 }
1580 }
1581
1582 Ok(())
1583}
1584
1585/// leave room disconnects from the room.
1586async fn leave_room(
1587 _: proto::LeaveRoom,
1588 response: Response<proto::LeaveRoom>,
1589 session: MessageContext,
1590) -> Result<()> {
1591 leave_room_for_session(&session, session.connection_id).await?;
1592 response.send(proto::Ack {})?;
1593 Ok(())
1594}
1595
1596/// Updates the permissions of someone else in the room.
1597async fn set_room_participant_role(
1598 request: proto::SetRoomParticipantRole,
1599 response: Response<proto::SetRoomParticipantRole>,
1600 session: MessageContext,
1601) -> Result<()> {
1602 let user_id = UserId::from_proto(request.user_id);
1603 let role = ChannelRole::from(request.role());
1604
1605 let (livekit_room, can_publish) = {
1606 let room = session
1607 .db()
1608 .await
1609 .set_room_participant_role(
1610 session.user_id(),
1611 RoomId::from_proto(request.room_id),
1612 user_id,
1613 role,
1614 )
1615 .await?;
1616
1617 let livekit_room = room.livekit_room.clone();
1618 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1619 room_updated(&room, &session.peer);
1620 (livekit_room, can_publish)
1621 };
1622
1623 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1624 live_kit
1625 .update_participant(
1626 livekit_room.clone(),
1627 request.user_id.to_string(),
1628 livekit_api::proto::ParticipantPermission {
1629 can_subscribe: true,
1630 can_publish,
1631 can_publish_data: can_publish,
1632 hidden: false,
1633 recorder: false,
1634 },
1635 )
1636 .await
1637 .trace_err();
1638 }
1639
1640 response.send(proto::Ack {})?;
1641 Ok(())
1642}
1643
1644/// Call someone else into the current room
1645async fn call(
1646 request: proto::Call,
1647 response: Response<proto::Call>,
1648 session: MessageContext,
1649) -> Result<()> {
1650 let room_id = RoomId::from_proto(request.room_id);
1651 let calling_user_id = session.user_id();
1652 let calling_connection_id = session.connection_id;
1653 let called_user_id = UserId::from_proto(request.called_user_id);
1654 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1655 if !session
1656 .db()
1657 .await
1658 .has_contact(calling_user_id, called_user_id)
1659 .await?
1660 {
1661 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1662 }
1663
1664 let incoming_call = {
1665 let (room, incoming_call) = &mut *session
1666 .db()
1667 .await
1668 .call(
1669 room_id,
1670 calling_user_id,
1671 calling_connection_id,
1672 called_user_id,
1673 initial_project_id,
1674 )
1675 .await?;
1676 room_updated(room, &session.peer);
1677 mem::take(incoming_call)
1678 };
1679 update_user_contacts(called_user_id, &session).await?;
1680
1681 let mut calls = session
1682 .connection_pool()
1683 .await
1684 .user_connection_ids(called_user_id)
1685 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1686 .collect::<FuturesUnordered<_>>();
1687
1688 while let Some(call_response) = calls.next().await {
1689 match call_response.as_ref() {
1690 Ok(_) => {
1691 response.send(proto::Ack {})?;
1692 return Ok(());
1693 }
1694 Err(_) => {
1695 call_response.trace_err();
1696 }
1697 }
1698 }
1699
1700 {
1701 let room = session
1702 .db()
1703 .await
1704 .call_failed(room_id, called_user_id)
1705 .await?;
1706 room_updated(&room, &session.peer);
1707 }
1708 update_user_contacts(called_user_id, &session).await?;
1709
1710 Err(anyhow!("failed to ring user"))?
1711}
1712
1713/// Cancel an outgoing call.
1714async fn cancel_call(
1715 request: proto::CancelCall,
1716 response: Response<proto::CancelCall>,
1717 session: MessageContext,
1718) -> Result<()> {
1719 let called_user_id = UserId::from_proto(request.called_user_id);
1720 let room_id = RoomId::from_proto(request.room_id);
1721 {
1722 let room = session
1723 .db()
1724 .await
1725 .cancel_call(room_id, session.connection_id, called_user_id)
1726 .await?;
1727 room_updated(&room, &session.peer);
1728 }
1729
1730 for connection_id in session
1731 .connection_pool()
1732 .await
1733 .user_connection_ids(called_user_id)
1734 {
1735 session
1736 .peer
1737 .send(
1738 connection_id,
1739 proto::CallCanceled {
1740 room_id: room_id.to_proto(),
1741 },
1742 )
1743 .trace_err();
1744 }
1745 response.send(proto::Ack {})?;
1746
1747 update_user_contacts(called_user_id, &session).await?;
1748 Ok(())
1749}
1750
1751/// Decline an incoming call.
1752async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1753 let room_id = RoomId::from_proto(message.room_id);
1754 {
1755 let room = session
1756 .db()
1757 .await
1758 .decline_call(Some(room_id), session.user_id())
1759 .await?
1760 .context("declining call")?;
1761 room_updated(&room, &session.peer);
1762 }
1763
1764 for connection_id in session
1765 .connection_pool()
1766 .await
1767 .user_connection_ids(session.user_id())
1768 {
1769 session
1770 .peer
1771 .send(
1772 connection_id,
1773 proto::CallCanceled {
1774 room_id: room_id.to_proto(),
1775 },
1776 )
1777 .trace_err();
1778 }
1779 update_user_contacts(session.user_id(), &session).await?;
1780 Ok(())
1781}
1782
1783/// Updates other participants in the room with your current location.
1784async fn update_participant_location(
1785 request: proto::UpdateParticipantLocation,
1786 response: Response<proto::UpdateParticipantLocation>,
1787 session: MessageContext,
1788) -> Result<()> {
1789 let room_id = RoomId::from_proto(request.room_id);
1790 let location = request.location.context("invalid location")?;
1791
1792 let db = session.db().await;
1793 let room = db
1794 .update_room_participant_location(room_id, session.connection_id, location)
1795 .await?;
1796
1797 room_updated(&room, &session.peer);
1798 response.send(proto::Ack {})?;
1799 Ok(())
1800}
1801
1802/// Share a project into the room.
1803async fn share_project(
1804 request: proto::ShareProject,
1805 response: Response<proto::ShareProject>,
1806 session: MessageContext,
1807) -> Result<()> {
1808 let (project_id, room) = &*session
1809 .db()
1810 .await
1811 .share_project(
1812 RoomId::from_proto(request.room_id),
1813 session.connection_id,
1814 &request.worktrees,
1815 request.is_ssh_project,
1816 request.windows_paths.unwrap_or(false),
1817 )
1818 .await?;
1819 response.send(proto::ShareProjectResponse {
1820 project_id: project_id.to_proto(),
1821 })?;
1822 room_updated(room, &session.peer);
1823
1824 Ok(())
1825}
1826
1827/// Unshare a project from the room.
1828async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1829 let project_id = ProjectId::from_proto(message.project_id);
1830 unshare_project_internal(project_id, session.connection_id, &session).await
1831}
1832
1833async fn unshare_project_internal(
1834 project_id: ProjectId,
1835 connection_id: ConnectionId,
1836 session: &Session,
1837) -> Result<()> {
1838 let delete = {
1839 let room_guard = session
1840 .db()
1841 .await
1842 .unshare_project(project_id, connection_id)
1843 .await?;
1844
1845 let (delete, room, guest_connection_ids) = &*room_guard;
1846
1847 let message = proto::UnshareProject {
1848 project_id: project_id.to_proto(),
1849 };
1850
1851 broadcast(
1852 Some(connection_id),
1853 guest_connection_ids.iter().copied(),
1854 |conn_id| session.peer.send(conn_id, message.clone()),
1855 );
1856 if let Some(room) = room {
1857 room_updated(room, &session.peer);
1858 }
1859
1860 *delete
1861 };
1862
1863 if delete {
1864 let db = session.db().await;
1865 db.delete_project(project_id).await?;
1866 }
1867
1868 Ok(())
1869}
1870
1871/// Join someone elses shared project.
1872async fn join_project(
1873 request: proto::JoinProject,
1874 response: Response<proto::JoinProject>,
1875 session: MessageContext,
1876) -> Result<()> {
1877 let project_id = ProjectId::from_proto(request.project_id);
1878
1879 tracing::info!(%project_id, "join project");
1880
1881 let db = session.db().await;
1882 let (project, replica_id) = &mut *db
1883 .join_project(
1884 project_id,
1885 session.connection_id,
1886 session.user_id(),
1887 request.committer_name.clone(),
1888 request.committer_email.clone(),
1889 )
1890 .await?;
1891 drop(db);
1892 tracing::info!(%project_id, "join remote project");
1893 let collaborators = project
1894 .collaborators
1895 .iter()
1896 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1897 .map(|collaborator| collaborator.to_proto())
1898 .collect::<Vec<_>>();
1899 let project_id = project.id;
1900 let guest_user_id = session.user_id();
1901
1902 let worktrees = project
1903 .worktrees
1904 .iter()
1905 .map(|(id, worktree)| proto::WorktreeMetadata {
1906 id: *id,
1907 root_name: worktree.root_name.clone(),
1908 visible: worktree.visible,
1909 abs_path: worktree.abs_path.clone(),
1910 })
1911 .collect::<Vec<_>>();
1912
1913 let add_project_collaborator = proto::AddProjectCollaborator {
1914 project_id: project_id.to_proto(),
1915 collaborator: Some(proto::Collaborator {
1916 peer_id: Some(session.connection_id.into()),
1917 replica_id: replica_id.0 as u32,
1918 user_id: guest_user_id.to_proto(),
1919 is_host: false,
1920 committer_name: request.committer_name.clone(),
1921 committer_email: request.committer_email.clone(),
1922 }),
1923 };
1924
1925 for collaborator in &collaborators {
1926 session
1927 .peer
1928 .send(
1929 collaborator.peer_id.unwrap().into(),
1930 add_project_collaborator.clone(),
1931 )
1932 .trace_err();
1933 }
1934
1935 // First, we send the metadata associated with each worktree.
1936 let (language_servers, language_server_capabilities) = project
1937 .language_servers
1938 .clone()
1939 .into_iter()
1940 .map(|server| (server.server, server.capabilities))
1941 .unzip();
1942 response.send(proto::JoinProjectResponse {
1943 project_id: project.id.0 as u64,
1944 worktrees,
1945 replica_id: replica_id.0 as u32,
1946 collaborators,
1947 language_servers,
1948 language_server_capabilities,
1949 role: project.role.into(),
1950 windows_paths: project.path_style == PathStyle::Windows,
1951 })?;
1952
1953 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1954 // Stream this worktree's entries.
1955 let message = proto::UpdateWorktree {
1956 project_id: project_id.to_proto(),
1957 worktree_id,
1958 abs_path: worktree.abs_path.clone(),
1959 root_name: worktree.root_name,
1960 updated_entries: worktree.entries,
1961 removed_entries: Default::default(),
1962 scan_id: worktree.scan_id,
1963 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1964 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1965 removed_repositories: Default::default(),
1966 };
1967 for update in proto::split_worktree_update(message) {
1968 session.peer.send(session.connection_id, update.clone())?;
1969 }
1970
1971 // Stream this worktree's diagnostics.
1972 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1973 if let Some(summary) = worktree_diagnostics.next() {
1974 let message = proto::UpdateDiagnosticSummary {
1975 project_id: project.id.to_proto(),
1976 worktree_id: worktree.id,
1977 summary: Some(summary),
1978 more_summaries: worktree_diagnostics.collect(),
1979 };
1980 session.peer.send(session.connection_id, message)?;
1981 }
1982
1983 for settings_file in worktree.settings_files {
1984 session.peer.send(
1985 session.connection_id,
1986 proto::UpdateWorktreeSettings {
1987 project_id: project_id.to_proto(),
1988 worktree_id: worktree.id,
1989 path: settings_file.path,
1990 content: Some(settings_file.content),
1991 kind: Some(settings_file.kind.to_proto() as i32),
1992 outside_worktree: Some(settings_file.outside_worktree),
1993 },
1994 )?;
1995 }
1996 }
1997
1998 for repository in mem::take(&mut project.repositories) {
1999 for update in split_repository_update(repository) {
2000 session.peer.send(session.connection_id, update)?;
2001 }
2002 }
2003
2004 for language_server in &project.language_servers {
2005 session.peer.send(
2006 session.connection_id,
2007 proto::UpdateLanguageServer {
2008 project_id: project_id.to_proto(),
2009 server_name: Some(language_server.server.name.clone()),
2010 language_server_id: language_server.server.id,
2011 variant: Some(
2012 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2013 proto::LspDiskBasedDiagnosticsUpdated {},
2014 ),
2015 ),
2016 },
2017 )?;
2018 }
2019
2020 Ok(())
2021}
2022
2023/// Leave someone elses shared project.
2024async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2025 let sender_id = session.connection_id;
2026 let project_id = ProjectId::from_proto(request.project_id);
2027 let db = session.db().await;
2028
2029 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2030 tracing::info!(
2031 %project_id,
2032 "leave project"
2033 );
2034
2035 project_left(project, &session);
2036 if let Some(room) = room {
2037 room_updated(room, &session.peer);
2038 }
2039
2040 Ok(())
2041}
2042
2043/// Updates other participants with changes to the project
2044async fn update_project(
2045 request: proto::UpdateProject,
2046 response: Response<proto::UpdateProject>,
2047 session: MessageContext,
2048) -> Result<()> {
2049 let project_id = ProjectId::from_proto(request.project_id);
2050 let (room, guest_connection_ids) = &*session
2051 .db()
2052 .await
2053 .update_project(project_id, session.connection_id, &request.worktrees)
2054 .await?;
2055 broadcast(
2056 Some(session.connection_id),
2057 guest_connection_ids.iter().copied(),
2058 |connection_id| {
2059 session
2060 .peer
2061 .forward_send(session.connection_id, connection_id, request.clone())
2062 },
2063 );
2064 if let Some(room) = room {
2065 room_updated(room, &session.peer);
2066 }
2067 response.send(proto::Ack {})?;
2068
2069 Ok(())
2070}
2071
2072/// Updates other participants with changes to the worktree
2073async fn update_worktree(
2074 request: proto::UpdateWorktree,
2075 response: Response<proto::UpdateWorktree>,
2076 session: MessageContext,
2077) -> Result<()> {
2078 let guest_connection_ids = session
2079 .db()
2080 .await
2081 .update_worktree(&request, session.connection_id)
2082 .await?;
2083
2084 broadcast(
2085 Some(session.connection_id),
2086 guest_connection_ids.iter().copied(),
2087 |connection_id| {
2088 session
2089 .peer
2090 .forward_send(session.connection_id, connection_id, request.clone())
2091 },
2092 );
2093 response.send(proto::Ack {})?;
2094 Ok(())
2095}
2096
2097async fn update_repository(
2098 request: proto::UpdateRepository,
2099 response: Response<proto::UpdateRepository>,
2100 session: MessageContext,
2101) -> Result<()> {
2102 let guest_connection_ids = session
2103 .db()
2104 .await
2105 .update_repository(&request, session.connection_id)
2106 .await?;
2107
2108 broadcast(
2109 Some(session.connection_id),
2110 guest_connection_ids.iter().copied(),
2111 |connection_id| {
2112 session
2113 .peer
2114 .forward_send(session.connection_id, connection_id, request.clone())
2115 },
2116 );
2117 response.send(proto::Ack {})?;
2118 Ok(())
2119}
2120
2121async fn remove_repository(
2122 request: proto::RemoveRepository,
2123 response: Response<proto::RemoveRepository>,
2124 session: MessageContext,
2125) -> Result<()> {
2126 let guest_connection_ids = session
2127 .db()
2128 .await
2129 .remove_repository(&request, session.connection_id)
2130 .await?;
2131
2132 broadcast(
2133 Some(session.connection_id),
2134 guest_connection_ids.iter().copied(),
2135 |connection_id| {
2136 session
2137 .peer
2138 .forward_send(session.connection_id, connection_id, request.clone())
2139 },
2140 );
2141 response.send(proto::Ack {})?;
2142 Ok(())
2143}
2144
2145/// Updates other participants with changes to the diagnostics
2146async fn update_diagnostic_summary(
2147 message: proto::UpdateDiagnosticSummary,
2148 session: MessageContext,
2149) -> Result<()> {
2150 let guest_connection_ids = session
2151 .db()
2152 .await
2153 .update_diagnostic_summary(&message, session.connection_id)
2154 .await?;
2155
2156 broadcast(
2157 Some(session.connection_id),
2158 guest_connection_ids.iter().copied(),
2159 |connection_id| {
2160 session
2161 .peer
2162 .forward_send(session.connection_id, connection_id, message.clone())
2163 },
2164 );
2165
2166 Ok(())
2167}
2168
2169/// Updates other participants with changes to the worktree settings
2170async fn update_worktree_settings(
2171 message: proto::UpdateWorktreeSettings,
2172 session: MessageContext,
2173) -> Result<()> {
2174 let guest_connection_ids = session
2175 .db()
2176 .await
2177 .update_worktree_settings(&message, session.connection_id)
2178 .await?;
2179
2180 broadcast(
2181 Some(session.connection_id),
2182 guest_connection_ids.iter().copied(),
2183 |connection_id| {
2184 session
2185 .peer
2186 .forward_send(session.connection_id, connection_id, message.clone())
2187 },
2188 );
2189
2190 Ok(())
2191}
2192
2193/// Notify other participants that a language server has started.
2194async fn start_language_server(
2195 request: proto::StartLanguageServer,
2196 session: MessageContext,
2197) -> Result<()> {
2198 let guest_connection_ids = session
2199 .db()
2200 .await
2201 .start_language_server(&request, session.connection_id)
2202 .await?;
2203
2204 broadcast(
2205 Some(session.connection_id),
2206 guest_connection_ids.iter().copied(),
2207 |connection_id| {
2208 session
2209 .peer
2210 .forward_send(session.connection_id, connection_id, request.clone())
2211 },
2212 );
2213 Ok(())
2214}
2215
2216/// Notify other participants that a language server has changed.
2217async fn update_language_server(
2218 request: proto::UpdateLanguageServer,
2219 session: MessageContext,
2220) -> Result<()> {
2221 let project_id = ProjectId::from_proto(request.project_id);
2222 let db = session.db().await;
2223
2224 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2225 && let Some(capabilities) = update.capabilities.clone()
2226 {
2227 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2228 .await?;
2229 }
2230
2231 let project_connection_ids = db
2232 .project_connection_ids(project_id, session.connection_id, true)
2233 .await?;
2234 broadcast(
2235 Some(session.connection_id),
2236 project_connection_ids.iter().copied(),
2237 |connection_id| {
2238 session
2239 .peer
2240 .forward_send(session.connection_id, connection_id, request.clone())
2241 },
2242 );
2243 Ok(())
2244}
2245
2246/// forward a project request to the host. These requests should be read only
2247/// as guests are allowed to send them.
2248async fn forward_read_only_project_request<T>(
2249 request: T,
2250 response: Response<T>,
2251 session: MessageContext,
2252) -> Result<()>
2253where
2254 T: EntityMessage + RequestMessage,
2255{
2256 let project_id = ProjectId::from_proto(request.remote_entity_id());
2257 let host_connection_id = session
2258 .db()
2259 .await
2260 .host_for_read_only_project_request(project_id, session.connection_id)
2261 .await?;
2262 let payload = session.forward_request(host_connection_id, request).await?;
2263 response.send(payload)?;
2264 Ok(())
2265}
2266
2267/// forward a project request to the host. These requests are disallowed
2268/// for guests.
2269async fn forward_mutating_project_request<T>(
2270 request: T,
2271 response: Response<T>,
2272 session: MessageContext,
2273) -> Result<()>
2274where
2275 T: EntityMessage + RequestMessage,
2276{
2277 let project_id = ProjectId::from_proto(request.remote_entity_id());
2278
2279 let host_connection_id = session
2280 .db()
2281 .await
2282 .host_for_mutating_project_request(project_id, session.connection_id)
2283 .await?;
2284 let payload = session.forward_request(host_connection_id, request).await?;
2285 response.send(payload)?;
2286 Ok(())
2287}
2288
2289async fn lsp_query(
2290 request: proto::LspQuery,
2291 response: Response<proto::LspQuery>,
2292 session: MessageContext,
2293) -> Result<()> {
2294 let (name, should_write) = request.query_name_and_write_permissions();
2295 tracing::Span::current().record("lsp_query_request", name);
2296 tracing::info!("lsp_query message received");
2297 if should_write {
2298 forward_mutating_project_request(request, response, session).await
2299 } else {
2300 forward_read_only_project_request(request, response, session).await
2301 }
2302}
2303
2304/// Notify other participants that a new buffer has been created
2305async fn create_buffer_for_peer(
2306 request: proto::CreateBufferForPeer,
2307 session: MessageContext,
2308) -> Result<()> {
2309 session
2310 .db()
2311 .await
2312 .check_user_is_project_host(
2313 ProjectId::from_proto(request.project_id),
2314 session.connection_id,
2315 )
2316 .await?;
2317 let peer_id = request.peer_id.context("invalid peer id")?;
2318 session
2319 .peer
2320 .forward_send(session.connection_id, peer_id.into(), request)?;
2321 Ok(())
2322}
2323
2324/// Notify other participants that a new image has been created
2325async fn create_image_for_peer(
2326 request: proto::CreateImageForPeer,
2327 session: MessageContext,
2328) -> Result<()> {
2329 session
2330 .db()
2331 .await
2332 .check_user_is_project_host(
2333 ProjectId::from_proto(request.project_id),
2334 session.connection_id,
2335 )
2336 .await?;
2337 let peer_id = request.peer_id.context("invalid peer id")?;
2338 session
2339 .peer
2340 .forward_send(session.connection_id, peer_id.into(), request)?;
2341 Ok(())
2342}
2343
2344/// Notify other participants that a buffer has been updated. This is
2345/// allowed for guests as long as the update is limited to selections.
2346async fn update_buffer(
2347 request: proto::UpdateBuffer,
2348 response: Response<proto::UpdateBuffer>,
2349 session: MessageContext,
2350) -> Result<()> {
2351 let project_id = ProjectId::from_proto(request.project_id);
2352 let mut capability = Capability::ReadOnly;
2353
2354 for op in request.operations.iter() {
2355 match op.variant {
2356 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2357 Some(_) => capability = Capability::ReadWrite,
2358 }
2359 }
2360
2361 let host = {
2362 let guard = session
2363 .db()
2364 .await
2365 .connections_for_buffer_update(project_id, session.connection_id, capability)
2366 .await?;
2367
2368 let (host, guests) = &*guard;
2369
2370 broadcast(
2371 Some(session.connection_id),
2372 guests.clone(),
2373 |connection_id| {
2374 session
2375 .peer
2376 .forward_send(session.connection_id, connection_id, request.clone())
2377 },
2378 );
2379
2380 *host
2381 };
2382
2383 if host != session.connection_id {
2384 session.forward_request(host, request.clone()).await?;
2385 }
2386
2387 response.send(proto::Ack {})?;
2388 Ok(())
2389}
2390
2391async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2392 let project_id = ProjectId::from_proto(message.project_id);
2393
2394 let operation = message.operation.as_ref().context("invalid operation")?;
2395 let capability = match operation.variant.as_ref() {
2396 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2397 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2398 match buffer_op.variant {
2399 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2400 Capability::ReadOnly
2401 }
2402 _ => Capability::ReadWrite,
2403 }
2404 } else {
2405 Capability::ReadWrite
2406 }
2407 }
2408 Some(_) => Capability::ReadWrite,
2409 None => Capability::ReadOnly,
2410 };
2411
2412 let guard = session
2413 .db()
2414 .await
2415 .connections_for_buffer_update(project_id, session.connection_id, capability)
2416 .await?;
2417
2418 let (host, guests) = &*guard;
2419
2420 broadcast(
2421 Some(session.connection_id),
2422 guests.iter().chain([host]).copied(),
2423 |connection_id| {
2424 session
2425 .peer
2426 .forward_send(session.connection_id, connection_id, message.clone())
2427 },
2428 );
2429
2430 Ok(())
2431}
2432
2433async fn forward_project_search_chunk(
2434 message: proto::FindSearchCandidatesChunk,
2435 response: Response<proto::FindSearchCandidatesChunk>,
2436 session: MessageContext,
2437) -> Result<()> {
2438 let peer_id = message.peer_id.context("missing peer_id")?;
2439 let payload = session
2440 .peer
2441 .forward_request(session.connection_id, peer_id.into(), message)
2442 .await?;
2443 response.send(payload)?;
2444 Ok(())
2445}
2446
2447/// Notify other participants that a project has been updated.
2448async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2449 request: T,
2450 session: MessageContext,
2451) -> Result<()> {
2452 let project_id = ProjectId::from_proto(request.remote_entity_id());
2453 let project_connection_ids = session
2454 .db()
2455 .await
2456 .project_connection_ids(project_id, session.connection_id, false)
2457 .await?;
2458
2459 broadcast(
2460 Some(session.connection_id),
2461 project_connection_ids.iter().copied(),
2462 |connection_id| {
2463 session
2464 .peer
2465 .forward_send(session.connection_id, connection_id, request.clone())
2466 },
2467 );
2468 Ok(())
2469}
2470
2471/// Start following another user in a call.
2472async fn follow(
2473 request: proto::Follow,
2474 response: Response<proto::Follow>,
2475 session: MessageContext,
2476) -> Result<()> {
2477 let room_id = RoomId::from_proto(request.room_id);
2478 let project_id = request.project_id.map(ProjectId::from_proto);
2479 let leader_id = request.leader_id.context("invalid leader id")?.into();
2480 let follower_id = session.connection_id;
2481
2482 session
2483 .db()
2484 .await
2485 .check_room_participants(room_id, leader_id, session.connection_id)
2486 .await?;
2487
2488 let response_payload = session.forward_request(leader_id, request).await?;
2489 response.send(response_payload)?;
2490
2491 if let Some(project_id) = project_id {
2492 let room = session
2493 .db()
2494 .await
2495 .follow(room_id, project_id, leader_id, follower_id)
2496 .await?;
2497 room_updated(&room, &session.peer);
2498 }
2499
2500 Ok(())
2501}
2502
2503/// Stop following another user in a call.
2504async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2505 let room_id = RoomId::from_proto(request.room_id);
2506 let project_id = request.project_id.map(ProjectId::from_proto);
2507 let leader_id = request.leader_id.context("invalid leader id")?.into();
2508 let follower_id = session.connection_id;
2509
2510 session
2511 .db()
2512 .await
2513 .check_room_participants(room_id, leader_id, session.connection_id)
2514 .await?;
2515
2516 session
2517 .peer
2518 .forward_send(session.connection_id, leader_id, request)?;
2519
2520 if let Some(project_id) = project_id {
2521 let room = session
2522 .db()
2523 .await
2524 .unfollow(room_id, project_id, leader_id, follower_id)
2525 .await?;
2526 room_updated(&room, &session.peer);
2527 }
2528
2529 Ok(())
2530}
2531
2532/// Notify everyone following you of your current location.
2533async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2534 let room_id = RoomId::from_proto(request.room_id);
2535 let database = session.db.lock().await;
2536
2537 let connection_ids = if let Some(project_id) = request.project_id {
2538 let project_id = ProjectId::from_proto(project_id);
2539 database
2540 .project_connection_ids(project_id, session.connection_id, true)
2541 .await?
2542 } else {
2543 database
2544 .room_connection_ids(room_id, session.connection_id)
2545 .await?
2546 };
2547
2548 // For now, don't send view update messages back to that view's current leader.
2549 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2550 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2551 _ => None,
2552 });
2553
2554 for connection_id in connection_ids.iter().cloned() {
2555 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2556 session
2557 .peer
2558 .forward_send(session.connection_id, connection_id, request.clone())?;
2559 }
2560 }
2561 Ok(())
2562}
2563
2564/// Get public data about users.
2565async fn get_users(
2566 request: proto::GetUsers,
2567 response: Response<proto::GetUsers>,
2568 session: MessageContext,
2569) -> Result<()> {
2570 let user_ids = request
2571 .user_ids
2572 .into_iter()
2573 .map(UserId::from_proto)
2574 .collect();
2575 let users = session
2576 .db()
2577 .await
2578 .get_users_by_ids(user_ids)
2579 .await?
2580 .into_iter()
2581 .map(|user| proto::User {
2582 id: user.id.to_proto(),
2583 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2584 github_login: user.github_login,
2585 name: user.name,
2586 })
2587 .collect();
2588 response.send(proto::UsersResponse { users })?;
2589 Ok(())
2590}
2591
2592/// Search for users (to invite) buy Github login
2593async fn fuzzy_search_users(
2594 request: proto::FuzzySearchUsers,
2595 response: Response<proto::FuzzySearchUsers>,
2596 session: MessageContext,
2597) -> Result<()> {
2598 let query = request.query;
2599 let users = match query.len() {
2600 0 => vec![],
2601 1 | 2 => session
2602 .db()
2603 .await
2604 .get_user_by_github_login(&query)
2605 .await?
2606 .into_iter()
2607 .collect(),
2608 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2609 };
2610 let users = users
2611 .into_iter()
2612 .filter(|user| user.id != session.user_id())
2613 .map(|user| proto::User {
2614 id: user.id.to_proto(),
2615 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2616 github_login: user.github_login,
2617 name: user.name,
2618 })
2619 .collect();
2620 response.send(proto::UsersResponse { users })?;
2621 Ok(())
2622}
2623
2624/// Send a contact request to another user.
2625async fn request_contact(
2626 request: proto::RequestContact,
2627 response: Response<proto::RequestContact>,
2628 session: MessageContext,
2629) -> Result<()> {
2630 let requester_id = session.user_id();
2631 let responder_id = UserId::from_proto(request.responder_id);
2632 if requester_id == responder_id {
2633 return Err(anyhow!("cannot add yourself as a contact"))?;
2634 }
2635
2636 let notifications = session
2637 .db()
2638 .await
2639 .send_contact_request(requester_id, responder_id)
2640 .await?;
2641
2642 // Update outgoing contact requests of requester
2643 let mut update = proto::UpdateContacts::default();
2644 update.outgoing_requests.push(responder_id.to_proto());
2645 for connection_id in session
2646 .connection_pool()
2647 .await
2648 .user_connection_ids(requester_id)
2649 {
2650 session.peer.send(connection_id, update.clone())?;
2651 }
2652
2653 // Update incoming contact requests of responder
2654 let mut update = proto::UpdateContacts::default();
2655 update
2656 .incoming_requests
2657 .push(proto::IncomingContactRequest {
2658 requester_id: requester_id.to_proto(),
2659 });
2660 let connection_pool = session.connection_pool().await;
2661 for connection_id in connection_pool.user_connection_ids(responder_id) {
2662 session.peer.send(connection_id, update.clone())?;
2663 }
2664
2665 send_notifications(&connection_pool, &session.peer, notifications);
2666
2667 response.send(proto::Ack {})?;
2668 Ok(())
2669}
2670
2671/// Accept or decline a contact request
2672async fn respond_to_contact_request(
2673 request: proto::RespondToContactRequest,
2674 response: Response<proto::RespondToContactRequest>,
2675 session: MessageContext,
2676) -> Result<()> {
2677 let responder_id = session.user_id();
2678 let requester_id = UserId::from_proto(request.requester_id);
2679 let db = session.db().await;
2680 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2681 db.dismiss_contact_notification(responder_id, requester_id)
2682 .await?;
2683 } else {
2684 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2685
2686 let notifications = db
2687 .respond_to_contact_request(responder_id, requester_id, accept)
2688 .await?;
2689 let requester_busy = db.is_user_busy(requester_id).await?;
2690 let responder_busy = db.is_user_busy(responder_id).await?;
2691
2692 let pool = session.connection_pool().await;
2693 // Update responder with new contact
2694 let mut update = proto::UpdateContacts::default();
2695 if accept {
2696 update
2697 .contacts
2698 .push(contact_for_user(requester_id, requester_busy, &pool));
2699 }
2700 update
2701 .remove_incoming_requests
2702 .push(requester_id.to_proto());
2703 for connection_id in pool.user_connection_ids(responder_id) {
2704 session.peer.send(connection_id, update.clone())?;
2705 }
2706
2707 // Update requester with new contact
2708 let mut update = proto::UpdateContacts::default();
2709 if accept {
2710 update
2711 .contacts
2712 .push(contact_for_user(responder_id, responder_busy, &pool));
2713 }
2714 update
2715 .remove_outgoing_requests
2716 .push(responder_id.to_proto());
2717
2718 for connection_id in pool.user_connection_ids(requester_id) {
2719 session.peer.send(connection_id, update.clone())?;
2720 }
2721
2722 send_notifications(&pool, &session.peer, notifications);
2723 }
2724
2725 response.send(proto::Ack {})?;
2726 Ok(())
2727}
2728
2729/// Remove a contact.
2730async fn remove_contact(
2731 request: proto::RemoveContact,
2732 response: Response<proto::RemoveContact>,
2733 session: MessageContext,
2734) -> Result<()> {
2735 let requester_id = session.user_id();
2736 let responder_id = UserId::from_proto(request.user_id);
2737 let db = session.db().await;
2738 let (contact_accepted, deleted_notification_id) =
2739 db.remove_contact(requester_id, responder_id).await?;
2740
2741 let pool = session.connection_pool().await;
2742 // Update outgoing contact requests of requester
2743 let mut update = proto::UpdateContacts::default();
2744 if contact_accepted {
2745 update.remove_contacts.push(responder_id.to_proto());
2746 } else {
2747 update
2748 .remove_outgoing_requests
2749 .push(responder_id.to_proto());
2750 }
2751 for connection_id in pool.user_connection_ids(requester_id) {
2752 session.peer.send(connection_id, update.clone())?;
2753 }
2754
2755 // Update incoming contact requests of responder
2756 let mut update = proto::UpdateContacts::default();
2757 if contact_accepted {
2758 update.remove_contacts.push(requester_id.to_proto());
2759 } else {
2760 update
2761 .remove_incoming_requests
2762 .push(requester_id.to_proto());
2763 }
2764 for connection_id in pool.user_connection_ids(responder_id) {
2765 session.peer.send(connection_id, update.clone())?;
2766 if let Some(notification_id) = deleted_notification_id {
2767 session.peer.send(
2768 connection_id,
2769 proto::DeleteNotification {
2770 notification_id: notification_id.to_proto(),
2771 },
2772 )?;
2773 }
2774 }
2775
2776 response.send(proto::Ack {})?;
2777 Ok(())
2778}
2779
2780fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2781 version.0.minor < 139
2782}
2783
2784async fn subscribe_to_channels(
2785 _: proto::SubscribeToChannels,
2786 session: MessageContext,
2787) -> Result<()> {
2788 subscribe_user_to_channels(session.user_id(), &session).await?;
2789 Ok(())
2790}
2791
2792async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2793 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2794 let mut pool = session.connection_pool().await;
2795 for membership in &channels_for_user.channel_memberships {
2796 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2797 }
2798 session.peer.send(
2799 session.connection_id,
2800 build_update_user_channels(&channels_for_user),
2801 )?;
2802 session.peer.send(
2803 session.connection_id,
2804 build_channels_update(channels_for_user),
2805 )?;
2806 Ok(())
2807}
2808
2809/// Creates a new channel.
2810async fn create_channel(
2811 request: proto::CreateChannel,
2812 response: Response<proto::CreateChannel>,
2813 session: MessageContext,
2814) -> Result<()> {
2815 let db = session.db().await;
2816
2817 let parent_id = request.parent_id.map(ChannelId::from_proto);
2818 let (channel, membership) = db
2819 .create_channel(&request.name, parent_id, session.user_id())
2820 .await?;
2821
2822 let root_id = channel.root_id();
2823 let channel = Channel::from_model(channel);
2824
2825 response.send(proto::CreateChannelResponse {
2826 channel: Some(channel.to_proto()),
2827 parent_id: request.parent_id,
2828 })?;
2829
2830 let mut connection_pool = session.connection_pool().await;
2831 if let Some(membership) = membership {
2832 connection_pool.subscribe_to_channel(
2833 membership.user_id,
2834 membership.channel_id,
2835 membership.role,
2836 );
2837 let update = proto::UpdateUserChannels {
2838 channel_memberships: vec![proto::ChannelMembership {
2839 channel_id: membership.channel_id.to_proto(),
2840 role: membership.role.into(),
2841 }],
2842 ..Default::default()
2843 };
2844 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2845 session.peer.send(connection_id, update.clone())?;
2846 }
2847 }
2848
2849 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2850 if !role.can_see_channel(channel.visibility) {
2851 continue;
2852 }
2853
2854 let update = proto::UpdateChannels {
2855 channels: vec![channel.to_proto()],
2856 ..Default::default()
2857 };
2858 session.peer.send(connection_id, update.clone())?;
2859 }
2860
2861 Ok(())
2862}
2863
2864/// Delete a channel
2865async fn delete_channel(
2866 request: proto::DeleteChannel,
2867 response: Response<proto::DeleteChannel>,
2868 session: MessageContext,
2869) -> Result<()> {
2870 let db = session.db().await;
2871
2872 let channel_id = request.channel_id;
2873 let (root_channel, removed_channels) = db
2874 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2875 .await?;
2876 response.send(proto::Ack {})?;
2877
2878 // Notify members of removed channels
2879 let mut update = proto::UpdateChannels::default();
2880 update
2881 .delete_channels
2882 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2883
2884 let connection_pool = session.connection_pool().await;
2885 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2886 session.peer.send(connection_id, update.clone())?;
2887 }
2888
2889 Ok(())
2890}
2891
2892/// Invite someone to join a channel.
2893async fn invite_channel_member(
2894 request: proto::InviteChannelMember,
2895 response: Response<proto::InviteChannelMember>,
2896 session: MessageContext,
2897) -> Result<()> {
2898 let db = session.db().await;
2899 let channel_id = ChannelId::from_proto(request.channel_id);
2900 let invitee_id = UserId::from_proto(request.user_id);
2901 let InviteMemberResult {
2902 channel,
2903 notifications,
2904 } = db
2905 .invite_channel_member(
2906 channel_id,
2907 invitee_id,
2908 session.user_id(),
2909 request.role().into(),
2910 )
2911 .await?;
2912
2913 let update = proto::UpdateChannels {
2914 channel_invitations: vec![channel.to_proto()],
2915 ..Default::default()
2916 };
2917
2918 let connection_pool = session.connection_pool().await;
2919 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2920 session.peer.send(connection_id, update.clone())?;
2921 }
2922
2923 send_notifications(&connection_pool, &session.peer, notifications);
2924
2925 response.send(proto::Ack {})?;
2926 Ok(())
2927}
2928
2929/// remove someone from a channel
2930async fn remove_channel_member(
2931 request: proto::RemoveChannelMember,
2932 response: Response<proto::RemoveChannelMember>,
2933 session: MessageContext,
2934) -> Result<()> {
2935 let db = session.db().await;
2936 let channel_id = ChannelId::from_proto(request.channel_id);
2937 let member_id = UserId::from_proto(request.user_id);
2938
2939 let RemoveChannelMemberResult {
2940 membership_update,
2941 notification_id,
2942 } = db
2943 .remove_channel_member(channel_id, member_id, session.user_id())
2944 .await?;
2945
2946 let mut connection_pool = session.connection_pool().await;
2947 notify_membership_updated(
2948 &mut connection_pool,
2949 membership_update,
2950 member_id,
2951 &session.peer,
2952 );
2953 for connection_id in connection_pool.user_connection_ids(member_id) {
2954 if let Some(notification_id) = notification_id {
2955 session
2956 .peer
2957 .send(
2958 connection_id,
2959 proto::DeleteNotification {
2960 notification_id: notification_id.to_proto(),
2961 },
2962 )
2963 .trace_err();
2964 }
2965 }
2966
2967 response.send(proto::Ack {})?;
2968 Ok(())
2969}
2970
2971/// Toggle the channel between public and private.
2972/// Care is taken to maintain the invariant that public channels only descend from public channels,
2973/// (though members-only channels can appear at any point in the hierarchy).
2974async fn set_channel_visibility(
2975 request: proto::SetChannelVisibility,
2976 response: Response<proto::SetChannelVisibility>,
2977 session: MessageContext,
2978) -> Result<()> {
2979 let db = session.db().await;
2980 let channel_id = ChannelId::from_proto(request.channel_id);
2981 let visibility = request.visibility().into();
2982
2983 let channel_model = db
2984 .set_channel_visibility(channel_id, visibility, session.user_id())
2985 .await?;
2986 let root_id = channel_model.root_id();
2987 let channel = Channel::from_model(channel_model);
2988
2989 let mut connection_pool = session.connection_pool().await;
2990 for (user_id, role) in connection_pool
2991 .channel_user_ids(root_id)
2992 .collect::<Vec<_>>()
2993 .into_iter()
2994 {
2995 let update = if role.can_see_channel(channel.visibility) {
2996 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2997 proto::UpdateChannels {
2998 channels: vec![channel.to_proto()],
2999 ..Default::default()
3000 }
3001 } else {
3002 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3003 proto::UpdateChannels {
3004 delete_channels: vec![channel.id.to_proto()],
3005 ..Default::default()
3006 }
3007 };
3008
3009 for connection_id in connection_pool.user_connection_ids(user_id) {
3010 session.peer.send(connection_id, update.clone())?;
3011 }
3012 }
3013
3014 response.send(proto::Ack {})?;
3015 Ok(())
3016}
3017
3018/// Alter the role for a user in the channel.
3019async fn set_channel_member_role(
3020 request: proto::SetChannelMemberRole,
3021 response: Response<proto::SetChannelMemberRole>,
3022 session: MessageContext,
3023) -> Result<()> {
3024 let db = session.db().await;
3025 let channel_id = ChannelId::from_proto(request.channel_id);
3026 let member_id = UserId::from_proto(request.user_id);
3027 let result = db
3028 .set_channel_member_role(
3029 channel_id,
3030 session.user_id(),
3031 member_id,
3032 request.role().into(),
3033 )
3034 .await?;
3035
3036 match result {
3037 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3038 let mut connection_pool = session.connection_pool().await;
3039 notify_membership_updated(
3040 &mut connection_pool,
3041 membership_update,
3042 member_id,
3043 &session.peer,
3044 )
3045 }
3046 db::SetMemberRoleResult::InviteUpdated(channel) => {
3047 let update = proto::UpdateChannels {
3048 channel_invitations: vec![channel.to_proto()],
3049 ..Default::default()
3050 };
3051
3052 for connection_id in session
3053 .connection_pool()
3054 .await
3055 .user_connection_ids(member_id)
3056 {
3057 session.peer.send(connection_id, update.clone())?;
3058 }
3059 }
3060 }
3061
3062 response.send(proto::Ack {})?;
3063 Ok(())
3064}
3065
3066/// Change the name of a channel
3067async fn rename_channel(
3068 request: proto::RenameChannel,
3069 response: Response<proto::RenameChannel>,
3070 session: MessageContext,
3071) -> Result<()> {
3072 let db = session.db().await;
3073 let channel_id = ChannelId::from_proto(request.channel_id);
3074 let channel_model = db
3075 .rename_channel(channel_id, session.user_id(), &request.name)
3076 .await?;
3077 let root_id = channel_model.root_id();
3078 let channel = Channel::from_model(channel_model);
3079
3080 response.send(proto::RenameChannelResponse {
3081 channel: Some(channel.to_proto()),
3082 })?;
3083
3084 let connection_pool = session.connection_pool().await;
3085 let update = proto::UpdateChannels {
3086 channels: vec![channel.to_proto()],
3087 ..Default::default()
3088 };
3089 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3090 if role.can_see_channel(channel.visibility) {
3091 session.peer.send(connection_id, update.clone())?;
3092 }
3093 }
3094
3095 Ok(())
3096}
3097
3098/// Move a channel to a new parent.
3099async fn move_channel(
3100 request: proto::MoveChannel,
3101 response: Response<proto::MoveChannel>,
3102 session: MessageContext,
3103) -> Result<()> {
3104 let channel_id = ChannelId::from_proto(request.channel_id);
3105 let to = ChannelId::from_proto(request.to);
3106
3107 let (root_id, channels) = session
3108 .db()
3109 .await
3110 .move_channel(channel_id, to, session.user_id())
3111 .await?;
3112
3113 let connection_pool = session.connection_pool().await;
3114 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3115 let channels = channels
3116 .iter()
3117 .filter_map(|channel| {
3118 if role.can_see_channel(channel.visibility) {
3119 Some(channel.to_proto())
3120 } else {
3121 None
3122 }
3123 })
3124 .collect::<Vec<_>>();
3125 if channels.is_empty() {
3126 continue;
3127 }
3128
3129 let update = proto::UpdateChannels {
3130 channels,
3131 ..Default::default()
3132 };
3133
3134 session.peer.send(connection_id, update.clone())?;
3135 }
3136
3137 response.send(Ack {})?;
3138 Ok(())
3139}
3140
3141async fn reorder_channel(
3142 request: proto::ReorderChannel,
3143 response: Response<proto::ReorderChannel>,
3144 session: MessageContext,
3145) -> Result<()> {
3146 let channel_id = ChannelId::from_proto(request.channel_id);
3147 let direction = request.direction();
3148
3149 let updated_channels = session
3150 .db()
3151 .await
3152 .reorder_channel(channel_id, direction, session.user_id())
3153 .await?;
3154
3155 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3156 let connection_pool = session.connection_pool().await;
3157 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3158 let channels = updated_channels
3159 .iter()
3160 .filter_map(|channel| {
3161 if role.can_see_channel(channel.visibility) {
3162 Some(channel.to_proto())
3163 } else {
3164 None
3165 }
3166 })
3167 .collect::<Vec<_>>();
3168
3169 if channels.is_empty() {
3170 continue;
3171 }
3172
3173 let update = proto::UpdateChannels {
3174 channels,
3175 ..Default::default()
3176 };
3177
3178 session.peer.send(connection_id, update.clone())?;
3179 }
3180 }
3181
3182 response.send(Ack {})?;
3183 Ok(())
3184}
3185
3186/// Get the list of channel members
3187async fn get_channel_members(
3188 request: proto::GetChannelMembers,
3189 response: Response<proto::GetChannelMembers>,
3190 session: MessageContext,
3191) -> Result<()> {
3192 let db = session.db().await;
3193 let channel_id = ChannelId::from_proto(request.channel_id);
3194 let limit = if request.limit == 0 {
3195 u16::MAX as u64
3196 } else {
3197 request.limit
3198 };
3199 let (members, users) = db
3200 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3201 .await?;
3202 response.send(proto::GetChannelMembersResponse { members, users })?;
3203 Ok(())
3204}
3205
3206/// Accept or decline a channel invitation.
3207async fn respond_to_channel_invite(
3208 request: proto::RespondToChannelInvite,
3209 response: Response<proto::RespondToChannelInvite>,
3210 session: MessageContext,
3211) -> Result<()> {
3212 let db = session.db().await;
3213 let channel_id = ChannelId::from_proto(request.channel_id);
3214 let RespondToChannelInvite {
3215 membership_update,
3216 notifications,
3217 } = db
3218 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3219 .await?;
3220
3221 let mut connection_pool = session.connection_pool().await;
3222 if let Some(membership_update) = membership_update {
3223 notify_membership_updated(
3224 &mut connection_pool,
3225 membership_update,
3226 session.user_id(),
3227 &session.peer,
3228 );
3229 } else {
3230 let update = proto::UpdateChannels {
3231 remove_channel_invitations: vec![channel_id.to_proto()],
3232 ..Default::default()
3233 };
3234
3235 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3236 session.peer.send(connection_id, update.clone())?;
3237 }
3238 };
3239
3240 send_notifications(&connection_pool, &session.peer, notifications);
3241
3242 response.send(proto::Ack {})?;
3243
3244 Ok(())
3245}
3246
3247/// Join the channels' room
3248async fn join_channel(
3249 request: proto::JoinChannel,
3250 response: Response<proto::JoinChannel>,
3251 session: MessageContext,
3252) -> Result<()> {
3253 let channel_id = ChannelId::from_proto(request.channel_id);
3254 join_channel_internal(channel_id, Box::new(response), session).await
3255}
3256
3257trait JoinChannelInternalResponse {
3258 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3259}
3260impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3261 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3262 Response::<proto::JoinChannel>::send(self, result)
3263 }
3264}
3265impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3266 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3267 Response::<proto::JoinRoom>::send(self, result)
3268 }
3269}
3270
3271async fn join_channel_internal(
3272 channel_id: ChannelId,
3273 response: Box<impl JoinChannelInternalResponse>,
3274 session: MessageContext,
3275) -> Result<()> {
3276 let joined_room = {
3277 let mut db = session.db().await;
3278 // If zed quits without leaving the room, and the user re-opens zed before the
3279 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3280 // room they were in.
3281 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3282 tracing::info!(
3283 stale_connection_id = %connection,
3284 "cleaning up stale connection",
3285 );
3286 drop(db);
3287 leave_room_for_session(&session, connection).await?;
3288 db = session.db().await;
3289 }
3290
3291 let (joined_room, membership_updated, role) = db
3292 .join_channel(channel_id, session.user_id(), session.connection_id)
3293 .await?;
3294
3295 let live_kit_connection_info =
3296 session
3297 .app_state
3298 .livekit_client
3299 .as_ref()
3300 .and_then(|live_kit| {
3301 let (can_publish, token) = if role == ChannelRole::Guest {
3302 (
3303 false,
3304 live_kit
3305 .guest_token(
3306 &joined_room.room.livekit_room,
3307 &session.user_id().to_string(),
3308 )
3309 .trace_err()?,
3310 )
3311 } else {
3312 (
3313 true,
3314 live_kit
3315 .room_token(
3316 &joined_room.room.livekit_room,
3317 &session.user_id().to_string(),
3318 )
3319 .trace_err()?,
3320 )
3321 };
3322
3323 Some(LiveKitConnectionInfo {
3324 server_url: live_kit.url().into(),
3325 token,
3326 can_publish,
3327 })
3328 });
3329
3330 response.send(proto::JoinRoomResponse {
3331 room: Some(joined_room.room.clone()),
3332 channel_id: joined_room
3333 .channel
3334 .as_ref()
3335 .map(|channel| channel.id.to_proto()),
3336 live_kit_connection_info,
3337 })?;
3338
3339 let mut connection_pool = session.connection_pool().await;
3340 if let Some(membership_updated) = membership_updated {
3341 notify_membership_updated(
3342 &mut connection_pool,
3343 membership_updated,
3344 session.user_id(),
3345 &session.peer,
3346 );
3347 }
3348
3349 room_updated(&joined_room.room, &session.peer);
3350
3351 joined_room
3352 };
3353
3354 channel_updated(
3355 &joined_room.channel.context("channel not returned")?,
3356 &joined_room.room,
3357 &session.peer,
3358 &*session.connection_pool().await,
3359 );
3360
3361 update_user_contacts(session.user_id(), &session).await?;
3362 Ok(())
3363}
3364
3365/// Start editing the channel notes
3366async fn join_channel_buffer(
3367 request: proto::JoinChannelBuffer,
3368 response: Response<proto::JoinChannelBuffer>,
3369 session: MessageContext,
3370) -> Result<()> {
3371 let db = session.db().await;
3372 let channel_id = ChannelId::from_proto(request.channel_id);
3373
3374 let open_response = db
3375 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3376 .await?;
3377
3378 let collaborators = open_response.collaborators.clone();
3379 response.send(open_response)?;
3380
3381 let update = UpdateChannelBufferCollaborators {
3382 channel_id: channel_id.to_proto(),
3383 collaborators: collaborators.clone(),
3384 };
3385 channel_buffer_updated(
3386 session.connection_id,
3387 collaborators
3388 .iter()
3389 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3390 &update,
3391 &session.peer,
3392 );
3393
3394 Ok(())
3395}
3396
3397/// Edit the channel notes
3398async fn update_channel_buffer(
3399 request: proto::UpdateChannelBuffer,
3400 session: MessageContext,
3401) -> Result<()> {
3402 let db = session.db().await;
3403 let channel_id = ChannelId::from_proto(request.channel_id);
3404
3405 let (collaborators, epoch, version) = db
3406 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3407 .await?;
3408
3409 channel_buffer_updated(
3410 session.connection_id,
3411 collaborators.clone(),
3412 &proto::UpdateChannelBuffer {
3413 channel_id: channel_id.to_proto(),
3414 operations: request.operations,
3415 },
3416 &session.peer,
3417 );
3418
3419 let pool = &*session.connection_pool().await;
3420
3421 let non_collaborators =
3422 pool.channel_connection_ids(channel_id)
3423 .filter_map(|(connection_id, _)| {
3424 if collaborators.contains(&connection_id) {
3425 None
3426 } else {
3427 Some(connection_id)
3428 }
3429 });
3430
3431 broadcast(None, non_collaborators, |peer_id| {
3432 session.peer.send(
3433 peer_id,
3434 proto::UpdateChannels {
3435 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3436 channel_id: channel_id.to_proto(),
3437 epoch: epoch as u64,
3438 version: version.clone(),
3439 }],
3440 ..Default::default()
3441 },
3442 )
3443 });
3444
3445 Ok(())
3446}
3447
3448/// Rejoin the channel notes after a connection blip
3449async fn rejoin_channel_buffers(
3450 request: proto::RejoinChannelBuffers,
3451 response: Response<proto::RejoinChannelBuffers>,
3452 session: MessageContext,
3453) -> Result<()> {
3454 let db = session.db().await;
3455 let buffers = db
3456 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3457 .await?;
3458
3459 for rejoined_buffer in &buffers {
3460 let collaborators_to_notify = rejoined_buffer
3461 .buffer
3462 .collaborators
3463 .iter()
3464 .filter_map(|c| Some(c.peer_id?.into()));
3465 channel_buffer_updated(
3466 session.connection_id,
3467 collaborators_to_notify,
3468 &proto::UpdateChannelBufferCollaborators {
3469 channel_id: rejoined_buffer.buffer.channel_id,
3470 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3471 },
3472 &session.peer,
3473 );
3474 }
3475
3476 response.send(proto::RejoinChannelBuffersResponse {
3477 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3478 })?;
3479
3480 Ok(())
3481}
3482
3483/// Stop editing the channel notes
3484async fn leave_channel_buffer(
3485 request: proto::LeaveChannelBuffer,
3486 response: Response<proto::LeaveChannelBuffer>,
3487 session: MessageContext,
3488) -> Result<()> {
3489 let db = session.db().await;
3490 let channel_id = ChannelId::from_proto(request.channel_id);
3491
3492 let left_buffer = db
3493 .leave_channel_buffer(channel_id, session.connection_id)
3494 .await?;
3495
3496 response.send(Ack {})?;
3497
3498 channel_buffer_updated(
3499 session.connection_id,
3500 left_buffer.connections,
3501 &proto::UpdateChannelBufferCollaborators {
3502 channel_id: channel_id.to_proto(),
3503 collaborators: left_buffer.collaborators,
3504 },
3505 &session.peer,
3506 );
3507
3508 Ok(())
3509}
3510
3511fn channel_buffer_updated<T: EnvelopedMessage>(
3512 sender_id: ConnectionId,
3513 collaborators: impl IntoIterator<Item = ConnectionId>,
3514 message: &T,
3515 peer: &Peer,
3516) {
3517 broadcast(Some(sender_id), collaborators, |peer_id| {
3518 peer.send(peer_id, message.clone())
3519 });
3520}
3521
3522fn send_notifications(
3523 connection_pool: &ConnectionPool,
3524 peer: &Peer,
3525 notifications: db::NotificationBatch,
3526) {
3527 for (user_id, notification) in notifications {
3528 for connection_id in connection_pool.user_connection_ids(user_id) {
3529 if let Err(error) = peer.send(
3530 connection_id,
3531 proto::AddNotification {
3532 notification: Some(notification.clone()),
3533 },
3534 ) {
3535 tracing::error!(
3536 "failed to send notification to {:?} {}",
3537 connection_id,
3538 error
3539 );
3540 }
3541 }
3542 }
3543}
3544
3545/// Send a message to the channel
3546async fn send_channel_message(
3547 _request: proto::SendChannelMessage,
3548 _response: Response<proto::SendChannelMessage>,
3549 _session: MessageContext,
3550) -> Result<()> {
3551 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3552}
3553
3554/// Delete a channel message
3555async fn remove_channel_message(
3556 _request: proto::RemoveChannelMessage,
3557 _response: Response<proto::RemoveChannelMessage>,
3558 _session: MessageContext,
3559) -> Result<()> {
3560 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3561}
3562
3563async fn update_channel_message(
3564 _request: proto::UpdateChannelMessage,
3565 _response: Response<proto::UpdateChannelMessage>,
3566 _session: MessageContext,
3567) -> Result<()> {
3568 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3569}
3570
3571/// Mark a channel message as read
3572async fn acknowledge_channel_message(
3573 _request: proto::AckChannelMessage,
3574 _session: MessageContext,
3575) -> Result<()> {
3576 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3577}
3578
3579/// Mark a buffer version as synced
3580async fn acknowledge_buffer_version(
3581 request: proto::AckBufferOperation,
3582 session: MessageContext,
3583) -> Result<()> {
3584 let buffer_id = BufferId::from_proto(request.buffer_id);
3585 session
3586 .db()
3587 .await
3588 .observe_buffer_version(
3589 buffer_id,
3590 session.user_id(),
3591 request.epoch as i32,
3592 &request.version,
3593 )
3594 .await?;
3595 Ok(())
3596}
3597
3598/// Start receiving chat updates for a channel
3599async fn join_channel_chat(
3600 _request: proto::JoinChannelChat,
3601 _response: Response<proto::JoinChannelChat>,
3602 _session: MessageContext,
3603) -> Result<()> {
3604 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3605}
3606
3607/// Stop receiving chat updates for a channel
3608async fn leave_channel_chat(
3609 _request: proto::LeaveChannelChat,
3610 _session: MessageContext,
3611) -> Result<()> {
3612 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3613}
3614
3615/// Retrieve the chat history for a channel
3616async fn get_channel_messages(
3617 _request: proto::GetChannelMessages,
3618 _response: Response<proto::GetChannelMessages>,
3619 _session: MessageContext,
3620) -> Result<()> {
3621 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3622}
3623
3624/// Retrieve specific chat messages
3625async fn get_channel_messages_by_id(
3626 _request: proto::GetChannelMessagesById,
3627 _response: Response<proto::GetChannelMessagesById>,
3628 _session: MessageContext,
3629) -> Result<()> {
3630 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3631}
3632
3633/// Retrieve the current users notifications
3634async fn get_notifications(
3635 request: proto::GetNotifications,
3636 response: Response<proto::GetNotifications>,
3637 session: MessageContext,
3638) -> Result<()> {
3639 let notifications = session
3640 .db()
3641 .await
3642 .get_notifications(
3643 session.user_id(),
3644 NOTIFICATION_COUNT_PER_PAGE,
3645 request.before_id.map(db::NotificationId::from_proto),
3646 )
3647 .await?;
3648 response.send(proto::GetNotificationsResponse {
3649 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3650 notifications,
3651 })?;
3652 Ok(())
3653}
3654
3655/// Mark notifications as read
3656async fn mark_notification_as_read(
3657 request: proto::MarkNotificationRead,
3658 response: Response<proto::MarkNotificationRead>,
3659 session: MessageContext,
3660) -> Result<()> {
3661 let database = &session.db().await;
3662 let notifications = database
3663 .mark_notification_as_read_by_id(
3664 session.user_id(),
3665 NotificationId::from_proto(request.notification_id),
3666 )
3667 .await?;
3668 send_notifications(
3669 &*session.connection_pool().await,
3670 &session.peer,
3671 notifications,
3672 );
3673 response.send(proto::Ack {})?;
3674 Ok(())
3675}
3676
3677fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3678 let message = match message {
3679 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3680 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3681 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3682 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3683 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3684 code: frame.code.into(),
3685 reason: frame.reason.as_str().to_owned().into(),
3686 })),
3687 // We should never receive a frame while reading the message, according
3688 // to the `tungstenite` maintainers:
3689 //
3690 // > It cannot occur when you read messages from the WebSocket, but it
3691 // > can be used when you want to send the raw frames (e.g. you want to
3692 // > send the frames to the WebSocket without composing the full message first).
3693 // >
3694 // > — https://github.com/snapview/tungstenite-rs/issues/268
3695 TungsteniteMessage::Frame(_) => {
3696 bail!("received an unexpected frame while reading the message")
3697 }
3698 };
3699
3700 Ok(message)
3701}
3702
3703fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3704 match message {
3705 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3706 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3707 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3708 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3709 AxumMessage::Close(frame) => {
3710 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3711 code: frame.code.into(),
3712 reason: frame.reason.as_ref().into(),
3713 }))
3714 }
3715 }
3716}
3717
3718fn notify_membership_updated(
3719 connection_pool: &mut ConnectionPool,
3720 result: MembershipUpdated,
3721 user_id: UserId,
3722 peer: &Peer,
3723) {
3724 for membership in &result.new_channels.channel_memberships {
3725 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3726 }
3727 for channel_id in &result.removed_channels {
3728 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3729 }
3730
3731 let user_channels_update = proto::UpdateUserChannels {
3732 channel_memberships: result
3733 .new_channels
3734 .channel_memberships
3735 .iter()
3736 .map(|cm| proto::ChannelMembership {
3737 channel_id: cm.channel_id.to_proto(),
3738 role: cm.role.into(),
3739 })
3740 .collect(),
3741 ..Default::default()
3742 };
3743
3744 let mut update = build_channels_update(result.new_channels);
3745 update.delete_channels = result
3746 .removed_channels
3747 .into_iter()
3748 .map(|id| id.to_proto())
3749 .collect();
3750 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3751
3752 for connection_id in connection_pool.user_connection_ids(user_id) {
3753 peer.send(connection_id, user_channels_update.clone())
3754 .trace_err();
3755 peer.send(connection_id, update.clone()).trace_err();
3756 }
3757}
3758
3759fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3760 proto::UpdateUserChannels {
3761 channel_memberships: channels
3762 .channel_memberships
3763 .iter()
3764 .map(|m| proto::ChannelMembership {
3765 channel_id: m.channel_id.to_proto(),
3766 role: m.role.into(),
3767 })
3768 .collect(),
3769 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3770 }
3771}
3772
3773fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3774 let mut update = proto::UpdateChannels::default();
3775
3776 for channel in channels.channels {
3777 update.channels.push(channel.to_proto());
3778 }
3779
3780 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3781
3782 for (channel_id, participants) in channels.channel_participants {
3783 update
3784 .channel_participants
3785 .push(proto::ChannelParticipants {
3786 channel_id: channel_id.to_proto(),
3787 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3788 });
3789 }
3790
3791 for channel in channels.invited_channels {
3792 update.channel_invitations.push(channel.to_proto());
3793 }
3794
3795 update
3796}
3797
3798fn build_initial_contacts_update(
3799 contacts: Vec<db::Contact>,
3800 pool: &ConnectionPool,
3801) -> proto::UpdateContacts {
3802 let mut update = proto::UpdateContacts::default();
3803
3804 for contact in contacts {
3805 match contact {
3806 db::Contact::Accepted { user_id, busy } => {
3807 update.contacts.push(contact_for_user(user_id, busy, pool));
3808 }
3809 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3810 db::Contact::Incoming { user_id } => {
3811 update
3812 .incoming_requests
3813 .push(proto::IncomingContactRequest {
3814 requester_id: user_id.to_proto(),
3815 })
3816 }
3817 }
3818 }
3819
3820 update
3821}
3822
3823fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3824 proto::Contact {
3825 user_id: user_id.to_proto(),
3826 online: pool.is_user_online(user_id),
3827 busy,
3828 }
3829}
3830
3831fn room_updated(room: &proto::Room, peer: &Peer) {
3832 broadcast(
3833 None,
3834 room.participants
3835 .iter()
3836 .filter_map(|participant| Some(participant.peer_id?.into())),
3837 |peer_id| {
3838 peer.send(
3839 peer_id,
3840 proto::RoomUpdated {
3841 room: Some(room.clone()),
3842 },
3843 )
3844 },
3845 );
3846}
3847
3848fn channel_updated(
3849 channel: &db::channel::Model,
3850 room: &proto::Room,
3851 peer: &Peer,
3852 pool: &ConnectionPool,
3853) {
3854 let participants = room
3855 .participants
3856 .iter()
3857 .map(|p| p.user_id)
3858 .collect::<Vec<_>>();
3859
3860 broadcast(
3861 None,
3862 pool.channel_connection_ids(channel.root_id())
3863 .filter_map(|(channel_id, role)| {
3864 role.can_see_channel(channel.visibility)
3865 .then_some(channel_id)
3866 }),
3867 |peer_id| {
3868 peer.send(
3869 peer_id,
3870 proto::UpdateChannels {
3871 channel_participants: vec![proto::ChannelParticipants {
3872 channel_id: channel.id.to_proto(),
3873 participant_user_ids: participants.clone(),
3874 }],
3875 ..Default::default()
3876 },
3877 )
3878 },
3879 );
3880}
3881
3882async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3883 let db = session.db().await;
3884
3885 let contacts = db.get_contacts(user_id).await?;
3886 let busy = db.is_user_busy(user_id).await?;
3887
3888 let pool = session.connection_pool().await;
3889 let updated_contact = contact_for_user(user_id, busy, &pool);
3890 for contact in contacts {
3891 if let db::Contact::Accepted {
3892 user_id: contact_user_id,
3893 ..
3894 } = contact
3895 {
3896 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3897 session
3898 .peer
3899 .send(
3900 contact_conn_id,
3901 proto::UpdateContacts {
3902 contacts: vec![updated_contact.clone()],
3903 remove_contacts: Default::default(),
3904 incoming_requests: Default::default(),
3905 remove_incoming_requests: Default::default(),
3906 outgoing_requests: Default::default(),
3907 remove_outgoing_requests: Default::default(),
3908 },
3909 )
3910 .trace_err();
3911 }
3912 }
3913 }
3914 Ok(())
3915}
3916
3917async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3918 let mut contacts_to_update = HashSet::default();
3919
3920 let room_id;
3921 let canceled_calls_to_user_ids;
3922 let livekit_room;
3923 let delete_livekit_room;
3924 let room;
3925 let channel;
3926
3927 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3928 contacts_to_update.insert(session.user_id());
3929
3930 for project in left_room.left_projects.values() {
3931 project_left(project, session);
3932 }
3933
3934 room_id = RoomId::from_proto(left_room.room.id);
3935 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3936 livekit_room = mem::take(&mut left_room.room.livekit_room);
3937 delete_livekit_room = left_room.deleted;
3938 room = mem::take(&mut left_room.room);
3939 channel = mem::take(&mut left_room.channel);
3940
3941 room_updated(&room, &session.peer);
3942 } else {
3943 return Ok(());
3944 }
3945
3946 if let Some(channel) = channel {
3947 channel_updated(
3948 &channel,
3949 &room,
3950 &session.peer,
3951 &*session.connection_pool().await,
3952 );
3953 }
3954
3955 {
3956 let pool = session.connection_pool().await;
3957 for canceled_user_id in canceled_calls_to_user_ids {
3958 for connection_id in pool.user_connection_ids(canceled_user_id) {
3959 session
3960 .peer
3961 .send(
3962 connection_id,
3963 proto::CallCanceled {
3964 room_id: room_id.to_proto(),
3965 },
3966 )
3967 .trace_err();
3968 }
3969 contacts_to_update.insert(canceled_user_id);
3970 }
3971 }
3972
3973 for contact_user_id in contacts_to_update {
3974 update_user_contacts(contact_user_id, session).await?;
3975 }
3976
3977 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3978 live_kit
3979 .remove_participant(livekit_room.clone(), session.user_id().to_string())
3980 .await
3981 .trace_err();
3982
3983 if delete_livekit_room {
3984 live_kit.delete_room(livekit_room).await.trace_err();
3985 }
3986 }
3987
3988 Ok(())
3989}
3990
3991async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3992 let left_channel_buffers = session
3993 .db()
3994 .await
3995 .leave_channel_buffers(session.connection_id)
3996 .await?;
3997
3998 for left_buffer in left_channel_buffers {
3999 channel_buffer_updated(
4000 session.connection_id,
4001 left_buffer.connections,
4002 &proto::UpdateChannelBufferCollaborators {
4003 channel_id: left_buffer.channel_id.to_proto(),
4004 collaborators: left_buffer.collaborators,
4005 },
4006 &session.peer,
4007 );
4008 }
4009
4010 Ok(())
4011}
4012
4013fn project_left(project: &db::LeftProject, session: &Session) {
4014 for connection_id in &project.connection_ids {
4015 if project.should_unshare {
4016 session
4017 .peer
4018 .send(
4019 *connection_id,
4020 proto::UnshareProject {
4021 project_id: project.id.to_proto(),
4022 },
4023 )
4024 .trace_err();
4025 } else {
4026 session
4027 .peer
4028 .send(
4029 *connection_id,
4030 proto::RemoveProjectCollaborator {
4031 project_id: project.id.to_proto(),
4032 peer_id: Some(session.connection_id.into()),
4033 },
4034 )
4035 .trace_err();
4036 }
4037 }
4038}
4039
4040async fn share_agent_thread(
4041 request: proto::ShareAgentThread,
4042 response: Response<proto::ShareAgentThread>,
4043 session: MessageContext,
4044) -> Result<()> {
4045 let user_id = session.user_id();
4046
4047 let share_id = SharedThreadId::from_proto(request.session_id.clone())
4048 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4049
4050 session
4051 .db()
4052 .await
4053 .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4054 .await?;
4055
4056 response.send(proto::Ack {})?;
4057
4058 Ok(())
4059}
4060
4061async fn get_shared_agent_thread(
4062 request: proto::GetSharedAgentThread,
4063 response: Response<proto::GetSharedAgentThread>,
4064 session: MessageContext,
4065) -> Result<()> {
4066 let share_id = SharedThreadId::from_proto(request.session_id)
4067 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4068
4069 let result = session.db().await.get_shared_thread(share_id).await?;
4070
4071 match result {
4072 Some((thread, username)) => {
4073 response.send(proto::GetSharedAgentThreadResponse {
4074 title: thread.title,
4075 thread_data: thread.data,
4076 sharer_username: username,
4077 created_at: thread.created_at.and_utc().to_rfc3339(),
4078 })?;
4079 }
4080 None => {
4081 return Err(anyhow!("Shared thread not found").into());
4082 }
4083 }
4084
4085 Ok(())
4086}
4087
4088pub trait ResultExt {
4089 type Ok;
4090
4091 fn trace_err(self) -> Option<Self::Ok>;
4092}
4093
4094impl<T, E> ResultExt for Result<T, E>
4095where
4096 E: std::fmt::Debug,
4097{
4098 type Ok = T;
4099
4100 #[track_caller]
4101 fn trace_err(self) -> Option<T> {
4102 match self {
4103 Ok(value) => Some(value),
4104 Err(error) => {
4105 tracing::error!("{:?}", error);
4106 None
4107 }
4108 }
4109 }
4110}