1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{
8 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
9 MIN_ACCOUNT_AGE_FOR_LLM_USE,
10};
11use crate::stripe_client::StripeCustomerId;
12use crate::{
13 AppState, Error, Result, auth,
14 db::{
15 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
16 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
17 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
18 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
19 },
20 executor::Executor,
21};
22use anyhow::{Context as _, anyhow, bail};
23use async_tungstenite::tungstenite::{
24 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
25};
26use axum::headers::UserAgent;
27use axum::{
28 Extension, Router, TypedHeader,
29 body::Body,
30 extract::{
31 ConnectInfo, WebSocketUpgrade,
32 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
33 },
34 headers::{Header, HeaderName},
35 http::StatusCode,
36 middleware,
37 response::IntoResponse,
38 routing::get,
39};
40use chrono::Utc;
41use collections::{HashMap, HashSet};
42pub use connection_pool::{ConnectionPool, ZedVersion};
43use core::fmt::{self, Debug, Formatter};
44use reqwest_client::ReqwestClient;
45use rpc::proto::{MultiLspQuery, split_repository_update};
46use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
47
48use futures::{
49 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
50 stream::FuturesUnordered,
51};
52use prometheus::{IntGauge, register_int_gauge};
53use rpc::{
54 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
55 proto::{
56 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
57 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
58 },
59};
60use semantic_version::SemanticVersion;
61use serde::{Serialize, Serializer};
62use std::{
63 any::TypeId,
64 future::Future,
65 marker::PhantomData,
66 mem,
67 net::SocketAddr,
68 ops::{Deref, DerefMut},
69 rc::Rc,
70 sync::{
71 Arc, OnceLock,
72 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
73 },
74 time::{Duration, Instant},
75};
76use time::OffsetDateTime;
77use tokio::sync::{Semaphore, watch};
78use tower::ServiceBuilder;
79use tracing::{
80 Instrument,
81 field::{self},
82 info_span, instrument,
83};
84
85pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
86
87// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
88pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
89
90const MESSAGE_COUNT_PER_PAGE: usize = 100;
91const MAX_MESSAGE_LEN: usize = 1024;
92const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
93const MAX_CONCURRENT_CONNECTIONS: usize = 512;
94
95static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
96
97type MessageHandler =
98 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
99
100pub struct ConnectionGuard;
101
102impl ConnectionGuard {
103 pub fn try_acquire() -> Result<Self, ()> {
104 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
105 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
106 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
107 tracing::error!(
108 "too many concurrent connections: {}",
109 current_connections + 1
110 );
111 return Err(());
112 }
113 Ok(ConnectionGuard)
114 }
115}
116
117impl Drop for ConnectionGuard {
118 fn drop(&mut self) {
119 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
120 }
121}
122
123struct Response<R> {
124 peer: Arc<Peer>,
125 receipt: Receipt<R>,
126 responded: Arc<AtomicBool>,
127}
128
129impl<R: RequestMessage> Response<R> {
130 fn send(self, payload: R::Response) -> Result<()> {
131 self.responded.store(true, SeqCst);
132 self.peer.respond(self.receipt, payload)?;
133 Ok(())
134 }
135}
136
137#[derive(Clone, Debug)]
138pub enum Principal {
139 User(User),
140 Impersonated { user: User, admin: User },
141}
142
143impl Principal {
144 fn user(&self) -> &User {
145 match self {
146 Principal::User(user) => user,
147 Principal::Impersonated { user, .. } => user,
148 }
149 }
150
151 fn update_span(&self, span: &tracing::Span) {
152 match &self {
153 Principal::User(user) => {
154 span.record("user_id", user.id.0);
155 span.record("login", &user.github_login);
156 }
157 Principal::Impersonated { user, admin } => {
158 span.record("user_id", user.id.0);
159 span.record("login", &user.github_login);
160 span.record("impersonator", &admin.github_login);
161 }
162 }
163 }
164}
165
166#[derive(Clone)]
167struct Session {
168 principal: Principal,
169 connection_id: ConnectionId,
170 db: Arc<tokio::sync::Mutex<DbHandle>>,
171 peer: Arc<Peer>,
172 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
173 app_state: Arc<AppState>,
174 supermaven_client: Option<Arc<SupermavenAdminApi>>,
175 /// The GeoIP country code for the user.
176 #[allow(unused)]
177 geoip_country_code: Option<String>,
178 system_id: Option<String>,
179 _executor: Executor,
180}
181
182impl Session {
183 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
184 #[cfg(test)]
185 tokio::task::yield_now().await;
186 let guard = self.db.lock().await;
187 #[cfg(test)]
188 tokio::task::yield_now().await;
189 guard
190 }
191
192 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
193 #[cfg(test)]
194 tokio::task::yield_now().await;
195 let guard = self.connection_pool.lock();
196 ConnectionPoolGuard {
197 guard,
198 _not_send: PhantomData,
199 }
200 }
201
202 fn is_staff(&self) -> bool {
203 match &self.principal {
204 Principal::User(user) => user.admin,
205 Principal::Impersonated { .. } => true,
206 }
207 }
208
209 fn user_id(&self) -> UserId {
210 match &self.principal {
211 Principal::User(user) => user.id,
212 Principal::Impersonated { user, .. } => user.id,
213 }
214 }
215
216 pub fn email(&self) -> Option<String> {
217 match &self.principal {
218 Principal::User(user) => user.email_address.clone(),
219 Principal::Impersonated { user, .. } => user.email_address.clone(),
220 }
221 }
222}
223
224impl Debug for Session {
225 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
226 let mut result = f.debug_struct("Session");
227 match &self.principal {
228 Principal::User(user) => {
229 result.field("user", &user.github_login);
230 }
231 Principal::Impersonated { user, admin } => {
232 result.field("user", &user.github_login);
233 result.field("impersonator", &admin.github_login);
234 }
235 }
236 result.field("connection_id", &self.connection_id).finish()
237 }
238}
239
240struct DbHandle(Arc<Database>);
241
242impl Deref for DbHandle {
243 type Target = Database;
244
245 fn deref(&self) -> &Self::Target {
246 self.0.as_ref()
247 }
248}
249
250pub struct Server {
251 id: parking_lot::Mutex<ServerId>,
252 peer: Arc<Peer>,
253 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
254 app_state: Arc<AppState>,
255 handlers: HashMap<TypeId, MessageHandler>,
256 teardown: watch::Sender<bool>,
257}
258
259pub(crate) struct ConnectionPoolGuard<'a> {
260 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
261 _not_send: PhantomData<Rc<()>>,
262}
263
264#[derive(Serialize)]
265pub struct ServerSnapshot<'a> {
266 peer: &'a Peer,
267 #[serde(serialize_with = "serialize_deref")]
268 connection_pool: ConnectionPoolGuard<'a>,
269}
270
271pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
272where
273 S: Serializer,
274 T: Deref<Target = U>,
275 U: Serialize,
276{
277 Serialize::serialize(value.deref(), serializer)
278}
279
280impl Server {
281 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
282 let mut server = Self {
283 id: parking_lot::Mutex::new(id),
284 peer: Peer::new(id.0 as u32),
285 app_state: app_state.clone(),
286 connection_pool: Default::default(),
287 handlers: Default::default(),
288 teardown: watch::channel(false).0,
289 };
290
291 server
292 .add_request_handler(ping)
293 .add_request_handler(create_room)
294 .add_request_handler(join_room)
295 .add_request_handler(rejoin_room)
296 .add_request_handler(leave_room)
297 .add_request_handler(set_room_participant_role)
298 .add_request_handler(call)
299 .add_request_handler(cancel_call)
300 .add_message_handler(decline_call)
301 .add_request_handler(update_participant_location)
302 .add_request_handler(share_project)
303 .add_message_handler(unshare_project)
304 .add_request_handler(join_project)
305 .add_message_handler(leave_project)
306 .add_request_handler(update_project)
307 .add_request_handler(update_worktree)
308 .add_request_handler(update_repository)
309 .add_request_handler(remove_repository)
310 .add_message_handler(start_language_server)
311 .add_message_handler(update_language_server)
312 .add_message_handler(update_diagnostic_summary)
313 .add_message_handler(update_worktree_settings)
314 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
316 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
317 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
318 .add_request_handler(forward_read_only_project_request::<proto::FindSearchCandidates>)
319 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
320 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
321 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
322 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
323 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
324 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
325 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
326 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
327 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
328 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
329 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
330 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
331 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
332 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
333 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
334 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
335 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
336 .add_request_handler(
337 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
338 )
339 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
340 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
341 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
342 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
343 .add_request_handler(
344 forward_read_only_project_request::<proto::LanguageServerIdForName>,
345 )
346 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
347 .add_request_handler(
348 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
349 )
350 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
351 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
352 .add_request_handler(
353 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
354 )
355 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
356 .add_request_handler(
357 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
358 )
359 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
360 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
361 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
362 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
363 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
364 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
365 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
366 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
367 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
368 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
369 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
370 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
371 .add_request_handler(
372 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
373 )
374 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
375 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
376 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
377 .add_request_handler(multi_lsp_query)
378 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
379 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
380 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
381 .add_message_handler(create_buffer_for_peer)
382 .add_request_handler(update_buffer)
383 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
384 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
385 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
386 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
387 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
388 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
389 .add_message_handler(
390 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
391 )
392 .add_request_handler(get_users)
393 .add_request_handler(fuzzy_search_users)
394 .add_request_handler(request_contact)
395 .add_request_handler(remove_contact)
396 .add_request_handler(respond_to_contact_request)
397 .add_message_handler(subscribe_to_channels)
398 .add_request_handler(create_channel)
399 .add_request_handler(delete_channel)
400 .add_request_handler(invite_channel_member)
401 .add_request_handler(remove_channel_member)
402 .add_request_handler(set_channel_member_role)
403 .add_request_handler(set_channel_visibility)
404 .add_request_handler(rename_channel)
405 .add_request_handler(join_channel_buffer)
406 .add_request_handler(leave_channel_buffer)
407 .add_message_handler(update_channel_buffer)
408 .add_request_handler(rejoin_channel_buffers)
409 .add_request_handler(get_channel_members)
410 .add_request_handler(respond_to_channel_invite)
411 .add_request_handler(join_channel)
412 .add_request_handler(join_channel_chat)
413 .add_message_handler(leave_channel_chat)
414 .add_request_handler(send_channel_message)
415 .add_request_handler(remove_channel_message)
416 .add_request_handler(update_channel_message)
417 .add_request_handler(get_channel_messages)
418 .add_request_handler(get_channel_messages_by_id)
419 .add_request_handler(get_notifications)
420 .add_request_handler(mark_notification_as_read)
421 .add_request_handler(move_channel)
422 .add_request_handler(reorder_channel)
423 .add_request_handler(follow)
424 .add_message_handler(unfollow)
425 .add_message_handler(update_followers)
426 .add_request_handler(get_private_user_info)
427 .add_request_handler(get_llm_api_token)
428 .add_request_handler(accept_terms_of_service)
429 .add_message_handler(acknowledge_channel_message)
430 .add_message_handler(acknowledge_buffer_version)
431 .add_request_handler(get_supermaven_api_key)
432 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
433 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
434 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
435 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
436 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
437 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
438 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
439 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
440 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
441 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
442 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
443 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
444 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
445 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
446 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
447 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
448 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
449 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
450 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
451 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
452 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
453 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
454 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
455 .add_message_handler(update_context);
456
457 Arc::new(server)
458 }
459
460 pub async fn start(&self) -> Result<()> {
461 let server_id = *self.id.lock();
462 let app_state = self.app_state.clone();
463 let peer = self.peer.clone();
464 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
465 let pool = self.connection_pool.clone();
466 let livekit_client = self.app_state.livekit_client.clone();
467
468 let span = info_span!("start server");
469 self.app_state.executor.spawn_detached(
470 async move {
471 tracing::info!("waiting for cleanup timeout");
472 timeout.await;
473 tracing::info!("cleanup timeout expired, retrieving stale rooms");
474
475 app_state
476 .db
477 .delete_stale_channel_chat_participants(
478 &app_state.config.zed_environment,
479 server_id,
480 )
481 .await
482 .trace_err();
483
484 if let Some((room_ids, channel_ids)) = app_state
485 .db
486 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
487 .await
488 .trace_err()
489 {
490 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
491 tracing::info!(
492 stale_channel_buffer_count = channel_ids.len(),
493 "retrieved stale channel buffers"
494 );
495
496 for channel_id in channel_ids {
497 if let Some(refreshed_channel_buffer) = app_state
498 .db
499 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
500 .await
501 .trace_err()
502 {
503 for connection_id in refreshed_channel_buffer.connection_ids {
504 peer.send(
505 connection_id,
506 proto::UpdateChannelBufferCollaborators {
507 channel_id: channel_id.to_proto(),
508 collaborators: refreshed_channel_buffer
509 .collaborators
510 .clone(),
511 },
512 )
513 .trace_err();
514 }
515 }
516 }
517
518 for room_id in room_ids {
519 let mut contacts_to_update = HashSet::default();
520 let mut canceled_calls_to_user_ids = Vec::new();
521 let mut livekit_room = String::new();
522 let mut delete_livekit_room = false;
523
524 if let Some(mut refreshed_room) = app_state
525 .db
526 .clear_stale_room_participants(room_id, server_id)
527 .await
528 .trace_err()
529 {
530 tracing::info!(
531 room_id = room_id.0,
532 new_participant_count = refreshed_room.room.participants.len(),
533 "refreshed room"
534 );
535 room_updated(&refreshed_room.room, &peer);
536 if let Some(channel) = refreshed_room.channel.as_ref() {
537 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
538 }
539 contacts_to_update
540 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
541 contacts_to_update
542 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
543 canceled_calls_to_user_ids =
544 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
545 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
546 delete_livekit_room = refreshed_room.room.participants.is_empty();
547 }
548
549 {
550 let pool = pool.lock();
551 for canceled_user_id in canceled_calls_to_user_ids {
552 for connection_id in pool.user_connection_ids(canceled_user_id) {
553 peer.send(
554 connection_id,
555 proto::CallCanceled {
556 room_id: room_id.to_proto(),
557 },
558 )
559 .trace_err();
560 }
561 }
562 }
563
564 for user_id in contacts_to_update {
565 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
566 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
567 if let Some((busy, contacts)) = busy.zip(contacts) {
568 let pool = pool.lock();
569 let updated_contact = contact_for_user(user_id, busy, &pool);
570 for contact in contacts {
571 if let db::Contact::Accepted {
572 user_id: contact_user_id,
573 ..
574 } = contact
575 {
576 for contact_conn_id in
577 pool.user_connection_ids(contact_user_id)
578 {
579 peer.send(
580 contact_conn_id,
581 proto::UpdateContacts {
582 contacts: vec![updated_contact.clone()],
583 remove_contacts: Default::default(),
584 incoming_requests: Default::default(),
585 remove_incoming_requests: Default::default(),
586 outgoing_requests: Default::default(),
587 remove_outgoing_requests: Default::default(),
588 },
589 )
590 .trace_err();
591 }
592 }
593 }
594 }
595 }
596
597 if let Some(live_kit) = livekit_client.as_ref() {
598 if delete_livekit_room {
599 live_kit.delete_room(livekit_room).await.trace_err();
600 }
601 }
602 }
603 }
604
605 app_state
606 .db
607 .delete_stale_channel_chat_participants(
608 &app_state.config.zed_environment,
609 server_id,
610 )
611 .await
612 .trace_err();
613
614 app_state
615 .db
616 .clear_old_worktree_entries(server_id)
617 .await
618 .trace_err();
619
620 app_state
621 .db
622 .delete_stale_servers(&app_state.config.zed_environment, server_id)
623 .await
624 .trace_err();
625 }
626 .instrument(span),
627 );
628 Ok(())
629 }
630
631 pub fn teardown(&self) {
632 self.peer.teardown();
633 self.connection_pool.lock().reset();
634 let _ = self.teardown.send(true);
635 }
636
637 #[cfg(test)]
638 pub fn reset(&self, id: ServerId) {
639 self.teardown();
640 *self.id.lock() = id;
641 self.peer.reset(id.0 as u32);
642 let _ = self.teardown.send(false);
643 }
644
645 #[cfg(test)]
646 pub fn id(&self) -> ServerId {
647 *self.id.lock()
648 }
649
650 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
651 where
652 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
653 Fut: 'static + Send + Future<Output = Result<()>>,
654 M: EnvelopedMessage,
655 {
656 let prev_handler = self.handlers.insert(
657 TypeId::of::<M>(),
658 Box::new(move |envelope, session| {
659 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
660 let received_at = envelope.received_at;
661 tracing::info!("message received");
662 let start_time = Instant::now();
663 let future = (handler)(*envelope, session);
664 async move {
665 let result = future.await;
666 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
667 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
668 let queue_duration_ms = total_duration_ms - processing_duration_ms;
669
670 match result {
671 Err(error) => {
672 tracing::error!(
673 ?error,
674 total_duration_ms,
675 processing_duration_ms,
676 queue_duration_ms,
677 "error handling message"
678 )
679 }
680 Ok(()) => tracing::info!(
681 total_duration_ms,
682 processing_duration_ms,
683 queue_duration_ms,
684 "finished handling message"
685 ),
686 }
687 }
688 .boxed()
689 }),
690 );
691 if prev_handler.is_some() {
692 panic!("registered a handler for the same message twice");
693 }
694 self
695 }
696
697 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
698 where
699 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
700 Fut: 'static + Send + Future<Output = Result<()>>,
701 M: EnvelopedMessage,
702 {
703 self.add_handler(move |envelope, session| handler(envelope.payload, session));
704 self
705 }
706
707 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
708 where
709 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
710 Fut: Send + Future<Output = Result<()>>,
711 M: RequestMessage,
712 {
713 let handler = Arc::new(handler);
714 self.add_handler(move |envelope, session| {
715 let receipt = envelope.receipt();
716 let handler = handler.clone();
717 async move {
718 let peer = session.peer.clone();
719 let responded = Arc::new(AtomicBool::default());
720 let response = Response {
721 peer: peer.clone(),
722 responded: responded.clone(),
723 receipt,
724 };
725 match (handler)(envelope.payload, response, session).await {
726 Ok(()) => {
727 if responded.load(std::sync::atomic::Ordering::SeqCst) {
728 Ok(())
729 } else {
730 Err(anyhow!("handler did not send a response"))?
731 }
732 }
733 Err(error) => {
734 let proto_err = match &error {
735 Error::Internal(err) => err.to_proto(),
736 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
737 };
738 peer.respond_with_error(receipt, proto_err)?;
739 Err(error)
740 }
741 }
742 }
743 })
744 }
745
746 pub fn handle_connection(
747 self: &Arc<Self>,
748 connection: Connection,
749 address: String,
750 principal: Principal,
751 zed_version: ZedVersion,
752 user_agent: Option<String>,
753 geoip_country_code: Option<String>,
754 system_id: Option<String>,
755 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
756 executor: Executor,
757 connection_guard: Option<ConnectionGuard>,
758 ) -> impl Future<Output = ()> + use<> {
759 let this = self.clone();
760 let span = info_span!("handle connection", %address,
761 connection_id=field::Empty,
762 user_id=field::Empty,
763 login=field::Empty,
764 impersonator=field::Empty,
765 user_agent=field::Empty,
766 geoip_country_code=field::Empty
767 );
768 principal.update_span(&span);
769 if let Some(user_agent) = user_agent {
770 span.record("user_agent", user_agent);
771 }
772
773 if let Some(country_code) = geoip_country_code.as_ref() {
774 span.record("geoip_country_code", country_code);
775 }
776
777 let mut teardown = self.teardown.subscribe();
778 async move {
779 if *teardown.borrow() {
780 tracing::error!("server is tearing down");
781 return;
782 }
783
784 let (connection_id, handle_io, mut incoming_rx) =
785 this.peer.add_connection(connection, {
786 let executor = executor.clone();
787 move |duration| executor.sleep(duration)
788 });
789 tracing::Span::current().record("connection_id", format!("{}", connection_id));
790
791 tracing::info!("connection opened");
792
793 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
794 let http_client = match ReqwestClient::user_agent(&user_agent) {
795 Ok(http_client) => Arc::new(http_client),
796 Err(error) => {
797 tracing::error!(?error, "failed to create HTTP client");
798 return;
799 }
800 };
801
802 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(
803 |supermaven_admin_api_key| {
804 Arc::new(SupermavenAdminApi::new(
805 supermaven_admin_api_key.to_string(),
806 http_client.clone(),
807 ))
808 },
809 );
810
811 let session = Session {
812 principal: principal.clone(),
813 connection_id,
814 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
815 peer: this.peer.clone(),
816 connection_pool: this.connection_pool.clone(),
817 app_state: this.app_state.clone(),
818 geoip_country_code,
819 system_id,
820 _executor: executor.clone(),
821 supermaven_client,
822 };
823
824 if let Err(error) = this
825 .send_initial_client_update(
826 connection_id,
827 zed_version,
828 send_connection_id,
829 &session,
830 )
831 .await
832 {
833 tracing::error!(?error, "failed to send initial client update");
834 return;
835 }
836 drop(connection_guard);
837
838 let handle_io = handle_io.fuse();
839 futures::pin_mut!(handle_io);
840
841 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
842 // This prevents deadlocks when e.g., client A performs a request to client B and
843 // client B performs a request to client A. If both clients stop processing further
844 // messages until their respective request completes, they won't have a chance to
845 // respond to the other client's request and cause a deadlock.
846 //
847 // This arrangement ensures we will attempt to process earlier messages first, but fall
848 // back to processing messages arrived later in the spirit of making progress.
849 const MAX_CONCURRENT_HANDLERS: usize = 256;
850 let mut foreground_message_handlers = FuturesUnordered::new();
851 let concurrent_handlers = Arc::new(Semaphore::new(MAX_CONCURRENT_HANDLERS));
852 let get_concurrent_handlers = {
853 let concurrent_handlers = concurrent_handlers.clone();
854 move || MAX_CONCURRENT_HANDLERS - concurrent_handlers.available_permits()
855 };
856 loop {
857 let next_message = async {
858 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
859 let message = incoming_rx.next().await;
860 // Cache the concurrent_handlers here, so that we know what the
861 // queue looks like as each handler starts
862 (permit, message, get_concurrent_handlers())
863 }
864 .fuse();
865 futures::pin_mut!(next_message);
866 futures::select_biased! {
867 _ = teardown.changed().fuse() => return,
868 result = handle_io => {
869 if let Err(error) = result {
870 tracing::error!(?error, "error handling I/O");
871 }
872 break;
873 }
874 _ = foreground_message_handlers.next() => {}
875 next_message = next_message => {
876 let (permit, message, concurrent_handlers) = next_message;
877 if let Some(message) = message {
878 let type_name = message.payload_type_name();
879 // note: we copy all the fields from the parent span so we can query them in the logs.
880 // (https://github.com/tokio-rs/tracing/issues/2670).
881 let span = tracing::info_span!("receive message",
882 %connection_id,
883 %address,
884 type_name,
885 concurrent_handlers,
886 user_id=field::Empty,
887 login=field::Empty,
888 impersonator=field::Empty,
889 multi_lsp_query_request=field::Empty,
890 );
891 principal.update_span(&span);
892 let span_enter = span.enter();
893 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
894 let is_background = message.is_background();
895 let handle_message = (handler)(message, session.clone());
896 drop(span_enter);
897
898 let handle_message = async move {
899 handle_message.await;
900 drop(permit);
901 }.instrument(span);
902 if is_background {
903 executor.spawn_detached(handle_message);
904 } else {
905 foreground_message_handlers.push(handle_message);
906 }
907 } else {
908 tracing::error!("no message handler");
909 }
910 } else {
911 tracing::info!("connection closed");
912 break;
913 }
914 }
915 }
916 }
917
918 drop(foreground_message_handlers);
919 let concurrent_handlers = get_concurrent_handlers();
920 tracing::info!(concurrent_handlers, "signing out");
921 if let Err(error) = connection_lost(session, teardown, executor).await {
922 tracing::error!(?error, "error signing out");
923 }
924 }
925 .instrument(span)
926 }
927
928 async fn send_initial_client_update(
929 &self,
930 connection_id: ConnectionId,
931 zed_version: ZedVersion,
932 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
933 session: &Session,
934 ) -> Result<()> {
935 self.peer.send(
936 connection_id,
937 proto::Hello {
938 peer_id: Some(connection_id.into()),
939 },
940 )?;
941 tracing::info!("sent hello message");
942 if let Some(send_connection_id) = send_connection_id.take() {
943 let _ = send_connection_id.send(connection_id);
944 }
945
946 match &session.principal {
947 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
948 if !user.connected_once {
949 self.peer.send(connection_id, proto::ShowContacts {})?;
950 self.app_state
951 .db
952 .set_user_connected_once(user.id, true)
953 .await?;
954 }
955
956 update_user_plan(session).await?;
957
958 let contacts = self.app_state.db.get_contacts(user.id).await?;
959
960 {
961 let mut pool = self.connection_pool.lock();
962 pool.add_connection(connection_id, user.id, user.admin, zed_version);
963 self.peer.send(
964 connection_id,
965 build_initial_contacts_update(contacts, &pool),
966 )?;
967 }
968
969 if should_auto_subscribe_to_channels(zed_version) {
970 subscribe_user_to_channels(user.id, session).await?;
971 }
972
973 if let Some(incoming_call) =
974 self.app_state.db.incoming_call_for_user(user.id).await?
975 {
976 self.peer.send(connection_id, incoming_call)?;
977 }
978
979 update_user_contacts(user.id, session).await?;
980 }
981 }
982
983 Ok(())
984 }
985
986 pub async fn invite_code_redeemed(
987 self: &Arc<Self>,
988 inviter_id: UserId,
989 invitee_id: UserId,
990 ) -> Result<()> {
991 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
992 if let Some(code) = &user.invite_code {
993 let pool = self.connection_pool.lock();
994 let invitee_contact = contact_for_user(invitee_id, false, &pool);
995 for connection_id in pool.user_connection_ids(inviter_id) {
996 self.peer.send(
997 connection_id,
998 proto::UpdateContacts {
999 contacts: vec![invitee_contact.clone()],
1000 ..Default::default()
1001 },
1002 )?;
1003 self.peer.send(
1004 connection_id,
1005 proto::UpdateInviteInfo {
1006 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1007 count: user.invite_count as u32,
1008 },
1009 )?;
1010 }
1011 }
1012 }
1013 Ok(())
1014 }
1015
1016 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1017 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1018 if let Some(invite_code) = &user.invite_code {
1019 let pool = self.connection_pool.lock();
1020 for connection_id in pool.user_connection_ids(user_id) {
1021 self.peer.send(
1022 connection_id,
1023 proto::UpdateInviteInfo {
1024 url: format!(
1025 "{}{}",
1026 self.app_state.config.invite_link_prefix, invite_code
1027 ),
1028 count: user.invite_count as u32,
1029 },
1030 )?;
1031 }
1032 }
1033 }
1034 Ok(())
1035 }
1036
1037 pub async fn update_plan_for_user(
1038 self: &Arc<Self>,
1039 user_id: UserId,
1040 update_user_plan: proto::UpdateUserPlan,
1041 ) -> Result<()> {
1042 let pool = self.connection_pool.lock();
1043 for connection_id in pool.user_connection_ids(user_id) {
1044 self.peer
1045 .send(connection_id, update_user_plan.clone())
1046 .trace_err();
1047 }
1048
1049 Ok(())
1050 }
1051
1052 /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1053 /// message on the Collab server.
1054 ///
1055 /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1056 pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1057 let user = self
1058 .app_state
1059 .db
1060 .get_user_by_id(user_id)
1061 .await?
1062 .context("user not found")?;
1063
1064 let update_user_plan = make_update_user_plan_message(
1065 &user,
1066 user.admin,
1067 &self.app_state.db,
1068 self.app_state.llm_db.clone(),
1069 )
1070 .await?;
1071
1072 self.update_plan_for_user(user_id, update_user_plan).await
1073 }
1074
1075 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1076 let pool = self.connection_pool.lock();
1077 for connection_id in pool.user_connection_ids(user_id) {
1078 self.peer
1079 .send(connection_id, proto::RefreshLlmToken {})
1080 .trace_err();
1081 }
1082 }
1083
1084 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1085 ServerSnapshot {
1086 connection_pool: ConnectionPoolGuard {
1087 guard: self.connection_pool.lock(),
1088 _not_send: PhantomData,
1089 },
1090 peer: &self.peer,
1091 }
1092 }
1093}
1094
1095impl Deref for ConnectionPoolGuard<'_> {
1096 type Target = ConnectionPool;
1097
1098 fn deref(&self) -> &Self::Target {
1099 &self.guard
1100 }
1101}
1102
1103impl DerefMut for ConnectionPoolGuard<'_> {
1104 fn deref_mut(&mut self) -> &mut Self::Target {
1105 &mut self.guard
1106 }
1107}
1108
1109impl Drop for ConnectionPoolGuard<'_> {
1110 fn drop(&mut self) {
1111 #[cfg(test)]
1112 self.check_invariants();
1113 }
1114}
1115
1116fn broadcast<F>(
1117 sender_id: Option<ConnectionId>,
1118 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1119 mut f: F,
1120) where
1121 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1122{
1123 for receiver_id in receiver_ids {
1124 if Some(receiver_id) != sender_id {
1125 if let Err(error) = f(receiver_id) {
1126 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1127 }
1128 }
1129 }
1130}
1131
1132pub struct ProtocolVersion(u32);
1133
1134impl Header for ProtocolVersion {
1135 fn name() -> &'static HeaderName {
1136 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1137 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1138 }
1139
1140 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1141 where
1142 Self: Sized,
1143 I: Iterator<Item = &'i axum::http::HeaderValue>,
1144 {
1145 let version = values
1146 .next()
1147 .ok_or_else(axum::headers::Error::invalid)?
1148 .to_str()
1149 .map_err(|_| axum::headers::Error::invalid())?
1150 .parse()
1151 .map_err(|_| axum::headers::Error::invalid())?;
1152 Ok(Self(version))
1153 }
1154
1155 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1156 values.extend([self.0.to_string().parse().unwrap()]);
1157 }
1158}
1159
1160pub struct AppVersionHeader(SemanticVersion);
1161impl Header for AppVersionHeader {
1162 fn name() -> &'static HeaderName {
1163 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1164 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1165 }
1166
1167 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1168 where
1169 Self: Sized,
1170 I: Iterator<Item = &'i axum::http::HeaderValue>,
1171 {
1172 let version = values
1173 .next()
1174 .ok_or_else(axum::headers::Error::invalid)?
1175 .to_str()
1176 .map_err(|_| axum::headers::Error::invalid())?
1177 .parse()
1178 .map_err(|_| axum::headers::Error::invalid())?;
1179 Ok(Self(version))
1180 }
1181
1182 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1183 values.extend([self.0.to_string().parse().unwrap()]);
1184 }
1185}
1186
1187pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1188 Router::new()
1189 .route("/rpc", get(handle_websocket_request))
1190 .layer(
1191 ServiceBuilder::new()
1192 .layer(Extension(server.app_state.clone()))
1193 .layer(middleware::from_fn(auth::validate_header)),
1194 )
1195 .route("/metrics", get(handle_metrics))
1196 .layer(Extension(server))
1197}
1198
1199pub async fn handle_websocket_request(
1200 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1201 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1202 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1203 Extension(server): Extension<Arc<Server>>,
1204 Extension(principal): Extension<Principal>,
1205 user_agent: Option<TypedHeader<UserAgent>>,
1206 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1207 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1208 ws: WebSocketUpgrade,
1209) -> axum::response::Response {
1210 if protocol_version != rpc::PROTOCOL_VERSION {
1211 return (
1212 StatusCode::UPGRADE_REQUIRED,
1213 "client must be upgraded".to_string(),
1214 )
1215 .into_response();
1216 }
1217
1218 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1219 return (
1220 StatusCode::UPGRADE_REQUIRED,
1221 "no version header found".to_string(),
1222 )
1223 .into_response();
1224 };
1225
1226 if !version.can_collaborate() {
1227 return (
1228 StatusCode::UPGRADE_REQUIRED,
1229 "client must be upgraded".to_string(),
1230 )
1231 .into_response();
1232 }
1233
1234 let socket_address = socket_address.to_string();
1235
1236 // Acquire connection guard before WebSocket upgrade
1237 let connection_guard = match ConnectionGuard::try_acquire() {
1238 Ok(guard) => guard,
1239 Err(()) => {
1240 return (
1241 StatusCode::SERVICE_UNAVAILABLE,
1242 "Too many concurrent connections",
1243 )
1244 .into_response();
1245 }
1246 };
1247
1248 ws.on_upgrade(move |socket| {
1249 let socket = socket
1250 .map_ok(to_tungstenite_message)
1251 .err_into()
1252 .with(|message| async move { to_axum_message(message) });
1253 let connection = Connection::new(Box::pin(socket));
1254 async move {
1255 server
1256 .handle_connection(
1257 connection,
1258 socket_address,
1259 principal,
1260 version,
1261 user_agent.map(|header| header.to_string()),
1262 country_code_header.map(|header| header.to_string()),
1263 system_id_header.map(|header| header.to_string()),
1264 None,
1265 Executor::Production,
1266 Some(connection_guard),
1267 )
1268 .await;
1269 }
1270 })
1271}
1272
1273pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1274 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1275 let connections_metric = CONNECTIONS_METRIC
1276 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1277
1278 let connections = server
1279 .connection_pool
1280 .lock()
1281 .connections()
1282 .filter(|connection| !connection.admin)
1283 .count();
1284 connections_metric.set(connections as _);
1285
1286 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1287 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1288 register_int_gauge!(
1289 "shared_projects",
1290 "number of open projects with one or more guests"
1291 )
1292 .unwrap()
1293 });
1294
1295 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1296 shared_projects_metric.set(shared_projects as _);
1297
1298 let encoder = prometheus::TextEncoder::new();
1299 let metric_families = prometheus::gather();
1300 let encoded_metrics = encoder
1301 .encode_to_string(&metric_families)
1302 .map_err(|err| anyhow!("{err}"))?;
1303 Ok(encoded_metrics)
1304}
1305
1306#[instrument(err, skip(executor))]
1307async fn connection_lost(
1308 session: Session,
1309 mut teardown: watch::Receiver<bool>,
1310 executor: Executor,
1311) -> Result<()> {
1312 session.peer.disconnect(session.connection_id);
1313 session
1314 .connection_pool()
1315 .await
1316 .remove_connection(session.connection_id)?;
1317
1318 session
1319 .db()
1320 .await
1321 .connection_lost(session.connection_id)
1322 .await
1323 .trace_err();
1324
1325 futures::select_biased! {
1326 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1327
1328 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1329 leave_room_for_session(&session, session.connection_id).await.trace_err();
1330 leave_channel_buffers_for_session(&session)
1331 .await
1332 .trace_err();
1333
1334 if !session
1335 .connection_pool()
1336 .await
1337 .is_user_online(session.user_id())
1338 {
1339 let db = session.db().await;
1340 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1341 room_updated(&room, &session.peer);
1342 }
1343 }
1344
1345 update_user_contacts(session.user_id(), &session).await?;
1346 },
1347 _ = teardown.changed().fuse() => {}
1348 }
1349
1350 Ok(())
1351}
1352
1353/// Acknowledges a ping from a client, used to keep the connection alive.
1354async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1355 response.send(proto::Ack {})?;
1356 Ok(())
1357}
1358
1359/// Creates a new room for calling (outside of channels)
1360async fn create_room(
1361 _request: proto::CreateRoom,
1362 response: Response<proto::CreateRoom>,
1363 session: Session,
1364) -> Result<()> {
1365 let livekit_room = nanoid::nanoid!(30);
1366
1367 let live_kit_connection_info = util::maybe!(async {
1368 let live_kit = session.app_state.livekit_client.as_ref();
1369 let live_kit = live_kit?;
1370 let user_id = session.user_id().to_string();
1371
1372 let token = live_kit
1373 .room_token(&livekit_room, &user_id.to_string())
1374 .trace_err()?;
1375
1376 Some(proto::LiveKitConnectionInfo {
1377 server_url: live_kit.url().into(),
1378 token,
1379 can_publish: true,
1380 })
1381 })
1382 .await;
1383
1384 let room = session
1385 .db()
1386 .await
1387 .create_room(session.user_id(), session.connection_id, &livekit_room)
1388 .await?;
1389
1390 response.send(proto::CreateRoomResponse {
1391 room: Some(room.clone()),
1392 live_kit_connection_info,
1393 })?;
1394
1395 update_user_contacts(session.user_id(), &session).await?;
1396 Ok(())
1397}
1398
1399/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1400async fn join_room(
1401 request: proto::JoinRoom,
1402 response: Response<proto::JoinRoom>,
1403 session: Session,
1404) -> Result<()> {
1405 let room_id = RoomId::from_proto(request.id);
1406
1407 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1408
1409 if let Some(channel_id) = channel_id {
1410 return join_channel_internal(channel_id, Box::new(response), session).await;
1411 }
1412
1413 let joined_room = {
1414 let room = session
1415 .db()
1416 .await
1417 .join_room(room_id, session.user_id(), session.connection_id)
1418 .await?;
1419 room_updated(&room.room, &session.peer);
1420 room.into_inner()
1421 };
1422
1423 for connection_id in session
1424 .connection_pool()
1425 .await
1426 .user_connection_ids(session.user_id())
1427 {
1428 session
1429 .peer
1430 .send(
1431 connection_id,
1432 proto::CallCanceled {
1433 room_id: room_id.to_proto(),
1434 },
1435 )
1436 .trace_err();
1437 }
1438
1439 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1440 {
1441 live_kit
1442 .room_token(
1443 &joined_room.room.livekit_room,
1444 &session.user_id().to_string(),
1445 )
1446 .trace_err()
1447 .map(|token| proto::LiveKitConnectionInfo {
1448 server_url: live_kit.url().into(),
1449 token,
1450 can_publish: true,
1451 })
1452 } else {
1453 None
1454 };
1455
1456 response.send(proto::JoinRoomResponse {
1457 room: Some(joined_room.room),
1458 channel_id: None,
1459 live_kit_connection_info,
1460 })?;
1461
1462 update_user_contacts(session.user_id(), &session).await?;
1463 Ok(())
1464}
1465
1466/// Rejoin room is used to reconnect to a room after connection errors.
1467async fn rejoin_room(
1468 request: proto::RejoinRoom,
1469 response: Response<proto::RejoinRoom>,
1470 session: Session,
1471) -> Result<()> {
1472 let room;
1473 let channel;
1474 {
1475 let mut rejoined_room = session
1476 .db()
1477 .await
1478 .rejoin_room(request, session.user_id(), session.connection_id)
1479 .await?;
1480
1481 response.send(proto::RejoinRoomResponse {
1482 room: Some(rejoined_room.room.clone()),
1483 reshared_projects: rejoined_room
1484 .reshared_projects
1485 .iter()
1486 .map(|project| proto::ResharedProject {
1487 id: project.id.to_proto(),
1488 collaborators: project
1489 .collaborators
1490 .iter()
1491 .map(|collaborator| collaborator.to_proto())
1492 .collect(),
1493 })
1494 .collect(),
1495 rejoined_projects: rejoined_room
1496 .rejoined_projects
1497 .iter()
1498 .map(|rejoined_project| rejoined_project.to_proto())
1499 .collect(),
1500 })?;
1501 room_updated(&rejoined_room.room, &session.peer);
1502
1503 for project in &rejoined_room.reshared_projects {
1504 for collaborator in &project.collaborators {
1505 session
1506 .peer
1507 .send(
1508 collaborator.connection_id,
1509 proto::UpdateProjectCollaborator {
1510 project_id: project.id.to_proto(),
1511 old_peer_id: Some(project.old_connection_id.into()),
1512 new_peer_id: Some(session.connection_id.into()),
1513 },
1514 )
1515 .trace_err();
1516 }
1517
1518 broadcast(
1519 Some(session.connection_id),
1520 project
1521 .collaborators
1522 .iter()
1523 .map(|collaborator| collaborator.connection_id),
1524 |connection_id| {
1525 session.peer.forward_send(
1526 session.connection_id,
1527 connection_id,
1528 proto::UpdateProject {
1529 project_id: project.id.to_proto(),
1530 worktrees: project.worktrees.clone(),
1531 },
1532 )
1533 },
1534 );
1535 }
1536
1537 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1538
1539 let rejoined_room = rejoined_room.into_inner();
1540
1541 room = rejoined_room.room;
1542 channel = rejoined_room.channel;
1543 }
1544
1545 if let Some(channel) = channel {
1546 channel_updated(
1547 &channel,
1548 &room,
1549 &session.peer,
1550 &*session.connection_pool().await,
1551 );
1552 }
1553
1554 update_user_contacts(session.user_id(), &session).await?;
1555 Ok(())
1556}
1557
1558fn notify_rejoined_projects(
1559 rejoined_projects: &mut Vec<RejoinedProject>,
1560 session: &Session,
1561) -> Result<()> {
1562 for project in rejoined_projects.iter() {
1563 for collaborator in &project.collaborators {
1564 session
1565 .peer
1566 .send(
1567 collaborator.connection_id,
1568 proto::UpdateProjectCollaborator {
1569 project_id: project.id.to_proto(),
1570 old_peer_id: Some(project.old_connection_id.into()),
1571 new_peer_id: Some(session.connection_id.into()),
1572 },
1573 )
1574 .trace_err();
1575 }
1576 }
1577
1578 for project in rejoined_projects {
1579 for worktree in mem::take(&mut project.worktrees) {
1580 // Stream this worktree's entries.
1581 let message = proto::UpdateWorktree {
1582 project_id: project.id.to_proto(),
1583 worktree_id: worktree.id,
1584 abs_path: worktree.abs_path.clone(),
1585 root_name: worktree.root_name,
1586 updated_entries: worktree.updated_entries,
1587 removed_entries: worktree.removed_entries,
1588 scan_id: worktree.scan_id,
1589 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1590 updated_repositories: worktree.updated_repositories,
1591 removed_repositories: worktree.removed_repositories,
1592 };
1593 for update in proto::split_worktree_update(message) {
1594 session.peer.send(session.connection_id, update)?;
1595 }
1596
1597 // Stream this worktree's diagnostics.
1598 for summary in worktree.diagnostic_summaries {
1599 session.peer.send(
1600 session.connection_id,
1601 proto::UpdateDiagnosticSummary {
1602 project_id: project.id.to_proto(),
1603 worktree_id: worktree.id,
1604 summary: Some(summary),
1605 },
1606 )?;
1607 }
1608
1609 for settings_file in worktree.settings_files {
1610 session.peer.send(
1611 session.connection_id,
1612 proto::UpdateWorktreeSettings {
1613 project_id: project.id.to_proto(),
1614 worktree_id: worktree.id,
1615 path: settings_file.path,
1616 content: Some(settings_file.content),
1617 kind: Some(settings_file.kind.to_proto().into()),
1618 },
1619 )?;
1620 }
1621 }
1622
1623 for repository in mem::take(&mut project.updated_repositories) {
1624 for update in split_repository_update(repository) {
1625 session.peer.send(session.connection_id, update)?;
1626 }
1627 }
1628
1629 for id in mem::take(&mut project.removed_repositories) {
1630 session.peer.send(
1631 session.connection_id,
1632 proto::RemoveRepository {
1633 project_id: project.id.to_proto(),
1634 id,
1635 },
1636 )?;
1637 }
1638 }
1639
1640 Ok(())
1641}
1642
1643/// leave room disconnects from the room.
1644async fn leave_room(
1645 _: proto::LeaveRoom,
1646 response: Response<proto::LeaveRoom>,
1647 session: Session,
1648) -> Result<()> {
1649 leave_room_for_session(&session, session.connection_id).await?;
1650 response.send(proto::Ack {})?;
1651 Ok(())
1652}
1653
1654/// Updates the permissions of someone else in the room.
1655async fn set_room_participant_role(
1656 request: proto::SetRoomParticipantRole,
1657 response: Response<proto::SetRoomParticipantRole>,
1658 session: Session,
1659) -> Result<()> {
1660 let user_id = UserId::from_proto(request.user_id);
1661 let role = ChannelRole::from(request.role());
1662
1663 let (livekit_room, can_publish) = {
1664 let room = session
1665 .db()
1666 .await
1667 .set_room_participant_role(
1668 session.user_id(),
1669 RoomId::from_proto(request.room_id),
1670 user_id,
1671 role,
1672 )
1673 .await?;
1674
1675 let livekit_room = room.livekit_room.clone();
1676 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1677 room_updated(&room, &session.peer);
1678 (livekit_room, can_publish)
1679 };
1680
1681 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1682 live_kit
1683 .update_participant(
1684 livekit_room.clone(),
1685 request.user_id.to_string(),
1686 livekit_api::proto::ParticipantPermission {
1687 can_subscribe: true,
1688 can_publish,
1689 can_publish_data: can_publish,
1690 hidden: false,
1691 recorder: false,
1692 },
1693 )
1694 .await
1695 .trace_err();
1696 }
1697
1698 response.send(proto::Ack {})?;
1699 Ok(())
1700}
1701
1702/// Call someone else into the current room
1703async fn call(
1704 request: proto::Call,
1705 response: Response<proto::Call>,
1706 session: Session,
1707) -> Result<()> {
1708 let room_id = RoomId::from_proto(request.room_id);
1709 let calling_user_id = session.user_id();
1710 let calling_connection_id = session.connection_id;
1711 let called_user_id = UserId::from_proto(request.called_user_id);
1712 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1713 if !session
1714 .db()
1715 .await
1716 .has_contact(calling_user_id, called_user_id)
1717 .await?
1718 {
1719 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1720 }
1721
1722 let incoming_call = {
1723 let (room, incoming_call) = &mut *session
1724 .db()
1725 .await
1726 .call(
1727 room_id,
1728 calling_user_id,
1729 calling_connection_id,
1730 called_user_id,
1731 initial_project_id,
1732 )
1733 .await?;
1734 room_updated(room, &session.peer);
1735 mem::take(incoming_call)
1736 };
1737 update_user_contacts(called_user_id, &session).await?;
1738
1739 let mut calls = session
1740 .connection_pool()
1741 .await
1742 .user_connection_ids(called_user_id)
1743 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1744 .collect::<FuturesUnordered<_>>();
1745
1746 while let Some(call_response) = calls.next().await {
1747 match call_response.as_ref() {
1748 Ok(_) => {
1749 response.send(proto::Ack {})?;
1750 return Ok(());
1751 }
1752 Err(_) => {
1753 call_response.trace_err();
1754 }
1755 }
1756 }
1757
1758 {
1759 let room = session
1760 .db()
1761 .await
1762 .call_failed(room_id, called_user_id)
1763 .await?;
1764 room_updated(&room, &session.peer);
1765 }
1766 update_user_contacts(called_user_id, &session).await?;
1767
1768 Err(anyhow!("failed to ring user"))?
1769}
1770
1771/// Cancel an outgoing call.
1772async fn cancel_call(
1773 request: proto::CancelCall,
1774 response: Response<proto::CancelCall>,
1775 session: Session,
1776) -> Result<()> {
1777 let called_user_id = UserId::from_proto(request.called_user_id);
1778 let room_id = RoomId::from_proto(request.room_id);
1779 {
1780 let room = session
1781 .db()
1782 .await
1783 .cancel_call(room_id, session.connection_id, called_user_id)
1784 .await?;
1785 room_updated(&room, &session.peer);
1786 }
1787
1788 for connection_id in session
1789 .connection_pool()
1790 .await
1791 .user_connection_ids(called_user_id)
1792 {
1793 session
1794 .peer
1795 .send(
1796 connection_id,
1797 proto::CallCanceled {
1798 room_id: room_id.to_proto(),
1799 },
1800 )
1801 .trace_err();
1802 }
1803 response.send(proto::Ack {})?;
1804
1805 update_user_contacts(called_user_id, &session).await?;
1806 Ok(())
1807}
1808
1809/// Decline an incoming call.
1810async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1811 let room_id = RoomId::from_proto(message.room_id);
1812 {
1813 let room = session
1814 .db()
1815 .await
1816 .decline_call(Some(room_id), session.user_id())
1817 .await?
1818 .context("declining call")?;
1819 room_updated(&room, &session.peer);
1820 }
1821
1822 for connection_id in session
1823 .connection_pool()
1824 .await
1825 .user_connection_ids(session.user_id())
1826 {
1827 session
1828 .peer
1829 .send(
1830 connection_id,
1831 proto::CallCanceled {
1832 room_id: room_id.to_proto(),
1833 },
1834 )
1835 .trace_err();
1836 }
1837 update_user_contacts(session.user_id(), &session).await?;
1838 Ok(())
1839}
1840
1841/// Updates other participants in the room with your current location.
1842async fn update_participant_location(
1843 request: proto::UpdateParticipantLocation,
1844 response: Response<proto::UpdateParticipantLocation>,
1845 session: Session,
1846) -> Result<()> {
1847 let room_id = RoomId::from_proto(request.room_id);
1848 let location = request.location.context("invalid location")?;
1849
1850 let db = session.db().await;
1851 let room = db
1852 .update_room_participant_location(room_id, session.connection_id, location)
1853 .await?;
1854
1855 room_updated(&room, &session.peer);
1856 response.send(proto::Ack {})?;
1857 Ok(())
1858}
1859
1860/// Share a project into the room.
1861async fn share_project(
1862 request: proto::ShareProject,
1863 response: Response<proto::ShareProject>,
1864 session: Session,
1865) -> Result<()> {
1866 let (project_id, room) = &*session
1867 .db()
1868 .await
1869 .share_project(
1870 RoomId::from_proto(request.room_id),
1871 session.connection_id,
1872 &request.worktrees,
1873 request.is_ssh_project,
1874 )
1875 .await?;
1876 response.send(proto::ShareProjectResponse {
1877 project_id: project_id.to_proto(),
1878 })?;
1879 room_updated(room, &session.peer);
1880
1881 Ok(())
1882}
1883
1884/// Unshare a project from the room.
1885async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1886 let project_id = ProjectId::from_proto(message.project_id);
1887 unshare_project_internal(project_id, session.connection_id, &session).await
1888}
1889
1890async fn unshare_project_internal(
1891 project_id: ProjectId,
1892 connection_id: ConnectionId,
1893 session: &Session,
1894) -> Result<()> {
1895 let delete = {
1896 let room_guard = session
1897 .db()
1898 .await
1899 .unshare_project(project_id, connection_id)
1900 .await?;
1901
1902 let (delete, room, guest_connection_ids) = &*room_guard;
1903
1904 let message = proto::UnshareProject {
1905 project_id: project_id.to_proto(),
1906 };
1907
1908 broadcast(
1909 Some(connection_id),
1910 guest_connection_ids.iter().copied(),
1911 |conn_id| session.peer.send(conn_id, message.clone()),
1912 );
1913 if let Some(room) = room {
1914 room_updated(room, &session.peer);
1915 }
1916
1917 *delete
1918 };
1919
1920 if delete {
1921 let db = session.db().await;
1922 db.delete_project(project_id).await?;
1923 }
1924
1925 Ok(())
1926}
1927
1928/// Join someone elses shared project.
1929async fn join_project(
1930 request: proto::JoinProject,
1931 response: Response<proto::JoinProject>,
1932 session: Session,
1933) -> Result<()> {
1934 let project_id = ProjectId::from_proto(request.project_id);
1935
1936 tracing::info!(%project_id, "join project");
1937
1938 let db = session.db().await;
1939 let (project, replica_id) = &mut *db
1940 .join_project(
1941 project_id,
1942 session.connection_id,
1943 session.user_id(),
1944 request.committer_name.clone(),
1945 request.committer_email.clone(),
1946 )
1947 .await?;
1948 drop(db);
1949 tracing::info!(%project_id, "join remote project");
1950 let collaborators = project
1951 .collaborators
1952 .iter()
1953 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1954 .map(|collaborator| collaborator.to_proto())
1955 .collect::<Vec<_>>();
1956 let project_id = project.id;
1957 let guest_user_id = session.user_id();
1958
1959 let worktrees = project
1960 .worktrees
1961 .iter()
1962 .map(|(id, worktree)| proto::WorktreeMetadata {
1963 id: *id,
1964 root_name: worktree.root_name.clone(),
1965 visible: worktree.visible,
1966 abs_path: worktree.abs_path.clone(),
1967 })
1968 .collect::<Vec<_>>();
1969
1970 let add_project_collaborator = proto::AddProjectCollaborator {
1971 project_id: project_id.to_proto(),
1972 collaborator: Some(proto::Collaborator {
1973 peer_id: Some(session.connection_id.into()),
1974 replica_id: replica_id.0 as u32,
1975 user_id: guest_user_id.to_proto(),
1976 is_host: false,
1977 committer_name: request.committer_name.clone(),
1978 committer_email: request.committer_email.clone(),
1979 }),
1980 };
1981
1982 for collaborator in &collaborators {
1983 session
1984 .peer
1985 .send(
1986 collaborator.peer_id.unwrap().into(),
1987 add_project_collaborator.clone(),
1988 )
1989 .trace_err();
1990 }
1991
1992 // First, we send the metadata associated with each worktree.
1993 let (language_servers, language_server_capabilities) = project
1994 .language_servers
1995 .clone()
1996 .into_iter()
1997 .map(|server| (server.server, server.capabilities))
1998 .unzip();
1999 response.send(proto::JoinProjectResponse {
2000 project_id: project.id.0 as u64,
2001 worktrees: worktrees.clone(),
2002 replica_id: replica_id.0 as u32,
2003 collaborators: collaborators.clone(),
2004 language_servers,
2005 language_server_capabilities,
2006 role: project.role.into(),
2007 })?;
2008
2009 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2010 // Stream this worktree's entries.
2011 let message = proto::UpdateWorktree {
2012 project_id: project_id.to_proto(),
2013 worktree_id,
2014 abs_path: worktree.abs_path.clone(),
2015 root_name: worktree.root_name,
2016 updated_entries: worktree.entries,
2017 removed_entries: Default::default(),
2018 scan_id: worktree.scan_id,
2019 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2020 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
2021 removed_repositories: Default::default(),
2022 };
2023 for update in proto::split_worktree_update(message) {
2024 session.peer.send(session.connection_id, update.clone())?;
2025 }
2026
2027 // Stream this worktree's diagnostics.
2028 for summary in worktree.diagnostic_summaries {
2029 session.peer.send(
2030 session.connection_id,
2031 proto::UpdateDiagnosticSummary {
2032 project_id: project_id.to_proto(),
2033 worktree_id: worktree.id,
2034 summary: Some(summary),
2035 },
2036 )?;
2037 }
2038
2039 for settings_file in worktree.settings_files {
2040 session.peer.send(
2041 session.connection_id,
2042 proto::UpdateWorktreeSettings {
2043 project_id: project_id.to_proto(),
2044 worktree_id: worktree.id,
2045 path: settings_file.path,
2046 content: Some(settings_file.content),
2047 kind: Some(settings_file.kind.to_proto() as i32),
2048 },
2049 )?;
2050 }
2051 }
2052
2053 for repository in mem::take(&mut project.repositories) {
2054 for update in split_repository_update(repository) {
2055 session.peer.send(session.connection_id, update)?;
2056 }
2057 }
2058
2059 for language_server in &project.language_servers {
2060 session.peer.send(
2061 session.connection_id,
2062 proto::UpdateLanguageServer {
2063 project_id: project_id.to_proto(),
2064 server_name: Some(language_server.server.name.clone()),
2065 language_server_id: language_server.server.id,
2066 variant: Some(
2067 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2068 proto::LspDiskBasedDiagnosticsUpdated {},
2069 ),
2070 ),
2071 },
2072 )?;
2073 }
2074
2075 Ok(())
2076}
2077
2078/// Leave someone elses shared project.
2079async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2080 let sender_id = session.connection_id;
2081 let project_id = ProjectId::from_proto(request.project_id);
2082 let db = session.db().await;
2083
2084 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2085 tracing::info!(
2086 %project_id,
2087 "leave project"
2088 );
2089
2090 project_left(project, &session);
2091 if let Some(room) = room {
2092 room_updated(room, &session.peer);
2093 }
2094
2095 Ok(())
2096}
2097
2098/// Updates other participants with changes to the project
2099async fn update_project(
2100 request: proto::UpdateProject,
2101 response: Response<proto::UpdateProject>,
2102 session: Session,
2103) -> Result<()> {
2104 let project_id = ProjectId::from_proto(request.project_id);
2105 let (room, guest_connection_ids) = &*session
2106 .db()
2107 .await
2108 .update_project(project_id, session.connection_id, &request.worktrees)
2109 .await?;
2110 broadcast(
2111 Some(session.connection_id),
2112 guest_connection_ids.iter().copied(),
2113 |connection_id| {
2114 session
2115 .peer
2116 .forward_send(session.connection_id, connection_id, request.clone())
2117 },
2118 );
2119 if let Some(room) = room {
2120 room_updated(room, &session.peer);
2121 }
2122 response.send(proto::Ack {})?;
2123
2124 Ok(())
2125}
2126
2127/// Updates other participants with changes to the worktree
2128async fn update_worktree(
2129 request: proto::UpdateWorktree,
2130 response: Response<proto::UpdateWorktree>,
2131 session: Session,
2132) -> Result<()> {
2133 let guest_connection_ids = session
2134 .db()
2135 .await
2136 .update_worktree(&request, session.connection_id)
2137 .await?;
2138
2139 broadcast(
2140 Some(session.connection_id),
2141 guest_connection_ids.iter().copied(),
2142 |connection_id| {
2143 session
2144 .peer
2145 .forward_send(session.connection_id, connection_id, request.clone())
2146 },
2147 );
2148 response.send(proto::Ack {})?;
2149 Ok(())
2150}
2151
2152async fn update_repository(
2153 request: proto::UpdateRepository,
2154 response: Response<proto::UpdateRepository>,
2155 session: Session,
2156) -> Result<()> {
2157 let guest_connection_ids = session
2158 .db()
2159 .await
2160 .update_repository(&request, session.connection_id)
2161 .await?;
2162
2163 broadcast(
2164 Some(session.connection_id),
2165 guest_connection_ids.iter().copied(),
2166 |connection_id| {
2167 session
2168 .peer
2169 .forward_send(session.connection_id, connection_id, request.clone())
2170 },
2171 );
2172 response.send(proto::Ack {})?;
2173 Ok(())
2174}
2175
2176async fn remove_repository(
2177 request: proto::RemoveRepository,
2178 response: Response<proto::RemoveRepository>,
2179 session: Session,
2180) -> Result<()> {
2181 let guest_connection_ids = session
2182 .db()
2183 .await
2184 .remove_repository(&request, session.connection_id)
2185 .await?;
2186
2187 broadcast(
2188 Some(session.connection_id),
2189 guest_connection_ids.iter().copied(),
2190 |connection_id| {
2191 session
2192 .peer
2193 .forward_send(session.connection_id, connection_id, request.clone())
2194 },
2195 );
2196 response.send(proto::Ack {})?;
2197 Ok(())
2198}
2199
2200/// Updates other participants with changes to the diagnostics
2201async fn update_diagnostic_summary(
2202 message: proto::UpdateDiagnosticSummary,
2203 session: Session,
2204) -> Result<()> {
2205 let guest_connection_ids = session
2206 .db()
2207 .await
2208 .update_diagnostic_summary(&message, session.connection_id)
2209 .await?;
2210
2211 broadcast(
2212 Some(session.connection_id),
2213 guest_connection_ids.iter().copied(),
2214 |connection_id| {
2215 session
2216 .peer
2217 .forward_send(session.connection_id, connection_id, message.clone())
2218 },
2219 );
2220
2221 Ok(())
2222}
2223
2224/// Updates other participants with changes to the worktree settings
2225async fn update_worktree_settings(
2226 message: proto::UpdateWorktreeSettings,
2227 session: Session,
2228) -> Result<()> {
2229 let guest_connection_ids = session
2230 .db()
2231 .await
2232 .update_worktree_settings(&message, session.connection_id)
2233 .await?;
2234
2235 broadcast(
2236 Some(session.connection_id),
2237 guest_connection_ids.iter().copied(),
2238 |connection_id| {
2239 session
2240 .peer
2241 .forward_send(session.connection_id, connection_id, message.clone())
2242 },
2243 );
2244
2245 Ok(())
2246}
2247
2248/// Notify other participants that a language server has started.
2249async fn start_language_server(
2250 request: proto::StartLanguageServer,
2251 session: Session,
2252) -> Result<()> {
2253 let guest_connection_ids = session
2254 .db()
2255 .await
2256 .start_language_server(&request, session.connection_id)
2257 .await?;
2258
2259 broadcast(
2260 Some(session.connection_id),
2261 guest_connection_ids.iter().copied(),
2262 |connection_id| {
2263 session
2264 .peer
2265 .forward_send(session.connection_id, connection_id, request.clone())
2266 },
2267 );
2268 Ok(())
2269}
2270
2271/// Notify other participants that a language server has changed.
2272async fn update_language_server(
2273 request: proto::UpdateLanguageServer,
2274 session: Session,
2275) -> Result<()> {
2276 let project_id = ProjectId::from_proto(request.project_id);
2277 let db = session.db().await;
2278
2279 if let Some(proto::update_language_server::Variant::MetadataUpdated(update)) = &request.variant
2280 {
2281 if let Some(capabilities) = update.capabilities.clone() {
2282 db.update_server_capabilities(project_id, request.language_server_id, capabilities)
2283 .await?;
2284 }
2285 }
2286
2287 let project_connection_ids = db
2288 .project_connection_ids(project_id, session.connection_id, true)
2289 .await?;
2290 broadcast(
2291 Some(session.connection_id),
2292 project_connection_ids.iter().copied(),
2293 |connection_id| {
2294 session
2295 .peer
2296 .forward_send(session.connection_id, connection_id, request.clone())
2297 },
2298 );
2299 Ok(())
2300}
2301
2302/// forward a project request to the host. These requests should be read only
2303/// as guests are allowed to send them.
2304async fn forward_read_only_project_request<T>(
2305 request: T,
2306 response: Response<T>,
2307 session: Session,
2308) -> Result<()>
2309where
2310 T: EntityMessage + RequestMessage,
2311{
2312 let project_id = ProjectId::from_proto(request.remote_entity_id());
2313 let host_connection_id = session
2314 .db()
2315 .await
2316 .host_for_read_only_project_request(project_id, session.connection_id)
2317 .await?;
2318 let payload = session
2319 .peer
2320 .forward_request(session.connection_id, host_connection_id, request)
2321 .await?;
2322 response.send(payload)?;
2323 Ok(())
2324}
2325
2326/// forward a project request to the host. These requests are disallowed
2327/// for guests.
2328async fn forward_mutating_project_request<T>(
2329 request: T,
2330 response: Response<T>,
2331 session: Session,
2332) -> Result<()>
2333where
2334 T: EntityMessage + RequestMessage,
2335{
2336 let project_id = ProjectId::from_proto(request.remote_entity_id());
2337
2338 let host_connection_id = session
2339 .db()
2340 .await
2341 .host_for_mutating_project_request(project_id, session.connection_id)
2342 .await?;
2343 let payload = session
2344 .peer
2345 .forward_request(session.connection_id, host_connection_id, request)
2346 .await?;
2347 response.send(payload)?;
2348 Ok(())
2349}
2350
2351async fn multi_lsp_query(
2352 request: MultiLspQuery,
2353 response: Response<MultiLspQuery>,
2354 session: Session,
2355) -> Result<()> {
2356 tracing::Span::current().record("multi_lsp_query_request", request.request_str());
2357 tracing::info!("multi_lsp_query message received");
2358 forward_mutating_project_request(request, response, session).await
2359}
2360
2361/// Notify other participants that a new buffer has been created
2362async fn create_buffer_for_peer(
2363 request: proto::CreateBufferForPeer,
2364 session: Session,
2365) -> Result<()> {
2366 session
2367 .db()
2368 .await
2369 .check_user_is_project_host(
2370 ProjectId::from_proto(request.project_id),
2371 session.connection_id,
2372 )
2373 .await?;
2374 let peer_id = request.peer_id.context("invalid peer id")?;
2375 session
2376 .peer
2377 .forward_send(session.connection_id, peer_id.into(), request)?;
2378 Ok(())
2379}
2380
2381/// Notify other participants that a buffer has been updated. This is
2382/// allowed for guests as long as the update is limited to selections.
2383async fn update_buffer(
2384 request: proto::UpdateBuffer,
2385 response: Response<proto::UpdateBuffer>,
2386 session: Session,
2387) -> Result<()> {
2388 let project_id = ProjectId::from_proto(request.project_id);
2389 let mut capability = Capability::ReadOnly;
2390
2391 for op in request.operations.iter() {
2392 match op.variant {
2393 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2394 Some(_) => capability = Capability::ReadWrite,
2395 }
2396 }
2397
2398 let host = {
2399 let guard = session
2400 .db()
2401 .await
2402 .connections_for_buffer_update(project_id, session.connection_id, capability)
2403 .await?;
2404
2405 let (host, guests) = &*guard;
2406
2407 broadcast(
2408 Some(session.connection_id),
2409 guests.clone(),
2410 |connection_id| {
2411 session
2412 .peer
2413 .forward_send(session.connection_id, connection_id, request.clone())
2414 },
2415 );
2416
2417 *host
2418 };
2419
2420 if host != session.connection_id {
2421 session
2422 .peer
2423 .forward_request(session.connection_id, host, request.clone())
2424 .await?;
2425 }
2426
2427 response.send(proto::Ack {})?;
2428 Ok(())
2429}
2430
2431async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2432 let project_id = ProjectId::from_proto(message.project_id);
2433
2434 let operation = message.operation.as_ref().context("invalid operation")?;
2435 let capability = match operation.variant.as_ref() {
2436 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2437 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2438 match buffer_op.variant {
2439 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2440 Capability::ReadOnly
2441 }
2442 _ => Capability::ReadWrite,
2443 }
2444 } else {
2445 Capability::ReadWrite
2446 }
2447 }
2448 Some(_) => Capability::ReadWrite,
2449 None => Capability::ReadOnly,
2450 };
2451
2452 let guard = session
2453 .db()
2454 .await
2455 .connections_for_buffer_update(project_id, session.connection_id, capability)
2456 .await?;
2457
2458 let (host, guests) = &*guard;
2459
2460 broadcast(
2461 Some(session.connection_id),
2462 guests.iter().chain([host]).copied(),
2463 |connection_id| {
2464 session
2465 .peer
2466 .forward_send(session.connection_id, connection_id, message.clone())
2467 },
2468 );
2469
2470 Ok(())
2471}
2472
2473/// Notify other participants that a project has been updated.
2474async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2475 request: T,
2476 session: Session,
2477) -> Result<()> {
2478 let project_id = ProjectId::from_proto(request.remote_entity_id());
2479 let project_connection_ids = session
2480 .db()
2481 .await
2482 .project_connection_ids(project_id, session.connection_id, false)
2483 .await?;
2484
2485 broadcast(
2486 Some(session.connection_id),
2487 project_connection_ids.iter().copied(),
2488 |connection_id| {
2489 session
2490 .peer
2491 .forward_send(session.connection_id, connection_id, request.clone())
2492 },
2493 );
2494 Ok(())
2495}
2496
2497/// Start following another user in a call.
2498async fn follow(
2499 request: proto::Follow,
2500 response: Response<proto::Follow>,
2501 session: Session,
2502) -> Result<()> {
2503 let room_id = RoomId::from_proto(request.room_id);
2504 let project_id = request.project_id.map(ProjectId::from_proto);
2505 let leader_id = request.leader_id.context("invalid leader id")?.into();
2506 let follower_id = session.connection_id;
2507
2508 session
2509 .db()
2510 .await
2511 .check_room_participants(room_id, leader_id, session.connection_id)
2512 .await?;
2513
2514 let response_payload = session
2515 .peer
2516 .forward_request(session.connection_id, leader_id, request)
2517 .await?;
2518 response.send(response_payload)?;
2519
2520 if let Some(project_id) = project_id {
2521 let room = session
2522 .db()
2523 .await
2524 .follow(room_id, project_id, leader_id, follower_id)
2525 .await?;
2526 room_updated(&room, &session.peer);
2527 }
2528
2529 Ok(())
2530}
2531
2532/// Stop following another user in a call.
2533async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2534 let room_id = RoomId::from_proto(request.room_id);
2535 let project_id = request.project_id.map(ProjectId::from_proto);
2536 let leader_id = request.leader_id.context("invalid leader id")?.into();
2537 let follower_id = session.connection_id;
2538
2539 session
2540 .db()
2541 .await
2542 .check_room_participants(room_id, leader_id, session.connection_id)
2543 .await?;
2544
2545 session
2546 .peer
2547 .forward_send(session.connection_id, leader_id, request)?;
2548
2549 if let Some(project_id) = project_id {
2550 let room = session
2551 .db()
2552 .await
2553 .unfollow(room_id, project_id, leader_id, follower_id)
2554 .await?;
2555 room_updated(&room, &session.peer);
2556 }
2557
2558 Ok(())
2559}
2560
2561/// Notify everyone following you of your current location.
2562async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2563 let room_id = RoomId::from_proto(request.room_id);
2564 let database = session.db.lock().await;
2565
2566 let connection_ids = if let Some(project_id) = request.project_id {
2567 let project_id = ProjectId::from_proto(project_id);
2568 database
2569 .project_connection_ids(project_id, session.connection_id, true)
2570 .await?
2571 } else {
2572 database
2573 .room_connection_ids(room_id, session.connection_id)
2574 .await?
2575 };
2576
2577 // For now, don't send view update messages back to that view's current leader.
2578 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2579 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2580 _ => None,
2581 });
2582
2583 for connection_id in connection_ids.iter().cloned() {
2584 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2585 session
2586 .peer
2587 .forward_send(session.connection_id, connection_id, request.clone())?;
2588 }
2589 }
2590 Ok(())
2591}
2592
2593/// Get public data about users.
2594async fn get_users(
2595 request: proto::GetUsers,
2596 response: Response<proto::GetUsers>,
2597 session: Session,
2598) -> Result<()> {
2599 let user_ids = request
2600 .user_ids
2601 .into_iter()
2602 .map(UserId::from_proto)
2603 .collect();
2604 let users = session
2605 .db()
2606 .await
2607 .get_users_by_ids(user_ids)
2608 .await?
2609 .into_iter()
2610 .map(|user| proto::User {
2611 id: user.id.to_proto(),
2612 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2613 github_login: user.github_login,
2614 name: user.name,
2615 })
2616 .collect();
2617 response.send(proto::UsersResponse { users })?;
2618 Ok(())
2619}
2620
2621/// Search for users (to invite) buy Github login
2622async fn fuzzy_search_users(
2623 request: proto::FuzzySearchUsers,
2624 response: Response<proto::FuzzySearchUsers>,
2625 session: Session,
2626) -> Result<()> {
2627 let query = request.query;
2628 let users = match query.len() {
2629 0 => vec![],
2630 1 | 2 => session
2631 .db()
2632 .await
2633 .get_user_by_github_login(&query)
2634 .await?
2635 .into_iter()
2636 .collect(),
2637 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2638 };
2639 let users = users
2640 .into_iter()
2641 .filter(|user| user.id != session.user_id())
2642 .map(|user| proto::User {
2643 id: user.id.to_proto(),
2644 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2645 github_login: user.github_login,
2646 name: user.name,
2647 })
2648 .collect();
2649 response.send(proto::UsersResponse { users })?;
2650 Ok(())
2651}
2652
2653/// Send a contact request to another user.
2654async fn request_contact(
2655 request: proto::RequestContact,
2656 response: Response<proto::RequestContact>,
2657 session: Session,
2658) -> Result<()> {
2659 let requester_id = session.user_id();
2660 let responder_id = UserId::from_proto(request.responder_id);
2661 if requester_id == responder_id {
2662 return Err(anyhow!("cannot add yourself as a contact"))?;
2663 }
2664
2665 let notifications = session
2666 .db()
2667 .await
2668 .send_contact_request(requester_id, responder_id)
2669 .await?;
2670
2671 // Update outgoing contact requests of requester
2672 let mut update = proto::UpdateContacts::default();
2673 update.outgoing_requests.push(responder_id.to_proto());
2674 for connection_id in session
2675 .connection_pool()
2676 .await
2677 .user_connection_ids(requester_id)
2678 {
2679 session.peer.send(connection_id, update.clone())?;
2680 }
2681
2682 // Update incoming contact requests of responder
2683 let mut update = proto::UpdateContacts::default();
2684 update
2685 .incoming_requests
2686 .push(proto::IncomingContactRequest {
2687 requester_id: requester_id.to_proto(),
2688 });
2689 let connection_pool = session.connection_pool().await;
2690 for connection_id in connection_pool.user_connection_ids(responder_id) {
2691 session.peer.send(connection_id, update.clone())?;
2692 }
2693
2694 send_notifications(&connection_pool, &session.peer, notifications);
2695
2696 response.send(proto::Ack {})?;
2697 Ok(())
2698}
2699
2700/// Accept or decline a contact request
2701async fn respond_to_contact_request(
2702 request: proto::RespondToContactRequest,
2703 response: Response<proto::RespondToContactRequest>,
2704 session: Session,
2705) -> Result<()> {
2706 let responder_id = session.user_id();
2707 let requester_id = UserId::from_proto(request.requester_id);
2708 let db = session.db().await;
2709 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2710 db.dismiss_contact_notification(responder_id, requester_id)
2711 .await?;
2712 } else {
2713 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2714
2715 let notifications = db
2716 .respond_to_contact_request(responder_id, requester_id, accept)
2717 .await?;
2718 let requester_busy = db.is_user_busy(requester_id).await?;
2719 let responder_busy = db.is_user_busy(responder_id).await?;
2720
2721 let pool = session.connection_pool().await;
2722 // Update responder with new contact
2723 let mut update = proto::UpdateContacts::default();
2724 if accept {
2725 update
2726 .contacts
2727 .push(contact_for_user(requester_id, requester_busy, &pool));
2728 }
2729 update
2730 .remove_incoming_requests
2731 .push(requester_id.to_proto());
2732 for connection_id in pool.user_connection_ids(responder_id) {
2733 session.peer.send(connection_id, update.clone())?;
2734 }
2735
2736 // Update requester with new contact
2737 let mut update = proto::UpdateContacts::default();
2738 if accept {
2739 update
2740 .contacts
2741 .push(contact_for_user(responder_id, responder_busy, &pool));
2742 }
2743 update
2744 .remove_outgoing_requests
2745 .push(responder_id.to_proto());
2746
2747 for connection_id in pool.user_connection_ids(requester_id) {
2748 session.peer.send(connection_id, update.clone())?;
2749 }
2750
2751 send_notifications(&pool, &session.peer, notifications);
2752 }
2753
2754 response.send(proto::Ack {})?;
2755 Ok(())
2756}
2757
2758/// Remove a contact.
2759async fn remove_contact(
2760 request: proto::RemoveContact,
2761 response: Response<proto::RemoveContact>,
2762 session: Session,
2763) -> Result<()> {
2764 let requester_id = session.user_id();
2765 let responder_id = UserId::from_proto(request.user_id);
2766 let db = session.db().await;
2767 let (contact_accepted, deleted_notification_id) =
2768 db.remove_contact(requester_id, responder_id).await?;
2769
2770 let pool = session.connection_pool().await;
2771 // Update outgoing contact requests of requester
2772 let mut update = proto::UpdateContacts::default();
2773 if contact_accepted {
2774 update.remove_contacts.push(responder_id.to_proto());
2775 } else {
2776 update
2777 .remove_outgoing_requests
2778 .push(responder_id.to_proto());
2779 }
2780 for connection_id in pool.user_connection_ids(requester_id) {
2781 session.peer.send(connection_id, update.clone())?;
2782 }
2783
2784 // Update incoming contact requests of responder
2785 let mut update = proto::UpdateContacts::default();
2786 if contact_accepted {
2787 update.remove_contacts.push(requester_id.to_proto());
2788 } else {
2789 update
2790 .remove_incoming_requests
2791 .push(requester_id.to_proto());
2792 }
2793 for connection_id in pool.user_connection_ids(responder_id) {
2794 session.peer.send(connection_id, update.clone())?;
2795 if let Some(notification_id) = deleted_notification_id {
2796 session.peer.send(
2797 connection_id,
2798 proto::DeleteNotification {
2799 notification_id: notification_id.to_proto(),
2800 },
2801 )?;
2802 }
2803 }
2804
2805 response.send(proto::Ack {})?;
2806 Ok(())
2807}
2808
2809fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2810 version.0.minor() < 139
2811}
2812
2813async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2814 if is_staff {
2815 return Ok(proto::Plan::ZedPro);
2816 }
2817
2818 let subscription = db.get_active_billing_subscription(user_id).await?;
2819 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2820
2821 let plan = if let Some(subscription_kind) = subscription_kind {
2822 match subscription_kind {
2823 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2824 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2825 SubscriptionKind::ZedFree => proto::Plan::Free,
2826 }
2827 } else {
2828 proto::Plan::Free
2829 };
2830
2831 Ok(plan)
2832}
2833
2834async fn make_update_user_plan_message(
2835 user: &User,
2836 is_staff: bool,
2837 db: &Arc<Database>,
2838 llm_db: Option<Arc<LlmDatabase>>,
2839) -> Result<proto::UpdateUserPlan> {
2840 let feature_flags = db.get_user_flags(user.id).await?;
2841 let plan = current_plan(db, user.id, is_staff).await?;
2842 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2843 let billing_preferences = db.get_billing_preferences(user.id).await?;
2844
2845 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2846 let subscription = db.get_active_billing_subscription(user.id).await?;
2847
2848 let subscription_period =
2849 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2850
2851 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2852 llm_db
2853 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2854 .await?
2855 } else {
2856 None
2857 };
2858
2859 (subscription_period, usage)
2860 } else {
2861 (None, None)
2862 };
2863
2864 let bypass_account_age_check = feature_flags
2865 .iter()
2866 .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2867 let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2868 && !bypass_account_age_check
2869 && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2870
2871 Ok(proto::UpdateUserPlan {
2872 plan: plan.into(),
2873 trial_started_at: billing_customer
2874 .as_ref()
2875 .and_then(|billing_customer| billing_customer.trial_started_at)
2876 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2877 is_usage_based_billing_enabled: if is_staff {
2878 Some(true)
2879 } else {
2880 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2881 },
2882 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2883 proto::SubscriptionPeriod {
2884 started_at: started_at.timestamp() as u64,
2885 ended_at: ended_at.timestamp() as u64,
2886 }
2887 }),
2888 account_too_young: Some(account_too_young),
2889 has_overdue_invoices: billing_customer
2890 .map(|billing_customer| billing_customer.has_overdue_invoices),
2891 usage: Some(
2892 usage
2893 .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2894 .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2895 ),
2896 })
2897}
2898
2899fn model_requests_limit(
2900 plan: cloud_llm_client::Plan,
2901 feature_flags: &Vec<String>,
2902) -> cloud_llm_client::UsageLimit {
2903 match plan.model_requests_limit() {
2904 cloud_llm_client::UsageLimit::Limited(limit) => {
2905 let limit = if plan == cloud_llm_client::Plan::ZedProTrial
2906 && feature_flags
2907 .iter()
2908 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2909 {
2910 1_000
2911 } else {
2912 limit
2913 };
2914
2915 cloud_llm_client::UsageLimit::Limited(limit)
2916 }
2917 cloud_llm_client::UsageLimit::Unlimited => cloud_llm_client::UsageLimit::Unlimited,
2918 }
2919}
2920
2921fn subscription_usage_to_proto(
2922 plan: proto::Plan,
2923 usage: crate::llm::db::subscription_usage::Model,
2924 feature_flags: &Vec<String>,
2925) -> proto::SubscriptionUsage {
2926 let plan = match plan {
2927 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2928 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2929 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2930 };
2931
2932 proto::SubscriptionUsage {
2933 model_requests_usage_amount: usage.model_requests as u32,
2934 model_requests_usage_limit: Some(proto::UsageLimit {
2935 variant: Some(match model_requests_limit(plan, feature_flags) {
2936 cloud_llm_client::UsageLimit::Limited(limit) => {
2937 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2938 limit: limit as u32,
2939 })
2940 }
2941 cloud_llm_client::UsageLimit::Unlimited => {
2942 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2943 }
2944 }),
2945 }),
2946 edit_predictions_usage_amount: usage.edit_predictions as u32,
2947 edit_predictions_usage_limit: Some(proto::UsageLimit {
2948 variant: Some(match plan.edit_predictions_limit() {
2949 cloud_llm_client::UsageLimit::Limited(limit) => {
2950 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2951 limit: limit as u32,
2952 })
2953 }
2954 cloud_llm_client::UsageLimit::Unlimited => {
2955 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2956 }
2957 }),
2958 }),
2959 }
2960}
2961
2962fn make_default_subscription_usage(
2963 plan: proto::Plan,
2964 feature_flags: &Vec<String>,
2965) -> proto::SubscriptionUsage {
2966 let plan = match plan {
2967 proto::Plan::Free => cloud_llm_client::Plan::ZedFree,
2968 proto::Plan::ZedPro => cloud_llm_client::Plan::ZedPro,
2969 proto::Plan::ZedProTrial => cloud_llm_client::Plan::ZedProTrial,
2970 };
2971
2972 proto::SubscriptionUsage {
2973 model_requests_usage_amount: 0,
2974 model_requests_usage_limit: Some(proto::UsageLimit {
2975 variant: Some(match model_requests_limit(plan, feature_flags) {
2976 cloud_llm_client::UsageLimit::Limited(limit) => {
2977 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2978 limit: limit as u32,
2979 })
2980 }
2981 cloud_llm_client::UsageLimit::Unlimited => {
2982 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2983 }
2984 }),
2985 }),
2986 edit_predictions_usage_amount: 0,
2987 edit_predictions_usage_limit: Some(proto::UsageLimit {
2988 variant: Some(match plan.edit_predictions_limit() {
2989 cloud_llm_client::UsageLimit::Limited(limit) => {
2990 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2991 limit: limit as u32,
2992 })
2993 }
2994 cloud_llm_client::UsageLimit::Unlimited => {
2995 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2996 }
2997 }),
2998 }),
2999 }
3000}
3001
3002async fn update_user_plan(session: &Session) -> Result<()> {
3003 let db = session.db().await;
3004
3005 let update_user_plan = make_update_user_plan_message(
3006 session.principal.user(),
3007 session.is_staff(),
3008 &db.0,
3009 session.app_state.llm_db.clone(),
3010 )
3011 .await?;
3012
3013 session
3014 .peer
3015 .send(session.connection_id, update_user_plan)
3016 .trace_err();
3017
3018 Ok(())
3019}
3020
3021async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3022 subscribe_user_to_channels(session.user_id(), &session).await?;
3023 Ok(())
3024}
3025
3026async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3027 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3028 let mut pool = session.connection_pool().await;
3029 for membership in &channels_for_user.channel_memberships {
3030 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3031 }
3032 session.peer.send(
3033 session.connection_id,
3034 build_update_user_channels(&channels_for_user),
3035 )?;
3036 session.peer.send(
3037 session.connection_id,
3038 build_channels_update(channels_for_user),
3039 )?;
3040 Ok(())
3041}
3042
3043/// Creates a new channel.
3044async fn create_channel(
3045 request: proto::CreateChannel,
3046 response: Response<proto::CreateChannel>,
3047 session: Session,
3048) -> Result<()> {
3049 let db = session.db().await;
3050
3051 let parent_id = request.parent_id.map(ChannelId::from_proto);
3052 let (channel, membership) = db
3053 .create_channel(&request.name, parent_id, session.user_id())
3054 .await?;
3055
3056 let root_id = channel.root_id();
3057 let channel = Channel::from_model(channel);
3058
3059 response.send(proto::CreateChannelResponse {
3060 channel: Some(channel.to_proto()),
3061 parent_id: request.parent_id,
3062 })?;
3063
3064 let mut connection_pool = session.connection_pool().await;
3065 if let Some(membership) = membership {
3066 connection_pool.subscribe_to_channel(
3067 membership.user_id,
3068 membership.channel_id,
3069 membership.role,
3070 );
3071 let update = proto::UpdateUserChannels {
3072 channel_memberships: vec![proto::ChannelMembership {
3073 channel_id: membership.channel_id.to_proto(),
3074 role: membership.role.into(),
3075 }],
3076 ..Default::default()
3077 };
3078 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3079 session.peer.send(connection_id, update.clone())?;
3080 }
3081 }
3082
3083 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3084 if !role.can_see_channel(channel.visibility) {
3085 continue;
3086 }
3087
3088 let update = proto::UpdateChannels {
3089 channels: vec![channel.to_proto()],
3090 ..Default::default()
3091 };
3092 session.peer.send(connection_id, update.clone())?;
3093 }
3094
3095 Ok(())
3096}
3097
3098/// Delete a channel
3099async fn delete_channel(
3100 request: proto::DeleteChannel,
3101 response: Response<proto::DeleteChannel>,
3102 session: Session,
3103) -> Result<()> {
3104 let db = session.db().await;
3105
3106 let channel_id = request.channel_id;
3107 let (root_channel, removed_channels) = db
3108 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3109 .await?;
3110 response.send(proto::Ack {})?;
3111
3112 // Notify members of removed channels
3113 let mut update = proto::UpdateChannels::default();
3114 update
3115 .delete_channels
3116 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3117
3118 let connection_pool = session.connection_pool().await;
3119 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3120 session.peer.send(connection_id, update.clone())?;
3121 }
3122
3123 Ok(())
3124}
3125
3126/// Invite someone to join a channel.
3127async fn invite_channel_member(
3128 request: proto::InviteChannelMember,
3129 response: Response<proto::InviteChannelMember>,
3130 session: Session,
3131) -> Result<()> {
3132 let db = session.db().await;
3133 let channel_id = ChannelId::from_proto(request.channel_id);
3134 let invitee_id = UserId::from_proto(request.user_id);
3135 let InviteMemberResult {
3136 channel,
3137 notifications,
3138 } = db
3139 .invite_channel_member(
3140 channel_id,
3141 invitee_id,
3142 session.user_id(),
3143 request.role().into(),
3144 )
3145 .await?;
3146
3147 let update = proto::UpdateChannels {
3148 channel_invitations: vec![channel.to_proto()],
3149 ..Default::default()
3150 };
3151
3152 let connection_pool = session.connection_pool().await;
3153 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3154 session.peer.send(connection_id, update.clone())?;
3155 }
3156
3157 send_notifications(&connection_pool, &session.peer, notifications);
3158
3159 response.send(proto::Ack {})?;
3160 Ok(())
3161}
3162
3163/// remove someone from a channel
3164async fn remove_channel_member(
3165 request: proto::RemoveChannelMember,
3166 response: Response<proto::RemoveChannelMember>,
3167 session: Session,
3168) -> Result<()> {
3169 let db = session.db().await;
3170 let channel_id = ChannelId::from_proto(request.channel_id);
3171 let member_id = UserId::from_proto(request.user_id);
3172
3173 let RemoveChannelMemberResult {
3174 membership_update,
3175 notification_id,
3176 } = db
3177 .remove_channel_member(channel_id, member_id, session.user_id())
3178 .await?;
3179
3180 let mut connection_pool = session.connection_pool().await;
3181 notify_membership_updated(
3182 &mut connection_pool,
3183 membership_update,
3184 member_id,
3185 &session.peer,
3186 );
3187 for connection_id in connection_pool.user_connection_ids(member_id) {
3188 if let Some(notification_id) = notification_id {
3189 session
3190 .peer
3191 .send(
3192 connection_id,
3193 proto::DeleteNotification {
3194 notification_id: notification_id.to_proto(),
3195 },
3196 )
3197 .trace_err();
3198 }
3199 }
3200
3201 response.send(proto::Ack {})?;
3202 Ok(())
3203}
3204
3205/// Toggle the channel between public and private.
3206/// Care is taken to maintain the invariant that public channels only descend from public channels,
3207/// (though members-only channels can appear at any point in the hierarchy).
3208async fn set_channel_visibility(
3209 request: proto::SetChannelVisibility,
3210 response: Response<proto::SetChannelVisibility>,
3211 session: Session,
3212) -> Result<()> {
3213 let db = session.db().await;
3214 let channel_id = ChannelId::from_proto(request.channel_id);
3215 let visibility = request.visibility().into();
3216
3217 let channel_model = db
3218 .set_channel_visibility(channel_id, visibility, session.user_id())
3219 .await?;
3220 let root_id = channel_model.root_id();
3221 let channel = Channel::from_model(channel_model);
3222
3223 let mut connection_pool = session.connection_pool().await;
3224 for (user_id, role) in connection_pool
3225 .channel_user_ids(root_id)
3226 .collect::<Vec<_>>()
3227 .into_iter()
3228 {
3229 let update = if role.can_see_channel(channel.visibility) {
3230 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3231 proto::UpdateChannels {
3232 channels: vec![channel.to_proto()],
3233 ..Default::default()
3234 }
3235 } else {
3236 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3237 proto::UpdateChannels {
3238 delete_channels: vec![channel.id.to_proto()],
3239 ..Default::default()
3240 }
3241 };
3242
3243 for connection_id in connection_pool.user_connection_ids(user_id) {
3244 session.peer.send(connection_id, update.clone())?;
3245 }
3246 }
3247
3248 response.send(proto::Ack {})?;
3249 Ok(())
3250}
3251
3252/// Alter the role for a user in the channel.
3253async fn set_channel_member_role(
3254 request: proto::SetChannelMemberRole,
3255 response: Response<proto::SetChannelMemberRole>,
3256 session: Session,
3257) -> Result<()> {
3258 let db = session.db().await;
3259 let channel_id = ChannelId::from_proto(request.channel_id);
3260 let member_id = UserId::from_proto(request.user_id);
3261 let result = db
3262 .set_channel_member_role(
3263 channel_id,
3264 session.user_id(),
3265 member_id,
3266 request.role().into(),
3267 )
3268 .await?;
3269
3270 match result {
3271 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3272 let mut connection_pool = session.connection_pool().await;
3273 notify_membership_updated(
3274 &mut connection_pool,
3275 membership_update,
3276 member_id,
3277 &session.peer,
3278 )
3279 }
3280 db::SetMemberRoleResult::InviteUpdated(channel) => {
3281 let update = proto::UpdateChannels {
3282 channel_invitations: vec![channel.to_proto()],
3283 ..Default::default()
3284 };
3285
3286 for connection_id in session
3287 .connection_pool()
3288 .await
3289 .user_connection_ids(member_id)
3290 {
3291 session.peer.send(connection_id, update.clone())?;
3292 }
3293 }
3294 }
3295
3296 response.send(proto::Ack {})?;
3297 Ok(())
3298}
3299
3300/// Change the name of a channel
3301async fn rename_channel(
3302 request: proto::RenameChannel,
3303 response: Response<proto::RenameChannel>,
3304 session: Session,
3305) -> Result<()> {
3306 let db = session.db().await;
3307 let channel_id = ChannelId::from_proto(request.channel_id);
3308 let channel_model = db
3309 .rename_channel(channel_id, session.user_id(), &request.name)
3310 .await?;
3311 let root_id = channel_model.root_id();
3312 let channel = Channel::from_model(channel_model);
3313
3314 response.send(proto::RenameChannelResponse {
3315 channel: Some(channel.to_proto()),
3316 })?;
3317
3318 let connection_pool = session.connection_pool().await;
3319 let update = proto::UpdateChannels {
3320 channels: vec![channel.to_proto()],
3321 ..Default::default()
3322 };
3323 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3324 if role.can_see_channel(channel.visibility) {
3325 session.peer.send(connection_id, update.clone())?;
3326 }
3327 }
3328
3329 Ok(())
3330}
3331
3332/// Move a channel to a new parent.
3333async fn move_channel(
3334 request: proto::MoveChannel,
3335 response: Response<proto::MoveChannel>,
3336 session: Session,
3337) -> Result<()> {
3338 let channel_id = ChannelId::from_proto(request.channel_id);
3339 let to = ChannelId::from_proto(request.to);
3340
3341 let (root_id, channels) = session
3342 .db()
3343 .await
3344 .move_channel(channel_id, to, session.user_id())
3345 .await?;
3346
3347 let connection_pool = session.connection_pool().await;
3348 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3349 let channels = channels
3350 .iter()
3351 .filter_map(|channel| {
3352 if role.can_see_channel(channel.visibility) {
3353 Some(channel.to_proto())
3354 } else {
3355 None
3356 }
3357 })
3358 .collect::<Vec<_>>();
3359 if channels.is_empty() {
3360 continue;
3361 }
3362
3363 let update = proto::UpdateChannels {
3364 channels,
3365 ..Default::default()
3366 };
3367
3368 session.peer.send(connection_id, update.clone())?;
3369 }
3370
3371 response.send(Ack {})?;
3372 Ok(())
3373}
3374
3375async fn reorder_channel(
3376 request: proto::ReorderChannel,
3377 response: Response<proto::ReorderChannel>,
3378 session: Session,
3379) -> Result<()> {
3380 let channel_id = ChannelId::from_proto(request.channel_id);
3381 let direction = request.direction();
3382
3383 let updated_channels = session
3384 .db()
3385 .await
3386 .reorder_channel(channel_id, direction, session.user_id())
3387 .await?;
3388
3389 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3390 let connection_pool = session.connection_pool().await;
3391 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3392 let channels = updated_channels
3393 .iter()
3394 .filter_map(|channel| {
3395 if role.can_see_channel(channel.visibility) {
3396 Some(channel.to_proto())
3397 } else {
3398 None
3399 }
3400 })
3401 .collect::<Vec<_>>();
3402
3403 if channels.is_empty() {
3404 continue;
3405 }
3406
3407 let update = proto::UpdateChannels {
3408 channels,
3409 ..Default::default()
3410 };
3411
3412 session.peer.send(connection_id, update.clone())?;
3413 }
3414 }
3415
3416 response.send(Ack {})?;
3417 Ok(())
3418}
3419
3420/// Get the list of channel members
3421async fn get_channel_members(
3422 request: proto::GetChannelMembers,
3423 response: Response<proto::GetChannelMembers>,
3424 session: Session,
3425) -> Result<()> {
3426 let db = session.db().await;
3427 let channel_id = ChannelId::from_proto(request.channel_id);
3428 let limit = if request.limit == 0 {
3429 u16::MAX as u64
3430 } else {
3431 request.limit
3432 };
3433 let (members, users) = db
3434 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3435 .await?;
3436 response.send(proto::GetChannelMembersResponse { members, users })?;
3437 Ok(())
3438}
3439
3440/// Accept or decline a channel invitation.
3441async fn respond_to_channel_invite(
3442 request: proto::RespondToChannelInvite,
3443 response: Response<proto::RespondToChannelInvite>,
3444 session: Session,
3445) -> Result<()> {
3446 let db = session.db().await;
3447 let channel_id = ChannelId::from_proto(request.channel_id);
3448 let RespondToChannelInvite {
3449 membership_update,
3450 notifications,
3451 } = db
3452 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3453 .await?;
3454
3455 let mut connection_pool = session.connection_pool().await;
3456 if let Some(membership_update) = membership_update {
3457 notify_membership_updated(
3458 &mut connection_pool,
3459 membership_update,
3460 session.user_id(),
3461 &session.peer,
3462 );
3463 } else {
3464 let update = proto::UpdateChannels {
3465 remove_channel_invitations: vec![channel_id.to_proto()],
3466 ..Default::default()
3467 };
3468
3469 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3470 session.peer.send(connection_id, update.clone())?;
3471 }
3472 };
3473
3474 send_notifications(&connection_pool, &session.peer, notifications);
3475
3476 response.send(proto::Ack {})?;
3477
3478 Ok(())
3479}
3480
3481/// Join the channels' room
3482async fn join_channel(
3483 request: proto::JoinChannel,
3484 response: Response<proto::JoinChannel>,
3485 session: Session,
3486) -> Result<()> {
3487 let channel_id = ChannelId::from_proto(request.channel_id);
3488 join_channel_internal(channel_id, Box::new(response), session).await
3489}
3490
3491trait JoinChannelInternalResponse {
3492 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3493}
3494impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3495 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3496 Response::<proto::JoinChannel>::send(self, result)
3497 }
3498}
3499impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3500 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3501 Response::<proto::JoinRoom>::send(self, result)
3502 }
3503}
3504
3505async fn join_channel_internal(
3506 channel_id: ChannelId,
3507 response: Box<impl JoinChannelInternalResponse>,
3508 session: Session,
3509) -> Result<()> {
3510 let joined_room = {
3511 let mut db = session.db().await;
3512 // If zed quits without leaving the room, and the user re-opens zed before the
3513 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3514 // room they were in.
3515 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3516 tracing::info!(
3517 stale_connection_id = %connection,
3518 "cleaning up stale connection",
3519 );
3520 drop(db);
3521 leave_room_for_session(&session, connection).await?;
3522 db = session.db().await;
3523 }
3524
3525 let (joined_room, membership_updated, role) = db
3526 .join_channel(channel_id, session.user_id(), session.connection_id)
3527 .await?;
3528
3529 let live_kit_connection_info =
3530 session
3531 .app_state
3532 .livekit_client
3533 .as_ref()
3534 .and_then(|live_kit| {
3535 let (can_publish, token) = if role == ChannelRole::Guest {
3536 (
3537 false,
3538 live_kit
3539 .guest_token(
3540 &joined_room.room.livekit_room,
3541 &session.user_id().to_string(),
3542 )
3543 .trace_err()?,
3544 )
3545 } else {
3546 (
3547 true,
3548 live_kit
3549 .room_token(
3550 &joined_room.room.livekit_room,
3551 &session.user_id().to_string(),
3552 )
3553 .trace_err()?,
3554 )
3555 };
3556
3557 Some(LiveKitConnectionInfo {
3558 server_url: live_kit.url().into(),
3559 token,
3560 can_publish,
3561 })
3562 });
3563
3564 response.send(proto::JoinRoomResponse {
3565 room: Some(joined_room.room.clone()),
3566 channel_id: joined_room
3567 .channel
3568 .as_ref()
3569 .map(|channel| channel.id.to_proto()),
3570 live_kit_connection_info,
3571 })?;
3572
3573 let mut connection_pool = session.connection_pool().await;
3574 if let Some(membership_updated) = membership_updated {
3575 notify_membership_updated(
3576 &mut connection_pool,
3577 membership_updated,
3578 session.user_id(),
3579 &session.peer,
3580 );
3581 }
3582
3583 room_updated(&joined_room.room, &session.peer);
3584
3585 joined_room
3586 };
3587
3588 channel_updated(
3589 &joined_room.channel.context("channel not returned")?,
3590 &joined_room.room,
3591 &session.peer,
3592 &*session.connection_pool().await,
3593 );
3594
3595 update_user_contacts(session.user_id(), &session).await?;
3596 Ok(())
3597}
3598
3599/// Start editing the channel notes
3600async fn join_channel_buffer(
3601 request: proto::JoinChannelBuffer,
3602 response: Response<proto::JoinChannelBuffer>,
3603 session: Session,
3604) -> Result<()> {
3605 let db = session.db().await;
3606 let channel_id = ChannelId::from_proto(request.channel_id);
3607
3608 let open_response = db
3609 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3610 .await?;
3611
3612 let collaborators = open_response.collaborators.clone();
3613 response.send(open_response)?;
3614
3615 let update = UpdateChannelBufferCollaborators {
3616 channel_id: channel_id.to_proto(),
3617 collaborators: collaborators.clone(),
3618 };
3619 channel_buffer_updated(
3620 session.connection_id,
3621 collaborators
3622 .iter()
3623 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3624 &update,
3625 &session.peer,
3626 );
3627
3628 Ok(())
3629}
3630
3631/// Edit the channel notes
3632async fn update_channel_buffer(
3633 request: proto::UpdateChannelBuffer,
3634 session: Session,
3635) -> Result<()> {
3636 let db = session.db().await;
3637 let channel_id = ChannelId::from_proto(request.channel_id);
3638
3639 let (collaborators, epoch, version) = db
3640 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3641 .await?;
3642
3643 channel_buffer_updated(
3644 session.connection_id,
3645 collaborators.clone(),
3646 &proto::UpdateChannelBuffer {
3647 channel_id: channel_id.to_proto(),
3648 operations: request.operations,
3649 },
3650 &session.peer,
3651 );
3652
3653 let pool = &*session.connection_pool().await;
3654
3655 let non_collaborators =
3656 pool.channel_connection_ids(channel_id)
3657 .filter_map(|(connection_id, _)| {
3658 if collaborators.contains(&connection_id) {
3659 None
3660 } else {
3661 Some(connection_id)
3662 }
3663 });
3664
3665 broadcast(None, non_collaborators, |peer_id| {
3666 session.peer.send(
3667 peer_id,
3668 proto::UpdateChannels {
3669 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3670 channel_id: channel_id.to_proto(),
3671 epoch: epoch as u64,
3672 version: version.clone(),
3673 }],
3674 ..Default::default()
3675 },
3676 )
3677 });
3678
3679 Ok(())
3680}
3681
3682/// Rejoin the channel notes after a connection blip
3683async fn rejoin_channel_buffers(
3684 request: proto::RejoinChannelBuffers,
3685 response: Response<proto::RejoinChannelBuffers>,
3686 session: Session,
3687) -> Result<()> {
3688 let db = session.db().await;
3689 let buffers = db
3690 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3691 .await?;
3692
3693 for rejoined_buffer in &buffers {
3694 let collaborators_to_notify = rejoined_buffer
3695 .buffer
3696 .collaborators
3697 .iter()
3698 .filter_map(|c| Some(c.peer_id?.into()));
3699 channel_buffer_updated(
3700 session.connection_id,
3701 collaborators_to_notify,
3702 &proto::UpdateChannelBufferCollaborators {
3703 channel_id: rejoined_buffer.buffer.channel_id,
3704 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3705 },
3706 &session.peer,
3707 );
3708 }
3709
3710 response.send(proto::RejoinChannelBuffersResponse {
3711 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3712 })?;
3713
3714 Ok(())
3715}
3716
3717/// Stop editing the channel notes
3718async fn leave_channel_buffer(
3719 request: proto::LeaveChannelBuffer,
3720 response: Response<proto::LeaveChannelBuffer>,
3721 session: Session,
3722) -> Result<()> {
3723 let db = session.db().await;
3724 let channel_id = ChannelId::from_proto(request.channel_id);
3725
3726 let left_buffer = db
3727 .leave_channel_buffer(channel_id, session.connection_id)
3728 .await?;
3729
3730 response.send(Ack {})?;
3731
3732 channel_buffer_updated(
3733 session.connection_id,
3734 left_buffer.connections,
3735 &proto::UpdateChannelBufferCollaborators {
3736 channel_id: channel_id.to_proto(),
3737 collaborators: left_buffer.collaborators,
3738 },
3739 &session.peer,
3740 );
3741
3742 Ok(())
3743}
3744
3745fn channel_buffer_updated<T: EnvelopedMessage>(
3746 sender_id: ConnectionId,
3747 collaborators: impl IntoIterator<Item = ConnectionId>,
3748 message: &T,
3749 peer: &Peer,
3750) {
3751 broadcast(Some(sender_id), collaborators, |peer_id| {
3752 peer.send(peer_id, message.clone())
3753 });
3754}
3755
3756fn send_notifications(
3757 connection_pool: &ConnectionPool,
3758 peer: &Peer,
3759 notifications: db::NotificationBatch,
3760) {
3761 for (user_id, notification) in notifications {
3762 for connection_id in connection_pool.user_connection_ids(user_id) {
3763 if let Err(error) = peer.send(
3764 connection_id,
3765 proto::AddNotification {
3766 notification: Some(notification.clone()),
3767 },
3768 ) {
3769 tracing::error!(
3770 "failed to send notification to {:?} {}",
3771 connection_id,
3772 error
3773 );
3774 }
3775 }
3776 }
3777}
3778
3779/// Send a message to the channel
3780async fn send_channel_message(
3781 request: proto::SendChannelMessage,
3782 response: Response<proto::SendChannelMessage>,
3783 session: Session,
3784) -> Result<()> {
3785 // Validate the message body.
3786 let body = request.body.trim().to_string();
3787 if body.len() > MAX_MESSAGE_LEN {
3788 return Err(anyhow!("message is too long"))?;
3789 }
3790 if body.is_empty() {
3791 return Err(anyhow!("message can't be blank"))?;
3792 }
3793
3794 // TODO: adjust mentions if body is trimmed
3795
3796 let timestamp = OffsetDateTime::now_utc();
3797 let nonce = request.nonce.context("nonce can't be blank")?;
3798
3799 let channel_id = ChannelId::from_proto(request.channel_id);
3800 let CreatedChannelMessage {
3801 message_id,
3802 participant_connection_ids,
3803 notifications,
3804 } = session
3805 .db()
3806 .await
3807 .create_channel_message(
3808 channel_id,
3809 session.user_id(),
3810 &body,
3811 &request.mentions,
3812 timestamp,
3813 nonce.clone().into(),
3814 request.reply_to_message_id.map(MessageId::from_proto),
3815 )
3816 .await?;
3817
3818 let message = proto::ChannelMessage {
3819 sender_id: session.user_id().to_proto(),
3820 id: message_id.to_proto(),
3821 body,
3822 mentions: request.mentions,
3823 timestamp: timestamp.unix_timestamp() as u64,
3824 nonce: Some(nonce),
3825 reply_to_message_id: request.reply_to_message_id,
3826 edited_at: None,
3827 };
3828 broadcast(
3829 Some(session.connection_id),
3830 participant_connection_ids.clone(),
3831 |connection| {
3832 session.peer.send(
3833 connection,
3834 proto::ChannelMessageSent {
3835 channel_id: channel_id.to_proto(),
3836 message: Some(message.clone()),
3837 },
3838 )
3839 },
3840 );
3841 response.send(proto::SendChannelMessageResponse {
3842 message: Some(message),
3843 })?;
3844
3845 let pool = &*session.connection_pool().await;
3846 let non_participants =
3847 pool.channel_connection_ids(channel_id)
3848 .filter_map(|(connection_id, _)| {
3849 if participant_connection_ids.contains(&connection_id) {
3850 None
3851 } else {
3852 Some(connection_id)
3853 }
3854 });
3855 broadcast(None, non_participants, |peer_id| {
3856 session.peer.send(
3857 peer_id,
3858 proto::UpdateChannels {
3859 latest_channel_message_ids: vec![proto::ChannelMessageId {
3860 channel_id: channel_id.to_proto(),
3861 message_id: message_id.to_proto(),
3862 }],
3863 ..Default::default()
3864 },
3865 )
3866 });
3867 send_notifications(pool, &session.peer, notifications);
3868
3869 Ok(())
3870}
3871
3872/// Delete a channel message
3873async fn remove_channel_message(
3874 request: proto::RemoveChannelMessage,
3875 response: Response<proto::RemoveChannelMessage>,
3876 session: Session,
3877) -> Result<()> {
3878 let channel_id = ChannelId::from_proto(request.channel_id);
3879 let message_id = MessageId::from_proto(request.message_id);
3880 let (connection_ids, existing_notification_ids) = session
3881 .db()
3882 .await
3883 .remove_channel_message(channel_id, message_id, session.user_id())
3884 .await?;
3885
3886 broadcast(
3887 Some(session.connection_id),
3888 connection_ids,
3889 move |connection| {
3890 session.peer.send(connection, request.clone())?;
3891
3892 for notification_id in &existing_notification_ids {
3893 session.peer.send(
3894 connection,
3895 proto::DeleteNotification {
3896 notification_id: (*notification_id).to_proto(),
3897 },
3898 )?;
3899 }
3900
3901 Ok(())
3902 },
3903 );
3904 response.send(proto::Ack {})?;
3905 Ok(())
3906}
3907
3908async fn update_channel_message(
3909 request: proto::UpdateChannelMessage,
3910 response: Response<proto::UpdateChannelMessage>,
3911 session: Session,
3912) -> Result<()> {
3913 let channel_id = ChannelId::from_proto(request.channel_id);
3914 let message_id = MessageId::from_proto(request.message_id);
3915 let updated_at = OffsetDateTime::now_utc();
3916 let UpdatedChannelMessage {
3917 message_id,
3918 participant_connection_ids,
3919 notifications,
3920 reply_to_message_id,
3921 timestamp,
3922 deleted_mention_notification_ids,
3923 updated_mention_notifications,
3924 } = session
3925 .db()
3926 .await
3927 .update_channel_message(
3928 channel_id,
3929 message_id,
3930 session.user_id(),
3931 request.body.as_str(),
3932 &request.mentions,
3933 updated_at,
3934 )
3935 .await?;
3936
3937 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3938
3939 let message = proto::ChannelMessage {
3940 sender_id: session.user_id().to_proto(),
3941 id: message_id.to_proto(),
3942 body: request.body.clone(),
3943 mentions: request.mentions.clone(),
3944 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3945 nonce: Some(nonce),
3946 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3947 edited_at: Some(updated_at.unix_timestamp() as u64),
3948 };
3949
3950 response.send(proto::Ack {})?;
3951
3952 let pool = &*session.connection_pool().await;
3953 broadcast(
3954 Some(session.connection_id),
3955 participant_connection_ids,
3956 |connection| {
3957 session.peer.send(
3958 connection,
3959 proto::ChannelMessageUpdate {
3960 channel_id: channel_id.to_proto(),
3961 message: Some(message.clone()),
3962 },
3963 )?;
3964
3965 for notification_id in &deleted_mention_notification_ids {
3966 session.peer.send(
3967 connection,
3968 proto::DeleteNotification {
3969 notification_id: (*notification_id).to_proto(),
3970 },
3971 )?;
3972 }
3973
3974 for notification in &updated_mention_notifications {
3975 session.peer.send(
3976 connection,
3977 proto::UpdateNotification {
3978 notification: Some(notification.clone()),
3979 },
3980 )?;
3981 }
3982
3983 Ok(())
3984 },
3985 );
3986
3987 send_notifications(pool, &session.peer, notifications);
3988
3989 Ok(())
3990}
3991
3992/// Mark a channel message as read
3993async fn acknowledge_channel_message(
3994 request: proto::AckChannelMessage,
3995 session: Session,
3996) -> Result<()> {
3997 let channel_id = ChannelId::from_proto(request.channel_id);
3998 let message_id = MessageId::from_proto(request.message_id);
3999 let notifications = session
4000 .db()
4001 .await
4002 .observe_channel_message(channel_id, session.user_id(), message_id)
4003 .await?;
4004 send_notifications(
4005 &*session.connection_pool().await,
4006 &session.peer,
4007 notifications,
4008 );
4009 Ok(())
4010}
4011
4012/// Mark a buffer version as synced
4013async fn acknowledge_buffer_version(
4014 request: proto::AckBufferOperation,
4015 session: Session,
4016) -> Result<()> {
4017 let buffer_id = BufferId::from_proto(request.buffer_id);
4018 session
4019 .db()
4020 .await
4021 .observe_buffer_version(
4022 buffer_id,
4023 session.user_id(),
4024 request.epoch as i32,
4025 &request.version,
4026 )
4027 .await?;
4028 Ok(())
4029}
4030
4031/// Get a Supermaven API key for the user
4032async fn get_supermaven_api_key(
4033 _request: proto::GetSupermavenApiKey,
4034 response: Response<proto::GetSupermavenApiKey>,
4035 session: Session,
4036) -> Result<()> {
4037 let user_id: String = session.user_id().to_string();
4038 if !session.is_staff() {
4039 return Err(anyhow!("supermaven not enabled for this account"))?;
4040 }
4041
4042 let email = session.email().context("user must have an email")?;
4043
4044 let supermaven_admin_api = session
4045 .supermaven_client
4046 .as_ref()
4047 .context("supermaven not configured")?;
4048
4049 let result = supermaven_admin_api
4050 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4051 .await?;
4052
4053 response.send(proto::GetSupermavenApiKeyResponse {
4054 api_key: result.api_key,
4055 })?;
4056
4057 Ok(())
4058}
4059
4060/// Start receiving chat updates for a channel
4061async fn join_channel_chat(
4062 request: proto::JoinChannelChat,
4063 response: Response<proto::JoinChannelChat>,
4064 session: Session,
4065) -> Result<()> {
4066 let channel_id = ChannelId::from_proto(request.channel_id);
4067
4068 let db = session.db().await;
4069 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4070 .await?;
4071 let messages = db
4072 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4073 .await?;
4074 response.send(proto::JoinChannelChatResponse {
4075 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4076 messages,
4077 })?;
4078 Ok(())
4079}
4080
4081/// Stop receiving chat updates for a channel
4082async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
4083 let channel_id = ChannelId::from_proto(request.channel_id);
4084 session
4085 .db()
4086 .await
4087 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4088 .await?;
4089 Ok(())
4090}
4091
4092/// Retrieve the chat history for a channel
4093async fn get_channel_messages(
4094 request: proto::GetChannelMessages,
4095 response: Response<proto::GetChannelMessages>,
4096 session: Session,
4097) -> Result<()> {
4098 let channel_id = ChannelId::from_proto(request.channel_id);
4099 let messages = session
4100 .db()
4101 .await
4102 .get_channel_messages(
4103 channel_id,
4104 session.user_id(),
4105 MESSAGE_COUNT_PER_PAGE,
4106 Some(MessageId::from_proto(request.before_message_id)),
4107 )
4108 .await?;
4109 response.send(proto::GetChannelMessagesResponse {
4110 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4111 messages,
4112 })?;
4113 Ok(())
4114}
4115
4116/// Retrieve specific chat messages
4117async fn get_channel_messages_by_id(
4118 request: proto::GetChannelMessagesById,
4119 response: Response<proto::GetChannelMessagesById>,
4120 session: Session,
4121) -> Result<()> {
4122 let message_ids = request
4123 .message_ids
4124 .iter()
4125 .map(|id| MessageId::from_proto(*id))
4126 .collect::<Vec<_>>();
4127 let messages = session
4128 .db()
4129 .await
4130 .get_channel_messages_by_id(session.user_id(), &message_ids)
4131 .await?;
4132 response.send(proto::GetChannelMessagesResponse {
4133 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4134 messages,
4135 })?;
4136 Ok(())
4137}
4138
4139/// Retrieve the current users notifications
4140async fn get_notifications(
4141 request: proto::GetNotifications,
4142 response: Response<proto::GetNotifications>,
4143 session: Session,
4144) -> Result<()> {
4145 let notifications = session
4146 .db()
4147 .await
4148 .get_notifications(
4149 session.user_id(),
4150 NOTIFICATION_COUNT_PER_PAGE,
4151 request.before_id.map(db::NotificationId::from_proto),
4152 )
4153 .await?;
4154 response.send(proto::GetNotificationsResponse {
4155 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4156 notifications,
4157 })?;
4158 Ok(())
4159}
4160
4161/// Mark notifications as read
4162async fn mark_notification_as_read(
4163 request: proto::MarkNotificationRead,
4164 response: Response<proto::MarkNotificationRead>,
4165 session: Session,
4166) -> Result<()> {
4167 let database = &session.db().await;
4168 let notifications = database
4169 .mark_notification_as_read_by_id(
4170 session.user_id(),
4171 NotificationId::from_proto(request.notification_id),
4172 )
4173 .await?;
4174 send_notifications(
4175 &*session.connection_pool().await,
4176 &session.peer,
4177 notifications,
4178 );
4179 response.send(proto::Ack {})?;
4180 Ok(())
4181}
4182
4183/// Get the current users information
4184async fn get_private_user_info(
4185 _request: proto::GetPrivateUserInfo,
4186 response: Response<proto::GetPrivateUserInfo>,
4187 session: Session,
4188) -> Result<()> {
4189 let db = session.db().await;
4190
4191 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4192 let user = db
4193 .get_user_by_id(session.user_id())
4194 .await?
4195 .context("user not found")?;
4196 let flags = db.get_user_flags(session.user_id()).await?;
4197
4198 response.send(proto::GetPrivateUserInfoResponse {
4199 metrics_id,
4200 staff: user.admin,
4201 flags,
4202 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4203 })?;
4204 Ok(())
4205}
4206
4207/// Accept the terms of service (tos) on behalf of the current user
4208async fn accept_terms_of_service(
4209 _request: proto::AcceptTermsOfService,
4210 response: Response<proto::AcceptTermsOfService>,
4211 session: Session,
4212) -> Result<()> {
4213 let db = session.db().await;
4214
4215 let accepted_tos_at = Utc::now();
4216 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4217 .await?;
4218
4219 response.send(proto::AcceptTermsOfServiceResponse {
4220 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4221 })?;
4222
4223 // When the user accepts the terms of service, we want to refresh their LLM
4224 // token to grant access.
4225 session
4226 .peer
4227 .send(session.connection_id, proto::RefreshLlmToken {})?;
4228
4229 Ok(())
4230}
4231
4232async fn get_llm_api_token(
4233 _request: proto::GetLlmToken,
4234 response: Response<proto::GetLlmToken>,
4235 session: Session,
4236) -> Result<()> {
4237 let db = session.db().await;
4238
4239 let flags = db.get_user_flags(session.user_id()).await?;
4240
4241 let user_id = session.user_id();
4242 let user = db
4243 .get_user_by_id(user_id)
4244 .await?
4245 .with_context(|| format!("user {user_id} not found"))?;
4246
4247 if user.accepted_tos_at.is_none() {
4248 Err(anyhow!("terms of service not accepted"))?
4249 }
4250
4251 let stripe_client = session
4252 .app_state
4253 .stripe_client
4254 .as_ref()
4255 .context("failed to retrieve Stripe client")?;
4256
4257 let stripe_billing = session
4258 .app_state
4259 .stripe_billing
4260 .as_ref()
4261 .context("failed to retrieve Stripe billing object")?;
4262
4263 let billing_customer = if let Some(billing_customer) =
4264 db.get_billing_customer_by_user_id(user.id).await?
4265 {
4266 billing_customer
4267 } else {
4268 let customer_id = stripe_billing
4269 .find_or_create_customer_by_email(user.email_address.as_deref())
4270 .await?;
4271
4272 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4273 .await?
4274 .context("billing customer not found")?
4275 };
4276
4277 let billing_subscription =
4278 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4279 billing_subscription
4280 } else {
4281 let stripe_customer_id =
4282 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4283
4284 let stripe_subscription = stripe_billing
4285 .subscribe_to_zed_free(stripe_customer_id)
4286 .await?;
4287
4288 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4289 billing_customer_id: billing_customer.id,
4290 kind: Some(SubscriptionKind::ZedFree),
4291 stripe_subscription_id: stripe_subscription.id.to_string(),
4292 stripe_subscription_status: stripe_subscription.status.into(),
4293 stripe_cancellation_reason: None,
4294 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4295 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4296 })
4297 .await?
4298 };
4299
4300 let billing_preferences = db.get_billing_preferences(user.id).await?;
4301
4302 let token = LlmTokenClaims::create(
4303 &user,
4304 session.is_staff(),
4305 billing_customer,
4306 billing_preferences,
4307 &flags,
4308 billing_subscription,
4309 session.system_id.clone(),
4310 &session.app_state.config,
4311 )?;
4312 response.send(proto::GetLlmTokenResponse { token })?;
4313 Ok(())
4314}
4315
4316fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4317 let message = match message {
4318 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4319 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4320 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4321 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4322 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4323 code: frame.code.into(),
4324 reason: frame.reason.as_str().to_owned().into(),
4325 })),
4326 // We should never receive a frame while reading the message, according
4327 // to the `tungstenite` maintainers:
4328 //
4329 // > It cannot occur when you read messages from the WebSocket, but it
4330 // > can be used when you want to send the raw frames (e.g. you want to
4331 // > send the frames to the WebSocket without composing the full message first).
4332 // >
4333 // > — https://github.com/snapview/tungstenite-rs/issues/268
4334 TungsteniteMessage::Frame(_) => {
4335 bail!("received an unexpected frame while reading the message")
4336 }
4337 };
4338
4339 Ok(message)
4340}
4341
4342fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4343 match message {
4344 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4345 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4346 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4347 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4348 AxumMessage::Close(frame) => {
4349 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4350 code: frame.code.into(),
4351 reason: frame.reason.as_ref().into(),
4352 }))
4353 }
4354 }
4355}
4356
4357fn notify_membership_updated(
4358 connection_pool: &mut ConnectionPool,
4359 result: MembershipUpdated,
4360 user_id: UserId,
4361 peer: &Peer,
4362) {
4363 for membership in &result.new_channels.channel_memberships {
4364 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4365 }
4366 for channel_id in &result.removed_channels {
4367 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4368 }
4369
4370 let user_channels_update = proto::UpdateUserChannels {
4371 channel_memberships: result
4372 .new_channels
4373 .channel_memberships
4374 .iter()
4375 .map(|cm| proto::ChannelMembership {
4376 channel_id: cm.channel_id.to_proto(),
4377 role: cm.role.into(),
4378 })
4379 .collect(),
4380 ..Default::default()
4381 };
4382
4383 let mut update = build_channels_update(result.new_channels);
4384 update.delete_channels = result
4385 .removed_channels
4386 .into_iter()
4387 .map(|id| id.to_proto())
4388 .collect();
4389 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4390
4391 for connection_id in connection_pool.user_connection_ids(user_id) {
4392 peer.send(connection_id, user_channels_update.clone())
4393 .trace_err();
4394 peer.send(connection_id, update.clone()).trace_err();
4395 }
4396}
4397
4398fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4399 proto::UpdateUserChannels {
4400 channel_memberships: channels
4401 .channel_memberships
4402 .iter()
4403 .map(|m| proto::ChannelMembership {
4404 channel_id: m.channel_id.to_proto(),
4405 role: m.role.into(),
4406 })
4407 .collect(),
4408 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4409 observed_channel_message_id: channels.observed_channel_messages.clone(),
4410 }
4411}
4412
4413fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4414 let mut update = proto::UpdateChannels::default();
4415
4416 for channel in channels.channels {
4417 update.channels.push(channel.to_proto());
4418 }
4419
4420 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4421 update.latest_channel_message_ids = channels.latest_channel_messages;
4422
4423 for (channel_id, participants) in channels.channel_participants {
4424 update
4425 .channel_participants
4426 .push(proto::ChannelParticipants {
4427 channel_id: channel_id.to_proto(),
4428 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4429 });
4430 }
4431
4432 for channel in channels.invited_channels {
4433 update.channel_invitations.push(channel.to_proto());
4434 }
4435
4436 update
4437}
4438
4439fn build_initial_contacts_update(
4440 contacts: Vec<db::Contact>,
4441 pool: &ConnectionPool,
4442) -> proto::UpdateContacts {
4443 let mut update = proto::UpdateContacts::default();
4444
4445 for contact in contacts {
4446 match contact {
4447 db::Contact::Accepted { user_id, busy } => {
4448 update.contacts.push(contact_for_user(user_id, busy, pool));
4449 }
4450 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4451 db::Contact::Incoming { user_id } => {
4452 update
4453 .incoming_requests
4454 .push(proto::IncomingContactRequest {
4455 requester_id: user_id.to_proto(),
4456 })
4457 }
4458 }
4459 }
4460
4461 update
4462}
4463
4464fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4465 proto::Contact {
4466 user_id: user_id.to_proto(),
4467 online: pool.is_user_online(user_id),
4468 busy,
4469 }
4470}
4471
4472fn room_updated(room: &proto::Room, peer: &Peer) {
4473 broadcast(
4474 None,
4475 room.participants
4476 .iter()
4477 .filter_map(|participant| Some(participant.peer_id?.into())),
4478 |peer_id| {
4479 peer.send(
4480 peer_id,
4481 proto::RoomUpdated {
4482 room: Some(room.clone()),
4483 },
4484 )
4485 },
4486 );
4487}
4488
4489fn channel_updated(
4490 channel: &db::channel::Model,
4491 room: &proto::Room,
4492 peer: &Peer,
4493 pool: &ConnectionPool,
4494) {
4495 let participants = room
4496 .participants
4497 .iter()
4498 .map(|p| p.user_id)
4499 .collect::<Vec<_>>();
4500
4501 broadcast(
4502 None,
4503 pool.channel_connection_ids(channel.root_id())
4504 .filter_map(|(channel_id, role)| {
4505 role.can_see_channel(channel.visibility)
4506 .then_some(channel_id)
4507 }),
4508 |peer_id| {
4509 peer.send(
4510 peer_id,
4511 proto::UpdateChannels {
4512 channel_participants: vec![proto::ChannelParticipants {
4513 channel_id: channel.id.to_proto(),
4514 participant_user_ids: participants.clone(),
4515 }],
4516 ..Default::default()
4517 },
4518 )
4519 },
4520 );
4521}
4522
4523async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4524 let db = session.db().await;
4525
4526 let contacts = db.get_contacts(user_id).await?;
4527 let busy = db.is_user_busy(user_id).await?;
4528
4529 let pool = session.connection_pool().await;
4530 let updated_contact = contact_for_user(user_id, busy, &pool);
4531 for contact in contacts {
4532 if let db::Contact::Accepted {
4533 user_id: contact_user_id,
4534 ..
4535 } = contact
4536 {
4537 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4538 session
4539 .peer
4540 .send(
4541 contact_conn_id,
4542 proto::UpdateContacts {
4543 contacts: vec![updated_contact.clone()],
4544 remove_contacts: Default::default(),
4545 incoming_requests: Default::default(),
4546 remove_incoming_requests: Default::default(),
4547 outgoing_requests: Default::default(),
4548 remove_outgoing_requests: Default::default(),
4549 },
4550 )
4551 .trace_err();
4552 }
4553 }
4554 }
4555 Ok(())
4556}
4557
4558async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4559 let mut contacts_to_update = HashSet::default();
4560
4561 let room_id;
4562 let canceled_calls_to_user_ids;
4563 let livekit_room;
4564 let delete_livekit_room;
4565 let room;
4566 let channel;
4567
4568 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4569 contacts_to_update.insert(session.user_id());
4570
4571 for project in left_room.left_projects.values() {
4572 project_left(project, session);
4573 }
4574
4575 room_id = RoomId::from_proto(left_room.room.id);
4576 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4577 livekit_room = mem::take(&mut left_room.room.livekit_room);
4578 delete_livekit_room = left_room.deleted;
4579 room = mem::take(&mut left_room.room);
4580 channel = mem::take(&mut left_room.channel);
4581
4582 room_updated(&room, &session.peer);
4583 } else {
4584 return Ok(());
4585 }
4586
4587 if let Some(channel) = channel {
4588 channel_updated(
4589 &channel,
4590 &room,
4591 &session.peer,
4592 &*session.connection_pool().await,
4593 );
4594 }
4595
4596 {
4597 let pool = session.connection_pool().await;
4598 for canceled_user_id in canceled_calls_to_user_ids {
4599 for connection_id in pool.user_connection_ids(canceled_user_id) {
4600 session
4601 .peer
4602 .send(
4603 connection_id,
4604 proto::CallCanceled {
4605 room_id: room_id.to_proto(),
4606 },
4607 )
4608 .trace_err();
4609 }
4610 contacts_to_update.insert(canceled_user_id);
4611 }
4612 }
4613
4614 for contact_user_id in contacts_to_update {
4615 update_user_contacts(contact_user_id, session).await?;
4616 }
4617
4618 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4619 live_kit
4620 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4621 .await
4622 .trace_err();
4623
4624 if delete_livekit_room {
4625 live_kit.delete_room(livekit_room).await.trace_err();
4626 }
4627 }
4628
4629 Ok(())
4630}
4631
4632async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4633 let left_channel_buffers = session
4634 .db()
4635 .await
4636 .leave_channel_buffers(session.connection_id)
4637 .await?;
4638
4639 for left_buffer in left_channel_buffers {
4640 channel_buffer_updated(
4641 session.connection_id,
4642 left_buffer.connections,
4643 &proto::UpdateChannelBufferCollaborators {
4644 channel_id: left_buffer.channel_id.to_proto(),
4645 collaborators: left_buffer.collaborators,
4646 },
4647 &session.peer,
4648 );
4649 }
4650
4651 Ok(())
4652}
4653
4654fn project_left(project: &db::LeftProject, session: &Session) {
4655 for connection_id in &project.connection_ids {
4656 if project.should_unshare {
4657 session
4658 .peer
4659 .send(
4660 *connection_id,
4661 proto::UnshareProject {
4662 project_id: project.id.to_proto(),
4663 },
4664 )
4665 .trace_err();
4666 } else {
4667 session
4668 .peer
4669 .send(
4670 *connection_id,
4671 proto::RemoveProjectCollaborator {
4672 project_id: project.id.to_proto(),
4673 peer_id: Some(session.connection_id.into()),
4674 },
4675 )
4676 .trace_err();
4677 }
4678 }
4679}
4680
4681pub trait ResultExt {
4682 type Ok;
4683
4684 fn trace_err(self) -> Option<Self::Ok>;
4685}
4686
4687impl<T, E> ResultExt for Result<T, E>
4688where
4689 E: std::fmt::Debug,
4690{
4691 type Ok = T;
4692
4693 #[track_caller]
4694 fn trace_err(self) -> Option<T> {
4695 match self {
4696 Ok(value) => Some(value),
4697 Err(error) => {
4698 tracing::error!("{:?}", error);
4699 None
4700 }
4701 }
4702 }
4703}