1mod connection_pool;
2
3use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
4use crate::db::billing_subscription::SubscriptionKind;
5use crate::llm::db::LlmDatabase;
6use crate::llm::{
7 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG,
8 MIN_ACCOUNT_AGE_FOR_LLM_USE,
9};
10use crate::{
11 AppState, Error, Result, auth,
12 db::{
13 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
14 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
15 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
16 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
17 },
18 executor::Executor,
19};
20use anyhow::{Context as _, anyhow, bail};
21use async_tungstenite::tungstenite::{
22 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
23};
24use axum::headers::UserAgent;
25use axum::{
26 Extension, Router, TypedHeader,
27 body::Body,
28 extract::{
29 ConnectInfo, WebSocketUpgrade,
30 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
31 },
32 headers::{Header, HeaderName},
33 http::StatusCode,
34 middleware,
35 response::IntoResponse,
36 routing::get,
37};
38use chrono::Utc;
39use collections::{HashMap, HashSet};
40pub use connection_pool::{ConnectionPool, ZedVersion};
41use core::fmt::{self, Debug, Formatter};
42use futures::TryFutureExt as _;
43use reqwest_client::ReqwestClient;
44use rpc::proto::{MultiLspQuery, split_repository_update};
45use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
46use tracing::Span;
47
48use futures::{
49 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
50 stream::FuturesUnordered,
51};
52use prometheus::{IntGauge, register_int_gauge};
53use rpc::{
54 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
55 proto::{
56 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
57 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
58 },
59};
60use semantic_version::SemanticVersion;
61use serde::{Serialize, Serializer};
62use std::{
63 any::TypeId,
64 future::Future,
65 marker::PhantomData,
66 mem,
67 net::SocketAddr,
68 ops::{Deref, DerefMut},
69 rc::Rc,
70 sync::{
71 Arc, OnceLock,
72 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
73 },
74 time::{Duration, Instant},
75};
76use time::OffsetDateTime;
77use tokio::sync::{Semaphore, watch};
78use tower::ServiceBuilder;
79use tracing::{
80 Instrument,
81 field::{self},
82 info_span, instrument,
83};
84
85pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
86
87// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
88pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
89
90const MESSAGE_COUNT_PER_PAGE: usize = 100;
91const MAX_MESSAGE_LEN: usize = 1024;
92const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
93const MAX_CONCURRENT_CONNECTIONS: usize = 512;
94
95static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
96
97const TOTAL_DURATION_MS: &str = "total_duration_ms";
98const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
99const QUEUE_DURATION_MS: &str = "queue_duration_ms";
100const HOST_WAITING_MS: &str = "host_waiting_ms";
101
102type MessageHandler =
103 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
104
105pub struct ConnectionGuard;
106
107impl ConnectionGuard {
108 pub fn try_acquire() -> Result<Self, ()> {
109 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
110 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
111 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
112 tracing::error!(
113 "too many concurrent connections: {}",
114 current_connections + 1
115 );
116 return Err(());
117 }
118 Ok(ConnectionGuard)
119 }
120}
121
122impl Drop for ConnectionGuard {
123 fn drop(&mut self) {
124 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
125 }
126}
127
128struct Response<R> {
129 peer: Arc<Peer>,
130 receipt: Receipt<R>,
131 responded: Arc<AtomicBool>,
132}
133
134impl<R: RequestMessage> Response<R> {
135 fn send(self, payload: R::Response) -> Result<()> {
136 self.responded.store(true, SeqCst);
137 self.peer.respond(self.receipt, payload)?;
138 Ok(())
139 }
140}
141
142#[derive(Clone, Debug)]
143pub enum Principal {
144 User(User),
145 Impersonated { user: User, admin: User },
146}
147
148impl Principal {
149 fn user(&self) -> &User {
150 match self {
151 Principal::User(user) => user,
152 Principal::Impersonated { user, .. } => user,
153 }
154 }
155
156 fn update_span(&self, span: &tracing::Span) {
157 match &self {
158 Principal::User(user) => {
159 span.record("user_id", user.id.0);
160 span.record("login", &user.github_login);
161 }
162 Principal::Impersonated { user, admin } => {
163 span.record("user_id", user.id.0);
164 span.record("login", &user.github_login);
165 span.record("impersonator", &admin.github_login);
166 }
167 }
168 }
169}
170
171#[derive(Clone)]
172struct MessageContext {
173 session: Session,
174 span: tracing::Span,
175}
176
177impl Deref for MessageContext {
178 type Target = Session;
179
180 fn deref(&self) -> &Self::Target {
181 &self.session
182 }
183}
184
185impl MessageContext {
186 pub fn forward_request<T: RequestMessage>(
187 &self,
188 receiver_id: ConnectionId,
189 request: T,
190 ) -> impl Future<Output = anyhow::Result<T::Response>> {
191 let request_start_time = Instant::now();
192 let span = self.span.clone();
193 tracing::info!("start forwarding request");
194 self.peer
195 .forward_request(self.connection_id, receiver_id, request)
196 .inspect(move |_| {
197 span.record(
198 HOST_WAITING_MS,
199 request_start_time.elapsed().as_micros() as f64 / 1000.0,
200 );
201 })
202 .inspect_err(|_| tracing::error!("error forwarding request"))
203 .inspect_ok(|_| tracing::info!("finished forwarding request"))
204 }
205}
206
207#[derive(Clone)]
208struct Session {
209 principal: Principal,
210 connection_id: ConnectionId,
211 db: Arc<tokio::sync::Mutex<DbHandle>>,
212 peer: Arc<Peer>,
213 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
214 app_state: Arc<AppState>,
215 supermaven_client: Option<Arc<SupermavenAdminApi>>,
216 /// The GeoIP country code for the user.
217 #[allow(unused)]
218 geoip_country_code: Option<String>,
219 #[allow(unused)]
220 system_id: Option<String>,
221 _executor: Executor,
222}
223
224impl Session {
225 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
226 #[cfg(test)]
227 tokio::task::yield_now().await;
228 let guard = self.db.lock().await;
229 #[cfg(test)]
230 tokio::task::yield_now().await;
231 guard
232 }
233
234 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
235 #[cfg(test)]
236 tokio::task::yield_now().await;
237 let guard = self.connection_pool.lock();
238 ConnectionPoolGuard {
239 guard,
240 _not_send: PhantomData,
241 }
242 }
243
244 fn is_staff(&self) -> bool {
245 match &self.principal {
246 Principal::User(user) => user.admin,
247 Principal::Impersonated { .. } => true,
248 }
249 }
250
251 fn user_id(&self) -> UserId {
252 match &self.principal {
253 Principal::User(user) => user.id,
254 Principal::Impersonated { user, .. } => user.id,
255 }
256 }
257
258 pub fn email(&self) -> Option<String> {
259 match &self.principal {
260 Principal::User(user) => user.email_address.clone(),
261 Principal::Impersonated { user, .. } => user.email_address.clone(),
262 }
263 }
264}
265
266impl Debug for Session {
267 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
268 let mut result = f.debug_struct("Session");
269 match &self.principal {
270 Principal::User(user) => {
271 result.field("user", &user.github_login);
272 }
273 Principal::Impersonated { user, admin } => {
274 result.field("user", &user.github_login);
275 result.field("impersonator", &admin.github_login);
276 }
277 }
278 result.field("connection_id", &self.connection_id).finish()
279 }
280}
281
282struct DbHandle(Arc<Database>);
283
284impl Deref for DbHandle {
285 type Target = Database;
286
287 fn deref(&self) -> &Self::Target {
288 self.0.as_ref()
289 }
290}
291
292pub struct Server {
293 id: parking_lot::Mutex<ServerId>,
294 peer: Arc<Peer>,
295 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
296 app_state: Arc<AppState>,
297 handlers: HashMap<TypeId, MessageHandler>,
298 teardown: watch::Sender<bool>,
299}
300
301pub(crate) struct ConnectionPoolGuard<'a> {
302 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
303 _not_send: PhantomData<Rc<()>>,
304}
305
306#[derive(Serialize)]
307pub struct ServerSnapshot<'a> {
308 peer: &'a Peer,
309 #[serde(serialize_with = "serialize_deref")]
310 connection_pool: ConnectionPoolGuard<'a>,
311}
312
313pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
314where
315 S: Serializer,
316 T: Deref<Target = U>,
317 U: Serialize,
318{
319 Serialize::serialize(value.deref(), serializer)
320}
321
322impl Server {
323 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
324 let mut server = Self {
325 id: parking_lot::Mutex::new(id),
326 peer: Peer::new(id.0 as u32),
327 app_state: app_state.clone(),
328 connection_pool: Default::default(),
329 handlers: Default::default(),
330 teardown: watch::channel(false).0,
331 };
332
333 server
334 .add_request_handler(ping)
335 .add_request_handler(create_room)
336 .add_request_handler(join_room)
337 .add_request_handler(rejoin_room)
338 .add_request_handler(leave_room)
339 .add_request_handler(set_room_participant_role)
340 .add_request_handler(call)
341 .add_request_handler(cancel_call)
342 .add_message_handler(decline_call)
343 .add_request_handler(update_participant_location)
344 .add_request_handler(share_project)
345 .add_message_handler(unshare_project)
346 .add_request_handler(join_project)
347 .add_message_handler(leave_project)
348 .add_request_handler(update_project)
349 .add_request_handler(update_worktree)
350 .add_request_handler(update_repository)
351 .add_request_handler(remove_repository)
352 .add_message_handler(start_language_server)
353 .add_message_handler(update_language_server)
354 .add_message_handler(update_diagnostic_summary)
355 .add_message_handler(update_worktree_settings)
356 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
357 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
358 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
359 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
360 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
361 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
362 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
363 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
364 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
365 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
366 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
367 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
368 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
369 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
370 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
371 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
372 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
373 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
374 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
375 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
376 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
377 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
378 .add_request_handler(
379 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
380 )
381 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
382 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
383 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
384 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
385 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
386 .add_request_handler(
387 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
388 )
389 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
390 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
391 .add_request_handler(
392 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
393 )
394 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
395 .add_request_handler(
396 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
397 )
398 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
399 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
400 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
401 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
402 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
403 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
404 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
405 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
406 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
407 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
408 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
409 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
410 .add_request_handler(
411 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
412 )
413 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
414 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
415 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
416 .add_request_handler(multi_lsp_query)
417 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
418 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
419 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
420 .add_message_handler(create_buffer_for_peer)
421 .add_request_handler(update_buffer)
422 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
423 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
424 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
425 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
426 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
427 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
428 .add_message_handler(
429 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
430 )
431 .add_request_handler(get_users)
432 .add_request_handler(fuzzy_search_users)
433 .add_request_handler(request_contact)
434 .add_request_handler(remove_contact)
435 .add_request_handler(respond_to_contact_request)
436 .add_message_handler(subscribe_to_channels)
437 .add_request_handler(create_channel)
438 .add_request_handler(delete_channel)
439 .add_request_handler(invite_channel_member)
440 .add_request_handler(remove_channel_member)
441 .add_request_handler(set_channel_member_role)
442 .add_request_handler(set_channel_visibility)
443 .add_request_handler(rename_channel)
444 .add_request_handler(join_channel_buffer)
445 .add_request_handler(leave_channel_buffer)
446 .add_message_handler(update_channel_buffer)
447 .add_request_handler(rejoin_channel_buffers)
448 .add_request_handler(get_channel_members)
449 .add_request_handler(respond_to_channel_invite)
450 .add_request_handler(join_channel)
451 .add_request_handler(join_channel_chat)
452 .add_message_handler(leave_channel_chat)
453 .add_request_handler(send_channel_message)
454 .add_request_handler(remove_channel_message)
455 .add_request_handler(update_channel_message)
456 .add_request_handler(get_channel_messages)
457 .add_request_handler(get_channel_messages_by_id)
458 .add_request_handler(get_notifications)
459 .add_request_handler(mark_notification_as_read)
460 .add_request_handler(move_channel)
461 .add_request_handler(reorder_channel)
462 .add_request_handler(follow)
463 .add_message_handler(unfollow)
464 .add_message_handler(update_followers)
465 .add_request_handler(accept_terms_of_service)
466 .add_message_handler(acknowledge_channel_message)
467 .add_message_handler(acknowledge_buffer_version)
468 .add_request_handler(get_supermaven_api_key)
469 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
470 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
471 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
472 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
473 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
474 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
475 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
476 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
477 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
478 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
479 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
480 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
481 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
482 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
483 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
484 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
485 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
486 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
487 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
488 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
489 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
490 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
491 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
492 .add_message_handler(update_context);
493
494 Arc::new(server)
495 }
496
497 pub async fn start(&self) -> Result<()> {
498 let server_id = *self.id.lock();
499 let app_state = self.app_state.clone();
500 let peer = self.peer.clone();
501 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
502 let pool = self.connection_pool.clone();
503 let livekit_client = self.app_state.livekit_client.clone();
504
505 let span = info_span!("start server");
506 self.app_state.executor.spawn_detached(
507 async move {
508 tracing::info!("waiting for cleanup timeout");
509 timeout.await;
510 tracing::info!("cleanup timeout expired, retrieving stale rooms");
511
512 app_state
513 .db
514 .delete_stale_channel_chat_participants(
515 &app_state.config.zed_environment,
516 server_id,
517 )
518 .await
519 .trace_err();
520
521 if let Some((room_ids, channel_ids)) = app_state
522 .db
523 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
524 .await
525 .trace_err()
526 {
527 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
528 tracing::info!(
529 stale_channel_buffer_count = channel_ids.len(),
530 "retrieved stale channel buffers"
531 );
532
533 for channel_id in channel_ids {
534 if let Some(refreshed_channel_buffer) = app_state
535 .db
536 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
537 .await
538 .trace_err()
539 {
540 for connection_id in refreshed_channel_buffer.connection_ids {
541 peer.send(
542 connection_id,
543 proto::UpdateChannelBufferCollaborators {
544 channel_id: channel_id.to_proto(),
545 collaborators: refreshed_channel_buffer
546 .collaborators
547 .clone(),
548 },
549 )
550 .trace_err();
551 }
552 }
553 }
554
555 for room_id in room_ids {
556 let mut contacts_to_update = HashSet::default();
557 let mut canceled_calls_to_user_ids = Vec::new();
558 let mut livekit_room = String::new();
559 let mut delete_livekit_room = false;
560
561 if let Some(mut refreshed_room) = app_state
562 .db
563 .clear_stale_room_participants(room_id, server_id)
564 .await
565 .trace_err()
566 {
567 tracing::info!(
568 room_id = room_id.0,
569 new_participant_count = refreshed_room.room.participants.len(),
570 "refreshed room"
571 );
572 room_updated(&refreshed_room.room, &peer);
573 if let Some(channel) = refreshed_room.channel.as_ref() {
574 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
575 }
576 contacts_to_update
577 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
578 contacts_to_update
579 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
580 canceled_calls_to_user_ids =
581 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
582 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
583 delete_livekit_room = refreshed_room.room.participants.is_empty();
584 }
585
586 {
587 let pool = pool.lock();
588 for canceled_user_id in canceled_calls_to_user_ids {
589 for connection_id in pool.user_connection_ids(canceled_user_id) {
590 peer.send(
591 connection_id,
592 proto::CallCanceled {
593 room_id: room_id.to_proto(),
594 },
595 )
596 .trace_err();
597 }
598 }
599 }
600
601 for user_id in contacts_to_update {
602 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
603 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
604 if let Some((busy, contacts)) = busy.zip(contacts) {
605 let pool = pool.lock();
606 let updated_contact = contact_for_user(user_id, busy, &pool);
607 for contact in contacts {
608 if let db::Contact::Accepted {
609 user_id: contact_user_id,
610 ..
611 } = contact
612 {
613 for contact_conn_id in
614 pool.user_connection_ids(contact_user_id)
615 {
616 peer.send(
617 contact_conn_id,
618 proto::UpdateContacts {
619 contacts: vec![updated_contact.clone()],
620 remove_contacts: Default::default(),
621 incoming_requests: Default::default(),
622 remove_incoming_requests: Default::default(),
623 outgoing_requests: Default::default(),
624 remove_outgoing_requests: Default::default(),
625 },
626 )
627 .trace_err();
628 }
629 }
630 }
631 }
632 }
633
634 if let Some(live_kit) = livekit_client.as_ref() {
635 if delete_livekit_room {
636 live_kit.delete_room(livekit_room).await.trace_err();
637 }
638 }
639 }
640 }
641
642 app_state
643 .db
644 .delete_stale_channel_chat_participants(
645 &app_state.config.zed_environment,
646 server_id,
647 )
648 .await
649 .trace_err();
650
651 app_state
652 .db
653 .clear_old_worktree_entries(server_id)
654 .await
655 .trace_err();
656
657 app_state
658 .db
659 .delete_stale_servers(&app_state.config.zed_environment, server_id)
660 .await
661 .trace_err();
662 }
663 .instrument(span),
664 );
665 Ok(())
666 }
667
668 pub fn teardown(&self) {
669 self.peer.teardown();
670 self.connection_pool.lock().reset();
671 let _ = self.teardown.send(true);
672 }
673
674 #[cfg(test)]
675 pub fn reset(&self, id: ServerId) {
676 self.teardown();
677 *self.id.lock() = id;
678 self.peer.reset(id.0 as u32);
679 let _ = self.teardown.send(false);
680 }
681
682 #[cfg(test)]
683 pub fn id(&self) -> ServerId {
684 *self.id.lock()
685 }
686
687 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
688 where
689 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
690 Fut: 'static + Send + Future<Output = Result<()>>,
691 M: EnvelopedMessage,
692 {
693 let prev_handler = self.handlers.insert(
694 TypeId::of::<M>(),
695 Box::new(move |envelope, session, span| {
696 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
697 let received_at = envelope.received_at;
698 tracing::info!("message received");
699 let start_time = Instant::now();
700 let future = (handler)(
701 *envelope,
702 MessageContext {
703 session,
704 span: span.clone(),
705 },
706 );
707 async move {
708 let result = future.await;
709 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
710 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
711 let queue_duration_ms = total_duration_ms - processing_duration_ms;
712 span.record(TOTAL_DURATION_MS, total_duration_ms);
713 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
714 span.record(QUEUE_DURATION_MS, queue_duration_ms);
715 match result {
716 Err(error) => {
717 tracing::error!(?error, "error handling message")
718 }
719 Ok(()) => tracing::info!("finished handling message"),
720 }
721 }
722 .boxed()
723 }),
724 );
725 if prev_handler.is_some() {
726 panic!("registered a handler for the same message twice");
727 }
728 self
729 }
730
731 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
732 where
733 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
734 Fut: 'static + Send + Future<Output = Result<()>>,
735 M: EnvelopedMessage,
736 {
737 self.add_handler(move |envelope, session| handler(envelope.payload, session));
738 self
739 }
740
741 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
742 where
743 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
744 Fut: Send + Future<Output = Result<()>>,
745 M: RequestMessage,
746 {
747 let handler = Arc::new(handler);
748 self.add_handler(move |envelope, session| {
749 let receipt = envelope.receipt();
750 let handler = handler.clone();
751 async move {
752 let peer = session.peer.clone();
753 let responded = Arc::new(AtomicBool::default());
754 let response = Response {
755 peer: peer.clone(),
756 responded: responded.clone(),
757 receipt,
758 };
759 match (handler)(envelope.payload, response, session).await {
760 Ok(()) => {
761 if responded.load(std::sync::atomic::Ordering::SeqCst) {
762 Ok(())
763 } else {
764 Err(anyhow!("handler did not send a response"))?
765 }
766 }
767 Err(error) => {
768 let proto_err = match &error {
769 Error::Internal(err) => err.to_proto(),
770 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
771 };
772 peer.respond_with_error(receipt, proto_err)?;
773 Err(error)
774 }
775 }
776 }
777 })
778 }
779
780 pub fn handle_connection(
781 self: &Arc<Self>,
782 connection: Connection,
783 address: String,
784 principal: Principal,
785 zed_version: ZedVersion,
786 release_channel: Option<String>,
787 user_agent: Option<String>,
788 geoip_country_code: Option<String>,
789 system_id: Option<String>,
790 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
791 executor: Executor,
792 connection_guard: Option<ConnectionGuard>,
793 ) -> impl Future<Output = ()> + use<> {
794 let this = self.clone();
795 let span = info_span!("handle connection", %address,
796 connection_id=field::Empty,
797 user_id=field::Empty,
798 login=field::Empty,
799 impersonator=field::Empty,
800 user_agent=field::Empty,
801 geoip_country_code=field::Empty,
802 release_channel=field::Empty,
803 );
804 principal.update_span(&span);
805 if let Some(user_agent) = user_agent {
806 span.record("user_agent", user_agent);
807 }
808 if let Some(release_channel) = release_channel {
809 span.record("release_channel", release_channel);
810 }
811
812 if let Some(country_code) = geoip_country_code.as_ref() {
813 span.record("geoip_country_code", country_code);
814 }
815
816 let mut teardown = self.teardown.subscribe();
817 async move {
818 if *teardown.borrow() {
819 tracing::error!("server is tearing down");
820 return;
821 }
822
823 let (connection_id, handle_io, mut incoming_rx) =
824 this.peer.add_connection(connection, {
825 let executor = executor.clone();
826 move |duration| executor.sleep(duration)
827 });
828 tracing::Span::current().record("connection_id", format!("{}", connection_id));
829
830 tracing::info!("connection opened");
831
832 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
833 let http_client = match ReqwestClient::user_agent(&user_agent) {
834 Ok(http_client) => Arc::new(http_client),
835 Err(error) => {
836 tracing::error!(?error, "failed to create HTTP client");
837 return;
838 }
839 };
840
841 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
842 |supermaven_admin_api_key| {
843 Arc::new(SupermavenAdminApi::new(
844 supermaven_admin_api_key.to_string(),
845 http_client.clone(),
846 ))
847 },
848 );
849
850 let session = Session {
851 principal: principal.clone(),
852 connection_id,
853 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
854 peer: this.peer.clone(),
855 connection_pool: this.connection_pool.clone(),
856 app_state: this.app_state.clone(),
857 geoip_country_code,
858 system_id,
859 _executor: executor.clone(),
860 supermaven_client,
861 };
862
863 if let Err(error) = this
864 .send_initial_client_update(
865 connection_id,
866 zed_version,
867 send_connection_id,
868 &session,
869 )
870 .await
871 {
872 tracing::error!(?error, "failed to send initial client update");
873 return;
874 }
875 drop(connection_guard);
876
877 let handle_io = handle_io.fuse();
878 futures::pin_mut!(handle_io);
879
880 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
881 // This prevents deadlocks when e.g., client A performs a request to client B and
882 // client B performs a request to client A. If both clients stop processing further
883 // messages until their respective request completes, they won't have a chance to
884 // respond to the other client's request and cause a deadlock.
885 //
886 // This arrangement ensures we will attempt to process earlier messages first, but fall
887 // back to processing messages arrived later in the spirit of making progress.
888 const MAX_CONCURRENT_HANDLERS: usize = 256;
889 let mut foreground_message_handlers = FuturesUnordered::new();
890 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
891 let get_concurrent_handlers = {
892 let concurrent_handlers = concurrent_handlers.clone();
893 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
894 };
895 loop {
896 let next_message = async {
897 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
898 let message = incoming_rx.next().await;
899 // Cache the concurrent_handlers here, so that we know what the
900 // queue looks like as each handler starts
901 (permit, message, get_concurrent_handlers())
902 }
903 .fuse();
904 futures::pin_mut!(next_message);
905 futures::select_biased! {
906 _ = teardown.changed().fuse() => return,
907 result = handle_io => {
908 if let Err(error) = result {
909 tracing::error!(?error, "error handling I/O");
910 }
911 break;
912 }
913 _ = foreground_message_handlers.next() => {}
914 next_message = next_message => {
915 let (permit, message, concurrent_handlers) = next_message;
916 if let Some(message) = message {
917 let type_name = message.payload_type_name();
918 // note: we copy all the fields from the parent span so we can query them in the logs.
919 // (https://github.com/tokio-rs/tracing/issues/2670).
920 let span = tracing::info_span!("receive message",
921 %connection_id,
922 %address,
923 type_name,
924 concurrent_handlers,
925 user_id=field::Empty,
926 login=field::Empty,
927 impersonator=field::Empty,
928 multi_lsp_query_request=field::Empty,
929 release_channel=field::Empty,
930 { TOTAL_DURATION_MS }=field::Empty,
931 { PROCESSING_DURATION_MS }=field::Empty,
932 { QUEUE_DURATION_MS }=field::Empty,
933 { HOST_WAITING_MS }=field::Empty
934 );
935 principal.update_span(&span);
936 let span_enter = span.enter();
937 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
938 let is_background = message.is_background();
939 let handle_message = (handler)(message, session.clone(), span.clone());
940 drop(span_enter);
941
942 let handle_message = async move {
943 handle_message.await;
944 drop(permit);
945 }.instrument(span);
946 if is_background {
947 executor.spawn_detached(handle_message);
948 } else {
949 foreground_message_handlers.push(handle_message);
950 }
951 } else {
952 tracing::error!("no message handler");
953 }
954 } else {
955 tracing::info!("connection closed");
956 break;
957 }
958 }
959 }
960 }
961
962 drop(foreground_message_handlers);
963 let concurrent_handlers = get_concurrent_handlers();
964 tracing::info!(concurrent_handlers, "signing out");
965 if let Err(error) = connection_lost(session, teardown, executor).await {
966 tracing::error!(?error, "error signing out");
967 }
968 }
969 .instrument(span)
970 }
971
972 async fn send_initial_client_update(
973 &self,
974 connection_id: ConnectionId,
975 zed_version: ZedVersion,
976 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
977 session: &Session,
978 ) -> Result<()> {
979 self.peer.send(
980 connection_id,
981 proto::Hello {
982 peer_id: Some(connection_id.into()),
983 },
984 )?;
985 tracing::info!("sent hello message");
986 if let Some(send_connection_id) = send_connection_id.take() {
987 let _ = send_connection_id.send(connection_id);
988 }
989
990 match &session.principal {
991 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
992 if !user.connected_once {
993 self.peer.send(connection_id, proto::ShowContacts {})?;
994 self.app_state
995 .db
996 .set_user_connected_once(user.id, true)
997 .await?;
998 }
999
1000 update_user_plan(session).await?;
1001
1002 let contacts = self.app_state.db.get_contacts(user.id).await?;
1003
1004 {
1005 let mut pool = self.connection_pool.lock();
1006 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1007 self.peer.send(
1008 connection_id,
1009 build_initial_contacts_update(contacts, &pool),
1010 )?;
1011 }
1012
1013 if should_auto_subscribe_to_channels(zed_version) {
1014 subscribe_user_to_channels(user.id, session).await?;
1015 }
1016
1017 if let Some(incoming_call) =
1018 self.app_state.db.incoming_call_for_user(user.id).await?
1019 {
1020 self.peer.send(connection_id, incoming_call)?;
1021 }
1022
1023 update_user_contacts(user.id, session).await?;
1024 }
1025 }
1026
1027 Ok(())
1028 }
1029
1030 pub async fn invite_code_redeemed(
1031 self: &Arc<Self>,
1032 inviter_id: UserId,
1033 invitee_id: UserId,
1034 ) -> Result<()> {
1035 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1036 if let Some(code) = &user.invite_code {
1037 let pool = self.connection_pool.lock();
1038 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1039 for connection_id in pool.user_connection_ids(inviter_id) {
1040 self.peer.send(
1041 connection_id,
1042 proto::UpdateContacts {
1043 contacts: vec![invitee_contact.clone()],
1044 ..Default::default()
1045 },
1046 )?;
1047 self.peer.send(
1048 connection_id,
1049 proto::UpdateInviteInfo {
1050 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1051 count: user.invite_count as u32,
1052 },
1053 )?;
1054 }
1055 }
1056 }
1057 Ok(())
1058 }
1059
1060 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1061 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1062 if let Some(invite_code) = &user.invite_code {
1063 let pool = self.connection_pool.lock();
1064 for connection_id in pool.user_connection_ids(user_id) {
1065 self.peer.send(
1066 connection_id,
1067 proto::UpdateInviteInfo {
1068 url: format!(
1069 "{}{}",
1070 self.app_state.config.invite_link_prefix, invite_code
1071 ),
1072 count: user.invite_count as u32,
1073 },
1074 )?;
1075 }
1076 }
1077 }
1078 Ok(())
1079 }
1080
1081 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1082 ServerSnapshot {
1083 connection_pool: ConnectionPoolGuard {
1084 guard: self.connection_pool.lock(),
1085 _not_send: PhantomData,
1086 },
1087 peer: &self.peer,
1088 }
1089 }
1090}
1091
1092impl Deref for ConnectionPoolGuard<'_> {
1093 type Target = ConnectionPool;
1094
1095 fn deref(&self) -> &Self::Target {
1096 &self.guard
1097 }
1098}
1099
1100impl DerefMut for ConnectionPoolGuard<'_> {
1101 fn deref_mut(&mut self) -> &mut Self::Target {
1102 &mut self.guard
1103 }
1104}
1105
1106impl Drop for ConnectionPoolGuard<'_> {
1107 fn drop(&mut self) {
1108 #[cfg(test)]
1109 self.check_invariants();
1110 }
1111}
1112
1113fn broadcast<F>(
1114 sender_id: Option<ConnectionId>,
1115 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1116 mut f: F,
1117) where
1118 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1119{
1120 for receiver_id in receiver_ids {
1121 if Some(receiver_id) != sender_id {
1122 if let Err(error) = f(receiver_id) {
1123 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1124 }
1125 }
1126 }
1127}
1128
1129pub struct ProtocolVersion(u32);
1130
1131impl Header for ProtocolVersion {
1132 fn name() -> &'static HeaderName {
1133 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1134 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1135 }
1136
1137 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1138 where
1139 Self: Sized,
1140 I: Iterator<Item = &'i axum::http::HeaderValue>,
1141 {
1142 let version = values
1143 .next()
1144 .ok_or_else(axum::headers::Error::invalid)?
1145 .to_str()
1146 .map_err(|_| axum::headers::Error::invalid())?
1147 .parse()
1148 .map_err(|_| axum::headers::Error::invalid())?;
1149 Ok(Self(version))
1150 }
1151
1152 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1153 values.extend([self.0.to_string().parse().unwrap()]);
1154 }
1155}
1156
1157pub struct AppVersionHeader(SemanticVersion);
1158impl Header for AppVersionHeader {
1159 fn name() -> &'static HeaderName {
1160 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1161 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1162 }
1163
1164 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1165 where
1166 Self: Sized,
1167 I: Iterator<Item = &'i axum::http::HeaderValue>,
1168 {
1169 let version = values
1170 .next()
1171 .ok_or_else(axum::headers::Error::invalid)?
1172 .to_str()
1173 .map_err(|_| axum::headers::Error::invalid())?
1174 .parse()
1175 .map_err(|_| axum::headers::Error::invalid())?;
1176 Ok(Self(version))
1177 }
1178
1179 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1180 values.extend([self.0.to_string().parse().unwrap()]);
1181 }
1182}
1183
1184#[derive(Debug)]
1185pub struct ReleaseChannelHeader(String);
1186
1187impl Header for ReleaseChannelHeader {
1188 fn name() -> &'static HeaderName {
1189 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1190 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1191 }
1192
1193 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1194 where
1195 Self: Sized,
1196 I: Iterator<Item = &'i axum::http::HeaderValue>,
1197 {
1198 Ok(Self(
1199 values
1200 .next()
1201 .ok_or_else(axum::headers::Error::invalid)?
1202 .to_str()
1203 .map_err(|_| axum::headers::Error::invalid())?
1204 .to_owned(),
1205 ))
1206 }
1207
1208 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1209 values.extend([self.0.parse().unwrap()]);
1210 }
1211}
1212
1213pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1214 Router::new()
1215 .route("/rpc", get(handle_websocket_request))
1216 .layer(
1217 ServiceBuilder::new()
1218 .layer(Extension(server.app_state.clone()))
1219 .layer(middleware::from_fn(auth::validate_header)),
1220 )
1221 .route("/metrics", get(handle_metrics))
1222 .layer(Extension(server))
1223}
1224
1225pub async fn handle_websocket_request(
1226 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1227 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1228 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1229 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1230 Extension(server): Extension<Arc<Server>>,
1231 Extension(principal): Extension<Principal>,
1232 user_agent: Option<TypedHeader<UserAgent>>,
1233 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1234 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1235 ws: WebSocketUpgrade,
1236) -> axum::response::Response {
1237 if protocol_version != rpc::PROTOCOL_VERSION {
1238 return (
1239 StatusCode::UPGRADE_REQUIRED,
1240 "client must be upgraded".to_string(),
1241 )
1242 .into_response();
1243 }
1244
1245 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1246 return (
1247 StatusCode::UPGRADE_REQUIRED,
1248 "no version header found".to_string(),
1249 )
1250 .into_response();
1251 };
1252
1253 let release_channel = release_channel_header.map(|header| header.0.0);
1254
1255 if !version.can_collaborate() {
1256 return (
1257 StatusCode::UPGRADE_REQUIRED,
1258 "client must be upgraded".to_string(),
1259 )
1260 .into_response();
1261 }
1262
1263 let socket_address = socket_address.to_string();
1264
1265 // Acquire connection guard before WebSocket upgrade
1266 let connection_guard = match ConnectionGuard::try_acquire() {
1267 Ok(guard) => guard,
1268 Err(()) => {
1269 return (
1270 StatusCode::SERVICE_UNAVAILABLE,
1271 "Too many concurrent connections",
1272 )
1273 .into_response();
1274 }
1275 };
1276
1277 ws.on_upgrade(move |socket| {
1278 let socket = socket
1279 .map_ok(to_tungstenite_message)
1280 .err_into()
1281 .with(|message| async move { to_axum_message(message) });
1282 let connection = Connection::new(Box::pin(socket));
1283 async move {
1284 server
1285 .handle_connection(
1286 connection,
1287 socket_address,
1288 principal,
1289 version,
1290 release_channel,
1291 user_agent.map(|header| header.to_string()),
1292 country_code_header.map(|header| header.to_string()),
1293 system_id_header.map(|header| header.to_string()),
1294 None,
1295 Executor::Production,
1296 Some(connection_guard),
1297 )
1298 .await;
1299 }
1300 })
1301}
1302
1303pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1304 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1305 let connections_metric = CONNECTIONS_METRIC
1306 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1307
1308 let connections = server
1309 .connection_pool
1310 .lock()
1311 .connections()
1312 .filter(|connection| !connection.admin)
1313 .count();
1314 connections_metric.set(connections as _);
1315
1316 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1317 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1318 register_int_gauge!(
1319 "shared_projects",
1320 "number of open projects with one or more guests"
1321 )
1322 .unwrap()
1323 });
1324
1325 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1326 shared_projects_metric.set(shared_projects as _);
1327
1328 let encoder = prometheus::TextEncoder::new();
1329 let metric_families = prometheus::gather();
1330 let encoded_metrics = encoder
1331 .encode_to_string(&metric_families)
1332 .map_err(|err| anyhow!("{err}"))?;
1333 Ok(encoded_metrics)
1334}
1335
1336#[instrument(err, skip(executor))]
1337async fn connection_lost(
1338 session: Session,
1339 mut teardown: watch::Receiver<bool>,
1340 executor: Executor,
1341) -> Result<()> {
1342 session.peer.disconnect(session.connection_id);
1343 session
1344 .connection_pool()
1345 .await
1346 .remove_connection(session.connection_id)?;
1347
1348 session
1349 .db()
1350 .await
1351 .connection_lost(session.connection_id)
1352 .await
1353 .trace_err();
1354
1355 futures::select_biased! {
1356 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1357
1358 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1359 leave_room_for_session(&session, session.connection_id).await.trace_err();
1360 leave_channel_buffers_for_session(&session)
1361 .await
1362 .trace_err();
1363
1364 if !session
1365 .connection_pool()
1366 .await
1367 .is_user_online(session.user_id())
1368 {
1369 let db = session.db().await;
1370 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1371 room_updated(&room, &session.peer);
1372 }
1373 }
1374
1375 update_user_contacts(session.user_id(), &session).await?;
1376 },
1377 _ = teardown.changed().fuse() => {}
1378 }
1379
1380 Ok(())
1381}
1382
1383/// Acknowledges a ping from a client, used to keep the connection alive.
1384async fn ping(
1385 _: proto::Ping,
1386 response: Response<proto::Ping>,
1387 _session: MessageContext,
1388) -> Result<()> {
1389 response.send(proto::Ack {})?;
1390 Ok(())
1391}
1392
1393/// Creates a new room for calling (outside of channels)
1394async fn create_room(
1395 _request: proto::CreateRoom,
1396 response: Response<proto::CreateRoom>,
1397 session: MessageContext,
1398) -> Result<()> {
1399 let livekit_room = nanoid::nanoid!(30);
1400
1401 let live_kit_connection_info = util::maybe!(async {
1402 let live_kit = session.app_state.livekit_client.as_ref();
1403 let live_kit = live_kit?;
1404 let user_id = session.user_id().to_string();
1405
1406 let token = live_kit
1407 .room_token(&livekit_room, &user_id.to_string())
1408 .trace_err()?;
1409
1410 Some(proto::LiveKitConnectionInfo {
1411 server_url: live_kit.url().into(),
1412 token,
1413 can_publish: true,
1414 })
1415 })
1416 .await;
1417
1418 let room = session
1419 .db()
1420 .await
1421 .create_room(session.user_id(), session.connection_id, &livekit_room)
1422 .await?;
1423
1424 response.send(proto::CreateRoomResponse {
1425 room: Some(room.clone()),
1426 live_kit_connection_info,
1427 })?;
1428
1429 update_user_contacts(session.user_id(), &session).await?;
1430 Ok(())
1431}
1432
1433/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1434async fn join_room(
1435 request: proto::JoinRoom,
1436 response: Response<proto::JoinRoom>,
1437 session: MessageContext,
1438) -> Result<()> {
1439 let room_id = RoomId::from_proto(request.id);
1440
1441 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1442
1443 if let Some(channel_id) = channel_id {
1444 return join_channel_internal(channel_id, Box::new(response), session).await;
1445 }
1446
1447 let joined_room = {
1448 let room = session
1449 .db()
1450 .await
1451 .join_room(room_id, session.user_id(), session.connection_id)
1452 .await?;
1453 room_updated(&room.room, &session.peer);
1454 room.into_inner()
1455 };
1456
1457 for connection_id in session
1458 .connection_pool()
1459 .await
1460 .user_connection_ids(session.user_id())
1461 {
1462 session
1463 .peer
1464 .send(
1465 connection_id,
1466 proto::CallCanceled {
1467 room_id: room_id.to_proto(),
1468 },
1469 )
1470 .trace_err();
1471 }
1472
1473 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1474 {
1475 live_kit
1476 .room_token(
1477 &joined_room.room.livekit_room,
1478 &session.user_id().to_string(),
1479 )
1480 .trace_err()
1481 .map(|token| proto::LiveKitConnectionInfo {
1482 server_url: live_kit.url().into(),
1483 token,
1484 can_publish: true,
1485 })
1486 } else {
1487 None
1488 };
1489
1490 response.send(proto::JoinRoomResponse {
1491 room: Some(joined_room.room),
1492 channel_id: None,
1493 live_kit_connection_info,
1494 })?;
1495
1496 update_user_contacts(session.user_id(), &session).await?;
1497 Ok(())
1498}
1499
1500/// Rejoin room is used to reconnect to a room after connection errors.
1501async fn rejoin_room(
1502 request: proto::RejoinRoom,
1503 response: Response<proto::RejoinRoom>,
1504 session: MessageContext,
1505) -> Result<()> {
1506 let room;
1507 let channel;
1508 {
1509 let mut rejoined_room = session
1510 .db()
1511 .await
1512 .rejoin_room(request, session.user_id(), session.connection_id)
1513 .await?;
1514
1515 response.send(proto::RejoinRoomResponse {
1516 room: Some(rejoined_room.room.clone()),
1517 reshared_projects: rejoined_room
1518 .reshared_projects
1519 .iter()
1520 .map(|project| proto::ResharedProject {
1521 id: project.id.to_proto(),
1522 collaborators: project
1523 .collaborators
1524 .iter()
1525 .map(|collaborator| collaborator.to_proto())
1526 .collect(),
1527 })
1528 .collect(),
1529 rejoined_projects: rejoined_room
1530 .rejoined_projects
1531 .iter()
1532 .map(|rejoined_project| rejoined_project.to_proto())
1533 .collect(),
1534 })?;
1535 room_updated(&rejoined_room.room, &session.peer);
1536
1537 for project in &rejoined_room.reshared_projects {
1538 for collaborator in &project.collaborators {
1539 session
1540 .peer
1541 .send(
1542 collaborator.connection_id,
1543 proto::UpdateProjectCollaborator {
1544 project_id: project.id.to_proto(),
1545 old_peer_id: Some(project.old_connection_id.into()),
1546 new_peer_id: Some(session.connection_id.into()),
1547 },
1548 )
1549 .trace_err();
1550 }
1551
1552 broadcast(
1553 Some(session.connection_id),
1554 project
1555 .collaborators
1556 .iter()
1557 .map(|collaborator| collaborator.connection_id),
1558 |connection_id| {
1559 session.peer.forward_send(
1560 session.connection_id,
1561 connection_id,
1562 proto::UpdateProject {
1563 project_id: project.id.to_proto(),
1564 worktrees: project.worktrees.clone(),
1565 },
1566 )
1567 },
1568 );
1569 }
1570
1571 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1572
1573 let rejoined_room = rejoined_room.into_inner();
1574
1575 room = rejoined_room.room;
1576 channel = rejoined_room.channel;
1577 }
1578
1579 if let Some(channel) = channel {
1580 channel_updated(
1581 &channel,
1582 &room,
1583 &session.peer,
1584 &*session.connection_pool().await,
1585 );
1586 }
1587
1588 update_user_contacts(session.user_id(), &session).await?;
1589 Ok(())
1590}
1591
1592fn notify_rejoined_projects(
1593 rejoined_projects: &mut Vec<RejoinedProject>,
1594 session: &Session,
1595) -> Result<()> {
1596 for project in rejoined_projects.iter() {
1597 for collaborator in &project.collaborators {
1598 session
1599 .peer
1600 .send(
1601 collaborator.connection_id,
1602 proto::UpdateProjectCollaborator {
1603 project_id: project.id.to_proto(),
1604 old_peer_id: Some(project.old_connection_id.into()),
1605 new_peer_id: Some(session.connection_id.into()),
1606 },
1607 )
1608 .trace_err();
1609 }
1610 }
1611
1612 for project in rejoined_projects {
1613 for worktree in mem::take(&mut project.worktrees) {
1614 // Stream this worktree's entries.
1615 let message = proto::UpdateWorktree {
1616 project_id: project.id.to_proto(),
1617 worktree_id: worktree.id,
1618 abs_path: worktree.abs_path.clone(),
1619 root_name: worktree.root_name,
1620 updated_entries: worktree.updated_entries,
1621 removed_entries: worktree.removed_entries,
1622 scan_id: worktree.scan_id,
1623 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1624 updated_repositories: worktree.updated_repositories,
1625 removed_repositories: worktree.removed_repositories,
1626 };
1627 for update in proto::split_worktree_update(message) {
1628 session.peer.send(session.connection_id, update)?;
1629 }
1630
1631 // Stream this worktree's diagnostics.
1632 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1633 if let Some(summary) = worktree_diagnostics.next() {
1634 let message = proto::UpdateDiagnosticSummary {
1635 project_id: project.id.to_proto(),
1636 worktree_id: worktree.id,
1637 summary: Some(summary),
1638 more_summaries: worktree_diagnostics.collect(),
1639 };
1640 session.peer.send(session.connection_id, message)?;
1641 }
1642
1643 for settings_file in worktree.settings_files {
1644 session.peer.send(
1645 session.connection_id,
1646 proto::UpdateWorktreeSettings {
1647 project_id: project.id.to_proto(),
1648 worktree_id: worktree.id,
1649 path: settings_file.path,
1650 content: Some(settings_file.content),
1651 kind: Some(settings_file.kind.to_proto().into()),
1652 },
1653 )?;
1654 }
1655 }
1656
1657 for repository in mem::take(&mut project.updated_repositories) {
1658 for update in split_repository_update(repository) {
1659 session.peer.send(session.connection_id, update)?;
1660 }
1661 }
1662
1663 for id in mem::take(&mut project.removed_repositories) {
1664 session.peer.send(
1665 session.connection_id,
1666 proto::RemoveRepository {
1667 project_id: project.id.to_proto(),
1668 id,
1669 },
1670 )?;
1671 }
1672 }
1673
1674 Ok(())
1675}
1676
1677/// leave room disconnects from the room.
1678async fn leave_room(
1679 _: proto::LeaveRoom,
1680 response: Response<proto::LeaveRoom>,
1681 session: MessageContext,
1682) -> Result<()> {
1683 leave_room_for_session(&session, session.connection_id).await?;
1684 response.send(proto::Ack {})?;
1685 Ok(())
1686}
1687
1688/// Updates the permissions of someone else in the room.
1689async fn set_room_participant_role(
1690 request: proto::SetRoomParticipantRole,
1691 response: Response<proto::SetRoomParticipantRole>,
1692 session: MessageContext,
1693) -> Result<()> {
1694 let user_id = UserId::from_proto(request.user_id);
1695 let role = ChannelRole::from(request.role());
1696
1697 let (livekit_room, can_publish) = {
1698 let room = session
1699 .db()
1700 .await
1701 .set_room_participant_role(
1702 session.user_id(),
1703 RoomId::from_proto(request.room_id),
1704 user_id,
1705 role,
1706 )
1707 .await?;
1708
1709 let livekit_room = room.livekit_room.clone();
1710 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1711 room_updated(&room, &session.peer);
1712 (livekit_room, can_publish)
1713 };
1714
1715 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1716 live_kit
1717 .update_participant(
1718 livekit_room.clone(),
1719 request.user_id.to_string(),
1720 livekit_api::proto::ParticipantPermission {
1721 can_subscribe: true,
1722 can_publish,
1723 can_publish_data: can_publish,
1724 hidden: false,
1725 recorder: false,
1726 },
1727 )
1728 .await
1729 .trace_err();
1730 }
1731
1732 response.send(proto::Ack {})?;
1733 Ok(())
1734}
1735
1736/// Call someone else into the current room
1737async fn call(
1738 request: proto::Call,
1739 response: Response<proto::Call>,
1740 session: MessageContext,
1741) -> Result<()> {
1742 let room_id = RoomId::from_proto(request.room_id);
1743 let calling_user_id = session.user_id();
1744 let calling_connection_id = session.connection_id;
1745 let called_user_id = UserId::from_proto(request.called_user_id);
1746 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1747 if !session
1748 .db()
1749 .await
1750 .has_contact(calling_user_id, called_user_id)
1751 .await?
1752 {
1753 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1754 }
1755
1756 let incoming_call = {
1757 let (room, incoming_call) = &mut *session
1758 .db()
1759 .await
1760 .call(
1761 room_id,
1762 calling_user_id,
1763 calling_connection_id,
1764 called_user_id,
1765 initial_project_id,
1766 )
1767 .await?;
1768 room_updated(room, &session.peer);
1769 mem::take(incoming_call)
1770 };
1771 update_user_contacts(called_user_id, &session).await?;
1772
1773 let mut calls = session
1774 .connection_pool()
1775 .await
1776 .user_connection_ids(called_user_id)
1777 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1778 .collect::<FuturesUnordered<_>>();
1779
1780 while let Some(call_response) = calls.next().await {
1781 match call_response.as_ref() {
1782 Ok(_) => {
1783 response.send(proto::Ack {})?;
1784 return Ok(());
1785 }
1786 Err(_) => {
1787 call_response.trace_err();
1788 }
1789 }
1790 }
1791
1792 {
1793 let room = session
1794 .db()
1795 .await
1796 .call_failed(room_id, called_user_id)
1797 .await?;
1798 room_updated(&room, &session.peer);
1799 }
1800 update_user_contacts(called_user_id, &session).await?;
1801
1802 Err(anyhow!("failed to ring user"))?
1803}
1804
1805/// Cancel an outgoing call.
1806async fn cancel_call(
1807 request: proto::CancelCall,
1808 response: Response<proto::CancelCall>,
1809 session: MessageContext,
1810) -> Result<()> {
1811 let called_user_id = UserId::from_proto(request.called_user_id);
1812 let room_id = RoomId::from_proto(request.room_id);
1813 {
1814 let room = session
1815 .db()
1816 .await
1817 .cancel_call(room_id, session.connection_id, called_user_id)
1818 .await?;
1819 room_updated(&room, &session.peer);
1820 }
1821
1822 for connection_id in session
1823 .connection_pool()
1824 .await
1825 .user_connection_ids(called_user_id)
1826 {
1827 session
1828 .peer
1829 .send(
1830 connection_id,
1831 proto::CallCanceled {
1832 room_id: room_id.to_proto(),
1833 },
1834 )
1835 .trace_err();
1836 }
1837 response.send(proto::Ack {})?;
1838
1839 update_user_contacts(called_user_id, &session).await?;
1840 Ok(())
1841}
1842
1843/// Decline an incoming call.
1844async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1845 let room_id = RoomId::from_proto(message.room_id);
1846 {
1847 let room = session
1848 .db()
1849 .await
1850 .decline_call(Some(room_id), session.user_id())
1851 .await?
1852 .context("declining call")?;
1853 room_updated(&room, &session.peer);
1854 }
1855
1856 for connection_id in session
1857 .connection_pool()
1858 .await
1859 .user_connection_ids(session.user_id())
1860 {
1861 session
1862 .peer
1863 .send(
1864 connection_id,
1865 proto::CallCanceled {
1866 room_id: room_id.to_proto(),
1867 },
1868 )
1869 .trace_err();
1870 }
1871 update_user_contacts(session.user_id(), &session).await?;
1872 Ok(())
1873}
1874
1875/// Updates other participants in the room with your current location.
1876async fn update_participant_location(
1877 request: proto::UpdateParticipantLocation,
1878 response: Response<proto::UpdateParticipantLocation>,
1879 session: MessageContext,
1880) -> Result<()> {
1881 let room_id = RoomId::from_proto(request.room_id);
1882 let location = request.location.context("invalid location")?;
1883
1884 let db = session.db().await;
1885 let room = db
1886 .update_room_participant_location(room_id, session.connection_id, location)
1887 .await?;
1888
1889 room_updated(&room, &session.peer);
1890 response.send(proto::Ack {})?;
1891 Ok(())
1892}
1893
1894/// Share a project into the room.
1895async fn share_project(
1896 request: proto::ShareProject,
1897 response: Response<proto::ShareProject>,
1898 session: MessageContext,
1899) -> Result<()> {
1900 let (project_id, room) = &*session
1901 .db()
1902 .await
1903 .share_project(
1904 RoomId::from_proto(request.room_id),
1905 session.connection_id,
1906 &request.worktrees,
1907 request.is_ssh_project,
1908 )
1909 .await?;
1910 response.send(proto::ShareProjectResponse {
1911 project_id: project_id.to_proto(),
1912 })?;
1913 room_updated(room, &session.peer);
1914
1915 Ok(())
1916}
1917
1918/// Unshare a project from the room.
1919async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1920 let project_id = ProjectId::from_proto(message.project_id);
1921 unshare_project_internal(project_id, session.connection_id, &session).await
1922}
1923
1924async fn unshare_project_internal(
1925 project_id: ProjectId,
1926 connection_id: ConnectionId,
1927 session: &Session,
1928) -> Result<()> {
1929 let delete = {
1930 let room_guard = session
1931 .db()
1932 .await
1933 .unshare_project(project_id, connection_id)
1934 .await?;
1935
1936 let (delete, room, guest_connection_ids) = &*room_guard;
1937
1938 let message = proto::UnshareProject {
1939 project_id: project_id.to_proto(),
1940 };
1941
1942 broadcast(
1943 Some(connection_id),
1944 guest_connection_ids.iter().copied(),
1945 |conn_id| session.peer.send(conn_id, message.clone()),
1946 );
1947 if let Some(room) = room {
1948 room_updated(room, &session.peer);
1949 }
1950
1951 *delete
1952 };
1953
1954 if delete {
1955 let db = session.db().await;
1956 db.delete_project(project_id).await?;
1957 }
1958
1959 Ok(())
1960}
1961
1962/// Join someone elses shared project.
1963async fn join_project(
1964 request: proto::JoinProject,
1965 response: Response<proto::JoinProject>,
1966 session: MessageContext,
1967) -> Result<()> {
1968 let project_id = ProjectId::from_proto(request.project_id);
1969
1970 tracing::info!(%project_id, "join project");
1971
1972 let db = session.db().await;
1973 let (project, replica_id) = &mut *db
1974 .join_project(
1975 project_id,
1976 session.connection_id,
1977 session.user_id(),
1978 request.committer_name.clone(),
1979 request.committer_email.clone(),
1980 )
1981 .await?;
1982 drop(db);
1983 tracing::info!(%project_id, "join remote project");
1984 let collaborators = project
1985 .collaborators
1986 .iter()
1987 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1988 .map(|collaborator| collaborator.to_proto())
1989 .collect::<Vec<_>>();
1990 let project_id = project.id;
1991 let guest_user_id = session.user_id();
1992
1993 let worktrees = project
1994 .worktrees
1995 .iter()
1996 .map(|(id, worktree)| proto::WorktreeMetadata {
1997 id: *id,
1998 root_name: worktree.root_name.clone(),
1999 visible: worktree.visible,
2000 abs_path: worktree.abs_path.clone(),
2001 })
2002 .collect::<Vec<_>>();
2003
2004 let add_project_collaborator = proto::AddProjectCollaborator {
2005 project_id: project_id.to_proto(),
2006 collaborator: Some(proto::Collaborator {
2007 peer_id: Some(session.connection_id.into()),
2008 replica_id: replica_id.0 as u32,
2009 user_id: guest_user_id.to_proto(),
2010 is_host: false,
2011 committer_name: request.committer_name.clone(),
2012 committer_email: request.committer_email.clone(),
2013 }),
2014 };
2015
2016 for collaborator in &collaborators {
2017 session
2018 .peer
2019 .send(
2020 collaborator.peer_id.unwrap().into(),
2021 add_project_collaborator.clone(),
2022 )
2023 .trace_err();
2024 }
2025
2026 // First, we send the metadata associated with each worktree.
2027 let (language_servers, language_server_capabilities) = project
2028 .language_servers
2029 .clone()
2030 .into_iter()
2031 .map(|server| (server.server, server.capabilities))
2032 .unzip();
2033 response.send(proto::JoinProjectResponse {
2034 project_id: project.id.0 as u64,
2035 worktrees: worktrees.clone(),
2036 replica_id: replica_id.0 as u32,
2037 collaborators: collaborators.clone(),
2038 language_servers,
2039 language_server_capabilities,
2040 role: project.role.into(),
2041 })?;
2042
2043 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2044 // Stream this worktree's entries.
2045 let message = proto::UpdateWorktree {
2046 project_id: project_id.to_proto(),
2047 worktree_id,
2048 abs_path: worktree.abs_path.clone(),
2049 root_name: worktree.root_name,
2050 updated_entries: worktree.entries,
2051 removed_entries: Default::default(),
2052 scan_id: worktree.scan_id,
2053 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2054 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2055 removed_repositories: Default::default(),
2056 };
2057 for update in proto::split_worktree_update(message) {
2058 session.peer.send(session.connection_id, update.clone())?;
2059 }
2060
2061 // Stream this worktree's diagnostics.
2062 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2063 if let Some(summary) = worktree_diagnostics.next() {
2064 let message = proto::UpdateDiagnosticSummary {
2065 project_id: project.id.to_proto(),
2066 worktree_id: worktree.id,
2067 summary: Some(summary),
2068 more_summaries: worktree_diagnostics.collect(),
2069 };
2070 session.peer.send(session.connection_id, message)?;
2071 }
2072
2073 for settings_file in worktree.settings_files {
2074 session.peer.send(
2075 session.connection_id,
2076 proto::UpdateWorktreeSettings {
2077 project_id: project_id.to_proto(),
2078 worktree_id: worktree.id,
2079 path: settings_file.path,
2080 content: Some(settings_file.content),
2081 kind: Some(settings_file.kind.to_proto() as i32),
2082 },
2083 )?;
2084 }
2085 }
2086
2087 for repository in mem::take(&mut project.repositories) {
2088 for update in split_repository_update(repository) {
2089 session.peer.send(session.connection_id, update)?;
2090 }
2091 }
2092
2093 for language_server in &project.language_servers {
2094 session.peer.send(
2095 session.connection_id,
2096 proto::UpdateLanguageServer {
2097 project_id: project_id.to_proto(),
2098 server_name: Some(language_server.server.name.clone()),
2099 language_server_id: language_server.server.id,
2100 variant: Some(
2101 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2102 proto::LspDiskBasedDiagnosticsUpdated {},
2103 ),
2104 ),
2105 },
2106 )?;
2107 }
2108
2109 Ok(())
2110}
2111
2112/// Leave someone elses shared project.
2113async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2114 let sender_id = session.connection_id;
2115 let project_id = ProjectId::from_proto(request.project_id);
2116 let db = session.db().await;
2117
2118 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2119 tracing::info!(
2120 %project_id,
2121 "leave project"
2122 );
2123
2124 project_left(project, &session);
2125 if let Some(room) = room {
2126 room_updated(room, &session.peer);
2127 }
2128
2129 Ok(())
2130}
2131
2132/// Updates other participants with changes to the project
2133async fn update_project(
2134 request: proto::UpdateProject,
2135 response: Response<proto::UpdateProject>,
2136 session: MessageContext,
2137) -> Result<()> {
2138 let project_id = ProjectId::from_proto(request.project_id);
2139 let (room, guest_connection_ids) = &*session
2140 .db()
2141 .await
2142 .update_project(project_id, session.connection_id, &request.worktrees)
2143 .await?;
2144 broadcast(
2145 Some(session.connection_id),
2146 guest_connection_ids.iter().copied(),
2147 |connection_id| {
2148 session
2149 .peer
2150 .forward_send(session.connection_id, connection_id, request.clone())
2151 },
2152 );
2153 if let Some(room) = room {
2154 room_updated(room, &session.peer);
2155 }
2156 response.send(proto::Ack {})?;
2157
2158 Ok(())
2159}
2160
2161/// Updates other participants with changes to the worktree
2162async fn update_worktree(
2163 request: proto::UpdateWorktree,
2164 response: Response<proto::UpdateWorktree>,
2165 session: MessageContext,
2166) -> Result<()> {
2167 let guest_connection_ids = session
2168 .db()
2169 .await
2170 .update_worktree(&request, session.connection_id)
2171 .await?;
2172
2173 broadcast(
2174 Some(session.connection_id),
2175 guest_connection_ids.iter().copied(),
2176 |connection_id| {
2177 session
2178 .peer
2179 .forward_send(session.connection_id, connection_id, request.clone())
2180 },
2181 );
2182 response.send(proto::Ack {})?;
2183 Ok(())
2184}
2185
2186async fn update_repository(
2187 request: proto::UpdateRepository,
2188 response: Response<proto::UpdateRepository>,
2189 session: MessageContext,
2190) -> Result<()> {
2191 let guest_connection_ids = session
2192 .db()
2193 .await
2194 .update_repository(&request, session.connection_id)
2195 .await?;
2196
2197 broadcast(
2198 Some(session.connection_id),
2199 guest_connection_ids.iter().copied(),
2200 |connection_id| {
2201 session
2202 .peer
2203 .forward_send(session.connection_id, connection_id, request.clone())
2204 },
2205 );
2206 response.send(proto::Ack {})?;
2207 Ok(())
2208}
2209
2210async fn remove_repository(
2211 request: proto::RemoveRepository,
2212 response: Response<proto::RemoveRepository>,
2213 session: MessageContext,
2214) -> Result<()> {
2215 let guest_connection_ids = session
2216 .db()
2217 .await
2218 .remove_repository(&request, session.connection_id)
2219 .await?;
2220
2221 broadcast(
2222 Some(session.connection_id),
2223 guest_connection_ids.iter().copied(),
2224 |connection_id| {
2225 session
2226 .peer
2227 .forward_send(session.connection_id, connection_id, request.clone())
2228 },
2229 );
2230 response.send(proto::Ack {})?;
2231 Ok(())
2232}
2233
2234/// Updates other participants with changes to the diagnostics
2235async fn update_diagnostic_summary(
2236 message: proto::UpdateDiagnosticSummary,
2237 session: MessageContext,
2238) -> Result<()> {
2239 let guest_connection_ids = session
2240 .db()
2241 .await
2242 .update_diagnostic_summary(&message, session.connection_id)
2243 .await?;
2244
2245 broadcast(
2246 Some(session.connection_id),
2247 guest_connection_ids.iter().copied(),
2248 |connection_id| {
2249 session
2250 .peer
2251 .forward_send(session.connection_id, connection_id, message.clone())
2252 },
2253 );
2254
2255 Ok(())
2256}
2257
2258/// Updates other participants with changes to the worktree settings
2259async fn update_worktree_settings(
2260 message: proto::UpdateWorktreeSettings,
2261 session: MessageContext,
2262) -> Result<()> {
2263 let guest_connection_ids = session
2264 .db()
2265 .await
2266 .update_worktree_settings(&message, session.connection_id)
2267 .await?;
2268
2269 broadcast(
2270 Some(session.connection_id),
2271 guest_connection_ids.iter().copied(),
2272 |connection_id| {
2273 session
2274 .peer
2275 .forward_send(session.connection_id, connection_id, message.clone())
2276 },
2277 );
2278
2279 Ok(())
2280}
2281
2282/// Notify other participants that a language server has started.
2283async fn start_language_server(
2284 request: proto::StartLanguageServer,
2285 session: MessageContext,
2286) -> Result<()> {
2287 let guest_connection_ids = session
2288 .db()
2289 .await
2290 .start_language_server(&request, session.connection_id)
2291 .await?;
2292
2293 broadcast(
2294 Some(session.connection_id),
2295 guest_connection_ids.iter().copied(),
2296 |connection_id| {
2297 session
2298 .peer
2299 .forward_send(session.connection_id, connection_id, request.clone())
2300 },
2301 );
2302 Ok(())
2303}
2304
2305/// Notify other participants that a language server has changed.
2306async fn update_language_server(
2307 request: proto::UpdateLanguageServer,
2308 session: MessageContext,
2309) -> Result<()> {
2310 let project_id = ProjectId::from_proto(request.project_id);
2311 let db = session.db().await;
2312
2313 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2314 {
2315 if let Some(capabilities) = update.capabilities.clone() {
2316 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2317 .await?;
2318 }
2319 }
2320
2321 let project_connection_ids = db
2322 .project_connection_ids(project_id, session.connection_id, true)
2323 .await?;
2324 broadcast(
2325 Some(session.connection_id),
2326 project_connection_ids.iter().copied(),
2327 |connection_id| {
2328 session
2329 .peer
2330 .forward_send(session.connection_id, connection_id, request.clone())
2331 },
2332 );
2333 Ok(())
2334}
2335
2336/// forward a project request to the host. These requests should be read only
2337/// as guests are allowed to send them.
2338async fn forward_read_only_project_request<T>(
2339 request: T,
2340 response: Response<T>,
2341 session: MessageContext,
2342) -> Result<()>
2343where
2344 T: EntityMessage + RequestMessage,
2345{
2346 let project_id = ProjectId::from_proto(request.remote_entity_id());
2347 let host_connection_id = session
2348 .db()
2349 .await
2350 .host_for_read_only_project_request(project_id, session.connection_id)
2351 .await?;
2352 let payload = session.forward_request(host_connection_id, request).await?;
2353 response.send(payload)?;
2354 Ok(())
2355}
2356
2357/// forward a project request to the host. These requests are disallowed
2358/// for guests.
2359async fn forward_mutating_project_request<T>(
2360 request: T,
2361 response: Response<T>,
2362 session: MessageContext,
2363) -> Result<()>
2364where
2365 T: EntityMessage + RequestMessage,
2366{
2367 let project_id = ProjectId::from_proto(request.remote_entity_id());
2368
2369 let host_connection_id = session
2370 .db()
2371 .await
2372 .host_for_mutating_project_request(project_id, session.connection_id)
2373 .await?;
2374 let payload = session.forward_request(host_connection_id, request).await?;
2375 response.send(payload)?;
2376 Ok(())
2377}
2378
2379async fn multi_lsp_query(
2380 request: MultiLspQuery,
2381 response: Response<MultiLspQuery>,
2382 session: MessageContext,
2383) -> Result<()> {
2384 tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2385 tracing::info!("multi_lsp_query message received");
2386 forward_mutating_project_request(request, response, session).await
2387}
2388
2389/// Notify other participants that a new buffer has been created
2390async fn create_buffer_for_peer(
2391 request: proto::CreateBufferForPeer,
2392 session: MessageContext,
2393) -> Result<()> {
2394 session
2395 .db()
2396 .await
2397 .check_user_is_project_host(
2398 ProjectId::from_proto(request.project_id),
2399 session.connection_id,
2400 )
2401 .await?;
2402 let peer_id = request.peer_id.context("invalid peer id")?;
2403 session
2404 .peer
2405 .forward_send(session.connection_id, peer_id.into(), request)?;
2406 Ok(())
2407}
2408
2409/// Notify other participants that a buffer has been updated. This is
2410/// allowed for guests as long as the update is limited to selections.
2411async fn update_buffer(
2412 request: proto::UpdateBuffer,
2413 response: Response<proto::UpdateBuffer>,
2414 session: MessageContext,
2415) -> Result<()> {
2416 let project_id = ProjectId::from_proto(request.project_id);
2417 let mut capability = Capability::ReadOnly;
2418
2419 for op in request.operations.iter() {
2420 match op.variant {
2421 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2422 Some(_) => capability = Capability::ReadWrite,
2423 }
2424 }
2425
2426 let host = {
2427 let guard = session
2428 .db()
2429 .await
2430 .connections_for_buffer_update(project_id, session.connection_id, capability)
2431 .await?;
2432
2433 let (host, guests) = &*guard;
2434
2435 broadcast(
2436 Some(session.connection_id),
2437 guests.clone(),
2438 |connection_id| {
2439 session
2440 .peer
2441 .forward_send(session.connection_id, connection_id, request.clone())
2442 },
2443 );
2444
2445 *host
2446 };
2447
2448 if host != session.connection_id {
2449 session.forward_request(host, request.clone()).await?;
2450 }
2451
2452 response.send(proto::Ack {})?;
2453 Ok(())
2454}
2455
2456async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2457 let project_id = ProjectId::from_proto(message.project_id);
2458
2459 let operation = message.operation.as_ref().context("invalid operation")?;
2460 let capability = match operation.variant.as_ref() {
2461 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2462 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2463 match buffer_op.variant {
2464 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2465 Capability::ReadOnly
2466 }
2467 _ => Capability::ReadWrite,
2468 }
2469 } else {
2470 Capability::ReadWrite
2471 }
2472 }
2473 Some(_) => Capability::ReadWrite,
2474 None => Capability::ReadOnly,
2475 };
2476
2477 let guard = session
2478 .db()
2479 .await
2480 .connections_for_buffer_update(project_id, session.connection_id, capability)
2481 .await?;
2482
2483 let (host, guests) = &*guard;
2484
2485 broadcast(
2486 Some(session.connection_id),
2487 guests.iter().chain([host]).copied(),
2488 |connection_id| {
2489 session
2490 .peer
2491 .forward_send(session.connection_id, connection_id, message.clone())
2492 },
2493 );
2494
2495 Ok(())
2496}
2497
2498/// Notify other participants that a project has been updated.
2499async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2500 request: T,
2501 session: MessageContext,
2502) -> Result<()> {
2503 let project_id = ProjectId::from_proto(request.remote_entity_id());
2504 let project_connection_ids = session
2505 .db()
2506 .await
2507 .project_connection_ids(project_id, session.connection_id, false)
2508 .await?;
2509
2510 broadcast(
2511 Some(session.connection_id),
2512 project_connection_ids.iter().copied(),
2513 |connection_id| {
2514 session
2515 .peer
2516 .forward_send(session.connection_id, connection_id, request.clone())
2517 },
2518 );
2519 Ok(())
2520}
2521
2522/// Start following another user in a call.
2523async fn follow(
2524 request: proto::Follow,
2525 response: Response<proto::Follow>,
2526 session: MessageContext,
2527) -> Result<()> {
2528 let room_id = RoomId::from_proto(request.room_id);
2529 let project_id = request.project_id.map(ProjectId::from_proto);
2530 let leader_id = request.leader_id.context("invalid leader id")?.into();
2531 let follower_id = session.connection_id;
2532
2533 session
2534 .db()
2535 .await
2536 .check_room_participants(room_id, leader_id, session.connection_id)
2537 .await?;
2538
2539 let response_payload = session.forward_request(leader_id, request).await?;
2540 response.send(response_payload)?;
2541
2542 if let Some(project_id) = project_id {
2543 let room = session
2544 .db()
2545 .await
2546 .follow(room_id, project_id, leader_id, follower_id)
2547 .await?;
2548 room_updated(&room, &session.peer);
2549 }
2550
2551 Ok(())
2552}
2553
2554/// Stop following another user in a call.
2555async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2556 let room_id = RoomId::from_proto(request.room_id);
2557 let project_id = request.project_id.map(ProjectId::from_proto);
2558 let leader_id = request.leader_id.context("invalid leader id")?.into();
2559 let follower_id = session.connection_id;
2560
2561 session
2562 .db()
2563 .await
2564 .check_room_participants(room_id, leader_id, session.connection_id)
2565 .await?;
2566
2567 session
2568 .peer
2569 .forward_send(session.connection_id, leader_id, request)?;
2570
2571 if let Some(project_id) = project_id {
2572 let room = session
2573 .db()
2574 .await
2575 .unfollow(room_id, project_id, leader_id, follower_id)
2576 .await?;
2577 room_updated(&room, &session.peer);
2578 }
2579
2580 Ok(())
2581}
2582
2583/// Notify everyone following you of your current location.
2584async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2585 let room_id = RoomId::from_proto(request.room_id);
2586 let database = session.db.lock().await;
2587
2588 let connection_ids = if let Some(project_id) = request.project_id {
2589 let project_id = ProjectId::from_proto(project_id);
2590 database
2591 .project_connection_ids(project_id, session.connection_id, true)
2592 .await?
2593 } else {
2594 database
2595 .room_connection_ids(room_id, session.connection_id)
2596 .await?
2597 };
2598
2599 // For now, don't send view update messages back to that view's current leader.
2600 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2601 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2602 _ => None,
2603 });
2604
2605 for connection_id in connection_ids.iter().cloned() {
2606 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2607 session
2608 .peer
2609 .forward_send(session.connection_id, connection_id, request.clone())?;
2610 }
2611 }
2612 Ok(())
2613}
2614
2615/// Get public data about users.
2616async fn get_users(
2617 request: proto::GetUsers,
2618 response: Response<proto::GetUsers>,
2619 session: MessageContext,
2620) -> Result<()> {
2621 let user_ids = request
2622 .user_ids
2623 .into_iter()
2624 .map(UserId::from_proto)
2625 .collect();
2626 let users = session
2627 .db()
2628 .await
2629 .get_users_by_ids(user_ids)
2630 .await?
2631 .into_iter()
2632 .map(|user| proto::User {
2633 id: user.id.to_proto(),
2634 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2635 github_login: user.github_login,
2636 name: user.name,
2637 })
2638 .collect();
2639 response.send(proto::UsersResponse { users })?;
2640 Ok(())
2641}
2642
2643/// Search for users (to invite) buy Github login
2644async fn fuzzy_search_users(
2645 request: proto::FuzzySearchUsers,
2646 response: Response<proto::FuzzySearchUsers>,
2647 session: MessageContext,
2648) -> Result<()> {
2649 let query = request.query;
2650 let users = match query.len() {
2651 0 => vec![],
2652 1 | 2 => session
2653 .db()
2654 .await
2655 .get_user_by_github_login(&query)
2656 .await?
2657 .into_iter()
2658 .collect(),
2659 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2660 };
2661 let users = users
2662 .into_iter()
2663 .filter(|user| user.id != session.user_id())
2664 .map(|user| proto::User {
2665 id: user.id.to_proto(),
2666 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2667 github_login: user.github_login,
2668 name: user.name,
2669 })
2670 .collect();
2671 response.send(proto::UsersResponse { users })?;
2672 Ok(())
2673}
2674
2675/// Send a contact request to another user.
2676async fn request_contact(
2677 request: proto::RequestContact,
2678 response: Response<proto::RequestContact>,
2679 session: MessageContext,
2680) -> Result<()> {
2681 let requester_id = session.user_id();
2682 let responder_id = UserId::from_proto(request.responder_id);
2683 if requester_id == responder_id {
2684 return Err(anyhow!("cannot add yourself as a contact"))?;
2685 }
2686
2687 let notifications = session
2688 .db()
2689 .await
2690 .send_contact_request(requester_id, responder_id)
2691 .await?;
2692
2693 // Update outgoing contact requests of requester
2694 let mut update = proto::UpdateContacts::default();
2695 update.outgoing_requests.push(responder_id.to_proto());
2696 for connection_id in session
2697 .connection_pool()
2698 .await
2699 .user_connection_ids(requester_id)
2700 {
2701 session.peer.send(connection_id, update.clone())?;
2702 }
2703
2704 // Update incoming contact requests of responder
2705 let mut update = proto::UpdateContacts::default();
2706 update
2707 .incoming_requests
2708 .push(proto::IncomingContactRequest {
2709 requester_id: requester_id.to_proto(),
2710 });
2711 let connection_pool = session.connection_pool().await;
2712 for connection_id in connection_pool.user_connection_ids(responder_id) {
2713 session.peer.send(connection_id, update.clone())?;
2714 }
2715
2716 send_notifications(&connection_pool, &session.peer, notifications);
2717
2718 response.send(proto::Ack {})?;
2719 Ok(())
2720}
2721
2722/// Accept or decline a contact request
2723async fn respond_to_contact_request(
2724 request: proto::RespondToContactRequest,
2725 response: Response<proto::RespondToContactRequest>,
2726 session: MessageContext,
2727) -> Result<()> {
2728 let responder_id = session.user_id();
2729 let requester_id = UserId::from_proto(request.requester_id);
2730 let db = session.db().await;
2731 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2732 db.dismiss_contact_notification(responder_id, requester_id)
2733 .await?;
2734 } else {
2735 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2736
2737 let notifications = db
2738 .respond_to_contact_request(responder_id, requester_id, accept)
2739 .await?;
2740 let requester_busy = db.is_user_busy(requester_id).await?;
2741 let responder_busy = db.is_user_busy(responder_id).await?;
2742
2743 let pool = session.connection_pool().await;
2744 // Update responder with new contact
2745 let mut update = proto::UpdateContacts::default();
2746 if accept {
2747 update
2748 .contacts
2749 .push(contact_for_user(requester_id, requester_busy, &pool));
2750 }
2751 update
2752 .remove_incoming_requests
2753 .push(requester_id.to_proto());
2754 for connection_id in pool.user_connection_ids(responder_id) {
2755 session.peer.send(connection_id, update.clone())?;
2756 }
2757
2758 // Update requester with new contact
2759 let mut update = proto::UpdateContacts::default();
2760 if accept {
2761 update
2762 .contacts
2763 .push(contact_for_user(responder_id, responder_busy, &pool));
2764 }
2765 update
2766 .remove_outgoing_requests
2767 .push(responder_id.to_proto());
2768
2769 for connection_id in pool.user_connection_ids(requester_id) {
2770 session.peer.send(connection_id, update.clone())?;
2771 }
2772
2773 send_notifications(&pool, &session.peer, notifications);
2774 }
2775
2776 response.send(proto::Ack {})?;
2777 Ok(())
2778}
2779
2780/// Remove a contact.
2781async fn remove_contact(
2782 request: proto::RemoveContact,
2783 response: Response<proto::RemoveContact>,
2784 session: MessageContext,
2785) -> Result<()> {
2786 let requester_id = session.user_id();
2787 let responder_id = UserId::from_proto(request.user_id);
2788 let db = session.db().await;
2789 let (contact_accepted, deleted_notification_id) =
2790 db.remove_contact(requester_id, responder_id).await?;
2791
2792 let pool = session.connection_pool().await;
2793 // Update outgoing contact requests of requester
2794 let mut update = proto::UpdateContacts::default();
2795 if contact_accepted {
2796 update.remove_contacts.push(responder_id.to_proto());
2797 } else {
2798 update
2799 .remove_outgoing_requests
2800 .push(responder_id.to_proto());
2801 }
2802 for connection_id in pool.user_connection_ids(requester_id) {
2803 session.peer.send(connection_id, update.clone())?;
2804 }
2805
2806 // Update incoming contact requests of responder
2807 let mut update = proto::UpdateContacts::default();
2808 if contact_accepted {
2809 update.remove_contacts.push(requester_id.to_proto());
2810 } else {
2811 update
2812 .remove_incoming_requests
2813 .push(requester_id.to_proto());
2814 }
2815 for connection_id in pool.user_connection_ids(responder_id) {
2816 session.peer.send(connection_id, update.clone())?;
2817 if let Some(notification_id) = deleted_notification_id {
2818 session.peer.send(
2819 connection_id,
2820 proto::DeleteNotification {
2821 notification_id: notification_id.to_proto(),
2822 },
2823 )?;
2824 }
2825 }
2826
2827 response.send(proto::Ack {})?;
2828 Ok(())
2829}
2830
2831fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2832 version.0.minor() < 139
2833}
2834
2835async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2836 if is_staff {
2837 return Ok(proto::Plan::ZedPro);
2838 }
2839
2840 let subscription = db.get_active_billing_subscription(user_id).await?;
2841 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2842
2843 let plan = if let Some(subscription_kind) = subscription_kind {
2844 match subscription_kind {
2845 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2846 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2847 SubscriptionKind::ZedFree => proto::Plan::Free,
2848 }
2849 } else {
2850 proto::Plan::Free
2851 };
2852
2853 Ok(plan)
2854}
2855
2856async fn make_update_user_plan_message(
2857 user: &User,
2858 is_staff: bool,
2859 db: &Arc<Database>,
2860 llm_db: Option<Arc<LlmDatabase>>,
2861) -> Result<proto::UpdateUserPlan> {
2862 let feature_flags = db.get_user_flags(user.id).await?;
2863 let plan = current_plan(db, user.id, is_staff).await?;
2864 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2865 let billing_preferences = db.get_billing_preferences(user.id).await?;
2866
2867 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2868 let subscription = db.get_active_billing_subscription(user.id).await?;
2869
2870 let subscription_period =
2871 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2872
2873 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2874 llm_db
2875 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2876 .await?
2877 } else {
2878 None
2879 };
2880
2881 (subscription_period, usage)
2882 } else {
2883 (None, None)
2884 };
2885
2886 let bypass_account_age_check = feature_flags
2887 .iter()
2888 .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2889 let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2890 && !bypass_account_age_check
2891 && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2892
2893 Ok(proto::UpdateUserPlan {
2894 plan: plan.into(),
2895 trial_started_at: billing_customer
2896 .as_ref()
2897 .and_then(|billing_customer| billing_customer.trial_started_at)
2898 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2899 is_usage_based_billing_enabled: if is_staff {
2900 Some(true)
2901 } else {
2902 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2903 },
2904 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2905 proto::SubscriptionPeriod {
2906 started_at: started_at.timestamp() as u64,
2907 ended_at: ended_at.timestamp() as u64,
2908 }
2909 }),
2910 account_too_young: Some(account_too_young),
2911 has_overdue_invoices: billing_customer
2912 .map(|billing_customer| billing_customer.has_overdue_invoices),
2913 usage: Some(
2914 usage
2915 .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2916 .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2917 ),
2918 })
2919}
2920
2921fn model_requests_limit(
2922 plan: cloud_llm_client::Plan,
2923 feature_flags: &Vec<String>,
2924) -> cloud_llm_client::UsageLimit {
2925 match plan.model_requests_limit() {
2926 cloud_llm_client::UsageLimit::Limited(limit) => {
2927 let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2928 && feature_flags
2929 .iter()
2930 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2931 {
2932 1_000
2933 } else {
2934 limit
2935 };
2936
2937 cloud_llm_client::UsageLimit::Limited(limit)
2938 }
2939 cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2940 }
2941}
2942
2943fn subscription_usage_to_proto(
2944 plan: proto::Plan,
2945 usage: crate::llm::db::subscription_usage::Model,
2946 feature_flags: &Vec<String>,
2947) -> proto::SubscriptionUsage {
2948 let plan = match plan {
2949 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2950 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2951 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2952 };
2953
2954 proto::SubscriptionUsage {
2955 model_requests_usage_amount: usage.model_requests as u32,
2956 model_requests_usage_limit: Some(proto::UsageLimit {
2957 variant: Some(match model_requests_limit(plan, feature_flags) {
2958 cloud_llm_client::UsageLimit::Limited(limit) => {
2959 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2960 limit: limit as u32,
2961 })
2962 }
2963 cloud_llm_client::UsageLimit::Unlimited => {
2964 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2965 }
2966 }),
2967 }),
2968 edit_predictions_usage_amount: usage.edit_predictions as u32,
2969 edit_predictions_usage_limit: Some(proto::UsageLimit {
2970 variant: Some(match plan.edit_predictions_limit() {
2971 cloud_llm_client::UsageLimit::Limited(limit) => {
2972 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2973 limit: limit as u32,
2974 })
2975 }
2976 cloud_llm_client::UsageLimit::Unlimited => {
2977 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2978 }
2979 }),
2980 }),
2981 }
2982}
2983
2984fn make_default_subscription_usage(
2985 plan: proto::Plan,
2986 feature_flags: &Vec<String>,
2987) -> proto::SubscriptionUsage {
2988 let plan = match plan {
2989 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2990 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2991 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2992 };
2993
2994 proto::SubscriptionUsage {
2995 model_requests_usage_amount: 0,
2996 model_requests_usage_limit: Some(proto::UsageLimit {
2997 variant: Some(match model_requests_limit(plan, feature_flags) {
2998 cloud_llm_client::UsageLimit::Limited(limit) => {
2999 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3000 limit: limit as u32,
3001 })
3002 }
3003 cloud_llm_client::UsageLimit::Unlimited => {
3004 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3005 }
3006 }),
3007 }),
3008 edit_predictions_usage_amount: 0,
3009 edit_predictions_usage_limit: Some(proto::UsageLimit {
3010 variant: Some(match plan.edit_predictions_limit() {
3011 cloud_llm_client::UsageLimit::Limited(limit) => {
3012 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3013 limit: limit as u32,
3014 })
3015 }
3016 cloud_llm_client::UsageLimit::Unlimited => {
3017 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3018 }
3019 }),
3020 }),
3021 }
3022}
3023
3024async fn update_user_plan(session: &Session) -> Result<()> {
3025 let db = session.db().await;
3026
3027 let update_user_plan = make_update_user_plan_message(
3028 session.principal.user(),
3029 session.is_staff(),
3030 &db.0,
3031 session.app_state.llm_db.clone(),
3032 )
3033 .await?;
3034
3035 session
3036 .peer
3037 .send(session.connection_id, update_user_plan)
3038 .trace_err();
3039
3040 Ok(())
3041}
3042
3043async fn subscribe_to_channels(
3044 _: proto::SubscribeToChannels,
3045 session: MessageContext,
3046) -> Result<()> {
3047 subscribe_user_to_channels(session.user_id(), &session).await?;
3048 Ok(())
3049}
3050
3051async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3052 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3053 let mut pool = session.connection_pool().await;
3054 for membership in &channels_for_user.channel_memberships {
3055 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3056 }
3057 session.peer.send(
3058 session.connection_id,
3059 build_update_user_channels(&channels_for_user),
3060 )?;
3061 session.peer.send(
3062 session.connection_id,
3063 build_channels_update(channels_for_user),
3064 )?;
3065 Ok(())
3066}
3067
3068/// Creates a new channel.
3069async fn create_channel(
3070 request: proto::CreateChannel,
3071 response: Response<proto::CreateChannel>,
3072 session: MessageContext,
3073) -> Result<()> {
3074 let db = session.db().await;
3075
3076 let parent_id = request.parent_id.map(ChannelId::from_proto);
3077 let (channel, membership) = db
3078 .create_channel(&request.name, parent_id, session.user_id())
3079 .await?;
3080
3081 let root_id = channel.root_id();
3082 let channel = Channel::from_model(channel);
3083
3084 response.send(proto::CreateChannelResponse {
3085 channel: Some(channel.to_proto()),
3086 parent_id: request.parent_id,
3087 })?;
3088
3089 let mut connection_pool = session.connection_pool().await;
3090 if let Some(membership) = membership {
3091 connection_pool.subscribe_to_channel(
3092 membership.user_id,
3093 membership.channel_id,
3094 membership.role,
3095 );
3096 let update = proto::UpdateUserChannels {
3097 channel_memberships: vec![proto::ChannelMembership {
3098 channel_id: membership.channel_id.to_proto(),
3099 role: membership.role.into(),
3100 }],
3101 ..Default::default()
3102 };
3103 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3104 session.peer.send(connection_id, update.clone())?;
3105 }
3106 }
3107
3108 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3109 if !role.can_see_channel(channel.visibility) {
3110 continue;
3111 }
3112
3113 let update = proto::UpdateChannels {
3114 channels: vec![channel.to_proto()],
3115 ..Default::default()
3116 };
3117 session.peer.send(connection_id, update.clone())?;
3118 }
3119
3120 Ok(())
3121}
3122
3123/// Delete a channel
3124async fn delete_channel(
3125 request: proto::DeleteChannel,
3126 response: Response<proto::DeleteChannel>,
3127 session: MessageContext,
3128) -> Result<()> {
3129 let db = session.db().await;
3130
3131 let channel_id = request.channel_id;
3132 let (root_channel, removed_channels) = db
3133 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3134 .await?;
3135 response.send(proto::Ack {})?;
3136
3137 // Notify members of removed channels
3138 let mut update = proto::UpdateChannels::default();
3139 update
3140 .delete_channels
3141 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3142
3143 let connection_pool = session.connection_pool().await;
3144 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3145 session.peer.send(connection_id, update.clone())?;
3146 }
3147
3148 Ok(())
3149}
3150
3151/// Invite someone to join a channel.
3152async fn invite_channel_member(
3153 request: proto::InviteChannelMember,
3154 response: Response<proto::InviteChannelMember>,
3155 session: MessageContext,
3156) -> Result<()> {
3157 let db = session.db().await;
3158 let channel_id = ChannelId::from_proto(request.channel_id);
3159 let invitee_id = UserId::from_proto(request.user_id);
3160 let InviteMemberResult {
3161 channel,
3162 notifications,
3163 } = db
3164 .invite_channel_member(
3165 channel_id,
3166 invitee_id,
3167 session.user_id(),
3168 request.role().into(),
3169 )
3170 .await?;
3171
3172 let update = proto::UpdateChannels {
3173 channel_invitations: vec![channel.to_proto()],
3174 ..Default::default()
3175 };
3176
3177 let connection_pool = session.connection_pool().await;
3178 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3179 session.peer.send(connection_id, update.clone())?;
3180 }
3181
3182 send_notifications(&connection_pool, &session.peer, notifications);
3183
3184 response.send(proto::Ack {})?;
3185 Ok(())
3186}
3187
3188/// remove someone from a channel
3189async fn remove_channel_member(
3190 request: proto::RemoveChannelMember,
3191 response: Response<proto::RemoveChannelMember>,
3192 session: MessageContext,
3193) -> Result<()> {
3194 let db = session.db().await;
3195 let channel_id = ChannelId::from_proto(request.channel_id);
3196 let member_id = UserId::from_proto(request.user_id);
3197
3198 let RemoveChannelMemberResult {
3199 membership_update,
3200 notification_id,
3201 } = db
3202 .remove_channel_member(channel_id, member_id, session.user_id())
3203 .await?;
3204
3205 let mut connection_pool = session.connection_pool().await;
3206 notify_membership_updated(
3207 &mut connection_pool,
3208 membership_update,
3209 member_id,
3210 &session.peer,
3211 );
3212 for connection_id in connection_pool.user_connection_ids(member_id) {
3213 if let Some(notification_id) = notification_id {
3214 session
3215 .peer
3216 .send(
3217 connection_id,
3218 proto::DeleteNotification {
3219 notification_id: notification_id.to_proto(),
3220 },
3221 )
3222 .trace_err();
3223 }
3224 }
3225
3226 response.send(proto::Ack {})?;
3227 Ok(())
3228}
3229
3230/// Toggle the channel between public and private.
3231/// Care is taken to maintain the invariant that public channels only descend from public channels,
3232/// (though members-only channels can appear at any point in the hierarchy).
3233async fn set_channel_visibility(
3234 request: proto::SetChannelVisibility,
3235 response: Response<proto::SetChannelVisibility>,
3236 session: MessageContext,
3237) -> Result<()> {
3238 let db = session.db().await;
3239 let channel_id = ChannelId::from_proto(request.channel_id);
3240 let visibility = request.visibility().into();
3241
3242 let channel_model = db
3243 .set_channel_visibility(channel_id, visibility, session.user_id())
3244 .await?;
3245 let root_id = channel_model.root_id();
3246 let channel = Channel::from_model(channel_model);
3247
3248 let mut connection_pool = session.connection_pool().await;
3249 for (user_id, role) in connection_pool
3250 .channel_user_ids(root_id)
3251 .collect::<Vec<_>>()
3252 .into_iter()
3253 {
3254 let update = if role.can_see_channel(channel.visibility) {
3255 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3256 proto::UpdateChannels {
3257 channels: vec![channel.to_proto()],
3258 ..Default::default()
3259 }
3260 } else {
3261 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3262 proto::UpdateChannels {
3263 delete_channels: vec![channel.id.to_proto()],
3264 ..Default::default()
3265 }
3266 };
3267
3268 for connection_id in connection_pool.user_connection_ids(user_id) {
3269 session.peer.send(connection_id, update.clone())?;
3270 }
3271 }
3272
3273 response.send(proto::Ack {})?;
3274 Ok(())
3275}
3276
3277/// Alter the role for a user in the channel.
3278async fn set_channel_member_role(
3279 request: proto::SetChannelMemberRole,
3280 response: Response<proto::SetChannelMemberRole>,
3281 session: MessageContext,
3282) -> Result<()> {
3283 let db = session.db().await;
3284 let channel_id = ChannelId::from_proto(request.channel_id);
3285 let member_id = UserId::from_proto(request.user_id);
3286 let result = db
3287 .set_channel_member_role(
3288 channel_id,
3289 session.user_id(),
3290 member_id,
3291 request.role().into(),
3292 )
3293 .await?;
3294
3295 match result {
3296 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3297 let mut connection_pool = session.connection_pool().await;
3298 notify_membership_updated(
3299 &mut connection_pool,
3300 membership_update,
3301 member_id,
3302 &session.peer,
3303 )
3304 }
3305 db::SetMemberRoleResult::InviteUpdated(channel) => {
3306 let update = proto::UpdateChannels {
3307 channel_invitations: vec![channel.to_proto()],
3308 ..Default::default()
3309 };
3310
3311 for connection_id in session
3312 .connection_pool()
3313 .await
3314 .user_connection_ids(member_id)
3315 {
3316 session.peer.send(connection_id, update.clone())?;
3317 }
3318 }
3319 }
3320
3321 response.send(proto::Ack {})?;
3322 Ok(())
3323}
3324
3325/// Change the name of a channel
3326async fn rename_channel(
3327 request: proto::RenameChannel,
3328 response: Response<proto::RenameChannel>,
3329 session: MessageContext,
3330) -> Result<()> {
3331 let db = session.db().await;
3332 let channel_id = ChannelId::from_proto(request.channel_id);
3333 let channel_model = db
3334 .rename_channel(channel_id, session.user_id(), &request.name)
3335 .await?;
3336 let root_id = channel_model.root_id();
3337 let channel = Channel::from_model(channel_model);
3338
3339 response.send(proto::RenameChannelResponse {
3340 channel: Some(channel.to_proto()),
3341 })?;
3342
3343 let connection_pool = session.connection_pool().await;
3344 let update = proto::UpdateChannels {
3345 channels: vec![channel.to_proto()],
3346 ..Default::default()
3347 };
3348 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3349 if role.can_see_channel(channel.visibility) {
3350 session.peer.send(connection_id, update.clone())?;
3351 }
3352 }
3353
3354 Ok(())
3355}
3356
3357/// Move a channel to a new parent.
3358async fn move_channel(
3359 request: proto::MoveChannel,
3360 response: Response<proto::MoveChannel>,
3361 session: MessageContext,
3362) -> Result<()> {
3363 let channel_id = ChannelId::from_proto(request.channel_id);
3364 let to = ChannelId::from_proto(request.to);
3365
3366 let (root_id, channels) = session
3367 .db()
3368 .await
3369 .move_channel(channel_id, to, session.user_id())
3370 .await?;
3371
3372 let connection_pool = session.connection_pool().await;
3373 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3374 let channels = channels
3375 .iter()
3376 .filter_map(|channel| {
3377 if role.can_see_channel(channel.visibility) {
3378 Some(channel.to_proto())
3379 } else {
3380 None
3381 }
3382 })
3383 .collect::<Vec<_>>();
3384 if channels.is_empty() {
3385 continue;
3386 }
3387
3388 let update = proto::UpdateChannels {
3389 channels,
3390 ..Default::default()
3391 };
3392
3393 session.peer.send(connection_id, update.clone())?;
3394 }
3395
3396 response.send(Ack {})?;
3397 Ok(())
3398}
3399
3400async fn reorder_channel(
3401 request: proto::ReorderChannel,
3402 response: Response<proto::ReorderChannel>,
3403 session: MessageContext,
3404) -> Result<()> {
3405 let channel_id = ChannelId::from_proto(request.channel_id);
3406 let direction = request.direction();
3407
3408 let updated_channels = session
3409 .db()
3410 .await
3411 .reorder_channel(channel_id, direction, session.user_id())
3412 .await?;
3413
3414 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3415 let connection_pool = session.connection_pool().await;
3416 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3417 let channels = updated_channels
3418 .iter()
3419 .filter_map(|channel| {
3420 if role.can_see_channel(channel.visibility) {
3421 Some(channel.to_proto())
3422 } else {
3423 None
3424 }
3425 })
3426 .collect::<Vec<_>>();
3427
3428 if channels.is_empty() {
3429 continue;
3430 }
3431
3432 let update = proto::UpdateChannels {
3433 channels,
3434 ..Default::default()
3435 };
3436
3437 session.peer.send(connection_id, update.clone())?;
3438 }
3439 }
3440
3441 response.send(Ack {})?;
3442 Ok(())
3443}
3444
3445/// Get the list of channel members
3446async fn get_channel_members(
3447 request: proto::GetChannelMembers,
3448 response: Response<proto::GetChannelMembers>,
3449 session: MessageContext,
3450) -> Result<()> {
3451 let db = session.db().await;
3452 let channel_id = ChannelId::from_proto(request.channel_id);
3453 let limit = if request.limit == 0 {
3454 u16::MAX as u64
3455 } else {
3456 request.limit
3457 };
3458 let (members, users) = db
3459 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3460 .await?;
3461 response.send(proto::GetChannelMembersResponse { members, users })?;
3462 Ok(())
3463}
3464
3465/// Accept or decline a channel invitation.
3466async fn respond_to_channel_invite(
3467 request: proto::RespondToChannelInvite,
3468 response: Response<proto::RespondToChannelInvite>,
3469 session: MessageContext,
3470) -> Result<()> {
3471 let db = session.db().await;
3472 let channel_id = ChannelId::from_proto(request.channel_id);
3473 let RespondToChannelInvite {
3474 membership_update,
3475 notifications,
3476 } = db
3477 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3478 .await?;
3479
3480 let mut connection_pool = session.connection_pool().await;
3481 if let Some(membership_update) = membership_update {
3482 notify_membership_updated(
3483 &mut connection_pool,
3484 membership_update,
3485 session.user_id(),
3486 &session.peer,
3487 );
3488 } else {
3489 let update = proto::UpdateChannels {
3490 remove_channel_invitations: vec![channel_id.to_proto()],
3491 ..Default::default()
3492 };
3493
3494 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3495 session.peer.send(connection_id, update.clone())?;
3496 }
3497 };
3498
3499 send_notifications(&connection_pool, &session.peer, notifications);
3500
3501 response.send(proto::Ack {})?;
3502
3503 Ok(())
3504}
3505
3506/// Join the channels' room
3507async fn join_channel(
3508 request: proto::JoinChannel,
3509 response: Response<proto::JoinChannel>,
3510 session: MessageContext,
3511) -> Result<()> {
3512 let channel_id = ChannelId::from_proto(request.channel_id);
3513 join_channel_internal(channel_id, Box::new(response), session).await
3514}
3515
3516trait JoinChannelInternalResponse {
3517 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3518}
3519impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3520 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3521 Response::<proto::JoinChannel>::send(self, result)
3522 }
3523}
3524impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3525 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3526 Response::<proto::JoinRoom>::send(self, result)
3527 }
3528}
3529
3530async fn join_channel_internal(
3531 channel_id: ChannelId,
3532 response: Box<impl JoinChannelInternalResponse>,
3533 session: MessageContext,
3534) -> Result<()> {
3535 let joined_room = {
3536 let mut db = session.db().await;
3537 // If zed quits without leaving the room, and the user re-opens zed before the
3538 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3539 // room they were in.
3540 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3541 tracing::info!(
3542 stale_connection_id = %connection,
3543 "cleaning up stale connection",
3544 );
3545 drop(db);
3546 leave_room_for_session(&session, connection).await?;
3547 db = session.db().await;
3548 }
3549
3550 let (joined_room, membership_updated, role) = db
3551 .join_channel(channel_id, session.user_id(), session.connection_id)
3552 .await?;
3553
3554 let live_kit_connection_info =
3555 session
3556 .app_state
3557 .livekit_client
3558 .as_ref()
3559 .and_then(|live_kit| {
3560 let (can_publish, token) = if role == ChannelRole::Guest {
3561 (
3562 false,
3563 live_kit
3564 .guest_token(
3565 &joined_room.room.livekit_room,
3566 &session.user_id().to_string(),
3567 )
3568 .trace_err()?,
3569 )
3570 } else {
3571 (
3572 true,
3573 live_kit
3574 .room_token(
3575 &joined_room.room.livekit_room,
3576 &session.user_id().to_string(),
3577 )
3578 .trace_err()?,
3579 )
3580 };
3581
3582 Some(LiveKitConnectionInfo {
3583 server_url: live_kit.url().into(),
3584 token,
3585 can_publish,
3586 })
3587 });
3588
3589 response.send(proto::JoinRoomResponse {
3590 room: Some(joined_room.room.clone()),
3591 channel_id: joined_room
3592 .channel
3593 .as_ref()
3594 .map(|channel| channel.id.to_proto()),
3595 live_kit_connection_info,
3596 })?;
3597
3598 let mut connection_pool = session.connection_pool().await;
3599 if let Some(membership_updated) = membership_updated {
3600 notify_membership_updated(
3601 &mut connection_pool,
3602 membership_updated,
3603 session.user_id(),
3604 &session.peer,
3605 );
3606 }
3607
3608 room_updated(&joined_room.room, &session.peer);
3609
3610 joined_room
3611 };
3612
3613 channel_updated(
3614 &joined_room.channel.context("channel not returned")?,
3615 &joined_room.room,
3616 &session.peer,
3617 &*session.connection_pool().await,
3618 );
3619
3620 update_user_contacts(session.user_id(), &session).await?;
3621 Ok(())
3622}
3623
3624/// Start editing the channel notes
3625async fn join_channel_buffer(
3626 request: proto::JoinChannelBuffer,
3627 response: Response<proto::JoinChannelBuffer>,
3628 session: MessageContext,
3629) -> Result<()> {
3630 let db = session.db().await;
3631 let channel_id = ChannelId::from_proto(request.channel_id);
3632
3633 let open_response = db
3634 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3635 .await?;
3636
3637 let collaborators = open_response.collaborators.clone();
3638 response.send(open_response)?;
3639
3640 let update = UpdateChannelBufferCollaborators {
3641 channel_id: channel_id.to_proto(),
3642 collaborators: collaborators.clone(),
3643 };
3644 channel_buffer_updated(
3645 session.connection_id,
3646 collaborators
3647 .iter()
3648 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3649 &update,
3650 &session.peer,
3651 );
3652
3653 Ok(())
3654}
3655
3656/// Edit the channel notes
3657async fn update_channel_buffer(
3658 request: proto::UpdateChannelBuffer,
3659 session: MessageContext,
3660) -> Result<()> {
3661 let db = session.db().await;
3662 let channel_id = ChannelId::from_proto(request.channel_id);
3663
3664 let (collaborators, epoch, version) = db
3665 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3666 .await?;
3667
3668 channel_buffer_updated(
3669 session.connection_id,
3670 collaborators.clone(),
3671 &proto::UpdateChannelBuffer {
3672 channel_id: channel_id.to_proto(),
3673 operations: request.operations,
3674 },
3675 &session.peer,
3676 );
3677
3678 let pool = &*session.connection_pool().await;
3679
3680 let non_collaborators =
3681 pool.channel_connection_ids(channel_id)
3682 .filter_map(|(connection_id, _)| {
3683 if collaborators.contains(&connection_id) {
3684 None
3685 } else {
3686 Some(connection_id)
3687 }
3688 });
3689
3690 broadcast(None, non_collaborators, |peer_id| {
3691 session.peer.send(
3692 peer_id,
3693 proto::UpdateChannels {
3694 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3695 channel_id: channel_id.to_proto(),
3696 epoch: epoch as u64,
3697 version: version.clone(),
3698 }],
3699 ..Default::default()
3700 },
3701 )
3702 });
3703
3704 Ok(())
3705}
3706
3707/// Rejoin the channel notes after a connection blip
3708async fn rejoin_channel_buffers(
3709 request: proto::RejoinChannelBuffers,
3710 response: Response<proto::RejoinChannelBuffers>,
3711 session: MessageContext,
3712) -> Result<()> {
3713 let db = session.db().await;
3714 let buffers = db
3715 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3716 .await?;
3717
3718 for rejoined_buffer in &buffers {
3719 let collaborators_to_notify = rejoined_buffer
3720 .buffer
3721 .collaborators
3722 .iter()
3723 .filter_map(|c| Some(c.peer_id?.into()));
3724 channel_buffer_updated(
3725 session.connection_id,
3726 collaborators_to_notify,
3727 &proto::UpdateChannelBufferCollaborators {
3728 channel_id: rejoined_buffer.buffer.channel_id,
3729 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3730 },
3731 &session.peer,
3732 );
3733 }
3734
3735 response.send(proto::RejoinChannelBuffersResponse {
3736 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3737 })?;
3738
3739 Ok(())
3740}
3741
3742/// Stop editing the channel notes
3743async fn leave_channel_buffer(
3744 request: proto::LeaveChannelBuffer,
3745 response: Response<proto::LeaveChannelBuffer>,
3746 session: MessageContext,
3747) -> Result<()> {
3748 let db = session.db().await;
3749 let channel_id = ChannelId::from_proto(request.channel_id);
3750
3751 let left_buffer = db
3752 .leave_channel_buffer(channel_id, session.connection_id)
3753 .await?;
3754
3755 response.send(Ack {})?;
3756
3757 channel_buffer_updated(
3758 session.connection_id,
3759 left_buffer.connections,
3760 &proto::UpdateChannelBufferCollaborators {
3761 channel_id: channel_id.to_proto(),
3762 collaborators: left_buffer.collaborators,
3763 },
3764 &session.peer,
3765 );
3766
3767 Ok(())
3768}
3769
3770fn channel_buffer_updated<T: EnvelopedMessage>(
3771 sender_id: ConnectionId,
3772 collaborators: impl IntoIterator<Item = ConnectionId>,
3773 message: &T,
3774 peer: &Peer,
3775) {
3776 broadcast(Some(sender_id), collaborators, |peer_id| {
3777 peer.send(peer_id, message.clone())
3778 });
3779}
3780
3781fn send_notifications(
3782 connection_pool: &ConnectionPool,
3783 peer: &Peer,
3784 notifications: db::NotificationBatch,
3785) {
3786 for (user_id, notification) in notifications {
3787 for connection_id in connection_pool.user_connection_ids(user_id) {
3788 if let Err(error) = peer.send(
3789 connection_id,
3790 proto::AddNotification {
3791 notification: Some(notification.clone()),
3792 },
3793 ) {
3794 tracing::error!(
3795 "failed to send notification to {:?} {}",
3796 connection_id,
3797 error
3798 );
3799 }
3800 }
3801 }
3802}
3803
3804/// Send a message to the channel
3805async fn send_channel_message(
3806 request: proto::SendChannelMessage,
3807 response: Response<proto::SendChannelMessage>,
3808 session: MessageContext,
3809) -> Result<()> {
3810 // Validate the message body.
3811 let body = request.body.trim().to_string();
3812 if body.len() > MAX_MESSAGE_LEN {
3813 return Err(anyhow!("message is too long"))?;
3814 }
3815 if body.is_empty() {
3816 return Err(anyhow!("message can't be blank"))?;
3817 }
3818
3819 // TODO: adjust mentions if body is trimmed
3820
3821 let timestamp = OffsetDateTime::now_utc();
3822 let nonce = request.nonce.context("nonce can't be blank")?;
3823
3824 let channel_id = ChannelId::from_proto(request.channel_id);
3825 let CreatedChannelMessage {
3826 message_id,
3827 participant_connection_ids,
3828 notifications,
3829 } = session
3830 .db()
3831 .await
3832 .create_channel_message(
3833 channel_id,
3834 session.user_id(),
3835 &body,
3836 &request.mentions,
3837 timestamp,
3838 nonce.clone().into(),
3839 request.reply_to_message_id.map(MessageId::from_proto),
3840 )
3841 .await?;
3842
3843 let message = proto::ChannelMessage {
3844 sender_id: session.user_id().to_proto(),
3845 id: message_id.to_proto(),
3846 body,
3847 mentions: request.mentions,
3848 timestamp: timestamp.unix_timestamp() as u64,
3849 nonce: Some(nonce),
3850 reply_to_message_id: request.reply_to_message_id,
3851 edited_at: None,
3852 };
3853 broadcast(
3854 Some(session.connection_id),
3855 participant_connection_ids.clone(),
3856 |connection| {
3857 session.peer.send(
3858 connection,
3859 proto::ChannelMessageSent {
3860 channel_id: channel_id.to_proto(),
3861 message: Some(message.clone()),
3862 },
3863 )
3864 },
3865 );
3866 response.send(proto::SendChannelMessageResponse {
3867 message: Some(message),
3868 })?;
3869
3870 let pool = &*session.connection_pool().await;
3871 let non_participants =
3872 pool.channel_connection_ids(channel_id)
3873 .filter_map(|(connection_id, _)| {
3874 if participant_connection_ids.contains(&connection_id) {
3875 None
3876 } else {
3877 Some(connection_id)
3878 }
3879 });
3880 broadcast(None, non_participants, |peer_id| {
3881 session.peer.send(
3882 peer_id,
3883 proto::UpdateChannels {
3884 latest_channel_message_ids: vec![proto::ChannelMessageId {
3885 channel_id: channel_id.to_proto(),
3886 message_id: message_id.to_proto(),
3887 }],
3888 ..Default::default()
3889 },
3890 )
3891 });
3892 send_notifications(pool, &session.peer, notifications);
3893
3894 Ok(())
3895}
3896
3897/// Delete a channel message
3898async fn remove_channel_message(
3899 request: proto::RemoveChannelMessage,
3900 response: Response<proto::RemoveChannelMessage>,
3901 session: MessageContext,
3902) -> Result<()> {
3903 let channel_id = ChannelId::from_proto(request.channel_id);
3904 let message_id = MessageId::from_proto(request.message_id);
3905 let (connection_ids, existing_notification_ids) = session
3906 .db()
3907 .await
3908 .remove_channel_message(channel_id, message_id, session.user_id())
3909 .await?;
3910
3911 broadcast(
3912 Some(session.connection_id),
3913 connection_ids,
3914 move |connection| {
3915 session.peer.send(connection, request.clone())?;
3916
3917 for notification_id in &existing_notification_ids {
3918 session.peer.send(
3919 connection,
3920 proto::DeleteNotification {
3921 notification_id: (*notification_id).to_proto(),
3922 },
3923 )?;
3924 }
3925
3926 Ok(())
3927 },
3928 );
3929 response.send(proto::Ack {})?;
3930 Ok(())
3931}
3932
3933async fn update_channel_message(
3934 request: proto::UpdateChannelMessage,
3935 response: Response<proto::UpdateChannelMessage>,
3936 session: MessageContext,
3937) -> Result<()> {
3938 let channel_id = ChannelId::from_proto(request.channel_id);
3939 let message_id = MessageId::from_proto(request.message_id);
3940 let updated_at = OffsetDateTime::now_utc();
3941 let UpdatedChannelMessage {
3942 message_id,
3943 participant_connection_ids,
3944 notifications,
3945 reply_to_message_id,
3946 timestamp,
3947 deleted_mention_notification_ids,
3948 updated_mention_notifications,
3949 } = session
3950 .db()
3951 .await
3952 .update_channel_message(
3953 channel_id,
3954 message_id,
3955 session.user_id(),
3956 request.body.as_str(),
3957 &request.mentions,
3958 updated_at,
3959 )
3960 .await?;
3961
3962 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3963
3964 let message = proto::ChannelMessage {
3965 sender_id: session.user_id().to_proto(),
3966 id: message_id.to_proto(),
3967 body: request.body.clone(),
3968 mentions: request.mentions.clone(),
3969 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3970 nonce: Some(nonce),
3971 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3972 edited_at: Some(updated_at.unix_timestamp() as u64),
3973 };
3974
3975 response.send(proto::Ack {})?;
3976
3977 let pool = &*session.connection_pool().await;
3978 broadcast(
3979 Some(session.connection_id),
3980 participant_connection_ids,
3981 |connection| {
3982 session.peer.send(
3983 connection,
3984 proto::ChannelMessageUpdate {
3985 channel_id: channel_id.to_proto(),
3986 message: Some(message.clone()),
3987 },
3988 )?;
3989
3990 for notification_id in &deleted_mention_notification_ids {
3991 session.peer.send(
3992 connection,
3993 proto::DeleteNotification {
3994 notification_id: (*notification_id).to_proto(),
3995 },
3996 )?;
3997 }
3998
3999 for notification in &updated_mention_notifications {
4000 session.peer.send(
4001 connection,
4002 proto::UpdateNotification {
4003 notification: Some(notification.clone()),
4004 },
4005 )?;
4006 }
4007
4008 Ok(())
4009 },
4010 );
4011
4012 send_notifications(pool, &session.peer, notifications);
4013
4014 Ok(())
4015}
4016
4017/// Mark a channel message as read
4018async fn acknowledge_channel_message(
4019 request: proto::AckChannelMessage,
4020 session: MessageContext,
4021) -> Result<()> {
4022 let channel_id = ChannelId::from_proto(request.channel_id);
4023 let message_id = MessageId::from_proto(request.message_id);
4024 let notifications = session
4025 .db()
4026 .await
4027 .observe_channel_message(channel_id, session.user_id(), message_id)
4028 .await?;
4029 send_notifications(
4030 &*session.connection_pool().await,
4031 &session.peer,
4032 notifications,
4033 );
4034 Ok(())
4035}
4036
4037/// Mark a buffer version as synced
4038async fn acknowledge_buffer_version(
4039 request: proto::AckBufferOperation,
4040 session: MessageContext,
4041) -> Result<()> {
4042 let buffer_id = BufferId::from_proto(request.buffer_id);
4043 session
4044 .db()
4045 .await
4046 .observe_buffer_version(
4047 buffer_id,
4048 session.user_id(),
4049 request.epoch as i32,
4050 &request.version,
4051 )
4052 .await?;
4053 Ok(())
4054}
4055
4056/// Get a Supermaven API key for the user
4057async fn get_supermaven_api_key(
4058 _request: proto::GetSupermavenApiKey,
4059 response: Response<proto::GetSupermavenApiKey>,
4060 session: MessageContext,
4061) -> Result<()> {
4062 let user_id: String = session.user_id().to_string();
4063 if !session.is_staff() {
4064 return Err(anyhow!("supermaven not enabled for this account"))?;
4065 }
4066
4067 let email = session.email().context("user must have an email")?;
4068
4069 let supermaven_admin_api = session
4070 .supermaven_client
4071 .as_ref()
4072 .context("supermaven not configured")?;
4073
4074 let result = supermaven_admin_api
4075 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4076 .await?;
4077
4078 response.send(proto::GetSupermavenApiKeyResponse {
4079 api_key: result.api_key,
4080 })?;
4081
4082 Ok(())
4083}
4084
4085/// Start receiving chat updates for a channel
4086async fn join_channel_chat(
4087 request: proto::JoinChannelChat,
4088 response: Response<proto::JoinChannelChat>,
4089 session: MessageContext,
4090) -> Result<()> {
4091 let channel_id = ChannelId::from_proto(request.channel_id);
4092
4093 let db = session.db().await;
4094 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4095 .await?;
4096 let messages = db
4097 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4098 .await?;
4099 response.send(proto::JoinChannelChatResponse {
4100 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4101 messages,
4102 })?;
4103 Ok(())
4104}
4105
4106/// Stop receiving chat updates for a channel
4107async fn leave_channel_chat(
4108 request: proto::LeaveChannelChat,
4109 session: MessageContext,
4110) -> Result<()> {
4111 let channel_id = ChannelId::from_proto(request.channel_id);
4112 session
4113 .db()
4114 .await
4115 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4116 .await?;
4117 Ok(())
4118}
4119
4120/// Retrieve the chat history for a channel
4121async fn get_channel_messages(
4122 request: proto::GetChannelMessages,
4123 response: Response<proto::GetChannelMessages>,
4124 session: MessageContext,
4125) -> Result<()> {
4126 let channel_id = ChannelId::from_proto(request.channel_id);
4127 let messages = session
4128 .db()
4129 .await
4130 .get_channel_messages(
4131 channel_id,
4132 session.user_id(),
4133 MESSAGE_COUNT_PER_PAGE,
4134 Some(MessageId::from_proto(request.before_message_id)),
4135 )
4136 .await?;
4137 response.send(proto::GetChannelMessagesResponse {
4138 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4139 messages,
4140 })?;
4141 Ok(())
4142}
4143
4144/// Retrieve specific chat messages
4145async fn get_channel_messages_by_id(
4146 request: proto::GetChannelMessagesById,
4147 response: Response<proto::GetChannelMessagesById>,
4148 session: MessageContext,
4149) -> Result<()> {
4150 let message_ids = request
4151 .message_ids
4152 .iter()
4153 .map(|id| MessageId::from_proto(*id))
4154 .collect::<Vec<_>>();
4155 let messages = session
4156 .db()
4157 .await
4158 .get_channel_messages_by_id(session.user_id(), &message_ids)
4159 .await?;
4160 response.send(proto::GetChannelMessagesResponse {
4161 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4162 messages,
4163 })?;
4164 Ok(())
4165}
4166
4167/// Retrieve the current users notifications
4168async fn get_notifications(
4169 request: proto::GetNotifications,
4170 response: Response<proto::GetNotifications>,
4171 session: MessageContext,
4172) -> Result<()> {
4173 let notifications = session
4174 .db()
4175 .await
4176 .get_notifications(
4177 session.user_id(),
4178 NOTIFICATION_COUNT_PER_PAGE,
4179 request.before_id.map(db::NotificationId::from_proto),
4180 )
4181 .await?;
4182 response.send(proto::GetNotificationsResponse {
4183 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4184 notifications,
4185 })?;
4186 Ok(())
4187}
4188
4189/// Mark notifications as read
4190async fn mark_notification_as_read(
4191 request: proto::MarkNotificationRead,
4192 response: Response<proto::MarkNotificationRead>,
4193 session: MessageContext,
4194) -> Result<()> {
4195 let database = &session.db().await;
4196 let notifications = database
4197 .mark_notification_as_read_by_id(
4198 session.user_id(),
4199 NotificationId::from_proto(request.notification_id),
4200 )
4201 .await?;
4202 send_notifications(
4203 &*session.connection_pool().await,
4204 &session.peer,
4205 notifications,
4206 );
4207 response.send(proto::Ack {})?;
4208 Ok(())
4209}
4210
4211/// Accept the terms of service (tos) on behalf of the current user
4212async fn accept_terms_of_service(
4213 _request: proto::AcceptTermsOfService,
4214 response: Response<proto::AcceptTermsOfService>,
4215 session: MessageContext,
4216) -> Result<()> {
4217 let db = session.db().await;
4218
4219 let accepted_tos_at = Utc::now();
4220 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4221 .await?;
4222
4223 response.send(proto::AcceptTermsOfServiceResponse {
4224 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4225 })?;
4226
4227 Ok(())
4228}
4229
4230fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4231 let message = match message {
4232 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4233 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4234 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4235 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4236 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4237 code: frame.code.into(),
4238 reason: frame.reason.as_str().to_owned().into(),
4239 })),
4240 // We should never receive a frame while reading the message, according
4241 // to the `tungstenite` maintainers:
4242 //
4243 // > It cannot occur when you read messages from the WebSocket, but it
4244 // > can be used when you want to send the raw frames (e.g. you want to
4245 // > send the frames to the WebSocket without composing the full message first).
4246 // >
4247 // > — https://github.com/snapview/tungstenite-rs/issues/268
4248 TungsteniteMessage::Frame(_) => {
4249 bail!("received an unexpected frame while reading the message")
4250 }
4251 };
4252
4253 Ok(message)
4254}
4255
4256fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4257 match message {
4258 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4259 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4260 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4261 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4262 AxumMessage::Close(frame) => {
4263 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4264 code: frame.code.into(),
4265 reason: frame.reason.as_ref().into(),
4266 }))
4267 }
4268 }
4269}
4270
4271fn notify_membership_updated(
4272 connection_pool: &mut ConnectionPool,
4273 result: MembershipUpdated,
4274 user_id: UserId,
4275 peer: &Peer,
4276) {
4277 for membership in &result.new_channels.channel_memberships {
4278 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4279 }
4280 for channel_id in &result.removed_channels {
4281 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4282 }
4283
4284 let user_channels_update = proto::UpdateUserChannels {
4285 channel_memberships: result
4286 .new_channels
4287 .channel_memberships
4288 .iter()
4289 .map(|cm| proto::ChannelMembership {
4290 channel_id: cm.channel_id.to_proto(),
4291 role: cm.role.into(),
4292 })
4293 .collect(),
4294 ..Default::default()
4295 };
4296
4297 let mut update = build_channels_update(result.new_channels);
4298 update.delete_channels = result
4299 .removed_channels
4300 .into_iter()
4301 .map(|id| id.to_proto())
4302 .collect();
4303 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4304
4305 for connection_id in connection_pool.user_connection_ids(user_id) {
4306 peer.send(connection_id, user_channels_update.clone())
4307 .trace_err();
4308 peer.send(connection_id, update.clone()).trace_err();
4309 }
4310}
4311
4312fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4313 proto::UpdateUserChannels {
4314 channel_memberships: channels
4315 .channel_memberships
4316 .iter()
4317 .map(|m| proto::ChannelMembership {
4318 channel_id: m.channel_id.to_proto(),
4319 role: m.role.into(),
4320 })
4321 .collect(),
4322 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4323 observed_channel_message_id: channels.observed_channel_messages.clone(),
4324 }
4325}
4326
4327fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4328 let mut update = proto::UpdateChannels::default();
4329
4330 for channel in channels.channels {
4331 update.channels.push(channel.to_proto());
4332 }
4333
4334 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4335 update.latest_channel_message_ids = channels.latest_channel_messages;
4336
4337 for (channel_id, participants) in channels.channel_participants {
4338 update
4339 .channel_participants
4340 .push(proto::ChannelParticipants {
4341 channel_id: channel_id.to_proto(),
4342 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4343 });
4344 }
4345
4346 for channel in channels.invited_channels {
4347 update.channel_invitations.push(channel.to_proto());
4348 }
4349
4350 update
4351}
4352
4353fn build_initial_contacts_update(
4354 contacts: Vec<db::Contact>,
4355 pool: &ConnectionPool,
4356) -> proto::UpdateContacts {
4357 let mut update = proto::UpdateContacts::default();
4358
4359 for contact in contacts {
4360 match contact {
4361 db::Contact::Accepted { user_id, busy } => {
4362 update.contacts.push(contact_for_user(user_id, busy, pool));
4363 }
4364 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4365 db::Contact::Incoming { user_id } => {
4366 update
4367 .incoming_requests
4368 .push(proto::IncomingContactRequest {
4369 requester_id: user_id.to_proto(),
4370 })
4371 }
4372 }
4373 }
4374
4375 update
4376}
4377
4378fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4379 proto::Contact {
4380 user_id: user_id.to_proto(),
4381 online: pool.is_user_online(user_id),
4382 busy,
4383 }
4384}
4385
4386fn room_updated(room: &proto::Room, peer: &Peer) {
4387 broadcast(
4388 None,
4389 room.participants
4390 .iter()
4391 .filter_map(|participant| Some(participant.peer_id?.into())),
4392 |peer_id| {
4393 peer.send(
4394 peer_id,
4395 proto::RoomUpdated {
4396 room: Some(room.clone()),
4397 },
4398 )
4399 },
4400 );
4401}
4402
4403fn channel_updated(
4404 channel: &db::channel::Model,
4405 room: &proto::Room,
4406 peer: &Peer,
4407 pool: &ConnectionPool,
4408) {
4409 let participants = room
4410 .participants
4411 .iter()
4412 .map(|p| p.user_id)
4413 .collect::<Vec<_>>();
4414
4415 broadcast(
4416 None,
4417 pool.channel_connection_ids(channel.root_id())
4418 .filter_map(|(channel_id, role)| {
4419 role.can_see_channel(channel.visibility)
4420 .then_some(channel_id)
4421 }),
4422 |peer_id| {
4423 peer.send(
4424 peer_id,
4425 proto::UpdateChannels {
4426 channel_participants: vec![proto::ChannelParticipants {
4427 channel_id: channel.id.to_proto(),
4428 participant_user_ids: participants.clone(),
4429 }],
4430 ..Default::default()
4431 },
4432 )
4433 },
4434 );
4435}
4436
4437async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4438 let db = session.db().await;
4439
4440 let contacts = db.get_contacts(user_id).await?;
4441 let busy = db.is_user_busy(user_id).await?;
4442
4443 let pool = session.connection_pool().await;
4444 let updated_contact = contact_for_user(user_id, busy, &pool);
4445 for contact in contacts {
4446 if let db::Contact::Accepted {
4447 user_id: contact_user_id,
4448 ..
4449 } = contact
4450 {
4451 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4452 session
4453 .peer
4454 .send(
4455 contact_conn_id,
4456 proto::UpdateContacts {
4457 contacts: vec![updated_contact.clone()],
4458 remove_contacts: Default::default(),
4459 incoming_requests: Default::default(),
4460 remove_incoming_requests: Default::default(),
4461 outgoing_requests: Default::default(),
4462 remove_outgoing_requests: Default::default(),
4463 },
4464 )
4465 .trace_err();
4466 }
4467 }
4468 }
4469 Ok(())
4470}
4471
4472async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4473 let mut contacts_to_update = HashSet::default();
4474
4475 let room_id;
4476 let canceled_calls_to_user_ids;
4477 let livekit_room;
4478 let delete_livekit_room;
4479 let room;
4480 let channel;
4481
4482 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4483 contacts_to_update.insert(session.user_id());
4484
4485 for project in left_room.left_projects.values() {
4486 project_left(project, session);
4487 }
4488
4489 room_id = RoomId::from_proto(left_room.room.id);
4490 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4491 livekit_room = mem::take(&mut left_room.room.livekit_room);
4492 delete_livekit_room = left_room.deleted;
4493 room = mem::take(&mut left_room.room);
4494 channel = mem::take(&mut left_room.channel);
4495
4496 room_updated(&room, &session.peer);
4497 } else {
4498 return Ok(());
4499 }
4500
4501 if let Some(channel) = channel {
4502 channel_updated(
4503 &channel,
4504 &room,
4505 &session.peer,
4506 &*session.connection_pool().await,
4507 );
4508 }
4509
4510 {
4511 let pool = session.connection_pool().await;
4512 for canceled_user_id in canceled_calls_to_user_ids {
4513 for connection_id in pool.user_connection_ids(canceled_user_id) {
4514 session
4515 .peer
4516 .send(
4517 connection_id,
4518 proto::CallCanceled {
4519 room_id: room_id.to_proto(),
4520 },
4521 )
4522 .trace_err();
4523 }
4524 contacts_to_update.insert(canceled_user_id);
4525 }
4526 }
4527
4528 for contact_user_id in contacts_to_update {
4529 update_user_contacts(contact_user_id, session).await?;
4530 }
4531
4532 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4533 live_kit
4534 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4535 .await
4536 .trace_err();
4537
4538 if delete_livekit_room {
4539 live_kit.delete_room(livekit_room).await.trace_err();
4540 }
4541 }
4542
4543 Ok(())
4544}
4545
4546async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4547 let left_channel_buffers = session
4548 .db()
4549 .await
4550 .leave_channel_buffers(session.connection_id)
4551 .await?;
4552
4553 for left_buffer in left_channel_buffers {
4554 channel_buffer_updated(
4555 session.connection_id,
4556 left_buffer.connections,
4557 &proto::UpdateChannelBufferCollaborators {
4558 channel_id: left_buffer.channel_id.to_proto(),
4559 collaborators: left_buffer.collaborators,
4560 },
4561 &session.peer,
4562 );
4563 }
4564
4565 Ok(())
4566}
4567
4568fn project_left(project: &db::LeftProject, session: &Session) {
4569 for connection_id in &project.connection_ids {
4570 if project.should_unshare {
4571 session
4572 .peer
4573 .send(
4574 *connection_id,
4575 proto::UnshareProject {
4576 project_id: project.id.to_proto(),
4577 },
4578 )
4579 .trace_err();
4580 } else {
4581 session
4582 .peer
4583 .send(
4584 *connection_id,
4585 proto::RemoveProjectCollaborator {
4586 project_id: project.id.to_proto(),
4587 peer_id: Some(session.connection_id.into()),
4588 },
4589 )
4590 .trace_err();
4591 }
4592 }
4593}
4594
4595pub trait ResultExt {
4596 type Ok;
4597
4598 fn trace_err(self) -> Option<Self::Ok>;
4599}
4600
4601impl<T, E> ResultExt for Result<T, E>
4602where
4603 E: std::fmt::Debug,
4604{
4605 type Ok = T;
4606
4607 #[track_caller]
4608 fn trace_err(self) -> Option<T> {
4609 match self {
4610 Ok(value) => Some(value),
4611 Err(error) => {
4612 tracing::error!("{:?}", error);
4613 None
4614 }
4615 }
4616 }
4617}