1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{
8 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
9 MIN_ACCOUNT_AGE_FOR_LLM_USE,
10};
11use crate::stripe_client::StripeCustomerId;
12use crate::{
13 AppState, Error, Result, auth,
14 db::{
15 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
16 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
17 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
18 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
19 },
20 executor::Executor,
21};
22use anyhow::{Context as _, anyhow, bail};
23use async_tungstenite::tungstenite::{
24 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
25};
26use axum::{
27 Extension, Router, TypedHeader,
28 body::Body,
29 extract::{
30 ConnectInfo, WebSocketUpgrade,
31 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
32 },
33 headers::{Header, HeaderName},
34 http::StatusCode,
35 middleware,
36 response::IntoResponse,
37 routing::get,
38};
39use chrono::Utc;
40use collections::{HashMap, HashSet};
41pub use connection_pool::{ConnectionPool, ZedVersion};
42use core::fmt::{self, Debug, Formatter};
43use reqwest_client::ReqwestClient;
44use rpc::proto::split_repository_update;
45use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
46
47use futures::{
48 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
49 stream::FuturesUnordered,
50};
51use prometheus::{IntGauge, register_int_gauge};
52use rpc::{
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54 proto::{
55 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
56 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
57 },
58};
59use semantic_version::SemanticVersion;
60use serde::{Serialize, Serializer};
61use std::{
62 any::TypeId,
63 future::Future,
64 marker::PhantomData,
65 mem,
66 net::SocketAddr,
67 ops::{Deref, DerefMut},
68 rc::Rc,
69 sync::{
70 Arc, OnceLock,
71 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
72 },
73 time::{Duration, Instant},
74};
75use time::OffsetDateTime;
76use tokio::sync::{Semaphore, watch};
77use tower::ServiceBuilder;
78use tracing::{
79 Instrument,
80 field::{self},
81 info_span, instrument,
82};
83
84pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
85
86// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
87pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
88
89const MESSAGE_COUNT_PER_PAGE: usize = 100;
90const MAX_MESSAGE_LEN: usize = 1024;
91const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
92const MAX_CONCURRENT_CONNECTIONS: usize = 512;
93
94static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
95
96type MessageHandler =
97 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
98
99pub struct ConnectionGuard;
100
101impl ConnectionGuard {
102 pub fn try_acquire() -> Result<Self, ()> {
103 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
104 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
105 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
106 tracing::error!(
107 "too many concurrent connections: {}",
108 current_connections + 1
109 );
110 return Err(());
111 }
112 Ok(ConnectionGuard)
113 }
114}
115
116impl Drop for ConnectionGuard {
117 fn drop(&mut self) {
118 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
119 }
120}
121
122struct Response<R> {
123 peer: Arc<Peer>,
124 receipt: Receipt<R>,
125 responded: Arc<AtomicBool>,
126}
127
128impl<R: RequestMessage> Response<R> {
129 fn send(self, payload: R::Response) -> Result<()> {
130 self.responded.store(true, SeqCst);
131 self.peer.respond(self.receipt, payload)?;
132 Ok(())
133 }
134}
135
136#[derive(Clone, Debug)]
137pub enum Principal {
138 User(User),
139 Impersonated { user: User, admin: User },
140}
141
142impl Principal {
143 fn user(&self) -> &User {
144 match self {
145 Principal::User(user) => user,
146 Principal::Impersonated { user, .. } => user,
147 }
148 }
149
150 fn update_span(&self, span: &tracing::Span) {
151 match &self {
152 Principal::User(user) => {
153 span.record("user_id", user.id.0);
154 span.record("login", &user.github_login);
155 }
156 Principal::Impersonated { user, admin } => {
157 span.record("user_id", user.id.0);
158 span.record("login", &user.github_login);
159 span.record("impersonator", &admin.github_login);
160 }
161 }
162 }
163}
164
165#[derive(Clone)]
166struct Session {
167 principal: Principal,
168 connection_id: ConnectionId,
169 db: Arc<tokio::sync::Mutex<DbHandle>>,
170 peer: Arc<Peer>,
171 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
172 app_state: Arc<AppState>,
173 supermaven_client: Option<Arc<SupermavenAdminApi>>,
174 /// The GeoIP country code for the user.
175 #[allow(unused)]
176 geoip_country_code: Option<String>,
177 system_id: Option<String>,
178 _executor: Executor,
179}
180
181impl Session {
182 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
183 #[cfg(test)]
184 tokio::task::yield_now().await;
185 let guard = self.db.lock().await;
186 #[cfg(test)]
187 tokio::task::yield_now().await;
188 guard
189 }
190
191 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
192 #[cfg(test)]
193 tokio::task::yield_now().await;
194 let guard = self.connection_pool.lock();
195 ConnectionPoolGuard {
196 guard,
197 _not_send: PhantomData,
198 }
199 }
200
201 fn is_staff(&self) -> bool {
202 match &self.principal {
203 Principal::User(user) => user.admin,
204 Principal::Impersonated { .. } => true,
205 }
206 }
207
208 fn user_id(&self) -> UserId {
209 match &self.principal {
210 Principal::User(user) => user.id,
211 Principal::Impersonated { user, .. } => user.id,
212 }
213 }
214
215 pub fn email(&self) -> Option<String> {
216 match &self.principal {
217 Principal::User(user) => user.email_address.clone(),
218 Principal::Impersonated { user, .. } => user.email_address.clone(),
219 }
220 }
221}
222
223impl Debug for Session {
224 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
225 let mut result = f.debug_struct("Session");
226 match &self.principal {
227 Principal::User(user) => {
228 result.field("user", &user.github_login);
229 }
230 Principal::Impersonated { user, admin } => {
231 result.field("user", &user.github_login);
232 result.field("impersonator", &admin.github_login);
233 }
234 }
235 result.field("connection_id", &self.connection_id).finish()
236 }
237}
238
239struct DbHandle(Arc<Database>);
240
241impl Deref for DbHandle {
242 type Target = Database;
243
244 fn deref(&self) -> &Self::Target {
245 self.0.as_ref()
246 }
247}
248
249pub struct Server {
250 id: parking_lot::Mutex<ServerId>,
251 peer: Arc<Peer>,
252 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
253 app_state: Arc<AppState>,
254 handlers: HashMap<TypeId, MessageHandler>,
255 teardown: watch::Sender<bool>,
256}
257
258pub(crate) struct ConnectionPoolGuard<'a> {
259 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
260 _not_send: PhantomData<Rc<()>>,
261}
262
263#[derive(Serialize)]
264pub struct ServerSnapshot<'a> {
265 peer: &'a Peer,
266 #[serde(serialize_with = "serialize_deref")]
267 connection_pool: ConnectionPoolGuard<'a>,
268}
269
270pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
271where
272 S: Serializer,
273 T: Deref<Target = U>,
274 U: Serialize,
275{
276 Serialize::serialize(value.deref(), serializer)
277}
278
279impl Server {
280 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
281 let mut server = Self {
282 id: parking_lot::Mutex::new(id),
283 peer: Peer::new(id.0 as u32),
284 app_state: app_state.clone(),
285 connection_pool: Default::default(),
286 handlers: Default::default(),
287 teardown: watch::channel(false).0,
288 };
289
290 server
291 .add_request_handler(ping)
292 .add_request_handler(create_room)
293 .add_request_handler(join_room)
294 .add_request_handler(rejoin_room)
295 .add_request_handler(leave_room)
296 .add_request_handler(set_room_participant_role)
297 .add_request_handler(call)
298 .add_request_handler(cancel_call)
299 .add_message_handler(decline_call)
300 .add_request_handler(update_participant_location)
301 .add_request_handler(share_project)
302 .add_message_handler(unshare_project)
303 .add_request_handler(join_project)
304 .add_message_handler(leave_project)
305 .add_request_handler(update_project)
306 .add_request_handler(update_worktree)
307 .add_request_handler(update_repository)
308 .add_request_handler(remove_repository)
309 .add_message_handler(start_language_server)
310 .add_message_handler(update_language_server)
311 .add_message_handler(update_diagnostic_summary)
312 .add_message_handler(update_worktree_settings)
313 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
314 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
316 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
317 .add_request_handler(forward_find_search_candidates_request)
318 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
319 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
320 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
321 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
322 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
323 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
324 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
325 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
326 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
327 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
328 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
329 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
330 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
331 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
332 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
333 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
334 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
335 .add_request_handler(
336 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
337 )
338 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
339 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
340 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
341 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
342 .add_request_handler(
343 forward_read_only_project_request::<proto::LanguageServerIdForName>,
344 )
345 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
346 .add_request_handler(
347 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
348 )
349 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
350 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
353 )
354 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
355 .add_request_handler(
356 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
357 )
358 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
359 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
360 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
361 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
362 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
363 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
364 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
365 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
366 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
367 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
368 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
369 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
370 .add_request_handler(
371 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
372 )
373 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
374 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
375 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
376 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
377 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
378 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
379 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
380 .add_message_handler(create_buffer_for_peer)
381 .add_request_handler(update_buffer)
382 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
383 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
384 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
385 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
386 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
387 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
388 .add_message_handler(
389 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
390 )
391 .add_request_handler(get_users)
392 .add_request_handler(fuzzy_search_users)
393 .add_request_handler(request_contact)
394 .add_request_handler(remove_contact)
395 .add_request_handler(respond_to_contact_request)
396 .add_message_handler(subscribe_to_channels)
397 .add_request_handler(create_channel)
398 .add_request_handler(delete_channel)
399 .add_request_handler(invite_channel_member)
400 .add_request_handler(remove_channel_member)
401 .add_request_handler(set_channel_member_role)
402 .add_request_handler(set_channel_visibility)
403 .add_request_handler(rename_channel)
404 .add_request_handler(join_channel_buffer)
405 .add_request_handler(leave_channel_buffer)
406 .add_message_handler(update_channel_buffer)
407 .add_request_handler(rejoin_channel_buffers)
408 .add_request_handler(get_channel_members)
409 .add_request_handler(respond_to_channel_invite)
410 .add_request_handler(join_channel)
411 .add_request_handler(join_channel_chat)
412 .add_message_handler(leave_channel_chat)
413 .add_request_handler(send_channel_message)
414 .add_request_handler(remove_channel_message)
415 .add_request_handler(update_channel_message)
416 .add_request_handler(get_channel_messages)
417 .add_request_handler(get_channel_messages_by_id)
418 .add_request_handler(get_notifications)
419 .add_request_handler(mark_notification_as_read)
420 .add_request_handler(move_channel)
421 .add_request_handler(reorder_channel)
422 .add_request_handler(follow)
423 .add_message_handler(unfollow)
424 .add_message_handler(update_followers)
425 .add_request_handler(get_private_user_info)
426 .add_request_handler(get_llm_api_token)
427 .add_request_handler(accept_terms_of_service)
428 .add_message_handler(acknowledge_channel_message)
429 .add_message_handler(acknowledge_buffer_version)
430 .add_request_handler(get_supermaven_api_key)
431 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
432 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
433 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
434 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
435 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
436 .add_request_handler(forward_mutating_project_request::<proto::Stash>)
437 .add_request_handler(forward_mutating_project_request::<proto::StashPop>)
438 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
439 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
440 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
441 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
442 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
443 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
444 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
445 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
446 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
447 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
448 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
449 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
450 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
451 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
452 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
453 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
454 .add_message_handler(update_context);
455
456 Arc::new(server)
457 }
458
459 pub async fn start(&self) -> Result<()> {
460 let server_id = *self.id.lock();
461 let app_state = self.app_state.clone();
462 let peer = self.peer.clone();
463 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
464 let pool = self.connection_pool.clone();
465 let livekit_client = self.app_state.livekit_client.clone();
466
467 let span = info_span!("start server");
468 self.app_state.executor.spawn_detached(
469 async move {
470 tracing::info!("waiting for cleanup timeout");
471 timeout.await;
472 tracing::info!("cleanup timeout expired, retrieving stale rooms");
473
474 app_state
475 .db
476 .delete_stale_channel_chat_participants(
477 &app_state.config.zed_environment,
478 server_id,
479 )
480 .await
481 .trace_err();
482
483 if let Some((room_ids, channel_ids)) = app_state
484 .db
485 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
486 .await
487 .trace_err()
488 {
489 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
490 tracing::info!(
491 stale_channel_buffer_count = channel_ids.len(),
492 "retrieved stale channel buffers"
493 );
494
495 for channel_id in channel_ids {
496 if let Some(refreshed_channel_buffer) = app_state
497 .db
498 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
499 .await
500 .trace_err()
501 {
502 for connection_id in refreshed_channel_buffer.connection_ids {
503 peer.send(
504 connection_id,
505 proto::UpdateChannelBufferCollaborators {
506 channel_id: channel_id.to_proto(),
507 collaborators: refreshed_channel_buffer
508 .collaborators
509 .clone(),
510 },
511 )
512 .trace_err();
513 }
514 }
515 }
516
517 for room_id in room_ids {
518 let mut contacts_to_update = HashSet::default();
519 let mut canceled_calls_to_user_ids = Vec::new();
520 let mut livekit_room = String::new();
521 let mut delete_livekit_room = false;
522
523 if let Some(mut refreshed_room) = app_state
524 .db
525 .clear_stale_room_participants(room_id, server_id)
526 .await
527 .trace_err()
528 {
529 tracing::info!(
530 room_id = room_id.0,
531 new_participant_count = refreshed_room.room.participants.len(),
532 "refreshed room"
533 );
534 room_updated(&refreshed_room.room, &peer);
535 if let Some(channel) = refreshed_room.channel.as_ref() {
536 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
537 }
538 contacts_to_update
539 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
540 contacts_to_update
541 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
542 canceled_calls_to_user_ids =
543 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
544 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
545 delete_livekit_room = refreshed_room.room.participants.is_empty();
546 }
547
548 {
549 let pool = pool.lock();
550 for canceled_user_id in canceled_calls_to_user_ids {
551 for connection_id in pool.user_connection_ids(canceled_user_id) {
552 peer.send(
553 connection_id,
554 proto::CallCanceled {
555 room_id: room_id.to_proto(),
556 },
557 )
558 .trace_err();
559 }
560 }
561 }
562
563 for user_id in contacts_to_update {
564 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
565 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
566 if let Some((busy, contacts)) = busy.zip(contacts) {
567 let pool = pool.lock();
568 let updated_contact = contact_for_user(user_id, busy, &pool);
569 for contact in contacts {
570 if let db::Contact::Accepted {
571 user_id: contact_user_id,
572 ..
573 } = contact
574 {
575 for contact_conn_id in
576 pool.user_connection_ids(contact_user_id)
577 {
578 peer.send(
579 contact_conn_id,
580 proto::UpdateContacts {
581 contacts: vec![updated_contact.clone()],
582 remove_contacts: Default::default(),
583 incoming_requests: Default::default(),
584 remove_incoming_requests: Default::default(),
585 outgoing_requests: Default::default(),
586 remove_outgoing_requests: Default::default(),
587 },
588 )
589 .trace_err();
590 }
591 }
592 }
593 }
594 }
595
596 if let Some(live_kit) = livekit_client.as_ref() {
597 if delete_livekit_room {
598 live_kit.delete_room(livekit_room).await.trace_err();
599 }
600 }
601 }
602 }
603
604 app_state
605 .db
606 .delete_stale_channel_chat_participants(
607 &app_state.config.zed_environment,
608 server_id,
609 )
610 .await
611 .trace_err();
612
613 app_state
614 .db
615 .clear_old_worktree_entries(server_id)
616 .await
617 .trace_err();
618
619 app_state
620 .db
621 .delete_stale_servers(&app_state.config.zed_environment, server_id)
622 .await
623 .trace_err();
624 }
625 .instrument(span),
626 );
627 Ok(())
628 }
629
630 pub fn teardown(&self) {
631 self.peer.teardown();
632 self.connection_pool.lock().reset();
633 let _ = self.teardown.send(true);
634 }
635
636 #[cfg(test)]
637 pub fn reset(&self, id: ServerId) {
638 self.teardown();
639 *self.id.lock() = id;
640 self.peer.reset(id.0 as u32);
641 let _ = self.teardown.send(false);
642 }
643
644 #[cfg(test)]
645 pub fn id(&self) -> ServerId {
646 *self.id.lock()
647 }
648
649 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
650 where
651 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
652 Fut: 'static + Send + Future<Output = Result<()>>,
653 M: EnvelopedMessage,
654 {
655 let prev_handler = self.handlers.insert(
656 TypeId::of::<M>(),
657 Box::new(move |envelope, session| {
658 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
659 let received_at = envelope.received_at;
660 tracing::info!("message received");
661 let start_time = Instant::now();
662 let future = (handler)(*envelope, session);
663 async move {
664 let result = future.await;
665 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
666 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
667 let queue_duration_ms = total_duration_ms - processing_duration_ms;
668 let payload_type = M::NAME;
669
670 match result {
671 Err(error) => {
672 tracing::error!(
673 ?error,
674 total_duration_ms,
675 processing_duration_ms,
676 queue_duration_ms,
677 payload_type,
678 "error handling message"
679 )
680 }
681 Ok(()) => tracing::info!(
682 total_duration_ms,
683 processing_duration_ms,
684 queue_duration_ms,
685 "finished handling message"
686 ),
687 }
688 }
689 .boxed()
690 }),
691 );
692 if prev_handler.is_some() {
693 panic!("registered a handler for the same message twice");
694 }
695 self
696 }
697
698 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
699 where
700 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
701 Fut: 'static + Send + Future<Output = Result<()>>,
702 M: EnvelopedMessage,
703 {
704 self.add_handler(move |envelope, session| handler(envelope.payload, session));
705 self
706 }
707
708 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
709 where
710 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
711 Fut: Send + Future<Output = Result<()>>,
712 M: RequestMessage,
713 {
714 let handler = Arc::new(handler);
715 self.add_handler(move |envelope, session| {
716 let receipt = envelope.receipt();
717 let handler = handler.clone();
718 async move {
719 let peer = session.peer.clone();
720 let responded = Arc::new(AtomicBool::default());
721 let response = Response {
722 peer: peer.clone(),
723 responded: responded.clone(),
724 receipt,
725 };
726 match (handler)(envelope.payload, response, session).await {
727 Ok(()) => {
728 if responded.load(std::sync::atomic::Ordering::SeqCst) {
729 Ok(())
730 } else {
731 Err(anyhow!("handler did not send a response"))?
732 }
733 }
734 Err(error) => {
735 let proto_err = match &error {
736 Error::Internal(err) => err.to_proto(),
737 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
738 };
739 peer.respond_with_error(receipt, proto_err)?;
740 Err(error)
741 }
742 }
743 }
744 })
745 }
746
747 pub fn handle_connection(
748 self: &Arc<Self>,
749 connection: Connection,
750 address: String,
751 principal: Principal,
752 zed_version: ZedVersion,
753 geoip_country_code: Option<String>,
754 system_id: Option<String>,
755 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
756 executor: Executor,
757 connection_guard: Option<ConnectionGuard>,
758 ) -> impl Future<Output = ()> + use<> {
759 let this = self.clone();
760 let span = info_span!("handle connection", %address,
761 connection_id=field::Empty,
762 user_id=field::Empty,
763 login=field::Empty,
764 impersonator=field::Empty,
765 geoip_country_code=field::Empty
766 );
767 principal.update_span(&span);
768 if let Some(country_code) = geoip_country_code.as_ref() {
769 span.record("geoip_country_code", country_code);
770 }
771
772 let mut teardown = self.teardown.subscribe();
773 async move {
774 if *teardown.borrow() {
775 tracing::error!("server is tearing down");
776 return
777 }
778
779 let (connection_id, handle_io, mut incoming_rx) = this
780 .peer
781 .add_connection(connection, {
782 let executor = executor.clone();
783 move |duration| executor.sleep(duration)
784 });
785 tracing::Span::current().record("connection_id", format!("{}", connection_id));
786
787 tracing::info!("connection opened");
788
789 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
790 let http_client = match ReqwestClient::user_agent(&user_agent) {
791 Ok(http_client) => Arc::new(http_client),
792 Err(error) => {
793 tracing::error!(?error, "failed to create HTTP client");
794 return;
795 }
796 };
797
798 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
799 supermaven_admin_api_key.to_string(),
800 http_client.clone(),
801 )));
802
803 let session = Session {
804 principal: principal.clone(),
805 connection_id,
806 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
807 peer: this.peer.clone(),
808 connection_pool: this.connection_pool.clone(),
809 app_state: this.app_state.clone(),
810 geoip_country_code,
811 system_id,
812 _executor: executor.clone(),
813 supermaven_client,
814 };
815
816 if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
817 tracing::error!(?error, "failed to send initial client update");
818 return;
819 }
820 drop(connection_guard);
821
822 let handle_io = handle_io.fuse();
823 futures::pin_mut!(handle_io);
824
825 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
826 // This prevents deadlocks when e.g., client A performs a request to client B and
827 // client B performs a request to client A. If both clients stop processing further
828 // messages until their respective request completes, they won't have a chance to
829 // respond to the other client's request and cause a deadlock.
830 //
831 // This arrangement ensures we will attempt to process earlier messages first, but fall
832 // back to processing messages arrived later in the spirit of making progress.
833 let mut foreground_message_handlers = FuturesUnordered::new();
834 let concurrent_handlers = Arc::new(Semaphore::new(512));
835 loop {
836 let next_message = async {
837 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
838 let message = incoming_rx.next().await;
839 (permit, message)
840 }.fuse();
841 futures::pin_mut!(next_message);
842 futures::select_biased! {
843 _ = teardown.changed().fuse() => return,
844 result = handle_io => {
845 if let Err(error) = result {
846 tracing::error!(?error, "error handling I/O");
847 }
848 break;
849 }
850 _ = foreground_message_handlers.next() => {}
851 next_message = next_message => {
852 let (permit, message) = next_message;
853 if let Some(message) = message {
854 let type_name = message.payload_type_name();
855 // note: we copy all the fields from the parent span so we can query them in the logs.
856 // (https://github.com/tokio-rs/tracing/issues/2670).
857 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
858 user_id=field::Empty,
859 login=field::Empty,
860 impersonator=field::Empty,
861 );
862 principal.update_span(&span);
863 let span_enter = span.enter();
864 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
865 let is_background = message.is_background();
866 let handle_message = (handler)(message, session.clone());
867 drop(span_enter);
868
869 let handle_message = async move {
870 handle_message.await;
871 drop(permit);
872 }.instrument(span);
873 if is_background {
874 executor.spawn_detached(handle_message);
875 } else {
876 foreground_message_handlers.push(handle_message);
877 }
878 } else {
879 tracing::error!("no message handler");
880 }
881 } else {
882 tracing::info!("connection closed");
883 break;
884 }
885 }
886 }
887 }
888
889 drop(foreground_message_handlers);
890 tracing::info!("signing out");
891 if let Err(error) = connection_lost(session, teardown, executor).await {
892 tracing::error!(?error, "error signing out");
893 }
894
895 }.instrument(span)
896 }
897
898 async fn send_initial_client_update(
899 &self,
900 connection_id: ConnectionId,
901 zed_version: ZedVersion,
902 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
903 session: &Session,
904 ) -> Result<()> {
905 self.peer.send(
906 connection_id,
907 proto::Hello {
908 peer_id: Some(connection_id.into()),
909 },
910 )?;
911 tracing::info!("sent hello message");
912 if let Some(send_connection_id) = send_connection_id.take() {
913 let _ = send_connection_id.send(connection_id);
914 }
915
916 match &session.principal {
917 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
918 if !user.connected_once {
919 self.peer.send(connection_id, proto::ShowContacts {})?;
920 self.app_state
921 .db
922 .set_user_connected_once(user.id, true)
923 .await?;
924 }
925
926 update_user_plan(session).await?;
927
928 let contacts = self.app_state.db.get_contacts(user.id).await?;
929
930 {
931 let mut pool = self.connection_pool.lock();
932 pool.add_connection(connection_id, user.id, user.admin, zed_version);
933 self.peer.send(
934 connection_id,
935 build_initial_contacts_update(contacts, &pool),
936 )?;
937 }
938
939 if should_auto_subscribe_to_channels(zed_version) {
940 subscribe_user_to_channels(user.id, session).await?;
941 }
942
943 if let Some(incoming_call) =
944 self.app_state.db.incoming_call_for_user(user.id).await?
945 {
946 self.peer.send(connection_id, incoming_call)?;
947 }
948
949 update_user_contacts(user.id, session).await?;
950 }
951 }
952
953 Ok(())
954 }
955
956 pub async fn invite_code_redeemed(
957 self: &Arc<Self>,
958 inviter_id: UserId,
959 invitee_id: UserId,
960 ) -> Result<()> {
961 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
962 if let Some(code) = &user.invite_code {
963 let pool = self.connection_pool.lock();
964 let invitee_contact = contact_for_user(invitee_id, false, &pool);
965 for connection_id in pool.user_connection_ids(inviter_id) {
966 self.peer.send(
967 connection_id,
968 proto::UpdateContacts {
969 contacts: vec![invitee_contact.clone()],
970 ..Default::default()
971 },
972 )?;
973 self.peer.send(
974 connection_id,
975 proto::UpdateInviteInfo {
976 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
977 count: user.invite_count as u32,
978 },
979 )?;
980 }
981 }
982 }
983 Ok(())
984 }
985
986 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
987 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
988 if let Some(invite_code) = &user.invite_code {
989 let pool = self.connection_pool.lock();
990 for connection_id in pool.user_connection_ids(user_id) {
991 self.peer.send(
992 connection_id,
993 proto::UpdateInviteInfo {
994 url: format!(
995 "{}{}",
996 self.app_state.config.invite_link_prefix, invite_code
997 ),
998 count: user.invite_count as u32,
999 },
1000 )?;
1001 }
1002 }
1003 }
1004 Ok(())
1005 }
1006
1007 pub async fn update_plan_for_user(
1008 self: &Arc<Self>,
1009 user_id: UserId,
1010 update_user_plan: proto::UpdateUserPlan,
1011 ) -> Result<()> {
1012 let pool = self.connection_pool.lock();
1013 for connection_id in pool.user_connection_ids(user_id) {
1014 self.peer
1015 .send(connection_id, update_user_plan.clone())
1016 .trace_err();
1017 }
1018
1019 Ok(())
1020 }
1021
1022 /// This is the legacy way of updating the user's plan, where we fetch the data to construct the `UpdateUserPlan`
1023 /// message on the Collab server.
1024 ///
1025 /// The new way is to receive the data from Cloud via the `POST /users/:id/update_plan` endpoint.
1026 pub async fn update_plan_for_user_legacy(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1027 let user = self
1028 .app_state
1029 .db
1030 .get_user_by_id(user_id)
1031 .await?
1032 .context("user not found")?;
1033
1034 let update_user_plan = make_update_user_plan_message(
1035 &user,
1036 user.admin,
1037 &self.app_state.db,
1038 self.app_state.llm_db.clone(),
1039 )
1040 .await?;
1041
1042 self.update_plan_for_user(user_id, update_user_plan).await
1043 }
1044
1045 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1046 let pool = self.connection_pool.lock();
1047 for connection_id in pool.user_connection_ids(user_id) {
1048 self.peer
1049 .send(connection_id, proto::RefreshLlmToken {})
1050 .trace_err();
1051 }
1052 }
1053
1054 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1055 ServerSnapshot {
1056 connection_pool: ConnectionPoolGuard {
1057 guard: self.connection_pool.lock(),
1058 _not_send: PhantomData,
1059 },
1060 peer: &self.peer,
1061 }
1062 }
1063}
1064
1065impl Deref for ConnectionPoolGuard<'_> {
1066 type Target = ConnectionPool;
1067
1068 fn deref(&self) -> &Self::Target {
1069 &self.guard
1070 }
1071}
1072
1073impl DerefMut for ConnectionPoolGuard<'_> {
1074 fn deref_mut(&mut self) -> &mut Self::Target {
1075 &mut self.guard
1076 }
1077}
1078
1079impl Drop for ConnectionPoolGuard<'_> {
1080 fn drop(&mut self) {
1081 #[cfg(test)]
1082 self.check_invariants();
1083 }
1084}
1085
1086fn broadcast<F>(
1087 sender_id: Option<ConnectionId>,
1088 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1089 mut f: F,
1090) where
1091 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1092{
1093 for receiver_id in receiver_ids {
1094 if Some(receiver_id) != sender_id {
1095 if let Err(error) = f(receiver_id) {
1096 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1097 }
1098 }
1099 }
1100}
1101
1102pub struct ProtocolVersion(u32);
1103
1104impl Header for ProtocolVersion {
1105 fn name() -> &'static HeaderName {
1106 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1107 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1108 }
1109
1110 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1111 where
1112 Self: Sized,
1113 I: Iterator<Item = &'i axum::http::HeaderValue>,
1114 {
1115 let version = values
1116 .next()
1117 .ok_or_else(axum::headers::Error::invalid)?
1118 .to_str()
1119 .map_err(|_| axum::headers::Error::invalid())?
1120 .parse()
1121 .map_err(|_| axum::headers::Error::invalid())?;
1122 Ok(Self(version))
1123 }
1124
1125 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1126 values.extend([self.0.to_string().parse().unwrap()]);
1127 }
1128}
1129
1130pub struct AppVersionHeader(SemanticVersion);
1131impl Header for AppVersionHeader {
1132 fn name() -> &'static HeaderName {
1133 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1134 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1135 }
1136
1137 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1138 where
1139 Self: Sized,
1140 I: Iterator<Item = &'i axum::http::HeaderValue>,
1141 {
1142 let version = values
1143 .next()
1144 .ok_or_else(axum::headers::Error::invalid)?
1145 .to_str()
1146 .map_err(|_| axum::headers::Error::invalid())?
1147 .parse()
1148 .map_err(|_| axum::headers::Error::invalid())?;
1149 Ok(Self(version))
1150 }
1151
1152 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1153 values.extend([self.0.to_string().parse().unwrap()]);
1154 }
1155}
1156
1157pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1158 Router::new()
1159 .route("/rpc", get(handle_websocket_request))
1160 .layer(
1161 ServiceBuilder::new()
1162 .layer(Extension(server.app_state.clone()))
1163 .layer(middleware::from_fn(auth::validate_header)),
1164 )
1165 .route("/metrics", get(handle_metrics))
1166 .layer(Extension(server))
1167}
1168
1169pub async fn handle_websocket_request(
1170 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1171 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1172 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1173 Extension(server): Extension<Arc<Server>>,
1174 Extension(principal): Extension<Principal>,
1175 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1176 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1177 ws: WebSocketUpgrade,
1178) -> axum::response::Response {
1179 if protocol_version != rpc::PROTOCOL_VERSION {
1180 return (
1181 StatusCode::UPGRADE_REQUIRED,
1182 "client must be upgraded".to_string(),
1183 )
1184 .into_response();
1185 }
1186
1187 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1188 return (
1189 StatusCode::UPGRADE_REQUIRED,
1190 "no version header found".to_string(),
1191 )
1192 .into_response();
1193 };
1194
1195 if !version.can_collaborate() {
1196 return (
1197 StatusCode::UPGRADE_REQUIRED,
1198 "client must be upgraded".to_string(),
1199 )
1200 .into_response();
1201 }
1202
1203 let socket_address = socket_address.to_string();
1204
1205 // Acquire connection guard before WebSocket upgrade
1206 let connection_guard = match ConnectionGuard::try_acquire() {
1207 Ok(guard) => guard,
1208 Err(()) => {
1209 return (
1210 StatusCode::SERVICE_UNAVAILABLE,
1211 "Too many concurrent connections",
1212 )
1213 .into_response();
1214 }
1215 };
1216
1217 ws.on_upgrade(move |socket| {
1218 let socket = socket
1219 .map_ok(to_tungstenite_message)
1220 .err_into()
1221 .with(|message| async move { to_axum_message(message) });
1222 let connection = Connection::new(Box::pin(socket));
1223 async move {
1224 server
1225 .handle_connection(
1226 connection,
1227 socket_address,
1228 principal,
1229 version,
1230 country_code_header.map(|header| header.to_string()),
1231 system_id_header.map(|header| header.to_string()),
1232 None,
1233 Executor::Production,
1234 Some(connection_guard),
1235 )
1236 .await;
1237 }
1238 })
1239}
1240
1241pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1242 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1243 let connections_metric = CONNECTIONS_METRIC
1244 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1245
1246 let connections = server
1247 .connection_pool
1248 .lock()
1249 .connections()
1250 .filter(|connection| !connection.admin)
1251 .count();
1252 connections_metric.set(connections as _);
1253
1254 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1255 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1256 register_int_gauge!(
1257 "shared_projects",
1258 "number of open projects with one or more guests"
1259 )
1260 .unwrap()
1261 });
1262
1263 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1264 shared_projects_metric.set(shared_projects as _);
1265
1266 let encoder = prometheus::TextEncoder::new();
1267 let metric_families = prometheus::gather();
1268 let encoded_metrics = encoder
1269 .encode_to_string(&metric_families)
1270 .map_err(|err| anyhow!("{err}"))?;
1271 Ok(encoded_metrics)
1272}
1273
1274#[instrument(err, skip(executor))]
1275async fn connection_lost(
1276 session: Session,
1277 mut teardown: watch::Receiver<bool>,
1278 executor: Executor,
1279) -> Result<()> {
1280 session.peer.disconnect(session.connection_id);
1281 session
1282 .connection_pool()
1283 .await
1284 .remove_connection(session.connection_id)?;
1285
1286 session
1287 .db()
1288 .await
1289 .connection_lost(session.connection_id)
1290 .await
1291 .trace_err();
1292
1293 futures::select_biased! {
1294 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1295
1296 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1297 leave_room_for_session(&session, session.connection_id).await.trace_err();
1298 leave_channel_buffers_for_session(&session)
1299 .await
1300 .trace_err();
1301
1302 if !session
1303 .connection_pool()
1304 .await
1305 .is_user_online(session.user_id())
1306 {
1307 let db = session.db().await;
1308 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1309 room_updated(&room, &session.peer);
1310 }
1311 }
1312
1313 update_user_contacts(session.user_id(), &session).await?;
1314 },
1315 _ = teardown.changed().fuse() => {}
1316 }
1317
1318 Ok(())
1319}
1320
1321/// Acknowledges a ping from a client, used to keep the connection alive.
1322async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1323 response.send(proto::Ack {})?;
1324 Ok(())
1325}
1326
1327/// Creates a new room for calling (outside of channels)
1328async fn create_room(
1329 _request: proto::CreateRoom,
1330 response: Response<proto::CreateRoom>,
1331 session: Session,
1332) -> Result<()> {
1333 let livekit_room = nanoid::nanoid!(30);
1334
1335 let live_kit_connection_info = util::maybe!(async {
1336 let live_kit = session.app_state.livekit_client.as_ref();
1337 let live_kit = live_kit?;
1338 let user_id = session.user_id().to_string();
1339
1340 let token = live_kit
1341 .room_token(&livekit_room, &user_id.to_string())
1342 .trace_err()?;
1343
1344 Some(proto::LiveKitConnectionInfo {
1345 server_url: live_kit.url().into(),
1346 token,
1347 can_publish: true,
1348 })
1349 })
1350 .await;
1351
1352 let room = session
1353 .db()
1354 .await
1355 .create_room(session.user_id(), session.connection_id, &livekit_room)
1356 .await?;
1357
1358 response.send(proto::CreateRoomResponse {
1359 room: Some(room.clone()),
1360 live_kit_connection_info,
1361 })?;
1362
1363 update_user_contacts(session.user_id(), &session).await?;
1364 Ok(())
1365}
1366
1367/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1368async fn join_room(
1369 request: proto::JoinRoom,
1370 response: Response<proto::JoinRoom>,
1371 session: Session,
1372) -> Result<()> {
1373 let room_id = RoomId::from_proto(request.id);
1374
1375 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1376
1377 if let Some(channel_id) = channel_id {
1378 return join_channel_internal(channel_id, Box::new(response), session).await;
1379 }
1380
1381 let joined_room = {
1382 let room = session
1383 .db()
1384 .await
1385 .join_room(room_id, session.user_id(), session.connection_id)
1386 .await?;
1387 room_updated(&room.room, &session.peer);
1388 room.into_inner()
1389 };
1390
1391 for connection_id in session
1392 .connection_pool()
1393 .await
1394 .user_connection_ids(session.user_id())
1395 {
1396 session
1397 .peer
1398 .send(
1399 connection_id,
1400 proto::CallCanceled {
1401 room_id: room_id.to_proto(),
1402 },
1403 )
1404 .trace_err();
1405 }
1406
1407 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1408 {
1409 live_kit
1410 .room_token(
1411 &joined_room.room.livekit_room,
1412 &session.user_id().to_string(),
1413 )
1414 .trace_err()
1415 .map(|token| proto::LiveKitConnectionInfo {
1416 server_url: live_kit.url().into(),
1417 token,
1418 can_publish: true,
1419 })
1420 } else {
1421 None
1422 };
1423
1424 response.send(proto::JoinRoomResponse {
1425 room: Some(joined_room.room),
1426 channel_id: None,
1427 live_kit_connection_info,
1428 })?;
1429
1430 update_user_contacts(session.user_id(), &session).await?;
1431 Ok(())
1432}
1433
1434/// Rejoin room is used to reconnect to a room after connection errors.
1435async fn rejoin_room(
1436 request: proto::RejoinRoom,
1437 response: Response<proto::RejoinRoom>,
1438 session: Session,
1439) -> Result<()> {
1440 let room;
1441 let channel;
1442 {
1443 let mut rejoined_room = session
1444 .db()
1445 .await
1446 .rejoin_room(request, session.user_id(), session.connection_id)
1447 .await?;
1448
1449 response.send(proto::RejoinRoomResponse {
1450 room: Some(rejoined_room.room.clone()),
1451 reshared_projects: rejoined_room
1452 .reshared_projects
1453 .iter()
1454 .map(|project| proto::ResharedProject {
1455 id: project.id.to_proto(),
1456 collaborators: project
1457 .collaborators
1458 .iter()
1459 .map(|collaborator| collaborator.to_proto())
1460 .collect(),
1461 })
1462 .collect(),
1463 rejoined_projects: rejoined_room
1464 .rejoined_projects
1465 .iter()
1466 .map(|rejoined_project| rejoined_project.to_proto())
1467 .collect(),
1468 })?;
1469 room_updated(&rejoined_room.room, &session.peer);
1470
1471 for project in &rejoined_room.reshared_projects {
1472 for collaborator in &project.collaborators {
1473 session
1474 .peer
1475 .send(
1476 collaborator.connection_id,
1477 proto::UpdateProjectCollaborator {
1478 project_id: project.id.to_proto(),
1479 old_peer_id: Some(project.old_connection_id.into()),
1480 new_peer_id: Some(session.connection_id.into()),
1481 },
1482 )
1483 .trace_err();
1484 }
1485
1486 broadcast(
1487 Some(session.connection_id),
1488 project
1489 .collaborators
1490 .iter()
1491 .map(|collaborator| collaborator.connection_id),
1492 |connection_id| {
1493 session.peer.forward_send(
1494 session.connection_id,
1495 connection_id,
1496 proto::UpdateProject {
1497 project_id: project.id.to_proto(),
1498 worktrees: project.worktrees.clone(),
1499 },
1500 )
1501 },
1502 );
1503 }
1504
1505 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1506
1507 let rejoined_room = rejoined_room.into_inner();
1508
1509 room = rejoined_room.room;
1510 channel = rejoined_room.channel;
1511 }
1512
1513 if let Some(channel) = channel {
1514 channel_updated(
1515 &channel,
1516 &room,
1517 &session.peer,
1518 &*session.connection_pool().await,
1519 );
1520 }
1521
1522 update_user_contacts(session.user_id(), &session).await?;
1523 Ok(())
1524}
1525
1526fn notify_rejoined_projects(
1527 rejoined_projects: &mut Vec<RejoinedProject>,
1528 session: &Session,
1529) -> Result<()> {
1530 for project in rejoined_projects.iter() {
1531 for collaborator in &project.collaborators {
1532 session
1533 .peer
1534 .send(
1535 collaborator.connection_id,
1536 proto::UpdateProjectCollaborator {
1537 project_id: project.id.to_proto(),
1538 old_peer_id: Some(project.old_connection_id.into()),
1539 new_peer_id: Some(session.connection_id.into()),
1540 },
1541 )
1542 .trace_err();
1543 }
1544 }
1545
1546 for project in rejoined_projects {
1547 for worktree in mem::take(&mut project.worktrees) {
1548 // Stream this worktree's entries.
1549 let message = proto::UpdateWorktree {
1550 project_id: project.id.to_proto(),
1551 worktree_id: worktree.id,
1552 abs_path: worktree.abs_path.clone(),
1553 root_name: worktree.root_name,
1554 updated_entries: worktree.updated_entries,
1555 removed_entries: worktree.removed_entries,
1556 scan_id: worktree.scan_id,
1557 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1558 updated_repositories: worktree.updated_repositories,
1559 removed_repositories: worktree.removed_repositories,
1560 };
1561 for update in proto::split_worktree_update(message) {
1562 session.peer.send(session.connection_id, update)?;
1563 }
1564
1565 // Stream this worktree's diagnostics.
1566 for summary in worktree.diagnostic_summaries {
1567 session.peer.send(
1568 session.connection_id,
1569 proto::UpdateDiagnosticSummary {
1570 project_id: project.id.to_proto(),
1571 worktree_id: worktree.id,
1572 summary: Some(summary),
1573 },
1574 )?;
1575 }
1576
1577 for settings_file in worktree.settings_files {
1578 session.peer.send(
1579 session.connection_id,
1580 proto::UpdateWorktreeSettings {
1581 project_id: project.id.to_proto(),
1582 worktree_id: worktree.id,
1583 path: settings_file.path,
1584 content: Some(settings_file.content),
1585 kind: Some(settings_file.kind.to_proto().into()),
1586 },
1587 )?;
1588 }
1589 }
1590
1591 for repository in mem::take(&mut project.updated_repositories) {
1592 for update in split_repository_update(repository) {
1593 session.peer.send(session.connection_id, update)?;
1594 }
1595 }
1596
1597 for id in mem::take(&mut project.removed_repositories) {
1598 session.peer.send(
1599 session.connection_id,
1600 proto::RemoveRepository {
1601 project_id: project.id.to_proto(),
1602 id,
1603 },
1604 )?;
1605 }
1606 }
1607
1608 Ok(())
1609}
1610
1611/// leave room disconnects from the room.
1612async fn leave_room(
1613 _: proto::LeaveRoom,
1614 response: Response<proto::LeaveRoom>,
1615 session: Session,
1616) -> Result<()> {
1617 leave_room_for_session(&session, session.connection_id).await?;
1618 response.send(proto::Ack {})?;
1619 Ok(())
1620}
1621
1622/// Updates the permissions of someone else in the room.
1623async fn set_room_participant_role(
1624 request: proto::SetRoomParticipantRole,
1625 response: Response<proto::SetRoomParticipantRole>,
1626 session: Session,
1627) -> Result<()> {
1628 let user_id = UserId::from_proto(request.user_id);
1629 let role = ChannelRole::from(request.role());
1630
1631 let (livekit_room, can_publish) = {
1632 let room = session
1633 .db()
1634 .await
1635 .set_room_participant_role(
1636 session.user_id(),
1637 RoomId::from_proto(request.room_id),
1638 user_id,
1639 role,
1640 )
1641 .await?;
1642
1643 let livekit_room = room.livekit_room.clone();
1644 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1645 room_updated(&room, &session.peer);
1646 (livekit_room, can_publish)
1647 };
1648
1649 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1650 live_kit
1651 .update_participant(
1652 livekit_room.clone(),
1653 request.user_id.to_string(),
1654 livekit_api::proto::ParticipantPermission {
1655 can_subscribe: true,
1656 can_publish,
1657 can_publish_data: can_publish,
1658 hidden: false,
1659 recorder: false,
1660 },
1661 )
1662 .await
1663 .trace_err();
1664 }
1665
1666 response.send(proto::Ack {})?;
1667 Ok(())
1668}
1669
1670/// Call someone else into the current room
1671async fn call(
1672 request: proto::Call,
1673 response: Response<proto::Call>,
1674 session: Session,
1675) -> Result<()> {
1676 let room_id = RoomId::from_proto(request.room_id);
1677 let calling_user_id = session.user_id();
1678 let calling_connection_id = session.connection_id;
1679 let called_user_id = UserId::from_proto(request.called_user_id);
1680 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1681 if !session
1682 .db()
1683 .await
1684 .has_contact(calling_user_id, called_user_id)
1685 .await?
1686 {
1687 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1688 }
1689
1690 let incoming_call = {
1691 let (room, incoming_call) = &mut *session
1692 .db()
1693 .await
1694 .call(
1695 room_id,
1696 calling_user_id,
1697 calling_connection_id,
1698 called_user_id,
1699 initial_project_id,
1700 )
1701 .await?;
1702 room_updated(room, &session.peer);
1703 mem::take(incoming_call)
1704 };
1705 update_user_contacts(called_user_id, &session).await?;
1706
1707 let mut calls = session
1708 .connection_pool()
1709 .await
1710 .user_connection_ids(called_user_id)
1711 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1712 .collect::<FuturesUnordered<_>>();
1713
1714 while let Some(call_response) = calls.next().await {
1715 match call_response.as_ref() {
1716 Ok(_) => {
1717 response.send(proto::Ack {})?;
1718 return Ok(());
1719 }
1720 Err(_) => {
1721 call_response.trace_err();
1722 }
1723 }
1724 }
1725
1726 {
1727 let room = session
1728 .db()
1729 .await
1730 .call_failed(room_id, called_user_id)
1731 .await?;
1732 room_updated(&room, &session.peer);
1733 }
1734 update_user_contacts(called_user_id, &session).await?;
1735
1736 Err(anyhow!("failed to ring user"))?
1737}
1738
1739/// Cancel an outgoing call.
1740async fn cancel_call(
1741 request: proto::CancelCall,
1742 response: Response<proto::CancelCall>,
1743 session: Session,
1744) -> Result<()> {
1745 let called_user_id = UserId::from_proto(request.called_user_id);
1746 let room_id = RoomId::from_proto(request.room_id);
1747 {
1748 let room = session
1749 .db()
1750 .await
1751 .cancel_call(room_id, session.connection_id, called_user_id)
1752 .await?;
1753 room_updated(&room, &session.peer);
1754 }
1755
1756 for connection_id in session
1757 .connection_pool()
1758 .await
1759 .user_connection_ids(called_user_id)
1760 {
1761 session
1762 .peer
1763 .send(
1764 connection_id,
1765 proto::CallCanceled {
1766 room_id: room_id.to_proto(),
1767 },
1768 )
1769 .trace_err();
1770 }
1771 response.send(proto::Ack {})?;
1772
1773 update_user_contacts(called_user_id, &session).await?;
1774 Ok(())
1775}
1776
1777/// Decline an incoming call.
1778async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1779 let room_id = RoomId::from_proto(message.room_id);
1780 {
1781 let room = session
1782 .db()
1783 .await
1784 .decline_call(Some(room_id), session.user_id())
1785 .await?
1786 .context("declining call")?;
1787 room_updated(&room, &session.peer);
1788 }
1789
1790 for connection_id in session
1791 .connection_pool()
1792 .await
1793 .user_connection_ids(session.user_id())
1794 {
1795 session
1796 .peer
1797 .send(
1798 connection_id,
1799 proto::CallCanceled {
1800 room_id: room_id.to_proto(),
1801 },
1802 )
1803 .trace_err();
1804 }
1805 update_user_contacts(session.user_id(), &session).await?;
1806 Ok(())
1807}
1808
1809/// Updates other participants in the room with your current location.
1810async fn update_participant_location(
1811 request: proto::UpdateParticipantLocation,
1812 response: Response<proto::UpdateParticipantLocation>,
1813 session: Session,
1814) -> Result<()> {
1815 let room_id = RoomId::from_proto(request.room_id);
1816 let location = request.location.context("invalid location")?;
1817
1818 let db = session.db().await;
1819 let room = db
1820 .update_room_participant_location(room_id, session.connection_id, location)
1821 .await?;
1822
1823 room_updated(&room, &session.peer);
1824 response.send(proto::Ack {})?;
1825 Ok(())
1826}
1827
1828/// Share a project into the room.
1829async fn share_project(
1830 request: proto::ShareProject,
1831 response: Response<proto::ShareProject>,
1832 session: Session,
1833) -> Result<()> {
1834 let (project_id, room) = &*session
1835 .db()
1836 .await
1837 .share_project(
1838 RoomId::from_proto(request.room_id),
1839 session.connection_id,
1840 &request.worktrees,
1841 request.is_ssh_project,
1842 )
1843 .await?;
1844 response.send(proto::ShareProjectResponse {
1845 project_id: project_id.to_proto(),
1846 })?;
1847 room_updated(room, &session.peer);
1848
1849 Ok(())
1850}
1851
1852/// Unshare a project from the room.
1853async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1854 let project_id = ProjectId::from_proto(message.project_id);
1855 unshare_project_internal(project_id, session.connection_id, &session).await
1856}
1857
1858async fn unshare_project_internal(
1859 project_id: ProjectId,
1860 connection_id: ConnectionId,
1861 session: &Session,
1862) -> Result<()> {
1863 let delete = {
1864 let room_guard = session
1865 .db()
1866 .await
1867 .unshare_project(project_id, connection_id)
1868 .await?;
1869
1870 let (delete, room, guest_connection_ids) = &*room_guard;
1871
1872 let message = proto::UnshareProject {
1873 project_id: project_id.to_proto(),
1874 };
1875
1876 broadcast(
1877 Some(connection_id),
1878 guest_connection_ids.iter().copied(),
1879 |conn_id| session.peer.send(conn_id, message.clone()),
1880 );
1881 if let Some(room) = room {
1882 room_updated(room, &session.peer);
1883 }
1884
1885 *delete
1886 };
1887
1888 if delete {
1889 let db = session.db().await;
1890 db.delete_project(project_id).await?;
1891 }
1892
1893 Ok(())
1894}
1895
1896/// Join someone elses shared project.
1897async fn join_project(
1898 request: proto::JoinProject,
1899 response: Response<proto::JoinProject>,
1900 session: Session,
1901) -> Result<()> {
1902 let project_id = ProjectId::from_proto(request.project_id);
1903
1904 tracing::info!(%project_id, "join project");
1905
1906 let db = session.db().await;
1907 let (project, replica_id) = &mut *db
1908 .join_project(
1909 project_id,
1910 session.connection_id,
1911 session.user_id(),
1912 request.committer_name.clone(),
1913 request.committer_email.clone(),
1914 )
1915 .await?;
1916 drop(db);
1917 tracing::info!(%project_id, "join remote project");
1918 let collaborators = project
1919 .collaborators
1920 .iter()
1921 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1922 .map(|collaborator| collaborator.to_proto())
1923 .collect::<Vec<_>>();
1924 let project_id = project.id;
1925 let guest_user_id = session.user_id();
1926
1927 let worktrees = project
1928 .worktrees
1929 .iter()
1930 .map(|(id, worktree)| proto::WorktreeMetadata {
1931 id: *id,
1932 root_name: worktree.root_name.clone(),
1933 visible: worktree.visible,
1934 abs_path: worktree.abs_path.clone(),
1935 })
1936 .collect::<Vec<_>>();
1937
1938 let add_project_collaborator = proto::AddProjectCollaborator {
1939 project_id: project_id.to_proto(),
1940 collaborator: Some(proto::Collaborator {
1941 peer_id: Some(session.connection_id.into()),
1942 replica_id: replica_id.0 as u32,
1943 user_id: guest_user_id.to_proto(),
1944 is_host: false,
1945 committer_name: request.committer_name.clone(),
1946 committer_email: request.committer_email.clone(),
1947 }),
1948 };
1949
1950 for collaborator in &collaborators {
1951 session
1952 .peer
1953 .send(
1954 collaborator.peer_id.unwrap().into(),
1955 add_project_collaborator.clone(),
1956 )
1957 .trace_err();
1958 }
1959
1960 // First, we send the metadata associated with each worktree.
1961 response.send(proto::JoinProjectResponse {
1962 project_id: project.id.0 as u64,
1963 worktrees: worktrees.clone(),
1964 replica_id: replica_id.0 as u32,
1965 collaborators: collaborators.clone(),
1966 language_servers: project.language_servers.clone(),
1967 role: project.role.into(),
1968 })?;
1969
1970 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1971 // Stream this worktree's entries.
1972 let message = proto::UpdateWorktree {
1973 project_id: project_id.to_proto(),
1974 worktree_id,
1975 abs_path: worktree.abs_path.clone(),
1976 root_name: worktree.root_name,
1977 updated_entries: worktree.entries,
1978 removed_entries: Default::default(),
1979 scan_id: worktree.scan_id,
1980 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1981 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1982 removed_repositories: Default::default(),
1983 };
1984 for update in proto::split_worktree_update(message) {
1985 session.peer.send(session.connection_id, update.clone())?;
1986 }
1987
1988 // Stream this worktree's diagnostics.
1989 for summary in worktree.diagnostic_summaries {
1990 session.peer.send(
1991 session.connection_id,
1992 proto::UpdateDiagnosticSummary {
1993 project_id: project_id.to_proto(),
1994 worktree_id: worktree.id,
1995 summary: Some(summary),
1996 },
1997 )?;
1998 }
1999
2000 for settings_file in worktree.settings_files {
2001 session.peer.send(
2002 session.connection_id,
2003 proto::UpdateWorktreeSettings {
2004 project_id: project_id.to_proto(),
2005 worktree_id: worktree.id,
2006 path: settings_file.path,
2007 content: Some(settings_file.content),
2008 kind: Some(settings_file.kind.to_proto() as i32),
2009 },
2010 )?;
2011 }
2012 }
2013
2014 for repository in mem::take(&mut project.repositories) {
2015 for update in split_repository_update(repository) {
2016 session.peer.send(session.connection_id, update)?;
2017 }
2018 }
2019
2020 for language_server in &project.language_servers {
2021 session.peer.send(
2022 session.connection_id,
2023 proto::UpdateLanguageServer {
2024 project_id: project_id.to_proto(),
2025 server_name: Some(language_server.name.clone()),
2026 language_server_id: language_server.id,
2027 variant: Some(
2028 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2029 proto::LspDiskBasedDiagnosticsUpdated {},
2030 ),
2031 ),
2032 },
2033 )?;
2034 }
2035
2036 Ok(())
2037}
2038
2039/// Leave someone elses shared project.
2040async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2041 let sender_id = session.connection_id;
2042 let project_id = ProjectId::from_proto(request.project_id);
2043 let db = session.db().await;
2044
2045 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2046 tracing::info!(
2047 %project_id,
2048 "leave project"
2049 );
2050
2051 project_left(project, &session);
2052 if let Some(room) = room {
2053 room_updated(room, &session.peer);
2054 }
2055
2056 Ok(())
2057}
2058
2059/// Updates other participants with changes to the project
2060async fn update_project(
2061 request: proto::UpdateProject,
2062 response: Response<proto::UpdateProject>,
2063 session: Session,
2064) -> Result<()> {
2065 let project_id = ProjectId::from_proto(request.project_id);
2066 let (room, guest_connection_ids) = &*session
2067 .db()
2068 .await
2069 .update_project(project_id, session.connection_id, &request.worktrees)
2070 .await?;
2071 broadcast(
2072 Some(session.connection_id),
2073 guest_connection_ids.iter().copied(),
2074 |connection_id| {
2075 session
2076 .peer
2077 .forward_send(session.connection_id, connection_id, request.clone())
2078 },
2079 );
2080 if let Some(room) = room {
2081 room_updated(room, &session.peer);
2082 }
2083 response.send(proto::Ack {})?;
2084
2085 Ok(())
2086}
2087
2088/// Updates other participants with changes to the worktree
2089async fn update_worktree(
2090 request: proto::UpdateWorktree,
2091 response: Response<proto::UpdateWorktree>,
2092 session: Session,
2093) -> Result<()> {
2094 let guest_connection_ids = session
2095 .db()
2096 .await
2097 .update_worktree(&request, session.connection_id)
2098 .await?;
2099
2100 broadcast(
2101 Some(session.connection_id),
2102 guest_connection_ids.iter().copied(),
2103 |connection_id| {
2104 session
2105 .peer
2106 .forward_send(session.connection_id, connection_id, request.clone())
2107 },
2108 );
2109 response.send(proto::Ack {})?;
2110 Ok(())
2111}
2112
2113async fn update_repository(
2114 request: proto::UpdateRepository,
2115 response: Response<proto::UpdateRepository>,
2116 session: Session,
2117) -> Result<()> {
2118 let guest_connection_ids = session
2119 .db()
2120 .await
2121 .update_repository(&request, session.connection_id)
2122 .await?;
2123
2124 broadcast(
2125 Some(session.connection_id),
2126 guest_connection_ids.iter().copied(),
2127 |connection_id| {
2128 session
2129 .peer
2130 .forward_send(session.connection_id, connection_id, request.clone())
2131 },
2132 );
2133 response.send(proto::Ack {})?;
2134 Ok(())
2135}
2136
2137async fn remove_repository(
2138 request: proto::RemoveRepository,
2139 response: Response<proto::RemoveRepository>,
2140 session: Session,
2141) -> Result<()> {
2142 let guest_connection_ids = session
2143 .db()
2144 .await
2145 .remove_repository(&request, session.connection_id)
2146 .await?;
2147
2148 broadcast(
2149 Some(session.connection_id),
2150 guest_connection_ids.iter().copied(),
2151 |connection_id| {
2152 session
2153 .peer
2154 .forward_send(session.connection_id, connection_id, request.clone())
2155 },
2156 );
2157 response.send(proto::Ack {})?;
2158 Ok(())
2159}
2160
2161/// Updates other participants with changes to the diagnostics
2162async fn update_diagnostic_summary(
2163 message: proto::UpdateDiagnosticSummary,
2164 session: Session,
2165) -> Result<()> {
2166 let guest_connection_ids = session
2167 .db()
2168 .await
2169 .update_diagnostic_summary(&message, session.connection_id)
2170 .await?;
2171
2172 broadcast(
2173 Some(session.connection_id),
2174 guest_connection_ids.iter().copied(),
2175 |connection_id| {
2176 session
2177 .peer
2178 .forward_send(session.connection_id, connection_id, message.clone())
2179 },
2180 );
2181
2182 Ok(())
2183}
2184
2185/// Updates other participants with changes to the worktree settings
2186async fn update_worktree_settings(
2187 message: proto::UpdateWorktreeSettings,
2188 session: Session,
2189) -> Result<()> {
2190 let guest_connection_ids = session
2191 .db()
2192 .await
2193 .update_worktree_settings(&message, session.connection_id)
2194 .await?;
2195
2196 broadcast(
2197 Some(session.connection_id),
2198 guest_connection_ids.iter().copied(),
2199 |connection_id| {
2200 session
2201 .peer
2202 .forward_send(session.connection_id, connection_id, message.clone())
2203 },
2204 );
2205
2206 Ok(())
2207}
2208
2209/// Notify other participants that a language server has started.
2210async fn start_language_server(
2211 request: proto::StartLanguageServer,
2212 session: Session,
2213) -> Result<()> {
2214 let guest_connection_ids = session
2215 .db()
2216 .await
2217 .start_language_server(&request, session.connection_id)
2218 .await?;
2219
2220 broadcast(
2221 Some(session.connection_id),
2222 guest_connection_ids.iter().copied(),
2223 |connection_id| {
2224 session
2225 .peer
2226 .forward_send(session.connection_id, connection_id, request.clone())
2227 },
2228 );
2229 Ok(())
2230}
2231
2232/// Notify other participants that a language server has changed.
2233async fn update_language_server(
2234 request: proto::UpdateLanguageServer,
2235 session: Session,
2236) -> Result<()> {
2237 let project_id = ProjectId::from_proto(request.project_id);
2238 let project_connection_ids = session
2239 .db()
2240 .await
2241 .project_connection_ids(project_id, session.connection_id, true)
2242 .await?;
2243 broadcast(
2244 Some(session.connection_id),
2245 project_connection_ids.iter().copied(),
2246 |connection_id| {
2247 session
2248 .peer
2249 .forward_send(session.connection_id, connection_id, request.clone())
2250 },
2251 );
2252 Ok(())
2253}
2254
2255/// forward a project request to the host. These requests should be read only
2256/// as guests are allowed to send them.
2257async fn forward_read_only_project_request<T>(
2258 request: T,
2259 response: Response<T>,
2260 session: Session,
2261) -> Result<()>
2262where
2263 T: EntityMessage + RequestMessage,
2264{
2265 let project_id = ProjectId::from_proto(request.remote_entity_id());
2266 let host_connection_id = session
2267 .db()
2268 .await
2269 .host_for_read_only_project_request(project_id, session.connection_id)
2270 .await?;
2271 let payload = session
2272 .peer
2273 .forward_request(session.connection_id, host_connection_id, request)
2274 .await?;
2275 response.send(payload)?;
2276 Ok(())
2277}
2278
2279async fn forward_find_search_candidates_request(
2280 request: proto::FindSearchCandidates,
2281 response: Response<proto::FindSearchCandidates>,
2282 session: Session,
2283) -> Result<()> {
2284 let project_id = ProjectId::from_proto(request.remote_entity_id());
2285 let host_connection_id = session
2286 .db()
2287 .await
2288 .host_for_read_only_project_request(project_id, session.connection_id)
2289 .await?;
2290 let payload = session
2291 .peer
2292 .forward_request(session.connection_id, host_connection_id, request)
2293 .await?;
2294 response.send(payload)?;
2295 Ok(())
2296}
2297
2298/// forward a project request to the host. These requests are disallowed
2299/// for guests.
2300async fn forward_mutating_project_request<T>(
2301 request: T,
2302 response: Response<T>,
2303 session: Session,
2304) -> Result<()>
2305where
2306 T: EntityMessage + RequestMessage,
2307{
2308 let project_id = ProjectId::from_proto(request.remote_entity_id());
2309
2310 let host_connection_id = session
2311 .db()
2312 .await
2313 .host_for_mutating_project_request(project_id, session.connection_id)
2314 .await?;
2315 let payload = session
2316 .peer
2317 .forward_request(session.connection_id, host_connection_id, request)
2318 .await?;
2319 response.send(payload)?;
2320 Ok(())
2321}
2322
2323/// Notify other participants that a new buffer has been created
2324async fn create_buffer_for_peer(
2325 request: proto::CreateBufferForPeer,
2326 session: Session,
2327) -> Result<()> {
2328 session
2329 .db()
2330 .await
2331 .check_user_is_project_host(
2332 ProjectId::from_proto(request.project_id),
2333 session.connection_id,
2334 )
2335 .await?;
2336 let peer_id = request.peer_id.context("invalid peer id")?;
2337 session
2338 .peer
2339 .forward_send(session.connection_id, peer_id.into(), request)?;
2340 Ok(())
2341}
2342
2343/// Notify other participants that a buffer has been updated. This is
2344/// allowed for guests as long as the update is limited to selections.
2345async fn update_buffer(
2346 request: proto::UpdateBuffer,
2347 response: Response<proto::UpdateBuffer>,
2348 session: Session,
2349) -> Result<()> {
2350 let project_id = ProjectId::from_proto(request.project_id);
2351 let mut capability = Capability::ReadOnly;
2352
2353 for op in request.operations.iter() {
2354 match op.variant {
2355 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2356 Some(_) => capability = Capability::ReadWrite,
2357 }
2358 }
2359
2360 let host = {
2361 let guard = session
2362 .db()
2363 .await
2364 .connections_for_buffer_update(project_id, session.connection_id, capability)
2365 .await?;
2366
2367 let (host, guests) = &*guard;
2368
2369 broadcast(
2370 Some(session.connection_id),
2371 guests.clone(),
2372 |connection_id| {
2373 session
2374 .peer
2375 .forward_send(session.connection_id, connection_id, request.clone())
2376 },
2377 );
2378
2379 *host
2380 };
2381
2382 if host != session.connection_id {
2383 session
2384 .peer
2385 .forward_request(session.connection_id, host, request.clone())
2386 .await?;
2387 }
2388
2389 response.send(proto::Ack {})?;
2390 Ok(())
2391}
2392
2393async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2394 let project_id = ProjectId::from_proto(message.project_id);
2395
2396 let operation = message.operation.as_ref().context("invalid operation")?;
2397 let capability = match operation.variant.as_ref() {
2398 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2399 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2400 match buffer_op.variant {
2401 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2402 Capability::ReadOnly
2403 }
2404 _ => Capability::ReadWrite,
2405 }
2406 } else {
2407 Capability::ReadWrite
2408 }
2409 }
2410 Some(_) => Capability::ReadWrite,
2411 None => Capability::ReadOnly,
2412 };
2413
2414 let guard = session
2415 .db()
2416 .await
2417 .connections_for_buffer_update(project_id, session.connection_id, capability)
2418 .await?;
2419
2420 let (host, guests) = &*guard;
2421
2422 broadcast(
2423 Some(session.connection_id),
2424 guests.iter().chain([host]).copied(),
2425 |connection_id| {
2426 session
2427 .peer
2428 .forward_send(session.connection_id, connection_id, message.clone())
2429 },
2430 );
2431
2432 Ok(())
2433}
2434
2435/// Notify other participants that a project has been updated.
2436async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2437 request: T,
2438 session: Session,
2439) -> Result<()> {
2440 let project_id = ProjectId::from_proto(request.remote_entity_id());
2441 let project_connection_ids = session
2442 .db()
2443 .await
2444 .project_connection_ids(project_id, session.connection_id, false)
2445 .await?;
2446
2447 broadcast(
2448 Some(session.connection_id),
2449 project_connection_ids.iter().copied(),
2450 |connection_id| {
2451 session
2452 .peer
2453 .forward_send(session.connection_id, connection_id, request.clone())
2454 },
2455 );
2456 Ok(())
2457}
2458
2459/// Start following another user in a call.
2460async fn follow(
2461 request: proto::Follow,
2462 response: Response<proto::Follow>,
2463 session: Session,
2464) -> Result<()> {
2465 let room_id = RoomId::from_proto(request.room_id);
2466 let project_id = request.project_id.map(ProjectId::from_proto);
2467 let leader_id = request.leader_id.context("invalid leader id")?.into();
2468 let follower_id = session.connection_id;
2469
2470 session
2471 .db()
2472 .await
2473 .check_room_participants(room_id, leader_id, session.connection_id)
2474 .await?;
2475
2476 let response_payload = session
2477 .peer
2478 .forward_request(session.connection_id, leader_id, request)
2479 .await?;
2480 response.send(response_payload)?;
2481
2482 if let Some(project_id) = project_id {
2483 let room = session
2484 .db()
2485 .await
2486 .follow(room_id, project_id, leader_id, follower_id)
2487 .await?;
2488 room_updated(&room, &session.peer);
2489 }
2490
2491 Ok(())
2492}
2493
2494/// Stop following another user in a call.
2495async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2496 let room_id = RoomId::from_proto(request.room_id);
2497 let project_id = request.project_id.map(ProjectId::from_proto);
2498 let leader_id = request.leader_id.context("invalid leader id")?.into();
2499 let follower_id = session.connection_id;
2500
2501 session
2502 .db()
2503 .await
2504 .check_room_participants(room_id, leader_id, session.connection_id)
2505 .await?;
2506
2507 session
2508 .peer
2509 .forward_send(session.connection_id, leader_id, request)?;
2510
2511 if let Some(project_id) = project_id {
2512 let room = session
2513 .db()
2514 .await
2515 .unfollow(room_id, project_id, leader_id, follower_id)
2516 .await?;
2517 room_updated(&room, &session.peer);
2518 }
2519
2520 Ok(())
2521}
2522
2523/// Notify everyone following you of your current location.
2524async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2525 let room_id = RoomId::from_proto(request.room_id);
2526 let database = session.db.lock().await;
2527
2528 let connection_ids = if let Some(project_id) = request.project_id {
2529 let project_id = ProjectId::from_proto(project_id);
2530 database
2531 .project_connection_ids(project_id, session.connection_id, true)
2532 .await?
2533 } else {
2534 database
2535 .room_connection_ids(room_id, session.connection_id)
2536 .await?
2537 };
2538
2539 // For now, don't send view update messages back to that view's current leader.
2540 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2541 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2542 _ => None,
2543 });
2544
2545 for connection_id in connection_ids.iter().cloned() {
2546 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2547 session
2548 .peer
2549 .forward_send(session.connection_id, connection_id, request.clone())?;
2550 }
2551 }
2552 Ok(())
2553}
2554
2555/// Get public data about users.
2556async fn get_users(
2557 request: proto::GetUsers,
2558 response: Response<proto::GetUsers>,
2559 session: Session,
2560) -> Result<()> {
2561 let user_ids = request
2562 .user_ids
2563 .into_iter()
2564 .map(UserId::from_proto)
2565 .collect();
2566 let users = session
2567 .db()
2568 .await
2569 .get_users_by_ids(user_ids)
2570 .await?
2571 .into_iter()
2572 .map(|user| proto::User {
2573 id: user.id.to_proto(),
2574 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2575 github_login: user.github_login,
2576 name: user.name,
2577 })
2578 .collect();
2579 response.send(proto::UsersResponse { users })?;
2580 Ok(())
2581}
2582
2583/// Search for users (to invite) buy Github login
2584async fn fuzzy_search_users(
2585 request: proto::FuzzySearchUsers,
2586 response: Response<proto::FuzzySearchUsers>,
2587 session: Session,
2588) -> Result<()> {
2589 let query = request.query;
2590 let users = match query.len() {
2591 0 => vec![],
2592 1 | 2 => session
2593 .db()
2594 .await
2595 .get_user_by_github_login(&query)
2596 .await?
2597 .into_iter()
2598 .collect(),
2599 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2600 };
2601 let users = users
2602 .into_iter()
2603 .filter(|user| user.id != session.user_id())
2604 .map(|user| proto::User {
2605 id: user.id.to_proto(),
2606 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2607 github_login: user.github_login,
2608 name: user.name,
2609 })
2610 .collect();
2611 response.send(proto::UsersResponse { users })?;
2612 Ok(())
2613}
2614
2615/// Send a contact request to another user.
2616async fn request_contact(
2617 request: proto::RequestContact,
2618 response: Response<proto::RequestContact>,
2619 session: Session,
2620) -> Result<()> {
2621 let requester_id = session.user_id();
2622 let responder_id = UserId::from_proto(request.responder_id);
2623 if requester_id == responder_id {
2624 return Err(anyhow!("cannot add yourself as a contact"))?;
2625 }
2626
2627 let notifications = session
2628 .db()
2629 .await
2630 .send_contact_request(requester_id, responder_id)
2631 .await?;
2632
2633 // Update outgoing contact requests of requester
2634 let mut update = proto::UpdateContacts::default();
2635 update.outgoing_requests.push(responder_id.to_proto());
2636 for connection_id in session
2637 .connection_pool()
2638 .await
2639 .user_connection_ids(requester_id)
2640 {
2641 session.peer.send(connection_id, update.clone())?;
2642 }
2643
2644 // Update incoming contact requests of responder
2645 let mut update = proto::UpdateContacts::default();
2646 update
2647 .incoming_requests
2648 .push(proto::IncomingContactRequest {
2649 requester_id: requester_id.to_proto(),
2650 });
2651 let connection_pool = session.connection_pool().await;
2652 for connection_id in connection_pool.user_connection_ids(responder_id) {
2653 session.peer.send(connection_id, update.clone())?;
2654 }
2655
2656 send_notifications(&connection_pool, &session.peer, notifications);
2657
2658 response.send(proto::Ack {})?;
2659 Ok(())
2660}
2661
2662/// Accept or decline a contact request
2663async fn respond_to_contact_request(
2664 request: proto::RespondToContactRequest,
2665 response: Response<proto::RespondToContactRequest>,
2666 session: Session,
2667) -> Result<()> {
2668 let responder_id = session.user_id();
2669 let requester_id = UserId::from_proto(request.requester_id);
2670 let db = session.db().await;
2671 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2672 db.dismiss_contact_notification(responder_id, requester_id)
2673 .await?;
2674 } else {
2675 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2676
2677 let notifications = db
2678 .respond_to_contact_request(responder_id, requester_id, accept)
2679 .await?;
2680 let requester_busy = db.is_user_busy(requester_id).await?;
2681 let responder_busy = db.is_user_busy(responder_id).await?;
2682
2683 let pool = session.connection_pool().await;
2684 // Update responder with new contact
2685 let mut update = proto::UpdateContacts::default();
2686 if accept {
2687 update
2688 .contacts
2689 .push(contact_for_user(requester_id, requester_busy, &pool));
2690 }
2691 update
2692 .remove_incoming_requests
2693 .push(requester_id.to_proto());
2694 for connection_id in pool.user_connection_ids(responder_id) {
2695 session.peer.send(connection_id, update.clone())?;
2696 }
2697
2698 // Update requester with new contact
2699 let mut update = proto::UpdateContacts::default();
2700 if accept {
2701 update
2702 .contacts
2703 .push(contact_for_user(responder_id, responder_busy, &pool));
2704 }
2705 update
2706 .remove_outgoing_requests
2707 .push(responder_id.to_proto());
2708
2709 for connection_id in pool.user_connection_ids(requester_id) {
2710 session.peer.send(connection_id, update.clone())?;
2711 }
2712
2713 send_notifications(&pool, &session.peer, notifications);
2714 }
2715
2716 response.send(proto::Ack {})?;
2717 Ok(())
2718}
2719
2720/// Remove a contact.
2721async fn remove_contact(
2722 request: proto::RemoveContact,
2723 response: Response<proto::RemoveContact>,
2724 session: Session,
2725) -> Result<()> {
2726 let requester_id = session.user_id();
2727 let responder_id = UserId::from_proto(request.user_id);
2728 let db = session.db().await;
2729 let (contact_accepted, deleted_notification_id) =
2730 db.remove_contact(requester_id, responder_id).await?;
2731
2732 let pool = session.connection_pool().await;
2733 // Update outgoing contact requests of requester
2734 let mut update = proto::UpdateContacts::default();
2735 if contact_accepted {
2736 update.remove_contacts.push(responder_id.to_proto());
2737 } else {
2738 update
2739 .remove_outgoing_requests
2740 .push(responder_id.to_proto());
2741 }
2742 for connection_id in pool.user_connection_ids(requester_id) {
2743 session.peer.send(connection_id, update.clone())?;
2744 }
2745
2746 // Update incoming contact requests of responder
2747 let mut update = proto::UpdateContacts::default();
2748 if contact_accepted {
2749 update.remove_contacts.push(requester_id.to_proto());
2750 } else {
2751 update
2752 .remove_incoming_requests
2753 .push(requester_id.to_proto());
2754 }
2755 for connection_id in pool.user_connection_ids(responder_id) {
2756 session.peer.send(connection_id, update.clone())?;
2757 if let Some(notification_id) = deleted_notification_id {
2758 session.peer.send(
2759 connection_id,
2760 proto::DeleteNotification {
2761 notification_id: notification_id.to_proto(),
2762 },
2763 )?;
2764 }
2765 }
2766
2767 response.send(proto::Ack {})?;
2768 Ok(())
2769}
2770
2771fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2772 version.0.minor() < 139
2773}
2774
2775async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2776 if is_staff {
2777 return Ok(proto::Plan::ZedPro);
2778 }
2779
2780 let subscription = db.get_active_billing_subscription(user_id).await?;
2781 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2782
2783 let plan = if let Some(subscription_kind) = subscription_kind {
2784 match subscription_kind {
2785 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2786 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2787 SubscriptionKind::ZedFree => proto::Plan::Free,
2788 }
2789 } else {
2790 proto::Plan::Free
2791 };
2792
2793 Ok(plan)
2794}
2795
2796async fn make_update_user_plan_message(
2797 user: &User,
2798 is_staff: bool,
2799 db: &Arc<Database>,
2800 llm_db: Option<Arc<LlmDatabase>>,
2801) -> Result<proto::UpdateUserPlan> {
2802 let feature_flags = db.get_user_flags(user.id).await?;
2803 let plan = current_plan(db, user.id, is_staff).await?;
2804 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2805 let billing_preferences = db.get_billing_preferences(user.id).await?;
2806
2807 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2808 let subscription = db.get_active_billing_subscription(user.id).await?;
2809
2810 let subscription_period =
2811 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2812
2813 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2814 llm_db
2815 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2816 .await?
2817 } else {
2818 None
2819 };
2820
2821 (subscription_period, usage)
2822 } else {
2823 (None, None)
2824 };
2825
2826 let bypass_account_age_check = feature_flags
2827 .iter()
2828 .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2829 let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2830 && !bypass_account_age_check
2831 && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2832
2833 Ok(proto::UpdateUserPlan {
2834 plan: plan.into(),
2835 trial_started_at: billing_customer
2836 .as_ref()
2837 .and_then(|billing_customer| billing_customer.trial_started_at)
2838 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2839 is_usage_based_billing_enabled: if is_staff {
2840 Some(true)
2841 } else {
2842 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2843 },
2844 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2845 proto::SubscriptionPeriod {
2846 started_at: started_at.timestamp() as u64,
2847 ended_at: ended_at.timestamp() as u64,
2848 }
2849 }),
2850 account_too_young: Some(account_too_young),
2851 has_overdue_invoices: billing_customer
2852 .map(|billing_customer| billing_customer.has_overdue_invoices),
2853 usage: Some(
2854 usage
2855 .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2856 .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2857 ),
2858 })
2859}
2860
2861fn model_requests_limit(
2862 plan: zed_llm_client::Plan,
2863 feature_flags: &Vec<String>,
2864) -> zed_llm_client::UsageLimit {
2865 match plan.model_requests_limit() {
2866 zed_llm_client::UsageLimit::Limited(limit) => {
2867 let limit = if plan == zed_llm_client::Plan::ZedProTrial
2868 && feature_flags
2869 .iter()
2870 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2871 {
2872 1_000
2873 } else {
2874 limit
2875 };
2876
2877 zed_llm_client::UsageLimit::Limited(limit)
2878 }
2879 zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2880 }
2881}
2882
2883fn subscription_usage_to_proto(
2884 plan: proto::Plan,
2885 usage: crate::llm::db::subscription_usage::Model,
2886 feature_flags: &Vec<String>,
2887) -> proto::SubscriptionUsage {
2888 let plan = match plan {
2889 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2890 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2891 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2892 };
2893
2894 proto::SubscriptionUsage {
2895 model_requests_usage_amount: usage.model_requests as u32,
2896 model_requests_usage_limit: Some(proto::UsageLimit {
2897 variant: Some(match model_requests_limit(plan, feature_flags) {
2898 zed_llm_client::UsageLimit::Limited(limit) => {
2899 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2900 limit: limit as u32,
2901 })
2902 }
2903 zed_llm_client::UsageLimit::Unlimited => {
2904 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2905 }
2906 }),
2907 }),
2908 edit_predictions_usage_amount: usage.edit_predictions as u32,
2909 edit_predictions_usage_limit: Some(proto::UsageLimit {
2910 variant: Some(match plan.edit_predictions_limit() {
2911 zed_llm_client::UsageLimit::Limited(limit) => {
2912 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2913 limit: limit as u32,
2914 })
2915 }
2916 zed_llm_client::UsageLimit::Unlimited => {
2917 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2918 }
2919 }),
2920 }),
2921 }
2922}
2923
2924fn make_default_subscription_usage(
2925 plan: proto::Plan,
2926 feature_flags: &Vec<String>,
2927) -> proto::SubscriptionUsage {
2928 let plan = match plan {
2929 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2930 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2931 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2932 };
2933
2934 proto::SubscriptionUsage {
2935 model_requests_usage_amount: 0,
2936 model_requests_usage_limit: Some(proto::UsageLimit {
2937 variant: Some(match model_requests_limit(plan, feature_flags) {
2938 zed_llm_client::UsageLimit::Limited(limit) => {
2939 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2940 limit: limit as u32,
2941 })
2942 }
2943 zed_llm_client::UsageLimit::Unlimited => {
2944 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2945 }
2946 }),
2947 }),
2948 edit_predictions_usage_amount: 0,
2949 edit_predictions_usage_limit: Some(proto::UsageLimit {
2950 variant: Some(match plan.edit_predictions_limit() {
2951 zed_llm_client::UsageLimit::Limited(limit) => {
2952 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2953 limit: limit as u32,
2954 })
2955 }
2956 zed_llm_client::UsageLimit::Unlimited => {
2957 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2958 }
2959 }),
2960 }),
2961 }
2962}
2963
2964async fn update_user_plan(session: &Session) -> Result<()> {
2965 let db = session.db().await;
2966
2967 let update_user_plan = make_update_user_plan_message(
2968 session.principal.user(),
2969 session.is_staff(),
2970 &db.0,
2971 session.app_state.llm_db.clone(),
2972 )
2973 .await?;
2974
2975 session
2976 .peer
2977 .send(session.connection_id, update_user_plan)
2978 .trace_err();
2979
2980 Ok(())
2981}
2982
2983async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2984 subscribe_user_to_channels(session.user_id(), &session).await?;
2985 Ok(())
2986}
2987
2988async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2989 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2990 let mut pool = session.connection_pool().await;
2991 for membership in &channels_for_user.channel_memberships {
2992 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2993 }
2994 session.peer.send(
2995 session.connection_id,
2996 build_update_user_channels(&channels_for_user),
2997 )?;
2998 session.peer.send(
2999 session.connection_id,
3000 build_channels_update(channels_for_user),
3001 )?;
3002 Ok(())
3003}
3004
3005/// Creates a new channel.
3006async fn create_channel(
3007 request: proto::CreateChannel,
3008 response: Response<proto::CreateChannel>,
3009 session: Session,
3010) -> Result<()> {
3011 let db = session.db().await;
3012
3013 let parent_id = request.parent_id.map(ChannelId::from_proto);
3014 let (channel, membership) = db
3015 .create_channel(&request.name, parent_id, session.user_id())
3016 .await?;
3017
3018 let root_id = channel.root_id();
3019 let channel = Channel::from_model(channel);
3020
3021 response.send(proto::CreateChannelResponse {
3022 channel: Some(channel.to_proto()),
3023 parent_id: request.parent_id,
3024 })?;
3025
3026 let mut connection_pool = session.connection_pool().await;
3027 if let Some(membership) = membership {
3028 connection_pool.subscribe_to_channel(
3029 membership.user_id,
3030 membership.channel_id,
3031 membership.role,
3032 );
3033 let update = proto::UpdateUserChannels {
3034 channel_memberships: vec![proto::ChannelMembership {
3035 channel_id: membership.channel_id.to_proto(),
3036 role: membership.role.into(),
3037 }],
3038 ..Default::default()
3039 };
3040 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3041 session.peer.send(connection_id, update.clone())?;
3042 }
3043 }
3044
3045 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3046 if !role.can_see_channel(channel.visibility) {
3047 continue;
3048 }
3049
3050 let update = proto::UpdateChannels {
3051 channels: vec![channel.to_proto()],
3052 ..Default::default()
3053 };
3054 session.peer.send(connection_id, update.clone())?;
3055 }
3056
3057 Ok(())
3058}
3059
3060/// Delete a channel
3061async fn delete_channel(
3062 request: proto::DeleteChannel,
3063 response: Response<proto::DeleteChannel>,
3064 session: Session,
3065) -> Result<()> {
3066 let db = session.db().await;
3067
3068 let channel_id = request.channel_id;
3069 let (root_channel, removed_channels) = db
3070 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3071 .await?;
3072 response.send(proto::Ack {})?;
3073
3074 // Notify members of removed channels
3075 let mut update = proto::UpdateChannels::default();
3076 update
3077 .delete_channels
3078 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3079
3080 let connection_pool = session.connection_pool().await;
3081 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3082 session.peer.send(connection_id, update.clone())?;
3083 }
3084
3085 Ok(())
3086}
3087
3088/// Invite someone to join a channel.
3089async fn invite_channel_member(
3090 request: proto::InviteChannelMember,
3091 response: Response<proto::InviteChannelMember>,
3092 session: Session,
3093) -> Result<()> {
3094 let db = session.db().await;
3095 let channel_id = ChannelId::from_proto(request.channel_id);
3096 let invitee_id = UserId::from_proto(request.user_id);
3097 let InviteMemberResult {
3098 channel,
3099 notifications,
3100 } = db
3101 .invite_channel_member(
3102 channel_id,
3103 invitee_id,
3104 session.user_id(),
3105 request.role().into(),
3106 )
3107 .await?;
3108
3109 let update = proto::UpdateChannels {
3110 channel_invitations: vec![channel.to_proto()],
3111 ..Default::default()
3112 };
3113
3114 let connection_pool = session.connection_pool().await;
3115 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3116 session.peer.send(connection_id, update.clone())?;
3117 }
3118
3119 send_notifications(&connection_pool, &session.peer, notifications);
3120
3121 response.send(proto::Ack {})?;
3122 Ok(())
3123}
3124
3125/// remove someone from a channel
3126async fn remove_channel_member(
3127 request: proto::RemoveChannelMember,
3128 response: Response<proto::RemoveChannelMember>,
3129 session: Session,
3130) -> Result<()> {
3131 let db = session.db().await;
3132 let channel_id = ChannelId::from_proto(request.channel_id);
3133 let member_id = UserId::from_proto(request.user_id);
3134
3135 let RemoveChannelMemberResult {
3136 membership_update,
3137 notification_id,
3138 } = db
3139 .remove_channel_member(channel_id, member_id, session.user_id())
3140 .await?;
3141
3142 let mut connection_pool = session.connection_pool().await;
3143 notify_membership_updated(
3144 &mut connection_pool,
3145 membership_update,
3146 member_id,
3147 &session.peer,
3148 );
3149 for connection_id in connection_pool.user_connection_ids(member_id) {
3150 if let Some(notification_id) = notification_id {
3151 session
3152 .peer
3153 .send(
3154 connection_id,
3155 proto::DeleteNotification {
3156 notification_id: notification_id.to_proto(),
3157 },
3158 )
3159 .trace_err();
3160 }
3161 }
3162
3163 response.send(proto::Ack {})?;
3164 Ok(())
3165}
3166
3167/// Toggle the channel between public and private.
3168/// Care is taken to maintain the invariant that public channels only descend from public channels,
3169/// (though members-only channels can appear at any point in the hierarchy).
3170async fn set_channel_visibility(
3171 request: proto::SetChannelVisibility,
3172 response: Response<proto::SetChannelVisibility>,
3173 session: Session,
3174) -> Result<()> {
3175 let db = session.db().await;
3176 let channel_id = ChannelId::from_proto(request.channel_id);
3177 let visibility = request.visibility().into();
3178
3179 let channel_model = db
3180 .set_channel_visibility(channel_id, visibility, session.user_id())
3181 .await?;
3182 let root_id = channel_model.root_id();
3183 let channel = Channel::from_model(channel_model);
3184
3185 let mut connection_pool = session.connection_pool().await;
3186 for (user_id, role) in connection_pool
3187 .channel_user_ids(root_id)
3188 .collect::<Vec<_>>()
3189 .into_iter()
3190 {
3191 let update = if role.can_see_channel(channel.visibility) {
3192 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3193 proto::UpdateChannels {
3194 channels: vec![channel.to_proto()],
3195 ..Default::default()
3196 }
3197 } else {
3198 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3199 proto::UpdateChannels {
3200 delete_channels: vec![channel.id.to_proto()],
3201 ..Default::default()
3202 }
3203 };
3204
3205 for connection_id in connection_pool.user_connection_ids(user_id) {
3206 session.peer.send(connection_id, update.clone())?;
3207 }
3208 }
3209
3210 response.send(proto::Ack {})?;
3211 Ok(())
3212}
3213
3214/// Alter the role for a user in the channel.
3215async fn set_channel_member_role(
3216 request: proto::SetChannelMemberRole,
3217 response: Response<proto::SetChannelMemberRole>,
3218 session: Session,
3219) -> Result<()> {
3220 let db = session.db().await;
3221 let channel_id = ChannelId::from_proto(request.channel_id);
3222 let member_id = UserId::from_proto(request.user_id);
3223 let result = db
3224 .set_channel_member_role(
3225 channel_id,
3226 session.user_id(),
3227 member_id,
3228 request.role().into(),
3229 )
3230 .await?;
3231
3232 match result {
3233 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3234 let mut connection_pool = session.connection_pool().await;
3235 notify_membership_updated(
3236 &mut connection_pool,
3237 membership_update,
3238 member_id,
3239 &session.peer,
3240 )
3241 }
3242 db::SetMemberRoleResult::InviteUpdated(channel) => {
3243 let update = proto::UpdateChannels {
3244 channel_invitations: vec![channel.to_proto()],
3245 ..Default::default()
3246 };
3247
3248 for connection_id in session
3249 .connection_pool()
3250 .await
3251 .user_connection_ids(member_id)
3252 {
3253 session.peer.send(connection_id, update.clone())?;
3254 }
3255 }
3256 }
3257
3258 response.send(proto::Ack {})?;
3259 Ok(())
3260}
3261
3262/// Change the name of a channel
3263async fn rename_channel(
3264 request: proto::RenameChannel,
3265 response: Response<proto::RenameChannel>,
3266 session: Session,
3267) -> Result<()> {
3268 let db = session.db().await;
3269 let channel_id = ChannelId::from_proto(request.channel_id);
3270 let channel_model = db
3271 .rename_channel(channel_id, session.user_id(), &request.name)
3272 .await?;
3273 let root_id = channel_model.root_id();
3274 let channel = Channel::from_model(channel_model);
3275
3276 response.send(proto::RenameChannelResponse {
3277 channel: Some(channel.to_proto()),
3278 })?;
3279
3280 let connection_pool = session.connection_pool().await;
3281 let update = proto::UpdateChannels {
3282 channels: vec![channel.to_proto()],
3283 ..Default::default()
3284 };
3285 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3286 if role.can_see_channel(channel.visibility) {
3287 session.peer.send(connection_id, update.clone())?;
3288 }
3289 }
3290
3291 Ok(())
3292}
3293
3294/// Move a channel to a new parent.
3295async fn move_channel(
3296 request: proto::MoveChannel,
3297 response: Response<proto::MoveChannel>,
3298 session: Session,
3299) -> Result<()> {
3300 let channel_id = ChannelId::from_proto(request.channel_id);
3301 let to = ChannelId::from_proto(request.to);
3302
3303 let (root_id, channels) = session
3304 .db()
3305 .await
3306 .move_channel(channel_id, to, session.user_id())
3307 .await?;
3308
3309 let connection_pool = session.connection_pool().await;
3310 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3311 let channels = channels
3312 .iter()
3313 .filter_map(|channel| {
3314 if role.can_see_channel(channel.visibility) {
3315 Some(channel.to_proto())
3316 } else {
3317 None
3318 }
3319 })
3320 .collect::<Vec<_>>();
3321 if channels.is_empty() {
3322 continue;
3323 }
3324
3325 let update = proto::UpdateChannels {
3326 channels,
3327 ..Default::default()
3328 };
3329
3330 session.peer.send(connection_id, update.clone())?;
3331 }
3332
3333 response.send(Ack {})?;
3334 Ok(())
3335}
3336
3337async fn reorder_channel(
3338 request: proto::ReorderChannel,
3339 response: Response<proto::ReorderChannel>,
3340 session: Session,
3341) -> Result<()> {
3342 let channel_id = ChannelId::from_proto(request.channel_id);
3343 let direction = request.direction();
3344
3345 let updated_channels = session
3346 .db()
3347 .await
3348 .reorder_channel(channel_id, direction, session.user_id())
3349 .await?;
3350
3351 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3352 let connection_pool = session.connection_pool().await;
3353 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3354 let channels = updated_channels
3355 .iter()
3356 .filter_map(|channel| {
3357 if role.can_see_channel(channel.visibility) {
3358 Some(channel.to_proto())
3359 } else {
3360 None
3361 }
3362 })
3363 .collect::<Vec<_>>();
3364
3365 if channels.is_empty() {
3366 continue;
3367 }
3368
3369 let update = proto::UpdateChannels {
3370 channels,
3371 ..Default::default()
3372 };
3373
3374 session.peer.send(connection_id, update.clone())?;
3375 }
3376 }
3377
3378 response.send(Ack {})?;
3379 Ok(())
3380}
3381
3382/// Get the list of channel members
3383async fn get_channel_members(
3384 request: proto::GetChannelMembers,
3385 response: Response<proto::GetChannelMembers>,
3386 session: Session,
3387) -> Result<()> {
3388 let db = session.db().await;
3389 let channel_id = ChannelId::from_proto(request.channel_id);
3390 let limit = if request.limit == 0 {
3391 u16::MAX as u64
3392 } else {
3393 request.limit
3394 };
3395 let (members, users) = db
3396 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3397 .await?;
3398 response.send(proto::GetChannelMembersResponse { members, users })?;
3399 Ok(())
3400}
3401
3402/// Accept or decline a channel invitation.
3403async fn respond_to_channel_invite(
3404 request: proto::RespondToChannelInvite,
3405 response: Response<proto::RespondToChannelInvite>,
3406 session: Session,
3407) -> Result<()> {
3408 let db = session.db().await;
3409 let channel_id = ChannelId::from_proto(request.channel_id);
3410 let RespondToChannelInvite {
3411 membership_update,
3412 notifications,
3413 } = db
3414 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3415 .await?;
3416
3417 let mut connection_pool = session.connection_pool().await;
3418 if let Some(membership_update) = membership_update {
3419 notify_membership_updated(
3420 &mut connection_pool,
3421 membership_update,
3422 session.user_id(),
3423 &session.peer,
3424 );
3425 } else {
3426 let update = proto::UpdateChannels {
3427 remove_channel_invitations: vec![channel_id.to_proto()],
3428 ..Default::default()
3429 };
3430
3431 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3432 session.peer.send(connection_id, update.clone())?;
3433 }
3434 };
3435
3436 send_notifications(&connection_pool, &session.peer, notifications);
3437
3438 response.send(proto::Ack {})?;
3439
3440 Ok(())
3441}
3442
3443/// Join the channels' room
3444async fn join_channel(
3445 request: proto::JoinChannel,
3446 response: Response<proto::JoinChannel>,
3447 session: Session,
3448) -> Result<()> {
3449 let channel_id = ChannelId::from_proto(request.channel_id);
3450 join_channel_internal(channel_id, Box::new(response), session).await
3451}
3452
3453trait JoinChannelInternalResponse {
3454 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3455}
3456impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3457 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3458 Response::<proto::JoinChannel>::send(self, result)
3459 }
3460}
3461impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3462 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3463 Response::<proto::JoinRoom>::send(self, result)
3464 }
3465}
3466
3467async fn join_channel_internal(
3468 channel_id: ChannelId,
3469 response: Box<impl JoinChannelInternalResponse>,
3470 session: Session,
3471) -> Result<()> {
3472 let joined_room = {
3473 let mut db = session.db().await;
3474 // If zed quits without leaving the room, and the user re-opens zed before the
3475 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3476 // room they were in.
3477 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3478 tracing::info!(
3479 stale_connection_id = %connection,
3480 "cleaning up stale connection",
3481 );
3482 drop(db);
3483 leave_room_for_session(&session, connection).await?;
3484 db = session.db().await;
3485 }
3486
3487 let (joined_room, membership_updated, role) = db
3488 .join_channel(channel_id, session.user_id(), session.connection_id)
3489 .await?;
3490
3491 let live_kit_connection_info =
3492 session
3493 .app_state
3494 .livekit_client
3495 .as_ref()
3496 .and_then(|live_kit| {
3497 let (can_publish, token) = if role == ChannelRole::Guest {
3498 (
3499 false,
3500 live_kit
3501 .guest_token(
3502 &joined_room.room.livekit_room,
3503 &session.user_id().to_string(),
3504 )
3505 .trace_err()?,
3506 )
3507 } else {
3508 (
3509 true,
3510 live_kit
3511 .room_token(
3512 &joined_room.room.livekit_room,
3513 &session.user_id().to_string(),
3514 )
3515 .trace_err()?,
3516 )
3517 };
3518
3519 Some(LiveKitConnectionInfo {
3520 server_url: live_kit.url().into(),
3521 token,
3522 can_publish,
3523 })
3524 });
3525
3526 response.send(proto::JoinRoomResponse {
3527 room: Some(joined_room.room.clone()),
3528 channel_id: joined_room
3529 .channel
3530 .as_ref()
3531 .map(|channel| channel.id.to_proto()),
3532 live_kit_connection_info,
3533 })?;
3534
3535 let mut connection_pool = session.connection_pool().await;
3536 if let Some(membership_updated) = membership_updated {
3537 notify_membership_updated(
3538 &mut connection_pool,
3539 membership_updated,
3540 session.user_id(),
3541 &session.peer,
3542 );
3543 }
3544
3545 room_updated(&joined_room.room, &session.peer);
3546
3547 joined_room
3548 };
3549
3550 channel_updated(
3551 &joined_room.channel.context("channel not returned")?,
3552 &joined_room.room,
3553 &session.peer,
3554 &*session.connection_pool().await,
3555 );
3556
3557 update_user_contacts(session.user_id(), &session).await?;
3558 Ok(())
3559}
3560
3561/// Start editing the channel notes
3562async fn join_channel_buffer(
3563 request: proto::JoinChannelBuffer,
3564 response: Response<proto::JoinChannelBuffer>,
3565 session: Session,
3566) -> Result<()> {
3567 let db = session.db().await;
3568 let channel_id = ChannelId::from_proto(request.channel_id);
3569
3570 let open_response = db
3571 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3572 .await?;
3573
3574 let collaborators = open_response.collaborators.clone();
3575 response.send(open_response)?;
3576
3577 let update = UpdateChannelBufferCollaborators {
3578 channel_id: channel_id.to_proto(),
3579 collaborators: collaborators.clone(),
3580 };
3581 channel_buffer_updated(
3582 session.connection_id,
3583 collaborators
3584 .iter()
3585 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3586 &update,
3587 &session.peer,
3588 );
3589
3590 Ok(())
3591}
3592
3593/// Edit the channel notes
3594async fn update_channel_buffer(
3595 request: proto::UpdateChannelBuffer,
3596 session: Session,
3597) -> Result<()> {
3598 let db = session.db().await;
3599 let channel_id = ChannelId::from_proto(request.channel_id);
3600
3601 let (collaborators, epoch, version) = db
3602 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3603 .await?;
3604
3605 channel_buffer_updated(
3606 session.connection_id,
3607 collaborators.clone(),
3608 &proto::UpdateChannelBuffer {
3609 channel_id: channel_id.to_proto(),
3610 operations: request.operations,
3611 },
3612 &session.peer,
3613 );
3614
3615 let pool = &*session.connection_pool().await;
3616
3617 let non_collaborators =
3618 pool.channel_connection_ids(channel_id)
3619 .filter_map(|(connection_id, _)| {
3620 if collaborators.contains(&connection_id) {
3621 None
3622 } else {
3623 Some(connection_id)
3624 }
3625 });
3626
3627 broadcast(None, non_collaborators, |peer_id| {
3628 session.peer.send(
3629 peer_id,
3630 proto::UpdateChannels {
3631 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3632 channel_id: channel_id.to_proto(),
3633 epoch: epoch as u64,
3634 version: version.clone(),
3635 }],
3636 ..Default::default()
3637 },
3638 )
3639 });
3640
3641 Ok(())
3642}
3643
3644/// Rejoin the channel notes after a connection blip
3645async fn rejoin_channel_buffers(
3646 request: proto::RejoinChannelBuffers,
3647 response: Response<proto::RejoinChannelBuffers>,
3648 session: Session,
3649) -> Result<()> {
3650 let db = session.db().await;
3651 let buffers = db
3652 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3653 .await?;
3654
3655 for rejoined_buffer in &buffers {
3656 let collaborators_to_notify = rejoined_buffer
3657 .buffer
3658 .collaborators
3659 .iter()
3660 .filter_map(|c| Some(c.peer_id?.into()));
3661 channel_buffer_updated(
3662 session.connection_id,
3663 collaborators_to_notify,
3664 &proto::UpdateChannelBufferCollaborators {
3665 channel_id: rejoined_buffer.buffer.channel_id,
3666 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3667 },
3668 &session.peer,
3669 );
3670 }
3671
3672 response.send(proto::RejoinChannelBuffersResponse {
3673 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3674 })?;
3675
3676 Ok(())
3677}
3678
3679/// Stop editing the channel notes
3680async fn leave_channel_buffer(
3681 request: proto::LeaveChannelBuffer,
3682 response: Response<proto::LeaveChannelBuffer>,
3683 session: Session,
3684) -> Result<()> {
3685 let db = session.db().await;
3686 let channel_id = ChannelId::from_proto(request.channel_id);
3687
3688 let left_buffer = db
3689 .leave_channel_buffer(channel_id, session.connection_id)
3690 .await?;
3691
3692 response.send(Ack {})?;
3693
3694 channel_buffer_updated(
3695 session.connection_id,
3696 left_buffer.connections,
3697 &proto::UpdateChannelBufferCollaborators {
3698 channel_id: channel_id.to_proto(),
3699 collaborators: left_buffer.collaborators,
3700 },
3701 &session.peer,
3702 );
3703
3704 Ok(())
3705}
3706
3707fn channel_buffer_updated<T: EnvelopedMessage>(
3708 sender_id: ConnectionId,
3709 collaborators: impl IntoIterator<Item = ConnectionId>,
3710 message: &T,
3711 peer: &Peer,
3712) {
3713 broadcast(Some(sender_id), collaborators, |peer_id| {
3714 peer.send(peer_id, message.clone())
3715 });
3716}
3717
3718fn send_notifications(
3719 connection_pool: &ConnectionPool,
3720 peer: &Peer,
3721 notifications: db::NotificationBatch,
3722) {
3723 for (user_id, notification) in notifications {
3724 for connection_id in connection_pool.user_connection_ids(user_id) {
3725 if let Err(error) = peer.send(
3726 connection_id,
3727 proto::AddNotification {
3728 notification: Some(notification.clone()),
3729 },
3730 ) {
3731 tracing::error!(
3732 "failed to send notification to {:?} {}",
3733 connection_id,
3734 error
3735 );
3736 }
3737 }
3738 }
3739}
3740
3741/// Send a message to the channel
3742async fn send_channel_message(
3743 request: proto::SendChannelMessage,
3744 response: Response<proto::SendChannelMessage>,
3745 session: Session,
3746) -> Result<()> {
3747 // Validate the message body.
3748 let body = request.body.trim().to_string();
3749 if body.len() > MAX_MESSAGE_LEN {
3750 return Err(anyhow!("message is too long"))?;
3751 }
3752 if body.is_empty() {
3753 return Err(anyhow!("message can't be blank"))?;
3754 }
3755
3756 // TODO: adjust mentions if body is trimmed
3757
3758 let timestamp = OffsetDateTime::now_utc();
3759 let nonce = request.nonce.context("nonce can't be blank")?;
3760
3761 let channel_id = ChannelId::from_proto(request.channel_id);
3762 let CreatedChannelMessage {
3763 message_id,
3764 participant_connection_ids,
3765 notifications,
3766 } = session
3767 .db()
3768 .await
3769 .create_channel_message(
3770 channel_id,
3771 session.user_id(),
3772 &body,
3773 &request.mentions,
3774 timestamp,
3775 nonce.clone().into(),
3776 request.reply_to_message_id.map(MessageId::from_proto),
3777 )
3778 .await?;
3779
3780 let message = proto::ChannelMessage {
3781 sender_id: session.user_id().to_proto(),
3782 id: message_id.to_proto(),
3783 body,
3784 mentions: request.mentions,
3785 timestamp: timestamp.unix_timestamp() as u64,
3786 nonce: Some(nonce),
3787 reply_to_message_id: request.reply_to_message_id,
3788 edited_at: None,
3789 };
3790 broadcast(
3791 Some(session.connection_id),
3792 participant_connection_ids.clone(),
3793 |connection| {
3794 session.peer.send(
3795 connection,
3796 proto::ChannelMessageSent {
3797 channel_id: channel_id.to_proto(),
3798 message: Some(message.clone()),
3799 },
3800 )
3801 },
3802 );
3803 response.send(proto::SendChannelMessageResponse {
3804 message: Some(message),
3805 })?;
3806
3807 let pool = &*session.connection_pool().await;
3808 let non_participants =
3809 pool.channel_connection_ids(channel_id)
3810 .filter_map(|(connection_id, _)| {
3811 if participant_connection_ids.contains(&connection_id) {
3812 None
3813 } else {
3814 Some(connection_id)
3815 }
3816 });
3817 broadcast(None, non_participants, |peer_id| {
3818 session.peer.send(
3819 peer_id,
3820 proto::UpdateChannels {
3821 latest_channel_message_ids: vec![proto::ChannelMessageId {
3822 channel_id: channel_id.to_proto(),
3823 message_id: message_id.to_proto(),
3824 }],
3825 ..Default::default()
3826 },
3827 )
3828 });
3829 send_notifications(pool, &session.peer, notifications);
3830
3831 Ok(())
3832}
3833
3834/// Delete a channel message
3835async fn remove_channel_message(
3836 request: proto::RemoveChannelMessage,
3837 response: Response<proto::RemoveChannelMessage>,
3838 session: Session,
3839) -> Result<()> {
3840 let channel_id = ChannelId::from_proto(request.channel_id);
3841 let message_id = MessageId::from_proto(request.message_id);
3842 let (connection_ids, existing_notification_ids) = session
3843 .db()
3844 .await
3845 .remove_channel_message(channel_id, message_id, session.user_id())
3846 .await?;
3847
3848 broadcast(
3849 Some(session.connection_id),
3850 connection_ids,
3851 move |connection| {
3852 session.peer.send(connection, request.clone())?;
3853
3854 for notification_id in &existing_notification_ids {
3855 session.peer.send(
3856 connection,
3857 proto::DeleteNotification {
3858 notification_id: (*notification_id).to_proto(),
3859 },
3860 )?;
3861 }
3862
3863 Ok(())
3864 },
3865 );
3866 response.send(proto::Ack {})?;
3867 Ok(())
3868}
3869
3870async fn update_channel_message(
3871 request: proto::UpdateChannelMessage,
3872 response: Response<proto::UpdateChannelMessage>,
3873 session: Session,
3874) -> Result<()> {
3875 let channel_id = ChannelId::from_proto(request.channel_id);
3876 let message_id = MessageId::from_proto(request.message_id);
3877 let updated_at = OffsetDateTime::now_utc();
3878 let UpdatedChannelMessage {
3879 message_id,
3880 participant_connection_ids,
3881 notifications,
3882 reply_to_message_id,
3883 timestamp,
3884 deleted_mention_notification_ids,
3885 updated_mention_notifications,
3886 } = session
3887 .db()
3888 .await
3889 .update_channel_message(
3890 channel_id,
3891 message_id,
3892 session.user_id(),
3893 request.body.as_str(),
3894 &request.mentions,
3895 updated_at,
3896 )
3897 .await?;
3898
3899 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3900
3901 let message = proto::ChannelMessage {
3902 sender_id: session.user_id().to_proto(),
3903 id: message_id.to_proto(),
3904 body: request.body.clone(),
3905 mentions: request.mentions.clone(),
3906 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3907 nonce: Some(nonce),
3908 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3909 edited_at: Some(updated_at.unix_timestamp() as u64),
3910 };
3911
3912 response.send(proto::Ack {})?;
3913
3914 let pool = &*session.connection_pool().await;
3915 broadcast(
3916 Some(session.connection_id),
3917 participant_connection_ids,
3918 |connection| {
3919 session.peer.send(
3920 connection,
3921 proto::ChannelMessageUpdate {
3922 channel_id: channel_id.to_proto(),
3923 message: Some(message.clone()),
3924 },
3925 )?;
3926
3927 for notification_id in &deleted_mention_notification_ids {
3928 session.peer.send(
3929 connection,
3930 proto::DeleteNotification {
3931 notification_id: (*notification_id).to_proto(),
3932 },
3933 )?;
3934 }
3935
3936 for notification in &updated_mention_notifications {
3937 session.peer.send(
3938 connection,
3939 proto::UpdateNotification {
3940 notification: Some(notification.clone()),
3941 },
3942 )?;
3943 }
3944
3945 Ok(())
3946 },
3947 );
3948
3949 send_notifications(pool, &session.peer, notifications);
3950
3951 Ok(())
3952}
3953
3954/// Mark a channel message as read
3955async fn acknowledge_channel_message(
3956 request: proto::AckChannelMessage,
3957 session: Session,
3958) -> Result<()> {
3959 let channel_id = ChannelId::from_proto(request.channel_id);
3960 let message_id = MessageId::from_proto(request.message_id);
3961 let notifications = session
3962 .db()
3963 .await
3964 .observe_channel_message(channel_id, session.user_id(), message_id)
3965 .await?;
3966 send_notifications(
3967 &*session.connection_pool().await,
3968 &session.peer,
3969 notifications,
3970 );
3971 Ok(())
3972}
3973
3974/// Mark a buffer version as synced
3975async fn acknowledge_buffer_version(
3976 request: proto::AckBufferOperation,
3977 session: Session,
3978) -> Result<()> {
3979 let buffer_id = BufferId::from_proto(request.buffer_id);
3980 session
3981 .db()
3982 .await
3983 .observe_buffer_version(
3984 buffer_id,
3985 session.user_id(),
3986 request.epoch as i32,
3987 &request.version,
3988 )
3989 .await?;
3990 Ok(())
3991}
3992
3993/// Get a Supermaven API key for the user
3994async fn get_supermaven_api_key(
3995 _request: proto::GetSupermavenApiKey,
3996 response: Response<proto::GetSupermavenApiKey>,
3997 session: Session,
3998) -> Result<()> {
3999 let user_id: String = session.user_id().to_string();
4000 if !session.is_staff() {
4001 return Err(anyhow!("supermaven not enabled for this account"))?;
4002 }
4003
4004 let email = session.email().context("user must have an email")?;
4005
4006 let supermaven_admin_api = session
4007 .supermaven_client
4008 .as_ref()
4009 .context("supermaven not configured")?;
4010
4011 let result = supermaven_admin_api
4012 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4013 .await?;
4014
4015 response.send(proto::GetSupermavenApiKeyResponse {
4016 api_key: result.api_key,
4017 })?;
4018
4019 Ok(())
4020}
4021
4022/// Start receiving chat updates for a channel
4023async fn join_channel_chat(
4024 request: proto::JoinChannelChat,
4025 response: Response<proto::JoinChannelChat>,
4026 session: Session,
4027) -> Result<()> {
4028 let channel_id = ChannelId::from_proto(request.channel_id);
4029
4030 let db = session.db().await;
4031 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4032 .await?;
4033 let messages = db
4034 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4035 .await?;
4036 response.send(proto::JoinChannelChatResponse {
4037 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4038 messages,
4039 })?;
4040 Ok(())
4041}
4042
4043/// Stop receiving chat updates for a channel
4044async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
4045 let channel_id = ChannelId::from_proto(request.channel_id);
4046 session
4047 .db()
4048 .await
4049 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4050 .await?;
4051 Ok(())
4052}
4053
4054/// Retrieve the chat history for a channel
4055async fn get_channel_messages(
4056 request: proto::GetChannelMessages,
4057 response: Response<proto::GetChannelMessages>,
4058 session: Session,
4059) -> Result<()> {
4060 let channel_id = ChannelId::from_proto(request.channel_id);
4061 let messages = session
4062 .db()
4063 .await
4064 .get_channel_messages(
4065 channel_id,
4066 session.user_id(),
4067 MESSAGE_COUNT_PER_PAGE,
4068 Some(MessageId::from_proto(request.before_message_id)),
4069 )
4070 .await?;
4071 response.send(proto::GetChannelMessagesResponse {
4072 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4073 messages,
4074 })?;
4075 Ok(())
4076}
4077
4078/// Retrieve specific chat messages
4079async fn get_channel_messages_by_id(
4080 request: proto::GetChannelMessagesById,
4081 response: Response<proto::GetChannelMessagesById>,
4082 session: Session,
4083) -> Result<()> {
4084 let message_ids = request
4085 .message_ids
4086 .iter()
4087 .map(|id| MessageId::from_proto(*id))
4088 .collect::<Vec<_>>();
4089 let messages = session
4090 .db()
4091 .await
4092 .get_channel_messages_by_id(session.user_id(), &message_ids)
4093 .await?;
4094 response.send(proto::GetChannelMessagesResponse {
4095 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4096 messages,
4097 })?;
4098 Ok(())
4099}
4100
4101/// Retrieve the current users notifications
4102async fn get_notifications(
4103 request: proto::GetNotifications,
4104 response: Response<proto::GetNotifications>,
4105 session: Session,
4106) -> Result<()> {
4107 let notifications = session
4108 .db()
4109 .await
4110 .get_notifications(
4111 session.user_id(),
4112 NOTIFICATION_COUNT_PER_PAGE,
4113 request.before_id.map(db::NotificationId::from_proto),
4114 )
4115 .await?;
4116 response.send(proto::GetNotificationsResponse {
4117 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4118 notifications,
4119 })?;
4120 Ok(())
4121}
4122
4123/// Mark notifications as read
4124async fn mark_notification_as_read(
4125 request: proto::MarkNotificationRead,
4126 response: Response<proto::MarkNotificationRead>,
4127 session: Session,
4128) -> Result<()> {
4129 let database = &session.db().await;
4130 let notifications = database
4131 .mark_notification_as_read_by_id(
4132 session.user_id(),
4133 NotificationId::from_proto(request.notification_id),
4134 )
4135 .await?;
4136 send_notifications(
4137 &*session.connection_pool().await,
4138 &session.peer,
4139 notifications,
4140 );
4141 response.send(proto::Ack {})?;
4142 Ok(())
4143}
4144
4145/// Get the current users information
4146async fn get_private_user_info(
4147 _request: proto::GetPrivateUserInfo,
4148 response: Response<proto::GetPrivateUserInfo>,
4149 session: Session,
4150) -> Result<()> {
4151 let db = session.db().await;
4152
4153 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4154 let user = db
4155 .get_user_by_id(session.user_id())
4156 .await?
4157 .context("user not found")?;
4158 let flags = db.get_user_flags(session.user_id()).await?;
4159
4160 response.send(proto::GetPrivateUserInfoResponse {
4161 metrics_id,
4162 staff: user.admin,
4163 flags,
4164 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4165 })?;
4166 Ok(())
4167}
4168
4169/// Accept the terms of service (tos) on behalf of the current user
4170async fn accept_terms_of_service(
4171 _request: proto::AcceptTermsOfService,
4172 response: Response<proto::AcceptTermsOfService>,
4173 session: Session,
4174) -> Result<()> {
4175 let db = session.db().await;
4176
4177 let accepted_tos_at = Utc::now();
4178 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4179 .await?;
4180
4181 response.send(proto::AcceptTermsOfServiceResponse {
4182 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4183 })?;
4184
4185 // When the user accepts the terms of service, we want to refresh their LLM
4186 // token to grant access.
4187 session
4188 .peer
4189 .send(session.connection_id, proto::RefreshLlmToken {})?;
4190
4191 Ok(())
4192}
4193
4194async fn get_llm_api_token(
4195 _request: proto::GetLlmToken,
4196 response: Response<proto::GetLlmToken>,
4197 session: Session,
4198) -> Result<()> {
4199 let db = session.db().await;
4200
4201 let flags = db.get_user_flags(session.user_id()).await?;
4202
4203 let user_id = session.user_id();
4204 let user = db
4205 .get_user_by_id(user_id)
4206 .await?
4207 .with_context(|| format!("user {user_id} not found"))?;
4208
4209 if user.accepted_tos_at.is_none() {
4210 Err(anyhow!("terms of service not accepted"))?
4211 }
4212
4213 let stripe_client = session
4214 .app_state
4215 .stripe_client
4216 .as_ref()
4217 .context("failed to retrieve Stripe client")?;
4218
4219 let stripe_billing = session
4220 .app_state
4221 .stripe_billing
4222 .as_ref()
4223 .context("failed to retrieve Stripe billing object")?;
4224
4225 let billing_customer = if let Some(billing_customer) =
4226 db.get_billing_customer_by_user_id(user.id).await?
4227 {
4228 billing_customer
4229 } else {
4230 let customer_id = stripe_billing
4231 .find_or_create_customer_by_email(user.email_address.as_deref())
4232 .await?;
4233
4234 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4235 .await?
4236 .context("billing customer not found")?
4237 };
4238
4239 let billing_subscription =
4240 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4241 billing_subscription
4242 } else {
4243 let stripe_customer_id =
4244 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4245
4246 let stripe_subscription = stripe_billing
4247 .subscribe_to_zed_free(stripe_customer_id)
4248 .await?;
4249
4250 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4251 billing_customer_id: billing_customer.id,
4252 kind: Some(SubscriptionKind::ZedFree),
4253 stripe_subscription_id: stripe_subscription.id.to_string(),
4254 stripe_subscription_status: stripe_subscription.status.into(),
4255 stripe_cancellation_reason: None,
4256 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4257 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4258 })
4259 .await?
4260 };
4261
4262 let billing_preferences = db.get_billing_preferences(user.id).await?;
4263
4264 let token = LlmTokenClaims::create(
4265 &user,
4266 session.is_staff(),
4267 billing_customer,
4268 billing_preferences,
4269 &flags,
4270 billing_subscription,
4271 session.system_id.clone(),
4272 &session.app_state.config,
4273 )?;
4274 response.send(proto::GetLlmTokenResponse { token })?;
4275 Ok(())
4276}
4277
4278fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4279 let message = match message {
4280 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4281 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4282 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4283 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4284 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4285 code: frame.code.into(),
4286 reason: frame.reason.as_str().to_owned().into(),
4287 })),
4288 // We should never receive a frame while reading the message, according
4289 // to the `tungstenite` maintainers:
4290 //
4291 // > It cannot occur when you read messages from the WebSocket, but it
4292 // > can be used when you want to send the raw frames (e.g. you want to
4293 // > send the frames to the WebSocket without composing the full message first).
4294 // >
4295 // > — https://github.com/snapview/tungstenite-rs/issues/268
4296 TungsteniteMessage::Frame(_) => {
4297 bail!("received an unexpected frame while reading the message")
4298 }
4299 };
4300
4301 Ok(message)
4302}
4303
4304fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4305 match message {
4306 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4307 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4308 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4309 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4310 AxumMessage::Close(frame) => {
4311 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4312 code: frame.code.into(),
4313 reason: frame.reason.as_ref().into(),
4314 }))
4315 }
4316 }
4317}
4318
4319fn notify_membership_updated(
4320 connection_pool: &mut ConnectionPool,
4321 result: MembershipUpdated,
4322 user_id: UserId,
4323 peer: &Peer,
4324) {
4325 for membership in &result.new_channels.channel_memberships {
4326 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4327 }
4328 for channel_id in &result.removed_channels {
4329 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4330 }
4331
4332 let user_channels_update = proto::UpdateUserChannels {
4333 channel_memberships: result
4334 .new_channels
4335 .channel_memberships
4336 .iter()
4337 .map(|cm| proto::ChannelMembership {
4338 channel_id: cm.channel_id.to_proto(),
4339 role: cm.role.into(),
4340 })
4341 .collect(),
4342 ..Default::default()
4343 };
4344
4345 let mut update = build_channels_update(result.new_channels);
4346 update.delete_channels = result
4347 .removed_channels
4348 .into_iter()
4349 .map(|id| id.to_proto())
4350 .collect();
4351 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4352
4353 for connection_id in connection_pool.user_connection_ids(user_id) {
4354 peer.send(connection_id, user_channels_update.clone())
4355 .trace_err();
4356 peer.send(connection_id, update.clone()).trace_err();
4357 }
4358}
4359
4360fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4361 proto::UpdateUserChannels {
4362 channel_memberships: channels
4363 .channel_memberships
4364 .iter()
4365 .map(|m| proto::ChannelMembership {
4366 channel_id: m.channel_id.to_proto(),
4367 role: m.role.into(),
4368 })
4369 .collect(),
4370 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4371 observed_channel_message_id: channels.observed_channel_messages.clone(),
4372 }
4373}
4374
4375fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4376 let mut update = proto::UpdateChannels::default();
4377
4378 for channel in channels.channels {
4379 update.channels.push(channel.to_proto());
4380 }
4381
4382 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4383 update.latest_channel_message_ids = channels.latest_channel_messages;
4384
4385 for (channel_id, participants) in channels.channel_participants {
4386 update
4387 .channel_participants
4388 .push(proto::ChannelParticipants {
4389 channel_id: channel_id.to_proto(),
4390 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4391 });
4392 }
4393
4394 for channel in channels.invited_channels {
4395 update.channel_invitations.push(channel.to_proto());
4396 }
4397
4398 update
4399}
4400
4401fn build_initial_contacts_update(
4402 contacts: Vec<db::Contact>,
4403 pool: &ConnectionPool,
4404) -> proto::UpdateContacts {
4405 let mut update = proto::UpdateContacts::default();
4406
4407 for contact in contacts {
4408 match contact {
4409 db::Contact::Accepted { user_id, busy } => {
4410 update.contacts.push(contact_for_user(user_id, busy, pool));
4411 }
4412 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4413 db::Contact::Incoming { user_id } => {
4414 update
4415 .incoming_requests
4416 .push(proto::IncomingContactRequest {
4417 requester_id: user_id.to_proto(),
4418 })
4419 }
4420 }
4421 }
4422
4423 update
4424}
4425
4426fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4427 proto::Contact {
4428 user_id: user_id.to_proto(),
4429 online: pool.is_user_online(user_id),
4430 busy,
4431 }
4432}
4433
4434fn room_updated(room: &proto::Room, peer: &Peer) {
4435 broadcast(
4436 None,
4437 room.participants
4438 .iter()
4439 .filter_map(|participant| Some(participant.peer_id?.into())),
4440 |peer_id| {
4441 peer.send(
4442 peer_id,
4443 proto::RoomUpdated {
4444 room: Some(room.clone()),
4445 },
4446 )
4447 },
4448 );
4449}
4450
4451fn channel_updated(
4452 channel: &db::channel::Model,
4453 room: &proto::Room,
4454 peer: &Peer,
4455 pool: &ConnectionPool,
4456) {
4457 let participants = room
4458 .participants
4459 .iter()
4460 .map(|p| p.user_id)
4461 .collect::<Vec<_>>();
4462
4463 broadcast(
4464 None,
4465 pool.channel_connection_ids(channel.root_id())
4466 .filter_map(|(channel_id, role)| {
4467 role.can_see_channel(channel.visibility)
4468 .then_some(channel_id)
4469 }),
4470 |peer_id| {
4471 peer.send(
4472 peer_id,
4473 proto::UpdateChannels {
4474 channel_participants: vec![proto::ChannelParticipants {
4475 channel_id: channel.id.to_proto(),
4476 participant_user_ids: participants.clone(),
4477 }],
4478 ..Default::default()
4479 },
4480 )
4481 },
4482 );
4483}
4484
4485async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4486 let db = session.db().await;
4487
4488 let contacts = db.get_contacts(user_id).await?;
4489 let busy = db.is_user_busy(user_id).await?;
4490
4491 let pool = session.connection_pool().await;
4492 let updated_contact = contact_for_user(user_id, busy, &pool);
4493 for contact in contacts {
4494 if let db::Contact::Accepted {
4495 user_id: contact_user_id,
4496 ..
4497 } = contact
4498 {
4499 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4500 session
4501 .peer
4502 .send(
4503 contact_conn_id,
4504 proto::UpdateContacts {
4505 contacts: vec![updated_contact.clone()],
4506 remove_contacts: Default::default(),
4507 incoming_requests: Default::default(),
4508 remove_incoming_requests: Default::default(),
4509 outgoing_requests: Default::default(),
4510 remove_outgoing_requests: Default::default(),
4511 },
4512 )
4513 .trace_err();
4514 }
4515 }
4516 }
4517 Ok(())
4518}
4519
4520async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4521 let mut contacts_to_update = HashSet::default();
4522
4523 let room_id;
4524 let canceled_calls_to_user_ids;
4525 let livekit_room;
4526 let delete_livekit_room;
4527 let room;
4528 let channel;
4529
4530 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4531 contacts_to_update.insert(session.user_id());
4532
4533 for project in left_room.left_projects.values() {
4534 project_left(project, session);
4535 }
4536
4537 room_id = RoomId::from_proto(left_room.room.id);
4538 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4539 livekit_room = mem::take(&mut left_room.room.livekit_room);
4540 delete_livekit_room = left_room.deleted;
4541 room = mem::take(&mut left_room.room);
4542 channel = mem::take(&mut left_room.channel);
4543
4544 room_updated(&room, &session.peer);
4545 } else {
4546 return Ok(());
4547 }
4548
4549 if let Some(channel) = channel {
4550 channel_updated(
4551 &channel,
4552 &room,
4553 &session.peer,
4554 &*session.connection_pool().await,
4555 );
4556 }
4557
4558 {
4559 let pool = session.connection_pool().await;
4560 for canceled_user_id in canceled_calls_to_user_ids {
4561 for connection_id in pool.user_connection_ids(canceled_user_id) {
4562 session
4563 .peer
4564 .send(
4565 connection_id,
4566 proto::CallCanceled {
4567 room_id: room_id.to_proto(),
4568 },
4569 )
4570 .trace_err();
4571 }
4572 contacts_to_update.insert(canceled_user_id);
4573 }
4574 }
4575
4576 for contact_user_id in contacts_to_update {
4577 update_user_contacts(contact_user_id, session).await?;
4578 }
4579
4580 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4581 live_kit
4582 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4583 .await
4584 .trace_err();
4585
4586 if delete_livekit_room {
4587 live_kit.delete_room(livekit_room).await.trace_err();
4588 }
4589 }
4590
4591 Ok(())
4592}
4593
4594async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4595 let left_channel_buffers = session
4596 .db()
4597 .await
4598 .leave_channel_buffers(session.connection_id)
4599 .await?;
4600
4601 for left_buffer in left_channel_buffers {
4602 channel_buffer_updated(
4603 session.connection_id,
4604 left_buffer.connections,
4605 &proto::UpdateChannelBufferCollaborators {
4606 channel_id: left_buffer.channel_id.to_proto(),
4607 collaborators: left_buffer.collaborators,
4608 },
4609 &session.peer,
4610 );
4611 }
4612
4613 Ok(())
4614}
4615
4616fn project_left(project: &db::LeftProject, session: &Session) {
4617 for connection_id in &project.connection_ids {
4618 if project.should_unshare {
4619 session
4620 .peer
4621 .send(
4622 *connection_id,
4623 proto::UnshareProject {
4624 project_id: project.id.to_proto(),
4625 },
4626 )
4627 .trace_err();
4628 } else {
4629 session
4630 .peer
4631 .send(
4632 *connection_id,
4633 proto::RemoveProjectCollaborator {
4634 project_id: project.id.to_proto(),
4635 peer_id: Some(session.connection_id.into()),
4636 },
4637 )
4638 .trace_err();
4639 }
4640 }
4641}
4642
4643pub trait ResultExt {
4644 type Ok;
4645
4646 fn trace_err(self) -> Option<Self::Ok>;
4647}
4648
4649impl<T, E> ResultExt for Result<T, E>
4650where
4651 E: std::fmt::Debug,
4652{
4653 type Ok = T;
4654
4655 #[track_caller]
4656 fn trace_err(self) -> Option<T> {
4657 match self {
4658 Ok(value) => Some(value),
4659 Err(error) => {
4660 tracing::error!("{:?}", error);
4661 None
4662 }
4663 }
4664 }
4665}