1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{
8 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
9 MIN_ACCOUNT_AGE_FOR_LLM_USE,
10};
11use crate::stripe_client::StripeCustomerId;
12use crate::{
13 AppState, Error, Result, auth,
14 db::{
15 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
16 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
17 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
18 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
19 },
20 executor::Executor,
21};
22use anyhow::{Context as _, anyhow, bail};
23use async_tungstenite::tungstenite::{
24 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
25};
26use axum::headers::UserAgent;
27use axum::{
28 Extension, Router, TypedHeader,
29 body::Body,
30 extract::{
31 ConnectInfo, WebSocketUpgrade,
32 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
33 },
34 headers::{Header, HeaderName},
35 http::StatusCode,
36 middleware,
37 response::IntoResponse,
38 routing::get,
39};
40use chrono::Utc;
41use collections::{HashMap, HashSet};
42pub use connection_pool::{ConnectionPool, ZedVersion};
43use core::fmt::{self, Debug, Formatter};
44use reqwest_client::ReqwestClient;
45use rpc::proto::split_repository_update;
46use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
47
48use futures::{
49 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
50 stream::FuturesUnordered,
51};
52use prometheus::{IntGauge, register_int_gauge};
53use rpc::{
54 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
55 proto::{
56 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
57 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
58 },
59};
60use semantic_version::SemanticVersion;
61use serde::{Serialize, Serializer};
62use std::{
63 any::TypeId,
64 future::Future,
65 marker::PhantomData,
66 mem,
67 net::SocketAddr,
68 ops::{Deref, DerefMut},
69 rc::Rc,
70 sync::{
71 Arc, OnceLock,
72 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
73 },
74 time::{Duration, Instant},
75};
76use time::OffsetDateTime;
77use tokio::sync::{Semaphore, watch};
78use tower::ServiceBuilder;
79use tracing::{
80 Instrument,
81 field::{self},
82 info_span, instrument,
83};
84
85pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
86
87// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
88pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
89
90const MESSAGE_COUNT_PER_PAGE: usize = 100;
91const MAX_MESSAGE_LEN: usize = 1024;
92const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
93const MAX_CONCURRENT_CONNECTIONS: usize = 512;
94
95static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
96
97type MessageHandler =
98 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
99
100pub struct ConnectionGuard;
101
102impl ConnectionGuard {
103 pub fn try_acquire() -> Result<Self, ()> {
104 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
105 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
106 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
107 tracing::error!(
108 "too many concurrent connections: {}",
109 current_connections + 1
110 );
111 return Err(());
112 }
113 Ok(ConnectionGuard)
114 }
115}
116
117impl Drop for ConnectionGuard {
118 fn drop(&mut self) {
119 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
120 }
121}
122
123struct Response<R> {
124 peer: Arc<Peer>,
125 receipt: Receipt<R>,
126 responded: Arc<AtomicBool>,
127}
128
129impl<R: RequestMessage> Response<R> {
130 fn send(self, payload: R::Response) -> Result<()> {
131 self.responded.store(true, SeqCst);
132 self.peer.respond(self.receipt, payload)?;
133 Ok(())
134 }
135}
136
137#[derive(Clone, Debug)]
138pub enum Principal {
139 User(User),
140 Impersonated { user: User, admin: User },
141}
142
143impl Principal {
144 fn user(&self) -> &User {
145 match self {
146 Principal::User(user) => user,
147 Principal::Impersonated { user, .. } => user,
148 }
149 }
150
151 fn update_span(&self, span: &tracing::Span) {
152 match &self {
153 Principal::User(user) => {
154 span.record("user_id", user.id.0);
155 span.record("login", &user.github_login);
156 }
157 Principal::Impersonated { user, admin } => {
158 span.record("user_id", user.id.0);
159 span.record("login", &user.github_login);
160 span.record("impersonator", &admin.github_login);
161 }
162 }
163 }
164}
165
166#[derive(Clone)]
167struct Session {
168 principal: Principal,
169 connection_id: ConnectionId,
170 db: Arc<tokio::sync::Mutex<DbHandle>>,
171 peer: Arc<Peer>,
172 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
173 app_state: Arc<AppState>,
174 supermaven_client: Option<Arc<SupermavenAdminApi>>,
175 /// The GeoIP country code for the user.
176 #[allow(unused)]
177 geoip_country_code: Option<String>,
178 system_id: Option<String>,
179 _executor: Executor,
180}
181
182impl Session {
183 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
184 #[cfg(test)]
185 tokio::task::yield_now().await;
186 let guard = self.db.lock().await;
187 #[cfg(test)]
188 tokio::task::yield_now().await;
189 guard
190 }
191
192 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
193 #[cfg(test)]
194 tokio::task::yield_now().await;
195 let guard = self.connection_pool.lock();
196 ConnectionPoolGuard {
197 guard,
198 _not_send: PhantomData,
199 }
200 }
201
202 fn is_staff(&self) -> bool {
203 match &self.principal {
204 Principal::User(user) => user.admin,
205 Principal::Impersonated { .. } => true,
206 }
207 }
208
209 fn user_id(&self) -> UserId {
210 match &self.principal {
211 Principal::User(user) => user.id,
212 Principal::Impersonated { user, .. } => user.id,
213 }
214 }
215
216 pub fn email(&self) -> Option<String> {
217 match &self.principal {
218 Principal::User(user) => user.email_address.clone(),
219 Principal::Impersonated { user, .. } => user.email_address.clone(),
220 }
221 }
222}
223
224impl Debug for Session {
225 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
226 let mut result = f.debug_struct("Session");
227 match &self.principal {
228 Principal::User(user) => {
229 result.field("user", &user.github_login);
230 }
231 Principal::Impersonated { user, admin } => {
232 result.field("user", &user.github_login);
233 result.field("impersonator", &admin.github_login);
234 }
235 }
236 result.field("connection_id", &self.connection_id).finish()
237 }
238}
239
240struct DbHandle(Arc<Database>);
241
242impl Deref for DbHandle {
243 type Target = Database;
244
245 fn deref(&self) -> &Self::Target {
246 self.0.as_ref()
247 }
248}
249
250pub struct Server {
251 id: parking_lot::Mutex<ServerId>,
252 peer: Arc<Peer>,
253 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
254 app_state: Arc<AppState>,
255 handlers: HashMap<TypeId, MessageHandler>,
256 teardown: watch::Sender<bool>,
257}
258
259pub(crate) struct ConnectionPoolGuard<'a> {
260 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
261 _not_send: PhantomData<Rc<()>>,
262}
263
264#[derive(Serialize)]
265pub struct ServerSnapshot<'a> {
266 peer: &'a Peer,
267 #[serde(serialize_with = "serialize_deref")]
268 connection_pool: ConnectionPoolGuard<'a>,
269}
270
271pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
272where
273 S: Serializer,
274 T: Deref<Target = U>,
275 U: Serialize,
276{
277 Serialize::serialize(value.deref(), serializer)
278}
279
280impl Server {
281 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
282 let mut server = Self {
283 id: parking_lot::Mutex::new(id),
284 peer: Peer::new(id.0 as u32),
285 app_state: app_state.clone(),
286 connection_pool: Default::default(),
287 handlers: Default::default(),
288 teardown: watch::channel(false).0,
289 };
290
291 server
292 .add_request_handler(ping)
293 .add_request_handler(create_room)
294 .add_request_handler(join_room)
295 .add_request_handler(rejoin_room)
296 .add_request_handler(leave_room)
297 .add_request_handler(set_room_participant_role)
298 .add_request_handler(call)
299 .add_request_handler(cancel_call)
300 .add_message_handler(decline_call)
301 .add_request_handler(update_participant_location)
302 .add_request_handler(share_project)
303 .add_message_handler(unshare_project)
304 .add_request_handler(join_project)
305 .add_message_handler(leave_project)
306 .add_request_handler(update_project)
307 .add_request_handler(update_worktree)
308 .add_request_handler(update_repository)
309 .add_request_handler(remove_repository)
310 .add_message_handler(start_language_server)
311 .add_message_handler(update_language_server)
312 .add_message_handler(update_diagnostic_summary)
313 .add_message_handler(update_worktree_settings)
314 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
316 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
317 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
318 .add_request_handler(forward_find_search_candidates_request)
319 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
320 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
321 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
322 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
323 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
324 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
325 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
326 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
327 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
328 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
329 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
330 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
331 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
332 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
333 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
334 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
335 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
336 .add_request_handler(
337 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
338 )
339 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
340 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
341 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
342 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
343 .add_request_handler(
344 forward_read_only_project_request::<proto::LanguageServerIdForName>,
345 )
346 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
347 .add_request_handler(
348 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
349 )
350 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
351 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
352 .add_request_handler(
353 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
354 )
355 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
356 .add_request_handler(
357 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
358 )
359 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
360 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
361 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
362 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
363 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
364 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
365 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
366 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
367 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
368 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
369 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
370 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
371 .add_request_handler(
372 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
373 )
374 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
375 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
376 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
377 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
378 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
379 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
380 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
381 .add_message_handler(create_buffer_for_peer)
382 .add_request_handler(update_buffer)
383 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
384 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
385 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
386 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
387 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
388 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
389 .add_message_handler(
390 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
391 )
392 .add_request_handler(get_users)
393 .add_request_handler(fuzzy_search_users)
394 .add_request_handler(request_contact)
395 .add_request_handler(remove_contact)
396 .add_request_handler(respond_to_contact_request)
397 .add_message_handler(subscribe_to_channels)
398 .add_request_handler(create_channel)
399 .add_request_handler(delete_channel)
400 .add_request_handler(invite_channel_member)
401 .add_request_handler(remove_channel_member)
402 .add_request_handler(set_channel_member_role)
403 .add_request_handler(set_channel_visibility)
404 .add_request_handler(rename_channel)
405 .add_request_handler(join_channel_buffer)
406 .add_request_handler(leave_channel_buffer)
407 .add_message_handler(update_channel_buffer)
408 .add_request_handler(rejoin_channel_buffers)
409 .add_request_handler(get_channel_members)
410 .add_request_handler(respond_to_channel_invite)
411 .add_request_handler(join_channel)
412 .add_request_handler(join_channel_chat)
413 .add_message_handler(leave_channel_chat)
414 .add_request_handler(send_channel_message)
415 .add_request_handler(remove_channel_message)
416 .add_request_handler(update_channel_message)
417 .add_request_handler(get_channel_messages)
418 .add_request_handler(get_channel_messages_by_id)
419 .add_request_handler(get_notifications)
420 .add_request_handler(mark_notification_as_read)
421 .add_request_handler(move_channel)
422 .add_request_handler(reorder_channel)
423 .add_request_handler(follow)
424 .add_message_handler(unfollow)
425 .add_message_handler(update_followers)
426 .add_request_handler(get_private_user_info)
427 .add_request_handler(get_llm_api_token)
428 .add_request_handler(accept_terms_of_service)
429 .add_message_handler(acknowledge_channel_message)
430 .add_message_handler(acknowledge_buffer_version)
431 .add_request_handler(get_supermaven_api_key)
432 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
433 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
434 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
435 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
436 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
437 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
438 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
439 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
440 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
441 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
442 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
443 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
444 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
445 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
446 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
447 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
448 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
449 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
450 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
451 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
452 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
453 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
454 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
455 .add_message_handler(update_context);
456
457 Arc::new(server)
458 }
459
460 pub async fn start(&self) -> Result<()> {
461 let server_id = *self.id.lock();
462 let app_state = self.app_state.clone();
463 let peer = self.peer.clone();
464 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
465 let pool = self.connection_pool.clone();
466 let livekit_client = self.app_state.livekit_client.clone();
467
468 let span = info_span!("start server");
469 self.app_state.executor.spawn_detached(
470 async move {
471 tracing::info!("waiting for cleanup timeout");
472 timeout.await;
473 tracing::info!("cleanup timeout expired, retrieving stale rooms");
474
475 app_state
476 .db
477 .delete_stale_channel_chat_participants(
478 &app_state.config.zed_environment,
479 server_id,
480 )
481 .await
482 .trace_err();
483
484 if let Some((room_ids, channel_ids)) = app_state
485 .db
486 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
487 .await
488 .trace_err()
489 {
490 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
491 tracing::info!(
492 stale_channel_buffer_count = channel_ids.len(),
493 "retrieved stale channel buffers"
494 );
495
496 for channel_id in channel_ids {
497 if let Some(refreshed_channel_buffer) = app_state
498 .db
499 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
500 .await
501 .trace_err()
502 {
503 for connection_id in refreshed_channel_buffer.connection_ids {
504 peer.send(
505 connection_id,
506 proto::UpdateChannelBufferCollaborators {
507 channel_id: channel_id.to_proto(),
508 collaborators: refreshed_channel_buffer
509 .collaborators
510 .clone(),
511 },
512 )
513 .trace_err();
514 }
515 }
516 }
517
518 for room_id in room_ids {
519 let mut contacts_to_update = HashSet::default();
520 let mut canceled_calls_to_user_ids = Vec::new();
521 let mut livekit_room = String::new();
522 let mut delete_livekit_room = false;
523
524 if let Some(mut refreshed_room) = app_state
525 .db
526 .clear_stale_room_participants(room_id, server_id)
527 .await
528 .trace_err()
529 {
530 tracing::info!(
531 room_id = room_id.0,
532 new_participant_count = refreshed_room.room.participants.len(),
533 "refreshed room"
534 );
535 room_updated(&refreshed_room.room, &peer);
536 if let Some(channel) = refreshed_room.channel.as_ref() {
537 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
538 }
539 contacts_to_update
540 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
541 contacts_to_update
542 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
543 canceled_calls_to_user_ids =
544 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
545 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
546 delete_livekit_room = refreshed_room.room.participants.is_empty();
547 }
548
549 {
550 let pool = pool.lock();
551 for canceled_user_id in canceled_calls_to_user_ids {
552 for connection_id in pool.user_connection_ids(canceled_user_id) {
553 peer.send(
554 connection_id,
555 proto::CallCanceled {
556 room_id: room_id.to_proto(),
557 },
558 )
559 .trace_err();
560 }
561 }
562 }
563
564 for user_id in contacts_to_update {
565 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
566 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
567 if let Some((busy, contacts)) = busy.zip(contacts) {
568 let pool = pool.lock();
569 let updated_contact = contact_for_user(user_id, busy, &pool);
570 for contact in contacts {
571 if let db::Contact::Accepted {
572 user_id: contact_user_id,
573 ..
574 } = contact
575 {
576 for contact_conn_id in
577 pool.user_connection_ids(contact_user_id)
578 {
579 peer.send(
580 contact_conn_id,
581 proto::UpdateContacts {
582 contacts: vec![updated_contact.clone()],
583 remove_contacts: Default::default(),
584 incoming_requests: Default::default(),
585 remove_incoming_requests: Default::default(),
586 outgoing_requests: Default::default(),
587 remove_outgoing_requests: Default::default(),
588 },
589 )
590 .trace_err();
591 }
592 }
593 }
594 }
595 }
596
597 if let Some(live_kit) = livekit_client.as_ref() {
598 if delete_livekit_room {
599 live_kit.delete_room(livekit_room).await.trace_err();
600 }
601 }
602 }
603 }
604
605 app_state
606 .db
607 .delete_stale_channel_chat_participants(
608 &app_state.config.zed_environment,
609 server_id,
610 )
611 .await
612 .trace_err();
613
614 app_state
615 .db
616 .clear_old_worktree_entries(server_id)
617 .await
618 .trace_err();
619
620 app_state
621 .db
622 .delete_stale_servers(&app_state.config.zed_environment, server_id)
623 .await
624 .trace_err();
625 }
626 .instrument(span),
627 );
628 Ok(())
629 }
630
631 pub fn teardown(&self) {
632 self.peer.teardown();
633 self.connection_pool.lock().reset();
634 let _ = self.teardown.send(true);
635 }
636
637 #[cfg(test)]
638 pub fn reset(&self, id: ServerId) {
639 self.teardown();
640 *self.id.lock() = id;
641 self.peer.reset(id.0 as u32);
642 let _ = self.teardown.send(false);
643 }
644
645 #[cfg(test)]
646 pub fn id(&self) -> ServerId {
647 *self.id.lock()
648 }
649
650 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
651 where
652 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
653 Fut: 'static + Send + Future<Output = Result<()>>,
654 M: EnvelopedMessage,
655 {
656 let prev_handler = self.handlers.insert(
657 TypeId::of::<M>(),
658 Box::new(move |envelope, session| {
659 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
660 let received_at = envelope.received_at;
661 tracing::info!("message received");
662 let start_time = Instant::now();
663 let future = (handler)(*envelope, session);
664 async move {
665 let result = future.await;
666 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
667 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
668 let queue_duration_ms = total_duration_ms - processing_duration_ms;
669 let payload_type = M::NAME;
670
671 match result {
672 Err(error) => {
673 tracing::error!(
674 ?error,
675 total_duration_ms,
676 processing_duration_ms,
677 queue_duration_ms,
678 payload_type,
679 "error handling message"
680 )
681 }
682 Ok(()) => tracing::info!(
683 total_duration_ms,
684 processing_duration_ms,
685 queue_duration_ms,
686 "finished handling message"
687 ),
688 }
689 }
690 .boxed()
691 }),
692 );
693 if prev_handler.is_some() {
694 panic!("registered a handler for the same message twice");
695 }
696 self
697 }
698
699 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
700 where
701 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
702 Fut: 'static + Send + Future<Output = Result<()>>,
703 M: EnvelopedMessage,
704 {
705 self.add_handler(move |envelope, session| handler(envelope.payload, session));
706 self
707 }
708
709 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
710 where
711 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
712 Fut: Send + Future<Output = Result<()>>,
713 M: RequestMessage,
714 {
715 let handler = Arc::new(handler);
716 self.add_handler(move |envelope, session| {
717 let receipt = envelope.receipt();
718 let handler = handler.clone();
719 async move {
720 let peer = session.peer.clone();
721 let responded = Arc::new(AtomicBool::default());
722 let response = Response {
723 peer: peer.clone(),
724 responded: responded.clone(),
725 receipt,
726 };
727 match (handler)(envelope.payload, response, session).await {
728 Ok(()) => {
729 if responded.load(std::sync::atomic::Ordering::SeqCst) {
730 Ok(())
731 } else {
732 Err(anyhow!("handler did not send a response"))?
733 }
734 }
735 Err(error) => {
736 let proto_err = match &error {
737 Error::Internal(err) => err.to_proto(),
738 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
739 };
740 peer.respond_with_error(receipt, proto_err)?;
741 Err(error)
742 }
743 }
744 }
745 })
746 }
747
748 pub fn handle_connection(
749 self: &Arc<Self>,
750 connection: Connection,
751 address: String,
752 principal: Principal,
753 zed_version: ZedVersion,
754 user_agent: Option<String>,
755 geoip_country_code: Option<String>,
756 system_id: Option<String>,
757 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
758 executor: Executor,
759 connection_guard: Option<ConnectionGuard>,
760 ) -> impl Future<Output = ()> + use<> {
761 let this = self.clone();
762 let span = info_span!("handle connection", %address,
763 connection_id=field::Empty,
764 user_id=field::Empty,
765 login=field::Empty,
766 impersonator=field::Empty,
767 user_agent=field::Empty,
768 geoip_country_code=field::Empty
769 );
770 principal.update_span(&span);
771 if let Some(user_agent) = user_agent {
772 span.record("user_agent", user_agent);
773 }
774
775 if let Some(country_code) = geoip_country_code.as_ref() {
776 span.record("geoip_country_code", country_code);
777 }
778
779 let mut teardown = self.teardown.subscribe();
780 async move {
781 if *teardown.borrow() {
782 tracing::error!("server is tearing down");
783 return
784 }
785
786 let (connection_id, handle_io, mut incoming_rx) = this
787 .peer
788 .add_connection(connection, {
789 let executor = executor.clone();
790 move |duration| executor.sleep(duration)
791 });
792 tracing::Span::current().record("connection_id", format!("{}", connection_id));
793
794 tracing::info!("connection opened");
795
796 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
797 let http_client = match ReqwestClient::user_agent(&user_agent) {
798 Ok(http_client) => Arc::new(http_client),
799 Err(error) => {
800 tracing::error!(?error, "failed to create HTTP client");
801 return;
802 }
803 };
804
805 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
806 supermaven_admin_api_key.to_string(),
807 http_client.clone(),
808 )));
809
810 let session = Session {
811 principal: principal.clone(),
812 connection_id,
813 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
814 peer: this.peer.clone(),
815 connection_pool: this.connection_pool.clone(),
816 app_state: this.app_state.clone(),
817 geoip_country_code,
818 system_id,
819 _executor: executor.clone(),
820 supermaven_client,
821 };
822
823 if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
824 tracing::error!(?error, "failed to send initial client update");
825 return;
826 }
827 drop(connection_guard);
828
829 let handle_io = handle_io.fuse();
830 futures::pin_mut!(handle_io);
831
832 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
833 // This prevents deadlocks when e.g., client A performs a request to client B and
834 // client B performs a request to client A. If both clients stop processing further
835 // messages until their respective request completes, they won't have a chance to
836 // respond to the other client's request and cause a deadlock.
837 //
838 // This arrangement ensures we will attempt to process earlier messages first, but fall
839 // back to processing messages arrived later in the spirit of making progress.
840 let mut foreground_message_handlers = FuturesUnordered::new();
841 let concurrent_handlers = Arc::new(Semaphore::new(512));
842 loop {
843 let next_message = async {
844 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
845 let message = incoming_rx.next().await;
846 (permit, message)
847 }.fuse();
848 futures::pin_mut!(next_message);
849 futures::select_biased! {
850 _ = teardown.changed().fuse() => return,
851 result = handle_io => {
852 if let Err(error) = result {
853 tracing::error!(?error, "error handling I/O");
854 }
855 break;
856 }
857 _ = foreground_message_handlers.next() => {}
858 next_message = next_message => {
859 let (permit, message) = next_message;
860 if let Some(message) = message {
861 let type_name = message.payload_type_name();
862 // note: we copy all the fields from the parent span so we can query them in the logs.
863 // (https://github.com/tokio-rs/tracing/issues/2670).
864 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
865 user_id=field::Empty,
866 login=field::Empty,
867 impersonator=field::Empty,
868 );
869 principal.update_span(&span);
870 let span_enter = span.enter();
871 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
872 let is_background = message.is_background();
873 let handle_message = (handler)(message, session.clone());
874 drop(span_enter);
875
876 let handle_message = async move {
877 handle_message.await;
878 drop(permit);
879 }.instrument(span);
880 if is_background {
881 executor.spawn_detached(handle_message);
882 } else {
883 foreground_message_handlers.push(handle_message);
884 }
885 } else {
886 tracing::error!("no message handler");
887 }
888 } else {
889 tracing::info!("connection closed");
890 break;
891 }
892 }
893 }
894 }
895
896 drop(foreground_message_handlers);
897 tracing::info!("signing out");
898 if let Err(error) = connection_lost(session, teardown, executor).await {
899 tracing::error!(?error, "error signing out");
900 }
901
902 }.instrument(span)
903 }
904
905 async fn send_initial_client_update(
906 &self,
907 connection_id: ConnectionId,
908 zed_version: ZedVersion,
909 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
910 session: &Session,
911 ) -> Result<()> {
912 self.peer.send(
913 connection_id,
914 proto::Hello {
915 peer_id: Some(connection_id.into()),
916 },
917 )?;
918 tracing::info!("sent hello message");
919 if let Some(send_connection_id) = send_connection_id.take() {
920 let _ = send_connection_id.send(connection_id);
921 }
922
923 match &session.principal {
924 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
925 if !user.connected_once {
926 self.peer.send(connection_id, proto::ShowContacts {})?;
927 self.app_state
928 .db
929 .set_user_connected_once(user.id, true)
930 .await?;
931 }
932
933 update_user_plan(session).await?;
934
935 let contacts = self.app_state.db.get_contacts(user.id).await?;
936
937 {
938 let mut pool = self.connection_pool.lock();
939 pool.add_connection(connection_id, user.id, user.admin, zed_version);
940 self.peer.send(
941 connection_id,
942 build_initial_contacts_update(contacts, &pool),
943 )?;
944 }
945
946 if should_auto_subscribe_to_channels(zed_version) {
947 subscribe_user_to_channels(user.id, session).await?;
948 }
949
950 if let Some(incoming_call) =
951 self.app_state.db.incoming_call_for_user(user.id).await?
952 {
953 self.peer.send(connection_id, incoming_call)?;
954 }
955
956 update_user_contacts(user.id, session).await?;
957 }
958 }
959
960 Ok(())
961 }
962
963 pub async fn invite_code_redeemed(
964 self: &Arc<Self>,
965 inviter_id: UserId,
966 invitee_id: UserId,
967 ) -> Result<()> {
968 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
969 if let Some(code) = &user.invite_code {
970 let pool = self.connection_pool.lock();
971 let invitee_contact = contact_for_user(invitee_id, false, &pool);
972 for connection_id in pool.user_connection_ids(inviter_id) {
973 self.peer.send(
974 connection_id,
975 proto::UpdateContacts {
976 contacts: vec![invitee_contact.clone()],
977 ..Default::default()
978 },
979 )?;
980 self.peer.send(
981 connection_id,
982 proto::UpdateInviteInfo {
983 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
984 count: user.invite_count as u32,
985 },
986 )?;
987 }
988 }
989 }
990 Ok(())
991 }
992
993 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
994 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
995 if let Some(invite_code) = &user.invite_code {
996 let pool = self.connection_pool.lock();
997 for connection_id in pool.user_connection_ids(user_id) {
998 self.peer.send(
999 connection_id,
1000 proto::UpdateInviteInfo {
1001 url: format!(
1002 "{}{}",
1003 self.app_state.config.invite_link_prefix, invite_code
1004 ),
1005 count: user.invite_count as u32,
1006 },
1007 )?;
1008 }
1009 }
1010 }
1011 Ok(())
1012 }
1013
1014 pub async fn update_plan_for_user(
1015 self: &Arc<Self>,
1016 user_id: UserId,
1017 update_user_plan: proto::UpdateUserPlan,
1018 ) -> Result<()> {
1019 let pool = self.connection_pool.lock();
1020 for connection_id in pool.user_connection_ids(user_id) {
1021 self.peer
1022 .send(connection_id, update_user_plan.clone())
1023 .trace_err();
1024 }
1025
1026 Ok(())
1027 }
1028
1029 /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1030 /// message on the Collab server.
1031 ///
1032 /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1033 pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1034 let user = self
1035 .app_state
1036 .db
1037 .get_user_by_id(user_id)
1038 .await?
1039 .context("user not found")?;
1040
1041 let update_user_plan = make_update_user_plan_message(
1042 &user,
1043 user.admin,
1044 &self.app_state.db,
1045 self.app_state.llm_db.clone(),
1046 )
1047 .await?;
1048
1049 self.update_plan_for_user(user_id, update_user_plan).await
1050 }
1051
1052 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1053 let pool = self.connection_pool.lock();
1054 for connection_id in pool.user_connection_ids(user_id) {
1055 self.peer
1056 .send(connection_id, proto::RefreshLlmToken {})
1057 .trace_err();
1058 }
1059 }
1060
1061 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1062 ServerSnapshot {
1063 connection_pool: ConnectionPoolGuard {
1064 guard: self.connection_pool.lock(),
1065 _not_send: PhantomData,
1066 },
1067 peer: &self.peer,
1068 }
1069 }
1070}
1071
1072impl Deref for ConnectionPoolGuard<'_> {
1073 type Target = ConnectionPool;
1074
1075 fn deref(&self) -> &Self::Target {
1076 &self.guard
1077 }
1078}
1079
1080impl DerefMut for ConnectionPoolGuard<'_> {
1081 fn deref_mut(&mut self) -> &mut Self::Target {
1082 &mut self.guard
1083 }
1084}
1085
1086impl Drop for ConnectionPoolGuard<'_> {
1087 fn drop(&mut self) {
1088 #[cfg(test)]
1089 self.check_invariants();
1090 }
1091}
1092
1093fn broadcast<F>(
1094 sender_id: Option<ConnectionId>,
1095 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1096 mut f: F,
1097) where
1098 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1099{
1100 for receiver_id in receiver_ids {
1101 if Some(receiver_id) != sender_id {
1102 if let Err(error) = f(receiver_id) {
1103 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1104 }
1105 }
1106 }
1107}
1108
1109pub struct ProtocolVersion(u32);
1110
1111impl Header for ProtocolVersion {
1112 fn name() -> &'static HeaderName {
1113 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1114 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1115 }
1116
1117 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1118 where
1119 Self: Sized,
1120 I: Iterator<Item = &'i axum::http::HeaderValue>,
1121 {
1122 let version = values
1123 .next()
1124 .ok_or_else(axum::headers::Error::invalid)?
1125 .to_str()
1126 .map_err(|_| axum::headers::Error::invalid())?
1127 .parse()
1128 .map_err(|_| axum::headers::Error::invalid())?;
1129 Ok(Self(version))
1130 }
1131
1132 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1133 values.extend([self.0.to_string().parse().unwrap()]);
1134 }
1135}
1136
1137pub struct AppVersionHeader(SemanticVersion);
1138impl Header for AppVersionHeader {
1139 fn name() -> &'static HeaderName {
1140 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1141 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1142 }
1143
1144 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1145 where
1146 Self: Sized,
1147 I: Iterator<Item = &'i axum::http::HeaderValue>,
1148 {
1149 let version = values
1150 .next()
1151 .ok_or_else(axum::headers::Error::invalid)?
1152 .to_str()
1153 .map_err(|_| axum::headers::Error::invalid())?
1154 .parse()
1155 .map_err(|_| axum::headers::Error::invalid())?;
1156 Ok(Self(version))
1157 }
1158
1159 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1160 values.extend([self.0.to_string().parse().unwrap()]);
1161 }
1162}
1163
1164pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1165 Router::new()
1166 .route("/rpc", get(handle_websocket_request))
1167 .layer(
1168 ServiceBuilder::new()
1169 .layer(Extension(server.app_state.clone()))
1170 .layer(middleware::from_fn(auth::validate_header)),
1171 )
1172 .route("/metrics", get(handle_metrics))
1173 .layer(Extension(server))
1174}
1175
1176pub async fn handle_websocket_request(
1177 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1178 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1179 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1180 Extension(server): Extension<Arc<Server>>,
1181 Extension(principal): Extension<Principal>,
1182 user_agent: Option<TypedHeader<UserAgent>>,
1183 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1184 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1185 ws: WebSocketUpgrade,
1186) -> axum::response::Response {
1187 if protocol_version != rpc::PROTOCOL_VERSION {
1188 return (
1189 StatusCode::UPGRADE_REQUIRED,
1190 "client must be upgraded".to_string(),
1191 )
1192 .into_response();
1193 }
1194
1195 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1196 return (
1197 StatusCode::UPGRADE_REQUIRED,
1198 "no version header found".to_string(),
1199 )
1200 .into_response();
1201 };
1202
1203 if !version.can_collaborate() {
1204 return (
1205 StatusCode::UPGRADE_REQUIRED,
1206 "client must be upgraded".to_string(),
1207 )
1208 .into_response();
1209 }
1210
1211 let socket_address = socket_address.to_string();
1212
1213 // Acquire connection guard before WebSocket upgrade
1214 let connection_guard = match ConnectionGuard::try_acquire() {
1215 Ok(guard) => guard,
1216 Err(()) => {
1217 return (
1218 StatusCode::SERVICE_UNAVAILABLE,
1219 "Too many concurrent connections",
1220 )
1221 .into_response();
1222 }
1223 };
1224
1225 ws.on_upgrade(move |socket| {
1226 let socket = socket
1227 .map_ok(to_tungstenite_message)
1228 .err_into()
1229 .with(|message| async move { to_axum_message(message) });
1230 let connection = Connection::new(Box::pin(socket));
1231 async move {
1232 server
1233 .handle_connection(
1234 connection,
1235 socket_address,
1236 principal,
1237 version,
1238 user_agent.map(|header| header.to_string()),
1239 country_code_header.map(|header| header.to_string()),
1240 system_id_header.map(|header| header.to_string()),
1241 None,
1242 Executor::Production,
1243 Some(connection_guard),
1244 )
1245 .await;
1246 }
1247 })
1248}
1249
1250pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1251 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1252 let connections_metric = CONNECTIONS_METRIC
1253 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1254
1255 let connections = server
1256 .connection_pool
1257 .lock()
1258 .connections()
1259 .filter(|connection| !connection.admin)
1260 .count();
1261 connections_metric.set(connections as _);
1262
1263 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1264 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1265 register_int_gauge!(
1266 "shared_projects",
1267 "number of open projects with one or more guests"
1268 )
1269 .unwrap()
1270 });
1271
1272 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1273 shared_projects_metric.set(shared_projects as _);
1274
1275 let encoder = prometheus::TextEncoder::new();
1276 let metric_families = prometheus::gather();
1277 let encoded_metrics = encoder
1278 .encode_to_string(&metric_families)
1279 .map_err(|err| anyhow!("{err}"))?;
1280 Ok(encoded_metrics)
1281}
1282
1283#[instrument(err, skip(executor))]
1284async fn connection_lost(
1285 session: Session,
1286 mut teardown: watch::Receiver<bool>,
1287 executor: Executor,
1288) -> Result<()> {
1289 session.peer.disconnect(session.connection_id);
1290 session
1291 .connection_pool()
1292 .await
1293 .remove_connection(session.connection_id)?;
1294
1295 session
1296 .db()
1297 .await
1298 .connection_lost(session.connection_id)
1299 .await
1300 .trace_err();
1301
1302 futures::select_biased! {
1303 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1304
1305 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1306 leave_room_for_session(&session, session.connection_id).await.trace_err();
1307 leave_channel_buffers_for_session(&session)
1308 .await
1309 .trace_err();
1310
1311 if !session
1312 .connection_pool()
1313 .await
1314 .is_user_online(session.user_id())
1315 {
1316 let db = session.db().await;
1317 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1318 room_updated(&room, &session.peer);
1319 }
1320 }
1321
1322 update_user_contacts(session.user_id(), &session).await?;
1323 },
1324 _ = teardown.changed().fuse() => {}
1325 }
1326
1327 Ok(())
1328}
1329
1330/// Acknowledges a ping from a client, used to keep the connection alive.
1331async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1332 response.send(proto::Ack {})?;
1333 Ok(())
1334}
1335
1336/// Creates a new room for calling (outside of channels)
1337async fn create_room(
1338 _request: proto::CreateRoom,
1339 response: Response<proto::CreateRoom>,
1340 session: Session,
1341) -> Result<()> {
1342 let livekit_room = nanoid::nanoid!(30);
1343
1344 let live_kit_connection_info = util::maybe!(async {
1345 let live_kit = session.app_state.livekit_client.as_ref();
1346 let live_kit = live_kit?;
1347 let user_id = session.user_id().to_string();
1348
1349 let token = live_kit
1350 .room_token(&livekit_room, &user_id.to_string())
1351 .trace_err()?;
1352
1353 Some(proto::LiveKitConnectionInfo {
1354 server_url: live_kit.url().into(),
1355 token,
1356 can_publish: true,
1357 })
1358 })
1359 .await;
1360
1361 let room = session
1362 .db()
1363 .await
1364 .create_room(session.user_id(), session.connection_id, &livekit_room)
1365 .await?;
1366
1367 response.send(proto::CreateRoomResponse {
1368 room: Some(room.clone()),
1369 live_kit_connection_info,
1370 })?;
1371
1372 update_user_contacts(session.user_id(), &session).await?;
1373 Ok(())
1374}
1375
1376/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1377async fn join_room(
1378 request: proto::JoinRoom,
1379 response: Response<proto::JoinRoom>,
1380 session: Session,
1381) -> Result<()> {
1382 let room_id = RoomId::from_proto(request.id);
1383
1384 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1385
1386 if let Some(channel_id) = channel_id {
1387 return join_channel_internal(channel_id, Box::new(response), session).await;
1388 }
1389
1390 let joined_room = {
1391 let room = session
1392 .db()
1393 .await
1394 .join_room(room_id, session.user_id(), session.connection_id)
1395 .await?;
1396 room_updated(&room.room, &session.peer);
1397 room.into_inner()
1398 };
1399
1400 for connection_id in session
1401 .connection_pool()
1402 .await
1403 .user_connection_ids(session.user_id())
1404 {
1405 session
1406 .peer
1407 .send(
1408 connection_id,
1409 proto::CallCanceled {
1410 room_id: room_id.to_proto(),
1411 },
1412 )
1413 .trace_err();
1414 }
1415
1416 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1417 {
1418 live_kit
1419 .room_token(
1420 &joined_room.room.livekit_room,
1421 &session.user_id().to_string(),
1422 )
1423 .trace_err()
1424 .map(|token| proto::LiveKitConnectionInfo {
1425 server_url: live_kit.url().into(),
1426 token,
1427 can_publish: true,
1428 })
1429 } else {
1430 None
1431 };
1432
1433 response.send(proto::JoinRoomResponse {
1434 room: Some(joined_room.room),
1435 channel_id: None,
1436 live_kit_connection_info,
1437 })?;
1438
1439 update_user_contacts(session.user_id(), &session).await?;
1440 Ok(())
1441}
1442
1443/// Rejoin room is used to reconnect to a room after connection errors.
1444async fn rejoin_room(
1445 request: proto::RejoinRoom,
1446 response: Response<proto::RejoinRoom>,
1447 session: Session,
1448) -> Result<()> {
1449 let room;
1450 let channel;
1451 {
1452 let mut rejoined_room = session
1453 .db()
1454 .await
1455 .rejoin_room(request, session.user_id(), session.connection_id)
1456 .await?;
1457
1458 response.send(proto::RejoinRoomResponse {
1459 room: Some(rejoined_room.room.clone()),
1460 reshared_projects: rejoined_room
1461 .reshared_projects
1462 .iter()
1463 .map(|project| proto::ResharedProject {
1464 id: project.id.to_proto(),
1465 collaborators: project
1466 .collaborators
1467 .iter()
1468 .map(|collaborator| collaborator.to_proto())
1469 .collect(),
1470 })
1471 .collect(),
1472 rejoined_projects: rejoined_room
1473 .rejoined_projects
1474 .iter()
1475 .map(|rejoined_project| rejoined_project.to_proto())
1476 .collect(),
1477 })?;
1478 room_updated(&rejoined_room.room, &session.peer);
1479
1480 for project in &rejoined_room.reshared_projects {
1481 for collaborator in &project.collaborators {
1482 session
1483 .peer
1484 .send(
1485 collaborator.connection_id,
1486 proto::UpdateProjectCollaborator {
1487 project_id: project.id.to_proto(),
1488 old_peer_id: Some(project.old_connection_id.into()),
1489 new_peer_id: Some(session.connection_id.into()),
1490 },
1491 )
1492 .trace_err();
1493 }
1494
1495 broadcast(
1496 Some(session.connection_id),
1497 project
1498 .collaborators
1499 .iter()
1500 .map(|collaborator| collaborator.connection_id),
1501 |connection_id| {
1502 session.peer.forward_send(
1503 session.connection_id,
1504 connection_id,
1505 proto::UpdateProject {
1506 project_id: project.id.to_proto(),
1507 worktrees: project.worktrees.clone(),
1508 },
1509 )
1510 },
1511 );
1512 }
1513
1514 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1515
1516 let rejoined_room = rejoined_room.into_inner();
1517
1518 room = rejoined_room.room;
1519 channel = rejoined_room.channel;
1520 }
1521
1522 if let Some(channel) = channel {
1523 channel_updated(
1524 &channel,
1525 &room,
1526 &session.peer,
1527 &*session.connection_pool().await,
1528 );
1529 }
1530
1531 update_user_contacts(session.user_id(), &session).await?;
1532 Ok(())
1533}
1534
1535fn notify_rejoined_projects(
1536 rejoined_projects: &mut Vec<RejoinedProject>,
1537 session: &Session,
1538) -> Result<()> {
1539 for project in rejoined_projects.iter() {
1540 for collaborator in &project.collaborators {
1541 session
1542 .peer
1543 .send(
1544 collaborator.connection_id,
1545 proto::UpdateProjectCollaborator {
1546 project_id: project.id.to_proto(),
1547 old_peer_id: Some(project.old_connection_id.into()),
1548 new_peer_id: Some(session.connection_id.into()),
1549 },
1550 )
1551 .trace_err();
1552 }
1553 }
1554
1555 for project in rejoined_projects {
1556 for worktree in mem::take(&mut project.worktrees) {
1557 // Stream this worktree's entries.
1558 let message = proto::UpdateWorktree {
1559 project_id: project.id.to_proto(),
1560 worktree_id: worktree.id,
1561 abs_path: worktree.abs_path.clone(),
1562 root_name: worktree.root_name,
1563 updated_entries: worktree.updated_entries,
1564 removed_entries: worktree.removed_entries,
1565 scan_id: worktree.scan_id,
1566 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1567 updated_repositories: worktree.updated_repositories,
1568 removed_repositories: worktree.removed_repositories,
1569 };
1570 for update in proto::split_worktree_update(message) {
1571 session.peer.send(session.connection_id, update)?;
1572 }
1573
1574 // Stream this worktree's diagnostics.
1575 for summary in worktree.diagnostic_summaries {
1576 session.peer.send(
1577 session.connection_id,
1578 proto::UpdateDiagnosticSummary {
1579 project_id: project.id.to_proto(),
1580 worktree_id: worktree.id,
1581 summary: Some(summary),
1582 },
1583 )?;
1584 }
1585
1586 for settings_file in worktree.settings_files {
1587 session.peer.send(
1588 session.connection_id,
1589 proto::UpdateWorktreeSettings {
1590 project_id: project.id.to_proto(),
1591 worktree_id: worktree.id,
1592 path: settings_file.path,
1593 content: Some(settings_file.content),
1594 kind: Some(settings_file.kind.to_proto().into()),
1595 },
1596 )?;
1597 }
1598 }
1599
1600 for repository in mem::take(&mut project.updated_repositories) {
1601 for update in split_repository_update(repository) {
1602 session.peer.send(session.connection_id, update)?;
1603 }
1604 }
1605
1606 for id in mem::take(&mut project.removed_repositories) {
1607 session.peer.send(
1608 session.connection_id,
1609 proto::RemoveRepository {
1610 project_id: project.id.to_proto(),
1611 id,
1612 },
1613 )?;
1614 }
1615 }
1616
1617 Ok(())
1618}
1619
1620/// leave room disconnects from the room.
1621async fn leave_room(
1622 _: proto::LeaveRoom,
1623 response: Response<proto::LeaveRoom>,
1624 session: Session,
1625) -> Result<()> {
1626 leave_room_for_session(&session, session.connection_id).await?;
1627 response.send(proto::Ack {})?;
1628 Ok(())
1629}
1630
1631/// Updates the permissions of someone else in the room.
1632async fn set_room_participant_role(
1633 request: proto::SetRoomParticipantRole,
1634 response: Response<proto::SetRoomParticipantRole>,
1635 session: Session,
1636) -> Result<()> {
1637 let user_id = UserId::from_proto(request.user_id);
1638 let role = ChannelRole::from(request.role());
1639
1640 let (livekit_room, can_publish) = {
1641 let room = session
1642 .db()
1643 .await
1644 .set_room_participant_role(
1645 session.user_id(),
1646 RoomId::from_proto(request.room_id),
1647 user_id,
1648 role,
1649 )
1650 .await?;
1651
1652 let livekit_room = room.livekit_room.clone();
1653 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1654 room_updated(&room, &session.peer);
1655 (livekit_room, can_publish)
1656 };
1657
1658 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1659 live_kit
1660 .update_participant(
1661 livekit_room.clone(),
1662 request.user_id.to_string(),
1663 livekit_api::proto::ParticipantPermission {
1664 can_subscribe: true,
1665 can_publish,
1666 can_publish_data: can_publish,
1667 hidden: false,
1668 recorder: false,
1669 },
1670 )
1671 .await
1672 .trace_err();
1673 }
1674
1675 response.send(proto::Ack {})?;
1676 Ok(())
1677}
1678
1679/// Call someone else into the current room
1680async fn call(
1681 request: proto::Call,
1682 response: Response<proto::Call>,
1683 session: Session,
1684) -> Result<()> {
1685 let room_id = RoomId::from_proto(request.room_id);
1686 let calling_user_id = session.user_id();
1687 let calling_connection_id = session.connection_id;
1688 let called_user_id = UserId::from_proto(request.called_user_id);
1689 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1690 if !session
1691 .db()
1692 .await
1693 .has_contact(calling_user_id, called_user_id)
1694 .await?
1695 {
1696 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1697 }
1698
1699 let incoming_call = {
1700 let (room, incoming_call) = &mut *session
1701 .db()
1702 .await
1703 .call(
1704 room_id,
1705 calling_user_id,
1706 calling_connection_id,
1707 called_user_id,
1708 initial_project_id,
1709 )
1710 .await?;
1711 room_updated(room, &session.peer);
1712 mem::take(incoming_call)
1713 };
1714 update_user_contacts(called_user_id, &session).await?;
1715
1716 let mut calls = session
1717 .connection_pool()
1718 .await
1719 .user_connection_ids(called_user_id)
1720 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1721 .collect::<FuturesUnordered<_>>();
1722
1723 while let Some(call_response) = calls.next().await {
1724 match call_response.as_ref() {
1725 Ok(_) => {
1726 response.send(proto::Ack {})?;
1727 return Ok(());
1728 }
1729 Err(_) => {
1730 call_response.trace_err();
1731 }
1732 }
1733 }
1734
1735 {
1736 let room = session
1737 .db()
1738 .await
1739 .call_failed(room_id, called_user_id)
1740 .await?;
1741 room_updated(&room, &session.peer);
1742 }
1743 update_user_contacts(called_user_id, &session).await?;
1744
1745 Err(anyhow!("failed to ring user"))?
1746}
1747
1748/// Cancel an outgoing call.
1749async fn cancel_call(
1750 request: proto::CancelCall,
1751 response: Response<proto::CancelCall>,
1752 session: Session,
1753) -> Result<()> {
1754 let called_user_id = UserId::from_proto(request.called_user_id);
1755 let room_id = RoomId::from_proto(request.room_id);
1756 {
1757 let room = session
1758 .db()
1759 .await
1760 .cancel_call(room_id, session.connection_id, called_user_id)
1761 .await?;
1762 room_updated(&room, &session.peer);
1763 }
1764
1765 for connection_id in session
1766 .connection_pool()
1767 .await
1768 .user_connection_ids(called_user_id)
1769 {
1770 session
1771 .peer
1772 .send(
1773 connection_id,
1774 proto::CallCanceled {
1775 room_id: room_id.to_proto(),
1776 },
1777 )
1778 .trace_err();
1779 }
1780 response.send(proto::Ack {})?;
1781
1782 update_user_contacts(called_user_id, &session).await?;
1783 Ok(())
1784}
1785
1786/// Decline an incoming call.
1787async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1788 let room_id = RoomId::from_proto(message.room_id);
1789 {
1790 let room = session
1791 .db()
1792 .await
1793 .decline_call(Some(room_id), session.user_id())
1794 .await?
1795 .context("declining call")?;
1796 room_updated(&room, &session.peer);
1797 }
1798
1799 for connection_id in session
1800 .connection_pool()
1801 .await
1802 .user_connection_ids(session.user_id())
1803 {
1804 session
1805 .peer
1806 .send(
1807 connection_id,
1808 proto::CallCanceled {
1809 room_id: room_id.to_proto(),
1810 },
1811 )
1812 .trace_err();
1813 }
1814 update_user_contacts(session.user_id(), &session).await?;
1815 Ok(())
1816}
1817
1818/// Updates other participants in the room with your current location.
1819async fn update_participant_location(
1820 request: proto::UpdateParticipantLocation,
1821 response: Response<proto::UpdateParticipantLocation>,
1822 session: Session,
1823) -> Result<()> {
1824 let room_id = RoomId::from_proto(request.room_id);
1825 let location = request.location.context("invalid location")?;
1826
1827 let db = session.db().await;
1828 let room = db
1829 .update_room_participant_location(room_id, session.connection_id, location)
1830 .await?;
1831
1832 room_updated(&room, &session.peer);
1833 response.send(proto::Ack {})?;
1834 Ok(())
1835}
1836
1837/// Share a project into the room.
1838async fn share_project(
1839 request: proto::ShareProject,
1840 response: Response<proto::ShareProject>,
1841 session: Session,
1842) -> Result<()> {
1843 let (project_id, room) = &*session
1844 .db()
1845 .await
1846 .share_project(
1847 RoomId::from_proto(request.room_id),
1848 session.connection_id,
1849 &request.worktrees,
1850 request.is_ssh_project,
1851 )
1852 .await?;
1853 response.send(proto::ShareProjectResponse {
1854 project_id: project_id.to_proto(),
1855 })?;
1856 room_updated(room, &session.peer);
1857
1858 Ok(())
1859}
1860
1861/// Unshare a project from the room.
1862async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1863 let project_id = ProjectId::from_proto(message.project_id);
1864 unshare_project_internal(project_id, session.connection_id, &session).await
1865}
1866
1867async fn unshare_project_internal(
1868 project_id: ProjectId,
1869 connection_id: ConnectionId,
1870 session: &Session,
1871) -> Result<()> {
1872 let delete = {
1873 let room_guard = session
1874 .db()
1875 .await
1876 .unshare_project(project_id, connection_id)
1877 .await?;
1878
1879 let (delete, room, guest_connection_ids) = &*room_guard;
1880
1881 let message = proto::UnshareProject {
1882 project_id: project_id.to_proto(),
1883 };
1884
1885 broadcast(
1886 Some(connection_id),
1887 guest_connection_ids.iter().copied(),
1888 |conn_id| session.peer.send(conn_id, message.clone()),
1889 );
1890 if let Some(room) = room {
1891 room_updated(room, &session.peer);
1892 }
1893
1894 *delete
1895 };
1896
1897 if delete {
1898 let db = session.db().await;
1899 db.delete_project(project_id).await?;
1900 }
1901
1902 Ok(())
1903}
1904
1905/// Join someone elses shared project.
1906async fn join_project(
1907 request: proto::JoinProject,
1908 response: Response<proto::JoinProject>,
1909 session: Session,
1910) -> Result<()> {
1911 let project_id = ProjectId::from_proto(request.project_id);
1912
1913 tracing::info!(%project_id, "join project");
1914
1915 let db = session.db().await;
1916 let (project, replica_id) = &mut *db
1917 .join_project(
1918 project_id,
1919 session.connection_id,
1920 session.user_id(),
1921 request.committer_name.clone(),
1922 request.committer_email.clone(),
1923 )
1924 .await?;
1925 drop(db);
1926 tracing::info!(%project_id, "join remote project");
1927 let collaborators = project
1928 .collaborators
1929 .iter()
1930 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1931 .map(|collaborator| collaborator.to_proto())
1932 .collect::<Vec<_>>();
1933 let project_id = project.id;
1934 let guest_user_id = session.user_id();
1935
1936 let worktrees = project
1937 .worktrees
1938 .iter()
1939 .map(|(id, worktree)| proto::WorktreeMetadata {
1940 id: *id,
1941 root_name: worktree.root_name.clone(),
1942 visible: worktree.visible,
1943 abs_path: worktree.abs_path.clone(),
1944 })
1945 .collect::<Vec<_>>();
1946
1947 let add_project_collaborator = proto::AddProjectCollaborator {
1948 project_id: project_id.to_proto(),
1949 collaborator: Some(proto::Collaborator {
1950 peer_id: Some(session.connection_id.into()),
1951 replica_id: replica_id.0 as u32,
1952 user_id: guest_user_id.to_proto(),
1953 is_host: false,
1954 committer_name: request.committer_name.clone(),
1955 committer_email: request.committer_email.clone(),
1956 }),
1957 };
1958
1959 for collaborator in &collaborators {
1960 session
1961 .peer
1962 .send(
1963 collaborator.peer_id.unwrap().into(),
1964 add_project_collaborator.clone(),
1965 )
1966 .trace_err();
1967 }
1968
1969 // First, we send the metadata associated with each worktree.
1970 response.send(proto::JoinProjectResponse {
1971 project_id: project.id.0 as u64,
1972 worktrees: worktrees.clone(),
1973 replica_id: replica_id.0 as u32,
1974 collaborators: collaborators.clone(),
1975 language_servers: project.language_servers.clone(),
1976 role: project.role.into(),
1977 })?;
1978
1979 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1980 // Stream this worktree's entries.
1981 let message = proto::UpdateWorktree {
1982 project_id: project_id.to_proto(),
1983 worktree_id,
1984 abs_path: worktree.abs_path.clone(),
1985 root_name: worktree.root_name,
1986 updated_entries: worktree.entries,
1987 removed_entries: Default::default(),
1988 scan_id: worktree.scan_id,
1989 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1990 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1991 removed_repositories: Default::default(),
1992 };
1993 for update in proto::split_worktree_update(message) {
1994 session.peer.send(session.connection_id, update.clone())?;
1995 }
1996
1997 // Stream this worktree's diagnostics.
1998 for summary in worktree.diagnostic_summaries {
1999 session.peer.send(
2000 session.connection_id,
2001 proto::UpdateDiagnosticSummary {
2002 project_id: project_id.to_proto(),
2003 worktree_id: worktree.id,
2004 summary: Some(summary),
2005 },
2006 )?;
2007 }
2008
2009 for settings_file in worktree.settings_files {
2010 session.peer.send(
2011 session.connection_id,
2012 proto::UpdateWorktreeSettings {
2013 project_id: project_id.to_proto(),
2014 worktree_id: worktree.id,
2015 path: settings_file.path,
2016 content: Some(settings_file.content),
2017 kind: Some(settings_file.kind.to_proto() as i32),
2018 },
2019 )?;
2020 }
2021 }
2022
2023 for repository in mem::take(&mut project.repositories) {
2024 for update in split_repository_update(repository) {
2025 session.peer.send(session.connection_id, update)?;
2026 }
2027 }
2028
2029 for language_server in &project.language_servers {
2030 session.peer.send(
2031 session.connection_id,
2032 proto::UpdateLanguageServer {
2033 project_id: project_id.to_proto(),
2034 server_name: Some(language_server.name.clone()),
2035 language_server_id: language_server.id,
2036 variant: Some(
2037 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2038 proto::LspDiskBasedDiagnosticsUpdated {},
2039 ),
2040 ),
2041 },
2042 )?;
2043 }
2044
2045 Ok(())
2046}
2047
2048/// Leave someone elses shared project.
2049async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2050 let sender_id = session.connection_id;
2051 let project_id = ProjectId::from_proto(request.project_id);
2052 let db = session.db().await;
2053
2054 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2055 tracing::info!(
2056 %project_id,
2057 "leave project"
2058 );
2059
2060 project_left(project, &session);
2061 if let Some(room) = room {
2062 room_updated(room, &session.peer);
2063 }
2064
2065 Ok(())
2066}
2067
2068/// Updates other participants with changes to the project
2069async fn update_project(
2070 request: proto::UpdateProject,
2071 response: Response<proto::UpdateProject>,
2072 session: Session,
2073) -> Result<()> {
2074 let project_id = ProjectId::from_proto(request.project_id);
2075 let (room, guest_connection_ids) = &*session
2076 .db()
2077 .await
2078 .update_project(project_id, session.connection_id, &request.worktrees)
2079 .await?;
2080 broadcast(
2081 Some(session.connection_id),
2082 guest_connection_ids.iter().copied(),
2083 |connection_id| {
2084 session
2085 .peer
2086 .forward_send(session.connection_id, connection_id, request.clone())
2087 },
2088 );
2089 if let Some(room) = room {
2090 room_updated(room, &session.peer);
2091 }
2092 response.send(proto::Ack {})?;
2093
2094 Ok(())
2095}
2096
2097/// Updates other participants with changes to the worktree
2098async fn update_worktree(
2099 request: proto::UpdateWorktree,
2100 response: Response<proto::UpdateWorktree>,
2101 session: Session,
2102) -> Result<()> {
2103 let guest_connection_ids = session
2104 .db()
2105 .await
2106 .update_worktree(&request, session.connection_id)
2107 .await?;
2108
2109 broadcast(
2110 Some(session.connection_id),
2111 guest_connection_ids.iter().copied(),
2112 |connection_id| {
2113 session
2114 .peer
2115 .forward_send(session.connection_id, connection_id, request.clone())
2116 },
2117 );
2118 response.send(proto::Ack {})?;
2119 Ok(())
2120}
2121
2122async fn update_repository(
2123 request: proto::UpdateRepository,
2124 response: Response<proto::UpdateRepository>,
2125 session: Session,
2126) -> Result<()> {
2127 let guest_connection_ids = session
2128 .db()
2129 .await
2130 .update_repository(&request, session.connection_id)
2131 .await?;
2132
2133 broadcast(
2134 Some(session.connection_id),
2135 guest_connection_ids.iter().copied(),
2136 |connection_id| {
2137 session
2138 .peer
2139 .forward_send(session.connection_id, connection_id, request.clone())
2140 },
2141 );
2142 response.send(proto::Ack {})?;
2143 Ok(())
2144}
2145
2146async fn remove_repository(
2147 request: proto::RemoveRepository,
2148 response: Response<proto::RemoveRepository>,
2149 session: Session,
2150) -> Result<()> {
2151 let guest_connection_ids = session
2152 .db()
2153 .await
2154 .remove_repository(&request, session.connection_id)
2155 .await?;
2156
2157 broadcast(
2158 Some(session.connection_id),
2159 guest_connection_ids.iter().copied(),
2160 |connection_id| {
2161 session
2162 .peer
2163 .forward_send(session.connection_id, connection_id, request.clone())
2164 },
2165 );
2166 response.send(proto::Ack {})?;
2167 Ok(())
2168}
2169
2170/// Updates other participants with changes to the diagnostics
2171async fn update_diagnostic_summary(
2172 message: proto::UpdateDiagnosticSummary,
2173 session: Session,
2174) -> Result<()> {
2175 let guest_connection_ids = session
2176 .db()
2177 .await
2178 .update_diagnostic_summary(&message, session.connection_id)
2179 .await?;
2180
2181 broadcast(
2182 Some(session.connection_id),
2183 guest_connection_ids.iter().copied(),
2184 |connection_id| {
2185 session
2186 .peer
2187 .forward_send(session.connection_id, connection_id, message.clone())
2188 },
2189 );
2190
2191 Ok(())
2192}
2193
2194/// Updates other participants with changes to the worktree settings
2195async fn update_worktree_settings(
2196 message: proto::UpdateWorktreeSettings,
2197 session: Session,
2198) -> Result<()> {
2199 let guest_connection_ids = session
2200 .db()
2201 .await
2202 .update_worktree_settings(&message, session.connection_id)
2203 .await?;
2204
2205 broadcast(
2206 Some(session.connection_id),
2207 guest_connection_ids.iter().copied(),
2208 |connection_id| {
2209 session
2210 .peer
2211 .forward_send(session.connection_id, connection_id, message.clone())
2212 },
2213 );
2214
2215 Ok(())
2216}
2217
2218/// Notify other participants that a language server has started.
2219async fn start_language_server(
2220 request: proto::StartLanguageServer,
2221 session: Session,
2222) -> Result<()> {
2223 let guest_connection_ids = session
2224 .db()
2225 .await
2226 .start_language_server(&request, session.connection_id)
2227 .await?;
2228
2229 broadcast(
2230 Some(session.connection_id),
2231 guest_connection_ids.iter().copied(),
2232 |connection_id| {
2233 session
2234 .peer
2235 .forward_send(session.connection_id, connection_id, request.clone())
2236 },
2237 );
2238 Ok(())
2239}
2240
2241/// Notify other participants that a language server has changed.
2242async fn update_language_server(
2243 request: proto::UpdateLanguageServer,
2244 session: Session,
2245) -> Result<()> {
2246 let project_id = ProjectId::from_proto(request.project_id);
2247 let project_connection_ids = session
2248 .db()
2249 .await
2250 .project_connection_ids(project_id, session.connection_id, true)
2251 .await?;
2252 broadcast(
2253 Some(session.connection_id),
2254 project_connection_ids.iter().copied(),
2255 |connection_id| {
2256 session
2257 .peer
2258 .forward_send(session.connection_id, connection_id, request.clone())
2259 },
2260 );
2261 Ok(())
2262}
2263
2264/// forward a project request to the host. These requests should be read only
2265/// as guests are allowed to send them.
2266async fn forward_read_only_project_request<T>(
2267 request: T,
2268 response: Response<T>,
2269 session: Session,
2270) -> Result<()>
2271where
2272 T: EntityMessage + RequestMessage,
2273{
2274 let project_id = ProjectId::from_proto(request.remote_entity_id());
2275 let host_connection_id = session
2276 .db()
2277 .await
2278 .host_for_read_only_project_request(project_id, session.connection_id)
2279 .await?;
2280 let payload = session
2281 .peer
2282 .forward_request(session.connection_id, host_connection_id, request)
2283 .await?;
2284 response.send(payload)?;
2285 Ok(())
2286}
2287
2288async fn forward_find_search_candidates_request(
2289 request: proto::FindSearchCandidates,
2290 response: Response<proto::FindSearchCandidates>,
2291 session: Session,
2292) -> Result<()> {
2293 let project_id = ProjectId::from_proto(request.remote_entity_id());
2294 let host_connection_id = session
2295 .db()
2296 .await
2297 .host_for_read_only_project_request(project_id, session.connection_id)
2298 .await?;
2299 let payload = session
2300 .peer
2301 .forward_request(session.connection_id, host_connection_id, request)
2302 .await?;
2303 response.send(payload)?;
2304 Ok(())
2305}
2306
2307/// forward a project request to the host. These requests are disallowed
2308/// for guests.
2309async fn forward_mutating_project_request<T>(
2310 request: T,
2311 response: Response<T>,
2312 session: Session,
2313) -> Result<()>
2314where
2315 T: EntityMessage + RequestMessage,
2316{
2317 let project_id = ProjectId::from_proto(request.remote_entity_id());
2318
2319 let host_connection_id = session
2320 .db()
2321 .await
2322 .host_for_mutating_project_request(project_id, session.connection_id)
2323 .await?;
2324 let payload = session
2325 .peer
2326 .forward_request(session.connection_id, host_connection_id, request)
2327 .await?;
2328 response.send(payload)?;
2329 Ok(())
2330}
2331
2332/// Notify other participants that a new buffer has been created
2333async fn create_buffer_for_peer(
2334 request: proto::CreateBufferForPeer,
2335 session: Session,
2336) -> Result<()> {
2337 session
2338 .db()
2339 .await
2340 .check_user_is_project_host(
2341 ProjectId::from_proto(request.project_id),
2342 session.connection_id,
2343 )
2344 .await?;
2345 let peer_id = request.peer_id.context("invalid peer id")?;
2346 session
2347 .peer
2348 .forward_send(session.connection_id, peer_id.into(), request)?;
2349 Ok(())
2350}
2351
2352/// Notify other participants that a buffer has been updated. This is
2353/// allowed for guests as long as the update is limited to selections.
2354async fn update_buffer(
2355 request: proto::UpdateBuffer,
2356 response: Response<proto::UpdateBuffer>,
2357 session: Session,
2358) -> Result<()> {
2359 let project_id = ProjectId::from_proto(request.project_id);
2360 let mut capability = Capability::ReadOnly;
2361
2362 for op in request.operations.iter() {
2363 match op.variant {
2364 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2365 Some(_) => capability = Capability::ReadWrite,
2366 }
2367 }
2368
2369 let host = {
2370 let guard = session
2371 .db()
2372 .await
2373 .connections_for_buffer_update(project_id, session.connection_id, capability)
2374 .await?;
2375
2376 let (host, guests) = &*guard;
2377
2378 broadcast(
2379 Some(session.connection_id),
2380 guests.clone(),
2381 |connection_id| {
2382 session
2383 .peer
2384 .forward_send(session.connection_id, connection_id, request.clone())
2385 },
2386 );
2387
2388 *host
2389 };
2390
2391 if host != session.connection_id {
2392 session
2393 .peer
2394 .forward_request(session.connection_id, host, request.clone())
2395 .await?;
2396 }
2397
2398 response.send(proto::Ack {})?;
2399 Ok(())
2400}
2401
2402async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2403 let project_id = ProjectId::from_proto(message.project_id);
2404
2405 let operation = message.operation.as_ref().context("invalid operation")?;
2406 let capability = match operation.variant.as_ref() {
2407 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2408 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2409 match buffer_op.variant {
2410 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2411 Capability::ReadOnly
2412 }
2413 _ => Capability::ReadWrite,
2414 }
2415 } else {
2416 Capability::ReadWrite
2417 }
2418 }
2419 Some(_) => Capability::ReadWrite,
2420 None => Capability::ReadOnly,
2421 };
2422
2423 let guard = session
2424 .db()
2425 .await
2426 .connections_for_buffer_update(project_id, session.connection_id, capability)
2427 .await?;
2428
2429 let (host, guests) = &*guard;
2430
2431 broadcast(
2432 Some(session.connection_id),
2433 guests.iter().chain([host]).copied(),
2434 |connection_id| {
2435 session
2436 .peer
2437 .forward_send(session.connection_id, connection_id, message.clone())
2438 },
2439 );
2440
2441 Ok(())
2442}
2443
2444/// Notify other participants that a project has been updated.
2445async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2446 request: T,
2447 session: Session,
2448) -> Result<()> {
2449 let project_id = ProjectId::from_proto(request.remote_entity_id());
2450 let project_connection_ids = session
2451 .db()
2452 .await
2453 .project_connection_ids(project_id, session.connection_id, false)
2454 .await?;
2455
2456 broadcast(
2457 Some(session.connection_id),
2458 project_connection_ids.iter().copied(),
2459 |connection_id| {
2460 session
2461 .peer
2462 .forward_send(session.connection_id, connection_id, request.clone())
2463 },
2464 );
2465 Ok(())
2466}
2467
2468/// Start following another user in a call.
2469async fn follow(
2470 request: proto::Follow,
2471 response: Response<proto::Follow>,
2472 session: Session,
2473) -> Result<()> {
2474 let room_id = RoomId::from_proto(request.room_id);
2475 let project_id = request.project_id.map(ProjectId::from_proto);
2476 let leader_id = request.leader_id.context("invalid leader id")?.into();
2477 let follower_id = session.connection_id;
2478
2479 session
2480 .db()
2481 .await
2482 .check_room_participants(room_id, leader_id, session.connection_id)
2483 .await?;
2484
2485 let response_payload = session
2486 .peer
2487 .forward_request(session.connection_id, leader_id, request)
2488 .await?;
2489 response.send(response_payload)?;
2490
2491 if let Some(project_id) = project_id {
2492 let room = session
2493 .db()
2494 .await
2495 .follow(room_id, project_id, leader_id, follower_id)
2496 .await?;
2497 room_updated(&room, &session.peer);
2498 }
2499
2500 Ok(())
2501}
2502
2503/// Stop following another user in a call.
2504async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2505 let room_id = RoomId::from_proto(request.room_id);
2506 let project_id = request.project_id.map(ProjectId::from_proto);
2507 let leader_id = request.leader_id.context("invalid leader id")?.into();
2508 let follower_id = session.connection_id;
2509
2510 session
2511 .db()
2512 .await
2513 .check_room_participants(room_id, leader_id, session.connection_id)
2514 .await?;
2515
2516 session
2517 .peer
2518 .forward_send(session.connection_id, leader_id, request)?;
2519
2520 if let Some(project_id) = project_id {
2521 let room = session
2522 .db()
2523 .await
2524 .unfollow(room_id, project_id, leader_id, follower_id)
2525 .await?;
2526 room_updated(&room, &session.peer);
2527 }
2528
2529 Ok(())
2530}
2531
2532/// Notify everyone following you of your current location.
2533async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2534 let room_id = RoomId::from_proto(request.room_id);
2535 let database = session.db.lock().await;
2536
2537 let connection_ids = if let Some(project_id) = request.project_id {
2538 let project_id = ProjectId::from_proto(project_id);
2539 database
2540 .project_connection_ids(project_id, session.connection_id, true)
2541 .await?
2542 } else {
2543 database
2544 .room_connection_ids(room_id, session.connection_id)
2545 .await?
2546 };
2547
2548 // For now, don't send view update messages back to that view's current leader.
2549 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2550 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2551 _ => None,
2552 });
2553
2554 for connection_id in connection_ids.iter().cloned() {
2555 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2556 session
2557 .peer
2558 .forward_send(session.connection_id, connection_id, request.clone())?;
2559 }
2560 }
2561 Ok(())
2562}
2563
2564/// Get public data about users.
2565async fn get_users(
2566 request: proto::GetUsers,
2567 response: Response<proto::GetUsers>,
2568 session: Session,
2569) -> Result<()> {
2570 let user_ids = request
2571 .user_ids
2572 .into_iter()
2573 .map(UserId::from_proto)
2574 .collect();
2575 let users = session
2576 .db()
2577 .await
2578 .get_users_by_ids(user_ids)
2579 .await?
2580 .into_iter()
2581 .map(|user| proto::User {
2582 id: user.id.to_proto(),
2583 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2584 github_login: user.github_login,
2585 name: user.name,
2586 })
2587 .collect();
2588 response.send(proto::UsersResponse { users })?;
2589 Ok(())
2590}
2591
2592/// Search for users (to invite) buy Github login
2593async fn fuzzy_search_users(
2594 request: proto::FuzzySearchUsers,
2595 response: Response<proto::FuzzySearchUsers>,
2596 session: Session,
2597) -> Result<()> {
2598 let query = request.query;
2599 let users = match query.len() {
2600 0 => vec![],
2601 1 | 2 => session
2602 .db()
2603 .await
2604 .get_user_by_github_login(&query)
2605 .await?
2606 .into_iter()
2607 .collect(),
2608 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2609 };
2610 let users = users
2611 .into_iter()
2612 .filter(|user| user.id != session.user_id())
2613 .map(|user| proto::User {
2614 id: user.id.to_proto(),
2615 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2616 github_login: user.github_login,
2617 name: user.name,
2618 })
2619 .collect();
2620 response.send(proto::UsersResponse { users })?;
2621 Ok(())
2622}
2623
2624/// Send a contact request to another user.
2625async fn request_contact(
2626 request: proto::RequestContact,
2627 response: Response<proto::RequestContact>,
2628 session: Session,
2629) -> Result<()> {
2630 let requester_id = session.user_id();
2631 let responder_id = UserId::from_proto(request.responder_id);
2632 if requester_id == responder_id {
2633 return Err(anyhow!("cannot add yourself as a contact"))?;
2634 }
2635
2636 let notifications = session
2637 .db()
2638 .await
2639 .send_contact_request(requester_id, responder_id)
2640 .await?;
2641
2642 // Update outgoing contact requests of requester
2643 let mut update = proto::UpdateContacts::default();
2644 update.outgoing_requests.push(responder_id.to_proto());
2645 for connection_id in session
2646 .connection_pool()
2647 .await
2648 .user_connection_ids(requester_id)
2649 {
2650 session.peer.send(connection_id, update.clone())?;
2651 }
2652
2653 // Update incoming contact requests of responder
2654 let mut update = proto::UpdateContacts::default();
2655 update
2656 .incoming_requests
2657 .push(proto::IncomingContactRequest {
2658 requester_id: requester_id.to_proto(),
2659 });
2660 let connection_pool = session.connection_pool().await;
2661 for connection_id in connection_pool.user_connection_ids(responder_id) {
2662 session.peer.send(connection_id, update.clone())?;
2663 }
2664
2665 send_notifications(&connection_pool, &session.peer, notifications);
2666
2667 response.send(proto::Ack {})?;
2668 Ok(())
2669}
2670
2671/// Accept or decline a contact request
2672async fn respond_to_contact_request(
2673 request: proto::RespondToContactRequest,
2674 response: Response<proto::RespondToContactRequest>,
2675 session: Session,
2676) -> Result<()> {
2677 let responder_id = session.user_id();
2678 let requester_id = UserId::from_proto(request.requester_id);
2679 let db = session.db().await;
2680 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2681 db.dismiss_contact_notification(responder_id, requester_id)
2682 .await?;
2683 } else {
2684 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2685
2686 let notifications = db
2687 .respond_to_contact_request(responder_id, requester_id, accept)
2688 .await?;
2689 let requester_busy = db.is_user_busy(requester_id).await?;
2690 let responder_busy = db.is_user_busy(responder_id).await?;
2691
2692 let pool = session.connection_pool().await;
2693 // Update responder with new contact
2694 let mut update = proto::UpdateContacts::default();
2695 if accept {
2696 update
2697 .contacts
2698 .push(contact_for_user(requester_id, requester_busy, &pool));
2699 }
2700 update
2701 .remove_incoming_requests
2702 .push(requester_id.to_proto());
2703 for connection_id in pool.user_connection_ids(responder_id) {
2704 session.peer.send(connection_id, update.clone())?;
2705 }
2706
2707 // Update requester with new contact
2708 let mut update = proto::UpdateContacts::default();
2709 if accept {
2710 update
2711 .contacts
2712 .push(contact_for_user(responder_id, responder_busy, &pool));
2713 }
2714 update
2715 .remove_outgoing_requests
2716 .push(responder_id.to_proto());
2717
2718 for connection_id in pool.user_connection_ids(requester_id) {
2719 session.peer.send(connection_id, update.clone())?;
2720 }
2721
2722 send_notifications(&pool, &session.peer, notifications);
2723 }
2724
2725 response.send(proto::Ack {})?;
2726 Ok(())
2727}
2728
2729/// Remove a contact.
2730async fn remove_contact(
2731 request: proto::RemoveContact,
2732 response: Response<proto::RemoveContact>,
2733 session: Session,
2734) -> Result<()> {
2735 let requester_id = session.user_id();
2736 let responder_id = UserId::from_proto(request.user_id);
2737 let db = session.db().await;
2738 let (contact_accepted, deleted_notification_id) =
2739 db.remove_contact(requester_id, responder_id).await?;
2740
2741 let pool = session.connection_pool().await;
2742 // Update outgoing contact requests of requester
2743 let mut update = proto::UpdateContacts::default();
2744 if contact_accepted {
2745 update.remove_contacts.push(responder_id.to_proto());
2746 } else {
2747 update
2748 .remove_outgoing_requests
2749 .push(responder_id.to_proto());
2750 }
2751 for connection_id in pool.user_connection_ids(requester_id) {
2752 session.peer.send(connection_id, update.clone())?;
2753 }
2754
2755 // Update incoming contact requests of responder
2756 let mut update = proto::UpdateContacts::default();
2757 if contact_accepted {
2758 update.remove_contacts.push(requester_id.to_proto());
2759 } else {
2760 update
2761 .remove_incoming_requests
2762 .push(requester_id.to_proto());
2763 }
2764 for connection_id in pool.user_connection_ids(responder_id) {
2765 session.peer.send(connection_id, update.clone())?;
2766 if let Some(notification_id) = deleted_notification_id {
2767 session.peer.send(
2768 connection_id,
2769 proto::DeleteNotification {
2770 notification_id: notification_id.to_proto(),
2771 },
2772 )?;
2773 }
2774 }
2775
2776 response.send(proto::Ack {})?;
2777 Ok(())
2778}
2779
2780fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2781 version.0.minor() < 139
2782}
2783
2784async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2785 if is_staff {
2786 return Ok(proto::Plan::ZedPro);
2787 }
2788
2789 let subscription = db.get_active_billing_subscription(user_id).await?;
2790 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2791
2792 let plan = if let Some(subscription_kind) = subscription_kind {
2793 match subscription_kind {
2794 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2795 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2796 SubscriptionKind::ZedFree => proto::Plan::Free,
2797 }
2798 } else {
2799 proto::Plan::Free
2800 };
2801
2802 Ok(plan)
2803}
2804
2805async fn make_update_user_plan_message(
2806 user: &User,
2807 is_staff: bool,
2808 db: &Arc<Database>,
2809 llm_db: Option<Arc<LlmDatabase>>,
2810) -> Result<proto::UpdateUserPlan> {
2811 let feature_flags = db.get_user_flags(user.id).await?;
2812 let plan = current_plan(db, user.id, is_staff).await?;
2813 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2814 let billing_preferences = db.get_billing_preferences(user.id).await?;
2815
2816 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2817 let subscription = db.get_active_billing_subscription(user.id).await?;
2818
2819 let subscription_period =
2820 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2821
2822 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2823 llm_db
2824 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2825 .await?
2826 } else {
2827 None
2828 };
2829
2830 (subscription_period, usage)
2831 } else {
2832 (None, None)
2833 };
2834
2835 let bypass_account_age_check = feature_flags
2836 .iter()
2837 .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2838 let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2839 && !bypass_account_age_check
2840 && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2841
2842 Ok(proto::UpdateUserPlan {
2843 plan: plan.into(),
2844 trial_started_at: billing_customer
2845 .as_ref()
2846 .and_then(|billing_customer| billing_customer.trial_started_at)
2847 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2848 is_usage_based_billing_enabled: if is_staff {
2849 Some(true)
2850 } else {
2851 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2852 },
2853 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2854 proto::SubscriptionPeriod {
2855 started_at: started_at.timestamp() as u64,
2856 ended_at: ended_at.timestamp() as u64,
2857 }
2858 }),
2859 account_too_young: Some(account_too_young),
2860 has_overdue_invoices: billing_customer
2861 .map(|billing_customer| billing_customer.has_overdue_invoices),
2862 usage: Some(
2863 usage
2864 .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2865 .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2866 ),
2867 })
2868}
2869
2870fn model_requests_limit(
2871 plan: zed_llm_client::Plan,
2872 feature_flags: &Vec<String>,
2873) -> zed_llm_client::UsageLimit {
2874 match plan.model_requests_limit() {
2875 zed_llm_client::UsageLimit::Limited(limit) => {
2876 let limit = if plan == zed_llm_client::Plan::ZedProTrial
2877 && feature_flags
2878 .iter()
2879 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2880 {
2881 1_000
2882 } else {
2883 limit
2884 };
2885
2886 zed_llm_client::UsageLimit::Limited(limit)
2887 }
2888 zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2889 }
2890}
2891
2892fn subscription_usage_to_proto(
2893 plan: proto::Plan,
2894 usage: crate::llm::db::subscription_usage::Model,
2895 feature_flags: &Vec<String>,
2896) -> proto::SubscriptionUsage {
2897 let plan = match plan {
2898 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2899 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2900 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2901 };
2902
2903 proto::SubscriptionUsage {
2904 model_requests_usage_amount: usage.model_requests as u32,
2905 model_requests_usage_limit: Some(proto::UsageLimit {
2906 variant: Some(match model_requests_limit(plan, feature_flags) {
2907 zed_llm_client::UsageLimit::Limited(limit) => {
2908 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2909 limit: limit as u32,
2910 })
2911 }
2912 zed_llm_client::UsageLimit::Unlimited => {
2913 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2914 }
2915 }),
2916 }),
2917 edit_predictions_usage_amount: usage.edit_predictions as u32,
2918 edit_predictions_usage_limit: Some(proto::UsageLimit {
2919 variant: Some(match plan.edit_predictions_limit() {
2920 zed_llm_client::UsageLimit::Limited(limit) => {
2921 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2922 limit: limit as u32,
2923 })
2924 }
2925 zed_llm_client::UsageLimit::Unlimited => {
2926 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2927 }
2928 }),
2929 }),
2930 }
2931}
2932
2933fn make_default_subscription_usage(
2934 plan: proto::Plan,
2935 feature_flags: &Vec<String>,
2936) -> proto::SubscriptionUsage {
2937 let plan = match plan {
2938 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2939 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2940 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2941 };
2942
2943 proto::SubscriptionUsage {
2944 model_requests_usage_amount: 0,
2945 model_requests_usage_limit: Some(proto::UsageLimit {
2946 variant: Some(match model_requests_limit(plan, feature_flags) {
2947 zed_llm_client::UsageLimit::Limited(limit) => {
2948 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2949 limit: limit as u32,
2950 })
2951 }
2952 zed_llm_client::UsageLimit::Unlimited => {
2953 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2954 }
2955 }),
2956 }),
2957 edit_predictions_usage_amount: 0,
2958 edit_predictions_usage_limit: Some(proto::UsageLimit {
2959 variant: Some(match plan.edit_predictions_limit() {
2960 zed_llm_client::UsageLimit::Limited(limit) => {
2961 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2962 limit: limit as u32,
2963 })
2964 }
2965 zed_llm_client::UsageLimit::Unlimited => {
2966 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2967 }
2968 }),
2969 }),
2970 }
2971}
2972
2973async fn update_user_plan(session: &Session) -> Result<()> {
2974 let db = session.db().await;
2975
2976 let update_user_plan = make_update_user_plan_message(
2977 session.principal.user(),
2978 session.is_staff(),
2979 &db.0,
2980 session.app_state.llm_db.clone(),
2981 )
2982 .await?;
2983
2984 session
2985 .peer
2986 .send(session.connection_id, update_user_plan)
2987 .trace_err();
2988
2989 Ok(())
2990}
2991
2992async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2993 subscribe_user_to_channels(session.user_id(), &session).await?;
2994 Ok(())
2995}
2996
2997async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2998 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2999 let mut pool = session.connection_pool().await;
3000 for membership in &channels_for_user.channel_memberships {
3001 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3002 }
3003 session.peer.send(
3004 session.connection_id,
3005 build_update_user_channels(&channels_for_user),
3006 )?;
3007 session.peer.send(
3008 session.connection_id,
3009 build_channels_update(channels_for_user),
3010 )?;
3011 Ok(())
3012}
3013
3014/// Creates a new channel.
3015async fn create_channel(
3016 request: proto::CreateChannel,
3017 response: Response<proto::CreateChannel>,
3018 session: Session,
3019) -> Result<()> {
3020 let db = session.db().await;
3021
3022 let parent_id = request.parent_id.map(ChannelId::from_proto);
3023 let (channel, membership) = db
3024 .create_channel(&request.name, parent_id, session.user_id())
3025 .await?;
3026
3027 let root_id = channel.root_id();
3028 let channel = Channel::from_model(channel);
3029
3030 response.send(proto::CreateChannelResponse {
3031 channel: Some(channel.to_proto()),
3032 parent_id: request.parent_id,
3033 })?;
3034
3035 let mut connection_pool = session.connection_pool().await;
3036 if let Some(membership) = membership {
3037 connection_pool.subscribe_to_channel(
3038 membership.user_id,
3039 membership.channel_id,
3040 membership.role,
3041 );
3042 let update = proto::UpdateUserChannels {
3043 channel_memberships: vec![proto::ChannelMembership {
3044 channel_id: membership.channel_id.to_proto(),
3045 role: membership.role.into(),
3046 }],
3047 ..Default::default()
3048 };
3049 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3050 session.peer.send(connection_id, update.clone())?;
3051 }
3052 }
3053
3054 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3055 if !role.can_see_channel(channel.visibility) {
3056 continue;
3057 }
3058
3059 let update = proto::UpdateChannels {
3060 channels: vec![channel.to_proto()],
3061 ..Default::default()
3062 };
3063 session.peer.send(connection_id, update.clone())?;
3064 }
3065
3066 Ok(())
3067}
3068
3069/// Delete a channel
3070async fn delete_channel(
3071 request: proto::DeleteChannel,
3072 response: Response<proto::DeleteChannel>,
3073 session: Session,
3074) -> Result<()> {
3075 let db = session.db().await;
3076
3077 let channel_id = request.channel_id;
3078 let (root_channel, removed_channels) = db
3079 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3080 .await?;
3081 response.send(proto::Ack {})?;
3082
3083 // Notify members of removed channels
3084 let mut update = proto::UpdateChannels::default();
3085 update
3086 .delete_channels
3087 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3088
3089 let connection_pool = session.connection_pool().await;
3090 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3091 session.peer.send(connection_id, update.clone())?;
3092 }
3093
3094 Ok(())
3095}
3096
3097/// Invite someone to join a channel.
3098async fn invite_channel_member(
3099 request: proto::InviteChannelMember,
3100 response: Response<proto::InviteChannelMember>,
3101 session: Session,
3102) -> Result<()> {
3103 let db = session.db().await;
3104 let channel_id = ChannelId::from_proto(request.channel_id);
3105 let invitee_id = UserId::from_proto(request.user_id);
3106 let InviteMemberResult {
3107 channel,
3108 notifications,
3109 } = db
3110 .invite_channel_member(
3111 channel_id,
3112 invitee_id,
3113 session.user_id(),
3114 request.role().into(),
3115 )
3116 .await?;
3117
3118 let update = proto::UpdateChannels {
3119 channel_invitations: vec![channel.to_proto()],
3120 ..Default::default()
3121 };
3122
3123 let connection_pool = session.connection_pool().await;
3124 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3125 session.peer.send(connection_id, update.clone())?;
3126 }
3127
3128 send_notifications(&connection_pool, &session.peer, notifications);
3129
3130 response.send(proto::Ack {})?;
3131 Ok(())
3132}
3133
3134/// remove someone from a channel
3135async fn remove_channel_member(
3136 request: proto::RemoveChannelMember,
3137 response: Response<proto::RemoveChannelMember>,
3138 session: Session,
3139) -> Result<()> {
3140 let db = session.db().await;
3141 let channel_id = ChannelId::from_proto(request.channel_id);
3142 let member_id = UserId::from_proto(request.user_id);
3143
3144 let RemoveChannelMemberResult {
3145 membership_update,
3146 notification_id,
3147 } = db
3148 .remove_channel_member(channel_id, member_id, session.user_id())
3149 .await?;
3150
3151 let mut connection_pool = session.connection_pool().await;
3152 notify_membership_updated(
3153 &mut connection_pool,
3154 membership_update,
3155 member_id,
3156 &session.peer,
3157 );
3158 for connection_id in connection_pool.user_connection_ids(member_id) {
3159 if let Some(notification_id) = notification_id {
3160 session
3161 .peer
3162 .send(
3163 connection_id,
3164 proto::DeleteNotification {
3165 notification_id: notification_id.to_proto(),
3166 },
3167 )
3168 .trace_err();
3169 }
3170 }
3171
3172 response.send(proto::Ack {})?;
3173 Ok(())
3174}
3175
3176/// Toggle the channel between public and private.
3177/// Care is taken to maintain the invariant that public channels only descend from public channels,
3178/// (though members-only channels can appear at any point in the hierarchy).
3179async fn set_channel_visibility(
3180 request: proto::SetChannelVisibility,
3181 response: Response<proto::SetChannelVisibility>,
3182 session: Session,
3183) -> Result<()> {
3184 let db = session.db().await;
3185 let channel_id = ChannelId::from_proto(request.channel_id);
3186 let visibility = request.visibility().into();
3187
3188 let channel_model = db
3189 .set_channel_visibility(channel_id, visibility, session.user_id())
3190 .await?;
3191 let root_id = channel_model.root_id();
3192 let channel = Channel::from_model(channel_model);
3193
3194 let mut connection_pool = session.connection_pool().await;
3195 for (user_id, role) in connection_pool
3196 .channel_user_ids(root_id)
3197 .collect::<Vec<_>>()
3198 .into_iter()
3199 {
3200 let update = if role.can_see_channel(channel.visibility) {
3201 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3202 proto::UpdateChannels {
3203 channels: vec![channel.to_proto()],
3204 ..Default::default()
3205 }
3206 } else {
3207 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3208 proto::UpdateChannels {
3209 delete_channels: vec![channel.id.to_proto()],
3210 ..Default::default()
3211 }
3212 };
3213
3214 for connection_id in connection_pool.user_connection_ids(user_id) {
3215 session.peer.send(connection_id, update.clone())?;
3216 }
3217 }
3218
3219 response.send(proto::Ack {})?;
3220 Ok(())
3221}
3222
3223/// Alter the role for a user in the channel.
3224async fn set_channel_member_role(
3225 request: proto::SetChannelMemberRole,
3226 response: Response<proto::SetChannelMemberRole>,
3227 session: Session,
3228) -> Result<()> {
3229 let db = session.db().await;
3230 let channel_id = ChannelId::from_proto(request.channel_id);
3231 let member_id = UserId::from_proto(request.user_id);
3232 let result = db
3233 .set_channel_member_role(
3234 channel_id,
3235 session.user_id(),
3236 member_id,
3237 request.role().into(),
3238 )
3239 .await?;
3240
3241 match result {
3242 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3243 let mut connection_pool = session.connection_pool().await;
3244 notify_membership_updated(
3245 &mut connection_pool,
3246 membership_update,
3247 member_id,
3248 &session.peer,
3249 )
3250 }
3251 db::SetMemberRoleResult::InviteUpdated(channel) => {
3252 let update = proto::UpdateChannels {
3253 channel_invitations: vec![channel.to_proto()],
3254 ..Default::default()
3255 };
3256
3257 for connection_id in session
3258 .connection_pool()
3259 .await
3260 .user_connection_ids(member_id)
3261 {
3262 session.peer.send(connection_id, update.clone())?;
3263 }
3264 }
3265 }
3266
3267 response.send(proto::Ack {})?;
3268 Ok(())
3269}
3270
3271/// Change the name of a channel
3272async fn rename_channel(
3273 request: proto::RenameChannel,
3274 response: Response<proto::RenameChannel>,
3275 session: Session,
3276) -> Result<()> {
3277 let db = session.db().await;
3278 let channel_id = ChannelId::from_proto(request.channel_id);
3279 let channel_model = db
3280 .rename_channel(channel_id, session.user_id(), &request.name)
3281 .await?;
3282 let root_id = channel_model.root_id();
3283 let channel = Channel::from_model(channel_model);
3284
3285 response.send(proto::RenameChannelResponse {
3286 channel: Some(channel.to_proto()),
3287 })?;
3288
3289 let connection_pool = session.connection_pool().await;
3290 let update = proto::UpdateChannels {
3291 channels: vec![channel.to_proto()],
3292 ..Default::default()
3293 };
3294 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3295 if role.can_see_channel(channel.visibility) {
3296 session.peer.send(connection_id, update.clone())?;
3297 }
3298 }
3299
3300 Ok(())
3301}
3302
3303/// Move a channel to a new parent.
3304async fn move_channel(
3305 request: proto::MoveChannel,
3306 response: Response<proto::MoveChannel>,
3307 session: Session,
3308) -> Result<()> {
3309 let channel_id = ChannelId::from_proto(request.channel_id);
3310 let to = ChannelId::from_proto(request.to);
3311
3312 let (root_id, channels) = session
3313 .db()
3314 .await
3315 .move_channel(channel_id, to, session.user_id())
3316 .await?;
3317
3318 let connection_pool = session.connection_pool().await;
3319 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3320 let channels = channels
3321 .iter()
3322 .filter_map(|channel| {
3323 if role.can_see_channel(channel.visibility) {
3324 Some(channel.to_proto())
3325 } else {
3326 None
3327 }
3328 })
3329 .collect::<Vec<_>>();
3330 if channels.is_empty() {
3331 continue;
3332 }
3333
3334 let update = proto::UpdateChannels {
3335 channels,
3336 ..Default::default()
3337 };
3338
3339 session.peer.send(connection_id, update.clone())?;
3340 }
3341
3342 response.send(Ack {})?;
3343 Ok(())
3344}
3345
3346async fn reorder_channel(
3347 request: proto::ReorderChannel,
3348 response: Response<proto::ReorderChannel>,
3349 session: Session,
3350) -> Result<()> {
3351 let channel_id = ChannelId::from_proto(request.channel_id);
3352 let direction = request.direction();
3353
3354 let updated_channels = session
3355 .db()
3356 .await
3357 .reorder_channel(channel_id, direction, session.user_id())
3358 .await?;
3359
3360 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3361 let connection_pool = session.connection_pool().await;
3362 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3363 let channels = updated_channels
3364 .iter()
3365 .filter_map(|channel| {
3366 if role.can_see_channel(channel.visibility) {
3367 Some(channel.to_proto())
3368 } else {
3369 None
3370 }
3371 })
3372 .collect::<Vec<_>>();
3373
3374 if channels.is_empty() {
3375 continue;
3376 }
3377
3378 let update = proto::UpdateChannels {
3379 channels,
3380 ..Default::default()
3381 };
3382
3383 session.peer.send(connection_id, update.clone())?;
3384 }
3385 }
3386
3387 response.send(Ack {})?;
3388 Ok(())
3389}
3390
3391/// Get the list of channel members
3392async fn get_channel_members(
3393 request: proto::GetChannelMembers,
3394 response: Response<proto::GetChannelMembers>,
3395 session: Session,
3396) -> Result<()> {
3397 let db = session.db().await;
3398 let channel_id = ChannelId::from_proto(request.channel_id);
3399 let limit = if request.limit == 0 {
3400 u16::MAX as u64
3401 } else {
3402 request.limit
3403 };
3404 let (members, users) = db
3405 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3406 .await?;
3407 response.send(proto::GetChannelMembersResponse { members, users })?;
3408 Ok(())
3409}
3410
3411/// Accept or decline a channel invitation.
3412async fn respond_to_channel_invite(
3413 request: proto::RespondToChannelInvite,
3414 response: Response<proto::RespondToChannelInvite>,
3415 session: Session,
3416) -> Result<()> {
3417 let db = session.db().await;
3418 let channel_id = ChannelId::from_proto(request.channel_id);
3419 let RespondToChannelInvite {
3420 membership_update,
3421 notifications,
3422 } = db
3423 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3424 .await?;
3425
3426 let mut connection_pool = session.connection_pool().await;
3427 if let Some(membership_update) = membership_update {
3428 notify_membership_updated(
3429 &mut connection_pool,
3430 membership_update,
3431 session.user_id(),
3432 &session.peer,
3433 );
3434 } else {
3435 let update = proto::UpdateChannels {
3436 remove_channel_invitations: vec![channel_id.to_proto()],
3437 ..Default::default()
3438 };
3439
3440 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3441 session.peer.send(connection_id, update.clone())?;
3442 }
3443 };
3444
3445 send_notifications(&connection_pool, &session.peer, notifications);
3446
3447 response.send(proto::Ack {})?;
3448
3449 Ok(())
3450}
3451
3452/// Join the channels' room
3453async fn join_channel(
3454 request: proto::JoinChannel,
3455 response: Response<proto::JoinChannel>,
3456 session: Session,
3457) -> Result<()> {
3458 let channel_id = ChannelId::from_proto(request.channel_id);
3459 join_channel_internal(channel_id, Box::new(response), session).await
3460}
3461
3462trait JoinChannelInternalResponse {
3463 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3464}
3465impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3466 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3467 Response::<proto::JoinChannel>::send(self, result)
3468 }
3469}
3470impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3471 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3472 Response::<proto::JoinRoom>::send(self, result)
3473 }
3474}
3475
3476async fn join_channel_internal(
3477 channel_id: ChannelId,
3478 response: Box<impl JoinChannelInternalResponse>,
3479 session: Session,
3480) -> Result<()> {
3481 let joined_room = {
3482 let mut db = session.db().await;
3483 // If zed quits without leaving the room, and the user re-opens zed before the
3484 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3485 // room they were in.
3486 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3487 tracing::info!(
3488 stale_connection_id = %connection,
3489 "cleaning up stale connection",
3490 );
3491 drop(db);
3492 leave_room_for_session(&session, connection).await?;
3493 db = session.db().await;
3494 }
3495
3496 let (joined_room, membership_updated, role) = db
3497 .join_channel(channel_id, session.user_id(), session.connection_id)
3498 .await?;
3499
3500 let live_kit_connection_info =
3501 session
3502 .app_state
3503 .livekit_client
3504 .as_ref()
3505 .and_then(|live_kit| {
3506 let (can_publish, token) = if role == ChannelRole::Guest {
3507 (
3508 false,
3509 live_kit
3510 .guest_token(
3511 &joined_room.room.livekit_room,
3512 &session.user_id().to_string(),
3513 )
3514 .trace_err()?,
3515 )
3516 } else {
3517 (
3518 true,
3519 live_kit
3520 .room_token(
3521 &joined_room.room.livekit_room,
3522 &session.user_id().to_string(),
3523 )
3524 .trace_err()?,
3525 )
3526 };
3527
3528 Some(LiveKitConnectionInfo {
3529 server_url: live_kit.url().into(),
3530 token,
3531 can_publish,
3532 })
3533 });
3534
3535 response.send(proto::JoinRoomResponse {
3536 room: Some(joined_room.room.clone()),
3537 channel_id: joined_room
3538 .channel
3539 .as_ref()
3540 .map(|channel| channel.id.to_proto()),
3541 live_kit_connection_info,
3542 })?;
3543
3544 let mut connection_pool = session.connection_pool().await;
3545 if let Some(membership_updated) = membership_updated {
3546 notify_membership_updated(
3547 &mut connection_pool,
3548 membership_updated,
3549 session.user_id(),
3550 &session.peer,
3551 );
3552 }
3553
3554 room_updated(&joined_room.room, &session.peer);
3555
3556 joined_room
3557 };
3558
3559 channel_updated(
3560 &joined_room.channel.context("channel not returned")?,
3561 &joined_room.room,
3562 &session.peer,
3563 &*session.connection_pool().await,
3564 );
3565
3566 update_user_contacts(session.user_id(), &session).await?;
3567 Ok(())
3568}
3569
3570/// Start editing the channel notes
3571async fn join_channel_buffer(
3572 request: proto::JoinChannelBuffer,
3573 response: Response<proto::JoinChannelBuffer>,
3574 session: Session,
3575) -> Result<()> {
3576 let db = session.db().await;
3577 let channel_id = ChannelId::from_proto(request.channel_id);
3578
3579 let open_response = db
3580 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3581 .await?;
3582
3583 let collaborators = open_response.collaborators.clone();
3584 response.send(open_response)?;
3585
3586 let update = UpdateChannelBufferCollaborators {
3587 channel_id: channel_id.to_proto(),
3588 collaborators: collaborators.clone(),
3589 };
3590 channel_buffer_updated(
3591 session.connection_id,
3592 collaborators
3593 .iter()
3594 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3595 &update,
3596 &session.peer,
3597 );
3598
3599 Ok(())
3600}
3601
3602/// Edit the channel notes
3603async fn update_channel_buffer(
3604 request: proto::UpdateChannelBuffer,
3605 session: Session,
3606) -> Result<()> {
3607 let db = session.db().await;
3608 let channel_id = ChannelId::from_proto(request.channel_id);
3609
3610 let (collaborators, epoch, version) = db
3611 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3612 .await?;
3613
3614 channel_buffer_updated(
3615 session.connection_id,
3616 collaborators.clone(),
3617 &proto::UpdateChannelBuffer {
3618 channel_id: channel_id.to_proto(),
3619 operations: request.operations,
3620 },
3621 &session.peer,
3622 );
3623
3624 let pool = &*session.connection_pool().await;
3625
3626 let non_collaborators =
3627 pool.channel_connection_ids(channel_id)
3628 .filter_map(|(connection_id, _)| {
3629 if collaborators.contains(&connection_id) {
3630 None
3631 } else {
3632 Some(connection_id)
3633 }
3634 });
3635
3636 broadcast(None, non_collaborators, |peer_id| {
3637 session.peer.send(
3638 peer_id,
3639 proto::UpdateChannels {
3640 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3641 channel_id: channel_id.to_proto(),
3642 epoch: epoch as u64,
3643 version: version.clone(),
3644 }],
3645 ..Default::default()
3646 },
3647 )
3648 });
3649
3650 Ok(())
3651}
3652
3653/// Rejoin the channel notes after a connection blip
3654async fn rejoin_channel_buffers(
3655 request: proto::RejoinChannelBuffers,
3656 response: Response<proto::RejoinChannelBuffers>,
3657 session: Session,
3658) -> Result<()> {
3659 let db = session.db().await;
3660 let buffers = db
3661 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3662 .await?;
3663
3664 for rejoined_buffer in &buffers {
3665 let collaborators_to_notify = rejoined_buffer
3666 .buffer
3667 .collaborators
3668 .iter()
3669 .filter_map(|c| Some(c.peer_id?.into()));
3670 channel_buffer_updated(
3671 session.connection_id,
3672 collaborators_to_notify,
3673 &proto::UpdateChannelBufferCollaborators {
3674 channel_id: rejoined_buffer.buffer.channel_id,
3675 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3676 },
3677 &session.peer,
3678 );
3679 }
3680
3681 response.send(proto::RejoinChannelBuffersResponse {
3682 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3683 })?;
3684
3685 Ok(())
3686}
3687
3688/// Stop editing the channel notes
3689async fn leave_channel_buffer(
3690 request: proto::LeaveChannelBuffer,
3691 response: Response<proto::LeaveChannelBuffer>,
3692 session: Session,
3693) -> Result<()> {
3694 let db = session.db().await;
3695 let channel_id = ChannelId::from_proto(request.channel_id);
3696
3697 let left_buffer = db
3698 .leave_channel_buffer(channel_id, session.connection_id)
3699 .await?;
3700
3701 response.send(Ack {})?;
3702
3703 channel_buffer_updated(
3704 session.connection_id,
3705 left_buffer.connections,
3706 &proto::UpdateChannelBufferCollaborators {
3707 channel_id: channel_id.to_proto(),
3708 collaborators: left_buffer.collaborators,
3709 },
3710 &session.peer,
3711 );
3712
3713 Ok(())
3714}
3715
3716fn channel_buffer_updated<T: EnvelopedMessage>(
3717 sender_id: ConnectionId,
3718 collaborators: impl IntoIterator<Item = ConnectionId>,
3719 message: &T,
3720 peer: &Peer,
3721) {
3722 broadcast(Some(sender_id), collaborators, |peer_id| {
3723 peer.send(peer_id, message.clone())
3724 });
3725}
3726
3727fn send_notifications(
3728 connection_pool: &ConnectionPool,
3729 peer: &Peer,
3730 notifications: db::NotificationBatch,
3731) {
3732 for (user_id, notification) in notifications {
3733 for connection_id in connection_pool.user_connection_ids(user_id) {
3734 if let Err(error) = peer.send(
3735 connection_id,
3736 proto::AddNotification {
3737 notification: Some(notification.clone()),
3738 },
3739 ) {
3740 tracing::error!(
3741 "failed to send notification to {:?} {}",
3742 connection_id,
3743 error
3744 );
3745 }
3746 }
3747 }
3748}
3749
3750/// Send a message to the channel
3751async fn send_channel_message(
3752 request: proto::SendChannelMessage,
3753 response: Response<proto::SendChannelMessage>,
3754 session: Session,
3755) -> Result<()> {
3756 // Validate the message body.
3757 let body = request.body.trim().to_string();
3758 if body.len() > MAX_MESSAGE_LEN {
3759 return Err(anyhow!("message is too long"))?;
3760 }
3761 if body.is_empty() {
3762 return Err(anyhow!("message can't be blank"))?;
3763 }
3764
3765 // TODO: adjust mentions if body is trimmed
3766
3767 let timestamp = OffsetDateTime::now_utc();
3768 let nonce = request.nonce.context("nonce can't be blank")?;
3769
3770 let channel_id = ChannelId::from_proto(request.channel_id);
3771 let CreatedChannelMessage {
3772 message_id,
3773 participant_connection_ids,
3774 notifications,
3775 } = session
3776 .db()
3777 .await
3778 .create_channel_message(
3779 channel_id,
3780 session.user_id(),
3781 &body,
3782 &request.mentions,
3783 timestamp,
3784 nonce.clone().into(),
3785 request.reply_to_message_id.map(MessageId::from_proto),
3786 )
3787 .await?;
3788
3789 let message = proto::ChannelMessage {
3790 sender_id: session.user_id().to_proto(),
3791 id: message_id.to_proto(),
3792 body,
3793 mentions: request.mentions,
3794 timestamp: timestamp.unix_timestamp() as u64,
3795 nonce: Some(nonce),
3796 reply_to_message_id: request.reply_to_message_id,
3797 edited_at: None,
3798 };
3799 broadcast(
3800 Some(session.connection_id),
3801 participant_connection_ids.clone(),
3802 |connection| {
3803 session.peer.send(
3804 connection,
3805 proto::ChannelMessageSent {
3806 channel_id: channel_id.to_proto(),
3807 message: Some(message.clone()),
3808 },
3809 )
3810 },
3811 );
3812 response.send(proto::SendChannelMessageResponse {
3813 message: Some(message),
3814 })?;
3815
3816 let pool = &*session.connection_pool().await;
3817 let non_participants =
3818 pool.channel_connection_ids(channel_id)
3819 .filter_map(|(connection_id, _)| {
3820 if participant_connection_ids.contains(&connection_id) {
3821 None
3822 } else {
3823 Some(connection_id)
3824 }
3825 });
3826 broadcast(None, non_participants, |peer_id| {
3827 session.peer.send(
3828 peer_id,
3829 proto::UpdateChannels {
3830 latest_channel_message_ids: vec![proto::ChannelMessageId {
3831 channel_id: channel_id.to_proto(),
3832 message_id: message_id.to_proto(),
3833 }],
3834 ..Default::default()
3835 },
3836 )
3837 });
3838 send_notifications(pool, &session.peer, notifications);
3839
3840 Ok(())
3841}
3842
3843/// Delete a channel message
3844async fn remove_channel_message(
3845 request: proto::RemoveChannelMessage,
3846 response: Response<proto::RemoveChannelMessage>,
3847 session: Session,
3848) -> Result<()> {
3849 let channel_id = ChannelId::from_proto(request.channel_id);
3850 let message_id = MessageId::from_proto(request.message_id);
3851 let (connection_ids, existing_notification_ids) = session
3852 .db()
3853 .await
3854 .remove_channel_message(channel_id, message_id, session.user_id())
3855 .await?;
3856
3857 broadcast(
3858 Some(session.connection_id),
3859 connection_ids,
3860 move |connection| {
3861 session.peer.send(connection, request.clone())?;
3862
3863 for notification_id in &existing_notification_ids {
3864 session.peer.send(
3865 connection,
3866 proto::DeleteNotification {
3867 notification_id: (*notification_id).to_proto(),
3868 },
3869 )?;
3870 }
3871
3872 Ok(())
3873 },
3874 );
3875 response.send(proto::Ack {})?;
3876 Ok(())
3877}
3878
3879async fn update_channel_message(
3880 request: proto::UpdateChannelMessage,
3881 response: Response<proto::UpdateChannelMessage>,
3882 session: Session,
3883) -> Result<()> {
3884 let channel_id = ChannelId::from_proto(request.channel_id);
3885 let message_id = MessageId::from_proto(request.message_id);
3886 let updated_at = OffsetDateTime::now_utc();
3887 let UpdatedChannelMessage {
3888 message_id,
3889 participant_connection_ids,
3890 notifications,
3891 reply_to_message_id,
3892 timestamp,
3893 deleted_mention_notification_ids,
3894 updated_mention_notifications,
3895 } = session
3896 .db()
3897 .await
3898 .update_channel_message(
3899 channel_id,
3900 message_id,
3901 session.user_id(),
3902 request.body.as_str(),
3903 &request.mentions,
3904 updated_at,
3905 )
3906 .await?;
3907
3908 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3909
3910 let message = proto::ChannelMessage {
3911 sender_id: session.user_id().to_proto(),
3912 id: message_id.to_proto(),
3913 body: request.body.clone(),
3914 mentions: request.mentions.clone(),
3915 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3916 nonce: Some(nonce),
3917 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3918 edited_at: Some(updated_at.unix_timestamp() as u64),
3919 };
3920
3921 response.send(proto::Ack {})?;
3922
3923 let pool = &*session.connection_pool().await;
3924 broadcast(
3925 Some(session.connection_id),
3926 participant_connection_ids,
3927 |connection| {
3928 session.peer.send(
3929 connection,
3930 proto::ChannelMessageUpdate {
3931 channel_id: channel_id.to_proto(),
3932 message: Some(message.clone()),
3933 },
3934 )?;
3935
3936 for notification_id in &deleted_mention_notification_ids {
3937 session.peer.send(
3938 connection,
3939 proto::DeleteNotification {
3940 notification_id: (*notification_id).to_proto(),
3941 },
3942 )?;
3943 }
3944
3945 for notification in &updated_mention_notifications {
3946 session.peer.send(
3947 connection,
3948 proto::UpdateNotification {
3949 notification: Some(notification.clone()),
3950 },
3951 )?;
3952 }
3953
3954 Ok(())
3955 },
3956 );
3957
3958 send_notifications(pool, &session.peer, notifications);
3959
3960 Ok(())
3961}
3962
3963/// Mark a channel message as read
3964async fn acknowledge_channel_message(
3965 request: proto::AckChannelMessage,
3966 session: Session,
3967) -> Result<()> {
3968 let channel_id = ChannelId::from_proto(request.channel_id);
3969 let message_id = MessageId::from_proto(request.message_id);
3970 let notifications = session
3971 .db()
3972 .await
3973 .observe_channel_message(channel_id, session.user_id(), message_id)
3974 .await?;
3975 send_notifications(
3976 &*session.connection_pool().await,
3977 &session.peer,
3978 notifications,
3979 );
3980 Ok(())
3981}
3982
3983/// Mark a buffer version as synced
3984async fn acknowledge_buffer_version(
3985 request: proto::AckBufferOperation,
3986 session: Session,
3987) -> Result<()> {
3988 let buffer_id = BufferId::from_proto(request.buffer_id);
3989 session
3990 .db()
3991 .await
3992 .observe_buffer_version(
3993 buffer_id,
3994 session.user_id(),
3995 request.epoch as i32,
3996 &request.version,
3997 )
3998 .await?;
3999 Ok(())
4000}
4001
4002/// Get a Supermaven API key for the user
4003async fn get_supermaven_api_key(
4004 _request: proto::GetSupermavenApiKey,
4005 response: Response<proto::GetSupermavenApiKey>,
4006 session: Session,
4007) -> Result<()> {
4008 let user_id: String = session.user_id().to_string();
4009 if !session.is_staff() {
4010 return Err(anyhow!("supermaven not enabled for this account"))?;
4011 }
4012
4013 let email = session.email().context("user must have an email")?;
4014
4015 let supermaven_admin_api = session
4016 .supermaven_client
4017 .as_ref()
4018 .context("supermaven not configured")?;
4019
4020 let result = supermaven_admin_api
4021 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4022 .await?;
4023
4024 response.send(proto::GetSupermavenApiKeyResponse {
4025 api_key: result.api_key,
4026 })?;
4027
4028 Ok(())
4029}
4030
4031/// Start receiving chat updates for a channel
4032async fn join_channel_chat(
4033 request: proto::JoinChannelChat,
4034 response: Response<proto::JoinChannelChat>,
4035 session: Session,
4036) -> Result<()> {
4037 let channel_id = ChannelId::from_proto(request.channel_id);
4038
4039 let db = session.db().await;
4040 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4041 .await?;
4042 let messages = db
4043 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4044 .await?;
4045 response.send(proto::JoinChannelChatResponse {
4046 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4047 messages,
4048 })?;
4049 Ok(())
4050}
4051
4052/// Stop receiving chat updates for a channel
4053async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
4054 let channel_id = ChannelId::from_proto(request.channel_id);
4055 session
4056 .db()
4057 .await
4058 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4059 .await?;
4060 Ok(())
4061}
4062
4063/// Retrieve the chat history for a channel
4064async fn get_channel_messages(
4065 request: proto::GetChannelMessages,
4066 response: Response<proto::GetChannelMessages>,
4067 session: Session,
4068) -> Result<()> {
4069 let channel_id = ChannelId::from_proto(request.channel_id);
4070 let messages = session
4071 .db()
4072 .await
4073 .get_channel_messages(
4074 channel_id,
4075 session.user_id(),
4076 MESSAGE_COUNT_PER_PAGE,
4077 Some(MessageId::from_proto(request.before_message_id)),
4078 )
4079 .await?;
4080 response.send(proto::GetChannelMessagesResponse {
4081 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4082 messages,
4083 })?;
4084 Ok(())
4085}
4086
4087/// Retrieve specific chat messages
4088async fn get_channel_messages_by_id(
4089 request: proto::GetChannelMessagesById,
4090 response: Response<proto::GetChannelMessagesById>,
4091 session: Session,
4092) -> Result<()> {
4093 let message_ids = request
4094 .message_ids
4095 .iter()
4096 .map(|id| MessageId::from_proto(*id))
4097 .collect::<Vec<_>>();
4098 let messages = session
4099 .db()
4100 .await
4101 .get_channel_messages_by_id(session.user_id(), &message_ids)
4102 .await?;
4103 response.send(proto::GetChannelMessagesResponse {
4104 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4105 messages,
4106 })?;
4107 Ok(())
4108}
4109
4110/// Retrieve the current users notifications
4111async fn get_notifications(
4112 request: proto::GetNotifications,
4113 response: Response<proto::GetNotifications>,
4114 session: Session,
4115) -> Result<()> {
4116 let notifications = session
4117 .db()
4118 .await
4119 .get_notifications(
4120 session.user_id(),
4121 NOTIFICATION_COUNT_PER_PAGE,
4122 request.before_id.map(db::NotificationId::from_proto),
4123 )
4124 .await?;
4125 response.send(proto::GetNotificationsResponse {
4126 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4127 notifications,
4128 })?;
4129 Ok(())
4130}
4131
4132/// Mark notifications as read
4133async fn mark_notification_as_read(
4134 request: proto::MarkNotificationRead,
4135 response: Response<proto::MarkNotificationRead>,
4136 session: Session,
4137) -> Result<()> {
4138 let database = &session.db().await;
4139 let notifications = database
4140 .mark_notification_as_read_by_id(
4141 session.user_id(),
4142 NotificationId::from_proto(request.notification_id),
4143 )
4144 .await?;
4145 send_notifications(
4146 &*session.connection_pool().await,
4147 &session.peer,
4148 notifications,
4149 );
4150 response.send(proto::Ack {})?;
4151 Ok(())
4152}
4153
4154/// Get the current users information
4155async fn get_private_user_info(
4156 _request: proto::GetPrivateUserInfo,
4157 response: Response<proto::GetPrivateUserInfo>,
4158 session: Session,
4159) -> Result<()> {
4160 let db = session.db().await;
4161
4162 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4163 let user = db
4164 .get_user_by_id(session.user_id())
4165 .await?
4166 .context("user not found")?;
4167 let flags = db.get_user_flags(session.user_id()).await?;
4168
4169 response.send(proto::GetPrivateUserInfoResponse {
4170 metrics_id,
4171 staff: user.admin,
4172 flags,
4173 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4174 })?;
4175 Ok(())
4176}
4177
4178/// Accept the terms of service (tos) on behalf of the current user
4179async fn accept_terms_of_service(
4180 _request: proto::AcceptTermsOfService,
4181 response: Response<proto::AcceptTermsOfService>,
4182 session: Session,
4183) -> Result<()> {
4184 let db = session.db().await;
4185
4186 let accepted_tos_at = Utc::now();
4187 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4188 .await?;
4189
4190 response.send(proto::AcceptTermsOfServiceResponse {
4191 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4192 })?;
4193
4194 // When the user accepts the terms of service, we want to refresh their LLM
4195 // token to grant access.
4196 session
4197 .peer
4198 .send(session.connection_id, proto::RefreshLlmToken {})?;
4199
4200 Ok(())
4201}
4202
4203async fn get_llm_api_token(
4204 _request: proto::GetLlmToken,
4205 response: Response<proto::GetLlmToken>,
4206 session: Session,
4207) -> Result<()> {
4208 let db = session.db().await;
4209
4210 let flags = db.get_user_flags(session.user_id()).await?;
4211
4212 let user_id = session.user_id();
4213 let user = db
4214 .get_user_by_id(user_id)
4215 .await?
4216 .with_context(|| format!("user {user_id} not found"))?;
4217
4218 if user.accepted_tos_at.is_none() {
4219 Err(anyhow!("terms of service not accepted"))?
4220 }
4221
4222 let stripe_client = session
4223 .app_state
4224 .stripe_client
4225 .as_ref()
4226 .context("failed to retrieve Stripe client")?;
4227
4228 let stripe_billing = session
4229 .app_state
4230 .stripe_billing
4231 .as_ref()
4232 .context("failed to retrieve Stripe billing object")?;
4233
4234 let billing_customer = if let Some(billing_customer) =
4235 db.get_billing_customer_by_user_id(user.id).await?
4236 {
4237 billing_customer
4238 } else {
4239 let customer_id = stripe_billing
4240 .find_or_create_customer_by_email(user.email_address.as_deref())
4241 .await?;
4242
4243 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4244 .await?
4245 .context("billing customer not found")?
4246 };
4247
4248 let billing_subscription =
4249 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4250 billing_subscription
4251 } else {
4252 let stripe_customer_id =
4253 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4254
4255 let stripe_subscription = stripe_billing
4256 .subscribe_to_zed_free(stripe_customer_id)
4257 .await?;
4258
4259 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4260 billing_customer_id: billing_customer.id,
4261 kind: Some(SubscriptionKind::ZedFree),
4262 stripe_subscription_id: stripe_subscription.id.to_string(),
4263 stripe_subscription_status: stripe_subscription.status.into(),
4264 stripe_cancellation_reason: None,
4265 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4266 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4267 })
4268 .await?
4269 };
4270
4271 let billing_preferences = db.get_billing_preferences(user.id).await?;
4272
4273 let token = LlmTokenClaims::create(
4274 &user,
4275 session.is_staff(),
4276 billing_customer,
4277 billing_preferences,
4278 &flags,
4279 billing_subscription,
4280 session.system_id.clone(),
4281 &session.app_state.config,
4282 )?;
4283 response.send(proto::GetLlmTokenResponse { token })?;
4284 Ok(())
4285}
4286
4287fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4288 let message = match message {
4289 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4290 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4291 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4292 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4293 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4294 code: frame.code.into(),
4295 reason: frame.reason.as_str().to_owned().into(),
4296 })),
4297 // We should never receive a frame while reading the message, according
4298 // to the `tungstenite` maintainers:
4299 //
4300 // > It cannot occur when you read messages from the WebSocket, but it
4301 // > can be used when you want to send the raw frames (e.g. you want to
4302 // > send the frames to the WebSocket without composing the full message first).
4303 // >
4304 // > — https://github.com/snapview/tungstenite-rs/issues/268
4305 TungsteniteMessage::Frame(_) => {
4306 bail!("received an unexpected frame while reading the message")
4307 }
4308 };
4309
4310 Ok(message)
4311}
4312
4313fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4314 match message {
4315 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4316 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4317 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4318 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4319 AxumMessage::Close(frame) => {
4320 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4321 code: frame.code.into(),
4322 reason: frame.reason.as_ref().into(),
4323 }))
4324 }
4325 }
4326}
4327
4328fn notify_membership_updated(
4329 connection_pool: &mut ConnectionPool,
4330 result: MembershipUpdated,
4331 user_id: UserId,
4332 peer: &Peer,
4333) {
4334 for membership in &result.new_channels.channel_memberships {
4335 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4336 }
4337 for channel_id in &result.removed_channels {
4338 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4339 }
4340
4341 let user_channels_update = proto::UpdateUserChannels {
4342 channel_memberships: result
4343 .new_channels
4344 .channel_memberships
4345 .iter()
4346 .map(|cm| proto::ChannelMembership {
4347 channel_id: cm.channel_id.to_proto(),
4348 role: cm.role.into(),
4349 })
4350 .collect(),
4351 ..Default::default()
4352 };
4353
4354 let mut update = build_channels_update(result.new_channels);
4355 update.delete_channels = result
4356 .removed_channels
4357 .into_iter()
4358 .map(|id| id.to_proto())
4359 .collect();
4360 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4361
4362 for connection_id in connection_pool.user_connection_ids(user_id) {
4363 peer.send(connection_id, user_channels_update.clone())
4364 .trace_err();
4365 peer.send(connection_id, update.clone()).trace_err();
4366 }
4367}
4368
4369fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4370 proto::UpdateUserChannels {
4371 channel_memberships: channels
4372 .channel_memberships
4373 .iter()
4374 .map(|m| proto::ChannelMembership {
4375 channel_id: m.channel_id.to_proto(),
4376 role: m.role.into(),
4377 })
4378 .collect(),
4379 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4380 observed_channel_message_id: channels.observed_channel_messages.clone(),
4381 }
4382}
4383
4384fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4385 let mut update = proto::UpdateChannels::default();
4386
4387 for channel in channels.channels {
4388 update.channels.push(channel.to_proto());
4389 }
4390
4391 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4392 update.latest_channel_message_ids = channels.latest_channel_messages;
4393
4394 for (channel_id, participants) in channels.channel_participants {
4395 update
4396 .channel_participants
4397 .push(proto::ChannelParticipants {
4398 channel_id: channel_id.to_proto(),
4399 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4400 });
4401 }
4402
4403 for channel in channels.invited_channels {
4404 update.channel_invitations.push(channel.to_proto());
4405 }
4406
4407 update
4408}
4409
4410fn build_initial_contacts_update(
4411 contacts: Vec<db::Contact>,
4412 pool: &ConnectionPool,
4413) -> proto::UpdateContacts {
4414 let mut update = proto::UpdateContacts::default();
4415
4416 for contact in contacts {
4417 match contact {
4418 db::Contact::Accepted { user_id, busy } => {
4419 update.contacts.push(contact_for_user(user_id, busy, pool));
4420 }
4421 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4422 db::Contact::Incoming { user_id } => {
4423 update
4424 .incoming_requests
4425 .push(proto::IncomingContactRequest {
4426 requester_id: user_id.to_proto(),
4427 })
4428 }
4429 }
4430 }
4431
4432 update
4433}
4434
4435fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4436 proto::Contact {
4437 user_id: user_id.to_proto(),
4438 online: pool.is_user_online(user_id),
4439 busy,
4440 }
4441}
4442
4443fn room_updated(room: &proto::Room, peer: &Peer) {
4444 broadcast(
4445 None,
4446 room.participants
4447 .iter()
4448 .filter_map(|participant| Some(participant.peer_id?.into())),
4449 |peer_id| {
4450 peer.send(
4451 peer_id,
4452 proto::RoomUpdated {
4453 room: Some(room.clone()),
4454 },
4455 )
4456 },
4457 );
4458}
4459
4460fn channel_updated(
4461 channel: &db::channel::Model,
4462 room: &proto::Room,
4463 peer: &Peer,
4464 pool: &ConnectionPool,
4465) {
4466 let participants = room
4467 .participants
4468 .iter()
4469 .map(|p| p.user_id)
4470 .collect::<Vec<_>>();
4471
4472 broadcast(
4473 None,
4474 pool.channel_connection_ids(channel.root_id())
4475 .filter_map(|(channel_id, role)| {
4476 role.can_see_channel(channel.visibility)
4477 .then_some(channel_id)
4478 }),
4479 |peer_id| {
4480 peer.send(
4481 peer_id,
4482 proto::UpdateChannels {
4483 channel_participants: vec![proto::ChannelParticipants {
4484 channel_id: channel.id.to_proto(),
4485 participant_user_ids: participants.clone(),
4486 }],
4487 ..Default::default()
4488 },
4489 )
4490 },
4491 );
4492}
4493
4494async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4495 let db = session.db().await;
4496
4497 let contacts = db.get_contacts(user_id).await?;
4498 let busy = db.is_user_busy(user_id).await?;
4499
4500 let pool = session.connection_pool().await;
4501 let updated_contact = contact_for_user(user_id, busy, &pool);
4502 for contact in contacts {
4503 if let db::Contact::Accepted {
4504 user_id: contact_user_id,
4505 ..
4506 } = contact
4507 {
4508 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4509 session
4510 .peer
4511 .send(
4512 contact_conn_id,
4513 proto::UpdateContacts {
4514 contacts: vec![updated_contact.clone()],
4515 remove_contacts: Default::default(),
4516 incoming_requests: Default::default(),
4517 remove_incoming_requests: Default::default(),
4518 outgoing_requests: Default::default(),
4519 remove_outgoing_requests: Default::default(),
4520 },
4521 )
4522 .trace_err();
4523 }
4524 }
4525 }
4526 Ok(())
4527}
4528
4529async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4530 let mut contacts_to_update = HashSet::default();
4531
4532 let room_id;
4533 let canceled_calls_to_user_ids;
4534 let livekit_room;
4535 let delete_livekit_room;
4536 let room;
4537 let channel;
4538
4539 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4540 contacts_to_update.insert(session.user_id());
4541
4542 for project in left_room.left_projects.values() {
4543 project_left(project, session);
4544 }
4545
4546 room_id = RoomId::from_proto(left_room.room.id);
4547 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4548 livekit_room = mem::take(&mut left_room.room.livekit_room);
4549 delete_livekit_room = left_room.deleted;
4550 room = mem::take(&mut left_room.room);
4551 channel = mem::take(&mut left_room.channel);
4552
4553 room_updated(&room, &session.peer);
4554 } else {
4555 return Ok(());
4556 }
4557
4558 if let Some(channel) = channel {
4559 channel_updated(
4560 &channel,
4561 &room,
4562 &session.peer,
4563 &*session.connection_pool().await,
4564 );
4565 }
4566
4567 {
4568 let pool = session.connection_pool().await;
4569 for canceled_user_id in canceled_calls_to_user_ids {
4570 for connection_id in pool.user_connection_ids(canceled_user_id) {
4571 session
4572 .peer
4573 .send(
4574 connection_id,
4575 proto::CallCanceled {
4576 room_id: room_id.to_proto(),
4577 },
4578 )
4579 .trace_err();
4580 }
4581 contacts_to_update.insert(canceled_user_id);
4582 }
4583 }
4584
4585 for contact_user_id in contacts_to_update {
4586 update_user_contacts(contact_user_id, session).await?;
4587 }
4588
4589 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4590 live_kit
4591 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4592 .await
4593 .trace_err();
4594
4595 if delete_livekit_room {
4596 live_kit.delete_room(livekit_room).await.trace_err();
4597 }
4598 }
4599
4600 Ok(())
4601}
4602
4603async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4604 let left_channel_buffers = session
4605 .db()
4606 .await
4607 .leave_channel_buffers(session.connection_id)
4608 .await?;
4609
4610 for left_buffer in left_channel_buffers {
4611 channel_buffer_updated(
4612 session.connection_id,
4613 left_buffer.connections,
4614 &proto::UpdateChannelBufferCollaborators {
4615 channel_id: left_buffer.channel_id.to_proto(),
4616 collaborators: left_buffer.collaborators,
4617 },
4618 &session.peer,
4619 );
4620 }
4621
4622 Ok(())
4623}
4624
4625fn project_left(project: &db::LeftProject, session: &Session) {
4626 for connection_id in &project.connection_ids {
4627 if project.should_unshare {
4628 session
4629 .peer
4630 .send(
4631 *connection_id,
4632 proto::UnshareProject {
4633 project_id: project.id.to_proto(),
4634 },
4635 )
4636 .trace_err();
4637 } else {
4638 session
4639 .peer
4640 .send(
4641 *connection_id,
4642 proto::RemoveProjectCollaborator {
4643 project_id: project.id.to_proto(),
4644 peer_id: Some(session.connection_id.into()),
4645 },
4646 )
4647 .trace_err();
4648 }
4649 }
4650}
4651
4652pub trait ResultExt {
4653 type Ok;
4654
4655 fn trace_err(self) -> Option<Self::Ok>;
4656}
4657
4658impl<T, E> ResultExt for Result<T, E>
4659where
4660 E: std::fmt::Debug,
4661{
4662 type Ok = T;
4663
4664 #[track_caller]
4665 fn trace_err(self) -> Option<T> {
4666 match self {
4667 Ok(value) => Some(value),
4668 Err(error) => {
4669 tracing::error!("{:?}", error);
4670 None
4671 }
4672 }
4673 }
4674}