1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
10 UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use futures::TryFutureExt as _;
36use rpc::proto::split_repository_update;
37use tracing::Span;
38use util::paths::PathStyle;
39
40use futures::{
41 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
42 stream::FuturesUnordered,
43};
44use prometheus::{IntGauge, register_int_gauge};
45use rpc::{
46 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
47 proto::{
48 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
49 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
50 },
51};
52use semver::Version;
53use serde::{Serialize, Serializer};
54use std::{
55 any::TypeId,
56 future::Future,
57 marker::PhantomData,
58 mem,
59 net::SocketAddr,
60 ops::{Deref, DerefMut},
61 rc::Rc,
62 sync::{
63 Arc, OnceLock,
64 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
65 },
66 time::{Duration, Instant},
67};
68use tokio::sync::{Semaphore, watch};
69use tower::ServiceBuilder;
70use tracing::{
71 Instrument,
72 field::{self},
73 info_span, instrument,
74};
75
76pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
77
78// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
79pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
80
81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
82const MAX_CONCURRENT_CONNECTIONS: usize = 512;
83
84static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
85
86const TOTAL_DURATION_MS: &str = "total_duration_ms";
87const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
88const QUEUE_DURATION_MS: &str = "queue_duration_ms";
89const HOST_WAITING_MS: &str = "host_waiting_ms";
90
91type MessageHandler =
92 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
93
94pub struct ConnectionGuard;
95
96impl ConnectionGuard {
97 pub fn try_acquire() -> Result<Self, ()> {
98 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
99 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
100 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
101 tracing::error!(
102 "too many concurrent connections: {}",
103 current_connections + 1
104 );
105 return Err(());
106 }
107 Ok(ConnectionGuard)
108 }
109}
110
111impl Drop for ConnectionGuard {
112 fn drop(&mut self) {
113 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
114 }
115}
116
117struct Response<R> {
118 peer: Arc<Peer>,
119 receipt: Receipt<R>,
120 responded: Arc<AtomicBool>,
121}
122
123impl<R: RequestMessage> Response<R> {
124 fn send(self, payload: R::Response) -> Result<()> {
125 self.responded.store(true, SeqCst);
126 self.peer.respond(self.receipt, payload)?;
127 Ok(())
128 }
129}
130
131#[derive(Clone, Debug)]
132pub enum Principal {
133 User(User),
134 Impersonated { user: User, admin: User },
135}
136
137impl Principal {
138 fn update_span(&self, span: &tracing::Span) {
139 match &self {
140 Principal::User(user) => {
141 span.record("user_id", user.id.0);
142 span.record("login", &user.github_login);
143 }
144 Principal::Impersonated { user, admin } => {
145 span.record("user_id", user.id.0);
146 span.record("login", &user.github_login);
147 span.record("impersonator", &admin.github_login);
148 }
149 }
150 }
151}
152
153#[derive(Clone)]
154struct MessageContext {
155 session: Session,
156 span: tracing::Span,
157}
158
159impl Deref for MessageContext {
160 type Target = Session;
161
162 fn deref(&self) -> &Self::Target {
163 &self.session
164 }
165}
166
167impl MessageContext {
168 pub fn forward_request<T: RequestMessage>(
169 &self,
170 receiver_id: ConnectionId,
171 request: T,
172 ) -> impl Future<Output = anyhow::Result<T::Response>> {
173 let request_start_time = Instant::now();
174 let span = self.span.clone();
175 tracing::info!("start forwarding request");
176 self.peer
177 .forward_request(self.connection_id, receiver_id, request)
178 .inspect(move |_| {
179 span.record(
180 HOST_WAITING_MS,
181 request_start_time.elapsed().as_micros() as f64 / 1000.0,
182 );
183 })
184 .inspect_err(|_| tracing::error!("error forwarding request"))
185 .inspect_ok(|_| tracing::info!("finished forwarding request"))
186 }
187}
188
189#[derive(Clone)]
190struct Session {
191 principal: Principal,
192 connection_id: ConnectionId,
193 db: Arc<tokio::sync::Mutex<DbHandle>>,
194 peer: Arc<Peer>,
195 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
196 app_state: Arc<AppState>,
197 /// The GeoIP country code for the user.
198 #[allow(unused)]
199 geoip_country_code: Option<String>,
200 #[allow(unused)]
201 system_id: Option<String>,
202 _executor: Executor,
203}
204
205impl Session {
206 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
207 #[cfg(test)]
208 tokio::task::yield_now().await;
209 let guard = self.db.lock().await;
210 #[cfg(test)]
211 tokio::task::yield_now().await;
212 guard
213 }
214
215 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
216 #[cfg(test)]
217 tokio::task::yield_now().await;
218 let guard = self.connection_pool.lock();
219 ConnectionPoolGuard {
220 guard,
221 _not_send: PhantomData,
222 }
223 }
224
225 #[expect(dead_code)]
226 fn is_staff(&self) -> bool {
227 match &self.principal {
228 Principal::User(user) => user.admin,
229 Principal::Impersonated { .. } => true,
230 }
231 }
232
233 fn user_id(&self) -> UserId {
234 match &self.principal {
235 Principal::User(user) => user.id,
236 Principal::Impersonated { user, .. } => user.id,
237 }
238 }
239}
240
241impl Debug for Session {
242 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
243 let mut result = f.debug_struct("Session");
244 match &self.principal {
245 Principal::User(user) => {
246 result.field("user", &user.github_login);
247 }
248 Principal::Impersonated { user, admin } => {
249 result.field("user", &user.github_login);
250 result.field("impersonator", &admin.github_login);
251 }
252 }
253 result.field("connection_id", &self.connection_id).finish()
254 }
255}
256
257struct DbHandle(Arc<Database>);
258
259impl Deref for DbHandle {
260 type Target = Database;
261
262 fn deref(&self) -> &Self::Target {
263 self.0.as_ref()
264 }
265}
266
267pub struct Server {
268 id: parking_lot::Mutex<ServerId>,
269 peer: Arc<Peer>,
270 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
271 app_state: Arc<AppState>,
272 handlers: HashMap<TypeId, MessageHandler>,
273 teardown: watch::Sender<bool>,
274}
275
276pub(crate) struct ConnectionPoolGuard<'a> {
277 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
278 _not_send: PhantomData<Rc<()>>,
279}
280
281#[derive(Serialize)]
282pub struct ServerSnapshot<'a> {
283 peer: &'a Peer,
284 #[serde(serialize_with = "serialize_deref")]
285 connection_pool: ConnectionPoolGuard<'a>,
286}
287
288pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
289where
290 S: Serializer,
291 T: Deref<Target = U>,
292 U: Serialize,
293{
294 Serialize::serialize(value.deref(), serializer)
295}
296
297impl Server {
298 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
299 let mut server = Self {
300 id: parking_lot::Mutex::new(id),
301 peer: Peer::new(id.0 as u32),
302 app_state,
303 connection_pool: Default::default(),
304 handlers: Default::default(),
305 teardown: watch::channel(false).0,
306 };
307
308 server
309 .add_request_handler(ping)
310 .add_request_handler(create_room)
311 .add_request_handler(join_room)
312 .add_request_handler(rejoin_room)
313 .add_request_handler(leave_room)
314 .add_request_handler(set_room_participant_role)
315 .add_request_handler(call)
316 .add_request_handler(cancel_call)
317 .add_message_handler(decline_call)
318 .add_request_handler(update_participant_location)
319 .add_request_handler(share_project)
320 .add_message_handler(unshare_project)
321 .add_request_handler(join_project)
322 .add_message_handler(leave_project)
323 .add_request_handler(update_project)
324 .add_request_handler(update_worktree)
325 .add_request_handler(update_repository)
326 .add_request_handler(remove_repository)
327 .add_message_handler(start_language_server)
328 .add_message_handler(update_language_server)
329 .add_message_handler(update_diagnostic_summary)
330 .add_message_handler(update_worktree_settings)
331 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
332 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
333 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
334 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
335 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
336 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
337 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
338 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
339 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
340 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
341 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
342 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
343 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
344 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
345 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
346 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
347 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
348 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
349 .add_request_handler(
350 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
351 )
352 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
353 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
354 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
355 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
356 .add_request_handler(
357 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
358 )
359 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
360 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
361 .add_request_handler(
362 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
363 )
364 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
365 .add_request_handler(
366 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
367 )
368 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
369 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
370 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
371 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
372 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
373 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
374 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
375 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
376 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
377 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
378 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
379 .add_request_handler(
380 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
381 )
382 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
383 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
384 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
385 .add_request_handler(lsp_query)
386 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
387 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
388 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
389 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
390 .add_message_handler(create_buffer_for_peer)
391 .add_message_handler(create_image_for_peer)
392 .add_request_handler(update_buffer)
393 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
394 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
395 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
396 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
397 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
398 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
399 .add_message_handler(
400 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
401 )
402 .add_request_handler(get_users)
403 .add_request_handler(fuzzy_search_users)
404 .add_request_handler(request_contact)
405 .add_request_handler(remove_contact)
406 .add_request_handler(respond_to_contact_request)
407 .add_message_handler(subscribe_to_channels)
408 .add_request_handler(create_channel)
409 .add_request_handler(delete_channel)
410 .add_request_handler(invite_channel_member)
411 .add_request_handler(remove_channel_member)
412 .add_request_handler(set_channel_member_role)
413 .add_request_handler(set_channel_visibility)
414 .add_request_handler(rename_channel)
415 .add_request_handler(join_channel_buffer)
416 .add_request_handler(leave_channel_buffer)
417 .add_message_handler(update_channel_buffer)
418 .add_request_handler(rejoin_channel_buffers)
419 .add_request_handler(get_channel_members)
420 .add_request_handler(respond_to_channel_invite)
421 .add_request_handler(join_channel)
422 .add_request_handler(join_channel_chat)
423 .add_message_handler(leave_channel_chat)
424 .add_request_handler(send_channel_message)
425 .add_request_handler(remove_channel_message)
426 .add_request_handler(update_channel_message)
427 .add_request_handler(get_channel_messages)
428 .add_request_handler(get_channel_messages_by_id)
429 .add_request_handler(get_notifications)
430 .add_request_handler(mark_notification_as_read)
431 .add_request_handler(move_channel)
432 .add_request_handler(reorder_channel)
433 .add_request_handler(follow)
434 .add_message_handler(unfollow)
435 .add_message_handler(update_followers)
436 .add_message_handler(acknowledge_channel_message)
437 .add_message_handler(acknowledge_buffer_version)
438 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
439 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
440 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
441 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
442 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
443 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
444 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
445 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
446 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
447 .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
448 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
449 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
450 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
451 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
452 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
453 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
454 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
455 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
456 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
457 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
458 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
459 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
460 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
461 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
462 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
463 .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
464 .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
465 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
466 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
467 .add_message_handler(update_context)
468 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
469 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
470 .add_request_handler(share_agent_thread)
471 .add_request_handler(get_shared_agent_thread);
472
473 Arc::new(server)
474 }
475
476 pub async fn start(&self) -> Result<()> {
477 let server_id = *self.id.lock();
478 let app_state = self.app_state.clone();
479 let peer = self.peer.clone();
480 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
481 let pool = self.connection_pool.clone();
482 let livekit_client = self.app_state.livekit_client.clone();
483
484 let span = info_span!("start server");
485 self.app_state.executor.spawn_detached(
486 async move {
487 tracing::info!("waiting for cleanup timeout");
488 timeout.await;
489 tracing::info!("cleanup timeout expired, retrieving stale rooms");
490
491 app_state
492 .db
493 .delete_stale_channel_chat_participants(
494 &app_state.config.zed_environment,
495 server_id,
496 )
497 .await
498 .trace_err();
499
500 if let Some((room_ids, channel_ids)) = app_state
501 .db
502 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
503 .await
504 .trace_err()
505 {
506 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
507 tracing::info!(
508 stale_channel_buffer_count = channel_ids.len(),
509 "retrieved stale channel buffers"
510 );
511
512 for channel_id in channel_ids {
513 if let Some(refreshed_channel_buffer) = app_state
514 .db
515 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
516 .await
517 .trace_err()
518 {
519 for connection_id in refreshed_channel_buffer.connection_ids {
520 peer.send(
521 connection_id,
522 proto::UpdateChannelBufferCollaborators {
523 channel_id: channel_id.to_proto(),
524 collaborators: refreshed_channel_buffer
525 .collaborators
526 .clone(),
527 },
528 )
529 .trace_err();
530 }
531 }
532 }
533
534 for room_id in room_ids {
535 let mut contacts_to_update = HashSet::default();
536 let mut canceled_calls_to_user_ids = Vec::new();
537 let mut livekit_room = String::new();
538 let mut delete_livekit_room = false;
539
540 if let Some(mut refreshed_room) = app_state
541 .db
542 .clear_stale_room_participants(room_id, server_id)
543 .await
544 .trace_err()
545 {
546 tracing::info!(
547 room_id = room_id.0,
548 new_participant_count = refreshed_room.room.participants.len(),
549 "refreshed room"
550 );
551 room_updated(&refreshed_room.room, &peer);
552 if let Some(channel) = refreshed_room.channel.as_ref() {
553 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
554 }
555 contacts_to_update
556 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
557 contacts_to_update
558 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
559 canceled_calls_to_user_ids =
560 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
561 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
562 delete_livekit_room = refreshed_room.room.participants.is_empty();
563 }
564
565 {
566 let pool = pool.lock();
567 for canceled_user_id in canceled_calls_to_user_ids {
568 for connection_id in pool.user_connection_ids(canceled_user_id) {
569 peer.send(
570 connection_id,
571 proto::CallCanceled {
572 room_id: room_id.to_proto(),
573 },
574 )
575 .trace_err();
576 }
577 }
578 }
579
580 for user_id in contacts_to_update {
581 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
582 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
583 if let Some((busy, contacts)) = busy.zip(contacts) {
584 let pool = pool.lock();
585 let updated_contact = contact_for_user(user_id, busy, &pool);
586 for contact in contacts {
587 if let db::Contact::Accepted {
588 user_id: contact_user_id,
589 ..
590 } = contact
591 {
592 for contact_conn_id in
593 pool.user_connection_ids(contact_user_id)
594 {
595 peer.send(
596 contact_conn_id,
597 proto::UpdateContacts {
598 contacts: vec![updated_contact.clone()],
599 remove_contacts: Default::default(),
600 incoming_requests: Default::default(),
601 remove_incoming_requests: Default::default(),
602 outgoing_requests: Default::default(),
603 remove_outgoing_requests: Default::default(),
604 },
605 )
606 .trace_err();
607 }
608 }
609 }
610 }
611 }
612
613 if let Some(live_kit) = livekit_client.as_ref()
614 && delete_livekit_room
615 {
616 live_kit.delete_room(livekit_room).await.trace_err();
617 }
618 }
619 }
620
621 app_state
622 .db
623 .delete_stale_channel_chat_participants(
624 &app_state.config.zed_environment,
625 server_id,
626 )
627 .await
628 .trace_err();
629
630 app_state
631 .db
632 .clear_old_worktree_entries(server_id)
633 .await
634 .trace_err();
635
636 app_state
637 .db
638 .delete_stale_servers(&app_state.config.zed_environment, server_id)
639 .await
640 .trace_err();
641 }
642 .instrument(span),
643 );
644 Ok(())
645 }
646
647 pub fn teardown(&self) {
648 self.peer.teardown();
649 self.connection_pool.lock().reset();
650 let _ = self.teardown.send(true);
651 }
652
653 #[cfg(test)]
654 pub fn reset(&self, id: ServerId) {
655 self.teardown();
656 *self.id.lock() = id;
657 self.peer.reset(id.0 as u32);
658 let _ = self.teardown.send(false);
659 }
660
661 #[cfg(test)]
662 pub fn id(&self) -> ServerId {
663 *self.id.lock()
664 }
665
666 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
667 where
668 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
669 Fut: 'static + Send + Future<Output = Result<()>>,
670 M: EnvelopedMessage,
671 {
672 let prev_handler = self.handlers.insert(
673 TypeId::of::<M>(),
674 Box::new(move |envelope, session, span| {
675 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
676 let received_at = envelope.received_at;
677 tracing::info!("message received");
678 let start_time = Instant::now();
679 let future = (handler)(
680 *envelope,
681 MessageContext {
682 session,
683 span: span.clone(),
684 },
685 );
686 async move {
687 let result = future.await;
688 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
689 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
690 let queue_duration_ms = total_duration_ms - processing_duration_ms;
691 span.record(TOTAL_DURATION_MS, total_duration_ms);
692 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
693 span.record(QUEUE_DURATION_MS, queue_duration_ms);
694 match result {
695 Err(error) => {
696 tracing::error!(?error, "error handling message")
697 }
698 Ok(()) => tracing::info!("finished handling message"),
699 }
700 }
701 .boxed()
702 }),
703 );
704 if prev_handler.is_some() {
705 panic!("registered a handler for the same message twice");
706 }
707 self
708 }
709
710 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
711 where
712 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
713 Fut: 'static + Send + Future<Output = Result<()>>,
714 M: EnvelopedMessage,
715 {
716 self.add_handler(move |envelope, session| handler(envelope.payload, session));
717 self
718 }
719
720 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
721 where
722 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
723 Fut: Send + Future<Output = Result<()>>,
724 M: RequestMessage,
725 {
726 let handler = Arc::new(handler);
727 self.add_handler(move |envelope, session| {
728 let receipt = envelope.receipt();
729 let handler = handler.clone();
730 async move {
731 let peer = session.peer.clone();
732 let responded = Arc::new(AtomicBool::default());
733 let response = Response {
734 peer: peer.clone(),
735 responded: responded.clone(),
736 receipt,
737 };
738 match (handler)(envelope.payload, response, session).await {
739 Ok(()) => {
740 if responded.load(std::sync::atomic::Ordering::SeqCst) {
741 Ok(())
742 } else {
743 Err(anyhow!("handler did not send a response"))?
744 }
745 }
746 Err(error) => {
747 let proto_err = match &error {
748 Error::Internal(err) => err.to_proto(),
749 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
750 };
751 peer.respond_with_error(receipt, proto_err)?;
752 Err(error)
753 }
754 }
755 }
756 })
757 }
758
759 pub fn handle_connection(
760 self: &Arc<Self>,
761 connection: Connection,
762 address: String,
763 principal: Principal,
764 zed_version: ZedVersion,
765 release_channel: Option<String>,
766 user_agent: Option<String>,
767 geoip_country_code: Option<String>,
768 system_id: Option<String>,
769 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
770 executor: Executor,
771 connection_guard: Option<ConnectionGuard>,
772 ) -> impl Future<Output = ()> + use<> {
773 let this = self.clone();
774 let span = info_span!("handle connection", %address,
775 connection_id=field::Empty,
776 user_id=field::Empty,
777 login=field::Empty,
778 impersonator=field::Empty,
779 user_agent=field::Empty,
780 geoip_country_code=field::Empty,
781 release_channel=field::Empty,
782 );
783 principal.update_span(&span);
784 if let Some(user_agent) = user_agent {
785 span.record("user_agent", user_agent);
786 }
787 if let Some(release_channel) = release_channel {
788 span.record("release_channel", release_channel);
789 }
790
791 if let Some(country_code) = geoip_country_code.as_ref() {
792 span.record("geoip_country_code", country_code);
793 }
794
795 let mut teardown = self.teardown.subscribe();
796 async move {
797 if *teardown.borrow() {
798 tracing::error!("server is tearing down");
799 return;
800 }
801
802 let (connection_id, handle_io, mut incoming_rx) =
803 this.peer.add_connection(connection, {
804 let executor = executor.clone();
805 move |duration| executor.sleep(duration)
806 });
807 tracing::Span::current().record("connection_id", format!("{}", connection_id));
808
809 tracing::info!("connection opened");
810
811 let session = Session {
812 principal: principal.clone(),
813 connection_id,
814 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
815 peer: this.peer.clone(),
816 connection_pool: this.connection_pool.clone(),
817 app_state: this.app_state.clone(),
818 geoip_country_code,
819 system_id,
820 _executor: executor.clone(),
821 };
822
823 if let Err(error) = this
824 .send_initial_client_update(
825 connection_id,
826 zed_version,
827 send_connection_id,
828 &session,
829 )
830 .await
831 {
832 tracing::error!(?error, "failed to send initial client update");
833 return;
834 }
835 drop(connection_guard);
836
837 let handle_io = handle_io.fuse();
838 futures::pin_mut!(handle_io);
839
840 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
841 // This prevents deadlocks when e.g., client A performs a request to client B and
842 // client B performs a request to client A. If both clients stop processing further
843 // messages until their respective request completes, they won't have a chance to
844 // respond to the other client's request and cause a deadlock.
845 //
846 // This arrangement ensures we will attempt to process earlier messages first, but fall
847 // back to processing messages arrived later in the spirit of making progress.
848 const MAX_CONCURRENT_HANDLERS: usize = 256;
849 let mut foreground_message_handlers = FuturesUnordered::new();
850 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
851 let get_concurrent_handlers = {
852 let concurrent_handlers = concurrent_handlers.clone();
853 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
854 };
855 loop {
856 let next_message = async {
857 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
858 let message = incoming_rx.next().await;
859 // Cache the concurrent_handlers here, so that we know what the
860 // queue looks like as each handler starts
861 (permit, message, get_concurrent_handlers())
862 }
863 .fuse();
864 futures::pin_mut!(next_message);
865 futures::select_biased! {
866 _ = teardown.changed().fuse() => return,
867 result = handle_io => {
868 if let Err(error) = result {
869 tracing::error!(?error, "error handling I/O");
870 }
871 break;
872 }
873 _ = foreground_message_handlers.next() => {}
874 next_message = next_message => {
875 let (permit, message, concurrent_handlers) = next_message;
876 if let Some(message) = message {
877 let type_name = message.payload_type_name();
878 // note: we copy all the fields from the parent span so we can query them in the logs.
879 // (https://github.com/tokio-rs/tracing/issues/2670).
880 let span = tracing::info_span!("receive message",
881 %connection_id,
882 %address,
883 type_name,
884 concurrent_handlers,
885 user_id=field::Empty,
886 login=field::Empty,
887 impersonator=field::Empty,
888 lsp_query_request=field::Empty,
889 release_channel=field::Empty,
890 { TOTAL_DURATION_MS }=field::Empty,
891 { PROCESSING_DURATION_MS }=field::Empty,
892 { QUEUE_DURATION_MS }=field::Empty,
893 { HOST_WAITING_MS }=field::Empty
894 );
895 principal.update_span(&span);
896 let span_enter = span.enter();
897 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
898 let is_background = message.is_background();
899 let handle_message = (handler)(message, session.clone(), span.clone());
900 drop(span_enter);
901
902 let handle_message = async move {
903 handle_message.await;
904 drop(permit);
905 }.instrument(span);
906 if is_background {
907 executor.spawn_detached(handle_message);
908 } else {
909 foreground_message_handlers.push(handle_message);
910 }
911 } else {
912 tracing::error!("no message handler");
913 }
914 } else {
915 tracing::info!("connection closed");
916 break;
917 }
918 }
919 }
920 }
921
922 drop(foreground_message_handlers);
923 let concurrent_handlers = get_concurrent_handlers();
924 tracing::info!(concurrent_handlers, "signing out");
925 if let Err(error) = connection_lost(session, teardown, executor).await {
926 tracing::error!(?error, "error signing out");
927 }
928 }
929 .instrument(span)
930 }
931
932 async fn send_initial_client_update(
933 &self,
934 connection_id: ConnectionId,
935 zed_version: ZedVersion,
936 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
937 session: &Session,
938 ) -> Result<()> {
939 self.peer.send(
940 connection_id,
941 proto::Hello {
942 peer_id: Some(connection_id.into()),
943 },
944 )?;
945 tracing::info!("sent hello message");
946 if let Some(send_connection_id) = send_connection_id.take() {
947 let _ = send_connection_id.send(connection_id);
948 }
949
950 match &session.principal {
951 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
952 if !user.connected_once {
953 self.peer.send(connection_id, proto::ShowContacts {})?;
954 self.app_state
955 .db
956 .set_user_connected_once(user.id, true)
957 .await?;
958 }
959
960 let contacts = self.app_state.db.get_contacts(user.id).await?;
961
962 {
963 let mut pool = self.connection_pool.lock();
964 pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
965 self.peer.send(
966 connection_id,
967 build_initial_contacts_update(contacts, &pool),
968 )?;
969 }
970
971 if should_auto_subscribe_to_channels(&zed_version) {
972 subscribe_user_to_channels(user.id, session).await?;
973 }
974
975 if let Some(incoming_call) =
976 self.app_state.db.incoming_call_for_user(user.id).await?
977 {
978 self.peer.send(connection_id, incoming_call)?;
979 }
980
981 update_user_contacts(user.id, session).await?;
982 }
983 }
984
985 Ok(())
986 }
987
988 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
989 ServerSnapshot {
990 connection_pool: ConnectionPoolGuard {
991 guard: self.connection_pool.lock(),
992 _not_send: PhantomData,
993 },
994 peer: &self.peer,
995 }
996 }
997}
998
999impl Deref for ConnectionPoolGuard<'_> {
1000 type Target = ConnectionPool;
1001
1002 fn deref(&self) -> &Self::Target {
1003 &self.guard
1004 }
1005}
1006
1007impl DerefMut for ConnectionPoolGuard<'_> {
1008 fn deref_mut(&mut self) -> &mut Self::Target {
1009 &mut self.guard
1010 }
1011}
1012
1013impl Drop for ConnectionPoolGuard<'_> {
1014 fn drop(&mut self) {
1015 #[cfg(test)]
1016 self.check_invariants();
1017 }
1018}
1019
1020fn broadcast<F>(
1021 sender_id: Option<ConnectionId>,
1022 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1023 mut f: F,
1024) where
1025 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1026{
1027 for receiver_id in receiver_ids {
1028 if Some(receiver_id) != sender_id
1029 && let Err(error) = f(receiver_id)
1030 {
1031 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1032 }
1033 }
1034}
1035
1036pub struct ProtocolVersion(u32);
1037
1038impl Header for ProtocolVersion {
1039 fn name() -> &'static HeaderName {
1040 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1041 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1042 }
1043
1044 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1045 where
1046 Self: Sized,
1047 I: Iterator<Item = &'i axum::http::HeaderValue>,
1048 {
1049 let version = values
1050 .next()
1051 .ok_or_else(axum::headers::Error::invalid)?
1052 .to_str()
1053 .map_err(|_| axum::headers::Error::invalid())?
1054 .parse()
1055 .map_err(|_| axum::headers::Error::invalid())?;
1056 Ok(Self(version))
1057 }
1058
1059 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1060 values.extend([self.0.to_string().parse().unwrap()]);
1061 }
1062}
1063
1064pub struct AppVersionHeader(Version);
1065impl Header for AppVersionHeader {
1066 fn name() -> &'static HeaderName {
1067 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1068 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1069 }
1070
1071 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1072 where
1073 Self: Sized,
1074 I: Iterator<Item = &'i axum::http::HeaderValue>,
1075 {
1076 let version = values
1077 .next()
1078 .ok_or_else(axum::headers::Error::invalid)?
1079 .to_str()
1080 .map_err(|_| axum::headers::Error::invalid())?
1081 .parse()
1082 .map_err(|_| axum::headers::Error::invalid())?;
1083 Ok(Self(version))
1084 }
1085
1086 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1087 values.extend([self.0.to_string().parse().unwrap()]);
1088 }
1089}
1090
1091#[derive(Debug)]
1092pub struct ReleaseChannelHeader(String);
1093
1094impl Header for ReleaseChannelHeader {
1095 fn name() -> &'static HeaderName {
1096 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1097 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1098 }
1099
1100 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1101 where
1102 Self: Sized,
1103 I: Iterator<Item = &'i axum::http::HeaderValue>,
1104 {
1105 Ok(Self(
1106 values
1107 .next()
1108 .ok_or_else(axum::headers::Error::invalid)?
1109 .to_str()
1110 .map_err(|_| axum::headers::Error::invalid())?
1111 .to_owned(),
1112 ))
1113 }
1114
1115 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1116 values.extend([self.0.parse().unwrap()]);
1117 }
1118}
1119
1120pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1121 Router::new()
1122 .route("/rpc", get(handle_websocket_request))
1123 .layer(
1124 ServiceBuilder::new()
1125 .layer(Extension(server.app_state.clone()))
1126 .layer(middleware::from_fn(auth::validate_header)),
1127 )
1128 .route("/metrics", get(handle_metrics))
1129 .layer(Extension(server))
1130}
1131
1132pub async fn handle_websocket_request(
1133 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1134 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1135 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1136 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1137 Extension(server): Extension<Arc<Server>>,
1138 Extension(principal): Extension<Principal>,
1139 user_agent: Option<TypedHeader<UserAgent>>,
1140 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1141 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1142 ws: WebSocketUpgrade,
1143) -> axum::response::Response {
1144 if protocol_version != rpc::PROTOCOL_VERSION {
1145 return (
1146 StatusCode::UPGRADE_REQUIRED,
1147 "client must be upgraded".to_string(),
1148 )
1149 .into_response();
1150 }
1151
1152 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1153 return (
1154 StatusCode::UPGRADE_REQUIRED,
1155 "no version header found".to_string(),
1156 )
1157 .into_response();
1158 };
1159
1160 let release_channel = release_channel_header.map(|header| header.0.0);
1161
1162 if !version.can_collaborate() {
1163 return (
1164 StatusCode::UPGRADE_REQUIRED,
1165 "client must be upgraded".to_string(),
1166 )
1167 .into_response();
1168 }
1169
1170 let socket_address = socket_address.to_string();
1171
1172 // Acquire connection guard before WebSocket upgrade
1173 let connection_guard = match ConnectionGuard::try_acquire() {
1174 Ok(guard) => guard,
1175 Err(()) => {
1176 return (
1177 StatusCode::SERVICE_UNAVAILABLE,
1178 "Too many concurrent connections",
1179 )
1180 .into_response();
1181 }
1182 };
1183
1184 ws.on_upgrade(move |socket| {
1185 let socket = socket
1186 .map_ok(to_tungstenite_message)
1187 .err_into()
1188 .with(|message| async move { to_axum_message(message) });
1189 let connection = Connection::new(Box::pin(socket));
1190 async move {
1191 server
1192 .handle_connection(
1193 connection,
1194 socket_address,
1195 principal,
1196 version,
1197 release_channel,
1198 user_agent.map(|header| header.to_string()),
1199 country_code_header.map(|header| header.to_string()),
1200 system_id_header.map(|header| header.to_string()),
1201 None,
1202 Executor::Production,
1203 Some(connection_guard),
1204 )
1205 .await;
1206 }
1207 })
1208}
1209
1210pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1211 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1212 let connections_metric = CONNECTIONS_METRIC
1213 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1214
1215 let connections = server
1216 .connection_pool
1217 .lock()
1218 .connections()
1219 .filter(|connection| !connection.admin)
1220 .count();
1221 connections_metric.set(connections as _);
1222
1223 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1224 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1225 register_int_gauge!(
1226 "shared_projects",
1227 "number of open projects with one or more guests"
1228 )
1229 .unwrap()
1230 });
1231
1232 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1233 shared_projects_metric.set(shared_projects as _);
1234
1235 let encoder = prometheus::TextEncoder::new();
1236 let metric_families = prometheus::gather();
1237 let encoded_metrics = encoder
1238 .encode_to_string(&metric_families)
1239 .map_err(|err| anyhow!("{err}"))?;
1240 Ok(encoded_metrics)
1241}
1242
1243#[instrument(err, skip(executor))]
1244async fn connection_lost(
1245 session: Session,
1246 mut teardown: watch::Receiver<bool>,
1247 executor: Executor,
1248) -> Result<()> {
1249 session.peer.disconnect(session.connection_id);
1250 session
1251 .connection_pool()
1252 .await
1253 .remove_connection(session.connection_id)?;
1254
1255 session
1256 .db()
1257 .await
1258 .connection_lost(session.connection_id)
1259 .await
1260 .trace_err();
1261
1262 futures::select_biased! {
1263 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1264
1265 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1266 leave_room_for_session(&session, session.connection_id).await.trace_err();
1267 leave_channel_buffers_for_session(&session)
1268 .await
1269 .trace_err();
1270
1271 if !session
1272 .connection_pool()
1273 .await
1274 .is_user_online(session.user_id())
1275 {
1276 let db = session.db().await;
1277 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1278 room_updated(&room, &session.peer);
1279 }
1280 }
1281
1282 update_user_contacts(session.user_id(), &session).await?;
1283 },
1284 _ = teardown.changed().fuse() => {}
1285 }
1286
1287 Ok(())
1288}
1289
1290/// Acknowledges a ping from a client, used to keep the connection alive.
1291async fn ping(
1292 _: proto::Ping,
1293 response: Response<proto::Ping>,
1294 _session: MessageContext,
1295) -> Result<()> {
1296 response.send(proto::Ack {})?;
1297 Ok(())
1298}
1299
1300/// Creates a new room for calling (outside of channels)
1301async fn create_room(
1302 _request: proto::CreateRoom,
1303 response: Response<proto::CreateRoom>,
1304 session: MessageContext,
1305) -> Result<()> {
1306 let livekit_room = nanoid::nanoid!(30);
1307
1308 let live_kit_connection_info = util::maybe!(async {
1309 let live_kit = session.app_state.livekit_client.as_ref();
1310 let live_kit = live_kit?;
1311 let user_id = session.user_id().to_string();
1312
1313 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1314
1315 Some(proto::LiveKitConnectionInfo {
1316 server_url: live_kit.url().into(),
1317 token,
1318 can_publish: true,
1319 })
1320 })
1321 .await;
1322
1323 let room = session
1324 .db()
1325 .await
1326 .create_room(session.user_id(), session.connection_id, &livekit_room)
1327 .await?;
1328
1329 response.send(proto::CreateRoomResponse {
1330 room: Some(room.clone()),
1331 live_kit_connection_info,
1332 })?;
1333
1334 update_user_contacts(session.user_id(), &session).await?;
1335 Ok(())
1336}
1337
1338/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1339async fn join_room(
1340 request: proto::JoinRoom,
1341 response: Response<proto::JoinRoom>,
1342 session: MessageContext,
1343) -> Result<()> {
1344 let room_id = RoomId::from_proto(request.id);
1345
1346 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1347
1348 if let Some(channel_id) = channel_id {
1349 return join_channel_internal(channel_id, Box::new(response), session).await;
1350 }
1351
1352 let joined_room = {
1353 let room = session
1354 .db()
1355 .await
1356 .join_room(room_id, session.user_id(), session.connection_id)
1357 .await?;
1358 room_updated(&room.room, &session.peer);
1359 room.into_inner()
1360 };
1361
1362 for connection_id in session
1363 .connection_pool()
1364 .await
1365 .user_connection_ids(session.user_id())
1366 {
1367 session
1368 .peer
1369 .send(
1370 connection_id,
1371 proto::CallCanceled {
1372 room_id: room_id.to_proto(),
1373 },
1374 )
1375 .trace_err();
1376 }
1377
1378 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1379 {
1380 live_kit
1381 .room_token(
1382 &joined_room.room.livekit_room,
1383 &session.user_id().to_string(),
1384 )
1385 .trace_err()
1386 .map(|token| proto::LiveKitConnectionInfo {
1387 server_url: live_kit.url().into(),
1388 token,
1389 can_publish: true,
1390 })
1391 } else {
1392 None
1393 };
1394
1395 response.send(proto::JoinRoomResponse {
1396 room: Some(joined_room.room),
1397 channel_id: None,
1398 live_kit_connection_info,
1399 })?;
1400
1401 update_user_contacts(session.user_id(), &session).await?;
1402 Ok(())
1403}
1404
1405/// Rejoin room is used to reconnect to a room after connection errors.
1406async fn rejoin_room(
1407 request: proto::RejoinRoom,
1408 response: Response<proto::RejoinRoom>,
1409 session: MessageContext,
1410) -> Result<()> {
1411 let room;
1412 let channel;
1413 {
1414 let mut rejoined_room = session
1415 .db()
1416 .await
1417 .rejoin_room(request, session.user_id(), session.connection_id)
1418 .await?;
1419
1420 response.send(proto::RejoinRoomResponse {
1421 room: Some(rejoined_room.room.clone()),
1422 reshared_projects: rejoined_room
1423 .reshared_projects
1424 .iter()
1425 .map(|project| proto::ResharedProject {
1426 id: project.id.to_proto(),
1427 collaborators: project
1428 .collaborators
1429 .iter()
1430 .map(|collaborator| collaborator.to_proto())
1431 .collect(),
1432 })
1433 .collect(),
1434 rejoined_projects: rejoined_room
1435 .rejoined_projects
1436 .iter()
1437 .map(|rejoined_project| rejoined_project.to_proto())
1438 .collect(),
1439 })?;
1440 room_updated(&rejoined_room.room, &session.peer);
1441
1442 for project in &rejoined_room.reshared_projects {
1443 for collaborator in &project.collaborators {
1444 session
1445 .peer
1446 .send(
1447 collaborator.connection_id,
1448 proto::UpdateProjectCollaborator {
1449 project_id: project.id.to_proto(),
1450 old_peer_id: Some(project.old_connection_id.into()),
1451 new_peer_id: Some(session.connection_id.into()),
1452 },
1453 )
1454 .trace_err();
1455 }
1456
1457 broadcast(
1458 Some(session.connection_id),
1459 project
1460 .collaborators
1461 .iter()
1462 .map(|collaborator| collaborator.connection_id),
1463 |connection_id| {
1464 session.peer.forward_send(
1465 session.connection_id,
1466 connection_id,
1467 proto::UpdateProject {
1468 project_id: project.id.to_proto(),
1469 worktrees: project.worktrees.clone(),
1470 },
1471 )
1472 },
1473 );
1474 }
1475
1476 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1477
1478 let rejoined_room = rejoined_room.into_inner();
1479
1480 room = rejoined_room.room;
1481 channel = rejoined_room.channel;
1482 }
1483
1484 if let Some(channel) = channel {
1485 channel_updated(
1486 &channel,
1487 &room,
1488 &session.peer,
1489 &*session.connection_pool().await,
1490 );
1491 }
1492
1493 update_user_contacts(session.user_id(), &session).await?;
1494 Ok(())
1495}
1496
1497fn notify_rejoined_projects(
1498 rejoined_projects: &mut Vec<RejoinedProject>,
1499 session: &Session,
1500) -> Result<()> {
1501 for project in rejoined_projects.iter() {
1502 for collaborator in &project.collaborators {
1503 session
1504 .peer
1505 .send(
1506 collaborator.connection_id,
1507 proto::UpdateProjectCollaborator {
1508 project_id: project.id.to_proto(),
1509 old_peer_id: Some(project.old_connection_id.into()),
1510 new_peer_id: Some(session.connection_id.into()),
1511 },
1512 )
1513 .trace_err();
1514 }
1515 }
1516
1517 for project in rejoined_projects {
1518 for worktree in mem::take(&mut project.worktrees) {
1519 // Stream this worktree's entries.
1520 let message = proto::UpdateWorktree {
1521 project_id: project.id.to_proto(),
1522 worktree_id: worktree.id,
1523 abs_path: worktree.abs_path.clone(),
1524 root_name: worktree.root_name,
1525 updated_entries: worktree.updated_entries,
1526 removed_entries: worktree.removed_entries,
1527 scan_id: worktree.scan_id,
1528 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1529 updated_repositories: worktree.updated_repositories,
1530 removed_repositories: worktree.removed_repositories,
1531 };
1532 for update in proto::split_worktree_update(message) {
1533 session.peer.send(session.connection_id, update)?;
1534 }
1535
1536 // Stream this worktree's diagnostics.
1537 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1538 if let Some(summary) = worktree_diagnostics.next() {
1539 let message = proto::UpdateDiagnosticSummary {
1540 project_id: project.id.to_proto(),
1541 worktree_id: worktree.id,
1542 summary: Some(summary),
1543 more_summaries: worktree_diagnostics.collect(),
1544 };
1545 session.peer.send(session.connection_id, message)?;
1546 }
1547
1548 for settings_file in worktree.settings_files {
1549 session.peer.send(
1550 session.connection_id,
1551 proto::UpdateWorktreeSettings {
1552 project_id: project.id.to_proto(),
1553 worktree_id: worktree.id,
1554 path: settings_file.path,
1555 content: Some(settings_file.content),
1556 kind: Some(settings_file.kind.to_proto().into()),
1557 },
1558 )?;
1559 }
1560 }
1561
1562 for repository in mem::take(&mut project.updated_repositories) {
1563 for update in split_repository_update(repository) {
1564 session.peer.send(session.connection_id, update)?;
1565 }
1566 }
1567
1568 for id in mem::take(&mut project.removed_repositories) {
1569 session.peer.send(
1570 session.connection_id,
1571 proto::RemoveRepository {
1572 project_id: project.id.to_proto(),
1573 id,
1574 },
1575 )?;
1576 }
1577 }
1578
1579 Ok(())
1580}
1581
1582/// leave room disconnects from the room.
1583async fn leave_room(
1584 _: proto::LeaveRoom,
1585 response: Response<proto::LeaveRoom>,
1586 session: MessageContext,
1587) -> Result<()> {
1588 leave_room_for_session(&session, session.connection_id).await?;
1589 response.send(proto::Ack {})?;
1590 Ok(())
1591}
1592
1593/// Updates the permissions of someone else in the room.
1594async fn set_room_participant_role(
1595 request: proto::SetRoomParticipantRole,
1596 response: Response<proto::SetRoomParticipantRole>,
1597 session: MessageContext,
1598) -> Result<()> {
1599 let user_id = UserId::from_proto(request.user_id);
1600 let role = ChannelRole::from(request.role());
1601
1602 let (livekit_room, can_publish) = {
1603 let room = session
1604 .db()
1605 .await
1606 .set_room_participant_role(
1607 session.user_id(),
1608 RoomId::from_proto(request.room_id),
1609 user_id,
1610 role,
1611 )
1612 .await?;
1613
1614 let livekit_room = room.livekit_room.clone();
1615 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1616 room_updated(&room, &session.peer);
1617 (livekit_room, can_publish)
1618 };
1619
1620 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1621 live_kit
1622 .update_participant(
1623 livekit_room.clone(),
1624 request.user_id.to_string(),
1625 livekit_api::proto::ParticipantPermission {
1626 can_subscribe: true,
1627 can_publish,
1628 can_publish_data: can_publish,
1629 hidden: false,
1630 recorder: false,
1631 },
1632 )
1633 .await
1634 .trace_err();
1635 }
1636
1637 response.send(proto::Ack {})?;
1638 Ok(())
1639}
1640
1641/// Call someone else into the current room
1642async fn call(
1643 request: proto::Call,
1644 response: Response<proto::Call>,
1645 session: MessageContext,
1646) -> Result<()> {
1647 let room_id = RoomId::from_proto(request.room_id);
1648 let calling_user_id = session.user_id();
1649 let calling_connection_id = session.connection_id;
1650 let called_user_id = UserId::from_proto(request.called_user_id);
1651 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1652 if !session
1653 .db()
1654 .await
1655 .has_contact(calling_user_id, called_user_id)
1656 .await?
1657 {
1658 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1659 }
1660
1661 let incoming_call = {
1662 let (room, incoming_call) = &mut *session
1663 .db()
1664 .await
1665 .call(
1666 room_id,
1667 calling_user_id,
1668 calling_connection_id,
1669 called_user_id,
1670 initial_project_id,
1671 )
1672 .await?;
1673 room_updated(room, &session.peer);
1674 mem::take(incoming_call)
1675 };
1676 update_user_contacts(called_user_id, &session).await?;
1677
1678 let mut calls = session
1679 .connection_pool()
1680 .await
1681 .user_connection_ids(called_user_id)
1682 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1683 .collect::<FuturesUnordered<_>>();
1684
1685 while let Some(call_response) = calls.next().await {
1686 match call_response.as_ref() {
1687 Ok(_) => {
1688 response.send(proto::Ack {})?;
1689 return Ok(());
1690 }
1691 Err(_) => {
1692 call_response.trace_err();
1693 }
1694 }
1695 }
1696
1697 {
1698 let room = session
1699 .db()
1700 .await
1701 .call_failed(room_id, called_user_id)
1702 .await?;
1703 room_updated(&room, &session.peer);
1704 }
1705 update_user_contacts(called_user_id, &session).await?;
1706
1707 Err(anyhow!("failed to ring user"))?
1708}
1709
1710/// Cancel an outgoing call.
1711async fn cancel_call(
1712 request: proto::CancelCall,
1713 response: Response<proto::CancelCall>,
1714 session: MessageContext,
1715) -> Result<()> {
1716 let called_user_id = UserId::from_proto(request.called_user_id);
1717 let room_id = RoomId::from_proto(request.room_id);
1718 {
1719 let room = session
1720 .db()
1721 .await
1722 .cancel_call(room_id, session.connection_id, called_user_id)
1723 .await?;
1724 room_updated(&room, &session.peer);
1725 }
1726
1727 for connection_id in session
1728 .connection_pool()
1729 .await
1730 .user_connection_ids(called_user_id)
1731 {
1732 session
1733 .peer
1734 .send(
1735 connection_id,
1736 proto::CallCanceled {
1737 room_id: room_id.to_proto(),
1738 },
1739 )
1740 .trace_err();
1741 }
1742 response.send(proto::Ack {})?;
1743
1744 update_user_contacts(called_user_id, &session).await?;
1745 Ok(())
1746}
1747
1748/// Decline an incoming call.
1749async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1750 let room_id = RoomId::from_proto(message.room_id);
1751 {
1752 let room = session
1753 .db()
1754 .await
1755 .decline_call(Some(room_id), session.user_id())
1756 .await?
1757 .context("declining call")?;
1758 room_updated(&room, &session.peer);
1759 }
1760
1761 for connection_id in session
1762 .connection_pool()
1763 .await
1764 .user_connection_ids(session.user_id())
1765 {
1766 session
1767 .peer
1768 .send(
1769 connection_id,
1770 proto::CallCanceled {
1771 room_id: room_id.to_proto(),
1772 },
1773 )
1774 .trace_err();
1775 }
1776 update_user_contacts(session.user_id(), &session).await?;
1777 Ok(())
1778}
1779
1780/// Updates other participants in the room with your current location.
1781async fn update_participant_location(
1782 request: proto::UpdateParticipantLocation,
1783 response: Response<proto::UpdateParticipantLocation>,
1784 session: MessageContext,
1785) -> Result<()> {
1786 let room_id = RoomId::from_proto(request.room_id);
1787 let location = request.location.context("invalid location")?;
1788
1789 let db = session.db().await;
1790 let room = db
1791 .update_room_participant_location(room_id, session.connection_id, location)
1792 .await?;
1793
1794 room_updated(&room, &session.peer);
1795 response.send(proto::Ack {})?;
1796 Ok(())
1797}
1798
1799/// Share a project into the room.
1800async fn share_project(
1801 request: proto::ShareProject,
1802 response: Response<proto::ShareProject>,
1803 session: MessageContext,
1804) -> Result<()> {
1805 let (project_id, room) = &*session
1806 .db()
1807 .await
1808 .share_project(
1809 RoomId::from_proto(request.room_id),
1810 session.connection_id,
1811 &request.worktrees,
1812 request.is_ssh_project,
1813 request.windows_paths.unwrap_or(false),
1814 )
1815 .await?;
1816 response.send(proto::ShareProjectResponse {
1817 project_id: project_id.to_proto(),
1818 })?;
1819 room_updated(room, &session.peer);
1820
1821 Ok(())
1822}
1823
1824/// Unshare a project from the room.
1825async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1826 let project_id = ProjectId::from_proto(message.project_id);
1827 unshare_project_internal(project_id, session.connection_id, &session).await
1828}
1829
1830async fn unshare_project_internal(
1831 project_id: ProjectId,
1832 connection_id: ConnectionId,
1833 session: &Session,
1834) -> Result<()> {
1835 let delete = {
1836 let room_guard = session
1837 .db()
1838 .await
1839 .unshare_project(project_id, connection_id)
1840 .await?;
1841
1842 let (delete, room, guest_connection_ids) = &*room_guard;
1843
1844 let message = proto::UnshareProject {
1845 project_id: project_id.to_proto(),
1846 };
1847
1848 broadcast(
1849 Some(connection_id),
1850 guest_connection_ids.iter().copied(),
1851 |conn_id| session.peer.send(conn_id, message.clone()),
1852 );
1853 if let Some(room) = room {
1854 room_updated(room, &session.peer);
1855 }
1856
1857 *delete
1858 };
1859
1860 if delete {
1861 let db = session.db().await;
1862 db.delete_project(project_id).await?;
1863 }
1864
1865 Ok(())
1866}
1867
1868/// Join someone elses shared project.
1869async fn join_project(
1870 request: proto::JoinProject,
1871 response: Response<proto::JoinProject>,
1872 session: MessageContext,
1873) -> Result<()> {
1874 let project_id = ProjectId::from_proto(request.project_id);
1875
1876 tracing::info!(%project_id, "join project");
1877
1878 let db = session.db().await;
1879 let (project, replica_id) = &mut *db
1880 .join_project(
1881 project_id,
1882 session.connection_id,
1883 session.user_id(),
1884 request.committer_name.clone(),
1885 request.committer_email.clone(),
1886 )
1887 .await?;
1888 drop(db);
1889 tracing::info!(%project_id, "join remote project");
1890 let collaborators = project
1891 .collaborators
1892 .iter()
1893 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1894 .map(|collaborator| collaborator.to_proto())
1895 .collect::<Vec<_>>();
1896 let project_id = project.id;
1897 let guest_user_id = session.user_id();
1898
1899 let worktrees = project
1900 .worktrees
1901 .iter()
1902 .map(|(id, worktree)| proto::WorktreeMetadata {
1903 id: *id,
1904 root_name: worktree.root_name.clone(),
1905 visible: worktree.visible,
1906 abs_path: worktree.abs_path.clone(),
1907 })
1908 .collect::<Vec<_>>();
1909
1910 let add_project_collaborator = proto::AddProjectCollaborator {
1911 project_id: project_id.to_proto(),
1912 collaborator: Some(proto::Collaborator {
1913 peer_id: Some(session.connection_id.into()),
1914 replica_id: replica_id.0 as u32,
1915 user_id: guest_user_id.to_proto(),
1916 is_host: false,
1917 committer_name: request.committer_name.clone(),
1918 committer_email: request.committer_email.clone(),
1919 }),
1920 };
1921
1922 for collaborator in &collaborators {
1923 session
1924 .peer
1925 .send(
1926 collaborator.peer_id.unwrap().into(),
1927 add_project_collaborator.clone(),
1928 )
1929 .trace_err();
1930 }
1931
1932 // First, we send the metadata associated with each worktree.
1933 let (language_servers, language_server_capabilities) = project
1934 .language_servers
1935 .clone()
1936 .into_iter()
1937 .map(|server| (server.server, server.capabilities))
1938 .unzip();
1939 response.send(proto::JoinProjectResponse {
1940 project_id: project.id.0 as u64,
1941 worktrees,
1942 replica_id: replica_id.0 as u32,
1943 collaborators,
1944 language_servers,
1945 language_server_capabilities,
1946 role: project.role.into(),
1947 windows_paths: project.path_style == PathStyle::Windows,
1948 })?;
1949
1950 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1951 // Stream this worktree's entries.
1952 let message = proto::UpdateWorktree {
1953 project_id: project_id.to_proto(),
1954 worktree_id,
1955 abs_path: worktree.abs_path.clone(),
1956 root_name: worktree.root_name,
1957 updated_entries: worktree.entries,
1958 removed_entries: Default::default(),
1959 scan_id: worktree.scan_id,
1960 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1961 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1962 removed_repositories: Default::default(),
1963 };
1964 for update in proto::split_worktree_update(message) {
1965 session.peer.send(session.connection_id, update.clone())?;
1966 }
1967
1968 // Stream this worktree's diagnostics.
1969 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1970 if let Some(summary) = worktree_diagnostics.next() {
1971 let message = proto::UpdateDiagnosticSummary {
1972 project_id: project.id.to_proto(),
1973 worktree_id: worktree.id,
1974 summary: Some(summary),
1975 more_summaries: worktree_diagnostics.collect(),
1976 };
1977 session.peer.send(session.connection_id, message)?;
1978 }
1979
1980 for settings_file in worktree.settings_files {
1981 session.peer.send(
1982 session.connection_id,
1983 proto::UpdateWorktreeSettings {
1984 project_id: project_id.to_proto(),
1985 worktree_id: worktree.id,
1986 path: settings_file.path,
1987 content: Some(settings_file.content),
1988 kind: Some(settings_file.kind.to_proto() as i32),
1989 },
1990 )?;
1991 }
1992 }
1993
1994 for repository in mem::take(&mut project.repositories) {
1995 for update in split_repository_update(repository) {
1996 session.peer.send(session.connection_id, update)?;
1997 }
1998 }
1999
2000 for language_server in &project.language_servers {
2001 session.peer.send(
2002 session.connection_id,
2003 proto::UpdateLanguageServer {
2004 project_id: project_id.to_proto(),
2005 server_name: Some(language_server.server.name.clone()),
2006 language_server_id: language_server.server.id,
2007 variant: Some(
2008 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2009 proto::LspDiskBasedDiagnosticsUpdated {},
2010 ),
2011 ),
2012 },
2013 )?;
2014 }
2015
2016 Ok(())
2017}
2018
2019/// Leave someone elses shared project.
2020async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2021 let sender_id = session.connection_id;
2022 let project_id = ProjectId::from_proto(request.project_id);
2023 let db = session.db().await;
2024
2025 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2026 tracing::info!(
2027 %project_id,
2028 "leave project"
2029 );
2030
2031 project_left(project, &session);
2032 if let Some(room) = room {
2033 room_updated(room, &session.peer);
2034 }
2035
2036 Ok(())
2037}
2038
2039/// Updates other participants with changes to the project
2040async fn update_project(
2041 request: proto::UpdateProject,
2042 response: Response<proto::UpdateProject>,
2043 session: MessageContext,
2044) -> Result<()> {
2045 let project_id = ProjectId::from_proto(request.project_id);
2046 let (room, guest_connection_ids) = &*session
2047 .db()
2048 .await
2049 .update_project(project_id, session.connection_id, &request.worktrees)
2050 .await?;
2051 broadcast(
2052 Some(session.connection_id),
2053 guest_connection_ids.iter().copied(),
2054 |connection_id| {
2055 session
2056 .peer
2057 .forward_send(session.connection_id, connection_id, request.clone())
2058 },
2059 );
2060 if let Some(room) = room {
2061 room_updated(room, &session.peer);
2062 }
2063 response.send(proto::Ack {})?;
2064
2065 Ok(())
2066}
2067
2068/// Updates other participants with changes to the worktree
2069async fn update_worktree(
2070 request: proto::UpdateWorktree,
2071 response: Response<proto::UpdateWorktree>,
2072 session: MessageContext,
2073) -> Result<()> {
2074 let guest_connection_ids = session
2075 .db()
2076 .await
2077 .update_worktree(&request, session.connection_id)
2078 .await?;
2079
2080 broadcast(
2081 Some(session.connection_id),
2082 guest_connection_ids.iter().copied(),
2083 |connection_id| {
2084 session
2085 .peer
2086 .forward_send(session.connection_id, connection_id, request.clone())
2087 },
2088 );
2089 response.send(proto::Ack {})?;
2090 Ok(())
2091}
2092
2093async fn update_repository(
2094 request: proto::UpdateRepository,
2095 response: Response<proto::UpdateRepository>,
2096 session: MessageContext,
2097) -> Result<()> {
2098 let guest_connection_ids = session
2099 .db()
2100 .await
2101 .update_repository(&request, session.connection_id)
2102 .await?;
2103
2104 broadcast(
2105 Some(session.connection_id),
2106 guest_connection_ids.iter().copied(),
2107 |connection_id| {
2108 session
2109 .peer
2110 .forward_send(session.connection_id, connection_id, request.clone())
2111 },
2112 );
2113 response.send(proto::Ack {})?;
2114 Ok(())
2115}
2116
2117async fn remove_repository(
2118 request: proto::RemoveRepository,
2119 response: Response<proto::RemoveRepository>,
2120 session: MessageContext,
2121) -> Result<()> {
2122 let guest_connection_ids = session
2123 .db()
2124 .await
2125 .remove_repository(&request, session.connection_id)
2126 .await?;
2127
2128 broadcast(
2129 Some(session.connection_id),
2130 guest_connection_ids.iter().copied(),
2131 |connection_id| {
2132 session
2133 .peer
2134 .forward_send(session.connection_id, connection_id, request.clone())
2135 },
2136 );
2137 response.send(proto::Ack {})?;
2138 Ok(())
2139}
2140
2141/// Updates other participants with changes to the diagnostics
2142async fn update_diagnostic_summary(
2143 message: proto::UpdateDiagnosticSummary,
2144 session: MessageContext,
2145) -> Result<()> {
2146 let guest_connection_ids = session
2147 .db()
2148 .await
2149 .update_diagnostic_summary(&message, session.connection_id)
2150 .await?;
2151
2152 broadcast(
2153 Some(session.connection_id),
2154 guest_connection_ids.iter().copied(),
2155 |connection_id| {
2156 session
2157 .peer
2158 .forward_send(session.connection_id, connection_id, message.clone())
2159 },
2160 );
2161
2162 Ok(())
2163}
2164
2165/// Updates other participants with changes to the worktree settings
2166async fn update_worktree_settings(
2167 message: proto::UpdateWorktreeSettings,
2168 session: MessageContext,
2169) -> Result<()> {
2170 let guest_connection_ids = session
2171 .db()
2172 .await
2173 .update_worktree_settings(&message, session.connection_id)
2174 .await?;
2175
2176 broadcast(
2177 Some(session.connection_id),
2178 guest_connection_ids.iter().copied(),
2179 |connection_id| {
2180 session
2181 .peer
2182 .forward_send(session.connection_id, connection_id, message.clone())
2183 },
2184 );
2185
2186 Ok(())
2187}
2188
2189/// Notify other participants that a language server has started.
2190async fn start_language_server(
2191 request: proto::StartLanguageServer,
2192 session: MessageContext,
2193) -> Result<()> {
2194 let guest_connection_ids = session
2195 .db()
2196 .await
2197 .start_language_server(&request, session.connection_id)
2198 .await?;
2199
2200 broadcast(
2201 Some(session.connection_id),
2202 guest_connection_ids.iter().copied(),
2203 |connection_id| {
2204 session
2205 .peer
2206 .forward_send(session.connection_id, connection_id, request.clone())
2207 },
2208 );
2209 Ok(())
2210}
2211
2212/// Notify other participants that a language server has changed.
2213async fn update_language_server(
2214 request: proto::UpdateLanguageServer,
2215 session: MessageContext,
2216) -> Result<()> {
2217 let project_id = ProjectId::from_proto(request.project_id);
2218 let db = session.db().await;
2219
2220 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2221 && let Some(capabilities) = update.capabilities.clone()
2222 {
2223 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2224 .await?;
2225 }
2226
2227 let project_connection_ids = db
2228 .project_connection_ids(project_id, session.connection_id, true)
2229 .await?;
2230 broadcast(
2231 Some(session.connection_id),
2232 project_connection_ids.iter().copied(),
2233 |connection_id| {
2234 session
2235 .peer
2236 .forward_send(session.connection_id, connection_id, request.clone())
2237 },
2238 );
2239 Ok(())
2240}
2241
2242/// forward a project request to the host. These requests should be read only
2243/// as guests are allowed to send them.
2244async fn forward_read_only_project_request<T>(
2245 request: T,
2246 response: Response<T>,
2247 session: MessageContext,
2248) -> Result<()>
2249where
2250 T: EntityMessage + RequestMessage,
2251{
2252 let project_id = ProjectId::from_proto(request.remote_entity_id());
2253 let host_connection_id = session
2254 .db()
2255 .await
2256 .host_for_read_only_project_request(project_id, session.connection_id)
2257 .await?;
2258 let payload = session.forward_request(host_connection_id, request).await?;
2259 response.send(payload)?;
2260 Ok(())
2261}
2262
2263/// forward a project request to the host. These requests are disallowed
2264/// for guests.
2265async fn forward_mutating_project_request<T>(
2266 request: T,
2267 response: Response<T>,
2268 session: MessageContext,
2269) -> Result<()>
2270where
2271 T: EntityMessage + RequestMessage,
2272{
2273 let project_id = ProjectId::from_proto(request.remote_entity_id());
2274
2275 let host_connection_id = session
2276 .db()
2277 .await
2278 .host_for_mutating_project_request(project_id, session.connection_id)
2279 .await?;
2280 let payload = session.forward_request(host_connection_id, request).await?;
2281 response.send(payload)?;
2282 Ok(())
2283}
2284
2285async fn lsp_query(
2286 request: proto::LspQuery,
2287 response: Response<proto::LspQuery>,
2288 session: MessageContext,
2289) -> Result<()> {
2290 let (name, should_write) = request.query_name_and_write_permissions();
2291 tracing::Span::current().record("lsp_query_request", name);
2292 tracing::info!("lsp_query message received");
2293 if should_write {
2294 forward_mutating_project_request(request, response, session).await
2295 } else {
2296 forward_read_only_project_request(request, response, session).await
2297 }
2298}
2299
2300/// Notify other participants that a new buffer has been created
2301async fn create_buffer_for_peer(
2302 request: proto::CreateBufferForPeer,
2303 session: MessageContext,
2304) -> Result<()> {
2305 session
2306 .db()
2307 .await
2308 .check_user_is_project_host(
2309 ProjectId::from_proto(request.project_id),
2310 session.connection_id,
2311 )
2312 .await?;
2313 let peer_id = request.peer_id.context("invalid peer id")?;
2314 session
2315 .peer
2316 .forward_send(session.connection_id, peer_id.into(), request)?;
2317 Ok(())
2318}
2319
2320/// Notify other participants that a new image has been created
2321async fn create_image_for_peer(
2322 request: proto::CreateImageForPeer,
2323 session: MessageContext,
2324) -> Result<()> {
2325 session
2326 .db()
2327 .await
2328 .check_user_is_project_host(
2329 ProjectId::from_proto(request.project_id),
2330 session.connection_id,
2331 )
2332 .await?;
2333 let peer_id = request.peer_id.context("invalid peer id")?;
2334 session
2335 .peer
2336 .forward_send(session.connection_id, peer_id.into(), request)?;
2337 Ok(())
2338}
2339
2340/// Notify other participants that a buffer has been updated. This is
2341/// allowed for guests as long as the update is limited to selections.
2342async fn update_buffer(
2343 request: proto::UpdateBuffer,
2344 response: Response<proto::UpdateBuffer>,
2345 session: MessageContext,
2346) -> Result<()> {
2347 let project_id = ProjectId::from_proto(request.project_id);
2348 let mut capability = Capability::ReadOnly;
2349
2350 for op in request.operations.iter() {
2351 match op.variant {
2352 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2353 Some(_) => capability = Capability::ReadWrite,
2354 }
2355 }
2356
2357 let host = {
2358 let guard = session
2359 .db()
2360 .await
2361 .connections_for_buffer_update(project_id, session.connection_id, capability)
2362 .await?;
2363
2364 let (host, guests) = &*guard;
2365
2366 broadcast(
2367 Some(session.connection_id),
2368 guests.clone(),
2369 |connection_id| {
2370 session
2371 .peer
2372 .forward_send(session.connection_id, connection_id, request.clone())
2373 },
2374 );
2375
2376 *host
2377 };
2378
2379 if host != session.connection_id {
2380 session.forward_request(host, request.clone()).await?;
2381 }
2382
2383 response.send(proto::Ack {})?;
2384 Ok(())
2385}
2386
2387async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2388 let project_id = ProjectId::from_proto(message.project_id);
2389
2390 let operation = message.operation.as_ref().context("invalid operation")?;
2391 let capability = match operation.variant.as_ref() {
2392 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2393 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2394 match buffer_op.variant {
2395 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2396 Capability::ReadOnly
2397 }
2398 _ => Capability::ReadWrite,
2399 }
2400 } else {
2401 Capability::ReadWrite
2402 }
2403 }
2404 Some(_) => Capability::ReadWrite,
2405 None => Capability::ReadOnly,
2406 };
2407
2408 let guard = session
2409 .db()
2410 .await
2411 .connections_for_buffer_update(project_id, session.connection_id, capability)
2412 .await?;
2413
2414 let (host, guests) = &*guard;
2415
2416 broadcast(
2417 Some(session.connection_id),
2418 guests.iter().chain([host]).copied(),
2419 |connection_id| {
2420 session
2421 .peer
2422 .forward_send(session.connection_id, connection_id, message.clone())
2423 },
2424 );
2425
2426 Ok(())
2427}
2428
2429/// Notify other participants that a project has been updated.
2430async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2431 request: T,
2432 session: MessageContext,
2433) -> Result<()> {
2434 let project_id = ProjectId::from_proto(request.remote_entity_id());
2435 let project_connection_ids = session
2436 .db()
2437 .await
2438 .project_connection_ids(project_id, session.connection_id, false)
2439 .await?;
2440
2441 broadcast(
2442 Some(session.connection_id),
2443 project_connection_ids.iter().copied(),
2444 |connection_id| {
2445 session
2446 .peer
2447 .forward_send(session.connection_id, connection_id, request.clone())
2448 },
2449 );
2450 Ok(())
2451}
2452
2453/// Start following another user in a call.
2454async fn follow(
2455 request: proto::Follow,
2456 response: Response<proto::Follow>,
2457 session: MessageContext,
2458) -> Result<()> {
2459 let room_id = RoomId::from_proto(request.room_id);
2460 let project_id = request.project_id.map(ProjectId::from_proto);
2461 let leader_id = request.leader_id.context("invalid leader id")?.into();
2462 let follower_id = session.connection_id;
2463
2464 session
2465 .db()
2466 .await
2467 .check_room_participants(room_id, leader_id, session.connection_id)
2468 .await?;
2469
2470 let response_payload = session.forward_request(leader_id, request).await?;
2471 response.send(response_payload)?;
2472
2473 if let Some(project_id) = project_id {
2474 let room = session
2475 .db()
2476 .await
2477 .follow(room_id, project_id, leader_id, follower_id)
2478 .await?;
2479 room_updated(&room, &session.peer);
2480 }
2481
2482 Ok(())
2483}
2484
2485/// Stop following another user in a call.
2486async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2487 let room_id = RoomId::from_proto(request.room_id);
2488 let project_id = request.project_id.map(ProjectId::from_proto);
2489 let leader_id = request.leader_id.context("invalid leader id")?.into();
2490 let follower_id = session.connection_id;
2491
2492 session
2493 .db()
2494 .await
2495 .check_room_participants(room_id, leader_id, session.connection_id)
2496 .await?;
2497
2498 session
2499 .peer
2500 .forward_send(session.connection_id, leader_id, request)?;
2501
2502 if let Some(project_id) = project_id {
2503 let room = session
2504 .db()
2505 .await
2506 .unfollow(room_id, project_id, leader_id, follower_id)
2507 .await?;
2508 room_updated(&room, &session.peer);
2509 }
2510
2511 Ok(())
2512}
2513
2514/// Notify everyone following you of your current location.
2515async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2516 let room_id = RoomId::from_proto(request.room_id);
2517 let database = session.db.lock().await;
2518
2519 let connection_ids = if let Some(project_id) = request.project_id {
2520 let project_id = ProjectId::from_proto(project_id);
2521 database
2522 .project_connection_ids(project_id, session.connection_id, true)
2523 .await?
2524 } else {
2525 database
2526 .room_connection_ids(room_id, session.connection_id)
2527 .await?
2528 };
2529
2530 // For now, don't send view update messages back to that view's current leader.
2531 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2532 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2533 _ => None,
2534 });
2535
2536 for connection_id in connection_ids.iter().cloned() {
2537 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2538 session
2539 .peer
2540 .forward_send(session.connection_id, connection_id, request.clone())?;
2541 }
2542 }
2543 Ok(())
2544}
2545
2546/// Get public data about users.
2547async fn get_users(
2548 request: proto::GetUsers,
2549 response: Response<proto::GetUsers>,
2550 session: MessageContext,
2551) -> Result<()> {
2552 let user_ids = request
2553 .user_ids
2554 .into_iter()
2555 .map(UserId::from_proto)
2556 .collect();
2557 let users = session
2558 .db()
2559 .await
2560 .get_users_by_ids(user_ids)
2561 .await?
2562 .into_iter()
2563 .map(|user| proto::User {
2564 id: user.id.to_proto(),
2565 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2566 github_login: user.github_login,
2567 name: user.name,
2568 })
2569 .collect();
2570 response.send(proto::UsersResponse { users })?;
2571 Ok(())
2572}
2573
2574/// Search for users (to invite) buy Github login
2575async fn fuzzy_search_users(
2576 request: proto::FuzzySearchUsers,
2577 response: Response<proto::FuzzySearchUsers>,
2578 session: MessageContext,
2579) -> Result<()> {
2580 let query = request.query;
2581 let users = match query.len() {
2582 0 => vec![],
2583 1 | 2 => session
2584 .db()
2585 .await
2586 .get_user_by_github_login(&query)
2587 .await?
2588 .into_iter()
2589 .collect(),
2590 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2591 };
2592 let users = users
2593 .into_iter()
2594 .filter(|user| user.id != session.user_id())
2595 .map(|user| proto::User {
2596 id: user.id.to_proto(),
2597 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2598 github_login: user.github_login,
2599 name: user.name,
2600 })
2601 .collect();
2602 response.send(proto::UsersResponse { users })?;
2603 Ok(())
2604}
2605
2606/// Send a contact request to another user.
2607async fn request_contact(
2608 request: proto::RequestContact,
2609 response: Response<proto::RequestContact>,
2610 session: MessageContext,
2611) -> Result<()> {
2612 let requester_id = session.user_id();
2613 let responder_id = UserId::from_proto(request.responder_id);
2614 if requester_id == responder_id {
2615 return Err(anyhow!("cannot add yourself as a contact"))?;
2616 }
2617
2618 let notifications = session
2619 .db()
2620 .await
2621 .send_contact_request(requester_id, responder_id)
2622 .await?;
2623
2624 // Update outgoing contact requests of requester
2625 let mut update = proto::UpdateContacts::default();
2626 update.outgoing_requests.push(responder_id.to_proto());
2627 for connection_id in session
2628 .connection_pool()
2629 .await
2630 .user_connection_ids(requester_id)
2631 {
2632 session.peer.send(connection_id, update.clone())?;
2633 }
2634
2635 // Update incoming contact requests of responder
2636 let mut update = proto::UpdateContacts::default();
2637 update
2638 .incoming_requests
2639 .push(proto::IncomingContactRequest {
2640 requester_id: requester_id.to_proto(),
2641 });
2642 let connection_pool = session.connection_pool().await;
2643 for connection_id in connection_pool.user_connection_ids(responder_id) {
2644 session.peer.send(connection_id, update.clone())?;
2645 }
2646
2647 send_notifications(&connection_pool, &session.peer, notifications);
2648
2649 response.send(proto::Ack {})?;
2650 Ok(())
2651}
2652
2653/// Accept or decline a contact request
2654async fn respond_to_contact_request(
2655 request: proto::RespondToContactRequest,
2656 response: Response<proto::RespondToContactRequest>,
2657 session: MessageContext,
2658) -> Result<()> {
2659 let responder_id = session.user_id();
2660 let requester_id = UserId::from_proto(request.requester_id);
2661 let db = session.db().await;
2662 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2663 db.dismiss_contact_notification(responder_id, requester_id)
2664 .await?;
2665 } else {
2666 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2667
2668 let notifications = db
2669 .respond_to_contact_request(responder_id, requester_id, accept)
2670 .await?;
2671 let requester_busy = db.is_user_busy(requester_id).await?;
2672 let responder_busy = db.is_user_busy(responder_id).await?;
2673
2674 let pool = session.connection_pool().await;
2675 // Update responder with new contact
2676 let mut update = proto::UpdateContacts::default();
2677 if accept {
2678 update
2679 .contacts
2680 .push(contact_for_user(requester_id, requester_busy, &pool));
2681 }
2682 update
2683 .remove_incoming_requests
2684 .push(requester_id.to_proto());
2685 for connection_id in pool.user_connection_ids(responder_id) {
2686 session.peer.send(connection_id, update.clone())?;
2687 }
2688
2689 // Update requester with new contact
2690 let mut update = proto::UpdateContacts::default();
2691 if accept {
2692 update
2693 .contacts
2694 .push(contact_for_user(responder_id, responder_busy, &pool));
2695 }
2696 update
2697 .remove_outgoing_requests
2698 .push(responder_id.to_proto());
2699
2700 for connection_id in pool.user_connection_ids(requester_id) {
2701 session.peer.send(connection_id, update.clone())?;
2702 }
2703
2704 send_notifications(&pool, &session.peer, notifications);
2705 }
2706
2707 response.send(proto::Ack {})?;
2708 Ok(())
2709}
2710
2711/// Remove a contact.
2712async fn remove_contact(
2713 request: proto::RemoveContact,
2714 response: Response<proto::RemoveContact>,
2715 session: MessageContext,
2716) -> Result<()> {
2717 let requester_id = session.user_id();
2718 let responder_id = UserId::from_proto(request.user_id);
2719 let db = session.db().await;
2720 let (contact_accepted, deleted_notification_id) =
2721 db.remove_contact(requester_id, responder_id).await?;
2722
2723 let pool = session.connection_pool().await;
2724 // Update outgoing contact requests of requester
2725 let mut update = proto::UpdateContacts::default();
2726 if contact_accepted {
2727 update.remove_contacts.push(responder_id.to_proto());
2728 } else {
2729 update
2730 .remove_outgoing_requests
2731 .push(responder_id.to_proto());
2732 }
2733 for connection_id in pool.user_connection_ids(requester_id) {
2734 session.peer.send(connection_id, update.clone())?;
2735 }
2736
2737 // Update incoming contact requests of responder
2738 let mut update = proto::UpdateContacts::default();
2739 if contact_accepted {
2740 update.remove_contacts.push(requester_id.to_proto());
2741 } else {
2742 update
2743 .remove_incoming_requests
2744 .push(requester_id.to_proto());
2745 }
2746 for connection_id in pool.user_connection_ids(responder_id) {
2747 session.peer.send(connection_id, update.clone())?;
2748 if let Some(notification_id) = deleted_notification_id {
2749 session.peer.send(
2750 connection_id,
2751 proto::DeleteNotification {
2752 notification_id: notification_id.to_proto(),
2753 },
2754 )?;
2755 }
2756 }
2757
2758 response.send(proto::Ack {})?;
2759 Ok(())
2760}
2761
2762fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2763 version.0.minor < 139
2764}
2765
2766async fn subscribe_to_channels(
2767 _: proto::SubscribeToChannels,
2768 session: MessageContext,
2769) -> Result<()> {
2770 subscribe_user_to_channels(session.user_id(), &session).await?;
2771 Ok(())
2772}
2773
2774async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2775 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2776 let mut pool = session.connection_pool().await;
2777 for membership in &channels_for_user.channel_memberships {
2778 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2779 }
2780 session.peer.send(
2781 session.connection_id,
2782 build_update_user_channels(&channels_for_user),
2783 )?;
2784 session.peer.send(
2785 session.connection_id,
2786 build_channels_update(channels_for_user),
2787 )?;
2788 Ok(())
2789}
2790
2791/// Creates a new channel.
2792async fn create_channel(
2793 request: proto::CreateChannel,
2794 response: Response<proto::CreateChannel>,
2795 session: MessageContext,
2796) -> Result<()> {
2797 let db = session.db().await;
2798
2799 let parent_id = request.parent_id.map(ChannelId::from_proto);
2800 let (channel, membership) = db
2801 .create_channel(&request.name, parent_id, session.user_id())
2802 .await?;
2803
2804 let root_id = channel.root_id();
2805 let channel = Channel::from_model(channel);
2806
2807 response.send(proto::CreateChannelResponse {
2808 channel: Some(channel.to_proto()),
2809 parent_id: request.parent_id,
2810 })?;
2811
2812 let mut connection_pool = session.connection_pool().await;
2813 if let Some(membership) = membership {
2814 connection_pool.subscribe_to_channel(
2815 membership.user_id,
2816 membership.channel_id,
2817 membership.role,
2818 );
2819 let update = proto::UpdateUserChannels {
2820 channel_memberships: vec![proto::ChannelMembership {
2821 channel_id: membership.channel_id.to_proto(),
2822 role: membership.role.into(),
2823 }],
2824 ..Default::default()
2825 };
2826 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2827 session.peer.send(connection_id, update.clone())?;
2828 }
2829 }
2830
2831 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2832 if !role.can_see_channel(channel.visibility) {
2833 continue;
2834 }
2835
2836 let update = proto::UpdateChannels {
2837 channels: vec![channel.to_proto()],
2838 ..Default::default()
2839 };
2840 session.peer.send(connection_id, update.clone())?;
2841 }
2842
2843 Ok(())
2844}
2845
2846/// Delete a channel
2847async fn delete_channel(
2848 request: proto::DeleteChannel,
2849 response: Response<proto::DeleteChannel>,
2850 session: MessageContext,
2851) -> Result<()> {
2852 let db = session.db().await;
2853
2854 let channel_id = request.channel_id;
2855 let (root_channel, removed_channels) = db
2856 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2857 .await?;
2858 response.send(proto::Ack {})?;
2859
2860 // Notify members of removed channels
2861 let mut update = proto::UpdateChannels::default();
2862 update
2863 .delete_channels
2864 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2865
2866 let connection_pool = session.connection_pool().await;
2867 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2868 session.peer.send(connection_id, update.clone())?;
2869 }
2870
2871 Ok(())
2872}
2873
2874/// Invite someone to join a channel.
2875async fn invite_channel_member(
2876 request: proto::InviteChannelMember,
2877 response: Response<proto::InviteChannelMember>,
2878 session: MessageContext,
2879) -> Result<()> {
2880 let db = session.db().await;
2881 let channel_id = ChannelId::from_proto(request.channel_id);
2882 let invitee_id = UserId::from_proto(request.user_id);
2883 let InviteMemberResult {
2884 channel,
2885 notifications,
2886 } = db
2887 .invite_channel_member(
2888 channel_id,
2889 invitee_id,
2890 session.user_id(),
2891 request.role().into(),
2892 )
2893 .await?;
2894
2895 let update = proto::UpdateChannels {
2896 channel_invitations: vec![channel.to_proto()],
2897 ..Default::default()
2898 };
2899
2900 let connection_pool = session.connection_pool().await;
2901 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2902 session.peer.send(connection_id, update.clone())?;
2903 }
2904
2905 send_notifications(&connection_pool, &session.peer, notifications);
2906
2907 response.send(proto::Ack {})?;
2908 Ok(())
2909}
2910
2911/// remove someone from a channel
2912async fn remove_channel_member(
2913 request: proto::RemoveChannelMember,
2914 response: Response<proto::RemoveChannelMember>,
2915 session: MessageContext,
2916) -> Result<()> {
2917 let db = session.db().await;
2918 let channel_id = ChannelId::from_proto(request.channel_id);
2919 let member_id = UserId::from_proto(request.user_id);
2920
2921 let RemoveChannelMemberResult {
2922 membership_update,
2923 notification_id,
2924 } = db
2925 .remove_channel_member(channel_id, member_id, session.user_id())
2926 .await?;
2927
2928 let mut connection_pool = session.connection_pool().await;
2929 notify_membership_updated(
2930 &mut connection_pool,
2931 membership_update,
2932 member_id,
2933 &session.peer,
2934 );
2935 for connection_id in connection_pool.user_connection_ids(member_id) {
2936 if let Some(notification_id) = notification_id {
2937 session
2938 .peer
2939 .send(
2940 connection_id,
2941 proto::DeleteNotification {
2942 notification_id: notification_id.to_proto(),
2943 },
2944 )
2945 .trace_err();
2946 }
2947 }
2948
2949 response.send(proto::Ack {})?;
2950 Ok(())
2951}
2952
2953/// Toggle the channel between public and private.
2954/// Care is taken to maintain the invariant that public channels only descend from public channels,
2955/// (though members-only channels can appear at any point in the hierarchy).
2956async fn set_channel_visibility(
2957 request: proto::SetChannelVisibility,
2958 response: Response<proto::SetChannelVisibility>,
2959 session: MessageContext,
2960) -> Result<()> {
2961 let db = session.db().await;
2962 let channel_id = ChannelId::from_proto(request.channel_id);
2963 let visibility = request.visibility().into();
2964
2965 let channel_model = db
2966 .set_channel_visibility(channel_id, visibility, session.user_id())
2967 .await?;
2968 let root_id = channel_model.root_id();
2969 let channel = Channel::from_model(channel_model);
2970
2971 let mut connection_pool = session.connection_pool().await;
2972 for (user_id, role) in connection_pool
2973 .channel_user_ids(root_id)
2974 .collect::<Vec<_>>()
2975 .into_iter()
2976 {
2977 let update = if role.can_see_channel(channel.visibility) {
2978 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2979 proto::UpdateChannels {
2980 channels: vec![channel.to_proto()],
2981 ..Default::default()
2982 }
2983 } else {
2984 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2985 proto::UpdateChannels {
2986 delete_channels: vec![channel.id.to_proto()],
2987 ..Default::default()
2988 }
2989 };
2990
2991 for connection_id in connection_pool.user_connection_ids(user_id) {
2992 session.peer.send(connection_id, update.clone())?;
2993 }
2994 }
2995
2996 response.send(proto::Ack {})?;
2997 Ok(())
2998}
2999
3000/// Alter the role for a user in the channel.
3001async fn set_channel_member_role(
3002 request: proto::SetChannelMemberRole,
3003 response: Response<proto::SetChannelMemberRole>,
3004 session: MessageContext,
3005) -> Result<()> {
3006 let db = session.db().await;
3007 let channel_id = ChannelId::from_proto(request.channel_id);
3008 let member_id = UserId::from_proto(request.user_id);
3009 let result = db
3010 .set_channel_member_role(
3011 channel_id,
3012 session.user_id(),
3013 member_id,
3014 request.role().into(),
3015 )
3016 .await?;
3017
3018 match result {
3019 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3020 let mut connection_pool = session.connection_pool().await;
3021 notify_membership_updated(
3022 &mut connection_pool,
3023 membership_update,
3024 member_id,
3025 &session.peer,
3026 )
3027 }
3028 db::SetMemberRoleResult::InviteUpdated(channel) => {
3029 let update = proto::UpdateChannels {
3030 channel_invitations: vec![channel.to_proto()],
3031 ..Default::default()
3032 };
3033
3034 for connection_id in session
3035 .connection_pool()
3036 .await
3037 .user_connection_ids(member_id)
3038 {
3039 session.peer.send(connection_id, update.clone())?;
3040 }
3041 }
3042 }
3043
3044 response.send(proto::Ack {})?;
3045 Ok(())
3046}
3047
3048/// Change the name of a channel
3049async fn rename_channel(
3050 request: proto::RenameChannel,
3051 response: Response<proto::RenameChannel>,
3052 session: MessageContext,
3053) -> Result<()> {
3054 let db = session.db().await;
3055 let channel_id = ChannelId::from_proto(request.channel_id);
3056 let channel_model = db
3057 .rename_channel(channel_id, session.user_id(), &request.name)
3058 .await?;
3059 let root_id = channel_model.root_id();
3060 let channel = Channel::from_model(channel_model);
3061
3062 response.send(proto::RenameChannelResponse {
3063 channel: Some(channel.to_proto()),
3064 })?;
3065
3066 let connection_pool = session.connection_pool().await;
3067 let update = proto::UpdateChannels {
3068 channels: vec![channel.to_proto()],
3069 ..Default::default()
3070 };
3071 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3072 if role.can_see_channel(channel.visibility) {
3073 session.peer.send(connection_id, update.clone())?;
3074 }
3075 }
3076
3077 Ok(())
3078}
3079
3080/// Move a channel to a new parent.
3081async fn move_channel(
3082 request: proto::MoveChannel,
3083 response: Response<proto::MoveChannel>,
3084 session: MessageContext,
3085) -> Result<()> {
3086 let channel_id = ChannelId::from_proto(request.channel_id);
3087 let to = ChannelId::from_proto(request.to);
3088
3089 let (root_id, channels) = session
3090 .db()
3091 .await
3092 .move_channel(channel_id, to, session.user_id())
3093 .await?;
3094
3095 let connection_pool = session.connection_pool().await;
3096 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3097 let channels = channels
3098 .iter()
3099 .filter_map(|channel| {
3100 if role.can_see_channel(channel.visibility) {
3101 Some(channel.to_proto())
3102 } else {
3103 None
3104 }
3105 })
3106 .collect::<Vec<_>>();
3107 if channels.is_empty() {
3108 continue;
3109 }
3110
3111 let update = proto::UpdateChannels {
3112 channels,
3113 ..Default::default()
3114 };
3115
3116 session.peer.send(connection_id, update.clone())?;
3117 }
3118
3119 response.send(Ack {})?;
3120 Ok(())
3121}
3122
3123async fn reorder_channel(
3124 request: proto::ReorderChannel,
3125 response: Response<proto::ReorderChannel>,
3126 session: MessageContext,
3127) -> Result<()> {
3128 let channel_id = ChannelId::from_proto(request.channel_id);
3129 let direction = request.direction();
3130
3131 let updated_channels = session
3132 .db()
3133 .await
3134 .reorder_channel(channel_id, direction, session.user_id())
3135 .await?;
3136
3137 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3138 let connection_pool = session.connection_pool().await;
3139 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3140 let channels = updated_channels
3141 .iter()
3142 .filter_map(|channel| {
3143 if role.can_see_channel(channel.visibility) {
3144 Some(channel.to_proto())
3145 } else {
3146 None
3147 }
3148 })
3149 .collect::<Vec<_>>();
3150
3151 if channels.is_empty() {
3152 continue;
3153 }
3154
3155 let update = proto::UpdateChannels {
3156 channels,
3157 ..Default::default()
3158 };
3159
3160 session.peer.send(connection_id, update.clone())?;
3161 }
3162 }
3163
3164 response.send(Ack {})?;
3165 Ok(())
3166}
3167
3168/// Get the list of channel members
3169async fn get_channel_members(
3170 request: proto::GetChannelMembers,
3171 response: Response<proto::GetChannelMembers>,
3172 session: MessageContext,
3173) -> Result<()> {
3174 let db = session.db().await;
3175 let channel_id = ChannelId::from_proto(request.channel_id);
3176 let limit = if request.limit == 0 {
3177 u16::MAX as u64
3178 } else {
3179 request.limit
3180 };
3181 let (members, users) = db
3182 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3183 .await?;
3184 response.send(proto::GetChannelMembersResponse { members, users })?;
3185 Ok(())
3186}
3187
3188/// Accept or decline a channel invitation.
3189async fn respond_to_channel_invite(
3190 request: proto::RespondToChannelInvite,
3191 response: Response<proto::RespondToChannelInvite>,
3192 session: MessageContext,
3193) -> Result<()> {
3194 let db = session.db().await;
3195 let channel_id = ChannelId::from_proto(request.channel_id);
3196 let RespondToChannelInvite {
3197 membership_update,
3198 notifications,
3199 } = db
3200 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3201 .await?;
3202
3203 let mut connection_pool = session.connection_pool().await;
3204 if let Some(membership_update) = membership_update {
3205 notify_membership_updated(
3206 &mut connection_pool,
3207 membership_update,
3208 session.user_id(),
3209 &session.peer,
3210 );
3211 } else {
3212 let update = proto::UpdateChannels {
3213 remove_channel_invitations: vec![channel_id.to_proto()],
3214 ..Default::default()
3215 };
3216
3217 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3218 session.peer.send(connection_id, update.clone())?;
3219 }
3220 };
3221
3222 send_notifications(&connection_pool, &session.peer, notifications);
3223
3224 response.send(proto::Ack {})?;
3225
3226 Ok(())
3227}
3228
3229/// Join the channels' room
3230async fn join_channel(
3231 request: proto::JoinChannel,
3232 response: Response<proto::JoinChannel>,
3233 session: MessageContext,
3234) -> Result<()> {
3235 let channel_id = ChannelId::from_proto(request.channel_id);
3236 join_channel_internal(channel_id, Box::new(response), session).await
3237}
3238
3239trait JoinChannelInternalResponse {
3240 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3241}
3242impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3243 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3244 Response::<proto::JoinChannel>::send(self, result)
3245 }
3246}
3247impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3248 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3249 Response::<proto::JoinRoom>::send(self, result)
3250 }
3251}
3252
3253async fn join_channel_internal(
3254 channel_id: ChannelId,
3255 response: Box<impl JoinChannelInternalResponse>,
3256 session: MessageContext,
3257) -> Result<()> {
3258 let joined_room = {
3259 let mut db = session.db().await;
3260 // If zed quits without leaving the room, and the user re-opens zed before the
3261 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3262 // room they were in.
3263 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3264 tracing::info!(
3265 stale_connection_id = %connection,
3266 "cleaning up stale connection",
3267 );
3268 drop(db);
3269 leave_room_for_session(&session, connection).await?;
3270 db = session.db().await;
3271 }
3272
3273 let (joined_room, membership_updated, role) = db
3274 .join_channel(channel_id, session.user_id(), session.connection_id)
3275 .await?;
3276
3277 let live_kit_connection_info =
3278 session
3279 .app_state
3280 .livekit_client
3281 .as_ref()
3282 .and_then(|live_kit| {
3283 let (can_publish, token) = if role == ChannelRole::Guest {
3284 (
3285 false,
3286 live_kit
3287 .guest_token(
3288 &joined_room.room.livekit_room,
3289 &session.user_id().to_string(),
3290 )
3291 .trace_err()?,
3292 )
3293 } else {
3294 (
3295 true,
3296 live_kit
3297 .room_token(
3298 &joined_room.room.livekit_room,
3299 &session.user_id().to_string(),
3300 )
3301 .trace_err()?,
3302 )
3303 };
3304
3305 Some(LiveKitConnectionInfo {
3306 server_url: live_kit.url().into(),
3307 token,
3308 can_publish,
3309 })
3310 });
3311
3312 response.send(proto::JoinRoomResponse {
3313 room: Some(joined_room.room.clone()),
3314 channel_id: joined_room
3315 .channel
3316 .as_ref()
3317 .map(|channel| channel.id.to_proto()),
3318 live_kit_connection_info,
3319 })?;
3320
3321 let mut connection_pool = session.connection_pool().await;
3322 if let Some(membership_updated) = membership_updated {
3323 notify_membership_updated(
3324 &mut connection_pool,
3325 membership_updated,
3326 session.user_id(),
3327 &session.peer,
3328 );
3329 }
3330
3331 room_updated(&joined_room.room, &session.peer);
3332
3333 joined_room
3334 };
3335
3336 channel_updated(
3337 &joined_room.channel.context("channel not returned")?,
3338 &joined_room.room,
3339 &session.peer,
3340 &*session.connection_pool().await,
3341 );
3342
3343 update_user_contacts(session.user_id(), &session).await?;
3344 Ok(())
3345}
3346
3347/// Start editing the channel notes
3348async fn join_channel_buffer(
3349 request: proto::JoinChannelBuffer,
3350 response: Response<proto::JoinChannelBuffer>,
3351 session: MessageContext,
3352) -> Result<()> {
3353 let db = session.db().await;
3354 let channel_id = ChannelId::from_proto(request.channel_id);
3355
3356 let open_response = db
3357 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3358 .await?;
3359
3360 let collaborators = open_response.collaborators.clone();
3361 response.send(open_response)?;
3362
3363 let update = UpdateChannelBufferCollaborators {
3364 channel_id: channel_id.to_proto(),
3365 collaborators: collaborators.clone(),
3366 };
3367 channel_buffer_updated(
3368 session.connection_id,
3369 collaborators
3370 .iter()
3371 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3372 &update,
3373 &session.peer,
3374 );
3375
3376 Ok(())
3377}
3378
3379/// Edit the channel notes
3380async fn update_channel_buffer(
3381 request: proto::UpdateChannelBuffer,
3382 session: MessageContext,
3383) -> Result<()> {
3384 let db = session.db().await;
3385 let channel_id = ChannelId::from_proto(request.channel_id);
3386
3387 let (collaborators, epoch, version) = db
3388 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3389 .await?;
3390
3391 channel_buffer_updated(
3392 session.connection_id,
3393 collaborators.clone(),
3394 &proto::UpdateChannelBuffer {
3395 channel_id: channel_id.to_proto(),
3396 operations: request.operations,
3397 },
3398 &session.peer,
3399 );
3400
3401 let pool = &*session.connection_pool().await;
3402
3403 let non_collaborators =
3404 pool.channel_connection_ids(channel_id)
3405 .filter_map(|(connection_id, _)| {
3406 if collaborators.contains(&connection_id) {
3407 None
3408 } else {
3409 Some(connection_id)
3410 }
3411 });
3412
3413 broadcast(None, non_collaborators, |peer_id| {
3414 session.peer.send(
3415 peer_id,
3416 proto::UpdateChannels {
3417 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3418 channel_id: channel_id.to_proto(),
3419 epoch: epoch as u64,
3420 version: version.clone(),
3421 }],
3422 ..Default::default()
3423 },
3424 )
3425 });
3426
3427 Ok(())
3428}
3429
3430/// Rejoin the channel notes after a connection blip
3431async fn rejoin_channel_buffers(
3432 request: proto::RejoinChannelBuffers,
3433 response: Response<proto::RejoinChannelBuffers>,
3434 session: MessageContext,
3435) -> Result<()> {
3436 let db = session.db().await;
3437 let buffers = db
3438 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3439 .await?;
3440
3441 for rejoined_buffer in &buffers {
3442 let collaborators_to_notify = rejoined_buffer
3443 .buffer
3444 .collaborators
3445 .iter()
3446 .filter_map(|c| Some(c.peer_id?.into()));
3447 channel_buffer_updated(
3448 session.connection_id,
3449 collaborators_to_notify,
3450 &proto::UpdateChannelBufferCollaborators {
3451 channel_id: rejoined_buffer.buffer.channel_id,
3452 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3453 },
3454 &session.peer,
3455 );
3456 }
3457
3458 response.send(proto::RejoinChannelBuffersResponse {
3459 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3460 })?;
3461
3462 Ok(())
3463}
3464
3465/// Stop editing the channel notes
3466async fn leave_channel_buffer(
3467 request: proto::LeaveChannelBuffer,
3468 response: Response<proto::LeaveChannelBuffer>,
3469 session: MessageContext,
3470) -> Result<()> {
3471 let db = session.db().await;
3472 let channel_id = ChannelId::from_proto(request.channel_id);
3473
3474 let left_buffer = db
3475 .leave_channel_buffer(channel_id, session.connection_id)
3476 .await?;
3477
3478 response.send(Ack {})?;
3479
3480 channel_buffer_updated(
3481 session.connection_id,
3482 left_buffer.connections,
3483 &proto::UpdateChannelBufferCollaborators {
3484 channel_id: channel_id.to_proto(),
3485 collaborators: left_buffer.collaborators,
3486 },
3487 &session.peer,
3488 );
3489
3490 Ok(())
3491}
3492
3493fn channel_buffer_updated<T: EnvelopedMessage>(
3494 sender_id: ConnectionId,
3495 collaborators: impl IntoIterator<Item = ConnectionId>,
3496 message: &T,
3497 peer: &Peer,
3498) {
3499 broadcast(Some(sender_id), collaborators, |peer_id| {
3500 peer.send(peer_id, message.clone())
3501 });
3502}
3503
3504fn send_notifications(
3505 connection_pool: &ConnectionPool,
3506 peer: &Peer,
3507 notifications: db::NotificationBatch,
3508) {
3509 for (user_id, notification) in notifications {
3510 for connection_id in connection_pool.user_connection_ids(user_id) {
3511 if let Err(error) = peer.send(
3512 connection_id,
3513 proto::AddNotification {
3514 notification: Some(notification.clone()),
3515 },
3516 ) {
3517 tracing::error!(
3518 "failed to send notification to {:?} {}",
3519 connection_id,
3520 error
3521 );
3522 }
3523 }
3524 }
3525}
3526
3527/// Send a message to the channel
3528async fn send_channel_message(
3529 _request: proto::SendChannelMessage,
3530 _response: Response<proto::SendChannelMessage>,
3531 _session: MessageContext,
3532) -> Result<()> {
3533 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3534}
3535
3536/// Delete a channel message
3537async fn remove_channel_message(
3538 _request: proto::RemoveChannelMessage,
3539 _response: Response<proto::RemoveChannelMessage>,
3540 _session: MessageContext,
3541) -> Result<()> {
3542 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3543}
3544
3545async fn update_channel_message(
3546 _request: proto::UpdateChannelMessage,
3547 _response: Response<proto::UpdateChannelMessage>,
3548 _session: MessageContext,
3549) -> Result<()> {
3550 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3551}
3552
3553/// Mark a channel message as read
3554async fn acknowledge_channel_message(
3555 _request: proto::AckChannelMessage,
3556 _session: MessageContext,
3557) -> Result<()> {
3558 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3559}
3560
3561/// Mark a buffer version as synced
3562async fn acknowledge_buffer_version(
3563 request: proto::AckBufferOperation,
3564 session: MessageContext,
3565) -> Result<()> {
3566 let buffer_id = BufferId::from_proto(request.buffer_id);
3567 session
3568 .db()
3569 .await
3570 .observe_buffer_version(
3571 buffer_id,
3572 session.user_id(),
3573 request.epoch as i32,
3574 &request.version,
3575 )
3576 .await?;
3577 Ok(())
3578}
3579
3580/// Start receiving chat updates for a channel
3581async fn join_channel_chat(
3582 _request: proto::JoinChannelChat,
3583 _response: Response<proto::JoinChannelChat>,
3584 _session: MessageContext,
3585) -> Result<()> {
3586 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3587}
3588
3589/// Stop receiving chat updates for a channel
3590async fn leave_channel_chat(
3591 _request: proto::LeaveChannelChat,
3592 _session: MessageContext,
3593) -> Result<()> {
3594 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3595}
3596
3597/// Retrieve the chat history for a channel
3598async fn get_channel_messages(
3599 _request: proto::GetChannelMessages,
3600 _response: Response<proto::GetChannelMessages>,
3601 _session: MessageContext,
3602) -> Result<()> {
3603 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3604}
3605
3606/// Retrieve specific chat messages
3607async fn get_channel_messages_by_id(
3608 _request: proto::GetChannelMessagesById,
3609 _response: Response<proto::GetChannelMessagesById>,
3610 _session: MessageContext,
3611) -> Result<()> {
3612 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3613}
3614
3615/// Retrieve the current users notifications
3616async fn get_notifications(
3617 request: proto::GetNotifications,
3618 response: Response<proto::GetNotifications>,
3619 session: MessageContext,
3620) -> Result<()> {
3621 let notifications = session
3622 .db()
3623 .await
3624 .get_notifications(
3625 session.user_id(),
3626 NOTIFICATION_COUNT_PER_PAGE,
3627 request.before_id.map(db::NotificationId::from_proto),
3628 )
3629 .await?;
3630 response.send(proto::GetNotificationsResponse {
3631 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3632 notifications,
3633 })?;
3634 Ok(())
3635}
3636
3637/// Mark notifications as read
3638async fn mark_notification_as_read(
3639 request: proto::MarkNotificationRead,
3640 response: Response<proto::MarkNotificationRead>,
3641 session: MessageContext,
3642) -> Result<()> {
3643 let database = &session.db().await;
3644 let notifications = database
3645 .mark_notification_as_read_by_id(
3646 session.user_id(),
3647 NotificationId::from_proto(request.notification_id),
3648 )
3649 .await?;
3650 send_notifications(
3651 &*session.connection_pool().await,
3652 &session.peer,
3653 notifications,
3654 );
3655 response.send(proto::Ack {})?;
3656 Ok(())
3657}
3658
3659fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3660 let message = match message {
3661 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3662 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3663 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3664 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3665 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3666 code: frame.code.into(),
3667 reason: frame.reason.as_str().to_owned().into(),
3668 })),
3669 // We should never receive a frame while reading the message, according
3670 // to the `tungstenite` maintainers:
3671 //
3672 // > It cannot occur when you read messages from the WebSocket, but it
3673 // > can be used when you want to send the raw frames (e.g. you want to
3674 // > send the frames to the WebSocket without composing the full message first).
3675 // >
3676 // > — https://github.com/snapview/tungstenite-rs/issues/268
3677 TungsteniteMessage::Frame(_) => {
3678 bail!("received an unexpected frame while reading the message")
3679 }
3680 };
3681
3682 Ok(message)
3683}
3684
3685fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3686 match message {
3687 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3688 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3689 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3690 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3691 AxumMessage::Close(frame) => {
3692 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3693 code: frame.code.into(),
3694 reason: frame.reason.as_ref().into(),
3695 }))
3696 }
3697 }
3698}
3699
3700fn notify_membership_updated(
3701 connection_pool: &mut ConnectionPool,
3702 result: MembershipUpdated,
3703 user_id: UserId,
3704 peer: &Peer,
3705) {
3706 for membership in &result.new_channels.channel_memberships {
3707 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3708 }
3709 for channel_id in &result.removed_channels {
3710 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3711 }
3712
3713 let user_channels_update = proto::UpdateUserChannels {
3714 channel_memberships: result
3715 .new_channels
3716 .channel_memberships
3717 .iter()
3718 .map(|cm| proto::ChannelMembership {
3719 channel_id: cm.channel_id.to_proto(),
3720 role: cm.role.into(),
3721 })
3722 .collect(),
3723 ..Default::default()
3724 };
3725
3726 let mut update = build_channels_update(result.new_channels);
3727 update.delete_channels = result
3728 .removed_channels
3729 .into_iter()
3730 .map(|id| id.to_proto())
3731 .collect();
3732 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3733
3734 for connection_id in connection_pool.user_connection_ids(user_id) {
3735 peer.send(connection_id, user_channels_update.clone())
3736 .trace_err();
3737 peer.send(connection_id, update.clone()).trace_err();
3738 }
3739}
3740
3741fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3742 proto::UpdateUserChannels {
3743 channel_memberships: channels
3744 .channel_memberships
3745 .iter()
3746 .map(|m| proto::ChannelMembership {
3747 channel_id: m.channel_id.to_proto(),
3748 role: m.role.into(),
3749 })
3750 .collect(),
3751 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3752 }
3753}
3754
3755fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3756 let mut update = proto::UpdateChannels::default();
3757
3758 for channel in channels.channels {
3759 update.channels.push(channel.to_proto());
3760 }
3761
3762 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3763
3764 for (channel_id, participants) in channels.channel_participants {
3765 update
3766 .channel_participants
3767 .push(proto::ChannelParticipants {
3768 channel_id: channel_id.to_proto(),
3769 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3770 });
3771 }
3772
3773 for channel in channels.invited_channels {
3774 update.channel_invitations.push(channel.to_proto());
3775 }
3776
3777 update
3778}
3779
3780fn build_initial_contacts_update(
3781 contacts: Vec<db::Contact>,
3782 pool: &ConnectionPool,
3783) -> proto::UpdateContacts {
3784 let mut update = proto::UpdateContacts::default();
3785
3786 for contact in contacts {
3787 match contact {
3788 db::Contact::Accepted { user_id, busy } => {
3789 update.contacts.push(contact_for_user(user_id, busy, pool));
3790 }
3791 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3792 db::Contact::Incoming { user_id } => {
3793 update
3794 .incoming_requests
3795 .push(proto::IncomingContactRequest {
3796 requester_id: user_id.to_proto(),
3797 })
3798 }
3799 }
3800 }
3801
3802 update
3803}
3804
3805fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3806 proto::Contact {
3807 user_id: user_id.to_proto(),
3808 online: pool.is_user_online(user_id),
3809 busy,
3810 }
3811}
3812
3813fn room_updated(room: &proto::Room, peer: &Peer) {
3814 broadcast(
3815 None,
3816 room.participants
3817 .iter()
3818 .filter_map(|participant| Some(participant.peer_id?.into())),
3819 |peer_id| {
3820 peer.send(
3821 peer_id,
3822 proto::RoomUpdated {
3823 room: Some(room.clone()),
3824 },
3825 )
3826 },
3827 );
3828}
3829
3830fn channel_updated(
3831 channel: &db::channel::Model,
3832 room: &proto::Room,
3833 peer: &Peer,
3834 pool: &ConnectionPool,
3835) {
3836 let participants = room
3837 .participants
3838 .iter()
3839 .map(|p| p.user_id)
3840 .collect::<Vec<_>>();
3841
3842 broadcast(
3843 None,
3844 pool.channel_connection_ids(channel.root_id())
3845 .filter_map(|(channel_id, role)| {
3846 role.can_see_channel(channel.visibility)
3847 .then_some(channel_id)
3848 }),
3849 |peer_id| {
3850 peer.send(
3851 peer_id,
3852 proto::UpdateChannels {
3853 channel_participants: vec![proto::ChannelParticipants {
3854 channel_id: channel.id.to_proto(),
3855 participant_user_ids: participants.clone(),
3856 }],
3857 ..Default::default()
3858 },
3859 )
3860 },
3861 );
3862}
3863
3864async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3865 let db = session.db().await;
3866
3867 let contacts = db.get_contacts(user_id).await?;
3868 let busy = db.is_user_busy(user_id).await?;
3869
3870 let pool = session.connection_pool().await;
3871 let updated_contact = contact_for_user(user_id, busy, &pool);
3872 for contact in contacts {
3873 if let db::Contact::Accepted {
3874 user_id: contact_user_id,
3875 ..
3876 } = contact
3877 {
3878 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3879 session
3880 .peer
3881 .send(
3882 contact_conn_id,
3883 proto::UpdateContacts {
3884 contacts: vec![updated_contact.clone()],
3885 remove_contacts: Default::default(),
3886 incoming_requests: Default::default(),
3887 remove_incoming_requests: Default::default(),
3888 outgoing_requests: Default::default(),
3889 remove_outgoing_requests: Default::default(),
3890 },
3891 )
3892 .trace_err();
3893 }
3894 }
3895 }
3896 Ok(())
3897}
3898
3899async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3900 let mut contacts_to_update = HashSet::default();
3901
3902 let room_id;
3903 let canceled_calls_to_user_ids;
3904 let livekit_room;
3905 let delete_livekit_room;
3906 let room;
3907 let channel;
3908
3909 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3910 contacts_to_update.insert(session.user_id());
3911
3912 for project in left_room.left_projects.values() {
3913 project_left(project, session);
3914 }
3915
3916 room_id = RoomId::from_proto(left_room.room.id);
3917 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3918 livekit_room = mem::take(&mut left_room.room.livekit_room);
3919 delete_livekit_room = left_room.deleted;
3920 room = mem::take(&mut left_room.room);
3921 channel = mem::take(&mut left_room.channel);
3922
3923 room_updated(&room, &session.peer);
3924 } else {
3925 return Ok(());
3926 }
3927
3928 if let Some(channel) = channel {
3929 channel_updated(
3930 &channel,
3931 &room,
3932 &session.peer,
3933 &*session.connection_pool().await,
3934 );
3935 }
3936
3937 {
3938 let pool = session.connection_pool().await;
3939 for canceled_user_id in canceled_calls_to_user_ids {
3940 for connection_id in pool.user_connection_ids(canceled_user_id) {
3941 session
3942 .peer
3943 .send(
3944 connection_id,
3945 proto::CallCanceled {
3946 room_id: room_id.to_proto(),
3947 },
3948 )
3949 .trace_err();
3950 }
3951 contacts_to_update.insert(canceled_user_id);
3952 }
3953 }
3954
3955 for contact_user_id in contacts_to_update {
3956 update_user_contacts(contact_user_id, session).await?;
3957 }
3958
3959 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3960 live_kit
3961 .remove_participant(livekit_room.clone(), session.user_id().to_string())
3962 .await
3963 .trace_err();
3964
3965 if delete_livekit_room {
3966 live_kit.delete_room(livekit_room).await.trace_err();
3967 }
3968 }
3969
3970 Ok(())
3971}
3972
3973async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3974 let left_channel_buffers = session
3975 .db()
3976 .await
3977 .leave_channel_buffers(session.connection_id)
3978 .await?;
3979
3980 for left_buffer in left_channel_buffers {
3981 channel_buffer_updated(
3982 session.connection_id,
3983 left_buffer.connections,
3984 &proto::UpdateChannelBufferCollaborators {
3985 channel_id: left_buffer.channel_id.to_proto(),
3986 collaborators: left_buffer.collaborators,
3987 },
3988 &session.peer,
3989 );
3990 }
3991
3992 Ok(())
3993}
3994
3995fn project_left(project: &db::LeftProject, session: &Session) {
3996 for connection_id in &project.connection_ids {
3997 if project.should_unshare {
3998 session
3999 .peer
4000 .send(
4001 *connection_id,
4002 proto::UnshareProject {
4003 project_id: project.id.to_proto(),
4004 },
4005 )
4006 .trace_err();
4007 } else {
4008 session
4009 .peer
4010 .send(
4011 *connection_id,
4012 proto::RemoveProjectCollaborator {
4013 project_id: project.id.to_proto(),
4014 peer_id: Some(session.connection_id.into()),
4015 },
4016 )
4017 .trace_err();
4018 }
4019 }
4020}
4021
4022async fn share_agent_thread(
4023 request: proto::ShareAgentThread,
4024 response: Response<proto::ShareAgentThread>,
4025 session: MessageContext,
4026) -> Result<()> {
4027 let user_id = session.user_id();
4028
4029 let share_id = SharedThreadId::from_proto(request.session_id.clone())
4030 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4031
4032 session
4033 .db()
4034 .await
4035 .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4036 .await?;
4037
4038 response.send(proto::Ack {})?;
4039
4040 Ok(())
4041}
4042
4043async fn get_shared_agent_thread(
4044 request: proto::GetSharedAgentThread,
4045 response: Response<proto::GetSharedAgentThread>,
4046 session: MessageContext,
4047) -> Result<()> {
4048 let share_id = SharedThreadId::from_proto(request.session_id)
4049 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4050
4051 let result = session.db().await.get_shared_thread(share_id).await?;
4052
4053 match result {
4054 Some((thread, username)) => {
4055 response.send(proto::GetSharedAgentThreadResponse {
4056 title: thread.title,
4057 thread_data: thread.data,
4058 sharer_username: username,
4059 created_at: thread.created_at.and_utc().to_rfc3339(),
4060 })?;
4061 }
4062 None => {
4063 return Err(anyhow!("Shared thread not found").into());
4064 }
4065 }
4066
4067 Ok(())
4068}
4069
4070pub trait ResultExt {
4071 type Ok;
4072
4073 fn trace_err(self) -> Option<Self::Ok>;
4074}
4075
4076impl<T, E> ResultExt for Result<T, E>
4077where
4078 E: std::fmt::Debug,
4079{
4080 type Ok = T;
4081
4082 #[track_caller]
4083 fn trace_err(self) -> Option<T> {
4084 match self {
4085 Ok(value) => Some(value),
4086 Err(error) => {
4087 tracing::error!("{:?}", error);
4088 None
4089 }
4090 }
4091 }
4092}