1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::db::billing_subscription::SubscriptionKind;
5use crate::llm::db::LlmDatabase;
6use crate::llm::{
7 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG,
8 MIN_ACCOUNT_AGE_FOR_LLM_USE,
9};
10use crate::{
11 AppState, Error, Result, auth,
12 db::{
13 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
14 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
15 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
16 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
17 },
18 executor::Executor,
19};
20use anyhow::{Context as _, anyhow, bail};
21use async_tungstenite::tungstenite::{
22 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
23};
24use axum::headers::UserAgent;
25use axum::{
26 Extension, Router, TypedHeader,
27 body::Body,
28 extract::{
29 ConnectInfo, WebSocketUpgrade,
30 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
31 },
32 headers::{Header, HeaderName},
33 http::StatusCode,
34 middleware,
35 response::IntoResponse,
36 routing::get,
37};
38use chrono::Utc;
39use collections::{HashMap, HashSet};
40pub use connection_pool::{ConnectionPool, ZedVersion};
41use core::fmt::{self, Debug, Formatter};
42use futures::TryFutureExt as _;
43use reqwest_client::ReqwestClient;
44use rpc::proto::{MultiLspQuery, split_repository_update};
45use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
46use tracing::Span;
47
48use futures::{
49 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
50 stream::FuturesUnordered,
51};
52use prometheus::{IntGauge, register_int_gauge};
53use rpc::{
54 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
55 proto::{
56 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
57 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
58 },
59};
60use semantic_version::SemanticVersion;
61use serde::{Serialize, Serializer};
62use std::{
63 any::TypeId,
64 future::Future,
65 marker::PhantomData,
66 mem,
67 net::SocketAddr,
68 ops::{Deref, DerefMut},
69 rc::Rc,
70 sync::{
71 Arc, OnceLock,
72 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
73 },
74 time::{Duration, Instant},
75};
76use time::OffsetDateTime;
77use tokio::sync::{Semaphore, watch};
78use tower::ServiceBuilder;
79use tracing::{
80 Instrument,
81 field::{self},
82 info_span, instrument,
83};
84
85pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
86
87// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
88pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
89
90const MESSAGE_COUNT_PER_PAGE: usize = 100;
91const MAX_MESSAGE_LEN: usize = 1024;
92const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
93const MAX_CONCURRENT_CONNECTIONS: usize = 512;
94
95static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
96
97const TOTAL_DURATION_MS: &str = "total_duration_ms";
98const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
99const QUEUE_DURATION_MS: &str = "queue_duration_ms";
100const HOST_WAITING_MS: &str = "host_waiting_ms";
101
102type MessageHandler =
103 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
104
105pub struct ConnectionGuard;
106
107impl ConnectionGuard {
108 pub fn try_acquire() -> Result<Self, ()> {
109 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
110 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
111 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
112 tracing::error!(
113 "too many concurrent connections: {}",
114 current_connections + 1
115 );
116 return Err(());
117 }
118 Ok(ConnectionGuard)
119 }
120}
121
122impl Drop for ConnectionGuard {
123 fn drop(&mut self) {
124 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
125 }
126}
127
128struct Response<R> {
129 peer: Arc<Peer>,
130 receipt: Receipt<R>,
131 responded: Arc<AtomicBool>,
132}
133
134impl<R: RequestMessage> Response<R> {
135 fn send(self, payload: R::Response) -> Result<()> {
136 self.responded.store(true, SeqCst);
137 self.peer.respond(self.receipt, payload)?;
138 Ok(())
139 }
140}
141
142#[derive(Clone, Debug)]
143pub enum Principal {
144 User(User),
145 Impersonated { user: User, admin: User },
146}
147
148impl Principal {
149 fn user(&self) -> &User {
150 match self {
151 Principal::User(user) => user,
152 Principal::Impersonated { user, .. } => user,
153 }
154 }
155
156 fn update_span(&self, span: &tracing::Span) {
157 match &self {
158 Principal::User(user) => {
159 span.record("user_id", user.id.0);
160 span.record("login", &user.github_login);
161 }
162 Principal::Impersonated { user, admin } => {
163 span.record("user_id", user.id.0);
164 span.record("login", &user.github_login);
165 span.record("impersonator", &admin.github_login);
166 }
167 }
168 }
169}
170
171#[derive(Clone)]
172struct MessageContext {
173 session: Session,
174 span: tracing::Span,
175}
176
177impl Deref for MessageContext {
178 type Target = Session;
179
180 fn deref(&self) -> &Self::Target {
181 &self.session
182 }
183}
184
185impl MessageContext {
186 pub fn forward_request<T: RequestMessage>(
187 &self,
188 receiver_id: ConnectionId,
189 request: T,
190 ) -> impl Future<Output = anyhow::Result<T::Response>> {
191 let request_start_time = Instant::now();
192 let span = self.span.clone();
193 tracing::info!("start forwarding request");
194 self.peer
195 .forward_request(self.connection_id, receiver_id, request)
196 .inspect(move |_| {
197 span.record(
198 HOST_WAITING_MS,
199 request_start_time.elapsed().as_micros() as f64 / 1000.0,
200 );
201 })
202 .inspect_err(|_| tracing::error!("error forwarding request"))
203 .inspect_ok(|_| tracing::info!("finished forwarding request"))
204 }
205}
206
207#[derive(Clone)]
208struct Session {
209 principal: Principal,
210 connection_id: ConnectionId,
211 db: Arc<tokio::sync::Mutex<DbHandle>>,
212 peer: Arc<Peer>,
213 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
214 app_state: Arc<AppState>,
215 supermaven_client: Option<Arc<SupermavenAdminApi>>,
216 /// The GeoIP country code for the user.
217 #[allow(unused)]
218 geoip_country_code: Option<String>,
219 #[allow(unused)]
220 system_id: Option<String>,
221 _executor: Executor,
222}
223
224impl Session {
225 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
226 #[cfg(test)]
227 tokio::task::yield_now().await;
228 let guard = self.db.lock().await;
229 #[cfg(test)]
230 tokio::task::yield_now().await;
231 guard
232 }
233
234 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
235 #[cfg(test)]
236 tokio::task::yield_now().await;
237 let guard = self.connection_pool.lock();
238 ConnectionPoolGuard {
239 guard,
240 _not_send: PhantomData,
241 }
242 }
243
244 fn is_staff(&self) -> bool {
245 match &self.principal {
246 Principal::User(user) => user.admin,
247 Principal::Impersonated { .. } => true,
248 }
249 }
250
251 fn user_id(&self) -> UserId {
252 match &self.principal {
253 Principal::User(user) => user.id,
254 Principal::Impersonated { user, .. } => user.id,
255 }
256 }
257
258 pub fn email(&self) -> Option<String> {
259 match &self.principal {
260 Principal::User(user) => user.email_address.clone(),
261 Principal::Impersonated { user, .. } => user.email_address.clone(),
262 }
263 }
264}
265
266impl Debug for Session {
267 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
268 let mut result = f.debug_struct("Session");
269 match &self.principal {
270 Principal::User(user) => {
271 result.field("user", &user.github_login);
272 }
273 Principal::Impersonated { user, admin } => {
274 result.field("user", &user.github_login);
275 result.field("impersonator", &admin.github_login);
276 }
277 }
278 result.field("connection_id", &self.connection_id).finish()
279 }
280}
281
282struct DbHandle(Arc<Database>);
283
284impl Deref for DbHandle {
285 type Target = Database;
286
287 fn deref(&self) -> &Self::Target {
288 self.0.as_ref()
289 }
290}
291
292pub struct Server {
293 id: parking_lot::Mutex<ServerId>,
294 peer: Arc<Peer>,
295 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
296 app_state: Arc<AppState>,
297 handlers: HashMap<TypeId, MessageHandler>,
298 teardown: watch::Sender<bool>,
299}
300
301pub(crate) struct ConnectionPoolGuard<'a> {
302 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
303 _not_send: PhantomData<Rc<()>>,
304}
305
306#[derive(Serialize)]
307pub struct ServerSnapshot<'a> {
308 peer: &'a Peer,
309 #[serde(serialize_with = "serialize_deref")]
310 connection_pool: ConnectionPoolGuard<'a>,
311}
312
313pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
314where
315 S: Serializer,
316 T: Deref<Target = U>,
317 U: Serialize,
318{
319 Serialize::serialize(value.deref(), serializer)
320}
321
322impl Server {
323 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
324 let mut server = Self {
325 id: parking_lot::Mutex::new(id),
326 peer: Peer::new(id.0 as u32),
327 app_state: app_state.clone(),
328 connection_pool: Default::default(),
329 handlers: Default::default(),
330 teardown: watch::channel(false).0,
331 };
332
333 server
334 .add_request_handler(ping)
335 .add_request_handler(create_room)
336 .add_request_handler(join_room)
337 .add_request_handler(rejoin_room)
338 .add_request_handler(leave_room)
339 .add_request_handler(set_room_participant_role)
340 .add_request_handler(call)
341 .add_request_handler(cancel_call)
342 .add_message_handler(decline_call)
343 .add_request_handler(update_participant_location)
344 .add_request_handler(share_project)
345 .add_message_handler(unshare_project)
346 .add_request_handler(join_project)
347 .add_message_handler(leave_project)
348 .add_request_handler(update_project)
349 .add_request_handler(update_worktree)
350 .add_request_handler(update_repository)
351 .add_request_handler(remove_repository)
352 .add_message_handler(start_language_server)
353 .add_message_handler(update_language_server)
354 .add_message_handler(update_diagnostic_summary)
355 .add_message_handler(update_worktree_settings)
356 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
357 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
358 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
359 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
360 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
361 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
362 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
363 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
364 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
365 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
366 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
367 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
368 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
369 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
370 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
371 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
372 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
373 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
374 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
375 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
376 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
377 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
378 .add_request_handler(
379 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
380 )
381 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
382 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
383 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
384 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
385 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
386 .add_request_handler(
387 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
388 )
389 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
390 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
391 .add_request_handler(
392 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
393 )
394 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
395 .add_request_handler(
396 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
397 )
398 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
399 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
400 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
401 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
402 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
403 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
404 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
405 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
406 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
407 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
408 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
409 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
410 .add_request_handler(
411 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
412 )
413 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
414 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
415 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
416 .add_request_handler(multi_lsp_query)
417 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
418 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
419 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
420 .add_message_handler(create_buffer_for_peer)
421 .add_request_handler(update_buffer)
422 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
423 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
424 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
425 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
426 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
427 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
428 .add_message_handler(
429 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
430 )
431 .add_request_handler(get_users)
432 .add_request_handler(fuzzy_search_users)
433 .add_request_handler(request_contact)
434 .add_request_handler(remove_contact)
435 .add_request_handler(respond_to_contact_request)
436 .add_message_handler(subscribe_to_channels)
437 .add_request_handler(create_channel)
438 .add_request_handler(delete_channel)
439 .add_request_handler(invite_channel_member)
440 .add_request_handler(remove_channel_member)
441 .add_request_handler(set_channel_member_role)
442 .add_request_handler(set_channel_visibility)
443 .add_request_handler(rename_channel)
444 .add_request_handler(join_channel_buffer)
445 .add_request_handler(leave_channel_buffer)
446 .add_message_handler(update_channel_buffer)
447 .add_request_handler(rejoin_channel_buffers)
448 .add_request_handler(get_channel_members)
449 .add_request_handler(respond_to_channel_invite)
450 .add_request_handler(join_channel)
451 .add_request_handler(join_channel_chat)
452 .add_message_handler(leave_channel_chat)
453 .add_request_handler(send_channel_message)
454 .add_request_handler(remove_channel_message)
455 .add_request_handler(update_channel_message)
456 .add_request_handler(get_channel_messages)
457 .add_request_handler(get_channel_messages_by_id)
458 .add_request_handler(get_notifications)
459 .add_request_handler(mark_notification_as_read)
460 .add_request_handler(move_channel)
461 .add_request_handler(reorder_channel)
462 .add_request_handler(follow)
463 .add_message_handler(unfollow)
464 .add_message_handler(update_followers)
465 .add_request_handler(get_private_user_info)
466 .add_request_handler(accept_terms_of_service)
467 .add_message_handler(acknowledge_channel_message)
468 .add_message_handler(acknowledge_buffer_version)
469 .add_request_handler(get_supermaven_api_key)
470 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
471 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
472 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
473 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
474 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
475 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
476 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
477 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
478 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
479 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
480 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
481 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
482 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
483 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
484 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
485 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
486 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
487 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
488 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
489 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
490 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
491 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
492 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
493 .add_message_handler(update_context);
494
495 Arc::new(server)
496 }
497
498 pub async fn start(&self) -> Result<()> {
499 let server_id = *self.id.lock();
500 let app_state = self.app_state.clone();
501 let peer = self.peer.clone();
502 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
503 let pool = self.connection_pool.clone();
504 let livekit_client = self.app_state.livekit_client.clone();
505
506 let span = info_span!("start server");
507 self.app_state.executor.spawn_detached(
508 async move {
509 tracing::info!("waiting for cleanup timeout");
510 timeout.await;
511 tracing::info!("cleanup timeout expired, retrieving stale rooms");
512
513 app_state
514 .db
515 .delete_stale_channel_chat_participants(
516 &app_state.config.zed_environment,
517 server_id,
518 )
519 .await
520 .trace_err();
521
522 if let Some((room_ids, channel_ids)) = app_state
523 .db
524 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
525 .await
526 .trace_err()
527 {
528 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
529 tracing::info!(
530 stale_channel_buffer_count = channel_ids.len(),
531 "retrieved stale channel buffers"
532 );
533
534 for channel_id in channel_ids {
535 if let Some(refreshed_channel_buffer) = app_state
536 .db
537 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
538 .await
539 .trace_err()
540 {
541 for connection_id in refreshed_channel_buffer.connection_ids {
542 peer.send(
543 connection_id,
544 proto::UpdateChannelBufferCollaborators {
545 channel_id: channel_id.to_proto(),
546 collaborators: refreshed_channel_buffer
547 .collaborators
548 .clone(),
549 },
550 )
551 .trace_err();
552 }
553 }
554 }
555
556 for room_id in room_ids {
557 let mut contacts_to_update = HashSet::default();
558 let mut canceled_calls_to_user_ids = Vec::new();
559 let mut livekit_room = String::new();
560 let mut delete_livekit_room = false;
561
562 if let Some(mut refreshed_room) = app_state
563 .db
564 .clear_stale_room_participants(room_id, server_id)
565 .await
566 .trace_err()
567 {
568 tracing::info!(
569 room_id = room_id.0,
570 new_participant_count = refreshed_room.room.participants.len(),
571 "refreshed room"
572 );
573 room_updated(&refreshed_room.room, &peer);
574 if let Some(channel) = refreshed_room.channel.as_ref() {
575 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
576 }
577 contacts_to_update
578 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
579 contacts_to_update
580 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
581 canceled_calls_to_user_ids =
582 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
583 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
584 delete_livekit_room = refreshed_room.room.participants.is_empty();
585 }
586
587 {
588 let pool = pool.lock();
589 for canceled_user_id in canceled_calls_to_user_ids {
590 for connection_id in pool.user_connection_ids(canceled_user_id) {
591 peer.send(
592 connection_id,
593 proto::CallCanceled {
594 room_id: room_id.to_proto(),
595 },
596 )
597 .trace_err();
598 }
599 }
600 }
601
602 for user_id in contacts_to_update {
603 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
604 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
605 if let Some((busy, contacts)) = busy.zip(contacts) {
606 let pool = pool.lock();
607 let updated_contact = contact_for_user(user_id, busy, &pool);
608 for contact in contacts {
609 if let db::Contact::Accepted {
610 user_id: contact_user_id,
611 ..
612 } = contact
613 {
614 for contact_conn_id in
615 pool.user_connection_ids(contact_user_id)
616 {
617 peer.send(
618 contact_conn_id,
619 proto::UpdateContacts {
620 contacts: vec![updated_contact.clone()],
621 remove_contacts: Default::default(),
622 incoming_requests: Default::default(),
623 remove_incoming_requests: Default::default(),
624 outgoing_requests: Default::default(),
625 remove_outgoing_requests: Default::default(),
626 },
627 )
628 .trace_err();
629 }
630 }
631 }
632 }
633 }
634
635 if let Some(live_kit) = livekit_client.as_ref() {
636 if delete_livekit_room {
637 live_kit.delete_room(livekit_room).await.trace_err();
638 }
639 }
640 }
641 }
642
643 app_state
644 .db
645 .delete_stale_channel_chat_participants(
646 &app_state.config.zed_environment,
647 server_id,
648 )
649 .await
650 .trace_err();
651
652 app_state
653 .db
654 .clear_old_worktree_entries(server_id)
655 .await
656 .trace_err();
657
658 app_state
659 .db
660 .delete_stale_servers(&app_state.config.zed_environment, server_id)
661 .await
662 .trace_err();
663 }
664 .instrument(span),
665 );
666 Ok(())
667 }
668
669 pub fn teardown(&self) {
670 self.peer.teardown();
671 self.connection_pool.lock().reset();
672 let _ = self.teardown.send(true);
673 }
674
675 #[cfg(test)]
676 pub fn reset(&self, id: ServerId) {
677 self.teardown();
678 *self.id.lock() = id;
679 self.peer.reset(id.0 as u32);
680 let _ = self.teardown.send(false);
681 }
682
683 #[cfg(test)]
684 pub fn id(&self) -> ServerId {
685 *self.id.lock()
686 }
687
688 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
689 where
690 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
691 Fut: 'static + Send + Future<Output = Result<()>>,
692 M: EnvelopedMessage,
693 {
694 let prev_handler = self.handlers.insert(
695 TypeId::of::<M>(),
696 Box::new(move |envelope, session, span| {
697 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
698 let received_at = envelope.received_at;
699 tracing::info!("message received");
700 let start_time = Instant::now();
701 let future = (handler)(
702 *envelope,
703 MessageContext {
704 session,
705 span: span.clone(),
706 },
707 );
708 async move {
709 let result = future.await;
710 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
711 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
712 let queue_duration_ms = total_duration_ms - processing_duration_ms;
713 span.record(TOTAL_DURATION_MS, total_duration_ms);
714 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
715 span.record(QUEUE_DURATION_MS, queue_duration_ms);
716 match result {
717 Err(error) => {
718 tracing::error!(?error, "error handling message")
719 }
720 Ok(()) => tracing::info!("finished handling message"),
721 }
722 }
723 .boxed()
724 }),
725 );
726 if prev_handler.is_some() {
727 panic!("registered a handler for the same message twice");
728 }
729 self
730 }
731
732 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
733 where
734 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
735 Fut: 'static + Send + Future<Output = Result<()>>,
736 M: EnvelopedMessage,
737 {
738 self.add_handler(move |envelope, session| handler(envelope.payload, session));
739 self
740 }
741
742 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
743 where
744 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
745 Fut: Send + Future<Output = Result<()>>,
746 M: RequestMessage,
747 {
748 let handler = Arc::new(handler);
749 self.add_handler(move |envelope, session| {
750 let receipt = envelope.receipt();
751 let handler = handler.clone();
752 async move {
753 let peer = session.peer.clone();
754 let responded = Arc::new(AtomicBool::default());
755 let response = Response {
756 peer: peer.clone(),
757 responded: responded.clone(),
758 receipt,
759 };
760 match (handler)(envelope.payload, response, session).await {
761 Ok(()) => {
762 if responded.load(std::sync::atomic::Ordering::SeqCst) {
763 Ok(())
764 } else {
765 Err(anyhow!("handler did not send a response"))?
766 }
767 }
768 Err(error) => {
769 let proto_err = match &error {
770 Error::Internal(err) => err.to_proto(),
771 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
772 };
773 peer.respond_with_error(receipt, proto_err)?;
774 Err(error)
775 }
776 }
777 }
778 })
779 }
780
781 pub fn handle_connection(
782 self: &Arc<Self>,
783 connection: Connection,
784 address: String,
785 principal: Principal,
786 zed_version: ZedVersion,
787 release_channel: Option<String>,
788 user_agent: Option<String>,
789 geoip_country_code: Option<String>,
790 system_id: Option<String>,
791 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
792 executor: Executor,
793 connection_guard: Option<ConnectionGuard>,
794 ) -> impl Future<Output = ()> + use<> {
795 let this = self.clone();
796 let span = info_span!("handle connection", %address,
797 connection_id=field::Empty,
798 user_id=field::Empty,
799 login=field::Empty,
800 impersonator=field::Empty,
801 user_agent=field::Empty,
802 geoip_country_code=field::Empty,
803 release_channel=field::Empty,
804 );
805 principal.update_span(&span);
806 if let Some(user_agent) = user_agent {
807 span.record("user_agent", user_agent);
808 }
809 if let Some(release_channel) = release_channel {
810 span.record("release_channel", release_channel);
811 }
812
813 if let Some(country_code) = geoip_country_code.as_ref() {
814 span.record("geoip_country_code", country_code);
815 }
816
817 let mut teardown = self.teardown.subscribe();
818 async move {
819 if *teardown.borrow() {
820 tracing::error!("server is tearing down");
821 return;
822 }
823
824 let (connection_id, handle_io, mut incoming_rx) =
825 this.peer.add_connection(connection, {
826 let executor = executor.clone();
827 move |duration| executor.sleep(duration)
828 });
829 tracing::Span::current().record("connection_id", format!("{}", connection_id));
830
831 tracing::info!("connection opened");
832
833 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
834 let http_client = match ReqwestClient::user_agent(&user_agent) {
835 Ok(http_client) => Arc::new(http_client),
836 Err(error) => {
837 tracing::error!(?error, "failed to create HTTP client");
838 return;
839 }
840 };
841
842 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
843 |supermaven_admin_api_key| {
844 Arc::new(SupermavenAdminApi::new(
845 supermaven_admin_api_key.to_string(),
846 http_client.clone(),
847 ))
848 },
849 );
850
851 let session = Session {
852 principal: principal.clone(),
853 connection_id,
854 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
855 peer: this.peer.clone(),
856 connection_pool: this.connection_pool.clone(),
857 app_state: this.app_state.clone(),
858 geoip_country_code,
859 system_id,
860 _executor: executor.clone(),
861 supermaven_client,
862 };
863
864 if let Err(error) = this
865 .send_initial_client_update(
866 connection_id,
867 zed_version,
868 send_connection_id,
869 &session,
870 )
871 .await
872 {
873 tracing::error!(?error, "failed to send initial client update");
874 return;
875 }
876 drop(connection_guard);
877
878 let handle_io = handle_io.fuse();
879 futures::pin_mut!(handle_io);
880
881 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
882 // This prevents deadlocks when e.g., client A performs a request to client B and
883 // client B performs a request to client A. If both clients stop processing further
884 // messages until their respective request completes, they won't have a chance to
885 // respond to the other client's request and cause a deadlock.
886 //
887 // This arrangement ensures we will attempt to process earlier messages first, but fall
888 // back to processing messages arrived later in the spirit of making progress.
889 const MAX_CONCURRENT_HANDLERS: usize = 256;
890 let mut foreground_message_handlers = FuturesUnordered::new();
891 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
892 let get_concurrent_handlers = {
893 let concurrent_handlers = concurrent_handlers.clone();
894 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
895 };
896 loop {
897 let next_message = async {
898 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
899 let message = incoming_rx.next().await;
900 // Cache the concurrent_handlers here, so that we know what the
901 // queue looks like as each handler starts
902 (permit, message, get_concurrent_handlers())
903 }
904 .fuse();
905 futures::pin_mut!(next_message);
906 futures::select_biased! {
907 _ = teardown.changed().fuse() => return,
908 result = handle_io => {
909 if let Err(error) = result {
910 tracing::error!(?error, "error handling I/O");
911 }
912 break;
913 }
914 _ = foreground_message_handlers.next() => {}
915 next_message = next_message => {
916 let (permit, message, concurrent_handlers) = next_message;
917 if let Some(message) = message {
918 let type_name = message.payload_type_name();
919 // note: we copy all the fields from the parent span so we can query them in the logs.
920 // (https://github.com/tokio-rs/tracing/issues/2670).
921 let span = tracing::info_span!("receive message",
922 %connection_id,
923 %address,
924 type_name,
925 concurrent_handlers,
926 user_id=field::Empty,
927 login=field::Empty,
928 impersonator=field::Empty,
929 multi_lsp_query_request=field::Empty,
930 release_channel=field::Empty,
931 { TOTAL_DURATION_MS }=field::Empty,
932 { PROCESSING_DURATION_MS }=field::Empty,
933 { QUEUE_DURATION_MS }=field::Empty,
934 { HOST_WAITING_MS }=field::Empty
935 );
936 principal.update_span(&span);
937 let span_enter = span.enter();
938 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
939 let is_background = message.is_background();
940 let handle_message = (handler)(message, session.clone(), span.clone());
941 drop(span_enter);
942
943 let handle_message = async move {
944 handle_message.await;
945 drop(permit);
946 }.instrument(span);
947 if is_background {
948 executor.spawn_detached(handle_message);
949 } else {
950 foreground_message_handlers.push(handle_message);
951 }
952 } else {
953 tracing::error!("no message handler");
954 }
955 } else {
956 tracing::info!("connection closed");
957 break;
958 }
959 }
960 }
961 }
962
963 drop(foreground_message_handlers);
964 let concurrent_handlers = get_concurrent_handlers();
965 tracing::info!(concurrent_handlers, "signing out");
966 if let Err(error) = connection_lost(session, teardown, executor).await {
967 tracing::error!(?error, "error signing out");
968 }
969 }
970 .instrument(span)
971 }
972
973 async fn send_initial_client_update(
974 &self,
975 connection_id: ConnectionId,
976 zed_version: ZedVersion,
977 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
978 session: &Session,
979 ) -> Result<()> {
980 self.peer.send(
981 connection_id,
982 proto::Hello {
983 peer_id: Some(connection_id.into()),
984 },
985 )?;
986 tracing::info!("sent hello message");
987 if let Some(send_connection_id) = send_connection_id.take() {
988 let _ = send_connection_id.send(connection_id);
989 }
990
991 match &session.principal {
992 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
993 if !user.connected_once {
994 self.peer.send(connection_id, proto::ShowContacts {})?;
995 self.app_state
996 .db
997 .set_user_connected_once(user.id, true)
998 .await?;
999 }
1000
1001 update_user_plan(session).await?;
1002
1003 let contacts = self.app_state.db.get_contacts(user.id).await?;
1004
1005 {
1006 let mut pool = self.connection_pool.lock();
1007 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1008 self.peer.send(
1009 connection_id,
1010 build_initial_contacts_update(contacts, &pool),
1011 )?;
1012 }
1013
1014 if should_auto_subscribe_to_channels(zed_version) {
1015 subscribe_user_to_channels(user.id, session).await?;
1016 }
1017
1018 if let Some(incoming_call) =
1019 self.app_state.db.incoming_call_for_user(user.id).await?
1020 {
1021 self.peer.send(connection_id, incoming_call)?;
1022 }
1023
1024 update_user_contacts(user.id, session).await?;
1025 }
1026 }
1027
1028 Ok(())
1029 }
1030
1031 pub async fn invite_code_redeemed(
1032 self: &Arc<Self>,
1033 inviter_id: UserId,
1034 invitee_id: UserId,
1035 ) -> Result<()> {
1036 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1037 if let Some(code) = &user.invite_code {
1038 let pool = self.connection_pool.lock();
1039 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1040 for connection_id in pool.user_connection_ids(inviter_id) {
1041 self.peer.send(
1042 connection_id,
1043 proto::UpdateContacts {
1044 contacts: vec![invitee_contact.clone()],
1045 ..Default::default()
1046 },
1047 )?;
1048 self.peer.send(
1049 connection_id,
1050 proto::UpdateInviteInfo {
1051 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1052 count: user.invite_count as u32,
1053 },
1054 )?;
1055 }
1056 }
1057 }
1058 Ok(())
1059 }
1060
1061 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1062 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1063 if let Some(invite_code) = &user.invite_code {
1064 let pool = self.connection_pool.lock();
1065 for connection_id in pool.user_connection_ids(user_id) {
1066 self.peer.send(
1067 connection_id,
1068 proto::UpdateInviteInfo {
1069 url: format!(
1070 "{}{}",
1071 self.app_state.config.invite_link_prefix, invite_code
1072 ),
1073 count: user.invite_count as u32,
1074 },
1075 )?;
1076 }
1077 }
1078 }
1079 Ok(())
1080 }
1081
1082 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1083 ServerSnapshot {
1084 connection_pool: ConnectionPoolGuard {
1085 guard: self.connection_pool.lock(),
1086 _not_send: PhantomData,
1087 },
1088 peer: &self.peer,
1089 }
1090 }
1091}
1092
1093impl Deref for ConnectionPoolGuard<'_> {
1094 type Target = ConnectionPool;
1095
1096 fn deref(&self) -> &Self::Target {
1097 &self.guard
1098 }
1099}
1100
1101impl DerefMut for ConnectionPoolGuard<'_> {
1102 fn deref_mut(&mut self) -> &mut Self::Target {
1103 &mut self.guard
1104 }
1105}
1106
1107impl Drop for ConnectionPoolGuard<'_> {
1108 fn drop(&mut self) {
1109 #[cfg(test)]
1110 self.check_invariants();
1111 }
1112}
1113
1114fn broadcast<F>(
1115 sender_id: Option<ConnectionId>,
1116 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1117 mut f: F,
1118) where
1119 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1120{
1121 for receiver_id in receiver_ids {
1122 if Some(receiver_id) != sender_id {
1123 if let Err(error) = f(receiver_id) {
1124 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1125 }
1126 }
1127 }
1128}
1129
1130pub struct ProtocolVersion(u32);
1131
1132impl Header for ProtocolVersion {
1133 fn name() -> &'static HeaderName {
1134 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1135 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1136 }
1137
1138 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1139 where
1140 Self: Sized,
1141 I: Iterator<Item = &'i axum::http::HeaderValue>,
1142 {
1143 let version = values
1144 .next()
1145 .ok_or_else(axum::headers::Error::invalid)?
1146 .to_str()
1147 .map_err(|_| axum::headers::Error::invalid())?
1148 .parse()
1149 .map_err(|_| axum::headers::Error::invalid())?;
1150 Ok(Self(version))
1151 }
1152
1153 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1154 values.extend([self.0.to_string().parse().unwrap()]);
1155 }
1156}
1157
1158pub struct AppVersionHeader(SemanticVersion);
1159impl Header for AppVersionHeader {
1160 fn name() -> &'static HeaderName {
1161 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1162 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1163 }
1164
1165 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1166 where
1167 Self: Sized,
1168 I: Iterator<Item = &'i axum::http::HeaderValue>,
1169 {
1170 let version = values
1171 .next()
1172 .ok_or_else(axum::headers::Error::invalid)?
1173 .to_str()
1174 .map_err(|_| axum::headers::Error::invalid())?
1175 .parse()
1176 .map_err(|_| axum::headers::Error::invalid())?;
1177 Ok(Self(version))
1178 }
1179
1180 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1181 values.extend([self.0.to_string().parse().unwrap()]);
1182 }
1183}
1184
1185#[derive(Debug)]
1186pub struct ReleaseChannelHeader(String);
1187
1188impl Header for ReleaseChannelHeader {
1189 fn name() -> &'static HeaderName {
1190 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1191 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1192 }
1193
1194 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1195 where
1196 Self: Sized,
1197 I: Iterator<Item = &'i axum::http::HeaderValue>,
1198 {
1199 Ok(Self(
1200 values
1201 .next()
1202 .ok_or_else(axum::headers::Error::invalid)?
1203 .to_str()
1204 .map_err(|_| axum::headers::Error::invalid())?
1205 .to_owned(),
1206 ))
1207 }
1208
1209 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1210 values.extend([self.0.parse().unwrap()]);
1211 }
1212}
1213
1214pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1215 Router::new()
1216 .route("/rpc", get(handle_websocket_request))
1217 .layer(
1218 ServiceBuilder::new()
1219 .layer(Extension(server.app_state.clone()))
1220 .layer(middleware::from_fn(auth::validate_header)),
1221 )
1222 .route("/metrics", get(handle_metrics))
1223 .layer(Extension(server))
1224}
1225
1226pub async fn handle_websocket_request(
1227 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1228 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1229 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1230 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1231 Extension(server): Extension<Arc<Server>>,
1232 Extension(principal): Extension<Principal>,
1233 user_agent: Option<TypedHeader<UserAgent>>,
1234 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1235 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1236 ws: WebSocketUpgrade,
1237) -> axum::response::Response {
1238 if protocol_version != rpc::PROTOCOL_VERSION {
1239 return (
1240 StatusCode::UPGRADE_REQUIRED,
1241 "client must be upgraded".to_string(),
1242 )
1243 .into_response();
1244 }
1245
1246 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1247 return (
1248 StatusCode::UPGRADE_REQUIRED,
1249 "no version header found".to_string(),
1250 )
1251 .into_response();
1252 };
1253
1254 let release_channel = release_channel_header.map(|header| header.0.0);
1255
1256 if !version.can_collaborate() {
1257 return (
1258 StatusCode::UPGRADE_REQUIRED,
1259 "client must be upgraded".to_string(),
1260 )
1261 .into_response();
1262 }
1263
1264 let socket_address = socket_address.to_string();
1265
1266 // Acquire connection guard before WebSocket upgrade
1267 let connection_guard = match ConnectionGuard::try_acquire() {
1268 Ok(guard) => guard,
1269 Err(()) => {
1270 return (
1271 StatusCode::SERVICE_UNAVAILABLE,
1272 "Too many concurrent connections",
1273 )
1274 .into_response();
1275 }
1276 };
1277
1278 ws.on_upgrade(move |socket| {
1279 let socket = socket
1280 .map_ok(to_tungstenite_message)
1281 .err_into()
1282 .with(|message| async move { to_axum_message(message) });
1283 let connection = Connection::new(Box::pin(socket));
1284 async move {
1285 server
1286 .handle_connection(
1287 connection,
1288 socket_address,
1289 principal,
1290 version,
1291 release_channel,
1292 user_agent.map(|header| header.to_string()),
1293 country_code_header.map(|header| header.to_string()),
1294 system_id_header.map(|header| header.to_string()),
1295 None,
1296 Executor::Production,
1297 Some(connection_guard),
1298 )
1299 .await;
1300 }
1301 })
1302}
1303
1304pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1305 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1306 let connections_metric = CONNECTIONS_METRIC
1307 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1308
1309 let connections = server
1310 .connection_pool
1311 .lock()
1312 .connections()
1313 .filter(|connection| !connection.admin)
1314 .count();
1315 connections_metric.set(connections as _);
1316
1317 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1318 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1319 register_int_gauge!(
1320 "shared_projects",
1321 "number of open projects with one or more guests"
1322 )
1323 .unwrap()
1324 });
1325
1326 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1327 shared_projects_metric.set(shared_projects as _);
1328
1329 let encoder = prometheus::TextEncoder::new();
1330 let metric_families = prometheus::gather();
1331 let encoded_metrics = encoder
1332 .encode_to_string(&metric_families)
1333 .map_err(|err| anyhow!("{err}"))?;
1334 Ok(encoded_metrics)
1335}
1336
1337#[instrument(err, skip(executor))]
1338async fn connection_lost(
1339 session: Session,
1340 mut teardown: watch::Receiver<bool>,
1341 executor: Executor,
1342) -> Result<()> {
1343 session.peer.disconnect(session.connection_id);
1344 session
1345 .connection_pool()
1346 .await
1347 .remove_connection(session.connection_id)?;
1348
1349 session
1350 .db()
1351 .await
1352 .connection_lost(session.connection_id)
1353 .await
1354 .trace_err();
1355
1356 futures::select_biased! {
1357 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1358
1359 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1360 leave_room_for_session(&session, session.connection_id).await.trace_err();
1361 leave_channel_buffers_for_session(&session)
1362 .await
1363 .trace_err();
1364
1365 if !session
1366 .connection_pool()
1367 .await
1368 .is_user_online(session.user_id())
1369 {
1370 let db = session.db().await;
1371 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1372 room_updated(&room, &session.peer);
1373 }
1374 }
1375
1376 update_user_contacts(session.user_id(), &session).await?;
1377 },
1378 _ = teardown.changed().fuse() => {}
1379 }
1380
1381 Ok(())
1382}
1383
1384/// Acknowledges a ping from a client, used to keep the connection alive.
1385async fn ping(
1386 _: proto::Ping,
1387 response: Response<proto::Ping>,
1388 _session: MessageContext,
1389) -> Result<()> {
1390 response.send(proto::Ack {})?;
1391 Ok(())
1392}
1393
1394/// Creates a new room for calling (outside of channels)
1395async fn create_room(
1396 _request: proto::CreateRoom,
1397 response: Response<proto::CreateRoom>,
1398 session: MessageContext,
1399) -> Result<()> {
1400 let livekit_room = nanoid::nanoid!(30);
1401
1402 let live_kit_connection_info = util::maybe!(async {
1403 let live_kit = session.app_state.livekit_client.as_ref();
1404 let live_kit = live_kit?;
1405 let user_id = session.user_id().to_string();
1406
1407 let token = live_kit
1408 .room_token(&livekit_room, &user_id.to_string())
1409 .trace_err()?;
1410
1411 Some(proto::LiveKitConnectionInfo {
1412 server_url: live_kit.url().into(),
1413 token,
1414 can_publish: true,
1415 })
1416 })
1417 .await;
1418
1419 let room = session
1420 .db()
1421 .await
1422 .create_room(session.user_id(), session.connection_id, &livekit_room)
1423 .await?;
1424
1425 response.send(proto::CreateRoomResponse {
1426 room: Some(room.clone()),
1427 live_kit_connection_info,
1428 })?;
1429
1430 update_user_contacts(session.user_id(), &session).await?;
1431 Ok(())
1432}
1433
1434/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1435async fn join_room(
1436 request: proto::JoinRoom,
1437 response: Response<proto::JoinRoom>,
1438 session: MessageContext,
1439) -> Result<()> {
1440 let room_id = RoomId::from_proto(request.id);
1441
1442 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1443
1444 if let Some(channel_id) = channel_id {
1445 return join_channel_internal(channel_id, Box::new(response), session).await;
1446 }
1447
1448 let joined_room = {
1449 let room = session
1450 .db()
1451 .await
1452 .join_room(room_id, session.user_id(), session.connection_id)
1453 .await?;
1454 room_updated(&room.room, &session.peer);
1455 room.into_inner()
1456 };
1457
1458 for connection_id in session
1459 .connection_pool()
1460 .await
1461 .user_connection_ids(session.user_id())
1462 {
1463 session
1464 .peer
1465 .send(
1466 connection_id,
1467 proto::CallCanceled {
1468 room_id: room_id.to_proto(),
1469 },
1470 )
1471 .trace_err();
1472 }
1473
1474 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1475 {
1476 live_kit
1477 .room_token(
1478 &joined_room.room.livekit_room,
1479 &session.user_id().to_string(),
1480 )
1481 .trace_err()
1482 .map(|token| proto::LiveKitConnectionInfo {
1483 server_url: live_kit.url().into(),
1484 token,
1485 can_publish: true,
1486 })
1487 } else {
1488 None
1489 };
1490
1491 response.send(proto::JoinRoomResponse {
1492 room: Some(joined_room.room),
1493 channel_id: None,
1494 live_kit_connection_info,
1495 })?;
1496
1497 update_user_contacts(session.user_id(), &session).await?;
1498 Ok(())
1499}
1500
1501/// Rejoin room is used to reconnect to a room after connection errors.
1502async fn rejoin_room(
1503 request: proto::RejoinRoom,
1504 response: Response<proto::RejoinRoom>,
1505 session: MessageContext,
1506) -> Result<()> {
1507 let room;
1508 let channel;
1509 {
1510 let mut rejoined_room = session
1511 .db()
1512 .await
1513 .rejoin_room(request, session.user_id(), session.connection_id)
1514 .await?;
1515
1516 response.send(proto::RejoinRoomResponse {
1517 room: Some(rejoined_room.room.clone()),
1518 reshared_projects: rejoined_room
1519 .reshared_projects
1520 .iter()
1521 .map(|project| proto::ResharedProject {
1522 id: project.id.to_proto(),
1523 collaborators: project
1524 .collaborators
1525 .iter()
1526 .map(|collaborator| collaborator.to_proto())
1527 .collect(),
1528 })
1529 .collect(),
1530 rejoined_projects: rejoined_room
1531 .rejoined_projects
1532 .iter()
1533 .map(|rejoined_project| rejoined_project.to_proto())
1534 .collect(),
1535 })?;
1536 room_updated(&rejoined_room.room, &session.peer);
1537
1538 for project in &rejoined_room.reshared_projects {
1539 for collaborator in &project.collaborators {
1540 session
1541 .peer
1542 .send(
1543 collaborator.connection_id,
1544 proto::UpdateProjectCollaborator {
1545 project_id: project.id.to_proto(),
1546 old_peer_id: Some(project.old_connection_id.into()),
1547 new_peer_id: Some(session.connection_id.into()),
1548 },
1549 )
1550 .trace_err();
1551 }
1552
1553 broadcast(
1554 Some(session.connection_id),
1555 project
1556 .collaborators
1557 .iter()
1558 .map(|collaborator| collaborator.connection_id),
1559 |connection_id| {
1560 session.peer.forward_send(
1561 session.connection_id,
1562 connection_id,
1563 proto::UpdateProject {
1564 project_id: project.id.to_proto(),
1565 worktrees: project.worktrees.clone(),
1566 },
1567 )
1568 },
1569 );
1570 }
1571
1572 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1573
1574 let rejoined_room = rejoined_room.into_inner();
1575
1576 room = rejoined_room.room;
1577 channel = rejoined_room.channel;
1578 }
1579
1580 if let Some(channel) = channel {
1581 channel_updated(
1582 &channel,
1583 &room,
1584 &session.peer,
1585 &*session.connection_pool().await,
1586 );
1587 }
1588
1589 update_user_contacts(session.user_id(), &session).await?;
1590 Ok(())
1591}
1592
1593fn notify_rejoined_projects(
1594 rejoined_projects: &mut Vec<RejoinedProject>,
1595 session: &Session,
1596) -> Result<()> {
1597 for project in rejoined_projects.iter() {
1598 for collaborator in &project.collaborators {
1599 session
1600 .peer
1601 .send(
1602 collaborator.connection_id,
1603 proto::UpdateProjectCollaborator {
1604 project_id: project.id.to_proto(),
1605 old_peer_id: Some(project.old_connection_id.into()),
1606 new_peer_id: Some(session.connection_id.into()),
1607 },
1608 )
1609 .trace_err();
1610 }
1611 }
1612
1613 for project in rejoined_projects {
1614 for worktree in mem::take(&mut project.worktrees) {
1615 // Stream this worktree's entries.
1616 let message = proto::UpdateWorktree {
1617 project_id: project.id.to_proto(),
1618 worktree_id: worktree.id,
1619 abs_path: worktree.abs_path.clone(),
1620 root_name: worktree.root_name,
1621 updated_entries: worktree.updated_entries,
1622 removed_entries: worktree.removed_entries,
1623 scan_id: worktree.scan_id,
1624 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1625 updated_repositories: worktree.updated_repositories,
1626 removed_repositories: worktree.removed_repositories,
1627 };
1628 for update in proto::split_worktree_update(message) {
1629 session.peer.send(session.connection_id, update)?;
1630 }
1631
1632 // Stream this worktree's diagnostics.
1633 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1634 if let Some(summary) = worktree_diagnostics.next() {
1635 let message = proto::UpdateDiagnosticSummary {
1636 project_id: project.id.to_proto(),
1637 worktree_id: worktree.id,
1638 summary: Some(summary),
1639 more_summaries: worktree_diagnostics.collect(),
1640 };
1641 session.peer.send(session.connection_id, message)?;
1642 }
1643
1644 for settings_file in worktree.settings_files {
1645 session.peer.send(
1646 session.connection_id,
1647 proto::UpdateWorktreeSettings {
1648 project_id: project.id.to_proto(),
1649 worktree_id: worktree.id,
1650 path: settings_file.path,
1651 content: Some(settings_file.content),
1652 kind: Some(settings_file.kind.to_proto().into()),
1653 },
1654 )?;
1655 }
1656 }
1657
1658 for repository in mem::take(&mut project.updated_repositories) {
1659 for update in split_repository_update(repository) {
1660 session.peer.send(session.connection_id, update)?;
1661 }
1662 }
1663
1664 for id in mem::take(&mut project.removed_repositories) {
1665 session.peer.send(
1666 session.connection_id,
1667 proto::RemoveRepository {
1668 project_id: project.id.to_proto(),
1669 id,
1670 },
1671 )?;
1672 }
1673 }
1674
1675 Ok(())
1676}
1677
1678/// leave room disconnects from the room.
1679async fn leave_room(
1680 _: proto::LeaveRoom,
1681 response: Response<proto::LeaveRoom>,
1682 session: MessageContext,
1683) -> Result<()> {
1684 leave_room_for_session(&session, session.connection_id).await?;
1685 response.send(proto::Ack {})?;
1686 Ok(())
1687}
1688
1689/// Updates the permissions of someone else in the room.
1690async fn set_room_participant_role(
1691 request: proto::SetRoomParticipantRole,
1692 response: Response<proto::SetRoomParticipantRole>,
1693 session: MessageContext,
1694) -> Result<()> {
1695 let user_id = UserId::from_proto(request.user_id);
1696 let role = ChannelRole::from(request.role());
1697
1698 let (livekit_room, can_publish) = {
1699 let room = session
1700 .db()
1701 .await
1702 .set_room_participant_role(
1703 session.user_id(),
1704 RoomId::from_proto(request.room_id),
1705 user_id,
1706 role,
1707 )
1708 .await?;
1709
1710 let livekit_room = room.livekit_room.clone();
1711 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1712 room_updated(&room, &session.peer);
1713 (livekit_room, can_publish)
1714 };
1715
1716 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1717 live_kit
1718 .update_participant(
1719 livekit_room.clone(),
1720 request.user_id.to_string(),
1721 livekit_api::proto::ParticipantPermission {
1722 can_subscribe: true,
1723 can_publish,
1724 can_publish_data: can_publish,
1725 hidden: false,
1726 recorder: false,
1727 },
1728 )
1729 .await
1730 .trace_err();
1731 }
1732
1733 response.send(proto::Ack {})?;
1734 Ok(())
1735}
1736
1737/// Call someone else into the current room
1738async fn call(
1739 request: proto::Call,
1740 response: Response<proto::Call>,
1741 session: MessageContext,
1742) -> Result<()> {
1743 let room_id = RoomId::from_proto(request.room_id);
1744 let calling_user_id = session.user_id();
1745 let calling_connection_id = session.connection_id;
1746 let called_user_id = UserId::from_proto(request.called_user_id);
1747 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1748 if !session
1749 .db()
1750 .await
1751 .has_contact(calling_user_id, called_user_id)
1752 .await?
1753 {
1754 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1755 }
1756
1757 let incoming_call = {
1758 let (room, incoming_call) = &mut *session
1759 .db()
1760 .await
1761 .call(
1762 room_id,
1763 calling_user_id,
1764 calling_connection_id,
1765 called_user_id,
1766 initial_project_id,
1767 )
1768 .await?;
1769 room_updated(room, &session.peer);
1770 mem::take(incoming_call)
1771 };
1772 update_user_contacts(called_user_id, &session).await?;
1773
1774 let mut calls = session
1775 .connection_pool()
1776 .await
1777 .user_connection_ids(called_user_id)
1778 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1779 .collect::<FuturesUnordered<_>>();
1780
1781 while let Some(call_response) = calls.next().await {
1782 match call_response.as_ref() {
1783 Ok(_) => {
1784 response.send(proto::Ack {})?;
1785 return Ok(());
1786 }
1787 Err(_) => {
1788 call_response.trace_err();
1789 }
1790 }
1791 }
1792
1793 {
1794 let room = session
1795 .db()
1796 .await
1797 .call_failed(room_id, called_user_id)
1798 .await?;
1799 room_updated(&room, &session.peer);
1800 }
1801 update_user_contacts(called_user_id, &session).await?;
1802
1803 Err(anyhow!("failed to ring user"))?
1804}
1805
1806/// Cancel an outgoing call.
1807async fn cancel_call(
1808 request: proto::CancelCall,
1809 response: Response<proto::CancelCall>,
1810 session: MessageContext,
1811) -> Result<()> {
1812 let called_user_id = UserId::from_proto(request.called_user_id);
1813 let room_id = RoomId::from_proto(request.room_id);
1814 {
1815 let room = session
1816 .db()
1817 .await
1818 .cancel_call(room_id, session.connection_id, called_user_id)
1819 .await?;
1820 room_updated(&room, &session.peer);
1821 }
1822
1823 for connection_id in session
1824 .connection_pool()
1825 .await
1826 .user_connection_ids(called_user_id)
1827 {
1828 session
1829 .peer
1830 .send(
1831 connection_id,
1832 proto::CallCanceled {
1833 room_id: room_id.to_proto(),
1834 },
1835 )
1836 .trace_err();
1837 }
1838 response.send(proto::Ack {})?;
1839
1840 update_user_contacts(called_user_id, &session).await?;
1841 Ok(())
1842}
1843
1844/// Decline an incoming call.
1845async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1846 let room_id = RoomId::from_proto(message.room_id);
1847 {
1848 let room = session
1849 .db()
1850 .await
1851 .decline_call(Some(room_id), session.user_id())
1852 .await?
1853 .context("declining call")?;
1854 room_updated(&room, &session.peer);
1855 }
1856
1857 for connection_id in session
1858 .connection_pool()
1859 .await
1860 .user_connection_ids(session.user_id())
1861 {
1862 session
1863 .peer
1864 .send(
1865 connection_id,
1866 proto::CallCanceled {
1867 room_id: room_id.to_proto(),
1868 },
1869 )
1870 .trace_err();
1871 }
1872 update_user_contacts(session.user_id(), &session).await?;
1873 Ok(())
1874}
1875
1876/// Updates other participants in the room with your current location.
1877async fn update_participant_location(
1878 request: proto::UpdateParticipantLocation,
1879 response: Response<proto::UpdateParticipantLocation>,
1880 session: MessageContext,
1881) -> Result<()> {
1882 let room_id = RoomId::from_proto(request.room_id);
1883 let location = request.location.context("invalid location")?;
1884
1885 let db = session.db().await;
1886 let room = db
1887 .update_room_participant_location(room_id, session.connection_id, location)
1888 .await?;
1889
1890 room_updated(&room, &session.peer);
1891 response.send(proto::Ack {})?;
1892 Ok(())
1893}
1894
1895/// Share a project into the room.
1896async fn share_project(
1897 request: proto::ShareProject,
1898 response: Response<proto::ShareProject>,
1899 session: MessageContext,
1900) -> Result<()> {
1901 let (project_id, room) = &*session
1902 .db()
1903 .await
1904 .share_project(
1905 RoomId::from_proto(request.room_id),
1906 session.connection_id,
1907 &request.worktrees,
1908 request.is_ssh_project,
1909 )
1910 .await?;
1911 response.send(proto::ShareProjectResponse {
1912 project_id: project_id.to_proto(),
1913 })?;
1914 room_updated(room, &session.peer);
1915
1916 Ok(())
1917}
1918
1919/// Unshare a project from the room.
1920async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1921 let project_id = ProjectId::from_proto(message.project_id);
1922 unshare_project_internal(project_id, session.connection_id, &session).await
1923}
1924
1925async fn unshare_project_internal(
1926 project_id: ProjectId,
1927 connection_id: ConnectionId,
1928 session: &Session,
1929) -> Result<()> {
1930 let delete = {
1931 let room_guard = session
1932 .db()
1933 .await
1934 .unshare_project(project_id, connection_id)
1935 .await?;
1936
1937 let (delete, room, guest_connection_ids) = &*room_guard;
1938
1939 let message = proto::UnshareProject {
1940 project_id: project_id.to_proto(),
1941 };
1942
1943 broadcast(
1944 Some(connection_id),
1945 guest_connection_ids.iter().copied(),
1946 |conn_id| session.peer.send(conn_id, message.clone()),
1947 );
1948 if let Some(room) = room {
1949 room_updated(room, &session.peer);
1950 }
1951
1952 *delete
1953 };
1954
1955 if delete {
1956 let db = session.db().await;
1957 db.delete_project(project_id).await?;
1958 }
1959
1960 Ok(())
1961}
1962
1963/// Join someone elses shared project.
1964async fn join_project(
1965 request: proto::JoinProject,
1966 response: Response<proto::JoinProject>,
1967 session: MessageContext,
1968) -> Result<()> {
1969 let project_id = ProjectId::from_proto(request.project_id);
1970
1971 tracing::info!(%project_id, "join project");
1972
1973 let db = session.db().await;
1974 let (project, replica_id) = &mut *db
1975 .join_project(
1976 project_id,
1977 session.connection_id,
1978 session.user_id(),
1979 request.committer_name.clone(),
1980 request.committer_email.clone(),
1981 )
1982 .await?;
1983 drop(db);
1984 tracing::info!(%project_id, "join remote project");
1985 let collaborators = project
1986 .collaborators
1987 .iter()
1988 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1989 .map(|collaborator| collaborator.to_proto())
1990 .collect::<Vec<_>>();
1991 let project_id = project.id;
1992 let guest_user_id = session.user_id();
1993
1994 let worktrees = project
1995 .worktrees
1996 .iter()
1997 .map(|(id, worktree)| proto::WorktreeMetadata {
1998 id: *id,
1999 root_name: worktree.root_name.clone(),
2000 visible: worktree.visible,
2001 abs_path: worktree.abs_path.clone(),
2002 })
2003 .collect::<Vec<_>>();
2004
2005 let add_project_collaborator = proto::AddProjectCollaborator {
2006 project_id: project_id.to_proto(),
2007 collaborator: Some(proto::Collaborator {
2008 peer_id: Some(session.connection_id.into()),
2009 replica_id: replica_id.0 as u32,
2010 user_id: guest_user_id.to_proto(),
2011 is_host: false,
2012 committer_name: request.committer_name.clone(),
2013 committer_email: request.committer_email.clone(),
2014 }),
2015 };
2016
2017 for collaborator in &collaborators {
2018 session
2019 .peer
2020 .send(
2021 collaborator.peer_id.unwrap().into(),
2022 add_project_collaborator.clone(),
2023 )
2024 .trace_err();
2025 }
2026
2027 // First, we send the metadata associated with each worktree.
2028 let (language_servers, language_server_capabilities) = project
2029 .language_servers
2030 .clone()
2031 .into_iter()
2032 .map(|server| (server.server, server.capabilities))
2033 .unzip();
2034 response.send(proto::JoinProjectResponse {
2035 project_id: project.id.0 as u64,
2036 worktrees: worktrees.clone(),
2037 replica_id: replica_id.0 as u32,
2038 collaborators: collaborators.clone(),
2039 language_servers,
2040 language_server_capabilities,
2041 role: project.role.into(),
2042 })?;
2043
2044 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2045 // Stream this worktree's entries.
2046 let message = proto::UpdateWorktree {
2047 project_id: project_id.to_proto(),
2048 worktree_id,
2049 abs_path: worktree.abs_path.clone(),
2050 root_name: worktree.root_name,
2051 updated_entries: worktree.entries,
2052 removed_entries: Default::default(),
2053 scan_id: worktree.scan_id,
2054 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2055 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2056 removed_repositories: Default::default(),
2057 };
2058 for update in proto::split_worktree_update(message) {
2059 session.peer.send(session.connection_id, update.clone())?;
2060 }
2061
2062 // Stream this worktree's diagnostics.
2063 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2064 if let Some(summary) = worktree_diagnostics.next() {
2065 let message = proto::UpdateDiagnosticSummary {
2066 project_id: project.id.to_proto(),
2067 worktree_id: worktree.id,
2068 summary: Some(summary),
2069 more_summaries: worktree_diagnostics.collect(),
2070 };
2071 session.peer.send(session.connection_id, message)?;
2072 }
2073
2074 for settings_file in worktree.settings_files {
2075 session.peer.send(
2076 session.connection_id,
2077 proto::UpdateWorktreeSettings {
2078 project_id: project_id.to_proto(),
2079 worktree_id: worktree.id,
2080 path: settings_file.path,
2081 content: Some(settings_file.content),
2082 kind: Some(settings_file.kind.to_proto() as i32),
2083 },
2084 )?;
2085 }
2086 }
2087
2088 for repository in mem::take(&mut project.repositories) {
2089 for update in split_repository_update(repository) {
2090 session.peer.send(session.connection_id, update)?;
2091 }
2092 }
2093
2094 for language_server in &project.language_servers {
2095 session.peer.send(
2096 session.connection_id,
2097 proto::UpdateLanguageServer {
2098 project_id: project_id.to_proto(),
2099 server_name: Some(language_server.server.name.clone()),
2100 language_server_id: language_server.server.id,
2101 variant: Some(
2102 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2103 proto::LspDiskBasedDiagnosticsUpdated {},
2104 ),
2105 ),
2106 },
2107 )?;
2108 }
2109
2110 Ok(())
2111}
2112
2113/// Leave someone elses shared project.
2114async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2115 let sender_id = session.connection_id;
2116 let project_id = ProjectId::from_proto(request.project_id);
2117 let db = session.db().await;
2118
2119 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2120 tracing::info!(
2121 %project_id,
2122 "leave project"
2123 );
2124
2125 project_left(project, &session);
2126 if let Some(room) = room {
2127 room_updated(room, &session.peer);
2128 }
2129
2130 Ok(())
2131}
2132
2133/// Updates other participants with changes to the project
2134async fn update_project(
2135 request: proto::UpdateProject,
2136 response: Response<proto::UpdateProject>,
2137 session: MessageContext,
2138) -> Result<()> {
2139 let project_id = ProjectId::from_proto(request.project_id);
2140 let (room, guest_connection_ids) = &*session
2141 .db()
2142 .await
2143 .update_project(project_id, session.connection_id, &request.worktrees)
2144 .await?;
2145 broadcast(
2146 Some(session.connection_id),
2147 guest_connection_ids.iter().copied(),
2148 |connection_id| {
2149 session
2150 .peer
2151 .forward_send(session.connection_id, connection_id, request.clone())
2152 },
2153 );
2154 if let Some(room) = room {
2155 room_updated(room, &session.peer);
2156 }
2157 response.send(proto::Ack {})?;
2158
2159 Ok(())
2160}
2161
2162/// Updates other participants with changes to the worktree
2163async fn update_worktree(
2164 request: proto::UpdateWorktree,
2165 response: Response<proto::UpdateWorktree>,
2166 session: MessageContext,
2167) -> Result<()> {
2168 let guest_connection_ids = session
2169 .db()
2170 .await
2171 .update_worktree(&request, session.connection_id)
2172 .await?;
2173
2174 broadcast(
2175 Some(session.connection_id),
2176 guest_connection_ids.iter().copied(),
2177 |connection_id| {
2178 session
2179 .peer
2180 .forward_send(session.connection_id, connection_id, request.clone())
2181 },
2182 );
2183 response.send(proto::Ack {})?;
2184 Ok(())
2185}
2186
2187async fn update_repository(
2188 request: proto::UpdateRepository,
2189 response: Response<proto::UpdateRepository>,
2190 session: MessageContext,
2191) -> Result<()> {
2192 let guest_connection_ids = session
2193 .db()
2194 .await
2195 .update_repository(&request, session.connection_id)
2196 .await?;
2197
2198 broadcast(
2199 Some(session.connection_id),
2200 guest_connection_ids.iter().copied(),
2201 |connection_id| {
2202 session
2203 .peer
2204 .forward_send(session.connection_id, connection_id, request.clone())
2205 },
2206 );
2207 response.send(proto::Ack {})?;
2208 Ok(())
2209}
2210
2211async fn remove_repository(
2212 request: proto::RemoveRepository,
2213 response: Response<proto::RemoveRepository>,
2214 session: MessageContext,
2215) -> Result<()> {
2216 let guest_connection_ids = session
2217 .db()
2218 .await
2219 .remove_repository(&request, session.connection_id)
2220 .await?;
2221
2222 broadcast(
2223 Some(session.connection_id),
2224 guest_connection_ids.iter().copied(),
2225 |connection_id| {
2226 session
2227 .peer
2228 .forward_send(session.connection_id, connection_id, request.clone())
2229 },
2230 );
2231 response.send(proto::Ack {})?;
2232 Ok(())
2233}
2234
2235/// Updates other participants with changes to the diagnostics
2236async fn update_diagnostic_summary(
2237 message: proto::UpdateDiagnosticSummary,
2238 session: MessageContext,
2239) -> Result<()> {
2240 let guest_connection_ids = session
2241 .db()
2242 .await
2243 .update_diagnostic_summary(&message, session.connection_id)
2244 .await?;
2245
2246 broadcast(
2247 Some(session.connection_id),
2248 guest_connection_ids.iter().copied(),
2249 |connection_id| {
2250 session
2251 .peer
2252 .forward_send(session.connection_id, connection_id, message.clone())
2253 },
2254 );
2255
2256 Ok(())
2257}
2258
2259/// Updates other participants with changes to the worktree settings
2260async fn update_worktree_settings(
2261 message: proto::UpdateWorktreeSettings,
2262 session: MessageContext,
2263) -> Result<()> {
2264 let guest_connection_ids = session
2265 .db()
2266 .await
2267 .update_worktree_settings(&message, session.connection_id)
2268 .await?;
2269
2270 broadcast(
2271 Some(session.connection_id),
2272 guest_connection_ids.iter().copied(),
2273 |connection_id| {
2274 session
2275 .peer
2276 .forward_send(session.connection_id, connection_id, message.clone())
2277 },
2278 );
2279
2280 Ok(())
2281}
2282
2283/// Notify other participants that a language server has started.
2284async fn start_language_server(
2285 request: proto::StartLanguageServer,
2286 session: MessageContext,
2287) -> Result<()> {
2288 let guest_connection_ids = session
2289 .db()
2290 .await
2291 .start_language_server(&request, session.connection_id)
2292 .await?;
2293
2294 broadcast(
2295 Some(session.connection_id),
2296 guest_connection_ids.iter().copied(),
2297 |connection_id| {
2298 session
2299 .peer
2300 .forward_send(session.connection_id, connection_id, request.clone())
2301 },
2302 );
2303 Ok(())
2304}
2305
2306/// Notify other participants that a language server has changed.
2307async fn update_language_server(
2308 request: proto::UpdateLanguageServer,
2309 session: MessageContext,
2310) -> Result<()> {
2311 let project_id = ProjectId::from_proto(request.project_id);
2312 let db = session.db().await;
2313
2314 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2315 {
2316 if let Some(capabilities) = update.capabilities.clone() {
2317 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2318 .await?;
2319 }
2320 }
2321
2322 let project_connection_ids = db
2323 .project_connection_ids(project_id, session.connection_id, true)
2324 .await?;
2325 broadcast(
2326 Some(session.connection_id),
2327 project_connection_ids.iter().copied(),
2328 |connection_id| {
2329 session
2330 .peer
2331 .forward_send(session.connection_id, connection_id, request.clone())
2332 },
2333 );
2334 Ok(())
2335}
2336
2337/// forward a project request to the host. These requests should be read only
2338/// as guests are allowed to send them.
2339async fn forward_read_only_project_request<T>(
2340 request: T,
2341 response: Response<T>,
2342 session: MessageContext,
2343) -> Result<()>
2344where
2345 T: EntityMessage + RequestMessage,
2346{
2347 let project_id = ProjectId::from_proto(request.remote_entity_id());
2348 let host_connection_id = session
2349 .db()
2350 .await
2351 .host_for_read_only_project_request(project_id, session.connection_id)
2352 .await?;
2353 let payload = session.forward_request(host_connection_id, request).await?;
2354 response.send(payload)?;
2355 Ok(())
2356}
2357
2358/// forward a project request to the host. These requests are disallowed
2359/// for guests.
2360async fn forward_mutating_project_request<T>(
2361 request: T,
2362 response: Response<T>,
2363 session: MessageContext,
2364) -> Result<()>
2365where
2366 T: EntityMessage + RequestMessage,
2367{
2368 let project_id = ProjectId::from_proto(request.remote_entity_id());
2369
2370 let host_connection_id = session
2371 .db()
2372 .await
2373 .host_for_mutating_project_request(project_id, session.connection_id)
2374 .await?;
2375 let payload = session.forward_request(host_connection_id, request).await?;
2376 response.send(payload)?;
2377 Ok(())
2378}
2379
2380async fn multi_lsp_query(
2381 request: MultiLspQuery,
2382 response: Response<MultiLspQuery>,
2383 session: MessageContext,
2384) -> Result<()> {
2385 tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2386 tracing::info!("multi_lsp_query message received");
2387 forward_mutating_project_request(request, response, session).await
2388}
2389
2390/// Notify other participants that a new buffer has been created
2391async fn create_buffer_for_peer(
2392 request: proto::CreateBufferForPeer,
2393 session: MessageContext,
2394) -> Result<()> {
2395 session
2396 .db()
2397 .await
2398 .check_user_is_project_host(
2399 ProjectId::from_proto(request.project_id),
2400 session.connection_id,
2401 )
2402 .await?;
2403 let peer_id = request.peer_id.context("invalid peer id")?;
2404 session
2405 .peer
2406 .forward_send(session.connection_id, peer_id.into(), request)?;
2407 Ok(())
2408}
2409
2410/// Notify other participants that a buffer has been updated. This is
2411/// allowed for guests as long as the update is limited to selections.
2412async fn update_buffer(
2413 request: proto::UpdateBuffer,
2414 response: Response<proto::UpdateBuffer>,
2415 session: MessageContext,
2416) -> Result<()> {
2417 let project_id = ProjectId::from_proto(request.project_id);
2418 let mut capability = Capability::ReadOnly;
2419
2420 for op in request.operations.iter() {
2421 match op.variant {
2422 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2423 Some(_) => capability = Capability::ReadWrite,
2424 }
2425 }
2426
2427 let host = {
2428 let guard = session
2429 .db()
2430 .await
2431 .connections_for_buffer_update(project_id, session.connection_id, capability)
2432 .await?;
2433
2434 let (host, guests) = &*guard;
2435
2436 broadcast(
2437 Some(session.connection_id),
2438 guests.clone(),
2439 |connection_id| {
2440 session
2441 .peer
2442 .forward_send(session.connection_id, connection_id, request.clone())
2443 },
2444 );
2445
2446 *host
2447 };
2448
2449 if host != session.connection_id {
2450 session.forward_request(host, request.clone()).await?;
2451 }
2452
2453 response.send(proto::Ack {})?;
2454 Ok(())
2455}
2456
2457async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2458 let project_id = ProjectId::from_proto(message.project_id);
2459
2460 let operation = message.operation.as_ref().context("invalid operation")?;
2461 let capability = match operation.variant.as_ref() {
2462 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2463 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2464 match buffer_op.variant {
2465 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2466 Capability::ReadOnly
2467 }
2468 _ => Capability::ReadWrite,
2469 }
2470 } else {
2471 Capability::ReadWrite
2472 }
2473 }
2474 Some(_) => Capability::ReadWrite,
2475 None => Capability::ReadOnly,
2476 };
2477
2478 let guard = session
2479 .db()
2480 .await
2481 .connections_for_buffer_update(project_id, session.connection_id, capability)
2482 .await?;
2483
2484 let (host, guests) = &*guard;
2485
2486 broadcast(
2487 Some(session.connection_id),
2488 guests.iter().chain([host]).copied(),
2489 |connection_id| {
2490 session
2491 .peer
2492 .forward_send(session.connection_id, connection_id, message.clone())
2493 },
2494 );
2495
2496 Ok(())
2497}
2498
2499/// Notify other participants that a project has been updated.
2500async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2501 request: T,
2502 session: MessageContext,
2503) -> Result<()> {
2504 let project_id = ProjectId::from_proto(request.remote_entity_id());
2505 let project_connection_ids = session
2506 .db()
2507 .await
2508 .project_connection_ids(project_id, session.connection_id, false)
2509 .await?;
2510
2511 broadcast(
2512 Some(session.connection_id),
2513 project_connection_ids.iter().copied(),
2514 |connection_id| {
2515 session
2516 .peer
2517 .forward_send(session.connection_id, connection_id, request.clone())
2518 },
2519 );
2520 Ok(())
2521}
2522
2523/// Start following another user in a call.
2524async fn follow(
2525 request: proto::Follow,
2526 response: Response<proto::Follow>,
2527 session: MessageContext,
2528) -> Result<()> {
2529 let room_id = RoomId::from_proto(request.room_id);
2530 let project_id = request.project_id.map(ProjectId::from_proto);
2531 let leader_id = request.leader_id.context("invalid leader id")?.into();
2532 let follower_id = session.connection_id;
2533
2534 session
2535 .db()
2536 .await
2537 .check_room_participants(room_id, leader_id, session.connection_id)
2538 .await?;
2539
2540 let response_payload = session.forward_request(leader_id, request).await?;
2541 response.send(response_payload)?;
2542
2543 if let Some(project_id) = project_id {
2544 let room = session
2545 .db()
2546 .await
2547 .follow(room_id, project_id, leader_id, follower_id)
2548 .await?;
2549 room_updated(&room, &session.peer);
2550 }
2551
2552 Ok(())
2553}
2554
2555/// Stop following another user in a call.
2556async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2557 let room_id = RoomId::from_proto(request.room_id);
2558 let project_id = request.project_id.map(ProjectId::from_proto);
2559 let leader_id = request.leader_id.context("invalid leader id")?.into();
2560 let follower_id = session.connection_id;
2561
2562 session
2563 .db()
2564 .await
2565 .check_room_participants(room_id, leader_id, session.connection_id)
2566 .await?;
2567
2568 session
2569 .peer
2570 .forward_send(session.connection_id, leader_id, request)?;
2571
2572 if let Some(project_id) = project_id {
2573 let room = session
2574 .db()
2575 .await
2576 .unfollow(room_id, project_id, leader_id, follower_id)
2577 .await?;
2578 room_updated(&room, &session.peer);
2579 }
2580
2581 Ok(())
2582}
2583
2584/// Notify everyone following you of your current location.
2585async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2586 let room_id = RoomId::from_proto(request.room_id);
2587 let database = session.db.lock().await;
2588
2589 let connection_ids = if let Some(project_id) = request.project_id {
2590 let project_id = ProjectId::from_proto(project_id);
2591 database
2592 .project_connection_ids(project_id, session.connection_id, true)
2593 .await?
2594 } else {
2595 database
2596 .room_connection_ids(room_id, session.connection_id)
2597 .await?
2598 };
2599
2600 // For now, don't send view update messages back to that view's current leader.
2601 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2602 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2603 _ => None,
2604 });
2605
2606 for connection_id in connection_ids.iter().cloned() {
2607 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2608 session
2609 .peer
2610 .forward_send(session.connection_id, connection_id, request.clone())?;
2611 }
2612 }
2613 Ok(())
2614}
2615
2616/// Get public data about users.
2617async fn get_users(
2618 request: proto::GetUsers,
2619 response: Response<proto::GetUsers>,
2620 session: MessageContext,
2621) -> Result<()> {
2622 let user_ids = request
2623 .user_ids
2624 .into_iter()
2625 .map(UserId::from_proto)
2626 .collect();
2627 let users = session
2628 .db()
2629 .await
2630 .get_users_by_ids(user_ids)
2631 .await?
2632 .into_iter()
2633 .map(|user| proto::User {
2634 id: user.id.to_proto(),
2635 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2636 github_login: user.github_login,
2637 name: user.name,
2638 })
2639 .collect();
2640 response.send(proto::UsersResponse { users })?;
2641 Ok(())
2642}
2643
2644/// Search for users (to invite) buy Github login
2645async fn fuzzy_search_users(
2646 request: proto::FuzzySearchUsers,
2647 response: Response<proto::FuzzySearchUsers>,
2648 session: MessageContext,
2649) -> Result<()> {
2650 let query = request.query;
2651 let users = match query.len() {
2652 0 => vec![],
2653 1 | 2 => session
2654 .db()
2655 .await
2656 .get_user_by_github_login(&query)
2657 .await?
2658 .into_iter()
2659 .collect(),
2660 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2661 };
2662 let users = users
2663 .into_iter()
2664 .filter(|user| user.id != session.user_id())
2665 .map(|user| proto::User {
2666 id: user.id.to_proto(),
2667 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2668 github_login: user.github_login,
2669 name: user.name,
2670 })
2671 .collect();
2672 response.send(proto::UsersResponse { users })?;
2673 Ok(())
2674}
2675
2676/// Send a contact request to another user.
2677async fn request_contact(
2678 request: proto::RequestContact,
2679 response: Response<proto::RequestContact>,
2680 session: MessageContext,
2681) -> Result<()> {
2682 let requester_id = session.user_id();
2683 let responder_id = UserId::from_proto(request.responder_id);
2684 if requester_id == responder_id {
2685 return Err(anyhow!("cannot add yourself as a contact"))?;
2686 }
2687
2688 let notifications = session
2689 .db()
2690 .await
2691 .send_contact_request(requester_id, responder_id)
2692 .await?;
2693
2694 // Update outgoing contact requests of requester
2695 let mut update = proto::UpdateContacts::default();
2696 update.outgoing_requests.push(responder_id.to_proto());
2697 for connection_id in session
2698 .connection_pool()
2699 .await
2700 .user_connection_ids(requester_id)
2701 {
2702 session.peer.send(connection_id, update.clone())?;
2703 }
2704
2705 // Update incoming contact requests of responder
2706 let mut update = proto::UpdateContacts::default();
2707 update
2708 .incoming_requests
2709 .push(proto::IncomingContactRequest {
2710 requester_id: requester_id.to_proto(),
2711 });
2712 let connection_pool = session.connection_pool().await;
2713 for connection_id in connection_pool.user_connection_ids(responder_id) {
2714 session.peer.send(connection_id, update.clone())?;
2715 }
2716
2717 send_notifications(&connection_pool, &session.peer, notifications);
2718
2719 response.send(proto::Ack {})?;
2720 Ok(())
2721}
2722
2723/// Accept or decline a contact request
2724async fn respond_to_contact_request(
2725 request: proto::RespondToContactRequest,
2726 response: Response<proto::RespondToContactRequest>,
2727 session: MessageContext,
2728) -> Result<()> {
2729 let responder_id = session.user_id();
2730 let requester_id = UserId::from_proto(request.requester_id);
2731 let db = session.db().await;
2732 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2733 db.dismiss_contact_notification(responder_id, requester_id)
2734 .await?;
2735 } else {
2736 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2737
2738 let notifications = db
2739 .respond_to_contact_request(responder_id, requester_id, accept)
2740 .await?;
2741 let requester_busy = db.is_user_busy(requester_id).await?;
2742 let responder_busy = db.is_user_busy(responder_id).await?;
2743
2744 let pool = session.connection_pool().await;
2745 // Update responder with new contact
2746 let mut update = proto::UpdateContacts::default();
2747 if accept {
2748 update
2749 .contacts
2750 .push(contact_for_user(requester_id, requester_busy, &pool));
2751 }
2752 update
2753 .remove_incoming_requests
2754 .push(requester_id.to_proto());
2755 for connection_id in pool.user_connection_ids(responder_id) {
2756 session.peer.send(connection_id, update.clone())?;
2757 }
2758
2759 // Update requester with new contact
2760 let mut update = proto::UpdateContacts::default();
2761 if accept {
2762 update
2763 .contacts
2764 .push(contact_for_user(responder_id, responder_busy, &pool));
2765 }
2766 update
2767 .remove_outgoing_requests
2768 .push(responder_id.to_proto());
2769
2770 for connection_id in pool.user_connection_ids(requester_id) {
2771 session.peer.send(connection_id, update.clone())?;
2772 }
2773
2774 send_notifications(&pool, &session.peer, notifications);
2775 }
2776
2777 response.send(proto::Ack {})?;
2778 Ok(())
2779}
2780
2781/// Remove a contact.
2782async fn remove_contact(
2783 request: proto::RemoveContact,
2784 response: Response<proto::RemoveContact>,
2785 session: MessageContext,
2786) -> Result<()> {
2787 let requester_id = session.user_id();
2788 let responder_id = UserId::from_proto(request.user_id);
2789 let db = session.db().await;
2790 let (contact_accepted, deleted_notification_id) =
2791 db.remove_contact(requester_id, responder_id).await?;
2792
2793 let pool = session.connection_pool().await;
2794 // Update outgoing contact requests of requester
2795 let mut update = proto::UpdateContacts::default();
2796 if contact_accepted {
2797 update.remove_contacts.push(responder_id.to_proto());
2798 } else {
2799 update
2800 .remove_outgoing_requests
2801 .push(responder_id.to_proto());
2802 }
2803 for connection_id in pool.user_connection_ids(requester_id) {
2804 session.peer.send(connection_id, update.clone())?;
2805 }
2806
2807 // Update incoming contact requests of responder
2808 let mut update = proto::UpdateContacts::default();
2809 if contact_accepted {
2810 update.remove_contacts.push(requester_id.to_proto());
2811 } else {
2812 update
2813 .remove_incoming_requests
2814 .push(requester_id.to_proto());
2815 }
2816 for connection_id in pool.user_connection_ids(responder_id) {
2817 session.peer.send(connection_id, update.clone())?;
2818 if let Some(notification_id) = deleted_notification_id {
2819 session.peer.send(
2820 connection_id,
2821 proto::DeleteNotification {
2822 notification_id: notification_id.to_proto(),
2823 },
2824 )?;
2825 }
2826 }
2827
2828 response.send(proto::Ack {})?;
2829 Ok(())
2830}
2831
2832fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2833 version.0.minor() < 139
2834}
2835
2836async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2837 if is_staff {
2838 return Ok(proto::Plan::ZedPro);
2839 }
2840
2841 let subscription = db.get_active_billing_subscription(user_id).await?;
2842 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2843
2844 let plan = if let Some(subscription_kind) = subscription_kind {
2845 match subscription_kind {
2846 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2847 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2848 SubscriptionKind::ZedFree => proto::Plan::Free,
2849 }
2850 } else {
2851 proto::Plan::Free
2852 };
2853
2854 Ok(plan)
2855}
2856
2857async fn make_update_user_plan_message(
2858 user: &User,
2859 is_staff: bool,
2860 db: &Arc<Database>,
2861 llm_db: Option<Arc<LlmDatabase>>,
2862) -> Result<proto::UpdateUserPlan> {
2863 let feature_flags = db.get_user_flags(user.id).await?;
2864 let plan = current_plan(db, user.id, is_staff).await?;
2865 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2866 let billing_preferences = db.get_billing_preferences(user.id).await?;
2867
2868 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2869 let subscription = db.get_active_billing_subscription(user.id).await?;
2870
2871 let subscription_period =
2872 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2873
2874 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2875 llm_db
2876 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2877 .await?
2878 } else {
2879 None
2880 };
2881
2882 (subscription_period, usage)
2883 } else {
2884 (None, None)
2885 };
2886
2887 let bypass_account_age_check = feature_flags
2888 .iter()
2889 .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2890 let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2891 && !bypass_account_age_check
2892 && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2893
2894 Ok(proto::UpdateUserPlan {
2895 plan: plan.into(),
2896 trial_started_at: billing_customer
2897 .as_ref()
2898 .and_then(|billing_customer| billing_customer.trial_started_at)
2899 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2900 is_usage_based_billing_enabled: if is_staff {
2901 Some(true)
2902 } else {
2903 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2904 },
2905 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2906 proto::SubscriptionPeriod {
2907 started_at: started_at.timestamp() as u64,
2908 ended_at: ended_at.timestamp() as u64,
2909 }
2910 }),
2911 account_too_young: Some(account_too_young),
2912 has_overdue_invoices: billing_customer
2913 .map(|billing_customer| billing_customer.has_overdue_invoices),
2914 usage: Some(
2915 usage
2916 .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2917 .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2918 ),
2919 })
2920}
2921
2922fn model_requests_limit(
2923 plan: cloud_llm_client::Plan,
2924 feature_flags: &Vec<String>,
2925) -> cloud_llm_client::UsageLimit {
2926 match plan.model_requests_limit() {
2927 cloud_llm_client::UsageLimit::Limited(limit) => {
2928 let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2929 && feature_flags
2930 .iter()
2931 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2932 {
2933 1_000
2934 } else {
2935 limit
2936 };
2937
2938 cloud_llm_client::UsageLimit::Limited(limit)
2939 }
2940 cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2941 }
2942}
2943
2944fn subscription_usage_to_proto(
2945 plan: proto::Plan,
2946 usage: crate::llm::db::subscription_usage::Model,
2947 feature_flags: &Vec<String>,
2948) -> proto::SubscriptionUsage {
2949 let plan = match plan {
2950 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2951 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2952 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2953 };
2954
2955 proto::SubscriptionUsage {
2956 model_requests_usage_amount: usage.model_requests as u32,
2957 model_requests_usage_limit: Some(proto::UsageLimit {
2958 variant: Some(match model_requests_limit(plan, feature_flags) {
2959 cloud_llm_client::UsageLimit::Limited(limit) => {
2960 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2961 limit: limit as u32,
2962 })
2963 }
2964 cloud_llm_client::UsageLimit::Unlimited => {
2965 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2966 }
2967 }),
2968 }),
2969 edit_predictions_usage_amount: usage.edit_predictions as u32,
2970 edit_predictions_usage_limit: Some(proto::UsageLimit {
2971 variant: Some(match plan.edit_predictions_limit() {
2972 cloud_llm_client::UsageLimit::Limited(limit) => {
2973 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2974 limit: limit as u32,
2975 })
2976 }
2977 cloud_llm_client::UsageLimit::Unlimited => {
2978 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2979 }
2980 }),
2981 }),
2982 }
2983}
2984
2985fn make_default_subscription_usage(
2986 plan: proto::Plan,
2987 feature_flags: &Vec<String>,
2988) -> proto::SubscriptionUsage {
2989 let plan = match plan {
2990 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2991 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2992 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2993 };
2994
2995 proto::SubscriptionUsage {
2996 model_requests_usage_amount: 0,
2997 model_requests_usage_limit: Some(proto::UsageLimit {
2998 variant: Some(match model_requests_limit(plan, feature_flags) {
2999 cloud_llm_client::UsageLimit::Limited(limit) => {
3000 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3001 limit: limit as u32,
3002 })
3003 }
3004 cloud_llm_client::UsageLimit::Unlimited => {
3005 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3006 }
3007 }),
3008 }),
3009 edit_predictions_usage_amount: 0,
3010 edit_predictions_usage_limit: Some(proto::UsageLimit {
3011 variant: Some(match plan.edit_predictions_limit() {
3012 cloud_llm_client::UsageLimit::Limited(limit) => {
3013 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3014 limit: limit as u32,
3015 })
3016 }
3017 cloud_llm_client::UsageLimit::Unlimited => {
3018 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3019 }
3020 }),
3021 }),
3022 }
3023}
3024
3025async fn update_user_plan(session: &Session) -> Result<()> {
3026 let db = session.db().await;
3027
3028 let update_user_plan = make_update_user_plan_message(
3029 session.principal.user(),
3030 session.is_staff(),
3031 &db.0,
3032 session.app_state.llm_db.clone(),
3033 )
3034 .await?;
3035
3036 session
3037 .peer
3038 .send(session.connection_id, update_user_plan)
3039 .trace_err();
3040
3041 Ok(())
3042}
3043
3044async fn subscribe_to_channels(
3045 _: proto::SubscribeToChannels,
3046 session: MessageContext,
3047) -> Result<()> {
3048 subscribe_user_to_channels(session.user_id(), &session).await?;
3049 Ok(())
3050}
3051
3052async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3053 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3054 let mut pool = session.connection_pool().await;
3055 for membership in &channels_for_user.channel_memberships {
3056 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3057 }
3058 session.peer.send(
3059 session.connection_id,
3060 build_update_user_channels(&channels_for_user),
3061 )?;
3062 session.peer.send(
3063 session.connection_id,
3064 build_channels_update(channels_for_user),
3065 )?;
3066 Ok(())
3067}
3068
3069/// Creates a new channel.
3070async fn create_channel(
3071 request: proto::CreateChannel,
3072 response: Response<proto::CreateChannel>,
3073 session: MessageContext,
3074) -> Result<()> {
3075 let db = session.db().await;
3076
3077 let parent_id = request.parent_id.map(ChannelId::from_proto);
3078 let (channel, membership) = db
3079 .create_channel(&request.name, parent_id, session.user_id())
3080 .await?;
3081
3082 let root_id = channel.root_id();
3083 let channel = Channel::from_model(channel);
3084
3085 response.send(proto::CreateChannelResponse {
3086 channel: Some(channel.to_proto()),
3087 parent_id: request.parent_id,
3088 })?;
3089
3090 let mut connection_pool = session.connection_pool().await;
3091 if let Some(membership) = membership {
3092 connection_pool.subscribe_to_channel(
3093 membership.user_id,
3094 membership.channel_id,
3095 membership.role,
3096 );
3097 let update = proto::UpdateUserChannels {
3098 channel_memberships: vec![proto::ChannelMembership {
3099 channel_id: membership.channel_id.to_proto(),
3100 role: membership.role.into(),
3101 }],
3102 ..Default::default()
3103 };
3104 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3105 session.peer.send(connection_id, update.clone())?;
3106 }
3107 }
3108
3109 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3110 if !role.can_see_channel(channel.visibility) {
3111 continue;
3112 }
3113
3114 let update = proto::UpdateChannels {
3115 channels: vec![channel.to_proto()],
3116 ..Default::default()
3117 };
3118 session.peer.send(connection_id, update.clone())?;
3119 }
3120
3121 Ok(())
3122}
3123
3124/// Delete a channel
3125async fn delete_channel(
3126 request: proto::DeleteChannel,
3127 response: Response<proto::DeleteChannel>,
3128 session: MessageContext,
3129) -> Result<()> {
3130 let db = session.db().await;
3131
3132 let channel_id = request.channel_id;
3133 let (root_channel, removed_channels) = db
3134 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3135 .await?;
3136 response.send(proto::Ack {})?;
3137
3138 // Notify members of removed channels
3139 let mut update = proto::UpdateChannels::default();
3140 update
3141 .delete_channels
3142 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3143
3144 let connection_pool = session.connection_pool().await;
3145 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3146 session.peer.send(connection_id, update.clone())?;
3147 }
3148
3149 Ok(())
3150}
3151
3152/// Invite someone to join a channel.
3153async fn invite_channel_member(
3154 request: proto::InviteChannelMember,
3155 response: Response<proto::InviteChannelMember>,
3156 session: MessageContext,
3157) -> Result<()> {
3158 let db = session.db().await;
3159 let channel_id = ChannelId::from_proto(request.channel_id);
3160 let invitee_id = UserId::from_proto(request.user_id);
3161 let InviteMemberResult {
3162 channel,
3163 notifications,
3164 } = db
3165 .invite_channel_member(
3166 channel_id,
3167 invitee_id,
3168 session.user_id(),
3169 request.role().into(),
3170 )
3171 .await?;
3172
3173 let update = proto::UpdateChannels {
3174 channel_invitations: vec![channel.to_proto()],
3175 ..Default::default()
3176 };
3177
3178 let connection_pool = session.connection_pool().await;
3179 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3180 session.peer.send(connection_id, update.clone())?;
3181 }
3182
3183 send_notifications(&connection_pool, &session.peer, notifications);
3184
3185 response.send(proto::Ack {})?;
3186 Ok(())
3187}
3188
3189/// remove someone from a channel
3190async fn remove_channel_member(
3191 request: proto::RemoveChannelMember,
3192 response: Response<proto::RemoveChannelMember>,
3193 session: MessageContext,
3194) -> Result<()> {
3195 let db = session.db().await;
3196 let channel_id = ChannelId::from_proto(request.channel_id);
3197 let member_id = UserId::from_proto(request.user_id);
3198
3199 let RemoveChannelMemberResult {
3200 membership_update,
3201 notification_id,
3202 } = db
3203 .remove_channel_member(channel_id, member_id, session.user_id())
3204 .await?;
3205
3206 let mut connection_pool = session.connection_pool().await;
3207 notify_membership_updated(
3208 &mut connection_pool,
3209 membership_update,
3210 member_id,
3211 &session.peer,
3212 );
3213 for connection_id in connection_pool.user_connection_ids(member_id) {
3214 if let Some(notification_id) = notification_id {
3215 session
3216 .peer
3217 .send(
3218 connection_id,
3219 proto::DeleteNotification {
3220 notification_id: notification_id.to_proto(),
3221 },
3222 )
3223 .trace_err();
3224 }
3225 }
3226
3227 response.send(proto::Ack {})?;
3228 Ok(())
3229}
3230
3231/// Toggle the channel between public and private.
3232/// Care is taken to maintain the invariant that public channels only descend from public channels,
3233/// (though members-only channels can appear at any point in the hierarchy).
3234async fn set_channel_visibility(
3235 request: proto::SetChannelVisibility,
3236 response: Response<proto::SetChannelVisibility>,
3237 session: MessageContext,
3238) -> Result<()> {
3239 let db = session.db().await;
3240 let channel_id = ChannelId::from_proto(request.channel_id);
3241 let visibility = request.visibility().into();
3242
3243 let channel_model = db
3244 .set_channel_visibility(channel_id, visibility, session.user_id())
3245 .await?;
3246 let root_id = channel_model.root_id();
3247 let channel = Channel::from_model(channel_model);
3248
3249 let mut connection_pool = session.connection_pool().await;
3250 for (user_id, role) in connection_pool
3251 .channel_user_ids(root_id)
3252 .collect::<Vec<_>>()
3253 .into_iter()
3254 {
3255 let update = if role.can_see_channel(channel.visibility) {
3256 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3257 proto::UpdateChannels {
3258 channels: vec![channel.to_proto()],
3259 ..Default::default()
3260 }
3261 } else {
3262 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3263 proto::UpdateChannels {
3264 delete_channels: vec![channel.id.to_proto()],
3265 ..Default::default()
3266 }
3267 };
3268
3269 for connection_id in connection_pool.user_connection_ids(user_id) {
3270 session.peer.send(connection_id, update.clone())?;
3271 }
3272 }
3273
3274 response.send(proto::Ack {})?;
3275 Ok(())
3276}
3277
3278/// Alter the role for a user in the channel.
3279async fn set_channel_member_role(
3280 request: proto::SetChannelMemberRole,
3281 response: Response<proto::SetChannelMemberRole>,
3282 session: MessageContext,
3283) -> Result<()> {
3284 let db = session.db().await;
3285 let channel_id = ChannelId::from_proto(request.channel_id);
3286 let member_id = UserId::from_proto(request.user_id);
3287 let result = db
3288 .set_channel_member_role(
3289 channel_id,
3290 session.user_id(),
3291 member_id,
3292 request.role().into(),
3293 )
3294 .await?;
3295
3296 match result {
3297 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3298 let mut connection_pool = session.connection_pool().await;
3299 notify_membership_updated(
3300 &mut connection_pool,
3301 membership_update,
3302 member_id,
3303 &session.peer,
3304 )
3305 }
3306 db::SetMemberRoleResult::InviteUpdated(channel) => {
3307 let update = proto::UpdateChannels {
3308 channel_invitations: vec![channel.to_proto()],
3309 ..Default::default()
3310 };
3311
3312 for connection_id in session
3313 .connection_pool()
3314 .await
3315 .user_connection_ids(member_id)
3316 {
3317 session.peer.send(connection_id, update.clone())?;
3318 }
3319 }
3320 }
3321
3322 response.send(proto::Ack {})?;
3323 Ok(())
3324}
3325
3326/// Change the name of a channel
3327async fn rename_channel(
3328 request: proto::RenameChannel,
3329 response: Response<proto::RenameChannel>,
3330 session: MessageContext,
3331) -> Result<()> {
3332 let db = session.db().await;
3333 let channel_id = ChannelId::from_proto(request.channel_id);
3334 let channel_model = db
3335 .rename_channel(channel_id, session.user_id(), &request.name)
3336 .await?;
3337 let root_id = channel_model.root_id();
3338 let channel = Channel::from_model(channel_model);
3339
3340 response.send(proto::RenameChannelResponse {
3341 channel: Some(channel.to_proto()),
3342 })?;
3343
3344 let connection_pool = session.connection_pool().await;
3345 let update = proto::UpdateChannels {
3346 channels: vec![channel.to_proto()],
3347 ..Default::default()
3348 };
3349 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3350 if role.can_see_channel(channel.visibility) {
3351 session.peer.send(connection_id, update.clone())?;
3352 }
3353 }
3354
3355 Ok(())
3356}
3357
3358/// Move a channel to a new parent.
3359async fn move_channel(
3360 request: proto::MoveChannel,
3361 response: Response<proto::MoveChannel>,
3362 session: MessageContext,
3363) -> Result<()> {
3364 let channel_id = ChannelId::from_proto(request.channel_id);
3365 let to = ChannelId::from_proto(request.to);
3366
3367 let (root_id, channels) = session
3368 .db()
3369 .await
3370 .move_channel(channel_id, to, session.user_id())
3371 .await?;
3372
3373 let connection_pool = session.connection_pool().await;
3374 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3375 let channels = channels
3376 .iter()
3377 .filter_map(|channel| {
3378 if role.can_see_channel(channel.visibility) {
3379 Some(channel.to_proto())
3380 } else {
3381 None
3382 }
3383 })
3384 .collect::<Vec<_>>();
3385 if channels.is_empty() {
3386 continue;
3387 }
3388
3389 let update = proto::UpdateChannels {
3390 channels,
3391 ..Default::default()
3392 };
3393
3394 session.peer.send(connection_id, update.clone())?;
3395 }
3396
3397 response.send(Ack {})?;
3398 Ok(())
3399}
3400
3401async fn reorder_channel(
3402 request: proto::ReorderChannel,
3403 response: Response<proto::ReorderChannel>,
3404 session: MessageContext,
3405) -> Result<()> {
3406 let channel_id = ChannelId::from_proto(request.channel_id);
3407 let direction = request.direction();
3408
3409 let updated_channels = session
3410 .db()
3411 .await
3412 .reorder_channel(channel_id, direction, session.user_id())
3413 .await?;
3414
3415 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3416 let connection_pool = session.connection_pool().await;
3417 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3418 let channels = updated_channels
3419 .iter()
3420 .filter_map(|channel| {
3421 if role.can_see_channel(channel.visibility) {
3422 Some(channel.to_proto())
3423 } else {
3424 None
3425 }
3426 })
3427 .collect::<Vec<_>>();
3428
3429 if channels.is_empty() {
3430 continue;
3431 }
3432
3433 let update = proto::UpdateChannels {
3434 channels,
3435 ..Default::default()
3436 };
3437
3438 session.peer.send(connection_id, update.clone())?;
3439 }
3440 }
3441
3442 response.send(Ack {})?;
3443 Ok(())
3444}
3445
3446/// Get the list of channel members
3447async fn get_channel_members(
3448 request: proto::GetChannelMembers,
3449 response: Response<proto::GetChannelMembers>,
3450 session: MessageContext,
3451) -> Result<()> {
3452 let db = session.db().await;
3453 let channel_id = ChannelId::from_proto(request.channel_id);
3454 let limit = if request.limit == 0 {
3455 u16::MAX as u64
3456 } else {
3457 request.limit
3458 };
3459 let (members, users) = db
3460 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3461 .await?;
3462 response.send(proto::GetChannelMembersResponse { members, users })?;
3463 Ok(())
3464}
3465
3466/// Accept or decline a channel invitation.
3467async fn respond_to_channel_invite(
3468 request: proto::RespondToChannelInvite,
3469 response: Response<proto::RespondToChannelInvite>,
3470 session: MessageContext,
3471) -> Result<()> {
3472 let db = session.db().await;
3473 let channel_id = ChannelId::from_proto(request.channel_id);
3474 let RespondToChannelInvite {
3475 membership_update,
3476 notifications,
3477 } = db
3478 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3479 .await?;
3480
3481 let mut connection_pool = session.connection_pool().await;
3482 if let Some(membership_update) = membership_update {
3483 notify_membership_updated(
3484 &mut connection_pool,
3485 membership_update,
3486 session.user_id(),
3487 &session.peer,
3488 );
3489 } else {
3490 let update = proto::UpdateChannels {
3491 remove_channel_invitations: vec![channel_id.to_proto()],
3492 ..Default::default()
3493 };
3494
3495 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3496 session.peer.send(connection_id, update.clone())?;
3497 }
3498 };
3499
3500 send_notifications(&connection_pool, &session.peer, notifications);
3501
3502 response.send(proto::Ack {})?;
3503
3504 Ok(())
3505}
3506
3507/// Join the channels' room
3508async fn join_channel(
3509 request: proto::JoinChannel,
3510 response: Response<proto::JoinChannel>,
3511 session: MessageContext,
3512) -> Result<()> {
3513 let channel_id = ChannelId::from_proto(request.channel_id);
3514 join_channel_internal(channel_id, Box::new(response), session).await
3515}
3516
3517trait JoinChannelInternalResponse {
3518 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3519}
3520impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3521 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3522 Response::<proto::JoinChannel>::send(self, result)
3523 }
3524}
3525impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3526 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3527 Response::<proto::JoinRoom>::send(self, result)
3528 }
3529}
3530
3531async fn join_channel_internal(
3532 channel_id: ChannelId,
3533 response: Box<impl JoinChannelInternalResponse>,
3534 session: MessageContext,
3535) -> Result<()> {
3536 let joined_room = {
3537 let mut db = session.db().await;
3538 // If zed quits without leaving the room, and the user re-opens zed before the
3539 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3540 // room they were in.
3541 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3542 tracing::info!(
3543 stale_connection_id = %connection,
3544 "cleaning up stale connection",
3545 );
3546 drop(db);
3547 leave_room_for_session(&session, connection).await?;
3548 db = session.db().await;
3549 }
3550
3551 let (joined_room, membership_updated, role) = db
3552 .join_channel(channel_id, session.user_id(), session.connection_id)
3553 .await?;
3554
3555 let live_kit_connection_info =
3556 session
3557 .app_state
3558 .livekit_client
3559 .as_ref()
3560 .and_then(|live_kit| {
3561 let (can_publish, token) = if role == ChannelRole::Guest {
3562 (
3563 false,
3564 live_kit
3565 .guest_token(
3566 &joined_room.room.livekit_room,
3567 &session.user_id().to_string(),
3568 )
3569 .trace_err()?,
3570 )
3571 } else {
3572 (
3573 true,
3574 live_kit
3575 .room_token(
3576 &joined_room.room.livekit_room,
3577 &session.user_id().to_string(),
3578 )
3579 .trace_err()?,
3580 )
3581 };
3582
3583 Some(LiveKitConnectionInfo {
3584 server_url: live_kit.url().into(),
3585 token,
3586 can_publish,
3587 })
3588 });
3589
3590 response.send(proto::JoinRoomResponse {
3591 room: Some(joined_room.room.clone()),
3592 channel_id: joined_room
3593 .channel
3594 .as_ref()
3595 .map(|channel| channel.id.to_proto()),
3596 live_kit_connection_info,
3597 })?;
3598
3599 let mut connection_pool = session.connection_pool().await;
3600 if let Some(membership_updated) = membership_updated {
3601 notify_membership_updated(
3602 &mut connection_pool,
3603 membership_updated,
3604 session.user_id(),
3605 &session.peer,
3606 );
3607 }
3608
3609 room_updated(&joined_room.room, &session.peer);
3610
3611 joined_room
3612 };
3613
3614 channel_updated(
3615 &joined_room.channel.context("channel not returned")?,
3616 &joined_room.room,
3617 &session.peer,
3618 &*session.connection_pool().await,
3619 );
3620
3621 update_user_contacts(session.user_id(), &session).await?;
3622 Ok(())
3623}
3624
3625/// Start editing the channel notes
3626async fn join_channel_buffer(
3627 request: proto::JoinChannelBuffer,
3628 response: Response<proto::JoinChannelBuffer>,
3629 session: MessageContext,
3630) -> Result<()> {
3631 let db = session.db().await;
3632 let channel_id = ChannelId::from_proto(request.channel_id);
3633
3634 let open_response = db
3635 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3636 .await?;
3637
3638 let collaborators = open_response.collaborators.clone();
3639 response.send(open_response)?;
3640
3641 let update = UpdateChannelBufferCollaborators {
3642 channel_id: channel_id.to_proto(),
3643 collaborators: collaborators.clone(),
3644 };
3645 channel_buffer_updated(
3646 session.connection_id,
3647 collaborators
3648 .iter()
3649 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3650 &update,
3651 &session.peer,
3652 );
3653
3654 Ok(())
3655}
3656
3657/// Edit the channel notes
3658async fn update_channel_buffer(
3659 request: proto::UpdateChannelBuffer,
3660 session: MessageContext,
3661) -> Result<()> {
3662 let db = session.db().await;
3663 let channel_id = ChannelId::from_proto(request.channel_id);
3664
3665 let (collaborators, epoch, version) = db
3666 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3667 .await?;
3668
3669 channel_buffer_updated(
3670 session.connection_id,
3671 collaborators.clone(),
3672 &proto::UpdateChannelBuffer {
3673 channel_id: channel_id.to_proto(),
3674 operations: request.operations,
3675 },
3676 &session.peer,
3677 );
3678
3679 let pool = &*session.connection_pool().await;
3680
3681 let non_collaborators =
3682 pool.channel_connection_ids(channel_id)
3683 .filter_map(|(connection_id, _)| {
3684 if collaborators.contains(&connection_id) {
3685 None
3686 } else {
3687 Some(connection_id)
3688 }
3689 });
3690
3691 broadcast(None, non_collaborators, |peer_id| {
3692 session.peer.send(
3693 peer_id,
3694 proto::UpdateChannels {
3695 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3696 channel_id: channel_id.to_proto(),
3697 epoch: epoch as u64,
3698 version: version.clone(),
3699 }],
3700 ..Default::default()
3701 },
3702 )
3703 });
3704
3705 Ok(())
3706}
3707
3708/// Rejoin the channel notes after a connection blip
3709async fn rejoin_channel_buffers(
3710 request: proto::RejoinChannelBuffers,
3711 response: Response<proto::RejoinChannelBuffers>,
3712 session: MessageContext,
3713) -> Result<()> {
3714 let db = session.db().await;
3715 let buffers = db
3716 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3717 .await?;
3718
3719 for rejoined_buffer in &buffers {
3720 let collaborators_to_notify = rejoined_buffer
3721 .buffer
3722 .collaborators
3723 .iter()
3724 .filter_map(|c| Some(c.peer_id?.into()));
3725 channel_buffer_updated(
3726 session.connection_id,
3727 collaborators_to_notify,
3728 &proto::UpdateChannelBufferCollaborators {
3729 channel_id: rejoined_buffer.buffer.channel_id,
3730 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3731 },
3732 &session.peer,
3733 );
3734 }
3735
3736 response.send(proto::RejoinChannelBuffersResponse {
3737 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3738 })?;
3739
3740 Ok(())
3741}
3742
3743/// Stop editing the channel notes
3744async fn leave_channel_buffer(
3745 request: proto::LeaveChannelBuffer,
3746 response: Response<proto::LeaveChannelBuffer>,
3747 session: MessageContext,
3748) -> Result<()> {
3749 let db = session.db().await;
3750 let channel_id = ChannelId::from_proto(request.channel_id);
3751
3752 let left_buffer = db
3753 .leave_channel_buffer(channel_id, session.connection_id)
3754 .await?;
3755
3756 response.send(Ack {})?;
3757
3758 channel_buffer_updated(
3759 session.connection_id,
3760 left_buffer.connections,
3761 &proto::UpdateChannelBufferCollaborators {
3762 channel_id: channel_id.to_proto(),
3763 collaborators: left_buffer.collaborators,
3764 },
3765 &session.peer,
3766 );
3767
3768 Ok(())
3769}
3770
3771fn channel_buffer_updated<T: EnvelopedMessage>(
3772 sender_id: ConnectionId,
3773 collaborators: impl IntoIterator<Item = ConnectionId>,
3774 message: &T,
3775 peer: &Peer,
3776) {
3777 broadcast(Some(sender_id), collaborators, |peer_id| {
3778 peer.send(peer_id, message.clone())
3779 });
3780}
3781
3782fn send_notifications(
3783 connection_pool: &ConnectionPool,
3784 peer: &Peer,
3785 notifications: db::NotificationBatch,
3786) {
3787 for (user_id, notification) in notifications {
3788 for connection_id in connection_pool.user_connection_ids(user_id) {
3789 if let Err(error) = peer.send(
3790 connection_id,
3791 proto::AddNotification {
3792 notification: Some(notification.clone()),
3793 },
3794 ) {
3795 tracing::error!(
3796 "failed to send notification to {:?} {}",
3797 connection_id,
3798 error
3799 );
3800 }
3801 }
3802 }
3803}
3804
3805/// Send a message to the channel
3806async fn send_channel_message(
3807 request: proto::SendChannelMessage,
3808 response: Response<proto::SendChannelMessage>,
3809 session: MessageContext,
3810) -> Result<()> {
3811 // Validate the message body.
3812 let body = request.body.trim().to_string();
3813 if body.len() > MAX_MESSAGE_LEN {
3814 return Err(anyhow!("message is too long"))?;
3815 }
3816 if body.is_empty() {
3817 return Err(anyhow!("message can't be blank"))?;
3818 }
3819
3820 // TODO: adjust mentions if body is trimmed
3821
3822 let timestamp = OffsetDateTime::now_utc();
3823 let nonce = request.nonce.context("nonce can't be blank")?;
3824
3825 let channel_id = ChannelId::from_proto(request.channel_id);
3826 let CreatedChannelMessage {
3827 message_id,
3828 participant_connection_ids,
3829 notifications,
3830 } = session
3831 .db()
3832 .await
3833 .create_channel_message(
3834 channel_id,
3835 session.user_id(),
3836 &body,
3837 &request.mentions,
3838 timestamp,
3839 nonce.clone().into(),
3840 request.reply_to_message_id.map(MessageId::from_proto),
3841 )
3842 .await?;
3843
3844 let message = proto::ChannelMessage {
3845 sender_id: session.user_id().to_proto(),
3846 id: message_id.to_proto(),
3847 body,
3848 mentions: request.mentions,
3849 timestamp: timestamp.unix_timestamp() as u64,
3850 nonce: Some(nonce),
3851 reply_to_message_id: request.reply_to_message_id,
3852 edited_at: None,
3853 };
3854 broadcast(
3855 Some(session.connection_id),
3856 participant_connection_ids.clone(),
3857 |connection| {
3858 session.peer.send(
3859 connection,
3860 proto::ChannelMessageSent {
3861 channel_id: channel_id.to_proto(),
3862 message: Some(message.clone()),
3863 },
3864 )
3865 },
3866 );
3867 response.send(proto::SendChannelMessageResponse {
3868 message: Some(message),
3869 })?;
3870
3871 let pool = &*session.connection_pool().await;
3872 let non_participants =
3873 pool.channel_connection_ids(channel_id)
3874 .filter_map(|(connection_id, _)| {
3875 if participant_connection_ids.contains(&connection_id) {
3876 None
3877 } else {
3878 Some(connection_id)
3879 }
3880 });
3881 broadcast(None, non_participants, |peer_id| {
3882 session.peer.send(
3883 peer_id,
3884 proto::UpdateChannels {
3885 latest_channel_message_ids: vec![proto::ChannelMessageId {
3886 channel_id: channel_id.to_proto(),
3887 message_id: message_id.to_proto(),
3888 }],
3889 ..Default::default()
3890 },
3891 )
3892 });
3893 send_notifications(pool, &session.peer, notifications);
3894
3895 Ok(())
3896}
3897
3898/// Delete a channel message
3899async fn remove_channel_message(
3900 request: proto::RemoveChannelMessage,
3901 response: Response<proto::RemoveChannelMessage>,
3902 session: MessageContext,
3903) -> Result<()> {
3904 let channel_id = ChannelId::from_proto(request.channel_id);
3905 let message_id = MessageId::from_proto(request.message_id);
3906 let (connection_ids, existing_notification_ids) = session
3907 .db()
3908 .await
3909 .remove_channel_message(channel_id, message_id, session.user_id())
3910 .await?;
3911
3912 broadcast(
3913 Some(session.connection_id),
3914 connection_ids,
3915 move |connection| {
3916 session.peer.send(connection, request.clone())?;
3917
3918 for notification_id in &existing_notification_ids {
3919 session.peer.send(
3920 connection,
3921 proto::DeleteNotification {
3922 notification_id: (*notification_id).to_proto(),
3923 },
3924 )?;
3925 }
3926
3927 Ok(())
3928 },
3929 );
3930 response.send(proto::Ack {})?;
3931 Ok(())
3932}
3933
3934async fn update_channel_message(
3935 request: proto::UpdateChannelMessage,
3936 response: Response<proto::UpdateChannelMessage>,
3937 session: MessageContext,
3938) -> Result<()> {
3939 let channel_id = ChannelId::from_proto(request.channel_id);
3940 let message_id = MessageId::from_proto(request.message_id);
3941 let updated_at = OffsetDateTime::now_utc();
3942 let UpdatedChannelMessage {
3943 message_id,
3944 participant_connection_ids,
3945 notifications,
3946 reply_to_message_id,
3947 timestamp,
3948 deleted_mention_notification_ids,
3949 updated_mention_notifications,
3950 } = session
3951 .db()
3952 .await
3953 .update_channel_message(
3954 channel_id,
3955 message_id,
3956 session.user_id(),
3957 request.body.as_str(),
3958 &request.mentions,
3959 updated_at,
3960 )
3961 .await?;
3962
3963 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3964
3965 let message = proto::ChannelMessage {
3966 sender_id: session.user_id().to_proto(),
3967 id: message_id.to_proto(),
3968 body: request.body.clone(),
3969 mentions: request.mentions.clone(),
3970 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3971 nonce: Some(nonce),
3972 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3973 edited_at: Some(updated_at.unix_timestamp() as u64),
3974 };
3975
3976 response.send(proto::Ack {})?;
3977
3978 let pool = &*session.connection_pool().await;
3979 broadcast(
3980 Some(session.connection_id),
3981 participant_connection_ids,
3982 |connection| {
3983 session.peer.send(
3984 connection,
3985 proto::ChannelMessageUpdate {
3986 channel_id: channel_id.to_proto(),
3987 message: Some(message.clone()),
3988 },
3989 )?;
3990
3991 for notification_id in &deleted_mention_notification_ids {
3992 session.peer.send(
3993 connection,
3994 proto::DeleteNotification {
3995 notification_id: (*notification_id).to_proto(),
3996 },
3997 )?;
3998 }
3999
4000 for notification in &updated_mention_notifications {
4001 session.peer.send(
4002 connection,
4003 proto::UpdateNotification {
4004 notification: Some(notification.clone()),
4005 },
4006 )?;
4007 }
4008
4009 Ok(())
4010 },
4011 );
4012
4013 send_notifications(pool, &session.peer, notifications);
4014
4015 Ok(())
4016}
4017
4018/// Mark a channel message as read
4019async fn acknowledge_channel_message(
4020 request: proto::AckChannelMessage,
4021 session: MessageContext,
4022) -> Result<()> {
4023 let channel_id = ChannelId::from_proto(request.channel_id);
4024 let message_id = MessageId::from_proto(request.message_id);
4025 let notifications = session
4026 .db()
4027 .await
4028 .observe_channel_message(channel_id, session.user_id(), message_id)
4029 .await?;
4030 send_notifications(
4031 &*session.connection_pool().await,
4032 &session.peer,
4033 notifications,
4034 );
4035 Ok(())
4036}
4037
4038/// Mark a buffer version as synced
4039async fn acknowledge_buffer_version(
4040 request: proto::AckBufferOperation,
4041 session: MessageContext,
4042) -> Result<()> {
4043 let buffer_id = BufferId::from_proto(request.buffer_id);
4044 session
4045 .db()
4046 .await
4047 .observe_buffer_version(
4048 buffer_id,
4049 session.user_id(),
4050 request.epoch as i32,
4051 &request.version,
4052 )
4053 .await?;
4054 Ok(())
4055}
4056
4057/// Get a Supermaven API key for the user
4058async fn get_supermaven_api_key(
4059 _request: proto::GetSupermavenApiKey,
4060 response: Response<proto::GetSupermavenApiKey>,
4061 session: MessageContext,
4062) -> Result<()> {
4063 let user_id: String = session.user_id().to_string();
4064 if !session.is_staff() {
4065 return Err(anyhow!("supermaven not enabled for this account"))?;
4066 }
4067
4068 let email = session.email().context("user must have an email")?;
4069
4070 let supermaven_admin_api = session
4071 .supermaven_client
4072 .as_ref()
4073 .context("supermaven not configured")?;
4074
4075 let result = supermaven_admin_api
4076 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4077 .await?;
4078
4079 response.send(proto::GetSupermavenApiKeyResponse {
4080 api_key: result.api_key,
4081 })?;
4082
4083 Ok(())
4084}
4085
4086/// Start receiving chat updates for a channel
4087async fn join_channel_chat(
4088 request: proto::JoinChannelChat,
4089 response: Response<proto::JoinChannelChat>,
4090 session: MessageContext,
4091) -> Result<()> {
4092 let channel_id = ChannelId::from_proto(request.channel_id);
4093
4094 let db = session.db().await;
4095 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4096 .await?;
4097 let messages = db
4098 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4099 .await?;
4100 response.send(proto::JoinChannelChatResponse {
4101 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4102 messages,
4103 })?;
4104 Ok(())
4105}
4106
4107/// Stop receiving chat updates for a channel
4108async fn leave_channel_chat(
4109 request: proto::LeaveChannelChat,
4110 session: MessageContext,
4111) -> Result<()> {
4112 let channel_id = ChannelId::from_proto(request.channel_id);
4113 session
4114 .db()
4115 .await
4116 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4117 .await?;
4118 Ok(())
4119}
4120
4121/// Retrieve the chat history for a channel
4122async fn get_channel_messages(
4123 request: proto::GetChannelMessages,
4124 response: Response<proto::GetChannelMessages>,
4125 session: MessageContext,
4126) -> Result<()> {
4127 let channel_id = ChannelId::from_proto(request.channel_id);
4128 let messages = session
4129 .db()
4130 .await
4131 .get_channel_messages(
4132 channel_id,
4133 session.user_id(),
4134 MESSAGE_COUNT_PER_PAGE,
4135 Some(MessageId::from_proto(request.before_message_id)),
4136 )
4137 .await?;
4138 response.send(proto::GetChannelMessagesResponse {
4139 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4140 messages,
4141 })?;
4142 Ok(())
4143}
4144
4145/// Retrieve specific chat messages
4146async fn get_channel_messages_by_id(
4147 request: proto::GetChannelMessagesById,
4148 response: Response<proto::GetChannelMessagesById>,
4149 session: MessageContext,
4150) -> Result<()> {
4151 let message_ids = request
4152 .message_ids
4153 .iter()
4154 .map(|id| MessageId::from_proto(*id))
4155 .collect::<Vec<_>>();
4156 let messages = session
4157 .db()
4158 .await
4159 .get_channel_messages_by_id(session.user_id(), &message_ids)
4160 .await?;
4161 response.send(proto::GetChannelMessagesResponse {
4162 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4163 messages,
4164 })?;
4165 Ok(())
4166}
4167
4168/// Retrieve the current users notifications
4169async fn get_notifications(
4170 request: proto::GetNotifications,
4171 response: Response<proto::GetNotifications>,
4172 session: MessageContext,
4173) -> Result<()> {
4174 let notifications = session
4175 .db()
4176 .await
4177 .get_notifications(
4178 session.user_id(),
4179 NOTIFICATION_COUNT_PER_PAGE,
4180 request.before_id.map(db::NotificationId::from_proto),
4181 )
4182 .await?;
4183 response.send(proto::GetNotificationsResponse {
4184 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4185 notifications,
4186 })?;
4187 Ok(())
4188}
4189
4190/// Mark notifications as read
4191async fn mark_notification_as_read(
4192 request: proto::MarkNotificationRead,
4193 response: Response<proto::MarkNotificationRead>,
4194 session: MessageContext,
4195) -> Result<()> {
4196 let database = &session.db().await;
4197 let notifications = database
4198 .mark_notification_as_read_by_id(
4199 session.user_id(),
4200 NotificationId::from_proto(request.notification_id),
4201 )
4202 .await?;
4203 send_notifications(
4204 &*session.connection_pool().await,
4205 &session.peer,
4206 notifications,
4207 );
4208 response.send(proto::Ack {})?;
4209 Ok(())
4210}
4211
4212/// Get the current users information
4213async fn get_private_user_info(
4214 _request: proto::GetPrivateUserInfo,
4215 response: Response<proto::GetPrivateUserInfo>,
4216 session: MessageContext,
4217) -> Result<()> {
4218 let db = session.db().await;
4219
4220 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4221 let user = db
4222 .get_user_by_id(session.user_id())
4223 .await?
4224 .context("user not found")?;
4225 let flags = db.get_user_flags(session.user_id()).await?;
4226
4227 response.send(proto::GetPrivateUserInfoResponse {
4228 metrics_id,
4229 staff: user.admin,
4230 flags,
4231 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4232 })?;
4233 Ok(())
4234}
4235
4236/// Accept the terms of service (tos) on behalf of the current user
4237async fn accept_terms_of_service(
4238 _request: proto::AcceptTermsOfService,
4239 response: Response<proto::AcceptTermsOfService>,
4240 session: MessageContext,
4241) -> Result<()> {
4242 let db = session.db().await;
4243
4244 let accepted_tos_at = Utc::now();
4245 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4246 .await?;
4247
4248 response.send(proto::AcceptTermsOfServiceResponse {
4249 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4250 })?;
4251
4252 Ok(())
4253}
4254
4255fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4256 let message = match message {
4257 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4258 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4259 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4260 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4261 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4262 code: frame.code.into(),
4263 reason: frame.reason.as_str().to_owned().into(),
4264 })),
4265 // We should never receive a frame while reading the message, according
4266 // to the `tungstenite` maintainers:
4267 //
4268 // > It cannot occur when you read messages from the WebSocket, but it
4269 // > can be used when you want to send the raw frames (e.g. you want to
4270 // > send the frames to the WebSocket without composing the full message first).
4271 // >
4272 // > — https://github.com/snapview/tungstenite-rs/issues/268
4273 TungsteniteMessage::Frame(_) => {
4274 bail!("received an unexpected frame while reading the message")
4275 }
4276 };
4277
4278 Ok(message)
4279}
4280
4281fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4282 match message {
4283 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4284 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4285 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4286 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4287 AxumMessage::Close(frame) => {
4288 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4289 code: frame.code.into(),
4290 reason: frame.reason.as_ref().into(),
4291 }))
4292 }
4293 }
4294}
4295
4296fn notify_membership_updated(
4297 connection_pool: &mut ConnectionPool,
4298 result: MembershipUpdated,
4299 user_id: UserId,
4300 peer: &Peer,
4301) {
4302 for membership in &result.new_channels.channel_memberships {
4303 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4304 }
4305 for channel_id in &result.removed_channels {
4306 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4307 }
4308
4309 let user_channels_update = proto::UpdateUserChannels {
4310 channel_memberships: result
4311 .new_channels
4312 .channel_memberships
4313 .iter()
4314 .map(|cm| proto::ChannelMembership {
4315 channel_id: cm.channel_id.to_proto(),
4316 role: cm.role.into(),
4317 })
4318 .collect(),
4319 ..Default::default()
4320 };
4321
4322 let mut update = build_channels_update(result.new_channels);
4323 update.delete_channels = result
4324 .removed_channels
4325 .into_iter()
4326 .map(|id| id.to_proto())
4327 .collect();
4328 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4329
4330 for connection_id in connection_pool.user_connection_ids(user_id) {
4331 peer.send(connection_id, user_channels_update.clone())
4332 .trace_err();
4333 peer.send(connection_id, update.clone()).trace_err();
4334 }
4335}
4336
4337fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4338 proto::UpdateUserChannels {
4339 channel_memberships: channels
4340 .channel_memberships
4341 .iter()
4342 .map(|m| proto::ChannelMembership {
4343 channel_id: m.channel_id.to_proto(),
4344 role: m.role.into(),
4345 })
4346 .collect(),
4347 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4348 observed_channel_message_id: channels.observed_channel_messages.clone(),
4349 }
4350}
4351
4352fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4353 let mut update = proto::UpdateChannels::default();
4354
4355 for channel in channels.channels {
4356 update.channels.push(channel.to_proto());
4357 }
4358
4359 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4360 update.latest_channel_message_ids = channels.latest_channel_messages;
4361
4362 for (channel_id, participants) in channels.channel_participants {
4363 update
4364 .channel_participants
4365 .push(proto::ChannelParticipants {
4366 channel_id: channel_id.to_proto(),
4367 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4368 });
4369 }
4370
4371 for channel in channels.invited_channels {
4372 update.channel_invitations.push(channel.to_proto());
4373 }
4374
4375 update
4376}
4377
4378fn build_initial_contacts_update(
4379 contacts: Vec<db::Contact>,
4380 pool: &ConnectionPool,
4381) -> proto::UpdateContacts {
4382 let mut update = proto::UpdateContacts::default();
4383
4384 for contact in contacts {
4385 match contact {
4386 db::Contact::Accepted { user_id, busy } => {
4387 update.contacts.push(contact_for_user(user_id, busy, pool));
4388 }
4389 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4390 db::Contact::Incoming { user_id } => {
4391 update
4392 .incoming_requests
4393 .push(proto::IncomingContactRequest {
4394 requester_id: user_id.to_proto(),
4395 })
4396 }
4397 }
4398 }
4399
4400 update
4401}
4402
4403fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4404 proto::Contact {
4405 user_id: user_id.to_proto(),
4406 online: pool.is_user_online(user_id),
4407 busy,
4408 }
4409}
4410
4411fn room_updated(room: &proto::Room, peer: &Peer) {
4412 broadcast(
4413 None,
4414 room.participants
4415 .iter()
4416 .filter_map(|participant| Some(participant.peer_id?.into())),
4417 |peer_id| {
4418 peer.send(
4419 peer_id,
4420 proto::RoomUpdated {
4421 room: Some(room.clone()),
4422 },
4423 )
4424 },
4425 );
4426}
4427
4428fn channel_updated(
4429 channel: &db::channel::Model,
4430 room: &proto::Room,
4431 peer: &Peer,
4432 pool: &ConnectionPool,
4433) {
4434 let participants = room
4435 .participants
4436 .iter()
4437 .map(|p| p.user_id)
4438 .collect::<Vec<_>>();
4439
4440 broadcast(
4441 None,
4442 pool.channel_connection_ids(channel.root_id())
4443 .filter_map(|(channel_id, role)| {
4444 role.can_see_channel(channel.visibility)
4445 .then_some(channel_id)
4446 }),
4447 |peer_id| {
4448 peer.send(
4449 peer_id,
4450 proto::UpdateChannels {
4451 channel_participants: vec![proto::ChannelParticipants {
4452 channel_id: channel.id.to_proto(),
4453 participant_user_ids: participants.clone(),
4454 }],
4455 ..Default::default()
4456 },
4457 )
4458 },
4459 );
4460}
4461
4462async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4463 let db = session.db().await;
4464
4465 let contacts = db.get_contacts(user_id).await?;
4466 let busy = db.is_user_busy(user_id).await?;
4467
4468 let pool = session.connection_pool().await;
4469 let updated_contact = contact_for_user(user_id, busy, &pool);
4470 for contact in contacts {
4471 if let db::Contact::Accepted {
4472 user_id: contact_user_id,
4473 ..
4474 } = contact
4475 {
4476 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4477 session
4478 .peer
4479 .send(
4480 contact_conn_id,
4481 proto::UpdateContacts {
4482 contacts: vec![updated_contact.clone()],
4483 remove_contacts: Default::default(),
4484 incoming_requests: Default::default(),
4485 remove_incoming_requests: Default::default(),
4486 outgoing_requests: Default::default(),
4487 remove_outgoing_requests: Default::default(),
4488 },
4489 )
4490 .trace_err();
4491 }
4492 }
4493 }
4494 Ok(())
4495}
4496
4497async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4498 let mut contacts_to_update = HashSet::default();
4499
4500 let room_id;
4501 let canceled_calls_to_user_ids;
4502 let livekit_room;
4503 let delete_livekit_room;
4504 let room;
4505 let channel;
4506
4507 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4508 contacts_to_update.insert(session.user_id());
4509
4510 for project in left_room.left_projects.values() {
4511 project_left(project, session);
4512 }
4513
4514 room_id = RoomId::from_proto(left_room.room.id);
4515 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4516 livekit_room = mem::take(&mut left_room.room.livekit_room);
4517 delete_livekit_room = left_room.deleted;
4518 room = mem::take(&mut left_room.room);
4519 channel = mem::take(&mut left_room.channel);
4520
4521 room_updated(&room, &session.peer);
4522 } else {
4523 return Ok(());
4524 }
4525
4526 if let Some(channel) = channel {
4527 channel_updated(
4528 &channel,
4529 &room,
4530 &session.peer,
4531 &*session.connection_pool().await,
4532 );
4533 }
4534
4535 {
4536 let pool = session.connection_pool().await;
4537 for canceled_user_id in canceled_calls_to_user_ids {
4538 for connection_id in pool.user_connection_ids(canceled_user_id) {
4539 session
4540 .peer
4541 .send(
4542 connection_id,
4543 proto::CallCanceled {
4544 room_id: room_id.to_proto(),
4545 },
4546 )
4547 .trace_err();
4548 }
4549 contacts_to_update.insert(canceled_user_id);
4550 }
4551 }
4552
4553 for contact_user_id in contacts_to_update {
4554 update_user_contacts(contact_user_id, session).await?;
4555 }
4556
4557 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4558 live_kit
4559 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4560 .await
4561 .trace_err();
4562
4563 if delete_livekit_room {
4564 live_kit.delete_room(livekit_room).await.trace_err();
4565 }
4566 }
4567
4568 Ok(())
4569}
4570
4571async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4572 let left_channel_buffers = session
4573 .db()
4574 .await
4575 .leave_channel_buffers(session.connection_id)
4576 .await?;
4577
4578 for left_buffer in left_channel_buffers {
4579 channel_buffer_updated(
4580 session.connection_id,
4581 left_buffer.connections,
4582 &proto::UpdateChannelBufferCollaborators {
4583 channel_id: left_buffer.channel_id.to_proto(),
4584 collaborators: left_buffer.collaborators,
4585 },
4586 &session.peer,
4587 );
4588 }
4589
4590 Ok(())
4591}
4592
4593fn project_left(project: &db::LeftProject, session: &Session) {
4594 for connection_id in &project.connection_ids {
4595 if project.should_unshare {
4596 session
4597 .peer
4598 .send(
4599 *connection_id,
4600 proto::UnshareProject {
4601 project_id: project.id.to_proto(),
4602 },
4603 )
4604 .trace_err();
4605 } else {
4606 session
4607 .peer
4608 .send(
4609 *connection_id,
4610 proto::RemoveProjectCollaborator {
4611 project_id: project.id.to_proto(),
4612 peer_id: Some(session.connection_id.into()),
4613 },
4614 )
4615 .trace_err();
4616 }
4617 }
4618}
4619
4620pub trait ResultExt {
4621 type Ok;
4622
4623 fn trace_err(self) -> Option<Self::Ok>;
4624}
4625
4626impl<T, E> ResultExt for Result<T, E>
4627where
4628 E: std::fmt::Debug,
4629{
4630 type Ok = T;
4631
4632 #[track_caller]
4633 fn trace_err(self) -> Option<T> {
4634 match self {
4635 Ok(value) => Some(value),
4636 Err(error) => {
4637 tracing::error!("{:?}", error);
4638 None
4639 }
4640 }
4641 }
4642}