1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{
8 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
9 MIN_ACCOUNT_AGE_FOR_LLM_USE,
10};
11use crate::stripe_client::StripeCustomerId;
12use crate::{
13 AppState, Error, Result, auth,
14 db::{
15 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
16 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
17 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
18 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
19 },
20 executor::Executor,
21};
22use anyhow::{Context as _, anyhow, bail};
23use async_tungstenite::tungstenite::{
24 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
25};
26use axum::headers::UserAgent;
27use axum::{
28 Extension, Router, TypedHeader,
29 body::Body,
30 extract::{
31 ConnectInfo, WebSocketUpgrade,
32 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
33 },
34 headers::{Header, HeaderName},
35 http::StatusCode,
36 middleware,
37 response::IntoResponse,
38 routing::get,
39};
40use chrono::Utc;
41use collections::{HashMap, HashSet};
42pub use connection_pool::{ConnectionPool, ZedVersion};
43use core::fmt::{self, Debug, Formatter};
44use futures::TryFutureExt as _;
45use reqwest_client::ReqwestClient;
46use rpc::proto::{MultiLspQuery, split_repository_update};
47use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
48use tracing::Span;
49
50use futures::{
51 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
52 stream::FuturesUnordered,
53};
54use prometheus::{IntGauge, register_int_gauge};
55use rpc::{
56 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
57 proto::{
58 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
59 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
60 },
61};
62use semantic_version::SemanticVersion;
63use serde::{Serialize, Serializer};
64use std::{
65 any::TypeId,
66 future::Future,
67 marker::PhantomData,
68 mem,
69 net::SocketAddr,
70 ops::{Deref, DerefMut},
71 rc::Rc,
72 sync::{
73 Arc, OnceLock,
74 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
75 },
76 time::{Duration, Instant},
77};
78use time::OffsetDateTime;
79use tokio::sync::{Semaphore, watch};
80use tower::ServiceBuilder;
81use tracing::{
82 Instrument,
83 field::{self},
84 info_span, instrument,
85};
86
87pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
88
89// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
90pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
91
92const MESSAGE_COUNT_PER_PAGE: usize = 100;
93const MAX_MESSAGE_LEN: usize = 1024;
94const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
95const MAX_CONCURRENT_CONNECTIONS: usize = 512;
96
97static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
98
99const TOTAL_DURATION_MS: &str = "total_duration_ms";
100const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
101const QUEUE_DURATION_MS: &str = "queue_duration_ms";
102const HOST_WAITING_MS: &str = "host_waiting_ms";
103
104type MessageHandler =
105 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
106
107pub struct ConnectionGuard;
108
109impl ConnectionGuard {
110 pub fn try_acquire() -> Result<Self, ()> {
111 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
112 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
113 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
114 tracing::error!(
115 "too many concurrent connections: {}",
116 current_connections + 1
117 );
118 return Err(());
119 }
120 Ok(ConnectionGuard)
121 }
122}
123
124impl Drop for ConnectionGuard {
125 fn drop(&mut self) {
126 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
127 }
128}
129
130struct Response<R> {
131 peer: Arc<Peer>,
132 receipt: Receipt<R>,
133 responded: Arc<AtomicBool>,
134}
135
136impl<R: RequestMessage> Response<R> {
137 fn send(self, payload: R::Response) -> Result<()> {
138 self.responded.store(true, SeqCst);
139 self.peer.respond(self.receipt, payload)?;
140 Ok(())
141 }
142}
143
144#[derive(Clone, Debug)]
145pub enum Principal {
146 User(User),
147 Impersonated { user: User, admin: User },
148}
149
150impl Principal {
151 fn user(&self) -> &User {
152 match self {
153 Principal::User(user) => user,
154 Principal::Impersonated { user, .. } => user,
155 }
156 }
157
158 fn update_span(&self, span: &tracing::Span) {
159 match &self {
160 Principal::User(user) => {
161 span.record("user_id", user.id.0);
162 span.record("login", &user.github_login);
163 }
164 Principal::Impersonated { user, admin } => {
165 span.record("user_id", user.id.0);
166 span.record("login", &user.github_login);
167 span.record("impersonator", &admin.github_login);
168 }
169 }
170 }
171}
172
173#[derive(Clone)]
174struct MessageContext {
175 session: Session,
176 span: tracing::Span,
177}
178
179impl Deref for MessageContext {
180 type Target = Session;
181
182 fn deref(&self) -> &Self::Target {
183 &self.session
184 }
185}
186
187impl MessageContext {
188 pub fn forward_request<T: RequestMessage>(
189 &self,
190 receiver_id: ConnectionId,
191 request: T,
192 ) -> impl Future<Output = anyhow::Result<T::Response>> {
193 let request_start_time = Instant::now();
194 let span = self.span.clone();
195 tracing::info!("start forwarding request");
196 self.peer
197 .forward_request(self.connection_id, receiver_id, request)
198 .inspect(move |_| {
199 span.record(
200 HOST_WAITING_MS,
201 request_start_time.elapsed().as_micros() as f64 / 1000.0,
202 );
203 })
204 .inspect_err(|_| tracing::error!("error forwarding request"))
205 .inspect_ok(|_| tracing::info!("finished forwarding request"))
206 }
207}
208
209#[derive(Clone)]
210struct Session {
211 principal: Principal,
212 connection_id: ConnectionId,
213 db: Arc<tokio::sync::Mutex<DbHandle>>,
214 peer: Arc<Peer>,
215 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
216 app_state: Arc<AppState>,
217 supermaven_client: Option<Arc<SupermavenAdminApi>>,
218 /// The GeoIP country code for the user.
219 #[allow(unused)]
220 geoip_country_code: Option<String>,
221 system_id: Option<String>,
222 _executor: Executor,
223}
224
225impl Session {
226 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
227 #[cfg(test)]
228 tokio::task::yield_now().await;
229 let guard = self.db.lock().await;
230 #[cfg(test)]
231 tokio::task::yield_now().await;
232 guard
233 }
234
235 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
236 #[cfg(test)]
237 tokio::task::yield_now().await;
238 let guard = self.connection_pool.lock();
239 ConnectionPoolGuard {
240 guard,
241 _not_send: PhantomData,
242 }
243 }
244
245 fn is_staff(&self) -> bool {
246 match &self.principal {
247 Principal::User(user) => user.admin,
248 Principal::Impersonated { .. } => true,
249 }
250 }
251
252 fn user_id(&self) -> UserId {
253 match &self.principal {
254 Principal::User(user) => user.id,
255 Principal::Impersonated { user, .. } => user.id,
256 }
257 }
258
259 pub fn email(&self) -> Option<String> {
260 match &self.principal {
261 Principal::User(user) => user.email_address.clone(),
262 Principal::Impersonated { user, .. } => user.email_address.clone(),
263 }
264 }
265}
266
267impl Debug for Session {
268 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
269 let mut result = f.debug_struct("Session");
270 match &self.principal {
271 Principal::User(user) => {
272 result.field("user", &user.github_login);
273 }
274 Principal::Impersonated { user, admin } => {
275 result.field("user", &user.github_login);
276 result.field("impersonator", &admin.github_login);
277 }
278 }
279 result.field("connection_id", &self.connection_id).finish()
280 }
281}
282
283struct DbHandle(Arc<Database>);
284
285impl Deref for DbHandle {
286 type Target = Database;
287
288 fn deref(&self) -> &Self::Target {
289 self.0.as_ref()
290 }
291}
292
293pub struct Server {
294 id: parking_lot::Mutex<ServerId>,
295 peer: Arc<Peer>,
296 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
297 app_state: Arc<AppState>,
298 handlers: HashMap<TypeId, MessageHandler>,
299 teardown: watch::Sender<bool>,
300}
301
302pub(crate) struct ConnectionPoolGuard<'a> {
303 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
304 _not_send: PhantomData<Rc<()>>,
305}
306
307#[derive(Serialize)]
308pub struct ServerSnapshot<'a> {
309 peer: &'a Peer,
310 #[serde(serialize_with = "serialize_deref")]
311 connection_pool: ConnectionPoolGuard<'a>,
312}
313
314pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
315where
316 S: Serializer,
317 T: Deref<Target = U>,
318 U: Serialize,
319{
320 Serialize::serialize(value.deref(), serializer)
321}
322
323impl Server {
324 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
325 let mut server = Self {
326 id: parking_lot::Mutex::new(id),
327 peer: Peer::new(id.0 as u32),
328 app_state: app_state.clone(),
329 connection_pool: Default::default(),
330 handlers: Default::default(),
331 teardown: watch::channel(false).0,
332 };
333
334 server
335 .add_request_handler(ping)
336 .add_request_handler(create_room)
337 .add_request_handler(join_room)
338 .add_request_handler(rejoin_room)
339 .add_request_handler(leave_room)
340 .add_request_handler(set_room_participant_role)
341 .add_request_handler(call)
342 .add_request_handler(cancel_call)
343 .add_message_handler(decline_call)
344 .add_request_handler(update_participant_location)
345 .add_request_handler(share_project)
346 .add_message_handler(unshare_project)
347 .add_request_handler(join_project)
348 .add_message_handler(leave_project)
349 .add_request_handler(update_project)
350 .add_request_handler(update_worktree)
351 .add_request_handler(update_repository)
352 .add_request_handler(remove_repository)
353 .add_message_handler(start_language_server)
354 .add_message_handler(update_language_server)
355 .add_message_handler(update_diagnostic_summary)
356 .add_message_handler(update_worktree_settings)
357 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
358 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
359 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
360 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
361 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
362 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
363 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
364 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
365 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
366 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
367 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
368 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
369 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
370 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
371 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
372 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
373 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
374 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
375 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
376 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
377 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
378 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
379 .add_request_handler(
380 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
381 )
382 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
383 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
384 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
385 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
386 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
387 .add_request_handler(
388 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
389 )
390 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
391 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
392 .add_request_handler(
393 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
394 )
395 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
396 .add_request_handler(
397 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
398 )
399 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
400 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
401 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
402 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
403 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
404 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
405 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
406 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
407 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
408 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
409 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
410 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
411 .add_request_handler(
412 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
413 )
414 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
415 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
416 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
417 .add_request_handler(multi_lsp_query)
418 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
419 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
420 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
421 .add_message_handler(create_buffer_for_peer)
422 .add_request_handler(update_buffer)
423 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
424 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
425 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
426 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
427 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
428 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
429 .add_message_handler(
430 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
431 )
432 .add_request_handler(get_users)
433 .add_request_handler(fuzzy_search_users)
434 .add_request_handler(request_contact)
435 .add_request_handler(remove_contact)
436 .add_request_handler(respond_to_contact_request)
437 .add_message_handler(subscribe_to_channels)
438 .add_request_handler(create_channel)
439 .add_request_handler(delete_channel)
440 .add_request_handler(invite_channel_member)
441 .add_request_handler(remove_channel_member)
442 .add_request_handler(set_channel_member_role)
443 .add_request_handler(set_channel_visibility)
444 .add_request_handler(rename_channel)
445 .add_request_handler(join_channel_buffer)
446 .add_request_handler(leave_channel_buffer)
447 .add_message_handler(update_channel_buffer)
448 .add_request_handler(rejoin_channel_buffers)
449 .add_request_handler(get_channel_members)
450 .add_request_handler(respond_to_channel_invite)
451 .add_request_handler(join_channel)
452 .add_request_handler(join_channel_chat)
453 .add_message_handler(leave_channel_chat)
454 .add_request_handler(send_channel_message)
455 .add_request_handler(remove_channel_message)
456 .add_request_handler(update_channel_message)
457 .add_request_handler(get_channel_messages)
458 .add_request_handler(get_channel_messages_by_id)
459 .add_request_handler(get_notifications)
460 .add_request_handler(mark_notification_as_read)
461 .add_request_handler(move_channel)
462 .add_request_handler(reorder_channel)
463 .add_request_handler(follow)
464 .add_message_handler(unfollow)
465 .add_message_handler(update_followers)
466 .add_request_handler(get_private_user_info)
467 .add_request_handler(get_llm_api_token)
468 .add_request_handler(accept_terms_of_service)
469 .add_message_handler(acknowledge_channel_message)
470 .add_message_handler(acknowledge_buffer_version)
471 .add_request_handler(get_supermaven_api_key)
472 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
473 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
474 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
475 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
476 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
477 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
478 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
479 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
480 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
481 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
482 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
483 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
484 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
485 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
486 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
487 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
488 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
489 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
490 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
491 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
492 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
493 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
494 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
495 .add_message_handler(update_context);
496
497 Arc::new(server)
498 }
499
500 pub async fn start(&self) -> Result<()> {
501 let server_id = *self.id.lock();
502 let app_state = self.app_state.clone();
503 let peer = self.peer.clone();
504 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
505 let pool = self.connection_pool.clone();
506 let livekit_client = self.app_state.livekit_client.clone();
507
508 let span = info_span!("start server");
509 self.app_state.executor.spawn_detached(
510 async move {
511 tracing::info!("waiting for cleanup timeout");
512 timeout.await;
513 tracing::info!("cleanup timeout expired, retrieving stale rooms");
514
515 app_state
516 .db
517 .delete_stale_channel_chat_participants(
518 &app_state.config.zed_environment,
519 server_id,
520 )
521 .await
522 .trace_err();
523
524 if let Some((room_ids, channel_ids)) = app_state
525 .db
526 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
527 .await
528 .trace_err()
529 {
530 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
531 tracing::info!(
532 stale_channel_buffer_count = channel_ids.len(),
533 "retrieved stale channel buffers"
534 );
535
536 for channel_id in channel_ids {
537 if let Some(refreshed_channel_buffer) = app_state
538 .db
539 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
540 .await
541 .trace_err()
542 {
543 for connection_id in refreshed_channel_buffer.connection_ids {
544 peer.send(
545 connection_id,
546 proto::UpdateChannelBufferCollaborators {
547 channel_id: channel_id.to_proto(),
548 collaborators: refreshed_channel_buffer
549 .collaborators
550 .clone(),
551 },
552 )
553 .trace_err();
554 }
555 }
556 }
557
558 for room_id in room_ids {
559 let mut contacts_to_update = HashSet::default();
560 let mut canceled_calls_to_user_ids = Vec::new();
561 let mut livekit_room = String::new();
562 let mut delete_livekit_room = false;
563
564 if let Some(mut refreshed_room) = app_state
565 .db
566 .clear_stale_room_participants(room_id, server_id)
567 .await
568 .trace_err()
569 {
570 tracing::info!(
571 room_id = room_id.0,
572 new_participant_count = refreshed_room.room.participants.len(),
573 "refreshed room"
574 );
575 room_updated(&refreshed_room.room, &peer);
576 if let Some(channel) = refreshed_room.channel.as_ref() {
577 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
578 }
579 contacts_to_update
580 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
581 contacts_to_update
582 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
583 canceled_calls_to_user_ids =
584 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
585 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
586 delete_livekit_room = refreshed_room.room.participants.is_empty();
587 }
588
589 {
590 let pool = pool.lock();
591 for canceled_user_id in canceled_calls_to_user_ids {
592 for connection_id in pool.user_connection_ids(canceled_user_id) {
593 peer.send(
594 connection_id,
595 proto::CallCanceled {
596 room_id: room_id.to_proto(),
597 },
598 )
599 .trace_err();
600 }
601 }
602 }
603
604 for user_id in contacts_to_update {
605 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
606 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
607 if let Some((busy, contacts)) = busy.zip(contacts) {
608 let pool = pool.lock();
609 let updated_contact = contact_for_user(user_id, busy, &pool);
610 for contact in contacts {
611 if let db::Contact::Accepted {
612 user_id: contact_user_id,
613 ..
614 } = contact
615 {
616 for contact_conn_id in
617 pool.user_connection_ids(contact_user_id)
618 {
619 peer.send(
620 contact_conn_id,
621 proto::UpdateContacts {
622 contacts: vec![updated_contact.clone()],
623 remove_contacts: Default::default(),
624 incoming_requests: Default::default(),
625 remove_incoming_requests: Default::default(),
626 outgoing_requests: Default::default(),
627 remove_outgoing_requests: Default::default(),
628 },
629 )
630 .trace_err();
631 }
632 }
633 }
634 }
635 }
636
637 if let Some(live_kit) = livekit_client.as_ref() {
638 if delete_livekit_room {
639 live_kit.delete_room(livekit_room).await.trace_err();
640 }
641 }
642 }
643 }
644
645 app_state
646 .db
647 .delete_stale_channel_chat_participants(
648 &app_state.config.zed_environment,
649 server_id,
650 )
651 .await
652 .trace_err();
653
654 app_state
655 .db
656 .clear_old_worktree_entries(server_id)
657 .await
658 .trace_err();
659
660 app_state
661 .db
662 .delete_stale_servers(&app_state.config.zed_environment, server_id)
663 .await
664 .trace_err();
665 }
666 .instrument(span),
667 );
668 Ok(())
669 }
670
671 pub fn teardown(&self) {
672 self.peer.teardown();
673 self.connection_pool.lock().reset();
674 let _ = self.teardown.send(true);
675 }
676
677 #[cfg(test)]
678 pub fn reset(&self, id: ServerId) {
679 self.teardown();
680 *self.id.lock() = id;
681 self.peer.reset(id.0 as u32);
682 let _ = self.teardown.send(false);
683 }
684
685 #[cfg(test)]
686 pub fn id(&self) -> ServerId {
687 *self.id.lock()
688 }
689
690 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
691 where
692 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
693 Fut: 'static + Send + Future<Output = Result<()>>,
694 M: EnvelopedMessage,
695 {
696 let prev_handler = self.handlers.insert(
697 TypeId::of::<M>(),
698 Box::new(move |envelope, session, span| {
699 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
700 let received_at = envelope.received_at;
701 tracing::info!("message received");
702 let start_time = Instant::now();
703 let future = (handler)(
704 *envelope,
705 MessageContext {
706 session,
707 span: span.clone(),
708 },
709 );
710 async move {
711 let result = future.await;
712 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
713 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
714 let queue_duration_ms = total_duration_ms - processing_duration_ms;
715 span.record(TOTAL_DURATION_MS, total_duration_ms);
716 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
717 span.record(QUEUE_DURATION_MS, queue_duration_ms);
718 match result {
719 Err(error) => {
720 tracing::error!(?error, "error handling message")
721 }
722 Ok(()) => tracing::info!("finished handling message"),
723 }
724 }
725 .boxed()
726 }),
727 );
728 if prev_handler.is_some() {
729 panic!("registered a handler for the same message twice");
730 }
731 self
732 }
733
734 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
735 where
736 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
737 Fut: 'static + Send + Future<Output = Result<()>>,
738 M: EnvelopedMessage,
739 {
740 self.add_handler(move |envelope, session| handler(envelope.payload, session));
741 self
742 }
743
744 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
745 where
746 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
747 Fut: Send + Future<Output = Result<()>>,
748 M: RequestMessage,
749 {
750 let handler = Arc::new(handler);
751 self.add_handler(move |envelope, session| {
752 let receipt = envelope.receipt();
753 let handler = handler.clone();
754 async move {
755 let peer = session.peer.clone();
756 let responded = Arc::new(AtomicBool::default());
757 let response = Response {
758 peer: peer.clone(),
759 responded: responded.clone(),
760 receipt,
761 };
762 match (handler)(envelope.payload, response, session).await {
763 Ok(()) => {
764 if responded.load(std::sync::atomic::Ordering::SeqCst) {
765 Ok(())
766 } else {
767 Err(anyhow!("handler did not send a response"))?
768 }
769 }
770 Err(error) => {
771 let proto_err = match &error {
772 Error::Internal(err) => err.to_proto(),
773 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
774 };
775 peer.respond_with_error(receipt, proto_err)?;
776 Err(error)
777 }
778 }
779 }
780 })
781 }
782
783 pub fn handle_connection(
784 self: &Arc<Self>,
785 connection: Connection,
786 address: String,
787 principal: Principal,
788 zed_version: ZedVersion,
789 release_channel: Option<String>,
790 user_agent: Option<String>,
791 geoip_country_code: Option<String>,
792 system_id: Option<String>,
793 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
794 executor: Executor,
795 connection_guard: Option<ConnectionGuard>,
796 ) -> impl Future<Output = ()> + use<> {
797 let this = self.clone();
798 let span = info_span!("handle connection", %address,
799 connection_id=field::Empty,
800 user_id=field::Empty,
801 login=field::Empty,
802 impersonator=field::Empty,
803 user_agent=field::Empty,
804 geoip_country_code=field::Empty,
805 release_channel=field::Empty,
806 );
807 principal.update_span(&span);
808 if let Some(user_agent) = user_agent {
809 span.record("user_agent", user_agent);
810 }
811 if let Some(release_channel) = release_channel {
812 span.record("release_channel", release_channel);
813 }
814
815 if let Some(country_code) = geoip_country_code.as_ref() {
816 span.record("geoip_country_code", country_code);
817 }
818
819 let mut teardown = self.teardown.subscribe();
820 async move {
821 if *teardown.borrow() {
822 tracing::error!("server is tearing down");
823 return;
824 }
825
826 let (connection_id, handle_io, mut incoming_rx) =
827 this.peer.add_connection(connection, {
828 let executor = executor.clone();
829 move |duration| executor.sleep(duration)
830 });
831 tracing::Span::current().record("connection_id", format!("{}", connection_id));
832
833 tracing::info!("connection opened");
834
835 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
836 let http_client = match ReqwestClient::user_agent(&user_agent) {
837 Ok(http_client) => Arc::new(http_client),
838 Err(error) => {
839 tracing::error!(?error, "failed to create HTTP client");
840 return;
841 }
842 };
843
844 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
845 |supermaven_admin_api_key| {
846 Arc::new(SupermavenAdminApi::new(
847 supermaven_admin_api_key.to_string(),
848 http_client.clone(),
849 ))
850 },
851 );
852
853 let session = Session {
854 principal: principal.clone(),
855 connection_id,
856 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
857 peer: this.peer.clone(),
858 connection_pool: this.connection_pool.clone(),
859 app_state: this.app_state.clone(),
860 geoip_country_code,
861 system_id,
862 _executor: executor.clone(),
863 supermaven_client,
864 };
865
866 if let Err(error) = this
867 .send_initial_client_update(
868 connection_id,
869 zed_version,
870 send_connection_id,
871 &session,
872 )
873 .await
874 {
875 tracing::error!(?error, "failed to send initial client update");
876 return;
877 }
878 drop(connection_guard);
879
880 let handle_io = handle_io.fuse();
881 futures::pin_mut!(handle_io);
882
883 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
884 // This prevents deadlocks when e.g., client A performs a request to client B and
885 // client B performs a request to client A. If both clients stop processing further
886 // messages until their respective request completes, they won't have a chance to
887 // respond to the other client's request and cause a deadlock.
888 //
889 // This arrangement ensures we will attempt to process earlier messages first, but fall
890 // back to processing messages arrived later in the spirit of making progress.
891 const MAX_CONCURRENT_HANDLERS: usize = 256;
892 let mut foreground_message_handlers = FuturesUnordered::new();
893 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
894 let get_concurrent_handlers = {
895 let concurrent_handlers = concurrent_handlers.clone();
896 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
897 };
898 loop {
899 let next_message = async {
900 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
901 let message = incoming_rx.next().await;
902 // Cache the concurrent_handlers here, so that we know what the
903 // queue looks like as each handler starts
904 (permit, message, get_concurrent_handlers())
905 }
906 .fuse();
907 futures::pin_mut!(next_message);
908 futures::select_biased! {
909 _ = teardown.changed().fuse() => return,
910 result = handle_io => {
911 if let Err(error) = result {
912 tracing::error!(?error, "error handling I/O");
913 }
914 break;
915 }
916 _ = foreground_message_handlers.next() => {}
917 next_message = next_message => {
918 let (permit, message, concurrent_handlers) = next_message;
919 if let Some(message) = message {
920 let type_name = message.payload_type_name();
921 // note: we copy all the fields from the parent span so we can query them in the logs.
922 // (https://github.com/tokio-rs/tracing/issues/2670).
923 let span = tracing::info_span!("receive message",
924 %connection_id,
925 %address,
926 type_name,
927 concurrent_handlers,
928 user_id=field::Empty,
929 login=field::Empty,
930 impersonator=field::Empty,
931 multi_lsp_query_request=field::Empty,
932 { TOTAL_DURATION_MS }=field::Empty,
933 { PROCESSING_DURATION_MS }=field::Empty,
934 { QUEUE_DURATION_MS }=field::Empty,
935 { HOST_WAITING_MS }=field::Empty
936 );
937 principal.update_span(&span);
938 let span_enter = span.enter();
939 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
940 let is_background = message.is_background();
941 let handle_message = (handler)(message, session.clone(), span.clone());
942 drop(span_enter);
943
944 let handle_message = async move {
945 handle_message.await;
946 drop(permit);
947 }.instrument(span);
948 if is_background {
949 executor.spawn_detached(handle_message);
950 } else {
951 foreground_message_handlers.push(handle_message);
952 }
953 } else {
954 tracing::error!("no message handler");
955 }
956 } else {
957 tracing::info!("connection closed");
958 break;
959 }
960 }
961 }
962 }
963
964 drop(foreground_message_handlers);
965 let concurrent_handlers = get_concurrent_handlers();
966 tracing::info!(concurrent_handlers, "signing out");
967 if let Err(error) = connection_lost(session, teardown, executor).await {
968 tracing::error!(?error, "error signing out");
969 }
970 }
971 .instrument(span)
972 }
973
974 async fn send_initial_client_update(
975 &self,
976 connection_id: ConnectionId,
977 zed_version: ZedVersion,
978 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
979 session: &Session,
980 ) -> Result<()> {
981 self.peer.send(
982 connection_id,
983 proto::Hello {
984 peer_id: Some(connection_id.into()),
985 },
986 )?;
987 tracing::info!("sent hello message");
988 if let Some(send_connection_id) = send_connection_id.take() {
989 let _ = send_connection_id.send(connection_id);
990 }
991
992 match &session.principal {
993 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
994 if !user.connected_once {
995 self.peer.send(connection_id, proto::ShowContacts {})?;
996 self.app_state
997 .db
998 .set_user_connected_once(user.id, true)
999 .await?;
1000 }
1001
1002 update_user_plan(session).await?;
1003
1004 let contacts = self.app_state.db.get_contacts(user.id).await?;
1005
1006 {
1007 let mut pool = self.connection_pool.lock();
1008 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1009 self.peer.send(
1010 connection_id,
1011 build_initial_contacts_update(contacts, &pool),
1012 )?;
1013 }
1014
1015 if should_auto_subscribe_to_channels(zed_version) {
1016 subscribe_user_to_channels(user.id, session).await?;
1017 }
1018
1019 if let Some(incoming_call) =
1020 self.app_state.db.incoming_call_for_user(user.id).await?
1021 {
1022 self.peer.send(connection_id, incoming_call)?;
1023 }
1024
1025 update_user_contacts(user.id, session).await?;
1026 }
1027 }
1028
1029 Ok(())
1030 }
1031
1032 pub async fn invite_code_redeemed(
1033 self: &Arc<Self>,
1034 inviter_id: UserId,
1035 invitee_id: UserId,
1036 ) -> Result<()> {
1037 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1038 if let Some(code) = &user.invite_code {
1039 let pool = self.connection_pool.lock();
1040 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1041 for connection_id in pool.user_connection_ids(inviter_id) {
1042 self.peer.send(
1043 connection_id,
1044 proto::UpdateContacts {
1045 contacts: vec![invitee_contact.clone()],
1046 ..Default::default()
1047 },
1048 )?;
1049 self.peer.send(
1050 connection_id,
1051 proto::UpdateInviteInfo {
1052 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1053 count: user.invite_count as u32,
1054 },
1055 )?;
1056 }
1057 }
1058 }
1059 Ok(())
1060 }
1061
1062 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1063 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1064 if let Some(invite_code) = &user.invite_code {
1065 let pool = self.connection_pool.lock();
1066 for connection_id in pool.user_connection_ids(user_id) {
1067 self.peer.send(
1068 connection_id,
1069 proto::UpdateInviteInfo {
1070 url: format!(
1071 "{}{}",
1072 self.app_state.config.invite_link_prefix, invite_code
1073 ),
1074 count: user.invite_count as u32,
1075 },
1076 )?;
1077 }
1078 }
1079 }
1080 Ok(())
1081 }
1082
1083 pub async fn update_plan_for_user(
1084 self: &Arc<Self>,
1085 user_id: UserId,
1086 update_user_plan: proto::UpdateUserPlan,
1087 ) -> Result<()> {
1088 let pool = self.connection_pool.lock();
1089 for connection_id in pool.user_connection_ids(user_id) {
1090 self.peer
1091 .send(connection_id, update_user_plan.clone())
1092 .trace_err();
1093 }
1094
1095 Ok(())
1096 }
1097
1098 /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1099 /// message on the Collab server.
1100 ///
1101 /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1102 pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1103 let user = self
1104 .app_state
1105 .db
1106 .get_user_by_id(user_id)
1107 .await?
1108 .context("user not found")?;
1109
1110 let update_user_plan = make_update_user_plan_message(
1111 &user,
1112 user.admin,
1113 &self.app_state.db,
1114 self.app_state.llm_db.clone(),
1115 )
1116 .await?;
1117
1118 self.update_plan_for_user(user_id, update_user_plan).await
1119 }
1120
1121 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1122 let pool = self.connection_pool.lock();
1123 for connection_id in pool.user_connection_ids(user_id) {
1124 self.peer
1125 .send(connection_id, proto::RefreshLlmToken {})
1126 .trace_err();
1127 }
1128 }
1129
1130 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1131 ServerSnapshot {
1132 connection_pool: ConnectionPoolGuard {
1133 guard: self.connection_pool.lock(),
1134 _not_send: PhantomData,
1135 },
1136 peer: &self.peer,
1137 }
1138 }
1139}
1140
1141impl Deref for ConnectionPoolGuard<'_> {
1142 type Target = ConnectionPool;
1143
1144 fn deref(&self) -> &Self::Target {
1145 &self.guard
1146 }
1147}
1148
1149impl DerefMut for ConnectionPoolGuard<'_> {
1150 fn deref_mut(&mut self) -> &mut Self::Target {
1151 &mut self.guard
1152 }
1153}
1154
1155impl Drop for ConnectionPoolGuard<'_> {
1156 fn drop(&mut self) {
1157 #[cfg(test)]
1158 self.check_invariants();
1159 }
1160}
1161
1162fn broadcast<F>(
1163 sender_id: Option<ConnectionId>,
1164 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1165 mut f: F,
1166) where
1167 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1168{
1169 for receiver_id in receiver_ids {
1170 if Some(receiver_id) != sender_id {
1171 if let Err(error) = f(receiver_id) {
1172 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1173 }
1174 }
1175 }
1176}
1177
1178pub struct ProtocolVersion(u32);
1179
1180impl Header for ProtocolVersion {
1181 fn name() -> &'static HeaderName {
1182 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1183 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1184 }
1185
1186 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1187 where
1188 Self: Sized,
1189 I: Iterator<Item = &'i axum::http::HeaderValue>,
1190 {
1191 let version = values
1192 .next()
1193 .ok_or_else(axum::headers::Error::invalid)?
1194 .to_str()
1195 .map_err(|_| axum::headers::Error::invalid())?
1196 .parse()
1197 .map_err(|_| axum::headers::Error::invalid())?;
1198 Ok(Self(version))
1199 }
1200
1201 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1202 values.extend([self.0.to_string().parse().unwrap()]);
1203 }
1204}
1205
1206pub struct AppVersionHeader(SemanticVersion);
1207impl Header for AppVersionHeader {
1208 fn name() -> &'static HeaderName {
1209 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1210 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1211 }
1212
1213 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1214 where
1215 Self: Sized,
1216 I: Iterator<Item = &'i axum::http::HeaderValue>,
1217 {
1218 let version = values
1219 .next()
1220 .ok_or_else(axum::headers::Error::invalid)?
1221 .to_str()
1222 .map_err(|_| axum::headers::Error::invalid())?
1223 .parse()
1224 .map_err(|_| axum::headers::Error::invalid())?;
1225 Ok(Self(version))
1226 }
1227
1228 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1229 values.extend([self.0.to_string().parse().unwrap()]);
1230 }
1231}
1232
1233#[derive(Debug)]
1234pub struct ReleaseChannelHeader(String);
1235
1236impl Header for ReleaseChannelHeader {
1237 fn name() -> &'static HeaderName {
1238 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1239 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1240 }
1241
1242 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1243 where
1244 Self: Sized,
1245 I: Iterator<Item = &'i axum::http::HeaderValue>,
1246 {
1247 Ok(Self(
1248 values
1249 .next()
1250 .ok_or_else(axum::headers::Error::invalid)?
1251 .to_str()
1252 .map_err(|_| axum::headers::Error::invalid())?
1253 .to_owned(),
1254 ))
1255 }
1256
1257 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1258 values.extend([self.0.parse().unwrap()]);
1259 }
1260}
1261
1262pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1263 Router::new()
1264 .route("/rpc", get(handle_websocket_request))
1265 .layer(
1266 ServiceBuilder::new()
1267 .layer(Extension(server.app_state.clone()))
1268 .layer(middleware::from_fn(auth::validate_header)),
1269 )
1270 .route("/metrics", get(handle_metrics))
1271 .layer(Extension(server))
1272}
1273
1274pub async fn handle_websocket_request(
1275 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1276 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1277 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1278 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1279 Extension(server): Extension<Arc<Server>>,
1280 Extension(principal): Extension<Principal>,
1281 user_agent: Option<TypedHeader<UserAgent>>,
1282 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1283 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1284 ws: WebSocketUpgrade,
1285) -> axum::response::Response {
1286 if protocol_version != rpc::PROTOCOL_VERSION {
1287 return (
1288 StatusCode::UPGRADE_REQUIRED,
1289 "client must be upgraded".to_string(),
1290 )
1291 .into_response();
1292 }
1293
1294 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1295 return (
1296 StatusCode::UPGRADE_REQUIRED,
1297 "no version header found".to_string(),
1298 )
1299 .into_response();
1300 };
1301
1302 let release_channel = release_channel_header.map(|header| header.0.0);
1303
1304 if !version.can_collaborate() {
1305 return (
1306 StatusCode::UPGRADE_REQUIRED,
1307 "client must be upgraded".to_string(),
1308 )
1309 .into_response();
1310 }
1311
1312 let socket_address = socket_address.to_string();
1313
1314 // Acquire connection guard before WebSocket upgrade
1315 let connection_guard = match ConnectionGuard::try_acquire() {
1316 Ok(guard) => guard,
1317 Err(()) => {
1318 return (
1319 StatusCode::SERVICE_UNAVAILABLE,
1320 "Too many concurrent connections",
1321 )
1322 .into_response();
1323 }
1324 };
1325
1326 ws.on_upgrade(move |socket| {
1327 let socket = socket
1328 .map_ok(to_tungstenite_message)
1329 .err_into()
1330 .with(|message| async move { to_axum_message(message) });
1331 let connection = Connection::new(Box::pin(socket));
1332 async move {
1333 server
1334 .handle_connection(
1335 connection,
1336 socket_address,
1337 principal,
1338 version,
1339 release_channel,
1340 user_agent.map(|header| header.to_string()),
1341 country_code_header.map(|header| header.to_string()),
1342 system_id_header.map(|header| header.to_string()),
1343 None,
1344 Executor::Production,
1345 Some(connection_guard),
1346 )
1347 .await;
1348 }
1349 })
1350}
1351
1352pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1353 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1354 let connections_metric = CONNECTIONS_METRIC
1355 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1356
1357 let connections = server
1358 .connection_pool
1359 .lock()
1360 .connections()
1361 .filter(|connection| !connection.admin)
1362 .count();
1363 connections_metric.set(connections as _);
1364
1365 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1366 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1367 register_int_gauge!(
1368 "shared_projects",
1369 "number of open projects with one or more guests"
1370 )
1371 .unwrap()
1372 });
1373
1374 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1375 shared_projects_metric.set(shared_projects as _);
1376
1377 let encoder = prometheus::TextEncoder::new();
1378 let metric_families = prometheus::gather();
1379 let encoded_metrics = encoder
1380 .encode_to_string(&metric_families)
1381 .map_err(|err| anyhow!("{err}"))?;
1382 Ok(encoded_metrics)
1383}
1384
1385#[instrument(err, skip(executor))]
1386async fn connection_lost(
1387 session: Session,
1388 mut teardown: watch::Receiver<bool>,
1389 executor: Executor,
1390) -> Result<()> {
1391 session.peer.disconnect(session.connection_id);
1392 session
1393 .connection_pool()
1394 .await
1395 .remove_connection(session.connection_id)?;
1396
1397 session
1398 .db()
1399 .await
1400 .connection_lost(session.connection_id)
1401 .await
1402 .trace_err();
1403
1404 futures::select_biased! {
1405 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1406
1407 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1408 leave_room_for_session(&session, session.connection_id).await.trace_err();
1409 leave_channel_buffers_for_session(&session)
1410 .await
1411 .trace_err();
1412
1413 if !session
1414 .connection_pool()
1415 .await
1416 .is_user_online(session.user_id())
1417 {
1418 let db = session.db().await;
1419 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1420 room_updated(&room, &session.peer);
1421 }
1422 }
1423
1424 update_user_contacts(session.user_id(), &session).await?;
1425 },
1426 _ = teardown.changed().fuse() => {}
1427 }
1428
1429 Ok(())
1430}
1431
1432/// Acknowledges a ping from a client, used to keep the connection alive.
1433async fn ping(
1434 _: proto::Ping,
1435 response: Response<proto::Ping>,
1436 _session: MessageContext,
1437) -> Result<()> {
1438 response.send(proto::Ack {})?;
1439 Ok(())
1440}
1441
1442/// Creates a new room for calling (outside of channels)
1443async fn create_room(
1444 _request: proto::CreateRoom,
1445 response: Response<proto::CreateRoom>,
1446 session: MessageContext,
1447) -> Result<()> {
1448 let livekit_room = nanoid::nanoid!(30);
1449
1450 let live_kit_connection_info = util::maybe!(async {
1451 let live_kit = session.app_state.livekit_client.as_ref();
1452 let live_kit = live_kit?;
1453 let user_id = session.user_id().to_string();
1454
1455 let token = live_kit
1456 .room_token(&livekit_room, &user_id.to_string())
1457 .trace_err()?;
1458
1459 Some(proto::LiveKitConnectionInfo {
1460 server_url: live_kit.url().into(),
1461 token,
1462 can_publish: true,
1463 })
1464 })
1465 .await;
1466
1467 let room = session
1468 .db()
1469 .await
1470 .create_room(session.user_id(), session.connection_id, &livekit_room)
1471 .await?;
1472
1473 response.send(proto::CreateRoomResponse {
1474 room: Some(room.clone()),
1475 live_kit_connection_info,
1476 })?;
1477
1478 update_user_contacts(session.user_id(), &session).await?;
1479 Ok(())
1480}
1481
1482/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1483async fn join_room(
1484 request: proto::JoinRoom,
1485 response: Response<proto::JoinRoom>,
1486 session: MessageContext,
1487) -> Result<()> {
1488 let room_id = RoomId::from_proto(request.id);
1489
1490 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1491
1492 if let Some(channel_id) = channel_id {
1493 return join_channel_internal(channel_id, Box::new(response), session).await;
1494 }
1495
1496 let joined_room = {
1497 let room = session
1498 .db()
1499 .await
1500 .join_room(room_id, session.user_id(), session.connection_id)
1501 .await?;
1502 room_updated(&room.room, &session.peer);
1503 room.into_inner()
1504 };
1505
1506 for connection_id in session
1507 .connection_pool()
1508 .await
1509 .user_connection_ids(session.user_id())
1510 {
1511 session
1512 .peer
1513 .send(
1514 connection_id,
1515 proto::CallCanceled {
1516 room_id: room_id.to_proto(),
1517 },
1518 )
1519 .trace_err();
1520 }
1521
1522 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1523 {
1524 live_kit
1525 .room_token(
1526 &joined_room.room.livekit_room,
1527 &session.user_id().to_string(),
1528 )
1529 .trace_err()
1530 .map(|token| proto::LiveKitConnectionInfo {
1531 server_url: live_kit.url().into(),
1532 token,
1533 can_publish: true,
1534 })
1535 } else {
1536 None
1537 };
1538
1539 response.send(proto::JoinRoomResponse {
1540 room: Some(joined_room.room),
1541 channel_id: None,
1542 live_kit_connection_info,
1543 })?;
1544
1545 update_user_contacts(session.user_id(), &session).await?;
1546 Ok(())
1547}
1548
1549/// Rejoin room is used to reconnect to a room after connection errors.
1550async fn rejoin_room(
1551 request: proto::RejoinRoom,
1552 response: Response<proto::RejoinRoom>,
1553 session: MessageContext,
1554) -> Result<()> {
1555 let room;
1556 let channel;
1557 {
1558 let mut rejoined_room = session
1559 .db()
1560 .await
1561 .rejoin_room(request, session.user_id(), session.connection_id)
1562 .await?;
1563
1564 response.send(proto::RejoinRoomResponse {
1565 room: Some(rejoined_room.room.clone()),
1566 reshared_projects: rejoined_room
1567 .reshared_projects
1568 .iter()
1569 .map(|project| proto::ResharedProject {
1570 id: project.id.to_proto(),
1571 collaborators: project
1572 .collaborators
1573 .iter()
1574 .map(|collaborator| collaborator.to_proto())
1575 .collect(),
1576 })
1577 .collect(),
1578 rejoined_projects: rejoined_room
1579 .rejoined_projects
1580 .iter()
1581 .map(|rejoined_project| rejoined_project.to_proto())
1582 .collect(),
1583 })?;
1584 room_updated(&rejoined_room.room, &session.peer);
1585
1586 for project in &rejoined_room.reshared_projects {
1587 for collaborator in &project.collaborators {
1588 session
1589 .peer
1590 .send(
1591 collaborator.connection_id,
1592 proto::UpdateProjectCollaborator {
1593 project_id: project.id.to_proto(),
1594 old_peer_id: Some(project.old_connection_id.into()),
1595 new_peer_id: Some(session.connection_id.into()),
1596 },
1597 )
1598 .trace_err();
1599 }
1600
1601 broadcast(
1602 Some(session.connection_id),
1603 project
1604 .collaborators
1605 .iter()
1606 .map(|collaborator| collaborator.connection_id),
1607 |connection_id| {
1608 session.peer.forward_send(
1609 session.connection_id,
1610 connection_id,
1611 proto::UpdateProject {
1612 project_id: project.id.to_proto(),
1613 worktrees: project.worktrees.clone(),
1614 },
1615 )
1616 },
1617 );
1618 }
1619
1620 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1621
1622 let rejoined_room = rejoined_room.into_inner();
1623
1624 room = rejoined_room.room;
1625 channel = rejoined_room.channel;
1626 }
1627
1628 if let Some(channel) = channel {
1629 channel_updated(
1630 &channel,
1631 &room,
1632 &session.peer,
1633 &*session.connection_pool().await,
1634 );
1635 }
1636
1637 update_user_contacts(session.user_id(), &session).await?;
1638 Ok(())
1639}
1640
1641fn notify_rejoined_projects(
1642 rejoined_projects: &mut Vec<RejoinedProject>,
1643 session: &Session,
1644) -> Result<()> {
1645 for project in rejoined_projects.iter() {
1646 for collaborator in &project.collaborators {
1647 session
1648 .peer
1649 .send(
1650 collaborator.connection_id,
1651 proto::UpdateProjectCollaborator {
1652 project_id: project.id.to_proto(),
1653 old_peer_id: Some(project.old_connection_id.into()),
1654 new_peer_id: Some(session.connection_id.into()),
1655 },
1656 )
1657 .trace_err();
1658 }
1659 }
1660
1661 for project in rejoined_projects {
1662 for worktree in mem::take(&mut project.worktrees) {
1663 // Stream this worktree's entries.
1664 let message = proto::UpdateWorktree {
1665 project_id: project.id.to_proto(),
1666 worktree_id: worktree.id,
1667 abs_path: worktree.abs_path.clone(),
1668 root_name: worktree.root_name,
1669 updated_entries: worktree.updated_entries,
1670 removed_entries: worktree.removed_entries,
1671 scan_id: worktree.scan_id,
1672 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1673 updated_repositories: worktree.updated_repositories,
1674 removed_repositories: worktree.removed_repositories,
1675 };
1676 for update in proto::split_worktree_update(message) {
1677 session.peer.send(session.connection_id, update)?;
1678 }
1679
1680 // Stream this worktree's diagnostics.
1681 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1682 if let Some(summary) = worktree_diagnostics.next() {
1683 let message = proto::UpdateDiagnosticSummary {
1684 project_id: project.id.to_proto(),
1685 worktree_id: worktree.id,
1686 summary: Some(summary),
1687 more_summaries: worktree_diagnostics.collect(),
1688 };
1689 session.peer.send(session.connection_id, message)?;
1690 }
1691
1692 for settings_file in worktree.settings_files {
1693 session.peer.send(
1694 session.connection_id,
1695 proto::UpdateWorktreeSettings {
1696 project_id: project.id.to_proto(),
1697 worktree_id: worktree.id,
1698 path: settings_file.path,
1699 content: Some(settings_file.content),
1700 kind: Some(settings_file.kind.to_proto().into()),
1701 },
1702 )?;
1703 }
1704 }
1705
1706 for repository in mem::take(&mut project.updated_repositories) {
1707 for update in split_repository_update(repository) {
1708 session.peer.send(session.connection_id, update)?;
1709 }
1710 }
1711
1712 for id in mem::take(&mut project.removed_repositories) {
1713 session.peer.send(
1714 session.connection_id,
1715 proto::RemoveRepository {
1716 project_id: project.id.to_proto(),
1717 id,
1718 },
1719 )?;
1720 }
1721 }
1722
1723 Ok(())
1724}
1725
1726/// leave room disconnects from the room.
1727async fn leave_room(
1728 _: proto::LeaveRoom,
1729 response: Response<proto::LeaveRoom>,
1730 session: MessageContext,
1731) -> Result<()> {
1732 leave_room_for_session(&session, session.connection_id).await?;
1733 response.send(proto::Ack {})?;
1734 Ok(())
1735}
1736
1737/// Updates the permissions of someone else in the room.
1738async fn set_room_participant_role(
1739 request: proto::SetRoomParticipantRole,
1740 response: Response<proto::SetRoomParticipantRole>,
1741 session: MessageContext,
1742) -> Result<()> {
1743 let user_id = UserId::from_proto(request.user_id);
1744 let role = ChannelRole::from(request.role());
1745
1746 let (livekit_room, can_publish) = {
1747 let room = session
1748 .db()
1749 .await
1750 .set_room_participant_role(
1751 session.user_id(),
1752 RoomId::from_proto(request.room_id),
1753 user_id,
1754 role,
1755 )
1756 .await?;
1757
1758 let livekit_room = room.livekit_room.clone();
1759 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1760 room_updated(&room, &session.peer);
1761 (livekit_room, can_publish)
1762 };
1763
1764 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1765 live_kit
1766 .update_participant(
1767 livekit_room.clone(),
1768 request.user_id.to_string(),
1769 livekit_api::proto::ParticipantPermission {
1770 can_subscribe: true,
1771 can_publish,
1772 can_publish_data: can_publish,
1773 hidden: false,
1774 recorder: false,
1775 },
1776 )
1777 .await
1778 .trace_err();
1779 }
1780
1781 response.send(proto::Ack {})?;
1782 Ok(())
1783}
1784
1785/// Call someone else into the current room
1786async fn call(
1787 request: proto::Call,
1788 response: Response<proto::Call>,
1789 session: MessageContext,
1790) -> Result<()> {
1791 let room_id = RoomId::from_proto(request.room_id);
1792 let calling_user_id = session.user_id();
1793 let calling_connection_id = session.connection_id;
1794 let called_user_id = UserId::from_proto(request.called_user_id);
1795 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1796 if !session
1797 .db()
1798 .await
1799 .has_contact(calling_user_id, called_user_id)
1800 .await?
1801 {
1802 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1803 }
1804
1805 let incoming_call = {
1806 let (room, incoming_call) = &mut *session
1807 .db()
1808 .await
1809 .call(
1810 room_id,
1811 calling_user_id,
1812 calling_connection_id,
1813 called_user_id,
1814 initial_project_id,
1815 )
1816 .await?;
1817 room_updated(room, &session.peer);
1818 mem::take(incoming_call)
1819 };
1820 update_user_contacts(called_user_id, &session).await?;
1821
1822 let mut calls = session
1823 .connection_pool()
1824 .await
1825 .user_connection_ids(called_user_id)
1826 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1827 .collect::<FuturesUnordered<_>>();
1828
1829 while let Some(call_response) = calls.next().await {
1830 match call_response.as_ref() {
1831 Ok(_) => {
1832 response.send(proto::Ack {})?;
1833 return Ok(());
1834 }
1835 Err(_) => {
1836 call_response.trace_err();
1837 }
1838 }
1839 }
1840
1841 {
1842 let room = session
1843 .db()
1844 .await
1845 .call_failed(room_id, called_user_id)
1846 .await?;
1847 room_updated(&room, &session.peer);
1848 }
1849 update_user_contacts(called_user_id, &session).await?;
1850
1851 Err(anyhow!("failed to ring user"))?
1852}
1853
1854/// Cancel an outgoing call.
1855async fn cancel_call(
1856 request: proto::CancelCall,
1857 response: Response<proto::CancelCall>,
1858 session: MessageContext,
1859) -> Result<()> {
1860 let called_user_id = UserId::from_proto(request.called_user_id);
1861 let room_id = RoomId::from_proto(request.room_id);
1862 {
1863 let room = session
1864 .db()
1865 .await
1866 .cancel_call(room_id, session.connection_id, called_user_id)
1867 .await?;
1868 room_updated(&room, &session.peer);
1869 }
1870
1871 for connection_id in session
1872 .connection_pool()
1873 .await
1874 .user_connection_ids(called_user_id)
1875 {
1876 session
1877 .peer
1878 .send(
1879 connection_id,
1880 proto::CallCanceled {
1881 room_id: room_id.to_proto(),
1882 },
1883 )
1884 .trace_err();
1885 }
1886 response.send(proto::Ack {})?;
1887
1888 update_user_contacts(called_user_id, &session).await?;
1889 Ok(())
1890}
1891
1892/// Decline an incoming call.
1893async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1894 let room_id = RoomId::from_proto(message.room_id);
1895 {
1896 let room = session
1897 .db()
1898 .await
1899 .decline_call(Some(room_id), session.user_id())
1900 .await?
1901 .context("declining call")?;
1902 room_updated(&room, &session.peer);
1903 }
1904
1905 for connection_id in session
1906 .connection_pool()
1907 .await
1908 .user_connection_ids(session.user_id())
1909 {
1910 session
1911 .peer
1912 .send(
1913 connection_id,
1914 proto::CallCanceled {
1915 room_id: room_id.to_proto(),
1916 },
1917 )
1918 .trace_err();
1919 }
1920 update_user_contacts(session.user_id(), &session).await?;
1921 Ok(())
1922}
1923
1924/// Updates other participants in the room with your current location.
1925async fn update_participant_location(
1926 request: proto::UpdateParticipantLocation,
1927 response: Response<proto::UpdateParticipantLocation>,
1928 session: MessageContext,
1929) -> Result<()> {
1930 let room_id = RoomId::from_proto(request.room_id);
1931 let location = request.location.context("invalid location")?;
1932
1933 let db = session.db().await;
1934 let room = db
1935 .update_room_participant_location(room_id, session.connection_id, location)
1936 .await?;
1937
1938 room_updated(&room, &session.peer);
1939 response.send(proto::Ack {})?;
1940 Ok(())
1941}
1942
1943/// Share a project into the room.
1944async fn share_project(
1945 request: proto::ShareProject,
1946 response: Response<proto::ShareProject>,
1947 session: MessageContext,
1948) -> Result<()> {
1949 let (project_id, room) = &*session
1950 .db()
1951 .await
1952 .share_project(
1953 RoomId::from_proto(request.room_id),
1954 session.connection_id,
1955 &request.worktrees,
1956 request.is_ssh_project,
1957 )
1958 .await?;
1959 response.send(proto::ShareProjectResponse {
1960 project_id: project_id.to_proto(),
1961 })?;
1962 room_updated(room, &session.peer);
1963
1964 Ok(())
1965}
1966
1967/// Unshare a project from the room.
1968async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1969 let project_id = ProjectId::from_proto(message.project_id);
1970 unshare_project_internal(project_id, session.connection_id, &session).await
1971}
1972
1973async fn unshare_project_internal(
1974 project_id: ProjectId,
1975 connection_id: ConnectionId,
1976 session: &Session,
1977) -> Result<()> {
1978 let delete = {
1979 let room_guard = session
1980 .db()
1981 .await
1982 .unshare_project(project_id, connection_id)
1983 .await?;
1984
1985 let (delete, room, guest_connection_ids) = &*room_guard;
1986
1987 let message = proto::UnshareProject {
1988 project_id: project_id.to_proto(),
1989 };
1990
1991 broadcast(
1992 Some(connection_id),
1993 guest_connection_ids.iter().copied(),
1994 |conn_id| session.peer.send(conn_id, message.clone()),
1995 );
1996 if let Some(room) = room {
1997 room_updated(room, &session.peer);
1998 }
1999
2000 *delete
2001 };
2002
2003 if delete {
2004 let db = session.db().await;
2005 db.delete_project(project_id).await?;
2006 }
2007
2008 Ok(())
2009}
2010
2011/// Join someone elses shared project.
2012async fn join_project(
2013 request: proto::JoinProject,
2014 response: Response<proto::JoinProject>,
2015 session: MessageContext,
2016) -> Result<()> {
2017 let project_id = ProjectId::from_proto(request.project_id);
2018
2019 tracing::info!(%project_id, "join project");
2020
2021 let db = session.db().await;
2022 let (project, replica_id) = &mut *db
2023 .join_project(
2024 project_id,
2025 session.connection_id,
2026 session.user_id(),
2027 request.committer_name.clone(),
2028 request.committer_email.clone(),
2029 )
2030 .await?;
2031 drop(db);
2032 tracing::info!(%project_id, "join remote project");
2033 let collaborators = project
2034 .collaborators
2035 .iter()
2036 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2037 .map(|collaborator| collaborator.to_proto())
2038 .collect::<Vec<_>>();
2039 let project_id = project.id;
2040 let guest_user_id = session.user_id();
2041
2042 let worktrees = project
2043 .worktrees
2044 .iter()
2045 .map(|(id, worktree)| proto::WorktreeMetadata {
2046 id: *id,
2047 root_name: worktree.root_name.clone(),
2048 visible: worktree.visible,
2049 abs_path: worktree.abs_path.clone(),
2050 })
2051 .collect::<Vec<_>>();
2052
2053 let add_project_collaborator = proto::AddProjectCollaborator {
2054 project_id: project_id.to_proto(),
2055 collaborator: Some(proto::Collaborator {
2056 peer_id: Some(session.connection_id.into()),
2057 replica_id: replica_id.0 as u32,
2058 user_id: guest_user_id.to_proto(),
2059 is_host: false,
2060 committer_name: request.committer_name.clone(),
2061 committer_email: request.committer_email.clone(),
2062 }),
2063 };
2064
2065 for collaborator in &collaborators {
2066 session
2067 .peer
2068 .send(
2069 collaborator.peer_id.unwrap().into(),
2070 add_project_collaborator.clone(),
2071 )
2072 .trace_err();
2073 }
2074
2075 // First, we send the metadata associated with each worktree.
2076 let (language_servers, language_server_capabilities) = project
2077 .language_servers
2078 .clone()
2079 .into_iter()
2080 .map(|server| (server.server, server.capabilities))
2081 .unzip();
2082 response.send(proto::JoinProjectResponse {
2083 project_id: project.id.0 as u64,
2084 worktrees: worktrees.clone(),
2085 replica_id: replica_id.0 as u32,
2086 collaborators: collaborators.clone(),
2087 language_servers,
2088 language_server_capabilities,
2089 role: project.role.into(),
2090 })?;
2091
2092 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2093 // Stream this worktree's entries.
2094 let message = proto::UpdateWorktree {
2095 project_id: project_id.to_proto(),
2096 worktree_id,
2097 abs_path: worktree.abs_path.clone(),
2098 root_name: worktree.root_name,
2099 updated_entries: worktree.entries,
2100 removed_entries: Default::default(),
2101 scan_id: worktree.scan_id,
2102 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2103 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2104 removed_repositories: Default::default(),
2105 };
2106 for update in proto::split_worktree_update(message) {
2107 session.peer.send(session.connection_id, update.clone())?;
2108 }
2109
2110 // Stream this worktree's diagnostics.
2111 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2112 if let Some(summary) = worktree_diagnostics.next() {
2113 let message = proto::UpdateDiagnosticSummary {
2114 project_id: project.id.to_proto(),
2115 worktree_id: worktree.id,
2116 summary: Some(summary),
2117 more_summaries: worktree_diagnostics.collect(),
2118 };
2119 session.peer.send(session.connection_id, message)?;
2120 }
2121
2122 for settings_file in worktree.settings_files {
2123 session.peer.send(
2124 session.connection_id,
2125 proto::UpdateWorktreeSettings {
2126 project_id: project_id.to_proto(),
2127 worktree_id: worktree.id,
2128 path: settings_file.path,
2129 content: Some(settings_file.content),
2130 kind: Some(settings_file.kind.to_proto() as i32),
2131 },
2132 )?;
2133 }
2134 }
2135
2136 for repository in mem::take(&mut project.repositories) {
2137 for update in split_repository_update(repository) {
2138 session.peer.send(session.connection_id, update)?;
2139 }
2140 }
2141
2142 for language_server in &project.language_servers {
2143 session.peer.send(
2144 session.connection_id,
2145 proto::UpdateLanguageServer {
2146 project_id: project_id.to_proto(),
2147 server_name: Some(language_server.server.name.clone()),
2148 language_server_id: language_server.server.id,
2149 variant: Some(
2150 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2151 proto::LspDiskBasedDiagnosticsUpdated {},
2152 ),
2153 ),
2154 },
2155 )?;
2156 }
2157
2158 Ok(())
2159}
2160
2161/// Leave someone elses shared project.
2162async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2163 let sender_id = session.connection_id;
2164 let project_id = ProjectId::from_proto(request.project_id);
2165 let db = session.db().await;
2166
2167 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2168 tracing::info!(
2169 %project_id,
2170 "leave project"
2171 );
2172
2173 project_left(project, &session);
2174 if let Some(room) = room {
2175 room_updated(room, &session.peer);
2176 }
2177
2178 Ok(())
2179}
2180
2181/// Updates other participants with changes to the project
2182async fn update_project(
2183 request: proto::UpdateProject,
2184 response: Response<proto::UpdateProject>,
2185 session: MessageContext,
2186) -> Result<()> {
2187 let project_id = ProjectId::from_proto(request.project_id);
2188 let (room, guest_connection_ids) = &*session
2189 .db()
2190 .await
2191 .update_project(project_id, session.connection_id, &request.worktrees)
2192 .await?;
2193 broadcast(
2194 Some(session.connection_id),
2195 guest_connection_ids.iter().copied(),
2196 |connection_id| {
2197 session
2198 .peer
2199 .forward_send(session.connection_id, connection_id, request.clone())
2200 },
2201 );
2202 if let Some(room) = room {
2203 room_updated(room, &session.peer);
2204 }
2205 response.send(proto::Ack {})?;
2206
2207 Ok(())
2208}
2209
2210/// Updates other participants with changes to the worktree
2211async fn update_worktree(
2212 request: proto::UpdateWorktree,
2213 response: Response<proto::UpdateWorktree>,
2214 session: MessageContext,
2215) -> Result<()> {
2216 let guest_connection_ids = session
2217 .db()
2218 .await
2219 .update_worktree(&request, session.connection_id)
2220 .await?;
2221
2222 broadcast(
2223 Some(session.connection_id),
2224 guest_connection_ids.iter().copied(),
2225 |connection_id| {
2226 session
2227 .peer
2228 .forward_send(session.connection_id, connection_id, request.clone())
2229 },
2230 );
2231 response.send(proto::Ack {})?;
2232 Ok(())
2233}
2234
2235async fn update_repository(
2236 request: proto::UpdateRepository,
2237 response: Response<proto::UpdateRepository>,
2238 session: MessageContext,
2239) -> Result<()> {
2240 let guest_connection_ids = session
2241 .db()
2242 .await
2243 .update_repository(&request, session.connection_id)
2244 .await?;
2245
2246 broadcast(
2247 Some(session.connection_id),
2248 guest_connection_ids.iter().copied(),
2249 |connection_id| {
2250 session
2251 .peer
2252 .forward_send(session.connection_id, connection_id, request.clone())
2253 },
2254 );
2255 response.send(proto::Ack {})?;
2256 Ok(())
2257}
2258
2259async fn remove_repository(
2260 request: proto::RemoveRepository,
2261 response: Response<proto::RemoveRepository>,
2262 session: MessageContext,
2263) -> Result<()> {
2264 let guest_connection_ids = session
2265 .db()
2266 .await
2267 .remove_repository(&request, session.connection_id)
2268 .await?;
2269
2270 broadcast(
2271 Some(session.connection_id),
2272 guest_connection_ids.iter().copied(),
2273 |connection_id| {
2274 session
2275 .peer
2276 .forward_send(session.connection_id, connection_id, request.clone())
2277 },
2278 );
2279 response.send(proto::Ack {})?;
2280 Ok(())
2281}
2282
2283/// Updates other participants with changes to the diagnostics
2284async fn update_diagnostic_summary(
2285 message: proto::UpdateDiagnosticSummary,
2286 session: MessageContext,
2287) -> Result<()> {
2288 let guest_connection_ids = session
2289 .db()
2290 .await
2291 .update_diagnostic_summary(&message, session.connection_id)
2292 .await?;
2293
2294 broadcast(
2295 Some(session.connection_id),
2296 guest_connection_ids.iter().copied(),
2297 |connection_id| {
2298 session
2299 .peer
2300 .forward_send(session.connection_id, connection_id, message.clone())
2301 },
2302 );
2303
2304 Ok(())
2305}
2306
2307/// Updates other participants with changes to the worktree settings
2308async fn update_worktree_settings(
2309 message: proto::UpdateWorktreeSettings,
2310 session: MessageContext,
2311) -> Result<()> {
2312 let guest_connection_ids = session
2313 .db()
2314 .await
2315 .update_worktree_settings(&message, session.connection_id)
2316 .await?;
2317
2318 broadcast(
2319 Some(session.connection_id),
2320 guest_connection_ids.iter().copied(),
2321 |connection_id| {
2322 session
2323 .peer
2324 .forward_send(session.connection_id, connection_id, message.clone())
2325 },
2326 );
2327
2328 Ok(())
2329}
2330
2331/// Notify other participants that a language server has started.
2332async fn start_language_server(
2333 request: proto::StartLanguageServer,
2334 session: MessageContext,
2335) -> Result<()> {
2336 let guest_connection_ids = session
2337 .db()
2338 .await
2339 .start_language_server(&request, session.connection_id)
2340 .await?;
2341
2342 broadcast(
2343 Some(session.connection_id),
2344 guest_connection_ids.iter().copied(),
2345 |connection_id| {
2346 session
2347 .peer
2348 .forward_send(session.connection_id, connection_id, request.clone())
2349 },
2350 );
2351 Ok(())
2352}
2353
2354/// Notify other participants that a language server has changed.
2355async fn update_language_server(
2356 request: proto::UpdateLanguageServer,
2357 session: MessageContext,
2358) -> Result<()> {
2359 let project_id = ProjectId::from_proto(request.project_id);
2360 let db = session.db().await;
2361
2362 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2363 {
2364 if let Some(capabilities) = update.capabilities.clone() {
2365 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2366 .await?;
2367 }
2368 }
2369
2370 let project_connection_ids = db
2371 .project_connection_ids(project_id, session.connection_id, true)
2372 .await?;
2373 broadcast(
2374 Some(session.connection_id),
2375 project_connection_ids.iter().copied(),
2376 |connection_id| {
2377 session
2378 .peer
2379 .forward_send(session.connection_id, connection_id, request.clone())
2380 },
2381 );
2382 Ok(())
2383}
2384
2385/// forward a project request to the host. These requests should be read only
2386/// as guests are allowed to send them.
2387async fn forward_read_only_project_request<T>(
2388 request: T,
2389 response: Response<T>,
2390 session: MessageContext,
2391) -> Result<()>
2392where
2393 T: EntityMessage + RequestMessage,
2394{
2395 let project_id = ProjectId::from_proto(request.remote_entity_id());
2396 let host_connection_id = session
2397 .db()
2398 .await
2399 .host_for_read_only_project_request(project_id, session.connection_id)
2400 .await?;
2401 let payload = session.forward_request(host_connection_id, request).await?;
2402 response.send(payload)?;
2403 Ok(())
2404}
2405
2406/// forward a project request to the host. These requests are disallowed
2407/// for guests.
2408async fn forward_mutating_project_request<T>(
2409 request: T,
2410 response: Response<T>,
2411 session: MessageContext,
2412) -> Result<()>
2413where
2414 T: EntityMessage + RequestMessage,
2415{
2416 let project_id = ProjectId::from_proto(request.remote_entity_id());
2417
2418 let host_connection_id = session
2419 .db()
2420 .await
2421 .host_for_mutating_project_request(project_id, session.connection_id)
2422 .await?;
2423 let payload = session.forward_request(host_connection_id, request).await?;
2424 response.send(payload)?;
2425 Ok(())
2426}
2427
2428async fn multi_lsp_query(
2429 request: MultiLspQuery,
2430 response: Response<MultiLspQuery>,
2431 session: MessageContext,
2432) -> Result<()> {
2433 tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2434 tracing::info!("multi_lsp_query message received");
2435 forward_mutating_project_request(request, response, session).await
2436}
2437
2438/// Notify other participants that a new buffer has been created
2439async fn create_buffer_for_peer(
2440 request: proto::CreateBufferForPeer,
2441 session: MessageContext,
2442) -> Result<()> {
2443 session
2444 .db()
2445 .await
2446 .check_user_is_project_host(
2447 ProjectId::from_proto(request.project_id),
2448 session.connection_id,
2449 )
2450 .await?;
2451 let peer_id = request.peer_id.context("invalid peer id")?;
2452 session
2453 .peer
2454 .forward_send(session.connection_id, peer_id.into(), request)?;
2455 Ok(())
2456}
2457
2458/// Notify other participants that a buffer has been updated. This is
2459/// allowed for guests as long as the update is limited to selections.
2460async fn update_buffer(
2461 request: proto::UpdateBuffer,
2462 response: Response<proto::UpdateBuffer>,
2463 session: MessageContext,
2464) -> Result<()> {
2465 let project_id = ProjectId::from_proto(request.project_id);
2466 let mut capability = Capability::ReadOnly;
2467
2468 for op in request.operations.iter() {
2469 match op.variant {
2470 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2471 Some(_) => capability = Capability::ReadWrite,
2472 }
2473 }
2474
2475 let host = {
2476 let guard = session
2477 .db()
2478 .await
2479 .connections_for_buffer_update(project_id, session.connection_id, capability)
2480 .await?;
2481
2482 let (host, guests) = &*guard;
2483
2484 broadcast(
2485 Some(session.connection_id),
2486 guests.clone(),
2487 |connection_id| {
2488 session
2489 .peer
2490 .forward_send(session.connection_id, connection_id, request.clone())
2491 },
2492 );
2493
2494 *host
2495 };
2496
2497 if host != session.connection_id {
2498 session.forward_request(host, request.clone()).await?;
2499 }
2500
2501 response.send(proto::Ack {})?;
2502 Ok(())
2503}
2504
2505async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2506 let project_id = ProjectId::from_proto(message.project_id);
2507
2508 let operation = message.operation.as_ref().context("invalid operation")?;
2509 let capability = match operation.variant.as_ref() {
2510 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2511 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2512 match buffer_op.variant {
2513 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2514 Capability::ReadOnly
2515 }
2516 _ => Capability::ReadWrite,
2517 }
2518 } else {
2519 Capability::ReadWrite
2520 }
2521 }
2522 Some(_) => Capability::ReadWrite,
2523 None => Capability::ReadOnly,
2524 };
2525
2526 let guard = session
2527 .db()
2528 .await
2529 .connections_for_buffer_update(project_id, session.connection_id, capability)
2530 .await?;
2531
2532 let (host, guests) = &*guard;
2533
2534 broadcast(
2535 Some(session.connection_id),
2536 guests.iter().chain([host]).copied(),
2537 |connection_id| {
2538 session
2539 .peer
2540 .forward_send(session.connection_id, connection_id, message.clone())
2541 },
2542 );
2543
2544 Ok(())
2545}
2546
2547/// Notify other participants that a project has been updated.
2548async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2549 request: T,
2550 session: MessageContext,
2551) -> Result<()> {
2552 let project_id = ProjectId::from_proto(request.remote_entity_id());
2553 let project_connection_ids = session
2554 .db()
2555 .await
2556 .project_connection_ids(project_id, session.connection_id, false)
2557 .await?;
2558
2559 broadcast(
2560 Some(session.connection_id),
2561 project_connection_ids.iter().copied(),
2562 |connection_id| {
2563 session
2564 .peer
2565 .forward_send(session.connection_id, connection_id, request.clone())
2566 },
2567 );
2568 Ok(())
2569}
2570
2571/// Start following another user in a call.
2572async fn follow(
2573 request: proto::Follow,
2574 response: Response<proto::Follow>,
2575 session: MessageContext,
2576) -> Result<()> {
2577 let room_id = RoomId::from_proto(request.room_id);
2578 let project_id = request.project_id.map(ProjectId::from_proto);
2579 let leader_id = request.leader_id.context("invalid leader id")?.into();
2580 let follower_id = session.connection_id;
2581
2582 session
2583 .db()
2584 .await
2585 .check_room_participants(room_id, leader_id, session.connection_id)
2586 .await?;
2587
2588 let response_payload = session.forward_request(leader_id, request).await?;
2589 response.send(response_payload)?;
2590
2591 if let Some(project_id) = project_id {
2592 let room = session
2593 .db()
2594 .await
2595 .follow(room_id, project_id, leader_id, follower_id)
2596 .await?;
2597 room_updated(&room, &session.peer);
2598 }
2599
2600 Ok(())
2601}
2602
2603/// Stop following another user in a call.
2604async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2605 let room_id = RoomId::from_proto(request.room_id);
2606 let project_id = request.project_id.map(ProjectId::from_proto);
2607 let leader_id = request.leader_id.context("invalid leader id")?.into();
2608 let follower_id = session.connection_id;
2609
2610 session
2611 .db()
2612 .await
2613 .check_room_participants(room_id, leader_id, session.connection_id)
2614 .await?;
2615
2616 session
2617 .peer
2618 .forward_send(session.connection_id, leader_id, request)?;
2619
2620 if let Some(project_id) = project_id {
2621 let room = session
2622 .db()
2623 .await
2624 .unfollow(room_id, project_id, leader_id, follower_id)
2625 .await?;
2626 room_updated(&room, &session.peer);
2627 }
2628
2629 Ok(())
2630}
2631
2632/// Notify everyone following you of your current location.
2633async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2634 let room_id = RoomId::from_proto(request.room_id);
2635 let database = session.db.lock().await;
2636
2637 let connection_ids = if let Some(project_id) = request.project_id {
2638 let project_id = ProjectId::from_proto(project_id);
2639 database
2640 .project_connection_ids(project_id, session.connection_id, true)
2641 .await?
2642 } else {
2643 database
2644 .room_connection_ids(room_id, session.connection_id)
2645 .await?
2646 };
2647
2648 // For now, don't send view update messages back to that view's current leader.
2649 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2650 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2651 _ => None,
2652 });
2653
2654 for connection_id in connection_ids.iter().cloned() {
2655 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2656 session
2657 .peer
2658 .forward_send(session.connection_id, connection_id, request.clone())?;
2659 }
2660 }
2661 Ok(())
2662}
2663
2664/// Get public data about users.
2665async fn get_users(
2666 request: proto::GetUsers,
2667 response: Response<proto::GetUsers>,
2668 session: MessageContext,
2669) -> Result<()> {
2670 let user_ids = request
2671 .user_ids
2672 .into_iter()
2673 .map(UserId::from_proto)
2674 .collect();
2675 let users = session
2676 .db()
2677 .await
2678 .get_users_by_ids(user_ids)
2679 .await?
2680 .into_iter()
2681 .map(|user| proto::User {
2682 id: user.id.to_proto(),
2683 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2684 github_login: user.github_login,
2685 name: user.name,
2686 })
2687 .collect();
2688 response.send(proto::UsersResponse { users })?;
2689 Ok(())
2690}
2691
2692/// Search for users (to invite) buy Github login
2693async fn fuzzy_search_users(
2694 request: proto::FuzzySearchUsers,
2695 response: Response<proto::FuzzySearchUsers>,
2696 session: MessageContext,
2697) -> Result<()> {
2698 let query = request.query;
2699 let users = match query.len() {
2700 0 => vec![],
2701 1 | 2 => session
2702 .db()
2703 .await
2704 .get_user_by_github_login(&query)
2705 .await?
2706 .into_iter()
2707 .collect(),
2708 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2709 };
2710 let users = users
2711 .into_iter()
2712 .filter(|user| user.id != session.user_id())
2713 .map(|user| proto::User {
2714 id: user.id.to_proto(),
2715 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2716 github_login: user.github_login,
2717 name: user.name,
2718 })
2719 .collect();
2720 response.send(proto::UsersResponse { users })?;
2721 Ok(())
2722}
2723
2724/// Send a contact request to another user.
2725async fn request_contact(
2726 request: proto::RequestContact,
2727 response: Response<proto::RequestContact>,
2728 session: MessageContext,
2729) -> Result<()> {
2730 let requester_id = session.user_id();
2731 let responder_id = UserId::from_proto(request.responder_id);
2732 if requester_id == responder_id {
2733 return Err(anyhow!("cannot add yourself as a contact"))?;
2734 }
2735
2736 let notifications = session
2737 .db()
2738 .await
2739 .send_contact_request(requester_id, responder_id)
2740 .await?;
2741
2742 // Update outgoing contact requests of requester
2743 let mut update = proto::UpdateContacts::default();
2744 update.outgoing_requests.push(responder_id.to_proto());
2745 for connection_id in session
2746 .connection_pool()
2747 .await
2748 .user_connection_ids(requester_id)
2749 {
2750 session.peer.send(connection_id, update.clone())?;
2751 }
2752
2753 // Update incoming contact requests of responder
2754 let mut update = proto::UpdateContacts::default();
2755 update
2756 .incoming_requests
2757 .push(proto::IncomingContactRequest {
2758 requester_id: requester_id.to_proto(),
2759 });
2760 let connection_pool = session.connection_pool().await;
2761 for connection_id in connection_pool.user_connection_ids(responder_id) {
2762 session.peer.send(connection_id, update.clone())?;
2763 }
2764
2765 send_notifications(&connection_pool, &session.peer, notifications);
2766
2767 response.send(proto::Ack {})?;
2768 Ok(())
2769}
2770
2771/// Accept or decline a contact request
2772async fn respond_to_contact_request(
2773 request: proto::RespondToContactRequest,
2774 response: Response<proto::RespondToContactRequest>,
2775 session: MessageContext,
2776) -> Result<()> {
2777 let responder_id = session.user_id();
2778 let requester_id = UserId::from_proto(request.requester_id);
2779 let db = session.db().await;
2780 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2781 db.dismiss_contact_notification(responder_id, requester_id)
2782 .await?;
2783 } else {
2784 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2785
2786 let notifications = db
2787 .respond_to_contact_request(responder_id, requester_id, accept)
2788 .await?;
2789 let requester_busy = db.is_user_busy(requester_id).await?;
2790 let responder_busy = db.is_user_busy(responder_id).await?;
2791
2792 let pool = session.connection_pool().await;
2793 // Update responder with new contact
2794 let mut update = proto::UpdateContacts::default();
2795 if accept {
2796 update
2797 .contacts
2798 .push(contact_for_user(requester_id, requester_busy, &pool));
2799 }
2800 update
2801 .remove_incoming_requests
2802 .push(requester_id.to_proto());
2803 for connection_id in pool.user_connection_ids(responder_id) {
2804 session.peer.send(connection_id, update.clone())?;
2805 }
2806
2807 // Update requester with new contact
2808 let mut update = proto::UpdateContacts::default();
2809 if accept {
2810 update
2811 .contacts
2812 .push(contact_for_user(responder_id, responder_busy, &pool));
2813 }
2814 update
2815 .remove_outgoing_requests
2816 .push(responder_id.to_proto());
2817
2818 for connection_id in pool.user_connection_ids(requester_id) {
2819 session.peer.send(connection_id, update.clone())?;
2820 }
2821
2822 send_notifications(&pool, &session.peer, notifications);
2823 }
2824
2825 response.send(proto::Ack {})?;
2826 Ok(())
2827}
2828
2829/// Remove a contact.
2830async fn remove_contact(
2831 request: proto::RemoveContact,
2832 response: Response<proto::RemoveContact>,
2833 session: MessageContext,
2834) -> Result<()> {
2835 let requester_id = session.user_id();
2836 let responder_id = UserId::from_proto(request.user_id);
2837 let db = session.db().await;
2838 let (contact_accepted, deleted_notification_id) =
2839 db.remove_contact(requester_id, responder_id).await?;
2840
2841 let pool = session.connection_pool().await;
2842 // Update outgoing contact requests of requester
2843 let mut update = proto::UpdateContacts::default();
2844 if contact_accepted {
2845 update.remove_contacts.push(responder_id.to_proto());
2846 } else {
2847 update
2848 .remove_outgoing_requests
2849 .push(responder_id.to_proto());
2850 }
2851 for connection_id in pool.user_connection_ids(requester_id) {
2852 session.peer.send(connection_id, update.clone())?;
2853 }
2854
2855 // Update incoming contact requests of responder
2856 let mut update = proto::UpdateContacts::default();
2857 if contact_accepted {
2858 update.remove_contacts.push(requester_id.to_proto());
2859 } else {
2860 update
2861 .remove_incoming_requests
2862 .push(requester_id.to_proto());
2863 }
2864 for connection_id in pool.user_connection_ids(responder_id) {
2865 session.peer.send(connection_id, update.clone())?;
2866 if let Some(notification_id) = deleted_notification_id {
2867 session.peer.send(
2868 connection_id,
2869 proto::DeleteNotification {
2870 notification_id: notification_id.to_proto(),
2871 },
2872 )?;
2873 }
2874 }
2875
2876 response.send(proto::Ack {})?;
2877 Ok(())
2878}
2879
2880fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2881 version.0.minor() < 139
2882}
2883
2884async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2885 if is_staff {
2886 return Ok(proto::Plan::ZedPro);
2887 }
2888
2889 let subscription = db.get_active_billing_subscription(user_id).await?;
2890 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2891
2892 let plan = if let Some(subscription_kind) = subscription_kind {
2893 match subscription_kind {
2894 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2895 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2896 SubscriptionKind::ZedFree => proto::Plan::Free,
2897 }
2898 } else {
2899 proto::Plan::Free
2900 };
2901
2902 Ok(plan)
2903}
2904
2905async fn make_update_user_plan_message(
2906 user: &User,
2907 is_staff: bool,
2908 db: &Arc<Database>,
2909 llm_db: Option<Arc<LlmDatabase>>,
2910) -> Result<proto::UpdateUserPlan> {
2911 let feature_flags = db.get_user_flags(user.id).await?;
2912 let plan = current_plan(db, user.id, is_staff).await?;
2913 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2914 let billing_preferences = db.get_billing_preferences(user.id).await?;
2915
2916 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2917 let subscription = db.get_active_billing_subscription(user.id).await?;
2918
2919 let subscription_period =
2920 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2921
2922 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2923 llm_db
2924 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2925 .await?
2926 } else {
2927 None
2928 };
2929
2930 (subscription_period, usage)
2931 } else {
2932 (None, None)
2933 };
2934
2935 let bypass_account_age_check = feature_flags
2936 .iter()
2937 .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2938 let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2939 && !bypass_account_age_check
2940 && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2941
2942 Ok(proto::UpdateUserPlan {
2943 plan: plan.into(),
2944 trial_started_at: billing_customer
2945 .as_ref()
2946 .and_then(|billing_customer| billing_customer.trial_started_at)
2947 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2948 is_usage_based_billing_enabled: if is_staff {
2949 Some(true)
2950 } else {
2951 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2952 },
2953 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2954 proto::SubscriptionPeriod {
2955 started_at: started_at.timestamp() as u64,
2956 ended_at: ended_at.timestamp() as u64,
2957 }
2958 }),
2959 account_too_young: Some(account_too_young),
2960 has_overdue_invoices: billing_customer
2961 .map(|billing_customer| billing_customer.has_overdue_invoices),
2962 usage: Some(
2963 usage
2964 .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2965 .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2966 ),
2967 })
2968}
2969
2970fn model_requests_limit(
2971 plan: cloud_llm_client::Plan,
2972 feature_flags: &Vec<String>,
2973) -> cloud_llm_client::UsageLimit {
2974 match plan.model_requests_limit() {
2975 cloud_llm_client::UsageLimit::Limited(limit) => {
2976 let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2977 && feature_flags
2978 .iter()
2979 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2980 {
2981 1_000
2982 } else {
2983 limit
2984 };
2985
2986 cloud_llm_client::UsageLimit::Limited(limit)
2987 }
2988 cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2989 }
2990}
2991
2992fn subscription_usage_to_proto(
2993 plan: proto::Plan,
2994 usage: crate::llm::db::subscription_usage::Model,
2995 feature_flags: &Vec<String>,
2996) -> proto::SubscriptionUsage {
2997 let plan = match plan {
2998 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2999 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
3000 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
3001 };
3002
3003 proto::SubscriptionUsage {
3004 model_requests_usage_amount: usage.model_requests as u32,
3005 model_requests_usage_limit: Some(proto::UsageLimit {
3006 variant: Some(match model_requests_limit(plan, feature_flags) {
3007 cloud_llm_client::UsageLimit::Limited(limit) => {
3008 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3009 limit: limit as u32,
3010 })
3011 }
3012 cloud_llm_client::UsageLimit::Unlimited => {
3013 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3014 }
3015 }),
3016 }),
3017 edit_predictions_usage_amount: usage.edit_predictions as u32,
3018 edit_predictions_usage_limit: Some(proto::UsageLimit {
3019 variant: Some(match plan.edit_predictions_limit() {
3020 cloud_llm_client::UsageLimit::Limited(limit) => {
3021 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3022 limit: limit as u32,
3023 })
3024 }
3025 cloud_llm_client::UsageLimit::Unlimited => {
3026 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3027 }
3028 }),
3029 }),
3030 }
3031}
3032
3033fn make_default_subscription_usage(
3034 plan: proto::Plan,
3035 feature_flags: &Vec<String>,
3036) -> proto::SubscriptionUsage {
3037 let plan = match plan {
3038 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
3039 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
3040 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
3041 };
3042
3043 proto::SubscriptionUsage {
3044 model_requests_usage_amount: 0,
3045 model_requests_usage_limit: Some(proto::UsageLimit {
3046 variant: Some(match model_requests_limit(plan, feature_flags) {
3047 cloud_llm_client::UsageLimit::Limited(limit) => {
3048 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3049 limit: limit as u32,
3050 })
3051 }
3052 cloud_llm_client::UsageLimit::Unlimited => {
3053 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3054 }
3055 }),
3056 }),
3057 edit_predictions_usage_amount: 0,
3058 edit_predictions_usage_limit: Some(proto::UsageLimit {
3059 variant: Some(match plan.edit_predictions_limit() {
3060 cloud_llm_client::UsageLimit::Limited(limit) => {
3061 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3062 limit: limit as u32,
3063 })
3064 }
3065 cloud_llm_client::UsageLimit::Unlimited => {
3066 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3067 }
3068 }),
3069 }),
3070 }
3071}
3072
3073async fn update_user_plan(session: &Session) -> Result<()> {
3074 let db = session.db().await;
3075
3076 let update_user_plan = make_update_user_plan_message(
3077 session.principal.user(),
3078 session.is_staff(),
3079 &db.0,
3080 session.app_state.llm_db.clone(),
3081 )
3082 .await?;
3083
3084 session
3085 .peer
3086 .send(session.connection_id, update_user_plan)
3087 .trace_err();
3088
3089 Ok(())
3090}
3091
3092async fn subscribe_to_channels(
3093 _: proto::SubscribeToChannels,
3094 session: MessageContext,
3095) -> Result<()> {
3096 subscribe_user_to_channels(session.user_id(), &session).await?;
3097 Ok(())
3098}
3099
3100async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3101 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3102 let mut pool = session.connection_pool().await;
3103 for membership in &channels_for_user.channel_memberships {
3104 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3105 }
3106 session.peer.send(
3107 session.connection_id,
3108 build_update_user_channels(&channels_for_user),
3109 )?;
3110 session.peer.send(
3111 session.connection_id,
3112 build_channels_update(channels_for_user),
3113 )?;
3114 Ok(())
3115}
3116
3117/// Creates a new channel.
3118async fn create_channel(
3119 request: proto::CreateChannel,
3120 response: Response<proto::CreateChannel>,
3121 session: MessageContext,
3122) -> Result<()> {
3123 let db = session.db().await;
3124
3125 let parent_id = request.parent_id.map(ChannelId::from_proto);
3126 let (channel, membership) = db
3127 .create_channel(&request.name, parent_id, session.user_id())
3128 .await?;
3129
3130 let root_id = channel.root_id();
3131 let channel = Channel::from_model(channel);
3132
3133 response.send(proto::CreateChannelResponse {
3134 channel: Some(channel.to_proto()),
3135 parent_id: request.parent_id,
3136 })?;
3137
3138 let mut connection_pool = session.connection_pool().await;
3139 if let Some(membership) = membership {
3140 connection_pool.subscribe_to_channel(
3141 membership.user_id,
3142 membership.channel_id,
3143 membership.role,
3144 );
3145 let update = proto::UpdateUserChannels {
3146 channel_memberships: vec![proto::ChannelMembership {
3147 channel_id: membership.channel_id.to_proto(),
3148 role: membership.role.into(),
3149 }],
3150 ..Default::default()
3151 };
3152 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3153 session.peer.send(connection_id, update.clone())?;
3154 }
3155 }
3156
3157 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3158 if !role.can_see_channel(channel.visibility) {
3159 continue;
3160 }
3161
3162 let update = proto::UpdateChannels {
3163 channels: vec![channel.to_proto()],
3164 ..Default::default()
3165 };
3166 session.peer.send(connection_id, update.clone())?;
3167 }
3168
3169 Ok(())
3170}
3171
3172/// Delete a channel
3173async fn delete_channel(
3174 request: proto::DeleteChannel,
3175 response: Response<proto::DeleteChannel>,
3176 session: MessageContext,
3177) -> Result<()> {
3178 let db = session.db().await;
3179
3180 let channel_id = request.channel_id;
3181 let (root_channel, removed_channels) = db
3182 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3183 .await?;
3184 response.send(proto::Ack {})?;
3185
3186 // Notify members of removed channels
3187 let mut update = proto::UpdateChannels::default();
3188 update
3189 .delete_channels
3190 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3191
3192 let connection_pool = session.connection_pool().await;
3193 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3194 session.peer.send(connection_id, update.clone())?;
3195 }
3196
3197 Ok(())
3198}
3199
3200/// Invite someone to join a channel.
3201async fn invite_channel_member(
3202 request: proto::InviteChannelMember,
3203 response: Response<proto::InviteChannelMember>,
3204 session: MessageContext,
3205) -> Result<()> {
3206 let db = session.db().await;
3207 let channel_id = ChannelId::from_proto(request.channel_id);
3208 let invitee_id = UserId::from_proto(request.user_id);
3209 let InviteMemberResult {
3210 channel,
3211 notifications,
3212 } = db
3213 .invite_channel_member(
3214 channel_id,
3215 invitee_id,
3216 session.user_id(),
3217 request.role().into(),
3218 )
3219 .await?;
3220
3221 let update = proto::UpdateChannels {
3222 channel_invitations: vec![channel.to_proto()],
3223 ..Default::default()
3224 };
3225
3226 let connection_pool = session.connection_pool().await;
3227 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3228 session.peer.send(connection_id, update.clone())?;
3229 }
3230
3231 send_notifications(&connection_pool, &session.peer, notifications);
3232
3233 response.send(proto::Ack {})?;
3234 Ok(())
3235}
3236
3237/// remove someone from a channel
3238async fn remove_channel_member(
3239 request: proto::RemoveChannelMember,
3240 response: Response<proto::RemoveChannelMember>,
3241 session: MessageContext,
3242) -> Result<()> {
3243 let db = session.db().await;
3244 let channel_id = ChannelId::from_proto(request.channel_id);
3245 let member_id = UserId::from_proto(request.user_id);
3246
3247 let RemoveChannelMemberResult {
3248 membership_update,
3249 notification_id,
3250 } = db
3251 .remove_channel_member(channel_id, member_id, session.user_id())
3252 .await?;
3253
3254 let mut connection_pool = session.connection_pool().await;
3255 notify_membership_updated(
3256 &mut connection_pool,
3257 membership_update,
3258 member_id,
3259 &session.peer,
3260 );
3261 for connection_id in connection_pool.user_connection_ids(member_id) {
3262 if let Some(notification_id) = notification_id {
3263 session
3264 .peer
3265 .send(
3266 connection_id,
3267 proto::DeleteNotification {
3268 notification_id: notification_id.to_proto(),
3269 },
3270 )
3271 .trace_err();
3272 }
3273 }
3274
3275 response.send(proto::Ack {})?;
3276 Ok(())
3277}
3278
3279/// Toggle the channel between public and private.
3280/// Care is taken to maintain the invariant that public channels only descend from public channels,
3281/// (though members-only channels can appear at any point in the hierarchy).
3282async fn set_channel_visibility(
3283 request: proto::SetChannelVisibility,
3284 response: Response<proto::SetChannelVisibility>,
3285 session: MessageContext,
3286) -> Result<()> {
3287 let db = session.db().await;
3288 let channel_id = ChannelId::from_proto(request.channel_id);
3289 let visibility = request.visibility().into();
3290
3291 let channel_model = db
3292 .set_channel_visibility(channel_id, visibility, session.user_id())
3293 .await?;
3294 let root_id = channel_model.root_id();
3295 let channel = Channel::from_model(channel_model);
3296
3297 let mut connection_pool = session.connection_pool().await;
3298 for (user_id, role) in connection_pool
3299 .channel_user_ids(root_id)
3300 .collect::<Vec<_>>()
3301 .into_iter()
3302 {
3303 let update = if role.can_see_channel(channel.visibility) {
3304 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3305 proto::UpdateChannels {
3306 channels: vec![channel.to_proto()],
3307 ..Default::default()
3308 }
3309 } else {
3310 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3311 proto::UpdateChannels {
3312 delete_channels: vec![channel.id.to_proto()],
3313 ..Default::default()
3314 }
3315 };
3316
3317 for connection_id in connection_pool.user_connection_ids(user_id) {
3318 session.peer.send(connection_id, update.clone())?;
3319 }
3320 }
3321
3322 response.send(proto::Ack {})?;
3323 Ok(())
3324}
3325
3326/// Alter the role for a user in the channel.
3327async fn set_channel_member_role(
3328 request: proto::SetChannelMemberRole,
3329 response: Response<proto::SetChannelMemberRole>,
3330 session: MessageContext,
3331) -> Result<()> {
3332 let db = session.db().await;
3333 let channel_id = ChannelId::from_proto(request.channel_id);
3334 let member_id = UserId::from_proto(request.user_id);
3335 let result = db
3336 .set_channel_member_role(
3337 channel_id,
3338 session.user_id(),
3339 member_id,
3340 request.role().into(),
3341 )
3342 .await?;
3343
3344 match result {
3345 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3346 let mut connection_pool = session.connection_pool().await;
3347 notify_membership_updated(
3348 &mut connection_pool,
3349 membership_update,
3350 member_id,
3351 &session.peer,
3352 )
3353 }
3354 db::SetMemberRoleResult::InviteUpdated(channel) => {
3355 let update = proto::UpdateChannels {
3356 channel_invitations: vec![channel.to_proto()],
3357 ..Default::default()
3358 };
3359
3360 for connection_id in session
3361 .connection_pool()
3362 .await
3363 .user_connection_ids(member_id)
3364 {
3365 session.peer.send(connection_id, update.clone())?;
3366 }
3367 }
3368 }
3369
3370 response.send(proto::Ack {})?;
3371 Ok(())
3372}
3373
3374/// Change the name of a channel
3375async fn rename_channel(
3376 request: proto::RenameChannel,
3377 response: Response<proto::RenameChannel>,
3378 session: MessageContext,
3379) -> Result<()> {
3380 let db = session.db().await;
3381 let channel_id = ChannelId::from_proto(request.channel_id);
3382 let channel_model = db
3383 .rename_channel(channel_id, session.user_id(), &request.name)
3384 .await?;
3385 let root_id = channel_model.root_id();
3386 let channel = Channel::from_model(channel_model);
3387
3388 response.send(proto::RenameChannelResponse {
3389 channel: Some(channel.to_proto()),
3390 })?;
3391
3392 let connection_pool = session.connection_pool().await;
3393 let update = proto::UpdateChannels {
3394 channels: vec![channel.to_proto()],
3395 ..Default::default()
3396 };
3397 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3398 if role.can_see_channel(channel.visibility) {
3399 session.peer.send(connection_id, update.clone())?;
3400 }
3401 }
3402
3403 Ok(())
3404}
3405
3406/// Move a channel to a new parent.
3407async fn move_channel(
3408 request: proto::MoveChannel,
3409 response: Response<proto::MoveChannel>,
3410 session: MessageContext,
3411) -> Result<()> {
3412 let channel_id = ChannelId::from_proto(request.channel_id);
3413 let to = ChannelId::from_proto(request.to);
3414
3415 let (root_id, channels) = session
3416 .db()
3417 .await
3418 .move_channel(channel_id, to, session.user_id())
3419 .await?;
3420
3421 let connection_pool = session.connection_pool().await;
3422 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3423 let channels = channels
3424 .iter()
3425 .filter_map(|channel| {
3426 if role.can_see_channel(channel.visibility) {
3427 Some(channel.to_proto())
3428 } else {
3429 None
3430 }
3431 })
3432 .collect::<Vec<_>>();
3433 if channels.is_empty() {
3434 continue;
3435 }
3436
3437 let update = proto::UpdateChannels {
3438 channels,
3439 ..Default::default()
3440 };
3441
3442 session.peer.send(connection_id, update.clone())?;
3443 }
3444
3445 response.send(Ack {})?;
3446 Ok(())
3447}
3448
3449async fn reorder_channel(
3450 request: proto::ReorderChannel,
3451 response: Response<proto::ReorderChannel>,
3452 session: MessageContext,
3453) -> Result<()> {
3454 let channel_id = ChannelId::from_proto(request.channel_id);
3455 let direction = request.direction();
3456
3457 let updated_channels = session
3458 .db()
3459 .await
3460 .reorder_channel(channel_id, direction, session.user_id())
3461 .await?;
3462
3463 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3464 let connection_pool = session.connection_pool().await;
3465 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3466 let channels = updated_channels
3467 .iter()
3468 .filter_map(|channel| {
3469 if role.can_see_channel(channel.visibility) {
3470 Some(channel.to_proto())
3471 } else {
3472 None
3473 }
3474 })
3475 .collect::<Vec<_>>();
3476
3477 if channels.is_empty() {
3478 continue;
3479 }
3480
3481 let update = proto::UpdateChannels {
3482 channels,
3483 ..Default::default()
3484 };
3485
3486 session.peer.send(connection_id, update.clone())?;
3487 }
3488 }
3489
3490 response.send(Ack {})?;
3491 Ok(())
3492}
3493
3494/// Get the list of channel members
3495async fn get_channel_members(
3496 request: proto::GetChannelMembers,
3497 response: Response<proto::GetChannelMembers>,
3498 session: MessageContext,
3499) -> Result<()> {
3500 let db = session.db().await;
3501 let channel_id = ChannelId::from_proto(request.channel_id);
3502 let limit = if request.limit == 0 {
3503 u16::MAX as u64
3504 } else {
3505 request.limit
3506 };
3507 let (members, users) = db
3508 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3509 .await?;
3510 response.send(proto::GetChannelMembersResponse { members, users })?;
3511 Ok(())
3512}
3513
3514/// Accept or decline a channel invitation.
3515async fn respond_to_channel_invite(
3516 request: proto::RespondToChannelInvite,
3517 response: Response<proto::RespondToChannelInvite>,
3518 session: MessageContext,
3519) -> Result<()> {
3520 let db = session.db().await;
3521 let channel_id = ChannelId::from_proto(request.channel_id);
3522 let RespondToChannelInvite {
3523 membership_update,
3524 notifications,
3525 } = db
3526 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3527 .await?;
3528
3529 let mut connection_pool = session.connection_pool().await;
3530 if let Some(membership_update) = membership_update {
3531 notify_membership_updated(
3532 &mut connection_pool,
3533 membership_update,
3534 session.user_id(),
3535 &session.peer,
3536 );
3537 } else {
3538 let update = proto::UpdateChannels {
3539 remove_channel_invitations: vec![channel_id.to_proto()],
3540 ..Default::default()
3541 };
3542
3543 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3544 session.peer.send(connection_id, update.clone())?;
3545 }
3546 };
3547
3548 send_notifications(&connection_pool, &session.peer, notifications);
3549
3550 response.send(proto::Ack {})?;
3551
3552 Ok(())
3553}
3554
3555/// Join the channels' room
3556async fn join_channel(
3557 request: proto::JoinChannel,
3558 response: Response<proto::JoinChannel>,
3559 session: MessageContext,
3560) -> Result<()> {
3561 let channel_id = ChannelId::from_proto(request.channel_id);
3562 join_channel_internal(channel_id, Box::new(response), session).await
3563}
3564
3565trait JoinChannelInternalResponse {
3566 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3567}
3568impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3569 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3570 Response::<proto::JoinChannel>::send(self, result)
3571 }
3572}
3573impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3574 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3575 Response::<proto::JoinRoom>::send(self, result)
3576 }
3577}
3578
3579async fn join_channel_internal(
3580 channel_id: ChannelId,
3581 response: Box<impl JoinChannelInternalResponse>,
3582 session: MessageContext,
3583) -> Result<()> {
3584 let joined_room = {
3585 let mut db = session.db().await;
3586 // If zed quits without leaving the room, and the user re-opens zed before the
3587 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3588 // room they were in.
3589 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3590 tracing::info!(
3591 stale_connection_id = %connection,
3592 "cleaning up stale connection",
3593 );
3594 drop(db);
3595 leave_room_for_session(&session, connection).await?;
3596 db = session.db().await;
3597 }
3598
3599 let (joined_room, membership_updated, role) = db
3600 .join_channel(channel_id, session.user_id(), session.connection_id)
3601 .await?;
3602
3603 let live_kit_connection_info =
3604 session
3605 .app_state
3606 .livekit_client
3607 .as_ref()
3608 .and_then(|live_kit| {
3609 let (can_publish, token) = if role == ChannelRole::Guest {
3610 (
3611 false,
3612 live_kit
3613 .guest_token(
3614 &joined_room.room.livekit_room,
3615 &session.user_id().to_string(),
3616 )
3617 .trace_err()?,
3618 )
3619 } else {
3620 (
3621 true,
3622 live_kit
3623 .room_token(
3624 &joined_room.room.livekit_room,
3625 &session.user_id().to_string(),
3626 )
3627 .trace_err()?,
3628 )
3629 };
3630
3631 Some(LiveKitConnectionInfo {
3632 server_url: live_kit.url().into(),
3633 token,
3634 can_publish,
3635 })
3636 });
3637
3638 response.send(proto::JoinRoomResponse {
3639 room: Some(joined_room.room.clone()),
3640 channel_id: joined_room
3641 .channel
3642 .as_ref()
3643 .map(|channel| channel.id.to_proto()),
3644 live_kit_connection_info,
3645 })?;
3646
3647 let mut connection_pool = session.connection_pool().await;
3648 if let Some(membership_updated) = membership_updated {
3649 notify_membership_updated(
3650 &mut connection_pool,
3651 membership_updated,
3652 session.user_id(),
3653 &session.peer,
3654 );
3655 }
3656
3657 room_updated(&joined_room.room, &session.peer);
3658
3659 joined_room
3660 };
3661
3662 channel_updated(
3663 &joined_room.channel.context("channel not returned")?,
3664 &joined_room.room,
3665 &session.peer,
3666 &*session.connection_pool().await,
3667 );
3668
3669 update_user_contacts(session.user_id(), &session).await?;
3670 Ok(())
3671}
3672
3673/// Start editing the channel notes
3674async fn join_channel_buffer(
3675 request: proto::JoinChannelBuffer,
3676 response: Response<proto::JoinChannelBuffer>,
3677 session: MessageContext,
3678) -> Result<()> {
3679 let db = session.db().await;
3680 let channel_id = ChannelId::from_proto(request.channel_id);
3681
3682 let open_response = db
3683 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3684 .await?;
3685
3686 let collaborators = open_response.collaborators.clone();
3687 response.send(open_response)?;
3688
3689 let update = UpdateChannelBufferCollaborators {
3690 channel_id: channel_id.to_proto(),
3691 collaborators: collaborators.clone(),
3692 };
3693 channel_buffer_updated(
3694 session.connection_id,
3695 collaborators
3696 .iter()
3697 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3698 &update,
3699 &session.peer,
3700 );
3701
3702 Ok(())
3703}
3704
3705/// Edit the channel notes
3706async fn update_channel_buffer(
3707 request: proto::UpdateChannelBuffer,
3708 session: MessageContext,
3709) -> Result<()> {
3710 let db = session.db().await;
3711 let channel_id = ChannelId::from_proto(request.channel_id);
3712
3713 let (collaborators, epoch, version) = db
3714 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3715 .await?;
3716
3717 channel_buffer_updated(
3718 session.connection_id,
3719 collaborators.clone(),
3720 &proto::UpdateChannelBuffer {
3721 channel_id: channel_id.to_proto(),
3722 operations: request.operations,
3723 },
3724 &session.peer,
3725 );
3726
3727 let pool = &*session.connection_pool().await;
3728
3729 let non_collaborators =
3730 pool.channel_connection_ids(channel_id)
3731 .filter_map(|(connection_id, _)| {
3732 if collaborators.contains(&connection_id) {
3733 None
3734 } else {
3735 Some(connection_id)
3736 }
3737 });
3738
3739 broadcast(None, non_collaborators, |peer_id| {
3740 session.peer.send(
3741 peer_id,
3742 proto::UpdateChannels {
3743 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3744 channel_id: channel_id.to_proto(),
3745 epoch: epoch as u64,
3746 version: version.clone(),
3747 }],
3748 ..Default::default()
3749 },
3750 )
3751 });
3752
3753 Ok(())
3754}
3755
3756/// Rejoin the channel notes after a connection blip
3757async fn rejoin_channel_buffers(
3758 request: proto::RejoinChannelBuffers,
3759 response: Response<proto::RejoinChannelBuffers>,
3760 session: MessageContext,
3761) -> Result<()> {
3762 let db = session.db().await;
3763 let buffers = db
3764 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3765 .await?;
3766
3767 for rejoined_buffer in &buffers {
3768 let collaborators_to_notify = rejoined_buffer
3769 .buffer
3770 .collaborators
3771 .iter()
3772 .filter_map(|c| Some(c.peer_id?.into()));
3773 channel_buffer_updated(
3774 session.connection_id,
3775 collaborators_to_notify,
3776 &proto::UpdateChannelBufferCollaborators {
3777 channel_id: rejoined_buffer.buffer.channel_id,
3778 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3779 },
3780 &session.peer,
3781 );
3782 }
3783
3784 response.send(proto::RejoinChannelBuffersResponse {
3785 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3786 })?;
3787
3788 Ok(())
3789}
3790
3791/// Stop editing the channel notes
3792async fn leave_channel_buffer(
3793 request: proto::LeaveChannelBuffer,
3794 response: Response<proto::LeaveChannelBuffer>,
3795 session: MessageContext,
3796) -> Result<()> {
3797 let db = session.db().await;
3798 let channel_id = ChannelId::from_proto(request.channel_id);
3799
3800 let left_buffer = db
3801 .leave_channel_buffer(channel_id, session.connection_id)
3802 .await?;
3803
3804 response.send(Ack {})?;
3805
3806 channel_buffer_updated(
3807 session.connection_id,
3808 left_buffer.connections,
3809 &proto::UpdateChannelBufferCollaborators {
3810 channel_id: channel_id.to_proto(),
3811 collaborators: left_buffer.collaborators,
3812 },
3813 &session.peer,
3814 );
3815
3816 Ok(())
3817}
3818
3819fn channel_buffer_updated<T: EnvelopedMessage>(
3820 sender_id: ConnectionId,
3821 collaborators: impl IntoIterator<Item = ConnectionId>,
3822 message: &T,
3823 peer: &Peer,
3824) {
3825 broadcast(Some(sender_id), collaborators, |peer_id| {
3826 peer.send(peer_id, message.clone())
3827 });
3828}
3829
3830fn send_notifications(
3831 connection_pool: &ConnectionPool,
3832 peer: &Peer,
3833 notifications: db::NotificationBatch,
3834) {
3835 for (user_id, notification) in notifications {
3836 for connection_id in connection_pool.user_connection_ids(user_id) {
3837 if let Err(error) = peer.send(
3838 connection_id,
3839 proto::AddNotification {
3840 notification: Some(notification.clone()),
3841 },
3842 ) {
3843 tracing::error!(
3844 "failed to send notification to {:?} {}",
3845 connection_id,
3846 error
3847 );
3848 }
3849 }
3850 }
3851}
3852
3853/// Send a message to the channel
3854async fn send_channel_message(
3855 request: proto::SendChannelMessage,
3856 response: Response<proto::SendChannelMessage>,
3857 session: MessageContext,
3858) -> Result<()> {
3859 // Validate the message body.
3860 let body = request.body.trim().to_string();
3861 if body.len() > MAX_MESSAGE_LEN {
3862 return Err(anyhow!("message is too long"))?;
3863 }
3864 if body.is_empty() {
3865 return Err(anyhow!("message can't be blank"))?;
3866 }
3867
3868 // TODO: adjust mentions if body is trimmed
3869
3870 let timestamp = OffsetDateTime::now_utc();
3871 let nonce = request.nonce.context("nonce can't be blank")?;
3872
3873 let channel_id = ChannelId::from_proto(request.channel_id);
3874 let CreatedChannelMessage {
3875 message_id,
3876 participant_connection_ids,
3877 notifications,
3878 } = session
3879 .db()
3880 .await
3881 .create_channel_message(
3882 channel_id,
3883 session.user_id(),
3884 &body,
3885 &request.mentions,
3886 timestamp,
3887 nonce.clone().into(),
3888 request.reply_to_message_id.map(MessageId::from_proto),
3889 )
3890 .await?;
3891
3892 let message = proto::ChannelMessage {
3893 sender_id: session.user_id().to_proto(),
3894 id: message_id.to_proto(),
3895 body,
3896 mentions: request.mentions,
3897 timestamp: timestamp.unix_timestamp() as u64,
3898 nonce: Some(nonce),
3899 reply_to_message_id: request.reply_to_message_id,
3900 edited_at: None,
3901 };
3902 broadcast(
3903 Some(session.connection_id),
3904 participant_connection_ids.clone(),
3905 |connection| {
3906 session.peer.send(
3907 connection,
3908 proto::ChannelMessageSent {
3909 channel_id: channel_id.to_proto(),
3910 message: Some(message.clone()),
3911 },
3912 )
3913 },
3914 );
3915 response.send(proto::SendChannelMessageResponse {
3916 message: Some(message),
3917 })?;
3918
3919 let pool = &*session.connection_pool().await;
3920 let non_participants =
3921 pool.channel_connection_ids(channel_id)
3922 .filter_map(|(connection_id, _)| {
3923 if participant_connection_ids.contains(&connection_id) {
3924 None
3925 } else {
3926 Some(connection_id)
3927 }
3928 });
3929 broadcast(None, non_participants, |peer_id| {
3930 session.peer.send(
3931 peer_id,
3932 proto::UpdateChannels {
3933 latest_channel_message_ids: vec![proto::ChannelMessageId {
3934 channel_id: channel_id.to_proto(),
3935 message_id: message_id.to_proto(),
3936 }],
3937 ..Default::default()
3938 },
3939 )
3940 });
3941 send_notifications(pool, &session.peer, notifications);
3942
3943 Ok(())
3944}
3945
3946/// Delete a channel message
3947async fn remove_channel_message(
3948 request: proto::RemoveChannelMessage,
3949 response: Response<proto::RemoveChannelMessage>,
3950 session: MessageContext,
3951) -> Result<()> {
3952 let channel_id = ChannelId::from_proto(request.channel_id);
3953 let message_id = MessageId::from_proto(request.message_id);
3954 let (connection_ids, existing_notification_ids) = session
3955 .db()
3956 .await
3957 .remove_channel_message(channel_id, message_id, session.user_id())
3958 .await?;
3959
3960 broadcast(
3961 Some(session.connection_id),
3962 connection_ids,
3963 move |connection| {
3964 session.peer.send(connection, request.clone())?;
3965
3966 for notification_id in &existing_notification_ids {
3967 session.peer.send(
3968 connection,
3969 proto::DeleteNotification {
3970 notification_id: (*notification_id).to_proto(),
3971 },
3972 )?;
3973 }
3974
3975 Ok(())
3976 },
3977 );
3978 response.send(proto::Ack {})?;
3979 Ok(())
3980}
3981
3982async fn update_channel_message(
3983 request: proto::UpdateChannelMessage,
3984 response: Response<proto::UpdateChannelMessage>,
3985 session: MessageContext,
3986) -> Result<()> {
3987 let channel_id = ChannelId::from_proto(request.channel_id);
3988 let message_id = MessageId::from_proto(request.message_id);
3989 let updated_at = OffsetDateTime::now_utc();
3990 let UpdatedChannelMessage {
3991 message_id,
3992 participant_connection_ids,
3993 notifications,
3994 reply_to_message_id,
3995 timestamp,
3996 deleted_mention_notification_ids,
3997 updated_mention_notifications,
3998 } = session
3999 .db()
4000 .await
4001 .update_channel_message(
4002 channel_id,
4003 message_id,
4004 session.user_id(),
4005 request.body.as_str(),
4006 &request.mentions,
4007 updated_at,
4008 )
4009 .await?;
4010
4011 let nonce = request.nonce.clone().context("nonce can't be blank")?;
4012
4013 let message = proto::ChannelMessage {
4014 sender_id: session.user_id().to_proto(),
4015 id: message_id.to_proto(),
4016 body: request.body.clone(),
4017 mentions: request.mentions.clone(),
4018 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4019 nonce: Some(nonce),
4020 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4021 edited_at: Some(updated_at.unix_timestamp() as u64),
4022 };
4023
4024 response.send(proto::Ack {})?;
4025
4026 let pool = &*session.connection_pool().await;
4027 broadcast(
4028 Some(session.connection_id),
4029 participant_connection_ids,
4030 |connection| {
4031 session.peer.send(
4032 connection,
4033 proto::ChannelMessageUpdate {
4034 channel_id: channel_id.to_proto(),
4035 message: Some(message.clone()),
4036 },
4037 )?;
4038
4039 for notification_id in &deleted_mention_notification_ids {
4040 session.peer.send(
4041 connection,
4042 proto::DeleteNotification {
4043 notification_id: (*notification_id).to_proto(),
4044 },
4045 )?;
4046 }
4047
4048 for notification in &updated_mention_notifications {
4049 session.peer.send(
4050 connection,
4051 proto::UpdateNotification {
4052 notification: Some(notification.clone()),
4053 },
4054 )?;
4055 }
4056
4057 Ok(())
4058 },
4059 );
4060
4061 send_notifications(pool, &session.peer, notifications);
4062
4063 Ok(())
4064}
4065
4066/// Mark a channel message as read
4067async fn acknowledge_channel_message(
4068 request: proto::AckChannelMessage,
4069 session: MessageContext,
4070) -> Result<()> {
4071 let channel_id = ChannelId::from_proto(request.channel_id);
4072 let message_id = MessageId::from_proto(request.message_id);
4073 let notifications = session
4074 .db()
4075 .await
4076 .observe_channel_message(channel_id, session.user_id(), message_id)
4077 .await?;
4078 send_notifications(
4079 &*session.connection_pool().await,
4080 &session.peer,
4081 notifications,
4082 );
4083 Ok(())
4084}
4085
4086/// Mark a buffer version as synced
4087async fn acknowledge_buffer_version(
4088 request: proto::AckBufferOperation,
4089 session: MessageContext,
4090) -> Result<()> {
4091 let buffer_id = BufferId::from_proto(request.buffer_id);
4092 session
4093 .db()
4094 .await
4095 .observe_buffer_version(
4096 buffer_id,
4097 session.user_id(),
4098 request.epoch as i32,
4099 &request.version,
4100 )
4101 .await?;
4102 Ok(())
4103}
4104
4105/// Get a Supermaven API key for the user
4106async fn get_supermaven_api_key(
4107 _request: proto::GetSupermavenApiKey,
4108 response: Response<proto::GetSupermavenApiKey>,
4109 session: MessageContext,
4110) -> Result<()> {
4111 let user_id: String = session.user_id().to_string();
4112 if !session.is_staff() {
4113 return Err(anyhow!("supermaven not enabled for this account"))?;
4114 }
4115
4116 let email = session.email().context("user must have an email")?;
4117
4118 let supermaven_admin_api = session
4119 .supermaven_client
4120 .as_ref()
4121 .context("supermaven not configured")?;
4122
4123 let result = supermaven_admin_api
4124 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4125 .await?;
4126
4127 response.send(proto::GetSupermavenApiKeyResponse {
4128 api_key: result.api_key,
4129 })?;
4130
4131 Ok(())
4132}
4133
4134/// Start receiving chat updates for a channel
4135async fn join_channel_chat(
4136 request: proto::JoinChannelChat,
4137 response: Response<proto::JoinChannelChat>,
4138 session: MessageContext,
4139) -> Result<()> {
4140 let channel_id = ChannelId::from_proto(request.channel_id);
4141
4142 let db = session.db().await;
4143 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4144 .await?;
4145 let messages = db
4146 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4147 .await?;
4148 response.send(proto::JoinChannelChatResponse {
4149 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4150 messages,
4151 })?;
4152 Ok(())
4153}
4154
4155/// Stop receiving chat updates for a channel
4156async fn leave_channel_chat(
4157 request: proto::LeaveChannelChat,
4158 session: MessageContext,
4159) -> Result<()> {
4160 let channel_id = ChannelId::from_proto(request.channel_id);
4161 session
4162 .db()
4163 .await
4164 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4165 .await?;
4166 Ok(())
4167}
4168
4169/// Retrieve the chat history for a channel
4170async fn get_channel_messages(
4171 request: proto::GetChannelMessages,
4172 response: Response<proto::GetChannelMessages>,
4173 session: MessageContext,
4174) -> Result<()> {
4175 let channel_id = ChannelId::from_proto(request.channel_id);
4176 let messages = session
4177 .db()
4178 .await
4179 .get_channel_messages(
4180 channel_id,
4181 session.user_id(),
4182 MESSAGE_COUNT_PER_PAGE,
4183 Some(MessageId::from_proto(request.before_message_id)),
4184 )
4185 .await?;
4186 response.send(proto::GetChannelMessagesResponse {
4187 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4188 messages,
4189 })?;
4190 Ok(())
4191}
4192
4193/// Retrieve specific chat messages
4194async fn get_channel_messages_by_id(
4195 request: proto::GetChannelMessagesById,
4196 response: Response<proto::GetChannelMessagesById>,
4197 session: MessageContext,
4198) -> Result<()> {
4199 let message_ids = request
4200 .message_ids
4201 .iter()
4202 .map(|id| MessageId::from_proto(*id))
4203 .collect::<Vec<_>>();
4204 let messages = session
4205 .db()
4206 .await
4207 .get_channel_messages_by_id(session.user_id(), &message_ids)
4208 .await?;
4209 response.send(proto::GetChannelMessagesResponse {
4210 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4211 messages,
4212 })?;
4213 Ok(())
4214}
4215
4216/// Retrieve the current users notifications
4217async fn get_notifications(
4218 request: proto::GetNotifications,
4219 response: Response<proto::GetNotifications>,
4220 session: MessageContext,
4221) -> Result<()> {
4222 let notifications = session
4223 .db()
4224 .await
4225 .get_notifications(
4226 session.user_id(),
4227 NOTIFICATION_COUNT_PER_PAGE,
4228 request.before_id.map(db::NotificationId::from_proto),
4229 )
4230 .await?;
4231 response.send(proto::GetNotificationsResponse {
4232 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4233 notifications,
4234 })?;
4235 Ok(())
4236}
4237
4238/// Mark notifications as read
4239async fn mark_notification_as_read(
4240 request: proto::MarkNotificationRead,
4241 response: Response<proto::MarkNotificationRead>,
4242 session: MessageContext,
4243) -> Result<()> {
4244 let database = &session.db().await;
4245 let notifications = database
4246 .mark_notification_as_read_by_id(
4247 session.user_id(),
4248 NotificationId::from_proto(request.notification_id),
4249 )
4250 .await?;
4251 send_notifications(
4252 &*session.connection_pool().await,
4253 &session.peer,
4254 notifications,
4255 );
4256 response.send(proto::Ack {})?;
4257 Ok(())
4258}
4259
4260/// Get the current users information
4261async fn get_private_user_info(
4262 _request: proto::GetPrivateUserInfo,
4263 response: Response<proto::GetPrivateUserInfo>,
4264 session: MessageContext,
4265) -> Result<()> {
4266 let db = session.db().await;
4267
4268 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4269 let user = db
4270 .get_user_by_id(session.user_id())
4271 .await?
4272 .context("user not found")?;
4273 let flags = db.get_user_flags(session.user_id()).await?;
4274
4275 response.send(proto::GetPrivateUserInfoResponse {
4276 metrics_id,
4277 staff: user.admin,
4278 flags,
4279 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4280 })?;
4281 Ok(())
4282}
4283
4284/// Accept the terms of service (tos) on behalf of the current user
4285async fn accept_terms_of_service(
4286 _request: proto::AcceptTermsOfService,
4287 response: Response<proto::AcceptTermsOfService>,
4288 session: MessageContext,
4289) -> Result<()> {
4290 let db = session.db().await;
4291
4292 let accepted_tos_at = Utc::now();
4293 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4294 .await?;
4295
4296 response.send(proto::AcceptTermsOfServiceResponse {
4297 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4298 })?;
4299
4300 // When the user accepts the terms of service, we want to refresh their LLM
4301 // token to grant access.
4302 session
4303 .peer
4304 .send(session.connection_id, proto::RefreshLlmToken {})?;
4305
4306 Ok(())
4307}
4308
4309async fn get_llm_api_token(
4310 _request: proto::GetLlmToken,
4311 response: Response<proto::GetLlmToken>,
4312 session: MessageContext,
4313) -> Result<()> {
4314 let db = session.db().await;
4315
4316 let flags = db.get_user_flags(session.user_id()).await?;
4317
4318 let user_id = session.user_id();
4319 let user = db
4320 .get_user_by_id(user_id)
4321 .await?
4322 .with_context(|| format!("user {user_id} not found"))?;
4323
4324 if user.accepted_tos_at.is_none() {
4325 Err(anyhow!("terms of service not accepted"))?
4326 }
4327
4328 let stripe_client = session
4329 .app_state
4330 .stripe_client
4331 .as_ref()
4332 .context("failed to retrieve Stripe client")?;
4333
4334 let stripe_billing = session
4335 .app_state
4336 .stripe_billing
4337 .as_ref()
4338 .context("failed to retrieve Stripe billing object")?;
4339
4340 let billing_customer = if let Some(billing_customer) =
4341 db.get_billing_customer_by_user_id(user.id).await?
4342 {
4343 billing_customer
4344 } else {
4345 let customer_id = stripe_billing
4346 .find_or_create_customer_by_email(user.email_address.as_deref())
4347 .await?;
4348
4349 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4350 .await?
4351 .context("billing customer not found")?
4352 };
4353
4354 let billing_subscription =
4355 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4356 billing_subscription
4357 } else {
4358 let stripe_customer_id =
4359 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4360
4361 let stripe_subscription = stripe_billing
4362 .subscribe_to_zed_free(stripe_customer_id)
4363 .await?;
4364
4365 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4366 billing_customer_id: billing_customer.id,
4367 kind: Some(SubscriptionKind::ZedFree),
4368 stripe_subscription_id: stripe_subscription.id.to_string(),
4369 stripe_subscription_status: stripe_subscription.status.into(),
4370 stripe_cancellation_reason: None,
4371 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4372 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4373 })
4374 .await?
4375 };
4376
4377 let billing_preferences = db.get_billing_preferences(user.id).await?;
4378
4379 let token = LlmTokenClaims::create(
4380 &user,
4381 session.is_staff(),
4382 billing_customer,
4383 billing_preferences,
4384 &flags,
4385 billing_subscription,
4386 session.system_id.clone(),
4387 &session.app_state.config,
4388 )?;
4389 response.send(proto::GetLlmTokenResponse { token })?;
4390 Ok(())
4391}
4392
4393fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4394 let message = match message {
4395 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4396 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4397 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4398 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4399 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4400 code: frame.code.into(),
4401 reason: frame.reason.as_str().to_owned().into(),
4402 })),
4403 // We should never receive a frame while reading the message, according
4404 // to the `tungstenite` maintainers:
4405 //
4406 // > It cannot occur when you read messages from the WebSocket, but it
4407 // > can be used when you want to send the raw frames (e.g. you want to
4408 // > send the frames to the WebSocket without composing the full message first).
4409 // >
4410 // > — https://github.com/snapview/tungstenite-rs/issues/268
4411 TungsteniteMessage::Frame(_) => {
4412 bail!("received an unexpected frame while reading the message")
4413 }
4414 };
4415
4416 Ok(message)
4417}
4418
4419fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4420 match message {
4421 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4422 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4423 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4424 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4425 AxumMessage::Close(frame) => {
4426 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4427 code: frame.code.into(),
4428 reason: frame.reason.as_ref().into(),
4429 }))
4430 }
4431 }
4432}
4433
4434fn notify_membership_updated(
4435 connection_pool: &mut ConnectionPool,
4436 result: MembershipUpdated,
4437 user_id: UserId,
4438 peer: &Peer,
4439) {
4440 for membership in &result.new_channels.channel_memberships {
4441 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4442 }
4443 for channel_id in &result.removed_channels {
4444 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4445 }
4446
4447 let user_channels_update = proto::UpdateUserChannels {
4448 channel_memberships: result
4449 .new_channels
4450 .channel_memberships
4451 .iter()
4452 .map(|cm| proto::ChannelMembership {
4453 channel_id: cm.channel_id.to_proto(),
4454 role: cm.role.into(),
4455 })
4456 .collect(),
4457 ..Default::default()
4458 };
4459
4460 let mut update = build_channels_update(result.new_channels);
4461 update.delete_channels = result
4462 .removed_channels
4463 .into_iter()
4464 .map(|id| id.to_proto())
4465 .collect();
4466 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4467
4468 for connection_id in connection_pool.user_connection_ids(user_id) {
4469 peer.send(connection_id, user_channels_update.clone())
4470 .trace_err();
4471 peer.send(connection_id, update.clone()).trace_err();
4472 }
4473}
4474
4475fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4476 proto::UpdateUserChannels {
4477 channel_memberships: channels
4478 .channel_memberships
4479 .iter()
4480 .map(|m| proto::ChannelMembership {
4481 channel_id: m.channel_id.to_proto(),
4482 role: m.role.into(),
4483 })
4484 .collect(),
4485 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4486 observed_channel_message_id: channels.observed_channel_messages.clone(),
4487 }
4488}
4489
4490fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4491 let mut update = proto::UpdateChannels::default();
4492
4493 for channel in channels.channels {
4494 update.channels.push(channel.to_proto());
4495 }
4496
4497 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4498 update.latest_channel_message_ids = channels.latest_channel_messages;
4499
4500 for (channel_id, participants) in channels.channel_participants {
4501 update
4502 .channel_participants
4503 .push(proto::ChannelParticipants {
4504 channel_id: channel_id.to_proto(),
4505 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4506 });
4507 }
4508
4509 for channel in channels.invited_channels {
4510 update.channel_invitations.push(channel.to_proto());
4511 }
4512
4513 update
4514}
4515
4516fn build_initial_contacts_update(
4517 contacts: Vec<db::Contact>,
4518 pool: &ConnectionPool,
4519) -> proto::UpdateContacts {
4520 let mut update = proto::UpdateContacts::default();
4521
4522 for contact in contacts {
4523 match contact {
4524 db::Contact::Accepted { user_id, busy } => {
4525 update.contacts.push(contact_for_user(user_id, busy, pool));
4526 }
4527 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4528 db::Contact::Incoming { user_id } => {
4529 update
4530 .incoming_requests
4531 .push(proto::IncomingContactRequest {
4532 requester_id: user_id.to_proto(),
4533 })
4534 }
4535 }
4536 }
4537
4538 update
4539}
4540
4541fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4542 proto::Contact {
4543 user_id: user_id.to_proto(),
4544 online: pool.is_user_online(user_id),
4545 busy,
4546 }
4547}
4548
4549fn room_updated(room: &proto::Room, peer: &Peer) {
4550 broadcast(
4551 None,
4552 room.participants
4553 .iter()
4554 .filter_map(|participant| Some(participant.peer_id?.into())),
4555 |peer_id| {
4556 peer.send(
4557 peer_id,
4558 proto::RoomUpdated {
4559 room: Some(room.clone()),
4560 },
4561 )
4562 },
4563 );
4564}
4565
4566fn channel_updated(
4567 channel: &db::channel::Model,
4568 room: &proto::Room,
4569 peer: &Peer,
4570 pool: &ConnectionPool,
4571) {
4572 let participants = room
4573 .participants
4574 .iter()
4575 .map(|p| p.user_id)
4576 .collect::<Vec<_>>();
4577
4578 broadcast(
4579 None,
4580 pool.channel_connection_ids(channel.root_id())
4581 .filter_map(|(channel_id, role)| {
4582 role.can_see_channel(channel.visibility)
4583 .then_some(channel_id)
4584 }),
4585 |peer_id| {
4586 peer.send(
4587 peer_id,
4588 proto::UpdateChannels {
4589 channel_participants: vec![proto::ChannelParticipants {
4590 channel_id: channel.id.to_proto(),
4591 participant_user_ids: participants.clone(),
4592 }],
4593 ..Default::default()
4594 },
4595 )
4596 },
4597 );
4598}
4599
4600async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4601 let db = session.db().await;
4602
4603 let contacts = db.get_contacts(user_id).await?;
4604 let busy = db.is_user_busy(user_id).await?;
4605
4606 let pool = session.connection_pool().await;
4607 let updated_contact = contact_for_user(user_id, busy, &pool);
4608 for contact in contacts {
4609 if let db::Contact::Accepted {
4610 user_id: contact_user_id,
4611 ..
4612 } = contact
4613 {
4614 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4615 session
4616 .peer
4617 .send(
4618 contact_conn_id,
4619 proto::UpdateContacts {
4620 contacts: vec![updated_contact.clone()],
4621 remove_contacts: Default::default(),
4622 incoming_requests: Default::default(),
4623 remove_incoming_requests: Default::default(),
4624 outgoing_requests: Default::default(),
4625 remove_outgoing_requests: Default::default(),
4626 },
4627 )
4628 .trace_err();
4629 }
4630 }
4631 }
4632 Ok(())
4633}
4634
4635async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4636 let mut contacts_to_update = HashSet::default();
4637
4638 let room_id;
4639 let canceled_calls_to_user_ids;
4640 let livekit_room;
4641 let delete_livekit_room;
4642 let room;
4643 let channel;
4644
4645 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4646 contacts_to_update.insert(session.user_id());
4647
4648 for project in left_room.left_projects.values() {
4649 project_left(project, session);
4650 }
4651
4652 room_id = RoomId::from_proto(left_room.room.id);
4653 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4654 livekit_room = mem::take(&mut left_room.room.livekit_room);
4655 delete_livekit_room = left_room.deleted;
4656 room = mem::take(&mut left_room.room);
4657 channel = mem::take(&mut left_room.channel);
4658
4659 room_updated(&room, &session.peer);
4660 } else {
4661 return Ok(());
4662 }
4663
4664 if let Some(channel) = channel {
4665 channel_updated(
4666 &channel,
4667 &room,
4668 &session.peer,
4669 &*session.connection_pool().await,
4670 );
4671 }
4672
4673 {
4674 let pool = session.connection_pool().await;
4675 for canceled_user_id in canceled_calls_to_user_ids {
4676 for connection_id in pool.user_connection_ids(canceled_user_id) {
4677 session
4678 .peer
4679 .send(
4680 connection_id,
4681 proto::CallCanceled {
4682 room_id: room_id.to_proto(),
4683 },
4684 )
4685 .trace_err();
4686 }
4687 contacts_to_update.insert(canceled_user_id);
4688 }
4689 }
4690
4691 for contact_user_id in contacts_to_update {
4692 update_user_contacts(contact_user_id, session).await?;
4693 }
4694
4695 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4696 live_kit
4697 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4698 .await
4699 .trace_err();
4700
4701 if delete_livekit_room {
4702 live_kit.delete_room(livekit_room).await.trace_err();
4703 }
4704 }
4705
4706 Ok(())
4707}
4708
4709async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4710 let left_channel_buffers = session
4711 .db()
4712 .await
4713 .leave_channel_buffers(session.connection_id)
4714 .await?;
4715
4716 for left_buffer in left_channel_buffers {
4717 channel_buffer_updated(
4718 session.connection_id,
4719 left_buffer.connections,
4720 &proto::UpdateChannelBufferCollaborators {
4721 channel_id: left_buffer.channel_id.to_proto(),
4722 collaborators: left_buffer.collaborators,
4723 },
4724 &session.peer,
4725 );
4726 }
4727
4728 Ok(())
4729}
4730
4731fn project_left(project: &db::LeftProject, session: &Session) {
4732 for connection_id in &project.connection_ids {
4733 if project.should_unshare {
4734 session
4735 .peer
4736 .send(
4737 *connection_id,
4738 proto::UnshareProject {
4739 project_id: project.id.to_proto(),
4740 },
4741 )
4742 .trace_err();
4743 } else {
4744 session
4745 .peer
4746 .send(
4747 *connection_id,
4748 proto::RemoveProjectCollaborator {
4749 project_id: project.id.to_proto(),
4750 peer_id: Some(session.connection_id.into()),
4751 },
4752 )
4753 .trace_err();
4754 }
4755 }
4756}
4757
4758pub trait ResultExt {
4759 type Ok;
4760
4761 fn trace_err(self) -> Option<Self::Ok>;
4762}
4763
4764impl<T, E> ResultExt for Result<T, E>
4765where
4766 E: std::fmt::Debug,
4767{
4768 type Ok = T;
4769
4770 #[track_caller]
4771 fn trace_err(self) -> Option<T> {
4772 match self {
4773 Ok(value) => Some(value),
4774 Err(error) => {
4775 tracing::error!("{:?}", error);
4776 None
4777 }
4778 }
4779 }
4780}