1mod connection_pool;
2
3use crate::api::billing::find_or_create_billing_customer;
4use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
5use crate::db::billing_subscription::SubscriptionKind;
6use crate::llm::db::LlmDatabase;
7use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, LlmTokenClaims};
8use crate::stripe_client::StripeCustomerId;
9use crate::{
10 AppState, Error, Result, auth,
11 db::{
12 self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
13 CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
14 NotificationId, Project, ProjectId, RejoinedProject, RemoveChannelMemberResult, ReplicaId,
15 RespondToChannelInvite, RoomId, ServerId, UpdatedChannelMessage, User, UserId,
16 },
17 executor::Executor,
18};
19use anyhow::{Context as _, anyhow, bail};
20use async_tungstenite::tungstenite::{
21 Message as TungsteniteMessage, protocol::CloseFrame as TungsteniteCloseFrame,
22};
23use axum::{
24 Extension, Router, TypedHeader,
25 body::Body,
26 extract::{
27 ConnectInfo, WebSocketUpgrade,
28 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
29 },
30 headers::{Header, HeaderName},
31 http::StatusCode,
32 middleware,
33 response::IntoResponse,
34 routing::get,
35};
36use chrono::Utc;
37use collections::{HashMap, HashSet};
38pub use connection_pool::{ConnectionPool, ZedVersion};
39use core::fmt::{self, Debug, Formatter};
40use reqwest_client::ReqwestClient;
41use rpc::proto::split_repository_update;
42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
43
44use futures::{
45 FutureExt, SinkExt, StreamExt, TryStreamExt, channel::oneshot, future::BoxFuture,
46 stream::FuturesUnordered,
47};
48use prometheus::{IntGauge, register_int_gauge};
49use rpc::{
50 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
51 proto::{
52 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
53 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
54 },
55};
56use semantic_version::SemanticVersion;
57use serde::{Serialize, Serializer};
58use std::{
59 any::TypeId,
60 future::Future,
61 marker::PhantomData,
62 mem,
63 net::SocketAddr,
64 ops::{Deref, DerefMut},
65 rc::Rc,
66 sync::{
67 Arc, OnceLock,
68 atomic::{AtomicBool, Ordering::SeqCst},
69 },
70 time::{Duration, Instant},
71};
72use time::OffsetDateTime;
73use tokio::sync::{Semaphore, watch};
74use tower::ServiceBuilder;
75use tracing::{
76 Instrument,
77 field::{self},
78 info_span, instrument,
79};
80
81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
82
83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
85
86const MESSAGE_COUNT_PER_PAGE: usize = 100;
87const MAX_MESSAGE_LEN: usize = 1024;
88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
92
93struct Response<R> {
94 peer: Arc<Peer>,
95 receipt: Receipt<R>,
96 responded: Arc<AtomicBool>,
97}
98
99impl<R: RequestMessage> Response<R> {
100 fn send(self, payload: R::Response) -> Result<()> {
101 self.responded.store(true, SeqCst);
102 self.peer.respond(self.receipt, payload)?;
103 Ok(())
104 }
105}
106
107#[derive(Clone, Debug)]
108pub enum Principal {
109 User(User),
110 Impersonated { user: User, admin: User },
111}
112
113impl Principal {
114 fn user(&self) -> &User {
115 match self {
116 Principal::User(user) => user,
117 Principal::Impersonated { user, .. } => user,
118 }
119 }
120
121 fn update_span(&self, span: &tracing::Span) {
122 match &self {
123 Principal::User(user) => {
124 span.record("user_id", user.id.0);
125 span.record("login", &user.github_login);
126 }
127 Principal::Impersonated { user, admin } => {
128 span.record("user_id", user.id.0);
129 span.record("login", &user.github_login);
130 span.record("impersonator", &admin.github_login);
131 }
132 }
133 }
134}
135
136#[derive(Clone)]
137struct Session {
138 principal: Principal,
139 connection_id: ConnectionId,
140 db: Arc<tokio::sync::Mutex<DbHandle>>,
141 peer: Arc<Peer>,
142 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
143 app_state: Arc<AppState>,
144 supermaven_client: Option<Arc<SupermavenAdminApi>>,
145 /// The GeoIP country code for the user.
146 #[allow(unused)]
147 geoip_country_code: Option<String>,
148 system_id: Option<String>,
149 _executor: Executor,
150}
151
152impl Session {
153 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
154 #[cfg(test)]
155 tokio::task::yield_now().await;
156 let guard = self.db.lock().await;
157 #[cfg(test)]
158 tokio::task::yield_now().await;
159 guard
160 }
161
162 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 let guard = self.connection_pool.lock();
166 ConnectionPoolGuard {
167 guard,
168 _not_send: PhantomData,
169 }
170 }
171
172 fn is_staff(&self) -> bool {
173 match &self.principal {
174 Principal::User(user) => user.admin,
175 Principal::Impersonated { .. } => true,
176 }
177 }
178
179 fn user_id(&self) -> UserId {
180 match &self.principal {
181 Principal::User(user) => user.id,
182 Principal::Impersonated { user, .. } => user.id,
183 }
184 }
185
186 pub fn email(&self) -> Option<String> {
187 match &self.principal {
188 Principal::User(user) => user.email_address.clone(),
189 Principal::Impersonated { user, .. } => user.email_address.clone(),
190 }
191 }
192}
193
194impl Debug for Session {
195 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
196 let mut result = f.debug_struct("Session");
197 match &self.principal {
198 Principal::User(user) => {
199 result.field("user", &user.github_login);
200 }
201 Principal::Impersonated { user, admin } => {
202 result.field("user", &user.github_login);
203 result.field("impersonator", &admin.github_login);
204 }
205 }
206 result.field("connection_id", &self.connection_id).finish()
207 }
208}
209
210struct DbHandle(Arc<Database>);
211
212impl Deref for DbHandle {
213 type Target = Database;
214
215 fn deref(&self) -> &Self::Target {
216 self.0.as_ref()
217 }
218}
219
220pub struct Server {
221 id: parking_lot::Mutex<ServerId>,
222 peer: Arc<Peer>,
223 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
224 app_state: Arc<AppState>,
225 handlers: HashMap<TypeId, MessageHandler>,
226 teardown: watch::Sender<bool>,
227}
228
229pub(crate) struct ConnectionPoolGuard<'a> {
230 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
231 _not_send: PhantomData<Rc<()>>,
232}
233
234#[derive(Serialize)]
235pub struct ServerSnapshot<'a> {
236 peer: &'a Peer,
237 #[serde(serialize_with = "serialize_deref")]
238 connection_pool: ConnectionPoolGuard<'a>,
239}
240
241pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
242where
243 S: Serializer,
244 T: Deref<Target = U>,
245 U: Serialize,
246{
247 Serialize::serialize(value.deref(), serializer)
248}
249
250impl Server {
251 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
252 let mut server = Self {
253 id: parking_lot::Mutex::new(id),
254 peer: Peer::new(id.0 as u32),
255 app_state: app_state.clone(),
256 connection_pool: Default::default(),
257 handlers: Default::default(),
258 teardown: watch::channel(false).0,
259 };
260
261 server
262 .add_request_handler(ping)
263 .add_request_handler(create_room)
264 .add_request_handler(join_room)
265 .add_request_handler(rejoin_room)
266 .add_request_handler(leave_room)
267 .add_request_handler(set_room_participant_role)
268 .add_request_handler(call)
269 .add_request_handler(cancel_call)
270 .add_message_handler(decline_call)
271 .add_request_handler(update_participant_location)
272 .add_request_handler(share_project)
273 .add_message_handler(unshare_project)
274 .add_request_handler(join_project)
275 .add_message_handler(leave_project)
276 .add_request_handler(update_project)
277 .add_request_handler(update_worktree)
278 .add_request_handler(update_repository)
279 .add_request_handler(remove_repository)
280 .add_message_handler(start_language_server)
281 .add_message_handler(update_language_server)
282 .add_message_handler(update_diagnostic_summary)
283 .add_message_handler(update_worktree_settings)
284 .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
285 .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
286 .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
287 .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
288 .add_request_handler(forward_find_search_candidates_request)
289 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
290 .add_request_handler(forward_read_only_project_request::<proto::GetDocumentSymbols>)
291 .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
292 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
293 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
294 .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
295 .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
296 .add_request_handler(forward_read_only_project_request::<proto::ResolveInlayHint>)
297 .add_request_handler(forward_mutating_project_request::<proto::GetCodeLens>)
298 .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
299 .add_request_handler(forward_read_only_project_request::<proto::GitGetBranches>)
300 .add_request_handler(forward_read_only_project_request::<proto::OpenUnstagedDiff>)
301 .add_request_handler(forward_read_only_project_request::<proto::OpenUncommittedDiff>)
302 .add_request_handler(forward_read_only_project_request::<proto::LspExtExpandMacro>)
303 .add_request_handler(forward_read_only_project_request::<proto::LspExtOpenDocs>)
304 .add_request_handler(forward_mutating_project_request::<proto::LspExtRunnables>)
305 .add_request_handler(
306 forward_read_only_project_request::<proto::LspExtSwitchSourceHeader>,
307 )
308 .add_request_handler(forward_read_only_project_request::<proto::LspExtGoToParentModule>)
309 .add_request_handler(forward_read_only_project_request::<proto::LspExtCancelFlycheck>)
310 .add_request_handler(forward_read_only_project_request::<proto::LspExtRunFlycheck>)
311 .add_request_handler(forward_read_only_project_request::<proto::LspExtClearFlycheck>)
312 .add_request_handler(
313 forward_read_only_project_request::<proto::LanguageServerIdForName>,
314 )
315 .add_request_handler(
316 forward_mutating_project_request::<proto::RegisterBufferWithLanguageServers>,
317 )
318 .add_request_handler(forward_mutating_project_request::<proto::UpdateGitBranch>)
319 .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
320 .add_request_handler(
321 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
322 )
323 .add_request_handler(forward_mutating_project_request::<proto::OpenNewBuffer>)
324 .add_request_handler(
325 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
326 )
327 .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
328 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
329 .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
330 .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
331 .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
332 .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeActionKind>)
333 .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
334 .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
335 .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
336 .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
337 .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
338 .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
339 .add_request_handler(
340 forward_mutating_project_request::<proto::ExpandAllForProjectEntry>,
341 )
342 .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
343 .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
344 .add_request_handler(forward_mutating_project_request::<proto::BlameBuffer>)
345 .add_request_handler(forward_mutating_project_request::<proto::MultiLspQuery>)
346 .add_request_handler(forward_mutating_project_request::<proto::RestartLanguageServers>)
347 .add_request_handler(forward_mutating_project_request::<proto::StopLanguageServers>)
348 .add_request_handler(forward_mutating_project_request::<proto::LinkedEditingRange>)
349 .add_message_handler(create_buffer_for_peer)
350 .add_request_handler(update_buffer)
351 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
352 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshCodeLens>)
353 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
354 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
355 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
356 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBases>)
357 .add_request_handler(get_users)
358 .add_request_handler(fuzzy_search_users)
359 .add_request_handler(request_contact)
360 .add_request_handler(remove_contact)
361 .add_request_handler(respond_to_contact_request)
362 .add_message_handler(subscribe_to_channels)
363 .add_request_handler(create_channel)
364 .add_request_handler(delete_channel)
365 .add_request_handler(invite_channel_member)
366 .add_request_handler(remove_channel_member)
367 .add_request_handler(set_channel_member_role)
368 .add_request_handler(set_channel_visibility)
369 .add_request_handler(rename_channel)
370 .add_request_handler(join_channel_buffer)
371 .add_request_handler(leave_channel_buffer)
372 .add_message_handler(update_channel_buffer)
373 .add_request_handler(rejoin_channel_buffers)
374 .add_request_handler(get_channel_members)
375 .add_request_handler(respond_to_channel_invite)
376 .add_request_handler(join_channel)
377 .add_request_handler(join_channel_chat)
378 .add_message_handler(leave_channel_chat)
379 .add_request_handler(send_channel_message)
380 .add_request_handler(remove_channel_message)
381 .add_request_handler(update_channel_message)
382 .add_request_handler(get_channel_messages)
383 .add_request_handler(get_channel_messages_by_id)
384 .add_request_handler(get_notifications)
385 .add_request_handler(mark_notification_as_read)
386 .add_request_handler(move_channel)
387 .add_request_handler(reorder_channel)
388 .add_request_handler(follow)
389 .add_message_handler(unfollow)
390 .add_message_handler(update_followers)
391 .add_request_handler(get_private_user_info)
392 .add_request_handler(get_llm_api_token)
393 .add_request_handler(accept_terms_of_service)
394 .add_message_handler(acknowledge_channel_message)
395 .add_message_handler(acknowledge_buffer_version)
396 .add_request_handler(get_supermaven_api_key)
397 .add_request_handler(forward_mutating_project_request::<proto::OpenContext>)
398 .add_request_handler(forward_mutating_project_request::<proto::CreateContext>)
399 .add_request_handler(forward_mutating_project_request::<proto::SynchronizeContexts>)
400 .add_request_handler(forward_mutating_project_request::<proto::Stage>)
401 .add_request_handler(forward_mutating_project_request::<proto::Unstage>)
402 .add_request_handler(forward_mutating_project_request::<proto::Commit>)
403 .add_request_handler(forward_mutating_project_request::<proto::GitInit>)
404 .add_request_handler(forward_read_only_project_request::<proto::GetRemotes>)
405 .add_request_handler(forward_read_only_project_request::<proto::GitShow>)
406 .add_request_handler(forward_read_only_project_request::<proto::LoadCommitDiff>)
407 .add_request_handler(forward_read_only_project_request::<proto::GitReset>)
408 .add_request_handler(forward_read_only_project_request::<proto::GitCheckoutFiles>)
409 .add_request_handler(forward_mutating_project_request::<proto::SetIndexText>)
410 .add_request_handler(forward_mutating_project_request::<proto::ToggleBreakpoint>)
411 .add_message_handler(broadcast_project_message_from_host::<proto::BreakpointsForFile>)
412 .add_request_handler(forward_mutating_project_request::<proto::OpenCommitMessageBuffer>)
413 .add_request_handler(forward_mutating_project_request::<proto::GitDiff>)
414 .add_request_handler(forward_mutating_project_request::<proto::GitCreateBranch>)
415 .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
416 .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
417 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
418 .add_message_handler(update_context);
419
420 Arc::new(server)
421 }
422
423 pub async fn start(&self) -> Result<()> {
424 let server_id = *self.id.lock();
425 let app_state = self.app_state.clone();
426 let peer = self.peer.clone();
427 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
428 let pool = self.connection_pool.clone();
429 let livekit_client = self.app_state.livekit_client.clone();
430
431 let span = info_span!("start server");
432 self.app_state.executor.spawn_detached(
433 async move {
434 tracing::info!("waiting for cleanup timeout");
435 timeout.await;
436 tracing::info!("cleanup timeout expired, retrieving stale rooms");
437
438 app_state
439 .db
440 .delete_stale_channel_chat_participants(
441 &app_state.config.zed_environment,
442 server_id,
443 )
444 .await
445 .trace_err();
446
447 if let Some((room_ids, channel_ids)) = app_state
448 .db
449 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
450 .await
451 .trace_err()
452 {
453 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
454 tracing::info!(
455 stale_channel_buffer_count = channel_ids.len(),
456 "retrieved stale channel buffers"
457 );
458
459 for channel_id in channel_ids {
460 if let Some(refreshed_channel_buffer) = app_state
461 .db
462 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
463 .await
464 .trace_err()
465 {
466 for connection_id in refreshed_channel_buffer.connection_ids {
467 peer.send(
468 connection_id,
469 proto::UpdateChannelBufferCollaborators {
470 channel_id: channel_id.to_proto(),
471 collaborators: refreshed_channel_buffer
472 .collaborators
473 .clone(),
474 },
475 )
476 .trace_err();
477 }
478 }
479 }
480
481 for room_id in room_ids {
482 let mut contacts_to_update = HashSet::default();
483 let mut canceled_calls_to_user_ids = Vec::new();
484 let mut livekit_room = String::new();
485 let mut delete_livekit_room = false;
486
487 if let Some(mut refreshed_room) = app_state
488 .db
489 .clear_stale_room_participants(room_id, server_id)
490 .await
491 .trace_err()
492 {
493 tracing::info!(
494 room_id = room_id.0,
495 new_participant_count = refreshed_room.room.participants.len(),
496 "refreshed room"
497 );
498 room_updated(&refreshed_room.room, &peer);
499 if let Some(channel) = refreshed_room.channel.as_ref() {
500 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
501 }
502 contacts_to_update
503 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
504 contacts_to_update
505 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
506 canceled_calls_to_user_ids =
507 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
508 livekit_room = mem::take(&mut refreshed_room.room.livekit_room);
509 delete_livekit_room = refreshed_room.room.participants.is_empty();
510 }
511
512 {
513 let pool = pool.lock();
514 for canceled_user_id in canceled_calls_to_user_ids {
515 for connection_id in pool.user_connection_ids(canceled_user_id) {
516 peer.send(
517 connection_id,
518 proto::CallCanceled {
519 room_id: room_id.to_proto(),
520 },
521 )
522 .trace_err();
523 }
524 }
525 }
526
527 for user_id in contacts_to_update {
528 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
529 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
530 if let Some((busy, contacts)) = busy.zip(contacts) {
531 let pool = pool.lock();
532 let updated_contact = contact_for_user(user_id, busy, &pool);
533 for contact in contacts {
534 if let db::Contact::Accepted {
535 user_id: contact_user_id,
536 ..
537 } = contact
538 {
539 for contact_conn_id in
540 pool.user_connection_ids(contact_user_id)
541 {
542 peer.send(
543 contact_conn_id,
544 proto::UpdateContacts {
545 contacts: vec![updated_contact.clone()],
546 remove_contacts: Default::default(),
547 incoming_requests: Default::default(),
548 remove_incoming_requests: Default::default(),
549 outgoing_requests: Default::default(),
550 remove_outgoing_requests: Default::default(),
551 },
552 )
553 .trace_err();
554 }
555 }
556 }
557 }
558 }
559
560 if let Some(live_kit) = livekit_client.as_ref() {
561 if delete_livekit_room {
562 live_kit.delete_room(livekit_room).await.trace_err();
563 }
564 }
565 }
566 }
567
568 app_state
569 .db
570 .delete_stale_channel_chat_participants(
571 &app_state.config.zed_environment,
572 server_id,
573 )
574 .await
575 .trace_err();
576
577 app_state
578 .db
579 .clear_old_worktree_entries(server_id)
580 .await
581 .trace_err();
582
583 app_state
584 .db
585 .delete_stale_servers(&app_state.config.zed_environment, server_id)
586 .await
587 .trace_err();
588 }
589 .instrument(span),
590 );
591 Ok(())
592 }
593
594 pub fn teardown(&self) {
595 self.peer.teardown();
596 self.connection_pool.lock().reset();
597 let _ = self.teardown.send(true);
598 }
599
600 #[cfg(test)]
601 pub fn reset(&self, id: ServerId) {
602 self.teardown();
603 *self.id.lock() = id;
604 self.peer.reset(id.0 as u32);
605 let _ = self.teardown.send(false);
606 }
607
608 #[cfg(test)]
609 pub fn id(&self) -> ServerId {
610 *self.id.lock()
611 }
612
613 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
614 where
615 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
616 Fut: 'static + Send + Future<Output = Result<()>>,
617 M: EnvelopedMessage,
618 {
619 let prev_handler = self.handlers.insert(
620 TypeId::of::<M>(),
621 Box::new(move |envelope, session| {
622 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
623 let received_at = envelope.received_at;
624 tracing::info!("message received");
625 let start_time = Instant::now();
626 let future = (handler)(*envelope, session);
627 async move {
628 let result = future.await;
629 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
630 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
631 let queue_duration_ms = total_duration_ms - processing_duration_ms;
632 let payload_type = M::NAME;
633
634 match result {
635 Err(error) => {
636 tracing::error!(
637 ?error,
638 total_duration_ms,
639 processing_duration_ms,
640 queue_duration_ms,
641 payload_type,
642 "error handling message"
643 )
644 }
645 Ok(()) => tracing::info!(
646 total_duration_ms,
647 processing_duration_ms,
648 queue_duration_ms,
649 "finished handling message"
650 ),
651 }
652 }
653 .boxed()
654 }),
655 );
656 if prev_handler.is_some() {
657 panic!("registered a handler for the same message twice");
658 }
659 self
660 }
661
662 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
663 where
664 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
665 Fut: 'static + Send + Future<Output = Result<()>>,
666 M: EnvelopedMessage,
667 {
668 self.add_handler(move |envelope, session| handler(envelope.payload, session));
669 self
670 }
671
672 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
673 where
674 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
675 Fut: Send + Future<Output = Result<()>>,
676 M: RequestMessage,
677 {
678 let handler = Arc::new(handler);
679 self.add_handler(move |envelope, session| {
680 let receipt = envelope.receipt();
681 let handler = handler.clone();
682 async move {
683 let peer = session.peer.clone();
684 let responded = Arc::new(AtomicBool::default());
685 let response = Response {
686 peer: peer.clone(),
687 responded: responded.clone(),
688 receipt,
689 };
690 match (handler)(envelope.payload, response, session).await {
691 Ok(()) => {
692 if responded.load(std::sync::atomic::Ordering::SeqCst) {
693 Ok(())
694 } else {
695 Err(anyhow!("handler did not send a response"))?
696 }
697 }
698 Err(error) => {
699 let proto_err = match &error {
700 Error::Internal(err) => err.to_proto(),
701 _ => ErrorCode::Internal.message(format!("{error}")).to_proto(),
702 };
703 peer.respond_with_error(receipt, proto_err)?;
704 Err(error)
705 }
706 }
707 }
708 })
709 }
710
711 pub fn handle_connection(
712 self: &Arc<Self>,
713 connection: Connection,
714 address: String,
715 principal: Principal,
716 zed_version: ZedVersion,
717 geoip_country_code: Option<String>,
718 system_id: Option<String>,
719 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
720 executor: Executor,
721 ) -> impl Future<Output = ()> + use<> {
722 let this = self.clone();
723 let span = info_span!("handle connection", %address,
724 connection_id=field::Empty,
725 user_id=field::Empty,
726 login=field::Empty,
727 impersonator=field::Empty,
728 geoip_country_code=field::Empty
729 );
730 principal.update_span(&span);
731 if let Some(country_code) = geoip_country_code.as_ref() {
732 span.record("geoip_country_code", country_code);
733 }
734
735 let mut teardown = self.teardown.subscribe();
736 async move {
737 if *teardown.borrow() {
738 tracing::error!("server is tearing down");
739 return
740 }
741 let (connection_id, handle_io, mut incoming_rx) = this
742 .peer
743 .add_connection(connection, {
744 let executor = executor.clone();
745 move |duration| executor.sleep(duration)
746 });
747 tracing::Span::current().record("connection_id", format!("{}", connection_id));
748
749 tracing::info!("connection opened");
750
751 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
752 let http_client = match ReqwestClient::user_agent(&user_agent) {
753 Ok(http_client) => Arc::new(http_client),
754 Err(error) => {
755 tracing::error!(?error, "failed to create HTTP client");
756 return;
757 }
758 };
759
760 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
761 supermaven_admin_api_key.to_string(),
762 http_client.clone(),
763 )));
764
765 let session = Session {
766 principal: principal.clone(),
767 connection_id,
768 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
769 peer: this.peer.clone(),
770 connection_pool: this.connection_pool.clone(),
771 app_state: this.app_state.clone(),
772 geoip_country_code,
773 system_id,
774 _executor: executor.clone(),
775 supermaven_client,
776 };
777
778 if let Err(error) = this.send_initial_client_update(connection_id, zed_version, send_connection_id, &session).await {
779 tracing::error!(?error, "failed to send initial client update");
780 return;
781 }
782
783 let handle_io = handle_io.fuse();
784 futures::pin_mut!(handle_io);
785
786 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
787 // This prevents deadlocks when e.g., client A performs a request to client B and
788 // client B performs a request to client A. If both clients stop processing further
789 // messages until their respective request completes, they won't have a chance to
790 // respond to the other client's request and cause a deadlock.
791 //
792 // This arrangement ensures we will attempt to process earlier messages first, but fall
793 // back to processing messages arrived later in the spirit of making progress.
794 let mut foreground_message_handlers = FuturesUnordered::new();
795 let concurrent_handlers = Arc::new(Semaphore::new(256));
796 loop {
797 let next_message = async {
798 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
799 let message = incoming_rx.next().await;
800 (permit, message)
801 }.fuse();
802 futures::pin_mut!(next_message);
803 futures::select_biased! {
804 _ = teardown.changed().fuse() => return,
805 result = handle_io => {
806 if let Err(error) = result {
807 tracing::error!(?error, "error handling I/O");
808 }
809 break;
810 }
811 _ = foreground_message_handlers.next() => {}
812 next_message = next_message => {
813 let (permit, message) = next_message;
814 if let Some(message) = message {
815 let type_name = message.payload_type_name();
816 // note: we copy all the fields from the parent span so we can query them in the logs.
817 // (https://github.com/tokio-rs/tracing/issues/2670).
818 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
819 user_id=field::Empty,
820 login=field::Empty,
821 impersonator=field::Empty,
822 );
823 principal.update_span(&span);
824 let span_enter = span.enter();
825 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
826 let is_background = message.is_background();
827 let handle_message = (handler)(message, session.clone());
828 drop(span_enter);
829
830 let handle_message = async move {
831 handle_message.await;
832 drop(permit);
833 }.instrument(span);
834 if is_background {
835 executor.spawn_detached(handle_message);
836 } else {
837 foreground_message_handlers.push(handle_message);
838 }
839 } else {
840 tracing::error!("no message handler");
841 }
842 } else {
843 tracing::info!("connection closed");
844 break;
845 }
846 }
847 }
848 }
849
850 drop(foreground_message_handlers);
851 tracing::info!("signing out");
852 if let Err(error) = connection_lost(session, teardown, executor).await {
853 tracing::error!(?error, "error signing out");
854 }
855
856 }.instrument(span)
857 }
858
859 async fn send_initial_client_update(
860 &self,
861 connection_id: ConnectionId,
862 zed_version: ZedVersion,
863 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
864 session: &Session,
865 ) -> Result<()> {
866 self.peer.send(
867 connection_id,
868 proto::Hello {
869 peer_id: Some(connection_id.into()),
870 },
871 )?;
872 tracing::info!("sent hello message");
873 if let Some(send_connection_id) = send_connection_id.take() {
874 let _ = send_connection_id.send(connection_id);
875 }
876
877 match &session.principal {
878 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
879 if !user.connected_once {
880 self.peer.send(connection_id, proto::ShowContacts {})?;
881 self.app_state
882 .db
883 .set_user_connected_once(user.id, true)
884 .await?;
885 }
886
887 update_user_plan(session).await?;
888
889 let contacts = self.app_state.db.get_contacts(user.id).await?;
890
891 {
892 let mut pool = self.connection_pool.lock();
893 pool.add_connection(connection_id, user.id, user.admin, zed_version);
894 self.peer.send(
895 connection_id,
896 build_initial_contacts_update(contacts, &pool),
897 )?;
898 }
899
900 if should_auto_subscribe_to_channels(zed_version) {
901 subscribe_user_to_channels(user.id, session).await?;
902 }
903
904 if let Some(incoming_call) =
905 self.app_state.db.incoming_call_for_user(user.id).await?
906 {
907 self.peer.send(connection_id, incoming_call)?;
908 }
909
910 update_user_contacts(user.id, session).await?;
911 }
912 }
913
914 Ok(())
915 }
916
917 pub async fn invite_code_redeemed(
918 self: &Arc<Self>,
919 inviter_id: UserId,
920 invitee_id: UserId,
921 ) -> Result<()> {
922 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
923 if let Some(code) = &user.invite_code {
924 let pool = self.connection_pool.lock();
925 let invitee_contact = contact_for_user(invitee_id, false, &pool);
926 for connection_id in pool.user_connection_ids(inviter_id) {
927 self.peer.send(
928 connection_id,
929 proto::UpdateContacts {
930 contacts: vec![invitee_contact.clone()],
931 ..Default::default()
932 },
933 )?;
934 self.peer.send(
935 connection_id,
936 proto::UpdateInviteInfo {
937 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
938 count: user.invite_count as u32,
939 },
940 )?;
941 }
942 }
943 }
944 Ok(())
945 }
946
947 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
948 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
949 if let Some(invite_code) = &user.invite_code {
950 let pool = self.connection_pool.lock();
951 for connection_id in pool.user_connection_ids(user_id) {
952 self.peer.send(
953 connection_id,
954 proto::UpdateInviteInfo {
955 url: format!(
956 "{}{}",
957 self.app_state.config.invite_link_prefix, invite_code
958 ),
959 count: user.invite_count as u32,
960 },
961 )?;
962 }
963 }
964 }
965 Ok(())
966 }
967
968 pub async fn update_plan_for_user(self: &Arc<Self>, user_id: UserId) -> Result<()> {
969 let user = self
970 .app_state
971 .db
972 .get_user_by_id(user_id)
973 .await?
974 .context("user not found")?;
975
976 let update_user_plan = make_update_user_plan_message(
977 &user,
978 user.admin,
979 &self.app_state.db,
980 self.app_state.llm_db.clone(),
981 )
982 .await?;
983
984 let pool = self.connection_pool.lock();
985 for connection_id in pool.user_connection_ids(user_id) {
986 self.peer
987 .send(connection_id, update_user_plan.clone())
988 .trace_err();
989 }
990
991 Ok(())
992 }
993
994 pub async fn refresh_llm_tokens_for_user(self: &Arc<Self>, user_id: UserId) {
995 let pool = self.connection_pool.lock();
996 for connection_id in pool.user_connection_ids(user_id) {
997 self.peer
998 .send(connection_id, proto::RefreshLlmToken {})
999 .trace_err();
1000 }
1001 }
1002
1003 pub async fn snapshot(self: &Arc<Self>) -> ServerSnapshot {
1004 ServerSnapshot {
1005 connection_pool: ConnectionPoolGuard {
1006 guard: self.connection_pool.lock(),
1007 _not_send: PhantomData,
1008 },
1009 peer: &self.peer,
1010 }
1011 }
1012}
1013
1014impl Deref for ConnectionPoolGuard<'_> {
1015 type Target = ConnectionPool;
1016
1017 fn deref(&self) -> &Self::Target {
1018 &self.guard
1019 }
1020}
1021
1022impl DerefMut for ConnectionPoolGuard<'_> {
1023 fn deref_mut(&mut self) -> &mut Self::Target {
1024 &mut self.guard
1025 }
1026}
1027
1028impl Drop for ConnectionPoolGuard<'_> {
1029 fn drop(&mut self) {
1030 #[cfg(test)]
1031 self.check_invariants();
1032 }
1033}
1034
1035fn broadcast<F>(
1036 sender_id: Option<ConnectionId>,
1037 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1038 mut f: F,
1039) where
1040 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1041{
1042 for receiver_id in receiver_ids {
1043 if Some(receiver_id) != sender_id {
1044 if let Err(error) = f(receiver_id) {
1045 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1046 }
1047 }
1048 }
1049}
1050
1051pub struct ProtocolVersion(u32);
1052
1053impl Header for ProtocolVersion {
1054 fn name() -> &'static HeaderName {
1055 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1056 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1057 }
1058
1059 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1060 where
1061 Self: Sized,
1062 I: Iterator<Item = &'i axum::http::HeaderValue>,
1063 {
1064 let version = values
1065 .next()
1066 .ok_or_else(axum::headers::Error::invalid)?
1067 .to_str()
1068 .map_err(|_| axum::headers::Error::invalid())?
1069 .parse()
1070 .map_err(|_| axum::headers::Error::invalid())?;
1071 Ok(Self(version))
1072 }
1073
1074 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1075 values.extend([self.0.to_string().parse().unwrap()]);
1076 }
1077}
1078
1079pub struct AppVersionHeader(SemanticVersion);
1080impl Header for AppVersionHeader {
1081 fn name() -> &'static HeaderName {
1082 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1083 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1084 }
1085
1086 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1087 where
1088 Self: Sized,
1089 I: Iterator<Item = &'i axum::http::HeaderValue>,
1090 {
1091 let version = values
1092 .next()
1093 .ok_or_else(axum::headers::Error::invalid)?
1094 .to_str()
1095 .map_err(|_| axum::headers::Error::invalid())?
1096 .parse()
1097 .map_err(|_| axum::headers::Error::invalid())?;
1098 Ok(Self(version))
1099 }
1100
1101 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1102 values.extend([self.0.to_string().parse().unwrap()]);
1103 }
1104}
1105
1106pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1107 Router::new()
1108 .route("/rpc", get(handle_websocket_request))
1109 .layer(
1110 ServiceBuilder::new()
1111 .layer(Extension(server.app_state.clone()))
1112 .layer(middleware::from_fn(auth::validate_header)),
1113 )
1114 .route("/metrics", get(handle_metrics))
1115 .layer(Extension(server))
1116}
1117
1118pub async fn handle_websocket_request(
1119 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1120 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1121 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1122 Extension(server): Extension<Arc<Server>>,
1123 Extension(principal): Extension<Principal>,
1124 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1125 system_id_header: Option<TypedHeader<SystemIdHeader>>,
1126 ws: WebSocketUpgrade,
1127) -> axum::response::Response {
1128 if protocol_version != rpc::PROTOCOL_VERSION {
1129 return (
1130 StatusCode::UPGRADE_REQUIRED,
1131 "client must be upgraded".to_string(),
1132 )
1133 .into_response();
1134 }
1135
1136 let Some(version) = app_version_header.map(|header| ZedVersion(header.0.0)) else {
1137 return (
1138 StatusCode::UPGRADE_REQUIRED,
1139 "no version header found".to_string(),
1140 )
1141 .into_response();
1142 };
1143
1144 if !version.can_collaborate() {
1145 return (
1146 StatusCode::UPGRADE_REQUIRED,
1147 "client must be upgraded".to_string(),
1148 )
1149 .into_response();
1150 }
1151
1152 let socket_address = socket_address.to_string();
1153 ws.on_upgrade(move |socket| {
1154 let socket = socket
1155 .map_ok(to_tungstenite_message)
1156 .err_into()
1157 .with(|message| async move { to_axum_message(message) });
1158 let connection = Connection::new(Box::pin(socket));
1159 async move {
1160 server
1161 .handle_connection(
1162 connection,
1163 socket_address,
1164 principal,
1165 version,
1166 country_code_header.map(|header| header.to_string()),
1167 system_id_header.map(|header| header.to_string()),
1168 None,
1169 Executor::Production,
1170 )
1171 .await;
1172 }
1173 })
1174}
1175
1176pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1177 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1178 let connections_metric = CONNECTIONS_METRIC
1179 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1180
1181 let connections = server
1182 .connection_pool
1183 .lock()
1184 .connections()
1185 .filter(|connection| !connection.admin)
1186 .count();
1187 connections_metric.set(connections as _);
1188
1189 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1190 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1191 register_int_gauge!(
1192 "shared_projects",
1193 "number of open projects with one or more guests"
1194 )
1195 .unwrap()
1196 });
1197
1198 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1199 shared_projects_metric.set(shared_projects as _);
1200
1201 let encoder = prometheus::TextEncoder::new();
1202 let metric_families = prometheus::gather();
1203 let encoded_metrics = encoder
1204 .encode_to_string(&metric_families)
1205 .map_err(|err| anyhow!("{err}"))?;
1206 Ok(encoded_metrics)
1207}
1208
1209#[instrument(err, skip(executor))]
1210async fn connection_lost(
1211 session: Session,
1212 mut teardown: watch::Receiver<bool>,
1213 executor: Executor,
1214) -> Result<()> {
1215 session.peer.disconnect(session.connection_id);
1216 session
1217 .connection_pool()
1218 .await
1219 .remove_connection(session.connection_id)?;
1220
1221 session
1222 .db()
1223 .await
1224 .connection_lost(session.connection_id)
1225 .await
1226 .trace_err();
1227
1228 futures::select_biased! {
1229 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1230
1231 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1232 leave_room_for_session(&session, session.connection_id).await.trace_err();
1233 leave_channel_buffers_for_session(&session)
1234 .await
1235 .trace_err();
1236
1237 if !session
1238 .connection_pool()
1239 .await
1240 .is_user_online(session.user_id())
1241 {
1242 let db = session.db().await;
1243 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1244 room_updated(&room, &session.peer);
1245 }
1246 }
1247
1248 update_user_contacts(session.user_id(), &session).await?;
1249 },
1250 _ = teardown.changed().fuse() => {}
1251 }
1252
1253 Ok(())
1254}
1255
1256/// Acknowledges a ping from a client, used to keep the connection alive.
1257async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1258 response.send(proto::Ack {})?;
1259 Ok(())
1260}
1261
1262/// Creates a new room for calling (outside of channels)
1263async fn create_room(
1264 _request: proto::CreateRoom,
1265 response: Response<proto::CreateRoom>,
1266 session: Session,
1267) -> Result<()> {
1268 let livekit_room = nanoid::nanoid!(30);
1269
1270 let live_kit_connection_info = util::maybe!(async {
1271 let live_kit = session.app_state.livekit_client.as_ref();
1272 let live_kit = live_kit?;
1273 let user_id = session.user_id().to_string();
1274
1275 let token = live_kit
1276 .room_token(&livekit_room, &user_id.to_string())
1277 .trace_err()?;
1278
1279 Some(proto::LiveKitConnectionInfo {
1280 server_url: live_kit.url().into(),
1281 token,
1282 can_publish: true,
1283 })
1284 })
1285 .await;
1286
1287 let room = session
1288 .db()
1289 .await
1290 .create_room(session.user_id(), session.connection_id, &livekit_room)
1291 .await?;
1292
1293 response.send(proto::CreateRoomResponse {
1294 room: Some(room.clone()),
1295 live_kit_connection_info,
1296 })?;
1297
1298 update_user_contacts(session.user_id(), &session).await?;
1299 Ok(())
1300}
1301
1302/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1303async fn join_room(
1304 request: proto::JoinRoom,
1305 response: Response<proto::JoinRoom>,
1306 session: Session,
1307) -> Result<()> {
1308 let room_id = RoomId::from_proto(request.id);
1309
1310 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1311
1312 if let Some(channel_id) = channel_id {
1313 return join_channel_internal(channel_id, Box::new(response), session).await;
1314 }
1315
1316 let joined_room = {
1317 let room = session
1318 .db()
1319 .await
1320 .join_room(room_id, session.user_id(), session.connection_id)
1321 .await?;
1322 room_updated(&room.room, &session.peer);
1323 room.into_inner()
1324 };
1325
1326 for connection_id in session
1327 .connection_pool()
1328 .await
1329 .user_connection_ids(session.user_id())
1330 {
1331 session
1332 .peer
1333 .send(
1334 connection_id,
1335 proto::CallCanceled {
1336 room_id: room_id.to_proto(),
1337 },
1338 )
1339 .trace_err();
1340 }
1341
1342 let live_kit_connection_info = if let Some(live_kit) = session.app_state.livekit_client.as_ref()
1343 {
1344 live_kit
1345 .room_token(
1346 &joined_room.room.livekit_room,
1347 &session.user_id().to_string(),
1348 )
1349 .trace_err()
1350 .map(|token| proto::LiveKitConnectionInfo {
1351 server_url: live_kit.url().into(),
1352 token,
1353 can_publish: true,
1354 })
1355 } else {
1356 None
1357 };
1358
1359 response.send(proto::JoinRoomResponse {
1360 room: Some(joined_room.room),
1361 channel_id: None,
1362 live_kit_connection_info,
1363 })?;
1364
1365 update_user_contacts(session.user_id(), &session).await?;
1366 Ok(())
1367}
1368
1369/// Rejoin room is used to reconnect to a room after connection errors.
1370async fn rejoin_room(
1371 request: proto::RejoinRoom,
1372 response: Response<proto::RejoinRoom>,
1373 session: Session,
1374) -> Result<()> {
1375 let room;
1376 let channel;
1377 {
1378 let mut rejoined_room = session
1379 .db()
1380 .await
1381 .rejoin_room(request, session.user_id(), session.connection_id)
1382 .await?;
1383
1384 response.send(proto::RejoinRoomResponse {
1385 room: Some(rejoined_room.room.clone()),
1386 reshared_projects: rejoined_room
1387 .reshared_projects
1388 .iter()
1389 .map(|project| proto::ResharedProject {
1390 id: project.id.to_proto(),
1391 collaborators: project
1392 .collaborators
1393 .iter()
1394 .map(|collaborator| collaborator.to_proto())
1395 .collect(),
1396 })
1397 .collect(),
1398 rejoined_projects: rejoined_room
1399 .rejoined_projects
1400 .iter()
1401 .map(|rejoined_project| rejoined_project.to_proto())
1402 .collect(),
1403 })?;
1404 room_updated(&rejoined_room.room, &session.peer);
1405
1406 for project in &rejoined_room.reshared_projects {
1407 for collaborator in &project.collaborators {
1408 session
1409 .peer
1410 .send(
1411 collaborator.connection_id,
1412 proto::UpdateProjectCollaborator {
1413 project_id: project.id.to_proto(),
1414 old_peer_id: Some(project.old_connection_id.into()),
1415 new_peer_id: Some(session.connection_id.into()),
1416 },
1417 )
1418 .trace_err();
1419 }
1420
1421 broadcast(
1422 Some(session.connection_id),
1423 project
1424 .collaborators
1425 .iter()
1426 .map(|collaborator| collaborator.connection_id),
1427 |connection_id| {
1428 session.peer.forward_send(
1429 session.connection_id,
1430 connection_id,
1431 proto::UpdateProject {
1432 project_id: project.id.to_proto(),
1433 worktrees: project.worktrees.clone(),
1434 },
1435 )
1436 },
1437 );
1438 }
1439
1440 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1441
1442 let rejoined_room = rejoined_room.into_inner();
1443
1444 room = rejoined_room.room;
1445 channel = rejoined_room.channel;
1446 }
1447
1448 if let Some(channel) = channel {
1449 channel_updated(
1450 &channel,
1451 &room,
1452 &session.peer,
1453 &*session.connection_pool().await,
1454 );
1455 }
1456
1457 update_user_contacts(session.user_id(), &session).await?;
1458 Ok(())
1459}
1460
1461fn notify_rejoined_projects(
1462 rejoined_projects: &mut Vec<RejoinedProject>,
1463 session: &Session,
1464) -> Result<()> {
1465 for project in rejoined_projects.iter() {
1466 for collaborator in &project.collaborators {
1467 session
1468 .peer
1469 .send(
1470 collaborator.connection_id,
1471 proto::UpdateProjectCollaborator {
1472 project_id: project.id.to_proto(),
1473 old_peer_id: Some(project.old_connection_id.into()),
1474 new_peer_id: Some(session.connection_id.into()),
1475 },
1476 )
1477 .trace_err();
1478 }
1479 }
1480
1481 for project in rejoined_projects {
1482 for worktree in mem::take(&mut project.worktrees) {
1483 // Stream this worktree's entries.
1484 let message = proto::UpdateWorktree {
1485 project_id: project.id.to_proto(),
1486 worktree_id: worktree.id,
1487 abs_path: worktree.abs_path.clone(),
1488 root_name: worktree.root_name,
1489 updated_entries: worktree.updated_entries,
1490 removed_entries: worktree.removed_entries,
1491 scan_id: worktree.scan_id,
1492 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1493 updated_repositories: worktree.updated_repositories,
1494 removed_repositories: worktree.removed_repositories,
1495 };
1496 for update in proto::split_worktree_update(message) {
1497 session.peer.send(session.connection_id, update)?;
1498 }
1499
1500 // Stream this worktree's diagnostics.
1501 for summary in worktree.diagnostic_summaries {
1502 session.peer.send(
1503 session.connection_id,
1504 proto::UpdateDiagnosticSummary {
1505 project_id: project.id.to_proto(),
1506 worktree_id: worktree.id,
1507 summary: Some(summary),
1508 },
1509 )?;
1510 }
1511
1512 for settings_file in worktree.settings_files {
1513 session.peer.send(
1514 session.connection_id,
1515 proto::UpdateWorktreeSettings {
1516 project_id: project.id.to_proto(),
1517 worktree_id: worktree.id,
1518 path: settings_file.path,
1519 content: Some(settings_file.content),
1520 kind: Some(settings_file.kind.to_proto().into()),
1521 },
1522 )?;
1523 }
1524 }
1525
1526 for repository in mem::take(&mut project.updated_repositories) {
1527 for update in split_repository_update(repository) {
1528 session.peer.send(session.connection_id, update)?;
1529 }
1530 }
1531
1532 for id in mem::take(&mut project.removed_repositories) {
1533 session.peer.send(
1534 session.connection_id,
1535 proto::RemoveRepository {
1536 project_id: project.id.to_proto(),
1537 id,
1538 },
1539 )?;
1540 }
1541 }
1542
1543 Ok(())
1544}
1545
1546/// leave room disconnects from the room.
1547async fn leave_room(
1548 _: proto::LeaveRoom,
1549 response: Response<proto::LeaveRoom>,
1550 session: Session,
1551) -> Result<()> {
1552 leave_room_for_session(&session, session.connection_id).await?;
1553 response.send(proto::Ack {})?;
1554 Ok(())
1555}
1556
1557/// Updates the permissions of someone else in the room.
1558async fn set_room_participant_role(
1559 request: proto::SetRoomParticipantRole,
1560 response: Response<proto::SetRoomParticipantRole>,
1561 session: Session,
1562) -> Result<()> {
1563 let user_id = UserId::from_proto(request.user_id);
1564 let role = ChannelRole::from(request.role());
1565
1566 let (livekit_room, can_publish) = {
1567 let room = session
1568 .db()
1569 .await
1570 .set_room_participant_role(
1571 session.user_id(),
1572 RoomId::from_proto(request.room_id),
1573 user_id,
1574 role,
1575 )
1576 .await?;
1577
1578 let livekit_room = room.livekit_room.clone();
1579 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1580 room_updated(&room, &session.peer);
1581 (livekit_room, can_publish)
1582 };
1583
1584 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
1585 live_kit
1586 .update_participant(
1587 livekit_room.clone(),
1588 request.user_id.to_string(),
1589 livekit_api::proto::ParticipantPermission {
1590 can_subscribe: true,
1591 can_publish,
1592 can_publish_data: can_publish,
1593 hidden: false,
1594 recorder: false,
1595 },
1596 )
1597 .await
1598 .trace_err();
1599 }
1600
1601 response.send(proto::Ack {})?;
1602 Ok(())
1603}
1604
1605/// Call someone else into the current room
1606async fn call(
1607 request: proto::Call,
1608 response: Response<proto::Call>,
1609 session: Session,
1610) -> Result<()> {
1611 let room_id = RoomId::from_proto(request.room_id);
1612 let calling_user_id = session.user_id();
1613 let calling_connection_id = session.connection_id;
1614 let called_user_id = UserId::from_proto(request.called_user_id);
1615 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1616 if !session
1617 .db()
1618 .await
1619 .has_contact(calling_user_id, called_user_id)
1620 .await?
1621 {
1622 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1623 }
1624
1625 let incoming_call = {
1626 let (room, incoming_call) = &mut *session
1627 .db()
1628 .await
1629 .call(
1630 room_id,
1631 calling_user_id,
1632 calling_connection_id,
1633 called_user_id,
1634 initial_project_id,
1635 )
1636 .await?;
1637 room_updated(room, &session.peer);
1638 mem::take(incoming_call)
1639 };
1640 update_user_contacts(called_user_id, &session).await?;
1641
1642 let mut calls = session
1643 .connection_pool()
1644 .await
1645 .user_connection_ids(called_user_id)
1646 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1647 .collect::<FuturesUnordered<_>>();
1648
1649 while let Some(call_response) = calls.next().await {
1650 match call_response.as_ref() {
1651 Ok(_) => {
1652 response.send(proto::Ack {})?;
1653 return Ok(());
1654 }
1655 Err(_) => {
1656 call_response.trace_err();
1657 }
1658 }
1659 }
1660
1661 {
1662 let room = session
1663 .db()
1664 .await
1665 .call_failed(room_id, called_user_id)
1666 .await?;
1667 room_updated(&room, &session.peer);
1668 }
1669 update_user_contacts(called_user_id, &session).await?;
1670
1671 Err(anyhow!("failed to ring user"))?
1672}
1673
1674/// Cancel an outgoing call.
1675async fn cancel_call(
1676 request: proto::CancelCall,
1677 response: Response<proto::CancelCall>,
1678 session: Session,
1679) -> Result<()> {
1680 let called_user_id = UserId::from_proto(request.called_user_id);
1681 let room_id = RoomId::from_proto(request.room_id);
1682 {
1683 let room = session
1684 .db()
1685 .await
1686 .cancel_call(room_id, session.connection_id, called_user_id)
1687 .await?;
1688 room_updated(&room, &session.peer);
1689 }
1690
1691 for connection_id in session
1692 .connection_pool()
1693 .await
1694 .user_connection_ids(called_user_id)
1695 {
1696 session
1697 .peer
1698 .send(
1699 connection_id,
1700 proto::CallCanceled {
1701 room_id: room_id.to_proto(),
1702 },
1703 )
1704 .trace_err();
1705 }
1706 response.send(proto::Ack {})?;
1707
1708 update_user_contacts(called_user_id, &session).await?;
1709 Ok(())
1710}
1711
1712/// Decline an incoming call.
1713async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> {
1714 let room_id = RoomId::from_proto(message.room_id);
1715 {
1716 let room = session
1717 .db()
1718 .await
1719 .decline_call(Some(room_id), session.user_id())
1720 .await?
1721 .context("declining call")?;
1722 room_updated(&room, &session.peer);
1723 }
1724
1725 for connection_id in session
1726 .connection_pool()
1727 .await
1728 .user_connection_ids(session.user_id())
1729 {
1730 session
1731 .peer
1732 .send(
1733 connection_id,
1734 proto::CallCanceled {
1735 room_id: room_id.to_proto(),
1736 },
1737 )
1738 .trace_err();
1739 }
1740 update_user_contacts(session.user_id(), &session).await?;
1741 Ok(())
1742}
1743
1744/// Updates other participants in the room with your current location.
1745async fn update_participant_location(
1746 request: proto::UpdateParticipantLocation,
1747 response: Response<proto::UpdateParticipantLocation>,
1748 session: Session,
1749) -> Result<()> {
1750 let room_id = RoomId::from_proto(request.room_id);
1751 let location = request.location.context("invalid location")?;
1752
1753 let db = session.db().await;
1754 let room = db
1755 .update_room_participant_location(room_id, session.connection_id, location)
1756 .await?;
1757
1758 room_updated(&room, &session.peer);
1759 response.send(proto::Ack {})?;
1760 Ok(())
1761}
1762
1763/// Share a project into the room.
1764async fn share_project(
1765 request: proto::ShareProject,
1766 response: Response<proto::ShareProject>,
1767 session: Session,
1768) -> Result<()> {
1769 let (project_id, room) = &*session
1770 .db()
1771 .await
1772 .share_project(
1773 RoomId::from_proto(request.room_id),
1774 session.connection_id,
1775 &request.worktrees,
1776 request.is_ssh_project,
1777 )
1778 .await?;
1779 response.send(proto::ShareProjectResponse {
1780 project_id: project_id.to_proto(),
1781 })?;
1782 room_updated(room, &session.peer);
1783
1784 Ok(())
1785}
1786
1787/// Unshare a project from the room.
1788async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
1789 let project_id = ProjectId::from_proto(message.project_id);
1790 unshare_project_internal(project_id, session.connection_id, &session).await
1791}
1792
1793async fn unshare_project_internal(
1794 project_id: ProjectId,
1795 connection_id: ConnectionId,
1796 session: &Session,
1797) -> Result<()> {
1798 let delete = {
1799 let room_guard = session
1800 .db()
1801 .await
1802 .unshare_project(project_id, connection_id)
1803 .await?;
1804
1805 let (delete, room, guest_connection_ids) = &*room_guard;
1806
1807 let message = proto::UnshareProject {
1808 project_id: project_id.to_proto(),
1809 };
1810
1811 broadcast(
1812 Some(connection_id),
1813 guest_connection_ids.iter().copied(),
1814 |conn_id| session.peer.send(conn_id, message.clone()),
1815 );
1816 if let Some(room) = room {
1817 room_updated(room, &session.peer);
1818 }
1819
1820 *delete
1821 };
1822
1823 if delete {
1824 let db = session.db().await;
1825 db.delete_project(project_id).await?;
1826 }
1827
1828 Ok(())
1829}
1830
1831/// Join someone elses shared project.
1832async fn join_project(
1833 request: proto::JoinProject,
1834 response: Response<proto::JoinProject>,
1835 session: Session,
1836) -> Result<()> {
1837 let project_id = ProjectId::from_proto(request.project_id);
1838
1839 tracing::info!(%project_id, "join project");
1840
1841 let db = session.db().await;
1842 let (project, replica_id) = &mut *db
1843 .join_project(project_id, session.connection_id, session.user_id())
1844 .await?;
1845 drop(db);
1846 tracing::info!(%project_id, "join remote project");
1847 join_project_internal(response, session, project, replica_id)
1848}
1849
1850trait JoinProjectInternalResponse {
1851 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
1852}
1853impl JoinProjectInternalResponse for Response<proto::JoinProject> {
1854 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
1855 Response::<proto::JoinProject>::send(self, result)
1856 }
1857}
1858
1859fn join_project_internal(
1860 response: impl JoinProjectInternalResponse,
1861 session: Session,
1862 project: &mut Project,
1863 replica_id: &ReplicaId,
1864) -> Result<()> {
1865 let collaborators = project
1866 .collaborators
1867 .iter()
1868 .filter(|collaborator| collaborator.connection_id != session.connection_id)
1869 .map(|collaborator| collaborator.to_proto())
1870 .collect::<Vec<_>>();
1871 let project_id = project.id;
1872 let guest_user_id = session.user_id();
1873
1874 let worktrees = project
1875 .worktrees
1876 .iter()
1877 .map(|(id, worktree)| proto::WorktreeMetadata {
1878 id: *id,
1879 root_name: worktree.root_name.clone(),
1880 visible: worktree.visible,
1881 abs_path: worktree.abs_path.clone(),
1882 })
1883 .collect::<Vec<_>>();
1884
1885 let add_project_collaborator = proto::AddProjectCollaborator {
1886 project_id: project_id.to_proto(),
1887 collaborator: Some(proto::Collaborator {
1888 peer_id: Some(session.connection_id.into()),
1889 replica_id: replica_id.0 as u32,
1890 user_id: guest_user_id.to_proto(),
1891 is_host: false,
1892 }),
1893 };
1894
1895 for collaborator in &collaborators {
1896 session
1897 .peer
1898 .send(
1899 collaborator.peer_id.unwrap().into(),
1900 add_project_collaborator.clone(),
1901 )
1902 .trace_err();
1903 }
1904
1905 // First, we send the metadata associated with each worktree.
1906 response.send(proto::JoinProjectResponse {
1907 project_id: project.id.0 as u64,
1908 worktrees: worktrees.clone(),
1909 replica_id: replica_id.0 as u32,
1910 collaborators: collaborators.clone(),
1911 language_servers: project.language_servers.clone(),
1912 role: project.role.into(),
1913 })?;
1914
1915 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
1916 // Stream this worktree's entries.
1917 let message = proto::UpdateWorktree {
1918 project_id: project_id.to_proto(),
1919 worktree_id,
1920 abs_path: worktree.abs_path.clone(),
1921 root_name: worktree.root_name,
1922 updated_entries: worktree.entries,
1923 removed_entries: Default::default(),
1924 scan_id: worktree.scan_id,
1925 is_last_update: worktree.scan_id == worktree.completed_scan_id,
1926 updated_repositories: worktree.legacy_repository_entries.into_values().collect(),
1927 removed_repositories: Default::default(),
1928 };
1929 for update in proto::split_worktree_update(message) {
1930 session.peer.send(session.connection_id, update.clone())?;
1931 }
1932
1933 // Stream this worktree's diagnostics.
1934 for summary in worktree.diagnostic_summaries {
1935 session.peer.send(
1936 session.connection_id,
1937 proto::UpdateDiagnosticSummary {
1938 project_id: project_id.to_proto(),
1939 worktree_id: worktree.id,
1940 summary: Some(summary),
1941 },
1942 )?;
1943 }
1944
1945 for settings_file in worktree.settings_files {
1946 session.peer.send(
1947 session.connection_id,
1948 proto::UpdateWorktreeSettings {
1949 project_id: project_id.to_proto(),
1950 worktree_id: worktree.id,
1951 path: settings_file.path,
1952 content: Some(settings_file.content),
1953 kind: Some(settings_file.kind.to_proto() as i32),
1954 },
1955 )?;
1956 }
1957 }
1958
1959 for repository in mem::take(&mut project.repositories) {
1960 for update in split_repository_update(repository) {
1961 session.peer.send(session.connection_id, update)?;
1962 }
1963 }
1964
1965 for language_server in &project.language_servers {
1966 session.peer.send(
1967 session.connection_id,
1968 proto::UpdateLanguageServer {
1969 project_id: project_id.to_proto(),
1970 language_server_id: language_server.id,
1971 variant: Some(
1972 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1973 proto::LspDiskBasedDiagnosticsUpdated {},
1974 ),
1975 ),
1976 },
1977 )?;
1978 }
1979
1980 Ok(())
1981}
1982
1983/// Leave someone elses shared project.
1984async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> {
1985 let sender_id = session.connection_id;
1986 let project_id = ProjectId::from_proto(request.project_id);
1987 let db = session.db().await;
1988
1989 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
1990 tracing::info!(
1991 %project_id,
1992 "leave project"
1993 );
1994
1995 project_left(project, &session);
1996 if let Some(room) = room {
1997 room_updated(room, &session.peer);
1998 }
1999
2000 Ok(())
2001}
2002
2003/// Updates other participants with changes to the project
2004async fn update_project(
2005 request: proto::UpdateProject,
2006 response: Response<proto::UpdateProject>,
2007 session: Session,
2008) -> Result<()> {
2009 let project_id = ProjectId::from_proto(request.project_id);
2010 let (room, guest_connection_ids) = &*session
2011 .db()
2012 .await
2013 .update_project(project_id, session.connection_id, &request.worktrees)
2014 .await?;
2015 broadcast(
2016 Some(session.connection_id),
2017 guest_connection_ids.iter().copied(),
2018 |connection_id| {
2019 session
2020 .peer
2021 .forward_send(session.connection_id, connection_id, request.clone())
2022 },
2023 );
2024 if let Some(room) = room {
2025 room_updated(room, &session.peer);
2026 }
2027 response.send(proto::Ack {})?;
2028
2029 Ok(())
2030}
2031
2032/// Updates other participants with changes to the worktree
2033async fn update_worktree(
2034 request: proto::UpdateWorktree,
2035 response: Response<proto::UpdateWorktree>,
2036 session: Session,
2037) -> Result<()> {
2038 let guest_connection_ids = session
2039 .db()
2040 .await
2041 .update_worktree(&request, session.connection_id)
2042 .await?;
2043
2044 broadcast(
2045 Some(session.connection_id),
2046 guest_connection_ids.iter().copied(),
2047 |connection_id| {
2048 session
2049 .peer
2050 .forward_send(session.connection_id, connection_id, request.clone())
2051 },
2052 );
2053 response.send(proto::Ack {})?;
2054 Ok(())
2055}
2056
2057async fn update_repository(
2058 request: proto::UpdateRepository,
2059 response: Response<proto::UpdateRepository>,
2060 session: Session,
2061) -> Result<()> {
2062 let guest_connection_ids = session
2063 .db()
2064 .await
2065 .update_repository(&request, session.connection_id)
2066 .await?;
2067
2068 broadcast(
2069 Some(session.connection_id),
2070 guest_connection_ids.iter().copied(),
2071 |connection_id| {
2072 session
2073 .peer
2074 .forward_send(session.connection_id, connection_id, request.clone())
2075 },
2076 );
2077 response.send(proto::Ack {})?;
2078 Ok(())
2079}
2080
2081async fn remove_repository(
2082 request: proto::RemoveRepository,
2083 response: Response<proto::RemoveRepository>,
2084 session: Session,
2085) -> Result<()> {
2086 let guest_connection_ids = session
2087 .db()
2088 .await
2089 .remove_repository(&request, session.connection_id)
2090 .await?;
2091
2092 broadcast(
2093 Some(session.connection_id),
2094 guest_connection_ids.iter().copied(),
2095 |connection_id| {
2096 session
2097 .peer
2098 .forward_send(session.connection_id, connection_id, request.clone())
2099 },
2100 );
2101 response.send(proto::Ack {})?;
2102 Ok(())
2103}
2104
2105/// Updates other participants with changes to the diagnostics
2106async fn update_diagnostic_summary(
2107 message: proto::UpdateDiagnosticSummary,
2108 session: Session,
2109) -> Result<()> {
2110 let guest_connection_ids = session
2111 .db()
2112 .await
2113 .update_diagnostic_summary(&message, session.connection_id)
2114 .await?;
2115
2116 broadcast(
2117 Some(session.connection_id),
2118 guest_connection_ids.iter().copied(),
2119 |connection_id| {
2120 session
2121 .peer
2122 .forward_send(session.connection_id, connection_id, message.clone())
2123 },
2124 );
2125
2126 Ok(())
2127}
2128
2129/// Updates other participants with changes to the worktree settings
2130async fn update_worktree_settings(
2131 message: proto::UpdateWorktreeSettings,
2132 session: Session,
2133) -> Result<()> {
2134 let guest_connection_ids = session
2135 .db()
2136 .await
2137 .update_worktree_settings(&message, session.connection_id)
2138 .await?;
2139
2140 broadcast(
2141 Some(session.connection_id),
2142 guest_connection_ids.iter().copied(),
2143 |connection_id| {
2144 session
2145 .peer
2146 .forward_send(session.connection_id, connection_id, message.clone())
2147 },
2148 );
2149
2150 Ok(())
2151}
2152
2153/// Notify other participants that a language server has started.
2154async fn start_language_server(
2155 request: proto::StartLanguageServer,
2156 session: Session,
2157) -> Result<()> {
2158 let guest_connection_ids = session
2159 .db()
2160 .await
2161 .start_language_server(&request, session.connection_id)
2162 .await?;
2163
2164 broadcast(
2165 Some(session.connection_id),
2166 guest_connection_ids.iter().copied(),
2167 |connection_id| {
2168 session
2169 .peer
2170 .forward_send(session.connection_id, connection_id, request.clone())
2171 },
2172 );
2173 Ok(())
2174}
2175
2176/// Notify other participants that a language server has changed.
2177async fn update_language_server(
2178 request: proto::UpdateLanguageServer,
2179 session: Session,
2180) -> Result<()> {
2181 let project_id = ProjectId::from_proto(request.project_id);
2182 let project_connection_ids = session
2183 .db()
2184 .await
2185 .project_connection_ids(project_id, session.connection_id, true)
2186 .await?;
2187 broadcast(
2188 Some(session.connection_id),
2189 project_connection_ids.iter().copied(),
2190 |connection_id| {
2191 session
2192 .peer
2193 .forward_send(session.connection_id, connection_id, request.clone())
2194 },
2195 );
2196 Ok(())
2197}
2198
2199/// forward a project request to the host. These requests should be read only
2200/// as guests are allowed to send them.
2201async fn forward_read_only_project_request<T>(
2202 request: T,
2203 response: Response<T>,
2204 session: Session,
2205) -> Result<()>
2206where
2207 T: EntityMessage + RequestMessage,
2208{
2209 let project_id = ProjectId::from_proto(request.remote_entity_id());
2210 let host_connection_id = session
2211 .db()
2212 .await
2213 .host_for_read_only_project_request(project_id, session.connection_id)
2214 .await?;
2215 let payload = session
2216 .peer
2217 .forward_request(session.connection_id, host_connection_id, request)
2218 .await?;
2219 response.send(payload)?;
2220 Ok(())
2221}
2222
2223async fn forward_find_search_candidates_request(
2224 request: proto::FindSearchCandidates,
2225 response: Response<proto::FindSearchCandidates>,
2226 session: Session,
2227) -> Result<()> {
2228 let project_id = ProjectId::from_proto(request.remote_entity_id());
2229 let host_connection_id = session
2230 .db()
2231 .await
2232 .host_for_read_only_project_request(project_id, session.connection_id)
2233 .await?;
2234 let payload = session
2235 .peer
2236 .forward_request(session.connection_id, host_connection_id, request)
2237 .await?;
2238 response.send(payload)?;
2239 Ok(())
2240}
2241
2242/// forward a project request to the host. These requests are disallowed
2243/// for guests.
2244async fn forward_mutating_project_request<T>(
2245 request: T,
2246 response: Response<T>,
2247 session: Session,
2248) -> Result<()>
2249where
2250 T: EntityMessage + RequestMessage,
2251{
2252 let project_id = ProjectId::from_proto(request.remote_entity_id());
2253
2254 let host_connection_id = session
2255 .db()
2256 .await
2257 .host_for_mutating_project_request(project_id, session.connection_id)
2258 .await?;
2259 let payload = session
2260 .peer
2261 .forward_request(session.connection_id, host_connection_id, request)
2262 .await?;
2263 response.send(payload)?;
2264 Ok(())
2265}
2266
2267/// Notify other participants that a new buffer has been created
2268async fn create_buffer_for_peer(
2269 request: proto::CreateBufferForPeer,
2270 session: Session,
2271) -> Result<()> {
2272 session
2273 .db()
2274 .await
2275 .check_user_is_project_host(
2276 ProjectId::from_proto(request.project_id),
2277 session.connection_id,
2278 )
2279 .await?;
2280 let peer_id = request.peer_id.context("invalid peer id")?;
2281 session
2282 .peer
2283 .forward_send(session.connection_id, peer_id.into(), request)?;
2284 Ok(())
2285}
2286
2287/// Notify other participants that a buffer has been updated. This is
2288/// allowed for guests as long as the update is limited to selections.
2289async fn update_buffer(
2290 request: proto::UpdateBuffer,
2291 response: Response<proto::UpdateBuffer>,
2292 session: Session,
2293) -> Result<()> {
2294 let project_id = ProjectId::from_proto(request.project_id);
2295 let mut capability = Capability::ReadOnly;
2296
2297 for op in request.operations.iter() {
2298 match op.variant {
2299 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2300 Some(_) => capability = Capability::ReadWrite,
2301 }
2302 }
2303
2304 let host = {
2305 let guard = session
2306 .db()
2307 .await
2308 .connections_for_buffer_update(project_id, session.connection_id, capability)
2309 .await?;
2310
2311 let (host, guests) = &*guard;
2312
2313 broadcast(
2314 Some(session.connection_id),
2315 guests.clone(),
2316 |connection_id| {
2317 session
2318 .peer
2319 .forward_send(session.connection_id, connection_id, request.clone())
2320 },
2321 );
2322
2323 *host
2324 };
2325
2326 if host != session.connection_id {
2327 session
2328 .peer
2329 .forward_request(session.connection_id, host, request.clone())
2330 .await?;
2331 }
2332
2333 response.send(proto::Ack {})?;
2334 Ok(())
2335}
2336
2337async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
2338 let project_id = ProjectId::from_proto(message.project_id);
2339
2340 let operation = message.operation.as_ref().context("invalid operation")?;
2341 let capability = match operation.variant.as_ref() {
2342 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
2343 if let Some(buffer_op) = buffer_op.operation.as_ref() {
2344 match buffer_op.variant {
2345 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
2346 Capability::ReadOnly
2347 }
2348 _ => Capability::ReadWrite,
2349 }
2350 } else {
2351 Capability::ReadWrite
2352 }
2353 }
2354 Some(_) => Capability::ReadWrite,
2355 None => Capability::ReadOnly,
2356 };
2357
2358 let guard = session
2359 .db()
2360 .await
2361 .connections_for_buffer_update(project_id, session.connection_id, capability)
2362 .await?;
2363
2364 let (host, guests) = &*guard;
2365
2366 broadcast(
2367 Some(session.connection_id),
2368 guests.iter().chain([host]).copied(),
2369 |connection_id| {
2370 session
2371 .peer
2372 .forward_send(session.connection_id, connection_id, message.clone())
2373 },
2374 );
2375
2376 Ok(())
2377}
2378
2379/// Notify other participants that a project has been updated.
2380async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2381 request: T,
2382 session: Session,
2383) -> Result<()> {
2384 let project_id = ProjectId::from_proto(request.remote_entity_id());
2385 let project_connection_ids = session
2386 .db()
2387 .await
2388 .project_connection_ids(project_id, session.connection_id, false)
2389 .await?;
2390
2391 broadcast(
2392 Some(session.connection_id),
2393 project_connection_ids.iter().copied(),
2394 |connection_id| {
2395 session
2396 .peer
2397 .forward_send(session.connection_id, connection_id, request.clone())
2398 },
2399 );
2400 Ok(())
2401}
2402
2403/// Start following another user in a call.
2404async fn follow(
2405 request: proto::Follow,
2406 response: Response<proto::Follow>,
2407 session: Session,
2408) -> Result<()> {
2409 let room_id = RoomId::from_proto(request.room_id);
2410 let project_id = request.project_id.map(ProjectId::from_proto);
2411 let leader_id = request.leader_id.context("invalid leader id")?.into();
2412 let follower_id = session.connection_id;
2413
2414 session
2415 .db()
2416 .await
2417 .check_room_participants(room_id, leader_id, session.connection_id)
2418 .await?;
2419
2420 let response_payload = session
2421 .peer
2422 .forward_request(session.connection_id, leader_id, request)
2423 .await?;
2424 response.send(response_payload)?;
2425
2426 if let Some(project_id) = project_id {
2427 let room = session
2428 .db()
2429 .await
2430 .follow(room_id, project_id, leader_id, follower_id)
2431 .await?;
2432 room_updated(&room, &session.peer);
2433 }
2434
2435 Ok(())
2436}
2437
2438/// Stop following another user in a call.
2439async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
2440 let room_id = RoomId::from_proto(request.room_id);
2441 let project_id = request.project_id.map(ProjectId::from_proto);
2442 let leader_id = request.leader_id.context("invalid leader id")?.into();
2443 let follower_id = session.connection_id;
2444
2445 session
2446 .db()
2447 .await
2448 .check_room_participants(room_id, leader_id, session.connection_id)
2449 .await?;
2450
2451 session
2452 .peer
2453 .forward_send(session.connection_id, leader_id, request)?;
2454
2455 if let Some(project_id) = project_id {
2456 let room = session
2457 .db()
2458 .await
2459 .unfollow(room_id, project_id, leader_id, follower_id)
2460 .await?;
2461 room_updated(&room, &session.peer);
2462 }
2463
2464 Ok(())
2465}
2466
2467/// Notify everyone following you of your current location.
2468async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
2469 let room_id = RoomId::from_proto(request.room_id);
2470 let database = session.db.lock().await;
2471
2472 let connection_ids = if let Some(project_id) = request.project_id {
2473 let project_id = ProjectId::from_proto(project_id);
2474 database
2475 .project_connection_ids(project_id, session.connection_id, true)
2476 .await?
2477 } else {
2478 database
2479 .room_connection_ids(room_id, session.connection_id)
2480 .await?
2481 };
2482
2483 // For now, don't send view update messages back to that view's current leader.
2484 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
2485 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
2486 _ => None,
2487 });
2488
2489 for connection_id in connection_ids.iter().cloned() {
2490 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
2491 session
2492 .peer
2493 .forward_send(session.connection_id, connection_id, request.clone())?;
2494 }
2495 }
2496 Ok(())
2497}
2498
2499/// Get public data about users.
2500async fn get_users(
2501 request: proto::GetUsers,
2502 response: Response<proto::GetUsers>,
2503 session: Session,
2504) -> Result<()> {
2505 let user_ids = request
2506 .user_ids
2507 .into_iter()
2508 .map(UserId::from_proto)
2509 .collect();
2510 let users = session
2511 .db()
2512 .await
2513 .get_users_by_ids(user_ids)
2514 .await?
2515 .into_iter()
2516 .map(|user| proto::User {
2517 id: user.id.to_proto(),
2518 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2519 github_login: user.github_login,
2520 email: user.email_address,
2521 name: user.name,
2522 })
2523 .collect();
2524 response.send(proto::UsersResponse { users })?;
2525 Ok(())
2526}
2527
2528/// Search for users (to invite) buy Github login
2529async fn fuzzy_search_users(
2530 request: proto::FuzzySearchUsers,
2531 response: Response<proto::FuzzySearchUsers>,
2532 session: Session,
2533) -> Result<()> {
2534 let query = request.query;
2535 let users = match query.len() {
2536 0 => vec![],
2537 1 | 2 => session
2538 .db()
2539 .await
2540 .get_user_by_github_login(&query)
2541 .await?
2542 .into_iter()
2543 .collect(),
2544 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
2545 };
2546 let users = users
2547 .into_iter()
2548 .filter(|user| user.id != session.user_id())
2549 .map(|user| proto::User {
2550 id: user.id.to_proto(),
2551 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
2552 github_login: user.github_login,
2553 name: user.name,
2554 email: user.email_address,
2555 })
2556 .collect();
2557 response.send(proto::UsersResponse { users })?;
2558 Ok(())
2559}
2560
2561/// Send a contact request to another user.
2562async fn request_contact(
2563 request: proto::RequestContact,
2564 response: Response<proto::RequestContact>,
2565 session: Session,
2566) -> Result<()> {
2567 let requester_id = session.user_id();
2568 let responder_id = UserId::from_proto(request.responder_id);
2569 if requester_id == responder_id {
2570 return Err(anyhow!("cannot add yourself as a contact"))?;
2571 }
2572
2573 let notifications = session
2574 .db()
2575 .await
2576 .send_contact_request(requester_id, responder_id)
2577 .await?;
2578
2579 // Update outgoing contact requests of requester
2580 let mut update = proto::UpdateContacts::default();
2581 update.outgoing_requests.push(responder_id.to_proto());
2582 for connection_id in session
2583 .connection_pool()
2584 .await
2585 .user_connection_ids(requester_id)
2586 {
2587 session.peer.send(connection_id, update.clone())?;
2588 }
2589
2590 // Update incoming contact requests of responder
2591 let mut update = proto::UpdateContacts::default();
2592 update
2593 .incoming_requests
2594 .push(proto::IncomingContactRequest {
2595 requester_id: requester_id.to_proto(),
2596 });
2597 let connection_pool = session.connection_pool().await;
2598 for connection_id in connection_pool.user_connection_ids(responder_id) {
2599 session.peer.send(connection_id, update.clone())?;
2600 }
2601
2602 send_notifications(&connection_pool, &session.peer, notifications);
2603
2604 response.send(proto::Ack {})?;
2605 Ok(())
2606}
2607
2608/// Accept or decline a contact request
2609async fn respond_to_contact_request(
2610 request: proto::RespondToContactRequest,
2611 response: Response<proto::RespondToContactRequest>,
2612 session: Session,
2613) -> Result<()> {
2614 let responder_id = session.user_id();
2615 let requester_id = UserId::from_proto(request.requester_id);
2616 let db = session.db().await;
2617 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
2618 db.dismiss_contact_notification(responder_id, requester_id)
2619 .await?;
2620 } else {
2621 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
2622
2623 let notifications = db
2624 .respond_to_contact_request(responder_id, requester_id, accept)
2625 .await?;
2626 let requester_busy = db.is_user_busy(requester_id).await?;
2627 let responder_busy = db.is_user_busy(responder_id).await?;
2628
2629 let pool = session.connection_pool().await;
2630 // Update responder with new contact
2631 let mut update = proto::UpdateContacts::default();
2632 if accept {
2633 update
2634 .contacts
2635 .push(contact_for_user(requester_id, requester_busy, &pool));
2636 }
2637 update
2638 .remove_incoming_requests
2639 .push(requester_id.to_proto());
2640 for connection_id in pool.user_connection_ids(responder_id) {
2641 session.peer.send(connection_id, update.clone())?;
2642 }
2643
2644 // Update requester with new contact
2645 let mut update = proto::UpdateContacts::default();
2646 if accept {
2647 update
2648 .contacts
2649 .push(contact_for_user(responder_id, responder_busy, &pool));
2650 }
2651 update
2652 .remove_outgoing_requests
2653 .push(responder_id.to_proto());
2654
2655 for connection_id in pool.user_connection_ids(requester_id) {
2656 session.peer.send(connection_id, update.clone())?;
2657 }
2658
2659 send_notifications(&pool, &session.peer, notifications);
2660 }
2661
2662 response.send(proto::Ack {})?;
2663 Ok(())
2664}
2665
2666/// Remove a contact.
2667async fn remove_contact(
2668 request: proto::RemoveContact,
2669 response: Response<proto::RemoveContact>,
2670 session: Session,
2671) -> Result<()> {
2672 let requester_id = session.user_id();
2673 let responder_id = UserId::from_proto(request.user_id);
2674 let db = session.db().await;
2675 let (contact_accepted, deleted_notification_id) =
2676 db.remove_contact(requester_id, responder_id).await?;
2677
2678 let pool = session.connection_pool().await;
2679 // Update outgoing contact requests of requester
2680 let mut update = proto::UpdateContacts::default();
2681 if contact_accepted {
2682 update.remove_contacts.push(responder_id.to_proto());
2683 } else {
2684 update
2685 .remove_outgoing_requests
2686 .push(responder_id.to_proto());
2687 }
2688 for connection_id in pool.user_connection_ids(requester_id) {
2689 session.peer.send(connection_id, update.clone())?;
2690 }
2691
2692 // Update incoming contact requests of responder
2693 let mut update = proto::UpdateContacts::default();
2694 if contact_accepted {
2695 update.remove_contacts.push(requester_id.to_proto());
2696 } else {
2697 update
2698 .remove_incoming_requests
2699 .push(requester_id.to_proto());
2700 }
2701 for connection_id in pool.user_connection_ids(responder_id) {
2702 session.peer.send(connection_id, update.clone())?;
2703 if let Some(notification_id) = deleted_notification_id {
2704 session.peer.send(
2705 connection_id,
2706 proto::DeleteNotification {
2707 notification_id: notification_id.to_proto(),
2708 },
2709 )?;
2710 }
2711 }
2712
2713 response.send(proto::Ack {})?;
2714 Ok(())
2715}
2716
2717fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
2718 version.0.minor() < 139
2719}
2720
2721async fn current_plan(db: &Arc<Database>, user_id: UserId, is_staff: bool) -> Result<proto::Plan> {
2722 if is_staff {
2723 return Ok(proto::Plan::ZedPro);
2724 }
2725
2726 let subscription = db.get_active_billing_subscription(user_id).await?;
2727 let subscription_kind = subscription.and_then(|subscription| subscription.kind);
2728
2729 let plan = if let Some(subscription_kind) = subscription_kind {
2730 match subscription_kind {
2731 SubscriptionKind::ZedPro => proto::Plan::ZedPro,
2732 SubscriptionKind::ZedProTrial => proto::Plan::ZedProTrial,
2733 SubscriptionKind::ZedFree => proto::Plan::Free,
2734 }
2735 } else {
2736 proto::Plan::Free
2737 };
2738
2739 Ok(plan)
2740}
2741
2742async fn make_update_user_plan_message(
2743 user: &User,
2744 is_staff: bool,
2745 db: &Arc<Database>,
2746 llm_db: Option<Arc<LlmDatabase>>,
2747) -> Result<proto::UpdateUserPlan> {
2748 let feature_flags = db.get_user_flags(user.id).await?;
2749 let plan = current_plan(db, user.id, is_staff).await?;
2750 let billing_customer = db.get_billing_customer_by_user_id(user.id).await?;
2751 let billing_preferences = db.get_billing_preferences(user.id).await?;
2752
2753 let (subscription_period, usage) = if let Some(llm_db) = llm_db {
2754 let subscription = db.get_active_billing_subscription(user.id).await?;
2755
2756 let subscription_period =
2757 crate::db::billing_subscription::Model::current_period(subscription, is_staff);
2758
2759 let usage = if let Some((period_start_at, period_end_at)) = subscription_period {
2760 llm_db
2761 .get_subscription_usage_for_period(user.id, period_start_at, period_end_at)
2762 .await?
2763 } else {
2764 None
2765 };
2766
2767 (subscription_period, usage)
2768 } else {
2769 (None, None)
2770 };
2771
2772 let account_too_young =
2773 !matches!(plan, proto::Plan::ZedPro) && user.account_age() < MIN_ACCOUNT_AGE_FOR_LLM_USE;
2774
2775 Ok(proto::UpdateUserPlan {
2776 plan: plan.into(),
2777 trial_started_at: billing_customer
2778 .as_ref()
2779 .and_then(|billing_customer| billing_customer.trial_started_at)
2780 .map(|trial_started_at| trial_started_at.and_utc().timestamp() as u64),
2781 is_usage_based_billing_enabled: if is_staff {
2782 Some(true)
2783 } else {
2784 billing_preferences.map(|preferences| preferences.model_request_overages_enabled)
2785 },
2786 subscription_period: subscription_period.map(|(started_at, ended_at)| {
2787 proto::SubscriptionPeriod {
2788 started_at: started_at.timestamp() as u64,
2789 ended_at: ended_at.timestamp() as u64,
2790 }
2791 }),
2792 account_too_young: Some(account_too_young),
2793 has_overdue_invoices: billing_customer
2794 .map(|billing_customer| billing_customer.has_overdue_invoices),
2795 usage: usage.map(|usage| {
2796 let plan = match plan {
2797 proto::Plan::Free => zed_llm_client::Plan::ZedFree,
2798 proto::Plan::ZedPro => zed_llm_client::Plan::ZedPro,
2799 proto::Plan::ZedProTrial => zed_llm_client::Plan::ZedProTrial,
2800 };
2801
2802 let model_requests_limit = match plan.model_requests_limit() {
2803 zed_llm_client::UsageLimit::Limited(limit) => {
2804 let limit = if plan == zed_llm_client::Plan::ZedProTrial
2805 && feature_flags
2806 .iter()
2807 .any(|flag| flag == AGENT_EXTENDED_TRIAL_FEATURE_FLAG)
2808 {
2809 1_000
2810 } else {
2811 limit
2812 };
2813
2814 zed_llm_client::UsageLimit::Limited(limit)
2815 }
2816 zed_llm_client::UsageLimit::Unlimited => zed_llm_client::UsageLimit::Unlimited,
2817 };
2818
2819 proto::SubscriptionUsage {
2820 model_requests_usage_amount: usage.model_requests as u32,
2821 model_requests_usage_limit: Some(proto::UsageLimit {
2822 variant: Some(match model_requests_limit {
2823 zed_llm_client::UsageLimit::Limited(limit) => {
2824 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2825 limit: limit as u32,
2826 })
2827 }
2828 zed_llm_client::UsageLimit::Unlimited => {
2829 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2830 }
2831 }),
2832 }),
2833 edit_predictions_usage_amount: usage.edit_predictions as u32,
2834 edit_predictions_usage_limit: Some(proto::UsageLimit {
2835 variant: Some(match plan.edit_predictions_limit() {
2836 zed_llm_client::UsageLimit::Limited(limit) => {
2837 proto::usage_limit::Variant::Limited(proto::usage_limit::Limited {
2838 limit: limit as u32,
2839 })
2840 }
2841 zed_llm_client::UsageLimit::Unlimited => {
2842 proto::usage_limit::Variant::Unlimited(proto::usage_limit::Unlimited {})
2843 }
2844 }),
2845 }),
2846 }
2847 }),
2848 })
2849}
2850
2851async fn update_user_plan(session: &Session) -> Result<()> {
2852 let db = session.db().await;
2853
2854 let update_user_plan = make_update_user_plan_message(
2855 session.principal.user(),
2856 session.is_staff(),
2857 &db.0,
2858 session.app_state.llm_db.clone(),
2859 )
2860 .await?;
2861
2862 session
2863 .peer
2864 .send(session.connection_id, update_user_plan)
2865 .trace_err();
2866
2867 Ok(())
2868}
2869
2870async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
2871 subscribe_user_to_channels(session.user_id(), &session).await?;
2872 Ok(())
2873}
2874
2875async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
2876 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
2877 let mut pool = session.connection_pool().await;
2878 for membership in &channels_for_user.channel_memberships {
2879 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
2880 }
2881 session.peer.send(
2882 session.connection_id,
2883 build_update_user_channels(&channels_for_user),
2884 )?;
2885 session.peer.send(
2886 session.connection_id,
2887 build_channels_update(channels_for_user),
2888 )?;
2889 Ok(())
2890}
2891
2892/// Creates a new channel.
2893async fn create_channel(
2894 request: proto::CreateChannel,
2895 response: Response<proto::CreateChannel>,
2896 session: Session,
2897) -> Result<()> {
2898 let db = session.db().await;
2899
2900 let parent_id = request.parent_id.map(ChannelId::from_proto);
2901 let (channel, membership) = db
2902 .create_channel(&request.name, parent_id, session.user_id())
2903 .await?;
2904
2905 let root_id = channel.root_id();
2906 let channel = Channel::from_model(channel);
2907
2908 response.send(proto::CreateChannelResponse {
2909 channel: Some(channel.to_proto()),
2910 parent_id: request.parent_id,
2911 })?;
2912
2913 let mut connection_pool = session.connection_pool().await;
2914 if let Some(membership) = membership {
2915 connection_pool.subscribe_to_channel(
2916 membership.user_id,
2917 membership.channel_id,
2918 membership.role,
2919 );
2920 let update = proto::UpdateUserChannels {
2921 channel_memberships: vec![proto::ChannelMembership {
2922 channel_id: membership.channel_id.to_proto(),
2923 role: membership.role.into(),
2924 }],
2925 ..Default::default()
2926 };
2927 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
2928 session.peer.send(connection_id, update.clone())?;
2929 }
2930 }
2931
2932 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
2933 if !role.can_see_channel(channel.visibility) {
2934 continue;
2935 }
2936
2937 let update = proto::UpdateChannels {
2938 channels: vec![channel.to_proto()],
2939 ..Default::default()
2940 };
2941 session.peer.send(connection_id, update.clone())?;
2942 }
2943
2944 Ok(())
2945}
2946
2947/// Delete a channel
2948async fn delete_channel(
2949 request: proto::DeleteChannel,
2950 response: Response<proto::DeleteChannel>,
2951 session: Session,
2952) -> Result<()> {
2953 let db = session.db().await;
2954
2955 let channel_id = request.channel_id;
2956 let (root_channel, removed_channels) = db
2957 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
2958 .await?;
2959 response.send(proto::Ack {})?;
2960
2961 // Notify members of removed channels
2962 let mut update = proto::UpdateChannels::default();
2963 update
2964 .delete_channels
2965 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
2966
2967 let connection_pool = session.connection_pool().await;
2968 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
2969 session.peer.send(connection_id, update.clone())?;
2970 }
2971
2972 Ok(())
2973}
2974
2975/// Invite someone to join a channel.
2976async fn invite_channel_member(
2977 request: proto::InviteChannelMember,
2978 response: Response<proto::InviteChannelMember>,
2979 session: Session,
2980) -> Result<()> {
2981 let db = session.db().await;
2982 let channel_id = ChannelId::from_proto(request.channel_id);
2983 let invitee_id = UserId::from_proto(request.user_id);
2984 let InviteMemberResult {
2985 channel,
2986 notifications,
2987 } = db
2988 .invite_channel_member(
2989 channel_id,
2990 invitee_id,
2991 session.user_id(),
2992 request.role().into(),
2993 )
2994 .await?;
2995
2996 let update = proto::UpdateChannels {
2997 channel_invitations: vec![channel.to_proto()],
2998 ..Default::default()
2999 };
3000
3001 let connection_pool = session.connection_pool().await;
3002 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3003 session.peer.send(connection_id, update.clone())?;
3004 }
3005
3006 send_notifications(&connection_pool, &session.peer, notifications);
3007
3008 response.send(proto::Ack {})?;
3009 Ok(())
3010}
3011
3012/// remove someone from a channel
3013async fn remove_channel_member(
3014 request: proto::RemoveChannelMember,
3015 response: Response<proto::RemoveChannelMember>,
3016 session: Session,
3017) -> Result<()> {
3018 let db = session.db().await;
3019 let channel_id = ChannelId::from_proto(request.channel_id);
3020 let member_id = UserId::from_proto(request.user_id);
3021
3022 let RemoveChannelMemberResult {
3023 membership_update,
3024 notification_id,
3025 } = db
3026 .remove_channel_member(channel_id, member_id, session.user_id())
3027 .await?;
3028
3029 let mut connection_pool = session.connection_pool().await;
3030 notify_membership_updated(
3031 &mut connection_pool,
3032 membership_update,
3033 member_id,
3034 &session.peer,
3035 );
3036 for connection_id in connection_pool.user_connection_ids(member_id) {
3037 if let Some(notification_id) = notification_id {
3038 session
3039 .peer
3040 .send(
3041 connection_id,
3042 proto::DeleteNotification {
3043 notification_id: notification_id.to_proto(),
3044 },
3045 )
3046 .trace_err();
3047 }
3048 }
3049
3050 response.send(proto::Ack {})?;
3051 Ok(())
3052}
3053
3054/// Toggle the channel between public and private.
3055/// Care is taken to maintain the invariant that public channels only descend from public channels,
3056/// (though members-only channels can appear at any point in the hierarchy).
3057async fn set_channel_visibility(
3058 request: proto::SetChannelVisibility,
3059 response: Response<proto::SetChannelVisibility>,
3060 session: Session,
3061) -> Result<()> {
3062 let db = session.db().await;
3063 let channel_id = ChannelId::from_proto(request.channel_id);
3064 let visibility = request.visibility().into();
3065
3066 let channel_model = db
3067 .set_channel_visibility(channel_id, visibility, session.user_id())
3068 .await?;
3069 let root_id = channel_model.root_id();
3070 let channel = Channel::from_model(channel_model);
3071
3072 let mut connection_pool = session.connection_pool().await;
3073 for (user_id, role) in connection_pool
3074 .channel_user_ids(root_id)
3075 .collect::<Vec<_>>()
3076 .into_iter()
3077 {
3078 let update = if role.can_see_channel(channel.visibility) {
3079 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3080 proto::UpdateChannels {
3081 channels: vec![channel.to_proto()],
3082 ..Default::default()
3083 }
3084 } else {
3085 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3086 proto::UpdateChannels {
3087 delete_channels: vec![channel.id.to_proto()],
3088 ..Default::default()
3089 }
3090 };
3091
3092 for connection_id in connection_pool.user_connection_ids(user_id) {
3093 session.peer.send(connection_id, update.clone())?;
3094 }
3095 }
3096
3097 response.send(proto::Ack {})?;
3098 Ok(())
3099}
3100
3101/// Alter the role for a user in the channel.
3102async fn set_channel_member_role(
3103 request: proto::SetChannelMemberRole,
3104 response: Response<proto::SetChannelMemberRole>,
3105 session: Session,
3106) -> Result<()> {
3107 let db = session.db().await;
3108 let channel_id = ChannelId::from_proto(request.channel_id);
3109 let member_id = UserId::from_proto(request.user_id);
3110 let result = db
3111 .set_channel_member_role(
3112 channel_id,
3113 session.user_id(),
3114 member_id,
3115 request.role().into(),
3116 )
3117 .await?;
3118
3119 match result {
3120 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3121 let mut connection_pool = session.connection_pool().await;
3122 notify_membership_updated(
3123 &mut connection_pool,
3124 membership_update,
3125 member_id,
3126 &session.peer,
3127 )
3128 }
3129 db::SetMemberRoleResult::InviteUpdated(channel) => {
3130 let update = proto::UpdateChannels {
3131 channel_invitations: vec![channel.to_proto()],
3132 ..Default::default()
3133 };
3134
3135 for connection_id in session
3136 .connection_pool()
3137 .await
3138 .user_connection_ids(member_id)
3139 {
3140 session.peer.send(connection_id, update.clone())?;
3141 }
3142 }
3143 }
3144
3145 response.send(proto::Ack {})?;
3146 Ok(())
3147}
3148
3149/// Change the name of a channel
3150async fn rename_channel(
3151 request: proto::RenameChannel,
3152 response: Response<proto::RenameChannel>,
3153 session: Session,
3154) -> Result<()> {
3155 let db = session.db().await;
3156 let channel_id = ChannelId::from_proto(request.channel_id);
3157 let channel_model = db
3158 .rename_channel(channel_id, session.user_id(), &request.name)
3159 .await?;
3160 let root_id = channel_model.root_id();
3161 let channel = Channel::from_model(channel_model);
3162
3163 response.send(proto::RenameChannelResponse {
3164 channel: Some(channel.to_proto()),
3165 })?;
3166
3167 let connection_pool = session.connection_pool().await;
3168 let update = proto::UpdateChannels {
3169 channels: vec![channel.to_proto()],
3170 ..Default::default()
3171 };
3172 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3173 if role.can_see_channel(channel.visibility) {
3174 session.peer.send(connection_id, update.clone())?;
3175 }
3176 }
3177
3178 Ok(())
3179}
3180
3181/// Move a channel to a new parent.
3182async fn move_channel(
3183 request: proto::MoveChannel,
3184 response: Response<proto::MoveChannel>,
3185 session: Session,
3186) -> Result<()> {
3187 let channel_id = ChannelId::from_proto(request.channel_id);
3188 let to = ChannelId::from_proto(request.to);
3189
3190 let (root_id, channels) = session
3191 .db()
3192 .await
3193 .move_channel(channel_id, to, session.user_id())
3194 .await?;
3195
3196 let connection_pool = session.connection_pool().await;
3197 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3198 let channels = channels
3199 .iter()
3200 .filter_map(|channel| {
3201 if role.can_see_channel(channel.visibility) {
3202 Some(channel.to_proto())
3203 } else {
3204 None
3205 }
3206 })
3207 .collect::<Vec<_>>();
3208 if channels.is_empty() {
3209 continue;
3210 }
3211
3212 let update = proto::UpdateChannels {
3213 channels,
3214 ..Default::default()
3215 };
3216
3217 session.peer.send(connection_id, update.clone())?;
3218 }
3219
3220 response.send(Ack {})?;
3221 Ok(())
3222}
3223
3224async fn reorder_channel(
3225 request: proto::ReorderChannel,
3226 response: Response<proto::ReorderChannel>,
3227 session: Session,
3228) -> Result<()> {
3229 let channel_id = ChannelId::from_proto(request.channel_id);
3230 let direction = request.direction();
3231
3232 let updated_channels = session
3233 .db()
3234 .await
3235 .reorder_channel(channel_id, direction, session.user_id())
3236 .await?;
3237
3238 if let Some(root_id) = updated_channels.first().map(|channel| channel.root_id()) {
3239 let connection_pool = session.connection_pool().await;
3240 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3241 let channels = updated_channels
3242 .iter()
3243 .filter_map(|channel| {
3244 if role.can_see_channel(channel.visibility) {
3245 Some(channel.to_proto())
3246 } else {
3247 None
3248 }
3249 })
3250 .collect::<Vec<_>>();
3251
3252 if channels.is_empty() {
3253 continue;
3254 }
3255
3256 let update = proto::UpdateChannels {
3257 channels,
3258 ..Default::default()
3259 };
3260
3261 session.peer.send(connection_id, update.clone())?;
3262 }
3263 }
3264
3265 response.send(Ack {})?;
3266 Ok(())
3267}
3268
3269/// Get the list of channel members
3270async fn get_channel_members(
3271 request: proto::GetChannelMembers,
3272 response: Response<proto::GetChannelMembers>,
3273 session: Session,
3274) -> Result<()> {
3275 let db = session.db().await;
3276 let channel_id = ChannelId::from_proto(request.channel_id);
3277 let limit = if request.limit == 0 {
3278 u16::MAX as u64
3279 } else {
3280 request.limit
3281 };
3282 let (members, users) = db
3283 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3284 .await?;
3285 response.send(proto::GetChannelMembersResponse { members, users })?;
3286 Ok(())
3287}
3288
3289/// Accept or decline a channel invitation.
3290async fn respond_to_channel_invite(
3291 request: proto::RespondToChannelInvite,
3292 response: Response<proto::RespondToChannelInvite>,
3293 session: Session,
3294) -> Result<()> {
3295 let db = session.db().await;
3296 let channel_id = ChannelId::from_proto(request.channel_id);
3297 let RespondToChannelInvite {
3298 membership_update,
3299 notifications,
3300 } = db
3301 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3302 .await?;
3303
3304 let mut connection_pool = session.connection_pool().await;
3305 if let Some(membership_update) = membership_update {
3306 notify_membership_updated(
3307 &mut connection_pool,
3308 membership_update,
3309 session.user_id(),
3310 &session.peer,
3311 );
3312 } else {
3313 let update = proto::UpdateChannels {
3314 remove_channel_invitations: vec![channel_id.to_proto()],
3315 ..Default::default()
3316 };
3317
3318 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3319 session.peer.send(connection_id, update.clone())?;
3320 }
3321 };
3322
3323 send_notifications(&connection_pool, &session.peer, notifications);
3324
3325 response.send(proto::Ack {})?;
3326
3327 Ok(())
3328}
3329
3330/// Join the channels' room
3331async fn join_channel(
3332 request: proto::JoinChannel,
3333 response: Response<proto::JoinChannel>,
3334 session: Session,
3335) -> Result<()> {
3336 let channel_id = ChannelId::from_proto(request.channel_id);
3337 join_channel_internal(channel_id, Box::new(response), session).await
3338}
3339
3340trait JoinChannelInternalResponse {
3341 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3342}
3343impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3344 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3345 Response::<proto::JoinChannel>::send(self, result)
3346 }
3347}
3348impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3349 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3350 Response::<proto::JoinRoom>::send(self, result)
3351 }
3352}
3353
3354async fn join_channel_internal(
3355 channel_id: ChannelId,
3356 response: Box<impl JoinChannelInternalResponse>,
3357 session: Session,
3358) -> Result<()> {
3359 let joined_room = {
3360 let mut db = session.db().await;
3361 // If zed quits without leaving the room, and the user re-opens zed before the
3362 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3363 // room they were in.
3364 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3365 tracing::info!(
3366 stale_connection_id = %connection,
3367 "cleaning up stale connection",
3368 );
3369 drop(db);
3370 leave_room_for_session(&session, connection).await?;
3371 db = session.db().await;
3372 }
3373
3374 let (joined_room, membership_updated, role) = db
3375 .join_channel(channel_id, session.user_id(), session.connection_id)
3376 .await?;
3377
3378 let live_kit_connection_info =
3379 session
3380 .app_state
3381 .livekit_client
3382 .as_ref()
3383 .and_then(|live_kit| {
3384 let (can_publish, token) = if role == ChannelRole::Guest {
3385 (
3386 false,
3387 live_kit
3388 .guest_token(
3389 &joined_room.room.livekit_room,
3390 &session.user_id().to_string(),
3391 )
3392 .trace_err()?,
3393 )
3394 } else {
3395 (
3396 true,
3397 live_kit
3398 .room_token(
3399 &joined_room.room.livekit_room,
3400 &session.user_id().to_string(),
3401 )
3402 .trace_err()?,
3403 )
3404 };
3405
3406 Some(LiveKitConnectionInfo {
3407 server_url: live_kit.url().into(),
3408 token,
3409 can_publish,
3410 })
3411 });
3412
3413 response.send(proto::JoinRoomResponse {
3414 room: Some(joined_room.room.clone()),
3415 channel_id: joined_room
3416 .channel
3417 .as_ref()
3418 .map(|channel| channel.id.to_proto()),
3419 live_kit_connection_info,
3420 })?;
3421
3422 let mut connection_pool = session.connection_pool().await;
3423 if let Some(membership_updated) = membership_updated {
3424 notify_membership_updated(
3425 &mut connection_pool,
3426 membership_updated,
3427 session.user_id(),
3428 &session.peer,
3429 );
3430 }
3431
3432 room_updated(&joined_room.room, &session.peer);
3433
3434 joined_room
3435 };
3436
3437 channel_updated(
3438 &joined_room.channel.context("channel not returned")?,
3439 &joined_room.room,
3440 &session.peer,
3441 &*session.connection_pool().await,
3442 );
3443
3444 update_user_contacts(session.user_id(), &session).await?;
3445 Ok(())
3446}
3447
3448/// Start editing the channel notes
3449async fn join_channel_buffer(
3450 request: proto::JoinChannelBuffer,
3451 response: Response<proto::JoinChannelBuffer>,
3452 session: Session,
3453) -> Result<()> {
3454 let db = session.db().await;
3455 let channel_id = ChannelId::from_proto(request.channel_id);
3456
3457 let open_response = db
3458 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3459 .await?;
3460
3461 let collaborators = open_response.collaborators.clone();
3462 response.send(open_response)?;
3463
3464 let update = UpdateChannelBufferCollaborators {
3465 channel_id: channel_id.to_proto(),
3466 collaborators: collaborators.clone(),
3467 };
3468 channel_buffer_updated(
3469 session.connection_id,
3470 collaborators
3471 .iter()
3472 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3473 &update,
3474 &session.peer,
3475 );
3476
3477 Ok(())
3478}
3479
3480/// Edit the channel notes
3481async fn update_channel_buffer(
3482 request: proto::UpdateChannelBuffer,
3483 session: Session,
3484) -> Result<()> {
3485 let db = session.db().await;
3486 let channel_id = ChannelId::from_proto(request.channel_id);
3487
3488 let (collaborators, epoch, version) = db
3489 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3490 .await?;
3491
3492 channel_buffer_updated(
3493 session.connection_id,
3494 collaborators.clone(),
3495 &proto::UpdateChannelBuffer {
3496 channel_id: channel_id.to_proto(),
3497 operations: request.operations,
3498 },
3499 &session.peer,
3500 );
3501
3502 let pool = &*session.connection_pool().await;
3503
3504 let non_collaborators =
3505 pool.channel_connection_ids(channel_id)
3506 .filter_map(|(connection_id, _)| {
3507 if collaborators.contains(&connection_id) {
3508 None
3509 } else {
3510 Some(connection_id)
3511 }
3512 });
3513
3514 broadcast(None, non_collaborators, |peer_id| {
3515 session.peer.send(
3516 peer_id,
3517 proto::UpdateChannels {
3518 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3519 channel_id: channel_id.to_proto(),
3520 epoch: epoch as u64,
3521 version: version.clone(),
3522 }],
3523 ..Default::default()
3524 },
3525 )
3526 });
3527
3528 Ok(())
3529}
3530
3531/// Rejoin the channel notes after a connection blip
3532async fn rejoin_channel_buffers(
3533 request: proto::RejoinChannelBuffers,
3534 response: Response<proto::RejoinChannelBuffers>,
3535 session: Session,
3536) -> Result<()> {
3537 let db = session.db().await;
3538 let buffers = db
3539 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3540 .await?;
3541
3542 for rejoined_buffer in &buffers {
3543 let collaborators_to_notify = rejoined_buffer
3544 .buffer
3545 .collaborators
3546 .iter()
3547 .filter_map(|c| Some(c.peer_id?.into()));
3548 channel_buffer_updated(
3549 session.connection_id,
3550 collaborators_to_notify,
3551 &proto::UpdateChannelBufferCollaborators {
3552 channel_id: rejoined_buffer.buffer.channel_id,
3553 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3554 },
3555 &session.peer,
3556 );
3557 }
3558
3559 response.send(proto::RejoinChannelBuffersResponse {
3560 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3561 })?;
3562
3563 Ok(())
3564}
3565
3566/// Stop editing the channel notes
3567async fn leave_channel_buffer(
3568 request: proto::LeaveChannelBuffer,
3569 response: Response<proto::LeaveChannelBuffer>,
3570 session: Session,
3571) -> Result<()> {
3572 let db = session.db().await;
3573 let channel_id = ChannelId::from_proto(request.channel_id);
3574
3575 let left_buffer = db
3576 .leave_channel_buffer(channel_id, session.connection_id)
3577 .await?;
3578
3579 response.send(Ack {})?;
3580
3581 channel_buffer_updated(
3582 session.connection_id,
3583 left_buffer.connections,
3584 &proto::UpdateChannelBufferCollaborators {
3585 channel_id: channel_id.to_proto(),
3586 collaborators: left_buffer.collaborators,
3587 },
3588 &session.peer,
3589 );
3590
3591 Ok(())
3592}
3593
3594fn channel_buffer_updated<T: EnvelopedMessage>(
3595 sender_id: ConnectionId,
3596 collaborators: impl IntoIterator<Item = ConnectionId>,
3597 message: &T,
3598 peer: &Peer,
3599) {
3600 broadcast(Some(sender_id), collaborators, |peer_id| {
3601 peer.send(peer_id, message.clone())
3602 });
3603}
3604
3605fn send_notifications(
3606 connection_pool: &ConnectionPool,
3607 peer: &Peer,
3608 notifications: db::NotificationBatch,
3609) {
3610 for (user_id, notification) in notifications {
3611 for connection_id in connection_pool.user_connection_ids(user_id) {
3612 if let Err(error) = peer.send(
3613 connection_id,
3614 proto::AddNotification {
3615 notification: Some(notification.clone()),
3616 },
3617 ) {
3618 tracing::error!(
3619 "failed to send notification to {:?} {}",
3620 connection_id,
3621 error
3622 );
3623 }
3624 }
3625 }
3626}
3627
3628/// Send a message to the channel
3629async fn send_channel_message(
3630 request: proto::SendChannelMessage,
3631 response: Response<proto::SendChannelMessage>,
3632 session: Session,
3633) -> Result<()> {
3634 // Validate the message body.
3635 let body = request.body.trim().to_string();
3636 if body.len() > MAX_MESSAGE_LEN {
3637 return Err(anyhow!("message is too long"))?;
3638 }
3639 if body.is_empty() {
3640 return Err(anyhow!("message can't be blank"))?;
3641 }
3642
3643 // TODO: adjust mentions if body is trimmed
3644
3645 let timestamp = OffsetDateTime::now_utc();
3646 let nonce = request.nonce.context("nonce can't be blank")?;
3647
3648 let channel_id = ChannelId::from_proto(request.channel_id);
3649 let CreatedChannelMessage {
3650 message_id,
3651 participant_connection_ids,
3652 notifications,
3653 } = session
3654 .db()
3655 .await
3656 .create_channel_message(
3657 channel_id,
3658 session.user_id(),
3659 &body,
3660 &request.mentions,
3661 timestamp,
3662 nonce.clone().into(),
3663 request.reply_to_message_id.map(MessageId::from_proto),
3664 )
3665 .await?;
3666
3667 let message = proto::ChannelMessage {
3668 sender_id: session.user_id().to_proto(),
3669 id: message_id.to_proto(),
3670 body,
3671 mentions: request.mentions,
3672 timestamp: timestamp.unix_timestamp() as u64,
3673 nonce: Some(nonce),
3674 reply_to_message_id: request.reply_to_message_id,
3675 edited_at: None,
3676 };
3677 broadcast(
3678 Some(session.connection_id),
3679 participant_connection_ids.clone(),
3680 |connection| {
3681 session.peer.send(
3682 connection,
3683 proto::ChannelMessageSent {
3684 channel_id: channel_id.to_proto(),
3685 message: Some(message.clone()),
3686 },
3687 )
3688 },
3689 );
3690 response.send(proto::SendChannelMessageResponse {
3691 message: Some(message),
3692 })?;
3693
3694 let pool = &*session.connection_pool().await;
3695 let non_participants =
3696 pool.channel_connection_ids(channel_id)
3697 .filter_map(|(connection_id, _)| {
3698 if participant_connection_ids.contains(&connection_id) {
3699 None
3700 } else {
3701 Some(connection_id)
3702 }
3703 });
3704 broadcast(None, non_participants, |peer_id| {
3705 session.peer.send(
3706 peer_id,
3707 proto::UpdateChannels {
3708 latest_channel_message_ids: vec![proto::ChannelMessageId {
3709 channel_id: channel_id.to_proto(),
3710 message_id: message_id.to_proto(),
3711 }],
3712 ..Default::default()
3713 },
3714 )
3715 });
3716 send_notifications(pool, &session.peer, notifications);
3717
3718 Ok(())
3719}
3720
3721/// Delete a channel message
3722async fn remove_channel_message(
3723 request: proto::RemoveChannelMessage,
3724 response: Response<proto::RemoveChannelMessage>,
3725 session: Session,
3726) -> Result<()> {
3727 let channel_id = ChannelId::from_proto(request.channel_id);
3728 let message_id = MessageId::from_proto(request.message_id);
3729 let (connection_ids, existing_notification_ids) = session
3730 .db()
3731 .await
3732 .remove_channel_message(channel_id, message_id, session.user_id())
3733 .await?;
3734
3735 broadcast(
3736 Some(session.connection_id),
3737 connection_ids,
3738 move |connection| {
3739 session.peer.send(connection, request.clone())?;
3740
3741 for notification_id in &existing_notification_ids {
3742 session.peer.send(
3743 connection,
3744 proto::DeleteNotification {
3745 notification_id: (*notification_id).to_proto(),
3746 },
3747 )?;
3748 }
3749
3750 Ok(())
3751 },
3752 );
3753 response.send(proto::Ack {})?;
3754 Ok(())
3755}
3756
3757async fn update_channel_message(
3758 request: proto::UpdateChannelMessage,
3759 response: Response<proto::UpdateChannelMessage>,
3760 session: Session,
3761) -> Result<()> {
3762 let channel_id = ChannelId::from_proto(request.channel_id);
3763 let message_id = MessageId::from_proto(request.message_id);
3764 let updated_at = OffsetDateTime::now_utc();
3765 let UpdatedChannelMessage {
3766 message_id,
3767 participant_connection_ids,
3768 notifications,
3769 reply_to_message_id,
3770 timestamp,
3771 deleted_mention_notification_ids,
3772 updated_mention_notifications,
3773 } = session
3774 .db()
3775 .await
3776 .update_channel_message(
3777 channel_id,
3778 message_id,
3779 session.user_id(),
3780 request.body.as_str(),
3781 &request.mentions,
3782 updated_at,
3783 )
3784 .await?;
3785
3786 let nonce = request.nonce.clone().context("nonce can't be blank")?;
3787
3788 let message = proto::ChannelMessage {
3789 sender_id: session.user_id().to_proto(),
3790 id: message_id.to_proto(),
3791 body: request.body.clone(),
3792 mentions: request.mentions.clone(),
3793 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
3794 nonce: Some(nonce),
3795 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
3796 edited_at: Some(updated_at.unix_timestamp() as u64),
3797 };
3798
3799 response.send(proto::Ack {})?;
3800
3801 let pool = &*session.connection_pool().await;
3802 broadcast(
3803 Some(session.connection_id),
3804 participant_connection_ids,
3805 |connection| {
3806 session.peer.send(
3807 connection,
3808 proto::ChannelMessageUpdate {
3809 channel_id: channel_id.to_proto(),
3810 message: Some(message.clone()),
3811 },
3812 )?;
3813
3814 for notification_id in &deleted_mention_notification_ids {
3815 session.peer.send(
3816 connection,
3817 proto::DeleteNotification {
3818 notification_id: (*notification_id).to_proto(),
3819 },
3820 )?;
3821 }
3822
3823 for notification in &updated_mention_notifications {
3824 session.peer.send(
3825 connection,
3826 proto::UpdateNotification {
3827 notification: Some(notification.clone()),
3828 },
3829 )?;
3830 }
3831
3832 Ok(())
3833 },
3834 );
3835
3836 send_notifications(pool, &session.peer, notifications);
3837
3838 Ok(())
3839}
3840
3841/// Mark a channel message as read
3842async fn acknowledge_channel_message(
3843 request: proto::AckChannelMessage,
3844 session: Session,
3845) -> Result<()> {
3846 let channel_id = ChannelId::from_proto(request.channel_id);
3847 let message_id = MessageId::from_proto(request.message_id);
3848 let notifications = session
3849 .db()
3850 .await
3851 .observe_channel_message(channel_id, session.user_id(), message_id)
3852 .await?;
3853 send_notifications(
3854 &*session.connection_pool().await,
3855 &session.peer,
3856 notifications,
3857 );
3858 Ok(())
3859}
3860
3861/// Mark a buffer version as synced
3862async fn acknowledge_buffer_version(
3863 request: proto::AckBufferOperation,
3864 session: Session,
3865) -> Result<()> {
3866 let buffer_id = BufferId::from_proto(request.buffer_id);
3867 session
3868 .db()
3869 .await
3870 .observe_buffer_version(
3871 buffer_id,
3872 session.user_id(),
3873 request.epoch as i32,
3874 &request.version,
3875 )
3876 .await?;
3877 Ok(())
3878}
3879
3880/// Get a Supermaven API key for the user
3881async fn get_supermaven_api_key(
3882 _request: proto::GetSupermavenApiKey,
3883 response: Response<proto::GetSupermavenApiKey>,
3884 session: Session,
3885) -> Result<()> {
3886 let user_id: String = session.user_id().to_string();
3887 if !session.is_staff() {
3888 return Err(anyhow!("supermaven not enabled for this account"))?;
3889 }
3890
3891 let email = session.email().context("user must have an email")?;
3892
3893 let supermaven_admin_api = session
3894 .supermaven_client
3895 .as_ref()
3896 .context("supermaven not configured")?;
3897
3898 let result = supermaven_admin_api
3899 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
3900 .await?;
3901
3902 response.send(proto::GetSupermavenApiKeyResponse {
3903 api_key: result.api_key,
3904 })?;
3905
3906 Ok(())
3907}
3908
3909/// Start receiving chat updates for a channel
3910async fn join_channel_chat(
3911 request: proto::JoinChannelChat,
3912 response: Response<proto::JoinChannelChat>,
3913 session: Session,
3914) -> Result<()> {
3915 let channel_id = ChannelId::from_proto(request.channel_id);
3916
3917 let db = session.db().await;
3918 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
3919 .await?;
3920 let messages = db
3921 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
3922 .await?;
3923 response.send(proto::JoinChannelChatResponse {
3924 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3925 messages,
3926 })?;
3927 Ok(())
3928}
3929
3930/// Stop receiving chat updates for a channel
3931async fn leave_channel_chat(request: proto::LeaveChannelChat, session: Session) -> Result<()> {
3932 let channel_id = ChannelId::from_proto(request.channel_id);
3933 session
3934 .db()
3935 .await
3936 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
3937 .await?;
3938 Ok(())
3939}
3940
3941/// Retrieve the chat history for a channel
3942async fn get_channel_messages(
3943 request: proto::GetChannelMessages,
3944 response: Response<proto::GetChannelMessages>,
3945 session: Session,
3946) -> Result<()> {
3947 let channel_id = ChannelId::from_proto(request.channel_id);
3948 let messages = session
3949 .db()
3950 .await
3951 .get_channel_messages(
3952 channel_id,
3953 session.user_id(),
3954 MESSAGE_COUNT_PER_PAGE,
3955 Some(MessageId::from_proto(request.before_message_id)),
3956 )
3957 .await?;
3958 response.send(proto::GetChannelMessagesResponse {
3959 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3960 messages,
3961 })?;
3962 Ok(())
3963}
3964
3965/// Retrieve specific chat messages
3966async fn get_channel_messages_by_id(
3967 request: proto::GetChannelMessagesById,
3968 response: Response<proto::GetChannelMessagesById>,
3969 session: Session,
3970) -> Result<()> {
3971 let message_ids = request
3972 .message_ids
3973 .iter()
3974 .map(|id| MessageId::from_proto(*id))
3975 .collect::<Vec<_>>();
3976 let messages = session
3977 .db()
3978 .await
3979 .get_channel_messages_by_id(session.user_id(), &message_ids)
3980 .await?;
3981 response.send(proto::GetChannelMessagesResponse {
3982 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
3983 messages,
3984 })?;
3985 Ok(())
3986}
3987
3988/// Retrieve the current users notifications
3989async fn get_notifications(
3990 request: proto::GetNotifications,
3991 response: Response<proto::GetNotifications>,
3992 session: Session,
3993) -> Result<()> {
3994 let notifications = session
3995 .db()
3996 .await
3997 .get_notifications(
3998 session.user_id(),
3999 NOTIFICATION_COUNT_PER_PAGE,
4000 request.before_id.map(db::NotificationId::from_proto),
4001 )
4002 .await?;
4003 response.send(proto::GetNotificationsResponse {
4004 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4005 notifications,
4006 })?;
4007 Ok(())
4008}
4009
4010/// Mark notifications as read
4011async fn mark_notification_as_read(
4012 request: proto::MarkNotificationRead,
4013 response: Response<proto::MarkNotificationRead>,
4014 session: Session,
4015) -> Result<()> {
4016 let database = &session.db().await;
4017 let notifications = database
4018 .mark_notification_as_read_by_id(
4019 session.user_id(),
4020 NotificationId::from_proto(request.notification_id),
4021 )
4022 .await?;
4023 send_notifications(
4024 &*session.connection_pool().await,
4025 &session.peer,
4026 notifications,
4027 );
4028 response.send(proto::Ack {})?;
4029 Ok(())
4030}
4031
4032/// Get the current users information
4033async fn get_private_user_info(
4034 _request: proto::GetPrivateUserInfo,
4035 response: Response<proto::GetPrivateUserInfo>,
4036 session: Session,
4037) -> Result<()> {
4038 let db = session.db().await;
4039
4040 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4041 let user = db
4042 .get_user_by_id(session.user_id())
4043 .await?
4044 .context("user not found")?;
4045 let flags = db.get_user_flags(session.user_id()).await?;
4046
4047 response.send(proto::GetPrivateUserInfoResponse {
4048 metrics_id,
4049 staff: user.admin,
4050 flags,
4051 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4052 })?;
4053 Ok(())
4054}
4055
4056/// Accept the terms of service (tos) on behalf of the current user
4057async fn accept_terms_of_service(
4058 _request: proto::AcceptTermsOfService,
4059 response: Response<proto::AcceptTermsOfService>,
4060 session: Session,
4061) -> Result<()> {
4062 let db = session.db().await;
4063
4064 let accepted_tos_at = Utc::now();
4065 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4066 .await?;
4067
4068 response.send(proto::AcceptTermsOfServiceResponse {
4069 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4070 })?;
4071 Ok(())
4072}
4073
4074/// The minimum account age an account must have in order to use the LLM service.
4075pub const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4076
4077async fn get_llm_api_token(
4078 _request: proto::GetLlmToken,
4079 response: Response<proto::GetLlmToken>,
4080 session: Session,
4081) -> Result<()> {
4082 let db = session.db().await;
4083
4084 let flags = db.get_user_flags(session.user_id()).await?;
4085
4086 let user_id = session.user_id();
4087 let user = db
4088 .get_user_by_id(user_id)
4089 .await?
4090 .with_context(|| format!("user {user_id} not found"))?;
4091
4092 if user.accepted_tos_at.is_none() {
4093 Err(anyhow!("terms of service not accepted"))?
4094 }
4095
4096 let stripe_client = session
4097 .app_state
4098 .stripe_client
4099 .as_ref()
4100 .context("failed to retrieve Stripe client")?;
4101
4102 let stripe_billing = session
4103 .app_state
4104 .stripe_billing
4105 .as_ref()
4106 .context("failed to retrieve Stripe billing object")?;
4107
4108 let billing_customer = if let Some(billing_customer) =
4109 db.get_billing_customer_by_user_id(user.id).await?
4110 {
4111 billing_customer
4112 } else {
4113 let customer_id = stripe_billing
4114 .find_or_create_customer_by_email(user.email_address.as_deref())
4115 .await?;
4116
4117 find_or_create_billing_customer(&session.app_state, stripe_client.as_ref(), &customer_id)
4118 .await?
4119 .context("billing customer not found")?
4120 };
4121
4122 let billing_subscription =
4123 if let Some(billing_subscription) = db.get_active_billing_subscription(user.id).await? {
4124 billing_subscription
4125 } else {
4126 let stripe_customer_id =
4127 StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
4128
4129 let stripe_subscription = stripe_billing
4130 .subscribe_to_zed_free(stripe_customer_id)
4131 .await?;
4132
4133 db.create_billing_subscription(&db::CreateBillingSubscriptionParams {
4134 billing_customer_id: billing_customer.id,
4135 kind: Some(SubscriptionKind::ZedFree),
4136 stripe_subscription_id: stripe_subscription.id.to_string(),
4137 stripe_subscription_status: stripe_subscription.status.into(),
4138 stripe_cancellation_reason: None,
4139 stripe_current_period_start: Some(stripe_subscription.current_period_start),
4140 stripe_current_period_end: Some(stripe_subscription.current_period_end),
4141 })
4142 .await?
4143 };
4144
4145 let billing_preferences = db.get_billing_preferences(user.id).await?;
4146
4147 let token = LlmTokenClaims::create(
4148 &user,
4149 session.is_staff(),
4150 billing_customer,
4151 billing_preferences,
4152 &flags,
4153 billing_subscription,
4154 session.system_id.clone(),
4155 &session.app_state.config,
4156 )?;
4157 response.send(proto::GetLlmTokenResponse { token })?;
4158 Ok(())
4159}
4160
4161fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4162 let message = match message {
4163 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload.as_str().to_string()),
4164 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload.into()),
4165 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload.into()),
4166 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload.into()),
4167 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4168 code: frame.code.into(),
4169 reason: frame.reason.as_str().to_owned().into(),
4170 })),
4171 // We should never receive a frame while reading the message, according
4172 // to the `tungstenite` maintainers:
4173 //
4174 // > It cannot occur when you read messages from the WebSocket, but it
4175 // > can be used when you want to send the raw frames (e.g. you want to
4176 // > send the frames to the WebSocket without composing the full message first).
4177 // >
4178 // > — https://github.com/snapview/tungstenite-rs/issues/268
4179 TungsteniteMessage::Frame(_) => {
4180 bail!("received an unexpected frame while reading the message")
4181 }
4182 };
4183
4184 Ok(message)
4185}
4186
4187fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4188 match message {
4189 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload.into()),
4190 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload.into()),
4191 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload.into()),
4192 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload.into()),
4193 AxumMessage::Close(frame) => {
4194 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4195 code: frame.code.into(),
4196 reason: frame.reason.as_ref().into(),
4197 }))
4198 }
4199 }
4200}
4201
4202fn notify_membership_updated(
4203 connection_pool: &mut ConnectionPool,
4204 result: MembershipUpdated,
4205 user_id: UserId,
4206 peer: &Peer,
4207) {
4208 for membership in &result.new_channels.channel_memberships {
4209 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4210 }
4211 for channel_id in &result.removed_channels {
4212 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4213 }
4214
4215 let user_channels_update = proto::UpdateUserChannels {
4216 channel_memberships: result
4217 .new_channels
4218 .channel_memberships
4219 .iter()
4220 .map(|cm| proto::ChannelMembership {
4221 channel_id: cm.channel_id.to_proto(),
4222 role: cm.role.into(),
4223 })
4224 .collect(),
4225 ..Default::default()
4226 };
4227
4228 let mut update = build_channels_update(result.new_channels);
4229 update.delete_channels = result
4230 .removed_channels
4231 .into_iter()
4232 .map(|id| id.to_proto())
4233 .collect();
4234 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4235
4236 for connection_id in connection_pool.user_connection_ids(user_id) {
4237 peer.send(connection_id, user_channels_update.clone())
4238 .trace_err();
4239 peer.send(connection_id, update.clone()).trace_err();
4240 }
4241}
4242
4243fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4244 proto::UpdateUserChannels {
4245 channel_memberships: channels
4246 .channel_memberships
4247 .iter()
4248 .map(|m| proto::ChannelMembership {
4249 channel_id: m.channel_id.to_proto(),
4250 role: m.role.into(),
4251 })
4252 .collect(),
4253 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4254 observed_channel_message_id: channels.observed_channel_messages.clone(),
4255 }
4256}
4257
4258fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
4259 let mut update = proto::UpdateChannels::default();
4260
4261 for channel in channels.channels {
4262 update.channels.push(channel.to_proto());
4263 }
4264
4265 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4266 update.latest_channel_message_ids = channels.latest_channel_messages;
4267
4268 for (channel_id, participants) in channels.channel_participants {
4269 update
4270 .channel_participants
4271 .push(proto::ChannelParticipants {
4272 channel_id: channel_id.to_proto(),
4273 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4274 });
4275 }
4276
4277 for channel in channels.invited_channels {
4278 update.channel_invitations.push(channel.to_proto());
4279 }
4280
4281 update
4282}
4283
4284fn build_initial_contacts_update(
4285 contacts: Vec<db::Contact>,
4286 pool: &ConnectionPool,
4287) -> proto::UpdateContacts {
4288 let mut update = proto::UpdateContacts::default();
4289
4290 for contact in contacts {
4291 match contact {
4292 db::Contact::Accepted { user_id, busy } => {
4293 update.contacts.push(contact_for_user(user_id, busy, pool));
4294 }
4295 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4296 db::Contact::Incoming { user_id } => {
4297 update
4298 .incoming_requests
4299 .push(proto::IncomingContactRequest {
4300 requester_id: user_id.to_proto(),
4301 })
4302 }
4303 }
4304 }
4305
4306 update
4307}
4308
4309fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4310 proto::Contact {
4311 user_id: user_id.to_proto(),
4312 online: pool.is_user_online(user_id),
4313 busy,
4314 }
4315}
4316
4317fn room_updated(room: &proto::Room, peer: &Peer) {
4318 broadcast(
4319 None,
4320 room.participants
4321 .iter()
4322 .filter_map(|participant| Some(participant.peer_id?.into())),
4323 |peer_id| {
4324 peer.send(
4325 peer_id,
4326 proto::RoomUpdated {
4327 room: Some(room.clone()),
4328 },
4329 )
4330 },
4331 );
4332}
4333
4334fn channel_updated(
4335 channel: &db::channel::Model,
4336 room: &proto::Room,
4337 peer: &Peer,
4338 pool: &ConnectionPool,
4339) {
4340 let participants = room
4341 .participants
4342 .iter()
4343 .map(|p| p.user_id)
4344 .collect::<Vec<_>>();
4345
4346 broadcast(
4347 None,
4348 pool.channel_connection_ids(channel.root_id())
4349 .filter_map(|(channel_id, role)| {
4350 role.can_see_channel(channel.visibility)
4351 .then_some(channel_id)
4352 }),
4353 |peer_id| {
4354 peer.send(
4355 peer_id,
4356 proto::UpdateChannels {
4357 channel_participants: vec![proto::ChannelParticipants {
4358 channel_id: channel.id.to_proto(),
4359 participant_user_ids: participants.clone(),
4360 }],
4361 ..Default::default()
4362 },
4363 )
4364 },
4365 );
4366}
4367
4368async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
4369 let db = session.db().await;
4370
4371 let contacts = db.get_contacts(user_id).await?;
4372 let busy = db.is_user_busy(user_id).await?;
4373
4374 let pool = session.connection_pool().await;
4375 let updated_contact = contact_for_user(user_id, busy, &pool);
4376 for contact in contacts {
4377 if let db::Contact::Accepted {
4378 user_id: contact_user_id,
4379 ..
4380 } = contact
4381 {
4382 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
4383 session
4384 .peer
4385 .send(
4386 contact_conn_id,
4387 proto::UpdateContacts {
4388 contacts: vec![updated_contact.clone()],
4389 remove_contacts: Default::default(),
4390 incoming_requests: Default::default(),
4391 remove_incoming_requests: Default::default(),
4392 outgoing_requests: Default::default(),
4393 remove_outgoing_requests: Default::default(),
4394 },
4395 )
4396 .trace_err();
4397 }
4398 }
4399 }
4400 Ok(())
4401}
4402
4403async fn leave_room_for_session(session: &Session, connection_id: ConnectionId) -> Result<()> {
4404 let mut contacts_to_update = HashSet::default();
4405
4406 let room_id;
4407 let canceled_calls_to_user_ids;
4408 let livekit_room;
4409 let delete_livekit_room;
4410 let room;
4411 let channel;
4412
4413 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
4414 contacts_to_update.insert(session.user_id());
4415
4416 for project in left_room.left_projects.values() {
4417 project_left(project, session);
4418 }
4419
4420 room_id = RoomId::from_proto(left_room.room.id);
4421 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
4422 livekit_room = mem::take(&mut left_room.room.livekit_room);
4423 delete_livekit_room = left_room.deleted;
4424 room = mem::take(&mut left_room.room);
4425 channel = mem::take(&mut left_room.channel);
4426
4427 room_updated(&room, &session.peer);
4428 } else {
4429 return Ok(());
4430 }
4431
4432 if let Some(channel) = channel {
4433 channel_updated(
4434 &channel,
4435 &room,
4436 &session.peer,
4437 &*session.connection_pool().await,
4438 );
4439 }
4440
4441 {
4442 let pool = session.connection_pool().await;
4443 for canceled_user_id in canceled_calls_to_user_ids {
4444 for connection_id in pool.user_connection_ids(canceled_user_id) {
4445 session
4446 .peer
4447 .send(
4448 connection_id,
4449 proto::CallCanceled {
4450 room_id: room_id.to_proto(),
4451 },
4452 )
4453 .trace_err();
4454 }
4455 contacts_to_update.insert(canceled_user_id);
4456 }
4457 }
4458
4459 for contact_user_id in contacts_to_update {
4460 update_user_contacts(contact_user_id, session).await?;
4461 }
4462
4463 if let Some(live_kit) = session.app_state.livekit_client.as_ref() {
4464 live_kit
4465 .remove_participant(livekit_room.clone(), session.user_id().to_string())
4466 .await
4467 .trace_err();
4468
4469 if delete_livekit_room {
4470 live_kit.delete_room(livekit_room).await.trace_err();
4471 }
4472 }
4473
4474 Ok(())
4475}
4476
4477async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
4478 let left_channel_buffers = session
4479 .db()
4480 .await
4481 .leave_channel_buffers(session.connection_id)
4482 .await?;
4483
4484 for left_buffer in left_channel_buffers {
4485 channel_buffer_updated(
4486 session.connection_id,
4487 left_buffer.connections,
4488 &proto::UpdateChannelBufferCollaborators {
4489 channel_id: left_buffer.channel_id.to_proto(),
4490 collaborators: left_buffer.collaborators,
4491 },
4492 &session.peer,
4493 );
4494 }
4495
4496 Ok(())
4497}
4498
4499fn project_left(project: &db::LeftProject, session: &Session) {
4500 for connection_id in &project.connection_ids {
4501 if project.should_unshare {
4502 session
4503 .peer
4504 .send(
4505 *connection_id,
4506 proto::UnshareProject {
4507 project_id: project.id.to_proto(),
4508 },
4509 )
4510 .trace_err();
4511 } else {
4512 session
4513 .peer
4514 .send(
4515 *connection_id,
4516 proto::RemoveProjectCollaborator {
4517 project_id: project.id.to_proto(),
4518 peer_id: Some(session.connection_id.into()),
4519 },
4520 )
4521 .trace_err();
4522 }
4523 }
4524}
4525
4526pub trait ResultExt {
4527 type Ok;
4528
4529 fn trace_err(self) -> Option<Self::Ok>;
4530}
4531
4532impl<T, E> ResultExt for Result<T, E>
4533where
4534 E: std::fmt::Debug,
4535{
4536 type Ok = T;
4537
4538 #[track_caller]
4539 fn trace_err(self) -> Option<T> {
4540 match self {
4541 Ok(value) => Some(value),
4542 Err(error) => {
4543 tracing::error!("{:?}", error);
4544 None
4545 }
4546 }
4547 }
4548}