1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{
8 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
9 MIN_ACCOUNT_AGE_FOR_LLM_USE,
10};
11use crate::stripe_client::StripeCustomerId;
12use crate::{
13 AppState, Error, Result, auth,
14 db::{
15 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
16 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
17 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
18 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
19 },
20 executor::Executor,
21};
22use anyhow::{Context as _, anyhow, bail};
23use async_tungstenite::tungstenite::{
24 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
25};
26use axum::headers::UserAgent;
27use axum::{
28 Extension, Router, TypedHeader,
29 body::Body,
30 extract::{
31 ConnectInfo, WebSocketUpgrade,
32 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
33 },
34 headers::{Header, HeaderName},
35 http::StatusCode,
36 middleware,
37 response::IntoResponse,
38 routing::get,
39};
40use chrono::Utc;
41use collections::{HashMap, HashSet};
42pub use connection_pool::{ConnectionPool, ZedVersion};
43use core::fmt::{self, Debug, Formatter};
44use futures::TryFutureExt as _;
45use reqwest_client::ReqwestClient;
46use rpc::proto::{MultiLspQuery, split_repository_update};
47use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
48use tracing::Span;
49
50use futures::{
51 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
52 stream::FuturesUnordered,
53};
54use prometheus::{IntGauge, register_int_gauge};
55use rpc::{
56 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
57 proto::{
58 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
59 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
60 },
61};
62use semantic_version::SemanticVersion;
63use serde::{Serialize, Serializer};
64use std::{
65 any::TypeId,
66 future::Future,
67 marker::PhantomData,
68 mem,
69 net::SocketAddr,
70 ops::{Deref, DerefMut},
71 rc::Rc,
72 sync::{
73 Arc, OnceLock,
74 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
75 },
76 time::{Duration, Instant},
77};
78use time::OffsetDateTime;
79use tokio::sync::{Semaphore, watch};
80use tower::ServiceBuilder;
81use tracing::{
82 Instrument,
83 field::{self},
84 info_span, instrument,
85};
86
87pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
88
89// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
90pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
91
92const MESSAGE_COUNT_PER_PAGE: usize = 100;
93const MAX_MESSAGE_LEN: usize = 1024;
94const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
95const MAX_CONCURRENT_CONNECTIONS: usize = 512;
96
97static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
98
99const TOTAL_DURATION_MS: &str = "total_duration_ms";
100const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
101const QUEUE_DURATION_MS: &str = "queue_duration_ms";
102const HOST_WAITING_MS: &str = "host_waiting_ms";
103
104type MessageHandler =
105 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
106
107pub struct ConnectionGuard;
108
109impl ConnectionGuard {
110 pub fn try_acquire() -> Result<Self, ()> {
111 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
112 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
113 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
114 tracing::error!(
115 "too many concurrent connections: {}",
116 current_connections + 1
117 );
118 return Err(());
119 }
120 Ok(ConnectionGuard)
121 }
122}
123
124impl Drop for ConnectionGuard {
125 fn drop(&mut self) {
126 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
127 }
128}
129
130struct Response<R> {
131 peer: Arc<Peer>,
132 receipt: Receipt<R>,
133 responded: Arc<AtomicBool>,
134}
135
136impl<R: RequestMessage> Response<R> {
137 fn send(self, payload: R::Response) -> Result<()> {
138 self.responded.store(true, SeqCst);
139 self.peer.respond(self.receipt, payload)?;
140 Ok(())
141 }
142}
143
144#[derive(Clone, Debug)]
145pub enum Principal {
146 User(User),
147 Impersonated { user: User, admin: User },
148}
149
150impl Principal {
151 fn user(&self) -> &User {
152 match self {
153 Principal::User(user) => user,
154 Principal::Impersonated { user, .. } => user,
155 }
156 }
157
158 fn update_span(&self, span: &tracing::Span) {
159 match &self {
160 Principal::User(user) => {
161 span.record("user_id", user.id.0);
162 span.record("login", &user.github_login);
163 }
164 Principal::Impersonated { user, admin } => {
165 span.record("user_id", user.id.0);
166 span.record("login", &user.github_login);
167 span.record("impersonator", &admin.github_login);
168 }
169 }
170 }
171}
172
173#[derive(Clone)]
174struct MessageContext {
175 session: Session,
176 span: tracing::Span,
177}
178
179impl Deref for MessageContext {
180 type Target = Session;
181
182 fn deref(&self) -> &Self::Target {
183 &self.session
184 }
185}
186
187impl MessageContext {
188 pub fn forward_request<T: RequestMessage>(
189 &self,
190 receiver_id: ConnectionId,
191 request: T,
192 ) -> impl Future<Output = anyhow::Result<T::Response>> {
193 let request_start_time = Instant::now();
194 let span = self.span.clone();
195 tracing::info!("start forwarding request");
196 self.peer
197 .forward_request(self.connection_id, receiver_id, request)
198 .inspect(move |_| {
199 span.record(
200 HOST_WAITING_MS,
201 request_start_time.elapsed().as_micros() as f64 / 1000.0,
202 );
203 })
204 .inspect_err(|_| tracing::error!("error forwarding request"))
205 .inspect_ok(|_| tracing::info!("finished forwarding request"))
206 }
207}
208
209#[derive(Clone)]
210struct Session {
211 principal: Principal,
212 connection_id: ConnectionId,
213 db: Arc<tokio::sync::Mutex<DbHandle>>,
214 peer: Arc<Peer>,
215 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
216 app_state: Arc<AppState>,
217 supermaven_client: Option<Arc<SupermavenAdminApi>>,
218 /// The GeoIP country code for the user.
219 #[allow(unused)]
220 geoip_country_code: Option<String>,
221 system_id: Option<String>,
222 _executor: Executor,
223}
224
225impl Session {
226 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
227 #[cfg(test)]
228 tokio::task::yield_now().await;
229 let guard = self.db.lock().await;
230 #[cfg(test)]
231 tokio::task::yield_now().await;
232 guard
233 }
234
235 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
236 #[cfg(test)]
237 tokio::task::yield_now().await;
238 let guard = self.connection_pool.lock();
239 ConnectionPoolGuard {
240 guard,
241 _not_send: PhantomData,
242 }
243 }
244
245 fn is_staff(&self) -> bool {
246 match &self.principal {
247 Principal::User(user) => user.admin,
248 Principal::Impersonated { .. } => true,
249 }
250 }
251
252 fn user_id(&self) -> UserId {
253 match &self.principal {
254 Principal::User(user) => user.id,
255 Principal::Impersonated { user, .. } => user.id,
256 }
257 }
258
259 pub fn email(&self) -> Option<String> {
260 match &self.principal {
261 Principal::User(user) => user.email_address.clone(),
262 Principal::Impersonated { user, .. } => user.email_address.clone(),
263 }
264 }
265}
266
267impl Debug for Session {
268 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
269 let mut result = f.debug_struct("Session");
270 match &self.principal {
271 Principal::User(user) => {
272 result.field("user", &user.github_login);
273 }
274 Principal::Impersonated { user, admin } => {
275 result.field("user", &user.github_login);
276 result.field("impersonator", &admin.github_login);
277 }
278 }
279 result.field("connection_id", &self.connection_id).finish()
280 }
281}
282
283struct DbHandle(Arc<Database>);
284
285impl Deref for DbHandle {
286 type Target = Database;
287
288 fn deref(&self) -> &Self::Target {
289 self.0.as_ref()
290 }
291}
292
293pub struct Server {
294 id: parking_lot::Mutex<ServerId>,
295 peer: Arc<Peer>,
296 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
297 app_state: Arc<AppState>,
298 handlers: HashMap<TypeId, MessageHandler>,
299 teardown: watch::Sender<bool>,
300}
301
302pub(crate) struct ConnectionPoolGuard<'a> {
303 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
304 _not_send: PhantomData<Rc<()>>,
305}
306
307#[derive(Serialize)]
308pub struct ServerSnapshot<'a> {
309 peer: &'a Peer,
310 #[serde(serialize_with = "serialize_deref")]
311 connection_pool: ConnectionPoolGuard<'a>,
312}
313
314pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
315where
316 S: Serializer,
317 T: Deref<Target = U>,
318 U: Serialize,
319{
320 Serialize::serialize(value.deref(), serializer)
321}
322
323impl Server {
324 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
325 let mut server = Self {
326 id: parking_lot::Mutex::new(id),
327 peer: Peer::new(id.0 as u32),
328 app_state: app_state.clone(),
329 connection_pool: Default::default(),
330 handlers: Default::default(),
331 teardown: watch::channel(false).0,
332 };
333
334 server
335 .add_request_handler(ping)
336 .add_request_handler(create_room)
337 .add_request_handler(join_room)
338 .add_request_handler(rejoin_room)
339 .add_request_handler(leave_room)
340 .add_request_handler(set_room_participant_role)
341 .add_request_handler(call)
342 .add_request_handler(cancel_call)
343 .add_message_handler(decline_call)
344 .add_request_handler(update_participant_location)
345 .add_request_handler(share_project)
346 .add_message_handler(unshare_project)
347 .add_request_handler(join_project)
348 .add_message_handler(leave_project)
349 .add_request_handler(update_project)
350 .add_request_handler(update_worktree)
351 .add_request_handler(update_repository)
352 .add_request_handler(remove_repository)
353 .add_message_handler(start_language_server)
354 .add_message_handler(update_language_server)
355 .add_message_handler(update_diagnostic_summary)
356 .add_message_handler(update_worktree_settings)
357 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
358 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
359 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
360 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
361 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
362 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
363 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
364 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
365 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
366 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
367 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
368 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
369 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
370 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
371 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
372 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
373 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
374 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
375 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
376 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
377 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
378 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
379 .add_request_handler(
380 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
381 )
382 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
383 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
384 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
385 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
386 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
387 .add_request_handler(
388 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
389 )
390 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
391 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
392 .add_request_handler(
393 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
394 )
395 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
396 .add_request_handler(
397 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
398 )
399 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
400 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
401 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
402 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
403 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
404 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
405 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
406 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
407 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
408 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
409 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
410 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
411 .add_request_handler(
412 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
413 )
414 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
415 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
416 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
417 .add_request_handler(multi_lsp_query)
418 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
419 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
420 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
421 .add_message_handler(create_buffer_for_peer)
422 .add_request_handler(update_buffer)
423 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
424 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
425 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
426 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
427 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
428 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
429 .add_message_handler(
430 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
431 )
432 .add_request_handler(get_users)
433 .add_request_handler(fuzzy_search_users)
434 .add_request_handler(request_contact)
435 .add_request_handler(remove_contact)
436 .add_request_handler(respond_to_contact_request)
437 .add_message_handler(subscribe_to_channels)
438 .add_request_handler(create_channel)
439 .add_request_handler(delete_channel)
440 .add_request_handler(invite_channel_member)
441 .add_request_handler(remove_channel_member)
442 .add_request_handler(set_channel_member_role)
443 .add_request_handler(set_channel_visibility)
444 .add_request_handler(rename_channel)
445 .add_request_handler(join_channel_buffer)
446 .add_request_handler(leave_channel_buffer)
447 .add_message_handler(update_channel_buffer)
448 .add_request_handler(rejoin_channel_buffers)
449 .add_request_handler(get_channel_members)
450 .add_request_handler(respond_to_channel_invite)
451 .add_request_handler(join_channel)
452 .add_request_handler(join_channel_chat)
453 .add_message_handler(leave_channel_chat)
454 .add_request_handler(send_channel_message)
455 .add_request_handler(remove_channel_message)
456 .add_request_handler(update_channel_message)
457 .add_request_handler(get_channel_messages)
458 .add_request_handler(get_channel_messages_by_id)
459 .add_request_handler(get_notifications)
460 .add_request_handler(mark_notification_as_read)
461 .add_request_handler(move_channel)
462 .add_request_handler(reorder_channel)
463 .add_request_handler(follow)
464 .add_message_handler(unfollow)
465 .add_message_handler(update_followers)
466 .add_request_handler(get_private_user_info)
467 .add_request_handler(get_llm_api_token)
468 .add_request_handler(accept_terms_of_service)
469 .add_message_handler(acknowledge_channel_message)
470 .add_message_handler(acknowledge_buffer_version)
471 .add_request_handler(get_supermaven_api_key)
472 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
473 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
474 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
475 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
476 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
477 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
478 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
479 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
480 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
481 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
482 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
483 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
484 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
485 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
486 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
487 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
488 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
489 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
490 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
491 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
492 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
493 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
494 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
495 .add_message_handler(update_context);
496
497 Arc::new(server)
498 }
499
500 pub async fn start(&self) -> Result<()> {
501 let server_id = *self.id.lock();
502 let app_state = self.app_state.clone();
503 let peer = self.peer.clone();
504 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
505 let pool = self.connection_pool.clone();
506 let livekit_client = self.app_state.livekit_client.clone();
507
508 let span = info_span!("start server");
509 self.app_state.executor.spawn_detached(
510 async move {
511 tracing::info!("waiting for cleanup timeout");
512 timeout.await;
513 tracing::info!("cleanup timeout expired, retrieving stale rooms");
514
515 app_state
516 .db
517 .delete_stale_channel_chat_participants(
518 &app_state.config.zed_environment,
519 server_id,
520 )
521 .await
522 .trace_err();
523
524 if let Some((room_ids, channel_ids)) = app_state
525 .db
526 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
527 .await
528 .trace_err()
529 {
530 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
531 tracing::info!(
532 stale_channel_buffer_count = channel_ids.len(),
533 "retrieved stale channel buffers"
534 );
535
536 for channel_id in channel_ids {
537 if let Some(refreshed_channel_buffer) = app_state
538 .db
539 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
540 .await
541 .trace_err()
542 {
543 for connection_id in refreshed_channel_buffer.connection_ids {
544 peer.send(
545 connection_id,
546 proto::UpdateChannelBufferCollaborators {
547 channel_id: channel_id.to_proto(),
548 collaborators: refreshed_channel_buffer
549 .collaborators
550 .clone(),
551 },
552 )
553 .trace_err();
554 }
555 }
556 }
557
558 for room_id in room_ids {
559 let mut contacts_to_update = HashSet::default();
560 let mut canceled_calls_to_user_ids = Vec::new();
561 let mut livekit_room = String::new();
562 let mut delete_livekit_room = false;
563
564 if let Some(mut refreshed_room) = app_state
565 .db
566 .clear_stale_room_participants(room_id, server_id)
567 .await
568 .trace_err()
569 {
570 tracing::info!(
571 room_id = room_id.0,
572 new_participant_count = refreshed_room.room.participants.len(),
573 "refreshed room"
574 );
575 room_updated(&refreshed_room.room, &peer);
576 if let Some(channel) = refreshed_room.channel.as_ref() {
577 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
578 }
579 contacts_to_update
580 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
581 contacts_to_update
582 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
583 canceled_calls_to_user_ids =
584 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
585 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
586 delete_livekit_room = refreshed_room.room.participants.is_empty();
587 }
588
589 {
590 let pool = pool.lock();
591 for canceled_user_id in canceled_calls_to_user_ids {
592 for connection_id in pool.user_connection_ids(canceled_user_id) {
593 peer.send(
594 connection_id,
595 proto::CallCanceled {
596 room_id: room_id.to_proto(),
597 },
598 )
599 .trace_err();
600 }
601 }
602 }
603
604 for user_id in contacts_to_update {
605 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
606 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
607 if let Some((busy, contacts)) = busy.zip(contacts) {
608 let pool = pool.lock();
609 let updated_contact = contact_for_user(user_id, busy, &pool);
610 for contact in contacts {
611 if let db::Contact::Accepted {
612 user_id: contact_user_id,
613 ..
614 } = contact
615 {
616 for contact_conn_id in
617 pool.user_connection_ids(contact_user_id)
618 {
619 peer.send(
620 contact_conn_id,
621 proto::UpdateContacts {
622 contacts: vec![updated_contact.clone()],
623 remove_contacts: Default::default(),
624 incoming_requests: Default::default(),
625 remove_incoming_requests: Default::default(),
626 outgoing_requests: Default::default(),
627 remove_outgoing_requests: Default::default(),
628 },
629 )
630 .trace_err();
631 }
632 }
633 }
634 }
635 }
636
637 if let Some(live_kit) = livekit_client.as_ref() {
638 if delete_livekit_room {
639 live_kit.delete_room(livekit_room).await.trace_err();
640 }
641 }
642 }
643 }
644
645 app_state
646 .db
647 .delete_stale_channel_chat_participants(
648 &app_state.config.zed_environment,
649 server_id,
650 )
651 .await
652 .trace_err();
653
654 app_state
655 .db
656 .clear_old_worktree_entries(server_id)
657 .await
658 .trace_err();
659
660 app_state
661 .db
662 .delete_stale_servers(&app_state.config.zed_environment, server_id)
663 .await
664 .trace_err();
665 }
666 .instrument(span),
667 );
668 Ok(())
669 }
670
671 pub fn teardown(&self) {
672 self.peer.teardown();
673 self.connection_pool.lock().reset();
674 let _ = self.teardown.send(true);
675 }
676
677 #[cfg(test)]
678 pub fn reset(&self, id: ServerId) {
679 self.teardown();
680 *self.id.lock() = id;
681 self.peer.reset(id.0 as u32);
682 let _ = self.teardown.send(false);
683 }
684
685 #[cfg(test)]
686 pub fn id(&self) -> ServerId {
687 *self.id.lock()
688 }
689
690 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
691 where
692 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
693 Fut: 'static + Send + Future<Output = Result<()>>,
694 M: EnvelopedMessage,
695 {
696 let prev_handler = self.handlers.insert(
697 TypeId::of::<M>(),
698 Box::new(move |envelope, session, span| {
699 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
700 let received_at = envelope.received_at;
701 tracing::info!("message received");
702 let start_time = Instant::now();
703 let future = (handler)(
704 *envelope,
705 MessageContext {
706 session,
707 span: span.clone(),
708 },
709 );
710 async move {
711 let result = future.await;
712 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
713 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
714 let queue_duration_ms = total_duration_ms - processing_duration_ms;
715 span.record(TOTAL_DURATION_MS, total_duration_ms);
716 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
717 span.record(QUEUE_DURATION_MS, queue_duration_ms);
718 match result {
719 Err(error) => {
720 tracing::error!(?error, "error handling message")
721 }
722 Ok(()) => tracing::info!("finished handling message"),
723 }
724 }
725 .boxed()
726 }),
727 );
728 if prev_handler.is_some() {
729 panic!("registered a handler for the same message twice");
730 }
731 self
732 }
733
734 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
735 where
736 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
737 Fut: 'static + Send + Future<Output = Result<()>>,
738 M: EnvelopedMessage,
739 {
740 self.add_handler(move |envelope, session| handler(envelope.payload, session));
741 self
742 }
743
744 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
745 where
746 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
747 Fut: Send + Future<Output = Result<()>>,
748 M: RequestMessage,
749 {
750 let handler = Arc::new(handler);
751 self.add_handler(move |envelope, session| {
752 let receipt = envelope.receipt();
753 let handler = handler.clone();
754 async move {
755 let peer = session.peer.clone();
756 let responded = Arc::new(AtomicBool::default());
757 let response = Response {
758 peer: peer.clone(),
759 responded: responded.clone(),
760 receipt,
761 };
762 match (handler)(envelope.payload, response, session).await {
763 Ok(()) => {
764 if responded.load(std::sync::atomic::Ordering::SeqCst) {
765 Ok(())
766 } else {
767 Err(anyhow!("handler did not send a response"))?
768 }
769 }
770 Err(error) => {
771 let proto_err = match &error {
772 Error::Internal(err) => err.to_proto(),
773 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
774 };
775 peer.respond_with_error(receipt, proto_err)?;
776 Err(error)
777 }
778 }
779 }
780 })
781 }
782
783 pub fn handle_connection(
784 self: &Arc<Self>,
785 connection: Connection,
786 address: String,
787 principal: Principal,
788 zed_version: ZedVersion,
789 release_channel: Option<String>,
790 user_agent: Option<String>,
791 geoip_country_code: Option<String>,
792 system_id: Option<String>,
793 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
794 executor: Executor,
795 connection_guard: Option<ConnectionGuard>,
796 ) -> impl Future<Output = ()> + use<> {
797 let this = self.clone();
798 let span = info_span!("handle connection", %address,
799 connection_id=field::Empty,
800 user_id=field::Empty,
801 login=field::Empty,
802 impersonator=field::Empty,
803 user_agent=field::Empty,
804 geoip_country_code=field::Empty,
805 release_channel=field::Empty,
806 );
807 principal.update_span(&span);
808 if let Some(user_agent) = user_agent {
809 span.record("user_agent", user_agent);
810 }
811 if let Some(release_channel) = release_channel {
812 span.record("release_channel", release_channel);
813 }
814
815 if let Some(country_code) = geoip_country_code.as_ref() {
816 span.record("geoip_country_code", country_code);
817 }
818
819 let mut teardown = self.teardown.subscribe();
820 async move {
821 if *teardown.borrow() {
822 tracing::error!("server is tearing down");
823 return;
824 }
825
826 let (connection_id, handle_io, mut incoming_rx) =
827 this.peer.add_connection(connection, {
828 let executor = executor.clone();
829 move |duration| executor.sleep(duration)
830 });
831 tracing::Span::current().record("connection_id", format!("{}", connection_id));
832
833 tracing::info!("connection opened");
834
835 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
836 let http_client = match ReqwestClient::user_agent(&user_agent) {
837 Ok(http_client) => Arc::new(http_client),
838 Err(error) => {
839 tracing::error!(?error, "failed to create HTTP client");
840 return;
841 }
842 };
843
844 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
845 |supermaven_admin_api_key| {
846 Arc::new(SupermavenAdminApi::new(
847 supermaven_admin_api_key.to_string(),
848 http_client.clone(),
849 ))
850 },
851 );
852
853 let session = Session {
854 principal: principal.clone(),
855 connection_id,
856 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
857 peer: this.peer.clone(),
858 connection_pool: this.connection_pool.clone(),
859 app_state: this.app_state.clone(),
860 geoip_country_code,
861 system_id,
862 _executor: executor.clone(),
863 supermaven_client,
864 };
865
866 if let Err(error) = this
867 .send_initial_client_update(
868 connection_id,
869 zed_version,
870 send_connection_id,
871 &session,
872 )
873 .await
874 {
875 tracing::error!(?error, "failed to send initial client update");
876 return;
877 }
878 drop(connection_guard);
879
880 let handle_io = handle_io.fuse();
881 futures::pin_mut!(handle_io);
882
883 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
884 // This prevents deadlocks when e.g., client A performs a request to client B and
885 // client B performs a request to client A. If both clients stop processing further
886 // messages until their respective request completes, they won't have a chance to
887 // respond to the other client's request and cause a deadlock.
888 //
889 // This arrangement ensures we will attempt to process earlier messages first, but fall
890 // back to processing messages arrived later in the spirit of making progress.
891 const MAX_CONCURRENT_HANDLERS: usize = 256;
892 let mut foreground_message_handlers = FuturesUnordered::new();
893 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
894 let get_concurrent_handlers = {
895 let concurrent_handlers = concurrent_handlers.clone();
896 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
897 };
898 loop {
899 let next_message = async {
900 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
901 let message = incoming_rx.next().await;
902 // Cache the concurrent_handlers here, so that we know what the
903 // queue looks like as each handler starts
904 (permit, message, get_concurrent_handlers())
905 }
906 .fuse();
907 futures::pin_mut!(next_message);
908 futures::select_biased! {
909 _ = teardown.changed().fuse() => return,
910 result = handle_io => {
911 if let Err(error) = result {
912 tracing::error!(?error, "error handling I/O");
913 }
914 break;
915 }
916 _ = foreground_message_handlers.next() => {}
917 next_message = next_message => {
918 let (permit, message, concurrent_handlers) = next_message;
919 if let Some(message) = message {
920 let type_name = message.payload_type_name();
921 // note: we copy all the fields from the parent span so we can query them in the logs.
922 // (https://github.com/tokio-rs/tracing/issues/2670).
923 let span = tracing::info_span!("receive message",
924 %connection_id,
925 %address,
926 type_name,
927 concurrent_handlers,
928 user_id=field::Empty,
929 login=field::Empty,
930 impersonator=field::Empty,
931 multi_lsp_query_request=field::Empty,
932 release_channel=field::Empty,
933 { TOTAL_DURATION_MS }=field::Empty,
934 { PROCESSING_DURATION_MS }=field::Empty,
935 { QUEUE_DURATION_MS }=field::Empty,
936 { HOST_WAITING_MS }=field::Empty
937 );
938 principal.update_span(&span);
939 let span_enter = span.enter();
940 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
941 let is_background = message.is_background();
942 let handle_message = (handler)(message, session.clone(), span.clone());
943 drop(span_enter);
944
945 let handle_message = async move {
946 handle_message.await;
947 drop(permit);
948 }.instrument(span);
949 if is_background {
950 executor.spawn_detached(handle_message);
951 } else {
952 foreground_message_handlers.push(handle_message);
953 }
954 } else {
955 tracing::error!("no message handler");
956 }
957 } else {
958 tracing::info!("connection closed");
959 break;
960 }
961 }
962 }
963 }
964
965 drop(foreground_message_handlers);
966 let concurrent_handlers = get_concurrent_handlers();
967 tracing::info!(concurrent_handlers, "signing out");
968 if let Err(error) = connection_lost(session, teardown, executor).await {
969 tracing::error!(?error, "error signing out");
970 }
971 }
972 .instrument(span)
973 }
974
975 async fn send_initial_client_update(
976 &self,
977 connection_id: ConnectionId,
978 zed_version: ZedVersion,
979 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
980 session: &Session,
981 ) -> Result<()> {
982 self.peer.send(
983 connection_id,
984 proto::Hello {
985 peer_id: Some(connection_id.into()),
986 },
987 )?;
988 tracing::info!("sent hello message");
989 if let Some(send_connection_id) = send_connection_id.take() {
990 let _ = send_connection_id.send(connection_id);
991 }
992
993 match &session.principal {
994 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
995 if !user.connected_once {
996 self.peer.send(connection_id, proto::ShowContacts {})?;
997 self.app_state
998 .db
999 .set_user_connected_once(user.id, true)
1000 .await?;
1001 }
1002
1003 update_user_plan(session).await?;
1004
1005 let contacts = self.app_state.db.get_contacts(user.id).await?;
1006
1007 {
1008 let mut pool = self.connection_pool.lock();
1009 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1010 self.peer.send(
1011 connection_id,
1012 build_initial_contacts_update(contacts, &pool),
1013 )?;
1014 }
1015
1016 if should_auto_subscribe_to_channels(zed_version) {
1017 subscribe_user_to_channels(user.id, session).await?;
1018 }
1019
1020 if let Some(incoming_call) =
1021 self.app_state.db.incoming_call_for_user(user.id).await?
1022 {
1023 self.peer.send(connection_id, incoming_call)?;
1024 }
1025
1026 update_user_contacts(user.id, session).await?;
1027 }
1028 }
1029
1030 Ok(())
1031 }
1032
1033 pub async fn invite_code_redeemed(
1034 self: &Arc<Self>,
1035 inviter_id: UserId,
1036 invitee_id: UserId,
1037 ) -> Result<()> {
1038 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1039 if let Some(code) = &user.invite_code {
1040 let pool = self.connection_pool.lock();
1041 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1042 for connection_id in pool.user_connection_ids(inviter_id) {
1043 self.peer.send(
1044 connection_id,
1045 proto::UpdateContacts {
1046 contacts: vec![invitee_contact.clone()],
1047 ..Default::default()
1048 },
1049 )?;
1050 self.peer.send(
1051 connection_id,
1052 proto::UpdateInviteInfo {
1053 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1054 count: user.invite_count as u32,
1055 },
1056 )?;
1057 }
1058 }
1059 }
1060 Ok(())
1061 }
1062
1063 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1064 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1065 if let Some(invite_code) = &user.invite_code {
1066 let pool = self.connection_pool.lock();
1067 for connection_id in pool.user_connection_ids(user_id) {
1068 self.peer.send(
1069 connection_id,
1070 proto::UpdateInviteInfo {
1071 url: format!(
1072 "{}{}",
1073 self.app_state.config.invite_link_prefix, invite_code
1074 ),
1075 count: user.invite_count as u32,
1076 },
1077 )?;
1078 }
1079 }
1080 }
1081 Ok(())
1082 }
1083
1084 pub async fn update_plan_for_user(
1085 self: &Arc<Self>,
1086 user_id: UserId,
1087 update_user_plan: proto::UpdateUserPlan,
1088 ) -> Result<()> {
1089 let pool = self.connection_pool.lock();
1090 for connection_id in pool.user_connection_ids(user_id) {
1091 self.peer
1092 .send(connection_id, update_user_plan.clone())
1093 .trace_err();
1094 }
1095
1096 Ok(())
1097 }
1098
1099 /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1100 /// message on the Collab server.
1101 ///
1102 /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1103 pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1104 let user = self
1105 .app_state
1106 .db
1107 .get_user_by_id(user_id)
1108 .await?
1109 .context("user not found")?;
1110
1111 let update_user_plan = make_update_user_plan_message(
1112 &user,
1113 user.admin,
1114 &self.app_state.db,
1115 self.app_state.llm_db.clone(),
1116 )
1117 .await?;
1118
1119 self.update_plan_for_user(user_id, update_user_plan).await
1120 }
1121
1122 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1123 let pool = self.connection_pool.lock();
1124 for connection_id in pool.user_connection_ids(user_id) {
1125 self.peer
1126 .send(connection_id, proto::RefreshLlmToken {})
1127 .trace_err();
1128 }
1129 }
1130
1131 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1132 ServerSnapshot {
1133 connection_pool: ConnectionPoolGuard {
1134 guard: self.connection_pool.lock(),
1135 _not_send: PhantomData,
1136 },
1137 peer: &self.peer,
1138 }
1139 }
1140}
1141
1142impl Deref for ConnectionPoolGuard<'_> {
1143 type Target = ConnectionPool;
1144
1145 fn deref(&self) -> &Self::Target {
1146 &self.guard
1147 }
1148}
1149
1150impl DerefMut for ConnectionPoolGuard<'_> {
1151 fn deref_mut(&mut self) -> &mut Self::Target {
1152 &mut self.guard
1153 }
1154}
1155
1156impl Drop for ConnectionPoolGuard<'_> {
1157 fn drop(&mut self) {
1158 #[cfg(test)]
1159 self.check_invariants();
1160 }
1161}
1162
1163fn broadcast<F>(
1164 sender_id: Option<ConnectionId>,
1165 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1166 mut f: F,
1167) where
1168 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1169{
1170 for receiver_id in receiver_ids {
1171 if Some(receiver_id) != sender_id {
1172 if let Err(error) = f(receiver_id) {
1173 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1174 }
1175 }
1176 }
1177}
1178
1179pub struct ProtocolVersion(u32);
1180
1181impl Header for ProtocolVersion {
1182 fn name() -> &'static HeaderName {
1183 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1184 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1185 }
1186
1187 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1188 where
1189 Self: Sized,
1190 I: Iterator<Item = &'i axum::http::HeaderValue>,
1191 {
1192 let version = values
1193 .next()
1194 .ok_or_else(axum::headers::Error::invalid)?
1195 .to_str()
1196 .map_err(|_| axum::headers::Error::invalid())?
1197 .parse()
1198 .map_err(|_| axum::headers::Error::invalid())?;
1199 Ok(Self(version))
1200 }
1201
1202 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1203 values.extend([self.0.to_string().parse().unwrap()]);
1204 }
1205}
1206
1207pub struct AppVersionHeader(SemanticVersion);
1208impl Header for AppVersionHeader {
1209 fn name() -> &'static HeaderName {
1210 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1211 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1212 }
1213
1214 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1215 where
1216 Self: Sized,
1217 I: Iterator<Item = &'i axum::http::HeaderValue>,
1218 {
1219 let version = values
1220 .next()
1221 .ok_or_else(axum::headers::Error::invalid)?
1222 .to_str()
1223 .map_err(|_| axum::headers::Error::invalid())?
1224 .parse()
1225 .map_err(|_| axum::headers::Error::invalid())?;
1226 Ok(Self(version))
1227 }
1228
1229 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1230 values.extend([self.0.to_string().parse().unwrap()]);
1231 }
1232}
1233
1234#[derive(Debug)]
1235pub struct ReleaseChannelHeader(String);
1236
1237impl Header for ReleaseChannelHeader {
1238 fn name() -> &'static HeaderName {
1239 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1240 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1241 }
1242
1243 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1244 where
1245 Self: Sized,
1246 I: Iterator<Item = &'i axum::http::HeaderValue>,
1247 {
1248 Ok(Self(
1249 values
1250 .next()
1251 .ok_or_else(axum::headers::Error::invalid)?
1252 .to_str()
1253 .map_err(|_| axum::headers::Error::invalid())?
1254 .to_owned(),
1255 ))
1256 }
1257
1258 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1259 values.extend([self.0.parse().unwrap()]);
1260 }
1261}
1262
1263pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1264 Router::new()
1265 .route("/rpc", get(handle_websocket_request))
1266 .layer(
1267 ServiceBuilder::new()
1268 .layer(Extension(server.app_state.clone()))
1269 .layer(middleware::from_fn(auth::validate_header)),
1270 )
1271 .route("/metrics", get(handle_metrics))
1272 .layer(Extension(server))
1273}
1274
1275pub async fn handle_websocket_request(
1276 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1277 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1278 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1279 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1280 Extension(server): Extension<Arc<Server>>,
1281 Extension(principal): Extension<Principal>,
1282 user_agent: Option<TypedHeader<UserAgent>>,
1283 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1284 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1285 ws: WebSocketUpgrade,
1286) -> axum::response::Response {
1287 if protocol_version != rpc::PROTOCOL_VERSION {
1288 return (
1289 StatusCode::UPGRADE_REQUIRED,
1290 "client must be upgraded".to_string(),
1291 )
1292 .into_response();
1293 }
1294
1295 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1296 return (
1297 StatusCode::UPGRADE_REQUIRED,
1298 "no version header found".to_string(),
1299 )
1300 .into_response();
1301 };
1302
1303 let release_channel = release_channel_header.map(|header| header.0.0);
1304
1305 if !version.can_collaborate() {
1306 return (
1307 StatusCode::UPGRADE_REQUIRED,
1308 "client must be upgraded".to_string(),
1309 )
1310 .into_response();
1311 }
1312
1313 let socket_address = socket_address.to_string();
1314
1315 // Acquire connection guard before WebSocket upgrade
1316 let connection_guard = match ConnectionGuard::try_acquire() {
1317 Ok(guard) => guard,
1318 Err(()) => {
1319 return (
1320 StatusCode::SERVICE_UNAVAILABLE,
1321 "Too many concurrent connections",
1322 )
1323 .into_response();
1324 }
1325 };
1326
1327 ws.on_upgrade(move |socket| {
1328 let socket = socket
1329 .map_ok(to_tungstenite_message)
1330 .err_into()
1331 .with(|message| async move { to_axum_message(message) });
1332 let connection = Connection::new(Box::pin(socket));
1333 async move {
1334 server
1335 .handle_connection(
1336 connection,
1337 socket_address,
1338 principal,
1339 version,
1340 release_channel,
1341 user_agent.map(|header| header.to_string()),
1342 country_code_header.map(|header| header.to_string()),
1343 system_id_header.map(|header| header.to_string()),
1344 None,
1345 Executor::Production,
1346 Some(connection_guard),
1347 )
1348 .await;
1349 }
1350 })
1351}
1352
1353pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1354 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1355 let connections_metric = CONNECTIONS_METRIC
1356 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1357
1358 let connections = server
1359 .connection_pool
1360 .lock()
1361 .connections()
1362 .filter(|connection| !connection.admin)
1363 .count();
1364 connections_metric.set(connections as _);
1365
1366 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1367 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1368 register_int_gauge!(
1369 "shared_projects",
1370 "number of open projects with one or more guests"
1371 )
1372 .unwrap()
1373 });
1374
1375 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1376 shared_projects_metric.set(shared_projects as _);
1377
1378 let encoder = prometheus::TextEncoder::new();
1379 let metric_families = prometheus::gather();
1380 let encoded_metrics = encoder
1381 .encode_to_string(&metric_families)
1382 .map_err(|err| anyhow!("{err}"))?;
1383 Ok(encoded_metrics)
1384}
1385
1386#[instrument(err, skip(executor))]
1387async fn connection_lost(
1388 session: Session,
1389 mut teardown: watch::Receiver<bool>,
1390 executor: Executor,
1391) -> Result<()> {
1392 session.peer.disconnect(session.connection_id);
1393 session
1394 .connection_pool()
1395 .await
1396 .remove_connection(session.connection_id)?;
1397
1398 session
1399 .db()
1400 .await
1401 .connection_lost(session.connection_id)
1402 .await
1403 .trace_err();
1404
1405 futures::select_biased! {
1406 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1407
1408 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1409 leave_room_for_session(&session, session.connection_id).await.trace_err();
1410 leave_channel_buffers_for_session(&session)
1411 .await
1412 .trace_err();
1413
1414 if !session
1415 .connection_pool()
1416 .await
1417 .is_user_online(session.user_id())
1418 {
1419 let db = session.db().await;
1420 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1421 room_updated(&room, &session.peer);
1422 }
1423 }
1424
1425 update_user_contacts(session.user_id(), &session).await?;
1426 },
1427 _ = teardown.changed().fuse() => {}
1428 }
1429
1430 Ok(())
1431}
1432
1433/// Acknowledges a ping from a client, used to keep the connection alive.
1434async fn ping(
1435 _: proto::Ping,
1436 response: Response<proto::Ping>,
1437 _session: MessageContext,
1438) -> Result<()> {
1439 response.send(proto::Ack {})?;
1440 Ok(())
1441}
1442
1443/// Creates a new room for calling (outside of channels)
1444async fn create_room(
1445 _request: proto::CreateRoom,
1446 response: Response<proto::CreateRoom>,
1447 session: MessageContext,
1448) -> Result<()> {
1449 let livekit_room = nanoid::nanoid!(30);
1450
1451 let live_kit_connection_info = util::maybe!(async {
1452 let live_kit = session.app_state.livekit_client.as_ref();
1453 let live_kit = live_kit?;
1454 let user_id = session.user_id().to_string();
1455
1456 let token = live_kit
1457 .room_token(&livekit_room, &user_id.to_string())
1458 .trace_err()?;
1459
1460 Some(proto::LiveKitConnectionInfo {
1461 server_url: live_kit.url().into(),
1462 token,
1463 can_publish: true,
1464 })
1465 })
1466 .await;
1467
1468 let room = session
1469 .db()
1470 .await
1471 .create_room(session.user_id(), session.connection_id, &livekit_room)
1472 .await?;
1473
1474 response.send(proto::CreateRoomResponse {
1475 room: Some(room.clone()),
1476 live_kit_connection_info,
1477 })?;
1478
1479 update_user_contacts(session.user_id(), &session).await?;
1480 Ok(())
1481}
1482
1483/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1484async fn join_room(
1485 request: proto::JoinRoom,
1486 response: Response<proto::JoinRoom>,
1487 session: MessageContext,
1488) -> Result<()> {
1489 let room_id = RoomId::from_proto(request.id);
1490
1491 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1492
1493 if let Some(channel_id) = channel_id {
1494 return join_channel_internal(channel_id, Box::new(response), session).await;
1495 }
1496
1497 let joined_room = {
1498 let room = session
1499 .db()
1500 .await
1501 .join_room(room_id, session.user_id(), session.connection_id)
1502 .await?;
1503 room_updated(&room.room, &session.peer);
1504 room.into_inner()
1505 };
1506
1507 for connection_id in session
1508 .connection_pool()
1509 .await
1510 .user_connection_ids(session.user_id())
1511 {
1512 session
1513 .peer
1514 .send(
1515 connection_id,
1516 proto::CallCanceled {
1517 room_id: room_id.to_proto(),
1518 },
1519 )
1520 .trace_err();
1521 }
1522
1523 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1524 {
1525 live_kit
1526 .room_token(
1527 &joined_room.room.livekit_room,
1528 &session.user_id().to_string(),
1529 )
1530 .trace_err()
1531 .map(|token| proto::LiveKitConnectionInfo {
1532 server_url: live_kit.url().into(),
1533 token,
1534 can_publish: true,
1535 })
1536 } else {
1537 None
1538 };
1539
1540 response.send(proto::JoinRoomResponse {
1541 room: Some(joined_room.room),
1542 channel_id: None,
1543 live_kit_connection_info,
1544 })?;
1545
1546 update_user_contacts(session.user_id(), &session).await?;
1547 Ok(())
1548}
1549
1550/// Rejoin room is used to reconnect to a room after connection errors.
1551async fn rejoin_room(
1552 request: proto::RejoinRoom,
1553 response: Response<proto::RejoinRoom>,
1554 session: MessageContext,
1555) -> Result<()> {
1556 let room;
1557 let channel;
1558 {
1559 let mut rejoined_room = session
1560 .db()
1561 .await
1562 .rejoin_room(request, session.user_id(), session.connection_id)
1563 .await?;
1564
1565 response.send(proto::RejoinRoomResponse {
1566 room: Some(rejoined_room.room.clone()),
1567 reshared_projects: rejoined_room
1568 .reshared_projects
1569 .iter()
1570 .map(|project| proto::ResharedProject {
1571 id: project.id.to_proto(),
1572 collaborators: project
1573 .collaborators
1574 .iter()
1575 .map(|collaborator| collaborator.to_proto())
1576 .collect(),
1577 })
1578 .collect(),
1579 rejoined_projects: rejoined_room
1580 .rejoined_projects
1581 .iter()
1582 .map(|rejoined_project| rejoined_project.to_proto())
1583 .collect(),
1584 })?;
1585 room_updated(&rejoined_room.room, &session.peer);
1586
1587 for project in &rejoined_room.reshared_projects {
1588 for collaborator in &project.collaborators {
1589 session
1590 .peer
1591 .send(
1592 collaborator.connection_id,
1593 proto::UpdateProjectCollaborator {
1594 project_id: project.id.to_proto(),
1595 old_peer_id: Some(project.old_connection_id.into()),
1596 new_peer_id: Some(session.connection_id.into()),
1597 },
1598 )
1599 .trace_err();
1600 }
1601
1602 broadcast(
1603 Some(session.connection_id),
1604 project
1605 .collaborators
1606 .iter()
1607 .map(|collaborator| collaborator.connection_id),
1608 |connection_id| {
1609 session.peer.forward_send(
1610 session.connection_id,
1611 connection_id,
1612 proto::UpdateProject {
1613 project_id: project.id.to_proto(),
1614 worktrees: project.worktrees.clone(),
1615 },
1616 )
1617 },
1618 );
1619 }
1620
1621 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1622
1623 let rejoined_room = rejoined_room.into_inner();
1624
1625 room = rejoined_room.room;
1626 channel = rejoined_room.channel;
1627 }
1628
1629 if let Some(channel) = channel {
1630 channel_updated(
1631 &channel,
1632 &room,
1633 &session.peer,
1634 &*session.connection_pool().await,
1635 );
1636 }
1637
1638 update_user_contacts(session.user_id(), &session).await?;
1639 Ok(())
1640}
1641
1642fn notify_rejoined_projects(
1643 rejoined_projects: &mut Vec<RejoinedProject>,
1644 session: &Session,
1645) -> Result<()> {
1646 for project in rejoined_projects.iter() {
1647 for collaborator in &project.collaborators {
1648 session
1649 .peer
1650 .send(
1651 collaborator.connection_id,
1652 proto::UpdateProjectCollaborator {
1653 project_id: project.id.to_proto(),
1654 old_peer_id: Some(project.old_connection_id.into()),
1655 new_peer_id: Some(session.connection_id.into()),
1656 },
1657 )
1658 .trace_err();
1659 }
1660 }
1661
1662 for project in rejoined_projects {
1663 for worktree in mem::take(&mut project.worktrees) {
1664 // Stream this worktree's entries.
1665 let message = proto::UpdateWorktree {
1666 project_id: project.id.to_proto(),
1667 worktree_id: worktree.id,
1668 abs_path: worktree.abs_path.clone(),
1669 root_name: worktree.root_name,
1670 updated_entries: worktree.updated_entries,
1671 removed_entries: worktree.removed_entries,
1672 scan_id: worktree.scan_id,
1673 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1674 updated_repositories: worktree.updated_repositories,
1675 removed_repositories: worktree.removed_repositories,
1676 };
1677 for update in proto::split_worktree_update(message) {
1678 session.peer.send(session.connection_id, update)?;
1679 }
1680
1681 // Stream this worktree's diagnostics.
1682 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1683 if let Some(summary) = worktree_diagnostics.next() {
1684 let message = proto::UpdateDiagnosticSummary {
1685 project_id: project.id.to_proto(),
1686 worktree_id: worktree.id,
1687 summary: Some(summary),
1688 more_summaries: worktree_diagnostics.collect(),
1689 };
1690 session.peer.send(session.connection_id, message)?;
1691 }
1692
1693 for settings_file in worktree.settings_files {
1694 session.peer.send(
1695 session.connection_id,
1696 proto::UpdateWorktreeSettings {
1697 project_id: project.id.to_proto(),
1698 worktree_id: worktree.id,
1699 path: settings_file.path,
1700 content: Some(settings_file.content),
1701 kind: Some(settings_file.kind.to_proto().into()),
1702 },
1703 )?;
1704 }
1705 }
1706
1707 for repository in mem::take(&mut project.updated_repositories) {
1708 for update in split_repository_update(repository) {
1709 session.peer.send(session.connection_id, update)?;
1710 }
1711 }
1712
1713 for id in mem::take(&mut project.removed_repositories) {
1714 session.peer.send(
1715 session.connection_id,
1716 proto::RemoveRepository {
1717 project_id: project.id.to_proto(),
1718 id,
1719 },
1720 )?;
1721 }
1722 }
1723
1724 Ok(())
1725}
1726
1727/// leave room disconnects from the room.
1728async fn leave_room(
1729 _: proto::LeaveRoom,
1730 response: Response<proto::LeaveRoom>,
1731 session: MessageContext,
1732) -> Result<()> {
1733 leave_room_for_session(&session, session.connection_id).await?;
1734 response.send(proto::Ack {})?;
1735 Ok(())
1736}
1737
1738/// Updates the permissions of someone else in the room.
1739async fn set_room_participant_role(
1740 request: proto::SetRoomParticipantRole,
1741 response: Response<proto::SetRoomParticipantRole>,
1742 session: MessageContext,
1743) -> Result<()> {
1744 let user_id = UserId::from_proto(request.user_id);
1745 let role = ChannelRole::from(request.role());
1746
1747 let (livekit_room, can_publish) = {
1748 let room = session
1749 .db()
1750 .await
1751 .set_room_participant_role(
1752 session.user_id(),
1753 RoomId::from_proto(request.room_id),
1754 user_id,
1755 role,
1756 )
1757 .await?;
1758
1759 let livekit_room = room.livekit_room.clone();
1760 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1761 room_updated(&room, &session.peer);
1762 (livekit_room, can_publish)
1763 };
1764
1765 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1766 live_kit
1767 .update_participant(
1768 livekit_room.clone(),
1769 request.user_id.to_string(),
1770 livekit_api::proto::ParticipantPermission {
1771 can_subscribe: true,
1772 can_publish,
1773 can_publish_data: can_publish,
1774 hidden: false,
1775 recorder: false,
1776 },
1777 )
1778 .await
1779 .trace_err();
1780 }
1781
1782 response.send(proto::Ack {})?;
1783 Ok(())
1784}
1785
1786/// Call someone else into the current room
1787async fn call(
1788 request: proto::Call,
1789 response: Response<proto::Call>,
1790 session: MessageContext,
1791) -> Result<()> {
1792 let room_id = RoomId::from_proto(request.room_id);
1793 let calling_user_id = session.user_id();
1794 let calling_connection_id = session.connection_id;
1795 let called_user_id = UserId::from_proto(request.called_user_id);
1796 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1797 if !session
1798 .db()
1799 .await
1800 .has_contact(calling_user_id, called_user_id)
1801 .await?
1802 {
1803 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1804 }
1805
1806 let incoming_call = {
1807 let (room, incoming_call) = &mut *session
1808 .db()
1809 .await
1810 .call(
1811 room_id,
1812 calling_user_id,
1813 calling_connection_id,
1814 called_user_id,
1815 initial_project_id,
1816 )
1817 .await?;
1818 room_updated(room, &session.peer);
1819 mem::take(incoming_call)
1820 };
1821 update_user_contacts(called_user_id, &session).await?;
1822
1823 let mut calls = session
1824 .connection_pool()
1825 .await
1826 .user_connection_ids(called_user_id)
1827 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1828 .collect::<FuturesUnordered<_>>();
1829
1830 while let Some(call_response) = calls.next().await {
1831 match call_response.as_ref() {
1832 Ok(_) => {
1833 response.send(proto::Ack {})?;
1834 return Ok(());
1835 }
1836 Err(_) => {
1837 call_response.trace_err();
1838 }
1839 }
1840 }
1841
1842 {
1843 let room = session
1844 .db()
1845 .await
1846 .call_failed(room_id, called_user_id)
1847 .await?;
1848 room_updated(&room, &session.peer);
1849 }
1850 update_user_contacts(called_user_id, &session).await?;
1851
1852 Err(anyhow!("failed to ring user"))?
1853}
1854
1855/// Cancel an outgoing call.
1856async fn cancel_call(
1857 request: proto::CancelCall,
1858 response: Response<proto::CancelCall>,
1859 session: MessageContext,
1860) -> Result<()> {
1861 let called_user_id = UserId::from_proto(request.called_user_id);
1862 let room_id = RoomId::from_proto(request.room_id);
1863 {
1864 let room = session
1865 .db()
1866 .await
1867 .cancel_call(room_id, session.connection_id, called_user_id)
1868 .await?;
1869 room_updated(&room, &session.peer);
1870 }
1871
1872 for connection_id in session
1873 .connection_pool()
1874 .await
1875 .user_connection_ids(called_user_id)
1876 {
1877 session
1878 .peer
1879 .send(
1880 connection_id,
1881 proto::CallCanceled {
1882 room_id: room_id.to_proto(),
1883 },
1884 )
1885 .trace_err();
1886 }
1887 response.send(proto::Ack {})?;
1888
1889 update_user_contacts(called_user_id, &session).await?;
1890 Ok(())
1891}
1892
1893/// Decline an incoming call.
1894async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1895 let room_id = RoomId::from_proto(message.room_id);
1896 {
1897 let room = session
1898 .db()
1899 .await
1900 .decline_call(Some(room_id), session.user_id())
1901 .await?
1902 .context("declining call")?;
1903 room_updated(&room, &session.peer);
1904 }
1905
1906 for connection_id in session
1907 .connection_pool()
1908 .await
1909 .user_connection_ids(session.user_id())
1910 {
1911 session
1912 .peer
1913 .send(
1914 connection_id,
1915 proto::CallCanceled {
1916 room_id: room_id.to_proto(),
1917 },
1918 )
1919 .trace_err();
1920 }
1921 update_user_contacts(session.user_id(), &session).await?;
1922 Ok(())
1923}
1924
1925/// Updates other participants in the room with your current location.
1926async fn update_participant_location(
1927 request: proto::UpdateParticipantLocation,
1928 response: Response<proto::UpdateParticipantLocation>,
1929 session: MessageContext,
1930) -> Result<()> {
1931 let room_id = RoomId::from_proto(request.room_id);
1932 let location = request.location.context("invalid location")?;
1933
1934 let db = session.db().await;
1935 let room = db
1936 .update_room_participant_location(room_id, session.connection_id, location)
1937 .await?;
1938
1939 room_updated(&room, &session.peer);
1940 response.send(proto::Ack {})?;
1941 Ok(())
1942}
1943
1944/// Share a project into the room.
1945async fn share_project(
1946 request: proto::ShareProject,
1947 response: Response<proto::ShareProject>,
1948 session: MessageContext,
1949) -> Result<()> {
1950 let (project_id, room) = &*session
1951 .db()
1952 .await
1953 .share_project(
1954 RoomId::from_proto(request.room_id),
1955 session.connection_id,
1956 &request.worktrees,
1957 request.is_ssh_project,
1958 )
1959 .await?;
1960 response.send(proto::ShareProjectResponse {
1961 project_id: project_id.to_proto(),
1962 })?;
1963 room_updated(room, &session.peer);
1964
1965 Ok(())
1966}
1967
1968/// Unshare a project from the room.
1969async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1970 let project_id = ProjectId::from_proto(message.project_id);
1971 unshare_project_internal(project_id, session.connection_id, &session).await
1972}
1973
1974async fn unshare_project_internal(
1975 project_id: ProjectId,
1976 connection_id: ConnectionId,
1977 session: &Session,
1978) -> Result<()> {
1979 let delete = {
1980 let room_guard = session
1981 .db()
1982 .await
1983 .unshare_project(project_id, connection_id)
1984 .await?;
1985
1986 let (delete, room, guest_connection_ids) = &*room_guard;
1987
1988 let message = proto::UnshareProject {
1989 project_id: project_id.to_proto(),
1990 };
1991
1992 broadcast(
1993 Some(connection_id),
1994 guest_connection_ids.iter().copied(),
1995 |conn_id| session.peer.send(conn_id, message.clone()),
1996 );
1997 if let Some(room) = room {
1998 room_updated(room, &session.peer);
1999 }
2000
2001 *delete
2002 };
2003
2004 if delete {
2005 let db = session.db().await;
2006 db.delete_project(project_id).await?;
2007 }
2008
2009 Ok(())
2010}
2011
2012/// Join someone elses shared project.
2013async fn join_project(
2014 request: proto::JoinProject,
2015 response: Response<proto::JoinProject>,
2016 session: MessageContext,
2017) -> Result<()> {
2018 let project_id = ProjectId::from_proto(request.project_id);
2019
2020 tracing::info!(%project_id, "join project");
2021
2022 let db = session.db().await;
2023 let (project, replica_id) = &mut *db
2024 .join_project(
2025 project_id,
2026 session.connection_id,
2027 session.user_id(),
2028 request.committer_name.clone(),
2029 request.committer_email.clone(),
2030 )
2031 .await?;
2032 drop(db);
2033 tracing::info!(%project_id, "join remote project");
2034 let collaborators = project
2035 .collaborators
2036 .iter()
2037 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2038 .map(|collaborator| collaborator.to_proto())
2039 .collect::<Vec<_>>();
2040 let project_id = project.id;
2041 let guest_user_id = session.user_id();
2042
2043 let worktrees = project
2044 .worktrees
2045 .iter()
2046 .map(|(id, worktree)| proto::WorktreeMetadata {
2047 id: *id,
2048 root_name: worktree.root_name.clone(),
2049 visible: worktree.visible,
2050 abs_path: worktree.abs_path.clone(),
2051 })
2052 .collect::<Vec<_>>();
2053
2054 let add_project_collaborator = proto::AddProjectCollaborator {
2055 project_id: project_id.to_proto(),
2056 collaborator: Some(proto::Collaborator {
2057 peer_id: Some(session.connection_id.into()),
2058 replica_id: replica_id.0 as u32,
2059 user_id: guest_user_id.to_proto(),
2060 is_host: false,
2061 committer_name: request.committer_name.clone(),
2062 committer_email: request.committer_email.clone(),
2063 }),
2064 };
2065
2066 for collaborator in &collaborators {
2067 session
2068 .peer
2069 .send(
2070 collaborator.peer_id.unwrap().into(),
2071 add_project_collaborator.clone(),
2072 )
2073 .trace_err();
2074 }
2075
2076 // First, we send the metadata associated with each worktree.
2077 let (language_servers, language_server_capabilities) = project
2078 .language_servers
2079 .clone()
2080 .into_iter()
2081 .map(|server| (server.server, server.capabilities))
2082 .unzip();
2083 response.send(proto::JoinProjectResponse {
2084 project_id: project.id.0 as u64,
2085 worktrees: worktrees.clone(),
2086 replica_id: replica_id.0 as u32,
2087 collaborators: collaborators.clone(),
2088 language_servers,
2089 language_server_capabilities,
2090 role: project.role.into(),
2091 })?;
2092
2093 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2094 // Stream this worktree's entries.
2095 let message = proto::UpdateWorktree {
2096 project_id: project_id.to_proto(),
2097 worktree_id,
2098 abs_path: worktree.abs_path.clone(),
2099 root_name: worktree.root_name,
2100 updated_entries: worktree.entries,
2101 removed_entries: Default::default(),
2102 scan_id: worktree.scan_id,
2103 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2104 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2105 removed_repositories: Default::default(),
2106 };
2107 for update in proto::split_worktree_update(message) {
2108 session.peer.send(session.connection_id, update.clone())?;
2109 }
2110
2111 // Stream this worktree's diagnostics.
2112 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2113 if let Some(summary) = worktree_diagnostics.next() {
2114 let message = proto::UpdateDiagnosticSummary {
2115 project_id: project.id.to_proto(),
2116 worktree_id: worktree.id,
2117 summary: Some(summary),
2118 more_summaries: worktree_diagnostics.collect(),
2119 };
2120 session.peer.send(session.connection_id, message)?;
2121 }
2122
2123 for settings_file in worktree.settings_files {
2124 session.peer.send(
2125 session.connection_id,
2126 proto::UpdateWorktreeSettings {
2127 project_id: project_id.to_proto(),
2128 worktree_id: worktree.id,
2129 path: settings_file.path,
2130 content: Some(settings_file.content),
2131 kind: Some(settings_file.kind.to_proto() as i32),
2132 },
2133 )?;
2134 }
2135 }
2136
2137 for repository in mem::take(&mut project.repositories) {
2138 for update in split_repository_update(repository) {
2139 session.peer.send(session.connection_id, update)?;
2140 }
2141 }
2142
2143 for language_server in &project.language_servers {
2144 session.peer.send(
2145 session.connection_id,
2146 proto::UpdateLanguageServer {
2147 project_id: project_id.to_proto(),
2148 server_name: Some(language_server.server.name.clone()),
2149 language_server_id: language_server.server.id,
2150 variant: Some(
2151 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2152 proto::LspDiskBasedDiagnosticsUpdated {},
2153 ),
2154 ),
2155 },
2156 )?;
2157 }
2158
2159 Ok(())
2160}
2161
2162/// Leave someone elses shared project.
2163async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2164 let sender_id = session.connection_id;
2165 let project_id = ProjectId::from_proto(request.project_id);
2166 let db = session.db().await;
2167
2168 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2169 tracing::info!(
2170 %project_id,
2171 "leave project"
2172 );
2173
2174 project_left(project, &session);
2175 if let Some(room) = room {
2176 room_updated(room, &session.peer);
2177 }
2178
2179 Ok(())
2180}
2181
2182/// Updates other participants with changes to the project
2183async fn update_project(
2184 request: proto::UpdateProject,
2185 response: Response<proto::UpdateProject>,
2186 session: MessageContext,
2187) -> Result<()> {
2188 let project_id = ProjectId::from_proto(request.project_id);
2189 let (room, guest_connection_ids) = &*session
2190 .db()
2191 .await
2192 .update_project(project_id, session.connection_id, &request.worktrees)
2193 .await?;
2194 broadcast(
2195 Some(session.connection_id),
2196 guest_connection_ids.iter().copied(),
2197 |connection_id| {
2198 session
2199 .peer
2200 .forward_send(session.connection_id, connection_id, request.clone())
2201 },
2202 );
2203 if let Some(room) = room {
2204 room_updated(room, &session.peer);
2205 }
2206 response.send(proto::Ack {})?;
2207
2208 Ok(())
2209}
2210
2211/// Updates other participants with changes to the worktree
2212async fn update_worktree(
2213 request: proto::UpdateWorktree,
2214 response: Response<proto::UpdateWorktree>,
2215 session: MessageContext,
2216) -> Result<()> {
2217 let guest_connection_ids = session
2218 .db()
2219 .await
2220 .update_worktree(&request, session.connection_id)
2221 .await?;
2222
2223 broadcast(
2224 Some(session.connection_id),
2225 guest_connection_ids.iter().copied(),
2226 |connection_id| {
2227 session
2228 .peer
2229 .forward_send(session.connection_id, connection_id, request.clone())
2230 },
2231 );
2232 response.send(proto::Ack {})?;
2233 Ok(())
2234}
2235
2236async fn update_repository(
2237 request: proto::UpdateRepository,
2238 response: Response<proto::UpdateRepository>,
2239 session: MessageContext,
2240) -> Result<()> {
2241 let guest_connection_ids = session
2242 .db()
2243 .await
2244 .update_repository(&request, session.connection_id)
2245 .await?;
2246
2247 broadcast(
2248 Some(session.connection_id),
2249 guest_connection_ids.iter().copied(),
2250 |connection_id| {
2251 session
2252 .peer
2253 .forward_send(session.connection_id, connection_id, request.clone())
2254 },
2255 );
2256 response.send(proto::Ack {})?;
2257 Ok(())
2258}
2259
2260async fn remove_repository(
2261 request: proto::RemoveRepository,
2262 response: Response<proto::RemoveRepository>,
2263 session: MessageContext,
2264) -> Result<()> {
2265 let guest_connection_ids = session
2266 .db()
2267 .await
2268 .remove_repository(&request, session.connection_id)
2269 .await?;
2270
2271 broadcast(
2272 Some(session.connection_id),
2273 guest_connection_ids.iter().copied(),
2274 |connection_id| {
2275 session
2276 .peer
2277 .forward_send(session.connection_id, connection_id, request.clone())
2278 },
2279 );
2280 response.send(proto::Ack {})?;
2281 Ok(())
2282}
2283
2284/// Updates other participants with changes to the diagnostics
2285async fn update_diagnostic_summary(
2286 message: proto::UpdateDiagnosticSummary,
2287 session: MessageContext,
2288) -> Result<()> {
2289 let guest_connection_ids = session
2290 .db()
2291 .await
2292 .update_diagnostic_summary(&message, session.connection_id)
2293 .await?;
2294
2295 broadcast(
2296 Some(session.connection_id),
2297 guest_connection_ids.iter().copied(),
2298 |connection_id| {
2299 session
2300 .peer
2301 .forward_send(session.connection_id, connection_id, message.clone())
2302 },
2303 );
2304
2305 Ok(())
2306}
2307
2308/// Updates other participants with changes to the worktree settings
2309async fn update_worktree_settings(
2310 message: proto::UpdateWorktreeSettings,
2311 session: MessageContext,
2312) -> Result<()> {
2313 let guest_connection_ids = session
2314 .db()
2315 .await
2316 .update_worktree_settings(&message, session.connection_id)
2317 .await?;
2318
2319 broadcast(
2320 Some(session.connection_id),
2321 guest_connection_ids.iter().copied(),
2322 |connection_id| {
2323 session
2324 .peer
2325 .forward_send(session.connection_id, connection_id, message.clone())
2326 },
2327 );
2328
2329 Ok(())
2330}
2331
2332/// Notify other participants that a language server has started.
2333async fn start_language_server(
2334 request: proto::StartLanguageServer,
2335 session: MessageContext,
2336) -> Result<()> {
2337 let guest_connection_ids = session
2338 .db()
2339 .await
2340 .start_language_server(&request, session.connection_id)
2341 .await?;
2342
2343 broadcast(
2344 Some(session.connection_id),
2345 guest_connection_ids.iter().copied(),
2346 |connection_id| {
2347 session
2348 .peer
2349 .forward_send(session.connection_id, connection_id, request.clone())
2350 },
2351 );
2352 Ok(())
2353}
2354
2355/// Notify other participants that a language server has changed.
2356async fn update_language_server(
2357 request: proto::UpdateLanguageServer,
2358 session: MessageContext,
2359) -> Result<()> {
2360 let project_id = ProjectId::from_proto(request.project_id);
2361 let db = session.db().await;
2362
2363 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2364 {
2365 if let Some(capabilities) = update.capabilities.clone() {
2366 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2367 .await?;
2368 }
2369 }
2370
2371 let project_connection_ids = db
2372 .project_connection_ids(project_id, session.connection_id, true)
2373 .await?;
2374 broadcast(
2375 Some(session.connection_id),
2376 project_connection_ids.iter().copied(),
2377 |connection_id| {
2378 session
2379 .peer
2380 .forward_send(session.connection_id, connection_id, request.clone())
2381 },
2382 );
2383 Ok(())
2384}
2385
2386/// forward a project request to the host. These requests should be read only
2387/// as guests are allowed to send them.
2388async fn forward_read_only_project_request<T>(
2389 request: T,
2390 response: Response<T>,
2391 session: MessageContext,
2392) -> Result<()>
2393where
2394 T: EntityMessage + RequestMessage,
2395{
2396 let project_id = ProjectId::from_proto(request.remote_entity_id());
2397 let host_connection_id = session
2398 .db()
2399 .await
2400 .host_for_read_only_project_request(project_id, session.connection_id)
2401 .await?;
2402 let payload = session.forward_request(host_connection_id, request).await?;
2403 response.send(payload)?;
2404 Ok(())
2405}
2406
2407/// forward a project request to the host. These requests are disallowed
2408/// for guests.
2409async fn forward_mutating_project_request<T>(
2410 request: T,
2411 response: Response<T>,
2412 session: MessageContext,
2413) -> Result<()>
2414where
2415 T: EntityMessage + RequestMessage,
2416{
2417 let project_id = ProjectId::from_proto(request.remote_entity_id());
2418
2419 let host_connection_id = session
2420 .db()
2421 .await
2422 .host_for_mutating_project_request(project_id, session.connection_id)
2423 .await?;
2424 let payload = session.forward_request(host_connection_id, request).await?;
2425 response.send(payload)?;
2426 Ok(())
2427}
2428
2429async fn multi_lsp_query(
2430 request: MultiLspQuery,
2431 response: Response<MultiLspQuery>,
2432 session: MessageContext,
2433) -> Result<()> {
2434 tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2435 tracing::info!("multi_lsp_query message received");
2436 forward_mutating_project_request(request, response, session).await
2437}
2438
2439/// Notify other participants that a new buffer has been created
2440async fn create_buffer_for_peer(
2441 request: proto::CreateBufferForPeer,
2442 session: MessageContext,
2443) -> Result<()> {
2444 session
2445 .db()
2446 .await
2447 .check_user_is_project_host(
2448 ProjectId::from_proto(request.project_id),
2449 session.connection_id,
2450 )
2451 .await?;
2452 let peer_id = request.peer_id.context("invalid peer id")?;
2453 session
2454 .peer
2455 .forward_send(session.connection_id, peer_id.into(), request)?;
2456 Ok(())
2457}
2458
2459/// Notify other participants that a buffer has been updated. This is
2460/// allowed for guests as long as the update is limited to selections.
2461async fn update_buffer(
2462 request: proto::UpdateBuffer,
2463 response: Response<proto::UpdateBuffer>,
2464 session: MessageContext,
2465) -> Result<()> {
2466 let project_id = ProjectId::from_proto(request.project_id);
2467 let mut capability = Capability::ReadOnly;
2468
2469 for op in request.operations.iter() {
2470 match op.variant {
2471 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2472 Some(_) => capability = Capability::ReadWrite,
2473 }
2474 }
2475
2476 let host = {
2477 let guard = session
2478 .db()
2479 .await
2480 .connections_for_buffer_update(project_id, session.connection_id, capability)
2481 .await?;
2482
2483 let (host, guests) = &*guard;
2484
2485 broadcast(
2486 Some(session.connection_id),
2487 guests.clone(),
2488 |connection_id| {
2489 session
2490 .peer
2491 .forward_send(session.connection_id, connection_id, request.clone())
2492 },
2493 );
2494
2495 *host
2496 };
2497
2498 if host != session.connection_id {
2499 session.forward_request(host, request.clone()).await?;
2500 }
2501
2502 response.send(proto::Ack {})?;
2503 Ok(())
2504}
2505
2506async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2507 let project_id = ProjectId::from_proto(message.project_id);
2508
2509 let operation = message.operation.as_ref().context("invalid operation")?;
2510 let capability = match operation.variant.as_ref() {
2511 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2512 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2513 match buffer_op.variant {
2514 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2515 Capability::ReadOnly
2516 }
2517 _ => Capability::ReadWrite,
2518 }
2519 } else {
2520 Capability::ReadWrite
2521 }
2522 }
2523 Some(_) => Capability::ReadWrite,
2524 None => Capability::ReadOnly,
2525 };
2526
2527 let guard = session
2528 .db()
2529 .await
2530 .connections_for_buffer_update(project_id, session.connection_id, capability)
2531 .await?;
2532
2533 let (host, guests) = &*guard;
2534
2535 broadcast(
2536 Some(session.connection_id),
2537 guests.iter().chain([host]).copied(),
2538 |connection_id| {
2539 session
2540 .peer
2541 .forward_send(session.connection_id, connection_id, message.clone())
2542 },
2543 );
2544
2545 Ok(())
2546}
2547
2548/// Notify other participants that a project has been updated.
2549async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2550 request: T,
2551 session: MessageContext,
2552) -> Result<()> {
2553 let project_id = ProjectId::from_proto(request.remote_entity_id());
2554 let project_connection_ids = session
2555 .db()
2556 .await
2557 .project_connection_ids(project_id, session.connection_id, false)
2558 .await?;
2559
2560 broadcast(
2561 Some(session.connection_id),
2562 project_connection_ids.iter().copied(),
2563 |connection_id| {
2564 session
2565 .peer
2566 .forward_send(session.connection_id, connection_id, request.clone())
2567 },
2568 );
2569 Ok(())
2570}
2571
2572/// Start following another user in a call.
2573async fn follow(
2574 request: proto::Follow,
2575 response: Response<proto::Follow>,
2576 session: MessageContext,
2577) -> Result<()> {
2578 let room_id = RoomId::from_proto(request.room_id);
2579 let project_id = request.project_id.map(ProjectId::from_proto);
2580 let leader_id = request.leader_id.context("invalid leader id")?.into();
2581 let follower_id = session.connection_id;
2582
2583 session
2584 .db()
2585 .await
2586 .check_room_participants(room_id, leader_id, session.connection_id)
2587 .await?;
2588
2589 let response_payload = session.forward_request(leader_id, request).await?;
2590 response.send(response_payload)?;
2591
2592 if let Some(project_id) = project_id {
2593 let room = session
2594 .db()
2595 .await
2596 .follow(room_id, project_id, leader_id, follower_id)
2597 .await?;
2598 room_updated(&room, &session.peer);
2599 }
2600
2601 Ok(())
2602}
2603
2604/// Stop following another user in a call.
2605async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2606 let room_id = RoomId::from_proto(request.room_id);
2607 let project_id = request.project_id.map(ProjectId::from_proto);
2608 let leader_id = request.leader_id.context("invalid leader id")?.into();
2609 let follower_id = session.connection_id;
2610
2611 session
2612 .db()
2613 .await
2614 .check_room_participants(room_id, leader_id, session.connection_id)
2615 .await?;
2616
2617 session
2618 .peer
2619 .forward_send(session.connection_id, leader_id, request)?;
2620
2621 if let Some(project_id) = project_id {
2622 let room = session
2623 .db()
2624 .await
2625 .unfollow(room_id, project_id, leader_id, follower_id)
2626 .await?;
2627 room_updated(&room, &session.peer);
2628 }
2629
2630 Ok(())
2631}
2632
2633/// Notify everyone following you of your current location.
2634async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2635 let room_id = RoomId::from_proto(request.room_id);
2636 let database = session.db.lock().await;
2637
2638 let connection_ids = if let Some(project_id) = request.project_id {
2639 let project_id = ProjectId::from_proto(project_id);
2640 database
2641 .project_connection_ids(project_id, session.connection_id, true)
2642 .await?
2643 } else {
2644 database
2645 .room_connection_ids(room_id, session.connection_id)
2646 .await?
2647 };
2648
2649 // For now, don't send view update messages back to that view's current leader.
2650 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2651 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2652 _ => None,
2653 });
2654
2655 for connection_id in connection_ids.iter().cloned() {
2656 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2657 session
2658 .peer
2659 .forward_send(session.connection_id, connection_id, request.clone())?;
2660 }
2661 }
2662 Ok(())
2663}
2664
2665/// Get public data about users.
2666async fn get_users(
2667 request: proto::GetUsers,
2668 response: Response<proto::GetUsers>,
2669 session: MessageContext,
2670) -> Result<()> {
2671 let user_ids = request
2672 .user_ids
2673 .into_iter()
2674 .map(UserId::from_proto)
2675 .collect();
2676 let users = session
2677 .db()
2678 .await
2679 .get_users_by_ids(user_ids)
2680 .await?
2681 .into_iter()
2682 .map(|user| proto::User {
2683 id: user.id.to_proto(),
2684 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2685 github_login: user.github_login,
2686 name: user.name,
2687 })
2688 .collect();
2689 response.send(proto::UsersResponse { users })?;
2690 Ok(())
2691}
2692
2693/// Search for users (to invite) buy Github login
2694async fn fuzzy_search_users(
2695 request: proto::FuzzySearchUsers,
2696 response: Response<proto::FuzzySearchUsers>,
2697 session: MessageContext,
2698) -> Result<()> {
2699 let query = request.query;
2700 let users = match query.len() {
2701 0 => vec![],
2702 1 | 2 => session
2703 .db()
2704 .await
2705 .get_user_by_github_login(&query)
2706 .await?
2707 .into_iter()
2708 .collect(),
2709 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2710 };
2711 let users = users
2712 .into_iter()
2713 .filter(|user| user.id != session.user_id())
2714 .map(|user| proto::User {
2715 id: user.id.to_proto(),
2716 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2717 github_login: user.github_login,
2718 name: user.name,
2719 })
2720 .collect();
2721 response.send(proto::UsersResponse { users })?;
2722 Ok(())
2723}
2724
2725/// Send a contact request to another user.
2726async fn request_contact(
2727 request: proto::RequestContact,
2728 response: Response<proto::RequestContact>,
2729 session: MessageContext,
2730) -> Result<()> {
2731 let requester_id = session.user_id();
2732 let responder_id = UserId::from_proto(request.responder_id);
2733 if requester_id == responder_id {
2734 return Err(anyhow!("cannot add yourself as a contact"))?;
2735 }
2736
2737 let notifications = session
2738 .db()
2739 .await
2740 .send_contact_request(requester_id, responder_id)
2741 .await?;
2742
2743 // Update outgoing contact requests of requester
2744 let mut update = proto::UpdateContacts::default();
2745 update.outgoing_requests.push(responder_id.to_proto());
2746 for connection_id in session
2747 .connection_pool()
2748 .await
2749 .user_connection_ids(requester_id)
2750 {
2751 session.peer.send(connection_id, update.clone())?;
2752 }
2753
2754 // Update incoming contact requests of responder
2755 let mut update = proto::UpdateContacts::default();
2756 update
2757 .incoming_requests
2758 .push(proto::IncomingContactRequest {
2759 requester_id: requester_id.to_proto(),
2760 });
2761 let connection_pool = session.connection_pool().await;
2762 for connection_id in connection_pool.user_connection_ids(responder_id) {
2763 session.peer.send(connection_id, update.clone())?;
2764 }
2765
2766 send_notifications(&connection_pool, &session.peer, notifications);
2767
2768 response.send(proto::Ack {})?;
2769 Ok(())
2770}
2771
2772/// Accept or decline a contact request
2773async fn respond_to_contact_request(
2774 request: proto::RespondToContactRequest,
2775 response: Response<proto::RespondToContactRequest>,
2776 session: MessageContext,
2777) -> Result<()> {
2778 let responder_id = session.user_id();
2779 let requester_id = UserId::from_proto(request.requester_id);
2780 let db = session.db().await;
2781 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2782 db.dismiss_contact_notification(responder_id, requester_id)
2783 .await?;
2784 } else {
2785 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2786
2787 let notifications = db
2788 .respond_to_contact_request(responder_id, requester_id, accept)
2789 .await?;
2790 let requester_busy = db.is_user_busy(requester_id).await?;
2791 let responder_busy = db.is_user_busy(responder_id).await?;
2792
2793 let pool = session.connection_pool().await;
2794 // Update responder with new contact
2795 let mut update = proto::UpdateContacts::default();
2796 if accept {
2797 update
2798 .contacts
2799 .push(contact_for_user(requester_id, requester_busy, &pool));
2800 }
2801 update
2802 .remove_incoming_requests
2803 .push(requester_id.to_proto());
2804 for connection_id in pool.user_connection_ids(responder_id) {
2805 session.peer.send(connection_id, update.clone())?;
2806 }
2807
2808 // Update requester with new contact
2809 let mut update = proto::UpdateContacts::default();
2810 if accept {
2811 update
2812 .contacts
2813 .push(contact_for_user(responder_id, responder_busy, &pool));
2814 }
2815 update
2816 .remove_outgoing_requests
2817 .push(responder_id.to_proto());
2818
2819 for connection_id in pool.user_connection_ids(requester_id) {
2820 session.peer.send(connection_id, update.clone())?;
2821 }
2822
2823 send_notifications(&pool, &session.peer, notifications);
2824 }
2825
2826 response.send(proto::Ack {})?;
2827 Ok(())
2828}
2829
2830/// Remove a contact.
2831async fn remove_contact(
2832 request: proto::RemoveContact,
2833 response: Response<proto::RemoveContact>,
2834 session: MessageContext,
2835) -> Result<()> {
2836 let requester_id = session.user_id();
2837 let responder_id = UserId::from_proto(request.user_id);
2838 let db = session.db().await;
2839 let (contact_accepted, deleted_notification_id) =
2840 db.remove_contact(requester_id, responder_id).await?;
2841
2842 let pool = session.connection_pool().await;
2843 // Update outgoing contact requests of requester
2844 let mut update = proto::UpdateContacts::default();
2845 if contact_accepted {
2846 update.remove_contacts.push(responder_id.to_proto());
2847 } else {
2848 update
2849 .remove_outgoing_requests
2850 .push(responder_id.to_proto());
2851 }
2852 for connection_id in pool.user_connection_ids(requester_id) {
2853 session.peer.send(connection_id, update.clone())?;
2854 }
2855
2856 // Update incoming contact requests of responder
2857 let mut update = proto::UpdateContacts::default();
2858 if contact_accepted {
2859 update.remove_contacts.push(requester_id.to_proto());
2860 } else {
2861 update
2862 .remove_incoming_requests
2863 .push(requester_id.to_proto());
2864 }
2865 for connection_id in pool.user_connection_ids(responder_id) {
2866 session.peer.send(connection_id, update.clone())?;
2867 if let Some(notification_id) = deleted_notification_id {
2868 session.peer.send(
2869 connection_id,
2870 proto::DeleteNotification {
2871 notification_id: notification_id.to_proto(),
2872 },
2873 )?;
2874 }
2875 }
2876
2877 response.send(proto::Ack {})?;
2878 Ok(())
2879}
2880
2881fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2882 version.0.minor() < 139
2883}
2884
2885async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2886 if is_staff {
2887 return Ok(proto::Plan::ZedPro);
2888 }
2889
2890 let subscription = db.get_active_billing_subscription(user_id).await?;
2891 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2892
2893 let plan = if let Some(subscription_kind) = subscription_kind {
2894 match subscription_kind {
2895 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2896 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2897 SubscriptionKind::ZedFree => proto::Plan::Free,
2898 }
2899 } else {
2900 proto::Plan::Free
2901 };
2902
2903 Ok(plan)
2904}
2905
2906async fn make_update_user_plan_message(
2907 user: &User,
2908 is_staff: bool,
2909 db: &Arc<Database>,
2910 llm_db: Option<Arc<LlmDatabase>>,
2911) -> Result<proto::UpdateUserPlan> {
2912 let feature_flags = db.get_user_flags(user.id).await?;
2913 let plan = current_plan(db, user.id, is_staff).await?;
2914 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2915 let billing_preferences = db.get_billing_preferences(user.id).await?;
2916
2917 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2918 let subscription = db.get_active_billing_subscription(user.id).await?;
2919
2920 let subscription_period =
2921 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2922
2923 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2924 llm_db
2925 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2926 .await?
2927 } else {
2928 None
2929 };
2930
2931 (subscription_period, usage)
2932 } else {
2933 (None, None)
2934 };
2935
2936 let bypass_account_age_check = feature_flags
2937 .iter()
2938 .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2939 let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2940 && !bypass_account_age_check
2941 && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2942
2943 Ok(proto::UpdateUserPlan {
2944 plan: plan.into(),
2945 trial_started_at: billing_customer
2946 .as_ref()
2947 .and_then(|billing_customer| billing_customer.trial_started_at)
2948 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2949 is_usage_based_billing_enabled: if is_staff {
2950 Some(true)
2951 } else {
2952 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2953 },
2954 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2955 proto::SubscriptionPeriod {
2956 started_at: started_at.timestamp() as u64,
2957 ended_at: ended_at.timestamp() as u64,
2958 }
2959 }),
2960 account_too_young: Some(account_too_young),
2961 has_overdue_invoices: billing_customer
2962 .map(|billing_customer| billing_customer.has_overdue_invoices),
2963 usage: Some(
2964 usage
2965 .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2966 .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2967 ),
2968 })
2969}
2970
2971fn model_requests_limit(
2972 plan: cloud_llm_client::Plan,
2973 feature_flags: &Vec<String>,
2974) -> cloud_llm_client::UsageLimit {
2975 match plan.model_requests_limit() {
2976 cloud_llm_client::UsageLimit::Limited(limit) => {
2977 let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2978 && feature_flags
2979 .iter()
2980 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2981 {
2982 1_000
2983 } else {
2984 limit
2985 };
2986
2987 cloud_llm_client::UsageLimit::Limited(limit)
2988 }
2989 cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2990 }
2991}
2992
2993fn subscription_usage_to_proto(
2994 plan: proto::Plan,
2995 usage: crate::llm::db::subscription_usage::Model,
2996 feature_flags: &Vec<String>,
2997) -> proto::SubscriptionUsage {
2998 let plan = match plan {
2999 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
3000 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
3001 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
3002 };
3003
3004 proto::SubscriptionUsage {
3005 model_requests_usage_amount: usage.model_requests as u32,
3006 model_requests_usage_limit: Some(proto::UsageLimit {
3007 variant: Some(match model_requests_limit(plan, feature_flags) {
3008 cloud_llm_client::UsageLimit::Limited(limit) => {
3009 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3010 limit: limit as u32,
3011 })
3012 }
3013 cloud_llm_client::UsageLimit::Unlimited => {
3014 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3015 }
3016 }),
3017 }),
3018 edit_predictions_usage_amount: usage.edit_predictions as u32,
3019 edit_predictions_usage_limit: Some(proto::UsageLimit {
3020 variant: Some(match plan.edit_predictions_limit() {
3021 cloud_llm_client::UsageLimit::Limited(limit) => {
3022 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3023 limit: limit as u32,
3024 })
3025 }
3026 cloud_llm_client::UsageLimit::Unlimited => {
3027 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3028 }
3029 }),
3030 }),
3031 }
3032}
3033
3034fn make_default_subscription_usage(
3035 plan: proto::Plan,
3036 feature_flags: &Vec<String>,
3037) -> proto::SubscriptionUsage {
3038 let plan = match plan {
3039 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
3040 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
3041 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
3042 };
3043
3044 proto::SubscriptionUsage {
3045 model_requests_usage_amount: 0,
3046 model_requests_usage_limit: Some(proto::UsageLimit {
3047 variant: Some(match model_requests_limit(plan, feature_flags) {
3048 cloud_llm_client::UsageLimit::Limited(limit) => {
3049 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3050 limit: limit as u32,
3051 })
3052 }
3053 cloud_llm_client::UsageLimit::Unlimited => {
3054 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3055 }
3056 }),
3057 }),
3058 edit_predictions_usage_amount: 0,
3059 edit_predictions_usage_limit: Some(proto::UsageLimit {
3060 variant: Some(match plan.edit_predictions_limit() {
3061 cloud_llm_client::UsageLimit::Limited(limit) => {
3062 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3063 limit: limit as u32,
3064 })
3065 }
3066 cloud_llm_client::UsageLimit::Unlimited => {
3067 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3068 }
3069 }),
3070 }),
3071 }
3072}
3073
3074async fn update_user_plan(session: &Session) -> Result<()> {
3075 let db = session.db().await;
3076
3077 let update_user_plan = make_update_user_plan_message(
3078 session.principal.user(),
3079 session.is_staff(),
3080 &db.0,
3081 session.app_state.llm_db.clone(),
3082 )
3083 .await?;
3084
3085 session
3086 .peer
3087 .send(session.connection_id, update_user_plan)
3088 .trace_err();
3089
3090 Ok(())
3091}
3092
3093async fn subscribe_to_channels(
3094 _: proto::SubscribeToChannels,
3095 session: MessageContext,
3096) -> Result<()> {
3097 subscribe_user_to_channels(session.user_id(), &session).await?;
3098 Ok(())
3099}
3100
3101async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3102 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3103 let mut pool = session.connection_pool().await;
3104 for membership in &channels_for_user.channel_memberships {
3105 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3106 }
3107 session.peer.send(
3108 session.connection_id,
3109 build_update_user_channels(&channels_for_user),
3110 )?;
3111 session.peer.send(
3112 session.connection_id,
3113 build_channels_update(channels_for_user),
3114 )?;
3115 Ok(())
3116}
3117
3118/// Creates a new channel.
3119async fn create_channel(
3120 request: proto::CreateChannel,
3121 response: Response<proto::CreateChannel>,
3122 session: MessageContext,
3123) -> Result<()> {
3124 let db = session.db().await;
3125
3126 let parent_id = request.parent_id.map(ChannelId::from_proto);
3127 let (channel, membership) = db
3128 .create_channel(&request.name, parent_id, session.user_id())
3129 .await?;
3130
3131 let root_id = channel.root_id();
3132 let channel = Channel::from_model(channel);
3133
3134 response.send(proto::CreateChannelResponse {
3135 channel: Some(channel.to_proto()),
3136 parent_id: request.parent_id,
3137 })?;
3138
3139 let mut connection_pool = session.connection_pool().await;
3140 if let Some(membership) = membership {
3141 connection_pool.subscribe_to_channel(
3142 membership.user_id,
3143 membership.channel_id,
3144 membership.role,
3145 );
3146 let update = proto::UpdateUserChannels {
3147 channel_memberships: vec![proto::ChannelMembership {
3148 channel_id: membership.channel_id.to_proto(),
3149 role: membership.role.into(),
3150 }],
3151 ..Default::default()
3152 };
3153 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3154 session.peer.send(connection_id, update.clone())?;
3155 }
3156 }
3157
3158 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3159 if !role.can_see_channel(channel.visibility) {
3160 continue;
3161 }
3162
3163 let update = proto::UpdateChannels {
3164 channels: vec![channel.to_proto()],
3165 ..Default::default()
3166 };
3167 session.peer.send(connection_id, update.clone())?;
3168 }
3169
3170 Ok(())
3171}
3172
3173/// Delete a channel
3174async fn delete_channel(
3175 request: proto::DeleteChannel,
3176 response: Response<proto::DeleteChannel>,
3177 session: MessageContext,
3178) -> Result<()> {
3179 let db = session.db().await;
3180
3181 let channel_id = request.channel_id;
3182 let (root_channel, removed_channels) = db
3183 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3184 .await?;
3185 response.send(proto::Ack {})?;
3186
3187 // Notify members of removed channels
3188 let mut update = proto::UpdateChannels::default();
3189 update
3190 .delete_channels
3191 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3192
3193 let connection_pool = session.connection_pool().await;
3194 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3195 session.peer.send(connection_id, update.clone())?;
3196 }
3197
3198 Ok(())
3199}
3200
3201/// Invite someone to join a channel.
3202async fn invite_channel_member(
3203 request: proto::InviteChannelMember,
3204 response: Response<proto::InviteChannelMember>,
3205 session: MessageContext,
3206) -> Result<()> {
3207 let db = session.db().await;
3208 let channel_id = ChannelId::from_proto(request.channel_id);
3209 let invitee_id = UserId::from_proto(request.user_id);
3210 let InviteMemberResult {
3211 channel,
3212 notifications,
3213 } = db
3214 .invite_channel_member(
3215 channel_id,
3216 invitee_id,
3217 session.user_id(),
3218 request.role().into(),
3219 )
3220 .await?;
3221
3222 let update = proto::UpdateChannels {
3223 channel_invitations: vec![channel.to_proto()],
3224 ..Default::default()
3225 };
3226
3227 let connection_pool = session.connection_pool().await;
3228 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3229 session.peer.send(connection_id, update.clone())?;
3230 }
3231
3232 send_notifications(&connection_pool, &session.peer, notifications);
3233
3234 response.send(proto::Ack {})?;
3235 Ok(())
3236}
3237
3238/// remove someone from a channel
3239async fn remove_channel_member(
3240 request: proto::RemoveChannelMember,
3241 response: Response<proto::RemoveChannelMember>,
3242 session: MessageContext,
3243) -> Result<()> {
3244 let db = session.db().await;
3245 let channel_id = ChannelId::from_proto(request.channel_id);
3246 let member_id = UserId::from_proto(request.user_id);
3247
3248 let RemoveChannelMemberResult {
3249 membership_update,
3250 notification_id,
3251 } = db
3252 .remove_channel_member(channel_id, member_id, session.user_id())
3253 .await?;
3254
3255 let mut connection_pool = session.connection_pool().await;
3256 notify_membership_updated(
3257 &mut connection_pool,
3258 membership_update,
3259 member_id,
3260 &session.peer,
3261 );
3262 for connection_id in connection_pool.user_connection_ids(member_id) {
3263 if let Some(notification_id) = notification_id {
3264 session
3265 .peer
3266 .send(
3267 connection_id,
3268 proto::DeleteNotification {
3269 notification_id: notification_id.to_proto(),
3270 },
3271 )
3272 .trace_err();
3273 }
3274 }
3275
3276 response.send(proto::Ack {})?;
3277 Ok(())
3278}
3279
3280/// Toggle the channel between public and private.
3281/// Care is taken to maintain the invariant that public channels only descend from public channels,
3282/// (though members-only channels can appear at any point in the hierarchy).
3283async fn set_channel_visibility(
3284 request: proto::SetChannelVisibility,
3285 response: Response<proto::SetChannelVisibility>,
3286 session: MessageContext,
3287) -> Result<()> {
3288 let db = session.db().await;
3289 let channel_id = ChannelId::from_proto(request.channel_id);
3290 let visibility = request.visibility().into();
3291
3292 let channel_model = db
3293 .set_channel_visibility(channel_id, visibility, session.user_id())
3294 .await?;
3295 let root_id = channel_model.root_id();
3296 let channel = Channel::from_model(channel_model);
3297
3298 let mut connection_pool = session.connection_pool().await;
3299 for (user_id, role) in connection_pool
3300 .channel_user_ids(root_id)
3301 .collect::<Vec<_>>()
3302 .into_iter()
3303 {
3304 let update = if role.can_see_channel(channel.visibility) {
3305 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3306 proto::UpdateChannels {
3307 channels: vec![channel.to_proto()],
3308 ..Default::default()
3309 }
3310 } else {
3311 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3312 proto::UpdateChannels {
3313 delete_channels: vec![channel.id.to_proto()],
3314 ..Default::default()
3315 }
3316 };
3317
3318 for connection_id in connection_pool.user_connection_ids(user_id) {
3319 session.peer.send(connection_id, update.clone())?;
3320 }
3321 }
3322
3323 response.send(proto::Ack {})?;
3324 Ok(())
3325}
3326
3327/// Alter the role for a user in the channel.
3328async fn set_channel_member_role(
3329 request: proto::SetChannelMemberRole,
3330 response: Response<proto::SetChannelMemberRole>,
3331 session: MessageContext,
3332) -> Result<()> {
3333 let db = session.db().await;
3334 let channel_id = ChannelId::from_proto(request.channel_id);
3335 let member_id = UserId::from_proto(request.user_id);
3336 let result = db
3337 .set_channel_member_role(
3338 channel_id,
3339 session.user_id(),
3340 member_id,
3341 request.role().into(),
3342 )
3343 .await?;
3344
3345 match result {
3346 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3347 let mut connection_pool = session.connection_pool().await;
3348 notify_membership_updated(
3349 &mut connection_pool,
3350 membership_update,
3351 member_id,
3352 &session.peer,
3353 )
3354 }
3355 db::SetMemberRoleResult::InviteUpdated(channel) => {
3356 let update = proto::UpdateChannels {
3357 channel_invitations: vec![channel.to_proto()],
3358 ..Default::default()
3359 };
3360
3361 for connection_id in session
3362 .connection_pool()
3363 .await
3364 .user_connection_ids(member_id)
3365 {
3366 session.peer.send(connection_id, update.clone())?;
3367 }
3368 }
3369 }
3370
3371 response.send(proto::Ack {})?;
3372 Ok(())
3373}
3374
3375/// Change the name of a channel
3376async fn rename_channel(
3377 request: proto::RenameChannel,
3378 response: Response<proto::RenameChannel>,
3379 session: MessageContext,
3380) -> Result<()> {
3381 let db = session.db().await;
3382 let channel_id = ChannelId::from_proto(request.channel_id);
3383 let channel_model = db
3384 .rename_channel(channel_id, session.user_id(), &request.name)
3385 .await?;
3386 let root_id = channel_model.root_id();
3387 let channel = Channel::from_model(channel_model);
3388
3389 response.send(proto::RenameChannelResponse {
3390 channel: Some(channel.to_proto()),
3391 })?;
3392
3393 let connection_pool = session.connection_pool().await;
3394 let update = proto::UpdateChannels {
3395 channels: vec![channel.to_proto()],
3396 ..Default::default()
3397 };
3398 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3399 if role.can_see_channel(channel.visibility) {
3400 session.peer.send(connection_id, update.clone())?;
3401 }
3402 }
3403
3404 Ok(())
3405}
3406
3407/// Move a channel to a new parent.
3408async fn move_channel(
3409 request: proto::MoveChannel,
3410 response: Response<proto::MoveChannel>,
3411 session: MessageContext,
3412) -> Result<()> {
3413 let channel_id = ChannelId::from_proto(request.channel_id);
3414 let to = ChannelId::from_proto(request.to);
3415
3416 let (root_id, channels) = session
3417 .db()
3418 .await
3419 .move_channel(channel_id, to, session.user_id())
3420 .await?;
3421
3422 let connection_pool = session.connection_pool().await;
3423 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3424 let channels = channels
3425 .iter()
3426 .filter_map(|channel| {
3427 if role.can_see_channel(channel.visibility) {
3428 Some(channel.to_proto())
3429 } else {
3430 None
3431 }
3432 })
3433 .collect::<Vec<_>>();
3434 if channels.is_empty() {
3435 continue;
3436 }
3437
3438 let update = proto::UpdateChannels {
3439 channels,
3440 ..Default::default()
3441 };
3442
3443 session.peer.send(connection_id, update.clone())?;
3444 }
3445
3446 response.send(Ack {})?;
3447 Ok(())
3448}
3449
3450async fn reorder_channel(
3451 request: proto::ReorderChannel,
3452 response: Response<proto::ReorderChannel>,
3453 session: MessageContext,
3454) -> Result<()> {
3455 let channel_id = ChannelId::from_proto(request.channel_id);
3456 let direction = request.direction();
3457
3458 let updated_channels = session
3459 .db()
3460 .await
3461 .reorder_channel(channel_id, direction, session.user_id())
3462 .await?;
3463
3464 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3465 let connection_pool = session.connection_pool().await;
3466 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3467 let channels = updated_channels
3468 .iter()
3469 .filter_map(|channel| {
3470 if role.can_see_channel(channel.visibility) {
3471 Some(channel.to_proto())
3472 } else {
3473 None
3474 }
3475 })
3476 .collect::<Vec<_>>();
3477
3478 if channels.is_empty() {
3479 continue;
3480 }
3481
3482 let update = proto::UpdateChannels {
3483 channels,
3484 ..Default::default()
3485 };
3486
3487 session.peer.send(connection_id, update.clone())?;
3488 }
3489 }
3490
3491 response.send(Ack {})?;
3492 Ok(())
3493}
3494
3495/// Get the list of channel members
3496async fn get_channel_members(
3497 request: proto::GetChannelMembers,
3498 response: Response<proto::GetChannelMembers>,
3499 session: MessageContext,
3500) -> Result<()> {
3501 let db = session.db().await;
3502 let channel_id = ChannelId::from_proto(request.channel_id);
3503 let limit = if request.limit == 0 {
3504 u16::MAX as u64
3505 } else {
3506 request.limit
3507 };
3508 let (members, users) = db
3509 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3510 .await?;
3511 response.send(proto::GetChannelMembersResponse { members, users })?;
3512 Ok(())
3513}
3514
3515/// Accept or decline a channel invitation.
3516async fn respond_to_channel_invite(
3517 request: proto::RespondToChannelInvite,
3518 response: Response<proto::RespondToChannelInvite>,
3519 session: MessageContext,
3520) -> Result<()> {
3521 let db = session.db().await;
3522 let channel_id = ChannelId::from_proto(request.channel_id);
3523 let RespondToChannelInvite {
3524 membership_update,
3525 notifications,
3526 } = db
3527 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3528 .await?;
3529
3530 let mut connection_pool = session.connection_pool().await;
3531 if let Some(membership_update) = membership_update {
3532 notify_membership_updated(
3533 &mut connection_pool,
3534 membership_update,
3535 session.user_id(),
3536 &session.peer,
3537 );
3538 } else {
3539 let update = proto::UpdateChannels {
3540 remove_channel_invitations: vec![channel_id.to_proto()],
3541 ..Default::default()
3542 };
3543
3544 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3545 session.peer.send(connection_id, update.clone())?;
3546 }
3547 };
3548
3549 send_notifications(&connection_pool, &session.peer, notifications);
3550
3551 response.send(proto::Ack {})?;
3552
3553 Ok(())
3554}
3555
3556/// Join the channels' room
3557async fn join_channel(
3558 request: proto::JoinChannel,
3559 response: Response<proto::JoinChannel>,
3560 session: MessageContext,
3561) -> Result<()> {
3562 let channel_id = ChannelId::from_proto(request.channel_id);
3563 join_channel_internal(channel_id, Box::new(response), session).await
3564}
3565
3566trait JoinChannelInternalResponse {
3567 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3568}
3569impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3570 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3571 Response::<proto::JoinChannel>::send(self, result)
3572 }
3573}
3574impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3575 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3576 Response::<proto::JoinRoom>::send(self, result)
3577 }
3578}
3579
3580async fn join_channel_internal(
3581 channel_id: ChannelId,
3582 response: Box<impl JoinChannelInternalResponse>,
3583 session: MessageContext,
3584) -> Result<()> {
3585 let joined_room = {
3586 let mut db = session.db().await;
3587 // If zed quits without leaving the room, and the user re-opens zed before the
3588 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3589 // room they were in.
3590 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3591 tracing::info!(
3592 stale_connection_id = %connection,
3593 "cleaning up stale connection",
3594 );
3595 drop(db);
3596 leave_room_for_session(&session, connection).await?;
3597 db = session.db().await;
3598 }
3599
3600 let (joined_room, membership_updated, role) = db
3601 .join_channel(channel_id, session.user_id(), session.connection_id)
3602 .await?;
3603
3604 let live_kit_connection_info =
3605 session
3606 .app_state
3607 .livekit_client
3608 .as_ref()
3609 .and_then(|live_kit| {
3610 let (can_publish, token) = if role == ChannelRole::Guest {
3611 (
3612 false,
3613 live_kit
3614 .guest_token(
3615 &joined_room.room.livekit_room,
3616 &session.user_id().to_string(),
3617 )
3618 .trace_err()?,
3619 )
3620 } else {
3621 (
3622 true,
3623 live_kit
3624 .room_token(
3625 &joined_room.room.livekit_room,
3626 &session.user_id().to_string(),
3627 )
3628 .trace_err()?,
3629 )
3630 };
3631
3632 Some(LiveKitConnectionInfo {
3633 server_url: live_kit.url().into(),
3634 token,
3635 can_publish,
3636 })
3637 });
3638
3639 response.send(proto::JoinRoomResponse {
3640 room: Some(joined_room.room.clone()),
3641 channel_id: joined_room
3642 .channel
3643 .as_ref()
3644 .map(|channel| channel.id.to_proto()),
3645 live_kit_connection_info,
3646 })?;
3647
3648 let mut connection_pool = session.connection_pool().await;
3649 if let Some(membership_updated) = membership_updated {
3650 notify_membership_updated(
3651 &mut connection_pool,
3652 membership_updated,
3653 session.user_id(),
3654 &session.peer,
3655 );
3656 }
3657
3658 room_updated(&joined_room.room, &session.peer);
3659
3660 joined_room
3661 };
3662
3663 channel_updated(
3664 &joined_room.channel.context("channel not returned")?,
3665 &joined_room.room,
3666 &session.peer,
3667 &*session.connection_pool().await,
3668 );
3669
3670 update_user_contacts(session.user_id(), &session).await?;
3671 Ok(())
3672}
3673
3674/// Start editing the channel notes
3675async fn join_channel_buffer(
3676 request: proto::JoinChannelBuffer,
3677 response: Response<proto::JoinChannelBuffer>,
3678 session: MessageContext,
3679) -> Result<()> {
3680 let db = session.db().await;
3681 let channel_id = ChannelId::from_proto(request.channel_id);
3682
3683 let open_response = db
3684 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3685 .await?;
3686
3687 let collaborators = open_response.collaborators.clone();
3688 response.send(open_response)?;
3689
3690 let update = UpdateChannelBufferCollaborators {
3691 channel_id: channel_id.to_proto(),
3692 collaborators: collaborators.clone(),
3693 };
3694 channel_buffer_updated(
3695 session.connection_id,
3696 collaborators
3697 .iter()
3698 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3699 &update,
3700 &session.peer,
3701 );
3702
3703 Ok(())
3704}
3705
3706/// Edit the channel notes
3707async fn update_channel_buffer(
3708 request: proto::UpdateChannelBuffer,
3709 session: MessageContext,
3710) -> Result<()> {
3711 let db = session.db().await;
3712 let channel_id = ChannelId::from_proto(request.channel_id);
3713
3714 let (collaborators, epoch, version) = db
3715 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3716 .await?;
3717
3718 channel_buffer_updated(
3719 session.connection_id,
3720 collaborators.clone(),
3721 &proto::UpdateChannelBuffer {
3722 channel_id: channel_id.to_proto(),
3723 operations: request.operations,
3724 },
3725 &session.peer,
3726 );
3727
3728 let pool = &*session.connection_pool().await;
3729
3730 let non_collaborators =
3731 pool.channel_connection_ids(channel_id)
3732 .filter_map(|(connection_id, _)| {
3733 if collaborators.contains(&connection_id) {
3734 None
3735 } else {
3736 Some(connection_id)
3737 }
3738 });
3739
3740 broadcast(None, non_collaborators, |peer_id| {
3741 session.peer.send(
3742 peer_id,
3743 proto::UpdateChannels {
3744 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3745 channel_id: channel_id.to_proto(),
3746 epoch: epoch as u64,
3747 version: version.clone(),
3748 }],
3749 ..Default::default()
3750 },
3751 )
3752 });
3753
3754 Ok(())
3755}
3756
3757/// Rejoin the channel notes after a connection blip
3758async fn rejoin_channel_buffers(
3759 request: proto::RejoinChannelBuffers,
3760 response: Response<proto::RejoinChannelBuffers>,
3761 session: MessageContext,
3762) -> Result<()> {
3763 let db = session.db().await;
3764 let buffers = db
3765 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3766 .await?;
3767
3768 for rejoined_buffer in &buffers {
3769 let collaborators_to_notify = rejoined_buffer
3770 .buffer
3771 .collaborators
3772 .iter()
3773 .filter_map(|c| Some(c.peer_id?.into()));
3774 channel_buffer_updated(
3775 session.connection_id,
3776 collaborators_to_notify,
3777 &proto::UpdateChannelBufferCollaborators {
3778 channel_id: rejoined_buffer.buffer.channel_id,
3779 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3780 },
3781 &session.peer,
3782 );
3783 }
3784
3785 response.send(proto::RejoinChannelBuffersResponse {
3786 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3787 })?;
3788
3789 Ok(())
3790}
3791
3792/// Stop editing the channel notes
3793async fn leave_channel_buffer(
3794 request: proto::LeaveChannelBuffer,
3795 response: Response<proto::LeaveChannelBuffer>,
3796 session: MessageContext,
3797) -> Result<()> {
3798 let db = session.db().await;
3799 let channel_id = ChannelId::from_proto(request.channel_id);
3800
3801 let left_buffer = db
3802 .leave_channel_buffer(channel_id, session.connection_id)
3803 .await?;
3804
3805 response.send(Ack {})?;
3806
3807 channel_buffer_updated(
3808 session.connection_id,
3809 left_buffer.connections,
3810 &proto::UpdateChannelBufferCollaborators {
3811 channel_id: channel_id.to_proto(),
3812 collaborators: left_buffer.collaborators,
3813 },
3814 &session.peer,
3815 );
3816
3817 Ok(())
3818}
3819
3820fn channel_buffer_updated<T: EnvelopedMessage>(
3821 sender_id: ConnectionId,
3822 collaborators: impl IntoIterator<Item = ConnectionId>,
3823 message: &T,
3824 peer: &Peer,
3825) {
3826 broadcast(Some(sender_id), collaborators, |peer_id| {
3827 peer.send(peer_id, message.clone())
3828 });
3829}
3830
3831fn send_notifications(
3832 connection_pool: &ConnectionPool,
3833 peer: &Peer,
3834 notifications: db::NotificationBatch,
3835) {
3836 for (user_id, notification) in notifications {
3837 for connection_id in connection_pool.user_connection_ids(user_id) {
3838 if let Err(error) = peer.send(
3839 connection_id,
3840 proto::AddNotification {
3841 notification: Some(notification.clone()),
3842 },
3843 ) {
3844 tracing::error!(
3845 "failed to send notification to {:?} {}",
3846 connection_id,
3847 error
3848 );
3849 }
3850 }
3851 }
3852}
3853
3854/// Send a message to the channel
3855async fn send_channel_message(
3856 request: proto::SendChannelMessage,
3857 response: Response<proto::SendChannelMessage>,
3858 session: MessageContext,
3859) -> Result<()> {
3860 // Validate the message body.
3861 let body = request.body.trim().to_string();
3862 if body.len() > MAX_MESSAGE_LEN {
3863 return Err(anyhow!("message is too long"))?;
3864 }
3865 if body.is_empty() {
3866 return Err(anyhow!("message can't be blank"))?;
3867 }
3868
3869 // TODO: adjust mentions if body is trimmed
3870
3871 let timestamp = OffsetDateTime::now_utc();
3872 let nonce = request.nonce.context("nonce can't be blank")?;
3873
3874 let channel_id = ChannelId::from_proto(request.channel_id);
3875 let CreatedChannelMessage {
3876 message_id,
3877 participant_connection_ids,
3878 notifications,
3879 } = session
3880 .db()
3881 .await
3882 .create_channel_message(
3883 channel_id,
3884 session.user_id(),
3885 &body,
3886 &request.mentions,
3887 timestamp,
3888 nonce.clone().into(),
3889 request.reply_to_message_id.map(MessageId::from_proto),
3890 )
3891 .await?;
3892
3893 let message = proto::ChannelMessage {
3894 sender_id: session.user_id().to_proto(),
3895 id: message_id.to_proto(),
3896 body,
3897 mentions: request.mentions,
3898 timestamp: timestamp.unix_timestamp() as u64,
3899 nonce: Some(nonce),
3900 reply_to_message_id: request.reply_to_message_id,
3901 edited_at: None,
3902 };
3903 broadcast(
3904 Some(session.connection_id),
3905 participant_connection_ids.clone(),
3906 |connection| {
3907 session.peer.send(
3908 connection,
3909 proto::ChannelMessageSent {
3910 channel_id: channel_id.to_proto(),
3911 message: Some(message.clone()),
3912 },
3913 )
3914 },
3915 );
3916 response.send(proto::SendChannelMessageResponse {
3917 message: Some(message),
3918 })?;
3919
3920 let pool = &*session.connection_pool().await;
3921 let non_participants =
3922 pool.channel_connection_ids(channel_id)
3923 .filter_map(|(connection_id, _)| {
3924 if participant_connection_ids.contains(&connection_id) {
3925 None
3926 } else {
3927 Some(connection_id)
3928 }
3929 });
3930 broadcast(None, non_participants, |peer_id| {
3931 session.peer.send(
3932 peer_id,
3933 proto::UpdateChannels {
3934 latest_channel_message_ids: vec![proto::ChannelMessageId {
3935 channel_id: channel_id.to_proto(),
3936 message_id: message_id.to_proto(),
3937 }],
3938 ..Default::default()
3939 },
3940 )
3941 });
3942 send_notifications(pool, &session.peer, notifications);
3943
3944 Ok(())
3945}
3946
3947/// Delete a channel message
3948async fn remove_channel_message(
3949 request: proto::RemoveChannelMessage,
3950 response: Response<proto::RemoveChannelMessage>,
3951 session: MessageContext,
3952) -> Result<()> {
3953 let channel_id = ChannelId::from_proto(request.channel_id);
3954 let message_id = MessageId::from_proto(request.message_id);
3955 let (connection_ids, existing_notification_ids) = session
3956 .db()
3957 .await
3958 .remove_channel_message(channel_id, message_id, session.user_id())
3959 .await?;
3960
3961 broadcast(
3962 Some(session.connection_id),
3963 connection_ids,
3964 move |connection| {
3965 session.peer.send(connection, request.clone())?;
3966
3967 for notification_id in &existing_notification_ids {
3968 session.peer.send(
3969 connection,
3970 proto::DeleteNotification {
3971 notification_id: (*notification_id).to_proto(),
3972 },
3973 )?;
3974 }
3975
3976 Ok(())
3977 },
3978 );
3979 response.send(proto::Ack {})?;
3980 Ok(())
3981}
3982
3983async fn update_channel_message(
3984 request: proto::UpdateChannelMessage,
3985 response: Response<proto::UpdateChannelMessage>,
3986 session: MessageContext,
3987) -> Result<()> {
3988 let channel_id = ChannelId::from_proto(request.channel_id);
3989 let message_id = MessageId::from_proto(request.message_id);
3990 let updated_at = OffsetDateTime::now_utc();
3991 let UpdatedChannelMessage {
3992 message_id,
3993 participant_connection_ids,
3994 notifications,
3995 reply_to_message_id,
3996 timestamp,
3997 deleted_mention_notification_ids,
3998 updated_mention_notifications,
3999 } = session
4000 .db()
4001 .await
4002 .update_channel_message(
4003 channel_id,
4004 message_id,
4005 session.user_id(),
4006 request.body.as_str(),
4007 &request.mentions,
4008 updated_at,
4009 )
4010 .await?;
4011
4012 let nonce = request.nonce.clone().context("nonce can't be blank")?;
4013
4014 let message = proto::ChannelMessage {
4015 sender_id: session.user_id().to_proto(),
4016 id: message_id.to_proto(),
4017 body: request.body.clone(),
4018 mentions: request.mentions.clone(),
4019 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4020 nonce: Some(nonce),
4021 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4022 edited_at: Some(updated_at.unix_timestamp() as u64),
4023 };
4024
4025 response.send(proto::Ack {})?;
4026
4027 let pool = &*session.connection_pool().await;
4028 broadcast(
4029 Some(session.connection_id),
4030 participant_connection_ids,
4031 |connection| {
4032 session.peer.send(
4033 connection,
4034 proto::ChannelMessageUpdate {
4035 channel_id: channel_id.to_proto(),
4036 message: Some(message.clone()),
4037 },
4038 )?;
4039
4040 for notification_id in &deleted_mention_notification_ids {
4041 session.peer.send(
4042 connection,
4043 proto::DeleteNotification {
4044 notification_id: (*notification_id).to_proto(),
4045 },
4046 )?;
4047 }
4048
4049 for notification in &updated_mention_notifications {
4050 session.peer.send(
4051 connection,
4052 proto::UpdateNotification {
4053 notification: Some(notification.clone()),
4054 },
4055 )?;
4056 }
4057
4058 Ok(())
4059 },
4060 );
4061
4062 send_notifications(pool, &session.peer, notifications);
4063
4064 Ok(())
4065}
4066
4067/// Mark a channel message as read
4068async fn acknowledge_channel_message(
4069 request: proto::AckChannelMessage,
4070 session: MessageContext,
4071) -> Result<()> {
4072 let channel_id = ChannelId::from_proto(request.channel_id);
4073 let message_id = MessageId::from_proto(request.message_id);
4074 let notifications = session
4075 .db()
4076 .await
4077 .observe_channel_message(channel_id, session.user_id(), message_id)
4078 .await?;
4079 send_notifications(
4080 &*session.connection_pool().await,
4081 &session.peer,
4082 notifications,
4083 );
4084 Ok(())
4085}
4086
4087/// Mark a buffer version as synced
4088async fn acknowledge_buffer_version(
4089 request: proto::AckBufferOperation,
4090 session: MessageContext,
4091) -> Result<()> {
4092 let buffer_id = BufferId::from_proto(request.buffer_id);
4093 session
4094 .db()
4095 .await
4096 .observe_buffer_version(
4097 buffer_id,
4098 session.user_id(),
4099 request.epoch as i32,
4100 &request.version,
4101 )
4102 .await?;
4103 Ok(())
4104}
4105
4106/// Get a Supermaven API key for the user
4107async fn get_supermaven_api_key(
4108 _request: proto::GetSupermavenApiKey,
4109 response: Response<proto::GetSupermavenApiKey>,
4110 session: MessageContext,
4111) -> Result<()> {
4112 let user_id: String = session.user_id().to_string();
4113 if !session.is_staff() {
4114 return Err(anyhow!("supermaven not enabled for this account"))?;
4115 }
4116
4117 let email = session.email().context("user must have an email")?;
4118
4119 let supermaven_admin_api = session
4120 .supermaven_client
4121 .as_ref()
4122 .context("supermaven not configured")?;
4123
4124 let result = supermaven_admin_api
4125 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4126 .await?;
4127
4128 response.send(proto::GetSupermavenApiKeyResponse {
4129 api_key: result.api_key,
4130 })?;
4131
4132 Ok(())
4133}
4134
4135/// Start receiving chat updates for a channel
4136async fn join_channel_chat(
4137 request: proto::JoinChannelChat,
4138 response: Response<proto::JoinChannelChat>,
4139 session: MessageContext,
4140) -> Result<()> {
4141 let channel_id = ChannelId::from_proto(request.channel_id);
4142
4143 let db = session.db().await;
4144 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4145 .await?;
4146 let messages = db
4147 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4148 .await?;
4149 response.send(proto::JoinChannelChatResponse {
4150 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4151 messages,
4152 })?;
4153 Ok(())
4154}
4155
4156/// Stop receiving chat updates for a channel
4157async fn leave_channel_chat(
4158 request: proto::LeaveChannelChat,
4159 session: MessageContext,
4160) -> Result<()> {
4161 let channel_id = ChannelId::from_proto(request.channel_id);
4162 session
4163 .db()
4164 .await
4165 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4166 .await?;
4167 Ok(())
4168}
4169
4170/// Retrieve the chat history for a channel
4171async fn get_channel_messages(
4172 request: proto::GetChannelMessages,
4173 response: Response<proto::GetChannelMessages>,
4174 session: MessageContext,
4175) -> Result<()> {
4176 let channel_id = ChannelId::from_proto(request.channel_id);
4177 let messages = session
4178 .db()
4179 .await
4180 .get_channel_messages(
4181 channel_id,
4182 session.user_id(),
4183 MESSAGE_COUNT_PER_PAGE,
4184 Some(MessageId::from_proto(request.before_message_id)),
4185 )
4186 .await?;
4187 response.send(proto::GetChannelMessagesResponse {
4188 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4189 messages,
4190 })?;
4191 Ok(())
4192}
4193
4194/// Retrieve specific chat messages
4195async fn get_channel_messages_by_id(
4196 request: proto::GetChannelMessagesById,
4197 response: Response<proto::GetChannelMessagesById>,
4198 session: MessageContext,
4199) -> Result<()> {
4200 let message_ids = request
4201 .message_ids
4202 .iter()
4203 .map(|id| MessageId::from_proto(*id))
4204 .collect::<Vec<_>>();
4205 let messages = session
4206 .db()
4207 .await
4208 .get_channel_messages_by_id(session.user_id(), &message_ids)
4209 .await?;
4210 response.send(proto::GetChannelMessagesResponse {
4211 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4212 messages,
4213 })?;
4214 Ok(())
4215}
4216
4217/// Retrieve the current users notifications
4218async fn get_notifications(
4219 request: proto::GetNotifications,
4220 response: Response<proto::GetNotifications>,
4221 session: MessageContext,
4222) -> Result<()> {
4223 let notifications = session
4224 .db()
4225 .await
4226 .get_notifications(
4227 session.user_id(),
4228 NOTIFICATION_COUNT_PER_PAGE,
4229 request.before_id.map(db::NotificationId::from_proto),
4230 )
4231 .await?;
4232 response.send(proto::GetNotificationsResponse {
4233 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4234 notifications,
4235 })?;
4236 Ok(())
4237}
4238
4239/// Mark notifications as read
4240async fn mark_notification_as_read(
4241 request: proto::MarkNotificationRead,
4242 response: Response<proto::MarkNotificationRead>,
4243 session: MessageContext,
4244) -> Result<()> {
4245 let database = &session.db().await;
4246 let notifications = database
4247 .mark_notification_as_read_by_id(
4248 session.user_id(),
4249 NotificationId::from_proto(request.notification_id),
4250 )
4251 .await?;
4252 send_notifications(
4253 &*session.connection_pool().await,
4254 &session.peer,
4255 notifications,
4256 );
4257 response.send(proto::Ack {})?;
4258 Ok(())
4259}
4260
4261/// Get the current users information
4262async fn get_private_user_info(
4263 _request: proto::GetPrivateUserInfo,
4264 response: Response<proto::GetPrivateUserInfo>,
4265 session: MessageContext,
4266) -> Result<()> {
4267 let db = session.db().await;
4268
4269 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4270 let user = db
4271 .get_user_by_id(session.user_id())
4272 .await?
4273 .context("user not found")?;
4274 let flags = db.get_user_flags(session.user_id()).await?;
4275
4276 response.send(proto::GetPrivateUserInfoResponse {
4277 metrics_id,
4278 staff: user.admin,
4279 flags,
4280 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4281 })?;
4282 Ok(())
4283}
4284
4285/// Accept the terms of service (tos) on behalf of the current user
4286async fn accept_terms_of_service(
4287 _request: proto::AcceptTermsOfService,
4288 response: Response<proto::AcceptTermsOfService>,
4289 session: MessageContext,
4290) -> Result<()> {
4291 let db = session.db().await;
4292
4293 let accepted_tos_at = Utc::now();
4294 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4295 .await?;
4296
4297 response.send(proto::AcceptTermsOfServiceResponse {
4298 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4299 })?;
4300
4301 // When the user accepts the terms of service, we want to refresh their LLM
4302 // token to grant access.
4303 session
4304 .peer
4305 .send(session.connection_id, proto::RefreshLlmToken {})?;
4306
4307 Ok(())
4308}
4309
4310async fn get_llm_api_token(
4311 _request: proto::GetLlmToken,
4312 response: Response<proto::GetLlmToken>,
4313 session: MessageContext,
4314) -> Result<()> {
4315 let db = session.db().await;
4316
4317 let flags = db.get_user_flags(session.user_id()).await?;
4318
4319 let user_id = session.user_id();
4320 let user = db
4321 .get_user_by_id(user_id)
4322 .await?
4323 .with_context(|| format!("user {user_id} not found"))?;
4324
4325 if user.accepted_tos_at.is_none() {
4326 Err(anyhow!("terms of service not accepted"))?
4327 }
4328
4329 let stripe_client = session
4330 .app_state
4331 .stripe_client
4332 .as_ref()
4333 .context("failed to retrieve Stripe client")?;
4334
4335 let stripe_billing = session
4336 .app_state
4337 .stripe_billing
4338 .as_ref()
4339 .context("failed to retrieve Stripe billing object")?;
4340
4341 let billing_customer = if let Some(billing_customer) =
4342 db.get_billing_customer_by_user_id(user.id).await?
4343 {
4344 billing_customer
4345 } else {
4346 let customer_id = stripe_billing
4347 .find_or_create_customer_by_email(user.email_address.as_deref())
4348 .await?;
4349
4350 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4351 .await?
4352 .context("billing customer not found")?
4353 };
4354
4355 let billing_subscription =
4356 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4357 billing_subscription
4358 } else {
4359 let stripe_customer_id =
4360 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4361
4362 let stripe_subscription = stripe_billing
4363 .subscribe_to_zed_free(stripe_customer_id)
4364 .await?;
4365
4366 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4367 billing_customer_id: billing_customer.id,
4368 kind: Some(SubscriptionKind::ZedFree),
4369 stripe_subscription_id: stripe_subscription.id.to_string(),
4370 stripe_subscription_status: stripe_subscription.status.into(),
4371 stripe_cancellation_reason: None,
4372 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4373 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4374 })
4375 .await?
4376 };
4377
4378 let billing_preferences = db.get_billing_preferences(user.id).await?;
4379
4380 let token = LlmTokenClaims::create(
4381 &user,
4382 session.is_staff(),
4383 billing_customer,
4384 billing_preferences,
4385 &flags,
4386 billing_subscription,
4387 session.system_id.clone(),
4388 &session.app_state.config,
4389 )?;
4390 response.send(proto::GetLlmTokenResponse { token })?;
4391 Ok(())
4392}
4393
4394fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4395 let message = match message {
4396 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4397 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4398 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4399 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4400 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4401 code: frame.code.into(),
4402 reason: frame.reason.as_str().to_owned().into(),
4403 })),
4404 // We should never receive a frame while reading the message, according
4405 // to the `tungstenite` maintainers:
4406 //
4407 // > It cannot occur when you read messages from the WebSocket, but it
4408 // > can be used when you want to send the raw frames (e.g. you want to
4409 // > send the frames to the WebSocket without composing the full message first).
4410 // >
4411 // > — https://github.com/snapview/tungstenite-rs/issues/268
4412 TungsteniteMessage::Frame(_) => {
4413 bail!("received an unexpected frame while reading the message")
4414 }
4415 };
4416
4417 Ok(message)
4418}
4419
4420fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4421 match message {
4422 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4423 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4424 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4425 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4426 AxumMessage::Close(frame) => {
4427 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4428 code: frame.code.into(),
4429 reason: frame.reason.as_ref().into(),
4430 }))
4431 }
4432 }
4433}
4434
4435fn notify_membership_updated(
4436 connection_pool: &mut ConnectionPool,
4437 result: MembershipUpdated,
4438 user_id: UserId,
4439 peer: &Peer,
4440) {
4441 for membership in &result.new_channels.channel_memberships {
4442 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4443 }
4444 for channel_id in &result.removed_channels {
4445 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4446 }
4447
4448 let user_channels_update = proto::UpdateUserChannels {
4449 channel_memberships: result
4450 .new_channels
4451 .channel_memberships
4452 .iter()
4453 .map(|cm| proto::ChannelMembership {
4454 channel_id: cm.channel_id.to_proto(),
4455 role: cm.role.into(),
4456 })
4457 .collect(),
4458 ..Default::default()
4459 };
4460
4461 let mut update = build_channels_update(result.new_channels);
4462 update.delete_channels = result
4463 .removed_channels
4464 .into_iter()
4465 .map(|id| id.to_proto())
4466 .collect();
4467 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4468
4469 for connection_id in connection_pool.user_connection_ids(user_id) {
4470 peer.send(connection_id, user_channels_update.clone())
4471 .trace_err();
4472 peer.send(connection_id, update.clone()).trace_err();
4473 }
4474}
4475
4476fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4477 proto::UpdateUserChannels {
4478 channel_memberships: channels
4479 .channel_memberships
4480 .iter()
4481 .map(|m| proto::ChannelMembership {
4482 channel_id: m.channel_id.to_proto(),
4483 role: m.role.into(),
4484 })
4485 .collect(),
4486 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4487 observed_channel_message_id: channels.observed_channel_messages.clone(),
4488 }
4489}
4490
4491fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4492 let mut update = proto::UpdateChannels::default();
4493
4494 for channel in channels.channels {
4495 update.channels.push(channel.to_proto());
4496 }
4497
4498 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4499 update.latest_channel_message_ids = channels.latest_channel_messages;
4500
4501 for (channel_id, participants) in channels.channel_participants {
4502 update
4503 .channel_participants
4504 .push(proto::ChannelParticipants {
4505 channel_id: channel_id.to_proto(),
4506 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4507 });
4508 }
4509
4510 for channel in channels.invited_channels {
4511 update.channel_invitations.push(channel.to_proto());
4512 }
4513
4514 update
4515}
4516
4517fn build_initial_contacts_update(
4518 contacts: Vec<db::Contact>,
4519 pool: &ConnectionPool,
4520) -> proto::UpdateContacts {
4521 let mut update = proto::UpdateContacts::default();
4522
4523 for contact in contacts {
4524 match contact {
4525 db::Contact::Accepted { user_id, busy } => {
4526 update.contacts.push(contact_for_user(user_id, busy, pool));
4527 }
4528 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4529 db::Contact::Incoming { user_id } => {
4530 update
4531 .incoming_requests
4532 .push(proto::IncomingContactRequest {
4533 requester_id: user_id.to_proto(),
4534 })
4535 }
4536 }
4537 }
4538
4539 update
4540}
4541
4542fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4543 proto::Contact {
4544 user_id: user_id.to_proto(),
4545 online: pool.is_user_online(user_id),
4546 busy,
4547 }
4548}
4549
4550fn room_updated(room: &proto::Room, peer: &Peer) {
4551 broadcast(
4552 None,
4553 room.participants
4554 .iter()
4555 .filter_map(|participant| Some(participant.peer_id?.into())),
4556 |peer_id| {
4557 peer.send(
4558 peer_id,
4559 proto::RoomUpdated {
4560 room: Some(room.clone()),
4561 },
4562 )
4563 },
4564 );
4565}
4566
4567fn channel_updated(
4568 channel: &db::channel::Model,
4569 room: &proto::Room,
4570 peer: &Peer,
4571 pool: &ConnectionPool,
4572) {
4573 let participants = room
4574 .participants
4575 .iter()
4576 .map(|p| p.user_id)
4577 .collect::<Vec<_>>();
4578
4579 broadcast(
4580 None,
4581 pool.channel_connection_ids(channel.root_id())
4582 .filter_map(|(channel_id, role)| {
4583 role.can_see_channel(channel.visibility)
4584 .then_some(channel_id)
4585 }),
4586 |peer_id| {
4587 peer.send(
4588 peer_id,
4589 proto::UpdateChannels {
4590 channel_participants: vec![proto::ChannelParticipants {
4591 channel_id: channel.id.to_proto(),
4592 participant_user_ids: participants.clone(),
4593 }],
4594 ..Default::default()
4595 },
4596 )
4597 },
4598 );
4599}
4600
4601async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4602 let db = session.db().await;
4603
4604 let contacts = db.get_contacts(user_id).await?;
4605 let busy = db.is_user_busy(user_id).await?;
4606
4607 let pool = session.connection_pool().await;
4608 let updated_contact = contact_for_user(user_id, busy, &pool);
4609 for contact in contacts {
4610 if let db::Contact::Accepted {
4611 user_id: contact_user_id,
4612 ..
4613 } = contact
4614 {
4615 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4616 session
4617 .peer
4618 .send(
4619 contact_conn_id,
4620 proto::UpdateContacts {
4621 contacts: vec![updated_contact.clone()],
4622 remove_contacts: Default::default(),
4623 incoming_requests: Default::default(),
4624 remove_incoming_requests: Default::default(),
4625 outgoing_requests: Default::default(),
4626 remove_outgoing_requests: Default::default(),
4627 },
4628 )
4629 .trace_err();
4630 }
4631 }
4632 }
4633 Ok(())
4634}
4635
4636async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4637 let mut contacts_to_update = HashSet::default();
4638
4639 let room_id;
4640 let canceled_calls_to_user_ids;
4641 let livekit_room;
4642 let delete_livekit_room;
4643 let room;
4644 let channel;
4645
4646 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4647 contacts_to_update.insert(session.user_id());
4648
4649 for project in left_room.left_projects.values() {
4650 project_left(project, session);
4651 }
4652
4653 room_id = RoomId::from_proto(left_room.room.id);
4654 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4655 livekit_room = mem::take(&mut left_room.room.livekit_room);
4656 delete_livekit_room = left_room.deleted;
4657 room = mem::take(&mut left_room.room);
4658 channel = mem::take(&mut left_room.channel);
4659
4660 room_updated(&room, &session.peer);
4661 } else {
4662 return Ok(());
4663 }
4664
4665 if let Some(channel) = channel {
4666 channel_updated(
4667 &channel,
4668 &room,
4669 &session.peer,
4670 &*session.connection_pool().await,
4671 );
4672 }
4673
4674 {
4675 let pool = session.connection_pool().await;
4676 for canceled_user_id in canceled_calls_to_user_ids {
4677 for connection_id in pool.user_connection_ids(canceled_user_id) {
4678 session
4679 .peer
4680 .send(
4681 connection_id,
4682 proto::CallCanceled {
4683 room_id: room_id.to_proto(),
4684 },
4685 )
4686 .trace_err();
4687 }
4688 contacts_to_update.insert(canceled_user_id);
4689 }
4690 }
4691
4692 for contact_user_id in contacts_to_update {
4693 update_user_contacts(contact_user_id, session).await?;
4694 }
4695
4696 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4697 live_kit
4698 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4699 .await
4700 .trace_err();
4701
4702 if delete_livekit_room {
4703 live_kit.delete_room(livekit_room).await.trace_err();
4704 }
4705 }
4706
4707 Ok(())
4708}
4709
4710async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4711 let left_channel_buffers = session
4712 .db()
4713 .await
4714 .leave_channel_buffers(session.connection_id)
4715 .await?;
4716
4717 for left_buffer in left_channel_buffers {
4718 channel_buffer_updated(
4719 session.connection_id,
4720 left_buffer.connections,
4721 &proto::UpdateChannelBufferCollaborators {
4722 channel_id: left_buffer.channel_id.to_proto(),
4723 collaborators: left_buffer.collaborators,
4724 },
4725 &session.peer,
4726 );
4727 }
4728
4729 Ok(())
4730}
4731
4732fn project_left(project: &db::LeftProject, session: &Session) {
4733 for connection_id in &project.connection_ids {
4734 if project.should_unshare {
4735 session
4736 .peer
4737 .send(
4738 *connection_id,
4739 proto::UnshareProject {
4740 project_id: project.id.to_proto(),
4741 },
4742 )
4743 .trace_err();
4744 } else {
4745 session
4746 .peer
4747 .send(
4748 *connection_id,
4749 proto::RemoveProjectCollaborator {
4750 project_id: project.id.to_proto(),
4751 peer_id: Some(session.connection_id.into()),
4752 },
4753 )
4754 .trace_err();
4755 }
4756 }
4757}
4758
4759pub trait ResultExt {
4760 type Ok;
4761
4762 fn trace_err(self) -> Option<Self::Ok>;
4763}
4764
4765impl<T, E> ResultExt for Result<T, E>
4766where
4767 E: std::fmt::Debug,
4768{
4769 type Ok = T;
4770
4771 #[track_caller]
4772 fn trace_err(self) -> Option<T> {
4773 match self {
4774 Ok(value) => Some(value),
4775 Err(error) => {
4776 tracing::error!("{:?}", error);
4777 None
4778 }
4779 }
4780 }
4781}