1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
10 UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use futures::TryFutureExt as _;
36use rpc::proto::split_repository_update;
37use tracing::Span;
38use util::paths::PathStyle;
39
40use futures::{
41 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
42 stream::FuturesUnordered,
43};
44use prometheus::{IntGauge, register_int_gauge};
45use rpc::{
46 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
47 proto::{
48 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
49 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
50 },
51};
52use semver::Version;
53use std::{
54 any::TypeId,
55 future::Future,
56 marker::PhantomData,
57 mem,
58 net::SocketAddr,
59 ops::{Deref, DerefMut},
60 rc::Rc,
61 sync::{
62 Arc, OnceLock,
63 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
64 },
65 time::{Duration, Instant},
66};
67use tokio::sync::{Semaphore, watch};
68use tower::ServiceBuilder;
69use tracing::{
70 Instrument,
71 field::{self},
72 info_span, instrument,
73};
74
75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
76
77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
79
80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
82
83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
84
85const TOTAL_DURATION_MS: &str = "total_duration_ms";
86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
88const HOST_WAITING_MS: &str = "host_waiting_ms";
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
92
93pub struct ConnectionGuard;
94
95impl ConnectionGuard {
96 pub fn try_acquire() -> Result<Self, ()> {
97 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
98 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
99 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
100 tracing::error!(
101 "too many concurrent connections: {}",
102 current_connections + 1
103 );
104 return Err(());
105 }
106 Ok(ConnectionGuard)
107 }
108}
109
110impl Drop for ConnectionGuard {
111 fn drop(&mut self) {
112 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
113 }
114}
115
116struct Response<R> {
117 peer: Arc<Peer>,
118 receipt: Receipt<R>,
119 responded: Arc<AtomicBool>,
120}
121
122impl<R: RequestMessage> Response<R> {
123 fn send(self, payload: R::Response) -> Result<()> {
124 self.responded.store(true, SeqCst);
125 self.peer.respond(self.receipt, payload)?;
126 Ok(())
127 }
128}
129
130#[derive(Clone, Debug)]
131pub enum Principal {
132 User(User),
133}
134
135impl Principal {
136 fn update_span(&self, span: &tracing::Span) {
137 match &self {
138 Principal::User(user) => {
139 span.record("user_id", user.id.0);
140 span.record("login", &user.github_login);
141 }
142 }
143 }
144}
145
146#[derive(Clone)]
147struct MessageContext {
148 session: Session,
149 span: tracing::Span,
150}
151
152impl Deref for MessageContext {
153 type Target = Session;
154
155 fn deref(&self) -> &Self::Target {
156 &self.session
157 }
158}
159
160impl MessageContext {
161 pub fn forward_request<T: RequestMessage>(
162 &self,
163 receiver_id: ConnectionId,
164 request: T,
165 ) -> impl Future<Output = anyhow::Result<T::Response>> {
166 let request_start_time = Instant::now();
167 let span = self.span.clone();
168 tracing::info!("start forwarding request");
169 self.peer
170 .forward_request(self.connection_id, receiver_id, request)
171 .inspect(move |_| {
172 span.record(
173 HOST_WAITING_MS,
174 request_start_time.elapsed().as_micros() as f64 / 1000.0,
175 );
176 })
177 .inspect_err(|_| tracing::error!("error forwarding request"))
178 .inspect_ok(|_| tracing::info!("finished forwarding request"))
179 }
180}
181
182#[derive(Clone)]
183struct Session {
184 principal: Principal,
185 connection_id: ConnectionId,
186 db: Arc<tokio::sync::Mutex<DbHandle>>,
187 peer: Arc<Peer>,
188 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
189 app_state: Arc<AppState>,
190 /// The GeoIP country code for the user.
191 #[allow(unused)]
192 geoip_country_code: Option<String>,
193 #[allow(unused)]
194 system_id: Option<String>,
195 _executor: Executor,
196}
197
198impl Session {
199 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
200 #[cfg(feature = "test-support")]
201 tokio::task::yield_now().await;
202 let guard = self.db.lock().await;
203 #[cfg(feature = "test-support")]
204 tokio::task::yield_now().await;
205 guard
206 }
207
208 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
209 #[cfg(feature = "test-support")]
210 tokio::task::yield_now().await;
211 let guard = self.connection_pool.lock();
212 ConnectionPoolGuard {
213 guard,
214 _not_send: PhantomData,
215 }
216 }
217
218 #[expect(dead_code)]
219 fn is_staff(&self) -> bool {
220 match &self.principal {
221 Principal::User(user) => user.admin,
222 }
223 }
224
225 fn user_id(&self) -> UserId {
226 match &self.principal {
227 Principal::User(user) => user.id,
228 }
229 }
230}
231
232impl Debug for Session {
233 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
234 let mut result = f.debug_struct("Session");
235 match &self.principal {
236 Principal::User(user) => {
237 result.field("user", &user.github_login);
238 }
239 }
240 result.field("connection_id", &self.connection_id).finish()
241 }
242}
243
244struct DbHandle(Arc<Database>);
245
246impl Deref for DbHandle {
247 type Target = Database;
248
249 fn deref(&self) -> &Self::Target {
250 self.0.as_ref()
251 }
252}
253
254pub struct Server {
255 id: parking_lot::Mutex<ServerId>,
256 peer: Arc<Peer>,
257 pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
258 app_state: Arc<AppState>,
259 handlers: HashMap<TypeId, MessageHandler>,
260 teardown: watch::Sender<bool>,
261}
262
263struct ConnectionPoolGuard<'a> {
264 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
265 _not_send: PhantomData<Rc<()>>,
266}
267
268impl Server {
269 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
270 let mut server = Self {
271 id: parking_lot::Mutex::new(id),
272 peer: Peer::new(id.0 as u32),
273 app_state,
274 connection_pool: Default::default(),
275 handlers: Default::default(),
276 teardown: watch::channel(false).0,
277 };
278
279 server
280 .add_request_handler(ping)
281 .add_request_handler(create_room)
282 .add_request_handler(join_room)
283 .add_request_handler(rejoin_room)
284 .add_request_handler(leave_room)
285 .add_request_handler(set_room_participant_role)
286 .add_request_handler(call)
287 .add_request_handler(cancel_call)
288 .add_message_handler(decline_call)
289 .add_request_handler(update_participant_location)
290 .add_request_handler(share_project)
291 .add_message_handler(unshare_project)
292 .add_request_handler(join_project)
293 .add_message_handler(leave_project)
294 .add_request_handler(update_project)
295 .add_request_handler(update_worktree)
296 .add_request_handler(update_repository)
297 .add_request_handler(remove_repository)
298 .add_message_handler(start_language_server)
299 .add_message_handler(update_language_server)
300 .add_message_handler(update_diagnostic_summary)
301 .add_message_handler(update_worktree_settings)
302 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
307 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
308 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
311 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
313 .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
314 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
316 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
317 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
318 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
319 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
320 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
321 .add_request_handler(
322 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
323 )
324 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
325 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
326 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
327 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
328 .add_request_handler(
329 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
330 )
331 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
332 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
333 .add_request_handler(
334 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
335 )
336 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
337 .add_request_handler(
338 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
339 )
340 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
341 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
342 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
343 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
344 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
345 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
346 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
347 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
348 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
349 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
350 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
353 )
354 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
355 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
356 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
357 .add_request_handler(lsp_query)
358 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
359 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
360 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
361 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
362 .add_message_handler(create_buffer_for_peer)
363 .add_message_handler(create_image_for_peer)
364 .add_request_handler(update_buffer)
365 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
366 .add_message_handler(
367 broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
368 )
369 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
370 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
371 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
372 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
373 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
374 .add_message_handler(
375 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
376 )
377 .add_request_handler(get_users)
378 .add_request_handler(fuzzy_search_users)
379 .add_request_handler(request_contact)
380 .add_request_handler(remove_contact)
381 .add_request_handler(respond_to_contact_request)
382 .add_message_handler(subscribe_to_channels)
383 .add_request_handler(create_channel)
384 .add_request_handler(delete_channel)
385 .add_request_handler(invite_channel_member)
386 .add_request_handler(remove_channel_member)
387 .add_request_handler(set_channel_member_role)
388 .add_request_handler(set_channel_visibility)
389 .add_request_handler(rename_channel)
390 .add_request_handler(join_channel_buffer)
391 .add_request_handler(leave_channel_buffer)
392 .add_message_handler(update_channel_buffer)
393 .add_request_handler(rejoin_channel_buffers)
394 .add_request_handler(get_channel_members)
395 .add_request_handler(respond_to_channel_invite)
396 .add_request_handler(join_channel)
397 .add_request_handler(join_channel_chat)
398 .add_message_handler(leave_channel_chat)
399 .add_request_handler(send_channel_message)
400 .add_request_handler(remove_channel_message)
401 .add_request_handler(update_channel_message)
402 .add_request_handler(get_channel_messages)
403 .add_request_handler(get_channel_messages_by_id)
404 .add_request_handler(get_notifications)
405 .add_request_handler(mark_notification_as_read)
406 .add_request_handler(move_channel)
407 .add_request_handler(reorder_channel)
408 .add_request_handler(follow)
409 .add_message_handler(unfollow)
410 .add_message_handler(update_followers)
411 .add_message_handler(acknowledge_channel_message)
412 .add_message_handler(acknowledge_buffer_version)
413 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
414 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
415 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
416 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
417 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
418 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
419 .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
420 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
421 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
422 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
423 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
424 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
425 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
426 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
427 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
428 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
429 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
430 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
431 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
432 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
433 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
434 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
435 .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
436 .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
437 .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
438 .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
439 .add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
440 .add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
441 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
442 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
443 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
444 .add_request_handler(share_agent_thread)
445 .add_request_handler(get_shared_agent_thread)
446 .add_request_handler(forward_project_search_chunk);
447
448 Arc::new(server)
449 }
450
451 pub async fn start(&self) -> Result<()> {
452 let server_id = *self.id.lock();
453 let app_state = self.app_state.clone();
454 let peer = self.peer.clone();
455 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
456 let pool = self.connection_pool.clone();
457 let livekit_client = self.app_state.livekit_client.clone();
458
459 let span = info_span!("start server");
460 self.app_state.executor.spawn_detached(
461 async move {
462 tracing::info!("waiting for cleanup timeout");
463 timeout.await;
464 tracing::info!("cleanup timeout expired, retrieving stale rooms");
465
466 app_state
467 .db
468 .delete_stale_channel_chat_participants(
469 &app_state.config.zed_environment,
470 server_id,
471 )
472 .await
473 .trace_err();
474
475 if let Some((room_ids, channel_ids)) = app_state
476 .db
477 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
478 .await
479 .trace_err()
480 {
481 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
482 tracing::info!(
483 stale_channel_buffer_count = channel_ids.len(),
484 "retrieved stale channel buffers"
485 );
486
487 for channel_id in channel_ids {
488 if let Some(refreshed_channel_buffer) = app_state
489 .db
490 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
491 .await
492 .trace_err()
493 {
494 for connection_id in refreshed_channel_buffer.connection_ids {
495 peer.send(
496 connection_id,
497 proto::UpdateChannelBufferCollaborators {
498 channel_id: channel_id.to_proto(),
499 collaborators: refreshed_channel_buffer
500 .collaborators
501 .clone(),
502 },
503 )
504 .trace_err();
505 }
506 }
507 }
508
509 for room_id in room_ids {
510 let mut contacts_to_update = HashSet::default();
511 let mut canceled_calls_to_user_ids = Vec::new();
512 let mut livekit_room = String::new();
513 let mut delete_livekit_room = false;
514
515 if let Some(mut refreshed_room) = app_state
516 .db
517 .clear_stale_room_participants(room_id, server_id)
518 .await
519 .trace_err()
520 {
521 tracing::info!(
522 room_id = room_id.0,
523 new_participant_count = refreshed_room.room.participants.len(),
524 "refreshed room"
525 );
526 room_updated(&refreshed_room.room, &peer);
527 if let Some(channel) = refreshed_room.channel.as_ref() {
528 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
529 }
530 contacts_to_update
531 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
532 contacts_to_update
533 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
534 canceled_calls_to_user_ids =
535 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
536 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
537 delete_livekit_room = refreshed_room.room.participants.is_empty();
538 }
539
540 {
541 let pool = pool.lock();
542 for canceled_user_id in canceled_calls_to_user_ids {
543 for connection_id in pool.user_connection_ids(canceled_user_id) {
544 peer.send(
545 connection_id,
546 proto::CallCanceled {
547 room_id: room_id.to_proto(),
548 },
549 )
550 .trace_err();
551 }
552 }
553 }
554
555 for user_id in contacts_to_update {
556 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
557 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
558 if let Some((busy, contacts)) = busy.zip(contacts) {
559 let pool = pool.lock();
560 let updated_contact = contact_for_user(user_id, busy, &pool);
561 for contact in contacts {
562 if let db::Contact::Accepted {
563 user_id: contact_user_id,
564 ..
565 } = contact
566 {
567 for contact_conn_id in
568 pool.user_connection_ids(contact_user_id)
569 {
570 peer.send(
571 contact_conn_id,
572 proto::UpdateContacts {
573 contacts: vec![updated_contact.clone()],
574 remove_contacts: Default::default(),
575 incoming_requests: Default::default(),
576 remove_incoming_requests: Default::default(),
577 outgoing_requests: Default::default(),
578 remove_outgoing_requests: Default::default(),
579 },
580 )
581 .trace_err();
582 }
583 }
584 }
585 }
586 }
587
588 if let Some(live_kit) = livekit_client.as_ref()
589 && delete_livekit_room
590 {
591 live_kit.delete_room(livekit_room).await.trace_err();
592 }
593 }
594 }
595
596 app_state
597 .db
598 .delete_stale_channel_chat_participants(
599 &app_state.config.zed_environment,
600 server_id,
601 )
602 .await
603 .trace_err();
604
605 app_state
606 .db
607 .clear_old_worktree_entries(server_id)
608 .await
609 .trace_err();
610
611 app_state
612 .db
613 .delete_stale_servers(&app_state.config.zed_environment, server_id)
614 .await
615 .trace_err();
616 }
617 .instrument(span),
618 );
619 Ok(())
620 }
621
622 pub fn teardown(&self) {
623 self.peer.teardown();
624 self.connection_pool.lock().reset();
625 let _ = self.teardown.send(true);
626 }
627
628 #[cfg(feature = "test-support")]
629 pub fn reset(&self, id: ServerId) {
630 self.teardown();
631 *self.id.lock() = id;
632 self.peer.reset(id.0 as u32);
633 let _ = self.teardown.send(false);
634 }
635
636 #[cfg(feature = "test-support")]
637 pub fn id(&self) -> ServerId {
638 *self.id.lock()
639 }
640
641 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
642 where
643 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
644 Fut: 'static + Send + Future<Output = Result<()>>,
645 M: EnvelopedMessage,
646 {
647 let prev_handler = self.handlers.insert(
648 TypeId::of::<M>(),
649 Box::new(move |envelope, session, span| {
650 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
651 let received_at = envelope.received_at;
652 tracing::info!("message received");
653 let start_time = Instant::now();
654 let future = (handler)(
655 *envelope,
656 MessageContext {
657 session,
658 span: span.clone(),
659 },
660 );
661 async move {
662 let result = future.await;
663 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
664 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
665 let queue_duration_ms = total_duration_ms - processing_duration_ms;
666 span.record(TOTAL_DURATION_MS, total_duration_ms);
667 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
668 span.record(QUEUE_DURATION_MS, queue_duration_ms);
669 match result {
670 Err(error) => {
671 tracing::error!(?error, "error handling message")
672 }
673 Ok(()) => tracing::info!("finished handling message"),
674 }
675 }
676 .boxed()
677 }),
678 );
679 if prev_handler.is_some() {
680 panic!("registered a handler for the same message twice");
681 }
682 self
683 }
684
685 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
686 where
687 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
688 Fut: 'static + Send + Future<Output = Result<()>>,
689 M: EnvelopedMessage,
690 {
691 self.add_handler(move |envelope, session| handler(envelope.payload, session));
692 self
693 }
694
695 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
696 where
697 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
698 Fut: Send + Future<Output = Result<()>>,
699 M: RequestMessage,
700 {
701 let handler = Arc::new(handler);
702 self.add_handler(move |envelope, session| {
703 let receipt = envelope.receipt();
704 let handler = handler.clone();
705 async move {
706 let peer = session.peer.clone();
707 let responded = Arc::new(AtomicBool::default());
708 let response = Response {
709 peer: peer.clone(),
710 responded: responded.clone(),
711 receipt,
712 };
713 match (handler)(envelope.payload, response, session).await {
714 Ok(()) => {
715 if responded.load(std::sync::atomic::Ordering::SeqCst) {
716 Ok(())
717 } else {
718 Err(anyhow!("handler did not send a response"))?
719 }
720 }
721 Err(error) => {
722 let proto_err = match &error {
723 Error::Internal(err) => err.to_proto(),
724 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
725 };
726 peer.respond_with_error(receipt, proto_err)?;
727 Err(error)
728 }
729 }
730 }
731 })
732 }
733
734 pub fn handle_connection(
735 self: &Arc<Self>,
736 connection: Connection,
737 address: String,
738 principal: Principal,
739 zed_version: ZedVersion,
740 release_channel: Option<String>,
741 user_agent: Option<String>,
742 geoip_country_code: Option<String>,
743 system_id: Option<String>,
744 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
745 executor: Executor,
746 connection_guard: Option<ConnectionGuard>,
747 ) -> impl Future<Output = ()> + use<> {
748 let this = self.clone();
749 let span = info_span!("handle connection", %address,
750 connection_id=field::Empty,
751 user_id=field::Empty,
752 login=field::Empty,
753 user_agent=field::Empty,
754 geoip_country_code=field::Empty,
755 release_channel=field::Empty,
756 );
757 principal.update_span(&span);
758 if let Some(user_agent) = user_agent {
759 span.record("user_agent", user_agent);
760 }
761 if let Some(release_channel) = release_channel {
762 span.record("release_channel", release_channel);
763 }
764
765 if let Some(country_code) = geoip_country_code.as_ref() {
766 span.record("geoip_country_code", country_code);
767 }
768
769 let mut teardown = self.teardown.subscribe();
770 async move {
771 if *teardown.borrow() {
772 tracing::error!("server is tearing down");
773 return;
774 }
775
776 let (connection_id, handle_io, mut incoming_rx) =
777 this.peer.add_connection(connection, {
778 let executor = executor.clone();
779 move |duration| executor.sleep(duration)
780 });
781 tracing::Span::current().record("connection_id", format!("{}", connection_id));
782
783 tracing::info!("connection opened");
784
785 let session = Session {
786 principal: principal.clone(),
787 connection_id,
788 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
789 peer: this.peer.clone(),
790 connection_pool: this.connection_pool.clone(),
791 app_state: this.app_state.clone(),
792 geoip_country_code,
793 system_id,
794 _executor: executor.clone(),
795 };
796
797 if let Err(error) = this
798 .send_initial_client_update(
799 connection_id,
800 zed_version,
801 send_connection_id,
802 &session,
803 )
804 .await
805 {
806 tracing::error!(?error, "failed to send initial client update");
807 return;
808 }
809 drop(connection_guard);
810
811 let handle_io = handle_io.fuse();
812 futures::pin_mut!(handle_io);
813
814 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
815 // This prevents deadlocks when e.g., client A performs a request to client B and
816 // client B performs a request to client A. If both clients stop processing further
817 // messages until their respective request completes, they won't have a chance to
818 // respond to the other client's request and cause a deadlock.
819 //
820 // This arrangement ensures we will attempt to process earlier messages first, but fall
821 // back to processing messages arrived later in the spirit of making progress.
822 const MAX_CONCURRENT_HANDLERS: usize = 256;
823 let mut foreground_message_handlers = FuturesUnordered::new();
824 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
825 let get_concurrent_handlers = {
826 let concurrent_handlers = concurrent_handlers.clone();
827 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
828 };
829 loop {
830 let next_message = async {
831 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
832 let message = incoming_rx.next().await;
833 // Cache the concurrent_handlers here, so that we know what the
834 // queue looks like as each handler starts
835 (permit, message, get_concurrent_handlers())
836 }
837 .fuse();
838 futures::pin_mut!(next_message);
839 futures::select_biased! {
840 _ = teardown.changed().fuse() => return,
841 result = handle_io => {
842 if let Err(error) = result {
843 tracing::error!(?error, "error handling I/O");
844 }
845 break;
846 }
847 _ = foreground_message_handlers.next() => {}
848 next_message = next_message => {
849 let (permit, message, concurrent_handlers) = next_message;
850 if let Some(message) = message {
851 let type_name = message.payload_type_name();
852 // note: we copy all the fields from the parent span so we can query them in the logs.
853 // (https://github.com/tokio-rs/tracing/issues/2670).
854 let span = tracing::info_span!("receive message",
855 %connection_id,
856 %address,
857 type_name,
858 concurrent_handlers,
859 user_id=field::Empty,
860 login=field::Empty,
861 lsp_query_request=field::Empty,
862 release_channel=field::Empty,
863 { TOTAL_DURATION_MS }=field::Empty,
864 { PROCESSING_DURATION_MS }=field::Empty,
865 { QUEUE_DURATION_MS }=field::Empty,
866 { HOST_WAITING_MS }=field::Empty
867 );
868 principal.update_span(&span);
869 let span_enter = span.enter();
870 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
871 let is_background = message.is_background();
872 let handle_message = (handler)(message, session.clone(), span.clone());
873 drop(span_enter);
874
875 let handle_message = async move {
876 handle_message.await;
877 drop(permit);
878 }.instrument(span);
879 if is_background {
880 executor.spawn_detached(handle_message);
881 } else {
882 foreground_message_handlers.push(handle_message);
883 }
884 } else {
885 tracing::error!("no message handler");
886 }
887 } else {
888 tracing::info!("connection closed");
889 break;
890 }
891 }
892 }
893 }
894
895 drop(foreground_message_handlers);
896 let concurrent_handlers = get_concurrent_handlers();
897 tracing::info!(concurrent_handlers, "signing out");
898 if let Err(error) = connection_lost(session, teardown, executor).await {
899 tracing::error!(?error, "error signing out");
900 }
901 }
902 .instrument(span)
903 }
904
905 async fn send_initial_client_update(
906 &self,
907 connection_id: ConnectionId,
908 zed_version: ZedVersion,
909 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
910 session: &Session,
911 ) -> Result<()> {
912 self.peer.send(
913 connection_id,
914 proto::Hello {
915 peer_id: Some(connection_id.into()),
916 },
917 )?;
918 tracing::info!("sent hello message");
919 if let Some(send_connection_id) = send_connection_id.take() {
920 let _ = send_connection_id.send(connection_id);
921 }
922
923 match &session.principal {
924 Principal::User(user) => {
925 if !user.connected_once {
926 self.peer.send(connection_id, proto::ShowContacts {})?;
927 self.app_state
928 .db
929 .set_user_connected_once(user.id, true)
930 .await?;
931 }
932
933 let contacts = self.app_state.db.get_contacts(user.id).await?;
934
935 {
936 let mut pool = self.connection_pool.lock();
937 pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
938 self.peer.send(
939 connection_id,
940 build_initial_contacts_update(contacts, &pool),
941 )?;
942 }
943
944 if should_auto_subscribe_to_channels(&zed_version) {
945 subscribe_user_to_channels(user.id, session).await?;
946 }
947
948 if let Some(incoming_call) =
949 self.app_state.db.incoming_call_for_user(user.id).await?
950 {
951 self.peer.send(connection_id, incoming_call)?;
952 }
953
954 update_user_contacts(user.id, session).await?;
955 }
956 }
957
958 Ok(())
959 }
960}
961
962impl Deref for ConnectionPoolGuard<'_> {
963 type Target = ConnectionPool;
964
965 fn deref(&self) -> &Self::Target {
966 &self.guard
967 }
968}
969
970impl DerefMut for ConnectionPoolGuard<'_> {
971 fn deref_mut(&mut self) -> &mut Self::Target {
972 &mut self.guard
973 }
974}
975
976impl Drop for ConnectionPoolGuard<'_> {
977 fn drop(&mut self) {
978 #[cfg(feature = "test-support")]
979 self.check_invariants();
980 }
981}
982
983fn broadcast<F>(
984 sender_id: Option<ConnectionId>,
985 receiver_ids: impl IntoIterator<Item = ConnectionId>,
986 mut f: F,
987) where
988 F: FnMut(ConnectionId) -> anyhow::Result<()>,
989{
990 for receiver_id in receiver_ids {
991 if Some(receiver_id) != sender_id
992 && let Err(error) = f(receiver_id)
993 {
994 tracing::error!("failed to send to {:?} {}", receiver_id, error);
995 }
996 }
997}
998
999pub struct ProtocolVersion(u32);
1000
1001impl Header for ProtocolVersion {
1002 fn name() -> &'static HeaderName {
1003 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1004 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1005 }
1006
1007 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1008 where
1009 Self: Sized,
1010 I: Iterator<Item = &'i axum::http::HeaderValue>,
1011 {
1012 let version = values
1013 .next()
1014 .ok_or_else(axum::headers::Error::invalid)?
1015 .to_str()
1016 .map_err(|_| axum::headers::Error::invalid())?
1017 .parse()
1018 .map_err(|_| axum::headers::Error::invalid())?;
1019 Ok(Self(version))
1020 }
1021
1022 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1023 values.extend([self.0.to_string().parse().unwrap()]);
1024 }
1025}
1026
1027pub struct AppVersionHeader(Version);
1028impl Header for AppVersionHeader {
1029 fn name() -> &'static HeaderName {
1030 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1031 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1032 }
1033
1034 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1035 where
1036 Self: Sized,
1037 I: Iterator<Item = &'i axum::http::HeaderValue>,
1038 {
1039 let version = values
1040 .next()
1041 .ok_or_else(axum::headers::Error::invalid)?
1042 .to_str()
1043 .map_err(|_| axum::headers::Error::invalid())?
1044 .parse()
1045 .map_err(|_| axum::headers::Error::invalid())?;
1046 Ok(Self(version))
1047 }
1048
1049 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1050 values.extend([self.0.to_string().parse().unwrap()]);
1051 }
1052}
1053
1054#[derive(Debug)]
1055pub struct ReleaseChannelHeader(String);
1056
1057impl Header for ReleaseChannelHeader {
1058 fn name() -> &'static HeaderName {
1059 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1060 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1061 }
1062
1063 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1064 where
1065 Self: Sized,
1066 I: Iterator<Item = &'i axum::http::HeaderValue>,
1067 {
1068 Ok(Self(
1069 values
1070 .next()
1071 .ok_or_else(axum::headers::Error::invalid)?
1072 .to_str()
1073 .map_err(|_| axum::headers::Error::invalid())?
1074 .to_owned(),
1075 ))
1076 }
1077
1078 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1079 values.extend([self.0.parse().unwrap()]);
1080 }
1081}
1082
1083pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1084 Router::new()
1085 .route("/rpc", get(handle_websocket_request))
1086 .layer(
1087 ServiceBuilder::new()
1088 .layer(Extension(server.app_state.clone()))
1089 .layer(middleware::from_fn(auth::validate_header)),
1090 )
1091 .route("/metrics", get(handle_metrics))
1092 .layer(Extension(server))
1093}
1094
1095pub async fn handle_websocket_request(
1096 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1097 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1098 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1099 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1100 Extension(server): Extension<Arc<Server>>,
1101 Extension(principal): Extension<Principal>,
1102 user_agent: Option<TypedHeader<UserAgent>>,
1103 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1104 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1105 ws: WebSocketUpgrade,
1106) -> axum::response::Response {
1107 if protocol_version != rpc::PROTOCOL_VERSION {
1108 return (
1109 StatusCode::UPGRADE_REQUIRED,
1110 "client must be upgraded".to_string(),
1111 )
1112 .into_response();
1113 }
1114
1115 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1116 return (
1117 StatusCode::UPGRADE_REQUIRED,
1118 "no version header found".to_string(),
1119 )
1120 .into_response();
1121 };
1122
1123 let release_channel = release_channel_header.map(|header| header.0.0);
1124
1125 if !version.can_collaborate() {
1126 return (
1127 StatusCode::UPGRADE_REQUIRED,
1128 "client must be upgraded".to_string(),
1129 )
1130 .into_response();
1131 }
1132
1133 let socket_address = socket_address.to_string();
1134
1135 // Acquire connection guard before WebSocket upgrade
1136 let connection_guard = match ConnectionGuard::try_acquire() {
1137 Ok(guard) => guard,
1138 Err(()) => {
1139 return (
1140 StatusCode::SERVICE_UNAVAILABLE,
1141 "Too many concurrent connections",
1142 )
1143 .into_response();
1144 }
1145 };
1146
1147 ws.on_upgrade(move |socket| {
1148 let socket = socket
1149 .map_ok(to_tungstenite_message)
1150 .err_into()
1151 .with(|message| async move { to_axum_message(message) });
1152 let connection = Connection::new(Box::pin(socket));
1153 async move {
1154 server
1155 .handle_connection(
1156 connection,
1157 socket_address,
1158 principal,
1159 version,
1160 release_channel,
1161 user_agent.map(|header| header.to_string()),
1162 country_code_header.map(|header| header.to_string()),
1163 system_id_header.map(|header| header.to_string()),
1164 None,
1165 Executor::Production,
1166 Some(connection_guard),
1167 )
1168 .await;
1169 }
1170 })
1171}
1172
1173pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1174 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1175 let connections_metric = CONNECTIONS_METRIC
1176 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1177
1178 let connections = server
1179 .connection_pool
1180 .lock()
1181 .connections()
1182 .filter(|connection| !connection.admin)
1183 .count();
1184 connections_metric.set(connections as _);
1185
1186 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1187 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1188 register_int_gauge!(
1189 "shared_projects",
1190 "number of open projects with one or more guests"
1191 )
1192 .unwrap()
1193 });
1194
1195 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1196 shared_projects_metric.set(shared_projects as _);
1197
1198 let encoder = prometheus::TextEncoder::new();
1199 let metric_families = prometheus::gather();
1200 let encoded_metrics = encoder
1201 .encode_to_string(&metric_families)
1202 .map_err(|err| anyhow!("{err}"))?;
1203 Ok(encoded_metrics)
1204}
1205
1206#[instrument(err, skip(executor))]
1207async fn connection_lost(
1208 session: Session,
1209 mut teardown: watch::Receiver<bool>,
1210 executor: Executor,
1211) -> Result<()> {
1212 session.peer.disconnect(session.connection_id);
1213 session
1214 .connection_pool()
1215 .await
1216 .remove_connection(session.connection_id)?;
1217
1218 session
1219 .db()
1220 .await
1221 .connection_lost(session.connection_id)
1222 .await
1223 .trace_err();
1224
1225 futures::select_biased! {
1226 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1227
1228 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1229 leave_room_for_session(&session, session.connection_id).await.trace_err();
1230 leave_channel_buffers_for_session(&session)
1231 .await
1232 .trace_err();
1233
1234 if !session
1235 .connection_pool()
1236 .await
1237 .is_user_online(session.user_id())
1238 {
1239 let db = session.db().await;
1240 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1241 room_updated(&room, &session.peer);
1242 }
1243 }
1244
1245 update_user_contacts(session.user_id(), &session).await?;
1246 },
1247 _ = teardown.changed().fuse() => {}
1248 }
1249
1250 Ok(())
1251}
1252
1253/// Acknowledges a ping from a client, used to keep the connection alive.
1254async fn ping(
1255 _: proto::Ping,
1256 response: Response<proto::Ping>,
1257 _session: MessageContext,
1258) -> Result<()> {
1259 response.send(proto::Ack {})?;
1260 Ok(())
1261}
1262
1263/// Creates a new room for calling (outside of channels)
1264async fn create_room(
1265 _request: proto::CreateRoom,
1266 response: Response<proto::CreateRoom>,
1267 session: MessageContext,
1268) -> Result<()> {
1269 let livekit_room = nanoid::nanoid!(30);
1270
1271 let live_kit_connection_info = util::maybe!(async {
1272 let live_kit = session.app_state.livekit_client.as_ref();
1273 let live_kit = live_kit?;
1274 let user_id = session.user_id().to_string();
1275
1276 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1277
1278 Some(proto::LiveKitConnectionInfo {
1279 server_url: live_kit.url().into(),
1280 token,
1281 can_publish: true,
1282 })
1283 })
1284 .await;
1285
1286 let room = session
1287 .db()
1288 .await
1289 .create_room(session.user_id(), session.connection_id, &livekit_room)
1290 .await?;
1291
1292 response.send(proto::CreateRoomResponse {
1293 room: Some(room.clone()),
1294 live_kit_connection_info,
1295 })?;
1296
1297 update_user_contacts(session.user_id(), &session).await?;
1298 Ok(())
1299}
1300
1301/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1302async fn join_room(
1303 request: proto::JoinRoom,
1304 response: Response<proto::JoinRoom>,
1305 session: MessageContext,
1306) -> Result<()> {
1307 let room_id = RoomId::from_proto(request.id);
1308
1309 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1310
1311 if let Some(channel_id) = channel_id {
1312 return join_channel_internal(channel_id, Box::new(response), session).await;
1313 }
1314
1315 let joined_room = {
1316 let room = session
1317 .db()
1318 .await
1319 .join_room(room_id, session.user_id(), session.connection_id)
1320 .await?;
1321 room_updated(&room.room, &session.peer);
1322 room.into_inner()
1323 };
1324
1325 for connection_id in session
1326 .connection_pool()
1327 .await
1328 .user_connection_ids(session.user_id())
1329 {
1330 session
1331 .peer
1332 .send(
1333 connection_id,
1334 proto::CallCanceled {
1335 room_id: room_id.to_proto(),
1336 },
1337 )
1338 .trace_err();
1339 }
1340
1341 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1342 {
1343 live_kit
1344 .room_token(
1345 &joined_room.room.livekit_room,
1346 &session.user_id().to_string(),
1347 )
1348 .trace_err()
1349 .map(|token| proto::LiveKitConnectionInfo {
1350 server_url: live_kit.url().into(),
1351 token,
1352 can_publish: true,
1353 })
1354 } else {
1355 None
1356 };
1357
1358 response.send(proto::JoinRoomResponse {
1359 room: Some(joined_room.room),
1360 channel_id: None,
1361 live_kit_connection_info,
1362 })?;
1363
1364 update_user_contacts(session.user_id(), &session).await?;
1365 Ok(())
1366}
1367
1368/// Rejoin room is used to reconnect to a room after connection errors.
1369async fn rejoin_room(
1370 request: proto::RejoinRoom,
1371 response: Response<proto::RejoinRoom>,
1372 session: MessageContext,
1373) -> Result<()> {
1374 let room;
1375 let channel;
1376 {
1377 let mut rejoined_room = session
1378 .db()
1379 .await
1380 .rejoin_room(request, session.user_id(), session.connection_id)
1381 .await?;
1382
1383 response.send(proto::RejoinRoomResponse {
1384 room: Some(rejoined_room.room.clone()),
1385 reshared_projects: rejoined_room
1386 .reshared_projects
1387 .iter()
1388 .map(|project| proto::ResharedProject {
1389 id: project.id.to_proto(),
1390 collaborators: project
1391 .collaborators
1392 .iter()
1393 .map(|collaborator| collaborator.to_proto())
1394 .collect(),
1395 })
1396 .collect(),
1397 rejoined_projects: rejoined_room
1398 .rejoined_projects
1399 .iter()
1400 .map(|rejoined_project| rejoined_project.to_proto())
1401 .collect(),
1402 })?;
1403 room_updated(&rejoined_room.room, &session.peer);
1404
1405 for project in &rejoined_room.reshared_projects {
1406 for collaborator in &project.collaborators {
1407 session
1408 .peer
1409 .send(
1410 collaborator.connection_id,
1411 proto::UpdateProjectCollaborator {
1412 project_id: project.id.to_proto(),
1413 old_peer_id: Some(project.old_connection_id.into()),
1414 new_peer_id: Some(session.connection_id.into()),
1415 },
1416 )
1417 .trace_err();
1418 }
1419
1420 broadcast(
1421 Some(session.connection_id),
1422 project
1423 .collaborators
1424 .iter()
1425 .map(|collaborator| collaborator.connection_id),
1426 |connection_id| {
1427 session.peer.forward_send(
1428 session.connection_id,
1429 connection_id,
1430 proto::UpdateProject {
1431 project_id: project.id.to_proto(),
1432 worktrees: project.worktrees.clone(),
1433 },
1434 )
1435 },
1436 );
1437 }
1438
1439 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1440
1441 let rejoined_room = rejoined_room.into_inner();
1442
1443 room = rejoined_room.room;
1444 channel = rejoined_room.channel;
1445 }
1446
1447 if let Some(channel) = channel {
1448 channel_updated(
1449 &channel,
1450 &room,
1451 &session.peer,
1452 &*session.connection_pool().await,
1453 );
1454 }
1455
1456 update_user_contacts(session.user_id(), &session).await?;
1457 Ok(())
1458}
1459
1460fn notify_rejoined_projects(
1461 rejoined_projects: &mut Vec<RejoinedProject>,
1462 session: &Session,
1463) -> Result<()> {
1464 for project in rejoined_projects.iter() {
1465 for collaborator in &project.collaborators {
1466 session
1467 .peer
1468 .send(
1469 collaborator.connection_id,
1470 proto::UpdateProjectCollaborator {
1471 project_id: project.id.to_proto(),
1472 old_peer_id: Some(project.old_connection_id.into()),
1473 new_peer_id: Some(session.connection_id.into()),
1474 },
1475 )
1476 .trace_err();
1477 }
1478 }
1479
1480 for project in rejoined_projects {
1481 for worktree in mem::take(&mut project.worktrees) {
1482 // Stream this worktree's entries.
1483 let message = proto::UpdateWorktree {
1484 project_id: project.id.to_proto(),
1485 worktree_id: worktree.id,
1486 abs_path: worktree.abs_path.clone(),
1487 root_name: worktree.root_name,
1488 updated_entries: worktree.updated_entries,
1489 removed_entries: worktree.removed_entries,
1490 scan_id: worktree.scan_id,
1491 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1492 updated_repositories: worktree.updated_repositories,
1493 removed_repositories: worktree.removed_repositories,
1494 };
1495 for update in proto::split_worktree_update(message) {
1496 session.peer.send(session.connection_id, update)?;
1497 }
1498
1499 // Stream this worktree's diagnostics.
1500 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1501 if let Some(summary) = worktree_diagnostics.next() {
1502 let message = proto::UpdateDiagnosticSummary {
1503 project_id: project.id.to_proto(),
1504 worktree_id: worktree.id,
1505 summary: Some(summary),
1506 more_summaries: worktree_diagnostics.collect(),
1507 };
1508 session.peer.send(session.connection_id, message)?;
1509 }
1510
1511 for settings_file in worktree.settings_files {
1512 session.peer.send(
1513 session.connection_id,
1514 proto::UpdateWorktreeSettings {
1515 project_id: project.id.to_proto(),
1516 worktree_id: worktree.id,
1517 path: settings_file.path,
1518 content: Some(settings_file.content),
1519 kind: Some(settings_file.kind.to_proto().into()),
1520 outside_worktree: Some(settings_file.outside_worktree),
1521 },
1522 )?;
1523 }
1524 }
1525
1526 for repository in mem::take(&mut project.updated_repositories) {
1527 for update in split_repository_update(repository) {
1528 session.peer.send(session.connection_id, update)?;
1529 }
1530 }
1531
1532 for id in mem::take(&mut project.removed_repositories) {
1533 session.peer.send(
1534 session.connection_id,
1535 proto::RemoveRepository {
1536 project_id: project.id.to_proto(),
1537 id,
1538 },
1539 )?;
1540 }
1541 }
1542
1543 Ok(())
1544}
1545
1546/// leave room disconnects from the room.
1547async fn leave_room(
1548 _: proto::LeaveRoom,
1549 response: Response<proto::LeaveRoom>,
1550 session: MessageContext,
1551) -> Result<()> {
1552 leave_room_for_session(&session, session.connection_id).await?;
1553 response.send(proto::Ack {})?;
1554 Ok(())
1555}
1556
1557/// Updates the permissions of someone else in the room.
1558async fn set_room_participant_role(
1559 request: proto::SetRoomParticipantRole,
1560 response: Response<proto::SetRoomParticipantRole>,
1561 session: MessageContext,
1562) -> Result<()> {
1563 let user_id = UserId::from_proto(request.user_id);
1564 let role = ChannelRole::from(request.role());
1565
1566 let (livekit_room, can_publish) = {
1567 let room = session
1568 .db()
1569 .await
1570 .set_room_participant_role(
1571 session.user_id(),
1572 RoomId::from_proto(request.room_id),
1573 user_id,
1574 role,
1575 )
1576 .await?;
1577
1578 let livekit_room = room.livekit_room.clone();
1579 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1580 room_updated(&room, &session.peer);
1581 (livekit_room, can_publish)
1582 };
1583
1584 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1585 live_kit
1586 .update_participant(
1587 livekit_room.clone(),
1588 request.user_id.to_string(),
1589 livekit_api::proto::ParticipantPermission {
1590 can_subscribe: true,
1591 can_publish,
1592 can_publish_data: can_publish,
1593 hidden: false,
1594 recorder: false,
1595 },
1596 )
1597 .await
1598 .trace_err();
1599 }
1600
1601 response.send(proto::Ack {})?;
1602 Ok(())
1603}
1604
1605/// Call someone else into the current room
1606async fn call(
1607 request: proto::Call,
1608 response: Response<proto::Call>,
1609 session: MessageContext,
1610) -> Result<()> {
1611 let room_id = RoomId::from_proto(request.room_id);
1612 let calling_user_id = session.user_id();
1613 let calling_connection_id = session.connection_id;
1614 let called_user_id = UserId::from_proto(request.called_user_id);
1615 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1616 if !session
1617 .db()
1618 .await
1619 .has_contact(calling_user_id, called_user_id)
1620 .await?
1621 {
1622 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1623 }
1624
1625 let incoming_call = {
1626 let (room, incoming_call) = &mut *session
1627 .db()
1628 .await
1629 .call(
1630 room_id,
1631 calling_user_id,
1632 calling_connection_id,
1633 called_user_id,
1634 initial_project_id,
1635 )
1636 .await?;
1637 room_updated(room, &session.peer);
1638 mem::take(incoming_call)
1639 };
1640 update_user_contacts(called_user_id, &session).await?;
1641
1642 let mut calls = session
1643 .connection_pool()
1644 .await
1645 .user_connection_ids(called_user_id)
1646 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1647 .collect::<FuturesUnordered<_>>();
1648
1649 while let Some(call_response) = calls.next().await {
1650 match call_response.as_ref() {
1651 Ok(_) => {
1652 response.send(proto::Ack {})?;
1653 return Ok(());
1654 }
1655 Err(_) => {
1656 call_response.trace_err();
1657 }
1658 }
1659 }
1660
1661 {
1662 let room = session
1663 .db()
1664 .await
1665 .call_failed(room_id, called_user_id)
1666 .await?;
1667 room_updated(&room, &session.peer);
1668 }
1669 update_user_contacts(called_user_id, &session).await?;
1670
1671 Err(anyhow!("failed to ring user"))?
1672}
1673
1674/// Cancel an outgoing call.
1675async fn cancel_call(
1676 request: proto::CancelCall,
1677 response: Response<proto::CancelCall>,
1678 session: MessageContext,
1679) -> Result<()> {
1680 let called_user_id = UserId::from_proto(request.called_user_id);
1681 let room_id = RoomId::from_proto(request.room_id);
1682 {
1683 let room = session
1684 .db()
1685 .await
1686 .cancel_call(room_id, session.connection_id, called_user_id)
1687 .await?;
1688 room_updated(&room, &session.peer);
1689 }
1690
1691 for connection_id in session
1692 .connection_pool()
1693 .await
1694 .user_connection_ids(called_user_id)
1695 {
1696 session
1697 .peer
1698 .send(
1699 connection_id,
1700 proto::CallCanceled {
1701 room_id: room_id.to_proto(),
1702 },
1703 )
1704 .trace_err();
1705 }
1706 response.send(proto::Ack {})?;
1707
1708 update_user_contacts(called_user_id, &session).await?;
1709 Ok(())
1710}
1711
1712/// Decline an incoming call.
1713async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1714 let room_id = RoomId::from_proto(message.room_id);
1715 {
1716 let room = session
1717 .db()
1718 .await
1719 .decline_call(Some(room_id), session.user_id())
1720 .await?
1721 .context("declining call")?;
1722 room_updated(&room, &session.peer);
1723 }
1724
1725 for connection_id in session
1726 .connection_pool()
1727 .await
1728 .user_connection_ids(session.user_id())
1729 {
1730 session
1731 .peer
1732 .send(
1733 connection_id,
1734 proto::CallCanceled {
1735 room_id: room_id.to_proto(),
1736 },
1737 )
1738 .trace_err();
1739 }
1740 update_user_contacts(session.user_id(), &session).await?;
1741 Ok(())
1742}
1743
1744/// Updates other participants in the room with your current location.
1745async fn update_participant_location(
1746 request: proto::UpdateParticipantLocation,
1747 response: Response<proto::UpdateParticipantLocation>,
1748 session: MessageContext,
1749) -> Result<()> {
1750 let room_id = RoomId::from_proto(request.room_id);
1751 let location = request.location.context("invalid location")?;
1752
1753 let db = session.db().await;
1754 let room = db
1755 .update_room_participant_location(room_id, session.connection_id, location)
1756 .await?;
1757
1758 room_updated(&room, &session.peer);
1759 response.send(proto::Ack {})?;
1760 Ok(())
1761}
1762
1763/// Share a project into the room.
1764async fn share_project(
1765 request: proto::ShareProject,
1766 response: Response<proto::ShareProject>,
1767 session: MessageContext,
1768) -> Result<()> {
1769 let (project_id, room) = &*session
1770 .db()
1771 .await
1772 .share_project(
1773 RoomId::from_proto(request.room_id),
1774 session.connection_id,
1775 &request.worktrees,
1776 request.is_ssh_project,
1777 request.windows_paths.unwrap_or(false),
1778 )
1779 .await?;
1780 response.send(proto::ShareProjectResponse {
1781 project_id: project_id.to_proto(),
1782 })?;
1783 room_updated(room, &session.peer);
1784
1785 Ok(())
1786}
1787
1788/// Unshare a project from the room.
1789async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1790 let project_id = ProjectId::from_proto(message.project_id);
1791 unshare_project_internal(project_id, session.connection_id, &session).await
1792}
1793
1794async fn unshare_project_internal(
1795 project_id: ProjectId,
1796 connection_id: ConnectionId,
1797 session: &Session,
1798) -> Result<()> {
1799 let delete = {
1800 let room_guard = session
1801 .db()
1802 .await
1803 .unshare_project(project_id, connection_id)
1804 .await?;
1805
1806 let (delete, room, guest_connection_ids) = &*room_guard;
1807
1808 let message = proto::UnshareProject {
1809 project_id: project_id.to_proto(),
1810 };
1811
1812 broadcast(
1813 Some(connection_id),
1814 guest_connection_ids.iter().copied(),
1815 |conn_id| session.peer.send(conn_id, message.clone()),
1816 );
1817 if let Some(room) = room {
1818 room_updated(room, &session.peer);
1819 }
1820
1821 *delete
1822 };
1823
1824 if delete {
1825 let db = session.db().await;
1826 db.delete_project(project_id).await?;
1827 }
1828
1829 Ok(())
1830}
1831
1832/// Join someone elses shared project.
1833async fn join_project(
1834 request: proto::JoinProject,
1835 response: Response<proto::JoinProject>,
1836 session: MessageContext,
1837) -> Result<()> {
1838 let project_id = ProjectId::from_proto(request.project_id);
1839
1840 tracing::info!(%project_id, "join project");
1841
1842 let db = session.db().await;
1843 let (project, replica_id) = &mut *db
1844 .join_project(
1845 project_id,
1846 session.connection_id,
1847 session.user_id(),
1848 request.committer_name.clone(),
1849 request.committer_email.clone(),
1850 )
1851 .await?;
1852 drop(db);
1853 tracing::info!(%project_id, "join remote project");
1854 let collaborators = project
1855 .collaborators
1856 .iter()
1857 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1858 .map(|collaborator| collaborator.to_proto())
1859 .collect::<Vec<_>>();
1860 let project_id = project.id;
1861 let guest_user_id = session.user_id();
1862
1863 let worktrees = project
1864 .worktrees
1865 .iter()
1866 .map(|(id, worktree)| proto::WorktreeMetadata {
1867 id: *id,
1868 root_name: worktree.root_name.clone(),
1869 visible: worktree.visible,
1870 abs_path: worktree.abs_path.clone(),
1871 })
1872 .collect::<Vec<_>>();
1873
1874 let add_project_collaborator = proto::AddProjectCollaborator {
1875 project_id: project_id.to_proto(),
1876 collaborator: Some(proto::Collaborator {
1877 peer_id: Some(session.connection_id.into()),
1878 replica_id: replica_id.0 as u32,
1879 user_id: guest_user_id.to_proto(),
1880 is_host: false,
1881 committer_name: request.committer_name.clone(),
1882 committer_email: request.committer_email.clone(),
1883 }),
1884 };
1885
1886 for collaborator in &collaborators {
1887 session
1888 .peer
1889 .send(
1890 collaborator.peer_id.unwrap().into(),
1891 add_project_collaborator.clone(),
1892 )
1893 .trace_err();
1894 }
1895
1896 // First, we send the metadata associated with each worktree.
1897 let (language_servers, language_server_capabilities) = project
1898 .language_servers
1899 .clone()
1900 .into_iter()
1901 .map(|server| (server.server, server.capabilities))
1902 .unzip();
1903 response.send(proto::JoinProjectResponse {
1904 project_id: project.id.0 as u64,
1905 worktrees,
1906 replica_id: replica_id.0 as u32,
1907 collaborators,
1908 language_servers,
1909 language_server_capabilities,
1910 role: project.role.into(),
1911 windows_paths: project.path_style == PathStyle::Windows,
1912 })?;
1913
1914 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1915 // Stream this worktree's entries.
1916 let message = proto::UpdateWorktree {
1917 project_id: project_id.to_proto(),
1918 worktree_id,
1919 abs_path: worktree.abs_path.clone(),
1920 root_name: worktree.root_name,
1921 updated_entries: worktree.entries,
1922 removed_entries: Default::default(),
1923 scan_id: worktree.scan_id,
1924 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1925 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1926 removed_repositories: Default::default(),
1927 };
1928 for update in proto::split_worktree_update(message) {
1929 session.peer.send(session.connection_id, update.clone())?;
1930 }
1931
1932 // Stream this worktree's diagnostics.
1933 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1934 if let Some(summary) = worktree_diagnostics.next() {
1935 let message = proto::UpdateDiagnosticSummary {
1936 project_id: project.id.to_proto(),
1937 worktree_id: worktree.id,
1938 summary: Some(summary),
1939 more_summaries: worktree_diagnostics.collect(),
1940 };
1941 session.peer.send(session.connection_id, message)?;
1942 }
1943
1944 for settings_file in worktree.settings_files {
1945 session.peer.send(
1946 session.connection_id,
1947 proto::UpdateWorktreeSettings {
1948 project_id: project_id.to_proto(),
1949 worktree_id: worktree.id,
1950 path: settings_file.path,
1951 content: Some(settings_file.content),
1952 kind: Some(settings_file.kind.to_proto() as i32),
1953 outside_worktree: Some(settings_file.outside_worktree),
1954 },
1955 )?;
1956 }
1957 }
1958
1959 for repository in mem::take(&mut project.repositories) {
1960 for update in split_repository_update(repository) {
1961 session.peer.send(session.connection_id, update)?;
1962 }
1963 }
1964
1965 for language_server in &project.language_servers {
1966 session.peer.send(
1967 session.connection_id,
1968 proto::UpdateLanguageServer {
1969 project_id: project_id.to_proto(),
1970 server_name: Some(language_server.server.name.clone()),
1971 language_server_id: language_server.server.id,
1972 variant: Some(
1973 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1974 proto::LspDiskBasedDiagnosticsUpdated {},
1975 ),
1976 ),
1977 },
1978 )?;
1979 }
1980
1981 Ok(())
1982}
1983
1984/// Leave someone elses shared project.
1985async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
1986 let sender_id = session.connection_id;
1987 let project_id = ProjectId::from_proto(request.project_id);
1988 let db = session.db().await;
1989
1990 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1991 tracing::info!(
1992 %project_id,
1993 "leave project"
1994 );
1995
1996 project_left(project, &session);
1997 if let Some(room) = room {
1998 room_updated(room, &session.peer);
1999 }
2000
2001 Ok(())
2002}
2003
2004/// Updates other participants with changes to the project
2005async fn update_project(
2006 request: proto::UpdateProject,
2007 response: Response<proto::UpdateProject>,
2008 session: MessageContext,
2009) -> Result<()> {
2010 let project_id = ProjectId::from_proto(request.project_id);
2011 let (room, guest_connection_ids) = &*session
2012 .db()
2013 .await
2014 .update_project(project_id, session.connection_id, &request.worktrees)
2015 .await?;
2016 broadcast(
2017 Some(session.connection_id),
2018 guest_connection_ids.iter().copied(),
2019 |connection_id| {
2020 session
2021 .peer
2022 .forward_send(session.connection_id, connection_id, request.clone())
2023 },
2024 );
2025 if let Some(room) = room {
2026 room_updated(room, &session.peer);
2027 }
2028 response.send(proto::Ack {})?;
2029
2030 Ok(())
2031}
2032
2033/// Updates other participants with changes to the worktree
2034async fn update_worktree(
2035 request: proto::UpdateWorktree,
2036 response: Response<proto::UpdateWorktree>,
2037 session: MessageContext,
2038) -> Result<()> {
2039 let guest_connection_ids = session
2040 .db()
2041 .await
2042 .update_worktree(&request, session.connection_id)
2043 .await?;
2044
2045 broadcast(
2046 Some(session.connection_id),
2047 guest_connection_ids.iter().copied(),
2048 |connection_id| {
2049 session
2050 .peer
2051 .forward_send(session.connection_id, connection_id, request.clone())
2052 },
2053 );
2054 response.send(proto::Ack {})?;
2055 Ok(())
2056}
2057
2058async fn update_repository(
2059 request: proto::UpdateRepository,
2060 response: Response<proto::UpdateRepository>,
2061 session: MessageContext,
2062) -> Result<()> {
2063 let guest_connection_ids = session
2064 .db()
2065 .await
2066 .update_repository(&request, session.connection_id)
2067 .await?;
2068
2069 broadcast(
2070 Some(session.connection_id),
2071 guest_connection_ids.iter().copied(),
2072 |connection_id| {
2073 session
2074 .peer
2075 .forward_send(session.connection_id, connection_id, request.clone())
2076 },
2077 );
2078 response.send(proto::Ack {})?;
2079 Ok(())
2080}
2081
2082async fn remove_repository(
2083 request: proto::RemoveRepository,
2084 response: Response<proto::RemoveRepository>,
2085 session: MessageContext,
2086) -> Result<()> {
2087 let guest_connection_ids = session
2088 .db()
2089 .await
2090 .remove_repository(&request, session.connection_id)
2091 .await?;
2092
2093 broadcast(
2094 Some(session.connection_id),
2095 guest_connection_ids.iter().copied(),
2096 |connection_id| {
2097 session
2098 .peer
2099 .forward_send(session.connection_id, connection_id, request.clone())
2100 },
2101 );
2102 response.send(proto::Ack {})?;
2103 Ok(())
2104}
2105
2106/// Updates other participants with changes to the diagnostics
2107async fn update_diagnostic_summary(
2108 message: proto::UpdateDiagnosticSummary,
2109 session: MessageContext,
2110) -> Result<()> {
2111 let guest_connection_ids = session
2112 .db()
2113 .await
2114 .update_diagnostic_summary(&message, session.connection_id)
2115 .await?;
2116
2117 broadcast(
2118 Some(session.connection_id),
2119 guest_connection_ids.iter().copied(),
2120 |connection_id| {
2121 session
2122 .peer
2123 .forward_send(session.connection_id, connection_id, message.clone())
2124 },
2125 );
2126
2127 Ok(())
2128}
2129
2130/// Updates other participants with changes to the worktree settings
2131async fn update_worktree_settings(
2132 message: proto::UpdateWorktreeSettings,
2133 session: MessageContext,
2134) -> Result<()> {
2135 let guest_connection_ids = session
2136 .db()
2137 .await
2138 .update_worktree_settings(&message, session.connection_id)
2139 .await?;
2140
2141 broadcast(
2142 Some(session.connection_id),
2143 guest_connection_ids.iter().copied(),
2144 |connection_id| {
2145 session
2146 .peer
2147 .forward_send(session.connection_id, connection_id, message.clone())
2148 },
2149 );
2150
2151 Ok(())
2152}
2153
2154/// Notify other participants that a language server has started.
2155async fn start_language_server(
2156 request: proto::StartLanguageServer,
2157 session: MessageContext,
2158) -> Result<()> {
2159 let guest_connection_ids = session
2160 .db()
2161 .await
2162 .start_language_server(&request, session.connection_id)
2163 .await?;
2164
2165 broadcast(
2166 Some(session.connection_id),
2167 guest_connection_ids.iter().copied(),
2168 |connection_id| {
2169 session
2170 .peer
2171 .forward_send(session.connection_id, connection_id, request.clone())
2172 },
2173 );
2174 Ok(())
2175}
2176
2177/// Notify other participants that a language server has changed.
2178async fn update_language_server(
2179 request: proto::UpdateLanguageServer,
2180 session: MessageContext,
2181) -> Result<()> {
2182 let project_id = ProjectId::from_proto(request.project_id);
2183 let db = session.db().await;
2184
2185 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2186 && let Some(capabilities) = update.capabilities.clone()
2187 {
2188 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2189 .await?;
2190 }
2191
2192 let project_connection_ids = db
2193 .project_connection_ids(project_id, session.connection_id, true)
2194 .await?;
2195 broadcast(
2196 Some(session.connection_id),
2197 project_connection_ids.iter().copied(),
2198 |connection_id| {
2199 session
2200 .peer
2201 .forward_send(session.connection_id, connection_id, request.clone())
2202 },
2203 );
2204 Ok(())
2205}
2206
2207/// forward a project request to the host. These requests should be read only
2208/// as guests are allowed to send them.
2209async fn forward_read_only_project_request<T>(
2210 request: T,
2211 response: Response<T>,
2212 session: MessageContext,
2213) -> Result<()>
2214where
2215 T: EntityMessage + RequestMessage,
2216{
2217 let project_id = ProjectId::from_proto(request.remote_entity_id());
2218 let host_connection_id = session
2219 .db()
2220 .await
2221 .host_for_read_only_project_request(project_id, session.connection_id)
2222 .await?;
2223 let payload = session.forward_request(host_connection_id, request).await?;
2224 response.send(payload)?;
2225 Ok(())
2226}
2227
2228/// forward a project request to the host. These requests are disallowed
2229/// for guests.
2230async fn forward_mutating_project_request<T>(
2231 request: T,
2232 response: Response<T>,
2233 session: MessageContext,
2234) -> Result<()>
2235where
2236 T: EntityMessage + RequestMessage,
2237{
2238 let project_id = ProjectId::from_proto(request.remote_entity_id());
2239
2240 let host_connection_id = session
2241 .db()
2242 .await
2243 .host_for_mutating_project_request(project_id, session.connection_id)
2244 .await?;
2245 let payload = session.forward_request(host_connection_id, request).await?;
2246 response.send(payload)?;
2247 Ok(())
2248}
2249
2250async fn disallow_guest_request<T>(
2251 _request: T,
2252 response: Response<T>,
2253 _session: MessageContext,
2254) -> Result<()>
2255where
2256 T: RequestMessage,
2257{
2258 response.peer.respond_with_error(
2259 response.receipt,
2260 ErrorCode::Forbidden
2261 .message("request is not allowed for guests".to_string())
2262 .to_proto(),
2263 )?;
2264 response.responded.store(true, SeqCst);
2265 Ok(())
2266}
2267
2268async fn lsp_query(
2269 request: proto::LspQuery,
2270 response: Response<proto::LspQuery>,
2271 session: MessageContext,
2272) -> Result<()> {
2273 let (name, should_write) = request.query_name_and_write_permissions();
2274 tracing::Span::current().record("lsp_query_request", name);
2275 tracing::info!("lsp_query message received");
2276 if should_write {
2277 forward_mutating_project_request(request, response, session).await
2278 } else {
2279 forward_read_only_project_request(request, response, session).await
2280 }
2281}
2282
2283/// Notify other participants that a new buffer has been created
2284async fn create_buffer_for_peer(
2285 request: proto::CreateBufferForPeer,
2286 session: MessageContext,
2287) -> Result<()> {
2288 session
2289 .db()
2290 .await
2291 .check_user_is_project_host(
2292 ProjectId::from_proto(request.project_id),
2293 session.connection_id,
2294 )
2295 .await?;
2296 let peer_id = request.peer_id.context("invalid peer id")?;
2297 session
2298 .peer
2299 .forward_send(session.connection_id, peer_id.into(), request)?;
2300 Ok(())
2301}
2302
2303/// Notify other participants that a new image has been created
2304async fn create_image_for_peer(
2305 request: proto::CreateImageForPeer,
2306 session: MessageContext,
2307) -> Result<()> {
2308 session
2309 .db()
2310 .await
2311 .check_user_is_project_host(
2312 ProjectId::from_proto(request.project_id),
2313 session.connection_id,
2314 )
2315 .await?;
2316 let peer_id = request.peer_id.context("invalid peer id")?;
2317 session
2318 .peer
2319 .forward_send(session.connection_id, peer_id.into(), request)?;
2320 Ok(())
2321}
2322
2323/// Notify other participants that a buffer has been updated. This is
2324/// allowed for guests as long as the update is limited to selections.
2325async fn update_buffer(
2326 request: proto::UpdateBuffer,
2327 response: Response<proto::UpdateBuffer>,
2328 session: MessageContext,
2329) -> Result<()> {
2330 let project_id = ProjectId::from_proto(request.project_id);
2331 let mut capability = Capability::ReadOnly;
2332
2333 for op in request.operations.iter() {
2334 match op.variant {
2335 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2336 Some(_) => capability = Capability::ReadWrite,
2337 }
2338 }
2339
2340 let host = {
2341 let guard = session
2342 .db()
2343 .await
2344 .connections_for_buffer_update(project_id, session.connection_id, capability)
2345 .await?;
2346
2347 let (host, guests) = &*guard;
2348
2349 broadcast(
2350 Some(session.connection_id),
2351 guests.clone(),
2352 |connection_id| {
2353 session
2354 .peer
2355 .forward_send(session.connection_id, connection_id, request.clone())
2356 },
2357 );
2358
2359 *host
2360 };
2361
2362 if host != session.connection_id {
2363 session.forward_request(host, request.clone()).await?;
2364 }
2365
2366 response.send(proto::Ack {})?;
2367 Ok(())
2368}
2369
2370async fn forward_project_search_chunk(
2371 message: proto::FindSearchCandidatesChunk,
2372 response: Response<proto::FindSearchCandidatesChunk>,
2373 session: MessageContext,
2374) -> Result<()> {
2375 let peer_id = message.peer_id.context("missing peer_id")?;
2376 let payload = session
2377 .peer
2378 .forward_request(session.connection_id, peer_id.into(), message)
2379 .await?;
2380 response.send(payload)?;
2381 Ok(())
2382}
2383
2384/// Notify other participants that a project has been updated.
2385async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2386 request: T,
2387 session: MessageContext,
2388) -> Result<()> {
2389 let project_id = ProjectId::from_proto(request.remote_entity_id());
2390 let project_connection_ids = session
2391 .db()
2392 .await
2393 .project_connection_ids(project_id, session.connection_id, false)
2394 .await?;
2395
2396 broadcast(
2397 Some(session.connection_id),
2398 project_connection_ids.iter().copied(),
2399 |connection_id| {
2400 session
2401 .peer
2402 .forward_send(session.connection_id, connection_id, request.clone())
2403 },
2404 );
2405 Ok(())
2406}
2407
2408/// Start following another user in a call.
2409async fn follow(
2410 request: proto::Follow,
2411 response: Response<proto::Follow>,
2412 session: MessageContext,
2413) -> Result<()> {
2414 let room_id = RoomId::from_proto(request.room_id);
2415 let project_id = request.project_id.map(ProjectId::from_proto);
2416 let leader_id = request.leader_id.context("invalid leader id")?.into();
2417 let follower_id = session.connection_id;
2418
2419 session
2420 .db()
2421 .await
2422 .check_room_participants(room_id, leader_id, session.connection_id)
2423 .await?;
2424
2425 let response_payload = session.forward_request(leader_id, request).await?;
2426 response.send(response_payload)?;
2427
2428 if let Some(project_id) = project_id {
2429 let room = session
2430 .db()
2431 .await
2432 .follow(room_id, project_id, leader_id, follower_id)
2433 .await?;
2434 room_updated(&room, &session.peer);
2435 }
2436
2437 Ok(())
2438}
2439
2440/// Stop following another user in a call.
2441async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2442 let room_id = RoomId::from_proto(request.room_id);
2443 let project_id = request.project_id.map(ProjectId::from_proto);
2444 let leader_id = request.leader_id.context("invalid leader id")?.into();
2445 let follower_id = session.connection_id;
2446
2447 session
2448 .db()
2449 .await
2450 .check_room_participants(room_id, leader_id, session.connection_id)
2451 .await?;
2452
2453 session
2454 .peer
2455 .forward_send(session.connection_id, leader_id, request)?;
2456
2457 if let Some(project_id) = project_id {
2458 let room = session
2459 .db()
2460 .await
2461 .unfollow(room_id, project_id, leader_id, follower_id)
2462 .await?;
2463 room_updated(&room, &session.peer);
2464 }
2465
2466 Ok(())
2467}
2468
2469/// Notify everyone following you of your current location.
2470async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2471 let room_id = RoomId::from_proto(request.room_id);
2472 let database = session.db.lock().await;
2473
2474 let connection_ids = if let Some(project_id) = request.project_id {
2475 let project_id = ProjectId::from_proto(project_id);
2476 database
2477 .project_connection_ids(project_id, session.connection_id, true)
2478 .await?
2479 } else {
2480 database
2481 .room_connection_ids(room_id, session.connection_id)
2482 .await?
2483 };
2484
2485 // For now, don't send view update messages back to that view's current leader.
2486 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2487 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2488 _ => None,
2489 });
2490
2491 for connection_id in connection_ids.iter().cloned() {
2492 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2493 session
2494 .peer
2495 .forward_send(session.connection_id, connection_id, request.clone())?;
2496 }
2497 }
2498 Ok(())
2499}
2500
2501/// Get public data about users.
2502async fn get_users(
2503 request: proto::GetUsers,
2504 response: Response<proto::GetUsers>,
2505 session: MessageContext,
2506) -> Result<()> {
2507 let user_ids = request
2508 .user_ids
2509 .into_iter()
2510 .map(UserId::from_proto)
2511 .collect();
2512 let users = session
2513 .db()
2514 .await
2515 .get_users_by_ids(user_ids)
2516 .await?
2517 .into_iter()
2518 .map(|user| proto::User {
2519 id: user.id.to_proto(),
2520 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2521 github_login: user.github_login,
2522 name: user.name,
2523 })
2524 .collect();
2525 response.send(proto::UsersResponse { users })?;
2526 Ok(())
2527}
2528
2529/// Search for users (to invite) buy Github login
2530async fn fuzzy_search_users(
2531 request: proto::FuzzySearchUsers,
2532 response: Response<proto::FuzzySearchUsers>,
2533 session: MessageContext,
2534) -> Result<()> {
2535 let query = request.query;
2536 let users = match query.len() {
2537 0 => vec![],
2538 1 | 2 => session
2539 .db()
2540 .await
2541 .get_user_by_github_login(&query)
2542 .await?
2543 .into_iter()
2544 .collect(),
2545 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2546 };
2547 let users = users
2548 .into_iter()
2549 .filter(|user| user.id != session.user_id())
2550 .map(|user| proto::User {
2551 id: user.id.to_proto(),
2552 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2553 github_login: user.github_login,
2554 name: user.name,
2555 })
2556 .collect();
2557 response.send(proto::UsersResponse { users })?;
2558 Ok(())
2559}
2560
2561/// Send a contact request to another user.
2562async fn request_contact(
2563 request: proto::RequestContact,
2564 response: Response<proto::RequestContact>,
2565 session: MessageContext,
2566) -> Result<()> {
2567 let requester_id = session.user_id();
2568 let responder_id = UserId::from_proto(request.responder_id);
2569 if requester_id == responder_id {
2570 return Err(anyhow!("cannot add yourself as a contact"))?;
2571 }
2572
2573 let notifications = session
2574 .db()
2575 .await
2576 .send_contact_request(requester_id, responder_id)
2577 .await?;
2578
2579 // Update outgoing contact requests of requester
2580 let mut update = proto::UpdateContacts::default();
2581 update.outgoing_requests.push(responder_id.to_proto());
2582 for connection_id in session
2583 .connection_pool()
2584 .await
2585 .user_connection_ids(requester_id)
2586 {
2587 session.peer.send(connection_id, update.clone())?;
2588 }
2589
2590 // Update incoming contact requests of responder
2591 let mut update = proto::UpdateContacts::default();
2592 update
2593 .incoming_requests
2594 .push(proto::IncomingContactRequest {
2595 requester_id: requester_id.to_proto(),
2596 });
2597 let connection_pool = session.connection_pool().await;
2598 for connection_id in connection_pool.user_connection_ids(responder_id) {
2599 session.peer.send(connection_id, update.clone())?;
2600 }
2601
2602 send_notifications(&connection_pool, &session.peer, notifications);
2603
2604 response.send(proto::Ack {})?;
2605 Ok(())
2606}
2607
2608/// Accept or decline a contact request
2609async fn respond_to_contact_request(
2610 request: proto::RespondToContactRequest,
2611 response: Response<proto::RespondToContactRequest>,
2612 session: MessageContext,
2613) -> Result<()> {
2614 let responder_id = session.user_id();
2615 let requester_id = UserId::from_proto(request.requester_id);
2616 let db = session.db().await;
2617 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2618 db.dismiss_contact_notification(responder_id, requester_id)
2619 .await?;
2620 } else {
2621 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2622
2623 let notifications = db
2624 .respond_to_contact_request(responder_id, requester_id, accept)
2625 .await?;
2626 let requester_busy = db.is_user_busy(requester_id).await?;
2627 let responder_busy = db.is_user_busy(responder_id).await?;
2628
2629 let pool = session.connection_pool().await;
2630 // Update responder with new contact
2631 let mut update = proto::UpdateContacts::default();
2632 if accept {
2633 update
2634 .contacts
2635 .push(contact_for_user(requester_id, requester_busy, &pool));
2636 }
2637 update
2638 .remove_incoming_requests
2639 .push(requester_id.to_proto());
2640 for connection_id in pool.user_connection_ids(responder_id) {
2641 session.peer.send(connection_id, update.clone())?;
2642 }
2643
2644 // Update requester with new contact
2645 let mut update = proto::UpdateContacts::default();
2646 if accept {
2647 update
2648 .contacts
2649 .push(contact_for_user(responder_id, responder_busy, &pool));
2650 }
2651 update
2652 .remove_outgoing_requests
2653 .push(responder_id.to_proto());
2654
2655 for connection_id in pool.user_connection_ids(requester_id) {
2656 session.peer.send(connection_id, update.clone())?;
2657 }
2658
2659 send_notifications(&pool, &session.peer, notifications);
2660 }
2661
2662 response.send(proto::Ack {})?;
2663 Ok(())
2664}
2665
2666/// Remove a contact.
2667async fn remove_contact(
2668 request: proto::RemoveContact,
2669 response: Response<proto::RemoveContact>,
2670 session: MessageContext,
2671) -> Result<()> {
2672 let requester_id = session.user_id();
2673 let responder_id = UserId::from_proto(request.user_id);
2674 let db = session.db().await;
2675 let (contact_accepted, deleted_notification_id) =
2676 db.remove_contact(requester_id, responder_id).await?;
2677
2678 let pool = session.connection_pool().await;
2679 // Update outgoing contact requests of requester
2680 let mut update = proto::UpdateContacts::default();
2681 if contact_accepted {
2682 update.remove_contacts.push(responder_id.to_proto());
2683 } else {
2684 update
2685 .remove_outgoing_requests
2686 .push(responder_id.to_proto());
2687 }
2688 for connection_id in pool.user_connection_ids(requester_id) {
2689 session.peer.send(connection_id, update.clone())?;
2690 }
2691
2692 // Update incoming contact requests of responder
2693 let mut update = proto::UpdateContacts::default();
2694 if contact_accepted {
2695 update.remove_contacts.push(requester_id.to_proto());
2696 } else {
2697 update
2698 .remove_incoming_requests
2699 .push(requester_id.to_proto());
2700 }
2701 for connection_id in pool.user_connection_ids(responder_id) {
2702 session.peer.send(connection_id, update.clone())?;
2703 if let Some(notification_id) = deleted_notification_id {
2704 session.peer.send(
2705 connection_id,
2706 proto::DeleteNotification {
2707 notification_id: notification_id.to_proto(),
2708 },
2709 )?;
2710 }
2711 }
2712
2713 response.send(proto::Ack {})?;
2714 Ok(())
2715}
2716
2717fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2718 version.0.minor < 139
2719}
2720
2721async fn subscribe_to_channels(
2722 _: proto::SubscribeToChannels,
2723 session: MessageContext,
2724) -> Result<()> {
2725 subscribe_user_to_channels(session.user_id(), &session).await?;
2726 Ok(())
2727}
2728
2729async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2730 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2731 let mut pool = session.connection_pool().await;
2732 for membership in &channels_for_user.channel_memberships {
2733 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2734 }
2735 session.peer.send(
2736 session.connection_id,
2737 build_update_user_channels(&channels_for_user),
2738 )?;
2739 session.peer.send(
2740 session.connection_id,
2741 build_channels_update(channels_for_user),
2742 )?;
2743 Ok(())
2744}
2745
2746/// Creates a new channel.
2747async fn create_channel(
2748 request: proto::CreateChannel,
2749 response: Response<proto::CreateChannel>,
2750 session: MessageContext,
2751) -> Result<()> {
2752 let db = session.db().await;
2753
2754 let parent_id = request.parent_id.map(ChannelId::from_proto);
2755 let (channel, membership) = db
2756 .create_channel(&request.name, parent_id, session.user_id())
2757 .await?;
2758
2759 let root_id = channel.root_id();
2760 let channel = Channel::from_model(channel);
2761
2762 response.send(proto::CreateChannelResponse {
2763 channel: Some(channel.to_proto()),
2764 parent_id: request.parent_id,
2765 })?;
2766
2767 let mut connection_pool = session.connection_pool().await;
2768 if let Some(membership) = membership {
2769 connection_pool.subscribe_to_channel(
2770 membership.user_id,
2771 membership.channel_id,
2772 membership.role,
2773 );
2774 let update = proto::UpdateUserChannels {
2775 channel_memberships: vec![proto::ChannelMembership {
2776 channel_id: membership.channel_id.to_proto(),
2777 role: membership.role.into(),
2778 }],
2779 ..Default::default()
2780 };
2781 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2782 session.peer.send(connection_id, update.clone())?;
2783 }
2784 }
2785
2786 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2787 if !role.can_see_channel(channel.visibility) {
2788 continue;
2789 }
2790
2791 let update = proto::UpdateChannels {
2792 channels: vec![channel.to_proto()],
2793 ..Default::default()
2794 };
2795 session.peer.send(connection_id, update.clone())?;
2796 }
2797
2798 Ok(())
2799}
2800
2801/// Delete a channel
2802async fn delete_channel(
2803 request: proto::DeleteChannel,
2804 response: Response<proto::DeleteChannel>,
2805 session: MessageContext,
2806) -> Result<()> {
2807 let db = session.db().await;
2808
2809 let channel_id = request.channel_id;
2810 let (root_channel, removed_channels) = db
2811 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2812 .await?;
2813 response.send(proto::Ack {})?;
2814
2815 // Notify members of removed channels
2816 let mut update = proto::UpdateChannels::default();
2817 update
2818 .delete_channels
2819 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2820
2821 let connection_pool = session.connection_pool().await;
2822 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2823 session.peer.send(connection_id, update.clone())?;
2824 }
2825
2826 Ok(())
2827}
2828
2829/// Invite someone to join a channel.
2830async fn invite_channel_member(
2831 request: proto::InviteChannelMember,
2832 response: Response<proto::InviteChannelMember>,
2833 session: MessageContext,
2834) -> Result<()> {
2835 let db = session.db().await;
2836 let channel_id = ChannelId::from_proto(request.channel_id);
2837 let invitee_id = UserId::from_proto(request.user_id);
2838 let InviteMemberResult {
2839 channel,
2840 notifications,
2841 } = db
2842 .invite_channel_member(
2843 channel_id,
2844 invitee_id,
2845 session.user_id(),
2846 request.role().into(),
2847 )
2848 .await?;
2849
2850 let update = proto::UpdateChannels {
2851 channel_invitations: vec![channel.to_proto()],
2852 ..Default::default()
2853 };
2854
2855 let connection_pool = session.connection_pool().await;
2856 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2857 session.peer.send(connection_id, update.clone())?;
2858 }
2859
2860 send_notifications(&connection_pool, &session.peer, notifications);
2861
2862 response.send(proto::Ack {})?;
2863 Ok(())
2864}
2865
2866/// remove someone from a channel
2867async fn remove_channel_member(
2868 request: proto::RemoveChannelMember,
2869 response: Response<proto::RemoveChannelMember>,
2870 session: MessageContext,
2871) -> Result<()> {
2872 let db = session.db().await;
2873 let channel_id = ChannelId::from_proto(request.channel_id);
2874 let member_id = UserId::from_proto(request.user_id);
2875
2876 let RemoveChannelMemberResult {
2877 membership_update,
2878 notification_id,
2879 } = db
2880 .remove_channel_member(channel_id, member_id, session.user_id())
2881 .await?;
2882
2883 let mut connection_pool = session.connection_pool().await;
2884 notify_membership_updated(
2885 &mut connection_pool,
2886 membership_update,
2887 member_id,
2888 &session.peer,
2889 );
2890 for connection_id in connection_pool.user_connection_ids(member_id) {
2891 if let Some(notification_id) = notification_id {
2892 session
2893 .peer
2894 .send(
2895 connection_id,
2896 proto::DeleteNotification {
2897 notification_id: notification_id.to_proto(),
2898 },
2899 )
2900 .trace_err();
2901 }
2902 }
2903
2904 response.send(proto::Ack {})?;
2905 Ok(())
2906}
2907
2908/// Toggle the channel between public and private.
2909/// Care is taken to maintain the invariant that public channels only descend from public channels,
2910/// (though members-only channels can appear at any point in the hierarchy).
2911async fn set_channel_visibility(
2912 request: proto::SetChannelVisibility,
2913 response: Response<proto::SetChannelVisibility>,
2914 session: MessageContext,
2915) -> Result<()> {
2916 let db = session.db().await;
2917 let channel_id = ChannelId::from_proto(request.channel_id);
2918 let visibility = request.visibility().into();
2919
2920 let channel_model = db
2921 .set_channel_visibility(channel_id, visibility, session.user_id())
2922 .await?;
2923 let root_id = channel_model.root_id();
2924 let channel = Channel::from_model(channel_model);
2925
2926 let mut connection_pool = session.connection_pool().await;
2927 for (user_id, role) in connection_pool
2928 .channel_user_ids(root_id)
2929 .collect::<Vec<_>>()
2930 .into_iter()
2931 {
2932 let update = if role.can_see_channel(channel.visibility) {
2933 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2934 proto::UpdateChannels {
2935 channels: vec![channel.to_proto()],
2936 ..Default::default()
2937 }
2938 } else {
2939 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2940 proto::UpdateChannels {
2941 delete_channels: vec![channel.id.to_proto()],
2942 ..Default::default()
2943 }
2944 };
2945
2946 for connection_id in connection_pool.user_connection_ids(user_id) {
2947 session.peer.send(connection_id, update.clone())?;
2948 }
2949 }
2950
2951 response.send(proto::Ack {})?;
2952 Ok(())
2953}
2954
2955/// Alter the role for a user in the channel.
2956async fn set_channel_member_role(
2957 request: proto::SetChannelMemberRole,
2958 response: Response<proto::SetChannelMemberRole>,
2959 session: MessageContext,
2960) -> Result<()> {
2961 let db = session.db().await;
2962 let channel_id = ChannelId::from_proto(request.channel_id);
2963 let member_id = UserId::from_proto(request.user_id);
2964 let result = db
2965 .set_channel_member_role(
2966 channel_id,
2967 session.user_id(),
2968 member_id,
2969 request.role().into(),
2970 )
2971 .await?;
2972
2973 match result {
2974 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
2975 let mut connection_pool = session.connection_pool().await;
2976 notify_membership_updated(
2977 &mut connection_pool,
2978 membership_update,
2979 member_id,
2980 &session.peer,
2981 )
2982 }
2983 db::SetMemberRoleResult::InviteUpdated(channel) => {
2984 let update = proto::UpdateChannels {
2985 channel_invitations: vec![channel.to_proto()],
2986 ..Default::default()
2987 };
2988
2989 for connection_id in session
2990 .connection_pool()
2991 .await
2992 .user_connection_ids(member_id)
2993 {
2994 session.peer.send(connection_id, update.clone())?;
2995 }
2996 }
2997 }
2998
2999 response.send(proto::Ack {})?;
3000 Ok(())
3001}
3002
3003/// Change the name of a channel
3004async fn rename_channel(
3005 request: proto::RenameChannel,
3006 response: Response<proto::RenameChannel>,
3007 session: MessageContext,
3008) -> Result<()> {
3009 let db = session.db().await;
3010 let channel_id = ChannelId::from_proto(request.channel_id);
3011 let channel_model = db
3012 .rename_channel(channel_id, session.user_id(), &request.name)
3013 .await?;
3014 let root_id = channel_model.root_id();
3015 let channel = Channel::from_model(channel_model);
3016
3017 response.send(proto::RenameChannelResponse {
3018 channel: Some(channel.to_proto()),
3019 })?;
3020
3021 let connection_pool = session.connection_pool().await;
3022 let update = proto::UpdateChannels {
3023 channels: vec![channel.to_proto()],
3024 ..Default::default()
3025 };
3026 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3027 if role.can_see_channel(channel.visibility) {
3028 session.peer.send(connection_id, update.clone())?;
3029 }
3030 }
3031
3032 Ok(())
3033}
3034
3035/// Move a channel to a new parent.
3036async fn move_channel(
3037 request: proto::MoveChannel,
3038 response: Response<proto::MoveChannel>,
3039 session: MessageContext,
3040) -> Result<()> {
3041 let channel_id = ChannelId::from_proto(request.channel_id);
3042 let to = ChannelId::from_proto(request.to);
3043
3044 let (root_id, channels) = session
3045 .db()
3046 .await
3047 .move_channel(channel_id, to, session.user_id())
3048 .await?;
3049
3050 let connection_pool = session.connection_pool().await;
3051 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3052 let channels = channels
3053 .iter()
3054 .filter_map(|channel| {
3055 if role.can_see_channel(channel.visibility) {
3056 Some(channel.to_proto())
3057 } else {
3058 None
3059 }
3060 })
3061 .collect::<Vec<_>>();
3062 if channels.is_empty() {
3063 continue;
3064 }
3065
3066 let update = proto::UpdateChannels {
3067 channels,
3068 ..Default::default()
3069 };
3070
3071 session.peer.send(connection_id, update.clone())?;
3072 }
3073
3074 response.send(Ack {})?;
3075 Ok(())
3076}
3077
3078async fn reorder_channel(
3079 request: proto::ReorderChannel,
3080 response: Response<proto::ReorderChannel>,
3081 session: MessageContext,
3082) -> Result<()> {
3083 let channel_id = ChannelId::from_proto(request.channel_id);
3084 let direction = request.direction();
3085
3086 let updated_channels = session
3087 .db()
3088 .await
3089 .reorder_channel(channel_id, direction, session.user_id())
3090 .await?;
3091
3092 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3093 let connection_pool = session.connection_pool().await;
3094 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3095 let channels = updated_channels
3096 .iter()
3097 .filter_map(|channel| {
3098 if role.can_see_channel(channel.visibility) {
3099 Some(channel.to_proto())
3100 } else {
3101 None
3102 }
3103 })
3104 .collect::<Vec<_>>();
3105
3106 if channels.is_empty() {
3107 continue;
3108 }
3109
3110 let update = proto::UpdateChannels {
3111 channels,
3112 ..Default::default()
3113 };
3114
3115 session.peer.send(connection_id, update.clone())?;
3116 }
3117 }
3118
3119 response.send(Ack {})?;
3120 Ok(())
3121}
3122
3123/// Get the list of channel members
3124async fn get_channel_members(
3125 request: proto::GetChannelMembers,
3126 response: Response<proto::GetChannelMembers>,
3127 session: MessageContext,
3128) -> Result<()> {
3129 let db = session.db().await;
3130 let channel_id = ChannelId::from_proto(request.channel_id);
3131 let limit = if request.limit == 0 {
3132 u16::MAX as u64
3133 } else {
3134 request.limit
3135 };
3136 let (members, users) = db
3137 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3138 .await?;
3139 response.send(proto::GetChannelMembersResponse { members, users })?;
3140 Ok(())
3141}
3142
3143/// Accept or decline a channel invitation.
3144async fn respond_to_channel_invite(
3145 request: proto::RespondToChannelInvite,
3146 response: Response<proto::RespondToChannelInvite>,
3147 session: MessageContext,
3148) -> Result<()> {
3149 let db = session.db().await;
3150 let channel_id = ChannelId::from_proto(request.channel_id);
3151 let RespondToChannelInvite {
3152 membership_update,
3153 notifications,
3154 } = db
3155 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3156 .await?;
3157
3158 let mut connection_pool = session.connection_pool().await;
3159 if let Some(membership_update) = membership_update {
3160 notify_membership_updated(
3161 &mut connection_pool,
3162 membership_update,
3163 session.user_id(),
3164 &session.peer,
3165 );
3166 } else {
3167 let update = proto::UpdateChannels {
3168 remove_channel_invitations: vec![channel_id.to_proto()],
3169 ..Default::default()
3170 };
3171
3172 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3173 session.peer.send(connection_id, update.clone())?;
3174 }
3175 };
3176
3177 send_notifications(&connection_pool, &session.peer, notifications);
3178
3179 response.send(proto::Ack {})?;
3180
3181 Ok(())
3182}
3183
3184/// Join the channels' room
3185async fn join_channel(
3186 request: proto::JoinChannel,
3187 response: Response<proto::JoinChannel>,
3188 session: MessageContext,
3189) -> Result<()> {
3190 let channel_id = ChannelId::from_proto(request.channel_id);
3191 join_channel_internal(channel_id, Box::new(response), session).await
3192}
3193
3194trait JoinChannelInternalResponse {
3195 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3196}
3197impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3198 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3199 Response::<proto::JoinChannel>::send(self, result)
3200 }
3201}
3202impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3203 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3204 Response::<proto::JoinRoom>::send(self, result)
3205 }
3206}
3207
3208async fn join_channel_internal(
3209 channel_id: ChannelId,
3210 response: Box<impl JoinChannelInternalResponse>,
3211 session: MessageContext,
3212) -> Result<()> {
3213 let joined_room = {
3214 let mut db = session.db().await;
3215 // If zed quits without leaving the room, and the user re-opens zed before the
3216 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3217 // room they were in.
3218 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3219 tracing::info!(
3220 stale_connection_id = %connection,
3221 "cleaning up stale connection",
3222 );
3223 drop(db);
3224 leave_room_for_session(&session, connection).await?;
3225 db = session.db().await;
3226 }
3227
3228 let (joined_room, membership_updated, role) = db
3229 .join_channel(channel_id, session.user_id(), session.connection_id)
3230 .await?;
3231
3232 let live_kit_connection_info =
3233 session
3234 .app_state
3235 .livekit_client
3236 .as_ref()
3237 .and_then(|live_kit| {
3238 let (can_publish, token) = if role == ChannelRole::Guest {
3239 (
3240 false,
3241 live_kit
3242 .guest_token(
3243 &joined_room.room.livekit_room,
3244 &session.user_id().to_string(),
3245 )
3246 .trace_err()?,
3247 )
3248 } else {
3249 (
3250 true,
3251 live_kit
3252 .room_token(
3253 &joined_room.room.livekit_room,
3254 &session.user_id().to_string(),
3255 )
3256 .trace_err()?,
3257 )
3258 };
3259
3260 Some(LiveKitConnectionInfo {
3261 server_url: live_kit.url().into(),
3262 token,
3263 can_publish,
3264 })
3265 });
3266
3267 response.send(proto::JoinRoomResponse {
3268 room: Some(joined_room.room.clone()),
3269 channel_id: joined_room
3270 .channel
3271 .as_ref()
3272 .map(|channel| channel.id.to_proto()),
3273 live_kit_connection_info,
3274 })?;
3275
3276 let mut connection_pool = session.connection_pool().await;
3277 if let Some(membership_updated) = membership_updated {
3278 notify_membership_updated(
3279 &mut connection_pool,
3280 membership_updated,
3281 session.user_id(),
3282 &session.peer,
3283 );
3284 }
3285
3286 room_updated(&joined_room.room, &session.peer);
3287
3288 joined_room
3289 };
3290
3291 channel_updated(
3292 &joined_room.channel.context("channel not returned")?,
3293 &joined_room.room,
3294 &session.peer,
3295 &*session.connection_pool().await,
3296 );
3297
3298 update_user_contacts(session.user_id(), &session).await?;
3299 Ok(())
3300}
3301
3302/// Start editing the channel notes
3303async fn join_channel_buffer(
3304 request: proto::JoinChannelBuffer,
3305 response: Response<proto::JoinChannelBuffer>,
3306 session: MessageContext,
3307) -> Result<()> {
3308 let db = session.db().await;
3309 let channel_id = ChannelId::from_proto(request.channel_id);
3310
3311 let open_response = db
3312 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3313 .await?;
3314
3315 let collaborators = open_response.collaborators.clone();
3316 response.send(open_response)?;
3317
3318 let update = UpdateChannelBufferCollaborators {
3319 channel_id: channel_id.to_proto(),
3320 collaborators: collaborators.clone(),
3321 };
3322 channel_buffer_updated(
3323 session.connection_id,
3324 collaborators
3325 .iter()
3326 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3327 &update,
3328 &session.peer,
3329 );
3330
3331 Ok(())
3332}
3333
3334/// Edit the channel notes
3335async fn update_channel_buffer(
3336 request: proto::UpdateChannelBuffer,
3337 session: MessageContext,
3338) -> Result<()> {
3339 let db = session.db().await;
3340 let channel_id = ChannelId::from_proto(request.channel_id);
3341
3342 let (collaborators, epoch, version) = db
3343 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3344 .await?;
3345
3346 channel_buffer_updated(
3347 session.connection_id,
3348 collaborators.clone(),
3349 &proto::UpdateChannelBuffer {
3350 channel_id: channel_id.to_proto(),
3351 operations: request.operations,
3352 },
3353 &session.peer,
3354 );
3355
3356 let pool = &*session.connection_pool().await;
3357
3358 let non_collaborators =
3359 pool.channel_connection_ids(channel_id)
3360 .filter_map(|(connection_id, _)| {
3361 if collaborators.contains(&connection_id) {
3362 None
3363 } else {
3364 Some(connection_id)
3365 }
3366 });
3367
3368 broadcast(None, non_collaborators, |peer_id| {
3369 session.peer.send(
3370 peer_id,
3371 proto::UpdateChannels {
3372 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3373 channel_id: channel_id.to_proto(),
3374 epoch: epoch as u64,
3375 version: version.clone(),
3376 }],
3377 ..Default::default()
3378 },
3379 )
3380 });
3381
3382 Ok(())
3383}
3384
3385/// Rejoin the channel notes after a connection blip
3386async fn rejoin_channel_buffers(
3387 request: proto::RejoinChannelBuffers,
3388 response: Response<proto::RejoinChannelBuffers>,
3389 session: MessageContext,
3390) -> Result<()> {
3391 let db = session.db().await;
3392 let buffers = db
3393 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3394 .await?;
3395
3396 for rejoined_buffer in &buffers {
3397 let collaborators_to_notify = rejoined_buffer
3398 .buffer
3399 .collaborators
3400 .iter()
3401 .filter_map(|c| Some(c.peer_id?.into()));
3402 channel_buffer_updated(
3403 session.connection_id,
3404 collaborators_to_notify,
3405 &proto::UpdateChannelBufferCollaborators {
3406 channel_id: rejoined_buffer.buffer.channel_id,
3407 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3408 },
3409 &session.peer,
3410 );
3411 }
3412
3413 response.send(proto::RejoinChannelBuffersResponse {
3414 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3415 })?;
3416
3417 Ok(())
3418}
3419
3420/// Stop editing the channel notes
3421async fn leave_channel_buffer(
3422 request: proto::LeaveChannelBuffer,
3423 response: Response<proto::LeaveChannelBuffer>,
3424 session: MessageContext,
3425) -> Result<()> {
3426 let db = session.db().await;
3427 let channel_id = ChannelId::from_proto(request.channel_id);
3428
3429 let left_buffer = db
3430 .leave_channel_buffer(channel_id, session.connection_id)
3431 .await?;
3432
3433 response.send(Ack {})?;
3434
3435 channel_buffer_updated(
3436 session.connection_id,
3437 left_buffer.connections,
3438 &proto::UpdateChannelBufferCollaborators {
3439 channel_id: channel_id.to_proto(),
3440 collaborators: left_buffer.collaborators,
3441 },
3442 &session.peer,
3443 );
3444
3445 Ok(())
3446}
3447
3448fn channel_buffer_updated<T: EnvelopedMessage>(
3449 sender_id: ConnectionId,
3450 collaborators: impl IntoIterator<Item = ConnectionId>,
3451 message: &T,
3452 peer: &Peer,
3453) {
3454 broadcast(Some(sender_id), collaborators, |peer_id| {
3455 peer.send(peer_id, message.clone())
3456 });
3457}
3458
3459fn send_notifications(
3460 connection_pool: &ConnectionPool,
3461 peer: &Peer,
3462 notifications: db::NotificationBatch,
3463) {
3464 for (user_id, notification) in notifications {
3465 for connection_id in connection_pool.user_connection_ids(user_id) {
3466 if let Err(error) = peer.send(
3467 connection_id,
3468 proto::AddNotification {
3469 notification: Some(notification.clone()),
3470 },
3471 ) {
3472 tracing::error!(
3473 "failed to send notification to {:?} {}",
3474 connection_id,
3475 error
3476 );
3477 }
3478 }
3479 }
3480}
3481
3482/// Send a message to the channel
3483async fn send_channel_message(
3484 _request: proto::SendChannelMessage,
3485 _response: Response<proto::SendChannelMessage>,
3486 _session: MessageContext,
3487) -> Result<()> {
3488 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3489}
3490
3491/// Delete a channel message
3492async fn remove_channel_message(
3493 _request: proto::RemoveChannelMessage,
3494 _response: Response<proto::RemoveChannelMessage>,
3495 _session: MessageContext,
3496) -> Result<()> {
3497 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3498}
3499
3500async fn update_channel_message(
3501 _request: proto::UpdateChannelMessage,
3502 _response: Response<proto::UpdateChannelMessage>,
3503 _session: MessageContext,
3504) -> Result<()> {
3505 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3506}
3507
3508/// Mark a channel message as read
3509async fn acknowledge_channel_message(
3510 _request: proto::AckChannelMessage,
3511 _session: MessageContext,
3512) -> Result<()> {
3513 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3514}
3515
3516/// Mark a buffer version as synced
3517async fn acknowledge_buffer_version(
3518 request: proto::AckBufferOperation,
3519 session: MessageContext,
3520) -> Result<()> {
3521 let buffer_id = BufferId::from_proto(request.buffer_id);
3522 session
3523 .db()
3524 .await
3525 .observe_buffer_version(
3526 buffer_id,
3527 session.user_id(),
3528 request.epoch as i32,
3529 &request.version,
3530 )
3531 .await?;
3532 Ok(())
3533}
3534
3535/// Start receiving chat updates for a channel
3536async fn join_channel_chat(
3537 _request: proto::JoinChannelChat,
3538 _response: Response<proto::JoinChannelChat>,
3539 _session: MessageContext,
3540) -> Result<()> {
3541 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3542}
3543
3544/// Stop receiving chat updates for a channel
3545async fn leave_channel_chat(
3546 _request: proto::LeaveChannelChat,
3547 _session: MessageContext,
3548) -> Result<()> {
3549 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3550}
3551
3552/// Retrieve the chat history for a channel
3553async fn get_channel_messages(
3554 _request: proto::GetChannelMessages,
3555 _response: Response<proto::GetChannelMessages>,
3556 _session: MessageContext,
3557) -> Result<()> {
3558 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3559}
3560
3561/// Retrieve specific chat messages
3562async fn get_channel_messages_by_id(
3563 _request: proto::GetChannelMessagesById,
3564 _response: Response<proto::GetChannelMessagesById>,
3565 _session: MessageContext,
3566) -> Result<()> {
3567 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3568}
3569
3570/// Retrieve the current users notifications
3571async fn get_notifications(
3572 request: proto::GetNotifications,
3573 response: Response<proto::GetNotifications>,
3574 session: MessageContext,
3575) -> Result<()> {
3576 let notifications = session
3577 .db()
3578 .await
3579 .get_notifications(
3580 session.user_id(),
3581 NOTIFICATION_COUNT_PER_PAGE,
3582 request.before_id.map(db::NotificationId::from_proto),
3583 )
3584 .await?;
3585 response.send(proto::GetNotificationsResponse {
3586 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3587 notifications,
3588 })?;
3589 Ok(())
3590}
3591
3592/// Mark notifications as read
3593async fn mark_notification_as_read(
3594 request: proto::MarkNotificationRead,
3595 response: Response<proto::MarkNotificationRead>,
3596 session: MessageContext,
3597) -> Result<()> {
3598 let database = &session.db().await;
3599 let notifications = database
3600 .mark_notification_as_read_by_id(
3601 session.user_id(),
3602 NotificationId::from_proto(request.notification_id),
3603 )
3604 .await?;
3605 send_notifications(
3606 &*session.connection_pool().await,
3607 &session.peer,
3608 notifications,
3609 );
3610 response.send(proto::Ack {})?;
3611 Ok(())
3612}
3613
3614fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3615 let message = match message {
3616 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3617 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3618 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3619 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3620 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3621 code: frame.code.into(),
3622 reason: frame.reason.as_str().to_owned().into(),
3623 })),
3624 // We should never receive a frame while reading the message, according
3625 // to the `tungstenite` maintainers:
3626 //
3627 // > It cannot occur when you read messages from the WebSocket, but it
3628 // > can be used when you want to send the raw frames (e.g. you want to
3629 // > send the frames to the WebSocket without composing the full message first).
3630 // >
3631 // > — https://github.com/snapview/tungstenite-rs/issues/268
3632 TungsteniteMessage::Frame(_) => {
3633 bail!("received an unexpected frame while reading the message")
3634 }
3635 };
3636
3637 Ok(message)
3638}
3639
3640fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3641 match message {
3642 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3643 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3644 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3645 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3646 AxumMessage::Close(frame) => {
3647 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3648 code: frame.code.into(),
3649 reason: frame.reason.as_ref().into(),
3650 }))
3651 }
3652 }
3653}
3654
3655fn notify_membership_updated(
3656 connection_pool: &mut ConnectionPool,
3657 result: MembershipUpdated,
3658 user_id: UserId,
3659 peer: &Peer,
3660) {
3661 for membership in &result.new_channels.channel_memberships {
3662 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3663 }
3664 for channel_id in &result.removed_channels {
3665 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3666 }
3667
3668 let user_channels_update = proto::UpdateUserChannels {
3669 channel_memberships: result
3670 .new_channels
3671 .channel_memberships
3672 .iter()
3673 .map(|cm| proto::ChannelMembership {
3674 channel_id: cm.channel_id.to_proto(),
3675 role: cm.role.into(),
3676 })
3677 .collect(),
3678 ..Default::default()
3679 };
3680
3681 let mut update = build_channels_update(result.new_channels);
3682 update.delete_channels = result
3683 .removed_channels
3684 .into_iter()
3685 .map(|id| id.to_proto())
3686 .collect();
3687 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3688
3689 for connection_id in connection_pool.user_connection_ids(user_id) {
3690 peer.send(connection_id, user_channels_update.clone())
3691 .trace_err();
3692 peer.send(connection_id, update.clone()).trace_err();
3693 }
3694}
3695
3696fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3697 proto::UpdateUserChannels {
3698 channel_memberships: channels
3699 .channel_memberships
3700 .iter()
3701 .map(|m| proto::ChannelMembership {
3702 channel_id: m.channel_id.to_proto(),
3703 role: m.role.into(),
3704 })
3705 .collect(),
3706 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3707 }
3708}
3709
3710fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3711 let mut update = proto::UpdateChannels::default();
3712
3713 for channel in channels.channels {
3714 update.channels.push(channel.to_proto());
3715 }
3716
3717 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3718
3719 for (channel_id, participants) in channels.channel_participants {
3720 update
3721 .channel_participants
3722 .push(proto::ChannelParticipants {
3723 channel_id: channel_id.to_proto(),
3724 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3725 });
3726 }
3727
3728 for channel in channels.invited_channels {
3729 update.channel_invitations.push(channel.to_proto());
3730 }
3731
3732 update
3733}
3734
3735fn build_initial_contacts_update(
3736 contacts: Vec<db::Contact>,
3737 pool: &ConnectionPool,
3738) -> proto::UpdateContacts {
3739 let mut update = proto::UpdateContacts::default();
3740
3741 for contact in contacts {
3742 match contact {
3743 db::Contact::Accepted { user_id, busy } => {
3744 update.contacts.push(contact_for_user(user_id, busy, pool));
3745 }
3746 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3747 db::Contact::Incoming { user_id } => {
3748 update
3749 .incoming_requests
3750 .push(proto::IncomingContactRequest {
3751 requester_id: user_id.to_proto(),
3752 })
3753 }
3754 }
3755 }
3756
3757 update
3758}
3759
3760fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3761 proto::Contact {
3762 user_id: user_id.to_proto(),
3763 online: pool.is_user_online(user_id),
3764 busy,
3765 }
3766}
3767
3768fn room_updated(room: &proto::Room, peer: &Peer) {
3769 broadcast(
3770 None,
3771 room.participants
3772 .iter()
3773 .filter_map(|participant| Some(participant.peer_id?.into())),
3774 |peer_id| {
3775 peer.send(
3776 peer_id,
3777 proto::RoomUpdated {
3778 room: Some(room.clone()),
3779 },
3780 )
3781 },
3782 );
3783}
3784
3785fn channel_updated(
3786 channel: &db::channel::Model,
3787 room: &proto::Room,
3788 peer: &Peer,
3789 pool: &ConnectionPool,
3790) {
3791 let participants = room
3792 .participants
3793 .iter()
3794 .map(|p| p.user_id)
3795 .collect::<Vec<_>>();
3796
3797 broadcast(
3798 None,
3799 pool.channel_connection_ids(channel.root_id())
3800 .filter_map(|(channel_id, role)| {
3801 role.can_see_channel(channel.visibility)
3802 .then_some(channel_id)
3803 }),
3804 |peer_id| {
3805 peer.send(
3806 peer_id,
3807 proto::UpdateChannels {
3808 channel_participants: vec![proto::ChannelParticipants {
3809 channel_id: channel.id.to_proto(),
3810 participant_user_ids: participants.clone(),
3811 }],
3812 ..Default::default()
3813 },
3814 )
3815 },
3816 );
3817}
3818
3819async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3820 let db = session.db().await;
3821
3822 let contacts = db.get_contacts(user_id).await?;
3823 let busy = db.is_user_busy(user_id).await?;
3824
3825 let pool = session.connection_pool().await;
3826 let updated_contact = contact_for_user(user_id, busy, &pool);
3827 for contact in contacts {
3828 if let db::Contact::Accepted {
3829 user_id: contact_user_id,
3830 ..
3831 } = contact
3832 {
3833 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3834 session
3835 .peer
3836 .send(
3837 contact_conn_id,
3838 proto::UpdateContacts {
3839 contacts: vec![updated_contact.clone()],
3840 remove_contacts: Default::default(),
3841 incoming_requests: Default::default(),
3842 remove_incoming_requests: Default::default(),
3843 outgoing_requests: Default::default(),
3844 remove_outgoing_requests: Default::default(),
3845 },
3846 )
3847 .trace_err();
3848 }
3849 }
3850 }
3851 Ok(())
3852}
3853
3854async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3855 let mut contacts_to_update = HashSet::default();
3856
3857 let room_id;
3858 let canceled_calls_to_user_ids;
3859 let livekit_room;
3860 let delete_livekit_room;
3861 let room;
3862 let channel;
3863
3864 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3865 contacts_to_update.insert(session.user_id());
3866
3867 for project in left_room.left_projects.values() {
3868 project_left(project, session);
3869 }
3870
3871 room_id = RoomId::from_proto(left_room.room.id);
3872 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3873 livekit_room = mem::take(&mut left_room.room.livekit_room);
3874 delete_livekit_room = left_room.deleted;
3875 room = mem::take(&mut left_room.room);
3876 channel = mem::take(&mut left_room.channel);
3877
3878 room_updated(&room, &session.peer);
3879 } else {
3880 return Ok(());
3881 }
3882
3883 if let Some(channel) = channel {
3884 channel_updated(
3885 &channel,
3886 &room,
3887 &session.peer,
3888 &*session.connection_pool().await,
3889 );
3890 }
3891
3892 {
3893 let pool = session.connection_pool().await;
3894 for canceled_user_id in canceled_calls_to_user_ids {
3895 for connection_id in pool.user_connection_ids(canceled_user_id) {
3896 session
3897 .peer
3898 .send(
3899 connection_id,
3900 proto::CallCanceled {
3901 room_id: room_id.to_proto(),
3902 },
3903 )
3904 .trace_err();
3905 }
3906 contacts_to_update.insert(canceled_user_id);
3907 }
3908 }
3909
3910 for contact_user_id in contacts_to_update {
3911 update_user_contacts(contact_user_id, session).await?;
3912 }
3913
3914 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3915 live_kit
3916 .remove_participant(livekit_room.clone(), session.user_id().to_string())
3917 .await
3918 .trace_err();
3919
3920 if delete_livekit_room {
3921 live_kit.delete_room(livekit_room).await.trace_err();
3922 }
3923 }
3924
3925 Ok(())
3926}
3927
3928async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3929 let left_channel_buffers = session
3930 .db()
3931 .await
3932 .leave_channel_buffers(session.connection_id)
3933 .await?;
3934
3935 for left_buffer in left_channel_buffers {
3936 channel_buffer_updated(
3937 session.connection_id,
3938 left_buffer.connections,
3939 &proto::UpdateChannelBufferCollaborators {
3940 channel_id: left_buffer.channel_id.to_proto(),
3941 collaborators: left_buffer.collaborators,
3942 },
3943 &session.peer,
3944 );
3945 }
3946
3947 Ok(())
3948}
3949
3950fn project_left(project: &db::LeftProject, session: &Session) {
3951 for connection_id in &project.connection_ids {
3952 if project.should_unshare {
3953 session
3954 .peer
3955 .send(
3956 *connection_id,
3957 proto::UnshareProject {
3958 project_id: project.id.to_proto(),
3959 },
3960 )
3961 .trace_err();
3962 } else {
3963 session
3964 .peer
3965 .send(
3966 *connection_id,
3967 proto::RemoveProjectCollaborator {
3968 project_id: project.id.to_proto(),
3969 peer_id: Some(session.connection_id.into()),
3970 },
3971 )
3972 .trace_err();
3973 }
3974 }
3975}
3976
3977async fn share_agent_thread(
3978 request: proto::ShareAgentThread,
3979 response: Response<proto::ShareAgentThread>,
3980 session: MessageContext,
3981) -> Result<()> {
3982 let user_id = session.user_id();
3983
3984 let share_id = SharedThreadId::from_proto(request.session_id.clone())
3985 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
3986
3987 session
3988 .db()
3989 .await
3990 .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
3991 .await?;
3992
3993 response.send(proto::Ack {})?;
3994
3995 Ok(())
3996}
3997
3998async fn get_shared_agent_thread(
3999 request: proto::GetSharedAgentThread,
4000 response: Response<proto::GetSharedAgentThread>,
4001 session: MessageContext,
4002) -> Result<()> {
4003 let share_id = SharedThreadId::from_proto(request.session_id)
4004 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4005
4006 let result = session.db().await.get_shared_thread(share_id).await?;
4007
4008 match result {
4009 Some((thread, username)) => {
4010 response.send(proto::GetSharedAgentThreadResponse {
4011 title: thread.title,
4012 thread_data: thread.data,
4013 sharer_username: username,
4014 created_at: thread.created_at.and_utc().to_rfc3339(),
4015 })?;
4016 }
4017 None => {
4018 return Err(anyhow!("Shared thread not found").into());
4019 }
4020 }
4021
4022 Ok(())
4023}
4024
4025pub trait ResultExt {
4026 type Ok;
4027
4028 fn trace_err(self) -> Option<Self::Ok>;
4029}
4030
4031impl<T, E> ResultExt for Result<T, E>
4032where
4033 E: std::fmt::Debug,
4034{
4035 type Ok = T;
4036
4037 #[track_caller]
4038 fn trace_err(self) -> Option<T> {
4039 match self {
4040 Ok(value) => Some(value),
4041 Err(error) => {
4042 tracing::error!("{:?}", error);
4043 None
4044 }
4045 }
4046 }
4047}