1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{
8 AGENT_EXTENDED_TRIAL_FEATURE_FLAG, BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG, LlmTokenClaims,
9 MIN_ACCOUNT_AGE_FOR_LLM_USE,
10};
11use crate::stripe_client::StripeCustomerId;
12use crate::{
13 AppState, Error, Result, auth,
14 db::{
15 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
16 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
17 NotificationId, ProjectId, RejoinedProject, RemoveChannelMemberResult,
18 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
19 },
20 executor::Executor,
21};
22use anyhow::{Context as _, anyhow, bail};
23use async_tungstenite::tungstenite::{
24 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
25};
26use axum::{
27 Extension, Router, TypedHeader,
28 body::Body,
29 extract::{
30 ConnectInfo, WebSocketUpgrade,
31 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
32 },
33 headers::{Header, HeaderName},
34 http::StatusCode,
35 middleware,
36 response::IntoResponse,
37 routing::get,
38};
39use chrono::Utc;
40use collections::{HashMap, HashSet};
41pub use connection_pool::{ConnectionPool, ZedVersion};
42use core::fmt::{self, Debug, Formatter};
43use reqwest_client::ReqwestClient;
44use rpc::proto::split_repository_update;
45use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
46
47use futures::{
48 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
49 stream::FuturesUnordered,
50};
51use prometheus::{IntGauge, register_int_gauge};
52use rpc::{
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54 proto::{
55 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
56 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
57 },
58};
59use semantic_version::SemanticVersion;
60use serde::{Serialize, Serializer};
61use std::{
62 any::TypeId,
63 future::Future,
64 marker::PhantomData,
65 mem,
66 net::SocketAddr,
67 ops::{Deref, DerefMut},
68 rc::Rc,
69 sync::{
70 Arc, OnceLock,
71 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
72 },
73 time::{Duration, Instant},
74};
75use time::OffsetDateTime;
76use tokio::sync::{Semaphore, watch};
77use tower::ServiceBuilder;
78use tracing::{
79 Instrument,
80 field::{self},
81 info_span, instrument,
82};
83
84pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
85
86// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
87pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
88
89const MESSAGE_COUNT_PER_PAGE: usize = 100;
90const MAX_MESSAGE_LEN: usize = 1024;
91const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
92const MAX_CONCURRENT_CONNECTIONS: usize = 512;
93
94static CONCURRENT_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
95
96type MessageHandler =
97 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
98
99pub struct ConnectionGuard;
100
101impl ConnectionGuard {
102 pub fn try_acquire() -> Result<Self, ()> {
103 let current_connections = CONCURRENT_CONNECTIONS.fetch_add(1, SeqCst);
104 if current_connections >= MAX_CONCURRENT_CONNECTIONS {
105 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
106 tracing::error!(
107 "too many concurrent connections: {}",
108 current_connections + 1
109 );
110 return Err(());
111 }
112 Ok(ConnectionGuard)
113 }
114}
115
116impl Drop for ConnectionGuard {
117 fn drop(&mut self) {
118 CONCURRENT_CONNECTIONS.fetch_sub(1, SeqCst);
119 }
120}
121
122struct Response<R> {
123 peer: Arc<Peer>,
124 receipt: Receipt<R>,
125 responded: Arc<AtomicBool>,
126}
127
128impl<R: RequestMessage> Response<R> {
129 fn send(self, payload: R::Response) -> Result<()> {
130 self.responded.store(true, SeqCst);
131 self.peer.respond(self.receipt, payload)?;
132 Ok(())
133 }
134}
135
136#[derive(Clone, Debug)]
137pub enum Principal {
138 User(User),
139 Impersonated { user: User, admin: User },
140}
141
142impl Principal {
143 fn user(&self) -> &User {
144 match self {
145 Principal::User(user) => user,
146 Principal::Impersonated { user, .. } => user,
147 }
148 }
149
150 fn update_span(&self, span: &tracing::Span) {
151 match &self {
152 Principal::User(user) => {
153 span.record("user_id", user.id.0);
154 span.record("login", &user.github_login);
155 }
156 Principal::Impersonated { user, admin } => {
157 span.record("user_id", user.id.0);
158 span.record("login", &user.github_login);
159 span.record("impersonator", &admin.github_login);
160 }
161 }
162 }
163}
164
165#[derive(Clone)]
166struct Session {
167 principal: Principal,
168 connection_id: ConnectionId,
169 db: Arc<tokio::sync::Mutex<DbHandle>>,
170 peer: Arc<Peer>,
171 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
172 app_state: Arc<AppState>,
173 supermaven_client: Option<Arc<SupermavenAdminApi>>,
174 /// The GeoIP country code for the user.
175 #[allow(unused)]
176 geoip_country_code: Option<String>,
177 system_id: Option<String>,
178 _executor: Executor,
179}
180
181impl Session {
182 async fn db(&self) -> tokio::sync::MutexGuard<'_, DbHandle> {
183 #[cfg(test)]
184 tokio::task::yield_now().await;
185 let guard = self.db.lock().await;
186 #[cfg(test)]
187 tokio::task::yield_now().await;
188 guard
189 }
190
191 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
192 #[cfg(test)]
193 tokio::task::yield_now().await;
194 let guard = self.connection_pool.lock();
195 ConnectionPoolGuard {
196 guard,
197 _not_send: PhantomData,
198 }
199 }
200
201 fn is_staff(&self) -> bool {
202 match &self.principal {
203 Principal::User(user) => user.admin,
204 Principal::Impersonated { .. } => true,
205 }
206 }
207
208 fn user_id(&self) -> UserId {
209 match &self.principal {
210 Principal::User(user) => user.id,
211 Principal::Impersonated { user, .. } => user.id,
212 }
213 }
214
215 pub fn email(&self) -> Option<String> {
216 match &self.principal {
217 Principal::User(user) => user.email_address.clone(),
218 Principal::Impersonated { user, .. } => user.email_address.clone(),
219 }
220 }
221}
222
223impl Debug for Session {
224 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
225 let mut result = f.debug_struct("Session");
226 match &self.principal {
227 Principal::User(user) => {
228 result.field("user", &user.github_login);
229 }
230 Principal::Impersonated { user, admin } => {
231 result.field("user", &user.github_login);
232 result.field("impersonator", &admin.github_login);
233 }
234 }
235 result.field("connection_id", &self.connection_id).finish()
236 }
237}
238
239struct DbHandle(Arc<Database>);
240
241impl Deref for DbHandle {
242 type Target = Database;
243
244 fn deref(&self) -> &Self::Target {
245 self.0.as_ref()
246 }
247}
248
249pub struct Server {
250 id: parking_lot::Mutex<ServerId>,
251 peer: Arc<Peer>,
252 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
253 app_state: Arc<AppState>,
254 handlers: HashMap<TypeId, MessageHandler>,
255 teardown: watch::Sender<bool>,
256}
257
258pub(crate) struct ConnectionPoolGuard<'a> {
259 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
260 _not_send: PhantomData<Rc<()>>,
261}
262
263#[derive(Serialize)]
264pub struct ServerSnapshot<'a> {
265 peer: &'a Peer,
266 #[serde(serialize_with = "serialize_deref")]
267 connection_pool: ConnectionPoolGuard<'a>,
268}
269
270pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
271where
272 S: Serializer,
273 T: Deref<Target = U>,
274 U: Serialize,
275{
276 Serialize::serialize(value.deref(), serializer)
277}
278
279impl Server {
280 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
281 let mut server = Self {
282 id: parking_lot::Mutex::new(id),
283 peer: Peer::new(id.0 as u32),
284 app_state: app_state.clone(),
285 connection_pool: Default::default(),
286 handlers: Default::default(),
287 teardown: watch::channel(false).0,
288 };
289
290 server
291 .add_request_handler(ping)
292 .add_request_handler(create_room)
293 .add_request_handler(join_room)
294 .add_request_handler(rejoin_room)
295 .add_request_handler(leave_room)
296 .add_request_handler(set_room_participant_role)
297 .add_request_handler(call)
298 .add_request_handler(cancel_call)
299 .add_message_handler(decline_call)
300 .add_request_handler(update_participant_location)
301 .add_request_handler(share_project)
302 .add_message_handler(unshare_project)
303 .add_request_handler(join_project)
304 .add_message_handler(leave_project)
305 .add_request_handler(update_project)
306 .add_request_handler(update_worktree)
307 .add_request_handler(update_repository)
308 .add_request_handler(remove_repository)
309 .add_message_handler(start_language_server)
310 .add_message_handler(update_language_server)
311 .add_message_handler(update_diagnostic_summary)
312 .add_message_handler(update_worktree_settings)
313 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
314 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
315 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
316 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
317 .add_request_handler(forward_find_search_candidates_request)
318 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
319 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
320 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
321 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
322 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
323 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
324 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
325 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
326 .add_request_handler(forward_read_only_project_request::<proto::GetColorPresentation>)
327 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
328 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
329 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
330 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
331 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
332 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
333 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
334 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
335 .add_request_handler(
336 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
337 )
338 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
339 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
340 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
341 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
342 .add_request_handler(
343 forward_read_only_project_request::<proto::LanguageServerIdForName>,
344 )
345 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentDiagnostics>)
346 .add_request_handler(
347 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
348 )
349 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
350 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
351 .add_request_handler(
352 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
353 )
354 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
355 .add_request_handler(
356 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
357 )
358 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
359 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
360 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
361 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
362 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
363 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
364 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
365 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
366 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
367 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
368 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
369 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
370 .add_request_handler(
371 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
372 )
373 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
374 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
375 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
376 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
377 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
378 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
379 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
380 .add_message_handler(create_buffer_for_peer)
381 .add_request_handler(update_buffer)
382 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
383 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
384 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
385 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
386 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
387 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
388 .add_message_handler(
389 broadcast_project_message_from_host::<proto::PullWorkspaceDiagnostics>,
390 )
391 .add_request_handler(get_users)
392 .add_request_handler(fuzzy_search_users)
393 .add_request_handler(request_contact)
394 .add_request_handler(remove_contact)
395 .add_request_handler(respond_to_contact_request)
396 .add_message_handler(subscribe_to_channels)
397 .add_request_handler(create_channel)
398 .add_request_handler(delete_channel)
399 .add_request_handler(invite_channel_member)
400 .add_request_handler(remove_channel_member)
401 .add_request_handler(set_channel_member_role)
402 .add_request_handler(set_channel_visibility)
403 .add_request_handler(rename_channel)
404 .add_request_handler(join_channel_buffer)
405 .add_request_handler(leave_channel_buffer)
406 .add_message_handler(update_channel_buffer)
407 .add_request_handler(rejoin_channel_buffers)
408 .add_request_handler(get_channel_members)
409 .add_request_handler(respond_to_channel_invite)
410 .add_request_handler(join_channel)
411 .add_request_handler(join_channel_chat)
412 .add_message_handler(leave_channel_chat)
413 .add_request_handler(send_channel_message)
414 .add_request_handler(remove_channel_message)
415 .add_request_handler(update_channel_message)
416 .add_request_handler(get_channel_messages)
417 .add_request_handler(get_channel_messages_by_id)
418 .add_request_handler(get_notifications)
419 .add_request_handler(mark_notification_as_read)
420 .add_request_handler(move_channel)
421 .add_request_handler(reorder_channel)
422 .add_request_handler(follow)
423 .add_message_handler(unfollow)
424 .add_message_handler(update_followers)
425 .add_request_handler(get_private_user_info)
426 .add_request_handler(get_llm_api_token)
427 .add_request_handler(accept_terms_of_service)
428 .add_message_handler(acknowledge_channel_message)
429 .add_message_handler(acknowledge_buffer_version)
430 .add_request_handler(get_supermaven_api_key)
431 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
432 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
433 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
434 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
435 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
436 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
437 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
438 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
439 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
440 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
441 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
442 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
443 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
444 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
445 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
446 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
447 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
448 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
449 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
450 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
451 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
452 .add_message_handler(update_context);
453
454 Arc::new(server)
455 }
456
457 pub async fn start(&self) -> Result<()> {
458 let server_id = *self.id.lock();
459 let app_state = self.app_state.clone();
460 let peer = self.peer.clone();
461 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
462 let pool = self.connection_pool.clone();
463 let livekit_client = self.app_state.livekit_client.clone();
464
465 let span = info_span!("start server");
466 self.app_state.executor.spawn_detached(
467 async move {
468 tracing::info!("waiting for cleanup timeout");
469 timeout.await;
470 tracing::info!("cleanup timeout expired, retrieving stale rooms");
471
472 app_state
473 .db
474 .delete_stale_channel_chat_participants(
475 &app_state.config.zed_environment,
476 server_id,
477 )
478 .await
479 .trace_err();
480
481 if let Some((room_ids, channel_ids)) = app_state
482 .db
483 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
484 .await
485 .trace_err()
486 {
487 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
488 tracing::info!(
489 stale_channel_buffer_count = channel_ids.len(),
490 "retrieved stale channel buffers"
491 );
492
493 for channel_id in channel_ids {
494 if let Some(refreshed_channel_buffer) = app_state
495 .db
496 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
497 .await
498 .trace_err()
499 {
500 for connection_id in refreshed_channel_buffer.connection_ids {
501 peer.send(
502 connection_id,
503 proto::UpdateChannelBufferCollaborators {
504 channel_id: channel_id.to_proto(),
505 collaborators: refreshed_channel_buffer
506 .collaborators
507 .clone(),
508 },
509 )
510 .trace_err();
511 }
512 }
513 }
514
515 for room_id in room_ids {
516 let mut contacts_to_update = HashSet::default();
517 let mut canceled_calls_to_user_ids = Vec::new();
518 let mut livekit_room = String::new();
519 let mut delete_livekit_room = false;
520
521 if let Some(mut refreshed_room) = app_state
522 .db
523 .clear_stale_room_participants(room_id, server_id)
524 .await
525 .trace_err()
526 {
527 tracing::info!(
528 room_id = room_id.0,
529 new_participant_count = refreshed_room.room.participants.len(),
530 "refreshed room"
531 );
532 room_updated(&refreshed_room.room, &peer);
533 if let Some(channel) = refreshed_room.channel.as_ref() {
534 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
535 }
536 contacts_to_update
537 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
538 contacts_to_update
539 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
540 canceled_calls_to_user_ids =
541 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
542 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
543 delete_livekit_room = refreshed_room.room.participants.is_empty();
544 }
545
546 {
547 let pool = pool.lock();
548 for canceled_user_id in canceled_calls_to_user_ids {
549 for connection_id in pool.user_connection_ids(canceled_user_id) {
550 peer.send(
551 connection_id,
552 proto::CallCanceled {
553 room_id: room_id.to_proto(),
554 },
555 )
556 .trace_err();
557 }
558 }
559 }
560
561 for user_id in contacts_to_update {
562 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
563 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
564 if let Some((busy, contacts)) = busy.zip(contacts) {
565 let pool = pool.lock();
566 let updated_contact = contact_for_user(user_id, busy, &pool);
567 for contact in contacts {
568 if let db::Contact::Accepted {
569 user_id: contact_user_id,
570 ..
571 } = contact
572 {
573 for contact_conn_id in
574 pool.user_connection_ids(contact_user_id)
575 {
576 peer.send(
577 contact_conn_id,
578 proto::UpdateContacts {
579 contacts: vec![updated_contact.clone()],
580 remove_contacts: Default::default(),
581 incoming_requests: Default::default(),
582 remove_incoming_requests: Default::default(),
583 outgoing_requests: Default::default(),
584 remove_outgoing_requests: Default::default(),
585 },
586 )
587 .trace_err();
588 }
589 }
590 }
591 }
592 }
593
594 if let Some(live_kit) = livekit_client.as_ref() {
595 if delete_livekit_room {
596 live_kit.delete_room(livekit_room).await.trace_err();
597 }
598 }
599 }
600 }
601
602 app_state
603 .db
604 .delete_stale_channel_chat_participants(
605 &app_state.config.zed_environment,
606 server_id,
607 )
608 .await
609 .trace_err();
610
611 app_state
612 .db
613 .clear_old_worktree_entries(server_id)
614 .await
615 .trace_err();
616
617 app_state
618 .db
619 .delete_stale_servers(&app_state.config.zed_environment, server_id)
620 .await
621 .trace_err();
622 }
623 .instrument(span),
624 );
625 Ok(())
626 }
627
628 pub fn teardown(&self) {
629 self.peer.teardown();
630 self.connection_pool.lock().reset();
631 let _ = self.teardown.send(true);
632 }
633
634 #[cfg(test)]
635 pub fn reset(&self, id: ServerId) {
636 self.teardown();
637 *self.id.lock() = id;
638 self.peer.reset(id.0 as u32);
639 let _ = self.teardown.send(false);
640 }
641
642 #[cfg(test)]
643 pub fn id(&self) -> ServerId {
644 *self.id.lock()
645 }
646
647 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
648 where
649 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
650 Fut: 'static + Send + Future<Output = Result<()>>,
651 M: EnvelopedMessage,
652 {
653 let prev_handler = self.handlers.insert(
654 TypeId::of::<M>(),
655 Box::new(move |envelope, session| {
656 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
657 let received_at = envelope.received_at;
658 tracing::info!("message received");
659 let start_time = Instant::now();
660 let future = (handler)(*envelope, session);
661 async move {
662 let result = future.await;
663 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
664 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
665 let queue_duration_ms = total_duration_ms - processing_duration_ms;
666 let payload_type = M::NAME;
667
668 match result {
669 Err(error) => {
670 tracing::error!(
671 ?error,
672 total_duration_ms,
673 processing_duration_ms,
674 queue_duration_ms,
675 payload_type,
676 "error handling message"
677 )
678 }
679 Ok(()) => tracing::info!(
680 total_duration_ms,
681 processing_duration_ms,
682 queue_duration_ms,
683 "finished handling message"
684 ),
685 }
686 }
687 .boxed()
688 }),
689 );
690 if prev_handler.is_some() {
691 panic!("registered a handler for the same message twice");
692 }
693 self
694 }
695
696 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
697 where
698 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
699 Fut: 'static + Send + Future<Output = Result<()>>,
700 M: EnvelopedMessage,
701 {
702 self.add_handler(move |envelope, session| handler(envelope.payload, session));
703 self
704 }
705
706 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
707 where
708 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
709 Fut: Send + Future<Output = Result<()>>,
710 M: RequestMessage,
711 {
712 let handler = Arc::new(handler);
713 self.add_handler(move |envelope, session| {
714 let receipt = envelope.receipt();
715 let handler = handler.clone();
716 async move {
717 let peer = session.peer.clone();
718 let responded = Arc::new(AtomicBool::default());
719 let response = Response {
720 peer: peer.clone(),
721 responded: responded.clone(),
722 receipt,
723 };
724 match (handler)(envelope.payload, response, session).await {
725 Ok(()) => {
726 if responded.load(std::sync::atomic::Ordering::SeqCst) {
727 Ok(())
728 } else {
729 Err(anyhow!("handler did not send a response"))?
730 }
731 }
732 Err(error) => {
733 let proto_err = match &error {
734 Error::Internal(err) => err.to_proto(),
735 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
736 };
737 peer.respond_with_error(receipt, proto_err)?;
738 Err(error)
739 }
740 }
741 }
742 })
743 }
744
745 pub fn handle_connection(
746 self: &Arc<Self>,
747 connection: Connection,
748 address: String,
749 principal: Principal,
750 zed_version: ZedVersion,
751 geoip_country_code: Option<String>,
752 system_id: Option<String>,
753 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
754 executor: Executor,
755 connection_guard: Option<ConnectionGuard>,
756 ) -> impl Future<Output = ()> + use<> {
757 let this = self.clone();
758 let span = info_span!("handle connection", %address,
759 connection_id=field::Empty,
760 user_id=field::Empty,
761 login=field::Empty,
762 impersonator=field::Empty,
763 geoip_country_code=field::Empty
764 );
765 principal.update_span(&span);
766 if let Some(country_code) = geoip_country_code.as_ref() {
767 span.record("geoip_country_code", country_code);
768 }
769
770 let mut teardown = self.teardown.subscribe();
771 async move {
772 if *teardown.borrow() {
773 tracing::error!("server is tearing down");
774 return
775 }
776
777 let (connection_id, handle_io, mut incoming_rx) = this
778 .peer
779 .add_connection(connection, {
780 let executor = executor.clone();
781 move |duration| executor.sleep(duration)
782 });
783 tracing::Span::current().record("connection_id", format!("{}", connection_id));
784
785 tracing::info!("connection opened");
786
787 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
788 let http_client = match ReqwestClient::user_agent(&user_agent) {
789 Ok(http_client) => Arc::new(http_client),
790 Err(error) => {
791 tracing::error!(?error, "failed to create HTTP client");
792 return;
793 }
794 };
795
796 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
797 supermaven_admin_api_key.to_string(),
798 http_client.clone(),
799 )));
800
801 let session = Session {
802 principal: principal.clone(),
803 connection_id,
804 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
805 peer: this.peer.clone(),
806 connection_pool: this.connection_pool.clone(),
807 app_state: this.app_state.clone(),
808 geoip_country_code,
809 system_id,
810 _executor: executor.clone(),
811 supermaven_client,
812 };
813
814 if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
815 tracing::error!(?error, "failed to send initial client update");
816 return;
817 }
818 drop(connection_guard);
819
820 let handle_io = handle_io.fuse();
821 futures::pin_mut!(handle_io);
822
823 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
824 // This prevents deadlocks when e.g., client A performs a request to client B and
825 // client B performs a request to client A. If both clients stop processing further
826 // messages until their respective request completes, they won't have a chance to
827 // respond to the other client's request and cause a deadlock.
828 //
829 // This arrangement ensures we will attempt to process earlier messages first, but fall
830 // back to processing messages arrived later in the spirit of making progress.
831 let mut foreground_message_handlers = FuturesUnordered::new();
832 let concurrent_handlers = Arc::new(Semaphore::new(256));
833 loop {
834 let next_message = async {
835 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
836 let message = incoming_rx.next().await;
837 (permit, message)
838 }.fuse();
839 futures::pin_mut!(next_message);
840 futures::select_biased! {
841 _ = teardown.changed().fuse() => return,
842 result = handle_io => {
843 if let Err(error) = result {
844 tracing::error!(?error, "error handling I/O");
845 }
846 break;
847 }
848 _ = foreground_message_handlers.next() => {}
849 next_message = next_message => {
850 let (permit, message) = next_message;
851 if let Some(message) = message {
852 let type_name = message.payload_type_name();
853 // note: we copy all the fields from the parent span so we can query them in the logs.
854 // (https://github.com/tokio-rs/tracing/issues/2670).
855 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
856 user_id=field::Empty,
857 login=field::Empty,
858 impersonator=field::Empty,
859 );
860 principal.update_span(&span);
861 let span_enter = span.enter();
862 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
863 let is_background = message.is_background();
864 let handle_message = (handler)(message, session.clone());
865 drop(span_enter);
866
867 let handle_message = async move {
868 handle_message.await;
869 drop(permit);
870 }.instrument(span);
871 if is_background {
872 executor.spawn_detached(handle_message);
873 } else {
874 foreground_message_handlers.push(handle_message);
875 }
876 } else {
877 tracing::error!("no message handler");
878 }
879 } else {
880 tracing::info!("connection closed");
881 break;
882 }
883 }
884 }
885 }
886
887 drop(foreground_message_handlers);
888 tracing::info!("signing out");
889 if let Err(error) = connection_lost(session, teardown, executor).await {
890 tracing::error!(?error, "error signing out");
891 }
892
893 }.instrument(span)
894 }
895
896 async fn send_initial_client_update(
897 &self,
898 connection_id: ConnectionId,
899 zed_version: ZedVersion,
900 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
901 session: &Session,
902 ) -> Result<()> {
903 self.peer.send(
904 connection_id,
905 proto::Hello {
906 peer_id: Some(connection_id.into()),
907 },
908 )?;
909 tracing::info!("sent hello message");
910 if let Some(send_connection_id) = send_connection_id.take() {
911 let _ = send_connection_id.send(connection_id);
912 }
913
914 match &session.principal {
915 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
916 if !user.connected_once {
917 self.peer.send(connection_id, proto::ShowContacts {})?;
918 self.app_state
919 .db
920 .set_user_connected_once(user.id, true)
921 .await?;
922 }
923
924 update_user_plan(session).await?;
925
926 let contacts = self.app_state.db.get_contacts(user.id).await?;
927
928 {
929 let mut pool = self.connection_pool.lock();
930 pool.add_connection(connection_id, user.id, user.admin, zed_version);
931 self.peer.send(
932 connection_id,
933 build_initial_contacts_update(contacts, &pool),
934 )?;
935 }
936
937 if should_auto_subscribe_to_channels(zed_version) {
938 subscribe_user_to_channels(user.id, session).await?;
939 }
940
941 if let Some(incoming_call) =
942 self.app_state.db.incoming_call_for_user(user.id).await?
943 {
944 self.peer.send(connection_id, incoming_call)?;
945 }
946
947 update_user_contacts(user.id, session).await?;
948 }
949 }
950
951 Ok(())
952 }
953
954 pub async fn invite_code_redeemed(
955 self: &Arc<Self>,
956 inviter_id: UserId,
957 invitee_id: UserId,
958 ) -> Result<()> {
959 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
960 if let Some(code) = &user.invite_code {
961 let pool = self.connection_pool.lock();
962 let invitee_contact = contact_for_user(invitee_id, false, &pool);
963 for connection_id in pool.user_connection_ids(inviter_id) {
964 self.peer.send(
965 connection_id,
966 proto::UpdateContacts {
967 contacts: vec![invitee_contact.clone()],
968 ..Default::default()
969 },
970 )?;
971 self.peer.send(
972 connection_id,
973 proto::UpdateInviteInfo {
974 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
975 count: user.invite_count as u32,
976 },
977 )?;
978 }
979 }
980 }
981 Ok(())
982 }
983
984 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
985 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
986 if let Some(invite_code) = &user.invite_code {
987 let pool = self.connection_pool.lock();
988 for connection_id in pool.user_connection_ids(user_id) {
989 self.peer.send(
990 connection_id,
991 proto::UpdateInviteInfo {
992 url: format!(
993 "{}{}",
994 self.app_state.config.invite_link_prefix, invite_code
995 ),
996 count: user.invite_count as u32,
997 },
998 )?;
999 }
1000 }
1001 }
1002 Ok(())
1003 }
1004
1005 pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1006 let user = self
1007 .app_state
1008 .db
1009 .get_user_by_id(user_id)
1010 .await?
1011 .context("user not found")?;
1012
1013 let update_user_plan = make_update_user_plan_message(
1014 &user,
1015 user.admin,
1016 &self.app_state.db,
1017 self.app_state.llm_db.clone(),
1018 )
1019 .await?;
1020
1021 let pool = self.connection_pool.lock();
1022 for connection_id in pool.user_connection_ids(user_id) {
1023 self.peer
1024 .send(connection_id, update_user_plan.clone())
1025 .trace_err();
1026 }
1027
1028 Ok(())
1029 }
1030
1031 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
1032 let pool = self.connection_pool.lock();
1033 for connection_id in pool.user_connection_ids(user_id) {
1034 self.peer
1035 .send(connection_id, proto::RefreshLlmToken {})
1036 .trace_err();
1037 }
1038 }
1039
1040 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot<'_> {
1041 ServerSnapshot {
1042 connection_pool: ConnectionPoolGuard {
1043 guard: self.connection_pool.lock(),
1044 _not_send: PhantomData,
1045 },
1046 peer: &self.peer,
1047 }
1048 }
1049}
1050
1051impl Deref for ConnectionPoolGuard<'_> {
1052 type Target = ConnectionPool;
1053
1054 fn deref(&self) -> &Self::Target {
1055 &self.guard
1056 }
1057}
1058
1059impl DerefMut for ConnectionPoolGuard<'_> {
1060 fn deref_mut(&mut self) -> &mut Self::Target {
1061 &mut self.guard
1062 }
1063}
1064
1065impl Drop for ConnectionPoolGuard<'_> {
1066 fn drop(&mut self) {
1067 #[cfg(test)]
1068 self.check_invariants();
1069 }
1070}
1071
1072fn broadcast<F>(
1073 sender_id: Option<ConnectionId>,
1074 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1075 mut f: F,
1076) where
1077 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1078{
1079 for receiver_id in receiver_ids {
1080 if Some(receiver_id) != sender_id {
1081 if let Err(error) = f(receiver_id) {
1082 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1083 }
1084 }
1085 }
1086}
1087
1088pub struct ProtocolVersion(u32);
1089
1090impl Header for ProtocolVersion {
1091 fn name() -> &'static HeaderName {
1092 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1093 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1094 }
1095
1096 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1097 where
1098 Self: Sized,
1099 I: Iterator<Item = &'i axum::http::HeaderValue>,
1100 {
1101 let version = values
1102 .next()
1103 .ok_or_else(axum::headers::Error::invalid)?
1104 .to_str()
1105 .map_err(|_| axum::headers::Error::invalid())?
1106 .parse()
1107 .map_err(|_| axum::headers::Error::invalid())?;
1108 Ok(Self(version))
1109 }
1110
1111 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1112 values.extend([self.0.to_string().parse().unwrap()]);
1113 }
1114}
1115
1116pub struct AppVersionHeader(SemanticVersion);
1117impl Header for AppVersionHeader {
1118 fn name() -> &'static HeaderName {
1119 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1120 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1121 }
1122
1123 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1124 where
1125 Self: Sized,
1126 I: Iterator<Item = &'i axum::http::HeaderValue>,
1127 {
1128 let version = values
1129 .next()
1130 .ok_or_else(axum::headers::Error::invalid)?
1131 .to_str()
1132 .map_err(|_| axum::headers::Error::invalid())?
1133 .parse()
1134 .map_err(|_| axum::headers::Error::invalid())?;
1135 Ok(Self(version))
1136 }
1137
1138 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1139 values.extend([self.0.to_string().parse().unwrap()]);
1140 }
1141}
1142
1143pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1144 Router::new()
1145 .route("/rpc", get(handle_websocket_request))
1146 .layer(
1147 ServiceBuilder::new()
1148 .layer(Extension(server.app_state.clone()))
1149 .layer(middleware::from_fn(auth::validate_header)),
1150 )
1151 .route("/metrics", get(handle_metrics))
1152 .layer(Extension(server))
1153}
1154
1155pub async fn handle_websocket_request(
1156 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1157 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1158 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1159 Extension(server): Extension<Arc<Server>>,
1160 Extension(principal): Extension<Principal>,
1161 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1162 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1163 ws: WebSocketUpgrade,
1164) -> axum::response::Response {
1165 if protocol_version != rpc::PROTOCOL_VERSION {
1166 return (
1167 StatusCode::UPGRADE_REQUIRED,
1168 "client must be upgraded".to_string(),
1169 )
1170 .into_response();
1171 }
1172
1173 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1174 return (
1175 StatusCode::UPGRADE_REQUIRED,
1176 "no version header found".to_string(),
1177 )
1178 .into_response();
1179 };
1180
1181 if !version.can_collaborate() {
1182 return (
1183 StatusCode::UPGRADE_REQUIRED,
1184 "client must be upgraded".to_string(),
1185 )
1186 .into_response();
1187 }
1188
1189 let socket_address = socket_address.to_string();
1190
1191 // Acquire connection guard before WebSocket upgrade
1192 let connection_guard = match ConnectionGuard::try_acquire() {
1193 Ok(guard) => guard,
1194 Err(()) => {
1195 return (
1196 StatusCode::SERVICE_UNAVAILABLE,
1197 "Too many concurrent connections",
1198 )
1199 .into_response();
1200 }
1201 };
1202
1203 ws.on_upgrade(move |socket| {
1204 let socket = socket
1205 .map_ok(to_tungstenite_message)
1206 .err_into()
1207 .with(|message| async move { to_axum_message(message) });
1208 let connection = Connection::new(Box::pin(socket));
1209 async move {
1210 server
1211 .handle_connection(
1212 connection,
1213 socket_address,
1214 principal,
1215 version,
1216 country_code_header.map(|header| header.to_string()),
1217 system_id_header.map(|header| header.to_string()),
1218 None,
1219 Executor::Production,
1220 Some(connection_guard),
1221 )
1222 .await;
1223 }
1224 })
1225}
1226
1227pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1228 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1229 let connections_metric = CONNECTIONS_METRIC
1230 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1231
1232 let connections = server
1233 .connection_pool
1234 .lock()
1235 .connections()
1236 .filter(|connection| !connection.admin)
1237 .count();
1238 connections_metric.set(connections as _);
1239
1240 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1241 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1242 register_int_gauge!(
1243 "shared_projects",
1244 "number of open projects with one or more guests"
1245 )
1246 .unwrap()
1247 });
1248
1249 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1250 shared_projects_metric.set(shared_projects as _);
1251
1252 let encoder = prometheus::TextEncoder::new();
1253 let metric_families = prometheus::gather();
1254 let encoded_metrics = encoder
1255 .encode_to_string(&metric_families)
1256 .map_err(|err| anyhow!("{err}"))?;
1257 Ok(encoded_metrics)
1258}
1259
1260#[instrument(err, skip(executor))]
1261async fn connection_lost(
1262 session: Session,
1263 mut teardown: watch::Receiver<bool>,
1264 executor: Executor,
1265) -> Result<()> {
1266 session.peer.disconnect(session.connection_id);
1267 session
1268 .connection_pool()
1269 .await
1270 .remove_connection(session.connection_id)?;
1271
1272 session
1273 .db()
1274 .await
1275 .connection_lost(session.connection_id)
1276 .await
1277 .trace_err();
1278
1279 futures::select_biased! {
1280 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1281
1282 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1283 leave_room_for_session(&session, session.connection_id).await.trace_err();
1284 leave_channel_buffers_for_session(&session)
1285 .await
1286 .trace_err();
1287
1288 if !session
1289 .connection_pool()
1290 .await
1291 .is_user_online(session.user_id())
1292 {
1293 let db = session.db().await;
1294 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1295 room_updated(&room, &session.peer);
1296 }
1297 }
1298
1299 update_user_contacts(session.user_id(), &session).await?;
1300 },
1301 _ = teardown.changed().fuse() => {}
1302 }
1303
1304 Ok(())
1305}
1306
1307/// Acknowledges a ping from a client, used to keep the connection alive.
1308async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1309 response.send(proto::Ack {})?;
1310 Ok(())
1311}
1312
1313/// Creates a new room for calling (outside of channels)
1314async fn create_room(
1315 _request: proto::CreateRoom,
1316 response: Response<proto::CreateRoom>,
1317 session: Session,
1318) -> Result<()> {
1319 let livekit_room = nanoid::nanoid!(30);
1320
1321 let live_kit_connection_info = util::maybe!(async {
1322 let live_kit = session.app_state.livekit_client.as_ref();
1323 let live_kit = live_kit?;
1324 let user_id = session.user_id().to_string();
1325
1326 let token = live_kit
1327 .room_token(&livekit_room, &user_id.to_string())
1328 .trace_err()?;
1329
1330 Some(proto::LiveKitConnectionInfo {
1331 server_url: live_kit.url().into(),
1332 token,
1333 can_publish: true,
1334 })
1335 })
1336 .await;
1337
1338 let room = session
1339 .db()
1340 .await
1341 .create_room(session.user_id(), session.connection_id, &livekit_room)
1342 .await?;
1343
1344 response.send(proto::CreateRoomResponse {
1345 room: Some(room.clone()),
1346 live_kit_connection_info,
1347 })?;
1348
1349 update_user_contacts(session.user_id(), &session).await?;
1350 Ok(())
1351}
1352
1353/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1354async fn join_room(
1355 request: proto::JoinRoom,
1356 response: Response<proto::JoinRoom>,
1357 session: Session,
1358) -> Result<()> {
1359 let room_id = RoomId::from_proto(request.id);
1360
1361 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1362
1363 if let Some(channel_id) = channel_id {
1364 return join_channel_internal(channel_id, Box::new(response), session).await;
1365 }
1366
1367 let joined_room = {
1368 let room = session
1369 .db()
1370 .await
1371 .join_room(room_id, session.user_id(), session.connection_id)
1372 .await?;
1373 room_updated(&room.room, &session.peer);
1374 room.into_inner()
1375 };
1376
1377 for connection_id in session
1378 .connection_pool()
1379 .await
1380 .user_connection_ids(session.user_id())
1381 {
1382 session
1383 .peer
1384 .send(
1385 connection_id,
1386 proto::CallCanceled {
1387 room_id: room_id.to_proto(),
1388 },
1389 )
1390 .trace_err();
1391 }
1392
1393 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1394 {
1395 live_kit
1396 .room_token(
1397 &joined_room.room.livekit_room,
1398 &session.user_id().to_string(),
1399 )
1400 .trace_err()
1401 .map(|token| proto::LiveKitConnectionInfo {
1402 server_url: live_kit.url().into(),
1403 token,
1404 can_publish: true,
1405 })
1406 } else {
1407 None
1408 };
1409
1410 response.send(proto::JoinRoomResponse {
1411 room: Some(joined_room.room),
1412 channel_id: None,
1413 live_kit_connection_info,
1414 })?;
1415
1416 update_user_contacts(session.user_id(), &session).await?;
1417 Ok(())
1418}
1419
1420/// Rejoin room is used to reconnect to a room after connection errors.
1421async fn rejoin_room(
1422 request: proto::RejoinRoom,
1423 response: Response<proto::RejoinRoom>,
1424 session: Session,
1425) -> Result<()> {
1426 let room;
1427 let channel;
1428 {
1429 let mut rejoined_room = session
1430 .db()
1431 .await
1432 .rejoin_room(request, session.user_id(), session.connection_id)
1433 .await?;
1434
1435 response.send(proto::RejoinRoomResponse {
1436 room: Some(rejoined_room.room.clone()),
1437 reshared_projects: rejoined_room
1438 .reshared_projects
1439 .iter()
1440 .map(|project| proto::ResharedProject {
1441 id: project.id.to_proto(),
1442 collaborators: project
1443 .collaborators
1444 .iter()
1445 .map(|collaborator| collaborator.to_proto())
1446 .collect(),
1447 })
1448 .collect(),
1449 rejoined_projects: rejoined_room
1450 .rejoined_projects
1451 .iter()
1452 .map(|rejoined_project| rejoined_project.to_proto())
1453 .collect(),
1454 })?;
1455 room_updated(&rejoined_room.room, &session.peer);
1456
1457 for project in &rejoined_room.reshared_projects {
1458 for collaborator in &project.collaborators {
1459 session
1460 .peer
1461 .send(
1462 collaborator.connection_id,
1463 proto::UpdateProjectCollaborator {
1464 project_id: project.id.to_proto(),
1465 old_peer_id: Some(project.old_connection_id.into()),
1466 new_peer_id: Some(session.connection_id.into()),
1467 },
1468 )
1469 .trace_err();
1470 }
1471
1472 broadcast(
1473 Some(session.connection_id),
1474 project
1475 .collaborators
1476 .iter()
1477 .map(|collaborator| collaborator.connection_id),
1478 |connection_id| {
1479 session.peer.forward_send(
1480 session.connection_id,
1481 connection_id,
1482 proto::UpdateProject {
1483 project_id: project.id.to_proto(),
1484 worktrees: project.worktrees.clone(),
1485 },
1486 )
1487 },
1488 );
1489 }
1490
1491 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1492
1493 let rejoined_room = rejoined_room.into_inner();
1494
1495 room = rejoined_room.room;
1496 channel = rejoined_room.channel;
1497 }
1498
1499 if let Some(channel) = channel {
1500 channel_updated(
1501 &channel,
1502 &room,
1503 &session.peer,
1504 &*session.connection_pool().await,
1505 );
1506 }
1507
1508 update_user_contacts(session.user_id(), &session).await?;
1509 Ok(())
1510}
1511
1512fn notify_rejoined_projects(
1513 rejoined_projects: &mut Vec<RejoinedProject>,
1514 session: &Session,
1515) -> Result<()> {
1516 for project in rejoined_projects.iter() {
1517 for collaborator in &project.collaborators {
1518 session
1519 .peer
1520 .send(
1521 collaborator.connection_id,
1522 proto::UpdateProjectCollaborator {
1523 project_id: project.id.to_proto(),
1524 old_peer_id: Some(project.old_connection_id.into()),
1525 new_peer_id: Some(session.connection_id.into()),
1526 },
1527 )
1528 .trace_err();
1529 }
1530 }
1531
1532 for project in rejoined_projects {
1533 for worktree in mem::take(&mut project.worktrees) {
1534 // Stream this worktree's entries.
1535 let message = proto::UpdateWorktree {
1536 project_id: project.id.to_proto(),
1537 worktree_id: worktree.id,
1538 abs_path: worktree.abs_path.clone(),
1539 root_name: worktree.root_name,
1540 updated_entries: worktree.updated_entries,
1541 removed_entries: worktree.removed_entries,
1542 scan_id: worktree.scan_id,
1543 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1544 updated_repositories: worktree.updated_repositories,
1545 removed_repositories: worktree.removed_repositories,
1546 };
1547 for update in proto::split_worktree_update(message) {
1548 session.peer.send(session.connection_id, update)?;
1549 }
1550
1551 // Stream this worktree's diagnostics.
1552 for summary in worktree.diagnostic_summaries {
1553 session.peer.send(
1554 session.connection_id,
1555 proto::UpdateDiagnosticSummary {
1556 project_id: project.id.to_proto(),
1557 worktree_id: worktree.id,
1558 summary: Some(summary),
1559 },
1560 )?;
1561 }
1562
1563 for settings_file in worktree.settings_files {
1564 session.peer.send(
1565 session.connection_id,
1566 proto::UpdateWorktreeSettings {
1567 project_id: project.id.to_proto(),
1568 worktree_id: worktree.id,
1569 path: settings_file.path,
1570 content: Some(settings_file.content),
1571 kind: Some(settings_file.kind.to_proto().into()),
1572 },
1573 )?;
1574 }
1575 }
1576
1577 for repository in mem::take(&mut project.updated_repositories) {
1578 for update in split_repository_update(repository) {
1579 session.peer.send(session.connection_id, update)?;
1580 }
1581 }
1582
1583 for id in mem::take(&mut project.removed_repositories) {
1584 session.peer.send(
1585 session.connection_id,
1586 proto::RemoveRepository {
1587 project_id: project.id.to_proto(),
1588 id,
1589 },
1590 )?;
1591 }
1592 }
1593
1594 Ok(())
1595}
1596
1597/// leave room disconnects from the room.
1598async fn leave_room(
1599 _: proto::LeaveRoom,
1600 response: Response<proto::LeaveRoom>,
1601 session: Session,
1602) -> Result<()> {
1603 leave_room_for_session(&session, session.connection_id).await?;
1604 response.send(proto::Ack {})?;
1605 Ok(())
1606}
1607
1608/// Updates the permissions of someone else in the room.
1609async fn set_room_participant_role(
1610 request: proto::SetRoomParticipantRole,
1611 response: Response<proto::SetRoomParticipantRole>,
1612 session: Session,
1613) -> Result<()> {
1614 let user_id = UserId::from_proto(request.user_id);
1615 let role = ChannelRole::from(request.role());
1616
1617 let (livekit_room, can_publish) = {
1618 let room = session
1619 .db()
1620 .await
1621 .set_room_participant_role(
1622 session.user_id(),
1623 RoomId::from_proto(request.room_id),
1624 user_id,
1625 role,
1626 )
1627 .await?;
1628
1629 let livekit_room = room.livekit_room.clone();
1630 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1631 room_updated(&room, &session.peer);
1632 (livekit_room, can_publish)
1633 };
1634
1635 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1636 live_kit
1637 .update_participant(
1638 livekit_room.clone(),
1639 request.user_id.to_string(),
1640 livekit_api::proto::ParticipantPermission {
1641 can_subscribe: true,
1642 can_publish,
1643 can_publish_data: can_publish,
1644 hidden: false,
1645 recorder: false,
1646 },
1647 )
1648 .await
1649 .trace_err();
1650 }
1651
1652 response.send(proto::Ack {})?;
1653 Ok(())
1654}
1655
1656/// Call someone else into the current room
1657async fn call(
1658 request: proto::Call,
1659 response: Response<proto::Call>,
1660 session: Session,
1661) -> Result<()> {
1662 let room_id = RoomId::from_proto(request.room_id);
1663 let calling_user_id = session.user_id();
1664 let calling_connection_id = session.connection_id;
1665 let called_user_id = UserId::from_proto(request.called_user_id);
1666 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1667 if !session
1668 .db()
1669 .await
1670 .has_contact(calling_user_id, called_user_id)
1671 .await?
1672 {
1673 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1674 }
1675
1676 let incoming_call = {
1677 let (room, incoming_call) = &mut *session
1678 .db()
1679 .await
1680 .call(
1681 room_id,
1682 calling_user_id,
1683 calling_connection_id,
1684 called_user_id,
1685 initial_project_id,
1686 )
1687 .await?;
1688 room_updated(room, &session.peer);
1689 mem::take(incoming_call)
1690 };
1691 update_user_contacts(called_user_id, &session).await?;
1692
1693 let mut calls = session
1694 .connection_pool()
1695 .await
1696 .user_connection_ids(called_user_id)
1697 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1698 .collect::<FuturesUnordered<_>>();
1699
1700 while let Some(call_response) = calls.next().await {
1701 match call_response.as_ref() {
1702 Ok(_) => {
1703 response.send(proto::Ack {})?;
1704 return Ok(());
1705 }
1706 Err(_) => {
1707 call_response.trace_err();
1708 }
1709 }
1710 }
1711
1712 {
1713 let room = session
1714 .db()
1715 .await
1716 .call_failed(room_id, called_user_id)
1717 .await?;
1718 room_updated(&room, &session.peer);
1719 }
1720 update_user_contacts(called_user_id, &session).await?;
1721
1722 Err(anyhow!("failed to ring user"))?
1723}
1724
1725/// Cancel an outgoing call.
1726async fn cancel_call(
1727 request: proto::CancelCall,
1728 response: Response<proto::CancelCall>,
1729 session: Session,
1730) -> Result<()> {
1731 let called_user_id = UserId::from_proto(request.called_user_id);
1732 let room_id = RoomId::from_proto(request.room_id);
1733 {
1734 let room = session
1735 .db()
1736 .await
1737 .cancel_call(room_id, session.connection_id, called_user_id)
1738 .await?;
1739 room_updated(&room, &session.peer);
1740 }
1741
1742 for connection_id in session
1743 .connection_pool()
1744 .await
1745 .user_connection_ids(called_user_id)
1746 {
1747 session
1748 .peer
1749 .send(
1750 connection_id,
1751 proto::CallCanceled {
1752 room_id: room_id.to_proto(),
1753 },
1754 )
1755 .trace_err();
1756 }
1757 response.send(proto::Ack {})?;
1758
1759 update_user_contacts(called_user_id, &session).await?;
1760 Ok(())
1761}
1762
1763/// Decline an incoming call.
1764async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1765 let room_id = RoomId::from_proto(message.room_id);
1766 {
1767 let room = session
1768 .db()
1769 .await
1770 .decline_call(Some(room_id), session.user_id())
1771 .await?
1772 .context("declining call")?;
1773 room_updated(&room, &session.peer);
1774 }
1775
1776 for connection_id in session
1777 .connection_pool()
1778 .await
1779 .user_connection_ids(session.user_id())
1780 {
1781 session
1782 .peer
1783 .send(
1784 connection_id,
1785 proto::CallCanceled {
1786 room_id: room_id.to_proto(),
1787 },
1788 )
1789 .trace_err();
1790 }
1791 update_user_contacts(session.user_id(), &session).await?;
1792 Ok(())
1793}
1794
1795/// Updates other participants in the room with your current location.
1796async fn update_participant_location(
1797 request: proto::UpdateParticipantLocation,
1798 response: Response<proto::UpdateParticipantLocation>,
1799 session: Session,
1800) -> Result<()> {
1801 let room_id = RoomId::from_proto(request.room_id);
1802 let location = request.location.context("invalid location")?;
1803
1804 let db = session.db().await;
1805 let room = db
1806 .update_room_participant_location(room_id, session.connection_id, location)
1807 .await?;
1808
1809 room_updated(&room, &session.peer);
1810 response.send(proto::Ack {})?;
1811 Ok(())
1812}
1813
1814/// Share a project into the room.
1815async fn share_project(
1816 request: proto::ShareProject,
1817 response: Response<proto::ShareProject>,
1818 session: Session,
1819) -> Result<()> {
1820 let (project_id, room) = &*session
1821 .db()
1822 .await
1823 .share_project(
1824 RoomId::from_proto(request.room_id),
1825 session.connection_id,
1826 &request.worktrees,
1827 request.is_ssh_project,
1828 )
1829 .await?;
1830 response.send(proto::ShareProjectResponse {
1831 project_id: project_id.to_proto(),
1832 })?;
1833 room_updated(room, &session.peer);
1834
1835 Ok(())
1836}
1837
1838/// Unshare a project from the room.
1839async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1840 let project_id = ProjectId::from_proto(message.project_id);
1841 unshare_project_internal(project_id, session.connection_id, &session).await
1842}
1843
1844async fn unshare_project_internal(
1845 project_id: ProjectId,
1846 connection_id: ConnectionId,
1847 session: &Session,
1848) -> Result<()> {
1849 let delete = {
1850 let room_guard = session
1851 .db()
1852 .await
1853 .unshare_project(project_id, connection_id)
1854 .await?;
1855
1856 let (delete, room, guest_connection_ids) = &*room_guard;
1857
1858 let message = proto::UnshareProject {
1859 project_id: project_id.to_proto(),
1860 };
1861
1862 broadcast(
1863 Some(connection_id),
1864 guest_connection_ids.iter().copied(),
1865 |conn_id| session.peer.send(conn_id, message.clone()),
1866 );
1867 if let Some(room) = room {
1868 room_updated(room, &session.peer);
1869 }
1870
1871 *delete
1872 };
1873
1874 if delete {
1875 let db = session.db().await;
1876 db.delete_project(project_id).await?;
1877 }
1878
1879 Ok(())
1880}
1881
1882/// Join someone elses shared project.
1883async fn join_project(
1884 request: proto::JoinProject,
1885 response: Response<proto::JoinProject>,
1886 session: Session,
1887) -> Result<()> {
1888 let project_id = ProjectId::from_proto(request.project_id);
1889
1890 tracing::info!(%project_id, "join project");
1891
1892 let db = session.db().await;
1893 let (project, replica_id) = &mut *db
1894 .join_project(
1895 project_id,
1896 session.connection_id,
1897 session.user_id(),
1898 request.committer_name.clone(),
1899 request.committer_email.clone(),
1900 )
1901 .await?;
1902 drop(db);
1903 tracing::info!(%project_id, "join remote project");
1904 let collaborators = project
1905 .collaborators
1906 .iter()
1907 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1908 .map(|collaborator| collaborator.to_proto())
1909 .collect::<Vec<_>>();
1910 let project_id = project.id;
1911 let guest_user_id = session.user_id();
1912
1913 let worktrees = project
1914 .worktrees
1915 .iter()
1916 .map(|(id, worktree)| proto::WorktreeMetadata {
1917 id: *id,
1918 root_name: worktree.root_name.clone(),
1919 visible: worktree.visible,
1920 abs_path: worktree.abs_path.clone(),
1921 })
1922 .collect::<Vec<_>>();
1923
1924 let add_project_collaborator = proto::AddProjectCollaborator {
1925 project_id: project_id.to_proto(),
1926 collaborator: Some(proto::Collaborator {
1927 peer_id: Some(session.connection_id.into()),
1928 replica_id: replica_id.0 as u32,
1929 user_id: guest_user_id.to_proto(),
1930 is_host: false,
1931 committer_name: request.committer_name.clone(),
1932 committer_email: request.committer_email.clone(),
1933 }),
1934 };
1935
1936 for collaborator in &collaborators {
1937 session
1938 .peer
1939 .send(
1940 collaborator.peer_id.unwrap().into(),
1941 add_project_collaborator.clone(),
1942 )
1943 .trace_err();
1944 }
1945
1946 // First, we send the metadata associated with each worktree.
1947 response.send(proto::JoinProjectResponse {
1948 project_id: project.id.0 as u64,
1949 worktrees: worktrees.clone(),
1950 replica_id: replica_id.0 as u32,
1951 collaborators: collaborators.clone(),
1952 language_servers: project.language_servers.clone(),
1953 role: project.role.into(),
1954 })?;
1955
1956 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1957 // Stream this worktree's entries.
1958 let message = proto::UpdateWorktree {
1959 project_id: project_id.to_proto(),
1960 worktree_id,
1961 abs_path: worktree.abs_path.clone(),
1962 root_name: worktree.root_name,
1963 updated_entries: worktree.entries,
1964 removed_entries: Default::default(),
1965 scan_id: worktree.scan_id,
1966 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1967 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1968 removed_repositories: Default::default(),
1969 };
1970 for update in proto::split_worktree_update(message) {
1971 session.peer.send(session.connection_id, update.clone())?;
1972 }
1973
1974 // Stream this worktree's diagnostics.
1975 for summary in worktree.diagnostic_summaries {
1976 session.peer.send(
1977 session.connection_id,
1978 proto::UpdateDiagnosticSummary {
1979 project_id: project_id.to_proto(),
1980 worktree_id: worktree.id,
1981 summary: Some(summary),
1982 },
1983 )?;
1984 }
1985
1986 for settings_file in worktree.settings_files {
1987 session.peer.send(
1988 session.connection_id,
1989 proto::UpdateWorktreeSettings {
1990 project_id: project_id.to_proto(),
1991 worktree_id: worktree.id,
1992 path: settings_file.path,
1993 content: Some(settings_file.content),
1994 kind: Some(settings_file.kind.to_proto() as i32),
1995 },
1996 )?;
1997 }
1998 }
1999
2000 for repository in mem::take(&mut project.repositories) {
2001 for update in split_repository_update(repository) {
2002 session.peer.send(session.connection_id, update)?;
2003 }
2004 }
2005
2006 for language_server in &project.language_servers {
2007 session.peer.send(
2008 session.connection_id,
2009 proto::UpdateLanguageServer {
2010 project_id: project_id.to_proto(),
2011 server_name: Some(language_server.name.clone()),
2012 language_server_id: language_server.id,
2013 variant: Some(
2014 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2015 proto::LspDiskBasedDiagnosticsUpdated {},
2016 ),
2017 ),
2018 },
2019 )?;
2020 }
2021
2022 Ok(())
2023}
2024
2025/// Leave someone elses shared project.
2026async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
2027 let sender_id = session.connection_id;
2028 let project_id = ProjectId::from_proto(request.project_id);
2029 let db = session.db().await;
2030
2031 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2032 tracing::info!(
2033 %project_id,
2034 "leave project"
2035 );
2036
2037 project_left(project, &session);
2038 if let Some(room) = room {
2039 room_updated(room, &session.peer);
2040 }
2041
2042 Ok(())
2043}
2044
2045/// Updates other participants with changes to the project
2046async fn update_project(
2047 request: proto::UpdateProject,
2048 response: Response<proto::UpdateProject>,
2049 session: Session,
2050) -> Result<()> {
2051 let project_id = ProjectId::from_proto(request.project_id);
2052 let (room, guest_connection_ids) = &*session
2053 .db()
2054 .await
2055 .update_project(project_id, session.connection_id, &request.worktrees)
2056 .await?;
2057 broadcast(
2058 Some(session.connection_id),
2059 guest_connection_ids.iter().copied(),
2060 |connection_id| {
2061 session
2062 .peer
2063 .forward_send(session.connection_id, connection_id, request.clone())
2064 },
2065 );
2066 if let Some(room) = room {
2067 room_updated(room, &session.peer);
2068 }
2069 response.send(proto::Ack {})?;
2070
2071 Ok(())
2072}
2073
2074/// Updates other participants with changes to the worktree
2075async fn update_worktree(
2076 request: proto::UpdateWorktree,
2077 response: Response<proto::UpdateWorktree>,
2078 session: Session,
2079) -> Result<()> {
2080 let guest_connection_ids = session
2081 .db()
2082 .await
2083 .update_worktree(&request, session.connection_id)
2084 .await?;
2085
2086 broadcast(
2087 Some(session.connection_id),
2088 guest_connection_ids.iter().copied(),
2089 |connection_id| {
2090 session
2091 .peer
2092 .forward_send(session.connection_id, connection_id, request.clone())
2093 },
2094 );
2095 response.send(proto::Ack {})?;
2096 Ok(())
2097}
2098
2099async fn update_repository(
2100 request: proto::UpdateRepository,
2101 response: Response<proto::UpdateRepository>,
2102 session: Session,
2103) -> Result<()> {
2104 let guest_connection_ids = session
2105 .db()
2106 .await
2107 .update_repository(&request, session.connection_id)
2108 .await?;
2109
2110 broadcast(
2111 Some(session.connection_id),
2112 guest_connection_ids.iter().copied(),
2113 |connection_id| {
2114 session
2115 .peer
2116 .forward_send(session.connection_id, connection_id, request.clone())
2117 },
2118 );
2119 response.send(proto::Ack {})?;
2120 Ok(())
2121}
2122
2123async fn remove_repository(
2124 request: proto::RemoveRepository,
2125 response: Response<proto::RemoveRepository>,
2126 session: Session,
2127) -> Result<()> {
2128 let guest_connection_ids = session
2129 .db()
2130 .await
2131 .remove_repository(&request, session.connection_id)
2132 .await?;
2133
2134 broadcast(
2135 Some(session.connection_id),
2136 guest_connection_ids.iter().copied(),
2137 |connection_id| {
2138 session
2139 .peer
2140 .forward_send(session.connection_id, connection_id, request.clone())
2141 },
2142 );
2143 response.send(proto::Ack {})?;
2144 Ok(())
2145}
2146
2147/// Updates other participants with changes to the diagnostics
2148async fn update_diagnostic_summary(
2149 message: proto::UpdateDiagnosticSummary,
2150 session: Session,
2151) -> Result<()> {
2152 let guest_connection_ids = session
2153 .db()
2154 .await
2155 .update_diagnostic_summary(&message, session.connection_id)
2156 .await?;
2157
2158 broadcast(
2159 Some(session.connection_id),
2160 guest_connection_ids.iter().copied(),
2161 |connection_id| {
2162 session
2163 .peer
2164 .forward_send(session.connection_id, connection_id, message.clone())
2165 },
2166 );
2167
2168 Ok(())
2169}
2170
2171/// Updates other participants with changes to the worktree settings
2172async fn update_worktree_settings(
2173 message: proto::UpdateWorktreeSettings,
2174 session: Session,
2175) -> Result<()> {
2176 let guest_connection_ids = session
2177 .db()
2178 .await
2179 .update_worktree_settings(&message, session.connection_id)
2180 .await?;
2181
2182 broadcast(
2183 Some(session.connection_id),
2184 guest_connection_ids.iter().copied(),
2185 |connection_id| {
2186 session
2187 .peer
2188 .forward_send(session.connection_id, connection_id, message.clone())
2189 },
2190 );
2191
2192 Ok(())
2193}
2194
2195/// Notify other participants that a language server has started.
2196async fn start_language_server(
2197 request: proto::StartLanguageServer,
2198 session: Session,
2199) -> Result<()> {
2200 let guest_connection_ids = session
2201 .db()
2202 .await
2203 .start_language_server(&request, session.connection_id)
2204 .await?;
2205
2206 broadcast(
2207 Some(session.connection_id),
2208 guest_connection_ids.iter().copied(),
2209 |connection_id| {
2210 session
2211 .peer
2212 .forward_send(session.connection_id, connection_id, request.clone())
2213 },
2214 );
2215 Ok(())
2216}
2217
2218/// Notify other participants that a language server has changed.
2219async fn update_language_server(
2220 request: proto::UpdateLanguageServer,
2221 session: Session,
2222) -> Result<()> {
2223 let project_id = ProjectId::from_proto(request.project_id);
2224 let project_connection_ids = session
2225 .db()
2226 .await
2227 .project_connection_ids(project_id, session.connection_id, true)
2228 .await?;
2229 broadcast(
2230 Some(session.connection_id),
2231 project_connection_ids.iter().copied(),
2232 |connection_id| {
2233 session
2234 .peer
2235 .forward_send(session.connection_id, connection_id, request.clone())
2236 },
2237 );
2238 Ok(())
2239}
2240
2241/// forward a project request to the host. These requests should be read only
2242/// as guests are allowed to send them.
2243async fn forward_read_only_project_request<T>(
2244 request: T,
2245 response: Response<T>,
2246 session: Session,
2247) -> Result<()>
2248where
2249 T: EntityMessage + RequestMessage,
2250{
2251 let project_id = ProjectId::from_proto(request.remote_entity_id());
2252 let host_connection_id = session
2253 .db()
2254 .await
2255 .host_for_read_only_project_request(project_id, session.connection_id)
2256 .await?;
2257 let payload = session
2258 .peer
2259 .forward_request(session.connection_id, host_connection_id, request)
2260 .await?;
2261 response.send(payload)?;
2262 Ok(())
2263}
2264
2265async fn forward_find_search_candidates_request(
2266 request: proto::FindSearchCandidates,
2267 response: Response<proto::FindSearchCandidates>,
2268 session: Session,
2269) -> Result<()> {
2270 let project_id = ProjectId::from_proto(request.remote_entity_id());
2271 let host_connection_id = session
2272 .db()
2273 .await
2274 .host_for_read_only_project_request(project_id, session.connection_id)
2275 .await?;
2276 let payload = session
2277 .peer
2278 .forward_request(session.connection_id, host_connection_id, request)
2279 .await?;
2280 response.send(payload)?;
2281 Ok(())
2282}
2283
2284/// forward a project request to the host. These requests are disallowed
2285/// for guests.
2286async fn forward_mutating_project_request<T>(
2287 request: T,
2288 response: Response<T>,
2289 session: Session,
2290) -> Result<()>
2291where
2292 T: EntityMessage + RequestMessage,
2293{
2294 let project_id = ProjectId::from_proto(request.remote_entity_id());
2295
2296 let host_connection_id = session
2297 .db()
2298 .await
2299 .host_for_mutating_project_request(project_id, session.connection_id)
2300 .await?;
2301 let payload = session
2302 .peer
2303 .forward_request(session.connection_id, host_connection_id, request)
2304 .await?;
2305 response.send(payload)?;
2306 Ok(())
2307}
2308
2309/// Notify other participants that a new buffer has been created
2310async fn create_buffer_for_peer(
2311 request: proto::CreateBufferForPeer,
2312 session: Session,
2313) -> Result<()> {
2314 session
2315 .db()
2316 .await
2317 .check_user_is_project_host(
2318 ProjectId::from_proto(request.project_id),
2319 session.connection_id,
2320 )
2321 .await?;
2322 let peer_id = request.peer_id.context("invalid peer id")?;
2323 session
2324 .peer
2325 .forward_send(session.connection_id, peer_id.into(), request)?;
2326 Ok(())
2327}
2328
2329/// Notify other participants that a buffer has been updated. This is
2330/// allowed for guests as long as the update is limited to selections.
2331async fn update_buffer(
2332 request: proto::UpdateBuffer,
2333 response: Response<proto::UpdateBuffer>,
2334 session: Session,
2335) -> Result<()> {
2336 let project_id = ProjectId::from_proto(request.project_id);
2337 let mut capability = Capability::ReadOnly;
2338
2339 for op in request.operations.iter() {
2340 match op.variant {
2341 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2342 Some(_) => capability = Capability::ReadWrite,
2343 }
2344 }
2345
2346 let host = {
2347 let guard = session
2348 .db()
2349 .await
2350 .connections_for_buffer_update(project_id, session.connection_id, capability)
2351 .await?;
2352
2353 let (host, guests) = &*guard;
2354
2355 broadcast(
2356 Some(session.connection_id),
2357 guests.clone(),
2358 |connection_id| {
2359 session
2360 .peer
2361 .forward_send(session.connection_id, connection_id, request.clone())
2362 },
2363 );
2364
2365 *host
2366 };
2367
2368 if host != session.connection_id {
2369 session
2370 .peer
2371 .forward_request(session.connection_id, host, request.clone())
2372 .await?;
2373 }
2374
2375 response.send(proto::Ack {})?;
2376 Ok(())
2377}
2378
2379async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2380 let project_id = ProjectId::from_proto(message.project_id);
2381
2382 let operation = message.operation.as_ref().context("invalid operation")?;
2383 let capability = match operation.variant.as_ref() {
2384 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2385 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2386 match buffer_op.variant {
2387 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2388 Capability::ReadOnly
2389 }
2390 _ => Capability::ReadWrite,
2391 }
2392 } else {
2393 Capability::ReadWrite
2394 }
2395 }
2396 Some(_) => Capability::ReadWrite,
2397 None => Capability::ReadOnly,
2398 };
2399
2400 let guard = session
2401 .db()
2402 .await
2403 .connections_for_buffer_update(project_id, session.connection_id, capability)
2404 .await?;
2405
2406 let (host, guests) = &*guard;
2407
2408 broadcast(
2409 Some(session.connection_id),
2410 guests.iter().chain([host]).copied(),
2411 |connection_id| {
2412 session
2413 .peer
2414 .forward_send(session.connection_id, connection_id, message.clone())
2415 },
2416 );
2417
2418 Ok(())
2419}
2420
2421/// Notify other participants that a project has been updated.
2422async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2423 request: T,
2424 session: Session,
2425) -> Result<()> {
2426 let project_id = ProjectId::from_proto(request.remote_entity_id());
2427 let project_connection_ids = session
2428 .db()
2429 .await
2430 .project_connection_ids(project_id, session.connection_id, false)
2431 .await?;
2432
2433 broadcast(
2434 Some(session.connection_id),
2435 project_connection_ids.iter().copied(),
2436 |connection_id| {
2437 session
2438 .peer
2439 .forward_send(session.connection_id, connection_id, request.clone())
2440 },
2441 );
2442 Ok(())
2443}
2444
2445/// Start following another user in a call.
2446async fn follow(
2447 request: proto::Follow,
2448 response: Response<proto::Follow>,
2449 session: Session,
2450) -> Result<()> {
2451 let room_id = RoomId::from_proto(request.room_id);
2452 let project_id = request.project_id.map(ProjectId::from_proto);
2453 let leader_id = request.leader_id.context("invalid leader id")?.into();
2454 let follower_id = session.connection_id;
2455
2456 session
2457 .db()
2458 .await
2459 .check_room_participants(room_id, leader_id, session.connection_id)
2460 .await?;
2461
2462 let response_payload = session
2463 .peer
2464 .forward_request(session.connection_id, leader_id, request)
2465 .await?;
2466 response.send(response_payload)?;
2467
2468 if let Some(project_id) = project_id {
2469 let room = session
2470 .db()
2471 .await
2472 .follow(room_id, project_id, leader_id, follower_id)
2473 .await?;
2474 room_updated(&room, &session.peer);
2475 }
2476
2477 Ok(())
2478}
2479
2480/// Stop following another user in a call.
2481async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2482 let room_id = RoomId::from_proto(request.room_id);
2483 let project_id = request.project_id.map(ProjectId::from_proto);
2484 let leader_id = request.leader_id.context("invalid leader id")?.into();
2485 let follower_id = session.connection_id;
2486
2487 session
2488 .db()
2489 .await
2490 .check_room_participants(room_id, leader_id, session.connection_id)
2491 .await?;
2492
2493 session
2494 .peer
2495 .forward_send(session.connection_id, leader_id, request)?;
2496
2497 if let Some(project_id) = project_id {
2498 let room = session
2499 .db()
2500 .await
2501 .unfollow(room_id, project_id, leader_id, follower_id)
2502 .await?;
2503 room_updated(&room, &session.peer);
2504 }
2505
2506 Ok(())
2507}
2508
2509/// Notify everyone following you of your current location.
2510async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2511 let room_id = RoomId::from_proto(request.room_id);
2512 let database = session.db.lock().await;
2513
2514 let connection_ids = if let Some(project_id) = request.project_id {
2515 let project_id = ProjectId::from_proto(project_id);
2516 database
2517 .project_connection_ids(project_id, session.connection_id, true)
2518 .await?
2519 } else {
2520 database
2521 .room_connection_ids(room_id, session.connection_id)
2522 .await?
2523 };
2524
2525 // For now, don't send view update messages back to that view's current leader.
2526 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2527 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2528 _ => None,
2529 });
2530
2531 for connection_id in connection_ids.iter().cloned() {
2532 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2533 session
2534 .peer
2535 .forward_send(session.connection_id, connection_id, request.clone())?;
2536 }
2537 }
2538 Ok(())
2539}
2540
2541/// Get public data about users.
2542async fn get_users(
2543 request: proto::GetUsers,
2544 response: Response<proto::GetUsers>,
2545 session: Session,
2546) -> Result<()> {
2547 let user_ids = request
2548 .user_ids
2549 .into_iter()
2550 .map(UserId::from_proto)
2551 .collect();
2552 let users = session
2553 .db()
2554 .await
2555 .get_users_by_ids(user_ids)
2556 .await?
2557 .into_iter()
2558 .map(|user| proto::User {
2559 id: user.id.to_proto(),
2560 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2561 github_login: user.github_login,
2562 name: user.name,
2563 })
2564 .collect();
2565 response.send(proto::UsersResponse { users })?;
2566 Ok(())
2567}
2568
2569/// Search for users (to invite) buy Github login
2570async fn fuzzy_search_users(
2571 request: proto::FuzzySearchUsers,
2572 response: Response<proto::FuzzySearchUsers>,
2573 session: Session,
2574) -> Result<()> {
2575 let query = request.query;
2576 let users = match query.len() {
2577 0 => vec![],
2578 1 | 2 => session
2579 .db()
2580 .await
2581 .get_user_by_github_login(&query)
2582 .await?
2583 .into_iter()
2584 .collect(),
2585 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2586 };
2587 let users = users
2588 .into_iter()
2589 .filter(|user| user.id != session.user_id())
2590 .map(|user| proto::User {
2591 id: user.id.to_proto(),
2592 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2593 github_login: user.github_login,
2594 name: user.name,
2595 })
2596 .collect();
2597 response.send(proto::UsersResponse { users })?;
2598 Ok(())
2599}
2600
2601/// Send a contact request to another user.
2602async fn request_contact(
2603 request: proto::RequestContact,
2604 response: Response<proto::RequestContact>,
2605 session: Session,
2606) -> Result<()> {
2607 let requester_id = session.user_id();
2608 let responder_id = UserId::from_proto(request.responder_id);
2609 if requester_id == responder_id {
2610 return Err(anyhow!("cannot add yourself as a contact"))?;
2611 }
2612
2613 let notifications = session
2614 .db()
2615 .await
2616 .send_contact_request(requester_id, responder_id)
2617 .await?;
2618
2619 // Update outgoing contact requests of requester
2620 let mut update = proto::UpdateContacts::default();
2621 update.outgoing_requests.push(responder_id.to_proto());
2622 for connection_id in session
2623 .connection_pool()
2624 .await
2625 .user_connection_ids(requester_id)
2626 {
2627 session.peer.send(connection_id, update.clone())?;
2628 }
2629
2630 // Update incoming contact requests of responder
2631 let mut update = proto::UpdateContacts::default();
2632 update
2633 .incoming_requests
2634 .push(proto::IncomingContactRequest {
2635 requester_id: requester_id.to_proto(),
2636 });
2637 let connection_pool = session.connection_pool().await;
2638 for connection_id in connection_pool.user_connection_ids(responder_id) {
2639 session.peer.send(connection_id, update.clone())?;
2640 }
2641
2642 send_notifications(&connection_pool, &session.peer, notifications);
2643
2644 response.send(proto::Ack {})?;
2645 Ok(())
2646}
2647
2648/// Accept or decline a contact request
2649async fn respond_to_contact_request(
2650 request: proto::RespondToContactRequest,
2651 response: Response<proto::RespondToContactRequest>,
2652 session: Session,
2653) -> Result<()> {
2654 let responder_id = session.user_id();
2655 let requester_id = UserId::from_proto(request.requester_id);
2656 let db = session.db().await;
2657 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2658 db.dismiss_contact_notification(responder_id, requester_id)
2659 .await?;
2660 } else {
2661 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2662
2663 let notifications = db
2664 .respond_to_contact_request(responder_id, requester_id, accept)
2665 .await?;
2666 let requester_busy = db.is_user_busy(requester_id).await?;
2667 let responder_busy = db.is_user_busy(responder_id).await?;
2668
2669 let pool = session.connection_pool().await;
2670 // Update responder with new contact
2671 let mut update = proto::UpdateContacts::default();
2672 if accept {
2673 update
2674 .contacts
2675 .push(contact_for_user(requester_id, requester_busy, &pool));
2676 }
2677 update
2678 .remove_incoming_requests
2679 .push(requester_id.to_proto());
2680 for connection_id in pool.user_connection_ids(responder_id) {
2681 session.peer.send(connection_id, update.clone())?;
2682 }
2683
2684 // Update requester with new contact
2685 let mut update = proto::UpdateContacts::default();
2686 if accept {
2687 update
2688 .contacts
2689 .push(contact_for_user(responder_id, responder_busy, &pool));
2690 }
2691 update
2692 .remove_outgoing_requests
2693 .push(responder_id.to_proto());
2694
2695 for connection_id in pool.user_connection_ids(requester_id) {
2696 session.peer.send(connection_id, update.clone())?;
2697 }
2698
2699 send_notifications(&pool, &session.peer, notifications);
2700 }
2701
2702 response.send(proto::Ack {})?;
2703 Ok(())
2704}
2705
2706/// Remove a contact.
2707async fn remove_contact(
2708 request: proto::RemoveContact,
2709 response: Response<proto::RemoveContact>,
2710 session: Session,
2711) -> Result<()> {
2712 let requester_id = session.user_id();
2713 let responder_id = UserId::from_proto(request.user_id);
2714 let db = session.db().await;
2715 let (contact_accepted, deleted_notification_id) =
2716 db.remove_contact(requester_id, responder_id).await?;
2717
2718 let pool = session.connection_pool().await;
2719 // Update outgoing contact requests of requester
2720 let mut update = proto::UpdateContacts::default();
2721 if contact_accepted {
2722 update.remove_contacts.push(responder_id.to_proto());
2723 } else {
2724 update
2725 .remove_outgoing_requests
2726 .push(responder_id.to_proto());
2727 }
2728 for connection_id in pool.user_connection_ids(requester_id) {
2729 session.peer.send(connection_id, update.clone())?;
2730 }
2731
2732 // Update incoming contact requests of responder
2733 let mut update = proto::UpdateContacts::default();
2734 if contact_accepted {
2735 update.remove_contacts.push(requester_id.to_proto());
2736 } else {
2737 update
2738 .remove_incoming_requests
2739 .push(requester_id.to_proto());
2740 }
2741 for connection_id in pool.user_connection_ids(responder_id) {
2742 session.peer.send(connection_id, update.clone())?;
2743 if let Some(notification_id) = deleted_notification_id {
2744 session.peer.send(
2745 connection_id,
2746 proto::DeleteNotification {
2747 notification_id: notification_id.to_proto(),
2748 },
2749 )?;
2750 }
2751 }
2752
2753 response.send(proto::Ack {})?;
2754 Ok(())
2755}
2756
2757fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2758 version.0.minor() < 139
2759}
2760
2761async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2762 if is_staff {
2763 return Ok(proto::Plan::ZedPro);
2764 }
2765
2766 let subscription = db.get_active_billing_subscription(user_id).await?;
2767 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2768
2769 let plan = if let Some(subscription_kind) = subscription_kind {
2770 match subscription_kind {
2771 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2772 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2773 SubscriptionKind::ZedFree => proto::Plan::Free,
2774 }
2775 } else {
2776 proto::Plan::Free
2777 };
2778
2779 Ok(plan)
2780}
2781
2782async fn make_update_user_plan_message(
2783 user: &User,
2784 is_staff: bool,
2785 db: &Arc<Database>,
2786 llm_db: Option<Arc<LlmDatabase>>,
2787) -> Result<proto::UpdateUserPlan> {
2788 let feature_flags = db.get_user_flags(user.id).await?;
2789 let plan = current_plan(db, user.id, is_staff).await?;
2790 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2791 let billing_preferences = db.get_billing_preferences(user.id).await?;
2792
2793 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2794 let subscription = db.get_active_billing_subscription(user.id).await?;
2795
2796 let subscription_period =
2797 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2798
2799 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2800 llm_db
2801 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2802 .await?
2803 } else {
2804 None
2805 };
2806
2807 (subscription_period, usage)
2808 } else {
2809 (None, None)
2810 };
2811
2812 let bypass_account_age_check = feature_flags
2813 .iter()
2814 .any(|flag| flag == BYPASS_ACCOUNT_AGE_CHECK_FEATURE_FLAG);
2815 let account_too_young = !matches!(plan, proto::Plan::ZedPro)
2816 && !bypass_account_age_check
2817 && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2818
2819 Ok(proto::UpdateUserPlan {
2820 plan: plan.into(),
2821 trial_started_at: billing_customer
2822 .as_ref()
2823 .and_then(|billing_customer| billing_customer.trial_started_at)
2824 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2825 is_usage_based_billing_enabled: if is_staff {
2826 Some(true)
2827 } else {
2828 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2829 },
2830 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2831 proto::SubscriptionPeriod {
2832 started_at: started_at.timestamp() as u64,
2833 ended_at: ended_at.timestamp() as u64,
2834 }
2835 }),
2836 account_too_young: Some(account_too_young),
2837 has_overdue_invoices: billing_customer
2838 .map(|billing_customer| billing_customer.has_overdue_invoices),
2839 usage: Some(
2840 usage
2841 .map(|usage| subscription_usage_to_proto(plan, usage, &feature_flags))
2842 .unwrap_or_else(|| make_default_subscription_usage(plan, &feature_flags)),
2843 ),
2844 })
2845}
2846
2847fn model_requests_limit(
2848 plan: zed_llm_client::Plan,
2849 feature_flags: &Vec<String>,
2850) -> zed_llm_client::UsageLimit {
2851 match plan.model_requests_limit() {
2852 zed_llm_client::UsageLimit::Limited(limit) => {
2853 let limit = if plan == zed_llm_client::Plan::ZedProTrial
2854 && feature_flags
2855 .iter()
2856 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2857 {
2858 1_000
2859 } else {
2860 limit
2861 };
2862
2863 zed_llm_client::UsageLimit::Limited(limit)
2864 }
2865 zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2866 }
2867}
2868
2869fn subscription_usage_to_proto(
2870 plan: proto::Plan,
2871 usage: crate::llm::db::subscription_usage::Model,
2872 feature_flags: &Vec<String>,
2873) -> proto::SubscriptionUsage {
2874 let plan = match plan {
2875 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2876 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2877 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2878 };
2879
2880 proto::SubscriptionUsage {
2881 model_requests_usage_amount: usage.model_requests as u32,
2882 model_requests_usage_limit: Some(proto::UsageLimit {
2883 variant: Some(match model_requests_limit(plan, feature_flags) {
2884 zed_llm_client::UsageLimit::Limited(limit) => {
2885 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2886 limit: limit as u32,
2887 })
2888 }
2889 zed_llm_client::UsageLimit::Unlimited => {
2890 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2891 }
2892 }),
2893 }),
2894 edit_predictions_usage_amount: usage.edit_predictions as u32,
2895 edit_predictions_usage_limit: Some(proto::UsageLimit {
2896 variant: Some(match plan.edit_predictions_limit() {
2897 zed_llm_client::UsageLimit::Limited(limit) => {
2898 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2899 limit: limit as u32,
2900 })
2901 }
2902 zed_llm_client::UsageLimit::Unlimited => {
2903 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2904 }
2905 }),
2906 }),
2907 }
2908}
2909
2910fn make_default_subscription_usage(
2911 plan: proto::Plan,
2912 feature_flags: &Vec<String>,
2913) -> proto::SubscriptionUsage {
2914 let plan = match plan {
2915 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2916 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2917 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2918 };
2919
2920 proto::SubscriptionUsage {
2921 model_requests_usage_amount: 0,
2922 model_requests_usage_limit: Some(proto::UsageLimit {
2923 variant: Some(match model_requests_limit(plan, feature_flags) {
2924 zed_llm_client::UsageLimit::Limited(limit) => {
2925 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2926 limit: limit as u32,
2927 })
2928 }
2929 zed_llm_client::UsageLimit::Unlimited => {
2930 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2931 }
2932 }),
2933 }),
2934 edit_predictions_usage_amount: 0,
2935 edit_predictions_usage_limit: Some(proto::UsageLimit {
2936 variant: Some(match plan.edit_predictions_limit() {
2937 zed_llm_client::UsageLimit::Limited(limit) => {
2938 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2939 limit: limit as u32,
2940 })
2941 }
2942 zed_llm_client::UsageLimit::Unlimited => {
2943 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2944 }
2945 }),
2946 }),
2947 }
2948}
2949
2950async fn update_user_plan(session: &Session) -> Result<()> {
2951 let db = session.db().await;
2952
2953 let update_user_plan = make_update_user_plan_message(
2954 session.principal.user(),
2955 session.is_staff(),
2956 &db.0,
2957 session.app_state.llm_db.clone(),
2958 )
2959 .await?;
2960
2961 session
2962 .peer
2963 .send(session.connection_id, update_user_plan)
2964 .trace_err();
2965
2966 Ok(())
2967}
2968
2969async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2970 subscribe_user_to_channels(session.user_id(), &session).await?;
2971 Ok(())
2972}
2973
2974async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2975 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2976 let mut pool = session.connection_pool().await;
2977 for membership in &channels_for_user.channel_memberships {
2978 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2979 }
2980 session.peer.send(
2981 session.connection_id,
2982 build_update_user_channels(&channels_for_user),
2983 )?;
2984 session.peer.send(
2985 session.connection_id,
2986 build_channels_update(channels_for_user),
2987 )?;
2988 Ok(())
2989}
2990
2991/// Creates a new channel.
2992async fn create_channel(
2993 request: proto::CreateChannel,
2994 response: Response<proto::CreateChannel>,
2995 session: Session,
2996) -> Result<()> {
2997 let db = session.db().await;
2998
2999 let parent_id = request.parent_id.map(ChannelId::from_proto);
3000 let (channel, membership) = db
3001 .create_channel(&request.name, parent_id, session.user_id())
3002 .await?;
3003
3004 let root_id = channel.root_id();
3005 let channel = Channel::from_model(channel);
3006
3007 response.send(proto::CreateChannelResponse {
3008 channel: Some(channel.to_proto()),
3009 parent_id: request.parent_id,
3010 })?;
3011
3012 let mut connection_pool = session.connection_pool().await;
3013 if let Some(membership) = membership {
3014 connection_pool.subscribe_to_channel(
3015 membership.user_id,
3016 membership.channel_id,
3017 membership.role,
3018 );
3019 let update = proto::UpdateUserChannels {
3020 channel_memberships: vec![proto::ChannelMembership {
3021 channel_id: membership.channel_id.to_proto(),
3022 role: membership.role.into(),
3023 }],
3024 ..Default::default()
3025 };
3026 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3027 session.peer.send(connection_id, update.clone())?;
3028 }
3029 }
3030
3031 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3032 if !role.can_see_channel(channel.visibility) {
3033 continue;
3034 }
3035
3036 let update = proto::UpdateChannels {
3037 channels: vec![channel.to_proto()],
3038 ..Default::default()
3039 };
3040 session.peer.send(connection_id, update.clone())?;
3041 }
3042
3043 Ok(())
3044}
3045
3046/// Delete a channel
3047async fn delete_channel(
3048 request: proto::DeleteChannel,
3049 response: Response<proto::DeleteChannel>,
3050 session: Session,
3051) -> Result<()> {
3052 let db = session.db().await;
3053
3054 let channel_id = request.channel_id;
3055 let (root_channel, removed_channels) = db
3056 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3057 .await?;
3058 response.send(proto::Ack {})?;
3059
3060 // Notify members of removed channels
3061 let mut update = proto::UpdateChannels::default();
3062 update
3063 .delete_channels
3064 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3065
3066 let connection_pool = session.connection_pool().await;
3067 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3068 session.peer.send(connection_id, update.clone())?;
3069 }
3070
3071 Ok(())
3072}
3073
3074/// Invite someone to join a channel.
3075async fn invite_channel_member(
3076 request: proto::InviteChannelMember,
3077 response: Response<proto::InviteChannelMember>,
3078 session: Session,
3079) -> Result<()> {
3080 let db = session.db().await;
3081 let channel_id = ChannelId::from_proto(request.channel_id);
3082 let invitee_id = UserId::from_proto(request.user_id);
3083 let InviteMemberResult {
3084 channel,
3085 notifications,
3086 } = db
3087 .invite_channel_member(
3088 channel_id,
3089 invitee_id,
3090 session.user_id(),
3091 request.role().into(),
3092 )
3093 .await?;
3094
3095 let update = proto::UpdateChannels {
3096 channel_invitations: vec![channel.to_proto()],
3097 ..Default::default()
3098 };
3099
3100 let connection_pool = session.connection_pool().await;
3101 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3102 session.peer.send(connection_id, update.clone())?;
3103 }
3104
3105 send_notifications(&connection_pool, &session.peer, notifications);
3106
3107 response.send(proto::Ack {})?;
3108 Ok(())
3109}
3110
3111/// remove someone from a channel
3112async fn remove_channel_member(
3113 request: proto::RemoveChannelMember,
3114 response: Response<proto::RemoveChannelMember>,
3115 session: Session,
3116) -> Result<()> {
3117 let db = session.db().await;
3118 let channel_id = ChannelId::from_proto(request.channel_id);
3119 let member_id = UserId::from_proto(request.user_id);
3120
3121 let RemoveChannelMemberResult {
3122 membership_update,
3123 notification_id,
3124 } = db
3125 .remove_channel_member(channel_id, member_id, session.user_id())
3126 .await?;
3127
3128 let mut connection_pool = session.connection_pool().await;
3129 notify_membership_updated(
3130 &mut connection_pool,
3131 membership_update,
3132 member_id,
3133 &session.peer,
3134 );
3135 for connection_id in connection_pool.user_connection_ids(member_id) {
3136 if let Some(notification_id) = notification_id {
3137 session
3138 .peer
3139 .send(
3140 connection_id,
3141 proto::DeleteNotification {
3142 notification_id: notification_id.to_proto(),
3143 },
3144 )
3145 .trace_err();
3146 }
3147 }
3148
3149 response.send(proto::Ack {})?;
3150 Ok(())
3151}
3152
3153/// Toggle the channel between public and private.
3154/// Care is taken to maintain the invariant that public channels only descend from public channels,
3155/// (though members-only channels can appear at any point in the hierarchy).
3156async fn set_channel_visibility(
3157 request: proto::SetChannelVisibility,
3158 response: Response<proto::SetChannelVisibility>,
3159 session: Session,
3160) -> Result<()> {
3161 let db = session.db().await;
3162 let channel_id = ChannelId::from_proto(request.channel_id);
3163 let visibility = request.visibility().into();
3164
3165 let channel_model = db
3166 .set_channel_visibility(channel_id, visibility, session.user_id())
3167 .await?;
3168 let root_id = channel_model.root_id();
3169 let channel = Channel::from_model(channel_model);
3170
3171 let mut connection_pool = session.connection_pool().await;
3172 for (user_id, role) in connection_pool
3173 .channel_user_ids(root_id)
3174 .collect::<Vec<_>>()
3175 .into_iter()
3176 {
3177 let update = if role.can_see_channel(channel.visibility) {
3178 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3179 proto::UpdateChannels {
3180 channels: vec![channel.to_proto()],
3181 ..Default::default()
3182 }
3183 } else {
3184 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3185 proto::UpdateChannels {
3186 delete_channels: vec![channel.id.to_proto()],
3187 ..Default::default()
3188 }
3189 };
3190
3191 for connection_id in connection_pool.user_connection_ids(user_id) {
3192 session.peer.send(connection_id, update.clone())?;
3193 }
3194 }
3195
3196 response.send(proto::Ack {})?;
3197 Ok(())
3198}
3199
3200/// Alter the role for a user in the channel.
3201async fn set_channel_member_role(
3202 request: proto::SetChannelMemberRole,
3203 response: Response<proto::SetChannelMemberRole>,
3204 session: Session,
3205) -> Result<()> {
3206 let db = session.db().await;
3207 let channel_id = ChannelId::from_proto(request.channel_id);
3208 let member_id = UserId::from_proto(request.user_id);
3209 let result = db
3210 .set_channel_member_role(
3211 channel_id,
3212 session.user_id(),
3213 member_id,
3214 request.role().into(),
3215 )
3216 .await?;
3217
3218 match result {
3219 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3220 let mut connection_pool = session.connection_pool().await;
3221 notify_membership_updated(
3222 &mut connection_pool,
3223 membership_update,
3224 member_id,
3225 &session.peer,
3226 )
3227 }
3228 db::SetMemberRoleResult::InviteUpdated(channel) => {
3229 let update = proto::UpdateChannels {
3230 channel_invitations: vec![channel.to_proto()],
3231 ..Default::default()
3232 };
3233
3234 for connection_id in session
3235 .connection_pool()
3236 .await
3237 .user_connection_ids(member_id)
3238 {
3239 session.peer.send(connection_id, update.clone())?;
3240 }
3241 }
3242 }
3243
3244 response.send(proto::Ack {})?;
3245 Ok(())
3246}
3247
3248/// Change the name of a channel
3249async fn rename_channel(
3250 request: proto::RenameChannel,
3251 response: Response<proto::RenameChannel>,
3252 session: Session,
3253) -> Result<()> {
3254 let db = session.db().await;
3255 let channel_id = ChannelId::from_proto(request.channel_id);
3256 let channel_model = db
3257 .rename_channel(channel_id, session.user_id(), &request.name)
3258 .await?;
3259 let root_id = channel_model.root_id();
3260 let channel = Channel::from_model(channel_model);
3261
3262 response.send(proto::RenameChannelResponse {
3263 channel: Some(channel.to_proto()),
3264 })?;
3265
3266 let connection_pool = session.connection_pool().await;
3267 let update = proto::UpdateChannels {
3268 channels: vec![channel.to_proto()],
3269 ..Default::default()
3270 };
3271 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3272 if role.can_see_channel(channel.visibility) {
3273 session.peer.send(connection_id, update.clone())?;
3274 }
3275 }
3276
3277 Ok(())
3278}
3279
3280/// Move a channel to a new parent.
3281async fn move_channel(
3282 request: proto::MoveChannel,
3283 response: Response<proto::MoveChannel>,
3284 session: Session,
3285) -> Result<()> {
3286 let channel_id = ChannelId::from_proto(request.channel_id);
3287 let to = ChannelId::from_proto(request.to);
3288
3289 let (root_id, channels) = session
3290 .db()
3291 .await
3292 .move_channel(channel_id, to, session.user_id())
3293 .await?;
3294
3295 let connection_pool = session.connection_pool().await;
3296 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3297 let channels = channels
3298 .iter()
3299 .filter_map(|channel| {
3300 if role.can_see_channel(channel.visibility) {
3301 Some(channel.to_proto())
3302 } else {
3303 None
3304 }
3305 })
3306 .collect::<Vec<_>>();
3307 if channels.is_empty() {
3308 continue;
3309 }
3310
3311 let update = proto::UpdateChannels {
3312 channels,
3313 ..Default::default()
3314 };
3315
3316 session.peer.send(connection_id, update.clone())?;
3317 }
3318
3319 response.send(Ack {})?;
3320 Ok(())
3321}
3322
3323async fn reorder_channel(
3324 request: proto::ReorderChannel,
3325 response: Response<proto::ReorderChannel>,
3326 session: Session,
3327) -> Result<()> {
3328 let channel_id = ChannelId::from_proto(request.channel_id);
3329 let direction = request.direction();
3330
3331 let updated_channels = session
3332 .db()
3333 .await
3334 .reorder_channel(channel_id, direction, session.user_id())
3335 .await?;
3336
3337 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3338 let connection_pool = session.connection_pool().await;
3339 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3340 let channels = updated_channels
3341 .iter()
3342 .filter_map(|channel| {
3343 if role.can_see_channel(channel.visibility) {
3344 Some(channel.to_proto())
3345 } else {
3346 None
3347 }
3348 })
3349 .collect::<Vec<_>>();
3350
3351 if channels.is_empty() {
3352 continue;
3353 }
3354
3355 let update = proto::UpdateChannels {
3356 channels,
3357 ..Default::default()
3358 };
3359
3360 session.peer.send(connection_id, update.clone())?;
3361 }
3362 }
3363
3364 response.send(Ack {})?;
3365 Ok(())
3366}
3367
3368/// Get the list of channel members
3369async fn get_channel_members(
3370 request: proto::GetChannelMembers,
3371 response: Response<proto::GetChannelMembers>,
3372 session: Session,
3373) -> Result<()> {
3374 let db = session.db().await;
3375 let channel_id = ChannelId::from_proto(request.channel_id);
3376 let limit = if request.limit == 0 {
3377 u16::MAX as u64
3378 } else {
3379 request.limit
3380 };
3381 let (members, users) = db
3382 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3383 .await?;
3384 response.send(proto::GetChannelMembersResponse { members, users })?;
3385 Ok(())
3386}
3387
3388/// Accept or decline a channel invitation.
3389async fn respond_to_channel_invite(
3390 request: proto::RespondToChannelInvite,
3391 response: Response<proto::RespondToChannelInvite>,
3392 session: Session,
3393) -> Result<()> {
3394 let db = session.db().await;
3395 let channel_id = ChannelId::from_proto(request.channel_id);
3396 let RespondToChannelInvite {
3397 membership_update,
3398 notifications,
3399 } = db
3400 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3401 .await?;
3402
3403 let mut connection_pool = session.connection_pool().await;
3404 if let Some(membership_update) = membership_update {
3405 notify_membership_updated(
3406 &mut connection_pool,
3407 membership_update,
3408 session.user_id(),
3409 &session.peer,
3410 );
3411 } else {
3412 let update = proto::UpdateChannels {
3413 remove_channel_invitations: vec![channel_id.to_proto()],
3414 ..Default::default()
3415 };
3416
3417 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3418 session.peer.send(connection_id, update.clone())?;
3419 }
3420 };
3421
3422 send_notifications(&connection_pool, &session.peer, notifications);
3423
3424 response.send(proto::Ack {})?;
3425
3426 Ok(())
3427}
3428
3429/// Join the channels' room
3430async fn join_channel(
3431 request: proto::JoinChannel,
3432 response: Response<proto::JoinChannel>,
3433 session: Session,
3434) -> Result<()> {
3435 let channel_id = ChannelId::from_proto(request.channel_id);
3436 join_channel_internal(channel_id, Box::new(response), session).await
3437}
3438
3439trait JoinChannelInternalResponse {
3440 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3441}
3442impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3443 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3444 Response::<proto::JoinChannel>::send(self, result)
3445 }
3446}
3447impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3448 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3449 Response::<proto::JoinRoom>::send(self, result)
3450 }
3451}
3452
3453async fn join_channel_internal(
3454 channel_id: ChannelId,
3455 response: Box<impl JoinChannelInternalResponse>,
3456 session: Session,
3457) -> Result<()> {
3458 let joined_room = {
3459 let mut db = session.db().await;
3460 // If zed quits without leaving the room, and the user re-opens zed before the
3461 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3462 // room they were in.
3463 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3464 tracing::info!(
3465 stale_connection_id = %connection,
3466 "cleaning up stale connection",
3467 );
3468 drop(db);
3469 leave_room_for_session(&session, connection).await?;
3470 db = session.db().await;
3471 }
3472
3473 let (joined_room, membership_updated, role) = db
3474 .join_channel(channel_id, session.user_id(), session.connection_id)
3475 .await?;
3476
3477 let live_kit_connection_info =
3478 session
3479 .app_state
3480 .livekit_client
3481 .as_ref()
3482 .and_then(|live_kit| {
3483 let (can_publish, token) = if role == ChannelRole::Guest {
3484 (
3485 false,
3486 live_kit
3487 .guest_token(
3488 &joined_room.room.livekit_room,
3489 &session.user_id().to_string(),
3490 )
3491 .trace_err()?,
3492 )
3493 } else {
3494 (
3495 true,
3496 live_kit
3497 .room_token(
3498 &joined_room.room.livekit_room,
3499 &session.user_id().to_string(),
3500 )
3501 .trace_err()?,
3502 )
3503 };
3504
3505 Some(LiveKitConnectionInfo {
3506 server_url: live_kit.url().into(),
3507 token,
3508 can_publish,
3509 })
3510 });
3511
3512 response.send(proto::JoinRoomResponse {
3513 room: Some(joined_room.room.clone()),
3514 channel_id: joined_room
3515 .channel
3516 .as_ref()
3517 .map(|channel| channel.id.to_proto()),
3518 live_kit_connection_info,
3519 })?;
3520
3521 let mut connection_pool = session.connection_pool().await;
3522 if let Some(membership_updated) = membership_updated {
3523 notify_membership_updated(
3524 &mut connection_pool,
3525 membership_updated,
3526 session.user_id(),
3527 &session.peer,
3528 );
3529 }
3530
3531 room_updated(&joined_room.room, &session.peer);
3532
3533 joined_room
3534 };
3535
3536 channel_updated(
3537 &joined_room.channel.context("channel not returned")?,
3538 &joined_room.room,
3539 &session.peer,
3540 &*session.connection_pool().await,
3541 );
3542
3543 update_user_contacts(session.user_id(), &session).await?;
3544 Ok(())
3545}
3546
3547/// Start editing the channel notes
3548async fn join_channel_buffer(
3549 request: proto::JoinChannelBuffer,
3550 response: Response<proto::JoinChannelBuffer>,
3551 session: Session,
3552) -> Result<()> {
3553 let db = session.db().await;
3554 let channel_id = ChannelId::from_proto(request.channel_id);
3555
3556 let open_response = db
3557 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3558 .await?;
3559
3560 let collaborators = open_response.collaborators.clone();
3561 response.send(open_response)?;
3562
3563 let update = UpdateChannelBufferCollaborators {
3564 channel_id: channel_id.to_proto(),
3565 collaborators: collaborators.clone(),
3566 };
3567 channel_buffer_updated(
3568 session.connection_id,
3569 collaborators
3570 .iter()
3571 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3572 &update,
3573 &session.peer,
3574 );
3575
3576 Ok(())
3577}
3578
3579/// Edit the channel notes
3580async fn update_channel_buffer(
3581 request: proto::UpdateChannelBuffer,
3582 session: Session,
3583) -> Result<()> {
3584 let db = session.db().await;
3585 let channel_id = ChannelId::from_proto(request.channel_id);
3586
3587 let (collaborators, epoch, version) = db
3588 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3589 .await?;
3590
3591 channel_buffer_updated(
3592 session.connection_id,
3593 collaborators.clone(),
3594 &proto::UpdateChannelBuffer {
3595 channel_id: channel_id.to_proto(),
3596 operations: request.operations,
3597 },
3598 &session.peer,
3599 );
3600
3601 let pool = &*session.connection_pool().await;
3602
3603 let non_collaborators =
3604 pool.channel_connection_ids(channel_id)
3605 .filter_map(|(connection_id, _)| {
3606 if collaborators.contains(&connection_id) {
3607 None
3608 } else {
3609 Some(connection_id)
3610 }
3611 });
3612
3613 broadcast(None, non_collaborators, |peer_id| {
3614 session.peer.send(
3615 peer_id,
3616 proto::UpdateChannels {
3617 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3618 channel_id: channel_id.to_proto(),
3619 epoch: epoch as u64,
3620 version: version.clone(),
3621 }],
3622 ..Default::default()
3623 },
3624 )
3625 });
3626
3627 Ok(())
3628}
3629
3630/// Rejoin the channel notes after a connection blip
3631async fn rejoin_channel_buffers(
3632 request: proto::RejoinChannelBuffers,
3633 response: Response<proto::RejoinChannelBuffers>,
3634 session: Session,
3635) -> Result<()> {
3636 let db = session.db().await;
3637 let buffers = db
3638 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3639 .await?;
3640
3641 for rejoined_buffer in &buffers {
3642 let collaborators_to_notify = rejoined_buffer
3643 .buffer
3644 .collaborators
3645 .iter()
3646 .filter_map(|c| Some(c.peer_id?.into()));
3647 channel_buffer_updated(
3648 session.connection_id,
3649 collaborators_to_notify,
3650 &proto::UpdateChannelBufferCollaborators {
3651 channel_id: rejoined_buffer.buffer.channel_id,
3652 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3653 },
3654 &session.peer,
3655 );
3656 }
3657
3658 response.send(proto::RejoinChannelBuffersResponse {
3659 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3660 })?;
3661
3662 Ok(())
3663}
3664
3665/// Stop editing the channel notes
3666async fn leave_channel_buffer(
3667 request: proto::LeaveChannelBuffer,
3668 response: Response<proto::LeaveChannelBuffer>,
3669 session: Session,
3670) -> Result<()> {
3671 let db = session.db().await;
3672 let channel_id = ChannelId::from_proto(request.channel_id);
3673
3674 let left_buffer = db
3675 .leave_channel_buffer(channel_id, session.connection_id)
3676 .await?;
3677
3678 response.send(Ack {})?;
3679
3680 channel_buffer_updated(
3681 session.connection_id,
3682 left_buffer.connections,
3683 &proto::UpdateChannelBufferCollaborators {
3684 channel_id: channel_id.to_proto(),
3685 collaborators: left_buffer.collaborators,
3686 },
3687 &session.peer,
3688 );
3689
3690 Ok(())
3691}
3692
3693fn channel_buffer_updated<T: EnvelopedMessage>(
3694 sender_id: ConnectionId,
3695 collaborators: impl IntoIterator<Item = ConnectionId>,
3696 message: &T,
3697 peer: &Peer,
3698) {
3699 broadcast(Some(sender_id), collaborators, |peer_id| {
3700 peer.send(peer_id, message.clone())
3701 });
3702}
3703
3704fn send_notifications(
3705 connection_pool: &ConnectionPool,
3706 peer: &Peer,
3707 notifications: db::NotificationBatch,
3708) {
3709 for (user_id, notification) in notifications {
3710 for connection_id in connection_pool.user_connection_ids(user_id) {
3711 if let Err(error) = peer.send(
3712 connection_id,
3713 proto::AddNotification {
3714 notification: Some(notification.clone()),
3715 },
3716 ) {
3717 tracing::error!(
3718 "failed to send notification to {:?} {}",
3719 connection_id,
3720 error
3721 );
3722 }
3723 }
3724 }
3725}
3726
3727/// Send a message to the channel
3728async fn send_channel_message(
3729 request: proto::SendChannelMessage,
3730 response: Response<proto::SendChannelMessage>,
3731 session: Session,
3732) -> Result<()> {
3733 // Validate the message body.
3734 let body = request.body.trim().to_string();
3735 if body.len() > MAX_MESSAGE_LEN {
3736 return Err(anyhow!("message is too long"))?;
3737 }
3738 if body.is_empty() {
3739 return Err(anyhow!("message can't be blank"))?;
3740 }
3741
3742 // TODO: adjust mentions if body is trimmed
3743
3744 let timestamp = OffsetDateTime::now_utc();
3745 let nonce = request.nonce.context("nonce can't be blank")?;
3746
3747 let channel_id = ChannelId::from_proto(request.channel_id);
3748 let CreatedChannelMessage {
3749 message_id,
3750 participant_connection_ids,
3751 notifications,
3752 } = session
3753 .db()
3754 .await
3755 .create_channel_message(
3756 channel_id,
3757 session.user_id(),
3758 &body,
3759 &request.mentions,
3760 timestamp,
3761 nonce.clone().into(),
3762 request.reply_to_message_id.map(MessageId::from_proto),
3763 )
3764 .await?;
3765
3766 let message = proto::ChannelMessage {
3767 sender_id: session.user_id().to_proto(),
3768 id: message_id.to_proto(),
3769 body,
3770 mentions: request.mentions,
3771 timestamp: timestamp.unix_timestamp() as u64,
3772 nonce: Some(nonce),
3773 reply_to_message_id: request.reply_to_message_id,
3774 edited_at: None,
3775 };
3776 broadcast(
3777 Some(session.connection_id),
3778 participant_connection_ids.clone(),
3779 |connection| {
3780 session.peer.send(
3781 connection,
3782 proto::ChannelMessageSent {
3783 channel_id: channel_id.to_proto(),
3784 message: Some(message.clone()),
3785 },
3786 )
3787 },
3788 );
3789 response.send(proto::SendChannelMessageResponse {
3790 message: Some(message),
3791 })?;
3792
3793 let pool = &*session.connection_pool().await;
3794 let non_participants =
3795 pool.channel_connection_ids(channel_id)
3796 .filter_map(|(connection_id, _)| {
3797 if participant_connection_ids.contains(&connection_id) {
3798 None
3799 } else {
3800 Some(connection_id)
3801 }
3802 });
3803 broadcast(None, non_participants, |peer_id| {
3804 session.peer.send(
3805 peer_id,
3806 proto::UpdateChannels {
3807 latest_channel_message_ids: vec![proto::ChannelMessageId {
3808 channel_id: channel_id.to_proto(),
3809 message_id: message_id.to_proto(),
3810 }],
3811 ..Default::default()
3812 },
3813 )
3814 });
3815 send_notifications(pool, &session.peer, notifications);
3816
3817 Ok(())
3818}
3819
3820/// Delete a channel message
3821async fn remove_channel_message(
3822 request: proto::RemoveChannelMessage,
3823 response: Response<proto::RemoveChannelMessage>,
3824 session: Session,
3825) -> Result<()> {
3826 let channel_id = ChannelId::from_proto(request.channel_id);
3827 let message_id = MessageId::from_proto(request.message_id);
3828 let (connection_ids, existing_notification_ids) = session
3829 .db()
3830 .await
3831 .remove_channel_message(channel_id, message_id, session.user_id())
3832 .await?;
3833
3834 broadcast(
3835 Some(session.connection_id),
3836 connection_ids,
3837 move |connection| {
3838 session.peer.send(connection, request.clone())?;
3839
3840 for notification_id in &existing_notification_ids {
3841 session.peer.send(
3842 connection,
3843 proto::DeleteNotification {
3844 notification_id: (*notification_id).to_proto(),
3845 },
3846 )?;
3847 }
3848
3849 Ok(())
3850 },
3851 );
3852 response.send(proto::Ack {})?;
3853 Ok(())
3854}
3855
3856async fn update_channel_message(
3857 request: proto::UpdateChannelMessage,
3858 response: Response<proto::UpdateChannelMessage>,
3859 session: Session,
3860) -> Result<()> {
3861 let channel_id = ChannelId::from_proto(request.channel_id);
3862 let message_id = MessageId::from_proto(request.message_id);
3863 let updated_at = OffsetDateTime::now_utc();
3864 let UpdatedChannelMessage {
3865 message_id,
3866 participant_connection_ids,
3867 notifications,
3868 reply_to_message_id,
3869 timestamp,
3870 deleted_mention_notification_ids,
3871 updated_mention_notifications,
3872 } = session
3873 .db()
3874 .await
3875 .update_channel_message(
3876 channel_id,
3877 message_id,
3878 session.user_id(),
3879 request.body.as_str(),
3880 &request.mentions,
3881 updated_at,
3882 )
3883 .await?;
3884
3885 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3886
3887 let message = proto::ChannelMessage {
3888 sender_id: session.user_id().to_proto(),
3889 id: message_id.to_proto(),
3890 body: request.body.clone(),
3891 mentions: request.mentions.clone(),
3892 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3893 nonce: Some(nonce),
3894 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3895 edited_at: Some(updated_at.unix_timestamp() as u64),
3896 };
3897
3898 response.send(proto::Ack {})?;
3899
3900 let pool = &*session.connection_pool().await;
3901 broadcast(
3902 Some(session.connection_id),
3903 participant_connection_ids,
3904 |connection| {
3905 session.peer.send(
3906 connection,
3907 proto::ChannelMessageUpdate {
3908 channel_id: channel_id.to_proto(),
3909 message: Some(message.clone()),
3910 },
3911 )?;
3912
3913 for notification_id in &deleted_mention_notification_ids {
3914 session.peer.send(
3915 connection,
3916 proto::DeleteNotification {
3917 notification_id: (*notification_id).to_proto(),
3918 },
3919 )?;
3920 }
3921
3922 for notification in &updated_mention_notifications {
3923 session.peer.send(
3924 connection,
3925 proto::UpdateNotification {
3926 notification: Some(notification.clone()),
3927 },
3928 )?;
3929 }
3930
3931 Ok(())
3932 },
3933 );
3934
3935 send_notifications(pool, &session.peer, notifications);
3936
3937 Ok(())
3938}
3939
3940/// Mark a channel message as read
3941async fn acknowledge_channel_message(
3942 request: proto::AckChannelMessage,
3943 session: Session,
3944) -> Result<()> {
3945 let channel_id = ChannelId::from_proto(request.channel_id);
3946 let message_id = MessageId::from_proto(request.message_id);
3947 let notifications = session
3948 .db()
3949 .await
3950 .observe_channel_message(channel_id, session.user_id(), message_id)
3951 .await?;
3952 send_notifications(
3953 &*session.connection_pool().await,
3954 &session.peer,
3955 notifications,
3956 );
3957 Ok(())
3958}
3959
3960/// Mark a buffer version as synced
3961async fn acknowledge_buffer_version(
3962 request: proto::AckBufferOperation,
3963 session: Session,
3964) -> Result<()> {
3965 let buffer_id = BufferId::from_proto(request.buffer_id);
3966 session
3967 .db()
3968 .await
3969 .observe_buffer_version(
3970 buffer_id,
3971 session.user_id(),
3972 request.epoch as i32,
3973 &request.version,
3974 )
3975 .await?;
3976 Ok(())
3977}
3978
3979/// Get a Supermaven API key for the user
3980async fn get_supermaven_api_key(
3981 _request: proto::GetSupermavenApiKey,
3982 response: Response<proto::GetSupermavenApiKey>,
3983 session: Session,
3984) -> Result<()> {
3985 let user_id: String = session.user_id().to_string();
3986 if !session.is_staff() {
3987 return Err(anyhow!("supermaven not enabled for this account"))?;
3988 }
3989
3990 let email = session.email().context("user must have an email")?;
3991
3992 let supermaven_admin_api = session
3993 .supermaven_client
3994 .as_ref()
3995 .context("supermaven not configured")?;
3996
3997 let result = supermaven_admin_api
3998 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3999 .await?;
4000
4001 response.send(proto::GetSupermavenApiKeyResponse {
4002 api_key: result.api_key,
4003 })?;
4004
4005 Ok(())
4006}
4007
4008/// Start receiving chat updates for a channel
4009async fn join_channel_chat(
4010 request: proto::JoinChannelChat,
4011 response: Response<proto::JoinChannelChat>,
4012 session: Session,
4013) -> Result<()> {
4014 let channel_id = ChannelId::from_proto(request.channel_id);
4015
4016 let db = session.db().await;
4017 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4018 .await?;
4019 let messages = db
4020 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4021 .await?;
4022 response.send(proto::JoinChannelChatResponse {
4023 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4024 messages,
4025 })?;
4026 Ok(())
4027}
4028
4029/// Stop receiving chat updates for a channel
4030async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
4031 let channel_id = ChannelId::from_proto(request.channel_id);
4032 session
4033 .db()
4034 .await
4035 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4036 .await?;
4037 Ok(())
4038}
4039
4040/// Retrieve the chat history for a channel
4041async fn get_channel_messages(
4042 request: proto::GetChannelMessages,
4043 response: Response<proto::GetChannelMessages>,
4044 session: Session,
4045) -> Result<()> {
4046 let channel_id = ChannelId::from_proto(request.channel_id);
4047 let messages = session
4048 .db()
4049 .await
4050 .get_channel_messages(
4051 channel_id,
4052 session.user_id(),
4053 MESSAGE_COUNT_PER_PAGE,
4054 Some(MessageId::from_proto(request.before_message_id)),
4055 )
4056 .await?;
4057 response.send(proto::GetChannelMessagesResponse {
4058 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4059 messages,
4060 })?;
4061 Ok(())
4062}
4063
4064/// Retrieve specific chat messages
4065async fn get_channel_messages_by_id(
4066 request: proto::GetChannelMessagesById,
4067 response: Response<proto::GetChannelMessagesById>,
4068 session: Session,
4069) -> Result<()> {
4070 let message_ids = request
4071 .message_ids
4072 .iter()
4073 .map(|id| MessageId::from_proto(*id))
4074 .collect::<Vec<_>>();
4075 let messages = session
4076 .db()
4077 .await
4078 .get_channel_messages_by_id(session.user_id(), &message_ids)
4079 .await?;
4080 response.send(proto::GetChannelMessagesResponse {
4081 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4082 messages,
4083 })?;
4084 Ok(())
4085}
4086
4087/// Retrieve the current users notifications
4088async fn get_notifications(
4089 request: proto::GetNotifications,
4090 response: Response<proto::GetNotifications>,
4091 session: Session,
4092) -> Result<()> {
4093 let notifications = session
4094 .db()
4095 .await
4096 .get_notifications(
4097 session.user_id(),
4098 NOTIFICATION_COUNT_PER_PAGE,
4099 request.before_id.map(db::NotificationId::from_proto),
4100 )
4101 .await?;
4102 response.send(proto::GetNotificationsResponse {
4103 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4104 notifications,
4105 })?;
4106 Ok(())
4107}
4108
4109/// Mark notifications as read
4110async fn mark_notification_as_read(
4111 request: proto::MarkNotificationRead,
4112 response: Response<proto::MarkNotificationRead>,
4113 session: Session,
4114) -> Result<()> {
4115 let database = &session.db().await;
4116 let notifications = database
4117 .mark_notification_as_read_by_id(
4118 session.user_id(),
4119 NotificationId::from_proto(request.notification_id),
4120 )
4121 .await?;
4122 send_notifications(
4123 &*session.connection_pool().await,
4124 &session.peer,
4125 notifications,
4126 );
4127 response.send(proto::Ack {})?;
4128 Ok(())
4129}
4130
4131/// Get the current users information
4132async fn get_private_user_info(
4133 _request: proto::GetPrivateUserInfo,
4134 response: Response<proto::GetPrivateUserInfo>,
4135 session: Session,
4136) -> Result<()> {
4137 let db = session.db().await;
4138
4139 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4140 let user = db
4141 .get_user_by_id(session.user_id())
4142 .await?
4143 .context("user not found")?;
4144 let flags = db.get_user_flags(session.user_id()).await?;
4145
4146 response.send(proto::GetPrivateUserInfoResponse {
4147 metrics_id,
4148 staff: user.admin,
4149 flags,
4150 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4151 })?;
4152 Ok(())
4153}
4154
4155/// Accept the terms of service (tos) on behalf of the current user
4156async fn accept_terms_of_service(
4157 _request: proto::AcceptTermsOfService,
4158 response: Response<proto::AcceptTermsOfService>,
4159 session: Session,
4160) -> Result<()> {
4161 let db = session.db().await;
4162
4163 let accepted_tos_at = Utc::now();
4164 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4165 .await?;
4166
4167 response.send(proto::AcceptTermsOfServiceResponse {
4168 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4169 })?;
4170
4171 // When the user accepts the terms of service, we want to refresh their LLM
4172 // token to grant access.
4173 session
4174 .peer
4175 .send(session.connection_id, proto::RefreshLlmToken {})?;
4176
4177 Ok(())
4178}
4179
4180async fn get_llm_api_token(
4181 _request: proto::GetLlmToken,
4182 response: Response<proto::GetLlmToken>,
4183 session: Session,
4184) -> Result<()> {
4185 let db = session.db().await;
4186
4187 let flags = db.get_user_flags(session.user_id()).await?;
4188
4189 let user_id = session.user_id();
4190 let user = db
4191 .get_user_by_id(user_id)
4192 .await?
4193 .with_context(|| format!("user {user_id} not found"))?;
4194
4195 if user.accepted_tos_at.is_none() {
4196 Err(anyhow!("terms of service not accepted"))?
4197 }
4198
4199 let stripe_client = session
4200 .app_state
4201 .stripe_client
4202 .as_ref()
4203 .context("failed to retrieve Stripe client")?;
4204
4205 let stripe_billing = session
4206 .app_state
4207 .stripe_billing
4208 .as_ref()
4209 .context("failed to retrieve Stripe billing object")?;
4210
4211 let billing_customer = if let Some(billing_customer) =
4212 db.get_billing_customer_by_user_id(user.id).await?
4213 {
4214 billing_customer
4215 } else {
4216 let customer_id = stripe_billing
4217 .find_or_create_customer_by_email(user.email_address.as_deref())
4218 .await?;
4219
4220 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4221 .await?
4222 .context("billing customer not found")?
4223 };
4224
4225 let billing_subscription =
4226 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4227 billing_subscription
4228 } else {
4229 let stripe_customer_id =
4230 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4231
4232 let stripe_subscription = stripe_billing
4233 .subscribe_to_zed_free(stripe_customer_id)
4234 .await?;
4235
4236 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4237 billing_customer_id: billing_customer.id,
4238 kind: Some(SubscriptionKind::ZedFree),
4239 stripe_subscription_id: stripe_subscription.id.to_string(),
4240 stripe_subscription_status: stripe_subscription.status.into(),
4241 stripe_cancellation_reason: None,
4242 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4243 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4244 })
4245 .await?
4246 };
4247
4248 let billing_preferences = db.get_billing_preferences(user.id).await?;
4249
4250 let token = LlmTokenClaims::create(
4251 &user,
4252 session.is_staff(),
4253 billing_customer,
4254 billing_preferences,
4255 &flags,
4256 billing_subscription,
4257 session.system_id.clone(),
4258 &session.app_state.config,
4259 )?;
4260 response.send(proto::GetLlmTokenResponse { token })?;
4261 Ok(())
4262}
4263
4264fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4265 let message = match message {
4266 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4267 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4268 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4269 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4270 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4271 code: frame.code.into(),
4272 reason: frame.reason.as_str().to_owned().into(),
4273 })),
4274 // We should never receive a frame while reading the message, according
4275 // to the `tungstenite` maintainers:
4276 //
4277 // > It cannot occur when you read messages from the WebSocket, but it
4278 // > can be used when you want to send the raw frames (e.g. you want to
4279 // > send the frames to the WebSocket without composing the full message first).
4280 // >
4281 // > — https://github.com/snapview/tungstenite-rs/issues/268
4282 TungsteniteMessage::Frame(_) => {
4283 bail!("received an unexpected frame while reading the message")
4284 }
4285 };
4286
4287 Ok(message)
4288}
4289
4290fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4291 match message {
4292 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4293 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4294 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4295 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4296 AxumMessage::Close(frame) => {
4297 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4298 code: frame.code.into(),
4299 reason: frame.reason.as_ref().into(),
4300 }))
4301 }
4302 }
4303}
4304
4305fn notify_membership_updated(
4306 connection_pool: &mut ConnectionPool,
4307 result: MembershipUpdated,
4308 user_id: UserId,
4309 peer: &Peer,
4310) {
4311 for membership in &result.new_channels.channel_memberships {
4312 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4313 }
4314 for channel_id in &result.removed_channels {
4315 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4316 }
4317
4318 let user_channels_update = proto::UpdateUserChannels {
4319 channel_memberships: result
4320 .new_channels
4321 .channel_memberships
4322 .iter()
4323 .map(|cm| proto::ChannelMembership {
4324 channel_id: cm.channel_id.to_proto(),
4325 role: cm.role.into(),
4326 })
4327 .collect(),
4328 ..Default::default()
4329 };
4330
4331 let mut update = build_channels_update(result.new_channels);
4332 update.delete_channels = result
4333 .removed_channels
4334 .into_iter()
4335 .map(|id| id.to_proto())
4336 .collect();
4337 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4338
4339 for connection_id in connection_pool.user_connection_ids(user_id) {
4340 peer.send(connection_id, user_channels_update.clone())
4341 .trace_err();
4342 peer.send(connection_id, update.clone()).trace_err();
4343 }
4344}
4345
4346fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4347 proto::UpdateUserChannels {
4348 channel_memberships: channels
4349 .channel_memberships
4350 .iter()
4351 .map(|m| proto::ChannelMembership {
4352 channel_id: m.channel_id.to_proto(),
4353 role: m.role.into(),
4354 })
4355 .collect(),
4356 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4357 observed_channel_message_id: channels.observed_channel_messages.clone(),
4358 }
4359}
4360
4361fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4362 let mut update = proto::UpdateChannels::default();
4363
4364 for channel in channels.channels {
4365 update.channels.push(channel.to_proto());
4366 }
4367
4368 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4369 update.latest_channel_message_ids = channels.latest_channel_messages;
4370
4371 for (channel_id, participants) in channels.channel_participants {
4372 update
4373 .channel_participants
4374 .push(proto::ChannelParticipants {
4375 channel_id: channel_id.to_proto(),
4376 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4377 });
4378 }
4379
4380 for channel in channels.invited_channels {
4381 update.channel_invitations.push(channel.to_proto());
4382 }
4383
4384 update
4385}
4386
4387fn build_initial_contacts_update(
4388 contacts: Vec<db::Contact>,
4389 pool: &ConnectionPool,
4390) -> proto::UpdateContacts {
4391 let mut update = proto::UpdateContacts::default();
4392
4393 for contact in contacts {
4394 match contact {
4395 db::Contact::Accepted { user_id, busy } => {
4396 update.contacts.push(contact_for_user(user_id, busy, pool));
4397 }
4398 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4399 db::Contact::Incoming { user_id } => {
4400 update
4401 .incoming_requests
4402 .push(proto::IncomingContactRequest {
4403 requester_id: user_id.to_proto(),
4404 })
4405 }
4406 }
4407 }
4408
4409 update
4410}
4411
4412fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4413 proto::Contact {
4414 user_id: user_id.to_proto(),
4415 online: pool.is_user_online(user_id),
4416 busy,
4417 }
4418}
4419
4420fn room_updated(room: &proto::Room, peer: &Peer) {
4421 broadcast(
4422 None,
4423 room.participants
4424 .iter()
4425 .filter_map(|participant| Some(participant.peer_id?.into())),
4426 |peer_id| {
4427 peer.send(
4428 peer_id,
4429 proto::RoomUpdated {
4430 room: Some(room.clone()),
4431 },
4432 )
4433 },
4434 );
4435}
4436
4437fn channel_updated(
4438 channel: &db::channel::Model,
4439 room: &proto::Room,
4440 peer: &Peer,
4441 pool: &ConnectionPool,
4442) {
4443 let participants = room
4444 .participants
4445 .iter()
4446 .map(|p| p.user_id)
4447 .collect::<Vec<_>>();
4448
4449 broadcast(
4450 None,
4451 pool.channel_connection_ids(channel.root_id())
4452 .filter_map(|(channel_id, role)| {
4453 role.can_see_channel(channel.visibility)
4454 .then_some(channel_id)
4455 }),
4456 |peer_id| {
4457 peer.send(
4458 peer_id,
4459 proto::UpdateChannels {
4460 channel_participants: vec![proto::ChannelParticipants {
4461 channel_id: channel.id.to_proto(),
4462 participant_user_ids: participants.clone(),
4463 }],
4464 ..Default::default()
4465 },
4466 )
4467 },
4468 );
4469}
4470
4471async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4472 let db = session.db().await;
4473
4474 let contacts = db.get_contacts(user_id).await?;
4475 let busy = db.is_user_busy(user_id).await?;
4476
4477 let pool = session.connection_pool().await;
4478 let updated_contact = contact_for_user(user_id, busy, &pool);
4479 for contact in contacts {
4480 if let db::Contact::Accepted {
4481 user_id: contact_user_id,
4482 ..
4483 } = contact
4484 {
4485 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4486 session
4487 .peer
4488 .send(
4489 contact_conn_id,
4490 proto::UpdateContacts {
4491 contacts: vec![updated_contact.clone()],
4492 remove_contacts: Default::default(),
4493 incoming_requests: Default::default(),
4494 remove_incoming_requests: Default::default(),
4495 outgoing_requests: Default::default(),
4496 remove_outgoing_requests: Default::default(),
4497 },
4498 )
4499 .trace_err();
4500 }
4501 }
4502 }
4503 Ok(())
4504}
4505
4506async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4507 let mut contacts_to_update = HashSet::default();
4508
4509 let room_id;
4510 let canceled_calls_to_user_ids;
4511 let livekit_room;
4512 let delete_livekit_room;
4513 let room;
4514 let channel;
4515
4516 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4517 contacts_to_update.insert(session.user_id());
4518
4519 for project in left_room.left_projects.values() {
4520 project_left(project, session);
4521 }
4522
4523 room_id = RoomId::from_proto(left_room.room.id);
4524 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4525 livekit_room = mem::take(&mut left_room.room.livekit_room);
4526 delete_livekit_room = left_room.deleted;
4527 room = mem::take(&mut left_room.room);
4528 channel = mem::take(&mut left_room.channel);
4529
4530 room_updated(&room, &session.peer);
4531 } else {
4532 return Ok(());
4533 }
4534
4535 if let Some(channel) = channel {
4536 channel_updated(
4537 &channel,
4538 &room,
4539 &session.peer,
4540 &*session.connection_pool().await,
4541 );
4542 }
4543
4544 {
4545 let pool = session.connection_pool().await;
4546 for canceled_user_id in canceled_calls_to_user_ids {
4547 for connection_id in pool.user_connection_ids(canceled_user_id) {
4548 session
4549 .peer
4550 .send(
4551 connection_id,
4552 proto::CallCanceled {
4553 room_id: room_id.to_proto(),
4554 },
4555 )
4556 .trace_err();
4557 }
4558 contacts_to_update.insert(canceled_user_id);
4559 }
4560 }
4561
4562 for contact_user_id in contacts_to_update {
4563 update_user_contacts(contact_user_id, session).await?;
4564 }
4565
4566 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4567 live_kit
4568 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4569 .await
4570 .trace_err();
4571
4572 if delete_livekit_room {
4573 live_kit.delete_room(livekit_room).await.trace_err();
4574 }
4575 }
4576
4577 Ok(())
4578}
4579
4580async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4581 let left_channel_buffers = session
4582 .db()
4583 .await
4584 .leave_channel_buffers(session.connection_id)
4585 .await?;
4586
4587 for left_buffer in left_channel_buffers {
4588 channel_buffer_updated(
4589 session.connection_id,
4590 left_buffer.connections,
4591 &proto::UpdateChannelBufferCollaborators {
4592 channel_id: left_buffer.channel_id.to_proto(),
4593 collaborators: left_buffer.collaborators,
4594 },
4595 &session.peer,
4596 );
4597 }
4598
4599 Ok(())
4600}
4601
4602fn project_left(project: &db::LeftProject, session: &Session) {
4603 for connection_id in &project.connection_ids {
4604 if project.should_unshare {
4605 session
4606 .peer
4607 .send(
4608 *connection_id,
4609 proto::UnshareProject {
4610 project_id: project.id.to_proto(),
4611 },
4612 )
4613 .trace_err();
4614 } else {
4615 session
4616 .peer
4617 .send(
4618 *connection_id,
4619 proto::RemoveProjectCollaborator {
4620 project_id: project.id.to_proto(),
4621 peer_id: Some(session.connection_id.into()),
4622 },
4623 )
4624 .trace_err();
4625 }
4626 }
4627}
4628
4629pub trait ResultExt {
4630 type Ok;
4631
4632 fn trace_err(self) -> Option<Self::Ok>;
4633}
4634
4635impl<T, E> ResultExt for Result<T, E>
4636where
4637 E: std::fmt::Debug,
4638{
4639 type Ok = T;
4640
4641 #[track_caller]
4642 fn trace_err(self) -> Option<T> {
4643 match self {
4644 Ok(value) => Some(value),
4645 Err(error) => {
4646 tracing::error!("{:?}", error);
4647 None
4648 }
4649 }
4650 }
4651}