1mod connection_pool;
2
3use crate::api::CloudflareIpCountryHeader;
4use crate::{
5 auth,
6 db::{
7 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
8 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
9 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
10 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
11 ServerId, UpdatedChannelMessage, User, UserId,
12 },
13 executor::Executor,
14 AppState, Config, Error, RateLimit, RateLimiter, Result,
15};
16use anyhow::{anyhow, bail, Context as _};
17use async_tungstenite::tungstenite::{
18 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
19};
20use axum::{
21 body::Body,
22 extract::{
23 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
24 ConnectInfo, WebSocketUpgrade,
25 },
26 headers::{Header, HeaderName},
27 http::StatusCode,
28 middleware,
29 response::IntoResponse,
30 routing::get,
31 Extension, Router, TypedHeader,
32};
33use collections::{HashMap, HashSet};
34pub use connection_pool::{ConnectionPool, ZedVersion};
35use core::fmt::{self, Debug, Formatter};
36use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
37use sha2::Digest;
38use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
39
40use futures::{
41 channel::oneshot,
42 future::{self, BoxFuture},
43 stream::FuturesUnordered,
44 FutureExt, SinkExt, StreamExt, TryStreamExt,
45};
46use http_client::IsahcHttpClient;
47use prometheus::{register_int_gauge, IntGauge};
48use rpc::{
49 proto::{
50 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
51 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
52 },
53 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
54};
55use semantic_version::SemanticVersion;
56use serde::{Serialize, Serializer};
57use std::{
58 any::TypeId,
59 future::Future,
60 marker::PhantomData,
61 mem,
62 net::SocketAddr,
63 ops::{Deref, DerefMut},
64 rc::Rc,
65 sync::{
66 atomic::{AtomicBool, Ordering::SeqCst},
67 Arc, OnceLock,
68 },
69 time::{Duration, Instant},
70};
71use time::OffsetDateTime;
72use tokio::sync::{watch, Semaphore};
73use tower::ServiceBuilder;
74use tracing::{
75 field::{self},
76 info_span, instrument, Instrument,
77};
78
79use self::connection_pool::VersionedMessage;
80
81pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
82
83// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
84pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
85
86const MESSAGE_COUNT_PER_PAGE: usize = 100;
87const MAX_MESSAGE_LEN: usize = 1024;
88const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
89
90type MessageHandler =
91 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
92
93struct Response<R> {
94 peer: Arc<Peer>,
95 receipt: Receipt<R>,
96 responded: Arc<AtomicBool>,
97}
98
99impl<R: RequestMessage> Response<R> {
100 fn send(self, payload: R::Response) -> Result<()> {
101 self.responded.store(true, SeqCst);
102 self.peer.respond(self.receipt, payload)?;
103 Ok(())
104 }
105}
106
107struct StreamingResponse<R: RequestMessage> {
108 peer: Arc<Peer>,
109 receipt: Receipt<R>,
110}
111
112impl<R: RequestMessage> StreamingResponse<R> {
113 fn send(&self, payload: R::Response) -> Result<()> {
114 self.peer.respond(self.receipt, payload)?;
115 Ok(())
116 }
117}
118
119#[derive(Clone, Debug)]
120pub enum Principal {
121 User(User),
122 Impersonated { user: User, admin: User },
123 DevServer(dev_server::Model),
124}
125
126impl Principal {
127 fn update_span(&self, span: &tracing::Span) {
128 match &self {
129 Principal::User(user) => {
130 span.record("user_id", &user.id.0);
131 span.record("login", &user.github_login);
132 }
133 Principal::Impersonated { user, admin } => {
134 span.record("user_id", &user.id.0);
135 span.record("login", &user.github_login);
136 span.record("impersonator", &admin.github_login);
137 }
138 Principal::DevServer(dev_server) => {
139 span.record("dev_server_id", &dev_server.id.0);
140 }
141 }
142 }
143}
144
145#[derive(Clone)]
146struct Session {
147 principal: Principal,
148 connection_id: ConnectionId,
149 db: Arc<tokio::sync::Mutex<DbHandle>>,
150 peer: Arc<Peer>,
151 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
152 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
153 supermaven_client: Option<Arc<SupermavenAdminApi>>,
154 http_client: Arc<IsahcHttpClient>,
155 rate_limiter: Arc<RateLimiter>,
156 /// The GeoIP country code for the user.
157 #[allow(unused)]
158 geoip_country_code: Option<String>,
159 _executor: Executor,
160}
161
162impl Session {
163 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
164 #[cfg(test)]
165 tokio::task::yield_now().await;
166 let guard = self.db.lock().await;
167 #[cfg(test)]
168 tokio::task::yield_now().await;
169 guard
170 }
171
172 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
173 #[cfg(test)]
174 tokio::task::yield_now().await;
175 let guard = self.connection_pool.lock();
176 ConnectionPoolGuard {
177 guard,
178 _not_send: PhantomData,
179 }
180 }
181
182 fn for_user(self) -> Option<UserSession> {
183 UserSession::new(self)
184 }
185
186 fn for_dev_server(self) -> Option<DevServerSession> {
187 DevServerSession::new(self)
188 }
189
190 fn user_id(&self) -> Option<UserId> {
191 match &self.principal {
192 Principal::User(user) => Some(user.id),
193 Principal::Impersonated { user, .. } => Some(user.id),
194 Principal::DevServer(_) => None,
195 }
196 }
197
198 fn is_staff(&self) -> bool {
199 match &self.principal {
200 Principal::User(user) => user.admin,
201 Principal::Impersonated { .. } => true,
202 Principal::DevServer(_) => false,
203 }
204 }
205
206 pub async fn current_plan(&self) -> anyhow::Result<proto::Plan> {
207 if self.is_staff() {
208 return Ok(proto::Plan::ZedPro);
209 }
210
211 let Some(user_id) = self.user_id() else {
212 return Ok(proto::Plan::Free);
213 };
214
215 let db = self.db().await;
216 if db.has_active_billing_subscription(user_id).await? {
217 Ok(proto::Plan::ZedPro)
218 } else {
219 Ok(proto::Plan::Free)
220 }
221 }
222
223 fn dev_server_id(&self) -> Option<DevServerId> {
224 match &self.principal {
225 Principal::User(_) | Principal::Impersonated { .. } => None,
226 Principal::DevServer(dev_server) => Some(dev_server.id),
227 }
228 }
229
230 fn principal_id(&self) -> PrincipalId {
231 match &self.principal {
232 Principal::User(user) => PrincipalId::UserId(user.id),
233 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
234 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
235 }
236 }
237}
238
239impl Debug for Session {
240 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
241 let mut result = f.debug_struct("Session");
242 match &self.principal {
243 Principal::User(user) => {
244 result.field("user", &user.github_login);
245 }
246 Principal::Impersonated { user, admin } => {
247 result.field("user", &user.github_login);
248 result.field("impersonator", &admin.github_login);
249 }
250 Principal::DevServer(dev_server) => {
251 result.field("dev_server", &dev_server.id);
252 }
253 }
254 result.field("connection_id", &self.connection_id).finish()
255 }
256}
257
258struct UserSession(Session);
259
260impl UserSession {
261 pub fn new(s: Session) -> Option<Self> {
262 s.user_id().map(|_| UserSession(s))
263 }
264 pub fn user_id(&self) -> UserId {
265 self.0.user_id().unwrap()
266 }
267
268 pub fn email(&self) -> Option<String> {
269 match &self.0.principal {
270 Principal::User(user) => user.email_address.clone(),
271 Principal::Impersonated { user, .. } => user.email_address.clone(),
272 Principal::DevServer(..) => None,
273 }
274 }
275}
276
277impl Deref for UserSession {
278 type Target = Session;
279
280 fn deref(&self) -> &Self::Target {
281 &self.0
282 }
283}
284impl DerefMut for UserSession {
285 fn deref_mut(&mut self) -> &mut Self::Target {
286 &mut self.0
287 }
288}
289
290struct DevServerSession(Session);
291
292impl DevServerSession {
293 pub fn new(s: Session) -> Option<Self> {
294 s.dev_server_id().map(|_| DevServerSession(s))
295 }
296 pub fn dev_server_id(&self) -> DevServerId {
297 self.0.dev_server_id().unwrap()
298 }
299
300 fn dev_server(&self) -> &dev_server::Model {
301 match &self.0.principal {
302 Principal::DevServer(dev_server) => dev_server,
303 _ => unreachable!(),
304 }
305 }
306}
307
308impl Deref for DevServerSession {
309 type Target = Session;
310
311 fn deref(&self) -> &Self::Target {
312 &self.0
313 }
314}
315impl DerefMut for DevServerSession {
316 fn deref_mut(&mut self) -> &mut Self::Target {
317 &mut self.0
318 }
319}
320
321fn user_handler<M: RequestMessage, Fut>(
322 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
323) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
324where
325 Fut: Send + Future<Output = Result<()>>,
326{
327 let handler = Arc::new(handler);
328 move |message, response, session| {
329 let handler = handler.clone();
330 Box::pin(async move {
331 if let Some(user_session) = session.for_user() {
332 Ok(handler(message, response, user_session).await?)
333 } else {
334 Err(Error::Internal(anyhow!(
335 "must be a user to call {}",
336 M::NAME
337 )))
338 }
339 })
340 }
341}
342
343fn dev_server_handler<M: RequestMessage, Fut>(
344 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
345) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
346where
347 Fut: Send + Future<Output = Result<()>>,
348{
349 let handler = Arc::new(handler);
350 move |message, response, session| {
351 let handler = handler.clone();
352 Box::pin(async move {
353 if let Some(dev_server_session) = session.for_dev_server() {
354 Ok(handler(message, response, dev_server_session).await?)
355 } else {
356 Err(Error::Internal(anyhow!(
357 "must be a dev server to call {}",
358 M::NAME
359 )))
360 }
361 })
362 }
363}
364
365fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
366 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
367) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
368where
369 InnertRetFut: Send + Future<Output = Result<()>>,
370{
371 let handler = Arc::new(handler);
372 move |message, session| {
373 let handler = handler.clone();
374 Box::pin(async move {
375 if let Some(user_session) = session.for_user() {
376 Ok(handler(message, user_session).await?)
377 } else {
378 Err(Error::Internal(anyhow!(
379 "must be a user to call {}",
380 M::NAME
381 )))
382 }
383 })
384 }
385}
386
387struct DbHandle(Arc<Database>);
388
389impl Deref for DbHandle {
390 type Target = Database;
391
392 fn deref(&self) -> &Self::Target {
393 self.0.as_ref()
394 }
395}
396
397pub struct Server {
398 id: parking_lot::Mutex<ServerId>,
399 peer: Arc<Peer>,
400 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
401 app_state: Arc<AppState>,
402 handlers: HashMap<TypeId, MessageHandler>,
403 teardown: watch::Sender<bool>,
404}
405
406pub(crate) struct ConnectionPoolGuard<'a> {
407 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
408 _not_send: PhantomData<Rc<()>>,
409}
410
411#[derive(Serialize)]
412pub struct ServerSnapshot<'a> {
413 peer: &'a Peer,
414 #[serde(serialize_with = "serialize_deref")]
415 connection_pool: ConnectionPoolGuard<'a>,
416}
417
418pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
419where
420 S: Serializer,
421 T: Deref<Target = U>,
422 U: Serialize,
423{
424 Serialize::serialize(value.deref(), serializer)
425}
426
427impl Server {
428 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
429 let mut server = Self {
430 id: parking_lot::Mutex::new(id),
431 peer: Peer::new(id.0 as u32),
432 app_state: app_state.clone(),
433 connection_pool: Default::default(),
434 handlers: Default::default(),
435 teardown: watch::channel(false).0,
436 };
437
438 server
439 .add_request_handler(ping)
440 .add_request_handler(user_handler(create_room))
441 .add_request_handler(user_handler(join_room))
442 .add_request_handler(user_handler(rejoin_room))
443 .add_request_handler(user_handler(leave_room))
444 .add_request_handler(user_handler(set_room_participant_role))
445 .add_request_handler(user_handler(call))
446 .add_request_handler(user_handler(cancel_call))
447 .add_message_handler(user_message_handler(decline_call))
448 .add_request_handler(user_handler(update_participant_location))
449 .add_request_handler(user_handler(share_project))
450 .add_message_handler(unshare_project)
451 .add_request_handler(user_handler(join_project))
452 .add_request_handler(user_handler(join_hosted_project))
453 .add_request_handler(user_handler(rejoin_dev_server_projects))
454 .add_request_handler(user_handler(create_dev_server_project))
455 .add_request_handler(user_handler(update_dev_server_project))
456 .add_request_handler(user_handler(delete_dev_server_project))
457 .add_request_handler(user_handler(create_dev_server))
458 .add_request_handler(user_handler(regenerate_dev_server_token))
459 .add_request_handler(user_handler(rename_dev_server))
460 .add_request_handler(user_handler(delete_dev_server))
461 .add_request_handler(user_handler(list_remote_directory))
462 .add_request_handler(dev_server_handler(share_dev_server_project))
463 .add_request_handler(dev_server_handler(shutdown_dev_server))
464 .add_request_handler(dev_server_handler(reconnect_dev_server))
465 .add_message_handler(user_message_handler(leave_project))
466 .add_request_handler(update_project)
467 .add_request_handler(update_worktree)
468 .add_message_handler(start_language_server)
469 .add_message_handler(update_language_server)
470 .add_message_handler(update_diagnostic_summary)
471 .add_message_handler(update_worktree_settings)
472 .add_request_handler(user_handler(
473 forward_project_request_for_owner::<proto::TaskContextForLocation>,
474 ))
475 .add_request_handler(user_handler(
476 forward_project_request_for_owner::<proto::TaskTemplates>,
477 ))
478 .add_request_handler(user_handler(
479 forward_read_only_project_request::<proto::GetHover>,
480 ))
481 .add_request_handler(user_handler(
482 forward_read_only_project_request::<proto::GetDefinition>,
483 ))
484 .add_request_handler(user_handler(
485 forward_read_only_project_request::<proto::GetTypeDefinition>,
486 ))
487 .add_request_handler(user_handler(
488 forward_read_only_project_request::<proto::GetReferences>,
489 ))
490 .add_request_handler(user_handler(
491 forward_read_only_project_request::<proto::SearchProject>,
492 ))
493 .add_request_handler(user_handler(
494 forward_read_only_project_request::<proto::GetDocumentHighlights>,
495 ))
496 .add_request_handler(user_handler(
497 forward_read_only_project_request::<proto::GetProjectSymbols>,
498 ))
499 .add_request_handler(user_handler(
500 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
501 ))
502 .add_request_handler(user_handler(
503 forward_read_only_project_request::<proto::OpenBufferById>,
504 ))
505 .add_request_handler(user_handler(
506 forward_read_only_project_request::<proto::SynchronizeBuffers>,
507 ))
508 .add_request_handler(user_handler(
509 forward_read_only_project_request::<proto::InlayHints>,
510 ))
511 .add_request_handler(user_handler(
512 forward_read_only_project_request::<proto::OpenBufferByPath>,
513 ))
514 .add_request_handler(user_handler(
515 forward_mutating_project_request::<proto::GetCompletions>,
516 ))
517 .add_request_handler(user_handler(
518 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
519 ))
520 .add_request_handler(user_handler(
521 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
522 ))
523 .add_request_handler(user_handler(
524 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
525 ))
526 .add_request_handler(user_handler(
527 forward_mutating_project_request::<proto::GetCodeActions>,
528 ))
529 .add_request_handler(user_handler(
530 forward_mutating_project_request::<proto::ApplyCodeAction>,
531 ))
532 .add_request_handler(user_handler(
533 forward_mutating_project_request::<proto::PrepareRename>,
534 ))
535 .add_request_handler(user_handler(
536 forward_mutating_project_request::<proto::PerformRename>,
537 ))
538 .add_request_handler(user_handler(
539 forward_mutating_project_request::<proto::ReloadBuffers>,
540 ))
541 .add_request_handler(user_handler(
542 forward_mutating_project_request::<proto::FormatBuffers>,
543 ))
544 .add_request_handler(user_handler(
545 forward_mutating_project_request::<proto::CreateProjectEntry>,
546 ))
547 .add_request_handler(user_handler(
548 forward_mutating_project_request::<proto::RenameProjectEntry>,
549 ))
550 .add_request_handler(user_handler(
551 forward_mutating_project_request::<proto::CopyProjectEntry>,
552 ))
553 .add_request_handler(user_handler(
554 forward_mutating_project_request::<proto::DeleteProjectEntry>,
555 ))
556 .add_request_handler(user_handler(
557 forward_mutating_project_request::<proto::ExpandProjectEntry>,
558 ))
559 .add_request_handler(user_handler(
560 forward_mutating_project_request::<proto::OnTypeFormatting>,
561 ))
562 .add_request_handler(user_handler(
563 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
564 ))
565 .add_request_handler(user_handler(
566 forward_mutating_project_request::<proto::BlameBuffer>,
567 ))
568 .add_request_handler(user_handler(
569 forward_mutating_project_request::<proto::MultiLspQuery>,
570 ))
571 .add_request_handler(user_handler(
572 forward_mutating_project_request::<proto::RestartLanguageServers>,
573 ))
574 .add_request_handler(user_handler(
575 forward_mutating_project_request::<proto::LinkedEditingRange>,
576 ))
577 .add_message_handler(create_buffer_for_peer)
578 .add_request_handler(update_buffer)
579 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
580 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
581 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
582 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
583 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
584 .add_request_handler(get_users)
585 .add_request_handler(user_handler(fuzzy_search_users))
586 .add_request_handler(user_handler(request_contact))
587 .add_request_handler(user_handler(remove_contact))
588 .add_request_handler(user_handler(respond_to_contact_request))
589 .add_message_handler(subscribe_to_channels)
590 .add_request_handler(user_handler(create_channel))
591 .add_request_handler(user_handler(delete_channel))
592 .add_request_handler(user_handler(invite_channel_member))
593 .add_request_handler(user_handler(remove_channel_member))
594 .add_request_handler(user_handler(set_channel_member_role))
595 .add_request_handler(user_handler(set_channel_visibility))
596 .add_request_handler(user_handler(rename_channel))
597 .add_request_handler(user_handler(join_channel_buffer))
598 .add_request_handler(user_handler(leave_channel_buffer))
599 .add_message_handler(user_message_handler(update_channel_buffer))
600 .add_request_handler(user_handler(rejoin_channel_buffers))
601 .add_request_handler(user_handler(get_channel_members))
602 .add_request_handler(user_handler(respond_to_channel_invite))
603 .add_request_handler(user_handler(join_channel))
604 .add_request_handler(user_handler(join_channel_chat))
605 .add_message_handler(user_message_handler(leave_channel_chat))
606 .add_request_handler(user_handler(send_channel_message))
607 .add_request_handler(user_handler(remove_channel_message))
608 .add_request_handler(user_handler(update_channel_message))
609 .add_request_handler(user_handler(get_channel_messages))
610 .add_request_handler(user_handler(get_channel_messages_by_id))
611 .add_request_handler(user_handler(get_notifications))
612 .add_request_handler(user_handler(mark_notification_as_read))
613 .add_request_handler(user_handler(move_channel))
614 .add_request_handler(user_handler(follow))
615 .add_message_handler(user_message_handler(unfollow))
616 .add_message_handler(user_message_handler(update_followers))
617 .add_request_handler(user_handler(get_private_user_info))
618 .add_message_handler(user_message_handler(acknowledge_channel_message))
619 .add_message_handler(user_message_handler(acknowledge_buffer_version))
620 .add_request_handler(user_handler(get_supermaven_api_key))
621 .add_request_handler(user_handler(
622 forward_mutating_project_request::<proto::OpenContext>,
623 ))
624 .add_request_handler(user_handler(
625 forward_mutating_project_request::<proto::CreateContext>,
626 ))
627 .add_request_handler(user_handler(
628 forward_mutating_project_request::<proto::SynchronizeContexts>,
629 ))
630 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
631 .add_message_handler(update_context)
632 .add_request_handler({
633 let app_state = app_state.clone();
634 move |request, response, session| {
635 let app_state = app_state.clone();
636 async move {
637 complete_with_language_model(request, response, session, &app_state.config)
638 .await
639 }
640 }
641 })
642 .add_streaming_request_handler({
643 let app_state = app_state.clone();
644 move |request, response, session| {
645 let app_state = app_state.clone();
646 async move {
647 stream_complete_with_language_model(
648 request,
649 response,
650 session,
651 &app_state.config,
652 )
653 .await
654 }
655 }
656 })
657 .add_request_handler({
658 let app_state = app_state.clone();
659 move |request, response, session| {
660 let app_state = app_state.clone();
661 async move {
662 count_language_model_tokens(request, response, session, &app_state.config)
663 .await
664 }
665 }
666 })
667 .add_request_handler({
668 user_handler(move |request, response, session| {
669 get_cached_embeddings(request, response, session)
670 })
671 })
672 .add_request_handler({
673 let app_state = app_state.clone();
674 user_handler(move |request, response, session| {
675 compute_embeddings(
676 request,
677 response,
678 session,
679 app_state.config.openai_api_key.clone(),
680 )
681 })
682 });
683
684 Arc::new(server)
685 }
686
687 pub async fn start(&self) -> Result<()> {
688 let server_id = *self.id.lock();
689 let app_state = self.app_state.clone();
690 let peer = self.peer.clone();
691 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
692 let pool = self.connection_pool.clone();
693 let live_kit_client = self.app_state.live_kit_client.clone();
694
695 let span = info_span!("start server");
696 self.app_state.executor.spawn_detached(
697 async move {
698 tracing::info!("waiting for cleanup timeout");
699 timeout.await;
700 tracing::info!("cleanup timeout expired, retrieving stale rooms");
701 if let Some((room_ids, channel_ids)) = app_state
702 .db
703 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
704 .await
705 .trace_err()
706 {
707 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
708 tracing::info!(
709 stale_channel_buffer_count = channel_ids.len(),
710 "retrieved stale channel buffers"
711 );
712
713 for channel_id in channel_ids {
714 if let Some(refreshed_channel_buffer) = app_state
715 .db
716 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
717 .await
718 .trace_err()
719 {
720 for connection_id in refreshed_channel_buffer.connection_ids {
721 peer.send(
722 connection_id,
723 proto::UpdateChannelBufferCollaborators {
724 channel_id: channel_id.to_proto(),
725 collaborators: refreshed_channel_buffer
726 .collaborators
727 .clone(),
728 },
729 )
730 .trace_err();
731 }
732 }
733 }
734
735 for room_id in room_ids {
736 let mut contacts_to_update = HashSet::default();
737 let mut canceled_calls_to_user_ids = Vec::new();
738 let mut live_kit_room = String::new();
739 let mut delete_live_kit_room = false;
740
741 if let Some(mut refreshed_room) = app_state
742 .db
743 .clear_stale_room_participants(room_id, server_id)
744 .await
745 .trace_err()
746 {
747 tracing::info!(
748 room_id = room_id.0,
749 new_participant_count = refreshed_room.room.participants.len(),
750 "refreshed room"
751 );
752 room_updated(&refreshed_room.room, &peer);
753 if let Some(channel) = refreshed_room.channel.as_ref() {
754 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
755 }
756 contacts_to_update
757 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
758 contacts_to_update
759 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
760 canceled_calls_to_user_ids =
761 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
762 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
763 delete_live_kit_room = refreshed_room.room.participants.is_empty();
764 }
765
766 {
767 let pool = pool.lock();
768 for canceled_user_id in canceled_calls_to_user_ids {
769 for connection_id in pool.user_connection_ids(canceled_user_id) {
770 peer.send(
771 connection_id,
772 proto::CallCanceled {
773 room_id: room_id.to_proto(),
774 },
775 )
776 .trace_err();
777 }
778 }
779 }
780
781 for user_id in contacts_to_update {
782 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
783 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
784 if let Some((busy, contacts)) = busy.zip(contacts) {
785 let pool = pool.lock();
786 let updated_contact = contact_for_user(user_id, busy, &pool);
787 for contact in contacts {
788 if let db::Contact::Accepted {
789 user_id: contact_user_id,
790 ..
791 } = contact
792 {
793 for contact_conn_id in
794 pool.user_connection_ids(contact_user_id)
795 {
796 peer.send(
797 contact_conn_id,
798 proto::UpdateContacts {
799 contacts: vec![updated_contact.clone()],
800 remove_contacts: Default::default(),
801 incoming_requests: Default::default(),
802 remove_incoming_requests: Default::default(),
803 outgoing_requests: Default::default(),
804 remove_outgoing_requests: Default::default(),
805 },
806 )
807 .trace_err();
808 }
809 }
810 }
811 }
812 }
813
814 if let Some(live_kit) = live_kit_client.as_ref() {
815 if delete_live_kit_room {
816 live_kit.delete_room(live_kit_room).await.trace_err();
817 }
818 }
819 }
820 }
821
822 app_state
823 .db
824 .delete_stale_servers(&app_state.config.zed_environment, server_id)
825 .await
826 .trace_err();
827 }
828 .instrument(span),
829 );
830 Ok(())
831 }
832
833 pub fn teardown(&self) {
834 self.peer.teardown();
835 self.connection_pool.lock().reset();
836 let _ = self.teardown.send(true);
837 }
838
839 #[cfg(test)]
840 pub fn reset(&self, id: ServerId) {
841 self.teardown();
842 *self.id.lock() = id;
843 self.peer.reset(id.0 as u32);
844 let _ = self.teardown.send(false);
845 }
846
847 #[cfg(test)]
848 pub fn id(&self) -> ServerId {
849 *self.id.lock()
850 }
851
852 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
853 where
854 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
855 Fut: 'static + Send + Future<Output = Result<()>>,
856 M: EnvelopedMessage,
857 {
858 let prev_handler = self.handlers.insert(
859 TypeId::of::<M>(),
860 Box::new(move |envelope, session| {
861 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
862 let received_at = envelope.received_at;
863 tracing::info!("message received");
864 let start_time = Instant::now();
865 let future = (handler)(*envelope, session);
866 async move {
867 let result = future.await;
868 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
869 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
870 let queue_duration_ms = total_duration_ms - processing_duration_ms;
871 let payload_type = M::NAME;
872
873 match result {
874 Err(error) => {
875 tracing::error!(
876 ?error,
877 total_duration_ms,
878 processing_duration_ms,
879 queue_duration_ms,
880 payload_type,
881 "error handling message"
882 )
883 }
884 Ok(()) => tracing::info!(
885 total_duration_ms,
886 processing_duration_ms,
887 queue_duration_ms,
888 "finished handling message"
889 ),
890 }
891 }
892 .boxed()
893 }),
894 );
895 if prev_handler.is_some() {
896 panic!("registered a handler for the same message twice");
897 }
898 self
899 }
900
901 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
902 where
903 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
904 Fut: 'static + Send + Future<Output = Result<()>>,
905 M: EnvelopedMessage,
906 {
907 self.add_handler(move |envelope, session| handler(envelope.payload, session));
908 self
909 }
910
911 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
912 where
913 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
914 Fut: Send + Future<Output = Result<()>>,
915 M: RequestMessage,
916 {
917 let handler = Arc::new(handler);
918 self.add_handler(move |envelope, session| {
919 let receipt = envelope.receipt();
920 let handler = handler.clone();
921 async move {
922 let peer = session.peer.clone();
923 let responded = Arc::new(AtomicBool::default());
924 let response = Response {
925 peer: peer.clone(),
926 responded: responded.clone(),
927 receipt,
928 };
929 match (handler)(envelope.payload, response, session).await {
930 Ok(()) => {
931 if responded.load(std::sync::atomic::Ordering::SeqCst) {
932 Ok(())
933 } else {
934 Err(anyhow!("handler did not send a response"))?
935 }
936 }
937 Err(error) => {
938 let proto_err = match &error {
939 Error::Internal(err) => err.to_proto(),
940 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
941 };
942 peer.respond_with_error(receipt, proto_err)?;
943 Err(error)
944 }
945 }
946 }
947 })
948 }
949
950 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
951 where
952 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
953 Fut: Send + Future<Output = Result<()>>,
954 M: RequestMessage,
955 {
956 let handler = Arc::new(handler);
957 self.add_handler(move |envelope, session| {
958 let receipt = envelope.receipt();
959 let handler = handler.clone();
960 async move {
961 let peer = session.peer.clone();
962 let response = StreamingResponse {
963 peer: peer.clone(),
964 receipt,
965 };
966 match (handler)(envelope.payload, response, session).await {
967 Ok(()) => {
968 peer.end_stream(receipt)?;
969 Ok(())
970 }
971 Err(error) => {
972 let proto_err = match &error {
973 Error::Internal(err) => err.to_proto(),
974 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
975 };
976 peer.respond_with_error(receipt, proto_err)?;
977 Err(error)
978 }
979 }
980 }
981 })
982 }
983
984 #[allow(clippy::too_many_arguments)]
985 pub fn handle_connection(
986 self: &Arc<Self>,
987 connection: Connection,
988 address: String,
989 principal: Principal,
990 zed_version: ZedVersion,
991 geoip_country_code: Option<String>,
992 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
993 executor: Executor,
994 ) -> impl Future<Output = ()> {
995 let this = self.clone();
996 let span = info_span!("handle connection", %address,
997 connection_id=field::Empty,
998 user_id=field::Empty,
999 login=field::Empty,
1000 impersonator=field::Empty,
1001 dev_server_id=field::Empty,
1002 geoip_country_code=field::Empty
1003 );
1004 principal.update_span(&span);
1005 if let Some(country_code) = geoip_country_code.as_ref() {
1006 span.record("geoip_country_code", country_code);
1007 }
1008
1009 let mut teardown = self.teardown.subscribe();
1010 async move {
1011 if *teardown.borrow() {
1012 tracing::error!("server is tearing down");
1013 return
1014 }
1015 let (connection_id, handle_io, mut incoming_rx) = this
1016 .peer
1017 .add_connection(connection, {
1018 let executor = executor.clone();
1019 move |duration| executor.sleep(duration)
1020 });
1021 tracing::Span::current().record("connection_id", format!("{}", connection_id));
1022
1023 tracing::info!("connection opened");
1024
1025 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
1026 let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
1027 Ok(http_client) => Arc::new(http_client),
1028 Err(error) => {
1029 tracing::error!(?error, "failed to create HTTP client");
1030 return;
1031 }
1032 };
1033
1034 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
1035 Some(Arc::new(SupermavenAdminApi::new(
1036 supermaven_admin_api_key.to_string(),
1037 http_client.clone(),
1038 )))
1039 } else {
1040 None
1041 };
1042
1043 let session = Session {
1044 principal: principal.clone(),
1045 connection_id,
1046 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1047 peer: this.peer.clone(),
1048 connection_pool: this.connection_pool.clone(),
1049 live_kit_client: this.app_state.live_kit_client.clone(),
1050 http_client,
1051 rate_limiter: this.app_state.rate_limiter.clone(),
1052 geoip_country_code,
1053 _executor: executor.clone(),
1054 supermaven_client,
1055 };
1056
1057 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1058 tracing::error!(?error, "failed to send initial client update");
1059 return;
1060 }
1061
1062 let handle_io = handle_io.fuse();
1063 futures::pin_mut!(handle_io);
1064
1065 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1066 // This prevents deadlocks when e.g., client A performs a request to client B and
1067 // client B performs a request to client A. If both clients stop processing further
1068 // messages until their respective request completes, they won't have a chance to
1069 // respond to the other client's request and cause a deadlock.
1070 //
1071 // This arrangement ensures we will attempt to process earlier messages first, but fall
1072 // back to processing messages arrived later in the spirit of making progress.
1073 let mut foreground_message_handlers = FuturesUnordered::new();
1074 let concurrent_handlers = Arc::new(Semaphore::new(256));
1075 loop {
1076 let next_message = async {
1077 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1078 let message = incoming_rx.next().await;
1079 (permit, message)
1080 }.fuse();
1081 futures::pin_mut!(next_message);
1082 futures::select_biased! {
1083 _ = teardown.changed().fuse() => return,
1084 result = handle_io => {
1085 if let Err(error) = result {
1086 tracing::error!(?error, "error handling I/O");
1087 }
1088 break;
1089 }
1090 _ = foreground_message_handlers.next() => {}
1091 next_message = next_message => {
1092 let (permit, message) = next_message;
1093 if let Some(message) = message {
1094 let type_name = message.payload_type_name();
1095 // note: we copy all the fields from the parent span so we can query them in the logs.
1096 // (https://github.com/tokio-rs/tracing/issues/2670).
1097 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1098 user_id=field::Empty,
1099 login=field::Empty,
1100 impersonator=field::Empty,
1101 dev_server_id=field::Empty
1102 );
1103 principal.update_span(&span);
1104 let span_enter = span.enter();
1105 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1106 let is_background = message.is_background();
1107 let handle_message = (handler)(message, session.clone());
1108 drop(span_enter);
1109
1110 let handle_message = async move {
1111 handle_message.await;
1112 drop(permit);
1113 }.instrument(span);
1114 if is_background {
1115 executor.spawn_detached(handle_message);
1116 } else {
1117 foreground_message_handlers.push(handle_message);
1118 }
1119 } else {
1120 tracing::error!("no message handler");
1121 }
1122 } else {
1123 tracing::info!("connection closed");
1124 break;
1125 }
1126 }
1127 }
1128 }
1129
1130 drop(foreground_message_handlers);
1131 tracing::info!("signing out");
1132 if let Err(error) = connection_lost(session, teardown, executor).await {
1133 tracing::error!(?error, "error signing out");
1134 }
1135
1136 }.instrument(span)
1137 }
1138
1139 async fn send_initial_client_update(
1140 &self,
1141 connection_id: ConnectionId,
1142 principal: &Principal,
1143 zed_version: ZedVersion,
1144 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1145 session: &Session,
1146 ) -> Result<()> {
1147 self.peer.send(
1148 connection_id,
1149 proto::Hello {
1150 peer_id: Some(connection_id.into()),
1151 },
1152 )?;
1153 tracing::info!("sent hello message");
1154 if let Some(send_connection_id) = send_connection_id.take() {
1155 let _ = send_connection_id.send(connection_id);
1156 }
1157
1158 match principal {
1159 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1160 if !user.connected_once {
1161 self.peer.send(connection_id, proto::ShowContacts {})?;
1162 self.app_state
1163 .db
1164 .set_user_connected_once(user.id, true)
1165 .await?;
1166 }
1167
1168 update_user_plan(user.id, session).await?;
1169
1170 let (contacts, dev_server_projects) = future::try_join(
1171 self.app_state.db.get_contacts(user.id),
1172 self.app_state.db.dev_server_projects_update(user.id),
1173 )
1174 .await?;
1175
1176 {
1177 let mut pool = self.connection_pool.lock();
1178 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1179 self.peer.send(
1180 connection_id,
1181 build_initial_contacts_update(contacts, &pool),
1182 )?;
1183 }
1184
1185 if should_auto_subscribe_to_channels(zed_version) {
1186 subscribe_user_to_channels(user.id, session).await?;
1187 }
1188
1189 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1190
1191 if let Some(incoming_call) =
1192 self.app_state.db.incoming_call_for_user(user.id).await?
1193 {
1194 self.peer.send(connection_id, incoming_call)?;
1195 }
1196
1197 update_user_contacts(user.id, &session).await?;
1198 }
1199 Principal::DevServer(dev_server) => {
1200 {
1201 let mut pool = self.connection_pool.lock();
1202 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1203 {
1204 self.peer.send(
1205 stale_connection_id,
1206 proto::ShutdownDevServer {
1207 reason: Some(
1208 "another dev server connected with the same token".to_string(),
1209 ),
1210 },
1211 )?;
1212 pool.remove_connection(stale_connection_id)?;
1213 };
1214 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1215 }
1216
1217 let projects = self
1218 .app_state
1219 .db
1220 .get_projects_for_dev_server(dev_server.id)
1221 .await?;
1222 self.peer
1223 .send(connection_id, proto::DevServerInstructions { projects })?;
1224
1225 let status = self
1226 .app_state
1227 .db
1228 .dev_server_projects_update(dev_server.user_id)
1229 .await?;
1230 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1231 }
1232 }
1233
1234 Ok(())
1235 }
1236
1237 pub async fn invite_code_redeemed(
1238 self: &Arc<Self>,
1239 inviter_id: UserId,
1240 invitee_id: UserId,
1241 ) -> Result<()> {
1242 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1243 if let Some(code) = &user.invite_code {
1244 let pool = self.connection_pool.lock();
1245 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1246 for connection_id in pool.user_connection_ids(inviter_id) {
1247 self.peer.send(
1248 connection_id,
1249 proto::UpdateContacts {
1250 contacts: vec![invitee_contact.clone()],
1251 ..Default::default()
1252 },
1253 )?;
1254 self.peer.send(
1255 connection_id,
1256 proto::UpdateInviteInfo {
1257 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1258 count: user.invite_count as u32,
1259 },
1260 )?;
1261 }
1262 }
1263 }
1264 Ok(())
1265 }
1266
1267 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1268 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1269 if let Some(invite_code) = &user.invite_code {
1270 let pool = self.connection_pool.lock();
1271 for connection_id in pool.user_connection_ids(user_id) {
1272 self.peer.send(
1273 connection_id,
1274 proto::UpdateInviteInfo {
1275 url: format!(
1276 "{}{}",
1277 self.app_state.config.invite_link_prefix, invite_code
1278 ),
1279 count: user.invite_count as u32,
1280 },
1281 )?;
1282 }
1283 }
1284 }
1285 Ok(())
1286 }
1287
1288 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1289 ServerSnapshot {
1290 connection_pool: ConnectionPoolGuard {
1291 guard: self.connection_pool.lock(),
1292 _not_send: PhantomData,
1293 },
1294 peer: &self.peer,
1295 }
1296 }
1297}
1298
1299impl<'a> Deref for ConnectionPoolGuard<'a> {
1300 type Target = ConnectionPool;
1301
1302 fn deref(&self) -> &Self::Target {
1303 &self.guard
1304 }
1305}
1306
1307impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1308 fn deref_mut(&mut self) -> &mut Self::Target {
1309 &mut self.guard
1310 }
1311}
1312
1313impl<'a> Drop for ConnectionPoolGuard<'a> {
1314 fn drop(&mut self) {
1315 #[cfg(test)]
1316 self.check_invariants();
1317 }
1318}
1319
1320fn broadcast<F>(
1321 sender_id: Option<ConnectionId>,
1322 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1323 mut f: F,
1324) where
1325 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1326{
1327 for receiver_id in receiver_ids {
1328 if Some(receiver_id) != sender_id {
1329 if let Err(error) = f(receiver_id) {
1330 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1331 }
1332 }
1333 }
1334}
1335
1336pub struct ProtocolVersion(u32);
1337
1338impl Header for ProtocolVersion {
1339 fn name() -> &'static HeaderName {
1340 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1341 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1342 }
1343
1344 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1345 where
1346 Self: Sized,
1347 I: Iterator<Item = &'i axum::http::HeaderValue>,
1348 {
1349 let version = values
1350 .next()
1351 .ok_or_else(axum::headers::Error::invalid)?
1352 .to_str()
1353 .map_err(|_| axum::headers::Error::invalid())?
1354 .parse()
1355 .map_err(|_| axum::headers::Error::invalid())?;
1356 Ok(Self(version))
1357 }
1358
1359 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1360 values.extend([self.0.to_string().parse().unwrap()]);
1361 }
1362}
1363
1364pub struct AppVersionHeader(SemanticVersion);
1365impl Header for AppVersionHeader {
1366 fn name() -> &'static HeaderName {
1367 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1368 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1369 }
1370
1371 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1372 where
1373 Self: Sized,
1374 I: Iterator<Item = &'i axum::http::HeaderValue>,
1375 {
1376 let version = values
1377 .next()
1378 .ok_or_else(axum::headers::Error::invalid)?
1379 .to_str()
1380 .map_err(|_| axum::headers::Error::invalid())?
1381 .parse()
1382 .map_err(|_| axum::headers::Error::invalid())?;
1383 Ok(Self(version))
1384 }
1385
1386 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1387 values.extend([self.0.to_string().parse().unwrap()]);
1388 }
1389}
1390
1391pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1392 Router::new()
1393 .route("/rpc", get(handle_websocket_request))
1394 .layer(
1395 ServiceBuilder::new()
1396 .layer(Extension(server.app_state.clone()))
1397 .layer(middleware::from_fn(auth::validate_header)),
1398 )
1399 .route("/metrics", get(handle_metrics))
1400 .layer(Extension(server))
1401}
1402
1403pub async fn handle_websocket_request(
1404 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1405 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1406 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1407 Extension(server): Extension<Arc<Server>>,
1408 Extension(principal): Extension<Principal>,
1409 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1410 ws: WebSocketUpgrade,
1411) -> axum::response::Response {
1412 if protocol_version != rpc::PROTOCOL_VERSION {
1413 return (
1414 StatusCode::UPGRADE_REQUIRED,
1415 "client must be upgraded".to_string(),
1416 )
1417 .into_response();
1418 }
1419
1420 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1421 return (
1422 StatusCode::UPGRADE_REQUIRED,
1423 "no version header found".to_string(),
1424 )
1425 .into_response();
1426 };
1427
1428 if !version.can_collaborate() {
1429 return (
1430 StatusCode::UPGRADE_REQUIRED,
1431 "client must be upgraded".to_string(),
1432 )
1433 .into_response();
1434 }
1435
1436 let socket_address = socket_address.to_string();
1437 ws.on_upgrade(move |socket| {
1438 let socket = socket
1439 .map_ok(to_tungstenite_message)
1440 .err_into()
1441 .with(|message| async move { to_axum_message(message) });
1442 let connection = Connection::new(Box::pin(socket));
1443 async move {
1444 server
1445 .handle_connection(
1446 connection,
1447 socket_address,
1448 principal,
1449 version,
1450 country_code_header.map(|header| header.to_string()),
1451 None,
1452 Executor::Production,
1453 )
1454 .await;
1455 }
1456 })
1457}
1458
1459pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1460 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1461 let connections_metric = CONNECTIONS_METRIC
1462 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1463
1464 let connections = server
1465 .connection_pool
1466 .lock()
1467 .connections()
1468 .filter(|connection| !connection.admin)
1469 .count();
1470 connections_metric.set(connections as _);
1471
1472 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1473 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1474 register_int_gauge!(
1475 "shared_projects",
1476 "number of open projects with one or more guests"
1477 )
1478 .unwrap()
1479 });
1480
1481 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1482 shared_projects_metric.set(shared_projects as _);
1483
1484 let encoder = prometheus::TextEncoder::new();
1485 let metric_families = prometheus::gather();
1486 let encoded_metrics = encoder
1487 .encode_to_string(&metric_families)
1488 .map_err(|err| anyhow!("{}", err))?;
1489 Ok(encoded_metrics)
1490}
1491
1492#[instrument(err, skip(executor))]
1493async fn connection_lost(
1494 session: Session,
1495 mut teardown: watch::Receiver<bool>,
1496 executor: Executor,
1497) -> Result<()> {
1498 session.peer.disconnect(session.connection_id);
1499 session
1500 .connection_pool()
1501 .await
1502 .remove_connection(session.connection_id)?;
1503
1504 session
1505 .db()
1506 .await
1507 .connection_lost(session.connection_id)
1508 .await
1509 .trace_err();
1510
1511 futures::select_biased! {
1512 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1513 match &session.principal {
1514 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1515 let session = session.for_user().unwrap();
1516
1517 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1518 leave_room_for_session(&session, session.connection_id).await.trace_err();
1519 leave_channel_buffers_for_session(&session)
1520 .await
1521 .trace_err();
1522
1523 if !session
1524 .connection_pool()
1525 .await
1526 .is_user_online(session.user_id())
1527 {
1528 let db = session.db().await;
1529 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1530 room_updated(&room, &session.peer);
1531 }
1532 }
1533
1534 update_user_contacts(session.user_id(), &session).await?;
1535 },
1536 Principal::DevServer(_) => {
1537 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1538 },
1539 }
1540 },
1541 _ = teardown.changed().fuse() => {}
1542 }
1543
1544 Ok(())
1545}
1546
1547/// Acknowledges a ping from a client, used to keep the connection alive.
1548async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1549 response.send(proto::Ack {})?;
1550 Ok(())
1551}
1552
1553/// Creates a new room for calling (outside of channels)
1554async fn create_room(
1555 _request: proto::CreateRoom,
1556 response: Response<proto::CreateRoom>,
1557 session: UserSession,
1558) -> Result<()> {
1559 let live_kit_room = nanoid::nanoid!(30);
1560
1561 let live_kit_connection_info = util::maybe!(async {
1562 let live_kit = session.live_kit_client.as_ref();
1563 let live_kit = live_kit?;
1564 let user_id = session.user_id().to_string();
1565
1566 let token = live_kit
1567 .room_token(&live_kit_room, &user_id.to_string())
1568 .trace_err()?;
1569
1570 Some(proto::LiveKitConnectionInfo {
1571 server_url: live_kit.url().into(),
1572 token,
1573 can_publish: true,
1574 })
1575 })
1576 .await;
1577
1578 let room = session
1579 .db()
1580 .await
1581 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1582 .await?;
1583
1584 response.send(proto::CreateRoomResponse {
1585 room: Some(room.clone()),
1586 live_kit_connection_info,
1587 })?;
1588
1589 update_user_contacts(session.user_id(), &session).await?;
1590 Ok(())
1591}
1592
1593/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1594async fn join_room(
1595 request: proto::JoinRoom,
1596 response: Response<proto::JoinRoom>,
1597 session: UserSession,
1598) -> Result<()> {
1599 let room_id = RoomId::from_proto(request.id);
1600
1601 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1602
1603 if let Some(channel_id) = channel_id {
1604 return join_channel_internal(channel_id, Box::new(response), session).await;
1605 }
1606
1607 let joined_room = {
1608 let room = session
1609 .db()
1610 .await
1611 .join_room(room_id, session.user_id(), session.connection_id)
1612 .await?;
1613 room_updated(&room.room, &session.peer);
1614 room.into_inner()
1615 };
1616
1617 for connection_id in session
1618 .connection_pool()
1619 .await
1620 .user_connection_ids(session.user_id())
1621 {
1622 session
1623 .peer
1624 .send(
1625 connection_id,
1626 proto::CallCanceled {
1627 room_id: room_id.to_proto(),
1628 },
1629 )
1630 .trace_err();
1631 }
1632
1633 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1634 if let Some(token) = live_kit
1635 .room_token(
1636 &joined_room.room.live_kit_room,
1637 &session.user_id().to_string(),
1638 )
1639 .trace_err()
1640 {
1641 Some(proto::LiveKitConnectionInfo {
1642 server_url: live_kit.url().into(),
1643 token,
1644 can_publish: true,
1645 })
1646 } else {
1647 None
1648 }
1649 } else {
1650 None
1651 };
1652
1653 response.send(proto::JoinRoomResponse {
1654 room: Some(joined_room.room),
1655 channel_id: None,
1656 live_kit_connection_info,
1657 })?;
1658
1659 update_user_contacts(session.user_id(), &session).await?;
1660 Ok(())
1661}
1662
1663/// Rejoin room is used to reconnect to a room after connection errors.
1664async fn rejoin_room(
1665 request: proto::RejoinRoom,
1666 response: Response<proto::RejoinRoom>,
1667 session: UserSession,
1668) -> Result<()> {
1669 let room;
1670 let channel;
1671 {
1672 let mut rejoined_room = session
1673 .db()
1674 .await
1675 .rejoin_room(request, session.user_id(), session.connection_id)
1676 .await?;
1677
1678 response.send(proto::RejoinRoomResponse {
1679 room: Some(rejoined_room.room.clone()),
1680 reshared_projects: rejoined_room
1681 .reshared_projects
1682 .iter()
1683 .map(|project| proto::ResharedProject {
1684 id: project.id.to_proto(),
1685 collaborators: project
1686 .collaborators
1687 .iter()
1688 .map(|collaborator| collaborator.to_proto())
1689 .collect(),
1690 })
1691 .collect(),
1692 rejoined_projects: rejoined_room
1693 .rejoined_projects
1694 .iter()
1695 .map(|rejoined_project| rejoined_project.to_proto())
1696 .collect(),
1697 })?;
1698 room_updated(&rejoined_room.room, &session.peer);
1699
1700 for project in &rejoined_room.reshared_projects {
1701 for collaborator in &project.collaborators {
1702 session
1703 .peer
1704 .send(
1705 collaborator.connection_id,
1706 proto::UpdateProjectCollaborator {
1707 project_id: project.id.to_proto(),
1708 old_peer_id: Some(project.old_connection_id.into()),
1709 new_peer_id: Some(session.connection_id.into()),
1710 },
1711 )
1712 .trace_err();
1713 }
1714
1715 broadcast(
1716 Some(session.connection_id),
1717 project
1718 .collaborators
1719 .iter()
1720 .map(|collaborator| collaborator.connection_id),
1721 |connection_id| {
1722 session.peer.forward_send(
1723 session.connection_id,
1724 connection_id,
1725 proto::UpdateProject {
1726 project_id: project.id.to_proto(),
1727 worktrees: project.worktrees.clone(),
1728 },
1729 )
1730 },
1731 );
1732 }
1733
1734 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1735
1736 let rejoined_room = rejoined_room.into_inner();
1737
1738 room = rejoined_room.room;
1739 channel = rejoined_room.channel;
1740 }
1741
1742 if let Some(channel) = channel {
1743 channel_updated(
1744 &channel,
1745 &room,
1746 &session.peer,
1747 &*session.connection_pool().await,
1748 );
1749 }
1750
1751 update_user_contacts(session.user_id(), &session).await?;
1752 Ok(())
1753}
1754
1755fn notify_rejoined_projects(
1756 rejoined_projects: &mut Vec<RejoinedProject>,
1757 session: &UserSession,
1758) -> Result<()> {
1759 for project in rejoined_projects.iter() {
1760 for collaborator in &project.collaborators {
1761 session
1762 .peer
1763 .send(
1764 collaborator.connection_id,
1765 proto::UpdateProjectCollaborator {
1766 project_id: project.id.to_proto(),
1767 old_peer_id: Some(project.old_connection_id.into()),
1768 new_peer_id: Some(session.connection_id.into()),
1769 },
1770 )
1771 .trace_err();
1772 }
1773 }
1774
1775 for project in rejoined_projects {
1776 for worktree in mem::take(&mut project.worktrees) {
1777 #[cfg(any(test, feature = "test-support"))]
1778 const MAX_CHUNK_SIZE: usize = 2;
1779 #[cfg(not(any(test, feature = "test-support")))]
1780 const MAX_CHUNK_SIZE: usize = 256;
1781
1782 // Stream this worktree's entries.
1783 let message = proto::UpdateWorktree {
1784 project_id: project.id.to_proto(),
1785 worktree_id: worktree.id,
1786 abs_path: worktree.abs_path.clone(),
1787 root_name: worktree.root_name,
1788 updated_entries: worktree.updated_entries,
1789 removed_entries: worktree.removed_entries,
1790 scan_id: worktree.scan_id,
1791 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1792 updated_repositories: worktree.updated_repositories,
1793 removed_repositories: worktree.removed_repositories,
1794 };
1795 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1796 session.peer.send(session.connection_id, update.clone())?;
1797 }
1798
1799 // Stream this worktree's diagnostics.
1800 for summary in worktree.diagnostic_summaries {
1801 session.peer.send(
1802 session.connection_id,
1803 proto::UpdateDiagnosticSummary {
1804 project_id: project.id.to_proto(),
1805 worktree_id: worktree.id,
1806 summary: Some(summary),
1807 },
1808 )?;
1809 }
1810
1811 for settings_file in worktree.settings_files {
1812 session.peer.send(
1813 session.connection_id,
1814 proto::UpdateWorktreeSettings {
1815 project_id: project.id.to_proto(),
1816 worktree_id: worktree.id,
1817 path: settings_file.path,
1818 content: Some(settings_file.content),
1819 },
1820 )?;
1821 }
1822 }
1823
1824 for language_server in &project.language_servers {
1825 session.peer.send(
1826 session.connection_id,
1827 proto::UpdateLanguageServer {
1828 project_id: project.id.to_proto(),
1829 language_server_id: language_server.id,
1830 variant: Some(
1831 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1832 proto::LspDiskBasedDiagnosticsUpdated {},
1833 ),
1834 ),
1835 },
1836 )?;
1837 }
1838 }
1839 Ok(())
1840}
1841
1842/// leave room disconnects from the room.
1843async fn leave_room(
1844 _: proto::LeaveRoom,
1845 response: Response<proto::LeaveRoom>,
1846 session: UserSession,
1847) -> Result<()> {
1848 leave_room_for_session(&session, session.connection_id).await?;
1849 response.send(proto::Ack {})?;
1850 Ok(())
1851}
1852
1853/// Updates the permissions of someone else in the room.
1854async fn set_room_participant_role(
1855 request: proto::SetRoomParticipantRole,
1856 response: Response<proto::SetRoomParticipantRole>,
1857 session: UserSession,
1858) -> Result<()> {
1859 let user_id = UserId::from_proto(request.user_id);
1860 let role = ChannelRole::from(request.role());
1861
1862 let (live_kit_room, can_publish) = {
1863 let room = session
1864 .db()
1865 .await
1866 .set_room_participant_role(
1867 session.user_id(),
1868 RoomId::from_proto(request.room_id),
1869 user_id,
1870 role,
1871 )
1872 .await?;
1873
1874 let live_kit_room = room.live_kit_room.clone();
1875 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1876 room_updated(&room, &session.peer);
1877 (live_kit_room, can_publish)
1878 };
1879
1880 if let Some(live_kit) = session.live_kit_client.as_ref() {
1881 live_kit
1882 .update_participant(
1883 live_kit_room.clone(),
1884 request.user_id.to_string(),
1885 live_kit_server::proto::ParticipantPermission {
1886 can_subscribe: true,
1887 can_publish,
1888 can_publish_data: can_publish,
1889 hidden: false,
1890 recorder: false,
1891 },
1892 )
1893 .await
1894 .trace_err();
1895 }
1896
1897 response.send(proto::Ack {})?;
1898 Ok(())
1899}
1900
1901/// Call someone else into the current room
1902async fn call(
1903 request: proto::Call,
1904 response: Response<proto::Call>,
1905 session: UserSession,
1906) -> Result<()> {
1907 let room_id = RoomId::from_proto(request.room_id);
1908 let calling_user_id = session.user_id();
1909 let calling_connection_id = session.connection_id;
1910 let called_user_id = UserId::from_proto(request.called_user_id);
1911 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1912 if !session
1913 .db()
1914 .await
1915 .has_contact(calling_user_id, called_user_id)
1916 .await?
1917 {
1918 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1919 }
1920
1921 let incoming_call = {
1922 let (room, incoming_call) = &mut *session
1923 .db()
1924 .await
1925 .call(
1926 room_id,
1927 calling_user_id,
1928 calling_connection_id,
1929 called_user_id,
1930 initial_project_id,
1931 )
1932 .await?;
1933 room_updated(&room, &session.peer);
1934 mem::take(incoming_call)
1935 };
1936 update_user_contacts(called_user_id, &session).await?;
1937
1938 let mut calls = session
1939 .connection_pool()
1940 .await
1941 .user_connection_ids(called_user_id)
1942 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1943 .collect::<FuturesUnordered<_>>();
1944
1945 while let Some(call_response) = calls.next().await {
1946 match call_response.as_ref() {
1947 Ok(_) => {
1948 response.send(proto::Ack {})?;
1949 return Ok(());
1950 }
1951 Err(_) => {
1952 call_response.trace_err();
1953 }
1954 }
1955 }
1956
1957 {
1958 let room = session
1959 .db()
1960 .await
1961 .call_failed(room_id, called_user_id)
1962 .await?;
1963 room_updated(&room, &session.peer);
1964 }
1965 update_user_contacts(called_user_id, &session).await?;
1966
1967 Err(anyhow!("failed to ring user"))?
1968}
1969
1970/// Cancel an outgoing call.
1971async fn cancel_call(
1972 request: proto::CancelCall,
1973 response: Response<proto::CancelCall>,
1974 session: UserSession,
1975) -> Result<()> {
1976 let called_user_id = UserId::from_proto(request.called_user_id);
1977 let room_id = RoomId::from_proto(request.room_id);
1978 {
1979 let room = session
1980 .db()
1981 .await
1982 .cancel_call(room_id, session.connection_id, called_user_id)
1983 .await?;
1984 room_updated(&room, &session.peer);
1985 }
1986
1987 for connection_id in session
1988 .connection_pool()
1989 .await
1990 .user_connection_ids(called_user_id)
1991 {
1992 session
1993 .peer
1994 .send(
1995 connection_id,
1996 proto::CallCanceled {
1997 room_id: room_id.to_proto(),
1998 },
1999 )
2000 .trace_err();
2001 }
2002 response.send(proto::Ack {})?;
2003
2004 update_user_contacts(called_user_id, &session).await?;
2005 Ok(())
2006}
2007
2008/// Decline an incoming call.
2009async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
2010 let room_id = RoomId::from_proto(message.room_id);
2011 {
2012 let room = session
2013 .db()
2014 .await
2015 .decline_call(Some(room_id), session.user_id())
2016 .await?
2017 .ok_or_else(|| anyhow!("failed to decline call"))?;
2018 room_updated(&room, &session.peer);
2019 }
2020
2021 for connection_id in session
2022 .connection_pool()
2023 .await
2024 .user_connection_ids(session.user_id())
2025 {
2026 session
2027 .peer
2028 .send(
2029 connection_id,
2030 proto::CallCanceled {
2031 room_id: room_id.to_proto(),
2032 },
2033 )
2034 .trace_err();
2035 }
2036 update_user_contacts(session.user_id(), &session).await?;
2037 Ok(())
2038}
2039
2040/// Updates other participants in the room with your current location.
2041async fn update_participant_location(
2042 request: proto::UpdateParticipantLocation,
2043 response: Response<proto::UpdateParticipantLocation>,
2044 session: UserSession,
2045) -> Result<()> {
2046 let room_id = RoomId::from_proto(request.room_id);
2047 let location = request
2048 .location
2049 .ok_or_else(|| anyhow!("invalid location"))?;
2050
2051 let db = session.db().await;
2052 let room = db
2053 .update_room_participant_location(room_id, session.connection_id, location)
2054 .await?;
2055
2056 room_updated(&room, &session.peer);
2057 response.send(proto::Ack {})?;
2058 Ok(())
2059}
2060
2061/// Share a project into the room.
2062async fn share_project(
2063 request: proto::ShareProject,
2064 response: Response<proto::ShareProject>,
2065 session: UserSession,
2066) -> Result<()> {
2067 let (project_id, room) = &*session
2068 .db()
2069 .await
2070 .share_project(
2071 RoomId::from_proto(request.room_id),
2072 session.connection_id,
2073 &request.worktrees,
2074 request
2075 .dev_server_project_id
2076 .map(|id| DevServerProjectId::from_proto(id)),
2077 )
2078 .await?;
2079 response.send(proto::ShareProjectResponse {
2080 project_id: project_id.to_proto(),
2081 })?;
2082 room_updated(&room, &session.peer);
2083
2084 Ok(())
2085}
2086
2087/// Unshare a project from the room.
2088async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2089 let project_id = ProjectId::from_proto(message.project_id);
2090 unshare_project_internal(
2091 project_id,
2092 session.connection_id,
2093 session.user_id(),
2094 &session,
2095 )
2096 .await
2097}
2098
2099async fn unshare_project_internal(
2100 project_id: ProjectId,
2101 connection_id: ConnectionId,
2102 user_id: Option<UserId>,
2103 session: &Session,
2104) -> Result<()> {
2105 let delete = {
2106 let room_guard = session
2107 .db()
2108 .await
2109 .unshare_project(project_id, connection_id, user_id)
2110 .await?;
2111
2112 let (delete, room, guest_connection_ids) = &*room_guard;
2113
2114 let message = proto::UnshareProject {
2115 project_id: project_id.to_proto(),
2116 };
2117
2118 broadcast(
2119 Some(connection_id),
2120 guest_connection_ids.iter().copied(),
2121 |conn_id| session.peer.send(conn_id, message.clone()),
2122 );
2123 if let Some(room) = room {
2124 room_updated(room, &session.peer);
2125 }
2126
2127 *delete
2128 };
2129
2130 if delete {
2131 let db = session.db().await;
2132 db.delete_project(project_id).await?;
2133 }
2134
2135 Ok(())
2136}
2137
2138/// DevServer makes a project available online
2139async fn share_dev_server_project(
2140 request: proto::ShareDevServerProject,
2141 response: Response<proto::ShareDevServerProject>,
2142 session: DevServerSession,
2143) -> Result<()> {
2144 let (dev_server_project, user_id, status) = session
2145 .db()
2146 .await
2147 .share_dev_server_project(
2148 DevServerProjectId::from_proto(request.dev_server_project_id),
2149 session.dev_server_id(),
2150 session.connection_id,
2151 &request.worktrees,
2152 )
2153 .await?;
2154 let Some(project_id) = dev_server_project.project_id else {
2155 return Err(anyhow!("failed to share remote project"))?;
2156 };
2157
2158 send_dev_server_projects_update(user_id, status, &session).await;
2159
2160 response.send(proto::ShareProjectResponse { project_id })?;
2161
2162 Ok(())
2163}
2164
2165/// Join someone elses shared project.
2166async fn join_project(
2167 request: proto::JoinProject,
2168 response: Response<proto::JoinProject>,
2169 session: UserSession,
2170) -> Result<()> {
2171 let project_id = ProjectId::from_proto(request.project_id);
2172
2173 tracing::info!(%project_id, "join project");
2174
2175 let db = session.db().await;
2176 let (project, replica_id) = &mut *db
2177 .join_project(project_id, session.connection_id, session.user_id())
2178 .await?;
2179 drop(db);
2180 tracing::info!(%project_id, "join remote project");
2181 join_project_internal(response, session, project, replica_id)
2182}
2183
2184trait JoinProjectInternalResponse {
2185 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2186}
2187impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2188 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2189 Response::<proto::JoinProject>::send(self, result)
2190 }
2191}
2192impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2193 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2194 Response::<proto::JoinHostedProject>::send(self, result)
2195 }
2196}
2197
2198fn join_project_internal(
2199 response: impl JoinProjectInternalResponse,
2200 session: UserSession,
2201 project: &mut Project,
2202 replica_id: &ReplicaId,
2203) -> Result<()> {
2204 let collaborators = project
2205 .collaborators
2206 .iter()
2207 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2208 .map(|collaborator| collaborator.to_proto())
2209 .collect::<Vec<_>>();
2210 let project_id = project.id;
2211 let guest_user_id = session.user_id();
2212
2213 let worktrees = project
2214 .worktrees
2215 .iter()
2216 .map(|(id, worktree)| proto::WorktreeMetadata {
2217 id: *id,
2218 root_name: worktree.root_name.clone(),
2219 visible: worktree.visible,
2220 abs_path: worktree.abs_path.clone(),
2221 })
2222 .collect::<Vec<_>>();
2223
2224 let add_project_collaborator = proto::AddProjectCollaborator {
2225 project_id: project_id.to_proto(),
2226 collaborator: Some(proto::Collaborator {
2227 peer_id: Some(session.connection_id.into()),
2228 replica_id: replica_id.0 as u32,
2229 user_id: guest_user_id.to_proto(),
2230 }),
2231 };
2232
2233 for collaborator in &collaborators {
2234 session
2235 .peer
2236 .send(
2237 collaborator.peer_id.unwrap().into(),
2238 add_project_collaborator.clone(),
2239 )
2240 .trace_err();
2241 }
2242
2243 // First, we send the metadata associated with each worktree.
2244 response.send(proto::JoinProjectResponse {
2245 project_id: project.id.0 as u64,
2246 worktrees: worktrees.clone(),
2247 replica_id: replica_id.0 as u32,
2248 collaborators: collaborators.clone(),
2249 language_servers: project.language_servers.clone(),
2250 role: project.role.into(),
2251 dev_server_project_id: project
2252 .dev_server_project_id
2253 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2254 })?;
2255
2256 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2257 #[cfg(any(test, feature = "test-support"))]
2258 const MAX_CHUNK_SIZE: usize = 2;
2259 #[cfg(not(any(test, feature = "test-support")))]
2260 const MAX_CHUNK_SIZE: usize = 256;
2261
2262 // Stream this worktree's entries.
2263 let message = proto::UpdateWorktree {
2264 project_id: project_id.to_proto(),
2265 worktree_id,
2266 abs_path: worktree.abs_path.clone(),
2267 root_name: worktree.root_name,
2268 updated_entries: worktree.entries,
2269 removed_entries: Default::default(),
2270 scan_id: worktree.scan_id,
2271 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2272 updated_repositories: worktree.repository_entries.into_values().collect(),
2273 removed_repositories: Default::default(),
2274 };
2275 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2276 session.peer.send(session.connection_id, update.clone())?;
2277 }
2278
2279 // Stream this worktree's diagnostics.
2280 for summary in worktree.diagnostic_summaries {
2281 session.peer.send(
2282 session.connection_id,
2283 proto::UpdateDiagnosticSummary {
2284 project_id: project_id.to_proto(),
2285 worktree_id: worktree.id,
2286 summary: Some(summary),
2287 },
2288 )?;
2289 }
2290
2291 for settings_file in worktree.settings_files {
2292 session.peer.send(
2293 session.connection_id,
2294 proto::UpdateWorktreeSettings {
2295 project_id: project_id.to_proto(),
2296 worktree_id: worktree.id,
2297 path: settings_file.path,
2298 content: Some(settings_file.content),
2299 },
2300 )?;
2301 }
2302 }
2303
2304 for language_server in &project.language_servers {
2305 session.peer.send(
2306 session.connection_id,
2307 proto::UpdateLanguageServer {
2308 project_id: project_id.to_proto(),
2309 language_server_id: language_server.id,
2310 variant: Some(
2311 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2312 proto::LspDiskBasedDiagnosticsUpdated {},
2313 ),
2314 ),
2315 },
2316 )?;
2317 }
2318
2319 Ok(())
2320}
2321
2322/// Leave someone elses shared project.
2323async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2324 let sender_id = session.connection_id;
2325 let project_id = ProjectId::from_proto(request.project_id);
2326 let db = session.db().await;
2327 if db.is_hosted_project(project_id).await? {
2328 let project = db.leave_hosted_project(project_id, sender_id).await?;
2329 project_left(&project, &session);
2330 return Ok(());
2331 }
2332
2333 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2334 tracing::info!(
2335 %project_id,
2336 "leave project"
2337 );
2338
2339 project_left(&project, &session);
2340 if let Some(room) = room {
2341 room_updated(&room, &session.peer);
2342 }
2343
2344 Ok(())
2345}
2346
2347async fn join_hosted_project(
2348 request: proto::JoinHostedProject,
2349 response: Response<proto::JoinHostedProject>,
2350 session: UserSession,
2351) -> Result<()> {
2352 let (mut project, replica_id) = session
2353 .db()
2354 .await
2355 .join_hosted_project(
2356 ProjectId(request.project_id as i32),
2357 session.user_id(),
2358 session.connection_id,
2359 )
2360 .await?;
2361
2362 join_project_internal(response, session, &mut project, &replica_id)
2363}
2364
2365async fn list_remote_directory(
2366 request: proto::ListRemoteDirectory,
2367 response: Response<proto::ListRemoteDirectory>,
2368 session: UserSession,
2369) -> Result<()> {
2370 let dev_server_id = DevServerId(request.dev_server_id as i32);
2371 let dev_server_connection_id = session
2372 .connection_pool()
2373 .await
2374 .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2375
2376 session
2377 .db()
2378 .await
2379 .get_dev_server_for_user(dev_server_id, session.user_id())
2380 .await?;
2381
2382 response.send(
2383 session
2384 .peer
2385 .forward_request(session.connection_id, dev_server_connection_id, request)
2386 .await?,
2387 )?;
2388 Ok(())
2389}
2390
2391async fn update_dev_server_project(
2392 request: proto::UpdateDevServerProject,
2393 response: Response<proto::UpdateDevServerProject>,
2394 session: UserSession,
2395) -> Result<()> {
2396 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2397
2398 let (dev_server_project, update) = session
2399 .db()
2400 .await
2401 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2402 .await?;
2403
2404 let projects = session
2405 .db()
2406 .await
2407 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2408 .await?;
2409
2410 let dev_server_connection_id = session
2411 .connection_pool()
2412 .await
2413 .dev_server_connection_id_supporting(
2414 dev_server_project.dev_server_id,
2415 ZedVersion::with_list_directory(),
2416 )?;
2417
2418 session.peer.send(
2419 dev_server_connection_id,
2420 proto::DevServerInstructions { projects },
2421 )?;
2422
2423 send_dev_server_projects_update(session.user_id(), update, &session).await;
2424
2425 response.send(proto::Ack {})
2426}
2427
2428async fn create_dev_server_project(
2429 request: proto::CreateDevServerProject,
2430 response: Response<proto::CreateDevServerProject>,
2431 session: UserSession,
2432) -> Result<()> {
2433 let dev_server_id = DevServerId(request.dev_server_id as i32);
2434 let dev_server_connection_id = session
2435 .connection_pool()
2436 .await
2437 .dev_server_connection_id(dev_server_id);
2438 let Some(dev_server_connection_id) = dev_server_connection_id else {
2439 Err(ErrorCode::DevServerOffline
2440 .message("Cannot create a remote project when the dev server is offline".to_string())
2441 .anyhow())?
2442 };
2443
2444 let path = request.path.clone();
2445 //Check that the path exists on the dev server
2446 session
2447 .peer
2448 .forward_request(
2449 session.connection_id,
2450 dev_server_connection_id,
2451 proto::ValidateDevServerProjectRequest { path: path.clone() },
2452 )
2453 .await?;
2454
2455 let (dev_server_project, update) = session
2456 .db()
2457 .await
2458 .create_dev_server_project(
2459 DevServerId(request.dev_server_id as i32),
2460 &request.path,
2461 session.user_id(),
2462 )
2463 .await?;
2464
2465 let projects = session
2466 .db()
2467 .await
2468 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2469 .await?;
2470
2471 session.peer.send(
2472 dev_server_connection_id,
2473 proto::DevServerInstructions { projects },
2474 )?;
2475
2476 send_dev_server_projects_update(session.user_id(), update, &session).await;
2477
2478 response.send(proto::CreateDevServerProjectResponse {
2479 dev_server_project: Some(dev_server_project.to_proto(None)),
2480 })?;
2481 Ok(())
2482}
2483
2484async fn create_dev_server(
2485 request: proto::CreateDevServer,
2486 response: Response<proto::CreateDevServer>,
2487 session: UserSession,
2488) -> Result<()> {
2489 let access_token = auth::random_token();
2490 let hashed_access_token = auth::hash_access_token(&access_token);
2491
2492 if request.name.is_empty() {
2493 return Err(proto::ErrorCode::Forbidden
2494 .message("Dev server name cannot be empty".to_string())
2495 .anyhow())?;
2496 }
2497
2498 let (dev_server, status) = session
2499 .db()
2500 .await
2501 .create_dev_server(
2502 &request.name,
2503 request.ssh_connection_string.as_deref(),
2504 &hashed_access_token,
2505 session.user_id(),
2506 )
2507 .await?;
2508
2509 send_dev_server_projects_update(session.user_id(), status, &session).await;
2510
2511 response.send(proto::CreateDevServerResponse {
2512 dev_server_id: dev_server.id.0 as u64,
2513 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2514 name: request.name,
2515 })?;
2516 Ok(())
2517}
2518
2519async fn regenerate_dev_server_token(
2520 request: proto::RegenerateDevServerToken,
2521 response: Response<proto::RegenerateDevServerToken>,
2522 session: UserSession,
2523) -> Result<()> {
2524 let dev_server_id = DevServerId(request.dev_server_id as i32);
2525 let access_token = auth::random_token();
2526 let hashed_access_token = auth::hash_access_token(&access_token);
2527
2528 let connection_id = session
2529 .connection_pool()
2530 .await
2531 .dev_server_connection_id(dev_server_id);
2532 if let Some(connection_id) = connection_id {
2533 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2534 session.peer.send(
2535 connection_id,
2536 proto::ShutdownDevServer {
2537 reason: Some("dev server token was regenerated".to_string()),
2538 },
2539 )?;
2540 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2541 }
2542
2543 let status = session
2544 .db()
2545 .await
2546 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2547 .await?;
2548
2549 send_dev_server_projects_update(session.user_id(), status, &session).await;
2550
2551 response.send(proto::RegenerateDevServerTokenResponse {
2552 dev_server_id: dev_server_id.to_proto(),
2553 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2554 })?;
2555 Ok(())
2556}
2557
2558async fn rename_dev_server(
2559 request: proto::RenameDevServer,
2560 response: Response<proto::RenameDevServer>,
2561 session: UserSession,
2562) -> Result<()> {
2563 if request.name.trim().is_empty() {
2564 return Err(proto::ErrorCode::Forbidden
2565 .message("Dev server name cannot be empty".to_string())
2566 .anyhow())?;
2567 }
2568
2569 let dev_server_id = DevServerId(request.dev_server_id as i32);
2570 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2571 if dev_server.user_id != session.user_id() {
2572 return Err(anyhow!(ErrorCode::Forbidden))?;
2573 }
2574
2575 let status = session
2576 .db()
2577 .await
2578 .rename_dev_server(
2579 dev_server_id,
2580 &request.name,
2581 request.ssh_connection_string.as_deref(),
2582 session.user_id(),
2583 )
2584 .await?;
2585
2586 send_dev_server_projects_update(session.user_id(), status, &session).await;
2587
2588 response.send(proto::Ack {})?;
2589 Ok(())
2590}
2591
2592async fn delete_dev_server(
2593 request: proto::DeleteDevServer,
2594 response: Response<proto::DeleteDevServer>,
2595 session: UserSession,
2596) -> Result<()> {
2597 let dev_server_id = DevServerId(request.dev_server_id as i32);
2598 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2599 if dev_server.user_id != session.user_id() {
2600 return Err(anyhow!(ErrorCode::Forbidden))?;
2601 }
2602
2603 let connection_id = session
2604 .connection_pool()
2605 .await
2606 .dev_server_connection_id(dev_server_id);
2607 if let Some(connection_id) = connection_id {
2608 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2609 session.peer.send(
2610 connection_id,
2611 proto::ShutdownDevServer {
2612 reason: Some("dev server was deleted".to_string()),
2613 },
2614 )?;
2615 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2616 }
2617
2618 let status = session
2619 .db()
2620 .await
2621 .delete_dev_server(dev_server_id, session.user_id())
2622 .await?;
2623
2624 send_dev_server_projects_update(session.user_id(), status, &session).await;
2625
2626 response.send(proto::Ack {})?;
2627 Ok(())
2628}
2629
2630async fn delete_dev_server_project(
2631 request: proto::DeleteDevServerProject,
2632 response: Response<proto::DeleteDevServerProject>,
2633 session: UserSession,
2634) -> Result<()> {
2635 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2636 let dev_server_project = session
2637 .db()
2638 .await
2639 .get_dev_server_project(dev_server_project_id)
2640 .await?;
2641
2642 let dev_server = session
2643 .db()
2644 .await
2645 .get_dev_server(dev_server_project.dev_server_id)
2646 .await?;
2647 if dev_server.user_id != session.user_id() {
2648 return Err(anyhow!(ErrorCode::Forbidden))?;
2649 }
2650
2651 let dev_server_connection_id = session
2652 .connection_pool()
2653 .await
2654 .dev_server_connection_id(dev_server.id);
2655
2656 if let Some(dev_server_connection_id) = dev_server_connection_id {
2657 let project = session
2658 .db()
2659 .await
2660 .find_dev_server_project(dev_server_project_id)
2661 .await;
2662 if let Ok(project) = project {
2663 unshare_project_internal(
2664 project.id,
2665 dev_server_connection_id,
2666 Some(session.user_id()),
2667 &session,
2668 )
2669 .await?;
2670 }
2671 }
2672
2673 let (projects, status) = session
2674 .db()
2675 .await
2676 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2677 .await?;
2678
2679 if let Some(dev_server_connection_id) = dev_server_connection_id {
2680 session.peer.send(
2681 dev_server_connection_id,
2682 proto::DevServerInstructions { projects },
2683 )?;
2684 }
2685
2686 send_dev_server_projects_update(session.user_id(), status, &session).await;
2687
2688 response.send(proto::Ack {})?;
2689 Ok(())
2690}
2691
2692async fn rejoin_dev_server_projects(
2693 request: proto::RejoinRemoteProjects,
2694 response: Response<proto::RejoinRemoteProjects>,
2695 session: UserSession,
2696) -> Result<()> {
2697 let mut rejoined_projects = {
2698 let db = session.db().await;
2699 db.rejoin_dev_server_projects(
2700 &request.rejoined_projects,
2701 session.user_id(),
2702 session.0.connection_id,
2703 )
2704 .await?
2705 };
2706 response.send(proto::RejoinRemoteProjectsResponse {
2707 rejoined_projects: rejoined_projects
2708 .iter()
2709 .map(|project| project.to_proto())
2710 .collect(),
2711 })?;
2712 notify_rejoined_projects(&mut rejoined_projects, &session)
2713}
2714
2715async fn reconnect_dev_server(
2716 request: proto::ReconnectDevServer,
2717 response: Response<proto::ReconnectDevServer>,
2718 session: DevServerSession,
2719) -> Result<()> {
2720 let reshared_projects = {
2721 let db = session.db().await;
2722 db.reshare_dev_server_projects(
2723 &request.reshared_projects,
2724 session.dev_server_id(),
2725 session.0.connection_id,
2726 )
2727 .await?
2728 };
2729
2730 for project in &reshared_projects {
2731 for collaborator in &project.collaborators {
2732 session
2733 .peer
2734 .send(
2735 collaborator.connection_id,
2736 proto::UpdateProjectCollaborator {
2737 project_id: project.id.to_proto(),
2738 old_peer_id: Some(project.old_connection_id.into()),
2739 new_peer_id: Some(session.connection_id.into()),
2740 },
2741 )
2742 .trace_err();
2743 }
2744
2745 broadcast(
2746 Some(session.connection_id),
2747 project
2748 .collaborators
2749 .iter()
2750 .map(|collaborator| collaborator.connection_id),
2751 |connection_id| {
2752 session.peer.forward_send(
2753 session.connection_id,
2754 connection_id,
2755 proto::UpdateProject {
2756 project_id: project.id.to_proto(),
2757 worktrees: project.worktrees.clone(),
2758 },
2759 )
2760 },
2761 );
2762 }
2763
2764 response.send(proto::ReconnectDevServerResponse {
2765 reshared_projects: reshared_projects
2766 .iter()
2767 .map(|project| proto::ResharedProject {
2768 id: project.id.to_proto(),
2769 collaborators: project
2770 .collaborators
2771 .iter()
2772 .map(|collaborator| collaborator.to_proto())
2773 .collect(),
2774 })
2775 .collect(),
2776 })?;
2777
2778 Ok(())
2779}
2780
2781async fn shutdown_dev_server(
2782 _: proto::ShutdownDevServer,
2783 response: Response<proto::ShutdownDevServer>,
2784 session: DevServerSession,
2785) -> Result<()> {
2786 response.send(proto::Ack {})?;
2787 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2788 remove_dev_server_connection(session.dev_server_id(), &session).await
2789}
2790
2791async fn shutdown_dev_server_internal(
2792 dev_server_id: DevServerId,
2793 connection_id: ConnectionId,
2794 session: &Session,
2795) -> Result<()> {
2796 let (dev_server_projects, dev_server) = {
2797 let db = session.db().await;
2798 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2799 let dev_server = db.get_dev_server(dev_server_id).await?;
2800 (dev_server_projects, dev_server)
2801 };
2802
2803 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2804 unshare_project_internal(
2805 ProjectId::from_proto(project_id),
2806 connection_id,
2807 None,
2808 session,
2809 )
2810 .await?;
2811 }
2812
2813 session
2814 .connection_pool()
2815 .await
2816 .set_dev_server_offline(dev_server_id);
2817
2818 let status = session
2819 .db()
2820 .await
2821 .dev_server_projects_update(dev_server.user_id)
2822 .await?;
2823 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2824
2825 Ok(())
2826}
2827
2828async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2829 let dev_server_connection = session
2830 .connection_pool()
2831 .await
2832 .dev_server_connection_id(dev_server_id);
2833
2834 if let Some(dev_server_connection) = dev_server_connection {
2835 session
2836 .connection_pool()
2837 .await
2838 .remove_connection(dev_server_connection)?;
2839 }
2840 Ok(())
2841}
2842
2843/// Updates other participants with changes to the project
2844async fn update_project(
2845 request: proto::UpdateProject,
2846 response: Response<proto::UpdateProject>,
2847 session: Session,
2848) -> Result<()> {
2849 let project_id = ProjectId::from_proto(request.project_id);
2850 let (room, guest_connection_ids) = &*session
2851 .db()
2852 .await
2853 .update_project(project_id, session.connection_id, &request.worktrees)
2854 .await?;
2855 broadcast(
2856 Some(session.connection_id),
2857 guest_connection_ids.iter().copied(),
2858 |connection_id| {
2859 session
2860 .peer
2861 .forward_send(session.connection_id, connection_id, request.clone())
2862 },
2863 );
2864 if let Some(room) = room {
2865 room_updated(&room, &session.peer);
2866 }
2867 response.send(proto::Ack {})?;
2868
2869 Ok(())
2870}
2871
2872/// Updates other participants with changes to the worktree
2873async fn update_worktree(
2874 request: proto::UpdateWorktree,
2875 response: Response<proto::UpdateWorktree>,
2876 session: Session,
2877) -> Result<()> {
2878 let guest_connection_ids = session
2879 .db()
2880 .await
2881 .update_worktree(&request, session.connection_id)
2882 .await?;
2883
2884 broadcast(
2885 Some(session.connection_id),
2886 guest_connection_ids.iter().copied(),
2887 |connection_id| {
2888 session
2889 .peer
2890 .forward_send(session.connection_id, connection_id, request.clone())
2891 },
2892 );
2893 response.send(proto::Ack {})?;
2894 Ok(())
2895}
2896
2897/// Updates other participants with changes to the diagnostics
2898async fn update_diagnostic_summary(
2899 message: proto::UpdateDiagnosticSummary,
2900 session: Session,
2901) -> Result<()> {
2902 let guest_connection_ids = session
2903 .db()
2904 .await
2905 .update_diagnostic_summary(&message, session.connection_id)
2906 .await?;
2907
2908 broadcast(
2909 Some(session.connection_id),
2910 guest_connection_ids.iter().copied(),
2911 |connection_id| {
2912 session
2913 .peer
2914 .forward_send(session.connection_id, connection_id, message.clone())
2915 },
2916 );
2917
2918 Ok(())
2919}
2920
2921/// Updates other participants with changes to the worktree settings
2922async fn update_worktree_settings(
2923 message: proto::UpdateWorktreeSettings,
2924 session: Session,
2925) -> Result<()> {
2926 let guest_connection_ids = session
2927 .db()
2928 .await
2929 .update_worktree_settings(&message, session.connection_id)
2930 .await?;
2931
2932 broadcast(
2933 Some(session.connection_id),
2934 guest_connection_ids.iter().copied(),
2935 |connection_id| {
2936 session
2937 .peer
2938 .forward_send(session.connection_id, connection_id, message.clone())
2939 },
2940 );
2941
2942 Ok(())
2943}
2944
2945/// Notify other participants that a language server has started.
2946async fn start_language_server(
2947 request: proto::StartLanguageServer,
2948 session: Session,
2949) -> Result<()> {
2950 let guest_connection_ids = session
2951 .db()
2952 .await
2953 .start_language_server(&request, session.connection_id)
2954 .await?;
2955
2956 broadcast(
2957 Some(session.connection_id),
2958 guest_connection_ids.iter().copied(),
2959 |connection_id| {
2960 session
2961 .peer
2962 .forward_send(session.connection_id, connection_id, request.clone())
2963 },
2964 );
2965 Ok(())
2966}
2967
2968/// Notify other participants that a language server has changed.
2969async fn update_language_server(
2970 request: proto::UpdateLanguageServer,
2971 session: Session,
2972) -> Result<()> {
2973 let project_id = ProjectId::from_proto(request.project_id);
2974 let project_connection_ids = session
2975 .db()
2976 .await
2977 .project_connection_ids(project_id, session.connection_id, true)
2978 .await?;
2979 broadcast(
2980 Some(session.connection_id),
2981 project_connection_ids.iter().copied(),
2982 |connection_id| {
2983 session
2984 .peer
2985 .forward_send(session.connection_id, connection_id, request.clone())
2986 },
2987 );
2988 Ok(())
2989}
2990
2991/// forward a project request to the host. These requests should be read only
2992/// as guests are allowed to send them.
2993async fn forward_read_only_project_request<T>(
2994 request: T,
2995 response: Response<T>,
2996 session: UserSession,
2997) -> Result<()>
2998where
2999 T: EntityMessage + RequestMessage,
3000{
3001 let project_id = ProjectId::from_proto(request.remote_entity_id());
3002 let host_connection_id = session
3003 .db()
3004 .await
3005 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
3006 .await?;
3007 let payload = session
3008 .peer
3009 .forward_request(session.connection_id, host_connection_id, request)
3010 .await?;
3011 response.send(payload)?;
3012 Ok(())
3013}
3014
3015/// forward a project request to the dev server. Only allowed
3016/// if it's your dev server.
3017async fn forward_project_request_for_owner<T>(
3018 request: T,
3019 response: Response<T>,
3020 session: UserSession,
3021) -> Result<()>
3022where
3023 T: EntityMessage + RequestMessage,
3024{
3025 let project_id = ProjectId::from_proto(request.remote_entity_id());
3026
3027 let host_connection_id = session
3028 .db()
3029 .await
3030 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3031 .await?;
3032 let payload = session
3033 .peer
3034 .forward_request(session.connection_id, host_connection_id, request)
3035 .await?;
3036 response.send(payload)?;
3037 Ok(())
3038}
3039
3040/// forward a project request to the host. These requests are disallowed
3041/// for guests.
3042async fn forward_mutating_project_request<T>(
3043 request: T,
3044 response: Response<T>,
3045 session: UserSession,
3046) -> Result<()>
3047where
3048 T: EntityMessage + RequestMessage,
3049{
3050 let project_id = ProjectId::from_proto(request.remote_entity_id());
3051
3052 let host_connection_id = session
3053 .db()
3054 .await
3055 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3056 .await?;
3057 let payload = session
3058 .peer
3059 .forward_request(session.connection_id, host_connection_id, request)
3060 .await?;
3061 response.send(payload)?;
3062 Ok(())
3063}
3064
3065/// forward a project request to the host. These requests are disallowed
3066/// for guests.
3067async fn forward_versioned_mutating_project_request<T>(
3068 request: T,
3069 response: Response<T>,
3070 session: UserSession,
3071) -> Result<()>
3072where
3073 T: EntityMessage + RequestMessage + VersionedMessage,
3074{
3075 let project_id = ProjectId::from_proto(request.remote_entity_id());
3076
3077 let host_connection_id = session
3078 .db()
3079 .await
3080 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3081 .await?;
3082 if let Some(host_version) = session
3083 .connection_pool()
3084 .await
3085 .connection(host_connection_id)
3086 .map(|c| c.zed_version)
3087 {
3088 if let Some(min_required_version) = request.required_host_version() {
3089 if min_required_version > host_version {
3090 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3091 .with_tag("required", &min_required_version.to_string())))?;
3092 }
3093 }
3094 }
3095
3096 let payload = session
3097 .peer
3098 .forward_request(session.connection_id, host_connection_id, request)
3099 .await?;
3100 response.send(payload)?;
3101 Ok(())
3102}
3103
3104/// Notify other participants that a new buffer has been created
3105async fn create_buffer_for_peer(
3106 request: proto::CreateBufferForPeer,
3107 session: Session,
3108) -> Result<()> {
3109 session
3110 .db()
3111 .await
3112 .check_user_is_project_host(
3113 ProjectId::from_proto(request.project_id),
3114 session.connection_id,
3115 )
3116 .await?;
3117 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3118 session
3119 .peer
3120 .forward_send(session.connection_id, peer_id.into(), request)?;
3121 Ok(())
3122}
3123
3124/// Notify other participants that a buffer has been updated. This is
3125/// allowed for guests as long as the update is limited to selections.
3126async fn update_buffer(
3127 request: proto::UpdateBuffer,
3128 response: Response<proto::UpdateBuffer>,
3129 session: Session,
3130) -> Result<()> {
3131 let project_id = ProjectId::from_proto(request.project_id);
3132 let mut capability = Capability::ReadOnly;
3133
3134 for op in request.operations.iter() {
3135 match op.variant {
3136 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3137 Some(_) => capability = Capability::ReadWrite,
3138 }
3139 }
3140
3141 let host = {
3142 let guard = session
3143 .db()
3144 .await
3145 .connections_for_buffer_update(
3146 project_id,
3147 session.principal_id(),
3148 session.connection_id,
3149 capability,
3150 )
3151 .await?;
3152
3153 let (host, guests) = &*guard;
3154
3155 broadcast(
3156 Some(session.connection_id),
3157 guests.clone(),
3158 |connection_id| {
3159 session
3160 .peer
3161 .forward_send(session.connection_id, connection_id, request.clone())
3162 },
3163 );
3164
3165 *host
3166 };
3167
3168 if host != session.connection_id {
3169 session
3170 .peer
3171 .forward_request(session.connection_id, host, request.clone())
3172 .await?;
3173 }
3174
3175 response.send(proto::Ack {})?;
3176 Ok(())
3177}
3178
3179async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3180 let project_id = ProjectId::from_proto(message.project_id);
3181
3182 let operation = message.operation.as_ref().context("invalid operation")?;
3183 let capability = match operation.variant.as_ref() {
3184 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3185 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3186 match buffer_op.variant {
3187 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3188 Capability::ReadOnly
3189 }
3190 _ => Capability::ReadWrite,
3191 }
3192 } else {
3193 Capability::ReadWrite
3194 }
3195 }
3196 Some(_) => Capability::ReadWrite,
3197 None => Capability::ReadOnly,
3198 };
3199
3200 let guard = session
3201 .db()
3202 .await
3203 .connections_for_buffer_update(
3204 project_id,
3205 session.principal_id(),
3206 session.connection_id,
3207 capability,
3208 )
3209 .await?;
3210
3211 let (host, guests) = &*guard;
3212
3213 broadcast(
3214 Some(session.connection_id),
3215 guests.iter().chain([host]).copied(),
3216 |connection_id| {
3217 session
3218 .peer
3219 .forward_send(session.connection_id, connection_id, message.clone())
3220 },
3221 );
3222
3223 Ok(())
3224}
3225
3226/// Notify other participants that a project has been updated.
3227async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3228 request: T,
3229 session: Session,
3230) -> Result<()> {
3231 let project_id = ProjectId::from_proto(request.remote_entity_id());
3232 let project_connection_ids = session
3233 .db()
3234 .await
3235 .project_connection_ids(project_id, session.connection_id, false)
3236 .await?;
3237
3238 broadcast(
3239 Some(session.connection_id),
3240 project_connection_ids.iter().copied(),
3241 |connection_id| {
3242 session
3243 .peer
3244 .forward_send(session.connection_id, connection_id, request.clone())
3245 },
3246 );
3247 Ok(())
3248}
3249
3250/// Start following another user in a call.
3251async fn follow(
3252 request: proto::Follow,
3253 response: Response<proto::Follow>,
3254 session: UserSession,
3255) -> Result<()> {
3256 let room_id = RoomId::from_proto(request.room_id);
3257 let project_id = request.project_id.map(ProjectId::from_proto);
3258 let leader_id = request
3259 .leader_id
3260 .ok_or_else(|| anyhow!("invalid leader id"))?
3261 .into();
3262 let follower_id = session.connection_id;
3263
3264 session
3265 .db()
3266 .await
3267 .check_room_participants(room_id, leader_id, session.connection_id)
3268 .await?;
3269
3270 let response_payload = session
3271 .peer
3272 .forward_request(session.connection_id, leader_id, request)
3273 .await?;
3274 response.send(response_payload)?;
3275
3276 if let Some(project_id) = project_id {
3277 let room = session
3278 .db()
3279 .await
3280 .follow(room_id, project_id, leader_id, follower_id)
3281 .await?;
3282 room_updated(&room, &session.peer);
3283 }
3284
3285 Ok(())
3286}
3287
3288/// Stop following another user in a call.
3289async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3290 let room_id = RoomId::from_proto(request.room_id);
3291 let project_id = request.project_id.map(ProjectId::from_proto);
3292 let leader_id = request
3293 .leader_id
3294 .ok_or_else(|| anyhow!("invalid leader id"))?
3295 .into();
3296 let follower_id = session.connection_id;
3297
3298 session
3299 .db()
3300 .await
3301 .check_room_participants(room_id, leader_id, session.connection_id)
3302 .await?;
3303
3304 session
3305 .peer
3306 .forward_send(session.connection_id, leader_id, request)?;
3307
3308 if let Some(project_id) = project_id {
3309 let room = session
3310 .db()
3311 .await
3312 .unfollow(room_id, project_id, leader_id, follower_id)
3313 .await?;
3314 room_updated(&room, &session.peer);
3315 }
3316
3317 Ok(())
3318}
3319
3320/// Notify everyone following you of your current location.
3321async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3322 let room_id = RoomId::from_proto(request.room_id);
3323 let database = session.db.lock().await;
3324
3325 let connection_ids = if let Some(project_id) = request.project_id {
3326 let project_id = ProjectId::from_proto(project_id);
3327 database
3328 .project_connection_ids(project_id, session.connection_id, true)
3329 .await?
3330 } else {
3331 database
3332 .room_connection_ids(room_id, session.connection_id)
3333 .await?
3334 };
3335
3336 // For now, don't send view update messages back to that view's current leader.
3337 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3338 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3339 _ => None,
3340 });
3341
3342 for connection_id in connection_ids.iter().cloned() {
3343 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3344 session
3345 .peer
3346 .forward_send(session.connection_id, connection_id, request.clone())?;
3347 }
3348 }
3349 Ok(())
3350}
3351
3352/// Get public data about users.
3353async fn get_users(
3354 request: proto::GetUsers,
3355 response: Response<proto::GetUsers>,
3356 session: Session,
3357) -> Result<()> {
3358 let user_ids = request
3359 .user_ids
3360 .into_iter()
3361 .map(UserId::from_proto)
3362 .collect();
3363 let users = session
3364 .db()
3365 .await
3366 .get_users_by_ids(user_ids)
3367 .await?
3368 .into_iter()
3369 .map(|user| proto::User {
3370 id: user.id.to_proto(),
3371 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3372 github_login: user.github_login,
3373 })
3374 .collect();
3375 response.send(proto::UsersResponse { users })?;
3376 Ok(())
3377}
3378
3379/// Search for users (to invite) buy Github login
3380async fn fuzzy_search_users(
3381 request: proto::FuzzySearchUsers,
3382 response: Response<proto::FuzzySearchUsers>,
3383 session: UserSession,
3384) -> Result<()> {
3385 let query = request.query;
3386 let users = match query.len() {
3387 0 => vec![],
3388 1 | 2 => session
3389 .db()
3390 .await
3391 .get_user_by_github_login(&query)
3392 .await?
3393 .into_iter()
3394 .collect(),
3395 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3396 };
3397 let users = users
3398 .into_iter()
3399 .filter(|user| user.id != session.user_id())
3400 .map(|user| proto::User {
3401 id: user.id.to_proto(),
3402 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3403 github_login: user.github_login,
3404 })
3405 .collect();
3406 response.send(proto::UsersResponse { users })?;
3407 Ok(())
3408}
3409
3410/// Send a contact request to another user.
3411async fn request_contact(
3412 request: proto::RequestContact,
3413 response: Response<proto::RequestContact>,
3414 session: UserSession,
3415) -> Result<()> {
3416 let requester_id = session.user_id();
3417 let responder_id = UserId::from_proto(request.responder_id);
3418 if requester_id == responder_id {
3419 return Err(anyhow!("cannot add yourself as a contact"))?;
3420 }
3421
3422 let notifications = session
3423 .db()
3424 .await
3425 .send_contact_request(requester_id, responder_id)
3426 .await?;
3427
3428 // Update outgoing contact requests of requester
3429 let mut update = proto::UpdateContacts::default();
3430 update.outgoing_requests.push(responder_id.to_proto());
3431 for connection_id in session
3432 .connection_pool()
3433 .await
3434 .user_connection_ids(requester_id)
3435 {
3436 session.peer.send(connection_id, update.clone())?;
3437 }
3438
3439 // Update incoming contact requests of responder
3440 let mut update = proto::UpdateContacts::default();
3441 update
3442 .incoming_requests
3443 .push(proto::IncomingContactRequest {
3444 requester_id: requester_id.to_proto(),
3445 });
3446 let connection_pool = session.connection_pool().await;
3447 for connection_id in connection_pool.user_connection_ids(responder_id) {
3448 session.peer.send(connection_id, update.clone())?;
3449 }
3450
3451 send_notifications(&connection_pool, &session.peer, notifications);
3452
3453 response.send(proto::Ack {})?;
3454 Ok(())
3455}
3456
3457/// Accept or decline a contact request
3458async fn respond_to_contact_request(
3459 request: proto::RespondToContactRequest,
3460 response: Response<proto::RespondToContactRequest>,
3461 session: UserSession,
3462) -> Result<()> {
3463 let responder_id = session.user_id();
3464 let requester_id = UserId::from_proto(request.requester_id);
3465 let db = session.db().await;
3466 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3467 db.dismiss_contact_notification(responder_id, requester_id)
3468 .await?;
3469 } else {
3470 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3471
3472 let notifications = db
3473 .respond_to_contact_request(responder_id, requester_id, accept)
3474 .await?;
3475 let requester_busy = db.is_user_busy(requester_id).await?;
3476 let responder_busy = db.is_user_busy(responder_id).await?;
3477
3478 let pool = session.connection_pool().await;
3479 // Update responder with new contact
3480 let mut update = proto::UpdateContacts::default();
3481 if accept {
3482 update
3483 .contacts
3484 .push(contact_for_user(requester_id, requester_busy, &pool));
3485 }
3486 update
3487 .remove_incoming_requests
3488 .push(requester_id.to_proto());
3489 for connection_id in pool.user_connection_ids(responder_id) {
3490 session.peer.send(connection_id, update.clone())?;
3491 }
3492
3493 // Update requester with new contact
3494 let mut update = proto::UpdateContacts::default();
3495 if accept {
3496 update
3497 .contacts
3498 .push(contact_for_user(responder_id, responder_busy, &pool));
3499 }
3500 update
3501 .remove_outgoing_requests
3502 .push(responder_id.to_proto());
3503
3504 for connection_id in pool.user_connection_ids(requester_id) {
3505 session.peer.send(connection_id, update.clone())?;
3506 }
3507
3508 send_notifications(&pool, &session.peer, notifications);
3509 }
3510
3511 response.send(proto::Ack {})?;
3512 Ok(())
3513}
3514
3515/// Remove a contact.
3516async fn remove_contact(
3517 request: proto::RemoveContact,
3518 response: Response<proto::RemoveContact>,
3519 session: UserSession,
3520) -> Result<()> {
3521 let requester_id = session.user_id();
3522 let responder_id = UserId::from_proto(request.user_id);
3523 let db = session.db().await;
3524 let (contact_accepted, deleted_notification_id) =
3525 db.remove_contact(requester_id, responder_id).await?;
3526
3527 let pool = session.connection_pool().await;
3528 // Update outgoing contact requests of requester
3529 let mut update = proto::UpdateContacts::default();
3530 if contact_accepted {
3531 update.remove_contacts.push(responder_id.to_proto());
3532 } else {
3533 update
3534 .remove_outgoing_requests
3535 .push(responder_id.to_proto());
3536 }
3537 for connection_id in pool.user_connection_ids(requester_id) {
3538 session.peer.send(connection_id, update.clone())?;
3539 }
3540
3541 // Update incoming contact requests of responder
3542 let mut update = proto::UpdateContacts::default();
3543 if contact_accepted {
3544 update.remove_contacts.push(requester_id.to_proto());
3545 } else {
3546 update
3547 .remove_incoming_requests
3548 .push(requester_id.to_proto());
3549 }
3550 for connection_id in pool.user_connection_ids(responder_id) {
3551 session.peer.send(connection_id, update.clone())?;
3552 if let Some(notification_id) = deleted_notification_id {
3553 session.peer.send(
3554 connection_id,
3555 proto::DeleteNotification {
3556 notification_id: notification_id.to_proto(),
3557 },
3558 )?;
3559 }
3560 }
3561
3562 response.send(proto::Ack {})?;
3563 Ok(())
3564}
3565
3566fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3567 version.0.minor() < 139
3568}
3569
3570async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3571 let plan = session.current_plan().await?;
3572
3573 session
3574 .peer
3575 .send(
3576 session.connection_id,
3577 proto::UpdateUserPlan { plan: plan.into() },
3578 )
3579 .trace_err();
3580
3581 Ok(())
3582}
3583
3584async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3585 subscribe_user_to_channels(
3586 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3587 &session,
3588 )
3589 .await?;
3590 Ok(())
3591}
3592
3593async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3594 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3595 let mut pool = session.connection_pool().await;
3596 for membership in &channels_for_user.channel_memberships {
3597 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3598 }
3599 session.peer.send(
3600 session.connection_id,
3601 build_update_user_channels(&channels_for_user),
3602 )?;
3603 session.peer.send(
3604 session.connection_id,
3605 build_channels_update(channels_for_user),
3606 )?;
3607 Ok(())
3608}
3609
3610/// Creates a new channel.
3611async fn create_channel(
3612 request: proto::CreateChannel,
3613 response: Response<proto::CreateChannel>,
3614 session: UserSession,
3615) -> Result<()> {
3616 let db = session.db().await;
3617
3618 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3619 let (channel, membership) = db
3620 .create_channel(&request.name, parent_id, session.user_id())
3621 .await?;
3622
3623 let root_id = channel.root_id();
3624 let channel = Channel::from_model(channel);
3625
3626 response.send(proto::CreateChannelResponse {
3627 channel: Some(channel.to_proto()),
3628 parent_id: request.parent_id,
3629 })?;
3630
3631 let mut connection_pool = session.connection_pool().await;
3632 if let Some(membership) = membership {
3633 connection_pool.subscribe_to_channel(
3634 membership.user_id,
3635 membership.channel_id,
3636 membership.role,
3637 );
3638 let update = proto::UpdateUserChannels {
3639 channel_memberships: vec![proto::ChannelMembership {
3640 channel_id: membership.channel_id.to_proto(),
3641 role: membership.role.into(),
3642 }],
3643 ..Default::default()
3644 };
3645 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3646 session.peer.send(connection_id, update.clone())?;
3647 }
3648 }
3649
3650 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3651 if !role.can_see_channel(channel.visibility) {
3652 continue;
3653 }
3654
3655 let update = proto::UpdateChannels {
3656 channels: vec![channel.to_proto()],
3657 ..Default::default()
3658 };
3659 session.peer.send(connection_id, update.clone())?;
3660 }
3661
3662 Ok(())
3663}
3664
3665/// Delete a channel
3666async fn delete_channel(
3667 request: proto::DeleteChannel,
3668 response: Response<proto::DeleteChannel>,
3669 session: UserSession,
3670) -> Result<()> {
3671 let db = session.db().await;
3672
3673 let channel_id = request.channel_id;
3674 let (root_channel, removed_channels) = db
3675 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3676 .await?;
3677 response.send(proto::Ack {})?;
3678
3679 // Notify members of removed channels
3680 let mut update = proto::UpdateChannels::default();
3681 update
3682 .delete_channels
3683 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3684
3685 let connection_pool = session.connection_pool().await;
3686 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3687 session.peer.send(connection_id, update.clone())?;
3688 }
3689
3690 Ok(())
3691}
3692
3693/// Invite someone to join a channel.
3694async fn invite_channel_member(
3695 request: proto::InviteChannelMember,
3696 response: Response<proto::InviteChannelMember>,
3697 session: UserSession,
3698) -> Result<()> {
3699 let db = session.db().await;
3700 let channel_id = ChannelId::from_proto(request.channel_id);
3701 let invitee_id = UserId::from_proto(request.user_id);
3702 let InviteMemberResult {
3703 channel,
3704 notifications,
3705 } = db
3706 .invite_channel_member(
3707 channel_id,
3708 invitee_id,
3709 session.user_id(),
3710 request.role().into(),
3711 )
3712 .await?;
3713
3714 let update = proto::UpdateChannels {
3715 channel_invitations: vec![channel.to_proto()],
3716 ..Default::default()
3717 };
3718
3719 let connection_pool = session.connection_pool().await;
3720 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3721 session.peer.send(connection_id, update.clone())?;
3722 }
3723
3724 send_notifications(&connection_pool, &session.peer, notifications);
3725
3726 response.send(proto::Ack {})?;
3727 Ok(())
3728}
3729
3730/// remove someone from a channel
3731async fn remove_channel_member(
3732 request: proto::RemoveChannelMember,
3733 response: Response<proto::RemoveChannelMember>,
3734 session: UserSession,
3735) -> Result<()> {
3736 let db = session.db().await;
3737 let channel_id = ChannelId::from_proto(request.channel_id);
3738 let member_id = UserId::from_proto(request.user_id);
3739
3740 let RemoveChannelMemberResult {
3741 membership_update,
3742 notification_id,
3743 } = db
3744 .remove_channel_member(channel_id, member_id, session.user_id())
3745 .await?;
3746
3747 let mut connection_pool = session.connection_pool().await;
3748 notify_membership_updated(
3749 &mut connection_pool,
3750 membership_update,
3751 member_id,
3752 &session.peer,
3753 );
3754 for connection_id in connection_pool.user_connection_ids(member_id) {
3755 if let Some(notification_id) = notification_id {
3756 session
3757 .peer
3758 .send(
3759 connection_id,
3760 proto::DeleteNotification {
3761 notification_id: notification_id.to_proto(),
3762 },
3763 )
3764 .trace_err();
3765 }
3766 }
3767
3768 response.send(proto::Ack {})?;
3769 Ok(())
3770}
3771
3772/// Toggle the channel between public and private.
3773/// Care is taken to maintain the invariant that public channels only descend from public channels,
3774/// (though members-only channels can appear at any point in the hierarchy).
3775async fn set_channel_visibility(
3776 request: proto::SetChannelVisibility,
3777 response: Response<proto::SetChannelVisibility>,
3778 session: UserSession,
3779) -> Result<()> {
3780 let db = session.db().await;
3781 let channel_id = ChannelId::from_proto(request.channel_id);
3782 let visibility = request.visibility().into();
3783
3784 let channel_model = db
3785 .set_channel_visibility(channel_id, visibility, session.user_id())
3786 .await?;
3787 let root_id = channel_model.root_id();
3788 let channel = Channel::from_model(channel_model);
3789
3790 let mut connection_pool = session.connection_pool().await;
3791 for (user_id, role) in connection_pool
3792 .channel_user_ids(root_id)
3793 .collect::<Vec<_>>()
3794 .into_iter()
3795 {
3796 let update = if role.can_see_channel(channel.visibility) {
3797 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3798 proto::UpdateChannels {
3799 channels: vec![channel.to_proto()],
3800 ..Default::default()
3801 }
3802 } else {
3803 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3804 proto::UpdateChannels {
3805 delete_channels: vec![channel.id.to_proto()],
3806 ..Default::default()
3807 }
3808 };
3809
3810 for connection_id in connection_pool.user_connection_ids(user_id) {
3811 session.peer.send(connection_id, update.clone())?;
3812 }
3813 }
3814
3815 response.send(proto::Ack {})?;
3816 Ok(())
3817}
3818
3819/// Alter the role for a user in the channel.
3820async fn set_channel_member_role(
3821 request: proto::SetChannelMemberRole,
3822 response: Response<proto::SetChannelMemberRole>,
3823 session: UserSession,
3824) -> Result<()> {
3825 let db = session.db().await;
3826 let channel_id = ChannelId::from_proto(request.channel_id);
3827 let member_id = UserId::from_proto(request.user_id);
3828 let result = db
3829 .set_channel_member_role(
3830 channel_id,
3831 session.user_id(),
3832 member_id,
3833 request.role().into(),
3834 )
3835 .await?;
3836
3837 match result {
3838 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3839 let mut connection_pool = session.connection_pool().await;
3840 notify_membership_updated(
3841 &mut connection_pool,
3842 membership_update,
3843 member_id,
3844 &session.peer,
3845 )
3846 }
3847 db::SetMemberRoleResult::InviteUpdated(channel) => {
3848 let update = proto::UpdateChannels {
3849 channel_invitations: vec![channel.to_proto()],
3850 ..Default::default()
3851 };
3852
3853 for connection_id in session
3854 .connection_pool()
3855 .await
3856 .user_connection_ids(member_id)
3857 {
3858 session.peer.send(connection_id, update.clone())?;
3859 }
3860 }
3861 }
3862
3863 response.send(proto::Ack {})?;
3864 Ok(())
3865}
3866
3867/// Change the name of a channel
3868async fn rename_channel(
3869 request: proto::RenameChannel,
3870 response: Response<proto::RenameChannel>,
3871 session: UserSession,
3872) -> Result<()> {
3873 let db = session.db().await;
3874 let channel_id = ChannelId::from_proto(request.channel_id);
3875 let channel_model = db
3876 .rename_channel(channel_id, session.user_id(), &request.name)
3877 .await?;
3878 let root_id = channel_model.root_id();
3879 let channel = Channel::from_model(channel_model);
3880
3881 response.send(proto::RenameChannelResponse {
3882 channel: Some(channel.to_proto()),
3883 })?;
3884
3885 let connection_pool = session.connection_pool().await;
3886 let update = proto::UpdateChannels {
3887 channels: vec![channel.to_proto()],
3888 ..Default::default()
3889 };
3890 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3891 if role.can_see_channel(channel.visibility) {
3892 session.peer.send(connection_id, update.clone())?;
3893 }
3894 }
3895
3896 Ok(())
3897}
3898
3899/// Move a channel to a new parent.
3900async fn move_channel(
3901 request: proto::MoveChannel,
3902 response: Response<proto::MoveChannel>,
3903 session: UserSession,
3904) -> Result<()> {
3905 let channel_id = ChannelId::from_proto(request.channel_id);
3906 let to = ChannelId::from_proto(request.to);
3907
3908 let (root_id, channels) = session
3909 .db()
3910 .await
3911 .move_channel(channel_id, to, session.user_id())
3912 .await?;
3913
3914 let connection_pool = session.connection_pool().await;
3915 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3916 let channels = channels
3917 .iter()
3918 .filter_map(|channel| {
3919 if role.can_see_channel(channel.visibility) {
3920 Some(channel.to_proto())
3921 } else {
3922 None
3923 }
3924 })
3925 .collect::<Vec<_>>();
3926 if channels.is_empty() {
3927 continue;
3928 }
3929
3930 let update = proto::UpdateChannels {
3931 channels,
3932 ..Default::default()
3933 };
3934
3935 session.peer.send(connection_id, update.clone())?;
3936 }
3937
3938 response.send(Ack {})?;
3939 Ok(())
3940}
3941
3942/// Get the list of channel members
3943async fn get_channel_members(
3944 request: proto::GetChannelMembers,
3945 response: Response<proto::GetChannelMembers>,
3946 session: UserSession,
3947) -> Result<()> {
3948 let db = session.db().await;
3949 let channel_id = ChannelId::from_proto(request.channel_id);
3950 let limit = if request.limit == 0 {
3951 u16::MAX as u64
3952 } else {
3953 request.limit
3954 };
3955 let (members, users) = db
3956 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3957 .await?;
3958 response.send(proto::GetChannelMembersResponse { members, users })?;
3959 Ok(())
3960}
3961
3962/// Accept or decline a channel invitation.
3963async fn respond_to_channel_invite(
3964 request: proto::RespondToChannelInvite,
3965 response: Response<proto::RespondToChannelInvite>,
3966 session: UserSession,
3967) -> Result<()> {
3968 let db = session.db().await;
3969 let channel_id = ChannelId::from_proto(request.channel_id);
3970 let RespondToChannelInvite {
3971 membership_update,
3972 notifications,
3973 } = db
3974 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3975 .await?;
3976
3977 let mut connection_pool = session.connection_pool().await;
3978 if let Some(membership_update) = membership_update {
3979 notify_membership_updated(
3980 &mut connection_pool,
3981 membership_update,
3982 session.user_id(),
3983 &session.peer,
3984 );
3985 } else {
3986 let update = proto::UpdateChannels {
3987 remove_channel_invitations: vec![channel_id.to_proto()],
3988 ..Default::default()
3989 };
3990
3991 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3992 session.peer.send(connection_id, update.clone())?;
3993 }
3994 };
3995
3996 send_notifications(&connection_pool, &session.peer, notifications);
3997
3998 response.send(proto::Ack {})?;
3999
4000 Ok(())
4001}
4002
4003/// Join the channels' room
4004async fn join_channel(
4005 request: proto::JoinChannel,
4006 response: Response<proto::JoinChannel>,
4007 session: UserSession,
4008) -> Result<()> {
4009 let channel_id = ChannelId::from_proto(request.channel_id);
4010 join_channel_internal(channel_id, Box::new(response), session).await
4011}
4012
4013trait JoinChannelInternalResponse {
4014 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
4015}
4016impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
4017 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4018 Response::<proto::JoinChannel>::send(self, result)
4019 }
4020}
4021impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
4022 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4023 Response::<proto::JoinRoom>::send(self, result)
4024 }
4025}
4026
4027async fn join_channel_internal(
4028 channel_id: ChannelId,
4029 response: Box<impl JoinChannelInternalResponse>,
4030 session: UserSession,
4031) -> Result<()> {
4032 let joined_room = {
4033 let mut db = session.db().await;
4034 // If zed quits without leaving the room, and the user re-opens zed before the
4035 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
4036 // room they were in.
4037 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
4038 tracing::info!(
4039 stale_connection_id = %connection,
4040 "cleaning up stale connection",
4041 );
4042 drop(db);
4043 leave_room_for_session(&session, connection).await?;
4044 db = session.db().await;
4045 }
4046
4047 let (joined_room, membership_updated, role) = db
4048 .join_channel(channel_id, session.user_id(), session.connection_id)
4049 .await?;
4050
4051 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
4052 let (can_publish, token) = if role == ChannelRole::Guest {
4053 (
4054 false,
4055 live_kit
4056 .guest_token(
4057 &joined_room.room.live_kit_room,
4058 &session.user_id().to_string(),
4059 )
4060 .trace_err()?,
4061 )
4062 } else {
4063 (
4064 true,
4065 live_kit
4066 .room_token(
4067 &joined_room.room.live_kit_room,
4068 &session.user_id().to_string(),
4069 )
4070 .trace_err()?,
4071 )
4072 };
4073
4074 Some(LiveKitConnectionInfo {
4075 server_url: live_kit.url().into(),
4076 token,
4077 can_publish,
4078 })
4079 });
4080
4081 response.send(proto::JoinRoomResponse {
4082 room: Some(joined_room.room.clone()),
4083 channel_id: joined_room
4084 .channel
4085 .as_ref()
4086 .map(|channel| channel.id.to_proto()),
4087 live_kit_connection_info,
4088 })?;
4089
4090 let mut connection_pool = session.connection_pool().await;
4091 if let Some(membership_updated) = membership_updated {
4092 notify_membership_updated(
4093 &mut connection_pool,
4094 membership_updated,
4095 session.user_id(),
4096 &session.peer,
4097 );
4098 }
4099
4100 room_updated(&joined_room.room, &session.peer);
4101
4102 joined_room
4103 };
4104
4105 channel_updated(
4106 &joined_room
4107 .channel
4108 .ok_or_else(|| anyhow!("channel not returned"))?,
4109 &joined_room.room,
4110 &session.peer,
4111 &*session.connection_pool().await,
4112 );
4113
4114 update_user_contacts(session.user_id(), &session).await?;
4115 Ok(())
4116}
4117
4118/// Start editing the channel notes
4119async fn join_channel_buffer(
4120 request: proto::JoinChannelBuffer,
4121 response: Response<proto::JoinChannelBuffer>,
4122 session: UserSession,
4123) -> Result<()> {
4124 let db = session.db().await;
4125 let channel_id = ChannelId::from_proto(request.channel_id);
4126
4127 let open_response = db
4128 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4129 .await?;
4130
4131 let collaborators = open_response.collaborators.clone();
4132 response.send(open_response)?;
4133
4134 let update = UpdateChannelBufferCollaborators {
4135 channel_id: channel_id.to_proto(),
4136 collaborators: collaborators.clone(),
4137 };
4138 channel_buffer_updated(
4139 session.connection_id,
4140 collaborators
4141 .iter()
4142 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4143 &update,
4144 &session.peer,
4145 );
4146
4147 Ok(())
4148}
4149
4150/// Edit the channel notes
4151async fn update_channel_buffer(
4152 request: proto::UpdateChannelBuffer,
4153 session: UserSession,
4154) -> Result<()> {
4155 let db = session.db().await;
4156 let channel_id = ChannelId::from_proto(request.channel_id);
4157
4158 let (collaborators, epoch, version) = db
4159 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4160 .await?;
4161
4162 channel_buffer_updated(
4163 session.connection_id,
4164 collaborators.clone(),
4165 &proto::UpdateChannelBuffer {
4166 channel_id: channel_id.to_proto(),
4167 operations: request.operations,
4168 },
4169 &session.peer,
4170 );
4171
4172 let pool = &*session.connection_pool().await;
4173
4174 let non_collaborators =
4175 pool.channel_connection_ids(channel_id)
4176 .filter_map(|(connection_id, _)| {
4177 if collaborators.contains(&connection_id) {
4178 None
4179 } else {
4180 Some(connection_id)
4181 }
4182 });
4183
4184 broadcast(None, non_collaborators, |peer_id| {
4185 session.peer.send(
4186 peer_id,
4187 proto::UpdateChannels {
4188 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4189 channel_id: channel_id.to_proto(),
4190 epoch: epoch as u64,
4191 version: version.clone(),
4192 }],
4193 ..Default::default()
4194 },
4195 )
4196 });
4197
4198 Ok(())
4199}
4200
4201/// Rejoin the channel notes after a connection blip
4202async fn rejoin_channel_buffers(
4203 request: proto::RejoinChannelBuffers,
4204 response: Response<proto::RejoinChannelBuffers>,
4205 session: UserSession,
4206) -> Result<()> {
4207 let db = session.db().await;
4208 let buffers = db
4209 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4210 .await?;
4211
4212 for rejoined_buffer in &buffers {
4213 let collaborators_to_notify = rejoined_buffer
4214 .buffer
4215 .collaborators
4216 .iter()
4217 .filter_map(|c| Some(c.peer_id?.into()));
4218 channel_buffer_updated(
4219 session.connection_id,
4220 collaborators_to_notify,
4221 &proto::UpdateChannelBufferCollaborators {
4222 channel_id: rejoined_buffer.buffer.channel_id,
4223 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4224 },
4225 &session.peer,
4226 );
4227 }
4228
4229 response.send(proto::RejoinChannelBuffersResponse {
4230 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4231 })?;
4232
4233 Ok(())
4234}
4235
4236/// Stop editing the channel notes
4237async fn leave_channel_buffer(
4238 request: proto::LeaveChannelBuffer,
4239 response: Response<proto::LeaveChannelBuffer>,
4240 session: UserSession,
4241) -> Result<()> {
4242 let db = session.db().await;
4243 let channel_id = ChannelId::from_proto(request.channel_id);
4244
4245 let left_buffer = db
4246 .leave_channel_buffer(channel_id, session.connection_id)
4247 .await?;
4248
4249 response.send(Ack {})?;
4250
4251 channel_buffer_updated(
4252 session.connection_id,
4253 left_buffer.connections,
4254 &proto::UpdateChannelBufferCollaborators {
4255 channel_id: channel_id.to_proto(),
4256 collaborators: left_buffer.collaborators,
4257 },
4258 &session.peer,
4259 );
4260
4261 Ok(())
4262}
4263
4264fn channel_buffer_updated<T: EnvelopedMessage>(
4265 sender_id: ConnectionId,
4266 collaborators: impl IntoIterator<Item = ConnectionId>,
4267 message: &T,
4268 peer: &Peer,
4269) {
4270 broadcast(Some(sender_id), collaborators, |peer_id| {
4271 peer.send(peer_id, message.clone())
4272 });
4273}
4274
4275fn send_notifications(
4276 connection_pool: &ConnectionPool,
4277 peer: &Peer,
4278 notifications: db::NotificationBatch,
4279) {
4280 for (user_id, notification) in notifications {
4281 for connection_id in connection_pool.user_connection_ids(user_id) {
4282 if let Err(error) = peer.send(
4283 connection_id,
4284 proto::AddNotification {
4285 notification: Some(notification.clone()),
4286 },
4287 ) {
4288 tracing::error!(
4289 "failed to send notification to {:?} {}",
4290 connection_id,
4291 error
4292 );
4293 }
4294 }
4295 }
4296}
4297
4298/// Send a message to the channel
4299async fn send_channel_message(
4300 request: proto::SendChannelMessage,
4301 response: Response<proto::SendChannelMessage>,
4302 session: UserSession,
4303) -> Result<()> {
4304 // Validate the message body.
4305 let body = request.body.trim().to_string();
4306 if body.len() > MAX_MESSAGE_LEN {
4307 return Err(anyhow!("message is too long"))?;
4308 }
4309 if body.is_empty() {
4310 return Err(anyhow!("message can't be blank"))?;
4311 }
4312
4313 // TODO: adjust mentions if body is trimmed
4314
4315 let timestamp = OffsetDateTime::now_utc();
4316 let nonce = request
4317 .nonce
4318 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4319
4320 let channel_id = ChannelId::from_proto(request.channel_id);
4321 let CreatedChannelMessage {
4322 message_id,
4323 participant_connection_ids,
4324 notifications,
4325 } = session
4326 .db()
4327 .await
4328 .create_channel_message(
4329 channel_id,
4330 session.user_id(),
4331 &body,
4332 &request.mentions,
4333 timestamp,
4334 nonce.clone().into(),
4335 match request.reply_to_message_id {
4336 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4337 None => None,
4338 },
4339 )
4340 .await?;
4341
4342 let message = proto::ChannelMessage {
4343 sender_id: session.user_id().to_proto(),
4344 id: message_id.to_proto(),
4345 body,
4346 mentions: request.mentions,
4347 timestamp: timestamp.unix_timestamp() as u64,
4348 nonce: Some(nonce),
4349 reply_to_message_id: request.reply_to_message_id,
4350 edited_at: None,
4351 };
4352 broadcast(
4353 Some(session.connection_id),
4354 participant_connection_ids.clone(),
4355 |connection| {
4356 session.peer.send(
4357 connection,
4358 proto::ChannelMessageSent {
4359 channel_id: channel_id.to_proto(),
4360 message: Some(message.clone()),
4361 },
4362 )
4363 },
4364 );
4365 response.send(proto::SendChannelMessageResponse {
4366 message: Some(message),
4367 })?;
4368
4369 let pool = &*session.connection_pool().await;
4370 let non_participants =
4371 pool.channel_connection_ids(channel_id)
4372 .filter_map(|(connection_id, _)| {
4373 if participant_connection_ids.contains(&connection_id) {
4374 None
4375 } else {
4376 Some(connection_id)
4377 }
4378 });
4379 broadcast(None, non_participants, |peer_id| {
4380 session.peer.send(
4381 peer_id,
4382 proto::UpdateChannels {
4383 latest_channel_message_ids: vec![proto::ChannelMessageId {
4384 channel_id: channel_id.to_proto(),
4385 message_id: message_id.to_proto(),
4386 }],
4387 ..Default::default()
4388 },
4389 )
4390 });
4391 send_notifications(pool, &session.peer, notifications);
4392
4393 Ok(())
4394}
4395
4396/// Delete a channel message
4397async fn remove_channel_message(
4398 request: proto::RemoveChannelMessage,
4399 response: Response<proto::RemoveChannelMessage>,
4400 session: UserSession,
4401) -> Result<()> {
4402 let channel_id = ChannelId::from_proto(request.channel_id);
4403 let message_id = MessageId::from_proto(request.message_id);
4404 let (connection_ids, existing_notification_ids) = session
4405 .db()
4406 .await
4407 .remove_channel_message(channel_id, message_id, session.user_id())
4408 .await?;
4409
4410 broadcast(
4411 Some(session.connection_id),
4412 connection_ids,
4413 move |connection| {
4414 session.peer.send(connection, request.clone())?;
4415
4416 for notification_id in &existing_notification_ids {
4417 session.peer.send(
4418 connection,
4419 proto::DeleteNotification {
4420 notification_id: (*notification_id).to_proto(),
4421 },
4422 )?;
4423 }
4424
4425 Ok(())
4426 },
4427 );
4428 response.send(proto::Ack {})?;
4429 Ok(())
4430}
4431
4432async fn update_channel_message(
4433 request: proto::UpdateChannelMessage,
4434 response: Response<proto::UpdateChannelMessage>,
4435 session: UserSession,
4436) -> Result<()> {
4437 let channel_id = ChannelId::from_proto(request.channel_id);
4438 let message_id = MessageId::from_proto(request.message_id);
4439 let updated_at = OffsetDateTime::now_utc();
4440 let UpdatedChannelMessage {
4441 message_id,
4442 participant_connection_ids,
4443 notifications,
4444 reply_to_message_id,
4445 timestamp,
4446 deleted_mention_notification_ids,
4447 updated_mention_notifications,
4448 } = session
4449 .db()
4450 .await
4451 .update_channel_message(
4452 channel_id,
4453 message_id,
4454 session.user_id(),
4455 request.body.as_str(),
4456 &request.mentions,
4457 updated_at,
4458 )
4459 .await?;
4460
4461 let nonce = request
4462 .nonce
4463 .clone()
4464 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4465
4466 let message = proto::ChannelMessage {
4467 sender_id: session.user_id().to_proto(),
4468 id: message_id.to_proto(),
4469 body: request.body.clone(),
4470 mentions: request.mentions.clone(),
4471 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4472 nonce: Some(nonce),
4473 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4474 edited_at: Some(updated_at.unix_timestamp() as u64),
4475 };
4476
4477 response.send(proto::Ack {})?;
4478
4479 let pool = &*session.connection_pool().await;
4480 broadcast(
4481 Some(session.connection_id),
4482 participant_connection_ids,
4483 |connection| {
4484 session.peer.send(
4485 connection,
4486 proto::ChannelMessageUpdate {
4487 channel_id: channel_id.to_proto(),
4488 message: Some(message.clone()),
4489 },
4490 )?;
4491
4492 for notification_id in &deleted_mention_notification_ids {
4493 session.peer.send(
4494 connection,
4495 proto::DeleteNotification {
4496 notification_id: (*notification_id).to_proto(),
4497 },
4498 )?;
4499 }
4500
4501 for notification in &updated_mention_notifications {
4502 session.peer.send(
4503 connection,
4504 proto::UpdateNotification {
4505 notification: Some(notification.clone()),
4506 },
4507 )?;
4508 }
4509
4510 Ok(())
4511 },
4512 );
4513
4514 send_notifications(pool, &session.peer, notifications);
4515
4516 Ok(())
4517}
4518
4519/// Mark a channel message as read
4520async fn acknowledge_channel_message(
4521 request: proto::AckChannelMessage,
4522 session: UserSession,
4523) -> Result<()> {
4524 let channel_id = ChannelId::from_proto(request.channel_id);
4525 let message_id = MessageId::from_proto(request.message_id);
4526 let notifications = session
4527 .db()
4528 .await
4529 .observe_channel_message(channel_id, session.user_id(), message_id)
4530 .await?;
4531 send_notifications(
4532 &*session.connection_pool().await,
4533 &session.peer,
4534 notifications,
4535 );
4536 Ok(())
4537}
4538
4539/// Mark a buffer version as synced
4540async fn acknowledge_buffer_version(
4541 request: proto::AckBufferOperation,
4542 session: UserSession,
4543) -> Result<()> {
4544 let buffer_id = BufferId::from_proto(request.buffer_id);
4545 session
4546 .db()
4547 .await
4548 .observe_buffer_version(
4549 buffer_id,
4550 session.user_id(),
4551 request.epoch as i32,
4552 &request.version,
4553 )
4554 .await?;
4555 Ok(())
4556}
4557
4558struct ZedProCompleteWithLanguageModelRateLimit;
4559
4560impl RateLimit for ZedProCompleteWithLanguageModelRateLimit {
4561 fn capacity(&self) -> usize {
4562 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4563 .ok()
4564 .and_then(|v| v.parse().ok())
4565 .unwrap_or(120) // Picked arbitrarily
4566 }
4567
4568 fn refill_duration(&self) -> chrono::Duration {
4569 chrono::Duration::hours(1)
4570 }
4571
4572 fn db_name(&self) -> &'static str {
4573 "zed-pro:complete-with-language-model"
4574 }
4575}
4576
4577struct FreeCompleteWithLanguageModelRateLimit;
4578
4579impl RateLimit for FreeCompleteWithLanguageModelRateLimit {
4580 fn capacity(&self) -> usize {
4581 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE")
4582 .ok()
4583 .and_then(|v| v.parse().ok())
4584 .unwrap_or(120 / 10) // Picked arbitrarily
4585 }
4586
4587 fn refill_duration(&self) -> chrono::Duration {
4588 chrono::Duration::hours(1)
4589 }
4590
4591 fn db_name(&self) -> &'static str {
4592 "free:complete-with-language-model"
4593 }
4594}
4595
4596async fn complete_with_language_model(
4597 request: proto::CompleteWithLanguageModel,
4598 response: Response<proto::CompleteWithLanguageModel>,
4599 session: Session,
4600 config: &Config,
4601) -> Result<()> {
4602 let Some(session) = session.for_user() else {
4603 return Err(anyhow!("user not found"))?;
4604 };
4605 authorize_access_to_language_models(&session).await?;
4606
4607 let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4608 proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
4609 proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
4610 };
4611
4612 session
4613 .rate_limiter
4614 .check(&*rate_limit, session.user_id())
4615 .await?;
4616
4617 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4618 Some(proto::LanguageModelProvider::Anthropic) => {
4619 let api_key = config
4620 .anthropic_api_key
4621 .as_ref()
4622 .context("no Anthropic AI API key configured on the server")?;
4623 anthropic::complete(
4624 session.http_client.as_ref(),
4625 anthropic::ANTHROPIC_API_URL,
4626 api_key,
4627 serde_json::from_str(&request.request)?,
4628 )
4629 .await?
4630 }
4631 _ => return Err(anyhow!("unsupported provider"))?,
4632 };
4633
4634 response.send(proto::CompleteWithLanguageModelResponse {
4635 completion: serde_json::to_string(&result)?,
4636 })?;
4637
4638 Ok(())
4639}
4640
4641async fn stream_complete_with_language_model(
4642 request: proto::StreamCompleteWithLanguageModel,
4643 response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
4644 session: Session,
4645 config: &Config,
4646) -> Result<()> {
4647 let Some(session) = session.for_user() else {
4648 return Err(anyhow!("user not found"))?;
4649 };
4650 authorize_access_to_language_models(&session).await?;
4651
4652 let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4653 proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
4654 proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
4655 };
4656
4657 session
4658 .rate_limiter
4659 .check(&*rate_limit, session.user_id())
4660 .await?;
4661
4662 match proto::LanguageModelProvider::from_i32(request.provider) {
4663 Some(proto::LanguageModelProvider::Anthropic) => {
4664 let api_key = config
4665 .anthropic_api_key
4666 .as_ref()
4667 .context("no Anthropic AI API key configured on the server")?;
4668 let mut chunks = anthropic::stream_completion(
4669 session.http_client.as_ref(),
4670 anthropic::ANTHROPIC_API_URL,
4671 api_key,
4672 serde_json::from_str(&request.request)?,
4673 None,
4674 )
4675 .await?;
4676 while let Some(event) = chunks.next().await {
4677 let chunk = event?;
4678 response.send(proto::StreamCompleteWithLanguageModelResponse {
4679 event: serde_json::to_string(&chunk)?,
4680 })?;
4681 }
4682 }
4683 Some(proto::LanguageModelProvider::OpenAi) => {
4684 let api_key = config
4685 .openai_api_key
4686 .as_ref()
4687 .context("no OpenAI API key configured on the server")?;
4688 let mut events = open_ai::stream_completion(
4689 session.http_client.as_ref(),
4690 open_ai::OPEN_AI_API_URL,
4691 api_key,
4692 serde_json::from_str(&request.request)?,
4693 None,
4694 )
4695 .await?;
4696 while let Some(event) = events.next().await {
4697 let event = event?;
4698 response.send(proto::StreamCompleteWithLanguageModelResponse {
4699 event: serde_json::to_string(&event)?,
4700 })?;
4701 }
4702 }
4703 Some(proto::LanguageModelProvider::Google) => {
4704 let api_key = config
4705 .google_ai_api_key
4706 .as_ref()
4707 .context("no Google AI API key configured on the server")?;
4708 let mut events = google_ai::stream_generate_content(
4709 session.http_client.as_ref(),
4710 google_ai::API_URL,
4711 api_key,
4712 serde_json::from_str(&request.request)?,
4713 )
4714 .await?;
4715 while let Some(event) = events.next().await {
4716 let event = event?;
4717 response.send(proto::StreamCompleteWithLanguageModelResponse {
4718 event: serde_json::to_string(&event)?,
4719 })?;
4720 }
4721 }
4722 Some(proto::LanguageModelProvider::Zed) => {
4723 let api_key = config
4724 .qwen2_7b_api_key
4725 .as_ref()
4726 .context("no Qwen2-7B API key configured on the server")?;
4727 let api_url = config
4728 .qwen2_7b_api_url
4729 .as_ref()
4730 .context("no Qwen2-7B URL configured on the server")?;
4731 let mut events = open_ai::stream_completion(
4732 session.http_client.as_ref(),
4733 &api_url,
4734 api_key,
4735 serde_json::from_str(&request.request)?,
4736 None,
4737 )
4738 .await?;
4739 while let Some(event) = events.next().await {
4740 let event = event?;
4741 response.send(proto::StreamCompleteWithLanguageModelResponse {
4742 event: serde_json::to_string(&event)?,
4743 })?;
4744 }
4745 }
4746 None => return Err(anyhow!("unknown provider"))?,
4747 }
4748
4749 Ok(())
4750}
4751
4752async fn count_language_model_tokens(
4753 request: proto::CountLanguageModelTokens,
4754 response: Response<proto::CountLanguageModelTokens>,
4755 session: Session,
4756 config: &Config,
4757) -> Result<()> {
4758 let Some(session) = session.for_user() else {
4759 return Err(anyhow!("user not found"))?;
4760 };
4761 authorize_access_to_language_models(&session).await?;
4762
4763 let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4764 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4765 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4766 };
4767
4768 session
4769 .rate_limiter
4770 .check(&*rate_limit, session.user_id())
4771 .await?;
4772
4773 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4774 Some(proto::LanguageModelProvider::Google) => {
4775 let api_key = config
4776 .google_ai_api_key
4777 .as_ref()
4778 .context("no Google AI API key configured on the server")?;
4779 google_ai::count_tokens(
4780 session.http_client.as_ref(),
4781 google_ai::API_URL,
4782 api_key,
4783 serde_json::from_str(&request.request)?,
4784 )
4785 .await?
4786 }
4787 _ => return Err(anyhow!("unsupported provider"))?,
4788 };
4789
4790 response.send(proto::CountLanguageModelTokensResponse {
4791 token_count: result.total_tokens as u32,
4792 })?;
4793
4794 Ok(())
4795}
4796
4797struct ZedProCountLanguageModelTokensRateLimit;
4798
4799impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4800 fn capacity(&self) -> usize {
4801 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4802 .ok()
4803 .and_then(|v| v.parse().ok())
4804 .unwrap_or(600) // Picked arbitrarily
4805 }
4806
4807 fn refill_duration(&self) -> chrono::Duration {
4808 chrono::Duration::hours(1)
4809 }
4810
4811 fn db_name(&self) -> &'static str {
4812 "zed-pro:count-language-model-tokens"
4813 }
4814}
4815
4816struct FreeCountLanguageModelTokensRateLimit;
4817
4818impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4819 fn capacity(&self) -> usize {
4820 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4821 .ok()
4822 .and_then(|v| v.parse().ok())
4823 .unwrap_or(600 / 10) // Picked arbitrarily
4824 }
4825
4826 fn refill_duration(&self) -> chrono::Duration {
4827 chrono::Duration::hours(1)
4828 }
4829
4830 fn db_name(&self) -> &'static str {
4831 "free:count-language-model-tokens"
4832 }
4833}
4834
4835struct ZedProComputeEmbeddingsRateLimit;
4836
4837impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4838 fn capacity(&self) -> usize {
4839 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4840 .ok()
4841 .and_then(|v| v.parse().ok())
4842 .unwrap_or(5000) // Picked arbitrarily
4843 }
4844
4845 fn refill_duration(&self) -> chrono::Duration {
4846 chrono::Duration::hours(1)
4847 }
4848
4849 fn db_name(&self) -> &'static str {
4850 "zed-pro:compute-embeddings"
4851 }
4852}
4853
4854struct FreeComputeEmbeddingsRateLimit;
4855
4856impl RateLimit for FreeComputeEmbeddingsRateLimit {
4857 fn capacity(&self) -> usize {
4858 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4859 .ok()
4860 .and_then(|v| v.parse().ok())
4861 .unwrap_or(5000 / 10) // Picked arbitrarily
4862 }
4863
4864 fn refill_duration(&self) -> chrono::Duration {
4865 chrono::Duration::hours(1)
4866 }
4867
4868 fn db_name(&self) -> &'static str {
4869 "free:compute-embeddings"
4870 }
4871}
4872
4873async fn compute_embeddings(
4874 request: proto::ComputeEmbeddings,
4875 response: Response<proto::ComputeEmbeddings>,
4876 session: UserSession,
4877 api_key: Option<Arc<str>>,
4878) -> Result<()> {
4879 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4880 authorize_access_to_language_models(&session).await?;
4881
4882 let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4883 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4884 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4885 };
4886
4887 session
4888 .rate_limiter
4889 .check(&*rate_limit, session.user_id())
4890 .await?;
4891
4892 let embeddings = match request.model.as_str() {
4893 "openai/text-embedding-3-small" => {
4894 open_ai::embed(
4895 session.http_client.as_ref(),
4896 OPEN_AI_API_URL,
4897 &api_key,
4898 OpenAiEmbeddingModel::TextEmbedding3Small,
4899 request.texts.iter().map(|text| text.as_str()),
4900 )
4901 .await?
4902 }
4903 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4904 };
4905
4906 let embeddings = request
4907 .texts
4908 .iter()
4909 .map(|text| {
4910 let mut hasher = sha2::Sha256::new();
4911 hasher.update(text.as_bytes());
4912 let result = hasher.finalize();
4913 result.to_vec()
4914 })
4915 .zip(
4916 embeddings
4917 .data
4918 .into_iter()
4919 .map(|embedding| embedding.embedding),
4920 )
4921 .collect::<HashMap<_, _>>();
4922
4923 let db = session.db().await;
4924 db.save_embeddings(&request.model, &embeddings)
4925 .await
4926 .context("failed to save embeddings")
4927 .trace_err();
4928
4929 response.send(proto::ComputeEmbeddingsResponse {
4930 embeddings: embeddings
4931 .into_iter()
4932 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4933 .collect(),
4934 })?;
4935 Ok(())
4936}
4937
4938async fn get_cached_embeddings(
4939 request: proto::GetCachedEmbeddings,
4940 response: Response<proto::GetCachedEmbeddings>,
4941 session: UserSession,
4942) -> Result<()> {
4943 authorize_access_to_language_models(&session).await?;
4944
4945 let db = session.db().await;
4946 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4947
4948 response.send(proto::GetCachedEmbeddingsResponse {
4949 embeddings: embeddings
4950 .into_iter()
4951 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4952 .collect(),
4953 })?;
4954 Ok(())
4955}
4956
4957async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4958 let db = session.db().await;
4959 let flags = db.get_user_flags(session.user_id()).await?;
4960 if flags.iter().any(|flag| flag == "language-models") {
4961 return Ok(());
4962 }
4963
4964 Err(anyhow!("permission denied"))?
4965}
4966
4967/// Get a Supermaven API key for the user
4968async fn get_supermaven_api_key(
4969 _request: proto::GetSupermavenApiKey,
4970 response: Response<proto::GetSupermavenApiKey>,
4971 session: UserSession,
4972) -> Result<()> {
4973 let user_id: String = session.user_id().to_string();
4974 if !session.is_staff() {
4975 return Err(anyhow!("supermaven not enabled for this account"))?;
4976 }
4977
4978 let email = session
4979 .email()
4980 .ok_or_else(|| anyhow!("user must have an email"))?;
4981
4982 let supermaven_admin_api = session
4983 .supermaven_client
4984 .as_ref()
4985 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4986
4987 let result = supermaven_admin_api
4988 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4989 .await?;
4990
4991 response.send(proto::GetSupermavenApiKeyResponse {
4992 api_key: result.api_key,
4993 })?;
4994
4995 Ok(())
4996}
4997
4998/// Start receiving chat updates for a channel
4999async fn join_channel_chat(
5000 request: proto::JoinChannelChat,
5001 response: Response<proto::JoinChannelChat>,
5002 session: UserSession,
5003) -> Result<()> {
5004 let channel_id = ChannelId::from_proto(request.channel_id);
5005
5006 let db = session.db().await;
5007 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
5008 .await?;
5009 let messages = db
5010 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
5011 .await?;
5012 response.send(proto::JoinChannelChatResponse {
5013 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5014 messages,
5015 })?;
5016 Ok(())
5017}
5018
5019/// Stop receiving chat updates for a channel
5020async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
5021 let channel_id = ChannelId::from_proto(request.channel_id);
5022 session
5023 .db()
5024 .await
5025 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
5026 .await?;
5027 Ok(())
5028}
5029
5030/// Retrieve the chat history for a channel
5031async fn get_channel_messages(
5032 request: proto::GetChannelMessages,
5033 response: Response<proto::GetChannelMessages>,
5034 session: UserSession,
5035) -> Result<()> {
5036 let channel_id = ChannelId::from_proto(request.channel_id);
5037 let messages = session
5038 .db()
5039 .await
5040 .get_channel_messages(
5041 channel_id,
5042 session.user_id(),
5043 MESSAGE_COUNT_PER_PAGE,
5044 Some(MessageId::from_proto(request.before_message_id)),
5045 )
5046 .await?;
5047 response.send(proto::GetChannelMessagesResponse {
5048 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5049 messages,
5050 })?;
5051 Ok(())
5052}
5053
5054/// Retrieve specific chat messages
5055async fn get_channel_messages_by_id(
5056 request: proto::GetChannelMessagesById,
5057 response: Response<proto::GetChannelMessagesById>,
5058 session: UserSession,
5059) -> Result<()> {
5060 let message_ids = request
5061 .message_ids
5062 .iter()
5063 .map(|id| MessageId::from_proto(*id))
5064 .collect::<Vec<_>>();
5065 let messages = session
5066 .db()
5067 .await
5068 .get_channel_messages_by_id(session.user_id(), &message_ids)
5069 .await?;
5070 response.send(proto::GetChannelMessagesResponse {
5071 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5072 messages,
5073 })?;
5074 Ok(())
5075}
5076
5077/// Retrieve the current users notifications
5078async fn get_notifications(
5079 request: proto::GetNotifications,
5080 response: Response<proto::GetNotifications>,
5081 session: UserSession,
5082) -> Result<()> {
5083 let notifications = session
5084 .db()
5085 .await
5086 .get_notifications(
5087 session.user_id(),
5088 NOTIFICATION_COUNT_PER_PAGE,
5089 request
5090 .before_id
5091 .map(|id| db::NotificationId::from_proto(id)),
5092 )
5093 .await?;
5094 response.send(proto::GetNotificationsResponse {
5095 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
5096 notifications,
5097 })?;
5098 Ok(())
5099}
5100
5101/// Mark notifications as read
5102async fn mark_notification_as_read(
5103 request: proto::MarkNotificationRead,
5104 response: Response<proto::MarkNotificationRead>,
5105 session: UserSession,
5106) -> Result<()> {
5107 let database = &session.db().await;
5108 let notifications = database
5109 .mark_notification_as_read_by_id(
5110 session.user_id(),
5111 NotificationId::from_proto(request.notification_id),
5112 )
5113 .await?;
5114 send_notifications(
5115 &*session.connection_pool().await,
5116 &session.peer,
5117 notifications,
5118 );
5119 response.send(proto::Ack {})?;
5120 Ok(())
5121}
5122
5123/// Get the current users information
5124async fn get_private_user_info(
5125 _request: proto::GetPrivateUserInfo,
5126 response: Response<proto::GetPrivateUserInfo>,
5127 session: UserSession,
5128) -> Result<()> {
5129 let db = session.db().await;
5130
5131 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
5132 let user = db
5133 .get_user_by_id(session.user_id())
5134 .await?
5135 .ok_or_else(|| anyhow!("user not found"))?;
5136 let flags = db.get_user_flags(session.user_id()).await?;
5137
5138 response.send(proto::GetPrivateUserInfoResponse {
5139 metrics_id,
5140 staff: user.admin,
5141 flags,
5142 })?;
5143 Ok(())
5144}
5145
5146fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
5147 let message = match message {
5148 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5149 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5150 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5151 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5152 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5153 code: frame.code.into(),
5154 reason: frame.reason,
5155 })),
5156 // We should never receive a frame while reading the message, according
5157 // to the `tungstenite` maintainers:
5158 //
5159 // > It cannot occur when you read messages from the WebSocket, but it
5160 // > can be used when you want to send the raw frames (e.g. you want to
5161 // > send the frames to the WebSocket without composing the full message first).
5162 // >
5163 // > — https://github.com/snapview/tungstenite-rs/issues/268
5164 TungsteniteMessage::Frame(_) => {
5165 bail!("received an unexpected frame while reading the message")
5166 }
5167 };
5168
5169 Ok(message)
5170}
5171
5172fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5173 match message {
5174 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5175 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5176 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5177 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5178 AxumMessage::Close(frame) => {
5179 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5180 code: frame.code.into(),
5181 reason: frame.reason,
5182 }))
5183 }
5184 }
5185}
5186
5187fn notify_membership_updated(
5188 connection_pool: &mut ConnectionPool,
5189 result: MembershipUpdated,
5190 user_id: UserId,
5191 peer: &Peer,
5192) {
5193 for membership in &result.new_channels.channel_memberships {
5194 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5195 }
5196 for channel_id in &result.removed_channels {
5197 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5198 }
5199
5200 let user_channels_update = proto::UpdateUserChannels {
5201 channel_memberships: result
5202 .new_channels
5203 .channel_memberships
5204 .iter()
5205 .map(|cm| proto::ChannelMembership {
5206 channel_id: cm.channel_id.to_proto(),
5207 role: cm.role.into(),
5208 })
5209 .collect(),
5210 ..Default::default()
5211 };
5212
5213 let mut update = build_channels_update(result.new_channels);
5214 update.delete_channels = result
5215 .removed_channels
5216 .into_iter()
5217 .map(|id| id.to_proto())
5218 .collect();
5219 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5220
5221 for connection_id in connection_pool.user_connection_ids(user_id) {
5222 peer.send(connection_id, user_channels_update.clone())
5223 .trace_err();
5224 peer.send(connection_id, update.clone()).trace_err();
5225 }
5226}
5227
5228fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5229 proto::UpdateUserChannels {
5230 channel_memberships: channels
5231 .channel_memberships
5232 .iter()
5233 .map(|m| proto::ChannelMembership {
5234 channel_id: m.channel_id.to_proto(),
5235 role: m.role.into(),
5236 })
5237 .collect(),
5238 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5239 observed_channel_message_id: channels.observed_channel_messages.clone(),
5240 }
5241}
5242
5243fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5244 let mut update = proto::UpdateChannels::default();
5245
5246 for channel in channels.channels {
5247 update.channels.push(channel.to_proto());
5248 }
5249
5250 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5251 update.latest_channel_message_ids = channels.latest_channel_messages;
5252
5253 for (channel_id, participants) in channels.channel_participants {
5254 update
5255 .channel_participants
5256 .push(proto::ChannelParticipants {
5257 channel_id: channel_id.to_proto(),
5258 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5259 });
5260 }
5261
5262 for channel in channels.invited_channels {
5263 update.channel_invitations.push(channel.to_proto());
5264 }
5265
5266 update.hosted_projects = channels.hosted_projects;
5267 update
5268}
5269
5270fn build_initial_contacts_update(
5271 contacts: Vec<db::Contact>,
5272 pool: &ConnectionPool,
5273) -> proto::UpdateContacts {
5274 let mut update = proto::UpdateContacts::default();
5275
5276 for contact in contacts {
5277 match contact {
5278 db::Contact::Accepted { user_id, busy } => {
5279 update.contacts.push(contact_for_user(user_id, busy, &pool));
5280 }
5281 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5282 db::Contact::Incoming { user_id } => {
5283 update
5284 .incoming_requests
5285 .push(proto::IncomingContactRequest {
5286 requester_id: user_id.to_proto(),
5287 })
5288 }
5289 }
5290 }
5291
5292 update
5293}
5294
5295fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5296 proto::Contact {
5297 user_id: user_id.to_proto(),
5298 online: pool.is_user_online(user_id),
5299 busy,
5300 }
5301}
5302
5303fn room_updated(room: &proto::Room, peer: &Peer) {
5304 broadcast(
5305 None,
5306 room.participants
5307 .iter()
5308 .filter_map(|participant| Some(participant.peer_id?.into())),
5309 |peer_id| {
5310 peer.send(
5311 peer_id,
5312 proto::RoomUpdated {
5313 room: Some(room.clone()),
5314 },
5315 )
5316 },
5317 );
5318}
5319
5320fn channel_updated(
5321 channel: &db::channel::Model,
5322 room: &proto::Room,
5323 peer: &Peer,
5324 pool: &ConnectionPool,
5325) {
5326 let participants = room
5327 .participants
5328 .iter()
5329 .map(|p| p.user_id)
5330 .collect::<Vec<_>>();
5331
5332 broadcast(
5333 None,
5334 pool.channel_connection_ids(channel.root_id())
5335 .filter_map(|(channel_id, role)| {
5336 role.can_see_channel(channel.visibility).then(|| channel_id)
5337 }),
5338 |peer_id| {
5339 peer.send(
5340 peer_id,
5341 proto::UpdateChannels {
5342 channel_participants: vec![proto::ChannelParticipants {
5343 channel_id: channel.id.to_proto(),
5344 participant_user_ids: participants.clone(),
5345 }],
5346 ..Default::default()
5347 },
5348 )
5349 },
5350 );
5351}
5352
5353async fn send_dev_server_projects_update(
5354 user_id: UserId,
5355 mut status: proto::DevServerProjectsUpdate,
5356 session: &Session,
5357) {
5358 let pool = session.connection_pool().await;
5359 for dev_server in &mut status.dev_servers {
5360 dev_server.status =
5361 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5362 }
5363 let connections = pool.user_connection_ids(user_id);
5364 for connection_id in connections {
5365 session.peer.send(connection_id, status.clone()).trace_err();
5366 }
5367}
5368
5369async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5370 let db = session.db().await;
5371
5372 let contacts = db.get_contacts(user_id).await?;
5373 let busy = db.is_user_busy(user_id).await?;
5374
5375 let pool = session.connection_pool().await;
5376 let updated_contact = contact_for_user(user_id, busy, &pool);
5377 for contact in contacts {
5378 if let db::Contact::Accepted {
5379 user_id: contact_user_id,
5380 ..
5381 } = contact
5382 {
5383 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5384 session
5385 .peer
5386 .send(
5387 contact_conn_id,
5388 proto::UpdateContacts {
5389 contacts: vec![updated_contact.clone()],
5390 remove_contacts: Default::default(),
5391 incoming_requests: Default::default(),
5392 remove_incoming_requests: Default::default(),
5393 outgoing_requests: Default::default(),
5394 remove_outgoing_requests: Default::default(),
5395 },
5396 )
5397 .trace_err();
5398 }
5399 }
5400 }
5401 Ok(())
5402}
5403
5404async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5405 log::info!("lost dev server connection, unsharing projects");
5406 let project_ids = session
5407 .db()
5408 .await
5409 .get_stale_dev_server_projects(session.connection_id)
5410 .await?;
5411
5412 for project_id in project_ids {
5413 // not unshare re-checks the connection ids match, so we get away with no transaction
5414 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5415 }
5416
5417 let user_id = session.dev_server().user_id;
5418 let update = session
5419 .db()
5420 .await
5421 .dev_server_projects_update(user_id)
5422 .await?;
5423
5424 send_dev_server_projects_update(user_id, update, session).await;
5425
5426 Ok(())
5427}
5428
5429async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5430 let mut contacts_to_update = HashSet::default();
5431
5432 let room_id;
5433 let canceled_calls_to_user_ids;
5434 let live_kit_room;
5435 let delete_live_kit_room;
5436 let room;
5437 let channel;
5438
5439 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5440 contacts_to_update.insert(session.user_id());
5441
5442 for project in left_room.left_projects.values() {
5443 project_left(project, session);
5444 }
5445
5446 room_id = RoomId::from_proto(left_room.room.id);
5447 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5448 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5449 delete_live_kit_room = left_room.deleted;
5450 room = mem::take(&mut left_room.room);
5451 channel = mem::take(&mut left_room.channel);
5452
5453 room_updated(&room, &session.peer);
5454 } else {
5455 return Ok(());
5456 }
5457
5458 if let Some(channel) = channel {
5459 channel_updated(
5460 &channel,
5461 &room,
5462 &session.peer,
5463 &*session.connection_pool().await,
5464 );
5465 }
5466
5467 {
5468 let pool = session.connection_pool().await;
5469 for canceled_user_id in canceled_calls_to_user_ids {
5470 for connection_id in pool.user_connection_ids(canceled_user_id) {
5471 session
5472 .peer
5473 .send(
5474 connection_id,
5475 proto::CallCanceled {
5476 room_id: room_id.to_proto(),
5477 },
5478 )
5479 .trace_err();
5480 }
5481 contacts_to_update.insert(canceled_user_id);
5482 }
5483 }
5484
5485 for contact_user_id in contacts_to_update {
5486 update_user_contacts(contact_user_id, &session).await?;
5487 }
5488
5489 if let Some(live_kit) = session.live_kit_client.as_ref() {
5490 live_kit
5491 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5492 .await
5493 .trace_err();
5494
5495 if delete_live_kit_room {
5496 live_kit.delete_room(live_kit_room).await.trace_err();
5497 }
5498 }
5499
5500 Ok(())
5501}
5502
5503async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5504 let left_channel_buffers = session
5505 .db()
5506 .await
5507 .leave_channel_buffers(session.connection_id)
5508 .await?;
5509
5510 for left_buffer in left_channel_buffers {
5511 channel_buffer_updated(
5512 session.connection_id,
5513 left_buffer.connections,
5514 &proto::UpdateChannelBufferCollaborators {
5515 channel_id: left_buffer.channel_id.to_proto(),
5516 collaborators: left_buffer.collaborators,
5517 },
5518 &session.peer,
5519 );
5520 }
5521
5522 Ok(())
5523}
5524
5525fn project_left(project: &db::LeftProject, session: &UserSession) {
5526 for connection_id in &project.connection_ids {
5527 if project.should_unshare {
5528 session
5529 .peer
5530 .send(
5531 *connection_id,
5532 proto::UnshareProject {
5533 project_id: project.id.to_proto(),
5534 },
5535 )
5536 .trace_err();
5537 } else {
5538 session
5539 .peer
5540 .send(
5541 *connection_id,
5542 proto::RemoveProjectCollaborator {
5543 project_id: project.id.to_proto(),
5544 peer_id: Some(session.connection_id.into()),
5545 },
5546 )
5547 .trace_err();
5548 }
5549 }
5550}
5551
5552pub trait ResultExt {
5553 type Ok;
5554
5555 fn trace_err(self) -> Option<Self::Ok>;
5556}
5557
5558impl<T, E> ResultExt for Result<T, E>
5559where
5560 E: std::fmt::Debug,
5561{
5562 type Ok = T;
5563
5564 #[track_caller]
5565 fn trace_err(self) -> Option<T> {
5566 match self {
5567 Ok(value) => Some(value),
5568 Err(error) => {
5569 tracing::error!("{:?}", error);
5570 None
5571 }
5572 }
5573 }
5574}