1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{
8 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
9 MIN_ACCOUNT_AGE_FOR_LLM_USE,
10};
11use crate::stripe_client::StripeCustomerId;
12use crate::{
13 AppState, Error, Result, auth,
14 db::{
15 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
16 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
17 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
18 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
19 },
20 executor::Executor,
21};
22use anyhow::{Context as _, anyhow, bail};
23use async_tungstenite::tungstenite::{
24 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
25};
26use axum::{
27 Extension, Router, TypedHeader,
28 body::Body,
29 extract::{
30 ConnectInfo, WebSocketUpgrade,
31 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
32 },
33 headers::{Header, HeaderName},
34 http::StatusCode,
35 middleware,
36 response::IntoResponse,
37 routing::get,
38};
39use chrono::Utc;
40use collections::{HashMap, HashSet};
41pub use connection_pool::{ConnectionPool, ZedVersion};
42use core::fmt::{self, Debug, Formatter};
43use reqwest_client::ReqwestClient;
44use rpc::proto::split_repository_update;
45use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
46
47use futures::{
48 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
49 stream::FuturesUnordered,
50};
51use prometheus::{IntGauge, register_int_gauge};
52use rpc::{
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54 proto::{
55 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
56 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
57 },
58};
59use semantic_version::SemanticVersion;
60use serde::{Serialize, Serializer};
61use std::{
62 any::TypeId,
63 future::Future,
64 marker::PhantomData,
65 mem,
66 net::SocketAddr,
67 ops::{Deref, DerefMut},
68 rc::Rc,
69 sync::{
70 Arc, OnceLock,
71 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
72 },
73 time::{Duration, Instant},
74};
75use time::OffsetDateTime;
76use tokio::sync::{Semaphore, watch};
77use tower::ServiceBuilder;
78use tracing::{
79 Instrument,
80 field::{self},
81 info_span, instrument,
82};
83
84pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
85
86// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
87pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
88
89const MESSAGE_COUNT_PER_PAGE: usize = 100;
90const MAX_MESSAGE_LEN: usize = 1024;
91const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
92const MAX_CONCURRENT_CONNECTIONS: usize = 512;
93
94static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
95
96type MessageHandler =
97 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
98
99pub struct ConnectionGuard;
100
101impl ConnectionGuard {
102 pub fn try_acquire() -> Result<Self, ()> {
103 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
104 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
105 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
106 tracing::error!(
107 "too many concurrent connections: {}",
108 current_connections + 1
109 );
110 return Err(());
111 }
112 Ok(ConnectionGuard)
113 }
114}
115
116impl Drop for ConnectionGuard {
117 fn drop(&mut self) {
118 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
119 }
120}
121
122struct Response<R> {
123 peer: Arc<Peer>,
124 receipt: Receipt<R>,
125 responded: Arc<AtomicBool>,
126}
127
128impl<R: RequestMessage> Response<R> {
129 fn send(self, payload: R::Response) -> Result<()> {
130 self.responded.store(true, SeqCst);
131 self.peer.respond(self.receipt, payload)?;
132 Ok(())
133 }
134}
135
136#[derive(Clone, Debug)]
137pub enum Principal {
138 User(User),
139 Impersonated { user: User, admin: User },
140}
141
142impl Principal {
143 fn user(&self) -> &User {
144 match self {
145 Principal::User(user) => user,
146 Principal::Impersonated { user, .. } => user,
147 }
148 }
149
150 fn update_span(&self, span: &tracing::Span) {
151 match &self {
152 Principal::User(user) => {
153 span.record("user_id", user.id.0);
154 span.record("login", &user.github_login);
155 }
156 Principal::Impersonated { user, admin } => {
157 span.record("user_id", user.id.0);
158 span.record("login", &user.github_login);
159 span.record("impersonator", &admin.github_login);
160 }
161 }
162 }
163}
164
165#[derive(Clone)]
166struct Session {
167 principal: Principal,
168 connection_id: ConnectionId,
169 db: Arc<tokio::sync::Mutex<DbHandle>>,
170 peer: Arc<Peer>,
171 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
172 app_state: Arc<AppState>,
173 supermaven_client: Option<Arc<SupermavenAdminApi>>,
174 /// The GeoIP country code for the user.
175 #[allow(unused)]
176 geoip_country_code: Option<String>,
177 system_id: Option<String>,
178 _executor: Executor,
179}
180
181impl Session {
182 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
183 #[cfg(test)]
184 tokio::task::yield_now().await;
185 let guard = self.db.lock().await;
186 #[cfg(test)]
187 tokio::task::yield_now().await;
188 guard
189 }
190
191 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
192 #[cfg(test)]
193 tokio::task::yield_now().await;
194 let guard = self.connection_pool.lock();
195 ConnectionPoolGuard {
196 guard,
197 _not_send: PhantomData,
198 }
199 }
200
201 fn is_staff(&self) -> bool {
202 match &self.principal {
203 Principal::User(user) => user.admin,
204 Principal::Impersonated { .. } => true,
205 }
206 }
207
208 fn user_id(&self) -> UserId {
209 match &self.principal {
210 Principal::User(user) => user.id,
211 Principal::Impersonated { user, .. } => user.id,
212 }
213 }
214
215 pub fn email(&self) -> Option<String> {
216 match &self.principal {
217 Principal::User(user) => user.email_address.clone(),
218 Principal::Impersonated { user, .. } => user.email_address.clone(),
219 }
220 }
221}
222
223impl Debug for Session {
224 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
225 let mut result = f.debug_struct("Session");
226 match &self.principal {
227 Principal::User(user) => {
228 result.field("user", &user.github_login);
229 }
230 Principal::Impersonated { user, admin } => {
231 result.field("user", &user.github_login);
232 result.field("impersonator", &admin.github_login);
233 }
234 }
235 result.field("connection_id", &self.connection_id).finish()
236 }
237}
238
239struct DbHandle(Arc<Database>);
240
241impl Deref for DbHandle {
242 type Target = Database;
243
244 fn deref(&self) -> &Self::Target {
245 self.0.as_ref()
246 }
247}
248
249pub struct Server {
250 id: parking_lot::Mutex<ServerId>,
251 peer: Arc<Peer>,
252 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
253 app_state: Arc<AppState>,
254 handlers: HashMap<TypeId, MessageHandler>,
255 teardown: watch::Sender<bool>,
256}
257
258pub(crate) struct ConnectionPoolGuard<'a> {
259 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
260 _not_send: PhantomData<Rc<()>>,
261}
262
263#[derive(Serialize)]
264pub struct ServerSnapshot<'a> {
265 peer: &'a Peer,
266 #[serde(serialize_with = "serialize_deref")]
267 connection_pool: ConnectionPoolGuard<'a>,
268}
269
270pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
271where
272 S: Serializer,
273 T: Deref<Target = U>,
274 U: Serialize,
275{
276 Serialize::serialize(value.deref(), serializer)
277}
278
279impl Server {
280 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
281 let mut server = Self {
282 id: parking_lot::Mutex::new(id),
283 peer: Peer::new(id.0 as u32),
284 app_state: app_state.clone(),
285 connection_pool: Default::default(),
286 handlers: Default::default(),
287 teardown: watch::channel(false).0,
288 };
289
290 server
291 .add_request_handler(ping)
292 .add_request_handler(create_room)
293 .add_request_handler(join_room)
294 .add_request_handler(rejoin_room)
295 .add_request_handler(leave_room)
296 .add_request_handler(set_room_participant_role)
297 .add_request_handler(call)
298 .add_request_handler(cancel_call)
299 .add_message_handler(decline_call)
300 .add_request_handler(update_participant_location)
301 .add_request_handler(share_project)
302 .add_message_handler(unshare_project)
303 .add_request_handler(join_project)
304 .add_message_handler(leave_project)
305 .add_request_handler(update_project)
306 .add_request_handler(update_worktree)
307 .add_request_handler(update_repository)
308 .add_request_handler(remove_repository)
309 .add_message_handler(start_language_server)
310 .add_message_handler(update_language_server)
311 .add_message_handler(update_diagnostic_summary)
312 .add_message_handler(update_worktree_settings)
313 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
314 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
316 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
317 .add_request_handler(forward_find_search_candidates_request)
318 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
319 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
320 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
321 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
322 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
323 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
324 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
325 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
326 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
327 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
328 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
329 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
330 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
331 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
332 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
333 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
334 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
335 .add_request_handler(
336 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
337 )
338 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
339 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
340 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
341 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
342 .add_request_handler(
343 forward_read_only_project_request::<proto::LanguageServerIdForName>,
344 )
345 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
346 .add_request_handler(
347 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
348 )
349 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
350 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
353 )
354 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
355 .add_request_handler(
356 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
357 )
358 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
359 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
360 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
361 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
362 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
363 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
364 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
365 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
366 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
367 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
368 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
369 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
370 .add_request_handler(
371 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
372 )
373 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
374 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
375 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
376 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
377 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
378 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
379 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
380 .add_message_handler(create_buffer_for_peer)
381 .add_request_handler(update_buffer)
382 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
383 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
384 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
385 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
386 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
387 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
388 .add_message_handler(
389 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
390 )
391 .add_request_handler(get_users)
392 .add_request_handler(fuzzy_search_users)
393 .add_request_handler(request_contact)
394 .add_request_handler(remove_contact)
395 .add_request_handler(respond_to_contact_request)
396 .add_message_handler(subscribe_to_channels)
397 .add_request_handler(create_channel)
398 .add_request_handler(delete_channel)
399 .add_request_handler(invite_channel_member)
400 .add_request_handler(remove_channel_member)
401 .add_request_handler(set_channel_member_role)
402 .add_request_handler(set_channel_visibility)
403 .add_request_handler(rename_channel)
404 .add_request_handler(join_channel_buffer)
405 .add_request_handler(leave_channel_buffer)
406 .add_message_handler(update_channel_buffer)
407 .add_request_handler(rejoin_channel_buffers)
408 .add_request_handler(get_channel_members)
409 .add_request_handler(respond_to_channel_invite)
410 .add_request_handler(join_channel)
411 .add_request_handler(join_channel_chat)
412 .add_message_handler(leave_channel_chat)
413 .add_request_handler(send_channel_message)
414 .add_request_handler(remove_channel_message)
415 .add_request_handler(update_channel_message)
416 .add_request_handler(get_channel_messages)
417 .add_request_handler(get_channel_messages_by_id)
418 .add_request_handler(get_notifications)
419 .add_request_handler(mark_notification_as_read)
420 .add_request_handler(move_channel)
421 .add_request_handler(reorder_channel)
422 .add_request_handler(follow)
423 .add_message_handler(unfollow)
424 .add_message_handler(update_followers)
425 .add_request_handler(get_private_user_info)
426 .add_request_handler(get_llm_api_token)
427 .add_request_handler(accept_terms_of_service)
428 .add_message_handler(acknowledge_channel_message)
429 .add_message_handler(acknowledge_buffer_version)
430 .add_request_handler(get_supermaven_api_key)
431 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
432 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
433 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
434 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
435 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
436 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
437 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
438 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
439 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
440 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
441 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
442 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
443 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
444 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
445 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
446 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
447 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
448 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
449 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
450 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
451 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
452 .add_message_handler(update_context);
453
454 Arc::new(server)
455 }
456
457 pub async fn start(&self) -> Result<()> {
458 let server_id = *self.id.lock();
459 let app_state = self.app_state.clone();
460 let peer = self.peer.clone();
461 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
462 let pool = self.connection_pool.clone();
463 let livekit_client = self.app_state.livekit_client.clone();
464
465 let span = info_span!("start server");
466 self.app_state.executor.spawn_detached(
467 async move {
468 tracing::info!("waiting for cleanup timeout");
469 timeout.await;
470 tracing::info!("cleanup timeout expired, retrieving stale rooms");
471
472 app_state
473 .db
474 .delete_stale_channel_chat_participants(
475 &app_state.config.zed_environment,
476 server_id,
477 )
478 .await
479 .trace_err();
480
481 if let Some((room_ids, channel_ids)) = app_state
482 .db
483 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
484 .await
485 .trace_err()
486 {
487 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
488 tracing::info!(
489 stale_channel_buffer_count = channel_ids.len(),
490 "retrieved stale channel buffers"
491 );
492
493 for channel_id in channel_ids {
494 if let Some(refreshed_channel_buffer) = app_state
495 .db
496 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
497 .await
498 .trace_err()
499 {
500 for connection_id in refreshed_channel_buffer.connection_ids {
501 peer.send(
502 connection_id,
503 proto::UpdateChannelBufferCollaborators {
504 channel_id: channel_id.to_proto(),
505 collaborators: refreshed_channel_buffer
506 .collaborators
507 .clone(),
508 },
509 )
510 .trace_err();
511 }
512 }
513 }
514
515 for room_id in room_ids {
516 let mut contacts_to_update = HashSet::default();
517 let mut canceled_calls_to_user_ids = Vec::new();
518 let mut livekit_room = String::new();
519 let mut delete_livekit_room = false;
520
521 if let Some(mut refreshed_room) = app_state
522 .db
523 .clear_stale_room_participants(room_id, server_id)
524 .await
525 .trace_err()
526 {
527 tracing::info!(
528 room_id = room_id.0,
529 new_participant_count = refreshed_room.room.participants.len(),
530 "refreshed room"
531 );
532 room_updated(&refreshed_room.room, &peer);
533 if let Some(channel) = refreshed_room.channel.as_ref() {
534 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
535 }
536 contacts_to_update
537 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
538 contacts_to_update
539 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
540 canceled_calls_to_user_ids =
541 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
542 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
543 delete_livekit_room = refreshed_room.room.participants.is_empty();
544 }
545
546 {
547 let pool = pool.lock();
548 for canceled_user_id in canceled_calls_to_user_ids {
549 for connection_id in pool.user_connection_ids(canceled_user_id) {
550 peer.send(
551 connection_id,
552 proto::CallCanceled {
553 room_id: room_id.to_proto(),
554 },
555 )
556 .trace_err();
557 }
558 }
559 }
560
561 for user_id in contacts_to_update {
562 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
563 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
564 if let Some((busy, contacts)) = busy.zip(contacts) {
565 let pool = pool.lock();
566 let updated_contact = contact_for_user(user_id, busy, &pool);
567 for contact in contacts {
568 if let db::Contact::Accepted {
569 user_id: contact_user_id,
570 ..
571 } = contact
572 {
573 for contact_conn_id in
574 pool.user_connection_ids(contact_user_id)
575 {
576 peer.send(
577 contact_conn_id,
578 proto::UpdateContacts {
579 contacts: vec![updated_contact.clone()],
580 remove_contacts: Default::default(),
581 incoming_requests: Default::default(),
582 remove_incoming_requests: Default::default(),
583 outgoing_requests: Default::default(),
584 remove_outgoing_requests: Default::default(),
585 },
586 )
587 .trace_err();
588 }
589 }
590 }
591 }
592 }
593
594 if let Some(live_kit) = livekit_client.as_ref() {
595 if delete_livekit_room {
596 live_kit.delete_room(livekit_room).await.trace_err();
597 }
598 }
599 }
600 }
601
602 app_state
603 .db
604 .delete_stale_channel_chat_participants(
605 &app_state.config.zed_environment,
606 server_id,
607 )
608 .await
609 .trace_err();
610
611 app_state
612 .db
613 .clear_old_worktree_entries(server_id)
614 .await
615 .trace_err();
616
617 app_state
618 .db
619 .delete_stale_servers(&app_state.config.zed_environment, server_id)
620 .await
621 .trace_err();
622 }
623 .instrument(span),
624 );
625 Ok(())
626 }
627
628 pub fn teardown(&self) {
629 self.peer.teardown();
630 self.connection_pool.lock().reset();
631 let _ = self.teardown.send(true);
632 }
633
634 #[cfg(test)]
635 pub fn reset(&self, id: ServerId) {
636 self.teardown();
637 *self.id.lock() = id;
638 self.peer.reset(id.0 as u32);
639 let _ = self.teardown.send(false);
640 }
641
642 #[cfg(test)]
643 pub fn id(&self) -> ServerId {
644 *self.id.lock()
645 }
646
647 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
648 where
649 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
650 Fut: 'static + Send + Future<Output = Result<()>>,
651 M: EnvelopedMessage,
652 {
653 let prev_handler = self.handlers.insert(
654 TypeId::of::<M>(),
655 Box::new(move |envelope, session| {
656 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
657 let received_at = envelope.received_at;
658 tracing::info!("message received");
659 let start_time = Instant::now();
660 let future = (handler)(*envelope, session);
661 async move {
662 let result = future.await;
663 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
664 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
665 let queue_duration_ms = total_duration_ms - processing_duration_ms;
666 let payload_type = M::NAME;
667
668 match result {
669 Err(error) => {
670 tracing::error!(
671 ?error,
672 total_duration_ms,
673 processing_duration_ms,
674 queue_duration_ms,
675 payload_type,
676 "error handling message"
677 )
678 }
679 Ok(()) => tracing::info!(
680 total_duration_ms,
681 processing_duration_ms,
682 queue_duration_ms,
683 "finished handling message"
684 ),
685 }
686 }
687 .boxed()
688 }),
689 );
690 if prev_handler.is_some() {
691 panic!("registered a handler for the same message twice");
692 }
693 self
694 }
695
696 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
697 where
698 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
699 Fut: 'static + Send + Future<Output = Result<()>>,
700 M: EnvelopedMessage,
701 {
702 self.add_handler(move |envelope, session| handler(envelope.payload, session));
703 self
704 }
705
706 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
707 where
708 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
709 Fut: Send + Future<Output = Result<()>>,
710 M: RequestMessage,
711 {
712 let handler = Arc::new(handler);
713 self.add_handler(move |envelope, session| {
714 let receipt = envelope.receipt();
715 let handler = handler.clone();
716 async move {
717 let peer = session.peer.clone();
718 let responded = Arc::new(AtomicBool::default());
719 let response = Response {
720 peer: peer.clone(),
721 responded: responded.clone(),
722 receipt,
723 };
724 match (handler)(envelope.payload, response, session).await {
725 Ok(()) => {
726 if responded.load(std::sync::atomic::Ordering::SeqCst) {
727 Ok(())
728 } else {
729 Err(anyhow!("handler did not send a response"))?
730 }
731 }
732 Err(error) => {
733 let proto_err = match &error {
734 Error::Internal(err) => err.to_proto(),
735 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
736 };
737 peer.respond_with_error(receipt, proto_err)?;
738 Err(error)
739 }
740 }
741 }
742 })
743 }
744
745 pub fn handle_connection(
746 self: &Arc<Self>,
747 connection: Connection,
748 address: String,
749 principal: Principal,
750 zed_version: ZedVersion,
751 geoip_country_code: Option<String>,
752 system_id: Option<String>,
753 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
754 executor: Executor,
755 connection_guard: Option<ConnectionGuard>,
756 ) -> impl Future<Output = ()> + use<> {
757 let this = self.clone();
758 let span = info_span!("handle connection", %address,
759 connection_id=field::Empty,
760 user_id=field::Empty,
761 login=field::Empty,
762 impersonator=field::Empty,
763 geoip_country_code=field::Empty
764 );
765 principal.update_span(&span);
766 if let Some(country_code) = geoip_country_code.as_ref() {
767 span.record("geoip_country_code", country_code);
768 }
769
770 let mut teardown = self.teardown.subscribe();
771 async move {
772 if *teardown.borrow() {
773 tracing::error!("server is tearing down");
774 return
775 }
776
777 let (connection_id, handle_io, mut incoming_rx) = this
778 .peer
779 .add_connection(connection, {
780 let executor = executor.clone();
781 move |duration| executor.sleep(duration)
782 });
783 tracing::Span::current().record("connection_id", format!("{}", connection_id));
784
785 tracing::info!("connection opened");
786
787 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
788 let http_client = match ReqwestClient::user_agent(&user_agent) {
789 Ok(http_client) => Arc::new(http_client),
790 Err(error) => {
791 tracing::error!(?error, "failed to create HTTP client");
792 return;
793 }
794 };
795
796 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
797 supermaven_admin_api_key.to_string(),
798 http_client.clone(),
799 )));
800
801 let session = Session {
802 principal: principal.clone(),
803 connection_id,
804 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
805 peer: this.peer.clone(),
806 connection_pool: this.connection_pool.clone(),
807 app_state: this.app_state.clone(),
808 geoip_country_code,
809 system_id,
810 _executor: executor.clone(),
811 supermaven_client,
812 };
813
814 if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
815 tracing::error!(?error, "failed to send initial client update");
816 return;
817 }
818 drop(connection_guard);
819
820 let handle_io = handle_io.fuse();
821 futures::pin_mut!(handle_io);
822
823 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
824 // This prevents deadlocks when e.g., client A performs a request to client B and
825 // client B performs a request to client A. If both clients stop processing further
826 // messages until their respective request completes, they won't have a chance to
827 // respond to the other client's request and cause a deadlock.
828 //
829 // This arrangement ensures we will attempt to process earlier messages first, but fall
830 // back to processing messages arrived later in the spirit of making progress.
831 let mut foreground_message_handlers = FuturesUnordered::new();
832 let concurrent_handlers = Arc::new(Semaphore::new(256));
833 loop {
834 let next_message = async {
835 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
836 let message = incoming_rx.next().await;
837 (permit, message)
838 }.fuse();
839 futures::pin_mut!(next_message);
840 futures::select_biased! {
841 _ = teardown.changed().fuse() => return,
842 result = handle_io => {
843 if let Err(error) = result {
844 tracing::error!(?error, "error handling I/O");
845 }
846 break;
847 }
848 _ = foreground_message_handlers.next() => {}
849 next_message = next_message => {
850 let (permit, message) = next_message;
851 if let Some(message) = message {
852 let type_name = message.payload_type_name();
853 // note: we copy all the fields from the parent span so we can query them in the logs.
854 // (https://github.com/tokio-rs/tracing/issues/2670).
855 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
856 user_id=field::Empty,
857 login=field::Empty,
858 impersonator=field::Empty,
859 );
860 principal.update_span(&span);
861 let span_enter = span.enter();
862 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
863 let is_background = message.is_background();
864 let handle_message = (handler)(message, session.clone());
865 drop(span_enter);
866
867 let handle_message = async move {
868 handle_message.await;
869 drop(permit);
870 }.instrument(span);
871 if is_background {
872 executor.spawn_detached(handle_message);
873 } else {
874 foreground_message_handlers.push(handle_message);
875 }
876 } else {
877 tracing::error!("no message handler");
878 }
879 } else {
880 tracing::info!("connection closed");
881 break;
882 }
883 }
884 }
885 }
886
887 drop(foreground_message_handlers);
888 tracing::info!("signing out");
889 if let Err(error) = connection_lost(session, teardown, executor).await {
890 tracing::error!(?error, "error signing out");
891 }
892
893 }.instrument(span)
894 }
895
896 async fn send_initial_client_update(
897 &self,
898 connection_id: ConnectionId,
899 zed_version: ZedVersion,
900 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
901 session: &Session,
902 ) -> Result<()> {
903 self.peer.send(
904 connection_id,
905 proto::Hello {
906 peer_id: Some(connection_id.into()),
907 },
908 )?;
909 tracing::info!("sent hello message");
910 if let Some(send_connection_id) = send_connection_id.take() {
911 let _ = send_connection_id.send(connection_id);
912 }
913
914 match &session.principal {
915 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
916 if !user.connected_once {
917 self.peer.send(connection_id, proto::ShowContacts {})?;
918 self.app_state
919 .db
920 .set_user_connected_once(user.id, true)
921 .await?;
922 }
923
924 update_user_plan(session).await?;
925
926 let contacts = self.app_state.db.get_contacts(user.id).await?;
927
928 {
929 let mut pool = self.connection_pool.lock();
930 pool.add_connection(connection_id, user.id, user.admin, zed_version);
931 self.peer.send(
932 connection_id,
933 build_initial_contacts_update(contacts, &pool),
934 )?;
935 }
936
937 if should_auto_subscribe_to_channels(zed_version) {
938 subscribe_user_to_channels(user.id, session).await?;
939 }
940
941 if let Some(incoming_call) =
942 self.app_state.db.incoming_call_for_user(user.id).await?
943 {
944 self.peer.send(connection_id, incoming_call)?;
945 }
946
947 update_user_contacts(user.id, session).await?;
948 }
949 }
950
951 Ok(())
952 }
953
954 pub async fn invite_code_redeemed(
955 self: &Arc<Self>,
956 inviter_id: UserId,
957 invitee_id: UserId,
958 ) -> Result<()> {
959 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
960 if let Some(code) = &user.invite_code {
961 let pool = self.connection_pool.lock();
962 let invitee_contact = contact_for_user(invitee_id, false, &pool);
963 for connection_id in pool.user_connection_ids(inviter_id) {
964 self.peer.send(
965 connection_id,
966 proto::UpdateContacts {
967 contacts: vec![invitee_contact.clone()],
968 ..Default::default()
969 },
970 )?;
971 self.peer.send(
972 connection_id,
973 proto::UpdateInviteInfo {
974 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
975 count: user.invite_count as u32,
976 },
977 )?;
978 }
979 }
980 }
981 Ok(())
982 }
983
984 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
985 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
986 if let Some(invite_code) = &user.invite_code {
987 let pool = self.connection_pool.lock();
988 for connection_id in pool.user_connection_ids(user_id) {
989 self.peer.send(
990 connection_id,
991 proto::UpdateInviteInfo {
992 url: format!(
993 "{}{}",
994 self.app_state.config.invite_link_prefix, invite_code
995 ),
996 count: user.invite_count as u32,
997 },
998 )?;
999 }
1000 }
1001 }
1002 Ok(())
1003 }
1004
1005 pub async fn update_plan_for_user(
1006 self: &Arc<Self>,
1007 user_id: UserId,
1008 update_user_plan: proto::UpdateUserPlan,
1009 ) -> Result<()> {
1010 let pool = self.connection_pool.lock();
1011 for connection_id in pool.user_connection_ids(user_id) {
1012 self.peer
1013 .send(connection_id, update_user_plan.clone())
1014 .trace_err();
1015 }
1016
1017 Ok(())
1018 }
1019
1020 /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1021 /// message on the Collab server.
1022 ///
1023 /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1024 pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1025 let user = self
1026 .app_state
1027 .db
1028 .get_user_by_id(user_id)
1029 .await?
1030 .context("user not found")?;
1031
1032 let update_user_plan = make_update_user_plan_message(
1033 &user,
1034 user.admin,
1035 &self.app_state.db,
1036 self.app_state.llm_db.clone(),
1037 )
1038 .await?;
1039
1040 self.update_plan_for_user(user_id, update_user_plan).await
1041 }
1042
1043 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1044 let pool = self.connection_pool.lock();
1045 for connection_id in pool.user_connection_ids(user_id) {
1046 self.peer
1047 .send(connection_id, proto::RefreshLlmToken {})
1048 .trace_err();
1049 }
1050 }
1051
1052 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1053 ServerSnapshot {
1054 connection_pool: ConnectionPoolGuard {
1055 guard: self.connection_pool.lock(),
1056 _not_send: PhantomData,
1057 },
1058 peer: &self.peer,
1059 }
1060 }
1061}
1062
1063impl Deref for ConnectionPoolGuard<'_> {
1064 type Target = ConnectionPool;
1065
1066 fn deref(&self) -> &Self::Target {
1067 &self.guard
1068 }
1069}
1070
1071impl DerefMut for ConnectionPoolGuard<'_> {
1072 fn deref_mut(&mut self) -> &mut Self::Target {
1073 &mut self.guard
1074 }
1075}
1076
1077impl Drop for ConnectionPoolGuard<'_> {
1078 fn drop(&mut self) {
1079 #[cfg(test)]
1080 self.check_invariants();
1081 }
1082}
1083
1084fn broadcast<F>(
1085 sender_id: Option<ConnectionId>,
1086 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1087 mut f: F,
1088) where
1089 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1090{
1091 for receiver_id in receiver_ids {
1092 if Some(receiver_id) != sender_id {
1093 if let Err(error) = f(receiver_id) {
1094 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1095 }
1096 }
1097 }
1098}
1099
1100pub struct ProtocolVersion(u32);
1101
1102impl Header for ProtocolVersion {
1103 fn name() -> &'static HeaderName {
1104 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1105 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1106 }
1107
1108 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1109 where
1110 Self: Sized,
1111 I: Iterator<Item = &'i axum::http::HeaderValue>,
1112 {
1113 let version = values
1114 .next()
1115 .ok_or_else(axum::headers::Error::invalid)?
1116 .to_str()
1117 .map_err(|_| axum::headers::Error::invalid())?
1118 .parse()
1119 .map_err(|_| axum::headers::Error::invalid())?;
1120 Ok(Self(version))
1121 }
1122
1123 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1124 values.extend([self.0.to_string().parse().unwrap()]);
1125 }
1126}
1127
1128pub struct AppVersionHeader(SemanticVersion);
1129impl Header for AppVersionHeader {
1130 fn name() -> &'static HeaderName {
1131 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1132 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1133 }
1134
1135 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1136 where
1137 Self: Sized,
1138 I: Iterator<Item = &'i axum::http::HeaderValue>,
1139 {
1140 let version = values
1141 .next()
1142 .ok_or_else(axum::headers::Error::invalid)?
1143 .to_str()
1144 .map_err(|_| axum::headers::Error::invalid())?
1145 .parse()
1146 .map_err(|_| axum::headers::Error::invalid())?;
1147 Ok(Self(version))
1148 }
1149
1150 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1151 values.extend([self.0.to_string().parse().unwrap()]);
1152 }
1153}
1154
1155pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1156 Router::new()
1157 .route("/rpc", get(handle_websocket_request))
1158 .layer(
1159 ServiceBuilder::new()
1160 .layer(Extension(server.app_state.clone()))
1161 .layer(middleware::from_fn(auth::validate_header)),
1162 )
1163 .route("/metrics", get(handle_metrics))
1164 .layer(Extension(server))
1165}
1166
1167pub async fn handle_websocket_request(
1168 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1169 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1170 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1171 Extension(server): Extension<Arc<Server>>,
1172 Extension(principal): Extension<Principal>,
1173 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1174 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1175 ws: WebSocketUpgrade,
1176) -> axum::response::Response {
1177 if protocol_version != rpc::PROTOCOL_VERSION {
1178 return (
1179 StatusCode::UPGRADE_REQUIRED,
1180 "client must be upgraded".to_string(),
1181 )
1182 .into_response();
1183 }
1184
1185 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1186 return (
1187 StatusCode::UPGRADE_REQUIRED,
1188 "no version header found".to_string(),
1189 )
1190 .into_response();
1191 };
1192
1193 if !version.can_collaborate() {
1194 return (
1195 StatusCode::UPGRADE_REQUIRED,
1196 "client must be upgraded".to_string(),
1197 )
1198 .into_response();
1199 }
1200
1201 let socket_address = socket_address.to_string();
1202
1203 // Acquire connection guard before WebSocket upgrade
1204 let connection_guard = match ConnectionGuard::try_acquire() {
1205 Ok(guard) => guard,
1206 Err(()) => {
1207 return (
1208 StatusCode::SERVICE_UNAVAILABLE,
1209 "Too many concurrent connections",
1210 )
1211 .into_response();
1212 }
1213 };
1214
1215 ws.on_upgrade(move |socket| {
1216 let socket = socket
1217 .map_ok(to_tungstenite_message)
1218 .err_into()
1219 .with(|message| async move { to_axum_message(message) });
1220 let connection = Connection::new(Box::pin(socket));
1221 async move {
1222 server
1223 .handle_connection(
1224 connection,
1225 socket_address,
1226 principal,
1227 version,
1228 country_code_header.map(|header| header.to_string()),
1229 system_id_header.map(|header| header.to_string()),
1230 None,
1231 Executor::Production,
1232 Some(connection_guard),
1233 )
1234 .await;
1235 }
1236 })
1237}
1238
1239pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1240 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1241 let connections_metric = CONNECTIONS_METRIC
1242 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1243
1244 let connections = server
1245 .connection_pool
1246 .lock()
1247 .connections()
1248 .filter(|connection| !connection.admin)
1249 .count();
1250 connections_metric.set(connections as _);
1251
1252 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1253 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1254 register_int_gauge!(
1255 "shared_projects",
1256 "number of open projects with one or more guests"
1257 )
1258 .unwrap()
1259 });
1260
1261 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1262 shared_projects_metric.set(shared_projects as _);
1263
1264 let encoder = prometheus::TextEncoder::new();
1265 let metric_families = prometheus::gather();
1266 let encoded_metrics = encoder
1267 .encode_to_string(&metric_families)
1268 .map_err(|err| anyhow!("{err}"))?;
1269 Ok(encoded_metrics)
1270}
1271
1272#[instrument(err, skip(executor))]
1273async fn connection_lost(
1274 session: Session,
1275 mut teardown: watch::Receiver<bool>,
1276 executor: Executor,
1277) -> Result<()> {
1278 session.peer.disconnect(session.connection_id);
1279 session
1280 .connection_pool()
1281 .await
1282 .remove_connection(session.connection_id)?;
1283
1284 session
1285 .db()
1286 .await
1287 .connection_lost(session.connection_id)
1288 .await
1289 .trace_err();
1290
1291 futures::select_biased! {
1292 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1293
1294 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1295 leave_room_for_session(&session, session.connection_id).await.trace_err();
1296 leave_channel_buffers_for_session(&session)
1297 .await
1298 .trace_err();
1299
1300 if !session
1301 .connection_pool()
1302 .await
1303 .is_user_online(session.user_id())
1304 {
1305 let db = session.db().await;
1306 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1307 room_updated(&room, &session.peer);
1308 }
1309 }
1310
1311 update_user_contacts(session.user_id(), &session).await?;
1312 },
1313 _ = teardown.changed().fuse() => {}
1314 }
1315
1316 Ok(())
1317}
1318
1319/// Acknowledges a ping from a client, used to keep the connection alive.
1320async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1321 response.send(proto::Ack {})?;
1322 Ok(())
1323}
1324
1325/// Creates a new room for calling (outside of channels)
1326async fn create_room(
1327 _request: proto::CreateRoom,
1328 response: Response<proto::CreateRoom>,
1329 session: Session,
1330) -> Result<()> {
1331 let livekit_room = nanoid::nanoid!(30);
1332
1333 let live_kit_connection_info = util::maybe!(async {
1334 let live_kit = session.app_state.livekit_client.as_ref();
1335 let live_kit = live_kit?;
1336 let user_id = session.user_id().to_string();
1337
1338 let token = live_kit
1339 .room_token(&livekit_room, &user_id.to_string())
1340 .trace_err()?;
1341
1342 Some(proto::LiveKitConnectionInfo {
1343 server_url: live_kit.url().into(),
1344 token,
1345 can_publish: true,
1346 })
1347 })
1348 .await;
1349
1350 let room = session
1351 .db()
1352 .await
1353 .create_room(session.user_id(), session.connection_id, &livekit_room)
1354 .await?;
1355
1356 response.send(proto::CreateRoomResponse {
1357 room: Some(room.clone()),
1358 live_kit_connection_info,
1359 })?;
1360
1361 update_user_contacts(session.user_id(), &session).await?;
1362 Ok(())
1363}
1364
1365/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1366async fn join_room(
1367 request: proto::JoinRoom,
1368 response: Response<proto::JoinRoom>,
1369 session: Session,
1370) -> Result<()> {
1371 let room_id = RoomId::from_proto(request.id);
1372
1373 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1374
1375 if let Some(channel_id) = channel_id {
1376 return join_channel_internal(channel_id, Box::new(response), session).await;
1377 }
1378
1379 let joined_room = {
1380 let room = session
1381 .db()
1382 .await
1383 .join_room(room_id, session.user_id(), session.connection_id)
1384 .await?;
1385 room_updated(&room.room, &session.peer);
1386 room.into_inner()
1387 };
1388
1389 for connection_id in session
1390 .connection_pool()
1391 .await
1392 .user_connection_ids(session.user_id())
1393 {
1394 session
1395 .peer
1396 .send(
1397 connection_id,
1398 proto::CallCanceled {
1399 room_id: room_id.to_proto(),
1400 },
1401 )
1402 .trace_err();
1403 }
1404
1405 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1406 {
1407 live_kit
1408 .room_token(
1409 &joined_room.room.livekit_room,
1410 &session.user_id().to_string(),
1411 )
1412 .trace_err()
1413 .map(|token| proto::LiveKitConnectionInfo {
1414 server_url: live_kit.url().into(),
1415 token,
1416 can_publish: true,
1417 })
1418 } else {
1419 None
1420 };
1421
1422 response.send(proto::JoinRoomResponse {
1423 room: Some(joined_room.room),
1424 channel_id: None,
1425 live_kit_connection_info,
1426 })?;
1427
1428 update_user_contacts(session.user_id(), &session).await?;
1429 Ok(())
1430}
1431
1432/// Rejoin room is used to reconnect to a room after connection errors.
1433async fn rejoin_room(
1434 request: proto::RejoinRoom,
1435 response: Response<proto::RejoinRoom>,
1436 session: Session,
1437) -> Result<()> {
1438 let room;
1439 let channel;
1440 {
1441 let mut rejoined_room = session
1442 .db()
1443 .await
1444 .rejoin_room(request, session.user_id(), session.connection_id)
1445 .await?;
1446
1447 response.send(proto::RejoinRoomResponse {
1448 room: Some(rejoined_room.room.clone()),
1449 reshared_projects: rejoined_room
1450 .reshared_projects
1451 .iter()
1452 .map(|project| proto::ResharedProject {
1453 id: project.id.to_proto(),
1454 collaborators: project
1455 .collaborators
1456 .iter()
1457 .map(|collaborator| collaborator.to_proto())
1458 .collect(),
1459 })
1460 .collect(),
1461 rejoined_projects: rejoined_room
1462 .rejoined_projects
1463 .iter()
1464 .map(|rejoined_project| rejoined_project.to_proto())
1465 .collect(),
1466 })?;
1467 room_updated(&rejoined_room.room, &session.peer);
1468
1469 for project in &rejoined_room.reshared_projects {
1470 for collaborator in &project.collaborators {
1471 session
1472 .peer
1473 .send(
1474 collaborator.connection_id,
1475 proto::UpdateProjectCollaborator {
1476 project_id: project.id.to_proto(),
1477 old_peer_id: Some(project.old_connection_id.into()),
1478 new_peer_id: Some(session.connection_id.into()),
1479 },
1480 )
1481 .trace_err();
1482 }
1483
1484 broadcast(
1485 Some(session.connection_id),
1486 project
1487 .collaborators
1488 .iter()
1489 .map(|collaborator| collaborator.connection_id),
1490 |connection_id| {
1491 session.peer.forward_send(
1492 session.connection_id,
1493 connection_id,
1494 proto::UpdateProject {
1495 project_id: project.id.to_proto(),
1496 worktrees: project.worktrees.clone(),
1497 },
1498 )
1499 },
1500 );
1501 }
1502
1503 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1504
1505 let rejoined_room = rejoined_room.into_inner();
1506
1507 room = rejoined_room.room;
1508 channel = rejoined_room.channel;
1509 }
1510
1511 if let Some(channel) = channel {
1512 channel_updated(
1513 &channel,
1514 &room,
1515 &session.peer,
1516 &*session.connection_pool().await,
1517 );
1518 }
1519
1520 update_user_contacts(session.user_id(), &session).await?;
1521 Ok(())
1522}
1523
1524fn notify_rejoined_projects(
1525 rejoined_projects: &mut Vec<RejoinedProject>,
1526 session: &Session,
1527) -> Result<()> {
1528 for project in rejoined_projects.iter() {
1529 for collaborator in &project.collaborators {
1530 session
1531 .peer
1532 .send(
1533 collaborator.connection_id,
1534 proto::UpdateProjectCollaborator {
1535 project_id: project.id.to_proto(),
1536 old_peer_id: Some(project.old_connection_id.into()),
1537 new_peer_id: Some(session.connection_id.into()),
1538 },
1539 )
1540 .trace_err();
1541 }
1542 }
1543
1544 for project in rejoined_projects {
1545 for worktree in mem::take(&mut project.worktrees) {
1546 // Stream this worktree's entries.
1547 let message = proto::UpdateWorktree {
1548 project_id: project.id.to_proto(),
1549 worktree_id: worktree.id,
1550 abs_path: worktree.abs_path.clone(),
1551 root_name: worktree.root_name,
1552 updated_entries: worktree.updated_entries,
1553 removed_entries: worktree.removed_entries,
1554 scan_id: worktree.scan_id,
1555 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1556 updated_repositories: worktree.updated_repositories,
1557 removed_repositories: worktree.removed_repositories,
1558 };
1559 for update in proto::split_worktree_update(message) {
1560 session.peer.send(session.connection_id, update)?;
1561 }
1562
1563 // Stream this worktree's diagnostics.
1564 for summary in worktree.diagnostic_summaries {
1565 session.peer.send(
1566 session.connection_id,
1567 proto::UpdateDiagnosticSummary {
1568 project_id: project.id.to_proto(),
1569 worktree_id: worktree.id,
1570 summary: Some(summary),
1571 },
1572 )?;
1573 }
1574
1575 for settings_file in worktree.settings_files {
1576 session.peer.send(
1577 session.connection_id,
1578 proto::UpdateWorktreeSettings {
1579 project_id: project.id.to_proto(),
1580 worktree_id: worktree.id,
1581 path: settings_file.path,
1582 content: Some(settings_file.content),
1583 kind: Some(settings_file.kind.to_proto().into()),
1584 },
1585 )?;
1586 }
1587 }
1588
1589 for repository in mem::take(&mut project.updated_repositories) {
1590 for update in split_repository_update(repository) {
1591 session.peer.send(session.connection_id, update)?;
1592 }
1593 }
1594
1595 for id in mem::take(&mut project.removed_repositories) {
1596 session.peer.send(
1597 session.connection_id,
1598 proto::RemoveRepository {
1599 project_id: project.id.to_proto(),
1600 id,
1601 },
1602 )?;
1603 }
1604 }
1605
1606 Ok(())
1607}
1608
1609/// leave room disconnects from the room.
1610async fn leave_room(
1611 _: proto::LeaveRoom,
1612 response: Response<proto::LeaveRoom>,
1613 session: Session,
1614) -> Result<()> {
1615 leave_room_for_session(&session, session.connection_id).await?;
1616 response.send(proto::Ack {})?;
1617 Ok(())
1618}
1619
1620/// Updates the permissions of someone else in the room.
1621async fn set_room_participant_role(
1622 request: proto::SetRoomParticipantRole,
1623 response: Response<proto::SetRoomParticipantRole>,
1624 session: Session,
1625) -> Result<()> {
1626 let user_id = UserId::from_proto(request.user_id);
1627 let role = ChannelRole::from(request.role());
1628
1629 let (livekit_room, can_publish) = {
1630 let room = session
1631 .db()
1632 .await
1633 .set_room_participant_role(
1634 session.user_id(),
1635 RoomId::from_proto(request.room_id),
1636 user_id,
1637 role,
1638 )
1639 .await?;
1640
1641 let livekit_room = room.livekit_room.clone();
1642 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1643 room_updated(&room, &session.peer);
1644 (livekit_room, can_publish)
1645 };
1646
1647 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1648 live_kit
1649 .update_participant(
1650 livekit_room.clone(),
1651 request.user_id.to_string(),
1652 livekit_api::proto::ParticipantPermission {
1653 can_subscribe: true,
1654 can_publish,
1655 can_publish_data: can_publish,
1656 hidden: false,
1657 recorder: false,
1658 },
1659 )
1660 .await
1661 .trace_err();
1662 }
1663
1664 response.send(proto::Ack {})?;
1665 Ok(())
1666}
1667
1668/// Call someone else into the current room
1669async fn call(
1670 request: proto::Call,
1671 response: Response<proto::Call>,
1672 session: Session,
1673) -> Result<()> {
1674 let room_id = RoomId::from_proto(request.room_id);
1675 let calling_user_id = session.user_id();
1676 let calling_connection_id = session.connection_id;
1677 let called_user_id = UserId::from_proto(request.called_user_id);
1678 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1679 if !session
1680 .db()
1681 .await
1682 .has_contact(calling_user_id, called_user_id)
1683 .await?
1684 {
1685 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1686 }
1687
1688 let incoming_call = {
1689 let (room, incoming_call) = &mut *session
1690 .db()
1691 .await
1692 .call(
1693 room_id,
1694 calling_user_id,
1695 calling_connection_id,
1696 called_user_id,
1697 initial_project_id,
1698 )
1699 .await?;
1700 room_updated(room, &session.peer);
1701 mem::take(incoming_call)
1702 };
1703 update_user_contacts(called_user_id, &session).await?;
1704
1705 let mut calls = session
1706 .connection_pool()
1707 .await
1708 .user_connection_ids(called_user_id)
1709 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1710 .collect::<FuturesUnordered<_>>();
1711
1712 while let Some(call_response) = calls.next().await {
1713 match call_response.as_ref() {
1714 Ok(_) => {
1715 response.send(proto::Ack {})?;
1716 return Ok(());
1717 }
1718 Err(_) => {
1719 call_response.trace_err();
1720 }
1721 }
1722 }
1723
1724 {
1725 let room = session
1726 .db()
1727 .await
1728 .call_failed(room_id, called_user_id)
1729 .await?;
1730 room_updated(&room, &session.peer);
1731 }
1732 update_user_contacts(called_user_id, &session).await?;
1733
1734 Err(anyhow!("failed to ring user"))?
1735}
1736
1737/// Cancel an outgoing call.
1738async fn cancel_call(
1739 request: proto::CancelCall,
1740 response: Response<proto::CancelCall>,
1741 session: Session,
1742) -> Result<()> {
1743 let called_user_id = UserId::from_proto(request.called_user_id);
1744 let room_id = RoomId::from_proto(request.room_id);
1745 {
1746 let room = session
1747 .db()
1748 .await
1749 .cancel_call(room_id, session.connection_id, called_user_id)
1750 .await?;
1751 room_updated(&room, &session.peer);
1752 }
1753
1754 for connection_id in session
1755 .connection_pool()
1756 .await
1757 .user_connection_ids(called_user_id)
1758 {
1759 session
1760 .peer
1761 .send(
1762 connection_id,
1763 proto::CallCanceled {
1764 room_id: room_id.to_proto(),
1765 },
1766 )
1767 .trace_err();
1768 }
1769 response.send(proto::Ack {})?;
1770
1771 update_user_contacts(called_user_id, &session).await?;
1772 Ok(())
1773}
1774
1775/// Decline an incoming call.
1776async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1777 let room_id = RoomId::from_proto(message.room_id);
1778 {
1779 let room = session
1780 .db()
1781 .await
1782 .decline_call(Some(room_id), session.user_id())
1783 .await?
1784 .context("declining call")?;
1785 room_updated(&room, &session.peer);
1786 }
1787
1788 for connection_id in session
1789 .connection_pool()
1790 .await
1791 .user_connection_ids(session.user_id())
1792 {
1793 session
1794 .peer
1795 .send(
1796 connection_id,
1797 proto::CallCanceled {
1798 room_id: room_id.to_proto(),
1799 },
1800 )
1801 .trace_err();
1802 }
1803 update_user_contacts(session.user_id(), &session).await?;
1804 Ok(())
1805}
1806
1807/// Updates other participants in the room with your current location.
1808async fn update_participant_location(
1809 request: proto::UpdateParticipantLocation,
1810 response: Response<proto::UpdateParticipantLocation>,
1811 session: Session,
1812) -> Result<()> {
1813 let room_id = RoomId::from_proto(request.room_id);
1814 let location = request.location.context("invalid location")?;
1815
1816 let db = session.db().await;
1817 let room = db
1818 .update_room_participant_location(room_id, session.connection_id, location)
1819 .await?;
1820
1821 room_updated(&room, &session.peer);
1822 response.send(proto::Ack {})?;
1823 Ok(())
1824}
1825
1826/// Share a project into the room.
1827async fn share_project(
1828 request: proto::ShareProject,
1829 response: Response<proto::ShareProject>,
1830 session: Session,
1831) -> Result<()> {
1832 let (project_id, room) = &*session
1833 .db()
1834 .await
1835 .share_project(
1836 RoomId::from_proto(request.room_id),
1837 session.connection_id,
1838 &request.worktrees,
1839 request.is_ssh_project,
1840 )
1841 .await?;
1842 response.send(proto::ShareProjectResponse {
1843 project_id: project_id.to_proto(),
1844 })?;
1845 room_updated(room, &session.peer);
1846
1847 Ok(())
1848}
1849
1850/// Unshare a project from the room.
1851async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1852 let project_id = ProjectId::from_proto(message.project_id);
1853 unshare_project_internal(project_id, session.connection_id, &session).await
1854}
1855
1856async fn unshare_project_internal(
1857 project_id: ProjectId,
1858 connection_id: ConnectionId,
1859 session: &Session,
1860) -> Result<()> {
1861 let delete = {
1862 let room_guard = session
1863 .db()
1864 .await
1865 .unshare_project(project_id, connection_id)
1866 .await?;
1867
1868 let (delete, room, guest_connection_ids) = &*room_guard;
1869
1870 let message = proto::UnshareProject {
1871 project_id: project_id.to_proto(),
1872 };
1873
1874 broadcast(
1875 Some(connection_id),
1876 guest_connection_ids.iter().copied(),
1877 |conn_id| session.peer.send(conn_id, message.clone()),
1878 );
1879 if let Some(room) = room {
1880 room_updated(room, &session.peer);
1881 }
1882
1883 *delete
1884 };
1885
1886 if delete {
1887 let db = session.db().await;
1888 db.delete_project(project_id).await?;
1889 }
1890
1891 Ok(())
1892}
1893
1894/// Join someone elses shared project.
1895async fn join_project(
1896 request: proto::JoinProject,
1897 response: Response<proto::JoinProject>,
1898 session: Session,
1899) -> Result<()> {
1900 let project_id = ProjectId::from_proto(request.project_id);
1901
1902 tracing::info!(%project_id, "join project");
1903
1904 let db = session.db().await;
1905 let (project, replica_id) = &mut *db
1906 .join_project(
1907 project_id,
1908 session.connection_id,
1909 session.user_id(),
1910 request.committer_name.clone(),
1911 request.committer_email.clone(),
1912 )
1913 .await?;
1914 drop(db);
1915 tracing::info!(%project_id, "join remote project");
1916 let collaborators = project
1917 .collaborators
1918 .iter()
1919 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1920 .map(|collaborator| collaborator.to_proto())
1921 .collect::<Vec<_>>();
1922 let project_id = project.id;
1923 let guest_user_id = session.user_id();
1924
1925 let worktrees = project
1926 .worktrees
1927 .iter()
1928 .map(|(id, worktree)| proto::WorktreeMetadata {
1929 id: *id,
1930 root_name: worktree.root_name.clone(),
1931 visible: worktree.visible,
1932 abs_path: worktree.abs_path.clone(),
1933 })
1934 .collect::<Vec<_>>();
1935
1936 let add_project_collaborator = proto::AddProjectCollaborator {
1937 project_id: project_id.to_proto(),
1938 collaborator: Some(proto::Collaborator {
1939 peer_id: Some(session.connection_id.into()),
1940 replica_id: replica_id.0 as u32,
1941 user_id: guest_user_id.to_proto(),
1942 is_host: false,
1943 committer_name: request.committer_name.clone(),
1944 committer_email: request.committer_email.clone(),
1945 }),
1946 };
1947
1948 for collaborator in &collaborators {
1949 session
1950 .peer
1951 .send(
1952 collaborator.peer_id.unwrap().into(),
1953 add_project_collaborator.clone(),
1954 )
1955 .trace_err();
1956 }
1957
1958 // First, we send the metadata associated with each worktree.
1959 response.send(proto::JoinProjectResponse {
1960 project_id: project.id.0 as u64,
1961 worktrees: worktrees.clone(),
1962 replica_id: replica_id.0 as u32,
1963 collaborators: collaborators.clone(),
1964 language_servers: project.language_servers.clone(),
1965 role: project.role.into(),
1966 })?;
1967
1968 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1969 // Stream this worktree's entries.
1970 let message = proto::UpdateWorktree {
1971 project_id: project_id.to_proto(),
1972 worktree_id,
1973 abs_path: worktree.abs_path.clone(),
1974 root_name: worktree.root_name,
1975 updated_entries: worktree.entries,
1976 removed_entries: Default::default(),
1977 scan_id: worktree.scan_id,
1978 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1979 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1980 removed_repositories: Default::default(),
1981 };
1982 for update in proto::split_worktree_update(message) {
1983 session.peer.send(session.connection_id, update.clone())?;
1984 }
1985
1986 // Stream this worktree's diagnostics.
1987 for summary in worktree.diagnostic_summaries {
1988 session.peer.send(
1989 session.connection_id,
1990 proto::UpdateDiagnosticSummary {
1991 project_id: project_id.to_proto(),
1992 worktree_id: worktree.id,
1993 summary: Some(summary),
1994 },
1995 )?;
1996 }
1997
1998 for settings_file in worktree.settings_files {
1999 session.peer.send(
2000 session.connection_id,
2001 proto::UpdateWorktreeSettings {
2002 project_id: project_id.to_proto(),
2003 worktree_id: worktree.id,
2004 path: settings_file.path,
2005 content: Some(settings_file.content),
2006 kind: Some(settings_file.kind.to_proto() as i32),
2007 },
2008 )?;
2009 }
2010 }
2011
2012 for repository in mem::take(&mut project.repositories) {
2013 for update in split_repository_update(repository) {
2014 session.peer.send(session.connection_id, update)?;
2015 }
2016 }
2017
2018 for language_server in &project.language_servers {
2019 session.peer.send(
2020 session.connection_id,
2021 proto::UpdateLanguageServer {
2022 project_id: project_id.to_proto(),
2023 server_name: Some(language_server.name.clone()),
2024 language_server_id: language_server.id,
2025 variant: Some(
2026 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2027 proto::LspDiskBasedDiagnosticsUpdated {},
2028 ),
2029 ),
2030 },
2031 )?;
2032 }
2033
2034 Ok(())
2035}
2036
2037/// Leave someone elses shared project.
2038async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2039 let sender_id = session.connection_id;
2040 let project_id = ProjectId::from_proto(request.project_id);
2041 let db = session.db().await;
2042
2043 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2044 tracing::info!(
2045 %project_id,
2046 "leave project"
2047 );
2048
2049 project_left(project, &session);
2050 if let Some(room) = room {
2051 room_updated(room, &session.peer);
2052 }
2053
2054 Ok(())
2055}
2056
2057/// Updates other participants with changes to the project
2058async fn update_project(
2059 request: proto::UpdateProject,
2060 response: Response<proto::UpdateProject>,
2061 session: Session,
2062) -> Result<()> {
2063 let project_id = ProjectId::from_proto(request.project_id);
2064 let (room, guest_connection_ids) = &*session
2065 .db()
2066 .await
2067 .update_project(project_id, session.connection_id, &request.worktrees)
2068 .await?;
2069 broadcast(
2070 Some(session.connection_id),
2071 guest_connection_ids.iter().copied(),
2072 |connection_id| {
2073 session
2074 .peer
2075 .forward_send(session.connection_id, connection_id, request.clone())
2076 },
2077 );
2078 if let Some(room) = room {
2079 room_updated(room, &session.peer);
2080 }
2081 response.send(proto::Ack {})?;
2082
2083 Ok(())
2084}
2085
2086/// Updates other participants with changes to the worktree
2087async fn update_worktree(
2088 request: proto::UpdateWorktree,
2089 response: Response<proto::UpdateWorktree>,
2090 session: Session,
2091) -> Result<()> {
2092 let guest_connection_ids = session
2093 .db()
2094 .await
2095 .update_worktree(&request, session.connection_id)
2096 .await?;
2097
2098 broadcast(
2099 Some(session.connection_id),
2100 guest_connection_ids.iter().copied(),
2101 |connection_id| {
2102 session
2103 .peer
2104 .forward_send(session.connection_id, connection_id, request.clone())
2105 },
2106 );
2107 response.send(proto::Ack {})?;
2108 Ok(())
2109}
2110
2111async fn update_repository(
2112 request: proto::UpdateRepository,
2113 response: Response<proto::UpdateRepository>,
2114 session: Session,
2115) -> Result<()> {
2116 let guest_connection_ids = session
2117 .db()
2118 .await
2119 .update_repository(&request, session.connection_id)
2120 .await?;
2121
2122 broadcast(
2123 Some(session.connection_id),
2124 guest_connection_ids.iter().copied(),
2125 |connection_id| {
2126 session
2127 .peer
2128 .forward_send(session.connection_id, connection_id, request.clone())
2129 },
2130 );
2131 response.send(proto::Ack {})?;
2132 Ok(())
2133}
2134
2135async fn remove_repository(
2136 request: proto::RemoveRepository,
2137 response: Response<proto::RemoveRepository>,
2138 session: Session,
2139) -> Result<()> {
2140 let guest_connection_ids = session
2141 .db()
2142 .await
2143 .remove_repository(&request, session.connection_id)
2144 .await?;
2145
2146 broadcast(
2147 Some(session.connection_id),
2148 guest_connection_ids.iter().copied(),
2149 |connection_id| {
2150 session
2151 .peer
2152 .forward_send(session.connection_id, connection_id, request.clone())
2153 },
2154 );
2155 response.send(proto::Ack {})?;
2156 Ok(())
2157}
2158
2159/// Updates other participants with changes to the diagnostics
2160async fn update_diagnostic_summary(
2161 message: proto::UpdateDiagnosticSummary,
2162 session: Session,
2163) -> Result<()> {
2164 let guest_connection_ids = session
2165 .db()
2166 .await
2167 .update_diagnostic_summary(&message, session.connection_id)
2168 .await?;
2169
2170 broadcast(
2171 Some(session.connection_id),
2172 guest_connection_ids.iter().copied(),
2173 |connection_id| {
2174 session
2175 .peer
2176 .forward_send(session.connection_id, connection_id, message.clone())
2177 },
2178 );
2179
2180 Ok(())
2181}
2182
2183/// Updates other participants with changes to the worktree settings
2184async fn update_worktree_settings(
2185 message: proto::UpdateWorktreeSettings,
2186 session: Session,
2187) -> Result<()> {
2188 let guest_connection_ids = session
2189 .db()
2190 .await
2191 .update_worktree_settings(&message, session.connection_id)
2192 .await?;
2193
2194 broadcast(
2195 Some(session.connection_id),
2196 guest_connection_ids.iter().copied(),
2197 |connection_id| {
2198 session
2199 .peer
2200 .forward_send(session.connection_id, connection_id, message.clone())
2201 },
2202 );
2203
2204 Ok(())
2205}
2206
2207/// Notify other participants that a language server has started.
2208async fn start_language_server(
2209 request: proto::StartLanguageServer,
2210 session: Session,
2211) -> Result<()> {
2212 let guest_connection_ids = session
2213 .db()
2214 .await
2215 .start_language_server(&request, session.connection_id)
2216 .await?;
2217
2218 broadcast(
2219 Some(session.connection_id),
2220 guest_connection_ids.iter().copied(),
2221 |connection_id| {
2222 session
2223 .peer
2224 .forward_send(session.connection_id, connection_id, request.clone())
2225 },
2226 );
2227 Ok(())
2228}
2229
2230/// Notify other participants that a language server has changed.
2231async fn update_language_server(
2232 request: proto::UpdateLanguageServer,
2233 session: Session,
2234) -> Result<()> {
2235 let project_id = ProjectId::from_proto(request.project_id);
2236 let project_connection_ids = session
2237 .db()
2238 .await
2239 .project_connection_ids(project_id, session.connection_id, true)
2240 .await?;
2241 broadcast(
2242 Some(session.connection_id),
2243 project_connection_ids.iter().copied(),
2244 |connection_id| {
2245 session
2246 .peer
2247 .forward_send(session.connection_id, connection_id, request.clone())
2248 },
2249 );
2250 Ok(())
2251}
2252
2253/// forward a project request to the host. These requests should be read only
2254/// as guests are allowed to send them.
2255async fn forward_read_only_project_request<T>(
2256 request: T,
2257 response: Response<T>,
2258 session: Session,
2259) -> Result<()>
2260where
2261 T: EntityMessage + RequestMessage,
2262{
2263 let project_id = ProjectId::from_proto(request.remote_entity_id());
2264 let host_connection_id = session
2265 .db()
2266 .await
2267 .host_for_read_only_project_request(project_id, session.connection_id)
2268 .await?;
2269 let payload = session
2270 .peer
2271 .forward_request(session.connection_id, host_connection_id, request)
2272 .await?;
2273 response.send(payload)?;
2274 Ok(())
2275}
2276
2277async fn forward_find_search_candidates_request(
2278 request: proto::FindSearchCandidates,
2279 response: Response<proto::FindSearchCandidates>,
2280 session: Session,
2281) -> Result<()> {
2282 let project_id = ProjectId::from_proto(request.remote_entity_id());
2283 let host_connection_id = session
2284 .db()
2285 .await
2286 .host_for_read_only_project_request(project_id, session.connection_id)
2287 .await?;
2288 let payload = session
2289 .peer
2290 .forward_request(session.connection_id, host_connection_id, request)
2291 .await?;
2292 response.send(payload)?;
2293 Ok(())
2294}
2295
2296/// forward a project request to the host. These requests are disallowed
2297/// for guests.
2298async fn forward_mutating_project_request<T>(
2299 request: T,
2300 response: Response<T>,
2301 session: Session,
2302) -> Result<()>
2303where
2304 T: EntityMessage + RequestMessage,
2305{
2306 let project_id = ProjectId::from_proto(request.remote_entity_id());
2307
2308 let host_connection_id = session
2309 .db()
2310 .await
2311 .host_for_mutating_project_request(project_id, session.connection_id)
2312 .await?;
2313 let payload = session
2314 .peer
2315 .forward_request(session.connection_id, host_connection_id, request)
2316 .await?;
2317 response.send(payload)?;
2318 Ok(())
2319}
2320
2321/// Notify other participants that a new buffer has been created
2322async fn create_buffer_for_peer(
2323 request: proto::CreateBufferForPeer,
2324 session: Session,
2325) -> Result<()> {
2326 session
2327 .db()
2328 .await
2329 .check_user_is_project_host(
2330 ProjectId::from_proto(request.project_id),
2331 session.connection_id,
2332 )
2333 .await?;
2334 let peer_id = request.peer_id.context("invalid peer id")?;
2335 session
2336 .peer
2337 .forward_send(session.connection_id, peer_id.into(), request)?;
2338 Ok(())
2339}
2340
2341/// Notify other participants that a buffer has been updated. This is
2342/// allowed for guests as long as the update is limited to selections.
2343async fn update_buffer(
2344 request: proto::UpdateBuffer,
2345 response: Response<proto::UpdateBuffer>,
2346 session: Session,
2347) -> Result<()> {
2348 let project_id = ProjectId::from_proto(request.project_id);
2349 let mut capability = Capability::ReadOnly;
2350
2351 for op in request.operations.iter() {
2352 match op.variant {
2353 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2354 Some(_) => capability = Capability::ReadWrite,
2355 }
2356 }
2357
2358 let host = {
2359 let guard = session
2360 .db()
2361 .await
2362 .connections_for_buffer_update(project_id, session.connection_id, capability)
2363 .await?;
2364
2365 let (host, guests) = &*guard;
2366
2367 broadcast(
2368 Some(session.connection_id),
2369 guests.clone(),
2370 |connection_id| {
2371 session
2372 .peer
2373 .forward_send(session.connection_id, connection_id, request.clone())
2374 },
2375 );
2376
2377 *host
2378 };
2379
2380 if host != session.connection_id {
2381 session
2382 .peer
2383 .forward_request(session.connection_id, host, request.clone())
2384 .await?;
2385 }
2386
2387 response.send(proto::Ack {})?;
2388 Ok(())
2389}
2390
2391async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2392 let project_id = ProjectId::from_proto(message.project_id);
2393
2394 let operation = message.operation.as_ref().context("invalid operation")?;
2395 let capability = match operation.variant.as_ref() {
2396 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2397 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2398 match buffer_op.variant {
2399 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2400 Capability::ReadOnly
2401 }
2402 _ => Capability::ReadWrite,
2403 }
2404 } else {
2405 Capability::ReadWrite
2406 }
2407 }
2408 Some(_) => Capability::ReadWrite,
2409 None => Capability::ReadOnly,
2410 };
2411
2412 let guard = session
2413 .db()
2414 .await
2415 .connections_for_buffer_update(project_id, session.connection_id, capability)
2416 .await?;
2417
2418 let (host, guests) = &*guard;
2419
2420 broadcast(
2421 Some(session.connection_id),
2422 guests.iter().chain([host]).copied(),
2423 |connection_id| {
2424 session
2425 .peer
2426 .forward_send(session.connection_id, connection_id, message.clone())
2427 },
2428 );
2429
2430 Ok(())
2431}
2432
2433/// Notify other participants that a project has been updated.
2434async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2435 request: T,
2436 session: Session,
2437) -> Result<()> {
2438 let project_id = ProjectId::from_proto(request.remote_entity_id());
2439 let project_connection_ids = session
2440 .db()
2441 .await
2442 .project_connection_ids(project_id, session.connection_id, false)
2443 .await?;
2444
2445 broadcast(
2446 Some(session.connection_id),
2447 project_connection_ids.iter().copied(),
2448 |connection_id| {
2449 session
2450 .peer
2451 .forward_send(session.connection_id, connection_id, request.clone())
2452 },
2453 );
2454 Ok(())
2455}
2456
2457/// Start following another user in a call.
2458async fn follow(
2459 request: proto::Follow,
2460 response: Response<proto::Follow>,
2461 session: Session,
2462) -> Result<()> {
2463 let room_id = RoomId::from_proto(request.room_id);
2464 let project_id = request.project_id.map(ProjectId::from_proto);
2465 let leader_id = request.leader_id.context("invalid leader id")?.into();
2466 let follower_id = session.connection_id;
2467
2468 session
2469 .db()
2470 .await
2471 .check_room_participants(room_id, leader_id, session.connection_id)
2472 .await?;
2473
2474 let response_payload = session
2475 .peer
2476 .forward_request(session.connection_id, leader_id, request)
2477 .await?;
2478 response.send(response_payload)?;
2479
2480 if let Some(project_id) = project_id {
2481 let room = session
2482 .db()
2483 .await
2484 .follow(room_id, project_id, leader_id, follower_id)
2485 .await?;
2486 room_updated(&room, &session.peer);
2487 }
2488
2489 Ok(())
2490}
2491
2492/// Stop following another user in a call.
2493async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2494 let room_id = RoomId::from_proto(request.room_id);
2495 let project_id = request.project_id.map(ProjectId::from_proto);
2496 let leader_id = request.leader_id.context("invalid leader id")?.into();
2497 let follower_id = session.connection_id;
2498
2499 session
2500 .db()
2501 .await
2502 .check_room_participants(room_id, leader_id, session.connection_id)
2503 .await?;
2504
2505 session
2506 .peer
2507 .forward_send(session.connection_id, leader_id, request)?;
2508
2509 if let Some(project_id) = project_id {
2510 let room = session
2511 .db()
2512 .await
2513 .unfollow(room_id, project_id, leader_id, follower_id)
2514 .await?;
2515 room_updated(&room, &session.peer);
2516 }
2517
2518 Ok(())
2519}
2520
2521/// Notify everyone following you of your current location.
2522async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2523 let room_id = RoomId::from_proto(request.room_id);
2524 let database = session.db.lock().await;
2525
2526 let connection_ids = if let Some(project_id) = request.project_id {
2527 let project_id = ProjectId::from_proto(project_id);
2528 database
2529 .project_connection_ids(project_id, session.connection_id, true)
2530 .await?
2531 } else {
2532 database
2533 .room_connection_ids(room_id, session.connection_id)
2534 .await?
2535 };
2536
2537 // For now, don't send view update messages back to that view's current leader.
2538 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2539 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2540 _ => None,
2541 });
2542
2543 for connection_id in connection_ids.iter().cloned() {
2544 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2545 session
2546 .peer
2547 .forward_send(session.connection_id, connection_id, request.clone())?;
2548 }
2549 }
2550 Ok(())
2551}
2552
2553/// Get public data about users.
2554async fn get_users(
2555 request: proto::GetUsers,
2556 response: Response<proto::GetUsers>,
2557 session: Session,
2558) -> Result<()> {
2559 let user_ids = request
2560 .user_ids
2561 .into_iter()
2562 .map(UserId::from_proto)
2563 .collect();
2564 let users = session
2565 .db()
2566 .await
2567 .get_users_by_ids(user_ids)
2568 .await?
2569 .into_iter()
2570 .map(|user| proto::User {
2571 id: user.id.to_proto(),
2572 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2573 github_login: user.github_login,
2574 name: user.name,
2575 })
2576 .collect();
2577 response.send(proto::UsersResponse { users })?;
2578 Ok(())
2579}
2580
2581/// Search for users (to invite) buy Github login
2582async fn fuzzy_search_users(
2583 request: proto::FuzzySearchUsers,
2584 response: Response<proto::FuzzySearchUsers>,
2585 session: Session,
2586) -> Result<()> {
2587 let query = request.query;
2588 let users = match query.len() {
2589 0 => vec![],
2590 1 | 2 => session
2591 .db()
2592 .await
2593 .get_user_by_github_login(&query)
2594 .await?
2595 .into_iter()
2596 .collect(),
2597 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2598 };
2599 let users = users
2600 .into_iter()
2601 .filter(|user| user.id != session.user_id())
2602 .map(|user| proto::User {
2603 id: user.id.to_proto(),
2604 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2605 github_login: user.github_login,
2606 name: user.name,
2607 })
2608 .collect();
2609 response.send(proto::UsersResponse { users })?;
2610 Ok(())
2611}
2612
2613/// Send a contact request to another user.
2614async fn request_contact(
2615 request: proto::RequestContact,
2616 response: Response<proto::RequestContact>,
2617 session: Session,
2618) -> Result<()> {
2619 let requester_id = session.user_id();
2620 let responder_id = UserId::from_proto(request.responder_id);
2621 if requester_id == responder_id {
2622 return Err(anyhow!("cannot add yourself as a contact"))?;
2623 }
2624
2625 let notifications = session
2626 .db()
2627 .await
2628 .send_contact_request(requester_id, responder_id)
2629 .await?;
2630
2631 // Update outgoing contact requests of requester
2632 let mut update = proto::UpdateContacts::default();
2633 update.outgoing_requests.push(responder_id.to_proto());
2634 for connection_id in session
2635 .connection_pool()
2636 .await
2637 .user_connection_ids(requester_id)
2638 {
2639 session.peer.send(connection_id, update.clone())?;
2640 }
2641
2642 // Update incoming contact requests of responder
2643 let mut update = proto::UpdateContacts::default();
2644 update
2645 .incoming_requests
2646 .push(proto::IncomingContactRequest {
2647 requester_id: requester_id.to_proto(),
2648 });
2649 let connection_pool = session.connection_pool().await;
2650 for connection_id in connection_pool.user_connection_ids(responder_id) {
2651 session.peer.send(connection_id, update.clone())?;
2652 }
2653
2654 send_notifications(&connection_pool, &session.peer, notifications);
2655
2656 response.send(proto::Ack {})?;
2657 Ok(())
2658}
2659
2660/// Accept or decline a contact request
2661async fn respond_to_contact_request(
2662 request: proto::RespondToContactRequest,
2663 response: Response<proto::RespondToContactRequest>,
2664 session: Session,
2665) -> Result<()> {
2666 let responder_id = session.user_id();
2667 let requester_id = UserId::from_proto(request.requester_id);
2668 let db = session.db().await;
2669 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2670 db.dismiss_contact_notification(responder_id, requester_id)
2671 .await?;
2672 } else {
2673 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2674
2675 let notifications = db
2676 .respond_to_contact_request(responder_id, requester_id, accept)
2677 .await?;
2678 let requester_busy = db.is_user_busy(requester_id).await?;
2679 let responder_busy = db.is_user_busy(responder_id).await?;
2680
2681 let pool = session.connection_pool().await;
2682 // Update responder with new contact
2683 let mut update = proto::UpdateContacts::default();
2684 if accept {
2685 update
2686 .contacts
2687 .push(contact_for_user(requester_id, requester_busy, &pool));
2688 }
2689 update
2690 .remove_incoming_requests
2691 .push(requester_id.to_proto());
2692 for connection_id in pool.user_connection_ids(responder_id) {
2693 session.peer.send(connection_id, update.clone())?;
2694 }
2695
2696 // Update requester with new contact
2697 let mut update = proto::UpdateContacts::default();
2698 if accept {
2699 update
2700 .contacts
2701 .push(contact_for_user(responder_id, responder_busy, &pool));
2702 }
2703 update
2704 .remove_outgoing_requests
2705 .push(responder_id.to_proto());
2706
2707 for connection_id in pool.user_connection_ids(requester_id) {
2708 session.peer.send(connection_id, update.clone())?;
2709 }
2710
2711 send_notifications(&pool, &session.peer, notifications);
2712 }
2713
2714 response.send(proto::Ack {})?;
2715 Ok(())
2716}
2717
2718/// Remove a contact.
2719async fn remove_contact(
2720 request: proto::RemoveContact,
2721 response: Response<proto::RemoveContact>,
2722 session: Session,
2723) -> Result<()> {
2724 let requester_id = session.user_id();
2725 let responder_id = UserId::from_proto(request.user_id);
2726 let db = session.db().await;
2727 let (contact_accepted, deleted_notification_id) =
2728 db.remove_contact(requester_id, responder_id).await?;
2729
2730 let pool = session.connection_pool().await;
2731 // Update outgoing contact requests of requester
2732 let mut update = proto::UpdateContacts::default();
2733 if contact_accepted {
2734 update.remove_contacts.push(responder_id.to_proto());
2735 } else {
2736 update
2737 .remove_outgoing_requests
2738 .push(responder_id.to_proto());
2739 }
2740 for connection_id in pool.user_connection_ids(requester_id) {
2741 session.peer.send(connection_id, update.clone())?;
2742 }
2743
2744 // Update incoming contact requests of responder
2745 let mut update = proto::UpdateContacts::default();
2746 if contact_accepted {
2747 update.remove_contacts.push(requester_id.to_proto());
2748 } else {
2749 update
2750 .remove_incoming_requests
2751 .push(requester_id.to_proto());
2752 }
2753 for connection_id in pool.user_connection_ids(responder_id) {
2754 session.peer.send(connection_id, update.clone())?;
2755 if let Some(notification_id) = deleted_notification_id {
2756 session.peer.send(
2757 connection_id,
2758 proto::DeleteNotification {
2759 notification_id: notification_id.to_proto(),
2760 },
2761 )?;
2762 }
2763 }
2764
2765 response.send(proto::Ack {})?;
2766 Ok(())
2767}
2768
2769fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2770 version.0.minor() < 139
2771}
2772
2773async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2774 if is_staff {
2775 return Ok(proto::Plan::ZedPro);
2776 }
2777
2778 let subscription = db.get_active_billing_subscription(user_id).await?;
2779 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2780
2781 let plan = if let Some(subscription_kind) = subscription_kind {
2782 match subscription_kind {
2783 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2784 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2785 SubscriptionKind::ZedFree => proto::Plan::Free,
2786 }
2787 } else {
2788 proto::Plan::Free
2789 };
2790
2791 Ok(plan)
2792}
2793
2794async fn make_update_user_plan_message(
2795 user: &User,
2796 is_staff: bool,
2797 db: &Arc<Database>,
2798 llm_db: Option<Arc<LlmDatabase>>,
2799) -> Result<proto::UpdateUserPlan> {
2800 let feature_flags = db.get_user_flags(user.id).await?;
2801 let plan = current_plan(db, user.id, is_staff).await?;
2802 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2803 let billing_preferences = db.get_billing_preferences(user.id).await?;
2804
2805 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2806 let subscription = db.get_active_billing_subscription(user.id).await?;
2807
2808 let subscription_period =
2809 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2810
2811 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2812 llm_db
2813 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2814 .await?
2815 } else {
2816 None
2817 };
2818
2819 (subscription_period, usage)
2820 } else {
2821 (None, None)
2822 };
2823
2824 let bypass_account_age_check = feature_flags
2825 .iter()
2826 .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2827 let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2828 && !bypass_account_age_check
2829 && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2830
2831 Ok(proto::UpdateUserPlan {
2832 plan: plan.into(),
2833 trial_started_at: billing_customer
2834 .as_ref()
2835 .and_then(|billing_customer| billing_customer.trial_started_at)
2836 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2837 is_usage_based_billing_enabled: if is_staff {
2838 Some(true)
2839 } else {
2840 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2841 },
2842 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2843 proto::SubscriptionPeriod {
2844 started_at: started_at.timestamp() as u64,
2845 ended_at: ended_at.timestamp() as u64,
2846 }
2847 }),
2848 account_too_young: Some(account_too_young),
2849 has_overdue_invoices: billing_customer
2850 .map(|billing_customer| billing_customer.has_overdue_invoices),
2851 usage: Some(
2852 usage
2853 .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2854 .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2855 ),
2856 })
2857}
2858
2859fn model_requests_limit(
2860 plan: zed_llm_client::Plan,
2861 feature_flags: &Vec<String>,
2862) -> zed_llm_client::UsageLimit {
2863 match plan.model_requests_limit() {
2864 zed_llm_client::UsageLimit::Limited(limit) => {
2865 let limit = if plan == zed_llm_client::Plan::ZedProTrial
2866 && feature_flags
2867 .iter()
2868 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2869 {
2870 1_000
2871 } else {
2872 limit
2873 };
2874
2875 zed_llm_client::UsageLimit::Limited(limit)
2876 }
2877 zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2878 }
2879}
2880
2881fn subscription_usage_to_proto(
2882 plan: proto::Plan,
2883 usage: crate::llm::db::subscription_usage::Model,
2884 feature_flags: &Vec<String>,
2885) -> proto::SubscriptionUsage {
2886 let plan = match plan {
2887 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2888 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2889 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2890 };
2891
2892 proto::SubscriptionUsage {
2893 model_requests_usage_amount: usage.model_requests as u32,
2894 model_requests_usage_limit: Some(proto::UsageLimit {
2895 variant: Some(match model_requests_limit(plan, feature_flags) {
2896 zed_llm_client::UsageLimit::Limited(limit) => {
2897 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2898 limit: limit as u32,
2899 })
2900 }
2901 zed_llm_client::UsageLimit::Unlimited => {
2902 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2903 }
2904 }),
2905 }),
2906 edit_predictions_usage_amount: usage.edit_predictions as u32,
2907 edit_predictions_usage_limit: Some(proto::UsageLimit {
2908 variant: Some(match plan.edit_predictions_limit() {
2909 zed_llm_client::UsageLimit::Limited(limit) => {
2910 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2911 limit: limit as u32,
2912 })
2913 }
2914 zed_llm_client::UsageLimit::Unlimited => {
2915 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2916 }
2917 }),
2918 }),
2919 }
2920}
2921
2922fn make_default_subscription_usage(
2923 plan: proto::Plan,
2924 feature_flags: &Vec<String>,
2925) -> proto::SubscriptionUsage {
2926 let plan = match plan {
2927 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2928 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2929 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2930 };
2931
2932 proto::SubscriptionUsage {
2933 model_requests_usage_amount: 0,
2934 model_requests_usage_limit: Some(proto::UsageLimit {
2935 variant: Some(match model_requests_limit(plan, feature_flags) {
2936 zed_llm_client::UsageLimit::Limited(limit) => {
2937 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2938 limit: limit as u32,
2939 })
2940 }
2941 zed_llm_client::UsageLimit::Unlimited => {
2942 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2943 }
2944 }),
2945 }),
2946 edit_predictions_usage_amount: 0,
2947 edit_predictions_usage_limit: Some(proto::UsageLimit {
2948 variant: Some(match plan.edit_predictions_limit() {
2949 zed_llm_client::UsageLimit::Limited(limit) => {
2950 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2951 limit: limit as u32,
2952 })
2953 }
2954 zed_llm_client::UsageLimit::Unlimited => {
2955 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2956 }
2957 }),
2958 }),
2959 }
2960}
2961
2962async fn update_user_plan(session: &Session) -> Result<()> {
2963 let db = session.db().await;
2964
2965 let update_user_plan = make_update_user_plan_message(
2966 session.principal.user(),
2967 session.is_staff(),
2968 &db.0,
2969 session.app_state.llm_db.clone(),
2970 )
2971 .await?;
2972
2973 session
2974 .peer
2975 .send(session.connection_id, update_user_plan)
2976 .trace_err();
2977
2978 Ok(())
2979}
2980
2981async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2982 subscribe_user_to_channels(session.user_id(), &session).await?;
2983 Ok(())
2984}
2985
2986async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2987 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2988 let mut pool = session.connection_pool().await;
2989 for membership in &channels_for_user.channel_memberships {
2990 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2991 }
2992 session.peer.send(
2993 session.connection_id,
2994 build_update_user_channels(&channels_for_user),
2995 )?;
2996 session.peer.send(
2997 session.connection_id,
2998 build_channels_update(channels_for_user),
2999 )?;
3000 Ok(())
3001}
3002
3003/// Creates a new channel.
3004async fn create_channel(
3005 request: proto::CreateChannel,
3006 response: Response<proto::CreateChannel>,
3007 session: Session,
3008) -> Result<()> {
3009 let db = session.db().await;
3010
3011 let parent_id = request.parent_id.map(ChannelId::from_proto);
3012 let (channel, membership) = db
3013 .create_channel(&request.name, parent_id, session.user_id())
3014 .await?;
3015
3016 let root_id = channel.root_id();
3017 let channel = Channel::from_model(channel);
3018
3019 response.send(proto::CreateChannelResponse {
3020 channel: Some(channel.to_proto()),
3021 parent_id: request.parent_id,
3022 })?;
3023
3024 let mut connection_pool = session.connection_pool().await;
3025 if let Some(membership) = membership {
3026 connection_pool.subscribe_to_channel(
3027 membership.user_id,
3028 membership.channel_id,
3029 membership.role,
3030 );
3031 let update = proto::UpdateUserChannels {
3032 channel_memberships: vec![proto::ChannelMembership {
3033 channel_id: membership.channel_id.to_proto(),
3034 role: membership.role.into(),
3035 }],
3036 ..Default::default()
3037 };
3038 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3039 session.peer.send(connection_id, update.clone())?;
3040 }
3041 }
3042
3043 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3044 if !role.can_see_channel(channel.visibility) {
3045 continue;
3046 }
3047
3048 let update = proto::UpdateChannels {
3049 channels: vec![channel.to_proto()],
3050 ..Default::default()
3051 };
3052 session.peer.send(connection_id, update.clone())?;
3053 }
3054
3055 Ok(())
3056}
3057
3058/// Delete a channel
3059async fn delete_channel(
3060 request: proto::DeleteChannel,
3061 response: Response<proto::DeleteChannel>,
3062 session: Session,
3063) -> Result<()> {
3064 let db = session.db().await;
3065
3066 let channel_id = request.channel_id;
3067 let (root_channel, removed_channels) = db
3068 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3069 .await?;
3070 response.send(proto::Ack {})?;
3071
3072 // Notify members of removed channels
3073 let mut update = proto::UpdateChannels::default();
3074 update
3075 .delete_channels
3076 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3077
3078 let connection_pool = session.connection_pool().await;
3079 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3080 session.peer.send(connection_id, update.clone())?;
3081 }
3082
3083 Ok(())
3084}
3085
3086/// Invite someone to join a channel.
3087async fn invite_channel_member(
3088 request: proto::InviteChannelMember,
3089 response: Response<proto::InviteChannelMember>,
3090 session: Session,
3091) -> Result<()> {
3092 let db = session.db().await;
3093 let channel_id = ChannelId::from_proto(request.channel_id);
3094 let invitee_id = UserId::from_proto(request.user_id);
3095 let InviteMemberResult {
3096 channel,
3097 notifications,
3098 } = db
3099 .invite_channel_member(
3100 channel_id,
3101 invitee_id,
3102 session.user_id(),
3103 request.role().into(),
3104 )
3105 .await?;
3106
3107 let update = proto::UpdateChannels {
3108 channel_invitations: vec![channel.to_proto()],
3109 ..Default::default()
3110 };
3111
3112 let connection_pool = session.connection_pool().await;
3113 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3114 session.peer.send(connection_id, update.clone())?;
3115 }
3116
3117 send_notifications(&connection_pool, &session.peer, notifications);
3118
3119 response.send(proto::Ack {})?;
3120 Ok(())
3121}
3122
3123/// remove someone from a channel
3124async fn remove_channel_member(
3125 request: proto::RemoveChannelMember,
3126 response: Response<proto::RemoveChannelMember>,
3127 session: Session,
3128) -> Result<()> {
3129 let db = session.db().await;
3130 let channel_id = ChannelId::from_proto(request.channel_id);
3131 let member_id = UserId::from_proto(request.user_id);
3132
3133 let RemoveChannelMemberResult {
3134 membership_update,
3135 notification_id,
3136 } = db
3137 .remove_channel_member(channel_id, member_id, session.user_id())
3138 .await?;
3139
3140 let mut connection_pool = session.connection_pool().await;
3141 notify_membership_updated(
3142 &mut connection_pool,
3143 membership_update,
3144 member_id,
3145 &session.peer,
3146 );
3147 for connection_id in connection_pool.user_connection_ids(member_id) {
3148 if let Some(notification_id) = notification_id {
3149 session
3150 .peer
3151 .send(
3152 connection_id,
3153 proto::DeleteNotification {
3154 notification_id: notification_id.to_proto(),
3155 },
3156 )
3157 .trace_err();
3158 }
3159 }
3160
3161 response.send(proto::Ack {})?;
3162 Ok(())
3163}
3164
3165/// Toggle the channel between public and private.
3166/// Care is taken to maintain the invariant that public channels only descend from public channels,
3167/// (though members-only channels can appear at any point in the hierarchy).
3168async fn set_channel_visibility(
3169 request: proto::SetChannelVisibility,
3170 response: Response<proto::SetChannelVisibility>,
3171 session: Session,
3172) -> Result<()> {
3173 let db = session.db().await;
3174 let channel_id = ChannelId::from_proto(request.channel_id);
3175 let visibility = request.visibility().into();
3176
3177 let channel_model = db
3178 .set_channel_visibility(channel_id, visibility, session.user_id())
3179 .await?;
3180 let root_id = channel_model.root_id();
3181 let channel = Channel::from_model(channel_model);
3182
3183 let mut connection_pool = session.connection_pool().await;
3184 for (user_id, role) in connection_pool
3185 .channel_user_ids(root_id)
3186 .collect::<Vec<_>>()
3187 .into_iter()
3188 {
3189 let update = if role.can_see_channel(channel.visibility) {
3190 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3191 proto::UpdateChannels {
3192 channels: vec![channel.to_proto()],
3193 ..Default::default()
3194 }
3195 } else {
3196 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3197 proto::UpdateChannels {
3198 delete_channels: vec![channel.id.to_proto()],
3199 ..Default::default()
3200 }
3201 };
3202
3203 for connection_id in connection_pool.user_connection_ids(user_id) {
3204 session.peer.send(connection_id, update.clone())?;
3205 }
3206 }
3207
3208 response.send(proto::Ack {})?;
3209 Ok(())
3210}
3211
3212/// Alter the role for a user in the channel.
3213async fn set_channel_member_role(
3214 request: proto::SetChannelMemberRole,
3215 response: Response<proto::SetChannelMemberRole>,
3216 session: Session,
3217) -> Result<()> {
3218 let db = session.db().await;
3219 let channel_id = ChannelId::from_proto(request.channel_id);
3220 let member_id = UserId::from_proto(request.user_id);
3221 let result = db
3222 .set_channel_member_role(
3223 channel_id,
3224 session.user_id(),
3225 member_id,
3226 request.role().into(),
3227 )
3228 .await?;
3229
3230 match result {
3231 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3232 let mut connection_pool = session.connection_pool().await;
3233 notify_membership_updated(
3234 &mut connection_pool,
3235 membership_update,
3236 member_id,
3237 &session.peer,
3238 )
3239 }
3240 db::SetMemberRoleResult::InviteUpdated(channel) => {
3241 let update = proto::UpdateChannels {
3242 channel_invitations: vec![channel.to_proto()],
3243 ..Default::default()
3244 };
3245
3246 for connection_id in session
3247 .connection_pool()
3248 .await
3249 .user_connection_ids(member_id)
3250 {
3251 session.peer.send(connection_id, update.clone())?;
3252 }
3253 }
3254 }
3255
3256 response.send(proto::Ack {})?;
3257 Ok(())
3258}
3259
3260/// Change the name of a channel
3261async fn rename_channel(
3262 request: proto::RenameChannel,
3263 response: Response<proto::RenameChannel>,
3264 session: Session,
3265) -> Result<()> {
3266 let db = session.db().await;
3267 let channel_id = ChannelId::from_proto(request.channel_id);
3268 let channel_model = db
3269 .rename_channel(channel_id, session.user_id(), &request.name)
3270 .await?;
3271 let root_id = channel_model.root_id();
3272 let channel = Channel::from_model(channel_model);
3273
3274 response.send(proto::RenameChannelResponse {
3275 channel: Some(channel.to_proto()),
3276 })?;
3277
3278 let connection_pool = session.connection_pool().await;
3279 let update = proto::UpdateChannels {
3280 channels: vec![channel.to_proto()],
3281 ..Default::default()
3282 };
3283 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3284 if role.can_see_channel(channel.visibility) {
3285 session.peer.send(connection_id, update.clone())?;
3286 }
3287 }
3288
3289 Ok(())
3290}
3291
3292/// Move a channel to a new parent.
3293async fn move_channel(
3294 request: proto::MoveChannel,
3295 response: Response<proto::MoveChannel>,
3296 session: Session,
3297) -> Result<()> {
3298 let channel_id = ChannelId::from_proto(request.channel_id);
3299 let to = ChannelId::from_proto(request.to);
3300
3301 let (root_id, channels) = session
3302 .db()
3303 .await
3304 .move_channel(channel_id, to, session.user_id())
3305 .await?;
3306
3307 let connection_pool = session.connection_pool().await;
3308 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3309 let channels = channels
3310 .iter()
3311 .filter_map(|channel| {
3312 if role.can_see_channel(channel.visibility) {
3313 Some(channel.to_proto())
3314 } else {
3315 None
3316 }
3317 })
3318 .collect::<Vec<_>>();
3319 if channels.is_empty() {
3320 continue;
3321 }
3322
3323 let update = proto::UpdateChannels {
3324 channels,
3325 ..Default::default()
3326 };
3327
3328 session.peer.send(connection_id, update.clone())?;
3329 }
3330
3331 response.send(Ack {})?;
3332 Ok(())
3333}
3334
3335async fn reorder_channel(
3336 request: proto::ReorderChannel,
3337 response: Response<proto::ReorderChannel>,
3338 session: Session,
3339) -> Result<()> {
3340 let channel_id = ChannelId::from_proto(request.channel_id);
3341 let direction = request.direction();
3342
3343 let updated_channels = session
3344 .db()
3345 .await
3346 .reorder_channel(channel_id, direction, session.user_id())
3347 .await?;
3348
3349 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3350 let connection_pool = session.connection_pool().await;
3351 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3352 let channels = updated_channels
3353 .iter()
3354 .filter_map(|channel| {
3355 if role.can_see_channel(channel.visibility) {
3356 Some(channel.to_proto())
3357 } else {
3358 None
3359 }
3360 })
3361 .collect::<Vec<_>>();
3362
3363 if channels.is_empty() {
3364 continue;
3365 }
3366
3367 let update = proto::UpdateChannels {
3368 channels,
3369 ..Default::default()
3370 };
3371
3372 session.peer.send(connection_id, update.clone())?;
3373 }
3374 }
3375
3376 response.send(Ack {})?;
3377 Ok(())
3378}
3379
3380/// Get the list of channel members
3381async fn get_channel_members(
3382 request: proto::GetChannelMembers,
3383 response: Response<proto::GetChannelMembers>,
3384 session: Session,
3385) -> Result<()> {
3386 let db = session.db().await;
3387 let channel_id = ChannelId::from_proto(request.channel_id);
3388 let limit = if request.limit == 0 {
3389 u16::MAX as u64
3390 } else {
3391 request.limit
3392 };
3393 let (members, users) = db
3394 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3395 .await?;
3396 response.send(proto::GetChannelMembersResponse { members, users })?;
3397 Ok(())
3398}
3399
3400/// Accept or decline a channel invitation.
3401async fn respond_to_channel_invite(
3402 request: proto::RespondToChannelInvite,
3403 response: Response<proto::RespondToChannelInvite>,
3404 session: Session,
3405) -> Result<()> {
3406 let db = session.db().await;
3407 let channel_id = ChannelId::from_proto(request.channel_id);
3408 let RespondToChannelInvite {
3409 membership_update,
3410 notifications,
3411 } = db
3412 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3413 .await?;
3414
3415 let mut connection_pool = session.connection_pool().await;
3416 if let Some(membership_update) = membership_update {
3417 notify_membership_updated(
3418 &mut connection_pool,
3419 membership_update,
3420 session.user_id(),
3421 &session.peer,
3422 );
3423 } else {
3424 let update = proto::UpdateChannels {
3425 remove_channel_invitations: vec![channel_id.to_proto()],
3426 ..Default::default()
3427 };
3428
3429 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3430 session.peer.send(connection_id, update.clone())?;
3431 }
3432 };
3433
3434 send_notifications(&connection_pool, &session.peer, notifications);
3435
3436 response.send(proto::Ack {})?;
3437
3438 Ok(())
3439}
3440
3441/// Join the channels' room
3442async fn join_channel(
3443 request: proto::JoinChannel,
3444 response: Response<proto::JoinChannel>,
3445 session: Session,
3446) -> Result<()> {
3447 let channel_id = ChannelId::from_proto(request.channel_id);
3448 join_channel_internal(channel_id, Box::new(response), session).await
3449}
3450
3451trait JoinChannelInternalResponse {
3452 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3453}
3454impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3455 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3456 Response::<proto::JoinChannel>::send(self, result)
3457 }
3458}
3459impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3460 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3461 Response::<proto::JoinRoom>::send(self, result)
3462 }
3463}
3464
3465async fn join_channel_internal(
3466 channel_id: ChannelId,
3467 response: Box<impl JoinChannelInternalResponse>,
3468 session: Session,
3469) -> Result<()> {
3470 let joined_room = {
3471 let mut db = session.db().await;
3472 // If zed quits without leaving the room, and the user re-opens zed before the
3473 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3474 // room they were in.
3475 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3476 tracing::info!(
3477 stale_connection_id = %connection,
3478 "cleaning up stale connection",
3479 );
3480 drop(db);
3481 leave_room_for_session(&session, connection).await?;
3482 db = session.db().await;
3483 }
3484
3485 let (joined_room, membership_updated, role) = db
3486 .join_channel(channel_id, session.user_id(), session.connection_id)
3487 .await?;
3488
3489 let live_kit_connection_info =
3490 session
3491 .app_state
3492 .livekit_client
3493 .as_ref()
3494 .and_then(|live_kit| {
3495 let (can_publish, token) = if role == ChannelRole::Guest {
3496 (
3497 false,
3498 live_kit
3499 .guest_token(
3500 &joined_room.room.livekit_room,
3501 &session.user_id().to_string(),
3502 )
3503 .trace_err()?,
3504 )
3505 } else {
3506 (
3507 true,
3508 live_kit
3509 .room_token(
3510 &joined_room.room.livekit_room,
3511 &session.user_id().to_string(),
3512 )
3513 .trace_err()?,
3514 )
3515 };
3516
3517 Some(LiveKitConnectionInfo {
3518 server_url: live_kit.url().into(),
3519 token,
3520 can_publish,
3521 })
3522 });
3523
3524 response.send(proto::JoinRoomResponse {
3525 room: Some(joined_room.room.clone()),
3526 channel_id: joined_room
3527 .channel
3528 .as_ref()
3529 .map(|channel| channel.id.to_proto()),
3530 live_kit_connection_info,
3531 })?;
3532
3533 let mut connection_pool = session.connection_pool().await;
3534 if let Some(membership_updated) = membership_updated {
3535 notify_membership_updated(
3536 &mut connection_pool,
3537 membership_updated,
3538 session.user_id(),
3539 &session.peer,
3540 );
3541 }
3542
3543 room_updated(&joined_room.room, &session.peer);
3544
3545 joined_room
3546 };
3547
3548 channel_updated(
3549 &joined_room.channel.context("channel not returned")?,
3550 &joined_room.room,
3551 &session.peer,
3552 &*session.connection_pool().await,
3553 );
3554
3555 update_user_contacts(session.user_id(), &session).await?;
3556 Ok(())
3557}
3558
3559/// Start editing the channel notes
3560async fn join_channel_buffer(
3561 request: proto::JoinChannelBuffer,
3562 response: Response<proto::JoinChannelBuffer>,
3563 session: Session,
3564) -> Result<()> {
3565 let db = session.db().await;
3566 let channel_id = ChannelId::from_proto(request.channel_id);
3567
3568 let open_response = db
3569 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3570 .await?;
3571
3572 let collaborators = open_response.collaborators.clone();
3573 response.send(open_response)?;
3574
3575 let update = UpdateChannelBufferCollaborators {
3576 channel_id: channel_id.to_proto(),
3577 collaborators: collaborators.clone(),
3578 };
3579 channel_buffer_updated(
3580 session.connection_id,
3581 collaborators
3582 .iter()
3583 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3584 &update,
3585 &session.peer,
3586 );
3587
3588 Ok(())
3589}
3590
3591/// Edit the channel notes
3592async fn update_channel_buffer(
3593 request: proto::UpdateChannelBuffer,
3594 session: Session,
3595) -> Result<()> {
3596 let db = session.db().await;
3597 let channel_id = ChannelId::from_proto(request.channel_id);
3598
3599 let (collaborators, epoch, version) = db
3600 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3601 .await?;
3602
3603 channel_buffer_updated(
3604 session.connection_id,
3605 collaborators.clone(),
3606 &proto::UpdateChannelBuffer {
3607 channel_id: channel_id.to_proto(),
3608 operations: request.operations,
3609 },
3610 &session.peer,
3611 );
3612
3613 let pool = &*session.connection_pool().await;
3614
3615 let non_collaborators =
3616 pool.channel_connection_ids(channel_id)
3617 .filter_map(|(connection_id, _)| {
3618 if collaborators.contains(&connection_id) {
3619 None
3620 } else {
3621 Some(connection_id)
3622 }
3623 });
3624
3625 broadcast(None, non_collaborators, |peer_id| {
3626 session.peer.send(
3627 peer_id,
3628 proto::UpdateChannels {
3629 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3630 channel_id: channel_id.to_proto(),
3631 epoch: epoch as u64,
3632 version: version.clone(),
3633 }],
3634 ..Default::default()
3635 },
3636 )
3637 });
3638
3639 Ok(())
3640}
3641
3642/// Rejoin the channel notes after a connection blip
3643async fn rejoin_channel_buffers(
3644 request: proto::RejoinChannelBuffers,
3645 response: Response<proto::RejoinChannelBuffers>,
3646 session: Session,
3647) -> Result<()> {
3648 let db = session.db().await;
3649 let buffers = db
3650 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3651 .await?;
3652
3653 for rejoined_buffer in &buffers {
3654 let collaborators_to_notify = rejoined_buffer
3655 .buffer
3656 .collaborators
3657 .iter()
3658 .filter_map(|c| Some(c.peer_id?.into()));
3659 channel_buffer_updated(
3660 session.connection_id,
3661 collaborators_to_notify,
3662 &proto::UpdateChannelBufferCollaborators {
3663 channel_id: rejoined_buffer.buffer.channel_id,
3664 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3665 },
3666 &session.peer,
3667 );
3668 }
3669
3670 response.send(proto::RejoinChannelBuffersResponse {
3671 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3672 })?;
3673
3674 Ok(())
3675}
3676
3677/// Stop editing the channel notes
3678async fn leave_channel_buffer(
3679 request: proto::LeaveChannelBuffer,
3680 response: Response<proto::LeaveChannelBuffer>,
3681 session: Session,
3682) -> Result<()> {
3683 let db = session.db().await;
3684 let channel_id = ChannelId::from_proto(request.channel_id);
3685
3686 let left_buffer = db
3687 .leave_channel_buffer(channel_id, session.connection_id)
3688 .await?;
3689
3690 response.send(Ack {})?;
3691
3692 channel_buffer_updated(
3693 session.connection_id,
3694 left_buffer.connections,
3695 &proto::UpdateChannelBufferCollaborators {
3696 channel_id: channel_id.to_proto(),
3697 collaborators: left_buffer.collaborators,
3698 },
3699 &session.peer,
3700 );
3701
3702 Ok(())
3703}
3704
3705fn channel_buffer_updated<T: EnvelopedMessage>(
3706 sender_id: ConnectionId,
3707 collaborators: impl IntoIterator<Item = ConnectionId>,
3708 message: &T,
3709 peer: &Peer,
3710) {
3711 broadcast(Some(sender_id), collaborators, |peer_id| {
3712 peer.send(peer_id, message.clone())
3713 });
3714}
3715
3716fn send_notifications(
3717 connection_pool: &ConnectionPool,
3718 peer: &Peer,
3719 notifications: db::NotificationBatch,
3720) {
3721 for (user_id, notification) in notifications {
3722 for connection_id in connection_pool.user_connection_ids(user_id) {
3723 if let Err(error) = peer.send(
3724 connection_id,
3725 proto::AddNotification {
3726 notification: Some(notification.clone()),
3727 },
3728 ) {
3729 tracing::error!(
3730 "failed to send notification to {:?} {}",
3731 connection_id,
3732 error
3733 );
3734 }
3735 }
3736 }
3737}
3738
3739/// Send a message to the channel
3740async fn send_channel_message(
3741 request: proto::SendChannelMessage,
3742 response: Response<proto::SendChannelMessage>,
3743 session: Session,
3744) -> Result<()> {
3745 // Validate the message body.
3746 let body = request.body.trim().to_string();
3747 if body.len() > MAX_MESSAGE_LEN {
3748 return Err(anyhow!("message is too long"))?;
3749 }
3750 if body.is_empty() {
3751 return Err(anyhow!("message can't be blank"))?;
3752 }
3753
3754 // TODO: adjust mentions if body is trimmed
3755
3756 let timestamp = OffsetDateTime::now_utc();
3757 let nonce = request.nonce.context("nonce can't be blank")?;
3758
3759 let channel_id = ChannelId::from_proto(request.channel_id);
3760 let CreatedChannelMessage {
3761 message_id,
3762 participant_connection_ids,
3763 notifications,
3764 } = session
3765 .db()
3766 .await
3767 .create_channel_message(
3768 channel_id,
3769 session.user_id(),
3770 &body,
3771 &request.mentions,
3772 timestamp,
3773 nonce.clone().into(),
3774 request.reply_to_message_id.map(MessageId::from_proto),
3775 )
3776 .await?;
3777
3778 let message = proto::ChannelMessage {
3779 sender_id: session.user_id().to_proto(),
3780 id: message_id.to_proto(),
3781 body,
3782 mentions: request.mentions,
3783 timestamp: timestamp.unix_timestamp() as u64,
3784 nonce: Some(nonce),
3785 reply_to_message_id: request.reply_to_message_id,
3786 edited_at: None,
3787 };
3788 broadcast(
3789 Some(session.connection_id),
3790 participant_connection_ids.clone(),
3791 |connection| {
3792 session.peer.send(
3793 connection,
3794 proto::ChannelMessageSent {
3795 channel_id: channel_id.to_proto(),
3796 message: Some(message.clone()),
3797 },
3798 )
3799 },
3800 );
3801 response.send(proto::SendChannelMessageResponse {
3802 message: Some(message),
3803 })?;
3804
3805 let pool = &*session.connection_pool().await;
3806 let non_participants =
3807 pool.channel_connection_ids(channel_id)
3808 .filter_map(|(connection_id, _)| {
3809 if participant_connection_ids.contains(&connection_id) {
3810 None
3811 } else {
3812 Some(connection_id)
3813 }
3814 });
3815 broadcast(None, non_participants, |peer_id| {
3816 session.peer.send(
3817 peer_id,
3818 proto::UpdateChannels {
3819 latest_channel_message_ids: vec![proto::ChannelMessageId {
3820 channel_id: channel_id.to_proto(),
3821 message_id: message_id.to_proto(),
3822 }],
3823 ..Default::default()
3824 },
3825 )
3826 });
3827 send_notifications(pool, &session.peer, notifications);
3828
3829 Ok(())
3830}
3831
3832/// Delete a channel message
3833async fn remove_channel_message(
3834 request: proto::RemoveChannelMessage,
3835 response: Response<proto::RemoveChannelMessage>,
3836 session: Session,
3837) -> Result<()> {
3838 let channel_id = ChannelId::from_proto(request.channel_id);
3839 let message_id = MessageId::from_proto(request.message_id);
3840 let (connection_ids, existing_notification_ids) = session
3841 .db()
3842 .await
3843 .remove_channel_message(channel_id, message_id, session.user_id())
3844 .await?;
3845
3846 broadcast(
3847 Some(session.connection_id),
3848 connection_ids,
3849 move |connection| {
3850 session.peer.send(connection, request.clone())?;
3851
3852 for notification_id in &existing_notification_ids {
3853 session.peer.send(
3854 connection,
3855 proto::DeleteNotification {
3856 notification_id: (*notification_id).to_proto(),
3857 },
3858 )?;
3859 }
3860
3861 Ok(())
3862 },
3863 );
3864 response.send(proto::Ack {})?;
3865 Ok(())
3866}
3867
3868async fn update_channel_message(
3869 request: proto::UpdateChannelMessage,
3870 response: Response<proto::UpdateChannelMessage>,
3871 session: Session,
3872) -> Result<()> {
3873 let channel_id = ChannelId::from_proto(request.channel_id);
3874 let message_id = MessageId::from_proto(request.message_id);
3875 let updated_at = OffsetDateTime::now_utc();
3876 let UpdatedChannelMessage {
3877 message_id,
3878 participant_connection_ids,
3879 notifications,
3880 reply_to_message_id,
3881 timestamp,
3882 deleted_mention_notification_ids,
3883 updated_mention_notifications,
3884 } = session
3885 .db()
3886 .await
3887 .update_channel_message(
3888 channel_id,
3889 message_id,
3890 session.user_id(),
3891 request.body.as_str(),
3892 &request.mentions,
3893 updated_at,
3894 )
3895 .await?;
3896
3897 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3898
3899 let message = proto::ChannelMessage {
3900 sender_id: session.user_id().to_proto(),
3901 id: message_id.to_proto(),
3902 body: request.body.clone(),
3903 mentions: request.mentions.clone(),
3904 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3905 nonce: Some(nonce),
3906 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3907 edited_at: Some(updated_at.unix_timestamp() as u64),
3908 };
3909
3910 response.send(proto::Ack {})?;
3911
3912 let pool = &*session.connection_pool().await;
3913 broadcast(
3914 Some(session.connection_id),
3915 participant_connection_ids,
3916 |connection| {
3917 session.peer.send(
3918 connection,
3919 proto::ChannelMessageUpdate {
3920 channel_id: channel_id.to_proto(),
3921 message: Some(message.clone()),
3922 },
3923 )?;
3924
3925 for notification_id in &deleted_mention_notification_ids {
3926 session.peer.send(
3927 connection,
3928 proto::DeleteNotification {
3929 notification_id: (*notification_id).to_proto(),
3930 },
3931 )?;
3932 }
3933
3934 for notification in &updated_mention_notifications {
3935 session.peer.send(
3936 connection,
3937 proto::UpdateNotification {
3938 notification: Some(notification.clone()),
3939 },
3940 )?;
3941 }
3942
3943 Ok(())
3944 },
3945 );
3946
3947 send_notifications(pool, &session.peer, notifications);
3948
3949 Ok(())
3950}
3951
3952/// Mark a channel message as read
3953async fn acknowledge_channel_message(
3954 request: proto::AckChannelMessage,
3955 session: Session,
3956) -> Result<()> {
3957 let channel_id = ChannelId::from_proto(request.channel_id);
3958 let message_id = MessageId::from_proto(request.message_id);
3959 let notifications = session
3960 .db()
3961 .await
3962 .observe_channel_message(channel_id, session.user_id(), message_id)
3963 .await?;
3964 send_notifications(
3965 &*session.connection_pool().await,
3966 &session.peer,
3967 notifications,
3968 );
3969 Ok(())
3970}
3971
3972/// Mark a buffer version as synced
3973async fn acknowledge_buffer_version(
3974 request: proto::AckBufferOperation,
3975 session: Session,
3976) -> Result<()> {
3977 let buffer_id = BufferId::from_proto(request.buffer_id);
3978 session
3979 .db()
3980 .await
3981 .observe_buffer_version(
3982 buffer_id,
3983 session.user_id(),
3984 request.epoch as i32,
3985 &request.version,
3986 )
3987 .await?;
3988 Ok(())
3989}
3990
3991/// Get a Supermaven API key for the user
3992async fn get_supermaven_api_key(
3993 _request: proto::GetSupermavenApiKey,
3994 response: Response<proto::GetSupermavenApiKey>,
3995 session: Session,
3996) -> Result<()> {
3997 let user_id: String = session.user_id().to_string();
3998 if !session.is_staff() {
3999 return Err(anyhow!("supermaven not enabled for this account"))?;
4000 }
4001
4002 let email = session.email().context("user must have an email")?;
4003
4004 let supermaven_admin_api = session
4005 .supermaven_client
4006 .as_ref()
4007 .context("supermaven not configured")?;
4008
4009 let result = supermaven_admin_api
4010 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4011 .await?;
4012
4013 response.send(proto::GetSupermavenApiKeyResponse {
4014 api_key: result.api_key,
4015 })?;
4016
4017 Ok(())
4018}
4019
4020/// Start receiving chat updates for a channel
4021async fn join_channel_chat(
4022 request: proto::JoinChannelChat,
4023 response: Response<proto::JoinChannelChat>,
4024 session: Session,
4025) -> Result<()> {
4026 let channel_id = ChannelId::from_proto(request.channel_id);
4027
4028 let db = session.db().await;
4029 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4030 .await?;
4031 let messages = db
4032 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4033 .await?;
4034 response.send(proto::JoinChannelChatResponse {
4035 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4036 messages,
4037 })?;
4038 Ok(())
4039}
4040
4041/// Stop receiving chat updates for a channel
4042async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
4043 let channel_id = ChannelId::from_proto(request.channel_id);
4044 session
4045 .db()
4046 .await
4047 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4048 .await?;
4049 Ok(())
4050}
4051
4052/// Retrieve the chat history for a channel
4053async fn get_channel_messages(
4054 request: proto::GetChannelMessages,
4055 response: Response<proto::GetChannelMessages>,
4056 session: Session,
4057) -> Result<()> {
4058 let channel_id = ChannelId::from_proto(request.channel_id);
4059 let messages = session
4060 .db()
4061 .await
4062 .get_channel_messages(
4063 channel_id,
4064 session.user_id(),
4065 MESSAGE_COUNT_PER_PAGE,
4066 Some(MessageId::from_proto(request.before_message_id)),
4067 )
4068 .await?;
4069 response.send(proto::GetChannelMessagesResponse {
4070 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4071 messages,
4072 })?;
4073 Ok(())
4074}
4075
4076/// Retrieve specific chat messages
4077async fn get_channel_messages_by_id(
4078 request: proto::GetChannelMessagesById,
4079 response: Response<proto::GetChannelMessagesById>,
4080 session: Session,
4081) -> Result<()> {
4082 let message_ids = request
4083 .message_ids
4084 .iter()
4085 .map(|id| MessageId::from_proto(*id))
4086 .collect::<Vec<_>>();
4087 let messages = session
4088 .db()
4089 .await
4090 .get_channel_messages_by_id(session.user_id(), &message_ids)
4091 .await?;
4092 response.send(proto::GetChannelMessagesResponse {
4093 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4094 messages,
4095 })?;
4096 Ok(())
4097}
4098
4099/// Retrieve the current users notifications
4100async fn get_notifications(
4101 request: proto::GetNotifications,
4102 response: Response<proto::GetNotifications>,
4103 session: Session,
4104) -> Result<()> {
4105 let notifications = session
4106 .db()
4107 .await
4108 .get_notifications(
4109 session.user_id(),
4110 NOTIFICATION_COUNT_PER_PAGE,
4111 request.before_id.map(db::NotificationId::from_proto),
4112 )
4113 .await?;
4114 response.send(proto::GetNotificationsResponse {
4115 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4116 notifications,
4117 })?;
4118 Ok(())
4119}
4120
4121/// Mark notifications as read
4122async fn mark_notification_as_read(
4123 request: proto::MarkNotificationRead,
4124 response: Response<proto::MarkNotificationRead>,
4125 session: Session,
4126) -> Result<()> {
4127 let database = &session.db().await;
4128 let notifications = database
4129 .mark_notification_as_read_by_id(
4130 session.user_id(),
4131 NotificationId::from_proto(request.notification_id),
4132 )
4133 .await?;
4134 send_notifications(
4135 &*session.connection_pool().await,
4136 &session.peer,
4137 notifications,
4138 );
4139 response.send(proto::Ack {})?;
4140 Ok(())
4141}
4142
4143/// Get the current users information
4144async fn get_private_user_info(
4145 _request: proto::GetPrivateUserInfo,
4146 response: Response<proto::GetPrivateUserInfo>,
4147 session: Session,
4148) -> Result<()> {
4149 let db = session.db().await;
4150
4151 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4152 let user = db
4153 .get_user_by_id(session.user_id())
4154 .await?
4155 .context("user not found")?;
4156 let flags = db.get_user_flags(session.user_id()).await?;
4157
4158 response.send(proto::GetPrivateUserInfoResponse {
4159 metrics_id,
4160 staff: user.admin,
4161 flags,
4162 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4163 })?;
4164 Ok(())
4165}
4166
4167/// Accept the terms of service (tos) on behalf of the current user
4168async fn accept_terms_of_service(
4169 _request: proto::AcceptTermsOfService,
4170 response: Response<proto::AcceptTermsOfService>,
4171 session: Session,
4172) -> Result<()> {
4173 let db = session.db().await;
4174
4175 let accepted_tos_at = Utc::now();
4176 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4177 .await?;
4178
4179 response.send(proto::AcceptTermsOfServiceResponse {
4180 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4181 })?;
4182
4183 // When the user accepts the terms of service, we want to refresh their LLM
4184 // token to grant access.
4185 session
4186 .peer
4187 .send(session.connection_id, proto::RefreshLlmToken {})?;
4188
4189 Ok(())
4190}
4191
4192async fn get_llm_api_token(
4193 _request: proto::GetLlmToken,
4194 response: Response<proto::GetLlmToken>,
4195 session: Session,
4196) -> Result<()> {
4197 let db = session.db().await;
4198
4199 let flags = db.get_user_flags(session.user_id()).await?;
4200
4201 let user_id = session.user_id();
4202 let user = db
4203 .get_user_by_id(user_id)
4204 .await?
4205 .with_context(|| format!("user {user_id} not found"))?;
4206
4207 if user.accepted_tos_at.is_none() {
4208 Err(anyhow!("terms of service not accepted"))?
4209 }
4210
4211 let stripe_client = session
4212 .app_state
4213 .stripe_client
4214 .as_ref()
4215 .context("failed to retrieve Stripe client")?;
4216
4217 let stripe_billing = session
4218 .app_state
4219 .stripe_billing
4220 .as_ref()
4221 .context("failed to retrieve Stripe billing object")?;
4222
4223 let billing_customer = if let Some(billing_customer) =
4224 db.get_billing_customer_by_user_id(user.id).await?
4225 {
4226 billing_customer
4227 } else {
4228 let customer_id = stripe_billing
4229 .find_or_create_customer_by_email(user.email_address.as_deref())
4230 .await?;
4231
4232 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4233 .await?
4234 .context("billing customer not found")?
4235 };
4236
4237 let billing_subscription =
4238 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4239 billing_subscription
4240 } else {
4241 let stripe_customer_id =
4242 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4243
4244 let stripe_subscription = stripe_billing
4245 .subscribe_to_zed_free(stripe_customer_id)
4246 .await?;
4247
4248 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4249 billing_customer_id: billing_customer.id,
4250 kind: Some(SubscriptionKind::ZedFree),
4251 stripe_subscription_id: stripe_subscription.id.to_string(),
4252 stripe_subscription_status: stripe_subscription.status.into(),
4253 stripe_cancellation_reason: None,
4254 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4255 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4256 })
4257 .await?
4258 };
4259
4260 let billing_preferences = db.get_billing_preferences(user.id).await?;
4261
4262 let token = LlmTokenClaims::create(
4263 &user,
4264 session.is_staff(),
4265 billing_customer,
4266 billing_preferences,
4267 &flags,
4268 billing_subscription,
4269 session.system_id.clone(),
4270 &session.app_state.config,
4271 )?;
4272 response.send(proto::GetLlmTokenResponse { token })?;
4273 Ok(())
4274}
4275
4276fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4277 let message = match message {
4278 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4279 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4280 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4281 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4282 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4283 code: frame.code.into(),
4284 reason: frame.reason.as_str().to_owned().into(),
4285 })),
4286 // We should never receive a frame while reading the message, according
4287 // to the `tungstenite` maintainers:
4288 //
4289 // > It cannot occur when you read messages from the WebSocket, but it
4290 // > can be used when you want to send the raw frames (e.g. you want to
4291 // > send the frames to the WebSocket without composing the full message first).
4292 // >
4293 // > — https://github.com/snapview/tungstenite-rs/issues/268
4294 TungsteniteMessage::Frame(_) => {
4295 bail!("received an unexpected frame while reading the message")
4296 }
4297 };
4298
4299 Ok(message)
4300}
4301
4302fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4303 match message {
4304 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4305 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4306 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4307 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4308 AxumMessage::Close(frame) => {
4309 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4310 code: frame.code.into(),
4311 reason: frame.reason.as_ref().into(),
4312 }))
4313 }
4314 }
4315}
4316
4317fn notify_membership_updated(
4318 connection_pool: &mut ConnectionPool,
4319 result: MembershipUpdated,
4320 user_id: UserId,
4321 peer: &Peer,
4322) {
4323 for membership in &result.new_channels.channel_memberships {
4324 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4325 }
4326 for channel_id in &result.removed_channels {
4327 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4328 }
4329
4330 let user_channels_update = proto::UpdateUserChannels {
4331 channel_memberships: result
4332 .new_channels
4333 .channel_memberships
4334 .iter()
4335 .map(|cm| proto::ChannelMembership {
4336 channel_id: cm.channel_id.to_proto(),
4337 role: cm.role.into(),
4338 })
4339 .collect(),
4340 ..Default::default()
4341 };
4342
4343 let mut update = build_channels_update(result.new_channels);
4344 update.delete_channels = result
4345 .removed_channels
4346 .into_iter()
4347 .map(|id| id.to_proto())
4348 .collect();
4349 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4350
4351 for connection_id in connection_pool.user_connection_ids(user_id) {
4352 peer.send(connection_id, user_channels_update.clone())
4353 .trace_err();
4354 peer.send(connection_id, update.clone()).trace_err();
4355 }
4356}
4357
4358fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4359 proto::UpdateUserChannels {
4360 channel_memberships: channels
4361 .channel_memberships
4362 .iter()
4363 .map(|m| proto::ChannelMembership {
4364 channel_id: m.channel_id.to_proto(),
4365 role: m.role.into(),
4366 })
4367 .collect(),
4368 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4369 observed_channel_message_id: channels.observed_channel_messages.clone(),
4370 }
4371}
4372
4373fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4374 let mut update = proto::UpdateChannels::default();
4375
4376 for channel in channels.channels {
4377 update.channels.push(channel.to_proto());
4378 }
4379
4380 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4381 update.latest_channel_message_ids = channels.latest_channel_messages;
4382
4383 for (channel_id, participants) in channels.channel_participants {
4384 update
4385 .channel_participants
4386 .push(proto::ChannelParticipants {
4387 channel_id: channel_id.to_proto(),
4388 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4389 });
4390 }
4391
4392 for channel in channels.invited_channels {
4393 update.channel_invitations.push(channel.to_proto());
4394 }
4395
4396 update
4397}
4398
4399fn build_initial_contacts_update(
4400 contacts: Vec<db::Contact>,
4401 pool: &ConnectionPool,
4402) -> proto::UpdateContacts {
4403 let mut update = proto::UpdateContacts::default();
4404
4405 for contact in contacts {
4406 match contact {
4407 db::Contact::Accepted { user_id, busy } => {
4408 update.contacts.push(contact_for_user(user_id, busy, pool));
4409 }
4410 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4411 db::Contact::Incoming { user_id } => {
4412 update
4413 .incoming_requests
4414 .push(proto::IncomingContactRequest {
4415 requester_id: user_id.to_proto(),
4416 })
4417 }
4418 }
4419 }
4420
4421 update
4422}
4423
4424fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4425 proto::Contact {
4426 user_id: user_id.to_proto(),
4427 online: pool.is_user_online(user_id),
4428 busy,
4429 }
4430}
4431
4432fn room_updated(room: &proto::Room, peer: &Peer) {
4433 broadcast(
4434 None,
4435 room.participants
4436 .iter()
4437 .filter_map(|participant| Some(participant.peer_id?.into())),
4438 |peer_id| {
4439 peer.send(
4440 peer_id,
4441 proto::RoomUpdated {
4442 room: Some(room.clone()),
4443 },
4444 )
4445 },
4446 );
4447}
4448
4449fn channel_updated(
4450 channel: &db::channel::Model,
4451 room: &proto::Room,
4452 peer: &Peer,
4453 pool: &ConnectionPool,
4454) {
4455 let participants = room
4456 .participants
4457 .iter()
4458 .map(|p| p.user_id)
4459 .collect::<Vec<_>>();
4460
4461 broadcast(
4462 None,
4463 pool.channel_connection_ids(channel.root_id())
4464 .filter_map(|(channel_id, role)| {
4465 role.can_see_channel(channel.visibility)
4466 .then_some(channel_id)
4467 }),
4468 |peer_id| {
4469 peer.send(
4470 peer_id,
4471 proto::UpdateChannels {
4472 channel_participants: vec![proto::ChannelParticipants {
4473 channel_id: channel.id.to_proto(),
4474 participant_user_ids: participants.clone(),
4475 }],
4476 ..Default::default()
4477 },
4478 )
4479 },
4480 );
4481}
4482
4483async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4484 let db = session.db().await;
4485
4486 let contacts = db.get_contacts(user_id).await?;
4487 let busy = db.is_user_busy(user_id).await?;
4488
4489 let pool = session.connection_pool().await;
4490 let updated_contact = contact_for_user(user_id, busy, &pool);
4491 for contact in contacts {
4492 if let db::Contact::Accepted {
4493 user_id: contact_user_id,
4494 ..
4495 } = contact
4496 {
4497 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4498 session
4499 .peer
4500 .send(
4501 contact_conn_id,
4502 proto::UpdateContacts {
4503 contacts: vec![updated_contact.clone()],
4504 remove_contacts: Default::default(),
4505 incoming_requests: Default::default(),
4506 remove_incoming_requests: Default::default(),
4507 outgoing_requests: Default::default(),
4508 remove_outgoing_requests: Default::default(),
4509 },
4510 )
4511 .trace_err();
4512 }
4513 }
4514 }
4515 Ok(())
4516}
4517
4518async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4519 let mut contacts_to_update = HashSet::default();
4520
4521 let room_id;
4522 let canceled_calls_to_user_ids;
4523 let livekit_room;
4524 let delete_livekit_room;
4525 let room;
4526 let channel;
4527
4528 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4529 contacts_to_update.insert(session.user_id());
4530
4531 for project in left_room.left_projects.values() {
4532 project_left(project, session);
4533 }
4534
4535 room_id = RoomId::from_proto(left_room.room.id);
4536 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4537 livekit_room = mem::take(&mut left_room.room.livekit_room);
4538 delete_livekit_room = left_room.deleted;
4539 room = mem::take(&mut left_room.room);
4540 channel = mem::take(&mut left_room.channel);
4541
4542 room_updated(&room, &session.peer);
4543 } else {
4544 return Ok(());
4545 }
4546
4547 if let Some(channel) = channel {
4548 channel_updated(
4549 &channel,
4550 &room,
4551 &session.peer,
4552 &*session.connection_pool().await,
4553 );
4554 }
4555
4556 {
4557 let pool = session.connection_pool().await;
4558 for canceled_user_id in canceled_calls_to_user_ids {
4559 for connection_id in pool.user_connection_ids(canceled_user_id) {
4560 session
4561 .peer
4562 .send(
4563 connection_id,
4564 proto::CallCanceled {
4565 room_id: room_id.to_proto(),
4566 },
4567 )
4568 .trace_err();
4569 }
4570 contacts_to_update.insert(canceled_user_id);
4571 }
4572 }
4573
4574 for contact_user_id in contacts_to_update {
4575 update_user_contacts(contact_user_id, session).await?;
4576 }
4577
4578 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4579 live_kit
4580 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4581 .await
4582 .trace_err();
4583
4584 if delete_livekit_room {
4585 live_kit.delete_room(livekit_room).await.trace_err();
4586 }
4587 }
4588
4589 Ok(())
4590}
4591
4592async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4593 let left_channel_buffers = session
4594 .db()
4595 .await
4596 .leave_channel_buffers(session.connection_id)
4597 .await?;
4598
4599 for left_buffer in left_channel_buffers {
4600 channel_buffer_updated(
4601 session.connection_id,
4602 left_buffer.connections,
4603 &proto::UpdateChannelBufferCollaborators {
4604 channel_id: left_buffer.channel_id.to_proto(),
4605 collaborators: left_buffer.collaborators,
4606 },
4607 &session.peer,
4608 );
4609 }
4610
4611 Ok(())
4612}
4613
4614fn project_left(project: &db::LeftProject, session: &Session) {
4615 for connection_id in &project.connection_ids {
4616 if project.should_unshare {
4617 session
4618 .peer
4619 .send(
4620 *connection_id,
4621 proto::UnshareProject {
4622 project_id: project.id.to_proto(),
4623 },
4624 )
4625 .trace_err();
4626 } else {
4627 session
4628 .peer
4629 .send(
4630 *connection_id,
4631 proto::RemoveProjectCollaborator {
4632 project_id: project.id.to_proto(),
4633 peer_id: Some(session.connection_id.into()),
4634 },
4635 )
4636 .trace_err();
4637 }
4638 }
4639}
4640
4641pub trait ResultExt {
4642 type Ok;
4643
4644 fn trace_err(self) -> Option<Self::Ok>;
4645}
4646
4647impl<T, E> ResultExt for Result<T, E>
4648where
4649 E: std::fmt::Debug,
4650{
4651 type Ok = T;
4652
4653 #[track_caller]
4654 fn trace_err(self) -> Option<T> {
4655 match self {
4656 Ok(value) => Some(value),
4657 Err(error) => {
4658 tracing::error!("{:?}", error);
4659 None
4660 }
4661 }
4662 }
4663}