1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
10 UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use futures::TryFutureExt as _;
36use rpc::proto::split_repository_update;
37use tracing::Span;
38use util::paths::PathStyle;
39
40use futures::{
41 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
42 stream::FuturesUnordered,
43};
44use prometheus::{IntGauge, register_int_gauge};
45use rpc::{
46 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
47 proto::{
48 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
49 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
50 },
51};
52use semver::Version;
53use std::{
54 any::TypeId,
55 future::Future,
56 marker::PhantomData,
57 mem,
58 net::SocketAddr,
59 ops::{Deref, DerefMut},
60 rc::Rc,
61 sync::{
62 Arc, OnceLock,
63 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
64 },
65 time::{Duration, Instant},
66};
67use tokio::sync::{Semaphore, watch};
68use tower::ServiceBuilder;
69use tracing::{
70 Instrument,
71 field::{self},
72 info_span, instrument,
73};
74
75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
76
77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
79
80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
82
83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
84
85const TOTAL_DURATION_MS: &str = "total_duration_ms";
86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
88const HOST_WAITING_MS: &str = "host_waiting_ms";
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
92
93pub struct ConnectionGuard;
94
95impl ConnectionGuard {
96 pub fn try_acquire() -> Result<Self, ()> {
97 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
98 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
99 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
100 tracing::error!(
101 "too many concurrent connections: {}",
102 current_connections + 1
103 );
104 return Err(());
105 }
106 Ok(ConnectionGuard)
107 }
108}
109
110impl Drop for ConnectionGuard {
111 fn drop(&mut self) {
112 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
113 }
114}
115
116struct Response<R> {
117 peer: Arc<Peer>,
118 receipt: Receipt<R>,
119 responded: Arc<AtomicBool>,
120}
121
122impl<R: RequestMessage> Response<R> {
123 fn send(self, payload: R::Response) -> Result<()> {
124 self.responded.store(true, SeqCst);
125 self.peer.respond(self.receipt, payload)?;
126 Ok(())
127 }
128}
129
130#[derive(Clone, Debug)]
131pub enum Principal {
132 User(User),
133}
134
135impl Principal {
136 fn update_span(&self, span: &tracing::Span) {
137 match &self {
138 Principal::User(user) => {
139 span.record("user_id", user.id.0);
140 span.record("login", &user.github_login);
141 }
142 }
143 }
144}
145
146#[derive(Clone)]
147struct MessageContext {
148 session: Session,
149 span: tracing::Span,
150}
151
152impl Deref for MessageContext {
153 type Target = Session;
154
155 fn deref(&self) -> &Self::Target {
156 &self.session
157 }
158}
159
160impl MessageContext {
161 pub fn forward_request<T: RequestMessage>(
162 &self,
163 receiver_id: ConnectionId,
164 request: T,
165 ) -> impl Future<Output = anyhow::Result<T::Response>> {
166 let request_start_time = Instant::now();
167 let span = self.span.clone();
168 tracing::info!("start forwarding request");
169 self.peer
170 .forward_request(self.connection_id, receiver_id, request)
171 .inspect(move |_| {
172 span.record(
173 HOST_WAITING_MS,
174 request_start_time.elapsed().as_micros() as f64 / 1000.0,
175 );
176 })
177 .inspect_err(|_| tracing::error!("error forwarding request"))
178 .inspect_ok(|_| tracing::info!("finished forwarding request"))
179 }
180}
181
182#[derive(Clone)]
183struct Session {
184 principal: Principal,
185 connection_id: ConnectionId,
186 db: Arc<tokio::sync::Mutex<DbHandle>>,
187 peer: Arc<Peer>,
188 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
189 app_state: Arc<AppState>,
190 /// The GeoIP country code for the user.
191 #[allow(unused)]
192 geoip_country_code: Option<String>,
193 #[allow(unused)]
194 system_id: Option<String>,
195 _executor: Executor,
196}
197
198impl Session {
199 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
200 #[cfg(feature = "test-support")]
201 tokio::task::yield_now().await;
202 let guard = self.db.lock().await;
203 #[cfg(feature = "test-support")]
204 tokio::task::yield_now().await;
205 guard
206 }
207
208 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
209 #[cfg(feature = "test-support")]
210 tokio::task::yield_now().await;
211 let guard = self.connection_pool.lock();
212 ConnectionPoolGuard {
213 guard,
214 _not_send: PhantomData,
215 }
216 }
217
218 #[expect(dead_code)]
219 fn is_staff(&self) -> bool {
220 match &self.principal {
221 Principal::User(user) => user.admin,
222 }
223 }
224
225 fn user_id(&self) -> UserId {
226 match &self.principal {
227 Principal::User(user) => user.id,
228 }
229 }
230}
231
232impl Debug for Session {
233 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
234 let mut result = f.debug_struct("Session");
235 match &self.principal {
236 Principal::User(user) => {
237 result.field("user", &user.github_login);
238 }
239 }
240 result.field("connection_id", &self.connection_id).finish()
241 }
242}
243
244struct DbHandle(Arc<Database>);
245
246impl Deref for DbHandle {
247 type Target = Database;
248
249 fn deref(&self) -> &Self::Target {
250 self.0.as_ref()
251 }
252}
253
254pub struct Server {
255 id: parking_lot::Mutex<ServerId>,
256 peer: Arc<Peer>,
257 pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
258 app_state: Arc<AppState>,
259 handlers: HashMap<TypeId, MessageHandler>,
260 teardown: watch::Sender<bool>,
261}
262
263struct ConnectionPoolGuard<'a> {
264 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
265 _not_send: PhantomData<Rc<()>>,
266}
267
268impl Server {
269 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
270 let mut server = Self {
271 id: parking_lot::Mutex::new(id),
272 peer: Peer::new(id.0 as u32),
273 app_state,
274 connection_pool: Default::default(),
275 handlers: Default::default(),
276 teardown: watch::channel(false).0,
277 };
278
279 server
280 .add_request_handler(ping)
281 .add_request_handler(create_room)
282 .add_request_handler(join_room)
283 .add_request_handler(rejoin_room)
284 .add_request_handler(leave_room)
285 .add_request_handler(set_room_participant_role)
286 .add_request_handler(call)
287 .add_request_handler(cancel_call)
288 .add_message_handler(decline_call)
289 .add_request_handler(update_participant_location)
290 .add_request_handler(share_project)
291 .add_message_handler(unshare_project)
292 .add_request_handler(join_project)
293 .add_message_handler(leave_project)
294 .add_request_handler(update_project)
295 .add_request_handler(update_worktree)
296 .add_request_handler(update_repository)
297 .add_request_handler(remove_repository)
298 .add_message_handler(start_language_server)
299 .add_message_handler(update_language_server)
300 .add_message_handler(update_diagnostic_summary)
301 .add_message_handler(update_worktree_settings)
302 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
307 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
308 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
311 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
313 .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
314 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
316 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
317 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
318 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
319 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
320 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
321 .add_request_handler(
322 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
323 )
324 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
325 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
326 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
327 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
328 .add_request_handler(
329 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
330 )
331 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
332 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
333 .add_request_handler(
334 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
335 )
336 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
337 .add_request_handler(
338 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
339 )
340 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
341 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
342 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
343 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
344 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
345 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
346 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
347 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
348 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
349 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
350 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
353 )
354 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
355 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
356 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
357 .add_request_handler(lsp_query)
358 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
359 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
360 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
361 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
362 .add_message_handler(create_buffer_for_peer)
363 .add_message_handler(create_image_for_peer)
364 .add_request_handler(update_buffer)
365 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
366 .add_message_handler(
367 broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
368 )
369 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
370 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
371 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
372 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
373 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
374 .add_message_handler(
375 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
376 )
377 .add_request_handler(get_users)
378 .add_request_handler(fuzzy_search_users)
379 .add_request_handler(request_contact)
380 .add_request_handler(remove_contact)
381 .add_request_handler(respond_to_contact_request)
382 .add_message_handler(subscribe_to_channels)
383 .add_request_handler(create_channel)
384 .add_request_handler(delete_channel)
385 .add_request_handler(invite_channel_member)
386 .add_request_handler(remove_channel_member)
387 .add_request_handler(set_channel_member_role)
388 .add_request_handler(set_channel_visibility)
389 .add_request_handler(rename_channel)
390 .add_request_handler(join_channel_buffer)
391 .add_request_handler(leave_channel_buffer)
392 .add_message_handler(update_channel_buffer)
393 .add_request_handler(rejoin_channel_buffers)
394 .add_request_handler(get_channel_members)
395 .add_request_handler(respond_to_channel_invite)
396 .add_request_handler(join_channel)
397 .add_request_handler(join_channel_chat)
398 .add_message_handler(leave_channel_chat)
399 .add_request_handler(send_channel_message)
400 .add_request_handler(remove_channel_message)
401 .add_request_handler(update_channel_message)
402 .add_request_handler(get_channel_messages)
403 .add_request_handler(get_channel_messages_by_id)
404 .add_request_handler(get_notifications)
405 .add_request_handler(mark_notification_as_read)
406 .add_request_handler(move_channel)
407 .add_request_handler(reorder_channel)
408 .add_request_handler(follow)
409 .add_message_handler(unfollow)
410 .add_message_handler(update_followers)
411 .add_message_handler(acknowledge_channel_message)
412 .add_message_handler(acknowledge_buffer_version)
413 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
414 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
415 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
416 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
417 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
418 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
419 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
420 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
421 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
422 .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
423 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
424 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
425 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
426 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
427 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
428 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
429 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
430 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
431 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
432 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
433 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
434 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
435 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
436 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
437 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
438 .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
439 .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
440 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
441 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
442 .add_message_handler(update_context)
443 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
444 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
445 .add_request_handler(share_agent_thread)
446 .add_request_handler(get_shared_agent_thread)
447 .add_request_handler(forward_project_search_chunk);
448
449 Arc::new(server)
450 }
451
452 pub async fn start(&self) -> Result<()> {
453 let server_id = *self.id.lock();
454 let app_state = self.app_state.clone();
455 let peer = self.peer.clone();
456 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
457 let pool = self.connection_pool.clone();
458 let livekit_client = self.app_state.livekit_client.clone();
459
460 let span = info_span!("start server");
461 self.app_state.executor.spawn_detached(
462 async move {
463 tracing::info!("waiting for cleanup timeout");
464 timeout.await;
465 tracing::info!("cleanup timeout expired, retrieving stale rooms");
466
467 app_state
468 .db
469 .delete_stale_channel_chat_participants(
470 &app_state.config.zed_environment,
471 server_id,
472 )
473 .await
474 .trace_err();
475
476 if let Some((room_ids, channel_ids)) = app_state
477 .db
478 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
479 .await
480 .trace_err()
481 {
482 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
483 tracing::info!(
484 stale_channel_buffer_count = channel_ids.len(),
485 "retrieved stale channel buffers"
486 );
487
488 for channel_id in channel_ids {
489 if let Some(refreshed_channel_buffer) = app_state
490 .db
491 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
492 .await
493 .trace_err()
494 {
495 for connection_id in refreshed_channel_buffer.connection_ids {
496 peer.send(
497 connection_id,
498 proto::UpdateChannelBufferCollaborators {
499 channel_id: channel_id.to_proto(),
500 collaborators: refreshed_channel_buffer
501 .collaborators
502 .clone(),
503 },
504 )
505 .trace_err();
506 }
507 }
508 }
509
510 for room_id in room_ids {
511 let mut contacts_to_update = HashSet::default();
512 let mut canceled_calls_to_user_ids = Vec::new();
513 let mut livekit_room = String::new();
514 let mut delete_livekit_room = false;
515
516 if let Some(mut refreshed_room) = app_state
517 .db
518 .clear_stale_room_participants(room_id, server_id)
519 .await
520 .trace_err()
521 {
522 tracing::info!(
523 room_id = room_id.0,
524 new_participant_count = refreshed_room.room.participants.len(),
525 "refreshed room"
526 );
527 room_updated(&refreshed_room.room, &peer);
528 if let Some(channel) = refreshed_room.channel.as_ref() {
529 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
530 }
531 contacts_to_update
532 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
533 contacts_to_update
534 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
535 canceled_calls_to_user_ids =
536 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
537 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
538 delete_livekit_room = refreshed_room.room.participants.is_empty();
539 }
540
541 {
542 let pool = pool.lock();
543 for canceled_user_id in canceled_calls_to_user_ids {
544 for connection_id in pool.user_connection_ids(canceled_user_id) {
545 peer.send(
546 connection_id,
547 proto::CallCanceled {
548 room_id: room_id.to_proto(),
549 },
550 )
551 .trace_err();
552 }
553 }
554 }
555
556 for user_id in contacts_to_update {
557 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
558 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
559 if let Some((busy, contacts)) = busy.zip(contacts) {
560 let pool = pool.lock();
561 let updated_contact = contact_for_user(user_id, busy, &pool);
562 for contact in contacts {
563 if let db::Contact::Accepted {
564 user_id: contact_user_id,
565 ..
566 } = contact
567 {
568 for contact_conn_id in
569 pool.user_connection_ids(contact_user_id)
570 {
571 peer.send(
572 contact_conn_id,
573 proto::UpdateContacts {
574 contacts: vec![updated_contact.clone()],
575 remove_contacts: Default::default(),
576 incoming_requests: Default::default(),
577 remove_incoming_requests: Default::default(),
578 outgoing_requests: Default::default(),
579 remove_outgoing_requests: Default::default(),
580 },
581 )
582 .trace_err();
583 }
584 }
585 }
586 }
587 }
588
589 if let Some(live_kit) = livekit_client.as_ref()
590 && delete_livekit_room
591 {
592 live_kit.delete_room(livekit_room).await.trace_err();
593 }
594 }
595 }
596
597 app_state
598 .db
599 .delete_stale_channel_chat_participants(
600 &app_state.config.zed_environment,
601 server_id,
602 )
603 .await
604 .trace_err();
605
606 app_state
607 .db
608 .clear_old_worktree_entries(server_id)
609 .await
610 .trace_err();
611
612 app_state
613 .db
614 .delete_stale_servers(&app_state.config.zed_environment, server_id)
615 .await
616 .trace_err();
617 }
618 .instrument(span),
619 );
620 Ok(())
621 }
622
623 pub fn teardown(&self) {
624 self.peer.teardown();
625 self.connection_pool.lock().reset();
626 let _ = self.teardown.send(true);
627 }
628
629 #[cfg(feature = "test-support")]
630 pub fn reset(&self, id: ServerId) {
631 self.teardown();
632 *self.id.lock() = id;
633 self.peer.reset(id.0 as u32);
634 let _ = self.teardown.send(false);
635 }
636
637 #[cfg(feature = "test-support")]
638 pub fn id(&self) -> ServerId {
639 *self.id.lock()
640 }
641
642 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
643 where
644 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
645 Fut: 'static + Send + Future<Output = Result<()>>,
646 M: EnvelopedMessage,
647 {
648 let prev_handler = self.handlers.insert(
649 TypeId::of::<M>(),
650 Box::new(move |envelope, session, span| {
651 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
652 let received_at = envelope.received_at;
653 tracing::info!("message received");
654 let start_time = Instant::now();
655 let future = (handler)(
656 *envelope,
657 MessageContext {
658 session,
659 span: span.clone(),
660 },
661 );
662 async move {
663 let result = future.await;
664 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
665 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
666 let queue_duration_ms = total_duration_ms - processing_duration_ms;
667 span.record(TOTAL_DURATION_MS, total_duration_ms);
668 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
669 span.record(QUEUE_DURATION_MS, queue_duration_ms);
670 match result {
671 Err(error) => {
672 tracing::error!(?error, "error handling message")
673 }
674 Ok(()) => tracing::info!("finished handling message"),
675 }
676 }
677 .boxed()
678 }),
679 );
680 if prev_handler.is_some() {
681 panic!("registered a handler for the same message twice");
682 }
683 self
684 }
685
686 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
687 where
688 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
689 Fut: 'static + Send + Future<Output = Result<()>>,
690 M: EnvelopedMessage,
691 {
692 self.add_handler(move |envelope, session| handler(envelope.payload, session));
693 self
694 }
695
696 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
697 where
698 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
699 Fut: Send + Future<Output = Result<()>>,
700 M: RequestMessage,
701 {
702 let handler = Arc::new(handler);
703 self.add_handler(move |envelope, session| {
704 let receipt = envelope.receipt();
705 let handler = handler.clone();
706 async move {
707 let peer = session.peer.clone();
708 let responded = Arc::new(AtomicBool::default());
709 let response = Response {
710 peer: peer.clone(),
711 responded: responded.clone(),
712 receipt,
713 };
714 match (handler)(envelope.payload, response, session).await {
715 Ok(()) => {
716 if responded.load(std::sync::atomic::Ordering::SeqCst) {
717 Ok(())
718 } else {
719 Err(anyhow!("handler did not send a response"))?
720 }
721 }
722 Err(error) => {
723 let proto_err = match &error {
724 Error::Internal(err) => err.to_proto(),
725 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
726 };
727 peer.respond_with_error(receipt, proto_err)?;
728 Err(error)
729 }
730 }
731 }
732 })
733 }
734
735 pub fn handle_connection(
736 self: &Arc<Self>,
737 connection: Connection,
738 address: String,
739 principal: Principal,
740 zed_version: ZedVersion,
741 release_channel: Option<String>,
742 user_agent: Option<String>,
743 geoip_country_code: Option<String>,
744 system_id: Option<String>,
745 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
746 executor: Executor,
747 connection_guard: Option<ConnectionGuard>,
748 ) -> impl Future<Output = ()> + use<> {
749 let this = self.clone();
750 let span = info_span!("handle connection", %address,
751 connection_id=field::Empty,
752 user_id=field::Empty,
753 login=field::Empty,
754 user_agent=field::Empty,
755 geoip_country_code=field::Empty,
756 release_channel=field::Empty,
757 );
758 principal.update_span(&span);
759 if let Some(user_agent) = user_agent {
760 span.record("user_agent", user_agent);
761 }
762 if let Some(release_channel) = release_channel {
763 span.record("release_channel", release_channel);
764 }
765
766 if let Some(country_code) = geoip_country_code.as_ref() {
767 span.record("geoip_country_code", country_code);
768 }
769
770 let mut teardown = self.teardown.subscribe();
771 async move {
772 if *teardown.borrow() {
773 tracing::error!("server is tearing down");
774 return;
775 }
776
777 let (connection_id, handle_io, mut incoming_rx) =
778 this.peer.add_connection(connection, {
779 let executor = executor.clone();
780 move |duration| executor.sleep(duration)
781 });
782 tracing::Span::current().record("connection_id", format!("{}", connection_id));
783
784 tracing::info!("connection opened");
785
786 let session = Session {
787 principal: principal.clone(),
788 connection_id,
789 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
790 peer: this.peer.clone(),
791 connection_pool: this.connection_pool.clone(),
792 app_state: this.app_state.clone(),
793 geoip_country_code,
794 system_id,
795 _executor: executor.clone(),
796 };
797
798 if let Err(error) = this
799 .send_initial_client_update(
800 connection_id,
801 zed_version,
802 send_connection_id,
803 &session,
804 )
805 .await
806 {
807 tracing::error!(?error, "failed to send initial client update");
808 return;
809 }
810 drop(connection_guard);
811
812 let handle_io = handle_io.fuse();
813 futures::pin_mut!(handle_io);
814
815 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
816 // This prevents deadlocks when e.g., client A performs a request to client B and
817 // client B performs a request to client A. If both clients stop processing further
818 // messages until their respective request completes, they won't have a chance to
819 // respond to the other client's request and cause a deadlock.
820 //
821 // This arrangement ensures we will attempt to process earlier messages first, but fall
822 // back to processing messages arrived later in the spirit of making progress.
823 const MAX_CONCURRENT_HANDLERS: usize = 256;
824 let mut foreground_message_handlers = FuturesUnordered::new();
825 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
826 let get_concurrent_handlers = {
827 let concurrent_handlers = concurrent_handlers.clone();
828 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
829 };
830 loop {
831 let next_message = async {
832 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
833 let message = incoming_rx.next().await;
834 // Cache the concurrent_handlers here, so that we know what the
835 // queue looks like as each handler starts
836 (permit, message, get_concurrent_handlers())
837 }
838 .fuse();
839 futures::pin_mut!(next_message);
840 futures::select_biased! {
841 _ = teardown.changed().fuse() => return,
842 result = handle_io => {
843 if let Err(error) = result {
844 tracing::error!(?error, "error handling I/O");
845 }
846 break;
847 }
848 _ = foreground_message_handlers.next() => {}
849 next_message = next_message => {
850 let (permit, message, concurrent_handlers) = next_message;
851 if let Some(message) = message {
852 let type_name = message.payload_type_name();
853 // note: we copy all the fields from the parent span so we can query them in the logs.
854 // (https://github.com/tokio-rs/tracing/issues/2670).
855 let span = tracing::info_span!("receive message",
856 %connection_id,
857 %address,
858 type_name,
859 concurrent_handlers,
860 user_id=field::Empty,
861 login=field::Empty,
862 lsp_query_request=field::Empty,
863 release_channel=field::Empty,
864 { TOTAL_DURATION_MS }=field::Empty,
865 { PROCESSING_DURATION_MS }=field::Empty,
866 { QUEUE_DURATION_MS }=field::Empty,
867 { HOST_WAITING_MS }=field::Empty
868 );
869 principal.update_span(&span);
870 let span_enter = span.enter();
871 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
872 let is_background = message.is_background();
873 let handle_message = (handler)(message, session.clone(), span.clone());
874 drop(span_enter);
875
876 let handle_message = async move {
877 handle_message.await;
878 drop(permit);
879 }.instrument(span);
880 if is_background {
881 executor.spawn_detached(handle_message);
882 } else {
883 foreground_message_handlers.push(handle_message);
884 }
885 } else {
886 tracing::error!("no message handler");
887 }
888 } else {
889 tracing::info!("connection closed");
890 break;
891 }
892 }
893 }
894 }
895
896 drop(foreground_message_handlers);
897 let concurrent_handlers = get_concurrent_handlers();
898 tracing::info!(concurrent_handlers, "signing out");
899 if let Err(error) = connection_lost(session, teardown, executor).await {
900 tracing::error!(?error, "error signing out");
901 }
902 }
903 .instrument(span)
904 }
905
906 async fn send_initial_client_update(
907 &self,
908 connection_id: ConnectionId,
909 zed_version: ZedVersion,
910 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
911 session: &Session,
912 ) -> Result<()> {
913 self.peer.send(
914 connection_id,
915 proto::Hello {
916 peer_id: Some(connection_id.into()),
917 },
918 )?;
919 tracing::info!("sent hello message");
920 if let Some(send_connection_id) = send_connection_id.take() {
921 let _ = send_connection_id.send(connection_id);
922 }
923
924 match &session.principal {
925 Principal::User(user) => {
926 if !user.connected_once {
927 self.peer.send(connection_id, proto::ShowContacts {})?;
928 self.app_state
929 .db
930 .set_user_connected_once(user.id, true)
931 .await?;
932 }
933
934 let contacts = self.app_state.db.get_contacts(user.id).await?;
935
936 {
937 let mut pool = self.connection_pool.lock();
938 pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
939 self.peer.send(
940 connection_id,
941 build_initial_contacts_update(contacts, &pool),
942 )?;
943 }
944
945 if should_auto_subscribe_to_channels(&zed_version) {
946 subscribe_user_to_channels(user.id, session).await?;
947 }
948
949 if let Some(incoming_call) =
950 self.app_state.db.incoming_call_for_user(user.id).await?
951 {
952 self.peer.send(connection_id, incoming_call)?;
953 }
954
955 update_user_contacts(user.id, session).await?;
956 }
957 }
958
959 Ok(())
960 }
961}
962
963impl Deref for ConnectionPoolGuard<'_> {
964 type Target = ConnectionPool;
965
966 fn deref(&self) -> &Self::Target {
967 &self.guard
968 }
969}
970
971impl DerefMut for ConnectionPoolGuard<'_> {
972 fn deref_mut(&mut self) -> &mut Self::Target {
973 &mut self.guard
974 }
975}
976
977impl Drop for ConnectionPoolGuard<'_> {
978 fn drop(&mut self) {
979 #[cfg(feature = "test-support")]
980 self.check_invariants();
981 }
982}
983
984fn broadcast<F>(
985 sender_id: Option<ConnectionId>,
986 receiver_ids: impl IntoIterator<Item = ConnectionId>,
987 mut f: F,
988) where
989 F: FnMut(ConnectionId) -> anyhow::Result<()>,
990{
991 for receiver_id in receiver_ids {
992 if Some(receiver_id) != sender_id
993 && let Err(error) = f(receiver_id)
994 {
995 tracing::error!("failed to send to {:?} {}", receiver_id, error);
996 }
997 }
998}
999
1000pub struct ProtocolVersion(u32);
1001
1002impl Header for ProtocolVersion {
1003 fn name() -> &'static HeaderName {
1004 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1005 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1006 }
1007
1008 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1009 where
1010 Self: Sized,
1011 I: Iterator<Item = &'i axum::http::HeaderValue>,
1012 {
1013 let version = values
1014 .next()
1015 .ok_or_else(axum::headers::Error::invalid)?
1016 .to_str()
1017 .map_err(|_| axum::headers::Error::invalid())?
1018 .parse()
1019 .map_err(|_| axum::headers::Error::invalid())?;
1020 Ok(Self(version))
1021 }
1022
1023 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1024 values.extend([self.0.to_string().parse().unwrap()]);
1025 }
1026}
1027
1028pub struct AppVersionHeader(Version);
1029impl Header for AppVersionHeader {
1030 fn name() -> &'static HeaderName {
1031 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1032 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1033 }
1034
1035 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1036 where
1037 Self: Sized,
1038 I: Iterator<Item = &'i axum::http::HeaderValue>,
1039 {
1040 let version = values
1041 .next()
1042 .ok_or_else(axum::headers::Error::invalid)?
1043 .to_str()
1044 .map_err(|_| axum::headers::Error::invalid())?
1045 .parse()
1046 .map_err(|_| axum::headers::Error::invalid())?;
1047 Ok(Self(version))
1048 }
1049
1050 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1051 values.extend([self.0.to_string().parse().unwrap()]);
1052 }
1053}
1054
1055#[derive(Debug)]
1056pub struct ReleaseChannelHeader(String);
1057
1058impl Header for ReleaseChannelHeader {
1059 fn name() -> &'static HeaderName {
1060 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1061 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1062 }
1063
1064 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1065 where
1066 Self: Sized,
1067 I: Iterator<Item = &'i axum::http::HeaderValue>,
1068 {
1069 Ok(Self(
1070 values
1071 .next()
1072 .ok_or_else(axum::headers::Error::invalid)?
1073 .to_str()
1074 .map_err(|_| axum::headers::Error::invalid())?
1075 .to_owned(),
1076 ))
1077 }
1078
1079 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1080 values.extend([self.0.parse().unwrap()]);
1081 }
1082}
1083
1084pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1085 Router::new()
1086 .route("/rpc", get(handle_websocket_request))
1087 .layer(
1088 ServiceBuilder::new()
1089 .layer(Extension(server.app_state.clone()))
1090 .layer(middleware::from_fn(auth::validate_header)),
1091 )
1092 .route("/metrics", get(handle_metrics))
1093 .layer(Extension(server))
1094}
1095
1096pub async fn handle_websocket_request(
1097 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1098 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1099 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1100 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1101 Extension(server): Extension<Arc<Server>>,
1102 Extension(principal): Extension<Principal>,
1103 user_agent: Option<TypedHeader<UserAgent>>,
1104 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1105 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1106 ws: WebSocketUpgrade,
1107) -> axum::response::Response {
1108 if protocol_version != rpc::PROTOCOL_VERSION {
1109 return (
1110 StatusCode::UPGRADE_REQUIRED,
1111 "client must be upgraded".to_string(),
1112 )
1113 .into_response();
1114 }
1115
1116 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1117 return (
1118 StatusCode::UPGRADE_REQUIRED,
1119 "no version header found".to_string(),
1120 )
1121 .into_response();
1122 };
1123
1124 let release_channel = release_channel_header.map(|header| header.0.0);
1125
1126 if !version.can_collaborate() {
1127 return (
1128 StatusCode::UPGRADE_REQUIRED,
1129 "client must be upgraded".to_string(),
1130 )
1131 .into_response();
1132 }
1133
1134 let socket_address = socket_address.to_string();
1135
1136 // Acquire connection guard before WebSocket upgrade
1137 let connection_guard = match ConnectionGuard::try_acquire() {
1138 Ok(guard) => guard,
1139 Err(()) => {
1140 return (
1141 StatusCode::SERVICE_UNAVAILABLE,
1142 "Too many concurrent connections",
1143 )
1144 .into_response();
1145 }
1146 };
1147
1148 ws.on_upgrade(move |socket| {
1149 let socket = socket
1150 .map_ok(to_tungstenite_message)
1151 .err_into()
1152 .with(|message| async move { to_axum_message(message) });
1153 let connection = Connection::new(Box::pin(socket));
1154 async move {
1155 server
1156 .handle_connection(
1157 connection,
1158 socket_address,
1159 principal,
1160 version,
1161 release_channel,
1162 user_agent.map(|header| header.to_string()),
1163 country_code_header.map(|header| header.to_string()),
1164 system_id_header.map(|header| header.to_string()),
1165 None,
1166 Executor::Production,
1167 Some(connection_guard),
1168 )
1169 .await;
1170 }
1171 })
1172}
1173
1174pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1175 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1176 let connections_metric = CONNECTIONS_METRIC
1177 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1178
1179 let connections = server
1180 .connection_pool
1181 .lock()
1182 .connections()
1183 .filter(|connection| !connection.admin)
1184 .count();
1185 connections_metric.set(connections as _);
1186
1187 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1188 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1189 register_int_gauge!(
1190 "shared_projects",
1191 "number of open projects with one or more guests"
1192 )
1193 .unwrap()
1194 });
1195
1196 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1197 shared_projects_metric.set(shared_projects as _);
1198
1199 let encoder = prometheus::TextEncoder::new();
1200 let metric_families = prometheus::gather();
1201 let encoded_metrics = encoder
1202 .encode_to_string(&metric_families)
1203 .map_err(|err| anyhow!("{err}"))?;
1204 Ok(encoded_metrics)
1205}
1206
1207#[instrument(err, skip(executor))]
1208async fn connection_lost(
1209 session: Session,
1210 mut teardown: watch::Receiver<bool>,
1211 executor: Executor,
1212) -> Result<()> {
1213 session.peer.disconnect(session.connection_id);
1214 session
1215 .connection_pool()
1216 .await
1217 .remove_connection(session.connection_id)?;
1218
1219 session
1220 .db()
1221 .await
1222 .connection_lost(session.connection_id)
1223 .await
1224 .trace_err();
1225
1226 futures::select_biased! {
1227 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1228
1229 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1230 leave_room_for_session(&session, session.connection_id).await.trace_err();
1231 leave_channel_buffers_for_session(&session)
1232 .await
1233 .trace_err();
1234
1235 if !session
1236 .connection_pool()
1237 .await
1238 .is_user_online(session.user_id())
1239 {
1240 let db = session.db().await;
1241 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1242 room_updated(&room, &session.peer);
1243 }
1244 }
1245
1246 update_user_contacts(session.user_id(), &session).await?;
1247 },
1248 _ = teardown.changed().fuse() => {}
1249 }
1250
1251 Ok(())
1252}
1253
1254/// Acknowledges a ping from a client, used to keep the connection alive.
1255async fn ping(
1256 _: proto::Ping,
1257 response: Response<proto::Ping>,
1258 _session: MessageContext,
1259) -> Result<()> {
1260 response.send(proto::Ack {})?;
1261 Ok(())
1262}
1263
1264/// Creates a new room for calling (outside of channels)
1265async fn create_room(
1266 _request: proto::CreateRoom,
1267 response: Response<proto::CreateRoom>,
1268 session: MessageContext,
1269) -> Result<()> {
1270 let livekit_room = nanoid::nanoid!(30);
1271
1272 let live_kit_connection_info = util::maybe!(async {
1273 let live_kit = session.app_state.livekit_client.as_ref();
1274 let live_kit = live_kit?;
1275 let user_id = session.user_id().to_string();
1276
1277 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1278
1279 Some(proto::LiveKitConnectionInfo {
1280 server_url: live_kit.url().into(),
1281 token,
1282 can_publish: true,
1283 })
1284 })
1285 .await;
1286
1287 let room = session
1288 .db()
1289 .await
1290 .create_room(session.user_id(), session.connection_id, &livekit_room)
1291 .await?;
1292
1293 response.send(proto::CreateRoomResponse {
1294 room: Some(room.clone()),
1295 live_kit_connection_info,
1296 })?;
1297
1298 update_user_contacts(session.user_id(), &session).await?;
1299 Ok(())
1300}
1301
1302/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1303async fn join_room(
1304 request: proto::JoinRoom,
1305 response: Response<proto::JoinRoom>,
1306 session: MessageContext,
1307) -> Result<()> {
1308 let room_id = RoomId::from_proto(request.id);
1309
1310 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1311
1312 if let Some(channel_id) = channel_id {
1313 return join_channel_internal(channel_id, Box::new(response), session).await;
1314 }
1315
1316 let joined_room = {
1317 let room = session
1318 .db()
1319 .await
1320 .join_room(room_id, session.user_id(), session.connection_id)
1321 .await?;
1322 room_updated(&room.room, &session.peer);
1323 room.into_inner()
1324 };
1325
1326 for connection_id in session
1327 .connection_pool()
1328 .await
1329 .user_connection_ids(session.user_id())
1330 {
1331 session
1332 .peer
1333 .send(
1334 connection_id,
1335 proto::CallCanceled {
1336 room_id: room_id.to_proto(),
1337 },
1338 )
1339 .trace_err();
1340 }
1341
1342 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1343 {
1344 live_kit
1345 .room_token(
1346 &joined_room.room.livekit_room,
1347 &session.user_id().to_string(),
1348 )
1349 .trace_err()
1350 .map(|token| proto::LiveKitConnectionInfo {
1351 server_url: live_kit.url().into(),
1352 token,
1353 can_publish: true,
1354 })
1355 } else {
1356 None
1357 };
1358
1359 response.send(proto::JoinRoomResponse {
1360 room: Some(joined_room.room),
1361 channel_id: None,
1362 live_kit_connection_info,
1363 })?;
1364
1365 update_user_contacts(session.user_id(), &session).await?;
1366 Ok(())
1367}
1368
1369/// Rejoin room is used to reconnect to a room after connection errors.
1370async fn rejoin_room(
1371 request: proto::RejoinRoom,
1372 response: Response<proto::RejoinRoom>,
1373 session: MessageContext,
1374) -> Result<()> {
1375 let room;
1376 let channel;
1377 {
1378 let mut rejoined_room = session
1379 .db()
1380 .await
1381 .rejoin_room(request, session.user_id(), session.connection_id)
1382 .await?;
1383
1384 response.send(proto::RejoinRoomResponse {
1385 room: Some(rejoined_room.room.clone()),
1386 reshared_projects: rejoined_room
1387 .reshared_projects
1388 .iter()
1389 .map(|project| proto::ResharedProject {
1390 id: project.id.to_proto(),
1391 collaborators: project
1392 .collaborators
1393 .iter()
1394 .map(|collaborator| collaborator.to_proto())
1395 .collect(),
1396 })
1397 .collect(),
1398 rejoined_projects: rejoined_room
1399 .rejoined_projects
1400 .iter()
1401 .map(|rejoined_project| rejoined_project.to_proto())
1402 .collect(),
1403 })?;
1404 room_updated(&rejoined_room.room, &session.peer);
1405
1406 for project in &rejoined_room.reshared_projects {
1407 for collaborator in &project.collaborators {
1408 session
1409 .peer
1410 .send(
1411 collaborator.connection_id,
1412 proto::UpdateProjectCollaborator {
1413 project_id: project.id.to_proto(),
1414 old_peer_id: Some(project.old_connection_id.into()),
1415 new_peer_id: Some(session.connection_id.into()),
1416 },
1417 )
1418 .trace_err();
1419 }
1420
1421 broadcast(
1422 Some(session.connection_id),
1423 project
1424 .collaborators
1425 .iter()
1426 .map(|collaborator| collaborator.connection_id),
1427 |connection_id| {
1428 session.peer.forward_send(
1429 session.connection_id,
1430 connection_id,
1431 proto::UpdateProject {
1432 project_id: project.id.to_proto(),
1433 worktrees: project.worktrees.clone(),
1434 },
1435 )
1436 },
1437 );
1438 }
1439
1440 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1441
1442 let rejoined_room = rejoined_room.into_inner();
1443
1444 room = rejoined_room.room;
1445 channel = rejoined_room.channel;
1446 }
1447
1448 if let Some(channel) = channel {
1449 channel_updated(
1450 &channel,
1451 &room,
1452 &session.peer,
1453 &*session.connection_pool().await,
1454 );
1455 }
1456
1457 update_user_contacts(session.user_id(), &session).await?;
1458 Ok(())
1459}
1460
1461fn notify_rejoined_projects(
1462 rejoined_projects: &mut Vec<RejoinedProject>,
1463 session: &Session,
1464) -> Result<()> {
1465 for project in rejoined_projects.iter() {
1466 for collaborator in &project.collaborators {
1467 session
1468 .peer
1469 .send(
1470 collaborator.connection_id,
1471 proto::UpdateProjectCollaborator {
1472 project_id: project.id.to_proto(),
1473 old_peer_id: Some(project.old_connection_id.into()),
1474 new_peer_id: Some(session.connection_id.into()),
1475 },
1476 )
1477 .trace_err();
1478 }
1479 }
1480
1481 for project in rejoined_projects {
1482 for worktree in mem::take(&mut project.worktrees) {
1483 // Stream this worktree's entries.
1484 let message = proto::UpdateWorktree {
1485 project_id: project.id.to_proto(),
1486 worktree_id: worktree.id,
1487 abs_path: worktree.abs_path.clone(),
1488 root_name: worktree.root_name,
1489 updated_entries: worktree.updated_entries,
1490 removed_entries: worktree.removed_entries,
1491 scan_id: worktree.scan_id,
1492 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1493 updated_repositories: worktree.updated_repositories,
1494 removed_repositories: worktree.removed_repositories,
1495 };
1496 for update in proto::split_worktree_update(message) {
1497 session.peer.send(session.connection_id, update)?;
1498 }
1499
1500 // Stream this worktree's diagnostics.
1501 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1502 if let Some(summary) = worktree_diagnostics.next() {
1503 let message = proto::UpdateDiagnosticSummary {
1504 project_id: project.id.to_proto(),
1505 worktree_id: worktree.id,
1506 summary: Some(summary),
1507 more_summaries: worktree_diagnostics.collect(),
1508 };
1509 session.peer.send(session.connection_id, message)?;
1510 }
1511
1512 for settings_file in worktree.settings_files {
1513 session.peer.send(
1514 session.connection_id,
1515 proto::UpdateWorktreeSettings {
1516 project_id: project.id.to_proto(),
1517 worktree_id: worktree.id,
1518 path: settings_file.path,
1519 content: Some(settings_file.content),
1520 kind: Some(settings_file.kind.to_proto().into()),
1521 outside_worktree: Some(settings_file.outside_worktree),
1522 },
1523 )?;
1524 }
1525 }
1526
1527 for repository in mem::take(&mut project.updated_repositories) {
1528 for update in split_repository_update(repository) {
1529 session.peer.send(session.connection_id, update)?;
1530 }
1531 }
1532
1533 for id in mem::take(&mut project.removed_repositories) {
1534 session.peer.send(
1535 session.connection_id,
1536 proto::RemoveRepository {
1537 project_id: project.id.to_proto(),
1538 id,
1539 },
1540 )?;
1541 }
1542 }
1543
1544 Ok(())
1545}
1546
1547/// leave room disconnects from the room.
1548async fn leave_room(
1549 _: proto::LeaveRoom,
1550 response: Response<proto::LeaveRoom>,
1551 session: MessageContext,
1552) -> Result<()> {
1553 leave_room_for_session(&session, session.connection_id).await?;
1554 response.send(proto::Ack {})?;
1555 Ok(())
1556}
1557
1558/// Updates the permissions of someone else in the room.
1559async fn set_room_participant_role(
1560 request: proto::SetRoomParticipantRole,
1561 response: Response<proto::SetRoomParticipantRole>,
1562 session: MessageContext,
1563) -> Result<()> {
1564 let user_id = UserId::from_proto(request.user_id);
1565 let role = ChannelRole::from(request.role());
1566
1567 let (livekit_room, can_publish) = {
1568 let room = session
1569 .db()
1570 .await
1571 .set_room_participant_role(
1572 session.user_id(),
1573 RoomId::from_proto(request.room_id),
1574 user_id,
1575 role,
1576 )
1577 .await?;
1578
1579 let livekit_room = room.livekit_room.clone();
1580 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1581 room_updated(&room, &session.peer);
1582 (livekit_room, can_publish)
1583 };
1584
1585 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1586 live_kit
1587 .update_participant(
1588 livekit_room.clone(),
1589 request.user_id.to_string(),
1590 livekit_api::proto::ParticipantPermission {
1591 can_subscribe: true,
1592 can_publish,
1593 can_publish_data: can_publish,
1594 hidden: false,
1595 recorder: false,
1596 },
1597 )
1598 .await
1599 .trace_err();
1600 }
1601
1602 response.send(proto::Ack {})?;
1603 Ok(())
1604}
1605
1606/// Call someone else into the current room
1607async fn call(
1608 request: proto::Call,
1609 response: Response<proto::Call>,
1610 session: MessageContext,
1611) -> Result<()> {
1612 let room_id = RoomId::from_proto(request.room_id);
1613 let calling_user_id = session.user_id();
1614 let calling_connection_id = session.connection_id;
1615 let called_user_id = UserId::from_proto(request.called_user_id);
1616 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1617 if !session
1618 .db()
1619 .await
1620 .has_contact(calling_user_id, called_user_id)
1621 .await?
1622 {
1623 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1624 }
1625
1626 let incoming_call = {
1627 let (room, incoming_call) = &mut *session
1628 .db()
1629 .await
1630 .call(
1631 room_id,
1632 calling_user_id,
1633 calling_connection_id,
1634 called_user_id,
1635 initial_project_id,
1636 )
1637 .await?;
1638 room_updated(room, &session.peer);
1639 mem::take(incoming_call)
1640 };
1641 update_user_contacts(called_user_id, &session).await?;
1642
1643 let mut calls = session
1644 .connection_pool()
1645 .await
1646 .user_connection_ids(called_user_id)
1647 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1648 .collect::<FuturesUnordered<_>>();
1649
1650 while let Some(call_response) = calls.next().await {
1651 match call_response.as_ref() {
1652 Ok(_) => {
1653 response.send(proto::Ack {})?;
1654 return Ok(());
1655 }
1656 Err(_) => {
1657 call_response.trace_err();
1658 }
1659 }
1660 }
1661
1662 {
1663 let room = session
1664 .db()
1665 .await
1666 .call_failed(room_id, called_user_id)
1667 .await?;
1668 room_updated(&room, &session.peer);
1669 }
1670 update_user_contacts(called_user_id, &session).await?;
1671
1672 Err(anyhow!("failed to ring user"))?
1673}
1674
1675/// Cancel an outgoing call.
1676async fn cancel_call(
1677 request: proto::CancelCall,
1678 response: Response<proto::CancelCall>,
1679 session: MessageContext,
1680) -> Result<()> {
1681 let called_user_id = UserId::from_proto(request.called_user_id);
1682 let room_id = RoomId::from_proto(request.room_id);
1683 {
1684 let room = session
1685 .db()
1686 .await
1687 .cancel_call(room_id, session.connection_id, called_user_id)
1688 .await?;
1689 room_updated(&room, &session.peer);
1690 }
1691
1692 for connection_id in session
1693 .connection_pool()
1694 .await
1695 .user_connection_ids(called_user_id)
1696 {
1697 session
1698 .peer
1699 .send(
1700 connection_id,
1701 proto::CallCanceled {
1702 room_id: room_id.to_proto(),
1703 },
1704 )
1705 .trace_err();
1706 }
1707 response.send(proto::Ack {})?;
1708
1709 update_user_contacts(called_user_id, &session).await?;
1710 Ok(())
1711}
1712
1713/// Decline an incoming call.
1714async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1715 let room_id = RoomId::from_proto(message.room_id);
1716 {
1717 let room = session
1718 .db()
1719 .await
1720 .decline_call(Some(room_id), session.user_id())
1721 .await?
1722 .context("declining call")?;
1723 room_updated(&room, &session.peer);
1724 }
1725
1726 for connection_id in session
1727 .connection_pool()
1728 .await
1729 .user_connection_ids(session.user_id())
1730 {
1731 session
1732 .peer
1733 .send(
1734 connection_id,
1735 proto::CallCanceled {
1736 room_id: room_id.to_proto(),
1737 },
1738 )
1739 .trace_err();
1740 }
1741 update_user_contacts(session.user_id(), &session).await?;
1742 Ok(())
1743}
1744
1745/// Updates other participants in the room with your current location.
1746async fn update_participant_location(
1747 request: proto::UpdateParticipantLocation,
1748 response: Response<proto::UpdateParticipantLocation>,
1749 session: MessageContext,
1750) -> Result<()> {
1751 let room_id = RoomId::from_proto(request.room_id);
1752 let location = request.location.context("invalid location")?;
1753
1754 let db = session.db().await;
1755 let room = db
1756 .update_room_participant_location(room_id, session.connection_id, location)
1757 .await?;
1758
1759 room_updated(&room, &session.peer);
1760 response.send(proto::Ack {})?;
1761 Ok(())
1762}
1763
1764/// Share a project into the room.
1765async fn share_project(
1766 request: proto::ShareProject,
1767 response: Response<proto::ShareProject>,
1768 session: MessageContext,
1769) -> Result<()> {
1770 let (project_id, room) = &*session
1771 .db()
1772 .await
1773 .share_project(
1774 RoomId::from_proto(request.room_id),
1775 session.connection_id,
1776 &request.worktrees,
1777 request.is_ssh_project,
1778 request.windows_paths.unwrap_or(false),
1779 )
1780 .await?;
1781 response.send(proto::ShareProjectResponse {
1782 project_id: project_id.to_proto(),
1783 })?;
1784 room_updated(room, &session.peer);
1785
1786 Ok(())
1787}
1788
1789/// Unshare a project from the room.
1790async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1791 let project_id = ProjectId::from_proto(message.project_id);
1792 unshare_project_internal(project_id, session.connection_id, &session).await
1793}
1794
1795async fn unshare_project_internal(
1796 project_id: ProjectId,
1797 connection_id: ConnectionId,
1798 session: &Session,
1799) -> Result<()> {
1800 let delete = {
1801 let room_guard = session
1802 .db()
1803 .await
1804 .unshare_project(project_id, connection_id)
1805 .await?;
1806
1807 let (delete, room, guest_connection_ids) = &*room_guard;
1808
1809 let message = proto::UnshareProject {
1810 project_id: project_id.to_proto(),
1811 };
1812
1813 broadcast(
1814 Some(connection_id),
1815 guest_connection_ids.iter().copied(),
1816 |conn_id| session.peer.send(conn_id, message.clone()),
1817 );
1818 if let Some(room) = room {
1819 room_updated(room, &session.peer);
1820 }
1821
1822 *delete
1823 };
1824
1825 if delete {
1826 let db = session.db().await;
1827 db.delete_project(project_id).await?;
1828 }
1829
1830 Ok(())
1831}
1832
1833/// Join someone elses shared project.
1834async fn join_project(
1835 request: proto::JoinProject,
1836 response: Response<proto::JoinProject>,
1837 session: MessageContext,
1838) -> Result<()> {
1839 let project_id = ProjectId::from_proto(request.project_id);
1840
1841 tracing::info!(%project_id, "join project");
1842
1843 let db = session.db().await;
1844 let (project, replica_id) = &mut *db
1845 .join_project(
1846 project_id,
1847 session.connection_id,
1848 session.user_id(),
1849 request.committer_name.clone(),
1850 request.committer_email.clone(),
1851 )
1852 .await?;
1853 drop(db);
1854 tracing::info!(%project_id, "join remote project");
1855 let collaborators = project
1856 .collaborators
1857 .iter()
1858 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1859 .map(|collaborator| collaborator.to_proto())
1860 .collect::<Vec<_>>();
1861 let project_id = project.id;
1862 let guest_user_id = session.user_id();
1863
1864 let worktrees = project
1865 .worktrees
1866 .iter()
1867 .map(|(id, worktree)| proto::WorktreeMetadata {
1868 id: *id,
1869 root_name: worktree.root_name.clone(),
1870 visible: worktree.visible,
1871 abs_path: worktree.abs_path.clone(),
1872 })
1873 .collect::<Vec<_>>();
1874
1875 let add_project_collaborator = proto::AddProjectCollaborator {
1876 project_id: project_id.to_proto(),
1877 collaborator: Some(proto::Collaborator {
1878 peer_id: Some(session.connection_id.into()),
1879 replica_id: replica_id.0 as u32,
1880 user_id: guest_user_id.to_proto(),
1881 is_host: false,
1882 committer_name: request.committer_name.clone(),
1883 committer_email: request.committer_email.clone(),
1884 }),
1885 };
1886
1887 for collaborator in &collaborators {
1888 session
1889 .peer
1890 .send(
1891 collaborator.peer_id.unwrap().into(),
1892 add_project_collaborator.clone(),
1893 )
1894 .trace_err();
1895 }
1896
1897 // First, we send the metadata associated with each worktree.
1898 let (language_servers, language_server_capabilities) = project
1899 .language_servers
1900 .clone()
1901 .into_iter()
1902 .map(|server| (server.server, server.capabilities))
1903 .unzip();
1904 response.send(proto::JoinProjectResponse {
1905 project_id: project.id.0 as u64,
1906 worktrees,
1907 replica_id: replica_id.0 as u32,
1908 collaborators,
1909 language_servers,
1910 language_server_capabilities,
1911 role: project.role.into(),
1912 windows_paths: project.path_style == PathStyle::Windows,
1913 })?;
1914
1915 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1916 // Stream this worktree's entries.
1917 let message = proto::UpdateWorktree {
1918 project_id: project_id.to_proto(),
1919 worktree_id,
1920 abs_path: worktree.abs_path.clone(),
1921 root_name: worktree.root_name,
1922 updated_entries: worktree.entries,
1923 removed_entries: Default::default(),
1924 scan_id: worktree.scan_id,
1925 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1926 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1927 removed_repositories: Default::default(),
1928 };
1929 for update in proto::split_worktree_update(message) {
1930 session.peer.send(session.connection_id, update.clone())?;
1931 }
1932
1933 // Stream this worktree's diagnostics.
1934 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1935 if let Some(summary) = worktree_diagnostics.next() {
1936 let message = proto::UpdateDiagnosticSummary {
1937 project_id: project.id.to_proto(),
1938 worktree_id: worktree.id,
1939 summary: Some(summary),
1940 more_summaries: worktree_diagnostics.collect(),
1941 };
1942 session.peer.send(session.connection_id, message)?;
1943 }
1944
1945 for settings_file in worktree.settings_files {
1946 session.peer.send(
1947 session.connection_id,
1948 proto::UpdateWorktreeSettings {
1949 project_id: project_id.to_proto(),
1950 worktree_id: worktree.id,
1951 path: settings_file.path,
1952 content: Some(settings_file.content),
1953 kind: Some(settings_file.kind.to_proto() as i32),
1954 outside_worktree: Some(settings_file.outside_worktree),
1955 },
1956 )?;
1957 }
1958 }
1959
1960 for repository in mem::take(&mut project.repositories) {
1961 for update in split_repository_update(repository) {
1962 session.peer.send(session.connection_id, update)?;
1963 }
1964 }
1965
1966 for language_server in &project.language_servers {
1967 session.peer.send(
1968 session.connection_id,
1969 proto::UpdateLanguageServer {
1970 project_id: project_id.to_proto(),
1971 server_name: Some(language_server.server.name.clone()),
1972 language_server_id: language_server.server.id,
1973 variant: Some(
1974 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1975 proto::LspDiskBasedDiagnosticsUpdated {},
1976 ),
1977 ),
1978 },
1979 )?;
1980 }
1981
1982 Ok(())
1983}
1984
1985/// Leave someone elses shared project.
1986async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
1987 let sender_id = session.connection_id;
1988 let project_id = ProjectId::from_proto(request.project_id);
1989 let db = session.db().await;
1990
1991 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1992 tracing::info!(
1993 %project_id,
1994 "leave project"
1995 );
1996
1997 project_left(project, &session);
1998 if let Some(room) = room {
1999 room_updated(room, &session.peer);
2000 }
2001
2002 Ok(())
2003}
2004
2005/// Updates other participants with changes to the project
2006async fn update_project(
2007 request: proto::UpdateProject,
2008 response: Response<proto::UpdateProject>,
2009 session: MessageContext,
2010) -> Result<()> {
2011 let project_id = ProjectId::from_proto(request.project_id);
2012 let (room, guest_connection_ids) = &*session
2013 .db()
2014 .await
2015 .update_project(project_id, session.connection_id, &request.worktrees)
2016 .await?;
2017 broadcast(
2018 Some(session.connection_id),
2019 guest_connection_ids.iter().copied(),
2020 |connection_id| {
2021 session
2022 .peer
2023 .forward_send(session.connection_id, connection_id, request.clone())
2024 },
2025 );
2026 if let Some(room) = room {
2027 room_updated(room, &session.peer);
2028 }
2029 response.send(proto::Ack {})?;
2030
2031 Ok(())
2032}
2033
2034/// Updates other participants with changes to the worktree
2035async fn update_worktree(
2036 request: proto::UpdateWorktree,
2037 response: Response<proto::UpdateWorktree>,
2038 session: MessageContext,
2039) -> Result<()> {
2040 let guest_connection_ids = session
2041 .db()
2042 .await
2043 .update_worktree(&request, session.connection_id)
2044 .await?;
2045
2046 broadcast(
2047 Some(session.connection_id),
2048 guest_connection_ids.iter().copied(),
2049 |connection_id| {
2050 session
2051 .peer
2052 .forward_send(session.connection_id, connection_id, request.clone())
2053 },
2054 );
2055 response.send(proto::Ack {})?;
2056 Ok(())
2057}
2058
2059async fn update_repository(
2060 request: proto::UpdateRepository,
2061 response: Response<proto::UpdateRepository>,
2062 session: MessageContext,
2063) -> Result<()> {
2064 let guest_connection_ids = session
2065 .db()
2066 .await
2067 .update_repository(&request, session.connection_id)
2068 .await?;
2069
2070 broadcast(
2071 Some(session.connection_id),
2072 guest_connection_ids.iter().copied(),
2073 |connection_id| {
2074 session
2075 .peer
2076 .forward_send(session.connection_id, connection_id, request.clone())
2077 },
2078 );
2079 response.send(proto::Ack {})?;
2080 Ok(())
2081}
2082
2083async fn remove_repository(
2084 request: proto::RemoveRepository,
2085 response: Response<proto::RemoveRepository>,
2086 session: MessageContext,
2087) -> Result<()> {
2088 let guest_connection_ids = session
2089 .db()
2090 .await
2091 .remove_repository(&request, session.connection_id)
2092 .await?;
2093
2094 broadcast(
2095 Some(session.connection_id),
2096 guest_connection_ids.iter().copied(),
2097 |connection_id| {
2098 session
2099 .peer
2100 .forward_send(session.connection_id, connection_id, request.clone())
2101 },
2102 );
2103 response.send(proto::Ack {})?;
2104 Ok(())
2105}
2106
2107/// Updates other participants with changes to the diagnostics
2108async fn update_diagnostic_summary(
2109 message: proto::UpdateDiagnosticSummary,
2110 session: MessageContext,
2111) -> Result<()> {
2112 let guest_connection_ids = session
2113 .db()
2114 .await
2115 .update_diagnostic_summary(&message, session.connection_id)
2116 .await?;
2117
2118 broadcast(
2119 Some(session.connection_id),
2120 guest_connection_ids.iter().copied(),
2121 |connection_id| {
2122 session
2123 .peer
2124 .forward_send(session.connection_id, connection_id, message.clone())
2125 },
2126 );
2127
2128 Ok(())
2129}
2130
2131/// Updates other participants with changes to the worktree settings
2132async fn update_worktree_settings(
2133 message: proto::UpdateWorktreeSettings,
2134 session: MessageContext,
2135) -> Result<()> {
2136 let guest_connection_ids = session
2137 .db()
2138 .await
2139 .update_worktree_settings(&message, session.connection_id)
2140 .await?;
2141
2142 broadcast(
2143 Some(session.connection_id),
2144 guest_connection_ids.iter().copied(),
2145 |connection_id| {
2146 session
2147 .peer
2148 .forward_send(session.connection_id, connection_id, message.clone())
2149 },
2150 );
2151
2152 Ok(())
2153}
2154
2155/// Notify other participants that a language server has started.
2156async fn start_language_server(
2157 request: proto::StartLanguageServer,
2158 session: MessageContext,
2159) -> Result<()> {
2160 let guest_connection_ids = session
2161 .db()
2162 .await
2163 .start_language_server(&request, session.connection_id)
2164 .await?;
2165
2166 broadcast(
2167 Some(session.connection_id),
2168 guest_connection_ids.iter().copied(),
2169 |connection_id| {
2170 session
2171 .peer
2172 .forward_send(session.connection_id, connection_id, request.clone())
2173 },
2174 );
2175 Ok(())
2176}
2177
2178/// Notify other participants that a language server has changed.
2179async fn update_language_server(
2180 request: proto::UpdateLanguageServer,
2181 session: MessageContext,
2182) -> Result<()> {
2183 let project_id = ProjectId::from_proto(request.project_id);
2184 let db = session.db().await;
2185
2186 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2187 && let Some(capabilities) = update.capabilities.clone()
2188 {
2189 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2190 .await?;
2191 }
2192
2193 let project_connection_ids = db
2194 .project_connection_ids(project_id, session.connection_id, true)
2195 .await?;
2196 broadcast(
2197 Some(session.connection_id),
2198 project_connection_ids.iter().copied(),
2199 |connection_id| {
2200 session
2201 .peer
2202 .forward_send(session.connection_id, connection_id, request.clone())
2203 },
2204 );
2205 Ok(())
2206}
2207
2208/// forward a project request to the host. These requests should be read only
2209/// as guests are allowed to send them.
2210async fn forward_read_only_project_request<T>(
2211 request: T,
2212 response: Response<T>,
2213 session: MessageContext,
2214) -> Result<()>
2215where
2216 T: EntityMessage + RequestMessage,
2217{
2218 let project_id = ProjectId::from_proto(request.remote_entity_id());
2219 let host_connection_id = session
2220 .db()
2221 .await
2222 .host_for_read_only_project_request(project_id, session.connection_id)
2223 .await?;
2224 let payload = session.forward_request(host_connection_id, request).await?;
2225 response.send(payload)?;
2226 Ok(())
2227}
2228
2229/// forward a project request to the host. These requests are disallowed
2230/// for guests.
2231async fn forward_mutating_project_request<T>(
2232 request: T,
2233 response: Response<T>,
2234 session: MessageContext,
2235) -> Result<()>
2236where
2237 T: EntityMessage + RequestMessage,
2238{
2239 let project_id = ProjectId::from_proto(request.remote_entity_id());
2240
2241 let host_connection_id = session
2242 .db()
2243 .await
2244 .host_for_mutating_project_request(project_id, session.connection_id)
2245 .await?;
2246 let payload = session.forward_request(host_connection_id, request).await?;
2247 response.send(payload)?;
2248 Ok(())
2249}
2250
2251async fn lsp_query(
2252 request: proto::LspQuery,
2253 response: Response<proto::LspQuery>,
2254 session: MessageContext,
2255) -> Result<()> {
2256 let (name, should_write) = request.query_name_and_write_permissions();
2257 tracing::Span::current().record("lsp_query_request", name);
2258 tracing::info!("lsp_query message received");
2259 if should_write {
2260 forward_mutating_project_request(request, response, session).await
2261 } else {
2262 forward_read_only_project_request(request, response, session).await
2263 }
2264}
2265
2266/// Notify other participants that a new buffer has been created
2267async fn create_buffer_for_peer(
2268 request: proto::CreateBufferForPeer,
2269 session: MessageContext,
2270) -> Result<()> {
2271 session
2272 .db()
2273 .await
2274 .check_user_is_project_host(
2275 ProjectId::from_proto(request.project_id),
2276 session.connection_id,
2277 )
2278 .await?;
2279 let peer_id = request.peer_id.context("invalid peer id")?;
2280 session
2281 .peer
2282 .forward_send(session.connection_id, peer_id.into(), request)?;
2283 Ok(())
2284}
2285
2286/// Notify other participants that a new image has been created
2287async fn create_image_for_peer(
2288 request: proto::CreateImageForPeer,
2289 session: MessageContext,
2290) -> Result<()> {
2291 session
2292 .db()
2293 .await
2294 .check_user_is_project_host(
2295 ProjectId::from_proto(request.project_id),
2296 session.connection_id,
2297 )
2298 .await?;
2299 let peer_id = request.peer_id.context("invalid peer id")?;
2300 session
2301 .peer
2302 .forward_send(session.connection_id, peer_id.into(), request)?;
2303 Ok(())
2304}
2305
2306/// Notify other participants that a buffer has been updated. This is
2307/// allowed for guests as long as the update is limited to selections.
2308async fn update_buffer(
2309 request: proto::UpdateBuffer,
2310 response: Response<proto::UpdateBuffer>,
2311 session: MessageContext,
2312) -> Result<()> {
2313 let project_id = ProjectId::from_proto(request.project_id);
2314 let mut capability = Capability::ReadOnly;
2315
2316 for op in request.operations.iter() {
2317 match op.variant {
2318 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2319 Some(_) => capability = Capability::ReadWrite,
2320 }
2321 }
2322
2323 let host = {
2324 let guard = session
2325 .db()
2326 .await
2327 .connections_for_buffer_update(project_id, session.connection_id, capability)
2328 .await?;
2329
2330 let (host, guests) = &*guard;
2331
2332 broadcast(
2333 Some(session.connection_id),
2334 guests.clone(),
2335 |connection_id| {
2336 session
2337 .peer
2338 .forward_send(session.connection_id, connection_id, request.clone())
2339 },
2340 );
2341
2342 *host
2343 };
2344
2345 if host != session.connection_id {
2346 session.forward_request(host, request.clone()).await?;
2347 }
2348
2349 response.send(proto::Ack {})?;
2350 Ok(())
2351}
2352
2353async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2354 let project_id = ProjectId::from_proto(message.project_id);
2355
2356 let operation = message.operation.as_ref().context("invalid operation")?;
2357 let capability = match operation.variant.as_ref() {
2358 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2359 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2360 match buffer_op.variant {
2361 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2362 Capability::ReadOnly
2363 }
2364 _ => Capability::ReadWrite,
2365 }
2366 } else {
2367 Capability::ReadWrite
2368 }
2369 }
2370 Some(_) => Capability::ReadWrite,
2371 None => Capability::ReadOnly,
2372 };
2373
2374 let guard = session
2375 .db()
2376 .await
2377 .connections_for_buffer_update(project_id, session.connection_id, capability)
2378 .await?;
2379
2380 let (host, guests) = &*guard;
2381
2382 broadcast(
2383 Some(session.connection_id),
2384 guests.iter().chain([host]).copied(),
2385 |connection_id| {
2386 session
2387 .peer
2388 .forward_send(session.connection_id, connection_id, message.clone())
2389 },
2390 );
2391
2392 Ok(())
2393}
2394
2395async fn forward_project_search_chunk(
2396 message: proto::FindSearchCandidatesChunk,
2397 response: Response<proto::FindSearchCandidatesChunk>,
2398 session: MessageContext,
2399) -> Result<()> {
2400 let peer_id = message.peer_id.context("missing peer_id")?;
2401 let payload = session
2402 .peer
2403 .forward_request(session.connection_id, peer_id.into(), message)
2404 .await?;
2405 response.send(payload)?;
2406 Ok(())
2407}
2408
2409/// Notify other participants that a project has been updated.
2410async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2411 request: T,
2412 session: MessageContext,
2413) -> Result<()> {
2414 let project_id = ProjectId::from_proto(request.remote_entity_id());
2415 let project_connection_ids = session
2416 .db()
2417 .await
2418 .project_connection_ids(project_id, session.connection_id, false)
2419 .await?;
2420
2421 broadcast(
2422 Some(session.connection_id),
2423 project_connection_ids.iter().copied(),
2424 |connection_id| {
2425 session
2426 .peer
2427 .forward_send(session.connection_id, connection_id, request.clone())
2428 },
2429 );
2430 Ok(())
2431}
2432
2433/// Start following another user in a call.
2434async fn follow(
2435 request: proto::Follow,
2436 response: Response<proto::Follow>,
2437 session: MessageContext,
2438) -> Result<()> {
2439 let room_id = RoomId::from_proto(request.room_id);
2440 let project_id = request.project_id.map(ProjectId::from_proto);
2441 let leader_id = request.leader_id.context("invalid leader id")?.into();
2442 let follower_id = session.connection_id;
2443
2444 session
2445 .db()
2446 .await
2447 .check_room_participants(room_id, leader_id, session.connection_id)
2448 .await?;
2449
2450 let response_payload = session.forward_request(leader_id, request).await?;
2451 response.send(response_payload)?;
2452
2453 if let Some(project_id) = project_id {
2454 let room = session
2455 .db()
2456 .await
2457 .follow(room_id, project_id, leader_id, follower_id)
2458 .await?;
2459 room_updated(&room, &session.peer);
2460 }
2461
2462 Ok(())
2463}
2464
2465/// Stop following another user in a call.
2466async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2467 let room_id = RoomId::from_proto(request.room_id);
2468 let project_id = request.project_id.map(ProjectId::from_proto);
2469 let leader_id = request.leader_id.context("invalid leader id")?.into();
2470 let follower_id = session.connection_id;
2471
2472 session
2473 .db()
2474 .await
2475 .check_room_participants(room_id, leader_id, session.connection_id)
2476 .await?;
2477
2478 session
2479 .peer
2480 .forward_send(session.connection_id, leader_id, request)?;
2481
2482 if let Some(project_id) = project_id {
2483 let room = session
2484 .db()
2485 .await
2486 .unfollow(room_id, project_id, leader_id, follower_id)
2487 .await?;
2488 room_updated(&room, &session.peer);
2489 }
2490
2491 Ok(())
2492}
2493
2494/// Notify everyone following you of your current location.
2495async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2496 let room_id = RoomId::from_proto(request.room_id);
2497 let database = session.db.lock().await;
2498
2499 let connection_ids = if let Some(project_id) = request.project_id {
2500 let project_id = ProjectId::from_proto(project_id);
2501 database
2502 .project_connection_ids(project_id, session.connection_id, true)
2503 .await?
2504 } else {
2505 database
2506 .room_connection_ids(room_id, session.connection_id)
2507 .await?
2508 };
2509
2510 // For now, don't send view update messages back to that view's current leader.
2511 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2512 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2513 _ => None,
2514 });
2515
2516 for connection_id in connection_ids.iter().cloned() {
2517 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2518 session
2519 .peer
2520 .forward_send(session.connection_id, connection_id, request.clone())?;
2521 }
2522 }
2523 Ok(())
2524}
2525
2526/// Get public data about users.
2527async fn get_users(
2528 request: proto::GetUsers,
2529 response: Response<proto::GetUsers>,
2530 session: MessageContext,
2531) -> Result<()> {
2532 let user_ids = request
2533 .user_ids
2534 .into_iter()
2535 .map(UserId::from_proto)
2536 .collect();
2537 let users = session
2538 .db()
2539 .await
2540 .get_users_by_ids(user_ids)
2541 .await?
2542 .into_iter()
2543 .map(|user| proto::User {
2544 id: user.id.to_proto(),
2545 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2546 github_login: user.github_login,
2547 name: user.name,
2548 })
2549 .collect();
2550 response.send(proto::UsersResponse { users })?;
2551 Ok(())
2552}
2553
2554/// Search for users (to invite) buy Github login
2555async fn fuzzy_search_users(
2556 request: proto::FuzzySearchUsers,
2557 response: Response<proto::FuzzySearchUsers>,
2558 session: MessageContext,
2559) -> Result<()> {
2560 let query = request.query;
2561 let users = match query.len() {
2562 0 => vec![],
2563 1 | 2 => session
2564 .db()
2565 .await
2566 .get_user_by_github_login(&query)
2567 .await?
2568 .into_iter()
2569 .collect(),
2570 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2571 };
2572 let users = users
2573 .into_iter()
2574 .filter(|user| user.id != session.user_id())
2575 .map(|user| proto::User {
2576 id: user.id.to_proto(),
2577 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2578 github_login: user.github_login,
2579 name: user.name,
2580 })
2581 .collect();
2582 response.send(proto::UsersResponse { users })?;
2583 Ok(())
2584}
2585
2586/// Send a contact request to another user.
2587async fn request_contact(
2588 request: proto::RequestContact,
2589 response: Response<proto::RequestContact>,
2590 session: MessageContext,
2591) -> Result<()> {
2592 let requester_id = session.user_id();
2593 let responder_id = UserId::from_proto(request.responder_id);
2594 if requester_id == responder_id {
2595 return Err(anyhow!("cannot add yourself as a contact"))?;
2596 }
2597
2598 let notifications = session
2599 .db()
2600 .await
2601 .send_contact_request(requester_id, responder_id)
2602 .await?;
2603
2604 // Update outgoing contact requests of requester
2605 let mut update = proto::UpdateContacts::default();
2606 update.outgoing_requests.push(responder_id.to_proto());
2607 for connection_id in session
2608 .connection_pool()
2609 .await
2610 .user_connection_ids(requester_id)
2611 {
2612 session.peer.send(connection_id, update.clone())?;
2613 }
2614
2615 // Update incoming contact requests of responder
2616 let mut update = proto::UpdateContacts::default();
2617 update
2618 .incoming_requests
2619 .push(proto::IncomingContactRequest {
2620 requester_id: requester_id.to_proto(),
2621 });
2622 let connection_pool = session.connection_pool().await;
2623 for connection_id in connection_pool.user_connection_ids(responder_id) {
2624 session.peer.send(connection_id, update.clone())?;
2625 }
2626
2627 send_notifications(&connection_pool, &session.peer, notifications);
2628
2629 response.send(proto::Ack {})?;
2630 Ok(())
2631}
2632
2633/// Accept or decline a contact request
2634async fn respond_to_contact_request(
2635 request: proto::RespondToContactRequest,
2636 response: Response<proto::RespondToContactRequest>,
2637 session: MessageContext,
2638) -> Result<()> {
2639 let responder_id = session.user_id();
2640 let requester_id = UserId::from_proto(request.requester_id);
2641 let db = session.db().await;
2642 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2643 db.dismiss_contact_notification(responder_id, requester_id)
2644 .await?;
2645 } else {
2646 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2647
2648 let notifications = db
2649 .respond_to_contact_request(responder_id, requester_id, accept)
2650 .await?;
2651 let requester_busy = db.is_user_busy(requester_id).await?;
2652 let responder_busy = db.is_user_busy(responder_id).await?;
2653
2654 let pool = session.connection_pool().await;
2655 // Update responder with new contact
2656 let mut update = proto::UpdateContacts::default();
2657 if accept {
2658 update
2659 .contacts
2660 .push(contact_for_user(requester_id, requester_busy, &pool));
2661 }
2662 update
2663 .remove_incoming_requests
2664 .push(requester_id.to_proto());
2665 for connection_id in pool.user_connection_ids(responder_id) {
2666 session.peer.send(connection_id, update.clone())?;
2667 }
2668
2669 // Update requester with new contact
2670 let mut update = proto::UpdateContacts::default();
2671 if accept {
2672 update
2673 .contacts
2674 .push(contact_for_user(responder_id, responder_busy, &pool));
2675 }
2676 update
2677 .remove_outgoing_requests
2678 .push(responder_id.to_proto());
2679
2680 for connection_id in pool.user_connection_ids(requester_id) {
2681 session.peer.send(connection_id, update.clone())?;
2682 }
2683
2684 send_notifications(&pool, &session.peer, notifications);
2685 }
2686
2687 response.send(proto::Ack {})?;
2688 Ok(())
2689}
2690
2691/// Remove a contact.
2692async fn remove_contact(
2693 request: proto::RemoveContact,
2694 response: Response<proto::RemoveContact>,
2695 session: MessageContext,
2696) -> Result<()> {
2697 let requester_id = session.user_id();
2698 let responder_id = UserId::from_proto(request.user_id);
2699 let db = session.db().await;
2700 let (contact_accepted, deleted_notification_id) =
2701 db.remove_contact(requester_id, responder_id).await?;
2702
2703 let pool = session.connection_pool().await;
2704 // Update outgoing contact requests of requester
2705 let mut update = proto::UpdateContacts::default();
2706 if contact_accepted {
2707 update.remove_contacts.push(responder_id.to_proto());
2708 } else {
2709 update
2710 .remove_outgoing_requests
2711 .push(responder_id.to_proto());
2712 }
2713 for connection_id in pool.user_connection_ids(requester_id) {
2714 session.peer.send(connection_id, update.clone())?;
2715 }
2716
2717 // Update incoming contact requests of responder
2718 let mut update = proto::UpdateContacts::default();
2719 if contact_accepted {
2720 update.remove_contacts.push(requester_id.to_proto());
2721 } else {
2722 update
2723 .remove_incoming_requests
2724 .push(requester_id.to_proto());
2725 }
2726 for connection_id in pool.user_connection_ids(responder_id) {
2727 session.peer.send(connection_id, update.clone())?;
2728 if let Some(notification_id) = deleted_notification_id {
2729 session.peer.send(
2730 connection_id,
2731 proto::DeleteNotification {
2732 notification_id: notification_id.to_proto(),
2733 },
2734 )?;
2735 }
2736 }
2737
2738 response.send(proto::Ack {})?;
2739 Ok(())
2740}
2741
2742fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2743 version.0.minor < 139
2744}
2745
2746async fn subscribe_to_channels(
2747 _: proto::SubscribeToChannels,
2748 session: MessageContext,
2749) -> Result<()> {
2750 subscribe_user_to_channels(session.user_id(), &session).await?;
2751 Ok(())
2752}
2753
2754async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2755 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2756 let mut pool = session.connection_pool().await;
2757 for membership in &channels_for_user.channel_memberships {
2758 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2759 }
2760 session.peer.send(
2761 session.connection_id,
2762 build_update_user_channels(&channels_for_user),
2763 )?;
2764 session.peer.send(
2765 session.connection_id,
2766 build_channels_update(channels_for_user),
2767 )?;
2768 Ok(())
2769}
2770
2771/// Creates a new channel.
2772async fn create_channel(
2773 request: proto::CreateChannel,
2774 response: Response<proto::CreateChannel>,
2775 session: MessageContext,
2776) -> Result<()> {
2777 let db = session.db().await;
2778
2779 let parent_id = request.parent_id.map(ChannelId::from_proto);
2780 let (channel, membership) = db
2781 .create_channel(&request.name, parent_id, session.user_id())
2782 .await?;
2783
2784 let root_id = channel.root_id();
2785 let channel = Channel::from_model(channel);
2786
2787 response.send(proto::CreateChannelResponse {
2788 channel: Some(channel.to_proto()),
2789 parent_id: request.parent_id,
2790 })?;
2791
2792 let mut connection_pool = session.connection_pool().await;
2793 if let Some(membership) = membership {
2794 connection_pool.subscribe_to_channel(
2795 membership.user_id,
2796 membership.channel_id,
2797 membership.role,
2798 );
2799 let update = proto::UpdateUserChannels {
2800 channel_memberships: vec![proto::ChannelMembership {
2801 channel_id: membership.channel_id.to_proto(),
2802 role: membership.role.into(),
2803 }],
2804 ..Default::default()
2805 };
2806 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2807 session.peer.send(connection_id, update.clone())?;
2808 }
2809 }
2810
2811 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2812 if !role.can_see_channel(channel.visibility) {
2813 continue;
2814 }
2815
2816 let update = proto::UpdateChannels {
2817 channels: vec![channel.to_proto()],
2818 ..Default::default()
2819 };
2820 session.peer.send(connection_id, update.clone())?;
2821 }
2822
2823 Ok(())
2824}
2825
2826/// Delete a channel
2827async fn delete_channel(
2828 request: proto::DeleteChannel,
2829 response: Response<proto::DeleteChannel>,
2830 session: MessageContext,
2831) -> Result<()> {
2832 let db = session.db().await;
2833
2834 let channel_id = request.channel_id;
2835 let (root_channel, removed_channels) = db
2836 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2837 .await?;
2838 response.send(proto::Ack {})?;
2839
2840 // Notify members of removed channels
2841 let mut update = proto::UpdateChannels::default();
2842 update
2843 .delete_channels
2844 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2845
2846 let connection_pool = session.connection_pool().await;
2847 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2848 session.peer.send(connection_id, update.clone())?;
2849 }
2850
2851 Ok(())
2852}
2853
2854/// Invite someone to join a channel.
2855async fn invite_channel_member(
2856 request: proto::InviteChannelMember,
2857 response: Response<proto::InviteChannelMember>,
2858 session: MessageContext,
2859) -> Result<()> {
2860 let db = session.db().await;
2861 let channel_id = ChannelId::from_proto(request.channel_id);
2862 let invitee_id = UserId::from_proto(request.user_id);
2863 let InviteMemberResult {
2864 channel,
2865 notifications,
2866 } = db
2867 .invite_channel_member(
2868 channel_id,
2869 invitee_id,
2870 session.user_id(),
2871 request.role().into(),
2872 )
2873 .await?;
2874
2875 let update = proto::UpdateChannels {
2876 channel_invitations: vec![channel.to_proto()],
2877 ..Default::default()
2878 };
2879
2880 let connection_pool = session.connection_pool().await;
2881 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2882 session.peer.send(connection_id, update.clone())?;
2883 }
2884
2885 send_notifications(&connection_pool, &session.peer, notifications);
2886
2887 response.send(proto::Ack {})?;
2888 Ok(())
2889}
2890
2891/// remove someone from a channel
2892async fn remove_channel_member(
2893 request: proto::RemoveChannelMember,
2894 response: Response<proto::RemoveChannelMember>,
2895 session: MessageContext,
2896) -> Result<()> {
2897 let db = session.db().await;
2898 let channel_id = ChannelId::from_proto(request.channel_id);
2899 let member_id = UserId::from_proto(request.user_id);
2900
2901 let RemoveChannelMemberResult {
2902 membership_update,
2903 notification_id,
2904 } = db
2905 .remove_channel_member(channel_id, member_id, session.user_id())
2906 .await?;
2907
2908 let mut connection_pool = session.connection_pool().await;
2909 notify_membership_updated(
2910 &mut connection_pool,
2911 membership_update,
2912 member_id,
2913 &session.peer,
2914 );
2915 for connection_id in connection_pool.user_connection_ids(member_id) {
2916 if let Some(notification_id) = notification_id {
2917 session
2918 .peer
2919 .send(
2920 connection_id,
2921 proto::DeleteNotification {
2922 notification_id: notification_id.to_proto(),
2923 },
2924 )
2925 .trace_err();
2926 }
2927 }
2928
2929 response.send(proto::Ack {})?;
2930 Ok(())
2931}
2932
2933/// Toggle the channel between public and private.
2934/// Care is taken to maintain the invariant that public channels only descend from public channels,
2935/// (though members-only channels can appear at any point in the hierarchy).
2936async fn set_channel_visibility(
2937 request: proto::SetChannelVisibility,
2938 response: Response<proto::SetChannelVisibility>,
2939 session: MessageContext,
2940) -> Result<()> {
2941 let db = session.db().await;
2942 let channel_id = ChannelId::from_proto(request.channel_id);
2943 let visibility = request.visibility().into();
2944
2945 let channel_model = db
2946 .set_channel_visibility(channel_id, visibility, session.user_id())
2947 .await?;
2948 let root_id = channel_model.root_id();
2949 let channel = Channel::from_model(channel_model);
2950
2951 let mut connection_pool = session.connection_pool().await;
2952 for (user_id, role) in connection_pool
2953 .channel_user_ids(root_id)
2954 .collect::<Vec<_>>()
2955 .into_iter()
2956 {
2957 let update = if role.can_see_channel(channel.visibility) {
2958 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2959 proto::UpdateChannels {
2960 channels: vec![channel.to_proto()],
2961 ..Default::default()
2962 }
2963 } else {
2964 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2965 proto::UpdateChannels {
2966 delete_channels: vec![channel.id.to_proto()],
2967 ..Default::default()
2968 }
2969 };
2970
2971 for connection_id in connection_pool.user_connection_ids(user_id) {
2972 session.peer.send(connection_id, update.clone())?;
2973 }
2974 }
2975
2976 response.send(proto::Ack {})?;
2977 Ok(())
2978}
2979
2980/// Alter the role for a user in the channel.
2981async fn set_channel_member_role(
2982 request: proto::SetChannelMemberRole,
2983 response: Response<proto::SetChannelMemberRole>,
2984 session: MessageContext,
2985) -> Result<()> {
2986 let db = session.db().await;
2987 let channel_id = ChannelId::from_proto(request.channel_id);
2988 let member_id = UserId::from_proto(request.user_id);
2989 let result = db
2990 .set_channel_member_role(
2991 channel_id,
2992 session.user_id(),
2993 member_id,
2994 request.role().into(),
2995 )
2996 .await?;
2997
2998 match result {
2999 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3000 let mut connection_pool = session.connection_pool().await;
3001 notify_membership_updated(
3002 &mut connection_pool,
3003 membership_update,
3004 member_id,
3005 &session.peer,
3006 )
3007 }
3008 db::SetMemberRoleResult::InviteUpdated(channel) => {
3009 let update = proto::UpdateChannels {
3010 channel_invitations: vec![channel.to_proto()],
3011 ..Default::default()
3012 };
3013
3014 for connection_id in session
3015 .connection_pool()
3016 .await
3017 .user_connection_ids(member_id)
3018 {
3019 session.peer.send(connection_id, update.clone())?;
3020 }
3021 }
3022 }
3023
3024 response.send(proto::Ack {})?;
3025 Ok(())
3026}
3027
3028/// Change the name of a channel
3029async fn rename_channel(
3030 request: proto::RenameChannel,
3031 response: Response<proto::RenameChannel>,
3032 session: MessageContext,
3033) -> Result<()> {
3034 let db = session.db().await;
3035 let channel_id = ChannelId::from_proto(request.channel_id);
3036 let channel_model = db
3037 .rename_channel(channel_id, session.user_id(), &request.name)
3038 .await?;
3039 let root_id = channel_model.root_id();
3040 let channel = Channel::from_model(channel_model);
3041
3042 response.send(proto::RenameChannelResponse {
3043 channel: Some(channel.to_proto()),
3044 })?;
3045
3046 let connection_pool = session.connection_pool().await;
3047 let update = proto::UpdateChannels {
3048 channels: vec![channel.to_proto()],
3049 ..Default::default()
3050 };
3051 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3052 if role.can_see_channel(channel.visibility) {
3053 session.peer.send(connection_id, update.clone())?;
3054 }
3055 }
3056
3057 Ok(())
3058}
3059
3060/// Move a channel to a new parent.
3061async fn move_channel(
3062 request: proto::MoveChannel,
3063 response: Response<proto::MoveChannel>,
3064 session: MessageContext,
3065) -> Result<()> {
3066 let channel_id = ChannelId::from_proto(request.channel_id);
3067 let to = ChannelId::from_proto(request.to);
3068
3069 let (root_id, channels) = session
3070 .db()
3071 .await
3072 .move_channel(channel_id, to, session.user_id())
3073 .await?;
3074
3075 let connection_pool = session.connection_pool().await;
3076 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3077 let channels = channels
3078 .iter()
3079 .filter_map(|channel| {
3080 if role.can_see_channel(channel.visibility) {
3081 Some(channel.to_proto())
3082 } else {
3083 None
3084 }
3085 })
3086 .collect::<Vec<_>>();
3087 if channels.is_empty() {
3088 continue;
3089 }
3090
3091 let update = proto::UpdateChannels {
3092 channels,
3093 ..Default::default()
3094 };
3095
3096 session.peer.send(connection_id, update.clone())?;
3097 }
3098
3099 response.send(Ack {})?;
3100 Ok(())
3101}
3102
3103async fn reorder_channel(
3104 request: proto::ReorderChannel,
3105 response: Response<proto::ReorderChannel>,
3106 session: MessageContext,
3107) -> Result<()> {
3108 let channel_id = ChannelId::from_proto(request.channel_id);
3109 let direction = request.direction();
3110
3111 let updated_channels = session
3112 .db()
3113 .await
3114 .reorder_channel(channel_id, direction, session.user_id())
3115 .await?;
3116
3117 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3118 let connection_pool = session.connection_pool().await;
3119 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3120 let channels = updated_channels
3121 .iter()
3122 .filter_map(|channel| {
3123 if role.can_see_channel(channel.visibility) {
3124 Some(channel.to_proto())
3125 } else {
3126 None
3127 }
3128 })
3129 .collect::<Vec<_>>();
3130
3131 if channels.is_empty() {
3132 continue;
3133 }
3134
3135 let update = proto::UpdateChannels {
3136 channels,
3137 ..Default::default()
3138 };
3139
3140 session.peer.send(connection_id, update.clone())?;
3141 }
3142 }
3143
3144 response.send(Ack {})?;
3145 Ok(())
3146}
3147
3148/// Get the list of channel members
3149async fn get_channel_members(
3150 request: proto::GetChannelMembers,
3151 response: Response<proto::GetChannelMembers>,
3152 session: MessageContext,
3153) -> Result<()> {
3154 let db = session.db().await;
3155 let channel_id = ChannelId::from_proto(request.channel_id);
3156 let limit = if request.limit == 0 {
3157 u16::MAX as u64
3158 } else {
3159 request.limit
3160 };
3161 let (members, users) = db
3162 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3163 .await?;
3164 response.send(proto::GetChannelMembersResponse { members, users })?;
3165 Ok(())
3166}
3167
3168/// Accept or decline a channel invitation.
3169async fn respond_to_channel_invite(
3170 request: proto::RespondToChannelInvite,
3171 response: Response<proto::RespondToChannelInvite>,
3172 session: MessageContext,
3173) -> Result<()> {
3174 let db = session.db().await;
3175 let channel_id = ChannelId::from_proto(request.channel_id);
3176 let RespondToChannelInvite {
3177 membership_update,
3178 notifications,
3179 } = db
3180 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3181 .await?;
3182
3183 let mut connection_pool = session.connection_pool().await;
3184 if let Some(membership_update) = membership_update {
3185 notify_membership_updated(
3186 &mut connection_pool,
3187 membership_update,
3188 session.user_id(),
3189 &session.peer,
3190 );
3191 } else {
3192 let update = proto::UpdateChannels {
3193 remove_channel_invitations: vec![channel_id.to_proto()],
3194 ..Default::default()
3195 };
3196
3197 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3198 session.peer.send(connection_id, update.clone())?;
3199 }
3200 };
3201
3202 send_notifications(&connection_pool, &session.peer, notifications);
3203
3204 response.send(proto::Ack {})?;
3205
3206 Ok(())
3207}
3208
3209/// Join the channels' room
3210async fn join_channel(
3211 request: proto::JoinChannel,
3212 response: Response<proto::JoinChannel>,
3213 session: MessageContext,
3214) -> Result<()> {
3215 let channel_id = ChannelId::from_proto(request.channel_id);
3216 join_channel_internal(channel_id, Box::new(response), session).await
3217}
3218
3219trait JoinChannelInternalResponse {
3220 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3221}
3222impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3223 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3224 Response::<proto::JoinChannel>::send(self, result)
3225 }
3226}
3227impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3228 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3229 Response::<proto::JoinRoom>::send(self, result)
3230 }
3231}
3232
3233async fn join_channel_internal(
3234 channel_id: ChannelId,
3235 response: Box<impl JoinChannelInternalResponse>,
3236 session: MessageContext,
3237) -> Result<()> {
3238 let joined_room = {
3239 let mut db = session.db().await;
3240 // If zed quits without leaving the room, and the user re-opens zed before the
3241 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3242 // room they were in.
3243 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3244 tracing::info!(
3245 stale_connection_id = %connection,
3246 "cleaning up stale connection",
3247 );
3248 drop(db);
3249 leave_room_for_session(&session, connection).await?;
3250 db = session.db().await;
3251 }
3252
3253 let (joined_room, membership_updated, role) = db
3254 .join_channel(channel_id, session.user_id(), session.connection_id)
3255 .await?;
3256
3257 let live_kit_connection_info =
3258 session
3259 .app_state
3260 .livekit_client
3261 .as_ref()
3262 .and_then(|live_kit| {
3263 let (can_publish, token) = if role == ChannelRole::Guest {
3264 (
3265 false,
3266 live_kit
3267 .guest_token(
3268 &joined_room.room.livekit_room,
3269 &session.user_id().to_string(),
3270 )
3271 .trace_err()?,
3272 )
3273 } else {
3274 (
3275 true,
3276 live_kit
3277 .room_token(
3278 &joined_room.room.livekit_room,
3279 &session.user_id().to_string(),
3280 )
3281 .trace_err()?,
3282 )
3283 };
3284
3285 Some(LiveKitConnectionInfo {
3286 server_url: live_kit.url().into(),
3287 token,
3288 can_publish,
3289 })
3290 });
3291
3292 response.send(proto::JoinRoomResponse {
3293 room: Some(joined_room.room.clone()),
3294 channel_id: joined_room
3295 .channel
3296 .as_ref()
3297 .map(|channel| channel.id.to_proto()),
3298 live_kit_connection_info,
3299 })?;
3300
3301 let mut connection_pool = session.connection_pool().await;
3302 if let Some(membership_updated) = membership_updated {
3303 notify_membership_updated(
3304 &mut connection_pool,
3305 membership_updated,
3306 session.user_id(),
3307 &session.peer,
3308 );
3309 }
3310
3311 room_updated(&joined_room.room, &session.peer);
3312
3313 joined_room
3314 };
3315
3316 channel_updated(
3317 &joined_room.channel.context("channel not returned")?,
3318 &joined_room.room,
3319 &session.peer,
3320 &*session.connection_pool().await,
3321 );
3322
3323 update_user_contacts(session.user_id(), &session).await?;
3324 Ok(())
3325}
3326
3327/// Start editing the channel notes
3328async fn join_channel_buffer(
3329 request: proto::JoinChannelBuffer,
3330 response: Response<proto::JoinChannelBuffer>,
3331 session: MessageContext,
3332) -> Result<()> {
3333 let db = session.db().await;
3334 let channel_id = ChannelId::from_proto(request.channel_id);
3335
3336 let open_response = db
3337 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3338 .await?;
3339
3340 let collaborators = open_response.collaborators.clone();
3341 response.send(open_response)?;
3342
3343 let update = UpdateChannelBufferCollaborators {
3344 channel_id: channel_id.to_proto(),
3345 collaborators: collaborators.clone(),
3346 };
3347 channel_buffer_updated(
3348 session.connection_id,
3349 collaborators
3350 .iter()
3351 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3352 &update,
3353 &session.peer,
3354 );
3355
3356 Ok(())
3357}
3358
3359/// Edit the channel notes
3360async fn update_channel_buffer(
3361 request: proto::UpdateChannelBuffer,
3362 session: MessageContext,
3363) -> Result<()> {
3364 let db = session.db().await;
3365 let channel_id = ChannelId::from_proto(request.channel_id);
3366
3367 let (collaborators, epoch, version) = db
3368 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3369 .await?;
3370
3371 channel_buffer_updated(
3372 session.connection_id,
3373 collaborators.clone(),
3374 &proto::UpdateChannelBuffer {
3375 channel_id: channel_id.to_proto(),
3376 operations: request.operations,
3377 },
3378 &session.peer,
3379 );
3380
3381 let pool = &*session.connection_pool().await;
3382
3383 let non_collaborators =
3384 pool.channel_connection_ids(channel_id)
3385 .filter_map(|(connection_id, _)| {
3386 if collaborators.contains(&connection_id) {
3387 None
3388 } else {
3389 Some(connection_id)
3390 }
3391 });
3392
3393 broadcast(None, non_collaborators, |peer_id| {
3394 session.peer.send(
3395 peer_id,
3396 proto::UpdateChannels {
3397 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3398 channel_id: channel_id.to_proto(),
3399 epoch: epoch as u64,
3400 version: version.clone(),
3401 }],
3402 ..Default::default()
3403 },
3404 )
3405 });
3406
3407 Ok(())
3408}
3409
3410/// Rejoin the channel notes after a connection blip
3411async fn rejoin_channel_buffers(
3412 request: proto::RejoinChannelBuffers,
3413 response: Response<proto::RejoinChannelBuffers>,
3414 session: MessageContext,
3415) -> Result<()> {
3416 let db = session.db().await;
3417 let buffers = db
3418 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3419 .await?;
3420
3421 for rejoined_buffer in &buffers {
3422 let collaborators_to_notify = rejoined_buffer
3423 .buffer
3424 .collaborators
3425 .iter()
3426 .filter_map(|c| Some(c.peer_id?.into()));
3427 channel_buffer_updated(
3428 session.connection_id,
3429 collaborators_to_notify,
3430 &proto::UpdateChannelBufferCollaborators {
3431 channel_id: rejoined_buffer.buffer.channel_id,
3432 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3433 },
3434 &session.peer,
3435 );
3436 }
3437
3438 response.send(proto::RejoinChannelBuffersResponse {
3439 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3440 })?;
3441
3442 Ok(())
3443}
3444
3445/// Stop editing the channel notes
3446async fn leave_channel_buffer(
3447 request: proto::LeaveChannelBuffer,
3448 response: Response<proto::LeaveChannelBuffer>,
3449 session: MessageContext,
3450) -> Result<()> {
3451 let db = session.db().await;
3452 let channel_id = ChannelId::from_proto(request.channel_id);
3453
3454 let left_buffer = db
3455 .leave_channel_buffer(channel_id, session.connection_id)
3456 .await?;
3457
3458 response.send(Ack {})?;
3459
3460 channel_buffer_updated(
3461 session.connection_id,
3462 left_buffer.connections,
3463 &proto::UpdateChannelBufferCollaborators {
3464 channel_id: channel_id.to_proto(),
3465 collaborators: left_buffer.collaborators,
3466 },
3467 &session.peer,
3468 );
3469
3470 Ok(())
3471}
3472
3473fn channel_buffer_updated<T: EnvelopedMessage>(
3474 sender_id: ConnectionId,
3475 collaborators: impl IntoIterator<Item = ConnectionId>,
3476 message: &T,
3477 peer: &Peer,
3478) {
3479 broadcast(Some(sender_id), collaborators, |peer_id| {
3480 peer.send(peer_id, message.clone())
3481 });
3482}
3483
3484fn send_notifications(
3485 connection_pool: &ConnectionPool,
3486 peer: &Peer,
3487 notifications: db::NotificationBatch,
3488) {
3489 for (user_id, notification) in notifications {
3490 for connection_id in connection_pool.user_connection_ids(user_id) {
3491 if let Err(error) = peer.send(
3492 connection_id,
3493 proto::AddNotification {
3494 notification: Some(notification.clone()),
3495 },
3496 ) {
3497 tracing::error!(
3498 "failed to send notification to {:?} {}",
3499 connection_id,
3500 error
3501 );
3502 }
3503 }
3504 }
3505}
3506
3507/// Send a message to the channel
3508async fn send_channel_message(
3509 _request: proto::SendChannelMessage,
3510 _response: Response<proto::SendChannelMessage>,
3511 _session: MessageContext,
3512) -> Result<()> {
3513 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3514}
3515
3516/// Delete a channel message
3517async fn remove_channel_message(
3518 _request: proto::RemoveChannelMessage,
3519 _response: Response<proto::RemoveChannelMessage>,
3520 _session: MessageContext,
3521) -> Result<()> {
3522 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3523}
3524
3525async fn update_channel_message(
3526 _request: proto::UpdateChannelMessage,
3527 _response: Response<proto::UpdateChannelMessage>,
3528 _session: MessageContext,
3529) -> Result<()> {
3530 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3531}
3532
3533/// Mark a channel message as read
3534async fn acknowledge_channel_message(
3535 _request: proto::AckChannelMessage,
3536 _session: MessageContext,
3537) -> Result<()> {
3538 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3539}
3540
3541/// Mark a buffer version as synced
3542async fn acknowledge_buffer_version(
3543 request: proto::AckBufferOperation,
3544 session: MessageContext,
3545) -> Result<()> {
3546 let buffer_id = BufferId::from_proto(request.buffer_id);
3547 session
3548 .db()
3549 .await
3550 .observe_buffer_version(
3551 buffer_id,
3552 session.user_id(),
3553 request.epoch as i32,
3554 &request.version,
3555 )
3556 .await?;
3557 Ok(())
3558}
3559
3560/// Start receiving chat updates for a channel
3561async fn join_channel_chat(
3562 _request: proto::JoinChannelChat,
3563 _response: Response<proto::JoinChannelChat>,
3564 _session: MessageContext,
3565) -> Result<()> {
3566 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3567}
3568
3569/// Stop receiving chat updates for a channel
3570async fn leave_channel_chat(
3571 _request: proto::LeaveChannelChat,
3572 _session: MessageContext,
3573) -> Result<()> {
3574 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3575}
3576
3577/// Retrieve the chat history for a channel
3578async fn get_channel_messages(
3579 _request: proto::GetChannelMessages,
3580 _response: Response<proto::GetChannelMessages>,
3581 _session: MessageContext,
3582) -> Result<()> {
3583 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3584}
3585
3586/// Retrieve specific chat messages
3587async fn get_channel_messages_by_id(
3588 _request: proto::GetChannelMessagesById,
3589 _response: Response<proto::GetChannelMessagesById>,
3590 _session: MessageContext,
3591) -> Result<()> {
3592 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3593}
3594
3595/// Retrieve the current users notifications
3596async fn get_notifications(
3597 request: proto::GetNotifications,
3598 response: Response<proto::GetNotifications>,
3599 session: MessageContext,
3600) -> Result<()> {
3601 let notifications = session
3602 .db()
3603 .await
3604 .get_notifications(
3605 session.user_id(),
3606 NOTIFICATION_COUNT_PER_PAGE,
3607 request.before_id.map(db::NotificationId::from_proto),
3608 )
3609 .await?;
3610 response.send(proto::GetNotificationsResponse {
3611 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3612 notifications,
3613 })?;
3614 Ok(())
3615}
3616
3617/// Mark notifications as read
3618async fn mark_notification_as_read(
3619 request: proto::MarkNotificationRead,
3620 response: Response<proto::MarkNotificationRead>,
3621 session: MessageContext,
3622) -> Result<()> {
3623 let database = &session.db().await;
3624 let notifications = database
3625 .mark_notification_as_read_by_id(
3626 session.user_id(),
3627 NotificationId::from_proto(request.notification_id),
3628 )
3629 .await?;
3630 send_notifications(
3631 &*session.connection_pool().await,
3632 &session.peer,
3633 notifications,
3634 );
3635 response.send(proto::Ack {})?;
3636 Ok(())
3637}
3638
3639fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3640 let message = match message {
3641 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3642 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3643 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3644 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3645 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3646 code: frame.code.into(),
3647 reason: frame.reason.as_str().to_owned().into(),
3648 })),
3649 // We should never receive a frame while reading the message, according
3650 // to the `tungstenite` maintainers:
3651 //
3652 // > It cannot occur when you read messages from the WebSocket, but it
3653 // > can be used when you want to send the raw frames (e.g. you want to
3654 // > send the frames to the WebSocket without composing the full message first).
3655 // >
3656 // > — https://github.com/snapview/tungstenite-rs/issues/268
3657 TungsteniteMessage::Frame(_) => {
3658 bail!("received an unexpected frame while reading the message")
3659 }
3660 };
3661
3662 Ok(message)
3663}
3664
3665fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3666 match message {
3667 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3668 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3669 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3670 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3671 AxumMessage::Close(frame) => {
3672 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3673 code: frame.code.into(),
3674 reason: frame.reason.as_ref().into(),
3675 }))
3676 }
3677 }
3678}
3679
3680fn notify_membership_updated(
3681 connection_pool: &mut ConnectionPool,
3682 result: MembershipUpdated,
3683 user_id: UserId,
3684 peer: &Peer,
3685) {
3686 for membership in &result.new_channels.channel_memberships {
3687 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3688 }
3689 for channel_id in &result.removed_channels {
3690 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3691 }
3692
3693 let user_channels_update = proto::UpdateUserChannels {
3694 channel_memberships: result
3695 .new_channels
3696 .channel_memberships
3697 .iter()
3698 .map(|cm| proto::ChannelMembership {
3699 channel_id: cm.channel_id.to_proto(),
3700 role: cm.role.into(),
3701 })
3702 .collect(),
3703 ..Default::default()
3704 };
3705
3706 let mut update = build_channels_update(result.new_channels);
3707 update.delete_channels = result
3708 .removed_channels
3709 .into_iter()
3710 .map(|id| id.to_proto())
3711 .collect();
3712 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3713
3714 for connection_id in connection_pool.user_connection_ids(user_id) {
3715 peer.send(connection_id, user_channels_update.clone())
3716 .trace_err();
3717 peer.send(connection_id, update.clone()).trace_err();
3718 }
3719}
3720
3721fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3722 proto::UpdateUserChannels {
3723 channel_memberships: channels
3724 .channel_memberships
3725 .iter()
3726 .map(|m| proto::ChannelMembership {
3727 channel_id: m.channel_id.to_proto(),
3728 role: m.role.into(),
3729 })
3730 .collect(),
3731 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3732 }
3733}
3734
3735fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3736 let mut update = proto::UpdateChannels::default();
3737
3738 for channel in channels.channels {
3739 update.channels.push(channel.to_proto());
3740 }
3741
3742 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3743
3744 for (channel_id, participants) in channels.channel_participants {
3745 update
3746 .channel_participants
3747 .push(proto::ChannelParticipants {
3748 channel_id: channel_id.to_proto(),
3749 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3750 });
3751 }
3752
3753 for channel in channels.invited_channels {
3754 update.channel_invitations.push(channel.to_proto());
3755 }
3756
3757 update
3758}
3759
3760fn build_initial_contacts_update(
3761 contacts: Vec<db::Contact>,
3762 pool: &ConnectionPool,
3763) -> proto::UpdateContacts {
3764 let mut update = proto::UpdateContacts::default();
3765
3766 for contact in contacts {
3767 match contact {
3768 db::Contact::Accepted { user_id, busy } => {
3769 update.contacts.push(contact_for_user(user_id, busy, pool));
3770 }
3771 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3772 db::Contact::Incoming { user_id } => {
3773 update
3774 .incoming_requests
3775 .push(proto::IncomingContactRequest {
3776 requester_id: user_id.to_proto(),
3777 })
3778 }
3779 }
3780 }
3781
3782 update
3783}
3784
3785fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3786 proto::Contact {
3787 user_id: user_id.to_proto(),
3788 online: pool.is_user_online(user_id),
3789 busy,
3790 }
3791}
3792
3793fn room_updated(room: &proto::Room, peer: &Peer) {
3794 broadcast(
3795 None,
3796 room.participants
3797 .iter()
3798 .filter_map(|participant| Some(participant.peer_id?.into())),
3799 |peer_id| {
3800 peer.send(
3801 peer_id,
3802 proto::RoomUpdated {
3803 room: Some(room.clone()),
3804 },
3805 )
3806 },
3807 );
3808}
3809
3810fn channel_updated(
3811 channel: &db::channel::Model,
3812 room: &proto::Room,
3813 peer: &Peer,
3814 pool: &ConnectionPool,
3815) {
3816 let participants = room
3817 .participants
3818 .iter()
3819 .map(|p| p.user_id)
3820 .collect::<Vec<_>>();
3821
3822 broadcast(
3823 None,
3824 pool.channel_connection_ids(channel.root_id())
3825 .filter_map(|(channel_id, role)| {
3826 role.can_see_channel(channel.visibility)
3827 .then_some(channel_id)
3828 }),
3829 |peer_id| {
3830 peer.send(
3831 peer_id,
3832 proto::UpdateChannels {
3833 channel_participants: vec![proto::ChannelParticipants {
3834 channel_id: channel.id.to_proto(),
3835 participant_user_ids: participants.clone(),
3836 }],
3837 ..Default::default()
3838 },
3839 )
3840 },
3841 );
3842}
3843
3844async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3845 let db = session.db().await;
3846
3847 let contacts = db.get_contacts(user_id).await?;
3848 let busy = db.is_user_busy(user_id).await?;
3849
3850 let pool = session.connection_pool().await;
3851 let updated_contact = contact_for_user(user_id, busy, &pool);
3852 for contact in contacts {
3853 if let db::Contact::Accepted {
3854 user_id: contact_user_id,
3855 ..
3856 } = contact
3857 {
3858 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3859 session
3860 .peer
3861 .send(
3862 contact_conn_id,
3863 proto::UpdateContacts {
3864 contacts: vec![updated_contact.clone()],
3865 remove_contacts: Default::default(),
3866 incoming_requests: Default::default(),
3867 remove_incoming_requests: Default::default(),
3868 outgoing_requests: Default::default(),
3869 remove_outgoing_requests: Default::default(),
3870 },
3871 )
3872 .trace_err();
3873 }
3874 }
3875 }
3876 Ok(())
3877}
3878
3879async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3880 let mut contacts_to_update = HashSet::default();
3881
3882 let room_id;
3883 let canceled_calls_to_user_ids;
3884 let livekit_room;
3885 let delete_livekit_room;
3886 let room;
3887 let channel;
3888
3889 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3890 contacts_to_update.insert(session.user_id());
3891
3892 for project in left_room.left_projects.values() {
3893 project_left(project, session);
3894 }
3895
3896 room_id = RoomId::from_proto(left_room.room.id);
3897 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3898 livekit_room = mem::take(&mut left_room.room.livekit_room);
3899 delete_livekit_room = left_room.deleted;
3900 room = mem::take(&mut left_room.room);
3901 channel = mem::take(&mut left_room.channel);
3902
3903 room_updated(&room, &session.peer);
3904 } else {
3905 return Ok(());
3906 }
3907
3908 if let Some(channel) = channel {
3909 channel_updated(
3910 &channel,
3911 &room,
3912 &session.peer,
3913 &*session.connection_pool().await,
3914 );
3915 }
3916
3917 {
3918 let pool = session.connection_pool().await;
3919 for canceled_user_id in canceled_calls_to_user_ids {
3920 for connection_id in pool.user_connection_ids(canceled_user_id) {
3921 session
3922 .peer
3923 .send(
3924 connection_id,
3925 proto::CallCanceled {
3926 room_id: room_id.to_proto(),
3927 },
3928 )
3929 .trace_err();
3930 }
3931 contacts_to_update.insert(canceled_user_id);
3932 }
3933 }
3934
3935 for contact_user_id in contacts_to_update {
3936 update_user_contacts(contact_user_id, session).await?;
3937 }
3938
3939 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3940 live_kit
3941 .remove_participant(livekit_room.clone(), session.user_id().to_string())
3942 .await
3943 .trace_err();
3944
3945 if delete_livekit_room {
3946 live_kit.delete_room(livekit_room).await.trace_err();
3947 }
3948 }
3949
3950 Ok(())
3951}
3952
3953async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3954 let left_channel_buffers = session
3955 .db()
3956 .await
3957 .leave_channel_buffers(session.connection_id)
3958 .await?;
3959
3960 for left_buffer in left_channel_buffers {
3961 channel_buffer_updated(
3962 session.connection_id,
3963 left_buffer.connections,
3964 &proto::UpdateChannelBufferCollaborators {
3965 channel_id: left_buffer.channel_id.to_proto(),
3966 collaborators: left_buffer.collaborators,
3967 },
3968 &session.peer,
3969 );
3970 }
3971
3972 Ok(())
3973}
3974
3975fn project_left(project: &db::LeftProject, session: &Session) {
3976 for connection_id in &project.connection_ids {
3977 if project.should_unshare {
3978 session
3979 .peer
3980 .send(
3981 *connection_id,
3982 proto::UnshareProject {
3983 project_id: project.id.to_proto(),
3984 },
3985 )
3986 .trace_err();
3987 } else {
3988 session
3989 .peer
3990 .send(
3991 *connection_id,
3992 proto::RemoveProjectCollaborator {
3993 project_id: project.id.to_proto(),
3994 peer_id: Some(session.connection_id.into()),
3995 },
3996 )
3997 .trace_err();
3998 }
3999 }
4000}
4001
4002async fn share_agent_thread(
4003 request: proto::ShareAgentThread,
4004 response: Response<proto::ShareAgentThread>,
4005 session: MessageContext,
4006) -> Result<()> {
4007 let user_id = session.user_id();
4008
4009 let share_id = SharedThreadId::from_proto(request.session_id.clone())
4010 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4011
4012 session
4013 .db()
4014 .await
4015 .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4016 .await?;
4017
4018 response.send(proto::Ack {})?;
4019
4020 Ok(())
4021}
4022
4023async fn get_shared_agent_thread(
4024 request: proto::GetSharedAgentThread,
4025 response: Response<proto::GetSharedAgentThread>,
4026 session: MessageContext,
4027) -> Result<()> {
4028 let share_id = SharedThreadId::from_proto(request.session_id)
4029 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4030
4031 let result = session.db().await.get_shared_thread(share_id).await?;
4032
4033 match result {
4034 Some((thread, username)) => {
4035 response.send(proto::GetSharedAgentThreadResponse {
4036 title: thread.title,
4037 thread_data: thread.data,
4038 sharer_username: username,
4039 created_at: thread.created_at.and_utc().to_rfc3339(),
4040 })?;
4041 }
4042 None => {
4043 return Err(anyhow!("Shared thread not found").into());
4044 }
4045 }
4046
4047 Ok(())
4048}
4049
4050pub trait ResultExt {
4051 type Ok;
4052
4053 fn trace_err(self) -> Option<Self::Ok>;
4054}
4055
4056impl<T, E> ResultExt for Result<T, E>
4057where
4058 E: std::fmt::Debug,
4059{
4060 type Ok = T;
4061
4062 #[track_caller]
4063 fn trace_err(self) -> Option<T> {
4064 match self {
4065 Ok(value) => Some(value),
4066 Err(error) => {
4067 tracing::error!("{:?}", error);
4068 None
4069 }
4070 }
4071 }
4072}