1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
10 UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use futures::TryFutureExt as _;
36use rpc::proto::split_repository_update;
37use tracing::Span;
38use util::paths::PathStyle;
39
40use futures::{
41 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
42 stream::FuturesUnordered,
43};
44use prometheus::{IntGauge, register_int_gauge};
45use rpc::{
46 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
47 proto::{
48 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
49 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
50 },
51};
52use semver::Version;
53use std::{
54 any::TypeId,
55 future::Future,
56 marker::PhantomData,
57 mem,
58 net::SocketAddr,
59 ops::{Deref, DerefMut},
60 rc::Rc,
61 sync::{
62 Arc, OnceLock,
63 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
64 },
65 time::{Duration, Instant},
66};
67use tokio::sync::{Semaphore, watch};
68use tower::ServiceBuilder;
69use tracing::{
70 Instrument,
71 field::{self},
72 info_span, instrument,
73};
74
75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
76
77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
79
80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
82
83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
84
85const TOTAL_DURATION_MS: &str = "total_duration_ms";
86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
88const HOST_WAITING_MS: &str = "host_waiting_ms";
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
92
93pub struct ConnectionGuard;
94
95impl ConnectionGuard {
96 pub fn try_acquire() -> Result<Self, ()> {
97 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
98 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
99 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
100 tracing::error!(
101 "too many concurrent connections: {}",
102 current_connections + 1
103 );
104 return Err(());
105 }
106 Ok(ConnectionGuard)
107 }
108}
109
110impl Drop for ConnectionGuard {
111 fn drop(&mut self) {
112 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
113 }
114}
115
116struct Response<R> {
117 peer: Arc<Peer>,
118 receipt: Receipt<R>,
119 responded: Arc<AtomicBool>,
120}
121
122impl<R: RequestMessage> Response<R> {
123 fn send(self, payload: R::Response) -> Result<()> {
124 self.responded.store(true, SeqCst);
125 self.peer.respond(self.receipt, payload)?;
126 Ok(())
127 }
128}
129
130#[derive(Clone, Debug)]
131pub enum Principal {
132 User(User),
133}
134
135impl Principal {
136 fn update_span(&self, span: &tracing::Span) {
137 match &self {
138 Principal::User(user) => {
139 span.record("user_id", user.id.0);
140 span.record("login", &user.github_login);
141 }
142 }
143 }
144}
145
146#[derive(Clone)]
147struct MessageContext {
148 session: Session,
149 span: tracing::Span,
150}
151
152impl Deref for MessageContext {
153 type Target = Session;
154
155 fn deref(&self) -> &Self::Target {
156 &self.session
157 }
158}
159
160impl MessageContext {
161 pub fn forward_request<T: RequestMessage>(
162 &self,
163 receiver_id: ConnectionId,
164 request: T,
165 ) -> impl Future<Output = anyhow::Result<T::Response>> {
166 let request_start_time = Instant::now();
167 let span = self.span.clone();
168 tracing::info!("start forwarding request");
169 self.peer
170 .forward_request(self.connection_id, receiver_id, request)
171 .inspect(move |_| {
172 span.record(
173 HOST_WAITING_MS,
174 request_start_time.elapsed().as_micros() as f64 / 1000.0,
175 );
176 })
177 .inspect_err(|_| tracing::error!("error forwarding request"))
178 .inspect_ok(|_| tracing::info!("finished forwarding request"))
179 }
180}
181
182#[derive(Clone)]
183struct Session {
184 principal: Principal,
185 connection_id: ConnectionId,
186 db: Arc<tokio::sync::Mutex<DbHandle>>,
187 peer: Arc<Peer>,
188 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
189 app_state: Arc<AppState>,
190 /// The GeoIP country code for the user.
191 #[allow(unused)]
192 geoip_country_code: Option<String>,
193 #[allow(unused)]
194 system_id: Option<String>,
195 _executor: Executor,
196}
197
198impl Session {
199 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
200 #[cfg(feature = "test-support")]
201 tokio::task::yield_now().await;
202 let guard = self.db.lock().await;
203 #[cfg(feature = "test-support")]
204 tokio::task::yield_now().await;
205 guard
206 }
207
208 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
209 #[cfg(feature = "test-support")]
210 tokio::task::yield_now().await;
211 let guard = self.connection_pool.lock();
212 ConnectionPoolGuard {
213 guard,
214 _not_send: PhantomData,
215 }
216 }
217
218 #[expect(dead_code)]
219 fn is_staff(&self) -> bool {
220 match &self.principal {
221 Principal::User(user) => user.admin,
222 }
223 }
224
225 fn user_id(&self) -> UserId {
226 match &self.principal {
227 Principal::User(user) => user.id,
228 }
229 }
230}
231
232impl Debug for Session {
233 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
234 let mut result = f.debug_struct("Session");
235 match &self.principal {
236 Principal::User(user) => {
237 result.field("user", &user.github_login);
238 }
239 }
240 result.field("connection_id", &self.connection_id).finish()
241 }
242}
243
244struct DbHandle(Arc<Database>);
245
246impl Deref for DbHandle {
247 type Target = Database;
248
249 fn deref(&self) -> &Self::Target {
250 self.0.as_ref()
251 }
252}
253
254pub struct Server {
255 id: parking_lot::Mutex<ServerId>,
256 peer: Arc<Peer>,
257 pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
258 app_state: Arc<AppState>,
259 handlers: HashMap<TypeId, MessageHandler>,
260 teardown: watch::Sender<bool>,
261}
262
263struct ConnectionPoolGuard<'a> {
264 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
265 _not_send: PhantomData<Rc<()>>,
266}
267
268impl Server {
269 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
270 let mut server = Self {
271 id: parking_lot::Mutex::new(id),
272 peer: Peer::new(id.0 as u32),
273 app_state,
274 connection_pool: Default::default(),
275 handlers: Default::default(),
276 teardown: watch::channel(false).0,
277 };
278
279 server
280 .add_request_handler(ping)
281 .add_request_handler(create_room)
282 .add_request_handler(join_room)
283 .add_request_handler(rejoin_room)
284 .add_request_handler(leave_room)
285 .add_request_handler(set_room_participant_role)
286 .add_request_handler(call)
287 .add_request_handler(cancel_call)
288 .add_message_handler(decline_call)
289 .add_request_handler(update_participant_location)
290 .add_request_handler(share_project)
291 .add_message_handler(unshare_project)
292 .add_request_handler(join_project)
293 .add_message_handler(leave_project)
294 .add_request_handler(update_project)
295 .add_request_handler(update_worktree)
296 .add_request_handler(update_repository)
297 .add_request_handler(remove_repository)
298 .add_message_handler(start_language_server)
299 .add_message_handler(update_language_server)
300 .add_message_handler(update_diagnostic_summary)
301 .add_message_handler(update_worktree_settings)
302 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
307 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
308 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
311 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
313 .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
314 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
316 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
317 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
318 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
319 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
320 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
321 .add_request_handler(
322 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
323 )
324 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
325 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
326 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
327 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
328 .add_request_handler(
329 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
330 )
331 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
332 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
333 .add_request_handler(
334 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
335 )
336 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
337 .add_request_handler(
338 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
339 )
340 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
341 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
342 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
343 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
344 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
345 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
346 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
347 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
348 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
349 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
350 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
353 )
354 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
355 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
356 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
357 .add_request_handler(lsp_query)
358 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
359 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
360 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
361 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
362 .add_message_handler(create_buffer_for_peer)
363 .add_message_handler(create_image_for_peer)
364 .add_request_handler(update_buffer)
365 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
366 .add_message_handler(
367 broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
368 )
369 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
370 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
371 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
372 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
373 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
374 .add_message_handler(
375 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
376 )
377 .add_request_handler(get_users)
378 .add_request_handler(fuzzy_search_users)
379 .add_request_handler(request_contact)
380 .add_request_handler(remove_contact)
381 .add_request_handler(respond_to_contact_request)
382 .add_message_handler(subscribe_to_channels)
383 .add_request_handler(create_channel)
384 .add_request_handler(delete_channel)
385 .add_request_handler(invite_channel_member)
386 .add_request_handler(remove_channel_member)
387 .add_request_handler(set_channel_member_role)
388 .add_request_handler(set_channel_visibility)
389 .add_request_handler(rename_channel)
390 .add_request_handler(join_channel_buffer)
391 .add_request_handler(leave_channel_buffer)
392 .add_message_handler(update_channel_buffer)
393 .add_request_handler(rejoin_channel_buffers)
394 .add_request_handler(get_channel_members)
395 .add_request_handler(respond_to_channel_invite)
396 .add_request_handler(join_channel)
397 .add_request_handler(join_channel_chat)
398 .add_message_handler(leave_channel_chat)
399 .add_request_handler(send_channel_message)
400 .add_request_handler(remove_channel_message)
401 .add_request_handler(update_channel_message)
402 .add_request_handler(get_channel_messages)
403 .add_request_handler(get_channel_messages_by_id)
404 .add_request_handler(get_notifications)
405 .add_request_handler(mark_notification_as_read)
406 .add_request_handler(move_channel)
407 .add_request_handler(reorder_channel)
408 .add_request_handler(follow)
409 .add_message_handler(unfollow)
410 .add_message_handler(update_followers)
411 .add_message_handler(acknowledge_channel_message)
412 .add_message_handler(acknowledge_buffer_version)
413 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
414 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
415 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
416 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
417 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
418 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
419 .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
420 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
421 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
422 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
423 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
424 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
425 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
426 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
427 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
428 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
429 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
430 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
431 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
432 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
433 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
434 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
435 .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
436 .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
437 .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
438 .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
439 .add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
440 .add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
441 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
442 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
443 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
444 .add_request_handler(share_agent_thread)
445 .add_request_handler(get_shared_agent_thread)
446 .add_request_handler(forward_project_search_chunk);
447
448 Arc::new(server)
449 }
450
451 pub async fn start(&self) -> Result<()> {
452 let server_id = *self.id.lock();
453 let app_state = self.app_state.clone();
454 let peer = self.peer.clone();
455 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
456 let pool = self.connection_pool.clone();
457 let livekit_client = self.app_state.livekit_client.clone();
458
459 let span = info_span!("start server");
460 self.app_state.executor.spawn_detached(
461 async move {
462 tracing::info!("waiting for cleanup timeout");
463 timeout.await;
464 tracing::info!("cleanup timeout expired, retrieving stale rooms");
465
466 app_state
467 .db
468 .delete_stale_channel_chat_participants(
469 &app_state.config.zed_environment,
470 server_id,
471 )
472 .await
473 .trace_err();
474
475 if let Some((room_ids, channel_ids)) = app_state
476 .db
477 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
478 .await
479 .trace_err()
480 {
481 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
482 tracing::info!(
483 stale_channel_buffer_count = channel_ids.len(),
484 "retrieved stale channel buffers"
485 );
486
487 for channel_id in channel_ids {
488 if let Some(refreshed_channel_buffer) = app_state
489 .db
490 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
491 .await
492 .trace_err()
493 {
494 for connection_id in refreshed_channel_buffer.connection_ids {
495 peer.send(
496 connection_id,
497 proto::UpdateChannelBufferCollaborators {
498 channel_id: channel_id.to_proto(),
499 collaborators: refreshed_channel_buffer
500 .collaborators
501 .clone(),
502 },
503 )
504 .trace_err();
505 }
506 }
507 }
508
509 for room_id in room_ids {
510 let mut contacts_to_update = HashSet::default();
511 let mut canceled_calls_to_user_ids = Vec::new();
512 let mut livekit_room = String::new();
513 let mut delete_livekit_room = false;
514
515 if let Some(mut refreshed_room) = app_state
516 .db
517 .clear_stale_room_participants(room_id, server_id)
518 .await
519 .trace_err()
520 {
521 tracing::info!(
522 room_id = room_id.0,
523 new_participant_count = refreshed_room.room.participants.len(),
524 "refreshed room"
525 );
526 room_updated(&refreshed_room.room, &peer);
527 if let Some(channel) = refreshed_room.channel.as_ref() {
528 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
529 }
530 contacts_to_update
531 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
532 contacts_to_update
533 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
534 canceled_calls_to_user_ids =
535 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
536 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
537 delete_livekit_room = refreshed_room.room.participants.is_empty();
538 }
539
540 {
541 let pool = pool.lock();
542 for canceled_user_id in canceled_calls_to_user_ids {
543 for connection_id in pool.user_connection_ids(canceled_user_id) {
544 peer.send(
545 connection_id,
546 proto::CallCanceled {
547 room_id: room_id.to_proto(),
548 },
549 )
550 .trace_err();
551 }
552 }
553 }
554
555 for user_id in contacts_to_update {
556 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
557 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
558 if let Some((busy, contacts)) = busy.zip(contacts) {
559 let pool = pool.lock();
560 let updated_contact = contact_for_user(user_id, busy, &pool);
561 for contact in contacts {
562 if let db::Contact::Accepted {
563 user_id: contact_user_id,
564 ..
565 } = contact
566 {
567 for contact_conn_id in
568 pool.user_connection_ids(contact_user_id)
569 {
570 peer.send(
571 contact_conn_id,
572 proto::UpdateContacts {
573 contacts: vec![updated_contact.clone()],
574 remove_contacts: Default::default(),
575 incoming_requests: Default::default(),
576 remove_incoming_requests: Default::default(),
577 outgoing_requests: Default::default(),
578 remove_outgoing_requests: Default::default(),
579 },
580 )
581 .trace_err();
582 }
583 }
584 }
585 }
586 }
587
588 if let Some(live_kit) = livekit_client.as_ref()
589 && delete_livekit_room
590 {
591 live_kit.delete_room(livekit_room).await.trace_err();
592 }
593 }
594 }
595
596 app_state
597 .db
598 .delete_stale_channel_chat_participants(
599 &app_state.config.zed_environment,
600 server_id,
601 )
602 .await
603 .trace_err();
604
605 app_state
606 .db
607 .clear_old_worktree_entries(server_id)
608 .await
609 .trace_err();
610
611 app_state
612 .db
613 .delete_stale_servers(&app_state.config.zed_environment, server_id)
614 .await
615 .trace_err();
616 }
617 .instrument(span),
618 );
619 Ok(())
620 }
621
622 pub fn teardown(&self) {
623 self.peer.teardown();
624 self.connection_pool.lock().reset();
625 let _ = self.teardown.send(true);
626 }
627
628 #[cfg(feature = "test-support")]
629 pub fn reset(&self, id: ServerId) {
630 self.teardown();
631 *self.id.lock() = id;
632 self.peer.reset(id.0 as u32);
633 let _ = self.teardown.send(false);
634 }
635
636 #[cfg(feature = "test-support")]
637 pub fn id(&self) -> ServerId {
638 *self.id.lock()
639 }
640
641 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
642 where
643 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
644 Fut: 'static + Send + Future<Output = Result<()>>,
645 M: EnvelopedMessage,
646 {
647 let prev_handler = self.handlers.insert(
648 TypeId::of::<M>(),
649 Box::new(move |envelope, session, span| {
650 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
651 let received_at = envelope.received_at;
652 tracing::info!("message received");
653 let start_time = Instant::now();
654 let future = (handler)(
655 *envelope,
656 MessageContext {
657 session,
658 span: span.clone(),
659 },
660 );
661 async move {
662 let result = future.await;
663 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
664 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
665 let queue_duration_ms = total_duration_ms - processing_duration_ms;
666 span.record(TOTAL_DURATION_MS, total_duration_ms);
667 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
668 span.record(QUEUE_DURATION_MS, queue_duration_ms);
669 match result {
670 Err(error) => {
671 tracing::error!(?error, "error handling message")
672 }
673 Ok(()) => tracing::info!("finished handling message"),
674 }
675 }
676 .boxed()
677 }),
678 );
679 if prev_handler.is_some() {
680 panic!("registered a handler for the same message twice");
681 }
682 self
683 }
684
685 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
686 where
687 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
688 Fut: 'static + Send + Future<Output = Result<()>>,
689 M: EnvelopedMessage,
690 {
691 self.add_handler(move |envelope, session| handler(envelope.payload, session));
692 self
693 }
694
695 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
696 where
697 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
698 Fut: Send + Future<Output = Result<()>>,
699 M: RequestMessage,
700 {
701 let handler = Arc::new(handler);
702 self.add_handler(move |envelope, session| {
703 let receipt = envelope.receipt();
704 let handler = handler.clone();
705 async move {
706 let peer = session.peer.clone();
707 let responded = Arc::new(AtomicBool::default());
708 let response = Response {
709 peer: peer.clone(),
710 responded: responded.clone(),
711 receipt,
712 };
713 match (handler)(envelope.payload, response, session).await {
714 Ok(()) => {
715 if responded.load(std::sync::atomic::Ordering::SeqCst) {
716 Ok(())
717 } else {
718 Err(anyhow!("handler did not send a response"))?
719 }
720 }
721 Err(error) => {
722 let proto_err = match &error {
723 Error::Internal(err) => err.to_proto(),
724 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
725 };
726 peer.respond_with_error(receipt, proto_err)?;
727 Err(error)
728 }
729 }
730 }
731 })
732 }
733
734 pub fn handle_connection(
735 self: &Arc<Self>,
736 connection: Connection,
737 address: String,
738 principal: Principal,
739 zed_version: ZedVersion,
740 release_channel: Option<String>,
741 user_agent: Option<String>,
742 geoip_country_code: Option<String>,
743 system_id: Option<String>,
744 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
745 executor: Executor,
746 connection_guard: Option<ConnectionGuard>,
747 ) -> impl Future<Output = ()> + use<> {
748 let this = self.clone();
749 let span = info_span!("handle connection", %address,
750 connection_id=field::Empty,
751 user_id=field::Empty,
752 login=field::Empty,
753 user_agent=field::Empty,
754 geoip_country_code=field::Empty,
755 release_channel=field::Empty,
756 );
757 principal.update_span(&span);
758 if let Some(user_agent) = user_agent {
759 span.record("user_agent", user_agent);
760 }
761 if let Some(release_channel) = release_channel {
762 span.record("release_channel", release_channel);
763 }
764
765 if let Some(country_code) = geoip_country_code.as_ref() {
766 span.record("geoip_country_code", country_code);
767 }
768
769 let mut teardown = self.teardown.subscribe();
770 async move {
771 if *teardown.borrow() {
772 tracing::error!("server is tearing down");
773 return;
774 }
775
776 let (connection_id, handle_io, mut incoming_rx) =
777 this.peer.add_connection(connection, {
778 let executor = executor.clone();
779 move |duration| executor.sleep(duration)
780 });
781 tracing::Span::current().record("connection_id", format!("{}", connection_id));
782
783 tracing::info!("connection opened");
784
785 let session = Session {
786 principal: principal.clone(),
787 connection_id,
788 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
789 peer: this.peer.clone(),
790 connection_pool: this.connection_pool.clone(),
791 app_state: this.app_state.clone(),
792 geoip_country_code,
793 system_id,
794 _executor: executor.clone(),
795 };
796
797 if let Err(error) = this
798 .send_initial_client_update(
799 connection_id,
800 zed_version,
801 send_connection_id,
802 &session,
803 )
804 .await
805 {
806 tracing::error!(?error, "failed to send initial client update");
807 return;
808 }
809 drop(connection_guard);
810
811 let handle_io = handle_io.fuse();
812 futures::pin_mut!(handle_io);
813
814 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
815 // This prevents deadlocks when e.g., client A performs a request to client B and
816 // client B performs a request to client A. If both clients stop processing further
817 // messages until their respective request completes, they won't have a chance to
818 // respond to the other client's request and cause a deadlock.
819 //
820 // This arrangement ensures we will attempt to process earlier messages first, but fall
821 // back to processing messages arrived later in the spirit of making progress.
822 const MAX_CONCURRENT_HANDLERS: usize = 256;
823 let mut foreground_message_handlers = FuturesUnordered::new();
824 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
825 let get_concurrent_handlers = {
826 let concurrent_handlers = concurrent_handlers.clone();
827 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
828 };
829 loop {
830 let next_message = async {
831 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
832 let message = incoming_rx.next().await;
833 // Cache the concurrent_handlers here, so that we know what the
834 // queue looks like as each handler starts
835 (permit, message, get_concurrent_handlers())
836 }
837 .fuse();
838 futures::pin_mut!(next_message);
839 futures::select_biased! {
840 _ = teardown.changed().fuse() => return,
841 result = handle_io => {
842 if let Err(error) = result {
843 tracing::error!(?error, "error handling I/O");
844 }
845 break;
846 }
847 _ = foreground_message_handlers.next() => {}
848 next_message = next_message => {
849 let (permit, message, concurrent_handlers) = next_message;
850 if let Some(message) = message {
851 let type_name = message.payload_type_name();
852 // note: we copy all the fields from the parent span so we can query them in the logs.
853 // (https://github.com/tokio-rs/tracing/issues/2670).
854 let span = tracing::info_span!("receive message",
855 %connection_id,
856 %address,
857 type_name,
858 concurrent_handlers,
859 user_id=field::Empty,
860 login=field::Empty,
861 lsp_query_request=field::Empty,
862 release_channel=field::Empty,
863 { TOTAL_DURATION_MS }=field::Empty,
864 { PROCESSING_DURATION_MS }=field::Empty,
865 { QUEUE_DURATION_MS }=field::Empty,
866 { HOST_WAITING_MS }=field::Empty
867 );
868 principal.update_span(&span);
869 let span_enter = span.enter();
870 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
871 let is_background = message.is_background();
872 let handle_message = (handler)(message, session.clone(), span.clone());
873 drop(span_enter);
874
875 let handle_message = async move {
876 handle_message.await;
877 drop(permit);
878 }.instrument(span);
879 if is_background {
880 executor.spawn_detached(handle_message);
881 } else {
882 foreground_message_handlers.push(handle_message);
883 }
884 } else {
885 tracing::error!("no message handler");
886 }
887 } else {
888 tracing::info!("connection closed");
889 break;
890 }
891 }
892 }
893 }
894
895 drop(foreground_message_handlers);
896 let concurrent_handlers = get_concurrent_handlers();
897 tracing::info!(concurrent_handlers, "signing out");
898 if let Err(error) = connection_lost(session, teardown, executor).await {
899 tracing::error!(?error, "error signing out");
900 }
901 }
902 .instrument(span)
903 }
904
905 async fn send_initial_client_update(
906 &self,
907 connection_id: ConnectionId,
908 zed_version: ZedVersion,
909 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
910 session: &Session,
911 ) -> Result<()> {
912 self.peer.send(
913 connection_id,
914 proto::Hello {
915 peer_id: Some(connection_id.into()),
916 },
917 )?;
918 tracing::info!("sent hello message");
919 if let Some(send_connection_id) = send_connection_id.take() {
920 let _ = send_connection_id.send(connection_id);
921 }
922
923 match &session.principal {
924 Principal::User(user) => {
925 if !user.connected_once {
926 self.peer.send(connection_id, proto::ShowContacts {})?;
927 self.app_state
928 .db
929 .set_user_connected_once(user.id, true)
930 .await?;
931 }
932
933 let contacts = self.app_state.db.get_contacts(user.id).await?;
934
935 {
936 let mut pool = self.connection_pool.lock();
937 pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
938 self.peer.send(
939 connection_id,
940 build_initial_contacts_update(contacts, &pool),
941 )?;
942 }
943
944 if should_auto_subscribe_to_channels(&zed_version) {
945 subscribe_user_to_channels(user.id, session).await?;
946 }
947
948 if let Some(incoming_call) =
949 self.app_state.db.incoming_call_for_user(user.id).await?
950 {
951 self.peer.send(connection_id, incoming_call)?;
952 }
953
954 update_user_contacts(user.id, session).await?;
955 }
956 }
957
958 Ok(())
959 }
960}
961
962impl Deref for ConnectionPoolGuard<'_> {
963 type Target = ConnectionPool;
964
965 fn deref(&self) -> &Self::Target {
966 &self.guard
967 }
968}
969
970impl DerefMut for ConnectionPoolGuard<'_> {
971 fn deref_mut(&mut self) -> &mut Self::Target {
972 &mut self.guard
973 }
974}
975
976impl Drop for ConnectionPoolGuard<'_> {
977 fn drop(&mut self) {
978 #[cfg(feature = "test-support")]
979 self.check_invariants();
980 }
981}
982
983fn broadcast<F>(
984 sender_id: Option<ConnectionId>,
985 receiver_ids: impl IntoIterator<Item = ConnectionId>,
986 mut f: F,
987) where
988 F: FnMut(ConnectionId) -> anyhow::Result<()>,
989{
990 for receiver_id in receiver_ids {
991 if Some(receiver_id) != sender_id
992 && let Err(error) = f(receiver_id)
993 {
994 tracing::error!("failed to send to {:?} {}", receiver_id, error);
995 }
996 }
997}
998
999pub struct ProtocolVersion(u32);
1000
1001impl Header for ProtocolVersion {
1002 fn name() -> &'static HeaderName {
1003 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1004 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1005 }
1006
1007 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1008 where
1009 Self: Sized,
1010 I: Iterator<Item = &'i axum::http::HeaderValue>,
1011 {
1012 let version = values
1013 .next()
1014 .ok_or_else(axum::headers::Error::invalid)?
1015 .to_str()
1016 .map_err(|_| axum::headers::Error::invalid())?
1017 .parse()
1018 .map_err(|_| axum::headers::Error::invalid())?;
1019 Ok(Self(version))
1020 }
1021
1022 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1023 values.extend([self.0.to_string().parse().unwrap()]);
1024 }
1025}
1026
1027pub struct AppVersionHeader(Version);
1028impl Header for AppVersionHeader {
1029 fn name() -> &'static HeaderName {
1030 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1031 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1032 }
1033
1034 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1035 where
1036 Self: Sized,
1037 I: Iterator<Item = &'i axum::http::HeaderValue>,
1038 {
1039 let version = values
1040 .next()
1041 .ok_or_else(axum::headers::Error::invalid)?
1042 .to_str()
1043 .map_err(|_| axum::headers::Error::invalid())?
1044 .parse()
1045 .map_err(|_| axum::headers::Error::invalid())?;
1046 Ok(Self(version))
1047 }
1048
1049 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1050 values.extend([self.0.to_string().parse().unwrap()]);
1051 }
1052}
1053
1054#[derive(Debug)]
1055pub struct ReleaseChannelHeader(String);
1056
1057impl Header for ReleaseChannelHeader {
1058 fn name() -> &'static HeaderName {
1059 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1060 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1061 }
1062
1063 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1064 where
1065 Self: Sized,
1066 I: Iterator<Item = &'i axum::http::HeaderValue>,
1067 {
1068 Ok(Self(
1069 values
1070 .next()
1071 .ok_or_else(axum::headers::Error::invalid)?
1072 .to_str()
1073 .map_err(|_| axum::headers::Error::invalid())?
1074 .to_owned(),
1075 ))
1076 }
1077
1078 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1079 values.extend([self.0.parse().unwrap()]);
1080 }
1081}
1082
1083pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1084 Router::new()
1085 .route("/rpc", get(handle_websocket_request))
1086 .layer(
1087 ServiceBuilder::new()
1088 .layer(Extension(server.app_state.clone()))
1089 .layer(middleware::from_fn(auth::validate_header)),
1090 )
1091 .route("/metrics", get(handle_metrics))
1092 .layer(Extension(server))
1093}
1094
1095pub async fn handle_websocket_request(
1096 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1097 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1098 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1099 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1100 Extension(server): Extension<Arc<Server>>,
1101 Extension(principal): Extension<Principal>,
1102 user_agent: Option<TypedHeader<UserAgent>>,
1103 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1104 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1105 ws: WebSocketUpgrade,
1106) -> axum::response::Response {
1107 if protocol_version != rpc::PROTOCOL_VERSION {
1108 return (
1109 StatusCode::UPGRADE_REQUIRED,
1110 "client must be upgraded".to_string(),
1111 )
1112 .into_response();
1113 }
1114
1115 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1116 return (
1117 StatusCode::UPGRADE_REQUIRED,
1118 "no version header found".to_string(),
1119 )
1120 .into_response();
1121 };
1122
1123 let release_channel = release_channel_header.map(|header| header.0.0);
1124
1125 if !version.can_collaborate() {
1126 return (
1127 StatusCode::UPGRADE_REQUIRED,
1128 "client must be upgraded".to_string(),
1129 )
1130 .into_response();
1131 }
1132
1133 let socket_address = socket_address.to_string();
1134
1135 // Acquire connection guard before WebSocket upgrade
1136 let connection_guard = match ConnectionGuard::try_acquire() {
1137 Ok(guard) => guard,
1138 Err(()) => {
1139 return (
1140 StatusCode::SERVICE_UNAVAILABLE,
1141 "Too many concurrent connections",
1142 )
1143 .into_response();
1144 }
1145 };
1146
1147 ws.on_upgrade(move |socket| {
1148 let socket = socket
1149 .map_ok(to_tungstenite_message)
1150 .err_into()
1151 .with(|message| async move { to_axum_message(message) });
1152 let connection = Connection::new(Box::pin(socket));
1153 async move {
1154 server
1155 .handle_connection(
1156 connection,
1157 socket_address,
1158 principal,
1159 version,
1160 release_channel,
1161 user_agent.map(|header| header.to_string()),
1162 country_code_header.map(|header| header.to_string()),
1163 system_id_header.map(|header| header.to_string()),
1164 None,
1165 Executor::Production,
1166 Some(connection_guard),
1167 )
1168 .await;
1169 }
1170 })
1171}
1172
1173pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1174 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1175 let connections_metric = CONNECTIONS_METRIC
1176 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1177
1178 let connections = server
1179 .connection_pool
1180 .lock()
1181 .connections()
1182 .filter(|connection| !connection.admin)
1183 .count();
1184 connections_metric.set(connections as _);
1185
1186 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1187 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1188 register_int_gauge!(
1189 "shared_projects",
1190 "number of open projects with one or more guests"
1191 )
1192 .unwrap()
1193 });
1194
1195 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1196 shared_projects_metric.set(shared_projects as _);
1197
1198 let encoder = prometheus::TextEncoder::new();
1199 let metric_families = prometheus::gather();
1200 let encoded_metrics = encoder
1201 .encode_to_string(&metric_families)
1202 .map_err(|err| anyhow!("{err}"))?;
1203 Ok(encoded_metrics)
1204}
1205
1206#[instrument(err, skip(executor))]
1207async fn connection_lost(
1208 session: Session,
1209 mut teardown: watch::Receiver<bool>,
1210 executor: Executor,
1211) -> Result<()> {
1212 session.peer.disconnect(session.connection_id);
1213 session
1214 .connection_pool()
1215 .await
1216 .remove_connection(session.connection_id)?;
1217
1218 session
1219 .db()
1220 .await
1221 .connection_lost(session.connection_id)
1222 .await
1223 .trace_err();
1224
1225 futures::select_biased! {
1226 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1227
1228 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1229 leave_room_for_session(&session, session.connection_id).await.trace_err();
1230 leave_channel_buffers_for_session(&session)
1231 .await
1232 .trace_err();
1233
1234 if !session
1235 .connection_pool()
1236 .await
1237 .is_user_online(session.user_id())
1238 {
1239 let db = session.db().await;
1240 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1241 room_updated(&room, &session.peer);
1242 }
1243 }
1244
1245 update_user_contacts(session.user_id(), &session).await?;
1246 },
1247 _ = teardown.changed().fuse() => {}
1248 }
1249
1250 Ok(())
1251}
1252
1253/// Acknowledges a ping from a client, used to keep the connection alive.
1254async fn ping(
1255 _: proto::Ping,
1256 response: Response<proto::Ping>,
1257 _session: MessageContext,
1258) -> Result<()> {
1259 response.send(proto::Ack {})?;
1260 Ok(())
1261}
1262
1263/// Creates a new room for calling (outside of channels)
1264async fn create_room(
1265 _request: proto::CreateRoom,
1266 response: Response<proto::CreateRoom>,
1267 session: MessageContext,
1268) -> Result<()> {
1269 let livekit_room = nanoid::nanoid!(30);
1270
1271 let live_kit_connection_info = util::maybe!(async {
1272 let live_kit = session.app_state.livekit_client.as_ref();
1273 let live_kit = live_kit?;
1274 let user_id = session.user_id().to_string();
1275
1276 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1277
1278 Some(proto::LiveKitConnectionInfo {
1279 server_url: live_kit.url().into(),
1280 token,
1281 can_publish: true,
1282 })
1283 })
1284 .await;
1285
1286 let room = session
1287 .db()
1288 .await
1289 .create_room(session.user_id(), session.connection_id, &livekit_room)
1290 .await?;
1291
1292 response.send(proto::CreateRoomResponse {
1293 room: Some(room.clone()),
1294 live_kit_connection_info,
1295 })?;
1296
1297 update_user_contacts(session.user_id(), &session).await?;
1298 Ok(())
1299}
1300
1301/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1302async fn join_room(
1303 request: proto::JoinRoom,
1304 response: Response<proto::JoinRoom>,
1305 session: MessageContext,
1306) -> Result<()> {
1307 let room_id = RoomId::from_proto(request.id);
1308
1309 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1310
1311 if let Some(channel_id) = channel_id {
1312 return join_channel_internal(channel_id, Box::new(response), session).await;
1313 }
1314
1315 let joined_room = {
1316 let room = session
1317 .db()
1318 .await
1319 .join_room(room_id, session.user_id(), session.connection_id)
1320 .await?;
1321 room_updated(&room.room, &session.peer);
1322 room.into_inner()
1323 };
1324
1325 for connection_id in session
1326 .connection_pool()
1327 .await
1328 .user_connection_ids(session.user_id())
1329 {
1330 session
1331 .peer
1332 .send(
1333 connection_id,
1334 proto::CallCanceled {
1335 room_id: room_id.to_proto(),
1336 },
1337 )
1338 .trace_err();
1339 }
1340
1341 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1342 {
1343 live_kit
1344 .room_token(
1345 &joined_room.room.livekit_room,
1346 &session.user_id().to_string(),
1347 )
1348 .trace_err()
1349 .map(|token| proto::LiveKitConnectionInfo {
1350 server_url: live_kit.url().into(),
1351 token,
1352 can_publish: true,
1353 })
1354 } else {
1355 None
1356 };
1357
1358 response.send(proto::JoinRoomResponse {
1359 room: Some(joined_room.room),
1360 channel_id: None,
1361 live_kit_connection_info,
1362 })?;
1363
1364 update_user_contacts(session.user_id(), &session).await?;
1365 Ok(())
1366}
1367
1368/// Rejoin room is used to reconnect to a room after connection errors.
1369async fn rejoin_room(
1370 request: proto::RejoinRoom,
1371 response: Response<proto::RejoinRoom>,
1372 session: MessageContext,
1373) -> Result<()> {
1374 let room;
1375 let channel;
1376 {
1377 let mut rejoined_room = session
1378 .db()
1379 .await
1380 .rejoin_room(request, session.user_id(), session.connection_id)
1381 .await?;
1382
1383 response.send(proto::RejoinRoomResponse {
1384 room: Some(rejoined_room.room.clone()),
1385 reshared_projects: rejoined_room
1386 .reshared_projects
1387 .iter()
1388 .map(|project| proto::ResharedProject {
1389 id: project.id.to_proto(),
1390 collaborators: project
1391 .collaborators
1392 .iter()
1393 .map(|collaborator| collaborator.to_proto())
1394 .collect(),
1395 })
1396 .collect(),
1397 rejoined_projects: rejoined_room
1398 .rejoined_projects
1399 .iter()
1400 .map(|rejoined_project| rejoined_project.to_proto())
1401 .collect(),
1402 })?;
1403 room_updated(&rejoined_room.room, &session.peer);
1404
1405 for project in &rejoined_room.reshared_projects {
1406 for collaborator in &project.collaborators {
1407 session
1408 .peer
1409 .send(
1410 collaborator.connection_id,
1411 proto::UpdateProjectCollaborator {
1412 project_id: project.id.to_proto(),
1413 old_peer_id: Some(project.old_connection_id.into()),
1414 new_peer_id: Some(session.connection_id.into()),
1415 },
1416 )
1417 .trace_err();
1418 }
1419
1420 broadcast(
1421 Some(session.connection_id),
1422 project
1423 .collaborators
1424 .iter()
1425 .map(|collaborator| collaborator.connection_id),
1426 |connection_id| {
1427 session.peer.forward_send(
1428 session.connection_id,
1429 connection_id,
1430 proto::UpdateProject {
1431 project_id: project.id.to_proto(),
1432 worktrees: project.worktrees.clone(),
1433 },
1434 )
1435 },
1436 );
1437 }
1438
1439 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1440
1441 let rejoined_room = rejoined_room.into_inner();
1442
1443 room = rejoined_room.room;
1444 channel = rejoined_room.channel;
1445 }
1446
1447 if let Some(channel) = channel {
1448 channel_updated(
1449 &channel,
1450 &room,
1451 &session.peer,
1452 &*session.connection_pool().await,
1453 );
1454 }
1455
1456 update_user_contacts(session.user_id(), &session).await?;
1457 Ok(())
1458}
1459
1460fn notify_rejoined_projects(
1461 rejoined_projects: &mut Vec<RejoinedProject>,
1462 session: &Session,
1463) -> Result<()> {
1464 for project in rejoined_projects.iter() {
1465 for collaborator in &project.collaborators {
1466 session
1467 .peer
1468 .send(
1469 collaborator.connection_id,
1470 proto::UpdateProjectCollaborator {
1471 project_id: project.id.to_proto(),
1472 old_peer_id: Some(project.old_connection_id.into()),
1473 new_peer_id: Some(session.connection_id.into()),
1474 },
1475 )
1476 .trace_err();
1477 }
1478 }
1479
1480 for project in rejoined_projects {
1481 for worktree in mem::take(&mut project.worktrees) {
1482 // Stream this worktree's entries.
1483 let message = proto::UpdateWorktree {
1484 project_id: project.id.to_proto(),
1485 worktree_id: worktree.id,
1486 abs_path: worktree.abs_path.clone(),
1487 root_name: worktree.root_name,
1488 root_repo_common_dir: worktree.root_repo_common_dir,
1489 updated_entries: worktree.updated_entries,
1490 removed_entries: worktree.removed_entries,
1491 scan_id: worktree.scan_id,
1492 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1493 updated_repositories: worktree.updated_repositories,
1494 removed_repositories: worktree.removed_repositories,
1495 };
1496 for update in proto::split_worktree_update(message) {
1497 session.peer.send(session.connection_id, update)?;
1498 }
1499
1500 // Stream this worktree's diagnostics.
1501 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1502 if let Some(summary) = worktree_diagnostics.next() {
1503 let message = proto::UpdateDiagnosticSummary {
1504 project_id: project.id.to_proto(),
1505 worktree_id: worktree.id,
1506 summary: Some(summary),
1507 more_summaries: worktree_diagnostics.collect(),
1508 };
1509 session.peer.send(session.connection_id, message)?;
1510 }
1511
1512 for settings_file in worktree.settings_files {
1513 session.peer.send(
1514 session.connection_id,
1515 proto::UpdateWorktreeSettings {
1516 project_id: project.id.to_proto(),
1517 worktree_id: worktree.id,
1518 path: settings_file.path,
1519 content: Some(settings_file.content),
1520 kind: Some(settings_file.kind.to_proto().into()),
1521 outside_worktree: Some(settings_file.outside_worktree),
1522 },
1523 )?;
1524 }
1525 }
1526
1527 for repository in mem::take(&mut project.updated_repositories) {
1528 for update in split_repository_update(repository) {
1529 session.peer.send(session.connection_id, update)?;
1530 }
1531 }
1532
1533 for id in mem::take(&mut project.removed_repositories) {
1534 session.peer.send(
1535 session.connection_id,
1536 proto::RemoveRepository {
1537 project_id: project.id.to_proto(),
1538 id,
1539 },
1540 )?;
1541 }
1542 }
1543
1544 Ok(())
1545}
1546
1547/// leave room disconnects from the room.
1548async fn leave_room(
1549 _: proto::LeaveRoom,
1550 response: Response<proto::LeaveRoom>,
1551 session: MessageContext,
1552) -> Result<()> {
1553 leave_room_for_session(&session, session.connection_id).await?;
1554 response.send(proto::Ack {})?;
1555 Ok(())
1556}
1557
1558/// Updates the permissions of someone else in the room.
1559async fn set_room_participant_role(
1560 request: proto::SetRoomParticipantRole,
1561 response: Response<proto::SetRoomParticipantRole>,
1562 session: MessageContext,
1563) -> Result<()> {
1564 let user_id = UserId::from_proto(request.user_id);
1565 let role = ChannelRole::from(request.role());
1566
1567 let (livekit_room, can_publish) = {
1568 let room = session
1569 .db()
1570 .await
1571 .set_room_participant_role(
1572 session.user_id(),
1573 RoomId::from_proto(request.room_id),
1574 user_id,
1575 role,
1576 )
1577 .await?;
1578
1579 let livekit_room = room.livekit_room.clone();
1580 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1581 room_updated(&room, &session.peer);
1582 (livekit_room, can_publish)
1583 };
1584
1585 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1586 live_kit
1587 .update_participant(
1588 livekit_room.clone(),
1589 request.user_id.to_string(),
1590 livekit_api::proto::ParticipantPermission {
1591 can_subscribe: true,
1592 can_publish,
1593 can_publish_data: can_publish,
1594 hidden: false,
1595 recorder: false,
1596 },
1597 )
1598 .await
1599 .trace_err();
1600 }
1601
1602 response.send(proto::Ack {})?;
1603 Ok(())
1604}
1605
1606/// Call someone else into the current room
1607async fn call(
1608 request: proto::Call,
1609 response: Response<proto::Call>,
1610 session: MessageContext,
1611) -> Result<()> {
1612 let room_id = RoomId::from_proto(request.room_id);
1613 let calling_user_id = session.user_id();
1614 let calling_connection_id = session.connection_id;
1615 let called_user_id = UserId::from_proto(request.called_user_id);
1616 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1617 if !session
1618 .db()
1619 .await
1620 .has_contact(calling_user_id, called_user_id)
1621 .await?
1622 {
1623 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1624 }
1625
1626 let incoming_call = {
1627 let (room, incoming_call) = &mut *session
1628 .db()
1629 .await
1630 .call(
1631 room_id,
1632 calling_user_id,
1633 calling_connection_id,
1634 called_user_id,
1635 initial_project_id,
1636 )
1637 .await?;
1638 room_updated(room, &session.peer);
1639 mem::take(incoming_call)
1640 };
1641 update_user_contacts(called_user_id, &session).await?;
1642
1643 let mut calls = session
1644 .connection_pool()
1645 .await
1646 .user_connection_ids(called_user_id)
1647 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1648 .collect::<FuturesUnordered<_>>();
1649
1650 while let Some(call_response) = calls.next().await {
1651 match call_response.as_ref() {
1652 Ok(_) => {
1653 response.send(proto::Ack {})?;
1654 return Ok(());
1655 }
1656 Err(_) => {
1657 call_response.trace_err();
1658 }
1659 }
1660 }
1661
1662 {
1663 let room = session
1664 .db()
1665 .await
1666 .call_failed(room_id, called_user_id)
1667 .await?;
1668 room_updated(&room, &session.peer);
1669 }
1670 update_user_contacts(called_user_id, &session).await?;
1671
1672 Err(anyhow!("failed to ring user"))?
1673}
1674
1675/// Cancel an outgoing call.
1676async fn cancel_call(
1677 request: proto::CancelCall,
1678 response: Response<proto::CancelCall>,
1679 session: MessageContext,
1680) -> Result<()> {
1681 let called_user_id = UserId::from_proto(request.called_user_id);
1682 let room_id = RoomId::from_proto(request.room_id);
1683 {
1684 let room = session
1685 .db()
1686 .await
1687 .cancel_call(room_id, session.connection_id, called_user_id)
1688 .await?;
1689 room_updated(&room, &session.peer);
1690 }
1691
1692 for connection_id in session
1693 .connection_pool()
1694 .await
1695 .user_connection_ids(called_user_id)
1696 {
1697 session
1698 .peer
1699 .send(
1700 connection_id,
1701 proto::CallCanceled {
1702 room_id: room_id.to_proto(),
1703 },
1704 )
1705 .trace_err();
1706 }
1707 response.send(proto::Ack {})?;
1708
1709 update_user_contacts(called_user_id, &session).await?;
1710 Ok(())
1711}
1712
1713/// Decline an incoming call.
1714async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1715 let room_id = RoomId::from_proto(message.room_id);
1716 {
1717 let room = session
1718 .db()
1719 .await
1720 .decline_call(Some(room_id), session.user_id())
1721 .await?
1722 .context("declining call")?;
1723 room_updated(&room, &session.peer);
1724 }
1725
1726 for connection_id in session
1727 .connection_pool()
1728 .await
1729 .user_connection_ids(session.user_id())
1730 {
1731 session
1732 .peer
1733 .send(
1734 connection_id,
1735 proto::CallCanceled {
1736 room_id: room_id.to_proto(),
1737 },
1738 )
1739 .trace_err();
1740 }
1741 update_user_contacts(session.user_id(), &session).await?;
1742 Ok(())
1743}
1744
1745/// Updates other participants in the room with your current location.
1746async fn update_participant_location(
1747 request: proto::UpdateParticipantLocation,
1748 response: Response<proto::UpdateParticipantLocation>,
1749 session: MessageContext,
1750) -> Result<()> {
1751 let room_id = RoomId::from_proto(request.room_id);
1752 let location = request.location.context("invalid location")?;
1753
1754 let db = session.db().await;
1755 let room = db
1756 .update_room_participant_location(room_id, session.connection_id, location)
1757 .await?;
1758
1759 room_updated(&room, &session.peer);
1760 response.send(proto::Ack {})?;
1761 Ok(())
1762}
1763
1764/// Share a project into the room.
1765async fn share_project(
1766 request: proto::ShareProject,
1767 response: Response<proto::ShareProject>,
1768 session: MessageContext,
1769) -> Result<()> {
1770 let (project_id, room) = &*session
1771 .db()
1772 .await
1773 .share_project(
1774 RoomId::from_proto(request.room_id),
1775 session.connection_id,
1776 &request.worktrees,
1777 request.is_ssh_project,
1778 request.windows_paths.unwrap_or(false),
1779 &request.features,
1780 )
1781 .await?;
1782 response.send(proto::ShareProjectResponse {
1783 project_id: project_id.to_proto(),
1784 })?;
1785 room_updated(room, &session.peer);
1786
1787 Ok(())
1788}
1789
1790/// Unshare a project from the room.
1791async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1792 let project_id = ProjectId::from_proto(message.project_id);
1793 unshare_project_internal(project_id, session.connection_id, &session).await
1794}
1795
1796async fn unshare_project_internal(
1797 project_id: ProjectId,
1798 connection_id: ConnectionId,
1799 session: &Session,
1800) -> Result<()> {
1801 let delete = {
1802 let room_guard = session
1803 .db()
1804 .await
1805 .unshare_project(project_id, connection_id)
1806 .await?;
1807
1808 let (delete, room, guest_connection_ids) = &*room_guard;
1809
1810 let message = proto::UnshareProject {
1811 project_id: project_id.to_proto(),
1812 };
1813
1814 broadcast(
1815 Some(connection_id),
1816 guest_connection_ids.iter().copied(),
1817 |conn_id| session.peer.send(conn_id, message.clone()),
1818 );
1819 if let Some(room) = room {
1820 room_updated(room, &session.peer);
1821 }
1822
1823 *delete
1824 };
1825
1826 if delete {
1827 let db = session.db().await;
1828 db.delete_project(project_id).await?;
1829 }
1830
1831 Ok(())
1832}
1833
1834/// Join someone elses shared project.
1835async fn join_project(
1836 request: proto::JoinProject,
1837 response: Response<proto::JoinProject>,
1838 session: MessageContext,
1839) -> Result<()> {
1840 let project_id = ProjectId::from_proto(request.project_id);
1841
1842 tracing::info!(%project_id, "join project");
1843
1844 let db = session.db().await;
1845 let project_model = db.get_project(project_id).await?;
1846 let host_features: Vec<String> =
1847 serde_json::from_str(&project_model.features).unwrap_or_default();
1848 let guest_features: HashSet<_> = request.features.iter().collect();
1849 let host_features_set: HashSet<_> = host_features.iter().collect();
1850 if guest_features != host_features_set {
1851 let host_connection_id = project_model.host_connection()?;
1852 let mut pool = session.connection_pool().await;
1853 let host_version = pool
1854 .connection(host_connection_id)
1855 .map(|c| c.zed_version.to_string());
1856 let guest_version = pool
1857 .connection(session.connection_id)
1858 .map(|c| c.zed_version.to_string());
1859 drop(pool);
1860 Err(anyhow!(
1861 "The host (v{}) and guest (v{}) are using incompatible versions of Zed. The peer with the older version must update to collaborate.",
1862 host_version.as_deref().unwrap_or("unknown"),
1863 guest_version.as_deref().unwrap_or("unknown"),
1864 ))?;
1865 }
1866
1867 let (project, replica_id) = &mut *db
1868 .join_project(
1869 project_id,
1870 session.connection_id,
1871 session.user_id(),
1872 request.committer_name.clone(),
1873 request.committer_email.clone(),
1874 )
1875 .await?;
1876 drop(db);
1877
1878 tracing::info!(%project_id, "join remote project");
1879 let collaborators = project
1880 .collaborators
1881 .iter()
1882 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1883 .map(|collaborator| collaborator.to_proto())
1884 .collect::<Vec<_>>();
1885 let project_id = project.id;
1886 let guest_user_id = session.user_id();
1887
1888 let worktrees = project
1889 .worktrees
1890 .iter()
1891 .map(|(id, worktree)| proto::WorktreeMetadata {
1892 id: *id,
1893 root_name: worktree.root_name.clone(),
1894 visible: worktree.visible,
1895 abs_path: worktree.abs_path.clone(),
1896 })
1897 .collect::<Vec<_>>();
1898
1899 let add_project_collaborator = proto::AddProjectCollaborator {
1900 project_id: project_id.to_proto(),
1901 collaborator: Some(proto::Collaborator {
1902 peer_id: Some(session.connection_id.into()),
1903 replica_id: replica_id.0 as u32,
1904 user_id: guest_user_id.to_proto(),
1905 is_host: false,
1906 committer_name: request.committer_name.clone(),
1907 committer_email: request.committer_email.clone(),
1908 }),
1909 };
1910
1911 for collaborator in &collaborators {
1912 session
1913 .peer
1914 .send(
1915 collaborator.peer_id.unwrap().into(),
1916 add_project_collaborator.clone(),
1917 )
1918 .trace_err();
1919 }
1920
1921 // First, we send the metadata associated with each worktree.
1922 let (language_servers, language_server_capabilities) = project
1923 .language_servers
1924 .clone()
1925 .into_iter()
1926 .map(|server| (server.server, server.capabilities))
1927 .unzip();
1928 response.send(proto::JoinProjectResponse {
1929 project_id: project.id.0 as u64,
1930 worktrees,
1931 replica_id: replica_id.0 as u32,
1932 collaborators,
1933 language_servers,
1934 language_server_capabilities,
1935 role: project.role.into(),
1936 windows_paths: project.path_style == PathStyle::Windows,
1937 features: project.features.clone(),
1938 })?;
1939
1940 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1941 // Stream this worktree's entries.
1942 let message = proto::UpdateWorktree {
1943 project_id: project_id.to_proto(),
1944 worktree_id,
1945 abs_path: worktree.abs_path.clone(),
1946 root_name: worktree.root_name,
1947 root_repo_common_dir: worktree.root_repo_common_dir,
1948 updated_entries: worktree.entries,
1949 removed_entries: Default::default(),
1950 scan_id: worktree.scan_id,
1951 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1952 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1953 removed_repositories: Default::default(),
1954 };
1955 for update in proto::split_worktree_update(message) {
1956 session.peer.send(session.connection_id, update.clone())?;
1957 }
1958
1959 // Stream this worktree's diagnostics.
1960 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1961 if let Some(summary) = worktree_diagnostics.next() {
1962 let message = proto::UpdateDiagnosticSummary {
1963 project_id: project.id.to_proto(),
1964 worktree_id: worktree.id,
1965 summary: Some(summary),
1966 more_summaries: worktree_diagnostics.collect(),
1967 };
1968 session.peer.send(session.connection_id, message)?;
1969 }
1970
1971 for settings_file in worktree.settings_files {
1972 session.peer.send(
1973 session.connection_id,
1974 proto::UpdateWorktreeSettings {
1975 project_id: project_id.to_proto(),
1976 worktree_id: worktree.id,
1977 path: settings_file.path,
1978 content: Some(settings_file.content),
1979 kind: Some(settings_file.kind.to_proto() as i32),
1980 outside_worktree: Some(settings_file.outside_worktree),
1981 },
1982 )?;
1983 }
1984 }
1985
1986 for repository in mem::take(&mut project.repositories) {
1987 for update in split_repository_update(repository) {
1988 session.peer.send(session.connection_id, update)?;
1989 }
1990 }
1991
1992 for language_server in &project.language_servers {
1993 session.peer.send(
1994 session.connection_id,
1995 proto::UpdateLanguageServer {
1996 project_id: project_id.to_proto(),
1997 server_name: Some(language_server.server.name.clone()),
1998 language_server_id: language_server.server.id,
1999 variant: Some(
2000 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2001 proto::LspDiskBasedDiagnosticsUpdated {},
2002 ),
2003 ),
2004 },
2005 )?;
2006 }
2007
2008 Ok(())
2009}
2010
2011/// Leave someone elses shared project.
2012async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2013 let sender_id = session.connection_id;
2014 let project_id = ProjectId::from_proto(request.project_id);
2015 let db = session.db().await;
2016
2017 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2018 tracing::info!(
2019 %project_id,
2020 "leave project"
2021 );
2022
2023 project_left(project, &session);
2024 if let Some(room) = room {
2025 room_updated(room, &session.peer);
2026 }
2027
2028 Ok(())
2029}
2030
2031/// Updates other participants with changes to the project
2032async fn update_project(
2033 request: proto::UpdateProject,
2034 response: Response<proto::UpdateProject>,
2035 session: MessageContext,
2036) -> Result<()> {
2037 let project_id = ProjectId::from_proto(request.project_id);
2038 let (room, guest_connection_ids) = &*session
2039 .db()
2040 .await
2041 .update_project(project_id, session.connection_id, &request.worktrees)
2042 .await?;
2043 broadcast(
2044 Some(session.connection_id),
2045 guest_connection_ids.iter().copied(),
2046 |connection_id| {
2047 session
2048 .peer
2049 .forward_send(session.connection_id, connection_id, request.clone())
2050 },
2051 );
2052 if let Some(room) = room {
2053 room_updated(room, &session.peer);
2054 }
2055 response.send(proto::Ack {})?;
2056
2057 Ok(())
2058}
2059
2060/// Updates other participants with changes to the worktree
2061async fn update_worktree(
2062 request: proto::UpdateWorktree,
2063 response: Response<proto::UpdateWorktree>,
2064 session: MessageContext,
2065) -> Result<()> {
2066 let guest_connection_ids = session
2067 .db()
2068 .await
2069 .update_worktree(&request, session.connection_id)
2070 .await?;
2071
2072 broadcast(
2073 Some(session.connection_id),
2074 guest_connection_ids.iter().copied(),
2075 |connection_id| {
2076 session
2077 .peer
2078 .forward_send(session.connection_id, connection_id, request.clone())
2079 },
2080 );
2081 response.send(proto::Ack {})?;
2082 Ok(())
2083}
2084
2085async fn update_repository(
2086 request: proto::UpdateRepository,
2087 response: Response<proto::UpdateRepository>,
2088 session: MessageContext,
2089) -> Result<()> {
2090 let guest_connection_ids = session
2091 .db()
2092 .await
2093 .update_repository(&request, session.connection_id)
2094 .await?;
2095
2096 broadcast(
2097 Some(session.connection_id),
2098 guest_connection_ids.iter().copied(),
2099 |connection_id| {
2100 session
2101 .peer
2102 .forward_send(session.connection_id, connection_id, request.clone())
2103 },
2104 );
2105 response.send(proto::Ack {})?;
2106 Ok(())
2107}
2108
2109async fn remove_repository(
2110 request: proto::RemoveRepository,
2111 response: Response<proto::RemoveRepository>,
2112 session: MessageContext,
2113) -> Result<()> {
2114 let guest_connection_ids = session
2115 .db()
2116 .await
2117 .remove_repository(&request, session.connection_id)
2118 .await?;
2119
2120 broadcast(
2121 Some(session.connection_id),
2122 guest_connection_ids.iter().copied(),
2123 |connection_id| {
2124 session
2125 .peer
2126 .forward_send(session.connection_id, connection_id, request.clone())
2127 },
2128 );
2129 response.send(proto::Ack {})?;
2130 Ok(())
2131}
2132
2133/// Updates other participants with changes to the diagnostics
2134async fn update_diagnostic_summary(
2135 message: proto::UpdateDiagnosticSummary,
2136 session: MessageContext,
2137) -> Result<()> {
2138 let guest_connection_ids = session
2139 .db()
2140 .await
2141 .update_diagnostic_summary(&message, session.connection_id)
2142 .await?;
2143
2144 broadcast(
2145 Some(session.connection_id),
2146 guest_connection_ids.iter().copied(),
2147 |connection_id| {
2148 session
2149 .peer
2150 .forward_send(session.connection_id, connection_id, message.clone())
2151 },
2152 );
2153
2154 Ok(())
2155}
2156
2157/// Updates other participants with changes to the worktree settings
2158async fn update_worktree_settings(
2159 message: proto::UpdateWorktreeSettings,
2160 session: MessageContext,
2161) -> Result<()> {
2162 let guest_connection_ids = session
2163 .db()
2164 .await
2165 .update_worktree_settings(&message, session.connection_id)
2166 .await?;
2167
2168 broadcast(
2169 Some(session.connection_id),
2170 guest_connection_ids.iter().copied(),
2171 |connection_id| {
2172 session
2173 .peer
2174 .forward_send(session.connection_id, connection_id, message.clone())
2175 },
2176 );
2177
2178 Ok(())
2179}
2180
2181/// Notify other participants that a language server has started.
2182async fn start_language_server(
2183 request: proto::StartLanguageServer,
2184 session: MessageContext,
2185) -> Result<()> {
2186 let guest_connection_ids = session
2187 .db()
2188 .await
2189 .start_language_server(&request, session.connection_id)
2190 .await?;
2191
2192 broadcast(
2193 Some(session.connection_id),
2194 guest_connection_ids.iter().copied(),
2195 |connection_id| {
2196 session
2197 .peer
2198 .forward_send(session.connection_id, connection_id, request.clone())
2199 },
2200 );
2201 Ok(())
2202}
2203
2204/// Notify other participants that a language server has changed.
2205async fn update_language_server(
2206 request: proto::UpdateLanguageServer,
2207 session: MessageContext,
2208) -> Result<()> {
2209 let project_id = ProjectId::from_proto(request.project_id);
2210 let db = session.db().await;
2211
2212 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2213 && let Some(capabilities) = update.capabilities.clone()
2214 {
2215 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2216 .await?;
2217 }
2218
2219 let project_connection_ids = db
2220 .project_connection_ids(project_id, session.connection_id, true)
2221 .await?;
2222 broadcast(
2223 Some(session.connection_id),
2224 project_connection_ids.iter().copied(),
2225 |connection_id| {
2226 session
2227 .peer
2228 .forward_send(session.connection_id, connection_id, request.clone())
2229 },
2230 );
2231 Ok(())
2232}
2233
2234/// forward a project request to the host. These requests should be read only
2235/// as guests are allowed to send them.
2236async fn forward_read_only_project_request<T>(
2237 request: T,
2238 response: Response<T>,
2239 session: MessageContext,
2240) -> Result<()>
2241where
2242 T: EntityMessage + RequestMessage,
2243{
2244 let project_id = ProjectId::from_proto(request.remote_entity_id());
2245 let host_connection_id = session
2246 .db()
2247 .await
2248 .host_for_read_only_project_request(project_id, session.connection_id)
2249 .await?;
2250 let payload = session.forward_request(host_connection_id, request).await?;
2251 response.send(payload)?;
2252 Ok(())
2253}
2254
2255/// forward a project request to the host. These requests are disallowed
2256/// for guests.
2257async fn forward_mutating_project_request<T>(
2258 request: T,
2259 response: Response<T>,
2260 session: MessageContext,
2261) -> Result<()>
2262where
2263 T: EntityMessage + RequestMessage,
2264{
2265 let project_id = ProjectId::from_proto(request.remote_entity_id());
2266
2267 let host_connection_id = session
2268 .db()
2269 .await
2270 .host_for_mutating_project_request(project_id, session.connection_id)
2271 .await?;
2272 let payload = session.forward_request(host_connection_id, request).await?;
2273 response.send(payload)?;
2274 Ok(())
2275}
2276
2277async fn disallow_guest_request<T>(
2278 _request: T,
2279 response: Response<T>,
2280 _session: MessageContext,
2281) -> Result<()>
2282where
2283 T: RequestMessage,
2284{
2285 response.peer.respond_with_error(
2286 response.receipt,
2287 ErrorCode::Forbidden
2288 .message("request is not allowed for guests".to_string())
2289 .to_proto(),
2290 )?;
2291 response.responded.store(true, SeqCst);
2292 Ok(())
2293}
2294
2295async fn lsp_query(
2296 request: proto::LspQuery,
2297 response: Response<proto::LspQuery>,
2298 session: MessageContext,
2299) -> Result<()> {
2300 let (name, should_write) = request.query_name_and_write_permissions();
2301 tracing::Span::current().record("lsp_query_request", name);
2302 tracing::info!("lsp_query message received");
2303 if should_write {
2304 forward_mutating_project_request(request, response, session).await
2305 } else {
2306 forward_read_only_project_request(request, response, session).await
2307 }
2308}
2309
2310/// Notify other participants that a new buffer has been created
2311async fn create_buffer_for_peer(
2312 request: proto::CreateBufferForPeer,
2313 session: MessageContext,
2314) -> Result<()> {
2315 session
2316 .db()
2317 .await
2318 .check_user_is_project_host(
2319 ProjectId::from_proto(request.project_id),
2320 session.connection_id,
2321 )
2322 .await?;
2323 let peer_id = request.peer_id.context("invalid peer id")?;
2324 session
2325 .peer
2326 .forward_send(session.connection_id, peer_id.into(), request)?;
2327 Ok(())
2328}
2329
2330/// Notify other participants that a new image has been created
2331async fn create_image_for_peer(
2332 request: proto::CreateImageForPeer,
2333 session: MessageContext,
2334) -> Result<()> {
2335 session
2336 .db()
2337 .await
2338 .check_user_is_project_host(
2339 ProjectId::from_proto(request.project_id),
2340 session.connection_id,
2341 )
2342 .await?;
2343 let peer_id = request.peer_id.context("invalid peer id")?;
2344 session
2345 .peer
2346 .forward_send(session.connection_id, peer_id.into(), request)?;
2347 Ok(())
2348}
2349
2350/// Notify other participants that a buffer has been updated. This is
2351/// allowed for guests as long as the update is limited to selections.
2352async fn update_buffer(
2353 request: proto::UpdateBuffer,
2354 response: Response<proto::UpdateBuffer>,
2355 session: MessageContext,
2356) -> Result<()> {
2357 let project_id = ProjectId::from_proto(request.project_id);
2358 let mut capability = Capability::ReadOnly;
2359
2360 for op in request.operations.iter() {
2361 match op.variant {
2362 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2363 Some(_) => capability = Capability::ReadWrite,
2364 }
2365 }
2366
2367 let host = {
2368 let guard = session
2369 .db()
2370 .await
2371 .connections_for_buffer_update(project_id, session.connection_id, capability)
2372 .await?;
2373
2374 let (host, guests) = &*guard;
2375
2376 broadcast(
2377 Some(session.connection_id),
2378 guests.clone(),
2379 |connection_id| {
2380 session
2381 .peer
2382 .forward_send(session.connection_id, connection_id, request.clone())
2383 },
2384 );
2385
2386 *host
2387 };
2388
2389 if host != session.connection_id {
2390 session.forward_request(host, request.clone()).await?;
2391 }
2392
2393 response.send(proto::Ack {})?;
2394 Ok(())
2395}
2396
2397async fn forward_project_search_chunk(
2398 message: proto::FindSearchCandidatesChunk,
2399 response: Response<proto::FindSearchCandidatesChunk>,
2400 session: MessageContext,
2401) -> Result<()> {
2402 let peer_id = message.peer_id.context("missing peer_id")?;
2403 let payload = session
2404 .peer
2405 .forward_request(session.connection_id, peer_id.into(), message)
2406 .await?;
2407 response.send(payload)?;
2408 Ok(())
2409}
2410
2411/// Notify other participants that a project has been updated.
2412async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2413 request: T,
2414 session: MessageContext,
2415) -> Result<()> {
2416 let project_id = ProjectId::from_proto(request.remote_entity_id());
2417 let project_connection_ids = session
2418 .db()
2419 .await
2420 .project_connection_ids(project_id, session.connection_id, false)
2421 .await?;
2422
2423 broadcast(
2424 Some(session.connection_id),
2425 project_connection_ids.iter().copied(),
2426 |connection_id| {
2427 session
2428 .peer
2429 .forward_send(session.connection_id, connection_id, request.clone())
2430 },
2431 );
2432 Ok(())
2433}
2434
2435/// Start following another user in a call.
2436async fn follow(
2437 request: proto::Follow,
2438 response: Response<proto::Follow>,
2439 session: MessageContext,
2440) -> Result<()> {
2441 let room_id = RoomId::from_proto(request.room_id);
2442 let project_id = request.project_id.map(ProjectId::from_proto);
2443 let leader_id = request.leader_id.context("invalid leader id")?.into();
2444 let follower_id = session.connection_id;
2445
2446 session
2447 .db()
2448 .await
2449 .check_room_participants(room_id, leader_id, session.connection_id)
2450 .await?;
2451
2452 let response_payload = session.forward_request(leader_id, request).await?;
2453 response.send(response_payload)?;
2454
2455 if let Some(project_id) = project_id {
2456 let room = session
2457 .db()
2458 .await
2459 .follow(room_id, project_id, leader_id, follower_id)
2460 .await?;
2461 room_updated(&room, &session.peer);
2462 }
2463
2464 Ok(())
2465}
2466
2467/// Stop following another user in a call.
2468async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2469 let room_id = RoomId::from_proto(request.room_id);
2470 let project_id = request.project_id.map(ProjectId::from_proto);
2471 let leader_id = request.leader_id.context("invalid leader id")?.into();
2472 let follower_id = session.connection_id;
2473
2474 session
2475 .db()
2476 .await
2477 .check_room_participants(room_id, leader_id, session.connection_id)
2478 .await?;
2479
2480 session
2481 .peer
2482 .forward_send(session.connection_id, leader_id, request)?;
2483
2484 if let Some(project_id) = project_id {
2485 let room = session
2486 .db()
2487 .await
2488 .unfollow(room_id, project_id, leader_id, follower_id)
2489 .await?;
2490 room_updated(&room, &session.peer);
2491 }
2492
2493 Ok(())
2494}
2495
2496/// Notify everyone following you of your current location.
2497async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2498 let room_id = RoomId::from_proto(request.room_id);
2499 let database = session.db.lock().await;
2500
2501 let connection_ids = if let Some(project_id) = request.project_id {
2502 let project_id = ProjectId::from_proto(project_id);
2503 database
2504 .project_connection_ids(project_id, session.connection_id, true)
2505 .await?
2506 } else {
2507 database
2508 .room_connection_ids(room_id, session.connection_id)
2509 .await?
2510 };
2511
2512 // For now, don't send view update messages back to that view's current leader.
2513 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2514 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2515 _ => None,
2516 });
2517
2518 for connection_id in connection_ids.iter().cloned() {
2519 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2520 session
2521 .peer
2522 .forward_send(session.connection_id, connection_id, request.clone())?;
2523 }
2524 }
2525 Ok(())
2526}
2527
2528/// Get public data about users.
2529async fn get_users(
2530 request: proto::GetUsers,
2531 response: Response<proto::GetUsers>,
2532 session: MessageContext,
2533) -> Result<()> {
2534 let user_ids = request
2535 .user_ids
2536 .into_iter()
2537 .map(UserId::from_proto)
2538 .collect();
2539 let users = session
2540 .db()
2541 .await
2542 .get_users_by_ids(user_ids)
2543 .await?
2544 .into_iter()
2545 .map(|user| proto::User {
2546 id: user.id.to_proto(),
2547 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2548 github_login: user.github_login,
2549 name: user.name,
2550 })
2551 .collect();
2552 response.send(proto::UsersResponse { users })?;
2553 Ok(())
2554}
2555
2556/// Search for users (to invite) buy Github login
2557async fn fuzzy_search_users(
2558 request: proto::FuzzySearchUsers,
2559 response: Response<proto::FuzzySearchUsers>,
2560 session: MessageContext,
2561) -> Result<()> {
2562 let query = request.query;
2563 let users = match query.len() {
2564 0 => vec![],
2565 1 | 2 => session
2566 .db()
2567 .await
2568 .get_user_by_github_login(&query)
2569 .await?
2570 .into_iter()
2571 .collect(),
2572 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2573 };
2574 let users = users
2575 .into_iter()
2576 .filter(|user| user.id != session.user_id())
2577 .map(|user| proto::User {
2578 id: user.id.to_proto(),
2579 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2580 github_login: user.github_login,
2581 name: user.name,
2582 })
2583 .collect();
2584 response.send(proto::UsersResponse { users })?;
2585 Ok(())
2586}
2587
2588/// Send a contact request to another user.
2589async fn request_contact(
2590 request: proto::RequestContact,
2591 response: Response<proto::RequestContact>,
2592 session: MessageContext,
2593) -> Result<()> {
2594 let requester_id = session.user_id();
2595 let responder_id = UserId::from_proto(request.responder_id);
2596 if requester_id == responder_id {
2597 return Err(anyhow!("cannot add yourself as a contact"))?;
2598 }
2599
2600 let notifications = session
2601 .db()
2602 .await
2603 .send_contact_request(requester_id, responder_id)
2604 .await?;
2605
2606 // Update outgoing contact requests of requester
2607 let mut update = proto::UpdateContacts::default();
2608 update.outgoing_requests.push(responder_id.to_proto());
2609 for connection_id in session
2610 .connection_pool()
2611 .await
2612 .user_connection_ids(requester_id)
2613 {
2614 session.peer.send(connection_id, update.clone())?;
2615 }
2616
2617 // Update incoming contact requests of responder
2618 let mut update = proto::UpdateContacts::default();
2619 update
2620 .incoming_requests
2621 .push(proto::IncomingContactRequest {
2622 requester_id: requester_id.to_proto(),
2623 });
2624 let connection_pool = session.connection_pool().await;
2625 for connection_id in connection_pool.user_connection_ids(responder_id) {
2626 session.peer.send(connection_id, update.clone())?;
2627 }
2628
2629 send_notifications(&connection_pool, &session.peer, notifications);
2630
2631 response.send(proto::Ack {})?;
2632 Ok(())
2633}
2634
2635/// Accept or decline a contact request
2636async fn respond_to_contact_request(
2637 request: proto::RespondToContactRequest,
2638 response: Response<proto::RespondToContactRequest>,
2639 session: MessageContext,
2640) -> Result<()> {
2641 let responder_id = session.user_id();
2642 let requester_id = UserId::from_proto(request.requester_id);
2643 let db = session.db().await;
2644 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2645 db.dismiss_contact_notification(responder_id, requester_id)
2646 .await?;
2647 } else {
2648 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2649
2650 let notifications = db
2651 .respond_to_contact_request(responder_id, requester_id, accept)
2652 .await?;
2653 let requester_busy = db.is_user_busy(requester_id).await?;
2654 let responder_busy = db.is_user_busy(responder_id).await?;
2655
2656 let pool = session.connection_pool().await;
2657 // Update responder with new contact
2658 let mut update = proto::UpdateContacts::default();
2659 if accept {
2660 update
2661 .contacts
2662 .push(contact_for_user(requester_id, requester_busy, &pool));
2663 }
2664 update
2665 .remove_incoming_requests
2666 .push(requester_id.to_proto());
2667 for connection_id in pool.user_connection_ids(responder_id) {
2668 session.peer.send(connection_id, update.clone())?;
2669 }
2670
2671 // Update requester with new contact
2672 let mut update = proto::UpdateContacts::default();
2673 if accept {
2674 update
2675 .contacts
2676 .push(contact_for_user(responder_id, responder_busy, &pool));
2677 }
2678 update
2679 .remove_outgoing_requests
2680 .push(responder_id.to_proto());
2681
2682 for connection_id in pool.user_connection_ids(requester_id) {
2683 session.peer.send(connection_id, update.clone())?;
2684 }
2685
2686 send_notifications(&pool, &session.peer, notifications);
2687 }
2688
2689 response.send(proto::Ack {})?;
2690 Ok(())
2691}
2692
2693/// Remove a contact.
2694async fn remove_contact(
2695 request: proto::RemoveContact,
2696 response: Response<proto::RemoveContact>,
2697 session: MessageContext,
2698) -> Result<()> {
2699 let requester_id = session.user_id();
2700 let responder_id = UserId::from_proto(request.user_id);
2701 let db = session.db().await;
2702 let (contact_accepted, deleted_notification_id) =
2703 db.remove_contact(requester_id, responder_id).await?;
2704
2705 let pool = session.connection_pool().await;
2706 // Update outgoing contact requests of requester
2707 let mut update = proto::UpdateContacts::default();
2708 if contact_accepted {
2709 update.remove_contacts.push(responder_id.to_proto());
2710 } else {
2711 update
2712 .remove_outgoing_requests
2713 .push(responder_id.to_proto());
2714 }
2715 for connection_id in pool.user_connection_ids(requester_id) {
2716 session.peer.send(connection_id, update.clone())?;
2717 }
2718
2719 // Update incoming contact requests of responder
2720 let mut update = proto::UpdateContacts::default();
2721 if contact_accepted {
2722 update.remove_contacts.push(requester_id.to_proto());
2723 } else {
2724 update
2725 .remove_incoming_requests
2726 .push(requester_id.to_proto());
2727 }
2728 for connection_id in pool.user_connection_ids(responder_id) {
2729 session.peer.send(connection_id, update.clone())?;
2730 if let Some(notification_id) = deleted_notification_id {
2731 session.peer.send(
2732 connection_id,
2733 proto::DeleteNotification {
2734 notification_id: notification_id.to_proto(),
2735 },
2736 )?;
2737 }
2738 }
2739
2740 response.send(proto::Ack {})?;
2741 Ok(())
2742}
2743
2744fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2745 version.0.minor < 139
2746}
2747
2748async fn subscribe_to_channels(
2749 _: proto::SubscribeToChannels,
2750 session: MessageContext,
2751) -> Result<()> {
2752 subscribe_user_to_channels(session.user_id(), &session).await?;
2753 Ok(())
2754}
2755
2756async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2757 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2758 let mut pool = session.connection_pool().await;
2759 for membership in &channels_for_user.channel_memberships {
2760 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2761 }
2762 session.peer.send(
2763 session.connection_id,
2764 build_update_user_channels(&channels_for_user),
2765 )?;
2766 session.peer.send(
2767 session.connection_id,
2768 build_channels_update(channels_for_user),
2769 )?;
2770 Ok(())
2771}
2772
2773/// Creates a new channel.
2774async fn create_channel(
2775 request: proto::CreateChannel,
2776 response: Response<proto::CreateChannel>,
2777 session: MessageContext,
2778) -> Result<()> {
2779 let db = session.db().await;
2780
2781 let parent_id = request.parent_id.map(ChannelId::from_proto);
2782 let (channel, membership) = db
2783 .create_channel(&request.name, parent_id, session.user_id())
2784 .await?;
2785
2786 let root_id = channel.root_id();
2787 let channel = Channel::from_model(channel);
2788
2789 response.send(proto::CreateChannelResponse {
2790 channel: Some(channel.to_proto()),
2791 parent_id: request.parent_id,
2792 })?;
2793
2794 let mut connection_pool = session.connection_pool().await;
2795 if let Some(membership) = membership {
2796 connection_pool.subscribe_to_channel(
2797 membership.user_id,
2798 membership.channel_id,
2799 membership.role,
2800 );
2801 let update = proto::UpdateUserChannels {
2802 channel_memberships: vec![proto::ChannelMembership {
2803 channel_id: membership.channel_id.to_proto(),
2804 role: membership.role.into(),
2805 }],
2806 ..Default::default()
2807 };
2808 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2809 session.peer.send(connection_id, update.clone())?;
2810 }
2811 }
2812
2813 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2814 if !role.can_see_channel(channel.visibility) {
2815 continue;
2816 }
2817
2818 let update = proto::UpdateChannels {
2819 channels: vec![channel.to_proto()],
2820 ..Default::default()
2821 };
2822 session.peer.send(connection_id, update.clone())?;
2823 }
2824
2825 Ok(())
2826}
2827
2828/// Delete a channel
2829async fn delete_channel(
2830 request: proto::DeleteChannel,
2831 response: Response<proto::DeleteChannel>,
2832 session: MessageContext,
2833) -> Result<()> {
2834 let db = session.db().await;
2835
2836 let channel_id = request.channel_id;
2837 let (root_channel, removed_channels) = db
2838 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2839 .await?;
2840 response.send(proto::Ack {})?;
2841
2842 // Notify members of removed channels
2843 let mut update = proto::UpdateChannels::default();
2844 update
2845 .delete_channels
2846 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2847
2848 let connection_pool = session.connection_pool().await;
2849 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2850 session.peer.send(connection_id, update.clone())?;
2851 }
2852
2853 Ok(())
2854}
2855
2856/// Invite someone to join a channel.
2857async fn invite_channel_member(
2858 request: proto::InviteChannelMember,
2859 response: Response<proto::InviteChannelMember>,
2860 session: MessageContext,
2861) -> Result<()> {
2862 let db = session.db().await;
2863 let channel_id = ChannelId::from_proto(request.channel_id);
2864 let invitee_id = UserId::from_proto(request.user_id);
2865 let InviteMemberResult {
2866 channel,
2867 notifications,
2868 } = db
2869 .invite_channel_member(
2870 channel_id,
2871 invitee_id,
2872 session.user_id(),
2873 request.role().into(),
2874 )
2875 .await?;
2876
2877 let update = proto::UpdateChannels {
2878 channel_invitations: vec![channel.to_proto()],
2879 ..Default::default()
2880 };
2881
2882 let connection_pool = session.connection_pool().await;
2883 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2884 session.peer.send(connection_id, update.clone())?;
2885 }
2886
2887 send_notifications(&connection_pool, &session.peer, notifications);
2888
2889 response.send(proto::Ack {})?;
2890 Ok(())
2891}
2892
2893/// remove someone from a channel
2894async fn remove_channel_member(
2895 request: proto::RemoveChannelMember,
2896 response: Response<proto::RemoveChannelMember>,
2897 session: MessageContext,
2898) -> Result<()> {
2899 let db = session.db().await;
2900 let channel_id = ChannelId::from_proto(request.channel_id);
2901 let member_id = UserId::from_proto(request.user_id);
2902
2903 let RemoveChannelMemberResult {
2904 membership_update,
2905 notification_id,
2906 } = db
2907 .remove_channel_member(channel_id, member_id, session.user_id())
2908 .await?;
2909
2910 let mut connection_pool = session.connection_pool().await;
2911 notify_membership_updated(
2912 &mut connection_pool,
2913 membership_update,
2914 member_id,
2915 &session.peer,
2916 );
2917 for connection_id in connection_pool.user_connection_ids(member_id) {
2918 if let Some(notification_id) = notification_id {
2919 session
2920 .peer
2921 .send(
2922 connection_id,
2923 proto::DeleteNotification {
2924 notification_id: notification_id.to_proto(),
2925 },
2926 )
2927 .trace_err();
2928 }
2929 }
2930
2931 response.send(proto::Ack {})?;
2932 Ok(())
2933}
2934
2935/// Toggle the channel between public and private.
2936/// Care is taken to maintain the invariant that public channels only descend from public channels,
2937/// (though members-only channels can appear at any point in the hierarchy).
2938async fn set_channel_visibility(
2939 request: proto::SetChannelVisibility,
2940 response: Response<proto::SetChannelVisibility>,
2941 session: MessageContext,
2942) -> Result<()> {
2943 let db = session.db().await;
2944 let channel_id = ChannelId::from_proto(request.channel_id);
2945 let visibility = request.visibility().into();
2946
2947 let channel_model = db
2948 .set_channel_visibility(channel_id, visibility, session.user_id())
2949 .await?;
2950 let root_id = channel_model.root_id();
2951 let channel = Channel::from_model(channel_model);
2952
2953 let mut connection_pool = session.connection_pool().await;
2954 for (user_id, role) in connection_pool
2955 .channel_user_ids(root_id)
2956 .collect::<Vec<_>>()
2957 .into_iter()
2958 {
2959 let update = if role.can_see_channel(channel.visibility) {
2960 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2961 proto::UpdateChannels {
2962 channels: vec![channel.to_proto()],
2963 ..Default::default()
2964 }
2965 } else {
2966 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2967 proto::UpdateChannels {
2968 delete_channels: vec![channel.id.to_proto()],
2969 ..Default::default()
2970 }
2971 };
2972
2973 for connection_id in connection_pool.user_connection_ids(user_id) {
2974 session.peer.send(connection_id, update.clone())?;
2975 }
2976 }
2977
2978 response.send(proto::Ack {})?;
2979 Ok(())
2980}
2981
2982/// Alter the role for a user in the channel.
2983async fn set_channel_member_role(
2984 request: proto::SetChannelMemberRole,
2985 response: Response<proto::SetChannelMemberRole>,
2986 session: MessageContext,
2987) -> Result<()> {
2988 let db = session.db().await;
2989 let channel_id = ChannelId::from_proto(request.channel_id);
2990 let member_id = UserId::from_proto(request.user_id);
2991 let result = db
2992 .set_channel_member_role(
2993 channel_id,
2994 session.user_id(),
2995 member_id,
2996 request.role().into(),
2997 )
2998 .await?;
2999
3000 match result {
3001 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3002 let mut connection_pool = session.connection_pool().await;
3003 notify_membership_updated(
3004 &mut connection_pool,
3005 membership_update,
3006 member_id,
3007 &session.peer,
3008 )
3009 }
3010 db::SetMemberRoleResult::InviteUpdated(channel) => {
3011 let update = proto::UpdateChannels {
3012 channel_invitations: vec![channel.to_proto()],
3013 ..Default::default()
3014 };
3015
3016 for connection_id in session
3017 .connection_pool()
3018 .await
3019 .user_connection_ids(member_id)
3020 {
3021 session.peer.send(connection_id, update.clone())?;
3022 }
3023 }
3024 }
3025
3026 response.send(proto::Ack {})?;
3027 Ok(())
3028}
3029
3030/// Change the name of a channel
3031async fn rename_channel(
3032 request: proto::RenameChannel,
3033 response: Response<proto::RenameChannel>,
3034 session: MessageContext,
3035) -> Result<()> {
3036 let db = session.db().await;
3037 let channel_id = ChannelId::from_proto(request.channel_id);
3038 let channel_model = db
3039 .rename_channel(channel_id, session.user_id(), &request.name)
3040 .await?;
3041 let root_id = channel_model.root_id();
3042 let channel = Channel::from_model(channel_model);
3043
3044 response.send(proto::RenameChannelResponse {
3045 channel: Some(channel.to_proto()),
3046 })?;
3047
3048 let connection_pool = session.connection_pool().await;
3049 let update = proto::UpdateChannels {
3050 channels: vec![channel.to_proto()],
3051 ..Default::default()
3052 };
3053 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3054 if role.can_see_channel(channel.visibility) {
3055 session.peer.send(connection_id, update.clone())?;
3056 }
3057 }
3058
3059 Ok(())
3060}
3061
3062/// Move a channel to a new parent.
3063async fn move_channel(
3064 request: proto::MoveChannel,
3065 response: Response<proto::MoveChannel>,
3066 session: MessageContext,
3067) -> Result<()> {
3068 let channel_id = ChannelId::from_proto(request.channel_id);
3069 let to = ChannelId::from_proto(request.to);
3070
3071 let (root_id, channels) = session
3072 .db()
3073 .await
3074 .move_channel(channel_id, to, session.user_id())
3075 .await?;
3076
3077 let connection_pool = session.connection_pool().await;
3078 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3079 let channels = channels
3080 .iter()
3081 .filter_map(|channel| {
3082 if role.can_see_channel(channel.visibility) {
3083 Some(channel.to_proto())
3084 } else {
3085 None
3086 }
3087 })
3088 .collect::<Vec<_>>();
3089 if channels.is_empty() {
3090 continue;
3091 }
3092
3093 let update = proto::UpdateChannels {
3094 channels,
3095 ..Default::default()
3096 };
3097
3098 session.peer.send(connection_id, update.clone())?;
3099 }
3100
3101 response.send(Ack {})?;
3102 Ok(())
3103}
3104
3105async fn reorder_channel(
3106 request: proto::ReorderChannel,
3107 response: Response<proto::ReorderChannel>,
3108 session: MessageContext,
3109) -> Result<()> {
3110 let channel_id = ChannelId::from_proto(request.channel_id);
3111 let direction = request.direction();
3112
3113 let updated_channels = session
3114 .db()
3115 .await
3116 .reorder_channel(channel_id, direction, session.user_id())
3117 .await?;
3118
3119 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3120 let connection_pool = session.connection_pool().await;
3121 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3122 let channels = updated_channels
3123 .iter()
3124 .filter_map(|channel| {
3125 if role.can_see_channel(channel.visibility) {
3126 Some(channel.to_proto())
3127 } else {
3128 None
3129 }
3130 })
3131 .collect::<Vec<_>>();
3132
3133 if channels.is_empty() {
3134 continue;
3135 }
3136
3137 let update = proto::UpdateChannels {
3138 channels,
3139 ..Default::default()
3140 };
3141
3142 session.peer.send(connection_id, update.clone())?;
3143 }
3144 }
3145
3146 response.send(Ack {})?;
3147 Ok(())
3148}
3149
3150/// Get the list of channel members
3151async fn get_channel_members(
3152 request: proto::GetChannelMembers,
3153 response: Response<proto::GetChannelMembers>,
3154 session: MessageContext,
3155) -> Result<()> {
3156 let db = session.db().await;
3157 let channel_id = ChannelId::from_proto(request.channel_id);
3158 let limit = if request.limit == 0 {
3159 u16::MAX as u64
3160 } else {
3161 request.limit
3162 };
3163 let (members, users) = db
3164 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3165 .await?;
3166 response.send(proto::GetChannelMembersResponse { members, users })?;
3167 Ok(())
3168}
3169
3170/// Accept or decline a channel invitation.
3171async fn respond_to_channel_invite(
3172 request: proto::RespondToChannelInvite,
3173 response: Response<proto::RespondToChannelInvite>,
3174 session: MessageContext,
3175) -> Result<()> {
3176 let db = session.db().await;
3177 let channel_id = ChannelId::from_proto(request.channel_id);
3178 let RespondToChannelInvite {
3179 membership_update,
3180 notifications,
3181 } = db
3182 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3183 .await?;
3184
3185 let mut connection_pool = session.connection_pool().await;
3186 if let Some(membership_update) = membership_update {
3187 notify_membership_updated(
3188 &mut connection_pool,
3189 membership_update,
3190 session.user_id(),
3191 &session.peer,
3192 );
3193 } else {
3194 let update = proto::UpdateChannels {
3195 remove_channel_invitations: vec![channel_id.to_proto()],
3196 ..Default::default()
3197 };
3198
3199 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3200 session.peer.send(connection_id, update.clone())?;
3201 }
3202 };
3203
3204 send_notifications(&connection_pool, &session.peer, notifications);
3205
3206 response.send(proto::Ack {})?;
3207
3208 Ok(())
3209}
3210
3211/// Join the channels' room
3212async fn join_channel(
3213 request: proto::JoinChannel,
3214 response: Response<proto::JoinChannel>,
3215 session: MessageContext,
3216) -> Result<()> {
3217 let channel_id = ChannelId::from_proto(request.channel_id);
3218 join_channel_internal(channel_id, Box::new(response), session).await
3219}
3220
3221trait JoinChannelInternalResponse {
3222 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3223}
3224impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3225 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3226 Response::<proto::JoinChannel>::send(self, result)
3227 }
3228}
3229impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3230 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3231 Response::<proto::JoinRoom>::send(self, result)
3232 }
3233}
3234
3235async fn join_channel_internal(
3236 channel_id: ChannelId,
3237 response: Box<impl JoinChannelInternalResponse>,
3238 session: MessageContext,
3239) -> Result<()> {
3240 let joined_room = {
3241 let mut db = session.db().await;
3242 // If zed quits without leaving the room, and the user re-opens zed before the
3243 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3244 // room they were in.
3245 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3246 tracing::info!(
3247 stale_connection_id = %connection,
3248 "cleaning up stale connection",
3249 );
3250 drop(db);
3251 leave_room_for_session(&session, connection).await?;
3252 db = session.db().await;
3253 }
3254
3255 let (joined_room, membership_updated, role) = db
3256 .join_channel(channel_id, session.user_id(), session.connection_id)
3257 .await?;
3258
3259 let live_kit_connection_info =
3260 session
3261 .app_state
3262 .livekit_client
3263 .as_ref()
3264 .and_then(|live_kit| {
3265 let (can_publish, token) = if role == ChannelRole::Guest {
3266 (
3267 false,
3268 live_kit
3269 .guest_token(
3270 &joined_room.room.livekit_room,
3271 &session.user_id().to_string(),
3272 )
3273 .trace_err()?,
3274 )
3275 } else {
3276 (
3277 true,
3278 live_kit
3279 .room_token(
3280 &joined_room.room.livekit_room,
3281 &session.user_id().to_string(),
3282 )
3283 .trace_err()?,
3284 )
3285 };
3286
3287 Some(LiveKitConnectionInfo {
3288 server_url: live_kit.url().into(),
3289 token,
3290 can_publish,
3291 })
3292 });
3293
3294 response.send(proto::JoinRoomResponse {
3295 room: Some(joined_room.room.clone()),
3296 channel_id: joined_room
3297 .channel
3298 .as_ref()
3299 .map(|channel| channel.id.to_proto()),
3300 live_kit_connection_info,
3301 })?;
3302
3303 let mut connection_pool = session.connection_pool().await;
3304 if let Some(membership_updated) = membership_updated {
3305 notify_membership_updated(
3306 &mut connection_pool,
3307 membership_updated,
3308 session.user_id(),
3309 &session.peer,
3310 );
3311 }
3312
3313 room_updated(&joined_room.room, &session.peer);
3314
3315 joined_room
3316 };
3317
3318 channel_updated(
3319 &joined_room.channel.context("channel not returned")?,
3320 &joined_room.room,
3321 &session.peer,
3322 &*session.connection_pool().await,
3323 );
3324
3325 update_user_contacts(session.user_id(), &session).await?;
3326 Ok(())
3327}
3328
3329/// Start editing the channel notes
3330async fn join_channel_buffer(
3331 request: proto::JoinChannelBuffer,
3332 response: Response<proto::JoinChannelBuffer>,
3333 session: MessageContext,
3334) -> Result<()> {
3335 let db = session.db().await;
3336 let channel_id = ChannelId::from_proto(request.channel_id);
3337
3338 let open_response = db
3339 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3340 .await?;
3341
3342 let collaborators = open_response.collaborators.clone();
3343 response.send(open_response)?;
3344
3345 let update = UpdateChannelBufferCollaborators {
3346 channel_id: channel_id.to_proto(),
3347 collaborators: collaborators.clone(),
3348 };
3349 channel_buffer_updated(
3350 session.connection_id,
3351 collaborators
3352 .iter()
3353 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3354 &update,
3355 &session.peer,
3356 );
3357
3358 Ok(())
3359}
3360
3361/// Edit the channel notes
3362async fn update_channel_buffer(
3363 request: proto::UpdateChannelBuffer,
3364 session: MessageContext,
3365) -> Result<()> {
3366 let db = session.db().await;
3367 let channel_id = ChannelId::from_proto(request.channel_id);
3368
3369 let (collaborators, epoch, version) = db
3370 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3371 .await?;
3372
3373 channel_buffer_updated(
3374 session.connection_id,
3375 collaborators.clone(),
3376 &proto::UpdateChannelBuffer {
3377 channel_id: channel_id.to_proto(),
3378 operations: request.operations,
3379 },
3380 &session.peer,
3381 );
3382
3383 let pool = &*session.connection_pool().await;
3384
3385 let non_collaborators =
3386 pool.channel_connection_ids(channel_id)
3387 .filter_map(|(connection_id, _)| {
3388 if collaborators.contains(&connection_id) {
3389 None
3390 } else {
3391 Some(connection_id)
3392 }
3393 });
3394
3395 broadcast(None, non_collaborators, |peer_id| {
3396 session.peer.send(
3397 peer_id,
3398 proto::UpdateChannels {
3399 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3400 channel_id: channel_id.to_proto(),
3401 epoch: epoch as u64,
3402 version: version.clone(),
3403 }],
3404 ..Default::default()
3405 },
3406 )
3407 });
3408
3409 Ok(())
3410}
3411
3412/// Rejoin the channel notes after a connection blip
3413async fn rejoin_channel_buffers(
3414 request: proto::RejoinChannelBuffers,
3415 response: Response<proto::RejoinChannelBuffers>,
3416 session: MessageContext,
3417) -> Result<()> {
3418 let db = session.db().await;
3419 let buffers = db
3420 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3421 .await?;
3422
3423 for rejoined_buffer in &buffers {
3424 let collaborators_to_notify = rejoined_buffer
3425 .buffer
3426 .collaborators
3427 .iter()
3428 .filter_map(|c| Some(c.peer_id?.into()));
3429 channel_buffer_updated(
3430 session.connection_id,
3431 collaborators_to_notify,
3432 &proto::UpdateChannelBufferCollaborators {
3433 channel_id: rejoined_buffer.buffer.channel_id,
3434 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3435 },
3436 &session.peer,
3437 );
3438 }
3439
3440 response.send(proto::RejoinChannelBuffersResponse {
3441 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3442 })?;
3443
3444 Ok(())
3445}
3446
3447/// Stop editing the channel notes
3448async fn leave_channel_buffer(
3449 request: proto::LeaveChannelBuffer,
3450 response: Response<proto::LeaveChannelBuffer>,
3451 session: MessageContext,
3452) -> Result<()> {
3453 let db = session.db().await;
3454 let channel_id = ChannelId::from_proto(request.channel_id);
3455
3456 let left_buffer = db
3457 .leave_channel_buffer(channel_id, session.connection_id)
3458 .await?;
3459
3460 response.send(Ack {})?;
3461
3462 channel_buffer_updated(
3463 session.connection_id,
3464 left_buffer.connections,
3465 &proto::UpdateChannelBufferCollaborators {
3466 channel_id: channel_id.to_proto(),
3467 collaborators: left_buffer.collaborators,
3468 },
3469 &session.peer,
3470 );
3471
3472 Ok(())
3473}
3474
3475fn channel_buffer_updated<T: EnvelopedMessage>(
3476 sender_id: ConnectionId,
3477 collaborators: impl IntoIterator<Item = ConnectionId>,
3478 message: &T,
3479 peer: &Peer,
3480) {
3481 broadcast(Some(sender_id), collaborators, |peer_id| {
3482 peer.send(peer_id, message.clone())
3483 });
3484}
3485
3486fn send_notifications(
3487 connection_pool: &ConnectionPool,
3488 peer: &Peer,
3489 notifications: db::NotificationBatch,
3490) {
3491 for (user_id, notification) in notifications {
3492 for connection_id in connection_pool.user_connection_ids(user_id) {
3493 if let Err(error) = peer.send(
3494 connection_id,
3495 proto::AddNotification {
3496 notification: Some(notification.clone()),
3497 },
3498 ) {
3499 tracing::error!(
3500 "failed to send notification to {:?} {}",
3501 connection_id,
3502 error
3503 );
3504 }
3505 }
3506 }
3507}
3508
3509/// Send a message to the channel
3510async fn send_channel_message(
3511 _request: proto::SendChannelMessage,
3512 _response: Response<proto::SendChannelMessage>,
3513 _session: MessageContext,
3514) -> Result<()> {
3515 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3516}
3517
3518/// Delete a channel message
3519async fn remove_channel_message(
3520 _request: proto::RemoveChannelMessage,
3521 _response: Response<proto::RemoveChannelMessage>,
3522 _session: MessageContext,
3523) -> Result<()> {
3524 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3525}
3526
3527async fn update_channel_message(
3528 _request: proto::UpdateChannelMessage,
3529 _response: Response<proto::UpdateChannelMessage>,
3530 _session: MessageContext,
3531) -> Result<()> {
3532 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3533}
3534
3535/// Mark a channel message as read
3536async fn acknowledge_channel_message(
3537 _request: proto::AckChannelMessage,
3538 _session: MessageContext,
3539) -> Result<()> {
3540 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3541}
3542
3543/// Mark a buffer version as synced
3544async fn acknowledge_buffer_version(
3545 request: proto::AckBufferOperation,
3546 session: MessageContext,
3547) -> Result<()> {
3548 let buffer_id = BufferId::from_proto(request.buffer_id);
3549 session
3550 .db()
3551 .await
3552 .observe_buffer_version(
3553 buffer_id,
3554 session.user_id(),
3555 request.epoch as i32,
3556 &request.version,
3557 )
3558 .await?;
3559 Ok(())
3560}
3561
3562/// Start receiving chat updates for a channel
3563async fn join_channel_chat(
3564 _request: proto::JoinChannelChat,
3565 _response: Response<proto::JoinChannelChat>,
3566 _session: MessageContext,
3567) -> Result<()> {
3568 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3569}
3570
3571/// Stop receiving chat updates for a channel
3572async fn leave_channel_chat(
3573 _request: proto::LeaveChannelChat,
3574 _session: MessageContext,
3575) -> Result<()> {
3576 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3577}
3578
3579/// Retrieve the chat history for a channel
3580async fn get_channel_messages(
3581 _request: proto::GetChannelMessages,
3582 _response: Response<proto::GetChannelMessages>,
3583 _session: MessageContext,
3584) -> Result<()> {
3585 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3586}
3587
3588/// Retrieve specific chat messages
3589async fn get_channel_messages_by_id(
3590 _request: proto::GetChannelMessagesById,
3591 _response: Response<proto::GetChannelMessagesById>,
3592 _session: MessageContext,
3593) -> Result<()> {
3594 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3595}
3596
3597/// Retrieve the current users notifications
3598async fn get_notifications(
3599 request: proto::GetNotifications,
3600 response: Response<proto::GetNotifications>,
3601 session: MessageContext,
3602) -> Result<()> {
3603 let notifications = session
3604 .db()
3605 .await
3606 .get_notifications(
3607 session.user_id(),
3608 NOTIFICATION_COUNT_PER_PAGE,
3609 request.before_id.map(db::NotificationId::from_proto),
3610 )
3611 .await?;
3612 response.send(proto::GetNotificationsResponse {
3613 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3614 notifications,
3615 })?;
3616 Ok(())
3617}
3618
3619/// Mark notifications as read
3620async fn mark_notification_as_read(
3621 request: proto::MarkNotificationRead,
3622 response: Response<proto::MarkNotificationRead>,
3623 session: MessageContext,
3624) -> Result<()> {
3625 let database = &session.db().await;
3626 let notifications = database
3627 .mark_notification_as_read_by_id(
3628 session.user_id(),
3629 NotificationId::from_proto(request.notification_id),
3630 )
3631 .await?;
3632 send_notifications(
3633 &*session.connection_pool().await,
3634 &session.peer,
3635 notifications,
3636 );
3637 response.send(proto::Ack {})?;
3638 Ok(())
3639}
3640
3641fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3642 let message = match message {
3643 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3644 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3645 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3646 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3647 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3648 code: frame.code.into(),
3649 reason: frame.reason.as_str().to_owned().into(),
3650 })),
3651 // We should never receive a frame while reading the message, according
3652 // to the `tungstenite` maintainers:
3653 //
3654 // > It cannot occur when you read messages from the WebSocket, but it
3655 // > can be used when you want to send the raw frames (e.g. you want to
3656 // > send the frames to the WebSocket without composing the full message first).
3657 // >
3658 // > — https://github.com/snapview/tungstenite-rs/issues/268
3659 TungsteniteMessage::Frame(_) => {
3660 bail!("received an unexpected frame while reading the message")
3661 }
3662 };
3663
3664 Ok(message)
3665}
3666
3667fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3668 match message {
3669 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3670 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3671 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3672 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3673 AxumMessage::Close(frame) => {
3674 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3675 code: frame.code.into(),
3676 reason: frame.reason.as_ref().into(),
3677 }))
3678 }
3679 }
3680}
3681
3682fn notify_membership_updated(
3683 connection_pool: &mut ConnectionPool,
3684 result: MembershipUpdated,
3685 user_id: UserId,
3686 peer: &Peer,
3687) {
3688 for membership in &result.new_channels.channel_memberships {
3689 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3690 }
3691 for channel_id in &result.removed_channels {
3692 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3693 }
3694
3695 let user_channels_update = proto::UpdateUserChannels {
3696 channel_memberships: result
3697 .new_channels
3698 .channel_memberships
3699 .iter()
3700 .map(|cm| proto::ChannelMembership {
3701 channel_id: cm.channel_id.to_proto(),
3702 role: cm.role.into(),
3703 })
3704 .collect(),
3705 ..Default::default()
3706 };
3707
3708 let mut update = build_channels_update(result.new_channels);
3709 update.delete_channels = result
3710 .removed_channels
3711 .into_iter()
3712 .map(|id| id.to_proto())
3713 .collect();
3714 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3715
3716 for connection_id in connection_pool.user_connection_ids(user_id) {
3717 peer.send(connection_id, user_channels_update.clone())
3718 .trace_err();
3719 peer.send(connection_id, update.clone()).trace_err();
3720 }
3721}
3722
3723fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3724 proto::UpdateUserChannels {
3725 channel_memberships: channels
3726 .channel_memberships
3727 .iter()
3728 .map(|m| proto::ChannelMembership {
3729 channel_id: m.channel_id.to_proto(),
3730 role: m.role.into(),
3731 })
3732 .collect(),
3733 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3734 }
3735}
3736
3737fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3738 let mut update = proto::UpdateChannels::default();
3739
3740 for channel in channels.channels {
3741 update.channels.push(channel.to_proto());
3742 }
3743
3744 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3745
3746 for (channel_id, participants) in channels.channel_participants {
3747 update
3748 .channel_participants
3749 .push(proto::ChannelParticipants {
3750 channel_id: channel_id.to_proto(),
3751 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3752 });
3753 }
3754
3755 for channel in channels.invited_channels {
3756 update.channel_invitations.push(channel.to_proto());
3757 }
3758
3759 update
3760}
3761
3762fn build_initial_contacts_update(
3763 contacts: Vec<db::Contact>,
3764 pool: &ConnectionPool,
3765) -> proto::UpdateContacts {
3766 let mut update = proto::UpdateContacts::default();
3767
3768 for contact in contacts {
3769 match contact {
3770 db::Contact::Accepted { user_id, busy } => {
3771 update.contacts.push(contact_for_user(user_id, busy, pool));
3772 }
3773 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3774 db::Contact::Incoming { user_id } => {
3775 update
3776 .incoming_requests
3777 .push(proto::IncomingContactRequest {
3778 requester_id: user_id.to_proto(),
3779 })
3780 }
3781 }
3782 }
3783
3784 update
3785}
3786
3787fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3788 proto::Contact {
3789 user_id: user_id.to_proto(),
3790 online: pool.is_user_online(user_id),
3791 busy,
3792 }
3793}
3794
3795fn room_updated(room: &proto::Room, peer: &Peer) {
3796 broadcast(
3797 None,
3798 room.participants
3799 .iter()
3800 .filter_map(|participant| Some(participant.peer_id?.into())),
3801 |peer_id| {
3802 peer.send(
3803 peer_id,
3804 proto::RoomUpdated {
3805 room: Some(room.clone()),
3806 },
3807 )
3808 },
3809 );
3810}
3811
3812fn channel_updated(
3813 channel: &db::channel::Model,
3814 room: &proto::Room,
3815 peer: &Peer,
3816 pool: &ConnectionPool,
3817) {
3818 let participants = room
3819 .participants
3820 .iter()
3821 .map(|p| p.user_id)
3822 .collect::<Vec<_>>();
3823
3824 broadcast(
3825 None,
3826 pool.channel_connection_ids(channel.root_id())
3827 .filter_map(|(channel_id, role)| {
3828 role.can_see_channel(channel.visibility)
3829 .then_some(channel_id)
3830 }),
3831 |peer_id| {
3832 peer.send(
3833 peer_id,
3834 proto::UpdateChannels {
3835 channel_participants: vec![proto::ChannelParticipants {
3836 channel_id: channel.id.to_proto(),
3837 participant_user_ids: participants.clone(),
3838 }],
3839 ..Default::default()
3840 },
3841 )
3842 },
3843 );
3844}
3845
3846async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3847 let db = session.db().await;
3848
3849 let contacts = db.get_contacts(user_id).await?;
3850 let busy = db.is_user_busy(user_id).await?;
3851
3852 let pool = session.connection_pool().await;
3853 let updated_contact = contact_for_user(user_id, busy, &pool);
3854 for contact in contacts {
3855 if let db::Contact::Accepted {
3856 user_id: contact_user_id,
3857 ..
3858 } = contact
3859 {
3860 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3861 session
3862 .peer
3863 .send(
3864 contact_conn_id,
3865 proto::UpdateContacts {
3866 contacts: vec![updated_contact.clone()],
3867 remove_contacts: Default::default(),
3868 incoming_requests: Default::default(),
3869 remove_incoming_requests: Default::default(),
3870 outgoing_requests: Default::default(),
3871 remove_outgoing_requests: Default::default(),
3872 },
3873 )
3874 .trace_err();
3875 }
3876 }
3877 }
3878 Ok(())
3879}
3880
3881async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3882 let mut contacts_to_update = HashSet::default();
3883
3884 let room_id;
3885 let canceled_calls_to_user_ids;
3886 let livekit_room;
3887 let delete_livekit_room;
3888 let room;
3889 let channel;
3890
3891 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3892 contacts_to_update.insert(session.user_id());
3893
3894 for project in left_room.left_projects.values() {
3895 project_left(project, session);
3896 }
3897
3898 room_id = RoomId::from_proto(left_room.room.id);
3899 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3900 livekit_room = mem::take(&mut left_room.room.livekit_room);
3901 delete_livekit_room = left_room.deleted;
3902 room = mem::take(&mut left_room.room);
3903 channel = mem::take(&mut left_room.channel);
3904
3905 room_updated(&room, &session.peer);
3906 } else {
3907 return Ok(());
3908 }
3909
3910 if let Some(channel) = channel {
3911 channel_updated(
3912 &channel,
3913 &room,
3914 &session.peer,
3915 &*session.connection_pool().await,
3916 );
3917 }
3918
3919 {
3920 let pool = session.connection_pool().await;
3921 for canceled_user_id in canceled_calls_to_user_ids {
3922 for connection_id in pool.user_connection_ids(canceled_user_id) {
3923 session
3924 .peer
3925 .send(
3926 connection_id,
3927 proto::CallCanceled {
3928 room_id: room_id.to_proto(),
3929 },
3930 )
3931 .trace_err();
3932 }
3933 contacts_to_update.insert(canceled_user_id);
3934 }
3935 }
3936
3937 for contact_user_id in contacts_to_update {
3938 update_user_contacts(contact_user_id, session).await?;
3939 }
3940
3941 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3942 live_kit
3943 .remove_participant(livekit_room.clone(), session.user_id().to_string())
3944 .await
3945 .trace_err();
3946
3947 if delete_livekit_room {
3948 live_kit.delete_room(livekit_room).await.trace_err();
3949 }
3950 }
3951
3952 Ok(())
3953}
3954
3955async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3956 let left_channel_buffers = session
3957 .db()
3958 .await
3959 .leave_channel_buffers(session.connection_id)
3960 .await?;
3961
3962 for left_buffer in left_channel_buffers {
3963 channel_buffer_updated(
3964 session.connection_id,
3965 left_buffer.connections,
3966 &proto::UpdateChannelBufferCollaborators {
3967 channel_id: left_buffer.channel_id.to_proto(),
3968 collaborators: left_buffer.collaborators,
3969 },
3970 &session.peer,
3971 );
3972 }
3973
3974 Ok(())
3975}
3976
3977fn project_left(project: &db::LeftProject, session: &Session) {
3978 for connection_id in &project.connection_ids {
3979 if project.should_unshare {
3980 session
3981 .peer
3982 .send(
3983 *connection_id,
3984 proto::UnshareProject {
3985 project_id: project.id.to_proto(),
3986 },
3987 )
3988 .trace_err();
3989 } else {
3990 session
3991 .peer
3992 .send(
3993 *connection_id,
3994 proto::RemoveProjectCollaborator {
3995 project_id: project.id.to_proto(),
3996 peer_id: Some(session.connection_id.into()),
3997 },
3998 )
3999 .trace_err();
4000 }
4001 }
4002}
4003
4004async fn share_agent_thread(
4005 request: proto::ShareAgentThread,
4006 response: Response<proto::ShareAgentThread>,
4007 session: MessageContext,
4008) -> Result<()> {
4009 let user_id = session.user_id();
4010
4011 let share_id = SharedThreadId::from_proto(request.session_id.clone())
4012 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4013
4014 session
4015 .db()
4016 .await
4017 .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4018 .await?;
4019
4020 response.send(proto::Ack {})?;
4021
4022 Ok(())
4023}
4024
4025async fn get_shared_agent_thread(
4026 request: proto::GetSharedAgentThread,
4027 response: Response<proto::GetSharedAgentThread>,
4028 session: MessageContext,
4029) -> Result<()> {
4030 let share_id = SharedThreadId::from_proto(request.session_id)
4031 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4032
4033 let result = session.db().await.get_shared_thread(share_id).await?;
4034
4035 match result {
4036 Some((thread, username)) => {
4037 response.send(proto::GetSharedAgentThreadResponse {
4038 title: thread.title,
4039 thread_data: thread.data,
4040 sharer_username: username,
4041 created_at: thread.created_at.and_utc().to_rfc3339(),
4042 })?;
4043 }
4044 None => {
4045 return Err(anyhow!("Shared thread not found").into());
4046 }
4047 }
4048
4049 Ok(())
4050}
4051
4052pub trait ResultExt {
4053 type Ok;
4054
4055 fn trace_err(self) -> Option<Self::Ok>;
4056}
4057
4058impl<T, E> ResultExt for Result<T, E>
4059where
4060 E: std::fmt::Debug,
4061{
4062 type Ok = T;
4063
4064 #[track_caller]
4065 fn trace_err(self) -> Option<T> {
4066 match self {
4067 Ok(value) => Some(value),
4068 Err(error) => {
4069 tracing::error!("{:?}", error);
4070 None
4071 }
4072 }
4073 }
4074}