1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
10 UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use futures::TryFutureExt as _;
36use rpc::proto::split_repository_update;
37use tracing::Span;
38use util::paths::PathStyle;
39
40use futures::{
41 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
42 stream::FuturesUnordered,
43};
44use prometheus::{IntGauge, register_int_gauge};
45use rpc::{
46 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
47 proto::{
48 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
49 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
50 },
51};
52use semver::Version;
53use std::{
54 any::TypeId,
55 future::Future,
56 marker::PhantomData,
57 mem,
58 net::SocketAddr,
59 ops::{Deref, DerefMut},
60 rc::Rc,
61 sync::{
62 Arc, OnceLock,
63 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
64 },
65 time::{Duration, Instant},
66};
67use tokio::sync::{Semaphore, watch};
68use tower::ServiceBuilder;
69use tracing::{
70 Instrument,
71 field::{self},
72 info_span, instrument,
73};
74
75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
76
77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
79
80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
82
83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
84
85const TOTAL_DURATION_MS: &str = "total_duration_ms";
86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
88const HOST_WAITING_MS: &str = "host_waiting_ms";
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
92
93pub struct ConnectionGuard;
94
95impl ConnectionGuard {
96 pub fn try_acquire() -> Result<Self, ()> {
97 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
98 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
99 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
100 tracing::error!(
101 "too many concurrent connections: {}",
102 current_connections + 1
103 );
104 return Err(());
105 }
106 Ok(ConnectionGuard)
107 }
108}
109
110impl Drop for ConnectionGuard {
111 fn drop(&mut self) {
112 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
113 }
114}
115
116struct Response<R> {
117 peer: Arc<Peer>,
118 receipt: Receipt<R>,
119 responded: Arc<AtomicBool>,
120}
121
122impl<R: RequestMessage> Response<R> {
123 fn send(self, payload: R::Response) -> Result<()> {
124 self.responded.store(true, SeqCst);
125 self.peer.respond(self.receipt, payload)?;
126 Ok(())
127 }
128}
129
130#[derive(Clone, Debug)]
131pub enum Principal {
132 User(User),
133}
134
135impl Principal {
136 fn update_span(&self, span: &tracing::Span) {
137 match &self {
138 Principal::User(user) => {
139 span.record("user_id", user.id.0);
140 span.record("login", &user.github_login);
141 }
142 }
143 }
144}
145
146#[derive(Clone)]
147struct MessageContext {
148 session: Session,
149 span: tracing::Span,
150}
151
152impl Deref for MessageContext {
153 type Target = Session;
154
155 fn deref(&self) -> &Self::Target {
156 &self.session
157 }
158}
159
160impl MessageContext {
161 pub fn forward_request<T: RequestMessage>(
162 &self,
163 receiver_id: ConnectionId,
164 request: T,
165 ) -> impl Future<Output = anyhow::Result<T::Response>> {
166 let request_start_time = Instant::now();
167 let span = self.span.clone();
168 tracing::info!("start forwarding request");
169 self.peer
170 .forward_request(self.connection_id, receiver_id, request)
171 .inspect(move |_| {
172 span.record(
173 HOST_WAITING_MS,
174 request_start_time.elapsed().as_micros() as f64 / 1000.0,
175 );
176 })
177 .inspect_err(|_| tracing::error!("error forwarding request"))
178 .inspect_ok(|_| tracing::info!("finished forwarding request"))
179 }
180}
181
182#[derive(Clone)]
183struct Session {
184 principal: Principal,
185 connection_id: ConnectionId,
186 db: Arc<tokio::sync::Mutex<DbHandle>>,
187 peer: Arc<Peer>,
188 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
189 app_state: Arc<AppState>,
190 /// The GeoIP country code for the user.
191 #[allow(unused)]
192 geoip_country_code: Option<String>,
193 #[allow(unused)]
194 system_id: Option<String>,
195 _executor: Executor,
196}
197
198impl Session {
199 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
200 #[cfg(feature = "test-support")]
201 tokio::task::yield_now().await;
202 let guard = self.db.lock().await;
203 #[cfg(feature = "test-support")]
204 tokio::task::yield_now().await;
205 guard
206 }
207
208 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
209 #[cfg(feature = "test-support")]
210 tokio::task::yield_now().await;
211 let guard = self.connection_pool.lock();
212 ConnectionPoolGuard {
213 guard,
214 _not_send: PhantomData,
215 }
216 }
217
218 #[expect(dead_code)]
219 fn is_staff(&self) -> bool {
220 match &self.principal {
221 Principal::User(user) => user.admin,
222 }
223 }
224
225 fn user_id(&self) -> UserId {
226 match &self.principal {
227 Principal::User(user) => user.id,
228 }
229 }
230}
231
232impl Debug for Session {
233 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
234 let mut result = f.debug_struct("Session");
235 match &self.principal {
236 Principal::User(user) => {
237 result.field("user", &user.github_login);
238 }
239 }
240 result.field("connection_id", &self.connection_id).finish()
241 }
242}
243
244struct DbHandle(Arc<Database>);
245
246impl Deref for DbHandle {
247 type Target = Database;
248
249 fn deref(&self) -> &Self::Target {
250 self.0.as_ref()
251 }
252}
253
254pub struct Server {
255 id: parking_lot::Mutex<ServerId>,
256 peer: Arc<Peer>,
257 pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
258 app_state: Arc<AppState>,
259 handlers: HashMap<TypeId, MessageHandler>,
260 teardown: watch::Sender<bool>,
261}
262
263struct ConnectionPoolGuard<'a> {
264 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
265 _not_send: PhantomData<Rc<()>>,
266}
267
268impl Server {
269 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
270 let mut server = Self {
271 id: parking_lot::Mutex::new(id),
272 peer: Peer::new(id.0 as u32),
273 app_state,
274 connection_pool: Default::default(),
275 handlers: Default::default(),
276 teardown: watch::channel(false).0,
277 };
278
279 server
280 .add_request_handler(ping)
281 .add_request_handler(create_room)
282 .add_request_handler(join_room)
283 .add_request_handler(rejoin_room)
284 .add_request_handler(leave_room)
285 .add_request_handler(set_room_participant_role)
286 .add_request_handler(call)
287 .add_request_handler(cancel_call)
288 .add_message_handler(decline_call)
289 .add_request_handler(update_participant_location)
290 .add_request_handler(share_project)
291 .add_message_handler(unshare_project)
292 .add_request_handler(join_project)
293 .add_message_handler(leave_project)
294 .add_request_handler(update_project)
295 .add_request_handler(update_worktree)
296 .add_request_handler(update_repository)
297 .add_request_handler(remove_repository)
298 .add_message_handler(start_language_server)
299 .add_message_handler(update_language_server)
300 .add_message_handler(update_diagnostic_summary)
301 .add_message_handler(update_worktree_settings)
302 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
307 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
308 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
311 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
313 .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
314 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
316 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
317 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
318 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
319 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
320 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
321 .add_request_handler(
322 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
323 )
324 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
325 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
326 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
327 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
328 .add_request_handler(
329 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
330 )
331 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
332 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
333 .add_request_handler(
334 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
335 )
336 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
337 .add_request_handler(
338 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
339 )
340 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
341 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
342 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
343 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
344 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
345 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
346 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
347 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
348 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
349 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
350 .add_request_handler(forward_mutating_project_request::<proto::TrashProjectEntry>)
351 .add_request_handler(forward_mutating_project_request::<proto::RestoreProjectEntry>)
352 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
353 .add_request_handler(
354 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
355 )
356 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
357 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
358 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
359 .add_request_handler(lsp_query)
360 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
361 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
362 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
363 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
364 .add_message_handler(create_buffer_for_peer)
365 .add_message_handler(create_image_for_peer)
366 .add_request_handler(update_buffer)
367 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
368 .add_message_handler(
369 broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
370 )
371 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
372 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
373 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
374 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
375 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
376 .add_message_handler(
377 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
378 )
379 .add_request_handler(get_users)
380 .add_request_handler(fuzzy_search_users)
381 .add_request_handler(request_contact)
382 .add_request_handler(remove_contact)
383 .add_request_handler(respond_to_contact_request)
384 .add_message_handler(subscribe_to_channels)
385 .add_request_handler(create_channel)
386 .add_request_handler(delete_channel)
387 .add_request_handler(invite_channel_member)
388 .add_request_handler(remove_channel_member)
389 .add_request_handler(set_channel_member_role)
390 .add_request_handler(set_channel_visibility)
391 .add_request_handler(rename_channel)
392 .add_request_handler(join_channel_buffer)
393 .add_request_handler(leave_channel_buffer)
394 .add_message_handler(update_channel_buffer)
395 .add_request_handler(rejoin_channel_buffers)
396 .add_request_handler(get_channel_members)
397 .add_request_handler(respond_to_channel_invite)
398 .add_request_handler(join_channel)
399 .add_request_handler(join_channel_chat)
400 .add_message_handler(leave_channel_chat)
401 .add_request_handler(send_channel_message)
402 .add_request_handler(remove_channel_message)
403 .add_request_handler(update_channel_message)
404 .add_request_handler(get_channel_messages)
405 .add_request_handler(get_channel_messages_by_id)
406 .add_request_handler(get_notifications)
407 .add_request_handler(mark_notification_as_read)
408 .add_request_handler(move_channel)
409 .add_request_handler(reorder_channel)
410 .add_request_handler(follow)
411 .add_message_handler(unfollow)
412 .add_message_handler(update_followers)
413 .add_message_handler(acknowledge_channel_message)
414 .add_message_handler(acknowledge_buffer_version)
415 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
416 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
417 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
418 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
419 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
420 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
421 .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
422 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
423 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
424 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
425 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
426 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
427 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
428 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
429 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
430 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
431 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
432 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
433 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
434 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
435 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
436 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
437 .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
438 .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
439 .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
440 .add_request_handler(forward_read_only_project_request::<proto::GitGetHeadSha>)
441 .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
442 .add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
443 .add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
444 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
445 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
446 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
447 .add_request_handler(share_agent_thread)
448 .add_request_handler(get_shared_agent_thread)
449 .add_request_handler(forward_project_search_chunk);
450
451 Arc::new(server)
452 }
453
454 pub async fn start(&self) -> Result<()> {
455 let server_id = *self.id.lock();
456 let app_state = self.app_state.clone();
457 let peer = self.peer.clone();
458 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
459 let pool = self.connection_pool.clone();
460 let livekit_client = self.app_state.livekit_client.clone();
461
462 let span = info_span!("start server");
463 self.app_state.executor.spawn_detached(
464 async move {
465 tracing::info!("waiting for cleanup timeout");
466 timeout.await;
467 tracing::info!("cleanup timeout expired, retrieving stale rooms");
468
469 app_state
470 .db
471 .delete_stale_channel_chat_participants(
472 &app_state.config.zed_environment,
473 server_id,
474 )
475 .await
476 .trace_err();
477
478 if let Some((room_ids, channel_ids)) = app_state
479 .db
480 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
481 .await
482 .trace_err()
483 {
484 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
485 tracing::info!(
486 stale_channel_buffer_count = channel_ids.len(),
487 "retrieved stale channel buffers"
488 );
489
490 for channel_id in channel_ids {
491 if let Some(refreshed_channel_buffer) = app_state
492 .db
493 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
494 .await
495 .trace_err()
496 {
497 for connection_id in refreshed_channel_buffer.connection_ids {
498 peer.send(
499 connection_id,
500 proto::UpdateChannelBufferCollaborators {
501 channel_id: channel_id.to_proto(),
502 collaborators: refreshed_channel_buffer
503 .collaborators
504 .clone(),
505 },
506 )
507 .trace_err();
508 }
509 }
510 }
511
512 for room_id in room_ids {
513 let mut contacts_to_update = HashSet::default();
514 let mut canceled_calls_to_user_ids = Vec::new();
515 let mut livekit_room = String::new();
516 let mut delete_livekit_room = false;
517
518 if let Some(mut refreshed_room) = app_state
519 .db
520 .clear_stale_room_participants(room_id, server_id)
521 .await
522 .trace_err()
523 {
524 tracing::info!(
525 room_id = room_id.0,
526 new_participant_count = refreshed_room.room.participants.len(),
527 "refreshed room"
528 );
529 room_updated(&refreshed_room.room, &peer);
530 if let Some(channel) = refreshed_room.channel.as_ref() {
531 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
532 }
533 contacts_to_update
534 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
535 contacts_to_update
536 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
537 canceled_calls_to_user_ids =
538 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
539 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
540 delete_livekit_room = refreshed_room.room.participants.is_empty();
541 }
542
543 {
544 let pool = pool.lock();
545 for canceled_user_id in canceled_calls_to_user_ids {
546 for connection_id in pool.user_connection_ids(canceled_user_id) {
547 peer.send(
548 connection_id,
549 proto::CallCanceled {
550 room_id: room_id.to_proto(),
551 },
552 )
553 .trace_err();
554 }
555 }
556 }
557
558 for user_id in contacts_to_update {
559 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
560 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
561 if let Some((busy, contacts)) = busy.zip(contacts) {
562 let pool = pool.lock();
563 let updated_contact = contact_for_user(user_id, busy, &pool);
564 for contact in contacts {
565 if let db::Contact::Accepted {
566 user_id: contact_user_id,
567 ..
568 } = contact
569 {
570 for contact_conn_id in
571 pool.user_connection_ids(contact_user_id)
572 {
573 peer.send(
574 contact_conn_id,
575 proto::UpdateContacts {
576 contacts: vec![updated_contact.clone()],
577 remove_contacts: Default::default(),
578 incoming_requests: Default::default(),
579 remove_incoming_requests: Default::default(),
580 outgoing_requests: Default::default(),
581 remove_outgoing_requests: Default::default(),
582 },
583 )
584 .trace_err();
585 }
586 }
587 }
588 }
589 }
590
591 if let Some(live_kit) = livekit_client.as_ref()
592 && delete_livekit_room
593 {
594 live_kit.delete_room(livekit_room).await.trace_err();
595 }
596 }
597 }
598
599 app_state
600 .db
601 .delete_stale_channel_chat_participants(
602 &app_state.config.zed_environment,
603 server_id,
604 )
605 .await
606 .trace_err();
607
608 app_state
609 .db
610 .clear_old_worktree_entries(server_id)
611 .await
612 .trace_err();
613
614 app_state
615 .db
616 .delete_stale_servers(&app_state.config.zed_environment, server_id)
617 .await
618 .trace_err();
619 }
620 .instrument(span),
621 );
622 Ok(())
623 }
624
625 pub fn teardown(&self) {
626 self.peer.teardown();
627 self.connection_pool.lock().reset();
628 let _ = self.teardown.send(true);
629 }
630
631 #[cfg(feature = "test-support")]
632 pub fn reset(&self, id: ServerId) {
633 self.teardown();
634 *self.id.lock() = id;
635 self.peer.reset(id.0 as u32);
636 let _ = self.teardown.send(false);
637 }
638
639 #[cfg(feature = "test-support")]
640 pub fn id(&self) -> ServerId {
641 *self.id.lock()
642 }
643
644 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
645 where
646 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
647 Fut: 'static + Send + Future<Output = Result<()>>,
648 M: EnvelopedMessage,
649 {
650 let prev_handler = self.handlers.insert(
651 TypeId::of::<M>(),
652 Box::new(move |envelope, session, span| {
653 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
654 let received_at = envelope.received_at;
655 tracing::info!("message received");
656 let start_time = Instant::now();
657 let future = (handler)(
658 *envelope,
659 MessageContext {
660 session,
661 span: span.clone(),
662 },
663 );
664 async move {
665 let result = future.await;
666 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
667 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
668 let queue_duration_ms = total_duration_ms - processing_duration_ms;
669 span.record(TOTAL_DURATION_MS, total_duration_ms);
670 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
671 span.record(QUEUE_DURATION_MS, queue_duration_ms);
672 match result {
673 Err(error) => {
674 tracing::error!(?error, "error handling message")
675 }
676 Ok(()) => tracing::info!("finished handling message"),
677 }
678 }
679 .boxed()
680 }),
681 );
682 if prev_handler.is_some() {
683 panic!("registered a handler for the same message twice");
684 }
685 self
686 }
687
688 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
689 where
690 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
691 Fut: 'static + Send + Future<Output = Result<()>>,
692 M: EnvelopedMessage,
693 {
694 self.add_handler(move |envelope, session| handler(envelope.payload, session));
695 self
696 }
697
698 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
699 where
700 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
701 Fut: Send + Future<Output = Result<()>>,
702 M: RequestMessage,
703 {
704 let handler = Arc::new(handler);
705 self.add_handler(move |envelope, session| {
706 let receipt = envelope.receipt();
707 let handler = handler.clone();
708 async move {
709 let peer = session.peer.clone();
710 let responded = Arc::new(AtomicBool::default());
711 let response = Response {
712 peer: peer.clone(),
713 responded: responded.clone(),
714 receipt,
715 };
716 match (handler)(envelope.payload, response, session).await {
717 Ok(()) => {
718 if responded.load(std::sync::atomic::Ordering::SeqCst) {
719 Ok(())
720 } else {
721 Err(anyhow!("handler did not send a response"))?
722 }
723 }
724 Err(error) => {
725 let proto_err = match &error {
726 Error::Internal(err) => err.to_proto(),
727 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
728 };
729 peer.respond_with_error(receipt, proto_err)?;
730 Err(error)
731 }
732 }
733 }
734 })
735 }
736
737 pub fn handle_connection(
738 self: &Arc<Self>,
739 connection: Connection,
740 address: String,
741 principal: Principal,
742 zed_version: ZedVersion,
743 release_channel: Option<String>,
744 user_agent: Option<String>,
745 geoip_country_code: Option<String>,
746 system_id: Option<String>,
747 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
748 executor: Executor,
749 connection_guard: Option<ConnectionGuard>,
750 ) -> impl Future<Output = ()> + use<> {
751 let this = self.clone();
752 let span = info_span!("handle connection", %address,
753 connection_id=field::Empty,
754 user_id=field::Empty,
755 login=field::Empty,
756 user_agent=field::Empty,
757 geoip_country_code=field::Empty,
758 release_channel=field::Empty,
759 );
760 principal.update_span(&span);
761 if let Some(user_agent) = user_agent {
762 span.record("user_agent", user_agent);
763 }
764 if let Some(release_channel) = release_channel {
765 span.record("release_channel", release_channel);
766 }
767
768 if let Some(country_code) = geoip_country_code.as_ref() {
769 span.record("geoip_country_code", country_code);
770 }
771
772 let mut teardown = self.teardown.subscribe();
773 async move {
774 if *teardown.borrow() {
775 tracing::error!("server is tearing down");
776 return;
777 }
778
779 let (connection_id, handle_io, mut incoming_rx) =
780 this.peer.add_connection(connection, {
781 let executor = executor.clone();
782 move |duration| executor.sleep(duration)
783 });
784 tracing::Span::current().record("connection_id", format!("{}", connection_id));
785
786 tracing::info!("connection opened");
787
788 let session = Session {
789 principal: principal.clone(),
790 connection_id,
791 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
792 peer: this.peer.clone(),
793 connection_pool: this.connection_pool.clone(),
794 app_state: this.app_state.clone(),
795 geoip_country_code,
796 system_id,
797 _executor: executor.clone(),
798 };
799
800 if let Err(error) = this
801 .send_initial_client_update(
802 connection_id,
803 zed_version,
804 send_connection_id,
805 &session,
806 )
807 .await
808 {
809 tracing::error!(?error, "failed to send initial client update");
810 return;
811 }
812 drop(connection_guard);
813
814 let handle_io = handle_io.fuse();
815 futures::pin_mut!(handle_io);
816
817 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
818 // This prevents deadlocks when e.g., client A performs a request to client B and
819 // client B performs a request to client A. If both clients stop processing further
820 // messages until their respective request completes, they won't have a chance to
821 // respond to the other client's request and cause a deadlock.
822 //
823 // This arrangement ensures we will attempt to process earlier messages first, but fall
824 // back to processing messages arrived later in the spirit of making progress.
825 const MAX_CONCURRENT_HANDLERS: usize = 256;
826 let mut foreground_message_handlers = FuturesUnordered::new();
827 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
828 let get_concurrent_handlers = {
829 let concurrent_handlers = concurrent_handlers.clone();
830 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
831 };
832 loop {
833 let next_message = async {
834 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
835 let message = incoming_rx.next().await;
836 // Cache the concurrent_handlers here, so that we know what the
837 // queue looks like as each handler starts
838 (permit, message, get_concurrent_handlers())
839 }
840 .fuse();
841 futures::pin_mut!(next_message);
842 futures::select_biased! {
843 _ = teardown.changed().fuse() => return,
844 result = handle_io => {
845 if let Err(error) = result {
846 tracing::error!(?error, "error handling I/O");
847 }
848 break;
849 }
850 _ = foreground_message_handlers.next() => {}
851 next_message = next_message => {
852 let (permit, message, concurrent_handlers) = next_message;
853 if let Some(message) = message {
854 let type_name = message.payload_type_name();
855 // note: we copy all the fields from the parent span so we can query them in the logs.
856 // (https://github.com/tokio-rs/tracing/issues/2670).
857 let span = tracing::info_span!("receive message",
858 %connection_id,
859 %address,
860 type_name,
861 concurrent_handlers,
862 user_id=field::Empty,
863 login=field::Empty,
864 lsp_query_request=field::Empty,
865 release_channel=field::Empty,
866 { TOTAL_DURATION_MS }=field::Empty,
867 { PROCESSING_DURATION_MS }=field::Empty,
868 { QUEUE_DURATION_MS }=field::Empty,
869 { HOST_WAITING_MS }=field::Empty
870 );
871 principal.update_span(&span);
872 let span_enter = span.enter();
873 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
874 let is_background = message.is_background();
875 let handle_message = (handler)(message, session.clone(), span.clone());
876 drop(span_enter);
877
878 let handle_message = async move {
879 handle_message.await;
880 drop(permit);
881 }.instrument(span);
882 if is_background {
883 executor.spawn_detached(handle_message);
884 } else {
885 foreground_message_handlers.push(handle_message);
886 }
887 } else {
888 tracing::error!("no message handler");
889 }
890 } else {
891 tracing::info!("connection closed");
892 break;
893 }
894 }
895 }
896 }
897
898 drop(foreground_message_handlers);
899 let concurrent_handlers = get_concurrent_handlers();
900 tracing::info!(concurrent_handlers, "signing out");
901 if let Err(error) = connection_lost(session, teardown, executor).await {
902 tracing::error!(?error, "error signing out");
903 }
904 }
905 .instrument(span)
906 }
907
908 async fn send_initial_client_update(
909 &self,
910 connection_id: ConnectionId,
911 zed_version: ZedVersion,
912 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
913 session: &Session,
914 ) -> Result<()> {
915 self.peer.send(
916 connection_id,
917 proto::Hello {
918 peer_id: Some(connection_id.into()),
919 },
920 )?;
921 tracing::info!("sent hello message");
922 if let Some(send_connection_id) = send_connection_id.take() {
923 let _ = send_connection_id.send(connection_id);
924 }
925
926 match &session.principal {
927 Principal::User(user) => {
928 if !user.connected_once {
929 self.peer.send(connection_id, proto::ShowContacts {})?;
930 self.app_state
931 .db
932 .set_user_connected_once(user.id, true)
933 .await?;
934 }
935
936 let contacts = self.app_state.db.get_contacts(user.id).await?;
937
938 {
939 let mut pool = self.connection_pool.lock();
940 pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
941 self.peer.send(
942 connection_id,
943 build_initial_contacts_update(contacts, &pool),
944 )?;
945 }
946
947 if should_auto_subscribe_to_channels(&zed_version) {
948 subscribe_user_to_channels(user.id, session).await?;
949 }
950
951 if let Some(incoming_call) =
952 self.app_state.db.incoming_call_for_user(user.id).await?
953 {
954 self.peer.send(connection_id, incoming_call)?;
955 }
956
957 update_user_contacts(user.id, session).await?;
958 }
959 }
960
961 Ok(())
962 }
963}
964
965impl Deref for ConnectionPoolGuard<'_> {
966 type Target = ConnectionPool;
967
968 fn deref(&self) -> &Self::Target {
969 &self.guard
970 }
971}
972
973impl DerefMut for ConnectionPoolGuard<'_> {
974 fn deref_mut(&mut self) -> &mut Self::Target {
975 &mut self.guard
976 }
977}
978
979impl Drop for ConnectionPoolGuard<'_> {
980 fn drop(&mut self) {
981 #[cfg(feature = "test-support")]
982 self.check_invariants();
983 }
984}
985
986fn broadcast<F>(
987 sender_id: Option<ConnectionId>,
988 receiver_ids: impl IntoIterator<Item = ConnectionId>,
989 mut f: F,
990) where
991 F: FnMut(ConnectionId) -> anyhow::Result<()>,
992{
993 for receiver_id in receiver_ids {
994 if Some(receiver_id) != sender_id
995 && let Err(error) = f(receiver_id)
996 {
997 tracing::error!("failed to send to {:?} {}", receiver_id, error);
998 }
999 }
1000}
1001
1002pub struct ProtocolVersion(u32);
1003
1004impl Header for ProtocolVersion {
1005 fn name() -> &'static HeaderName {
1006 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1007 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1008 }
1009
1010 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1011 where
1012 Self: Sized,
1013 I: Iterator<Item = &'i axum::http::HeaderValue>,
1014 {
1015 let version = values
1016 .next()
1017 .ok_or_else(axum::headers::Error::invalid)?
1018 .to_str()
1019 .map_err(|_| axum::headers::Error::invalid())?
1020 .parse()
1021 .map_err(|_| axum::headers::Error::invalid())?;
1022 Ok(Self(version))
1023 }
1024
1025 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1026 values.extend([self.0.to_string().parse().unwrap()]);
1027 }
1028}
1029
1030pub struct AppVersionHeader(Version);
1031impl Header for AppVersionHeader {
1032 fn name() -> &'static HeaderName {
1033 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1034 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1035 }
1036
1037 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1038 where
1039 Self: Sized,
1040 I: Iterator<Item = &'i axum::http::HeaderValue>,
1041 {
1042 let version = values
1043 .next()
1044 .ok_or_else(axum::headers::Error::invalid)?
1045 .to_str()
1046 .map_err(|_| axum::headers::Error::invalid())?
1047 .parse()
1048 .map_err(|_| axum::headers::Error::invalid())?;
1049 Ok(Self(version))
1050 }
1051
1052 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1053 values.extend([self.0.to_string().parse().unwrap()]);
1054 }
1055}
1056
1057#[derive(Debug)]
1058pub struct ReleaseChannelHeader(String);
1059
1060impl Header for ReleaseChannelHeader {
1061 fn name() -> &'static HeaderName {
1062 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1063 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1064 }
1065
1066 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1067 where
1068 Self: Sized,
1069 I: Iterator<Item = &'i axum::http::HeaderValue>,
1070 {
1071 Ok(Self(
1072 values
1073 .next()
1074 .ok_or_else(axum::headers::Error::invalid)?
1075 .to_str()
1076 .map_err(|_| axum::headers::Error::invalid())?
1077 .to_owned(),
1078 ))
1079 }
1080
1081 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1082 values.extend([self.0.parse().unwrap()]);
1083 }
1084}
1085
1086pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1087 Router::new()
1088 .route("/rpc", get(handle_websocket_request))
1089 .layer(
1090 ServiceBuilder::new()
1091 .layer(Extension(server.app_state.clone()))
1092 .layer(middleware::from_fn(auth::validate_header)),
1093 )
1094 .route("/metrics", get(handle_metrics))
1095 .layer(Extension(server))
1096}
1097
1098pub async fn handle_websocket_request(
1099 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1100 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1101 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1102 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1103 Extension(server): Extension<Arc<Server>>,
1104 Extension(principal): Extension<Principal>,
1105 user_agent: Option<TypedHeader<UserAgent>>,
1106 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1107 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1108 ws: WebSocketUpgrade,
1109) -> axum::response::Response {
1110 if protocol_version != rpc::PROTOCOL_VERSION {
1111 return (
1112 StatusCode::UPGRADE_REQUIRED,
1113 "client must be upgraded".to_string(),
1114 )
1115 .into_response();
1116 }
1117
1118 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1119 return (
1120 StatusCode::UPGRADE_REQUIRED,
1121 "no version header found".to_string(),
1122 )
1123 .into_response();
1124 };
1125
1126 let release_channel = release_channel_header.map(|header| header.0.0);
1127
1128 if !version.can_collaborate() {
1129 return (
1130 StatusCode::UPGRADE_REQUIRED,
1131 "client must be upgraded".to_string(),
1132 )
1133 .into_response();
1134 }
1135
1136 let socket_address = socket_address.to_string();
1137
1138 // Acquire connection guard before WebSocket upgrade
1139 let connection_guard = match ConnectionGuard::try_acquire() {
1140 Ok(guard) => guard,
1141 Err(()) => {
1142 return (
1143 StatusCode::SERVICE_UNAVAILABLE,
1144 "Too many concurrent connections",
1145 )
1146 .into_response();
1147 }
1148 };
1149
1150 ws.on_upgrade(move |socket| {
1151 let socket = socket
1152 .map_ok(to_tungstenite_message)
1153 .err_into()
1154 .with(|message| async move { to_axum_message(message) });
1155 let connection = Connection::new(Box::pin(socket));
1156 async move {
1157 server
1158 .handle_connection(
1159 connection,
1160 socket_address,
1161 principal,
1162 version,
1163 release_channel,
1164 user_agent.map(|header| header.to_string()),
1165 country_code_header.map(|header| header.to_string()),
1166 system_id_header.map(|header| header.to_string()),
1167 None,
1168 Executor::Production,
1169 Some(connection_guard),
1170 )
1171 .await;
1172 }
1173 })
1174}
1175
1176pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1177 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1178 let connections_metric = CONNECTIONS_METRIC
1179 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1180
1181 let connections = server
1182 .connection_pool
1183 .lock()
1184 .connections()
1185 .filter(|connection| !connection.admin)
1186 .count();
1187 connections_metric.set(connections as _);
1188
1189 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1190 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1191 register_int_gauge!(
1192 "shared_projects",
1193 "number of open projects with one or more guests"
1194 )
1195 .unwrap()
1196 });
1197
1198 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1199 shared_projects_metric.set(shared_projects as _);
1200
1201 let encoder = prometheus::TextEncoder::new();
1202 let metric_families = prometheus::gather();
1203 let encoded_metrics = encoder
1204 .encode_to_string(&metric_families)
1205 .map_err(|err| anyhow!("{err}"))?;
1206 Ok(encoded_metrics)
1207}
1208
1209#[instrument(err, skip(executor))]
1210async fn connection_lost(
1211 session: Session,
1212 mut teardown: watch::Receiver<bool>,
1213 executor: Executor,
1214) -> Result<()> {
1215 session.peer.disconnect(session.connection_id);
1216 session
1217 .connection_pool()
1218 .await
1219 .remove_connection(session.connection_id)?;
1220
1221 session
1222 .db()
1223 .await
1224 .connection_lost(session.connection_id)
1225 .await
1226 .trace_err();
1227
1228 futures::select_biased! {
1229 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1230
1231 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1232 leave_room_for_session(&session, session.connection_id).await.trace_err();
1233 leave_channel_buffers_for_session(&session)
1234 .await
1235 .trace_err();
1236
1237 if !session
1238 .connection_pool()
1239 .await
1240 .is_user_online(session.user_id())
1241 {
1242 let db = session.db().await;
1243 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1244 room_updated(&room, &session.peer);
1245 }
1246 }
1247
1248 update_user_contacts(session.user_id(), &session).await?;
1249 },
1250 _ = teardown.changed().fuse() => {}
1251 }
1252
1253 Ok(())
1254}
1255
1256/// Acknowledges a ping from a client, used to keep the connection alive.
1257async fn ping(
1258 _: proto::Ping,
1259 response: Response<proto::Ping>,
1260 _session: MessageContext,
1261) -> Result<()> {
1262 response.send(proto::Ack {})?;
1263 Ok(())
1264}
1265
1266/// Creates a new room for calling (outside of channels)
1267async fn create_room(
1268 _request: proto::CreateRoom,
1269 response: Response<proto::CreateRoom>,
1270 session: MessageContext,
1271) -> Result<()> {
1272 let livekit_room = nanoid::nanoid!(30);
1273
1274 let live_kit_connection_info = util::maybe!(async {
1275 let live_kit = session.app_state.livekit_client.as_ref();
1276 let live_kit = live_kit?;
1277 let user_id = session.user_id().to_string();
1278
1279 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1280
1281 Some(proto::LiveKitConnectionInfo {
1282 server_url: live_kit.url().into(),
1283 token,
1284 can_publish: true,
1285 })
1286 })
1287 .await;
1288
1289 let room = session
1290 .db()
1291 .await
1292 .create_room(session.user_id(), session.connection_id, &livekit_room)
1293 .await?;
1294
1295 response.send(proto::CreateRoomResponse {
1296 room: Some(room.clone()),
1297 live_kit_connection_info,
1298 })?;
1299
1300 update_user_contacts(session.user_id(), &session).await?;
1301 Ok(())
1302}
1303
1304/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1305async fn join_room(
1306 request: proto::JoinRoom,
1307 response: Response<proto::JoinRoom>,
1308 session: MessageContext,
1309) -> Result<()> {
1310 let room_id = RoomId::from_proto(request.id);
1311
1312 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1313
1314 if let Some(channel_id) = channel_id {
1315 return join_channel_internal(channel_id, Box::new(response), session).await;
1316 }
1317
1318 let joined_room = {
1319 let room = session
1320 .db()
1321 .await
1322 .join_room(room_id, session.user_id(), session.connection_id)
1323 .await?;
1324 room_updated(&room.room, &session.peer);
1325 room.into_inner()
1326 };
1327
1328 for connection_id in session
1329 .connection_pool()
1330 .await
1331 .user_connection_ids(session.user_id())
1332 {
1333 session
1334 .peer
1335 .send(
1336 connection_id,
1337 proto::CallCanceled {
1338 room_id: room_id.to_proto(),
1339 },
1340 )
1341 .trace_err();
1342 }
1343
1344 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1345 {
1346 live_kit
1347 .room_token(
1348 &joined_room.room.livekit_room,
1349 &session.user_id().to_string(),
1350 )
1351 .trace_err()
1352 .map(|token| proto::LiveKitConnectionInfo {
1353 server_url: live_kit.url().into(),
1354 token,
1355 can_publish: true,
1356 })
1357 } else {
1358 None
1359 };
1360
1361 response.send(proto::JoinRoomResponse {
1362 room: Some(joined_room.room),
1363 channel_id: None,
1364 live_kit_connection_info,
1365 })?;
1366
1367 update_user_contacts(session.user_id(), &session).await?;
1368 Ok(())
1369}
1370
1371/// Rejoin room is used to reconnect to a room after connection errors.
1372async fn rejoin_room(
1373 request: proto::RejoinRoom,
1374 response: Response<proto::RejoinRoom>,
1375 session: MessageContext,
1376) -> Result<()> {
1377 let room;
1378 let channel;
1379 {
1380 let mut rejoined_room = session
1381 .db()
1382 .await
1383 .rejoin_room(request, session.user_id(), session.connection_id)
1384 .await?;
1385
1386 response.send(proto::RejoinRoomResponse {
1387 room: Some(rejoined_room.room.clone()),
1388 reshared_projects: rejoined_room
1389 .reshared_projects
1390 .iter()
1391 .map(|project| proto::ResharedProject {
1392 id: project.id.to_proto(),
1393 collaborators: project
1394 .collaborators
1395 .iter()
1396 .map(|collaborator| collaborator.to_proto())
1397 .collect(),
1398 })
1399 .collect(),
1400 rejoined_projects: rejoined_room
1401 .rejoined_projects
1402 .iter()
1403 .map(|rejoined_project| rejoined_project.to_proto())
1404 .collect(),
1405 })?;
1406 room_updated(&rejoined_room.room, &session.peer);
1407
1408 for project in &rejoined_room.reshared_projects {
1409 for collaborator in &project.collaborators {
1410 session
1411 .peer
1412 .send(
1413 collaborator.connection_id,
1414 proto::UpdateProjectCollaborator {
1415 project_id: project.id.to_proto(),
1416 old_peer_id: Some(project.old_connection_id.into()),
1417 new_peer_id: Some(session.connection_id.into()),
1418 },
1419 )
1420 .trace_err();
1421 }
1422
1423 broadcast(
1424 Some(session.connection_id),
1425 project
1426 .collaborators
1427 .iter()
1428 .map(|collaborator| collaborator.connection_id),
1429 |connection_id| {
1430 session.peer.forward_send(
1431 session.connection_id,
1432 connection_id,
1433 proto::UpdateProject {
1434 project_id: project.id.to_proto(),
1435 worktrees: project.worktrees.clone(),
1436 },
1437 )
1438 },
1439 );
1440 }
1441
1442 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1443
1444 let rejoined_room = rejoined_room.into_inner();
1445
1446 room = rejoined_room.room;
1447 channel = rejoined_room.channel;
1448 }
1449
1450 if let Some(channel) = channel {
1451 channel_updated(
1452 &channel,
1453 &room,
1454 &session.peer,
1455 &*session.connection_pool().await,
1456 );
1457 }
1458
1459 update_user_contacts(session.user_id(), &session).await?;
1460 Ok(())
1461}
1462
1463fn notify_rejoined_projects(
1464 rejoined_projects: &mut Vec<RejoinedProject>,
1465 session: &Session,
1466) -> Result<()> {
1467 for project in rejoined_projects.iter() {
1468 for collaborator in &project.collaborators {
1469 session
1470 .peer
1471 .send(
1472 collaborator.connection_id,
1473 proto::UpdateProjectCollaborator {
1474 project_id: project.id.to_proto(),
1475 old_peer_id: Some(project.old_connection_id.into()),
1476 new_peer_id: Some(session.connection_id.into()),
1477 },
1478 )
1479 .trace_err();
1480 }
1481 }
1482
1483 for project in rejoined_projects {
1484 for worktree in mem::take(&mut project.worktrees) {
1485 // Stream this worktree's entries.
1486 let message = proto::UpdateWorktree {
1487 project_id: project.id.to_proto(),
1488 worktree_id: worktree.id,
1489 abs_path: worktree.abs_path.clone(),
1490 root_name: worktree.root_name,
1491 root_repo_common_dir: worktree.root_repo_common_dir,
1492 updated_entries: worktree.updated_entries,
1493 removed_entries: worktree.removed_entries,
1494 scan_id: worktree.scan_id,
1495 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1496 updated_repositories: worktree.updated_repositories,
1497 removed_repositories: worktree.removed_repositories,
1498 };
1499 for update in proto::split_worktree_update(message) {
1500 session.peer.send(session.connection_id, update)?;
1501 }
1502
1503 // Stream this worktree's diagnostics.
1504 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1505 if let Some(summary) = worktree_diagnostics.next() {
1506 let message = proto::UpdateDiagnosticSummary {
1507 project_id: project.id.to_proto(),
1508 worktree_id: worktree.id,
1509 summary: Some(summary),
1510 more_summaries: worktree_diagnostics.collect(),
1511 };
1512 session.peer.send(session.connection_id, message)?;
1513 }
1514
1515 for settings_file in worktree.settings_files {
1516 session.peer.send(
1517 session.connection_id,
1518 proto::UpdateWorktreeSettings {
1519 project_id: project.id.to_proto(),
1520 worktree_id: worktree.id,
1521 path: settings_file.path,
1522 content: Some(settings_file.content),
1523 kind: Some(settings_file.kind.to_proto().into()),
1524 outside_worktree: Some(settings_file.outside_worktree),
1525 },
1526 )?;
1527 }
1528 }
1529
1530 for repository in mem::take(&mut project.updated_repositories) {
1531 for update in split_repository_update(repository) {
1532 session.peer.send(session.connection_id, update)?;
1533 }
1534 }
1535
1536 for id in mem::take(&mut project.removed_repositories) {
1537 session.peer.send(
1538 session.connection_id,
1539 proto::RemoveRepository {
1540 project_id: project.id.to_proto(),
1541 id,
1542 },
1543 )?;
1544 }
1545 }
1546
1547 Ok(())
1548}
1549
1550/// leave room disconnects from the room.
1551async fn leave_room(
1552 _: proto::LeaveRoom,
1553 response: Response<proto::LeaveRoom>,
1554 session: MessageContext,
1555) -> Result<()> {
1556 leave_room_for_session(&session, session.connection_id).await?;
1557 response.send(proto::Ack {})?;
1558 Ok(())
1559}
1560
1561/// Updates the permissions of someone else in the room.
1562async fn set_room_participant_role(
1563 request: proto::SetRoomParticipantRole,
1564 response: Response<proto::SetRoomParticipantRole>,
1565 session: MessageContext,
1566) -> Result<()> {
1567 let user_id = UserId::from_proto(request.user_id);
1568 let role = ChannelRole::from(request.role());
1569
1570 let (livekit_room, can_publish) = {
1571 let room = session
1572 .db()
1573 .await
1574 .set_room_participant_role(
1575 session.user_id(),
1576 RoomId::from_proto(request.room_id),
1577 user_id,
1578 role,
1579 )
1580 .await?;
1581
1582 let livekit_room = room.livekit_room.clone();
1583 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1584 room_updated(&room, &session.peer);
1585 (livekit_room, can_publish)
1586 };
1587
1588 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1589 live_kit
1590 .update_participant(
1591 livekit_room.clone(),
1592 request.user_id.to_string(),
1593 livekit_api::proto::ParticipantPermission {
1594 can_subscribe: true,
1595 can_publish,
1596 can_publish_data: can_publish,
1597 hidden: false,
1598 recorder: false,
1599 },
1600 )
1601 .await
1602 .trace_err();
1603 }
1604
1605 response.send(proto::Ack {})?;
1606 Ok(())
1607}
1608
1609/// Call someone else into the current room
1610async fn call(
1611 request: proto::Call,
1612 response: Response<proto::Call>,
1613 session: MessageContext,
1614) -> Result<()> {
1615 let room_id = RoomId::from_proto(request.room_id);
1616 let calling_user_id = session.user_id();
1617 let calling_connection_id = session.connection_id;
1618 let called_user_id = UserId::from_proto(request.called_user_id);
1619 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1620 if !session
1621 .db()
1622 .await
1623 .has_contact(calling_user_id, called_user_id)
1624 .await?
1625 {
1626 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1627 }
1628
1629 let incoming_call = {
1630 let (room, incoming_call) = &mut *session
1631 .db()
1632 .await
1633 .call(
1634 room_id,
1635 calling_user_id,
1636 calling_connection_id,
1637 called_user_id,
1638 initial_project_id,
1639 )
1640 .await?;
1641 room_updated(room, &session.peer);
1642 mem::take(incoming_call)
1643 };
1644 update_user_contacts(called_user_id, &session).await?;
1645
1646 let mut calls = session
1647 .connection_pool()
1648 .await
1649 .user_connection_ids(called_user_id)
1650 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1651 .collect::<FuturesUnordered<_>>();
1652
1653 while let Some(call_response) = calls.next().await {
1654 match call_response.as_ref() {
1655 Ok(_) => {
1656 response.send(proto::Ack {})?;
1657 return Ok(());
1658 }
1659 Err(_) => {
1660 call_response.trace_err();
1661 }
1662 }
1663 }
1664
1665 {
1666 let room = session
1667 .db()
1668 .await
1669 .call_failed(room_id, called_user_id)
1670 .await?;
1671 room_updated(&room, &session.peer);
1672 }
1673 update_user_contacts(called_user_id, &session).await?;
1674
1675 Err(anyhow!("failed to ring user"))?
1676}
1677
1678/// Cancel an outgoing call.
1679async fn cancel_call(
1680 request: proto::CancelCall,
1681 response: Response<proto::CancelCall>,
1682 session: MessageContext,
1683) -> Result<()> {
1684 let called_user_id = UserId::from_proto(request.called_user_id);
1685 let room_id = RoomId::from_proto(request.room_id);
1686 {
1687 let room = session
1688 .db()
1689 .await
1690 .cancel_call(room_id, session.connection_id, called_user_id)
1691 .await?;
1692 room_updated(&room, &session.peer);
1693 }
1694
1695 for connection_id in session
1696 .connection_pool()
1697 .await
1698 .user_connection_ids(called_user_id)
1699 {
1700 session
1701 .peer
1702 .send(
1703 connection_id,
1704 proto::CallCanceled {
1705 room_id: room_id.to_proto(),
1706 },
1707 )
1708 .trace_err();
1709 }
1710 response.send(proto::Ack {})?;
1711
1712 update_user_contacts(called_user_id, &session).await?;
1713 Ok(())
1714}
1715
1716/// Decline an incoming call.
1717async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1718 let room_id = RoomId::from_proto(message.room_id);
1719 {
1720 let room = session
1721 .db()
1722 .await
1723 .decline_call(Some(room_id), session.user_id())
1724 .await?
1725 .context("declining call")?;
1726 room_updated(&room, &session.peer);
1727 }
1728
1729 for connection_id in session
1730 .connection_pool()
1731 .await
1732 .user_connection_ids(session.user_id())
1733 {
1734 session
1735 .peer
1736 .send(
1737 connection_id,
1738 proto::CallCanceled {
1739 room_id: room_id.to_proto(),
1740 },
1741 )
1742 .trace_err();
1743 }
1744 update_user_contacts(session.user_id(), &session).await?;
1745 Ok(())
1746}
1747
1748/// Updates other participants in the room with your current location.
1749async fn update_participant_location(
1750 request: proto::UpdateParticipantLocation,
1751 response: Response<proto::UpdateParticipantLocation>,
1752 session: MessageContext,
1753) -> Result<()> {
1754 let room_id = RoomId::from_proto(request.room_id);
1755 let location = request.location.context("invalid location")?;
1756
1757 let db = session.db().await;
1758 let room = db
1759 .update_room_participant_location(room_id, session.connection_id, location)
1760 .await?;
1761
1762 room_updated(&room, &session.peer);
1763 response.send(proto::Ack {})?;
1764 Ok(())
1765}
1766
1767/// Share a project into the room.
1768async fn share_project(
1769 request: proto::ShareProject,
1770 response: Response<proto::ShareProject>,
1771 session: MessageContext,
1772) -> Result<()> {
1773 let (project_id, room) = &*session
1774 .db()
1775 .await
1776 .share_project(
1777 RoomId::from_proto(request.room_id),
1778 session.connection_id,
1779 &request.worktrees,
1780 request.is_ssh_project,
1781 request.windows_paths.unwrap_or(false),
1782 &request.features,
1783 )
1784 .await?;
1785 response.send(proto::ShareProjectResponse {
1786 project_id: project_id.to_proto(),
1787 })?;
1788 room_updated(room, &session.peer);
1789
1790 Ok(())
1791}
1792
1793/// Unshare a project from the room.
1794async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1795 let project_id = ProjectId::from_proto(message.project_id);
1796 unshare_project_internal(project_id, session.connection_id, &session).await
1797}
1798
1799async fn unshare_project_internal(
1800 project_id: ProjectId,
1801 connection_id: ConnectionId,
1802 session: &Session,
1803) -> Result<()> {
1804 let delete = {
1805 let room_guard = session
1806 .db()
1807 .await
1808 .unshare_project(project_id, connection_id)
1809 .await?;
1810
1811 let (delete, room, guest_connection_ids) = &*room_guard;
1812
1813 let message = proto::UnshareProject {
1814 project_id: project_id.to_proto(),
1815 };
1816
1817 broadcast(
1818 Some(connection_id),
1819 guest_connection_ids.iter().copied(),
1820 |conn_id| session.peer.send(conn_id, message.clone()),
1821 );
1822 if let Some(room) = room {
1823 room_updated(room, &session.peer);
1824 }
1825
1826 *delete
1827 };
1828
1829 if delete {
1830 let db = session.db().await;
1831 db.delete_project(project_id).await?;
1832 }
1833
1834 Ok(())
1835}
1836
1837/// Join someone elses shared project.
1838async fn join_project(
1839 request: proto::JoinProject,
1840 response: Response<proto::JoinProject>,
1841 session: MessageContext,
1842) -> Result<()> {
1843 let project_id = ProjectId::from_proto(request.project_id);
1844
1845 tracing::info!(%project_id, "join project");
1846
1847 let db = session.db().await;
1848 let project_model = db.get_project(project_id).await?;
1849 let host_features: Vec<String> =
1850 serde_json::from_str(&project_model.features).unwrap_or_default();
1851 let guest_features: HashSet<_> = request.features.iter().collect();
1852 let host_features_set: HashSet<_> = host_features.iter().collect();
1853 if guest_features != host_features_set {
1854 let host_connection_id = project_model.host_connection()?;
1855 let mut pool = session.connection_pool().await;
1856 let host_version = pool
1857 .connection(host_connection_id)
1858 .map(|c| c.zed_version.to_string());
1859 let guest_version = pool
1860 .connection(session.connection_id)
1861 .map(|c| c.zed_version.to_string());
1862 drop(pool);
1863 Err(anyhow!(
1864 "The host (v{}) and guest (v{}) are using incompatible versions of Zed. The peer with the older version must update to collaborate.",
1865 host_version.as_deref().unwrap_or("unknown"),
1866 guest_version.as_deref().unwrap_or("unknown"),
1867 ))?;
1868 }
1869
1870 let (project, replica_id) = &mut *db
1871 .join_project(
1872 project_id,
1873 session.connection_id,
1874 session.user_id(),
1875 request.committer_name.clone(),
1876 request.committer_email.clone(),
1877 )
1878 .await?;
1879 drop(db);
1880
1881 tracing::info!(%project_id, "join remote project");
1882 let collaborators = project
1883 .collaborators
1884 .iter()
1885 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1886 .map(|collaborator| collaborator.to_proto())
1887 .collect::<Vec<_>>();
1888 let project_id = project.id;
1889 let guest_user_id = session.user_id();
1890
1891 let worktrees = project
1892 .worktrees
1893 .iter()
1894 .map(|(id, worktree)| proto::WorktreeMetadata {
1895 id: *id,
1896 root_name: worktree.root_name.clone(),
1897 visible: worktree.visible,
1898 abs_path: worktree.abs_path.clone(),
1899 root_repo_common_dir: None,
1900 })
1901 .collect::<Vec<_>>();
1902
1903 let add_project_collaborator = proto::AddProjectCollaborator {
1904 project_id: project_id.to_proto(),
1905 collaborator: Some(proto::Collaborator {
1906 peer_id: Some(session.connection_id.into()),
1907 replica_id: replica_id.0 as u32,
1908 user_id: guest_user_id.to_proto(),
1909 is_host: false,
1910 committer_name: request.committer_name.clone(),
1911 committer_email: request.committer_email.clone(),
1912 }),
1913 };
1914
1915 for collaborator in &collaborators {
1916 session
1917 .peer
1918 .send(
1919 collaborator.peer_id.unwrap().into(),
1920 add_project_collaborator.clone(),
1921 )
1922 .trace_err();
1923 }
1924
1925 // First, we send the metadata associated with each worktree.
1926 let (language_servers, language_server_capabilities) = project
1927 .language_servers
1928 .clone()
1929 .into_iter()
1930 .map(|server| (server.server, server.capabilities))
1931 .unzip();
1932 response.send(proto::JoinProjectResponse {
1933 project_id: project.id.0 as u64,
1934 worktrees,
1935 replica_id: replica_id.0 as u32,
1936 collaborators,
1937 language_servers,
1938 language_server_capabilities,
1939 role: project.role.into(),
1940 windows_paths: project.path_style == PathStyle::Windows,
1941 features: project.features.clone(),
1942 })?;
1943
1944 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1945 // Stream this worktree's entries.
1946 let message = proto::UpdateWorktree {
1947 project_id: project_id.to_proto(),
1948 worktree_id,
1949 abs_path: worktree.abs_path.clone(),
1950 root_name: worktree.root_name,
1951 root_repo_common_dir: worktree.root_repo_common_dir,
1952 updated_entries: worktree.entries,
1953 removed_entries: Default::default(),
1954 scan_id: worktree.scan_id,
1955 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1956 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1957 removed_repositories: Default::default(),
1958 };
1959 for update in proto::split_worktree_update(message) {
1960 session.peer.send(session.connection_id, update.clone())?;
1961 }
1962
1963 // Stream this worktree's diagnostics.
1964 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1965 if let Some(summary) = worktree_diagnostics.next() {
1966 let message = proto::UpdateDiagnosticSummary {
1967 project_id: project.id.to_proto(),
1968 worktree_id: worktree.id,
1969 summary: Some(summary),
1970 more_summaries: worktree_diagnostics.collect(),
1971 };
1972 session.peer.send(session.connection_id, message)?;
1973 }
1974
1975 for settings_file in worktree.settings_files {
1976 session.peer.send(
1977 session.connection_id,
1978 proto::UpdateWorktreeSettings {
1979 project_id: project_id.to_proto(),
1980 worktree_id: worktree.id,
1981 path: settings_file.path,
1982 content: Some(settings_file.content),
1983 kind: Some(settings_file.kind.to_proto() as i32),
1984 outside_worktree: Some(settings_file.outside_worktree),
1985 },
1986 )?;
1987 }
1988 }
1989
1990 for repository in mem::take(&mut project.repositories) {
1991 for update in split_repository_update(repository) {
1992 session.peer.send(session.connection_id, update)?;
1993 }
1994 }
1995
1996 for language_server in &project.language_servers {
1997 session.peer.send(
1998 session.connection_id,
1999 proto::UpdateLanguageServer {
2000 project_id: project_id.to_proto(),
2001 server_name: Some(language_server.server.name.clone()),
2002 language_server_id: language_server.server.id,
2003 variant: Some(
2004 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2005 proto::LspDiskBasedDiagnosticsUpdated {},
2006 ),
2007 ),
2008 },
2009 )?;
2010 }
2011
2012 Ok(())
2013}
2014
2015/// Leave someone elses shared project.
2016async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2017 let sender_id = session.connection_id;
2018 let project_id = ProjectId::from_proto(request.project_id);
2019 let db = session.db().await;
2020
2021 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2022 tracing::info!(
2023 %project_id,
2024 "leave project"
2025 );
2026
2027 project_left(project, &session);
2028 if let Some(room) = room {
2029 room_updated(room, &session.peer);
2030 }
2031
2032 Ok(())
2033}
2034
2035/// Updates other participants with changes to the project
2036async fn update_project(
2037 request: proto::UpdateProject,
2038 response: Response<proto::UpdateProject>,
2039 session: MessageContext,
2040) -> Result<()> {
2041 let project_id = ProjectId::from_proto(request.project_id);
2042 let (room, guest_connection_ids) = &*session
2043 .db()
2044 .await
2045 .update_project(project_id, session.connection_id, &request.worktrees)
2046 .await?;
2047 broadcast(
2048 Some(session.connection_id),
2049 guest_connection_ids.iter().copied(),
2050 |connection_id| {
2051 session
2052 .peer
2053 .forward_send(session.connection_id, connection_id, request.clone())
2054 },
2055 );
2056 if let Some(room) = room {
2057 room_updated(room, &session.peer);
2058 }
2059 response.send(proto::Ack {})?;
2060
2061 Ok(())
2062}
2063
2064/// Updates other participants with changes to the worktree
2065async fn update_worktree(
2066 request: proto::UpdateWorktree,
2067 response: Response<proto::UpdateWorktree>,
2068 session: MessageContext,
2069) -> Result<()> {
2070 let guest_connection_ids = session
2071 .db()
2072 .await
2073 .update_worktree(&request, session.connection_id)
2074 .await?;
2075
2076 broadcast(
2077 Some(session.connection_id),
2078 guest_connection_ids.iter().copied(),
2079 |connection_id| {
2080 session
2081 .peer
2082 .forward_send(session.connection_id, connection_id, request.clone())
2083 },
2084 );
2085 response.send(proto::Ack {})?;
2086 Ok(())
2087}
2088
2089async fn update_repository(
2090 request: proto::UpdateRepository,
2091 response: Response<proto::UpdateRepository>,
2092 session: MessageContext,
2093) -> Result<()> {
2094 let guest_connection_ids = session
2095 .db()
2096 .await
2097 .update_repository(&request, session.connection_id)
2098 .await?;
2099
2100 broadcast(
2101 Some(session.connection_id),
2102 guest_connection_ids.iter().copied(),
2103 |connection_id| {
2104 session
2105 .peer
2106 .forward_send(session.connection_id, connection_id, request.clone())
2107 },
2108 );
2109 response.send(proto::Ack {})?;
2110 Ok(())
2111}
2112
2113async fn remove_repository(
2114 request: proto::RemoveRepository,
2115 response: Response<proto::RemoveRepository>,
2116 session: MessageContext,
2117) -> Result<()> {
2118 let guest_connection_ids = session
2119 .db()
2120 .await
2121 .remove_repository(&request, session.connection_id)
2122 .await?;
2123
2124 broadcast(
2125 Some(session.connection_id),
2126 guest_connection_ids.iter().copied(),
2127 |connection_id| {
2128 session
2129 .peer
2130 .forward_send(session.connection_id, connection_id, request.clone())
2131 },
2132 );
2133 response.send(proto::Ack {})?;
2134 Ok(())
2135}
2136
2137/// Updates other participants with changes to the diagnostics
2138async fn update_diagnostic_summary(
2139 message: proto::UpdateDiagnosticSummary,
2140 session: MessageContext,
2141) -> Result<()> {
2142 let guest_connection_ids = session
2143 .db()
2144 .await
2145 .update_diagnostic_summary(&message, session.connection_id)
2146 .await?;
2147
2148 broadcast(
2149 Some(session.connection_id),
2150 guest_connection_ids.iter().copied(),
2151 |connection_id| {
2152 session
2153 .peer
2154 .forward_send(session.connection_id, connection_id, message.clone())
2155 },
2156 );
2157
2158 Ok(())
2159}
2160
2161/// Updates other participants with changes to the worktree settings
2162async fn update_worktree_settings(
2163 message: proto::UpdateWorktreeSettings,
2164 session: MessageContext,
2165) -> Result<()> {
2166 let guest_connection_ids = session
2167 .db()
2168 .await
2169 .update_worktree_settings(&message, session.connection_id)
2170 .await?;
2171
2172 broadcast(
2173 Some(session.connection_id),
2174 guest_connection_ids.iter().copied(),
2175 |connection_id| {
2176 session
2177 .peer
2178 .forward_send(session.connection_id, connection_id, message.clone())
2179 },
2180 );
2181
2182 Ok(())
2183}
2184
2185/// Notify other participants that a language server has started.
2186async fn start_language_server(
2187 request: proto::StartLanguageServer,
2188 session: MessageContext,
2189) -> Result<()> {
2190 let guest_connection_ids = session
2191 .db()
2192 .await
2193 .start_language_server(&request, session.connection_id)
2194 .await?;
2195
2196 broadcast(
2197 Some(session.connection_id),
2198 guest_connection_ids.iter().copied(),
2199 |connection_id| {
2200 session
2201 .peer
2202 .forward_send(session.connection_id, connection_id, request.clone())
2203 },
2204 );
2205 Ok(())
2206}
2207
2208/// Notify other participants that a language server has changed.
2209async fn update_language_server(
2210 request: proto::UpdateLanguageServer,
2211 session: MessageContext,
2212) -> Result<()> {
2213 let project_id = ProjectId::from_proto(request.project_id);
2214 let db = session.db().await;
2215
2216 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2217 && let Some(capabilities) = update.capabilities.clone()
2218 {
2219 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2220 .await?;
2221 }
2222
2223 let project_connection_ids = db
2224 .project_connection_ids(project_id, session.connection_id, true)
2225 .await?;
2226 broadcast(
2227 Some(session.connection_id),
2228 project_connection_ids.iter().copied(),
2229 |connection_id| {
2230 session
2231 .peer
2232 .forward_send(session.connection_id, connection_id, request.clone())
2233 },
2234 );
2235 Ok(())
2236}
2237
2238/// forward a project request to the host. These requests should be read only
2239/// as guests are allowed to send them.
2240async fn forward_read_only_project_request<T>(
2241 request: T,
2242 response: Response<T>,
2243 session: MessageContext,
2244) -> Result<()>
2245where
2246 T: EntityMessage + RequestMessage,
2247{
2248 let project_id = ProjectId::from_proto(request.remote_entity_id());
2249 let host_connection_id = session
2250 .db()
2251 .await
2252 .host_for_read_only_project_request(project_id, session.connection_id)
2253 .await?;
2254 let payload = session.forward_request(host_connection_id, request).await?;
2255 response.send(payload)?;
2256 Ok(())
2257}
2258
2259/// forward a project request to the host. These requests are disallowed
2260/// for guests.
2261async fn forward_mutating_project_request<T>(
2262 request: T,
2263 response: Response<T>,
2264 session: MessageContext,
2265) -> Result<()>
2266where
2267 T: EntityMessage + RequestMessage,
2268{
2269 let project_id = ProjectId::from_proto(request.remote_entity_id());
2270
2271 let host_connection_id = session
2272 .db()
2273 .await
2274 .host_for_mutating_project_request(project_id, session.connection_id)
2275 .await?;
2276 let payload = session.forward_request(host_connection_id, request).await?;
2277 response.send(payload)?;
2278 Ok(())
2279}
2280
2281async fn disallow_guest_request<T>(
2282 _request: T,
2283 response: Response<T>,
2284 _session: MessageContext,
2285) -> Result<()>
2286where
2287 T: RequestMessage,
2288{
2289 response.peer.respond_with_error(
2290 response.receipt,
2291 ErrorCode::Forbidden
2292 .message("request is not allowed for guests".to_string())
2293 .to_proto(),
2294 )?;
2295 response.responded.store(true, SeqCst);
2296 Ok(())
2297}
2298
2299async fn lsp_query(
2300 request: proto::LspQuery,
2301 response: Response<proto::LspQuery>,
2302 session: MessageContext,
2303) -> Result<()> {
2304 let (name, should_write) = request.query_name_and_write_permissions();
2305 tracing::Span::current().record("lsp_query_request", name);
2306 tracing::info!("lsp_query message received");
2307 if should_write {
2308 forward_mutating_project_request(request, response, session).await
2309 } else {
2310 forward_read_only_project_request(request, response, session).await
2311 }
2312}
2313
2314/// Notify other participants that a new buffer has been created
2315async fn create_buffer_for_peer(
2316 request: proto::CreateBufferForPeer,
2317 session: MessageContext,
2318) -> Result<()> {
2319 session
2320 .db()
2321 .await
2322 .check_user_is_project_host(
2323 ProjectId::from_proto(request.project_id),
2324 session.connection_id,
2325 )
2326 .await?;
2327 let peer_id = request.peer_id.context("invalid peer id")?;
2328 session
2329 .peer
2330 .forward_send(session.connection_id, peer_id.into(), request)?;
2331 Ok(())
2332}
2333
2334/// Notify other participants that a new image has been created
2335async fn create_image_for_peer(
2336 request: proto::CreateImageForPeer,
2337 session: MessageContext,
2338) -> Result<()> {
2339 session
2340 .db()
2341 .await
2342 .check_user_is_project_host(
2343 ProjectId::from_proto(request.project_id),
2344 session.connection_id,
2345 )
2346 .await?;
2347 let peer_id = request.peer_id.context("invalid peer id")?;
2348 session
2349 .peer
2350 .forward_send(session.connection_id, peer_id.into(), request)?;
2351 Ok(())
2352}
2353
2354/// Notify other participants that a buffer has been updated. This is
2355/// allowed for guests as long as the update is limited to selections.
2356async fn update_buffer(
2357 request: proto::UpdateBuffer,
2358 response: Response<proto::UpdateBuffer>,
2359 session: MessageContext,
2360) -> Result<()> {
2361 let project_id = ProjectId::from_proto(request.project_id);
2362 let mut capability = Capability::ReadOnly;
2363
2364 for op in request.operations.iter() {
2365 match op.variant {
2366 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2367 Some(_) => capability = Capability::ReadWrite,
2368 }
2369 }
2370
2371 let host = {
2372 let guard = session
2373 .db()
2374 .await
2375 .connections_for_buffer_update(project_id, session.connection_id, capability)
2376 .await?;
2377
2378 let (host, guests) = &*guard;
2379
2380 broadcast(
2381 Some(session.connection_id),
2382 guests.clone(),
2383 |connection_id| {
2384 session
2385 .peer
2386 .forward_send(session.connection_id, connection_id, request.clone())
2387 },
2388 );
2389
2390 *host
2391 };
2392
2393 if host != session.connection_id {
2394 session.forward_request(host, request.clone()).await?;
2395 }
2396
2397 response.send(proto::Ack {})?;
2398 Ok(())
2399}
2400
2401async fn forward_project_search_chunk(
2402 message: proto::FindSearchCandidatesChunk,
2403 response: Response<proto::FindSearchCandidatesChunk>,
2404 session: MessageContext,
2405) -> Result<()> {
2406 let peer_id = message.peer_id.context("missing peer_id")?;
2407 let payload = session
2408 .peer
2409 .forward_request(session.connection_id, peer_id.into(), message)
2410 .await?;
2411 response.send(payload)?;
2412 Ok(())
2413}
2414
2415/// Notify other participants that a project has been updated.
2416async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2417 request: T,
2418 session: MessageContext,
2419) -> Result<()> {
2420 let project_id = ProjectId::from_proto(request.remote_entity_id());
2421 let project_connection_ids = session
2422 .db()
2423 .await
2424 .project_connection_ids(project_id, session.connection_id, false)
2425 .await?;
2426
2427 broadcast(
2428 Some(session.connection_id),
2429 project_connection_ids.iter().copied(),
2430 |connection_id| {
2431 session
2432 .peer
2433 .forward_send(session.connection_id, connection_id, request.clone())
2434 },
2435 );
2436 Ok(())
2437}
2438
2439/// Start following another user in a call.
2440async fn follow(
2441 request: proto::Follow,
2442 response: Response<proto::Follow>,
2443 session: MessageContext,
2444) -> Result<()> {
2445 let room_id = RoomId::from_proto(request.room_id);
2446 let project_id = request.project_id.map(ProjectId::from_proto);
2447 let leader_id = request.leader_id.context("invalid leader id")?.into();
2448 let follower_id = session.connection_id;
2449
2450 session
2451 .db()
2452 .await
2453 .check_room_participants(room_id, leader_id, session.connection_id)
2454 .await?;
2455
2456 let response_payload = session.forward_request(leader_id, request).await?;
2457 response.send(response_payload)?;
2458
2459 if let Some(project_id) = project_id {
2460 let room = session
2461 .db()
2462 .await
2463 .follow(room_id, project_id, leader_id, follower_id)
2464 .await?;
2465 room_updated(&room, &session.peer);
2466 }
2467
2468 Ok(())
2469}
2470
2471/// Stop following another user in a call.
2472async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2473 let room_id = RoomId::from_proto(request.room_id);
2474 let project_id = request.project_id.map(ProjectId::from_proto);
2475 let leader_id = request.leader_id.context("invalid leader id")?.into();
2476 let follower_id = session.connection_id;
2477
2478 session
2479 .db()
2480 .await
2481 .check_room_participants(room_id, leader_id, session.connection_id)
2482 .await?;
2483
2484 session
2485 .peer
2486 .forward_send(session.connection_id, leader_id, request)?;
2487
2488 if let Some(project_id) = project_id {
2489 let room = session
2490 .db()
2491 .await
2492 .unfollow(room_id, project_id, leader_id, follower_id)
2493 .await?;
2494 room_updated(&room, &session.peer);
2495 }
2496
2497 Ok(())
2498}
2499
2500/// Notify everyone following you of your current location.
2501async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2502 let room_id = RoomId::from_proto(request.room_id);
2503 let database = session.db.lock().await;
2504
2505 let connection_ids = if let Some(project_id) = request.project_id {
2506 let project_id = ProjectId::from_proto(project_id);
2507 database
2508 .project_connection_ids(project_id, session.connection_id, true)
2509 .await?
2510 } else {
2511 database
2512 .room_connection_ids(room_id, session.connection_id)
2513 .await?
2514 };
2515
2516 // For now, don't send view update messages back to that view's current leader.
2517 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2518 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2519 _ => None,
2520 });
2521
2522 for connection_id in connection_ids.iter().cloned() {
2523 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2524 session
2525 .peer
2526 .forward_send(session.connection_id, connection_id, request.clone())?;
2527 }
2528 }
2529 Ok(())
2530}
2531
2532/// Get public data about users.
2533async fn get_users(
2534 request: proto::GetUsers,
2535 response: Response<proto::GetUsers>,
2536 session: MessageContext,
2537) -> Result<()> {
2538 let user_ids = request
2539 .user_ids
2540 .into_iter()
2541 .map(UserId::from_proto)
2542 .collect();
2543 let users = session
2544 .db()
2545 .await
2546 .get_users_by_ids(user_ids)
2547 .await?
2548 .into_iter()
2549 .map(|user| proto::User {
2550 id: user.id.to_proto(),
2551 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2552 github_login: user.github_login,
2553 name: user.name,
2554 })
2555 .collect();
2556 response.send(proto::UsersResponse { users })?;
2557 Ok(())
2558}
2559
2560/// Search for users (to invite) buy Github login
2561async fn fuzzy_search_users(
2562 request: proto::FuzzySearchUsers,
2563 response: Response<proto::FuzzySearchUsers>,
2564 session: MessageContext,
2565) -> Result<()> {
2566 let query = request.query;
2567 let users = match query.len() {
2568 0 => vec![],
2569 1 | 2 => session
2570 .db()
2571 .await
2572 .get_user_by_github_login(&query)
2573 .await?
2574 .into_iter()
2575 .collect(),
2576 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2577 };
2578 let users = users
2579 .into_iter()
2580 .filter(|user| user.id != session.user_id())
2581 .map(|user| proto::User {
2582 id: user.id.to_proto(),
2583 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2584 github_login: user.github_login,
2585 name: user.name,
2586 })
2587 .collect();
2588 response.send(proto::UsersResponse { users })?;
2589 Ok(())
2590}
2591
2592/// Send a contact request to another user.
2593async fn request_contact(
2594 request: proto::RequestContact,
2595 response: Response<proto::RequestContact>,
2596 session: MessageContext,
2597) -> Result<()> {
2598 let requester_id = session.user_id();
2599 let responder_id = UserId::from_proto(request.responder_id);
2600 if requester_id == responder_id {
2601 return Err(anyhow!("cannot add yourself as a contact"))?;
2602 }
2603
2604 let notifications = session
2605 .db()
2606 .await
2607 .send_contact_request(requester_id, responder_id)
2608 .await?;
2609
2610 // Update outgoing contact requests of requester
2611 let mut update = proto::UpdateContacts::default();
2612 update.outgoing_requests.push(responder_id.to_proto());
2613 for connection_id in session
2614 .connection_pool()
2615 .await
2616 .user_connection_ids(requester_id)
2617 {
2618 session.peer.send(connection_id, update.clone())?;
2619 }
2620
2621 // Update incoming contact requests of responder
2622 let mut update = proto::UpdateContacts::default();
2623 update
2624 .incoming_requests
2625 .push(proto::IncomingContactRequest {
2626 requester_id: requester_id.to_proto(),
2627 });
2628 let connection_pool = session.connection_pool().await;
2629 for connection_id in connection_pool.user_connection_ids(responder_id) {
2630 session.peer.send(connection_id, update.clone())?;
2631 }
2632
2633 send_notifications(&connection_pool, &session.peer, notifications);
2634
2635 response.send(proto::Ack {})?;
2636 Ok(())
2637}
2638
2639/// Accept or decline a contact request
2640async fn respond_to_contact_request(
2641 request: proto::RespondToContactRequest,
2642 response: Response<proto::RespondToContactRequest>,
2643 session: MessageContext,
2644) -> Result<()> {
2645 let responder_id = session.user_id();
2646 let requester_id = UserId::from_proto(request.requester_id);
2647 let db = session.db().await;
2648 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2649 db.dismiss_contact_notification(responder_id, requester_id)
2650 .await?;
2651 } else {
2652 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2653
2654 let notifications = db
2655 .respond_to_contact_request(responder_id, requester_id, accept)
2656 .await?;
2657 let requester_busy = db.is_user_busy(requester_id).await?;
2658 let responder_busy = db.is_user_busy(responder_id).await?;
2659
2660 let pool = session.connection_pool().await;
2661 // Update responder with new contact
2662 let mut update = proto::UpdateContacts::default();
2663 if accept {
2664 update
2665 .contacts
2666 .push(contact_for_user(requester_id, requester_busy, &pool));
2667 }
2668 update
2669 .remove_incoming_requests
2670 .push(requester_id.to_proto());
2671 for connection_id in pool.user_connection_ids(responder_id) {
2672 session.peer.send(connection_id, update.clone())?;
2673 }
2674
2675 // Update requester with new contact
2676 let mut update = proto::UpdateContacts::default();
2677 if accept {
2678 update
2679 .contacts
2680 .push(contact_for_user(responder_id, responder_busy, &pool));
2681 }
2682 update
2683 .remove_outgoing_requests
2684 .push(responder_id.to_proto());
2685
2686 for connection_id in pool.user_connection_ids(requester_id) {
2687 session.peer.send(connection_id, update.clone())?;
2688 }
2689
2690 send_notifications(&pool, &session.peer, notifications);
2691 }
2692
2693 response.send(proto::Ack {})?;
2694 Ok(())
2695}
2696
2697/// Remove a contact.
2698async fn remove_contact(
2699 request: proto::RemoveContact,
2700 response: Response<proto::RemoveContact>,
2701 session: MessageContext,
2702) -> Result<()> {
2703 let requester_id = session.user_id();
2704 let responder_id = UserId::from_proto(request.user_id);
2705 let db = session.db().await;
2706 let (contact_accepted, deleted_notification_id) =
2707 db.remove_contact(requester_id, responder_id).await?;
2708
2709 let pool = session.connection_pool().await;
2710 // Update outgoing contact requests of requester
2711 let mut update = proto::UpdateContacts::default();
2712 if contact_accepted {
2713 update.remove_contacts.push(responder_id.to_proto());
2714 } else {
2715 update
2716 .remove_outgoing_requests
2717 .push(responder_id.to_proto());
2718 }
2719 for connection_id in pool.user_connection_ids(requester_id) {
2720 session.peer.send(connection_id, update.clone())?;
2721 }
2722
2723 // Update incoming contact requests of responder
2724 let mut update = proto::UpdateContacts::default();
2725 if contact_accepted {
2726 update.remove_contacts.push(requester_id.to_proto());
2727 } else {
2728 update
2729 .remove_incoming_requests
2730 .push(requester_id.to_proto());
2731 }
2732 for connection_id in pool.user_connection_ids(responder_id) {
2733 session.peer.send(connection_id, update.clone())?;
2734 if let Some(notification_id) = deleted_notification_id {
2735 session.peer.send(
2736 connection_id,
2737 proto::DeleteNotification {
2738 notification_id: notification_id.to_proto(),
2739 },
2740 )?;
2741 }
2742 }
2743
2744 response.send(proto::Ack {})?;
2745 Ok(())
2746}
2747
2748fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2749 version.0.minor < 139
2750}
2751
2752async fn subscribe_to_channels(
2753 _: proto::SubscribeToChannels,
2754 session: MessageContext,
2755) -> Result<()> {
2756 subscribe_user_to_channels(session.user_id(), &session).await?;
2757 Ok(())
2758}
2759
2760async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2761 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2762 let mut pool = session.connection_pool().await;
2763 for membership in &channels_for_user.channel_memberships {
2764 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2765 }
2766 session.peer.send(
2767 session.connection_id,
2768 build_update_user_channels(&channels_for_user),
2769 )?;
2770 session.peer.send(
2771 session.connection_id,
2772 build_channels_update(channels_for_user),
2773 )?;
2774 Ok(())
2775}
2776
2777/// Creates a new channel.
2778async fn create_channel(
2779 request: proto::CreateChannel,
2780 response: Response<proto::CreateChannel>,
2781 session: MessageContext,
2782) -> Result<()> {
2783 let db = session.db().await;
2784
2785 let parent_id = request.parent_id.map(ChannelId::from_proto);
2786 let (channel, membership) = db
2787 .create_channel(&request.name, parent_id, session.user_id())
2788 .await?;
2789
2790 let root_id = channel.root_id();
2791 let channel = Channel::from_model(channel);
2792
2793 response.send(proto::CreateChannelResponse {
2794 channel: Some(channel.to_proto()),
2795 parent_id: request.parent_id,
2796 })?;
2797
2798 let mut connection_pool = session.connection_pool().await;
2799 if let Some(membership) = membership {
2800 connection_pool.subscribe_to_channel(
2801 membership.user_id,
2802 membership.channel_id,
2803 membership.role,
2804 );
2805 let update = proto::UpdateUserChannels {
2806 channel_memberships: vec![proto::ChannelMembership {
2807 channel_id: membership.channel_id.to_proto(),
2808 role: membership.role.into(),
2809 }],
2810 ..Default::default()
2811 };
2812 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2813 session.peer.send(connection_id, update.clone())?;
2814 }
2815 }
2816
2817 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2818 if !role.can_see_channel(channel.visibility) {
2819 continue;
2820 }
2821
2822 let update = proto::UpdateChannels {
2823 channels: vec![channel.to_proto()],
2824 ..Default::default()
2825 };
2826 session.peer.send(connection_id, update.clone())?;
2827 }
2828
2829 Ok(())
2830}
2831
2832/// Delete a channel
2833async fn delete_channel(
2834 request: proto::DeleteChannel,
2835 response: Response<proto::DeleteChannel>,
2836 session: MessageContext,
2837) -> Result<()> {
2838 let db = session.db().await;
2839
2840 let channel_id = request.channel_id;
2841 let (root_channel, removed_channels) = db
2842 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2843 .await?;
2844 response.send(proto::Ack {})?;
2845
2846 // Notify members of removed channels
2847 let mut update = proto::UpdateChannels::default();
2848 update
2849 .delete_channels
2850 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2851
2852 let connection_pool = session.connection_pool().await;
2853 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2854 session.peer.send(connection_id, update.clone())?;
2855 }
2856
2857 Ok(())
2858}
2859
2860/// Invite someone to join a channel.
2861async fn invite_channel_member(
2862 request: proto::InviteChannelMember,
2863 response: Response<proto::InviteChannelMember>,
2864 session: MessageContext,
2865) -> Result<()> {
2866 let db = session.db().await;
2867 let channel_id = ChannelId::from_proto(request.channel_id);
2868 let invitee_id = UserId::from_proto(request.user_id);
2869 let InviteMemberResult {
2870 channel,
2871 notifications,
2872 } = db
2873 .invite_channel_member(
2874 channel_id,
2875 invitee_id,
2876 session.user_id(),
2877 request.role().into(),
2878 )
2879 .await?;
2880
2881 let update = proto::UpdateChannels {
2882 channel_invitations: vec![channel.to_proto()],
2883 ..Default::default()
2884 };
2885
2886 let connection_pool = session.connection_pool().await;
2887 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2888 session.peer.send(connection_id, update.clone())?;
2889 }
2890
2891 send_notifications(&connection_pool, &session.peer, notifications);
2892
2893 response.send(proto::Ack {})?;
2894 Ok(())
2895}
2896
2897/// remove someone from a channel
2898async fn remove_channel_member(
2899 request: proto::RemoveChannelMember,
2900 response: Response<proto::RemoveChannelMember>,
2901 session: MessageContext,
2902) -> Result<()> {
2903 let db = session.db().await;
2904 let channel_id = ChannelId::from_proto(request.channel_id);
2905 let member_id = UserId::from_proto(request.user_id);
2906
2907 let RemoveChannelMemberResult {
2908 membership_update,
2909 notification_id,
2910 } = db
2911 .remove_channel_member(channel_id, member_id, session.user_id())
2912 .await?;
2913
2914 let mut connection_pool = session.connection_pool().await;
2915 notify_membership_updated(
2916 &mut connection_pool,
2917 membership_update,
2918 member_id,
2919 &session.peer,
2920 );
2921 for connection_id in connection_pool.user_connection_ids(member_id) {
2922 if let Some(notification_id) = notification_id {
2923 session
2924 .peer
2925 .send(
2926 connection_id,
2927 proto::DeleteNotification {
2928 notification_id: notification_id.to_proto(),
2929 },
2930 )
2931 .trace_err();
2932 }
2933 }
2934
2935 response.send(proto::Ack {})?;
2936 Ok(())
2937}
2938
2939/// Toggle the channel between public and private.
2940/// Care is taken to maintain the invariant that public channels only descend from public channels,
2941/// (though members-only channels can appear at any point in the hierarchy).
2942async fn set_channel_visibility(
2943 request: proto::SetChannelVisibility,
2944 response: Response<proto::SetChannelVisibility>,
2945 session: MessageContext,
2946) -> Result<()> {
2947 let db = session.db().await;
2948 let channel_id = ChannelId::from_proto(request.channel_id);
2949 let visibility = request.visibility().into();
2950
2951 let channel_model = db
2952 .set_channel_visibility(channel_id, visibility, session.user_id())
2953 .await?;
2954 let root_id = channel_model.root_id();
2955 let channel = Channel::from_model(channel_model);
2956
2957 let mut connection_pool = session.connection_pool().await;
2958 for (user_id, role) in connection_pool
2959 .channel_user_ids(root_id)
2960 .collect::<Vec<_>>()
2961 .into_iter()
2962 {
2963 let update = if role.can_see_channel(channel.visibility) {
2964 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2965 proto::UpdateChannels {
2966 channels: vec![channel.to_proto()],
2967 ..Default::default()
2968 }
2969 } else {
2970 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2971 proto::UpdateChannels {
2972 delete_channels: vec![channel.id.to_proto()],
2973 ..Default::default()
2974 }
2975 };
2976
2977 for connection_id in connection_pool.user_connection_ids(user_id) {
2978 session.peer.send(connection_id, update.clone())?;
2979 }
2980 }
2981
2982 response.send(proto::Ack {})?;
2983 Ok(())
2984}
2985
2986/// Alter the role for a user in the channel.
2987async fn set_channel_member_role(
2988 request: proto::SetChannelMemberRole,
2989 response: Response<proto::SetChannelMemberRole>,
2990 session: MessageContext,
2991) -> Result<()> {
2992 let db = session.db().await;
2993 let channel_id = ChannelId::from_proto(request.channel_id);
2994 let member_id = UserId::from_proto(request.user_id);
2995 let result = db
2996 .set_channel_member_role(
2997 channel_id,
2998 session.user_id(),
2999 member_id,
3000 request.role().into(),
3001 )
3002 .await?;
3003
3004 match result {
3005 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3006 let mut connection_pool = session.connection_pool().await;
3007 notify_membership_updated(
3008 &mut connection_pool,
3009 membership_update,
3010 member_id,
3011 &session.peer,
3012 )
3013 }
3014 db::SetMemberRoleResult::InviteUpdated(channel) => {
3015 let update = proto::UpdateChannels {
3016 channel_invitations: vec![channel.to_proto()],
3017 ..Default::default()
3018 };
3019
3020 for connection_id in session
3021 .connection_pool()
3022 .await
3023 .user_connection_ids(member_id)
3024 {
3025 session.peer.send(connection_id, update.clone())?;
3026 }
3027 }
3028 }
3029
3030 response.send(proto::Ack {})?;
3031 Ok(())
3032}
3033
3034/// Change the name of a channel
3035async fn rename_channel(
3036 request: proto::RenameChannel,
3037 response: Response<proto::RenameChannel>,
3038 session: MessageContext,
3039) -> Result<()> {
3040 let db = session.db().await;
3041 let channel_id = ChannelId::from_proto(request.channel_id);
3042 let channel_model = db
3043 .rename_channel(channel_id, session.user_id(), &request.name)
3044 .await?;
3045 let root_id = channel_model.root_id();
3046 let channel = Channel::from_model(channel_model);
3047
3048 response.send(proto::RenameChannelResponse {
3049 channel: Some(channel.to_proto()),
3050 })?;
3051
3052 let connection_pool = session.connection_pool().await;
3053 let update = proto::UpdateChannels {
3054 channels: vec![channel.to_proto()],
3055 ..Default::default()
3056 };
3057 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3058 if role.can_see_channel(channel.visibility) {
3059 session.peer.send(connection_id, update.clone())?;
3060 }
3061 }
3062
3063 Ok(())
3064}
3065
3066/// Move a channel to a new parent.
3067async fn move_channel(
3068 request: proto::MoveChannel,
3069 response: Response<proto::MoveChannel>,
3070 session: MessageContext,
3071) -> Result<()> {
3072 let channel_id = ChannelId::from_proto(request.channel_id);
3073 let to = ChannelId::from_proto(request.to);
3074
3075 let (root_id, channels) = session
3076 .db()
3077 .await
3078 .move_channel(channel_id, to, session.user_id())
3079 .await?;
3080
3081 let connection_pool = session.connection_pool().await;
3082 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3083 let channels = channels
3084 .iter()
3085 .filter_map(|channel| {
3086 if role.can_see_channel(channel.visibility) {
3087 Some(channel.to_proto())
3088 } else {
3089 None
3090 }
3091 })
3092 .collect::<Vec<_>>();
3093 if channels.is_empty() {
3094 continue;
3095 }
3096
3097 let update = proto::UpdateChannels {
3098 channels,
3099 ..Default::default()
3100 };
3101
3102 session.peer.send(connection_id, update.clone())?;
3103 }
3104
3105 response.send(Ack {})?;
3106 Ok(())
3107}
3108
3109async fn reorder_channel(
3110 request: proto::ReorderChannel,
3111 response: Response<proto::ReorderChannel>,
3112 session: MessageContext,
3113) -> Result<()> {
3114 let channel_id = ChannelId::from_proto(request.channel_id);
3115 let direction = request.direction();
3116
3117 let updated_channels = session
3118 .db()
3119 .await
3120 .reorder_channel(channel_id, direction, session.user_id())
3121 .await?;
3122
3123 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3124 let connection_pool = session.connection_pool().await;
3125 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3126 let channels = updated_channels
3127 .iter()
3128 .filter_map(|channel| {
3129 if role.can_see_channel(channel.visibility) {
3130 Some(channel.to_proto())
3131 } else {
3132 None
3133 }
3134 })
3135 .collect::<Vec<_>>();
3136
3137 if channels.is_empty() {
3138 continue;
3139 }
3140
3141 let update = proto::UpdateChannels {
3142 channels,
3143 ..Default::default()
3144 };
3145
3146 session.peer.send(connection_id, update.clone())?;
3147 }
3148 }
3149
3150 response.send(Ack {})?;
3151 Ok(())
3152}
3153
3154/// Get the list of channel members
3155async fn get_channel_members(
3156 request: proto::GetChannelMembers,
3157 response: Response<proto::GetChannelMembers>,
3158 session: MessageContext,
3159) -> Result<()> {
3160 let db = session.db().await;
3161 let channel_id = ChannelId::from_proto(request.channel_id);
3162 let limit = if request.limit == 0 {
3163 u16::MAX as u64
3164 } else {
3165 request.limit
3166 };
3167 let (members, users) = db
3168 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3169 .await?;
3170 response.send(proto::GetChannelMembersResponse { members, users })?;
3171 Ok(())
3172}
3173
3174/// Accept or decline a channel invitation.
3175async fn respond_to_channel_invite(
3176 request: proto::RespondToChannelInvite,
3177 response: Response<proto::RespondToChannelInvite>,
3178 session: MessageContext,
3179) -> Result<()> {
3180 let db = session.db().await;
3181 let channel_id = ChannelId::from_proto(request.channel_id);
3182 let RespondToChannelInvite {
3183 membership_update,
3184 notifications,
3185 } = db
3186 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3187 .await?;
3188
3189 let mut connection_pool = session.connection_pool().await;
3190 if let Some(membership_update) = membership_update {
3191 notify_membership_updated(
3192 &mut connection_pool,
3193 membership_update,
3194 session.user_id(),
3195 &session.peer,
3196 );
3197 } else {
3198 let update = proto::UpdateChannels {
3199 remove_channel_invitations: vec![channel_id.to_proto()],
3200 ..Default::default()
3201 };
3202
3203 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3204 session.peer.send(connection_id, update.clone())?;
3205 }
3206 };
3207
3208 send_notifications(&connection_pool, &session.peer, notifications);
3209
3210 response.send(proto::Ack {})?;
3211
3212 Ok(())
3213}
3214
3215/// Join the channels' room
3216async fn join_channel(
3217 request: proto::JoinChannel,
3218 response: Response<proto::JoinChannel>,
3219 session: MessageContext,
3220) -> Result<()> {
3221 let channel_id = ChannelId::from_proto(request.channel_id);
3222 join_channel_internal(channel_id, Box::new(response), session).await
3223}
3224
3225trait JoinChannelInternalResponse {
3226 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3227}
3228impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3229 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3230 Response::<proto::JoinChannel>::send(self, result)
3231 }
3232}
3233impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3234 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3235 Response::<proto::JoinRoom>::send(self, result)
3236 }
3237}
3238
3239async fn join_channel_internal(
3240 channel_id: ChannelId,
3241 response: Box<impl JoinChannelInternalResponse>,
3242 session: MessageContext,
3243) -> Result<()> {
3244 let joined_room = {
3245 let mut db = session.db().await;
3246 // If zed quits without leaving the room, and the user re-opens zed before the
3247 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3248 // room they were in.
3249 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3250 tracing::info!(
3251 stale_connection_id = %connection,
3252 "cleaning up stale connection",
3253 );
3254 drop(db);
3255 leave_room_for_session(&session, connection).await?;
3256 db = session.db().await;
3257 }
3258
3259 let (joined_room, membership_updated, role) = db
3260 .join_channel(channel_id, session.user_id(), session.connection_id)
3261 .await?;
3262
3263 let live_kit_connection_info =
3264 session
3265 .app_state
3266 .livekit_client
3267 .as_ref()
3268 .and_then(|live_kit| {
3269 let (can_publish, token) = if role == ChannelRole::Guest {
3270 (
3271 false,
3272 live_kit
3273 .guest_token(
3274 &joined_room.room.livekit_room,
3275 &session.user_id().to_string(),
3276 )
3277 .trace_err()?,
3278 )
3279 } else {
3280 (
3281 true,
3282 live_kit
3283 .room_token(
3284 &joined_room.room.livekit_room,
3285 &session.user_id().to_string(),
3286 )
3287 .trace_err()?,
3288 )
3289 };
3290
3291 Some(LiveKitConnectionInfo {
3292 server_url: live_kit.url().into(),
3293 token,
3294 can_publish,
3295 })
3296 });
3297
3298 response.send(proto::JoinRoomResponse {
3299 room: Some(joined_room.room.clone()),
3300 channel_id: joined_room
3301 .channel
3302 .as_ref()
3303 .map(|channel| channel.id.to_proto()),
3304 live_kit_connection_info,
3305 })?;
3306
3307 let mut connection_pool = session.connection_pool().await;
3308 if let Some(membership_updated) = membership_updated {
3309 notify_membership_updated(
3310 &mut connection_pool,
3311 membership_updated,
3312 session.user_id(),
3313 &session.peer,
3314 );
3315 }
3316
3317 room_updated(&joined_room.room, &session.peer);
3318
3319 joined_room
3320 };
3321
3322 channel_updated(
3323 &joined_room.channel.context("channel not returned")?,
3324 &joined_room.room,
3325 &session.peer,
3326 &*session.connection_pool().await,
3327 );
3328
3329 update_user_contacts(session.user_id(), &session).await?;
3330 Ok(())
3331}
3332
3333/// Start editing the channel notes
3334async fn join_channel_buffer(
3335 request: proto::JoinChannelBuffer,
3336 response: Response<proto::JoinChannelBuffer>,
3337 session: MessageContext,
3338) -> Result<()> {
3339 let db = session.db().await;
3340 let channel_id = ChannelId::from_proto(request.channel_id);
3341
3342 let open_response = db
3343 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3344 .await?;
3345
3346 let collaborators = open_response.collaborators.clone();
3347 response.send(open_response)?;
3348
3349 let update = UpdateChannelBufferCollaborators {
3350 channel_id: channel_id.to_proto(),
3351 collaborators: collaborators.clone(),
3352 };
3353 channel_buffer_updated(
3354 session.connection_id,
3355 collaborators
3356 .iter()
3357 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3358 &update,
3359 &session.peer,
3360 );
3361
3362 Ok(())
3363}
3364
3365/// Edit the channel notes
3366async fn update_channel_buffer(
3367 request: proto::UpdateChannelBuffer,
3368 session: MessageContext,
3369) -> Result<()> {
3370 let db = session.db().await;
3371 let channel_id = ChannelId::from_proto(request.channel_id);
3372
3373 let (collaborators, epoch, version) = db
3374 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3375 .await?;
3376
3377 channel_buffer_updated(
3378 session.connection_id,
3379 collaborators.clone(),
3380 &proto::UpdateChannelBuffer {
3381 channel_id: channel_id.to_proto(),
3382 operations: request.operations,
3383 },
3384 &session.peer,
3385 );
3386
3387 let pool = &*session.connection_pool().await;
3388
3389 let non_collaborators =
3390 pool.channel_connection_ids(channel_id)
3391 .filter_map(|(connection_id, _)| {
3392 if collaborators.contains(&connection_id) {
3393 None
3394 } else {
3395 Some(connection_id)
3396 }
3397 });
3398
3399 broadcast(None, non_collaborators, |peer_id| {
3400 session.peer.send(
3401 peer_id,
3402 proto::UpdateChannels {
3403 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3404 channel_id: channel_id.to_proto(),
3405 epoch: epoch as u64,
3406 version: version.clone(),
3407 }],
3408 ..Default::default()
3409 },
3410 )
3411 });
3412
3413 Ok(())
3414}
3415
3416/// Rejoin the channel notes after a connection blip
3417async fn rejoin_channel_buffers(
3418 request: proto::RejoinChannelBuffers,
3419 response: Response<proto::RejoinChannelBuffers>,
3420 session: MessageContext,
3421) -> Result<()> {
3422 let db = session.db().await;
3423 let buffers = db
3424 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3425 .await?;
3426
3427 for rejoined_buffer in &buffers {
3428 let collaborators_to_notify = rejoined_buffer
3429 .buffer
3430 .collaborators
3431 .iter()
3432 .filter_map(|c| Some(c.peer_id?.into()));
3433 channel_buffer_updated(
3434 session.connection_id,
3435 collaborators_to_notify,
3436 &proto::UpdateChannelBufferCollaborators {
3437 channel_id: rejoined_buffer.buffer.channel_id,
3438 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3439 },
3440 &session.peer,
3441 );
3442 }
3443
3444 response.send(proto::RejoinChannelBuffersResponse {
3445 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3446 })?;
3447
3448 Ok(())
3449}
3450
3451/// Stop editing the channel notes
3452async fn leave_channel_buffer(
3453 request: proto::LeaveChannelBuffer,
3454 response: Response<proto::LeaveChannelBuffer>,
3455 session: MessageContext,
3456) -> Result<()> {
3457 let db = session.db().await;
3458 let channel_id = ChannelId::from_proto(request.channel_id);
3459
3460 let left_buffer = db
3461 .leave_channel_buffer(channel_id, session.connection_id)
3462 .await?;
3463
3464 response.send(Ack {})?;
3465
3466 channel_buffer_updated(
3467 session.connection_id,
3468 left_buffer.connections,
3469 &proto::UpdateChannelBufferCollaborators {
3470 channel_id: channel_id.to_proto(),
3471 collaborators: left_buffer.collaborators,
3472 },
3473 &session.peer,
3474 );
3475
3476 Ok(())
3477}
3478
3479fn channel_buffer_updated<T: EnvelopedMessage>(
3480 sender_id: ConnectionId,
3481 collaborators: impl IntoIterator<Item = ConnectionId>,
3482 message: &T,
3483 peer: &Peer,
3484) {
3485 broadcast(Some(sender_id), collaborators, |peer_id| {
3486 peer.send(peer_id, message.clone())
3487 });
3488}
3489
3490fn send_notifications(
3491 connection_pool: &ConnectionPool,
3492 peer: &Peer,
3493 notifications: db::NotificationBatch,
3494) {
3495 for (user_id, notification) in notifications {
3496 for connection_id in connection_pool.user_connection_ids(user_id) {
3497 if let Err(error) = peer.send(
3498 connection_id,
3499 proto::AddNotification {
3500 notification: Some(notification.clone()),
3501 },
3502 ) {
3503 tracing::error!(
3504 "failed to send notification to {:?} {}",
3505 connection_id,
3506 error
3507 );
3508 }
3509 }
3510 }
3511}
3512
3513/// Send a message to the channel
3514async fn send_channel_message(
3515 _request: proto::SendChannelMessage,
3516 _response: Response<proto::SendChannelMessage>,
3517 _session: MessageContext,
3518) -> Result<()> {
3519 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3520}
3521
3522/// Delete a channel message
3523async fn remove_channel_message(
3524 _request: proto::RemoveChannelMessage,
3525 _response: Response<proto::RemoveChannelMessage>,
3526 _session: MessageContext,
3527) -> Result<()> {
3528 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3529}
3530
3531async fn update_channel_message(
3532 _request: proto::UpdateChannelMessage,
3533 _response: Response<proto::UpdateChannelMessage>,
3534 _session: MessageContext,
3535) -> Result<()> {
3536 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3537}
3538
3539/// Mark a channel message as read
3540async fn acknowledge_channel_message(
3541 _request: proto::AckChannelMessage,
3542 _session: MessageContext,
3543) -> Result<()> {
3544 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3545}
3546
3547/// Mark a buffer version as synced
3548async fn acknowledge_buffer_version(
3549 request: proto::AckBufferOperation,
3550 session: MessageContext,
3551) -> Result<()> {
3552 let buffer_id = BufferId::from_proto(request.buffer_id);
3553 session
3554 .db()
3555 .await
3556 .observe_buffer_version(
3557 buffer_id,
3558 session.user_id(),
3559 request.epoch as i32,
3560 &request.version,
3561 )
3562 .await?;
3563 Ok(())
3564}
3565
3566/// Start receiving chat updates for a channel
3567async fn join_channel_chat(
3568 _request: proto::JoinChannelChat,
3569 _response: Response<proto::JoinChannelChat>,
3570 _session: MessageContext,
3571) -> Result<()> {
3572 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3573}
3574
3575/// Stop receiving chat updates for a channel
3576async fn leave_channel_chat(
3577 _request: proto::LeaveChannelChat,
3578 _session: MessageContext,
3579) -> Result<()> {
3580 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3581}
3582
3583/// Retrieve the chat history for a channel
3584async fn get_channel_messages(
3585 _request: proto::GetChannelMessages,
3586 _response: Response<proto::GetChannelMessages>,
3587 _session: MessageContext,
3588) -> Result<()> {
3589 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3590}
3591
3592/// Retrieve specific chat messages
3593async fn get_channel_messages_by_id(
3594 _request: proto::GetChannelMessagesById,
3595 _response: Response<proto::GetChannelMessagesById>,
3596 _session: MessageContext,
3597) -> Result<()> {
3598 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3599}
3600
3601/// Retrieve the current users notifications
3602async fn get_notifications(
3603 request: proto::GetNotifications,
3604 response: Response<proto::GetNotifications>,
3605 session: MessageContext,
3606) -> Result<()> {
3607 let notifications = session
3608 .db()
3609 .await
3610 .get_notifications(
3611 session.user_id(),
3612 NOTIFICATION_COUNT_PER_PAGE,
3613 request.before_id.map(db::NotificationId::from_proto),
3614 )
3615 .await?;
3616 response.send(proto::GetNotificationsResponse {
3617 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3618 notifications,
3619 })?;
3620 Ok(())
3621}
3622
3623/// Mark notifications as read
3624async fn mark_notification_as_read(
3625 request: proto::MarkNotificationRead,
3626 response: Response<proto::MarkNotificationRead>,
3627 session: MessageContext,
3628) -> Result<()> {
3629 let database = &session.db().await;
3630 let notifications = database
3631 .mark_notification_as_read_by_id(
3632 session.user_id(),
3633 NotificationId::from_proto(request.notification_id),
3634 )
3635 .await?;
3636 send_notifications(
3637 &*session.connection_pool().await,
3638 &session.peer,
3639 notifications,
3640 );
3641 response.send(proto::Ack {})?;
3642 Ok(())
3643}
3644
3645fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3646 let message = match message {
3647 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3648 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3649 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3650 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3651 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3652 code: frame.code.into(),
3653 reason: frame.reason.as_str().to_owned().into(),
3654 })),
3655 // We should never receive a frame while reading the message, according
3656 // to the `tungstenite` maintainers:
3657 //
3658 // > It cannot occur when you read messages from the WebSocket, but it
3659 // > can be used when you want to send the raw frames (e.g. you want to
3660 // > send the frames to the WebSocket without composing the full message first).
3661 // >
3662 // > — https://github.com/snapview/tungstenite-rs/issues/268
3663 TungsteniteMessage::Frame(_) => {
3664 bail!("received an unexpected frame while reading the message")
3665 }
3666 };
3667
3668 Ok(message)
3669}
3670
3671fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3672 match message {
3673 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3674 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3675 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3676 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3677 AxumMessage::Close(frame) => {
3678 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3679 code: frame.code.into(),
3680 reason: frame.reason.as_ref().into(),
3681 }))
3682 }
3683 }
3684}
3685
3686fn notify_membership_updated(
3687 connection_pool: &mut ConnectionPool,
3688 result: MembershipUpdated,
3689 user_id: UserId,
3690 peer: &Peer,
3691) {
3692 for membership in &result.new_channels.channel_memberships {
3693 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3694 }
3695 for channel_id in &result.removed_channels {
3696 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3697 }
3698
3699 let user_channels_update = proto::UpdateUserChannels {
3700 channel_memberships: result
3701 .new_channels
3702 .channel_memberships
3703 .iter()
3704 .map(|cm| proto::ChannelMembership {
3705 channel_id: cm.channel_id.to_proto(),
3706 role: cm.role.into(),
3707 })
3708 .collect(),
3709 ..Default::default()
3710 };
3711
3712 let mut update = build_channels_update(result.new_channels);
3713 update.delete_channels = result
3714 .removed_channels
3715 .into_iter()
3716 .map(|id| id.to_proto())
3717 .collect();
3718 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3719
3720 for connection_id in connection_pool.user_connection_ids(user_id) {
3721 peer.send(connection_id, user_channels_update.clone())
3722 .trace_err();
3723 peer.send(connection_id, update.clone()).trace_err();
3724 }
3725}
3726
3727fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3728 proto::UpdateUserChannels {
3729 channel_memberships: channels
3730 .channel_memberships
3731 .iter()
3732 .map(|m| proto::ChannelMembership {
3733 channel_id: m.channel_id.to_proto(),
3734 role: m.role.into(),
3735 })
3736 .collect(),
3737 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3738 }
3739}
3740
3741fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3742 let mut update = proto::UpdateChannels::default();
3743
3744 for channel in channels.channels {
3745 update.channels.push(channel.to_proto());
3746 }
3747
3748 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3749
3750 for (channel_id, participants) in channels.channel_participants {
3751 update
3752 .channel_participants
3753 .push(proto::ChannelParticipants {
3754 channel_id: channel_id.to_proto(),
3755 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3756 });
3757 }
3758
3759 for channel in channels.invited_channels {
3760 update.channel_invitations.push(channel.to_proto());
3761 }
3762
3763 update
3764}
3765
3766fn build_initial_contacts_update(
3767 contacts: Vec<db::Contact>,
3768 pool: &ConnectionPool,
3769) -> proto::UpdateContacts {
3770 let mut update = proto::UpdateContacts::default();
3771
3772 for contact in contacts {
3773 match contact {
3774 db::Contact::Accepted { user_id, busy } => {
3775 update.contacts.push(contact_for_user(user_id, busy, pool));
3776 }
3777 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3778 db::Contact::Incoming { user_id } => {
3779 update
3780 .incoming_requests
3781 .push(proto::IncomingContactRequest {
3782 requester_id: user_id.to_proto(),
3783 })
3784 }
3785 }
3786 }
3787
3788 update
3789}
3790
3791fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3792 proto::Contact {
3793 user_id: user_id.to_proto(),
3794 online: pool.is_user_online(user_id),
3795 busy,
3796 }
3797}
3798
3799fn room_updated(room: &proto::Room, peer: &Peer) {
3800 broadcast(
3801 None,
3802 room.participants
3803 .iter()
3804 .filter_map(|participant| Some(participant.peer_id?.into())),
3805 |peer_id| {
3806 peer.send(
3807 peer_id,
3808 proto::RoomUpdated {
3809 room: Some(room.clone()),
3810 },
3811 )
3812 },
3813 );
3814}
3815
3816fn channel_updated(
3817 channel: &db::channel::Model,
3818 room: &proto::Room,
3819 peer: &Peer,
3820 pool: &ConnectionPool,
3821) {
3822 let participants = room
3823 .participants
3824 .iter()
3825 .map(|p| p.user_id)
3826 .collect::<Vec<_>>();
3827
3828 broadcast(
3829 None,
3830 pool.channel_connection_ids(channel.root_id())
3831 .filter_map(|(channel_id, role)| {
3832 role.can_see_channel(channel.visibility)
3833 .then_some(channel_id)
3834 }),
3835 |peer_id| {
3836 peer.send(
3837 peer_id,
3838 proto::UpdateChannels {
3839 channel_participants: vec![proto::ChannelParticipants {
3840 channel_id: channel.id.to_proto(),
3841 participant_user_ids: participants.clone(),
3842 }],
3843 ..Default::default()
3844 },
3845 )
3846 },
3847 );
3848}
3849
3850async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3851 let db = session.db().await;
3852
3853 let contacts = db.get_contacts(user_id).await?;
3854 let busy = db.is_user_busy(user_id).await?;
3855
3856 let pool = session.connection_pool().await;
3857 let updated_contact = contact_for_user(user_id, busy, &pool);
3858 for contact in contacts {
3859 if let db::Contact::Accepted {
3860 user_id: contact_user_id,
3861 ..
3862 } = contact
3863 {
3864 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3865 session
3866 .peer
3867 .send(
3868 contact_conn_id,
3869 proto::UpdateContacts {
3870 contacts: vec![updated_contact.clone()],
3871 remove_contacts: Default::default(),
3872 incoming_requests: Default::default(),
3873 remove_incoming_requests: Default::default(),
3874 outgoing_requests: Default::default(),
3875 remove_outgoing_requests: Default::default(),
3876 },
3877 )
3878 .trace_err();
3879 }
3880 }
3881 }
3882 Ok(())
3883}
3884
3885async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3886 let mut contacts_to_update = HashSet::default();
3887
3888 let room_id;
3889 let canceled_calls_to_user_ids;
3890 let livekit_room;
3891 let delete_livekit_room;
3892 let room;
3893 let channel;
3894
3895 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3896 contacts_to_update.insert(session.user_id());
3897
3898 for project in left_room.left_projects.values() {
3899 project_left(project, session);
3900 }
3901
3902 room_id = RoomId::from_proto(left_room.room.id);
3903 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3904 livekit_room = mem::take(&mut left_room.room.livekit_room);
3905 delete_livekit_room = left_room.deleted;
3906 room = mem::take(&mut left_room.room);
3907 channel = mem::take(&mut left_room.channel);
3908
3909 room_updated(&room, &session.peer);
3910 } else {
3911 return Ok(());
3912 }
3913
3914 if let Some(channel) = channel {
3915 channel_updated(
3916 &channel,
3917 &room,
3918 &session.peer,
3919 &*session.connection_pool().await,
3920 );
3921 }
3922
3923 {
3924 let pool = session.connection_pool().await;
3925 for canceled_user_id in canceled_calls_to_user_ids {
3926 for connection_id in pool.user_connection_ids(canceled_user_id) {
3927 session
3928 .peer
3929 .send(
3930 connection_id,
3931 proto::CallCanceled {
3932 room_id: room_id.to_proto(),
3933 },
3934 )
3935 .trace_err();
3936 }
3937 contacts_to_update.insert(canceled_user_id);
3938 }
3939 }
3940
3941 for contact_user_id in contacts_to_update {
3942 update_user_contacts(contact_user_id, session).await?;
3943 }
3944
3945 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3946 live_kit
3947 .remove_participant(livekit_room.clone(), session.user_id().to_string())
3948 .await
3949 .trace_err();
3950
3951 if delete_livekit_room {
3952 live_kit.delete_room(livekit_room).await.trace_err();
3953 }
3954 }
3955
3956 Ok(())
3957}
3958
3959async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3960 let left_channel_buffers = session
3961 .db()
3962 .await
3963 .leave_channel_buffers(session.connection_id)
3964 .await?;
3965
3966 for left_buffer in left_channel_buffers {
3967 channel_buffer_updated(
3968 session.connection_id,
3969 left_buffer.connections,
3970 &proto::UpdateChannelBufferCollaborators {
3971 channel_id: left_buffer.channel_id.to_proto(),
3972 collaborators: left_buffer.collaborators,
3973 },
3974 &session.peer,
3975 );
3976 }
3977
3978 Ok(())
3979}
3980
3981fn project_left(project: &db::LeftProject, session: &Session) {
3982 for connection_id in &project.connection_ids {
3983 if project.should_unshare {
3984 session
3985 .peer
3986 .send(
3987 *connection_id,
3988 proto::UnshareProject {
3989 project_id: project.id.to_proto(),
3990 },
3991 )
3992 .trace_err();
3993 } else {
3994 session
3995 .peer
3996 .send(
3997 *connection_id,
3998 proto::RemoveProjectCollaborator {
3999 project_id: project.id.to_proto(),
4000 peer_id: Some(session.connection_id.into()),
4001 },
4002 )
4003 .trace_err();
4004 }
4005 }
4006}
4007
4008async fn share_agent_thread(
4009 request: proto::ShareAgentThread,
4010 response: Response<proto::ShareAgentThread>,
4011 session: MessageContext,
4012) -> Result<()> {
4013 let user_id = session.user_id();
4014
4015 let share_id = SharedThreadId::from_proto(request.session_id.clone())
4016 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4017
4018 session
4019 .db()
4020 .await
4021 .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4022 .await?;
4023
4024 response.send(proto::Ack {})?;
4025
4026 Ok(())
4027}
4028
4029async fn get_shared_agent_thread(
4030 request: proto::GetSharedAgentThread,
4031 response: Response<proto::GetSharedAgentThread>,
4032 session: MessageContext,
4033) -> Result<()> {
4034 let share_id = SharedThreadId::from_proto(request.session_id)
4035 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4036
4037 let result = session.db().await.get_shared_thread(share_id).await?;
4038
4039 match result {
4040 Some((thread, username)) => {
4041 response.send(proto::GetSharedAgentThreadResponse {
4042 title: thread.title,
4043 thread_data: thread.data,
4044 sharer_username: username,
4045 created_at: thread.created_at.and_utc().to_rfc3339(),
4046 })?;
4047 }
4048 None => {
4049 return Err(anyhow!("Shared thread not found").into());
4050 }
4051 }
4052
4053 Ok(())
4054}
4055
4056pub trait ResultExt {
4057 type Ok;
4058
4059 fn trace_err(self) -> Option<Self::Ok>;
4060}
4061
4062impl<T, E> ResultExt for Result<T, E>
4063where
4064 E: std::fmt::Debug,
4065{
4066 type Ok = T;
4067
4068 #[track_caller]
4069 fn trace_err(self) -> Option<T> {
4070 match self {
4071 Ok(value) => Some(value),
4072 Err(error) => {
4073 tracing::error!("{:?}", error);
4074 None
4075 }
4076 }
4077 }
4078}