1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
10 UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use futures::TryFutureExt as _;
36use rpc::proto::split_repository_update;
37use tracing::Span;
38use util::paths::PathStyle;
39
40use futures::{
41 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
42 stream::FuturesUnordered,
43};
44use prometheus::{IntGauge, register_int_gauge};
45use rpc::{
46 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
47 proto::{
48 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
49 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
50 },
51};
52use semver::Version;
53use serde::{Serialize, Serializer};
54use std::{
55 any::TypeId,
56 future::Future,
57 marker::PhantomData,
58 mem,
59 net::SocketAddr,
60 ops::{Deref, DerefMut},
61 rc::Rc,
62 sync::{
63 Arc, OnceLock,
64 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
65 },
66 time::{Duration, Instant},
67};
68use tokio::sync::{Semaphore, watch};
69use tower::ServiceBuilder;
70use tracing::{
71 Instrument,
72 field::{self},
73 info_span, instrument,
74};
75
76pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
77
78// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
79pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
80
81const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
82const MAX_CONCURRENT_CONNECTIONS: usize = 512;
83
84static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
85
86const TOTAL_DURATION_MS: &str = "total_duration_ms";
87const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
88const QUEUE_DURATION_MS: &str = "queue_duration_ms";
89const HOST_WAITING_MS: &str = "host_waiting_ms";
90
91type MessageHandler =
92 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
93
94pub struct ConnectionGuard;
95
96impl ConnectionGuard {
97 pub fn try_acquire() -> Result<Self, ()> {
98 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
99 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
100 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
101 tracing::error!(
102 "too many concurrent connections: {}",
103 current_connections + 1
104 );
105 return Err(());
106 }
107 Ok(ConnectionGuard)
108 }
109}
110
111impl Drop for ConnectionGuard {
112 fn drop(&mut self) {
113 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
114 }
115}
116
117struct Response<R> {
118 peer: Arc<Peer>,
119 receipt: Receipt<R>,
120 responded: Arc<AtomicBool>,
121}
122
123impl<R: RequestMessage> Response<R> {
124 fn send(self, payload: R::Response) -> Result<()> {
125 self.responded.store(true, SeqCst);
126 self.peer.respond(self.receipt, payload)?;
127 Ok(())
128 }
129}
130
131#[derive(Clone, Debug)]
132pub enum Principal {
133 User(User),
134 Impersonated { user: User, admin: User },
135}
136
137impl Principal {
138 fn update_span(&self, span: &tracing::Span) {
139 match &self {
140 Principal::User(user) => {
141 span.record("user_id", user.id.0);
142 span.record("login", &user.github_login);
143 }
144 Principal::Impersonated { user, admin } => {
145 span.record("user_id", user.id.0);
146 span.record("login", &user.github_login);
147 span.record("impersonator", &admin.github_login);
148 }
149 }
150 }
151}
152
153#[derive(Clone)]
154struct MessageContext {
155 session: Session,
156 span: tracing::Span,
157}
158
159impl Deref for MessageContext {
160 type Target = Session;
161
162 fn deref(&self) -> &Self::Target {
163 &self.session
164 }
165}
166
167impl MessageContext {
168 pub fn forward_request<T: RequestMessage>(
169 &self,
170 receiver_id: ConnectionId,
171 request: T,
172 ) -> impl Future<Output = anyhow::Result<T::Response>> {
173 let request_start_time = Instant::now();
174 let span = self.span.clone();
175 tracing::info!("start forwarding request");
176 self.peer
177 .forward_request(self.connection_id, receiver_id, request)
178 .inspect(move |_| {
179 span.record(
180 HOST_WAITING_MS,
181 request_start_time.elapsed().as_micros() as f64 / 1000.0,
182 );
183 })
184 .inspect_err(|_| tracing::error!("error forwarding request"))
185 .inspect_ok(|_| tracing::info!("finished forwarding request"))
186 }
187}
188
189#[derive(Clone)]
190struct Session {
191 principal: Principal,
192 connection_id: ConnectionId,
193 db: Arc<tokio::sync::Mutex<DbHandle>>,
194 peer: Arc<Peer>,
195 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
196 app_state: Arc<AppState>,
197 /// The GeoIP country code for the user.
198 #[allow(unused)]
199 geoip_country_code: Option<String>,
200 #[allow(unused)]
201 system_id: Option<String>,
202 _executor: Executor,
203}
204
205impl Session {
206 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
207 #[cfg(test)]
208 tokio::task::yield_now().await;
209 let guard = self.db.lock().await;
210 #[cfg(test)]
211 tokio::task::yield_now().await;
212 guard
213 }
214
215 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
216 #[cfg(test)]
217 tokio::task::yield_now().await;
218 let guard = self.connection_pool.lock();
219 ConnectionPoolGuard {
220 guard,
221 _not_send: PhantomData,
222 }
223 }
224
225 #[expect(dead_code)]
226 fn is_staff(&self) -> bool {
227 match &self.principal {
228 Principal::User(user) => user.admin,
229 Principal::Impersonated { .. } => true,
230 }
231 }
232
233 fn user_id(&self) -> UserId {
234 match &self.principal {
235 Principal::User(user) => user.id,
236 Principal::Impersonated { user, .. } => user.id,
237 }
238 }
239}
240
241impl Debug for Session {
242 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
243 let mut result = f.debug_struct("Session");
244 match &self.principal {
245 Principal::User(user) => {
246 result.field("user", &user.github_login);
247 }
248 Principal::Impersonated { user, admin } => {
249 result.field("user", &user.github_login);
250 result.field("impersonator", &admin.github_login);
251 }
252 }
253 result.field("connection_id", &self.connection_id).finish()
254 }
255}
256
257struct DbHandle(Arc<Database>);
258
259impl Deref for DbHandle {
260 type Target = Database;
261
262 fn deref(&self) -> &Self::Target {
263 self.0.as_ref()
264 }
265}
266
267pub struct Server {
268 id: parking_lot::Mutex<ServerId>,
269 peer: Arc<Peer>,
270 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
271 app_state: Arc<AppState>,
272 handlers: HashMap<TypeId, MessageHandler>,
273 teardown: watch::Sender<bool>,
274}
275
276pub(crate) struct ConnectionPoolGuard<'a> {
277 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
278 _not_send: PhantomData<Rc<()>>,
279}
280
281#[derive(Serialize)]
282pub struct ServerSnapshot<'a> {
283 peer: &'a Peer,
284 #[serde(serialize_with = "serialize_deref")]
285 connection_pool: ConnectionPoolGuard<'a>,
286}
287
288pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
289where
290 S: Serializer,
291 T: Deref<Target = U>,
292 U: Serialize,
293{
294 Serialize::serialize(value.deref(), serializer)
295}
296
297impl Server {
298 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
299 let mut server = Self {
300 id: parking_lot::Mutex::new(id),
301 peer: Peer::new(id.0 as u32),
302 app_state,
303 connection_pool: Default::default(),
304 handlers: Default::default(),
305 teardown: watch::channel(false).0,
306 };
307
308 server
309 .add_request_handler(ping)
310 .add_request_handler(create_room)
311 .add_request_handler(join_room)
312 .add_request_handler(rejoin_room)
313 .add_request_handler(leave_room)
314 .add_request_handler(set_room_participant_role)
315 .add_request_handler(call)
316 .add_request_handler(cancel_call)
317 .add_message_handler(decline_call)
318 .add_request_handler(update_participant_location)
319 .add_request_handler(share_project)
320 .add_message_handler(unshare_project)
321 .add_request_handler(join_project)
322 .add_message_handler(leave_project)
323 .add_request_handler(update_project)
324 .add_request_handler(update_worktree)
325 .add_request_handler(update_repository)
326 .add_request_handler(remove_repository)
327 .add_message_handler(start_language_server)
328 .add_message_handler(update_language_server)
329 .add_message_handler(update_diagnostic_summary)
330 .add_message_handler(update_worktree_settings)
331 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
332 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
333 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
334 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
335 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
336 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
337 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
338 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
339 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
340 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
341 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
342 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
343 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
344 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
345 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
346 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
347 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
348 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
349 .add_request_handler(
350 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
351 )
352 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
353 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
354 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
355 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
356 .add_request_handler(
357 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
358 )
359 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
360 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
361 .add_request_handler(
362 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
363 )
364 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
365 .add_request_handler(
366 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
367 )
368 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
369 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
370 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
371 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
372 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
373 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
374 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
375 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
376 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
377 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
378 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
379 .add_request_handler(
380 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
381 )
382 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
383 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
384 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
385 .add_request_handler(lsp_query)
386 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
387 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
388 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
389 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
390 .add_message_handler(create_buffer_for_peer)
391 .add_message_handler(create_image_for_peer)
392 .add_request_handler(update_buffer)
393 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
394 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
395 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
396 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
397 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
398 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
399 .add_message_handler(
400 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
401 )
402 .add_request_handler(get_users)
403 .add_request_handler(fuzzy_search_users)
404 .add_request_handler(request_contact)
405 .add_request_handler(remove_contact)
406 .add_request_handler(respond_to_contact_request)
407 .add_message_handler(subscribe_to_channels)
408 .add_request_handler(create_channel)
409 .add_request_handler(delete_channel)
410 .add_request_handler(invite_channel_member)
411 .add_request_handler(remove_channel_member)
412 .add_request_handler(set_channel_member_role)
413 .add_request_handler(set_channel_visibility)
414 .add_request_handler(rename_channel)
415 .add_request_handler(join_channel_buffer)
416 .add_request_handler(leave_channel_buffer)
417 .add_message_handler(update_channel_buffer)
418 .add_request_handler(rejoin_channel_buffers)
419 .add_request_handler(get_channel_members)
420 .add_request_handler(respond_to_channel_invite)
421 .add_request_handler(join_channel)
422 .add_request_handler(join_channel_chat)
423 .add_message_handler(leave_channel_chat)
424 .add_request_handler(send_channel_message)
425 .add_request_handler(remove_channel_message)
426 .add_request_handler(update_channel_message)
427 .add_request_handler(get_channel_messages)
428 .add_request_handler(get_channel_messages_by_id)
429 .add_request_handler(get_notifications)
430 .add_request_handler(mark_notification_as_read)
431 .add_request_handler(move_channel)
432 .add_request_handler(reorder_channel)
433 .add_request_handler(follow)
434 .add_message_handler(unfollow)
435 .add_message_handler(update_followers)
436 .add_message_handler(acknowledge_channel_message)
437 .add_message_handler(acknowledge_buffer_version)
438 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
439 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
440 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
441 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
442 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
443 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
444 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
445 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
446 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
447 .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
448 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
449 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
450 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
451 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
452 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
453 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
454 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
455 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
456 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
457 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
458 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
459 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
460 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
461 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
462 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
463 .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
464 .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
465 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
466 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
467 .add_message_handler(update_context)
468 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
469 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
470 .add_request_handler(share_agent_thread)
471 .add_request_handler(get_shared_agent_thread)
472 .add_request_handler(forward_project_search_chunk);
473
474 Arc::new(server)
475 }
476
477 pub async fn start(&self) -> Result<()> {
478 let server_id = *self.id.lock();
479 let app_state = self.app_state.clone();
480 let peer = self.peer.clone();
481 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
482 let pool = self.connection_pool.clone();
483 let livekit_client = self.app_state.livekit_client.clone();
484
485 let span = info_span!("start server");
486 self.app_state.executor.spawn_detached(
487 async move {
488 tracing::info!("waiting for cleanup timeout");
489 timeout.await;
490 tracing::info!("cleanup timeout expired, retrieving stale rooms");
491
492 app_state
493 .db
494 .delete_stale_channel_chat_participants(
495 &app_state.config.zed_environment,
496 server_id,
497 )
498 .await
499 .trace_err();
500
501 if let Some((room_ids, channel_ids)) = app_state
502 .db
503 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
504 .await
505 .trace_err()
506 {
507 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
508 tracing::info!(
509 stale_channel_buffer_count = channel_ids.len(),
510 "retrieved stale channel buffers"
511 );
512
513 for channel_id in channel_ids {
514 if let Some(refreshed_channel_buffer) = app_state
515 .db
516 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
517 .await
518 .trace_err()
519 {
520 for connection_id in refreshed_channel_buffer.connection_ids {
521 peer.send(
522 connection_id,
523 proto::UpdateChannelBufferCollaborators {
524 channel_id: channel_id.to_proto(),
525 collaborators: refreshed_channel_buffer
526 .collaborators
527 .clone(),
528 },
529 )
530 .trace_err();
531 }
532 }
533 }
534
535 for room_id in room_ids {
536 let mut contacts_to_update = HashSet::default();
537 let mut canceled_calls_to_user_ids = Vec::new();
538 let mut livekit_room = String::new();
539 let mut delete_livekit_room = false;
540
541 if let Some(mut refreshed_room) = app_state
542 .db
543 .clear_stale_room_participants(room_id, server_id)
544 .await
545 .trace_err()
546 {
547 tracing::info!(
548 room_id = room_id.0,
549 new_participant_count = refreshed_room.room.participants.len(),
550 "refreshed room"
551 );
552 room_updated(&refreshed_room.room, &peer);
553 if let Some(channel) = refreshed_room.channel.as_ref() {
554 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
555 }
556 contacts_to_update
557 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
558 contacts_to_update
559 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
560 canceled_calls_to_user_ids =
561 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
562 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
563 delete_livekit_room = refreshed_room.room.participants.is_empty();
564 }
565
566 {
567 let pool = pool.lock();
568 for canceled_user_id in canceled_calls_to_user_ids {
569 for connection_id in pool.user_connection_ids(canceled_user_id) {
570 peer.send(
571 connection_id,
572 proto::CallCanceled {
573 room_id: room_id.to_proto(),
574 },
575 )
576 .trace_err();
577 }
578 }
579 }
580
581 for user_id in contacts_to_update {
582 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
583 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
584 if let Some((busy, contacts)) = busy.zip(contacts) {
585 let pool = pool.lock();
586 let updated_contact = contact_for_user(user_id, busy, &pool);
587 for contact in contacts {
588 if let db::Contact::Accepted {
589 user_id: contact_user_id,
590 ..
591 } = contact
592 {
593 for contact_conn_id in
594 pool.user_connection_ids(contact_user_id)
595 {
596 peer.send(
597 contact_conn_id,
598 proto::UpdateContacts {
599 contacts: vec![updated_contact.clone()],
600 remove_contacts: Default::default(),
601 incoming_requests: Default::default(),
602 remove_incoming_requests: Default::default(),
603 outgoing_requests: Default::default(),
604 remove_outgoing_requests: Default::default(),
605 },
606 )
607 .trace_err();
608 }
609 }
610 }
611 }
612 }
613
614 if let Some(live_kit) = livekit_client.as_ref()
615 && delete_livekit_room
616 {
617 live_kit.delete_room(livekit_room).await.trace_err();
618 }
619 }
620 }
621
622 app_state
623 .db
624 .delete_stale_channel_chat_participants(
625 &app_state.config.zed_environment,
626 server_id,
627 )
628 .await
629 .trace_err();
630
631 app_state
632 .db
633 .clear_old_worktree_entries(server_id)
634 .await
635 .trace_err();
636
637 app_state
638 .db
639 .delete_stale_servers(&app_state.config.zed_environment, server_id)
640 .await
641 .trace_err();
642 }
643 .instrument(span),
644 );
645 Ok(())
646 }
647
648 pub fn teardown(&self) {
649 self.peer.teardown();
650 self.connection_pool.lock().reset();
651 let _ = self.teardown.send(true);
652 }
653
654 #[cfg(test)]
655 pub fn reset(&self, id: ServerId) {
656 self.teardown();
657 *self.id.lock() = id;
658 self.peer.reset(id.0 as u32);
659 let _ = self.teardown.send(false);
660 }
661
662 #[cfg(test)]
663 pub fn id(&self) -> ServerId {
664 *self.id.lock()
665 }
666
667 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
668 where
669 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
670 Fut: 'static + Send + Future<Output = Result<()>>,
671 M: EnvelopedMessage,
672 {
673 let prev_handler = self.handlers.insert(
674 TypeId::of::<M>(),
675 Box::new(move |envelope, session, span| {
676 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
677 let received_at = envelope.received_at;
678 tracing::info!("message received");
679 let start_time = Instant::now();
680 let future = (handler)(
681 *envelope,
682 MessageContext {
683 session,
684 span: span.clone(),
685 },
686 );
687 async move {
688 let result = future.await;
689 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
690 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
691 let queue_duration_ms = total_duration_ms - processing_duration_ms;
692 span.record(TOTAL_DURATION_MS, total_duration_ms);
693 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
694 span.record(QUEUE_DURATION_MS, queue_duration_ms);
695 match result {
696 Err(error) => {
697 tracing::error!(?error, "error handling message")
698 }
699 Ok(()) => tracing::info!("finished handling message"),
700 }
701 }
702 .boxed()
703 }),
704 );
705 if prev_handler.is_some() {
706 panic!("registered a handler for the same message twice");
707 }
708 self
709 }
710
711 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
712 where
713 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
714 Fut: 'static + Send + Future<Output = Result<()>>,
715 M: EnvelopedMessage,
716 {
717 self.add_handler(move |envelope, session| handler(envelope.payload, session));
718 self
719 }
720
721 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
722 where
723 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
724 Fut: Send + Future<Output = Result<()>>,
725 M: RequestMessage,
726 {
727 let handler = Arc::new(handler);
728 self.add_handler(move |envelope, session| {
729 let receipt = envelope.receipt();
730 let handler = handler.clone();
731 async move {
732 let peer = session.peer.clone();
733 let responded = Arc::new(AtomicBool::default());
734 let response = Response {
735 peer: peer.clone(),
736 responded: responded.clone(),
737 receipt,
738 };
739 match (handler)(envelope.payload, response, session).await {
740 Ok(()) => {
741 if responded.load(std::sync::atomic::Ordering::SeqCst) {
742 Ok(())
743 } else {
744 Err(anyhow!("handler did not send a response"))?
745 }
746 }
747 Err(error) => {
748 let proto_err = match &error {
749 Error::Internal(err) => err.to_proto(),
750 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
751 };
752 peer.respond_with_error(receipt, proto_err)?;
753 Err(error)
754 }
755 }
756 }
757 })
758 }
759
760 pub fn handle_connection(
761 self: &Arc<Self>,
762 connection: Connection,
763 address: String,
764 principal: Principal,
765 zed_version: ZedVersion,
766 release_channel: Option<String>,
767 user_agent: Option<String>,
768 geoip_country_code: Option<String>,
769 system_id: Option<String>,
770 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
771 executor: Executor,
772 connection_guard: Option<ConnectionGuard>,
773 ) -> impl Future<Output = ()> + use<> {
774 let this = self.clone();
775 let span = info_span!("handle connection", %address,
776 connection_id=field::Empty,
777 user_id=field::Empty,
778 login=field::Empty,
779 impersonator=field::Empty,
780 user_agent=field::Empty,
781 geoip_country_code=field::Empty,
782 release_channel=field::Empty,
783 );
784 principal.update_span(&span);
785 if let Some(user_agent) = user_agent {
786 span.record("user_agent", user_agent);
787 }
788 if let Some(release_channel) = release_channel {
789 span.record("release_channel", release_channel);
790 }
791
792 if let Some(country_code) = geoip_country_code.as_ref() {
793 span.record("geoip_country_code", country_code);
794 }
795
796 let mut teardown = self.teardown.subscribe();
797 async move {
798 if *teardown.borrow() {
799 tracing::error!("server is tearing down");
800 return;
801 }
802
803 let (connection_id, handle_io, mut incoming_rx) =
804 this.peer.add_connection(connection, {
805 let executor = executor.clone();
806 move |duration| executor.sleep(duration)
807 });
808 tracing::Span::current().record("connection_id", format!("{}", connection_id));
809
810 tracing::info!("connection opened");
811
812 let session = Session {
813 principal: principal.clone(),
814 connection_id,
815 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
816 peer: this.peer.clone(),
817 connection_pool: this.connection_pool.clone(),
818 app_state: this.app_state.clone(),
819 geoip_country_code,
820 system_id,
821 _executor: executor.clone(),
822 };
823
824 if let Err(error) = this
825 .send_initial_client_update(
826 connection_id,
827 zed_version,
828 send_connection_id,
829 &session,
830 )
831 .await
832 {
833 tracing::error!(?error, "failed to send initial client update");
834 return;
835 }
836 drop(connection_guard);
837
838 let handle_io = handle_io.fuse();
839 futures::pin_mut!(handle_io);
840
841 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
842 // This prevents deadlocks when e.g., client A performs a request to client B and
843 // client B performs a request to client A. If both clients stop processing further
844 // messages until their respective request completes, they won't have a chance to
845 // respond to the other client's request and cause a deadlock.
846 //
847 // This arrangement ensures we will attempt to process earlier messages first, but fall
848 // back to processing messages arrived later in the spirit of making progress.
849 const MAX_CONCURRENT_HANDLERS: usize = 256;
850 let mut foreground_message_handlers = FuturesUnordered::new();
851 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
852 let get_concurrent_handlers = {
853 let concurrent_handlers = concurrent_handlers.clone();
854 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
855 };
856 loop {
857 let next_message = async {
858 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
859 let message = incoming_rx.next().await;
860 // Cache the concurrent_handlers here, so that we know what the
861 // queue looks like as each handler starts
862 (permit, message, get_concurrent_handlers())
863 }
864 .fuse();
865 futures::pin_mut!(next_message);
866 futures::select_biased! {
867 _ = teardown.changed().fuse() => return,
868 result = handle_io => {
869 if let Err(error) = result {
870 tracing::error!(?error, "error handling I/O");
871 }
872 break;
873 }
874 _ = foreground_message_handlers.next() => {}
875 next_message = next_message => {
876 let (permit, message, concurrent_handlers) = next_message;
877 if let Some(message) = message {
878 let type_name = message.payload_type_name();
879 // note: we copy all the fields from the parent span so we can query them in the logs.
880 // (https://github.com/tokio-rs/tracing/issues/2670).
881 let span = tracing::info_span!("receive message",
882 %connection_id,
883 %address,
884 type_name,
885 concurrent_handlers,
886 user_id=field::Empty,
887 login=field::Empty,
888 impersonator=field::Empty,
889 lsp_query_request=field::Empty,
890 release_channel=field::Empty,
891 { TOTAL_DURATION_MS }=field::Empty,
892 { PROCESSING_DURATION_MS }=field::Empty,
893 { QUEUE_DURATION_MS }=field::Empty,
894 { HOST_WAITING_MS }=field::Empty
895 );
896 principal.update_span(&span);
897 let span_enter = span.enter();
898 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
899 let is_background = message.is_background();
900 let handle_message = (handler)(message, session.clone(), span.clone());
901 drop(span_enter);
902
903 let handle_message = async move {
904 handle_message.await;
905 drop(permit);
906 }.instrument(span);
907 if is_background {
908 executor.spawn_detached(handle_message);
909 } else {
910 foreground_message_handlers.push(handle_message);
911 }
912 } else {
913 tracing::error!("no message handler");
914 }
915 } else {
916 tracing::info!("connection closed");
917 break;
918 }
919 }
920 }
921 }
922
923 drop(foreground_message_handlers);
924 let concurrent_handlers = get_concurrent_handlers();
925 tracing::info!(concurrent_handlers, "signing out");
926 if let Err(error) = connection_lost(session, teardown, executor).await {
927 tracing::error!(?error, "error signing out");
928 }
929 }
930 .instrument(span)
931 }
932
933 async fn send_initial_client_update(
934 &self,
935 connection_id: ConnectionId,
936 zed_version: ZedVersion,
937 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
938 session: &Session,
939 ) -> Result<()> {
940 self.peer.send(
941 connection_id,
942 proto::Hello {
943 peer_id: Some(connection_id.into()),
944 },
945 )?;
946 tracing::info!("sent hello message");
947 if let Some(send_connection_id) = send_connection_id.take() {
948 let _ = send_connection_id.send(connection_id);
949 }
950
951 match &session.principal {
952 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
953 if !user.connected_once {
954 self.peer.send(connection_id, proto::ShowContacts {})?;
955 self.app_state
956 .db
957 .set_user_connected_once(user.id, true)
958 .await?;
959 }
960
961 let contacts = self.app_state.db.get_contacts(user.id).await?;
962
963 {
964 let mut pool = self.connection_pool.lock();
965 pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
966 self.peer.send(
967 connection_id,
968 build_initial_contacts_update(contacts, &pool),
969 )?;
970 }
971
972 if should_auto_subscribe_to_channels(&zed_version) {
973 subscribe_user_to_channels(user.id, session).await?;
974 }
975
976 if let Some(incoming_call) =
977 self.app_state.db.incoming_call_for_user(user.id).await?
978 {
979 self.peer.send(connection_id, incoming_call)?;
980 }
981
982 update_user_contacts(user.id, session).await?;
983 }
984 }
985
986 Ok(())
987 }
988
989 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
990 ServerSnapshot {
991 connection_pool: ConnectionPoolGuard {
992 guard: self.connection_pool.lock(),
993 _not_send: PhantomData,
994 },
995 peer: &self.peer,
996 }
997 }
998}
999
1000impl Deref for ConnectionPoolGuard<'_> {
1001 type Target = ConnectionPool;
1002
1003 fn deref(&self) -> &Self::Target {
1004 &self.guard
1005 }
1006}
1007
1008impl DerefMut for ConnectionPoolGuard<'_> {
1009 fn deref_mut(&mut self) -> &mut Self::Target {
1010 &mut self.guard
1011 }
1012}
1013
1014impl Drop for ConnectionPoolGuard<'_> {
1015 fn drop(&mut self) {
1016 #[cfg(test)]
1017 self.check_invariants();
1018 }
1019}
1020
1021fn broadcast<F>(
1022 sender_id: Option<ConnectionId>,
1023 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1024 mut f: F,
1025) where
1026 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1027{
1028 for receiver_id in receiver_ids {
1029 if Some(receiver_id) != sender_id
1030 && let Err(error) = f(receiver_id)
1031 {
1032 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1033 }
1034 }
1035}
1036
1037pub struct ProtocolVersion(u32);
1038
1039impl Header for ProtocolVersion {
1040 fn name() -> &'static HeaderName {
1041 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1042 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1043 }
1044
1045 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1046 where
1047 Self: Sized,
1048 I: Iterator<Item = &'i axum::http::HeaderValue>,
1049 {
1050 let version = values
1051 .next()
1052 .ok_or_else(axum::headers::Error::invalid)?
1053 .to_str()
1054 .map_err(|_| axum::headers::Error::invalid())?
1055 .parse()
1056 .map_err(|_| axum::headers::Error::invalid())?;
1057 Ok(Self(version))
1058 }
1059
1060 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1061 values.extend([self.0.to_string().parse().unwrap()]);
1062 }
1063}
1064
1065pub struct AppVersionHeader(Version);
1066impl Header for AppVersionHeader {
1067 fn name() -> &'static HeaderName {
1068 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1069 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1070 }
1071
1072 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1073 where
1074 Self: Sized,
1075 I: Iterator<Item = &'i axum::http::HeaderValue>,
1076 {
1077 let version = values
1078 .next()
1079 .ok_or_else(axum::headers::Error::invalid)?
1080 .to_str()
1081 .map_err(|_| axum::headers::Error::invalid())?
1082 .parse()
1083 .map_err(|_| axum::headers::Error::invalid())?;
1084 Ok(Self(version))
1085 }
1086
1087 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1088 values.extend([self.0.to_string().parse().unwrap()]);
1089 }
1090}
1091
1092#[derive(Debug)]
1093pub struct ReleaseChannelHeader(String);
1094
1095impl Header for ReleaseChannelHeader {
1096 fn name() -> &'static HeaderName {
1097 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1098 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1099 }
1100
1101 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1102 where
1103 Self: Sized,
1104 I: Iterator<Item = &'i axum::http::HeaderValue>,
1105 {
1106 Ok(Self(
1107 values
1108 .next()
1109 .ok_or_else(axum::headers::Error::invalid)?
1110 .to_str()
1111 .map_err(|_| axum::headers::Error::invalid())?
1112 .to_owned(),
1113 ))
1114 }
1115
1116 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1117 values.extend([self.0.parse().unwrap()]);
1118 }
1119}
1120
1121pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1122 Router::new()
1123 .route("/rpc", get(handle_websocket_request))
1124 .layer(
1125 ServiceBuilder::new()
1126 .layer(Extension(server.app_state.clone()))
1127 .layer(middleware::from_fn(auth::validate_header)),
1128 )
1129 .route("/metrics", get(handle_metrics))
1130 .layer(Extension(server))
1131}
1132
1133pub async fn handle_websocket_request(
1134 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1135 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1136 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1137 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1138 Extension(server): Extension<Arc<Server>>,
1139 Extension(principal): Extension<Principal>,
1140 user_agent: Option<TypedHeader<UserAgent>>,
1141 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1142 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1143 ws: WebSocketUpgrade,
1144) -> axum::response::Response {
1145 if protocol_version != rpc::PROTOCOL_VERSION {
1146 return (
1147 StatusCode::UPGRADE_REQUIRED,
1148 "client must be upgraded".to_string(),
1149 )
1150 .into_response();
1151 }
1152
1153 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1154 return (
1155 StatusCode::UPGRADE_REQUIRED,
1156 "no version header found".to_string(),
1157 )
1158 .into_response();
1159 };
1160
1161 let release_channel = release_channel_header.map(|header| header.0.0);
1162
1163 if !version.can_collaborate() {
1164 return (
1165 StatusCode::UPGRADE_REQUIRED,
1166 "client must be upgraded".to_string(),
1167 )
1168 .into_response();
1169 }
1170
1171 let socket_address = socket_address.to_string();
1172
1173 // Acquire connection guard before WebSocket upgrade
1174 let connection_guard = match ConnectionGuard::try_acquire() {
1175 Ok(guard) => guard,
1176 Err(()) => {
1177 return (
1178 StatusCode::SERVICE_UNAVAILABLE,
1179 "Too many concurrent connections",
1180 )
1181 .into_response();
1182 }
1183 };
1184
1185 ws.on_upgrade(move |socket| {
1186 let socket = socket
1187 .map_ok(to_tungstenite_message)
1188 .err_into()
1189 .with(|message| async move { to_axum_message(message) });
1190 let connection = Connection::new(Box::pin(socket));
1191 async move {
1192 server
1193 .handle_connection(
1194 connection,
1195 socket_address,
1196 principal,
1197 version,
1198 release_channel,
1199 user_agent.map(|header| header.to_string()),
1200 country_code_header.map(|header| header.to_string()),
1201 system_id_header.map(|header| header.to_string()),
1202 None,
1203 Executor::Production,
1204 Some(connection_guard),
1205 )
1206 .await;
1207 }
1208 })
1209}
1210
1211pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1212 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1213 let connections_metric = CONNECTIONS_METRIC
1214 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1215
1216 let connections = server
1217 .connection_pool
1218 .lock()
1219 .connections()
1220 .filter(|connection| !connection.admin)
1221 .count();
1222 connections_metric.set(connections as _);
1223
1224 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1225 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1226 register_int_gauge!(
1227 "shared_projects",
1228 "number of open projects with one or more guests"
1229 )
1230 .unwrap()
1231 });
1232
1233 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1234 shared_projects_metric.set(shared_projects as _);
1235
1236 let encoder = prometheus::TextEncoder::new();
1237 let metric_families = prometheus::gather();
1238 let encoded_metrics = encoder
1239 .encode_to_string(&metric_families)
1240 .map_err(|err| anyhow!("{err}"))?;
1241 Ok(encoded_metrics)
1242}
1243
1244#[instrument(err, skip(executor))]
1245async fn connection_lost(
1246 session: Session,
1247 mut teardown: watch::Receiver<bool>,
1248 executor: Executor,
1249) -> Result<()> {
1250 session.peer.disconnect(session.connection_id);
1251 session
1252 .connection_pool()
1253 .await
1254 .remove_connection(session.connection_id)?;
1255
1256 session
1257 .db()
1258 .await
1259 .connection_lost(session.connection_id)
1260 .await
1261 .trace_err();
1262
1263 futures::select_biased! {
1264 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1265
1266 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1267 leave_room_for_session(&session, session.connection_id).await.trace_err();
1268 leave_channel_buffers_for_session(&session)
1269 .await
1270 .trace_err();
1271
1272 if !session
1273 .connection_pool()
1274 .await
1275 .is_user_online(session.user_id())
1276 {
1277 let db = session.db().await;
1278 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1279 room_updated(&room, &session.peer);
1280 }
1281 }
1282
1283 update_user_contacts(session.user_id(), &session).await?;
1284 },
1285 _ = teardown.changed().fuse() => {}
1286 }
1287
1288 Ok(())
1289}
1290
1291/// Acknowledges a ping from a client, used to keep the connection alive.
1292async fn ping(
1293 _: proto::Ping,
1294 response: Response<proto::Ping>,
1295 _session: MessageContext,
1296) -> Result<()> {
1297 response.send(proto::Ack {})?;
1298 Ok(())
1299}
1300
1301/// Creates a new room for calling (outside of channels)
1302async fn create_room(
1303 _request: proto::CreateRoom,
1304 response: Response<proto::CreateRoom>,
1305 session: MessageContext,
1306) -> Result<()> {
1307 let livekit_room = nanoid::nanoid!(30);
1308
1309 let live_kit_connection_info = util::maybe!(async {
1310 let live_kit = session.app_state.livekit_client.as_ref();
1311 let live_kit = live_kit?;
1312 let user_id = session.user_id().to_string();
1313
1314 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1315
1316 Some(proto::LiveKitConnectionInfo {
1317 server_url: live_kit.url().into(),
1318 token,
1319 can_publish: true,
1320 })
1321 })
1322 .await;
1323
1324 let room = session
1325 .db()
1326 .await
1327 .create_room(session.user_id(), session.connection_id, &livekit_room)
1328 .await?;
1329
1330 response.send(proto::CreateRoomResponse {
1331 room: Some(room.clone()),
1332 live_kit_connection_info,
1333 })?;
1334
1335 update_user_contacts(session.user_id(), &session).await?;
1336 Ok(())
1337}
1338
1339/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1340async fn join_room(
1341 request: proto::JoinRoom,
1342 response: Response<proto::JoinRoom>,
1343 session: MessageContext,
1344) -> Result<()> {
1345 let room_id = RoomId::from_proto(request.id);
1346
1347 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1348
1349 if let Some(channel_id) = channel_id {
1350 return join_channel_internal(channel_id, Box::new(response), session).await;
1351 }
1352
1353 let joined_room = {
1354 let room = session
1355 .db()
1356 .await
1357 .join_room(room_id, session.user_id(), session.connection_id)
1358 .await?;
1359 room_updated(&room.room, &session.peer);
1360 room.into_inner()
1361 };
1362
1363 for connection_id in session
1364 .connection_pool()
1365 .await
1366 .user_connection_ids(session.user_id())
1367 {
1368 session
1369 .peer
1370 .send(
1371 connection_id,
1372 proto::CallCanceled {
1373 room_id: room_id.to_proto(),
1374 },
1375 )
1376 .trace_err();
1377 }
1378
1379 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1380 {
1381 live_kit
1382 .room_token(
1383 &joined_room.room.livekit_room,
1384 &session.user_id().to_string(),
1385 )
1386 .trace_err()
1387 .map(|token| proto::LiveKitConnectionInfo {
1388 server_url: live_kit.url().into(),
1389 token,
1390 can_publish: true,
1391 })
1392 } else {
1393 None
1394 };
1395
1396 response.send(proto::JoinRoomResponse {
1397 room: Some(joined_room.room),
1398 channel_id: None,
1399 live_kit_connection_info,
1400 })?;
1401
1402 update_user_contacts(session.user_id(), &session).await?;
1403 Ok(())
1404}
1405
1406/// Rejoin room is used to reconnect to a room after connection errors.
1407async fn rejoin_room(
1408 request: proto::RejoinRoom,
1409 response: Response<proto::RejoinRoom>,
1410 session: MessageContext,
1411) -> Result<()> {
1412 let room;
1413 let channel;
1414 {
1415 let mut rejoined_room = session
1416 .db()
1417 .await
1418 .rejoin_room(request, session.user_id(), session.connection_id)
1419 .await?;
1420
1421 response.send(proto::RejoinRoomResponse {
1422 room: Some(rejoined_room.room.clone()),
1423 reshared_projects: rejoined_room
1424 .reshared_projects
1425 .iter()
1426 .map(|project| proto::ResharedProject {
1427 id: project.id.to_proto(),
1428 collaborators: project
1429 .collaborators
1430 .iter()
1431 .map(|collaborator| collaborator.to_proto())
1432 .collect(),
1433 })
1434 .collect(),
1435 rejoined_projects: rejoined_room
1436 .rejoined_projects
1437 .iter()
1438 .map(|rejoined_project| rejoined_project.to_proto())
1439 .collect(),
1440 })?;
1441 room_updated(&rejoined_room.room, &session.peer);
1442
1443 for project in &rejoined_room.reshared_projects {
1444 for collaborator in &project.collaborators {
1445 session
1446 .peer
1447 .send(
1448 collaborator.connection_id,
1449 proto::UpdateProjectCollaborator {
1450 project_id: project.id.to_proto(),
1451 old_peer_id: Some(project.old_connection_id.into()),
1452 new_peer_id: Some(session.connection_id.into()),
1453 },
1454 )
1455 .trace_err();
1456 }
1457
1458 broadcast(
1459 Some(session.connection_id),
1460 project
1461 .collaborators
1462 .iter()
1463 .map(|collaborator| collaborator.connection_id),
1464 |connection_id| {
1465 session.peer.forward_send(
1466 session.connection_id,
1467 connection_id,
1468 proto::UpdateProject {
1469 project_id: project.id.to_proto(),
1470 worktrees: project.worktrees.clone(),
1471 },
1472 )
1473 },
1474 );
1475 }
1476
1477 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1478
1479 let rejoined_room = rejoined_room.into_inner();
1480
1481 room = rejoined_room.room;
1482 channel = rejoined_room.channel;
1483 }
1484
1485 if let Some(channel) = channel {
1486 channel_updated(
1487 &channel,
1488 &room,
1489 &session.peer,
1490 &*session.connection_pool().await,
1491 );
1492 }
1493
1494 update_user_contacts(session.user_id(), &session).await?;
1495 Ok(())
1496}
1497
1498fn notify_rejoined_projects(
1499 rejoined_projects: &mut Vec<RejoinedProject>,
1500 session: &Session,
1501) -> Result<()> {
1502 for project in rejoined_projects.iter() {
1503 for collaborator in &project.collaborators {
1504 session
1505 .peer
1506 .send(
1507 collaborator.connection_id,
1508 proto::UpdateProjectCollaborator {
1509 project_id: project.id.to_proto(),
1510 old_peer_id: Some(project.old_connection_id.into()),
1511 new_peer_id: Some(session.connection_id.into()),
1512 },
1513 )
1514 .trace_err();
1515 }
1516 }
1517
1518 for project in rejoined_projects {
1519 for worktree in mem::take(&mut project.worktrees) {
1520 // Stream this worktree's entries.
1521 let message = proto::UpdateWorktree {
1522 project_id: project.id.to_proto(),
1523 worktree_id: worktree.id,
1524 abs_path: worktree.abs_path.clone(),
1525 root_name: worktree.root_name,
1526 updated_entries: worktree.updated_entries,
1527 removed_entries: worktree.removed_entries,
1528 scan_id: worktree.scan_id,
1529 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1530 updated_repositories: worktree.updated_repositories,
1531 removed_repositories: worktree.removed_repositories,
1532 };
1533 for update in proto::split_worktree_update(message) {
1534 session.peer.send(session.connection_id, update)?;
1535 }
1536
1537 // Stream this worktree's diagnostics.
1538 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1539 if let Some(summary) = worktree_diagnostics.next() {
1540 let message = proto::UpdateDiagnosticSummary {
1541 project_id: project.id.to_proto(),
1542 worktree_id: worktree.id,
1543 summary: Some(summary),
1544 more_summaries: worktree_diagnostics.collect(),
1545 };
1546 session.peer.send(session.connection_id, message)?;
1547 }
1548
1549 for settings_file in worktree.settings_files {
1550 session.peer.send(
1551 session.connection_id,
1552 proto::UpdateWorktreeSettings {
1553 project_id: project.id.to_proto(),
1554 worktree_id: worktree.id,
1555 path: settings_file.path,
1556 content: Some(settings_file.content),
1557 kind: Some(settings_file.kind.to_proto().into()),
1558 outside_worktree: Some(settings_file.outside_worktree),
1559 },
1560 )?;
1561 }
1562 }
1563
1564 for repository in mem::take(&mut project.updated_repositories) {
1565 for update in split_repository_update(repository) {
1566 session.peer.send(session.connection_id, update)?;
1567 }
1568 }
1569
1570 for id in mem::take(&mut project.removed_repositories) {
1571 session.peer.send(
1572 session.connection_id,
1573 proto::RemoveRepository {
1574 project_id: project.id.to_proto(),
1575 id,
1576 },
1577 )?;
1578 }
1579 }
1580
1581 Ok(())
1582}
1583
1584/// leave room disconnects from the room.
1585async fn leave_room(
1586 _: proto::LeaveRoom,
1587 response: Response<proto::LeaveRoom>,
1588 session: MessageContext,
1589) -> Result<()> {
1590 leave_room_for_session(&session, session.connection_id).await?;
1591 response.send(proto::Ack {})?;
1592 Ok(())
1593}
1594
1595/// Updates the permissions of someone else in the room.
1596async fn set_room_participant_role(
1597 request: proto::SetRoomParticipantRole,
1598 response: Response<proto::SetRoomParticipantRole>,
1599 session: MessageContext,
1600) -> Result<()> {
1601 let user_id = UserId::from_proto(request.user_id);
1602 let role = ChannelRole::from(request.role());
1603
1604 let (livekit_room, can_publish) = {
1605 let room = session
1606 .db()
1607 .await
1608 .set_room_participant_role(
1609 session.user_id(),
1610 RoomId::from_proto(request.room_id),
1611 user_id,
1612 role,
1613 )
1614 .await?;
1615
1616 let livekit_room = room.livekit_room.clone();
1617 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1618 room_updated(&room, &session.peer);
1619 (livekit_room, can_publish)
1620 };
1621
1622 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1623 live_kit
1624 .update_participant(
1625 livekit_room.clone(),
1626 request.user_id.to_string(),
1627 livekit_api::proto::ParticipantPermission {
1628 can_subscribe: true,
1629 can_publish,
1630 can_publish_data: can_publish,
1631 hidden: false,
1632 recorder: false,
1633 },
1634 )
1635 .await
1636 .trace_err();
1637 }
1638
1639 response.send(proto::Ack {})?;
1640 Ok(())
1641}
1642
1643/// Call someone else into the current room
1644async fn call(
1645 request: proto::Call,
1646 response: Response<proto::Call>,
1647 session: MessageContext,
1648) -> Result<()> {
1649 let room_id = RoomId::from_proto(request.room_id);
1650 let calling_user_id = session.user_id();
1651 let calling_connection_id = session.connection_id;
1652 let called_user_id = UserId::from_proto(request.called_user_id);
1653 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1654 if !session
1655 .db()
1656 .await
1657 .has_contact(calling_user_id, called_user_id)
1658 .await?
1659 {
1660 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1661 }
1662
1663 let incoming_call = {
1664 let (room, incoming_call) = &mut *session
1665 .db()
1666 .await
1667 .call(
1668 room_id,
1669 calling_user_id,
1670 calling_connection_id,
1671 called_user_id,
1672 initial_project_id,
1673 )
1674 .await?;
1675 room_updated(room, &session.peer);
1676 mem::take(incoming_call)
1677 };
1678 update_user_contacts(called_user_id, &session).await?;
1679
1680 let mut calls = session
1681 .connection_pool()
1682 .await
1683 .user_connection_ids(called_user_id)
1684 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1685 .collect::<FuturesUnordered<_>>();
1686
1687 while let Some(call_response) = calls.next().await {
1688 match call_response.as_ref() {
1689 Ok(_) => {
1690 response.send(proto::Ack {})?;
1691 return Ok(());
1692 }
1693 Err(_) => {
1694 call_response.trace_err();
1695 }
1696 }
1697 }
1698
1699 {
1700 let room = session
1701 .db()
1702 .await
1703 .call_failed(room_id, called_user_id)
1704 .await?;
1705 room_updated(&room, &session.peer);
1706 }
1707 update_user_contacts(called_user_id, &session).await?;
1708
1709 Err(anyhow!("failed to ring user"))?
1710}
1711
1712/// Cancel an outgoing call.
1713async fn cancel_call(
1714 request: proto::CancelCall,
1715 response: Response<proto::CancelCall>,
1716 session: MessageContext,
1717) -> Result<()> {
1718 let called_user_id = UserId::from_proto(request.called_user_id);
1719 let room_id = RoomId::from_proto(request.room_id);
1720 {
1721 let room = session
1722 .db()
1723 .await
1724 .cancel_call(room_id, session.connection_id, called_user_id)
1725 .await?;
1726 room_updated(&room, &session.peer);
1727 }
1728
1729 for connection_id in session
1730 .connection_pool()
1731 .await
1732 .user_connection_ids(called_user_id)
1733 {
1734 session
1735 .peer
1736 .send(
1737 connection_id,
1738 proto::CallCanceled {
1739 room_id: room_id.to_proto(),
1740 },
1741 )
1742 .trace_err();
1743 }
1744 response.send(proto::Ack {})?;
1745
1746 update_user_contacts(called_user_id, &session).await?;
1747 Ok(())
1748}
1749
1750/// Decline an incoming call.
1751async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1752 let room_id = RoomId::from_proto(message.room_id);
1753 {
1754 let room = session
1755 .db()
1756 .await
1757 .decline_call(Some(room_id), session.user_id())
1758 .await?
1759 .context("declining call")?;
1760 room_updated(&room, &session.peer);
1761 }
1762
1763 for connection_id in session
1764 .connection_pool()
1765 .await
1766 .user_connection_ids(session.user_id())
1767 {
1768 session
1769 .peer
1770 .send(
1771 connection_id,
1772 proto::CallCanceled {
1773 room_id: room_id.to_proto(),
1774 },
1775 )
1776 .trace_err();
1777 }
1778 update_user_contacts(session.user_id(), &session).await?;
1779 Ok(())
1780}
1781
1782/// Updates other participants in the room with your current location.
1783async fn update_participant_location(
1784 request: proto::UpdateParticipantLocation,
1785 response: Response<proto::UpdateParticipantLocation>,
1786 session: MessageContext,
1787) -> Result<()> {
1788 let room_id = RoomId::from_proto(request.room_id);
1789 let location = request.location.context("invalid location")?;
1790
1791 let db = session.db().await;
1792 let room = db
1793 .update_room_participant_location(room_id, session.connection_id, location)
1794 .await?;
1795
1796 room_updated(&room, &session.peer);
1797 response.send(proto::Ack {})?;
1798 Ok(())
1799}
1800
1801/// Share a project into the room.
1802async fn share_project(
1803 request: proto::ShareProject,
1804 response: Response<proto::ShareProject>,
1805 session: MessageContext,
1806) -> Result<()> {
1807 let (project_id, room) = &*session
1808 .db()
1809 .await
1810 .share_project(
1811 RoomId::from_proto(request.room_id),
1812 session.connection_id,
1813 &request.worktrees,
1814 request.is_ssh_project,
1815 request.windows_paths.unwrap_or(false),
1816 )
1817 .await?;
1818 response.send(proto::ShareProjectResponse {
1819 project_id: project_id.to_proto(),
1820 })?;
1821 room_updated(room, &session.peer);
1822
1823 Ok(())
1824}
1825
1826/// Unshare a project from the room.
1827async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1828 let project_id = ProjectId::from_proto(message.project_id);
1829 unshare_project_internal(project_id, session.connection_id, &session).await
1830}
1831
1832async fn unshare_project_internal(
1833 project_id: ProjectId,
1834 connection_id: ConnectionId,
1835 session: &Session,
1836) -> Result<()> {
1837 let delete = {
1838 let room_guard = session
1839 .db()
1840 .await
1841 .unshare_project(project_id, connection_id)
1842 .await?;
1843
1844 let (delete, room, guest_connection_ids) = &*room_guard;
1845
1846 let message = proto::UnshareProject {
1847 project_id: project_id.to_proto(),
1848 };
1849
1850 broadcast(
1851 Some(connection_id),
1852 guest_connection_ids.iter().copied(),
1853 |conn_id| session.peer.send(conn_id, message.clone()),
1854 );
1855 if let Some(room) = room {
1856 room_updated(room, &session.peer);
1857 }
1858
1859 *delete
1860 };
1861
1862 if delete {
1863 let db = session.db().await;
1864 db.delete_project(project_id).await?;
1865 }
1866
1867 Ok(())
1868}
1869
1870/// Join someone elses shared project.
1871async fn join_project(
1872 request: proto::JoinProject,
1873 response: Response<proto::JoinProject>,
1874 session: MessageContext,
1875) -> Result<()> {
1876 let project_id = ProjectId::from_proto(request.project_id);
1877
1878 tracing::info!(%project_id, "join project");
1879
1880 let db = session.db().await;
1881 let (project, replica_id) = &mut *db
1882 .join_project(
1883 project_id,
1884 session.connection_id,
1885 session.user_id(),
1886 request.committer_name.clone(),
1887 request.committer_email.clone(),
1888 )
1889 .await?;
1890 drop(db);
1891 tracing::info!(%project_id, "join remote project");
1892 let collaborators = project
1893 .collaborators
1894 .iter()
1895 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1896 .map(|collaborator| collaborator.to_proto())
1897 .collect::<Vec<_>>();
1898 let project_id = project.id;
1899 let guest_user_id = session.user_id();
1900
1901 let worktrees = project
1902 .worktrees
1903 .iter()
1904 .map(|(id, worktree)| proto::WorktreeMetadata {
1905 id: *id,
1906 root_name: worktree.root_name.clone(),
1907 visible: worktree.visible,
1908 abs_path: worktree.abs_path.clone(),
1909 })
1910 .collect::<Vec<_>>();
1911
1912 let add_project_collaborator = proto::AddProjectCollaborator {
1913 project_id: project_id.to_proto(),
1914 collaborator: Some(proto::Collaborator {
1915 peer_id: Some(session.connection_id.into()),
1916 replica_id: replica_id.0 as u32,
1917 user_id: guest_user_id.to_proto(),
1918 is_host: false,
1919 committer_name: request.committer_name.clone(),
1920 committer_email: request.committer_email.clone(),
1921 }),
1922 };
1923
1924 for collaborator in &collaborators {
1925 session
1926 .peer
1927 .send(
1928 collaborator.peer_id.unwrap().into(),
1929 add_project_collaborator.clone(),
1930 )
1931 .trace_err();
1932 }
1933
1934 // First, we send the metadata associated with each worktree.
1935 let (language_servers, language_server_capabilities) = project
1936 .language_servers
1937 .clone()
1938 .into_iter()
1939 .map(|server| (server.server, server.capabilities))
1940 .unzip();
1941 response.send(proto::JoinProjectResponse {
1942 project_id: project.id.0 as u64,
1943 worktrees,
1944 replica_id: replica_id.0 as u32,
1945 collaborators,
1946 language_servers,
1947 language_server_capabilities,
1948 role: project.role.into(),
1949 windows_paths: project.path_style == PathStyle::Windows,
1950 })?;
1951
1952 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1953 // Stream this worktree's entries.
1954 let message = proto::UpdateWorktree {
1955 project_id: project_id.to_proto(),
1956 worktree_id,
1957 abs_path: worktree.abs_path.clone(),
1958 root_name: worktree.root_name,
1959 updated_entries: worktree.entries,
1960 removed_entries: Default::default(),
1961 scan_id: worktree.scan_id,
1962 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1963 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1964 removed_repositories: Default::default(),
1965 };
1966 for update in proto::split_worktree_update(message) {
1967 session.peer.send(session.connection_id, update.clone())?;
1968 }
1969
1970 // Stream this worktree's diagnostics.
1971 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1972 if let Some(summary) = worktree_diagnostics.next() {
1973 let message = proto::UpdateDiagnosticSummary {
1974 project_id: project.id.to_proto(),
1975 worktree_id: worktree.id,
1976 summary: Some(summary),
1977 more_summaries: worktree_diagnostics.collect(),
1978 };
1979 session.peer.send(session.connection_id, message)?;
1980 }
1981
1982 for settings_file in worktree.settings_files {
1983 session.peer.send(
1984 session.connection_id,
1985 proto::UpdateWorktreeSettings {
1986 project_id: project_id.to_proto(),
1987 worktree_id: worktree.id,
1988 path: settings_file.path,
1989 content: Some(settings_file.content),
1990 kind: Some(settings_file.kind.to_proto() as i32),
1991 outside_worktree: Some(settings_file.outside_worktree),
1992 },
1993 )?;
1994 }
1995 }
1996
1997 for repository in mem::take(&mut project.repositories) {
1998 for update in split_repository_update(repository) {
1999 session.peer.send(session.connection_id, update)?;
2000 }
2001 }
2002
2003 for language_server in &project.language_servers {
2004 session.peer.send(
2005 session.connection_id,
2006 proto::UpdateLanguageServer {
2007 project_id: project_id.to_proto(),
2008 server_name: Some(language_server.server.name.clone()),
2009 language_server_id: language_server.server.id,
2010 variant: Some(
2011 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2012 proto::LspDiskBasedDiagnosticsUpdated {},
2013 ),
2014 ),
2015 },
2016 )?;
2017 }
2018
2019 Ok(())
2020}
2021
2022/// Leave someone elses shared project.
2023async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2024 let sender_id = session.connection_id;
2025 let project_id = ProjectId::from_proto(request.project_id);
2026 let db = session.db().await;
2027
2028 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2029 tracing::info!(
2030 %project_id,
2031 "leave project"
2032 );
2033
2034 project_left(project, &session);
2035 if let Some(room) = room {
2036 room_updated(room, &session.peer);
2037 }
2038
2039 Ok(())
2040}
2041
2042/// Updates other participants with changes to the project
2043async fn update_project(
2044 request: proto::UpdateProject,
2045 response: Response<proto::UpdateProject>,
2046 session: MessageContext,
2047) -> Result<()> {
2048 let project_id = ProjectId::from_proto(request.project_id);
2049 let (room, guest_connection_ids) = &*session
2050 .db()
2051 .await
2052 .update_project(project_id, session.connection_id, &request.worktrees)
2053 .await?;
2054 broadcast(
2055 Some(session.connection_id),
2056 guest_connection_ids.iter().copied(),
2057 |connection_id| {
2058 session
2059 .peer
2060 .forward_send(session.connection_id, connection_id, request.clone())
2061 },
2062 );
2063 if let Some(room) = room {
2064 room_updated(room, &session.peer);
2065 }
2066 response.send(proto::Ack {})?;
2067
2068 Ok(())
2069}
2070
2071/// Updates other participants with changes to the worktree
2072async fn update_worktree(
2073 request: proto::UpdateWorktree,
2074 response: Response<proto::UpdateWorktree>,
2075 session: MessageContext,
2076) -> Result<()> {
2077 let guest_connection_ids = session
2078 .db()
2079 .await
2080 .update_worktree(&request, session.connection_id)
2081 .await?;
2082
2083 broadcast(
2084 Some(session.connection_id),
2085 guest_connection_ids.iter().copied(),
2086 |connection_id| {
2087 session
2088 .peer
2089 .forward_send(session.connection_id, connection_id, request.clone())
2090 },
2091 );
2092 response.send(proto::Ack {})?;
2093 Ok(())
2094}
2095
2096async fn update_repository(
2097 request: proto::UpdateRepository,
2098 response: Response<proto::UpdateRepository>,
2099 session: MessageContext,
2100) -> Result<()> {
2101 let guest_connection_ids = session
2102 .db()
2103 .await
2104 .update_repository(&request, session.connection_id)
2105 .await?;
2106
2107 broadcast(
2108 Some(session.connection_id),
2109 guest_connection_ids.iter().copied(),
2110 |connection_id| {
2111 session
2112 .peer
2113 .forward_send(session.connection_id, connection_id, request.clone())
2114 },
2115 );
2116 response.send(proto::Ack {})?;
2117 Ok(())
2118}
2119
2120async fn remove_repository(
2121 request: proto::RemoveRepository,
2122 response: Response<proto::RemoveRepository>,
2123 session: MessageContext,
2124) -> Result<()> {
2125 let guest_connection_ids = session
2126 .db()
2127 .await
2128 .remove_repository(&request, session.connection_id)
2129 .await?;
2130
2131 broadcast(
2132 Some(session.connection_id),
2133 guest_connection_ids.iter().copied(),
2134 |connection_id| {
2135 session
2136 .peer
2137 .forward_send(session.connection_id, connection_id, request.clone())
2138 },
2139 );
2140 response.send(proto::Ack {})?;
2141 Ok(())
2142}
2143
2144/// Updates other participants with changes to the diagnostics
2145async fn update_diagnostic_summary(
2146 message: proto::UpdateDiagnosticSummary,
2147 session: MessageContext,
2148) -> Result<()> {
2149 let guest_connection_ids = session
2150 .db()
2151 .await
2152 .update_diagnostic_summary(&message, session.connection_id)
2153 .await?;
2154
2155 broadcast(
2156 Some(session.connection_id),
2157 guest_connection_ids.iter().copied(),
2158 |connection_id| {
2159 session
2160 .peer
2161 .forward_send(session.connection_id, connection_id, message.clone())
2162 },
2163 );
2164
2165 Ok(())
2166}
2167
2168/// Updates other participants with changes to the worktree settings
2169async fn update_worktree_settings(
2170 message: proto::UpdateWorktreeSettings,
2171 session: MessageContext,
2172) -> Result<()> {
2173 let guest_connection_ids = session
2174 .db()
2175 .await
2176 .update_worktree_settings(&message, session.connection_id)
2177 .await?;
2178
2179 broadcast(
2180 Some(session.connection_id),
2181 guest_connection_ids.iter().copied(),
2182 |connection_id| {
2183 session
2184 .peer
2185 .forward_send(session.connection_id, connection_id, message.clone())
2186 },
2187 );
2188
2189 Ok(())
2190}
2191
2192/// Notify other participants that a language server has started.
2193async fn start_language_server(
2194 request: proto::StartLanguageServer,
2195 session: MessageContext,
2196) -> Result<()> {
2197 let guest_connection_ids = session
2198 .db()
2199 .await
2200 .start_language_server(&request, session.connection_id)
2201 .await?;
2202
2203 broadcast(
2204 Some(session.connection_id),
2205 guest_connection_ids.iter().copied(),
2206 |connection_id| {
2207 session
2208 .peer
2209 .forward_send(session.connection_id, connection_id, request.clone())
2210 },
2211 );
2212 Ok(())
2213}
2214
2215/// Notify other participants that a language server has changed.
2216async fn update_language_server(
2217 request: proto::UpdateLanguageServer,
2218 session: MessageContext,
2219) -> Result<()> {
2220 let project_id = ProjectId::from_proto(request.project_id);
2221 let db = session.db().await;
2222
2223 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2224 && let Some(capabilities) = update.capabilities.clone()
2225 {
2226 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2227 .await?;
2228 }
2229
2230 let project_connection_ids = db
2231 .project_connection_ids(project_id, session.connection_id, true)
2232 .await?;
2233 broadcast(
2234 Some(session.connection_id),
2235 project_connection_ids.iter().copied(),
2236 |connection_id| {
2237 session
2238 .peer
2239 .forward_send(session.connection_id, connection_id, request.clone())
2240 },
2241 );
2242 Ok(())
2243}
2244
2245/// forward a project request to the host. These requests should be read only
2246/// as guests are allowed to send them.
2247async fn forward_read_only_project_request<T>(
2248 request: T,
2249 response: Response<T>,
2250 session: MessageContext,
2251) -> Result<()>
2252where
2253 T: EntityMessage + RequestMessage,
2254{
2255 let project_id = ProjectId::from_proto(request.remote_entity_id());
2256 let host_connection_id = session
2257 .db()
2258 .await
2259 .host_for_read_only_project_request(project_id, session.connection_id)
2260 .await?;
2261 let payload = session.forward_request(host_connection_id, request).await?;
2262 response.send(payload)?;
2263 Ok(())
2264}
2265
2266/// forward a project request to the host. These requests are disallowed
2267/// for guests.
2268async fn forward_mutating_project_request<T>(
2269 request: T,
2270 response: Response<T>,
2271 session: MessageContext,
2272) -> Result<()>
2273where
2274 T: EntityMessage + RequestMessage,
2275{
2276 let project_id = ProjectId::from_proto(request.remote_entity_id());
2277
2278 let host_connection_id = session
2279 .db()
2280 .await
2281 .host_for_mutating_project_request(project_id, session.connection_id)
2282 .await?;
2283 let payload = session.forward_request(host_connection_id, request).await?;
2284 response.send(payload)?;
2285 Ok(())
2286}
2287
2288async fn lsp_query(
2289 request: proto::LspQuery,
2290 response: Response<proto::LspQuery>,
2291 session: MessageContext,
2292) -> Result<()> {
2293 let (name, should_write) = request.query_name_and_write_permissions();
2294 tracing::Span::current().record("lsp_query_request", name);
2295 tracing::info!("lsp_query message received");
2296 if should_write {
2297 forward_mutating_project_request(request, response, session).await
2298 } else {
2299 forward_read_only_project_request(request, response, session).await
2300 }
2301}
2302
2303/// Notify other participants that a new buffer has been created
2304async fn create_buffer_for_peer(
2305 request: proto::CreateBufferForPeer,
2306 session: MessageContext,
2307) -> Result<()> {
2308 session
2309 .db()
2310 .await
2311 .check_user_is_project_host(
2312 ProjectId::from_proto(request.project_id),
2313 session.connection_id,
2314 )
2315 .await?;
2316 let peer_id = request.peer_id.context("invalid peer id")?;
2317 session
2318 .peer
2319 .forward_send(session.connection_id, peer_id.into(), request)?;
2320 Ok(())
2321}
2322
2323/// Notify other participants that a new image has been created
2324async fn create_image_for_peer(
2325 request: proto::CreateImageForPeer,
2326 session: MessageContext,
2327) -> Result<()> {
2328 session
2329 .db()
2330 .await
2331 .check_user_is_project_host(
2332 ProjectId::from_proto(request.project_id),
2333 session.connection_id,
2334 )
2335 .await?;
2336 let peer_id = request.peer_id.context("invalid peer id")?;
2337 session
2338 .peer
2339 .forward_send(session.connection_id, peer_id.into(), request)?;
2340 Ok(())
2341}
2342
2343/// Notify other participants that a buffer has been updated. This is
2344/// allowed for guests as long as the update is limited to selections.
2345async fn update_buffer(
2346 request: proto::UpdateBuffer,
2347 response: Response<proto::UpdateBuffer>,
2348 session: MessageContext,
2349) -> Result<()> {
2350 let project_id = ProjectId::from_proto(request.project_id);
2351 let mut capability = Capability::ReadOnly;
2352
2353 for op in request.operations.iter() {
2354 match op.variant {
2355 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2356 Some(_) => capability = Capability::ReadWrite,
2357 }
2358 }
2359
2360 let host = {
2361 let guard = session
2362 .db()
2363 .await
2364 .connections_for_buffer_update(project_id, session.connection_id, capability)
2365 .await?;
2366
2367 let (host, guests) = &*guard;
2368
2369 broadcast(
2370 Some(session.connection_id),
2371 guests.clone(),
2372 |connection_id| {
2373 session
2374 .peer
2375 .forward_send(session.connection_id, connection_id, request.clone())
2376 },
2377 );
2378
2379 *host
2380 };
2381
2382 if host != session.connection_id {
2383 session.forward_request(host, request.clone()).await?;
2384 }
2385
2386 response.send(proto::Ack {})?;
2387 Ok(())
2388}
2389
2390async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2391 let project_id = ProjectId::from_proto(message.project_id);
2392
2393 let operation = message.operation.as_ref().context("invalid operation")?;
2394 let capability = match operation.variant.as_ref() {
2395 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2396 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2397 match buffer_op.variant {
2398 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2399 Capability::ReadOnly
2400 }
2401 _ => Capability::ReadWrite,
2402 }
2403 } else {
2404 Capability::ReadWrite
2405 }
2406 }
2407 Some(_) => Capability::ReadWrite,
2408 None => Capability::ReadOnly,
2409 };
2410
2411 let guard = session
2412 .db()
2413 .await
2414 .connections_for_buffer_update(project_id, session.connection_id, capability)
2415 .await?;
2416
2417 let (host, guests) = &*guard;
2418
2419 broadcast(
2420 Some(session.connection_id),
2421 guests.iter().chain([host]).copied(),
2422 |connection_id| {
2423 session
2424 .peer
2425 .forward_send(session.connection_id, connection_id, message.clone())
2426 },
2427 );
2428
2429 Ok(())
2430}
2431
2432async fn forward_project_search_chunk(
2433 message: proto::FindSearchCandidatesChunk,
2434 response: Response<proto::FindSearchCandidatesChunk>,
2435 session: MessageContext,
2436) -> Result<()> {
2437 let peer_id = message.peer_id.context("missing peer_id")?;
2438 let payload = session
2439 .peer
2440 .forward_request(session.connection_id, peer_id.into(), message)
2441 .await?;
2442 response.send(payload)?;
2443 Ok(())
2444}
2445
2446/// Notify other participants that a project has been updated.
2447async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2448 request: T,
2449 session: MessageContext,
2450) -> Result<()> {
2451 let project_id = ProjectId::from_proto(request.remote_entity_id());
2452 let project_connection_ids = session
2453 .db()
2454 .await
2455 .project_connection_ids(project_id, session.connection_id, false)
2456 .await?;
2457
2458 broadcast(
2459 Some(session.connection_id),
2460 project_connection_ids.iter().copied(),
2461 |connection_id| {
2462 session
2463 .peer
2464 .forward_send(session.connection_id, connection_id, request.clone())
2465 },
2466 );
2467 Ok(())
2468}
2469
2470/// Start following another user in a call.
2471async fn follow(
2472 request: proto::Follow,
2473 response: Response<proto::Follow>,
2474 session: MessageContext,
2475) -> Result<()> {
2476 let room_id = RoomId::from_proto(request.room_id);
2477 let project_id = request.project_id.map(ProjectId::from_proto);
2478 let leader_id = request.leader_id.context("invalid leader id")?.into();
2479 let follower_id = session.connection_id;
2480
2481 session
2482 .db()
2483 .await
2484 .check_room_participants(room_id, leader_id, session.connection_id)
2485 .await?;
2486
2487 let response_payload = session.forward_request(leader_id, request).await?;
2488 response.send(response_payload)?;
2489
2490 if let Some(project_id) = project_id {
2491 let room = session
2492 .db()
2493 .await
2494 .follow(room_id, project_id, leader_id, follower_id)
2495 .await?;
2496 room_updated(&room, &session.peer);
2497 }
2498
2499 Ok(())
2500}
2501
2502/// Stop following another user in a call.
2503async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2504 let room_id = RoomId::from_proto(request.room_id);
2505 let project_id = request.project_id.map(ProjectId::from_proto);
2506 let leader_id = request.leader_id.context("invalid leader id")?.into();
2507 let follower_id = session.connection_id;
2508
2509 session
2510 .db()
2511 .await
2512 .check_room_participants(room_id, leader_id, session.connection_id)
2513 .await?;
2514
2515 session
2516 .peer
2517 .forward_send(session.connection_id, leader_id, request)?;
2518
2519 if let Some(project_id) = project_id {
2520 let room = session
2521 .db()
2522 .await
2523 .unfollow(room_id, project_id, leader_id, follower_id)
2524 .await?;
2525 room_updated(&room, &session.peer);
2526 }
2527
2528 Ok(())
2529}
2530
2531/// Notify everyone following you of your current location.
2532async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2533 let room_id = RoomId::from_proto(request.room_id);
2534 let database = session.db.lock().await;
2535
2536 let connection_ids = if let Some(project_id) = request.project_id {
2537 let project_id = ProjectId::from_proto(project_id);
2538 database
2539 .project_connection_ids(project_id, session.connection_id, true)
2540 .await?
2541 } else {
2542 database
2543 .room_connection_ids(room_id, session.connection_id)
2544 .await?
2545 };
2546
2547 // For now, don't send view update messages back to that view's current leader.
2548 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2549 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2550 _ => None,
2551 });
2552
2553 for connection_id in connection_ids.iter().cloned() {
2554 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2555 session
2556 .peer
2557 .forward_send(session.connection_id, connection_id, request.clone())?;
2558 }
2559 }
2560 Ok(())
2561}
2562
2563/// Get public data about users.
2564async fn get_users(
2565 request: proto::GetUsers,
2566 response: Response<proto::GetUsers>,
2567 session: MessageContext,
2568) -> Result<()> {
2569 let user_ids = request
2570 .user_ids
2571 .into_iter()
2572 .map(UserId::from_proto)
2573 .collect();
2574 let users = session
2575 .db()
2576 .await
2577 .get_users_by_ids(user_ids)
2578 .await?
2579 .into_iter()
2580 .map(|user| proto::User {
2581 id: user.id.to_proto(),
2582 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2583 github_login: user.github_login,
2584 name: user.name,
2585 })
2586 .collect();
2587 response.send(proto::UsersResponse { users })?;
2588 Ok(())
2589}
2590
2591/// Search for users (to invite) buy Github login
2592async fn fuzzy_search_users(
2593 request: proto::FuzzySearchUsers,
2594 response: Response<proto::FuzzySearchUsers>,
2595 session: MessageContext,
2596) -> Result<()> {
2597 let query = request.query;
2598 let users = match query.len() {
2599 0 => vec![],
2600 1 | 2 => session
2601 .db()
2602 .await
2603 .get_user_by_github_login(&query)
2604 .await?
2605 .into_iter()
2606 .collect(),
2607 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2608 };
2609 let users = users
2610 .into_iter()
2611 .filter(|user| user.id != session.user_id())
2612 .map(|user| proto::User {
2613 id: user.id.to_proto(),
2614 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2615 github_login: user.github_login,
2616 name: user.name,
2617 })
2618 .collect();
2619 response.send(proto::UsersResponse { users })?;
2620 Ok(())
2621}
2622
2623/// Send a contact request to another user.
2624async fn request_contact(
2625 request: proto::RequestContact,
2626 response: Response<proto::RequestContact>,
2627 session: MessageContext,
2628) -> Result<()> {
2629 let requester_id = session.user_id();
2630 let responder_id = UserId::from_proto(request.responder_id);
2631 if requester_id == responder_id {
2632 return Err(anyhow!("cannot add yourself as a contact"))?;
2633 }
2634
2635 let notifications = session
2636 .db()
2637 .await
2638 .send_contact_request(requester_id, responder_id)
2639 .await?;
2640
2641 // Update outgoing contact requests of requester
2642 let mut update = proto::UpdateContacts::default();
2643 update.outgoing_requests.push(responder_id.to_proto());
2644 for connection_id in session
2645 .connection_pool()
2646 .await
2647 .user_connection_ids(requester_id)
2648 {
2649 session.peer.send(connection_id, update.clone())?;
2650 }
2651
2652 // Update incoming contact requests of responder
2653 let mut update = proto::UpdateContacts::default();
2654 update
2655 .incoming_requests
2656 .push(proto::IncomingContactRequest {
2657 requester_id: requester_id.to_proto(),
2658 });
2659 let connection_pool = session.connection_pool().await;
2660 for connection_id in connection_pool.user_connection_ids(responder_id) {
2661 session.peer.send(connection_id, update.clone())?;
2662 }
2663
2664 send_notifications(&connection_pool, &session.peer, notifications);
2665
2666 response.send(proto::Ack {})?;
2667 Ok(())
2668}
2669
2670/// Accept or decline a contact request
2671async fn respond_to_contact_request(
2672 request: proto::RespondToContactRequest,
2673 response: Response<proto::RespondToContactRequest>,
2674 session: MessageContext,
2675) -> Result<()> {
2676 let responder_id = session.user_id();
2677 let requester_id = UserId::from_proto(request.requester_id);
2678 let db = session.db().await;
2679 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2680 db.dismiss_contact_notification(responder_id, requester_id)
2681 .await?;
2682 } else {
2683 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2684
2685 let notifications = db
2686 .respond_to_contact_request(responder_id, requester_id, accept)
2687 .await?;
2688 let requester_busy = db.is_user_busy(requester_id).await?;
2689 let responder_busy = db.is_user_busy(responder_id).await?;
2690
2691 let pool = session.connection_pool().await;
2692 // Update responder with new contact
2693 let mut update = proto::UpdateContacts::default();
2694 if accept {
2695 update
2696 .contacts
2697 .push(contact_for_user(requester_id, requester_busy, &pool));
2698 }
2699 update
2700 .remove_incoming_requests
2701 .push(requester_id.to_proto());
2702 for connection_id in pool.user_connection_ids(responder_id) {
2703 session.peer.send(connection_id, update.clone())?;
2704 }
2705
2706 // Update requester with new contact
2707 let mut update = proto::UpdateContacts::default();
2708 if accept {
2709 update
2710 .contacts
2711 .push(contact_for_user(responder_id, responder_busy, &pool));
2712 }
2713 update
2714 .remove_outgoing_requests
2715 .push(responder_id.to_proto());
2716
2717 for connection_id in pool.user_connection_ids(requester_id) {
2718 session.peer.send(connection_id, update.clone())?;
2719 }
2720
2721 send_notifications(&pool, &session.peer, notifications);
2722 }
2723
2724 response.send(proto::Ack {})?;
2725 Ok(())
2726}
2727
2728/// Remove a contact.
2729async fn remove_contact(
2730 request: proto::RemoveContact,
2731 response: Response<proto::RemoveContact>,
2732 session: MessageContext,
2733) -> Result<()> {
2734 let requester_id = session.user_id();
2735 let responder_id = UserId::from_proto(request.user_id);
2736 let db = session.db().await;
2737 let (contact_accepted, deleted_notification_id) =
2738 db.remove_contact(requester_id, responder_id).await?;
2739
2740 let pool = session.connection_pool().await;
2741 // Update outgoing contact requests of requester
2742 let mut update = proto::UpdateContacts::default();
2743 if contact_accepted {
2744 update.remove_contacts.push(responder_id.to_proto());
2745 } else {
2746 update
2747 .remove_outgoing_requests
2748 .push(responder_id.to_proto());
2749 }
2750 for connection_id in pool.user_connection_ids(requester_id) {
2751 session.peer.send(connection_id, update.clone())?;
2752 }
2753
2754 // Update incoming contact requests of responder
2755 let mut update = proto::UpdateContacts::default();
2756 if contact_accepted {
2757 update.remove_contacts.push(requester_id.to_proto());
2758 } else {
2759 update
2760 .remove_incoming_requests
2761 .push(requester_id.to_proto());
2762 }
2763 for connection_id in pool.user_connection_ids(responder_id) {
2764 session.peer.send(connection_id, update.clone())?;
2765 if let Some(notification_id) = deleted_notification_id {
2766 session.peer.send(
2767 connection_id,
2768 proto::DeleteNotification {
2769 notification_id: notification_id.to_proto(),
2770 },
2771 )?;
2772 }
2773 }
2774
2775 response.send(proto::Ack {})?;
2776 Ok(())
2777}
2778
2779fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2780 version.0.minor < 139
2781}
2782
2783async fn subscribe_to_channels(
2784 _: proto::SubscribeToChannels,
2785 session: MessageContext,
2786) -> Result<()> {
2787 subscribe_user_to_channels(session.user_id(), &session).await?;
2788 Ok(())
2789}
2790
2791async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2792 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2793 let mut pool = session.connection_pool().await;
2794 for membership in &channels_for_user.channel_memberships {
2795 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2796 }
2797 session.peer.send(
2798 session.connection_id,
2799 build_update_user_channels(&channels_for_user),
2800 )?;
2801 session.peer.send(
2802 session.connection_id,
2803 build_channels_update(channels_for_user),
2804 )?;
2805 Ok(())
2806}
2807
2808/// Creates a new channel.
2809async fn create_channel(
2810 request: proto::CreateChannel,
2811 response: Response<proto::CreateChannel>,
2812 session: MessageContext,
2813) -> Result<()> {
2814 let db = session.db().await;
2815
2816 let parent_id = request.parent_id.map(ChannelId::from_proto);
2817 let (channel, membership) = db
2818 .create_channel(&request.name, parent_id, session.user_id())
2819 .await?;
2820
2821 let root_id = channel.root_id();
2822 let channel = Channel::from_model(channel);
2823
2824 response.send(proto::CreateChannelResponse {
2825 channel: Some(channel.to_proto()),
2826 parent_id: request.parent_id,
2827 })?;
2828
2829 let mut connection_pool = session.connection_pool().await;
2830 if let Some(membership) = membership {
2831 connection_pool.subscribe_to_channel(
2832 membership.user_id,
2833 membership.channel_id,
2834 membership.role,
2835 );
2836 let update = proto::UpdateUserChannels {
2837 channel_memberships: vec![proto::ChannelMembership {
2838 channel_id: membership.channel_id.to_proto(),
2839 role: membership.role.into(),
2840 }],
2841 ..Default::default()
2842 };
2843 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2844 session.peer.send(connection_id, update.clone())?;
2845 }
2846 }
2847
2848 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2849 if !role.can_see_channel(channel.visibility) {
2850 continue;
2851 }
2852
2853 let update = proto::UpdateChannels {
2854 channels: vec![channel.to_proto()],
2855 ..Default::default()
2856 };
2857 session.peer.send(connection_id, update.clone())?;
2858 }
2859
2860 Ok(())
2861}
2862
2863/// Delete a channel
2864async fn delete_channel(
2865 request: proto::DeleteChannel,
2866 response: Response<proto::DeleteChannel>,
2867 session: MessageContext,
2868) -> Result<()> {
2869 let db = session.db().await;
2870
2871 let channel_id = request.channel_id;
2872 let (root_channel, removed_channels) = db
2873 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2874 .await?;
2875 response.send(proto::Ack {})?;
2876
2877 // Notify members of removed channels
2878 let mut update = proto::UpdateChannels::default();
2879 update
2880 .delete_channels
2881 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2882
2883 let connection_pool = session.connection_pool().await;
2884 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2885 session.peer.send(connection_id, update.clone())?;
2886 }
2887
2888 Ok(())
2889}
2890
2891/// Invite someone to join a channel.
2892async fn invite_channel_member(
2893 request: proto::InviteChannelMember,
2894 response: Response<proto::InviteChannelMember>,
2895 session: MessageContext,
2896) -> Result<()> {
2897 let db = session.db().await;
2898 let channel_id = ChannelId::from_proto(request.channel_id);
2899 let invitee_id = UserId::from_proto(request.user_id);
2900 let InviteMemberResult {
2901 channel,
2902 notifications,
2903 } = db
2904 .invite_channel_member(
2905 channel_id,
2906 invitee_id,
2907 session.user_id(),
2908 request.role().into(),
2909 )
2910 .await?;
2911
2912 let update = proto::UpdateChannels {
2913 channel_invitations: vec![channel.to_proto()],
2914 ..Default::default()
2915 };
2916
2917 let connection_pool = session.connection_pool().await;
2918 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2919 session.peer.send(connection_id, update.clone())?;
2920 }
2921
2922 send_notifications(&connection_pool, &session.peer, notifications);
2923
2924 response.send(proto::Ack {})?;
2925 Ok(())
2926}
2927
2928/// remove someone from a channel
2929async fn remove_channel_member(
2930 request: proto::RemoveChannelMember,
2931 response: Response<proto::RemoveChannelMember>,
2932 session: MessageContext,
2933) -> Result<()> {
2934 let db = session.db().await;
2935 let channel_id = ChannelId::from_proto(request.channel_id);
2936 let member_id = UserId::from_proto(request.user_id);
2937
2938 let RemoveChannelMemberResult {
2939 membership_update,
2940 notification_id,
2941 } = db
2942 .remove_channel_member(channel_id, member_id, session.user_id())
2943 .await?;
2944
2945 let mut connection_pool = session.connection_pool().await;
2946 notify_membership_updated(
2947 &mut connection_pool,
2948 membership_update,
2949 member_id,
2950 &session.peer,
2951 );
2952 for connection_id in connection_pool.user_connection_ids(member_id) {
2953 if let Some(notification_id) = notification_id {
2954 session
2955 .peer
2956 .send(
2957 connection_id,
2958 proto::DeleteNotification {
2959 notification_id: notification_id.to_proto(),
2960 },
2961 )
2962 .trace_err();
2963 }
2964 }
2965
2966 response.send(proto::Ack {})?;
2967 Ok(())
2968}
2969
2970/// Toggle the channel between public and private.
2971/// Care is taken to maintain the invariant that public channels only descend from public channels,
2972/// (though members-only channels can appear at any point in the hierarchy).
2973async fn set_channel_visibility(
2974 request: proto::SetChannelVisibility,
2975 response: Response<proto::SetChannelVisibility>,
2976 session: MessageContext,
2977) -> Result<()> {
2978 let db = session.db().await;
2979 let channel_id = ChannelId::from_proto(request.channel_id);
2980 let visibility = request.visibility().into();
2981
2982 let channel_model = db
2983 .set_channel_visibility(channel_id, visibility, session.user_id())
2984 .await?;
2985 let root_id = channel_model.root_id();
2986 let channel = Channel::from_model(channel_model);
2987
2988 let mut connection_pool = session.connection_pool().await;
2989 for (user_id, role) in connection_pool
2990 .channel_user_ids(root_id)
2991 .collect::<Vec<_>>()
2992 .into_iter()
2993 {
2994 let update = if role.can_see_channel(channel.visibility) {
2995 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2996 proto::UpdateChannels {
2997 channels: vec![channel.to_proto()],
2998 ..Default::default()
2999 }
3000 } else {
3001 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3002 proto::UpdateChannels {
3003 delete_channels: vec![channel.id.to_proto()],
3004 ..Default::default()
3005 }
3006 };
3007
3008 for connection_id in connection_pool.user_connection_ids(user_id) {
3009 session.peer.send(connection_id, update.clone())?;
3010 }
3011 }
3012
3013 response.send(proto::Ack {})?;
3014 Ok(())
3015}
3016
3017/// Alter the role for a user in the channel.
3018async fn set_channel_member_role(
3019 request: proto::SetChannelMemberRole,
3020 response: Response<proto::SetChannelMemberRole>,
3021 session: MessageContext,
3022) -> Result<()> {
3023 let db = session.db().await;
3024 let channel_id = ChannelId::from_proto(request.channel_id);
3025 let member_id = UserId::from_proto(request.user_id);
3026 let result = db
3027 .set_channel_member_role(
3028 channel_id,
3029 session.user_id(),
3030 member_id,
3031 request.role().into(),
3032 )
3033 .await?;
3034
3035 match result {
3036 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3037 let mut connection_pool = session.connection_pool().await;
3038 notify_membership_updated(
3039 &mut connection_pool,
3040 membership_update,
3041 member_id,
3042 &session.peer,
3043 )
3044 }
3045 db::SetMemberRoleResult::InviteUpdated(channel) => {
3046 let update = proto::UpdateChannels {
3047 channel_invitations: vec![channel.to_proto()],
3048 ..Default::default()
3049 };
3050
3051 for connection_id in session
3052 .connection_pool()
3053 .await
3054 .user_connection_ids(member_id)
3055 {
3056 session.peer.send(connection_id, update.clone())?;
3057 }
3058 }
3059 }
3060
3061 response.send(proto::Ack {})?;
3062 Ok(())
3063}
3064
3065/// Change the name of a channel
3066async fn rename_channel(
3067 request: proto::RenameChannel,
3068 response: Response<proto::RenameChannel>,
3069 session: MessageContext,
3070) -> Result<()> {
3071 let db = session.db().await;
3072 let channel_id = ChannelId::from_proto(request.channel_id);
3073 let channel_model = db
3074 .rename_channel(channel_id, session.user_id(), &request.name)
3075 .await?;
3076 let root_id = channel_model.root_id();
3077 let channel = Channel::from_model(channel_model);
3078
3079 response.send(proto::RenameChannelResponse {
3080 channel: Some(channel.to_proto()),
3081 })?;
3082
3083 let connection_pool = session.connection_pool().await;
3084 let update = proto::UpdateChannels {
3085 channels: vec![channel.to_proto()],
3086 ..Default::default()
3087 };
3088 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3089 if role.can_see_channel(channel.visibility) {
3090 session.peer.send(connection_id, update.clone())?;
3091 }
3092 }
3093
3094 Ok(())
3095}
3096
3097/// Move a channel to a new parent.
3098async fn move_channel(
3099 request: proto::MoveChannel,
3100 response: Response<proto::MoveChannel>,
3101 session: MessageContext,
3102) -> Result<()> {
3103 let channel_id = ChannelId::from_proto(request.channel_id);
3104 let to = ChannelId::from_proto(request.to);
3105
3106 let (root_id, channels) = session
3107 .db()
3108 .await
3109 .move_channel(channel_id, to, session.user_id())
3110 .await?;
3111
3112 let connection_pool = session.connection_pool().await;
3113 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3114 let channels = channels
3115 .iter()
3116 .filter_map(|channel| {
3117 if role.can_see_channel(channel.visibility) {
3118 Some(channel.to_proto())
3119 } else {
3120 None
3121 }
3122 })
3123 .collect::<Vec<_>>();
3124 if channels.is_empty() {
3125 continue;
3126 }
3127
3128 let update = proto::UpdateChannels {
3129 channels,
3130 ..Default::default()
3131 };
3132
3133 session.peer.send(connection_id, update.clone())?;
3134 }
3135
3136 response.send(Ack {})?;
3137 Ok(())
3138}
3139
3140async fn reorder_channel(
3141 request: proto::ReorderChannel,
3142 response: Response<proto::ReorderChannel>,
3143 session: MessageContext,
3144) -> Result<()> {
3145 let channel_id = ChannelId::from_proto(request.channel_id);
3146 let direction = request.direction();
3147
3148 let updated_channels = session
3149 .db()
3150 .await
3151 .reorder_channel(channel_id, direction, session.user_id())
3152 .await?;
3153
3154 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3155 let connection_pool = session.connection_pool().await;
3156 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3157 let channels = updated_channels
3158 .iter()
3159 .filter_map(|channel| {
3160 if role.can_see_channel(channel.visibility) {
3161 Some(channel.to_proto())
3162 } else {
3163 None
3164 }
3165 })
3166 .collect::<Vec<_>>();
3167
3168 if channels.is_empty() {
3169 continue;
3170 }
3171
3172 let update = proto::UpdateChannels {
3173 channels,
3174 ..Default::default()
3175 };
3176
3177 session.peer.send(connection_id, update.clone())?;
3178 }
3179 }
3180
3181 response.send(Ack {})?;
3182 Ok(())
3183}
3184
3185/// Get the list of channel members
3186async fn get_channel_members(
3187 request: proto::GetChannelMembers,
3188 response: Response<proto::GetChannelMembers>,
3189 session: MessageContext,
3190) -> Result<()> {
3191 let db = session.db().await;
3192 let channel_id = ChannelId::from_proto(request.channel_id);
3193 let limit = if request.limit == 0 {
3194 u16::MAX as u64
3195 } else {
3196 request.limit
3197 };
3198 let (members, users) = db
3199 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3200 .await?;
3201 response.send(proto::GetChannelMembersResponse { members, users })?;
3202 Ok(())
3203}
3204
3205/// Accept or decline a channel invitation.
3206async fn respond_to_channel_invite(
3207 request: proto::RespondToChannelInvite,
3208 response: Response<proto::RespondToChannelInvite>,
3209 session: MessageContext,
3210) -> Result<()> {
3211 let db = session.db().await;
3212 let channel_id = ChannelId::from_proto(request.channel_id);
3213 let RespondToChannelInvite {
3214 membership_update,
3215 notifications,
3216 } = db
3217 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3218 .await?;
3219
3220 let mut connection_pool = session.connection_pool().await;
3221 if let Some(membership_update) = membership_update {
3222 notify_membership_updated(
3223 &mut connection_pool,
3224 membership_update,
3225 session.user_id(),
3226 &session.peer,
3227 );
3228 } else {
3229 let update = proto::UpdateChannels {
3230 remove_channel_invitations: vec![channel_id.to_proto()],
3231 ..Default::default()
3232 };
3233
3234 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3235 session.peer.send(connection_id, update.clone())?;
3236 }
3237 };
3238
3239 send_notifications(&connection_pool, &session.peer, notifications);
3240
3241 response.send(proto::Ack {})?;
3242
3243 Ok(())
3244}
3245
3246/// Join the channels' room
3247async fn join_channel(
3248 request: proto::JoinChannel,
3249 response: Response<proto::JoinChannel>,
3250 session: MessageContext,
3251) -> Result<()> {
3252 let channel_id = ChannelId::from_proto(request.channel_id);
3253 join_channel_internal(channel_id, Box::new(response), session).await
3254}
3255
3256trait JoinChannelInternalResponse {
3257 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3258}
3259impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3260 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3261 Response::<proto::JoinChannel>::send(self, result)
3262 }
3263}
3264impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3265 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3266 Response::<proto::JoinRoom>::send(self, result)
3267 }
3268}
3269
3270async fn join_channel_internal(
3271 channel_id: ChannelId,
3272 response: Box<impl JoinChannelInternalResponse>,
3273 session: MessageContext,
3274) -> Result<()> {
3275 let joined_room = {
3276 let mut db = session.db().await;
3277 // If zed quits without leaving the room, and the user re-opens zed before the
3278 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3279 // room they were in.
3280 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3281 tracing::info!(
3282 stale_connection_id = %connection,
3283 "cleaning up stale connection",
3284 );
3285 drop(db);
3286 leave_room_for_session(&session, connection).await?;
3287 db = session.db().await;
3288 }
3289
3290 let (joined_room, membership_updated, role) = db
3291 .join_channel(channel_id, session.user_id(), session.connection_id)
3292 .await?;
3293
3294 let live_kit_connection_info =
3295 session
3296 .app_state
3297 .livekit_client
3298 .as_ref()
3299 .and_then(|live_kit| {
3300 let (can_publish, token) = if role == ChannelRole::Guest {
3301 (
3302 false,
3303 live_kit
3304 .guest_token(
3305 &joined_room.room.livekit_room,
3306 &session.user_id().to_string(),
3307 )
3308 .trace_err()?,
3309 )
3310 } else {
3311 (
3312 true,
3313 live_kit
3314 .room_token(
3315 &joined_room.room.livekit_room,
3316 &session.user_id().to_string(),
3317 )
3318 .trace_err()?,
3319 )
3320 };
3321
3322 Some(LiveKitConnectionInfo {
3323 server_url: live_kit.url().into(),
3324 token,
3325 can_publish,
3326 })
3327 });
3328
3329 response.send(proto::JoinRoomResponse {
3330 room: Some(joined_room.room.clone()),
3331 channel_id: joined_room
3332 .channel
3333 .as_ref()
3334 .map(|channel| channel.id.to_proto()),
3335 live_kit_connection_info,
3336 })?;
3337
3338 let mut connection_pool = session.connection_pool().await;
3339 if let Some(membership_updated) = membership_updated {
3340 notify_membership_updated(
3341 &mut connection_pool,
3342 membership_updated,
3343 session.user_id(),
3344 &session.peer,
3345 );
3346 }
3347
3348 room_updated(&joined_room.room, &session.peer);
3349
3350 joined_room
3351 };
3352
3353 channel_updated(
3354 &joined_room.channel.context("channel not returned")?,
3355 &joined_room.room,
3356 &session.peer,
3357 &*session.connection_pool().await,
3358 );
3359
3360 update_user_contacts(session.user_id(), &session).await?;
3361 Ok(())
3362}
3363
3364/// Start editing the channel notes
3365async fn join_channel_buffer(
3366 request: proto::JoinChannelBuffer,
3367 response: Response<proto::JoinChannelBuffer>,
3368 session: MessageContext,
3369) -> Result<()> {
3370 let db = session.db().await;
3371 let channel_id = ChannelId::from_proto(request.channel_id);
3372
3373 let open_response = db
3374 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3375 .await?;
3376
3377 let collaborators = open_response.collaborators.clone();
3378 response.send(open_response)?;
3379
3380 let update = UpdateChannelBufferCollaborators {
3381 channel_id: channel_id.to_proto(),
3382 collaborators: collaborators.clone(),
3383 };
3384 channel_buffer_updated(
3385 session.connection_id,
3386 collaborators
3387 .iter()
3388 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3389 &update,
3390 &session.peer,
3391 );
3392
3393 Ok(())
3394}
3395
3396/// Edit the channel notes
3397async fn update_channel_buffer(
3398 request: proto::UpdateChannelBuffer,
3399 session: MessageContext,
3400) -> Result<()> {
3401 let db = session.db().await;
3402 let channel_id = ChannelId::from_proto(request.channel_id);
3403
3404 let (collaborators, epoch, version) = db
3405 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3406 .await?;
3407
3408 channel_buffer_updated(
3409 session.connection_id,
3410 collaborators.clone(),
3411 &proto::UpdateChannelBuffer {
3412 channel_id: channel_id.to_proto(),
3413 operations: request.operations,
3414 },
3415 &session.peer,
3416 );
3417
3418 let pool = &*session.connection_pool().await;
3419
3420 let non_collaborators =
3421 pool.channel_connection_ids(channel_id)
3422 .filter_map(|(connection_id, _)| {
3423 if collaborators.contains(&connection_id) {
3424 None
3425 } else {
3426 Some(connection_id)
3427 }
3428 });
3429
3430 broadcast(None, non_collaborators, |peer_id| {
3431 session.peer.send(
3432 peer_id,
3433 proto::UpdateChannels {
3434 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3435 channel_id: channel_id.to_proto(),
3436 epoch: epoch as u64,
3437 version: version.clone(),
3438 }],
3439 ..Default::default()
3440 },
3441 )
3442 });
3443
3444 Ok(())
3445}
3446
3447/// Rejoin the channel notes after a connection blip
3448async fn rejoin_channel_buffers(
3449 request: proto::RejoinChannelBuffers,
3450 response: Response<proto::RejoinChannelBuffers>,
3451 session: MessageContext,
3452) -> Result<()> {
3453 let db = session.db().await;
3454 let buffers = db
3455 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3456 .await?;
3457
3458 for rejoined_buffer in &buffers {
3459 let collaborators_to_notify = rejoined_buffer
3460 .buffer
3461 .collaborators
3462 .iter()
3463 .filter_map(|c| Some(c.peer_id?.into()));
3464 channel_buffer_updated(
3465 session.connection_id,
3466 collaborators_to_notify,
3467 &proto::UpdateChannelBufferCollaborators {
3468 channel_id: rejoined_buffer.buffer.channel_id,
3469 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3470 },
3471 &session.peer,
3472 );
3473 }
3474
3475 response.send(proto::RejoinChannelBuffersResponse {
3476 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3477 })?;
3478
3479 Ok(())
3480}
3481
3482/// Stop editing the channel notes
3483async fn leave_channel_buffer(
3484 request: proto::LeaveChannelBuffer,
3485 response: Response<proto::LeaveChannelBuffer>,
3486 session: MessageContext,
3487) -> Result<()> {
3488 let db = session.db().await;
3489 let channel_id = ChannelId::from_proto(request.channel_id);
3490
3491 let left_buffer = db
3492 .leave_channel_buffer(channel_id, session.connection_id)
3493 .await?;
3494
3495 response.send(Ack {})?;
3496
3497 channel_buffer_updated(
3498 session.connection_id,
3499 left_buffer.connections,
3500 &proto::UpdateChannelBufferCollaborators {
3501 channel_id: channel_id.to_proto(),
3502 collaborators: left_buffer.collaborators,
3503 },
3504 &session.peer,
3505 );
3506
3507 Ok(())
3508}
3509
3510fn channel_buffer_updated<T: EnvelopedMessage>(
3511 sender_id: ConnectionId,
3512 collaborators: impl IntoIterator<Item = ConnectionId>,
3513 message: &T,
3514 peer: &Peer,
3515) {
3516 broadcast(Some(sender_id), collaborators, |peer_id| {
3517 peer.send(peer_id, message.clone())
3518 });
3519}
3520
3521fn send_notifications(
3522 connection_pool: &ConnectionPool,
3523 peer: &Peer,
3524 notifications: db::NotificationBatch,
3525) {
3526 for (user_id, notification) in notifications {
3527 for connection_id in connection_pool.user_connection_ids(user_id) {
3528 if let Err(error) = peer.send(
3529 connection_id,
3530 proto::AddNotification {
3531 notification: Some(notification.clone()),
3532 },
3533 ) {
3534 tracing::error!(
3535 "failed to send notification to {:?} {}",
3536 connection_id,
3537 error
3538 );
3539 }
3540 }
3541 }
3542}
3543
3544/// Send a message to the channel
3545async fn send_channel_message(
3546 _request: proto::SendChannelMessage,
3547 _response: Response<proto::SendChannelMessage>,
3548 _session: MessageContext,
3549) -> Result<()> {
3550 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3551}
3552
3553/// Delete a channel message
3554async fn remove_channel_message(
3555 _request: proto::RemoveChannelMessage,
3556 _response: Response<proto::RemoveChannelMessage>,
3557 _session: MessageContext,
3558) -> Result<()> {
3559 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3560}
3561
3562async fn update_channel_message(
3563 _request: proto::UpdateChannelMessage,
3564 _response: Response<proto::UpdateChannelMessage>,
3565 _session: MessageContext,
3566) -> Result<()> {
3567 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3568}
3569
3570/// Mark a channel message as read
3571async fn acknowledge_channel_message(
3572 _request: proto::AckChannelMessage,
3573 _session: MessageContext,
3574) -> Result<()> {
3575 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3576}
3577
3578/// Mark a buffer version as synced
3579async fn acknowledge_buffer_version(
3580 request: proto::AckBufferOperation,
3581 session: MessageContext,
3582) -> Result<()> {
3583 let buffer_id = BufferId::from_proto(request.buffer_id);
3584 session
3585 .db()
3586 .await
3587 .observe_buffer_version(
3588 buffer_id,
3589 session.user_id(),
3590 request.epoch as i32,
3591 &request.version,
3592 )
3593 .await?;
3594 Ok(())
3595}
3596
3597/// Start receiving chat updates for a channel
3598async fn join_channel_chat(
3599 _request: proto::JoinChannelChat,
3600 _response: Response<proto::JoinChannelChat>,
3601 _session: MessageContext,
3602) -> Result<()> {
3603 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3604}
3605
3606/// Stop receiving chat updates for a channel
3607async fn leave_channel_chat(
3608 _request: proto::LeaveChannelChat,
3609 _session: MessageContext,
3610) -> Result<()> {
3611 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3612}
3613
3614/// Retrieve the chat history for a channel
3615async fn get_channel_messages(
3616 _request: proto::GetChannelMessages,
3617 _response: Response<proto::GetChannelMessages>,
3618 _session: MessageContext,
3619) -> Result<()> {
3620 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3621}
3622
3623/// Retrieve specific chat messages
3624async fn get_channel_messages_by_id(
3625 _request: proto::GetChannelMessagesById,
3626 _response: Response<proto::GetChannelMessagesById>,
3627 _session: MessageContext,
3628) -> Result<()> {
3629 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3630}
3631
3632/// Retrieve the current users notifications
3633async fn get_notifications(
3634 request: proto::GetNotifications,
3635 response: Response<proto::GetNotifications>,
3636 session: MessageContext,
3637) -> Result<()> {
3638 let notifications = session
3639 .db()
3640 .await
3641 .get_notifications(
3642 session.user_id(),
3643 NOTIFICATION_COUNT_PER_PAGE,
3644 request.before_id.map(db::NotificationId::from_proto),
3645 )
3646 .await?;
3647 response.send(proto::GetNotificationsResponse {
3648 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3649 notifications,
3650 })?;
3651 Ok(())
3652}
3653
3654/// Mark notifications as read
3655async fn mark_notification_as_read(
3656 request: proto::MarkNotificationRead,
3657 response: Response<proto::MarkNotificationRead>,
3658 session: MessageContext,
3659) -> Result<()> {
3660 let database = &session.db().await;
3661 let notifications = database
3662 .mark_notification_as_read_by_id(
3663 session.user_id(),
3664 NotificationId::from_proto(request.notification_id),
3665 )
3666 .await?;
3667 send_notifications(
3668 &*session.connection_pool().await,
3669 &session.peer,
3670 notifications,
3671 );
3672 response.send(proto::Ack {})?;
3673 Ok(())
3674}
3675
3676fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3677 let message = match message {
3678 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3679 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3680 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3681 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3682 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3683 code: frame.code.into(),
3684 reason: frame.reason.as_str().to_owned().into(),
3685 })),
3686 // We should never receive a frame while reading the message, according
3687 // to the `tungstenite` maintainers:
3688 //
3689 // > It cannot occur when you read messages from the WebSocket, but it
3690 // > can be used when you want to send the raw frames (e.g. you want to
3691 // > send the frames to the WebSocket without composing the full message first).
3692 // >
3693 // > — https://github.com/snapview/tungstenite-rs/issues/268
3694 TungsteniteMessage::Frame(_) => {
3695 bail!("received an unexpected frame while reading the message")
3696 }
3697 };
3698
3699 Ok(message)
3700}
3701
3702fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3703 match message {
3704 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3705 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3706 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3707 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3708 AxumMessage::Close(frame) => {
3709 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3710 code: frame.code.into(),
3711 reason: frame.reason.as_ref().into(),
3712 }))
3713 }
3714 }
3715}
3716
3717fn notify_membership_updated(
3718 connection_pool: &mut ConnectionPool,
3719 result: MembershipUpdated,
3720 user_id: UserId,
3721 peer: &Peer,
3722) {
3723 for membership in &result.new_channels.channel_memberships {
3724 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3725 }
3726 for channel_id in &result.removed_channels {
3727 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3728 }
3729
3730 let user_channels_update = proto::UpdateUserChannels {
3731 channel_memberships: result
3732 .new_channels
3733 .channel_memberships
3734 .iter()
3735 .map(|cm| proto::ChannelMembership {
3736 channel_id: cm.channel_id.to_proto(),
3737 role: cm.role.into(),
3738 })
3739 .collect(),
3740 ..Default::default()
3741 };
3742
3743 let mut update = build_channels_update(result.new_channels);
3744 update.delete_channels = result
3745 .removed_channels
3746 .into_iter()
3747 .map(|id| id.to_proto())
3748 .collect();
3749 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3750
3751 for connection_id in connection_pool.user_connection_ids(user_id) {
3752 peer.send(connection_id, user_channels_update.clone())
3753 .trace_err();
3754 peer.send(connection_id, update.clone()).trace_err();
3755 }
3756}
3757
3758fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3759 proto::UpdateUserChannels {
3760 channel_memberships: channels
3761 .channel_memberships
3762 .iter()
3763 .map(|m| proto::ChannelMembership {
3764 channel_id: m.channel_id.to_proto(),
3765 role: m.role.into(),
3766 })
3767 .collect(),
3768 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3769 }
3770}
3771
3772fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3773 let mut update = proto::UpdateChannels::default();
3774
3775 for channel in channels.channels {
3776 update.channels.push(channel.to_proto());
3777 }
3778
3779 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3780
3781 for (channel_id, participants) in channels.channel_participants {
3782 update
3783 .channel_participants
3784 .push(proto::ChannelParticipants {
3785 channel_id: channel_id.to_proto(),
3786 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3787 });
3788 }
3789
3790 for channel in channels.invited_channels {
3791 update.channel_invitations.push(channel.to_proto());
3792 }
3793
3794 update
3795}
3796
3797fn build_initial_contacts_update(
3798 contacts: Vec<db::Contact>,
3799 pool: &ConnectionPool,
3800) -> proto::UpdateContacts {
3801 let mut update = proto::UpdateContacts::default();
3802
3803 for contact in contacts {
3804 match contact {
3805 db::Contact::Accepted { user_id, busy } => {
3806 update.contacts.push(contact_for_user(user_id, busy, pool));
3807 }
3808 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3809 db::Contact::Incoming { user_id } => {
3810 update
3811 .incoming_requests
3812 .push(proto::IncomingContactRequest {
3813 requester_id: user_id.to_proto(),
3814 })
3815 }
3816 }
3817 }
3818
3819 update
3820}
3821
3822fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3823 proto::Contact {
3824 user_id: user_id.to_proto(),
3825 online: pool.is_user_online(user_id),
3826 busy,
3827 }
3828}
3829
3830fn room_updated(room: &proto::Room, peer: &Peer) {
3831 broadcast(
3832 None,
3833 room.participants
3834 .iter()
3835 .filter_map(|participant| Some(participant.peer_id?.into())),
3836 |peer_id| {
3837 peer.send(
3838 peer_id,
3839 proto::RoomUpdated {
3840 room: Some(room.clone()),
3841 },
3842 )
3843 },
3844 );
3845}
3846
3847fn channel_updated(
3848 channel: &db::channel::Model,
3849 room: &proto::Room,
3850 peer: &Peer,
3851 pool: &ConnectionPool,
3852) {
3853 let participants = room
3854 .participants
3855 .iter()
3856 .map(|p| p.user_id)
3857 .collect::<Vec<_>>();
3858
3859 broadcast(
3860 None,
3861 pool.channel_connection_ids(channel.root_id())
3862 .filter_map(|(channel_id, role)| {
3863 role.can_see_channel(channel.visibility)
3864 .then_some(channel_id)
3865 }),
3866 |peer_id| {
3867 peer.send(
3868 peer_id,
3869 proto::UpdateChannels {
3870 channel_participants: vec![proto::ChannelParticipants {
3871 channel_id: channel.id.to_proto(),
3872 participant_user_ids: participants.clone(),
3873 }],
3874 ..Default::default()
3875 },
3876 )
3877 },
3878 );
3879}
3880
3881async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3882 let db = session.db().await;
3883
3884 let contacts = db.get_contacts(user_id).await?;
3885 let busy = db.is_user_busy(user_id).await?;
3886
3887 let pool = session.connection_pool().await;
3888 let updated_contact = contact_for_user(user_id, busy, &pool);
3889 for contact in contacts {
3890 if let db::Contact::Accepted {
3891 user_id: contact_user_id,
3892 ..
3893 } = contact
3894 {
3895 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3896 session
3897 .peer
3898 .send(
3899 contact_conn_id,
3900 proto::UpdateContacts {
3901 contacts: vec![updated_contact.clone()],
3902 remove_contacts: Default::default(),
3903 incoming_requests: Default::default(),
3904 remove_incoming_requests: Default::default(),
3905 outgoing_requests: Default::default(),
3906 remove_outgoing_requests: Default::default(),
3907 },
3908 )
3909 .trace_err();
3910 }
3911 }
3912 }
3913 Ok(())
3914}
3915
3916async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3917 let mut contacts_to_update = HashSet::default();
3918
3919 let room_id;
3920 let canceled_calls_to_user_ids;
3921 let livekit_room;
3922 let delete_livekit_room;
3923 let room;
3924 let channel;
3925
3926 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3927 contacts_to_update.insert(session.user_id());
3928
3929 for project in left_room.left_projects.values() {
3930 project_left(project, session);
3931 }
3932
3933 room_id = RoomId::from_proto(left_room.room.id);
3934 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3935 livekit_room = mem::take(&mut left_room.room.livekit_room);
3936 delete_livekit_room = left_room.deleted;
3937 room = mem::take(&mut left_room.room);
3938 channel = mem::take(&mut left_room.channel);
3939
3940 room_updated(&room, &session.peer);
3941 } else {
3942 return Ok(());
3943 }
3944
3945 if let Some(channel) = channel {
3946 channel_updated(
3947 &channel,
3948 &room,
3949 &session.peer,
3950 &*session.connection_pool().await,
3951 );
3952 }
3953
3954 {
3955 let pool = session.connection_pool().await;
3956 for canceled_user_id in canceled_calls_to_user_ids {
3957 for connection_id in pool.user_connection_ids(canceled_user_id) {
3958 session
3959 .peer
3960 .send(
3961 connection_id,
3962 proto::CallCanceled {
3963 room_id: room_id.to_proto(),
3964 },
3965 )
3966 .trace_err();
3967 }
3968 contacts_to_update.insert(canceled_user_id);
3969 }
3970 }
3971
3972 for contact_user_id in contacts_to_update {
3973 update_user_contacts(contact_user_id, session).await?;
3974 }
3975
3976 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3977 live_kit
3978 .remove_participant(livekit_room.clone(), session.user_id().to_string())
3979 .await
3980 .trace_err();
3981
3982 if delete_livekit_room {
3983 live_kit.delete_room(livekit_room).await.trace_err();
3984 }
3985 }
3986
3987 Ok(())
3988}
3989
3990async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3991 let left_channel_buffers = session
3992 .db()
3993 .await
3994 .leave_channel_buffers(session.connection_id)
3995 .await?;
3996
3997 for left_buffer in left_channel_buffers {
3998 channel_buffer_updated(
3999 session.connection_id,
4000 left_buffer.connections,
4001 &proto::UpdateChannelBufferCollaborators {
4002 channel_id: left_buffer.channel_id.to_proto(),
4003 collaborators: left_buffer.collaborators,
4004 },
4005 &session.peer,
4006 );
4007 }
4008
4009 Ok(())
4010}
4011
4012fn project_left(project: &db::LeftProject, session: &Session) {
4013 for connection_id in &project.connection_ids {
4014 if project.should_unshare {
4015 session
4016 .peer
4017 .send(
4018 *connection_id,
4019 proto::UnshareProject {
4020 project_id: project.id.to_proto(),
4021 },
4022 )
4023 .trace_err();
4024 } else {
4025 session
4026 .peer
4027 .send(
4028 *connection_id,
4029 proto::RemoveProjectCollaborator {
4030 project_id: project.id.to_proto(),
4031 peer_id: Some(session.connection_id.into()),
4032 },
4033 )
4034 .trace_err();
4035 }
4036 }
4037}
4038
4039async fn share_agent_thread(
4040 request: proto::ShareAgentThread,
4041 response: Response<proto::ShareAgentThread>,
4042 session: MessageContext,
4043) -> Result<()> {
4044 let user_id = session.user_id();
4045
4046 let share_id = SharedThreadId::from_proto(request.session_id.clone())
4047 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4048
4049 session
4050 .db()
4051 .await
4052 .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4053 .await?;
4054
4055 response.send(proto::Ack {})?;
4056
4057 Ok(())
4058}
4059
4060async fn get_shared_agent_thread(
4061 request: proto::GetSharedAgentThread,
4062 response: Response<proto::GetSharedAgentThread>,
4063 session: MessageContext,
4064) -> Result<()> {
4065 let share_id = SharedThreadId::from_proto(request.session_id)
4066 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4067
4068 let result = session.db().await.get_shared_thread(share_id).await?;
4069
4070 match result {
4071 Some((thread, username)) => {
4072 response.send(proto::GetSharedAgentThreadResponse {
4073 title: thread.title,
4074 thread_data: thread.data,
4075 sharer_username: username,
4076 created_at: thread.created_at.and_utc().to_rfc3339(),
4077 })?;
4078 }
4079 None => {
4080 return Err(anyhow!("Shared thread not found").into());
4081 }
4082 }
4083
4084 Ok(())
4085}
4086
4087pub trait ResultExt {
4088 type Ok;
4089
4090 fn trace_err(self) -> Option<Self::Ok>;
4091}
4092
4093impl<T, E> ResultExt for Result<T, E>
4094where
4095 E: std::fmt::Debug,
4096{
4097 type Ok = T;
4098
4099 #[track_caller]
4100 fn trace_err(self) -> Option<T> {
4101 match self {
4102 Ok(value) => Some(value),
4103 Err(error) => {
4104 tracing::error!("{:?}", error);
4105 None
4106 }
4107 }
4108 }
4109}