1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{
8 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
9 MIN_ACCOUNT_AGE_FOR_LLM_USE,
10};
11use crate::stripe_client::StripeCustomerId;
12use crate::{
13 AppState, Error, Result, auth,
14 db::{
15 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
16 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
17 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
18 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
19 },
20 executor::Executor,
21};
22use anyhow::{Context as _, anyhow, bail};
23use async_tungstenite::tungstenite::{
24 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
25};
26use axum::headers::UserAgent;
27use axum::{
28 Extension, Router, TypedHeader,
29 body::Body,
30 extract::{
31 ConnectInfo, WebSocketUpgrade,
32 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
33 },
34 headers::{Header, HeaderName},
35 http::StatusCode,
36 middleware,
37 response::IntoResponse,
38 routing::get,
39};
40use chrono::Utc;
41use collections::{HashMap, HashSet};
42pub use connection_pool::{ConnectionPool, ZedVersion};
43use core::fmt::{self, Debug, Formatter};
44use futures::TryFutureExt as _;
45use reqwest_client::ReqwestClient;
46use rpc::proto::{MultiLspQuery, split_repository_update};
47use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
48use tracing::Span;
49
50use futures::{
51 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
52 stream::FuturesUnordered,
53};
54use prometheus::{IntGauge, register_int_gauge};
55use rpc::{
56 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
57 proto::{
58 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
59 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
60 },
61};
62use semantic_version::SemanticVersion;
63use serde::{Serialize, Serializer};
64use std::{
65 any::TypeId,
66 future::Future,
67 marker::PhantomData,
68 mem,
69 net::SocketAddr,
70 ops::{Deref, DerefMut},
71 rc::Rc,
72 sync::{
73 Arc, OnceLock,
74 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
75 },
76 time::{Duration, Instant},
77};
78use time::OffsetDateTime;
79use tokio::sync::{Semaphore, watch};
80use tower::ServiceBuilder;
81use tracing::{
82 Instrument,
83 field::{self},
84 info_span, instrument,
85};
86
87pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
88
89// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
90pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
91
92const MESSAGE_COUNT_PER_PAGE: usize = 100;
93const MAX_MESSAGE_LEN: usize = 1024;
94const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
95const MAX_CONCURRENT_CONNECTIONS: usize = 512;
96
97static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
98
99const TOTAL_DURATION_MS: &str = "total_duration_ms";
100const PROCESSING_DURATION_MS: &str = "processing_duration_ms";
101const QUEUE_DURATION_MS: &str = "queue_duration_ms";
102const HOST_WAITING_MS: &str = "host_waiting_ms";
103
104type MessageHandler =
105 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session, Span) -> BoxFuture<'static, ()>>;
106
107pub struct ConnectionGuard;
108
109impl ConnectionGuard {
110 pub fn try_acquire() -> Result<Self, ()> {
111 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
112 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
113 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
114 tracing::error!(
115 "too many concurrent connections: {}",
116 current_connections + 1
117 );
118 return Err(());
119 }
120 Ok(ConnectionGuard)
121 }
122}
123
124impl Drop for ConnectionGuard {
125 fn drop(&mut self) {
126 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
127 }
128}
129
130struct Response<R> {
131 peer: Arc<Peer>,
132 receipt: Receipt<R>,
133 responded: Arc<AtomicBool>,
134}
135
136impl<R: RequestMessage> Response<R> {
137 fn send(self, payload: R::Response) -> Result<()> {
138 self.responded.store(true, SeqCst);
139 self.peer.respond(self.receipt, payload)?;
140 Ok(())
141 }
142}
143
144#[derive(Clone, Debug)]
145pub enum Principal {
146 User(User),
147 Impersonated { user: User, admin: User },
148}
149
150impl Principal {
151 fn user(&self) -> &User {
152 match self {
153 Principal::User(user) => user,
154 Principal::Impersonated { user, .. } => user,
155 }
156 }
157
158 fn update_span(&self, span: &tracing::Span) {
159 match &self {
160 Principal::User(user) => {
161 span.record("user_id", user.id.0);
162 span.record("login", &user.github_login);
163 }
164 Principal::Impersonated { user, admin } => {
165 span.record("user_id", user.id.0);
166 span.record("login", &user.github_login);
167 span.record("impersonator", &admin.github_login);
168 }
169 }
170 }
171}
172
173#[derive(Clone)]
174struct MessageContext {
175 session: Session,
176 span: tracing::Span,
177}
178
179impl Deref for MessageContext {
180 type Target = Session;
181
182 fn deref(&self) -> &Self::Target {
183 &self.session
184 }
185}
186
187impl MessageContext {
188 pub fn forward_request<T: RequestMessage>(
189 &self,
190 receiver_id: ConnectionId,
191 request: T,
192 ) -> impl Future<Output = anyhow::Result<T::Response>> {
193 let request_start_time = Instant::now();
194 let span = self.span.clone();
195 tracing::info!("start forwarding request");
196 self.peer
197 .forward_request(self.connection_id, receiver_id, request)
198 .inspect(move |_| {
199 span.record(
200 HOST_WAITING_MS,
201 request_start_time.elapsed().as_micros() as f64 / 1000.0,
202 );
203 })
204 .inspect_err(|_| tracing::error!("error forwarding request"))
205 .inspect_ok(|_| tracing::info!("finished forwarding request"))
206 }
207}
208
209#[derive(Clone)]
210struct Session {
211 principal: Principal,
212 connection_id: ConnectionId,
213 db: Arc<tokio::sync::Mutex<DbHandle>>,
214 peer: Arc<Peer>,
215 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
216 app_state: Arc<AppState>,
217 supermaven_client: Option<Arc<SupermavenAdminApi>>,
218 /// The GeoIP country code for the user.
219 #[allow(unused)]
220 geoip_country_code: Option<String>,
221 system_id: Option<String>,
222 _executor: Executor,
223}
224
225impl Session {
226 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
227 #[cfg(test)]
228 tokio::task::yield_now().await;
229 let guard = self.db.lock().await;
230 #[cfg(test)]
231 tokio::task::yield_now().await;
232 guard
233 }
234
235 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
236 #[cfg(test)]
237 tokio::task::yield_now().await;
238 let guard = self.connection_pool.lock();
239 ConnectionPoolGuard {
240 guard,
241 _not_send: PhantomData,
242 }
243 }
244
245 fn is_staff(&self) -> bool {
246 match &self.principal {
247 Principal::User(user) => user.admin,
248 Principal::Impersonated { .. } => true,
249 }
250 }
251
252 fn user_id(&self) -> UserId {
253 match &self.principal {
254 Principal::User(user) => user.id,
255 Principal::Impersonated { user, .. } => user.id,
256 }
257 }
258
259 pub fn email(&self) -> Option<String> {
260 match &self.principal {
261 Principal::User(user) => user.email_address.clone(),
262 Principal::Impersonated { user, .. } => user.email_address.clone(),
263 }
264 }
265}
266
267impl Debug for Session {
268 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
269 let mut result = f.debug_struct("Session");
270 match &self.principal {
271 Principal::User(user) => {
272 result.field("user", &user.github_login);
273 }
274 Principal::Impersonated { user, admin } => {
275 result.field("user", &user.github_login);
276 result.field("impersonator", &admin.github_login);
277 }
278 }
279 result.field("connection_id", &self.connection_id).finish()
280 }
281}
282
283struct DbHandle(Arc<Database>);
284
285impl Deref for DbHandle {
286 type Target = Database;
287
288 fn deref(&self) -> &Self::Target {
289 self.0.as_ref()
290 }
291}
292
293pub struct Server {
294 id: parking_lot::Mutex<ServerId>,
295 peer: Arc<Peer>,
296 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
297 app_state: Arc<AppState>,
298 handlers: HashMap<TypeId, MessageHandler>,
299 teardown: watch::Sender<bool>,
300}
301
302pub(crate) struct ConnectionPoolGuard<'a> {
303 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
304 _not_send: PhantomData<Rc<()>>,
305}
306
307#[derive(Serialize)]
308pub struct ServerSnapshot<'a> {
309 peer: &'a Peer,
310 #[serde(serialize_with = "serialize_deref")]
311 connection_pool: ConnectionPoolGuard<'a>,
312}
313
314pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
315where
316 S: Serializer,
317 T: Deref<Target = U>,
318 U: Serialize,
319{
320 Serialize::serialize(value.deref(), serializer)
321}
322
323impl Server {
324 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
325 let mut server = Self {
326 id: parking_lot::Mutex::new(id),
327 peer: Peer::new(id.0 as u32),
328 app_state: app_state.clone(),
329 connection_pool: Default::default(),
330 handlers: Default::default(),
331 teardown: watch::channel(false).0,
332 };
333
334 server
335 .add_request_handler(ping)
336 .add_request_handler(create_room)
337 .add_request_handler(join_room)
338 .add_request_handler(rejoin_room)
339 .add_request_handler(leave_room)
340 .add_request_handler(set_room_participant_role)
341 .add_request_handler(call)
342 .add_request_handler(cancel_call)
343 .add_message_handler(decline_call)
344 .add_request_handler(update_participant_location)
345 .add_request_handler(share_project)
346 .add_message_handler(unshare_project)
347 .add_request_handler(join_project)
348 .add_message_handler(leave_project)
349 .add_request_handler(update_project)
350 .add_request_handler(update_worktree)
351 .add_request_handler(update_repository)
352 .add_request_handler(remove_repository)
353 .add_message_handler(start_language_server)
354 .add_message_handler(update_language_server)
355 .add_message_handler(update_diagnostic_summary)
356 .add_message_handler(update_worktree_settings)
357 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
358 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
359 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
360 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
361 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
362 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
363 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
364 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
365 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
366 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
367 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
368 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
369 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
370 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
371 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
372 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
373 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
374 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
375 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
376 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
377 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
378 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
379 .add_request_handler(
380 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
381 )
382 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
383 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
384 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
385 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
386 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
387 .add_request_handler(
388 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
389 )
390 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
391 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
392 .add_request_handler(
393 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
394 )
395 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
396 .add_request_handler(
397 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
398 )
399 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
400 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
401 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
402 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
403 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
404 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
405 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
406 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
407 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
408 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
409 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
410 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
411 .add_request_handler(
412 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
413 )
414 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
415 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
416 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
417 .add_request_handler(multi_lsp_query)
418 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
419 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
420 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
421 .add_message_handler(create_buffer_for_peer)
422 .add_request_handler(update_buffer)
423 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
424 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
425 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
426 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
427 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
428 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
429 .add_message_handler(
430 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
431 )
432 .add_request_handler(get_users)
433 .add_request_handler(fuzzy_search_users)
434 .add_request_handler(request_contact)
435 .add_request_handler(remove_contact)
436 .add_request_handler(respond_to_contact_request)
437 .add_message_handler(subscribe_to_channels)
438 .add_request_handler(create_channel)
439 .add_request_handler(delete_channel)
440 .add_request_handler(invite_channel_member)
441 .add_request_handler(remove_channel_member)
442 .add_request_handler(set_channel_member_role)
443 .add_request_handler(set_channel_visibility)
444 .add_request_handler(rename_channel)
445 .add_request_handler(join_channel_buffer)
446 .add_request_handler(leave_channel_buffer)
447 .add_message_handler(update_channel_buffer)
448 .add_request_handler(rejoin_channel_buffers)
449 .add_request_handler(get_channel_members)
450 .add_request_handler(respond_to_channel_invite)
451 .add_request_handler(join_channel)
452 .add_request_handler(join_channel_chat)
453 .add_message_handler(leave_channel_chat)
454 .add_request_handler(send_channel_message)
455 .add_request_handler(remove_channel_message)
456 .add_request_handler(update_channel_message)
457 .add_request_handler(get_channel_messages)
458 .add_request_handler(get_channel_messages_by_id)
459 .add_request_handler(get_notifications)
460 .add_request_handler(mark_notification_as_read)
461 .add_request_handler(move_channel)
462 .add_request_handler(reorder_channel)
463 .add_request_handler(follow)
464 .add_message_handler(unfollow)
465 .add_message_handler(update_followers)
466 .add_request_handler(get_private_user_info)
467 .add_request_handler(get_llm_api_token)
468 .add_request_handler(accept_terms_of_service)
469 .add_message_handler(acknowledge_channel_message)
470 .add_message_handler(acknowledge_buffer_version)
471 .add_request_handler(get_supermaven_api_key)
472 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
473 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
474 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
475 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
476 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
477 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
478 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
479 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
480 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
481 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
482 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
483 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
484 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
485 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
486 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
487 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
488 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
489 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
490 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
491 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
492 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
493 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
494 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
495 .add_message_handler(update_context);
496
497 Arc::new(server)
498 }
499
500 pub async fn start(&self) -> Result<()> {
501 let server_id = *self.id.lock();
502 let app_state = self.app_state.clone();
503 let peer = self.peer.clone();
504 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
505 let pool = self.connection_pool.clone();
506 let livekit_client = self.app_state.livekit_client.clone();
507
508 let span = info_span!("start server");
509 self.app_state.executor.spawn_detached(
510 async move {
511 tracing::info!("waiting for cleanup timeout");
512 timeout.await;
513 tracing::info!("cleanup timeout expired, retrieving stale rooms");
514
515 app_state
516 .db
517 .delete_stale_channel_chat_participants(
518 &app_state.config.zed_environment,
519 server_id,
520 )
521 .await
522 .trace_err();
523
524 if let Some((room_ids, channel_ids)) = app_state
525 .db
526 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
527 .await
528 .trace_err()
529 {
530 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
531 tracing::info!(
532 stale_channel_buffer_count = channel_ids.len(),
533 "retrieved stale channel buffers"
534 );
535
536 for channel_id in channel_ids {
537 if let Some(refreshed_channel_buffer) = app_state
538 .db
539 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
540 .await
541 .trace_err()
542 {
543 for connection_id in refreshed_channel_buffer.connection_ids {
544 peer.send(
545 connection_id,
546 proto::UpdateChannelBufferCollaborators {
547 channel_id: channel_id.to_proto(),
548 collaborators: refreshed_channel_buffer
549 .collaborators
550 .clone(),
551 },
552 )
553 .trace_err();
554 }
555 }
556 }
557
558 for room_id in room_ids {
559 let mut contacts_to_update = HashSet::default();
560 let mut canceled_calls_to_user_ids = Vec::new();
561 let mut livekit_room = String::new();
562 let mut delete_livekit_room = false;
563
564 if let Some(mut refreshed_room) = app_state
565 .db
566 .clear_stale_room_participants(room_id, server_id)
567 .await
568 .trace_err()
569 {
570 tracing::info!(
571 room_id = room_id.0,
572 new_participant_count = refreshed_room.room.participants.len(),
573 "refreshed room"
574 );
575 room_updated(&refreshed_room.room, &peer);
576 if let Some(channel) = refreshed_room.channel.as_ref() {
577 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
578 }
579 contacts_to_update
580 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
581 contacts_to_update
582 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
583 canceled_calls_to_user_ids =
584 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
585 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
586 delete_livekit_room = refreshed_room.room.participants.is_empty();
587 }
588
589 {
590 let pool = pool.lock();
591 for canceled_user_id in canceled_calls_to_user_ids {
592 for connection_id in pool.user_connection_ids(canceled_user_id) {
593 peer.send(
594 connection_id,
595 proto::CallCanceled {
596 room_id: room_id.to_proto(),
597 },
598 )
599 .trace_err();
600 }
601 }
602 }
603
604 for user_id in contacts_to_update {
605 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
606 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
607 if let Some((busy, contacts)) = busy.zip(contacts) {
608 let pool = pool.lock();
609 let updated_contact = contact_for_user(user_id, busy, &pool);
610 for contact in contacts {
611 if let db::Contact::Accepted {
612 user_id: contact_user_id,
613 ..
614 } = contact
615 {
616 for contact_conn_id in
617 pool.user_connection_ids(contact_user_id)
618 {
619 peer.send(
620 contact_conn_id,
621 proto::UpdateContacts {
622 contacts: vec![updated_contact.clone()],
623 remove_contacts: Default::default(),
624 incoming_requests: Default::default(),
625 remove_incoming_requests: Default::default(),
626 outgoing_requests: Default::default(),
627 remove_outgoing_requests: Default::default(),
628 },
629 )
630 .trace_err();
631 }
632 }
633 }
634 }
635 }
636
637 if let Some(live_kit) = livekit_client.as_ref() {
638 if delete_livekit_room {
639 live_kit.delete_room(livekit_room).await.trace_err();
640 }
641 }
642 }
643 }
644
645 app_state
646 .db
647 .delete_stale_channel_chat_participants(
648 &app_state.config.zed_environment,
649 server_id,
650 )
651 .await
652 .trace_err();
653
654 app_state
655 .db
656 .clear_old_worktree_entries(server_id)
657 .await
658 .trace_err();
659
660 app_state
661 .db
662 .delete_stale_servers(&app_state.config.zed_environment, server_id)
663 .await
664 .trace_err();
665 }
666 .instrument(span),
667 );
668 Ok(())
669 }
670
671 pub fn teardown(&self) {
672 self.peer.teardown();
673 self.connection_pool.lock().reset();
674 let _ = self.teardown.send(true);
675 }
676
677 #[cfg(test)]
678 pub fn reset(&self, id: ServerId) {
679 self.teardown();
680 *self.id.lock() = id;
681 self.peer.reset(id.0 as u32);
682 let _ = self.teardown.send(false);
683 }
684
685 #[cfg(test)]
686 pub fn id(&self) -> ServerId {
687 *self.id.lock()
688 }
689
690 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
691 where
692 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, MessageContext) -> Fut,
693 Fut: 'static + Send + Future<Output = Result<()>>,
694 M: EnvelopedMessage,
695 {
696 let prev_handler = self.handlers.insert(
697 TypeId::of::<M>(),
698 Box::new(move |envelope, session, span| {
699 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
700 let received_at = envelope.received_at;
701 tracing::info!("message received");
702 let start_time = Instant::now();
703 let future = (handler)(
704 *envelope,
705 MessageContext {
706 session,
707 span: span.clone(),
708 },
709 );
710 async move {
711 let result = future.await;
712 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
713 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
714 let queue_duration_ms = total_duration_ms - processing_duration_ms;
715 span.record(TOTAL_DURATION_MS, total_duration_ms);
716 span.record(PROCESSING_DURATION_MS, processing_duration_ms);
717 span.record(QUEUE_DURATION_MS, queue_duration_ms);
718 match result {
719 Err(error) => {
720 tracing::error!(?error, "error handling message")
721 }
722 Ok(()) => tracing::info!("finished handling message"),
723 }
724 }
725 .boxed()
726 }),
727 );
728 if prev_handler.is_some() {
729 panic!("registered a handler for the same message twice");
730 }
731 self
732 }
733
734 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
735 where
736 F: 'static + Send + Sync + Fn(M, MessageContext) -> Fut,
737 Fut: 'static + Send + Future<Output = Result<()>>,
738 M: EnvelopedMessage,
739 {
740 self.add_handler(move |envelope, session| handler(envelope.payload, session));
741 self
742 }
743
744 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
745 where
746 F: 'static + Send + Sync + Fn(M, Response<M>, MessageContext) -> Fut,
747 Fut: Send + Future<Output = Result<()>>,
748 M: RequestMessage,
749 {
750 let handler = Arc::new(handler);
751 self.add_handler(move |envelope, session| {
752 let receipt = envelope.receipt();
753 let handler = handler.clone();
754 async move {
755 let peer = session.peer.clone();
756 let responded = Arc::new(AtomicBool::default());
757 let response = Response {
758 peer: peer.clone(),
759 responded: responded.clone(),
760 receipt,
761 };
762 match (handler)(envelope.payload, response, session).await {
763 Ok(()) => {
764 if responded.load(std::sync::atomic::Ordering::SeqCst) {
765 Ok(())
766 } else {
767 Err(anyhow!("handler did not send a response"))?
768 }
769 }
770 Err(error) => {
771 let proto_err = match &error {
772 Error::Internal(err) => err.to_proto(),
773 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
774 };
775 peer.respond_with_error(receipt, proto_err)?;
776 Err(error)
777 }
778 }
779 }
780 })
781 }
782
783 pub fn handle_connection(
784 self: &Arc<Self>,
785 connection: Connection,
786 address: String,
787 principal: Principal,
788 zed_version: ZedVersion,
789 release_channel: Option<String>,
790 user_agent: Option<String>,
791 geoip_country_code: Option<String>,
792 system_id: Option<String>,
793 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
794 executor: Executor,
795 connection_guard: Option<ConnectionGuard>,
796 ) -> impl Future<Output = ()> + use<> {
797 let this = self.clone();
798 let span = info_span!("handle connection", %address,
799 connection_id=field::Empty,
800 user_id=field::Empty,
801 login=field::Empty,
802 impersonator=field::Empty,
803 user_agent=field::Empty,
804 geoip_country_code=field::Empty,
805 release_channel=field::Empty,
806 );
807 principal.update_span(&span);
808 if let Some(user_agent) = user_agent {
809 span.record("user_agent", user_agent);
810 }
811 if let Some(release_channel) = release_channel {
812 span.record("release_channel", release_channel);
813 }
814
815 if let Some(country_code) = geoip_country_code.as_ref() {
816 span.record("geoip_country_code", country_code);
817 }
818
819 let mut teardown = self.teardown.subscribe();
820 async move {
821 if *teardown.borrow() {
822 tracing::error!("server is tearing down");
823 return;
824 }
825
826 let (connection_id, handle_io, mut incoming_rx) =
827 this.peer.add_connection(connection, {
828 let executor = executor.clone();
829 move |duration| executor.sleep(duration)
830 });
831 tracing::Span::current().record("connection_id", format!("{}", connection_id));
832
833 tracing::info!("connection opened");
834
835 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
836 let http_client = match ReqwestClient::user_agent(&user_agent) {
837 Ok(http_client) => Arc::new(http_client),
838 Err(error) => {
839 tracing::error!(?error, "failed to create HTTP client");
840 return;
841 }
842 };
843
844 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
845 |supermaven_admin_api_key| {
846 Arc::new(SupermavenAdminApi::new(
847 supermaven_admin_api_key.to_string(),
848 http_client.clone(),
849 ))
850 },
851 );
852
853 let session = Session {
854 principal: principal.clone(),
855 connection_id,
856 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
857 peer: this.peer.clone(),
858 connection_pool: this.connection_pool.clone(),
859 app_state: this.app_state.clone(),
860 geoip_country_code,
861 system_id,
862 _executor: executor.clone(),
863 supermaven_client,
864 };
865
866 if let Err(error) = this
867 .send_initial_client_update(
868 connection_id,
869 zed_version,
870 send_connection_id,
871 &session,
872 )
873 .await
874 {
875 tracing::error!(?error, "failed to send initial client update");
876 return;
877 }
878 drop(connection_guard);
879
880 let handle_io = handle_io.fuse();
881 futures::pin_mut!(handle_io);
882
883 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
884 // This prevents deadlocks when e.g., client A performs a request to client B and
885 // client B performs a request to client A. If both clients stop processing further
886 // messages until their respective request completes, they won't have a chance to
887 // respond to the other client's request and cause a deadlock.
888 //
889 // This arrangement ensures we will attempt to process earlier messages first, but fall
890 // back to processing messages arrived later in the spirit of making progress.
891 const MAX_CONCURRENT_HANDLERS: usize = 256;
892 let mut foreground_message_handlers = FuturesUnordered::new();
893 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
894 let get_concurrent_handlers = {
895 let concurrent_handlers = concurrent_handlers.clone();
896 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
897 };
898 loop {
899 let next_message = async {
900 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
901 let message = incoming_rx.next().await;
902 // Cache the concurrent_handlers here, so that we know what the
903 // queue looks like as each handler starts
904 (permit, message, get_concurrent_handlers())
905 }
906 .fuse();
907 futures::pin_mut!(next_message);
908 futures::select_biased! {
909 _ = teardown.changed().fuse() => return,
910 result = handle_io => {
911 if let Err(error) = result {
912 tracing::error!(?error, "error handling I/O");
913 }
914 break;
915 }
916 _ = foreground_message_handlers.next() => {}
917 next_message = next_message => {
918 let (permit, message, concurrent_handlers) = next_message;
919 if let Some(message) = message {
920 let type_name = message.payload_type_name();
921 // note: we copy all the fields from the parent span so we can query them in the logs.
922 // (https://github.com/tokio-rs/tracing/issues/2670).
923 let span = tracing::info_span!("receive message",
924 %connection_id,
925 %address,
926 type_name,
927 concurrent_handlers,
928 user_id=field::Empty,
929 login=field::Empty,
930 impersonator=field::Empty,
931 multi_lsp_query_request=field::Empty,
932 release_channel=field::Empty,
933 { TOTAL_DURATION_MS }=field::Empty,
934 { PROCESSING_DURATION_MS }=field::Empty,
935 { QUEUE_DURATION_MS }=field::Empty,
936 { HOST_WAITING_MS }=field::Empty
937 );
938 principal.update_span(&span);
939 let span_enter = span.enter();
940 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
941 let is_background = message.is_background();
942 let handle_message = (handler)(message, session.clone(), span.clone());
943 drop(span_enter);
944
945 let handle_message = async move {
946 handle_message.await;
947 drop(permit);
948 }.instrument(span);
949 if is_background {
950 executor.spawn_detached(handle_message);
951 } else {
952 foreground_message_handlers.push(handle_message);
953 }
954 } else {
955 tracing::error!("no message handler");
956 }
957 } else {
958 tracing::info!("connection closed");
959 break;
960 }
961 }
962 }
963 }
964
965 drop(foreground_message_handlers);
966 let concurrent_handlers = get_concurrent_handlers();
967 tracing::info!(concurrent_handlers, "signing out");
968 if let Err(error) = connection_lost(session, teardown, executor).await {
969 tracing::error!(?error, "error signing out");
970 }
971 }
972 .instrument(span)
973 }
974
975 async fn send_initial_client_update(
976 &self,
977 connection_id: ConnectionId,
978 zed_version: ZedVersion,
979 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
980 session: &Session,
981 ) -> Result<()> {
982 self.peer.send(
983 connection_id,
984 proto::Hello {
985 peer_id: Some(connection_id.into()),
986 },
987 )?;
988 tracing::info!("sent hello message");
989 if let Some(send_connection_id) = send_connection_id.take() {
990 let _ = send_connection_id.send(connection_id);
991 }
992
993 match &session.principal {
994 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
995 if !user.connected_once {
996 self.peer.send(connection_id, proto::ShowContacts {})?;
997 self.app_state
998 .db
999 .set_user_connected_once(user.id, true)
1000 .await?;
1001 }
1002
1003 update_user_plan(session).await?;
1004
1005 let contacts = self.app_state.db.get_contacts(user.id).await?;
1006
1007 {
1008 let mut pool = self.connection_pool.lock();
1009 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1010 self.peer.send(
1011 connection_id,
1012 build_initial_contacts_update(contacts, &pool),
1013 )?;
1014 }
1015
1016 if should_auto_subscribe_to_channels(zed_version) {
1017 subscribe_user_to_channels(user.id, session).await?;
1018 }
1019
1020 if let Some(incoming_call) =
1021 self.app_state.db.incoming_call_for_user(user.id).await?
1022 {
1023 self.peer.send(connection_id, incoming_call)?;
1024 }
1025
1026 update_user_contacts(user.id, session).await?;
1027 }
1028 }
1029
1030 Ok(())
1031 }
1032
1033 pub async fn invite_code_redeemed(
1034 self: &Arc<Self>,
1035 inviter_id: UserId,
1036 invitee_id: UserId,
1037 ) -> Result<()> {
1038 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1039 if let Some(code) = &user.invite_code {
1040 let pool = self.connection_pool.lock();
1041 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1042 for connection_id in pool.user_connection_ids(inviter_id) {
1043 self.peer.send(
1044 connection_id,
1045 proto::UpdateContacts {
1046 contacts: vec![invitee_contact.clone()],
1047 ..Default::default()
1048 },
1049 )?;
1050 self.peer.send(
1051 connection_id,
1052 proto::UpdateInviteInfo {
1053 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1054 count: user.invite_count as u32,
1055 },
1056 )?;
1057 }
1058 }
1059 }
1060 Ok(())
1061 }
1062
1063 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1064 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1065 if let Some(invite_code) = &user.invite_code {
1066 let pool = self.connection_pool.lock();
1067 for connection_id in pool.user_connection_ids(user_id) {
1068 self.peer.send(
1069 connection_id,
1070 proto::UpdateInviteInfo {
1071 url: format!(
1072 "{}{}",
1073 self.app_state.config.invite_link_prefix, invite_code
1074 ),
1075 count: user.invite_count as u32,
1076 },
1077 )?;
1078 }
1079 }
1080 }
1081 Ok(())
1082 }
1083
1084 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1085 ServerSnapshot {
1086 connection_pool: ConnectionPoolGuard {
1087 guard: self.connection_pool.lock(),
1088 _not_send: PhantomData,
1089 },
1090 peer: &self.peer,
1091 }
1092 }
1093}
1094
1095impl Deref for ConnectionPoolGuard<'_> {
1096 type Target = ConnectionPool;
1097
1098 fn deref(&self) -> &Self::Target {
1099 &self.guard
1100 }
1101}
1102
1103impl DerefMut for ConnectionPoolGuard<'_> {
1104 fn deref_mut(&mut self) -> &mut Self::Target {
1105 &mut self.guard
1106 }
1107}
1108
1109impl Drop for ConnectionPoolGuard<'_> {
1110 fn drop(&mut self) {
1111 #[cfg(test)]
1112 self.check_invariants();
1113 }
1114}
1115
1116fn broadcast<F>(
1117 sender_id: Option<ConnectionId>,
1118 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1119 mut f: F,
1120) where
1121 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1122{
1123 for receiver_id in receiver_ids {
1124 if Some(receiver_id) != sender_id {
1125 if let Err(error) = f(receiver_id) {
1126 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1127 }
1128 }
1129 }
1130}
1131
1132pub struct ProtocolVersion(u32);
1133
1134impl Header for ProtocolVersion {
1135 fn name() -> &'static HeaderName {
1136 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1137 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1138 }
1139
1140 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1141 where
1142 Self: Sized,
1143 I: Iterator<Item = &'i axum::http::HeaderValue>,
1144 {
1145 let version = values
1146 .next()
1147 .ok_or_else(axum::headers::Error::invalid)?
1148 .to_str()
1149 .map_err(|_| axum::headers::Error::invalid())?
1150 .parse()
1151 .map_err(|_| axum::headers::Error::invalid())?;
1152 Ok(Self(version))
1153 }
1154
1155 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1156 values.extend([self.0.to_string().parse().unwrap()]);
1157 }
1158}
1159
1160pub struct AppVersionHeader(SemanticVersion);
1161impl Header for AppVersionHeader {
1162 fn name() -> &'static HeaderName {
1163 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1164 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1165 }
1166
1167 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1168 where
1169 Self: Sized,
1170 I: Iterator<Item = &'i axum::http::HeaderValue>,
1171 {
1172 let version = values
1173 .next()
1174 .ok_or_else(axum::headers::Error::invalid)?
1175 .to_str()
1176 .map_err(|_| axum::headers::Error::invalid())?
1177 .parse()
1178 .map_err(|_| axum::headers::Error::invalid())?;
1179 Ok(Self(version))
1180 }
1181
1182 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1183 values.extend([self.0.to_string().parse().unwrap()]);
1184 }
1185}
1186
1187#[derive(Debug)]
1188pub struct ReleaseChannelHeader(String);
1189
1190impl Header for ReleaseChannelHeader {
1191 fn name() -> &'static HeaderName {
1192 static ZED_RELEASE_CHANNEL: OnceLock<HeaderName> = OnceLock::new();
1193 ZED_RELEASE_CHANNEL.get_or_init(|| HeaderName::from_static("x-zed-release-channel"))
1194 }
1195
1196 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1197 where
1198 Self: Sized,
1199 I: Iterator<Item = &'i axum::http::HeaderValue>,
1200 {
1201 Ok(Self(
1202 values
1203 .next()
1204 .ok_or_else(axum::headers::Error::invalid)?
1205 .to_str()
1206 .map_err(|_| axum::headers::Error::invalid())?
1207 .to_owned(),
1208 ))
1209 }
1210
1211 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1212 values.extend([self.0.parse().unwrap()]);
1213 }
1214}
1215
1216pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1217 Router::new()
1218 .route("/rpc", get(handle_websocket_request))
1219 .layer(
1220 ServiceBuilder::new()
1221 .layer(Extension(server.app_state.clone()))
1222 .layer(middleware::from_fn(auth::validate_header)),
1223 )
1224 .route("/metrics", get(handle_metrics))
1225 .layer(Extension(server))
1226}
1227
1228pub async fn handle_websocket_request(
1229 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1230 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1231 release_channel_header: Option<TypedHeader<ReleaseChannelHeader>>,
1232 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1233 Extension(server): Extension<Arc<Server>>,
1234 Extension(principal): Extension<Principal>,
1235 user_agent: Option<TypedHeader<UserAgent>>,
1236 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1237 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1238 ws: WebSocketUpgrade,
1239) -> axum::response::Response {
1240 if protocol_version != rpc::PROTOCOL_VERSION {
1241 return (
1242 StatusCode::UPGRADE_REQUIRED,
1243 "client must be upgraded".to_string(),
1244 )
1245 .into_response();
1246 }
1247
1248 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1249 return (
1250 StatusCode::UPGRADE_REQUIRED,
1251 "no version header found".to_string(),
1252 )
1253 .into_response();
1254 };
1255
1256 let release_channel = release_channel_header.map(|header| header.0.0);
1257
1258 if !version.can_collaborate() {
1259 return (
1260 StatusCode::UPGRADE_REQUIRED,
1261 "client must be upgraded".to_string(),
1262 )
1263 .into_response();
1264 }
1265
1266 let socket_address = socket_address.to_string();
1267
1268 // Acquire connection guard before WebSocket upgrade
1269 let connection_guard = match ConnectionGuard::try_acquire() {
1270 Ok(guard) => guard,
1271 Err(()) => {
1272 return (
1273 StatusCode::SERVICE_UNAVAILABLE,
1274 "Too many concurrent connections",
1275 )
1276 .into_response();
1277 }
1278 };
1279
1280 ws.on_upgrade(move |socket| {
1281 let socket = socket
1282 .map_ok(to_tungstenite_message)
1283 .err_into()
1284 .with(|message| async move { to_axum_message(message) });
1285 let connection = Connection::new(Box::pin(socket));
1286 async move {
1287 server
1288 .handle_connection(
1289 connection,
1290 socket_address,
1291 principal,
1292 version,
1293 release_channel,
1294 user_agent.map(|header| header.to_string()),
1295 country_code_header.map(|header| header.to_string()),
1296 system_id_header.map(|header| header.to_string()),
1297 None,
1298 Executor::Production,
1299 Some(connection_guard),
1300 )
1301 .await;
1302 }
1303 })
1304}
1305
1306pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1307 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1308 let connections_metric = CONNECTIONS_METRIC
1309 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1310
1311 let connections = server
1312 .connection_pool
1313 .lock()
1314 .connections()
1315 .filter(|connection| !connection.admin)
1316 .count();
1317 connections_metric.set(connections as _);
1318
1319 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1320 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1321 register_int_gauge!(
1322 "shared_projects",
1323 "number of open projects with one or more guests"
1324 )
1325 .unwrap()
1326 });
1327
1328 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1329 shared_projects_metric.set(shared_projects as _);
1330
1331 let encoder = prometheus::TextEncoder::new();
1332 let metric_families = prometheus::gather();
1333 let encoded_metrics = encoder
1334 .encode_to_string(&metric_families)
1335 .map_err(|err| anyhow!("{err}"))?;
1336 Ok(encoded_metrics)
1337}
1338
1339#[instrument(err, skip(executor))]
1340async fn connection_lost(
1341 session: Session,
1342 mut teardown: watch::Receiver<bool>,
1343 executor: Executor,
1344) -> Result<()> {
1345 session.peer.disconnect(session.connection_id);
1346 session
1347 .connection_pool()
1348 .await
1349 .remove_connection(session.connection_id)?;
1350
1351 session
1352 .db()
1353 .await
1354 .connection_lost(session.connection_id)
1355 .await
1356 .trace_err();
1357
1358 futures::select_biased! {
1359 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1360
1361 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1362 leave_room_for_session(&session, session.connection_id).await.trace_err();
1363 leave_channel_buffers_for_session(&session)
1364 .await
1365 .trace_err();
1366
1367 if !session
1368 .connection_pool()
1369 .await
1370 .is_user_online(session.user_id())
1371 {
1372 let db = session.db().await;
1373 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1374 room_updated(&room, &session.peer);
1375 }
1376 }
1377
1378 update_user_contacts(session.user_id(), &session).await?;
1379 },
1380 _ = teardown.changed().fuse() => {}
1381 }
1382
1383 Ok(())
1384}
1385
1386/// Acknowledges a ping from a client, used to keep the connection alive.
1387async fn ping(
1388 _: proto::Ping,
1389 response: Response<proto::Ping>,
1390 _session: MessageContext,
1391) -> Result<()> {
1392 response.send(proto::Ack {})?;
1393 Ok(())
1394}
1395
1396/// Creates a new room for calling (outside of channels)
1397async fn create_room(
1398 _request: proto::CreateRoom,
1399 response: Response<proto::CreateRoom>,
1400 session: MessageContext,
1401) -> Result<()> {
1402 let livekit_room = nanoid::nanoid!(30);
1403
1404 let live_kit_connection_info = util::maybe!(async {
1405 let live_kit = session.app_state.livekit_client.as_ref();
1406 let live_kit = live_kit?;
1407 let user_id = session.user_id().to_string();
1408
1409 let token = live_kit
1410 .room_token(&livekit_room, &user_id.to_string())
1411 .trace_err()?;
1412
1413 Some(proto::LiveKitConnectionInfo {
1414 server_url: live_kit.url().into(),
1415 token,
1416 can_publish: true,
1417 })
1418 })
1419 .await;
1420
1421 let room = session
1422 .db()
1423 .await
1424 .create_room(session.user_id(), session.connection_id, &livekit_room)
1425 .await?;
1426
1427 response.send(proto::CreateRoomResponse {
1428 room: Some(room.clone()),
1429 live_kit_connection_info,
1430 })?;
1431
1432 update_user_contacts(session.user_id(), &session).await?;
1433 Ok(())
1434}
1435
1436/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1437async fn join_room(
1438 request: proto::JoinRoom,
1439 response: Response<proto::JoinRoom>,
1440 session: MessageContext,
1441) -> Result<()> {
1442 let room_id = RoomId::from_proto(request.id);
1443
1444 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1445
1446 if let Some(channel_id) = channel_id {
1447 return join_channel_internal(channel_id, Box::new(response), session).await;
1448 }
1449
1450 let joined_room = {
1451 let room = session
1452 .db()
1453 .await
1454 .join_room(room_id, session.user_id(), session.connection_id)
1455 .await?;
1456 room_updated(&room.room, &session.peer);
1457 room.into_inner()
1458 };
1459
1460 for connection_id in session
1461 .connection_pool()
1462 .await
1463 .user_connection_ids(session.user_id())
1464 {
1465 session
1466 .peer
1467 .send(
1468 connection_id,
1469 proto::CallCanceled {
1470 room_id: room_id.to_proto(),
1471 },
1472 )
1473 .trace_err();
1474 }
1475
1476 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1477 {
1478 live_kit
1479 .room_token(
1480 &joined_room.room.livekit_room,
1481 &session.user_id().to_string(),
1482 )
1483 .trace_err()
1484 .map(|token| proto::LiveKitConnectionInfo {
1485 server_url: live_kit.url().into(),
1486 token,
1487 can_publish: true,
1488 })
1489 } else {
1490 None
1491 };
1492
1493 response.send(proto::JoinRoomResponse {
1494 room: Some(joined_room.room),
1495 channel_id: None,
1496 live_kit_connection_info,
1497 })?;
1498
1499 update_user_contacts(session.user_id(), &session).await?;
1500 Ok(())
1501}
1502
1503/// Rejoin room is used to reconnect to a room after connection errors.
1504async fn rejoin_room(
1505 request: proto::RejoinRoom,
1506 response: Response<proto::RejoinRoom>,
1507 session: MessageContext,
1508) -> Result<()> {
1509 let room;
1510 let channel;
1511 {
1512 let mut rejoined_room = session
1513 .db()
1514 .await
1515 .rejoin_room(request, session.user_id(), session.connection_id)
1516 .await?;
1517
1518 response.send(proto::RejoinRoomResponse {
1519 room: Some(rejoined_room.room.clone()),
1520 reshared_projects: rejoined_room
1521 .reshared_projects
1522 .iter()
1523 .map(|project| proto::ResharedProject {
1524 id: project.id.to_proto(),
1525 collaborators: project
1526 .collaborators
1527 .iter()
1528 .map(|collaborator| collaborator.to_proto())
1529 .collect(),
1530 })
1531 .collect(),
1532 rejoined_projects: rejoined_room
1533 .rejoined_projects
1534 .iter()
1535 .map(|rejoined_project| rejoined_project.to_proto())
1536 .collect(),
1537 })?;
1538 room_updated(&rejoined_room.room, &session.peer);
1539
1540 for project in &rejoined_room.reshared_projects {
1541 for collaborator in &project.collaborators {
1542 session
1543 .peer
1544 .send(
1545 collaborator.connection_id,
1546 proto::UpdateProjectCollaborator {
1547 project_id: project.id.to_proto(),
1548 old_peer_id: Some(project.old_connection_id.into()),
1549 new_peer_id: Some(session.connection_id.into()),
1550 },
1551 )
1552 .trace_err();
1553 }
1554
1555 broadcast(
1556 Some(session.connection_id),
1557 project
1558 .collaborators
1559 .iter()
1560 .map(|collaborator| collaborator.connection_id),
1561 |connection_id| {
1562 session.peer.forward_send(
1563 session.connection_id,
1564 connection_id,
1565 proto::UpdateProject {
1566 project_id: project.id.to_proto(),
1567 worktrees: project.worktrees.clone(),
1568 },
1569 )
1570 },
1571 );
1572 }
1573
1574 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1575
1576 let rejoined_room = rejoined_room.into_inner();
1577
1578 room = rejoined_room.room;
1579 channel = rejoined_room.channel;
1580 }
1581
1582 if let Some(channel) = channel {
1583 channel_updated(
1584 &channel,
1585 &room,
1586 &session.peer,
1587 &*session.connection_pool().await,
1588 );
1589 }
1590
1591 update_user_contacts(session.user_id(), &session).await?;
1592 Ok(())
1593}
1594
1595fn notify_rejoined_projects(
1596 rejoined_projects: &mut Vec<RejoinedProject>,
1597 session: &Session,
1598) -> Result<()> {
1599 for project in rejoined_projects.iter() {
1600 for collaborator in &project.collaborators {
1601 session
1602 .peer
1603 .send(
1604 collaborator.connection_id,
1605 proto::UpdateProjectCollaborator {
1606 project_id: project.id.to_proto(),
1607 old_peer_id: Some(project.old_connection_id.into()),
1608 new_peer_id: Some(session.connection_id.into()),
1609 },
1610 )
1611 .trace_err();
1612 }
1613 }
1614
1615 for project in rejoined_projects {
1616 for worktree in mem::take(&mut project.worktrees) {
1617 // Stream this worktree's entries.
1618 let message = proto::UpdateWorktree {
1619 project_id: project.id.to_proto(),
1620 worktree_id: worktree.id,
1621 abs_path: worktree.abs_path.clone(),
1622 root_name: worktree.root_name,
1623 updated_entries: worktree.updated_entries,
1624 removed_entries: worktree.removed_entries,
1625 scan_id: worktree.scan_id,
1626 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1627 updated_repositories: worktree.updated_repositories,
1628 removed_repositories: worktree.removed_repositories,
1629 };
1630 for update in proto::split_worktree_update(message) {
1631 session.peer.send(session.connection_id, update)?;
1632 }
1633
1634 // Stream this worktree's diagnostics.
1635 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
1636 if let Some(summary) = worktree_diagnostics.next() {
1637 let message = proto::UpdateDiagnosticSummary {
1638 project_id: project.id.to_proto(),
1639 worktree_id: worktree.id,
1640 summary: Some(summary),
1641 more_summaries: worktree_diagnostics.collect(),
1642 };
1643 session.peer.send(session.connection_id, message)?;
1644 }
1645
1646 for settings_file in worktree.settings_files {
1647 session.peer.send(
1648 session.connection_id,
1649 proto::UpdateWorktreeSettings {
1650 project_id: project.id.to_proto(),
1651 worktree_id: worktree.id,
1652 path: settings_file.path,
1653 content: Some(settings_file.content),
1654 kind: Some(settings_file.kind.to_proto().into()),
1655 },
1656 )?;
1657 }
1658 }
1659
1660 for repository in mem::take(&mut project.updated_repositories) {
1661 for update in split_repository_update(repository) {
1662 session.peer.send(session.connection_id, update)?;
1663 }
1664 }
1665
1666 for id in mem::take(&mut project.removed_repositories) {
1667 session.peer.send(
1668 session.connection_id,
1669 proto::RemoveRepository {
1670 project_id: project.id.to_proto(),
1671 id,
1672 },
1673 )?;
1674 }
1675 }
1676
1677 Ok(())
1678}
1679
1680/// leave room disconnects from the room.
1681async fn leave_room(
1682 _: proto::LeaveRoom,
1683 response: Response<proto::LeaveRoom>,
1684 session: MessageContext,
1685) -> Result<()> {
1686 leave_room_for_session(&session, session.connection_id).await?;
1687 response.send(proto::Ack {})?;
1688 Ok(())
1689}
1690
1691/// Updates the permissions of someone else in the room.
1692async fn set_room_participant_role(
1693 request: proto::SetRoomParticipantRole,
1694 response: Response<proto::SetRoomParticipantRole>,
1695 session: MessageContext,
1696) -> Result<()> {
1697 let user_id = UserId::from_proto(request.user_id);
1698 let role = ChannelRole::from(request.role());
1699
1700 let (livekit_room, can_publish) = {
1701 let room = session
1702 .db()
1703 .await
1704 .set_room_participant_role(
1705 session.user_id(),
1706 RoomId::from_proto(request.room_id),
1707 user_id,
1708 role,
1709 )
1710 .await?;
1711
1712 let livekit_room = room.livekit_room.clone();
1713 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1714 room_updated(&room, &session.peer);
1715 (livekit_room, can_publish)
1716 };
1717
1718 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1719 live_kit
1720 .update_participant(
1721 livekit_room.clone(),
1722 request.user_id.to_string(),
1723 livekit_api::proto::ParticipantPermission {
1724 can_subscribe: true,
1725 can_publish,
1726 can_publish_data: can_publish,
1727 hidden: false,
1728 recorder: false,
1729 },
1730 )
1731 .await
1732 .trace_err();
1733 }
1734
1735 response.send(proto::Ack {})?;
1736 Ok(())
1737}
1738
1739/// Call someone else into the current room
1740async fn call(
1741 request: proto::Call,
1742 response: Response<proto::Call>,
1743 session: MessageContext,
1744) -> Result<()> {
1745 let room_id = RoomId::from_proto(request.room_id);
1746 let calling_user_id = session.user_id();
1747 let calling_connection_id = session.connection_id;
1748 let called_user_id = UserId::from_proto(request.called_user_id);
1749 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1750 if !session
1751 .db()
1752 .await
1753 .has_contact(calling_user_id, called_user_id)
1754 .await?
1755 {
1756 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1757 }
1758
1759 let incoming_call = {
1760 let (room, incoming_call) = &mut *session
1761 .db()
1762 .await
1763 .call(
1764 room_id,
1765 calling_user_id,
1766 calling_connection_id,
1767 called_user_id,
1768 initial_project_id,
1769 )
1770 .await?;
1771 room_updated(room, &session.peer);
1772 mem::take(incoming_call)
1773 };
1774 update_user_contacts(called_user_id, &session).await?;
1775
1776 let mut calls = session
1777 .connection_pool()
1778 .await
1779 .user_connection_ids(called_user_id)
1780 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1781 .collect::<FuturesUnordered<_>>();
1782
1783 while let Some(call_response) = calls.next().await {
1784 match call_response.as_ref() {
1785 Ok(_) => {
1786 response.send(proto::Ack {})?;
1787 return Ok(());
1788 }
1789 Err(_) => {
1790 call_response.trace_err();
1791 }
1792 }
1793 }
1794
1795 {
1796 let room = session
1797 .db()
1798 .await
1799 .call_failed(room_id, called_user_id)
1800 .await?;
1801 room_updated(&room, &session.peer);
1802 }
1803 update_user_contacts(called_user_id, &session).await?;
1804
1805 Err(anyhow!("failed to ring user"))?
1806}
1807
1808/// Cancel an outgoing call.
1809async fn cancel_call(
1810 request: proto::CancelCall,
1811 response: Response<proto::CancelCall>,
1812 session: MessageContext,
1813) -> Result<()> {
1814 let called_user_id = UserId::from_proto(request.called_user_id);
1815 let room_id = RoomId::from_proto(request.room_id);
1816 {
1817 let room = session
1818 .db()
1819 .await
1820 .cancel_call(room_id, session.connection_id, called_user_id)
1821 .await?;
1822 room_updated(&room, &session.peer);
1823 }
1824
1825 for connection_id in session
1826 .connection_pool()
1827 .await
1828 .user_connection_ids(called_user_id)
1829 {
1830 session
1831 .peer
1832 .send(
1833 connection_id,
1834 proto::CallCanceled {
1835 room_id: room_id.to_proto(),
1836 },
1837 )
1838 .trace_err();
1839 }
1840 response.send(proto::Ack {})?;
1841
1842 update_user_contacts(called_user_id, &session).await?;
1843 Ok(())
1844}
1845
1846/// Decline an incoming call.
1847async fn decline_call(message: proto::DeclineCall, session: MessageContext) -> Result<()> {
1848 let room_id = RoomId::from_proto(message.room_id);
1849 {
1850 let room = session
1851 .db()
1852 .await
1853 .decline_call(Some(room_id), session.user_id())
1854 .await?
1855 .context("declining call")?;
1856 room_updated(&room, &session.peer);
1857 }
1858
1859 for connection_id in session
1860 .connection_pool()
1861 .await
1862 .user_connection_ids(session.user_id())
1863 {
1864 session
1865 .peer
1866 .send(
1867 connection_id,
1868 proto::CallCanceled {
1869 room_id: room_id.to_proto(),
1870 },
1871 )
1872 .trace_err();
1873 }
1874 update_user_contacts(session.user_id(), &session).await?;
1875 Ok(())
1876}
1877
1878/// Updates other participants in the room with your current location.
1879async fn update_participant_location(
1880 request: proto::UpdateParticipantLocation,
1881 response: Response<proto::UpdateParticipantLocation>,
1882 session: MessageContext,
1883) -> Result<()> {
1884 let room_id = RoomId::from_proto(request.room_id);
1885 let location = request.location.context("invalid location")?;
1886
1887 let db = session.db().await;
1888 let room = db
1889 .update_room_participant_location(room_id, session.connection_id, location)
1890 .await?;
1891
1892 room_updated(&room, &session.peer);
1893 response.send(proto::Ack {})?;
1894 Ok(())
1895}
1896
1897/// Share a project into the room.
1898async fn share_project(
1899 request: proto::ShareProject,
1900 response: Response<proto::ShareProject>,
1901 session: MessageContext,
1902) -> Result<()> {
1903 let (project_id, room) = &*session
1904 .db()
1905 .await
1906 .share_project(
1907 RoomId::from_proto(request.room_id),
1908 session.connection_id,
1909 &request.worktrees,
1910 request.is_ssh_project,
1911 )
1912 .await?;
1913 response.send(proto::ShareProjectResponse {
1914 project_id: project_id.to_proto(),
1915 })?;
1916 room_updated(room, &session.peer);
1917
1918 Ok(())
1919}
1920
1921/// Unshare a project from the room.
1922async fn unshare_project(message: proto::UnshareProject, session: MessageContext) -> Result<()> {
1923 let project_id = ProjectId::from_proto(message.project_id);
1924 unshare_project_internal(project_id, session.connection_id, &session).await
1925}
1926
1927async fn unshare_project_internal(
1928 project_id: ProjectId,
1929 connection_id: ConnectionId,
1930 session: &Session,
1931) -> Result<()> {
1932 let delete = {
1933 let room_guard = session
1934 .db()
1935 .await
1936 .unshare_project(project_id, connection_id)
1937 .await?;
1938
1939 let (delete, room, guest_connection_ids) = &*room_guard;
1940
1941 let message = proto::UnshareProject {
1942 project_id: project_id.to_proto(),
1943 };
1944
1945 broadcast(
1946 Some(connection_id),
1947 guest_connection_ids.iter().copied(),
1948 |conn_id| session.peer.send(conn_id, message.clone()),
1949 );
1950 if let Some(room) = room {
1951 room_updated(room, &session.peer);
1952 }
1953
1954 *delete
1955 };
1956
1957 if delete {
1958 let db = session.db().await;
1959 db.delete_project(project_id).await?;
1960 }
1961
1962 Ok(())
1963}
1964
1965/// Join someone elses shared project.
1966async fn join_project(
1967 request: proto::JoinProject,
1968 response: Response<proto::JoinProject>,
1969 session: MessageContext,
1970) -> Result<()> {
1971 let project_id = ProjectId::from_proto(request.project_id);
1972
1973 tracing::info!(%project_id, "join project");
1974
1975 let db = session.db().await;
1976 let (project, replica_id) = &mut *db
1977 .join_project(
1978 project_id,
1979 session.connection_id,
1980 session.user_id(),
1981 request.committer_name.clone(),
1982 request.committer_email.clone(),
1983 )
1984 .await?;
1985 drop(db);
1986 tracing::info!(%project_id, "join remote project");
1987 let collaborators = project
1988 .collaborators
1989 .iter()
1990 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1991 .map(|collaborator| collaborator.to_proto())
1992 .collect::<Vec<_>>();
1993 let project_id = project.id;
1994 let guest_user_id = session.user_id();
1995
1996 let worktrees = project
1997 .worktrees
1998 .iter()
1999 .map(|(id, worktree)| proto::WorktreeMetadata {
2000 id: *id,
2001 root_name: worktree.root_name.clone(),
2002 visible: worktree.visible,
2003 abs_path: worktree.abs_path.clone(),
2004 })
2005 .collect::<Vec<_>>();
2006
2007 let add_project_collaborator = proto::AddProjectCollaborator {
2008 project_id: project_id.to_proto(),
2009 collaborator: Some(proto::Collaborator {
2010 peer_id: Some(session.connection_id.into()),
2011 replica_id: replica_id.0 as u32,
2012 user_id: guest_user_id.to_proto(),
2013 is_host: false,
2014 committer_name: request.committer_name.clone(),
2015 committer_email: request.committer_email.clone(),
2016 }),
2017 };
2018
2019 for collaborator in &collaborators {
2020 session
2021 .peer
2022 .send(
2023 collaborator.peer_id.unwrap().into(),
2024 add_project_collaborator.clone(),
2025 )
2026 .trace_err();
2027 }
2028
2029 // First, we send the metadata associated with each worktree.
2030 let (language_servers, language_server_capabilities) = project
2031 .language_servers
2032 .clone()
2033 .into_iter()
2034 .map(|server| (server.server, server.capabilities))
2035 .unzip();
2036 response.send(proto::JoinProjectResponse {
2037 project_id: project.id.0 as u64,
2038 worktrees: worktrees.clone(),
2039 replica_id: replica_id.0 as u32,
2040 collaborators: collaborators.clone(),
2041 language_servers,
2042 language_server_capabilities,
2043 role: project.role.into(),
2044 })?;
2045
2046 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2047 // Stream this worktree's entries.
2048 let message = proto::UpdateWorktree {
2049 project_id: project_id.to_proto(),
2050 worktree_id,
2051 abs_path: worktree.abs_path.clone(),
2052 root_name: worktree.root_name,
2053 updated_entries: worktree.entries,
2054 removed_entries: Default::default(),
2055 scan_id: worktree.scan_id,
2056 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2057 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2058 removed_repositories: Default::default(),
2059 };
2060 for update in proto::split_worktree_update(message) {
2061 session.peer.send(session.connection_id, update.clone())?;
2062 }
2063
2064 // Stream this worktree's diagnostics.
2065 let mut worktree_diagnostics = worktree.diagnostic_summaries.into_iter();
2066 if let Some(summary) = worktree_diagnostics.next() {
2067 let message = proto::UpdateDiagnosticSummary {
2068 project_id: project.id.to_proto(),
2069 worktree_id: worktree.id,
2070 summary: Some(summary),
2071 more_summaries: worktree_diagnostics.collect(),
2072 };
2073 session.peer.send(session.connection_id, message)?;
2074 }
2075
2076 for settings_file in worktree.settings_files {
2077 session.peer.send(
2078 session.connection_id,
2079 proto::UpdateWorktreeSettings {
2080 project_id: project_id.to_proto(),
2081 worktree_id: worktree.id,
2082 path: settings_file.path,
2083 content: Some(settings_file.content),
2084 kind: Some(settings_file.kind.to_proto() as i32),
2085 },
2086 )?;
2087 }
2088 }
2089
2090 for repository in mem::take(&mut project.repositories) {
2091 for update in split_repository_update(repository) {
2092 session.peer.send(session.connection_id, update)?;
2093 }
2094 }
2095
2096 for language_server in &project.language_servers {
2097 session.peer.send(
2098 session.connection_id,
2099 proto::UpdateLanguageServer {
2100 project_id: project_id.to_proto(),
2101 server_name: Some(language_server.server.name.clone()),
2102 language_server_id: language_server.server.id,
2103 variant: Some(
2104 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2105 proto::LspDiskBasedDiagnosticsUpdated {},
2106 ),
2107 ),
2108 },
2109 )?;
2110 }
2111
2112 Ok(())
2113}
2114
2115/// Leave someone elses shared project.
2116async fn leave_project(request: proto::LeaveProject, session: MessageContext) -> Result<()> {
2117 let sender_id = session.connection_id;
2118 let project_id = ProjectId::from_proto(request.project_id);
2119 let db = session.db().await;
2120
2121 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2122 tracing::info!(
2123 %project_id,
2124 "leave project"
2125 );
2126
2127 project_left(project, &session);
2128 if let Some(room) = room {
2129 room_updated(room, &session.peer);
2130 }
2131
2132 Ok(())
2133}
2134
2135/// Updates other participants with changes to the project
2136async fn update_project(
2137 request: proto::UpdateProject,
2138 response: Response<proto::UpdateProject>,
2139 session: MessageContext,
2140) -> Result<()> {
2141 let project_id = ProjectId::from_proto(request.project_id);
2142 let (room, guest_connection_ids) = &*session
2143 .db()
2144 .await
2145 .update_project(project_id, session.connection_id, &request.worktrees)
2146 .await?;
2147 broadcast(
2148 Some(session.connection_id),
2149 guest_connection_ids.iter().copied(),
2150 |connection_id| {
2151 session
2152 .peer
2153 .forward_send(session.connection_id, connection_id, request.clone())
2154 },
2155 );
2156 if let Some(room) = room {
2157 room_updated(room, &session.peer);
2158 }
2159 response.send(proto::Ack {})?;
2160
2161 Ok(())
2162}
2163
2164/// Updates other participants with changes to the worktree
2165async fn update_worktree(
2166 request: proto::UpdateWorktree,
2167 response: Response<proto::UpdateWorktree>,
2168 session: MessageContext,
2169) -> Result<()> {
2170 let guest_connection_ids = session
2171 .db()
2172 .await
2173 .update_worktree(&request, session.connection_id)
2174 .await?;
2175
2176 broadcast(
2177 Some(session.connection_id),
2178 guest_connection_ids.iter().copied(),
2179 |connection_id| {
2180 session
2181 .peer
2182 .forward_send(session.connection_id, connection_id, request.clone())
2183 },
2184 );
2185 response.send(proto::Ack {})?;
2186 Ok(())
2187}
2188
2189async fn update_repository(
2190 request: proto::UpdateRepository,
2191 response: Response<proto::UpdateRepository>,
2192 session: MessageContext,
2193) -> Result<()> {
2194 let guest_connection_ids = session
2195 .db()
2196 .await
2197 .update_repository(&request, session.connection_id)
2198 .await?;
2199
2200 broadcast(
2201 Some(session.connection_id),
2202 guest_connection_ids.iter().copied(),
2203 |connection_id| {
2204 session
2205 .peer
2206 .forward_send(session.connection_id, connection_id, request.clone())
2207 },
2208 );
2209 response.send(proto::Ack {})?;
2210 Ok(())
2211}
2212
2213async fn remove_repository(
2214 request: proto::RemoveRepository,
2215 response: Response<proto::RemoveRepository>,
2216 session: MessageContext,
2217) -> Result<()> {
2218 let guest_connection_ids = session
2219 .db()
2220 .await
2221 .remove_repository(&request, session.connection_id)
2222 .await?;
2223
2224 broadcast(
2225 Some(session.connection_id),
2226 guest_connection_ids.iter().copied(),
2227 |connection_id| {
2228 session
2229 .peer
2230 .forward_send(session.connection_id, connection_id, request.clone())
2231 },
2232 );
2233 response.send(proto::Ack {})?;
2234 Ok(())
2235}
2236
2237/// Updates other participants with changes to the diagnostics
2238async fn update_diagnostic_summary(
2239 message: proto::UpdateDiagnosticSummary,
2240 session: MessageContext,
2241) -> Result<()> {
2242 let guest_connection_ids = session
2243 .db()
2244 .await
2245 .update_diagnostic_summary(&message, session.connection_id)
2246 .await?;
2247
2248 broadcast(
2249 Some(session.connection_id),
2250 guest_connection_ids.iter().copied(),
2251 |connection_id| {
2252 session
2253 .peer
2254 .forward_send(session.connection_id, connection_id, message.clone())
2255 },
2256 );
2257
2258 Ok(())
2259}
2260
2261/// Updates other participants with changes to the worktree settings
2262async fn update_worktree_settings(
2263 message: proto::UpdateWorktreeSettings,
2264 session: MessageContext,
2265) -> Result<()> {
2266 let guest_connection_ids = session
2267 .db()
2268 .await
2269 .update_worktree_settings(&message, session.connection_id)
2270 .await?;
2271
2272 broadcast(
2273 Some(session.connection_id),
2274 guest_connection_ids.iter().copied(),
2275 |connection_id| {
2276 session
2277 .peer
2278 .forward_send(session.connection_id, connection_id, message.clone())
2279 },
2280 );
2281
2282 Ok(())
2283}
2284
2285/// Notify other participants that a language server has started.
2286async fn start_language_server(
2287 request: proto::StartLanguageServer,
2288 session: MessageContext,
2289) -> Result<()> {
2290 let guest_connection_ids = session
2291 .db()
2292 .await
2293 .start_language_server(&request, session.connection_id)
2294 .await?;
2295
2296 broadcast(
2297 Some(session.connection_id),
2298 guest_connection_ids.iter().copied(),
2299 |connection_id| {
2300 session
2301 .peer
2302 .forward_send(session.connection_id, connection_id, request.clone())
2303 },
2304 );
2305 Ok(())
2306}
2307
2308/// Notify other participants that a language server has changed.
2309async fn update_language_server(
2310 request: proto::UpdateLanguageServer,
2311 session: MessageContext,
2312) -> Result<()> {
2313 let project_id = ProjectId::from_proto(request.project_id);
2314 let db = session.db().await;
2315
2316 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2317 {
2318 if let Some(capabilities) = update.capabilities.clone() {
2319 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2320 .await?;
2321 }
2322 }
2323
2324 let project_connection_ids = db
2325 .project_connection_ids(project_id, session.connection_id, true)
2326 .await?;
2327 broadcast(
2328 Some(session.connection_id),
2329 project_connection_ids.iter().copied(),
2330 |connection_id| {
2331 session
2332 .peer
2333 .forward_send(session.connection_id, connection_id, request.clone())
2334 },
2335 );
2336 Ok(())
2337}
2338
2339/// forward a project request to the host. These requests should be read only
2340/// as guests are allowed to send them.
2341async fn forward_read_only_project_request<T>(
2342 request: T,
2343 response: Response<T>,
2344 session: MessageContext,
2345) -> Result<()>
2346where
2347 T: EntityMessage + RequestMessage,
2348{
2349 let project_id = ProjectId::from_proto(request.remote_entity_id());
2350 let host_connection_id = session
2351 .db()
2352 .await
2353 .host_for_read_only_project_request(project_id, session.connection_id)
2354 .await?;
2355 let payload = session.forward_request(host_connection_id, request).await?;
2356 response.send(payload)?;
2357 Ok(())
2358}
2359
2360/// forward a project request to the host. These requests are disallowed
2361/// for guests.
2362async fn forward_mutating_project_request<T>(
2363 request: T,
2364 response: Response<T>,
2365 session: MessageContext,
2366) -> Result<()>
2367where
2368 T: EntityMessage + RequestMessage,
2369{
2370 let project_id = ProjectId::from_proto(request.remote_entity_id());
2371
2372 let host_connection_id = session
2373 .db()
2374 .await
2375 .host_for_mutating_project_request(project_id, session.connection_id)
2376 .await?;
2377 let payload = session.forward_request(host_connection_id, request).await?;
2378 response.send(payload)?;
2379 Ok(())
2380}
2381
2382async fn multi_lsp_query(
2383 request: MultiLspQuery,
2384 response: Response<MultiLspQuery>,
2385 session: MessageContext,
2386) -> Result<()> {
2387 tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2388 tracing::info!("multi_lsp_query message received");
2389 forward_mutating_project_request(request, response, session).await
2390}
2391
2392/// Notify other participants that a new buffer has been created
2393async fn create_buffer_for_peer(
2394 request: proto::CreateBufferForPeer,
2395 session: MessageContext,
2396) -> Result<()> {
2397 session
2398 .db()
2399 .await
2400 .check_user_is_project_host(
2401 ProjectId::from_proto(request.project_id),
2402 session.connection_id,
2403 )
2404 .await?;
2405 let peer_id = request.peer_id.context("invalid peer id")?;
2406 session
2407 .peer
2408 .forward_send(session.connection_id, peer_id.into(), request)?;
2409 Ok(())
2410}
2411
2412/// Notify other participants that a buffer has been updated. This is
2413/// allowed for guests as long as the update is limited to selections.
2414async fn update_buffer(
2415 request: proto::UpdateBuffer,
2416 response: Response<proto::UpdateBuffer>,
2417 session: MessageContext,
2418) -> Result<()> {
2419 let project_id = ProjectId::from_proto(request.project_id);
2420 let mut capability = Capability::ReadOnly;
2421
2422 for op in request.operations.iter() {
2423 match op.variant {
2424 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2425 Some(_) => capability = Capability::ReadWrite,
2426 }
2427 }
2428
2429 let host = {
2430 let guard = session
2431 .db()
2432 .await
2433 .connections_for_buffer_update(project_id, session.connection_id, capability)
2434 .await?;
2435
2436 let (host, guests) = &*guard;
2437
2438 broadcast(
2439 Some(session.connection_id),
2440 guests.clone(),
2441 |connection_id| {
2442 session
2443 .peer
2444 .forward_send(session.connection_id, connection_id, request.clone())
2445 },
2446 );
2447
2448 *host
2449 };
2450
2451 if host != session.connection_id {
2452 session.forward_request(host, request.clone()).await?;
2453 }
2454
2455 response.send(proto::Ack {})?;
2456 Ok(())
2457}
2458
2459async fn update_context(message: proto::UpdateContext, session: MessageContext) -> Result<()> {
2460 let project_id = ProjectId::from_proto(message.project_id);
2461
2462 let operation = message.operation.as_ref().context("invalid operation")?;
2463 let capability = match operation.variant.as_ref() {
2464 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2465 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2466 match buffer_op.variant {
2467 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2468 Capability::ReadOnly
2469 }
2470 _ => Capability::ReadWrite,
2471 }
2472 } else {
2473 Capability::ReadWrite
2474 }
2475 }
2476 Some(_) => Capability::ReadWrite,
2477 None => Capability::ReadOnly,
2478 };
2479
2480 let guard = session
2481 .db()
2482 .await
2483 .connections_for_buffer_update(project_id, session.connection_id, capability)
2484 .await?;
2485
2486 let (host, guests) = &*guard;
2487
2488 broadcast(
2489 Some(session.connection_id),
2490 guests.iter().chain([host]).copied(),
2491 |connection_id| {
2492 session
2493 .peer
2494 .forward_send(session.connection_id, connection_id, message.clone())
2495 },
2496 );
2497
2498 Ok(())
2499}
2500
2501/// Notify other participants that a project has been updated.
2502async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2503 request: T,
2504 session: MessageContext,
2505) -> Result<()> {
2506 let project_id = ProjectId::from_proto(request.remote_entity_id());
2507 let project_connection_ids = session
2508 .db()
2509 .await
2510 .project_connection_ids(project_id, session.connection_id, false)
2511 .await?;
2512
2513 broadcast(
2514 Some(session.connection_id),
2515 project_connection_ids.iter().copied(),
2516 |connection_id| {
2517 session
2518 .peer
2519 .forward_send(session.connection_id, connection_id, request.clone())
2520 },
2521 );
2522 Ok(())
2523}
2524
2525/// Start following another user in a call.
2526async fn follow(
2527 request: proto::Follow,
2528 response: Response<proto::Follow>,
2529 session: MessageContext,
2530) -> Result<()> {
2531 let room_id = RoomId::from_proto(request.room_id);
2532 let project_id = request.project_id.map(ProjectId::from_proto);
2533 let leader_id = request.leader_id.context("invalid leader id")?.into();
2534 let follower_id = session.connection_id;
2535
2536 session
2537 .db()
2538 .await
2539 .check_room_participants(room_id, leader_id, session.connection_id)
2540 .await?;
2541
2542 let response_payload = session.forward_request(leader_id, request).await?;
2543 response.send(response_payload)?;
2544
2545 if let Some(project_id) = project_id {
2546 let room = session
2547 .db()
2548 .await
2549 .follow(room_id, project_id, leader_id, follower_id)
2550 .await?;
2551 room_updated(&room, &session.peer);
2552 }
2553
2554 Ok(())
2555}
2556
2557/// Stop following another user in a call.
2558async fn unfollow(request: proto::Unfollow, session: MessageContext) -> Result<()> {
2559 let room_id = RoomId::from_proto(request.room_id);
2560 let project_id = request.project_id.map(ProjectId::from_proto);
2561 let leader_id = request.leader_id.context("invalid leader id")?.into();
2562 let follower_id = session.connection_id;
2563
2564 session
2565 .db()
2566 .await
2567 .check_room_participants(room_id, leader_id, session.connection_id)
2568 .await?;
2569
2570 session
2571 .peer
2572 .forward_send(session.connection_id, leader_id, request)?;
2573
2574 if let Some(project_id) = project_id {
2575 let room = session
2576 .db()
2577 .await
2578 .unfollow(room_id, project_id, leader_id, follower_id)
2579 .await?;
2580 room_updated(&room, &session.peer);
2581 }
2582
2583 Ok(())
2584}
2585
2586/// Notify everyone following you of your current location.
2587async fn update_followers(request: proto::UpdateFollowers, session: MessageContext) -> Result<()> {
2588 let room_id = RoomId::from_proto(request.room_id);
2589 let database = session.db.lock().await;
2590
2591 let connection_ids = if let Some(project_id) = request.project_id {
2592 let project_id = ProjectId::from_proto(project_id);
2593 database
2594 .project_connection_ids(project_id, session.connection_id, true)
2595 .await?
2596 } else {
2597 database
2598 .room_connection_ids(room_id, session.connection_id)
2599 .await?
2600 };
2601
2602 // For now, don't send view update messages back to that view's current leader.
2603 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2604 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2605 _ => None,
2606 });
2607
2608 for connection_id in connection_ids.iter().cloned() {
2609 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2610 session
2611 .peer
2612 .forward_send(session.connection_id, connection_id, request.clone())?;
2613 }
2614 }
2615 Ok(())
2616}
2617
2618/// Get public data about users.
2619async fn get_users(
2620 request: proto::GetUsers,
2621 response: Response<proto::GetUsers>,
2622 session: MessageContext,
2623) -> Result<()> {
2624 let user_ids = request
2625 .user_ids
2626 .into_iter()
2627 .map(UserId::from_proto)
2628 .collect();
2629 let users = session
2630 .db()
2631 .await
2632 .get_users_by_ids(user_ids)
2633 .await?
2634 .into_iter()
2635 .map(|user| proto::User {
2636 id: user.id.to_proto(),
2637 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2638 github_login: user.github_login,
2639 name: user.name,
2640 })
2641 .collect();
2642 response.send(proto::UsersResponse { users })?;
2643 Ok(())
2644}
2645
2646/// Search for users (to invite) buy Github login
2647async fn fuzzy_search_users(
2648 request: proto::FuzzySearchUsers,
2649 response: Response<proto::FuzzySearchUsers>,
2650 session: MessageContext,
2651) -> Result<()> {
2652 let query = request.query;
2653 let users = match query.len() {
2654 0 => vec![],
2655 1 | 2 => session
2656 .db()
2657 .await
2658 .get_user_by_github_login(&query)
2659 .await?
2660 .into_iter()
2661 .collect(),
2662 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2663 };
2664 let users = users
2665 .into_iter()
2666 .filter(|user| user.id != session.user_id())
2667 .map(|user| proto::User {
2668 id: user.id.to_proto(),
2669 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2670 github_login: user.github_login,
2671 name: user.name,
2672 })
2673 .collect();
2674 response.send(proto::UsersResponse { users })?;
2675 Ok(())
2676}
2677
2678/// Send a contact request to another user.
2679async fn request_contact(
2680 request: proto::RequestContact,
2681 response: Response<proto::RequestContact>,
2682 session: MessageContext,
2683) -> Result<()> {
2684 let requester_id = session.user_id();
2685 let responder_id = UserId::from_proto(request.responder_id);
2686 if requester_id == responder_id {
2687 return Err(anyhow!("cannot add yourself as a contact"))?;
2688 }
2689
2690 let notifications = session
2691 .db()
2692 .await
2693 .send_contact_request(requester_id, responder_id)
2694 .await?;
2695
2696 // Update outgoing contact requests of requester
2697 let mut update = proto::UpdateContacts::default();
2698 update.outgoing_requests.push(responder_id.to_proto());
2699 for connection_id in session
2700 .connection_pool()
2701 .await
2702 .user_connection_ids(requester_id)
2703 {
2704 session.peer.send(connection_id, update.clone())?;
2705 }
2706
2707 // Update incoming contact requests of responder
2708 let mut update = proto::UpdateContacts::default();
2709 update
2710 .incoming_requests
2711 .push(proto::IncomingContactRequest {
2712 requester_id: requester_id.to_proto(),
2713 });
2714 let connection_pool = session.connection_pool().await;
2715 for connection_id in connection_pool.user_connection_ids(responder_id) {
2716 session.peer.send(connection_id, update.clone())?;
2717 }
2718
2719 send_notifications(&connection_pool, &session.peer, notifications);
2720
2721 response.send(proto::Ack {})?;
2722 Ok(())
2723}
2724
2725/// Accept or decline a contact request
2726async fn respond_to_contact_request(
2727 request: proto::RespondToContactRequest,
2728 response: Response<proto::RespondToContactRequest>,
2729 session: MessageContext,
2730) -> Result<()> {
2731 let responder_id = session.user_id();
2732 let requester_id = UserId::from_proto(request.requester_id);
2733 let db = session.db().await;
2734 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2735 db.dismiss_contact_notification(responder_id, requester_id)
2736 .await?;
2737 } else {
2738 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2739
2740 let notifications = db
2741 .respond_to_contact_request(responder_id, requester_id, accept)
2742 .await?;
2743 let requester_busy = db.is_user_busy(requester_id).await?;
2744 let responder_busy = db.is_user_busy(responder_id).await?;
2745
2746 let pool = session.connection_pool().await;
2747 // Update responder with new contact
2748 let mut update = proto::UpdateContacts::default();
2749 if accept {
2750 update
2751 .contacts
2752 .push(contact_for_user(requester_id, requester_busy, &pool));
2753 }
2754 update
2755 .remove_incoming_requests
2756 .push(requester_id.to_proto());
2757 for connection_id in pool.user_connection_ids(responder_id) {
2758 session.peer.send(connection_id, update.clone())?;
2759 }
2760
2761 // Update requester with new contact
2762 let mut update = proto::UpdateContacts::default();
2763 if accept {
2764 update
2765 .contacts
2766 .push(contact_for_user(responder_id, responder_busy, &pool));
2767 }
2768 update
2769 .remove_outgoing_requests
2770 .push(responder_id.to_proto());
2771
2772 for connection_id in pool.user_connection_ids(requester_id) {
2773 session.peer.send(connection_id, update.clone())?;
2774 }
2775
2776 send_notifications(&pool, &session.peer, notifications);
2777 }
2778
2779 response.send(proto::Ack {})?;
2780 Ok(())
2781}
2782
2783/// Remove a contact.
2784async fn remove_contact(
2785 request: proto::RemoveContact,
2786 response: Response<proto::RemoveContact>,
2787 session: MessageContext,
2788) -> Result<()> {
2789 let requester_id = session.user_id();
2790 let responder_id = UserId::from_proto(request.user_id);
2791 let db = session.db().await;
2792 let (contact_accepted, deleted_notification_id) =
2793 db.remove_contact(requester_id, responder_id).await?;
2794
2795 let pool = session.connection_pool().await;
2796 // Update outgoing contact requests of requester
2797 let mut update = proto::UpdateContacts::default();
2798 if contact_accepted {
2799 update.remove_contacts.push(responder_id.to_proto());
2800 } else {
2801 update
2802 .remove_outgoing_requests
2803 .push(responder_id.to_proto());
2804 }
2805 for connection_id in pool.user_connection_ids(requester_id) {
2806 session.peer.send(connection_id, update.clone())?;
2807 }
2808
2809 // Update incoming contact requests of responder
2810 let mut update = proto::UpdateContacts::default();
2811 if contact_accepted {
2812 update.remove_contacts.push(requester_id.to_proto());
2813 } else {
2814 update
2815 .remove_incoming_requests
2816 .push(requester_id.to_proto());
2817 }
2818 for connection_id in pool.user_connection_ids(responder_id) {
2819 session.peer.send(connection_id, update.clone())?;
2820 if let Some(notification_id) = deleted_notification_id {
2821 session.peer.send(
2822 connection_id,
2823 proto::DeleteNotification {
2824 notification_id: notification_id.to_proto(),
2825 },
2826 )?;
2827 }
2828 }
2829
2830 response.send(proto::Ack {})?;
2831 Ok(())
2832}
2833
2834fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2835 version.0.minor() < 139
2836}
2837
2838async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2839 if is_staff {
2840 return Ok(proto::Plan::ZedPro);
2841 }
2842
2843 let subscription = db.get_active_billing_subscription(user_id).await?;
2844 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2845
2846 let plan = if let Some(subscription_kind) = subscription_kind {
2847 match subscription_kind {
2848 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2849 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2850 SubscriptionKind::ZedFree => proto::Plan::Free,
2851 }
2852 } else {
2853 proto::Plan::Free
2854 };
2855
2856 Ok(plan)
2857}
2858
2859async fn make_update_user_plan_message(
2860 user: &User,
2861 is_staff: bool,
2862 db: &Arc<Database>,
2863 llm_db: Option<Arc<LlmDatabase>>,
2864) -> Result<proto::UpdateUserPlan> {
2865 let feature_flags = db.get_user_flags(user.id).await?;
2866 let plan = current_plan(db, user.id, is_staff).await?;
2867 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2868 let billing_preferences = db.get_billing_preferences(user.id).await?;
2869
2870 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2871 let subscription = db.get_active_billing_subscription(user.id).await?;
2872
2873 let subscription_period =
2874 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2875
2876 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2877 llm_db
2878 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2879 .await?
2880 } else {
2881 None
2882 };
2883
2884 (subscription_period, usage)
2885 } else {
2886 (None, None)
2887 };
2888
2889 let bypass_account_age_check = feature_flags
2890 .iter()
2891 .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2892 let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2893 && !bypass_account_age_check
2894 && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2895
2896 Ok(proto::UpdateUserPlan {
2897 plan: plan.into(),
2898 trial_started_at: billing_customer
2899 .as_ref()
2900 .and_then(|billing_customer| billing_customer.trial_started_at)
2901 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2902 is_usage_based_billing_enabled: if is_staff {
2903 Some(true)
2904 } else {
2905 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2906 },
2907 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2908 proto::SubscriptionPeriod {
2909 started_at: started_at.timestamp() as u64,
2910 ended_at: ended_at.timestamp() as u64,
2911 }
2912 }),
2913 account_too_young: Some(account_too_young),
2914 has_overdue_invoices: billing_customer
2915 .map(|billing_customer| billing_customer.has_overdue_invoices),
2916 usage: Some(
2917 usage
2918 .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2919 .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2920 ),
2921 })
2922}
2923
2924fn model_requests_limit(
2925 plan: cloud_llm_client::Plan,
2926 feature_flags: &Vec<String>,
2927) -> cloud_llm_client::UsageLimit {
2928 match plan.model_requests_limit() {
2929 cloud_llm_client::UsageLimit::Limited(limit) => {
2930 let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2931 && feature_flags
2932 .iter()
2933 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2934 {
2935 1_000
2936 } else {
2937 limit
2938 };
2939
2940 cloud_llm_client::UsageLimit::Limited(limit)
2941 }
2942 cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2943 }
2944}
2945
2946fn subscription_usage_to_proto(
2947 plan: proto::Plan,
2948 usage: crate::llm::db::subscription_usage::Model,
2949 feature_flags: &Vec<String>,
2950) -> proto::SubscriptionUsage {
2951 let plan = match plan {
2952 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2953 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2954 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2955 };
2956
2957 proto::SubscriptionUsage {
2958 model_requests_usage_amount: usage.model_requests as u32,
2959 model_requests_usage_limit: Some(proto::UsageLimit {
2960 variant: Some(match model_requests_limit(plan, feature_flags) {
2961 cloud_llm_client::UsageLimit::Limited(limit) => {
2962 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2963 limit: limit as u32,
2964 })
2965 }
2966 cloud_llm_client::UsageLimit::Unlimited => {
2967 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2968 }
2969 }),
2970 }),
2971 edit_predictions_usage_amount: usage.edit_predictions as u32,
2972 edit_predictions_usage_limit: Some(proto::UsageLimit {
2973 variant: Some(match plan.edit_predictions_limit() {
2974 cloud_llm_client::UsageLimit::Limited(limit) => {
2975 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2976 limit: limit as u32,
2977 })
2978 }
2979 cloud_llm_client::UsageLimit::Unlimited => {
2980 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2981 }
2982 }),
2983 }),
2984 }
2985}
2986
2987fn make_default_subscription_usage(
2988 plan: proto::Plan,
2989 feature_flags: &Vec<String>,
2990) -> proto::SubscriptionUsage {
2991 let plan = match plan {
2992 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2993 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2994 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2995 };
2996
2997 proto::SubscriptionUsage {
2998 model_requests_usage_amount: 0,
2999 model_requests_usage_limit: Some(proto::UsageLimit {
3000 variant: Some(match model_requests_limit(plan, feature_flags) {
3001 cloud_llm_client::UsageLimit::Limited(limit) => {
3002 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3003 limit: limit as u32,
3004 })
3005 }
3006 cloud_llm_client::UsageLimit::Unlimited => {
3007 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3008 }
3009 }),
3010 }),
3011 edit_predictions_usage_amount: 0,
3012 edit_predictions_usage_limit: Some(proto::UsageLimit {
3013 variant: Some(match plan.edit_predictions_limit() {
3014 cloud_llm_client::UsageLimit::Limited(limit) => {
3015 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
3016 limit: limit as u32,
3017 })
3018 }
3019 cloud_llm_client::UsageLimit::Unlimited => {
3020 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
3021 }
3022 }),
3023 }),
3024 }
3025}
3026
3027async fn update_user_plan(session: &Session) -> Result<()> {
3028 let db = session.db().await;
3029
3030 let update_user_plan = make_update_user_plan_message(
3031 session.principal.user(),
3032 session.is_staff(),
3033 &db.0,
3034 session.app_state.llm_db.clone(),
3035 )
3036 .await?;
3037
3038 session
3039 .peer
3040 .send(session.connection_id, update_user_plan)
3041 .trace_err();
3042
3043 Ok(())
3044}
3045
3046async fn subscribe_to_channels(
3047 _: proto::SubscribeToChannels,
3048 session: MessageContext,
3049) -> Result<()> {
3050 subscribe_user_to_channels(session.user_id(), &session).await?;
3051 Ok(())
3052}
3053
3054async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3055 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3056 let mut pool = session.connection_pool().await;
3057 for membership in &channels_for_user.channel_memberships {
3058 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3059 }
3060 session.peer.send(
3061 session.connection_id,
3062 build_update_user_channels(&channels_for_user),
3063 )?;
3064 session.peer.send(
3065 session.connection_id,
3066 build_channels_update(channels_for_user),
3067 )?;
3068 Ok(())
3069}
3070
3071/// Creates a new channel.
3072async fn create_channel(
3073 request: proto::CreateChannel,
3074 response: Response<proto::CreateChannel>,
3075 session: MessageContext,
3076) -> Result<()> {
3077 let db = session.db().await;
3078
3079 let parent_id = request.parent_id.map(ChannelId::from_proto);
3080 let (channel, membership) = db
3081 .create_channel(&request.name, parent_id, session.user_id())
3082 .await?;
3083
3084 let root_id = channel.root_id();
3085 let channel = Channel::from_model(channel);
3086
3087 response.send(proto::CreateChannelResponse {
3088 channel: Some(channel.to_proto()),
3089 parent_id: request.parent_id,
3090 })?;
3091
3092 let mut connection_pool = session.connection_pool().await;
3093 if let Some(membership) = membership {
3094 connection_pool.subscribe_to_channel(
3095 membership.user_id,
3096 membership.channel_id,
3097 membership.role,
3098 );
3099 let update = proto::UpdateUserChannels {
3100 channel_memberships: vec![proto::ChannelMembership {
3101 channel_id: membership.channel_id.to_proto(),
3102 role: membership.role.into(),
3103 }],
3104 ..Default::default()
3105 };
3106 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3107 session.peer.send(connection_id, update.clone())?;
3108 }
3109 }
3110
3111 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3112 if !role.can_see_channel(channel.visibility) {
3113 continue;
3114 }
3115
3116 let update = proto::UpdateChannels {
3117 channels: vec![channel.to_proto()],
3118 ..Default::default()
3119 };
3120 session.peer.send(connection_id, update.clone())?;
3121 }
3122
3123 Ok(())
3124}
3125
3126/// Delete a channel
3127async fn delete_channel(
3128 request: proto::DeleteChannel,
3129 response: Response<proto::DeleteChannel>,
3130 session: MessageContext,
3131) -> Result<()> {
3132 let db = session.db().await;
3133
3134 let channel_id = request.channel_id;
3135 let (root_channel, removed_channels) = db
3136 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3137 .await?;
3138 response.send(proto::Ack {})?;
3139
3140 // Notify members of removed channels
3141 let mut update = proto::UpdateChannels::default();
3142 update
3143 .delete_channels
3144 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3145
3146 let connection_pool = session.connection_pool().await;
3147 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3148 session.peer.send(connection_id, update.clone())?;
3149 }
3150
3151 Ok(())
3152}
3153
3154/// Invite someone to join a channel.
3155async fn invite_channel_member(
3156 request: proto::InviteChannelMember,
3157 response: Response<proto::InviteChannelMember>,
3158 session: MessageContext,
3159) -> Result<()> {
3160 let db = session.db().await;
3161 let channel_id = ChannelId::from_proto(request.channel_id);
3162 let invitee_id = UserId::from_proto(request.user_id);
3163 let InviteMemberResult {
3164 channel,
3165 notifications,
3166 } = db
3167 .invite_channel_member(
3168 channel_id,
3169 invitee_id,
3170 session.user_id(),
3171 request.role().into(),
3172 )
3173 .await?;
3174
3175 let update = proto::UpdateChannels {
3176 channel_invitations: vec![channel.to_proto()],
3177 ..Default::default()
3178 };
3179
3180 let connection_pool = session.connection_pool().await;
3181 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3182 session.peer.send(connection_id, update.clone())?;
3183 }
3184
3185 send_notifications(&connection_pool, &session.peer, notifications);
3186
3187 response.send(proto::Ack {})?;
3188 Ok(())
3189}
3190
3191/// remove someone from a channel
3192async fn remove_channel_member(
3193 request: proto::RemoveChannelMember,
3194 response: Response<proto::RemoveChannelMember>,
3195 session: MessageContext,
3196) -> Result<()> {
3197 let db = session.db().await;
3198 let channel_id = ChannelId::from_proto(request.channel_id);
3199 let member_id = UserId::from_proto(request.user_id);
3200
3201 let RemoveChannelMemberResult {
3202 membership_update,
3203 notification_id,
3204 } = db
3205 .remove_channel_member(channel_id, member_id, session.user_id())
3206 .await?;
3207
3208 let mut connection_pool = session.connection_pool().await;
3209 notify_membership_updated(
3210 &mut connection_pool,
3211 membership_update,
3212 member_id,
3213 &session.peer,
3214 );
3215 for connection_id in connection_pool.user_connection_ids(member_id) {
3216 if let Some(notification_id) = notification_id {
3217 session
3218 .peer
3219 .send(
3220 connection_id,
3221 proto::DeleteNotification {
3222 notification_id: notification_id.to_proto(),
3223 },
3224 )
3225 .trace_err();
3226 }
3227 }
3228
3229 response.send(proto::Ack {})?;
3230 Ok(())
3231}
3232
3233/// Toggle the channel between public and private.
3234/// Care is taken to maintain the invariant that public channels only descend from public channels,
3235/// (though members-only channels can appear at any point in the hierarchy).
3236async fn set_channel_visibility(
3237 request: proto::SetChannelVisibility,
3238 response: Response<proto::SetChannelVisibility>,
3239 session: MessageContext,
3240) -> Result<()> {
3241 let db = session.db().await;
3242 let channel_id = ChannelId::from_proto(request.channel_id);
3243 let visibility = request.visibility().into();
3244
3245 let channel_model = db
3246 .set_channel_visibility(channel_id, visibility, session.user_id())
3247 .await?;
3248 let root_id = channel_model.root_id();
3249 let channel = Channel::from_model(channel_model);
3250
3251 let mut connection_pool = session.connection_pool().await;
3252 for (user_id, role) in connection_pool
3253 .channel_user_ids(root_id)
3254 .collect::<Vec<_>>()
3255 .into_iter()
3256 {
3257 let update = if role.can_see_channel(channel.visibility) {
3258 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3259 proto::UpdateChannels {
3260 channels: vec![channel.to_proto()],
3261 ..Default::default()
3262 }
3263 } else {
3264 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3265 proto::UpdateChannels {
3266 delete_channels: vec![channel.id.to_proto()],
3267 ..Default::default()
3268 }
3269 };
3270
3271 for connection_id in connection_pool.user_connection_ids(user_id) {
3272 session.peer.send(connection_id, update.clone())?;
3273 }
3274 }
3275
3276 response.send(proto::Ack {})?;
3277 Ok(())
3278}
3279
3280/// Alter the role for a user in the channel.
3281async fn set_channel_member_role(
3282 request: proto::SetChannelMemberRole,
3283 response: Response<proto::SetChannelMemberRole>,
3284 session: MessageContext,
3285) -> Result<()> {
3286 let db = session.db().await;
3287 let channel_id = ChannelId::from_proto(request.channel_id);
3288 let member_id = UserId::from_proto(request.user_id);
3289 let result = db
3290 .set_channel_member_role(
3291 channel_id,
3292 session.user_id(),
3293 member_id,
3294 request.role().into(),
3295 )
3296 .await?;
3297
3298 match result {
3299 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3300 let mut connection_pool = session.connection_pool().await;
3301 notify_membership_updated(
3302 &mut connection_pool,
3303 membership_update,
3304 member_id,
3305 &session.peer,
3306 )
3307 }
3308 db::SetMemberRoleResult::InviteUpdated(channel) => {
3309 let update = proto::UpdateChannels {
3310 channel_invitations: vec![channel.to_proto()],
3311 ..Default::default()
3312 };
3313
3314 for connection_id in session
3315 .connection_pool()
3316 .await
3317 .user_connection_ids(member_id)
3318 {
3319 session.peer.send(connection_id, update.clone())?;
3320 }
3321 }
3322 }
3323
3324 response.send(proto::Ack {})?;
3325 Ok(())
3326}
3327
3328/// Change the name of a channel
3329async fn rename_channel(
3330 request: proto::RenameChannel,
3331 response: Response<proto::RenameChannel>,
3332 session: MessageContext,
3333) -> Result<()> {
3334 let db = session.db().await;
3335 let channel_id = ChannelId::from_proto(request.channel_id);
3336 let channel_model = db
3337 .rename_channel(channel_id, session.user_id(), &request.name)
3338 .await?;
3339 let root_id = channel_model.root_id();
3340 let channel = Channel::from_model(channel_model);
3341
3342 response.send(proto::RenameChannelResponse {
3343 channel: Some(channel.to_proto()),
3344 })?;
3345
3346 let connection_pool = session.connection_pool().await;
3347 let update = proto::UpdateChannels {
3348 channels: vec![channel.to_proto()],
3349 ..Default::default()
3350 };
3351 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3352 if role.can_see_channel(channel.visibility) {
3353 session.peer.send(connection_id, update.clone())?;
3354 }
3355 }
3356
3357 Ok(())
3358}
3359
3360/// Move a channel to a new parent.
3361async fn move_channel(
3362 request: proto::MoveChannel,
3363 response: Response<proto::MoveChannel>,
3364 session: MessageContext,
3365) -> Result<()> {
3366 let channel_id = ChannelId::from_proto(request.channel_id);
3367 let to = ChannelId::from_proto(request.to);
3368
3369 let (root_id, channels) = session
3370 .db()
3371 .await
3372 .move_channel(channel_id, to, session.user_id())
3373 .await?;
3374
3375 let connection_pool = session.connection_pool().await;
3376 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3377 let channels = channels
3378 .iter()
3379 .filter_map(|channel| {
3380 if role.can_see_channel(channel.visibility) {
3381 Some(channel.to_proto())
3382 } else {
3383 None
3384 }
3385 })
3386 .collect::<Vec<_>>();
3387 if channels.is_empty() {
3388 continue;
3389 }
3390
3391 let update = proto::UpdateChannels {
3392 channels,
3393 ..Default::default()
3394 };
3395
3396 session.peer.send(connection_id, update.clone())?;
3397 }
3398
3399 response.send(Ack {})?;
3400 Ok(())
3401}
3402
3403async fn reorder_channel(
3404 request: proto::ReorderChannel,
3405 response: Response<proto::ReorderChannel>,
3406 session: MessageContext,
3407) -> Result<()> {
3408 let channel_id = ChannelId::from_proto(request.channel_id);
3409 let direction = request.direction();
3410
3411 let updated_channels = session
3412 .db()
3413 .await
3414 .reorder_channel(channel_id, direction, session.user_id())
3415 .await?;
3416
3417 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3418 let connection_pool = session.connection_pool().await;
3419 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3420 let channels = updated_channels
3421 .iter()
3422 .filter_map(|channel| {
3423 if role.can_see_channel(channel.visibility) {
3424 Some(channel.to_proto())
3425 } else {
3426 None
3427 }
3428 })
3429 .collect::<Vec<_>>();
3430
3431 if channels.is_empty() {
3432 continue;
3433 }
3434
3435 let update = proto::UpdateChannels {
3436 channels,
3437 ..Default::default()
3438 };
3439
3440 session.peer.send(connection_id, update.clone())?;
3441 }
3442 }
3443
3444 response.send(Ack {})?;
3445 Ok(())
3446}
3447
3448/// Get the list of channel members
3449async fn get_channel_members(
3450 request: proto::GetChannelMembers,
3451 response: Response<proto::GetChannelMembers>,
3452 session: MessageContext,
3453) -> Result<()> {
3454 let db = session.db().await;
3455 let channel_id = ChannelId::from_proto(request.channel_id);
3456 let limit = if request.limit == 0 {
3457 u16::MAX as u64
3458 } else {
3459 request.limit
3460 };
3461 let (members, users) = db
3462 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3463 .await?;
3464 response.send(proto::GetChannelMembersResponse { members, users })?;
3465 Ok(())
3466}
3467
3468/// Accept or decline a channel invitation.
3469async fn respond_to_channel_invite(
3470 request: proto::RespondToChannelInvite,
3471 response: Response<proto::RespondToChannelInvite>,
3472 session: MessageContext,
3473) -> Result<()> {
3474 let db = session.db().await;
3475 let channel_id = ChannelId::from_proto(request.channel_id);
3476 let RespondToChannelInvite {
3477 membership_update,
3478 notifications,
3479 } = db
3480 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3481 .await?;
3482
3483 let mut connection_pool = session.connection_pool().await;
3484 if let Some(membership_update) = membership_update {
3485 notify_membership_updated(
3486 &mut connection_pool,
3487 membership_update,
3488 session.user_id(),
3489 &session.peer,
3490 );
3491 } else {
3492 let update = proto::UpdateChannels {
3493 remove_channel_invitations: vec![channel_id.to_proto()],
3494 ..Default::default()
3495 };
3496
3497 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3498 session.peer.send(connection_id, update.clone())?;
3499 }
3500 };
3501
3502 send_notifications(&connection_pool, &session.peer, notifications);
3503
3504 response.send(proto::Ack {})?;
3505
3506 Ok(())
3507}
3508
3509/// Join the channels' room
3510async fn join_channel(
3511 request: proto::JoinChannel,
3512 response: Response<proto::JoinChannel>,
3513 session: MessageContext,
3514) -> Result<()> {
3515 let channel_id = ChannelId::from_proto(request.channel_id);
3516 join_channel_internal(channel_id, Box::new(response), session).await
3517}
3518
3519trait JoinChannelInternalResponse {
3520 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3521}
3522impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3523 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3524 Response::<proto::JoinChannel>::send(self, result)
3525 }
3526}
3527impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3528 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3529 Response::<proto::JoinRoom>::send(self, result)
3530 }
3531}
3532
3533async fn join_channel_internal(
3534 channel_id: ChannelId,
3535 response: Box<impl JoinChannelInternalResponse>,
3536 session: MessageContext,
3537) -> Result<()> {
3538 let joined_room = {
3539 let mut db = session.db().await;
3540 // If zed quits without leaving the room, and the user re-opens zed before the
3541 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3542 // room they were in.
3543 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3544 tracing::info!(
3545 stale_connection_id = %connection,
3546 "cleaning up stale connection",
3547 );
3548 drop(db);
3549 leave_room_for_session(&session, connection).await?;
3550 db = session.db().await;
3551 }
3552
3553 let (joined_room, membership_updated, role) = db
3554 .join_channel(channel_id, session.user_id(), session.connection_id)
3555 .await?;
3556
3557 let live_kit_connection_info =
3558 session
3559 .app_state
3560 .livekit_client
3561 .as_ref()
3562 .and_then(|live_kit| {
3563 let (can_publish, token) = if role == ChannelRole::Guest {
3564 (
3565 false,
3566 live_kit
3567 .guest_token(
3568 &joined_room.room.livekit_room,
3569 &session.user_id().to_string(),
3570 )
3571 .trace_err()?,
3572 )
3573 } else {
3574 (
3575 true,
3576 live_kit
3577 .room_token(
3578 &joined_room.room.livekit_room,
3579 &session.user_id().to_string(),
3580 )
3581 .trace_err()?,
3582 )
3583 };
3584
3585 Some(LiveKitConnectionInfo {
3586 server_url: live_kit.url().into(),
3587 token,
3588 can_publish,
3589 })
3590 });
3591
3592 response.send(proto::JoinRoomResponse {
3593 room: Some(joined_room.room.clone()),
3594 channel_id: joined_room
3595 .channel
3596 .as_ref()
3597 .map(|channel| channel.id.to_proto()),
3598 live_kit_connection_info,
3599 })?;
3600
3601 let mut connection_pool = session.connection_pool().await;
3602 if let Some(membership_updated) = membership_updated {
3603 notify_membership_updated(
3604 &mut connection_pool,
3605 membership_updated,
3606 session.user_id(),
3607 &session.peer,
3608 );
3609 }
3610
3611 room_updated(&joined_room.room, &session.peer);
3612
3613 joined_room
3614 };
3615
3616 channel_updated(
3617 &joined_room.channel.context("channel not returned")?,
3618 &joined_room.room,
3619 &session.peer,
3620 &*session.connection_pool().await,
3621 );
3622
3623 update_user_contacts(session.user_id(), &session).await?;
3624 Ok(())
3625}
3626
3627/// Start editing the channel notes
3628async fn join_channel_buffer(
3629 request: proto::JoinChannelBuffer,
3630 response: Response<proto::JoinChannelBuffer>,
3631 session: MessageContext,
3632) -> Result<()> {
3633 let db = session.db().await;
3634 let channel_id = ChannelId::from_proto(request.channel_id);
3635
3636 let open_response = db
3637 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3638 .await?;
3639
3640 let collaborators = open_response.collaborators.clone();
3641 response.send(open_response)?;
3642
3643 let update = UpdateChannelBufferCollaborators {
3644 channel_id: channel_id.to_proto(),
3645 collaborators: collaborators.clone(),
3646 };
3647 channel_buffer_updated(
3648 session.connection_id,
3649 collaborators
3650 .iter()
3651 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3652 &update,
3653 &session.peer,
3654 );
3655
3656 Ok(())
3657}
3658
3659/// Edit the channel notes
3660async fn update_channel_buffer(
3661 request: proto::UpdateChannelBuffer,
3662 session: MessageContext,
3663) -> Result<()> {
3664 let db = session.db().await;
3665 let channel_id = ChannelId::from_proto(request.channel_id);
3666
3667 let (collaborators, epoch, version) = db
3668 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3669 .await?;
3670
3671 channel_buffer_updated(
3672 session.connection_id,
3673 collaborators.clone(),
3674 &proto::UpdateChannelBuffer {
3675 channel_id: channel_id.to_proto(),
3676 operations: request.operations,
3677 },
3678 &session.peer,
3679 );
3680
3681 let pool = &*session.connection_pool().await;
3682
3683 let non_collaborators =
3684 pool.channel_connection_ids(channel_id)
3685 .filter_map(|(connection_id, _)| {
3686 if collaborators.contains(&connection_id) {
3687 None
3688 } else {
3689 Some(connection_id)
3690 }
3691 });
3692
3693 broadcast(None, non_collaborators, |peer_id| {
3694 session.peer.send(
3695 peer_id,
3696 proto::UpdateChannels {
3697 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3698 channel_id: channel_id.to_proto(),
3699 epoch: epoch as u64,
3700 version: version.clone(),
3701 }],
3702 ..Default::default()
3703 },
3704 )
3705 });
3706
3707 Ok(())
3708}
3709
3710/// Rejoin the channel notes after a connection blip
3711async fn rejoin_channel_buffers(
3712 request: proto::RejoinChannelBuffers,
3713 response: Response<proto::RejoinChannelBuffers>,
3714 session: MessageContext,
3715) -> Result<()> {
3716 let db = session.db().await;
3717 let buffers = db
3718 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3719 .await?;
3720
3721 for rejoined_buffer in &buffers {
3722 let collaborators_to_notify = rejoined_buffer
3723 .buffer
3724 .collaborators
3725 .iter()
3726 .filter_map(|c| Some(c.peer_id?.into()));
3727 channel_buffer_updated(
3728 session.connection_id,
3729 collaborators_to_notify,
3730 &proto::UpdateChannelBufferCollaborators {
3731 channel_id: rejoined_buffer.buffer.channel_id,
3732 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3733 },
3734 &session.peer,
3735 );
3736 }
3737
3738 response.send(proto::RejoinChannelBuffersResponse {
3739 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3740 })?;
3741
3742 Ok(())
3743}
3744
3745/// Stop editing the channel notes
3746async fn leave_channel_buffer(
3747 request: proto::LeaveChannelBuffer,
3748 response: Response<proto::LeaveChannelBuffer>,
3749 session: MessageContext,
3750) -> Result<()> {
3751 let db = session.db().await;
3752 let channel_id = ChannelId::from_proto(request.channel_id);
3753
3754 let left_buffer = db
3755 .leave_channel_buffer(channel_id, session.connection_id)
3756 .await?;
3757
3758 response.send(Ack {})?;
3759
3760 channel_buffer_updated(
3761 session.connection_id,
3762 left_buffer.connections,
3763 &proto::UpdateChannelBufferCollaborators {
3764 channel_id: channel_id.to_proto(),
3765 collaborators: left_buffer.collaborators,
3766 },
3767 &session.peer,
3768 );
3769
3770 Ok(())
3771}
3772
3773fn channel_buffer_updated<T: EnvelopedMessage>(
3774 sender_id: ConnectionId,
3775 collaborators: impl IntoIterator<Item = ConnectionId>,
3776 message: &T,
3777 peer: &Peer,
3778) {
3779 broadcast(Some(sender_id), collaborators, |peer_id| {
3780 peer.send(peer_id, message.clone())
3781 });
3782}
3783
3784fn send_notifications(
3785 connection_pool: &ConnectionPool,
3786 peer: &Peer,
3787 notifications: db::NotificationBatch,
3788) {
3789 for (user_id, notification) in notifications {
3790 for connection_id in connection_pool.user_connection_ids(user_id) {
3791 if let Err(error) = peer.send(
3792 connection_id,
3793 proto::AddNotification {
3794 notification: Some(notification.clone()),
3795 },
3796 ) {
3797 tracing::error!(
3798 "failed to send notification to {:?} {}",
3799 connection_id,
3800 error
3801 );
3802 }
3803 }
3804 }
3805}
3806
3807/// Send a message to the channel
3808async fn send_channel_message(
3809 request: proto::SendChannelMessage,
3810 response: Response<proto::SendChannelMessage>,
3811 session: MessageContext,
3812) -> Result<()> {
3813 // Validate the message body.
3814 let body = request.body.trim().to_string();
3815 if body.len() > MAX_MESSAGE_LEN {
3816 return Err(anyhow!("message is too long"))?;
3817 }
3818 if body.is_empty() {
3819 return Err(anyhow!("message can't be blank"))?;
3820 }
3821
3822 // TODO: adjust mentions if body is trimmed
3823
3824 let timestamp = OffsetDateTime::now_utc();
3825 let nonce = request.nonce.context("nonce can't be blank")?;
3826
3827 let channel_id = ChannelId::from_proto(request.channel_id);
3828 let CreatedChannelMessage {
3829 message_id,
3830 participant_connection_ids,
3831 notifications,
3832 } = session
3833 .db()
3834 .await
3835 .create_channel_message(
3836 channel_id,
3837 session.user_id(),
3838 &body,
3839 &request.mentions,
3840 timestamp,
3841 nonce.clone().into(),
3842 request.reply_to_message_id.map(MessageId::from_proto),
3843 )
3844 .await?;
3845
3846 let message = proto::ChannelMessage {
3847 sender_id: session.user_id().to_proto(),
3848 id: message_id.to_proto(),
3849 body,
3850 mentions: request.mentions,
3851 timestamp: timestamp.unix_timestamp() as u64,
3852 nonce: Some(nonce),
3853 reply_to_message_id: request.reply_to_message_id,
3854 edited_at: None,
3855 };
3856 broadcast(
3857 Some(session.connection_id),
3858 participant_connection_ids.clone(),
3859 |connection| {
3860 session.peer.send(
3861 connection,
3862 proto::ChannelMessageSent {
3863 channel_id: channel_id.to_proto(),
3864 message: Some(message.clone()),
3865 },
3866 )
3867 },
3868 );
3869 response.send(proto::SendChannelMessageResponse {
3870 message: Some(message),
3871 })?;
3872
3873 let pool = &*session.connection_pool().await;
3874 let non_participants =
3875 pool.channel_connection_ids(channel_id)
3876 .filter_map(|(connection_id, _)| {
3877 if participant_connection_ids.contains(&connection_id) {
3878 None
3879 } else {
3880 Some(connection_id)
3881 }
3882 });
3883 broadcast(None, non_participants, |peer_id| {
3884 session.peer.send(
3885 peer_id,
3886 proto::UpdateChannels {
3887 latest_channel_message_ids: vec![proto::ChannelMessageId {
3888 channel_id: channel_id.to_proto(),
3889 message_id: message_id.to_proto(),
3890 }],
3891 ..Default::default()
3892 },
3893 )
3894 });
3895 send_notifications(pool, &session.peer, notifications);
3896
3897 Ok(())
3898}
3899
3900/// Delete a channel message
3901async fn remove_channel_message(
3902 request: proto::RemoveChannelMessage,
3903 response: Response<proto::RemoveChannelMessage>,
3904 session: MessageContext,
3905) -> Result<()> {
3906 let channel_id = ChannelId::from_proto(request.channel_id);
3907 let message_id = MessageId::from_proto(request.message_id);
3908 let (connection_ids, existing_notification_ids) = session
3909 .db()
3910 .await
3911 .remove_channel_message(channel_id, message_id, session.user_id())
3912 .await?;
3913
3914 broadcast(
3915 Some(session.connection_id),
3916 connection_ids,
3917 move |connection| {
3918 session.peer.send(connection, request.clone())?;
3919
3920 for notification_id in &existing_notification_ids {
3921 session.peer.send(
3922 connection,
3923 proto::DeleteNotification {
3924 notification_id: (*notification_id).to_proto(),
3925 },
3926 )?;
3927 }
3928
3929 Ok(())
3930 },
3931 );
3932 response.send(proto::Ack {})?;
3933 Ok(())
3934}
3935
3936async fn update_channel_message(
3937 request: proto::UpdateChannelMessage,
3938 response: Response<proto::UpdateChannelMessage>,
3939 session: MessageContext,
3940) -> Result<()> {
3941 let channel_id = ChannelId::from_proto(request.channel_id);
3942 let message_id = MessageId::from_proto(request.message_id);
3943 let updated_at = OffsetDateTime::now_utc();
3944 let UpdatedChannelMessage {
3945 message_id,
3946 participant_connection_ids,
3947 notifications,
3948 reply_to_message_id,
3949 timestamp,
3950 deleted_mention_notification_ids,
3951 updated_mention_notifications,
3952 } = session
3953 .db()
3954 .await
3955 .update_channel_message(
3956 channel_id,
3957 message_id,
3958 session.user_id(),
3959 request.body.as_str(),
3960 &request.mentions,
3961 updated_at,
3962 )
3963 .await?;
3964
3965 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3966
3967 let message = proto::ChannelMessage {
3968 sender_id: session.user_id().to_proto(),
3969 id: message_id.to_proto(),
3970 body: request.body.clone(),
3971 mentions: request.mentions.clone(),
3972 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3973 nonce: Some(nonce),
3974 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3975 edited_at: Some(updated_at.unix_timestamp() as u64),
3976 };
3977
3978 response.send(proto::Ack {})?;
3979
3980 let pool = &*session.connection_pool().await;
3981 broadcast(
3982 Some(session.connection_id),
3983 participant_connection_ids,
3984 |connection| {
3985 session.peer.send(
3986 connection,
3987 proto::ChannelMessageUpdate {
3988 channel_id: channel_id.to_proto(),
3989 message: Some(message.clone()),
3990 },
3991 )?;
3992
3993 for notification_id in &deleted_mention_notification_ids {
3994 session.peer.send(
3995 connection,
3996 proto::DeleteNotification {
3997 notification_id: (*notification_id).to_proto(),
3998 },
3999 )?;
4000 }
4001
4002 for notification in &updated_mention_notifications {
4003 session.peer.send(
4004 connection,
4005 proto::UpdateNotification {
4006 notification: Some(notification.clone()),
4007 },
4008 )?;
4009 }
4010
4011 Ok(())
4012 },
4013 );
4014
4015 send_notifications(pool, &session.peer, notifications);
4016
4017 Ok(())
4018}
4019
4020/// Mark a channel message as read
4021async fn acknowledge_channel_message(
4022 request: proto::AckChannelMessage,
4023 session: MessageContext,
4024) -> Result<()> {
4025 let channel_id = ChannelId::from_proto(request.channel_id);
4026 let message_id = MessageId::from_proto(request.message_id);
4027 let notifications = session
4028 .db()
4029 .await
4030 .observe_channel_message(channel_id, session.user_id(), message_id)
4031 .await?;
4032 send_notifications(
4033 &*session.connection_pool().await,
4034 &session.peer,
4035 notifications,
4036 );
4037 Ok(())
4038}
4039
4040/// Mark a buffer version as synced
4041async fn acknowledge_buffer_version(
4042 request: proto::AckBufferOperation,
4043 session: MessageContext,
4044) -> Result<()> {
4045 let buffer_id = BufferId::from_proto(request.buffer_id);
4046 session
4047 .db()
4048 .await
4049 .observe_buffer_version(
4050 buffer_id,
4051 session.user_id(),
4052 request.epoch as i32,
4053 &request.version,
4054 )
4055 .await?;
4056 Ok(())
4057}
4058
4059/// Get a Supermaven API key for the user
4060async fn get_supermaven_api_key(
4061 _request: proto::GetSupermavenApiKey,
4062 response: Response<proto::GetSupermavenApiKey>,
4063 session: MessageContext,
4064) -> Result<()> {
4065 let user_id: String = session.user_id().to_string();
4066 if !session.is_staff() {
4067 return Err(anyhow!("supermaven not enabled for this account"))?;
4068 }
4069
4070 let email = session.email().context("user must have an email")?;
4071
4072 let supermaven_admin_api = session
4073 .supermaven_client
4074 .as_ref()
4075 .context("supermaven not configured")?;
4076
4077 let result = supermaven_admin_api
4078 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4079 .await?;
4080
4081 response.send(proto::GetSupermavenApiKeyResponse {
4082 api_key: result.api_key,
4083 })?;
4084
4085 Ok(())
4086}
4087
4088/// Start receiving chat updates for a channel
4089async fn join_channel_chat(
4090 request: proto::JoinChannelChat,
4091 response: Response<proto::JoinChannelChat>,
4092 session: MessageContext,
4093) -> Result<()> {
4094 let channel_id = ChannelId::from_proto(request.channel_id);
4095
4096 let db = session.db().await;
4097 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4098 .await?;
4099 let messages = db
4100 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4101 .await?;
4102 response.send(proto::JoinChannelChatResponse {
4103 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4104 messages,
4105 })?;
4106 Ok(())
4107}
4108
4109/// Stop receiving chat updates for a channel
4110async fn leave_channel_chat(
4111 request: proto::LeaveChannelChat,
4112 session: MessageContext,
4113) -> Result<()> {
4114 let channel_id = ChannelId::from_proto(request.channel_id);
4115 session
4116 .db()
4117 .await
4118 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4119 .await?;
4120 Ok(())
4121}
4122
4123/// Retrieve the chat history for a channel
4124async fn get_channel_messages(
4125 request: proto::GetChannelMessages,
4126 response: Response<proto::GetChannelMessages>,
4127 session: MessageContext,
4128) -> Result<()> {
4129 let channel_id = ChannelId::from_proto(request.channel_id);
4130 let messages = session
4131 .db()
4132 .await
4133 .get_channel_messages(
4134 channel_id,
4135 session.user_id(),
4136 MESSAGE_COUNT_PER_PAGE,
4137 Some(MessageId::from_proto(request.before_message_id)),
4138 )
4139 .await?;
4140 response.send(proto::GetChannelMessagesResponse {
4141 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4142 messages,
4143 })?;
4144 Ok(())
4145}
4146
4147/// Retrieve specific chat messages
4148async fn get_channel_messages_by_id(
4149 request: proto::GetChannelMessagesById,
4150 response: Response<proto::GetChannelMessagesById>,
4151 session: MessageContext,
4152) -> Result<()> {
4153 let message_ids = request
4154 .message_ids
4155 .iter()
4156 .map(|id| MessageId::from_proto(*id))
4157 .collect::<Vec<_>>();
4158 let messages = session
4159 .db()
4160 .await
4161 .get_channel_messages_by_id(session.user_id(), &message_ids)
4162 .await?;
4163 response.send(proto::GetChannelMessagesResponse {
4164 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4165 messages,
4166 })?;
4167 Ok(())
4168}
4169
4170/// Retrieve the current users notifications
4171async fn get_notifications(
4172 request: proto::GetNotifications,
4173 response: Response<proto::GetNotifications>,
4174 session: MessageContext,
4175) -> Result<()> {
4176 let notifications = session
4177 .db()
4178 .await
4179 .get_notifications(
4180 session.user_id(),
4181 NOTIFICATION_COUNT_PER_PAGE,
4182 request.before_id.map(db::NotificationId::from_proto),
4183 )
4184 .await?;
4185 response.send(proto::GetNotificationsResponse {
4186 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4187 notifications,
4188 })?;
4189 Ok(())
4190}
4191
4192/// Mark notifications as read
4193async fn mark_notification_as_read(
4194 request: proto::MarkNotificationRead,
4195 response: Response<proto::MarkNotificationRead>,
4196 session: MessageContext,
4197) -> Result<()> {
4198 let database = &session.db().await;
4199 let notifications = database
4200 .mark_notification_as_read_by_id(
4201 session.user_id(),
4202 NotificationId::from_proto(request.notification_id),
4203 )
4204 .await?;
4205 send_notifications(
4206 &*session.connection_pool().await,
4207 &session.peer,
4208 notifications,
4209 );
4210 response.send(proto::Ack {})?;
4211 Ok(())
4212}
4213
4214/// Get the current users information
4215async fn get_private_user_info(
4216 _request: proto::GetPrivateUserInfo,
4217 response: Response<proto::GetPrivateUserInfo>,
4218 session: MessageContext,
4219) -> Result<()> {
4220 let db = session.db().await;
4221
4222 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4223 let user = db
4224 .get_user_by_id(session.user_id())
4225 .await?
4226 .context("user not found")?;
4227 let flags = db.get_user_flags(session.user_id()).await?;
4228
4229 response.send(proto::GetPrivateUserInfoResponse {
4230 metrics_id,
4231 staff: user.admin,
4232 flags,
4233 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4234 })?;
4235 Ok(())
4236}
4237
4238/// Accept the terms of service (tos) on behalf of the current user
4239async fn accept_terms_of_service(
4240 _request: proto::AcceptTermsOfService,
4241 response: Response<proto::AcceptTermsOfService>,
4242 session: MessageContext,
4243) -> Result<()> {
4244 let db = session.db().await;
4245
4246 let accepted_tos_at = Utc::now();
4247 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4248 .await?;
4249
4250 response.send(proto::AcceptTermsOfServiceResponse {
4251 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4252 })?;
4253
4254 // When the user accepts the terms of service, we want to refresh their LLM
4255 // token to grant access.
4256 session
4257 .peer
4258 .send(session.connection_id, proto::RefreshLlmToken {})?;
4259
4260 Ok(())
4261}
4262
4263async fn get_llm_api_token(
4264 _request: proto::GetLlmToken,
4265 response: Response<proto::GetLlmToken>,
4266 session: MessageContext,
4267) -> Result<()> {
4268 let db = session.db().await;
4269
4270 let flags = db.get_user_flags(session.user_id()).await?;
4271
4272 let user_id = session.user_id();
4273 let user = db
4274 .get_user_by_id(user_id)
4275 .await?
4276 .with_context(|| format!("user {user_id} not found"))?;
4277
4278 if user.accepted_tos_at.is_none() {
4279 Err(anyhow!("terms of service not accepted"))?
4280 }
4281
4282 let stripe_client = session
4283 .app_state
4284 .stripe_client
4285 .as_ref()
4286 .context("failed to retrieve Stripe client")?;
4287
4288 let stripe_billing = session
4289 .app_state
4290 .stripe_billing
4291 .as_ref()
4292 .context("failed to retrieve Stripe billing object")?;
4293
4294 let billing_customer = if let Some(billing_customer) =
4295 db.get_billing_customer_by_user_id(user.id).await?
4296 {
4297 billing_customer
4298 } else {
4299 let customer_id = stripe_billing
4300 .find_or_create_customer_by_email(user.email_address.as_deref())
4301 .await?;
4302
4303 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4304 .await?
4305 .context("billing customer not found")?
4306 };
4307
4308 let billing_subscription =
4309 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4310 billing_subscription
4311 } else {
4312 let stripe_customer_id =
4313 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4314
4315 let stripe_subscription = stripe_billing
4316 .subscribe_to_zed_free(stripe_customer_id)
4317 .await?;
4318
4319 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4320 billing_customer_id: billing_customer.id,
4321 kind: Some(SubscriptionKind::ZedFree),
4322 stripe_subscription_id: stripe_subscription.id.to_string(),
4323 stripe_subscription_status: stripe_subscription.status.into(),
4324 stripe_cancellation_reason: None,
4325 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4326 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4327 })
4328 .await?
4329 };
4330
4331 let billing_preferences = db.get_billing_preferences(user.id).await?;
4332
4333 let token = LlmTokenClaims::create(
4334 &user,
4335 session.is_staff(),
4336 billing_customer,
4337 billing_preferences,
4338 &flags,
4339 billing_subscription,
4340 session.system_id.clone(),
4341 &session.app_state.config,
4342 )?;
4343 response.send(proto::GetLlmTokenResponse { token })?;
4344 Ok(())
4345}
4346
4347fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4348 let message = match message {
4349 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4350 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4351 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4352 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4353 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4354 code: frame.code.into(),
4355 reason: frame.reason.as_str().to_owned().into(),
4356 })),
4357 // We should never receive a frame while reading the message, according
4358 // to the `tungstenite` maintainers:
4359 //
4360 // > It cannot occur when you read messages from the WebSocket, but it
4361 // > can be used when you want to send the raw frames (e.g. you want to
4362 // > send the frames to the WebSocket without composing the full message first).
4363 // >
4364 // > — https://github.com/snapview/tungstenite-rs/issues/268
4365 TungsteniteMessage::Frame(_) => {
4366 bail!("received an unexpected frame while reading the message")
4367 }
4368 };
4369
4370 Ok(message)
4371}
4372
4373fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4374 match message {
4375 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4376 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4377 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4378 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4379 AxumMessage::Close(frame) => {
4380 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4381 code: frame.code.into(),
4382 reason: frame.reason.as_ref().into(),
4383 }))
4384 }
4385 }
4386}
4387
4388fn notify_membership_updated(
4389 connection_pool: &mut ConnectionPool,
4390 result: MembershipUpdated,
4391 user_id: UserId,
4392 peer: &Peer,
4393) {
4394 for membership in &result.new_channels.channel_memberships {
4395 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4396 }
4397 for channel_id in &result.removed_channels {
4398 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4399 }
4400
4401 let user_channels_update = proto::UpdateUserChannels {
4402 channel_memberships: result
4403 .new_channels
4404 .channel_memberships
4405 .iter()
4406 .map(|cm| proto::ChannelMembership {
4407 channel_id: cm.channel_id.to_proto(),
4408 role: cm.role.into(),
4409 })
4410 .collect(),
4411 ..Default::default()
4412 };
4413
4414 let mut update = build_channels_update(result.new_channels);
4415 update.delete_channels = result
4416 .removed_channels
4417 .into_iter()
4418 .map(|id| id.to_proto())
4419 .collect();
4420 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4421
4422 for connection_id in connection_pool.user_connection_ids(user_id) {
4423 peer.send(connection_id, user_channels_update.clone())
4424 .trace_err();
4425 peer.send(connection_id, update.clone()).trace_err();
4426 }
4427}
4428
4429fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4430 proto::UpdateUserChannels {
4431 channel_memberships: channels
4432 .channel_memberships
4433 .iter()
4434 .map(|m| proto::ChannelMembership {
4435 channel_id: m.channel_id.to_proto(),
4436 role: m.role.into(),
4437 })
4438 .collect(),
4439 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4440 observed_channel_message_id: channels.observed_channel_messages.clone(),
4441 }
4442}
4443
4444fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4445 let mut update = proto::UpdateChannels::default();
4446
4447 for channel in channels.channels {
4448 update.channels.push(channel.to_proto());
4449 }
4450
4451 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4452 update.latest_channel_message_ids = channels.latest_channel_messages;
4453
4454 for (channel_id, participants) in channels.channel_participants {
4455 update
4456 .channel_participants
4457 .push(proto::ChannelParticipants {
4458 channel_id: channel_id.to_proto(),
4459 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4460 });
4461 }
4462
4463 for channel in channels.invited_channels {
4464 update.channel_invitations.push(channel.to_proto());
4465 }
4466
4467 update
4468}
4469
4470fn build_initial_contacts_update(
4471 contacts: Vec<db::Contact>,
4472 pool: &ConnectionPool,
4473) -> proto::UpdateContacts {
4474 let mut update = proto::UpdateContacts::default();
4475
4476 for contact in contacts {
4477 match contact {
4478 db::Contact::Accepted { user_id, busy } => {
4479 update.contacts.push(contact_for_user(user_id, busy, pool));
4480 }
4481 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4482 db::Contact::Incoming { user_id } => {
4483 update
4484 .incoming_requests
4485 .push(proto::IncomingContactRequest {
4486 requester_id: user_id.to_proto(),
4487 })
4488 }
4489 }
4490 }
4491
4492 update
4493}
4494
4495fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4496 proto::Contact {
4497 user_id: user_id.to_proto(),
4498 online: pool.is_user_online(user_id),
4499 busy,
4500 }
4501}
4502
4503fn room_updated(room: &proto::Room, peer: &Peer) {
4504 broadcast(
4505 None,
4506 room.participants
4507 .iter()
4508 .filter_map(|participant| Some(participant.peer_id?.into())),
4509 |peer_id| {
4510 peer.send(
4511 peer_id,
4512 proto::RoomUpdated {
4513 room: Some(room.clone()),
4514 },
4515 )
4516 },
4517 );
4518}
4519
4520fn channel_updated(
4521 channel: &db::channel::Model,
4522 room: &proto::Room,
4523 peer: &Peer,
4524 pool: &ConnectionPool,
4525) {
4526 let participants = room
4527 .participants
4528 .iter()
4529 .map(|p| p.user_id)
4530 .collect::<Vec<_>>();
4531
4532 broadcast(
4533 None,
4534 pool.channel_connection_ids(channel.root_id())
4535 .filter_map(|(channel_id, role)| {
4536 role.can_see_channel(channel.visibility)
4537 .then_some(channel_id)
4538 }),
4539 |peer_id| {
4540 peer.send(
4541 peer_id,
4542 proto::UpdateChannels {
4543 channel_participants: vec![proto::ChannelParticipants {
4544 channel_id: channel.id.to_proto(),
4545 participant_user_ids: participants.clone(),
4546 }],
4547 ..Default::default()
4548 },
4549 )
4550 },
4551 );
4552}
4553
4554async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4555 let db = session.db().await;
4556
4557 let contacts = db.get_contacts(user_id).await?;
4558 let busy = db.is_user_busy(user_id).await?;
4559
4560 let pool = session.connection_pool().await;
4561 let updated_contact = contact_for_user(user_id, busy, &pool);
4562 for contact in contacts {
4563 if let db::Contact::Accepted {
4564 user_id: contact_user_id,
4565 ..
4566 } = contact
4567 {
4568 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4569 session
4570 .peer
4571 .send(
4572 contact_conn_id,
4573 proto::UpdateContacts {
4574 contacts: vec![updated_contact.clone()],
4575 remove_contacts: Default::default(),
4576 incoming_requests: Default::default(),
4577 remove_incoming_requests: Default::default(),
4578 outgoing_requests: Default::default(),
4579 remove_outgoing_requests: Default::default(),
4580 },
4581 )
4582 .trace_err();
4583 }
4584 }
4585 }
4586 Ok(())
4587}
4588
4589async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4590 let mut contacts_to_update = HashSet::default();
4591
4592 let room_id;
4593 let canceled_calls_to_user_ids;
4594 let livekit_room;
4595 let delete_livekit_room;
4596 let room;
4597 let channel;
4598
4599 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4600 contacts_to_update.insert(session.user_id());
4601
4602 for project in left_room.left_projects.values() {
4603 project_left(project, session);
4604 }
4605
4606 room_id = RoomId::from_proto(left_room.room.id);
4607 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4608 livekit_room = mem::take(&mut left_room.room.livekit_room);
4609 delete_livekit_room = left_room.deleted;
4610 room = mem::take(&mut left_room.room);
4611 channel = mem::take(&mut left_room.channel);
4612
4613 room_updated(&room, &session.peer);
4614 } else {
4615 return Ok(());
4616 }
4617
4618 if let Some(channel) = channel {
4619 channel_updated(
4620 &channel,
4621 &room,
4622 &session.peer,
4623 &*session.connection_pool().await,
4624 );
4625 }
4626
4627 {
4628 let pool = session.connection_pool().await;
4629 for canceled_user_id in canceled_calls_to_user_ids {
4630 for connection_id in pool.user_connection_ids(canceled_user_id) {
4631 session
4632 .peer
4633 .send(
4634 connection_id,
4635 proto::CallCanceled {
4636 room_id: room_id.to_proto(),
4637 },
4638 )
4639 .trace_err();
4640 }
4641 contacts_to_update.insert(canceled_user_id);
4642 }
4643 }
4644
4645 for contact_user_id in contacts_to_update {
4646 update_user_contacts(contact_user_id, session).await?;
4647 }
4648
4649 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4650 live_kit
4651 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4652 .await
4653 .trace_err();
4654
4655 if delete_livekit_room {
4656 live_kit.delete_room(livekit_room).await.trace_err();
4657 }
4658 }
4659
4660 Ok(())
4661}
4662
4663async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4664 let left_channel_buffers = session
4665 .db()
4666 .await
4667 .leave_channel_buffers(session.connection_id)
4668 .await?;
4669
4670 for left_buffer in left_channel_buffers {
4671 channel_buffer_updated(
4672 session.connection_id,
4673 left_buffer.connections,
4674 &proto::UpdateChannelBufferCollaborators {
4675 channel_id: left_buffer.channel_id.to_proto(),
4676 collaborators: left_buffer.collaborators,
4677 },
4678 &session.peer,
4679 );
4680 }
4681
4682 Ok(())
4683}
4684
4685fn project_left(project: &db::LeftProject, session: &Session) {
4686 for connection_id in &project.connection_ids {
4687 if project.should_unshare {
4688 session
4689 .peer
4690 .send(
4691 *connection_id,
4692 proto::UnshareProject {
4693 project_id: project.id.to_proto(),
4694 },
4695 )
4696 .trace_err();
4697 } else {
4698 session
4699 .peer
4700 .send(
4701 *connection_id,
4702 proto::RemoveProjectCollaborator {
4703 project_id: project.id.to_proto(),
4704 peer_id: Some(session.connection_id.into()),
4705 },
4706 )
4707 .trace_err();
4708 }
4709 }
4710}
4711
4712pub trait ResultExt {
4713 type Ok;
4714
4715 fn trace_err(self) -> Option<Self::Ok>;
4716}
4717
4718impl<T, E> ResultExt for Result<T, E>
4719where
4720 E: std::fmt::Debug,
4721{
4722 type Ok = T;
4723
4724 #[track_caller]
4725 fn trace_err(self) -> Option<T> {
4726 match self {
4727 Ok(value) => Some(value),
4728 Err(error) => {
4729 tracing::error!("{:?}", error);
4730 None
4731 }
4732 }
4733 }
4734}