1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
10 UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use futures::TryFutureExt as _;
36use rpc::proto::split_repository_update;
37use tracing::Span;
38use util::paths::PathStyle;
39
40use futures::{
41 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
42 stream::FuturesUnordered,
43};
44use prometheus::{IntGauge, register_int_gauge};
45use rpc::{
46 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
47 proto::{
48 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
49 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
50 },
51};
52use semver::Version;
53use std::{
54 any::TypeId,
55 future::Future,
56 marker::PhantomData,
57 mem,
58 net::SocketAddr,
59 ops::{Deref, DerefMut},
60 rc::Rc,
61 sync::{
62 Arc, OnceLock,
63 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
64 },
65 time::{Duration, Instant},
66};
67use tokio::sync::{Semaphore, watch};
68use tower::ServiceBuilder;
69use tracing::{
70 Instrument,
71 field::{self},
72 info_span, instrument,
73};
74
75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
76
77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
79
80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
82
83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
84
85const TOTAL_DURATION_MS: &str = "total_duration_ms";
86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
88const HOST_WAITING_MS: &str = "host_waiting_ms";
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
92
93pub struct ConnectionGuard;
94
95impl ConnectionGuard {
96 pub fn try_acquire() -> Result<Self, ()> {
97 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
98 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
99 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
100 tracing::error!(
101 "too many concurrent connections: {}",
102 current_connections + 1
103 );
104 return Err(());
105 }
106 Ok(ConnectionGuard)
107 }
108}
109
110impl Drop for ConnectionGuard {
111 fn drop(&mut self) {
112 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
113 }
114}
115
116struct Response<R> {
117 peer: Arc<Peer>,
118 receipt: Receipt<R>,
119 responded: Arc<AtomicBool>,
120}
121
122impl<R: RequestMessage> Response<R> {
123 fn send(self, payload: R::Response) -> Result<()> {
124 self.responded.store(true, SeqCst);
125 self.peer.respond(self.receipt, payload)?;
126 Ok(())
127 }
128}
129
130#[derive(Clone, Debug)]
131pub enum Principal {
132 User(User),
133}
134
135impl Principal {
136 fn update_span(&self, span: &tracing::Span) {
137 match &self {
138 Principal::User(user) => {
139 span.record("user_id", user.id.0);
140 span.record("login", &user.github_login);
141 }
142 }
143 }
144}
145
146#[derive(Clone)]
147struct MessageContext {
148 session: Session,
149 span: tracing::Span,
150}
151
152impl Deref for MessageContext {
153 type Target = Session;
154
155 fn deref(&self) -> &Self::Target {
156 &self.session
157 }
158}
159
160impl MessageContext {
161 pub fn forward_request<T: RequestMessage>(
162 &self,
163 receiver_id: ConnectionId,
164 request: T,
165 ) -> impl Future<Output = anyhow::Result<T::Response>> {
166 let request_start_time = Instant::now();
167 let span = self.span.clone();
168 tracing::info!("start forwarding request");
169 self.peer
170 .forward_request(self.connection_id, receiver_id, request)
171 .inspect(move |_| {
172 span.record(
173 HOST_WAITING_MS,
174 request_start_time.elapsed().as_micros() as f64 / 1000.0,
175 );
176 })
177 .inspect_err(|_| tracing::error!("error forwarding request"))
178 .inspect_ok(|_| tracing::info!("finished forwarding request"))
179 }
180}
181
182#[derive(Clone)]
183struct Session {
184 principal: Principal,
185 connection_id: ConnectionId,
186 db: Arc<tokio::sync::Mutex<DbHandle>>,
187 peer: Arc<Peer>,
188 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
189 app_state: Arc<AppState>,
190 /// The GeoIP country code for the user.
191 #[allow(unused)]
192 geoip_country_code: Option<String>,
193 #[allow(unused)]
194 system_id: Option<String>,
195 _executor: Executor,
196}
197
198impl Session {
199 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
200 #[cfg(feature = "test-support")]
201 tokio::task::yield_now().await;
202 let guard = self.db.lock().await;
203 #[cfg(feature = "test-support")]
204 tokio::task::yield_now().await;
205 guard
206 }
207
208 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
209 #[cfg(feature = "test-support")]
210 tokio::task::yield_now().await;
211 let guard = self.connection_pool.lock();
212 ConnectionPoolGuard {
213 guard,
214 _not_send: PhantomData,
215 }
216 }
217
218 #[expect(dead_code)]
219 fn is_staff(&self) -> bool {
220 match &self.principal {
221 Principal::User(user) => user.admin,
222 }
223 }
224
225 fn user_id(&self) -> UserId {
226 match &self.principal {
227 Principal::User(user) => user.id,
228 }
229 }
230}
231
232impl Debug for Session {
233 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
234 let mut result = f.debug_struct("Session");
235 match &self.principal {
236 Principal::User(user) => {
237 result.field("user", &user.github_login);
238 }
239 }
240 result.field("connection_id", &self.connection_id).finish()
241 }
242}
243
244struct DbHandle(Arc<Database>);
245
246impl Deref for DbHandle {
247 type Target = Database;
248
249 fn deref(&self) -> &Self::Target {
250 self.0.as_ref()
251 }
252}
253
254pub struct Server {
255 id: parking_lot::Mutex<ServerId>,
256 peer: Arc<Peer>,
257 pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
258 app_state: Arc<AppState>,
259 handlers: HashMap<TypeId, MessageHandler>,
260 teardown: watch::Sender<bool>,
261}
262
263struct ConnectionPoolGuard<'a> {
264 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
265 _not_send: PhantomData<Rc<()>>,
266}
267
268impl Server {
269 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
270 let mut server = Self {
271 id: parking_lot::Mutex::new(id),
272 peer: Peer::new(id.0 as u32),
273 app_state,
274 connection_pool: Default::default(),
275 handlers: Default::default(),
276 teardown: watch::channel(false).0,
277 };
278
279 server
280 .add_request_handler(ping)
281 .add_request_handler(create_room)
282 .add_request_handler(join_room)
283 .add_request_handler(rejoin_room)
284 .add_request_handler(leave_room)
285 .add_request_handler(set_room_participant_role)
286 .add_request_handler(call)
287 .add_request_handler(cancel_call)
288 .add_message_handler(decline_call)
289 .add_request_handler(update_participant_location)
290 .add_request_handler(share_project)
291 .add_message_handler(unshare_project)
292 .add_request_handler(join_project)
293 .add_message_handler(leave_project)
294 .add_request_handler(update_project)
295 .add_request_handler(update_worktree)
296 .add_request_handler(update_repository)
297 .add_request_handler(remove_repository)
298 .add_message_handler(start_language_server)
299 .add_message_handler(update_language_server)
300 .add_message_handler(update_diagnostic_summary)
301 .add_message_handler(update_worktree_settings)
302 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
307 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
308 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
311 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
313 .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
314 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
316 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
317 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
318 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
319 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
320 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
321 .add_request_handler(
322 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
323 )
324 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
325 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
326 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
327 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
328 .add_request_handler(
329 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
330 )
331 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
332 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
333 .add_request_handler(
334 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
335 )
336 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
337 .add_request_handler(
338 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
339 )
340 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
341 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
342 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
343 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
344 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
345 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
346 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
347 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
348 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
349 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
350 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
353 )
354 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
355 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
356 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
357 .add_request_handler(lsp_query)
358 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
359 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
360 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
361 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
362 .add_message_handler(create_buffer_for_peer)
363 .add_message_handler(create_image_for_peer)
364 .add_request_handler(update_buffer)
365 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
366 .add_message_handler(
367 broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
368 )
369 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
370 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
371 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
372 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
373 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
374 .add_message_handler(
375 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
376 )
377 .add_request_handler(get_users)
378 .add_request_handler(fuzzy_search_users)
379 .add_request_handler(request_contact)
380 .add_request_handler(remove_contact)
381 .add_request_handler(respond_to_contact_request)
382 .add_message_handler(subscribe_to_channels)
383 .add_request_handler(create_channel)
384 .add_request_handler(delete_channel)
385 .add_request_handler(invite_channel_member)
386 .add_request_handler(remove_channel_member)
387 .add_request_handler(set_channel_member_role)
388 .add_request_handler(set_channel_visibility)
389 .add_request_handler(rename_channel)
390 .add_request_handler(join_channel_buffer)
391 .add_request_handler(leave_channel_buffer)
392 .add_message_handler(update_channel_buffer)
393 .add_request_handler(rejoin_channel_buffers)
394 .add_request_handler(get_channel_members)
395 .add_request_handler(respond_to_channel_invite)
396 .add_request_handler(join_channel)
397 .add_request_handler(join_channel_chat)
398 .add_message_handler(leave_channel_chat)
399 .add_request_handler(send_channel_message)
400 .add_request_handler(remove_channel_message)
401 .add_request_handler(update_channel_message)
402 .add_request_handler(get_channel_messages)
403 .add_request_handler(get_channel_messages_by_id)
404 .add_request_handler(get_notifications)
405 .add_request_handler(mark_notification_as_read)
406 .add_request_handler(move_channel)
407 .add_request_handler(reorder_channel)
408 .add_request_handler(follow)
409 .add_message_handler(unfollow)
410 .add_message_handler(update_followers)
411 .add_message_handler(acknowledge_channel_message)
412 .add_message_handler(acknowledge_buffer_version)
413 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
414 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
415 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
416 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
417 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
418 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
419 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
420 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
421 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
422 .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
423 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
424 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
425 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
426 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
427 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
428 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
429 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
430 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
431 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
432 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
433 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
434 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
435 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
436 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
437 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
438 .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
439 .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
440 .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
441 .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
442 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
443 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
444 .add_message_handler(update_context)
445 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
446 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
447 .add_request_handler(share_agent_thread)
448 .add_request_handler(get_shared_agent_thread)
449 .add_request_handler(forward_project_search_chunk);
450
451 Arc::new(server)
452 }
453
454 pub async fn start(&self) -> Result<()> {
455 let server_id = *self.id.lock();
456 let app_state = self.app_state.clone();
457 let peer = self.peer.clone();
458 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
459 let pool = self.connection_pool.clone();
460 let livekit_client = self.app_state.livekit_client.clone();
461
462 let span = info_span!("start server");
463 self.app_state.executor.spawn_detached(
464 async move {
465 tracing::info!("waiting for cleanup timeout");
466 timeout.await;
467 tracing::info!("cleanup timeout expired, retrieving stale rooms");
468
469 app_state
470 .db
471 .delete_stale_channel_chat_participants(
472 &app_state.config.zed_environment,
473 server_id,
474 )
475 .await
476 .trace_err();
477
478 if let Some((room_ids, channel_ids)) = app_state
479 .db
480 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
481 .await
482 .trace_err()
483 {
484 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
485 tracing::info!(
486 stale_channel_buffer_count = channel_ids.len(),
487 "retrieved stale channel buffers"
488 );
489
490 for channel_id in channel_ids {
491 if let Some(refreshed_channel_buffer) = app_state
492 .db
493 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
494 .await
495 .trace_err()
496 {
497 for connection_id in refreshed_channel_buffer.connection_ids {
498 peer.send(
499 connection_id,
500 proto::UpdateChannelBufferCollaborators {
501 channel_id: channel_id.to_proto(),
502 collaborators: refreshed_channel_buffer
503 .collaborators
504 .clone(),
505 },
506 )
507 .trace_err();
508 }
509 }
510 }
511
512 for room_id in room_ids {
513 let mut contacts_to_update = HashSet::default();
514 let mut canceled_calls_to_user_ids = Vec::new();
515 let mut livekit_room = String::new();
516 let mut delete_livekit_room = false;
517
518 if let Some(mut refreshed_room) = app_state
519 .db
520 .clear_stale_room_participants(room_id, server_id)
521 .await
522 .trace_err()
523 {
524 tracing::info!(
525 room_id = room_id.0,
526 new_participant_count = refreshed_room.room.participants.len(),
527 "refreshed room"
528 );
529 room_updated(&refreshed_room.room, &peer);
530 if let Some(channel) = refreshed_room.channel.as_ref() {
531 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
532 }
533 contacts_to_update
534 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
535 contacts_to_update
536 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
537 canceled_calls_to_user_ids =
538 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
539 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
540 delete_livekit_room = refreshed_room.room.participants.is_empty();
541 }
542
543 {
544 let pool = pool.lock();
545 for canceled_user_id in canceled_calls_to_user_ids {
546 for connection_id in pool.user_connection_ids(canceled_user_id) {
547 peer.send(
548 connection_id,
549 proto::CallCanceled {
550 room_id: room_id.to_proto(),
551 },
552 )
553 .trace_err();
554 }
555 }
556 }
557
558 for user_id in contacts_to_update {
559 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
560 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
561 if let Some((busy, contacts)) = busy.zip(contacts) {
562 let pool = pool.lock();
563 let updated_contact = contact_for_user(user_id, busy, &pool);
564 for contact in contacts {
565 if let db::Contact::Accepted {
566 user_id: contact_user_id,
567 ..
568 } = contact
569 {
570 for contact_conn_id in
571 pool.user_connection_ids(contact_user_id)
572 {
573 peer.send(
574 contact_conn_id,
575 proto::UpdateContacts {
576 contacts: vec![updated_contact.clone()],
577 remove_contacts: Default::default(),
578 incoming_requests: Default::default(),
579 remove_incoming_requests: Default::default(),
580 outgoing_requests: Default::default(),
581 remove_outgoing_requests: Default::default(),
582 },
583 )
584 .trace_err();
585 }
586 }
587 }
588 }
589 }
590
591 if let Some(live_kit) = livekit_client.as_ref()
592 && delete_livekit_room
593 {
594 live_kit.delete_room(livekit_room).await.trace_err();
595 }
596 }
597 }
598
599 app_state
600 .db
601 .delete_stale_channel_chat_participants(
602 &app_state.config.zed_environment,
603 server_id,
604 )
605 .await
606 .trace_err();
607
608 app_state
609 .db
610 .clear_old_worktree_entries(server_id)
611 .await
612 .trace_err();
613
614 app_state
615 .db
616 .delete_stale_servers(&app_state.config.zed_environment, server_id)
617 .await
618 .trace_err();
619 }
620 .instrument(span),
621 );
622 Ok(())
623 }
624
625 pub fn teardown(&self) {
626 self.peer.teardown();
627 self.connection_pool.lock().reset();
628 let _ = self.teardown.send(true);
629 }
630
631 #[cfg(feature = "test-support")]
632 pub fn reset(&self, id: ServerId) {
633 self.teardown();
634 *self.id.lock() = id;
635 self.peer.reset(id.0 as u32);
636 let _ = self.teardown.send(false);
637 }
638
639 #[cfg(feature = "test-support")]
640 pub fn id(&self) -> ServerId {
641 *self.id.lock()
642 }
643
644 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
645 where
646 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
647 Fut: 'static + Send + Future<Output = Result<()>>,
648 M: EnvelopedMessage,
649 {
650 let prev_handler = self.handlers.insert(
651 TypeId::of::<M>(),
652 Box::new(move |envelope, session, span| {
653 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
654 let received_at = envelope.received_at;
655 tracing::info!("message received");
656 let start_time = Instant::now();
657 let future = (handler)(
658 *envelope,
659 MessageContext {
660 session,
661 span: span.clone(),
662 },
663 );
664 async move {
665 let result = future.await;
666 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
667 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
668 let queue_duration_ms = total_duration_ms - processing_duration_ms;
669 span.record(TOTAL_DURATION_MS, total_duration_ms);
670 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
671 span.record(QUEUE_DURATION_MS, queue_duration_ms);
672 match result {
673 Err(error) => {
674 tracing::error!(?error, "error handling message")
675 }
676 Ok(()) => tracing::info!("finished handling message"),
677 }
678 }
679 .boxed()
680 }),
681 );
682 if prev_handler.is_some() {
683 panic!("registered a handler for the same message twice");
684 }
685 self
686 }
687
688 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
689 where
690 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
691 Fut: 'static + Send + Future<Output = Result<()>>,
692 M: EnvelopedMessage,
693 {
694 self.add_handler(move |envelope, session| handler(envelope.payload, session));
695 self
696 }
697
698 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
699 where
700 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
701 Fut: Send + Future<Output = Result<()>>,
702 M: RequestMessage,
703 {
704 let handler = Arc::new(handler);
705 self.add_handler(move |envelope, session| {
706 let receipt = envelope.receipt();
707 let handler = handler.clone();
708 async move {
709 let peer = session.peer.clone();
710 let responded = Arc::new(AtomicBool::default());
711 let response = Response {
712 peer: peer.clone(),
713 responded: responded.clone(),
714 receipt,
715 };
716 match (handler)(envelope.payload, response, session).await {
717 Ok(()) => {
718 if responded.load(std::sync::atomic::Ordering::SeqCst) {
719 Ok(())
720 } else {
721 Err(anyhow!("handler did not send a response"))?
722 }
723 }
724 Err(error) => {
725 let proto_err = match &error {
726 Error::Internal(err) => err.to_proto(),
727 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
728 };
729 peer.respond_with_error(receipt, proto_err)?;
730 Err(error)
731 }
732 }
733 }
734 })
735 }
736
737 pub fn handle_connection(
738 self: &Arc<Self>,
739 connection: Connection,
740 address: String,
741 principal: Principal,
742 zed_version: ZedVersion,
743 release_channel: Option<String>,
744 user_agent: Option<String>,
745 geoip_country_code: Option<String>,
746 system_id: Option<String>,
747 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
748 executor: Executor,
749 connection_guard: Option<ConnectionGuard>,
750 ) -> impl Future<Output = ()> + use<> {
751 let this = self.clone();
752 let span = info_span!("handle connection", %address,
753 connection_id=field::Empty,
754 user_id=field::Empty,
755 login=field::Empty,
756 user_agent=field::Empty,
757 geoip_country_code=field::Empty,
758 release_channel=field::Empty,
759 );
760 principal.update_span(&span);
761 if let Some(user_agent) = user_agent {
762 span.record("user_agent", user_agent);
763 }
764 if let Some(release_channel) = release_channel {
765 span.record("release_channel", release_channel);
766 }
767
768 if let Some(country_code) = geoip_country_code.as_ref() {
769 span.record("geoip_country_code", country_code);
770 }
771
772 let mut teardown = self.teardown.subscribe();
773 async move {
774 if *teardown.borrow() {
775 tracing::error!("server is tearing down");
776 return;
777 }
778
779 let (connection_id, handle_io, mut incoming_rx) =
780 this.peer.add_connection(connection, {
781 let executor = executor.clone();
782 move |duration| executor.sleep(duration)
783 });
784 tracing::Span::current().record("connection_id", format!("{}", connection_id));
785
786 tracing::info!("connection opened");
787
788 let session = Session {
789 principal: principal.clone(),
790 connection_id,
791 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
792 peer: this.peer.clone(),
793 connection_pool: this.connection_pool.clone(),
794 app_state: this.app_state.clone(),
795 geoip_country_code,
796 system_id,
797 _executor: executor.clone(),
798 };
799
800 if let Err(error) = this
801 .send_initial_client_update(
802 connection_id,
803 zed_version,
804 send_connection_id,
805 &session,
806 )
807 .await
808 {
809 tracing::error!(?error, "failed to send initial client update");
810 return;
811 }
812 drop(connection_guard);
813
814 let handle_io = handle_io.fuse();
815 futures::pin_mut!(handle_io);
816
817 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
818 // This prevents deadlocks when e.g., client A performs a request to client B and
819 // client B performs a request to client A. If both clients stop processing further
820 // messages until their respective request completes, they won't have a chance to
821 // respond to the other client's request and cause a deadlock.
822 //
823 // This arrangement ensures we will attempt to process earlier messages first, but fall
824 // back to processing messages arrived later in the spirit of making progress.
825 const MAX_CONCURRENT_HANDLERS: usize = 256;
826 let mut foreground_message_handlers = FuturesUnordered::new();
827 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
828 let get_concurrent_handlers = {
829 let concurrent_handlers = concurrent_handlers.clone();
830 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
831 };
832 loop {
833 let next_message = async {
834 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
835 let message = incoming_rx.next().await;
836 // Cache the concurrent_handlers here, so that we know what the
837 // queue looks like as each handler starts
838 (permit, message, get_concurrent_handlers())
839 }
840 .fuse();
841 futures::pin_mut!(next_message);
842 futures::select_biased! {
843 _ = teardown.changed().fuse() => return,
844 result = handle_io => {
845 if let Err(error) = result {
846 tracing::error!(?error, "error handling I/O");
847 }
848 break;
849 }
850 _ = foreground_message_handlers.next() => {}
851 next_message = next_message => {
852 let (permit, message, concurrent_handlers) = next_message;
853 if let Some(message) = message {
854 let type_name = message.payload_type_name();
855 // note: we copy all the fields from the parent span so we can query them in the logs.
856 // (https://github.com/tokio-rs/tracing/issues/2670).
857 let span = tracing::info_span!("receive message",
858 %connection_id,
859 %address,
860 type_name,
861 concurrent_handlers,
862 user_id=field::Empty,
863 login=field::Empty,
864 lsp_query_request=field::Empty,
865 release_channel=field::Empty,
866 { TOTAL_DURATION_MS }=field::Empty,
867 { PROCESSING_DURATION_MS }=field::Empty,
868 { QUEUE_DURATION_MS }=field::Empty,
869 { HOST_WAITING_MS }=field::Empty
870 );
871 principal.update_span(&span);
872 let span_enter = span.enter();
873 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
874 let is_background = message.is_background();
875 let handle_message = (handler)(message, session.clone(), span.clone());
876 drop(span_enter);
877
878 let handle_message = async move {
879 handle_message.await;
880 drop(permit);
881 }.instrument(span);
882 if is_background {
883 executor.spawn_detached(handle_message);
884 } else {
885 foreground_message_handlers.push(handle_message);
886 }
887 } else {
888 tracing::error!("no message handler");
889 }
890 } else {
891 tracing::info!("connection closed");
892 break;
893 }
894 }
895 }
896 }
897
898 drop(foreground_message_handlers);
899 let concurrent_handlers = get_concurrent_handlers();
900 tracing::info!(concurrent_handlers, "signing out");
901 if let Err(error) = connection_lost(session, teardown, executor).await {
902 tracing::error!(?error, "error signing out");
903 }
904 }
905 .instrument(span)
906 }
907
908 async fn send_initial_client_update(
909 &self,
910 connection_id: ConnectionId,
911 zed_version: ZedVersion,
912 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
913 session: &Session,
914 ) -> Result<()> {
915 self.peer.send(
916 connection_id,
917 proto::Hello {
918 peer_id: Some(connection_id.into()),
919 },
920 )?;
921 tracing::info!("sent hello message");
922 if let Some(send_connection_id) = send_connection_id.take() {
923 let _ = send_connection_id.send(connection_id);
924 }
925
926 match &session.principal {
927 Principal::User(user) => {
928 if !user.connected_once {
929 self.peer.send(connection_id, proto::ShowContacts {})?;
930 self.app_state
931 .db
932 .set_user_connected_once(user.id, true)
933 .await?;
934 }
935
936 let contacts = self.app_state.db.get_contacts(user.id).await?;
937
938 {
939 let mut pool = self.connection_pool.lock();
940 pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
941 self.peer.send(
942 connection_id,
943 build_initial_contacts_update(contacts, &pool),
944 )?;
945 }
946
947 if should_auto_subscribe_to_channels(&zed_version) {
948 subscribe_user_to_channels(user.id, session).await?;
949 }
950
951 if let Some(incoming_call) =
952 self.app_state.db.incoming_call_for_user(user.id).await?
953 {
954 self.peer.send(connection_id, incoming_call)?;
955 }
956
957 update_user_contacts(user.id, session).await?;
958 }
959 }
960
961 Ok(())
962 }
963}
964
965impl Deref for ConnectionPoolGuard<'_> {
966 type Target = ConnectionPool;
967
968 fn deref(&self) -> &Self::Target {
969 &self.guard
970 }
971}
972
973impl DerefMut for ConnectionPoolGuard<'_> {
974 fn deref_mut(&mut self) -> &mut Self::Target {
975 &mut self.guard
976 }
977}
978
979impl Drop for ConnectionPoolGuard<'_> {
980 fn drop(&mut self) {
981 #[cfg(feature = "test-support")]
982 self.check_invariants();
983 }
984}
985
986fn broadcast<F>(
987 sender_id: Option<ConnectionId>,
988 receiver_ids: impl IntoIterator<Item = ConnectionId>,
989 mut f: F,
990) where
991 F: FnMut(ConnectionId) -> anyhow::Result<()>,
992{
993 for receiver_id in receiver_ids {
994 if Some(receiver_id) != sender_id
995 && let Err(error) = f(receiver_id)
996 {
997 tracing::error!("failed to send to {:?} {}", receiver_id, error);
998 }
999 }
1000}
1001
1002pub struct ProtocolVersion(u32);
1003
1004impl Header for ProtocolVersion {
1005 fn name() -> &'static HeaderName {
1006 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1007 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1008 }
1009
1010 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1011 where
1012 Self: Sized,
1013 I: Iterator<Item = &'i axum::http::HeaderValue>,
1014 {
1015 let version = values
1016 .next()
1017 .ok_or_else(axum::headers::Error::invalid)?
1018 .to_str()
1019 .map_err(|_| axum::headers::Error::invalid())?
1020 .parse()
1021 .map_err(|_| axum::headers::Error::invalid())?;
1022 Ok(Self(version))
1023 }
1024
1025 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1026 values.extend([self.0.to_string().parse().unwrap()]);
1027 }
1028}
1029
1030pub struct AppVersionHeader(Version);
1031impl Header for AppVersionHeader {
1032 fn name() -> &'static HeaderName {
1033 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1034 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1035 }
1036
1037 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1038 where
1039 Self: Sized,
1040 I: Iterator<Item = &'i axum::http::HeaderValue>,
1041 {
1042 let version = values
1043 .next()
1044 .ok_or_else(axum::headers::Error::invalid)?
1045 .to_str()
1046 .map_err(|_| axum::headers::Error::invalid())?
1047 .parse()
1048 .map_err(|_| axum::headers::Error::invalid())?;
1049 Ok(Self(version))
1050 }
1051
1052 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1053 values.extend([self.0.to_string().parse().unwrap()]);
1054 }
1055}
1056
1057#[derive(Debug)]
1058pub struct ReleaseChannelHeader(String);
1059
1060impl Header for ReleaseChannelHeader {
1061 fn name() -> &'static HeaderName {
1062 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1063 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1064 }
1065
1066 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1067 where
1068 Self: Sized,
1069 I: Iterator<Item = &'i axum::http::HeaderValue>,
1070 {
1071 Ok(Self(
1072 values
1073 .next()
1074 .ok_or_else(axum::headers::Error::invalid)?
1075 .to_str()
1076 .map_err(|_| axum::headers::Error::invalid())?
1077 .to_owned(),
1078 ))
1079 }
1080
1081 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1082 values.extend([self.0.parse().unwrap()]);
1083 }
1084}
1085
1086pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1087 Router::new()
1088 .route("/rpc", get(handle_websocket_request))
1089 .layer(
1090 ServiceBuilder::new()
1091 .layer(Extension(server.app_state.clone()))
1092 .layer(middleware::from_fn(auth::validate_header)),
1093 )
1094 .route("/metrics", get(handle_metrics))
1095 .layer(Extension(server))
1096}
1097
1098pub async fn handle_websocket_request(
1099 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1100 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1101 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1102 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1103 Extension(server): Extension<Arc<Server>>,
1104 Extension(principal): Extension<Principal>,
1105 user_agent: Option<TypedHeader<UserAgent>>,
1106 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1107 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1108 ws: WebSocketUpgrade,
1109) -> axum::response::Response {
1110 if protocol_version != rpc::PROTOCOL_VERSION {
1111 return (
1112 StatusCode::UPGRADE_REQUIRED,
1113 "client must be upgraded".to_string(),
1114 )
1115 .into_response();
1116 }
1117
1118 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1119 return (
1120 StatusCode::UPGRADE_REQUIRED,
1121 "no version header found".to_string(),
1122 )
1123 .into_response();
1124 };
1125
1126 let release_channel = release_channel_header.map(|header| header.0.0);
1127
1128 if !version.can_collaborate() {
1129 return (
1130 StatusCode::UPGRADE_REQUIRED,
1131 "client must be upgraded".to_string(),
1132 )
1133 .into_response();
1134 }
1135
1136 let socket_address = socket_address.to_string();
1137
1138 // Acquire connection guard before WebSocket upgrade
1139 let connection_guard = match ConnectionGuard::try_acquire() {
1140 Ok(guard) => guard,
1141 Err(()) => {
1142 return (
1143 StatusCode::SERVICE_UNAVAILABLE,
1144 "Too many concurrent connections",
1145 )
1146 .into_response();
1147 }
1148 };
1149
1150 ws.on_upgrade(move |socket| {
1151 let socket = socket
1152 .map_ok(to_tungstenite_message)
1153 .err_into()
1154 .with(|message| async move { to_axum_message(message) });
1155 let connection = Connection::new(Box::pin(socket));
1156 async move {
1157 server
1158 .handle_connection(
1159 connection,
1160 socket_address,
1161 principal,
1162 version,
1163 release_channel,
1164 user_agent.map(|header| header.to_string()),
1165 country_code_header.map(|header| header.to_string()),
1166 system_id_header.map(|header| header.to_string()),
1167 None,
1168 Executor::Production,
1169 Some(connection_guard),
1170 )
1171 .await;
1172 }
1173 })
1174}
1175
1176pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1177 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1178 let connections_metric = CONNECTIONS_METRIC
1179 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1180
1181 let connections = server
1182 .connection_pool
1183 .lock()
1184 .connections()
1185 .filter(|connection| !connection.admin)
1186 .count();
1187 connections_metric.set(connections as _);
1188
1189 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1190 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1191 register_int_gauge!(
1192 "shared_projects",
1193 "number of open projects with one or more guests"
1194 )
1195 .unwrap()
1196 });
1197
1198 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1199 shared_projects_metric.set(shared_projects as _);
1200
1201 let encoder = prometheus::TextEncoder::new();
1202 let metric_families = prometheus::gather();
1203 let encoded_metrics = encoder
1204 .encode_to_string(&metric_families)
1205 .map_err(|err| anyhow!("{err}"))?;
1206 Ok(encoded_metrics)
1207}
1208
1209#[instrument(err, skip(executor))]
1210async fn connection_lost(
1211 session: Session,
1212 mut teardown: watch::Receiver<bool>,
1213 executor: Executor,
1214) -> Result<()> {
1215 session.peer.disconnect(session.connection_id);
1216 session
1217 .connection_pool()
1218 .await
1219 .remove_connection(session.connection_id)?;
1220
1221 session
1222 .db()
1223 .await
1224 .connection_lost(session.connection_id)
1225 .await
1226 .trace_err();
1227
1228 futures::select_biased! {
1229 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1230
1231 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1232 leave_room_for_session(&session, session.connection_id).await.trace_err();
1233 leave_channel_buffers_for_session(&session)
1234 .await
1235 .trace_err();
1236
1237 if !session
1238 .connection_pool()
1239 .await
1240 .is_user_online(session.user_id())
1241 {
1242 let db = session.db().await;
1243 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1244 room_updated(&room, &session.peer);
1245 }
1246 }
1247
1248 update_user_contacts(session.user_id(), &session).await?;
1249 },
1250 _ = teardown.changed().fuse() => {}
1251 }
1252
1253 Ok(())
1254}
1255
1256/// Acknowledges a ping from a client, used to keep the connection alive.
1257async fn ping(
1258 _: proto::Ping,
1259 response: Response<proto::Ping>,
1260 _session: MessageContext,
1261) -> Result<()> {
1262 response.send(proto::Ack {})?;
1263 Ok(())
1264}
1265
1266/// Creates a new room for calling (outside of channels)
1267async fn create_room(
1268 _request: proto::CreateRoom,
1269 response: Response<proto::CreateRoom>,
1270 session: MessageContext,
1271) -> Result<()> {
1272 let livekit_room = nanoid::nanoid!(30);
1273
1274 let live_kit_connection_info = util::maybe!(async {
1275 let live_kit = session.app_state.livekit_client.as_ref();
1276 let live_kit = live_kit?;
1277 let user_id = session.user_id().to_string();
1278
1279 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1280
1281 Some(proto::LiveKitConnectionInfo {
1282 server_url: live_kit.url().into(),
1283 token,
1284 can_publish: true,
1285 })
1286 })
1287 .await;
1288
1289 let room = session
1290 .db()
1291 .await
1292 .create_room(session.user_id(), session.connection_id, &livekit_room)
1293 .await?;
1294
1295 response.send(proto::CreateRoomResponse {
1296 room: Some(room.clone()),
1297 live_kit_connection_info,
1298 })?;
1299
1300 update_user_contacts(session.user_id(), &session).await?;
1301 Ok(())
1302}
1303
1304/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1305async fn join_room(
1306 request: proto::JoinRoom,
1307 response: Response<proto::JoinRoom>,
1308 session: MessageContext,
1309) -> Result<()> {
1310 let room_id = RoomId::from_proto(request.id);
1311
1312 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1313
1314 if let Some(channel_id) = channel_id {
1315 return join_channel_internal(channel_id, Box::new(response), session).await;
1316 }
1317
1318 let joined_room = {
1319 let room = session
1320 .db()
1321 .await
1322 .join_room(room_id, session.user_id(), session.connection_id)
1323 .await?;
1324 room_updated(&room.room, &session.peer);
1325 room.into_inner()
1326 };
1327
1328 for connection_id in session
1329 .connection_pool()
1330 .await
1331 .user_connection_ids(session.user_id())
1332 {
1333 session
1334 .peer
1335 .send(
1336 connection_id,
1337 proto::CallCanceled {
1338 room_id: room_id.to_proto(),
1339 },
1340 )
1341 .trace_err();
1342 }
1343
1344 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1345 {
1346 live_kit
1347 .room_token(
1348 &joined_room.room.livekit_room,
1349 &session.user_id().to_string(),
1350 )
1351 .trace_err()
1352 .map(|token| proto::LiveKitConnectionInfo {
1353 server_url: live_kit.url().into(),
1354 token,
1355 can_publish: true,
1356 })
1357 } else {
1358 None
1359 };
1360
1361 response.send(proto::JoinRoomResponse {
1362 room: Some(joined_room.room),
1363 channel_id: None,
1364 live_kit_connection_info,
1365 })?;
1366
1367 update_user_contacts(session.user_id(), &session).await?;
1368 Ok(())
1369}
1370
1371/// Rejoin room is used to reconnect to a room after connection errors.
1372async fn rejoin_room(
1373 request: proto::RejoinRoom,
1374 response: Response<proto::RejoinRoom>,
1375 session: MessageContext,
1376) -> Result<()> {
1377 let room;
1378 let channel;
1379 {
1380 let mut rejoined_room = session
1381 .db()
1382 .await
1383 .rejoin_room(request, session.user_id(), session.connection_id)
1384 .await?;
1385
1386 response.send(proto::RejoinRoomResponse {
1387 room: Some(rejoined_room.room.clone()),
1388 reshared_projects: rejoined_room
1389 .reshared_projects
1390 .iter()
1391 .map(|project| proto::ResharedProject {
1392 id: project.id.to_proto(),
1393 collaborators: project
1394 .collaborators
1395 .iter()
1396 .map(|collaborator| collaborator.to_proto())
1397 .collect(),
1398 })
1399 .collect(),
1400 rejoined_projects: rejoined_room
1401 .rejoined_projects
1402 .iter()
1403 .map(|rejoined_project| rejoined_project.to_proto())
1404 .collect(),
1405 })?;
1406 room_updated(&rejoined_room.room, &session.peer);
1407
1408 for project in &rejoined_room.reshared_projects {
1409 for collaborator in &project.collaborators {
1410 session
1411 .peer
1412 .send(
1413 collaborator.connection_id,
1414 proto::UpdateProjectCollaborator {
1415 project_id: project.id.to_proto(),
1416 old_peer_id: Some(project.old_connection_id.into()),
1417 new_peer_id: Some(session.connection_id.into()),
1418 },
1419 )
1420 .trace_err();
1421 }
1422
1423 broadcast(
1424 Some(session.connection_id),
1425 project
1426 .collaborators
1427 .iter()
1428 .map(|collaborator| collaborator.connection_id),
1429 |connection_id| {
1430 session.peer.forward_send(
1431 session.connection_id,
1432 connection_id,
1433 proto::UpdateProject {
1434 project_id: project.id.to_proto(),
1435 worktrees: project.worktrees.clone(),
1436 },
1437 )
1438 },
1439 );
1440 }
1441
1442 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1443
1444 let rejoined_room = rejoined_room.into_inner();
1445
1446 room = rejoined_room.room;
1447 channel = rejoined_room.channel;
1448 }
1449
1450 if let Some(channel) = channel {
1451 channel_updated(
1452 &channel,
1453 &room,
1454 &session.peer,
1455 &*session.connection_pool().await,
1456 );
1457 }
1458
1459 update_user_contacts(session.user_id(), &session).await?;
1460 Ok(())
1461}
1462
1463fn notify_rejoined_projects(
1464 rejoined_projects: &mut Vec<RejoinedProject>,
1465 session: &Session,
1466) -> Result<()> {
1467 for project in rejoined_projects.iter() {
1468 for collaborator in &project.collaborators {
1469 session
1470 .peer
1471 .send(
1472 collaborator.connection_id,
1473 proto::UpdateProjectCollaborator {
1474 project_id: project.id.to_proto(),
1475 old_peer_id: Some(project.old_connection_id.into()),
1476 new_peer_id: Some(session.connection_id.into()),
1477 },
1478 )
1479 .trace_err();
1480 }
1481 }
1482
1483 for project in rejoined_projects {
1484 for worktree in mem::take(&mut project.worktrees) {
1485 // Stream this worktree's entries.
1486 let message = proto::UpdateWorktree {
1487 project_id: project.id.to_proto(),
1488 worktree_id: worktree.id,
1489 abs_path: worktree.abs_path.clone(),
1490 root_name: worktree.root_name,
1491 updated_entries: worktree.updated_entries,
1492 removed_entries: worktree.removed_entries,
1493 scan_id: worktree.scan_id,
1494 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1495 updated_repositories: worktree.updated_repositories,
1496 removed_repositories: worktree.removed_repositories,
1497 };
1498 for update in proto::split_worktree_update(message) {
1499 session.peer.send(session.connection_id, update)?;
1500 }
1501
1502 // Stream this worktree's diagnostics.
1503 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1504 if let Some(summary) = worktree_diagnostics.next() {
1505 let message = proto::UpdateDiagnosticSummary {
1506 project_id: project.id.to_proto(),
1507 worktree_id: worktree.id,
1508 summary: Some(summary),
1509 more_summaries: worktree_diagnostics.collect(),
1510 };
1511 session.peer.send(session.connection_id, message)?;
1512 }
1513
1514 for settings_file in worktree.settings_files {
1515 session.peer.send(
1516 session.connection_id,
1517 proto::UpdateWorktreeSettings {
1518 project_id: project.id.to_proto(),
1519 worktree_id: worktree.id,
1520 path: settings_file.path,
1521 content: Some(settings_file.content),
1522 kind: Some(settings_file.kind.to_proto().into()),
1523 outside_worktree: Some(settings_file.outside_worktree),
1524 },
1525 )?;
1526 }
1527 }
1528
1529 for repository in mem::take(&mut project.updated_repositories) {
1530 for update in split_repository_update(repository) {
1531 session.peer.send(session.connection_id, update)?;
1532 }
1533 }
1534
1535 for id in mem::take(&mut project.removed_repositories) {
1536 session.peer.send(
1537 session.connection_id,
1538 proto::RemoveRepository {
1539 project_id: project.id.to_proto(),
1540 id,
1541 },
1542 )?;
1543 }
1544 }
1545
1546 Ok(())
1547}
1548
1549/// leave room disconnects from the room.
1550async fn leave_room(
1551 _: proto::LeaveRoom,
1552 response: Response<proto::LeaveRoom>,
1553 session: MessageContext,
1554) -> Result<()> {
1555 leave_room_for_session(&session, session.connection_id).await?;
1556 response.send(proto::Ack {})?;
1557 Ok(())
1558}
1559
1560/// Updates the permissions of someone else in the room.
1561async fn set_room_participant_role(
1562 request: proto::SetRoomParticipantRole,
1563 response: Response<proto::SetRoomParticipantRole>,
1564 session: MessageContext,
1565) -> Result<()> {
1566 let user_id = UserId::from_proto(request.user_id);
1567 let role = ChannelRole::from(request.role());
1568
1569 let (livekit_room, can_publish) = {
1570 let room = session
1571 .db()
1572 .await
1573 .set_room_participant_role(
1574 session.user_id(),
1575 RoomId::from_proto(request.room_id),
1576 user_id,
1577 role,
1578 )
1579 .await?;
1580
1581 let livekit_room = room.livekit_room.clone();
1582 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1583 room_updated(&room, &session.peer);
1584 (livekit_room, can_publish)
1585 };
1586
1587 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1588 live_kit
1589 .update_participant(
1590 livekit_room.clone(),
1591 request.user_id.to_string(),
1592 livekit_api::proto::ParticipantPermission {
1593 can_subscribe: true,
1594 can_publish,
1595 can_publish_data: can_publish,
1596 hidden: false,
1597 recorder: false,
1598 },
1599 )
1600 .await
1601 .trace_err();
1602 }
1603
1604 response.send(proto::Ack {})?;
1605 Ok(())
1606}
1607
1608/// Call someone else into the current room
1609async fn call(
1610 request: proto::Call,
1611 response: Response<proto::Call>,
1612 session: MessageContext,
1613) -> Result<()> {
1614 let room_id = RoomId::from_proto(request.room_id);
1615 let calling_user_id = session.user_id();
1616 let calling_connection_id = session.connection_id;
1617 let called_user_id = UserId::from_proto(request.called_user_id);
1618 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1619 if !session
1620 .db()
1621 .await
1622 .has_contact(calling_user_id, called_user_id)
1623 .await?
1624 {
1625 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1626 }
1627
1628 let incoming_call = {
1629 let (room, incoming_call) = &mut *session
1630 .db()
1631 .await
1632 .call(
1633 room_id,
1634 calling_user_id,
1635 calling_connection_id,
1636 called_user_id,
1637 initial_project_id,
1638 )
1639 .await?;
1640 room_updated(room, &session.peer);
1641 mem::take(incoming_call)
1642 };
1643 update_user_contacts(called_user_id, &session).await?;
1644
1645 let mut calls = session
1646 .connection_pool()
1647 .await
1648 .user_connection_ids(called_user_id)
1649 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1650 .collect::<FuturesUnordered<_>>();
1651
1652 while let Some(call_response) = calls.next().await {
1653 match call_response.as_ref() {
1654 Ok(_) => {
1655 response.send(proto::Ack {})?;
1656 return Ok(());
1657 }
1658 Err(_) => {
1659 call_response.trace_err();
1660 }
1661 }
1662 }
1663
1664 {
1665 let room = session
1666 .db()
1667 .await
1668 .call_failed(room_id, called_user_id)
1669 .await?;
1670 room_updated(&room, &session.peer);
1671 }
1672 update_user_contacts(called_user_id, &session).await?;
1673
1674 Err(anyhow!("failed to ring user"))?
1675}
1676
1677/// Cancel an outgoing call.
1678async fn cancel_call(
1679 request: proto::CancelCall,
1680 response: Response<proto::CancelCall>,
1681 session: MessageContext,
1682) -> Result<()> {
1683 let called_user_id = UserId::from_proto(request.called_user_id);
1684 let room_id = RoomId::from_proto(request.room_id);
1685 {
1686 let room = session
1687 .db()
1688 .await
1689 .cancel_call(room_id, session.connection_id, called_user_id)
1690 .await?;
1691 room_updated(&room, &session.peer);
1692 }
1693
1694 for connection_id in session
1695 .connection_pool()
1696 .await
1697 .user_connection_ids(called_user_id)
1698 {
1699 session
1700 .peer
1701 .send(
1702 connection_id,
1703 proto::CallCanceled {
1704 room_id: room_id.to_proto(),
1705 },
1706 )
1707 .trace_err();
1708 }
1709 response.send(proto::Ack {})?;
1710
1711 update_user_contacts(called_user_id, &session).await?;
1712 Ok(())
1713}
1714
1715/// Decline an incoming call.
1716async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1717 let room_id = RoomId::from_proto(message.room_id);
1718 {
1719 let room = session
1720 .db()
1721 .await
1722 .decline_call(Some(room_id), session.user_id())
1723 .await?
1724 .context("declining call")?;
1725 room_updated(&room, &session.peer);
1726 }
1727
1728 for connection_id in session
1729 .connection_pool()
1730 .await
1731 .user_connection_ids(session.user_id())
1732 {
1733 session
1734 .peer
1735 .send(
1736 connection_id,
1737 proto::CallCanceled {
1738 room_id: room_id.to_proto(),
1739 },
1740 )
1741 .trace_err();
1742 }
1743 update_user_contacts(session.user_id(), &session).await?;
1744 Ok(())
1745}
1746
1747/// Updates other participants in the room with your current location.
1748async fn update_participant_location(
1749 request: proto::UpdateParticipantLocation,
1750 response: Response<proto::UpdateParticipantLocation>,
1751 session: MessageContext,
1752) -> Result<()> {
1753 let room_id = RoomId::from_proto(request.room_id);
1754 let location = request.location.context("invalid location")?;
1755
1756 let db = session.db().await;
1757 let room = db
1758 .update_room_participant_location(room_id, session.connection_id, location)
1759 .await?;
1760
1761 room_updated(&room, &session.peer);
1762 response.send(proto::Ack {})?;
1763 Ok(())
1764}
1765
1766/// Share a project into the room.
1767async fn share_project(
1768 request: proto::ShareProject,
1769 response: Response<proto::ShareProject>,
1770 session: MessageContext,
1771) -> Result<()> {
1772 let (project_id, room) = &*session
1773 .db()
1774 .await
1775 .share_project(
1776 RoomId::from_proto(request.room_id),
1777 session.connection_id,
1778 &request.worktrees,
1779 request.is_ssh_project,
1780 request.windows_paths.unwrap_or(false),
1781 )
1782 .await?;
1783 response.send(proto::ShareProjectResponse {
1784 project_id: project_id.to_proto(),
1785 })?;
1786 room_updated(room, &session.peer);
1787
1788 Ok(())
1789}
1790
1791/// Unshare a project from the room.
1792async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1793 let project_id = ProjectId::from_proto(message.project_id);
1794 unshare_project_internal(project_id, session.connection_id, &session).await
1795}
1796
1797async fn unshare_project_internal(
1798 project_id: ProjectId,
1799 connection_id: ConnectionId,
1800 session: &Session,
1801) -> Result<()> {
1802 let delete = {
1803 let room_guard = session
1804 .db()
1805 .await
1806 .unshare_project(project_id, connection_id)
1807 .await?;
1808
1809 let (delete, room, guest_connection_ids) = &*room_guard;
1810
1811 let message = proto::UnshareProject {
1812 project_id: project_id.to_proto(),
1813 };
1814
1815 broadcast(
1816 Some(connection_id),
1817 guest_connection_ids.iter().copied(),
1818 |conn_id| session.peer.send(conn_id, message.clone()),
1819 );
1820 if let Some(room) = room {
1821 room_updated(room, &session.peer);
1822 }
1823
1824 *delete
1825 };
1826
1827 if delete {
1828 let db = session.db().await;
1829 db.delete_project(project_id).await?;
1830 }
1831
1832 Ok(())
1833}
1834
1835/// Join someone elses shared project.
1836async fn join_project(
1837 request: proto::JoinProject,
1838 response: Response<proto::JoinProject>,
1839 session: MessageContext,
1840) -> Result<()> {
1841 let project_id = ProjectId::from_proto(request.project_id);
1842
1843 tracing::info!(%project_id, "join project");
1844
1845 let db = session.db().await;
1846 let (project, replica_id) = &mut *db
1847 .join_project(
1848 project_id,
1849 session.connection_id,
1850 session.user_id(),
1851 request.committer_name.clone(),
1852 request.committer_email.clone(),
1853 )
1854 .await?;
1855 drop(db);
1856 tracing::info!(%project_id, "join remote project");
1857 let collaborators = project
1858 .collaborators
1859 .iter()
1860 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1861 .map(|collaborator| collaborator.to_proto())
1862 .collect::<Vec<_>>();
1863 let project_id = project.id;
1864 let guest_user_id = session.user_id();
1865
1866 let worktrees = project
1867 .worktrees
1868 .iter()
1869 .map(|(id, worktree)| proto::WorktreeMetadata {
1870 id: *id,
1871 root_name: worktree.root_name.clone(),
1872 visible: worktree.visible,
1873 abs_path: worktree.abs_path.clone(),
1874 })
1875 .collect::<Vec<_>>();
1876
1877 let add_project_collaborator = proto::AddProjectCollaborator {
1878 project_id: project_id.to_proto(),
1879 collaborator: Some(proto::Collaborator {
1880 peer_id: Some(session.connection_id.into()),
1881 replica_id: replica_id.0 as u32,
1882 user_id: guest_user_id.to_proto(),
1883 is_host: false,
1884 committer_name: request.committer_name.clone(),
1885 committer_email: request.committer_email.clone(),
1886 }),
1887 };
1888
1889 for collaborator in &collaborators {
1890 session
1891 .peer
1892 .send(
1893 collaborator.peer_id.unwrap().into(),
1894 add_project_collaborator.clone(),
1895 )
1896 .trace_err();
1897 }
1898
1899 // First, we send the metadata associated with each worktree.
1900 let (language_servers, language_server_capabilities) = project
1901 .language_servers
1902 .clone()
1903 .into_iter()
1904 .map(|server| (server.server, server.capabilities))
1905 .unzip();
1906 response.send(proto::JoinProjectResponse {
1907 project_id: project.id.0 as u64,
1908 worktrees,
1909 replica_id: replica_id.0 as u32,
1910 collaborators,
1911 language_servers,
1912 language_server_capabilities,
1913 role: project.role.into(),
1914 windows_paths: project.path_style == PathStyle::Windows,
1915 })?;
1916
1917 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1918 // Stream this worktree's entries.
1919 let message = proto::UpdateWorktree {
1920 project_id: project_id.to_proto(),
1921 worktree_id,
1922 abs_path: worktree.abs_path.clone(),
1923 root_name: worktree.root_name,
1924 updated_entries: worktree.entries,
1925 removed_entries: Default::default(),
1926 scan_id: worktree.scan_id,
1927 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1928 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1929 removed_repositories: Default::default(),
1930 };
1931 for update in proto::split_worktree_update(message) {
1932 session.peer.send(session.connection_id, update.clone())?;
1933 }
1934
1935 // Stream this worktree's diagnostics.
1936 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1937 if let Some(summary) = worktree_diagnostics.next() {
1938 let message = proto::UpdateDiagnosticSummary {
1939 project_id: project.id.to_proto(),
1940 worktree_id: worktree.id,
1941 summary: Some(summary),
1942 more_summaries: worktree_diagnostics.collect(),
1943 };
1944 session.peer.send(session.connection_id, message)?;
1945 }
1946
1947 for settings_file in worktree.settings_files {
1948 session.peer.send(
1949 session.connection_id,
1950 proto::UpdateWorktreeSettings {
1951 project_id: project_id.to_proto(),
1952 worktree_id: worktree.id,
1953 path: settings_file.path,
1954 content: Some(settings_file.content),
1955 kind: Some(settings_file.kind.to_proto() as i32),
1956 outside_worktree: Some(settings_file.outside_worktree),
1957 },
1958 )?;
1959 }
1960 }
1961
1962 for repository in mem::take(&mut project.repositories) {
1963 for update in split_repository_update(repository) {
1964 session.peer.send(session.connection_id, update)?;
1965 }
1966 }
1967
1968 for language_server in &project.language_servers {
1969 session.peer.send(
1970 session.connection_id,
1971 proto::UpdateLanguageServer {
1972 project_id: project_id.to_proto(),
1973 server_name: Some(language_server.server.name.clone()),
1974 language_server_id: language_server.server.id,
1975 variant: Some(
1976 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1977 proto::LspDiskBasedDiagnosticsUpdated {},
1978 ),
1979 ),
1980 },
1981 )?;
1982 }
1983
1984 Ok(())
1985}
1986
1987/// Leave someone elses shared project.
1988async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
1989 let sender_id = session.connection_id;
1990 let project_id = ProjectId::from_proto(request.project_id);
1991 let db = session.db().await;
1992
1993 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1994 tracing::info!(
1995 %project_id,
1996 "leave project"
1997 );
1998
1999 project_left(project, &session);
2000 if let Some(room) = room {
2001 room_updated(room, &session.peer);
2002 }
2003
2004 Ok(())
2005}
2006
2007/// Updates other participants with changes to the project
2008async fn update_project(
2009 request: proto::UpdateProject,
2010 response: Response<proto::UpdateProject>,
2011 session: MessageContext,
2012) -> Result<()> {
2013 let project_id = ProjectId::from_proto(request.project_id);
2014 let (room, guest_connection_ids) = &*session
2015 .db()
2016 .await
2017 .update_project(project_id, session.connection_id, &request.worktrees)
2018 .await?;
2019 broadcast(
2020 Some(session.connection_id),
2021 guest_connection_ids.iter().copied(),
2022 |connection_id| {
2023 session
2024 .peer
2025 .forward_send(session.connection_id, connection_id, request.clone())
2026 },
2027 );
2028 if let Some(room) = room {
2029 room_updated(room, &session.peer);
2030 }
2031 response.send(proto::Ack {})?;
2032
2033 Ok(())
2034}
2035
2036/// Updates other participants with changes to the worktree
2037async fn update_worktree(
2038 request: proto::UpdateWorktree,
2039 response: Response<proto::UpdateWorktree>,
2040 session: MessageContext,
2041) -> Result<()> {
2042 let guest_connection_ids = session
2043 .db()
2044 .await
2045 .update_worktree(&request, session.connection_id)
2046 .await?;
2047
2048 broadcast(
2049 Some(session.connection_id),
2050 guest_connection_ids.iter().copied(),
2051 |connection_id| {
2052 session
2053 .peer
2054 .forward_send(session.connection_id, connection_id, request.clone())
2055 },
2056 );
2057 response.send(proto::Ack {})?;
2058 Ok(())
2059}
2060
2061async fn update_repository(
2062 request: proto::UpdateRepository,
2063 response: Response<proto::UpdateRepository>,
2064 session: MessageContext,
2065) -> Result<()> {
2066 let guest_connection_ids = session
2067 .db()
2068 .await
2069 .update_repository(&request, session.connection_id)
2070 .await?;
2071
2072 broadcast(
2073 Some(session.connection_id),
2074 guest_connection_ids.iter().copied(),
2075 |connection_id| {
2076 session
2077 .peer
2078 .forward_send(session.connection_id, connection_id, request.clone())
2079 },
2080 );
2081 response.send(proto::Ack {})?;
2082 Ok(())
2083}
2084
2085async fn remove_repository(
2086 request: proto::RemoveRepository,
2087 response: Response<proto::RemoveRepository>,
2088 session: MessageContext,
2089) -> Result<()> {
2090 let guest_connection_ids = session
2091 .db()
2092 .await
2093 .remove_repository(&request, session.connection_id)
2094 .await?;
2095
2096 broadcast(
2097 Some(session.connection_id),
2098 guest_connection_ids.iter().copied(),
2099 |connection_id| {
2100 session
2101 .peer
2102 .forward_send(session.connection_id, connection_id, request.clone())
2103 },
2104 );
2105 response.send(proto::Ack {})?;
2106 Ok(())
2107}
2108
2109/// Updates other participants with changes to the diagnostics
2110async fn update_diagnostic_summary(
2111 message: proto::UpdateDiagnosticSummary,
2112 session: MessageContext,
2113) -> Result<()> {
2114 let guest_connection_ids = session
2115 .db()
2116 .await
2117 .update_diagnostic_summary(&message, session.connection_id)
2118 .await?;
2119
2120 broadcast(
2121 Some(session.connection_id),
2122 guest_connection_ids.iter().copied(),
2123 |connection_id| {
2124 session
2125 .peer
2126 .forward_send(session.connection_id, connection_id, message.clone())
2127 },
2128 );
2129
2130 Ok(())
2131}
2132
2133/// Updates other participants with changes to the worktree settings
2134async fn update_worktree_settings(
2135 message: proto::UpdateWorktreeSettings,
2136 session: MessageContext,
2137) -> Result<()> {
2138 let guest_connection_ids = session
2139 .db()
2140 .await
2141 .update_worktree_settings(&message, session.connection_id)
2142 .await?;
2143
2144 broadcast(
2145 Some(session.connection_id),
2146 guest_connection_ids.iter().copied(),
2147 |connection_id| {
2148 session
2149 .peer
2150 .forward_send(session.connection_id, connection_id, message.clone())
2151 },
2152 );
2153
2154 Ok(())
2155}
2156
2157/// Notify other participants that a language server has started.
2158async fn start_language_server(
2159 request: proto::StartLanguageServer,
2160 session: MessageContext,
2161) -> Result<()> {
2162 let guest_connection_ids = session
2163 .db()
2164 .await
2165 .start_language_server(&request, session.connection_id)
2166 .await?;
2167
2168 broadcast(
2169 Some(session.connection_id),
2170 guest_connection_ids.iter().copied(),
2171 |connection_id| {
2172 session
2173 .peer
2174 .forward_send(session.connection_id, connection_id, request.clone())
2175 },
2176 );
2177 Ok(())
2178}
2179
2180/// Notify other participants that a language server has changed.
2181async fn update_language_server(
2182 request: proto::UpdateLanguageServer,
2183 session: MessageContext,
2184) -> Result<()> {
2185 let project_id = ProjectId::from_proto(request.project_id);
2186 let db = session.db().await;
2187
2188 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2189 && let Some(capabilities) = update.capabilities.clone()
2190 {
2191 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2192 .await?;
2193 }
2194
2195 let project_connection_ids = db
2196 .project_connection_ids(project_id, session.connection_id, true)
2197 .await?;
2198 broadcast(
2199 Some(session.connection_id),
2200 project_connection_ids.iter().copied(),
2201 |connection_id| {
2202 session
2203 .peer
2204 .forward_send(session.connection_id, connection_id, request.clone())
2205 },
2206 );
2207 Ok(())
2208}
2209
2210/// forward a project request to the host. These requests should be read only
2211/// as guests are allowed to send them.
2212async fn forward_read_only_project_request<T>(
2213 request: T,
2214 response: Response<T>,
2215 session: MessageContext,
2216) -> Result<()>
2217where
2218 T: EntityMessage + RequestMessage,
2219{
2220 let project_id = ProjectId::from_proto(request.remote_entity_id());
2221 let host_connection_id = session
2222 .db()
2223 .await
2224 .host_for_read_only_project_request(project_id, session.connection_id)
2225 .await?;
2226 let payload = session.forward_request(host_connection_id, request).await?;
2227 response.send(payload)?;
2228 Ok(())
2229}
2230
2231/// forward a project request to the host. These requests are disallowed
2232/// for guests.
2233async fn forward_mutating_project_request<T>(
2234 request: T,
2235 response: Response<T>,
2236 session: MessageContext,
2237) -> Result<()>
2238where
2239 T: EntityMessage + RequestMessage,
2240{
2241 let project_id = ProjectId::from_proto(request.remote_entity_id());
2242
2243 let host_connection_id = session
2244 .db()
2245 .await
2246 .host_for_mutating_project_request(project_id, session.connection_id)
2247 .await?;
2248 let payload = session.forward_request(host_connection_id, request).await?;
2249 response.send(payload)?;
2250 Ok(())
2251}
2252
2253async fn lsp_query(
2254 request: proto::LspQuery,
2255 response: Response<proto::LspQuery>,
2256 session: MessageContext,
2257) -> Result<()> {
2258 let (name, should_write) = request.query_name_and_write_permissions();
2259 tracing::Span::current().record("lsp_query_request", name);
2260 tracing::info!("lsp_query message received");
2261 if should_write {
2262 forward_mutating_project_request(request, response, session).await
2263 } else {
2264 forward_read_only_project_request(request, response, session).await
2265 }
2266}
2267
2268/// Notify other participants that a new buffer has been created
2269async fn create_buffer_for_peer(
2270 request: proto::CreateBufferForPeer,
2271 session: MessageContext,
2272) -> Result<()> {
2273 session
2274 .db()
2275 .await
2276 .check_user_is_project_host(
2277 ProjectId::from_proto(request.project_id),
2278 session.connection_id,
2279 )
2280 .await?;
2281 let peer_id = request.peer_id.context("invalid peer id")?;
2282 session
2283 .peer
2284 .forward_send(session.connection_id, peer_id.into(), request)?;
2285 Ok(())
2286}
2287
2288/// Notify other participants that a new image has been created
2289async fn create_image_for_peer(
2290 request: proto::CreateImageForPeer,
2291 session: MessageContext,
2292) -> Result<()> {
2293 session
2294 .db()
2295 .await
2296 .check_user_is_project_host(
2297 ProjectId::from_proto(request.project_id),
2298 session.connection_id,
2299 )
2300 .await?;
2301 let peer_id = request.peer_id.context("invalid peer id")?;
2302 session
2303 .peer
2304 .forward_send(session.connection_id, peer_id.into(), request)?;
2305 Ok(())
2306}
2307
2308/// Notify other participants that a buffer has been updated. This is
2309/// allowed for guests as long as the update is limited to selections.
2310async fn update_buffer(
2311 request: proto::UpdateBuffer,
2312 response: Response<proto::UpdateBuffer>,
2313 session: MessageContext,
2314) -> Result<()> {
2315 let project_id = ProjectId::from_proto(request.project_id);
2316 let mut capability = Capability::ReadOnly;
2317
2318 for op in request.operations.iter() {
2319 match op.variant {
2320 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2321 Some(_) => capability = Capability::ReadWrite,
2322 }
2323 }
2324
2325 let host = {
2326 let guard = session
2327 .db()
2328 .await
2329 .connections_for_buffer_update(project_id, session.connection_id, capability)
2330 .await?;
2331
2332 let (host, guests) = &*guard;
2333
2334 broadcast(
2335 Some(session.connection_id),
2336 guests.clone(),
2337 |connection_id| {
2338 session
2339 .peer
2340 .forward_send(session.connection_id, connection_id, request.clone())
2341 },
2342 );
2343
2344 *host
2345 };
2346
2347 if host != session.connection_id {
2348 session.forward_request(host, request.clone()).await?;
2349 }
2350
2351 response.send(proto::Ack {})?;
2352 Ok(())
2353}
2354
2355async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2356 let project_id = ProjectId::from_proto(message.project_id);
2357
2358 let operation = message.operation.as_ref().context("invalid operation")?;
2359 let capability = match operation.variant.as_ref() {
2360 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2361 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2362 match buffer_op.variant {
2363 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2364 Capability::ReadOnly
2365 }
2366 _ => Capability::ReadWrite,
2367 }
2368 } else {
2369 Capability::ReadWrite
2370 }
2371 }
2372 Some(_) => Capability::ReadWrite,
2373 None => Capability::ReadOnly,
2374 };
2375
2376 let guard = session
2377 .db()
2378 .await
2379 .connections_for_buffer_update(project_id, session.connection_id, capability)
2380 .await?;
2381
2382 let (host, guests) = &*guard;
2383
2384 broadcast(
2385 Some(session.connection_id),
2386 guests.iter().chain([host]).copied(),
2387 |connection_id| {
2388 session
2389 .peer
2390 .forward_send(session.connection_id, connection_id, message.clone())
2391 },
2392 );
2393
2394 Ok(())
2395}
2396
2397async fn forward_project_search_chunk(
2398 message: proto::FindSearchCandidatesChunk,
2399 response: Response<proto::FindSearchCandidatesChunk>,
2400 session: MessageContext,
2401) -> Result<()> {
2402 let peer_id = message.peer_id.context("missing peer_id")?;
2403 let payload = session
2404 .peer
2405 .forward_request(session.connection_id, peer_id.into(), message)
2406 .await?;
2407 response.send(payload)?;
2408 Ok(())
2409}
2410
2411/// Notify other participants that a project has been updated.
2412async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2413 request: T,
2414 session: MessageContext,
2415) -> Result<()> {
2416 let project_id = ProjectId::from_proto(request.remote_entity_id());
2417 let project_connection_ids = session
2418 .db()
2419 .await
2420 .project_connection_ids(project_id, session.connection_id, false)
2421 .await?;
2422
2423 broadcast(
2424 Some(session.connection_id),
2425 project_connection_ids.iter().copied(),
2426 |connection_id| {
2427 session
2428 .peer
2429 .forward_send(session.connection_id, connection_id, request.clone())
2430 },
2431 );
2432 Ok(())
2433}
2434
2435/// Start following another user in a call.
2436async fn follow(
2437 request: proto::Follow,
2438 response: Response<proto::Follow>,
2439 session: MessageContext,
2440) -> Result<()> {
2441 let room_id = RoomId::from_proto(request.room_id);
2442 let project_id = request.project_id.map(ProjectId::from_proto);
2443 let leader_id = request.leader_id.context("invalid leader id")?.into();
2444 let follower_id = session.connection_id;
2445
2446 session
2447 .db()
2448 .await
2449 .check_room_participants(room_id, leader_id, session.connection_id)
2450 .await?;
2451
2452 let response_payload = session.forward_request(leader_id, request).await?;
2453 response.send(response_payload)?;
2454
2455 if let Some(project_id) = project_id {
2456 let room = session
2457 .db()
2458 .await
2459 .follow(room_id, project_id, leader_id, follower_id)
2460 .await?;
2461 room_updated(&room, &session.peer);
2462 }
2463
2464 Ok(())
2465}
2466
2467/// Stop following another user in a call.
2468async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2469 let room_id = RoomId::from_proto(request.room_id);
2470 let project_id = request.project_id.map(ProjectId::from_proto);
2471 let leader_id = request.leader_id.context("invalid leader id")?.into();
2472 let follower_id = session.connection_id;
2473
2474 session
2475 .db()
2476 .await
2477 .check_room_participants(room_id, leader_id, session.connection_id)
2478 .await?;
2479
2480 session
2481 .peer
2482 .forward_send(session.connection_id, leader_id, request)?;
2483
2484 if let Some(project_id) = project_id {
2485 let room = session
2486 .db()
2487 .await
2488 .unfollow(room_id, project_id, leader_id, follower_id)
2489 .await?;
2490 room_updated(&room, &session.peer);
2491 }
2492
2493 Ok(())
2494}
2495
2496/// Notify everyone following you of your current location.
2497async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2498 let room_id = RoomId::from_proto(request.room_id);
2499 let database = session.db.lock().await;
2500
2501 let connection_ids = if let Some(project_id) = request.project_id {
2502 let project_id = ProjectId::from_proto(project_id);
2503 database
2504 .project_connection_ids(project_id, session.connection_id, true)
2505 .await?
2506 } else {
2507 database
2508 .room_connection_ids(room_id, session.connection_id)
2509 .await?
2510 };
2511
2512 // For now, don't send view update messages back to that view's current leader.
2513 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2514 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2515 _ => None,
2516 });
2517
2518 for connection_id in connection_ids.iter().cloned() {
2519 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2520 session
2521 .peer
2522 .forward_send(session.connection_id, connection_id, request.clone())?;
2523 }
2524 }
2525 Ok(())
2526}
2527
2528/// Get public data about users.
2529async fn get_users(
2530 request: proto::GetUsers,
2531 response: Response<proto::GetUsers>,
2532 session: MessageContext,
2533) -> Result<()> {
2534 let user_ids = request
2535 .user_ids
2536 .into_iter()
2537 .map(UserId::from_proto)
2538 .collect();
2539 let users = session
2540 .db()
2541 .await
2542 .get_users_by_ids(user_ids)
2543 .await?
2544 .into_iter()
2545 .map(|user| proto::User {
2546 id: user.id.to_proto(),
2547 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2548 github_login: user.github_login,
2549 name: user.name,
2550 })
2551 .collect();
2552 response.send(proto::UsersResponse { users })?;
2553 Ok(())
2554}
2555
2556/// Search for users (to invite) buy Github login
2557async fn fuzzy_search_users(
2558 request: proto::FuzzySearchUsers,
2559 response: Response<proto::FuzzySearchUsers>,
2560 session: MessageContext,
2561) -> Result<()> {
2562 let query = request.query;
2563 let users = match query.len() {
2564 0 => vec![],
2565 1 | 2 => session
2566 .db()
2567 .await
2568 .get_user_by_github_login(&query)
2569 .await?
2570 .into_iter()
2571 .collect(),
2572 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2573 };
2574 let users = users
2575 .into_iter()
2576 .filter(|user| user.id != session.user_id())
2577 .map(|user| proto::User {
2578 id: user.id.to_proto(),
2579 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2580 github_login: user.github_login,
2581 name: user.name,
2582 })
2583 .collect();
2584 response.send(proto::UsersResponse { users })?;
2585 Ok(())
2586}
2587
2588/// Send a contact request to another user.
2589async fn request_contact(
2590 request: proto::RequestContact,
2591 response: Response<proto::RequestContact>,
2592 session: MessageContext,
2593) -> Result<()> {
2594 let requester_id = session.user_id();
2595 let responder_id = UserId::from_proto(request.responder_id);
2596 if requester_id == responder_id {
2597 return Err(anyhow!("cannot add yourself as a contact"))?;
2598 }
2599
2600 let notifications = session
2601 .db()
2602 .await
2603 .send_contact_request(requester_id, responder_id)
2604 .await?;
2605
2606 // Update outgoing contact requests of requester
2607 let mut update = proto::UpdateContacts::default();
2608 update.outgoing_requests.push(responder_id.to_proto());
2609 for connection_id in session
2610 .connection_pool()
2611 .await
2612 .user_connection_ids(requester_id)
2613 {
2614 session.peer.send(connection_id, update.clone())?;
2615 }
2616
2617 // Update incoming contact requests of responder
2618 let mut update = proto::UpdateContacts::default();
2619 update
2620 .incoming_requests
2621 .push(proto::IncomingContactRequest {
2622 requester_id: requester_id.to_proto(),
2623 });
2624 let connection_pool = session.connection_pool().await;
2625 for connection_id in connection_pool.user_connection_ids(responder_id) {
2626 session.peer.send(connection_id, update.clone())?;
2627 }
2628
2629 send_notifications(&connection_pool, &session.peer, notifications);
2630
2631 response.send(proto::Ack {})?;
2632 Ok(())
2633}
2634
2635/// Accept or decline a contact request
2636async fn respond_to_contact_request(
2637 request: proto::RespondToContactRequest,
2638 response: Response<proto::RespondToContactRequest>,
2639 session: MessageContext,
2640) -> Result<()> {
2641 let responder_id = session.user_id();
2642 let requester_id = UserId::from_proto(request.requester_id);
2643 let db = session.db().await;
2644 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2645 db.dismiss_contact_notification(responder_id, requester_id)
2646 .await?;
2647 } else {
2648 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2649
2650 let notifications = db
2651 .respond_to_contact_request(responder_id, requester_id, accept)
2652 .await?;
2653 let requester_busy = db.is_user_busy(requester_id).await?;
2654 let responder_busy = db.is_user_busy(responder_id).await?;
2655
2656 let pool = session.connection_pool().await;
2657 // Update responder with new contact
2658 let mut update = proto::UpdateContacts::default();
2659 if accept {
2660 update
2661 .contacts
2662 .push(contact_for_user(requester_id, requester_busy, &pool));
2663 }
2664 update
2665 .remove_incoming_requests
2666 .push(requester_id.to_proto());
2667 for connection_id in pool.user_connection_ids(responder_id) {
2668 session.peer.send(connection_id, update.clone())?;
2669 }
2670
2671 // Update requester with new contact
2672 let mut update = proto::UpdateContacts::default();
2673 if accept {
2674 update
2675 .contacts
2676 .push(contact_for_user(responder_id, responder_busy, &pool));
2677 }
2678 update
2679 .remove_outgoing_requests
2680 .push(responder_id.to_proto());
2681
2682 for connection_id in pool.user_connection_ids(requester_id) {
2683 session.peer.send(connection_id, update.clone())?;
2684 }
2685
2686 send_notifications(&pool, &session.peer, notifications);
2687 }
2688
2689 response.send(proto::Ack {})?;
2690 Ok(())
2691}
2692
2693/// Remove a contact.
2694async fn remove_contact(
2695 request: proto::RemoveContact,
2696 response: Response<proto::RemoveContact>,
2697 session: MessageContext,
2698) -> Result<()> {
2699 let requester_id = session.user_id();
2700 let responder_id = UserId::from_proto(request.user_id);
2701 let db = session.db().await;
2702 let (contact_accepted, deleted_notification_id) =
2703 db.remove_contact(requester_id, responder_id).await?;
2704
2705 let pool = session.connection_pool().await;
2706 // Update outgoing contact requests of requester
2707 let mut update = proto::UpdateContacts::default();
2708 if contact_accepted {
2709 update.remove_contacts.push(responder_id.to_proto());
2710 } else {
2711 update
2712 .remove_outgoing_requests
2713 .push(responder_id.to_proto());
2714 }
2715 for connection_id in pool.user_connection_ids(requester_id) {
2716 session.peer.send(connection_id, update.clone())?;
2717 }
2718
2719 // Update incoming contact requests of responder
2720 let mut update = proto::UpdateContacts::default();
2721 if contact_accepted {
2722 update.remove_contacts.push(requester_id.to_proto());
2723 } else {
2724 update
2725 .remove_incoming_requests
2726 .push(requester_id.to_proto());
2727 }
2728 for connection_id in pool.user_connection_ids(responder_id) {
2729 session.peer.send(connection_id, update.clone())?;
2730 if let Some(notification_id) = deleted_notification_id {
2731 session.peer.send(
2732 connection_id,
2733 proto::DeleteNotification {
2734 notification_id: notification_id.to_proto(),
2735 },
2736 )?;
2737 }
2738 }
2739
2740 response.send(proto::Ack {})?;
2741 Ok(())
2742}
2743
2744fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2745 version.0.minor < 139
2746}
2747
2748async fn subscribe_to_channels(
2749 _: proto::SubscribeToChannels,
2750 session: MessageContext,
2751) -> Result<()> {
2752 subscribe_user_to_channels(session.user_id(), &session).await?;
2753 Ok(())
2754}
2755
2756async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2757 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2758 let mut pool = session.connection_pool().await;
2759 for membership in &channels_for_user.channel_memberships {
2760 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2761 }
2762 session.peer.send(
2763 session.connection_id,
2764 build_update_user_channels(&channels_for_user),
2765 )?;
2766 session.peer.send(
2767 session.connection_id,
2768 build_channels_update(channels_for_user),
2769 )?;
2770 Ok(())
2771}
2772
2773/// Creates a new channel.
2774async fn create_channel(
2775 request: proto::CreateChannel,
2776 response: Response<proto::CreateChannel>,
2777 session: MessageContext,
2778) -> Result<()> {
2779 let db = session.db().await;
2780
2781 let parent_id = request.parent_id.map(ChannelId::from_proto);
2782 let (channel, membership) = db
2783 .create_channel(&request.name, parent_id, session.user_id())
2784 .await?;
2785
2786 let root_id = channel.root_id();
2787 let channel = Channel::from_model(channel);
2788
2789 response.send(proto::CreateChannelResponse {
2790 channel: Some(channel.to_proto()),
2791 parent_id: request.parent_id,
2792 })?;
2793
2794 let mut connection_pool = session.connection_pool().await;
2795 if let Some(membership) = membership {
2796 connection_pool.subscribe_to_channel(
2797 membership.user_id,
2798 membership.channel_id,
2799 membership.role,
2800 );
2801 let update = proto::UpdateUserChannels {
2802 channel_memberships: vec![proto::ChannelMembership {
2803 channel_id: membership.channel_id.to_proto(),
2804 role: membership.role.into(),
2805 }],
2806 ..Default::default()
2807 };
2808 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2809 session.peer.send(connection_id, update.clone())?;
2810 }
2811 }
2812
2813 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2814 if !role.can_see_channel(channel.visibility) {
2815 continue;
2816 }
2817
2818 let update = proto::UpdateChannels {
2819 channels: vec![channel.to_proto()],
2820 ..Default::default()
2821 };
2822 session.peer.send(connection_id, update.clone())?;
2823 }
2824
2825 Ok(())
2826}
2827
2828/// Delete a channel
2829async fn delete_channel(
2830 request: proto::DeleteChannel,
2831 response: Response<proto::DeleteChannel>,
2832 session: MessageContext,
2833) -> Result<()> {
2834 let db = session.db().await;
2835
2836 let channel_id = request.channel_id;
2837 let (root_channel, removed_channels) = db
2838 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2839 .await?;
2840 response.send(proto::Ack {})?;
2841
2842 // Notify members of removed channels
2843 let mut update = proto::UpdateChannels::default();
2844 update
2845 .delete_channels
2846 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2847
2848 let connection_pool = session.connection_pool().await;
2849 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2850 session.peer.send(connection_id, update.clone())?;
2851 }
2852
2853 Ok(())
2854}
2855
2856/// Invite someone to join a channel.
2857async fn invite_channel_member(
2858 request: proto::InviteChannelMember,
2859 response: Response<proto::InviteChannelMember>,
2860 session: MessageContext,
2861) -> Result<()> {
2862 let db = session.db().await;
2863 let channel_id = ChannelId::from_proto(request.channel_id);
2864 let invitee_id = UserId::from_proto(request.user_id);
2865 let InviteMemberResult {
2866 channel,
2867 notifications,
2868 } = db
2869 .invite_channel_member(
2870 channel_id,
2871 invitee_id,
2872 session.user_id(),
2873 request.role().into(),
2874 )
2875 .await?;
2876
2877 let update = proto::UpdateChannels {
2878 channel_invitations: vec![channel.to_proto()],
2879 ..Default::default()
2880 };
2881
2882 let connection_pool = session.connection_pool().await;
2883 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2884 session.peer.send(connection_id, update.clone())?;
2885 }
2886
2887 send_notifications(&connection_pool, &session.peer, notifications);
2888
2889 response.send(proto::Ack {})?;
2890 Ok(())
2891}
2892
2893/// remove someone from a channel
2894async fn remove_channel_member(
2895 request: proto::RemoveChannelMember,
2896 response: Response<proto::RemoveChannelMember>,
2897 session: MessageContext,
2898) -> Result<()> {
2899 let db = session.db().await;
2900 let channel_id = ChannelId::from_proto(request.channel_id);
2901 let member_id = UserId::from_proto(request.user_id);
2902
2903 let RemoveChannelMemberResult {
2904 membership_update,
2905 notification_id,
2906 } = db
2907 .remove_channel_member(channel_id, member_id, session.user_id())
2908 .await?;
2909
2910 let mut connection_pool = session.connection_pool().await;
2911 notify_membership_updated(
2912 &mut connection_pool,
2913 membership_update,
2914 member_id,
2915 &session.peer,
2916 );
2917 for connection_id in connection_pool.user_connection_ids(member_id) {
2918 if let Some(notification_id) = notification_id {
2919 session
2920 .peer
2921 .send(
2922 connection_id,
2923 proto::DeleteNotification {
2924 notification_id: notification_id.to_proto(),
2925 },
2926 )
2927 .trace_err();
2928 }
2929 }
2930
2931 response.send(proto::Ack {})?;
2932 Ok(())
2933}
2934
2935/// Toggle the channel between public and private.
2936/// Care is taken to maintain the invariant that public channels only descend from public channels,
2937/// (though members-only channels can appear at any point in the hierarchy).
2938async fn set_channel_visibility(
2939 request: proto::SetChannelVisibility,
2940 response: Response<proto::SetChannelVisibility>,
2941 session: MessageContext,
2942) -> Result<()> {
2943 let db = session.db().await;
2944 let channel_id = ChannelId::from_proto(request.channel_id);
2945 let visibility = request.visibility().into();
2946
2947 let channel_model = db
2948 .set_channel_visibility(channel_id, visibility, session.user_id())
2949 .await?;
2950 let root_id = channel_model.root_id();
2951 let channel = Channel::from_model(channel_model);
2952
2953 let mut connection_pool = session.connection_pool().await;
2954 for (user_id, role) in connection_pool
2955 .channel_user_ids(root_id)
2956 .collect::<Vec<_>>()
2957 .into_iter()
2958 {
2959 let update = if role.can_see_channel(channel.visibility) {
2960 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2961 proto::UpdateChannels {
2962 channels: vec![channel.to_proto()],
2963 ..Default::default()
2964 }
2965 } else {
2966 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2967 proto::UpdateChannels {
2968 delete_channels: vec![channel.id.to_proto()],
2969 ..Default::default()
2970 }
2971 };
2972
2973 for connection_id in connection_pool.user_connection_ids(user_id) {
2974 session.peer.send(connection_id, update.clone())?;
2975 }
2976 }
2977
2978 response.send(proto::Ack {})?;
2979 Ok(())
2980}
2981
2982/// Alter the role for a user in the channel.
2983async fn set_channel_member_role(
2984 request: proto::SetChannelMemberRole,
2985 response: Response<proto::SetChannelMemberRole>,
2986 session: MessageContext,
2987) -> Result<()> {
2988 let db = session.db().await;
2989 let channel_id = ChannelId::from_proto(request.channel_id);
2990 let member_id = UserId::from_proto(request.user_id);
2991 let result = db
2992 .set_channel_member_role(
2993 channel_id,
2994 session.user_id(),
2995 member_id,
2996 request.role().into(),
2997 )
2998 .await?;
2999
3000 match result {
3001 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3002 let mut connection_pool = session.connection_pool().await;
3003 notify_membership_updated(
3004 &mut connection_pool,
3005 membership_update,
3006 member_id,
3007 &session.peer,
3008 )
3009 }
3010 db::SetMemberRoleResult::InviteUpdated(channel) => {
3011 let update = proto::UpdateChannels {
3012 channel_invitations: vec![channel.to_proto()],
3013 ..Default::default()
3014 };
3015
3016 for connection_id in session
3017 .connection_pool()
3018 .await
3019 .user_connection_ids(member_id)
3020 {
3021 session.peer.send(connection_id, update.clone())?;
3022 }
3023 }
3024 }
3025
3026 response.send(proto::Ack {})?;
3027 Ok(())
3028}
3029
3030/// Change the name of a channel
3031async fn rename_channel(
3032 request: proto::RenameChannel,
3033 response: Response<proto::RenameChannel>,
3034 session: MessageContext,
3035) -> Result<()> {
3036 let db = session.db().await;
3037 let channel_id = ChannelId::from_proto(request.channel_id);
3038 let channel_model = db
3039 .rename_channel(channel_id, session.user_id(), &request.name)
3040 .await?;
3041 let root_id = channel_model.root_id();
3042 let channel = Channel::from_model(channel_model);
3043
3044 response.send(proto::RenameChannelResponse {
3045 channel: Some(channel.to_proto()),
3046 })?;
3047
3048 let connection_pool = session.connection_pool().await;
3049 let update = proto::UpdateChannels {
3050 channels: vec![channel.to_proto()],
3051 ..Default::default()
3052 };
3053 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3054 if role.can_see_channel(channel.visibility) {
3055 session.peer.send(connection_id, update.clone())?;
3056 }
3057 }
3058
3059 Ok(())
3060}
3061
3062/// Move a channel to a new parent.
3063async fn move_channel(
3064 request: proto::MoveChannel,
3065 response: Response<proto::MoveChannel>,
3066 session: MessageContext,
3067) -> Result<()> {
3068 let channel_id = ChannelId::from_proto(request.channel_id);
3069 let to = ChannelId::from_proto(request.to);
3070
3071 let (root_id, channels) = session
3072 .db()
3073 .await
3074 .move_channel(channel_id, to, session.user_id())
3075 .await?;
3076
3077 let connection_pool = session.connection_pool().await;
3078 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3079 let channels = channels
3080 .iter()
3081 .filter_map(|channel| {
3082 if role.can_see_channel(channel.visibility) {
3083 Some(channel.to_proto())
3084 } else {
3085 None
3086 }
3087 })
3088 .collect::<Vec<_>>();
3089 if channels.is_empty() {
3090 continue;
3091 }
3092
3093 let update = proto::UpdateChannels {
3094 channels,
3095 ..Default::default()
3096 };
3097
3098 session.peer.send(connection_id, update.clone())?;
3099 }
3100
3101 response.send(Ack {})?;
3102 Ok(())
3103}
3104
3105async fn reorder_channel(
3106 request: proto::ReorderChannel,
3107 response: Response<proto::ReorderChannel>,
3108 session: MessageContext,
3109) -> Result<()> {
3110 let channel_id = ChannelId::from_proto(request.channel_id);
3111 let direction = request.direction();
3112
3113 let updated_channels = session
3114 .db()
3115 .await
3116 .reorder_channel(channel_id, direction, session.user_id())
3117 .await?;
3118
3119 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3120 let connection_pool = session.connection_pool().await;
3121 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3122 let channels = updated_channels
3123 .iter()
3124 .filter_map(|channel| {
3125 if role.can_see_channel(channel.visibility) {
3126 Some(channel.to_proto())
3127 } else {
3128 None
3129 }
3130 })
3131 .collect::<Vec<_>>();
3132
3133 if channels.is_empty() {
3134 continue;
3135 }
3136
3137 let update = proto::UpdateChannels {
3138 channels,
3139 ..Default::default()
3140 };
3141
3142 session.peer.send(connection_id, update.clone())?;
3143 }
3144 }
3145
3146 response.send(Ack {})?;
3147 Ok(())
3148}
3149
3150/// Get the list of channel members
3151async fn get_channel_members(
3152 request: proto::GetChannelMembers,
3153 response: Response<proto::GetChannelMembers>,
3154 session: MessageContext,
3155) -> Result<()> {
3156 let db = session.db().await;
3157 let channel_id = ChannelId::from_proto(request.channel_id);
3158 let limit = if request.limit == 0 {
3159 u16::MAX as u64
3160 } else {
3161 request.limit
3162 };
3163 let (members, users) = db
3164 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3165 .await?;
3166 response.send(proto::GetChannelMembersResponse { members, users })?;
3167 Ok(())
3168}
3169
3170/// Accept or decline a channel invitation.
3171async fn respond_to_channel_invite(
3172 request: proto::RespondToChannelInvite,
3173 response: Response<proto::RespondToChannelInvite>,
3174 session: MessageContext,
3175) -> Result<()> {
3176 let db = session.db().await;
3177 let channel_id = ChannelId::from_proto(request.channel_id);
3178 let RespondToChannelInvite {
3179 membership_update,
3180 notifications,
3181 } = db
3182 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3183 .await?;
3184
3185 let mut connection_pool = session.connection_pool().await;
3186 if let Some(membership_update) = membership_update {
3187 notify_membership_updated(
3188 &mut connection_pool,
3189 membership_update,
3190 session.user_id(),
3191 &session.peer,
3192 );
3193 } else {
3194 let update = proto::UpdateChannels {
3195 remove_channel_invitations: vec![channel_id.to_proto()],
3196 ..Default::default()
3197 };
3198
3199 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3200 session.peer.send(connection_id, update.clone())?;
3201 }
3202 };
3203
3204 send_notifications(&connection_pool, &session.peer, notifications);
3205
3206 response.send(proto::Ack {})?;
3207
3208 Ok(())
3209}
3210
3211/// Join the channels' room
3212async fn join_channel(
3213 request: proto::JoinChannel,
3214 response: Response<proto::JoinChannel>,
3215 session: MessageContext,
3216) -> Result<()> {
3217 let channel_id = ChannelId::from_proto(request.channel_id);
3218 join_channel_internal(channel_id, Box::new(response), session).await
3219}
3220
3221trait JoinChannelInternalResponse {
3222 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3223}
3224impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3225 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3226 Response::<proto::JoinChannel>::send(self, result)
3227 }
3228}
3229impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3230 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3231 Response::<proto::JoinRoom>::send(self, result)
3232 }
3233}
3234
3235async fn join_channel_internal(
3236 channel_id: ChannelId,
3237 response: Box<impl JoinChannelInternalResponse>,
3238 session: MessageContext,
3239) -> Result<()> {
3240 let joined_room = {
3241 let mut db = session.db().await;
3242 // If zed quits without leaving the room, and the user re-opens zed before the
3243 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3244 // room they were in.
3245 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3246 tracing::info!(
3247 stale_connection_id = %connection,
3248 "cleaning up stale connection",
3249 );
3250 drop(db);
3251 leave_room_for_session(&session, connection).await?;
3252 db = session.db().await;
3253 }
3254
3255 let (joined_room, membership_updated, role) = db
3256 .join_channel(channel_id, session.user_id(), session.connection_id)
3257 .await?;
3258
3259 let live_kit_connection_info =
3260 session
3261 .app_state
3262 .livekit_client
3263 .as_ref()
3264 .and_then(|live_kit| {
3265 let (can_publish, token) = if role == ChannelRole::Guest {
3266 (
3267 false,
3268 live_kit
3269 .guest_token(
3270 &joined_room.room.livekit_room,
3271 &session.user_id().to_string(),
3272 )
3273 .trace_err()?,
3274 )
3275 } else {
3276 (
3277 true,
3278 live_kit
3279 .room_token(
3280 &joined_room.room.livekit_room,
3281 &session.user_id().to_string(),
3282 )
3283 .trace_err()?,
3284 )
3285 };
3286
3287 Some(LiveKitConnectionInfo {
3288 server_url: live_kit.url().into(),
3289 token,
3290 can_publish,
3291 })
3292 });
3293
3294 response.send(proto::JoinRoomResponse {
3295 room: Some(joined_room.room.clone()),
3296 channel_id: joined_room
3297 .channel
3298 .as_ref()
3299 .map(|channel| channel.id.to_proto()),
3300 live_kit_connection_info,
3301 })?;
3302
3303 let mut connection_pool = session.connection_pool().await;
3304 if let Some(membership_updated) = membership_updated {
3305 notify_membership_updated(
3306 &mut connection_pool,
3307 membership_updated,
3308 session.user_id(),
3309 &session.peer,
3310 );
3311 }
3312
3313 room_updated(&joined_room.room, &session.peer);
3314
3315 joined_room
3316 };
3317
3318 channel_updated(
3319 &joined_room.channel.context("channel not returned")?,
3320 &joined_room.room,
3321 &session.peer,
3322 &*session.connection_pool().await,
3323 );
3324
3325 update_user_contacts(session.user_id(), &session).await?;
3326 Ok(())
3327}
3328
3329/// Start editing the channel notes
3330async fn join_channel_buffer(
3331 request: proto::JoinChannelBuffer,
3332 response: Response<proto::JoinChannelBuffer>,
3333 session: MessageContext,
3334) -> Result<()> {
3335 let db = session.db().await;
3336 let channel_id = ChannelId::from_proto(request.channel_id);
3337
3338 let open_response = db
3339 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3340 .await?;
3341
3342 let collaborators = open_response.collaborators.clone();
3343 response.send(open_response)?;
3344
3345 let update = UpdateChannelBufferCollaborators {
3346 channel_id: channel_id.to_proto(),
3347 collaborators: collaborators.clone(),
3348 };
3349 channel_buffer_updated(
3350 session.connection_id,
3351 collaborators
3352 .iter()
3353 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3354 &update,
3355 &session.peer,
3356 );
3357
3358 Ok(())
3359}
3360
3361/// Edit the channel notes
3362async fn update_channel_buffer(
3363 request: proto::UpdateChannelBuffer,
3364 session: MessageContext,
3365) -> Result<()> {
3366 let db = session.db().await;
3367 let channel_id = ChannelId::from_proto(request.channel_id);
3368
3369 let (collaborators, epoch, version) = db
3370 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3371 .await?;
3372
3373 channel_buffer_updated(
3374 session.connection_id,
3375 collaborators.clone(),
3376 &proto::UpdateChannelBuffer {
3377 channel_id: channel_id.to_proto(),
3378 operations: request.operations,
3379 },
3380 &session.peer,
3381 );
3382
3383 let pool = &*session.connection_pool().await;
3384
3385 let non_collaborators =
3386 pool.channel_connection_ids(channel_id)
3387 .filter_map(|(connection_id, _)| {
3388 if collaborators.contains(&connection_id) {
3389 None
3390 } else {
3391 Some(connection_id)
3392 }
3393 });
3394
3395 broadcast(None, non_collaborators, |peer_id| {
3396 session.peer.send(
3397 peer_id,
3398 proto::UpdateChannels {
3399 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3400 channel_id: channel_id.to_proto(),
3401 epoch: epoch as u64,
3402 version: version.clone(),
3403 }],
3404 ..Default::default()
3405 },
3406 )
3407 });
3408
3409 Ok(())
3410}
3411
3412/// Rejoin the channel notes after a connection blip
3413async fn rejoin_channel_buffers(
3414 request: proto::RejoinChannelBuffers,
3415 response: Response<proto::RejoinChannelBuffers>,
3416 session: MessageContext,
3417) -> Result<()> {
3418 let db = session.db().await;
3419 let buffers = db
3420 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3421 .await?;
3422
3423 for rejoined_buffer in &buffers {
3424 let collaborators_to_notify = rejoined_buffer
3425 .buffer
3426 .collaborators
3427 .iter()
3428 .filter_map(|c| Some(c.peer_id?.into()));
3429 channel_buffer_updated(
3430 session.connection_id,
3431 collaborators_to_notify,
3432 &proto::UpdateChannelBufferCollaborators {
3433 channel_id: rejoined_buffer.buffer.channel_id,
3434 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3435 },
3436 &session.peer,
3437 );
3438 }
3439
3440 response.send(proto::RejoinChannelBuffersResponse {
3441 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3442 })?;
3443
3444 Ok(())
3445}
3446
3447/// Stop editing the channel notes
3448async fn leave_channel_buffer(
3449 request: proto::LeaveChannelBuffer,
3450 response: Response<proto::LeaveChannelBuffer>,
3451 session: MessageContext,
3452) -> Result<()> {
3453 let db = session.db().await;
3454 let channel_id = ChannelId::from_proto(request.channel_id);
3455
3456 let left_buffer = db
3457 .leave_channel_buffer(channel_id, session.connection_id)
3458 .await?;
3459
3460 response.send(Ack {})?;
3461
3462 channel_buffer_updated(
3463 session.connection_id,
3464 left_buffer.connections,
3465 &proto::UpdateChannelBufferCollaborators {
3466 channel_id: channel_id.to_proto(),
3467 collaborators: left_buffer.collaborators,
3468 },
3469 &session.peer,
3470 );
3471
3472 Ok(())
3473}
3474
3475fn channel_buffer_updated<T: EnvelopedMessage>(
3476 sender_id: ConnectionId,
3477 collaborators: impl IntoIterator<Item = ConnectionId>,
3478 message: &T,
3479 peer: &Peer,
3480) {
3481 broadcast(Some(sender_id), collaborators, |peer_id| {
3482 peer.send(peer_id, message.clone())
3483 });
3484}
3485
3486fn send_notifications(
3487 connection_pool: &ConnectionPool,
3488 peer: &Peer,
3489 notifications: db::NotificationBatch,
3490) {
3491 for (user_id, notification) in notifications {
3492 for connection_id in connection_pool.user_connection_ids(user_id) {
3493 if let Err(error) = peer.send(
3494 connection_id,
3495 proto::AddNotification {
3496 notification: Some(notification.clone()),
3497 },
3498 ) {
3499 tracing::error!(
3500 "failed to send notification to {:?} {}",
3501 connection_id,
3502 error
3503 );
3504 }
3505 }
3506 }
3507}
3508
3509/// Send a message to the channel
3510async fn send_channel_message(
3511 _request: proto::SendChannelMessage,
3512 _response: Response<proto::SendChannelMessage>,
3513 _session: MessageContext,
3514) -> Result<()> {
3515 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3516}
3517
3518/// Delete a channel message
3519async fn remove_channel_message(
3520 _request: proto::RemoveChannelMessage,
3521 _response: Response<proto::RemoveChannelMessage>,
3522 _session: MessageContext,
3523) -> Result<()> {
3524 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3525}
3526
3527async fn update_channel_message(
3528 _request: proto::UpdateChannelMessage,
3529 _response: Response<proto::UpdateChannelMessage>,
3530 _session: MessageContext,
3531) -> Result<()> {
3532 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3533}
3534
3535/// Mark a channel message as read
3536async fn acknowledge_channel_message(
3537 _request: proto::AckChannelMessage,
3538 _session: MessageContext,
3539) -> Result<()> {
3540 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3541}
3542
3543/// Mark a buffer version as synced
3544async fn acknowledge_buffer_version(
3545 request: proto::AckBufferOperation,
3546 session: MessageContext,
3547) -> Result<()> {
3548 let buffer_id = BufferId::from_proto(request.buffer_id);
3549 session
3550 .db()
3551 .await
3552 .observe_buffer_version(
3553 buffer_id,
3554 session.user_id(),
3555 request.epoch as i32,
3556 &request.version,
3557 )
3558 .await?;
3559 Ok(())
3560}
3561
3562/// Start receiving chat updates for a channel
3563async fn join_channel_chat(
3564 _request: proto::JoinChannelChat,
3565 _response: Response<proto::JoinChannelChat>,
3566 _session: MessageContext,
3567) -> Result<()> {
3568 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3569}
3570
3571/// Stop receiving chat updates for a channel
3572async fn leave_channel_chat(
3573 _request: proto::LeaveChannelChat,
3574 _session: MessageContext,
3575) -> Result<()> {
3576 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3577}
3578
3579/// Retrieve the chat history for a channel
3580async fn get_channel_messages(
3581 _request: proto::GetChannelMessages,
3582 _response: Response<proto::GetChannelMessages>,
3583 _session: MessageContext,
3584) -> Result<()> {
3585 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3586}
3587
3588/// Retrieve specific chat messages
3589async fn get_channel_messages_by_id(
3590 _request: proto::GetChannelMessagesById,
3591 _response: Response<proto::GetChannelMessagesById>,
3592 _session: MessageContext,
3593) -> Result<()> {
3594 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3595}
3596
3597/// Retrieve the current users notifications
3598async fn get_notifications(
3599 request: proto::GetNotifications,
3600 response: Response<proto::GetNotifications>,
3601 session: MessageContext,
3602) -> Result<()> {
3603 let notifications = session
3604 .db()
3605 .await
3606 .get_notifications(
3607 session.user_id(),
3608 NOTIFICATION_COUNT_PER_PAGE,
3609 request.before_id.map(db::NotificationId::from_proto),
3610 )
3611 .await?;
3612 response.send(proto::GetNotificationsResponse {
3613 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3614 notifications,
3615 })?;
3616 Ok(())
3617}
3618
3619/// Mark notifications as read
3620async fn mark_notification_as_read(
3621 request: proto::MarkNotificationRead,
3622 response: Response<proto::MarkNotificationRead>,
3623 session: MessageContext,
3624) -> Result<()> {
3625 let database = &session.db().await;
3626 let notifications = database
3627 .mark_notification_as_read_by_id(
3628 session.user_id(),
3629 NotificationId::from_proto(request.notification_id),
3630 )
3631 .await?;
3632 send_notifications(
3633 &*session.connection_pool().await,
3634 &session.peer,
3635 notifications,
3636 );
3637 response.send(proto::Ack {})?;
3638 Ok(())
3639}
3640
3641fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3642 let message = match message {
3643 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3644 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3645 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3646 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3647 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3648 code: frame.code.into(),
3649 reason: frame.reason.as_str().to_owned().into(),
3650 })),
3651 // We should never receive a frame while reading the message, according
3652 // to the `tungstenite` maintainers:
3653 //
3654 // > It cannot occur when you read messages from the WebSocket, but it
3655 // > can be used when you want to send the raw frames (e.g. you want to
3656 // > send the frames to the WebSocket without composing the full message first).
3657 // >
3658 // > — https://github.com/snapview/tungstenite-rs/issues/268
3659 TungsteniteMessage::Frame(_) => {
3660 bail!("received an unexpected frame while reading the message")
3661 }
3662 };
3663
3664 Ok(message)
3665}
3666
3667fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3668 match message {
3669 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3670 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3671 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3672 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3673 AxumMessage::Close(frame) => {
3674 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3675 code: frame.code.into(),
3676 reason: frame.reason.as_ref().into(),
3677 }))
3678 }
3679 }
3680}
3681
3682fn notify_membership_updated(
3683 connection_pool: &mut ConnectionPool,
3684 result: MembershipUpdated,
3685 user_id: UserId,
3686 peer: &Peer,
3687) {
3688 for membership in &result.new_channels.channel_memberships {
3689 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3690 }
3691 for channel_id in &result.removed_channels {
3692 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3693 }
3694
3695 let user_channels_update = proto::UpdateUserChannels {
3696 channel_memberships: result
3697 .new_channels
3698 .channel_memberships
3699 .iter()
3700 .map(|cm| proto::ChannelMembership {
3701 channel_id: cm.channel_id.to_proto(),
3702 role: cm.role.into(),
3703 })
3704 .collect(),
3705 ..Default::default()
3706 };
3707
3708 let mut update = build_channels_update(result.new_channels);
3709 update.delete_channels = result
3710 .removed_channels
3711 .into_iter()
3712 .map(|id| id.to_proto())
3713 .collect();
3714 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3715
3716 for connection_id in connection_pool.user_connection_ids(user_id) {
3717 peer.send(connection_id, user_channels_update.clone())
3718 .trace_err();
3719 peer.send(connection_id, update.clone()).trace_err();
3720 }
3721}
3722
3723fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3724 proto::UpdateUserChannels {
3725 channel_memberships: channels
3726 .channel_memberships
3727 .iter()
3728 .map(|m| proto::ChannelMembership {
3729 channel_id: m.channel_id.to_proto(),
3730 role: m.role.into(),
3731 })
3732 .collect(),
3733 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3734 }
3735}
3736
3737fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3738 let mut update = proto::UpdateChannels::default();
3739
3740 for channel in channels.channels {
3741 update.channels.push(channel.to_proto());
3742 }
3743
3744 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3745
3746 for (channel_id, participants) in channels.channel_participants {
3747 update
3748 .channel_participants
3749 .push(proto::ChannelParticipants {
3750 channel_id: channel_id.to_proto(),
3751 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3752 });
3753 }
3754
3755 for channel in channels.invited_channels {
3756 update.channel_invitations.push(channel.to_proto());
3757 }
3758
3759 update
3760}
3761
3762fn build_initial_contacts_update(
3763 contacts: Vec<db::Contact>,
3764 pool: &ConnectionPool,
3765) -> proto::UpdateContacts {
3766 let mut update = proto::UpdateContacts::default();
3767
3768 for contact in contacts {
3769 match contact {
3770 db::Contact::Accepted { user_id, busy } => {
3771 update.contacts.push(contact_for_user(user_id, busy, pool));
3772 }
3773 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3774 db::Contact::Incoming { user_id } => {
3775 update
3776 .incoming_requests
3777 .push(proto::IncomingContactRequest {
3778 requester_id: user_id.to_proto(),
3779 })
3780 }
3781 }
3782 }
3783
3784 update
3785}
3786
3787fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3788 proto::Contact {
3789 user_id: user_id.to_proto(),
3790 online: pool.is_user_online(user_id),
3791 busy,
3792 }
3793}
3794
3795fn room_updated(room: &proto::Room, peer: &Peer) {
3796 broadcast(
3797 None,
3798 room.participants
3799 .iter()
3800 .filter_map(|participant| Some(participant.peer_id?.into())),
3801 |peer_id| {
3802 peer.send(
3803 peer_id,
3804 proto::RoomUpdated {
3805 room: Some(room.clone()),
3806 },
3807 )
3808 },
3809 );
3810}
3811
3812fn channel_updated(
3813 channel: &db::channel::Model,
3814 room: &proto::Room,
3815 peer: &Peer,
3816 pool: &ConnectionPool,
3817) {
3818 let participants = room
3819 .participants
3820 .iter()
3821 .map(|p| p.user_id)
3822 .collect::<Vec<_>>();
3823
3824 broadcast(
3825 None,
3826 pool.channel_connection_ids(channel.root_id())
3827 .filter_map(|(channel_id, role)| {
3828 role.can_see_channel(channel.visibility)
3829 .then_some(channel_id)
3830 }),
3831 |peer_id| {
3832 peer.send(
3833 peer_id,
3834 proto::UpdateChannels {
3835 channel_participants: vec![proto::ChannelParticipants {
3836 channel_id: channel.id.to_proto(),
3837 participant_user_ids: participants.clone(),
3838 }],
3839 ..Default::default()
3840 },
3841 )
3842 },
3843 );
3844}
3845
3846async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3847 let db = session.db().await;
3848
3849 let contacts = db.get_contacts(user_id).await?;
3850 let busy = db.is_user_busy(user_id).await?;
3851
3852 let pool = session.connection_pool().await;
3853 let updated_contact = contact_for_user(user_id, busy, &pool);
3854 for contact in contacts {
3855 if let db::Contact::Accepted {
3856 user_id: contact_user_id,
3857 ..
3858 } = contact
3859 {
3860 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3861 session
3862 .peer
3863 .send(
3864 contact_conn_id,
3865 proto::UpdateContacts {
3866 contacts: vec![updated_contact.clone()],
3867 remove_contacts: Default::default(),
3868 incoming_requests: Default::default(),
3869 remove_incoming_requests: Default::default(),
3870 outgoing_requests: Default::default(),
3871 remove_outgoing_requests: Default::default(),
3872 },
3873 )
3874 .trace_err();
3875 }
3876 }
3877 }
3878 Ok(())
3879}
3880
3881async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3882 let mut contacts_to_update = HashSet::default();
3883
3884 let room_id;
3885 let canceled_calls_to_user_ids;
3886 let livekit_room;
3887 let delete_livekit_room;
3888 let room;
3889 let channel;
3890
3891 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3892 contacts_to_update.insert(session.user_id());
3893
3894 for project in left_room.left_projects.values() {
3895 project_left(project, session);
3896 }
3897
3898 room_id = RoomId::from_proto(left_room.room.id);
3899 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3900 livekit_room = mem::take(&mut left_room.room.livekit_room);
3901 delete_livekit_room = left_room.deleted;
3902 room = mem::take(&mut left_room.room);
3903 channel = mem::take(&mut left_room.channel);
3904
3905 room_updated(&room, &session.peer);
3906 } else {
3907 return Ok(());
3908 }
3909
3910 if let Some(channel) = channel {
3911 channel_updated(
3912 &channel,
3913 &room,
3914 &session.peer,
3915 &*session.connection_pool().await,
3916 );
3917 }
3918
3919 {
3920 let pool = session.connection_pool().await;
3921 for canceled_user_id in canceled_calls_to_user_ids {
3922 for connection_id in pool.user_connection_ids(canceled_user_id) {
3923 session
3924 .peer
3925 .send(
3926 connection_id,
3927 proto::CallCanceled {
3928 room_id: room_id.to_proto(),
3929 },
3930 )
3931 .trace_err();
3932 }
3933 contacts_to_update.insert(canceled_user_id);
3934 }
3935 }
3936
3937 for contact_user_id in contacts_to_update {
3938 update_user_contacts(contact_user_id, session).await?;
3939 }
3940
3941 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3942 live_kit
3943 .remove_participant(livekit_room.clone(), session.user_id().to_string())
3944 .await
3945 .trace_err();
3946
3947 if delete_livekit_room {
3948 live_kit.delete_room(livekit_room).await.trace_err();
3949 }
3950 }
3951
3952 Ok(())
3953}
3954
3955async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3956 let left_channel_buffers = session
3957 .db()
3958 .await
3959 .leave_channel_buffers(session.connection_id)
3960 .await?;
3961
3962 for left_buffer in left_channel_buffers {
3963 channel_buffer_updated(
3964 session.connection_id,
3965 left_buffer.connections,
3966 &proto::UpdateChannelBufferCollaborators {
3967 channel_id: left_buffer.channel_id.to_proto(),
3968 collaborators: left_buffer.collaborators,
3969 },
3970 &session.peer,
3971 );
3972 }
3973
3974 Ok(())
3975}
3976
3977fn project_left(project: &db::LeftProject, session: &Session) {
3978 for connection_id in &project.connection_ids {
3979 if project.should_unshare {
3980 session
3981 .peer
3982 .send(
3983 *connection_id,
3984 proto::UnshareProject {
3985 project_id: project.id.to_proto(),
3986 },
3987 )
3988 .trace_err();
3989 } else {
3990 session
3991 .peer
3992 .send(
3993 *connection_id,
3994 proto::RemoveProjectCollaborator {
3995 project_id: project.id.to_proto(),
3996 peer_id: Some(session.connection_id.into()),
3997 },
3998 )
3999 .trace_err();
4000 }
4001 }
4002}
4003
4004async fn share_agent_thread(
4005 request: proto::ShareAgentThread,
4006 response: Response<proto::ShareAgentThread>,
4007 session: MessageContext,
4008) -> Result<()> {
4009 let user_id = session.user_id();
4010
4011 let share_id = SharedThreadId::from_proto(request.session_id.clone())
4012 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4013
4014 session
4015 .db()
4016 .await
4017 .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4018 .await?;
4019
4020 response.send(proto::Ack {})?;
4021
4022 Ok(())
4023}
4024
4025async fn get_shared_agent_thread(
4026 request: proto::GetSharedAgentThread,
4027 response: Response<proto::GetSharedAgentThread>,
4028 session: MessageContext,
4029) -> Result<()> {
4030 let share_id = SharedThreadId::from_proto(request.session_id)
4031 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4032
4033 let result = session.db().await.get_shared_thread(share_id).await?;
4034
4035 match result {
4036 Some((thread, username)) => {
4037 response.send(proto::GetSharedAgentThreadResponse {
4038 title: thread.title,
4039 thread_data: thread.data,
4040 sharer_username: username,
4041 created_at: thread.created_at.and_utc().to_rfc3339(),
4042 })?;
4043 }
4044 None => {
4045 return Err(anyhow!("Shared thread not found").into());
4046 }
4047 }
4048
4049 Ok(())
4050}
4051
4052pub trait ResultExt {
4053 type Ok;
4054
4055 fn trace_err(self) -> Option<Self::Ok>;
4056}
4057
4058impl<T, E> ResultExt for Result<T, E>
4059where
4060 E: std::fmt::Debug,
4061{
4062 type Ok = T;
4063
4064 #[track_caller]
4065 fn trace_err(self) -> Option<T> {
4066 match self {
4067 Ok(value) => Some(value),
4068 Err(error) => {
4069 tracing::error!("{:?}", error);
4070 None
4071 }
4072 }
4073 }
4074}