1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{
8 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
9 MIN_ACCOUNT_AGE_FOR_LLM_USE,
10};
11use crate::stripe_client::StripeCustomerId;
12use crate::{
13 AppState, Error, Result, auth,
14 db::{
15 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
16 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
17 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
18 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
19 },
20 executor::Executor,
21};
22use anyhow::{Context as _, anyhow, bail};
23use async_tungstenite::tungstenite::{
24 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
25};
26use axum::headers::UserAgent;
27use axum::{
28 Extension, Router, TypedHeader,
29 body::Body,
30 extract::{
31 ConnectInfo, WebSocketUpgrade,
32 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
33 },
34 headers::{Header, HeaderName},
35 http::StatusCode,
36 middleware,
37 response::IntoResponse,
38 routing::get,
39};
40use chrono::Utc;
41use collections::{HashMap, HashSet};
42pub use connection_pool::{ConnectionPool, ZedVersion};
43use core::fmt::{self, Debug, Formatter};
44use reqwest_client::ReqwestClient;
45use rpc::proto::{MultiLspQuery, split_repository_update};
46use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
47
48use futures::{
49 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
50 stream::FuturesUnordered,
51};
52use prometheus::{IntGauge, register_int_gauge};
53use rpc::{
54 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
55 proto::{
56 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
57 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
58 },
59};
60use semantic_version::SemanticVersion;
61use serde::{Serialize, Serializer};
62use std::{
63 any::TypeId,
64 future::Future,
65 marker::PhantomData,
66 mem,
67 net::SocketAddr,
68 ops::{Deref, DerefMut},
69 rc::Rc,
70 sync::{
71 Arc, OnceLock,
72 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
73 },
74 time::{Duration, Instant},
75};
76use time::OffsetDateTime;
77use tokio::sync::{Semaphore, watch};
78use tower::ServiceBuilder;
79use tracing::{
80 Instrument,
81 field::{self},
82 info_span, instrument,
83};
84
85pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
86
87// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
88pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
89
90const MESSAGE_COUNT_PER_PAGE: usize = 100;
91const MAX_MESSAGE_LEN: usize = 1024;
92const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
93const MAX_CONCURRENT_CONNECTIONS: usize = 512;
94
95static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
96
97type MessageHandler =
98 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
99
100pub struct ConnectionGuard;
101
102impl ConnectionGuard {
103 pub fn try_acquire() -> Result<Self, ()> {
104 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
105 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
106 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
107 tracing::error!(
108 "too many concurrent connections: {}",
109 current_connections + 1
110 );
111 return Err(());
112 }
113 Ok(ConnectionGuard)
114 }
115}
116
117impl Drop for ConnectionGuard {
118 fn drop(&mut self) {
119 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
120 }
121}
122
123struct Response<R> {
124 peer: Arc<Peer>,
125 receipt: Receipt<R>,
126 responded: Arc<AtomicBool>,
127}
128
129impl<R: RequestMessage> Response<R> {
130 fn send(self, payload: R::Response) -> Result<()> {
131 self.responded.store(true, SeqCst);
132 self.peer.respond(self.receipt, payload)?;
133 Ok(())
134 }
135}
136
137#[derive(Clone, Debug)]
138pub enum Principal {
139 User(User),
140 Impersonated { user: User, admin: User },
141}
142
143impl Principal {
144 fn user(&self) -> &User {
145 match self {
146 Principal::User(user) => user,
147 Principal::Impersonated { user, .. } => user,
148 }
149 }
150
151 fn update_span(&self, span: &tracing::Span) {
152 match &self {
153 Principal::User(user) => {
154 span.record("user_id", user.id.0);
155 span.record("login", &user.github_login);
156 }
157 Principal::Impersonated { user, admin } => {
158 span.record("user_id", user.id.0);
159 span.record("login", &user.github_login);
160 span.record("impersonator", &admin.github_login);
161 }
162 }
163 }
164}
165
166#[derive(Clone)]
167struct Session {
168 principal: Principal,
169 connection_id: ConnectionId,
170 db: Arc<tokio::sync::Mutex<DbHandle>>,
171 peer: Arc<Peer>,
172 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
173 app_state: Arc<AppState>,
174 supermaven_client: Option<Arc<SupermavenAdminApi>>,
175 /// The GeoIP country code for the user.
176 #[allow(unused)]
177 geoip_country_code: Option<String>,
178 system_id: Option<String>,
179 _executor: Executor,
180}
181
182impl Session {
183 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
184 #[cfg(test)]
185 tokio::task::yield_now().await;
186 let guard = self.db.lock().await;
187 #[cfg(test)]
188 tokio::task::yield_now().await;
189 guard
190 }
191
192 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
193 #[cfg(test)]
194 tokio::task::yield_now().await;
195 let guard = self.connection_pool.lock();
196 ConnectionPoolGuard {
197 guard,
198 _not_send: PhantomData,
199 }
200 }
201
202 fn is_staff(&self) -> bool {
203 match &self.principal {
204 Principal::User(user) => user.admin,
205 Principal::Impersonated { .. } => true,
206 }
207 }
208
209 fn user_id(&self) -> UserId {
210 match &self.principal {
211 Principal::User(user) => user.id,
212 Principal::Impersonated { user, .. } => user.id,
213 }
214 }
215
216 pub fn email(&self) -> Option<String> {
217 match &self.principal {
218 Principal::User(user) => user.email_address.clone(),
219 Principal::Impersonated { user, .. } => user.email_address.clone(),
220 }
221 }
222}
223
224impl Debug for Session {
225 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
226 let mut result = f.debug_struct("Session");
227 match &self.principal {
228 Principal::User(user) => {
229 result.field("user", &user.github_login);
230 }
231 Principal::Impersonated { user, admin } => {
232 result.field("user", &user.github_login);
233 result.field("impersonator", &admin.github_login);
234 }
235 }
236 result.field("connection_id", &self.connection_id).finish()
237 }
238}
239
240struct DbHandle(Arc<Database>);
241
242impl Deref for DbHandle {
243 type Target = Database;
244
245 fn deref(&self) -> &Self::Target {
246 self.0.as_ref()
247 }
248}
249
250pub struct Server {
251 id: parking_lot::Mutex<ServerId>,
252 peer: Arc<Peer>,
253 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
254 app_state: Arc<AppState>,
255 handlers: HashMap<TypeId, MessageHandler>,
256 teardown: watch::Sender<bool>,
257}
258
259pub(crate) struct ConnectionPoolGuard<'a> {
260 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
261 _not_send: PhantomData<Rc<()>>,
262}
263
264#[derive(Serialize)]
265pub struct ServerSnapshot<'a> {
266 peer: &'a Peer,
267 #[serde(serialize_with = "serialize_deref")]
268 connection_pool: ConnectionPoolGuard<'a>,
269}
270
271pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
272where
273 S: Serializer,
274 T: Deref<Target = U>,
275 U: Serialize,
276{
277 Serialize::serialize(value.deref(), serializer)
278}
279
280impl Server {
281 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
282 let mut server = Self {
283 id: parking_lot::Mutex::new(id),
284 peer: Peer::new(id.0 as u32),
285 app_state: app_state.clone(),
286 connection_pool: Default::default(),
287 handlers: Default::default(),
288 teardown: watch::channel(false).0,
289 };
290
291 server
292 .add_request_handler(ping)
293 .add_request_handler(create_room)
294 .add_request_handler(join_room)
295 .add_request_handler(rejoin_room)
296 .add_request_handler(leave_room)
297 .add_request_handler(set_room_participant_role)
298 .add_request_handler(call)
299 .add_request_handler(cancel_call)
300 .add_message_handler(decline_call)
301 .add_request_handler(update_participant_location)
302 .add_request_handler(share_project)
303 .add_message_handler(unshare_project)
304 .add_request_handler(join_project)
305 .add_message_handler(leave_project)
306 .add_request_handler(update_project)
307 .add_request_handler(update_worktree)
308 .add_request_handler(update_repository)
309 .add_request_handler(remove_repository)
310 .add_message_handler(start_language_server)
311 .add_message_handler(update_language_server)
312 .add_message_handler(update_diagnostic_summary)
313 .add_message_handler(update_worktree_settings)
314 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
316 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
317 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
318 .add_request_handler(forward_find_search_candidates_request)
319 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
320 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
321 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
322 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
323 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
324 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
325 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
326 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
327 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
328 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
329 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
330 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
331 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
332 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
333 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
334 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
335 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
336 .add_request_handler(
337 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
338 )
339 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
340 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
341 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
342 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
343 .add_request_handler(
344 forward_read_only_project_request::<proto::LanguageServerIdForName>,
345 )
346 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
347 .add_request_handler(
348 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
349 )
350 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
351 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
352 .add_request_handler(
353 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
354 )
355 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
356 .add_request_handler(
357 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
358 )
359 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
360 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
361 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
362 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
363 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
364 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
365 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
366 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
367 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
368 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
369 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
370 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
371 .add_request_handler(
372 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
373 )
374 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
375 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
376 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
377 .add_request_handler(multi_lsp_query)
378 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
379 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
380 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
381 .add_message_handler(create_buffer_for_peer)
382 .add_request_handler(update_buffer)
383 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
384 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
385 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
386 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
387 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
388 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
389 .add_message_handler(
390 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
391 )
392 .add_request_handler(get_users)
393 .add_request_handler(fuzzy_search_users)
394 .add_request_handler(request_contact)
395 .add_request_handler(remove_contact)
396 .add_request_handler(respond_to_contact_request)
397 .add_message_handler(subscribe_to_channels)
398 .add_request_handler(create_channel)
399 .add_request_handler(delete_channel)
400 .add_request_handler(invite_channel_member)
401 .add_request_handler(remove_channel_member)
402 .add_request_handler(set_channel_member_role)
403 .add_request_handler(set_channel_visibility)
404 .add_request_handler(rename_channel)
405 .add_request_handler(join_channel_buffer)
406 .add_request_handler(leave_channel_buffer)
407 .add_message_handler(update_channel_buffer)
408 .add_request_handler(rejoin_channel_buffers)
409 .add_request_handler(get_channel_members)
410 .add_request_handler(respond_to_channel_invite)
411 .add_request_handler(join_channel)
412 .add_request_handler(join_channel_chat)
413 .add_message_handler(leave_channel_chat)
414 .add_request_handler(send_channel_message)
415 .add_request_handler(remove_channel_message)
416 .add_request_handler(update_channel_message)
417 .add_request_handler(get_channel_messages)
418 .add_request_handler(get_channel_messages_by_id)
419 .add_request_handler(get_notifications)
420 .add_request_handler(mark_notification_as_read)
421 .add_request_handler(move_channel)
422 .add_request_handler(reorder_channel)
423 .add_request_handler(follow)
424 .add_message_handler(unfollow)
425 .add_message_handler(update_followers)
426 .add_request_handler(get_private_user_info)
427 .add_request_handler(get_llm_api_token)
428 .add_request_handler(accept_terms_of_service)
429 .add_message_handler(acknowledge_channel_message)
430 .add_message_handler(acknowledge_buffer_version)
431 .add_request_handler(get_supermaven_api_key)
432 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
433 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
434 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
435 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
436 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
437 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
438 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
439 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
440 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
441 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
442 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
443 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
444 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
445 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
446 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
447 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
448 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
449 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
450 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
451 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
452 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
453 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
454 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
455 .add_message_handler(update_context);
456
457 Arc::new(server)
458 }
459
460 pub async fn start(&self) -> Result<()> {
461 let server_id = *self.id.lock();
462 let app_state = self.app_state.clone();
463 let peer = self.peer.clone();
464 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
465 let pool = self.connection_pool.clone();
466 let livekit_client = self.app_state.livekit_client.clone();
467
468 let span = info_span!("start server");
469 self.app_state.executor.spawn_detached(
470 async move {
471 tracing::info!("waiting for cleanup timeout");
472 timeout.await;
473 tracing::info!("cleanup timeout expired, retrieving stale rooms");
474
475 app_state
476 .db
477 .delete_stale_channel_chat_participants(
478 &app_state.config.zed_environment,
479 server_id,
480 )
481 .await
482 .trace_err();
483
484 if let Some((room_ids, channel_ids)) = app_state
485 .db
486 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
487 .await
488 .trace_err()
489 {
490 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
491 tracing::info!(
492 stale_channel_buffer_count = channel_ids.len(),
493 "retrieved stale channel buffers"
494 );
495
496 for channel_id in channel_ids {
497 if let Some(refreshed_channel_buffer) = app_state
498 .db
499 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
500 .await
501 .trace_err()
502 {
503 for connection_id in refreshed_channel_buffer.connection_ids {
504 peer.send(
505 connection_id,
506 proto::UpdateChannelBufferCollaborators {
507 channel_id: channel_id.to_proto(),
508 collaborators: refreshed_channel_buffer
509 .collaborators
510 .clone(),
511 },
512 )
513 .trace_err();
514 }
515 }
516 }
517
518 for room_id in room_ids {
519 let mut contacts_to_update = HashSet::default();
520 let mut canceled_calls_to_user_ids = Vec::new();
521 let mut livekit_room = String::new();
522 let mut delete_livekit_room = false;
523
524 if let Some(mut refreshed_room) = app_state
525 .db
526 .clear_stale_room_participants(room_id, server_id)
527 .await
528 .trace_err()
529 {
530 tracing::info!(
531 room_id = room_id.0,
532 new_participant_count = refreshed_room.room.participants.len(),
533 "refreshed room"
534 );
535 room_updated(&refreshed_room.room, &peer);
536 if let Some(channel) = refreshed_room.channel.as_ref() {
537 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
538 }
539 contacts_to_update
540 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
541 contacts_to_update
542 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
543 canceled_calls_to_user_ids =
544 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
545 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
546 delete_livekit_room = refreshed_room.room.participants.is_empty();
547 }
548
549 {
550 let pool = pool.lock();
551 for canceled_user_id in canceled_calls_to_user_ids {
552 for connection_id in pool.user_connection_ids(canceled_user_id) {
553 peer.send(
554 connection_id,
555 proto::CallCanceled {
556 room_id: room_id.to_proto(),
557 },
558 )
559 .trace_err();
560 }
561 }
562 }
563
564 for user_id in contacts_to_update {
565 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
566 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
567 if let Some((busy, contacts)) = busy.zip(contacts) {
568 let pool = pool.lock();
569 let updated_contact = contact_for_user(user_id, busy, &pool);
570 for contact in contacts {
571 if let db::Contact::Accepted {
572 user_id: contact_user_id,
573 ..
574 } = contact
575 {
576 for contact_conn_id in
577 pool.user_connection_ids(contact_user_id)
578 {
579 peer.send(
580 contact_conn_id,
581 proto::UpdateContacts {
582 contacts: vec![updated_contact.clone()],
583 remove_contacts: Default::default(),
584 incoming_requests: Default::default(),
585 remove_incoming_requests: Default::default(),
586 outgoing_requests: Default::default(),
587 remove_outgoing_requests: Default::default(),
588 },
589 )
590 .trace_err();
591 }
592 }
593 }
594 }
595 }
596
597 if let Some(live_kit) = livekit_client.as_ref() {
598 if delete_livekit_room {
599 live_kit.delete_room(livekit_room).await.trace_err();
600 }
601 }
602 }
603 }
604
605 app_state
606 .db
607 .delete_stale_channel_chat_participants(
608 &app_state.config.zed_environment,
609 server_id,
610 )
611 .await
612 .trace_err();
613
614 app_state
615 .db
616 .clear_old_worktree_entries(server_id)
617 .await
618 .trace_err();
619
620 app_state
621 .db
622 .delete_stale_servers(&app_state.config.zed_environment, server_id)
623 .await
624 .trace_err();
625 }
626 .instrument(span),
627 );
628 Ok(())
629 }
630
631 pub fn teardown(&self) {
632 self.peer.teardown();
633 self.connection_pool.lock().reset();
634 let _ = self.teardown.send(true);
635 }
636
637 #[cfg(test)]
638 pub fn reset(&self, id: ServerId) {
639 self.teardown();
640 *self.id.lock() = id;
641 self.peer.reset(id.0 as u32);
642 let _ = self.teardown.send(false);
643 }
644
645 #[cfg(test)]
646 pub fn id(&self) -> ServerId {
647 *self.id.lock()
648 }
649
650 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
651 where
652 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
653 Fut: 'static + Send + Future<Output = Result<()>>,
654 M: EnvelopedMessage,
655 {
656 let prev_handler = self.handlers.insert(
657 TypeId::of::<M>(),
658 Box::new(move |envelope, session| {
659 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
660 let received_at = envelope.received_at;
661 tracing::info!("message received");
662 let start_time = Instant::now();
663 let future = (handler)(*envelope, session);
664 async move {
665 let result = future.await;
666 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
667 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
668 let queue_duration_ms = total_duration_ms - processing_duration_ms;
669 let payload_type = M::NAME;
670
671 match result {
672 Err(error) => {
673 tracing::error!(
674 ?error,
675 total_duration_ms,
676 processing_duration_ms,
677 queue_duration_ms,
678 payload_type,
679 "error handling message"
680 )
681 }
682 Ok(()) => tracing::info!(
683 total_duration_ms,
684 processing_duration_ms,
685 queue_duration_ms,
686 "finished handling message"
687 ),
688 }
689 }
690 .boxed()
691 }),
692 );
693 if prev_handler.is_some() {
694 panic!("registered a handler for the same message twice");
695 }
696 self
697 }
698
699 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
700 where
701 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
702 Fut: 'static + Send + Future<Output = Result<()>>,
703 M: EnvelopedMessage,
704 {
705 self.add_handler(move |envelope, session| handler(envelope.payload, session));
706 self
707 }
708
709 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
710 where
711 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
712 Fut: Send + Future<Output = Result<()>>,
713 M: RequestMessage,
714 {
715 let handler = Arc::new(handler);
716 self.add_handler(move |envelope, session| {
717 let receipt = envelope.receipt();
718 let handler = handler.clone();
719 async move {
720 let peer = session.peer.clone();
721 let responded = Arc::new(AtomicBool::default());
722 let response = Response {
723 peer: peer.clone(),
724 responded: responded.clone(),
725 receipt,
726 };
727 match (handler)(envelope.payload, response, session).await {
728 Ok(()) => {
729 if responded.load(std::sync::atomic::Ordering::SeqCst) {
730 Ok(())
731 } else {
732 Err(anyhow!("handler did not send a response"))?
733 }
734 }
735 Err(error) => {
736 let proto_err = match &error {
737 Error::Internal(err) => err.to_proto(),
738 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
739 };
740 peer.respond_with_error(receipt, proto_err)?;
741 Err(error)
742 }
743 }
744 }
745 })
746 }
747
748 pub fn handle_connection(
749 self: &Arc<Self>,
750 connection: Connection,
751 address: String,
752 principal: Principal,
753 zed_version: ZedVersion,
754 user_agent: Option<String>,
755 geoip_country_code: Option<String>,
756 system_id: Option<String>,
757 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
758 executor: Executor,
759 connection_guard: Option<ConnectionGuard>,
760 ) -> impl Future<Output = ()> + use<> {
761 let this = self.clone();
762 let span = info_span!("handle connection", %address,
763 connection_id=field::Empty,
764 user_id=field::Empty,
765 login=field::Empty,
766 impersonator=field::Empty,
767 user_agent=field::Empty,
768 geoip_country_code=field::Empty
769 );
770 principal.update_span(&span);
771 if let Some(user_agent) = user_agent {
772 span.record("user_agent", user_agent);
773 }
774
775 if let Some(country_code) = geoip_country_code.as_ref() {
776 span.record("geoip_country_code", country_code);
777 }
778
779 let mut teardown = self.teardown.subscribe();
780 async move {
781 if *teardown.borrow() {
782 tracing::error!("server is tearing down");
783 return
784 }
785
786 let (connection_id, handle_io, mut incoming_rx) = this
787 .peer
788 .add_connection(connection, {
789 let executor = executor.clone();
790 move |duration| executor.sleep(duration)
791 });
792 tracing::Span::current().record("connection_id", format!("{}", connection_id));
793
794 tracing::info!("connection opened");
795
796 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
797 let http_client = match ReqwestClient::user_agent(&user_agent) {
798 Ok(http_client) => Arc::new(http_client),
799 Err(error) => {
800 tracing::error!(?error, "failed to create HTTP client");
801 return;
802 }
803 };
804
805 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
806 supermaven_admin_api_key.to_string(),
807 http_client.clone(),
808 )));
809
810 let session = Session {
811 principal: principal.clone(),
812 connection_id,
813 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
814 peer: this.peer.clone(),
815 connection_pool: this.connection_pool.clone(),
816 app_state: this.app_state.clone(),
817 geoip_country_code,
818 system_id,
819 _executor: executor.clone(),
820 supermaven_client,
821 };
822
823 if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
824 tracing::error!(?error, "failed to send initial client update");
825 return;
826 }
827 drop(connection_guard);
828
829 let handle_io = handle_io.fuse();
830 futures::pin_mut!(handle_io);
831
832 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
833 // This prevents deadlocks when e.g., client A performs a request to client B and
834 // client B performs a request to client A. If both clients stop processing further
835 // messages until their respective request completes, they won't have a chance to
836 // respond to the other client's request and cause a deadlock.
837 //
838 // This arrangement ensures we will attempt to process earlier messages first, but fall
839 // back to processing messages arrived later in the spirit of making progress.
840 let mut foreground_message_handlers = FuturesUnordered::new();
841 let concurrent_handlers = Arc::new(Semaphore::new(256));
842 loop {
843 let next_message = async {
844 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
845 let message = incoming_rx.next().await;
846 (permit, message)
847 }.fuse();
848 futures::pin_mut!(next_message);
849 futures::select_biased! {
850 _ = teardown.changed().fuse() => return,
851 result = handle_io => {
852 if let Err(error) = result {
853 tracing::error!(?error, "error handling I/O");
854 }
855 break;
856 }
857 _ = foreground_message_handlers.next() => {}
858 next_message = next_message => {
859 let (permit, message) = next_message;
860 if let Some(message) = message {
861 let type_name = message.payload_type_name();
862 // note: we copy all the fields from the parent span so we can query them in the logs.
863 // (https://github.com/tokio-rs/tracing/issues/2670).
864 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
865 user_id=field::Empty,
866 login=field::Empty,
867 impersonator=field::Empty,
868 multi_lsp_query_request=field::Empty,
869 );
870 principal.update_span(&span);
871 let span_enter = span.enter();
872 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
873 let is_background = message.is_background();
874 let handle_message = (handler)(message, session.clone());
875 drop(span_enter);
876
877 let handle_message = async move {
878 handle_message.await;
879 drop(permit);
880 }.instrument(span);
881 if is_background {
882 executor.spawn_detached(handle_message);
883 } else {
884 foreground_message_handlers.push(handle_message);
885 }
886 } else {
887 tracing::error!("no message handler");
888 }
889 } else {
890 tracing::info!("connection closed");
891 break;
892 }
893 }
894 }
895 }
896
897 drop(foreground_message_handlers);
898 tracing::info!("signing out");
899 if let Err(error) = connection_lost(session, teardown, executor).await {
900 tracing::error!(?error, "error signing out");
901 }
902
903 }.instrument(span)
904 }
905
906 async fn send_initial_client_update(
907 &self,
908 connection_id: ConnectionId,
909 zed_version: ZedVersion,
910 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
911 session: &Session,
912 ) -> Result<()> {
913 self.peer.send(
914 connection_id,
915 proto::Hello {
916 peer_id: Some(connection_id.into()),
917 },
918 )?;
919 tracing::info!("sent hello message");
920 if let Some(send_connection_id) = send_connection_id.take() {
921 let _ = send_connection_id.send(connection_id);
922 }
923
924 match &session.principal {
925 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
926 if !user.connected_once {
927 self.peer.send(connection_id, proto::ShowContacts {})?;
928 self.app_state
929 .db
930 .set_user_connected_once(user.id, true)
931 .await?;
932 }
933
934 update_user_plan(session).await?;
935
936 let contacts = self.app_state.db.get_contacts(user.id).await?;
937
938 {
939 let mut pool = self.connection_pool.lock();
940 pool.add_connection(connection_id, user.id, user.admin, zed_version);
941 self.peer.send(
942 connection_id,
943 build_initial_contacts_update(contacts, &pool),
944 )?;
945 }
946
947 if should_auto_subscribe_to_channels(zed_version) {
948 subscribe_user_to_channels(user.id, session).await?;
949 }
950
951 if let Some(incoming_call) =
952 self.app_state.db.incoming_call_for_user(user.id).await?
953 {
954 self.peer.send(connection_id, incoming_call)?;
955 }
956
957 update_user_contacts(user.id, session).await?;
958 }
959 }
960
961 Ok(())
962 }
963
964 pub async fn invite_code_redeemed(
965 self: &Arc<Self>,
966 inviter_id: UserId,
967 invitee_id: UserId,
968 ) -> Result<()> {
969 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
970 if let Some(code) = &user.invite_code {
971 let pool = self.connection_pool.lock();
972 let invitee_contact = contact_for_user(invitee_id, false, &pool);
973 for connection_id in pool.user_connection_ids(inviter_id) {
974 self.peer.send(
975 connection_id,
976 proto::UpdateContacts {
977 contacts: vec![invitee_contact.clone()],
978 ..Default::default()
979 },
980 )?;
981 self.peer.send(
982 connection_id,
983 proto::UpdateInviteInfo {
984 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
985 count: user.invite_count as u32,
986 },
987 )?;
988 }
989 }
990 }
991 Ok(())
992 }
993
994 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
995 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
996 if let Some(invite_code) = &user.invite_code {
997 let pool = self.connection_pool.lock();
998 for connection_id in pool.user_connection_ids(user_id) {
999 self.peer.send(
1000 connection_id,
1001 proto::UpdateInviteInfo {
1002 url: format!(
1003 "{}{}",
1004 self.app_state.config.invite_link_prefix, invite_code
1005 ),
1006 count: user.invite_count as u32,
1007 },
1008 )?;
1009 }
1010 }
1011 }
1012 Ok(())
1013 }
1014
1015 pub async fn update_plan_for_user(
1016 self: &Arc<Self>,
1017 user_id: UserId,
1018 update_user_plan: proto::UpdateUserPlan,
1019 ) -> Result<()> {
1020 let pool = self.connection_pool.lock();
1021 for connection_id in pool.user_connection_ids(user_id) {
1022 self.peer
1023 .send(connection_id, update_user_plan.clone())
1024 .trace_err();
1025 }
1026
1027 Ok(())
1028 }
1029
1030 /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1031 /// message on the Collab server.
1032 ///
1033 /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1034 pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1035 let user = self
1036 .app_state
1037 .db
1038 .get_user_by_id(user_id)
1039 .await?
1040 .context("user not found")?;
1041
1042 let update_user_plan = make_update_user_plan_message(
1043 &user,
1044 user.admin,
1045 &self.app_state.db,
1046 self.app_state.llm_db.clone(),
1047 )
1048 .await?;
1049
1050 self.update_plan_for_user(user_id, update_user_plan).await
1051 }
1052
1053 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1054 let pool = self.connection_pool.lock();
1055 for connection_id in pool.user_connection_ids(user_id) {
1056 self.peer
1057 .send(connection_id, proto::RefreshLlmToken {})
1058 .trace_err();
1059 }
1060 }
1061
1062 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1063 ServerSnapshot {
1064 connection_pool: ConnectionPoolGuard {
1065 guard: self.connection_pool.lock(),
1066 _not_send: PhantomData,
1067 },
1068 peer: &self.peer,
1069 }
1070 }
1071}
1072
1073impl Deref for ConnectionPoolGuard<'_> {
1074 type Target = ConnectionPool;
1075
1076 fn deref(&self) -> &Self::Target {
1077 &self.guard
1078 }
1079}
1080
1081impl DerefMut for ConnectionPoolGuard<'_> {
1082 fn deref_mut(&mut self) -> &mut Self::Target {
1083 &mut self.guard
1084 }
1085}
1086
1087impl Drop for ConnectionPoolGuard<'_> {
1088 fn drop(&mut self) {
1089 #[cfg(test)]
1090 self.check_invariants();
1091 }
1092}
1093
1094fn broadcast<F>(
1095 sender_id: Option<ConnectionId>,
1096 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1097 mut f: F,
1098) where
1099 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1100{
1101 for receiver_id in receiver_ids {
1102 if Some(receiver_id) != sender_id {
1103 if let Err(error) = f(receiver_id) {
1104 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1105 }
1106 }
1107 }
1108}
1109
1110pub struct ProtocolVersion(u32);
1111
1112impl Header for ProtocolVersion {
1113 fn name() -> &'static HeaderName {
1114 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1115 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1116 }
1117
1118 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1119 where
1120 Self: Sized,
1121 I: Iterator<Item = &'i axum::http::HeaderValue>,
1122 {
1123 let version = values
1124 .next()
1125 .ok_or_else(axum::headers::Error::invalid)?
1126 .to_str()
1127 .map_err(|_| axum::headers::Error::invalid())?
1128 .parse()
1129 .map_err(|_| axum::headers::Error::invalid())?;
1130 Ok(Self(version))
1131 }
1132
1133 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1134 values.extend([self.0.to_string().parse().unwrap()]);
1135 }
1136}
1137
1138pub struct AppVersionHeader(SemanticVersion);
1139impl Header for AppVersionHeader {
1140 fn name() -> &'static HeaderName {
1141 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1142 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1143 }
1144
1145 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1146 where
1147 Self: Sized,
1148 I: Iterator<Item = &'i axum::http::HeaderValue>,
1149 {
1150 let version = values
1151 .next()
1152 .ok_or_else(axum::headers::Error::invalid)?
1153 .to_str()
1154 .map_err(|_| axum::headers::Error::invalid())?
1155 .parse()
1156 .map_err(|_| axum::headers::Error::invalid())?;
1157 Ok(Self(version))
1158 }
1159
1160 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1161 values.extend([self.0.to_string().parse().unwrap()]);
1162 }
1163}
1164
1165pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1166 Router::new()
1167 .route("/rpc", get(handle_websocket_request))
1168 .layer(
1169 ServiceBuilder::new()
1170 .layer(Extension(server.app_state.clone()))
1171 .layer(middleware::from_fn(auth::validate_header)),
1172 )
1173 .route("/metrics", get(handle_metrics))
1174 .layer(Extension(server))
1175}
1176
1177pub async fn handle_websocket_request(
1178 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1179 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1180 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1181 Extension(server): Extension<Arc<Server>>,
1182 Extension(principal): Extension<Principal>,
1183 user_agent: Option<TypedHeader<UserAgent>>,
1184 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1185 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1186 ws: WebSocketUpgrade,
1187) -> axum::response::Response {
1188 if protocol_version != rpc::PROTOCOL_VERSION {
1189 return (
1190 StatusCode::UPGRADE_REQUIRED,
1191 "client must be upgraded".to_string(),
1192 )
1193 .into_response();
1194 }
1195
1196 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1197 return (
1198 StatusCode::UPGRADE_REQUIRED,
1199 "no version header found".to_string(),
1200 )
1201 .into_response();
1202 };
1203
1204 if !version.can_collaborate() {
1205 return (
1206 StatusCode::UPGRADE_REQUIRED,
1207 "client must be upgraded".to_string(),
1208 )
1209 .into_response();
1210 }
1211
1212 let socket_address = socket_address.to_string();
1213
1214 // Acquire connection guard before WebSocket upgrade
1215 let connection_guard = match ConnectionGuard::try_acquire() {
1216 Ok(guard) => guard,
1217 Err(()) => {
1218 return (
1219 StatusCode::SERVICE_UNAVAILABLE,
1220 "Too many concurrent connections",
1221 )
1222 .into_response();
1223 }
1224 };
1225
1226 ws.on_upgrade(move |socket| {
1227 let socket = socket
1228 .map_ok(to_tungstenite_message)
1229 .err_into()
1230 .with(|message| async move { to_axum_message(message) });
1231 let connection = Connection::new(Box::pin(socket));
1232 async move {
1233 server
1234 .handle_connection(
1235 connection,
1236 socket_address,
1237 principal,
1238 version,
1239 user_agent.map(|header| header.to_string()),
1240 country_code_header.map(|header| header.to_string()),
1241 system_id_header.map(|header| header.to_string()),
1242 None,
1243 Executor::Production,
1244 Some(connection_guard),
1245 )
1246 .await;
1247 }
1248 })
1249}
1250
1251pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1252 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1253 let connections_metric = CONNECTIONS_METRIC
1254 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1255
1256 let connections = server
1257 .connection_pool
1258 .lock()
1259 .connections()
1260 .filter(|connection| !connection.admin)
1261 .count();
1262 connections_metric.set(connections as _);
1263
1264 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1265 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1266 register_int_gauge!(
1267 "shared_projects",
1268 "number of open projects with one or more guests"
1269 )
1270 .unwrap()
1271 });
1272
1273 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1274 shared_projects_metric.set(shared_projects as _);
1275
1276 let encoder = prometheus::TextEncoder::new();
1277 let metric_families = prometheus::gather();
1278 let encoded_metrics = encoder
1279 .encode_to_string(&metric_families)
1280 .map_err(|err| anyhow!("{err}"))?;
1281 Ok(encoded_metrics)
1282}
1283
1284#[instrument(err, skip(executor))]
1285async fn connection_lost(
1286 session: Session,
1287 mut teardown: watch::Receiver<bool>,
1288 executor: Executor,
1289) -> Result<()> {
1290 session.peer.disconnect(session.connection_id);
1291 session
1292 .connection_pool()
1293 .await
1294 .remove_connection(session.connection_id)?;
1295
1296 session
1297 .db()
1298 .await
1299 .connection_lost(session.connection_id)
1300 .await
1301 .trace_err();
1302
1303 futures::select_biased! {
1304 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1305
1306 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1307 leave_room_for_session(&session, session.connection_id).await.trace_err();
1308 leave_channel_buffers_for_session(&session)
1309 .await
1310 .trace_err();
1311
1312 if !session
1313 .connection_pool()
1314 .await
1315 .is_user_online(session.user_id())
1316 {
1317 let db = session.db().await;
1318 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1319 room_updated(&room, &session.peer);
1320 }
1321 }
1322
1323 update_user_contacts(session.user_id(), &session).await?;
1324 },
1325 _ = teardown.changed().fuse() => {}
1326 }
1327
1328 Ok(())
1329}
1330
1331/// Acknowledges a ping from a client, used to keep the connection alive.
1332async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1333 response.send(proto::Ack {})?;
1334 Ok(())
1335}
1336
1337/// Creates a new room for calling (outside of channels)
1338async fn create_room(
1339 _request: proto::CreateRoom,
1340 response: Response<proto::CreateRoom>,
1341 session: Session,
1342) -> Result<()> {
1343 let livekit_room = nanoid::nanoid!(30);
1344
1345 let live_kit_connection_info = util::maybe!(async {
1346 let live_kit = session.app_state.livekit_client.as_ref();
1347 let live_kit = live_kit?;
1348 let user_id = session.user_id().to_string();
1349
1350 let token = live_kit
1351 .room_token(&livekit_room, &user_id.to_string())
1352 .trace_err()?;
1353
1354 Some(proto::LiveKitConnectionInfo {
1355 server_url: live_kit.url().into(),
1356 token,
1357 can_publish: true,
1358 })
1359 })
1360 .await;
1361
1362 let room = session
1363 .db()
1364 .await
1365 .create_room(session.user_id(), session.connection_id, &livekit_room)
1366 .await?;
1367
1368 response.send(proto::CreateRoomResponse {
1369 room: Some(room.clone()),
1370 live_kit_connection_info,
1371 })?;
1372
1373 update_user_contacts(session.user_id(), &session).await?;
1374 Ok(())
1375}
1376
1377/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1378async fn join_room(
1379 request: proto::JoinRoom,
1380 response: Response<proto::JoinRoom>,
1381 session: Session,
1382) -> Result<()> {
1383 let room_id = RoomId::from_proto(request.id);
1384
1385 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1386
1387 if let Some(channel_id) = channel_id {
1388 return join_channel_internal(channel_id, Box::new(response), session).await;
1389 }
1390
1391 let joined_room = {
1392 let room = session
1393 .db()
1394 .await
1395 .join_room(room_id, session.user_id(), session.connection_id)
1396 .await?;
1397 room_updated(&room.room, &session.peer);
1398 room.into_inner()
1399 };
1400
1401 for connection_id in session
1402 .connection_pool()
1403 .await
1404 .user_connection_ids(session.user_id())
1405 {
1406 session
1407 .peer
1408 .send(
1409 connection_id,
1410 proto::CallCanceled {
1411 room_id: room_id.to_proto(),
1412 },
1413 )
1414 .trace_err();
1415 }
1416
1417 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1418 {
1419 live_kit
1420 .room_token(
1421 &joined_room.room.livekit_room,
1422 &session.user_id().to_string(),
1423 )
1424 .trace_err()
1425 .map(|token| proto::LiveKitConnectionInfo {
1426 server_url: live_kit.url().into(),
1427 token,
1428 can_publish: true,
1429 })
1430 } else {
1431 None
1432 };
1433
1434 response.send(proto::JoinRoomResponse {
1435 room: Some(joined_room.room),
1436 channel_id: None,
1437 live_kit_connection_info,
1438 })?;
1439
1440 update_user_contacts(session.user_id(), &session).await?;
1441 Ok(())
1442}
1443
1444/// Rejoin room is used to reconnect to a room after connection errors.
1445async fn rejoin_room(
1446 request: proto::RejoinRoom,
1447 response: Response<proto::RejoinRoom>,
1448 session: Session,
1449) -> Result<()> {
1450 let room;
1451 let channel;
1452 {
1453 let mut rejoined_room = session
1454 .db()
1455 .await
1456 .rejoin_room(request, session.user_id(), session.connection_id)
1457 .await?;
1458
1459 response.send(proto::RejoinRoomResponse {
1460 room: Some(rejoined_room.room.clone()),
1461 reshared_projects: rejoined_room
1462 .reshared_projects
1463 .iter()
1464 .map(|project| proto::ResharedProject {
1465 id: project.id.to_proto(),
1466 collaborators: project
1467 .collaborators
1468 .iter()
1469 .map(|collaborator| collaborator.to_proto())
1470 .collect(),
1471 })
1472 .collect(),
1473 rejoined_projects: rejoined_room
1474 .rejoined_projects
1475 .iter()
1476 .map(|rejoined_project| rejoined_project.to_proto())
1477 .collect(),
1478 })?;
1479 room_updated(&rejoined_room.room, &session.peer);
1480
1481 for project in &rejoined_room.reshared_projects {
1482 for collaborator in &project.collaborators {
1483 session
1484 .peer
1485 .send(
1486 collaborator.connection_id,
1487 proto::UpdateProjectCollaborator {
1488 project_id: project.id.to_proto(),
1489 old_peer_id: Some(project.old_connection_id.into()),
1490 new_peer_id: Some(session.connection_id.into()),
1491 },
1492 )
1493 .trace_err();
1494 }
1495
1496 broadcast(
1497 Some(session.connection_id),
1498 project
1499 .collaborators
1500 .iter()
1501 .map(|collaborator| collaborator.connection_id),
1502 |connection_id| {
1503 session.peer.forward_send(
1504 session.connection_id,
1505 connection_id,
1506 proto::UpdateProject {
1507 project_id: project.id.to_proto(),
1508 worktrees: project.worktrees.clone(),
1509 },
1510 )
1511 },
1512 );
1513 }
1514
1515 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1516
1517 let rejoined_room = rejoined_room.into_inner();
1518
1519 room = rejoined_room.room;
1520 channel = rejoined_room.channel;
1521 }
1522
1523 if let Some(channel) = channel {
1524 channel_updated(
1525 &channel,
1526 &room,
1527 &session.peer,
1528 &*session.connection_pool().await,
1529 );
1530 }
1531
1532 update_user_contacts(session.user_id(), &session).await?;
1533 Ok(())
1534}
1535
1536fn notify_rejoined_projects(
1537 rejoined_projects: &mut Vec<RejoinedProject>,
1538 session: &Session,
1539) -> Result<()> {
1540 for project in rejoined_projects.iter() {
1541 for collaborator in &project.collaborators {
1542 session
1543 .peer
1544 .send(
1545 collaborator.connection_id,
1546 proto::UpdateProjectCollaborator {
1547 project_id: project.id.to_proto(),
1548 old_peer_id: Some(project.old_connection_id.into()),
1549 new_peer_id: Some(session.connection_id.into()),
1550 },
1551 )
1552 .trace_err();
1553 }
1554 }
1555
1556 for project in rejoined_projects {
1557 for worktree in mem::take(&mut project.worktrees) {
1558 // Stream this worktree's entries.
1559 let message = proto::UpdateWorktree {
1560 project_id: project.id.to_proto(),
1561 worktree_id: worktree.id,
1562 abs_path: worktree.abs_path.clone(),
1563 root_name: worktree.root_name,
1564 updated_entries: worktree.updated_entries,
1565 removed_entries: worktree.removed_entries,
1566 scan_id: worktree.scan_id,
1567 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1568 updated_repositories: worktree.updated_repositories,
1569 removed_repositories: worktree.removed_repositories,
1570 };
1571 for update in proto::split_worktree_update(message) {
1572 session.peer.send(session.connection_id, update)?;
1573 }
1574
1575 // Stream this worktree's diagnostics.
1576 for summary in worktree.diagnostic_summaries {
1577 session.peer.send(
1578 session.connection_id,
1579 proto::UpdateDiagnosticSummary {
1580 project_id: project.id.to_proto(),
1581 worktree_id: worktree.id,
1582 summary: Some(summary),
1583 },
1584 )?;
1585 }
1586
1587 for settings_file in worktree.settings_files {
1588 session.peer.send(
1589 session.connection_id,
1590 proto::UpdateWorktreeSettings {
1591 project_id: project.id.to_proto(),
1592 worktree_id: worktree.id,
1593 path: settings_file.path,
1594 content: Some(settings_file.content),
1595 kind: Some(settings_file.kind.to_proto().into()),
1596 },
1597 )?;
1598 }
1599 }
1600
1601 for repository in mem::take(&mut project.updated_repositories) {
1602 for update in split_repository_update(repository) {
1603 session.peer.send(session.connection_id, update)?;
1604 }
1605 }
1606
1607 for id in mem::take(&mut project.removed_repositories) {
1608 session.peer.send(
1609 session.connection_id,
1610 proto::RemoveRepository {
1611 project_id: project.id.to_proto(),
1612 id,
1613 },
1614 )?;
1615 }
1616 }
1617
1618 Ok(())
1619}
1620
1621/// leave room disconnects from the room.
1622async fn leave_room(
1623 _: proto::LeaveRoom,
1624 response: Response<proto::LeaveRoom>,
1625 session: Session,
1626) -> Result<()> {
1627 leave_room_for_session(&session, session.connection_id).await?;
1628 response.send(proto::Ack {})?;
1629 Ok(())
1630}
1631
1632/// Updates the permissions of someone else in the room.
1633async fn set_room_participant_role(
1634 request: proto::SetRoomParticipantRole,
1635 response: Response<proto::SetRoomParticipantRole>,
1636 session: Session,
1637) -> Result<()> {
1638 let user_id = UserId::from_proto(request.user_id);
1639 let role = ChannelRole::from(request.role());
1640
1641 let (livekit_room, can_publish) = {
1642 let room = session
1643 .db()
1644 .await
1645 .set_room_participant_role(
1646 session.user_id(),
1647 RoomId::from_proto(request.room_id),
1648 user_id,
1649 role,
1650 )
1651 .await?;
1652
1653 let livekit_room = room.livekit_room.clone();
1654 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1655 room_updated(&room, &session.peer);
1656 (livekit_room, can_publish)
1657 };
1658
1659 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1660 live_kit
1661 .update_participant(
1662 livekit_room.clone(),
1663 request.user_id.to_string(),
1664 livekit_api::proto::ParticipantPermission {
1665 can_subscribe: true,
1666 can_publish,
1667 can_publish_data: can_publish,
1668 hidden: false,
1669 recorder: false,
1670 },
1671 )
1672 .await
1673 .trace_err();
1674 }
1675
1676 response.send(proto::Ack {})?;
1677 Ok(())
1678}
1679
1680/// Call someone else into the current room
1681async fn call(
1682 request: proto::Call,
1683 response: Response<proto::Call>,
1684 session: Session,
1685) -> Result<()> {
1686 let room_id = RoomId::from_proto(request.room_id);
1687 let calling_user_id = session.user_id();
1688 let calling_connection_id = session.connection_id;
1689 let called_user_id = UserId::from_proto(request.called_user_id);
1690 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1691 if !session
1692 .db()
1693 .await
1694 .has_contact(calling_user_id, called_user_id)
1695 .await?
1696 {
1697 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1698 }
1699
1700 let incoming_call = {
1701 let (room, incoming_call) = &mut *session
1702 .db()
1703 .await
1704 .call(
1705 room_id,
1706 calling_user_id,
1707 calling_connection_id,
1708 called_user_id,
1709 initial_project_id,
1710 )
1711 .await?;
1712 room_updated(room, &session.peer);
1713 mem::take(incoming_call)
1714 };
1715 update_user_contacts(called_user_id, &session).await?;
1716
1717 let mut calls = session
1718 .connection_pool()
1719 .await
1720 .user_connection_ids(called_user_id)
1721 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1722 .collect::<FuturesUnordered<_>>();
1723
1724 while let Some(call_response) = calls.next().await {
1725 match call_response.as_ref() {
1726 Ok(_) => {
1727 response.send(proto::Ack {})?;
1728 return Ok(());
1729 }
1730 Err(_) => {
1731 call_response.trace_err();
1732 }
1733 }
1734 }
1735
1736 {
1737 let room = session
1738 .db()
1739 .await
1740 .call_failed(room_id, called_user_id)
1741 .await?;
1742 room_updated(&room, &session.peer);
1743 }
1744 update_user_contacts(called_user_id, &session).await?;
1745
1746 Err(anyhow!("failed to ring user"))?
1747}
1748
1749/// Cancel an outgoing call.
1750async fn cancel_call(
1751 request: proto::CancelCall,
1752 response: Response<proto::CancelCall>,
1753 session: Session,
1754) -> Result<()> {
1755 let called_user_id = UserId::from_proto(request.called_user_id);
1756 let room_id = RoomId::from_proto(request.room_id);
1757 {
1758 let room = session
1759 .db()
1760 .await
1761 .cancel_call(room_id, session.connection_id, called_user_id)
1762 .await?;
1763 room_updated(&room, &session.peer);
1764 }
1765
1766 for connection_id in session
1767 .connection_pool()
1768 .await
1769 .user_connection_ids(called_user_id)
1770 {
1771 session
1772 .peer
1773 .send(
1774 connection_id,
1775 proto::CallCanceled {
1776 room_id: room_id.to_proto(),
1777 },
1778 )
1779 .trace_err();
1780 }
1781 response.send(proto::Ack {})?;
1782
1783 update_user_contacts(called_user_id, &session).await?;
1784 Ok(())
1785}
1786
1787/// Decline an incoming call.
1788async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1789 let room_id = RoomId::from_proto(message.room_id);
1790 {
1791 let room = session
1792 .db()
1793 .await
1794 .decline_call(Some(room_id), session.user_id())
1795 .await?
1796 .context("declining call")?;
1797 room_updated(&room, &session.peer);
1798 }
1799
1800 for connection_id in session
1801 .connection_pool()
1802 .await
1803 .user_connection_ids(session.user_id())
1804 {
1805 session
1806 .peer
1807 .send(
1808 connection_id,
1809 proto::CallCanceled {
1810 room_id: room_id.to_proto(),
1811 },
1812 )
1813 .trace_err();
1814 }
1815 update_user_contacts(session.user_id(), &session).await?;
1816 Ok(())
1817}
1818
1819/// Updates other participants in the room with your current location.
1820async fn update_participant_location(
1821 request: proto::UpdateParticipantLocation,
1822 response: Response<proto::UpdateParticipantLocation>,
1823 session: Session,
1824) -> Result<()> {
1825 let room_id = RoomId::from_proto(request.room_id);
1826 let location = request.location.context("invalid location")?;
1827
1828 let db = session.db().await;
1829 let room = db
1830 .update_room_participant_location(room_id, session.connection_id, location)
1831 .await?;
1832
1833 room_updated(&room, &session.peer);
1834 response.send(proto::Ack {})?;
1835 Ok(())
1836}
1837
1838/// Share a project into the room.
1839async fn share_project(
1840 request: proto::ShareProject,
1841 response: Response<proto::ShareProject>,
1842 session: Session,
1843) -> Result<()> {
1844 let (project_id, room) = &*session
1845 .db()
1846 .await
1847 .share_project(
1848 RoomId::from_proto(request.room_id),
1849 session.connection_id,
1850 &request.worktrees,
1851 request.is_ssh_project,
1852 )
1853 .await?;
1854 response.send(proto::ShareProjectResponse {
1855 project_id: project_id.to_proto(),
1856 })?;
1857 room_updated(room, &session.peer);
1858
1859 Ok(())
1860}
1861
1862/// Unshare a project from the room.
1863async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1864 let project_id = ProjectId::from_proto(message.project_id);
1865 unshare_project_internal(project_id, session.connection_id, &session).await
1866}
1867
1868async fn unshare_project_internal(
1869 project_id: ProjectId,
1870 connection_id: ConnectionId,
1871 session: &Session,
1872) -> Result<()> {
1873 let delete = {
1874 let room_guard = session
1875 .db()
1876 .await
1877 .unshare_project(project_id, connection_id)
1878 .await?;
1879
1880 let (delete, room, guest_connection_ids) = &*room_guard;
1881
1882 let message = proto::UnshareProject {
1883 project_id: project_id.to_proto(),
1884 };
1885
1886 broadcast(
1887 Some(connection_id),
1888 guest_connection_ids.iter().copied(),
1889 |conn_id| session.peer.send(conn_id, message.clone()),
1890 );
1891 if let Some(room) = room {
1892 room_updated(room, &session.peer);
1893 }
1894
1895 *delete
1896 };
1897
1898 if delete {
1899 let db = session.db().await;
1900 db.delete_project(project_id).await?;
1901 }
1902
1903 Ok(())
1904}
1905
1906/// Join someone elses shared project.
1907async fn join_project(
1908 request: proto::JoinProject,
1909 response: Response<proto::JoinProject>,
1910 session: Session,
1911) -> Result<()> {
1912 let project_id = ProjectId::from_proto(request.project_id);
1913
1914 tracing::info!(%project_id, "join project");
1915
1916 let db = session.db().await;
1917 let (project, replica_id) = &mut *db
1918 .join_project(
1919 project_id,
1920 session.connection_id,
1921 session.user_id(),
1922 request.committer_name.clone(),
1923 request.committer_email.clone(),
1924 )
1925 .await?;
1926 drop(db);
1927 tracing::info!(%project_id, "join remote project");
1928 let collaborators = project
1929 .collaborators
1930 .iter()
1931 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1932 .map(|collaborator| collaborator.to_proto())
1933 .collect::<Vec<_>>();
1934 let project_id = project.id;
1935 let guest_user_id = session.user_id();
1936
1937 let worktrees = project
1938 .worktrees
1939 .iter()
1940 .map(|(id, worktree)| proto::WorktreeMetadata {
1941 id: *id,
1942 root_name: worktree.root_name.clone(),
1943 visible: worktree.visible,
1944 abs_path: worktree.abs_path.clone(),
1945 })
1946 .collect::<Vec<_>>();
1947
1948 let add_project_collaborator = proto::AddProjectCollaborator {
1949 project_id: project_id.to_proto(),
1950 collaborator: Some(proto::Collaborator {
1951 peer_id: Some(session.connection_id.into()),
1952 replica_id: replica_id.0 as u32,
1953 user_id: guest_user_id.to_proto(),
1954 is_host: false,
1955 committer_name: request.committer_name.clone(),
1956 committer_email: request.committer_email.clone(),
1957 }),
1958 };
1959
1960 for collaborator in &collaborators {
1961 session
1962 .peer
1963 .send(
1964 collaborator.peer_id.unwrap().into(),
1965 add_project_collaborator.clone(),
1966 )
1967 .trace_err();
1968 }
1969
1970 // First, we send the metadata associated with each worktree.
1971 response.send(proto::JoinProjectResponse {
1972 project_id: project.id.0 as u64,
1973 worktrees: worktrees.clone(),
1974 replica_id: replica_id.0 as u32,
1975 collaborators: collaborators.clone(),
1976 language_servers: project.language_servers.clone(),
1977 role: project.role.into(),
1978 })?;
1979
1980 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1981 // Stream this worktree's entries.
1982 let message = proto::UpdateWorktree {
1983 project_id: project_id.to_proto(),
1984 worktree_id,
1985 abs_path: worktree.abs_path.clone(),
1986 root_name: worktree.root_name,
1987 updated_entries: worktree.entries,
1988 removed_entries: Default::default(),
1989 scan_id: worktree.scan_id,
1990 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1991 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1992 removed_repositories: Default::default(),
1993 };
1994 for update in proto::split_worktree_update(message) {
1995 session.peer.send(session.connection_id, update.clone())?;
1996 }
1997
1998 // Stream this worktree's diagnostics.
1999 for summary in worktree.diagnostic_summaries {
2000 session.peer.send(
2001 session.connection_id,
2002 proto::UpdateDiagnosticSummary {
2003 project_id: project_id.to_proto(),
2004 worktree_id: worktree.id,
2005 summary: Some(summary),
2006 },
2007 )?;
2008 }
2009
2010 for settings_file in worktree.settings_files {
2011 session.peer.send(
2012 session.connection_id,
2013 proto::UpdateWorktreeSettings {
2014 project_id: project_id.to_proto(),
2015 worktree_id: worktree.id,
2016 path: settings_file.path,
2017 content: Some(settings_file.content),
2018 kind: Some(settings_file.kind.to_proto() as i32),
2019 },
2020 )?;
2021 }
2022 }
2023
2024 for repository in mem::take(&mut project.repositories) {
2025 for update in split_repository_update(repository) {
2026 session.peer.send(session.connection_id, update)?;
2027 }
2028 }
2029
2030 for language_server in &project.language_servers {
2031 session.peer.send(
2032 session.connection_id,
2033 proto::UpdateLanguageServer {
2034 project_id: project_id.to_proto(),
2035 server_name: Some(language_server.name.clone()),
2036 language_server_id: language_server.id,
2037 variant: Some(
2038 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2039 proto::LspDiskBasedDiagnosticsUpdated {},
2040 ),
2041 ),
2042 },
2043 )?;
2044 }
2045
2046 Ok(())
2047}
2048
2049/// Leave someone elses shared project.
2050async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2051 let sender_id = session.connection_id;
2052 let project_id = ProjectId::from_proto(request.project_id);
2053 let db = session.db().await;
2054
2055 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2056 tracing::info!(
2057 %project_id,
2058 "leave project"
2059 );
2060
2061 project_left(project, &session);
2062 if let Some(room) = room {
2063 room_updated(room, &session.peer);
2064 }
2065
2066 Ok(())
2067}
2068
2069/// Updates other participants with changes to the project
2070async fn update_project(
2071 request: proto::UpdateProject,
2072 response: Response<proto::UpdateProject>,
2073 session: Session,
2074) -> Result<()> {
2075 let project_id = ProjectId::from_proto(request.project_id);
2076 let (room, guest_connection_ids) = &*session
2077 .db()
2078 .await
2079 .update_project(project_id, session.connection_id, &request.worktrees)
2080 .await?;
2081 broadcast(
2082 Some(session.connection_id),
2083 guest_connection_ids.iter().copied(),
2084 |connection_id| {
2085 session
2086 .peer
2087 .forward_send(session.connection_id, connection_id, request.clone())
2088 },
2089 );
2090 if let Some(room) = room {
2091 room_updated(room, &session.peer);
2092 }
2093 response.send(proto::Ack {})?;
2094
2095 Ok(())
2096}
2097
2098/// Updates other participants with changes to the worktree
2099async fn update_worktree(
2100 request: proto::UpdateWorktree,
2101 response: Response<proto::UpdateWorktree>,
2102 session: Session,
2103) -> Result<()> {
2104 let guest_connection_ids = session
2105 .db()
2106 .await
2107 .update_worktree(&request, session.connection_id)
2108 .await?;
2109
2110 broadcast(
2111 Some(session.connection_id),
2112 guest_connection_ids.iter().copied(),
2113 |connection_id| {
2114 session
2115 .peer
2116 .forward_send(session.connection_id, connection_id, request.clone())
2117 },
2118 );
2119 response.send(proto::Ack {})?;
2120 Ok(())
2121}
2122
2123async fn update_repository(
2124 request: proto::UpdateRepository,
2125 response: Response<proto::UpdateRepository>,
2126 session: Session,
2127) -> Result<()> {
2128 let guest_connection_ids = session
2129 .db()
2130 .await
2131 .update_repository(&request, session.connection_id)
2132 .await?;
2133
2134 broadcast(
2135 Some(session.connection_id),
2136 guest_connection_ids.iter().copied(),
2137 |connection_id| {
2138 session
2139 .peer
2140 .forward_send(session.connection_id, connection_id, request.clone())
2141 },
2142 );
2143 response.send(proto::Ack {})?;
2144 Ok(())
2145}
2146
2147async fn remove_repository(
2148 request: proto::RemoveRepository,
2149 response: Response<proto::RemoveRepository>,
2150 session: Session,
2151) -> Result<()> {
2152 let guest_connection_ids = session
2153 .db()
2154 .await
2155 .remove_repository(&request, session.connection_id)
2156 .await?;
2157
2158 broadcast(
2159 Some(session.connection_id),
2160 guest_connection_ids.iter().copied(),
2161 |connection_id| {
2162 session
2163 .peer
2164 .forward_send(session.connection_id, connection_id, request.clone())
2165 },
2166 );
2167 response.send(proto::Ack {})?;
2168 Ok(())
2169}
2170
2171/// Updates other participants with changes to the diagnostics
2172async fn update_diagnostic_summary(
2173 message: proto::UpdateDiagnosticSummary,
2174 session: Session,
2175) -> Result<()> {
2176 let guest_connection_ids = session
2177 .db()
2178 .await
2179 .update_diagnostic_summary(&message, session.connection_id)
2180 .await?;
2181
2182 broadcast(
2183 Some(session.connection_id),
2184 guest_connection_ids.iter().copied(),
2185 |connection_id| {
2186 session
2187 .peer
2188 .forward_send(session.connection_id, connection_id, message.clone())
2189 },
2190 );
2191
2192 Ok(())
2193}
2194
2195/// Updates other participants with changes to the worktree settings
2196async fn update_worktree_settings(
2197 message: proto::UpdateWorktreeSettings,
2198 session: Session,
2199) -> Result<()> {
2200 let guest_connection_ids = session
2201 .db()
2202 .await
2203 .update_worktree_settings(&message, session.connection_id)
2204 .await?;
2205
2206 broadcast(
2207 Some(session.connection_id),
2208 guest_connection_ids.iter().copied(),
2209 |connection_id| {
2210 session
2211 .peer
2212 .forward_send(session.connection_id, connection_id, message.clone())
2213 },
2214 );
2215
2216 Ok(())
2217}
2218
2219/// Notify other participants that a language server has started.
2220async fn start_language_server(
2221 request: proto::StartLanguageServer,
2222 session: Session,
2223) -> Result<()> {
2224 let guest_connection_ids = session
2225 .db()
2226 .await
2227 .start_language_server(&request, session.connection_id)
2228 .await?;
2229
2230 broadcast(
2231 Some(session.connection_id),
2232 guest_connection_ids.iter().copied(),
2233 |connection_id| {
2234 session
2235 .peer
2236 .forward_send(session.connection_id, connection_id, request.clone())
2237 },
2238 );
2239 Ok(())
2240}
2241
2242/// Notify other participants that a language server has changed.
2243async fn update_language_server(
2244 request: proto::UpdateLanguageServer,
2245 session: Session,
2246) -> Result<()> {
2247 let project_id = ProjectId::from_proto(request.project_id);
2248 let project_connection_ids = session
2249 .db()
2250 .await
2251 .project_connection_ids(project_id, session.connection_id, true)
2252 .await?;
2253 broadcast(
2254 Some(session.connection_id),
2255 project_connection_ids.iter().copied(),
2256 |connection_id| {
2257 session
2258 .peer
2259 .forward_send(session.connection_id, connection_id, request.clone())
2260 },
2261 );
2262 Ok(())
2263}
2264
2265/// forward a project request to the host. These requests should be read only
2266/// as guests are allowed to send them.
2267async fn forward_read_only_project_request<T>(
2268 request: T,
2269 response: Response<T>,
2270 session: Session,
2271) -> Result<()>
2272where
2273 T: EntityMessage + RequestMessage,
2274{
2275 let project_id = ProjectId::from_proto(request.remote_entity_id());
2276 let host_connection_id = session
2277 .db()
2278 .await
2279 .host_for_read_only_project_request(project_id, session.connection_id)
2280 .await?;
2281 let payload = session
2282 .peer
2283 .forward_request(session.connection_id, host_connection_id, request)
2284 .await?;
2285 response.send(payload)?;
2286 Ok(())
2287}
2288
2289async fn forward_find_search_candidates_request(
2290 request: proto::FindSearchCandidates,
2291 response: Response<proto::FindSearchCandidates>,
2292 session: Session,
2293) -> Result<()> {
2294 let project_id = ProjectId::from_proto(request.remote_entity_id());
2295 let host_connection_id = session
2296 .db()
2297 .await
2298 .host_for_read_only_project_request(project_id, session.connection_id)
2299 .await?;
2300 let payload = session
2301 .peer
2302 .forward_request(session.connection_id, host_connection_id, request)
2303 .await?;
2304 response.send(payload)?;
2305 Ok(())
2306}
2307
2308/// forward a project request to the host. These requests are disallowed
2309/// for guests.
2310async fn forward_mutating_project_request<T>(
2311 request: T,
2312 response: Response<T>,
2313 session: Session,
2314) -> Result<()>
2315where
2316 T: EntityMessage + RequestMessage,
2317{
2318 let project_id = ProjectId::from_proto(request.remote_entity_id());
2319
2320 let host_connection_id = session
2321 .db()
2322 .await
2323 .host_for_mutating_project_request(project_id, session.connection_id)
2324 .await?;
2325 let payload = session
2326 .peer
2327 .forward_request(session.connection_id, host_connection_id, request)
2328 .await?;
2329 response.send(payload)?;
2330 Ok(())
2331}
2332
2333async fn multi_lsp_query(
2334 request: MultiLspQuery,
2335 response: Response<MultiLspQuery>,
2336 session: Session,
2337) -> Result<()> {
2338 tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2339 forward_mutating_project_request(request, response, session).await
2340}
2341
2342/// Notify other participants that a new buffer has been created
2343async fn create_buffer_for_peer(
2344 request: proto::CreateBufferForPeer,
2345 session: Session,
2346) -> Result<()> {
2347 session
2348 .db()
2349 .await
2350 .check_user_is_project_host(
2351 ProjectId::from_proto(request.project_id),
2352 session.connection_id,
2353 )
2354 .await?;
2355 let peer_id = request.peer_id.context("invalid peer id")?;
2356 session
2357 .peer
2358 .forward_send(session.connection_id, peer_id.into(), request)?;
2359 Ok(())
2360}
2361
2362/// Notify other participants that a buffer has been updated. This is
2363/// allowed for guests as long as the update is limited to selections.
2364async fn update_buffer(
2365 request: proto::UpdateBuffer,
2366 response: Response<proto::UpdateBuffer>,
2367 session: Session,
2368) -> Result<()> {
2369 let project_id = ProjectId::from_proto(request.project_id);
2370 let mut capability = Capability::ReadOnly;
2371
2372 for op in request.operations.iter() {
2373 match op.variant {
2374 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2375 Some(_) => capability = Capability::ReadWrite,
2376 }
2377 }
2378
2379 let host = {
2380 let guard = session
2381 .db()
2382 .await
2383 .connections_for_buffer_update(project_id, session.connection_id, capability)
2384 .await?;
2385
2386 let (host, guests) = &*guard;
2387
2388 broadcast(
2389 Some(session.connection_id),
2390 guests.clone(),
2391 |connection_id| {
2392 session
2393 .peer
2394 .forward_send(session.connection_id, connection_id, request.clone())
2395 },
2396 );
2397
2398 *host
2399 };
2400
2401 if host != session.connection_id {
2402 session
2403 .peer
2404 .forward_request(session.connection_id, host, request.clone())
2405 .await?;
2406 }
2407
2408 response.send(proto::Ack {})?;
2409 Ok(())
2410}
2411
2412async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2413 let project_id = ProjectId::from_proto(message.project_id);
2414
2415 let operation = message.operation.as_ref().context("invalid operation")?;
2416 let capability = match operation.variant.as_ref() {
2417 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2418 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2419 match buffer_op.variant {
2420 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2421 Capability::ReadOnly
2422 }
2423 _ => Capability::ReadWrite,
2424 }
2425 } else {
2426 Capability::ReadWrite
2427 }
2428 }
2429 Some(_) => Capability::ReadWrite,
2430 None => Capability::ReadOnly,
2431 };
2432
2433 let guard = session
2434 .db()
2435 .await
2436 .connections_for_buffer_update(project_id, session.connection_id, capability)
2437 .await?;
2438
2439 let (host, guests) = &*guard;
2440
2441 broadcast(
2442 Some(session.connection_id),
2443 guests.iter().chain([host]).copied(),
2444 |connection_id| {
2445 session
2446 .peer
2447 .forward_send(session.connection_id, connection_id, message.clone())
2448 },
2449 );
2450
2451 Ok(())
2452}
2453
2454/// Notify other participants that a project has been updated.
2455async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2456 request: T,
2457 session: Session,
2458) -> Result<()> {
2459 let project_id = ProjectId::from_proto(request.remote_entity_id());
2460 let project_connection_ids = session
2461 .db()
2462 .await
2463 .project_connection_ids(project_id, session.connection_id, false)
2464 .await?;
2465
2466 broadcast(
2467 Some(session.connection_id),
2468 project_connection_ids.iter().copied(),
2469 |connection_id| {
2470 session
2471 .peer
2472 .forward_send(session.connection_id, connection_id, request.clone())
2473 },
2474 );
2475 Ok(())
2476}
2477
2478/// Start following another user in a call.
2479async fn follow(
2480 request: proto::Follow,
2481 response: Response<proto::Follow>,
2482 session: Session,
2483) -> Result<()> {
2484 let room_id = RoomId::from_proto(request.room_id);
2485 let project_id = request.project_id.map(ProjectId::from_proto);
2486 let leader_id = request.leader_id.context("invalid leader id")?.into();
2487 let follower_id = session.connection_id;
2488
2489 session
2490 .db()
2491 .await
2492 .check_room_participants(room_id, leader_id, session.connection_id)
2493 .await?;
2494
2495 let response_payload = session
2496 .peer
2497 .forward_request(session.connection_id, leader_id, request)
2498 .await?;
2499 response.send(response_payload)?;
2500
2501 if let Some(project_id) = project_id {
2502 let room = session
2503 .db()
2504 .await
2505 .follow(room_id, project_id, leader_id, follower_id)
2506 .await?;
2507 room_updated(&room, &session.peer);
2508 }
2509
2510 Ok(())
2511}
2512
2513/// Stop following another user in a call.
2514async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2515 let room_id = RoomId::from_proto(request.room_id);
2516 let project_id = request.project_id.map(ProjectId::from_proto);
2517 let leader_id = request.leader_id.context("invalid leader id")?.into();
2518 let follower_id = session.connection_id;
2519
2520 session
2521 .db()
2522 .await
2523 .check_room_participants(room_id, leader_id, session.connection_id)
2524 .await?;
2525
2526 session
2527 .peer
2528 .forward_send(session.connection_id, leader_id, request)?;
2529
2530 if let Some(project_id) = project_id {
2531 let room = session
2532 .db()
2533 .await
2534 .unfollow(room_id, project_id, leader_id, follower_id)
2535 .await?;
2536 room_updated(&room, &session.peer);
2537 }
2538
2539 Ok(())
2540}
2541
2542/// Notify everyone following you of your current location.
2543async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2544 let room_id = RoomId::from_proto(request.room_id);
2545 let database = session.db.lock().await;
2546
2547 let connection_ids = if let Some(project_id) = request.project_id {
2548 let project_id = ProjectId::from_proto(project_id);
2549 database
2550 .project_connection_ids(project_id, session.connection_id, true)
2551 .await?
2552 } else {
2553 database
2554 .room_connection_ids(room_id, session.connection_id)
2555 .await?
2556 };
2557
2558 // For now, don't send view update messages back to that view's current leader.
2559 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2560 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2561 _ => None,
2562 });
2563
2564 for connection_id in connection_ids.iter().cloned() {
2565 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2566 session
2567 .peer
2568 .forward_send(session.connection_id, connection_id, request.clone())?;
2569 }
2570 }
2571 Ok(())
2572}
2573
2574/// Get public data about users.
2575async fn get_users(
2576 request: proto::GetUsers,
2577 response: Response<proto::GetUsers>,
2578 session: Session,
2579) -> Result<()> {
2580 let user_ids = request
2581 .user_ids
2582 .into_iter()
2583 .map(UserId::from_proto)
2584 .collect();
2585 let users = session
2586 .db()
2587 .await
2588 .get_users_by_ids(user_ids)
2589 .await?
2590 .into_iter()
2591 .map(|user| proto::User {
2592 id: user.id.to_proto(),
2593 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2594 github_login: user.github_login,
2595 name: user.name,
2596 })
2597 .collect();
2598 response.send(proto::UsersResponse { users })?;
2599 Ok(())
2600}
2601
2602/// Search for users (to invite) buy Github login
2603async fn fuzzy_search_users(
2604 request: proto::FuzzySearchUsers,
2605 response: Response<proto::FuzzySearchUsers>,
2606 session: Session,
2607) -> Result<()> {
2608 let query = request.query;
2609 let users = match query.len() {
2610 0 => vec![],
2611 1 | 2 => session
2612 .db()
2613 .await
2614 .get_user_by_github_login(&query)
2615 .await?
2616 .into_iter()
2617 .collect(),
2618 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2619 };
2620 let users = users
2621 .into_iter()
2622 .filter(|user| user.id != session.user_id())
2623 .map(|user| proto::User {
2624 id: user.id.to_proto(),
2625 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2626 github_login: user.github_login,
2627 name: user.name,
2628 })
2629 .collect();
2630 response.send(proto::UsersResponse { users })?;
2631 Ok(())
2632}
2633
2634/// Send a contact request to another user.
2635async fn request_contact(
2636 request: proto::RequestContact,
2637 response: Response<proto::RequestContact>,
2638 session: Session,
2639) -> Result<()> {
2640 let requester_id = session.user_id();
2641 let responder_id = UserId::from_proto(request.responder_id);
2642 if requester_id == responder_id {
2643 return Err(anyhow!("cannot add yourself as a contact"))?;
2644 }
2645
2646 let notifications = session
2647 .db()
2648 .await
2649 .send_contact_request(requester_id, responder_id)
2650 .await?;
2651
2652 // Update outgoing contact requests of requester
2653 let mut update = proto::UpdateContacts::default();
2654 update.outgoing_requests.push(responder_id.to_proto());
2655 for connection_id in session
2656 .connection_pool()
2657 .await
2658 .user_connection_ids(requester_id)
2659 {
2660 session.peer.send(connection_id, update.clone())?;
2661 }
2662
2663 // Update incoming contact requests of responder
2664 let mut update = proto::UpdateContacts::default();
2665 update
2666 .incoming_requests
2667 .push(proto::IncomingContactRequest {
2668 requester_id: requester_id.to_proto(),
2669 });
2670 let connection_pool = session.connection_pool().await;
2671 for connection_id in connection_pool.user_connection_ids(responder_id) {
2672 session.peer.send(connection_id, update.clone())?;
2673 }
2674
2675 send_notifications(&connection_pool, &session.peer, notifications);
2676
2677 response.send(proto::Ack {})?;
2678 Ok(())
2679}
2680
2681/// Accept or decline a contact request
2682async fn respond_to_contact_request(
2683 request: proto::RespondToContactRequest,
2684 response: Response<proto::RespondToContactRequest>,
2685 session: Session,
2686) -> Result<()> {
2687 let responder_id = session.user_id();
2688 let requester_id = UserId::from_proto(request.requester_id);
2689 let db = session.db().await;
2690 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2691 db.dismiss_contact_notification(responder_id, requester_id)
2692 .await?;
2693 } else {
2694 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2695
2696 let notifications = db
2697 .respond_to_contact_request(responder_id, requester_id, accept)
2698 .await?;
2699 let requester_busy = db.is_user_busy(requester_id).await?;
2700 let responder_busy = db.is_user_busy(responder_id).await?;
2701
2702 let pool = session.connection_pool().await;
2703 // Update responder with new contact
2704 let mut update = proto::UpdateContacts::default();
2705 if accept {
2706 update
2707 .contacts
2708 .push(contact_for_user(requester_id, requester_busy, &pool));
2709 }
2710 update
2711 .remove_incoming_requests
2712 .push(requester_id.to_proto());
2713 for connection_id in pool.user_connection_ids(responder_id) {
2714 session.peer.send(connection_id, update.clone())?;
2715 }
2716
2717 // Update requester with new contact
2718 let mut update = proto::UpdateContacts::default();
2719 if accept {
2720 update
2721 .contacts
2722 .push(contact_for_user(responder_id, responder_busy, &pool));
2723 }
2724 update
2725 .remove_outgoing_requests
2726 .push(responder_id.to_proto());
2727
2728 for connection_id in pool.user_connection_ids(requester_id) {
2729 session.peer.send(connection_id, update.clone())?;
2730 }
2731
2732 send_notifications(&pool, &session.peer, notifications);
2733 }
2734
2735 response.send(proto::Ack {})?;
2736 Ok(())
2737}
2738
2739/// Remove a contact.
2740async fn remove_contact(
2741 request: proto::RemoveContact,
2742 response: Response<proto::RemoveContact>,
2743 session: Session,
2744) -> Result<()> {
2745 let requester_id = session.user_id();
2746 let responder_id = UserId::from_proto(request.user_id);
2747 let db = session.db().await;
2748 let (contact_accepted, deleted_notification_id) =
2749 db.remove_contact(requester_id, responder_id).await?;
2750
2751 let pool = session.connection_pool().await;
2752 // Update outgoing contact requests of requester
2753 let mut update = proto::UpdateContacts::default();
2754 if contact_accepted {
2755 update.remove_contacts.push(responder_id.to_proto());
2756 } else {
2757 update
2758 .remove_outgoing_requests
2759 .push(responder_id.to_proto());
2760 }
2761 for connection_id in pool.user_connection_ids(requester_id) {
2762 session.peer.send(connection_id, update.clone())?;
2763 }
2764
2765 // Update incoming contact requests of responder
2766 let mut update = proto::UpdateContacts::default();
2767 if contact_accepted {
2768 update.remove_contacts.push(requester_id.to_proto());
2769 } else {
2770 update
2771 .remove_incoming_requests
2772 .push(requester_id.to_proto());
2773 }
2774 for connection_id in pool.user_connection_ids(responder_id) {
2775 session.peer.send(connection_id, update.clone())?;
2776 if let Some(notification_id) = deleted_notification_id {
2777 session.peer.send(
2778 connection_id,
2779 proto::DeleteNotification {
2780 notification_id: notification_id.to_proto(),
2781 },
2782 )?;
2783 }
2784 }
2785
2786 response.send(proto::Ack {})?;
2787 Ok(())
2788}
2789
2790fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2791 version.0.minor() < 139
2792}
2793
2794async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2795 if is_staff {
2796 return Ok(proto::Plan::ZedPro);
2797 }
2798
2799 let subscription = db.get_active_billing_subscription(user_id).await?;
2800 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2801
2802 let plan = if let Some(subscription_kind) = subscription_kind {
2803 match subscription_kind {
2804 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2805 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2806 SubscriptionKind::ZedFree => proto::Plan::Free,
2807 }
2808 } else {
2809 proto::Plan::Free
2810 };
2811
2812 Ok(plan)
2813}
2814
2815async fn make_update_user_plan_message(
2816 user: &User,
2817 is_staff: bool,
2818 db: &Arc<Database>,
2819 llm_db: Option<Arc<LlmDatabase>>,
2820) -> Result<proto::UpdateUserPlan> {
2821 let feature_flags = db.get_user_flags(user.id).await?;
2822 let plan = current_plan(db, user.id, is_staff).await?;
2823 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2824 let billing_preferences = db.get_billing_preferences(user.id).await?;
2825
2826 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2827 let subscription = db.get_active_billing_subscription(user.id).await?;
2828
2829 let subscription_period =
2830 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2831
2832 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2833 llm_db
2834 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2835 .await?
2836 } else {
2837 None
2838 };
2839
2840 (subscription_period, usage)
2841 } else {
2842 (None, None)
2843 };
2844
2845 let bypass_account_age_check = feature_flags
2846 .iter()
2847 .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2848 let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2849 && !bypass_account_age_check
2850 && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2851
2852 Ok(proto::UpdateUserPlan {
2853 plan: plan.into(),
2854 trial_started_at: billing_customer
2855 .as_ref()
2856 .and_then(|billing_customer| billing_customer.trial_started_at)
2857 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2858 is_usage_based_billing_enabled: if is_staff {
2859 Some(true)
2860 } else {
2861 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2862 },
2863 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2864 proto::SubscriptionPeriod {
2865 started_at: started_at.timestamp() as u64,
2866 ended_at: ended_at.timestamp() as u64,
2867 }
2868 }),
2869 account_too_young: Some(account_too_young),
2870 has_overdue_invoices: billing_customer
2871 .map(|billing_customer| billing_customer.has_overdue_invoices),
2872 usage: Some(
2873 usage
2874 .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2875 .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2876 ),
2877 })
2878}
2879
2880fn model_requests_limit(
2881 plan: cloud_llm_client::Plan,
2882 feature_flags: &Vec<String>,
2883) -> cloud_llm_client::UsageLimit {
2884 match plan.model_requests_limit() {
2885 cloud_llm_client::UsageLimit::Limited(limit) => {
2886 let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2887 && feature_flags
2888 .iter()
2889 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2890 {
2891 1_000
2892 } else {
2893 limit
2894 };
2895
2896 cloud_llm_client::UsageLimit::Limited(limit)
2897 }
2898 cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2899 }
2900}
2901
2902fn subscription_usage_to_proto(
2903 plan: proto::Plan,
2904 usage: crate::llm::db::subscription_usage::Model,
2905 feature_flags: &Vec<String>,
2906) -> proto::SubscriptionUsage {
2907 let plan = match plan {
2908 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2909 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2910 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2911 };
2912
2913 proto::SubscriptionUsage {
2914 model_requests_usage_amount: usage.model_requests as u32,
2915 model_requests_usage_limit: Some(proto::UsageLimit {
2916 variant: Some(match model_requests_limit(plan, feature_flags) {
2917 cloud_llm_client::UsageLimit::Limited(limit) => {
2918 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2919 limit: limit as u32,
2920 })
2921 }
2922 cloud_llm_client::UsageLimit::Unlimited => {
2923 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2924 }
2925 }),
2926 }),
2927 edit_predictions_usage_amount: usage.edit_predictions as u32,
2928 edit_predictions_usage_limit: Some(proto::UsageLimit {
2929 variant: Some(match plan.edit_predictions_limit() {
2930 cloud_llm_client::UsageLimit::Limited(limit) => {
2931 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2932 limit: limit as u32,
2933 })
2934 }
2935 cloud_llm_client::UsageLimit::Unlimited => {
2936 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2937 }
2938 }),
2939 }),
2940 }
2941}
2942
2943fn make_default_subscription_usage(
2944 plan: proto::Plan,
2945 feature_flags: &Vec<String>,
2946) -> proto::SubscriptionUsage {
2947 let plan = match plan {
2948 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2949 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2950 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2951 };
2952
2953 proto::SubscriptionUsage {
2954 model_requests_usage_amount: 0,
2955 model_requests_usage_limit: Some(proto::UsageLimit {
2956 variant: Some(match model_requests_limit(plan, feature_flags) {
2957 cloud_llm_client::UsageLimit::Limited(limit) => {
2958 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2959 limit: limit as u32,
2960 })
2961 }
2962 cloud_llm_client::UsageLimit::Unlimited => {
2963 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2964 }
2965 }),
2966 }),
2967 edit_predictions_usage_amount: 0,
2968 edit_predictions_usage_limit: Some(proto::UsageLimit {
2969 variant: Some(match plan.edit_predictions_limit() {
2970 cloud_llm_client::UsageLimit::Limited(limit) => {
2971 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2972 limit: limit as u32,
2973 })
2974 }
2975 cloud_llm_client::UsageLimit::Unlimited => {
2976 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2977 }
2978 }),
2979 }),
2980 }
2981}
2982
2983async fn update_user_plan(session: &Session) -> Result<()> {
2984 let db = session.db().await;
2985
2986 let update_user_plan = make_update_user_plan_message(
2987 session.principal.user(),
2988 session.is_staff(),
2989 &db.0,
2990 session.app_state.llm_db.clone(),
2991 )
2992 .await?;
2993
2994 session
2995 .peer
2996 .send(session.connection_id, update_user_plan)
2997 .trace_err();
2998
2999 Ok(())
3000}
3001
3002async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3003 subscribe_user_to_channels(session.user_id(), &session).await?;
3004 Ok(())
3005}
3006
3007async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3008 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3009 let mut pool = session.connection_pool().await;
3010 for membership in &channels_for_user.channel_memberships {
3011 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3012 }
3013 session.peer.send(
3014 session.connection_id,
3015 build_update_user_channels(&channels_for_user),
3016 )?;
3017 session.peer.send(
3018 session.connection_id,
3019 build_channels_update(channels_for_user),
3020 )?;
3021 Ok(())
3022}
3023
3024/// Creates a new channel.
3025async fn create_channel(
3026 request: proto::CreateChannel,
3027 response: Response<proto::CreateChannel>,
3028 session: Session,
3029) -> Result<()> {
3030 let db = session.db().await;
3031
3032 let parent_id = request.parent_id.map(ChannelId::from_proto);
3033 let (channel, membership) = db
3034 .create_channel(&request.name, parent_id, session.user_id())
3035 .await?;
3036
3037 let root_id = channel.root_id();
3038 let channel = Channel::from_model(channel);
3039
3040 response.send(proto::CreateChannelResponse {
3041 channel: Some(channel.to_proto()),
3042 parent_id: request.parent_id,
3043 })?;
3044
3045 let mut connection_pool = session.connection_pool().await;
3046 if let Some(membership) = membership {
3047 connection_pool.subscribe_to_channel(
3048 membership.user_id,
3049 membership.channel_id,
3050 membership.role,
3051 );
3052 let update = proto::UpdateUserChannels {
3053 channel_memberships: vec![proto::ChannelMembership {
3054 channel_id: membership.channel_id.to_proto(),
3055 role: membership.role.into(),
3056 }],
3057 ..Default::default()
3058 };
3059 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3060 session.peer.send(connection_id, update.clone())?;
3061 }
3062 }
3063
3064 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3065 if !role.can_see_channel(channel.visibility) {
3066 continue;
3067 }
3068
3069 let update = proto::UpdateChannels {
3070 channels: vec![channel.to_proto()],
3071 ..Default::default()
3072 };
3073 session.peer.send(connection_id, update.clone())?;
3074 }
3075
3076 Ok(())
3077}
3078
3079/// Delete a channel
3080async fn delete_channel(
3081 request: proto::DeleteChannel,
3082 response: Response<proto::DeleteChannel>,
3083 session: Session,
3084) -> Result<()> {
3085 let db = session.db().await;
3086
3087 let channel_id = request.channel_id;
3088 let (root_channel, removed_channels) = db
3089 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3090 .await?;
3091 response.send(proto::Ack {})?;
3092
3093 // Notify members of removed channels
3094 let mut update = proto::UpdateChannels::default();
3095 update
3096 .delete_channels
3097 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3098
3099 let connection_pool = session.connection_pool().await;
3100 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3101 session.peer.send(connection_id, update.clone())?;
3102 }
3103
3104 Ok(())
3105}
3106
3107/// Invite someone to join a channel.
3108async fn invite_channel_member(
3109 request: proto::InviteChannelMember,
3110 response: Response<proto::InviteChannelMember>,
3111 session: Session,
3112) -> Result<()> {
3113 let db = session.db().await;
3114 let channel_id = ChannelId::from_proto(request.channel_id);
3115 let invitee_id = UserId::from_proto(request.user_id);
3116 let InviteMemberResult {
3117 channel,
3118 notifications,
3119 } = db
3120 .invite_channel_member(
3121 channel_id,
3122 invitee_id,
3123 session.user_id(),
3124 request.role().into(),
3125 )
3126 .await?;
3127
3128 let update = proto::UpdateChannels {
3129 channel_invitations: vec![channel.to_proto()],
3130 ..Default::default()
3131 };
3132
3133 let connection_pool = session.connection_pool().await;
3134 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3135 session.peer.send(connection_id, update.clone())?;
3136 }
3137
3138 send_notifications(&connection_pool, &session.peer, notifications);
3139
3140 response.send(proto::Ack {})?;
3141 Ok(())
3142}
3143
3144/// remove someone from a channel
3145async fn remove_channel_member(
3146 request: proto::RemoveChannelMember,
3147 response: Response<proto::RemoveChannelMember>,
3148 session: Session,
3149) -> Result<()> {
3150 let db = session.db().await;
3151 let channel_id = ChannelId::from_proto(request.channel_id);
3152 let member_id = UserId::from_proto(request.user_id);
3153
3154 let RemoveChannelMemberResult {
3155 membership_update,
3156 notification_id,
3157 } = db
3158 .remove_channel_member(channel_id, member_id, session.user_id())
3159 .await?;
3160
3161 let mut connection_pool = session.connection_pool().await;
3162 notify_membership_updated(
3163 &mut connection_pool,
3164 membership_update,
3165 member_id,
3166 &session.peer,
3167 );
3168 for connection_id in connection_pool.user_connection_ids(member_id) {
3169 if let Some(notification_id) = notification_id {
3170 session
3171 .peer
3172 .send(
3173 connection_id,
3174 proto::DeleteNotification {
3175 notification_id: notification_id.to_proto(),
3176 },
3177 )
3178 .trace_err();
3179 }
3180 }
3181
3182 response.send(proto::Ack {})?;
3183 Ok(())
3184}
3185
3186/// Toggle the channel between public and private.
3187/// Care is taken to maintain the invariant that public channels only descend from public channels,
3188/// (though members-only channels can appear at any point in the hierarchy).
3189async fn set_channel_visibility(
3190 request: proto::SetChannelVisibility,
3191 response: Response<proto::SetChannelVisibility>,
3192 session: Session,
3193) -> Result<()> {
3194 let db = session.db().await;
3195 let channel_id = ChannelId::from_proto(request.channel_id);
3196 let visibility = request.visibility().into();
3197
3198 let channel_model = db
3199 .set_channel_visibility(channel_id, visibility, session.user_id())
3200 .await?;
3201 let root_id = channel_model.root_id();
3202 let channel = Channel::from_model(channel_model);
3203
3204 let mut connection_pool = session.connection_pool().await;
3205 for (user_id, role) in connection_pool
3206 .channel_user_ids(root_id)
3207 .collect::<Vec<_>>()
3208 .into_iter()
3209 {
3210 let update = if role.can_see_channel(channel.visibility) {
3211 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3212 proto::UpdateChannels {
3213 channels: vec![channel.to_proto()],
3214 ..Default::default()
3215 }
3216 } else {
3217 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3218 proto::UpdateChannels {
3219 delete_channels: vec![channel.id.to_proto()],
3220 ..Default::default()
3221 }
3222 };
3223
3224 for connection_id in connection_pool.user_connection_ids(user_id) {
3225 session.peer.send(connection_id, update.clone())?;
3226 }
3227 }
3228
3229 response.send(proto::Ack {})?;
3230 Ok(())
3231}
3232
3233/// Alter the role for a user in the channel.
3234async fn set_channel_member_role(
3235 request: proto::SetChannelMemberRole,
3236 response: Response<proto::SetChannelMemberRole>,
3237 session: Session,
3238) -> Result<()> {
3239 let db = session.db().await;
3240 let channel_id = ChannelId::from_proto(request.channel_id);
3241 let member_id = UserId::from_proto(request.user_id);
3242 let result = db
3243 .set_channel_member_role(
3244 channel_id,
3245 session.user_id(),
3246 member_id,
3247 request.role().into(),
3248 )
3249 .await?;
3250
3251 match result {
3252 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3253 let mut connection_pool = session.connection_pool().await;
3254 notify_membership_updated(
3255 &mut connection_pool,
3256 membership_update,
3257 member_id,
3258 &session.peer,
3259 )
3260 }
3261 db::SetMemberRoleResult::InviteUpdated(channel) => {
3262 let update = proto::UpdateChannels {
3263 channel_invitations: vec![channel.to_proto()],
3264 ..Default::default()
3265 };
3266
3267 for connection_id in session
3268 .connection_pool()
3269 .await
3270 .user_connection_ids(member_id)
3271 {
3272 session.peer.send(connection_id, update.clone())?;
3273 }
3274 }
3275 }
3276
3277 response.send(proto::Ack {})?;
3278 Ok(())
3279}
3280
3281/// Change the name of a channel
3282async fn rename_channel(
3283 request: proto::RenameChannel,
3284 response: Response<proto::RenameChannel>,
3285 session: Session,
3286) -> Result<()> {
3287 let db = session.db().await;
3288 let channel_id = ChannelId::from_proto(request.channel_id);
3289 let channel_model = db
3290 .rename_channel(channel_id, session.user_id(), &request.name)
3291 .await?;
3292 let root_id = channel_model.root_id();
3293 let channel = Channel::from_model(channel_model);
3294
3295 response.send(proto::RenameChannelResponse {
3296 channel: Some(channel.to_proto()),
3297 })?;
3298
3299 let connection_pool = session.connection_pool().await;
3300 let update = proto::UpdateChannels {
3301 channels: vec![channel.to_proto()],
3302 ..Default::default()
3303 };
3304 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3305 if role.can_see_channel(channel.visibility) {
3306 session.peer.send(connection_id, update.clone())?;
3307 }
3308 }
3309
3310 Ok(())
3311}
3312
3313/// Move a channel to a new parent.
3314async fn move_channel(
3315 request: proto::MoveChannel,
3316 response: Response<proto::MoveChannel>,
3317 session: Session,
3318) -> Result<()> {
3319 let channel_id = ChannelId::from_proto(request.channel_id);
3320 let to = ChannelId::from_proto(request.to);
3321
3322 let (root_id, channels) = session
3323 .db()
3324 .await
3325 .move_channel(channel_id, to, session.user_id())
3326 .await?;
3327
3328 let connection_pool = session.connection_pool().await;
3329 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3330 let channels = channels
3331 .iter()
3332 .filter_map(|channel| {
3333 if role.can_see_channel(channel.visibility) {
3334 Some(channel.to_proto())
3335 } else {
3336 None
3337 }
3338 })
3339 .collect::<Vec<_>>();
3340 if channels.is_empty() {
3341 continue;
3342 }
3343
3344 let update = proto::UpdateChannels {
3345 channels,
3346 ..Default::default()
3347 };
3348
3349 session.peer.send(connection_id, update.clone())?;
3350 }
3351
3352 response.send(Ack {})?;
3353 Ok(())
3354}
3355
3356async fn reorder_channel(
3357 request: proto::ReorderChannel,
3358 response: Response<proto::ReorderChannel>,
3359 session: Session,
3360) -> Result<()> {
3361 let channel_id = ChannelId::from_proto(request.channel_id);
3362 let direction = request.direction();
3363
3364 let updated_channels = session
3365 .db()
3366 .await
3367 .reorder_channel(channel_id, direction, session.user_id())
3368 .await?;
3369
3370 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3371 let connection_pool = session.connection_pool().await;
3372 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3373 let channels = updated_channels
3374 .iter()
3375 .filter_map(|channel| {
3376 if role.can_see_channel(channel.visibility) {
3377 Some(channel.to_proto())
3378 } else {
3379 None
3380 }
3381 })
3382 .collect::<Vec<_>>();
3383
3384 if channels.is_empty() {
3385 continue;
3386 }
3387
3388 let update = proto::UpdateChannels {
3389 channels,
3390 ..Default::default()
3391 };
3392
3393 session.peer.send(connection_id, update.clone())?;
3394 }
3395 }
3396
3397 response.send(Ack {})?;
3398 Ok(())
3399}
3400
3401/// Get the list of channel members
3402async fn get_channel_members(
3403 request: proto::GetChannelMembers,
3404 response: Response<proto::GetChannelMembers>,
3405 session: Session,
3406) -> Result<()> {
3407 let db = session.db().await;
3408 let channel_id = ChannelId::from_proto(request.channel_id);
3409 let limit = if request.limit == 0 {
3410 u16::MAX as u64
3411 } else {
3412 request.limit
3413 };
3414 let (members, users) = db
3415 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3416 .await?;
3417 response.send(proto::GetChannelMembersResponse { members, users })?;
3418 Ok(())
3419}
3420
3421/// Accept or decline a channel invitation.
3422async fn respond_to_channel_invite(
3423 request: proto::RespondToChannelInvite,
3424 response: Response<proto::RespondToChannelInvite>,
3425 session: Session,
3426) -> Result<()> {
3427 let db = session.db().await;
3428 let channel_id = ChannelId::from_proto(request.channel_id);
3429 let RespondToChannelInvite {
3430 membership_update,
3431 notifications,
3432 } = db
3433 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3434 .await?;
3435
3436 let mut connection_pool = session.connection_pool().await;
3437 if let Some(membership_update) = membership_update {
3438 notify_membership_updated(
3439 &mut connection_pool,
3440 membership_update,
3441 session.user_id(),
3442 &session.peer,
3443 );
3444 } else {
3445 let update = proto::UpdateChannels {
3446 remove_channel_invitations: vec![channel_id.to_proto()],
3447 ..Default::default()
3448 };
3449
3450 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3451 session.peer.send(connection_id, update.clone())?;
3452 }
3453 };
3454
3455 send_notifications(&connection_pool, &session.peer, notifications);
3456
3457 response.send(proto::Ack {})?;
3458
3459 Ok(())
3460}
3461
3462/// Join the channels' room
3463async fn join_channel(
3464 request: proto::JoinChannel,
3465 response: Response<proto::JoinChannel>,
3466 session: Session,
3467) -> Result<()> {
3468 let channel_id = ChannelId::from_proto(request.channel_id);
3469 join_channel_internal(channel_id, Box::new(response), session).await
3470}
3471
3472trait JoinChannelInternalResponse {
3473 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3474}
3475impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3476 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3477 Response::<proto::JoinChannel>::send(self, result)
3478 }
3479}
3480impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3481 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3482 Response::<proto::JoinRoom>::send(self, result)
3483 }
3484}
3485
3486async fn join_channel_internal(
3487 channel_id: ChannelId,
3488 response: Box<impl JoinChannelInternalResponse>,
3489 session: Session,
3490) -> Result<()> {
3491 let joined_room = {
3492 let mut db = session.db().await;
3493 // If zed quits without leaving the room, and the user re-opens zed before the
3494 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3495 // room they were in.
3496 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3497 tracing::info!(
3498 stale_connection_id = %connection,
3499 "cleaning up stale connection",
3500 );
3501 drop(db);
3502 leave_room_for_session(&session, connection).await?;
3503 db = session.db().await;
3504 }
3505
3506 let (joined_room, membership_updated, role) = db
3507 .join_channel(channel_id, session.user_id(), session.connection_id)
3508 .await?;
3509
3510 let live_kit_connection_info =
3511 session
3512 .app_state
3513 .livekit_client
3514 .as_ref()
3515 .and_then(|live_kit| {
3516 let (can_publish, token) = if role == ChannelRole::Guest {
3517 (
3518 false,
3519 live_kit
3520 .guest_token(
3521 &joined_room.room.livekit_room,
3522 &session.user_id().to_string(),
3523 )
3524 .trace_err()?,
3525 )
3526 } else {
3527 (
3528 true,
3529 live_kit
3530 .room_token(
3531 &joined_room.room.livekit_room,
3532 &session.user_id().to_string(),
3533 )
3534 .trace_err()?,
3535 )
3536 };
3537
3538 Some(LiveKitConnectionInfo {
3539 server_url: live_kit.url().into(),
3540 token,
3541 can_publish,
3542 })
3543 });
3544
3545 response.send(proto::JoinRoomResponse {
3546 room: Some(joined_room.room.clone()),
3547 channel_id: joined_room
3548 .channel
3549 .as_ref()
3550 .map(|channel| channel.id.to_proto()),
3551 live_kit_connection_info,
3552 })?;
3553
3554 let mut connection_pool = session.connection_pool().await;
3555 if let Some(membership_updated) = membership_updated {
3556 notify_membership_updated(
3557 &mut connection_pool,
3558 membership_updated,
3559 session.user_id(),
3560 &session.peer,
3561 );
3562 }
3563
3564 room_updated(&joined_room.room, &session.peer);
3565
3566 joined_room
3567 };
3568
3569 channel_updated(
3570 &joined_room.channel.context("channel not returned")?,
3571 &joined_room.room,
3572 &session.peer,
3573 &*session.connection_pool().await,
3574 );
3575
3576 update_user_contacts(session.user_id(), &session).await?;
3577 Ok(())
3578}
3579
3580/// Start editing the channel notes
3581async fn join_channel_buffer(
3582 request: proto::JoinChannelBuffer,
3583 response: Response<proto::JoinChannelBuffer>,
3584 session: Session,
3585) -> Result<()> {
3586 let db = session.db().await;
3587 let channel_id = ChannelId::from_proto(request.channel_id);
3588
3589 let open_response = db
3590 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3591 .await?;
3592
3593 let collaborators = open_response.collaborators.clone();
3594 response.send(open_response)?;
3595
3596 let update = UpdateChannelBufferCollaborators {
3597 channel_id: channel_id.to_proto(),
3598 collaborators: collaborators.clone(),
3599 };
3600 channel_buffer_updated(
3601 session.connection_id,
3602 collaborators
3603 .iter()
3604 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3605 &update,
3606 &session.peer,
3607 );
3608
3609 Ok(())
3610}
3611
3612/// Edit the channel notes
3613async fn update_channel_buffer(
3614 request: proto::UpdateChannelBuffer,
3615 session: Session,
3616) -> Result<()> {
3617 let db = session.db().await;
3618 let channel_id = ChannelId::from_proto(request.channel_id);
3619
3620 let (collaborators, epoch, version) = db
3621 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3622 .await?;
3623
3624 channel_buffer_updated(
3625 session.connection_id,
3626 collaborators.clone(),
3627 &proto::UpdateChannelBuffer {
3628 channel_id: channel_id.to_proto(),
3629 operations: request.operations,
3630 },
3631 &session.peer,
3632 );
3633
3634 let pool = &*session.connection_pool().await;
3635
3636 let non_collaborators =
3637 pool.channel_connection_ids(channel_id)
3638 .filter_map(|(connection_id, _)| {
3639 if collaborators.contains(&connection_id) {
3640 None
3641 } else {
3642 Some(connection_id)
3643 }
3644 });
3645
3646 broadcast(None, non_collaborators, |peer_id| {
3647 session.peer.send(
3648 peer_id,
3649 proto::UpdateChannels {
3650 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3651 channel_id: channel_id.to_proto(),
3652 epoch: epoch as u64,
3653 version: version.clone(),
3654 }],
3655 ..Default::default()
3656 },
3657 )
3658 });
3659
3660 Ok(())
3661}
3662
3663/// Rejoin the channel notes after a connection blip
3664async fn rejoin_channel_buffers(
3665 request: proto::RejoinChannelBuffers,
3666 response: Response<proto::RejoinChannelBuffers>,
3667 session: Session,
3668) -> Result<()> {
3669 let db = session.db().await;
3670 let buffers = db
3671 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3672 .await?;
3673
3674 for rejoined_buffer in &buffers {
3675 let collaborators_to_notify = rejoined_buffer
3676 .buffer
3677 .collaborators
3678 .iter()
3679 .filter_map(|c| Some(c.peer_id?.into()));
3680 channel_buffer_updated(
3681 session.connection_id,
3682 collaborators_to_notify,
3683 &proto::UpdateChannelBufferCollaborators {
3684 channel_id: rejoined_buffer.buffer.channel_id,
3685 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3686 },
3687 &session.peer,
3688 );
3689 }
3690
3691 response.send(proto::RejoinChannelBuffersResponse {
3692 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3693 })?;
3694
3695 Ok(())
3696}
3697
3698/// Stop editing the channel notes
3699async fn leave_channel_buffer(
3700 request: proto::LeaveChannelBuffer,
3701 response: Response<proto::LeaveChannelBuffer>,
3702 session: Session,
3703) -> Result<()> {
3704 let db = session.db().await;
3705 let channel_id = ChannelId::from_proto(request.channel_id);
3706
3707 let left_buffer = db
3708 .leave_channel_buffer(channel_id, session.connection_id)
3709 .await?;
3710
3711 response.send(Ack {})?;
3712
3713 channel_buffer_updated(
3714 session.connection_id,
3715 left_buffer.connections,
3716 &proto::UpdateChannelBufferCollaborators {
3717 channel_id: channel_id.to_proto(),
3718 collaborators: left_buffer.collaborators,
3719 },
3720 &session.peer,
3721 );
3722
3723 Ok(())
3724}
3725
3726fn channel_buffer_updated<T: EnvelopedMessage>(
3727 sender_id: ConnectionId,
3728 collaborators: impl IntoIterator<Item = ConnectionId>,
3729 message: &T,
3730 peer: &Peer,
3731) {
3732 broadcast(Some(sender_id), collaborators, |peer_id| {
3733 peer.send(peer_id, message.clone())
3734 });
3735}
3736
3737fn send_notifications(
3738 connection_pool: &ConnectionPool,
3739 peer: &Peer,
3740 notifications: db::NotificationBatch,
3741) {
3742 for (user_id, notification) in notifications {
3743 for connection_id in connection_pool.user_connection_ids(user_id) {
3744 if let Err(error) = peer.send(
3745 connection_id,
3746 proto::AddNotification {
3747 notification: Some(notification.clone()),
3748 },
3749 ) {
3750 tracing::error!(
3751 "failed to send notification to {:?} {}",
3752 connection_id,
3753 error
3754 );
3755 }
3756 }
3757 }
3758}
3759
3760/// Send a message to the channel
3761async fn send_channel_message(
3762 request: proto::SendChannelMessage,
3763 response: Response<proto::SendChannelMessage>,
3764 session: Session,
3765) -> Result<()> {
3766 // Validate the message body.
3767 let body = request.body.trim().to_string();
3768 if body.len() > MAX_MESSAGE_LEN {
3769 return Err(anyhow!("message is too long"))?;
3770 }
3771 if body.is_empty() {
3772 return Err(anyhow!("message can't be blank"))?;
3773 }
3774
3775 // TODO: adjust mentions if body is trimmed
3776
3777 let timestamp = OffsetDateTime::now_utc();
3778 let nonce = request.nonce.context("nonce can't be blank")?;
3779
3780 let channel_id = ChannelId::from_proto(request.channel_id);
3781 let CreatedChannelMessage {
3782 message_id,
3783 participant_connection_ids,
3784 notifications,
3785 } = session
3786 .db()
3787 .await
3788 .create_channel_message(
3789 channel_id,
3790 session.user_id(),
3791 &body,
3792 &request.mentions,
3793 timestamp,
3794 nonce.clone().into(),
3795 request.reply_to_message_id.map(MessageId::from_proto),
3796 )
3797 .await?;
3798
3799 let message = proto::ChannelMessage {
3800 sender_id: session.user_id().to_proto(),
3801 id: message_id.to_proto(),
3802 body,
3803 mentions: request.mentions,
3804 timestamp: timestamp.unix_timestamp() as u64,
3805 nonce: Some(nonce),
3806 reply_to_message_id: request.reply_to_message_id,
3807 edited_at: None,
3808 };
3809 broadcast(
3810 Some(session.connection_id),
3811 participant_connection_ids.clone(),
3812 |connection| {
3813 session.peer.send(
3814 connection,
3815 proto::ChannelMessageSent {
3816 channel_id: channel_id.to_proto(),
3817 message: Some(message.clone()),
3818 },
3819 )
3820 },
3821 );
3822 response.send(proto::SendChannelMessageResponse {
3823 message: Some(message),
3824 })?;
3825
3826 let pool = &*session.connection_pool().await;
3827 let non_participants =
3828 pool.channel_connection_ids(channel_id)
3829 .filter_map(|(connection_id, _)| {
3830 if participant_connection_ids.contains(&connection_id) {
3831 None
3832 } else {
3833 Some(connection_id)
3834 }
3835 });
3836 broadcast(None, non_participants, |peer_id| {
3837 session.peer.send(
3838 peer_id,
3839 proto::UpdateChannels {
3840 latest_channel_message_ids: vec![proto::ChannelMessageId {
3841 channel_id: channel_id.to_proto(),
3842 message_id: message_id.to_proto(),
3843 }],
3844 ..Default::default()
3845 },
3846 )
3847 });
3848 send_notifications(pool, &session.peer, notifications);
3849
3850 Ok(())
3851}
3852
3853/// Delete a channel message
3854async fn remove_channel_message(
3855 request: proto::RemoveChannelMessage,
3856 response: Response<proto::RemoveChannelMessage>,
3857 session: Session,
3858) -> Result<()> {
3859 let channel_id = ChannelId::from_proto(request.channel_id);
3860 let message_id = MessageId::from_proto(request.message_id);
3861 let (connection_ids, existing_notification_ids) = session
3862 .db()
3863 .await
3864 .remove_channel_message(channel_id, message_id, session.user_id())
3865 .await?;
3866
3867 broadcast(
3868 Some(session.connection_id),
3869 connection_ids,
3870 move |connection| {
3871 session.peer.send(connection, request.clone())?;
3872
3873 for notification_id in &existing_notification_ids {
3874 session.peer.send(
3875 connection,
3876 proto::DeleteNotification {
3877 notification_id: (*notification_id).to_proto(),
3878 },
3879 )?;
3880 }
3881
3882 Ok(())
3883 },
3884 );
3885 response.send(proto::Ack {})?;
3886 Ok(())
3887}
3888
3889async fn update_channel_message(
3890 request: proto::UpdateChannelMessage,
3891 response: Response<proto::UpdateChannelMessage>,
3892 session: Session,
3893) -> Result<()> {
3894 let channel_id = ChannelId::from_proto(request.channel_id);
3895 let message_id = MessageId::from_proto(request.message_id);
3896 let updated_at = OffsetDateTime::now_utc();
3897 let UpdatedChannelMessage {
3898 message_id,
3899 participant_connection_ids,
3900 notifications,
3901 reply_to_message_id,
3902 timestamp,
3903 deleted_mention_notification_ids,
3904 updated_mention_notifications,
3905 } = session
3906 .db()
3907 .await
3908 .update_channel_message(
3909 channel_id,
3910 message_id,
3911 session.user_id(),
3912 request.body.as_str(),
3913 &request.mentions,
3914 updated_at,
3915 )
3916 .await?;
3917
3918 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3919
3920 let message = proto::ChannelMessage {
3921 sender_id: session.user_id().to_proto(),
3922 id: message_id.to_proto(),
3923 body: request.body.clone(),
3924 mentions: request.mentions.clone(),
3925 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3926 nonce: Some(nonce),
3927 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3928 edited_at: Some(updated_at.unix_timestamp() as u64),
3929 };
3930
3931 response.send(proto::Ack {})?;
3932
3933 let pool = &*session.connection_pool().await;
3934 broadcast(
3935 Some(session.connection_id),
3936 participant_connection_ids,
3937 |connection| {
3938 session.peer.send(
3939 connection,
3940 proto::ChannelMessageUpdate {
3941 channel_id: channel_id.to_proto(),
3942 message: Some(message.clone()),
3943 },
3944 )?;
3945
3946 for notification_id in &deleted_mention_notification_ids {
3947 session.peer.send(
3948 connection,
3949 proto::DeleteNotification {
3950 notification_id: (*notification_id).to_proto(),
3951 },
3952 )?;
3953 }
3954
3955 for notification in &updated_mention_notifications {
3956 session.peer.send(
3957 connection,
3958 proto::UpdateNotification {
3959 notification: Some(notification.clone()),
3960 },
3961 )?;
3962 }
3963
3964 Ok(())
3965 },
3966 );
3967
3968 send_notifications(pool, &session.peer, notifications);
3969
3970 Ok(())
3971}
3972
3973/// Mark a channel message as read
3974async fn acknowledge_channel_message(
3975 request: proto::AckChannelMessage,
3976 session: Session,
3977) -> Result<()> {
3978 let channel_id = ChannelId::from_proto(request.channel_id);
3979 let message_id = MessageId::from_proto(request.message_id);
3980 let notifications = session
3981 .db()
3982 .await
3983 .observe_channel_message(channel_id, session.user_id(), message_id)
3984 .await?;
3985 send_notifications(
3986 &*session.connection_pool().await,
3987 &session.peer,
3988 notifications,
3989 );
3990 Ok(())
3991}
3992
3993/// Mark a buffer version as synced
3994async fn acknowledge_buffer_version(
3995 request: proto::AckBufferOperation,
3996 session: Session,
3997) -> Result<()> {
3998 let buffer_id = BufferId::from_proto(request.buffer_id);
3999 session
4000 .db()
4001 .await
4002 .observe_buffer_version(
4003 buffer_id,
4004 session.user_id(),
4005 request.epoch as i32,
4006 &request.version,
4007 )
4008 .await?;
4009 Ok(())
4010}
4011
4012/// Get a Supermaven API key for the user
4013async fn get_supermaven_api_key(
4014 _request: proto::GetSupermavenApiKey,
4015 response: Response<proto::GetSupermavenApiKey>,
4016 session: Session,
4017) -> Result<()> {
4018 let user_id: String = session.user_id().to_string();
4019 if !session.is_staff() {
4020 return Err(anyhow!("supermaven not enabled for this account"))?;
4021 }
4022
4023 let email = session.email().context("user must have an email")?;
4024
4025 let supermaven_admin_api = session
4026 .supermaven_client
4027 .as_ref()
4028 .context("supermaven not configured")?;
4029
4030 let result = supermaven_admin_api
4031 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4032 .await?;
4033
4034 response.send(proto::GetSupermavenApiKeyResponse {
4035 api_key: result.api_key,
4036 })?;
4037
4038 Ok(())
4039}
4040
4041/// Start receiving chat updates for a channel
4042async fn join_channel_chat(
4043 request: proto::JoinChannelChat,
4044 response: Response<proto::JoinChannelChat>,
4045 session: Session,
4046) -> Result<()> {
4047 let channel_id = ChannelId::from_proto(request.channel_id);
4048
4049 let db = session.db().await;
4050 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4051 .await?;
4052 let messages = db
4053 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4054 .await?;
4055 response.send(proto::JoinChannelChatResponse {
4056 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4057 messages,
4058 })?;
4059 Ok(())
4060}
4061
4062/// Stop receiving chat updates for a channel
4063async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
4064 let channel_id = ChannelId::from_proto(request.channel_id);
4065 session
4066 .db()
4067 .await
4068 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4069 .await?;
4070 Ok(())
4071}
4072
4073/// Retrieve the chat history for a channel
4074async fn get_channel_messages(
4075 request: proto::GetChannelMessages,
4076 response: Response<proto::GetChannelMessages>,
4077 session: Session,
4078) -> Result<()> {
4079 let channel_id = ChannelId::from_proto(request.channel_id);
4080 let messages = session
4081 .db()
4082 .await
4083 .get_channel_messages(
4084 channel_id,
4085 session.user_id(),
4086 MESSAGE_COUNT_PER_PAGE,
4087 Some(MessageId::from_proto(request.before_message_id)),
4088 )
4089 .await?;
4090 response.send(proto::GetChannelMessagesResponse {
4091 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4092 messages,
4093 })?;
4094 Ok(())
4095}
4096
4097/// Retrieve specific chat messages
4098async fn get_channel_messages_by_id(
4099 request: proto::GetChannelMessagesById,
4100 response: Response<proto::GetChannelMessagesById>,
4101 session: Session,
4102) -> Result<()> {
4103 let message_ids = request
4104 .message_ids
4105 .iter()
4106 .map(|id| MessageId::from_proto(*id))
4107 .collect::<Vec<_>>();
4108 let messages = session
4109 .db()
4110 .await
4111 .get_channel_messages_by_id(session.user_id(), &message_ids)
4112 .await?;
4113 response.send(proto::GetChannelMessagesResponse {
4114 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4115 messages,
4116 })?;
4117 Ok(())
4118}
4119
4120/// Retrieve the current users notifications
4121async fn get_notifications(
4122 request: proto::GetNotifications,
4123 response: Response<proto::GetNotifications>,
4124 session: Session,
4125) -> Result<()> {
4126 let notifications = session
4127 .db()
4128 .await
4129 .get_notifications(
4130 session.user_id(),
4131 NOTIFICATION_COUNT_PER_PAGE,
4132 request.before_id.map(db::NotificationId::from_proto),
4133 )
4134 .await?;
4135 response.send(proto::GetNotificationsResponse {
4136 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4137 notifications,
4138 })?;
4139 Ok(())
4140}
4141
4142/// Mark notifications as read
4143async fn mark_notification_as_read(
4144 request: proto::MarkNotificationRead,
4145 response: Response<proto::MarkNotificationRead>,
4146 session: Session,
4147) -> Result<()> {
4148 let database = &session.db().await;
4149 let notifications = database
4150 .mark_notification_as_read_by_id(
4151 session.user_id(),
4152 NotificationId::from_proto(request.notification_id),
4153 )
4154 .await?;
4155 send_notifications(
4156 &*session.connection_pool().await,
4157 &session.peer,
4158 notifications,
4159 );
4160 response.send(proto::Ack {})?;
4161 Ok(())
4162}
4163
4164/// Get the current users information
4165async fn get_private_user_info(
4166 _request: proto::GetPrivateUserInfo,
4167 response: Response<proto::GetPrivateUserInfo>,
4168 session: Session,
4169) -> Result<()> {
4170 let db = session.db().await;
4171
4172 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4173 let user = db
4174 .get_user_by_id(session.user_id())
4175 .await?
4176 .context("user not found")?;
4177 let flags = db.get_user_flags(session.user_id()).await?;
4178
4179 response.send(proto::GetPrivateUserInfoResponse {
4180 metrics_id,
4181 staff: user.admin,
4182 flags,
4183 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4184 })?;
4185 Ok(())
4186}
4187
4188/// Accept the terms of service (tos) on behalf of the current user
4189async fn accept_terms_of_service(
4190 _request: proto::AcceptTermsOfService,
4191 response: Response<proto::AcceptTermsOfService>,
4192 session: Session,
4193) -> Result<()> {
4194 let db = session.db().await;
4195
4196 let accepted_tos_at = Utc::now();
4197 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4198 .await?;
4199
4200 response.send(proto::AcceptTermsOfServiceResponse {
4201 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4202 })?;
4203
4204 // When the user accepts the terms of service, we want to refresh their LLM
4205 // token to grant access.
4206 session
4207 .peer
4208 .send(session.connection_id, proto::RefreshLlmToken {})?;
4209
4210 Ok(())
4211}
4212
4213async fn get_llm_api_token(
4214 _request: proto::GetLlmToken,
4215 response: Response<proto::GetLlmToken>,
4216 session: Session,
4217) -> Result<()> {
4218 let db = session.db().await;
4219
4220 let flags = db.get_user_flags(session.user_id()).await?;
4221
4222 let user_id = session.user_id();
4223 let user = db
4224 .get_user_by_id(user_id)
4225 .await?
4226 .with_context(|| format!("user {user_id} not found"))?;
4227
4228 if user.accepted_tos_at.is_none() {
4229 Err(anyhow!("terms of service not accepted"))?
4230 }
4231
4232 let stripe_client = session
4233 .app_state
4234 .stripe_client
4235 .as_ref()
4236 .context("failed to retrieve Stripe client")?;
4237
4238 let stripe_billing = session
4239 .app_state
4240 .stripe_billing
4241 .as_ref()
4242 .context("failed to retrieve Stripe billing object")?;
4243
4244 let billing_customer = if let Some(billing_customer) =
4245 db.get_billing_customer_by_user_id(user.id).await?
4246 {
4247 billing_customer
4248 } else {
4249 let customer_id = stripe_billing
4250 .find_or_create_customer_by_email(user.email_address.as_deref())
4251 .await?;
4252
4253 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4254 .await?
4255 .context("billing customer not found")?
4256 };
4257
4258 let billing_subscription =
4259 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4260 billing_subscription
4261 } else {
4262 let stripe_customer_id =
4263 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4264
4265 let stripe_subscription = stripe_billing
4266 .subscribe_to_zed_free(stripe_customer_id)
4267 .await?;
4268
4269 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4270 billing_customer_id: billing_customer.id,
4271 kind: Some(SubscriptionKind::ZedFree),
4272 stripe_subscription_id: stripe_subscription.id.to_string(),
4273 stripe_subscription_status: stripe_subscription.status.into(),
4274 stripe_cancellation_reason: None,
4275 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4276 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4277 })
4278 .await?
4279 };
4280
4281 let billing_preferences = db.get_billing_preferences(user.id).await?;
4282
4283 let token = LlmTokenClaims::create(
4284 &user,
4285 session.is_staff(),
4286 billing_customer,
4287 billing_preferences,
4288 &flags,
4289 billing_subscription,
4290 session.system_id.clone(),
4291 &session.app_state.config,
4292 )?;
4293 response.send(proto::GetLlmTokenResponse { token })?;
4294 Ok(())
4295}
4296
4297fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4298 let message = match message {
4299 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4300 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4301 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4302 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4303 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4304 code: frame.code.into(),
4305 reason: frame.reason.as_str().to_owned().into(),
4306 })),
4307 // We should never receive a frame while reading the message, according
4308 // to the `tungstenite` maintainers:
4309 //
4310 // > It cannot occur when you read messages from the WebSocket, but it
4311 // > can be used when you want to send the raw frames (e.g. you want to
4312 // > send the frames to the WebSocket without composing the full message first).
4313 // >
4314 // > — https://github.com/snapview/tungstenite-rs/issues/268
4315 TungsteniteMessage::Frame(_) => {
4316 bail!("received an unexpected frame while reading the message")
4317 }
4318 };
4319
4320 Ok(message)
4321}
4322
4323fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4324 match message {
4325 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4326 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4327 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4328 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4329 AxumMessage::Close(frame) => {
4330 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4331 code: frame.code.into(),
4332 reason: frame.reason.as_ref().into(),
4333 }))
4334 }
4335 }
4336}
4337
4338fn notify_membership_updated(
4339 connection_pool: &mut ConnectionPool,
4340 result: MembershipUpdated,
4341 user_id: UserId,
4342 peer: &Peer,
4343) {
4344 for membership in &result.new_channels.channel_memberships {
4345 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4346 }
4347 for channel_id in &result.removed_channels {
4348 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4349 }
4350
4351 let user_channels_update = proto::UpdateUserChannels {
4352 channel_memberships: result
4353 .new_channels
4354 .channel_memberships
4355 .iter()
4356 .map(|cm| proto::ChannelMembership {
4357 channel_id: cm.channel_id.to_proto(),
4358 role: cm.role.into(),
4359 })
4360 .collect(),
4361 ..Default::default()
4362 };
4363
4364 let mut update = build_channels_update(result.new_channels);
4365 update.delete_channels = result
4366 .removed_channels
4367 .into_iter()
4368 .map(|id| id.to_proto())
4369 .collect();
4370 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4371
4372 for connection_id in connection_pool.user_connection_ids(user_id) {
4373 peer.send(connection_id, user_channels_update.clone())
4374 .trace_err();
4375 peer.send(connection_id, update.clone()).trace_err();
4376 }
4377}
4378
4379fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4380 proto::UpdateUserChannels {
4381 channel_memberships: channels
4382 .channel_memberships
4383 .iter()
4384 .map(|m| proto::ChannelMembership {
4385 channel_id: m.channel_id.to_proto(),
4386 role: m.role.into(),
4387 })
4388 .collect(),
4389 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4390 observed_channel_message_id: channels.observed_channel_messages.clone(),
4391 }
4392}
4393
4394fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4395 let mut update = proto::UpdateChannels::default();
4396
4397 for channel in channels.channels {
4398 update.channels.push(channel.to_proto());
4399 }
4400
4401 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4402 update.latest_channel_message_ids = channels.latest_channel_messages;
4403
4404 for (channel_id, participants) in channels.channel_participants {
4405 update
4406 .channel_participants
4407 .push(proto::ChannelParticipants {
4408 channel_id: channel_id.to_proto(),
4409 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4410 });
4411 }
4412
4413 for channel in channels.invited_channels {
4414 update.channel_invitations.push(channel.to_proto());
4415 }
4416
4417 update
4418}
4419
4420fn build_initial_contacts_update(
4421 contacts: Vec<db::Contact>,
4422 pool: &ConnectionPool,
4423) -> proto::UpdateContacts {
4424 let mut update = proto::UpdateContacts::default();
4425
4426 for contact in contacts {
4427 match contact {
4428 db::Contact::Accepted { user_id, busy } => {
4429 update.contacts.push(contact_for_user(user_id, busy, pool));
4430 }
4431 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4432 db::Contact::Incoming { user_id } => {
4433 update
4434 .incoming_requests
4435 .push(proto::IncomingContactRequest {
4436 requester_id: user_id.to_proto(),
4437 })
4438 }
4439 }
4440 }
4441
4442 update
4443}
4444
4445fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4446 proto::Contact {
4447 user_id: user_id.to_proto(),
4448 online: pool.is_user_online(user_id),
4449 busy,
4450 }
4451}
4452
4453fn room_updated(room: &proto::Room, peer: &Peer) {
4454 broadcast(
4455 None,
4456 room.participants
4457 .iter()
4458 .filter_map(|participant| Some(participant.peer_id?.into())),
4459 |peer_id| {
4460 peer.send(
4461 peer_id,
4462 proto::RoomUpdated {
4463 room: Some(room.clone()),
4464 },
4465 )
4466 },
4467 );
4468}
4469
4470fn channel_updated(
4471 channel: &db::channel::Model,
4472 room: &proto::Room,
4473 peer: &Peer,
4474 pool: &ConnectionPool,
4475) {
4476 let participants = room
4477 .participants
4478 .iter()
4479 .map(|p| p.user_id)
4480 .collect::<Vec<_>>();
4481
4482 broadcast(
4483 None,
4484 pool.channel_connection_ids(channel.root_id())
4485 .filter_map(|(channel_id, role)| {
4486 role.can_see_channel(channel.visibility)
4487 .then_some(channel_id)
4488 }),
4489 |peer_id| {
4490 peer.send(
4491 peer_id,
4492 proto::UpdateChannels {
4493 channel_participants: vec![proto::ChannelParticipants {
4494 channel_id: channel.id.to_proto(),
4495 participant_user_ids: participants.clone(),
4496 }],
4497 ..Default::default()
4498 },
4499 )
4500 },
4501 );
4502}
4503
4504async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4505 let db = session.db().await;
4506
4507 let contacts = db.get_contacts(user_id).await?;
4508 let busy = db.is_user_busy(user_id).await?;
4509
4510 let pool = session.connection_pool().await;
4511 let updated_contact = contact_for_user(user_id, busy, &pool);
4512 for contact in contacts {
4513 if let db::Contact::Accepted {
4514 user_id: contact_user_id,
4515 ..
4516 } = contact
4517 {
4518 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4519 session
4520 .peer
4521 .send(
4522 contact_conn_id,
4523 proto::UpdateContacts {
4524 contacts: vec![updated_contact.clone()],
4525 remove_contacts: Default::default(),
4526 incoming_requests: Default::default(),
4527 remove_incoming_requests: Default::default(),
4528 outgoing_requests: Default::default(),
4529 remove_outgoing_requests: Default::default(),
4530 },
4531 )
4532 .trace_err();
4533 }
4534 }
4535 }
4536 Ok(())
4537}
4538
4539async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4540 let mut contacts_to_update = HashSet::default();
4541
4542 let room_id;
4543 let canceled_calls_to_user_ids;
4544 let livekit_room;
4545 let delete_livekit_room;
4546 let room;
4547 let channel;
4548
4549 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4550 contacts_to_update.insert(session.user_id());
4551
4552 for project in left_room.left_projects.values() {
4553 project_left(project, session);
4554 }
4555
4556 room_id = RoomId::from_proto(left_room.room.id);
4557 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4558 livekit_room = mem::take(&mut left_room.room.livekit_room);
4559 delete_livekit_room = left_room.deleted;
4560 room = mem::take(&mut left_room.room);
4561 channel = mem::take(&mut left_room.channel);
4562
4563 room_updated(&room, &session.peer);
4564 } else {
4565 return Ok(());
4566 }
4567
4568 if let Some(channel) = channel {
4569 channel_updated(
4570 &channel,
4571 &room,
4572 &session.peer,
4573 &*session.connection_pool().await,
4574 );
4575 }
4576
4577 {
4578 let pool = session.connection_pool().await;
4579 for canceled_user_id in canceled_calls_to_user_ids {
4580 for connection_id in pool.user_connection_ids(canceled_user_id) {
4581 session
4582 .peer
4583 .send(
4584 connection_id,
4585 proto::CallCanceled {
4586 room_id: room_id.to_proto(),
4587 },
4588 )
4589 .trace_err();
4590 }
4591 contacts_to_update.insert(canceled_user_id);
4592 }
4593 }
4594
4595 for contact_user_id in contacts_to_update {
4596 update_user_contacts(contact_user_id, session).await?;
4597 }
4598
4599 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4600 live_kit
4601 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4602 .await
4603 .trace_err();
4604
4605 if delete_livekit_room {
4606 live_kit.delete_room(livekit_room).await.trace_err();
4607 }
4608 }
4609
4610 Ok(())
4611}
4612
4613async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4614 let left_channel_buffers = session
4615 .db()
4616 .await
4617 .leave_channel_buffers(session.connection_id)
4618 .await?;
4619
4620 for left_buffer in left_channel_buffers {
4621 channel_buffer_updated(
4622 session.connection_id,
4623 left_buffer.connections,
4624 &proto::UpdateChannelBufferCollaborators {
4625 channel_id: left_buffer.channel_id.to_proto(),
4626 collaborators: left_buffer.collaborators,
4627 },
4628 &session.peer,
4629 );
4630 }
4631
4632 Ok(())
4633}
4634
4635fn project_left(project: &db::LeftProject, session: &Session) {
4636 for connection_id in &project.connection_ids {
4637 if project.should_unshare {
4638 session
4639 .peer
4640 .send(
4641 *connection_id,
4642 proto::UnshareProject {
4643 project_id: project.id.to_proto(),
4644 },
4645 )
4646 .trace_err();
4647 } else {
4648 session
4649 .peer
4650 .send(
4651 *connection_id,
4652 proto::RemoveProjectCollaborator {
4653 project_id: project.id.to_proto(),
4654 peer_id: Some(session.connection_id.into()),
4655 },
4656 )
4657 .trace_err();
4658 }
4659 }
4660}
4661
4662pub trait ResultExt {
4663 type Ok;
4664
4665 fn trace_err(self) -> Option<Self::Ok>;
4666}
4667
4668impl<T, E> ResultExt for Result<T, E>
4669where
4670 E: std::fmt::Debug,
4671{
4672 type Ok = T;
4673
4674 #[track_caller]
4675 fn trace_err(self) -> Option<T> {
4676 match self {
4677 Ok(value) => Some(value),
4678 Err(error) => {
4679 tracing::error!("{:?}", error);
4680 None
4681 }
4682 }
4683 }
4684}