1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
10 UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use futures::TryFutureExt as _;
36use rpc::proto::split_repository_update;
37use tracing::Span;
38use util::paths::PathStyle;
39
40use futures::{
41 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
42 stream::FuturesUnordered,
43};
44use prometheus::{IntGauge, register_int_gauge};
45use rpc::{
46 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
47 proto::{
48 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
49 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
50 },
51};
52use semver::Version;
53use std::{
54 any::TypeId,
55 future::Future,
56 marker::PhantomData,
57 mem,
58 net::SocketAddr,
59 ops::{Deref, DerefMut},
60 rc::Rc,
61 sync::{
62 Arc, OnceLock,
63 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
64 },
65 time::{Duration, Instant},
66};
67use tokio::sync::{Semaphore, watch};
68use tower::ServiceBuilder;
69use tracing::{
70 Instrument,
71 field::{self},
72 info_span, instrument,
73};
74
75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
76
77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
79
80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
82
83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
84
85const TOTAL_DURATION_MS: &str = "total_duration_ms";
86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
88const HOST_WAITING_MS: &str = "host_waiting_ms";
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
92
93pub struct ConnectionGuard;
94
95impl ConnectionGuard {
96 pub fn try_acquire() -> Result<Self, ()> {
97 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
98 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
99 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
100 tracing::error!(
101 "too many concurrent connections: {}",
102 current_connections + 1
103 );
104 return Err(());
105 }
106 Ok(ConnectionGuard)
107 }
108}
109
110impl Drop for ConnectionGuard {
111 fn drop(&mut self) {
112 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
113 }
114}
115
116struct Response<R> {
117 peer: Arc<Peer>,
118 receipt: Receipt<R>,
119 responded: Arc<AtomicBool>,
120}
121
122impl<R: RequestMessage> Response<R> {
123 fn send(self, payload: R::Response) -> Result<()> {
124 self.responded.store(true, SeqCst);
125 self.peer.respond(self.receipt, payload)?;
126 Ok(())
127 }
128}
129
130#[derive(Clone, Debug)]
131pub enum Principal {
132 User(User),
133}
134
135impl Principal {
136 fn update_span(&self, span: &tracing::Span) {
137 match &self {
138 Principal::User(user) => {
139 span.record("user_id", user.id.0);
140 span.record("login", &user.github_login);
141 }
142 }
143 }
144}
145
146#[derive(Clone)]
147struct MessageContext {
148 session: Session,
149 span: tracing::Span,
150}
151
152impl Deref for MessageContext {
153 type Target = Session;
154
155 fn deref(&self) -> &Self::Target {
156 &self.session
157 }
158}
159
160impl MessageContext {
161 pub fn forward_request<T: RequestMessage>(
162 &self,
163 receiver_id: ConnectionId,
164 request: T,
165 ) -> impl Future<Output = anyhow::Result<T::Response>> {
166 let request_start_time = Instant::now();
167 let span = self.span.clone();
168 tracing::info!("start forwarding request");
169 self.peer
170 .forward_request(self.connection_id, receiver_id, request)
171 .inspect(move |_| {
172 span.record(
173 HOST_WAITING_MS,
174 request_start_time.elapsed().as_micros() as f64 / 1000.0,
175 );
176 })
177 .inspect_err(|_| tracing::error!("error forwarding request"))
178 .inspect_ok(|_| tracing::info!("finished forwarding request"))
179 }
180}
181
182#[derive(Clone)]
183struct Session {
184 principal: Principal,
185 connection_id: ConnectionId,
186 db: Arc<tokio::sync::Mutex<DbHandle>>,
187 peer: Arc<Peer>,
188 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
189 app_state: Arc<AppState>,
190 /// The GeoIP country code for the user.
191 #[allow(unused)]
192 geoip_country_code: Option<String>,
193 #[allow(unused)]
194 system_id: Option<String>,
195 _executor: Executor,
196}
197
198impl Session {
199 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
200 #[cfg(feature = "test-support")]
201 tokio::task::yield_now().await;
202 let guard = self.db.lock().await;
203 #[cfg(feature = "test-support")]
204 tokio::task::yield_now().await;
205 guard
206 }
207
208 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
209 #[cfg(feature = "test-support")]
210 tokio::task::yield_now().await;
211 let guard = self.connection_pool.lock();
212 ConnectionPoolGuard {
213 guard,
214 _not_send: PhantomData,
215 }
216 }
217
218 #[expect(dead_code)]
219 fn is_staff(&self) -> bool {
220 match &self.principal {
221 Principal::User(user) => user.admin,
222 }
223 }
224
225 fn user_id(&self) -> UserId {
226 match &self.principal {
227 Principal::User(user) => user.id,
228 }
229 }
230}
231
232impl Debug for Session {
233 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
234 let mut result = f.debug_struct("Session");
235 match &self.principal {
236 Principal::User(user) => {
237 result.field("user", &user.github_login);
238 }
239 }
240 result.field("connection_id", &self.connection_id).finish()
241 }
242}
243
244struct DbHandle(Arc<Database>);
245
246impl Deref for DbHandle {
247 type Target = Database;
248
249 fn deref(&self) -> &Self::Target {
250 self.0.as_ref()
251 }
252}
253
254pub struct Server {
255 id: parking_lot::Mutex<ServerId>,
256 peer: Arc<Peer>,
257 pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
258 app_state: Arc<AppState>,
259 handlers: HashMap<TypeId, MessageHandler>,
260 teardown: watch::Sender<bool>,
261}
262
263struct ConnectionPoolGuard<'a> {
264 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
265 _not_send: PhantomData<Rc<()>>,
266}
267
268impl Server {
269 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
270 let mut server = Self {
271 id: parking_lot::Mutex::new(id),
272 peer: Peer::new(id.0 as u32),
273 app_state,
274 connection_pool: Default::default(),
275 handlers: Default::default(),
276 teardown: watch::channel(false).0,
277 };
278
279 server
280 .add_request_handler(ping)
281 .add_request_handler(create_room)
282 .add_request_handler(join_room)
283 .add_request_handler(rejoin_room)
284 .add_request_handler(leave_room)
285 .add_request_handler(set_room_participant_role)
286 .add_request_handler(call)
287 .add_request_handler(cancel_call)
288 .add_message_handler(decline_call)
289 .add_request_handler(update_participant_location)
290 .add_request_handler(share_project)
291 .add_message_handler(unshare_project)
292 .add_request_handler(join_project)
293 .add_message_handler(leave_project)
294 .add_request_handler(update_project)
295 .add_request_handler(update_worktree)
296 .add_request_handler(update_repository)
297 .add_request_handler(remove_repository)
298 .add_message_handler(start_language_server)
299 .add_message_handler(update_language_server)
300 .add_message_handler(update_diagnostic_summary)
301 .add_message_handler(update_worktree_settings)
302 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
307 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
308 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
311 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
313 .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
314 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
316 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
317 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
318 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
319 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
320 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
321 .add_request_handler(
322 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
323 )
324 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
325 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
326 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
327 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
328 .add_request_handler(
329 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
330 )
331 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
332 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
333 .add_request_handler(
334 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
335 )
336 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
337 .add_request_handler(
338 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
339 )
340 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
341 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
342 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
343 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
344 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
345 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
346 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
347 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
348 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
349 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
350 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
353 )
354 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
355 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
356 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
357 .add_request_handler(lsp_query)
358 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
359 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
360 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
361 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
362 .add_message_handler(create_buffer_for_peer)
363 .add_message_handler(create_image_for_peer)
364 .add_request_handler(update_buffer)
365 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
366 .add_message_handler(
367 broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
368 )
369 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
370 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
371 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
372 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
373 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
374 .add_message_handler(
375 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
376 )
377 .add_request_handler(get_users)
378 .add_request_handler(fuzzy_search_users)
379 .add_request_handler(request_contact)
380 .add_request_handler(remove_contact)
381 .add_request_handler(respond_to_contact_request)
382 .add_message_handler(subscribe_to_channels)
383 .add_request_handler(create_channel)
384 .add_request_handler(delete_channel)
385 .add_request_handler(invite_channel_member)
386 .add_request_handler(remove_channel_member)
387 .add_request_handler(set_channel_member_role)
388 .add_request_handler(set_channel_visibility)
389 .add_request_handler(rename_channel)
390 .add_request_handler(join_channel_buffer)
391 .add_request_handler(leave_channel_buffer)
392 .add_message_handler(update_channel_buffer)
393 .add_request_handler(rejoin_channel_buffers)
394 .add_request_handler(get_channel_members)
395 .add_request_handler(respond_to_channel_invite)
396 .add_request_handler(join_channel)
397 .add_request_handler(join_channel_chat)
398 .add_message_handler(leave_channel_chat)
399 .add_request_handler(send_channel_message)
400 .add_request_handler(remove_channel_message)
401 .add_request_handler(update_channel_message)
402 .add_request_handler(get_channel_messages)
403 .add_request_handler(get_channel_messages_by_id)
404 .add_request_handler(get_notifications)
405 .add_request_handler(mark_notification_as_read)
406 .add_request_handler(move_channel)
407 .add_request_handler(reorder_channel)
408 .add_request_handler(follow)
409 .add_message_handler(unfollow)
410 .add_message_handler(update_followers)
411 .add_message_handler(acknowledge_channel_message)
412 .add_message_handler(acknowledge_buffer_version)
413 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
414 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
415 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
416 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
417 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
418 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
419 .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
420 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
421 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
422 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
423 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
424 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
425 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
426 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
427 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
428 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
429 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
430 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
431 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
432 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
433 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
434 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
435 .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
436 .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
437 .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
438 .add_request_handler(forward_read_only_project_request::<proto::GitGetHeadSha>)
439 .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
440 .add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
441 .add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
442 .add_request_handler(forward_mutating_project_request::<proto::GitEditRef>)
443 .add_request_handler(forward_mutating_project_request::<proto::GitRepairWorktrees>)
444 .add_request_handler(disallow_guest_request::<proto::GitCreateArchiveCheckpoint>)
445 .add_request_handler(disallow_guest_request::<proto::GitRestoreArchiveCheckpoint>)
446 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
447 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
448 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
449 .add_request_handler(share_agent_thread)
450 .add_request_handler(get_shared_agent_thread)
451 .add_request_handler(forward_project_search_chunk);
452
453 Arc::new(server)
454 }
455
456 pub async fn start(&self) -> Result<()> {
457 let server_id = *self.id.lock();
458 let app_state = self.app_state.clone();
459 let peer = self.peer.clone();
460 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
461 let pool = self.connection_pool.clone();
462 let livekit_client = self.app_state.livekit_client.clone();
463
464 let span = info_span!("start server");
465 self.app_state.executor.spawn_detached(
466 async move {
467 tracing::info!("waiting for cleanup timeout");
468 timeout.await;
469 tracing::info!("cleanup timeout expired, retrieving stale rooms");
470
471 app_state
472 .db
473 .delete_stale_channel_chat_participants(
474 &app_state.config.zed_environment,
475 server_id,
476 )
477 .await
478 .trace_err();
479
480 if let Some((room_ids, channel_ids)) = app_state
481 .db
482 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
483 .await
484 .trace_err()
485 {
486 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
487 tracing::info!(
488 stale_channel_buffer_count = channel_ids.len(),
489 "retrieved stale channel buffers"
490 );
491
492 for channel_id in channel_ids {
493 if let Some(refreshed_channel_buffer) = app_state
494 .db
495 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
496 .await
497 .trace_err()
498 {
499 for connection_id in refreshed_channel_buffer.connection_ids {
500 peer.send(
501 connection_id,
502 proto::UpdateChannelBufferCollaborators {
503 channel_id: channel_id.to_proto(),
504 collaborators: refreshed_channel_buffer
505 .collaborators
506 .clone(),
507 },
508 )
509 .trace_err();
510 }
511 }
512 }
513
514 for room_id in room_ids {
515 let mut contacts_to_update = HashSet::default();
516 let mut canceled_calls_to_user_ids = Vec::new();
517 let mut livekit_room = String::new();
518 let mut delete_livekit_room = false;
519
520 if let Some(mut refreshed_room) = app_state
521 .db
522 .clear_stale_room_participants(room_id, server_id)
523 .await
524 .trace_err()
525 {
526 tracing::info!(
527 room_id = room_id.0,
528 new_participant_count = refreshed_room.room.participants.len(),
529 "refreshed room"
530 );
531 room_updated(&refreshed_room.room, &peer);
532 if let Some(channel) = refreshed_room.channel.as_ref() {
533 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
534 }
535 contacts_to_update
536 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
537 contacts_to_update
538 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
539 canceled_calls_to_user_ids =
540 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
541 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
542 delete_livekit_room = refreshed_room.room.participants.is_empty();
543 }
544
545 {
546 let pool = pool.lock();
547 for canceled_user_id in canceled_calls_to_user_ids {
548 for connection_id in pool.user_connection_ids(canceled_user_id) {
549 peer.send(
550 connection_id,
551 proto::CallCanceled {
552 room_id: room_id.to_proto(),
553 },
554 )
555 .trace_err();
556 }
557 }
558 }
559
560 for user_id in contacts_to_update {
561 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
562 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
563 if let Some((busy, contacts)) = busy.zip(contacts) {
564 let pool = pool.lock();
565 let updated_contact = contact_for_user(user_id, busy, &pool);
566 for contact in contacts {
567 if let db::Contact::Accepted {
568 user_id: contact_user_id,
569 ..
570 } = contact
571 {
572 for contact_conn_id in
573 pool.user_connection_ids(contact_user_id)
574 {
575 peer.send(
576 contact_conn_id,
577 proto::UpdateContacts {
578 contacts: vec![updated_contact.clone()],
579 remove_contacts: Default::default(),
580 incoming_requests: Default::default(),
581 remove_incoming_requests: Default::default(),
582 outgoing_requests: Default::default(),
583 remove_outgoing_requests: Default::default(),
584 },
585 )
586 .trace_err();
587 }
588 }
589 }
590 }
591 }
592
593 if let Some(live_kit) = livekit_client.as_ref()
594 && delete_livekit_room
595 {
596 live_kit.delete_room(livekit_room).await.trace_err();
597 }
598 }
599 }
600
601 app_state
602 .db
603 .delete_stale_channel_chat_participants(
604 &app_state.config.zed_environment,
605 server_id,
606 )
607 .await
608 .trace_err();
609
610 app_state
611 .db
612 .clear_old_worktree_entries(server_id)
613 .await
614 .trace_err();
615
616 app_state
617 .db
618 .delete_stale_servers(&app_state.config.zed_environment, server_id)
619 .await
620 .trace_err();
621 }
622 .instrument(span),
623 );
624 Ok(())
625 }
626
627 pub fn teardown(&self) {
628 self.peer.teardown();
629 self.connection_pool.lock().reset();
630 let _ = self.teardown.send(true);
631 }
632
633 #[cfg(feature = "test-support")]
634 pub fn reset(&self, id: ServerId) {
635 self.teardown();
636 *self.id.lock() = id;
637 self.peer.reset(id.0 as u32);
638 let _ = self.teardown.send(false);
639 }
640
641 #[cfg(feature = "test-support")]
642 pub fn id(&self) -> ServerId {
643 *self.id.lock()
644 }
645
646 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
647 where
648 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
649 Fut: 'static + Send + Future<Output = Result<()>>,
650 M: EnvelopedMessage,
651 {
652 let prev_handler = self.handlers.insert(
653 TypeId::of::<M>(),
654 Box::new(move |envelope, session, span| {
655 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
656 let received_at = envelope.received_at;
657 tracing::info!("message received");
658 let start_time = Instant::now();
659 let future = (handler)(
660 *envelope,
661 MessageContext {
662 session,
663 span: span.clone(),
664 },
665 );
666 async move {
667 let result = future.await;
668 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
669 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
670 let queue_duration_ms = total_duration_ms - processing_duration_ms;
671 span.record(TOTAL_DURATION_MS, total_duration_ms);
672 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
673 span.record(QUEUE_DURATION_MS, queue_duration_ms);
674 match result {
675 Err(error) => {
676 tracing::error!(?error, "error handling message")
677 }
678 Ok(()) => tracing::info!("finished handling message"),
679 }
680 }
681 .boxed()
682 }),
683 );
684 if prev_handler.is_some() {
685 panic!("registered a handler for the same message twice");
686 }
687 self
688 }
689
690 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
691 where
692 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
693 Fut: 'static + Send + Future<Output = Result<()>>,
694 M: EnvelopedMessage,
695 {
696 self.add_handler(move |envelope, session| handler(envelope.payload, session));
697 self
698 }
699
700 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
701 where
702 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
703 Fut: Send + Future<Output = Result<()>>,
704 M: RequestMessage,
705 {
706 let handler = Arc::new(handler);
707 self.add_handler(move |envelope, session| {
708 let receipt = envelope.receipt();
709 let handler = handler.clone();
710 async move {
711 let peer = session.peer.clone();
712 let responded = Arc::new(AtomicBool::default());
713 let response = Response {
714 peer: peer.clone(),
715 responded: responded.clone(),
716 receipt,
717 };
718 match (handler)(envelope.payload, response, session).await {
719 Ok(()) => {
720 if responded.load(std::sync::atomic::Ordering::SeqCst) {
721 Ok(())
722 } else {
723 Err(anyhow!("handler did not send a response"))?
724 }
725 }
726 Err(error) => {
727 let proto_err = match &error {
728 Error::Internal(err) => err.to_proto(),
729 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
730 };
731 peer.respond_with_error(receipt, proto_err)?;
732 Err(error)
733 }
734 }
735 }
736 })
737 }
738
739 pub fn handle_connection(
740 self: &Arc<Self>,
741 connection: Connection,
742 address: String,
743 principal: Principal,
744 zed_version: ZedVersion,
745 release_channel: Option<String>,
746 user_agent: Option<String>,
747 geoip_country_code: Option<String>,
748 system_id: Option<String>,
749 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
750 executor: Executor,
751 connection_guard: Option<ConnectionGuard>,
752 ) -> impl Future<Output = ()> + use<> {
753 let this = self.clone();
754 let span = info_span!("handle connection", %address,
755 connection_id=field::Empty,
756 user_id=field::Empty,
757 login=field::Empty,
758 user_agent=field::Empty,
759 geoip_country_code=field::Empty,
760 release_channel=field::Empty,
761 );
762 principal.update_span(&span);
763 if let Some(user_agent) = user_agent {
764 span.record("user_agent", user_agent);
765 }
766 if let Some(release_channel) = release_channel {
767 span.record("release_channel", release_channel);
768 }
769
770 if let Some(country_code) = geoip_country_code.as_ref() {
771 span.record("geoip_country_code", country_code);
772 }
773
774 let mut teardown = self.teardown.subscribe();
775 async move {
776 if *teardown.borrow() {
777 tracing::error!("server is tearing down");
778 return;
779 }
780
781 let (connection_id, handle_io, mut incoming_rx) =
782 this.peer.add_connection(connection, {
783 let executor = executor.clone();
784 move |duration| executor.sleep(duration)
785 });
786 tracing::Span::current().record("connection_id", format!("{}", connection_id));
787
788 tracing::info!("connection opened");
789
790 let session = Session {
791 principal: principal.clone(),
792 connection_id,
793 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
794 peer: this.peer.clone(),
795 connection_pool: this.connection_pool.clone(),
796 app_state: this.app_state.clone(),
797 geoip_country_code,
798 system_id,
799 _executor: executor.clone(),
800 };
801
802 if let Err(error) = this
803 .send_initial_client_update(
804 connection_id,
805 zed_version,
806 send_connection_id,
807 &session,
808 )
809 .await
810 {
811 tracing::error!(?error, "failed to send initial client update");
812 return;
813 }
814 drop(connection_guard);
815
816 let handle_io = handle_io.fuse();
817 futures::pin_mut!(handle_io);
818
819 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
820 // This prevents deadlocks when e.g., client A performs a request to client B and
821 // client B performs a request to client A. If both clients stop processing further
822 // messages until their respective request completes, they won't have a chance to
823 // respond to the other client's request and cause a deadlock.
824 //
825 // This arrangement ensures we will attempt to process earlier messages first, but fall
826 // back to processing messages arrived later in the spirit of making progress.
827 const MAX_CONCURRENT_HANDLERS: usize = 256;
828 let mut foreground_message_handlers = FuturesUnordered::new();
829 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
830 let get_concurrent_handlers = {
831 let concurrent_handlers = concurrent_handlers.clone();
832 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
833 };
834 loop {
835 let next_message = async {
836 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
837 let message = incoming_rx.next().await;
838 // Cache the concurrent_handlers here, so that we know what the
839 // queue looks like as each handler starts
840 (permit, message, get_concurrent_handlers())
841 }
842 .fuse();
843 futures::pin_mut!(next_message);
844 futures::select_biased! {
845 _ = teardown.changed().fuse() => return,
846 result = handle_io => {
847 if let Err(error) = result {
848 tracing::error!(?error, "error handling I/O");
849 }
850 break;
851 }
852 _ = foreground_message_handlers.next() => {}
853 next_message = next_message => {
854 let (permit, message, concurrent_handlers) = next_message;
855 if let Some(message) = message {
856 let type_name = message.payload_type_name();
857 // note: we copy all the fields from the parent span so we can query them in the logs.
858 // (https://github.com/tokio-rs/tracing/issues/2670).
859 let span = tracing::info_span!("receive message",
860 %connection_id,
861 %address,
862 type_name,
863 concurrent_handlers,
864 user_id=field::Empty,
865 login=field::Empty,
866 lsp_query_request=field::Empty,
867 release_channel=field::Empty,
868 { TOTAL_DURATION_MS }=field::Empty,
869 { PROCESSING_DURATION_MS }=field::Empty,
870 { QUEUE_DURATION_MS }=field::Empty,
871 { HOST_WAITING_MS }=field::Empty
872 );
873 principal.update_span(&span);
874 let span_enter = span.enter();
875 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
876 let is_background = message.is_background();
877 let handle_message = (handler)(message, session.clone(), span.clone());
878 drop(span_enter);
879
880 let handle_message = async move {
881 handle_message.await;
882 drop(permit);
883 }.instrument(span);
884 if is_background {
885 executor.spawn_detached(handle_message);
886 } else {
887 foreground_message_handlers.push(handle_message);
888 }
889 } else {
890 tracing::error!("no message handler");
891 }
892 } else {
893 tracing::info!("connection closed");
894 break;
895 }
896 }
897 }
898 }
899
900 drop(foreground_message_handlers);
901 let concurrent_handlers = get_concurrent_handlers();
902 tracing::info!(concurrent_handlers, "signing out");
903 if let Err(error) = connection_lost(session, teardown, executor).await {
904 tracing::error!(?error, "error signing out");
905 }
906 }
907 .instrument(span)
908 }
909
910 async fn send_initial_client_update(
911 &self,
912 connection_id: ConnectionId,
913 zed_version: ZedVersion,
914 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
915 session: &Session,
916 ) -> Result<()> {
917 self.peer.send(
918 connection_id,
919 proto::Hello {
920 peer_id: Some(connection_id.into()),
921 },
922 )?;
923 tracing::info!("sent hello message");
924 if let Some(send_connection_id) = send_connection_id.take() {
925 let _ = send_connection_id.send(connection_id);
926 }
927
928 match &session.principal {
929 Principal::User(user) => {
930 if !user.connected_once {
931 self.peer.send(connection_id, proto::ShowContacts {})?;
932 self.app_state
933 .db
934 .set_user_connected_once(user.id, true)
935 .await?;
936 }
937
938 let contacts = self.app_state.db.get_contacts(user.id).await?;
939
940 {
941 let mut pool = self.connection_pool.lock();
942 pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
943 self.peer.send(
944 connection_id,
945 build_initial_contacts_update(contacts, &pool),
946 )?;
947 }
948
949 if should_auto_subscribe_to_channels(&zed_version) {
950 subscribe_user_to_channels(user.id, session).await?;
951 }
952
953 if let Some(incoming_call) =
954 self.app_state.db.incoming_call_for_user(user.id).await?
955 {
956 self.peer.send(connection_id, incoming_call)?;
957 }
958
959 update_user_contacts(user.id, session).await?;
960 }
961 }
962
963 Ok(())
964 }
965}
966
967impl Deref for ConnectionPoolGuard<'_> {
968 type Target = ConnectionPool;
969
970 fn deref(&self) -> &Self::Target {
971 &self.guard
972 }
973}
974
975impl DerefMut for ConnectionPoolGuard<'_> {
976 fn deref_mut(&mut self) -> &mut Self::Target {
977 &mut self.guard
978 }
979}
980
981impl Drop for ConnectionPoolGuard<'_> {
982 fn drop(&mut self) {
983 #[cfg(feature = "test-support")]
984 self.check_invariants();
985 }
986}
987
988fn broadcast<F>(
989 sender_id: Option<ConnectionId>,
990 receiver_ids: impl IntoIterator<Item = ConnectionId>,
991 mut f: F,
992) where
993 F: FnMut(ConnectionId) -> anyhow::Result<()>,
994{
995 for receiver_id in receiver_ids {
996 if Some(receiver_id) != sender_id
997 && let Err(error) = f(receiver_id)
998 {
999 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1000 }
1001 }
1002}
1003
1004pub struct ProtocolVersion(u32);
1005
1006impl Header for ProtocolVersion {
1007 fn name() -> &'static HeaderName {
1008 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1009 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1010 }
1011
1012 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1013 where
1014 Self: Sized,
1015 I: Iterator<Item = &'i axum::http::HeaderValue>,
1016 {
1017 let version = values
1018 .next()
1019 .ok_or_else(axum::headers::Error::invalid)?
1020 .to_str()
1021 .map_err(|_| axum::headers::Error::invalid())?
1022 .parse()
1023 .map_err(|_| axum::headers::Error::invalid())?;
1024 Ok(Self(version))
1025 }
1026
1027 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1028 values.extend([self.0.to_string().parse().unwrap()]);
1029 }
1030}
1031
1032pub struct AppVersionHeader(Version);
1033impl Header for AppVersionHeader {
1034 fn name() -> &'static HeaderName {
1035 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1036 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1037 }
1038
1039 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1040 where
1041 Self: Sized,
1042 I: Iterator<Item = &'i axum::http::HeaderValue>,
1043 {
1044 let version = values
1045 .next()
1046 .ok_or_else(axum::headers::Error::invalid)?
1047 .to_str()
1048 .map_err(|_| axum::headers::Error::invalid())?
1049 .parse()
1050 .map_err(|_| axum::headers::Error::invalid())?;
1051 Ok(Self(version))
1052 }
1053
1054 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1055 values.extend([self.0.to_string().parse().unwrap()]);
1056 }
1057}
1058
1059#[derive(Debug)]
1060pub struct ReleaseChannelHeader(String);
1061
1062impl Header for ReleaseChannelHeader {
1063 fn name() -> &'static HeaderName {
1064 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1065 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1066 }
1067
1068 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1069 where
1070 Self: Sized,
1071 I: Iterator<Item = &'i axum::http::HeaderValue>,
1072 {
1073 Ok(Self(
1074 values
1075 .next()
1076 .ok_or_else(axum::headers::Error::invalid)?
1077 .to_str()
1078 .map_err(|_| axum::headers::Error::invalid())?
1079 .to_owned(),
1080 ))
1081 }
1082
1083 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1084 values.extend([self.0.parse().unwrap()]);
1085 }
1086}
1087
1088pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1089 Router::new()
1090 .route("/rpc", get(handle_websocket_request))
1091 .layer(
1092 ServiceBuilder::new()
1093 .layer(Extension(server.app_state.clone()))
1094 .layer(middleware::from_fn(auth::validate_header)),
1095 )
1096 .route("/metrics", get(handle_metrics))
1097 .layer(Extension(server))
1098}
1099
1100pub async fn handle_websocket_request(
1101 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1102 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1103 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1104 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1105 Extension(server): Extension<Arc<Server>>,
1106 Extension(principal): Extension<Principal>,
1107 user_agent: Option<TypedHeader<UserAgent>>,
1108 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1109 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1110 ws: WebSocketUpgrade,
1111) -> axum::response::Response {
1112 if protocol_version != rpc::PROTOCOL_VERSION {
1113 return (
1114 StatusCode::UPGRADE_REQUIRED,
1115 "client must be upgraded".to_string(),
1116 )
1117 .into_response();
1118 }
1119
1120 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1121 return (
1122 StatusCode::UPGRADE_REQUIRED,
1123 "no version header found".to_string(),
1124 )
1125 .into_response();
1126 };
1127
1128 let release_channel = release_channel_header.map(|header| header.0.0);
1129
1130 if !version.can_collaborate() {
1131 return (
1132 StatusCode::UPGRADE_REQUIRED,
1133 "client must be upgraded".to_string(),
1134 )
1135 .into_response();
1136 }
1137
1138 let socket_address = socket_address.to_string();
1139
1140 // Acquire connection guard before WebSocket upgrade
1141 let connection_guard = match ConnectionGuard::try_acquire() {
1142 Ok(guard) => guard,
1143 Err(()) => {
1144 return (
1145 StatusCode::SERVICE_UNAVAILABLE,
1146 "Too many concurrent connections",
1147 )
1148 .into_response();
1149 }
1150 };
1151
1152 ws.on_upgrade(move |socket| {
1153 let socket = socket
1154 .map_ok(to_tungstenite_message)
1155 .err_into()
1156 .with(|message| async move { to_axum_message(message) });
1157 let connection = Connection::new(Box::pin(socket));
1158 async move {
1159 server
1160 .handle_connection(
1161 connection,
1162 socket_address,
1163 principal,
1164 version,
1165 release_channel,
1166 user_agent.map(|header| header.to_string()),
1167 country_code_header.map(|header| header.to_string()),
1168 system_id_header.map(|header| header.to_string()),
1169 None,
1170 Executor::Production,
1171 Some(connection_guard),
1172 )
1173 .await;
1174 }
1175 })
1176}
1177
1178pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1179 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1180 let connections_metric = CONNECTIONS_METRIC
1181 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1182
1183 let connections = server
1184 .connection_pool
1185 .lock()
1186 .connections()
1187 .filter(|connection| !connection.admin)
1188 .count();
1189 connections_metric.set(connections as _);
1190
1191 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1192 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1193 register_int_gauge!(
1194 "shared_projects",
1195 "number of open projects with one or more guests"
1196 )
1197 .unwrap()
1198 });
1199
1200 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1201 shared_projects_metric.set(shared_projects as _);
1202
1203 let encoder = prometheus::TextEncoder::new();
1204 let metric_families = prometheus::gather();
1205 let encoded_metrics = encoder
1206 .encode_to_string(&metric_families)
1207 .map_err(|err| anyhow!("{err}"))?;
1208 Ok(encoded_metrics)
1209}
1210
1211#[instrument(err, skip(executor))]
1212async fn connection_lost(
1213 session: Session,
1214 mut teardown: watch::Receiver<bool>,
1215 executor: Executor,
1216) -> Result<()> {
1217 session.peer.disconnect(session.connection_id);
1218 session
1219 .connection_pool()
1220 .await
1221 .remove_connection(session.connection_id)?;
1222
1223 session
1224 .db()
1225 .await
1226 .connection_lost(session.connection_id)
1227 .await
1228 .trace_err();
1229
1230 futures::select_biased! {
1231 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1232
1233 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1234 leave_room_for_session(&session, session.connection_id).await.trace_err();
1235 leave_channel_buffers_for_session(&session)
1236 .await
1237 .trace_err();
1238
1239 if !session
1240 .connection_pool()
1241 .await
1242 .is_user_online(session.user_id())
1243 {
1244 let db = session.db().await;
1245 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1246 room_updated(&room, &session.peer);
1247 }
1248 }
1249
1250 update_user_contacts(session.user_id(), &session).await?;
1251 },
1252 _ = teardown.changed().fuse() => {}
1253 }
1254
1255 Ok(())
1256}
1257
1258/// Acknowledges a ping from a client, used to keep the connection alive.
1259async fn ping(
1260 _: proto::Ping,
1261 response: Response<proto::Ping>,
1262 _session: MessageContext,
1263) -> Result<()> {
1264 response.send(proto::Ack {})?;
1265 Ok(())
1266}
1267
1268/// Creates a new room for calling (outside of channels)
1269async fn create_room(
1270 _request: proto::CreateRoom,
1271 response: Response<proto::CreateRoom>,
1272 session: MessageContext,
1273) -> Result<()> {
1274 let livekit_room = nanoid::nanoid!(30);
1275
1276 let live_kit_connection_info = util::maybe!(async {
1277 let live_kit = session.app_state.livekit_client.as_ref();
1278 let live_kit = live_kit?;
1279 let user_id = session.user_id().to_string();
1280
1281 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1282
1283 Some(proto::LiveKitConnectionInfo {
1284 server_url: live_kit.url().into(),
1285 token,
1286 can_publish: true,
1287 })
1288 })
1289 .await;
1290
1291 let room = session
1292 .db()
1293 .await
1294 .create_room(session.user_id(), session.connection_id, &livekit_room)
1295 .await?;
1296
1297 response.send(proto::CreateRoomResponse {
1298 room: Some(room.clone()),
1299 live_kit_connection_info,
1300 })?;
1301
1302 update_user_contacts(session.user_id(), &session).await?;
1303 Ok(())
1304}
1305
1306/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1307async fn join_room(
1308 request: proto::JoinRoom,
1309 response: Response<proto::JoinRoom>,
1310 session: MessageContext,
1311) -> Result<()> {
1312 let room_id = RoomId::from_proto(request.id);
1313
1314 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1315
1316 if let Some(channel_id) = channel_id {
1317 return join_channel_internal(channel_id, Box::new(response), session).await;
1318 }
1319
1320 let joined_room = {
1321 let room = session
1322 .db()
1323 .await
1324 .join_room(room_id, session.user_id(), session.connection_id)
1325 .await?;
1326 room_updated(&room.room, &session.peer);
1327 room.into_inner()
1328 };
1329
1330 for connection_id in session
1331 .connection_pool()
1332 .await
1333 .user_connection_ids(session.user_id())
1334 {
1335 session
1336 .peer
1337 .send(
1338 connection_id,
1339 proto::CallCanceled {
1340 room_id: room_id.to_proto(),
1341 },
1342 )
1343 .trace_err();
1344 }
1345
1346 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1347 {
1348 live_kit
1349 .room_token(
1350 &joined_room.room.livekit_room,
1351 &session.user_id().to_string(),
1352 )
1353 .trace_err()
1354 .map(|token| proto::LiveKitConnectionInfo {
1355 server_url: live_kit.url().into(),
1356 token,
1357 can_publish: true,
1358 })
1359 } else {
1360 None
1361 };
1362
1363 response.send(proto::JoinRoomResponse {
1364 room: Some(joined_room.room),
1365 channel_id: None,
1366 live_kit_connection_info,
1367 })?;
1368
1369 update_user_contacts(session.user_id(), &session).await?;
1370 Ok(())
1371}
1372
1373/// Rejoin room is used to reconnect to a room after connection errors.
1374async fn rejoin_room(
1375 request: proto::RejoinRoom,
1376 response: Response<proto::RejoinRoom>,
1377 session: MessageContext,
1378) -> Result<()> {
1379 let room;
1380 let channel;
1381 {
1382 let mut rejoined_room = session
1383 .db()
1384 .await
1385 .rejoin_room(request, session.user_id(), session.connection_id)
1386 .await?;
1387
1388 response.send(proto::RejoinRoomResponse {
1389 room: Some(rejoined_room.room.clone()),
1390 reshared_projects: rejoined_room
1391 .reshared_projects
1392 .iter()
1393 .map(|project| proto::ResharedProject {
1394 id: project.id.to_proto(),
1395 collaborators: project
1396 .collaborators
1397 .iter()
1398 .map(|collaborator| collaborator.to_proto())
1399 .collect(),
1400 })
1401 .collect(),
1402 rejoined_projects: rejoined_room
1403 .rejoined_projects
1404 .iter()
1405 .map(|rejoined_project| rejoined_project.to_proto())
1406 .collect(),
1407 })?;
1408 room_updated(&rejoined_room.room, &session.peer);
1409
1410 for project in &rejoined_room.reshared_projects {
1411 for collaborator in &project.collaborators {
1412 session
1413 .peer
1414 .send(
1415 collaborator.connection_id,
1416 proto::UpdateProjectCollaborator {
1417 project_id: project.id.to_proto(),
1418 old_peer_id: Some(project.old_connection_id.into()),
1419 new_peer_id: Some(session.connection_id.into()),
1420 },
1421 )
1422 .trace_err();
1423 }
1424
1425 broadcast(
1426 Some(session.connection_id),
1427 project
1428 .collaborators
1429 .iter()
1430 .map(|collaborator| collaborator.connection_id),
1431 |connection_id| {
1432 session.peer.forward_send(
1433 session.connection_id,
1434 connection_id,
1435 proto::UpdateProject {
1436 project_id: project.id.to_proto(),
1437 worktrees: project.worktrees.clone(),
1438 },
1439 )
1440 },
1441 );
1442 }
1443
1444 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1445
1446 let rejoined_room = rejoined_room.into_inner();
1447
1448 room = rejoined_room.room;
1449 channel = rejoined_room.channel;
1450 }
1451
1452 if let Some(channel) = channel {
1453 channel_updated(
1454 &channel,
1455 &room,
1456 &session.peer,
1457 &*session.connection_pool().await,
1458 );
1459 }
1460
1461 update_user_contacts(session.user_id(), &session).await?;
1462 Ok(())
1463}
1464
1465fn notify_rejoined_projects(
1466 rejoined_projects: &mut Vec<RejoinedProject>,
1467 session: &Session,
1468) -> Result<()> {
1469 for project in rejoined_projects.iter() {
1470 for collaborator in &project.collaborators {
1471 session
1472 .peer
1473 .send(
1474 collaborator.connection_id,
1475 proto::UpdateProjectCollaborator {
1476 project_id: project.id.to_proto(),
1477 old_peer_id: Some(project.old_connection_id.into()),
1478 new_peer_id: Some(session.connection_id.into()),
1479 },
1480 )
1481 .trace_err();
1482 }
1483 }
1484
1485 for project in rejoined_projects {
1486 for worktree in mem::take(&mut project.worktrees) {
1487 // Stream this worktree's entries.
1488 let message = proto::UpdateWorktree {
1489 project_id: project.id.to_proto(),
1490 worktree_id: worktree.id,
1491 abs_path: worktree.abs_path.clone(),
1492 root_name: worktree.root_name,
1493 root_repo_common_dir: worktree.root_repo_common_dir,
1494 updated_entries: worktree.updated_entries,
1495 removed_entries: worktree.removed_entries,
1496 scan_id: worktree.scan_id,
1497 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1498 updated_repositories: worktree.updated_repositories,
1499 removed_repositories: worktree.removed_repositories,
1500 };
1501 for update in proto::split_worktree_update(message) {
1502 session.peer.send(session.connection_id, update)?;
1503 }
1504
1505 // Stream this worktree's diagnostics.
1506 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1507 if let Some(summary) = worktree_diagnostics.next() {
1508 let message = proto::UpdateDiagnosticSummary {
1509 project_id: project.id.to_proto(),
1510 worktree_id: worktree.id,
1511 summary: Some(summary),
1512 more_summaries: worktree_diagnostics.collect(),
1513 };
1514 session.peer.send(session.connection_id, message)?;
1515 }
1516
1517 for settings_file in worktree.settings_files {
1518 session.peer.send(
1519 session.connection_id,
1520 proto::UpdateWorktreeSettings {
1521 project_id: project.id.to_proto(),
1522 worktree_id: worktree.id,
1523 path: settings_file.path,
1524 content: Some(settings_file.content),
1525 kind: Some(settings_file.kind.to_proto().into()),
1526 outside_worktree: Some(settings_file.outside_worktree),
1527 },
1528 )?;
1529 }
1530 }
1531
1532 for repository in mem::take(&mut project.updated_repositories) {
1533 for update in split_repository_update(repository) {
1534 session.peer.send(session.connection_id, update)?;
1535 }
1536 }
1537
1538 for id in mem::take(&mut project.removed_repositories) {
1539 session.peer.send(
1540 session.connection_id,
1541 proto::RemoveRepository {
1542 project_id: project.id.to_proto(),
1543 id,
1544 },
1545 )?;
1546 }
1547 }
1548
1549 Ok(())
1550}
1551
1552/// leave room disconnects from the room.
1553async fn leave_room(
1554 _: proto::LeaveRoom,
1555 response: Response<proto::LeaveRoom>,
1556 session: MessageContext,
1557) -> Result<()> {
1558 leave_room_for_session(&session, session.connection_id).await?;
1559 response.send(proto::Ack {})?;
1560 Ok(())
1561}
1562
1563/// Updates the permissions of someone else in the room.
1564async fn set_room_participant_role(
1565 request: proto::SetRoomParticipantRole,
1566 response: Response<proto::SetRoomParticipantRole>,
1567 session: MessageContext,
1568) -> Result<()> {
1569 let user_id = UserId::from_proto(request.user_id);
1570 let role = ChannelRole::from(request.role());
1571
1572 let (livekit_room, can_publish) = {
1573 let room = session
1574 .db()
1575 .await
1576 .set_room_participant_role(
1577 session.user_id(),
1578 RoomId::from_proto(request.room_id),
1579 user_id,
1580 role,
1581 )
1582 .await?;
1583
1584 let livekit_room = room.livekit_room.clone();
1585 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1586 room_updated(&room, &session.peer);
1587 (livekit_room, can_publish)
1588 };
1589
1590 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1591 live_kit
1592 .update_participant(
1593 livekit_room.clone(),
1594 request.user_id.to_string(),
1595 livekit_api::proto::ParticipantPermission {
1596 can_subscribe: true,
1597 can_publish,
1598 can_publish_data: can_publish,
1599 hidden: false,
1600 recorder: false,
1601 },
1602 )
1603 .await
1604 .trace_err();
1605 }
1606
1607 response.send(proto::Ack {})?;
1608 Ok(())
1609}
1610
1611/// Call someone else into the current room
1612async fn call(
1613 request: proto::Call,
1614 response: Response<proto::Call>,
1615 session: MessageContext,
1616) -> Result<()> {
1617 let room_id = RoomId::from_proto(request.room_id);
1618 let calling_user_id = session.user_id();
1619 let calling_connection_id = session.connection_id;
1620 let called_user_id = UserId::from_proto(request.called_user_id);
1621 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1622 if !session
1623 .db()
1624 .await
1625 .has_contact(calling_user_id, called_user_id)
1626 .await?
1627 {
1628 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1629 }
1630
1631 let incoming_call = {
1632 let (room, incoming_call) = &mut *session
1633 .db()
1634 .await
1635 .call(
1636 room_id,
1637 calling_user_id,
1638 calling_connection_id,
1639 called_user_id,
1640 initial_project_id,
1641 )
1642 .await?;
1643 room_updated(room, &session.peer);
1644 mem::take(incoming_call)
1645 };
1646 update_user_contacts(called_user_id, &session).await?;
1647
1648 let mut calls = session
1649 .connection_pool()
1650 .await
1651 .user_connection_ids(called_user_id)
1652 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1653 .collect::<FuturesUnordered<_>>();
1654
1655 while let Some(call_response) = calls.next().await {
1656 match call_response.as_ref() {
1657 Ok(_) => {
1658 response.send(proto::Ack {})?;
1659 return Ok(());
1660 }
1661 Err(_) => {
1662 call_response.trace_err();
1663 }
1664 }
1665 }
1666
1667 {
1668 let room = session
1669 .db()
1670 .await
1671 .call_failed(room_id, called_user_id)
1672 .await?;
1673 room_updated(&room, &session.peer);
1674 }
1675 update_user_contacts(called_user_id, &session).await?;
1676
1677 Err(anyhow!("failed to ring user"))?
1678}
1679
1680/// Cancel an outgoing call.
1681async fn cancel_call(
1682 request: proto::CancelCall,
1683 response: Response<proto::CancelCall>,
1684 session: MessageContext,
1685) -> Result<()> {
1686 let called_user_id = UserId::from_proto(request.called_user_id);
1687 let room_id = RoomId::from_proto(request.room_id);
1688 {
1689 let room = session
1690 .db()
1691 .await
1692 .cancel_call(room_id, session.connection_id, called_user_id)
1693 .await?;
1694 room_updated(&room, &session.peer);
1695 }
1696
1697 for connection_id in session
1698 .connection_pool()
1699 .await
1700 .user_connection_ids(called_user_id)
1701 {
1702 session
1703 .peer
1704 .send(
1705 connection_id,
1706 proto::CallCanceled {
1707 room_id: room_id.to_proto(),
1708 },
1709 )
1710 .trace_err();
1711 }
1712 response.send(proto::Ack {})?;
1713
1714 update_user_contacts(called_user_id, &session).await?;
1715 Ok(())
1716}
1717
1718/// Decline an incoming call.
1719async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1720 let room_id = RoomId::from_proto(message.room_id);
1721 {
1722 let room = session
1723 .db()
1724 .await
1725 .decline_call(Some(room_id), session.user_id())
1726 .await?
1727 .context("declining call")?;
1728 room_updated(&room, &session.peer);
1729 }
1730
1731 for connection_id in session
1732 .connection_pool()
1733 .await
1734 .user_connection_ids(session.user_id())
1735 {
1736 session
1737 .peer
1738 .send(
1739 connection_id,
1740 proto::CallCanceled {
1741 room_id: room_id.to_proto(),
1742 },
1743 )
1744 .trace_err();
1745 }
1746 update_user_contacts(session.user_id(), &session).await?;
1747 Ok(())
1748}
1749
1750/// Updates other participants in the room with your current location.
1751async fn update_participant_location(
1752 request: proto::UpdateParticipantLocation,
1753 response: Response<proto::UpdateParticipantLocation>,
1754 session: MessageContext,
1755) -> Result<()> {
1756 let room_id = RoomId::from_proto(request.room_id);
1757 let location = request.location.context("invalid location")?;
1758
1759 let db = session.db().await;
1760 let room = db
1761 .update_room_participant_location(room_id, session.connection_id, location)
1762 .await?;
1763
1764 room_updated(&room, &session.peer);
1765 response.send(proto::Ack {})?;
1766 Ok(())
1767}
1768
1769/// Share a project into the room.
1770async fn share_project(
1771 request: proto::ShareProject,
1772 response: Response<proto::ShareProject>,
1773 session: MessageContext,
1774) -> Result<()> {
1775 let (project_id, room) = &*session
1776 .db()
1777 .await
1778 .share_project(
1779 RoomId::from_proto(request.room_id),
1780 session.connection_id,
1781 &request.worktrees,
1782 request.is_ssh_project,
1783 request.windows_paths.unwrap_or(false),
1784 &request.features,
1785 )
1786 .await?;
1787 response.send(proto::ShareProjectResponse {
1788 project_id: project_id.to_proto(),
1789 })?;
1790 room_updated(room, &session.peer);
1791
1792 Ok(())
1793}
1794
1795/// Unshare a project from the room.
1796async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1797 let project_id = ProjectId::from_proto(message.project_id);
1798 unshare_project_internal(project_id, session.connection_id, &session).await
1799}
1800
1801async fn unshare_project_internal(
1802 project_id: ProjectId,
1803 connection_id: ConnectionId,
1804 session: &Session,
1805) -> Result<()> {
1806 let delete = {
1807 let room_guard = session
1808 .db()
1809 .await
1810 .unshare_project(project_id, connection_id)
1811 .await?;
1812
1813 let (delete, room, guest_connection_ids) = &*room_guard;
1814
1815 let message = proto::UnshareProject {
1816 project_id: project_id.to_proto(),
1817 };
1818
1819 broadcast(
1820 Some(connection_id),
1821 guest_connection_ids.iter().copied(),
1822 |conn_id| session.peer.send(conn_id, message.clone()),
1823 );
1824 if let Some(room) = room {
1825 room_updated(room, &session.peer);
1826 }
1827
1828 *delete
1829 };
1830
1831 if delete {
1832 let db = session.db().await;
1833 db.delete_project(project_id).await?;
1834 }
1835
1836 Ok(())
1837}
1838
1839/// Join someone elses shared project.
1840async fn join_project(
1841 request: proto::JoinProject,
1842 response: Response<proto::JoinProject>,
1843 session: MessageContext,
1844) -> Result<()> {
1845 let project_id = ProjectId::from_proto(request.project_id);
1846
1847 tracing::info!(%project_id, "join project");
1848
1849 let db = session.db().await;
1850 let project_model = db.get_project(project_id).await?;
1851 let host_features: Vec<String> =
1852 serde_json::from_str(&project_model.features).unwrap_or_default();
1853 let guest_features: HashSet<_> = request.features.iter().collect();
1854 let host_features_set: HashSet<_> = host_features.iter().collect();
1855 if guest_features != host_features_set {
1856 let host_connection_id = project_model.host_connection()?;
1857 let mut pool = session.connection_pool().await;
1858 let host_version = pool
1859 .connection(host_connection_id)
1860 .map(|c| c.zed_version.to_string());
1861 let guest_version = pool
1862 .connection(session.connection_id)
1863 .map(|c| c.zed_version.to_string());
1864 drop(pool);
1865 Err(anyhow!(
1866 "The host (v{}) and guest (v{}) are using incompatible versions of Zed. The peer with the older version must update to collaborate.",
1867 host_version.as_deref().unwrap_or("unknown"),
1868 guest_version.as_deref().unwrap_or("unknown"),
1869 ))?;
1870 }
1871
1872 let (project, replica_id) = &mut *db
1873 .join_project(
1874 project_id,
1875 session.connection_id,
1876 session.user_id(),
1877 request.committer_name.clone(),
1878 request.committer_email.clone(),
1879 )
1880 .await?;
1881 drop(db);
1882
1883 tracing::info!(%project_id, "join remote project");
1884 let collaborators = project
1885 .collaborators
1886 .iter()
1887 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1888 .map(|collaborator| collaborator.to_proto())
1889 .collect::<Vec<_>>();
1890 let project_id = project.id;
1891 let guest_user_id = session.user_id();
1892
1893 let worktrees = project
1894 .worktrees
1895 .iter()
1896 .map(|(id, worktree)| proto::WorktreeMetadata {
1897 id: *id,
1898 root_name: worktree.root_name.clone(),
1899 visible: worktree.visible,
1900 abs_path: worktree.abs_path.clone(),
1901 root_repo_common_dir: None,
1902 })
1903 .collect::<Vec<_>>();
1904
1905 let add_project_collaborator = proto::AddProjectCollaborator {
1906 project_id: project_id.to_proto(),
1907 collaborator: Some(proto::Collaborator {
1908 peer_id: Some(session.connection_id.into()),
1909 replica_id: replica_id.0 as u32,
1910 user_id: guest_user_id.to_proto(),
1911 is_host: false,
1912 committer_name: request.committer_name.clone(),
1913 committer_email: request.committer_email.clone(),
1914 }),
1915 };
1916
1917 for collaborator in &collaborators {
1918 session
1919 .peer
1920 .send(
1921 collaborator.peer_id.unwrap().into(),
1922 add_project_collaborator.clone(),
1923 )
1924 .trace_err();
1925 }
1926
1927 // First, we send the metadata associated with each worktree.
1928 let (language_servers, language_server_capabilities) = project
1929 .language_servers
1930 .clone()
1931 .into_iter()
1932 .map(|server| (server.server, server.capabilities))
1933 .unzip();
1934 response.send(proto::JoinProjectResponse {
1935 project_id: project.id.0 as u64,
1936 worktrees,
1937 replica_id: replica_id.0 as u32,
1938 collaborators,
1939 language_servers,
1940 language_server_capabilities,
1941 role: project.role.into(),
1942 windows_paths: project.path_style == PathStyle::Windows,
1943 features: project.features.clone(),
1944 })?;
1945
1946 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1947 // Stream this worktree's entries.
1948 let message = proto::UpdateWorktree {
1949 project_id: project_id.to_proto(),
1950 worktree_id,
1951 abs_path: worktree.abs_path.clone(),
1952 root_name: worktree.root_name,
1953 root_repo_common_dir: worktree.root_repo_common_dir,
1954 updated_entries: worktree.entries,
1955 removed_entries: Default::default(),
1956 scan_id: worktree.scan_id,
1957 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1958 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1959 removed_repositories: Default::default(),
1960 };
1961 for update in proto::split_worktree_update(message) {
1962 session.peer.send(session.connection_id, update.clone())?;
1963 }
1964
1965 // Stream this worktree's diagnostics.
1966 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1967 if let Some(summary) = worktree_diagnostics.next() {
1968 let message = proto::UpdateDiagnosticSummary {
1969 project_id: project.id.to_proto(),
1970 worktree_id: worktree.id,
1971 summary: Some(summary),
1972 more_summaries: worktree_diagnostics.collect(),
1973 };
1974 session.peer.send(session.connection_id, message)?;
1975 }
1976
1977 for settings_file in worktree.settings_files {
1978 session.peer.send(
1979 session.connection_id,
1980 proto::UpdateWorktreeSettings {
1981 project_id: project_id.to_proto(),
1982 worktree_id: worktree.id,
1983 path: settings_file.path,
1984 content: Some(settings_file.content),
1985 kind: Some(settings_file.kind.to_proto() as i32),
1986 outside_worktree: Some(settings_file.outside_worktree),
1987 },
1988 )?;
1989 }
1990 }
1991
1992 for repository in mem::take(&mut project.repositories) {
1993 for update in split_repository_update(repository) {
1994 session.peer.send(session.connection_id, update)?;
1995 }
1996 }
1997
1998 for language_server in &project.language_servers {
1999 session.peer.send(
2000 session.connection_id,
2001 proto::UpdateLanguageServer {
2002 project_id: project_id.to_proto(),
2003 server_name: Some(language_server.server.name.clone()),
2004 language_server_id: language_server.server.id,
2005 variant: Some(
2006 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2007 proto::LspDiskBasedDiagnosticsUpdated {},
2008 ),
2009 ),
2010 },
2011 )?;
2012 }
2013
2014 Ok(())
2015}
2016
2017/// Leave someone elses shared project.
2018async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2019 let sender_id = session.connection_id;
2020 let project_id = ProjectId::from_proto(request.project_id);
2021 let db = session.db().await;
2022
2023 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2024 tracing::info!(
2025 %project_id,
2026 "leave project"
2027 );
2028
2029 project_left(project, &session);
2030 if let Some(room) = room {
2031 room_updated(room, &session.peer);
2032 }
2033
2034 Ok(())
2035}
2036
2037/// Updates other participants with changes to the project
2038async fn update_project(
2039 request: proto::UpdateProject,
2040 response: Response<proto::UpdateProject>,
2041 session: MessageContext,
2042) -> Result<()> {
2043 let project_id = ProjectId::from_proto(request.project_id);
2044 let (room, guest_connection_ids) = &*session
2045 .db()
2046 .await
2047 .update_project(project_id, session.connection_id, &request.worktrees)
2048 .await?;
2049 broadcast(
2050 Some(session.connection_id),
2051 guest_connection_ids.iter().copied(),
2052 |connection_id| {
2053 session
2054 .peer
2055 .forward_send(session.connection_id, connection_id, request.clone())
2056 },
2057 );
2058 if let Some(room) = room {
2059 room_updated(room, &session.peer);
2060 }
2061 response.send(proto::Ack {})?;
2062
2063 Ok(())
2064}
2065
2066/// Updates other participants with changes to the worktree
2067async fn update_worktree(
2068 request: proto::UpdateWorktree,
2069 response: Response<proto::UpdateWorktree>,
2070 session: MessageContext,
2071) -> Result<()> {
2072 let guest_connection_ids = session
2073 .db()
2074 .await
2075 .update_worktree(&request, session.connection_id)
2076 .await?;
2077
2078 broadcast(
2079 Some(session.connection_id),
2080 guest_connection_ids.iter().copied(),
2081 |connection_id| {
2082 session
2083 .peer
2084 .forward_send(session.connection_id, connection_id, request.clone())
2085 },
2086 );
2087 response.send(proto::Ack {})?;
2088 Ok(())
2089}
2090
2091async fn update_repository(
2092 request: proto::UpdateRepository,
2093 response: Response<proto::UpdateRepository>,
2094 session: MessageContext,
2095) -> Result<()> {
2096 let guest_connection_ids = session
2097 .db()
2098 .await
2099 .update_repository(&request, session.connection_id)
2100 .await?;
2101
2102 broadcast(
2103 Some(session.connection_id),
2104 guest_connection_ids.iter().copied(),
2105 |connection_id| {
2106 session
2107 .peer
2108 .forward_send(session.connection_id, connection_id, request.clone())
2109 },
2110 );
2111 response.send(proto::Ack {})?;
2112 Ok(())
2113}
2114
2115async fn remove_repository(
2116 request: proto::RemoveRepository,
2117 response: Response<proto::RemoveRepository>,
2118 session: MessageContext,
2119) -> Result<()> {
2120 let guest_connection_ids = session
2121 .db()
2122 .await
2123 .remove_repository(&request, session.connection_id)
2124 .await?;
2125
2126 broadcast(
2127 Some(session.connection_id),
2128 guest_connection_ids.iter().copied(),
2129 |connection_id| {
2130 session
2131 .peer
2132 .forward_send(session.connection_id, connection_id, request.clone())
2133 },
2134 );
2135 response.send(proto::Ack {})?;
2136 Ok(())
2137}
2138
2139/// Updates other participants with changes to the diagnostics
2140async fn update_diagnostic_summary(
2141 message: proto::UpdateDiagnosticSummary,
2142 session: MessageContext,
2143) -> Result<()> {
2144 let guest_connection_ids = session
2145 .db()
2146 .await
2147 .update_diagnostic_summary(&message, session.connection_id)
2148 .await?;
2149
2150 broadcast(
2151 Some(session.connection_id),
2152 guest_connection_ids.iter().copied(),
2153 |connection_id| {
2154 session
2155 .peer
2156 .forward_send(session.connection_id, connection_id, message.clone())
2157 },
2158 );
2159
2160 Ok(())
2161}
2162
2163/// Updates other participants with changes to the worktree settings
2164async fn update_worktree_settings(
2165 message: proto::UpdateWorktreeSettings,
2166 session: MessageContext,
2167) -> Result<()> {
2168 let guest_connection_ids = session
2169 .db()
2170 .await
2171 .update_worktree_settings(&message, session.connection_id)
2172 .await?;
2173
2174 broadcast(
2175 Some(session.connection_id),
2176 guest_connection_ids.iter().copied(),
2177 |connection_id| {
2178 session
2179 .peer
2180 .forward_send(session.connection_id, connection_id, message.clone())
2181 },
2182 );
2183
2184 Ok(())
2185}
2186
2187/// Notify other participants that a language server has started.
2188async fn start_language_server(
2189 request: proto::StartLanguageServer,
2190 session: MessageContext,
2191) -> Result<()> {
2192 let guest_connection_ids = session
2193 .db()
2194 .await
2195 .start_language_server(&request, session.connection_id)
2196 .await?;
2197
2198 broadcast(
2199 Some(session.connection_id),
2200 guest_connection_ids.iter().copied(),
2201 |connection_id| {
2202 session
2203 .peer
2204 .forward_send(session.connection_id, connection_id, request.clone())
2205 },
2206 );
2207 Ok(())
2208}
2209
2210/// Notify other participants that a language server has changed.
2211async fn update_language_server(
2212 request: proto::UpdateLanguageServer,
2213 session: MessageContext,
2214) -> Result<()> {
2215 let project_id = ProjectId::from_proto(request.project_id);
2216 let db = session.db().await;
2217
2218 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2219 && let Some(capabilities) = update.capabilities.clone()
2220 {
2221 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2222 .await?;
2223 }
2224
2225 let project_connection_ids = db
2226 .project_connection_ids(project_id, session.connection_id, true)
2227 .await?;
2228 broadcast(
2229 Some(session.connection_id),
2230 project_connection_ids.iter().copied(),
2231 |connection_id| {
2232 session
2233 .peer
2234 .forward_send(session.connection_id, connection_id, request.clone())
2235 },
2236 );
2237 Ok(())
2238}
2239
2240/// forward a project request to the host. These requests should be read only
2241/// as guests are allowed to send them.
2242async fn forward_read_only_project_request<T>(
2243 request: T,
2244 response: Response<T>,
2245 session: MessageContext,
2246) -> Result<()>
2247where
2248 T: EntityMessage + RequestMessage,
2249{
2250 let project_id = ProjectId::from_proto(request.remote_entity_id());
2251 let host_connection_id = session
2252 .db()
2253 .await
2254 .host_for_read_only_project_request(project_id, session.connection_id)
2255 .await?;
2256 let payload = session.forward_request(host_connection_id, request).await?;
2257 response.send(payload)?;
2258 Ok(())
2259}
2260
2261/// forward a project request to the host. These requests are disallowed
2262/// for guests.
2263async fn forward_mutating_project_request<T>(
2264 request: T,
2265 response: Response<T>,
2266 session: MessageContext,
2267) -> Result<()>
2268where
2269 T: EntityMessage + RequestMessage,
2270{
2271 let project_id = ProjectId::from_proto(request.remote_entity_id());
2272
2273 let host_connection_id = session
2274 .db()
2275 .await
2276 .host_for_mutating_project_request(project_id, session.connection_id)
2277 .await?;
2278 let payload = session.forward_request(host_connection_id, request).await?;
2279 response.send(payload)?;
2280 Ok(())
2281}
2282
2283async fn disallow_guest_request<T>(
2284 _request: T,
2285 response: Response<T>,
2286 _session: MessageContext,
2287) -> Result<()>
2288where
2289 T: RequestMessage,
2290{
2291 response.peer.respond_with_error(
2292 response.receipt,
2293 ErrorCode::Forbidden
2294 .message("request is not allowed for guests".to_string())
2295 .to_proto(),
2296 )?;
2297 response.responded.store(true, SeqCst);
2298 Ok(())
2299}
2300
2301async fn lsp_query(
2302 request: proto::LspQuery,
2303 response: Response<proto::LspQuery>,
2304 session: MessageContext,
2305) -> Result<()> {
2306 let (name, should_write) = request.query_name_and_write_permissions();
2307 tracing::Span::current().record("lsp_query_request", name);
2308 tracing::info!("lsp_query message received");
2309 if should_write {
2310 forward_mutating_project_request(request, response, session).await
2311 } else {
2312 forward_read_only_project_request(request, response, session).await
2313 }
2314}
2315
2316/// Notify other participants that a new buffer has been created
2317async fn create_buffer_for_peer(
2318 request: proto::CreateBufferForPeer,
2319 session: MessageContext,
2320) -> Result<()> {
2321 session
2322 .db()
2323 .await
2324 .check_user_is_project_host(
2325 ProjectId::from_proto(request.project_id),
2326 session.connection_id,
2327 )
2328 .await?;
2329 let peer_id = request.peer_id.context("invalid peer id")?;
2330 session
2331 .peer
2332 .forward_send(session.connection_id, peer_id.into(), request)?;
2333 Ok(())
2334}
2335
2336/// Notify other participants that a new image has been created
2337async fn create_image_for_peer(
2338 request: proto::CreateImageForPeer,
2339 session: MessageContext,
2340) -> Result<()> {
2341 session
2342 .db()
2343 .await
2344 .check_user_is_project_host(
2345 ProjectId::from_proto(request.project_id),
2346 session.connection_id,
2347 )
2348 .await?;
2349 let peer_id = request.peer_id.context("invalid peer id")?;
2350 session
2351 .peer
2352 .forward_send(session.connection_id, peer_id.into(), request)?;
2353 Ok(())
2354}
2355
2356/// Notify other participants that a buffer has been updated. This is
2357/// allowed for guests as long as the update is limited to selections.
2358async fn update_buffer(
2359 request: proto::UpdateBuffer,
2360 response: Response<proto::UpdateBuffer>,
2361 session: MessageContext,
2362) -> Result<()> {
2363 let project_id = ProjectId::from_proto(request.project_id);
2364 let mut capability = Capability::ReadOnly;
2365
2366 for op in request.operations.iter() {
2367 match op.variant {
2368 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2369 Some(_) => capability = Capability::ReadWrite,
2370 }
2371 }
2372
2373 let host = {
2374 let guard = session
2375 .db()
2376 .await
2377 .connections_for_buffer_update(project_id, session.connection_id, capability)
2378 .await?;
2379
2380 let (host, guests) = &*guard;
2381
2382 broadcast(
2383 Some(session.connection_id),
2384 guests.clone(),
2385 |connection_id| {
2386 session
2387 .peer
2388 .forward_send(session.connection_id, connection_id, request.clone())
2389 },
2390 );
2391
2392 *host
2393 };
2394
2395 if host != session.connection_id {
2396 session.forward_request(host, request.clone()).await?;
2397 }
2398
2399 response.send(proto::Ack {})?;
2400 Ok(())
2401}
2402
2403async fn forward_project_search_chunk(
2404 message: proto::FindSearchCandidatesChunk,
2405 response: Response<proto::FindSearchCandidatesChunk>,
2406 session: MessageContext,
2407) -> Result<()> {
2408 let peer_id = message.peer_id.context("missing peer_id")?;
2409 let payload = session
2410 .peer
2411 .forward_request(session.connection_id, peer_id.into(), message)
2412 .await?;
2413 response.send(payload)?;
2414 Ok(())
2415}
2416
2417/// Notify other participants that a project has been updated.
2418async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2419 request: T,
2420 session: MessageContext,
2421) -> Result<()> {
2422 let project_id = ProjectId::from_proto(request.remote_entity_id());
2423 let project_connection_ids = session
2424 .db()
2425 .await
2426 .project_connection_ids(project_id, session.connection_id, false)
2427 .await?;
2428
2429 broadcast(
2430 Some(session.connection_id),
2431 project_connection_ids.iter().copied(),
2432 |connection_id| {
2433 session
2434 .peer
2435 .forward_send(session.connection_id, connection_id, request.clone())
2436 },
2437 );
2438 Ok(())
2439}
2440
2441/// Start following another user in a call.
2442async fn follow(
2443 request: proto::Follow,
2444 response: Response<proto::Follow>,
2445 session: MessageContext,
2446) -> Result<()> {
2447 let room_id = RoomId::from_proto(request.room_id);
2448 let project_id = request.project_id.map(ProjectId::from_proto);
2449 let leader_id = request.leader_id.context("invalid leader id")?.into();
2450 let follower_id = session.connection_id;
2451
2452 session
2453 .db()
2454 .await
2455 .check_room_participants(room_id, leader_id, session.connection_id)
2456 .await?;
2457
2458 let response_payload = session.forward_request(leader_id, request).await?;
2459 response.send(response_payload)?;
2460
2461 if let Some(project_id) = project_id {
2462 let room = session
2463 .db()
2464 .await
2465 .follow(room_id, project_id, leader_id, follower_id)
2466 .await?;
2467 room_updated(&room, &session.peer);
2468 }
2469
2470 Ok(())
2471}
2472
2473/// Stop following another user in a call.
2474async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2475 let room_id = RoomId::from_proto(request.room_id);
2476 let project_id = request.project_id.map(ProjectId::from_proto);
2477 let leader_id = request.leader_id.context("invalid leader id")?.into();
2478 let follower_id = session.connection_id;
2479
2480 session
2481 .db()
2482 .await
2483 .check_room_participants(room_id, leader_id, session.connection_id)
2484 .await?;
2485
2486 session
2487 .peer
2488 .forward_send(session.connection_id, leader_id, request)?;
2489
2490 if let Some(project_id) = project_id {
2491 let room = session
2492 .db()
2493 .await
2494 .unfollow(room_id, project_id, leader_id, follower_id)
2495 .await?;
2496 room_updated(&room, &session.peer);
2497 }
2498
2499 Ok(())
2500}
2501
2502/// Notify everyone following you of your current location.
2503async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2504 let room_id = RoomId::from_proto(request.room_id);
2505 let database = session.db.lock().await;
2506
2507 let connection_ids = if let Some(project_id) = request.project_id {
2508 let project_id = ProjectId::from_proto(project_id);
2509 database
2510 .project_connection_ids(project_id, session.connection_id, true)
2511 .await?
2512 } else {
2513 database
2514 .room_connection_ids(room_id, session.connection_id)
2515 .await?
2516 };
2517
2518 // For now, don't send view update messages back to that view's current leader.
2519 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2520 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2521 _ => None,
2522 });
2523
2524 for connection_id in connection_ids.iter().cloned() {
2525 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2526 session
2527 .peer
2528 .forward_send(session.connection_id, connection_id, request.clone())?;
2529 }
2530 }
2531 Ok(())
2532}
2533
2534/// Get public data about users.
2535async fn get_users(
2536 request: proto::GetUsers,
2537 response: Response<proto::GetUsers>,
2538 session: MessageContext,
2539) -> Result<()> {
2540 let user_ids = request
2541 .user_ids
2542 .into_iter()
2543 .map(UserId::from_proto)
2544 .collect();
2545 let users = session
2546 .db()
2547 .await
2548 .get_users_by_ids(user_ids)
2549 .await?
2550 .into_iter()
2551 .map(|user| proto::User {
2552 id: user.id.to_proto(),
2553 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2554 github_login: user.github_login,
2555 name: user.name,
2556 })
2557 .collect();
2558 response.send(proto::UsersResponse { users })?;
2559 Ok(())
2560}
2561
2562/// Search for users (to invite) buy Github login
2563async fn fuzzy_search_users(
2564 request: proto::FuzzySearchUsers,
2565 response: Response<proto::FuzzySearchUsers>,
2566 session: MessageContext,
2567) -> Result<()> {
2568 let query = request.query;
2569 let users = match query.len() {
2570 0 => vec![],
2571 1 | 2 => session
2572 .db()
2573 .await
2574 .get_user_by_github_login(&query)
2575 .await?
2576 .into_iter()
2577 .collect(),
2578 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2579 };
2580 let users = users
2581 .into_iter()
2582 .filter(|user| user.id != session.user_id())
2583 .map(|user| proto::User {
2584 id: user.id.to_proto(),
2585 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2586 github_login: user.github_login,
2587 name: user.name,
2588 })
2589 .collect();
2590 response.send(proto::UsersResponse { users })?;
2591 Ok(())
2592}
2593
2594/// Send a contact request to another user.
2595async fn request_contact(
2596 request: proto::RequestContact,
2597 response: Response<proto::RequestContact>,
2598 session: MessageContext,
2599) -> Result<()> {
2600 let requester_id = session.user_id();
2601 let responder_id = UserId::from_proto(request.responder_id);
2602 if requester_id == responder_id {
2603 return Err(anyhow!("cannot add yourself as a contact"))?;
2604 }
2605
2606 let notifications = session
2607 .db()
2608 .await
2609 .send_contact_request(requester_id, responder_id)
2610 .await?;
2611
2612 // Update outgoing contact requests of requester
2613 let mut update = proto::UpdateContacts::default();
2614 update.outgoing_requests.push(responder_id.to_proto());
2615 for connection_id in session
2616 .connection_pool()
2617 .await
2618 .user_connection_ids(requester_id)
2619 {
2620 session.peer.send(connection_id, update.clone())?;
2621 }
2622
2623 // Update incoming contact requests of responder
2624 let mut update = proto::UpdateContacts::default();
2625 update
2626 .incoming_requests
2627 .push(proto::IncomingContactRequest {
2628 requester_id: requester_id.to_proto(),
2629 });
2630 let connection_pool = session.connection_pool().await;
2631 for connection_id in connection_pool.user_connection_ids(responder_id) {
2632 session.peer.send(connection_id, update.clone())?;
2633 }
2634
2635 send_notifications(&connection_pool, &session.peer, notifications);
2636
2637 response.send(proto::Ack {})?;
2638 Ok(())
2639}
2640
2641/// Accept or decline a contact request
2642async fn respond_to_contact_request(
2643 request: proto::RespondToContactRequest,
2644 response: Response<proto::RespondToContactRequest>,
2645 session: MessageContext,
2646) -> Result<()> {
2647 let responder_id = session.user_id();
2648 let requester_id = UserId::from_proto(request.requester_id);
2649 let db = session.db().await;
2650 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2651 db.dismiss_contact_notification(responder_id, requester_id)
2652 .await?;
2653 } else {
2654 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2655
2656 let notifications = db
2657 .respond_to_contact_request(responder_id, requester_id, accept)
2658 .await?;
2659 let requester_busy = db.is_user_busy(requester_id).await?;
2660 let responder_busy = db.is_user_busy(responder_id).await?;
2661
2662 let pool = session.connection_pool().await;
2663 // Update responder with new contact
2664 let mut update = proto::UpdateContacts::default();
2665 if accept {
2666 update
2667 .contacts
2668 .push(contact_for_user(requester_id, requester_busy, &pool));
2669 }
2670 update
2671 .remove_incoming_requests
2672 .push(requester_id.to_proto());
2673 for connection_id in pool.user_connection_ids(responder_id) {
2674 session.peer.send(connection_id, update.clone())?;
2675 }
2676
2677 // Update requester with new contact
2678 let mut update = proto::UpdateContacts::default();
2679 if accept {
2680 update
2681 .contacts
2682 .push(contact_for_user(responder_id, responder_busy, &pool));
2683 }
2684 update
2685 .remove_outgoing_requests
2686 .push(responder_id.to_proto());
2687
2688 for connection_id in pool.user_connection_ids(requester_id) {
2689 session.peer.send(connection_id, update.clone())?;
2690 }
2691
2692 send_notifications(&pool, &session.peer, notifications);
2693 }
2694
2695 response.send(proto::Ack {})?;
2696 Ok(())
2697}
2698
2699/// Remove a contact.
2700async fn remove_contact(
2701 request: proto::RemoveContact,
2702 response: Response<proto::RemoveContact>,
2703 session: MessageContext,
2704) -> Result<()> {
2705 let requester_id = session.user_id();
2706 let responder_id = UserId::from_proto(request.user_id);
2707 let db = session.db().await;
2708 let (contact_accepted, deleted_notification_id) =
2709 db.remove_contact(requester_id, responder_id).await?;
2710
2711 let pool = session.connection_pool().await;
2712 // Update outgoing contact requests of requester
2713 let mut update = proto::UpdateContacts::default();
2714 if contact_accepted {
2715 update.remove_contacts.push(responder_id.to_proto());
2716 } else {
2717 update
2718 .remove_outgoing_requests
2719 .push(responder_id.to_proto());
2720 }
2721 for connection_id in pool.user_connection_ids(requester_id) {
2722 session.peer.send(connection_id, update.clone())?;
2723 }
2724
2725 // Update incoming contact requests of responder
2726 let mut update = proto::UpdateContacts::default();
2727 if contact_accepted {
2728 update.remove_contacts.push(requester_id.to_proto());
2729 } else {
2730 update
2731 .remove_incoming_requests
2732 .push(requester_id.to_proto());
2733 }
2734 for connection_id in pool.user_connection_ids(responder_id) {
2735 session.peer.send(connection_id, update.clone())?;
2736 if let Some(notification_id) = deleted_notification_id {
2737 session.peer.send(
2738 connection_id,
2739 proto::DeleteNotification {
2740 notification_id: notification_id.to_proto(),
2741 },
2742 )?;
2743 }
2744 }
2745
2746 response.send(proto::Ack {})?;
2747 Ok(())
2748}
2749
2750fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2751 version.0.minor < 139
2752}
2753
2754async fn subscribe_to_channels(
2755 _: proto::SubscribeToChannels,
2756 session: MessageContext,
2757) -> Result<()> {
2758 subscribe_user_to_channels(session.user_id(), &session).await?;
2759 Ok(())
2760}
2761
2762async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2763 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2764 let mut pool = session.connection_pool().await;
2765 for membership in &channels_for_user.channel_memberships {
2766 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2767 }
2768 session.peer.send(
2769 session.connection_id,
2770 build_update_user_channels(&channels_for_user),
2771 )?;
2772 session.peer.send(
2773 session.connection_id,
2774 build_channels_update(channels_for_user),
2775 )?;
2776 Ok(())
2777}
2778
2779/// Creates a new channel.
2780async fn create_channel(
2781 request: proto::CreateChannel,
2782 response: Response<proto::CreateChannel>,
2783 session: MessageContext,
2784) -> Result<()> {
2785 let db = session.db().await;
2786
2787 let parent_id = request.parent_id.map(ChannelId::from_proto);
2788 let (channel, membership) = db
2789 .create_channel(&request.name, parent_id, session.user_id())
2790 .await?;
2791
2792 let root_id = channel.root_id();
2793 let channel = Channel::from_model(channel);
2794
2795 response.send(proto::CreateChannelResponse {
2796 channel: Some(channel.to_proto()),
2797 parent_id: request.parent_id,
2798 })?;
2799
2800 let mut connection_pool = session.connection_pool().await;
2801 if let Some(membership) = membership {
2802 connection_pool.subscribe_to_channel(
2803 membership.user_id,
2804 membership.channel_id,
2805 membership.role,
2806 );
2807 let update = proto::UpdateUserChannels {
2808 channel_memberships: vec![proto::ChannelMembership {
2809 channel_id: membership.channel_id.to_proto(),
2810 role: membership.role.into(),
2811 }],
2812 ..Default::default()
2813 };
2814 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2815 session.peer.send(connection_id, update.clone())?;
2816 }
2817 }
2818
2819 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2820 if !role.can_see_channel(channel.visibility) {
2821 continue;
2822 }
2823
2824 let update = proto::UpdateChannels {
2825 channels: vec![channel.to_proto()],
2826 ..Default::default()
2827 };
2828 session.peer.send(connection_id, update.clone())?;
2829 }
2830
2831 Ok(())
2832}
2833
2834/// Delete a channel
2835async fn delete_channel(
2836 request: proto::DeleteChannel,
2837 response: Response<proto::DeleteChannel>,
2838 session: MessageContext,
2839) -> Result<()> {
2840 let db = session.db().await;
2841
2842 let channel_id = request.channel_id;
2843 let (root_channel, removed_channels) = db
2844 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2845 .await?;
2846 response.send(proto::Ack {})?;
2847
2848 // Notify members of removed channels
2849 let mut update = proto::UpdateChannels::default();
2850 update
2851 .delete_channels
2852 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2853
2854 let connection_pool = session.connection_pool().await;
2855 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2856 session.peer.send(connection_id, update.clone())?;
2857 }
2858
2859 Ok(())
2860}
2861
2862/// Invite someone to join a channel.
2863async fn invite_channel_member(
2864 request: proto::InviteChannelMember,
2865 response: Response<proto::InviteChannelMember>,
2866 session: MessageContext,
2867) -> Result<()> {
2868 let db = session.db().await;
2869 let channel_id = ChannelId::from_proto(request.channel_id);
2870 let invitee_id = UserId::from_proto(request.user_id);
2871 let InviteMemberResult {
2872 channel,
2873 notifications,
2874 } = db
2875 .invite_channel_member(
2876 channel_id,
2877 invitee_id,
2878 session.user_id(),
2879 request.role().into(),
2880 )
2881 .await?;
2882
2883 let update = proto::UpdateChannels {
2884 channel_invitations: vec![channel.to_proto()],
2885 ..Default::default()
2886 };
2887
2888 let connection_pool = session.connection_pool().await;
2889 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2890 session.peer.send(connection_id, update.clone())?;
2891 }
2892
2893 send_notifications(&connection_pool, &session.peer, notifications);
2894
2895 response.send(proto::Ack {})?;
2896 Ok(())
2897}
2898
2899/// remove someone from a channel
2900async fn remove_channel_member(
2901 request: proto::RemoveChannelMember,
2902 response: Response<proto::RemoveChannelMember>,
2903 session: MessageContext,
2904) -> Result<()> {
2905 let db = session.db().await;
2906 let channel_id = ChannelId::from_proto(request.channel_id);
2907 let member_id = UserId::from_proto(request.user_id);
2908
2909 let RemoveChannelMemberResult {
2910 membership_update,
2911 notification_id,
2912 } = db
2913 .remove_channel_member(channel_id, member_id, session.user_id())
2914 .await?;
2915
2916 let mut connection_pool = session.connection_pool().await;
2917 notify_membership_updated(
2918 &mut connection_pool,
2919 membership_update,
2920 member_id,
2921 &session.peer,
2922 );
2923 for connection_id in connection_pool.user_connection_ids(member_id) {
2924 if let Some(notification_id) = notification_id {
2925 session
2926 .peer
2927 .send(
2928 connection_id,
2929 proto::DeleteNotification {
2930 notification_id: notification_id.to_proto(),
2931 },
2932 )
2933 .trace_err();
2934 }
2935 }
2936
2937 response.send(proto::Ack {})?;
2938 Ok(())
2939}
2940
2941/// Toggle the channel between public and private.
2942/// Care is taken to maintain the invariant that public channels only descend from public channels,
2943/// (though members-only channels can appear at any point in the hierarchy).
2944async fn set_channel_visibility(
2945 request: proto::SetChannelVisibility,
2946 response: Response<proto::SetChannelVisibility>,
2947 session: MessageContext,
2948) -> Result<()> {
2949 let db = session.db().await;
2950 let channel_id = ChannelId::from_proto(request.channel_id);
2951 let visibility = request.visibility().into();
2952
2953 let channel_model = db
2954 .set_channel_visibility(channel_id, visibility, session.user_id())
2955 .await?;
2956 let root_id = channel_model.root_id();
2957 let channel = Channel::from_model(channel_model);
2958
2959 let mut connection_pool = session.connection_pool().await;
2960 for (user_id, role) in connection_pool
2961 .channel_user_ids(root_id)
2962 .collect::<Vec<_>>()
2963 .into_iter()
2964 {
2965 let update = if role.can_see_channel(channel.visibility) {
2966 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2967 proto::UpdateChannels {
2968 channels: vec![channel.to_proto()],
2969 ..Default::default()
2970 }
2971 } else {
2972 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2973 proto::UpdateChannels {
2974 delete_channels: vec![channel.id.to_proto()],
2975 ..Default::default()
2976 }
2977 };
2978
2979 for connection_id in connection_pool.user_connection_ids(user_id) {
2980 session.peer.send(connection_id, update.clone())?;
2981 }
2982 }
2983
2984 response.send(proto::Ack {})?;
2985 Ok(())
2986}
2987
2988/// Alter the role for a user in the channel.
2989async fn set_channel_member_role(
2990 request: proto::SetChannelMemberRole,
2991 response: Response<proto::SetChannelMemberRole>,
2992 session: MessageContext,
2993) -> Result<()> {
2994 let db = session.db().await;
2995 let channel_id = ChannelId::from_proto(request.channel_id);
2996 let member_id = UserId::from_proto(request.user_id);
2997 let result = db
2998 .set_channel_member_role(
2999 channel_id,
3000 session.user_id(),
3001 member_id,
3002 request.role().into(),
3003 )
3004 .await?;
3005
3006 match result {
3007 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3008 let mut connection_pool = session.connection_pool().await;
3009 notify_membership_updated(
3010 &mut connection_pool,
3011 membership_update,
3012 member_id,
3013 &session.peer,
3014 )
3015 }
3016 db::SetMemberRoleResult::InviteUpdated(channel) => {
3017 let update = proto::UpdateChannels {
3018 channel_invitations: vec![channel.to_proto()],
3019 ..Default::default()
3020 };
3021
3022 for connection_id in session
3023 .connection_pool()
3024 .await
3025 .user_connection_ids(member_id)
3026 {
3027 session.peer.send(connection_id, update.clone())?;
3028 }
3029 }
3030 }
3031
3032 response.send(proto::Ack {})?;
3033 Ok(())
3034}
3035
3036/// Change the name of a channel
3037async fn rename_channel(
3038 request: proto::RenameChannel,
3039 response: Response<proto::RenameChannel>,
3040 session: MessageContext,
3041) -> Result<()> {
3042 let db = session.db().await;
3043 let channel_id = ChannelId::from_proto(request.channel_id);
3044 let channel_model = db
3045 .rename_channel(channel_id, session.user_id(), &request.name)
3046 .await?;
3047 let root_id = channel_model.root_id();
3048 let channel = Channel::from_model(channel_model);
3049
3050 response.send(proto::RenameChannelResponse {
3051 channel: Some(channel.to_proto()),
3052 })?;
3053
3054 let connection_pool = session.connection_pool().await;
3055 let update = proto::UpdateChannels {
3056 channels: vec![channel.to_proto()],
3057 ..Default::default()
3058 };
3059 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3060 if role.can_see_channel(channel.visibility) {
3061 session.peer.send(connection_id, update.clone())?;
3062 }
3063 }
3064
3065 Ok(())
3066}
3067
3068/// Move a channel to a new parent.
3069async fn move_channel(
3070 request: proto::MoveChannel,
3071 response: Response<proto::MoveChannel>,
3072 session: MessageContext,
3073) -> Result<()> {
3074 let channel_id = ChannelId::from_proto(request.channel_id);
3075 let to = ChannelId::from_proto(request.to);
3076
3077 let (root_id, channels) = session
3078 .db()
3079 .await
3080 .move_channel(channel_id, to, session.user_id())
3081 .await?;
3082
3083 let connection_pool = session.connection_pool().await;
3084 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3085 let channels = channels
3086 .iter()
3087 .filter_map(|channel| {
3088 if role.can_see_channel(channel.visibility) {
3089 Some(channel.to_proto())
3090 } else {
3091 None
3092 }
3093 })
3094 .collect::<Vec<_>>();
3095 if channels.is_empty() {
3096 continue;
3097 }
3098
3099 let update = proto::UpdateChannels {
3100 channels,
3101 ..Default::default()
3102 };
3103
3104 session.peer.send(connection_id, update.clone())?;
3105 }
3106
3107 response.send(Ack {})?;
3108 Ok(())
3109}
3110
3111async fn reorder_channel(
3112 request: proto::ReorderChannel,
3113 response: Response<proto::ReorderChannel>,
3114 session: MessageContext,
3115) -> Result<()> {
3116 let channel_id = ChannelId::from_proto(request.channel_id);
3117 let direction = request.direction();
3118
3119 let updated_channels = session
3120 .db()
3121 .await
3122 .reorder_channel(channel_id, direction, session.user_id())
3123 .await?;
3124
3125 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3126 let connection_pool = session.connection_pool().await;
3127 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3128 let channels = updated_channels
3129 .iter()
3130 .filter_map(|channel| {
3131 if role.can_see_channel(channel.visibility) {
3132 Some(channel.to_proto())
3133 } else {
3134 None
3135 }
3136 })
3137 .collect::<Vec<_>>();
3138
3139 if channels.is_empty() {
3140 continue;
3141 }
3142
3143 let update = proto::UpdateChannels {
3144 channels,
3145 ..Default::default()
3146 };
3147
3148 session.peer.send(connection_id, update.clone())?;
3149 }
3150 }
3151
3152 response.send(Ack {})?;
3153 Ok(())
3154}
3155
3156/// Get the list of channel members
3157async fn get_channel_members(
3158 request: proto::GetChannelMembers,
3159 response: Response<proto::GetChannelMembers>,
3160 session: MessageContext,
3161) -> Result<()> {
3162 let db = session.db().await;
3163 let channel_id = ChannelId::from_proto(request.channel_id);
3164 let limit = if request.limit == 0 {
3165 u16::MAX as u64
3166 } else {
3167 request.limit
3168 };
3169 let (members, users) = db
3170 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3171 .await?;
3172 response.send(proto::GetChannelMembersResponse { members, users })?;
3173 Ok(())
3174}
3175
3176/// Accept or decline a channel invitation.
3177async fn respond_to_channel_invite(
3178 request: proto::RespondToChannelInvite,
3179 response: Response<proto::RespondToChannelInvite>,
3180 session: MessageContext,
3181) -> Result<()> {
3182 let db = session.db().await;
3183 let channel_id = ChannelId::from_proto(request.channel_id);
3184 let RespondToChannelInvite {
3185 membership_update,
3186 notifications,
3187 } = db
3188 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3189 .await?;
3190
3191 let mut connection_pool = session.connection_pool().await;
3192 if let Some(membership_update) = membership_update {
3193 notify_membership_updated(
3194 &mut connection_pool,
3195 membership_update,
3196 session.user_id(),
3197 &session.peer,
3198 );
3199 } else {
3200 let update = proto::UpdateChannels {
3201 remove_channel_invitations: vec![channel_id.to_proto()],
3202 ..Default::default()
3203 };
3204
3205 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3206 session.peer.send(connection_id, update.clone())?;
3207 }
3208 };
3209
3210 send_notifications(&connection_pool, &session.peer, notifications);
3211
3212 response.send(proto::Ack {})?;
3213
3214 Ok(())
3215}
3216
3217/// Join the channels' room
3218async fn join_channel(
3219 request: proto::JoinChannel,
3220 response: Response<proto::JoinChannel>,
3221 session: MessageContext,
3222) -> Result<()> {
3223 let channel_id = ChannelId::from_proto(request.channel_id);
3224 join_channel_internal(channel_id, Box::new(response), session).await
3225}
3226
3227trait JoinChannelInternalResponse {
3228 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3229}
3230impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3231 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3232 Response::<proto::JoinChannel>::send(self, result)
3233 }
3234}
3235impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3236 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3237 Response::<proto::JoinRoom>::send(self, result)
3238 }
3239}
3240
3241async fn join_channel_internal(
3242 channel_id: ChannelId,
3243 response: Box<impl JoinChannelInternalResponse>,
3244 session: MessageContext,
3245) -> Result<()> {
3246 let joined_room = {
3247 let mut db = session.db().await;
3248 // If zed quits without leaving the room, and the user re-opens zed before the
3249 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3250 // room they were in.
3251 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3252 tracing::info!(
3253 stale_connection_id = %connection,
3254 "cleaning up stale connection",
3255 );
3256 drop(db);
3257 leave_room_for_session(&session, connection).await?;
3258 db = session.db().await;
3259 }
3260
3261 let (joined_room, membership_updated, role) = db
3262 .join_channel(channel_id, session.user_id(), session.connection_id)
3263 .await?;
3264
3265 let live_kit_connection_info =
3266 session
3267 .app_state
3268 .livekit_client
3269 .as_ref()
3270 .and_then(|live_kit| {
3271 let (can_publish, token) = if role == ChannelRole::Guest {
3272 (
3273 false,
3274 live_kit
3275 .guest_token(
3276 &joined_room.room.livekit_room,
3277 &session.user_id().to_string(),
3278 )
3279 .trace_err()?,
3280 )
3281 } else {
3282 (
3283 true,
3284 live_kit
3285 .room_token(
3286 &joined_room.room.livekit_room,
3287 &session.user_id().to_string(),
3288 )
3289 .trace_err()?,
3290 )
3291 };
3292
3293 Some(LiveKitConnectionInfo {
3294 server_url: live_kit.url().into(),
3295 token,
3296 can_publish,
3297 })
3298 });
3299
3300 response.send(proto::JoinRoomResponse {
3301 room: Some(joined_room.room.clone()),
3302 channel_id: joined_room
3303 .channel
3304 .as_ref()
3305 .map(|channel| channel.id.to_proto()),
3306 live_kit_connection_info,
3307 })?;
3308
3309 let mut connection_pool = session.connection_pool().await;
3310 if let Some(membership_updated) = membership_updated {
3311 notify_membership_updated(
3312 &mut connection_pool,
3313 membership_updated,
3314 session.user_id(),
3315 &session.peer,
3316 );
3317 }
3318
3319 room_updated(&joined_room.room, &session.peer);
3320
3321 joined_room
3322 };
3323
3324 channel_updated(
3325 &joined_room.channel.context("channel not returned")?,
3326 &joined_room.room,
3327 &session.peer,
3328 &*session.connection_pool().await,
3329 );
3330
3331 update_user_contacts(session.user_id(), &session).await?;
3332 Ok(())
3333}
3334
3335/// Start editing the channel notes
3336async fn join_channel_buffer(
3337 request: proto::JoinChannelBuffer,
3338 response: Response<proto::JoinChannelBuffer>,
3339 session: MessageContext,
3340) -> Result<()> {
3341 let db = session.db().await;
3342 let channel_id = ChannelId::from_proto(request.channel_id);
3343
3344 let open_response = db
3345 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3346 .await?;
3347
3348 let collaborators = open_response.collaborators.clone();
3349 response.send(open_response)?;
3350
3351 let update = UpdateChannelBufferCollaborators {
3352 channel_id: channel_id.to_proto(),
3353 collaborators: collaborators.clone(),
3354 };
3355 channel_buffer_updated(
3356 session.connection_id,
3357 collaborators
3358 .iter()
3359 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3360 &update,
3361 &session.peer,
3362 );
3363
3364 Ok(())
3365}
3366
3367/// Edit the channel notes
3368async fn update_channel_buffer(
3369 request: proto::UpdateChannelBuffer,
3370 session: MessageContext,
3371) -> Result<()> {
3372 let db = session.db().await;
3373 let channel_id = ChannelId::from_proto(request.channel_id);
3374
3375 let (collaborators, epoch, version) = db
3376 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3377 .await?;
3378
3379 channel_buffer_updated(
3380 session.connection_id,
3381 collaborators.clone(),
3382 &proto::UpdateChannelBuffer {
3383 channel_id: channel_id.to_proto(),
3384 operations: request.operations,
3385 },
3386 &session.peer,
3387 );
3388
3389 let pool = &*session.connection_pool().await;
3390
3391 let non_collaborators =
3392 pool.channel_connection_ids(channel_id)
3393 .filter_map(|(connection_id, _)| {
3394 if collaborators.contains(&connection_id) {
3395 None
3396 } else {
3397 Some(connection_id)
3398 }
3399 });
3400
3401 broadcast(None, non_collaborators, |peer_id| {
3402 session.peer.send(
3403 peer_id,
3404 proto::UpdateChannels {
3405 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3406 channel_id: channel_id.to_proto(),
3407 epoch: epoch as u64,
3408 version: version.clone(),
3409 }],
3410 ..Default::default()
3411 },
3412 )
3413 });
3414
3415 Ok(())
3416}
3417
3418/// Rejoin the channel notes after a connection blip
3419async fn rejoin_channel_buffers(
3420 request: proto::RejoinChannelBuffers,
3421 response: Response<proto::RejoinChannelBuffers>,
3422 session: MessageContext,
3423) -> Result<()> {
3424 let db = session.db().await;
3425 let buffers = db
3426 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3427 .await?;
3428
3429 for rejoined_buffer in &buffers {
3430 let collaborators_to_notify = rejoined_buffer
3431 .buffer
3432 .collaborators
3433 .iter()
3434 .filter_map(|c| Some(c.peer_id?.into()));
3435 channel_buffer_updated(
3436 session.connection_id,
3437 collaborators_to_notify,
3438 &proto::UpdateChannelBufferCollaborators {
3439 channel_id: rejoined_buffer.buffer.channel_id,
3440 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3441 },
3442 &session.peer,
3443 );
3444 }
3445
3446 response.send(proto::RejoinChannelBuffersResponse {
3447 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3448 })?;
3449
3450 Ok(())
3451}
3452
3453/// Stop editing the channel notes
3454async fn leave_channel_buffer(
3455 request: proto::LeaveChannelBuffer,
3456 response: Response<proto::LeaveChannelBuffer>,
3457 session: MessageContext,
3458) -> Result<()> {
3459 let db = session.db().await;
3460 let channel_id = ChannelId::from_proto(request.channel_id);
3461
3462 let left_buffer = db
3463 .leave_channel_buffer(channel_id, session.connection_id)
3464 .await?;
3465
3466 response.send(Ack {})?;
3467
3468 channel_buffer_updated(
3469 session.connection_id,
3470 left_buffer.connections,
3471 &proto::UpdateChannelBufferCollaborators {
3472 channel_id: channel_id.to_proto(),
3473 collaborators: left_buffer.collaborators,
3474 },
3475 &session.peer,
3476 );
3477
3478 Ok(())
3479}
3480
3481fn channel_buffer_updated<T: EnvelopedMessage>(
3482 sender_id: ConnectionId,
3483 collaborators: impl IntoIterator<Item = ConnectionId>,
3484 message: &T,
3485 peer: &Peer,
3486) {
3487 broadcast(Some(sender_id), collaborators, |peer_id| {
3488 peer.send(peer_id, message.clone())
3489 });
3490}
3491
3492fn send_notifications(
3493 connection_pool: &ConnectionPool,
3494 peer: &Peer,
3495 notifications: db::NotificationBatch,
3496) {
3497 for (user_id, notification) in notifications {
3498 for connection_id in connection_pool.user_connection_ids(user_id) {
3499 if let Err(error) = peer.send(
3500 connection_id,
3501 proto::AddNotification {
3502 notification: Some(notification.clone()),
3503 },
3504 ) {
3505 tracing::error!(
3506 "failed to send notification to {:?} {}",
3507 connection_id,
3508 error
3509 );
3510 }
3511 }
3512 }
3513}
3514
3515/// Send a message to the channel
3516async fn send_channel_message(
3517 _request: proto::SendChannelMessage,
3518 _response: Response<proto::SendChannelMessage>,
3519 _session: MessageContext,
3520) -> Result<()> {
3521 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3522}
3523
3524/// Delete a channel message
3525async fn remove_channel_message(
3526 _request: proto::RemoveChannelMessage,
3527 _response: Response<proto::RemoveChannelMessage>,
3528 _session: MessageContext,
3529) -> Result<()> {
3530 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3531}
3532
3533async fn update_channel_message(
3534 _request: proto::UpdateChannelMessage,
3535 _response: Response<proto::UpdateChannelMessage>,
3536 _session: MessageContext,
3537) -> Result<()> {
3538 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3539}
3540
3541/// Mark a channel message as read
3542async fn acknowledge_channel_message(
3543 _request: proto::AckChannelMessage,
3544 _session: MessageContext,
3545) -> Result<()> {
3546 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3547}
3548
3549/// Mark a buffer version as synced
3550async fn acknowledge_buffer_version(
3551 request: proto::AckBufferOperation,
3552 session: MessageContext,
3553) -> Result<()> {
3554 let buffer_id = BufferId::from_proto(request.buffer_id);
3555 session
3556 .db()
3557 .await
3558 .observe_buffer_version(
3559 buffer_id,
3560 session.user_id(),
3561 request.epoch as i32,
3562 &request.version,
3563 )
3564 .await?;
3565 Ok(())
3566}
3567
3568/// Start receiving chat updates for a channel
3569async fn join_channel_chat(
3570 _request: proto::JoinChannelChat,
3571 _response: Response<proto::JoinChannelChat>,
3572 _session: MessageContext,
3573) -> Result<()> {
3574 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3575}
3576
3577/// Stop receiving chat updates for a channel
3578async fn leave_channel_chat(
3579 _request: proto::LeaveChannelChat,
3580 _session: MessageContext,
3581) -> Result<()> {
3582 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3583}
3584
3585/// Retrieve the chat history for a channel
3586async fn get_channel_messages(
3587 _request: proto::GetChannelMessages,
3588 _response: Response<proto::GetChannelMessages>,
3589 _session: MessageContext,
3590) -> Result<()> {
3591 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3592}
3593
3594/// Retrieve specific chat messages
3595async fn get_channel_messages_by_id(
3596 _request: proto::GetChannelMessagesById,
3597 _response: Response<proto::GetChannelMessagesById>,
3598 _session: MessageContext,
3599) -> Result<()> {
3600 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3601}
3602
3603/// Retrieve the current users notifications
3604async fn get_notifications(
3605 request: proto::GetNotifications,
3606 response: Response<proto::GetNotifications>,
3607 session: MessageContext,
3608) -> Result<()> {
3609 let notifications = session
3610 .db()
3611 .await
3612 .get_notifications(
3613 session.user_id(),
3614 NOTIFICATION_COUNT_PER_PAGE,
3615 request.before_id.map(db::NotificationId::from_proto),
3616 )
3617 .await?;
3618 response.send(proto::GetNotificationsResponse {
3619 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3620 notifications,
3621 })?;
3622 Ok(())
3623}
3624
3625/// Mark notifications as read
3626async fn mark_notification_as_read(
3627 request: proto::MarkNotificationRead,
3628 response: Response<proto::MarkNotificationRead>,
3629 session: MessageContext,
3630) -> Result<()> {
3631 let database = &session.db().await;
3632 let notifications = database
3633 .mark_notification_as_read_by_id(
3634 session.user_id(),
3635 NotificationId::from_proto(request.notification_id),
3636 )
3637 .await?;
3638 send_notifications(
3639 &*session.connection_pool().await,
3640 &session.peer,
3641 notifications,
3642 );
3643 response.send(proto::Ack {})?;
3644 Ok(())
3645}
3646
3647fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3648 let message = match message {
3649 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3650 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3651 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3652 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3653 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3654 code: frame.code.into(),
3655 reason: frame.reason.as_str().to_owned().into(),
3656 })),
3657 // We should never receive a frame while reading the message, according
3658 // to the `tungstenite` maintainers:
3659 //
3660 // > It cannot occur when you read messages from the WebSocket, but it
3661 // > can be used when you want to send the raw frames (e.g. you want to
3662 // > send the frames to the WebSocket without composing the full message first).
3663 // >
3664 // > — https://github.com/snapview/tungstenite-rs/issues/268
3665 TungsteniteMessage::Frame(_) => {
3666 bail!("received an unexpected frame while reading the message")
3667 }
3668 };
3669
3670 Ok(message)
3671}
3672
3673fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3674 match message {
3675 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3676 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3677 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3678 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3679 AxumMessage::Close(frame) => {
3680 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3681 code: frame.code.into(),
3682 reason: frame.reason.as_ref().into(),
3683 }))
3684 }
3685 }
3686}
3687
3688fn notify_membership_updated(
3689 connection_pool: &mut ConnectionPool,
3690 result: MembershipUpdated,
3691 user_id: UserId,
3692 peer: &Peer,
3693) {
3694 for membership in &result.new_channels.channel_memberships {
3695 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3696 }
3697 for channel_id in &result.removed_channels {
3698 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3699 }
3700
3701 let user_channels_update = proto::UpdateUserChannels {
3702 channel_memberships: result
3703 .new_channels
3704 .channel_memberships
3705 .iter()
3706 .map(|cm| proto::ChannelMembership {
3707 channel_id: cm.channel_id.to_proto(),
3708 role: cm.role.into(),
3709 })
3710 .collect(),
3711 ..Default::default()
3712 };
3713
3714 let mut update = build_channels_update(result.new_channels);
3715 update.delete_channels = result
3716 .removed_channels
3717 .into_iter()
3718 .map(|id| id.to_proto())
3719 .collect();
3720 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3721
3722 for connection_id in connection_pool.user_connection_ids(user_id) {
3723 peer.send(connection_id, user_channels_update.clone())
3724 .trace_err();
3725 peer.send(connection_id, update.clone()).trace_err();
3726 }
3727}
3728
3729fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3730 proto::UpdateUserChannels {
3731 channel_memberships: channels
3732 .channel_memberships
3733 .iter()
3734 .map(|m| proto::ChannelMembership {
3735 channel_id: m.channel_id.to_proto(),
3736 role: m.role.into(),
3737 })
3738 .collect(),
3739 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3740 }
3741}
3742
3743fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3744 let mut update = proto::UpdateChannels::default();
3745
3746 for channel in channels.channels {
3747 update.channels.push(channel.to_proto());
3748 }
3749
3750 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3751
3752 for (channel_id, participants) in channels.channel_participants {
3753 update
3754 .channel_participants
3755 .push(proto::ChannelParticipants {
3756 channel_id: channel_id.to_proto(),
3757 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3758 });
3759 }
3760
3761 for channel in channels.invited_channels {
3762 update.channel_invitations.push(channel.to_proto());
3763 }
3764
3765 update
3766}
3767
3768fn build_initial_contacts_update(
3769 contacts: Vec<db::Contact>,
3770 pool: &ConnectionPool,
3771) -> proto::UpdateContacts {
3772 let mut update = proto::UpdateContacts::default();
3773
3774 for contact in contacts {
3775 match contact {
3776 db::Contact::Accepted { user_id, busy } => {
3777 update.contacts.push(contact_for_user(user_id, busy, pool));
3778 }
3779 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3780 db::Contact::Incoming { user_id } => {
3781 update
3782 .incoming_requests
3783 .push(proto::IncomingContactRequest {
3784 requester_id: user_id.to_proto(),
3785 })
3786 }
3787 }
3788 }
3789
3790 update
3791}
3792
3793fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3794 proto::Contact {
3795 user_id: user_id.to_proto(),
3796 online: pool.is_user_online(user_id),
3797 busy,
3798 }
3799}
3800
3801fn room_updated(room: &proto::Room, peer: &Peer) {
3802 broadcast(
3803 None,
3804 room.participants
3805 .iter()
3806 .filter_map(|participant| Some(participant.peer_id?.into())),
3807 |peer_id| {
3808 peer.send(
3809 peer_id,
3810 proto::RoomUpdated {
3811 room: Some(room.clone()),
3812 },
3813 )
3814 },
3815 );
3816}
3817
3818fn channel_updated(
3819 channel: &db::channel::Model,
3820 room: &proto::Room,
3821 peer: &Peer,
3822 pool: &ConnectionPool,
3823) {
3824 let participants = room
3825 .participants
3826 .iter()
3827 .map(|p| p.user_id)
3828 .collect::<Vec<_>>();
3829
3830 broadcast(
3831 None,
3832 pool.channel_connection_ids(channel.root_id())
3833 .filter_map(|(channel_id, role)| {
3834 role.can_see_channel(channel.visibility)
3835 .then_some(channel_id)
3836 }),
3837 |peer_id| {
3838 peer.send(
3839 peer_id,
3840 proto::UpdateChannels {
3841 channel_participants: vec![proto::ChannelParticipants {
3842 channel_id: channel.id.to_proto(),
3843 participant_user_ids: participants.clone(),
3844 }],
3845 ..Default::default()
3846 },
3847 )
3848 },
3849 );
3850}
3851
3852async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3853 let db = session.db().await;
3854
3855 let contacts = db.get_contacts(user_id).await?;
3856 let busy = db.is_user_busy(user_id).await?;
3857
3858 let pool = session.connection_pool().await;
3859 let updated_contact = contact_for_user(user_id, busy, &pool);
3860 for contact in contacts {
3861 if let db::Contact::Accepted {
3862 user_id: contact_user_id,
3863 ..
3864 } = contact
3865 {
3866 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3867 session
3868 .peer
3869 .send(
3870 contact_conn_id,
3871 proto::UpdateContacts {
3872 contacts: vec![updated_contact.clone()],
3873 remove_contacts: Default::default(),
3874 incoming_requests: Default::default(),
3875 remove_incoming_requests: Default::default(),
3876 outgoing_requests: Default::default(),
3877 remove_outgoing_requests: Default::default(),
3878 },
3879 )
3880 .trace_err();
3881 }
3882 }
3883 }
3884 Ok(())
3885}
3886
3887async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3888 let mut contacts_to_update = HashSet::default();
3889
3890 let room_id;
3891 let canceled_calls_to_user_ids;
3892 let livekit_room;
3893 let delete_livekit_room;
3894 let room;
3895 let channel;
3896
3897 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3898 contacts_to_update.insert(session.user_id());
3899
3900 for project in left_room.left_projects.values() {
3901 project_left(project, session);
3902 }
3903
3904 room_id = RoomId::from_proto(left_room.room.id);
3905 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3906 livekit_room = mem::take(&mut left_room.room.livekit_room);
3907 delete_livekit_room = left_room.deleted;
3908 room = mem::take(&mut left_room.room);
3909 channel = mem::take(&mut left_room.channel);
3910
3911 room_updated(&room, &session.peer);
3912 } else {
3913 return Ok(());
3914 }
3915
3916 if let Some(channel) = channel {
3917 channel_updated(
3918 &channel,
3919 &room,
3920 &session.peer,
3921 &*session.connection_pool().await,
3922 );
3923 }
3924
3925 {
3926 let pool = session.connection_pool().await;
3927 for canceled_user_id in canceled_calls_to_user_ids {
3928 for connection_id in pool.user_connection_ids(canceled_user_id) {
3929 session
3930 .peer
3931 .send(
3932 connection_id,
3933 proto::CallCanceled {
3934 room_id: room_id.to_proto(),
3935 },
3936 )
3937 .trace_err();
3938 }
3939 contacts_to_update.insert(canceled_user_id);
3940 }
3941 }
3942
3943 for contact_user_id in contacts_to_update {
3944 update_user_contacts(contact_user_id, session).await?;
3945 }
3946
3947 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3948 live_kit
3949 .remove_participant(livekit_room.clone(), session.user_id().to_string())
3950 .await
3951 .trace_err();
3952
3953 if delete_livekit_room {
3954 live_kit.delete_room(livekit_room).await.trace_err();
3955 }
3956 }
3957
3958 Ok(())
3959}
3960
3961async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3962 let left_channel_buffers = session
3963 .db()
3964 .await
3965 .leave_channel_buffers(session.connection_id)
3966 .await?;
3967
3968 for left_buffer in left_channel_buffers {
3969 channel_buffer_updated(
3970 session.connection_id,
3971 left_buffer.connections,
3972 &proto::UpdateChannelBufferCollaborators {
3973 channel_id: left_buffer.channel_id.to_proto(),
3974 collaborators: left_buffer.collaborators,
3975 },
3976 &session.peer,
3977 );
3978 }
3979
3980 Ok(())
3981}
3982
3983fn project_left(project: &db::LeftProject, session: &Session) {
3984 for connection_id in &project.connection_ids {
3985 if project.should_unshare {
3986 session
3987 .peer
3988 .send(
3989 *connection_id,
3990 proto::UnshareProject {
3991 project_id: project.id.to_proto(),
3992 },
3993 )
3994 .trace_err();
3995 } else {
3996 session
3997 .peer
3998 .send(
3999 *connection_id,
4000 proto::RemoveProjectCollaborator {
4001 project_id: project.id.to_proto(),
4002 peer_id: Some(session.connection_id.into()),
4003 },
4004 )
4005 .trace_err();
4006 }
4007 }
4008}
4009
4010async fn share_agent_thread(
4011 request: proto::ShareAgentThread,
4012 response: Response<proto::ShareAgentThread>,
4013 session: MessageContext,
4014) -> Result<()> {
4015 let user_id = session.user_id();
4016
4017 let share_id = SharedThreadId::from_proto(request.session_id.clone())
4018 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4019
4020 session
4021 .db()
4022 .await
4023 .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4024 .await?;
4025
4026 response.send(proto::Ack {})?;
4027
4028 Ok(())
4029}
4030
4031async fn get_shared_agent_thread(
4032 request: proto::GetSharedAgentThread,
4033 response: Response<proto::GetSharedAgentThread>,
4034 session: MessageContext,
4035) -> Result<()> {
4036 let share_id = SharedThreadId::from_proto(request.session_id)
4037 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4038
4039 let result = session.db().await.get_shared_thread(share_id).await?;
4040
4041 match result {
4042 Some((thread, username)) => {
4043 response.send(proto::GetSharedAgentThreadResponse {
4044 title: thread.title,
4045 thread_data: thread.data,
4046 sharer_username: username,
4047 created_at: thread.created_at.and_utc().to_rfc3339(),
4048 })?;
4049 }
4050 None => {
4051 return Err(anyhow!("Shared thread not found").into());
4052 }
4053 }
4054
4055 Ok(())
4056}
4057
4058pub trait ResultExt {
4059 type Ok;
4060
4061 fn trace_err(self) -> Option<Self::Ok>;
4062}
4063
4064impl<T, E> ResultExt for Result<T, E>
4065where
4066 E: std::fmt::Debug,
4067{
4068 type Ok = T;
4069
4070 #[track_caller]
4071 fn trace_err(self) -> Option<T> {
4072 match self {
4073 Ok(value) => Some(value),
4074 Err(error) => {
4075 tracing::error!("{:?}", error);
4076 None
4077 }
4078 }
4079 }
4080}