1mod connection_pool;
2
3use crate::api::CloudflareIpCountryHeader;
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
10 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
11 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
12 ServerId, UpdatedChannelMessage, User, UserId,
13 },
14 executor::Executor,
15 AppState, Config, Error, RateLimit, Result,
16};
17use anyhow::{anyhow, bail, Context as _};
18use async_tungstenite::tungstenite::{
19 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
20};
21use axum::{
22 body::Body,
23 extract::{
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 ConnectInfo, WebSocketUpgrade,
26 },
27 headers::{Header, HeaderName},
28 http::StatusCode,
29 middleware,
30 response::IntoResponse,
31 routing::get,
32 Extension, Router, TypedHeader,
33};
34use chrono::Utc;
35use collections::{HashMap, HashSet};
36pub use connection_pool::{ConnectionPool, ZedVersion};
37use core::fmt::{self, Debug, Formatter};
38use http_client::HttpClient;
39use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
40use reqwest_client::ReqwestClient;
41use sha2::Digest;
42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
43
44use futures::{
45 channel::oneshot,
46 future::{self, BoxFuture},
47 stream::FuturesUnordered,
48 FutureExt, SinkExt, StreamExt, TryStreamExt,
49};
50use prometheus::{register_int_gauge, IntGauge};
51use rpc::{
52 proto::{
53 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
54 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
55 },
56 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
57};
58use semantic_version::SemanticVersion;
59use serde::{Serialize, Serializer};
60use std::{
61 any::TypeId,
62 future::Future,
63 marker::PhantomData,
64 mem,
65 net::SocketAddr,
66 ops::{Deref, DerefMut},
67 rc::Rc,
68 sync::{
69 atomic::{AtomicBool, Ordering::SeqCst},
70 Arc, OnceLock,
71 },
72 time::{Duration, Instant},
73};
74use time::OffsetDateTime;
75use tokio::sync::{watch, MutexGuard, Semaphore};
76use tower::ServiceBuilder;
77use tracing::{
78 field::{self},
79 info_span, instrument, Instrument,
80};
81
82pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
83
84// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
85pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
86
87const MESSAGE_COUNT_PER_PAGE: usize = 100;
88const MAX_MESSAGE_LEN: usize = 1024;
89const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
90
91type MessageHandler =
92 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
93
94struct Response<R> {
95 peer: Arc<Peer>,
96 receipt: Receipt<R>,
97 responded: Arc<AtomicBool>,
98}
99
100impl<R: RequestMessage> Response<R> {
101 fn send(self, payload: R::Response) -> Result<()> {
102 self.responded.store(true, SeqCst);
103 self.peer.respond(self.receipt, payload)?;
104 Ok(())
105 }
106}
107
108#[derive(Clone, Debug)]
109pub enum Principal {
110 User(User),
111 Impersonated { user: User, admin: User },
112 DevServer(dev_server::Model),
113}
114
115impl Principal {
116 fn update_span(&self, span: &tracing::Span) {
117 match &self {
118 Principal::User(user) => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 }
122 Principal::Impersonated { user, admin } => {
123 span.record("user_id", user.id.0);
124 span.record("login", &user.github_login);
125 span.record("impersonator", &admin.github_login);
126 }
127 Principal::DevServer(dev_server) => {
128 span.record("dev_server_id", dev_server.id.0);
129 }
130 }
131 }
132}
133
134#[derive(Clone)]
135struct Session {
136 principal: Principal,
137 connection_id: ConnectionId,
138 db: Arc<tokio::sync::Mutex<DbHandle>>,
139 peer: Arc<Peer>,
140 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
141 app_state: Arc<AppState>,
142 supermaven_client: Option<Arc<SupermavenAdminApi>>,
143 http_client: Arc<dyn HttpClient>,
144 /// The GeoIP country code for the user.
145 #[allow(unused)]
146 geoip_country_code: Option<String>,
147 _executor: Executor,
148}
149
150impl Session {
151 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
152 #[cfg(test)]
153 tokio::task::yield_now().await;
154 let guard = self.db.lock().await;
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 guard
158 }
159
160 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
161 #[cfg(test)]
162 tokio::task::yield_now().await;
163 let guard = self.connection_pool.lock();
164 ConnectionPoolGuard {
165 guard,
166 _not_send: PhantomData,
167 }
168 }
169
170 fn for_user(self) -> Option<UserSession> {
171 UserSession::new(self)
172 }
173
174 fn for_dev_server(self) -> Option<DevServerSession> {
175 DevServerSession::new(self)
176 }
177
178 fn user_id(&self) -> Option<UserId> {
179 match &self.principal {
180 Principal::User(user) => Some(user.id),
181 Principal::Impersonated { user, .. } => Some(user.id),
182 Principal::DevServer(_) => None,
183 }
184 }
185
186 fn is_staff(&self) -> bool {
187 match &self.principal {
188 Principal::User(user) => user.admin,
189 Principal::Impersonated { .. } => true,
190 Principal::DevServer(_) => false,
191 }
192 }
193
194 pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
195 if self.is_staff() {
196 return Ok(proto::Plan::ZedPro);
197 }
198
199 let Some(user_id) = self.user_id() else {
200 return Ok(proto::Plan::Free);
201 };
202
203 if db.has_active_billing_subscription(user_id).await? {
204 Ok(proto::Plan::ZedPro)
205 } else {
206 Ok(proto::Plan::Free)
207 }
208 }
209
210 fn dev_server_id(&self) -> Option<DevServerId> {
211 match &self.principal {
212 Principal::User(_) | Principal::Impersonated { .. } => None,
213 Principal::DevServer(dev_server) => Some(dev_server.id),
214 }
215 }
216
217 fn principal_id(&self) -> PrincipalId {
218 match &self.principal {
219 Principal::User(user) => PrincipalId::UserId(user.id),
220 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
221 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
222 }
223 }
224}
225
226impl Debug for Session {
227 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
228 let mut result = f.debug_struct("Session");
229 match &self.principal {
230 Principal::User(user) => {
231 result.field("user", &user.github_login);
232 }
233 Principal::Impersonated { user, admin } => {
234 result.field("user", &user.github_login);
235 result.field("impersonator", &admin.github_login);
236 }
237 Principal::DevServer(dev_server) => {
238 result.field("dev_server", &dev_server.id);
239 }
240 }
241 result.field("connection_id", &self.connection_id).finish()
242 }
243}
244
245struct UserSession(Session);
246
247impl UserSession {
248 pub fn new(s: Session) -> Option<Self> {
249 s.user_id().map(|_| UserSession(s))
250 }
251 pub fn user_id(&self) -> UserId {
252 self.0.user_id().unwrap()
253 }
254
255 pub fn email(&self) -> Option<String> {
256 match &self.0.principal {
257 Principal::User(user) => user.email_address.clone(),
258 Principal::Impersonated { user, .. } => user.email_address.clone(),
259 Principal::DevServer(..) => None,
260 }
261 }
262}
263
264impl Deref for UserSession {
265 type Target = Session;
266
267 fn deref(&self) -> &Self::Target {
268 &self.0
269 }
270}
271impl DerefMut for UserSession {
272 fn deref_mut(&mut self) -> &mut Self::Target {
273 &mut self.0
274 }
275}
276
277struct DevServerSession(Session);
278
279impl DevServerSession {
280 pub fn new(s: Session) -> Option<Self> {
281 s.dev_server_id().map(|_| DevServerSession(s))
282 }
283 pub fn dev_server_id(&self) -> DevServerId {
284 self.0.dev_server_id().unwrap()
285 }
286
287 fn dev_server(&self) -> &dev_server::Model {
288 match &self.0.principal {
289 Principal::DevServer(dev_server) => dev_server,
290 _ => unreachable!(),
291 }
292 }
293}
294
295impl Deref for DevServerSession {
296 type Target = Session;
297
298 fn deref(&self) -> &Self::Target {
299 &self.0
300 }
301}
302impl DerefMut for DevServerSession {
303 fn deref_mut(&mut self) -> &mut Self::Target {
304 &mut self.0
305 }
306}
307
308fn user_handler<M: RequestMessage, Fut>(
309 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
310) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
311where
312 Fut: Send + Future<Output = Result<()>>,
313{
314 let handler = Arc::new(handler);
315 move |message, response, session| {
316 let handler = handler.clone();
317 Box::pin(async move {
318 if let Some(user_session) = session.for_user() {
319 Ok(handler(message, response, user_session).await?)
320 } else {
321 Err(Error::Internal(anyhow!(
322 "must be a user to call {}",
323 M::NAME
324 )))
325 }
326 })
327 }
328}
329
330fn dev_server_handler<M: RequestMessage, Fut>(
331 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
332) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
333where
334 Fut: Send + Future<Output = Result<()>>,
335{
336 let handler = Arc::new(handler);
337 move |message, response, session| {
338 let handler = handler.clone();
339 Box::pin(async move {
340 if let Some(dev_server_session) = session.for_dev_server() {
341 Ok(handler(message, response, dev_server_session).await?)
342 } else {
343 Err(Error::Internal(anyhow!(
344 "must be a dev server to call {}",
345 M::NAME
346 )))
347 }
348 })
349 }
350}
351
352fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
353 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
354) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
355where
356 InnertRetFut: Send + Future<Output = Result<()>>,
357{
358 let handler = Arc::new(handler);
359 move |message, session| {
360 let handler = handler.clone();
361 Box::pin(async move {
362 if let Some(user_session) = session.for_user() {
363 Ok(handler(message, user_session).await?)
364 } else {
365 Err(Error::Internal(anyhow!(
366 "must be a user to call {}",
367 M::NAME
368 )))
369 }
370 })
371 }
372}
373
374struct DbHandle(Arc<Database>);
375
376impl Deref for DbHandle {
377 type Target = Database;
378
379 fn deref(&self) -> &Self::Target {
380 self.0.as_ref()
381 }
382}
383
384pub struct Server {
385 id: parking_lot::Mutex<ServerId>,
386 peer: Arc<Peer>,
387 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
388 app_state: Arc<AppState>,
389 handlers: HashMap<TypeId, MessageHandler>,
390 teardown: watch::Sender<bool>,
391}
392
393pub(crate) struct ConnectionPoolGuard<'a> {
394 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
395 _not_send: PhantomData<Rc<()>>,
396}
397
398#[derive(Serialize)]
399pub struct ServerSnapshot<'a> {
400 peer: &'a Peer,
401 #[serde(serialize_with = "serialize_deref")]
402 connection_pool: ConnectionPoolGuard<'a>,
403}
404
405pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
406where
407 S: Serializer,
408 T: Deref<Target = U>,
409 U: Serialize,
410{
411 Serialize::serialize(value.deref(), serializer)
412}
413
414impl Server {
415 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
416 let mut server = Self {
417 id: parking_lot::Mutex::new(id),
418 peer: Peer::new(id.0 as u32),
419 app_state: app_state.clone(),
420 connection_pool: Default::default(),
421 handlers: Default::default(),
422 teardown: watch::channel(false).0,
423 };
424
425 server
426 .add_request_handler(ping)
427 .add_request_handler(user_handler(create_room))
428 .add_request_handler(user_handler(join_room))
429 .add_request_handler(user_handler(rejoin_room))
430 .add_request_handler(user_handler(leave_room))
431 .add_request_handler(user_handler(set_room_participant_role))
432 .add_request_handler(user_handler(call))
433 .add_request_handler(user_handler(cancel_call))
434 .add_message_handler(user_message_handler(decline_call))
435 .add_request_handler(user_handler(update_participant_location))
436 .add_request_handler(user_handler(share_project))
437 .add_message_handler(unshare_project)
438 .add_request_handler(user_handler(join_project))
439 .add_request_handler(user_handler(join_hosted_project))
440 .add_request_handler(user_handler(rejoin_dev_server_projects))
441 .add_request_handler(user_handler(create_dev_server_project))
442 .add_request_handler(user_handler(update_dev_server_project))
443 .add_request_handler(user_handler(delete_dev_server_project))
444 .add_request_handler(user_handler(create_dev_server))
445 .add_request_handler(user_handler(regenerate_dev_server_token))
446 .add_request_handler(user_handler(rename_dev_server))
447 .add_request_handler(user_handler(delete_dev_server))
448 .add_request_handler(user_handler(list_remote_directory))
449 .add_request_handler(dev_server_handler(share_dev_server_project))
450 .add_request_handler(dev_server_handler(shutdown_dev_server))
451 .add_request_handler(dev_server_handler(reconnect_dev_server))
452 .add_message_handler(user_message_handler(leave_project))
453 .add_request_handler(update_project)
454 .add_request_handler(update_worktree)
455 .add_message_handler(start_language_server)
456 .add_message_handler(update_language_server)
457 .add_message_handler(update_diagnostic_summary)
458 .add_message_handler(update_worktree_settings)
459 .add_request_handler(user_handler(
460 forward_project_request_for_owner::<proto::TaskContextForLocation>,
461 ))
462 .add_request_handler(user_handler(
463 forward_project_request_for_owner::<proto::TaskTemplates>,
464 ))
465 .add_request_handler(user_handler(
466 forward_read_only_project_request::<proto::GetHover>,
467 ))
468 .add_request_handler(user_handler(
469 forward_read_only_project_request::<proto::GetDefinition>,
470 ))
471 .add_request_handler(user_handler(
472 forward_read_only_project_request::<proto::GetTypeDefinition>,
473 ))
474 .add_request_handler(user_handler(
475 forward_read_only_project_request::<proto::GetReferences>,
476 ))
477 .add_request_handler(user_handler(forward_find_search_candidates_request))
478 .add_request_handler(user_handler(
479 forward_read_only_project_request::<proto::GetDocumentHighlights>,
480 ))
481 .add_request_handler(user_handler(
482 forward_read_only_project_request::<proto::GetProjectSymbols>,
483 ))
484 .add_request_handler(user_handler(
485 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
486 ))
487 .add_request_handler(user_handler(
488 forward_read_only_project_request::<proto::OpenBufferById>,
489 ))
490 .add_request_handler(user_handler(
491 forward_read_only_project_request::<proto::SynchronizeBuffers>,
492 ))
493 .add_request_handler(user_handler(
494 forward_read_only_project_request::<proto::InlayHints>,
495 ))
496 .add_request_handler(user_handler(
497 forward_read_only_project_request::<proto::ResolveInlayHint>,
498 ))
499 .add_request_handler(user_handler(
500 forward_read_only_project_request::<proto::OpenBufferByPath>,
501 ))
502 .add_request_handler(user_handler(
503 forward_mutating_project_request::<proto::GetCompletions>,
504 ))
505 .add_request_handler(user_handler(
506 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
507 ))
508 .add_request_handler(user_handler(
509 forward_mutating_project_request::<proto::OpenNewBuffer>,
510 ))
511 .add_request_handler(user_handler(
512 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
513 ))
514 .add_request_handler(user_handler(
515 forward_mutating_project_request::<proto::GetCodeActions>,
516 ))
517 .add_request_handler(user_handler(
518 forward_mutating_project_request::<proto::ApplyCodeAction>,
519 ))
520 .add_request_handler(user_handler(
521 forward_mutating_project_request::<proto::PrepareRename>,
522 ))
523 .add_request_handler(user_handler(
524 forward_mutating_project_request::<proto::PerformRename>,
525 ))
526 .add_request_handler(user_handler(
527 forward_mutating_project_request::<proto::ReloadBuffers>,
528 ))
529 .add_request_handler(user_handler(
530 forward_mutating_project_request::<proto::FormatBuffers>,
531 ))
532 .add_request_handler(user_handler(
533 forward_mutating_project_request::<proto::CreateProjectEntry>,
534 ))
535 .add_request_handler(user_handler(
536 forward_mutating_project_request::<proto::RenameProjectEntry>,
537 ))
538 .add_request_handler(user_handler(
539 forward_mutating_project_request::<proto::CopyProjectEntry>,
540 ))
541 .add_request_handler(user_handler(
542 forward_mutating_project_request::<proto::DeleteProjectEntry>,
543 ))
544 .add_request_handler(user_handler(
545 forward_mutating_project_request::<proto::ExpandProjectEntry>,
546 ))
547 .add_request_handler(user_handler(
548 forward_mutating_project_request::<proto::OnTypeFormatting>,
549 ))
550 .add_request_handler(user_handler(
551 forward_mutating_project_request::<proto::SaveBuffer>,
552 ))
553 .add_request_handler(user_handler(
554 forward_mutating_project_request::<proto::BlameBuffer>,
555 ))
556 .add_request_handler(user_handler(
557 forward_mutating_project_request::<proto::MultiLspQuery>,
558 ))
559 .add_request_handler(user_handler(
560 forward_mutating_project_request::<proto::RestartLanguageServers>,
561 ))
562 .add_request_handler(user_handler(
563 forward_mutating_project_request::<proto::LinkedEditingRange>,
564 ))
565 .add_message_handler(create_buffer_for_peer)
566 .add_request_handler(update_buffer)
567 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
568 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
569 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
570 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
571 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
572 .add_request_handler(get_users)
573 .add_request_handler(user_handler(fuzzy_search_users))
574 .add_request_handler(user_handler(request_contact))
575 .add_request_handler(user_handler(remove_contact))
576 .add_request_handler(user_handler(respond_to_contact_request))
577 .add_message_handler(subscribe_to_channels)
578 .add_request_handler(user_handler(create_channel))
579 .add_request_handler(user_handler(delete_channel))
580 .add_request_handler(user_handler(invite_channel_member))
581 .add_request_handler(user_handler(remove_channel_member))
582 .add_request_handler(user_handler(set_channel_member_role))
583 .add_request_handler(user_handler(set_channel_visibility))
584 .add_request_handler(user_handler(rename_channel))
585 .add_request_handler(user_handler(join_channel_buffer))
586 .add_request_handler(user_handler(leave_channel_buffer))
587 .add_message_handler(user_message_handler(update_channel_buffer))
588 .add_request_handler(user_handler(rejoin_channel_buffers))
589 .add_request_handler(user_handler(get_channel_members))
590 .add_request_handler(user_handler(respond_to_channel_invite))
591 .add_request_handler(user_handler(join_channel))
592 .add_request_handler(user_handler(join_channel_chat))
593 .add_message_handler(user_message_handler(leave_channel_chat))
594 .add_request_handler(user_handler(send_channel_message))
595 .add_request_handler(user_handler(remove_channel_message))
596 .add_request_handler(user_handler(update_channel_message))
597 .add_request_handler(user_handler(get_channel_messages))
598 .add_request_handler(user_handler(get_channel_messages_by_id))
599 .add_request_handler(user_handler(get_notifications))
600 .add_request_handler(user_handler(mark_notification_as_read))
601 .add_request_handler(user_handler(move_channel))
602 .add_request_handler(user_handler(follow))
603 .add_message_handler(user_message_handler(unfollow))
604 .add_message_handler(user_message_handler(update_followers))
605 .add_request_handler(user_handler(get_private_user_info))
606 .add_request_handler(user_handler(get_llm_api_token))
607 .add_request_handler(user_handler(accept_terms_of_service))
608 .add_message_handler(user_message_handler(acknowledge_channel_message))
609 .add_message_handler(user_message_handler(acknowledge_buffer_version))
610 .add_request_handler(user_handler(get_supermaven_api_key))
611 .add_request_handler(user_handler(
612 forward_mutating_project_request::<proto::OpenContext>,
613 ))
614 .add_request_handler(user_handler(
615 forward_mutating_project_request::<proto::CreateContext>,
616 ))
617 .add_request_handler(user_handler(
618 forward_mutating_project_request::<proto::SynchronizeContexts>,
619 ))
620 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
621 .add_message_handler(update_context)
622 .add_request_handler({
623 let app_state = app_state.clone();
624 move |request, response, session| {
625 let app_state = app_state.clone();
626 async move {
627 count_language_model_tokens(request, response, session, &app_state.config)
628 .await
629 }
630 }
631 })
632 .add_request_handler({
633 user_handler(move |request, response, session| {
634 get_cached_embeddings(request, response, session)
635 })
636 })
637 .add_request_handler({
638 let app_state = app_state.clone();
639 user_handler(move |request, response, session| {
640 compute_embeddings(
641 request,
642 response,
643 session,
644 app_state.config.openai_api_key.clone(),
645 )
646 })
647 });
648
649 Arc::new(server)
650 }
651
652 pub async fn start(&self) -> Result<()> {
653 let server_id = *self.id.lock();
654 let app_state = self.app_state.clone();
655 let peer = self.peer.clone();
656 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
657 let pool = self.connection_pool.clone();
658 let live_kit_client = self.app_state.live_kit_client.clone();
659
660 let span = info_span!("start server");
661 self.app_state.executor.spawn_detached(
662 async move {
663 tracing::info!("waiting for cleanup timeout");
664 timeout.await;
665 tracing::info!("cleanup timeout expired, retrieving stale rooms");
666 if let Some((room_ids, channel_ids)) = app_state
667 .db
668 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
669 .await
670 .trace_err()
671 {
672 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
673 tracing::info!(
674 stale_channel_buffer_count = channel_ids.len(),
675 "retrieved stale channel buffers"
676 );
677
678 for channel_id in channel_ids {
679 if let Some(refreshed_channel_buffer) = app_state
680 .db
681 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
682 .await
683 .trace_err()
684 {
685 for connection_id in refreshed_channel_buffer.connection_ids {
686 peer.send(
687 connection_id,
688 proto::UpdateChannelBufferCollaborators {
689 channel_id: channel_id.to_proto(),
690 collaborators: refreshed_channel_buffer
691 .collaborators
692 .clone(),
693 },
694 )
695 .trace_err();
696 }
697 }
698 }
699
700 for room_id in room_ids {
701 let mut contacts_to_update = HashSet::default();
702 let mut canceled_calls_to_user_ids = Vec::new();
703 let mut live_kit_room = String::new();
704 let mut delete_live_kit_room = false;
705
706 if let Some(mut refreshed_room) = app_state
707 .db
708 .clear_stale_room_participants(room_id, server_id)
709 .await
710 .trace_err()
711 {
712 tracing::info!(
713 room_id = room_id.0,
714 new_participant_count = refreshed_room.room.participants.len(),
715 "refreshed room"
716 );
717 room_updated(&refreshed_room.room, &peer);
718 if let Some(channel) = refreshed_room.channel.as_ref() {
719 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
720 }
721 contacts_to_update
722 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
723 contacts_to_update
724 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
725 canceled_calls_to_user_ids =
726 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
727 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
728 delete_live_kit_room = refreshed_room.room.participants.is_empty();
729 }
730
731 {
732 let pool = pool.lock();
733 for canceled_user_id in canceled_calls_to_user_ids {
734 for connection_id in pool.user_connection_ids(canceled_user_id) {
735 peer.send(
736 connection_id,
737 proto::CallCanceled {
738 room_id: room_id.to_proto(),
739 },
740 )
741 .trace_err();
742 }
743 }
744 }
745
746 for user_id in contacts_to_update {
747 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
748 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
749 if let Some((busy, contacts)) = busy.zip(contacts) {
750 let pool = pool.lock();
751 let updated_contact = contact_for_user(user_id, busy, &pool);
752 for contact in contacts {
753 if let db::Contact::Accepted {
754 user_id: contact_user_id,
755 ..
756 } = contact
757 {
758 for contact_conn_id in
759 pool.user_connection_ids(contact_user_id)
760 {
761 peer.send(
762 contact_conn_id,
763 proto::UpdateContacts {
764 contacts: vec![updated_contact.clone()],
765 remove_contacts: Default::default(),
766 incoming_requests: Default::default(),
767 remove_incoming_requests: Default::default(),
768 outgoing_requests: Default::default(),
769 remove_outgoing_requests: Default::default(),
770 },
771 )
772 .trace_err();
773 }
774 }
775 }
776 }
777 }
778
779 if let Some(live_kit) = live_kit_client.as_ref() {
780 if delete_live_kit_room {
781 live_kit.delete_room(live_kit_room).await.trace_err();
782 }
783 }
784 }
785 }
786
787 app_state
788 .db
789 .delete_stale_servers(&app_state.config.zed_environment, server_id)
790 .await
791 .trace_err();
792 }
793 .instrument(span),
794 );
795 Ok(())
796 }
797
798 pub fn teardown(&self) {
799 self.peer.teardown();
800 self.connection_pool.lock().reset();
801 let _ = self.teardown.send(true);
802 }
803
804 #[cfg(test)]
805 pub fn reset(&self, id: ServerId) {
806 self.teardown();
807 *self.id.lock() = id;
808 self.peer.reset(id.0 as u32);
809 let _ = self.teardown.send(false);
810 }
811
812 #[cfg(test)]
813 pub fn id(&self) -> ServerId {
814 *self.id.lock()
815 }
816
817 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
818 where
819 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
820 Fut: 'static + Send + Future<Output = Result<()>>,
821 M: EnvelopedMessage,
822 {
823 let prev_handler = self.handlers.insert(
824 TypeId::of::<M>(),
825 Box::new(move |envelope, session| {
826 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
827 let received_at = envelope.received_at;
828 tracing::info!("message received");
829 let start_time = Instant::now();
830 let future = (handler)(*envelope, session);
831 async move {
832 let result = future.await;
833 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
834 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
835 let queue_duration_ms = total_duration_ms - processing_duration_ms;
836 let payload_type = M::NAME;
837
838 match result {
839 Err(error) => {
840 tracing::error!(
841 ?error,
842 total_duration_ms,
843 processing_duration_ms,
844 queue_duration_ms,
845 payload_type,
846 "error handling message"
847 )
848 }
849 Ok(()) => tracing::info!(
850 total_duration_ms,
851 processing_duration_ms,
852 queue_duration_ms,
853 "finished handling message"
854 ),
855 }
856 }
857 .boxed()
858 }),
859 );
860 if prev_handler.is_some() {
861 panic!("registered a handler for the same message twice");
862 }
863 self
864 }
865
866 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
867 where
868 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
869 Fut: 'static + Send + Future<Output = Result<()>>,
870 M: EnvelopedMessage,
871 {
872 self.add_handler(move |envelope, session| handler(envelope.payload, session));
873 self
874 }
875
876 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
877 where
878 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
879 Fut: Send + Future<Output = Result<()>>,
880 M: RequestMessage,
881 {
882 let handler = Arc::new(handler);
883 self.add_handler(move |envelope, session| {
884 let receipt = envelope.receipt();
885 let handler = handler.clone();
886 async move {
887 let peer = session.peer.clone();
888 let responded = Arc::new(AtomicBool::default());
889 let response = Response {
890 peer: peer.clone(),
891 responded: responded.clone(),
892 receipt,
893 };
894 match (handler)(envelope.payload, response, session).await {
895 Ok(()) => {
896 if responded.load(std::sync::atomic::Ordering::SeqCst) {
897 Ok(())
898 } else {
899 Err(anyhow!("handler did not send a response"))?
900 }
901 }
902 Err(error) => {
903 let proto_err = match &error {
904 Error::Internal(err) => err.to_proto(),
905 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
906 };
907 peer.respond_with_error(receipt, proto_err)?;
908 Err(error)
909 }
910 }
911 }
912 })
913 }
914
915 #[allow(clippy::too_many_arguments)]
916 pub fn handle_connection(
917 self: &Arc<Self>,
918 connection: Connection,
919 address: String,
920 principal: Principal,
921 zed_version: ZedVersion,
922 geoip_country_code: Option<String>,
923 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
924 executor: Executor,
925 ) -> impl Future<Output = ()> {
926 let this = self.clone();
927 let span = info_span!("handle connection", %address,
928 connection_id=field::Empty,
929 user_id=field::Empty,
930 login=field::Empty,
931 impersonator=field::Empty,
932 dev_server_id=field::Empty,
933 geoip_country_code=field::Empty
934 );
935 principal.update_span(&span);
936 if let Some(country_code) = geoip_country_code.as_ref() {
937 span.record("geoip_country_code", country_code);
938 }
939
940 let mut teardown = self.teardown.subscribe();
941 async move {
942 if *teardown.borrow() {
943 tracing::error!("server is tearing down");
944 return
945 }
946 let (connection_id, handle_io, mut incoming_rx) = this
947 .peer
948 .add_connection(connection, {
949 let executor = executor.clone();
950 move |duration| executor.sleep(duration)
951 });
952 tracing::Span::current().record("connection_id", format!("{}", connection_id));
953
954 tracing::info!("connection opened");
955
956 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
957 let http_client = match ReqwestClient::user_agent(&user_agent) {
958 Ok(http_client) => Arc::new(http_client),
959 Err(error) => {
960 tracing::error!(?error, "failed to create HTTP client");
961 return;
962 }
963 };
964
965 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
966 supermaven_admin_api_key.to_string(),
967 http_client.clone(),
968 )));
969
970 let session = Session {
971 principal: principal.clone(),
972 connection_id,
973 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
974 peer: this.peer.clone(),
975 connection_pool: this.connection_pool.clone(),
976 app_state: this.app_state.clone(),
977 http_client,
978 geoip_country_code,
979 _executor: executor.clone(),
980 supermaven_client,
981 };
982
983 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
984 tracing::error!(?error, "failed to send initial client update");
985 return;
986 }
987
988 let handle_io = handle_io.fuse();
989 futures::pin_mut!(handle_io);
990
991 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
992 // This prevents deadlocks when e.g., client A performs a request to client B and
993 // client B performs a request to client A. If both clients stop processing further
994 // messages until their respective request completes, they won't have a chance to
995 // respond to the other client's request and cause a deadlock.
996 //
997 // This arrangement ensures we will attempt to process earlier messages first, but fall
998 // back to processing messages arrived later in the spirit of making progress.
999 let mut foreground_message_handlers = FuturesUnordered::new();
1000 let concurrent_handlers = Arc::new(Semaphore::new(256));
1001 loop {
1002 let next_message = async {
1003 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1004 let message = incoming_rx.next().await;
1005 (permit, message)
1006 }.fuse();
1007 futures::pin_mut!(next_message);
1008 futures::select_biased! {
1009 _ = teardown.changed().fuse() => return,
1010 result = handle_io => {
1011 if let Err(error) = result {
1012 tracing::error!(?error, "error handling I/O");
1013 }
1014 break;
1015 }
1016 _ = foreground_message_handlers.next() => {}
1017 next_message = next_message => {
1018 let (permit, message) = next_message;
1019 if let Some(message) = message {
1020 let type_name = message.payload_type_name();
1021 // note: we copy all the fields from the parent span so we can query them in the logs.
1022 // (https://github.com/tokio-rs/tracing/issues/2670).
1023 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1024 user_id=field::Empty,
1025 login=field::Empty,
1026 impersonator=field::Empty,
1027 dev_server_id=field::Empty
1028 );
1029 principal.update_span(&span);
1030 let span_enter = span.enter();
1031 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1032 let is_background = message.is_background();
1033 let handle_message = (handler)(message, session.clone());
1034 drop(span_enter);
1035
1036 let handle_message = async move {
1037 handle_message.await;
1038 drop(permit);
1039 }.instrument(span);
1040 if is_background {
1041 executor.spawn_detached(handle_message);
1042 } else {
1043 foreground_message_handlers.push(handle_message);
1044 }
1045 } else {
1046 tracing::error!("no message handler");
1047 }
1048 } else {
1049 tracing::info!("connection closed");
1050 break;
1051 }
1052 }
1053 }
1054 }
1055
1056 drop(foreground_message_handlers);
1057 tracing::info!("signing out");
1058 if let Err(error) = connection_lost(session, teardown, executor).await {
1059 tracing::error!(?error, "error signing out");
1060 }
1061
1062 }.instrument(span)
1063 }
1064
1065 async fn send_initial_client_update(
1066 &self,
1067 connection_id: ConnectionId,
1068 principal: &Principal,
1069 zed_version: ZedVersion,
1070 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1071 session: &Session,
1072 ) -> Result<()> {
1073 self.peer.send(
1074 connection_id,
1075 proto::Hello {
1076 peer_id: Some(connection_id.into()),
1077 },
1078 )?;
1079 tracing::info!("sent hello message");
1080 if let Some(send_connection_id) = send_connection_id.take() {
1081 let _ = send_connection_id.send(connection_id);
1082 }
1083
1084 match principal {
1085 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1086 if !user.connected_once {
1087 self.peer.send(connection_id, proto::ShowContacts {})?;
1088 self.app_state
1089 .db
1090 .set_user_connected_once(user.id, true)
1091 .await?;
1092 }
1093
1094 update_user_plan(user.id, session).await?;
1095
1096 let (contacts, dev_server_projects) = future::try_join(
1097 self.app_state.db.get_contacts(user.id),
1098 self.app_state.db.dev_server_projects_update(user.id),
1099 )
1100 .await?;
1101
1102 {
1103 let mut pool = self.connection_pool.lock();
1104 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1105 self.peer.send(
1106 connection_id,
1107 build_initial_contacts_update(contacts, &pool),
1108 )?;
1109 }
1110
1111 if should_auto_subscribe_to_channels(zed_version) {
1112 subscribe_user_to_channels(user.id, session).await?;
1113 }
1114
1115 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1116
1117 if let Some(incoming_call) =
1118 self.app_state.db.incoming_call_for_user(user.id).await?
1119 {
1120 self.peer.send(connection_id, incoming_call)?;
1121 }
1122
1123 update_user_contacts(user.id, session).await?;
1124 }
1125 Principal::DevServer(dev_server) => {
1126 {
1127 let mut pool = self.connection_pool.lock();
1128 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1129 {
1130 self.peer.send(
1131 stale_connection_id,
1132 proto::ShutdownDevServer {
1133 reason: Some(
1134 "another dev server connected with the same token".to_string(),
1135 ),
1136 },
1137 )?;
1138 pool.remove_connection(stale_connection_id)?;
1139 };
1140 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1141 }
1142
1143 let projects = self
1144 .app_state
1145 .db
1146 .get_projects_for_dev_server(dev_server.id)
1147 .await?;
1148 self.peer
1149 .send(connection_id, proto::DevServerInstructions { projects })?;
1150
1151 let status = self
1152 .app_state
1153 .db
1154 .dev_server_projects_update(dev_server.user_id)
1155 .await?;
1156 send_dev_server_projects_update(dev_server.user_id, status, session).await;
1157 }
1158 }
1159
1160 Ok(())
1161 }
1162
1163 pub async fn invite_code_redeemed(
1164 self: &Arc<Self>,
1165 inviter_id: UserId,
1166 invitee_id: UserId,
1167 ) -> Result<()> {
1168 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1169 if let Some(code) = &user.invite_code {
1170 let pool = self.connection_pool.lock();
1171 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1172 for connection_id in pool.user_connection_ids(inviter_id) {
1173 self.peer.send(
1174 connection_id,
1175 proto::UpdateContacts {
1176 contacts: vec![invitee_contact.clone()],
1177 ..Default::default()
1178 },
1179 )?;
1180 self.peer.send(
1181 connection_id,
1182 proto::UpdateInviteInfo {
1183 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1184 count: user.invite_count as u32,
1185 },
1186 )?;
1187 }
1188 }
1189 }
1190 Ok(())
1191 }
1192
1193 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1194 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1195 if let Some(invite_code) = &user.invite_code {
1196 let pool = self.connection_pool.lock();
1197 for connection_id in pool.user_connection_ids(user_id) {
1198 self.peer.send(
1199 connection_id,
1200 proto::UpdateInviteInfo {
1201 url: format!(
1202 "{}{}",
1203 self.app_state.config.invite_link_prefix, invite_code
1204 ),
1205 count: user.invite_count as u32,
1206 },
1207 )?;
1208 }
1209 }
1210 }
1211 Ok(())
1212 }
1213
1214 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1215 ServerSnapshot {
1216 connection_pool: ConnectionPoolGuard {
1217 guard: self.connection_pool.lock(),
1218 _not_send: PhantomData,
1219 },
1220 peer: &self.peer,
1221 }
1222 }
1223}
1224
1225impl<'a> Deref for ConnectionPoolGuard<'a> {
1226 type Target = ConnectionPool;
1227
1228 fn deref(&self) -> &Self::Target {
1229 &self.guard
1230 }
1231}
1232
1233impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1234 fn deref_mut(&mut self) -> &mut Self::Target {
1235 &mut self.guard
1236 }
1237}
1238
1239impl<'a> Drop for ConnectionPoolGuard<'a> {
1240 fn drop(&mut self) {
1241 #[cfg(test)]
1242 self.check_invariants();
1243 }
1244}
1245
1246fn broadcast<F>(
1247 sender_id: Option<ConnectionId>,
1248 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1249 mut f: F,
1250) where
1251 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1252{
1253 for receiver_id in receiver_ids {
1254 if Some(receiver_id) != sender_id {
1255 if let Err(error) = f(receiver_id) {
1256 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1257 }
1258 }
1259 }
1260}
1261
1262pub struct ProtocolVersion(u32);
1263
1264impl Header for ProtocolVersion {
1265 fn name() -> &'static HeaderName {
1266 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1267 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1268 }
1269
1270 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1271 where
1272 Self: Sized,
1273 I: Iterator<Item = &'i axum::http::HeaderValue>,
1274 {
1275 let version = values
1276 .next()
1277 .ok_or_else(axum::headers::Error::invalid)?
1278 .to_str()
1279 .map_err(|_| axum::headers::Error::invalid())?
1280 .parse()
1281 .map_err(|_| axum::headers::Error::invalid())?;
1282 Ok(Self(version))
1283 }
1284
1285 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1286 values.extend([self.0.to_string().parse().unwrap()]);
1287 }
1288}
1289
1290pub struct AppVersionHeader(SemanticVersion);
1291impl Header for AppVersionHeader {
1292 fn name() -> &'static HeaderName {
1293 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1294 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1295 }
1296
1297 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1298 where
1299 Self: Sized,
1300 I: Iterator<Item = &'i axum::http::HeaderValue>,
1301 {
1302 let version = values
1303 .next()
1304 .ok_or_else(axum::headers::Error::invalid)?
1305 .to_str()
1306 .map_err(|_| axum::headers::Error::invalid())?
1307 .parse()
1308 .map_err(|_| axum::headers::Error::invalid())?;
1309 Ok(Self(version))
1310 }
1311
1312 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1313 values.extend([self.0.to_string().parse().unwrap()]);
1314 }
1315}
1316
1317pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1318 Router::new()
1319 .route("/rpc", get(handle_websocket_request))
1320 .layer(
1321 ServiceBuilder::new()
1322 .layer(Extension(server.app_state.clone()))
1323 .layer(middleware::from_fn(auth::validate_header)),
1324 )
1325 .route("/metrics", get(handle_metrics))
1326 .layer(Extension(server))
1327}
1328
1329pub async fn handle_websocket_request(
1330 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1331 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1332 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1333 Extension(server): Extension<Arc<Server>>,
1334 Extension(principal): Extension<Principal>,
1335 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1336 ws: WebSocketUpgrade,
1337) -> axum::response::Response {
1338 if protocol_version != rpc::PROTOCOL_VERSION {
1339 return (
1340 StatusCode::UPGRADE_REQUIRED,
1341 "client must be upgraded".to_string(),
1342 )
1343 .into_response();
1344 }
1345
1346 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1347 return (
1348 StatusCode::UPGRADE_REQUIRED,
1349 "no version header found".to_string(),
1350 )
1351 .into_response();
1352 };
1353
1354 if !version.can_collaborate() {
1355 return (
1356 StatusCode::UPGRADE_REQUIRED,
1357 "client must be upgraded".to_string(),
1358 )
1359 .into_response();
1360 }
1361
1362 let socket_address = socket_address.to_string();
1363 ws.on_upgrade(move |socket| {
1364 let socket = socket
1365 .map_ok(to_tungstenite_message)
1366 .err_into()
1367 .with(|message| async move { to_axum_message(message) });
1368 let connection = Connection::new(Box::pin(socket));
1369 async move {
1370 server
1371 .handle_connection(
1372 connection,
1373 socket_address,
1374 principal,
1375 version,
1376 country_code_header.map(|header| header.to_string()),
1377 None,
1378 Executor::Production,
1379 )
1380 .await;
1381 }
1382 })
1383}
1384
1385pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1386 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1387 let connections_metric = CONNECTIONS_METRIC
1388 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1389
1390 let connections = server
1391 .connection_pool
1392 .lock()
1393 .connections()
1394 .filter(|connection| !connection.admin)
1395 .count();
1396 connections_metric.set(connections as _);
1397
1398 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1399 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1400 register_int_gauge!(
1401 "shared_projects",
1402 "number of open projects with one or more guests"
1403 )
1404 .unwrap()
1405 });
1406
1407 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1408 shared_projects_metric.set(shared_projects as _);
1409
1410 let encoder = prometheus::TextEncoder::new();
1411 let metric_families = prometheus::gather();
1412 let encoded_metrics = encoder
1413 .encode_to_string(&metric_families)
1414 .map_err(|err| anyhow!("{}", err))?;
1415 Ok(encoded_metrics)
1416}
1417
1418#[instrument(err, skip(executor))]
1419async fn connection_lost(
1420 session: Session,
1421 mut teardown: watch::Receiver<bool>,
1422 executor: Executor,
1423) -> Result<()> {
1424 session.peer.disconnect(session.connection_id);
1425 session
1426 .connection_pool()
1427 .await
1428 .remove_connection(session.connection_id)?;
1429
1430 session
1431 .db()
1432 .await
1433 .connection_lost(session.connection_id)
1434 .await
1435 .trace_err();
1436
1437 futures::select_biased! {
1438 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1439 match &session.principal {
1440 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1441 let session = session.for_user().unwrap();
1442
1443 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1444 leave_room_for_session(&session, session.connection_id).await.trace_err();
1445 leave_channel_buffers_for_session(&session)
1446 .await
1447 .trace_err();
1448
1449 if !session
1450 .connection_pool()
1451 .await
1452 .is_user_online(session.user_id())
1453 {
1454 let db = session.db().await;
1455 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1456 room_updated(&room, &session.peer);
1457 }
1458 }
1459
1460 update_user_contacts(session.user_id(), &session).await?;
1461 },
1462 Principal::DevServer(_) => {
1463 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1464 },
1465 }
1466 },
1467 _ = teardown.changed().fuse() => {}
1468 }
1469
1470 Ok(())
1471}
1472
1473/// Acknowledges a ping from a client, used to keep the connection alive.
1474async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1475 response.send(proto::Ack {})?;
1476 Ok(())
1477}
1478
1479/// Creates a new room for calling (outside of channels)
1480async fn create_room(
1481 _request: proto::CreateRoom,
1482 response: Response<proto::CreateRoom>,
1483 session: UserSession,
1484) -> Result<()> {
1485 let live_kit_room = nanoid::nanoid!(30);
1486
1487 let live_kit_connection_info = util::maybe!(async {
1488 let live_kit = session.app_state.live_kit_client.as_ref();
1489 let live_kit = live_kit?;
1490 let user_id = session.user_id().to_string();
1491
1492 let token = live_kit
1493 .room_token(&live_kit_room, &user_id.to_string())
1494 .trace_err()?;
1495
1496 Some(proto::LiveKitConnectionInfo {
1497 server_url: live_kit.url().into(),
1498 token,
1499 can_publish: true,
1500 })
1501 })
1502 .await;
1503
1504 let room = session
1505 .db()
1506 .await
1507 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1508 .await?;
1509
1510 response.send(proto::CreateRoomResponse {
1511 room: Some(room.clone()),
1512 live_kit_connection_info,
1513 })?;
1514
1515 update_user_contacts(session.user_id(), &session).await?;
1516 Ok(())
1517}
1518
1519/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1520async fn join_room(
1521 request: proto::JoinRoom,
1522 response: Response<proto::JoinRoom>,
1523 session: UserSession,
1524) -> Result<()> {
1525 let room_id = RoomId::from_proto(request.id);
1526
1527 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1528
1529 if let Some(channel_id) = channel_id {
1530 return join_channel_internal(channel_id, Box::new(response), session).await;
1531 }
1532
1533 let joined_room = {
1534 let room = session
1535 .db()
1536 .await
1537 .join_room(room_id, session.user_id(), session.connection_id)
1538 .await?;
1539 room_updated(&room.room, &session.peer);
1540 room.into_inner()
1541 };
1542
1543 for connection_id in session
1544 .connection_pool()
1545 .await
1546 .user_connection_ids(session.user_id())
1547 {
1548 session
1549 .peer
1550 .send(
1551 connection_id,
1552 proto::CallCanceled {
1553 room_id: room_id.to_proto(),
1554 },
1555 )
1556 .trace_err();
1557 }
1558
1559 let live_kit_connection_info =
1560 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1561 live_kit
1562 .room_token(
1563 &joined_room.room.live_kit_room,
1564 &session.user_id().to_string(),
1565 )
1566 .trace_err()
1567 .map(|token| proto::LiveKitConnectionInfo {
1568 server_url: live_kit.url().into(),
1569 token,
1570 can_publish: true,
1571 })
1572 } else {
1573 None
1574 };
1575
1576 response.send(proto::JoinRoomResponse {
1577 room: Some(joined_room.room),
1578 channel_id: None,
1579 live_kit_connection_info,
1580 })?;
1581
1582 update_user_contacts(session.user_id(), &session).await?;
1583 Ok(())
1584}
1585
1586/// Rejoin room is used to reconnect to a room after connection errors.
1587async fn rejoin_room(
1588 request: proto::RejoinRoom,
1589 response: Response<proto::RejoinRoom>,
1590 session: UserSession,
1591) -> Result<()> {
1592 let room;
1593 let channel;
1594 {
1595 let mut rejoined_room = session
1596 .db()
1597 .await
1598 .rejoin_room(request, session.user_id(), session.connection_id)
1599 .await?;
1600
1601 response.send(proto::RejoinRoomResponse {
1602 room: Some(rejoined_room.room.clone()),
1603 reshared_projects: rejoined_room
1604 .reshared_projects
1605 .iter()
1606 .map(|project| proto::ResharedProject {
1607 id: project.id.to_proto(),
1608 collaborators: project
1609 .collaborators
1610 .iter()
1611 .map(|collaborator| collaborator.to_proto())
1612 .collect(),
1613 })
1614 .collect(),
1615 rejoined_projects: rejoined_room
1616 .rejoined_projects
1617 .iter()
1618 .map(|rejoined_project| rejoined_project.to_proto())
1619 .collect(),
1620 })?;
1621 room_updated(&rejoined_room.room, &session.peer);
1622
1623 for project in &rejoined_room.reshared_projects {
1624 for collaborator in &project.collaborators {
1625 session
1626 .peer
1627 .send(
1628 collaborator.connection_id,
1629 proto::UpdateProjectCollaborator {
1630 project_id: project.id.to_proto(),
1631 old_peer_id: Some(project.old_connection_id.into()),
1632 new_peer_id: Some(session.connection_id.into()),
1633 },
1634 )
1635 .trace_err();
1636 }
1637
1638 broadcast(
1639 Some(session.connection_id),
1640 project
1641 .collaborators
1642 .iter()
1643 .map(|collaborator| collaborator.connection_id),
1644 |connection_id| {
1645 session.peer.forward_send(
1646 session.connection_id,
1647 connection_id,
1648 proto::UpdateProject {
1649 project_id: project.id.to_proto(),
1650 worktrees: project.worktrees.clone(),
1651 },
1652 )
1653 },
1654 );
1655 }
1656
1657 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1658
1659 let rejoined_room = rejoined_room.into_inner();
1660
1661 room = rejoined_room.room;
1662 channel = rejoined_room.channel;
1663 }
1664
1665 if let Some(channel) = channel {
1666 channel_updated(
1667 &channel,
1668 &room,
1669 &session.peer,
1670 &*session.connection_pool().await,
1671 );
1672 }
1673
1674 update_user_contacts(session.user_id(), &session).await?;
1675 Ok(())
1676}
1677
1678fn notify_rejoined_projects(
1679 rejoined_projects: &mut Vec<RejoinedProject>,
1680 session: &UserSession,
1681) -> Result<()> {
1682 for project in rejoined_projects.iter() {
1683 for collaborator in &project.collaborators {
1684 session
1685 .peer
1686 .send(
1687 collaborator.connection_id,
1688 proto::UpdateProjectCollaborator {
1689 project_id: project.id.to_proto(),
1690 old_peer_id: Some(project.old_connection_id.into()),
1691 new_peer_id: Some(session.connection_id.into()),
1692 },
1693 )
1694 .trace_err();
1695 }
1696 }
1697
1698 for project in rejoined_projects {
1699 for worktree in mem::take(&mut project.worktrees) {
1700 #[cfg(any(test, feature = "test-support"))]
1701 const MAX_CHUNK_SIZE: usize = 2;
1702 #[cfg(not(any(test, feature = "test-support")))]
1703 const MAX_CHUNK_SIZE: usize = 256;
1704
1705 // Stream this worktree's entries.
1706 let message = proto::UpdateWorktree {
1707 project_id: project.id.to_proto(),
1708 worktree_id: worktree.id,
1709 abs_path: worktree.abs_path.clone(),
1710 root_name: worktree.root_name,
1711 updated_entries: worktree.updated_entries,
1712 removed_entries: worktree.removed_entries,
1713 scan_id: worktree.scan_id,
1714 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1715 updated_repositories: worktree.updated_repositories,
1716 removed_repositories: worktree.removed_repositories,
1717 };
1718 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1719 session.peer.send(session.connection_id, update.clone())?;
1720 }
1721
1722 // Stream this worktree's diagnostics.
1723 for summary in worktree.diagnostic_summaries {
1724 session.peer.send(
1725 session.connection_id,
1726 proto::UpdateDiagnosticSummary {
1727 project_id: project.id.to_proto(),
1728 worktree_id: worktree.id,
1729 summary: Some(summary),
1730 },
1731 )?;
1732 }
1733
1734 for settings_file in worktree.settings_files {
1735 session.peer.send(
1736 session.connection_id,
1737 proto::UpdateWorktreeSettings {
1738 project_id: project.id.to_proto(),
1739 worktree_id: worktree.id,
1740 path: settings_file.path,
1741 content: Some(settings_file.content),
1742 kind: Some(settings_file.kind.to_proto().into()),
1743 },
1744 )?;
1745 }
1746 }
1747
1748 for language_server in &project.language_servers {
1749 session.peer.send(
1750 session.connection_id,
1751 proto::UpdateLanguageServer {
1752 project_id: project.id.to_proto(),
1753 language_server_id: language_server.id,
1754 variant: Some(
1755 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1756 proto::LspDiskBasedDiagnosticsUpdated {},
1757 ),
1758 ),
1759 },
1760 )?;
1761 }
1762 }
1763 Ok(())
1764}
1765
1766/// leave room disconnects from the room.
1767async fn leave_room(
1768 _: proto::LeaveRoom,
1769 response: Response<proto::LeaveRoom>,
1770 session: UserSession,
1771) -> Result<()> {
1772 leave_room_for_session(&session, session.connection_id).await?;
1773 response.send(proto::Ack {})?;
1774 Ok(())
1775}
1776
1777/// Updates the permissions of someone else in the room.
1778async fn set_room_participant_role(
1779 request: proto::SetRoomParticipantRole,
1780 response: Response<proto::SetRoomParticipantRole>,
1781 session: UserSession,
1782) -> Result<()> {
1783 let user_id = UserId::from_proto(request.user_id);
1784 let role = ChannelRole::from(request.role());
1785
1786 let (live_kit_room, can_publish) = {
1787 let room = session
1788 .db()
1789 .await
1790 .set_room_participant_role(
1791 session.user_id(),
1792 RoomId::from_proto(request.room_id),
1793 user_id,
1794 role,
1795 )
1796 .await?;
1797
1798 let live_kit_room = room.live_kit_room.clone();
1799 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1800 room_updated(&room, &session.peer);
1801 (live_kit_room, can_publish)
1802 };
1803
1804 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1805 live_kit
1806 .update_participant(
1807 live_kit_room.clone(),
1808 request.user_id.to_string(),
1809 live_kit_server::proto::ParticipantPermission {
1810 can_subscribe: true,
1811 can_publish,
1812 can_publish_data: can_publish,
1813 hidden: false,
1814 recorder: false,
1815 },
1816 )
1817 .await
1818 .trace_err();
1819 }
1820
1821 response.send(proto::Ack {})?;
1822 Ok(())
1823}
1824
1825/// Call someone else into the current room
1826async fn call(
1827 request: proto::Call,
1828 response: Response<proto::Call>,
1829 session: UserSession,
1830) -> Result<()> {
1831 let room_id = RoomId::from_proto(request.room_id);
1832 let calling_user_id = session.user_id();
1833 let calling_connection_id = session.connection_id;
1834 let called_user_id = UserId::from_proto(request.called_user_id);
1835 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1836 if !session
1837 .db()
1838 .await
1839 .has_contact(calling_user_id, called_user_id)
1840 .await?
1841 {
1842 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1843 }
1844
1845 let incoming_call = {
1846 let (room, incoming_call) = &mut *session
1847 .db()
1848 .await
1849 .call(
1850 room_id,
1851 calling_user_id,
1852 calling_connection_id,
1853 called_user_id,
1854 initial_project_id,
1855 )
1856 .await?;
1857 room_updated(room, &session.peer);
1858 mem::take(incoming_call)
1859 };
1860 update_user_contacts(called_user_id, &session).await?;
1861
1862 let mut calls = session
1863 .connection_pool()
1864 .await
1865 .user_connection_ids(called_user_id)
1866 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1867 .collect::<FuturesUnordered<_>>();
1868
1869 while let Some(call_response) = calls.next().await {
1870 match call_response.as_ref() {
1871 Ok(_) => {
1872 response.send(proto::Ack {})?;
1873 return Ok(());
1874 }
1875 Err(_) => {
1876 call_response.trace_err();
1877 }
1878 }
1879 }
1880
1881 {
1882 let room = session
1883 .db()
1884 .await
1885 .call_failed(room_id, called_user_id)
1886 .await?;
1887 room_updated(&room, &session.peer);
1888 }
1889 update_user_contacts(called_user_id, &session).await?;
1890
1891 Err(anyhow!("failed to ring user"))?
1892}
1893
1894/// Cancel an outgoing call.
1895async fn cancel_call(
1896 request: proto::CancelCall,
1897 response: Response<proto::CancelCall>,
1898 session: UserSession,
1899) -> Result<()> {
1900 let called_user_id = UserId::from_proto(request.called_user_id);
1901 let room_id = RoomId::from_proto(request.room_id);
1902 {
1903 let room = session
1904 .db()
1905 .await
1906 .cancel_call(room_id, session.connection_id, called_user_id)
1907 .await?;
1908 room_updated(&room, &session.peer);
1909 }
1910
1911 for connection_id in session
1912 .connection_pool()
1913 .await
1914 .user_connection_ids(called_user_id)
1915 {
1916 session
1917 .peer
1918 .send(
1919 connection_id,
1920 proto::CallCanceled {
1921 room_id: room_id.to_proto(),
1922 },
1923 )
1924 .trace_err();
1925 }
1926 response.send(proto::Ack {})?;
1927
1928 update_user_contacts(called_user_id, &session).await?;
1929 Ok(())
1930}
1931
1932/// Decline an incoming call.
1933async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1934 let room_id = RoomId::from_proto(message.room_id);
1935 {
1936 let room = session
1937 .db()
1938 .await
1939 .decline_call(Some(room_id), session.user_id())
1940 .await?
1941 .ok_or_else(|| anyhow!("failed to decline call"))?;
1942 room_updated(&room, &session.peer);
1943 }
1944
1945 for connection_id in session
1946 .connection_pool()
1947 .await
1948 .user_connection_ids(session.user_id())
1949 {
1950 session
1951 .peer
1952 .send(
1953 connection_id,
1954 proto::CallCanceled {
1955 room_id: room_id.to_proto(),
1956 },
1957 )
1958 .trace_err();
1959 }
1960 update_user_contacts(session.user_id(), &session).await?;
1961 Ok(())
1962}
1963
1964/// Updates other participants in the room with your current location.
1965async fn update_participant_location(
1966 request: proto::UpdateParticipantLocation,
1967 response: Response<proto::UpdateParticipantLocation>,
1968 session: UserSession,
1969) -> Result<()> {
1970 let room_id = RoomId::from_proto(request.room_id);
1971 let location = request
1972 .location
1973 .ok_or_else(|| anyhow!("invalid location"))?;
1974
1975 let db = session.db().await;
1976 let room = db
1977 .update_room_participant_location(room_id, session.connection_id, location)
1978 .await?;
1979
1980 room_updated(&room, &session.peer);
1981 response.send(proto::Ack {})?;
1982 Ok(())
1983}
1984
1985/// Share a project into the room.
1986async fn share_project(
1987 request: proto::ShareProject,
1988 response: Response<proto::ShareProject>,
1989 session: UserSession,
1990) -> Result<()> {
1991 let (project_id, room) = &*session
1992 .db()
1993 .await
1994 .share_project(
1995 RoomId::from_proto(request.room_id),
1996 session.connection_id,
1997 &request.worktrees,
1998 request.is_ssh_project,
1999 request
2000 .dev_server_project_id
2001 .map(DevServerProjectId::from_proto),
2002 )
2003 .await?;
2004 response.send(proto::ShareProjectResponse {
2005 project_id: project_id.to_proto(),
2006 })?;
2007 room_updated(room, &session.peer);
2008
2009 Ok(())
2010}
2011
2012/// Unshare a project from the room.
2013async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2014 let project_id = ProjectId::from_proto(message.project_id);
2015 unshare_project_internal(
2016 project_id,
2017 session.connection_id,
2018 session.user_id(),
2019 &session,
2020 )
2021 .await
2022}
2023
2024async fn unshare_project_internal(
2025 project_id: ProjectId,
2026 connection_id: ConnectionId,
2027 user_id: Option<UserId>,
2028 session: &Session,
2029) -> Result<()> {
2030 let delete = {
2031 let room_guard = session
2032 .db()
2033 .await
2034 .unshare_project(project_id, connection_id, user_id)
2035 .await?;
2036
2037 let (delete, room, guest_connection_ids) = &*room_guard;
2038
2039 let message = proto::UnshareProject {
2040 project_id: project_id.to_proto(),
2041 };
2042
2043 broadcast(
2044 Some(connection_id),
2045 guest_connection_ids.iter().copied(),
2046 |conn_id| session.peer.send(conn_id, message.clone()),
2047 );
2048 if let Some(room) = room {
2049 room_updated(room, &session.peer);
2050 }
2051
2052 *delete
2053 };
2054
2055 if delete {
2056 let db = session.db().await;
2057 db.delete_project(project_id).await?;
2058 }
2059
2060 Ok(())
2061}
2062
2063/// DevServer makes a project available online
2064async fn share_dev_server_project(
2065 request: proto::ShareDevServerProject,
2066 response: Response<proto::ShareDevServerProject>,
2067 session: DevServerSession,
2068) -> Result<()> {
2069 let (dev_server_project, user_id, status) = session
2070 .db()
2071 .await
2072 .share_dev_server_project(
2073 DevServerProjectId::from_proto(request.dev_server_project_id),
2074 session.dev_server_id(),
2075 session.connection_id,
2076 &request.worktrees,
2077 )
2078 .await?;
2079 let Some(project_id) = dev_server_project.project_id else {
2080 return Err(anyhow!("failed to share remote project"))?;
2081 };
2082
2083 send_dev_server_projects_update(user_id, status, &session).await;
2084
2085 response.send(proto::ShareProjectResponse { project_id })?;
2086
2087 Ok(())
2088}
2089
2090/// Join someone elses shared project.
2091async fn join_project(
2092 request: proto::JoinProject,
2093 response: Response<proto::JoinProject>,
2094 session: UserSession,
2095) -> Result<()> {
2096 let project_id = ProjectId::from_proto(request.project_id);
2097
2098 tracing::info!(%project_id, "join project");
2099
2100 let db = session.db().await;
2101 let (project, replica_id) = &mut *db
2102 .join_project(project_id, session.connection_id, session.user_id())
2103 .await?;
2104 drop(db);
2105 tracing::info!(%project_id, "join remote project");
2106 join_project_internal(response, session, project, replica_id)
2107}
2108
2109trait JoinProjectInternalResponse {
2110 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2111}
2112impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2113 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2114 Response::<proto::JoinProject>::send(self, result)
2115 }
2116}
2117impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2118 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2119 Response::<proto::JoinHostedProject>::send(self, result)
2120 }
2121}
2122
2123fn join_project_internal(
2124 response: impl JoinProjectInternalResponse,
2125 session: UserSession,
2126 project: &mut Project,
2127 replica_id: &ReplicaId,
2128) -> Result<()> {
2129 let collaborators = project
2130 .collaborators
2131 .iter()
2132 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2133 .map(|collaborator| collaborator.to_proto())
2134 .collect::<Vec<_>>();
2135 let project_id = project.id;
2136 let guest_user_id = session.user_id();
2137
2138 let worktrees = project
2139 .worktrees
2140 .iter()
2141 .map(|(id, worktree)| proto::WorktreeMetadata {
2142 id: *id,
2143 root_name: worktree.root_name.clone(),
2144 visible: worktree.visible,
2145 abs_path: worktree.abs_path.clone(),
2146 })
2147 .collect::<Vec<_>>();
2148
2149 let add_project_collaborator = proto::AddProjectCollaborator {
2150 project_id: project_id.to_proto(),
2151 collaborator: Some(proto::Collaborator {
2152 peer_id: Some(session.connection_id.into()),
2153 replica_id: replica_id.0 as u32,
2154 user_id: guest_user_id.to_proto(),
2155 }),
2156 };
2157
2158 for collaborator in &collaborators {
2159 session
2160 .peer
2161 .send(
2162 collaborator.peer_id.unwrap().into(),
2163 add_project_collaborator.clone(),
2164 )
2165 .trace_err();
2166 }
2167
2168 // First, we send the metadata associated with each worktree.
2169 response.send(proto::JoinProjectResponse {
2170 project_id: project.id.0 as u64,
2171 worktrees: worktrees.clone(),
2172 replica_id: replica_id.0 as u32,
2173 collaborators: collaborators.clone(),
2174 language_servers: project.language_servers.clone(),
2175 role: project.role.into(),
2176 dev_server_project_id: project
2177 .dev_server_project_id
2178 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2179 })?;
2180
2181 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2182 #[cfg(any(test, feature = "test-support"))]
2183 const MAX_CHUNK_SIZE: usize = 2;
2184 #[cfg(not(any(test, feature = "test-support")))]
2185 const MAX_CHUNK_SIZE: usize = 256;
2186
2187 // Stream this worktree's entries.
2188 let message = proto::UpdateWorktree {
2189 project_id: project_id.to_proto(),
2190 worktree_id,
2191 abs_path: worktree.abs_path.clone(),
2192 root_name: worktree.root_name,
2193 updated_entries: worktree.entries,
2194 removed_entries: Default::default(),
2195 scan_id: worktree.scan_id,
2196 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2197 updated_repositories: worktree.repository_entries.into_values().collect(),
2198 removed_repositories: Default::default(),
2199 };
2200 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2201 session.peer.send(session.connection_id, update.clone())?;
2202 }
2203
2204 // Stream this worktree's diagnostics.
2205 for summary in worktree.diagnostic_summaries {
2206 session.peer.send(
2207 session.connection_id,
2208 proto::UpdateDiagnosticSummary {
2209 project_id: project_id.to_proto(),
2210 worktree_id: worktree.id,
2211 summary: Some(summary),
2212 },
2213 )?;
2214 }
2215
2216 for settings_file in worktree.settings_files {
2217 session.peer.send(
2218 session.connection_id,
2219 proto::UpdateWorktreeSettings {
2220 project_id: project_id.to_proto(),
2221 worktree_id: worktree.id,
2222 path: settings_file.path,
2223 content: Some(settings_file.content),
2224 kind: Some(proto::update_user_settings::Kind::Settings.into()),
2225 },
2226 )?;
2227 }
2228 }
2229
2230 for language_server in &project.language_servers {
2231 session.peer.send(
2232 session.connection_id,
2233 proto::UpdateLanguageServer {
2234 project_id: project_id.to_proto(),
2235 language_server_id: language_server.id,
2236 variant: Some(
2237 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2238 proto::LspDiskBasedDiagnosticsUpdated {},
2239 ),
2240 ),
2241 },
2242 )?;
2243 }
2244
2245 Ok(())
2246}
2247
2248/// Leave someone elses shared project.
2249async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2250 let sender_id = session.connection_id;
2251 let project_id = ProjectId::from_proto(request.project_id);
2252 let db = session.db().await;
2253 if db.is_hosted_project(project_id).await? {
2254 let project = db.leave_hosted_project(project_id, sender_id).await?;
2255 project_left(&project, &session);
2256 return Ok(());
2257 }
2258
2259 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2260 tracing::info!(
2261 %project_id,
2262 "leave project"
2263 );
2264
2265 project_left(project, &session);
2266 if let Some(room) = room {
2267 room_updated(room, &session.peer);
2268 }
2269
2270 Ok(())
2271}
2272
2273async fn join_hosted_project(
2274 request: proto::JoinHostedProject,
2275 response: Response<proto::JoinHostedProject>,
2276 session: UserSession,
2277) -> Result<()> {
2278 let (mut project, replica_id) = session
2279 .db()
2280 .await
2281 .join_hosted_project(
2282 ProjectId(request.project_id as i32),
2283 session.user_id(),
2284 session.connection_id,
2285 )
2286 .await?;
2287
2288 join_project_internal(response, session, &mut project, &replica_id)
2289}
2290
2291async fn list_remote_directory(
2292 request: proto::ListRemoteDirectory,
2293 response: Response<proto::ListRemoteDirectory>,
2294 session: UserSession,
2295) -> Result<()> {
2296 let dev_server_id = DevServerId(request.dev_server_id as i32);
2297 let dev_server_connection_id = session
2298 .connection_pool()
2299 .await
2300 .online_dev_server_connection_id(dev_server_id)?;
2301
2302 session
2303 .db()
2304 .await
2305 .get_dev_server_for_user(dev_server_id, session.user_id())
2306 .await?;
2307
2308 response.send(
2309 session
2310 .peer
2311 .forward_request(session.connection_id, dev_server_connection_id, request)
2312 .await?,
2313 )?;
2314 Ok(())
2315}
2316
2317async fn update_dev_server_project(
2318 request: proto::UpdateDevServerProject,
2319 response: Response<proto::UpdateDevServerProject>,
2320 session: UserSession,
2321) -> Result<()> {
2322 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2323
2324 let (dev_server_project, update) = session
2325 .db()
2326 .await
2327 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2328 .await?;
2329
2330 let projects = session
2331 .db()
2332 .await
2333 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2334 .await?;
2335
2336 let dev_server_connection_id = session
2337 .connection_pool()
2338 .await
2339 .online_dev_server_connection_id(dev_server_project.dev_server_id)?;
2340
2341 session.peer.send(
2342 dev_server_connection_id,
2343 proto::DevServerInstructions { projects },
2344 )?;
2345
2346 send_dev_server_projects_update(session.user_id(), update, &session).await;
2347
2348 response.send(proto::Ack {})
2349}
2350
2351async fn create_dev_server_project(
2352 request: proto::CreateDevServerProject,
2353 response: Response<proto::CreateDevServerProject>,
2354 session: UserSession,
2355) -> Result<()> {
2356 let dev_server_id = DevServerId(request.dev_server_id as i32);
2357 let dev_server_connection_id = session
2358 .connection_pool()
2359 .await
2360 .dev_server_connection_id(dev_server_id);
2361 let Some(dev_server_connection_id) = dev_server_connection_id else {
2362 Err(ErrorCode::DevServerOffline
2363 .message("Cannot create a remote project when the dev server is offline".to_string())
2364 .anyhow())?
2365 };
2366
2367 let path = request.path.clone();
2368 //Check that the path exists on the dev server
2369 session
2370 .peer
2371 .forward_request(
2372 session.connection_id,
2373 dev_server_connection_id,
2374 proto::ValidateDevServerProjectRequest { path: path.clone() },
2375 )
2376 .await?;
2377
2378 let (dev_server_project, update) = session
2379 .db()
2380 .await
2381 .create_dev_server_project(
2382 DevServerId(request.dev_server_id as i32),
2383 &request.path,
2384 session.user_id(),
2385 )
2386 .await?;
2387
2388 let projects = session
2389 .db()
2390 .await
2391 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2392 .await?;
2393
2394 session.peer.send(
2395 dev_server_connection_id,
2396 proto::DevServerInstructions { projects },
2397 )?;
2398
2399 send_dev_server_projects_update(session.user_id(), update, &session).await;
2400
2401 response.send(proto::CreateDevServerProjectResponse {
2402 dev_server_project: Some(dev_server_project.to_proto(None)),
2403 })?;
2404 Ok(())
2405}
2406
2407async fn create_dev_server(
2408 request: proto::CreateDevServer,
2409 response: Response<proto::CreateDevServer>,
2410 session: UserSession,
2411) -> Result<()> {
2412 let access_token = auth::random_token();
2413 let hashed_access_token = auth::hash_access_token(&access_token);
2414
2415 if request.name.is_empty() {
2416 return Err(proto::ErrorCode::Forbidden
2417 .message("Dev server name cannot be empty".to_string())
2418 .anyhow())?;
2419 }
2420
2421 let (dev_server, status) = session
2422 .db()
2423 .await
2424 .create_dev_server(
2425 &request.name,
2426 request.ssh_connection_string.as_deref(),
2427 &hashed_access_token,
2428 session.user_id(),
2429 )
2430 .await?;
2431
2432 send_dev_server_projects_update(session.user_id(), status, &session).await;
2433
2434 response.send(proto::CreateDevServerResponse {
2435 dev_server_id: dev_server.id.0 as u64,
2436 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2437 name: request.name,
2438 })?;
2439 Ok(())
2440}
2441
2442async fn regenerate_dev_server_token(
2443 request: proto::RegenerateDevServerToken,
2444 response: Response<proto::RegenerateDevServerToken>,
2445 session: UserSession,
2446) -> Result<()> {
2447 let dev_server_id = DevServerId(request.dev_server_id as i32);
2448 let access_token = auth::random_token();
2449 let hashed_access_token = auth::hash_access_token(&access_token);
2450
2451 let connection_id = session
2452 .connection_pool()
2453 .await
2454 .dev_server_connection_id(dev_server_id);
2455 if let Some(connection_id) = connection_id {
2456 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2457 session.peer.send(
2458 connection_id,
2459 proto::ShutdownDevServer {
2460 reason: Some("dev server token was regenerated".to_string()),
2461 },
2462 )?;
2463 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2464 }
2465
2466 let status = session
2467 .db()
2468 .await
2469 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2470 .await?;
2471
2472 send_dev_server_projects_update(session.user_id(), status, &session).await;
2473
2474 response.send(proto::RegenerateDevServerTokenResponse {
2475 dev_server_id: dev_server_id.to_proto(),
2476 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2477 })?;
2478 Ok(())
2479}
2480
2481async fn rename_dev_server(
2482 request: proto::RenameDevServer,
2483 response: Response<proto::RenameDevServer>,
2484 session: UserSession,
2485) -> Result<()> {
2486 if request.name.trim().is_empty() {
2487 return Err(proto::ErrorCode::Forbidden
2488 .message("Dev server name cannot be empty".to_string())
2489 .anyhow())?;
2490 }
2491
2492 let dev_server_id = DevServerId(request.dev_server_id as i32);
2493 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2494 if dev_server.user_id != session.user_id() {
2495 return Err(anyhow!(ErrorCode::Forbidden))?;
2496 }
2497
2498 let status = session
2499 .db()
2500 .await
2501 .rename_dev_server(
2502 dev_server_id,
2503 &request.name,
2504 request.ssh_connection_string.as_deref(),
2505 session.user_id(),
2506 )
2507 .await?;
2508
2509 send_dev_server_projects_update(session.user_id(), status, &session).await;
2510
2511 response.send(proto::Ack {})?;
2512 Ok(())
2513}
2514
2515async fn delete_dev_server(
2516 request: proto::DeleteDevServer,
2517 response: Response<proto::DeleteDevServer>,
2518 session: UserSession,
2519) -> Result<()> {
2520 let dev_server_id = DevServerId(request.dev_server_id as i32);
2521 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2522 if dev_server.user_id != session.user_id() {
2523 return Err(anyhow!(ErrorCode::Forbidden))?;
2524 }
2525
2526 let connection_id = session
2527 .connection_pool()
2528 .await
2529 .dev_server_connection_id(dev_server_id);
2530 if let Some(connection_id) = connection_id {
2531 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2532 session.peer.send(
2533 connection_id,
2534 proto::ShutdownDevServer {
2535 reason: Some("dev server was deleted".to_string()),
2536 },
2537 )?;
2538 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2539 }
2540
2541 let status = session
2542 .db()
2543 .await
2544 .delete_dev_server(dev_server_id, session.user_id())
2545 .await?;
2546
2547 send_dev_server_projects_update(session.user_id(), status, &session).await;
2548
2549 response.send(proto::Ack {})?;
2550 Ok(())
2551}
2552
2553async fn delete_dev_server_project(
2554 request: proto::DeleteDevServerProject,
2555 response: Response<proto::DeleteDevServerProject>,
2556 session: UserSession,
2557) -> Result<()> {
2558 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2559 let dev_server_project = session
2560 .db()
2561 .await
2562 .get_dev_server_project(dev_server_project_id)
2563 .await?;
2564
2565 let dev_server = session
2566 .db()
2567 .await
2568 .get_dev_server(dev_server_project.dev_server_id)
2569 .await?;
2570 if dev_server.user_id != session.user_id() {
2571 return Err(anyhow!(ErrorCode::Forbidden))?;
2572 }
2573
2574 let dev_server_connection_id = session
2575 .connection_pool()
2576 .await
2577 .dev_server_connection_id(dev_server.id);
2578
2579 if let Some(dev_server_connection_id) = dev_server_connection_id {
2580 let project = session
2581 .db()
2582 .await
2583 .find_dev_server_project(dev_server_project_id)
2584 .await;
2585 if let Ok(project) = project {
2586 unshare_project_internal(
2587 project.id,
2588 dev_server_connection_id,
2589 Some(session.user_id()),
2590 &session,
2591 )
2592 .await?;
2593 }
2594 }
2595
2596 let (projects, status) = session
2597 .db()
2598 .await
2599 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2600 .await?;
2601
2602 if let Some(dev_server_connection_id) = dev_server_connection_id {
2603 session.peer.send(
2604 dev_server_connection_id,
2605 proto::DevServerInstructions { projects },
2606 )?;
2607 }
2608
2609 send_dev_server_projects_update(session.user_id(), status, &session).await;
2610
2611 response.send(proto::Ack {})?;
2612 Ok(())
2613}
2614
2615async fn rejoin_dev_server_projects(
2616 request: proto::RejoinRemoteProjects,
2617 response: Response<proto::RejoinRemoteProjects>,
2618 session: UserSession,
2619) -> Result<()> {
2620 let mut rejoined_projects = {
2621 let db = session.db().await;
2622 db.rejoin_dev_server_projects(
2623 &request.rejoined_projects,
2624 session.user_id(),
2625 session.0.connection_id,
2626 )
2627 .await?
2628 };
2629 response.send(proto::RejoinRemoteProjectsResponse {
2630 rejoined_projects: rejoined_projects
2631 .iter()
2632 .map(|project| project.to_proto())
2633 .collect(),
2634 })?;
2635 notify_rejoined_projects(&mut rejoined_projects, &session)
2636}
2637
2638async fn reconnect_dev_server(
2639 request: proto::ReconnectDevServer,
2640 response: Response<proto::ReconnectDevServer>,
2641 session: DevServerSession,
2642) -> Result<()> {
2643 let reshared_projects = {
2644 let db = session.db().await;
2645 db.reshare_dev_server_projects(
2646 &request.reshared_projects,
2647 session.dev_server_id(),
2648 session.0.connection_id,
2649 )
2650 .await?
2651 };
2652
2653 for project in &reshared_projects {
2654 for collaborator in &project.collaborators {
2655 session
2656 .peer
2657 .send(
2658 collaborator.connection_id,
2659 proto::UpdateProjectCollaborator {
2660 project_id: project.id.to_proto(),
2661 old_peer_id: Some(project.old_connection_id.into()),
2662 new_peer_id: Some(session.connection_id.into()),
2663 },
2664 )
2665 .trace_err();
2666 }
2667
2668 broadcast(
2669 Some(session.connection_id),
2670 project
2671 .collaborators
2672 .iter()
2673 .map(|collaborator| collaborator.connection_id),
2674 |connection_id| {
2675 session.peer.forward_send(
2676 session.connection_id,
2677 connection_id,
2678 proto::UpdateProject {
2679 project_id: project.id.to_proto(),
2680 worktrees: project.worktrees.clone(),
2681 },
2682 )
2683 },
2684 );
2685 }
2686
2687 response.send(proto::ReconnectDevServerResponse {
2688 reshared_projects: reshared_projects
2689 .iter()
2690 .map(|project| proto::ResharedProject {
2691 id: project.id.to_proto(),
2692 collaborators: project
2693 .collaborators
2694 .iter()
2695 .map(|collaborator| collaborator.to_proto())
2696 .collect(),
2697 })
2698 .collect(),
2699 })?;
2700
2701 Ok(())
2702}
2703
2704async fn shutdown_dev_server(
2705 _: proto::ShutdownDevServer,
2706 response: Response<proto::ShutdownDevServer>,
2707 session: DevServerSession,
2708) -> Result<()> {
2709 response.send(proto::Ack {})?;
2710 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2711 remove_dev_server_connection(session.dev_server_id(), &session).await
2712}
2713
2714async fn shutdown_dev_server_internal(
2715 dev_server_id: DevServerId,
2716 connection_id: ConnectionId,
2717 session: &Session,
2718) -> Result<()> {
2719 let (dev_server_projects, dev_server) = {
2720 let db = session.db().await;
2721 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2722 let dev_server = db.get_dev_server(dev_server_id).await?;
2723 (dev_server_projects, dev_server)
2724 };
2725
2726 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2727 unshare_project_internal(
2728 ProjectId::from_proto(project_id),
2729 connection_id,
2730 None,
2731 session,
2732 )
2733 .await?;
2734 }
2735
2736 session
2737 .connection_pool()
2738 .await
2739 .set_dev_server_offline(dev_server_id);
2740
2741 let status = session
2742 .db()
2743 .await
2744 .dev_server_projects_update(dev_server.user_id)
2745 .await?;
2746 send_dev_server_projects_update(dev_server.user_id, status, session).await;
2747
2748 Ok(())
2749}
2750
2751async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2752 let dev_server_connection = session
2753 .connection_pool()
2754 .await
2755 .dev_server_connection_id(dev_server_id);
2756
2757 if let Some(dev_server_connection) = dev_server_connection {
2758 session
2759 .connection_pool()
2760 .await
2761 .remove_connection(dev_server_connection)?;
2762 }
2763 Ok(())
2764}
2765
2766/// Updates other participants with changes to the project
2767async fn update_project(
2768 request: proto::UpdateProject,
2769 response: Response<proto::UpdateProject>,
2770 session: Session,
2771) -> Result<()> {
2772 let project_id = ProjectId::from_proto(request.project_id);
2773 let (room, guest_connection_ids) = &*session
2774 .db()
2775 .await
2776 .update_project(project_id, session.connection_id, &request.worktrees)
2777 .await?;
2778 broadcast(
2779 Some(session.connection_id),
2780 guest_connection_ids.iter().copied(),
2781 |connection_id| {
2782 session
2783 .peer
2784 .forward_send(session.connection_id, connection_id, request.clone())
2785 },
2786 );
2787 if let Some(room) = room {
2788 room_updated(room, &session.peer);
2789 }
2790 response.send(proto::Ack {})?;
2791
2792 Ok(())
2793}
2794
2795/// Updates other participants with changes to the worktree
2796async fn update_worktree(
2797 request: proto::UpdateWorktree,
2798 response: Response<proto::UpdateWorktree>,
2799 session: Session,
2800) -> Result<()> {
2801 let guest_connection_ids = session
2802 .db()
2803 .await
2804 .update_worktree(&request, session.connection_id)
2805 .await?;
2806
2807 broadcast(
2808 Some(session.connection_id),
2809 guest_connection_ids.iter().copied(),
2810 |connection_id| {
2811 session
2812 .peer
2813 .forward_send(session.connection_id, connection_id, request.clone())
2814 },
2815 );
2816 response.send(proto::Ack {})?;
2817 Ok(())
2818}
2819
2820/// Updates other participants with changes to the diagnostics
2821async fn update_diagnostic_summary(
2822 message: proto::UpdateDiagnosticSummary,
2823 session: Session,
2824) -> Result<()> {
2825 let guest_connection_ids = session
2826 .db()
2827 .await
2828 .update_diagnostic_summary(&message, session.connection_id)
2829 .await?;
2830
2831 broadcast(
2832 Some(session.connection_id),
2833 guest_connection_ids.iter().copied(),
2834 |connection_id| {
2835 session
2836 .peer
2837 .forward_send(session.connection_id, connection_id, message.clone())
2838 },
2839 );
2840
2841 Ok(())
2842}
2843
2844/// Updates other participants with changes to the worktree settings
2845async fn update_worktree_settings(
2846 message: proto::UpdateWorktreeSettings,
2847 session: Session,
2848) -> Result<()> {
2849 let guest_connection_ids = session
2850 .db()
2851 .await
2852 .update_worktree_settings(&message, session.connection_id)
2853 .await?;
2854
2855 broadcast(
2856 Some(session.connection_id),
2857 guest_connection_ids.iter().copied(),
2858 |connection_id| {
2859 session
2860 .peer
2861 .forward_send(session.connection_id, connection_id, message.clone())
2862 },
2863 );
2864
2865 Ok(())
2866}
2867
2868/// Notify other participants that a language server has started.
2869async fn start_language_server(
2870 request: proto::StartLanguageServer,
2871 session: Session,
2872) -> Result<()> {
2873 let guest_connection_ids = session
2874 .db()
2875 .await
2876 .start_language_server(&request, session.connection_id)
2877 .await?;
2878
2879 broadcast(
2880 Some(session.connection_id),
2881 guest_connection_ids.iter().copied(),
2882 |connection_id| {
2883 session
2884 .peer
2885 .forward_send(session.connection_id, connection_id, request.clone())
2886 },
2887 );
2888 Ok(())
2889}
2890
2891/// Notify other participants that a language server has changed.
2892async fn update_language_server(
2893 request: proto::UpdateLanguageServer,
2894 session: Session,
2895) -> Result<()> {
2896 let project_id = ProjectId::from_proto(request.project_id);
2897 let project_connection_ids = session
2898 .db()
2899 .await
2900 .project_connection_ids(project_id, session.connection_id, true)
2901 .await?;
2902 broadcast(
2903 Some(session.connection_id),
2904 project_connection_ids.iter().copied(),
2905 |connection_id| {
2906 session
2907 .peer
2908 .forward_send(session.connection_id, connection_id, request.clone())
2909 },
2910 );
2911 Ok(())
2912}
2913
2914/// forward a project request to the host. These requests should be read only
2915/// as guests are allowed to send them.
2916async fn forward_read_only_project_request<T>(
2917 request: T,
2918 response: Response<T>,
2919 session: UserSession,
2920) -> Result<()>
2921where
2922 T: EntityMessage + RequestMessage,
2923{
2924 let project_id = ProjectId::from_proto(request.remote_entity_id());
2925 let host_connection_id = session
2926 .db()
2927 .await
2928 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2929 .await?;
2930 let payload = session
2931 .peer
2932 .forward_request(session.connection_id, host_connection_id, request)
2933 .await?;
2934 response.send(payload)?;
2935 Ok(())
2936}
2937
2938async fn forward_find_search_candidates_request(
2939 request: proto::FindSearchCandidates,
2940 response: Response<proto::FindSearchCandidates>,
2941 session: UserSession,
2942) -> Result<()> {
2943 let project_id = ProjectId::from_proto(request.remote_entity_id());
2944 let host_connection_id = session
2945 .db()
2946 .await
2947 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2948 .await?;
2949 let payload = session
2950 .peer
2951 .forward_request(session.connection_id, host_connection_id, request)
2952 .await?;
2953 response.send(payload)?;
2954 Ok(())
2955}
2956
2957/// forward a project request to the dev server. Only allowed
2958/// if it's your dev server.
2959async fn forward_project_request_for_owner<T>(
2960 request: T,
2961 response: Response<T>,
2962 session: UserSession,
2963) -> Result<()>
2964where
2965 T: EntityMessage + RequestMessage,
2966{
2967 let project_id = ProjectId::from_proto(request.remote_entity_id());
2968
2969 let host_connection_id = session
2970 .db()
2971 .await
2972 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2973 .await?;
2974 let payload = session
2975 .peer
2976 .forward_request(session.connection_id, host_connection_id, request)
2977 .await?;
2978 response.send(payload)?;
2979 Ok(())
2980}
2981
2982/// forward a project request to the host. These requests are disallowed
2983/// for guests.
2984async fn forward_mutating_project_request<T>(
2985 request: T,
2986 response: Response<T>,
2987 session: UserSession,
2988) -> Result<()>
2989where
2990 T: EntityMessage + RequestMessage,
2991{
2992 let project_id = ProjectId::from_proto(request.remote_entity_id());
2993
2994 let host_connection_id = session
2995 .db()
2996 .await
2997 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2998 .await?;
2999 let payload = session
3000 .peer
3001 .forward_request(session.connection_id, host_connection_id, request)
3002 .await?;
3003 response.send(payload)?;
3004 Ok(())
3005}
3006
3007/// Notify other participants that a new buffer has been created
3008async fn create_buffer_for_peer(
3009 request: proto::CreateBufferForPeer,
3010 session: Session,
3011) -> Result<()> {
3012 session
3013 .db()
3014 .await
3015 .check_user_is_project_host(
3016 ProjectId::from_proto(request.project_id),
3017 session.connection_id,
3018 )
3019 .await?;
3020 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3021 session
3022 .peer
3023 .forward_send(session.connection_id, peer_id.into(), request)?;
3024 Ok(())
3025}
3026
3027/// Notify other participants that a buffer has been updated. This is
3028/// allowed for guests as long as the update is limited to selections.
3029async fn update_buffer(
3030 request: proto::UpdateBuffer,
3031 response: Response<proto::UpdateBuffer>,
3032 session: Session,
3033) -> Result<()> {
3034 let project_id = ProjectId::from_proto(request.project_id);
3035 let mut capability = Capability::ReadOnly;
3036
3037 for op in request.operations.iter() {
3038 match op.variant {
3039 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3040 Some(_) => capability = Capability::ReadWrite,
3041 }
3042 }
3043
3044 let host = {
3045 let guard = session
3046 .db()
3047 .await
3048 .connections_for_buffer_update(
3049 project_id,
3050 session.principal_id(),
3051 session.connection_id,
3052 capability,
3053 )
3054 .await?;
3055
3056 let (host, guests) = &*guard;
3057
3058 broadcast(
3059 Some(session.connection_id),
3060 guests.clone(),
3061 |connection_id| {
3062 session
3063 .peer
3064 .forward_send(session.connection_id, connection_id, request.clone())
3065 },
3066 );
3067
3068 *host
3069 };
3070
3071 if host != session.connection_id {
3072 session
3073 .peer
3074 .forward_request(session.connection_id, host, request.clone())
3075 .await?;
3076 }
3077
3078 response.send(proto::Ack {})?;
3079 Ok(())
3080}
3081
3082async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3083 let project_id = ProjectId::from_proto(message.project_id);
3084
3085 let operation = message.operation.as_ref().context("invalid operation")?;
3086 let capability = match operation.variant.as_ref() {
3087 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3088 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3089 match buffer_op.variant {
3090 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3091 Capability::ReadOnly
3092 }
3093 _ => Capability::ReadWrite,
3094 }
3095 } else {
3096 Capability::ReadWrite
3097 }
3098 }
3099 Some(_) => Capability::ReadWrite,
3100 None => Capability::ReadOnly,
3101 };
3102
3103 let guard = session
3104 .db()
3105 .await
3106 .connections_for_buffer_update(
3107 project_id,
3108 session.principal_id(),
3109 session.connection_id,
3110 capability,
3111 )
3112 .await?;
3113
3114 let (host, guests) = &*guard;
3115
3116 broadcast(
3117 Some(session.connection_id),
3118 guests.iter().chain([host]).copied(),
3119 |connection_id| {
3120 session
3121 .peer
3122 .forward_send(session.connection_id, connection_id, message.clone())
3123 },
3124 );
3125
3126 Ok(())
3127}
3128
3129/// Notify other participants that a project has been updated.
3130async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3131 request: T,
3132 session: Session,
3133) -> Result<()> {
3134 let project_id = ProjectId::from_proto(request.remote_entity_id());
3135 let project_connection_ids = session
3136 .db()
3137 .await
3138 .project_connection_ids(project_id, session.connection_id, false)
3139 .await?;
3140
3141 broadcast(
3142 Some(session.connection_id),
3143 project_connection_ids.iter().copied(),
3144 |connection_id| {
3145 session
3146 .peer
3147 .forward_send(session.connection_id, connection_id, request.clone())
3148 },
3149 );
3150 Ok(())
3151}
3152
3153/// Start following another user in a call.
3154async fn follow(
3155 request: proto::Follow,
3156 response: Response<proto::Follow>,
3157 session: UserSession,
3158) -> Result<()> {
3159 let room_id = RoomId::from_proto(request.room_id);
3160 let project_id = request.project_id.map(ProjectId::from_proto);
3161 let leader_id = request
3162 .leader_id
3163 .ok_or_else(|| anyhow!("invalid leader id"))?
3164 .into();
3165 let follower_id = session.connection_id;
3166
3167 session
3168 .db()
3169 .await
3170 .check_room_participants(room_id, leader_id, session.connection_id)
3171 .await?;
3172
3173 let response_payload = session
3174 .peer
3175 .forward_request(session.connection_id, leader_id, request)
3176 .await?;
3177 response.send(response_payload)?;
3178
3179 if let Some(project_id) = project_id {
3180 let room = session
3181 .db()
3182 .await
3183 .follow(room_id, project_id, leader_id, follower_id)
3184 .await?;
3185 room_updated(&room, &session.peer);
3186 }
3187
3188 Ok(())
3189}
3190
3191/// Stop following another user in a call.
3192async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3193 let room_id = RoomId::from_proto(request.room_id);
3194 let project_id = request.project_id.map(ProjectId::from_proto);
3195 let leader_id = request
3196 .leader_id
3197 .ok_or_else(|| anyhow!("invalid leader id"))?
3198 .into();
3199 let follower_id = session.connection_id;
3200
3201 session
3202 .db()
3203 .await
3204 .check_room_participants(room_id, leader_id, session.connection_id)
3205 .await?;
3206
3207 session
3208 .peer
3209 .forward_send(session.connection_id, leader_id, request)?;
3210
3211 if let Some(project_id) = project_id {
3212 let room = session
3213 .db()
3214 .await
3215 .unfollow(room_id, project_id, leader_id, follower_id)
3216 .await?;
3217 room_updated(&room, &session.peer);
3218 }
3219
3220 Ok(())
3221}
3222
3223/// Notify everyone following you of your current location.
3224async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3225 let room_id = RoomId::from_proto(request.room_id);
3226 let database = session.db.lock().await;
3227
3228 let connection_ids = if let Some(project_id) = request.project_id {
3229 let project_id = ProjectId::from_proto(project_id);
3230 database
3231 .project_connection_ids(project_id, session.connection_id, true)
3232 .await?
3233 } else {
3234 database
3235 .room_connection_ids(room_id, session.connection_id)
3236 .await?
3237 };
3238
3239 // For now, don't send view update messages back to that view's current leader.
3240 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3241 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3242 _ => None,
3243 });
3244
3245 for connection_id in connection_ids.iter().cloned() {
3246 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3247 session
3248 .peer
3249 .forward_send(session.connection_id, connection_id, request.clone())?;
3250 }
3251 }
3252 Ok(())
3253}
3254
3255/// Get public data about users.
3256async fn get_users(
3257 request: proto::GetUsers,
3258 response: Response<proto::GetUsers>,
3259 session: Session,
3260) -> Result<()> {
3261 let user_ids = request
3262 .user_ids
3263 .into_iter()
3264 .map(UserId::from_proto)
3265 .collect();
3266 let users = session
3267 .db()
3268 .await
3269 .get_users_by_ids(user_ids)
3270 .await?
3271 .into_iter()
3272 .map(|user| proto::User {
3273 id: user.id.to_proto(),
3274 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3275 github_login: user.github_login,
3276 })
3277 .collect();
3278 response.send(proto::UsersResponse { users })?;
3279 Ok(())
3280}
3281
3282/// Search for users (to invite) buy Github login
3283async fn fuzzy_search_users(
3284 request: proto::FuzzySearchUsers,
3285 response: Response<proto::FuzzySearchUsers>,
3286 session: UserSession,
3287) -> Result<()> {
3288 let query = request.query;
3289 let users = match query.len() {
3290 0 => vec![],
3291 1 | 2 => session
3292 .db()
3293 .await
3294 .get_user_by_github_login(&query)
3295 .await?
3296 .into_iter()
3297 .collect(),
3298 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3299 };
3300 let users = users
3301 .into_iter()
3302 .filter(|user| user.id != session.user_id())
3303 .map(|user| proto::User {
3304 id: user.id.to_proto(),
3305 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3306 github_login: user.github_login,
3307 })
3308 .collect();
3309 response.send(proto::UsersResponse { users })?;
3310 Ok(())
3311}
3312
3313/// Send a contact request to another user.
3314async fn request_contact(
3315 request: proto::RequestContact,
3316 response: Response<proto::RequestContact>,
3317 session: UserSession,
3318) -> Result<()> {
3319 let requester_id = session.user_id();
3320 let responder_id = UserId::from_proto(request.responder_id);
3321 if requester_id == responder_id {
3322 return Err(anyhow!("cannot add yourself as a contact"))?;
3323 }
3324
3325 let notifications = session
3326 .db()
3327 .await
3328 .send_contact_request(requester_id, responder_id)
3329 .await?;
3330
3331 // Update outgoing contact requests of requester
3332 let mut update = proto::UpdateContacts::default();
3333 update.outgoing_requests.push(responder_id.to_proto());
3334 for connection_id in session
3335 .connection_pool()
3336 .await
3337 .user_connection_ids(requester_id)
3338 {
3339 session.peer.send(connection_id, update.clone())?;
3340 }
3341
3342 // Update incoming contact requests of responder
3343 let mut update = proto::UpdateContacts::default();
3344 update
3345 .incoming_requests
3346 .push(proto::IncomingContactRequest {
3347 requester_id: requester_id.to_proto(),
3348 });
3349 let connection_pool = session.connection_pool().await;
3350 for connection_id in connection_pool.user_connection_ids(responder_id) {
3351 session.peer.send(connection_id, update.clone())?;
3352 }
3353
3354 send_notifications(&connection_pool, &session.peer, notifications);
3355
3356 response.send(proto::Ack {})?;
3357 Ok(())
3358}
3359
3360/// Accept or decline a contact request
3361async fn respond_to_contact_request(
3362 request: proto::RespondToContactRequest,
3363 response: Response<proto::RespondToContactRequest>,
3364 session: UserSession,
3365) -> Result<()> {
3366 let responder_id = session.user_id();
3367 let requester_id = UserId::from_proto(request.requester_id);
3368 let db = session.db().await;
3369 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3370 db.dismiss_contact_notification(responder_id, requester_id)
3371 .await?;
3372 } else {
3373 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3374
3375 let notifications = db
3376 .respond_to_contact_request(responder_id, requester_id, accept)
3377 .await?;
3378 let requester_busy = db.is_user_busy(requester_id).await?;
3379 let responder_busy = db.is_user_busy(responder_id).await?;
3380
3381 let pool = session.connection_pool().await;
3382 // Update responder with new contact
3383 let mut update = proto::UpdateContacts::default();
3384 if accept {
3385 update
3386 .contacts
3387 .push(contact_for_user(requester_id, requester_busy, &pool));
3388 }
3389 update
3390 .remove_incoming_requests
3391 .push(requester_id.to_proto());
3392 for connection_id in pool.user_connection_ids(responder_id) {
3393 session.peer.send(connection_id, update.clone())?;
3394 }
3395
3396 // Update requester with new contact
3397 let mut update = proto::UpdateContacts::default();
3398 if accept {
3399 update
3400 .contacts
3401 .push(contact_for_user(responder_id, responder_busy, &pool));
3402 }
3403 update
3404 .remove_outgoing_requests
3405 .push(responder_id.to_proto());
3406
3407 for connection_id in pool.user_connection_ids(requester_id) {
3408 session.peer.send(connection_id, update.clone())?;
3409 }
3410
3411 send_notifications(&pool, &session.peer, notifications);
3412 }
3413
3414 response.send(proto::Ack {})?;
3415 Ok(())
3416}
3417
3418/// Remove a contact.
3419async fn remove_contact(
3420 request: proto::RemoveContact,
3421 response: Response<proto::RemoveContact>,
3422 session: UserSession,
3423) -> Result<()> {
3424 let requester_id = session.user_id();
3425 let responder_id = UserId::from_proto(request.user_id);
3426 let db = session.db().await;
3427 let (contact_accepted, deleted_notification_id) =
3428 db.remove_contact(requester_id, responder_id).await?;
3429
3430 let pool = session.connection_pool().await;
3431 // Update outgoing contact requests of requester
3432 let mut update = proto::UpdateContacts::default();
3433 if contact_accepted {
3434 update.remove_contacts.push(responder_id.to_proto());
3435 } else {
3436 update
3437 .remove_outgoing_requests
3438 .push(responder_id.to_proto());
3439 }
3440 for connection_id in pool.user_connection_ids(requester_id) {
3441 session.peer.send(connection_id, update.clone())?;
3442 }
3443
3444 // Update incoming contact requests of responder
3445 let mut update = proto::UpdateContacts::default();
3446 if contact_accepted {
3447 update.remove_contacts.push(requester_id.to_proto());
3448 } else {
3449 update
3450 .remove_incoming_requests
3451 .push(requester_id.to_proto());
3452 }
3453 for connection_id in pool.user_connection_ids(responder_id) {
3454 session.peer.send(connection_id, update.clone())?;
3455 if let Some(notification_id) = deleted_notification_id {
3456 session.peer.send(
3457 connection_id,
3458 proto::DeleteNotification {
3459 notification_id: notification_id.to_proto(),
3460 },
3461 )?;
3462 }
3463 }
3464
3465 response.send(proto::Ack {})?;
3466 Ok(())
3467}
3468
3469fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3470 version.0.minor() < 139
3471}
3472
3473async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3474 let plan = session.current_plan(session.db().await).await?;
3475
3476 session
3477 .peer
3478 .send(
3479 session.connection_id,
3480 proto::UpdateUserPlan { plan: plan.into() },
3481 )
3482 .trace_err();
3483
3484 Ok(())
3485}
3486
3487async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3488 subscribe_user_to_channels(
3489 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3490 &session,
3491 )
3492 .await?;
3493 Ok(())
3494}
3495
3496async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3497 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3498 let mut pool = session.connection_pool().await;
3499 for membership in &channels_for_user.channel_memberships {
3500 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3501 }
3502 session.peer.send(
3503 session.connection_id,
3504 build_update_user_channels(&channels_for_user),
3505 )?;
3506 session.peer.send(
3507 session.connection_id,
3508 build_channels_update(channels_for_user),
3509 )?;
3510 Ok(())
3511}
3512
3513/// Creates a new channel.
3514async fn create_channel(
3515 request: proto::CreateChannel,
3516 response: Response<proto::CreateChannel>,
3517 session: UserSession,
3518) -> Result<()> {
3519 let db = session.db().await;
3520
3521 let parent_id = request.parent_id.map(ChannelId::from_proto);
3522 let (channel, membership) = db
3523 .create_channel(&request.name, parent_id, session.user_id())
3524 .await?;
3525
3526 let root_id = channel.root_id();
3527 let channel = Channel::from_model(channel);
3528
3529 response.send(proto::CreateChannelResponse {
3530 channel: Some(channel.to_proto()),
3531 parent_id: request.parent_id,
3532 })?;
3533
3534 let mut connection_pool = session.connection_pool().await;
3535 if let Some(membership) = membership {
3536 connection_pool.subscribe_to_channel(
3537 membership.user_id,
3538 membership.channel_id,
3539 membership.role,
3540 );
3541 let update = proto::UpdateUserChannels {
3542 channel_memberships: vec![proto::ChannelMembership {
3543 channel_id: membership.channel_id.to_proto(),
3544 role: membership.role.into(),
3545 }],
3546 ..Default::default()
3547 };
3548 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3549 session.peer.send(connection_id, update.clone())?;
3550 }
3551 }
3552
3553 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3554 if !role.can_see_channel(channel.visibility) {
3555 continue;
3556 }
3557
3558 let update = proto::UpdateChannels {
3559 channels: vec![channel.to_proto()],
3560 ..Default::default()
3561 };
3562 session.peer.send(connection_id, update.clone())?;
3563 }
3564
3565 Ok(())
3566}
3567
3568/// Delete a channel
3569async fn delete_channel(
3570 request: proto::DeleteChannel,
3571 response: Response<proto::DeleteChannel>,
3572 session: UserSession,
3573) -> Result<()> {
3574 let db = session.db().await;
3575
3576 let channel_id = request.channel_id;
3577 let (root_channel, removed_channels) = db
3578 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3579 .await?;
3580 response.send(proto::Ack {})?;
3581
3582 // Notify members of removed channels
3583 let mut update = proto::UpdateChannels::default();
3584 update
3585 .delete_channels
3586 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3587
3588 let connection_pool = session.connection_pool().await;
3589 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3590 session.peer.send(connection_id, update.clone())?;
3591 }
3592
3593 Ok(())
3594}
3595
3596/// Invite someone to join a channel.
3597async fn invite_channel_member(
3598 request: proto::InviteChannelMember,
3599 response: Response<proto::InviteChannelMember>,
3600 session: UserSession,
3601) -> Result<()> {
3602 let db = session.db().await;
3603 let channel_id = ChannelId::from_proto(request.channel_id);
3604 let invitee_id = UserId::from_proto(request.user_id);
3605 let InviteMemberResult {
3606 channel,
3607 notifications,
3608 } = db
3609 .invite_channel_member(
3610 channel_id,
3611 invitee_id,
3612 session.user_id(),
3613 request.role().into(),
3614 )
3615 .await?;
3616
3617 let update = proto::UpdateChannels {
3618 channel_invitations: vec![channel.to_proto()],
3619 ..Default::default()
3620 };
3621
3622 let connection_pool = session.connection_pool().await;
3623 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3624 session.peer.send(connection_id, update.clone())?;
3625 }
3626
3627 send_notifications(&connection_pool, &session.peer, notifications);
3628
3629 response.send(proto::Ack {})?;
3630 Ok(())
3631}
3632
3633/// remove someone from a channel
3634async fn remove_channel_member(
3635 request: proto::RemoveChannelMember,
3636 response: Response<proto::RemoveChannelMember>,
3637 session: UserSession,
3638) -> Result<()> {
3639 let db = session.db().await;
3640 let channel_id = ChannelId::from_proto(request.channel_id);
3641 let member_id = UserId::from_proto(request.user_id);
3642
3643 let RemoveChannelMemberResult {
3644 membership_update,
3645 notification_id,
3646 } = db
3647 .remove_channel_member(channel_id, member_id, session.user_id())
3648 .await?;
3649
3650 let mut connection_pool = session.connection_pool().await;
3651 notify_membership_updated(
3652 &mut connection_pool,
3653 membership_update,
3654 member_id,
3655 &session.peer,
3656 );
3657 for connection_id in connection_pool.user_connection_ids(member_id) {
3658 if let Some(notification_id) = notification_id {
3659 session
3660 .peer
3661 .send(
3662 connection_id,
3663 proto::DeleteNotification {
3664 notification_id: notification_id.to_proto(),
3665 },
3666 )
3667 .trace_err();
3668 }
3669 }
3670
3671 response.send(proto::Ack {})?;
3672 Ok(())
3673}
3674
3675/// Toggle the channel between public and private.
3676/// Care is taken to maintain the invariant that public channels only descend from public channels,
3677/// (though members-only channels can appear at any point in the hierarchy).
3678async fn set_channel_visibility(
3679 request: proto::SetChannelVisibility,
3680 response: Response<proto::SetChannelVisibility>,
3681 session: UserSession,
3682) -> Result<()> {
3683 let db = session.db().await;
3684 let channel_id = ChannelId::from_proto(request.channel_id);
3685 let visibility = request.visibility().into();
3686
3687 let channel_model = db
3688 .set_channel_visibility(channel_id, visibility, session.user_id())
3689 .await?;
3690 let root_id = channel_model.root_id();
3691 let channel = Channel::from_model(channel_model);
3692
3693 let mut connection_pool = session.connection_pool().await;
3694 for (user_id, role) in connection_pool
3695 .channel_user_ids(root_id)
3696 .collect::<Vec<_>>()
3697 .into_iter()
3698 {
3699 let update = if role.can_see_channel(channel.visibility) {
3700 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3701 proto::UpdateChannels {
3702 channels: vec![channel.to_proto()],
3703 ..Default::default()
3704 }
3705 } else {
3706 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3707 proto::UpdateChannels {
3708 delete_channels: vec![channel.id.to_proto()],
3709 ..Default::default()
3710 }
3711 };
3712
3713 for connection_id in connection_pool.user_connection_ids(user_id) {
3714 session.peer.send(connection_id, update.clone())?;
3715 }
3716 }
3717
3718 response.send(proto::Ack {})?;
3719 Ok(())
3720}
3721
3722/// Alter the role for a user in the channel.
3723async fn set_channel_member_role(
3724 request: proto::SetChannelMemberRole,
3725 response: Response<proto::SetChannelMemberRole>,
3726 session: UserSession,
3727) -> Result<()> {
3728 let db = session.db().await;
3729 let channel_id = ChannelId::from_proto(request.channel_id);
3730 let member_id = UserId::from_proto(request.user_id);
3731 let result = db
3732 .set_channel_member_role(
3733 channel_id,
3734 session.user_id(),
3735 member_id,
3736 request.role().into(),
3737 )
3738 .await?;
3739
3740 match result {
3741 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3742 let mut connection_pool = session.connection_pool().await;
3743 notify_membership_updated(
3744 &mut connection_pool,
3745 membership_update,
3746 member_id,
3747 &session.peer,
3748 )
3749 }
3750 db::SetMemberRoleResult::InviteUpdated(channel) => {
3751 let update = proto::UpdateChannels {
3752 channel_invitations: vec![channel.to_proto()],
3753 ..Default::default()
3754 };
3755
3756 for connection_id in session
3757 .connection_pool()
3758 .await
3759 .user_connection_ids(member_id)
3760 {
3761 session.peer.send(connection_id, update.clone())?;
3762 }
3763 }
3764 }
3765
3766 response.send(proto::Ack {})?;
3767 Ok(())
3768}
3769
3770/// Change the name of a channel
3771async fn rename_channel(
3772 request: proto::RenameChannel,
3773 response: Response<proto::RenameChannel>,
3774 session: UserSession,
3775) -> Result<()> {
3776 let db = session.db().await;
3777 let channel_id = ChannelId::from_proto(request.channel_id);
3778 let channel_model = db
3779 .rename_channel(channel_id, session.user_id(), &request.name)
3780 .await?;
3781 let root_id = channel_model.root_id();
3782 let channel = Channel::from_model(channel_model);
3783
3784 response.send(proto::RenameChannelResponse {
3785 channel: Some(channel.to_proto()),
3786 })?;
3787
3788 let connection_pool = session.connection_pool().await;
3789 let update = proto::UpdateChannels {
3790 channels: vec![channel.to_proto()],
3791 ..Default::default()
3792 };
3793 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3794 if role.can_see_channel(channel.visibility) {
3795 session.peer.send(connection_id, update.clone())?;
3796 }
3797 }
3798
3799 Ok(())
3800}
3801
3802/// Move a channel to a new parent.
3803async fn move_channel(
3804 request: proto::MoveChannel,
3805 response: Response<proto::MoveChannel>,
3806 session: UserSession,
3807) -> Result<()> {
3808 let channel_id = ChannelId::from_proto(request.channel_id);
3809 let to = ChannelId::from_proto(request.to);
3810
3811 let (root_id, channels) = session
3812 .db()
3813 .await
3814 .move_channel(channel_id, to, session.user_id())
3815 .await?;
3816
3817 let connection_pool = session.connection_pool().await;
3818 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3819 let channels = channels
3820 .iter()
3821 .filter_map(|channel| {
3822 if role.can_see_channel(channel.visibility) {
3823 Some(channel.to_proto())
3824 } else {
3825 None
3826 }
3827 })
3828 .collect::<Vec<_>>();
3829 if channels.is_empty() {
3830 continue;
3831 }
3832
3833 let update = proto::UpdateChannels {
3834 channels,
3835 ..Default::default()
3836 };
3837
3838 session.peer.send(connection_id, update.clone())?;
3839 }
3840
3841 response.send(Ack {})?;
3842 Ok(())
3843}
3844
3845/// Get the list of channel members
3846async fn get_channel_members(
3847 request: proto::GetChannelMembers,
3848 response: Response<proto::GetChannelMembers>,
3849 session: UserSession,
3850) -> Result<()> {
3851 let db = session.db().await;
3852 let channel_id = ChannelId::from_proto(request.channel_id);
3853 let limit = if request.limit == 0 {
3854 u16::MAX as u64
3855 } else {
3856 request.limit
3857 };
3858 let (members, users) = db
3859 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3860 .await?;
3861 response.send(proto::GetChannelMembersResponse { members, users })?;
3862 Ok(())
3863}
3864
3865/// Accept or decline a channel invitation.
3866async fn respond_to_channel_invite(
3867 request: proto::RespondToChannelInvite,
3868 response: Response<proto::RespondToChannelInvite>,
3869 session: UserSession,
3870) -> Result<()> {
3871 let db = session.db().await;
3872 let channel_id = ChannelId::from_proto(request.channel_id);
3873 let RespondToChannelInvite {
3874 membership_update,
3875 notifications,
3876 } = db
3877 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3878 .await?;
3879
3880 let mut connection_pool = session.connection_pool().await;
3881 if let Some(membership_update) = membership_update {
3882 notify_membership_updated(
3883 &mut connection_pool,
3884 membership_update,
3885 session.user_id(),
3886 &session.peer,
3887 );
3888 } else {
3889 let update = proto::UpdateChannels {
3890 remove_channel_invitations: vec![channel_id.to_proto()],
3891 ..Default::default()
3892 };
3893
3894 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3895 session.peer.send(connection_id, update.clone())?;
3896 }
3897 };
3898
3899 send_notifications(&connection_pool, &session.peer, notifications);
3900
3901 response.send(proto::Ack {})?;
3902
3903 Ok(())
3904}
3905
3906/// Join the channels' room
3907async fn join_channel(
3908 request: proto::JoinChannel,
3909 response: Response<proto::JoinChannel>,
3910 session: UserSession,
3911) -> Result<()> {
3912 let channel_id = ChannelId::from_proto(request.channel_id);
3913 join_channel_internal(channel_id, Box::new(response), session).await
3914}
3915
3916trait JoinChannelInternalResponse {
3917 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3918}
3919impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3920 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3921 Response::<proto::JoinChannel>::send(self, result)
3922 }
3923}
3924impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3925 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3926 Response::<proto::JoinRoom>::send(self, result)
3927 }
3928}
3929
3930async fn join_channel_internal(
3931 channel_id: ChannelId,
3932 response: Box<impl JoinChannelInternalResponse>,
3933 session: UserSession,
3934) -> Result<()> {
3935 let joined_room = {
3936 let mut db = session.db().await;
3937 // If zed quits without leaving the room, and the user re-opens zed before the
3938 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3939 // room they were in.
3940 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3941 tracing::info!(
3942 stale_connection_id = %connection,
3943 "cleaning up stale connection",
3944 );
3945 drop(db);
3946 leave_room_for_session(&session, connection).await?;
3947 db = session.db().await;
3948 }
3949
3950 let (joined_room, membership_updated, role) = db
3951 .join_channel(channel_id, session.user_id(), session.connection_id)
3952 .await?;
3953
3954 let live_kit_connection_info =
3955 session
3956 .app_state
3957 .live_kit_client
3958 .as_ref()
3959 .and_then(|live_kit| {
3960 let (can_publish, token) = if role == ChannelRole::Guest {
3961 (
3962 false,
3963 live_kit
3964 .guest_token(
3965 &joined_room.room.live_kit_room,
3966 &session.user_id().to_string(),
3967 )
3968 .trace_err()?,
3969 )
3970 } else {
3971 (
3972 true,
3973 live_kit
3974 .room_token(
3975 &joined_room.room.live_kit_room,
3976 &session.user_id().to_string(),
3977 )
3978 .trace_err()?,
3979 )
3980 };
3981
3982 Some(LiveKitConnectionInfo {
3983 server_url: live_kit.url().into(),
3984 token,
3985 can_publish,
3986 })
3987 });
3988
3989 response.send(proto::JoinRoomResponse {
3990 room: Some(joined_room.room.clone()),
3991 channel_id: joined_room
3992 .channel
3993 .as_ref()
3994 .map(|channel| channel.id.to_proto()),
3995 live_kit_connection_info,
3996 })?;
3997
3998 let mut connection_pool = session.connection_pool().await;
3999 if let Some(membership_updated) = membership_updated {
4000 notify_membership_updated(
4001 &mut connection_pool,
4002 membership_updated,
4003 session.user_id(),
4004 &session.peer,
4005 );
4006 }
4007
4008 room_updated(&joined_room.room, &session.peer);
4009
4010 joined_room
4011 };
4012
4013 channel_updated(
4014 &joined_room
4015 .channel
4016 .ok_or_else(|| anyhow!("channel not returned"))?,
4017 &joined_room.room,
4018 &session.peer,
4019 &*session.connection_pool().await,
4020 );
4021
4022 update_user_contacts(session.user_id(), &session).await?;
4023 Ok(())
4024}
4025
4026/// Start editing the channel notes
4027async fn join_channel_buffer(
4028 request: proto::JoinChannelBuffer,
4029 response: Response<proto::JoinChannelBuffer>,
4030 session: UserSession,
4031) -> Result<()> {
4032 let db = session.db().await;
4033 let channel_id = ChannelId::from_proto(request.channel_id);
4034
4035 let open_response = db
4036 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4037 .await?;
4038
4039 let collaborators = open_response.collaborators.clone();
4040 response.send(open_response)?;
4041
4042 let update = UpdateChannelBufferCollaborators {
4043 channel_id: channel_id.to_proto(),
4044 collaborators: collaborators.clone(),
4045 };
4046 channel_buffer_updated(
4047 session.connection_id,
4048 collaborators
4049 .iter()
4050 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4051 &update,
4052 &session.peer,
4053 );
4054
4055 Ok(())
4056}
4057
4058/// Edit the channel notes
4059async fn update_channel_buffer(
4060 request: proto::UpdateChannelBuffer,
4061 session: UserSession,
4062) -> Result<()> {
4063 let db = session.db().await;
4064 let channel_id = ChannelId::from_proto(request.channel_id);
4065
4066 let (collaborators, epoch, version) = db
4067 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4068 .await?;
4069
4070 channel_buffer_updated(
4071 session.connection_id,
4072 collaborators.clone(),
4073 &proto::UpdateChannelBuffer {
4074 channel_id: channel_id.to_proto(),
4075 operations: request.operations,
4076 },
4077 &session.peer,
4078 );
4079
4080 let pool = &*session.connection_pool().await;
4081
4082 let non_collaborators =
4083 pool.channel_connection_ids(channel_id)
4084 .filter_map(|(connection_id, _)| {
4085 if collaborators.contains(&connection_id) {
4086 None
4087 } else {
4088 Some(connection_id)
4089 }
4090 });
4091
4092 broadcast(None, non_collaborators, |peer_id| {
4093 session.peer.send(
4094 peer_id,
4095 proto::UpdateChannels {
4096 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4097 channel_id: channel_id.to_proto(),
4098 epoch: epoch as u64,
4099 version: version.clone(),
4100 }],
4101 ..Default::default()
4102 },
4103 )
4104 });
4105
4106 Ok(())
4107}
4108
4109/// Rejoin the channel notes after a connection blip
4110async fn rejoin_channel_buffers(
4111 request: proto::RejoinChannelBuffers,
4112 response: Response<proto::RejoinChannelBuffers>,
4113 session: UserSession,
4114) -> Result<()> {
4115 let db = session.db().await;
4116 let buffers = db
4117 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4118 .await?;
4119
4120 for rejoined_buffer in &buffers {
4121 let collaborators_to_notify = rejoined_buffer
4122 .buffer
4123 .collaborators
4124 .iter()
4125 .filter_map(|c| Some(c.peer_id?.into()));
4126 channel_buffer_updated(
4127 session.connection_id,
4128 collaborators_to_notify,
4129 &proto::UpdateChannelBufferCollaborators {
4130 channel_id: rejoined_buffer.buffer.channel_id,
4131 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4132 },
4133 &session.peer,
4134 );
4135 }
4136
4137 response.send(proto::RejoinChannelBuffersResponse {
4138 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4139 })?;
4140
4141 Ok(())
4142}
4143
4144/// Stop editing the channel notes
4145async fn leave_channel_buffer(
4146 request: proto::LeaveChannelBuffer,
4147 response: Response<proto::LeaveChannelBuffer>,
4148 session: UserSession,
4149) -> Result<()> {
4150 let db = session.db().await;
4151 let channel_id = ChannelId::from_proto(request.channel_id);
4152
4153 let left_buffer = db
4154 .leave_channel_buffer(channel_id, session.connection_id)
4155 .await?;
4156
4157 response.send(Ack {})?;
4158
4159 channel_buffer_updated(
4160 session.connection_id,
4161 left_buffer.connections,
4162 &proto::UpdateChannelBufferCollaborators {
4163 channel_id: channel_id.to_proto(),
4164 collaborators: left_buffer.collaborators,
4165 },
4166 &session.peer,
4167 );
4168
4169 Ok(())
4170}
4171
4172fn channel_buffer_updated<T: EnvelopedMessage>(
4173 sender_id: ConnectionId,
4174 collaborators: impl IntoIterator<Item = ConnectionId>,
4175 message: &T,
4176 peer: &Peer,
4177) {
4178 broadcast(Some(sender_id), collaborators, |peer_id| {
4179 peer.send(peer_id, message.clone())
4180 });
4181}
4182
4183fn send_notifications(
4184 connection_pool: &ConnectionPool,
4185 peer: &Peer,
4186 notifications: db::NotificationBatch,
4187) {
4188 for (user_id, notification) in notifications {
4189 for connection_id in connection_pool.user_connection_ids(user_id) {
4190 if let Err(error) = peer.send(
4191 connection_id,
4192 proto::AddNotification {
4193 notification: Some(notification.clone()),
4194 },
4195 ) {
4196 tracing::error!(
4197 "failed to send notification to {:?} {}",
4198 connection_id,
4199 error
4200 );
4201 }
4202 }
4203 }
4204}
4205
4206/// Send a message to the channel
4207async fn send_channel_message(
4208 request: proto::SendChannelMessage,
4209 response: Response<proto::SendChannelMessage>,
4210 session: UserSession,
4211) -> Result<()> {
4212 // Validate the message body.
4213 let body = request.body.trim().to_string();
4214 if body.len() > MAX_MESSAGE_LEN {
4215 return Err(anyhow!("message is too long"))?;
4216 }
4217 if body.is_empty() {
4218 return Err(anyhow!("message can't be blank"))?;
4219 }
4220
4221 // TODO: adjust mentions if body is trimmed
4222
4223 let timestamp = OffsetDateTime::now_utc();
4224 let nonce = request
4225 .nonce
4226 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4227
4228 let channel_id = ChannelId::from_proto(request.channel_id);
4229 let CreatedChannelMessage {
4230 message_id,
4231 participant_connection_ids,
4232 notifications,
4233 } = session
4234 .db()
4235 .await
4236 .create_channel_message(
4237 channel_id,
4238 session.user_id(),
4239 &body,
4240 &request.mentions,
4241 timestamp,
4242 nonce.clone().into(),
4243 request.reply_to_message_id.map(MessageId::from_proto),
4244 )
4245 .await?;
4246
4247 let message = proto::ChannelMessage {
4248 sender_id: session.user_id().to_proto(),
4249 id: message_id.to_proto(),
4250 body,
4251 mentions: request.mentions,
4252 timestamp: timestamp.unix_timestamp() as u64,
4253 nonce: Some(nonce),
4254 reply_to_message_id: request.reply_to_message_id,
4255 edited_at: None,
4256 };
4257 broadcast(
4258 Some(session.connection_id),
4259 participant_connection_ids.clone(),
4260 |connection| {
4261 session.peer.send(
4262 connection,
4263 proto::ChannelMessageSent {
4264 channel_id: channel_id.to_proto(),
4265 message: Some(message.clone()),
4266 },
4267 )
4268 },
4269 );
4270 response.send(proto::SendChannelMessageResponse {
4271 message: Some(message),
4272 })?;
4273
4274 let pool = &*session.connection_pool().await;
4275 let non_participants =
4276 pool.channel_connection_ids(channel_id)
4277 .filter_map(|(connection_id, _)| {
4278 if participant_connection_ids.contains(&connection_id) {
4279 None
4280 } else {
4281 Some(connection_id)
4282 }
4283 });
4284 broadcast(None, non_participants, |peer_id| {
4285 session.peer.send(
4286 peer_id,
4287 proto::UpdateChannels {
4288 latest_channel_message_ids: vec![proto::ChannelMessageId {
4289 channel_id: channel_id.to_proto(),
4290 message_id: message_id.to_proto(),
4291 }],
4292 ..Default::default()
4293 },
4294 )
4295 });
4296 send_notifications(pool, &session.peer, notifications);
4297
4298 Ok(())
4299}
4300
4301/// Delete a channel message
4302async fn remove_channel_message(
4303 request: proto::RemoveChannelMessage,
4304 response: Response<proto::RemoveChannelMessage>,
4305 session: UserSession,
4306) -> Result<()> {
4307 let channel_id = ChannelId::from_proto(request.channel_id);
4308 let message_id = MessageId::from_proto(request.message_id);
4309 let (connection_ids, existing_notification_ids) = session
4310 .db()
4311 .await
4312 .remove_channel_message(channel_id, message_id, session.user_id())
4313 .await?;
4314
4315 broadcast(
4316 Some(session.connection_id),
4317 connection_ids,
4318 move |connection| {
4319 session.peer.send(connection, request.clone())?;
4320
4321 for notification_id in &existing_notification_ids {
4322 session.peer.send(
4323 connection,
4324 proto::DeleteNotification {
4325 notification_id: (*notification_id).to_proto(),
4326 },
4327 )?;
4328 }
4329
4330 Ok(())
4331 },
4332 );
4333 response.send(proto::Ack {})?;
4334 Ok(())
4335}
4336
4337async fn update_channel_message(
4338 request: proto::UpdateChannelMessage,
4339 response: Response<proto::UpdateChannelMessage>,
4340 session: UserSession,
4341) -> Result<()> {
4342 let channel_id = ChannelId::from_proto(request.channel_id);
4343 let message_id = MessageId::from_proto(request.message_id);
4344 let updated_at = OffsetDateTime::now_utc();
4345 let UpdatedChannelMessage {
4346 message_id,
4347 participant_connection_ids,
4348 notifications,
4349 reply_to_message_id,
4350 timestamp,
4351 deleted_mention_notification_ids,
4352 updated_mention_notifications,
4353 } = session
4354 .db()
4355 .await
4356 .update_channel_message(
4357 channel_id,
4358 message_id,
4359 session.user_id(),
4360 request.body.as_str(),
4361 &request.mentions,
4362 updated_at,
4363 )
4364 .await?;
4365
4366 let nonce = request
4367 .nonce
4368 .clone()
4369 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4370
4371 let message = proto::ChannelMessage {
4372 sender_id: session.user_id().to_proto(),
4373 id: message_id.to_proto(),
4374 body: request.body.clone(),
4375 mentions: request.mentions.clone(),
4376 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4377 nonce: Some(nonce),
4378 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4379 edited_at: Some(updated_at.unix_timestamp() as u64),
4380 };
4381
4382 response.send(proto::Ack {})?;
4383
4384 let pool = &*session.connection_pool().await;
4385 broadcast(
4386 Some(session.connection_id),
4387 participant_connection_ids,
4388 |connection| {
4389 session.peer.send(
4390 connection,
4391 proto::ChannelMessageUpdate {
4392 channel_id: channel_id.to_proto(),
4393 message: Some(message.clone()),
4394 },
4395 )?;
4396
4397 for notification_id in &deleted_mention_notification_ids {
4398 session.peer.send(
4399 connection,
4400 proto::DeleteNotification {
4401 notification_id: (*notification_id).to_proto(),
4402 },
4403 )?;
4404 }
4405
4406 for notification in &updated_mention_notifications {
4407 session.peer.send(
4408 connection,
4409 proto::UpdateNotification {
4410 notification: Some(notification.clone()),
4411 },
4412 )?;
4413 }
4414
4415 Ok(())
4416 },
4417 );
4418
4419 send_notifications(pool, &session.peer, notifications);
4420
4421 Ok(())
4422}
4423
4424/// Mark a channel message as read
4425async fn acknowledge_channel_message(
4426 request: proto::AckChannelMessage,
4427 session: UserSession,
4428) -> Result<()> {
4429 let channel_id = ChannelId::from_proto(request.channel_id);
4430 let message_id = MessageId::from_proto(request.message_id);
4431 let notifications = session
4432 .db()
4433 .await
4434 .observe_channel_message(channel_id, session.user_id(), message_id)
4435 .await?;
4436 send_notifications(
4437 &*session.connection_pool().await,
4438 &session.peer,
4439 notifications,
4440 );
4441 Ok(())
4442}
4443
4444/// Mark a buffer version as synced
4445async fn acknowledge_buffer_version(
4446 request: proto::AckBufferOperation,
4447 session: UserSession,
4448) -> Result<()> {
4449 let buffer_id = BufferId::from_proto(request.buffer_id);
4450 session
4451 .db()
4452 .await
4453 .observe_buffer_version(
4454 buffer_id,
4455 session.user_id(),
4456 request.epoch as i32,
4457 &request.version,
4458 )
4459 .await?;
4460 Ok(())
4461}
4462
4463async fn count_language_model_tokens(
4464 request: proto::CountLanguageModelTokens,
4465 response: Response<proto::CountLanguageModelTokens>,
4466 session: Session,
4467 config: &Config,
4468) -> Result<()> {
4469 let Some(session) = session.for_user() else {
4470 return Err(anyhow!("user not found"))?;
4471 };
4472 authorize_access_to_legacy_llm_endpoints(&session).await?;
4473
4474 let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4475 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4476 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4477 };
4478
4479 session
4480 .app_state
4481 .rate_limiter
4482 .check(&*rate_limit, session.user_id())
4483 .await?;
4484
4485 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4486 Some(proto::LanguageModelProvider::Google) => {
4487 let api_key = config
4488 .google_ai_api_key
4489 .as_ref()
4490 .context("no Google AI API key configured on the server")?;
4491 google_ai::count_tokens(
4492 session.http_client.as_ref(),
4493 google_ai::API_URL,
4494 api_key,
4495 serde_json::from_str(&request.request)?,
4496 None,
4497 )
4498 .await?
4499 }
4500 _ => return Err(anyhow!("unsupported provider"))?,
4501 };
4502
4503 response.send(proto::CountLanguageModelTokensResponse {
4504 token_count: result.total_tokens as u32,
4505 })?;
4506
4507 Ok(())
4508}
4509
4510struct ZedProCountLanguageModelTokensRateLimit;
4511
4512impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4513 fn capacity(&self) -> usize {
4514 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4515 .ok()
4516 .and_then(|v| v.parse().ok())
4517 .unwrap_or(600) // Picked arbitrarily
4518 }
4519
4520 fn refill_duration(&self) -> chrono::Duration {
4521 chrono::Duration::hours(1)
4522 }
4523
4524 fn db_name(&self) -> &'static str {
4525 "zed-pro:count-language-model-tokens"
4526 }
4527}
4528
4529struct FreeCountLanguageModelTokensRateLimit;
4530
4531impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4532 fn capacity(&self) -> usize {
4533 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4534 .ok()
4535 .and_then(|v| v.parse().ok())
4536 .unwrap_or(600 / 10) // Picked arbitrarily
4537 }
4538
4539 fn refill_duration(&self) -> chrono::Duration {
4540 chrono::Duration::hours(1)
4541 }
4542
4543 fn db_name(&self) -> &'static str {
4544 "free:count-language-model-tokens"
4545 }
4546}
4547
4548struct ZedProComputeEmbeddingsRateLimit;
4549
4550impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4551 fn capacity(&self) -> usize {
4552 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4553 .ok()
4554 .and_then(|v| v.parse().ok())
4555 .unwrap_or(5000) // Picked arbitrarily
4556 }
4557
4558 fn refill_duration(&self) -> chrono::Duration {
4559 chrono::Duration::hours(1)
4560 }
4561
4562 fn db_name(&self) -> &'static str {
4563 "zed-pro:compute-embeddings"
4564 }
4565}
4566
4567struct FreeComputeEmbeddingsRateLimit;
4568
4569impl RateLimit for FreeComputeEmbeddingsRateLimit {
4570 fn capacity(&self) -> usize {
4571 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4572 .ok()
4573 .and_then(|v| v.parse().ok())
4574 .unwrap_or(5000 / 10) // Picked arbitrarily
4575 }
4576
4577 fn refill_duration(&self) -> chrono::Duration {
4578 chrono::Duration::hours(1)
4579 }
4580
4581 fn db_name(&self) -> &'static str {
4582 "free:compute-embeddings"
4583 }
4584}
4585
4586async fn compute_embeddings(
4587 request: proto::ComputeEmbeddings,
4588 response: Response<proto::ComputeEmbeddings>,
4589 session: UserSession,
4590 api_key: Option<Arc<str>>,
4591) -> Result<()> {
4592 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4593 authorize_access_to_legacy_llm_endpoints(&session).await?;
4594
4595 let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4596 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4597 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4598 };
4599
4600 session
4601 .app_state
4602 .rate_limiter
4603 .check(&*rate_limit, session.user_id())
4604 .await?;
4605
4606 let embeddings = match request.model.as_str() {
4607 "openai/text-embedding-3-small" => {
4608 open_ai::embed(
4609 session.http_client.as_ref(),
4610 OPEN_AI_API_URL,
4611 &api_key,
4612 OpenAiEmbeddingModel::TextEmbedding3Small,
4613 request.texts.iter().map(|text| text.as_str()),
4614 )
4615 .await?
4616 }
4617 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4618 };
4619
4620 let embeddings = request
4621 .texts
4622 .iter()
4623 .map(|text| {
4624 let mut hasher = sha2::Sha256::new();
4625 hasher.update(text.as_bytes());
4626 let result = hasher.finalize();
4627 result.to_vec()
4628 })
4629 .zip(
4630 embeddings
4631 .data
4632 .into_iter()
4633 .map(|embedding| embedding.embedding),
4634 )
4635 .collect::<HashMap<_, _>>();
4636
4637 let db = session.db().await;
4638 db.save_embeddings(&request.model, &embeddings)
4639 .await
4640 .context("failed to save embeddings")
4641 .trace_err();
4642
4643 response.send(proto::ComputeEmbeddingsResponse {
4644 embeddings: embeddings
4645 .into_iter()
4646 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4647 .collect(),
4648 })?;
4649 Ok(())
4650}
4651
4652async fn get_cached_embeddings(
4653 request: proto::GetCachedEmbeddings,
4654 response: Response<proto::GetCachedEmbeddings>,
4655 session: UserSession,
4656) -> Result<()> {
4657 authorize_access_to_legacy_llm_endpoints(&session).await?;
4658
4659 let db = session.db().await;
4660 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4661
4662 response.send(proto::GetCachedEmbeddingsResponse {
4663 embeddings: embeddings
4664 .into_iter()
4665 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4666 .collect(),
4667 })?;
4668 Ok(())
4669}
4670
4671/// This is leftover from before the LLM service.
4672///
4673/// The endpoints protected by this check will be moved there eventually.
4674async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4675 if session.is_staff() {
4676 Ok(())
4677 } else {
4678 Err(anyhow!("permission denied"))?
4679 }
4680}
4681
4682/// Get a Supermaven API key for the user
4683async fn get_supermaven_api_key(
4684 _request: proto::GetSupermavenApiKey,
4685 response: Response<proto::GetSupermavenApiKey>,
4686 session: UserSession,
4687) -> Result<()> {
4688 let user_id: String = session.user_id().to_string();
4689 if !session.is_staff() {
4690 return Err(anyhow!("supermaven not enabled for this account"))?;
4691 }
4692
4693 let email = session
4694 .email()
4695 .ok_or_else(|| anyhow!("user must have an email"))?;
4696
4697 let supermaven_admin_api = session
4698 .supermaven_client
4699 .as_ref()
4700 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4701
4702 let result = supermaven_admin_api
4703 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4704 .await?;
4705
4706 response.send(proto::GetSupermavenApiKeyResponse {
4707 api_key: result.api_key,
4708 })?;
4709
4710 Ok(())
4711}
4712
4713/// Start receiving chat updates for a channel
4714async fn join_channel_chat(
4715 request: proto::JoinChannelChat,
4716 response: Response<proto::JoinChannelChat>,
4717 session: UserSession,
4718) -> Result<()> {
4719 let channel_id = ChannelId::from_proto(request.channel_id);
4720
4721 let db = session.db().await;
4722 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4723 .await?;
4724 let messages = db
4725 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4726 .await?;
4727 response.send(proto::JoinChannelChatResponse {
4728 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4729 messages,
4730 })?;
4731 Ok(())
4732}
4733
4734/// Stop receiving chat updates for a channel
4735async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4736 let channel_id = ChannelId::from_proto(request.channel_id);
4737 session
4738 .db()
4739 .await
4740 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4741 .await?;
4742 Ok(())
4743}
4744
4745/// Retrieve the chat history for a channel
4746async fn get_channel_messages(
4747 request: proto::GetChannelMessages,
4748 response: Response<proto::GetChannelMessages>,
4749 session: UserSession,
4750) -> Result<()> {
4751 let channel_id = ChannelId::from_proto(request.channel_id);
4752 let messages = session
4753 .db()
4754 .await
4755 .get_channel_messages(
4756 channel_id,
4757 session.user_id(),
4758 MESSAGE_COUNT_PER_PAGE,
4759 Some(MessageId::from_proto(request.before_message_id)),
4760 )
4761 .await?;
4762 response.send(proto::GetChannelMessagesResponse {
4763 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4764 messages,
4765 })?;
4766 Ok(())
4767}
4768
4769/// Retrieve specific chat messages
4770async fn get_channel_messages_by_id(
4771 request: proto::GetChannelMessagesById,
4772 response: Response<proto::GetChannelMessagesById>,
4773 session: UserSession,
4774) -> Result<()> {
4775 let message_ids = request
4776 .message_ids
4777 .iter()
4778 .map(|id| MessageId::from_proto(*id))
4779 .collect::<Vec<_>>();
4780 let messages = session
4781 .db()
4782 .await
4783 .get_channel_messages_by_id(session.user_id(), &message_ids)
4784 .await?;
4785 response.send(proto::GetChannelMessagesResponse {
4786 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4787 messages,
4788 })?;
4789 Ok(())
4790}
4791
4792/// Retrieve the current users notifications
4793async fn get_notifications(
4794 request: proto::GetNotifications,
4795 response: Response<proto::GetNotifications>,
4796 session: UserSession,
4797) -> Result<()> {
4798 let notifications = session
4799 .db()
4800 .await
4801 .get_notifications(
4802 session.user_id(),
4803 NOTIFICATION_COUNT_PER_PAGE,
4804 request.before_id.map(db::NotificationId::from_proto),
4805 )
4806 .await?;
4807 response.send(proto::GetNotificationsResponse {
4808 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4809 notifications,
4810 })?;
4811 Ok(())
4812}
4813
4814/// Mark notifications as read
4815async fn mark_notification_as_read(
4816 request: proto::MarkNotificationRead,
4817 response: Response<proto::MarkNotificationRead>,
4818 session: UserSession,
4819) -> Result<()> {
4820 let database = &session.db().await;
4821 let notifications = database
4822 .mark_notification_as_read_by_id(
4823 session.user_id(),
4824 NotificationId::from_proto(request.notification_id),
4825 )
4826 .await?;
4827 send_notifications(
4828 &*session.connection_pool().await,
4829 &session.peer,
4830 notifications,
4831 );
4832 response.send(proto::Ack {})?;
4833 Ok(())
4834}
4835
4836/// Get the current users information
4837async fn get_private_user_info(
4838 _request: proto::GetPrivateUserInfo,
4839 response: Response<proto::GetPrivateUserInfo>,
4840 session: UserSession,
4841) -> Result<()> {
4842 let db = session.db().await;
4843
4844 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4845 let user = db
4846 .get_user_by_id(session.user_id())
4847 .await?
4848 .ok_or_else(|| anyhow!("user not found"))?;
4849 let flags = db.get_user_flags(session.user_id()).await?;
4850
4851 response.send(proto::GetPrivateUserInfoResponse {
4852 metrics_id,
4853 staff: user.admin,
4854 flags,
4855 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4856 })?;
4857 Ok(())
4858}
4859
4860/// Accept the terms of service (tos) on behalf of the current user
4861async fn accept_terms_of_service(
4862 _request: proto::AcceptTermsOfService,
4863 response: Response<proto::AcceptTermsOfService>,
4864 session: UserSession,
4865) -> Result<()> {
4866 let db = session.db().await;
4867
4868 let accepted_tos_at = Utc::now();
4869 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4870 .await?;
4871
4872 response.send(proto::AcceptTermsOfServiceResponse {
4873 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4874 })?;
4875 Ok(())
4876}
4877
4878/// The minimum account age an account must have in order to use the LLM service.
4879const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4880
4881async fn get_llm_api_token(
4882 _request: proto::GetLlmToken,
4883 response: Response<proto::GetLlmToken>,
4884 session: UserSession,
4885) -> Result<()> {
4886 let db = session.db().await;
4887
4888 let flags = db.get_user_flags(session.user_id()).await?;
4889 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4890 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4891
4892 if !session.is_staff() && !has_language_models_feature_flag {
4893 Err(anyhow!("permission denied"))?
4894 }
4895
4896 let user_id = session.user_id();
4897 let user = db
4898 .get_user_by_id(user_id)
4899 .await?
4900 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4901
4902 if user.accepted_tos_at.is_none() {
4903 Err(anyhow!("terms of service not accepted"))?
4904 }
4905
4906 let mut account_created_at = user.created_at;
4907 if let Some(github_created_at) = user.github_user_created_at {
4908 account_created_at = account_created_at.min(github_created_at);
4909 }
4910 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4911 Err(anyhow!("account too young"))?
4912 }
4913 let token = LlmTokenClaims::create(
4914 user.id,
4915 user.github_login.clone(),
4916 session.is_staff(),
4917 has_llm_closed_beta_feature_flag,
4918 session.current_plan(db).await?,
4919 &session.app_state.config,
4920 )?;
4921 response.send(proto::GetLlmTokenResponse { token })?;
4922 Ok(())
4923}
4924
4925fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4926 let message = match message {
4927 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4928 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4929 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4930 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4931 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4932 code: frame.code.into(),
4933 reason: frame.reason,
4934 })),
4935 // We should never receive a frame while reading the message, according
4936 // to the `tungstenite` maintainers:
4937 //
4938 // > It cannot occur when you read messages from the WebSocket, but it
4939 // > can be used when you want to send the raw frames (e.g. you want to
4940 // > send the frames to the WebSocket without composing the full message first).
4941 // >
4942 // > — https://github.com/snapview/tungstenite-rs/issues/268
4943 TungsteniteMessage::Frame(_) => {
4944 bail!("received an unexpected frame while reading the message")
4945 }
4946 };
4947
4948 Ok(message)
4949}
4950
4951fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4952 match message {
4953 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4954 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4955 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4956 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4957 AxumMessage::Close(frame) => {
4958 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4959 code: frame.code.into(),
4960 reason: frame.reason,
4961 }))
4962 }
4963 }
4964}
4965
4966fn notify_membership_updated(
4967 connection_pool: &mut ConnectionPool,
4968 result: MembershipUpdated,
4969 user_id: UserId,
4970 peer: &Peer,
4971) {
4972 for membership in &result.new_channels.channel_memberships {
4973 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4974 }
4975 for channel_id in &result.removed_channels {
4976 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4977 }
4978
4979 let user_channels_update = proto::UpdateUserChannels {
4980 channel_memberships: result
4981 .new_channels
4982 .channel_memberships
4983 .iter()
4984 .map(|cm| proto::ChannelMembership {
4985 channel_id: cm.channel_id.to_proto(),
4986 role: cm.role.into(),
4987 })
4988 .collect(),
4989 ..Default::default()
4990 };
4991
4992 let mut update = build_channels_update(result.new_channels);
4993 update.delete_channels = result
4994 .removed_channels
4995 .into_iter()
4996 .map(|id| id.to_proto())
4997 .collect();
4998 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4999
5000 for connection_id in connection_pool.user_connection_ids(user_id) {
5001 peer.send(connection_id, user_channels_update.clone())
5002 .trace_err();
5003 peer.send(connection_id, update.clone()).trace_err();
5004 }
5005}
5006
5007fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5008 proto::UpdateUserChannels {
5009 channel_memberships: channels
5010 .channel_memberships
5011 .iter()
5012 .map(|m| proto::ChannelMembership {
5013 channel_id: m.channel_id.to_proto(),
5014 role: m.role.into(),
5015 })
5016 .collect(),
5017 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5018 observed_channel_message_id: channels.observed_channel_messages.clone(),
5019 }
5020}
5021
5022fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5023 let mut update = proto::UpdateChannels::default();
5024
5025 for channel in channels.channels {
5026 update.channels.push(channel.to_proto());
5027 }
5028
5029 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5030 update.latest_channel_message_ids = channels.latest_channel_messages;
5031
5032 for (channel_id, participants) in channels.channel_participants {
5033 update
5034 .channel_participants
5035 .push(proto::ChannelParticipants {
5036 channel_id: channel_id.to_proto(),
5037 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5038 });
5039 }
5040
5041 for channel in channels.invited_channels {
5042 update.channel_invitations.push(channel.to_proto());
5043 }
5044
5045 update.hosted_projects = channels.hosted_projects;
5046 update
5047}
5048
5049fn build_initial_contacts_update(
5050 contacts: Vec<db::Contact>,
5051 pool: &ConnectionPool,
5052) -> proto::UpdateContacts {
5053 let mut update = proto::UpdateContacts::default();
5054
5055 for contact in contacts {
5056 match contact {
5057 db::Contact::Accepted { user_id, busy } => {
5058 update.contacts.push(contact_for_user(user_id, busy, pool));
5059 }
5060 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5061 db::Contact::Incoming { user_id } => {
5062 update
5063 .incoming_requests
5064 .push(proto::IncomingContactRequest {
5065 requester_id: user_id.to_proto(),
5066 })
5067 }
5068 }
5069 }
5070
5071 update
5072}
5073
5074fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5075 proto::Contact {
5076 user_id: user_id.to_proto(),
5077 online: pool.is_user_online(user_id),
5078 busy,
5079 }
5080}
5081
5082fn room_updated(room: &proto::Room, peer: &Peer) {
5083 broadcast(
5084 None,
5085 room.participants
5086 .iter()
5087 .filter_map(|participant| Some(participant.peer_id?.into())),
5088 |peer_id| {
5089 peer.send(
5090 peer_id,
5091 proto::RoomUpdated {
5092 room: Some(room.clone()),
5093 },
5094 )
5095 },
5096 );
5097}
5098
5099fn channel_updated(
5100 channel: &db::channel::Model,
5101 room: &proto::Room,
5102 peer: &Peer,
5103 pool: &ConnectionPool,
5104) {
5105 let participants = room
5106 .participants
5107 .iter()
5108 .map(|p| p.user_id)
5109 .collect::<Vec<_>>();
5110
5111 broadcast(
5112 None,
5113 pool.channel_connection_ids(channel.root_id())
5114 .filter_map(|(channel_id, role)| {
5115 role.can_see_channel(channel.visibility)
5116 .then_some(channel_id)
5117 }),
5118 |peer_id| {
5119 peer.send(
5120 peer_id,
5121 proto::UpdateChannels {
5122 channel_participants: vec![proto::ChannelParticipants {
5123 channel_id: channel.id.to_proto(),
5124 participant_user_ids: participants.clone(),
5125 }],
5126 ..Default::default()
5127 },
5128 )
5129 },
5130 );
5131}
5132
5133async fn send_dev_server_projects_update(
5134 user_id: UserId,
5135 mut status: proto::DevServerProjectsUpdate,
5136 session: &Session,
5137) {
5138 let pool = session.connection_pool().await;
5139 for dev_server in &mut status.dev_servers {
5140 dev_server.status =
5141 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5142 }
5143 let connections = pool.user_connection_ids(user_id);
5144 for connection_id in connections {
5145 session.peer.send(connection_id, status.clone()).trace_err();
5146 }
5147}
5148
5149async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5150 let db = session.db().await;
5151
5152 let contacts = db.get_contacts(user_id).await?;
5153 let busy = db.is_user_busy(user_id).await?;
5154
5155 let pool = session.connection_pool().await;
5156 let updated_contact = contact_for_user(user_id, busy, &pool);
5157 for contact in contacts {
5158 if let db::Contact::Accepted {
5159 user_id: contact_user_id,
5160 ..
5161 } = contact
5162 {
5163 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5164 session
5165 .peer
5166 .send(
5167 contact_conn_id,
5168 proto::UpdateContacts {
5169 contacts: vec![updated_contact.clone()],
5170 remove_contacts: Default::default(),
5171 incoming_requests: Default::default(),
5172 remove_incoming_requests: Default::default(),
5173 outgoing_requests: Default::default(),
5174 remove_outgoing_requests: Default::default(),
5175 },
5176 )
5177 .trace_err();
5178 }
5179 }
5180 }
5181 Ok(())
5182}
5183
5184async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5185 log::info!("lost dev server connection, unsharing projects");
5186 let project_ids = session
5187 .db()
5188 .await
5189 .get_stale_dev_server_projects(session.connection_id)
5190 .await?;
5191
5192 for project_id in project_ids {
5193 // not unshare re-checks the connection ids match, so we get away with no transaction
5194 unshare_project_internal(project_id, session.connection_id, None, session).await?;
5195 }
5196
5197 let user_id = session.dev_server().user_id;
5198 let update = session
5199 .db()
5200 .await
5201 .dev_server_projects_update(user_id)
5202 .await?;
5203
5204 send_dev_server_projects_update(user_id, update, session).await;
5205
5206 Ok(())
5207}
5208
5209async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5210 let mut contacts_to_update = HashSet::default();
5211
5212 let room_id;
5213 let canceled_calls_to_user_ids;
5214 let live_kit_room;
5215 let delete_live_kit_room;
5216 let room;
5217 let channel;
5218
5219 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5220 contacts_to_update.insert(session.user_id());
5221
5222 for project in left_room.left_projects.values() {
5223 project_left(project, session);
5224 }
5225
5226 room_id = RoomId::from_proto(left_room.room.id);
5227 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5228 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5229 delete_live_kit_room = left_room.deleted;
5230 room = mem::take(&mut left_room.room);
5231 channel = mem::take(&mut left_room.channel);
5232
5233 room_updated(&room, &session.peer);
5234 } else {
5235 return Ok(());
5236 }
5237
5238 if let Some(channel) = channel {
5239 channel_updated(
5240 &channel,
5241 &room,
5242 &session.peer,
5243 &*session.connection_pool().await,
5244 );
5245 }
5246
5247 {
5248 let pool = session.connection_pool().await;
5249 for canceled_user_id in canceled_calls_to_user_ids {
5250 for connection_id in pool.user_connection_ids(canceled_user_id) {
5251 session
5252 .peer
5253 .send(
5254 connection_id,
5255 proto::CallCanceled {
5256 room_id: room_id.to_proto(),
5257 },
5258 )
5259 .trace_err();
5260 }
5261 contacts_to_update.insert(canceled_user_id);
5262 }
5263 }
5264
5265 for contact_user_id in contacts_to_update {
5266 update_user_contacts(contact_user_id, session).await?;
5267 }
5268
5269 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5270 live_kit
5271 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5272 .await
5273 .trace_err();
5274
5275 if delete_live_kit_room {
5276 live_kit.delete_room(live_kit_room).await.trace_err();
5277 }
5278 }
5279
5280 Ok(())
5281}
5282
5283async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5284 let left_channel_buffers = session
5285 .db()
5286 .await
5287 .leave_channel_buffers(session.connection_id)
5288 .await?;
5289
5290 for left_buffer in left_channel_buffers {
5291 channel_buffer_updated(
5292 session.connection_id,
5293 left_buffer.connections,
5294 &proto::UpdateChannelBufferCollaborators {
5295 channel_id: left_buffer.channel_id.to_proto(),
5296 collaborators: left_buffer.collaborators,
5297 },
5298 &session.peer,
5299 );
5300 }
5301
5302 Ok(())
5303}
5304
5305fn project_left(project: &db::LeftProject, session: &UserSession) {
5306 for connection_id in &project.connection_ids {
5307 if project.should_unshare {
5308 session
5309 .peer
5310 .send(
5311 *connection_id,
5312 proto::UnshareProject {
5313 project_id: project.id.to_proto(),
5314 },
5315 )
5316 .trace_err();
5317 } else {
5318 session
5319 .peer
5320 .send(
5321 *connection_id,
5322 proto::RemoveProjectCollaborator {
5323 project_id: project.id.to_proto(),
5324 peer_id: Some(session.connection_id.into()),
5325 },
5326 )
5327 .trace_err();
5328 }
5329 }
5330}
5331
5332pub trait ResultExt {
5333 type Ok;
5334
5335 fn trace_err(self) -> Option<Self::Ok>;
5336}
5337
5338impl<T, E> ResultExt for Result<T, E>
5339where
5340 E: std::fmt::Debug,
5341{
5342 type Ok = T;
5343
5344 #[track_caller]
5345 fn trace_err(self) -> Option<T> {
5346 match self {
5347 Ok(value) => Some(value),
5348 Err(error) => {
5349 tracing::error!("{:?}", error);
5350 None
5351 }
5352 }
5353 }
5354}