1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
10 UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use futures::TryFutureExt as _;
36use rpc::proto::split_repository_update;
37use tracing::Span;
38use util::paths::PathStyle;
39
40use futures::{
41 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
42 stream::FuturesUnordered,
43};
44use prometheus::{IntGauge, register_int_gauge};
45use rpc::{
46 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
47 proto::{
48 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
49 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
50 },
51};
52use semver::Version;
53use std::{
54 any::TypeId,
55 future::Future,
56 marker::PhantomData,
57 mem,
58 net::SocketAddr,
59 ops::{Deref, DerefMut},
60 rc::Rc,
61 sync::{
62 Arc, OnceLock,
63 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
64 },
65 time::{Duration, Instant},
66};
67use tokio::sync::{Semaphore, watch};
68use tower::ServiceBuilder;
69use tracing::{
70 Instrument,
71 field::{self},
72 info_span, instrument,
73};
74
75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
76
77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
79
80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
82
83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
84
85const TOTAL_DURATION_MS: &str = "total_duration_ms";
86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
88const HOST_WAITING_MS: &str = "host_waiting_ms";
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
92
93pub struct ConnectionGuard;
94
95impl ConnectionGuard {
96 pub fn try_acquire() -> Result<Self, ()> {
97 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
98 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
99 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
100 tracing::error!(
101 "too many concurrent connections: {}",
102 current_connections + 1
103 );
104 return Err(());
105 }
106 Ok(ConnectionGuard)
107 }
108}
109
110impl Drop for ConnectionGuard {
111 fn drop(&mut self) {
112 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
113 }
114}
115
116struct Response<R> {
117 peer: Arc<Peer>,
118 receipt: Receipt<R>,
119 responded: Arc<AtomicBool>,
120}
121
122impl<R: RequestMessage> Response<R> {
123 fn send(self, payload: R::Response) -> Result<()> {
124 self.responded.store(true, SeqCst);
125 self.peer.respond(self.receipt, payload)?;
126 Ok(())
127 }
128}
129
130#[derive(Clone, Debug)]
131pub enum Principal {
132 User(User),
133}
134
135impl Principal {
136 fn update_span(&self, span: &tracing::Span) {
137 match &self {
138 Principal::User(user) => {
139 span.record("user_id", user.id.0);
140 span.record("login", &user.github_login);
141 }
142 }
143 }
144}
145
146#[derive(Clone)]
147struct MessageContext {
148 session: Session,
149 span: tracing::Span,
150}
151
152impl Deref for MessageContext {
153 type Target = Session;
154
155 fn deref(&self) -> &Self::Target {
156 &self.session
157 }
158}
159
160impl MessageContext {
161 pub fn forward_request<T: RequestMessage>(
162 &self,
163 receiver_id: ConnectionId,
164 request: T,
165 ) -> impl Future<Output = anyhow::Result<T::Response>> {
166 let request_start_time = Instant::now();
167 let span = self.span.clone();
168 tracing::info!("start forwarding request");
169 self.peer
170 .forward_request(self.connection_id, receiver_id, request)
171 .inspect(move |_| {
172 span.record(
173 HOST_WAITING_MS,
174 request_start_time.elapsed().as_micros() as f64 / 1000.0,
175 );
176 })
177 .inspect_err(|_| tracing::error!("error forwarding request"))
178 .inspect_ok(|_| tracing::info!("finished forwarding request"))
179 }
180}
181
182#[derive(Clone)]
183struct Session {
184 principal: Principal,
185 connection_id: ConnectionId,
186 db: Arc<tokio::sync::Mutex<DbHandle>>,
187 peer: Arc<Peer>,
188 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
189 app_state: Arc<AppState>,
190 /// The GeoIP country code for the user.
191 #[allow(unused)]
192 geoip_country_code: Option<String>,
193 #[allow(unused)]
194 system_id: Option<String>,
195 _executor: Executor,
196}
197
198impl Session {
199 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
200 #[cfg(feature = "test-support")]
201 tokio::task::yield_now().await;
202 let guard = self.db.lock().await;
203 #[cfg(feature = "test-support")]
204 tokio::task::yield_now().await;
205 guard
206 }
207
208 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
209 #[cfg(feature = "test-support")]
210 tokio::task::yield_now().await;
211 let guard = self.connection_pool.lock();
212 ConnectionPoolGuard {
213 guard,
214 _not_send: PhantomData,
215 }
216 }
217
218 #[expect(dead_code)]
219 fn is_staff(&self) -> bool {
220 match &self.principal {
221 Principal::User(user) => user.admin,
222 }
223 }
224
225 fn user_id(&self) -> UserId {
226 match &self.principal {
227 Principal::User(user) => user.id,
228 }
229 }
230}
231
232impl Debug for Session {
233 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
234 let mut result = f.debug_struct("Session");
235 match &self.principal {
236 Principal::User(user) => {
237 result.field("user", &user.github_login);
238 }
239 }
240 result.field("connection_id", &self.connection_id).finish()
241 }
242}
243
244struct DbHandle(Arc<Database>);
245
246impl Deref for DbHandle {
247 type Target = Database;
248
249 fn deref(&self) -> &Self::Target {
250 self.0.as_ref()
251 }
252}
253
254pub struct Server {
255 id: parking_lot::Mutex<ServerId>,
256 peer: Arc<Peer>,
257 pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
258 app_state: Arc<AppState>,
259 handlers: HashMap<TypeId, MessageHandler>,
260 teardown: watch::Sender<bool>,
261}
262
263struct ConnectionPoolGuard<'a> {
264 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
265 _not_send: PhantomData<Rc<()>>,
266}
267
268impl Server {
269 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
270 let mut server = Self {
271 id: parking_lot::Mutex::new(id),
272 peer: Peer::new(id.0 as u32),
273 app_state,
274 connection_pool: Default::default(),
275 handlers: Default::default(),
276 teardown: watch::channel(false).0,
277 };
278
279 server
280 .add_request_handler(ping)
281 .add_request_handler(create_room)
282 .add_request_handler(join_room)
283 .add_request_handler(rejoin_room)
284 .add_request_handler(leave_room)
285 .add_request_handler(set_room_participant_role)
286 .add_request_handler(call)
287 .add_request_handler(cancel_call)
288 .add_message_handler(decline_call)
289 .add_request_handler(update_participant_location)
290 .add_request_handler(share_project)
291 .add_message_handler(unshare_project)
292 .add_request_handler(join_project)
293 .add_message_handler(leave_project)
294 .add_request_handler(update_project)
295 .add_request_handler(update_worktree)
296 .add_request_handler(update_repository)
297 .add_request_handler(remove_repository)
298 .add_message_handler(start_language_server)
299 .add_message_handler(update_language_server)
300 .add_message_handler(update_diagnostic_summary)
301 .add_message_handler(update_worktree_settings)
302 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
307 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
308 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
311 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
313 .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
314 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
316 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
317 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
318 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
319 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
320 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
321 .add_request_handler(
322 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
323 )
324 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
325 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
326 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
327 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
328 .add_request_handler(
329 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
330 )
331 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
332 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
333 .add_request_handler(
334 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
335 )
336 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
337 .add_request_handler(
338 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
339 )
340 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
341 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
342 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
343 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
344 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
345 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
346 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
347 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
348 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
349 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
350 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
353 )
354 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
355 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
356 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
357 .add_request_handler(lsp_query)
358 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
359 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
360 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
361 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
362 .add_message_handler(create_buffer_for_peer)
363 .add_message_handler(create_image_for_peer)
364 .add_request_handler(update_buffer)
365 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
366 .add_message_handler(
367 broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
368 )
369 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
370 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
371 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
372 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
373 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
374 .add_message_handler(
375 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
376 )
377 .add_request_handler(get_users)
378 .add_request_handler(fuzzy_search_users)
379 .add_request_handler(request_contact)
380 .add_request_handler(remove_contact)
381 .add_request_handler(respond_to_contact_request)
382 .add_message_handler(subscribe_to_channels)
383 .add_request_handler(create_channel)
384 .add_request_handler(delete_channel)
385 .add_request_handler(invite_channel_member)
386 .add_request_handler(remove_channel_member)
387 .add_request_handler(set_channel_member_role)
388 .add_request_handler(set_channel_visibility)
389 .add_request_handler(rename_channel)
390 .add_request_handler(join_channel_buffer)
391 .add_request_handler(leave_channel_buffer)
392 .add_message_handler(update_channel_buffer)
393 .add_request_handler(rejoin_channel_buffers)
394 .add_request_handler(get_channel_members)
395 .add_request_handler(respond_to_channel_invite)
396 .add_request_handler(join_channel)
397 .add_request_handler(join_channel_chat)
398 .add_message_handler(leave_channel_chat)
399 .add_request_handler(send_channel_message)
400 .add_request_handler(remove_channel_message)
401 .add_request_handler(update_channel_message)
402 .add_request_handler(get_channel_messages)
403 .add_request_handler(get_channel_messages_by_id)
404 .add_request_handler(get_notifications)
405 .add_request_handler(mark_notification_as_read)
406 .add_request_handler(move_channel)
407 .add_request_handler(reorder_channel)
408 .add_request_handler(follow)
409 .add_message_handler(unfollow)
410 .add_message_handler(update_followers)
411 .add_message_handler(acknowledge_channel_message)
412 .add_message_handler(acknowledge_buffer_version)
413 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
414 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
415 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
416 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
417 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
418 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
419 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
420 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
421 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
422 .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
423 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
424 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
425 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
426 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
427 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
428 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
429 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
430 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
431 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
432 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
433 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
434 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
435 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
436 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
437 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
438 .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
439 .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
440 .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
441 .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
442 .add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
443 .add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
444 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
445 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
446 .add_message_handler(update_context)
447 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
448 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
449 .add_request_handler(share_agent_thread)
450 .add_request_handler(get_shared_agent_thread)
451 .add_request_handler(forward_project_search_chunk);
452
453 Arc::new(server)
454 }
455
456 pub async fn start(&self) -> Result<()> {
457 let server_id = *self.id.lock();
458 let app_state = self.app_state.clone();
459 let peer = self.peer.clone();
460 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
461 let pool = self.connection_pool.clone();
462 let livekit_client = self.app_state.livekit_client.clone();
463
464 let span = info_span!("start server");
465 self.app_state.executor.spawn_detached(
466 async move {
467 tracing::info!("waiting for cleanup timeout");
468 timeout.await;
469 tracing::info!("cleanup timeout expired, retrieving stale rooms");
470
471 app_state
472 .db
473 .delete_stale_channel_chat_participants(
474 &app_state.config.zed_environment,
475 server_id,
476 )
477 .await
478 .trace_err();
479
480 if let Some((room_ids, channel_ids)) = app_state
481 .db
482 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
483 .await
484 .trace_err()
485 {
486 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
487 tracing::info!(
488 stale_channel_buffer_count = channel_ids.len(),
489 "retrieved stale channel buffers"
490 );
491
492 for channel_id in channel_ids {
493 if let Some(refreshed_channel_buffer) = app_state
494 .db
495 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
496 .await
497 .trace_err()
498 {
499 for connection_id in refreshed_channel_buffer.connection_ids {
500 peer.send(
501 connection_id,
502 proto::UpdateChannelBufferCollaborators {
503 channel_id: channel_id.to_proto(),
504 collaborators: refreshed_channel_buffer
505 .collaborators
506 .clone(),
507 },
508 )
509 .trace_err();
510 }
511 }
512 }
513
514 for room_id in room_ids {
515 let mut contacts_to_update = HashSet::default();
516 let mut canceled_calls_to_user_ids = Vec::new();
517 let mut livekit_room = String::new();
518 let mut delete_livekit_room = false;
519
520 if let Some(mut refreshed_room) = app_state
521 .db
522 .clear_stale_room_participants(room_id, server_id)
523 .await
524 .trace_err()
525 {
526 tracing::info!(
527 room_id = room_id.0,
528 new_participant_count = refreshed_room.room.participants.len(),
529 "refreshed room"
530 );
531 room_updated(&refreshed_room.room, &peer);
532 if let Some(channel) = refreshed_room.channel.as_ref() {
533 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
534 }
535 contacts_to_update
536 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
537 contacts_to_update
538 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
539 canceled_calls_to_user_ids =
540 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
541 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
542 delete_livekit_room = refreshed_room.room.participants.is_empty();
543 }
544
545 {
546 let pool = pool.lock();
547 for canceled_user_id in canceled_calls_to_user_ids {
548 for connection_id in pool.user_connection_ids(canceled_user_id) {
549 peer.send(
550 connection_id,
551 proto::CallCanceled {
552 room_id: room_id.to_proto(),
553 },
554 )
555 .trace_err();
556 }
557 }
558 }
559
560 for user_id in contacts_to_update {
561 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
562 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
563 if let Some((busy, contacts)) = busy.zip(contacts) {
564 let pool = pool.lock();
565 let updated_contact = contact_for_user(user_id, busy, &pool);
566 for contact in contacts {
567 if let db::Contact::Accepted {
568 user_id: contact_user_id,
569 ..
570 } = contact
571 {
572 for contact_conn_id in
573 pool.user_connection_ids(contact_user_id)
574 {
575 peer.send(
576 contact_conn_id,
577 proto::UpdateContacts {
578 contacts: vec![updated_contact.clone()],
579 remove_contacts: Default::default(),
580 incoming_requests: Default::default(),
581 remove_incoming_requests: Default::default(),
582 outgoing_requests: Default::default(),
583 remove_outgoing_requests: Default::default(),
584 },
585 )
586 .trace_err();
587 }
588 }
589 }
590 }
591 }
592
593 if let Some(live_kit) = livekit_client.as_ref()
594 && delete_livekit_room
595 {
596 live_kit.delete_room(livekit_room).await.trace_err();
597 }
598 }
599 }
600
601 app_state
602 .db
603 .delete_stale_channel_chat_participants(
604 &app_state.config.zed_environment,
605 server_id,
606 )
607 .await
608 .trace_err();
609
610 app_state
611 .db
612 .clear_old_worktree_entries(server_id)
613 .await
614 .trace_err();
615
616 app_state
617 .db
618 .delete_stale_servers(&app_state.config.zed_environment, server_id)
619 .await
620 .trace_err();
621 }
622 .instrument(span),
623 );
624 Ok(())
625 }
626
627 pub fn teardown(&self) {
628 self.peer.teardown();
629 self.connection_pool.lock().reset();
630 let _ = self.teardown.send(true);
631 }
632
633 #[cfg(feature = "test-support")]
634 pub fn reset(&self, id: ServerId) {
635 self.teardown();
636 *self.id.lock() = id;
637 self.peer.reset(id.0 as u32);
638 let _ = self.teardown.send(false);
639 }
640
641 #[cfg(feature = "test-support")]
642 pub fn id(&self) -> ServerId {
643 *self.id.lock()
644 }
645
646 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
647 where
648 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
649 Fut: 'static + Send + Future<Output = Result<()>>,
650 M: EnvelopedMessage,
651 {
652 let prev_handler = self.handlers.insert(
653 TypeId::of::<M>(),
654 Box::new(move |envelope, session, span| {
655 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
656 let received_at = envelope.received_at;
657 tracing::info!("message received");
658 let start_time = Instant::now();
659 let future = (handler)(
660 *envelope,
661 MessageContext {
662 session,
663 span: span.clone(),
664 },
665 );
666 async move {
667 let result = future.await;
668 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
669 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
670 let queue_duration_ms = total_duration_ms - processing_duration_ms;
671 span.record(TOTAL_DURATION_MS, total_duration_ms);
672 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
673 span.record(QUEUE_DURATION_MS, queue_duration_ms);
674 match result {
675 Err(error) => {
676 tracing::error!(?error, "error handling message")
677 }
678 Ok(()) => tracing::info!("finished handling message"),
679 }
680 }
681 .boxed()
682 }),
683 );
684 if prev_handler.is_some() {
685 panic!("registered a handler for the same message twice");
686 }
687 self
688 }
689
690 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
691 where
692 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
693 Fut: 'static + Send + Future<Output = Result<()>>,
694 M: EnvelopedMessage,
695 {
696 self.add_handler(move |envelope, session| handler(envelope.payload, session));
697 self
698 }
699
700 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
701 where
702 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
703 Fut: Send + Future<Output = Result<()>>,
704 M: RequestMessage,
705 {
706 let handler = Arc::new(handler);
707 self.add_handler(move |envelope, session| {
708 let receipt = envelope.receipt();
709 let handler = handler.clone();
710 async move {
711 let peer = session.peer.clone();
712 let responded = Arc::new(AtomicBool::default());
713 let response = Response {
714 peer: peer.clone(),
715 responded: responded.clone(),
716 receipt,
717 };
718 match (handler)(envelope.payload, response, session).await {
719 Ok(()) => {
720 if responded.load(std::sync::atomic::Ordering::SeqCst) {
721 Ok(())
722 } else {
723 Err(anyhow!("handler did not send a response"))?
724 }
725 }
726 Err(error) => {
727 let proto_err = match &error {
728 Error::Internal(err) => err.to_proto(),
729 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
730 };
731 peer.respond_with_error(receipt, proto_err)?;
732 Err(error)
733 }
734 }
735 }
736 })
737 }
738
739 pub fn handle_connection(
740 self: &Arc<Self>,
741 connection: Connection,
742 address: String,
743 principal: Principal,
744 zed_version: ZedVersion,
745 release_channel: Option<String>,
746 user_agent: Option<String>,
747 geoip_country_code: Option<String>,
748 system_id: Option<String>,
749 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
750 executor: Executor,
751 connection_guard: Option<ConnectionGuard>,
752 ) -> impl Future<Output = ()> + use<> {
753 let this = self.clone();
754 let span = info_span!("handle connection", %address,
755 connection_id=field::Empty,
756 user_id=field::Empty,
757 login=field::Empty,
758 user_agent=field::Empty,
759 geoip_country_code=field::Empty,
760 release_channel=field::Empty,
761 );
762 principal.update_span(&span);
763 if let Some(user_agent) = user_agent {
764 span.record("user_agent", user_agent);
765 }
766 if let Some(release_channel) = release_channel {
767 span.record("release_channel", release_channel);
768 }
769
770 if let Some(country_code) = geoip_country_code.as_ref() {
771 span.record("geoip_country_code", country_code);
772 }
773
774 let mut teardown = self.teardown.subscribe();
775 async move {
776 if *teardown.borrow() {
777 tracing::error!("server is tearing down");
778 return;
779 }
780
781 let (connection_id, handle_io, mut incoming_rx) =
782 this.peer.add_connection(connection, {
783 let executor = executor.clone();
784 move |duration| executor.sleep(duration)
785 });
786 tracing::Span::current().record("connection_id", format!("{}", connection_id));
787
788 tracing::info!("connection opened");
789
790 let session = Session {
791 principal: principal.clone(),
792 connection_id,
793 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
794 peer: this.peer.clone(),
795 connection_pool: this.connection_pool.clone(),
796 app_state: this.app_state.clone(),
797 geoip_country_code,
798 system_id,
799 _executor: executor.clone(),
800 };
801
802 if let Err(error) = this
803 .send_initial_client_update(
804 connection_id,
805 zed_version,
806 send_connection_id,
807 &session,
808 )
809 .await
810 {
811 tracing::error!(?error, "failed to send initial client update");
812 return;
813 }
814 drop(connection_guard);
815
816 let handle_io = handle_io.fuse();
817 futures::pin_mut!(handle_io);
818
819 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
820 // This prevents deadlocks when e.g., client A performs a request to client B and
821 // client B performs a request to client A. If both clients stop processing further
822 // messages until their respective request completes, they won't have a chance to
823 // respond to the other client's request and cause a deadlock.
824 //
825 // This arrangement ensures we will attempt to process earlier messages first, but fall
826 // back to processing messages arrived later in the spirit of making progress.
827 const MAX_CONCURRENT_HANDLERS: usize = 256;
828 let mut foreground_message_handlers = FuturesUnordered::new();
829 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
830 let get_concurrent_handlers = {
831 let concurrent_handlers = concurrent_handlers.clone();
832 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
833 };
834 loop {
835 let next_message = async {
836 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
837 let message = incoming_rx.next().await;
838 // Cache the concurrent_handlers here, so that we know what the
839 // queue looks like as each handler starts
840 (permit, message, get_concurrent_handlers())
841 }
842 .fuse();
843 futures::pin_mut!(next_message);
844 futures::select_biased! {
845 _ = teardown.changed().fuse() => return,
846 result = handle_io => {
847 if let Err(error) = result {
848 tracing::error!(?error, "error handling I/O");
849 }
850 break;
851 }
852 _ = foreground_message_handlers.next() => {}
853 next_message = next_message => {
854 let (permit, message, concurrent_handlers) = next_message;
855 if let Some(message) = message {
856 let type_name = message.payload_type_name();
857 // note: we copy all the fields from the parent span so we can query them in the logs.
858 // (https://github.com/tokio-rs/tracing/issues/2670).
859 let span = tracing::info_span!("receive message",
860 %connection_id,
861 %address,
862 type_name,
863 concurrent_handlers,
864 user_id=field::Empty,
865 login=field::Empty,
866 lsp_query_request=field::Empty,
867 release_channel=field::Empty,
868 { TOTAL_DURATION_MS }=field::Empty,
869 { PROCESSING_DURATION_MS }=field::Empty,
870 { QUEUE_DURATION_MS }=field::Empty,
871 { HOST_WAITING_MS }=field::Empty
872 );
873 principal.update_span(&span);
874 let span_enter = span.enter();
875 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
876 let is_background = message.is_background();
877 let handle_message = (handler)(message, session.clone(), span.clone());
878 drop(span_enter);
879
880 let handle_message = async move {
881 handle_message.await;
882 drop(permit);
883 }.instrument(span);
884 if is_background {
885 executor.spawn_detached(handle_message);
886 } else {
887 foreground_message_handlers.push(handle_message);
888 }
889 } else {
890 tracing::error!("no message handler");
891 }
892 } else {
893 tracing::info!("connection closed");
894 break;
895 }
896 }
897 }
898 }
899
900 drop(foreground_message_handlers);
901 let concurrent_handlers = get_concurrent_handlers();
902 tracing::info!(concurrent_handlers, "signing out");
903 if let Err(error) = connection_lost(session, teardown, executor).await {
904 tracing::error!(?error, "error signing out");
905 }
906 }
907 .instrument(span)
908 }
909
910 async fn send_initial_client_update(
911 &self,
912 connection_id: ConnectionId,
913 zed_version: ZedVersion,
914 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
915 session: &Session,
916 ) -> Result<()> {
917 self.peer.send(
918 connection_id,
919 proto::Hello {
920 peer_id: Some(connection_id.into()),
921 },
922 )?;
923 tracing::info!("sent hello message");
924 if let Some(send_connection_id) = send_connection_id.take() {
925 let _ = send_connection_id.send(connection_id);
926 }
927
928 match &session.principal {
929 Principal::User(user) => {
930 if !user.connected_once {
931 self.peer.send(connection_id, proto::ShowContacts {})?;
932 self.app_state
933 .db
934 .set_user_connected_once(user.id, true)
935 .await?;
936 }
937
938 let contacts = self.app_state.db.get_contacts(user.id).await?;
939
940 {
941 let mut pool = self.connection_pool.lock();
942 pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
943 self.peer.send(
944 connection_id,
945 build_initial_contacts_update(contacts, &pool),
946 )?;
947 }
948
949 if should_auto_subscribe_to_channels(&zed_version) {
950 subscribe_user_to_channels(user.id, session).await?;
951 }
952
953 if let Some(incoming_call) =
954 self.app_state.db.incoming_call_for_user(user.id).await?
955 {
956 self.peer.send(connection_id, incoming_call)?;
957 }
958
959 update_user_contacts(user.id, session).await?;
960 }
961 }
962
963 Ok(())
964 }
965}
966
967impl Deref for ConnectionPoolGuard<'_> {
968 type Target = ConnectionPool;
969
970 fn deref(&self) -> &Self::Target {
971 &self.guard
972 }
973}
974
975impl DerefMut for ConnectionPoolGuard<'_> {
976 fn deref_mut(&mut self) -> &mut Self::Target {
977 &mut self.guard
978 }
979}
980
981impl Drop for ConnectionPoolGuard<'_> {
982 fn drop(&mut self) {
983 #[cfg(feature = "test-support")]
984 self.check_invariants();
985 }
986}
987
988fn broadcast<F>(
989 sender_id: Option<ConnectionId>,
990 receiver_ids: impl IntoIterator<Item = ConnectionId>,
991 mut f: F,
992) where
993 F: FnMut(ConnectionId) -> anyhow::Result<()>,
994{
995 for receiver_id in receiver_ids {
996 if Some(receiver_id) != sender_id
997 && let Err(error) = f(receiver_id)
998 {
999 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1000 }
1001 }
1002}
1003
1004pub struct ProtocolVersion(u32);
1005
1006impl Header for ProtocolVersion {
1007 fn name() -> &'static HeaderName {
1008 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1009 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1010 }
1011
1012 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1013 where
1014 Self: Sized,
1015 I: Iterator<Item = &'i axum::http::HeaderValue>,
1016 {
1017 let version = values
1018 .next()
1019 .ok_or_else(axum::headers::Error::invalid)?
1020 .to_str()
1021 .map_err(|_| axum::headers::Error::invalid())?
1022 .parse()
1023 .map_err(|_| axum::headers::Error::invalid())?;
1024 Ok(Self(version))
1025 }
1026
1027 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1028 values.extend([self.0.to_string().parse().unwrap()]);
1029 }
1030}
1031
1032pub struct AppVersionHeader(Version);
1033impl Header for AppVersionHeader {
1034 fn name() -> &'static HeaderName {
1035 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1036 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1037 }
1038
1039 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1040 where
1041 Self: Sized,
1042 I: Iterator<Item = &'i axum::http::HeaderValue>,
1043 {
1044 let version = values
1045 .next()
1046 .ok_or_else(axum::headers::Error::invalid)?
1047 .to_str()
1048 .map_err(|_| axum::headers::Error::invalid())?
1049 .parse()
1050 .map_err(|_| axum::headers::Error::invalid())?;
1051 Ok(Self(version))
1052 }
1053
1054 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1055 values.extend([self.0.to_string().parse().unwrap()]);
1056 }
1057}
1058
1059#[derive(Debug)]
1060pub struct ReleaseChannelHeader(String);
1061
1062impl Header for ReleaseChannelHeader {
1063 fn name() -> &'static HeaderName {
1064 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1065 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1066 }
1067
1068 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1069 where
1070 Self: Sized,
1071 I: Iterator<Item = &'i axum::http::HeaderValue>,
1072 {
1073 Ok(Self(
1074 values
1075 .next()
1076 .ok_or_else(axum::headers::Error::invalid)?
1077 .to_str()
1078 .map_err(|_| axum::headers::Error::invalid())?
1079 .to_owned(),
1080 ))
1081 }
1082
1083 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1084 values.extend([self.0.parse().unwrap()]);
1085 }
1086}
1087
1088pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1089 Router::new()
1090 .route("/rpc", get(handle_websocket_request))
1091 .layer(
1092 ServiceBuilder::new()
1093 .layer(Extension(server.app_state.clone()))
1094 .layer(middleware::from_fn(auth::validate_header)),
1095 )
1096 .route("/metrics", get(handle_metrics))
1097 .layer(Extension(server))
1098}
1099
1100pub async fn handle_websocket_request(
1101 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1102 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1103 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1104 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1105 Extension(server): Extension<Arc<Server>>,
1106 Extension(principal): Extension<Principal>,
1107 user_agent: Option<TypedHeader<UserAgent>>,
1108 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1109 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1110 ws: WebSocketUpgrade,
1111) -> axum::response::Response {
1112 if protocol_version != rpc::PROTOCOL_VERSION {
1113 return (
1114 StatusCode::UPGRADE_REQUIRED,
1115 "client must be upgraded".to_string(),
1116 )
1117 .into_response();
1118 }
1119
1120 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1121 return (
1122 StatusCode::UPGRADE_REQUIRED,
1123 "no version header found".to_string(),
1124 )
1125 .into_response();
1126 };
1127
1128 let release_channel = release_channel_header.map(|header| header.0.0);
1129
1130 if !version.can_collaborate() {
1131 return (
1132 StatusCode::UPGRADE_REQUIRED,
1133 "client must be upgraded".to_string(),
1134 )
1135 .into_response();
1136 }
1137
1138 let socket_address = socket_address.to_string();
1139
1140 // Acquire connection guard before WebSocket upgrade
1141 let connection_guard = match ConnectionGuard::try_acquire() {
1142 Ok(guard) => guard,
1143 Err(()) => {
1144 return (
1145 StatusCode::SERVICE_UNAVAILABLE,
1146 "Too many concurrent connections",
1147 )
1148 .into_response();
1149 }
1150 };
1151
1152 ws.on_upgrade(move |socket| {
1153 let socket = socket
1154 .map_ok(to_tungstenite_message)
1155 .err_into()
1156 .with(|message| async move { to_axum_message(message) });
1157 let connection = Connection::new(Box::pin(socket));
1158 async move {
1159 server
1160 .handle_connection(
1161 connection,
1162 socket_address,
1163 principal,
1164 version,
1165 release_channel,
1166 user_agent.map(|header| header.to_string()),
1167 country_code_header.map(|header| header.to_string()),
1168 system_id_header.map(|header| header.to_string()),
1169 None,
1170 Executor::Production,
1171 Some(connection_guard),
1172 )
1173 .await;
1174 }
1175 })
1176}
1177
1178pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1179 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1180 let connections_metric = CONNECTIONS_METRIC
1181 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1182
1183 let connections = server
1184 .connection_pool
1185 .lock()
1186 .connections()
1187 .filter(|connection| !connection.admin)
1188 .count();
1189 connections_metric.set(connections as _);
1190
1191 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1192 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1193 register_int_gauge!(
1194 "shared_projects",
1195 "number of open projects with one or more guests"
1196 )
1197 .unwrap()
1198 });
1199
1200 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1201 shared_projects_metric.set(shared_projects as _);
1202
1203 let encoder = prometheus::TextEncoder::new();
1204 let metric_families = prometheus::gather();
1205 let encoded_metrics = encoder
1206 .encode_to_string(&metric_families)
1207 .map_err(|err| anyhow!("{err}"))?;
1208 Ok(encoded_metrics)
1209}
1210
1211#[instrument(err, skip(executor))]
1212async fn connection_lost(
1213 session: Session,
1214 mut teardown: watch::Receiver<bool>,
1215 executor: Executor,
1216) -> Result<()> {
1217 session.peer.disconnect(session.connection_id);
1218 session
1219 .connection_pool()
1220 .await
1221 .remove_connection(session.connection_id)?;
1222
1223 session
1224 .db()
1225 .await
1226 .connection_lost(session.connection_id)
1227 .await
1228 .trace_err();
1229
1230 futures::select_biased! {
1231 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1232
1233 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1234 leave_room_for_session(&session, session.connection_id).await.trace_err();
1235 leave_channel_buffers_for_session(&session)
1236 .await
1237 .trace_err();
1238
1239 if !session
1240 .connection_pool()
1241 .await
1242 .is_user_online(session.user_id())
1243 {
1244 let db = session.db().await;
1245 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1246 room_updated(&room, &session.peer);
1247 }
1248 }
1249
1250 update_user_contacts(session.user_id(), &session).await?;
1251 },
1252 _ = teardown.changed().fuse() => {}
1253 }
1254
1255 Ok(())
1256}
1257
1258/// Acknowledges a ping from a client, used to keep the connection alive.
1259async fn ping(
1260 _: proto::Ping,
1261 response: Response<proto::Ping>,
1262 _session: MessageContext,
1263) -> Result<()> {
1264 response.send(proto::Ack {})?;
1265 Ok(())
1266}
1267
1268/// Creates a new room for calling (outside of channels)
1269async fn create_room(
1270 _request: proto::CreateRoom,
1271 response: Response<proto::CreateRoom>,
1272 session: MessageContext,
1273) -> Result<()> {
1274 let livekit_room = nanoid::nanoid!(30);
1275
1276 let live_kit_connection_info = util::maybe!(async {
1277 let live_kit = session.app_state.livekit_client.as_ref();
1278 let live_kit = live_kit?;
1279 let user_id = session.user_id().to_string();
1280
1281 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1282
1283 Some(proto::LiveKitConnectionInfo {
1284 server_url: live_kit.url().into(),
1285 token,
1286 can_publish: true,
1287 })
1288 })
1289 .await;
1290
1291 let room = session
1292 .db()
1293 .await
1294 .create_room(session.user_id(), session.connection_id, &livekit_room)
1295 .await?;
1296
1297 response.send(proto::CreateRoomResponse {
1298 room: Some(room.clone()),
1299 live_kit_connection_info,
1300 })?;
1301
1302 update_user_contacts(session.user_id(), &session).await?;
1303 Ok(())
1304}
1305
1306/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1307async fn join_room(
1308 request: proto::JoinRoom,
1309 response: Response<proto::JoinRoom>,
1310 session: MessageContext,
1311) -> Result<()> {
1312 let room_id = RoomId::from_proto(request.id);
1313
1314 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1315
1316 if let Some(channel_id) = channel_id {
1317 return join_channel_internal(channel_id, Box::new(response), session).await;
1318 }
1319
1320 let joined_room = {
1321 let room = session
1322 .db()
1323 .await
1324 .join_room(room_id, session.user_id(), session.connection_id)
1325 .await?;
1326 room_updated(&room.room, &session.peer);
1327 room.into_inner()
1328 };
1329
1330 for connection_id in session
1331 .connection_pool()
1332 .await
1333 .user_connection_ids(session.user_id())
1334 {
1335 session
1336 .peer
1337 .send(
1338 connection_id,
1339 proto::CallCanceled {
1340 room_id: room_id.to_proto(),
1341 },
1342 )
1343 .trace_err();
1344 }
1345
1346 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1347 {
1348 live_kit
1349 .room_token(
1350 &joined_room.room.livekit_room,
1351 &session.user_id().to_string(),
1352 )
1353 .trace_err()
1354 .map(|token| proto::LiveKitConnectionInfo {
1355 server_url: live_kit.url().into(),
1356 token,
1357 can_publish: true,
1358 })
1359 } else {
1360 None
1361 };
1362
1363 response.send(proto::JoinRoomResponse {
1364 room: Some(joined_room.room),
1365 channel_id: None,
1366 live_kit_connection_info,
1367 })?;
1368
1369 update_user_contacts(session.user_id(), &session).await?;
1370 Ok(())
1371}
1372
1373/// Rejoin room is used to reconnect to a room after connection errors.
1374async fn rejoin_room(
1375 request: proto::RejoinRoom,
1376 response: Response<proto::RejoinRoom>,
1377 session: MessageContext,
1378) -> Result<()> {
1379 let room;
1380 let channel;
1381 {
1382 let mut rejoined_room = session
1383 .db()
1384 .await
1385 .rejoin_room(request, session.user_id(), session.connection_id)
1386 .await?;
1387
1388 response.send(proto::RejoinRoomResponse {
1389 room: Some(rejoined_room.room.clone()),
1390 reshared_projects: rejoined_room
1391 .reshared_projects
1392 .iter()
1393 .map(|project| proto::ResharedProject {
1394 id: project.id.to_proto(),
1395 collaborators: project
1396 .collaborators
1397 .iter()
1398 .map(|collaborator| collaborator.to_proto())
1399 .collect(),
1400 })
1401 .collect(),
1402 rejoined_projects: rejoined_room
1403 .rejoined_projects
1404 .iter()
1405 .map(|rejoined_project| rejoined_project.to_proto())
1406 .collect(),
1407 })?;
1408 room_updated(&rejoined_room.room, &session.peer);
1409
1410 for project in &rejoined_room.reshared_projects {
1411 for collaborator in &project.collaborators {
1412 session
1413 .peer
1414 .send(
1415 collaborator.connection_id,
1416 proto::UpdateProjectCollaborator {
1417 project_id: project.id.to_proto(),
1418 old_peer_id: Some(project.old_connection_id.into()),
1419 new_peer_id: Some(session.connection_id.into()),
1420 },
1421 )
1422 .trace_err();
1423 }
1424
1425 broadcast(
1426 Some(session.connection_id),
1427 project
1428 .collaborators
1429 .iter()
1430 .map(|collaborator| collaborator.connection_id),
1431 |connection_id| {
1432 session.peer.forward_send(
1433 session.connection_id,
1434 connection_id,
1435 proto::UpdateProject {
1436 project_id: project.id.to_proto(),
1437 worktrees: project.worktrees.clone(),
1438 },
1439 )
1440 },
1441 );
1442 }
1443
1444 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1445
1446 let rejoined_room = rejoined_room.into_inner();
1447
1448 room = rejoined_room.room;
1449 channel = rejoined_room.channel;
1450 }
1451
1452 if let Some(channel) = channel {
1453 channel_updated(
1454 &channel,
1455 &room,
1456 &session.peer,
1457 &*session.connection_pool().await,
1458 );
1459 }
1460
1461 update_user_contacts(session.user_id(), &session).await?;
1462 Ok(())
1463}
1464
1465fn notify_rejoined_projects(
1466 rejoined_projects: &mut Vec<RejoinedProject>,
1467 session: &Session,
1468) -> Result<()> {
1469 for project in rejoined_projects.iter() {
1470 for collaborator in &project.collaborators {
1471 session
1472 .peer
1473 .send(
1474 collaborator.connection_id,
1475 proto::UpdateProjectCollaborator {
1476 project_id: project.id.to_proto(),
1477 old_peer_id: Some(project.old_connection_id.into()),
1478 new_peer_id: Some(session.connection_id.into()),
1479 },
1480 )
1481 .trace_err();
1482 }
1483 }
1484
1485 for project in rejoined_projects {
1486 for worktree in mem::take(&mut project.worktrees) {
1487 // Stream this worktree's entries.
1488 let message = proto::UpdateWorktree {
1489 project_id: project.id.to_proto(),
1490 worktree_id: worktree.id,
1491 abs_path: worktree.abs_path.clone(),
1492 root_name: worktree.root_name,
1493 updated_entries: worktree.updated_entries,
1494 removed_entries: worktree.removed_entries,
1495 scan_id: worktree.scan_id,
1496 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1497 updated_repositories: worktree.updated_repositories,
1498 removed_repositories: worktree.removed_repositories,
1499 };
1500 for update in proto::split_worktree_update(message) {
1501 session.peer.send(session.connection_id, update)?;
1502 }
1503
1504 // Stream this worktree's diagnostics.
1505 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1506 if let Some(summary) = worktree_diagnostics.next() {
1507 let message = proto::UpdateDiagnosticSummary {
1508 project_id: project.id.to_proto(),
1509 worktree_id: worktree.id,
1510 summary: Some(summary),
1511 more_summaries: worktree_diagnostics.collect(),
1512 };
1513 session.peer.send(session.connection_id, message)?;
1514 }
1515
1516 for settings_file in worktree.settings_files {
1517 session.peer.send(
1518 session.connection_id,
1519 proto::UpdateWorktreeSettings {
1520 project_id: project.id.to_proto(),
1521 worktree_id: worktree.id,
1522 path: settings_file.path,
1523 content: Some(settings_file.content),
1524 kind: Some(settings_file.kind.to_proto().into()),
1525 outside_worktree: Some(settings_file.outside_worktree),
1526 },
1527 )?;
1528 }
1529 }
1530
1531 for repository in mem::take(&mut project.updated_repositories) {
1532 for update in split_repository_update(repository) {
1533 session.peer.send(session.connection_id, update)?;
1534 }
1535 }
1536
1537 for id in mem::take(&mut project.removed_repositories) {
1538 session.peer.send(
1539 session.connection_id,
1540 proto::RemoveRepository {
1541 project_id: project.id.to_proto(),
1542 id,
1543 },
1544 )?;
1545 }
1546 }
1547
1548 Ok(())
1549}
1550
1551/// leave room disconnects from the room.
1552async fn leave_room(
1553 _: proto::LeaveRoom,
1554 response: Response<proto::LeaveRoom>,
1555 session: MessageContext,
1556) -> Result<()> {
1557 leave_room_for_session(&session, session.connection_id).await?;
1558 response.send(proto::Ack {})?;
1559 Ok(())
1560}
1561
1562/// Updates the permissions of someone else in the room.
1563async fn set_room_participant_role(
1564 request: proto::SetRoomParticipantRole,
1565 response: Response<proto::SetRoomParticipantRole>,
1566 session: MessageContext,
1567) -> Result<()> {
1568 let user_id = UserId::from_proto(request.user_id);
1569 let role = ChannelRole::from(request.role());
1570
1571 let (livekit_room, can_publish) = {
1572 let room = session
1573 .db()
1574 .await
1575 .set_room_participant_role(
1576 session.user_id(),
1577 RoomId::from_proto(request.room_id),
1578 user_id,
1579 role,
1580 )
1581 .await?;
1582
1583 let livekit_room = room.livekit_room.clone();
1584 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1585 room_updated(&room, &session.peer);
1586 (livekit_room, can_publish)
1587 };
1588
1589 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1590 live_kit
1591 .update_participant(
1592 livekit_room.clone(),
1593 request.user_id.to_string(),
1594 livekit_api::proto::ParticipantPermission {
1595 can_subscribe: true,
1596 can_publish,
1597 can_publish_data: can_publish,
1598 hidden: false,
1599 recorder: false,
1600 },
1601 )
1602 .await
1603 .trace_err();
1604 }
1605
1606 response.send(proto::Ack {})?;
1607 Ok(())
1608}
1609
1610/// Call someone else into the current room
1611async fn call(
1612 request: proto::Call,
1613 response: Response<proto::Call>,
1614 session: MessageContext,
1615) -> Result<()> {
1616 let room_id = RoomId::from_proto(request.room_id);
1617 let calling_user_id = session.user_id();
1618 let calling_connection_id = session.connection_id;
1619 let called_user_id = UserId::from_proto(request.called_user_id);
1620 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1621 if !session
1622 .db()
1623 .await
1624 .has_contact(calling_user_id, called_user_id)
1625 .await?
1626 {
1627 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1628 }
1629
1630 let incoming_call = {
1631 let (room, incoming_call) = &mut *session
1632 .db()
1633 .await
1634 .call(
1635 room_id,
1636 calling_user_id,
1637 calling_connection_id,
1638 called_user_id,
1639 initial_project_id,
1640 )
1641 .await?;
1642 room_updated(room, &session.peer);
1643 mem::take(incoming_call)
1644 };
1645 update_user_contacts(called_user_id, &session).await?;
1646
1647 let mut calls = session
1648 .connection_pool()
1649 .await
1650 .user_connection_ids(called_user_id)
1651 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1652 .collect::<FuturesUnordered<_>>();
1653
1654 while let Some(call_response) = calls.next().await {
1655 match call_response.as_ref() {
1656 Ok(_) => {
1657 response.send(proto::Ack {})?;
1658 return Ok(());
1659 }
1660 Err(_) => {
1661 call_response.trace_err();
1662 }
1663 }
1664 }
1665
1666 {
1667 let room = session
1668 .db()
1669 .await
1670 .call_failed(room_id, called_user_id)
1671 .await?;
1672 room_updated(&room, &session.peer);
1673 }
1674 update_user_contacts(called_user_id, &session).await?;
1675
1676 Err(anyhow!("failed to ring user"))?
1677}
1678
1679/// Cancel an outgoing call.
1680async fn cancel_call(
1681 request: proto::CancelCall,
1682 response: Response<proto::CancelCall>,
1683 session: MessageContext,
1684) -> Result<()> {
1685 let called_user_id = UserId::from_proto(request.called_user_id);
1686 let room_id = RoomId::from_proto(request.room_id);
1687 {
1688 let room = session
1689 .db()
1690 .await
1691 .cancel_call(room_id, session.connection_id, called_user_id)
1692 .await?;
1693 room_updated(&room, &session.peer);
1694 }
1695
1696 for connection_id in session
1697 .connection_pool()
1698 .await
1699 .user_connection_ids(called_user_id)
1700 {
1701 session
1702 .peer
1703 .send(
1704 connection_id,
1705 proto::CallCanceled {
1706 room_id: room_id.to_proto(),
1707 },
1708 )
1709 .trace_err();
1710 }
1711 response.send(proto::Ack {})?;
1712
1713 update_user_contacts(called_user_id, &session).await?;
1714 Ok(())
1715}
1716
1717/// Decline an incoming call.
1718async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1719 let room_id = RoomId::from_proto(message.room_id);
1720 {
1721 let room = session
1722 .db()
1723 .await
1724 .decline_call(Some(room_id), session.user_id())
1725 .await?
1726 .context("declining call")?;
1727 room_updated(&room, &session.peer);
1728 }
1729
1730 for connection_id in session
1731 .connection_pool()
1732 .await
1733 .user_connection_ids(session.user_id())
1734 {
1735 session
1736 .peer
1737 .send(
1738 connection_id,
1739 proto::CallCanceled {
1740 room_id: room_id.to_proto(),
1741 },
1742 )
1743 .trace_err();
1744 }
1745 update_user_contacts(session.user_id(), &session).await?;
1746 Ok(())
1747}
1748
1749/// Updates other participants in the room with your current location.
1750async fn update_participant_location(
1751 request: proto::UpdateParticipantLocation,
1752 response: Response<proto::UpdateParticipantLocation>,
1753 session: MessageContext,
1754) -> Result<()> {
1755 let room_id = RoomId::from_proto(request.room_id);
1756 let location = request.location.context("invalid location")?;
1757
1758 let db = session.db().await;
1759 let room = db
1760 .update_room_participant_location(room_id, session.connection_id, location)
1761 .await?;
1762
1763 room_updated(&room, &session.peer);
1764 response.send(proto::Ack {})?;
1765 Ok(())
1766}
1767
1768/// Share a project into the room.
1769async fn share_project(
1770 request: proto::ShareProject,
1771 response: Response<proto::ShareProject>,
1772 session: MessageContext,
1773) -> Result<()> {
1774 let (project_id, room) = &*session
1775 .db()
1776 .await
1777 .share_project(
1778 RoomId::from_proto(request.room_id),
1779 session.connection_id,
1780 &request.worktrees,
1781 request.is_ssh_project,
1782 request.windows_paths.unwrap_or(false),
1783 )
1784 .await?;
1785 response.send(proto::ShareProjectResponse {
1786 project_id: project_id.to_proto(),
1787 })?;
1788 room_updated(room, &session.peer);
1789
1790 Ok(())
1791}
1792
1793/// Unshare a project from the room.
1794async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1795 let project_id = ProjectId::from_proto(message.project_id);
1796 unshare_project_internal(project_id, session.connection_id, &session).await
1797}
1798
1799async fn unshare_project_internal(
1800 project_id: ProjectId,
1801 connection_id: ConnectionId,
1802 session: &Session,
1803) -> Result<()> {
1804 let delete = {
1805 let room_guard = session
1806 .db()
1807 .await
1808 .unshare_project(project_id, connection_id)
1809 .await?;
1810
1811 let (delete, room, guest_connection_ids) = &*room_guard;
1812
1813 let message = proto::UnshareProject {
1814 project_id: project_id.to_proto(),
1815 };
1816
1817 broadcast(
1818 Some(connection_id),
1819 guest_connection_ids.iter().copied(),
1820 |conn_id| session.peer.send(conn_id, message.clone()),
1821 );
1822 if let Some(room) = room {
1823 room_updated(room, &session.peer);
1824 }
1825
1826 *delete
1827 };
1828
1829 if delete {
1830 let db = session.db().await;
1831 db.delete_project(project_id).await?;
1832 }
1833
1834 Ok(())
1835}
1836
1837/// Join someone elses shared project.
1838async fn join_project(
1839 request: proto::JoinProject,
1840 response: Response<proto::JoinProject>,
1841 session: MessageContext,
1842) -> Result<()> {
1843 let project_id = ProjectId::from_proto(request.project_id);
1844
1845 tracing::info!(%project_id, "join project");
1846
1847 let db = session.db().await;
1848 let (project, replica_id) = &mut *db
1849 .join_project(
1850 project_id,
1851 session.connection_id,
1852 session.user_id(),
1853 request.committer_name.clone(),
1854 request.committer_email.clone(),
1855 )
1856 .await?;
1857 drop(db);
1858 tracing::info!(%project_id, "join remote project");
1859 let collaborators = project
1860 .collaborators
1861 .iter()
1862 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1863 .map(|collaborator| collaborator.to_proto())
1864 .collect::<Vec<_>>();
1865 let project_id = project.id;
1866 let guest_user_id = session.user_id();
1867
1868 let worktrees = project
1869 .worktrees
1870 .iter()
1871 .map(|(id, worktree)| proto::WorktreeMetadata {
1872 id: *id,
1873 root_name: worktree.root_name.clone(),
1874 visible: worktree.visible,
1875 abs_path: worktree.abs_path.clone(),
1876 })
1877 .collect::<Vec<_>>();
1878
1879 let add_project_collaborator = proto::AddProjectCollaborator {
1880 project_id: project_id.to_proto(),
1881 collaborator: Some(proto::Collaborator {
1882 peer_id: Some(session.connection_id.into()),
1883 replica_id: replica_id.0 as u32,
1884 user_id: guest_user_id.to_proto(),
1885 is_host: false,
1886 committer_name: request.committer_name.clone(),
1887 committer_email: request.committer_email.clone(),
1888 }),
1889 };
1890
1891 for collaborator in &collaborators {
1892 session
1893 .peer
1894 .send(
1895 collaborator.peer_id.unwrap().into(),
1896 add_project_collaborator.clone(),
1897 )
1898 .trace_err();
1899 }
1900
1901 // First, we send the metadata associated with each worktree.
1902 let (language_servers, language_server_capabilities) = project
1903 .language_servers
1904 .clone()
1905 .into_iter()
1906 .map(|server| (server.server, server.capabilities))
1907 .unzip();
1908 response.send(proto::JoinProjectResponse {
1909 project_id: project.id.0 as u64,
1910 worktrees,
1911 replica_id: replica_id.0 as u32,
1912 collaborators,
1913 language_servers,
1914 language_server_capabilities,
1915 role: project.role.into(),
1916 windows_paths: project.path_style == PathStyle::Windows,
1917 })?;
1918
1919 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1920 // Stream this worktree's entries.
1921 let message = proto::UpdateWorktree {
1922 project_id: project_id.to_proto(),
1923 worktree_id,
1924 abs_path: worktree.abs_path.clone(),
1925 root_name: worktree.root_name,
1926 updated_entries: worktree.entries,
1927 removed_entries: Default::default(),
1928 scan_id: worktree.scan_id,
1929 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1930 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1931 removed_repositories: Default::default(),
1932 };
1933 for update in proto::split_worktree_update(message) {
1934 session.peer.send(session.connection_id, update.clone())?;
1935 }
1936
1937 // Stream this worktree's diagnostics.
1938 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1939 if let Some(summary) = worktree_diagnostics.next() {
1940 let message = proto::UpdateDiagnosticSummary {
1941 project_id: project.id.to_proto(),
1942 worktree_id: worktree.id,
1943 summary: Some(summary),
1944 more_summaries: worktree_diagnostics.collect(),
1945 };
1946 session.peer.send(session.connection_id, message)?;
1947 }
1948
1949 for settings_file in worktree.settings_files {
1950 session.peer.send(
1951 session.connection_id,
1952 proto::UpdateWorktreeSettings {
1953 project_id: project_id.to_proto(),
1954 worktree_id: worktree.id,
1955 path: settings_file.path,
1956 content: Some(settings_file.content),
1957 kind: Some(settings_file.kind.to_proto() as i32),
1958 outside_worktree: Some(settings_file.outside_worktree),
1959 },
1960 )?;
1961 }
1962 }
1963
1964 for repository in mem::take(&mut project.repositories) {
1965 for update in split_repository_update(repository) {
1966 session.peer.send(session.connection_id, update)?;
1967 }
1968 }
1969
1970 for language_server in &project.language_servers {
1971 session.peer.send(
1972 session.connection_id,
1973 proto::UpdateLanguageServer {
1974 project_id: project_id.to_proto(),
1975 server_name: Some(language_server.server.name.clone()),
1976 language_server_id: language_server.server.id,
1977 variant: Some(
1978 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1979 proto::LspDiskBasedDiagnosticsUpdated {},
1980 ),
1981 ),
1982 },
1983 )?;
1984 }
1985
1986 Ok(())
1987}
1988
1989/// Leave someone elses shared project.
1990async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
1991 let sender_id = session.connection_id;
1992 let project_id = ProjectId::from_proto(request.project_id);
1993 let db = session.db().await;
1994
1995 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1996 tracing::info!(
1997 %project_id,
1998 "leave project"
1999 );
2000
2001 project_left(project, &session);
2002 if let Some(room) = room {
2003 room_updated(room, &session.peer);
2004 }
2005
2006 Ok(())
2007}
2008
2009/// Updates other participants with changes to the project
2010async fn update_project(
2011 request: proto::UpdateProject,
2012 response: Response<proto::UpdateProject>,
2013 session: MessageContext,
2014) -> Result<()> {
2015 let project_id = ProjectId::from_proto(request.project_id);
2016 let (room, guest_connection_ids) = &*session
2017 .db()
2018 .await
2019 .update_project(project_id, session.connection_id, &request.worktrees)
2020 .await?;
2021 broadcast(
2022 Some(session.connection_id),
2023 guest_connection_ids.iter().copied(),
2024 |connection_id| {
2025 session
2026 .peer
2027 .forward_send(session.connection_id, connection_id, request.clone())
2028 },
2029 );
2030 if let Some(room) = room {
2031 room_updated(room, &session.peer);
2032 }
2033 response.send(proto::Ack {})?;
2034
2035 Ok(())
2036}
2037
2038/// Updates other participants with changes to the worktree
2039async fn update_worktree(
2040 request: proto::UpdateWorktree,
2041 response: Response<proto::UpdateWorktree>,
2042 session: MessageContext,
2043) -> Result<()> {
2044 let guest_connection_ids = session
2045 .db()
2046 .await
2047 .update_worktree(&request, session.connection_id)
2048 .await?;
2049
2050 broadcast(
2051 Some(session.connection_id),
2052 guest_connection_ids.iter().copied(),
2053 |connection_id| {
2054 session
2055 .peer
2056 .forward_send(session.connection_id, connection_id, request.clone())
2057 },
2058 );
2059 response.send(proto::Ack {})?;
2060 Ok(())
2061}
2062
2063async fn update_repository(
2064 request: proto::UpdateRepository,
2065 response: Response<proto::UpdateRepository>,
2066 session: MessageContext,
2067) -> Result<()> {
2068 let guest_connection_ids = session
2069 .db()
2070 .await
2071 .update_repository(&request, session.connection_id)
2072 .await?;
2073
2074 broadcast(
2075 Some(session.connection_id),
2076 guest_connection_ids.iter().copied(),
2077 |connection_id| {
2078 session
2079 .peer
2080 .forward_send(session.connection_id, connection_id, request.clone())
2081 },
2082 );
2083 response.send(proto::Ack {})?;
2084 Ok(())
2085}
2086
2087async fn remove_repository(
2088 request: proto::RemoveRepository,
2089 response: Response<proto::RemoveRepository>,
2090 session: MessageContext,
2091) -> Result<()> {
2092 let guest_connection_ids = session
2093 .db()
2094 .await
2095 .remove_repository(&request, session.connection_id)
2096 .await?;
2097
2098 broadcast(
2099 Some(session.connection_id),
2100 guest_connection_ids.iter().copied(),
2101 |connection_id| {
2102 session
2103 .peer
2104 .forward_send(session.connection_id, connection_id, request.clone())
2105 },
2106 );
2107 response.send(proto::Ack {})?;
2108 Ok(())
2109}
2110
2111/// Updates other participants with changes to the diagnostics
2112async fn update_diagnostic_summary(
2113 message: proto::UpdateDiagnosticSummary,
2114 session: MessageContext,
2115) -> Result<()> {
2116 let guest_connection_ids = session
2117 .db()
2118 .await
2119 .update_diagnostic_summary(&message, session.connection_id)
2120 .await?;
2121
2122 broadcast(
2123 Some(session.connection_id),
2124 guest_connection_ids.iter().copied(),
2125 |connection_id| {
2126 session
2127 .peer
2128 .forward_send(session.connection_id, connection_id, message.clone())
2129 },
2130 );
2131
2132 Ok(())
2133}
2134
2135/// Updates other participants with changes to the worktree settings
2136async fn update_worktree_settings(
2137 message: proto::UpdateWorktreeSettings,
2138 session: MessageContext,
2139) -> Result<()> {
2140 let guest_connection_ids = session
2141 .db()
2142 .await
2143 .update_worktree_settings(&message, session.connection_id)
2144 .await?;
2145
2146 broadcast(
2147 Some(session.connection_id),
2148 guest_connection_ids.iter().copied(),
2149 |connection_id| {
2150 session
2151 .peer
2152 .forward_send(session.connection_id, connection_id, message.clone())
2153 },
2154 );
2155
2156 Ok(())
2157}
2158
2159/// Notify other participants that a language server has started.
2160async fn start_language_server(
2161 request: proto::StartLanguageServer,
2162 session: MessageContext,
2163) -> Result<()> {
2164 let guest_connection_ids = session
2165 .db()
2166 .await
2167 .start_language_server(&request, session.connection_id)
2168 .await?;
2169
2170 broadcast(
2171 Some(session.connection_id),
2172 guest_connection_ids.iter().copied(),
2173 |connection_id| {
2174 session
2175 .peer
2176 .forward_send(session.connection_id, connection_id, request.clone())
2177 },
2178 );
2179 Ok(())
2180}
2181
2182/// Notify other participants that a language server has changed.
2183async fn update_language_server(
2184 request: proto::UpdateLanguageServer,
2185 session: MessageContext,
2186) -> Result<()> {
2187 let project_id = ProjectId::from_proto(request.project_id);
2188 let db = session.db().await;
2189
2190 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2191 && let Some(capabilities) = update.capabilities.clone()
2192 {
2193 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2194 .await?;
2195 }
2196
2197 let project_connection_ids = db
2198 .project_connection_ids(project_id, session.connection_id, true)
2199 .await?;
2200 broadcast(
2201 Some(session.connection_id),
2202 project_connection_ids.iter().copied(),
2203 |connection_id| {
2204 session
2205 .peer
2206 .forward_send(session.connection_id, connection_id, request.clone())
2207 },
2208 );
2209 Ok(())
2210}
2211
2212/// forward a project request to the host. These requests should be read only
2213/// as guests are allowed to send them.
2214async fn forward_read_only_project_request<T>(
2215 request: T,
2216 response: Response<T>,
2217 session: MessageContext,
2218) -> Result<()>
2219where
2220 T: EntityMessage + RequestMessage,
2221{
2222 let project_id = ProjectId::from_proto(request.remote_entity_id());
2223 let host_connection_id = session
2224 .db()
2225 .await
2226 .host_for_read_only_project_request(project_id, session.connection_id)
2227 .await?;
2228 let payload = session.forward_request(host_connection_id, request).await?;
2229 response.send(payload)?;
2230 Ok(())
2231}
2232
2233/// forward a project request to the host. These requests are disallowed
2234/// for guests.
2235async fn forward_mutating_project_request<T>(
2236 request: T,
2237 response: Response<T>,
2238 session: MessageContext,
2239) -> Result<()>
2240where
2241 T: EntityMessage + RequestMessage,
2242{
2243 let project_id = ProjectId::from_proto(request.remote_entity_id());
2244
2245 let host_connection_id = session
2246 .db()
2247 .await
2248 .host_for_mutating_project_request(project_id, session.connection_id)
2249 .await?;
2250 let payload = session.forward_request(host_connection_id, request).await?;
2251 response.send(payload)?;
2252 Ok(())
2253}
2254
2255async fn disallow_guest_request<T>(
2256 _request: T,
2257 response: Response<T>,
2258 _session: MessageContext,
2259) -> Result<()>
2260where
2261 T: RequestMessage,
2262{
2263 response.peer.respond_with_error(
2264 response.receipt,
2265 ErrorCode::Forbidden
2266 .message("request is not allowed for guests".to_string())
2267 .to_proto(),
2268 )?;
2269 response.responded.store(true, SeqCst);
2270 Ok(())
2271}
2272
2273async fn lsp_query(
2274 request: proto::LspQuery,
2275 response: Response<proto::LspQuery>,
2276 session: MessageContext,
2277) -> Result<()> {
2278 let (name, should_write) = request.query_name_and_write_permissions();
2279 tracing::Span::current().record("lsp_query_request", name);
2280 tracing::info!("lsp_query message received");
2281 if should_write {
2282 forward_mutating_project_request(request, response, session).await
2283 } else {
2284 forward_read_only_project_request(request, response, session).await
2285 }
2286}
2287
2288/// Notify other participants that a new buffer has been created
2289async fn create_buffer_for_peer(
2290 request: proto::CreateBufferForPeer,
2291 session: MessageContext,
2292) -> Result<()> {
2293 session
2294 .db()
2295 .await
2296 .check_user_is_project_host(
2297 ProjectId::from_proto(request.project_id),
2298 session.connection_id,
2299 )
2300 .await?;
2301 let peer_id = request.peer_id.context("invalid peer id")?;
2302 session
2303 .peer
2304 .forward_send(session.connection_id, peer_id.into(), request)?;
2305 Ok(())
2306}
2307
2308/// Notify other participants that a new image has been created
2309async fn create_image_for_peer(
2310 request: proto::CreateImageForPeer,
2311 session: MessageContext,
2312) -> Result<()> {
2313 session
2314 .db()
2315 .await
2316 .check_user_is_project_host(
2317 ProjectId::from_proto(request.project_id),
2318 session.connection_id,
2319 )
2320 .await?;
2321 let peer_id = request.peer_id.context("invalid peer id")?;
2322 session
2323 .peer
2324 .forward_send(session.connection_id, peer_id.into(), request)?;
2325 Ok(())
2326}
2327
2328/// Notify other participants that a buffer has been updated. This is
2329/// allowed for guests as long as the update is limited to selections.
2330async fn update_buffer(
2331 request: proto::UpdateBuffer,
2332 response: Response<proto::UpdateBuffer>,
2333 session: MessageContext,
2334) -> Result<()> {
2335 let project_id = ProjectId::from_proto(request.project_id);
2336 let mut capability = Capability::ReadOnly;
2337
2338 for op in request.operations.iter() {
2339 match op.variant {
2340 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2341 Some(_) => capability = Capability::ReadWrite,
2342 }
2343 }
2344
2345 let host = {
2346 let guard = session
2347 .db()
2348 .await
2349 .connections_for_buffer_update(project_id, session.connection_id, capability)
2350 .await?;
2351
2352 let (host, guests) = &*guard;
2353
2354 broadcast(
2355 Some(session.connection_id),
2356 guests.clone(),
2357 |connection_id| {
2358 session
2359 .peer
2360 .forward_send(session.connection_id, connection_id, request.clone())
2361 },
2362 );
2363
2364 *host
2365 };
2366
2367 if host != session.connection_id {
2368 session.forward_request(host, request.clone()).await?;
2369 }
2370
2371 response.send(proto::Ack {})?;
2372 Ok(())
2373}
2374
2375async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2376 let project_id = ProjectId::from_proto(message.project_id);
2377
2378 let operation = message.operation.as_ref().context("invalid operation")?;
2379 let capability = match operation.variant.as_ref() {
2380 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2381 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2382 match buffer_op.variant {
2383 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2384 Capability::ReadOnly
2385 }
2386 _ => Capability::ReadWrite,
2387 }
2388 } else {
2389 Capability::ReadWrite
2390 }
2391 }
2392 Some(_) => Capability::ReadWrite,
2393 None => Capability::ReadOnly,
2394 };
2395
2396 let guard = session
2397 .db()
2398 .await
2399 .connections_for_buffer_update(project_id, session.connection_id, capability)
2400 .await?;
2401
2402 let (host, guests) = &*guard;
2403
2404 broadcast(
2405 Some(session.connection_id),
2406 guests.iter().chain([host]).copied(),
2407 |connection_id| {
2408 session
2409 .peer
2410 .forward_send(session.connection_id, connection_id, message.clone())
2411 },
2412 );
2413
2414 Ok(())
2415}
2416
2417async fn forward_project_search_chunk(
2418 message: proto::FindSearchCandidatesChunk,
2419 response: Response<proto::FindSearchCandidatesChunk>,
2420 session: MessageContext,
2421) -> Result<()> {
2422 let peer_id = message.peer_id.context("missing peer_id")?;
2423 let payload = session
2424 .peer
2425 .forward_request(session.connection_id, peer_id.into(), message)
2426 .await?;
2427 response.send(payload)?;
2428 Ok(())
2429}
2430
2431/// Notify other participants that a project has been updated.
2432async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2433 request: T,
2434 session: MessageContext,
2435) -> Result<()> {
2436 let project_id = ProjectId::from_proto(request.remote_entity_id());
2437 let project_connection_ids = session
2438 .db()
2439 .await
2440 .project_connection_ids(project_id, session.connection_id, false)
2441 .await?;
2442
2443 broadcast(
2444 Some(session.connection_id),
2445 project_connection_ids.iter().copied(),
2446 |connection_id| {
2447 session
2448 .peer
2449 .forward_send(session.connection_id, connection_id, request.clone())
2450 },
2451 );
2452 Ok(())
2453}
2454
2455/// Start following another user in a call.
2456async fn follow(
2457 request: proto::Follow,
2458 response: Response<proto::Follow>,
2459 session: MessageContext,
2460) -> Result<()> {
2461 let room_id = RoomId::from_proto(request.room_id);
2462 let project_id = request.project_id.map(ProjectId::from_proto);
2463 let leader_id = request.leader_id.context("invalid leader id")?.into();
2464 let follower_id = session.connection_id;
2465
2466 session
2467 .db()
2468 .await
2469 .check_room_participants(room_id, leader_id, session.connection_id)
2470 .await?;
2471
2472 let response_payload = session.forward_request(leader_id, request).await?;
2473 response.send(response_payload)?;
2474
2475 if let Some(project_id) = project_id {
2476 let room = session
2477 .db()
2478 .await
2479 .follow(room_id, project_id, leader_id, follower_id)
2480 .await?;
2481 room_updated(&room, &session.peer);
2482 }
2483
2484 Ok(())
2485}
2486
2487/// Stop following another user in a call.
2488async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2489 let room_id = RoomId::from_proto(request.room_id);
2490 let project_id = request.project_id.map(ProjectId::from_proto);
2491 let leader_id = request.leader_id.context("invalid leader id")?.into();
2492 let follower_id = session.connection_id;
2493
2494 session
2495 .db()
2496 .await
2497 .check_room_participants(room_id, leader_id, session.connection_id)
2498 .await?;
2499
2500 session
2501 .peer
2502 .forward_send(session.connection_id, leader_id, request)?;
2503
2504 if let Some(project_id) = project_id {
2505 let room = session
2506 .db()
2507 .await
2508 .unfollow(room_id, project_id, leader_id, follower_id)
2509 .await?;
2510 room_updated(&room, &session.peer);
2511 }
2512
2513 Ok(())
2514}
2515
2516/// Notify everyone following you of your current location.
2517async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2518 let room_id = RoomId::from_proto(request.room_id);
2519 let database = session.db.lock().await;
2520
2521 let connection_ids = if let Some(project_id) = request.project_id {
2522 let project_id = ProjectId::from_proto(project_id);
2523 database
2524 .project_connection_ids(project_id, session.connection_id, true)
2525 .await?
2526 } else {
2527 database
2528 .room_connection_ids(room_id, session.connection_id)
2529 .await?
2530 };
2531
2532 // For now, don't send view update messages back to that view's current leader.
2533 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2534 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2535 _ => None,
2536 });
2537
2538 for connection_id in connection_ids.iter().cloned() {
2539 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2540 session
2541 .peer
2542 .forward_send(session.connection_id, connection_id, request.clone())?;
2543 }
2544 }
2545 Ok(())
2546}
2547
2548/// Get public data about users.
2549async fn get_users(
2550 request: proto::GetUsers,
2551 response: Response<proto::GetUsers>,
2552 session: MessageContext,
2553) -> Result<()> {
2554 let user_ids = request
2555 .user_ids
2556 .into_iter()
2557 .map(UserId::from_proto)
2558 .collect();
2559 let users = session
2560 .db()
2561 .await
2562 .get_users_by_ids(user_ids)
2563 .await?
2564 .into_iter()
2565 .map(|user| proto::User {
2566 id: user.id.to_proto(),
2567 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2568 github_login: user.github_login,
2569 name: user.name,
2570 })
2571 .collect();
2572 response.send(proto::UsersResponse { users })?;
2573 Ok(())
2574}
2575
2576/// Search for users (to invite) buy Github login
2577async fn fuzzy_search_users(
2578 request: proto::FuzzySearchUsers,
2579 response: Response<proto::FuzzySearchUsers>,
2580 session: MessageContext,
2581) -> Result<()> {
2582 let query = request.query;
2583 let users = match query.len() {
2584 0 => vec![],
2585 1 | 2 => session
2586 .db()
2587 .await
2588 .get_user_by_github_login(&query)
2589 .await?
2590 .into_iter()
2591 .collect(),
2592 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2593 };
2594 let users = users
2595 .into_iter()
2596 .filter(|user| user.id != session.user_id())
2597 .map(|user| proto::User {
2598 id: user.id.to_proto(),
2599 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2600 github_login: user.github_login,
2601 name: user.name,
2602 })
2603 .collect();
2604 response.send(proto::UsersResponse { users })?;
2605 Ok(())
2606}
2607
2608/// Send a contact request to another user.
2609async fn request_contact(
2610 request: proto::RequestContact,
2611 response: Response<proto::RequestContact>,
2612 session: MessageContext,
2613) -> Result<()> {
2614 let requester_id = session.user_id();
2615 let responder_id = UserId::from_proto(request.responder_id);
2616 if requester_id == responder_id {
2617 return Err(anyhow!("cannot add yourself as a contact"))?;
2618 }
2619
2620 let notifications = session
2621 .db()
2622 .await
2623 .send_contact_request(requester_id, responder_id)
2624 .await?;
2625
2626 // Update outgoing contact requests of requester
2627 let mut update = proto::UpdateContacts::default();
2628 update.outgoing_requests.push(responder_id.to_proto());
2629 for connection_id in session
2630 .connection_pool()
2631 .await
2632 .user_connection_ids(requester_id)
2633 {
2634 session.peer.send(connection_id, update.clone())?;
2635 }
2636
2637 // Update incoming contact requests of responder
2638 let mut update = proto::UpdateContacts::default();
2639 update
2640 .incoming_requests
2641 .push(proto::IncomingContactRequest {
2642 requester_id: requester_id.to_proto(),
2643 });
2644 let connection_pool = session.connection_pool().await;
2645 for connection_id in connection_pool.user_connection_ids(responder_id) {
2646 session.peer.send(connection_id, update.clone())?;
2647 }
2648
2649 send_notifications(&connection_pool, &session.peer, notifications);
2650
2651 response.send(proto::Ack {})?;
2652 Ok(())
2653}
2654
2655/// Accept or decline a contact request
2656async fn respond_to_contact_request(
2657 request: proto::RespondToContactRequest,
2658 response: Response<proto::RespondToContactRequest>,
2659 session: MessageContext,
2660) -> Result<()> {
2661 let responder_id = session.user_id();
2662 let requester_id = UserId::from_proto(request.requester_id);
2663 let db = session.db().await;
2664 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2665 db.dismiss_contact_notification(responder_id, requester_id)
2666 .await?;
2667 } else {
2668 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2669
2670 let notifications = db
2671 .respond_to_contact_request(responder_id, requester_id, accept)
2672 .await?;
2673 let requester_busy = db.is_user_busy(requester_id).await?;
2674 let responder_busy = db.is_user_busy(responder_id).await?;
2675
2676 let pool = session.connection_pool().await;
2677 // Update responder with new contact
2678 let mut update = proto::UpdateContacts::default();
2679 if accept {
2680 update
2681 .contacts
2682 .push(contact_for_user(requester_id, requester_busy, &pool));
2683 }
2684 update
2685 .remove_incoming_requests
2686 .push(requester_id.to_proto());
2687 for connection_id in pool.user_connection_ids(responder_id) {
2688 session.peer.send(connection_id, update.clone())?;
2689 }
2690
2691 // Update requester with new contact
2692 let mut update = proto::UpdateContacts::default();
2693 if accept {
2694 update
2695 .contacts
2696 .push(contact_for_user(responder_id, responder_busy, &pool));
2697 }
2698 update
2699 .remove_outgoing_requests
2700 .push(responder_id.to_proto());
2701
2702 for connection_id in pool.user_connection_ids(requester_id) {
2703 session.peer.send(connection_id, update.clone())?;
2704 }
2705
2706 send_notifications(&pool, &session.peer, notifications);
2707 }
2708
2709 response.send(proto::Ack {})?;
2710 Ok(())
2711}
2712
2713/// Remove a contact.
2714async fn remove_contact(
2715 request: proto::RemoveContact,
2716 response: Response<proto::RemoveContact>,
2717 session: MessageContext,
2718) -> Result<()> {
2719 let requester_id = session.user_id();
2720 let responder_id = UserId::from_proto(request.user_id);
2721 let db = session.db().await;
2722 let (contact_accepted, deleted_notification_id) =
2723 db.remove_contact(requester_id, responder_id).await?;
2724
2725 let pool = session.connection_pool().await;
2726 // Update outgoing contact requests of requester
2727 let mut update = proto::UpdateContacts::default();
2728 if contact_accepted {
2729 update.remove_contacts.push(responder_id.to_proto());
2730 } else {
2731 update
2732 .remove_outgoing_requests
2733 .push(responder_id.to_proto());
2734 }
2735 for connection_id in pool.user_connection_ids(requester_id) {
2736 session.peer.send(connection_id, update.clone())?;
2737 }
2738
2739 // Update incoming contact requests of responder
2740 let mut update = proto::UpdateContacts::default();
2741 if contact_accepted {
2742 update.remove_contacts.push(requester_id.to_proto());
2743 } else {
2744 update
2745 .remove_incoming_requests
2746 .push(requester_id.to_proto());
2747 }
2748 for connection_id in pool.user_connection_ids(responder_id) {
2749 session.peer.send(connection_id, update.clone())?;
2750 if let Some(notification_id) = deleted_notification_id {
2751 session.peer.send(
2752 connection_id,
2753 proto::DeleteNotification {
2754 notification_id: notification_id.to_proto(),
2755 },
2756 )?;
2757 }
2758 }
2759
2760 response.send(proto::Ack {})?;
2761 Ok(())
2762}
2763
2764fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2765 version.0.minor < 139
2766}
2767
2768async fn subscribe_to_channels(
2769 _: proto::SubscribeToChannels,
2770 session: MessageContext,
2771) -> Result<()> {
2772 subscribe_user_to_channels(session.user_id(), &session).await?;
2773 Ok(())
2774}
2775
2776async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2777 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2778 let mut pool = session.connection_pool().await;
2779 for membership in &channels_for_user.channel_memberships {
2780 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2781 }
2782 session.peer.send(
2783 session.connection_id,
2784 build_update_user_channels(&channels_for_user),
2785 )?;
2786 session.peer.send(
2787 session.connection_id,
2788 build_channels_update(channels_for_user),
2789 )?;
2790 Ok(())
2791}
2792
2793/// Creates a new channel.
2794async fn create_channel(
2795 request: proto::CreateChannel,
2796 response: Response<proto::CreateChannel>,
2797 session: MessageContext,
2798) -> Result<()> {
2799 let db = session.db().await;
2800
2801 let parent_id = request.parent_id.map(ChannelId::from_proto);
2802 let (channel, membership) = db
2803 .create_channel(&request.name, parent_id, session.user_id())
2804 .await?;
2805
2806 let root_id = channel.root_id();
2807 let channel = Channel::from_model(channel);
2808
2809 response.send(proto::CreateChannelResponse {
2810 channel: Some(channel.to_proto()),
2811 parent_id: request.parent_id,
2812 })?;
2813
2814 let mut connection_pool = session.connection_pool().await;
2815 if let Some(membership) = membership {
2816 connection_pool.subscribe_to_channel(
2817 membership.user_id,
2818 membership.channel_id,
2819 membership.role,
2820 );
2821 let update = proto::UpdateUserChannels {
2822 channel_memberships: vec![proto::ChannelMembership {
2823 channel_id: membership.channel_id.to_proto(),
2824 role: membership.role.into(),
2825 }],
2826 ..Default::default()
2827 };
2828 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2829 session.peer.send(connection_id, update.clone())?;
2830 }
2831 }
2832
2833 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2834 if !role.can_see_channel(channel.visibility) {
2835 continue;
2836 }
2837
2838 let update = proto::UpdateChannels {
2839 channels: vec![channel.to_proto()],
2840 ..Default::default()
2841 };
2842 session.peer.send(connection_id, update.clone())?;
2843 }
2844
2845 Ok(())
2846}
2847
2848/// Delete a channel
2849async fn delete_channel(
2850 request: proto::DeleteChannel,
2851 response: Response<proto::DeleteChannel>,
2852 session: MessageContext,
2853) -> Result<()> {
2854 let db = session.db().await;
2855
2856 let channel_id = request.channel_id;
2857 let (root_channel, removed_channels) = db
2858 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2859 .await?;
2860 response.send(proto::Ack {})?;
2861
2862 // Notify members of removed channels
2863 let mut update = proto::UpdateChannels::default();
2864 update
2865 .delete_channels
2866 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2867
2868 let connection_pool = session.connection_pool().await;
2869 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2870 session.peer.send(connection_id, update.clone())?;
2871 }
2872
2873 Ok(())
2874}
2875
2876/// Invite someone to join a channel.
2877async fn invite_channel_member(
2878 request: proto::InviteChannelMember,
2879 response: Response<proto::InviteChannelMember>,
2880 session: MessageContext,
2881) -> Result<()> {
2882 let db = session.db().await;
2883 let channel_id = ChannelId::from_proto(request.channel_id);
2884 let invitee_id = UserId::from_proto(request.user_id);
2885 let InviteMemberResult {
2886 channel,
2887 notifications,
2888 } = db
2889 .invite_channel_member(
2890 channel_id,
2891 invitee_id,
2892 session.user_id(),
2893 request.role().into(),
2894 )
2895 .await?;
2896
2897 let update = proto::UpdateChannels {
2898 channel_invitations: vec![channel.to_proto()],
2899 ..Default::default()
2900 };
2901
2902 let connection_pool = session.connection_pool().await;
2903 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2904 session.peer.send(connection_id, update.clone())?;
2905 }
2906
2907 send_notifications(&connection_pool, &session.peer, notifications);
2908
2909 response.send(proto::Ack {})?;
2910 Ok(())
2911}
2912
2913/// remove someone from a channel
2914async fn remove_channel_member(
2915 request: proto::RemoveChannelMember,
2916 response: Response<proto::RemoveChannelMember>,
2917 session: MessageContext,
2918) -> Result<()> {
2919 let db = session.db().await;
2920 let channel_id = ChannelId::from_proto(request.channel_id);
2921 let member_id = UserId::from_proto(request.user_id);
2922
2923 let RemoveChannelMemberResult {
2924 membership_update,
2925 notification_id,
2926 } = db
2927 .remove_channel_member(channel_id, member_id, session.user_id())
2928 .await?;
2929
2930 let mut connection_pool = session.connection_pool().await;
2931 notify_membership_updated(
2932 &mut connection_pool,
2933 membership_update,
2934 member_id,
2935 &session.peer,
2936 );
2937 for connection_id in connection_pool.user_connection_ids(member_id) {
2938 if let Some(notification_id) = notification_id {
2939 session
2940 .peer
2941 .send(
2942 connection_id,
2943 proto::DeleteNotification {
2944 notification_id: notification_id.to_proto(),
2945 },
2946 )
2947 .trace_err();
2948 }
2949 }
2950
2951 response.send(proto::Ack {})?;
2952 Ok(())
2953}
2954
2955/// Toggle the channel between public and private.
2956/// Care is taken to maintain the invariant that public channels only descend from public channels,
2957/// (though members-only channels can appear at any point in the hierarchy).
2958async fn set_channel_visibility(
2959 request: proto::SetChannelVisibility,
2960 response: Response<proto::SetChannelVisibility>,
2961 session: MessageContext,
2962) -> Result<()> {
2963 let db = session.db().await;
2964 let channel_id = ChannelId::from_proto(request.channel_id);
2965 let visibility = request.visibility().into();
2966
2967 let channel_model = db
2968 .set_channel_visibility(channel_id, visibility, session.user_id())
2969 .await?;
2970 let root_id = channel_model.root_id();
2971 let channel = Channel::from_model(channel_model);
2972
2973 let mut connection_pool = session.connection_pool().await;
2974 for (user_id, role) in connection_pool
2975 .channel_user_ids(root_id)
2976 .collect::<Vec<_>>()
2977 .into_iter()
2978 {
2979 let update = if role.can_see_channel(channel.visibility) {
2980 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2981 proto::UpdateChannels {
2982 channels: vec![channel.to_proto()],
2983 ..Default::default()
2984 }
2985 } else {
2986 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2987 proto::UpdateChannels {
2988 delete_channels: vec![channel.id.to_proto()],
2989 ..Default::default()
2990 }
2991 };
2992
2993 for connection_id in connection_pool.user_connection_ids(user_id) {
2994 session.peer.send(connection_id, update.clone())?;
2995 }
2996 }
2997
2998 response.send(proto::Ack {})?;
2999 Ok(())
3000}
3001
3002/// Alter the role for a user in the channel.
3003async fn set_channel_member_role(
3004 request: proto::SetChannelMemberRole,
3005 response: Response<proto::SetChannelMemberRole>,
3006 session: MessageContext,
3007) -> Result<()> {
3008 let db = session.db().await;
3009 let channel_id = ChannelId::from_proto(request.channel_id);
3010 let member_id = UserId::from_proto(request.user_id);
3011 let result = db
3012 .set_channel_member_role(
3013 channel_id,
3014 session.user_id(),
3015 member_id,
3016 request.role().into(),
3017 )
3018 .await?;
3019
3020 match result {
3021 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3022 let mut connection_pool = session.connection_pool().await;
3023 notify_membership_updated(
3024 &mut connection_pool,
3025 membership_update,
3026 member_id,
3027 &session.peer,
3028 )
3029 }
3030 db::SetMemberRoleResult::InviteUpdated(channel) => {
3031 let update = proto::UpdateChannels {
3032 channel_invitations: vec![channel.to_proto()],
3033 ..Default::default()
3034 };
3035
3036 for connection_id in session
3037 .connection_pool()
3038 .await
3039 .user_connection_ids(member_id)
3040 {
3041 session.peer.send(connection_id, update.clone())?;
3042 }
3043 }
3044 }
3045
3046 response.send(proto::Ack {})?;
3047 Ok(())
3048}
3049
3050/// Change the name of a channel
3051async fn rename_channel(
3052 request: proto::RenameChannel,
3053 response: Response<proto::RenameChannel>,
3054 session: MessageContext,
3055) -> Result<()> {
3056 let db = session.db().await;
3057 let channel_id = ChannelId::from_proto(request.channel_id);
3058 let channel_model = db
3059 .rename_channel(channel_id, session.user_id(), &request.name)
3060 .await?;
3061 let root_id = channel_model.root_id();
3062 let channel = Channel::from_model(channel_model);
3063
3064 response.send(proto::RenameChannelResponse {
3065 channel: Some(channel.to_proto()),
3066 })?;
3067
3068 let connection_pool = session.connection_pool().await;
3069 let update = proto::UpdateChannels {
3070 channels: vec![channel.to_proto()],
3071 ..Default::default()
3072 };
3073 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3074 if role.can_see_channel(channel.visibility) {
3075 session.peer.send(connection_id, update.clone())?;
3076 }
3077 }
3078
3079 Ok(())
3080}
3081
3082/// Move a channel to a new parent.
3083async fn move_channel(
3084 request: proto::MoveChannel,
3085 response: Response<proto::MoveChannel>,
3086 session: MessageContext,
3087) -> Result<()> {
3088 let channel_id = ChannelId::from_proto(request.channel_id);
3089 let to = ChannelId::from_proto(request.to);
3090
3091 let (root_id, channels) = session
3092 .db()
3093 .await
3094 .move_channel(channel_id, to, session.user_id())
3095 .await?;
3096
3097 let connection_pool = session.connection_pool().await;
3098 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3099 let channels = channels
3100 .iter()
3101 .filter_map(|channel| {
3102 if role.can_see_channel(channel.visibility) {
3103 Some(channel.to_proto())
3104 } else {
3105 None
3106 }
3107 })
3108 .collect::<Vec<_>>();
3109 if channels.is_empty() {
3110 continue;
3111 }
3112
3113 let update = proto::UpdateChannels {
3114 channels,
3115 ..Default::default()
3116 };
3117
3118 session.peer.send(connection_id, update.clone())?;
3119 }
3120
3121 response.send(Ack {})?;
3122 Ok(())
3123}
3124
3125async fn reorder_channel(
3126 request: proto::ReorderChannel,
3127 response: Response<proto::ReorderChannel>,
3128 session: MessageContext,
3129) -> Result<()> {
3130 let channel_id = ChannelId::from_proto(request.channel_id);
3131 let direction = request.direction();
3132
3133 let updated_channels = session
3134 .db()
3135 .await
3136 .reorder_channel(channel_id, direction, session.user_id())
3137 .await?;
3138
3139 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3140 let connection_pool = session.connection_pool().await;
3141 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3142 let channels = updated_channels
3143 .iter()
3144 .filter_map(|channel| {
3145 if role.can_see_channel(channel.visibility) {
3146 Some(channel.to_proto())
3147 } else {
3148 None
3149 }
3150 })
3151 .collect::<Vec<_>>();
3152
3153 if channels.is_empty() {
3154 continue;
3155 }
3156
3157 let update = proto::UpdateChannels {
3158 channels,
3159 ..Default::default()
3160 };
3161
3162 session.peer.send(connection_id, update.clone())?;
3163 }
3164 }
3165
3166 response.send(Ack {})?;
3167 Ok(())
3168}
3169
3170/// Get the list of channel members
3171async fn get_channel_members(
3172 request: proto::GetChannelMembers,
3173 response: Response<proto::GetChannelMembers>,
3174 session: MessageContext,
3175) -> Result<()> {
3176 let db = session.db().await;
3177 let channel_id = ChannelId::from_proto(request.channel_id);
3178 let limit = if request.limit == 0 {
3179 u16::MAX as u64
3180 } else {
3181 request.limit
3182 };
3183 let (members, users) = db
3184 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3185 .await?;
3186 response.send(proto::GetChannelMembersResponse { members, users })?;
3187 Ok(())
3188}
3189
3190/// Accept or decline a channel invitation.
3191async fn respond_to_channel_invite(
3192 request: proto::RespondToChannelInvite,
3193 response: Response<proto::RespondToChannelInvite>,
3194 session: MessageContext,
3195) -> Result<()> {
3196 let db = session.db().await;
3197 let channel_id = ChannelId::from_proto(request.channel_id);
3198 let RespondToChannelInvite {
3199 membership_update,
3200 notifications,
3201 } = db
3202 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3203 .await?;
3204
3205 let mut connection_pool = session.connection_pool().await;
3206 if let Some(membership_update) = membership_update {
3207 notify_membership_updated(
3208 &mut connection_pool,
3209 membership_update,
3210 session.user_id(),
3211 &session.peer,
3212 );
3213 } else {
3214 let update = proto::UpdateChannels {
3215 remove_channel_invitations: vec![channel_id.to_proto()],
3216 ..Default::default()
3217 };
3218
3219 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3220 session.peer.send(connection_id, update.clone())?;
3221 }
3222 };
3223
3224 send_notifications(&connection_pool, &session.peer, notifications);
3225
3226 response.send(proto::Ack {})?;
3227
3228 Ok(())
3229}
3230
3231/// Join the channels' room
3232async fn join_channel(
3233 request: proto::JoinChannel,
3234 response: Response<proto::JoinChannel>,
3235 session: MessageContext,
3236) -> Result<()> {
3237 let channel_id = ChannelId::from_proto(request.channel_id);
3238 join_channel_internal(channel_id, Box::new(response), session).await
3239}
3240
3241trait JoinChannelInternalResponse {
3242 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3243}
3244impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3245 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3246 Response::<proto::JoinChannel>::send(self, result)
3247 }
3248}
3249impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3250 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3251 Response::<proto::JoinRoom>::send(self, result)
3252 }
3253}
3254
3255async fn join_channel_internal(
3256 channel_id: ChannelId,
3257 response: Box<impl JoinChannelInternalResponse>,
3258 session: MessageContext,
3259) -> Result<()> {
3260 let joined_room = {
3261 let mut db = session.db().await;
3262 // If zed quits without leaving the room, and the user re-opens zed before the
3263 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3264 // room they were in.
3265 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3266 tracing::info!(
3267 stale_connection_id = %connection,
3268 "cleaning up stale connection",
3269 );
3270 drop(db);
3271 leave_room_for_session(&session, connection).await?;
3272 db = session.db().await;
3273 }
3274
3275 let (joined_room, membership_updated, role) = db
3276 .join_channel(channel_id, session.user_id(), session.connection_id)
3277 .await?;
3278
3279 let live_kit_connection_info =
3280 session
3281 .app_state
3282 .livekit_client
3283 .as_ref()
3284 .and_then(|live_kit| {
3285 let (can_publish, token) = if role == ChannelRole::Guest {
3286 (
3287 false,
3288 live_kit
3289 .guest_token(
3290 &joined_room.room.livekit_room,
3291 &session.user_id().to_string(),
3292 )
3293 .trace_err()?,
3294 )
3295 } else {
3296 (
3297 true,
3298 live_kit
3299 .room_token(
3300 &joined_room.room.livekit_room,
3301 &session.user_id().to_string(),
3302 )
3303 .trace_err()?,
3304 )
3305 };
3306
3307 Some(LiveKitConnectionInfo {
3308 server_url: live_kit.url().into(),
3309 token,
3310 can_publish,
3311 })
3312 });
3313
3314 response.send(proto::JoinRoomResponse {
3315 room: Some(joined_room.room.clone()),
3316 channel_id: joined_room
3317 .channel
3318 .as_ref()
3319 .map(|channel| channel.id.to_proto()),
3320 live_kit_connection_info,
3321 })?;
3322
3323 let mut connection_pool = session.connection_pool().await;
3324 if let Some(membership_updated) = membership_updated {
3325 notify_membership_updated(
3326 &mut connection_pool,
3327 membership_updated,
3328 session.user_id(),
3329 &session.peer,
3330 );
3331 }
3332
3333 room_updated(&joined_room.room, &session.peer);
3334
3335 joined_room
3336 };
3337
3338 channel_updated(
3339 &joined_room.channel.context("channel not returned")?,
3340 &joined_room.room,
3341 &session.peer,
3342 &*session.connection_pool().await,
3343 );
3344
3345 update_user_contacts(session.user_id(), &session).await?;
3346 Ok(())
3347}
3348
3349/// Start editing the channel notes
3350async fn join_channel_buffer(
3351 request: proto::JoinChannelBuffer,
3352 response: Response<proto::JoinChannelBuffer>,
3353 session: MessageContext,
3354) -> Result<()> {
3355 let db = session.db().await;
3356 let channel_id = ChannelId::from_proto(request.channel_id);
3357
3358 let open_response = db
3359 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3360 .await?;
3361
3362 let collaborators = open_response.collaborators.clone();
3363 response.send(open_response)?;
3364
3365 let update = UpdateChannelBufferCollaborators {
3366 channel_id: channel_id.to_proto(),
3367 collaborators: collaborators.clone(),
3368 };
3369 channel_buffer_updated(
3370 session.connection_id,
3371 collaborators
3372 .iter()
3373 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3374 &update,
3375 &session.peer,
3376 );
3377
3378 Ok(())
3379}
3380
3381/// Edit the channel notes
3382async fn update_channel_buffer(
3383 request: proto::UpdateChannelBuffer,
3384 session: MessageContext,
3385) -> Result<()> {
3386 let db = session.db().await;
3387 let channel_id = ChannelId::from_proto(request.channel_id);
3388
3389 let (collaborators, epoch, version) = db
3390 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3391 .await?;
3392
3393 channel_buffer_updated(
3394 session.connection_id,
3395 collaborators.clone(),
3396 &proto::UpdateChannelBuffer {
3397 channel_id: channel_id.to_proto(),
3398 operations: request.operations,
3399 },
3400 &session.peer,
3401 );
3402
3403 let pool = &*session.connection_pool().await;
3404
3405 let non_collaborators =
3406 pool.channel_connection_ids(channel_id)
3407 .filter_map(|(connection_id, _)| {
3408 if collaborators.contains(&connection_id) {
3409 None
3410 } else {
3411 Some(connection_id)
3412 }
3413 });
3414
3415 broadcast(None, non_collaborators, |peer_id| {
3416 session.peer.send(
3417 peer_id,
3418 proto::UpdateChannels {
3419 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3420 channel_id: channel_id.to_proto(),
3421 epoch: epoch as u64,
3422 version: version.clone(),
3423 }],
3424 ..Default::default()
3425 },
3426 )
3427 });
3428
3429 Ok(())
3430}
3431
3432/// Rejoin the channel notes after a connection blip
3433async fn rejoin_channel_buffers(
3434 request: proto::RejoinChannelBuffers,
3435 response: Response<proto::RejoinChannelBuffers>,
3436 session: MessageContext,
3437) -> Result<()> {
3438 let db = session.db().await;
3439 let buffers = db
3440 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3441 .await?;
3442
3443 for rejoined_buffer in &buffers {
3444 let collaborators_to_notify = rejoined_buffer
3445 .buffer
3446 .collaborators
3447 .iter()
3448 .filter_map(|c| Some(c.peer_id?.into()));
3449 channel_buffer_updated(
3450 session.connection_id,
3451 collaborators_to_notify,
3452 &proto::UpdateChannelBufferCollaborators {
3453 channel_id: rejoined_buffer.buffer.channel_id,
3454 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3455 },
3456 &session.peer,
3457 );
3458 }
3459
3460 response.send(proto::RejoinChannelBuffersResponse {
3461 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3462 })?;
3463
3464 Ok(())
3465}
3466
3467/// Stop editing the channel notes
3468async fn leave_channel_buffer(
3469 request: proto::LeaveChannelBuffer,
3470 response: Response<proto::LeaveChannelBuffer>,
3471 session: MessageContext,
3472) -> Result<()> {
3473 let db = session.db().await;
3474 let channel_id = ChannelId::from_proto(request.channel_id);
3475
3476 let left_buffer = db
3477 .leave_channel_buffer(channel_id, session.connection_id)
3478 .await?;
3479
3480 response.send(Ack {})?;
3481
3482 channel_buffer_updated(
3483 session.connection_id,
3484 left_buffer.connections,
3485 &proto::UpdateChannelBufferCollaborators {
3486 channel_id: channel_id.to_proto(),
3487 collaborators: left_buffer.collaborators,
3488 },
3489 &session.peer,
3490 );
3491
3492 Ok(())
3493}
3494
3495fn channel_buffer_updated<T: EnvelopedMessage>(
3496 sender_id: ConnectionId,
3497 collaborators: impl IntoIterator<Item = ConnectionId>,
3498 message: &T,
3499 peer: &Peer,
3500) {
3501 broadcast(Some(sender_id), collaborators, |peer_id| {
3502 peer.send(peer_id, message.clone())
3503 });
3504}
3505
3506fn send_notifications(
3507 connection_pool: &ConnectionPool,
3508 peer: &Peer,
3509 notifications: db::NotificationBatch,
3510) {
3511 for (user_id, notification) in notifications {
3512 for connection_id in connection_pool.user_connection_ids(user_id) {
3513 if let Err(error) = peer.send(
3514 connection_id,
3515 proto::AddNotification {
3516 notification: Some(notification.clone()),
3517 },
3518 ) {
3519 tracing::error!(
3520 "failed to send notification to {:?} {}",
3521 connection_id,
3522 error
3523 );
3524 }
3525 }
3526 }
3527}
3528
3529/// Send a message to the channel
3530async fn send_channel_message(
3531 _request: proto::SendChannelMessage,
3532 _response: Response<proto::SendChannelMessage>,
3533 _session: MessageContext,
3534) -> Result<()> {
3535 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3536}
3537
3538/// Delete a channel message
3539async fn remove_channel_message(
3540 _request: proto::RemoveChannelMessage,
3541 _response: Response<proto::RemoveChannelMessage>,
3542 _session: MessageContext,
3543) -> Result<()> {
3544 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3545}
3546
3547async fn update_channel_message(
3548 _request: proto::UpdateChannelMessage,
3549 _response: Response<proto::UpdateChannelMessage>,
3550 _session: MessageContext,
3551) -> Result<()> {
3552 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3553}
3554
3555/// Mark a channel message as read
3556async fn acknowledge_channel_message(
3557 _request: proto::AckChannelMessage,
3558 _session: MessageContext,
3559) -> Result<()> {
3560 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3561}
3562
3563/// Mark a buffer version as synced
3564async fn acknowledge_buffer_version(
3565 request: proto::AckBufferOperation,
3566 session: MessageContext,
3567) -> Result<()> {
3568 let buffer_id = BufferId::from_proto(request.buffer_id);
3569 session
3570 .db()
3571 .await
3572 .observe_buffer_version(
3573 buffer_id,
3574 session.user_id(),
3575 request.epoch as i32,
3576 &request.version,
3577 )
3578 .await?;
3579 Ok(())
3580}
3581
3582/// Start receiving chat updates for a channel
3583async fn join_channel_chat(
3584 _request: proto::JoinChannelChat,
3585 _response: Response<proto::JoinChannelChat>,
3586 _session: MessageContext,
3587) -> Result<()> {
3588 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3589}
3590
3591/// Stop receiving chat updates for a channel
3592async fn leave_channel_chat(
3593 _request: proto::LeaveChannelChat,
3594 _session: MessageContext,
3595) -> Result<()> {
3596 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3597}
3598
3599/// Retrieve the chat history for a channel
3600async fn get_channel_messages(
3601 _request: proto::GetChannelMessages,
3602 _response: Response<proto::GetChannelMessages>,
3603 _session: MessageContext,
3604) -> Result<()> {
3605 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3606}
3607
3608/// Retrieve specific chat messages
3609async fn get_channel_messages_by_id(
3610 _request: proto::GetChannelMessagesById,
3611 _response: Response<proto::GetChannelMessagesById>,
3612 _session: MessageContext,
3613) -> Result<()> {
3614 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3615}
3616
3617/// Retrieve the current users notifications
3618async fn get_notifications(
3619 request: proto::GetNotifications,
3620 response: Response<proto::GetNotifications>,
3621 session: MessageContext,
3622) -> Result<()> {
3623 let notifications = session
3624 .db()
3625 .await
3626 .get_notifications(
3627 session.user_id(),
3628 NOTIFICATION_COUNT_PER_PAGE,
3629 request.before_id.map(db::NotificationId::from_proto),
3630 )
3631 .await?;
3632 response.send(proto::GetNotificationsResponse {
3633 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3634 notifications,
3635 })?;
3636 Ok(())
3637}
3638
3639/// Mark notifications as read
3640async fn mark_notification_as_read(
3641 request: proto::MarkNotificationRead,
3642 response: Response<proto::MarkNotificationRead>,
3643 session: MessageContext,
3644) -> Result<()> {
3645 let database = &session.db().await;
3646 let notifications = database
3647 .mark_notification_as_read_by_id(
3648 session.user_id(),
3649 NotificationId::from_proto(request.notification_id),
3650 )
3651 .await?;
3652 send_notifications(
3653 &*session.connection_pool().await,
3654 &session.peer,
3655 notifications,
3656 );
3657 response.send(proto::Ack {})?;
3658 Ok(())
3659}
3660
3661fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3662 let message = match message {
3663 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3664 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3665 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3666 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3667 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3668 code: frame.code.into(),
3669 reason: frame.reason.as_str().to_owned().into(),
3670 })),
3671 // We should never receive a frame while reading the message, according
3672 // to the `tungstenite` maintainers:
3673 //
3674 // > It cannot occur when you read messages from the WebSocket, but it
3675 // > can be used when you want to send the raw frames (e.g. you want to
3676 // > send the frames to the WebSocket without composing the full message first).
3677 // >
3678 // > — https://github.com/snapview/tungstenite-rs/issues/268
3679 TungsteniteMessage::Frame(_) => {
3680 bail!("received an unexpected frame while reading the message")
3681 }
3682 };
3683
3684 Ok(message)
3685}
3686
3687fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3688 match message {
3689 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3690 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3691 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3692 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3693 AxumMessage::Close(frame) => {
3694 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3695 code: frame.code.into(),
3696 reason: frame.reason.as_ref().into(),
3697 }))
3698 }
3699 }
3700}
3701
3702fn notify_membership_updated(
3703 connection_pool: &mut ConnectionPool,
3704 result: MembershipUpdated,
3705 user_id: UserId,
3706 peer: &Peer,
3707) {
3708 for membership in &result.new_channels.channel_memberships {
3709 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3710 }
3711 for channel_id in &result.removed_channels {
3712 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3713 }
3714
3715 let user_channels_update = proto::UpdateUserChannels {
3716 channel_memberships: result
3717 .new_channels
3718 .channel_memberships
3719 .iter()
3720 .map(|cm| proto::ChannelMembership {
3721 channel_id: cm.channel_id.to_proto(),
3722 role: cm.role.into(),
3723 })
3724 .collect(),
3725 ..Default::default()
3726 };
3727
3728 let mut update = build_channels_update(result.new_channels);
3729 update.delete_channels = result
3730 .removed_channels
3731 .into_iter()
3732 .map(|id| id.to_proto())
3733 .collect();
3734 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3735
3736 for connection_id in connection_pool.user_connection_ids(user_id) {
3737 peer.send(connection_id, user_channels_update.clone())
3738 .trace_err();
3739 peer.send(connection_id, update.clone()).trace_err();
3740 }
3741}
3742
3743fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3744 proto::UpdateUserChannels {
3745 channel_memberships: channels
3746 .channel_memberships
3747 .iter()
3748 .map(|m| proto::ChannelMembership {
3749 channel_id: m.channel_id.to_proto(),
3750 role: m.role.into(),
3751 })
3752 .collect(),
3753 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3754 }
3755}
3756
3757fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3758 let mut update = proto::UpdateChannels::default();
3759
3760 for channel in channels.channels {
3761 update.channels.push(channel.to_proto());
3762 }
3763
3764 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3765
3766 for (channel_id, participants) in channels.channel_participants {
3767 update
3768 .channel_participants
3769 .push(proto::ChannelParticipants {
3770 channel_id: channel_id.to_proto(),
3771 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3772 });
3773 }
3774
3775 for channel in channels.invited_channels {
3776 update.channel_invitations.push(channel.to_proto());
3777 }
3778
3779 update
3780}
3781
3782fn build_initial_contacts_update(
3783 contacts: Vec<db::Contact>,
3784 pool: &ConnectionPool,
3785) -> proto::UpdateContacts {
3786 let mut update = proto::UpdateContacts::default();
3787
3788 for contact in contacts {
3789 match contact {
3790 db::Contact::Accepted { user_id, busy } => {
3791 update.contacts.push(contact_for_user(user_id, busy, pool));
3792 }
3793 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3794 db::Contact::Incoming { user_id } => {
3795 update
3796 .incoming_requests
3797 .push(proto::IncomingContactRequest {
3798 requester_id: user_id.to_proto(),
3799 })
3800 }
3801 }
3802 }
3803
3804 update
3805}
3806
3807fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3808 proto::Contact {
3809 user_id: user_id.to_proto(),
3810 online: pool.is_user_online(user_id),
3811 busy,
3812 }
3813}
3814
3815fn room_updated(room: &proto::Room, peer: &Peer) {
3816 broadcast(
3817 None,
3818 room.participants
3819 .iter()
3820 .filter_map(|participant| Some(participant.peer_id?.into())),
3821 |peer_id| {
3822 peer.send(
3823 peer_id,
3824 proto::RoomUpdated {
3825 room: Some(room.clone()),
3826 },
3827 )
3828 },
3829 );
3830}
3831
3832fn channel_updated(
3833 channel: &db::channel::Model,
3834 room: &proto::Room,
3835 peer: &Peer,
3836 pool: &ConnectionPool,
3837) {
3838 let participants = room
3839 .participants
3840 .iter()
3841 .map(|p| p.user_id)
3842 .collect::<Vec<_>>();
3843
3844 broadcast(
3845 None,
3846 pool.channel_connection_ids(channel.root_id())
3847 .filter_map(|(channel_id, role)| {
3848 role.can_see_channel(channel.visibility)
3849 .then_some(channel_id)
3850 }),
3851 |peer_id| {
3852 peer.send(
3853 peer_id,
3854 proto::UpdateChannels {
3855 channel_participants: vec![proto::ChannelParticipants {
3856 channel_id: channel.id.to_proto(),
3857 participant_user_ids: participants.clone(),
3858 }],
3859 ..Default::default()
3860 },
3861 )
3862 },
3863 );
3864}
3865
3866async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3867 let db = session.db().await;
3868
3869 let contacts = db.get_contacts(user_id).await?;
3870 let busy = db.is_user_busy(user_id).await?;
3871
3872 let pool = session.connection_pool().await;
3873 let updated_contact = contact_for_user(user_id, busy, &pool);
3874 for contact in contacts {
3875 if let db::Contact::Accepted {
3876 user_id: contact_user_id,
3877 ..
3878 } = contact
3879 {
3880 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3881 session
3882 .peer
3883 .send(
3884 contact_conn_id,
3885 proto::UpdateContacts {
3886 contacts: vec![updated_contact.clone()],
3887 remove_contacts: Default::default(),
3888 incoming_requests: Default::default(),
3889 remove_incoming_requests: Default::default(),
3890 outgoing_requests: Default::default(),
3891 remove_outgoing_requests: Default::default(),
3892 },
3893 )
3894 .trace_err();
3895 }
3896 }
3897 }
3898 Ok(())
3899}
3900
3901async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3902 let mut contacts_to_update = HashSet::default();
3903
3904 let room_id;
3905 let canceled_calls_to_user_ids;
3906 let livekit_room;
3907 let delete_livekit_room;
3908 let room;
3909 let channel;
3910
3911 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3912 contacts_to_update.insert(session.user_id());
3913
3914 for project in left_room.left_projects.values() {
3915 project_left(project, session);
3916 }
3917
3918 room_id = RoomId::from_proto(left_room.room.id);
3919 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3920 livekit_room = mem::take(&mut left_room.room.livekit_room);
3921 delete_livekit_room = left_room.deleted;
3922 room = mem::take(&mut left_room.room);
3923 channel = mem::take(&mut left_room.channel);
3924
3925 room_updated(&room, &session.peer);
3926 } else {
3927 return Ok(());
3928 }
3929
3930 if let Some(channel) = channel {
3931 channel_updated(
3932 &channel,
3933 &room,
3934 &session.peer,
3935 &*session.connection_pool().await,
3936 );
3937 }
3938
3939 {
3940 let pool = session.connection_pool().await;
3941 for canceled_user_id in canceled_calls_to_user_ids {
3942 for connection_id in pool.user_connection_ids(canceled_user_id) {
3943 session
3944 .peer
3945 .send(
3946 connection_id,
3947 proto::CallCanceled {
3948 room_id: room_id.to_proto(),
3949 },
3950 )
3951 .trace_err();
3952 }
3953 contacts_to_update.insert(canceled_user_id);
3954 }
3955 }
3956
3957 for contact_user_id in contacts_to_update {
3958 update_user_contacts(contact_user_id, session).await?;
3959 }
3960
3961 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3962 live_kit
3963 .remove_participant(livekit_room.clone(), session.user_id().to_string())
3964 .await
3965 .trace_err();
3966
3967 if delete_livekit_room {
3968 live_kit.delete_room(livekit_room).await.trace_err();
3969 }
3970 }
3971
3972 Ok(())
3973}
3974
3975async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3976 let left_channel_buffers = session
3977 .db()
3978 .await
3979 .leave_channel_buffers(session.connection_id)
3980 .await?;
3981
3982 for left_buffer in left_channel_buffers {
3983 channel_buffer_updated(
3984 session.connection_id,
3985 left_buffer.connections,
3986 &proto::UpdateChannelBufferCollaborators {
3987 channel_id: left_buffer.channel_id.to_proto(),
3988 collaborators: left_buffer.collaborators,
3989 },
3990 &session.peer,
3991 );
3992 }
3993
3994 Ok(())
3995}
3996
3997fn project_left(project: &db::LeftProject, session: &Session) {
3998 for connection_id in &project.connection_ids {
3999 if project.should_unshare {
4000 session
4001 .peer
4002 .send(
4003 *connection_id,
4004 proto::UnshareProject {
4005 project_id: project.id.to_proto(),
4006 },
4007 )
4008 .trace_err();
4009 } else {
4010 session
4011 .peer
4012 .send(
4013 *connection_id,
4014 proto::RemoveProjectCollaborator {
4015 project_id: project.id.to_proto(),
4016 peer_id: Some(session.connection_id.into()),
4017 },
4018 )
4019 .trace_err();
4020 }
4021 }
4022}
4023
4024async fn share_agent_thread(
4025 request: proto::ShareAgentThread,
4026 response: Response<proto::ShareAgentThread>,
4027 session: MessageContext,
4028) -> Result<()> {
4029 let user_id = session.user_id();
4030
4031 let share_id = SharedThreadId::from_proto(request.session_id.clone())
4032 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4033
4034 session
4035 .db()
4036 .await
4037 .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4038 .await?;
4039
4040 response.send(proto::Ack {})?;
4041
4042 Ok(())
4043}
4044
4045async fn get_shared_agent_thread(
4046 request: proto::GetSharedAgentThread,
4047 response: Response<proto::GetSharedAgentThread>,
4048 session: MessageContext,
4049) -> Result<()> {
4050 let share_id = SharedThreadId::from_proto(request.session_id)
4051 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4052
4053 let result = session.db().await.get_shared_thread(share_id).await?;
4054
4055 match result {
4056 Some((thread, username)) => {
4057 response.send(proto::GetSharedAgentThreadResponse {
4058 title: thread.title,
4059 thread_data: thread.data,
4060 sharer_username: username,
4061 created_at: thread.created_at.and_utc().to_rfc3339(),
4062 })?;
4063 }
4064 None => {
4065 return Err(anyhow!("Shared thread not found").into());
4066 }
4067 }
4068
4069 Ok(())
4070}
4071
4072pub trait ResultExt {
4073 type Ok;
4074
4075 fn trace_err(self) -> Option<Self::Ok>;
4076}
4077
4078impl<T, E> ResultExt for Result<T, E>
4079where
4080 E: std::fmt::Debug,
4081{
4082 type Ok = T;
4083
4084 #[track_caller]
4085 fn trace_err(self) -> Option<T> {
4086 match self {
4087 Ok(value) => Some(value),
4088 Err(error) => {
4089 tracing::error!("{:?}", error);
4090 None
4091 }
4092 }
4093 }
4094}