1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
10 UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use futures::TryFutureExt as _;
36use rpc::proto::split_repository_update;
37use tracing::Span;
38use util::paths::PathStyle;
39
40use futures::{
41 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
42 stream::FuturesUnordered,
43};
44use prometheus::{IntGauge, register_int_gauge};
45use rpc::{
46 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
47 proto::{
48 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
49 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
50 },
51};
52use semver::Version;
53use std::{
54 any::TypeId,
55 future::Future,
56 marker::PhantomData,
57 mem,
58 net::SocketAddr,
59 ops::{Deref, DerefMut},
60 rc::Rc,
61 sync::{
62 Arc, OnceLock,
63 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
64 },
65 time::{Duration, Instant},
66};
67use tokio::sync::{Semaphore, watch};
68use tower::ServiceBuilder;
69use tracing::{
70 Instrument,
71 field::{self},
72 info_span, instrument,
73};
74
75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
76
77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
79
80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
82
83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
84
85const TOTAL_DURATION_MS: &str = "total_duration_ms";
86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
88const HOST_WAITING_MS: &str = "host_waiting_ms";
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
92
93pub struct ConnectionGuard;
94
95impl ConnectionGuard {
96 pub fn try_acquire() -> Result<Self, ()> {
97 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
98 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
99 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
100 tracing::error!(
101 "too many concurrent connections: {}",
102 current_connections + 1
103 );
104 return Err(());
105 }
106 Ok(ConnectionGuard)
107 }
108}
109
110impl Drop for ConnectionGuard {
111 fn drop(&mut self) {
112 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
113 }
114}
115
116struct Response<R> {
117 peer: Arc<Peer>,
118 receipt: Receipt<R>,
119 responded: Arc<AtomicBool>,
120}
121
122impl<R: RequestMessage> Response<R> {
123 fn send(self, payload: R::Response) -> Result<()> {
124 self.responded.store(true, SeqCst);
125 self.peer.respond(self.receipt, payload)?;
126 Ok(())
127 }
128}
129
130#[derive(Clone, Debug)]
131pub enum Principal {
132 User(User),
133}
134
135impl Principal {
136 fn update_span(&self, span: &tracing::Span) {
137 match &self {
138 Principal::User(user) => {
139 span.record("user_id", user.id.0);
140 span.record("login", &user.github_login);
141 }
142 }
143 }
144}
145
146#[derive(Clone)]
147struct MessageContext {
148 session: Session,
149 span: tracing::Span,
150}
151
152impl Deref for MessageContext {
153 type Target = Session;
154
155 fn deref(&self) -> &Self::Target {
156 &self.session
157 }
158}
159
160impl MessageContext {
161 pub fn forward_request<T: RequestMessage>(
162 &self,
163 receiver_id: ConnectionId,
164 request: T,
165 ) -> impl Future<Output = anyhow::Result<T::Response>> {
166 let request_start_time = Instant::now();
167 let span = self.span.clone();
168 tracing::info!("start forwarding request");
169 self.peer
170 .forward_request(self.connection_id, receiver_id, request)
171 .inspect(move |_| {
172 span.record(
173 HOST_WAITING_MS,
174 request_start_time.elapsed().as_micros() as f64 / 1000.0,
175 );
176 })
177 .inspect_err(|_| tracing::error!("error forwarding request"))
178 .inspect_ok(|_| tracing::info!("finished forwarding request"))
179 }
180}
181
182#[derive(Clone)]
183struct Session {
184 principal: Principal,
185 connection_id: ConnectionId,
186 db: Arc<tokio::sync::Mutex<DbHandle>>,
187 peer: Arc<Peer>,
188 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
189 app_state: Arc<AppState>,
190 /// The GeoIP country code for the user.
191 #[allow(unused)]
192 geoip_country_code: Option<String>,
193 #[allow(unused)]
194 system_id: Option<String>,
195 _executor: Executor,
196}
197
198impl Session {
199 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
200 #[cfg(feature = "test-support")]
201 tokio::task::yield_now().await;
202 let guard = self.db.lock().await;
203 #[cfg(feature = "test-support")]
204 tokio::task::yield_now().await;
205 guard
206 }
207
208 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
209 #[cfg(feature = "test-support")]
210 tokio::task::yield_now().await;
211 let guard = self.connection_pool.lock();
212 ConnectionPoolGuard {
213 guard,
214 _not_send: PhantomData,
215 }
216 }
217
218 #[expect(dead_code)]
219 fn is_staff(&self) -> bool {
220 match &self.principal {
221 Principal::User(user) => user.admin,
222 }
223 }
224
225 fn user_id(&self) -> UserId {
226 match &self.principal {
227 Principal::User(user) => user.id,
228 }
229 }
230}
231
232impl Debug for Session {
233 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
234 let mut result = f.debug_struct("Session");
235 match &self.principal {
236 Principal::User(user) => {
237 result.field("user", &user.github_login);
238 }
239 }
240 result.field("connection_id", &self.connection_id).finish()
241 }
242}
243
244struct DbHandle(Arc<Database>);
245
246impl Deref for DbHandle {
247 type Target = Database;
248
249 fn deref(&self) -> &Self::Target {
250 self.0.as_ref()
251 }
252}
253
254pub struct Server {
255 id: parking_lot::Mutex<ServerId>,
256 peer: Arc<Peer>,
257 pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
258 app_state: Arc<AppState>,
259 handlers: HashMap<TypeId, MessageHandler>,
260 teardown: watch::Sender<bool>,
261}
262
263struct ConnectionPoolGuard<'a> {
264 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
265 _not_send: PhantomData<Rc<()>>,
266}
267
268impl Server {
269 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
270 let mut server = Self {
271 id: parking_lot::Mutex::new(id),
272 peer: Peer::new(id.0 as u32),
273 app_state,
274 connection_pool: Default::default(),
275 handlers: Default::default(),
276 teardown: watch::channel(false).0,
277 };
278
279 server
280 .add_request_handler(ping)
281 .add_request_handler(create_room)
282 .add_request_handler(join_room)
283 .add_request_handler(rejoin_room)
284 .add_request_handler(leave_room)
285 .add_request_handler(set_room_participant_role)
286 .add_request_handler(call)
287 .add_request_handler(cancel_call)
288 .add_message_handler(decline_call)
289 .add_request_handler(update_participant_location)
290 .add_request_handler(share_project)
291 .add_message_handler(unshare_project)
292 .add_request_handler(join_project)
293 .add_message_handler(leave_project)
294 .add_request_handler(update_project)
295 .add_request_handler(update_worktree)
296 .add_request_handler(update_repository)
297 .add_request_handler(remove_repository)
298 .add_message_handler(start_language_server)
299 .add_message_handler(update_language_server)
300 .add_message_handler(update_diagnostic_summary)
301 .add_message_handler(update_worktree_settings)
302 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
307 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
308 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
311 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
313 .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
314 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
316 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
317 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
318 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
319 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
320 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
321 .add_request_handler(
322 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
323 )
324 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
325 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
326 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
327 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
328 .add_request_handler(
329 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
330 )
331 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
332 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
333 .add_request_handler(
334 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
335 )
336 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
337 .add_request_handler(
338 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
339 )
340 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
341 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
342 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
343 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
344 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
345 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
346 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
347 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
348 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
349 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
350 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
353 )
354 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
355 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
356 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
357 .add_request_handler(lsp_query)
358 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
359 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
360 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
361 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
362 .add_message_handler(create_buffer_for_peer)
363 .add_message_handler(create_image_for_peer)
364 .add_request_handler(update_buffer)
365 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
366 .add_message_handler(
367 broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
368 )
369 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
370 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
371 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
372 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
373 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
374 .add_message_handler(
375 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
376 )
377 .add_request_handler(get_users)
378 .add_request_handler(fuzzy_search_users)
379 .add_request_handler(request_contact)
380 .add_request_handler(remove_contact)
381 .add_request_handler(respond_to_contact_request)
382 .add_message_handler(subscribe_to_channels)
383 .add_request_handler(create_channel)
384 .add_request_handler(delete_channel)
385 .add_request_handler(invite_channel_member)
386 .add_request_handler(remove_channel_member)
387 .add_request_handler(set_channel_member_role)
388 .add_request_handler(set_channel_visibility)
389 .add_request_handler(rename_channel)
390 .add_request_handler(join_channel_buffer)
391 .add_request_handler(leave_channel_buffer)
392 .add_message_handler(update_channel_buffer)
393 .add_request_handler(rejoin_channel_buffers)
394 .add_request_handler(get_channel_members)
395 .add_request_handler(respond_to_channel_invite)
396 .add_request_handler(join_channel)
397 .add_request_handler(join_channel_chat)
398 .add_message_handler(leave_channel_chat)
399 .add_request_handler(send_channel_message)
400 .add_request_handler(remove_channel_message)
401 .add_request_handler(update_channel_message)
402 .add_request_handler(get_channel_messages)
403 .add_request_handler(get_channel_messages_by_id)
404 .add_request_handler(get_notifications)
405 .add_request_handler(mark_notification_as_read)
406 .add_request_handler(move_channel)
407 .add_request_handler(reorder_channel)
408 .add_request_handler(follow)
409 .add_message_handler(unfollow)
410 .add_message_handler(update_followers)
411 .add_message_handler(acknowledge_channel_message)
412 .add_message_handler(acknowledge_buffer_version)
413 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
414 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
415 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
416 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
417 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
418 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
419 .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
420 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
421 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
422 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
423 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
424 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
425 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
426 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
427 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
428 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
429 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
430 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
431 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
432 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
433 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
434 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
435 .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
436 .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
437 .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
438 .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
439 .add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
440 .add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
441 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
442 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
443 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
444 .add_request_handler(share_agent_thread)
445 .add_request_handler(get_shared_agent_thread)
446 .add_request_handler(forward_project_search_chunk);
447
448 Arc::new(server)
449 }
450
451 pub async fn start(&self) -> Result<()> {
452 let server_id = *self.id.lock();
453 let app_state = self.app_state.clone();
454 let peer = self.peer.clone();
455 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
456 let pool = self.connection_pool.clone();
457 let livekit_client = self.app_state.livekit_client.clone();
458
459 let span = info_span!("start server");
460 self.app_state.executor.spawn_detached(
461 async move {
462 tracing::info!("waiting for cleanup timeout");
463 timeout.await;
464 tracing::info!("cleanup timeout expired, retrieving stale rooms");
465
466 app_state
467 .db
468 .delete_stale_channel_chat_participants(
469 &app_state.config.zed_environment,
470 server_id,
471 )
472 .await
473 .trace_err();
474
475 if let Some((room_ids, channel_ids)) = app_state
476 .db
477 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
478 .await
479 .trace_err()
480 {
481 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
482 tracing::info!(
483 stale_channel_buffer_count = channel_ids.len(),
484 "retrieved stale channel buffers"
485 );
486
487 for channel_id in channel_ids {
488 if let Some(refreshed_channel_buffer) = app_state
489 .db
490 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
491 .await
492 .trace_err()
493 {
494 for connection_id in refreshed_channel_buffer.connection_ids {
495 peer.send(
496 connection_id,
497 proto::UpdateChannelBufferCollaborators {
498 channel_id: channel_id.to_proto(),
499 collaborators: refreshed_channel_buffer
500 .collaborators
501 .clone(),
502 },
503 )
504 .trace_err();
505 }
506 }
507 }
508
509 for room_id in room_ids {
510 let mut contacts_to_update = HashSet::default();
511 let mut canceled_calls_to_user_ids = Vec::new();
512 let mut livekit_room = String::new();
513 let mut delete_livekit_room = false;
514
515 if let Some(mut refreshed_room) = app_state
516 .db
517 .clear_stale_room_participants(room_id, server_id)
518 .await
519 .trace_err()
520 {
521 tracing::info!(
522 room_id = room_id.0,
523 new_participant_count = refreshed_room.room.participants.len(),
524 "refreshed room"
525 );
526 room_updated(&refreshed_room.room, &peer);
527 if let Some(channel) = refreshed_room.channel.as_ref() {
528 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
529 }
530 contacts_to_update
531 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
532 contacts_to_update
533 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
534 canceled_calls_to_user_ids =
535 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
536 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
537 delete_livekit_room = refreshed_room.room.participants.is_empty();
538 }
539
540 {
541 let pool = pool.lock();
542 for canceled_user_id in canceled_calls_to_user_ids {
543 for connection_id in pool.user_connection_ids(canceled_user_id) {
544 peer.send(
545 connection_id,
546 proto::CallCanceled {
547 room_id: room_id.to_proto(),
548 },
549 )
550 .trace_err();
551 }
552 }
553 }
554
555 for user_id in contacts_to_update {
556 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
557 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
558 if let Some((busy, contacts)) = busy.zip(contacts) {
559 let pool = pool.lock();
560 let updated_contact = contact_for_user(user_id, busy, &pool);
561 for contact in contacts {
562 if let db::Contact::Accepted {
563 user_id: contact_user_id,
564 ..
565 } = contact
566 {
567 for contact_conn_id in
568 pool.user_connection_ids(contact_user_id)
569 {
570 peer.send(
571 contact_conn_id,
572 proto::UpdateContacts {
573 contacts: vec![updated_contact.clone()],
574 remove_contacts: Default::default(),
575 incoming_requests: Default::default(),
576 remove_incoming_requests: Default::default(),
577 outgoing_requests: Default::default(),
578 remove_outgoing_requests: Default::default(),
579 },
580 )
581 .trace_err();
582 }
583 }
584 }
585 }
586 }
587
588 if let Some(live_kit) = livekit_client.as_ref()
589 && delete_livekit_room
590 {
591 live_kit.delete_room(livekit_room).await.trace_err();
592 }
593 }
594 }
595
596 app_state
597 .db
598 .delete_stale_channel_chat_participants(
599 &app_state.config.zed_environment,
600 server_id,
601 )
602 .await
603 .trace_err();
604
605 app_state
606 .db
607 .clear_old_worktree_entries(server_id)
608 .await
609 .trace_err();
610
611 app_state
612 .db
613 .delete_stale_servers(&app_state.config.zed_environment, server_id)
614 .await
615 .trace_err();
616 }
617 .instrument(span),
618 );
619 Ok(())
620 }
621
622 pub fn teardown(&self) {
623 self.peer.teardown();
624 self.connection_pool.lock().reset();
625 let _ = self.teardown.send(true);
626 }
627
628 #[cfg(feature = "test-support")]
629 pub fn reset(&self, id: ServerId) {
630 self.teardown();
631 *self.id.lock() = id;
632 self.peer.reset(id.0 as u32);
633 let _ = self.teardown.send(false);
634 }
635
636 #[cfg(feature = "test-support")]
637 pub fn id(&self) -> ServerId {
638 *self.id.lock()
639 }
640
641 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
642 where
643 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
644 Fut: 'static + Send + Future<Output = Result<()>>,
645 M: EnvelopedMessage,
646 {
647 let prev_handler = self.handlers.insert(
648 TypeId::of::<M>(),
649 Box::new(move |envelope, session, span| {
650 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
651 let received_at = envelope.received_at;
652 tracing::info!("message received");
653 let start_time = Instant::now();
654 let future = (handler)(
655 *envelope,
656 MessageContext {
657 session,
658 span: span.clone(),
659 },
660 );
661 async move {
662 let result = future.await;
663 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
664 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
665 let queue_duration_ms = total_duration_ms - processing_duration_ms;
666 span.record(TOTAL_DURATION_MS, total_duration_ms);
667 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
668 span.record(QUEUE_DURATION_MS, queue_duration_ms);
669 match result {
670 Err(error) => {
671 tracing::error!(?error, "error handling message")
672 }
673 Ok(()) => tracing::info!("finished handling message"),
674 }
675 }
676 .boxed()
677 }),
678 );
679 if prev_handler.is_some() {
680 panic!("registered a handler for the same message twice");
681 }
682 self
683 }
684
685 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
686 where
687 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
688 Fut: 'static + Send + Future<Output = Result<()>>,
689 M: EnvelopedMessage,
690 {
691 self.add_handler(move |envelope, session| handler(envelope.payload, session));
692 self
693 }
694
695 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
696 where
697 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
698 Fut: Send + Future<Output = Result<()>>,
699 M: RequestMessage,
700 {
701 let handler = Arc::new(handler);
702 self.add_handler(move |envelope, session| {
703 let receipt = envelope.receipt();
704 let handler = handler.clone();
705 async move {
706 let peer = session.peer.clone();
707 let responded = Arc::new(AtomicBool::default());
708 let response = Response {
709 peer: peer.clone(),
710 responded: responded.clone(),
711 receipt,
712 };
713 match (handler)(envelope.payload, response, session).await {
714 Ok(()) => {
715 if responded.load(std::sync::atomic::Ordering::SeqCst) {
716 Ok(())
717 } else {
718 Err(anyhow!("handler did not send a response"))?
719 }
720 }
721 Err(error) => {
722 let proto_err = match &error {
723 Error::Internal(err) => err.to_proto(),
724 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
725 };
726 peer.respond_with_error(receipt, proto_err)?;
727 Err(error)
728 }
729 }
730 }
731 })
732 }
733
734 pub fn handle_connection(
735 self: &Arc<Self>,
736 connection: Connection,
737 address: String,
738 principal: Principal,
739 zed_version: ZedVersion,
740 release_channel: Option<String>,
741 user_agent: Option<String>,
742 geoip_country_code: Option<String>,
743 system_id: Option<String>,
744 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
745 executor: Executor,
746 connection_guard: Option<ConnectionGuard>,
747 ) -> impl Future<Output = ()> + use<> {
748 let this = self.clone();
749 let span = info_span!("handle connection", %address,
750 connection_id=field::Empty,
751 user_id=field::Empty,
752 login=field::Empty,
753 user_agent=field::Empty,
754 geoip_country_code=field::Empty,
755 release_channel=field::Empty,
756 );
757 principal.update_span(&span);
758 if let Some(user_agent) = user_agent {
759 span.record("user_agent", user_agent);
760 }
761 if let Some(release_channel) = release_channel {
762 span.record("release_channel", release_channel);
763 }
764
765 if let Some(country_code) = geoip_country_code.as_ref() {
766 span.record("geoip_country_code", country_code);
767 }
768
769 let mut teardown = self.teardown.subscribe();
770 async move {
771 if *teardown.borrow() {
772 tracing::error!("server is tearing down");
773 return;
774 }
775
776 let (connection_id, handle_io, mut incoming_rx) =
777 this.peer.add_connection(connection, {
778 let executor = executor.clone();
779 move |duration| executor.sleep(duration)
780 });
781 tracing::Span::current().record("connection_id", format!("{}", connection_id));
782
783 tracing::info!("connection opened");
784
785 let session = Session {
786 principal: principal.clone(),
787 connection_id,
788 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
789 peer: this.peer.clone(),
790 connection_pool: this.connection_pool.clone(),
791 app_state: this.app_state.clone(),
792 geoip_country_code,
793 system_id,
794 _executor: executor.clone(),
795 };
796
797 if let Err(error) = this
798 .send_initial_client_update(
799 connection_id,
800 zed_version,
801 send_connection_id,
802 &session,
803 )
804 .await
805 {
806 tracing::error!(?error, "failed to send initial client update");
807 return;
808 }
809 drop(connection_guard);
810
811 let handle_io = handle_io.fuse();
812 futures::pin_mut!(handle_io);
813
814 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
815 // This prevents deadlocks when e.g., client A performs a request to client B and
816 // client B performs a request to client A. If both clients stop processing further
817 // messages until their respective request completes, they won't have a chance to
818 // respond to the other client's request and cause a deadlock.
819 //
820 // This arrangement ensures we will attempt to process earlier messages first, but fall
821 // back to processing messages arrived later in the spirit of making progress.
822 const MAX_CONCURRENT_HANDLERS: usize = 256;
823 let mut foreground_message_handlers = FuturesUnordered::new();
824 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
825 let get_concurrent_handlers = {
826 let concurrent_handlers = concurrent_handlers.clone();
827 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
828 };
829 loop {
830 let next_message = async {
831 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
832 let message = incoming_rx.next().await;
833 // Cache the concurrent_handlers here, so that we know what the
834 // queue looks like as each handler starts
835 (permit, message, get_concurrent_handlers())
836 }
837 .fuse();
838 futures::pin_mut!(next_message);
839 futures::select_biased! {
840 _ = teardown.changed().fuse() => return,
841 result = handle_io => {
842 if let Err(error) = result {
843 tracing::error!(?error, "error handling I/O");
844 }
845 break;
846 }
847 _ = foreground_message_handlers.next() => {}
848 next_message = next_message => {
849 let (permit, message, concurrent_handlers) = next_message;
850 if let Some(message) = message {
851 let type_name = message.payload_type_name();
852 // note: we copy all the fields from the parent span so we can query them in the logs.
853 // (https://github.com/tokio-rs/tracing/issues/2670).
854 let span = tracing::info_span!("receive message",
855 %connection_id,
856 %address,
857 type_name,
858 concurrent_handlers,
859 user_id=field::Empty,
860 login=field::Empty,
861 lsp_query_request=field::Empty,
862 release_channel=field::Empty,
863 { TOTAL_DURATION_MS }=field::Empty,
864 { PROCESSING_DURATION_MS }=field::Empty,
865 { QUEUE_DURATION_MS }=field::Empty,
866 { HOST_WAITING_MS }=field::Empty
867 );
868 principal.update_span(&span);
869 let span_enter = span.enter();
870 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
871 let is_background = message.is_background();
872 let handle_message = (handler)(message, session.clone(), span.clone());
873 drop(span_enter);
874
875 let handle_message = async move {
876 handle_message.await;
877 drop(permit);
878 }.instrument(span);
879 if is_background {
880 executor.spawn_detached(handle_message);
881 } else {
882 foreground_message_handlers.push(handle_message);
883 }
884 } else {
885 tracing::error!("no message handler");
886 }
887 } else {
888 tracing::info!("connection closed");
889 break;
890 }
891 }
892 }
893 }
894
895 drop(foreground_message_handlers);
896 let concurrent_handlers = get_concurrent_handlers();
897 tracing::info!(concurrent_handlers, "signing out");
898 if let Err(error) = connection_lost(session, teardown, executor).await {
899 tracing::error!(?error, "error signing out");
900 }
901 }
902 .instrument(span)
903 }
904
905 async fn send_initial_client_update(
906 &self,
907 connection_id: ConnectionId,
908 zed_version: ZedVersion,
909 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
910 session: &Session,
911 ) -> Result<()> {
912 self.peer.send(
913 connection_id,
914 proto::Hello {
915 peer_id: Some(connection_id.into()),
916 },
917 )?;
918 tracing::info!("sent hello message");
919 if let Some(send_connection_id) = send_connection_id.take() {
920 let _ = send_connection_id.send(connection_id);
921 }
922
923 match &session.principal {
924 Principal::User(user) => {
925 if !user.connected_once {
926 self.peer.send(connection_id, proto::ShowContacts {})?;
927 self.app_state
928 .db
929 .set_user_connected_once(user.id, true)
930 .await?;
931 }
932
933 let contacts = self.app_state.db.get_contacts(user.id).await?;
934
935 {
936 let mut pool = self.connection_pool.lock();
937 pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
938 self.peer.send(
939 connection_id,
940 build_initial_contacts_update(contacts, &pool),
941 )?;
942 }
943
944 if should_auto_subscribe_to_channels(&zed_version) {
945 subscribe_user_to_channels(user.id, session).await?;
946 }
947
948 if let Some(incoming_call) =
949 self.app_state.db.incoming_call_for_user(user.id).await?
950 {
951 self.peer.send(connection_id, incoming_call)?;
952 }
953
954 update_user_contacts(user.id, session).await?;
955 }
956 }
957
958 Ok(())
959 }
960}
961
962impl Deref for ConnectionPoolGuard<'_> {
963 type Target = ConnectionPool;
964
965 fn deref(&self) -> &Self::Target {
966 &self.guard
967 }
968}
969
970impl DerefMut for ConnectionPoolGuard<'_> {
971 fn deref_mut(&mut self) -> &mut Self::Target {
972 &mut self.guard
973 }
974}
975
976impl Drop for ConnectionPoolGuard<'_> {
977 fn drop(&mut self) {
978 #[cfg(feature = "test-support")]
979 self.check_invariants();
980 }
981}
982
983fn broadcast<F>(
984 sender_id: Option<ConnectionId>,
985 receiver_ids: impl IntoIterator<Item = ConnectionId>,
986 mut f: F,
987) where
988 F: FnMut(ConnectionId) -> anyhow::Result<()>,
989{
990 for receiver_id in receiver_ids {
991 if Some(receiver_id) != sender_id
992 && let Err(error) = f(receiver_id)
993 {
994 tracing::error!("failed to send to {:?} {}", receiver_id, error);
995 }
996 }
997}
998
999pub struct ProtocolVersion(u32);
1000
1001impl Header for ProtocolVersion {
1002 fn name() -> &'static HeaderName {
1003 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1004 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1005 }
1006
1007 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1008 where
1009 Self: Sized,
1010 I: Iterator<Item = &'i axum::http::HeaderValue>,
1011 {
1012 let version = values
1013 .next()
1014 .ok_or_else(axum::headers::Error::invalid)?
1015 .to_str()
1016 .map_err(|_| axum::headers::Error::invalid())?
1017 .parse()
1018 .map_err(|_| axum::headers::Error::invalid())?;
1019 Ok(Self(version))
1020 }
1021
1022 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1023 values.extend([self.0.to_string().parse().unwrap()]);
1024 }
1025}
1026
1027pub struct AppVersionHeader(Version);
1028impl Header for AppVersionHeader {
1029 fn name() -> &'static HeaderName {
1030 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1031 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1032 }
1033
1034 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1035 where
1036 Self: Sized,
1037 I: Iterator<Item = &'i axum::http::HeaderValue>,
1038 {
1039 let version = values
1040 .next()
1041 .ok_or_else(axum::headers::Error::invalid)?
1042 .to_str()
1043 .map_err(|_| axum::headers::Error::invalid())?
1044 .parse()
1045 .map_err(|_| axum::headers::Error::invalid())?;
1046 Ok(Self(version))
1047 }
1048
1049 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1050 values.extend([self.0.to_string().parse().unwrap()]);
1051 }
1052}
1053
1054#[derive(Debug)]
1055pub struct ReleaseChannelHeader(String);
1056
1057impl Header for ReleaseChannelHeader {
1058 fn name() -> &'static HeaderName {
1059 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1060 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1061 }
1062
1063 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1064 where
1065 Self: Sized,
1066 I: Iterator<Item = &'i axum::http::HeaderValue>,
1067 {
1068 Ok(Self(
1069 values
1070 .next()
1071 .ok_or_else(axum::headers::Error::invalid)?
1072 .to_str()
1073 .map_err(|_| axum::headers::Error::invalid())?
1074 .to_owned(),
1075 ))
1076 }
1077
1078 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1079 values.extend([self.0.parse().unwrap()]);
1080 }
1081}
1082
1083pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1084 Router::new()
1085 .route("/rpc", get(handle_websocket_request))
1086 .layer(
1087 ServiceBuilder::new()
1088 .layer(Extension(server.app_state.clone()))
1089 .layer(middleware::from_fn(auth::validate_header)),
1090 )
1091 .route("/metrics", get(handle_metrics))
1092 .layer(Extension(server))
1093}
1094
1095pub async fn handle_websocket_request(
1096 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1097 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1098 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1099 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1100 Extension(server): Extension<Arc<Server>>,
1101 Extension(principal): Extension<Principal>,
1102 user_agent: Option<TypedHeader<UserAgent>>,
1103 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1104 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1105 ws: WebSocketUpgrade,
1106) -> axum::response::Response {
1107 if protocol_version != rpc::PROTOCOL_VERSION {
1108 return (
1109 StatusCode::UPGRADE_REQUIRED,
1110 "client must be upgraded".to_string(),
1111 )
1112 .into_response();
1113 }
1114
1115 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1116 return (
1117 StatusCode::UPGRADE_REQUIRED,
1118 "no version header found".to_string(),
1119 )
1120 .into_response();
1121 };
1122
1123 let release_channel = release_channel_header.map(|header| header.0.0);
1124
1125 if !version.can_collaborate() {
1126 return (
1127 StatusCode::UPGRADE_REQUIRED,
1128 "client must be upgraded".to_string(),
1129 )
1130 .into_response();
1131 }
1132
1133 let socket_address = socket_address.to_string();
1134
1135 // Acquire connection guard before WebSocket upgrade
1136 let connection_guard = match ConnectionGuard::try_acquire() {
1137 Ok(guard) => guard,
1138 Err(()) => {
1139 return (
1140 StatusCode::SERVICE_UNAVAILABLE,
1141 "Too many concurrent connections",
1142 )
1143 .into_response();
1144 }
1145 };
1146
1147 ws.on_upgrade(move |socket| {
1148 let socket = socket
1149 .map_ok(to_tungstenite_message)
1150 .err_into()
1151 .with(|message| async move { to_axum_message(message) });
1152 let connection = Connection::new(Box::pin(socket));
1153 async move {
1154 server
1155 .handle_connection(
1156 connection,
1157 socket_address,
1158 principal,
1159 version,
1160 release_channel,
1161 user_agent.map(|header| header.to_string()),
1162 country_code_header.map(|header| header.to_string()),
1163 system_id_header.map(|header| header.to_string()),
1164 None,
1165 Executor::Production,
1166 Some(connection_guard),
1167 )
1168 .await;
1169 }
1170 })
1171}
1172
1173pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1174 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1175 let connections_metric = CONNECTIONS_METRIC
1176 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1177
1178 let connections = server
1179 .connection_pool
1180 .lock()
1181 .connections()
1182 .filter(|connection| !connection.admin)
1183 .count();
1184 connections_metric.set(connections as _);
1185
1186 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1187 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1188 register_int_gauge!(
1189 "shared_projects",
1190 "number of open projects with one or more guests"
1191 )
1192 .unwrap()
1193 });
1194
1195 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1196 shared_projects_metric.set(shared_projects as _);
1197
1198 let encoder = prometheus::TextEncoder::new();
1199 let metric_families = prometheus::gather();
1200 let encoded_metrics = encoder
1201 .encode_to_string(&metric_families)
1202 .map_err(|err| anyhow!("{err}"))?;
1203 Ok(encoded_metrics)
1204}
1205
1206#[instrument(err, skip(executor))]
1207async fn connection_lost(
1208 session: Session,
1209 mut teardown: watch::Receiver<bool>,
1210 executor: Executor,
1211) -> Result<()> {
1212 session.peer.disconnect(session.connection_id);
1213 session
1214 .connection_pool()
1215 .await
1216 .remove_connection(session.connection_id)?;
1217
1218 session
1219 .db()
1220 .await
1221 .connection_lost(session.connection_id)
1222 .await
1223 .trace_err();
1224
1225 futures::select_biased! {
1226 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1227
1228 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1229 leave_room_for_session(&session, session.connection_id).await.trace_err();
1230 leave_channel_buffers_for_session(&session)
1231 .await
1232 .trace_err();
1233
1234 if !session
1235 .connection_pool()
1236 .await
1237 .is_user_online(session.user_id())
1238 {
1239 let db = session.db().await;
1240 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1241 room_updated(&room, &session.peer);
1242 }
1243 }
1244
1245 update_user_contacts(session.user_id(), &session).await?;
1246 },
1247 _ = teardown.changed().fuse() => {}
1248 }
1249
1250 Ok(())
1251}
1252
1253/// Acknowledges a ping from a client, used to keep the connection alive.
1254async fn ping(
1255 _: proto::Ping,
1256 response: Response<proto::Ping>,
1257 _session: MessageContext,
1258) -> Result<()> {
1259 response.send(proto::Ack {})?;
1260 Ok(())
1261}
1262
1263/// Creates a new room for calling (outside of channels)
1264async fn create_room(
1265 _request: proto::CreateRoom,
1266 response: Response<proto::CreateRoom>,
1267 session: MessageContext,
1268) -> Result<()> {
1269 let livekit_room = nanoid::nanoid!(30);
1270
1271 let live_kit_connection_info = util::maybe!(async {
1272 let live_kit = session.app_state.livekit_client.as_ref();
1273 let live_kit = live_kit?;
1274 let user_id = session.user_id().to_string();
1275
1276 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1277
1278 Some(proto::LiveKitConnectionInfo {
1279 server_url: live_kit.url().into(),
1280 token,
1281 can_publish: true,
1282 })
1283 })
1284 .await;
1285
1286 let room = session
1287 .db()
1288 .await
1289 .create_room(session.user_id(), session.connection_id, &livekit_room)
1290 .await?;
1291
1292 response.send(proto::CreateRoomResponse {
1293 room: Some(room.clone()),
1294 live_kit_connection_info,
1295 })?;
1296
1297 update_user_contacts(session.user_id(), &session).await?;
1298 Ok(())
1299}
1300
1301/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1302async fn join_room(
1303 request: proto::JoinRoom,
1304 response: Response<proto::JoinRoom>,
1305 session: MessageContext,
1306) -> Result<()> {
1307 let room_id = RoomId::from_proto(request.id);
1308
1309 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1310
1311 if let Some(channel_id) = channel_id {
1312 return join_channel_internal(channel_id, Box::new(response), session).await;
1313 }
1314
1315 let joined_room = {
1316 let room = session
1317 .db()
1318 .await
1319 .join_room(room_id, session.user_id(), session.connection_id)
1320 .await?;
1321 room_updated(&room.room, &session.peer);
1322 room.into_inner()
1323 };
1324
1325 for connection_id in session
1326 .connection_pool()
1327 .await
1328 .user_connection_ids(session.user_id())
1329 {
1330 session
1331 .peer
1332 .send(
1333 connection_id,
1334 proto::CallCanceled {
1335 room_id: room_id.to_proto(),
1336 },
1337 )
1338 .trace_err();
1339 }
1340
1341 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1342 {
1343 live_kit
1344 .room_token(
1345 &joined_room.room.livekit_room,
1346 &session.user_id().to_string(),
1347 )
1348 .trace_err()
1349 .map(|token| proto::LiveKitConnectionInfo {
1350 server_url: live_kit.url().into(),
1351 token,
1352 can_publish: true,
1353 })
1354 } else {
1355 None
1356 };
1357
1358 response.send(proto::JoinRoomResponse {
1359 room: Some(joined_room.room),
1360 channel_id: None,
1361 live_kit_connection_info,
1362 })?;
1363
1364 update_user_contacts(session.user_id(), &session).await?;
1365 Ok(())
1366}
1367
1368/// Rejoin room is used to reconnect to a room after connection errors.
1369async fn rejoin_room(
1370 request: proto::RejoinRoom,
1371 response: Response<proto::RejoinRoom>,
1372 session: MessageContext,
1373) -> Result<()> {
1374 let room;
1375 let channel;
1376 {
1377 let mut rejoined_room = session
1378 .db()
1379 .await
1380 .rejoin_room(request, session.user_id(), session.connection_id)
1381 .await?;
1382
1383 response.send(proto::RejoinRoomResponse {
1384 room: Some(rejoined_room.room.clone()),
1385 reshared_projects: rejoined_room
1386 .reshared_projects
1387 .iter()
1388 .map(|project| proto::ResharedProject {
1389 id: project.id.to_proto(),
1390 collaborators: project
1391 .collaborators
1392 .iter()
1393 .map(|collaborator| collaborator.to_proto())
1394 .collect(),
1395 })
1396 .collect(),
1397 rejoined_projects: rejoined_room
1398 .rejoined_projects
1399 .iter()
1400 .map(|rejoined_project| rejoined_project.to_proto())
1401 .collect(),
1402 })?;
1403 room_updated(&rejoined_room.room, &session.peer);
1404
1405 for project in &rejoined_room.reshared_projects {
1406 for collaborator in &project.collaborators {
1407 session
1408 .peer
1409 .send(
1410 collaborator.connection_id,
1411 proto::UpdateProjectCollaborator {
1412 project_id: project.id.to_proto(),
1413 old_peer_id: Some(project.old_connection_id.into()),
1414 new_peer_id: Some(session.connection_id.into()),
1415 },
1416 )
1417 .trace_err();
1418 }
1419
1420 broadcast(
1421 Some(session.connection_id),
1422 project
1423 .collaborators
1424 .iter()
1425 .map(|collaborator| collaborator.connection_id),
1426 |connection_id| {
1427 session.peer.forward_send(
1428 session.connection_id,
1429 connection_id,
1430 proto::UpdateProject {
1431 project_id: project.id.to_proto(),
1432 worktrees: project.worktrees.clone(),
1433 },
1434 )
1435 },
1436 );
1437 }
1438
1439 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1440
1441 let rejoined_room = rejoined_room.into_inner();
1442
1443 room = rejoined_room.room;
1444 channel = rejoined_room.channel;
1445 }
1446
1447 if let Some(channel) = channel {
1448 channel_updated(
1449 &channel,
1450 &room,
1451 &session.peer,
1452 &*session.connection_pool().await,
1453 );
1454 }
1455
1456 update_user_contacts(session.user_id(), &session).await?;
1457 Ok(())
1458}
1459
1460fn notify_rejoined_projects(
1461 rejoined_projects: &mut Vec<RejoinedProject>,
1462 session: &Session,
1463) -> Result<()> {
1464 for project in rejoined_projects.iter() {
1465 for collaborator in &project.collaborators {
1466 session
1467 .peer
1468 .send(
1469 collaborator.connection_id,
1470 proto::UpdateProjectCollaborator {
1471 project_id: project.id.to_proto(),
1472 old_peer_id: Some(project.old_connection_id.into()),
1473 new_peer_id: Some(session.connection_id.into()),
1474 },
1475 )
1476 .trace_err();
1477 }
1478 }
1479
1480 for project in rejoined_projects {
1481 for worktree in mem::take(&mut project.worktrees) {
1482 // Stream this worktree's entries.
1483 let message = proto::UpdateWorktree {
1484 project_id: project.id.to_proto(),
1485 worktree_id: worktree.id,
1486 abs_path: worktree.abs_path.clone(),
1487 root_name: worktree.root_name,
1488 updated_entries: worktree.updated_entries,
1489 removed_entries: worktree.removed_entries,
1490 scan_id: worktree.scan_id,
1491 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1492 updated_repositories: worktree.updated_repositories,
1493 removed_repositories: worktree.removed_repositories,
1494 };
1495 for update in proto::split_worktree_update(message) {
1496 session.peer.send(session.connection_id, update)?;
1497 }
1498
1499 // Stream this worktree's diagnostics.
1500 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1501 if let Some(summary) = worktree_diagnostics.next() {
1502 let message = proto::UpdateDiagnosticSummary {
1503 project_id: project.id.to_proto(),
1504 worktree_id: worktree.id,
1505 summary: Some(summary),
1506 more_summaries: worktree_diagnostics.collect(),
1507 };
1508 session.peer.send(session.connection_id, message)?;
1509 }
1510
1511 for settings_file in worktree.settings_files {
1512 session.peer.send(
1513 session.connection_id,
1514 proto::UpdateWorktreeSettings {
1515 project_id: project.id.to_proto(),
1516 worktree_id: worktree.id,
1517 path: settings_file.path,
1518 content: Some(settings_file.content),
1519 kind: Some(settings_file.kind.to_proto().into()),
1520 outside_worktree: Some(settings_file.outside_worktree),
1521 },
1522 )?;
1523 }
1524 }
1525
1526 for repository in mem::take(&mut project.updated_repositories) {
1527 for update in split_repository_update(repository) {
1528 session.peer.send(session.connection_id, update)?;
1529 }
1530 }
1531
1532 for id in mem::take(&mut project.removed_repositories) {
1533 session.peer.send(
1534 session.connection_id,
1535 proto::RemoveRepository {
1536 project_id: project.id.to_proto(),
1537 id,
1538 },
1539 )?;
1540 }
1541 }
1542
1543 Ok(())
1544}
1545
1546/// leave room disconnects from the room.
1547async fn leave_room(
1548 _: proto::LeaveRoom,
1549 response: Response<proto::LeaveRoom>,
1550 session: MessageContext,
1551) -> Result<()> {
1552 leave_room_for_session(&session, session.connection_id).await?;
1553 response.send(proto::Ack {})?;
1554 Ok(())
1555}
1556
1557/// Updates the permissions of someone else in the room.
1558async fn set_room_participant_role(
1559 request: proto::SetRoomParticipantRole,
1560 response: Response<proto::SetRoomParticipantRole>,
1561 session: MessageContext,
1562) -> Result<()> {
1563 let user_id = UserId::from_proto(request.user_id);
1564 let role = ChannelRole::from(request.role());
1565
1566 let (livekit_room, can_publish) = {
1567 let room = session
1568 .db()
1569 .await
1570 .set_room_participant_role(
1571 session.user_id(),
1572 RoomId::from_proto(request.room_id),
1573 user_id,
1574 role,
1575 )
1576 .await?;
1577
1578 let livekit_room = room.livekit_room.clone();
1579 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1580 room_updated(&room, &session.peer);
1581 (livekit_room, can_publish)
1582 };
1583
1584 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1585 live_kit
1586 .update_participant(
1587 livekit_room.clone(),
1588 request.user_id.to_string(),
1589 livekit_api::proto::ParticipantPermission {
1590 can_subscribe: true,
1591 can_publish,
1592 can_publish_data: can_publish,
1593 hidden: false,
1594 recorder: false,
1595 },
1596 )
1597 .await
1598 .trace_err();
1599 }
1600
1601 response.send(proto::Ack {})?;
1602 Ok(())
1603}
1604
1605/// Call someone else into the current room
1606async fn call(
1607 request: proto::Call,
1608 response: Response<proto::Call>,
1609 session: MessageContext,
1610) -> Result<()> {
1611 let room_id = RoomId::from_proto(request.room_id);
1612 let calling_user_id = session.user_id();
1613 let calling_connection_id = session.connection_id;
1614 let called_user_id = UserId::from_proto(request.called_user_id);
1615 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1616 if !session
1617 .db()
1618 .await
1619 .has_contact(calling_user_id, called_user_id)
1620 .await?
1621 {
1622 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1623 }
1624
1625 let incoming_call = {
1626 let (room, incoming_call) = &mut *session
1627 .db()
1628 .await
1629 .call(
1630 room_id,
1631 calling_user_id,
1632 calling_connection_id,
1633 called_user_id,
1634 initial_project_id,
1635 )
1636 .await?;
1637 room_updated(room, &session.peer);
1638 mem::take(incoming_call)
1639 };
1640 update_user_contacts(called_user_id, &session).await?;
1641
1642 let mut calls = session
1643 .connection_pool()
1644 .await
1645 .user_connection_ids(called_user_id)
1646 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1647 .collect::<FuturesUnordered<_>>();
1648
1649 while let Some(call_response) = calls.next().await {
1650 match call_response.as_ref() {
1651 Ok(_) => {
1652 response.send(proto::Ack {})?;
1653 return Ok(());
1654 }
1655 Err(_) => {
1656 call_response.trace_err();
1657 }
1658 }
1659 }
1660
1661 {
1662 let room = session
1663 .db()
1664 .await
1665 .call_failed(room_id, called_user_id)
1666 .await?;
1667 room_updated(&room, &session.peer);
1668 }
1669 update_user_contacts(called_user_id, &session).await?;
1670
1671 Err(anyhow!("failed to ring user"))?
1672}
1673
1674/// Cancel an outgoing call.
1675async fn cancel_call(
1676 request: proto::CancelCall,
1677 response: Response<proto::CancelCall>,
1678 session: MessageContext,
1679) -> Result<()> {
1680 let called_user_id = UserId::from_proto(request.called_user_id);
1681 let room_id = RoomId::from_proto(request.room_id);
1682 {
1683 let room = session
1684 .db()
1685 .await
1686 .cancel_call(room_id, session.connection_id, called_user_id)
1687 .await?;
1688 room_updated(&room, &session.peer);
1689 }
1690
1691 for connection_id in session
1692 .connection_pool()
1693 .await
1694 .user_connection_ids(called_user_id)
1695 {
1696 session
1697 .peer
1698 .send(
1699 connection_id,
1700 proto::CallCanceled {
1701 room_id: room_id.to_proto(),
1702 },
1703 )
1704 .trace_err();
1705 }
1706 response.send(proto::Ack {})?;
1707
1708 update_user_contacts(called_user_id, &session).await?;
1709 Ok(())
1710}
1711
1712/// Decline an incoming call.
1713async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1714 let room_id = RoomId::from_proto(message.room_id);
1715 {
1716 let room = session
1717 .db()
1718 .await
1719 .decline_call(Some(room_id), session.user_id())
1720 .await?
1721 .context("declining call")?;
1722 room_updated(&room, &session.peer);
1723 }
1724
1725 for connection_id in session
1726 .connection_pool()
1727 .await
1728 .user_connection_ids(session.user_id())
1729 {
1730 session
1731 .peer
1732 .send(
1733 connection_id,
1734 proto::CallCanceled {
1735 room_id: room_id.to_proto(),
1736 },
1737 )
1738 .trace_err();
1739 }
1740 update_user_contacts(session.user_id(), &session).await?;
1741 Ok(())
1742}
1743
1744/// Updates other participants in the room with your current location.
1745async fn update_participant_location(
1746 request: proto::UpdateParticipantLocation,
1747 response: Response<proto::UpdateParticipantLocation>,
1748 session: MessageContext,
1749) -> Result<()> {
1750 let room_id = RoomId::from_proto(request.room_id);
1751 let location = request.location.context("invalid location")?;
1752
1753 let db = session.db().await;
1754 let room = db
1755 .update_room_participant_location(room_id, session.connection_id, location)
1756 .await?;
1757
1758 room_updated(&room, &session.peer);
1759 response.send(proto::Ack {})?;
1760 Ok(())
1761}
1762
1763/// Share a project into the room.
1764async fn share_project(
1765 request: proto::ShareProject,
1766 response: Response<proto::ShareProject>,
1767 session: MessageContext,
1768) -> Result<()> {
1769 let (project_id, room) = &*session
1770 .db()
1771 .await
1772 .share_project(
1773 RoomId::from_proto(request.room_id),
1774 session.connection_id,
1775 &request.worktrees,
1776 request.is_ssh_project,
1777 request.windows_paths.unwrap_or(false),
1778 &request.features,
1779 )
1780 .await?;
1781 response.send(proto::ShareProjectResponse {
1782 project_id: project_id.to_proto(),
1783 })?;
1784 room_updated(room, &session.peer);
1785
1786 Ok(())
1787}
1788
1789/// Unshare a project from the room.
1790async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1791 let project_id = ProjectId::from_proto(message.project_id);
1792 unshare_project_internal(project_id, session.connection_id, &session).await
1793}
1794
1795async fn unshare_project_internal(
1796 project_id: ProjectId,
1797 connection_id: ConnectionId,
1798 session: &Session,
1799) -> Result<()> {
1800 let delete = {
1801 let room_guard = session
1802 .db()
1803 .await
1804 .unshare_project(project_id, connection_id)
1805 .await?;
1806
1807 let (delete, room, guest_connection_ids) = &*room_guard;
1808
1809 let message = proto::UnshareProject {
1810 project_id: project_id.to_proto(),
1811 };
1812
1813 broadcast(
1814 Some(connection_id),
1815 guest_connection_ids.iter().copied(),
1816 |conn_id| session.peer.send(conn_id, message.clone()),
1817 );
1818 if let Some(room) = room {
1819 room_updated(room, &session.peer);
1820 }
1821
1822 *delete
1823 };
1824
1825 if delete {
1826 let db = session.db().await;
1827 db.delete_project(project_id).await?;
1828 }
1829
1830 Ok(())
1831}
1832
1833/// Join someone elses shared project.
1834async fn join_project(
1835 request: proto::JoinProject,
1836 response: Response<proto::JoinProject>,
1837 session: MessageContext,
1838) -> Result<()> {
1839 let project_id = ProjectId::from_proto(request.project_id);
1840
1841 tracing::info!(%project_id, "join project");
1842
1843 let db = session.db().await;
1844 let project_model = db.get_project(project_id).await?;
1845 let host_features: Vec<String> =
1846 serde_json::from_str(&project_model.features).unwrap_or_default();
1847 let guest_features: HashSet<_> = request.features.iter().collect();
1848 let host_features_set: HashSet<_> = host_features.iter().collect();
1849 if guest_features != host_features_set {
1850 let host_connection_id = project_model.host_connection()?;
1851 let mut pool = session.connection_pool().await;
1852 let host_version = pool
1853 .connection(host_connection_id)
1854 .map(|c| c.zed_version.to_string());
1855 let guest_version = pool
1856 .connection(session.connection_id)
1857 .map(|c| c.zed_version.to_string());
1858 drop(pool);
1859 Err(anyhow!(
1860 "The host (v{}) and guest (v{}) are using incompatible versions of Zed. The peer with the older version must update to collaborate.",
1861 host_version.as_deref().unwrap_or("unknown"),
1862 guest_version.as_deref().unwrap_or("unknown"),
1863 ))?;
1864 }
1865
1866 let (project, replica_id) = &mut *db
1867 .join_project(
1868 project_id,
1869 session.connection_id,
1870 session.user_id(),
1871 request.committer_name.clone(),
1872 request.committer_email.clone(),
1873 )
1874 .await?;
1875 drop(db);
1876
1877 tracing::info!(%project_id, "join remote project");
1878 let collaborators = project
1879 .collaborators
1880 .iter()
1881 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1882 .map(|collaborator| collaborator.to_proto())
1883 .collect::<Vec<_>>();
1884 let project_id = project.id;
1885 let guest_user_id = session.user_id();
1886
1887 let worktrees = project
1888 .worktrees
1889 .iter()
1890 .map(|(id, worktree)| proto::WorktreeMetadata {
1891 id: *id,
1892 root_name: worktree.root_name.clone(),
1893 visible: worktree.visible,
1894 abs_path: worktree.abs_path.clone(),
1895 })
1896 .collect::<Vec<_>>();
1897
1898 let add_project_collaborator = proto::AddProjectCollaborator {
1899 project_id: project_id.to_proto(),
1900 collaborator: Some(proto::Collaborator {
1901 peer_id: Some(session.connection_id.into()),
1902 replica_id: replica_id.0 as u32,
1903 user_id: guest_user_id.to_proto(),
1904 is_host: false,
1905 committer_name: request.committer_name.clone(),
1906 committer_email: request.committer_email.clone(),
1907 }),
1908 };
1909
1910 for collaborator in &collaborators {
1911 session
1912 .peer
1913 .send(
1914 collaborator.peer_id.unwrap().into(),
1915 add_project_collaborator.clone(),
1916 )
1917 .trace_err();
1918 }
1919
1920 // First, we send the metadata associated with each worktree.
1921 let (language_servers, language_server_capabilities) = project
1922 .language_servers
1923 .clone()
1924 .into_iter()
1925 .map(|server| (server.server, server.capabilities))
1926 .unzip();
1927 response.send(proto::JoinProjectResponse {
1928 project_id: project.id.0 as u64,
1929 worktrees,
1930 replica_id: replica_id.0 as u32,
1931 collaborators,
1932 language_servers,
1933 language_server_capabilities,
1934 role: project.role.into(),
1935 windows_paths: project.path_style == PathStyle::Windows,
1936 features: project.features.clone(),
1937 })?;
1938
1939 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1940 // Stream this worktree's entries.
1941 let message = proto::UpdateWorktree {
1942 project_id: project_id.to_proto(),
1943 worktree_id,
1944 abs_path: worktree.abs_path.clone(),
1945 root_name: worktree.root_name,
1946 updated_entries: worktree.entries,
1947 removed_entries: Default::default(),
1948 scan_id: worktree.scan_id,
1949 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1950 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1951 removed_repositories: Default::default(),
1952 };
1953 for update in proto::split_worktree_update(message) {
1954 session.peer.send(session.connection_id, update.clone())?;
1955 }
1956
1957 // Stream this worktree's diagnostics.
1958 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1959 if let Some(summary) = worktree_diagnostics.next() {
1960 let message = proto::UpdateDiagnosticSummary {
1961 project_id: project.id.to_proto(),
1962 worktree_id: worktree.id,
1963 summary: Some(summary),
1964 more_summaries: worktree_diagnostics.collect(),
1965 };
1966 session.peer.send(session.connection_id, message)?;
1967 }
1968
1969 for settings_file in worktree.settings_files {
1970 session.peer.send(
1971 session.connection_id,
1972 proto::UpdateWorktreeSettings {
1973 project_id: project_id.to_proto(),
1974 worktree_id: worktree.id,
1975 path: settings_file.path,
1976 content: Some(settings_file.content),
1977 kind: Some(settings_file.kind.to_proto() as i32),
1978 outside_worktree: Some(settings_file.outside_worktree),
1979 },
1980 )?;
1981 }
1982 }
1983
1984 for repository in mem::take(&mut project.repositories) {
1985 for update in split_repository_update(repository) {
1986 session.peer.send(session.connection_id, update)?;
1987 }
1988 }
1989
1990 for language_server in &project.language_servers {
1991 session.peer.send(
1992 session.connection_id,
1993 proto::UpdateLanguageServer {
1994 project_id: project_id.to_proto(),
1995 server_name: Some(language_server.server.name.clone()),
1996 language_server_id: language_server.server.id,
1997 variant: Some(
1998 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1999 proto::LspDiskBasedDiagnosticsUpdated {},
2000 ),
2001 ),
2002 },
2003 )?;
2004 }
2005
2006 Ok(())
2007}
2008
2009/// Leave someone elses shared project.
2010async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2011 let sender_id = session.connection_id;
2012 let project_id = ProjectId::from_proto(request.project_id);
2013 let db = session.db().await;
2014
2015 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2016 tracing::info!(
2017 %project_id,
2018 "leave project"
2019 );
2020
2021 project_left(project, &session);
2022 if let Some(room) = room {
2023 room_updated(room, &session.peer);
2024 }
2025
2026 Ok(())
2027}
2028
2029/// Updates other participants with changes to the project
2030async fn update_project(
2031 request: proto::UpdateProject,
2032 response: Response<proto::UpdateProject>,
2033 session: MessageContext,
2034) -> Result<()> {
2035 let project_id = ProjectId::from_proto(request.project_id);
2036 let (room, guest_connection_ids) = &*session
2037 .db()
2038 .await
2039 .update_project(project_id, session.connection_id, &request.worktrees)
2040 .await?;
2041 broadcast(
2042 Some(session.connection_id),
2043 guest_connection_ids.iter().copied(),
2044 |connection_id| {
2045 session
2046 .peer
2047 .forward_send(session.connection_id, connection_id, request.clone())
2048 },
2049 );
2050 if let Some(room) = room {
2051 room_updated(room, &session.peer);
2052 }
2053 response.send(proto::Ack {})?;
2054
2055 Ok(())
2056}
2057
2058/// Updates other participants with changes to the worktree
2059async fn update_worktree(
2060 request: proto::UpdateWorktree,
2061 response: Response<proto::UpdateWorktree>,
2062 session: MessageContext,
2063) -> Result<()> {
2064 let guest_connection_ids = session
2065 .db()
2066 .await
2067 .update_worktree(&request, session.connection_id)
2068 .await?;
2069
2070 broadcast(
2071 Some(session.connection_id),
2072 guest_connection_ids.iter().copied(),
2073 |connection_id| {
2074 session
2075 .peer
2076 .forward_send(session.connection_id, connection_id, request.clone())
2077 },
2078 );
2079 response.send(proto::Ack {})?;
2080 Ok(())
2081}
2082
2083async fn update_repository(
2084 request: proto::UpdateRepository,
2085 response: Response<proto::UpdateRepository>,
2086 session: MessageContext,
2087) -> Result<()> {
2088 let guest_connection_ids = session
2089 .db()
2090 .await
2091 .update_repository(&request, session.connection_id)
2092 .await?;
2093
2094 broadcast(
2095 Some(session.connection_id),
2096 guest_connection_ids.iter().copied(),
2097 |connection_id| {
2098 session
2099 .peer
2100 .forward_send(session.connection_id, connection_id, request.clone())
2101 },
2102 );
2103 response.send(proto::Ack {})?;
2104 Ok(())
2105}
2106
2107async fn remove_repository(
2108 request: proto::RemoveRepository,
2109 response: Response<proto::RemoveRepository>,
2110 session: MessageContext,
2111) -> Result<()> {
2112 let guest_connection_ids = session
2113 .db()
2114 .await
2115 .remove_repository(&request, session.connection_id)
2116 .await?;
2117
2118 broadcast(
2119 Some(session.connection_id),
2120 guest_connection_ids.iter().copied(),
2121 |connection_id| {
2122 session
2123 .peer
2124 .forward_send(session.connection_id, connection_id, request.clone())
2125 },
2126 );
2127 response.send(proto::Ack {})?;
2128 Ok(())
2129}
2130
2131/// Updates other participants with changes to the diagnostics
2132async fn update_diagnostic_summary(
2133 message: proto::UpdateDiagnosticSummary,
2134 session: MessageContext,
2135) -> Result<()> {
2136 let guest_connection_ids = session
2137 .db()
2138 .await
2139 .update_diagnostic_summary(&message, session.connection_id)
2140 .await?;
2141
2142 broadcast(
2143 Some(session.connection_id),
2144 guest_connection_ids.iter().copied(),
2145 |connection_id| {
2146 session
2147 .peer
2148 .forward_send(session.connection_id, connection_id, message.clone())
2149 },
2150 );
2151
2152 Ok(())
2153}
2154
2155/// Updates other participants with changes to the worktree settings
2156async fn update_worktree_settings(
2157 message: proto::UpdateWorktreeSettings,
2158 session: MessageContext,
2159) -> Result<()> {
2160 let guest_connection_ids = session
2161 .db()
2162 .await
2163 .update_worktree_settings(&message, session.connection_id)
2164 .await?;
2165
2166 broadcast(
2167 Some(session.connection_id),
2168 guest_connection_ids.iter().copied(),
2169 |connection_id| {
2170 session
2171 .peer
2172 .forward_send(session.connection_id, connection_id, message.clone())
2173 },
2174 );
2175
2176 Ok(())
2177}
2178
2179/// Notify other participants that a language server has started.
2180async fn start_language_server(
2181 request: proto::StartLanguageServer,
2182 session: MessageContext,
2183) -> Result<()> {
2184 let guest_connection_ids = session
2185 .db()
2186 .await
2187 .start_language_server(&request, session.connection_id)
2188 .await?;
2189
2190 broadcast(
2191 Some(session.connection_id),
2192 guest_connection_ids.iter().copied(),
2193 |connection_id| {
2194 session
2195 .peer
2196 .forward_send(session.connection_id, connection_id, request.clone())
2197 },
2198 );
2199 Ok(())
2200}
2201
2202/// Notify other participants that a language server has changed.
2203async fn update_language_server(
2204 request: proto::UpdateLanguageServer,
2205 session: MessageContext,
2206) -> Result<()> {
2207 let project_id = ProjectId::from_proto(request.project_id);
2208 let db = session.db().await;
2209
2210 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2211 && let Some(capabilities) = update.capabilities.clone()
2212 {
2213 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2214 .await?;
2215 }
2216
2217 let project_connection_ids = db
2218 .project_connection_ids(project_id, session.connection_id, true)
2219 .await?;
2220 broadcast(
2221 Some(session.connection_id),
2222 project_connection_ids.iter().copied(),
2223 |connection_id| {
2224 session
2225 .peer
2226 .forward_send(session.connection_id, connection_id, request.clone())
2227 },
2228 );
2229 Ok(())
2230}
2231
2232/// forward a project request to the host. These requests should be read only
2233/// as guests are allowed to send them.
2234async fn forward_read_only_project_request<T>(
2235 request: T,
2236 response: Response<T>,
2237 session: MessageContext,
2238) -> Result<()>
2239where
2240 T: EntityMessage + RequestMessage,
2241{
2242 let project_id = ProjectId::from_proto(request.remote_entity_id());
2243 let host_connection_id = session
2244 .db()
2245 .await
2246 .host_for_read_only_project_request(project_id, session.connection_id)
2247 .await?;
2248 let payload = session.forward_request(host_connection_id, request).await?;
2249 response.send(payload)?;
2250 Ok(())
2251}
2252
2253/// forward a project request to the host. These requests are disallowed
2254/// for guests.
2255async fn forward_mutating_project_request<T>(
2256 request: T,
2257 response: Response<T>,
2258 session: MessageContext,
2259) -> Result<()>
2260where
2261 T: EntityMessage + RequestMessage,
2262{
2263 let project_id = ProjectId::from_proto(request.remote_entity_id());
2264
2265 let host_connection_id = session
2266 .db()
2267 .await
2268 .host_for_mutating_project_request(project_id, session.connection_id)
2269 .await?;
2270 let payload = session.forward_request(host_connection_id, request).await?;
2271 response.send(payload)?;
2272 Ok(())
2273}
2274
2275async fn disallow_guest_request<T>(
2276 _request: T,
2277 response: Response<T>,
2278 _session: MessageContext,
2279) -> Result<()>
2280where
2281 T: RequestMessage,
2282{
2283 response.peer.respond_with_error(
2284 response.receipt,
2285 ErrorCode::Forbidden
2286 .message("request is not allowed for guests".to_string())
2287 .to_proto(),
2288 )?;
2289 response.responded.store(true, SeqCst);
2290 Ok(())
2291}
2292
2293async fn lsp_query(
2294 request: proto::LspQuery,
2295 response: Response<proto::LspQuery>,
2296 session: MessageContext,
2297) -> Result<()> {
2298 let (name, should_write) = request.query_name_and_write_permissions();
2299 tracing::Span::current().record("lsp_query_request", name);
2300 tracing::info!("lsp_query message received");
2301 if should_write {
2302 forward_mutating_project_request(request, response, session).await
2303 } else {
2304 forward_read_only_project_request(request, response, session).await
2305 }
2306}
2307
2308/// Notify other participants that a new buffer has been created
2309async fn create_buffer_for_peer(
2310 request: proto::CreateBufferForPeer,
2311 session: MessageContext,
2312) -> Result<()> {
2313 session
2314 .db()
2315 .await
2316 .check_user_is_project_host(
2317 ProjectId::from_proto(request.project_id),
2318 session.connection_id,
2319 )
2320 .await?;
2321 let peer_id = request.peer_id.context("invalid peer id")?;
2322 session
2323 .peer
2324 .forward_send(session.connection_id, peer_id.into(), request)?;
2325 Ok(())
2326}
2327
2328/// Notify other participants that a new image has been created
2329async fn create_image_for_peer(
2330 request: proto::CreateImageForPeer,
2331 session: MessageContext,
2332) -> Result<()> {
2333 session
2334 .db()
2335 .await
2336 .check_user_is_project_host(
2337 ProjectId::from_proto(request.project_id),
2338 session.connection_id,
2339 )
2340 .await?;
2341 let peer_id = request.peer_id.context("invalid peer id")?;
2342 session
2343 .peer
2344 .forward_send(session.connection_id, peer_id.into(), request)?;
2345 Ok(())
2346}
2347
2348/// Notify other participants that a buffer has been updated. This is
2349/// allowed for guests as long as the update is limited to selections.
2350async fn update_buffer(
2351 request: proto::UpdateBuffer,
2352 response: Response<proto::UpdateBuffer>,
2353 session: MessageContext,
2354) -> Result<()> {
2355 let project_id = ProjectId::from_proto(request.project_id);
2356 let mut capability = Capability::ReadOnly;
2357
2358 for op in request.operations.iter() {
2359 match op.variant {
2360 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2361 Some(_) => capability = Capability::ReadWrite,
2362 }
2363 }
2364
2365 let host = {
2366 let guard = session
2367 .db()
2368 .await
2369 .connections_for_buffer_update(project_id, session.connection_id, capability)
2370 .await?;
2371
2372 let (host, guests) = &*guard;
2373
2374 broadcast(
2375 Some(session.connection_id),
2376 guests.clone(),
2377 |connection_id| {
2378 session
2379 .peer
2380 .forward_send(session.connection_id, connection_id, request.clone())
2381 },
2382 );
2383
2384 *host
2385 };
2386
2387 if host != session.connection_id {
2388 session.forward_request(host, request.clone()).await?;
2389 }
2390
2391 response.send(proto::Ack {})?;
2392 Ok(())
2393}
2394
2395async fn forward_project_search_chunk(
2396 message: proto::FindSearchCandidatesChunk,
2397 response: Response<proto::FindSearchCandidatesChunk>,
2398 session: MessageContext,
2399) -> Result<()> {
2400 let peer_id = message.peer_id.context("missing peer_id")?;
2401 let payload = session
2402 .peer
2403 .forward_request(session.connection_id, peer_id.into(), message)
2404 .await?;
2405 response.send(payload)?;
2406 Ok(())
2407}
2408
2409/// Notify other participants that a project has been updated.
2410async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2411 request: T,
2412 session: MessageContext,
2413) -> Result<()> {
2414 let project_id = ProjectId::from_proto(request.remote_entity_id());
2415 let project_connection_ids = session
2416 .db()
2417 .await
2418 .project_connection_ids(project_id, session.connection_id, false)
2419 .await?;
2420
2421 broadcast(
2422 Some(session.connection_id),
2423 project_connection_ids.iter().copied(),
2424 |connection_id| {
2425 session
2426 .peer
2427 .forward_send(session.connection_id, connection_id, request.clone())
2428 },
2429 );
2430 Ok(())
2431}
2432
2433/// Start following another user in a call.
2434async fn follow(
2435 request: proto::Follow,
2436 response: Response<proto::Follow>,
2437 session: MessageContext,
2438) -> Result<()> {
2439 let room_id = RoomId::from_proto(request.room_id);
2440 let project_id = request.project_id.map(ProjectId::from_proto);
2441 let leader_id = request.leader_id.context("invalid leader id")?.into();
2442 let follower_id = session.connection_id;
2443
2444 session
2445 .db()
2446 .await
2447 .check_room_participants(room_id, leader_id, session.connection_id)
2448 .await?;
2449
2450 let response_payload = session.forward_request(leader_id, request).await?;
2451 response.send(response_payload)?;
2452
2453 if let Some(project_id) = project_id {
2454 let room = session
2455 .db()
2456 .await
2457 .follow(room_id, project_id, leader_id, follower_id)
2458 .await?;
2459 room_updated(&room, &session.peer);
2460 }
2461
2462 Ok(())
2463}
2464
2465/// Stop following another user in a call.
2466async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2467 let room_id = RoomId::from_proto(request.room_id);
2468 let project_id = request.project_id.map(ProjectId::from_proto);
2469 let leader_id = request.leader_id.context("invalid leader id")?.into();
2470 let follower_id = session.connection_id;
2471
2472 session
2473 .db()
2474 .await
2475 .check_room_participants(room_id, leader_id, session.connection_id)
2476 .await?;
2477
2478 session
2479 .peer
2480 .forward_send(session.connection_id, leader_id, request)?;
2481
2482 if let Some(project_id) = project_id {
2483 let room = session
2484 .db()
2485 .await
2486 .unfollow(room_id, project_id, leader_id, follower_id)
2487 .await?;
2488 room_updated(&room, &session.peer);
2489 }
2490
2491 Ok(())
2492}
2493
2494/// Notify everyone following you of your current location.
2495async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2496 let room_id = RoomId::from_proto(request.room_id);
2497 let database = session.db.lock().await;
2498
2499 let connection_ids = if let Some(project_id) = request.project_id {
2500 let project_id = ProjectId::from_proto(project_id);
2501 database
2502 .project_connection_ids(project_id, session.connection_id, true)
2503 .await?
2504 } else {
2505 database
2506 .room_connection_ids(room_id, session.connection_id)
2507 .await?
2508 };
2509
2510 // For now, don't send view update messages back to that view's current leader.
2511 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2512 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2513 _ => None,
2514 });
2515
2516 for connection_id in connection_ids.iter().cloned() {
2517 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2518 session
2519 .peer
2520 .forward_send(session.connection_id, connection_id, request.clone())?;
2521 }
2522 }
2523 Ok(())
2524}
2525
2526/// Get public data about users.
2527async fn get_users(
2528 request: proto::GetUsers,
2529 response: Response<proto::GetUsers>,
2530 session: MessageContext,
2531) -> Result<()> {
2532 let user_ids = request
2533 .user_ids
2534 .into_iter()
2535 .map(UserId::from_proto)
2536 .collect();
2537 let users = session
2538 .db()
2539 .await
2540 .get_users_by_ids(user_ids)
2541 .await?
2542 .into_iter()
2543 .map(|user| proto::User {
2544 id: user.id.to_proto(),
2545 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2546 github_login: user.github_login,
2547 name: user.name,
2548 })
2549 .collect();
2550 response.send(proto::UsersResponse { users })?;
2551 Ok(())
2552}
2553
2554/// Search for users (to invite) buy Github login
2555async fn fuzzy_search_users(
2556 request: proto::FuzzySearchUsers,
2557 response: Response<proto::FuzzySearchUsers>,
2558 session: MessageContext,
2559) -> Result<()> {
2560 let query = request.query;
2561 let users = match query.len() {
2562 0 => vec![],
2563 1 | 2 => session
2564 .db()
2565 .await
2566 .get_user_by_github_login(&query)
2567 .await?
2568 .into_iter()
2569 .collect(),
2570 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2571 };
2572 let users = users
2573 .into_iter()
2574 .filter(|user| user.id != session.user_id())
2575 .map(|user| proto::User {
2576 id: user.id.to_proto(),
2577 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2578 github_login: user.github_login,
2579 name: user.name,
2580 })
2581 .collect();
2582 response.send(proto::UsersResponse { users })?;
2583 Ok(())
2584}
2585
2586/// Send a contact request to another user.
2587async fn request_contact(
2588 request: proto::RequestContact,
2589 response: Response<proto::RequestContact>,
2590 session: MessageContext,
2591) -> Result<()> {
2592 let requester_id = session.user_id();
2593 let responder_id = UserId::from_proto(request.responder_id);
2594 if requester_id == responder_id {
2595 return Err(anyhow!("cannot add yourself as a contact"))?;
2596 }
2597
2598 let notifications = session
2599 .db()
2600 .await
2601 .send_contact_request(requester_id, responder_id)
2602 .await?;
2603
2604 // Update outgoing contact requests of requester
2605 let mut update = proto::UpdateContacts::default();
2606 update.outgoing_requests.push(responder_id.to_proto());
2607 for connection_id in session
2608 .connection_pool()
2609 .await
2610 .user_connection_ids(requester_id)
2611 {
2612 session.peer.send(connection_id, update.clone())?;
2613 }
2614
2615 // Update incoming contact requests of responder
2616 let mut update = proto::UpdateContacts::default();
2617 update
2618 .incoming_requests
2619 .push(proto::IncomingContactRequest {
2620 requester_id: requester_id.to_proto(),
2621 });
2622 let connection_pool = session.connection_pool().await;
2623 for connection_id in connection_pool.user_connection_ids(responder_id) {
2624 session.peer.send(connection_id, update.clone())?;
2625 }
2626
2627 send_notifications(&connection_pool, &session.peer, notifications);
2628
2629 response.send(proto::Ack {})?;
2630 Ok(())
2631}
2632
2633/// Accept or decline a contact request
2634async fn respond_to_contact_request(
2635 request: proto::RespondToContactRequest,
2636 response: Response<proto::RespondToContactRequest>,
2637 session: MessageContext,
2638) -> Result<()> {
2639 let responder_id = session.user_id();
2640 let requester_id = UserId::from_proto(request.requester_id);
2641 let db = session.db().await;
2642 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2643 db.dismiss_contact_notification(responder_id, requester_id)
2644 .await?;
2645 } else {
2646 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2647
2648 let notifications = db
2649 .respond_to_contact_request(responder_id, requester_id, accept)
2650 .await?;
2651 let requester_busy = db.is_user_busy(requester_id).await?;
2652 let responder_busy = db.is_user_busy(responder_id).await?;
2653
2654 let pool = session.connection_pool().await;
2655 // Update responder with new contact
2656 let mut update = proto::UpdateContacts::default();
2657 if accept {
2658 update
2659 .contacts
2660 .push(contact_for_user(requester_id, requester_busy, &pool));
2661 }
2662 update
2663 .remove_incoming_requests
2664 .push(requester_id.to_proto());
2665 for connection_id in pool.user_connection_ids(responder_id) {
2666 session.peer.send(connection_id, update.clone())?;
2667 }
2668
2669 // Update requester with new contact
2670 let mut update = proto::UpdateContacts::default();
2671 if accept {
2672 update
2673 .contacts
2674 .push(contact_for_user(responder_id, responder_busy, &pool));
2675 }
2676 update
2677 .remove_outgoing_requests
2678 .push(responder_id.to_proto());
2679
2680 for connection_id in pool.user_connection_ids(requester_id) {
2681 session.peer.send(connection_id, update.clone())?;
2682 }
2683
2684 send_notifications(&pool, &session.peer, notifications);
2685 }
2686
2687 response.send(proto::Ack {})?;
2688 Ok(())
2689}
2690
2691/// Remove a contact.
2692async fn remove_contact(
2693 request: proto::RemoveContact,
2694 response: Response<proto::RemoveContact>,
2695 session: MessageContext,
2696) -> Result<()> {
2697 let requester_id = session.user_id();
2698 let responder_id = UserId::from_proto(request.user_id);
2699 let db = session.db().await;
2700 let (contact_accepted, deleted_notification_id) =
2701 db.remove_contact(requester_id, responder_id).await?;
2702
2703 let pool = session.connection_pool().await;
2704 // Update outgoing contact requests of requester
2705 let mut update = proto::UpdateContacts::default();
2706 if contact_accepted {
2707 update.remove_contacts.push(responder_id.to_proto());
2708 } else {
2709 update
2710 .remove_outgoing_requests
2711 .push(responder_id.to_proto());
2712 }
2713 for connection_id in pool.user_connection_ids(requester_id) {
2714 session.peer.send(connection_id, update.clone())?;
2715 }
2716
2717 // Update incoming contact requests of responder
2718 let mut update = proto::UpdateContacts::default();
2719 if contact_accepted {
2720 update.remove_contacts.push(requester_id.to_proto());
2721 } else {
2722 update
2723 .remove_incoming_requests
2724 .push(requester_id.to_proto());
2725 }
2726 for connection_id in pool.user_connection_ids(responder_id) {
2727 session.peer.send(connection_id, update.clone())?;
2728 if let Some(notification_id) = deleted_notification_id {
2729 session.peer.send(
2730 connection_id,
2731 proto::DeleteNotification {
2732 notification_id: notification_id.to_proto(),
2733 },
2734 )?;
2735 }
2736 }
2737
2738 response.send(proto::Ack {})?;
2739 Ok(())
2740}
2741
2742fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2743 version.0.minor < 139
2744}
2745
2746async fn subscribe_to_channels(
2747 _: proto::SubscribeToChannels,
2748 session: MessageContext,
2749) -> Result<()> {
2750 subscribe_user_to_channels(session.user_id(), &session).await?;
2751 Ok(())
2752}
2753
2754async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2755 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2756 let mut pool = session.connection_pool().await;
2757 for membership in &channels_for_user.channel_memberships {
2758 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2759 }
2760 session.peer.send(
2761 session.connection_id,
2762 build_update_user_channels(&channels_for_user),
2763 )?;
2764 session.peer.send(
2765 session.connection_id,
2766 build_channels_update(channels_for_user),
2767 )?;
2768 Ok(())
2769}
2770
2771/// Creates a new channel.
2772async fn create_channel(
2773 request: proto::CreateChannel,
2774 response: Response<proto::CreateChannel>,
2775 session: MessageContext,
2776) -> Result<()> {
2777 let db = session.db().await;
2778
2779 let parent_id = request.parent_id.map(ChannelId::from_proto);
2780 let (channel, membership) = db
2781 .create_channel(&request.name, parent_id, session.user_id())
2782 .await?;
2783
2784 let root_id = channel.root_id();
2785 let channel = Channel::from_model(channel);
2786
2787 response.send(proto::CreateChannelResponse {
2788 channel: Some(channel.to_proto()),
2789 parent_id: request.parent_id,
2790 })?;
2791
2792 let mut connection_pool = session.connection_pool().await;
2793 if let Some(membership) = membership {
2794 connection_pool.subscribe_to_channel(
2795 membership.user_id,
2796 membership.channel_id,
2797 membership.role,
2798 );
2799 let update = proto::UpdateUserChannels {
2800 channel_memberships: vec![proto::ChannelMembership {
2801 channel_id: membership.channel_id.to_proto(),
2802 role: membership.role.into(),
2803 }],
2804 ..Default::default()
2805 };
2806 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2807 session.peer.send(connection_id, update.clone())?;
2808 }
2809 }
2810
2811 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2812 if !role.can_see_channel(channel.visibility) {
2813 continue;
2814 }
2815
2816 let update = proto::UpdateChannels {
2817 channels: vec![channel.to_proto()],
2818 ..Default::default()
2819 };
2820 session.peer.send(connection_id, update.clone())?;
2821 }
2822
2823 Ok(())
2824}
2825
2826/// Delete a channel
2827async fn delete_channel(
2828 request: proto::DeleteChannel,
2829 response: Response<proto::DeleteChannel>,
2830 session: MessageContext,
2831) -> Result<()> {
2832 let db = session.db().await;
2833
2834 let channel_id = request.channel_id;
2835 let (root_channel, removed_channels) = db
2836 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2837 .await?;
2838 response.send(proto::Ack {})?;
2839
2840 // Notify members of removed channels
2841 let mut update = proto::UpdateChannels::default();
2842 update
2843 .delete_channels
2844 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2845
2846 let connection_pool = session.connection_pool().await;
2847 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2848 session.peer.send(connection_id, update.clone())?;
2849 }
2850
2851 Ok(())
2852}
2853
2854/// Invite someone to join a channel.
2855async fn invite_channel_member(
2856 request: proto::InviteChannelMember,
2857 response: Response<proto::InviteChannelMember>,
2858 session: MessageContext,
2859) -> Result<()> {
2860 let db = session.db().await;
2861 let channel_id = ChannelId::from_proto(request.channel_id);
2862 let invitee_id = UserId::from_proto(request.user_id);
2863 let InviteMemberResult {
2864 channel,
2865 notifications,
2866 } = db
2867 .invite_channel_member(
2868 channel_id,
2869 invitee_id,
2870 session.user_id(),
2871 request.role().into(),
2872 )
2873 .await?;
2874
2875 let update = proto::UpdateChannels {
2876 channel_invitations: vec![channel.to_proto()],
2877 ..Default::default()
2878 };
2879
2880 let connection_pool = session.connection_pool().await;
2881 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2882 session.peer.send(connection_id, update.clone())?;
2883 }
2884
2885 send_notifications(&connection_pool, &session.peer, notifications);
2886
2887 response.send(proto::Ack {})?;
2888 Ok(())
2889}
2890
2891/// remove someone from a channel
2892async fn remove_channel_member(
2893 request: proto::RemoveChannelMember,
2894 response: Response<proto::RemoveChannelMember>,
2895 session: MessageContext,
2896) -> Result<()> {
2897 let db = session.db().await;
2898 let channel_id = ChannelId::from_proto(request.channel_id);
2899 let member_id = UserId::from_proto(request.user_id);
2900
2901 let RemoveChannelMemberResult {
2902 membership_update,
2903 notification_id,
2904 } = db
2905 .remove_channel_member(channel_id, member_id, session.user_id())
2906 .await?;
2907
2908 let mut connection_pool = session.connection_pool().await;
2909 notify_membership_updated(
2910 &mut connection_pool,
2911 membership_update,
2912 member_id,
2913 &session.peer,
2914 );
2915 for connection_id in connection_pool.user_connection_ids(member_id) {
2916 if let Some(notification_id) = notification_id {
2917 session
2918 .peer
2919 .send(
2920 connection_id,
2921 proto::DeleteNotification {
2922 notification_id: notification_id.to_proto(),
2923 },
2924 )
2925 .trace_err();
2926 }
2927 }
2928
2929 response.send(proto::Ack {})?;
2930 Ok(())
2931}
2932
2933/// Toggle the channel between public and private.
2934/// Care is taken to maintain the invariant that public channels only descend from public channels,
2935/// (though members-only channels can appear at any point in the hierarchy).
2936async fn set_channel_visibility(
2937 request: proto::SetChannelVisibility,
2938 response: Response<proto::SetChannelVisibility>,
2939 session: MessageContext,
2940) -> Result<()> {
2941 let db = session.db().await;
2942 let channel_id = ChannelId::from_proto(request.channel_id);
2943 let visibility = request.visibility().into();
2944
2945 let channel_model = db
2946 .set_channel_visibility(channel_id, visibility, session.user_id())
2947 .await?;
2948 let root_id = channel_model.root_id();
2949 let channel = Channel::from_model(channel_model);
2950
2951 let mut connection_pool = session.connection_pool().await;
2952 for (user_id, role) in connection_pool
2953 .channel_user_ids(root_id)
2954 .collect::<Vec<_>>()
2955 .into_iter()
2956 {
2957 let update = if role.can_see_channel(channel.visibility) {
2958 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2959 proto::UpdateChannels {
2960 channels: vec![channel.to_proto()],
2961 ..Default::default()
2962 }
2963 } else {
2964 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2965 proto::UpdateChannels {
2966 delete_channels: vec![channel.id.to_proto()],
2967 ..Default::default()
2968 }
2969 };
2970
2971 for connection_id in connection_pool.user_connection_ids(user_id) {
2972 session.peer.send(connection_id, update.clone())?;
2973 }
2974 }
2975
2976 response.send(proto::Ack {})?;
2977 Ok(())
2978}
2979
2980/// Alter the role for a user in the channel.
2981async fn set_channel_member_role(
2982 request: proto::SetChannelMemberRole,
2983 response: Response<proto::SetChannelMemberRole>,
2984 session: MessageContext,
2985) -> Result<()> {
2986 let db = session.db().await;
2987 let channel_id = ChannelId::from_proto(request.channel_id);
2988 let member_id = UserId::from_proto(request.user_id);
2989 let result = db
2990 .set_channel_member_role(
2991 channel_id,
2992 session.user_id(),
2993 member_id,
2994 request.role().into(),
2995 )
2996 .await?;
2997
2998 match result {
2999 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3000 let mut connection_pool = session.connection_pool().await;
3001 notify_membership_updated(
3002 &mut connection_pool,
3003 membership_update,
3004 member_id,
3005 &session.peer,
3006 )
3007 }
3008 db::SetMemberRoleResult::InviteUpdated(channel) => {
3009 let update = proto::UpdateChannels {
3010 channel_invitations: vec![channel.to_proto()],
3011 ..Default::default()
3012 };
3013
3014 for connection_id in session
3015 .connection_pool()
3016 .await
3017 .user_connection_ids(member_id)
3018 {
3019 session.peer.send(connection_id, update.clone())?;
3020 }
3021 }
3022 }
3023
3024 response.send(proto::Ack {})?;
3025 Ok(())
3026}
3027
3028/// Change the name of a channel
3029async fn rename_channel(
3030 request: proto::RenameChannel,
3031 response: Response<proto::RenameChannel>,
3032 session: MessageContext,
3033) -> Result<()> {
3034 let db = session.db().await;
3035 let channel_id = ChannelId::from_proto(request.channel_id);
3036 let channel_model = db
3037 .rename_channel(channel_id, session.user_id(), &request.name)
3038 .await?;
3039 let root_id = channel_model.root_id();
3040 let channel = Channel::from_model(channel_model);
3041
3042 response.send(proto::RenameChannelResponse {
3043 channel: Some(channel.to_proto()),
3044 })?;
3045
3046 let connection_pool = session.connection_pool().await;
3047 let update = proto::UpdateChannels {
3048 channels: vec![channel.to_proto()],
3049 ..Default::default()
3050 };
3051 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3052 if role.can_see_channel(channel.visibility) {
3053 session.peer.send(connection_id, update.clone())?;
3054 }
3055 }
3056
3057 Ok(())
3058}
3059
3060/// Move a channel to a new parent.
3061async fn move_channel(
3062 request: proto::MoveChannel,
3063 response: Response<proto::MoveChannel>,
3064 session: MessageContext,
3065) -> Result<()> {
3066 let channel_id = ChannelId::from_proto(request.channel_id);
3067 let to = ChannelId::from_proto(request.to);
3068
3069 let (root_id, channels) = session
3070 .db()
3071 .await
3072 .move_channel(channel_id, to, session.user_id())
3073 .await?;
3074
3075 let connection_pool = session.connection_pool().await;
3076 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3077 let channels = channels
3078 .iter()
3079 .filter_map(|channel| {
3080 if role.can_see_channel(channel.visibility) {
3081 Some(channel.to_proto())
3082 } else {
3083 None
3084 }
3085 })
3086 .collect::<Vec<_>>();
3087 if channels.is_empty() {
3088 continue;
3089 }
3090
3091 let update = proto::UpdateChannels {
3092 channels,
3093 ..Default::default()
3094 };
3095
3096 session.peer.send(connection_id, update.clone())?;
3097 }
3098
3099 response.send(Ack {})?;
3100 Ok(())
3101}
3102
3103async fn reorder_channel(
3104 request: proto::ReorderChannel,
3105 response: Response<proto::ReorderChannel>,
3106 session: MessageContext,
3107) -> Result<()> {
3108 let channel_id = ChannelId::from_proto(request.channel_id);
3109 let direction = request.direction();
3110
3111 let updated_channels = session
3112 .db()
3113 .await
3114 .reorder_channel(channel_id, direction, session.user_id())
3115 .await?;
3116
3117 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3118 let connection_pool = session.connection_pool().await;
3119 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3120 let channels = updated_channels
3121 .iter()
3122 .filter_map(|channel| {
3123 if role.can_see_channel(channel.visibility) {
3124 Some(channel.to_proto())
3125 } else {
3126 None
3127 }
3128 })
3129 .collect::<Vec<_>>();
3130
3131 if channels.is_empty() {
3132 continue;
3133 }
3134
3135 let update = proto::UpdateChannels {
3136 channels,
3137 ..Default::default()
3138 };
3139
3140 session.peer.send(connection_id, update.clone())?;
3141 }
3142 }
3143
3144 response.send(Ack {})?;
3145 Ok(())
3146}
3147
3148/// Get the list of channel members
3149async fn get_channel_members(
3150 request: proto::GetChannelMembers,
3151 response: Response<proto::GetChannelMembers>,
3152 session: MessageContext,
3153) -> Result<()> {
3154 let db = session.db().await;
3155 let channel_id = ChannelId::from_proto(request.channel_id);
3156 let limit = if request.limit == 0 {
3157 u16::MAX as u64
3158 } else {
3159 request.limit
3160 };
3161 let (members, users) = db
3162 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3163 .await?;
3164 response.send(proto::GetChannelMembersResponse { members, users })?;
3165 Ok(())
3166}
3167
3168/// Accept or decline a channel invitation.
3169async fn respond_to_channel_invite(
3170 request: proto::RespondToChannelInvite,
3171 response: Response<proto::RespondToChannelInvite>,
3172 session: MessageContext,
3173) -> Result<()> {
3174 let db = session.db().await;
3175 let channel_id = ChannelId::from_proto(request.channel_id);
3176 let RespondToChannelInvite {
3177 membership_update,
3178 notifications,
3179 } = db
3180 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3181 .await?;
3182
3183 let mut connection_pool = session.connection_pool().await;
3184 if let Some(membership_update) = membership_update {
3185 notify_membership_updated(
3186 &mut connection_pool,
3187 membership_update,
3188 session.user_id(),
3189 &session.peer,
3190 );
3191 } else {
3192 let update = proto::UpdateChannels {
3193 remove_channel_invitations: vec![channel_id.to_proto()],
3194 ..Default::default()
3195 };
3196
3197 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3198 session.peer.send(connection_id, update.clone())?;
3199 }
3200 };
3201
3202 send_notifications(&connection_pool, &session.peer, notifications);
3203
3204 response.send(proto::Ack {})?;
3205
3206 Ok(())
3207}
3208
3209/// Join the channels' room
3210async fn join_channel(
3211 request: proto::JoinChannel,
3212 response: Response<proto::JoinChannel>,
3213 session: MessageContext,
3214) -> Result<()> {
3215 let channel_id = ChannelId::from_proto(request.channel_id);
3216 join_channel_internal(channel_id, Box::new(response), session).await
3217}
3218
3219trait JoinChannelInternalResponse {
3220 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3221}
3222impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3223 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3224 Response::<proto::JoinChannel>::send(self, result)
3225 }
3226}
3227impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3228 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3229 Response::<proto::JoinRoom>::send(self, result)
3230 }
3231}
3232
3233async fn join_channel_internal(
3234 channel_id: ChannelId,
3235 response: Box<impl JoinChannelInternalResponse>,
3236 session: MessageContext,
3237) -> Result<()> {
3238 let joined_room = {
3239 let mut db = session.db().await;
3240 // If zed quits without leaving the room, and the user re-opens zed before the
3241 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3242 // room they were in.
3243 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3244 tracing::info!(
3245 stale_connection_id = %connection,
3246 "cleaning up stale connection",
3247 );
3248 drop(db);
3249 leave_room_for_session(&session, connection).await?;
3250 db = session.db().await;
3251 }
3252
3253 let (joined_room, membership_updated, role) = db
3254 .join_channel(channel_id, session.user_id(), session.connection_id)
3255 .await?;
3256
3257 let live_kit_connection_info =
3258 session
3259 .app_state
3260 .livekit_client
3261 .as_ref()
3262 .and_then(|live_kit| {
3263 let (can_publish, token) = if role == ChannelRole::Guest {
3264 (
3265 false,
3266 live_kit
3267 .guest_token(
3268 &joined_room.room.livekit_room,
3269 &session.user_id().to_string(),
3270 )
3271 .trace_err()?,
3272 )
3273 } else {
3274 (
3275 true,
3276 live_kit
3277 .room_token(
3278 &joined_room.room.livekit_room,
3279 &session.user_id().to_string(),
3280 )
3281 .trace_err()?,
3282 )
3283 };
3284
3285 Some(LiveKitConnectionInfo {
3286 server_url: live_kit.url().into(),
3287 token,
3288 can_publish,
3289 })
3290 });
3291
3292 response.send(proto::JoinRoomResponse {
3293 room: Some(joined_room.room.clone()),
3294 channel_id: joined_room
3295 .channel
3296 .as_ref()
3297 .map(|channel| channel.id.to_proto()),
3298 live_kit_connection_info,
3299 })?;
3300
3301 let mut connection_pool = session.connection_pool().await;
3302 if let Some(membership_updated) = membership_updated {
3303 notify_membership_updated(
3304 &mut connection_pool,
3305 membership_updated,
3306 session.user_id(),
3307 &session.peer,
3308 );
3309 }
3310
3311 room_updated(&joined_room.room, &session.peer);
3312
3313 joined_room
3314 };
3315
3316 channel_updated(
3317 &joined_room.channel.context("channel not returned")?,
3318 &joined_room.room,
3319 &session.peer,
3320 &*session.connection_pool().await,
3321 );
3322
3323 update_user_contacts(session.user_id(), &session).await?;
3324 Ok(())
3325}
3326
3327/// Start editing the channel notes
3328async fn join_channel_buffer(
3329 request: proto::JoinChannelBuffer,
3330 response: Response<proto::JoinChannelBuffer>,
3331 session: MessageContext,
3332) -> Result<()> {
3333 let db = session.db().await;
3334 let channel_id = ChannelId::from_proto(request.channel_id);
3335
3336 let open_response = db
3337 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3338 .await?;
3339
3340 let collaborators = open_response.collaborators.clone();
3341 response.send(open_response)?;
3342
3343 let update = UpdateChannelBufferCollaborators {
3344 channel_id: channel_id.to_proto(),
3345 collaborators: collaborators.clone(),
3346 };
3347 channel_buffer_updated(
3348 session.connection_id,
3349 collaborators
3350 .iter()
3351 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3352 &update,
3353 &session.peer,
3354 );
3355
3356 Ok(())
3357}
3358
3359/// Edit the channel notes
3360async fn update_channel_buffer(
3361 request: proto::UpdateChannelBuffer,
3362 session: MessageContext,
3363) -> Result<()> {
3364 let db = session.db().await;
3365 let channel_id = ChannelId::from_proto(request.channel_id);
3366
3367 let (collaborators, epoch, version) = db
3368 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3369 .await?;
3370
3371 channel_buffer_updated(
3372 session.connection_id,
3373 collaborators.clone(),
3374 &proto::UpdateChannelBuffer {
3375 channel_id: channel_id.to_proto(),
3376 operations: request.operations,
3377 },
3378 &session.peer,
3379 );
3380
3381 let pool = &*session.connection_pool().await;
3382
3383 let non_collaborators =
3384 pool.channel_connection_ids(channel_id)
3385 .filter_map(|(connection_id, _)| {
3386 if collaborators.contains(&connection_id) {
3387 None
3388 } else {
3389 Some(connection_id)
3390 }
3391 });
3392
3393 broadcast(None, non_collaborators, |peer_id| {
3394 session.peer.send(
3395 peer_id,
3396 proto::UpdateChannels {
3397 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3398 channel_id: channel_id.to_proto(),
3399 epoch: epoch as u64,
3400 version: version.clone(),
3401 }],
3402 ..Default::default()
3403 },
3404 )
3405 });
3406
3407 Ok(())
3408}
3409
3410/// Rejoin the channel notes after a connection blip
3411async fn rejoin_channel_buffers(
3412 request: proto::RejoinChannelBuffers,
3413 response: Response<proto::RejoinChannelBuffers>,
3414 session: MessageContext,
3415) -> Result<()> {
3416 let db = session.db().await;
3417 let buffers = db
3418 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3419 .await?;
3420
3421 for rejoined_buffer in &buffers {
3422 let collaborators_to_notify = rejoined_buffer
3423 .buffer
3424 .collaborators
3425 .iter()
3426 .filter_map(|c| Some(c.peer_id?.into()));
3427 channel_buffer_updated(
3428 session.connection_id,
3429 collaborators_to_notify,
3430 &proto::UpdateChannelBufferCollaborators {
3431 channel_id: rejoined_buffer.buffer.channel_id,
3432 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3433 },
3434 &session.peer,
3435 );
3436 }
3437
3438 response.send(proto::RejoinChannelBuffersResponse {
3439 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3440 })?;
3441
3442 Ok(())
3443}
3444
3445/// Stop editing the channel notes
3446async fn leave_channel_buffer(
3447 request: proto::LeaveChannelBuffer,
3448 response: Response<proto::LeaveChannelBuffer>,
3449 session: MessageContext,
3450) -> Result<()> {
3451 let db = session.db().await;
3452 let channel_id = ChannelId::from_proto(request.channel_id);
3453
3454 let left_buffer = db
3455 .leave_channel_buffer(channel_id, session.connection_id)
3456 .await?;
3457
3458 response.send(Ack {})?;
3459
3460 channel_buffer_updated(
3461 session.connection_id,
3462 left_buffer.connections,
3463 &proto::UpdateChannelBufferCollaborators {
3464 channel_id: channel_id.to_proto(),
3465 collaborators: left_buffer.collaborators,
3466 },
3467 &session.peer,
3468 );
3469
3470 Ok(())
3471}
3472
3473fn channel_buffer_updated<T: EnvelopedMessage>(
3474 sender_id: ConnectionId,
3475 collaborators: impl IntoIterator<Item = ConnectionId>,
3476 message: &T,
3477 peer: &Peer,
3478) {
3479 broadcast(Some(sender_id), collaborators, |peer_id| {
3480 peer.send(peer_id, message.clone())
3481 });
3482}
3483
3484fn send_notifications(
3485 connection_pool: &ConnectionPool,
3486 peer: &Peer,
3487 notifications: db::NotificationBatch,
3488) {
3489 for (user_id, notification) in notifications {
3490 for connection_id in connection_pool.user_connection_ids(user_id) {
3491 if let Err(error) = peer.send(
3492 connection_id,
3493 proto::AddNotification {
3494 notification: Some(notification.clone()),
3495 },
3496 ) {
3497 tracing::error!(
3498 "failed to send notification to {:?} {}",
3499 connection_id,
3500 error
3501 );
3502 }
3503 }
3504 }
3505}
3506
3507/// Send a message to the channel
3508async fn send_channel_message(
3509 _request: proto::SendChannelMessage,
3510 _response: Response<proto::SendChannelMessage>,
3511 _session: MessageContext,
3512) -> Result<()> {
3513 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3514}
3515
3516/// Delete a channel message
3517async fn remove_channel_message(
3518 _request: proto::RemoveChannelMessage,
3519 _response: Response<proto::RemoveChannelMessage>,
3520 _session: MessageContext,
3521) -> Result<()> {
3522 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3523}
3524
3525async fn update_channel_message(
3526 _request: proto::UpdateChannelMessage,
3527 _response: Response<proto::UpdateChannelMessage>,
3528 _session: MessageContext,
3529) -> Result<()> {
3530 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3531}
3532
3533/// Mark a channel message as read
3534async fn acknowledge_channel_message(
3535 _request: proto::AckChannelMessage,
3536 _session: MessageContext,
3537) -> Result<()> {
3538 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3539}
3540
3541/// Mark a buffer version as synced
3542async fn acknowledge_buffer_version(
3543 request: proto::AckBufferOperation,
3544 session: MessageContext,
3545) -> Result<()> {
3546 let buffer_id = BufferId::from_proto(request.buffer_id);
3547 session
3548 .db()
3549 .await
3550 .observe_buffer_version(
3551 buffer_id,
3552 session.user_id(),
3553 request.epoch as i32,
3554 &request.version,
3555 )
3556 .await?;
3557 Ok(())
3558}
3559
3560/// Start receiving chat updates for a channel
3561async fn join_channel_chat(
3562 _request: proto::JoinChannelChat,
3563 _response: Response<proto::JoinChannelChat>,
3564 _session: MessageContext,
3565) -> Result<()> {
3566 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3567}
3568
3569/// Stop receiving chat updates for a channel
3570async fn leave_channel_chat(
3571 _request: proto::LeaveChannelChat,
3572 _session: MessageContext,
3573) -> Result<()> {
3574 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3575}
3576
3577/// Retrieve the chat history for a channel
3578async fn get_channel_messages(
3579 _request: proto::GetChannelMessages,
3580 _response: Response<proto::GetChannelMessages>,
3581 _session: MessageContext,
3582) -> Result<()> {
3583 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3584}
3585
3586/// Retrieve specific chat messages
3587async fn get_channel_messages_by_id(
3588 _request: proto::GetChannelMessagesById,
3589 _response: Response<proto::GetChannelMessagesById>,
3590 _session: MessageContext,
3591) -> Result<()> {
3592 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3593}
3594
3595/// Retrieve the current users notifications
3596async fn get_notifications(
3597 request: proto::GetNotifications,
3598 response: Response<proto::GetNotifications>,
3599 session: MessageContext,
3600) -> Result<()> {
3601 let notifications = session
3602 .db()
3603 .await
3604 .get_notifications(
3605 session.user_id(),
3606 NOTIFICATION_COUNT_PER_PAGE,
3607 request.before_id.map(db::NotificationId::from_proto),
3608 )
3609 .await?;
3610 response.send(proto::GetNotificationsResponse {
3611 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3612 notifications,
3613 })?;
3614 Ok(())
3615}
3616
3617/// Mark notifications as read
3618async fn mark_notification_as_read(
3619 request: proto::MarkNotificationRead,
3620 response: Response<proto::MarkNotificationRead>,
3621 session: MessageContext,
3622) -> Result<()> {
3623 let database = &session.db().await;
3624 let notifications = database
3625 .mark_notification_as_read_by_id(
3626 session.user_id(),
3627 NotificationId::from_proto(request.notification_id),
3628 )
3629 .await?;
3630 send_notifications(
3631 &*session.connection_pool().await,
3632 &session.peer,
3633 notifications,
3634 );
3635 response.send(proto::Ack {})?;
3636 Ok(())
3637}
3638
3639fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3640 let message = match message {
3641 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3642 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3643 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3644 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3645 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3646 code: frame.code.into(),
3647 reason: frame.reason.as_str().to_owned().into(),
3648 })),
3649 // We should never receive a frame while reading the message, according
3650 // to the `tungstenite` maintainers:
3651 //
3652 // > It cannot occur when you read messages from the WebSocket, but it
3653 // > can be used when you want to send the raw frames (e.g. you want to
3654 // > send the frames to the WebSocket without composing the full message first).
3655 // >
3656 // > — https://github.com/snapview/tungstenite-rs/issues/268
3657 TungsteniteMessage::Frame(_) => {
3658 bail!("received an unexpected frame while reading the message")
3659 }
3660 };
3661
3662 Ok(message)
3663}
3664
3665fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3666 match message {
3667 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3668 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3669 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3670 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3671 AxumMessage::Close(frame) => {
3672 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3673 code: frame.code.into(),
3674 reason: frame.reason.as_ref().into(),
3675 }))
3676 }
3677 }
3678}
3679
3680fn notify_membership_updated(
3681 connection_pool: &mut ConnectionPool,
3682 result: MembershipUpdated,
3683 user_id: UserId,
3684 peer: &Peer,
3685) {
3686 for membership in &result.new_channels.channel_memberships {
3687 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3688 }
3689 for channel_id in &result.removed_channels {
3690 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3691 }
3692
3693 let user_channels_update = proto::UpdateUserChannels {
3694 channel_memberships: result
3695 .new_channels
3696 .channel_memberships
3697 .iter()
3698 .map(|cm| proto::ChannelMembership {
3699 channel_id: cm.channel_id.to_proto(),
3700 role: cm.role.into(),
3701 })
3702 .collect(),
3703 ..Default::default()
3704 };
3705
3706 let mut update = build_channels_update(result.new_channels);
3707 update.delete_channels = result
3708 .removed_channels
3709 .into_iter()
3710 .map(|id| id.to_proto())
3711 .collect();
3712 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3713
3714 for connection_id in connection_pool.user_connection_ids(user_id) {
3715 peer.send(connection_id, user_channels_update.clone())
3716 .trace_err();
3717 peer.send(connection_id, update.clone()).trace_err();
3718 }
3719}
3720
3721fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3722 proto::UpdateUserChannels {
3723 channel_memberships: channels
3724 .channel_memberships
3725 .iter()
3726 .map(|m| proto::ChannelMembership {
3727 channel_id: m.channel_id.to_proto(),
3728 role: m.role.into(),
3729 })
3730 .collect(),
3731 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3732 }
3733}
3734
3735fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3736 let mut update = proto::UpdateChannels::default();
3737
3738 for channel in channels.channels {
3739 update.channels.push(channel.to_proto());
3740 }
3741
3742 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3743
3744 for (channel_id, participants) in channels.channel_participants {
3745 update
3746 .channel_participants
3747 .push(proto::ChannelParticipants {
3748 channel_id: channel_id.to_proto(),
3749 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3750 });
3751 }
3752
3753 for channel in channels.invited_channels {
3754 update.channel_invitations.push(channel.to_proto());
3755 }
3756
3757 update
3758}
3759
3760fn build_initial_contacts_update(
3761 contacts: Vec<db::Contact>,
3762 pool: &ConnectionPool,
3763) -> proto::UpdateContacts {
3764 let mut update = proto::UpdateContacts::default();
3765
3766 for contact in contacts {
3767 match contact {
3768 db::Contact::Accepted { user_id, busy } => {
3769 update.contacts.push(contact_for_user(user_id, busy, pool));
3770 }
3771 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3772 db::Contact::Incoming { user_id } => {
3773 update
3774 .incoming_requests
3775 .push(proto::IncomingContactRequest {
3776 requester_id: user_id.to_proto(),
3777 })
3778 }
3779 }
3780 }
3781
3782 update
3783}
3784
3785fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3786 proto::Contact {
3787 user_id: user_id.to_proto(),
3788 online: pool.is_user_online(user_id),
3789 busy,
3790 }
3791}
3792
3793fn room_updated(room: &proto::Room, peer: &Peer) {
3794 broadcast(
3795 None,
3796 room.participants
3797 .iter()
3798 .filter_map(|participant| Some(participant.peer_id?.into())),
3799 |peer_id| {
3800 peer.send(
3801 peer_id,
3802 proto::RoomUpdated {
3803 room: Some(room.clone()),
3804 },
3805 )
3806 },
3807 );
3808}
3809
3810fn channel_updated(
3811 channel: &db::channel::Model,
3812 room: &proto::Room,
3813 peer: &Peer,
3814 pool: &ConnectionPool,
3815) {
3816 let participants = room
3817 .participants
3818 .iter()
3819 .map(|p| p.user_id)
3820 .collect::<Vec<_>>();
3821
3822 broadcast(
3823 None,
3824 pool.channel_connection_ids(channel.root_id())
3825 .filter_map(|(channel_id, role)| {
3826 role.can_see_channel(channel.visibility)
3827 .then_some(channel_id)
3828 }),
3829 |peer_id| {
3830 peer.send(
3831 peer_id,
3832 proto::UpdateChannels {
3833 channel_participants: vec![proto::ChannelParticipants {
3834 channel_id: channel.id.to_proto(),
3835 participant_user_ids: participants.clone(),
3836 }],
3837 ..Default::default()
3838 },
3839 )
3840 },
3841 );
3842}
3843
3844async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3845 let db = session.db().await;
3846
3847 let contacts = db.get_contacts(user_id).await?;
3848 let busy = db.is_user_busy(user_id).await?;
3849
3850 let pool = session.connection_pool().await;
3851 let updated_contact = contact_for_user(user_id, busy, &pool);
3852 for contact in contacts {
3853 if let db::Contact::Accepted {
3854 user_id: contact_user_id,
3855 ..
3856 } = contact
3857 {
3858 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3859 session
3860 .peer
3861 .send(
3862 contact_conn_id,
3863 proto::UpdateContacts {
3864 contacts: vec![updated_contact.clone()],
3865 remove_contacts: Default::default(),
3866 incoming_requests: Default::default(),
3867 remove_incoming_requests: Default::default(),
3868 outgoing_requests: Default::default(),
3869 remove_outgoing_requests: Default::default(),
3870 },
3871 )
3872 .trace_err();
3873 }
3874 }
3875 }
3876 Ok(())
3877}
3878
3879async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3880 let mut contacts_to_update = HashSet::default();
3881
3882 let room_id;
3883 let canceled_calls_to_user_ids;
3884 let livekit_room;
3885 let delete_livekit_room;
3886 let room;
3887 let channel;
3888
3889 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3890 contacts_to_update.insert(session.user_id());
3891
3892 for project in left_room.left_projects.values() {
3893 project_left(project, session);
3894 }
3895
3896 room_id = RoomId::from_proto(left_room.room.id);
3897 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3898 livekit_room = mem::take(&mut left_room.room.livekit_room);
3899 delete_livekit_room = left_room.deleted;
3900 room = mem::take(&mut left_room.room);
3901 channel = mem::take(&mut left_room.channel);
3902
3903 room_updated(&room, &session.peer);
3904 } else {
3905 return Ok(());
3906 }
3907
3908 if let Some(channel) = channel {
3909 channel_updated(
3910 &channel,
3911 &room,
3912 &session.peer,
3913 &*session.connection_pool().await,
3914 );
3915 }
3916
3917 {
3918 let pool = session.connection_pool().await;
3919 for canceled_user_id in canceled_calls_to_user_ids {
3920 for connection_id in pool.user_connection_ids(canceled_user_id) {
3921 session
3922 .peer
3923 .send(
3924 connection_id,
3925 proto::CallCanceled {
3926 room_id: room_id.to_proto(),
3927 },
3928 )
3929 .trace_err();
3930 }
3931 contacts_to_update.insert(canceled_user_id);
3932 }
3933 }
3934
3935 for contact_user_id in contacts_to_update {
3936 update_user_contacts(contact_user_id, session).await?;
3937 }
3938
3939 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3940 live_kit
3941 .remove_participant(livekit_room.clone(), session.user_id().to_string())
3942 .await
3943 .trace_err();
3944
3945 if delete_livekit_room {
3946 live_kit.delete_room(livekit_room).await.trace_err();
3947 }
3948 }
3949
3950 Ok(())
3951}
3952
3953async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3954 let left_channel_buffers = session
3955 .db()
3956 .await
3957 .leave_channel_buffers(session.connection_id)
3958 .await?;
3959
3960 for left_buffer in left_channel_buffers {
3961 channel_buffer_updated(
3962 session.connection_id,
3963 left_buffer.connections,
3964 &proto::UpdateChannelBufferCollaborators {
3965 channel_id: left_buffer.channel_id.to_proto(),
3966 collaborators: left_buffer.collaborators,
3967 },
3968 &session.peer,
3969 );
3970 }
3971
3972 Ok(())
3973}
3974
3975fn project_left(project: &db::LeftProject, session: &Session) {
3976 for connection_id in &project.connection_ids {
3977 if project.should_unshare {
3978 session
3979 .peer
3980 .send(
3981 *connection_id,
3982 proto::UnshareProject {
3983 project_id: project.id.to_proto(),
3984 },
3985 )
3986 .trace_err();
3987 } else {
3988 session
3989 .peer
3990 .send(
3991 *connection_id,
3992 proto::RemoveProjectCollaborator {
3993 project_id: project.id.to_proto(),
3994 peer_id: Some(session.connection_id.into()),
3995 },
3996 )
3997 .trace_err();
3998 }
3999 }
4000}
4001
4002async fn share_agent_thread(
4003 request: proto::ShareAgentThread,
4004 response: Response<proto::ShareAgentThread>,
4005 session: MessageContext,
4006) -> Result<()> {
4007 let user_id = session.user_id();
4008
4009 let share_id = SharedThreadId::from_proto(request.session_id.clone())
4010 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4011
4012 session
4013 .db()
4014 .await
4015 .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4016 .await?;
4017
4018 response.send(proto::Ack {})?;
4019
4020 Ok(())
4021}
4022
4023async fn get_shared_agent_thread(
4024 request: proto::GetSharedAgentThread,
4025 response: Response<proto::GetSharedAgentThread>,
4026 session: MessageContext,
4027) -> Result<()> {
4028 let share_id = SharedThreadId::from_proto(request.session_id)
4029 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4030
4031 let result = session.db().await.get_shared_thread(share_id).await?;
4032
4033 match result {
4034 Some((thread, username)) => {
4035 response.send(proto::GetSharedAgentThreadResponse {
4036 title: thread.title,
4037 thread_data: thread.data,
4038 sharer_username: username,
4039 created_at: thread.created_at.and_utc().to_rfc3339(),
4040 })?;
4041 }
4042 None => {
4043 return Err(anyhow!("Shared thread not found").into());
4044 }
4045 }
4046
4047 Ok(())
4048}
4049
4050pub trait ResultExt {
4051 type Ok;
4052
4053 fn trace_err(self) -> Option<Self::Ok>;
4054}
4055
4056impl<T, E> ResultExt for Result<T, E>
4057where
4058 E: std::fmt::Debug,
4059{
4060 type Ok = T;
4061
4062 #[track_caller]
4063 fn trace_err(self) -> Option<T> {
4064 match self {
4065 Ok(value) => Some(value),
4066 Err(error) => {
4067 tracing::error!("{:?}", error);
4068 None
4069 }
4070 }
4071 }
4072}