1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
10 UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use futures::TryFutureExt as _;
36use rpc::proto::split_repository_update;
37use tracing::Span;
38use util::paths::PathStyle;
39
40use futures::{
41 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
42 stream::FuturesUnordered,
43};
44use prometheus::{IntGauge, register_int_gauge};
45use rpc::{
46 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
47 proto::{
48 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
49 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
50 },
51};
52use semver::Version;
53use std::{
54 any::TypeId,
55 future::Future,
56 marker::PhantomData,
57 mem,
58 net::SocketAddr,
59 ops::{Deref, DerefMut},
60 rc::Rc,
61 sync::{
62 Arc, OnceLock,
63 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
64 },
65 time::{Duration, Instant},
66};
67use tokio::sync::{Semaphore, watch};
68use tower::ServiceBuilder;
69use tracing::{
70 Instrument,
71 field::{self},
72 info_span, instrument,
73};
74
75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
76
77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
79
80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
82
83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
84
85const TOTAL_DURATION_MS: &str = "total_duration_ms";
86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
88const HOST_WAITING_MS: &str = "host_waiting_ms";
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
92
93pub struct ConnectionGuard;
94
95impl ConnectionGuard {
96 pub fn try_acquire() -> Result<Self, ()> {
97 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
98 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
99 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
100 tracing::error!(
101 "too many concurrent connections: {}",
102 current_connections + 1
103 );
104 return Err(());
105 }
106 Ok(ConnectionGuard)
107 }
108}
109
110impl Drop for ConnectionGuard {
111 fn drop(&mut self) {
112 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
113 }
114}
115
116struct Response<R> {
117 peer: Arc<Peer>,
118 receipt: Receipt<R>,
119 responded: Arc<AtomicBool>,
120}
121
122impl<R: RequestMessage> Response<R> {
123 fn send(self, payload: R::Response) -> Result<()> {
124 self.responded.store(true, SeqCst);
125 self.peer.respond(self.receipt, payload)?;
126 Ok(())
127 }
128}
129
130#[derive(Clone, Debug)]
131pub enum Principal {
132 User(User),
133}
134
135impl Principal {
136 fn update_span(&self, span: &tracing::Span) {
137 match &self {
138 Principal::User(user) => {
139 span.record("user_id", user.id.0);
140 span.record("login", &user.github_login);
141 }
142 }
143 }
144}
145
146#[derive(Clone)]
147struct MessageContext {
148 session: Session,
149 span: tracing::Span,
150}
151
152impl Deref for MessageContext {
153 type Target = Session;
154
155 fn deref(&self) -> &Self::Target {
156 &self.session
157 }
158}
159
160impl MessageContext {
161 pub fn forward_request<T: RequestMessage>(
162 &self,
163 receiver_id: ConnectionId,
164 request: T,
165 ) -> impl Future<Output = anyhow::Result<T::Response>> {
166 let request_start_time = Instant::now();
167 let span = self.span.clone();
168 tracing::info!("start forwarding request");
169 self.peer
170 .forward_request(self.connection_id, receiver_id, request)
171 .inspect(move |_| {
172 span.record(
173 HOST_WAITING_MS,
174 request_start_time.elapsed().as_micros() as f64 / 1000.0,
175 );
176 })
177 .inspect_err(|_| tracing::error!("error forwarding request"))
178 .inspect_ok(|_| tracing::info!("finished forwarding request"))
179 }
180}
181
182#[derive(Clone)]
183struct Session {
184 principal: Principal,
185 connection_id: ConnectionId,
186 db: Arc<tokio::sync::Mutex<DbHandle>>,
187 peer: Arc<Peer>,
188 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
189 app_state: Arc<AppState>,
190 /// The GeoIP country code for the user.
191 #[allow(unused)]
192 geoip_country_code: Option<String>,
193 #[allow(unused)]
194 system_id: Option<String>,
195 _executor: Executor,
196}
197
198impl Session {
199 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
200 #[cfg(feature = "test-support")]
201 tokio::task::yield_now().await;
202 let guard = self.db.lock().await;
203 #[cfg(feature = "test-support")]
204 tokio::task::yield_now().await;
205 guard
206 }
207
208 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
209 #[cfg(feature = "test-support")]
210 tokio::task::yield_now().await;
211 let guard = self.connection_pool.lock();
212 ConnectionPoolGuard {
213 guard,
214 _not_send: PhantomData,
215 }
216 }
217
218 #[expect(dead_code)]
219 fn is_staff(&self) -> bool {
220 match &self.principal {
221 Principal::User(user) => user.admin,
222 }
223 }
224
225 fn user_id(&self) -> UserId {
226 match &self.principal {
227 Principal::User(user) => user.id,
228 }
229 }
230}
231
232impl Debug for Session {
233 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
234 let mut result = f.debug_struct("Session");
235 match &self.principal {
236 Principal::User(user) => {
237 result.field("user", &user.github_login);
238 }
239 }
240 result.field("connection_id", &self.connection_id).finish()
241 }
242}
243
244struct DbHandle(Arc<Database>);
245
246impl Deref for DbHandle {
247 type Target = Database;
248
249 fn deref(&self) -> &Self::Target {
250 self.0.as_ref()
251 }
252}
253
254pub struct Server {
255 id: parking_lot::Mutex<ServerId>,
256 peer: Arc<Peer>,
257 pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
258 app_state: Arc<AppState>,
259 handlers: HashMap<TypeId, MessageHandler>,
260 teardown: watch::Sender<bool>,
261}
262
263struct ConnectionPoolGuard<'a> {
264 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
265 _not_send: PhantomData<Rc<()>>,
266}
267
268impl Server {
269 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
270 let mut server = Self {
271 id: parking_lot::Mutex::new(id),
272 peer: Peer::new(id.0 as u32),
273 app_state,
274 connection_pool: Default::default(),
275 handlers: Default::default(),
276 teardown: watch::channel(false).0,
277 };
278
279 server
280 .add_request_handler(ping)
281 .add_request_handler(create_room)
282 .add_request_handler(join_room)
283 .add_request_handler(rejoin_room)
284 .add_request_handler(leave_room)
285 .add_request_handler(set_room_participant_role)
286 .add_request_handler(call)
287 .add_request_handler(cancel_call)
288 .add_message_handler(decline_call)
289 .add_request_handler(update_participant_location)
290 .add_request_handler(share_project)
291 .add_message_handler(unshare_project)
292 .add_request_handler(join_project)
293 .add_message_handler(leave_project)
294 .add_request_handler(update_project)
295 .add_request_handler(update_worktree)
296 .add_request_handler(update_repository)
297 .add_request_handler(remove_repository)
298 .add_message_handler(start_language_server)
299 .add_message_handler(update_language_server)
300 .add_message_handler(update_diagnostic_summary)
301 .add_message_handler(update_worktree_settings)
302 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
307 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
308 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
311 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
313 .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
314 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
316 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
317 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
318 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
319 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
320 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
321 .add_request_handler(
322 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
323 )
324 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
325 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
326 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
327 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
328 .add_request_handler(
329 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
330 )
331 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
332 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
333 .add_request_handler(
334 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
335 )
336 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
337 .add_request_handler(
338 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
339 )
340 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
341 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
342 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
343 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
344 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
345 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
346 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
347 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
348 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
349 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
350 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
353 )
354 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
355 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
356 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
357 .add_request_handler(lsp_query)
358 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
359 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
360 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
361 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
362 .add_message_handler(create_buffer_for_peer)
363 .add_message_handler(create_image_for_peer)
364 .add_request_handler(update_buffer)
365 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
366 .add_message_handler(
367 broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
368 )
369 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
370 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
371 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
372 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
373 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
374 .add_message_handler(
375 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
376 )
377 .add_request_handler(get_users)
378 .add_request_handler(fuzzy_search_users)
379 .add_request_handler(request_contact)
380 .add_request_handler(remove_contact)
381 .add_request_handler(respond_to_contact_request)
382 .add_message_handler(subscribe_to_channels)
383 .add_request_handler(create_channel)
384 .add_request_handler(delete_channel)
385 .add_request_handler(invite_channel_member)
386 .add_request_handler(remove_channel_member)
387 .add_request_handler(set_channel_member_role)
388 .add_request_handler(set_channel_visibility)
389 .add_request_handler(rename_channel)
390 .add_request_handler(join_channel_buffer)
391 .add_request_handler(leave_channel_buffer)
392 .add_message_handler(update_channel_buffer)
393 .add_request_handler(rejoin_channel_buffers)
394 .add_request_handler(get_channel_members)
395 .add_request_handler(respond_to_channel_invite)
396 .add_request_handler(join_channel)
397 .add_request_handler(join_channel_chat)
398 .add_message_handler(leave_channel_chat)
399 .add_request_handler(send_channel_message)
400 .add_request_handler(remove_channel_message)
401 .add_request_handler(update_channel_message)
402 .add_request_handler(get_channel_messages)
403 .add_request_handler(get_channel_messages_by_id)
404 .add_request_handler(get_notifications)
405 .add_request_handler(mark_notification_as_read)
406 .add_request_handler(move_channel)
407 .add_request_handler(reorder_channel)
408 .add_request_handler(follow)
409 .add_message_handler(unfollow)
410 .add_message_handler(update_followers)
411 .add_message_handler(acknowledge_channel_message)
412 .add_message_handler(acknowledge_buffer_version)
413 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
414 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
415 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
416 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
417 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
418 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
419 .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
420 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
421 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
422 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
423 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
424 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
425 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
426 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
427 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
428 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
429 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
430 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
431 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
432 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
433 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
434 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
435 .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
436 .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
437 .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
438 .add_request_handler(forward_read_only_project_request::<proto::GitGetHeadSha>)
439 .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
440 .add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
441 .add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
442 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
443 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
444 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
445 .add_request_handler(share_agent_thread)
446 .add_request_handler(get_shared_agent_thread)
447 .add_request_handler(forward_project_search_chunk);
448
449 Arc::new(server)
450 }
451
452 pub async fn start(&self) -> Result<()> {
453 let server_id = *self.id.lock();
454 let app_state = self.app_state.clone();
455 let peer = self.peer.clone();
456 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
457 let pool = self.connection_pool.clone();
458 let livekit_client = self.app_state.livekit_client.clone();
459
460 let span = info_span!("start server");
461 self.app_state.executor.spawn_detached(
462 async move {
463 tracing::info!("waiting for cleanup timeout");
464 timeout.await;
465 tracing::info!("cleanup timeout expired, retrieving stale rooms");
466
467 app_state
468 .db
469 .delete_stale_channel_chat_participants(
470 &app_state.config.zed_environment,
471 server_id,
472 )
473 .await
474 .trace_err();
475
476 if let Some((room_ids, channel_ids)) = app_state
477 .db
478 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
479 .await
480 .trace_err()
481 {
482 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
483 tracing::info!(
484 stale_channel_buffer_count = channel_ids.len(),
485 "retrieved stale channel buffers"
486 );
487
488 for channel_id in channel_ids {
489 if let Some(refreshed_channel_buffer) = app_state
490 .db
491 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
492 .await
493 .trace_err()
494 {
495 for connection_id in refreshed_channel_buffer.connection_ids {
496 peer.send(
497 connection_id,
498 proto::UpdateChannelBufferCollaborators {
499 channel_id: channel_id.to_proto(),
500 collaborators: refreshed_channel_buffer
501 .collaborators
502 .clone(),
503 },
504 )
505 .trace_err();
506 }
507 }
508 }
509
510 for room_id in room_ids {
511 let mut contacts_to_update = HashSet::default();
512 let mut canceled_calls_to_user_ids = Vec::new();
513 let mut livekit_room = String::new();
514 let mut delete_livekit_room = false;
515
516 if let Some(mut refreshed_room) = app_state
517 .db
518 .clear_stale_room_participants(room_id, server_id)
519 .await
520 .trace_err()
521 {
522 tracing::info!(
523 room_id = room_id.0,
524 new_participant_count = refreshed_room.room.participants.len(),
525 "refreshed room"
526 );
527 room_updated(&refreshed_room.room, &peer);
528 if let Some(channel) = refreshed_room.channel.as_ref() {
529 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
530 }
531 contacts_to_update
532 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
533 contacts_to_update
534 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
535 canceled_calls_to_user_ids =
536 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
537 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
538 delete_livekit_room = refreshed_room.room.participants.is_empty();
539 }
540
541 {
542 let pool = pool.lock();
543 for canceled_user_id in canceled_calls_to_user_ids {
544 for connection_id in pool.user_connection_ids(canceled_user_id) {
545 peer.send(
546 connection_id,
547 proto::CallCanceled {
548 room_id: room_id.to_proto(),
549 },
550 )
551 .trace_err();
552 }
553 }
554 }
555
556 for user_id in contacts_to_update {
557 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
558 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
559 if let Some((busy, contacts)) = busy.zip(contacts) {
560 let pool = pool.lock();
561 let updated_contact = contact_for_user(user_id, busy, &pool);
562 for contact in contacts {
563 if let db::Contact::Accepted {
564 user_id: contact_user_id,
565 ..
566 } = contact
567 {
568 for contact_conn_id in
569 pool.user_connection_ids(contact_user_id)
570 {
571 peer.send(
572 contact_conn_id,
573 proto::UpdateContacts {
574 contacts: vec![updated_contact.clone()],
575 remove_contacts: Default::default(),
576 incoming_requests: Default::default(),
577 remove_incoming_requests: Default::default(),
578 outgoing_requests: Default::default(),
579 remove_outgoing_requests: Default::default(),
580 },
581 )
582 .trace_err();
583 }
584 }
585 }
586 }
587 }
588
589 if let Some(live_kit) = livekit_client.as_ref()
590 && delete_livekit_room
591 {
592 live_kit.delete_room(livekit_room).await.trace_err();
593 }
594 }
595 }
596
597 app_state
598 .db
599 .delete_stale_channel_chat_participants(
600 &app_state.config.zed_environment,
601 server_id,
602 )
603 .await
604 .trace_err();
605
606 app_state
607 .db
608 .clear_old_worktree_entries(server_id)
609 .await
610 .trace_err();
611
612 app_state
613 .db
614 .delete_stale_servers(&app_state.config.zed_environment, server_id)
615 .await
616 .trace_err();
617 }
618 .instrument(span),
619 );
620 Ok(())
621 }
622
623 pub fn teardown(&self) {
624 self.peer.teardown();
625 self.connection_pool.lock().reset();
626 let _ = self.teardown.send(true);
627 }
628
629 #[cfg(feature = "test-support")]
630 pub fn reset(&self, id: ServerId) {
631 self.teardown();
632 *self.id.lock() = id;
633 self.peer.reset(id.0 as u32);
634 let _ = self.teardown.send(false);
635 }
636
637 #[cfg(feature = "test-support")]
638 pub fn id(&self) -> ServerId {
639 *self.id.lock()
640 }
641
642 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
643 where
644 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
645 Fut: 'static + Send + Future<Output = Result<()>>,
646 M: EnvelopedMessage,
647 {
648 let prev_handler = self.handlers.insert(
649 TypeId::of::<M>(),
650 Box::new(move |envelope, session, span| {
651 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
652 let received_at = envelope.received_at;
653 tracing::info!("message received");
654 let start_time = Instant::now();
655 let future = (handler)(
656 *envelope,
657 MessageContext {
658 session,
659 span: span.clone(),
660 },
661 );
662 async move {
663 let result = future.await;
664 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
665 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
666 let queue_duration_ms = total_duration_ms - processing_duration_ms;
667 span.record(TOTAL_DURATION_MS, total_duration_ms);
668 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
669 span.record(QUEUE_DURATION_MS, queue_duration_ms);
670 match result {
671 Err(error) => {
672 tracing::error!(?error, "error handling message")
673 }
674 Ok(()) => tracing::info!("finished handling message"),
675 }
676 }
677 .boxed()
678 }),
679 );
680 if prev_handler.is_some() {
681 panic!("registered a handler for the same message twice");
682 }
683 self
684 }
685
686 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
687 where
688 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
689 Fut: 'static + Send + Future<Output = Result<()>>,
690 M: EnvelopedMessage,
691 {
692 self.add_handler(move |envelope, session| handler(envelope.payload, session));
693 self
694 }
695
696 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
697 where
698 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
699 Fut: Send + Future<Output = Result<()>>,
700 M: RequestMessage,
701 {
702 let handler = Arc::new(handler);
703 self.add_handler(move |envelope, session| {
704 let receipt = envelope.receipt();
705 let handler = handler.clone();
706 async move {
707 let peer = session.peer.clone();
708 let responded = Arc::new(AtomicBool::default());
709 let response = Response {
710 peer: peer.clone(),
711 responded: responded.clone(),
712 receipt,
713 };
714 match (handler)(envelope.payload, response, session).await {
715 Ok(()) => {
716 if responded.load(std::sync::atomic::Ordering::SeqCst) {
717 Ok(())
718 } else {
719 Err(anyhow!("handler did not send a response"))?
720 }
721 }
722 Err(error) => {
723 let proto_err = match &error {
724 Error::Internal(err) => err.to_proto(),
725 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
726 };
727 peer.respond_with_error(receipt, proto_err)?;
728 Err(error)
729 }
730 }
731 }
732 })
733 }
734
735 pub fn handle_connection(
736 self: &Arc<Self>,
737 connection: Connection,
738 address: String,
739 principal: Principal,
740 zed_version: ZedVersion,
741 release_channel: Option<String>,
742 user_agent: Option<String>,
743 geoip_country_code: Option<String>,
744 system_id: Option<String>,
745 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
746 executor: Executor,
747 connection_guard: Option<ConnectionGuard>,
748 ) -> impl Future<Output = ()> + use<> {
749 let this = self.clone();
750 let span = info_span!("handle connection", %address,
751 connection_id=field::Empty,
752 user_id=field::Empty,
753 login=field::Empty,
754 user_agent=field::Empty,
755 geoip_country_code=field::Empty,
756 release_channel=field::Empty,
757 );
758 principal.update_span(&span);
759 if let Some(user_agent) = user_agent {
760 span.record("user_agent", user_agent);
761 }
762 if let Some(release_channel) = release_channel {
763 span.record("release_channel", release_channel);
764 }
765
766 if let Some(country_code) = geoip_country_code.as_ref() {
767 span.record("geoip_country_code", country_code);
768 }
769
770 let mut teardown = self.teardown.subscribe();
771 async move {
772 if *teardown.borrow() {
773 tracing::error!("server is tearing down");
774 return;
775 }
776
777 let (connection_id, handle_io, mut incoming_rx) =
778 this.peer.add_connection(connection, {
779 let executor = executor.clone();
780 move |duration| executor.sleep(duration)
781 });
782 tracing::Span::current().record("connection_id", format!("{}", connection_id));
783
784 tracing::info!("connection opened");
785
786 let session = Session {
787 principal: principal.clone(),
788 connection_id,
789 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
790 peer: this.peer.clone(),
791 connection_pool: this.connection_pool.clone(),
792 app_state: this.app_state.clone(),
793 geoip_country_code,
794 system_id,
795 _executor: executor.clone(),
796 };
797
798 if let Err(error) = this
799 .send_initial_client_update(
800 connection_id,
801 zed_version,
802 send_connection_id,
803 &session,
804 )
805 .await
806 {
807 tracing::error!(?error, "failed to send initial client update");
808 return;
809 }
810 drop(connection_guard);
811
812 let handle_io = handle_io.fuse();
813 futures::pin_mut!(handle_io);
814
815 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
816 // This prevents deadlocks when e.g., client A performs a request to client B and
817 // client B performs a request to client A. If both clients stop processing further
818 // messages until their respective request completes, they won't have a chance to
819 // respond to the other client's request and cause a deadlock.
820 //
821 // This arrangement ensures we will attempt to process earlier messages first, but fall
822 // back to processing messages arrived later in the spirit of making progress.
823 const MAX_CONCURRENT_HANDLERS: usize = 256;
824 let mut foreground_message_handlers = FuturesUnordered::new();
825 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
826 let get_concurrent_handlers = {
827 let concurrent_handlers = concurrent_handlers.clone();
828 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
829 };
830 loop {
831 let next_message = async {
832 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
833 let message = incoming_rx.next().await;
834 // Cache the concurrent_handlers here, so that we know what the
835 // queue looks like as each handler starts
836 (permit, message, get_concurrent_handlers())
837 }
838 .fuse();
839 futures::pin_mut!(next_message);
840 futures::select_biased! {
841 _ = teardown.changed().fuse() => return,
842 result = handle_io => {
843 if let Err(error) = result {
844 tracing::error!(?error, "error handling I/O");
845 }
846 break;
847 }
848 _ = foreground_message_handlers.next() => {}
849 next_message = next_message => {
850 let (permit, message, concurrent_handlers) = next_message;
851 if let Some(message) = message {
852 let type_name = message.payload_type_name();
853 // note: we copy all the fields from the parent span so we can query them in the logs.
854 // (https://github.com/tokio-rs/tracing/issues/2670).
855 let span = tracing::info_span!("receive message",
856 %connection_id,
857 %address,
858 type_name,
859 concurrent_handlers,
860 user_id=field::Empty,
861 login=field::Empty,
862 lsp_query_request=field::Empty,
863 release_channel=field::Empty,
864 { TOTAL_DURATION_MS }=field::Empty,
865 { PROCESSING_DURATION_MS }=field::Empty,
866 { QUEUE_DURATION_MS }=field::Empty,
867 { HOST_WAITING_MS }=field::Empty
868 );
869 principal.update_span(&span);
870 let span_enter = span.enter();
871 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
872 let is_background = message.is_background();
873 let handle_message = (handler)(message, session.clone(), span.clone());
874 drop(span_enter);
875
876 let handle_message = async move {
877 handle_message.await;
878 drop(permit);
879 }.instrument(span);
880 if is_background {
881 executor.spawn_detached(handle_message);
882 } else {
883 foreground_message_handlers.push(handle_message);
884 }
885 } else {
886 tracing::error!("no message handler");
887 }
888 } else {
889 tracing::info!("connection closed");
890 break;
891 }
892 }
893 }
894 }
895
896 drop(foreground_message_handlers);
897 let concurrent_handlers = get_concurrent_handlers();
898 tracing::info!(concurrent_handlers, "signing out");
899 if let Err(error) = connection_lost(session, teardown, executor).await {
900 tracing::error!(?error, "error signing out");
901 }
902 }
903 .instrument(span)
904 }
905
906 async fn send_initial_client_update(
907 &self,
908 connection_id: ConnectionId,
909 zed_version: ZedVersion,
910 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
911 session: &Session,
912 ) -> Result<()> {
913 self.peer.send(
914 connection_id,
915 proto::Hello {
916 peer_id: Some(connection_id.into()),
917 },
918 )?;
919 tracing::info!("sent hello message");
920 if let Some(send_connection_id) = send_connection_id.take() {
921 let _ = send_connection_id.send(connection_id);
922 }
923
924 match &session.principal {
925 Principal::User(user) => {
926 if !user.connected_once {
927 self.peer.send(connection_id, proto::ShowContacts {})?;
928 self.app_state
929 .db
930 .set_user_connected_once(user.id, true)
931 .await?;
932 }
933
934 let contacts = self.app_state.db.get_contacts(user.id).await?;
935
936 {
937 let mut pool = self.connection_pool.lock();
938 pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
939 self.peer.send(
940 connection_id,
941 build_initial_contacts_update(contacts, &pool),
942 )?;
943 }
944
945 if should_auto_subscribe_to_channels(&zed_version) {
946 subscribe_user_to_channels(user.id, session).await?;
947 }
948
949 if let Some(incoming_call) =
950 self.app_state.db.incoming_call_for_user(user.id).await?
951 {
952 self.peer.send(connection_id, incoming_call)?;
953 }
954
955 update_user_contacts(user.id, session).await?;
956 }
957 }
958
959 Ok(())
960 }
961}
962
963impl Deref for ConnectionPoolGuard<'_> {
964 type Target = ConnectionPool;
965
966 fn deref(&self) -> &Self::Target {
967 &self.guard
968 }
969}
970
971impl DerefMut for ConnectionPoolGuard<'_> {
972 fn deref_mut(&mut self) -> &mut Self::Target {
973 &mut self.guard
974 }
975}
976
977impl Drop for ConnectionPoolGuard<'_> {
978 fn drop(&mut self) {
979 #[cfg(feature = "test-support")]
980 self.check_invariants();
981 }
982}
983
984fn broadcast<F>(
985 sender_id: Option<ConnectionId>,
986 receiver_ids: impl IntoIterator<Item = ConnectionId>,
987 mut f: F,
988) where
989 F: FnMut(ConnectionId) -> anyhow::Result<()>,
990{
991 for receiver_id in receiver_ids {
992 if Some(receiver_id) != sender_id
993 && let Err(error) = f(receiver_id)
994 {
995 tracing::error!("failed to send to {:?} {}", receiver_id, error);
996 }
997 }
998}
999
1000pub struct ProtocolVersion(u32);
1001
1002impl Header for ProtocolVersion {
1003 fn name() -> &'static HeaderName {
1004 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1005 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1006 }
1007
1008 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1009 where
1010 Self: Sized,
1011 I: Iterator<Item = &'i axum::http::HeaderValue>,
1012 {
1013 let version = values
1014 .next()
1015 .ok_or_else(axum::headers::Error::invalid)?
1016 .to_str()
1017 .map_err(|_| axum::headers::Error::invalid())?
1018 .parse()
1019 .map_err(|_| axum::headers::Error::invalid())?;
1020 Ok(Self(version))
1021 }
1022
1023 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1024 values.extend([self.0.to_string().parse().unwrap()]);
1025 }
1026}
1027
1028pub struct AppVersionHeader(Version);
1029impl Header for AppVersionHeader {
1030 fn name() -> &'static HeaderName {
1031 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1032 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1033 }
1034
1035 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1036 where
1037 Self: Sized,
1038 I: Iterator<Item = &'i axum::http::HeaderValue>,
1039 {
1040 let version = values
1041 .next()
1042 .ok_or_else(axum::headers::Error::invalid)?
1043 .to_str()
1044 .map_err(|_| axum::headers::Error::invalid())?
1045 .parse()
1046 .map_err(|_| axum::headers::Error::invalid())?;
1047 Ok(Self(version))
1048 }
1049
1050 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1051 values.extend([self.0.to_string().parse().unwrap()]);
1052 }
1053}
1054
1055#[derive(Debug)]
1056pub struct ReleaseChannelHeader(String);
1057
1058impl Header for ReleaseChannelHeader {
1059 fn name() -> &'static HeaderName {
1060 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1061 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1062 }
1063
1064 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1065 where
1066 Self: Sized,
1067 I: Iterator<Item = &'i axum::http::HeaderValue>,
1068 {
1069 Ok(Self(
1070 values
1071 .next()
1072 .ok_or_else(axum::headers::Error::invalid)?
1073 .to_str()
1074 .map_err(|_| axum::headers::Error::invalid())?
1075 .to_owned(),
1076 ))
1077 }
1078
1079 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1080 values.extend([self.0.parse().unwrap()]);
1081 }
1082}
1083
1084pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1085 Router::new()
1086 .route("/rpc", get(handle_websocket_request))
1087 .layer(
1088 ServiceBuilder::new()
1089 .layer(Extension(server.app_state.clone()))
1090 .layer(middleware::from_fn(auth::validate_header)),
1091 )
1092 .route("/metrics", get(handle_metrics))
1093 .layer(Extension(server))
1094}
1095
1096pub async fn handle_websocket_request(
1097 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1098 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1099 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1100 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1101 Extension(server): Extension<Arc<Server>>,
1102 Extension(principal): Extension<Principal>,
1103 user_agent: Option<TypedHeader<UserAgent>>,
1104 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1105 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1106 ws: WebSocketUpgrade,
1107) -> axum::response::Response {
1108 if protocol_version != rpc::PROTOCOL_VERSION {
1109 return (
1110 StatusCode::UPGRADE_REQUIRED,
1111 "client must be upgraded".to_string(),
1112 )
1113 .into_response();
1114 }
1115
1116 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1117 return (
1118 StatusCode::UPGRADE_REQUIRED,
1119 "no version header found".to_string(),
1120 )
1121 .into_response();
1122 };
1123
1124 let release_channel = release_channel_header.map(|header| header.0.0);
1125
1126 if !version.can_collaborate() {
1127 return (
1128 StatusCode::UPGRADE_REQUIRED,
1129 "client must be upgraded".to_string(),
1130 )
1131 .into_response();
1132 }
1133
1134 let socket_address = socket_address.to_string();
1135
1136 // Acquire connection guard before WebSocket upgrade
1137 let connection_guard = match ConnectionGuard::try_acquire() {
1138 Ok(guard) => guard,
1139 Err(()) => {
1140 return (
1141 StatusCode::SERVICE_UNAVAILABLE,
1142 "Too many concurrent connections",
1143 )
1144 .into_response();
1145 }
1146 };
1147
1148 ws.on_upgrade(move |socket| {
1149 let socket = socket
1150 .map_ok(to_tungstenite_message)
1151 .err_into()
1152 .with(|message| async move { to_axum_message(message) });
1153 let connection = Connection::new(Box::pin(socket));
1154 async move {
1155 server
1156 .handle_connection(
1157 connection,
1158 socket_address,
1159 principal,
1160 version,
1161 release_channel,
1162 user_agent.map(|header| header.to_string()),
1163 country_code_header.map(|header| header.to_string()),
1164 system_id_header.map(|header| header.to_string()),
1165 None,
1166 Executor::Production,
1167 Some(connection_guard),
1168 )
1169 .await;
1170 }
1171 })
1172}
1173
1174pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1175 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1176 let connections_metric = CONNECTIONS_METRIC
1177 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1178
1179 let connections = server
1180 .connection_pool
1181 .lock()
1182 .connections()
1183 .filter(|connection| !connection.admin)
1184 .count();
1185 connections_metric.set(connections as _);
1186
1187 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1188 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1189 register_int_gauge!(
1190 "shared_projects",
1191 "number of open projects with one or more guests"
1192 )
1193 .unwrap()
1194 });
1195
1196 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1197 shared_projects_metric.set(shared_projects as _);
1198
1199 let encoder = prometheus::TextEncoder::new();
1200 let metric_families = prometheus::gather();
1201 let encoded_metrics = encoder
1202 .encode_to_string(&metric_families)
1203 .map_err(|err| anyhow!("{err}"))?;
1204 Ok(encoded_metrics)
1205}
1206
1207#[instrument(err, skip(executor))]
1208async fn connection_lost(
1209 session: Session,
1210 mut teardown: watch::Receiver<bool>,
1211 executor: Executor,
1212) -> Result<()> {
1213 session.peer.disconnect(session.connection_id);
1214 session
1215 .connection_pool()
1216 .await
1217 .remove_connection(session.connection_id)?;
1218
1219 session
1220 .db()
1221 .await
1222 .connection_lost(session.connection_id)
1223 .await
1224 .trace_err();
1225
1226 futures::select_biased! {
1227 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1228
1229 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1230 leave_room_for_session(&session, session.connection_id).await.trace_err();
1231 leave_channel_buffers_for_session(&session)
1232 .await
1233 .trace_err();
1234
1235 if !session
1236 .connection_pool()
1237 .await
1238 .is_user_online(session.user_id())
1239 {
1240 let db = session.db().await;
1241 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1242 room_updated(&room, &session.peer);
1243 }
1244 }
1245
1246 update_user_contacts(session.user_id(), &session).await?;
1247 },
1248 _ = teardown.changed().fuse() => {}
1249 }
1250
1251 Ok(())
1252}
1253
1254/// Acknowledges a ping from a client, used to keep the connection alive.
1255async fn ping(
1256 _: proto::Ping,
1257 response: Response<proto::Ping>,
1258 _session: MessageContext,
1259) -> Result<()> {
1260 response.send(proto::Ack {})?;
1261 Ok(())
1262}
1263
1264/// Creates a new room for calling (outside of channels)
1265async fn create_room(
1266 _request: proto::CreateRoom,
1267 response: Response<proto::CreateRoom>,
1268 session: MessageContext,
1269) -> Result<()> {
1270 let livekit_room = nanoid::nanoid!(30);
1271
1272 let live_kit_connection_info = util::maybe!(async {
1273 let live_kit = session.app_state.livekit_client.as_ref();
1274 let live_kit = live_kit?;
1275 let user_id = session.user_id().to_string();
1276
1277 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1278
1279 Some(proto::LiveKitConnectionInfo {
1280 server_url: live_kit.url().into(),
1281 token,
1282 can_publish: true,
1283 })
1284 })
1285 .await;
1286
1287 let room = session
1288 .db()
1289 .await
1290 .create_room(session.user_id(), session.connection_id, &livekit_room)
1291 .await?;
1292
1293 response.send(proto::CreateRoomResponse {
1294 room: Some(room.clone()),
1295 live_kit_connection_info,
1296 })?;
1297
1298 update_user_contacts(session.user_id(), &session).await?;
1299 Ok(())
1300}
1301
1302/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1303async fn join_room(
1304 request: proto::JoinRoom,
1305 response: Response<proto::JoinRoom>,
1306 session: MessageContext,
1307) -> Result<()> {
1308 let room_id = RoomId::from_proto(request.id);
1309
1310 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1311
1312 if let Some(channel_id) = channel_id {
1313 return join_channel_internal(channel_id, Box::new(response), session).await;
1314 }
1315
1316 let joined_room = {
1317 let room = session
1318 .db()
1319 .await
1320 .join_room(room_id, session.user_id(), session.connection_id)
1321 .await?;
1322 room_updated(&room.room, &session.peer);
1323 room.into_inner()
1324 };
1325
1326 for connection_id in session
1327 .connection_pool()
1328 .await
1329 .user_connection_ids(session.user_id())
1330 {
1331 session
1332 .peer
1333 .send(
1334 connection_id,
1335 proto::CallCanceled {
1336 room_id: room_id.to_proto(),
1337 },
1338 )
1339 .trace_err();
1340 }
1341
1342 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1343 {
1344 live_kit
1345 .room_token(
1346 &joined_room.room.livekit_room,
1347 &session.user_id().to_string(),
1348 )
1349 .trace_err()
1350 .map(|token| proto::LiveKitConnectionInfo {
1351 server_url: live_kit.url().into(),
1352 token,
1353 can_publish: true,
1354 })
1355 } else {
1356 None
1357 };
1358
1359 response.send(proto::JoinRoomResponse {
1360 room: Some(joined_room.room),
1361 channel_id: None,
1362 live_kit_connection_info,
1363 })?;
1364
1365 update_user_contacts(session.user_id(), &session).await?;
1366 Ok(())
1367}
1368
1369/// Rejoin room is used to reconnect to a room after connection errors.
1370async fn rejoin_room(
1371 request: proto::RejoinRoom,
1372 response: Response<proto::RejoinRoom>,
1373 session: MessageContext,
1374) -> Result<()> {
1375 let room;
1376 let channel;
1377 {
1378 let mut rejoined_room = session
1379 .db()
1380 .await
1381 .rejoin_room(request, session.user_id(), session.connection_id)
1382 .await?;
1383
1384 response.send(proto::RejoinRoomResponse {
1385 room: Some(rejoined_room.room.clone()),
1386 reshared_projects: rejoined_room
1387 .reshared_projects
1388 .iter()
1389 .map(|project| proto::ResharedProject {
1390 id: project.id.to_proto(),
1391 collaborators: project
1392 .collaborators
1393 .iter()
1394 .map(|collaborator| collaborator.to_proto())
1395 .collect(),
1396 })
1397 .collect(),
1398 rejoined_projects: rejoined_room
1399 .rejoined_projects
1400 .iter()
1401 .map(|rejoined_project| rejoined_project.to_proto())
1402 .collect(),
1403 })?;
1404 room_updated(&rejoined_room.room, &session.peer);
1405
1406 for project in &rejoined_room.reshared_projects {
1407 for collaborator in &project.collaborators {
1408 session
1409 .peer
1410 .send(
1411 collaborator.connection_id,
1412 proto::UpdateProjectCollaborator {
1413 project_id: project.id.to_proto(),
1414 old_peer_id: Some(project.old_connection_id.into()),
1415 new_peer_id: Some(session.connection_id.into()),
1416 },
1417 )
1418 .trace_err();
1419 }
1420
1421 broadcast(
1422 Some(session.connection_id),
1423 project
1424 .collaborators
1425 .iter()
1426 .map(|collaborator| collaborator.connection_id),
1427 |connection_id| {
1428 session.peer.forward_send(
1429 session.connection_id,
1430 connection_id,
1431 proto::UpdateProject {
1432 project_id: project.id.to_proto(),
1433 worktrees: project.worktrees.clone(),
1434 },
1435 )
1436 },
1437 );
1438 }
1439
1440 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1441
1442 let rejoined_room = rejoined_room.into_inner();
1443
1444 room = rejoined_room.room;
1445 channel = rejoined_room.channel;
1446 }
1447
1448 if let Some(channel) = channel {
1449 channel_updated(
1450 &channel,
1451 &room,
1452 &session.peer,
1453 &*session.connection_pool().await,
1454 );
1455 }
1456
1457 update_user_contacts(session.user_id(), &session).await?;
1458 Ok(())
1459}
1460
1461fn notify_rejoined_projects(
1462 rejoined_projects: &mut Vec<RejoinedProject>,
1463 session: &Session,
1464) -> Result<()> {
1465 for project in rejoined_projects.iter() {
1466 for collaborator in &project.collaborators {
1467 session
1468 .peer
1469 .send(
1470 collaborator.connection_id,
1471 proto::UpdateProjectCollaborator {
1472 project_id: project.id.to_proto(),
1473 old_peer_id: Some(project.old_connection_id.into()),
1474 new_peer_id: Some(session.connection_id.into()),
1475 },
1476 )
1477 .trace_err();
1478 }
1479 }
1480
1481 for project in rejoined_projects {
1482 for worktree in mem::take(&mut project.worktrees) {
1483 // Stream this worktree's entries.
1484 let message = proto::UpdateWorktree {
1485 project_id: project.id.to_proto(),
1486 worktree_id: worktree.id,
1487 abs_path: worktree.abs_path.clone(),
1488 root_name: worktree.root_name,
1489 root_repo_common_dir: worktree.root_repo_common_dir,
1490 updated_entries: worktree.updated_entries,
1491 removed_entries: worktree.removed_entries,
1492 scan_id: worktree.scan_id,
1493 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1494 updated_repositories: worktree.updated_repositories,
1495 removed_repositories: worktree.removed_repositories,
1496 };
1497 for update in proto::split_worktree_update(message) {
1498 session.peer.send(session.connection_id, update)?;
1499 }
1500
1501 // Stream this worktree's diagnostics.
1502 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1503 if let Some(summary) = worktree_diagnostics.next() {
1504 let message = proto::UpdateDiagnosticSummary {
1505 project_id: project.id.to_proto(),
1506 worktree_id: worktree.id,
1507 summary: Some(summary),
1508 more_summaries: worktree_diagnostics.collect(),
1509 };
1510 session.peer.send(session.connection_id, message)?;
1511 }
1512
1513 for settings_file in worktree.settings_files {
1514 session.peer.send(
1515 session.connection_id,
1516 proto::UpdateWorktreeSettings {
1517 project_id: project.id.to_proto(),
1518 worktree_id: worktree.id,
1519 path: settings_file.path,
1520 content: Some(settings_file.content),
1521 kind: Some(settings_file.kind.to_proto().into()),
1522 outside_worktree: Some(settings_file.outside_worktree),
1523 },
1524 )?;
1525 }
1526 }
1527
1528 for repository in mem::take(&mut project.updated_repositories) {
1529 for update in split_repository_update(repository) {
1530 session.peer.send(session.connection_id, update)?;
1531 }
1532 }
1533
1534 for id in mem::take(&mut project.removed_repositories) {
1535 session.peer.send(
1536 session.connection_id,
1537 proto::RemoveRepository {
1538 project_id: project.id.to_proto(),
1539 id,
1540 },
1541 )?;
1542 }
1543 }
1544
1545 Ok(())
1546}
1547
1548/// leave room disconnects from the room.
1549async fn leave_room(
1550 _: proto::LeaveRoom,
1551 response: Response<proto::LeaveRoom>,
1552 session: MessageContext,
1553) -> Result<()> {
1554 leave_room_for_session(&session, session.connection_id).await?;
1555 response.send(proto::Ack {})?;
1556 Ok(())
1557}
1558
1559/// Updates the permissions of someone else in the room.
1560async fn set_room_participant_role(
1561 request: proto::SetRoomParticipantRole,
1562 response: Response<proto::SetRoomParticipantRole>,
1563 session: MessageContext,
1564) -> Result<()> {
1565 let user_id = UserId::from_proto(request.user_id);
1566 let role = ChannelRole::from(request.role());
1567
1568 let (livekit_room, can_publish) = {
1569 let room = session
1570 .db()
1571 .await
1572 .set_room_participant_role(
1573 session.user_id(),
1574 RoomId::from_proto(request.room_id),
1575 user_id,
1576 role,
1577 )
1578 .await?;
1579
1580 let livekit_room = room.livekit_room.clone();
1581 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1582 room_updated(&room, &session.peer);
1583 (livekit_room, can_publish)
1584 };
1585
1586 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1587 live_kit
1588 .update_participant(
1589 livekit_room.clone(),
1590 request.user_id.to_string(),
1591 livekit_api::proto::ParticipantPermission {
1592 can_subscribe: true,
1593 can_publish,
1594 can_publish_data: can_publish,
1595 hidden: false,
1596 recorder: false,
1597 },
1598 )
1599 .await
1600 .trace_err();
1601 }
1602
1603 response.send(proto::Ack {})?;
1604 Ok(())
1605}
1606
1607/// Call someone else into the current room
1608async fn call(
1609 request: proto::Call,
1610 response: Response<proto::Call>,
1611 session: MessageContext,
1612) -> Result<()> {
1613 let room_id = RoomId::from_proto(request.room_id);
1614 let calling_user_id = session.user_id();
1615 let calling_connection_id = session.connection_id;
1616 let called_user_id = UserId::from_proto(request.called_user_id);
1617 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1618 if !session
1619 .db()
1620 .await
1621 .has_contact(calling_user_id, called_user_id)
1622 .await?
1623 {
1624 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1625 }
1626
1627 let incoming_call = {
1628 let (room, incoming_call) = &mut *session
1629 .db()
1630 .await
1631 .call(
1632 room_id,
1633 calling_user_id,
1634 calling_connection_id,
1635 called_user_id,
1636 initial_project_id,
1637 )
1638 .await?;
1639 room_updated(room, &session.peer);
1640 mem::take(incoming_call)
1641 };
1642 update_user_contacts(called_user_id, &session).await?;
1643
1644 let mut calls = session
1645 .connection_pool()
1646 .await
1647 .user_connection_ids(called_user_id)
1648 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1649 .collect::<FuturesUnordered<_>>();
1650
1651 while let Some(call_response) = calls.next().await {
1652 match call_response.as_ref() {
1653 Ok(_) => {
1654 response.send(proto::Ack {})?;
1655 return Ok(());
1656 }
1657 Err(_) => {
1658 call_response.trace_err();
1659 }
1660 }
1661 }
1662
1663 {
1664 let room = session
1665 .db()
1666 .await
1667 .call_failed(room_id, called_user_id)
1668 .await?;
1669 room_updated(&room, &session.peer);
1670 }
1671 update_user_contacts(called_user_id, &session).await?;
1672
1673 Err(anyhow!("failed to ring user"))?
1674}
1675
1676/// Cancel an outgoing call.
1677async fn cancel_call(
1678 request: proto::CancelCall,
1679 response: Response<proto::CancelCall>,
1680 session: MessageContext,
1681) -> Result<()> {
1682 let called_user_id = UserId::from_proto(request.called_user_id);
1683 let room_id = RoomId::from_proto(request.room_id);
1684 {
1685 let room = session
1686 .db()
1687 .await
1688 .cancel_call(room_id, session.connection_id, called_user_id)
1689 .await?;
1690 room_updated(&room, &session.peer);
1691 }
1692
1693 for connection_id in session
1694 .connection_pool()
1695 .await
1696 .user_connection_ids(called_user_id)
1697 {
1698 session
1699 .peer
1700 .send(
1701 connection_id,
1702 proto::CallCanceled {
1703 room_id: room_id.to_proto(),
1704 },
1705 )
1706 .trace_err();
1707 }
1708 response.send(proto::Ack {})?;
1709
1710 update_user_contacts(called_user_id, &session).await?;
1711 Ok(())
1712}
1713
1714/// Decline an incoming call.
1715async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1716 let room_id = RoomId::from_proto(message.room_id);
1717 {
1718 let room = session
1719 .db()
1720 .await
1721 .decline_call(Some(room_id), session.user_id())
1722 .await?
1723 .context("declining call")?;
1724 room_updated(&room, &session.peer);
1725 }
1726
1727 for connection_id in session
1728 .connection_pool()
1729 .await
1730 .user_connection_ids(session.user_id())
1731 {
1732 session
1733 .peer
1734 .send(
1735 connection_id,
1736 proto::CallCanceled {
1737 room_id: room_id.to_proto(),
1738 },
1739 )
1740 .trace_err();
1741 }
1742 update_user_contacts(session.user_id(), &session).await?;
1743 Ok(())
1744}
1745
1746/// Updates other participants in the room with your current location.
1747async fn update_participant_location(
1748 request: proto::UpdateParticipantLocation,
1749 response: Response<proto::UpdateParticipantLocation>,
1750 session: MessageContext,
1751) -> Result<()> {
1752 let room_id = RoomId::from_proto(request.room_id);
1753 let location = request.location.context("invalid location")?;
1754
1755 let db = session.db().await;
1756 let room = db
1757 .update_room_participant_location(room_id, session.connection_id, location)
1758 .await?;
1759
1760 room_updated(&room, &session.peer);
1761 response.send(proto::Ack {})?;
1762 Ok(())
1763}
1764
1765/// Share a project into the room.
1766async fn share_project(
1767 request: proto::ShareProject,
1768 response: Response<proto::ShareProject>,
1769 session: MessageContext,
1770) -> Result<()> {
1771 let (project_id, room) = &*session
1772 .db()
1773 .await
1774 .share_project(
1775 RoomId::from_proto(request.room_id),
1776 session.connection_id,
1777 &request.worktrees,
1778 request.is_ssh_project,
1779 request.windows_paths.unwrap_or(false),
1780 &request.features,
1781 )
1782 .await?;
1783 response.send(proto::ShareProjectResponse {
1784 project_id: project_id.to_proto(),
1785 })?;
1786 room_updated(room, &session.peer);
1787
1788 Ok(())
1789}
1790
1791/// Unshare a project from the room.
1792async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1793 let project_id = ProjectId::from_proto(message.project_id);
1794 unshare_project_internal(project_id, session.connection_id, &session).await
1795}
1796
1797async fn unshare_project_internal(
1798 project_id: ProjectId,
1799 connection_id: ConnectionId,
1800 session: &Session,
1801) -> Result<()> {
1802 let delete = {
1803 let room_guard = session
1804 .db()
1805 .await
1806 .unshare_project(project_id, connection_id)
1807 .await?;
1808
1809 let (delete, room, guest_connection_ids) = &*room_guard;
1810
1811 let message = proto::UnshareProject {
1812 project_id: project_id.to_proto(),
1813 };
1814
1815 broadcast(
1816 Some(connection_id),
1817 guest_connection_ids.iter().copied(),
1818 |conn_id| session.peer.send(conn_id, message.clone()),
1819 );
1820 if let Some(room) = room {
1821 room_updated(room, &session.peer);
1822 }
1823
1824 *delete
1825 };
1826
1827 if delete {
1828 let db = session.db().await;
1829 db.delete_project(project_id).await?;
1830 }
1831
1832 Ok(())
1833}
1834
1835/// Join someone elses shared project.
1836async fn join_project(
1837 request: proto::JoinProject,
1838 response: Response<proto::JoinProject>,
1839 session: MessageContext,
1840) -> Result<()> {
1841 let project_id = ProjectId::from_proto(request.project_id);
1842
1843 tracing::info!(%project_id, "join project");
1844
1845 let db = session.db().await;
1846 let project_model = db.get_project(project_id).await?;
1847 let host_features: Vec<String> =
1848 serde_json::from_str(&project_model.features).unwrap_or_default();
1849 let guest_features: HashSet<_> = request.features.iter().collect();
1850 let host_features_set: HashSet<_> = host_features.iter().collect();
1851 if guest_features != host_features_set {
1852 let host_connection_id = project_model.host_connection()?;
1853 let mut pool = session.connection_pool().await;
1854 let host_version = pool
1855 .connection(host_connection_id)
1856 .map(|c| c.zed_version.to_string());
1857 let guest_version = pool
1858 .connection(session.connection_id)
1859 .map(|c| c.zed_version.to_string());
1860 drop(pool);
1861 Err(anyhow!(
1862 "The host (v{}) and guest (v{}) are using incompatible versions of Zed. The peer with the older version must update to collaborate.",
1863 host_version.as_deref().unwrap_or("unknown"),
1864 guest_version.as_deref().unwrap_or("unknown"),
1865 ))?;
1866 }
1867
1868 let (project, replica_id) = &mut *db
1869 .join_project(
1870 project_id,
1871 session.connection_id,
1872 session.user_id(),
1873 request.committer_name.clone(),
1874 request.committer_email.clone(),
1875 )
1876 .await?;
1877 drop(db);
1878
1879 tracing::info!(%project_id, "join remote project");
1880 let collaborators = project
1881 .collaborators
1882 .iter()
1883 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1884 .map(|collaborator| collaborator.to_proto())
1885 .collect::<Vec<_>>();
1886 let project_id = project.id;
1887 let guest_user_id = session.user_id();
1888
1889 let worktrees = project
1890 .worktrees
1891 .iter()
1892 .map(|(id, worktree)| proto::WorktreeMetadata {
1893 id: *id,
1894 root_name: worktree.root_name.clone(),
1895 visible: worktree.visible,
1896 abs_path: worktree.abs_path.clone(),
1897 root_repo_common_dir: None,
1898 })
1899 .collect::<Vec<_>>();
1900
1901 let add_project_collaborator = proto::AddProjectCollaborator {
1902 project_id: project_id.to_proto(),
1903 collaborator: Some(proto::Collaborator {
1904 peer_id: Some(session.connection_id.into()),
1905 replica_id: replica_id.0 as u32,
1906 user_id: guest_user_id.to_proto(),
1907 is_host: false,
1908 committer_name: request.committer_name.clone(),
1909 committer_email: request.committer_email.clone(),
1910 }),
1911 };
1912
1913 for collaborator in &collaborators {
1914 session
1915 .peer
1916 .send(
1917 collaborator.peer_id.unwrap().into(),
1918 add_project_collaborator.clone(),
1919 )
1920 .trace_err();
1921 }
1922
1923 // First, we send the metadata associated with each worktree.
1924 let (language_servers, language_server_capabilities) = project
1925 .language_servers
1926 .clone()
1927 .into_iter()
1928 .map(|server| (server.server, server.capabilities))
1929 .unzip();
1930 response.send(proto::JoinProjectResponse {
1931 project_id: project.id.0 as u64,
1932 worktrees,
1933 replica_id: replica_id.0 as u32,
1934 collaborators,
1935 language_servers,
1936 language_server_capabilities,
1937 role: project.role.into(),
1938 windows_paths: project.path_style == PathStyle::Windows,
1939 features: project.features.clone(),
1940 })?;
1941
1942 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1943 // Stream this worktree's entries.
1944 let message = proto::UpdateWorktree {
1945 project_id: project_id.to_proto(),
1946 worktree_id,
1947 abs_path: worktree.abs_path.clone(),
1948 root_name: worktree.root_name,
1949 root_repo_common_dir: worktree.root_repo_common_dir,
1950 updated_entries: worktree.entries,
1951 removed_entries: Default::default(),
1952 scan_id: worktree.scan_id,
1953 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1954 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1955 removed_repositories: Default::default(),
1956 };
1957 for update in proto::split_worktree_update(message) {
1958 session.peer.send(session.connection_id, update.clone())?;
1959 }
1960
1961 // Stream this worktree's diagnostics.
1962 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1963 if let Some(summary) = worktree_diagnostics.next() {
1964 let message = proto::UpdateDiagnosticSummary {
1965 project_id: project.id.to_proto(),
1966 worktree_id: worktree.id,
1967 summary: Some(summary),
1968 more_summaries: worktree_diagnostics.collect(),
1969 };
1970 session.peer.send(session.connection_id, message)?;
1971 }
1972
1973 for settings_file in worktree.settings_files {
1974 session.peer.send(
1975 session.connection_id,
1976 proto::UpdateWorktreeSettings {
1977 project_id: project_id.to_proto(),
1978 worktree_id: worktree.id,
1979 path: settings_file.path,
1980 content: Some(settings_file.content),
1981 kind: Some(settings_file.kind.to_proto() as i32),
1982 outside_worktree: Some(settings_file.outside_worktree),
1983 },
1984 )?;
1985 }
1986 }
1987
1988 for repository in mem::take(&mut project.repositories) {
1989 for update in split_repository_update(repository) {
1990 session.peer.send(session.connection_id, update)?;
1991 }
1992 }
1993
1994 for language_server in &project.language_servers {
1995 session.peer.send(
1996 session.connection_id,
1997 proto::UpdateLanguageServer {
1998 project_id: project_id.to_proto(),
1999 server_name: Some(language_server.server.name.clone()),
2000 language_server_id: language_server.server.id,
2001 variant: Some(
2002 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2003 proto::LspDiskBasedDiagnosticsUpdated {},
2004 ),
2005 ),
2006 },
2007 )?;
2008 }
2009
2010 Ok(())
2011}
2012
2013/// Leave someone elses shared project.
2014async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2015 let sender_id = session.connection_id;
2016 let project_id = ProjectId::from_proto(request.project_id);
2017 let db = session.db().await;
2018
2019 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2020 tracing::info!(
2021 %project_id,
2022 "leave project"
2023 );
2024
2025 project_left(project, &session);
2026 if let Some(room) = room {
2027 room_updated(room, &session.peer);
2028 }
2029
2030 Ok(())
2031}
2032
2033/// Updates other participants with changes to the project
2034async fn update_project(
2035 request: proto::UpdateProject,
2036 response: Response<proto::UpdateProject>,
2037 session: MessageContext,
2038) -> Result<()> {
2039 let project_id = ProjectId::from_proto(request.project_id);
2040 let (room, guest_connection_ids) = &*session
2041 .db()
2042 .await
2043 .update_project(project_id, session.connection_id, &request.worktrees)
2044 .await?;
2045 broadcast(
2046 Some(session.connection_id),
2047 guest_connection_ids.iter().copied(),
2048 |connection_id| {
2049 session
2050 .peer
2051 .forward_send(session.connection_id, connection_id, request.clone())
2052 },
2053 );
2054 if let Some(room) = room {
2055 room_updated(room, &session.peer);
2056 }
2057 response.send(proto::Ack {})?;
2058
2059 Ok(())
2060}
2061
2062/// Updates other participants with changes to the worktree
2063async fn update_worktree(
2064 request: proto::UpdateWorktree,
2065 response: Response<proto::UpdateWorktree>,
2066 session: MessageContext,
2067) -> Result<()> {
2068 let guest_connection_ids = session
2069 .db()
2070 .await
2071 .update_worktree(&request, session.connection_id)
2072 .await?;
2073
2074 broadcast(
2075 Some(session.connection_id),
2076 guest_connection_ids.iter().copied(),
2077 |connection_id| {
2078 session
2079 .peer
2080 .forward_send(session.connection_id, connection_id, request.clone())
2081 },
2082 );
2083 response.send(proto::Ack {})?;
2084 Ok(())
2085}
2086
2087async fn update_repository(
2088 request: proto::UpdateRepository,
2089 response: Response<proto::UpdateRepository>,
2090 session: MessageContext,
2091) -> Result<()> {
2092 let guest_connection_ids = session
2093 .db()
2094 .await
2095 .update_repository(&request, session.connection_id)
2096 .await?;
2097
2098 broadcast(
2099 Some(session.connection_id),
2100 guest_connection_ids.iter().copied(),
2101 |connection_id| {
2102 session
2103 .peer
2104 .forward_send(session.connection_id, connection_id, request.clone())
2105 },
2106 );
2107 response.send(proto::Ack {})?;
2108 Ok(())
2109}
2110
2111async fn remove_repository(
2112 request: proto::RemoveRepository,
2113 response: Response<proto::RemoveRepository>,
2114 session: MessageContext,
2115) -> Result<()> {
2116 let guest_connection_ids = session
2117 .db()
2118 .await
2119 .remove_repository(&request, session.connection_id)
2120 .await?;
2121
2122 broadcast(
2123 Some(session.connection_id),
2124 guest_connection_ids.iter().copied(),
2125 |connection_id| {
2126 session
2127 .peer
2128 .forward_send(session.connection_id, connection_id, request.clone())
2129 },
2130 );
2131 response.send(proto::Ack {})?;
2132 Ok(())
2133}
2134
2135/// Updates other participants with changes to the diagnostics
2136async fn update_diagnostic_summary(
2137 message: proto::UpdateDiagnosticSummary,
2138 session: MessageContext,
2139) -> Result<()> {
2140 let guest_connection_ids = session
2141 .db()
2142 .await
2143 .update_diagnostic_summary(&message, session.connection_id)
2144 .await?;
2145
2146 broadcast(
2147 Some(session.connection_id),
2148 guest_connection_ids.iter().copied(),
2149 |connection_id| {
2150 session
2151 .peer
2152 .forward_send(session.connection_id, connection_id, message.clone())
2153 },
2154 );
2155
2156 Ok(())
2157}
2158
2159/// Updates other participants with changes to the worktree settings
2160async fn update_worktree_settings(
2161 message: proto::UpdateWorktreeSettings,
2162 session: MessageContext,
2163) -> Result<()> {
2164 let guest_connection_ids = session
2165 .db()
2166 .await
2167 .update_worktree_settings(&message, session.connection_id)
2168 .await?;
2169
2170 broadcast(
2171 Some(session.connection_id),
2172 guest_connection_ids.iter().copied(),
2173 |connection_id| {
2174 session
2175 .peer
2176 .forward_send(session.connection_id, connection_id, message.clone())
2177 },
2178 );
2179
2180 Ok(())
2181}
2182
2183/// Notify other participants that a language server has started.
2184async fn start_language_server(
2185 request: proto::StartLanguageServer,
2186 session: MessageContext,
2187) -> Result<()> {
2188 let guest_connection_ids = session
2189 .db()
2190 .await
2191 .start_language_server(&request, session.connection_id)
2192 .await?;
2193
2194 broadcast(
2195 Some(session.connection_id),
2196 guest_connection_ids.iter().copied(),
2197 |connection_id| {
2198 session
2199 .peer
2200 .forward_send(session.connection_id, connection_id, request.clone())
2201 },
2202 );
2203 Ok(())
2204}
2205
2206/// Notify other participants that a language server has changed.
2207async fn update_language_server(
2208 request: proto::UpdateLanguageServer,
2209 session: MessageContext,
2210) -> Result<()> {
2211 let project_id = ProjectId::from_proto(request.project_id);
2212 let db = session.db().await;
2213
2214 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2215 && let Some(capabilities) = update.capabilities.clone()
2216 {
2217 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2218 .await?;
2219 }
2220
2221 let project_connection_ids = db
2222 .project_connection_ids(project_id, session.connection_id, true)
2223 .await?;
2224 broadcast(
2225 Some(session.connection_id),
2226 project_connection_ids.iter().copied(),
2227 |connection_id| {
2228 session
2229 .peer
2230 .forward_send(session.connection_id, connection_id, request.clone())
2231 },
2232 );
2233 Ok(())
2234}
2235
2236/// forward a project request to the host. These requests should be read only
2237/// as guests are allowed to send them.
2238async fn forward_read_only_project_request<T>(
2239 request: T,
2240 response: Response<T>,
2241 session: MessageContext,
2242) -> Result<()>
2243where
2244 T: EntityMessage + RequestMessage,
2245{
2246 let project_id = ProjectId::from_proto(request.remote_entity_id());
2247 let host_connection_id = session
2248 .db()
2249 .await
2250 .host_for_read_only_project_request(project_id, session.connection_id)
2251 .await?;
2252 let payload = session.forward_request(host_connection_id, request).await?;
2253 response.send(payload)?;
2254 Ok(())
2255}
2256
2257/// forward a project request to the host. These requests are disallowed
2258/// for guests.
2259async fn forward_mutating_project_request<T>(
2260 request: T,
2261 response: Response<T>,
2262 session: MessageContext,
2263) -> Result<()>
2264where
2265 T: EntityMessage + RequestMessage,
2266{
2267 let project_id = ProjectId::from_proto(request.remote_entity_id());
2268
2269 let host_connection_id = session
2270 .db()
2271 .await
2272 .host_for_mutating_project_request(project_id, session.connection_id)
2273 .await?;
2274 let payload = session.forward_request(host_connection_id, request).await?;
2275 response.send(payload)?;
2276 Ok(())
2277}
2278
2279async fn disallow_guest_request<T>(
2280 _request: T,
2281 response: Response<T>,
2282 _session: MessageContext,
2283) -> Result<()>
2284where
2285 T: RequestMessage,
2286{
2287 response.peer.respond_with_error(
2288 response.receipt,
2289 ErrorCode::Forbidden
2290 .message("request is not allowed for guests".to_string())
2291 .to_proto(),
2292 )?;
2293 response.responded.store(true, SeqCst);
2294 Ok(())
2295}
2296
2297async fn lsp_query(
2298 request: proto::LspQuery,
2299 response: Response<proto::LspQuery>,
2300 session: MessageContext,
2301) -> Result<()> {
2302 let (name, should_write) = request.query_name_and_write_permissions();
2303 tracing::Span::current().record("lsp_query_request", name);
2304 tracing::info!("lsp_query message received");
2305 if should_write {
2306 forward_mutating_project_request(request, response, session).await
2307 } else {
2308 forward_read_only_project_request(request, response, session).await
2309 }
2310}
2311
2312/// Notify other participants that a new buffer has been created
2313async fn create_buffer_for_peer(
2314 request: proto::CreateBufferForPeer,
2315 session: MessageContext,
2316) -> Result<()> {
2317 session
2318 .db()
2319 .await
2320 .check_user_is_project_host(
2321 ProjectId::from_proto(request.project_id),
2322 session.connection_id,
2323 )
2324 .await?;
2325 let peer_id = request.peer_id.context("invalid peer id")?;
2326 session
2327 .peer
2328 .forward_send(session.connection_id, peer_id.into(), request)?;
2329 Ok(())
2330}
2331
2332/// Notify other participants that a new image has been created
2333async fn create_image_for_peer(
2334 request: proto::CreateImageForPeer,
2335 session: MessageContext,
2336) -> Result<()> {
2337 session
2338 .db()
2339 .await
2340 .check_user_is_project_host(
2341 ProjectId::from_proto(request.project_id),
2342 session.connection_id,
2343 )
2344 .await?;
2345 let peer_id = request.peer_id.context("invalid peer id")?;
2346 session
2347 .peer
2348 .forward_send(session.connection_id, peer_id.into(), request)?;
2349 Ok(())
2350}
2351
2352/// Notify other participants that a buffer has been updated. This is
2353/// allowed for guests as long as the update is limited to selections.
2354async fn update_buffer(
2355 request: proto::UpdateBuffer,
2356 response: Response<proto::UpdateBuffer>,
2357 session: MessageContext,
2358) -> Result<()> {
2359 let project_id = ProjectId::from_proto(request.project_id);
2360 let mut capability = Capability::ReadOnly;
2361
2362 for op in request.operations.iter() {
2363 match op.variant {
2364 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2365 Some(_) => capability = Capability::ReadWrite,
2366 }
2367 }
2368
2369 let host = {
2370 let guard = session
2371 .db()
2372 .await
2373 .connections_for_buffer_update(project_id, session.connection_id, capability)
2374 .await?;
2375
2376 let (host, guests) = &*guard;
2377
2378 broadcast(
2379 Some(session.connection_id),
2380 guests.clone(),
2381 |connection_id| {
2382 session
2383 .peer
2384 .forward_send(session.connection_id, connection_id, request.clone())
2385 },
2386 );
2387
2388 *host
2389 };
2390
2391 if host != session.connection_id {
2392 session.forward_request(host, request.clone()).await?;
2393 }
2394
2395 response.send(proto::Ack {})?;
2396 Ok(())
2397}
2398
2399async fn forward_project_search_chunk(
2400 message: proto::FindSearchCandidatesChunk,
2401 response: Response<proto::FindSearchCandidatesChunk>,
2402 session: MessageContext,
2403) -> Result<()> {
2404 let peer_id = message.peer_id.context("missing peer_id")?;
2405 let payload = session
2406 .peer
2407 .forward_request(session.connection_id, peer_id.into(), message)
2408 .await?;
2409 response.send(payload)?;
2410 Ok(())
2411}
2412
2413/// Notify other participants that a project has been updated.
2414async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2415 request: T,
2416 session: MessageContext,
2417) -> Result<()> {
2418 let project_id = ProjectId::from_proto(request.remote_entity_id());
2419 let project_connection_ids = session
2420 .db()
2421 .await
2422 .project_connection_ids(project_id, session.connection_id, false)
2423 .await?;
2424
2425 broadcast(
2426 Some(session.connection_id),
2427 project_connection_ids.iter().copied(),
2428 |connection_id| {
2429 session
2430 .peer
2431 .forward_send(session.connection_id, connection_id, request.clone())
2432 },
2433 );
2434 Ok(())
2435}
2436
2437/// Start following another user in a call.
2438async fn follow(
2439 request: proto::Follow,
2440 response: Response<proto::Follow>,
2441 session: MessageContext,
2442) -> Result<()> {
2443 let room_id = RoomId::from_proto(request.room_id);
2444 let project_id = request.project_id.map(ProjectId::from_proto);
2445 let leader_id = request.leader_id.context("invalid leader id")?.into();
2446 let follower_id = session.connection_id;
2447
2448 session
2449 .db()
2450 .await
2451 .check_room_participants(room_id, leader_id, session.connection_id)
2452 .await?;
2453
2454 let response_payload = session.forward_request(leader_id, request).await?;
2455 response.send(response_payload)?;
2456
2457 if let Some(project_id) = project_id {
2458 let room = session
2459 .db()
2460 .await
2461 .follow(room_id, project_id, leader_id, follower_id)
2462 .await?;
2463 room_updated(&room, &session.peer);
2464 }
2465
2466 Ok(())
2467}
2468
2469/// Stop following another user in a call.
2470async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2471 let room_id = RoomId::from_proto(request.room_id);
2472 let project_id = request.project_id.map(ProjectId::from_proto);
2473 let leader_id = request.leader_id.context("invalid leader id")?.into();
2474 let follower_id = session.connection_id;
2475
2476 session
2477 .db()
2478 .await
2479 .check_room_participants(room_id, leader_id, session.connection_id)
2480 .await?;
2481
2482 session
2483 .peer
2484 .forward_send(session.connection_id, leader_id, request)?;
2485
2486 if let Some(project_id) = project_id {
2487 let room = session
2488 .db()
2489 .await
2490 .unfollow(room_id, project_id, leader_id, follower_id)
2491 .await?;
2492 room_updated(&room, &session.peer);
2493 }
2494
2495 Ok(())
2496}
2497
2498/// Notify everyone following you of your current location.
2499async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2500 let room_id = RoomId::from_proto(request.room_id);
2501 let database = session.db.lock().await;
2502
2503 let connection_ids = if let Some(project_id) = request.project_id {
2504 let project_id = ProjectId::from_proto(project_id);
2505 database
2506 .project_connection_ids(project_id, session.connection_id, true)
2507 .await?
2508 } else {
2509 database
2510 .room_connection_ids(room_id, session.connection_id)
2511 .await?
2512 };
2513
2514 // For now, don't send view update messages back to that view's current leader.
2515 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2516 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2517 _ => None,
2518 });
2519
2520 for connection_id in connection_ids.iter().cloned() {
2521 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2522 session
2523 .peer
2524 .forward_send(session.connection_id, connection_id, request.clone())?;
2525 }
2526 }
2527 Ok(())
2528}
2529
2530/// Get public data about users.
2531async fn get_users(
2532 request: proto::GetUsers,
2533 response: Response<proto::GetUsers>,
2534 session: MessageContext,
2535) -> Result<()> {
2536 let user_ids = request
2537 .user_ids
2538 .into_iter()
2539 .map(UserId::from_proto)
2540 .collect();
2541 let users = session
2542 .db()
2543 .await
2544 .get_users_by_ids(user_ids)
2545 .await?
2546 .into_iter()
2547 .map(|user| proto::User {
2548 id: user.id.to_proto(),
2549 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2550 github_login: user.github_login,
2551 name: user.name,
2552 })
2553 .collect();
2554 response.send(proto::UsersResponse { users })?;
2555 Ok(())
2556}
2557
2558/// Search for users (to invite) buy Github login
2559async fn fuzzy_search_users(
2560 request: proto::FuzzySearchUsers,
2561 response: Response<proto::FuzzySearchUsers>,
2562 session: MessageContext,
2563) -> Result<()> {
2564 let query = request.query;
2565 let users = match query.len() {
2566 0 => vec![],
2567 1 | 2 => session
2568 .db()
2569 .await
2570 .get_user_by_github_login(&query)
2571 .await?
2572 .into_iter()
2573 .collect(),
2574 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2575 };
2576 let users = users
2577 .into_iter()
2578 .filter(|user| user.id != session.user_id())
2579 .map(|user| proto::User {
2580 id: user.id.to_proto(),
2581 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2582 github_login: user.github_login,
2583 name: user.name,
2584 })
2585 .collect();
2586 response.send(proto::UsersResponse { users })?;
2587 Ok(())
2588}
2589
2590/// Send a contact request to another user.
2591async fn request_contact(
2592 request: proto::RequestContact,
2593 response: Response<proto::RequestContact>,
2594 session: MessageContext,
2595) -> Result<()> {
2596 let requester_id = session.user_id();
2597 let responder_id = UserId::from_proto(request.responder_id);
2598 if requester_id == responder_id {
2599 return Err(anyhow!("cannot add yourself as a contact"))?;
2600 }
2601
2602 let notifications = session
2603 .db()
2604 .await
2605 .send_contact_request(requester_id, responder_id)
2606 .await?;
2607
2608 // Update outgoing contact requests of requester
2609 let mut update = proto::UpdateContacts::default();
2610 update.outgoing_requests.push(responder_id.to_proto());
2611 for connection_id in session
2612 .connection_pool()
2613 .await
2614 .user_connection_ids(requester_id)
2615 {
2616 session.peer.send(connection_id, update.clone())?;
2617 }
2618
2619 // Update incoming contact requests of responder
2620 let mut update = proto::UpdateContacts::default();
2621 update
2622 .incoming_requests
2623 .push(proto::IncomingContactRequest {
2624 requester_id: requester_id.to_proto(),
2625 });
2626 let connection_pool = session.connection_pool().await;
2627 for connection_id in connection_pool.user_connection_ids(responder_id) {
2628 session.peer.send(connection_id, update.clone())?;
2629 }
2630
2631 send_notifications(&connection_pool, &session.peer, notifications);
2632
2633 response.send(proto::Ack {})?;
2634 Ok(())
2635}
2636
2637/// Accept or decline a contact request
2638async fn respond_to_contact_request(
2639 request: proto::RespondToContactRequest,
2640 response: Response<proto::RespondToContactRequest>,
2641 session: MessageContext,
2642) -> Result<()> {
2643 let responder_id = session.user_id();
2644 let requester_id = UserId::from_proto(request.requester_id);
2645 let db = session.db().await;
2646 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2647 db.dismiss_contact_notification(responder_id, requester_id)
2648 .await?;
2649 } else {
2650 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2651
2652 let notifications = db
2653 .respond_to_contact_request(responder_id, requester_id, accept)
2654 .await?;
2655 let requester_busy = db.is_user_busy(requester_id).await?;
2656 let responder_busy = db.is_user_busy(responder_id).await?;
2657
2658 let pool = session.connection_pool().await;
2659 // Update responder with new contact
2660 let mut update = proto::UpdateContacts::default();
2661 if accept {
2662 update
2663 .contacts
2664 .push(contact_for_user(requester_id, requester_busy, &pool));
2665 }
2666 update
2667 .remove_incoming_requests
2668 .push(requester_id.to_proto());
2669 for connection_id in pool.user_connection_ids(responder_id) {
2670 session.peer.send(connection_id, update.clone())?;
2671 }
2672
2673 // Update requester with new contact
2674 let mut update = proto::UpdateContacts::default();
2675 if accept {
2676 update
2677 .contacts
2678 .push(contact_for_user(responder_id, responder_busy, &pool));
2679 }
2680 update
2681 .remove_outgoing_requests
2682 .push(responder_id.to_proto());
2683
2684 for connection_id in pool.user_connection_ids(requester_id) {
2685 session.peer.send(connection_id, update.clone())?;
2686 }
2687
2688 send_notifications(&pool, &session.peer, notifications);
2689 }
2690
2691 response.send(proto::Ack {})?;
2692 Ok(())
2693}
2694
2695/// Remove a contact.
2696async fn remove_contact(
2697 request: proto::RemoveContact,
2698 response: Response<proto::RemoveContact>,
2699 session: MessageContext,
2700) -> Result<()> {
2701 let requester_id = session.user_id();
2702 let responder_id = UserId::from_proto(request.user_id);
2703 let db = session.db().await;
2704 let (contact_accepted, deleted_notification_id) =
2705 db.remove_contact(requester_id, responder_id).await?;
2706
2707 let pool = session.connection_pool().await;
2708 // Update outgoing contact requests of requester
2709 let mut update = proto::UpdateContacts::default();
2710 if contact_accepted {
2711 update.remove_contacts.push(responder_id.to_proto());
2712 } else {
2713 update
2714 .remove_outgoing_requests
2715 .push(responder_id.to_proto());
2716 }
2717 for connection_id in pool.user_connection_ids(requester_id) {
2718 session.peer.send(connection_id, update.clone())?;
2719 }
2720
2721 // Update incoming contact requests of responder
2722 let mut update = proto::UpdateContacts::default();
2723 if contact_accepted {
2724 update.remove_contacts.push(requester_id.to_proto());
2725 } else {
2726 update
2727 .remove_incoming_requests
2728 .push(requester_id.to_proto());
2729 }
2730 for connection_id in pool.user_connection_ids(responder_id) {
2731 session.peer.send(connection_id, update.clone())?;
2732 if let Some(notification_id) = deleted_notification_id {
2733 session.peer.send(
2734 connection_id,
2735 proto::DeleteNotification {
2736 notification_id: notification_id.to_proto(),
2737 },
2738 )?;
2739 }
2740 }
2741
2742 response.send(proto::Ack {})?;
2743 Ok(())
2744}
2745
2746fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2747 version.0.minor < 139
2748}
2749
2750async fn subscribe_to_channels(
2751 _: proto::SubscribeToChannels,
2752 session: MessageContext,
2753) -> Result<()> {
2754 subscribe_user_to_channels(session.user_id(), &session).await?;
2755 Ok(())
2756}
2757
2758async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2759 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2760 let mut pool = session.connection_pool().await;
2761 for membership in &channels_for_user.channel_memberships {
2762 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2763 }
2764 session.peer.send(
2765 session.connection_id,
2766 build_update_user_channels(&channels_for_user),
2767 )?;
2768 session.peer.send(
2769 session.connection_id,
2770 build_channels_update(channels_for_user),
2771 )?;
2772 Ok(())
2773}
2774
2775/// Creates a new channel.
2776async fn create_channel(
2777 request: proto::CreateChannel,
2778 response: Response<proto::CreateChannel>,
2779 session: MessageContext,
2780) -> Result<()> {
2781 let db = session.db().await;
2782
2783 let parent_id = request.parent_id.map(ChannelId::from_proto);
2784 let (channel, membership) = db
2785 .create_channel(&request.name, parent_id, session.user_id())
2786 .await?;
2787
2788 let root_id = channel.root_id();
2789 let channel = Channel::from_model(channel);
2790
2791 response.send(proto::CreateChannelResponse {
2792 channel: Some(channel.to_proto()),
2793 parent_id: request.parent_id,
2794 })?;
2795
2796 let mut connection_pool = session.connection_pool().await;
2797 if let Some(membership) = membership {
2798 connection_pool.subscribe_to_channel(
2799 membership.user_id,
2800 membership.channel_id,
2801 membership.role,
2802 );
2803 let update = proto::UpdateUserChannels {
2804 channel_memberships: vec![proto::ChannelMembership {
2805 channel_id: membership.channel_id.to_proto(),
2806 role: membership.role.into(),
2807 }],
2808 ..Default::default()
2809 };
2810 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2811 session.peer.send(connection_id, update.clone())?;
2812 }
2813 }
2814
2815 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2816 if !role.can_see_channel(channel.visibility) {
2817 continue;
2818 }
2819
2820 let update = proto::UpdateChannels {
2821 channels: vec![channel.to_proto()],
2822 ..Default::default()
2823 };
2824 session.peer.send(connection_id, update.clone())?;
2825 }
2826
2827 Ok(())
2828}
2829
2830/// Delete a channel
2831async fn delete_channel(
2832 request: proto::DeleteChannel,
2833 response: Response<proto::DeleteChannel>,
2834 session: MessageContext,
2835) -> Result<()> {
2836 let db = session.db().await;
2837
2838 let channel_id = request.channel_id;
2839 let (root_channel, removed_channels) = db
2840 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2841 .await?;
2842 response.send(proto::Ack {})?;
2843
2844 // Notify members of removed channels
2845 let mut update = proto::UpdateChannels::default();
2846 update
2847 .delete_channels
2848 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2849
2850 let connection_pool = session.connection_pool().await;
2851 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2852 session.peer.send(connection_id, update.clone())?;
2853 }
2854
2855 Ok(())
2856}
2857
2858/// Invite someone to join a channel.
2859async fn invite_channel_member(
2860 request: proto::InviteChannelMember,
2861 response: Response<proto::InviteChannelMember>,
2862 session: MessageContext,
2863) -> Result<()> {
2864 let db = session.db().await;
2865 let channel_id = ChannelId::from_proto(request.channel_id);
2866 let invitee_id = UserId::from_proto(request.user_id);
2867 let InviteMemberResult {
2868 channel,
2869 notifications,
2870 } = db
2871 .invite_channel_member(
2872 channel_id,
2873 invitee_id,
2874 session.user_id(),
2875 request.role().into(),
2876 )
2877 .await?;
2878
2879 let update = proto::UpdateChannels {
2880 channel_invitations: vec![channel.to_proto()],
2881 ..Default::default()
2882 };
2883
2884 let connection_pool = session.connection_pool().await;
2885 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2886 session.peer.send(connection_id, update.clone())?;
2887 }
2888
2889 send_notifications(&connection_pool, &session.peer, notifications);
2890
2891 response.send(proto::Ack {})?;
2892 Ok(())
2893}
2894
2895/// remove someone from a channel
2896async fn remove_channel_member(
2897 request: proto::RemoveChannelMember,
2898 response: Response<proto::RemoveChannelMember>,
2899 session: MessageContext,
2900) -> Result<()> {
2901 let db = session.db().await;
2902 let channel_id = ChannelId::from_proto(request.channel_id);
2903 let member_id = UserId::from_proto(request.user_id);
2904
2905 let RemoveChannelMemberResult {
2906 membership_update,
2907 notification_id,
2908 } = db
2909 .remove_channel_member(channel_id, member_id, session.user_id())
2910 .await?;
2911
2912 let mut connection_pool = session.connection_pool().await;
2913 notify_membership_updated(
2914 &mut connection_pool,
2915 membership_update,
2916 member_id,
2917 &session.peer,
2918 );
2919 for connection_id in connection_pool.user_connection_ids(member_id) {
2920 if let Some(notification_id) = notification_id {
2921 session
2922 .peer
2923 .send(
2924 connection_id,
2925 proto::DeleteNotification {
2926 notification_id: notification_id.to_proto(),
2927 },
2928 )
2929 .trace_err();
2930 }
2931 }
2932
2933 response.send(proto::Ack {})?;
2934 Ok(())
2935}
2936
2937/// Toggle the channel between public and private.
2938/// Care is taken to maintain the invariant that public channels only descend from public channels,
2939/// (though members-only channels can appear at any point in the hierarchy).
2940async fn set_channel_visibility(
2941 request: proto::SetChannelVisibility,
2942 response: Response<proto::SetChannelVisibility>,
2943 session: MessageContext,
2944) -> Result<()> {
2945 let db = session.db().await;
2946 let channel_id = ChannelId::from_proto(request.channel_id);
2947 let visibility = request.visibility().into();
2948
2949 let channel_model = db
2950 .set_channel_visibility(channel_id, visibility, session.user_id())
2951 .await?;
2952 let root_id = channel_model.root_id();
2953 let channel = Channel::from_model(channel_model);
2954
2955 let mut connection_pool = session.connection_pool().await;
2956 for (user_id, role) in connection_pool
2957 .channel_user_ids(root_id)
2958 .collect::<Vec<_>>()
2959 .into_iter()
2960 {
2961 let update = if role.can_see_channel(channel.visibility) {
2962 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2963 proto::UpdateChannels {
2964 channels: vec![channel.to_proto()],
2965 ..Default::default()
2966 }
2967 } else {
2968 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2969 proto::UpdateChannels {
2970 delete_channels: vec![channel.id.to_proto()],
2971 ..Default::default()
2972 }
2973 };
2974
2975 for connection_id in connection_pool.user_connection_ids(user_id) {
2976 session.peer.send(connection_id, update.clone())?;
2977 }
2978 }
2979
2980 response.send(proto::Ack {})?;
2981 Ok(())
2982}
2983
2984/// Alter the role for a user in the channel.
2985async fn set_channel_member_role(
2986 request: proto::SetChannelMemberRole,
2987 response: Response<proto::SetChannelMemberRole>,
2988 session: MessageContext,
2989) -> Result<()> {
2990 let db = session.db().await;
2991 let channel_id = ChannelId::from_proto(request.channel_id);
2992 let member_id = UserId::from_proto(request.user_id);
2993 let result = db
2994 .set_channel_member_role(
2995 channel_id,
2996 session.user_id(),
2997 member_id,
2998 request.role().into(),
2999 )
3000 .await?;
3001
3002 match result {
3003 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3004 let mut connection_pool = session.connection_pool().await;
3005 notify_membership_updated(
3006 &mut connection_pool,
3007 membership_update,
3008 member_id,
3009 &session.peer,
3010 )
3011 }
3012 db::SetMemberRoleResult::InviteUpdated(channel) => {
3013 let update = proto::UpdateChannels {
3014 channel_invitations: vec![channel.to_proto()],
3015 ..Default::default()
3016 };
3017
3018 for connection_id in session
3019 .connection_pool()
3020 .await
3021 .user_connection_ids(member_id)
3022 {
3023 session.peer.send(connection_id, update.clone())?;
3024 }
3025 }
3026 }
3027
3028 response.send(proto::Ack {})?;
3029 Ok(())
3030}
3031
3032/// Change the name of a channel
3033async fn rename_channel(
3034 request: proto::RenameChannel,
3035 response: Response<proto::RenameChannel>,
3036 session: MessageContext,
3037) -> Result<()> {
3038 let db = session.db().await;
3039 let channel_id = ChannelId::from_proto(request.channel_id);
3040 let channel_model = db
3041 .rename_channel(channel_id, session.user_id(), &request.name)
3042 .await?;
3043 let root_id = channel_model.root_id();
3044 let channel = Channel::from_model(channel_model);
3045
3046 response.send(proto::RenameChannelResponse {
3047 channel: Some(channel.to_proto()),
3048 })?;
3049
3050 let connection_pool = session.connection_pool().await;
3051 let update = proto::UpdateChannels {
3052 channels: vec![channel.to_proto()],
3053 ..Default::default()
3054 };
3055 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3056 if role.can_see_channel(channel.visibility) {
3057 session.peer.send(connection_id, update.clone())?;
3058 }
3059 }
3060
3061 Ok(())
3062}
3063
3064/// Move a channel to a new parent.
3065async fn move_channel(
3066 request: proto::MoveChannel,
3067 response: Response<proto::MoveChannel>,
3068 session: MessageContext,
3069) -> Result<()> {
3070 let channel_id = ChannelId::from_proto(request.channel_id);
3071 let to = ChannelId::from_proto(request.to);
3072
3073 let (root_id, channels) = session
3074 .db()
3075 .await
3076 .move_channel(channel_id, to, session.user_id())
3077 .await?;
3078
3079 let connection_pool = session.connection_pool().await;
3080 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3081 let channels = channels
3082 .iter()
3083 .filter_map(|channel| {
3084 if role.can_see_channel(channel.visibility) {
3085 Some(channel.to_proto())
3086 } else {
3087 None
3088 }
3089 })
3090 .collect::<Vec<_>>();
3091 if channels.is_empty() {
3092 continue;
3093 }
3094
3095 let update = proto::UpdateChannels {
3096 channels,
3097 ..Default::default()
3098 };
3099
3100 session.peer.send(connection_id, update.clone())?;
3101 }
3102
3103 response.send(Ack {})?;
3104 Ok(())
3105}
3106
3107async fn reorder_channel(
3108 request: proto::ReorderChannel,
3109 response: Response<proto::ReorderChannel>,
3110 session: MessageContext,
3111) -> Result<()> {
3112 let channel_id = ChannelId::from_proto(request.channel_id);
3113 let direction = request.direction();
3114
3115 let updated_channels = session
3116 .db()
3117 .await
3118 .reorder_channel(channel_id, direction, session.user_id())
3119 .await?;
3120
3121 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3122 let connection_pool = session.connection_pool().await;
3123 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3124 let channels = updated_channels
3125 .iter()
3126 .filter_map(|channel| {
3127 if role.can_see_channel(channel.visibility) {
3128 Some(channel.to_proto())
3129 } else {
3130 None
3131 }
3132 })
3133 .collect::<Vec<_>>();
3134
3135 if channels.is_empty() {
3136 continue;
3137 }
3138
3139 let update = proto::UpdateChannels {
3140 channels,
3141 ..Default::default()
3142 };
3143
3144 session.peer.send(connection_id, update.clone())?;
3145 }
3146 }
3147
3148 response.send(Ack {})?;
3149 Ok(())
3150}
3151
3152/// Get the list of channel members
3153async fn get_channel_members(
3154 request: proto::GetChannelMembers,
3155 response: Response<proto::GetChannelMembers>,
3156 session: MessageContext,
3157) -> Result<()> {
3158 let db = session.db().await;
3159 let channel_id = ChannelId::from_proto(request.channel_id);
3160 let limit = if request.limit == 0 {
3161 u16::MAX as u64
3162 } else {
3163 request.limit
3164 };
3165 let (members, users) = db
3166 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3167 .await?;
3168 response.send(proto::GetChannelMembersResponse { members, users })?;
3169 Ok(())
3170}
3171
3172/// Accept or decline a channel invitation.
3173async fn respond_to_channel_invite(
3174 request: proto::RespondToChannelInvite,
3175 response: Response<proto::RespondToChannelInvite>,
3176 session: MessageContext,
3177) -> Result<()> {
3178 let db = session.db().await;
3179 let channel_id = ChannelId::from_proto(request.channel_id);
3180 let RespondToChannelInvite {
3181 membership_update,
3182 notifications,
3183 } = db
3184 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3185 .await?;
3186
3187 let mut connection_pool = session.connection_pool().await;
3188 if let Some(membership_update) = membership_update {
3189 notify_membership_updated(
3190 &mut connection_pool,
3191 membership_update,
3192 session.user_id(),
3193 &session.peer,
3194 );
3195 } else {
3196 let update = proto::UpdateChannels {
3197 remove_channel_invitations: vec![channel_id.to_proto()],
3198 ..Default::default()
3199 };
3200
3201 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3202 session.peer.send(connection_id, update.clone())?;
3203 }
3204 };
3205
3206 send_notifications(&connection_pool, &session.peer, notifications);
3207
3208 response.send(proto::Ack {})?;
3209
3210 Ok(())
3211}
3212
3213/// Join the channels' room
3214async fn join_channel(
3215 request: proto::JoinChannel,
3216 response: Response<proto::JoinChannel>,
3217 session: MessageContext,
3218) -> Result<()> {
3219 let channel_id = ChannelId::from_proto(request.channel_id);
3220 join_channel_internal(channel_id, Box::new(response), session).await
3221}
3222
3223trait JoinChannelInternalResponse {
3224 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3225}
3226impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3227 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3228 Response::<proto::JoinChannel>::send(self, result)
3229 }
3230}
3231impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3232 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3233 Response::<proto::JoinRoom>::send(self, result)
3234 }
3235}
3236
3237async fn join_channel_internal(
3238 channel_id: ChannelId,
3239 response: Box<impl JoinChannelInternalResponse>,
3240 session: MessageContext,
3241) -> Result<()> {
3242 let joined_room = {
3243 let mut db = session.db().await;
3244 // If zed quits without leaving the room, and the user re-opens zed before the
3245 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3246 // room they were in.
3247 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3248 tracing::info!(
3249 stale_connection_id = %connection,
3250 "cleaning up stale connection",
3251 );
3252 drop(db);
3253 leave_room_for_session(&session, connection).await?;
3254 db = session.db().await;
3255 }
3256
3257 let (joined_room, membership_updated, role) = db
3258 .join_channel(channel_id, session.user_id(), session.connection_id)
3259 .await?;
3260
3261 let live_kit_connection_info =
3262 session
3263 .app_state
3264 .livekit_client
3265 .as_ref()
3266 .and_then(|live_kit| {
3267 let (can_publish, token) = if role == ChannelRole::Guest {
3268 (
3269 false,
3270 live_kit
3271 .guest_token(
3272 &joined_room.room.livekit_room,
3273 &session.user_id().to_string(),
3274 )
3275 .trace_err()?,
3276 )
3277 } else {
3278 (
3279 true,
3280 live_kit
3281 .room_token(
3282 &joined_room.room.livekit_room,
3283 &session.user_id().to_string(),
3284 )
3285 .trace_err()?,
3286 )
3287 };
3288
3289 Some(LiveKitConnectionInfo {
3290 server_url: live_kit.url().into(),
3291 token,
3292 can_publish,
3293 })
3294 });
3295
3296 response.send(proto::JoinRoomResponse {
3297 room: Some(joined_room.room.clone()),
3298 channel_id: joined_room
3299 .channel
3300 .as_ref()
3301 .map(|channel| channel.id.to_proto()),
3302 live_kit_connection_info,
3303 })?;
3304
3305 let mut connection_pool = session.connection_pool().await;
3306 if let Some(membership_updated) = membership_updated {
3307 notify_membership_updated(
3308 &mut connection_pool,
3309 membership_updated,
3310 session.user_id(),
3311 &session.peer,
3312 );
3313 }
3314
3315 room_updated(&joined_room.room, &session.peer);
3316
3317 joined_room
3318 };
3319
3320 channel_updated(
3321 &joined_room.channel.context("channel not returned")?,
3322 &joined_room.room,
3323 &session.peer,
3324 &*session.connection_pool().await,
3325 );
3326
3327 update_user_contacts(session.user_id(), &session).await?;
3328 Ok(())
3329}
3330
3331/// Start editing the channel notes
3332async fn join_channel_buffer(
3333 request: proto::JoinChannelBuffer,
3334 response: Response<proto::JoinChannelBuffer>,
3335 session: MessageContext,
3336) -> Result<()> {
3337 let db = session.db().await;
3338 let channel_id = ChannelId::from_proto(request.channel_id);
3339
3340 let open_response = db
3341 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3342 .await?;
3343
3344 let collaborators = open_response.collaborators.clone();
3345 response.send(open_response)?;
3346
3347 let update = UpdateChannelBufferCollaborators {
3348 channel_id: channel_id.to_proto(),
3349 collaborators: collaborators.clone(),
3350 };
3351 channel_buffer_updated(
3352 session.connection_id,
3353 collaborators
3354 .iter()
3355 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3356 &update,
3357 &session.peer,
3358 );
3359
3360 Ok(())
3361}
3362
3363/// Edit the channel notes
3364async fn update_channel_buffer(
3365 request: proto::UpdateChannelBuffer,
3366 session: MessageContext,
3367) -> Result<()> {
3368 let db = session.db().await;
3369 let channel_id = ChannelId::from_proto(request.channel_id);
3370
3371 let (collaborators, epoch, version) = db
3372 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3373 .await?;
3374
3375 channel_buffer_updated(
3376 session.connection_id,
3377 collaborators.clone(),
3378 &proto::UpdateChannelBuffer {
3379 channel_id: channel_id.to_proto(),
3380 operations: request.operations,
3381 },
3382 &session.peer,
3383 );
3384
3385 let pool = &*session.connection_pool().await;
3386
3387 let non_collaborators =
3388 pool.channel_connection_ids(channel_id)
3389 .filter_map(|(connection_id, _)| {
3390 if collaborators.contains(&connection_id) {
3391 None
3392 } else {
3393 Some(connection_id)
3394 }
3395 });
3396
3397 broadcast(None, non_collaborators, |peer_id| {
3398 session.peer.send(
3399 peer_id,
3400 proto::UpdateChannels {
3401 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3402 channel_id: channel_id.to_proto(),
3403 epoch: epoch as u64,
3404 version: version.clone(),
3405 }],
3406 ..Default::default()
3407 },
3408 )
3409 });
3410
3411 Ok(())
3412}
3413
3414/// Rejoin the channel notes after a connection blip
3415async fn rejoin_channel_buffers(
3416 request: proto::RejoinChannelBuffers,
3417 response: Response<proto::RejoinChannelBuffers>,
3418 session: MessageContext,
3419) -> Result<()> {
3420 let db = session.db().await;
3421 let buffers = db
3422 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3423 .await?;
3424
3425 for rejoined_buffer in &buffers {
3426 let collaborators_to_notify = rejoined_buffer
3427 .buffer
3428 .collaborators
3429 .iter()
3430 .filter_map(|c| Some(c.peer_id?.into()));
3431 channel_buffer_updated(
3432 session.connection_id,
3433 collaborators_to_notify,
3434 &proto::UpdateChannelBufferCollaborators {
3435 channel_id: rejoined_buffer.buffer.channel_id,
3436 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3437 },
3438 &session.peer,
3439 );
3440 }
3441
3442 response.send(proto::RejoinChannelBuffersResponse {
3443 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3444 })?;
3445
3446 Ok(())
3447}
3448
3449/// Stop editing the channel notes
3450async fn leave_channel_buffer(
3451 request: proto::LeaveChannelBuffer,
3452 response: Response<proto::LeaveChannelBuffer>,
3453 session: MessageContext,
3454) -> Result<()> {
3455 let db = session.db().await;
3456 let channel_id = ChannelId::from_proto(request.channel_id);
3457
3458 let left_buffer = db
3459 .leave_channel_buffer(channel_id, session.connection_id)
3460 .await?;
3461
3462 response.send(Ack {})?;
3463
3464 channel_buffer_updated(
3465 session.connection_id,
3466 left_buffer.connections,
3467 &proto::UpdateChannelBufferCollaborators {
3468 channel_id: channel_id.to_proto(),
3469 collaborators: left_buffer.collaborators,
3470 },
3471 &session.peer,
3472 );
3473
3474 Ok(())
3475}
3476
3477fn channel_buffer_updated<T: EnvelopedMessage>(
3478 sender_id: ConnectionId,
3479 collaborators: impl IntoIterator<Item = ConnectionId>,
3480 message: &T,
3481 peer: &Peer,
3482) {
3483 broadcast(Some(sender_id), collaborators, |peer_id| {
3484 peer.send(peer_id, message.clone())
3485 });
3486}
3487
3488fn send_notifications(
3489 connection_pool: &ConnectionPool,
3490 peer: &Peer,
3491 notifications: db::NotificationBatch,
3492) {
3493 for (user_id, notification) in notifications {
3494 for connection_id in connection_pool.user_connection_ids(user_id) {
3495 if let Err(error) = peer.send(
3496 connection_id,
3497 proto::AddNotification {
3498 notification: Some(notification.clone()),
3499 },
3500 ) {
3501 tracing::error!(
3502 "failed to send notification to {:?} {}",
3503 connection_id,
3504 error
3505 );
3506 }
3507 }
3508 }
3509}
3510
3511/// Send a message to the channel
3512async fn send_channel_message(
3513 _request: proto::SendChannelMessage,
3514 _response: Response<proto::SendChannelMessage>,
3515 _session: MessageContext,
3516) -> Result<()> {
3517 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3518}
3519
3520/// Delete a channel message
3521async fn remove_channel_message(
3522 _request: proto::RemoveChannelMessage,
3523 _response: Response<proto::RemoveChannelMessage>,
3524 _session: MessageContext,
3525) -> Result<()> {
3526 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3527}
3528
3529async fn update_channel_message(
3530 _request: proto::UpdateChannelMessage,
3531 _response: Response<proto::UpdateChannelMessage>,
3532 _session: MessageContext,
3533) -> Result<()> {
3534 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3535}
3536
3537/// Mark a channel message as read
3538async fn acknowledge_channel_message(
3539 _request: proto::AckChannelMessage,
3540 _session: MessageContext,
3541) -> Result<()> {
3542 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3543}
3544
3545/// Mark a buffer version as synced
3546async fn acknowledge_buffer_version(
3547 request: proto::AckBufferOperation,
3548 session: MessageContext,
3549) -> Result<()> {
3550 let buffer_id = BufferId::from_proto(request.buffer_id);
3551 session
3552 .db()
3553 .await
3554 .observe_buffer_version(
3555 buffer_id,
3556 session.user_id(),
3557 request.epoch as i32,
3558 &request.version,
3559 )
3560 .await?;
3561 Ok(())
3562}
3563
3564/// Start receiving chat updates for a channel
3565async fn join_channel_chat(
3566 _request: proto::JoinChannelChat,
3567 _response: Response<proto::JoinChannelChat>,
3568 _session: MessageContext,
3569) -> Result<()> {
3570 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3571}
3572
3573/// Stop receiving chat updates for a channel
3574async fn leave_channel_chat(
3575 _request: proto::LeaveChannelChat,
3576 _session: MessageContext,
3577) -> Result<()> {
3578 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3579}
3580
3581/// Retrieve the chat history for a channel
3582async fn get_channel_messages(
3583 _request: proto::GetChannelMessages,
3584 _response: Response<proto::GetChannelMessages>,
3585 _session: MessageContext,
3586) -> Result<()> {
3587 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3588}
3589
3590/// Retrieve specific chat messages
3591async fn get_channel_messages_by_id(
3592 _request: proto::GetChannelMessagesById,
3593 _response: Response<proto::GetChannelMessagesById>,
3594 _session: MessageContext,
3595) -> Result<()> {
3596 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3597}
3598
3599/// Retrieve the current users notifications
3600async fn get_notifications(
3601 request: proto::GetNotifications,
3602 response: Response<proto::GetNotifications>,
3603 session: MessageContext,
3604) -> Result<()> {
3605 let notifications = session
3606 .db()
3607 .await
3608 .get_notifications(
3609 session.user_id(),
3610 NOTIFICATION_COUNT_PER_PAGE,
3611 request.before_id.map(db::NotificationId::from_proto),
3612 )
3613 .await?;
3614 response.send(proto::GetNotificationsResponse {
3615 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3616 notifications,
3617 })?;
3618 Ok(())
3619}
3620
3621/// Mark notifications as read
3622async fn mark_notification_as_read(
3623 request: proto::MarkNotificationRead,
3624 response: Response<proto::MarkNotificationRead>,
3625 session: MessageContext,
3626) -> Result<()> {
3627 let database = &session.db().await;
3628 let notifications = database
3629 .mark_notification_as_read_by_id(
3630 session.user_id(),
3631 NotificationId::from_proto(request.notification_id),
3632 )
3633 .await?;
3634 send_notifications(
3635 &*session.connection_pool().await,
3636 &session.peer,
3637 notifications,
3638 );
3639 response.send(proto::Ack {})?;
3640 Ok(())
3641}
3642
3643fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3644 let message = match message {
3645 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3646 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3647 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3648 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3649 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3650 code: frame.code.into(),
3651 reason: frame.reason.as_str().to_owned().into(),
3652 })),
3653 // We should never receive a frame while reading the message, according
3654 // to the `tungstenite` maintainers:
3655 //
3656 // > It cannot occur when you read messages from the WebSocket, but it
3657 // > can be used when you want to send the raw frames (e.g. you want to
3658 // > send the frames to the WebSocket without composing the full message first).
3659 // >
3660 // > — https://github.com/snapview/tungstenite-rs/issues/268
3661 TungsteniteMessage::Frame(_) => {
3662 bail!("received an unexpected frame while reading the message")
3663 }
3664 };
3665
3666 Ok(message)
3667}
3668
3669fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3670 match message {
3671 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3672 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3673 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3674 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3675 AxumMessage::Close(frame) => {
3676 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3677 code: frame.code.into(),
3678 reason: frame.reason.as_ref().into(),
3679 }))
3680 }
3681 }
3682}
3683
3684fn notify_membership_updated(
3685 connection_pool: &mut ConnectionPool,
3686 result: MembershipUpdated,
3687 user_id: UserId,
3688 peer: &Peer,
3689) {
3690 for membership in &result.new_channels.channel_memberships {
3691 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3692 }
3693 for channel_id in &result.removed_channels {
3694 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3695 }
3696
3697 let user_channels_update = proto::UpdateUserChannels {
3698 channel_memberships: result
3699 .new_channels
3700 .channel_memberships
3701 .iter()
3702 .map(|cm| proto::ChannelMembership {
3703 channel_id: cm.channel_id.to_proto(),
3704 role: cm.role.into(),
3705 })
3706 .collect(),
3707 ..Default::default()
3708 };
3709
3710 let mut update = build_channels_update(result.new_channels);
3711 update.delete_channels = result
3712 .removed_channels
3713 .into_iter()
3714 .map(|id| id.to_proto())
3715 .collect();
3716 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3717
3718 for connection_id in connection_pool.user_connection_ids(user_id) {
3719 peer.send(connection_id, user_channels_update.clone())
3720 .trace_err();
3721 peer.send(connection_id, update.clone()).trace_err();
3722 }
3723}
3724
3725fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3726 proto::UpdateUserChannels {
3727 channel_memberships: channels
3728 .channel_memberships
3729 .iter()
3730 .map(|m| proto::ChannelMembership {
3731 channel_id: m.channel_id.to_proto(),
3732 role: m.role.into(),
3733 })
3734 .collect(),
3735 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3736 }
3737}
3738
3739fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3740 let mut update = proto::UpdateChannels::default();
3741
3742 for channel in channels.channels {
3743 update.channels.push(channel.to_proto());
3744 }
3745
3746 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3747
3748 for (channel_id, participants) in channels.channel_participants {
3749 update
3750 .channel_participants
3751 .push(proto::ChannelParticipants {
3752 channel_id: channel_id.to_proto(),
3753 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3754 });
3755 }
3756
3757 for channel in channels.invited_channels {
3758 update.channel_invitations.push(channel.to_proto());
3759 }
3760
3761 update
3762}
3763
3764fn build_initial_contacts_update(
3765 contacts: Vec<db::Contact>,
3766 pool: &ConnectionPool,
3767) -> proto::UpdateContacts {
3768 let mut update = proto::UpdateContacts::default();
3769
3770 for contact in contacts {
3771 match contact {
3772 db::Contact::Accepted { user_id, busy } => {
3773 update.contacts.push(contact_for_user(user_id, busy, pool));
3774 }
3775 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3776 db::Contact::Incoming { user_id } => {
3777 update
3778 .incoming_requests
3779 .push(proto::IncomingContactRequest {
3780 requester_id: user_id.to_proto(),
3781 })
3782 }
3783 }
3784 }
3785
3786 update
3787}
3788
3789fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3790 proto::Contact {
3791 user_id: user_id.to_proto(),
3792 online: pool.is_user_online(user_id),
3793 busy,
3794 }
3795}
3796
3797fn room_updated(room: &proto::Room, peer: &Peer) {
3798 broadcast(
3799 None,
3800 room.participants
3801 .iter()
3802 .filter_map(|participant| Some(participant.peer_id?.into())),
3803 |peer_id| {
3804 peer.send(
3805 peer_id,
3806 proto::RoomUpdated {
3807 room: Some(room.clone()),
3808 },
3809 )
3810 },
3811 );
3812}
3813
3814fn channel_updated(
3815 channel: &db::channel::Model,
3816 room: &proto::Room,
3817 peer: &Peer,
3818 pool: &ConnectionPool,
3819) {
3820 let participants = room
3821 .participants
3822 .iter()
3823 .map(|p| p.user_id)
3824 .collect::<Vec<_>>();
3825
3826 broadcast(
3827 None,
3828 pool.channel_connection_ids(channel.root_id())
3829 .filter_map(|(channel_id, role)| {
3830 role.can_see_channel(channel.visibility)
3831 .then_some(channel_id)
3832 }),
3833 |peer_id| {
3834 peer.send(
3835 peer_id,
3836 proto::UpdateChannels {
3837 channel_participants: vec![proto::ChannelParticipants {
3838 channel_id: channel.id.to_proto(),
3839 participant_user_ids: participants.clone(),
3840 }],
3841 ..Default::default()
3842 },
3843 )
3844 },
3845 );
3846}
3847
3848async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3849 let db = session.db().await;
3850
3851 let contacts = db.get_contacts(user_id).await?;
3852 let busy = db.is_user_busy(user_id).await?;
3853
3854 let pool = session.connection_pool().await;
3855 let updated_contact = contact_for_user(user_id, busy, &pool);
3856 for contact in contacts {
3857 if let db::Contact::Accepted {
3858 user_id: contact_user_id,
3859 ..
3860 } = contact
3861 {
3862 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3863 session
3864 .peer
3865 .send(
3866 contact_conn_id,
3867 proto::UpdateContacts {
3868 contacts: vec![updated_contact.clone()],
3869 remove_contacts: Default::default(),
3870 incoming_requests: Default::default(),
3871 remove_incoming_requests: Default::default(),
3872 outgoing_requests: Default::default(),
3873 remove_outgoing_requests: Default::default(),
3874 },
3875 )
3876 .trace_err();
3877 }
3878 }
3879 }
3880 Ok(())
3881}
3882
3883async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3884 let mut contacts_to_update = HashSet::default();
3885
3886 let room_id;
3887 let canceled_calls_to_user_ids;
3888 let livekit_room;
3889 let delete_livekit_room;
3890 let room;
3891 let channel;
3892
3893 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3894 contacts_to_update.insert(session.user_id());
3895
3896 for project in left_room.left_projects.values() {
3897 project_left(project, session);
3898 }
3899
3900 room_id = RoomId::from_proto(left_room.room.id);
3901 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3902 livekit_room = mem::take(&mut left_room.room.livekit_room);
3903 delete_livekit_room = left_room.deleted;
3904 room = mem::take(&mut left_room.room);
3905 channel = mem::take(&mut left_room.channel);
3906
3907 room_updated(&room, &session.peer);
3908 } else {
3909 return Ok(());
3910 }
3911
3912 if let Some(channel) = channel {
3913 channel_updated(
3914 &channel,
3915 &room,
3916 &session.peer,
3917 &*session.connection_pool().await,
3918 );
3919 }
3920
3921 {
3922 let pool = session.connection_pool().await;
3923 for canceled_user_id in canceled_calls_to_user_ids {
3924 for connection_id in pool.user_connection_ids(canceled_user_id) {
3925 session
3926 .peer
3927 .send(
3928 connection_id,
3929 proto::CallCanceled {
3930 room_id: room_id.to_proto(),
3931 },
3932 )
3933 .trace_err();
3934 }
3935 contacts_to_update.insert(canceled_user_id);
3936 }
3937 }
3938
3939 for contact_user_id in contacts_to_update {
3940 update_user_contacts(contact_user_id, session).await?;
3941 }
3942
3943 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3944 live_kit
3945 .remove_participant(livekit_room.clone(), session.user_id().to_string())
3946 .await
3947 .trace_err();
3948
3949 if delete_livekit_room {
3950 live_kit.delete_room(livekit_room).await.trace_err();
3951 }
3952 }
3953
3954 Ok(())
3955}
3956
3957async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3958 let left_channel_buffers = session
3959 .db()
3960 .await
3961 .leave_channel_buffers(session.connection_id)
3962 .await?;
3963
3964 for left_buffer in left_channel_buffers {
3965 channel_buffer_updated(
3966 session.connection_id,
3967 left_buffer.connections,
3968 &proto::UpdateChannelBufferCollaborators {
3969 channel_id: left_buffer.channel_id.to_proto(),
3970 collaborators: left_buffer.collaborators,
3971 },
3972 &session.peer,
3973 );
3974 }
3975
3976 Ok(())
3977}
3978
3979fn project_left(project: &db::LeftProject, session: &Session) {
3980 for connection_id in &project.connection_ids {
3981 if project.should_unshare {
3982 session
3983 .peer
3984 .send(
3985 *connection_id,
3986 proto::UnshareProject {
3987 project_id: project.id.to_proto(),
3988 },
3989 )
3990 .trace_err();
3991 } else {
3992 session
3993 .peer
3994 .send(
3995 *connection_id,
3996 proto::RemoveProjectCollaborator {
3997 project_id: project.id.to_proto(),
3998 peer_id: Some(session.connection_id.into()),
3999 },
4000 )
4001 .trace_err();
4002 }
4003 }
4004}
4005
4006async fn share_agent_thread(
4007 request: proto::ShareAgentThread,
4008 response: Response<proto::ShareAgentThread>,
4009 session: MessageContext,
4010) -> Result<()> {
4011 let user_id = session.user_id();
4012
4013 let share_id = SharedThreadId::from_proto(request.session_id.clone())
4014 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4015
4016 session
4017 .db()
4018 .await
4019 .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4020 .await?;
4021
4022 response.send(proto::Ack {})?;
4023
4024 Ok(())
4025}
4026
4027async fn get_shared_agent_thread(
4028 request: proto::GetSharedAgentThread,
4029 response: Response<proto::GetSharedAgentThread>,
4030 session: MessageContext,
4031) -> Result<()> {
4032 let share_id = SharedThreadId::from_proto(request.session_id)
4033 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4034
4035 let result = session.db().await.get_shared_thread(share_id).await?;
4036
4037 match result {
4038 Some((thread, username)) => {
4039 response.send(proto::GetSharedAgentThreadResponse {
4040 title: thread.title,
4041 thread_data: thread.data,
4042 sharer_username: username,
4043 created_at: thread.created_at.and_utc().to_rfc3339(),
4044 })?;
4045 }
4046 None => {
4047 return Err(anyhow!("Shared thread not found").into());
4048 }
4049 }
4050
4051 Ok(())
4052}
4053
4054pub trait ResultExt {
4055 type Ok;
4056
4057 fn trace_err(self) -> Option<Self::Ok>;
4058}
4059
4060impl<T, E> ResultExt for Result<T, E>
4061where
4062 E: std::fmt::Debug,
4063{
4064 type Ok = T;
4065
4066 #[track_caller]
4067 fn trace_err(self) -> Option<T> {
4068 match self {
4069 Ok(value) => Some(value),
4070 Err(error) => {
4071 tracing::error!("{:?}", error);
4072 None
4073 }
4074 }
4075 }
4076}