1mod connection_pool;
2
3use crate::api::CloudflareIpCountryHeader;
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
10 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
11 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
12 ServerId, UpdatedChannelMessage, User, UserId,
13 },
14 executor::Executor,
15 AppState, Config, Error, RateLimit, Result,
16};
17use anyhow::{anyhow, bail, Context as _};
18use async_tungstenite::tungstenite::{
19 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
20};
21use axum::{
22 body::Body,
23 extract::{
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 ConnectInfo, WebSocketUpgrade,
26 },
27 headers::{Header, HeaderName},
28 http::StatusCode,
29 middleware,
30 response::IntoResponse,
31 routing::get,
32 Extension, Router, TypedHeader,
33};
34use chrono::Utc;
35use collections::{HashMap, HashSet};
36pub use connection_pool::{ConnectionPool, ZedVersion};
37use core::fmt::{self, Debug, Formatter};
38use http_client::HttpClient;
39use isahc_http_client::IsahcHttpClient;
40use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
41use sha2::Digest;
42use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
43
44use futures::{
45 channel::oneshot,
46 future::{self, BoxFuture},
47 stream::FuturesUnordered,
48 FutureExt, SinkExt, StreamExt, TryStreamExt,
49};
50use prometheus::{register_int_gauge, IntGauge};
51use rpc::{
52 proto::{
53 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
54 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
55 },
56 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
57};
58use semantic_version::SemanticVersion;
59use serde::{Serialize, Serializer};
60use std::{
61 any::TypeId,
62 future::Future,
63 marker::PhantomData,
64 mem,
65 net::SocketAddr,
66 ops::{Deref, DerefMut},
67 rc::Rc,
68 sync::{
69 atomic::{AtomicBool, Ordering::SeqCst},
70 Arc, OnceLock,
71 },
72 time::{Duration, Instant},
73};
74use time::OffsetDateTime;
75use tokio::sync::{watch, MutexGuard, Semaphore};
76use tower::ServiceBuilder;
77use tracing::{
78 field::{self},
79 info_span, instrument, Instrument,
80};
81
82pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
83
84// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
85pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
86
87const MESSAGE_COUNT_PER_PAGE: usize = 100;
88const MAX_MESSAGE_LEN: usize = 1024;
89const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
90
91type MessageHandler =
92 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
93
94struct Response<R> {
95 peer: Arc<Peer>,
96 receipt: Receipt<R>,
97 responded: Arc<AtomicBool>,
98}
99
100impl<R: RequestMessage> Response<R> {
101 fn send(self, payload: R::Response) -> Result<()> {
102 self.responded.store(true, SeqCst);
103 self.peer.respond(self.receipt, payload)?;
104 Ok(())
105 }
106}
107
108#[derive(Clone, Debug)]
109pub enum Principal {
110 User(User),
111 Impersonated { user: User, admin: User },
112 DevServer(dev_server::Model),
113}
114
115impl Principal {
116 fn update_span(&self, span: &tracing::Span) {
117 match &self {
118 Principal::User(user) => {
119 span.record("user_id", user.id.0);
120 span.record("login", &user.github_login);
121 }
122 Principal::Impersonated { user, admin } => {
123 span.record("user_id", user.id.0);
124 span.record("login", &user.github_login);
125 span.record("impersonator", &admin.github_login);
126 }
127 Principal::DevServer(dev_server) => {
128 span.record("dev_server_id", dev_server.id.0);
129 }
130 }
131 }
132}
133
134#[derive(Clone)]
135struct Session {
136 principal: Principal,
137 connection_id: ConnectionId,
138 db: Arc<tokio::sync::Mutex<DbHandle>>,
139 peer: Arc<Peer>,
140 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
141 app_state: Arc<AppState>,
142 supermaven_client: Option<Arc<SupermavenAdminApi>>,
143 http_client: Arc<dyn HttpClient>,
144 /// The GeoIP country code for the user.
145 #[allow(unused)]
146 geoip_country_code: Option<String>,
147 _executor: Executor,
148}
149
150impl Session {
151 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
152 #[cfg(test)]
153 tokio::task::yield_now().await;
154 let guard = self.db.lock().await;
155 #[cfg(test)]
156 tokio::task::yield_now().await;
157 guard
158 }
159
160 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
161 #[cfg(test)]
162 tokio::task::yield_now().await;
163 let guard = self.connection_pool.lock();
164 ConnectionPoolGuard {
165 guard,
166 _not_send: PhantomData,
167 }
168 }
169
170 fn for_user(self) -> Option<UserSession> {
171 UserSession::new(self)
172 }
173
174 fn for_dev_server(self) -> Option<DevServerSession> {
175 DevServerSession::new(self)
176 }
177
178 fn user_id(&self) -> Option<UserId> {
179 match &self.principal {
180 Principal::User(user) => Some(user.id),
181 Principal::Impersonated { user, .. } => Some(user.id),
182 Principal::DevServer(_) => None,
183 }
184 }
185
186 fn is_staff(&self) -> bool {
187 match &self.principal {
188 Principal::User(user) => user.admin,
189 Principal::Impersonated { .. } => true,
190 Principal::DevServer(_) => false,
191 }
192 }
193
194 pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
195 if self.is_staff() {
196 return Ok(proto::Plan::ZedPro);
197 }
198
199 let Some(user_id) = self.user_id() else {
200 return Ok(proto::Plan::Free);
201 };
202
203 if db.has_active_billing_subscription(user_id).await? {
204 Ok(proto::Plan::ZedPro)
205 } else {
206 Ok(proto::Plan::Free)
207 }
208 }
209
210 fn dev_server_id(&self) -> Option<DevServerId> {
211 match &self.principal {
212 Principal::User(_) | Principal::Impersonated { .. } => None,
213 Principal::DevServer(dev_server) => Some(dev_server.id),
214 }
215 }
216
217 fn principal_id(&self) -> PrincipalId {
218 match &self.principal {
219 Principal::User(user) => PrincipalId::UserId(user.id),
220 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
221 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
222 }
223 }
224}
225
226impl Debug for Session {
227 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
228 let mut result = f.debug_struct("Session");
229 match &self.principal {
230 Principal::User(user) => {
231 result.field("user", &user.github_login);
232 }
233 Principal::Impersonated { user, admin } => {
234 result.field("user", &user.github_login);
235 result.field("impersonator", &admin.github_login);
236 }
237 Principal::DevServer(dev_server) => {
238 result.field("dev_server", &dev_server.id);
239 }
240 }
241 result.field("connection_id", &self.connection_id).finish()
242 }
243}
244
245struct UserSession(Session);
246
247impl UserSession {
248 pub fn new(s: Session) -> Option<Self> {
249 s.user_id().map(|_| UserSession(s))
250 }
251 pub fn user_id(&self) -> UserId {
252 self.0.user_id().unwrap()
253 }
254
255 pub fn email(&self) -> Option<String> {
256 match &self.0.principal {
257 Principal::User(user) => user.email_address.clone(),
258 Principal::Impersonated { user, .. } => user.email_address.clone(),
259 Principal::DevServer(..) => None,
260 }
261 }
262}
263
264impl Deref for UserSession {
265 type Target = Session;
266
267 fn deref(&self) -> &Self::Target {
268 &self.0
269 }
270}
271impl DerefMut for UserSession {
272 fn deref_mut(&mut self) -> &mut Self::Target {
273 &mut self.0
274 }
275}
276
277struct DevServerSession(Session);
278
279impl DevServerSession {
280 pub fn new(s: Session) -> Option<Self> {
281 s.dev_server_id().map(|_| DevServerSession(s))
282 }
283 pub fn dev_server_id(&self) -> DevServerId {
284 self.0.dev_server_id().unwrap()
285 }
286
287 fn dev_server(&self) -> &dev_server::Model {
288 match &self.0.principal {
289 Principal::DevServer(dev_server) => dev_server,
290 _ => unreachable!(),
291 }
292 }
293}
294
295impl Deref for DevServerSession {
296 type Target = Session;
297
298 fn deref(&self) -> &Self::Target {
299 &self.0
300 }
301}
302impl DerefMut for DevServerSession {
303 fn deref_mut(&mut self) -> &mut Self::Target {
304 &mut self.0
305 }
306}
307
308fn user_handler<M: RequestMessage, Fut>(
309 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
310) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
311where
312 Fut: Send + Future<Output = Result<()>>,
313{
314 let handler = Arc::new(handler);
315 move |message, response, session| {
316 let handler = handler.clone();
317 Box::pin(async move {
318 if let Some(user_session) = session.for_user() {
319 Ok(handler(message, response, user_session).await?)
320 } else {
321 Err(Error::Internal(anyhow!(
322 "must be a user to call {}",
323 M::NAME
324 )))
325 }
326 })
327 }
328}
329
330fn dev_server_handler<M: RequestMessage, Fut>(
331 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
332) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
333where
334 Fut: Send + Future<Output = Result<()>>,
335{
336 let handler = Arc::new(handler);
337 move |message, response, session| {
338 let handler = handler.clone();
339 Box::pin(async move {
340 if let Some(dev_server_session) = session.for_dev_server() {
341 Ok(handler(message, response, dev_server_session).await?)
342 } else {
343 Err(Error::Internal(anyhow!(
344 "must be a dev server to call {}",
345 M::NAME
346 )))
347 }
348 })
349 }
350}
351
352fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
353 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
354) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
355where
356 InnertRetFut: Send + Future<Output = Result<()>>,
357{
358 let handler = Arc::new(handler);
359 move |message, session| {
360 let handler = handler.clone();
361 Box::pin(async move {
362 if let Some(user_session) = session.for_user() {
363 Ok(handler(message, user_session).await?)
364 } else {
365 Err(Error::Internal(anyhow!(
366 "must be a user to call {}",
367 M::NAME
368 )))
369 }
370 })
371 }
372}
373
374struct DbHandle(Arc<Database>);
375
376impl Deref for DbHandle {
377 type Target = Database;
378
379 fn deref(&self) -> &Self::Target {
380 self.0.as_ref()
381 }
382}
383
384pub struct Server {
385 id: parking_lot::Mutex<ServerId>,
386 peer: Arc<Peer>,
387 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
388 app_state: Arc<AppState>,
389 handlers: HashMap<TypeId, MessageHandler>,
390 teardown: watch::Sender<bool>,
391}
392
393pub(crate) struct ConnectionPoolGuard<'a> {
394 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
395 _not_send: PhantomData<Rc<()>>,
396}
397
398#[derive(Serialize)]
399pub struct ServerSnapshot<'a> {
400 peer: &'a Peer,
401 #[serde(serialize_with = "serialize_deref")]
402 connection_pool: ConnectionPoolGuard<'a>,
403}
404
405pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
406where
407 S: Serializer,
408 T: Deref<Target = U>,
409 U: Serialize,
410{
411 Serialize::serialize(value.deref(), serializer)
412}
413
414impl Server {
415 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
416 let mut server = Self {
417 id: parking_lot::Mutex::new(id),
418 peer: Peer::new(id.0 as u32),
419 app_state: app_state.clone(),
420 connection_pool: Default::default(),
421 handlers: Default::default(),
422 teardown: watch::channel(false).0,
423 };
424
425 server
426 .add_request_handler(ping)
427 .add_request_handler(user_handler(create_room))
428 .add_request_handler(user_handler(join_room))
429 .add_request_handler(user_handler(rejoin_room))
430 .add_request_handler(user_handler(leave_room))
431 .add_request_handler(user_handler(set_room_participant_role))
432 .add_request_handler(user_handler(call))
433 .add_request_handler(user_handler(cancel_call))
434 .add_message_handler(user_message_handler(decline_call))
435 .add_request_handler(user_handler(update_participant_location))
436 .add_request_handler(user_handler(share_project))
437 .add_message_handler(unshare_project)
438 .add_request_handler(user_handler(join_project))
439 .add_request_handler(user_handler(join_hosted_project))
440 .add_request_handler(user_handler(rejoin_dev_server_projects))
441 .add_request_handler(user_handler(create_dev_server_project))
442 .add_request_handler(user_handler(update_dev_server_project))
443 .add_request_handler(user_handler(delete_dev_server_project))
444 .add_request_handler(user_handler(create_dev_server))
445 .add_request_handler(user_handler(regenerate_dev_server_token))
446 .add_request_handler(user_handler(rename_dev_server))
447 .add_request_handler(user_handler(delete_dev_server))
448 .add_request_handler(user_handler(list_remote_directory))
449 .add_request_handler(dev_server_handler(share_dev_server_project))
450 .add_request_handler(dev_server_handler(shutdown_dev_server))
451 .add_request_handler(dev_server_handler(reconnect_dev_server))
452 .add_message_handler(user_message_handler(leave_project))
453 .add_request_handler(update_project)
454 .add_request_handler(update_worktree)
455 .add_message_handler(start_language_server)
456 .add_message_handler(update_language_server)
457 .add_message_handler(update_diagnostic_summary)
458 .add_message_handler(update_worktree_settings)
459 .add_request_handler(user_handler(
460 forward_project_request_for_owner::<proto::TaskContextForLocation>,
461 ))
462 .add_request_handler(user_handler(
463 forward_project_request_for_owner::<proto::TaskTemplates>,
464 ))
465 .add_request_handler(user_handler(
466 forward_read_only_project_request::<proto::GetHover>,
467 ))
468 .add_request_handler(user_handler(
469 forward_read_only_project_request::<proto::GetDefinition>,
470 ))
471 .add_request_handler(user_handler(
472 forward_read_only_project_request::<proto::GetTypeDefinition>,
473 ))
474 .add_request_handler(user_handler(
475 forward_read_only_project_request::<proto::GetReferences>,
476 ))
477 .add_request_handler(user_handler(forward_find_search_candidates_request))
478 .add_request_handler(user_handler(
479 forward_read_only_project_request::<proto::GetDocumentHighlights>,
480 ))
481 .add_request_handler(user_handler(
482 forward_read_only_project_request::<proto::GetProjectSymbols>,
483 ))
484 .add_request_handler(user_handler(
485 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
486 ))
487 .add_request_handler(user_handler(
488 forward_read_only_project_request::<proto::OpenBufferById>,
489 ))
490 .add_request_handler(user_handler(
491 forward_read_only_project_request::<proto::SynchronizeBuffers>,
492 ))
493 .add_request_handler(user_handler(
494 forward_read_only_project_request::<proto::InlayHints>,
495 ))
496 .add_request_handler(user_handler(
497 forward_read_only_project_request::<proto::ResolveInlayHint>,
498 ))
499 .add_request_handler(user_handler(
500 forward_read_only_project_request::<proto::OpenBufferByPath>,
501 ))
502 .add_request_handler(user_handler(
503 forward_mutating_project_request::<proto::GetCompletions>,
504 ))
505 .add_request_handler(user_handler(
506 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
507 ))
508 .add_request_handler(user_handler(
509 forward_mutating_project_request::<proto::OpenNewBuffer>,
510 ))
511 .add_request_handler(user_handler(
512 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
513 ))
514 .add_request_handler(user_handler(
515 forward_mutating_project_request::<proto::GetCodeActions>,
516 ))
517 .add_request_handler(user_handler(
518 forward_mutating_project_request::<proto::ApplyCodeAction>,
519 ))
520 .add_request_handler(user_handler(
521 forward_mutating_project_request::<proto::PrepareRename>,
522 ))
523 .add_request_handler(user_handler(
524 forward_mutating_project_request::<proto::PerformRename>,
525 ))
526 .add_request_handler(user_handler(
527 forward_mutating_project_request::<proto::ReloadBuffers>,
528 ))
529 .add_request_handler(user_handler(
530 forward_mutating_project_request::<proto::FormatBuffers>,
531 ))
532 .add_request_handler(user_handler(
533 forward_mutating_project_request::<proto::CreateProjectEntry>,
534 ))
535 .add_request_handler(user_handler(
536 forward_mutating_project_request::<proto::RenameProjectEntry>,
537 ))
538 .add_request_handler(user_handler(
539 forward_mutating_project_request::<proto::CopyProjectEntry>,
540 ))
541 .add_request_handler(user_handler(
542 forward_mutating_project_request::<proto::DeleteProjectEntry>,
543 ))
544 .add_request_handler(user_handler(
545 forward_mutating_project_request::<proto::ExpandProjectEntry>,
546 ))
547 .add_request_handler(user_handler(
548 forward_mutating_project_request::<proto::OnTypeFormatting>,
549 ))
550 .add_request_handler(user_handler(
551 forward_mutating_project_request::<proto::SaveBuffer>,
552 ))
553 .add_request_handler(user_handler(
554 forward_mutating_project_request::<proto::BlameBuffer>,
555 ))
556 .add_request_handler(user_handler(
557 forward_mutating_project_request::<proto::MultiLspQuery>,
558 ))
559 .add_request_handler(user_handler(
560 forward_mutating_project_request::<proto::RestartLanguageServers>,
561 ))
562 .add_request_handler(user_handler(
563 forward_mutating_project_request::<proto::LinkedEditingRange>,
564 ))
565 .add_message_handler(create_buffer_for_peer)
566 .add_request_handler(update_buffer)
567 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
568 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
569 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
570 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
571 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
572 .add_request_handler(get_users)
573 .add_request_handler(user_handler(fuzzy_search_users))
574 .add_request_handler(user_handler(request_contact))
575 .add_request_handler(user_handler(remove_contact))
576 .add_request_handler(user_handler(respond_to_contact_request))
577 .add_message_handler(subscribe_to_channels)
578 .add_request_handler(user_handler(create_channel))
579 .add_request_handler(user_handler(delete_channel))
580 .add_request_handler(user_handler(invite_channel_member))
581 .add_request_handler(user_handler(remove_channel_member))
582 .add_request_handler(user_handler(set_channel_member_role))
583 .add_request_handler(user_handler(set_channel_visibility))
584 .add_request_handler(user_handler(rename_channel))
585 .add_request_handler(user_handler(join_channel_buffer))
586 .add_request_handler(user_handler(leave_channel_buffer))
587 .add_message_handler(user_message_handler(update_channel_buffer))
588 .add_request_handler(user_handler(rejoin_channel_buffers))
589 .add_request_handler(user_handler(get_channel_members))
590 .add_request_handler(user_handler(respond_to_channel_invite))
591 .add_request_handler(user_handler(join_channel))
592 .add_request_handler(user_handler(join_channel_chat))
593 .add_message_handler(user_message_handler(leave_channel_chat))
594 .add_request_handler(user_handler(send_channel_message))
595 .add_request_handler(user_handler(remove_channel_message))
596 .add_request_handler(user_handler(update_channel_message))
597 .add_request_handler(user_handler(get_channel_messages))
598 .add_request_handler(user_handler(get_channel_messages_by_id))
599 .add_request_handler(user_handler(get_notifications))
600 .add_request_handler(user_handler(mark_notification_as_read))
601 .add_request_handler(user_handler(move_channel))
602 .add_request_handler(user_handler(follow))
603 .add_message_handler(user_message_handler(unfollow))
604 .add_message_handler(user_message_handler(update_followers))
605 .add_request_handler(user_handler(get_private_user_info))
606 .add_request_handler(user_handler(get_llm_api_token))
607 .add_request_handler(user_handler(accept_terms_of_service))
608 .add_message_handler(user_message_handler(acknowledge_channel_message))
609 .add_message_handler(user_message_handler(acknowledge_buffer_version))
610 .add_request_handler(user_handler(get_supermaven_api_key))
611 .add_request_handler(user_handler(
612 forward_mutating_project_request::<proto::OpenContext>,
613 ))
614 .add_request_handler(user_handler(
615 forward_mutating_project_request::<proto::CreateContext>,
616 ))
617 .add_request_handler(user_handler(
618 forward_mutating_project_request::<proto::SynchronizeContexts>,
619 ))
620 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
621 .add_message_handler(update_context)
622 .add_request_handler({
623 let app_state = app_state.clone();
624 move |request, response, session| {
625 let app_state = app_state.clone();
626 async move {
627 count_language_model_tokens(request, response, session, &app_state.config)
628 .await
629 }
630 }
631 })
632 .add_request_handler({
633 user_handler(move |request, response, session| {
634 get_cached_embeddings(request, response, session)
635 })
636 })
637 .add_request_handler({
638 let app_state = app_state.clone();
639 user_handler(move |request, response, session| {
640 compute_embeddings(
641 request,
642 response,
643 session,
644 app_state.config.openai_api_key.clone(),
645 )
646 })
647 });
648
649 Arc::new(server)
650 }
651
652 pub async fn start(&self) -> Result<()> {
653 let server_id = *self.id.lock();
654 let app_state = self.app_state.clone();
655 let peer = self.peer.clone();
656 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
657 let pool = self.connection_pool.clone();
658 let live_kit_client = self.app_state.live_kit_client.clone();
659
660 let span = info_span!("start server");
661 self.app_state.executor.spawn_detached(
662 async move {
663 tracing::info!("waiting for cleanup timeout");
664 timeout.await;
665 tracing::info!("cleanup timeout expired, retrieving stale rooms");
666 if let Some((room_ids, channel_ids)) = app_state
667 .db
668 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
669 .await
670 .trace_err()
671 {
672 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
673 tracing::info!(
674 stale_channel_buffer_count = channel_ids.len(),
675 "retrieved stale channel buffers"
676 );
677
678 for channel_id in channel_ids {
679 if let Some(refreshed_channel_buffer) = app_state
680 .db
681 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
682 .await
683 .trace_err()
684 {
685 for connection_id in refreshed_channel_buffer.connection_ids {
686 peer.send(
687 connection_id,
688 proto::UpdateChannelBufferCollaborators {
689 channel_id: channel_id.to_proto(),
690 collaborators: refreshed_channel_buffer
691 .collaborators
692 .clone(),
693 },
694 )
695 .trace_err();
696 }
697 }
698 }
699
700 for room_id in room_ids {
701 let mut contacts_to_update = HashSet::default();
702 let mut canceled_calls_to_user_ids = Vec::new();
703 let mut live_kit_room = String::new();
704 let mut delete_live_kit_room = false;
705
706 if let Some(mut refreshed_room) = app_state
707 .db
708 .clear_stale_room_participants(room_id, server_id)
709 .await
710 .trace_err()
711 {
712 tracing::info!(
713 room_id = room_id.0,
714 new_participant_count = refreshed_room.room.participants.len(),
715 "refreshed room"
716 );
717 room_updated(&refreshed_room.room, &peer);
718 if let Some(channel) = refreshed_room.channel.as_ref() {
719 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
720 }
721 contacts_to_update
722 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
723 contacts_to_update
724 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
725 canceled_calls_to_user_ids =
726 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
727 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
728 delete_live_kit_room = refreshed_room.room.participants.is_empty();
729 }
730
731 {
732 let pool = pool.lock();
733 for canceled_user_id in canceled_calls_to_user_ids {
734 for connection_id in pool.user_connection_ids(canceled_user_id) {
735 peer.send(
736 connection_id,
737 proto::CallCanceled {
738 room_id: room_id.to_proto(),
739 },
740 )
741 .trace_err();
742 }
743 }
744 }
745
746 for user_id in contacts_to_update {
747 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
748 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
749 if let Some((busy, contacts)) = busy.zip(contacts) {
750 let pool = pool.lock();
751 let updated_contact = contact_for_user(user_id, busy, &pool);
752 for contact in contacts {
753 if let db::Contact::Accepted {
754 user_id: contact_user_id,
755 ..
756 } = contact
757 {
758 for contact_conn_id in
759 pool.user_connection_ids(contact_user_id)
760 {
761 peer.send(
762 contact_conn_id,
763 proto::UpdateContacts {
764 contacts: vec![updated_contact.clone()],
765 remove_contacts: Default::default(),
766 incoming_requests: Default::default(),
767 remove_incoming_requests: Default::default(),
768 outgoing_requests: Default::default(),
769 remove_outgoing_requests: Default::default(),
770 },
771 )
772 .trace_err();
773 }
774 }
775 }
776 }
777 }
778
779 if let Some(live_kit) = live_kit_client.as_ref() {
780 if delete_live_kit_room {
781 live_kit.delete_room(live_kit_room).await.trace_err();
782 }
783 }
784 }
785 }
786
787 app_state
788 .db
789 .delete_stale_servers(&app_state.config.zed_environment, server_id)
790 .await
791 .trace_err();
792 }
793 .instrument(span),
794 );
795 Ok(())
796 }
797
798 pub fn teardown(&self) {
799 self.peer.teardown();
800 self.connection_pool.lock().reset();
801 let _ = self.teardown.send(true);
802 }
803
804 #[cfg(test)]
805 pub fn reset(&self, id: ServerId) {
806 self.teardown();
807 *self.id.lock() = id;
808 self.peer.reset(id.0 as u32);
809 let _ = self.teardown.send(false);
810 }
811
812 #[cfg(test)]
813 pub fn id(&self) -> ServerId {
814 *self.id.lock()
815 }
816
817 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
818 where
819 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
820 Fut: 'static + Send + Future<Output = Result<()>>,
821 M: EnvelopedMessage,
822 {
823 let prev_handler = self.handlers.insert(
824 TypeId::of::<M>(),
825 Box::new(move |envelope, session| {
826 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
827 let received_at = envelope.received_at;
828 tracing::info!("message received");
829 let start_time = Instant::now();
830 let future = (handler)(*envelope, session);
831 async move {
832 let result = future.await;
833 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
834 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
835 let queue_duration_ms = total_duration_ms - processing_duration_ms;
836 let payload_type = M::NAME;
837
838 match result {
839 Err(error) => {
840 tracing::error!(
841 ?error,
842 total_duration_ms,
843 processing_duration_ms,
844 queue_duration_ms,
845 payload_type,
846 "error handling message"
847 )
848 }
849 Ok(()) => tracing::info!(
850 total_duration_ms,
851 processing_duration_ms,
852 queue_duration_ms,
853 "finished handling message"
854 ),
855 }
856 }
857 .boxed()
858 }),
859 );
860 if prev_handler.is_some() {
861 panic!("registered a handler for the same message twice");
862 }
863 self
864 }
865
866 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
867 where
868 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
869 Fut: 'static + Send + Future<Output = Result<()>>,
870 M: EnvelopedMessage,
871 {
872 self.add_handler(move |envelope, session| handler(envelope.payload, session));
873 self
874 }
875
876 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
877 where
878 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
879 Fut: Send + Future<Output = Result<()>>,
880 M: RequestMessage,
881 {
882 let handler = Arc::new(handler);
883 self.add_handler(move |envelope, session| {
884 let receipt = envelope.receipt();
885 let handler = handler.clone();
886 async move {
887 let peer = session.peer.clone();
888 let responded = Arc::new(AtomicBool::default());
889 let response = Response {
890 peer: peer.clone(),
891 responded: responded.clone(),
892 receipt,
893 };
894 match (handler)(envelope.payload, response, session).await {
895 Ok(()) => {
896 if responded.load(std::sync::atomic::Ordering::SeqCst) {
897 Ok(())
898 } else {
899 Err(anyhow!("handler did not send a response"))?
900 }
901 }
902 Err(error) => {
903 let proto_err = match &error {
904 Error::Internal(err) => err.to_proto(),
905 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
906 };
907 peer.respond_with_error(receipt, proto_err)?;
908 Err(error)
909 }
910 }
911 }
912 })
913 }
914
915 #[allow(clippy::too_many_arguments)]
916 pub fn handle_connection(
917 self: &Arc<Self>,
918 connection: Connection,
919 address: String,
920 principal: Principal,
921 zed_version: ZedVersion,
922 geoip_country_code: Option<String>,
923 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
924 executor: Executor,
925 ) -> impl Future<Output = ()> {
926 let this = self.clone();
927 let span = info_span!("handle connection", %address,
928 connection_id=field::Empty,
929 user_id=field::Empty,
930 login=field::Empty,
931 impersonator=field::Empty,
932 dev_server_id=field::Empty,
933 geoip_country_code=field::Empty
934 );
935 principal.update_span(&span);
936 if let Some(country_code) = geoip_country_code.as_ref() {
937 span.record("geoip_country_code", country_code);
938 }
939
940 let mut teardown = self.teardown.subscribe();
941 async move {
942 if *teardown.borrow() {
943 tracing::error!("server is tearing down");
944 return
945 }
946 let (connection_id, handle_io, mut incoming_rx) = this
947 .peer
948 .add_connection(connection, {
949 let executor = executor.clone();
950 move |duration| executor.sleep(duration)
951 });
952 tracing::Span::current().record("connection_id", format!("{}", connection_id));
953
954 tracing::info!("connection opened");
955
956 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
957 let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
958 Ok(http_client) => Arc::new(IsahcHttpClient::from(http_client)),
959 Err(error) => {
960 tracing::error!(?error, "failed to create HTTP client");
961 return;
962 }
963 };
964
965 let supermaven_client = this.app_state.config.supermaven_admin_api_key.clone().map(|supermaven_admin_api_key| Arc::new(SupermavenAdminApi::new(
966 supermaven_admin_api_key.to_string(),
967 http_client.clone(),
968 )));
969
970 let session = Session {
971 principal: principal.clone(),
972 connection_id,
973 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
974 peer: this.peer.clone(),
975 connection_pool: this.connection_pool.clone(),
976 app_state: this.app_state.clone(),
977 http_client,
978 geoip_country_code,
979 _executor: executor.clone(),
980 supermaven_client,
981 };
982
983 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
984 tracing::error!(?error, "failed to send initial client update");
985 return;
986 }
987
988 let handle_io = handle_io.fuse();
989 futures::pin_mut!(handle_io);
990
991 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
992 // This prevents deadlocks when e.g., client A performs a request to client B and
993 // client B performs a request to client A. If both clients stop processing further
994 // messages until their respective request completes, they won't have a chance to
995 // respond to the other client's request and cause a deadlock.
996 //
997 // This arrangement ensures we will attempt to process earlier messages first, but fall
998 // back to processing messages arrived later in the spirit of making progress.
999 let mut foreground_message_handlers = FuturesUnordered::new();
1000 let concurrent_handlers = Arc::new(Semaphore::new(256));
1001 loop {
1002 let next_message = async {
1003 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1004 let message = incoming_rx.next().await;
1005 (permit, message)
1006 }.fuse();
1007 futures::pin_mut!(next_message);
1008 futures::select_biased! {
1009 _ = teardown.changed().fuse() => return,
1010 result = handle_io => {
1011 if let Err(error) = result {
1012 tracing::error!(?error, "error handling I/O");
1013 }
1014 break;
1015 }
1016 _ = foreground_message_handlers.next() => {}
1017 next_message = next_message => {
1018 let (permit, message) = next_message;
1019 if let Some(message) = message {
1020 let type_name = message.payload_type_name();
1021 // note: we copy all the fields from the parent span so we can query them in the logs.
1022 // (https://github.com/tokio-rs/tracing/issues/2670).
1023 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1024 user_id=field::Empty,
1025 login=field::Empty,
1026 impersonator=field::Empty,
1027 dev_server_id=field::Empty
1028 );
1029 principal.update_span(&span);
1030 let span_enter = span.enter();
1031 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1032 let is_background = message.is_background();
1033 let handle_message = (handler)(message, session.clone());
1034 drop(span_enter);
1035
1036 let handle_message = async move {
1037 handle_message.await;
1038 drop(permit);
1039 }.instrument(span);
1040 if is_background {
1041 executor.spawn_detached(handle_message);
1042 } else {
1043 foreground_message_handlers.push(handle_message);
1044 }
1045 } else {
1046 tracing::error!("no message handler");
1047 }
1048 } else {
1049 tracing::info!("connection closed");
1050 break;
1051 }
1052 }
1053 }
1054 }
1055
1056 drop(foreground_message_handlers);
1057 tracing::info!("signing out");
1058 if let Err(error) = connection_lost(session, teardown, executor).await {
1059 tracing::error!(?error, "error signing out");
1060 }
1061
1062 }.instrument(span)
1063 }
1064
1065 async fn send_initial_client_update(
1066 &self,
1067 connection_id: ConnectionId,
1068 principal: &Principal,
1069 zed_version: ZedVersion,
1070 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1071 session: &Session,
1072 ) -> Result<()> {
1073 self.peer.send(
1074 connection_id,
1075 proto::Hello {
1076 peer_id: Some(connection_id.into()),
1077 },
1078 )?;
1079 tracing::info!("sent hello message");
1080 if let Some(send_connection_id) = send_connection_id.take() {
1081 let _ = send_connection_id.send(connection_id);
1082 }
1083
1084 match principal {
1085 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1086 if !user.connected_once {
1087 self.peer.send(connection_id, proto::ShowContacts {})?;
1088 self.app_state
1089 .db
1090 .set_user_connected_once(user.id, true)
1091 .await?;
1092 }
1093
1094 update_user_plan(user.id, session).await?;
1095
1096 let (contacts, dev_server_projects) = future::try_join(
1097 self.app_state.db.get_contacts(user.id),
1098 self.app_state.db.dev_server_projects_update(user.id),
1099 )
1100 .await?;
1101
1102 {
1103 let mut pool = self.connection_pool.lock();
1104 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1105 self.peer.send(
1106 connection_id,
1107 build_initial_contacts_update(contacts, &pool),
1108 )?;
1109 }
1110
1111 if should_auto_subscribe_to_channels(zed_version) {
1112 subscribe_user_to_channels(user.id, session).await?;
1113 }
1114
1115 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1116
1117 if let Some(incoming_call) =
1118 self.app_state.db.incoming_call_for_user(user.id).await?
1119 {
1120 self.peer.send(connection_id, incoming_call)?;
1121 }
1122
1123 update_user_contacts(user.id, session).await?;
1124 }
1125 Principal::DevServer(dev_server) => {
1126 {
1127 let mut pool = self.connection_pool.lock();
1128 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1129 {
1130 self.peer.send(
1131 stale_connection_id,
1132 proto::ShutdownDevServer {
1133 reason: Some(
1134 "another dev server connected with the same token".to_string(),
1135 ),
1136 },
1137 )?;
1138 pool.remove_connection(stale_connection_id)?;
1139 };
1140 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1141 }
1142
1143 let projects = self
1144 .app_state
1145 .db
1146 .get_projects_for_dev_server(dev_server.id)
1147 .await?;
1148 self.peer
1149 .send(connection_id, proto::DevServerInstructions { projects })?;
1150
1151 let status = self
1152 .app_state
1153 .db
1154 .dev_server_projects_update(dev_server.user_id)
1155 .await?;
1156 send_dev_server_projects_update(dev_server.user_id, status, session).await;
1157 }
1158 }
1159
1160 Ok(())
1161 }
1162
1163 pub async fn invite_code_redeemed(
1164 self: &Arc<Self>,
1165 inviter_id: UserId,
1166 invitee_id: UserId,
1167 ) -> Result<()> {
1168 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1169 if let Some(code) = &user.invite_code {
1170 let pool = self.connection_pool.lock();
1171 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1172 for connection_id in pool.user_connection_ids(inviter_id) {
1173 self.peer.send(
1174 connection_id,
1175 proto::UpdateContacts {
1176 contacts: vec![invitee_contact.clone()],
1177 ..Default::default()
1178 },
1179 )?;
1180 self.peer.send(
1181 connection_id,
1182 proto::UpdateInviteInfo {
1183 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1184 count: user.invite_count as u32,
1185 },
1186 )?;
1187 }
1188 }
1189 }
1190 Ok(())
1191 }
1192
1193 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1194 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1195 if let Some(invite_code) = &user.invite_code {
1196 let pool = self.connection_pool.lock();
1197 for connection_id in pool.user_connection_ids(user_id) {
1198 self.peer.send(
1199 connection_id,
1200 proto::UpdateInviteInfo {
1201 url: format!(
1202 "{}{}",
1203 self.app_state.config.invite_link_prefix, invite_code
1204 ),
1205 count: user.invite_count as u32,
1206 },
1207 )?;
1208 }
1209 }
1210 }
1211 Ok(())
1212 }
1213
1214 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1215 ServerSnapshot {
1216 connection_pool: ConnectionPoolGuard {
1217 guard: self.connection_pool.lock(),
1218 _not_send: PhantomData,
1219 },
1220 peer: &self.peer,
1221 }
1222 }
1223}
1224
1225impl<'a> Deref for ConnectionPoolGuard<'a> {
1226 type Target = ConnectionPool;
1227
1228 fn deref(&self) -> &Self::Target {
1229 &self.guard
1230 }
1231}
1232
1233impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1234 fn deref_mut(&mut self) -> &mut Self::Target {
1235 &mut self.guard
1236 }
1237}
1238
1239impl<'a> Drop for ConnectionPoolGuard<'a> {
1240 fn drop(&mut self) {
1241 #[cfg(test)]
1242 self.check_invariants();
1243 }
1244}
1245
1246fn broadcast<F>(
1247 sender_id: Option<ConnectionId>,
1248 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1249 mut f: F,
1250) where
1251 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1252{
1253 for receiver_id in receiver_ids {
1254 if Some(receiver_id) != sender_id {
1255 if let Err(error) = f(receiver_id) {
1256 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1257 }
1258 }
1259 }
1260}
1261
1262pub struct ProtocolVersion(u32);
1263
1264impl Header for ProtocolVersion {
1265 fn name() -> &'static HeaderName {
1266 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1267 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1268 }
1269
1270 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1271 where
1272 Self: Sized,
1273 I: Iterator<Item = &'i axum::http::HeaderValue>,
1274 {
1275 let version = values
1276 .next()
1277 .ok_or_else(axum::headers::Error::invalid)?
1278 .to_str()
1279 .map_err(|_| axum::headers::Error::invalid())?
1280 .parse()
1281 .map_err(|_| axum::headers::Error::invalid())?;
1282 Ok(Self(version))
1283 }
1284
1285 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1286 values.extend([self.0.to_string().parse().unwrap()]);
1287 }
1288}
1289
1290pub struct AppVersionHeader(SemanticVersion);
1291impl Header for AppVersionHeader {
1292 fn name() -> &'static HeaderName {
1293 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1294 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1295 }
1296
1297 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1298 where
1299 Self: Sized,
1300 I: Iterator<Item = &'i axum::http::HeaderValue>,
1301 {
1302 let version = values
1303 .next()
1304 .ok_or_else(axum::headers::Error::invalid)?
1305 .to_str()
1306 .map_err(|_| axum::headers::Error::invalid())?
1307 .parse()
1308 .map_err(|_| axum::headers::Error::invalid())?;
1309 Ok(Self(version))
1310 }
1311
1312 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1313 values.extend([self.0.to_string().parse().unwrap()]);
1314 }
1315}
1316
1317pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1318 Router::new()
1319 .route("/rpc", get(handle_websocket_request))
1320 .layer(
1321 ServiceBuilder::new()
1322 .layer(Extension(server.app_state.clone()))
1323 .layer(middleware::from_fn(auth::validate_header)),
1324 )
1325 .route("/metrics", get(handle_metrics))
1326 .layer(Extension(server))
1327}
1328
1329pub async fn handle_websocket_request(
1330 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1331 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1332 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1333 Extension(server): Extension<Arc<Server>>,
1334 Extension(principal): Extension<Principal>,
1335 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1336 ws: WebSocketUpgrade,
1337) -> axum::response::Response {
1338 if protocol_version != rpc::PROTOCOL_VERSION {
1339 return (
1340 StatusCode::UPGRADE_REQUIRED,
1341 "client must be upgraded".to_string(),
1342 )
1343 .into_response();
1344 }
1345
1346 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1347 return (
1348 StatusCode::UPGRADE_REQUIRED,
1349 "no version header found".to_string(),
1350 )
1351 .into_response();
1352 };
1353
1354 if !version.can_collaborate() {
1355 return (
1356 StatusCode::UPGRADE_REQUIRED,
1357 "client must be upgraded".to_string(),
1358 )
1359 .into_response();
1360 }
1361
1362 let socket_address = socket_address.to_string();
1363 ws.on_upgrade(move |socket| {
1364 let socket = socket
1365 .map_ok(to_tungstenite_message)
1366 .err_into()
1367 .with(|message| async move { to_axum_message(message) });
1368 let connection = Connection::new(Box::pin(socket));
1369 async move {
1370 server
1371 .handle_connection(
1372 connection,
1373 socket_address,
1374 principal,
1375 version,
1376 country_code_header.map(|header| header.to_string()),
1377 None,
1378 Executor::Production,
1379 )
1380 .await;
1381 }
1382 })
1383}
1384
1385pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1386 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1387 let connections_metric = CONNECTIONS_METRIC
1388 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1389
1390 let connections = server
1391 .connection_pool
1392 .lock()
1393 .connections()
1394 .filter(|connection| !connection.admin)
1395 .count();
1396 connections_metric.set(connections as _);
1397
1398 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1399 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1400 register_int_gauge!(
1401 "shared_projects",
1402 "number of open projects with one or more guests"
1403 )
1404 .unwrap()
1405 });
1406
1407 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1408 shared_projects_metric.set(shared_projects as _);
1409
1410 let encoder = prometheus::TextEncoder::new();
1411 let metric_families = prometheus::gather();
1412 let encoded_metrics = encoder
1413 .encode_to_string(&metric_families)
1414 .map_err(|err| anyhow!("{}", err))?;
1415 Ok(encoded_metrics)
1416}
1417
1418#[instrument(err, skip(executor))]
1419async fn connection_lost(
1420 session: Session,
1421 mut teardown: watch::Receiver<bool>,
1422 executor: Executor,
1423) -> Result<()> {
1424 session.peer.disconnect(session.connection_id);
1425 session
1426 .connection_pool()
1427 .await
1428 .remove_connection(session.connection_id)?;
1429
1430 session
1431 .db()
1432 .await
1433 .connection_lost(session.connection_id)
1434 .await
1435 .trace_err();
1436
1437 futures::select_biased! {
1438 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1439 match &session.principal {
1440 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1441 let session = session.for_user().unwrap();
1442
1443 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1444 leave_room_for_session(&session, session.connection_id).await.trace_err();
1445 leave_channel_buffers_for_session(&session)
1446 .await
1447 .trace_err();
1448
1449 if !session
1450 .connection_pool()
1451 .await
1452 .is_user_online(session.user_id())
1453 {
1454 let db = session.db().await;
1455 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1456 room_updated(&room, &session.peer);
1457 }
1458 }
1459
1460 update_user_contacts(session.user_id(), &session).await?;
1461 },
1462 Principal::DevServer(_) => {
1463 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1464 },
1465 }
1466 },
1467 _ = teardown.changed().fuse() => {}
1468 }
1469
1470 Ok(())
1471}
1472
1473/// Acknowledges a ping from a client, used to keep the connection alive.
1474async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1475 response.send(proto::Ack {})?;
1476 Ok(())
1477}
1478
1479/// Creates a new room for calling (outside of channels)
1480async fn create_room(
1481 _request: proto::CreateRoom,
1482 response: Response<proto::CreateRoom>,
1483 session: UserSession,
1484) -> Result<()> {
1485 let live_kit_room = nanoid::nanoid!(30);
1486
1487 let live_kit_connection_info = util::maybe!(async {
1488 let live_kit = session.app_state.live_kit_client.as_ref();
1489 let live_kit = live_kit?;
1490 let user_id = session.user_id().to_string();
1491
1492 let token = live_kit
1493 .room_token(&live_kit_room, &user_id.to_string())
1494 .trace_err()?;
1495
1496 Some(proto::LiveKitConnectionInfo {
1497 server_url: live_kit.url().into(),
1498 token,
1499 can_publish: true,
1500 })
1501 })
1502 .await;
1503
1504 let room = session
1505 .db()
1506 .await
1507 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1508 .await?;
1509
1510 response.send(proto::CreateRoomResponse {
1511 room: Some(room.clone()),
1512 live_kit_connection_info,
1513 })?;
1514
1515 update_user_contacts(session.user_id(), &session).await?;
1516 Ok(())
1517}
1518
1519/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1520async fn join_room(
1521 request: proto::JoinRoom,
1522 response: Response<proto::JoinRoom>,
1523 session: UserSession,
1524) -> Result<()> {
1525 let room_id = RoomId::from_proto(request.id);
1526
1527 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1528
1529 if let Some(channel_id) = channel_id {
1530 return join_channel_internal(channel_id, Box::new(response), session).await;
1531 }
1532
1533 let joined_room = {
1534 let room = session
1535 .db()
1536 .await
1537 .join_room(room_id, session.user_id(), session.connection_id)
1538 .await?;
1539 room_updated(&room.room, &session.peer);
1540 room.into_inner()
1541 };
1542
1543 for connection_id in session
1544 .connection_pool()
1545 .await
1546 .user_connection_ids(session.user_id())
1547 {
1548 session
1549 .peer
1550 .send(
1551 connection_id,
1552 proto::CallCanceled {
1553 room_id: room_id.to_proto(),
1554 },
1555 )
1556 .trace_err();
1557 }
1558
1559 let live_kit_connection_info =
1560 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1561 live_kit
1562 .room_token(
1563 &joined_room.room.live_kit_room,
1564 &session.user_id().to_string(),
1565 )
1566 .trace_err()
1567 .map(|token| proto::LiveKitConnectionInfo {
1568 server_url: live_kit.url().into(),
1569 token,
1570 can_publish: true,
1571 })
1572 } else {
1573 None
1574 };
1575
1576 response.send(proto::JoinRoomResponse {
1577 room: Some(joined_room.room),
1578 channel_id: None,
1579 live_kit_connection_info,
1580 })?;
1581
1582 update_user_contacts(session.user_id(), &session).await?;
1583 Ok(())
1584}
1585
1586/// Rejoin room is used to reconnect to a room after connection errors.
1587async fn rejoin_room(
1588 request: proto::RejoinRoom,
1589 response: Response<proto::RejoinRoom>,
1590 session: UserSession,
1591) -> Result<()> {
1592 let room;
1593 let channel;
1594 {
1595 let mut rejoined_room = session
1596 .db()
1597 .await
1598 .rejoin_room(request, session.user_id(), session.connection_id)
1599 .await?;
1600
1601 response.send(proto::RejoinRoomResponse {
1602 room: Some(rejoined_room.room.clone()),
1603 reshared_projects: rejoined_room
1604 .reshared_projects
1605 .iter()
1606 .map(|project| proto::ResharedProject {
1607 id: project.id.to_proto(),
1608 collaborators: project
1609 .collaborators
1610 .iter()
1611 .map(|collaborator| collaborator.to_proto())
1612 .collect(),
1613 })
1614 .collect(),
1615 rejoined_projects: rejoined_room
1616 .rejoined_projects
1617 .iter()
1618 .map(|rejoined_project| rejoined_project.to_proto())
1619 .collect(),
1620 })?;
1621 room_updated(&rejoined_room.room, &session.peer);
1622
1623 for project in &rejoined_room.reshared_projects {
1624 for collaborator in &project.collaborators {
1625 session
1626 .peer
1627 .send(
1628 collaborator.connection_id,
1629 proto::UpdateProjectCollaborator {
1630 project_id: project.id.to_proto(),
1631 old_peer_id: Some(project.old_connection_id.into()),
1632 new_peer_id: Some(session.connection_id.into()),
1633 },
1634 )
1635 .trace_err();
1636 }
1637
1638 broadcast(
1639 Some(session.connection_id),
1640 project
1641 .collaborators
1642 .iter()
1643 .map(|collaborator| collaborator.connection_id),
1644 |connection_id| {
1645 session.peer.forward_send(
1646 session.connection_id,
1647 connection_id,
1648 proto::UpdateProject {
1649 project_id: project.id.to_proto(),
1650 worktrees: project.worktrees.clone(),
1651 },
1652 )
1653 },
1654 );
1655 }
1656
1657 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1658
1659 let rejoined_room = rejoined_room.into_inner();
1660
1661 room = rejoined_room.room;
1662 channel = rejoined_room.channel;
1663 }
1664
1665 if let Some(channel) = channel {
1666 channel_updated(
1667 &channel,
1668 &room,
1669 &session.peer,
1670 &*session.connection_pool().await,
1671 );
1672 }
1673
1674 update_user_contacts(session.user_id(), &session).await?;
1675 Ok(())
1676}
1677
1678fn notify_rejoined_projects(
1679 rejoined_projects: &mut Vec<RejoinedProject>,
1680 session: &UserSession,
1681) -> Result<()> {
1682 for project in rejoined_projects.iter() {
1683 for collaborator in &project.collaborators {
1684 session
1685 .peer
1686 .send(
1687 collaborator.connection_id,
1688 proto::UpdateProjectCollaborator {
1689 project_id: project.id.to_proto(),
1690 old_peer_id: Some(project.old_connection_id.into()),
1691 new_peer_id: Some(session.connection_id.into()),
1692 },
1693 )
1694 .trace_err();
1695 }
1696 }
1697
1698 for project in rejoined_projects {
1699 for worktree in mem::take(&mut project.worktrees) {
1700 #[cfg(any(test, feature = "test-support"))]
1701 const MAX_CHUNK_SIZE: usize = 2;
1702 #[cfg(not(any(test, feature = "test-support")))]
1703 const MAX_CHUNK_SIZE: usize = 256;
1704
1705 // Stream this worktree's entries.
1706 let message = proto::UpdateWorktree {
1707 project_id: project.id.to_proto(),
1708 worktree_id: worktree.id,
1709 abs_path: worktree.abs_path.clone(),
1710 root_name: worktree.root_name,
1711 updated_entries: worktree.updated_entries,
1712 removed_entries: worktree.removed_entries,
1713 scan_id: worktree.scan_id,
1714 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1715 updated_repositories: worktree.updated_repositories,
1716 removed_repositories: worktree.removed_repositories,
1717 };
1718 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1719 session.peer.send(session.connection_id, update.clone())?;
1720 }
1721
1722 // Stream this worktree's diagnostics.
1723 for summary in worktree.diagnostic_summaries {
1724 session.peer.send(
1725 session.connection_id,
1726 proto::UpdateDiagnosticSummary {
1727 project_id: project.id.to_proto(),
1728 worktree_id: worktree.id,
1729 summary: Some(summary),
1730 },
1731 )?;
1732 }
1733
1734 for settings_file in worktree.settings_files {
1735 session.peer.send(
1736 session.connection_id,
1737 proto::UpdateWorktreeSettings {
1738 project_id: project.id.to_proto(),
1739 worktree_id: worktree.id,
1740 path: settings_file.path,
1741 content: Some(settings_file.content),
1742 },
1743 )?;
1744 }
1745 }
1746
1747 for language_server in &project.language_servers {
1748 session.peer.send(
1749 session.connection_id,
1750 proto::UpdateLanguageServer {
1751 project_id: project.id.to_proto(),
1752 language_server_id: language_server.id,
1753 variant: Some(
1754 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1755 proto::LspDiskBasedDiagnosticsUpdated {},
1756 ),
1757 ),
1758 },
1759 )?;
1760 }
1761 }
1762 Ok(())
1763}
1764
1765/// leave room disconnects from the room.
1766async fn leave_room(
1767 _: proto::LeaveRoom,
1768 response: Response<proto::LeaveRoom>,
1769 session: UserSession,
1770) -> Result<()> {
1771 leave_room_for_session(&session, session.connection_id).await?;
1772 response.send(proto::Ack {})?;
1773 Ok(())
1774}
1775
1776/// Updates the permissions of someone else in the room.
1777async fn set_room_participant_role(
1778 request: proto::SetRoomParticipantRole,
1779 response: Response<proto::SetRoomParticipantRole>,
1780 session: UserSession,
1781) -> Result<()> {
1782 let user_id = UserId::from_proto(request.user_id);
1783 let role = ChannelRole::from(request.role());
1784
1785 let (live_kit_room, can_publish) = {
1786 let room = session
1787 .db()
1788 .await
1789 .set_room_participant_role(
1790 session.user_id(),
1791 RoomId::from_proto(request.room_id),
1792 user_id,
1793 role,
1794 )
1795 .await?;
1796
1797 let live_kit_room = room.live_kit_room.clone();
1798 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1799 room_updated(&room, &session.peer);
1800 (live_kit_room, can_publish)
1801 };
1802
1803 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1804 live_kit
1805 .update_participant(
1806 live_kit_room.clone(),
1807 request.user_id.to_string(),
1808 live_kit_server::proto::ParticipantPermission {
1809 can_subscribe: true,
1810 can_publish,
1811 can_publish_data: can_publish,
1812 hidden: false,
1813 recorder: false,
1814 },
1815 )
1816 .await
1817 .trace_err();
1818 }
1819
1820 response.send(proto::Ack {})?;
1821 Ok(())
1822}
1823
1824/// Call someone else into the current room
1825async fn call(
1826 request: proto::Call,
1827 response: Response<proto::Call>,
1828 session: UserSession,
1829) -> Result<()> {
1830 let room_id = RoomId::from_proto(request.room_id);
1831 let calling_user_id = session.user_id();
1832 let calling_connection_id = session.connection_id;
1833 let called_user_id = UserId::from_proto(request.called_user_id);
1834 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1835 if !session
1836 .db()
1837 .await
1838 .has_contact(calling_user_id, called_user_id)
1839 .await?
1840 {
1841 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1842 }
1843
1844 let incoming_call = {
1845 let (room, incoming_call) = &mut *session
1846 .db()
1847 .await
1848 .call(
1849 room_id,
1850 calling_user_id,
1851 calling_connection_id,
1852 called_user_id,
1853 initial_project_id,
1854 )
1855 .await?;
1856 room_updated(room, &session.peer);
1857 mem::take(incoming_call)
1858 };
1859 update_user_contacts(called_user_id, &session).await?;
1860
1861 let mut calls = session
1862 .connection_pool()
1863 .await
1864 .user_connection_ids(called_user_id)
1865 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1866 .collect::<FuturesUnordered<_>>();
1867
1868 while let Some(call_response) = calls.next().await {
1869 match call_response.as_ref() {
1870 Ok(_) => {
1871 response.send(proto::Ack {})?;
1872 return Ok(());
1873 }
1874 Err(_) => {
1875 call_response.trace_err();
1876 }
1877 }
1878 }
1879
1880 {
1881 let room = session
1882 .db()
1883 .await
1884 .call_failed(room_id, called_user_id)
1885 .await?;
1886 room_updated(&room, &session.peer);
1887 }
1888 update_user_contacts(called_user_id, &session).await?;
1889
1890 Err(anyhow!("failed to ring user"))?
1891}
1892
1893/// Cancel an outgoing call.
1894async fn cancel_call(
1895 request: proto::CancelCall,
1896 response: Response<proto::CancelCall>,
1897 session: UserSession,
1898) -> Result<()> {
1899 let called_user_id = UserId::from_proto(request.called_user_id);
1900 let room_id = RoomId::from_proto(request.room_id);
1901 {
1902 let room = session
1903 .db()
1904 .await
1905 .cancel_call(room_id, session.connection_id, called_user_id)
1906 .await?;
1907 room_updated(&room, &session.peer);
1908 }
1909
1910 for connection_id in session
1911 .connection_pool()
1912 .await
1913 .user_connection_ids(called_user_id)
1914 {
1915 session
1916 .peer
1917 .send(
1918 connection_id,
1919 proto::CallCanceled {
1920 room_id: room_id.to_proto(),
1921 },
1922 )
1923 .trace_err();
1924 }
1925 response.send(proto::Ack {})?;
1926
1927 update_user_contacts(called_user_id, &session).await?;
1928 Ok(())
1929}
1930
1931/// Decline an incoming call.
1932async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1933 let room_id = RoomId::from_proto(message.room_id);
1934 {
1935 let room = session
1936 .db()
1937 .await
1938 .decline_call(Some(room_id), session.user_id())
1939 .await?
1940 .ok_or_else(|| anyhow!("failed to decline call"))?;
1941 room_updated(&room, &session.peer);
1942 }
1943
1944 for connection_id in session
1945 .connection_pool()
1946 .await
1947 .user_connection_ids(session.user_id())
1948 {
1949 session
1950 .peer
1951 .send(
1952 connection_id,
1953 proto::CallCanceled {
1954 room_id: room_id.to_proto(),
1955 },
1956 )
1957 .trace_err();
1958 }
1959 update_user_contacts(session.user_id(), &session).await?;
1960 Ok(())
1961}
1962
1963/// Updates other participants in the room with your current location.
1964async fn update_participant_location(
1965 request: proto::UpdateParticipantLocation,
1966 response: Response<proto::UpdateParticipantLocation>,
1967 session: UserSession,
1968) -> Result<()> {
1969 let room_id = RoomId::from_proto(request.room_id);
1970 let location = request
1971 .location
1972 .ok_or_else(|| anyhow!("invalid location"))?;
1973
1974 let db = session.db().await;
1975 let room = db
1976 .update_room_participant_location(room_id, session.connection_id, location)
1977 .await?;
1978
1979 room_updated(&room, &session.peer);
1980 response.send(proto::Ack {})?;
1981 Ok(())
1982}
1983
1984/// Share a project into the room.
1985async fn share_project(
1986 request: proto::ShareProject,
1987 response: Response<proto::ShareProject>,
1988 session: UserSession,
1989) -> Result<()> {
1990 let (project_id, room) = &*session
1991 .db()
1992 .await
1993 .share_project(
1994 RoomId::from_proto(request.room_id),
1995 session.connection_id,
1996 &request.worktrees,
1997 request.is_ssh_project,
1998 request
1999 .dev_server_project_id
2000 .map(DevServerProjectId::from_proto),
2001 )
2002 .await?;
2003 response.send(proto::ShareProjectResponse {
2004 project_id: project_id.to_proto(),
2005 })?;
2006 room_updated(room, &session.peer);
2007
2008 Ok(())
2009}
2010
2011/// Unshare a project from the room.
2012async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2013 let project_id = ProjectId::from_proto(message.project_id);
2014 unshare_project_internal(
2015 project_id,
2016 session.connection_id,
2017 session.user_id(),
2018 &session,
2019 )
2020 .await
2021}
2022
2023async fn unshare_project_internal(
2024 project_id: ProjectId,
2025 connection_id: ConnectionId,
2026 user_id: Option<UserId>,
2027 session: &Session,
2028) -> Result<()> {
2029 let delete = {
2030 let room_guard = session
2031 .db()
2032 .await
2033 .unshare_project(project_id, connection_id, user_id)
2034 .await?;
2035
2036 let (delete, room, guest_connection_ids) = &*room_guard;
2037
2038 let message = proto::UnshareProject {
2039 project_id: project_id.to_proto(),
2040 };
2041
2042 broadcast(
2043 Some(connection_id),
2044 guest_connection_ids.iter().copied(),
2045 |conn_id| session.peer.send(conn_id, message.clone()),
2046 );
2047 if let Some(room) = room {
2048 room_updated(room, &session.peer);
2049 }
2050
2051 *delete
2052 };
2053
2054 if delete {
2055 let db = session.db().await;
2056 db.delete_project(project_id).await?;
2057 }
2058
2059 Ok(())
2060}
2061
2062/// DevServer makes a project available online
2063async fn share_dev_server_project(
2064 request: proto::ShareDevServerProject,
2065 response: Response<proto::ShareDevServerProject>,
2066 session: DevServerSession,
2067) -> Result<()> {
2068 let (dev_server_project, user_id, status) = session
2069 .db()
2070 .await
2071 .share_dev_server_project(
2072 DevServerProjectId::from_proto(request.dev_server_project_id),
2073 session.dev_server_id(),
2074 session.connection_id,
2075 &request.worktrees,
2076 )
2077 .await?;
2078 let Some(project_id) = dev_server_project.project_id else {
2079 return Err(anyhow!("failed to share remote project"))?;
2080 };
2081
2082 send_dev_server_projects_update(user_id, status, &session).await;
2083
2084 response.send(proto::ShareProjectResponse { project_id })?;
2085
2086 Ok(())
2087}
2088
2089/// Join someone elses shared project.
2090async fn join_project(
2091 request: proto::JoinProject,
2092 response: Response<proto::JoinProject>,
2093 session: UserSession,
2094) -> Result<()> {
2095 let project_id = ProjectId::from_proto(request.project_id);
2096
2097 tracing::info!(%project_id, "join project");
2098
2099 let db = session.db().await;
2100 let (project, replica_id) = &mut *db
2101 .join_project(project_id, session.connection_id, session.user_id())
2102 .await?;
2103 drop(db);
2104 tracing::info!(%project_id, "join remote project");
2105 join_project_internal(response, session, project, replica_id)
2106}
2107
2108trait JoinProjectInternalResponse {
2109 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2110}
2111impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2112 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2113 Response::<proto::JoinProject>::send(self, result)
2114 }
2115}
2116impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2117 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2118 Response::<proto::JoinHostedProject>::send(self, result)
2119 }
2120}
2121
2122fn join_project_internal(
2123 response: impl JoinProjectInternalResponse,
2124 session: UserSession,
2125 project: &mut Project,
2126 replica_id: &ReplicaId,
2127) -> Result<()> {
2128 let collaborators = project
2129 .collaborators
2130 .iter()
2131 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2132 .map(|collaborator| collaborator.to_proto())
2133 .collect::<Vec<_>>();
2134 let project_id = project.id;
2135 let guest_user_id = session.user_id();
2136
2137 let worktrees = project
2138 .worktrees
2139 .iter()
2140 .map(|(id, worktree)| proto::WorktreeMetadata {
2141 id: *id,
2142 root_name: worktree.root_name.clone(),
2143 visible: worktree.visible,
2144 abs_path: worktree.abs_path.clone(),
2145 })
2146 .collect::<Vec<_>>();
2147
2148 let add_project_collaborator = proto::AddProjectCollaborator {
2149 project_id: project_id.to_proto(),
2150 collaborator: Some(proto::Collaborator {
2151 peer_id: Some(session.connection_id.into()),
2152 replica_id: replica_id.0 as u32,
2153 user_id: guest_user_id.to_proto(),
2154 }),
2155 };
2156
2157 for collaborator in &collaborators {
2158 session
2159 .peer
2160 .send(
2161 collaborator.peer_id.unwrap().into(),
2162 add_project_collaborator.clone(),
2163 )
2164 .trace_err();
2165 }
2166
2167 // First, we send the metadata associated with each worktree.
2168 response.send(proto::JoinProjectResponse {
2169 project_id: project.id.0 as u64,
2170 worktrees: worktrees.clone(),
2171 replica_id: replica_id.0 as u32,
2172 collaborators: collaborators.clone(),
2173 language_servers: project.language_servers.clone(),
2174 role: project.role.into(),
2175 dev_server_project_id: project
2176 .dev_server_project_id
2177 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2178 })?;
2179
2180 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2181 #[cfg(any(test, feature = "test-support"))]
2182 const MAX_CHUNK_SIZE: usize = 2;
2183 #[cfg(not(any(test, feature = "test-support")))]
2184 const MAX_CHUNK_SIZE: usize = 256;
2185
2186 // Stream this worktree's entries.
2187 let message = proto::UpdateWorktree {
2188 project_id: project_id.to_proto(),
2189 worktree_id,
2190 abs_path: worktree.abs_path.clone(),
2191 root_name: worktree.root_name,
2192 updated_entries: worktree.entries,
2193 removed_entries: Default::default(),
2194 scan_id: worktree.scan_id,
2195 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2196 updated_repositories: worktree.repository_entries.into_values().collect(),
2197 removed_repositories: Default::default(),
2198 };
2199 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2200 session.peer.send(session.connection_id, update.clone())?;
2201 }
2202
2203 // Stream this worktree's diagnostics.
2204 for summary in worktree.diagnostic_summaries {
2205 session.peer.send(
2206 session.connection_id,
2207 proto::UpdateDiagnosticSummary {
2208 project_id: project_id.to_proto(),
2209 worktree_id: worktree.id,
2210 summary: Some(summary),
2211 },
2212 )?;
2213 }
2214
2215 for settings_file in worktree.settings_files {
2216 session.peer.send(
2217 session.connection_id,
2218 proto::UpdateWorktreeSettings {
2219 project_id: project_id.to_proto(),
2220 worktree_id: worktree.id,
2221 path: settings_file.path,
2222 content: Some(settings_file.content),
2223 },
2224 )?;
2225 }
2226 }
2227
2228 for language_server in &project.language_servers {
2229 session.peer.send(
2230 session.connection_id,
2231 proto::UpdateLanguageServer {
2232 project_id: project_id.to_proto(),
2233 language_server_id: language_server.id,
2234 variant: Some(
2235 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2236 proto::LspDiskBasedDiagnosticsUpdated {},
2237 ),
2238 ),
2239 },
2240 )?;
2241 }
2242
2243 Ok(())
2244}
2245
2246/// Leave someone elses shared project.
2247async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2248 let sender_id = session.connection_id;
2249 let project_id = ProjectId::from_proto(request.project_id);
2250 let db = session.db().await;
2251 if db.is_hosted_project(project_id).await? {
2252 let project = db.leave_hosted_project(project_id, sender_id).await?;
2253 project_left(&project, &session);
2254 return Ok(());
2255 }
2256
2257 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2258 tracing::info!(
2259 %project_id,
2260 "leave project"
2261 );
2262
2263 project_left(project, &session);
2264 if let Some(room) = room {
2265 room_updated(room, &session.peer);
2266 }
2267
2268 Ok(())
2269}
2270
2271async fn join_hosted_project(
2272 request: proto::JoinHostedProject,
2273 response: Response<proto::JoinHostedProject>,
2274 session: UserSession,
2275) -> Result<()> {
2276 let (mut project, replica_id) = session
2277 .db()
2278 .await
2279 .join_hosted_project(
2280 ProjectId(request.project_id as i32),
2281 session.user_id(),
2282 session.connection_id,
2283 )
2284 .await?;
2285
2286 join_project_internal(response, session, &mut project, &replica_id)
2287}
2288
2289async fn list_remote_directory(
2290 request: proto::ListRemoteDirectory,
2291 response: Response<proto::ListRemoteDirectory>,
2292 session: UserSession,
2293) -> Result<()> {
2294 let dev_server_id = DevServerId(request.dev_server_id as i32);
2295 let dev_server_connection_id = session
2296 .connection_pool()
2297 .await
2298 .online_dev_server_connection_id(dev_server_id)?;
2299
2300 session
2301 .db()
2302 .await
2303 .get_dev_server_for_user(dev_server_id, session.user_id())
2304 .await?;
2305
2306 response.send(
2307 session
2308 .peer
2309 .forward_request(session.connection_id, dev_server_connection_id, request)
2310 .await?,
2311 )?;
2312 Ok(())
2313}
2314
2315async fn update_dev_server_project(
2316 request: proto::UpdateDevServerProject,
2317 response: Response<proto::UpdateDevServerProject>,
2318 session: UserSession,
2319) -> Result<()> {
2320 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2321
2322 let (dev_server_project, update) = session
2323 .db()
2324 .await
2325 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2326 .await?;
2327
2328 let projects = session
2329 .db()
2330 .await
2331 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2332 .await?;
2333
2334 let dev_server_connection_id = session
2335 .connection_pool()
2336 .await
2337 .online_dev_server_connection_id(dev_server_project.dev_server_id)?;
2338
2339 session.peer.send(
2340 dev_server_connection_id,
2341 proto::DevServerInstructions { projects },
2342 )?;
2343
2344 send_dev_server_projects_update(session.user_id(), update, &session).await;
2345
2346 response.send(proto::Ack {})
2347}
2348
2349async fn create_dev_server_project(
2350 request: proto::CreateDevServerProject,
2351 response: Response<proto::CreateDevServerProject>,
2352 session: UserSession,
2353) -> Result<()> {
2354 let dev_server_id = DevServerId(request.dev_server_id as i32);
2355 let dev_server_connection_id = session
2356 .connection_pool()
2357 .await
2358 .dev_server_connection_id(dev_server_id);
2359 let Some(dev_server_connection_id) = dev_server_connection_id else {
2360 Err(ErrorCode::DevServerOffline
2361 .message("Cannot create a remote project when the dev server is offline".to_string())
2362 .anyhow())?
2363 };
2364
2365 let path = request.path.clone();
2366 //Check that the path exists on the dev server
2367 session
2368 .peer
2369 .forward_request(
2370 session.connection_id,
2371 dev_server_connection_id,
2372 proto::ValidateDevServerProjectRequest { path: path.clone() },
2373 )
2374 .await?;
2375
2376 let (dev_server_project, update) = session
2377 .db()
2378 .await
2379 .create_dev_server_project(
2380 DevServerId(request.dev_server_id as i32),
2381 &request.path,
2382 session.user_id(),
2383 )
2384 .await?;
2385
2386 let projects = session
2387 .db()
2388 .await
2389 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2390 .await?;
2391
2392 session.peer.send(
2393 dev_server_connection_id,
2394 proto::DevServerInstructions { projects },
2395 )?;
2396
2397 send_dev_server_projects_update(session.user_id(), update, &session).await;
2398
2399 response.send(proto::CreateDevServerProjectResponse {
2400 dev_server_project: Some(dev_server_project.to_proto(None)),
2401 })?;
2402 Ok(())
2403}
2404
2405async fn create_dev_server(
2406 request: proto::CreateDevServer,
2407 response: Response<proto::CreateDevServer>,
2408 session: UserSession,
2409) -> Result<()> {
2410 let access_token = auth::random_token();
2411 let hashed_access_token = auth::hash_access_token(&access_token);
2412
2413 if request.name.is_empty() {
2414 return Err(proto::ErrorCode::Forbidden
2415 .message("Dev server name cannot be empty".to_string())
2416 .anyhow())?;
2417 }
2418
2419 let (dev_server, status) = session
2420 .db()
2421 .await
2422 .create_dev_server(
2423 &request.name,
2424 request.ssh_connection_string.as_deref(),
2425 &hashed_access_token,
2426 session.user_id(),
2427 )
2428 .await?;
2429
2430 send_dev_server_projects_update(session.user_id(), status, &session).await;
2431
2432 response.send(proto::CreateDevServerResponse {
2433 dev_server_id: dev_server.id.0 as u64,
2434 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2435 name: request.name,
2436 })?;
2437 Ok(())
2438}
2439
2440async fn regenerate_dev_server_token(
2441 request: proto::RegenerateDevServerToken,
2442 response: Response<proto::RegenerateDevServerToken>,
2443 session: UserSession,
2444) -> Result<()> {
2445 let dev_server_id = DevServerId(request.dev_server_id as i32);
2446 let access_token = auth::random_token();
2447 let hashed_access_token = auth::hash_access_token(&access_token);
2448
2449 let connection_id = session
2450 .connection_pool()
2451 .await
2452 .dev_server_connection_id(dev_server_id);
2453 if let Some(connection_id) = connection_id {
2454 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2455 session.peer.send(
2456 connection_id,
2457 proto::ShutdownDevServer {
2458 reason: Some("dev server token was regenerated".to_string()),
2459 },
2460 )?;
2461 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2462 }
2463
2464 let status = session
2465 .db()
2466 .await
2467 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2468 .await?;
2469
2470 send_dev_server_projects_update(session.user_id(), status, &session).await;
2471
2472 response.send(proto::RegenerateDevServerTokenResponse {
2473 dev_server_id: dev_server_id.to_proto(),
2474 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2475 })?;
2476 Ok(())
2477}
2478
2479async fn rename_dev_server(
2480 request: proto::RenameDevServer,
2481 response: Response<proto::RenameDevServer>,
2482 session: UserSession,
2483) -> Result<()> {
2484 if request.name.trim().is_empty() {
2485 return Err(proto::ErrorCode::Forbidden
2486 .message("Dev server name cannot be empty".to_string())
2487 .anyhow())?;
2488 }
2489
2490 let dev_server_id = DevServerId(request.dev_server_id as i32);
2491 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2492 if dev_server.user_id != session.user_id() {
2493 return Err(anyhow!(ErrorCode::Forbidden))?;
2494 }
2495
2496 let status = session
2497 .db()
2498 .await
2499 .rename_dev_server(
2500 dev_server_id,
2501 &request.name,
2502 request.ssh_connection_string.as_deref(),
2503 session.user_id(),
2504 )
2505 .await?;
2506
2507 send_dev_server_projects_update(session.user_id(), status, &session).await;
2508
2509 response.send(proto::Ack {})?;
2510 Ok(())
2511}
2512
2513async fn delete_dev_server(
2514 request: proto::DeleteDevServer,
2515 response: Response<proto::DeleteDevServer>,
2516 session: UserSession,
2517) -> Result<()> {
2518 let dev_server_id = DevServerId(request.dev_server_id as i32);
2519 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2520 if dev_server.user_id != session.user_id() {
2521 return Err(anyhow!(ErrorCode::Forbidden))?;
2522 }
2523
2524 let connection_id = session
2525 .connection_pool()
2526 .await
2527 .dev_server_connection_id(dev_server_id);
2528 if let Some(connection_id) = connection_id {
2529 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2530 session.peer.send(
2531 connection_id,
2532 proto::ShutdownDevServer {
2533 reason: Some("dev server was deleted".to_string()),
2534 },
2535 )?;
2536 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2537 }
2538
2539 let status = session
2540 .db()
2541 .await
2542 .delete_dev_server(dev_server_id, session.user_id())
2543 .await?;
2544
2545 send_dev_server_projects_update(session.user_id(), status, &session).await;
2546
2547 response.send(proto::Ack {})?;
2548 Ok(())
2549}
2550
2551async fn delete_dev_server_project(
2552 request: proto::DeleteDevServerProject,
2553 response: Response<proto::DeleteDevServerProject>,
2554 session: UserSession,
2555) -> Result<()> {
2556 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2557 let dev_server_project = session
2558 .db()
2559 .await
2560 .get_dev_server_project(dev_server_project_id)
2561 .await?;
2562
2563 let dev_server = session
2564 .db()
2565 .await
2566 .get_dev_server(dev_server_project.dev_server_id)
2567 .await?;
2568 if dev_server.user_id != session.user_id() {
2569 return Err(anyhow!(ErrorCode::Forbidden))?;
2570 }
2571
2572 let dev_server_connection_id = session
2573 .connection_pool()
2574 .await
2575 .dev_server_connection_id(dev_server.id);
2576
2577 if let Some(dev_server_connection_id) = dev_server_connection_id {
2578 let project = session
2579 .db()
2580 .await
2581 .find_dev_server_project(dev_server_project_id)
2582 .await;
2583 if let Ok(project) = project {
2584 unshare_project_internal(
2585 project.id,
2586 dev_server_connection_id,
2587 Some(session.user_id()),
2588 &session,
2589 )
2590 .await?;
2591 }
2592 }
2593
2594 let (projects, status) = session
2595 .db()
2596 .await
2597 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2598 .await?;
2599
2600 if let Some(dev_server_connection_id) = dev_server_connection_id {
2601 session.peer.send(
2602 dev_server_connection_id,
2603 proto::DevServerInstructions { projects },
2604 )?;
2605 }
2606
2607 send_dev_server_projects_update(session.user_id(), status, &session).await;
2608
2609 response.send(proto::Ack {})?;
2610 Ok(())
2611}
2612
2613async fn rejoin_dev_server_projects(
2614 request: proto::RejoinRemoteProjects,
2615 response: Response<proto::RejoinRemoteProjects>,
2616 session: UserSession,
2617) -> Result<()> {
2618 let mut rejoined_projects = {
2619 let db = session.db().await;
2620 db.rejoin_dev_server_projects(
2621 &request.rejoined_projects,
2622 session.user_id(),
2623 session.0.connection_id,
2624 )
2625 .await?
2626 };
2627 response.send(proto::RejoinRemoteProjectsResponse {
2628 rejoined_projects: rejoined_projects
2629 .iter()
2630 .map(|project| project.to_proto())
2631 .collect(),
2632 })?;
2633 notify_rejoined_projects(&mut rejoined_projects, &session)
2634}
2635
2636async fn reconnect_dev_server(
2637 request: proto::ReconnectDevServer,
2638 response: Response<proto::ReconnectDevServer>,
2639 session: DevServerSession,
2640) -> Result<()> {
2641 let reshared_projects = {
2642 let db = session.db().await;
2643 db.reshare_dev_server_projects(
2644 &request.reshared_projects,
2645 session.dev_server_id(),
2646 session.0.connection_id,
2647 )
2648 .await?
2649 };
2650
2651 for project in &reshared_projects {
2652 for collaborator in &project.collaborators {
2653 session
2654 .peer
2655 .send(
2656 collaborator.connection_id,
2657 proto::UpdateProjectCollaborator {
2658 project_id: project.id.to_proto(),
2659 old_peer_id: Some(project.old_connection_id.into()),
2660 new_peer_id: Some(session.connection_id.into()),
2661 },
2662 )
2663 .trace_err();
2664 }
2665
2666 broadcast(
2667 Some(session.connection_id),
2668 project
2669 .collaborators
2670 .iter()
2671 .map(|collaborator| collaborator.connection_id),
2672 |connection_id| {
2673 session.peer.forward_send(
2674 session.connection_id,
2675 connection_id,
2676 proto::UpdateProject {
2677 project_id: project.id.to_proto(),
2678 worktrees: project.worktrees.clone(),
2679 },
2680 )
2681 },
2682 );
2683 }
2684
2685 response.send(proto::ReconnectDevServerResponse {
2686 reshared_projects: reshared_projects
2687 .iter()
2688 .map(|project| proto::ResharedProject {
2689 id: project.id.to_proto(),
2690 collaborators: project
2691 .collaborators
2692 .iter()
2693 .map(|collaborator| collaborator.to_proto())
2694 .collect(),
2695 })
2696 .collect(),
2697 })?;
2698
2699 Ok(())
2700}
2701
2702async fn shutdown_dev_server(
2703 _: proto::ShutdownDevServer,
2704 response: Response<proto::ShutdownDevServer>,
2705 session: DevServerSession,
2706) -> Result<()> {
2707 response.send(proto::Ack {})?;
2708 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2709 remove_dev_server_connection(session.dev_server_id(), &session).await
2710}
2711
2712async fn shutdown_dev_server_internal(
2713 dev_server_id: DevServerId,
2714 connection_id: ConnectionId,
2715 session: &Session,
2716) -> Result<()> {
2717 let (dev_server_projects, dev_server) = {
2718 let db = session.db().await;
2719 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2720 let dev_server = db.get_dev_server(dev_server_id).await?;
2721 (dev_server_projects, dev_server)
2722 };
2723
2724 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2725 unshare_project_internal(
2726 ProjectId::from_proto(project_id),
2727 connection_id,
2728 None,
2729 session,
2730 )
2731 .await?;
2732 }
2733
2734 session
2735 .connection_pool()
2736 .await
2737 .set_dev_server_offline(dev_server_id);
2738
2739 let status = session
2740 .db()
2741 .await
2742 .dev_server_projects_update(dev_server.user_id)
2743 .await?;
2744 send_dev_server_projects_update(dev_server.user_id, status, session).await;
2745
2746 Ok(())
2747}
2748
2749async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2750 let dev_server_connection = session
2751 .connection_pool()
2752 .await
2753 .dev_server_connection_id(dev_server_id);
2754
2755 if let Some(dev_server_connection) = dev_server_connection {
2756 session
2757 .connection_pool()
2758 .await
2759 .remove_connection(dev_server_connection)?;
2760 }
2761 Ok(())
2762}
2763
2764/// Updates other participants with changes to the project
2765async fn update_project(
2766 request: proto::UpdateProject,
2767 response: Response<proto::UpdateProject>,
2768 session: Session,
2769) -> Result<()> {
2770 let project_id = ProjectId::from_proto(request.project_id);
2771 let (room, guest_connection_ids) = &*session
2772 .db()
2773 .await
2774 .update_project(project_id, session.connection_id, &request.worktrees)
2775 .await?;
2776 broadcast(
2777 Some(session.connection_id),
2778 guest_connection_ids.iter().copied(),
2779 |connection_id| {
2780 session
2781 .peer
2782 .forward_send(session.connection_id, connection_id, request.clone())
2783 },
2784 );
2785 if let Some(room) = room {
2786 room_updated(room, &session.peer);
2787 }
2788 response.send(proto::Ack {})?;
2789
2790 Ok(())
2791}
2792
2793/// Updates other participants with changes to the worktree
2794async fn update_worktree(
2795 request: proto::UpdateWorktree,
2796 response: Response<proto::UpdateWorktree>,
2797 session: Session,
2798) -> Result<()> {
2799 let guest_connection_ids = session
2800 .db()
2801 .await
2802 .update_worktree(&request, session.connection_id)
2803 .await?;
2804
2805 broadcast(
2806 Some(session.connection_id),
2807 guest_connection_ids.iter().copied(),
2808 |connection_id| {
2809 session
2810 .peer
2811 .forward_send(session.connection_id, connection_id, request.clone())
2812 },
2813 );
2814 response.send(proto::Ack {})?;
2815 Ok(())
2816}
2817
2818/// Updates other participants with changes to the diagnostics
2819async fn update_diagnostic_summary(
2820 message: proto::UpdateDiagnosticSummary,
2821 session: Session,
2822) -> Result<()> {
2823 let guest_connection_ids = session
2824 .db()
2825 .await
2826 .update_diagnostic_summary(&message, session.connection_id)
2827 .await?;
2828
2829 broadcast(
2830 Some(session.connection_id),
2831 guest_connection_ids.iter().copied(),
2832 |connection_id| {
2833 session
2834 .peer
2835 .forward_send(session.connection_id, connection_id, message.clone())
2836 },
2837 );
2838
2839 Ok(())
2840}
2841
2842/// Updates other participants with changes to the worktree settings
2843async fn update_worktree_settings(
2844 message: proto::UpdateWorktreeSettings,
2845 session: Session,
2846) -> Result<()> {
2847 let guest_connection_ids = session
2848 .db()
2849 .await
2850 .update_worktree_settings(&message, session.connection_id)
2851 .await?;
2852
2853 broadcast(
2854 Some(session.connection_id),
2855 guest_connection_ids.iter().copied(),
2856 |connection_id| {
2857 session
2858 .peer
2859 .forward_send(session.connection_id, connection_id, message.clone())
2860 },
2861 );
2862
2863 Ok(())
2864}
2865
2866/// Notify other participants that a language server has started.
2867async fn start_language_server(
2868 request: proto::StartLanguageServer,
2869 session: Session,
2870) -> Result<()> {
2871 let guest_connection_ids = session
2872 .db()
2873 .await
2874 .start_language_server(&request, session.connection_id)
2875 .await?;
2876
2877 broadcast(
2878 Some(session.connection_id),
2879 guest_connection_ids.iter().copied(),
2880 |connection_id| {
2881 session
2882 .peer
2883 .forward_send(session.connection_id, connection_id, request.clone())
2884 },
2885 );
2886 Ok(())
2887}
2888
2889/// Notify other participants that a language server has changed.
2890async fn update_language_server(
2891 request: proto::UpdateLanguageServer,
2892 session: Session,
2893) -> Result<()> {
2894 let project_id = ProjectId::from_proto(request.project_id);
2895 let project_connection_ids = session
2896 .db()
2897 .await
2898 .project_connection_ids(project_id, session.connection_id, true)
2899 .await?;
2900 broadcast(
2901 Some(session.connection_id),
2902 project_connection_ids.iter().copied(),
2903 |connection_id| {
2904 session
2905 .peer
2906 .forward_send(session.connection_id, connection_id, request.clone())
2907 },
2908 );
2909 Ok(())
2910}
2911
2912/// forward a project request to the host. These requests should be read only
2913/// as guests are allowed to send them.
2914async fn forward_read_only_project_request<T>(
2915 request: T,
2916 response: Response<T>,
2917 session: UserSession,
2918) -> Result<()>
2919where
2920 T: EntityMessage + RequestMessage,
2921{
2922 let project_id = ProjectId::from_proto(request.remote_entity_id());
2923 let host_connection_id = session
2924 .db()
2925 .await
2926 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2927 .await?;
2928 let payload = session
2929 .peer
2930 .forward_request(session.connection_id, host_connection_id, request)
2931 .await?;
2932 response.send(payload)?;
2933 Ok(())
2934}
2935
2936async fn forward_find_search_candidates_request(
2937 request: proto::FindSearchCandidates,
2938 response: Response<proto::FindSearchCandidates>,
2939 session: UserSession,
2940) -> Result<()> {
2941 let project_id = ProjectId::from_proto(request.remote_entity_id());
2942 let host_connection_id = session
2943 .db()
2944 .await
2945 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2946 .await?;
2947 let payload = session
2948 .peer
2949 .forward_request(session.connection_id, host_connection_id, request)
2950 .await?;
2951 response.send(payload)?;
2952 Ok(())
2953}
2954
2955/// forward a project request to the dev server. Only allowed
2956/// if it's your dev server.
2957async fn forward_project_request_for_owner<T>(
2958 request: T,
2959 response: Response<T>,
2960 session: UserSession,
2961) -> Result<()>
2962where
2963 T: EntityMessage + RequestMessage,
2964{
2965 let project_id = ProjectId::from_proto(request.remote_entity_id());
2966
2967 let host_connection_id = session
2968 .db()
2969 .await
2970 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2971 .await?;
2972 let payload = session
2973 .peer
2974 .forward_request(session.connection_id, host_connection_id, request)
2975 .await?;
2976 response.send(payload)?;
2977 Ok(())
2978}
2979
2980/// forward a project request to the host. These requests are disallowed
2981/// for guests.
2982async fn forward_mutating_project_request<T>(
2983 request: T,
2984 response: Response<T>,
2985 session: UserSession,
2986) -> Result<()>
2987where
2988 T: EntityMessage + RequestMessage,
2989{
2990 let project_id = ProjectId::from_proto(request.remote_entity_id());
2991
2992 let host_connection_id = session
2993 .db()
2994 .await
2995 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2996 .await?;
2997 let payload = session
2998 .peer
2999 .forward_request(session.connection_id, host_connection_id, request)
3000 .await?;
3001 response.send(payload)?;
3002 Ok(())
3003}
3004
3005/// Notify other participants that a new buffer has been created
3006async fn create_buffer_for_peer(
3007 request: proto::CreateBufferForPeer,
3008 session: Session,
3009) -> Result<()> {
3010 session
3011 .db()
3012 .await
3013 .check_user_is_project_host(
3014 ProjectId::from_proto(request.project_id),
3015 session.connection_id,
3016 )
3017 .await?;
3018 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3019 session
3020 .peer
3021 .forward_send(session.connection_id, peer_id.into(), request)?;
3022 Ok(())
3023}
3024
3025/// Notify other participants that a buffer has been updated. This is
3026/// allowed for guests as long as the update is limited to selections.
3027async fn update_buffer(
3028 request: proto::UpdateBuffer,
3029 response: Response<proto::UpdateBuffer>,
3030 session: Session,
3031) -> Result<()> {
3032 let project_id = ProjectId::from_proto(request.project_id);
3033 let mut capability = Capability::ReadOnly;
3034
3035 for op in request.operations.iter() {
3036 match op.variant {
3037 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3038 Some(_) => capability = Capability::ReadWrite,
3039 }
3040 }
3041
3042 let host = {
3043 let guard = session
3044 .db()
3045 .await
3046 .connections_for_buffer_update(
3047 project_id,
3048 session.principal_id(),
3049 session.connection_id,
3050 capability,
3051 )
3052 .await?;
3053
3054 let (host, guests) = &*guard;
3055
3056 broadcast(
3057 Some(session.connection_id),
3058 guests.clone(),
3059 |connection_id| {
3060 session
3061 .peer
3062 .forward_send(session.connection_id, connection_id, request.clone())
3063 },
3064 );
3065
3066 *host
3067 };
3068
3069 if host != session.connection_id {
3070 session
3071 .peer
3072 .forward_request(session.connection_id, host, request.clone())
3073 .await?;
3074 }
3075
3076 response.send(proto::Ack {})?;
3077 Ok(())
3078}
3079
3080async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3081 let project_id = ProjectId::from_proto(message.project_id);
3082
3083 let operation = message.operation.as_ref().context("invalid operation")?;
3084 let capability = match operation.variant.as_ref() {
3085 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3086 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3087 match buffer_op.variant {
3088 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3089 Capability::ReadOnly
3090 }
3091 _ => Capability::ReadWrite,
3092 }
3093 } else {
3094 Capability::ReadWrite
3095 }
3096 }
3097 Some(_) => Capability::ReadWrite,
3098 None => Capability::ReadOnly,
3099 };
3100
3101 let guard = session
3102 .db()
3103 .await
3104 .connections_for_buffer_update(
3105 project_id,
3106 session.principal_id(),
3107 session.connection_id,
3108 capability,
3109 )
3110 .await?;
3111
3112 let (host, guests) = &*guard;
3113
3114 broadcast(
3115 Some(session.connection_id),
3116 guests.iter().chain([host]).copied(),
3117 |connection_id| {
3118 session
3119 .peer
3120 .forward_send(session.connection_id, connection_id, message.clone())
3121 },
3122 );
3123
3124 Ok(())
3125}
3126
3127/// Notify other participants that a project has been updated.
3128async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3129 request: T,
3130 session: Session,
3131) -> Result<()> {
3132 let project_id = ProjectId::from_proto(request.remote_entity_id());
3133 let project_connection_ids = session
3134 .db()
3135 .await
3136 .project_connection_ids(project_id, session.connection_id, false)
3137 .await?;
3138
3139 broadcast(
3140 Some(session.connection_id),
3141 project_connection_ids.iter().copied(),
3142 |connection_id| {
3143 session
3144 .peer
3145 .forward_send(session.connection_id, connection_id, request.clone())
3146 },
3147 );
3148 Ok(())
3149}
3150
3151/// Start following another user in a call.
3152async fn follow(
3153 request: proto::Follow,
3154 response: Response<proto::Follow>,
3155 session: UserSession,
3156) -> Result<()> {
3157 let room_id = RoomId::from_proto(request.room_id);
3158 let project_id = request.project_id.map(ProjectId::from_proto);
3159 let leader_id = request
3160 .leader_id
3161 .ok_or_else(|| anyhow!("invalid leader id"))?
3162 .into();
3163 let follower_id = session.connection_id;
3164
3165 session
3166 .db()
3167 .await
3168 .check_room_participants(room_id, leader_id, session.connection_id)
3169 .await?;
3170
3171 let response_payload = session
3172 .peer
3173 .forward_request(session.connection_id, leader_id, request)
3174 .await?;
3175 response.send(response_payload)?;
3176
3177 if let Some(project_id) = project_id {
3178 let room = session
3179 .db()
3180 .await
3181 .follow(room_id, project_id, leader_id, follower_id)
3182 .await?;
3183 room_updated(&room, &session.peer);
3184 }
3185
3186 Ok(())
3187}
3188
3189/// Stop following another user in a call.
3190async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3191 let room_id = RoomId::from_proto(request.room_id);
3192 let project_id = request.project_id.map(ProjectId::from_proto);
3193 let leader_id = request
3194 .leader_id
3195 .ok_or_else(|| anyhow!("invalid leader id"))?
3196 .into();
3197 let follower_id = session.connection_id;
3198
3199 session
3200 .db()
3201 .await
3202 .check_room_participants(room_id, leader_id, session.connection_id)
3203 .await?;
3204
3205 session
3206 .peer
3207 .forward_send(session.connection_id, leader_id, request)?;
3208
3209 if let Some(project_id) = project_id {
3210 let room = session
3211 .db()
3212 .await
3213 .unfollow(room_id, project_id, leader_id, follower_id)
3214 .await?;
3215 room_updated(&room, &session.peer);
3216 }
3217
3218 Ok(())
3219}
3220
3221/// Notify everyone following you of your current location.
3222async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3223 let room_id = RoomId::from_proto(request.room_id);
3224 let database = session.db.lock().await;
3225
3226 let connection_ids = if let Some(project_id) = request.project_id {
3227 let project_id = ProjectId::from_proto(project_id);
3228 database
3229 .project_connection_ids(project_id, session.connection_id, true)
3230 .await?
3231 } else {
3232 database
3233 .room_connection_ids(room_id, session.connection_id)
3234 .await?
3235 };
3236
3237 // For now, don't send view update messages back to that view's current leader.
3238 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3239 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3240 _ => None,
3241 });
3242
3243 for connection_id in connection_ids.iter().cloned() {
3244 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3245 session
3246 .peer
3247 .forward_send(session.connection_id, connection_id, request.clone())?;
3248 }
3249 }
3250 Ok(())
3251}
3252
3253/// Get public data about users.
3254async fn get_users(
3255 request: proto::GetUsers,
3256 response: Response<proto::GetUsers>,
3257 session: Session,
3258) -> Result<()> {
3259 let user_ids = request
3260 .user_ids
3261 .into_iter()
3262 .map(UserId::from_proto)
3263 .collect();
3264 let users = session
3265 .db()
3266 .await
3267 .get_users_by_ids(user_ids)
3268 .await?
3269 .into_iter()
3270 .map(|user| proto::User {
3271 id: user.id.to_proto(),
3272 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3273 github_login: user.github_login,
3274 })
3275 .collect();
3276 response.send(proto::UsersResponse { users })?;
3277 Ok(())
3278}
3279
3280/// Search for users (to invite) buy Github login
3281async fn fuzzy_search_users(
3282 request: proto::FuzzySearchUsers,
3283 response: Response<proto::FuzzySearchUsers>,
3284 session: UserSession,
3285) -> Result<()> {
3286 let query = request.query;
3287 let users = match query.len() {
3288 0 => vec![],
3289 1 | 2 => session
3290 .db()
3291 .await
3292 .get_user_by_github_login(&query)
3293 .await?
3294 .into_iter()
3295 .collect(),
3296 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3297 };
3298 let users = users
3299 .into_iter()
3300 .filter(|user| user.id != session.user_id())
3301 .map(|user| proto::User {
3302 id: user.id.to_proto(),
3303 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3304 github_login: user.github_login,
3305 })
3306 .collect();
3307 response.send(proto::UsersResponse { users })?;
3308 Ok(())
3309}
3310
3311/// Send a contact request to another user.
3312async fn request_contact(
3313 request: proto::RequestContact,
3314 response: Response<proto::RequestContact>,
3315 session: UserSession,
3316) -> Result<()> {
3317 let requester_id = session.user_id();
3318 let responder_id = UserId::from_proto(request.responder_id);
3319 if requester_id == responder_id {
3320 return Err(anyhow!("cannot add yourself as a contact"))?;
3321 }
3322
3323 let notifications = session
3324 .db()
3325 .await
3326 .send_contact_request(requester_id, responder_id)
3327 .await?;
3328
3329 // Update outgoing contact requests of requester
3330 let mut update = proto::UpdateContacts::default();
3331 update.outgoing_requests.push(responder_id.to_proto());
3332 for connection_id in session
3333 .connection_pool()
3334 .await
3335 .user_connection_ids(requester_id)
3336 {
3337 session.peer.send(connection_id, update.clone())?;
3338 }
3339
3340 // Update incoming contact requests of responder
3341 let mut update = proto::UpdateContacts::default();
3342 update
3343 .incoming_requests
3344 .push(proto::IncomingContactRequest {
3345 requester_id: requester_id.to_proto(),
3346 });
3347 let connection_pool = session.connection_pool().await;
3348 for connection_id in connection_pool.user_connection_ids(responder_id) {
3349 session.peer.send(connection_id, update.clone())?;
3350 }
3351
3352 send_notifications(&connection_pool, &session.peer, notifications);
3353
3354 response.send(proto::Ack {})?;
3355 Ok(())
3356}
3357
3358/// Accept or decline a contact request
3359async fn respond_to_contact_request(
3360 request: proto::RespondToContactRequest,
3361 response: Response<proto::RespondToContactRequest>,
3362 session: UserSession,
3363) -> Result<()> {
3364 let responder_id = session.user_id();
3365 let requester_id = UserId::from_proto(request.requester_id);
3366 let db = session.db().await;
3367 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3368 db.dismiss_contact_notification(responder_id, requester_id)
3369 .await?;
3370 } else {
3371 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3372
3373 let notifications = db
3374 .respond_to_contact_request(responder_id, requester_id, accept)
3375 .await?;
3376 let requester_busy = db.is_user_busy(requester_id).await?;
3377 let responder_busy = db.is_user_busy(responder_id).await?;
3378
3379 let pool = session.connection_pool().await;
3380 // Update responder with new contact
3381 let mut update = proto::UpdateContacts::default();
3382 if accept {
3383 update
3384 .contacts
3385 .push(contact_for_user(requester_id, requester_busy, &pool));
3386 }
3387 update
3388 .remove_incoming_requests
3389 .push(requester_id.to_proto());
3390 for connection_id in pool.user_connection_ids(responder_id) {
3391 session.peer.send(connection_id, update.clone())?;
3392 }
3393
3394 // Update requester with new contact
3395 let mut update = proto::UpdateContacts::default();
3396 if accept {
3397 update
3398 .contacts
3399 .push(contact_for_user(responder_id, responder_busy, &pool));
3400 }
3401 update
3402 .remove_outgoing_requests
3403 .push(responder_id.to_proto());
3404
3405 for connection_id in pool.user_connection_ids(requester_id) {
3406 session.peer.send(connection_id, update.clone())?;
3407 }
3408
3409 send_notifications(&pool, &session.peer, notifications);
3410 }
3411
3412 response.send(proto::Ack {})?;
3413 Ok(())
3414}
3415
3416/// Remove a contact.
3417async fn remove_contact(
3418 request: proto::RemoveContact,
3419 response: Response<proto::RemoveContact>,
3420 session: UserSession,
3421) -> Result<()> {
3422 let requester_id = session.user_id();
3423 let responder_id = UserId::from_proto(request.user_id);
3424 let db = session.db().await;
3425 let (contact_accepted, deleted_notification_id) =
3426 db.remove_contact(requester_id, responder_id).await?;
3427
3428 let pool = session.connection_pool().await;
3429 // Update outgoing contact requests of requester
3430 let mut update = proto::UpdateContacts::default();
3431 if contact_accepted {
3432 update.remove_contacts.push(responder_id.to_proto());
3433 } else {
3434 update
3435 .remove_outgoing_requests
3436 .push(responder_id.to_proto());
3437 }
3438 for connection_id in pool.user_connection_ids(requester_id) {
3439 session.peer.send(connection_id, update.clone())?;
3440 }
3441
3442 // Update incoming contact requests of responder
3443 let mut update = proto::UpdateContacts::default();
3444 if contact_accepted {
3445 update.remove_contacts.push(requester_id.to_proto());
3446 } else {
3447 update
3448 .remove_incoming_requests
3449 .push(requester_id.to_proto());
3450 }
3451 for connection_id in pool.user_connection_ids(responder_id) {
3452 session.peer.send(connection_id, update.clone())?;
3453 if let Some(notification_id) = deleted_notification_id {
3454 session.peer.send(
3455 connection_id,
3456 proto::DeleteNotification {
3457 notification_id: notification_id.to_proto(),
3458 },
3459 )?;
3460 }
3461 }
3462
3463 response.send(proto::Ack {})?;
3464 Ok(())
3465}
3466
3467fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3468 version.0.minor() < 139
3469}
3470
3471async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3472 let plan = session.current_plan(session.db().await).await?;
3473
3474 session
3475 .peer
3476 .send(
3477 session.connection_id,
3478 proto::UpdateUserPlan { plan: plan.into() },
3479 )
3480 .trace_err();
3481
3482 Ok(())
3483}
3484
3485async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3486 subscribe_user_to_channels(
3487 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3488 &session,
3489 )
3490 .await?;
3491 Ok(())
3492}
3493
3494async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3495 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3496 let mut pool = session.connection_pool().await;
3497 for membership in &channels_for_user.channel_memberships {
3498 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3499 }
3500 session.peer.send(
3501 session.connection_id,
3502 build_update_user_channels(&channels_for_user),
3503 )?;
3504 session.peer.send(
3505 session.connection_id,
3506 build_channels_update(channels_for_user),
3507 )?;
3508 Ok(())
3509}
3510
3511/// Creates a new channel.
3512async fn create_channel(
3513 request: proto::CreateChannel,
3514 response: Response<proto::CreateChannel>,
3515 session: UserSession,
3516) -> Result<()> {
3517 let db = session.db().await;
3518
3519 let parent_id = request.parent_id.map(ChannelId::from_proto);
3520 let (channel, membership) = db
3521 .create_channel(&request.name, parent_id, session.user_id())
3522 .await?;
3523
3524 let root_id = channel.root_id();
3525 let channel = Channel::from_model(channel);
3526
3527 response.send(proto::CreateChannelResponse {
3528 channel: Some(channel.to_proto()),
3529 parent_id: request.parent_id,
3530 })?;
3531
3532 let mut connection_pool = session.connection_pool().await;
3533 if let Some(membership) = membership {
3534 connection_pool.subscribe_to_channel(
3535 membership.user_id,
3536 membership.channel_id,
3537 membership.role,
3538 );
3539 let update = proto::UpdateUserChannels {
3540 channel_memberships: vec![proto::ChannelMembership {
3541 channel_id: membership.channel_id.to_proto(),
3542 role: membership.role.into(),
3543 }],
3544 ..Default::default()
3545 };
3546 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3547 session.peer.send(connection_id, update.clone())?;
3548 }
3549 }
3550
3551 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3552 if !role.can_see_channel(channel.visibility) {
3553 continue;
3554 }
3555
3556 let update = proto::UpdateChannels {
3557 channels: vec![channel.to_proto()],
3558 ..Default::default()
3559 };
3560 session.peer.send(connection_id, update.clone())?;
3561 }
3562
3563 Ok(())
3564}
3565
3566/// Delete a channel
3567async fn delete_channel(
3568 request: proto::DeleteChannel,
3569 response: Response<proto::DeleteChannel>,
3570 session: UserSession,
3571) -> Result<()> {
3572 let db = session.db().await;
3573
3574 let channel_id = request.channel_id;
3575 let (root_channel, removed_channels) = db
3576 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3577 .await?;
3578 response.send(proto::Ack {})?;
3579
3580 // Notify members of removed channels
3581 let mut update = proto::UpdateChannels::default();
3582 update
3583 .delete_channels
3584 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3585
3586 let connection_pool = session.connection_pool().await;
3587 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3588 session.peer.send(connection_id, update.clone())?;
3589 }
3590
3591 Ok(())
3592}
3593
3594/// Invite someone to join a channel.
3595async fn invite_channel_member(
3596 request: proto::InviteChannelMember,
3597 response: Response<proto::InviteChannelMember>,
3598 session: UserSession,
3599) -> Result<()> {
3600 let db = session.db().await;
3601 let channel_id = ChannelId::from_proto(request.channel_id);
3602 let invitee_id = UserId::from_proto(request.user_id);
3603 let InviteMemberResult {
3604 channel,
3605 notifications,
3606 } = db
3607 .invite_channel_member(
3608 channel_id,
3609 invitee_id,
3610 session.user_id(),
3611 request.role().into(),
3612 )
3613 .await?;
3614
3615 let update = proto::UpdateChannels {
3616 channel_invitations: vec![channel.to_proto()],
3617 ..Default::default()
3618 };
3619
3620 let connection_pool = session.connection_pool().await;
3621 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3622 session.peer.send(connection_id, update.clone())?;
3623 }
3624
3625 send_notifications(&connection_pool, &session.peer, notifications);
3626
3627 response.send(proto::Ack {})?;
3628 Ok(())
3629}
3630
3631/// remove someone from a channel
3632async fn remove_channel_member(
3633 request: proto::RemoveChannelMember,
3634 response: Response<proto::RemoveChannelMember>,
3635 session: UserSession,
3636) -> Result<()> {
3637 let db = session.db().await;
3638 let channel_id = ChannelId::from_proto(request.channel_id);
3639 let member_id = UserId::from_proto(request.user_id);
3640
3641 let RemoveChannelMemberResult {
3642 membership_update,
3643 notification_id,
3644 } = db
3645 .remove_channel_member(channel_id, member_id, session.user_id())
3646 .await?;
3647
3648 let mut connection_pool = session.connection_pool().await;
3649 notify_membership_updated(
3650 &mut connection_pool,
3651 membership_update,
3652 member_id,
3653 &session.peer,
3654 );
3655 for connection_id in connection_pool.user_connection_ids(member_id) {
3656 if let Some(notification_id) = notification_id {
3657 session
3658 .peer
3659 .send(
3660 connection_id,
3661 proto::DeleteNotification {
3662 notification_id: notification_id.to_proto(),
3663 },
3664 )
3665 .trace_err();
3666 }
3667 }
3668
3669 response.send(proto::Ack {})?;
3670 Ok(())
3671}
3672
3673/// Toggle the channel between public and private.
3674/// Care is taken to maintain the invariant that public channels only descend from public channels,
3675/// (though members-only channels can appear at any point in the hierarchy).
3676async fn set_channel_visibility(
3677 request: proto::SetChannelVisibility,
3678 response: Response<proto::SetChannelVisibility>,
3679 session: UserSession,
3680) -> Result<()> {
3681 let db = session.db().await;
3682 let channel_id = ChannelId::from_proto(request.channel_id);
3683 let visibility = request.visibility().into();
3684
3685 let channel_model = db
3686 .set_channel_visibility(channel_id, visibility, session.user_id())
3687 .await?;
3688 let root_id = channel_model.root_id();
3689 let channel = Channel::from_model(channel_model);
3690
3691 let mut connection_pool = session.connection_pool().await;
3692 for (user_id, role) in connection_pool
3693 .channel_user_ids(root_id)
3694 .collect::<Vec<_>>()
3695 .into_iter()
3696 {
3697 let update = if role.can_see_channel(channel.visibility) {
3698 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3699 proto::UpdateChannels {
3700 channels: vec![channel.to_proto()],
3701 ..Default::default()
3702 }
3703 } else {
3704 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3705 proto::UpdateChannels {
3706 delete_channels: vec![channel.id.to_proto()],
3707 ..Default::default()
3708 }
3709 };
3710
3711 for connection_id in connection_pool.user_connection_ids(user_id) {
3712 session.peer.send(connection_id, update.clone())?;
3713 }
3714 }
3715
3716 response.send(proto::Ack {})?;
3717 Ok(())
3718}
3719
3720/// Alter the role for a user in the channel.
3721async fn set_channel_member_role(
3722 request: proto::SetChannelMemberRole,
3723 response: Response<proto::SetChannelMemberRole>,
3724 session: UserSession,
3725) -> Result<()> {
3726 let db = session.db().await;
3727 let channel_id = ChannelId::from_proto(request.channel_id);
3728 let member_id = UserId::from_proto(request.user_id);
3729 let result = db
3730 .set_channel_member_role(
3731 channel_id,
3732 session.user_id(),
3733 member_id,
3734 request.role().into(),
3735 )
3736 .await?;
3737
3738 match result {
3739 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3740 let mut connection_pool = session.connection_pool().await;
3741 notify_membership_updated(
3742 &mut connection_pool,
3743 membership_update,
3744 member_id,
3745 &session.peer,
3746 )
3747 }
3748 db::SetMemberRoleResult::InviteUpdated(channel) => {
3749 let update = proto::UpdateChannels {
3750 channel_invitations: vec![channel.to_proto()],
3751 ..Default::default()
3752 };
3753
3754 for connection_id in session
3755 .connection_pool()
3756 .await
3757 .user_connection_ids(member_id)
3758 {
3759 session.peer.send(connection_id, update.clone())?;
3760 }
3761 }
3762 }
3763
3764 response.send(proto::Ack {})?;
3765 Ok(())
3766}
3767
3768/// Change the name of a channel
3769async fn rename_channel(
3770 request: proto::RenameChannel,
3771 response: Response<proto::RenameChannel>,
3772 session: UserSession,
3773) -> Result<()> {
3774 let db = session.db().await;
3775 let channel_id = ChannelId::from_proto(request.channel_id);
3776 let channel_model = db
3777 .rename_channel(channel_id, session.user_id(), &request.name)
3778 .await?;
3779 let root_id = channel_model.root_id();
3780 let channel = Channel::from_model(channel_model);
3781
3782 response.send(proto::RenameChannelResponse {
3783 channel: Some(channel.to_proto()),
3784 })?;
3785
3786 let connection_pool = session.connection_pool().await;
3787 let update = proto::UpdateChannels {
3788 channels: vec![channel.to_proto()],
3789 ..Default::default()
3790 };
3791 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3792 if role.can_see_channel(channel.visibility) {
3793 session.peer.send(connection_id, update.clone())?;
3794 }
3795 }
3796
3797 Ok(())
3798}
3799
3800/// Move a channel to a new parent.
3801async fn move_channel(
3802 request: proto::MoveChannel,
3803 response: Response<proto::MoveChannel>,
3804 session: UserSession,
3805) -> Result<()> {
3806 let channel_id = ChannelId::from_proto(request.channel_id);
3807 let to = ChannelId::from_proto(request.to);
3808
3809 let (root_id, channels) = session
3810 .db()
3811 .await
3812 .move_channel(channel_id, to, session.user_id())
3813 .await?;
3814
3815 let connection_pool = session.connection_pool().await;
3816 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3817 let channels = channels
3818 .iter()
3819 .filter_map(|channel| {
3820 if role.can_see_channel(channel.visibility) {
3821 Some(channel.to_proto())
3822 } else {
3823 None
3824 }
3825 })
3826 .collect::<Vec<_>>();
3827 if channels.is_empty() {
3828 continue;
3829 }
3830
3831 let update = proto::UpdateChannels {
3832 channels,
3833 ..Default::default()
3834 };
3835
3836 session.peer.send(connection_id, update.clone())?;
3837 }
3838
3839 response.send(Ack {})?;
3840 Ok(())
3841}
3842
3843/// Get the list of channel members
3844async fn get_channel_members(
3845 request: proto::GetChannelMembers,
3846 response: Response<proto::GetChannelMembers>,
3847 session: UserSession,
3848) -> Result<()> {
3849 let db = session.db().await;
3850 let channel_id = ChannelId::from_proto(request.channel_id);
3851 let limit = if request.limit == 0 {
3852 u16::MAX as u64
3853 } else {
3854 request.limit
3855 };
3856 let (members, users) = db
3857 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3858 .await?;
3859 response.send(proto::GetChannelMembersResponse { members, users })?;
3860 Ok(())
3861}
3862
3863/// Accept or decline a channel invitation.
3864async fn respond_to_channel_invite(
3865 request: proto::RespondToChannelInvite,
3866 response: Response<proto::RespondToChannelInvite>,
3867 session: UserSession,
3868) -> Result<()> {
3869 let db = session.db().await;
3870 let channel_id = ChannelId::from_proto(request.channel_id);
3871 let RespondToChannelInvite {
3872 membership_update,
3873 notifications,
3874 } = db
3875 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3876 .await?;
3877
3878 let mut connection_pool = session.connection_pool().await;
3879 if let Some(membership_update) = membership_update {
3880 notify_membership_updated(
3881 &mut connection_pool,
3882 membership_update,
3883 session.user_id(),
3884 &session.peer,
3885 );
3886 } else {
3887 let update = proto::UpdateChannels {
3888 remove_channel_invitations: vec![channel_id.to_proto()],
3889 ..Default::default()
3890 };
3891
3892 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3893 session.peer.send(connection_id, update.clone())?;
3894 }
3895 };
3896
3897 send_notifications(&connection_pool, &session.peer, notifications);
3898
3899 response.send(proto::Ack {})?;
3900
3901 Ok(())
3902}
3903
3904/// Join the channels' room
3905async fn join_channel(
3906 request: proto::JoinChannel,
3907 response: Response<proto::JoinChannel>,
3908 session: UserSession,
3909) -> Result<()> {
3910 let channel_id = ChannelId::from_proto(request.channel_id);
3911 join_channel_internal(channel_id, Box::new(response), session).await
3912}
3913
3914trait JoinChannelInternalResponse {
3915 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3916}
3917impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3918 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3919 Response::<proto::JoinChannel>::send(self, result)
3920 }
3921}
3922impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3923 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3924 Response::<proto::JoinRoom>::send(self, result)
3925 }
3926}
3927
3928async fn join_channel_internal(
3929 channel_id: ChannelId,
3930 response: Box<impl JoinChannelInternalResponse>,
3931 session: UserSession,
3932) -> Result<()> {
3933 let joined_room = {
3934 let mut db = session.db().await;
3935 // If zed quits without leaving the room, and the user re-opens zed before the
3936 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3937 // room they were in.
3938 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3939 tracing::info!(
3940 stale_connection_id = %connection,
3941 "cleaning up stale connection",
3942 );
3943 drop(db);
3944 leave_room_for_session(&session, connection).await?;
3945 db = session.db().await;
3946 }
3947
3948 let (joined_room, membership_updated, role) = db
3949 .join_channel(channel_id, session.user_id(), session.connection_id)
3950 .await?;
3951
3952 let live_kit_connection_info =
3953 session
3954 .app_state
3955 .live_kit_client
3956 .as_ref()
3957 .and_then(|live_kit| {
3958 let (can_publish, token) = if role == ChannelRole::Guest {
3959 (
3960 false,
3961 live_kit
3962 .guest_token(
3963 &joined_room.room.live_kit_room,
3964 &session.user_id().to_string(),
3965 )
3966 .trace_err()?,
3967 )
3968 } else {
3969 (
3970 true,
3971 live_kit
3972 .room_token(
3973 &joined_room.room.live_kit_room,
3974 &session.user_id().to_string(),
3975 )
3976 .trace_err()?,
3977 )
3978 };
3979
3980 Some(LiveKitConnectionInfo {
3981 server_url: live_kit.url().into(),
3982 token,
3983 can_publish,
3984 })
3985 });
3986
3987 response.send(proto::JoinRoomResponse {
3988 room: Some(joined_room.room.clone()),
3989 channel_id: joined_room
3990 .channel
3991 .as_ref()
3992 .map(|channel| channel.id.to_proto()),
3993 live_kit_connection_info,
3994 })?;
3995
3996 let mut connection_pool = session.connection_pool().await;
3997 if let Some(membership_updated) = membership_updated {
3998 notify_membership_updated(
3999 &mut connection_pool,
4000 membership_updated,
4001 session.user_id(),
4002 &session.peer,
4003 );
4004 }
4005
4006 room_updated(&joined_room.room, &session.peer);
4007
4008 joined_room
4009 };
4010
4011 channel_updated(
4012 &joined_room
4013 .channel
4014 .ok_or_else(|| anyhow!("channel not returned"))?,
4015 &joined_room.room,
4016 &session.peer,
4017 &*session.connection_pool().await,
4018 );
4019
4020 update_user_contacts(session.user_id(), &session).await?;
4021 Ok(())
4022}
4023
4024/// Start editing the channel notes
4025async fn join_channel_buffer(
4026 request: proto::JoinChannelBuffer,
4027 response: Response<proto::JoinChannelBuffer>,
4028 session: UserSession,
4029) -> Result<()> {
4030 let db = session.db().await;
4031 let channel_id = ChannelId::from_proto(request.channel_id);
4032
4033 let open_response = db
4034 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4035 .await?;
4036
4037 let collaborators = open_response.collaborators.clone();
4038 response.send(open_response)?;
4039
4040 let update = UpdateChannelBufferCollaborators {
4041 channel_id: channel_id.to_proto(),
4042 collaborators: collaborators.clone(),
4043 };
4044 channel_buffer_updated(
4045 session.connection_id,
4046 collaborators
4047 .iter()
4048 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4049 &update,
4050 &session.peer,
4051 );
4052
4053 Ok(())
4054}
4055
4056/// Edit the channel notes
4057async fn update_channel_buffer(
4058 request: proto::UpdateChannelBuffer,
4059 session: UserSession,
4060) -> Result<()> {
4061 let db = session.db().await;
4062 let channel_id = ChannelId::from_proto(request.channel_id);
4063
4064 let (collaborators, epoch, version) = db
4065 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4066 .await?;
4067
4068 channel_buffer_updated(
4069 session.connection_id,
4070 collaborators.clone(),
4071 &proto::UpdateChannelBuffer {
4072 channel_id: channel_id.to_proto(),
4073 operations: request.operations,
4074 },
4075 &session.peer,
4076 );
4077
4078 let pool = &*session.connection_pool().await;
4079
4080 let non_collaborators =
4081 pool.channel_connection_ids(channel_id)
4082 .filter_map(|(connection_id, _)| {
4083 if collaborators.contains(&connection_id) {
4084 None
4085 } else {
4086 Some(connection_id)
4087 }
4088 });
4089
4090 broadcast(None, non_collaborators, |peer_id| {
4091 session.peer.send(
4092 peer_id,
4093 proto::UpdateChannels {
4094 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4095 channel_id: channel_id.to_proto(),
4096 epoch: epoch as u64,
4097 version: version.clone(),
4098 }],
4099 ..Default::default()
4100 },
4101 )
4102 });
4103
4104 Ok(())
4105}
4106
4107/// Rejoin the channel notes after a connection blip
4108async fn rejoin_channel_buffers(
4109 request: proto::RejoinChannelBuffers,
4110 response: Response<proto::RejoinChannelBuffers>,
4111 session: UserSession,
4112) -> Result<()> {
4113 let db = session.db().await;
4114 let buffers = db
4115 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4116 .await?;
4117
4118 for rejoined_buffer in &buffers {
4119 let collaborators_to_notify = rejoined_buffer
4120 .buffer
4121 .collaborators
4122 .iter()
4123 .filter_map(|c| Some(c.peer_id?.into()));
4124 channel_buffer_updated(
4125 session.connection_id,
4126 collaborators_to_notify,
4127 &proto::UpdateChannelBufferCollaborators {
4128 channel_id: rejoined_buffer.buffer.channel_id,
4129 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4130 },
4131 &session.peer,
4132 );
4133 }
4134
4135 response.send(proto::RejoinChannelBuffersResponse {
4136 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4137 })?;
4138
4139 Ok(())
4140}
4141
4142/// Stop editing the channel notes
4143async fn leave_channel_buffer(
4144 request: proto::LeaveChannelBuffer,
4145 response: Response<proto::LeaveChannelBuffer>,
4146 session: UserSession,
4147) -> Result<()> {
4148 let db = session.db().await;
4149 let channel_id = ChannelId::from_proto(request.channel_id);
4150
4151 let left_buffer = db
4152 .leave_channel_buffer(channel_id, session.connection_id)
4153 .await?;
4154
4155 response.send(Ack {})?;
4156
4157 channel_buffer_updated(
4158 session.connection_id,
4159 left_buffer.connections,
4160 &proto::UpdateChannelBufferCollaborators {
4161 channel_id: channel_id.to_proto(),
4162 collaborators: left_buffer.collaborators,
4163 },
4164 &session.peer,
4165 );
4166
4167 Ok(())
4168}
4169
4170fn channel_buffer_updated<T: EnvelopedMessage>(
4171 sender_id: ConnectionId,
4172 collaborators: impl IntoIterator<Item = ConnectionId>,
4173 message: &T,
4174 peer: &Peer,
4175) {
4176 broadcast(Some(sender_id), collaborators, |peer_id| {
4177 peer.send(peer_id, message.clone())
4178 });
4179}
4180
4181fn send_notifications(
4182 connection_pool: &ConnectionPool,
4183 peer: &Peer,
4184 notifications: db::NotificationBatch,
4185) {
4186 for (user_id, notification) in notifications {
4187 for connection_id in connection_pool.user_connection_ids(user_id) {
4188 if let Err(error) = peer.send(
4189 connection_id,
4190 proto::AddNotification {
4191 notification: Some(notification.clone()),
4192 },
4193 ) {
4194 tracing::error!(
4195 "failed to send notification to {:?} {}",
4196 connection_id,
4197 error
4198 );
4199 }
4200 }
4201 }
4202}
4203
4204/// Send a message to the channel
4205async fn send_channel_message(
4206 request: proto::SendChannelMessage,
4207 response: Response<proto::SendChannelMessage>,
4208 session: UserSession,
4209) -> Result<()> {
4210 // Validate the message body.
4211 let body = request.body.trim().to_string();
4212 if body.len() > MAX_MESSAGE_LEN {
4213 return Err(anyhow!("message is too long"))?;
4214 }
4215 if body.is_empty() {
4216 return Err(anyhow!("message can't be blank"))?;
4217 }
4218
4219 // TODO: adjust mentions if body is trimmed
4220
4221 let timestamp = OffsetDateTime::now_utc();
4222 let nonce = request
4223 .nonce
4224 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4225
4226 let channel_id = ChannelId::from_proto(request.channel_id);
4227 let CreatedChannelMessage {
4228 message_id,
4229 participant_connection_ids,
4230 notifications,
4231 } = session
4232 .db()
4233 .await
4234 .create_channel_message(
4235 channel_id,
4236 session.user_id(),
4237 &body,
4238 &request.mentions,
4239 timestamp,
4240 nonce.clone().into(),
4241 request.reply_to_message_id.map(MessageId::from_proto),
4242 )
4243 .await?;
4244
4245 let message = proto::ChannelMessage {
4246 sender_id: session.user_id().to_proto(),
4247 id: message_id.to_proto(),
4248 body,
4249 mentions: request.mentions,
4250 timestamp: timestamp.unix_timestamp() as u64,
4251 nonce: Some(nonce),
4252 reply_to_message_id: request.reply_to_message_id,
4253 edited_at: None,
4254 };
4255 broadcast(
4256 Some(session.connection_id),
4257 participant_connection_ids.clone(),
4258 |connection| {
4259 session.peer.send(
4260 connection,
4261 proto::ChannelMessageSent {
4262 channel_id: channel_id.to_proto(),
4263 message: Some(message.clone()),
4264 },
4265 )
4266 },
4267 );
4268 response.send(proto::SendChannelMessageResponse {
4269 message: Some(message),
4270 })?;
4271
4272 let pool = &*session.connection_pool().await;
4273 let non_participants =
4274 pool.channel_connection_ids(channel_id)
4275 .filter_map(|(connection_id, _)| {
4276 if participant_connection_ids.contains(&connection_id) {
4277 None
4278 } else {
4279 Some(connection_id)
4280 }
4281 });
4282 broadcast(None, non_participants, |peer_id| {
4283 session.peer.send(
4284 peer_id,
4285 proto::UpdateChannels {
4286 latest_channel_message_ids: vec![proto::ChannelMessageId {
4287 channel_id: channel_id.to_proto(),
4288 message_id: message_id.to_proto(),
4289 }],
4290 ..Default::default()
4291 },
4292 )
4293 });
4294 send_notifications(pool, &session.peer, notifications);
4295
4296 Ok(())
4297}
4298
4299/// Delete a channel message
4300async fn remove_channel_message(
4301 request: proto::RemoveChannelMessage,
4302 response: Response<proto::RemoveChannelMessage>,
4303 session: UserSession,
4304) -> Result<()> {
4305 let channel_id = ChannelId::from_proto(request.channel_id);
4306 let message_id = MessageId::from_proto(request.message_id);
4307 let (connection_ids, existing_notification_ids) = session
4308 .db()
4309 .await
4310 .remove_channel_message(channel_id, message_id, session.user_id())
4311 .await?;
4312
4313 broadcast(
4314 Some(session.connection_id),
4315 connection_ids,
4316 move |connection| {
4317 session.peer.send(connection, request.clone())?;
4318
4319 for notification_id in &existing_notification_ids {
4320 session.peer.send(
4321 connection,
4322 proto::DeleteNotification {
4323 notification_id: (*notification_id).to_proto(),
4324 },
4325 )?;
4326 }
4327
4328 Ok(())
4329 },
4330 );
4331 response.send(proto::Ack {})?;
4332 Ok(())
4333}
4334
4335async fn update_channel_message(
4336 request: proto::UpdateChannelMessage,
4337 response: Response<proto::UpdateChannelMessage>,
4338 session: UserSession,
4339) -> Result<()> {
4340 let channel_id = ChannelId::from_proto(request.channel_id);
4341 let message_id = MessageId::from_proto(request.message_id);
4342 let updated_at = OffsetDateTime::now_utc();
4343 let UpdatedChannelMessage {
4344 message_id,
4345 participant_connection_ids,
4346 notifications,
4347 reply_to_message_id,
4348 timestamp,
4349 deleted_mention_notification_ids,
4350 updated_mention_notifications,
4351 } = session
4352 .db()
4353 .await
4354 .update_channel_message(
4355 channel_id,
4356 message_id,
4357 session.user_id(),
4358 request.body.as_str(),
4359 &request.mentions,
4360 updated_at,
4361 )
4362 .await?;
4363
4364 let nonce = request
4365 .nonce
4366 .clone()
4367 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4368
4369 let message = proto::ChannelMessage {
4370 sender_id: session.user_id().to_proto(),
4371 id: message_id.to_proto(),
4372 body: request.body.clone(),
4373 mentions: request.mentions.clone(),
4374 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4375 nonce: Some(nonce),
4376 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4377 edited_at: Some(updated_at.unix_timestamp() as u64),
4378 };
4379
4380 response.send(proto::Ack {})?;
4381
4382 let pool = &*session.connection_pool().await;
4383 broadcast(
4384 Some(session.connection_id),
4385 participant_connection_ids,
4386 |connection| {
4387 session.peer.send(
4388 connection,
4389 proto::ChannelMessageUpdate {
4390 channel_id: channel_id.to_proto(),
4391 message: Some(message.clone()),
4392 },
4393 )?;
4394
4395 for notification_id in &deleted_mention_notification_ids {
4396 session.peer.send(
4397 connection,
4398 proto::DeleteNotification {
4399 notification_id: (*notification_id).to_proto(),
4400 },
4401 )?;
4402 }
4403
4404 for notification in &updated_mention_notifications {
4405 session.peer.send(
4406 connection,
4407 proto::UpdateNotification {
4408 notification: Some(notification.clone()),
4409 },
4410 )?;
4411 }
4412
4413 Ok(())
4414 },
4415 );
4416
4417 send_notifications(pool, &session.peer, notifications);
4418
4419 Ok(())
4420}
4421
4422/// Mark a channel message as read
4423async fn acknowledge_channel_message(
4424 request: proto::AckChannelMessage,
4425 session: UserSession,
4426) -> Result<()> {
4427 let channel_id = ChannelId::from_proto(request.channel_id);
4428 let message_id = MessageId::from_proto(request.message_id);
4429 let notifications = session
4430 .db()
4431 .await
4432 .observe_channel_message(channel_id, session.user_id(), message_id)
4433 .await?;
4434 send_notifications(
4435 &*session.connection_pool().await,
4436 &session.peer,
4437 notifications,
4438 );
4439 Ok(())
4440}
4441
4442/// Mark a buffer version as synced
4443async fn acknowledge_buffer_version(
4444 request: proto::AckBufferOperation,
4445 session: UserSession,
4446) -> Result<()> {
4447 let buffer_id = BufferId::from_proto(request.buffer_id);
4448 session
4449 .db()
4450 .await
4451 .observe_buffer_version(
4452 buffer_id,
4453 session.user_id(),
4454 request.epoch as i32,
4455 &request.version,
4456 )
4457 .await?;
4458 Ok(())
4459}
4460
4461async fn count_language_model_tokens(
4462 request: proto::CountLanguageModelTokens,
4463 response: Response<proto::CountLanguageModelTokens>,
4464 session: Session,
4465 config: &Config,
4466) -> Result<()> {
4467 let Some(session) = session.for_user() else {
4468 return Err(anyhow!("user not found"))?;
4469 };
4470 authorize_access_to_legacy_llm_endpoints(&session).await?;
4471
4472 let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4473 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4474 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4475 };
4476
4477 session
4478 .app_state
4479 .rate_limiter
4480 .check(&*rate_limit, session.user_id())
4481 .await?;
4482
4483 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4484 Some(proto::LanguageModelProvider::Google) => {
4485 let api_key = config
4486 .google_ai_api_key
4487 .as_ref()
4488 .context("no Google AI API key configured on the server")?;
4489 google_ai::count_tokens(
4490 session.http_client.as_ref(),
4491 google_ai::API_URL,
4492 api_key,
4493 serde_json::from_str(&request.request)?,
4494 None,
4495 )
4496 .await?
4497 }
4498 _ => return Err(anyhow!("unsupported provider"))?,
4499 };
4500
4501 response.send(proto::CountLanguageModelTokensResponse {
4502 token_count: result.total_tokens as u32,
4503 })?;
4504
4505 Ok(())
4506}
4507
4508struct ZedProCountLanguageModelTokensRateLimit;
4509
4510impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4511 fn capacity(&self) -> usize {
4512 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4513 .ok()
4514 .and_then(|v| v.parse().ok())
4515 .unwrap_or(600) // Picked arbitrarily
4516 }
4517
4518 fn refill_duration(&self) -> chrono::Duration {
4519 chrono::Duration::hours(1)
4520 }
4521
4522 fn db_name(&self) -> &'static str {
4523 "zed-pro:count-language-model-tokens"
4524 }
4525}
4526
4527struct FreeCountLanguageModelTokensRateLimit;
4528
4529impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4530 fn capacity(&self) -> usize {
4531 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4532 .ok()
4533 .and_then(|v| v.parse().ok())
4534 .unwrap_or(600 / 10) // Picked arbitrarily
4535 }
4536
4537 fn refill_duration(&self) -> chrono::Duration {
4538 chrono::Duration::hours(1)
4539 }
4540
4541 fn db_name(&self) -> &'static str {
4542 "free:count-language-model-tokens"
4543 }
4544}
4545
4546struct ZedProComputeEmbeddingsRateLimit;
4547
4548impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4549 fn capacity(&self) -> usize {
4550 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4551 .ok()
4552 .and_then(|v| v.parse().ok())
4553 .unwrap_or(5000) // Picked arbitrarily
4554 }
4555
4556 fn refill_duration(&self) -> chrono::Duration {
4557 chrono::Duration::hours(1)
4558 }
4559
4560 fn db_name(&self) -> &'static str {
4561 "zed-pro:compute-embeddings"
4562 }
4563}
4564
4565struct FreeComputeEmbeddingsRateLimit;
4566
4567impl RateLimit for FreeComputeEmbeddingsRateLimit {
4568 fn capacity(&self) -> usize {
4569 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4570 .ok()
4571 .and_then(|v| v.parse().ok())
4572 .unwrap_or(5000 / 10) // Picked arbitrarily
4573 }
4574
4575 fn refill_duration(&self) -> chrono::Duration {
4576 chrono::Duration::hours(1)
4577 }
4578
4579 fn db_name(&self) -> &'static str {
4580 "free:compute-embeddings"
4581 }
4582}
4583
4584async fn compute_embeddings(
4585 request: proto::ComputeEmbeddings,
4586 response: Response<proto::ComputeEmbeddings>,
4587 session: UserSession,
4588 api_key: Option<Arc<str>>,
4589) -> Result<()> {
4590 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4591 authorize_access_to_legacy_llm_endpoints(&session).await?;
4592
4593 let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
4594 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4595 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4596 };
4597
4598 session
4599 .app_state
4600 .rate_limiter
4601 .check(&*rate_limit, session.user_id())
4602 .await?;
4603
4604 let embeddings = match request.model.as_str() {
4605 "openai/text-embedding-3-small" => {
4606 open_ai::embed(
4607 session.http_client.as_ref(),
4608 OPEN_AI_API_URL,
4609 &api_key,
4610 OpenAiEmbeddingModel::TextEmbedding3Small,
4611 request.texts.iter().map(|text| text.as_str()),
4612 )
4613 .await?
4614 }
4615 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4616 };
4617
4618 let embeddings = request
4619 .texts
4620 .iter()
4621 .map(|text| {
4622 let mut hasher = sha2::Sha256::new();
4623 hasher.update(text.as_bytes());
4624 let result = hasher.finalize();
4625 result.to_vec()
4626 })
4627 .zip(
4628 embeddings
4629 .data
4630 .into_iter()
4631 .map(|embedding| embedding.embedding),
4632 )
4633 .collect::<HashMap<_, _>>();
4634
4635 let db = session.db().await;
4636 db.save_embeddings(&request.model, &embeddings)
4637 .await
4638 .context("failed to save embeddings")
4639 .trace_err();
4640
4641 response.send(proto::ComputeEmbeddingsResponse {
4642 embeddings: embeddings
4643 .into_iter()
4644 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4645 .collect(),
4646 })?;
4647 Ok(())
4648}
4649
4650async fn get_cached_embeddings(
4651 request: proto::GetCachedEmbeddings,
4652 response: Response<proto::GetCachedEmbeddings>,
4653 session: UserSession,
4654) -> Result<()> {
4655 authorize_access_to_legacy_llm_endpoints(&session).await?;
4656
4657 let db = session.db().await;
4658 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4659
4660 response.send(proto::GetCachedEmbeddingsResponse {
4661 embeddings: embeddings
4662 .into_iter()
4663 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4664 .collect(),
4665 })?;
4666 Ok(())
4667}
4668
4669/// This is leftover from before the LLM service.
4670///
4671/// The endpoints protected by this check will be moved there eventually.
4672async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
4673 if session.is_staff() {
4674 Ok(())
4675 } else {
4676 Err(anyhow!("permission denied"))?
4677 }
4678}
4679
4680/// Get a Supermaven API key for the user
4681async fn get_supermaven_api_key(
4682 _request: proto::GetSupermavenApiKey,
4683 response: Response<proto::GetSupermavenApiKey>,
4684 session: UserSession,
4685) -> Result<()> {
4686 let user_id: String = session.user_id().to_string();
4687 if !session.is_staff() {
4688 return Err(anyhow!("supermaven not enabled for this account"))?;
4689 }
4690
4691 let email = session
4692 .email()
4693 .ok_or_else(|| anyhow!("user must have an email"))?;
4694
4695 let supermaven_admin_api = session
4696 .supermaven_client
4697 .as_ref()
4698 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4699
4700 let result = supermaven_admin_api
4701 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4702 .await?;
4703
4704 response.send(proto::GetSupermavenApiKeyResponse {
4705 api_key: result.api_key,
4706 })?;
4707
4708 Ok(())
4709}
4710
4711/// Start receiving chat updates for a channel
4712async fn join_channel_chat(
4713 request: proto::JoinChannelChat,
4714 response: Response<proto::JoinChannelChat>,
4715 session: UserSession,
4716) -> Result<()> {
4717 let channel_id = ChannelId::from_proto(request.channel_id);
4718
4719 let db = session.db().await;
4720 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4721 .await?;
4722 let messages = db
4723 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4724 .await?;
4725 response.send(proto::JoinChannelChatResponse {
4726 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4727 messages,
4728 })?;
4729 Ok(())
4730}
4731
4732/// Stop receiving chat updates for a channel
4733async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4734 let channel_id = ChannelId::from_proto(request.channel_id);
4735 session
4736 .db()
4737 .await
4738 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4739 .await?;
4740 Ok(())
4741}
4742
4743/// Retrieve the chat history for a channel
4744async fn get_channel_messages(
4745 request: proto::GetChannelMessages,
4746 response: Response<proto::GetChannelMessages>,
4747 session: UserSession,
4748) -> Result<()> {
4749 let channel_id = ChannelId::from_proto(request.channel_id);
4750 let messages = session
4751 .db()
4752 .await
4753 .get_channel_messages(
4754 channel_id,
4755 session.user_id(),
4756 MESSAGE_COUNT_PER_PAGE,
4757 Some(MessageId::from_proto(request.before_message_id)),
4758 )
4759 .await?;
4760 response.send(proto::GetChannelMessagesResponse {
4761 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4762 messages,
4763 })?;
4764 Ok(())
4765}
4766
4767/// Retrieve specific chat messages
4768async fn get_channel_messages_by_id(
4769 request: proto::GetChannelMessagesById,
4770 response: Response<proto::GetChannelMessagesById>,
4771 session: UserSession,
4772) -> Result<()> {
4773 let message_ids = request
4774 .message_ids
4775 .iter()
4776 .map(|id| MessageId::from_proto(*id))
4777 .collect::<Vec<_>>();
4778 let messages = session
4779 .db()
4780 .await
4781 .get_channel_messages_by_id(session.user_id(), &message_ids)
4782 .await?;
4783 response.send(proto::GetChannelMessagesResponse {
4784 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4785 messages,
4786 })?;
4787 Ok(())
4788}
4789
4790/// Retrieve the current users notifications
4791async fn get_notifications(
4792 request: proto::GetNotifications,
4793 response: Response<proto::GetNotifications>,
4794 session: UserSession,
4795) -> Result<()> {
4796 let notifications = session
4797 .db()
4798 .await
4799 .get_notifications(
4800 session.user_id(),
4801 NOTIFICATION_COUNT_PER_PAGE,
4802 request.before_id.map(db::NotificationId::from_proto),
4803 )
4804 .await?;
4805 response.send(proto::GetNotificationsResponse {
4806 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4807 notifications,
4808 })?;
4809 Ok(())
4810}
4811
4812/// Mark notifications as read
4813async fn mark_notification_as_read(
4814 request: proto::MarkNotificationRead,
4815 response: Response<proto::MarkNotificationRead>,
4816 session: UserSession,
4817) -> Result<()> {
4818 let database = &session.db().await;
4819 let notifications = database
4820 .mark_notification_as_read_by_id(
4821 session.user_id(),
4822 NotificationId::from_proto(request.notification_id),
4823 )
4824 .await?;
4825 send_notifications(
4826 &*session.connection_pool().await,
4827 &session.peer,
4828 notifications,
4829 );
4830 response.send(proto::Ack {})?;
4831 Ok(())
4832}
4833
4834/// Get the current users information
4835async fn get_private_user_info(
4836 _request: proto::GetPrivateUserInfo,
4837 response: Response<proto::GetPrivateUserInfo>,
4838 session: UserSession,
4839) -> Result<()> {
4840 let db = session.db().await;
4841
4842 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4843 let user = db
4844 .get_user_by_id(session.user_id())
4845 .await?
4846 .ok_or_else(|| anyhow!("user not found"))?;
4847 let flags = db.get_user_flags(session.user_id()).await?;
4848
4849 response.send(proto::GetPrivateUserInfoResponse {
4850 metrics_id,
4851 staff: user.admin,
4852 flags,
4853 accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
4854 })?;
4855 Ok(())
4856}
4857
4858/// Accept the terms of service (tos) on behalf of the current user
4859async fn accept_terms_of_service(
4860 _request: proto::AcceptTermsOfService,
4861 response: Response<proto::AcceptTermsOfService>,
4862 session: UserSession,
4863) -> Result<()> {
4864 let db = session.db().await;
4865
4866 let accepted_tos_at = Utc::now();
4867 db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
4868 .await?;
4869
4870 response.send(proto::AcceptTermsOfServiceResponse {
4871 accepted_tos_at: accepted_tos_at.timestamp() as u64,
4872 })?;
4873 Ok(())
4874}
4875
4876/// The minimum account age an account must have in order to use the LLM service.
4877const MIN_ACCOUNT_AGE_FOR_LLM_USE: chrono::Duration = chrono::Duration::days(30);
4878
4879async fn get_llm_api_token(
4880 _request: proto::GetLlmToken,
4881 response: Response<proto::GetLlmToken>,
4882 session: UserSession,
4883) -> Result<()> {
4884 let db = session.db().await;
4885
4886 let flags = db.get_user_flags(session.user_id()).await?;
4887 let has_language_models_feature_flag = flags.iter().any(|flag| flag == "language-models");
4888 let has_llm_closed_beta_feature_flag = flags.iter().any(|flag| flag == "llm-closed-beta");
4889
4890 if !session.is_staff() && !has_language_models_feature_flag {
4891 Err(anyhow!("permission denied"))?
4892 }
4893
4894 let user_id = session.user_id();
4895 let user = db
4896 .get_user_by_id(user_id)
4897 .await?
4898 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
4899
4900 if user.accepted_tos_at.is_none() {
4901 Err(anyhow!("terms of service not accepted"))?
4902 }
4903
4904 let mut account_created_at = user.created_at;
4905 if let Some(github_created_at) = user.github_user_created_at {
4906 account_created_at = account_created_at.min(github_created_at);
4907 }
4908 if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
4909 Err(anyhow!("account too young"))?
4910 }
4911 let token = LlmTokenClaims::create(
4912 user.id,
4913 user.github_login.clone(),
4914 session.is_staff(),
4915 has_llm_closed_beta_feature_flag,
4916 session.current_plan(db).await?,
4917 &session.app_state.config,
4918 )?;
4919 response.send(proto::GetLlmTokenResponse { token })?;
4920 Ok(())
4921}
4922
4923fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4924 let message = match message {
4925 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4926 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4927 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4928 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4929 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4930 code: frame.code.into(),
4931 reason: frame.reason,
4932 })),
4933 // We should never receive a frame while reading the message, according
4934 // to the `tungstenite` maintainers:
4935 //
4936 // > It cannot occur when you read messages from the WebSocket, but it
4937 // > can be used when you want to send the raw frames (e.g. you want to
4938 // > send the frames to the WebSocket without composing the full message first).
4939 // >
4940 // > — https://github.com/snapview/tungstenite-rs/issues/268
4941 TungsteniteMessage::Frame(_) => {
4942 bail!("received an unexpected frame while reading the message")
4943 }
4944 };
4945
4946 Ok(message)
4947}
4948
4949fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4950 match message {
4951 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4952 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4953 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4954 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4955 AxumMessage::Close(frame) => {
4956 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4957 code: frame.code.into(),
4958 reason: frame.reason,
4959 }))
4960 }
4961 }
4962}
4963
4964fn notify_membership_updated(
4965 connection_pool: &mut ConnectionPool,
4966 result: MembershipUpdated,
4967 user_id: UserId,
4968 peer: &Peer,
4969) {
4970 for membership in &result.new_channels.channel_memberships {
4971 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4972 }
4973 for channel_id in &result.removed_channels {
4974 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4975 }
4976
4977 let user_channels_update = proto::UpdateUserChannels {
4978 channel_memberships: result
4979 .new_channels
4980 .channel_memberships
4981 .iter()
4982 .map(|cm| proto::ChannelMembership {
4983 channel_id: cm.channel_id.to_proto(),
4984 role: cm.role.into(),
4985 })
4986 .collect(),
4987 ..Default::default()
4988 };
4989
4990 let mut update = build_channels_update(result.new_channels);
4991 update.delete_channels = result
4992 .removed_channels
4993 .into_iter()
4994 .map(|id| id.to_proto())
4995 .collect();
4996 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4997
4998 for connection_id in connection_pool.user_connection_ids(user_id) {
4999 peer.send(connection_id, user_channels_update.clone())
5000 .trace_err();
5001 peer.send(connection_id, update.clone()).trace_err();
5002 }
5003}
5004
5005fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5006 proto::UpdateUserChannels {
5007 channel_memberships: channels
5008 .channel_memberships
5009 .iter()
5010 .map(|m| proto::ChannelMembership {
5011 channel_id: m.channel_id.to_proto(),
5012 role: m.role.into(),
5013 })
5014 .collect(),
5015 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5016 observed_channel_message_id: channels.observed_channel_messages.clone(),
5017 }
5018}
5019
5020fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5021 let mut update = proto::UpdateChannels::default();
5022
5023 for channel in channels.channels {
5024 update.channels.push(channel.to_proto());
5025 }
5026
5027 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5028 update.latest_channel_message_ids = channels.latest_channel_messages;
5029
5030 for (channel_id, participants) in channels.channel_participants {
5031 update
5032 .channel_participants
5033 .push(proto::ChannelParticipants {
5034 channel_id: channel_id.to_proto(),
5035 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5036 });
5037 }
5038
5039 for channel in channels.invited_channels {
5040 update.channel_invitations.push(channel.to_proto());
5041 }
5042
5043 update.hosted_projects = channels.hosted_projects;
5044 update
5045}
5046
5047fn build_initial_contacts_update(
5048 contacts: Vec<db::Contact>,
5049 pool: &ConnectionPool,
5050) -> proto::UpdateContacts {
5051 let mut update = proto::UpdateContacts::default();
5052
5053 for contact in contacts {
5054 match contact {
5055 db::Contact::Accepted { user_id, busy } => {
5056 update.contacts.push(contact_for_user(user_id, busy, pool));
5057 }
5058 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5059 db::Contact::Incoming { user_id } => {
5060 update
5061 .incoming_requests
5062 .push(proto::IncomingContactRequest {
5063 requester_id: user_id.to_proto(),
5064 })
5065 }
5066 }
5067 }
5068
5069 update
5070}
5071
5072fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5073 proto::Contact {
5074 user_id: user_id.to_proto(),
5075 online: pool.is_user_online(user_id),
5076 busy,
5077 }
5078}
5079
5080fn room_updated(room: &proto::Room, peer: &Peer) {
5081 broadcast(
5082 None,
5083 room.participants
5084 .iter()
5085 .filter_map(|participant| Some(participant.peer_id?.into())),
5086 |peer_id| {
5087 peer.send(
5088 peer_id,
5089 proto::RoomUpdated {
5090 room: Some(room.clone()),
5091 },
5092 )
5093 },
5094 );
5095}
5096
5097fn channel_updated(
5098 channel: &db::channel::Model,
5099 room: &proto::Room,
5100 peer: &Peer,
5101 pool: &ConnectionPool,
5102) {
5103 let participants = room
5104 .participants
5105 .iter()
5106 .map(|p| p.user_id)
5107 .collect::<Vec<_>>();
5108
5109 broadcast(
5110 None,
5111 pool.channel_connection_ids(channel.root_id())
5112 .filter_map(|(channel_id, role)| {
5113 role.can_see_channel(channel.visibility)
5114 .then_some(channel_id)
5115 }),
5116 |peer_id| {
5117 peer.send(
5118 peer_id,
5119 proto::UpdateChannels {
5120 channel_participants: vec![proto::ChannelParticipants {
5121 channel_id: channel.id.to_proto(),
5122 participant_user_ids: participants.clone(),
5123 }],
5124 ..Default::default()
5125 },
5126 )
5127 },
5128 );
5129}
5130
5131async fn send_dev_server_projects_update(
5132 user_id: UserId,
5133 mut status: proto::DevServerProjectsUpdate,
5134 session: &Session,
5135) {
5136 let pool = session.connection_pool().await;
5137 for dev_server in &mut status.dev_servers {
5138 dev_server.status =
5139 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5140 }
5141 let connections = pool.user_connection_ids(user_id);
5142 for connection_id in connections {
5143 session.peer.send(connection_id, status.clone()).trace_err();
5144 }
5145}
5146
5147async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5148 let db = session.db().await;
5149
5150 let contacts = db.get_contacts(user_id).await?;
5151 let busy = db.is_user_busy(user_id).await?;
5152
5153 let pool = session.connection_pool().await;
5154 let updated_contact = contact_for_user(user_id, busy, &pool);
5155 for contact in contacts {
5156 if let db::Contact::Accepted {
5157 user_id: contact_user_id,
5158 ..
5159 } = contact
5160 {
5161 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5162 session
5163 .peer
5164 .send(
5165 contact_conn_id,
5166 proto::UpdateContacts {
5167 contacts: vec![updated_contact.clone()],
5168 remove_contacts: Default::default(),
5169 incoming_requests: Default::default(),
5170 remove_incoming_requests: Default::default(),
5171 outgoing_requests: Default::default(),
5172 remove_outgoing_requests: Default::default(),
5173 },
5174 )
5175 .trace_err();
5176 }
5177 }
5178 }
5179 Ok(())
5180}
5181
5182async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5183 log::info!("lost dev server connection, unsharing projects");
5184 let project_ids = session
5185 .db()
5186 .await
5187 .get_stale_dev_server_projects(session.connection_id)
5188 .await?;
5189
5190 for project_id in project_ids {
5191 // not unshare re-checks the connection ids match, so we get away with no transaction
5192 unshare_project_internal(project_id, session.connection_id, None, session).await?;
5193 }
5194
5195 let user_id = session.dev_server().user_id;
5196 let update = session
5197 .db()
5198 .await
5199 .dev_server_projects_update(user_id)
5200 .await?;
5201
5202 send_dev_server_projects_update(user_id, update, session).await;
5203
5204 Ok(())
5205}
5206
5207async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5208 let mut contacts_to_update = HashSet::default();
5209
5210 let room_id;
5211 let canceled_calls_to_user_ids;
5212 let live_kit_room;
5213 let delete_live_kit_room;
5214 let room;
5215 let channel;
5216
5217 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5218 contacts_to_update.insert(session.user_id());
5219
5220 for project in left_room.left_projects.values() {
5221 project_left(project, session);
5222 }
5223
5224 room_id = RoomId::from_proto(left_room.room.id);
5225 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5226 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5227 delete_live_kit_room = left_room.deleted;
5228 room = mem::take(&mut left_room.room);
5229 channel = mem::take(&mut left_room.channel);
5230
5231 room_updated(&room, &session.peer);
5232 } else {
5233 return Ok(());
5234 }
5235
5236 if let Some(channel) = channel {
5237 channel_updated(
5238 &channel,
5239 &room,
5240 &session.peer,
5241 &*session.connection_pool().await,
5242 );
5243 }
5244
5245 {
5246 let pool = session.connection_pool().await;
5247 for canceled_user_id in canceled_calls_to_user_ids {
5248 for connection_id in pool.user_connection_ids(canceled_user_id) {
5249 session
5250 .peer
5251 .send(
5252 connection_id,
5253 proto::CallCanceled {
5254 room_id: room_id.to_proto(),
5255 },
5256 )
5257 .trace_err();
5258 }
5259 contacts_to_update.insert(canceled_user_id);
5260 }
5261 }
5262
5263 for contact_user_id in contacts_to_update {
5264 update_user_contacts(contact_user_id, session).await?;
5265 }
5266
5267 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5268 live_kit
5269 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5270 .await
5271 .trace_err();
5272
5273 if delete_live_kit_room {
5274 live_kit.delete_room(live_kit_room).await.trace_err();
5275 }
5276 }
5277
5278 Ok(())
5279}
5280
5281async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5282 let left_channel_buffers = session
5283 .db()
5284 .await
5285 .leave_channel_buffers(session.connection_id)
5286 .await?;
5287
5288 for left_buffer in left_channel_buffers {
5289 channel_buffer_updated(
5290 session.connection_id,
5291 left_buffer.connections,
5292 &proto::UpdateChannelBufferCollaborators {
5293 channel_id: left_buffer.channel_id.to_proto(),
5294 collaborators: left_buffer.collaborators,
5295 },
5296 &session.peer,
5297 );
5298 }
5299
5300 Ok(())
5301}
5302
5303fn project_left(project: &db::LeftProject, session: &UserSession) {
5304 for connection_id in &project.connection_ids {
5305 if project.should_unshare {
5306 session
5307 .peer
5308 .send(
5309 *connection_id,
5310 proto::UnshareProject {
5311 project_id: project.id.to_proto(),
5312 },
5313 )
5314 .trace_err();
5315 } else {
5316 session
5317 .peer
5318 .send(
5319 *connection_id,
5320 proto::RemoveProjectCollaborator {
5321 project_id: project.id.to_proto(),
5322 peer_id: Some(session.connection_id.into()),
5323 },
5324 )
5325 .trace_err();
5326 }
5327 }
5328}
5329
5330pub trait ResultExt {
5331 type Ok;
5332
5333 fn trace_err(self) -> Option<Self::Ok>;
5334}
5335
5336impl<T, E> ResultExt for Result<T, E>
5337where
5338 E: std::fmt::Debug,
5339{
5340 type Ok = T;
5341
5342 #[track_caller]
5343 fn trace_err(self) -> Option<T> {
5344 match self {
5345 Ok(value) => Some(value),
5346 Err(error) => {
5347 tracing::error!("{:?}", error);
5348 None
5349 }
5350 }
5351 }
5352}