1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
10 UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use futures::TryFutureExt as _;
36use rpc::proto::split_repository_update;
37use tracing::Span;
38use util::paths::PathStyle;
39
40use futures::{
41 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
42 stream::FuturesUnordered,
43};
44use prometheus::{IntGauge, register_int_gauge};
45use rpc::{
46 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
47 proto::{
48 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
49 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
50 },
51};
52use semver::Version;
53use std::{
54 any::TypeId,
55 future::Future,
56 marker::PhantomData,
57 mem,
58 net::SocketAddr,
59 ops::{Deref, DerefMut},
60 rc::Rc,
61 sync::{
62 Arc, OnceLock,
63 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
64 },
65 time::{Duration, Instant},
66};
67use tokio::sync::{Semaphore, watch};
68use tower::ServiceBuilder;
69use tracing::{
70 Instrument,
71 field::{self},
72 info_span, instrument,
73};
74
75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
76
77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
79
80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
82
83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
84
85const TOTAL_DURATION_MS: &str = "total_duration_ms";
86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
88const HOST_WAITING_MS: &str = "host_waiting_ms";
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
92
93pub struct ConnectionGuard;
94
95impl ConnectionGuard {
96 pub fn try_acquire() -> Result<Self, ()> {
97 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
98 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
99 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
100 tracing::error!(
101 "too many concurrent connections: {}",
102 current_connections + 1
103 );
104 return Err(());
105 }
106 Ok(ConnectionGuard)
107 }
108}
109
110impl Drop for ConnectionGuard {
111 fn drop(&mut self) {
112 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
113 }
114}
115
116struct Response<R> {
117 peer: Arc<Peer>,
118 receipt: Receipt<R>,
119 responded: Arc<AtomicBool>,
120}
121
122impl<R: RequestMessage> Response<R> {
123 fn send(self, payload: R::Response) -> Result<()> {
124 self.responded.store(true, SeqCst);
125 self.peer.respond(self.receipt, payload)?;
126 Ok(())
127 }
128}
129
130#[derive(Clone, Debug)]
131pub enum Principal {
132 User(User),
133}
134
135impl Principal {
136 fn update_span(&self, span: &tracing::Span) {
137 match &self {
138 Principal::User(user) => {
139 span.record("user_id", user.id.0);
140 span.record("login", &user.github_login);
141 }
142 }
143 }
144}
145
146#[derive(Clone)]
147struct MessageContext {
148 session: Session,
149 span: tracing::Span,
150}
151
152impl Deref for MessageContext {
153 type Target = Session;
154
155 fn deref(&self) -> &Self::Target {
156 &self.session
157 }
158}
159
160impl MessageContext {
161 pub fn forward_request<T: RequestMessage>(
162 &self,
163 receiver_id: ConnectionId,
164 request: T,
165 ) -> impl Future<Output = anyhow::Result<T::Response>> {
166 let request_start_time = Instant::now();
167 let span = self.span.clone();
168 tracing::info!("start forwarding request");
169 self.peer
170 .forward_request(self.connection_id, receiver_id, request)
171 .inspect(move |_| {
172 span.record(
173 HOST_WAITING_MS,
174 request_start_time.elapsed().as_micros() as f64 / 1000.0,
175 );
176 })
177 .inspect_err(|_| tracing::error!("error forwarding request"))
178 .inspect_ok(|_| tracing::info!("finished forwarding request"))
179 }
180}
181
182#[derive(Clone)]
183struct Session {
184 principal: Principal,
185 connection_id: ConnectionId,
186 db: Arc<tokio::sync::Mutex<DbHandle>>,
187 peer: Arc<Peer>,
188 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
189 app_state: Arc<AppState>,
190 /// The GeoIP country code for the user.
191 #[allow(unused)]
192 geoip_country_code: Option<String>,
193 #[allow(unused)]
194 system_id: Option<String>,
195 _executor: Executor,
196}
197
198impl Session {
199 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
200 #[cfg(feature = "test-support")]
201 tokio::task::yield_now().await;
202 let guard = self.db.lock().await;
203 #[cfg(feature = "test-support")]
204 tokio::task::yield_now().await;
205 guard
206 }
207
208 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
209 #[cfg(feature = "test-support")]
210 tokio::task::yield_now().await;
211 let guard = self.connection_pool.lock();
212 ConnectionPoolGuard {
213 guard,
214 _not_send: PhantomData,
215 }
216 }
217
218 #[expect(dead_code)]
219 fn is_staff(&self) -> bool {
220 match &self.principal {
221 Principal::User(user) => user.admin,
222 }
223 }
224
225 fn user_id(&self) -> UserId {
226 match &self.principal {
227 Principal::User(user) => user.id,
228 }
229 }
230}
231
232impl Debug for Session {
233 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
234 let mut result = f.debug_struct("Session");
235 match &self.principal {
236 Principal::User(user) => {
237 result.field("user", &user.github_login);
238 }
239 }
240 result.field("connection_id", &self.connection_id).finish()
241 }
242}
243
244struct DbHandle(Arc<Database>);
245
246impl Deref for DbHandle {
247 type Target = Database;
248
249 fn deref(&self) -> &Self::Target {
250 self.0.as_ref()
251 }
252}
253
254pub struct Server {
255 id: parking_lot::Mutex<ServerId>,
256 peer: Arc<Peer>,
257 pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
258 app_state: Arc<AppState>,
259 handlers: HashMap<TypeId, MessageHandler>,
260 teardown: watch::Sender<bool>,
261}
262
263struct ConnectionPoolGuard<'a> {
264 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
265 _not_send: PhantomData<Rc<()>>,
266}
267
268impl Server {
269 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
270 let mut server = Self {
271 id: parking_lot::Mutex::new(id),
272 peer: Peer::new(id.0 as u32),
273 app_state,
274 connection_pool: Default::default(),
275 handlers: Default::default(),
276 teardown: watch::channel(false).0,
277 };
278
279 server
280 .add_request_handler(ping)
281 .add_request_handler(create_room)
282 .add_request_handler(join_room)
283 .add_request_handler(rejoin_room)
284 .add_request_handler(leave_room)
285 .add_request_handler(set_room_participant_role)
286 .add_request_handler(call)
287 .add_request_handler(cancel_call)
288 .add_message_handler(decline_call)
289 .add_request_handler(update_participant_location)
290 .add_request_handler(share_project)
291 .add_message_handler(unshare_project)
292 .add_request_handler(join_project)
293 .add_message_handler(leave_project)
294 .add_request_handler(update_project)
295 .add_request_handler(update_worktree)
296 .add_request_handler(update_repository)
297 .add_request_handler(remove_repository)
298 .add_message_handler(start_language_server)
299 .add_message_handler(update_language_server)
300 .add_message_handler(update_diagnostic_summary)
301 .add_message_handler(update_worktree_settings)
302 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
307 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
308 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
311 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
313 .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
314 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
316 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
317 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
318 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
319 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
320 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
321 .add_request_handler(
322 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
323 )
324 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
325 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
326 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
327 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
328 .add_request_handler(
329 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
330 )
331 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
332 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
333 .add_request_handler(
334 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
335 )
336 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
337 .add_request_handler(
338 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
339 )
340 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
341 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
342 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
343 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
344 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
345 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
346 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
347 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
348 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
349 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
350 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
353 )
354 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
355 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
356 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
357 .add_request_handler(lsp_query)
358 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
359 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
360 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
361 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
362 .add_message_handler(create_buffer_for_peer)
363 .add_message_handler(create_image_for_peer)
364 .add_request_handler(update_buffer)
365 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
366 .add_message_handler(
367 broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
368 )
369 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
370 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
371 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
372 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
373 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
374 .add_message_handler(
375 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
376 )
377 .add_request_handler(get_users)
378 .add_request_handler(fuzzy_search_users)
379 .add_request_handler(request_contact)
380 .add_request_handler(remove_contact)
381 .add_request_handler(respond_to_contact_request)
382 .add_message_handler(subscribe_to_channels)
383 .add_request_handler(create_channel)
384 .add_request_handler(delete_channel)
385 .add_request_handler(invite_channel_member)
386 .add_request_handler(remove_channel_member)
387 .add_request_handler(set_channel_member_role)
388 .add_request_handler(set_channel_visibility)
389 .add_request_handler(rename_channel)
390 .add_request_handler(join_channel_buffer)
391 .add_request_handler(leave_channel_buffer)
392 .add_message_handler(update_channel_buffer)
393 .add_request_handler(rejoin_channel_buffers)
394 .add_request_handler(get_channel_members)
395 .add_request_handler(respond_to_channel_invite)
396 .add_request_handler(join_channel)
397 .add_request_handler(join_channel_chat)
398 .add_message_handler(leave_channel_chat)
399 .add_request_handler(send_channel_message)
400 .add_request_handler(remove_channel_message)
401 .add_request_handler(update_channel_message)
402 .add_request_handler(get_channel_messages)
403 .add_request_handler(get_channel_messages_by_id)
404 .add_request_handler(get_notifications)
405 .add_request_handler(mark_notification_as_read)
406 .add_request_handler(move_channel)
407 .add_request_handler(reorder_channel)
408 .add_request_handler(follow)
409 .add_message_handler(unfollow)
410 .add_message_handler(update_followers)
411 .add_message_handler(acknowledge_channel_message)
412 .add_message_handler(acknowledge_buffer_version)
413 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
414 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
415 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
416 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
417 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
418 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
419 .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
420 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
421 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
422 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
423 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
424 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
425 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
426 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
427 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
428 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
429 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
430 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
431 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
432 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
433 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
434 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
435 .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
436 .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
437 .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
438 .add_request_handler(forward_read_only_project_request::<proto::GitGetHeadSha>)
439 .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
440 .add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
441 .add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
442 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
443 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
444 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
445 .add_request_handler(share_agent_thread)
446 .add_request_handler(get_shared_agent_thread)
447 .add_request_handler(forward_project_search_chunk);
448
449 Arc::new(server)
450 }
451
452 pub async fn start(&self) -> Result<()> {
453 let server_id = *self.id.lock();
454 let app_state = self.app_state.clone();
455 let peer = self.peer.clone();
456 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
457 let pool = self.connection_pool.clone();
458 let livekit_client = self.app_state.livekit_client.clone();
459
460 let span = info_span!("start server");
461 self.app_state.executor.spawn_detached(
462 async move {
463 tracing::info!("waiting for cleanup timeout");
464 timeout.await;
465 tracing::info!("cleanup timeout expired, retrieving stale rooms");
466
467 app_state
468 .db
469 .delete_stale_channel_chat_participants(
470 &app_state.config.zed_environment,
471 server_id,
472 )
473 .await
474 .trace_err();
475
476 if let Some((room_ids, channel_ids)) = app_state
477 .db
478 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
479 .await
480 .trace_err()
481 {
482 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
483 tracing::info!(
484 stale_channel_buffer_count = channel_ids.len(),
485 "retrieved stale channel buffers"
486 );
487
488 for channel_id in channel_ids {
489 if let Some(refreshed_channel_buffer) = app_state
490 .db
491 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
492 .await
493 .trace_err()
494 {
495 for connection_id in refreshed_channel_buffer.connection_ids {
496 peer.send(
497 connection_id,
498 proto::UpdateChannelBufferCollaborators {
499 channel_id: channel_id.to_proto(),
500 collaborators: refreshed_channel_buffer
501 .collaborators
502 .clone(),
503 },
504 )
505 .trace_err();
506 }
507 }
508 }
509
510 for room_id in room_ids {
511 let mut contacts_to_update = HashSet::default();
512 let mut canceled_calls_to_user_ids = Vec::new();
513 let mut livekit_room = String::new();
514 let mut delete_livekit_room = false;
515
516 if let Some(mut refreshed_room) = app_state
517 .db
518 .clear_stale_room_participants(room_id, server_id)
519 .await
520 .trace_err()
521 {
522 tracing::info!(
523 room_id = room_id.0,
524 new_participant_count = refreshed_room.room.participants.len(),
525 "refreshed room"
526 );
527 room_updated(&refreshed_room.room, &peer);
528 if let Some(channel) = refreshed_room.channel.as_ref() {
529 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
530 }
531 contacts_to_update
532 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
533 contacts_to_update
534 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
535 canceled_calls_to_user_ids =
536 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
537 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
538 delete_livekit_room = refreshed_room.room.participants.is_empty();
539 }
540
541 {
542 let pool = pool.lock();
543 for canceled_user_id in canceled_calls_to_user_ids {
544 for connection_id in pool.user_connection_ids(canceled_user_id) {
545 peer.send(
546 connection_id,
547 proto::CallCanceled {
548 room_id: room_id.to_proto(),
549 },
550 )
551 .trace_err();
552 }
553 }
554 }
555
556 for user_id in contacts_to_update {
557 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
558 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
559 if let Some((busy, contacts)) = busy.zip(contacts) {
560 let pool = pool.lock();
561 let updated_contact = contact_for_user(user_id, busy, &pool);
562 for contact in contacts {
563 if let db::Contact::Accepted {
564 user_id: contact_user_id,
565 ..
566 } = contact
567 {
568 for contact_conn_id in
569 pool.user_connection_ids(contact_user_id)
570 {
571 peer.send(
572 contact_conn_id,
573 proto::UpdateContacts {
574 contacts: vec![updated_contact.clone()],
575 remove_contacts: Default::default(),
576 incoming_requests: Default::default(),
577 remove_incoming_requests: Default::default(),
578 outgoing_requests: Default::default(),
579 remove_outgoing_requests: Default::default(),
580 },
581 )
582 .trace_err();
583 }
584 }
585 }
586 }
587 }
588
589 if let Some(live_kit) = livekit_client.as_ref()
590 && delete_livekit_room
591 {
592 live_kit.delete_room(livekit_room).await.trace_err();
593 }
594 }
595 }
596
597 app_state
598 .db
599 .delete_stale_channel_chat_participants(
600 &app_state.config.zed_environment,
601 server_id,
602 )
603 .await
604 .trace_err();
605
606 app_state
607 .db
608 .clear_old_worktree_entries(server_id)
609 .await
610 .trace_err();
611
612 app_state
613 .db
614 .delete_stale_servers(&app_state.config.zed_environment, server_id)
615 .await
616 .trace_err();
617 }
618 .instrument(span),
619 );
620 Ok(())
621 }
622
623 pub fn teardown(&self) {
624 self.peer.teardown();
625 self.connection_pool.lock().reset();
626 let _ = self.teardown.send(true);
627 }
628
629 #[cfg(feature = "test-support")]
630 pub fn reset(&self, id: ServerId) {
631 self.teardown();
632 *self.id.lock() = id;
633 self.peer.reset(id.0 as u32);
634 let _ = self.teardown.send(false);
635 }
636
637 #[cfg(feature = "test-support")]
638 pub fn id(&self) -> ServerId {
639 *self.id.lock()
640 }
641
642 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
643 where
644 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
645 Fut: 'static + Send + Future<Output = Result<()>>,
646 M: EnvelopedMessage,
647 {
648 let prev_handler = self.handlers.insert(
649 TypeId::of::<M>(),
650 Box::new(move |envelope, session, span| {
651 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
652 let received_at = envelope.received_at;
653 tracing::info!("message received");
654 let start_time = Instant::now();
655 let future = (handler)(
656 *envelope,
657 MessageContext {
658 session,
659 span: span.clone(),
660 },
661 );
662 async move {
663 let result = future.await;
664 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
665 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
666 let queue_duration_ms = total_duration_ms - processing_duration_ms;
667 span.record(TOTAL_DURATION_MS, total_duration_ms);
668 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
669 span.record(QUEUE_DURATION_MS, queue_duration_ms);
670 match result {
671 Err(error) => {
672 tracing::error!(?error, "error handling message")
673 }
674 Ok(()) => tracing::info!("finished handling message"),
675 }
676 }
677 .boxed()
678 }),
679 );
680 if prev_handler.is_some() {
681 panic!("registered a handler for the same message twice");
682 }
683 self
684 }
685
686 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
687 where
688 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
689 Fut: 'static + Send + Future<Output = Result<()>>,
690 M: EnvelopedMessage,
691 {
692 self.add_handler(move |envelope, session| handler(envelope.payload, session));
693 self
694 }
695
696 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
697 where
698 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
699 Fut: Send + Future<Output = Result<()>>,
700 M: RequestMessage,
701 {
702 let handler = Arc::new(handler);
703 self.add_handler(move |envelope, session| {
704 let receipt = envelope.receipt();
705 let handler = handler.clone();
706 async move {
707 let peer = session.peer.clone();
708 let responded = Arc::new(AtomicBool::default());
709 let response = Response {
710 peer: peer.clone(),
711 responded: responded.clone(),
712 receipt,
713 };
714 match (handler)(envelope.payload, response, session).await {
715 Ok(()) => {
716 if responded.load(std::sync::atomic::Ordering::SeqCst) {
717 Ok(())
718 } else {
719 Err(anyhow!("handler did not send a response"))?
720 }
721 }
722 Err(error) => {
723 let proto_err = match &error {
724 Error::Internal(err) => err.to_proto(),
725 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
726 };
727 peer.respond_with_error(receipt, proto_err)?;
728 Err(error)
729 }
730 }
731 }
732 })
733 }
734
735 pub fn handle_connection(
736 self: &Arc<Self>,
737 connection: Connection,
738 address: String,
739 principal: Principal,
740 zed_version: ZedVersion,
741 release_channel: Option<String>,
742 user_agent: Option<String>,
743 geoip_country_code: Option<String>,
744 system_id: Option<String>,
745 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
746 executor: Executor,
747 connection_guard: Option<ConnectionGuard>,
748 ) -> impl Future<Output = ()> + use<> {
749 let this = self.clone();
750 let span = info_span!("handle connection", %address,
751 connection_id=field::Empty,
752 user_id=field::Empty,
753 login=field::Empty,
754 user_agent=field::Empty,
755 geoip_country_code=field::Empty,
756 release_channel=field::Empty,
757 );
758 principal.update_span(&span);
759 if let Some(user_agent) = user_agent {
760 span.record("user_agent", user_agent);
761 }
762 if let Some(release_channel) = release_channel {
763 span.record("release_channel", release_channel);
764 }
765
766 if let Some(country_code) = geoip_country_code.as_ref() {
767 span.record("geoip_country_code", country_code);
768 }
769
770 let mut teardown = self.teardown.subscribe();
771 async move {
772 if *teardown.borrow() {
773 tracing::error!("server is tearing down");
774 return;
775 }
776
777 let (connection_id, handle_io, mut incoming_rx) =
778 this.peer.add_connection(connection, {
779 let executor = executor.clone();
780 move |duration| executor.sleep(duration)
781 });
782 tracing::Span::current().record("connection_id", format!("{}", connection_id));
783
784 tracing::info!("connection opened");
785
786 let session = Session {
787 principal: principal.clone(),
788 connection_id,
789 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
790 peer: this.peer.clone(),
791 connection_pool: this.connection_pool.clone(),
792 app_state: this.app_state.clone(),
793 geoip_country_code,
794 system_id,
795 _executor: executor.clone(),
796 };
797
798 if let Err(error) = this
799 .send_initial_client_update(
800 connection_id,
801 zed_version,
802 send_connection_id,
803 &session,
804 )
805 .await
806 {
807 tracing::error!(?error, "failed to send initial client update");
808 return;
809 }
810 drop(connection_guard);
811
812 let handle_io = handle_io.fuse();
813 futures::pin_mut!(handle_io);
814
815 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
816 // This prevents deadlocks when e.g., client A performs a request to client B and
817 // client B performs a request to client A. If both clients stop processing further
818 // messages until their respective request completes, they won't have a chance to
819 // respond to the other client's request and cause a deadlock.
820 //
821 // This arrangement ensures we will attempt to process earlier messages first, but fall
822 // back to processing messages arrived later in the spirit of making progress.
823 const MAX_CONCURRENT_HANDLERS: usize = 256;
824 let mut foreground_message_handlers = FuturesUnordered::new();
825 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
826 let get_concurrent_handlers = {
827 let concurrent_handlers = concurrent_handlers.clone();
828 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
829 };
830 loop {
831 let next_message = async {
832 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
833 let message = incoming_rx.next().await;
834 // Cache the concurrent_handlers here, so that we know what the
835 // queue looks like as each handler starts
836 (permit, message, get_concurrent_handlers())
837 }
838 .fuse();
839 futures::pin_mut!(next_message);
840 futures::select_biased! {
841 _ = teardown.changed().fuse() => return,
842 result = handle_io => {
843 if let Err(error) = result {
844 tracing::error!(?error, "error handling I/O");
845 }
846 break;
847 }
848 _ = foreground_message_handlers.next() => {}
849 next_message = next_message => {
850 let (permit, message, concurrent_handlers) = next_message;
851 if let Some(message) = message {
852 let type_name = message.payload_type_name();
853 // note: we copy all the fields from the parent span so we can query them in the logs.
854 // (https://github.com/tokio-rs/tracing/issues/2670).
855 let span = tracing::info_span!("receive message",
856 %connection_id,
857 %address,
858 type_name,
859 concurrent_handlers,
860 user_id=field::Empty,
861 login=field::Empty,
862 lsp_query_request=field::Empty,
863 release_channel=field::Empty,
864 { TOTAL_DURATION_MS }=field::Empty,
865 { PROCESSING_DURATION_MS }=field::Empty,
866 { QUEUE_DURATION_MS }=field::Empty,
867 { HOST_WAITING_MS }=field::Empty
868 );
869 principal.update_span(&span);
870 let span_enter = span.enter();
871 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
872 let is_background = message.is_background();
873 let handle_message = (handler)(message, session.clone(), span.clone());
874 drop(span_enter);
875
876 let handle_message = async move {
877 handle_message.await;
878 drop(permit);
879 }.instrument(span);
880 if is_background {
881 executor.spawn_detached(handle_message);
882 } else {
883 foreground_message_handlers.push(handle_message);
884 }
885 } else {
886 tracing::error!("no message handler");
887 }
888 } else {
889 tracing::info!("connection closed");
890 break;
891 }
892 }
893 }
894 }
895
896 drop(foreground_message_handlers);
897 let concurrent_handlers = get_concurrent_handlers();
898 tracing::info!(concurrent_handlers, "signing out");
899 if let Err(error) = connection_lost(session, teardown, executor).await {
900 tracing::error!(?error, "error signing out");
901 }
902 }
903 .instrument(span)
904 }
905
906 async fn send_initial_client_update(
907 &self,
908 connection_id: ConnectionId,
909 zed_version: ZedVersion,
910 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
911 session: &Session,
912 ) -> Result<()> {
913 self.peer.send(
914 connection_id,
915 proto::Hello {
916 peer_id: Some(connection_id.into()),
917 },
918 )?;
919 tracing::info!("sent hello message");
920 if let Some(send_connection_id) = send_connection_id.take() {
921 let _ = send_connection_id.send(connection_id);
922 }
923
924 match &session.principal {
925 Principal::User(user) => {
926 if !user.connected_once {
927 self.peer.send(connection_id, proto::ShowContacts {})?;
928 self.app_state
929 .db
930 .set_user_connected_once(user.id, true)
931 .await?;
932 }
933
934 let contacts = self.app_state.db.get_contacts(user.id).await?;
935
936 {
937 let mut pool = self.connection_pool.lock();
938 pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
939 self.peer.send(
940 connection_id,
941 build_initial_contacts_update(contacts, &pool),
942 )?;
943 }
944
945 if should_auto_subscribe_to_channels(&zed_version) {
946 subscribe_user_to_channels(user.id, session).await?;
947 }
948
949 if let Some(incoming_call) =
950 self.app_state.db.incoming_call_for_user(user.id).await?
951 {
952 self.peer.send(connection_id, incoming_call)?;
953 }
954
955 update_user_contacts(user.id, session).await?;
956 }
957 }
958
959 Ok(())
960 }
961}
962
963impl Deref for ConnectionPoolGuard<'_> {
964 type Target = ConnectionPool;
965
966 fn deref(&self) -> &Self::Target {
967 &self.guard
968 }
969}
970
971impl DerefMut for ConnectionPoolGuard<'_> {
972 fn deref_mut(&mut self) -> &mut Self::Target {
973 &mut self.guard
974 }
975}
976
977impl Drop for ConnectionPoolGuard<'_> {
978 fn drop(&mut self) {
979 #[cfg(feature = "test-support")]
980 self.check_invariants();
981 }
982}
983
984fn broadcast<F>(
985 sender_id: Option<ConnectionId>,
986 receiver_ids: impl IntoIterator<Item = ConnectionId>,
987 mut f: F,
988) where
989 F: FnMut(ConnectionId) -> anyhow::Result<()>,
990{
991 for receiver_id in receiver_ids {
992 if Some(receiver_id) != sender_id
993 && let Err(error) = f(receiver_id)
994 {
995 tracing::error!("failed to send to {:?} {}", receiver_id, error);
996 }
997 }
998}
999
1000pub struct ProtocolVersion(u32);
1001
1002impl Header for ProtocolVersion {
1003 fn name() -> &'static HeaderName {
1004 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1005 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1006 }
1007
1008 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1009 where
1010 Self: Sized,
1011 I: Iterator<Item = &'i axum::http::HeaderValue>,
1012 {
1013 let version = values
1014 .next()
1015 .ok_or_else(axum::headers::Error::invalid)?
1016 .to_str()
1017 .map_err(|_| axum::headers::Error::invalid())?
1018 .parse()
1019 .map_err(|_| axum::headers::Error::invalid())?;
1020 Ok(Self(version))
1021 }
1022
1023 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1024 values.extend([self.0.to_string().parse().unwrap()]);
1025 }
1026}
1027
1028pub struct AppVersionHeader(Version);
1029impl Header for AppVersionHeader {
1030 fn name() -> &'static HeaderName {
1031 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1032 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1033 }
1034
1035 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1036 where
1037 Self: Sized,
1038 I: Iterator<Item = &'i axum::http::HeaderValue>,
1039 {
1040 let version = values
1041 .next()
1042 .ok_or_else(axum::headers::Error::invalid)?
1043 .to_str()
1044 .map_err(|_| axum::headers::Error::invalid())?
1045 .parse()
1046 .map_err(|_| axum::headers::Error::invalid())?;
1047 Ok(Self(version))
1048 }
1049
1050 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1051 values.extend([self.0.to_string().parse().unwrap()]);
1052 }
1053}
1054
1055#[derive(Debug)]
1056pub struct ReleaseChannelHeader(String);
1057
1058impl Header for ReleaseChannelHeader {
1059 fn name() -> &'static HeaderName {
1060 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1061 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1062 }
1063
1064 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1065 where
1066 Self: Sized,
1067 I: Iterator<Item = &'i axum::http::HeaderValue>,
1068 {
1069 Ok(Self(
1070 values
1071 .next()
1072 .ok_or_else(axum::headers::Error::invalid)?
1073 .to_str()
1074 .map_err(|_| axum::headers::Error::invalid())?
1075 .to_owned(),
1076 ))
1077 }
1078
1079 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1080 values.extend([self.0.parse().unwrap()]);
1081 }
1082}
1083
1084pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1085 Router::new()
1086 .route("/rpc", get(handle_websocket_request))
1087 .layer(
1088 ServiceBuilder::new()
1089 .layer(Extension(server.app_state.clone()))
1090 .layer(middleware::from_fn(auth::validate_header)),
1091 )
1092 .route("/metrics", get(handle_metrics))
1093 .layer(Extension(server))
1094}
1095
1096pub async fn handle_websocket_request(
1097 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1098 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1099 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1100 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1101 Extension(server): Extension<Arc<Server>>,
1102 Extension(principal): Extension<Principal>,
1103 user_agent: Option<TypedHeader<UserAgent>>,
1104 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1105 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1106 ws: WebSocketUpgrade,
1107) -> axum::response::Response {
1108 if protocol_version != rpc::PROTOCOL_VERSION {
1109 return (
1110 StatusCode::UPGRADE_REQUIRED,
1111 "client must be upgraded".to_string(),
1112 )
1113 .into_response();
1114 }
1115
1116 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1117 return (
1118 StatusCode::UPGRADE_REQUIRED,
1119 "no version header found".to_string(),
1120 )
1121 .into_response();
1122 };
1123
1124 let release_channel = release_channel_header.map(|header| header.0.0);
1125
1126 if !version.can_collaborate() {
1127 return (
1128 StatusCode::UPGRADE_REQUIRED,
1129 "client must be upgraded".to_string(),
1130 )
1131 .into_response();
1132 }
1133
1134 let socket_address = socket_address.to_string();
1135
1136 // Acquire connection guard before WebSocket upgrade
1137 let connection_guard = match ConnectionGuard::try_acquire() {
1138 Ok(guard) => guard,
1139 Err(()) => {
1140 return (
1141 StatusCode::SERVICE_UNAVAILABLE,
1142 "Too many concurrent connections",
1143 )
1144 .into_response();
1145 }
1146 };
1147
1148 ws.on_upgrade(move |socket| {
1149 let socket = socket
1150 .map_ok(to_tungstenite_message)
1151 .err_into()
1152 .with(|message| async move { to_axum_message(message) });
1153 let connection = Connection::new(Box::pin(socket));
1154 async move {
1155 server
1156 .handle_connection(
1157 connection,
1158 socket_address,
1159 principal,
1160 version,
1161 release_channel,
1162 user_agent.map(|header| header.to_string()),
1163 country_code_header.map(|header| header.to_string()),
1164 system_id_header.map(|header| header.to_string()),
1165 None,
1166 Executor::Production,
1167 Some(connection_guard),
1168 )
1169 .await;
1170 }
1171 })
1172}
1173
1174pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1175 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1176 let connections_metric = CONNECTIONS_METRIC
1177 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1178
1179 let connections = server
1180 .connection_pool
1181 .lock()
1182 .connections()
1183 .filter(|connection| !connection.admin)
1184 .count();
1185 connections_metric.set(connections as _);
1186
1187 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1188 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1189 register_int_gauge!(
1190 "shared_projects",
1191 "number of open projects with one or more guests"
1192 )
1193 .unwrap()
1194 });
1195
1196 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1197 shared_projects_metric.set(shared_projects as _);
1198
1199 let encoder = prometheus::TextEncoder::new();
1200 let metric_families = prometheus::gather();
1201 let encoded_metrics = encoder
1202 .encode_to_string(&metric_families)
1203 .map_err(|err| anyhow!("{err}"))?;
1204 Ok(encoded_metrics)
1205}
1206
1207#[instrument(err, skip(executor))]
1208async fn connection_lost(
1209 session: Session,
1210 mut teardown: watch::Receiver<bool>,
1211 executor: Executor,
1212) -> Result<()> {
1213 session.peer.disconnect(session.connection_id);
1214 session
1215 .connection_pool()
1216 .await
1217 .remove_connection(session.connection_id)?;
1218
1219 session
1220 .db()
1221 .await
1222 .connection_lost(session.connection_id)
1223 .await
1224 .trace_err();
1225
1226 futures::select_biased! {
1227 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1228
1229 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1230 leave_room_for_session(&session, session.connection_id).await.trace_err();
1231 leave_channel_buffers_for_session(&session)
1232 .await
1233 .trace_err();
1234
1235 if !session
1236 .connection_pool()
1237 .await
1238 .is_user_online(session.user_id())
1239 {
1240 let db = session.db().await;
1241 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1242 room_updated(&room, &session.peer);
1243 }
1244 }
1245
1246 update_user_contacts(session.user_id(), &session).await?;
1247 },
1248 _ = teardown.changed().fuse() => {}
1249 }
1250
1251 Ok(())
1252}
1253
1254/// Acknowledges a ping from a client, used to keep the connection alive.
1255async fn ping(
1256 _: proto::Ping,
1257 response: Response<proto::Ping>,
1258 _session: MessageContext,
1259) -> Result<()> {
1260 response.send(proto::Ack {})?;
1261 Ok(())
1262}
1263
1264/// Creates a new room for calling (outside of channels)
1265async fn create_room(
1266 _request: proto::CreateRoom,
1267 response: Response<proto::CreateRoom>,
1268 session: MessageContext,
1269) -> Result<()> {
1270 let livekit_room = nanoid::nanoid!(30);
1271
1272 let live_kit_connection_info = util::maybe!(async {
1273 let live_kit = session.app_state.livekit_client.as_ref();
1274 let live_kit = live_kit?;
1275 let user_id = session.user_id().to_string();
1276
1277 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1278
1279 Some(proto::LiveKitConnectionInfo {
1280 server_url: live_kit.url().into(),
1281 token,
1282 can_publish: true,
1283 })
1284 })
1285 .await;
1286
1287 let room = session
1288 .db()
1289 .await
1290 .create_room(session.user_id(), session.connection_id, &livekit_room)
1291 .await?;
1292
1293 response.send(proto::CreateRoomResponse {
1294 room: Some(room.clone()),
1295 live_kit_connection_info,
1296 })?;
1297
1298 update_user_contacts(session.user_id(), &session).await?;
1299 Ok(())
1300}
1301
1302/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1303async fn join_room(
1304 request: proto::JoinRoom,
1305 response: Response<proto::JoinRoom>,
1306 session: MessageContext,
1307) -> Result<()> {
1308 let room_id = RoomId::from_proto(request.id);
1309
1310 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1311
1312 if let Some(channel_id) = channel_id {
1313 return join_channel_internal(channel_id, Box::new(response), session).await;
1314 }
1315
1316 let joined_room = {
1317 let room = session
1318 .db()
1319 .await
1320 .join_room(room_id, session.user_id(), session.connection_id)
1321 .await?;
1322 room_updated(&room.room, &session.peer);
1323 room.into_inner()
1324 };
1325
1326 for connection_id in session
1327 .connection_pool()
1328 .await
1329 .user_connection_ids(session.user_id())
1330 {
1331 session
1332 .peer
1333 .send(
1334 connection_id,
1335 proto::CallCanceled {
1336 room_id: room_id.to_proto(),
1337 },
1338 )
1339 .trace_err();
1340 }
1341
1342 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1343 {
1344 live_kit
1345 .room_token(
1346 &joined_room.room.livekit_room,
1347 &session.user_id().to_string(),
1348 )
1349 .trace_err()
1350 .map(|token| proto::LiveKitConnectionInfo {
1351 server_url: live_kit.url().into(),
1352 token,
1353 can_publish: true,
1354 })
1355 } else {
1356 None
1357 };
1358
1359 response.send(proto::JoinRoomResponse {
1360 room: Some(joined_room.room),
1361 channel_id: None,
1362 live_kit_connection_info,
1363 })?;
1364
1365 update_user_contacts(session.user_id(), &session).await?;
1366 Ok(())
1367}
1368
1369/// Rejoin room is used to reconnect to a room after connection errors.
1370async fn rejoin_room(
1371 request: proto::RejoinRoom,
1372 response: Response<proto::RejoinRoom>,
1373 session: MessageContext,
1374) -> Result<()> {
1375 let room;
1376 let channel;
1377 {
1378 let mut rejoined_room = session
1379 .db()
1380 .await
1381 .rejoin_room(request, session.user_id(), session.connection_id)
1382 .await?;
1383
1384 response.send(proto::RejoinRoomResponse {
1385 room: Some(rejoined_room.room.clone()),
1386 reshared_projects: rejoined_room
1387 .reshared_projects
1388 .iter()
1389 .map(|project| proto::ResharedProject {
1390 id: project.id.to_proto(),
1391 collaborators: project
1392 .collaborators
1393 .iter()
1394 .map(|collaborator| collaborator.to_proto())
1395 .collect(),
1396 })
1397 .collect(),
1398 rejoined_projects: rejoined_room
1399 .rejoined_projects
1400 .iter()
1401 .map(|rejoined_project| rejoined_project.to_proto())
1402 .collect(),
1403 })?;
1404 room_updated(&rejoined_room.room, &session.peer);
1405
1406 for project in &rejoined_room.reshared_projects {
1407 for collaborator in &project.collaborators {
1408 session
1409 .peer
1410 .send(
1411 collaborator.connection_id,
1412 proto::UpdateProjectCollaborator {
1413 project_id: project.id.to_proto(),
1414 old_peer_id: Some(project.old_connection_id.into()),
1415 new_peer_id: Some(session.connection_id.into()),
1416 },
1417 )
1418 .trace_err();
1419 }
1420
1421 broadcast(
1422 Some(session.connection_id),
1423 project
1424 .collaborators
1425 .iter()
1426 .map(|collaborator| collaborator.connection_id),
1427 |connection_id| {
1428 session.peer.forward_send(
1429 session.connection_id,
1430 connection_id,
1431 proto::UpdateProject {
1432 project_id: project.id.to_proto(),
1433 worktrees: project.worktrees.clone(),
1434 },
1435 )
1436 },
1437 );
1438 }
1439
1440 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1441
1442 let rejoined_room = rejoined_room.into_inner();
1443
1444 room = rejoined_room.room;
1445 channel = rejoined_room.channel;
1446 }
1447
1448 if let Some(channel) = channel {
1449 channel_updated(
1450 &channel,
1451 &room,
1452 &session.peer,
1453 &*session.connection_pool().await,
1454 );
1455 }
1456
1457 update_user_contacts(session.user_id(), &session).await?;
1458 Ok(())
1459}
1460
1461fn notify_rejoined_projects(
1462 rejoined_projects: &mut Vec<RejoinedProject>,
1463 session: &Session,
1464) -> Result<()> {
1465 for project in rejoined_projects.iter() {
1466 for collaborator in &project.collaborators {
1467 session
1468 .peer
1469 .send(
1470 collaborator.connection_id,
1471 proto::UpdateProjectCollaborator {
1472 project_id: project.id.to_proto(),
1473 old_peer_id: Some(project.old_connection_id.into()),
1474 new_peer_id: Some(session.connection_id.into()),
1475 },
1476 )
1477 .trace_err();
1478 }
1479 }
1480
1481 for project in rejoined_projects {
1482 for worktree in mem::take(&mut project.worktrees) {
1483 // Stream this worktree's entries.
1484 let message = proto::UpdateWorktree {
1485 project_id: project.id.to_proto(),
1486 worktree_id: worktree.id,
1487 abs_path: worktree.abs_path.clone(),
1488 root_name: worktree.root_name,
1489 root_repo_common_dir: worktree.root_repo_common_dir,
1490 updated_entries: worktree.updated_entries,
1491 removed_entries: worktree.removed_entries,
1492 scan_id: worktree.scan_id,
1493 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1494 updated_repositories: worktree.updated_repositories,
1495 removed_repositories: worktree.removed_repositories,
1496 };
1497 for update in proto::split_worktree_update(message) {
1498 session.peer.send(session.connection_id, update)?;
1499 }
1500
1501 // Stream this worktree's diagnostics.
1502 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1503 if let Some(summary) = worktree_diagnostics.next() {
1504 let message = proto::UpdateDiagnosticSummary {
1505 project_id: project.id.to_proto(),
1506 worktree_id: worktree.id,
1507 summary: Some(summary),
1508 more_summaries: worktree_diagnostics.collect(),
1509 };
1510 session.peer.send(session.connection_id, message)?;
1511 }
1512
1513 for settings_file in worktree.settings_files {
1514 session.peer.send(
1515 session.connection_id,
1516 proto::UpdateWorktreeSettings {
1517 project_id: project.id.to_proto(),
1518 worktree_id: worktree.id,
1519 path: settings_file.path,
1520 content: Some(settings_file.content),
1521 kind: Some(settings_file.kind.to_proto().into()),
1522 outside_worktree: Some(settings_file.outside_worktree),
1523 },
1524 )?;
1525 }
1526 }
1527
1528 for repository in mem::take(&mut project.updated_repositories) {
1529 for update in split_repository_update(repository) {
1530 session.peer.send(session.connection_id, update)?;
1531 }
1532 }
1533
1534 for id in mem::take(&mut project.removed_repositories) {
1535 session.peer.send(
1536 session.connection_id,
1537 proto::RemoveRepository {
1538 project_id: project.id.to_proto(),
1539 id,
1540 },
1541 )?;
1542 }
1543 }
1544
1545 Ok(())
1546}
1547
1548/// leave room disconnects from the room.
1549async fn leave_room(
1550 _: proto::LeaveRoom,
1551 response: Response<proto::LeaveRoom>,
1552 session: MessageContext,
1553) -> Result<()> {
1554 leave_room_for_session(&session, session.connection_id).await?;
1555 response.send(proto::Ack {})?;
1556 Ok(())
1557}
1558
1559/// Updates the permissions of someone else in the room.
1560async fn set_room_participant_role(
1561 request: proto::SetRoomParticipantRole,
1562 response: Response<proto::SetRoomParticipantRole>,
1563 session: MessageContext,
1564) -> Result<()> {
1565 let user_id = UserId::from_proto(request.user_id);
1566 let role = ChannelRole::from(request.role());
1567
1568 let (livekit_room, can_publish) = {
1569 let room = session
1570 .db()
1571 .await
1572 .set_room_participant_role(
1573 session.user_id(),
1574 RoomId::from_proto(request.room_id),
1575 user_id,
1576 role,
1577 )
1578 .await?;
1579
1580 let livekit_room = room.livekit_room.clone();
1581 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1582 room_updated(&room, &session.peer);
1583 (livekit_room, can_publish)
1584 };
1585
1586 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1587 live_kit
1588 .update_participant(
1589 livekit_room.clone(),
1590 request.user_id.to_string(),
1591 livekit_api::proto::ParticipantPermission {
1592 can_subscribe: true,
1593 can_publish,
1594 can_publish_data: can_publish,
1595 hidden: false,
1596 recorder: false,
1597 },
1598 )
1599 .await
1600 .trace_err();
1601 }
1602
1603 response.send(proto::Ack {})?;
1604 Ok(())
1605}
1606
1607/// Call someone else into the current room
1608async fn call(
1609 request: proto::Call,
1610 response: Response<proto::Call>,
1611 session: MessageContext,
1612) -> Result<()> {
1613 let room_id = RoomId::from_proto(request.room_id);
1614 let calling_user_id = session.user_id();
1615 let calling_connection_id = session.connection_id;
1616 let called_user_id = UserId::from_proto(request.called_user_id);
1617 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1618 if !session
1619 .db()
1620 .await
1621 .has_contact(calling_user_id, called_user_id)
1622 .await?
1623 {
1624 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1625 }
1626
1627 let incoming_call = {
1628 let (room, incoming_call) = &mut *session
1629 .db()
1630 .await
1631 .call(
1632 room_id,
1633 calling_user_id,
1634 calling_connection_id,
1635 called_user_id,
1636 initial_project_id,
1637 )
1638 .await?;
1639 room_updated(room, &session.peer);
1640 mem::take(incoming_call)
1641 };
1642 update_user_contacts(called_user_id, &session).await?;
1643
1644 let mut calls = session
1645 .connection_pool()
1646 .await
1647 .user_connection_ids(called_user_id)
1648 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1649 .collect::<FuturesUnordered<_>>();
1650
1651 while let Some(call_response) = calls.next().await {
1652 match call_response.as_ref() {
1653 Ok(_) => {
1654 response.send(proto::Ack {})?;
1655 return Ok(());
1656 }
1657 Err(_) => {
1658 call_response.trace_err();
1659 }
1660 }
1661 }
1662
1663 {
1664 let room = session
1665 .db()
1666 .await
1667 .call_failed(room_id, called_user_id)
1668 .await?;
1669 room_updated(&room, &session.peer);
1670 }
1671 update_user_contacts(called_user_id, &session).await?;
1672
1673 Err(anyhow!("failed to ring user"))?
1674}
1675
1676/// Cancel an outgoing call.
1677async fn cancel_call(
1678 request: proto::CancelCall,
1679 response: Response<proto::CancelCall>,
1680 session: MessageContext,
1681) -> Result<()> {
1682 let called_user_id = UserId::from_proto(request.called_user_id);
1683 let room_id = RoomId::from_proto(request.room_id);
1684 {
1685 let room = session
1686 .db()
1687 .await
1688 .cancel_call(room_id, session.connection_id, called_user_id)
1689 .await?;
1690 room_updated(&room, &session.peer);
1691 }
1692
1693 for connection_id in session
1694 .connection_pool()
1695 .await
1696 .user_connection_ids(called_user_id)
1697 {
1698 session
1699 .peer
1700 .send(
1701 connection_id,
1702 proto::CallCanceled {
1703 room_id: room_id.to_proto(),
1704 },
1705 )
1706 .trace_err();
1707 }
1708 response.send(proto::Ack {})?;
1709
1710 update_user_contacts(called_user_id, &session).await?;
1711 Ok(())
1712}
1713
1714/// Decline an incoming call.
1715async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1716 let room_id = RoomId::from_proto(message.room_id);
1717 {
1718 let room = session
1719 .db()
1720 .await
1721 .decline_call(Some(room_id), session.user_id())
1722 .await?
1723 .context("declining call")?;
1724 room_updated(&room, &session.peer);
1725 }
1726
1727 for connection_id in session
1728 .connection_pool()
1729 .await
1730 .user_connection_ids(session.user_id())
1731 {
1732 session
1733 .peer
1734 .send(
1735 connection_id,
1736 proto::CallCanceled {
1737 room_id: room_id.to_proto(),
1738 },
1739 )
1740 .trace_err();
1741 }
1742 update_user_contacts(session.user_id(), &session).await?;
1743 Ok(())
1744}
1745
1746/// Updates other participants in the room with your current location.
1747async fn update_participant_location(
1748 request: proto::UpdateParticipantLocation,
1749 response: Response<proto::UpdateParticipantLocation>,
1750 session: MessageContext,
1751) -> Result<()> {
1752 let room_id = RoomId::from_proto(request.room_id);
1753 let location = request.location.context("invalid location")?;
1754
1755 let db = session.db().await;
1756 let room = db
1757 .update_room_participant_location(room_id, session.connection_id, location)
1758 .await?;
1759
1760 room_updated(&room, &session.peer);
1761 response.send(proto::Ack {})?;
1762 Ok(())
1763}
1764
1765/// Share a project into the room.
1766async fn share_project(
1767 request: proto::ShareProject,
1768 response: Response<proto::ShareProject>,
1769 session: MessageContext,
1770) -> Result<()> {
1771 let (project_id, room) = &*session
1772 .db()
1773 .await
1774 .share_project(
1775 RoomId::from_proto(request.room_id),
1776 session.connection_id,
1777 &request.worktrees,
1778 request.is_ssh_project,
1779 request.windows_paths.unwrap_or(false),
1780 &request.features,
1781 )
1782 .await?;
1783 response.send(proto::ShareProjectResponse {
1784 project_id: project_id.to_proto(),
1785 })?;
1786 room_updated(room, &session.peer);
1787
1788 Ok(())
1789}
1790
1791/// Unshare a project from the room.
1792async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1793 let project_id = ProjectId::from_proto(message.project_id);
1794 unshare_project_internal(project_id, session.connection_id, &session).await
1795}
1796
1797async fn unshare_project_internal(
1798 project_id: ProjectId,
1799 connection_id: ConnectionId,
1800 session: &Session,
1801) -> Result<()> {
1802 let delete = {
1803 let room_guard = session
1804 .db()
1805 .await
1806 .unshare_project(project_id, connection_id)
1807 .await?;
1808
1809 let (delete, room, guest_connection_ids) = &*room_guard;
1810
1811 let message = proto::UnshareProject {
1812 project_id: project_id.to_proto(),
1813 };
1814
1815 broadcast(
1816 Some(connection_id),
1817 guest_connection_ids.iter().copied(),
1818 |conn_id| session.peer.send(conn_id, message.clone()),
1819 );
1820 if let Some(room) = room {
1821 room_updated(room, &session.peer);
1822 }
1823
1824 *delete
1825 };
1826
1827 if delete {
1828 let db = session.db().await;
1829 db.delete_project(project_id).await?;
1830 }
1831
1832 Ok(())
1833}
1834
1835/// Join someone elses shared project.
1836async fn join_project(
1837 request: proto::JoinProject,
1838 response: Response<proto::JoinProject>,
1839 session: MessageContext,
1840) -> Result<()> {
1841 let project_id = ProjectId::from_proto(request.project_id);
1842
1843 tracing::info!(%project_id, "join project");
1844
1845 let db = session.db().await;
1846 let project_model = db.get_project(project_id).await?;
1847 let host_features: Vec<String> =
1848 serde_json::from_str(&project_model.features).unwrap_or_default();
1849 let guest_features: HashSet<_> = request.features.iter().collect();
1850 let host_features_set: HashSet<_> = host_features.iter().collect();
1851 if guest_features != host_features_set {
1852 let host_connection_id = project_model.host_connection()?;
1853 let mut pool = session.connection_pool().await;
1854 let host_version = pool
1855 .connection(host_connection_id)
1856 .map(|c| c.zed_version.to_string());
1857 let guest_version = pool
1858 .connection(session.connection_id)
1859 .map(|c| c.zed_version.to_string());
1860 drop(pool);
1861 Err(anyhow!(
1862 "The host (v{}) and guest (v{}) are using incompatible versions of Zed. The peer with the older version must update to collaborate.",
1863 host_version.as_deref().unwrap_or("unknown"),
1864 guest_version.as_deref().unwrap_or("unknown"),
1865 ))?;
1866 }
1867
1868 let (project, replica_id) = &mut *db
1869 .join_project(
1870 project_id,
1871 session.connection_id,
1872 session.user_id(),
1873 request.committer_name.clone(),
1874 request.committer_email.clone(),
1875 )
1876 .await?;
1877 drop(db);
1878
1879 tracing::info!(%project_id, "join remote project");
1880 let collaborators = project
1881 .collaborators
1882 .iter()
1883 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1884 .map(|collaborator| collaborator.to_proto())
1885 .collect::<Vec<_>>();
1886 let project_id = project.id;
1887 let guest_user_id = session.user_id();
1888
1889 let worktrees = project
1890 .worktrees
1891 .iter()
1892 .map(|(id, worktree)| proto::WorktreeMetadata {
1893 id: *id,
1894 root_name: worktree.root_name.clone(),
1895 visible: worktree.visible,
1896 abs_path: worktree.abs_path.clone(),
1897 })
1898 .collect::<Vec<_>>();
1899
1900 let add_project_collaborator = proto::AddProjectCollaborator {
1901 project_id: project_id.to_proto(),
1902 collaborator: Some(proto::Collaborator {
1903 peer_id: Some(session.connection_id.into()),
1904 replica_id: replica_id.0 as u32,
1905 user_id: guest_user_id.to_proto(),
1906 is_host: false,
1907 committer_name: request.committer_name.clone(),
1908 committer_email: request.committer_email.clone(),
1909 }),
1910 };
1911
1912 for collaborator in &collaborators {
1913 session
1914 .peer
1915 .send(
1916 collaborator.peer_id.unwrap().into(),
1917 add_project_collaborator.clone(),
1918 )
1919 .trace_err();
1920 }
1921
1922 // First, we send the metadata associated with each worktree.
1923 let (language_servers, language_server_capabilities) = project
1924 .language_servers
1925 .clone()
1926 .into_iter()
1927 .map(|server| (server.server, server.capabilities))
1928 .unzip();
1929 response.send(proto::JoinProjectResponse {
1930 project_id: project.id.0 as u64,
1931 worktrees,
1932 replica_id: replica_id.0 as u32,
1933 collaborators,
1934 language_servers,
1935 language_server_capabilities,
1936 role: project.role.into(),
1937 windows_paths: project.path_style == PathStyle::Windows,
1938 features: project.features.clone(),
1939 })?;
1940
1941 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1942 // Stream this worktree's entries.
1943 let message = proto::UpdateWorktree {
1944 project_id: project_id.to_proto(),
1945 worktree_id,
1946 abs_path: worktree.abs_path.clone(),
1947 root_name: worktree.root_name,
1948 root_repo_common_dir: worktree.root_repo_common_dir,
1949 updated_entries: worktree.entries,
1950 removed_entries: Default::default(),
1951 scan_id: worktree.scan_id,
1952 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1953 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1954 removed_repositories: Default::default(),
1955 };
1956 for update in proto::split_worktree_update(message) {
1957 session.peer.send(session.connection_id, update.clone())?;
1958 }
1959
1960 // Stream this worktree's diagnostics.
1961 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1962 if let Some(summary) = worktree_diagnostics.next() {
1963 let message = proto::UpdateDiagnosticSummary {
1964 project_id: project.id.to_proto(),
1965 worktree_id: worktree.id,
1966 summary: Some(summary),
1967 more_summaries: worktree_diagnostics.collect(),
1968 };
1969 session.peer.send(session.connection_id, message)?;
1970 }
1971
1972 for settings_file in worktree.settings_files {
1973 session.peer.send(
1974 session.connection_id,
1975 proto::UpdateWorktreeSettings {
1976 project_id: project_id.to_proto(),
1977 worktree_id: worktree.id,
1978 path: settings_file.path,
1979 content: Some(settings_file.content),
1980 kind: Some(settings_file.kind.to_proto() as i32),
1981 outside_worktree: Some(settings_file.outside_worktree),
1982 },
1983 )?;
1984 }
1985 }
1986
1987 for repository in mem::take(&mut project.repositories) {
1988 for update in split_repository_update(repository) {
1989 session.peer.send(session.connection_id, update)?;
1990 }
1991 }
1992
1993 for language_server in &project.language_servers {
1994 session.peer.send(
1995 session.connection_id,
1996 proto::UpdateLanguageServer {
1997 project_id: project_id.to_proto(),
1998 server_name: Some(language_server.server.name.clone()),
1999 language_server_id: language_server.server.id,
2000 variant: Some(
2001 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2002 proto::LspDiskBasedDiagnosticsUpdated {},
2003 ),
2004 ),
2005 },
2006 )?;
2007 }
2008
2009 Ok(())
2010}
2011
2012/// Leave someone elses shared project.
2013async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2014 let sender_id = session.connection_id;
2015 let project_id = ProjectId::from_proto(request.project_id);
2016 let db = session.db().await;
2017
2018 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2019 tracing::info!(
2020 %project_id,
2021 "leave project"
2022 );
2023
2024 project_left(project, &session);
2025 if let Some(room) = room {
2026 room_updated(room, &session.peer);
2027 }
2028
2029 Ok(())
2030}
2031
2032/// Updates other participants with changes to the project
2033async fn update_project(
2034 request: proto::UpdateProject,
2035 response: Response<proto::UpdateProject>,
2036 session: MessageContext,
2037) -> Result<()> {
2038 let project_id = ProjectId::from_proto(request.project_id);
2039 let (room, guest_connection_ids) = &*session
2040 .db()
2041 .await
2042 .update_project(project_id, session.connection_id, &request.worktrees)
2043 .await?;
2044 broadcast(
2045 Some(session.connection_id),
2046 guest_connection_ids.iter().copied(),
2047 |connection_id| {
2048 session
2049 .peer
2050 .forward_send(session.connection_id, connection_id, request.clone())
2051 },
2052 );
2053 if let Some(room) = room {
2054 room_updated(room, &session.peer);
2055 }
2056 response.send(proto::Ack {})?;
2057
2058 Ok(())
2059}
2060
2061/// Updates other participants with changes to the worktree
2062async fn update_worktree(
2063 request: proto::UpdateWorktree,
2064 response: Response<proto::UpdateWorktree>,
2065 session: MessageContext,
2066) -> Result<()> {
2067 let guest_connection_ids = session
2068 .db()
2069 .await
2070 .update_worktree(&request, session.connection_id)
2071 .await?;
2072
2073 broadcast(
2074 Some(session.connection_id),
2075 guest_connection_ids.iter().copied(),
2076 |connection_id| {
2077 session
2078 .peer
2079 .forward_send(session.connection_id, connection_id, request.clone())
2080 },
2081 );
2082 response.send(proto::Ack {})?;
2083 Ok(())
2084}
2085
2086async fn update_repository(
2087 request: proto::UpdateRepository,
2088 response: Response<proto::UpdateRepository>,
2089 session: MessageContext,
2090) -> Result<()> {
2091 let guest_connection_ids = session
2092 .db()
2093 .await
2094 .update_repository(&request, session.connection_id)
2095 .await?;
2096
2097 broadcast(
2098 Some(session.connection_id),
2099 guest_connection_ids.iter().copied(),
2100 |connection_id| {
2101 session
2102 .peer
2103 .forward_send(session.connection_id, connection_id, request.clone())
2104 },
2105 );
2106 response.send(proto::Ack {})?;
2107 Ok(())
2108}
2109
2110async fn remove_repository(
2111 request: proto::RemoveRepository,
2112 response: Response<proto::RemoveRepository>,
2113 session: MessageContext,
2114) -> Result<()> {
2115 let guest_connection_ids = session
2116 .db()
2117 .await
2118 .remove_repository(&request, session.connection_id)
2119 .await?;
2120
2121 broadcast(
2122 Some(session.connection_id),
2123 guest_connection_ids.iter().copied(),
2124 |connection_id| {
2125 session
2126 .peer
2127 .forward_send(session.connection_id, connection_id, request.clone())
2128 },
2129 );
2130 response.send(proto::Ack {})?;
2131 Ok(())
2132}
2133
2134/// Updates other participants with changes to the diagnostics
2135async fn update_diagnostic_summary(
2136 message: proto::UpdateDiagnosticSummary,
2137 session: MessageContext,
2138) -> Result<()> {
2139 let guest_connection_ids = session
2140 .db()
2141 .await
2142 .update_diagnostic_summary(&message, session.connection_id)
2143 .await?;
2144
2145 broadcast(
2146 Some(session.connection_id),
2147 guest_connection_ids.iter().copied(),
2148 |connection_id| {
2149 session
2150 .peer
2151 .forward_send(session.connection_id, connection_id, message.clone())
2152 },
2153 );
2154
2155 Ok(())
2156}
2157
2158/// Updates other participants with changes to the worktree settings
2159async fn update_worktree_settings(
2160 message: proto::UpdateWorktreeSettings,
2161 session: MessageContext,
2162) -> Result<()> {
2163 let guest_connection_ids = session
2164 .db()
2165 .await
2166 .update_worktree_settings(&message, session.connection_id)
2167 .await?;
2168
2169 broadcast(
2170 Some(session.connection_id),
2171 guest_connection_ids.iter().copied(),
2172 |connection_id| {
2173 session
2174 .peer
2175 .forward_send(session.connection_id, connection_id, message.clone())
2176 },
2177 );
2178
2179 Ok(())
2180}
2181
2182/// Notify other participants that a language server has started.
2183async fn start_language_server(
2184 request: proto::StartLanguageServer,
2185 session: MessageContext,
2186) -> Result<()> {
2187 let guest_connection_ids = session
2188 .db()
2189 .await
2190 .start_language_server(&request, session.connection_id)
2191 .await?;
2192
2193 broadcast(
2194 Some(session.connection_id),
2195 guest_connection_ids.iter().copied(),
2196 |connection_id| {
2197 session
2198 .peer
2199 .forward_send(session.connection_id, connection_id, request.clone())
2200 },
2201 );
2202 Ok(())
2203}
2204
2205/// Notify other participants that a language server has changed.
2206async fn update_language_server(
2207 request: proto::UpdateLanguageServer,
2208 session: MessageContext,
2209) -> Result<()> {
2210 let project_id = ProjectId::from_proto(request.project_id);
2211 let db = session.db().await;
2212
2213 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2214 && let Some(capabilities) = update.capabilities.clone()
2215 {
2216 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2217 .await?;
2218 }
2219
2220 let project_connection_ids = db
2221 .project_connection_ids(project_id, session.connection_id, true)
2222 .await?;
2223 broadcast(
2224 Some(session.connection_id),
2225 project_connection_ids.iter().copied(),
2226 |connection_id| {
2227 session
2228 .peer
2229 .forward_send(session.connection_id, connection_id, request.clone())
2230 },
2231 );
2232 Ok(())
2233}
2234
2235/// forward a project request to the host. These requests should be read only
2236/// as guests are allowed to send them.
2237async fn forward_read_only_project_request<T>(
2238 request: T,
2239 response: Response<T>,
2240 session: MessageContext,
2241) -> Result<()>
2242where
2243 T: EntityMessage + RequestMessage,
2244{
2245 let project_id = ProjectId::from_proto(request.remote_entity_id());
2246 let host_connection_id = session
2247 .db()
2248 .await
2249 .host_for_read_only_project_request(project_id, session.connection_id)
2250 .await?;
2251 let payload = session.forward_request(host_connection_id, request).await?;
2252 response.send(payload)?;
2253 Ok(())
2254}
2255
2256/// forward a project request to the host. These requests are disallowed
2257/// for guests.
2258async fn forward_mutating_project_request<T>(
2259 request: T,
2260 response: Response<T>,
2261 session: MessageContext,
2262) -> Result<()>
2263where
2264 T: EntityMessage + RequestMessage,
2265{
2266 let project_id = ProjectId::from_proto(request.remote_entity_id());
2267
2268 let host_connection_id = session
2269 .db()
2270 .await
2271 .host_for_mutating_project_request(project_id, session.connection_id)
2272 .await?;
2273 let payload = session.forward_request(host_connection_id, request).await?;
2274 response.send(payload)?;
2275 Ok(())
2276}
2277
2278async fn disallow_guest_request<T>(
2279 _request: T,
2280 response: Response<T>,
2281 _session: MessageContext,
2282) -> Result<()>
2283where
2284 T: RequestMessage,
2285{
2286 response.peer.respond_with_error(
2287 response.receipt,
2288 ErrorCode::Forbidden
2289 .message("request is not allowed for guests".to_string())
2290 .to_proto(),
2291 )?;
2292 response.responded.store(true, SeqCst);
2293 Ok(())
2294}
2295
2296async fn lsp_query(
2297 request: proto::LspQuery,
2298 response: Response<proto::LspQuery>,
2299 session: MessageContext,
2300) -> Result<()> {
2301 let (name, should_write) = request.query_name_and_write_permissions();
2302 tracing::Span::current().record("lsp_query_request", name);
2303 tracing::info!("lsp_query message received");
2304 if should_write {
2305 forward_mutating_project_request(request, response, session).await
2306 } else {
2307 forward_read_only_project_request(request, response, session).await
2308 }
2309}
2310
2311/// Notify other participants that a new buffer has been created
2312async fn create_buffer_for_peer(
2313 request: proto::CreateBufferForPeer,
2314 session: MessageContext,
2315) -> Result<()> {
2316 session
2317 .db()
2318 .await
2319 .check_user_is_project_host(
2320 ProjectId::from_proto(request.project_id),
2321 session.connection_id,
2322 )
2323 .await?;
2324 let peer_id = request.peer_id.context("invalid peer id")?;
2325 session
2326 .peer
2327 .forward_send(session.connection_id, peer_id.into(), request)?;
2328 Ok(())
2329}
2330
2331/// Notify other participants that a new image has been created
2332async fn create_image_for_peer(
2333 request: proto::CreateImageForPeer,
2334 session: MessageContext,
2335) -> Result<()> {
2336 session
2337 .db()
2338 .await
2339 .check_user_is_project_host(
2340 ProjectId::from_proto(request.project_id),
2341 session.connection_id,
2342 )
2343 .await?;
2344 let peer_id = request.peer_id.context("invalid peer id")?;
2345 session
2346 .peer
2347 .forward_send(session.connection_id, peer_id.into(), request)?;
2348 Ok(())
2349}
2350
2351/// Notify other participants that a buffer has been updated. This is
2352/// allowed for guests as long as the update is limited to selections.
2353async fn update_buffer(
2354 request: proto::UpdateBuffer,
2355 response: Response<proto::UpdateBuffer>,
2356 session: MessageContext,
2357) -> Result<()> {
2358 let project_id = ProjectId::from_proto(request.project_id);
2359 let mut capability = Capability::ReadOnly;
2360
2361 for op in request.operations.iter() {
2362 match op.variant {
2363 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2364 Some(_) => capability = Capability::ReadWrite,
2365 }
2366 }
2367
2368 let host = {
2369 let guard = session
2370 .db()
2371 .await
2372 .connections_for_buffer_update(project_id, session.connection_id, capability)
2373 .await?;
2374
2375 let (host, guests) = &*guard;
2376
2377 broadcast(
2378 Some(session.connection_id),
2379 guests.clone(),
2380 |connection_id| {
2381 session
2382 .peer
2383 .forward_send(session.connection_id, connection_id, request.clone())
2384 },
2385 );
2386
2387 *host
2388 };
2389
2390 if host != session.connection_id {
2391 session.forward_request(host, request.clone()).await?;
2392 }
2393
2394 response.send(proto::Ack {})?;
2395 Ok(())
2396}
2397
2398async fn forward_project_search_chunk(
2399 message: proto::FindSearchCandidatesChunk,
2400 response: Response<proto::FindSearchCandidatesChunk>,
2401 session: MessageContext,
2402) -> Result<()> {
2403 let peer_id = message.peer_id.context("missing peer_id")?;
2404 let payload = session
2405 .peer
2406 .forward_request(session.connection_id, peer_id.into(), message)
2407 .await?;
2408 response.send(payload)?;
2409 Ok(())
2410}
2411
2412/// Notify other participants that a project has been updated.
2413async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2414 request: T,
2415 session: MessageContext,
2416) -> Result<()> {
2417 let project_id = ProjectId::from_proto(request.remote_entity_id());
2418 let project_connection_ids = session
2419 .db()
2420 .await
2421 .project_connection_ids(project_id, session.connection_id, false)
2422 .await?;
2423
2424 broadcast(
2425 Some(session.connection_id),
2426 project_connection_ids.iter().copied(),
2427 |connection_id| {
2428 session
2429 .peer
2430 .forward_send(session.connection_id, connection_id, request.clone())
2431 },
2432 );
2433 Ok(())
2434}
2435
2436/// Start following another user in a call.
2437async fn follow(
2438 request: proto::Follow,
2439 response: Response<proto::Follow>,
2440 session: MessageContext,
2441) -> Result<()> {
2442 let room_id = RoomId::from_proto(request.room_id);
2443 let project_id = request.project_id.map(ProjectId::from_proto);
2444 let leader_id = request.leader_id.context("invalid leader id")?.into();
2445 let follower_id = session.connection_id;
2446
2447 session
2448 .db()
2449 .await
2450 .check_room_participants(room_id, leader_id, session.connection_id)
2451 .await?;
2452
2453 let response_payload = session.forward_request(leader_id, request).await?;
2454 response.send(response_payload)?;
2455
2456 if let Some(project_id) = project_id {
2457 let room = session
2458 .db()
2459 .await
2460 .follow(room_id, project_id, leader_id, follower_id)
2461 .await?;
2462 room_updated(&room, &session.peer);
2463 }
2464
2465 Ok(())
2466}
2467
2468/// Stop following another user in a call.
2469async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2470 let room_id = RoomId::from_proto(request.room_id);
2471 let project_id = request.project_id.map(ProjectId::from_proto);
2472 let leader_id = request.leader_id.context("invalid leader id")?.into();
2473 let follower_id = session.connection_id;
2474
2475 session
2476 .db()
2477 .await
2478 .check_room_participants(room_id, leader_id, session.connection_id)
2479 .await?;
2480
2481 session
2482 .peer
2483 .forward_send(session.connection_id, leader_id, request)?;
2484
2485 if let Some(project_id) = project_id {
2486 let room = session
2487 .db()
2488 .await
2489 .unfollow(room_id, project_id, leader_id, follower_id)
2490 .await?;
2491 room_updated(&room, &session.peer);
2492 }
2493
2494 Ok(())
2495}
2496
2497/// Notify everyone following you of your current location.
2498async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2499 let room_id = RoomId::from_proto(request.room_id);
2500 let database = session.db.lock().await;
2501
2502 let connection_ids = if let Some(project_id) = request.project_id {
2503 let project_id = ProjectId::from_proto(project_id);
2504 database
2505 .project_connection_ids(project_id, session.connection_id, true)
2506 .await?
2507 } else {
2508 database
2509 .room_connection_ids(room_id, session.connection_id)
2510 .await?
2511 };
2512
2513 // For now, don't send view update messages back to that view's current leader.
2514 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2515 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2516 _ => None,
2517 });
2518
2519 for connection_id in connection_ids.iter().cloned() {
2520 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2521 session
2522 .peer
2523 .forward_send(session.connection_id, connection_id, request.clone())?;
2524 }
2525 }
2526 Ok(())
2527}
2528
2529/// Get public data about users.
2530async fn get_users(
2531 request: proto::GetUsers,
2532 response: Response<proto::GetUsers>,
2533 session: MessageContext,
2534) -> Result<()> {
2535 let user_ids = request
2536 .user_ids
2537 .into_iter()
2538 .map(UserId::from_proto)
2539 .collect();
2540 let users = session
2541 .db()
2542 .await
2543 .get_users_by_ids(user_ids)
2544 .await?
2545 .into_iter()
2546 .map(|user| proto::User {
2547 id: user.id.to_proto(),
2548 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2549 github_login: user.github_login,
2550 name: user.name,
2551 })
2552 .collect();
2553 response.send(proto::UsersResponse { users })?;
2554 Ok(())
2555}
2556
2557/// Search for users (to invite) buy Github login
2558async fn fuzzy_search_users(
2559 request: proto::FuzzySearchUsers,
2560 response: Response<proto::FuzzySearchUsers>,
2561 session: MessageContext,
2562) -> Result<()> {
2563 let query = request.query;
2564 let users = match query.len() {
2565 0 => vec![],
2566 1 | 2 => session
2567 .db()
2568 .await
2569 .get_user_by_github_login(&query)
2570 .await?
2571 .into_iter()
2572 .collect(),
2573 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2574 };
2575 let users = users
2576 .into_iter()
2577 .filter(|user| user.id != session.user_id())
2578 .map(|user| proto::User {
2579 id: user.id.to_proto(),
2580 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2581 github_login: user.github_login,
2582 name: user.name,
2583 })
2584 .collect();
2585 response.send(proto::UsersResponse { users })?;
2586 Ok(())
2587}
2588
2589/// Send a contact request to another user.
2590async fn request_contact(
2591 request: proto::RequestContact,
2592 response: Response<proto::RequestContact>,
2593 session: MessageContext,
2594) -> Result<()> {
2595 let requester_id = session.user_id();
2596 let responder_id = UserId::from_proto(request.responder_id);
2597 if requester_id == responder_id {
2598 return Err(anyhow!("cannot add yourself as a contact"))?;
2599 }
2600
2601 let notifications = session
2602 .db()
2603 .await
2604 .send_contact_request(requester_id, responder_id)
2605 .await?;
2606
2607 // Update outgoing contact requests of requester
2608 let mut update = proto::UpdateContacts::default();
2609 update.outgoing_requests.push(responder_id.to_proto());
2610 for connection_id in session
2611 .connection_pool()
2612 .await
2613 .user_connection_ids(requester_id)
2614 {
2615 session.peer.send(connection_id, update.clone())?;
2616 }
2617
2618 // Update incoming contact requests of responder
2619 let mut update = proto::UpdateContacts::default();
2620 update
2621 .incoming_requests
2622 .push(proto::IncomingContactRequest {
2623 requester_id: requester_id.to_proto(),
2624 });
2625 let connection_pool = session.connection_pool().await;
2626 for connection_id in connection_pool.user_connection_ids(responder_id) {
2627 session.peer.send(connection_id, update.clone())?;
2628 }
2629
2630 send_notifications(&connection_pool, &session.peer, notifications);
2631
2632 response.send(proto::Ack {})?;
2633 Ok(())
2634}
2635
2636/// Accept or decline a contact request
2637async fn respond_to_contact_request(
2638 request: proto::RespondToContactRequest,
2639 response: Response<proto::RespondToContactRequest>,
2640 session: MessageContext,
2641) -> Result<()> {
2642 let responder_id = session.user_id();
2643 let requester_id = UserId::from_proto(request.requester_id);
2644 let db = session.db().await;
2645 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2646 db.dismiss_contact_notification(responder_id, requester_id)
2647 .await?;
2648 } else {
2649 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2650
2651 let notifications = db
2652 .respond_to_contact_request(responder_id, requester_id, accept)
2653 .await?;
2654 let requester_busy = db.is_user_busy(requester_id).await?;
2655 let responder_busy = db.is_user_busy(responder_id).await?;
2656
2657 let pool = session.connection_pool().await;
2658 // Update responder with new contact
2659 let mut update = proto::UpdateContacts::default();
2660 if accept {
2661 update
2662 .contacts
2663 .push(contact_for_user(requester_id, requester_busy, &pool));
2664 }
2665 update
2666 .remove_incoming_requests
2667 .push(requester_id.to_proto());
2668 for connection_id in pool.user_connection_ids(responder_id) {
2669 session.peer.send(connection_id, update.clone())?;
2670 }
2671
2672 // Update requester with new contact
2673 let mut update = proto::UpdateContacts::default();
2674 if accept {
2675 update
2676 .contacts
2677 .push(contact_for_user(responder_id, responder_busy, &pool));
2678 }
2679 update
2680 .remove_outgoing_requests
2681 .push(responder_id.to_proto());
2682
2683 for connection_id in pool.user_connection_ids(requester_id) {
2684 session.peer.send(connection_id, update.clone())?;
2685 }
2686
2687 send_notifications(&pool, &session.peer, notifications);
2688 }
2689
2690 response.send(proto::Ack {})?;
2691 Ok(())
2692}
2693
2694/// Remove a contact.
2695async fn remove_contact(
2696 request: proto::RemoveContact,
2697 response: Response<proto::RemoveContact>,
2698 session: MessageContext,
2699) -> Result<()> {
2700 let requester_id = session.user_id();
2701 let responder_id = UserId::from_proto(request.user_id);
2702 let db = session.db().await;
2703 let (contact_accepted, deleted_notification_id) =
2704 db.remove_contact(requester_id, responder_id).await?;
2705
2706 let pool = session.connection_pool().await;
2707 // Update outgoing contact requests of requester
2708 let mut update = proto::UpdateContacts::default();
2709 if contact_accepted {
2710 update.remove_contacts.push(responder_id.to_proto());
2711 } else {
2712 update
2713 .remove_outgoing_requests
2714 .push(responder_id.to_proto());
2715 }
2716 for connection_id in pool.user_connection_ids(requester_id) {
2717 session.peer.send(connection_id, update.clone())?;
2718 }
2719
2720 // Update incoming contact requests of responder
2721 let mut update = proto::UpdateContacts::default();
2722 if contact_accepted {
2723 update.remove_contacts.push(requester_id.to_proto());
2724 } else {
2725 update
2726 .remove_incoming_requests
2727 .push(requester_id.to_proto());
2728 }
2729 for connection_id in pool.user_connection_ids(responder_id) {
2730 session.peer.send(connection_id, update.clone())?;
2731 if let Some(notification_id) = deleted_notification_id {
2732 session.peer.send(
2733 connection_id,
2734 proto::DeleteNotification {
2735 notification_id: notification_id.to_proto(),
2736 },
2737 )?;
2738 }
2739 }
2740
2741 response.send(proto::Ack {})?;
2742 Ok(())
2743}
2744
2745fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2746 version.0.minor < 139
2747}
2748
2749async fn subscribe_to_channels(
2750 _: proto::SubscribeToChannels,
2751 session: MessageContext,
2752) -> Result<()> {
2753 subscribe_user_to_channels(session.user_id(), &session).await?;
2754 Ok(())
2755}
2756
2757async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2758 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2759 let mut pool = session.connection_pool().await;
2760 for membership in &channels_for_user.channel_memberships {
2761 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2762 }
2763 session.peer.send(
2764 session.connection_id,
2765 build_update_user_channels(&channels_for_user),
2766 )?;
2767 session.peer.send(
2768 session.connection_id,
2769 build_channels_update(channels_for_user),
2770 )?;
2771 Ok(())
2772}
2773
2774/// Creates a new channel.
2775async fn create_channel(
2776 request: proto::CreateChannel,
2777 response: Response<proto::CreateChannel>,
2778 session: MessageContext,
2779) -> Result<()> {
2780 let db = session.db().await;
2781
2782 let parent_id = request.parent_id.map(ChannelId::from_proto);
2783 let (channel, membership) = db
2784 .create_channel(&request.name, parent_id, session.user_id())
2785 .await?;
2786
2787 let root_id = channel.root_id();
2788 let channel = Channel::from_model(channel);
2789
2790 response.send(proto::CreateChannelResponse {
2791 channel: Some(channel.to_proto()),
2792 parent_id: request.parent_id,
2793 })?;
2794
2795 let mut connection_pool = session.connection_pool().await;
2796 if let Some(membership) = membership {
2797 connection_pool.subscribe_to_channel(
2798 membership.user_id,
2799 membership.channel_id,
2800 membership.role,
2801 );
2802 let update = proto::UpdateUserChannels {
2803 channel_memberships: vec![proto::ChannelMembership {
2804 channel_id: membership.channel_id.to_proto(),
2805 role: membership.role.into(),
2806 }],
2807 ..Default::default()
2808 };
2809 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2810 session.peer.send(connection_id, update.clone())?;
2811 }
2812 }
2813
2814 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2815 if !role.can_see_channel(channel.visibility) {
2816 continue;
2817 }
2818
2819 let update = proto::UpdateChannels {
2820 channels: vec![channel.to_proto()],
2821 ..Default::default()
2822 };
2823 session.peer.send(connection_id, update.clone())?;
2824 }
2825
2826 Ok(())
2827}
2828
2829/// Delete a channel
2830async fn delete_channel(
2831 request: proto::DeleteChannel,
2832 response: Response<proto::DeleteChannel>,
2833 session: MessageContext,
2834) -> Result<()> {
2835 let db = session.db().await;
2836
2837 let channel_id = request.channel_id;
2838 let (root_channel, removed_channels) = db
2839 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2840 .await?;
2841 response.send(proto::Ack {})?;
2842
2843 // Notify members of removed channels
2844 let mut update = proto::UpdateChannels::default();
2845 update
2846 .delete_channels
2847 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2848
2849 let connection_pool = session.connection_pool().await;
2850 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2851 session.peer.send(connection_id, update.clone())?;
2852 }
2853
2854 Ok(())
2855}
2856
2857/// Invite someone to join a channel.
2858async fn invite_channel_member(
2859 request: proto::InviteChannelMember,
2860 response: Response<proto::InviteChannelMember>,
2861 session: MessageContext,
2862) -> Result<()> {
2863 let db = session.db().await;
2864 let channel_id = ChannelId::from_proto(request.channel_id);
2865 let invitee_id = UserId::from_proto(request.user_id);
2866 let InviteMemberResult {
2867 channel,
2868 notifications,
2869 } = db
2870 .invite_channel_member(
2871 channel_id,
2872 invitee_id,
2873 session.user_id(),
2874 request.role().into(),
2875 )
2876 .await?;
2877
2878 let update = proto::UpdateChannels {
2879 channel_invitations: vec![channel.to_proto()],
2880 ..Default::default()
2881 };
2882
2883 let connection_pool = session.connection_pool().await;
2884 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2885 session.peer.send(connection_id, update.clone())?;
2886 }
2887
2888 send_notifications(&connection_pool, &session.peer, notifications);
2889
2890 response.send(proto::Ack {})?;
2891 Ok(())
2892}
2893
2894/// remove someone from a channel
2895async fn remove_channel_member(
2896 request: proto::RemoveChannelMember,
2897 response: Response<proto::RemoveChannelMember>,
2898 session: MessageContext,
2899) -> Result<()> {
2900 let db = session.db().await;
2901 let channel_id = ChannelId::from_proto(request.channel_id);
2902 let member_id = UserId::from_proto(request.user_id);
2903
2904 let RemoveChannelMemberResult {
2905 membership_update,
2906 notification_id,
2907 } = db
2908 .remove_channel_member(channel_id, member_id, session.user_id())
2909 .await?;
2910
2911 let mut connection_pool = session.connection_pool().await;
2912 notify_membership_updated(
2913 &mut connection_pool,
2914 membership_update,
2915 member_id,
2916 &session.peer,
2917 );
2918 for connection_id in connection_pool.user_connection_ids(member_id) {
2919 if let Some(notification_id) = notification_id {
2920 session
2921 .peer
2922 .send(
2923 connection_id,
2924 proto::DeleteNotification {
2925 notification_id: notification_id.to_proto(),
2926 },
2927 )
2928 .trace_err();
2929 }
2930 }
2931
2932 response.send(proto::Ack {})?;
2933 Ok(())
2934}
2935
2936/// Toggle the channel between public and private.
2937/// Care is taken to maintain the invariant that public channels only descend from public channels,
2938/// (though members-only channels can appear at any point in the hierarchy).
2939async fn set_channel_visibility(
2940 request: proto::SetChannelVisibility,
2941 response: Response<proto::SetChannelVisibility>,
2942 session: MessageContext,
2943) -> Result<()> {
2944 let db = session.db().await;
2945 let channel_id = ChannelId::from_proto(request.channel_id);
2946 let visibility = request.visibility().into();
2947
2948 let channel_model = db
2949 .set_channel_visibility(channel_id, visibility, session.user_id())
2950 .await?;
2951 let root_id = channel_model.root_id();
2952 let channel = Channel::from_model(channel_model);
2953
2954 let mut connection_pool = session.connection_pool().await;
2955 for (user_id, role) in connection_pool
2956 .channel_user_ids(root_id)
2957 .collect::<Vec<_>>()
2958 .into_iter()
2959 {
2960 let update = if role.can_see_channel(channel.visibility) {
2961 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2962 proto::UpdateChannels {
2963 channels: vec![channel.to_proto()],
2964 ..Default::default()
2965 }
2966 } else {
2967 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2968 proto::UpdateChannels {
2969 delete_channels: vec![channel.id.to_proto()],
2970 ..Default::default()
2971 }
2972 };
2973
2974 for connection_id in connection_pool.user_connection_ids(user_id) {
2975 session.peer.send(connection_id, update.clone())?;
2976 }
2977 }
2978
2979 response.send(proto::Ack {})?;
2980 Ok(())
2981}
2982
2983/// Alter the role for a user in the channel.
2984async fn set_channel_member_role(
2985 request: proto::SetChannelMemberRole,
2986 response: Response<proto::SetChannelMemberRole>,
2987 session: MessageContext,
2988) -> Result<()> {
2989 let db = session.db().await;
2990 let channel_id = ChannelId::from_proto(request.channel_id);
2991 let member_id = UserId::from_proto(request.user_id);
2992 let result = db
2993 .set_channel_member_role(
2994 channel_id,
2995 session.user_id(),
2996 member_id,
2997 request.role().into(),
2998 )
2999 .await?;
3000
3001 match result {
3002 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3003 let mut connection_pool = session.connection_pool().await;
3004 notify_membership_updated(
3005 &mut connection_pool,
3006 membership_update,
3007 member_id,
3008 &session.peer,
3009 )
3010 }
3011 db::SetMemberRoleResult::InviteUpdated(channel) => {
3012 let update = proto::UpdateChannels {
3013 channel_invitations: vec![channel.to_proto()],
3014 ..Default::default()
3015 };
3016
3017 for connection_id in session
3018 .connection_pool()
3019 .await
3020 .user_connection_ids(member_id)
3021 {
3022 session.peer.send(connection_id, update.clone())?;
3023 }
3024 }
3025 }
3026
3027 response.send(proto::Ack {})?;
3028 Ok(())
3029}
3030
3031/// Change the name of a channel
3032async fn rename_channel(
3033 request: proto::RenameChannel,
3034 response: Response<proto::RenameChannel>,
3035 session: MessageContext,
3036) -> Result<()> {
3037 let db = session.db().await;
3038 let channel_id = ChannelId::from_proto(request.channel_id);
3039 let channel_model = db
3040 .rename_channel(channel_id, session.user_id(), &request.name)
3041 .await?;
3042 let root_id = channel_model.root_id();
3043 let channel = Channel::from_model(channel_model);
3044
3045 response.send(proto::RenameChannelResponse {
3046 channel: Some(channel.to_proto()),
3047 })?;
3048
3049 let connection_pool = session.connection_pool().await;
3050 let update = proto::UpdateChannels {
3051 channels: vec![channel.to_proto()],
3052 ..Default::default()
3053 };
3054 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3055 if role.can_see_channel(channel.visibility) {
3056 session.peer.send(connection_id, update.clone())?;
3057 }
3058 }
3059
3060 Ok(())
3061}
3062
3063/// Move a channel to a new parent.
3064async fn move_channel(
3065 request: proto::MoveChannel,
3066 response: Response<proto::MoveChannel>,
3067 session: MessageContext,
3068) -> Result<()> {
3069 let channel_id = ChannelId::from_proto(request.channel_id);
3070 let to = ChannelId::from_proto(request.to);
3071
3072 let (root_id, channels) = session
3073 .db()
3074 .await
3075 .move_channel(channel_id, to, session.user_id())
3076 .await?;
3077
3078 let connection_pool = session.connection_pool().await;
3079 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3080 let channels = channels
3081 .iter()
3082 .filter_map(|channel| {
3083 if role.can_see_channel(channel.visibility) {
3084 Some(channel.to_proto())
3085 } else {
3086 None
3087 }
3088 })
3089 .collect::<Vec<_>>();
3090 if channels.is_empty() {
3091 continue;
3092 }
3093
3094 let update = proto::UpdateChannels {
3095 channels,
3096 ..Default::default()
3097 };
3098
3099 session.peer.send(connection_id, update.clone())?;
3100 }
3101
3102 response.send(Ack {})?;
3103 Ok(())
3104}
3105
3106async fn reorder_channel(
3107 request: proto::ReorderChannel,
3108 response: Response<proto::ReorderChannel>,
3109 session: MessageContext,
3110) -> Result<()> {
3111 let channel_id = ChannelId::from_proto(request.channel_id);
3112 let direction = request.direction();
3113
3114 let updated_channels = session
3115 .db()
3116 .await
3117 .reorder_channel(channel_id, direction, session.user_id())
3118 .await?;
3119
3120 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3121 let connection_pool = session.connection_pool().await;
3122 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3123 let channels = updated_channels
3124 .iter()
3125 .filter_map(|channel| {
3126 if role.can_see_channel(channel.visibility) {
3127 Some(channel.to_proto())
3128 } else {
3129 None
3130 }
3131 })
3132 .collect::<Vec<_>>();
3133
3134 if channels.is_empty() {
3135 continue;
3136 }
3137
3138 let update = proto::UpdateChannels {
3139 channels,
3140 ..Default::default()
3141 };
3142
3143 session.peer.send(connection_id, update.clone())?;
3144 }
3145 }
3146
3147 response.send(Ack {})?;
3148 Ok(())
3149}
3150
3151/// Get the list of channel members
3152async fn get_channel_members(
3153 request: proto::GetChannelMembers,
3154 response: Response<proto::GetChannelMembers>,
3155 session: MessageContext,
3156) -> Result<()> {
3157 let db = session.db().await;
3158 let channel_id = ChannelId::from_proto(request.channel_id);
3159 let limit = if request.limit == 0 {
3160 u16::MAX as u64
3161 } else {
3162 request.limit
3163 };
3164 let (members, users) = db
3165 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3166 .await?;
3167 response.send(proto::GetChannelMembersResponse { members, users })?;
3168 Ok(())
3169}
3170
3171/// Accept or decline a channel invitation.
3172async fn respond_to_channel_invite(
3173 request: proto::RespondToChannelInvite,
3174 response: Response<proto::RespondToChannelInvite>,
3175 session: MessageContext,
3176) -> Result<()> {
3177 let db = session.db().await;
3178 let channel_id = ChannelId::from_proto(request.channel_id);
3179 let RespondToChannelInvite {
3180 membership_update,
3181 notifications,
3182 } = db
3183 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3184 .await?;
3185
3186 let mut connection_pool = session.connection_pool().await;
3187 if let Some(membership_update) = membership_update {
3188 notify_membership_updated(
3189 &mut connection_pool,
3190 membership_update,
3191 session.user_id(),
3192 &session.peer,
3193 );
3194 } else {
3195 let update = proto::UpdateChannels {
3196 remove_channel_invitations: vec![channel_id.to_proto()],
3197 ..Default::default()
3198 };
3199
3200 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3201 session.peer.send(connection_id, update.clone())?;
3202 }
3203 };
3204
3205 send_notifications(&connection_pool, &session.peer, notifications);
3206
3207 response.send(proto::Ack {})?;
3208
3209 Ok(())
3210}
3211
3212/// Join the channels' room
3213async fn join_channel(
3214 request: proto::JoinChannel,
3215 response: Response<proto::JoinChannel>,
3216 session: MessageContext,
3217) -> Result<()> {
3218 let channel_id = ChannelId::from_proto(request.channel_id);
3219 join_channel_internal(channel_id, Box::new(response), session).await
3220}
3221
3222trait JoinChannelInternalResponse {
3223 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3224}
3225impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3226 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3227 Response::<proto::JoinChannel>::send(self, result)
3228 }
3229}
3230impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3231 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3232 Response::<proto::JoinRoom>::send(self, result)
3233 }
3234}
3235
3236async fn join_channel_internal(
3237 channel_id: ChannelId,
3238 response: Box<impl JoinChannelInternalResponse>,
3239 session: MessageContext,
3240) -> Result<()> {
3241 let joined_room = {
3242 let mut db = session.db().await;
3243 // If zed quits without leaving the room, and the user re-opens zed before the
3244 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3245 // room they were in.
3246 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3247 tracing::info!(
3248 stale_connection_id = %connection,
3249 "cleaning up stale connection",
3250 );
3251 drop(db);
3252 leave_room_for_session(&session, connection).await?;
3253 db = session.db().await;
3254 }
3255
3256 let (joined_room, membership_updated, role) = db
3257 .join_channel(channel_id, session.user_id(), session.connection_id)
3258 .await?;
3259
3260 let live_kit_connection_info =
3261 session
3262 .app_state
3263 .livekit_client
3264 .as_ref()
3265 .and_then(|live_kit| {
3266 let (can_publish, token) = if role == ChannelRole::Guest {
3267 (
3268 false,
3269 live_kit
3270 .guest_token(
3271 &joined_room.room.livekit_room,
3272 &session.user_id().to_string(),
3273 )
3274 .trace_err()?,
3275 )
3276 } else {
3277 (
3278 true,
3279 live_kit
3280 .room_token(
3281 &joined_room.room.livekit_room,
3282 &session.user_id().to_string(),
3283 )
3284 .trace_err()?,
3285 )
3286 };
3287
3288 Some(LiveKitConnectionInfo {
3289 server_url: live_kit.url().into(),
3290 token,
3291 can_publish,
3292 })
3293 });
3294
3295 response.send(proto::JoinRoomResponse {
3296 room: Some(joined_room.room.clone()),
3297 channel_id: joined_room
3298 .channel
3299 .as_ref()
3300 .map(|channel| channel.id.to_proto()),
3301 live_kit_connection_info,
3302 })?;
3303
3304 let mut connection_pool = session.connection_pool().await;
3305 if let Some(membership_updated) = membership_updated {
3306 notify_membership_updated(
3307 &mut connection_pool,
3308 membership_updated,
3309 session.user_id(),
3310 &session.peer,
3311 );
3312 }
3313
3314 room_updated(&joined_room.room, &session.peer);
3315
3316 joined_room
3317 };
3318
3319 channel_updated(
3320 &joined_room.channel.context("channel not returned")?,
3321 &joined_room.room,
3322 &session.peer,
3323 &*session.connection_pool().await,
3324 );
3325
3326 update_user_contacts(session.user_id(), &session).await?;
3327 Ok(())
3328}
3329
3330/// Start editing the channel notes
3331async fn join_channel_buffer(
3332 request: proto::JoinChannelBuffer,
3333 response: Response<proto::JoinChannelBuffer>,
3334 session: MessageContext,
3335) -> Result<()> {
3336 let db = session.db().await;
3337 let channel_id = ChannelId::from_proto(request.channel_id);
3338
3339 let open_response = db
3340 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3341 .await?;
3342
3343 let collaborators = open_response.collaborators.clone();
3344 response.send(open_response)?;
3345
3346 let update = UpdateChannelBufferCollaborators {
3347 channel_id: channel_id.to_proto(),
3348 collaborators: collaborators.clone(),
3349 };
3350 channel_buffer_updated(
3351 session.connection_id,
3352 collaborators
3353 .iter()
3354 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3355 &update,
3356 &session.peer,
3357 );
3358
3359 Ok(())
3360}
3361
3362/// Edit the channel notes
3363async fn update_channel_buffer(
3364 request: proto::UpdateChannelBuffer,
3365 session: MessageContext,
3366) -> Result<()> {
3367 let db = session.db().await;
3368 let channel_id = ChannelId::from_proto(request.channel_id);
3369
3370 let (collaborators, epoch, version) = db
3371 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3372 .await?;
3373
3374 channel_buffer_updated(
3375 session.connection_id,
3376 collaborators.clone(),
3377 &proto::UpdateChannelBuffer {
3378 channel_id: channel_id.to_proto(),
3379 operations: request.operations,
3380 },
3381 &session.peer,
3382 );
3383
3384 let pool = &*session.connection_pool().await;
3385
3386 let non_collaborators =
3387 pool.channel_connection_ids(channel_id)
3388 .filter_map(|(connection_id, _)| {
3389 if collaborators.contains(&connection_id) {
3390 None
3391 } else {
3392 Some(connection_id)
3393 }
3394 });
3395
3396 broadcast(None, non_collaborators, |peer_id| {
3397 session.peer.send(
3398 peer_id,
3399 proto::UpdateChannels {
3400 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3401 channel_id: channel_id.to_proto(),
3402 epoch: epoch as u64,
3403 version: version.clone(),
3404 }],
3405 ..Default::default()
3406 },
3407 )
3408 });
3409
3410 Ok(())
3411}
3412
3413/// Rejoin the channel notes after a connection blip
3414async fn rejoin_channel_buffers(
3415 request: proto::RejoinChannelBuffers,
3416 response: Response<proto::RejoinChannelBuffers>,
3417 session: MessageContext,
3418) -> Result<()> {
3419 let db = session.db().await;
3420 let buffers = db
3421 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3422 .await?;
3423
3424 for rejoined_buffer in &buffers {
3425 let collaborators_to_notify = rejoined_buffer
3426 .buffer
3427 .collaborators
3428 .iter()
3429 .filter_map(|c| Some(c.peer_id?.into()));
3430 channel_buffer_updated(
3431 session.connection_id,
3432 collaborators_to_notify,
3433 &proto::UpdateChannelBufferCollaborators {
3434 channel_id: rejoined_buffer.buffer.channel_id,
3435 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3436 },
3437 &session.peer,
3438 );
3439 }
3440
3441 response.send(proto::RejoinChannelBuffersResponse {
3442 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3443 })?;
3444
3445 Ok(())
3446}
3447
3448/// Stop editing the channel notes
3449async fn leave_channel_buffer(
3450 request: proto::LeaveChannelBuffer,
3451 response: Response<proto::LeaveChannelBuffer>,
3452 session: MessageContext,
3453) -> Result<()> {
3454 let db = session.db().await;
3455 let channel_id = ChannelId::from_proto(request.channel_id);
3456
3457 let left_buffer = db
3458 .leave_channel_buffer(channel_id, session.connection_id)
3459 .await?;
3460
3461 response.send(Ack {})?;
3462
3463 channel_buffer_updated(
3464 session.connection_id,
3465 left_buffer.connections,
3466 &proto::UpdateChannelBufferCollaborators {
3467 channel_id: channel_id.to_proto(),
3468 collaborators: left_buffer.collaborators,
3469 },
3470 &session.peer,
3471 );
3472
3473 Ok(())
3474}
3475
3476fn channel_buffer_updated<T: EnvelopedMessage>(
3477 sender_id: ConnectionId,
3478 collaborators: impl IntoIterator<Item = ConnectionId>,
3479 message: &T,
3480 peer: &Peer,
3481) {
3482 broadcast(Some(sender_id), collaborators, |peer_id| {
3483 peer.send(peer_id, message.clone())
3484 });
3485}
3486
3487fn send_notifications(
3488 connection_pool: &ConnectionPool,
3489 peer: &Peer,
3490 notifications: db::NotificationBatch,
3491) {
3492 for (user_id, notification) in notifications {
3493 for connection_id in connection_pool.user_connection_ids(user_id) {
3494 if let Err(error) = peer.send(
3495 connection_id,
3496 proto::AddNotification {
3497 notification: Some(notification.clone()),
3498 },
3499 ) {
3500 tracing::error!(
3501 "failed to send notification to {:?} {}",
3502 connection_id,
3503 error
3504 );
3505 }
3506 }
3507 }
3508}
3509
3510/// Send a message to the channel
3511async fn send_channel_message(
3512 _request: proto::SendChannelMessage,
3513 _response: Response<proto::SendChannelMessage>,
3514 _session: MessageContext,
3515) -> Result<()> {
3516 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3517}
3518
3519/// Delete a channel message
3520async fn remove_channel_message(
3521 _request: proto::RemoveChannelMessage,
3522 _response: Response<proto::RemoveChannelMessage>,
3523 _session: MessageContext,
3524) -> Result<()> {
3525 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3526}
3527
3528async fn update_channel_message(
3529 _request: proto::UpdateChannelMessage,
3530 _response: Response<proto::UpdateChannelMessage>,
3531 _session: MessageContext,
3532) -> Result<()> {
3533 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3534}
3535
3536/// Mark a channel message as read
3537async fn acknowledge_channel_message(
3538 _request: proto::AckChannelMessage,
3539 _session: MessageContext,
3540) -> Result<()> {
3541 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3542}
3543
3544/// Mark a buffer version as synced
3545async fn acknowledge_buffer_version(
3546 request: proto::AckBufferOperation,
3547 session: MessageContext,
3548) -> Result<()> {
3549 let buffer_id = BufferId::from_proto(request.buffer_id);
3550 session
3551 .db()
3552 .await
3553 .observe_buffer_version(
3554 buffer_id,
3555 session.user_id(),
3556 request.epoch as i32,
3557 &request.version,
3558 )
3559 .await?;
3560 Ok(())
3561}
3562
3563/// Start receiving chat updates for a channel
3564async fn join_channel_chat(
3565 _request: proto::JoinChannelChat,
3566 _response: Response<proto::JoinChannelChat>,
3567 _session: MessageContext,
3568) -> Result<()> {
3569 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3570}
3571
3572/// Stop receiving chat updates for a channel
3573async fn leave_channel_chat(
3574 _request: proto::LeaveChannelChat,
3575 _session: MessageContext,
3576) -> Result<()> {
3577 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3578}
3579
3580/// Retrieve the chat history for a channel
3581async fn get_channel_messages(
3582 _request: proto::GetChannelMessages,
3583 _response: Response<proto::GetChannelMessages>,
3584 _session: MessageContext,
3585) -> Result<()> {
3586 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3587}
3588
3589/// Retrieve specific chat messages
3590async fn get_channel_messages_by_id(
3591 _request: proto::GetChannelMessagesById,
3592 _response: Response<proto::GetChannelMessagesById>,
3593 _session: MessageContext,
3594) -> Result<()> {
3595 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3596}
3597
3598/// Retrieve the current users notifications
3599async fn get_notifications(
3600 request: proto::GetNotifications,
3601 response: Response<proto::GetNotifications>,
3602 session: MessageContext,
3603) -> Result<()> {
3604 let notifications = session
3605 .db()
3606 .await
3607 .get_notifications(
3608 session.user_id(),
3609 NOTIFICATION_COUNT_PER_PAGE,
3610 request.before_id.map(db::NotificationId::from_proto),
3611 )
3612 .await?;
3613 response.send(proto::GetNotificationsResponse {
3614 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3615 notifications,
3616 })?;
3617 Ok(())
3618}
3619
3620/// Mark notifications as read
3621async fn mark_notification_as_read(
3622 request: proto::MarkNotificationRead,
3623 response: Response<proto::MarkNotificationRead>,
3624 session: MessageContext,
3625) -> Result<()> {
3626 let database = &session.db().await;
3627 let notifications = database
3628 .mark_notification_as_read_by_id(
3629 session.user_id(),
3630 NotificationId::from_proto(request.notification_id),
3631 )
3632 .await?;
3633 send_notifications(
3634 &*session.connection_pool().await,
3635 &session.peer,
3636 notifications,
3637 );
3638 response.send(proto::Ack {})?;
3639 Ok(())
3640}
3641
3642fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3643 let message = match message {
3644 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3645 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3646 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3647 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3648 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3649 code: frame.code.into(),
3650 reason: frame.reason.as_str().to_owned().into(),
3651 })),
3652 // We should never receive a frame while reading the message, according
3653 // to the `tungstenite` maintainers:
3654 //
3655 // > It cannot occur when you read messages from the WebSocket, but it
3656 // > can be used when you want to send the raw frames (e.g. you want to
3657 // > send the frames to the WebSocket without composing the full message first).
3658 // >
3659 // > — https://github.com/snapview/tungstenite-rs/issues/268
3660 TungsteniteMessage::Frame(_) => {
3661 bail!("received an unexpected frame while reading the message")
3662 }
3663 };
3664
3665 Ok(message)
3666}
3667
3668fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3669 match message {
3670 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3671 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3672 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3673 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3674 AxumMessage::Close(frame) => {
3675 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3676 code: frame.code.into(),
3677 reason: frame.reason.as_ref().into(),
3678 }))
3679 }
3680 }
3681}
3682
3683fn notify_membership_updated(
3684 connection_pool: &mut ConnectionPool,
3685 result: MembershipUpdated,
3686 user_id: UserId,
3687 peer: &Peer,
3688) {
3689 for membership in &result.new_channels.channel_memberships {
3690 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3691 }
3692 for channel_id in &result.removed_channels {
3693 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3694 }
3695
3696 let user_channels_update = proto::UpdateUserChannels {
3697 channel_memberships: result
3698 .new_channels
3699 .channel_memberships
3700 .iter()
3701 .map(|cm| proto::ChannelMembership {
3702 channel_id: cm.channel_id.to_proto(),
3703 role: cm.role.into(),
3704 })
3705 .collect(),
3706 ..Default::default()
3707 };
3708
3709 let mut update = build_channels_update(result.new_channels);
3710 update.delete_channels = result
3711 .removed_channels
3712 .into_iter()
3713 .map(|id| id.to_proto())
3714 .collect();
3715 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3716
3717 for connection_id in connection_pool.user_connection_ids(user_id) {
3718 peer.send(connection_id, user_channels_update.clone())
3719 .trace_err();
3720 peer.send(connection_id, update.clone()).trace_err();
3721 }
3722}
3723
3724fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3725 proto::UpdateUserChannels {
3726 channel_memberships: channels
3727 .channel_memberships
3728 .iter()
3729 .map(|m| proto::ChannelMembership {
3730 channel_id: m.channel_id.to_proto(),
3731 role: m.role.into(),
3732 })
3733 .collect(),
3734 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3735 }
3736}
3737
3738fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3739 let mut update = proto::UpdateChannels::default();
3740
3741 for channel in channels.channels {
3742 update.channels.push(channel.to_proto());
3743 }
3744
3745 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3746
3747 for (channel_id, participants) in channels.channel_participants {
3748 update
3749 .channel_participants
3750 .push(proto::ChannelParticipants {
3751 channel_id: channel_id.to_proto(),
3752 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3753 });
3754 }
3755
3756 for channel in channels.invited_channels {
3757 update.channel_invitations.push(channel.to_proto());
3758 }
3759
3760 update
3761}
3762
3763fn build_initial_contacts_update(
3764 contacts: Vec<db::Contact>,
3765 pool: &ConnectionPool,
3766) -> proto::UpdateContacts {
3767 let mut update = proto::UpdateContacts::default();
3768
3769 for contact in contacts {
3770 match contact {
3771 db::Contact::Accepted { user_id, busy } => {
3772 update.contacts.push(contact_for_user(user_id, busy, pool));
3773 }
3774 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3775 db::Contact::Incoming { user_id } => {
3776 update
3777 .incoming_requests
3778 .push(proto::IncomingContactRequest {
3779 requester_id: user_id.to_proto(),
3780 })
3781 }
3782 }
3783 }
3784
3785 update
3786}
3787
3788fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3789 proto::Contact {
3790 user_id: user_id.to_proto(),
3791 online: pool.is_user_online(user_id),
3792 busy,
3793 }
3794}
3795
3796fn room_updated(room: &proto::Room, peer: &Peer) {
3797 broadcast(
3798 None,
3799 room.participants
3800 .iter()
3801 .filter_map(|participant| Some(participant.peer_id?.into())),
3802 |peer_id| {
3803 peer.send(
3804 peer_id,
3805 proto::RoomUpdated {
3806 room: Some(room.clone()),
3807 },
3808 )
3809 },
3810 );
3811}
3812
3813fn channel_updated(
3814 channel: &db::channel::Model,
3815 room: &proto::Room,
3816 peer: &Peer,
3817 pool: &ConnectionPool,
3818) {
3819 let participants = room
3820 .participants
3821 .iter()
3822 .map(|p| p.user_id)
3823 .collect::<Vec<_>>();
3824
3825 broadcast(
3826 None,
3827 pool.channel_connection_ids(channel.root_id())
3828 .filter_map(|(channel_id, role)| {
3829 role.can_see_channel(channel.visibility)
3830 .then_some(channel_id)
3831 }),
3832 |peer_id| {
3833 peer.send(
3834 peer_id,
3835 proto::UpdateChannels {
3836 channel_participants: vec![proto::ChannelParticipants {
3837 channel_id: channel.id.to_proto(),
3838 participant_user_ids: participants.clone(),
3839 }],
3840 ..Default::default()
3841 },
3842 )
3843 },
3844 );
3845}
3846
3847async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3848 let db = session.db().await;
3849
3850 let contacts = db.get_contacts(user_id).await?;
3851 let busy = db.is_user_busy(user_id).await?;
3852
3853 let pool = session.connection_pool().await;
3854 let updated_contact = contact_for_user(user_id, busy, &pool);
3855 for contact in contacts {
3856 if let db::Contact::Accepted {
3857 user_id: contact_user_id,
3858 ..
3859 } = contact
3860 {
3861 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3862 session
3863 .peer
3864 .send(
3865 contact_conn_id,
3866 proto::UpdateContacts {
3867 contacts: vec![updated_contact.clone()],
3868 remove_contacts: Default::default(),
3869 incoming_requests: Default::default(),
3870 remove_incoming_requests: Default::default(),
3871 outgoing_requests: Default::default(),
3872 remove_outgoing_requests: Default::default(),
3873 },
3874 )
3875 .trace_err();
3876 }
3877 }
3878 }
3879 Ok(())
3880}
3881
3882async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3883 let mut contacts_to_update = HashSet::default();
3884
3885 let room_id;
3886 let canceled_calls_to_user_ids;
3887 let livekit_room;
3888 let delete_livekit_room;
3889 let room;
3890 let channel;
3891
3892 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3893 contacts_to_update.insert(session.user_id());
3894
3895 for project in left_room.left_projects.values() {
3896 project_left(project, session);
3897 }
3898
3899 room_id = RoomId::from_proto(left_room.room.id);
3900 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3901 livekit_room = mem::take(&mut left_room.room.livekit_room);
3902 delete_livekit_room = left_room.deleted;
3903 room = mem::take(&mut left_room.room);
3904 channel = mem::take(&mut left_room.channel);
3905
3906 room_updated(&room, &session.peer);
3907 } else {
3908 return Ok(());
3909 }
3910
3911 if let Some(channel) = channel {
3912 channel_updated(
3913 &channel,
3914 &room,
3915 &session.peer,
3916 &*session.connection_pool().await,
3917 );
3918 }
3919
3920 {
3921 let pool = session.connection_pool().await;
3922 for canceled_user_id in canceled_calls_to_user_ids {
3923 for connection_id in pool.user_connection_ids(canceled_user_id) {
3924 session
3925 .peer
3926 .send(
3927 connection_id,
3928 proto::CallCanceled {
3929 room_id: room_id.to_proto(),
3930 },
3931 )
3932 .trace_err();
3933 }
3934 contacts_to_update.insert(canceled_user_id);
3935 }
3936 }
3937
3938 for contact_user_id in contacts_to_update {
3939 update_user_contacts(contact_user_id, session).await?;
3940 }
3941
3942 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3943 live_kit
3944 .remove_participant(livekit_room.clone(), session.user_id().to_string())
3945 .await
3946 .trace_err();
3947
3948 if delete_livekit_room {
3949 live_kit.delete_room(livekit_room).await.trace_err();
3950 }
3951 }
3952
3953 Ok(())
3954}
3955
3956async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3957 let left_channel_buffers = session
3958 .db()
3959 .await
3960 .leave_channel_buffers(session.connection_id)
3961 .await?;
3962
3963 for left_buffer in left_channel_buffers {
3964 channel_buffer_updated(
3965 session.connection_id,
3966 left_buffer.connections,
3967 &proto::UpdateChannelBufferCollaborators {
3968 channel_id: left_buffer.channel_id.to_proto(),
3969 collaborators: left_buffer.collaborators,
3970 },
3971 &session.peer,
3972 );
3973 }
3974
3975 Ok(())
3976}
3977
3978fn project_left(project: &db::LeftProject, session: &Session) {
3979 for connection_id in &project.connection_ids {
3980 if project.should_unshare {
3981 session
3982 .peer
3983 .send(
3984 *connection_id,
3985 proto::UnshareProject {
3986 project_id: project.id.to_proto(),
3987 },
3988 )
3989 .trace_err();
3990 } else {
3991 session
3992 .peer
3993 .send(
3994 *connection_id,
3995 proto::RemoveProjectCollaborator {
3996 project_id: project.id.to_proto(),
3997 peer_id: Some(session.connection_id.into()),
3998 },
3999 )
4000 .trace_err();
4001 }
4002 }
4003}
4004
4005async fn share_agent_thread(
4006 request: proto::ShareAgentThread,
4007 response: Response<proto::ShareAgentThread>,
4008 session: MessageContext,
4009) -> Result<()> {
4010 let user_id = session.user_id();
4011
4012 let share_id = SharedThreadId::from_proto(request.session_id.clone())
4013 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4014
4015 session
4016 .db()
4017 .await
4018 .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4019 .await?;
4020
4021 response.send(proto::Ack {})?;
4022
4023 Ok(())
4024}
4025
4026async fn get_shared_agent_thread(
4027 request: proto::GetSharedAgentThread,
4028 response: Response<proto::GetSharedAgentThread>,
4029 session: MessageContext,
4030) -> Result<()> {
4031 let share_id = SharedThreadId::from_proto(request.session_id)
4032 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4033
4034 let result = session.db().await.get_shared_thread(share_id).await?;
4035
4036 match result {
4037 Some((thread, username)) => {
4038 response.send(proto::GetSharedAgentThreadResponse {
4039 title: thread.title,
4040 thread_data: thread.data,
4041 sharer_username: username,
4042 created_at: thread.created_at.and_utc().to_rfc3339(),
4043 })?;
4044 }
4045 None => {
4046 return Err(anyhow!("Shared thread not found").into());
4047 }
4048 }
4049
4050 Ok(())
4051}
4052
4053pub trait ResultExt {
4054 type Ok;
4055
4056 fn trace_err(self) -> Option<Self::Ok>;
4057}
4058
4059impl<T, E> ResultExt for Result<T, E>
4060where
4061 E: std::fmt::Debug,
4062{
4063 type Ok = T;
4064
4065 #[track_caller]
4066 fn trace_err(self) -> Option<T> {
4067 match self {
4068 Ok(value) => Some(value),
4069 Err(error) => {
4070 tracing::error!("{:?}", error);
4071 None
4072 }
4073 }
4074 }
4075}