1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::{
5 AppState, Error, Result, auth,
6 db::{
7 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser, Database,
8 InviteMemberResult, MembershipUpdated, NotificationId, ProjectId, RejoinedProject,
9 RemoveChannelMemberResult, RespondToChannelInvite, RoomId, ServerId, SharedThreadId, User,
10 UserId,
11 },
12 executor::Executor,
13};
14use anyhow::{Context as _, anyhow, bail};
15use async_tungstenite::tungstenite::{
16 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
17};
18use axum::headers::UserAgent;
19use axum::{
20 Extension, Router, TypedHeader,
21 body::Body,
22 extract::{
23 ConnectInfo, WebSocketUpgrade,
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use futures::TryFutureExt as _;
36use rpc::proto::split_repository_update;
37use tracing::Span;
38use util::paths::PathStyle;
39
40use futures::{
41 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
42 stream::FuturesUnordered,
43};
44use prometheus::{IntGauge, register_int_gauge};
45use rpc::{
46 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
47 proto::{
48 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
49 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
50 },
51};
52use semver::Version;
53use std::{
54 any::TypeId,
55 future::Future,
56 marker::PhantomData,
57 mem,
58 net::SocketAddr,
59 ops::{Deref, DerefMut},
60 rc::Rc,
61 sync::{
62 Arc, OnceLock,
63 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
64 },
65 time::{Duration, Instant},
66};
67use tokio::sync::{Semaphore, watch};
68use tower::ServiceBuilder;
69use tracing::{
70 Instrument,
71 field::{self},
72 info_span, instrument,
73};
74
75pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
76
77// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
78pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
79
80const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
81const MAX_CONCURRENT_CONNECTIONS: usize = 512;
82
83static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
84
85const TOTAL_DURATION_MS: &str = "total_duration_ms";
86const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
87const QUEUE_DURATION_MS: &str = "queue_duration_ms";
88const HOST_WAITING_MS: &str = "host_waiting_ms";
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
92
93pub struct ConnectionGuard;
94
95impl ConnectionGuard {
96 pub fn try_acquire() -> Result<Self, ()> {
97 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
98 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
99 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
100 tracing::error!(
101 "too many concurrent connections: {}",
102 current_connections + 1
103 );
104 return Err(());
105 }
106 Ok(ConnectionGuard)
107 }
108}
109
110impl Drop for ConnectionGuard {
111 fn drop(&mut self) {
112 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
113 }
114}
115
116struct Response<R> {
117 peer: Arc<Peer>,
118 receipt: Receipt<R>,
119 responded: Arc<AtomicBool>,
120}
121
122impl<R: RequestMessage> Response<R> {
123 fn send(self, payload: R::Response) -> Result<()> {
124 self.responded.store(true, SeqCst);
125 self.peer.respond(self.receipt, payload)?;
126 Ok(())
127 }
128}
129
130#[derive(Clone, Debug)]
131pub enum Principal {
132 User(User),
133}
134
135impl Principal {
136 fn update_span(&self, span: &tracing::Span) {
137 match &self {
138 Principal::User(user) => {
139 span.record("user_id", user.id.0);
140 span.record("login", &user.github_login);
141 }
142 }
143 }
144}
145
146#[derive(Clone)]
147struct MessageContext {
148 session: Session,
149 span: tracing::Span,
150}
151
152impl Deref for MessageContext {
153 type Target = Session;
154
155 fn deref(&self) -> &Self::Target {
156 &self.session
157 }
158}
159
160impl MessageContext {
161 pub fn forward_request<T: RequestMessage>(
162 &self,
163 receiver_id: ConnectionId,
164 request: T,
165 ) -> impl Future<Output = anyhow::Result<T::Response>> {
166 let request_start_time = Instant::now();
167 let span = self.span.clone();
168 tracing::info!("start forwarding request");
169 self.peer
170 .forward_request(self.connection_id, receiver_id, request)
171 .inspect(move |_| {
172 span.record(
173 HOST_WAITING_MS,
174 request_start_time.elapsed().as_micros() as f64 / 1000.0,
175 );
176 })
177 .inspect_err(|_| tracing::error!("error forwarding request"))
178 .inspect_ok(|_| tracing::info!("finished forwarding request"))
179 }
180}
181
182#[derive(Clone)]
183struct Session {
184 principal: Principal,
185 connection_id: ConnectionId,
186 db: Arc<tokio::sync::Mutex<DbHandle>>,
187 peer: Arc<Peer>,
188 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
189 app_state: Arc<AppState>,
190 /// The GeoIP country code for the user.
191 #[allow(unused)]
192 geoip_country_code: Option<String>,
193 #[allow(unused)]
194 system_id: Option<String>,
195 _executor: Executor,
196}
197
198impl Session {
199 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
200 #[cfg(feature = "test-support")]
201 tokio::task::yield_now().await;
202 let guard = self.db.lock().await;
203 #[cfg(feature = "test-support")]
204 tokio::task::yield_now().await;
205 guard
206 }
207
208 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
209 #[cfg(feature = "test-support")]
210 tokio::task::yield_now().await;
211 let guard = self.connection_pool.lock();
212 ConnectionPoolGuard {
213 guard,
214 _not_send: PhantomData,
215 }
216 }
217
218 #[expect(dead_code)]
219 fn is_staff(&self) -> bool {
220 match &self.principal {
221 Principal::User(user) => user.admin,
222 }
223 }
224
225 fn user_id(&self) -> UserId {
226 match &self.principal {
227 Principal::User(user) => user.id,
228 }
229 }
230}
231
232impl Debug for Session {
233 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
234 let mut result = f.debug_struct("Session");
235 match &self.principal {
236 Principal::User(user) => {
237 result.field("user", &user.github_login);
238 }
239 }
240 result.field("connection_id", &self.connection_id).finish()
241 }
242}
243
244struct DbHandle(Arc<Database>);
245
246impl Deref for DbHandle {
247 type Target = Database;
248
249 fn deref(&self) -> &Self::Target {
250 self.0.as_ref()
251 }
252}
253
254pub struct Server {
255 id: parking_lot::Mutex<ServerId>,
256 peer: Arc<Peer>,
257 pub connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
258 app_state: Arc<AppState>,
259 handlers: HashMap<TypeId, MessageHandler>,
260 teardown: watch::Sender<bool>,
261}
262
263struct ConnectionPoolGuard<'a> {
264 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
265 _not_send: PhantomData<Rc<()>>,
266}
267
268impl Server {
269 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
270 let mut server = Self {
271 id: parking_lot::Mutex::new(id),
272 peer: Peer::new(id.0 as u32),
273 app_state,
274 connection_pool: Default::default(),
275 handlers: Default::default(),
276 teardown: watch::channel(false).0,
277 };
278
279 server
280 .add_request_handler(ping)
281 .add_request_handler(create_room)
282 .add_request_handler(join_room)
283 .add_request_handler(rejoin_room)
284 .add_request_handler(leave_room)
285 .add_request_handler(set_room_participant_role)
286 .add_request_handler(call)
287 .add_request_handler(cancel_call)
288 .add_message_handler(decline_call)
289 .add_request_handler(update_participant_location)
290 .add_request_handler(share_project)
291 .add_message_handler(unshare_project)
292 .add_request_handler(join_project)
293 .add_message_handler(leave_project)
294 .add_request_handler(update_project)
295 .add_request_handler(update_worktree)
296 .add_request_handler(update_repository)
297 .add_request_handler(remove_repository)
298 .add_message_handler(start_language_server)
299 .add_message_handler(update_language_server)
300 .add_message_handler(update_diagnostic_summary)
301 .add_message_handler(update_worktree_settings)
302 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
303 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
304 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
305 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
306 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
307 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
308 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
309 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
310 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
311 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
312 .add_request_handler(forward_read_only_project_request::<proto::OpenImageByPath>)
313 .add_request_handler(forward_read_only_project_request::<proto::DownloadFileByPath>)
314 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetDefaultBranch>)
316 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
317 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
318 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
319 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
320 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
321 .add_request_handler(
322 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
323 )
324 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
325 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
326 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
327 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
328 .add_request_handler(
329 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
330 )
331 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
332 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
333 .add_request_handler(
334 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
335 )
336 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
337 .add_request_handler(
338 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
339 )
340 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
341 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
342 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
343 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
344 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
345 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
346 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
347 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
348 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
349 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
350 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
353 )
354 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
355 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
356 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
357 .add_request_handler(lsp_query)
358 .add_message_handler(broadcast_project_message_from_host::<proto::LspQueryResponse>)
359 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
360 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
361 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
362 .add_message_handler(create_buffer_for_peer)
363 .add_message_handler(create_image_for_peer)
364 .add_request_handler(update_buffer)
365 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
366 .add_message_handler(
367 broadcast_project_message_from_host::<proto::RefreshSemanticTokens>,
368 )
369 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
370 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
371 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
372 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
373 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
374 .add_message_handler(
375 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
376 )
377 .add_request_handler(get_users)
378 .add_request_handler(fuzzy_search_users)
379 .add_request_handler(request_contact)
380 .add_request_handler(remove_contact)
381 .add_request_handler(respond_to_contact_request)
382 .add_message_handler(subscribe_to_channels)
383 .add_request_handler(create_channel)
384 .add_request_handler(delete_channel)
385 .add_request_handler(invite_channel_member)
386 .add_request_handler(remove_channel_member)
387 .add_request_handler(set_channel_member_role)
388 .add_request_handler(set_channel_visibility)
389 .add_request_handler(rename_channel)
390 .add_request_handler(join_channel_buffer)
391 .add_request_handler(leave_channel_buffer)
392 .add_message_handler(update_channel_buffer)
393 .add_request_handler(rejoin_channel_buffers)
394 .add_request_handler(get_channel_members)
395 .add_request_handler(respond_to_channel_invite)
396 .add_request_handler(join_channel)
397 .add_request_handler(join_channel_chat)
398 .add_message_handler(leave_channel_chat)
399 .add_request_handler(send_channel_message)
400 .add_request_handler(remove_channel_message)
401 .add_request_handler(update_channel_message)
402 .add_request_handler(get_channel_messages)
403 .add_request_handler(get_channel_messages_by_id)
404 .add_request_handler(get_notifications)
405 .add_request_handler(mark_notification_as_read)
406 .add_request_handler(move_channel)
407 .add_request_handler(reorder_channel)
408 .add_request_handler(follow)
409 .add_message_handler(unfollow)
410 .add_message_handler(update_followers)
411 .add_message_handler(acknowledge_channel_message)
412 .add_message_handler(acknowledge_buffer_version)
413 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
414 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
415 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
416 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
417 .add_request_handler(forward_mutating_project_request::<proto::StashDrop>)
418 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
419 .add_request_handler(forward_mutating_project_request::<proto::RunGitHook>)
420 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
421 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
422 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
423 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
424 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
425 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
426 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
427 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
428 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
429 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
430 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
431 .add_request_handler(forward_mutating_project_request::<proto::GetTreeDiff>)
432 .add_request_handler(forward_mutating_project_request::<proto::GetBlobContent>)
433 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
434 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
435 .add_request_handler(forward_mutating_project_request::<proto::GitCreateRemote>)
436 .add_request_handler(forward_mutating_project_request::<proto::GitRemoveRemote>)
437 .add_request_handler(forward_read_only_project_request::<proto::GitGetWorktrees>)
438 .add_request_handler(forward_read_only_project_request::<proto::GitGetHeadSha>)
439 .add_request_handler(forward_read_only_project_request::<proto::GetCommitData>)
440 .add_request_handler(forward_mutating_project_request::<proto::GitCreateWorktree>)
441 .add_request_handler(disallow_guest_request::<proto::GitRemoveWorktree>)
442 .add_request_handler(disallow_guest_request::<proto::GitRenameWorktree>)
443 .add_request_handler(forward_mutating_project_request::<proto::GitEditRef>)
444 .add_request_handler(forward_mutating_project_request::<proto::GitRepairWorktrees>)
445 .add_request_handler(disallow_guest_request::<proto::GitCreateArchiveCheckpoint>)
446 .add_request_handler(disallow_guest_request::<proto::GitRestoreArchiveCheckpoint>)
447 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
448 .add_request_handler(forward_mutating_project_request::<proto::ToggleLspLogs>)
449 .add_message_handler(broadcast_project_message_from_host::<proto::LanguageServerLog>)
450 .add_request_handler(share_agent_thread)
451 .add_request_handler(get_shared_agent_thread)
452 .add_request_handler(forward_project_search_chunk);
453
454 Arc::new(server)
455 }
456
457 pub async fn start(&self) -> Result<()> {
458 let server_id = *self.id.lock();
459 let app_state = self.app_state.clone();
460 let peer = self.peer.clone();
461 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
462 let pool = self.connection_pool.clone();
463 let livekit_client = self.app_state.livekit_client.clone();
464
465 let span = info_span!("start server");
466 self.app_state.executor.spawn_detached(
467 async move {
468 tracing::info!("waiting for cleanup timeout");
469 timeout.await;
470 tracing::info!("cleanup timeout expired, retrieving stale rooms");
471
472 app_state
473 .db
474 .delete_stale_channel_chat_participants(
475 &app_state.config.zed_environment,
476 server_id,
477 )
478 .await
479 .trace_err();
480
481 if let Some((room_ids, channel_ids)) = app_state
482 .db
483 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
484 .await
485 .trace_err()
486 {
487 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
488 tracing::info!(
489 stale_channel_buffer_count = channel_ids.len(),
490 "retrieved stale channel buffers"
491 );
492
493 for channel_id in channel_ids {
494 if let Some(refreshed_channel_buffer) = app_state
495 .db
496 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
497 .await
498 .trace_err()
499 {
500 for connection_id in refreshed_channel_buffer.connection_ids {
501 peer.send(
502 connection_id,
503 proto::UpdateChannelBufferCollaborators {
504 channel_id: channel_id.to_proto(),
505 collaborators: refreshed_channel_buffer
506 .collaborators
507 .clone(),
508 },
509 )
510 .trace_err();
511 }
512 }
513 }
514
515 for room_id in room_ids {
516 let mut contacts_to_update = HashSet::default();
517 let mut canceled_calls_to_user_ids = Vec::new();
518 let mut livekit_room = String::new();
519 let mut delete_livekit_room = false;
520
521 if let Some(mut refreshed_room) = app_state
522 .db
523 .clear_stale_room_participants(room_id, server_id)
524 .await
525 .trace_err()
526 {
527 tracing::info!(
528 room_id = room_id.0,
529 new_participant_count = refreshed_room.room.participants.len(),
530 "refreshed room"
531 );
532 room_updated(&refreshed_room.room, &peer);
533 if let Some(channel) = refreshed_room.channel.as_ref() {
534 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
535 }
536 contacts_to_update
537 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
538 contacts_to_update
539 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
540 canceled_calls_to_user_ids =
541 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
542 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
543 delete_livekit_room = refreshed_room.room.participants.is_empty();
544 }
545
546 {
547 let pool = pool.lock();
548 for canceled_user_id in canceled_calls_to_user_ids {
549 for connection_id in pool.user_connection_ids(canceled_user_id) {
550 peer.send(
551 connection_id,
552 proto::CallCanceled {
553 room_id: room_id.to_proto(),
554 },
555 )
556 .trace_err();
557 }
558 }
559 }
560
561 for user_id in contacts_to_update {
562 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
563 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
564 if let Some((busy, contacts)) = busy.zip(contacts) {
565 let pool = pool.lock();
566 let updated_contact = contact_for_user(user_id, busy, &pool);
567 for contact in contacts {
568 if let db::Contact::Accepted {
569 user_id: contact_user_id,
570 ..
571 } = contact
572 {
573 for contact_conn_id in
574 pool.user_connection_ids(contact_user_id)
575 {
576 peer.send(
577 contact_conn_id,
578 proto::UpdateContacts {
579 contacts: vec![updated_contact.clone()],
580 remove_contacts: Default::default(),
581 incoming_requests: Default::default(),
582 remove_incoming_requests: Default::default(),
583 outgoing_requests: Default::default(),
584 remove_outgoing_requests: Default::default(),
585 },
586 )
587 .trace_err();
588 }
589 }
590 }
591 }
592 }
593
594 if let Some(live_kit) = livekit_client.as_ref()
595 && delete_livekit_room
596 {
597 live_kit.delete_room(livekit_room).await.trace_err();
598 }
599 }
600 }
601
602 app_state
603 .db
604 .delete_stale_channel_chat_participants(
605 &app_state.config.zed_environment,
606 server_id,
607 )
608 .await
609 .trace_err();
610
611 app_state
612 .db
613 .clear_old_worktree_entries(server_id)
614 .await
615 .trace_err();
616
617 app_state
618 .db
619 .delete_stale_servers(&app_state.config.zed_environment, server_id)
620 .await
621 .trace_err();
622 }
623 .instrument(span),
624 );
625 Ok(())
626 }
627
628 pub fn teardown(&self) {
629 self.peer.teardown();
630 self.connection_pool.lock().reset();
631 let _ = self.teardown.send(true);
632 }
633
634 #[cfg(feature = "test-support")]
635 pub fn reset(&self, id: ServerId) {
636 self.teardown();
637 *self.id.lock() = id;
638 self.peer.reset(id.0 as u32);
639 let _ = self.teardown.send(false);
640 }
641
642 #[cfg(feature = "test-support")]
643 pub fn id(&self) -> ServerId {
644 *self.id.lock()
645 }
646
647 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
648 where
649 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
650 Fut: 'static + Send + Future<Output = Result<()>>,
651 M: EnvelopedMessage,
652 {
653 let prev_handler = self.handlers.insert(
654 TypeId::of::<M>(),
655 Box::new(move |envelope, session, span| {
656 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
657 let received_at = envelope.received_at;
658 tracing::info!("message received");
659 let start_time = Instant::now();
660 let future = (handler)(
661 *envelope,
662 MessageContext {
663 session,
664 span: span.clone(),
665 },
666 );
667 async move {
668 let result = future.await;
669 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
670 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
671 let queue_duration_ms = total_duration_ms - processing_duration_ms;
672 span.record(TOTAL_DURATION_MS, total_duration_ms);
673 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
674 span.record(QUEUE_DURATION_MS, queue_duration_ms);
675 match result {
676 Err(error) => {
677 tracing::error!(?error, "error handling message")
678 }
679 Ok(()) => tracing::info!("finished handling message"),
680 }
681 }
682 .boxed()
683 }),
684 );
685 if prev_handler.is_some() {
686 panic!("registered a handler for the same message twice");
687 }
688 self
689 }
690
691 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
692 where
693 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
694 Fut: 'static + Send + Future<Output = Result<()>>,
695 M: EnvelopedMessage,
696 {
697 self.add_handler(move |envelope, session| handler(envelope.payload, session));
698 self
699 }
700
701 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
702 where
703 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
704 Fut: Send + Future<Output = Result<()>>,
705 M: RequestMessage,
706 {
707 let handler = Arc::new(handler);
708 self.add_handler(move |envelope, session| {
709 let receipt = envelope.receipt();
710 let handler = handler.clone();
711 async move {
712 let peer = session.peer.clone();
713 let responded = Arc::new(AtomicBool::default());
714 let response = Response {
715 peer: peer.clone(),
716 responded: responded.clone(),
717 receipt,
718 };
719 match (handler)(envelope.payload, response, session).await {
720 Ok(()) => {
721 if responded.load(std::sync::atomic::Ordering::SeqCst) {
722 Ok(())
723 } else {
724 Err(anyhow!("handler did not send a response"))?
725 }
726 }
727 Err(error) => {
728 let proto_err = match &error {
729 Error::Internal(err) => err.to_proto(),
730 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
731 };
732 peer.respond_with_error(receipt, proto_err)?;
733 Err(error)
734 }
735 }
736 }
737 })
738 }
739
740 pub fn handle_connection(
741 self: &Arc<Self>,
742 connection: Connection,
743 address: String,
744 principal: Principal,
745 zed_version: ZedVersion,
746 release_channel: Option<String>,
747 user_agent: Option<String>,
748 geoip_country_code: Option<String>,
749 system_id: Option<String>,
750 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
751 executor: Executor,
752 connection_guard: Option<ConnectionGuard>,
753 ) -> impl Future<Output = ()> + use<> {
754 let this = self.clone();
755 let span = info_span!("handle connection", %address,
756 connection_id=field::Empty,
757 user_id=field::Empty,
758 login=field::Empty,
759 user_agent=field::Empty,
760 geoip_country_code=field::Empty,
761 release_channel=field::Empty,
762 );
763 principal.update_span(&span);
764 if let Some(user_agent) = user_agent {
765 span.record("user_agent", user_agent);
766 }
767 if let Some(release_channel) = release_channel {
768 span.record("release_channel", release_channel);
769 }
770
771 if let Some(country_code) = geoip_country_code.as_ref() {
772 span.record("geoip_country_code", country_code);
773 }
774
775 let mut teardown = self.teardown.subscribe();
776 async move {
777 if *teardown.borrow() {
778 tracing::error!("server is tearing down");
779 return;
780 }
781
782 let (connection_id, handle_io, mut incoming_rx) =
783 this.peer.add_connection(connection, {
784 let executor = executor.clone();
785 move |duration| executor.sleep(duration)
786 });
787 tracing::Span::current().record("connection_id", format!("{}", connection_id));
788
789 tracing::info!("connection opened");
790
791 let session = Session {
792 principal: principal.clone(),
793 connection_id,
794 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
795 peer: this.peer.clone(),
796 connection_pool: this.connection_pool.clone(),
797 app_state: this.app_state.clone(),
798 geoip_country_code,
799 system_id,
800 _executor: executor.clone(),
801 };
802
803 if let Err(error) = this
804 .send_initial_client_update(
805 connection_id,
806 zed_version,
807 send_connection_id,
808 &session,
809 )
810 .await
811 {
812 tracing::error!(?error, "failed to send initial client update");
813 return;
814 }
815 drop(connection_guard);
816
817 let handle_io = handle_io.fuse();
818 futures::pin_mut!(handle_io);
819
820 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
821 // This prevents deadlocks when e.g., client A performs a request to client B and
822 // client B performs a request to client A. If both clients stop processing further
823 // messages until their respective request completes, they won't have a chance to
824 // respond to the other client's request and cause a deadlock.
825 //
826 // This arrangement ensures we will attempt to process earlier messages first, but fall
827 // back to processing messages arrived later in the spirit of making progress.
828 const MAX_CONCURRENT_HANDLERS: usize = 256;
829 let mut foreground_message_handlers = FuturesUnordered::new();
830 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
831 let get_concurrent_handlers = {
832 let concurrent_handlers = concurrent_handlers.clone();
833 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
834 };
835 loop {
836 let next_message = async {
837 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
838 let message = incoming_rx.next().await;
839 // Cache the concurrent_handlers here, so that we know what the
840 // queue looks like as each handler starts
841 (permit, message, get_concurrent_handlers())
842 }
843 .fuse();
844 futures::pin_mut!(next_message);
845 futures::select_biased! {
846 _ = teardown.changed().fuse() => return,
847 result = handle_io => {
848 if let Err(error) = result {
849 tracing::error!(?error, "error handling I/O");
850 }
851 break;
852 }
853 _ = foreground_message_handlers.next() => {}
854 next_message = next_message => {
855 let (permit, message, concurrent_handlers) = next_message;
856 if let Some(message) = message {
857 let type_name = message.payload_type_name();
858 // note: we copy all the fields from the parent span so we can query them in the logs.
859 // (https://github.com/tokio-rs/tracing/issues/2670).
860 let span = tracing::info_span!("receive message",
861 %connection_id,
862 %address,
863 type_name,
864 concurrent_handlers,
865 user_id=field::Empty,
866 login=field::Empty,
867 lsp_query_request=field::Empty,
868 release_channel=field::Empty,
869 { TOTAL_DURATION_MS }=field::Empty,
870 { PROCESSING_DURATION_MS }=field::Empty,
871 { QUEUE_DURATION_MS }=field::Empty,
872 { HOST_WAITING_MS }=field::Empty
873 );
874 principal.update_span(&span);
875 let span_enter = span.enter();
876 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
877 let is_background = message.is_background();
878 let handle_message = (handler)(message, session.clone(), span.clone());
879 drop(span_enter);
880
881 let handle_message = async move {
882 handle_message.await;
883 drop(permit);
884 }.instrument(span);
885 if is_background {
886 executor.spawn_detached(handle_message);
887 } else {
888 foreground_message_handlers.push(handle_message);
889 }
890 } else {
891 tracing::error!("no message handler");
892 }
893 } else {
894 tracing::info!("connection closed");
895 break;
896 }
897 }
898 }
899 }
900
901 drop(foreground_message_handlers);
902 let concurrent_handlers = get_concurrent_handlers();
903 tracing::info!(concurrent_handlers, "signing out");
904 if let Err(error) = connection_lost(session, teardown, executor).await {
905 tracing::error!(?error, "error signing out");
906 }
907 }
908 .instrument(span)
909 }
910
911 async fn send_initial_client_update(
912 &self,
913 connection_id: ConnectionId,
914 zed_version: ZedVersion,
915 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
916 session: &Session,
917 ) -> Result<()> {
918 self.peer.send(
919 connection_id,
920 proto::Hello {
921 peer_id: Some(connection_id.into()),
922 },
923 )?;
924 tracing::info!("sent hello message");
925 if let Some(send_connection_id) = send_connection_id.take() {
926 let _ = send_connection_id.send(connection_id);
927 }
928
929 match &session.principal {
930 Principal::User(user) => {
931 if !user.connected_once {
932 self.peer.send(connection_id, proto::ShowContacts {})?;
933 self.app_state
934 .db
935 .set_user_connected_once(user.id, true)
936 .await?;
937 }
938
939 let contacts = self.app_state.db.get_contacts(user.id).await?;
940
941 {
942 let mut pool = self.connection_pool.lock();
943 pool.add_connection(connection_id, user.id, user.admin, zed_version.clone());
944 self.peer.send(
945 connection_id,
946 build_initial_contacts_update(contacts, &pool),
947 )?;
948 }
949
950 if should_auto_subscribe_to_channels(&zed_version) {
951 subscribe_user_to_channels(user.id, session).await?;
952 }
953
954 if let Some(incoming_call) =
955 self.app_state.db.incoming_call_for_user(user.id).await?
956 {
957 self.peer.send(connection_id, incoming_call)?;
958 }
959
960 update_user_contacts(user.id, session).await?;
961 }
962 }
963
964 Ok(())
965 }
966}
967
968impl Deref for ConnectionPoolGuard<'_> {
969 type Target = ConnectionPool;
970
971 fn deref(&self) -> &Self::Target {
972 &self.guard
973 }
974}
975
976impl DerefMut for ConnectionPoolGuard<'_> {
977 fn deref_mut(&mut self) -> &mut Self::Target {
978 &mut self.guard
979 }
980}
981
982impl Drop for ConnectionPoolGuard<'_> {
983 fn drop(&mut self) {
984 #[cfg(feature = "test-support")]
985 self.check_invariants();
986 }
987}
988
989fn broadcast<F>(
990 sender_id: Option<ConnectionId>,
991 receiver_ids: impl IntoIterator<Item = ConnectionId>,
992 mut f: F,
993) where
994 F: FnMut(ConnectionId) -> anyhow::Result<()>,
995{
996 for receiver_id in receiver_ids {
997 if Some(receiver_id) != sender_id
998 && let Err(error) = f(receiver_id)
999 {
1000 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1001 }
1002 }
1003}
1004
1005pub struct ProtocolVersion(u32);
1006
1007impl Header for ProtocolVersion {
1008 fn name() -> &'static HeaderName {
1009 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1010 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1011 }
1012
1013 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1014 where
1015 Self: Sized,
1016 I: Iterator<Item = &'i axum::http::HeaderValue>,
1017 {
1018 let version = values
1019 .next()
1020 .ok_or_else(axum::headers::Error::invalid)?
1021 .to_str()
1022 .map_err(|_| axum::headers::Error::invalid())?
1023 .parse()
1024 .map_err(|_| axum::headers::Error::invalid())?;
1025 Ok(Self(version))
1026 }
1027
1028 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1029 values.extend([self.0.to_string().parse().unwrap()]);
1030 }
1031}
1032
1033pub struct AppVersionHeader(Version);
1034impl Header for AppVersionHeader {
1035 fn name() -> &'static HeaderName {
1036 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1037 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1038 }
1039
1040 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1041 where
1042 Self: Sized,
1043 I: Iterator<Item = &'i axum::http::HeaderValue>,
1044 {
1045 let version = values
1046 .next()
1047 .ok_or_else(axum::headers::Error::invalid)?
1048 .to_str()
1049 .map_err(|_| axum::headers::Error::invalid())?
1050 .parse()
1051 .map_err(|_| axum::headers::Error::invalid())?;
1052 Ok(Self(version))
1053 }
1054
1055 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1056 values.extend([self.0.to_string().parse().unwrap()]);
1057 }
1058}
1059
1060#[derive(Debug)]
1061pub struct ReleaseChannelHeader(String);
1062
1063impl Header for ReleaseChannelHeader {
1064 fn name() -> &'static HeaderName {
1065 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1066 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1067 }
1068
1069 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1070 where
1071 Self: Sized,
1072 I: Iterator<Item = &'i axum::http::HeaderValue>,
1073 {
1074 Ok(Self(
1075 values
1076 .next()
1077 .ok_or_else(axum::headers::Error::invalid)?
1078 .to_str()
1079 .map_err(|_| axum::headers::Error::invalid())?
1080 .to_owned(),
1081 ))
1082 }
1083
1084 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1085 values.extend([self.0.parse().unwrap()]);
1086 }
1087}
1088
1089pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1090 Router::new()
1091 .route("/rpc", get(handle_websocket_request))
1092 .layer(
1093 ServiceBuilder::new()
1094 .layer(Extension(server.app_state.clone()))
1095 .layer(middleware::from_fn(auth::validate_header)),
1096 )
1097 .route("/metrics", get(handle_metrics))
1098 .layer(Extension(server))
1099}
1100
1101pub async fn handle_websocket_request(
1102 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1103 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1104 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1105 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1106 Extension(server): Extension<Arc<Server>>,
1107 Extension(principal): Extension<Principal>,
1108 user_agent: Option<TypedHeader<UserAgent>>,
1109 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1110 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1111 ws: WebSocketUpgrade,
1112) -> axum::response::Response {
1113 if protocol_version != rpc::PROTOCOL_VERSION {
1114 return (
1115 StatusCode::UPGRADE_REQUIRED,
1116 "client must be upgraded".to_string(),
1117 )
1118 .into_response();
1119 }
1120
1121 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1122 return (
1123 StatusCode::UPGRADE_REQUIRED,
1124 "no version header found".to_string(),
1125 )
1126 .into_response();
1127 };
1128
1129 let release_channel = release_channel_header.map(|header| header.0.0);
1130
1131 if !version.can_collaborate() {
1132 return (
1133 StatusCode::UPGRADE_REQUIRED,
1134 "client must be upgraded".to_string(),
1135 )
1136 .into_response();
1137 }
1138
1139 let socket_address = socket_address.to_string();
1140
1141 // Acquire connection guard before WebSocket upgrade
1142 let connection_guard = match ConnectionGuard::try_acquire() {
1143 Ok(guard) => guard,
1144 Err(()) => {
1145 return (
1146 StatusCode::SERVICE_UNAVAILABLE,
1147 "Too many concurrent connections",
1148 )
1149 .into_response();
1150 }
1151 };
1152
1153 ws.on_upgrade(move |socket| {
1154 let socket = socket
1155 .map_ok(to_tungstenite_message)
1156 .err_into()
1157 .with(|message| async move { to_axum_message(message) });
1158 let connection = Connection::new(Box::pin(socket));
1159 async move {
1160 server
1161 .handle_connection(
1162 connection,
1163 socket_address,
1164 principal,
1165 version,
1166 release_channel,
1167 user_agent.map(|header| header.to_string()),
1168 country_code_header.map(|header| header.to_string()),
1169 system_id_header.map(|header| header.to_string()),
1170 None,
1171 Executor::Production,
1172 Some(connection_guard),
1173 )
1174 .await;
1175 }
1176 })
1177}
1178
1179pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1180 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1181 let connections_metric = CONNECTIONS_METRIC
1182 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1183
1184 let connections = server
1185 .connection_pool
1186 .lock()
1187 .connections()
1188 .filter(|connection| !connection.admin)
1189 .count();
1190 connections_metric.set(connections as _);
1191
1192 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1193 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1194 register_int_gauge!(
1195 "shared_projects",
1196 "number of open projects with one or more guests"
1197 )
1198 .unwrap()
1199 });
1200
1201 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1202 shared_projects_metric.set(shared_projects as _);
1203
1204 let encoder = prometheus::TextEncoder::new();
1205 let metric_families = prometheus::gather();
1206 let encoded_metrics = encoder
1207 .encode_to_string(&metric_families)
1208 .map_err(|err| anyhow!("{err}"))?;
1209 Ok(encoded_metrics)
1210}
1211
1212#[instrument(err, skip(executor))]
1213async fn connection_lost(
1214 session: Session,
1215 mut teardown: watch::Receiver<bool>,
1216 executor: Executor,
1217) -> Result<()> {
1218 session.peer.disconnect(session.connection_id);
1219 session
1220 .connection_pool()
1221 .await
1222 .remove_connection(session.connection_id)?;
1223
1224 session
1225 .db()
1226 .await
1227 .connection_lost(session.connection_id)
1228 .await
1229 .trace_err();
1230
1231 futures::select_biased! {
1232 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1233
1234 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1235 leave_room_for_session(&session, session.connection_id).await.trace_err();
1236 leave_channel_buffers_for_session(&session)
1237 .await
1238 .trace_err();
1239
1240 if !session
1241 .connection_pool()
1242 .await
1243 .is_user_online(session.user_id())
1244 {
1245 let db = session.db().await;
1246 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1247 room_updated(&room, &session.peer);
1248 }
1249 }
1250
1251 update_user_contacts(session.user_id(), &session).await?;
1252 },
1253 _ = teardown.changed().fuse() => {}
1254 }
1255
1256 Ok(())
1257}
1258
1259/// Acknowledges a ping from a client, used to keep the connection alive.
1260async fn ping(
1261 _: proto::Ping,
1262 response: Response<proto::Ping>,
1263 _session: MessageContext,
1264) -> Result<()> {
1265 response.send(proto::Ack {})?;
1266 Ok(())
1267}
1268
1269/// Creates a new room for calling (outside of channels)
1270async fn create_room(
1271 _request: proto::CreateRoom,
1272 response: Response<proto::CreateRoom>,
1273 session: MessageContext,
1274) -> Result<()> {
1275 let livekit_room = nanoid::nanoid!(30);
1276
1277 let live_kit_connection_info = util::maybe!(async {
1278 let live_kit = session.app_state.livekit_client.as_ref();
1279 let live_kit = live_kit?;
1280 let user_id = session.user_id().to_string();
1281
1282 let token = live_kit.room_token(&livekit_room, &user_id).trace_err()?;
1283
1284 Some(proto::LiveKitConnectionInfo {
1285 server_url: live_kit.url().into(),
1286 token,
1287 can_publish: true,
1288 })
1289 })
1290 .await;
1291
1292 let room = session
1293 .db()
1294 .await
1295 .create_room(session.user_id(), session.connection_id, &livekit_room)
1296 .await?;
1297
1298 response.send(proto::CreateRoomResponse {
1299 room: Some(room.clone()),
1300 live_kit_connection_info,
1301 })?;
1302
1303 update_user_contacts(session.user_id(), &session).await?;
1304 Ok(())
1305}
1306
1307/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1308async fn join_room(
1309 request: proto::JoinRoom,
1310 response: Response<proto::JoinRoom>,
1311 session: MessageContext,
1312) -> Result<()> {
1313 let room_id = RoomId::from_proto(request.id);
1314
1315 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1316
1317 if let Some(channel_id) = channel_id {
1318 return join_channel_internal(channel_id, Box::new(response), session).await;
1319 }
1320
1321 let joined_room = {
1322 let room = session
1323 .db()
1324 .await
1325 .join_room(room_id, session.user_id(), session.connection_id)
1326 .await?;
1327 room_updated(&room.room, &session.peer);
1328 room.into_inner()
1329 };
1330
1331 for connection_id in session
1332 .connection_pool()
1333 .await
1334 .user_connection_ids(session.user_id())
1335 {
1336 session
1337 .peer
1338 .send(
1339 connection_id,
1340 proto::CallCanceled {
1341 room_id: room_id.to_proto(),
1342 },
1343 )
1344 .trace_err();
1345 }
1346
1347 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1348 {
1349 live_kit
1350 .room_token(
1351 &joined_room.room.livekit_room,
1352 &session.user_id().to_string(),
1353 )
1354 .trace_err()
1355 .map(|token| proto::LiveKitConnectionInfo {
1356 server_url: live_kit.url().into(),
1357 token,
1358 can_publish: true,
1359 })
1360 } else {
1361 None
1362 };
1363
1364 response.send(proto::JoinRoomResponse {
1365 room: Some(joined_room.room),
1366 channel_id: None,
1367 live_kit_connection_info,
1368 })?;
1369
1370 update_user_contacts(session.user_id(), &session).await?;
1371 Ok(())
1372}
1373
1374/// Rejoin room is used to reconnect to a room after connection errors.
1375async fn rejoin_room(
1376 request: proto::RejoinRoom,
1377 response: Response<proto::RejoinRoom>,
1378 session: MessageContext,
1379) -> Result<()> {
1380 let room;
1381 let channel;
1382 {
1383 let mut rejoined_room = session
1384 .db()
1385 .await
1386 .rejoin_room(request, session.user_id(), session.connection_id)
1387 .await?;
1388
1389 response.send(proto::RejoinRoomResponse {
1390 room: Some(rejoined_room.room.clone()),
1391 reshared_projects: rejoined_room
1392 .reshared_projects
1393 .iter()
1394 .map(|project| proto::ResharedProject {
1395 id: project.id.to_proto(),
1396 collaborators: project
1397 .collaborators
1398 .iter()
1399 .map(|collaborator| collaborator.to_proto())
1400 .collect(),
1401 })
1402 .collect(),
1403 rejoined_projects: rejoined_room
1404 .rejoined_projects
1405 .iter()
1406 .map(|rejoined_project| rejoined_project.to_proto())
1407 .collect(),
1408 })?;
1409 room_updated(&rejoined_room.room, &session.peer);
1410
1411 for project in &rejoined_room.reshared_projects {
1412 for collaborator in &project.collaborators {
1413 session
1414 .peer
1415 .send(
1416 collaborator.connection_id,
1417 proto::UpdateProjectCollaborator {
1418 project_id: project.id.to_proto(),
1419 old_peer_id: Some(project.old_connection_id.into()),
1420 new_peer_id: Some(session.connection_id.into()),
1421 },
1422 )
1423 .trace_err();
1424 }
1425
1426 broadcast(
1427 Some(session.connection_id),
1428 project
1429 .collaborators
1430 .iter()
1431 .map(|collaborator| collaborator.connection_id),
1432 |connection_id| {
1433 session.peer.forward_send(
1434 session.connection_id,
1435 connection_id,
1436 proto::UpdateProject {
1437 project_id: project.id.to_proto(),
1438 worktrees: project.worktrees.clone(),
1439 },
1440 )
1441 },
1442 );
1443 }
1444
1445 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1446
1447 let rejoined_room = rejoined_room.into_inner();
1448
1449 room = rejoined_room.room;
1450 channel = rejoined_room.channel;
1451 }
1452
1453 if let Some(channel) = channel {
1454 channel_updated(
1455 &channel,
1456 &room,
1457 &session.peer,
1458 &*session.connection_pool().await,
1459 );
1460 }
1461
1462 update_user_contacts(session.user_id(), &session).await?;
1463 Ok(())
1464}
1465
1466fn notify_rejoined_projects(
1467 rejoined_projects: &mut Vec<RejoinedProject>,
1468 session: &Session,
1469) -> Result<()> {
1470 for project in rejoined_projects.iter() {
1471 for collaborator in &project.collaborators {
1472 session
1473 .peer
1474 .send(
1475 collaborator.connection_id,
1476 proto::UpdateProjectCollaborator {
1477 project_id: project.id.to_proto(),
1478 old_peer_id: Some(project.old_connection_id.into()),
1479 new_peer_id: Some(session.connection_id.into()),
1480 },
1481 )
1482 .trace_err();
1483 }
1484 }
1485
1486 for project in rejoined_projects {
1487 for worktree in mem::take(&mut project.worktrees) {
1488 // Stream this worktree's entries.
1489 let message = proto::UpdateWorktree {
1490 project_id: project.id.to_proto(),
1491 worktree_id: worktree.id,
1492 abs_path: worktree.abs_path.clone(),
1493 root_name: worktree.root_name,
1494 root_repo_common_dir: worktree.root_repo_common_dir,
1495 updated_entries: worktree.updated_entries,
1496 removed_entries: worktree.removed_entries,
1497 scan_id: worktree.scan_id,
1498 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1499 updated_repositories: worktree.updated_repositories,
1500 removed_repositories: worktree.removed_repositories,
1501 };
1502 for update in proto::split_worktree_update(message) {
1503 session.peer.send(session.connection_id, update)?;
1504 }
1505
1506 // Stream this worktree's diagnostics.
1507 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1508 if let Some(summary) = worktree_diagnostics.next() {
1509 let message = proto::UpdateDiagnosticSummary {
1510 project_id: project.id.to_proto(),
1511 worktree_id: worktree.id,
1512 summary: Some(summary),
1513 more_summaries: worktree_diagnostics.collect(),
1514 };
1515 session.peer.send(session.connection_id, message)?;
1516 }
1517
1518 for settings_file in worktree.settings_files {
1519 session.peer.send(
1520 session.connection_id,
1521 proto::UpdateWorktreeSettings {
1522 project_id: project.id.to_proto(),
1523 worktree_id: worktree.id,
1524 path: settings_file.path,
1525 content: Some(settings_file.content),
1526 kind: Some(settings_file.kind.to_proto().into()),
1527 outside_worktree: Some(settings_file.outside_worktree),
1528 },
1529 )?;
1530 }
1531 }
1532
1533 for repository in mem::take(&mut project.updated_repositories) {
1534 for update in split_repository_update(repository) {
1535 session.peer.send(session.connection_id, update)?;
1536 }
1537 }
1538
1539 for id in mem::take(&mut project.removed_repositories) {
1540 session.peer.send(
1541 session.connection_id,
1542 proto::RemoveRepository {
1543 project_id: project.id.to_proto(),
1544 id,
1545 },
1546 )?;
1547 }
1548 }
1549
1550 Ok(())
1551}
1552
1553/// leave room disconnects from the room.
1554async fn leave_room(
1555 _: proto::LeaveRoom,
1556 response: Response<proto::LeaveRoom>,
1557 session: MessageContext,
1558) -> Result<()> {
1559 leave_room_for_session(&session, session.connection_id).await?;
1560 response.send(proto::Ack {})?;
1561 Ok(())
1562}
1563
1564/// Updates the permissions of someone else in the room.
1565async fn set_room_participant_role(
1566 request: proto::SetRoomParticipantRole,
1567 response: Response<proto::SetRoomParticipantRole>,
1568 session: MessageContext,
1569) -> Result<()> {
1570 let user_id = UserId::from_proto(request.user_id);
1571 let role = ChannelRole::from(request.role());
1572
1573 let (livekit_room, can_publish) = {
1574 let room = session
1575 .db()
1576 .await
1577 .set_room_participant_role(
1578 session.user_id(),
1579 RoomId::from_proto(request.room_id),
1580 user_id,
1581 role,
1582 )
1583 .await?;
1584
1585 let livekit_room = room.livekit_room.clone();
1586 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1587 room_updated(&room, &session.peer);
1588 (livekit_room, can_publish)
1589 };
1590
1591 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1592 live_kit
1593 .update_participant(
1594 livekit_room.clone(),
1595 request.user_id.to_string(),
1596 livekit_api::proto::ParticipantPermission {
1597 can_subscribe: true,
1598 can_publish,
1599 can_publish_data: can_publish,
1600 hidden: false,
1601 recorder: false,
1602 },
1603 )
1604 .await
1605 .trace_err();
1606 }
1607
1608 response.send(proto::Ack {})?;
1609 Ok(())
1610}
1611
1612/// Call someone else into the current room
1613async fn call(
1614 request: proto::Call,
1615 response: Response<proto::Call>,
1616 session: MessageContext,
1617) -> Result<()> {
1618 let room_id = RoomId::from_proto(request.room_id);
1619 let calling_user_id = session.user_id();
1620 let calling_connection_id = session.connection_id;
1621 let called_user_id = UserId::from_proto(request.called_user_id);
1622 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1623 if !session
1624 .db()
1625 .await
1626 .has_contact(calling_user_id, called_user_id)
1627 .await?
1628 {
1629 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1630 }
1631
1632 let incoming_call = {
1633 let (room, incoming_call) = &mut *session
1634 .db()
1635 .await
1636 .call(
1637 room_id,
1638 calling_user_id,
1639 calling_connection_id,
1640 called_user_id,
1641 initial_project_id,
1642 )
1643 .await?;
1644 room_updated(room, &session.peer);
1645 mem::take(incoming_call)
1646 };
1647 update_user_contacts(called_user_id, &session).await?;
1648
1649 let mut calls = session
1650 .connection_pool()
1651 .await
1652 .user_connection_ids(called_user_id)
1653 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1654 .collect::<FuturesUnordered<_>>();
1655
1656 while let Some(call_response) = calls.next().await {
1657 match call_response.as_ref() {
1658 Ok(_) => {
1659 response.send(proto::Ack {})?;
1660 return Ok(());
1661 }
1662 Err(_) => {
1663 call_response.trace_err();
1664 }
1665 }
1666 }
1667
1668 {
1669 let room = session
1670 .db()
1671 .await
1672 .call_failed(room_id, called_user_id)
1673 .await?;
1674 room_updated(&room, &session.peer);
1675 }
1676 update_user_contacts(called_user_id, &session).await?;
1677
1678 Err(anyhow!("failed to ring user"))?
1679}
1680
1681/// Cancel an outgoing call.
1682async fn cancel_call(
1683 request: proto::CancelCall,
1684 response: Response<proto::CancelCall>,
1685 session: MessageContext,
1686) -> Result<()> {
1687 let called_user_id = UserId::from_proto(request.called_user_id);
1688 let room_id = RoomId::from_proto(request.room_id);
1689 {
1690 let room = session
1691 .db()
1692 .await
1693 .cancel_call(room_id, session.connection_id, called_user_id)
1694 .await?;
1695 room_updated(&room, &session.peer);
1696 }
1697
1698 for connection_id in session
1699 .connection_pool()
1700 .await
1701 .user_connection_ids(called_user_id)
1702 {
1703 session
1704 .peer
1705 .send(
1706 connection_id,
1707 proto::CallCanceled {
1708 room_id: room_id.to_proto(),
1709 },
1710 )
1711 .trace_err();
1712 }
1713 response.send(proto::Ack {})?;
1714
1715 update_user_contacts(called_user_id, &session).await?;
1716 Ok(())
1717}
1718
1719/// Decline an incoming call.
1720async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1721 let room_id = RoomId::from_proto(message.room_id);
1722 {
1723 let room = session
1724 .db()
1725 .await
1726 .decline_call(Some(room_id), session.user_id())
1727 .await?
1728 .context("declining call")?;
1729 room_updated(&room, &session.peer);
1730 }
1731
1732 for connection_id in session
1733 .connection_pool()
1734 .await
1735 .user_connection_ids(session.user_id())
1736 {
1737 session
1738 .peer
1739 .send(
1740 connection_id,
1741 proto::CallCanceled {
1742 room_id: room_id.to_proto(),
1743 },
1744 )
1745 .trace_err();
1746 }
1747 update_user_contacts(session.user_id(), &session).await?;
1748 Ok(())
1749}
1750
1751/// Updates other participants in the room with your current location.
1752async fn update_participant_location(
1753 request: proto::UpdateParticipantLocation,
1754 response: Response<proto::UpdateParticipantLocation>,
1755 session: MessageContext,
1756) -> Result<()> {
1757 let room_id = RoomId::from_proto(request.room_id);
1758 let location = request.location.context("invalid location")?;
1759
1760 let db = session.db().await;
1761 let room = db
1762 .update_room_participant_location(room_id, session.connection_id, location)
1763 .await?;
1764
1765 room_updated(&room, &session.peer);
1766 response.send(proto::Ack {})?;
1767 Ok(())
1768}
1769
1770/// Share a project into the room.
1771async fn share_project(
1772 request: proto::ShareProject,
1773 response: Response<proto::ShareProject>,
1774 session: MessageContext,
1775) -> Result<()> {
1776 let (project_id, room) = &*session
1777 .db()
1778 .await
1779 .share_project(
1780 RoomId::from_proto(request.room_id),
1781 session.connection_id,
1782 &request.worktrees,
1783 request.is_ssh_project,
1784 request.windows_paths.unwrap_or(false),
1785 &request.features,
1786 )
1787 .await?;
1788 response.send(proto::ShareProjectResponse {
1789 project_id: project_id.to_proto(),
1790 })?;
1791 room_updated(room, &session.peer);
1792
1793 Ok(())
1794}
1795
1796/// Unshare a project from the room.
1797async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1798 let project_id = ProjectId::from_proto(message.project_id);
1799 unshare_project_internal(project_id, session.connection_id, &session).await
1800}
1801
1802async fn unshare_project_internal(
1803 project_id: ProjectId,
1804 connection_id: ConnectionId,
1805 session: &Session,
1806) -> Result<()> {
1807 let delete = {
1808 let room_guard = session
1809 .db()
1810 .await
1811 .unshare_project(project_id, connection_id)
1812 .await?;
1813
1814 let (delete, room, guest_connection_ids) = &*room_guard;
1815
1816 let message = proto::UnshareProject {
1817 project_id: project_id.to_proto(),
1818 };
1819
1820 broadcast(
1821 Some(connection_id),
1822 guest_connection_ids.iter().copied(),
1823 |conn_id| session.peer.send(conn_id, message.clone()),
1824 );
1825 if let Some(room) = room {
1826 room_updated(room, &session.peer);
1827 }
1828
1829 *delete
1830 };
1831
1832 if delete {
1833 let db = session.db().await;
1834 db.delete_project(project_id).await?;
1835 }
1836
1837 Ok(())
1838}
1839
1840/// Join someone elses shared project.
1841async fn join_project(
1842 request: proto::JoinProject,
1843 response: Response<proto::JoinProject>,
1844 session: MessageContext,
1845) -> Result<()> {
1846 let project_id = ProjectId::from_proto(request.project_id);
1847
1848 tracing::info!(%project_id, "join project");
1849
1850 let db = session.db().await;
1851 let project_model = db.get_project(project_id).await?;
1852 let host_features: Vec<String> =
1853 serde_json::from_str(&project_model.features).unwrap_or_default();
1854 let guest_features: HashSet<_> = request.features.iter().collect();
1855 let host_features_set: HashSet<_> = host_features.iter().collect();
1856 if guest_features != host_features_set {
1857 let host_connection_id = project_model.host_connection()?;
1858 let mut pool = session.connection_pool().await;
1859 let host_version = pool
1860 .connection(host_connection_id)
1861 .map(|c| c.zed_version.to_string());
1862 let guest_version = pool
1863 .connection(session.connection_id)
1864 .map(|c| c.zed_version.to_string());
1865 drop(pool);
1866 Err(anyhow!(
1867 "The host (v{}) and guest (v{}) are using incompatible versions of Zed. The peer with the older version must update to collaborate.",
1868 host_version.as_deref().unwrap_or("unknown"),
1869 guest_version.as_deref().unwrap_or("unknown"),
1870 ))?;
1871 }
1872
1873 let (project, replica_id) = &mut *db
1874 .join_project(
1875 project_id,
1876 session.connection_id,
1877 session.user_id(),
1878 request.committer_name.clone(),
1879 request.committer_email.clone(),
1880 )
1881 .await?;
1882 drop(db);
1883
1884 tracing::info!(%project_id, "join remote project");
1885 let collaborators = project
1886 .collaborators
1887 .iter()
1888 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1889 .map(|collaborator| collaborator.to_proto())
1890 .collect::<Vec<_>>();
1891 let project_id = project.id;
1892 let guest_user_id = session.user_id();
1893
1894 let worktrees = project
1895 .worktrees
1896 .iter()
1897 .map(|(id, worktree)| proto::WorktreeMetadata {
1898 id: *id,
1899 root_name: worktree.root_name.clone(),
1900 visible: worktree.visible,
1901 abs_path: worktree.abs_path.clone(),
1902 root_repo_common_dir: None,
1903 })
1904 .collect::<Vec<_>>();
1905
1906 let add_project_collaborator = proto::AddProjectCollaborator {
1907 project_id: project_id.to_proto(),
1908 collaborator: Some(proto::Collaborator {
1909 peer_id: Some(session.connection_id.into()),
1910 replica_id: replica_id.0 as u32,
1911 user_id: guest_user_id.to_proto(),
1912 is_host: false,
1913 committer_name: request.committer_name.clone(),
1914 committer_email: request.committer_email.clone(),
1915 }),
1916 };
1917
1918 for collaborator in &collaborators {
1919 session
1920 .peer
1921 .send(
1922 collaborator.peer_id.unwrap().into(),
1923 add_project_collaborator.clone(),
1924 )
1925 .trace_err();
1926 }
1927
1928 // First, we send the metadata associated with each worktree.
1929 let (language_servers, language_server_capabilities) = project
1930 .language_servers
1931 .clone()
1932 .into_iter()
1933 .map(|server| (server.server, server.capabilities))
1934 .unzip();
1935 response.send(proto::JoinProjectResponse {
1936 project_id: project.id.0 as u64,
1937 worktrees,
1938 replica_id: replica_id.0 as u32,
1939 collaborators,
1940 language_servers,
1941 language_server_capabilities,
1942 role: project.role.into(),
1943 windows_paths: project.path_style == PathStyle::Windows,
1944 features: project.features.clone(),
1945 })?;
1946
1947 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1948 // Stream this worktree's entries.
1949 let message = proto::UpdateWorktree {
1950 project_id: project_id.to_proto(),
1951 worktree_id,
1952 abs_path: worktree.abs_path.clone(),
1953 root_name: worktree.root_name,
1954 root_repo_common_dir: worktree.root_repo_common_dir,
1955 updated_entries: worktree.entries,
1956 removed_entries: Default::default(),
1957 scan_id: worktree.scan_id,
1958 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1959 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1960 removed_repositories: Default::default(),
1961 };
1962 for update in proto::split_worktree_update(message) {
1963 session.peer.send(session.connection_id, update.clone())?;
1964 }
1965
1966 // Stream this worktree's diagnostics.
1967 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1968 if let Some(summary) = worktree_diagnostics.next() {
1969 let message = proto::UpdateDiagnosticSummary {
1970 project_id: project.id.to_proto(),
1971 worktree_id: worktree.id,
1972 summary: Some(summary),
1973 more_summaries: worktree_diagnostics.collect(),
1974 };
1975 session.peer.send(session.connection_id, message)?;
1976 }
1977
1978 for settings_file in worktree.settings_files {
1979 session.peer.send(
1980 session.connection_id,
1981 proto::UpdateWorktreeSettings {
1982 project_id: project_id.to_proto(),
1983 worktree_id: worktree.id,
1984 path: settings_file.path,
1985 content: Some(settings_file.content),
1986 kind: Some(settings_file.kind.to_proto() as i32),
1987 outside_worktree: Some(settings_file.outside_worktree),
1988 },
1989 )?;
1990 }
1991 }
1992
1993 for repository in mem::take(&mut project.repositories) {
1994 for update in split_repository_update(repository) {
1995 session.peer.send(session.connection_id, update)?;
1996 }
1997 }
1998
1999 for language_server in &project.language_servers {
2000 session.peer.send(
2001 session.connection_id,
2002 proto::UpdateLanguageServer {
2003 project_id: project_id.to_proto(),
2004 server_name: Some(language_server.server.name.clone()),
2005 language_server_id: language_server.server.id,
2006 variant: Some(
2007 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2008 proto::LspDiskBasedDiagnosticsUpdated {},
2009 ),
2010 ),
2011 },
2012 )?;
2013 }
2014
2015 Ok(())
2016}
2017
2018/// Leave someone elses shared project.
2019async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2020 let sender_id = session.connection_id;
2021 let project_id = ProjectId::from_proto(request.project_id);
2022 let db = session.db().await;
2023
2024 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2025 tracing::info!(
2026 %project_id,
2027 "leave project"
2028 );
2029
2030 project_left(project, &session);
2031 if let Some(room) = room {
2032 room_updated(room, &session.peer);
2033 }
2034
2035 Ok(())
2036}
2037
2038/// Updates other participants with changes to the project
2039async fn update_project(
2040 request: proto::UpdateProject,
2041 response: Response<proto::UpdateProject>,
2042 session: MessageContext,
2043) -> Result<()> {
2044 let project_id = ProjectId::from_proto(request.project_id);
2045 let (room, guest_connection_ids) = &*session
2046 .db()
2047 .await
2048 .update_project(project_id, session.connection_id, &request.worktrees)
2049 .await?;
2050 broadcast(
2051 Some(session.connection_id),
2052 guest_connection_ids.iter().copied(),
2053 |connection_id| {
2054 session
2055 .peer
2056 .forward_send(session.connection_id, connection_id, request.clone())
2057 },
2058 );
2059 if let Some(room) = room {
2060 room_updated(room, &session.peer);
2061 }
2062 response.send(proto::Ack {})?;
2063
2064 Ok(())
2065}
2066
2067/// Updates other participants with changes to the worktree
2068async fn update_worktree(
2069 request: proto::UpdateWorktree,
2070 response: Response<proto::UpdateWorktree>,
2071 session: MessageContext,
2072) -> Result<()> {
2073 let guest_connection_ids = session
2074 .db()
2075 .await
2076 .update_worktree(&request, session.connection_id)
2077 .await?;
2078
2079 broadcast(
2080 Some(session.connection_id),
2081 guest_connection_ids.iter().copied(),
2082 |connection_id| {
2083 session
2084 .peer
2085 .forward_send(session.connection_id, connection_id, request.clone())
2086 },
2087 );
2088 response.send(proto::Ack {})?;
2089 Ok(())
2090}
2091
2092async fn update_repository(
2093 request: proto::UpdateRepository,
2094 response: Response<proto::UpdateRepository>,
2095 session: MessageContext,
2096) -> Result<()> {
2097 let guest_connection_ids = session
2098 .db()
2099 .await
2100 .update_repository(&request, session.connection_id)
2101 .await?;
2102
2103 broadcast(
2104 Some(session.connection_id),
2105 guest_connection_ids.iter().copied(),
2106 |connection_id| {
2107 session
2108 .peer
2109 .forward_send(session.connection_id, connection_id, request.clone())
2110 },
2111 );
2112 response.send(proto::Ack {})?;
2113 Ok(())
2114}
2115
2116async fn remove_repository(
2117 request: proto::RemoveRepository,
2118 response: Response<proto::RemoveRepository>,
2119 session: MessageContext,
2120) -> Result<()> {
2121 let guest_connection_ids = session
2122 .db()
2123 .await
2124 .remove_repository(&request, session.connection_id)
2125 .await?;
2126
2127 broadcast(
2128 Some(session.connection_id),
2129 guest_connection_ids.iter().copied(),
2130 |connection_id| {
2131 session
2132 .peer
2133 .forward_send(session.connection_id, connection_id, request.clone())
2134 },
2135 );
2136 response.send(proto::Ack {})?;
2137 Ok(())
2138}
2139
2140/// Updates other participants with changes to the diagnostics
2141async fn update_diagnostic_summary(
2142 message: proto::UpdateDiagnosticSummary,
2143 session: MessageContext,
2144) -> Result<()> {
2145 let guest_connection_ids = session
2146 .db()
2147 .await
2148 .update_diagnostic_summary(&message, session.connection_id)
2149 .await?;
2150
2151 broadcast(
2152 Some(session.connection_id),
2153 guest_connection_ids.iter().copied(),
2154 |connection_id| {
2155 session
2156 .peer
2157 .forward_send(session.connection_id, connection_id, message.clone())
2158 },
2159 );
2160
2161 Ok(())
2162}
2163
2164/// Updates other participants with changes to the worktree settings
2165async fn update_worktree_settings(
2166 message: proto::UpdateWorktreeSettings,
2167 session: MessageContext,
2168) -> Result<()> {
2169 let guest_connection_ids = session
2170 .db()
2171 .await
2172 .update_worktree_settings(&message, session.connection_id)
2173 .await?;
2174
2175 broadcast(
2176 Some(session.connection_id),
2177 guest_connection_ids.iter().copied(),
2178 |connection_id| {
2179 session
2180 .peer
2181 .forward_send(session.connection_id, connection_id, message.clone())
2182 },
2183 );
2184
2185 Ok(())
2186}
2187
2188/// Notify other participants that a language server has started.
2189async fn start_language_server(
2190 request: proto::StartLanguageServer,
2191 session: MessageContext,
2192) -> Result<()> {
2193 let guest_connection_ids = session
2194 .db()
2195 .await
2196 .start_language_server(&request, session.connection_id)
2197 .await?;
2198
2199 broadcast(
2200 Some(session.connection_id),
2201 guest_connection_ids.iter().copied(),
2202 |connection_id| {
2203 session
2204 .peer
2205 .forward_send(session.connection_id, connection_id, request.clone())
2206 },
2207 );
2208 Ok(())
2209}
2210
2211/// Notify other participants that a language server has changed.
2212async fn update_language_server(
2213 request: proto::UpdateLanguageServer,
2214 session: MessageContext,
2215) -> Result<()> {
2216 let project_id = ProjectId::from_proto(request.project_id);
2217 let db = session.db().await;
2218
2219 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2220 && let Some(capabilities) = update.capabilities.clone()
2221 {
2222 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2223 .await?;
2224 }
2225
2226 let project_connection_ids = db
2227 .project_connection_ids(project_id, session.connection_id, true)
2228 .await?;
2229 broadcast(
2230 Some(session.connection_id),
2231 project_connection_ids.iter().copied(),
2232 |connection_id| {
2233 session
2234 .peer
2235 .forward_send(session.connection_id, connection_id, request.clone())
2236 },
2237 );
2238 Ok(())
2239}
2240
2241/// forward a project request to the host. These requests should be read only
2242/// as guests are allowed to send them.
2243async fn forward_read_only_project_request<T>(
2244 request: T,
2245 response: Response<T>,
2246 session: MessageContext,
2247) -> Result<()>
2248where
2249 T: EntityMessage + RequestMessage,
2250{
2251 let project_id = ProjectId::from_proto(request.remote_entity_id());
2252 let host_connection_id = session
2253 .db()
2254 .await
2255 .host_for_read_only_project_request(project_id, session.connection_id)
2256 .await?;
2257 let payload = session.forward_request(host_connection_id, request).await?;
2258 response.send(payload)?;
2259 Ok(())
2260}
2261
2262/// forward a project request to the host. These requests are disallowed
2263/// for guests.
2264async fn forward_mutating_project_request<T>(
2265 request: T,
2266 response: Response<T>,
2267 session: MessageContext,
2268) -> Result<()>
2269where
2270 T: EntityMessage + RequestMessage,
2271{
2272 let project_id = ProjectId::from_proto(request.remote_entity_id());
2273
2274 let host_connection_id = session
2275 .db()
2276 .await
2277 .host_for_mutating_project_request(project_id, session.connection_id)
2278 .await?;
2279 let payload = session.forward_request(host_connection_id, request).await?;
2280 response.send(payload)?;
2281 Ok(())
2282}
2283
2284async fn disallow_guest_request<T>(
2285 _request: T,
2286 response: Response<T>,
2287 _session: MessageContext,
2288) -> Result<()>
2289where
2290 T: RequestMessage,
2291{
2292 response.peer.respond_with_error(
2293 response.receipt,
2294 ErrorCode::Forbidden
2295 .message("request is not allowed for guests".to_string())
2296 .to_proto(),
2297 )?;
2298 response.responded.store(true, SeqCst);
2299 Ok(())
2300}
2301
2302async fn lsp_query(
2303 request: proto::LspQuery,
2304 response: Response<proto::LspQuery>,
2305 session: MessageContext,
2306) -> Result<()> {
2307 let (name, should_write) = request.query_name_and_write_permissions();
2308 tracing::Span::current().record("lsp_query_request", name);
2309 tracing::info!("lsp_query message received");
2310 if should_write {
2311 forward_mutating_project_request(request, response, session).await
2312 } else {
2313 forward_read_only_project_request(request, response, session).await
2314 }
2315}
2316
2317/// Notify other participants that a new buffer has been created
2318async fn create_buffer_for_peer(
2319 request: proto::CreateBufferForPeer,
2320 session: MessageContext,
2321) -> Result<()> {
2322 session
2323 .db()
2324 .await
2325 .check_user_is_project_host(
2326 ProjectId::from_proto(request.project_id),
2327 session.connection_id,
2328 )
2329 .await?;
2330 let peer_id = request.peer_id.context("invalid peer id")?;
2331 session
2332 .peer
2333 .forward_send(session.connection_id, peer_id.into(), request)?;
2334 Ok(())
2335}
2336
2337/// Notify other participants that a new image has been created
2338async fn create_image_for_peer(
2339 request: proto::CreateImageForPeer,
2340 session: MessageContext,
2341) -> Result<()> {
2342 session
2343 .db()
2344 .await
2345 .check_user_is_project_host(
2346 ProjectId::from_proto(request.project_id),
2347 session.connection_id,
2348 )
2349 .await?;
2350 let peer_id = request.peer_id.context("invalid peer id")?;
2351 session
2352 .peer
2353 .forward_send(session.connection_id, peer_id.into(), request)?;
2354 Ok(())
2355}
2356
2357/// Notify other participants that a buffer has been updated. This is
2358/// allowed for guests as long as the update is limited to selections.
2359async fn update_buffer(
2360 request: proto::UpdateBuffer,
2361 response: Response<proto::UpdateBuffer>,
2362 session: MessageContext,
2363) -> Result<()> {
2364 let project_id = ProjectId::from_proto(request.project_id);
2365 let mut capability = Capability::ReadOnly;
2366
2367 for op in request.operations.iter() {
2368 match op.variant {
2369 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2370 Some(_) => capability = Capability::ReadWrite,
2371 }
2372 }
2373
2374 let host = {
2375 let guard = session
2376 .db()
2377 .await
2378 .connections_for_buffer_update(project_id, session.connection_id, capability)
2379 .await?;
2380
2381 let (host, guests) = &*guard;
2382
2383 broadcast(
2384 Some(session.connection_id),
2385 guests.clone(),
2386 |connection_id| {
2387 session
2388 .peer
2389 .forward_send(session.connection_id, connection_id, request.clone())
2390 },
2391 );
2392
2393 *host
2394 };
2395
2396 if host != session.connection_id {
2397 session.forward_request(host, request.clone()).await?;
2398 }
2399
2400 response.send(proto::Ack {})?;
2401 Ok(())
2402}
2403
2404async fn forward_project_search_chunk(
2405 message: proto::FindSearchCandidatesChunk,
2406 response: Response<proto::FindSearchCandidatesChunk>,
2407 session: MessageContext,
2408) -> Result<()> {
2409 let peer_id = message.peer_id.context("missing peer_id")?;
2410 let payload = session
2411 .peer
2412 .forward_request(session.connection_id, peer_id.into(), message)
2413 .await?;
2414 response.send(payload)?;
2415 Ok(())
2416}
2417
2418/// Notify other participants that a project has been updated.
2419async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2420 request: T,
2421 session: MessageContext,
2422) -> Result<()> {
2423 let project_id = ProjectId::from_proto(request.remote_entity_id());
2424 let project_connection_ids = session
2425 .db()
2426 .await
2427 .project_connection_ids(project_id, session.connection_id, false)
2428 .await?;
2429
2430 broadcast(
2431 Some(session.connection_id),
2432 project_connection_ids.iter().copied(),
2433 |connection_id| {
2434 session
2435 .peer
2436 .forward_send(session.connection_id, connection_id, request.clone())
2437 },
2438 );
2439 Ok(())
2440}
2441
2442/// Start following another user in a call.
2443async fn follow(
2444 request: proto::Follow,
2445 response: Response<proto::Follow>,
2446 session: MessageContext,
2447) -> Result<()> {
2448 let room_id = RoomId::from_proto(request.room_id);
2449 let project_id = request.project_id.map(ProjectId::from_proto);
2450 let leader_id = request.leader_id.context("invalid leader id")?.into();
2451 let follower_id = session.connection_id;
2452
2453 session
2454 .db()
2455 .await
2456 .check_room_participants(room_id, leader_id, session.connection_id)
2457 .await?;
2458
2459 let response_payload = session.forward_request(leader_id, request).await?;
2460 response.send(response_payload)?;
2461
2462 if let Some(project_id) = project_id {
2463 let room = session
2464 .db()
2465 .await
2466 .follow(room_id, project_id, leader_id, follower_id)
2467 .await?;
2468 room_updated(&room, &session.peer);
2469 }
2470
2471 Ok(())
2472}
2473
2474/// Stop following another user in a call.
2475async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2476 let room_id = RoomId::from_proto(request.room_id);
2477 let project_id = request.project_id.map(ProjectId::from_proto);
2478 let leader_id = request.leader_id.context("invalid leader id")?.into();
2479 let follower_id = session.connection_id;
2480
2481 session
2482 .db()
2483 .await
2484 .check_room_participants(room_id, leader_id, session.connection_id)
2485 .await?;
2486
2487 session
2488 .peer
2489 .forward_send(session.connection_id, leader_id, request)?;
2490
2491 if let Some(project_id) = project_id {
2492 let room = session
2493 .db()
2494 .await
2495 .unfollow(room_id, project_id, leader_id, follower_id)
2496 .await?;
2497 room_updated(&room, &session.peer);
2498 }
2499
2500 Ok(())
2501}
2502
2503/// Notify everyone following you of your current location.
2504async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2505 let room_id = RoomId::from_proto(request.room_id);
2506 let database = session.db.lock().await;
2507
2508 let connection_ids = if let Some(project_id) = request.project_id {
2509 let project_id = ProjectId::from_proto(project_id);
2510 database
2511 .project_connection_ids(project_id, session.connection_id, true)
2512 .await?
2513 } else {
2514 database
2515 .room_connection_ids(room_id, session.connection_id)
2516 .await?
2517 };
2518
2519 // For now, don't send view update messages back to that view's current leader.
2520 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2521 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2522 _ => None,
2523 });
2524
2525 for connection_id in connection_ids.iter().cloned() {
2526 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2527 session
2528 .peer
2529 .forward_send(session.connection_id, connection_id, request.clone())?;
2530 }
2531 }
2532 Ok(())
2533}
2534
2535/// Get public data about users.
2536async fn get_users(
2537 request: proto::GetUsers,
2538 response: Response<proto::GetUsers>,
2539 session: MessageContext,
2540) -> Result<()> {
2541 let user_ids = request
2542 .user_ids
2543 .into_iter()
2544 .map(UserId::from_proto)
2545 .collect();
2546 let users = session
2547 .db()
2548 .await
2549 .get_users_by_ids(user_ids)
2550 .await?
2551 .into_iter()
2552 .map(|user| proto::User {
2553 id: user.id.to_proto(),
2554 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2555 github_login: user.github_login,
2556 name: user.name,
2557 })
2558 .collect();
2559 response.send(proto::UsersResponse { users })?;
2560 Ok(())
2561}
2562
2563/// Search for users (to invite) buy Github login
2564async fn fuzzy_search_users(
2565 request: proto::FuzzySearchUsers,
2566 response: Response<proto::FuzzySearchUsers>,
2567 session: MessageContext,
2568) -> Result<()> {
2569 let query = request.query;
2570 let users = match query.len() {
2571 0 => vec![],
2572 1 | 2 => session
2573 .db()
2574 .await
2575 .get_user_by_github_login(&query)
2576 .await?
2577 .into_iter()
2578 .collect(),
2579 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2580 };
2581 let users = users
2582 .into_iter()
2583 .filter(|user| user.id != session.user_id())
2584 .map(|user| proto::User {
2585 id: user.id.to_proto(),
2586 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2587 github_login: user.github_login,
2588 name: user.name,
2589 })
2590 .collect();
2591 response.send(proto::UsersResponse { users })?;
2592 Ok(())
2593}
2594
2595/// Send a contact request to another user.
2596async fn request_contact(
2597 request: proto::RequestContact,
2598 response: Response<proto::RequestContact>,
2599 session: MessageContext,
2600) -> Result<()> {
2601 let requester_id = session.user_id();
2602 let responder_id = UserId::from_proto(request.responder_id);
2603 if requester_id == responder_id {
2604 return Err(anyhow!("cannot add yourself as a contact"))?;
2605 }
2606
2607 let notifications = session
2608 .db()
2609 .await
2610 .send_contact_request(requester_id, responder_id)
2611 .await?;
2612
2613 // Update outgoing contact requests of requester
2614 let mut update = proto::UpdateContacts::default();
2615 update.outgoing_requests.push(responder_id.to_proto());
2616 for connection_id in session
2617 .connection_pool()
2618 .await
2619 .user_connection_ids(requester_id)
2620 {
2621 session.peer.send(connection_id, update.clone())?;
2622 }
2623
2624 // Update incoming contact requests of responder
2625 let mut update = proto::UpdateContacts::default();
2626 update
2627 .incoming_requests
2628 .push(proto::IncomingContactRequest {
2629 requester_id: requester_id.to_proto(),
2630 });
2631 let connection_pool = session.connection_pool().await;
2632 for connection_id in connection_pool.user_connection_ids(responder_id) {
2633 session.peer.send(connection_id, update.clone())?;
2634 }
2635
2636 send_notifications(&connection_pool, &session.peer, notifications);
2637
2638 response.send(proto::Ack {})?;
2639 Ok(())
2640}
2641
2642/// Accept or decline a contact request
2643async fn respond_to_contact_request(
2644 request: proto::RespondToContactRequest,
2645 response: Response<proto::RespondToContactRequest>,
2646 session: MessageContext,
2647) -> Result<()> {
2648 let responder_id = session.user_id();
2649 let requester_id = UserId::from_proto(request.requester_id);
2650 let db = session.db().await;
2651 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2652 db.dismiss_contact_notification(responder_id, requester_id)
2653 .await?;
2654 } else {
2655 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2656
2657 let notifications = db
2658 .respond_to_contact_request(responder_id, requester_id, accept)
2659 .await?;
2660 let requester_busy = db.is_user_busy(requester_id).await?;
2661 let responder_busy = db.is_user_busy(responder_id).await?;
2662
2663 let pool = session.connection_pool().await;
2664 // Update responder with new contact
2665 let mut update = proto::UpdateContacts::default();
2666 if accept {
2667 update
2668 .contacts
2669 .push(contact_for_user(requester_id, requester_busy, &pool));
2670 }
2671 update
2672 .remove_incoming_requests
2673 .push(requester_id.to_proto());
2674 for connection_id in pool.user_connection_ids(responder_id) {
2675 session.peer.send(connection_id, update.clone())?;
2676 }
2677
2678 // Update requester with new contact
2679 let mut update = proto::UpdateContacts::default();
2680 if accept {
2681 update
2682 .contacts
2683 .push(contact_for_user(responder_id, responder_busy, &pool));
2684 }
2685 update
2686 .remove_outgoing_requests
2687 .push(responder_id.to_proto());
2688
2689 for connection_id in pool.user_connection_ids(requester_id) {
2690 session.peer.send(connection_id, update.clone())?;
2691 }
2692
2693 send_notifications(&pool, &session.peer, notifications);
2694 }
2695
2696 response.send(proto::Ack {})?;
2697 Ok(())
2698}
2699
2700/// Remove a contact.
2701async fn remove_contact(
2702 request: proto::RemoveContact,
2703 response: Response<proto::RemoveContact>,
2704 session: MessageContext,
2705) -> Result<()> {
2706 let requester_id = session.user_id();
2707 let responder_id = UserId::from_proto(request.user_id);
2708 let db = session.db().await;
2709 let (contact_accepted, deleted_notification_id) =
2710 db.remove_contact(requester_id, responder_id).await?;
2711
2712 let pool = session.connection_pool().await;
2713 // Update outgoing contact requests of requester
2714 let mut update = proto::UpdateContacts::default();
2715 if contact_accepted {
2716 update.remove_contacts.push(responder_id.to_proto());
2717 } else {
2718 update
2719 .remove_outgoing_requests
2720 .push(responder_id.to_proto());
2721 }
2722 for connection_id in pool.user_connection_ids(requester_id) {
2723 session.peer.send(connection_id, update.clone())?;
2724 }
2725
2726 // Update incoming contact requests of responder
2727 let mut update = proto::UpdateContacts::default();
2728 if contact_accepted {
2729 update.remove_contacts.push(requester_id.to_proto());
2730 } else {
2731 update
2732 .remove_incoming_requests
2733 .push(requester_id.to_proto());
2734 }
2735 for connection_id in pool.user_connection_ids(responder_id) {
2736 session.peer.send(connection_id, update.clone())?;
2737 if let Some(notification_id) = deleted_notification_id {
2738 session.peer.send(
2739 connection_id,
2740 proto::DeleteNotification {
2741 notification_id: notification_id.to_proto(),
2742 },
2743 )?;
2744 }
2745 }
2746
2747 response.send(proto::Ack {})?;
2748 Ok(())
2749}
2750
2751fn should_auto_subscribe_to_channels(version: &ZedVersion) -> bool {
2752 version.0.minor < 139
2753}
2754
2755async fn subscribe_to_channels(
2756 _: proto::SubscribeToChannels,
2757 session: MessageContext,
2758) -> Result<()> {
2759 subscribe_user_to_channels(session.user_id(), &session).await?;
2760 Ok(())
2761}
2762
2763async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2764 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2765 let mut pool = session.connection_pool().await;
2766 for membership in &channels_for_user.channel_memberships {
2767 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2768 }
2769 session.peer.send(
2770 session.connection_id,
2771 build_update_user_channels(&channels_for_user),
2772 )?;
2773 session.peer.send(
2774 session.connection_id,
2775 build_channels_update(channels_for_user),
2776 )?;
2777 Ok(())
2778}
2779
2780/// Creates a new channel.
2781async fn create_channel(
2782 request: proto::CreateChannel,
2783 response: Response<proto::CreateChannel>,
2784 session: MessageContext,
2785) -> Result<()> {
2786 let db = session.db().await;
2787
2788 let parent_id = request.parent_id.map(ChannelId::from_proto);
2789 let (channel, membership) = db
2790 .create_channel(&request.name, parent_id, session.user_id())
2791 .await?;
2792
2793 let root_id = channel.root_id();
2794 let channel = Channel::from_model(channel);
2795
2796 response.send(proto::CreateChannelResponse {
2797 channel: Some(channel.to_proto()),
2798 parent_id: request.parent_id,
2799 })?;
2800
2801 let mut connection_pool = session.connection_pool().await;
2802 if let Some(membership) = membership {
2803 connection_pool.subscribe_to_channel(
2804 membership.user_id,
2805 membership.channel_id,
2806 membership.role,
2807 );
2808 let update = proto::UpdateUserChannels {
2809 channel_memberships: vec![proto::ChannelMembership {
2810 channel_id: membership.channel_id.to_proto(),
2811 role: membership.role.into(),
2812 }],
2813 ..Default::default()
2814 };
2815 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2816 session.peer.send(connection_id, update.clone())?;
2817 }
2818 }
2819
2820 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2821 if !role.can_see_channel(channel.visibility) {
2822 continue;
2823 }
2824
2825 let update = proto::UpdateChannels {
2826 channels: vec![channel.to_proto()],
2827 ..Default::default()
2828 };
2829 session.peer.send(connection_id, update.clone())?;
2830 }
2831
2832 Ok(())
2833}
2834
2835/// Delete a channel
2836async fn delete_channel(
2837 request: proto::DeleteChannel,
2838 response: Response<proto::DeleteChannel>,
2839 session: MessageContext,
2840) -> Result<()> {
2841 let db = session.db().await;
2842
2843 let channel_id = request.channel_id;
2844 let (root_channel, removed_channels) = db
2845 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2846 .await?;
2847 response.send(proto::Ack {})?;
2848
2849 // Notify members of removed channels
2850 let mut update = proto::UpdateChannels::default();
2851 update
2852 .delete_channels
2853 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2854
2855 let connection_pool = session.connection_pool().await;
2856 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2857 session.peer.send(connection_id, update.clone())?;
2858 }
2859
2860 Ok(())
2861}
2862
2863/// Invite someone to join a channel.
2864async fn invite_channel_member(
2865 request: proto::InviteChannelMember,
2866 response: Response<proto::InviteChannelMember>,
2867 session: MessageContext,
2868) -> Result<()> {
2869 let db = session.db().await;
2870 let channel_id = ChannelId::from_proto(request.channel_id);
2871 let invitee_id = UserId::from_proto(request.user_id);
2872 let InviteMemberResult {
2873 channel,
2874 notifications,
2875 } = db
2876 .invite_channel_member(
2877 channel_id,
2878 invitee_id,
2879 session.user_id(),
2880 request.role().into(),
2881 )
2882 .await?;
2883
2884 let update = proto::UpdateChannels {
2885 channel_invitations: vec![channel.to_proto()],
2886 ..Default::default()
2887 };
2888
2889 let connection_pool = session.connection_pool().await;
2890 for connection_id in connection_pool.user_connection_ids(invitee_id) {
2891 session.peer.send(connection_id, update.clone())?;
2892 }
2893
2894 send_notifications(&connection_pool, &session.peer, notifications);
2895
2896 response.send(proto::Ack {})?;
2897 Ok(())
2898}
2899
2900/// remove someone from a channel
2901async fn remove_channel_member(
2902 request: proto::RemoveChannelMember,
2903 response: Response<proto::RemoveChannelMember>,
2904 session: MessageContext,
2905) -> Result<()> {
2906 let db = session.db().await;
2907 let channel_id = ChannelId::from_proto(request.channel_id);
2908 let member_id = UserId::from_proto(request.user_id);
2909
2910 let RemoveChannelMemberResult {
2911 membership_update,
2912 notification_id,
2913 } = db
2914 .remove_channel_member(channel_id, member_id, session.user_id())
2915 .await?;
2916
2917 let mut connection_pool = session.connection_pool().await;
2918 notify_membership_updated(
2919 &mut connection_pool,
2920 membership_update,
2921 member_id,
2922 &session.peer,
2923 );
2924 for connection_id in connection_pool.user_connection_ids(member_id) {
2925 if let Some(notification_id) = notification_id {
2926 session
2927 .peer
2928 .send(
2929 connection_id,
2930 proto::DeleteNotification {
2931 notification_id: notification_id.to_proto(),
2932 },
2933 )
2934 .trace_err();
2935 }
2936 }
2937
2938 response.send(proto::Ack {})?;
2939 Ok(())
2940}
2941
2942/// Toggle the channel between public and private.
2943/// Care is taken to maintain the invariant that public channels only descend from public channels,
2944/// (though members-only channels can appear at any point in the hierarchy).
2945async fn set_channel_visibility(
2946 request: proto::SetChannelVisibility,
2947 response: Response<proto::SetChannelVisibility>,
2948 session: MessageContext,
2949) -> Result<()> {
2950 let db = session.db().await;
2951 let channel_id = ChannelId::from_proto(request.channel_id);
2952 let visibility = request.visibility().into();
2953
2954 let channel_model = db
2955 .set_channel_visibility(channel_id, visibility, session.user_id())
2956 .await?;
2957 let root_id = channel_model.root_id();
2958 let channel = Channel::from_model(channel_model);
2959
2960 let mut connection_pool = session.connection_pool().await;
2961 for (user_id, role) in connection_pool
2962 .channel_user_ids(root_id)
2963 .collect::<Vec<_>>()
2964 .into_iter()
2965 {
2966 let update = if role.can_see_channel(channel.visibility) {
2967 connection_pool.subscribe_to_channel(user_id, channel_id, role);
2968 proto::UpdateChannels {
2969 channels: vec![channel.to_proto()],
2970 ..Default::default()
2971 }
2972 } else {
2973 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
2974 proto::UpdateChannels {
2975 delete_channels: vec![channel.id.to_proto()],
2976 ..Default::default()
2977 }
2978 };
2979
2980 for connection_id in connection_pool.user_connection_ids(user_id) {
2981 session.peer.send(connection_id, update.clone())?;
2982 }
2983 }
2984
2985 response.send(proto::Ack {})?;
2986 Ok(())
2987}
2988
2989/// Alter the role for a user in the channel.
2990async fn set_channel_member_role(
2991 request: proto::SetChannelMemberRole,
2992 response: Response<proto::SetChannelMemberRole>,
2993 session: MessageContext,
2994) -> Result<()> {
2995 let db = session.db().await;
2996 let channel_id = ChannelId::from_proto(request.channel_id);
2997 let member_id = UserId::from_proto(request.user_id);
2998 let result = db
2999 .set_channel_member_role(
3000 channel_id,
3001 session.user_id(),
3002 member_id,
3003 request.role().into(),
3004 )
3005 .await?;
3006
3007 match result {
3008 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3009 let mut connection_pool = session.connection_pool().await;
3010 notify_membership_updated(
3011 &mut connection_pool,
3012 membership_update,
3013 member_id,
3014 &session.peer,
3015 )
3016 }
3017 db::SetMemberRoleResult::InviteUpdated(channel) => {
3018 let update = proto::UpdateChannels {
3019 channel_invitations: vec![channel.to_proto()],
3020 ..Default::default()
3021 };
3022
3023 for connection_id in session
3024 .connection_pool()
3025 .await
3026 .user_connection_ids(member_id)
3027 {
3028 session.peer.send(connection_id, update.clone())?;
3029 }
3030 }
3031 }
3032
3033 response.send(proto::Ack {})?;
3034 Ok(())
3035}
3036
3037/// Change the name of a channel
3038async fn rename_channel(
3039 request: proto::RenameChannel,
3040 response: Response<proto::RenameChannel>,
3041 session: MessageContext,
3042) -> Result<()> {
3043 let db = session.db().await;
3044 let channel_id = ChannelId::from_proto(request.channel_id);
3045 let channel_model = db
3046 .rename_channel(channel_id, session.user_id(), &request.name)
3047 .await?;
3048 let root_id = channel_model.root_id();
3049 let channel = Channel::from_model(channel_model);
3050
3051 response.send(proto::RenameChannelResponse {
3052 channel: Some(channel.to_proto()),
3053 })?;
3054
3055 let connection_pool = session.connection_pool().await;
3056 let update = proto::UpdateChannels {
3057 channels: vec![channel.to_proto()],
3058 ..Default::default()
3059 };
3060 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3061 if role.can_see_channel(channel.visibility) {
3062 session.peer.send(connection_id, update.clone())?;
3063 }
3064 }
3065
3066 Ok(())
3067}
3068
3069/// Move a channel to a new parent.
3070async fn move_channel(
3071 request: proto::MoveChannel,
3072 response: Response<proto::MoveChannel>,
3073 session: MessageContext,
3074) -> Result<()> {
3075 let channel_id = ChannelId::from_proto(request.channel_id);
3076 let to = ChannelId::from_proto(request.to);
3077
3078 let (root_id, channels) = session
3079 .db()
3080 .await
3081 .move_channel(channel_id, to, session.user_id())
3082 .await?;
3083
3084 let connection_pool = session.connection_pool().await;
3085 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3086 let channels = channels
3087 .iter()
3088 .filter_map(|channel| {
3089 if role.can_see_channel(channel.visibility) {
3090 Some(channel.to_proto())
3091 } else {
3092 None
3093 }
3094 })
3095 .collect::<Vec<_>>();
3096 if channels.is_empty() {
3097 continue;
3098 }
3099
3100 let update = proto::UpdateChannels {
3101 channels,
3102 ..Default::default()
3103 };
3104
3105 session.peer.send(connection_id, update.clone())?;
3106 }
3107
3108 response.send(Ack {})?;
3109 Ok(())
3110}
3111
3112async fn reorder_channel(
3113 request: proto::ReorderChannel,
3114 response: Response<proto::ReorderChannel>,
3115 session: MessageContext,
3116) -> Result<()> {
3117 let channel_id = ChannelId::from_proto(request.channel_id);
3118 let direction = request.direction();
3119
3120 let updated_channels = session
3121 .db()
3122 .await
3123 .reorder_channel(channel_id, direction, session.user_id())
3124 .await?;
3125
3126 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3127 let connection_pool = session.connection_pool().await;
3128 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3129 let channels = updated_channels
3130 .iter()
3131 .filter_map(|channel| {
3132 if role.can_see_channel(channel.visibility) {
3133 Some(channel.to_proto())
3134 } else {
3135 None
3136 }
3137 })
3138 .collect::<Vec<_>>();
3139
3140 if channels.is_empty() {
3141 continue;
3142 }
3143
3144 let update = proto::UpdateChannels {
3145 channels,
3146 ..Default::default()
3147 };
3148
3149 session.peer.send(connection_id, update.clone())?;
3150 }
3151 }
3152
3153 response.send(Ack {})?;
3154 Ok(())
3155}
3156
3157/// Get the list of channel members
3158async fn get_channel_members(
3159 request: proto::GetChannelMembers,
3160 response: Response<proto::GetChannelMembers>,
3161 session: MessageContext,
3162) -> Result<()> {
3163 let db = session.db().await;
3164 let channel_id = ChannelId::from_proto(request.channel_id);
3165 let limit = if request.limit == 0 {
3166 u16::MAX as u64
3167 } else {
3168 request.limit
3169 };
3170 let (members, users) = db
3171 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3172 .await?;
3173 response.send(proto::GetChannelMembersResponse { members, users })?;
3174 Ok(())
3175}
3176
3177/// Accept or decline a channel invitation.
3178async fn respond_to_channel_invite(
3179 request: proto::RespondToChannelInvite,
3180 response: Response<proto::RespondToChannelInvite>,
3181 session: MessageContext,
3182) -> Result<()> {
3183 let db = session.db().await;
3184 let channel_id = ChannelId::from_proto(request.channel_id);
3185 let RespondToChannelInvite {
3186 membership_update,
3187 notifications,
3188 } = db
3189 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3190 .await?;
3191
3192 let mut connection_pool = session.connection_pool().await;
3193 if let Some(membership_update) = membership_update {
3194 notify_membership_updated(
3195 &mut connection_pool,
3196 membership_update,
3197 session.user_id(),
3198 &session.peer,
3199 );
3200 } else {
3201 let update = proto::UpdateChannels {
3202 remove_channel_invitations: vec![channel_id.to_proto()],
3203 ..Default::default()
3204 };
3205
3206 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3207 session.peer.send(connection_id, update.clone())?;
3208 }
3209 };
3210
3211 send_notifications(&connection_pool, &session.peer, notifications);
3212
3213 response.send(proto::Ack {})?;
3214
3215 Ok(())
3216}
3217
3218/// Join the channels' room
3219async fn join_channel(
3220 request: proto::JoinChannel,
3221 response: Response<proto::JoinChannel>,
3222 session: MessageContext,
3223) -> Result<()> {
3224 let channel_id = ChannelId::from_proto(request.channel_id);
3225 join_channel_internal(channel_id, Box::new(response), session).await
3226}
3227
3228trait JoinChannelInternalResponse {
3229 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3230}
3231impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3232 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3233 Response::<proto::JoinChannel>::send(self, result)
3234 }
3235}
3236impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3237 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3238 Response::<proto::JoinRoom>::send(self, result)
3239 }
3240}
3241
3242async fn join_channel_internal(
3243 channel_id: ChannelId,
3244 response: Box<impl JoinChannelInternalResponse>,
3245 session: MessageContext,
3246) -> Result<()> {
3247 let joined_room = {
3248 let mut db = session.db().await;
3249 // If zed quits without leaving the room, and the user re-opens zed before the
3250 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3251 // room they were in.
3252 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3253 tracing::info!(
3254 stale_connection_id = %connection,
3255 "cleaning up stale connection",
3256 );
3257 drop(db);
3258 leave_room_for_session(&session, connection).await?;
3259 db = session.db().await;
3260 }
3261
3262 let (joined_room, membership_updated, role) = db
3263 .join_channel(channel_id, session.user_id(), session.connection_id)
3264 .await?;
3265
3266 let live_kit_connection_info =
3267 session
3268 .app_state
3269 .livekit_client
3270 .as_ref()
3271 .and_then(|live_kit| {
3272 let (can_publish, token) = if role == ChannelRole::Guest {
3273 (
3274 false,
3275 live_kit
3276 .guest_token(
3277 &joined_room.room.livekit_room,
3278 &session.user_id().to_string(),
3279 )
3280 .trace_err()?,
3281 )
3282 } else {
3283 (
3284 true,
3285 live_kit
3286 .room_token(
3287 &joined_room.room.livekit_room,
3288 &session.user_id().to_string(),
3289 )
3290 .trace_err()?,
3291 )
3292 };
3293
3294 Some(LiveKitConnectionInfo {
3295 server_url: live_kit.url().into(),
3296 token,
3297 can_publish,
3298 })
3299 });
3300
3301 response.send(proto::JoinRoomResponse {
3302 room: Some(joined_room.room.clone()),
3303 channel_id: joined_room
3304 .channel
3305 .as_ref()
3306 .map(|channel| channel.id.to_proto()),
3307 live_kit_connection_info,
3308 })?;
3309
3310 let mut connection_pool = session.connection_pool().await;
3311 if let Some(membership_updated) = membership_updated {
3312 notify_membership_updated(
3313 &mut connection_pool,
3314 membership_updated,
3315 session.user_id(),
3316 &session.peer,
3317 );
3318 }
3319
3320 room_updated(&joined_room.room, &session.peer);
3321
3322 joined_room
3323 };
3324
3325 channel_updated(
3326 &joined_room.channel.context("channel not returned")?,
3327 &joined_room.room,
3328 &session.peer,
3329 &*session.connection_pool().await,
3330 );
3331
3332 update_user_contacts(session.user_id(), &session).await?;
3333 Ok(())
3334}
3335
3336/// Start editing the channel notes
3337async fn join_channel_buffer(
3338 request: proto::JoinChannelBuffer,
3339 response: Response<proto::JoinChannelBuffer>,
3340 session: MessageContext,
3341) -> Result<()> {
3342 let db = session.db().await;
3343 let channel_id = ChannelId::from_proto(request.channel_id);
3344
3345 let open_response = db
3346 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3347 .await?;
3348
3349 let collaborators = open_response.collaborators.clone();
3350 response.send(open_response)?;
3351
3352 let update = UpdateChannelBufferCollaborators {
3353 channel_id: channel_id.to_proto(),
3354 collaborators: collaborators.clone(),
3355 };
3356 channel_buffer_updated(
3357 session.connection_id,
3358 collaborators
3359 .iter()
3360 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3361 &update,
3362 &session.peer,
3363 );
3364
3365 Ok(())
3366}
3367
3368/// Edit the channel notes
3369async fn update_channel_buffer(
3370 request: proto::UpdateChannelBuffer,
3371 session: MessageContext,
3372) -> Result<()> {
3373 let db = session.db().await;
3374 let channel_id = ChannelId::from_proto(request.channel_id);
3375
3376 let (collaborators, epoch, version) = db
3377 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3378 .await?;
3379
3380 channel_buffer_updated(
3381 session.connection_id,
3382 collaborators.clone(),
3383 &proto::UpdateChannelBuffer {
3384 channel_id: channel_id.to_proto(),
3385 operations: request.operations,
3386 },
3387 &session.peer,
3388 );
3389
3390 let pool = &*session.connection_pool().await;
3391
3392 let non_collaborators =
3393 pool.channel_connection_ids(channel_id)
3394 .filter_map(|(connection_id, _)| {
3395 if collaborators.contains(&connection_id) {
3396 None
3397 } else {
3398 Some(connection_id)
3399 }
3400 });
3401
3402 broadcast(None, non_collaborators, |peer_id| {
3403 session.peer.send(
3404 peer_id,
3405 proto::UpdateChannels {
3406 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3407 channel_id: channel_id.to_proto(),
3408 epoch: epoch as u64,
3409 version: version.clone(),
3410 }],
3411 ..Default::default()
3412 },
3413 )
3414 });
3415
3416 Ok(())
3417}
3418
3419/// Rejoin the channel notes after a connection blip
3420async fn rejoin_channel_buffers(
3421 request: proto::RejoinChannelBuffers,
3422 response: Response<proto::RejoinChannelBuffers>,
3423 session: MessageContext,
3424) -> Result<()> {
3425 let db = session.db().await;
3426 let buffers = db
3427 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3428 .await?;
3429
3430 for rejoined_buffer in &buffers {
3431 let collaborators_to_notify = rejoined_buffer
3432 .buffer
3433 .collaborators
3434 .iter()
3435 .filter_map(|c| Some(c.peer_id?.into()));
3436 channel_buffer_updated(
3437 session.connection_id,
3438 collaborators_to_notify,
3439 &proto::UpdateChannelBufferCollaborators {
3440 channel_id: rejoined_buffer.buffer.channel_id,
3441 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3442 },
3443 &session.peer,
3444 );
3445 }
3446
3447 response.send(proto::RejoinChannelBuffersResponse {
3448 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3449 })?;
3450
3451 Ok(())
3452}
3453
3454/// Stop editing the channel notes
3455async fn leave_channel_buffer(
3456 request: proto::LeaveChannelBuffer,
3457 response: Response<proto::LeaveChannelBuffer>,
3458 session: MessageContext,
3459) -> Result<()> {
3460 let db = session.db().await;
3461 let channel_id = ChannelId::from_proto(request.channel_id);
3462
3463 let left_buffer = db
3464 .leave_channel_buffer(channel_id, session.connection_id)
3465 .await?;
3466
3467 response.send(Ack {})?;
3468
3469 channel_buffer_updated(
3470 session.connection_id,
3471 left_buffer.connections,
3472 &proto::UpdateChannelBufferCollaborators {
3473 channel_id: channel_id.to_proto(),
3474 collaborators: left_buffer.collaborators,
3475 },
3476 &session.peer,
3477 );
3478
3479 Ok(())
3480}
3481
3482fn channel_buffer_updated<T: EnvelopedMessage>(
3483 sender_id: ConnectionId,
3484 collaborators: impl IntoIterator<Item = ConnectionId>,
3485 message: &T,
3486 peer: &Peer,
3487) {
3488 broadcast(Some(sender_id), collaborators, |peer_id| {
3489 peer.send(peer_id, message.clone())
3490 });
3491}
3492
3493fn send_notifications(
3494 connection_pool: &ConnectionPool,
3495 peer: &Peer,
3496 notifications: db::NotificationBatch,
3497) {
3498 for (user_id, notification) in notifications {
3499 for connection_id in connection_pool.user_connection_ids(user_id) {
3500 if let Err(error) = peer.send(
3501 connection_id,
3502 proto::AddNotification {
3503 notification: Some(notification.clone()),
3504 },
3505 ) {
3506 tracing::error!(
3507 "failed to send notification to {:?} {}",
3508 connection_id,
3509 error
3510 );
3511 }
3512 }
3513 }
3514}
3515
3516/// Send a message to the channel
3517async fn send_channel_message(
3518 _request: proto::SendChannelMessage,
3519 _response: Response<proto::SendChannelMessage>,
3520 _session: MessageContext,
3521) -> Result<()> {
3522 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3523}
3524
3525/// Delete a channel message
3526async fn remove_channel_message(
3527 _request: proto::RemoveChannelMessage,
3528 _response: Response<proto::RemoveChannelMessage>,
3529 _session: MessageContext,
3530) -> Result<()> {
3531 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3532}
3533
3534async fn update_channel_message(
3535 _request: proto::UpdateChannelMessage,
3536 _response: Response<proto::UpdateChannelMessage>,
3537 _session: MessageContext,
3538) -> Result<()> {
3539 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3540}
3541
3542/// Mark a channel message as read
3543async fn acknowledge_channel_message(
3544 _request: proto::AckChannelMessage,
3545 _session: MessageContext,
3546) -> Result<()> {
3547 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3548}
3549
3550/// Mark a buffer version as synced
3551async fn acknowledge_buffer_version(
3552 request: proto::AckBufferOperation,
3553 session: MessageContext,
3554) -> Result<()> {
3555 let buffer_id = BufferId::from_proto(request.buffer_id);
3556 session
3557 .db()
3558 .await
3559 .observe_buffer_version(
3560 buffer_id,
3561 session.user_id(),
3562 request.epoch as i32,
3563 &request.version,
3564 )
3565 .await?;
3566 Ok(())
3567}
3568
3569/// Start receiving chat updates for a channel
3570async fn join_channel_chat(
3571 _request: proto::JoinChannelChat,
3572 _response: Response<proto::JoinChannelChat>,
3573 _session: MessageContext,
3574) -> Result<()> {
3575 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3576}
3577
3578/// Stop receiving chat updates for a channel
3579async fn leave_channel_chat(
3580 _request: proto::LeaveChannelChat,
3581 _session: MessageContext,
3582) -> Result<()> {
3583 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3584}
3585
3586/// Retrieve the chat history for a channel
3587async fn get_channel_messages(
3588 _request: proto::GetChannelMessages,
3589 _response: Response<proto::GetChannelMessages>,
3590 _session: MessageContext,
3591) -> Result<()> {
3592 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3593}
3594
3595/// Retrieve specific chat messages
3596async fn get_channel_messages_by_id(
3597 _request: proto::GetChannelMessagesById,
3598 _response: Response<proto::GetChannelMessagesById>,
3599 _session: MessageContext,
3600) -> Result<()> {
3601 Err(anyhow!("chat has been removed in the latest version of Zed").into())
3602}
3603
3604/// Retrieve the current users notifications
3605async fn get_notifications(
3606 request: proto::GetNotifications,
3607 response: Response<proto::GetNotifications>,
3608 session: MessageContext,
3609) -> Result<()> {
3610 let notifications = session
3611 .db()
3612 .await
3613 .get_notifications(
3614 session.user_id(),
3615 NOTIFICATION_COUNT_PER_PAGE,
3616 request.before_id.map(db::NotificationId::from_proto),
3617 )
3618 .await?;
3619 response.send(proto::GetNotificationsResponse {
3620 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
3621 notifications,
3622 })?;
3623 Ok(())
3624}
3625
3626/// Mark notifications as read
3627async fn mark_notification_as_read(
3628 request: proto::MarkNotificationRead,
3629 response: Response<proto::MarkNotificationRead>,
3630 session: MessageContext,
3631) -> Result<()> {
3632 let database = &session.db().await;
3633 let notifications = database
3634 .mark_notification_as_read_by_id(
3635 session.user_id(),
3636 NotificationId::from_proto(request.notification_id),
3637 )
3638 .await?;
3639 send_notifications(
3640 &*session.connection_pool().await,
3641 &session.peer,
3642 notifications,
3643 );
3644 response.send(proto::Ack {})?;
3645 Ok(())
3646}
3647
3648fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
3649 let message = match message {
3650 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
3651 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
3652 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
3653 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
3654 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
3655 code: frame.code.into(),
3656 reason: frame.reason.as_str().to_owned().into(),
3657 })),
3658 // We should never receive a frame while reading the message, according
3659 // to the `tungstenite` maintainers:
3660 //
3661 // > It cannot occur when you read messages from the WebSocket, but it
3662 // > can be used when you want to send the raw frames (e.g. you want to
3663 // > send the frames to the WebSocket without composing the full message first).
3664 // >
3665 // > — https://github.com/snapview/tungstenite-rs/issues/268
3666 TungsteniteMessage::Frame(_) => {
3667 bail!("received an unexpected frame while reading the message")
3668 }
3669 };
3670
3671 Ok(message)
3672}
3673
3674fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
3675 match message {
3676 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
3677 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
3678 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
3679 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
3680 AxumMessage::Close(frame) => {
3681 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
3682 code: frame.code.into(),
3683 reason: frame.reason.as_ref().into(),
3684 }))
3685 }
3686 }
3687}
3688
3689fn notify_membership_updated(
3690 connection_pool: &mut ConnectionPool,
3691 result: MembershipUpdated,
3692 user_id: UserId,
3693 peer: &Peer,
3694) {
3695 for membership in &result.new_channels.channel_memberships {
3696 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3697 }
3698 for channel_id in &result.removed_channels {
3699 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
3700 }
3701
3702 let user_channels_update = proto::UpdateUserChannels {
3703 channel_memberships: result
3704 .new_channels
3705 .channel_memberships
3706 .iter()
3707 .map(|cm| proto::ChannelMembership {
3708 channel_id: cm.channel_id.to_proto(),
3709 role: cm.role.into(),
3710 })
3711 .collect(),
3712 ..Default::default()
3713 };
3714
3715 let mut update = build_channels_update(result.new_channels);
3716 update.delete_channels = result
3717 .removed_channels
3718 .into_iter()
3719 .map(|id| id.to_proto())
3720 .collect();
3721 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
3722
3723 for connection_id in connection_pool.user_connection_ids(user_id) {
3724 peer.send(connection_id, user_channels_update.clone())
3725 .trace_err();
3726 peer.send(connection_id, update.clone()).trace_err();
3727 }
3728}
3729
3730fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
3731 proto::UpdateUserChannels {
3732 channel_memberships: channels
3733 .channel_memberships
3734 .iter()
3735 .map(|m| proto::ChannelMembership {
3736 channel_id: m.channel_id.to_proto(),
3737 role: m.role.into(),
3738 })
3739 .collect(),
3740 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
3741 }
3742}
3743
3744fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
3745 let mut update = proto::UpdateChannels::default();
3746
3747 for channel in channels.channels {
3748 update.channels.push(channel.to_proto());
3749 }
3750
3751 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
3752
3753 for (channel_id, participants) in channels.channel_participants {
3754 update
3755 .channel_participants
3756 .push(proto::ChannelParticipants {
3757 channel_id: channel_id.to_proto(),
3758 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
3759 });
3760 }
3761
3762 for channel in channels.invited_channels {
3763 update.channel_invitations.push(channel.to_proto());
3764 }
3765
3766 update
3767}
3768
3769fn build_initial_contacts_update(
3770 contacts: Vec<db::Contact>,
3771 pool: &ConnectionPool,
3772) -> proto::UpdateContacts {
3773 let mut update = proto::UpdateContacts::default();
3774
3775 for contact in contacts {
3776 match contact {
3777 db::Contact::Accepted { user_id, busy } => {
3778 update.contacts.push(contact_for_user(user_id, busy, pool));
3779 }
3780 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
3781 db::Contact::Incoming { user_id } => {
3782 update
3783 .incoming_requests
3784 .push(proto::IncomingContactRequest {
3785 requester_id: user_id.to_proto(),
3786 })
3787 }
3788 }
3789 }
3790
3791 update
3792}
3793
3794fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
3795 proto::Contact {
3796 user_id: user_id.to_proto(),
3797 online: pool.is_user_online(user_id),
3798 busy,
3799 }
3800}
3801
3802fn room_updated(room: &proto::Room, peer: &Peer) {
3803 broadcast(
3804 None,
3805 room.participants
3806 .iter()
3807 .filter_map(|participant| Some(participant.peer_id?.into())),
3808 |peer_id| {
3809 peer.send(
3810 peer_id,
3811 proto::RoomUpdated {
3812 room: Some(room.clone()),
3813 },
3814 )
3815 },
3816 );
3817}
3818
3819fn channel_updated(
3820 channel: &db::channel::Model,
3821 room: &proto::Room,
3822 peer: &Peer,
3823 pool: &ConnectionPool,
3824) {
3825 let participants = room
3826 .participants
3827 .iter()
3828 .map(|p| p.user_id)
3829 .collect::<Vec<_>>();
3830
3831 broadcast(
3832 None,
3833 pool.channel_connection_ids(channel.root_id())
3834 .filter_map(|(channel_id, role)| {
3835 role.can_see_channel(channel.visibility)
3836 .then_some(channel_id)
3837 }),
3838 |peer_id| {
3839 peer.send(
3840 peer_id,
3841 proto::UpdateChannels {
3842 channel_participants: vec![proto::ChannelParticipants {
3843 channel_id: channel.id.to_proto(),
3844 participant_user_ids: participants.clone(),
3845 }],
3846 ..Default::default()
3847 },
3848 )
3849 },
3850 );
3851}
3852
3853async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
3854 let db = session.db().await;
3855
3856 let contacts = db.get_contacts(user_id).await?;
3857 let busy = db.is_user_busy(user_id).await?;
3858
3859 let pool = session.connection_pool().await;
3860 let updated_contact = contact_for_user(user_id, busy, &pool);
3861 for contact in contacts {
3862 if let db::Contact::Accepted {
3863 user_id: contact_user_id,
3864 ..
3865 } = contact
3866 {
3867 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
3868 session
3869 .peer
3870 .send(
3871 contact_conn_id,
3872 proto::UpdateContacts {
3873 contacts: vec![updated_contact.clone()],
3874 remove_contacts: Default::default(),
3875 incoming_requests: Default::default(),
3876 remove_incoming_requests: Default::default(),
3877 outgoing_requests: Default::default(),
3878 remove_outgoing_requests: Default::default(),
3879 },
3880 )
3881 .trace_err();
3882 }
3883 }
3884 }
3885 Ok(())
3886}
3887
3888async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
3889 let mut contacts_to_update = HashSet::default();
3890
3891 let room_id;
3892 let canceled_calls_to_user_ids;
3893 let livekit_room;
3894 let delete_livekit_room;
3895 let room;
3896 let channel;
3897
3898 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
3899 contacts_to_update.insert(session.user_id());
3900
3901 for project in left_room.left_projects.values() {
3902 project_left(project, session);
3903 }
3904
3905 room_id = RoomId::from_proto(left_room.room.id);
3906 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
3907 livekit_room = mem::take(&mut left_room.room.livekit_room);
3908 delete_livekit_room = left_room.deleted;
3909 room = mem::take(&mut left_room.room);
3910 channel = mem::take(&mut left_room.channel);
3911
3912 room_updated(&room, &session.peer);
3913 } else {
3914 return Ok(());
3915 }
3916
3917 if let Some(channel) = channel {
3918 channel_updated(
3919 &channel,
3920 &room,
3921 &session.peer,
3922 &*session.connection_pool().await,
3923 );
3924 }
3925
3926 {
3927 let pool = session.connection_pool().await;
3928 for canceled_user_id in canceled_calls_to_user_ids {
3929 for connection_id in pool.user_connection_ids(canceled_user_id) {
3930 session
3931 .peer
3932 .send(
3933 connection_id,
3934 proto::CallCanceled {
3935 room_id: room_id.to_proto(),
3936 },
3937 )
3938 .trace_err();
3939 }
3940 contacts_to_update.insert(canceled_user_id);
3941 }
3942 }
3943
3944 for contact_user_id in contacts_to_update {
3945 update_user_contacts(contact_user_id, session).await?;
3946 }
3947
3948 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
3949 live_kit
3950 .remove_participant(livekit_room.clone(), session.user_id().to_string())
3951 .await
3952 .trace_err();
3953
3954 if delete_livekit_room {
3955 live_kit.delete_room(livekit_room).await.trace_err();
3956 }
3957 }
3958
3959 Ok(())
3960}
3961
3962async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
3963 let left_channel_buffers = session
3964 .db()
3965 .await
3966 .leave_channel_buffers(session.connection_id)
3967 .await?;
3968
3969 for left_buffer in left_channel_buffers {
3970 channel_buffer_updated(
3971 session.connection_id,
3972 left_buffer.connections,
3973 &proto::UpdateChannelBufferCollaborators {
3974 channel_id: left_buffer.channel_id.to_proto(),
3975 collaborators: left_buffer.collaborators,
3976 },
3977 &session.peer,
3978 );
3979 }
3980
3981 Ok(())
3982}
3983
3984fn project_left(project: &db::LeftProject, session: &Session) {
3985 for connection_id in &project.connection_ids {
3986 if project.should_unshare {
3987 session
3988 .peer
3989 .send(
3990 *connection_id,
3991 proto::UnshareProject {
3992 project_id: project.id.to_proto(),
3993 },
3994 )
3995 .trace_err();
3996 } else {
3997 session
3998 .peer
3999 .send(
4000 *connection_id,
4001 proto::RemoveProjectCollaborator {
4002 project_id: project.id.to_proto(),
4003 peer_id: Some(session.connection_id.into()),
4004 },
4005 )
4006 .trace_err();
4007 }
4008 }
4009}
4010
4011async fn share_agent_thread(
4012 request: proto::ShareAgentThread,
4013 response: Response<proto::ShareAgentThread>,
4014 session: MessageContext,
4015) -> Result<()> {
4016 let user_id = session.user_id();
4017
4018 let share_id = SharedThreadId::from_proto(request.session_id.clone())
4019 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4020
4021 session
4022 .db()
4023 .await
4024 .upsert_shared_thread(share_id, user_id, &request.title, request.thread_data)
4025 .await?;
4026
4027 response.send(proto::Ack {})?;
4028
4029 Ok(())
4030}
4031
4032async fn get_shared_agent_thread(
4033 request: proto::GetSharedAgentThread,
4034 response: Response<proto::GetSharedAgentThread>,
4035 session: MessageContext,
4036) -> Result<()> {
4037 let share_id = SharedThreadId::from_proto(request.session_id)
4038 .ok_or_else(|| anyhow!("Invalid session ID format"))?;
4039
4040 let result = session.db().await.get_shared_thread(share_id).await?;
4041
4042 match result {
4043 Some((thread, username)) => {
4044 response.send(proto::GetSharedAgentThreadResponse {
4045 title: thread.title,
4046 thread_data: thread.data,
4047 sharer_username: username,
4048 created_at: thread.created_at.and_utc().to_rfc3339(),
4049 })?;
4050 }
4051 None => {
4052 return Err(anyhow!("Shared thread not found").into());
4053 }
4054 }
4055
4056 Ok(())
4057}
4058
4059pub trait ResultExt {
4060 type Ok;
4061
4062 fn trace_err(self) -> Option<Self::Ok>;
4063}
4064
4065impl<T, E> ResultExt for Result<T, E>
4066where
4067 E: std::fmt::Debug,
4068{
4069 type Ok = T;
4070
4071 #[track_caller]
4072 fn trace_err(self) -> Option<T> {
4073 match self {
4074 Ok(value) => Some(value),
4075 Err(error) => {
4076 tracing::error!("{:?}", error);
4077 None
4078 }
4079 }
4080 }
4081}