1mod connection_pool;
2
3use crate::api::CloudflareIpCountryHeader;
4use crate::llm::LlmTokenClaims;
5use crate::{
6 auth,
7 db::{
8 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
9 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
10 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
11 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
12 ServerId, UpdatedChannelMessage, User, UserId,
13 },
14 executor::Executor,
15 AppState, Config, Error, RateLimit, Result,
16};
17use anyhow::{anyhow, bail, Context as _};
18use async_tungstenite::tungstenite::{
19 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
20};
21use axum::{
22 body::Body,
23 extract::{
24 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
25 ConnectInfo, WebSocketUpgrade,
26 },
27 headers::{Header, HeaderName},
28 http::StatusCode,
29 middleware,
30 response::IntoResponse,
31 routing::get,
32 Extension, Router, TypedHeader,
33};
34use collections::{HashMap, HashSet};
35pub use connection_pool::{ConnectionPool, ZedVersion};
36use core::fmt::{self, Debug, Formatter};
37use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
38use sha2::Digest;
39use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
40
41use futures::{
42 channel::oneshot,
43 future::{self, BoxFuture},
44 stream::FuturesUnordered,
45 FutureExt, SinkExt, StreamExt, TryStreamExt,
46};
47use http_client::IsahcHttpClient;
48use prometheus::{register_int_gauge, IntGauge};
49use rpc::{
50 proto::{
51 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
52 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
53 },
54 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
55};
56use semantic_version::SemanticVersion;
57use serde::{Serialize, Serializer};
58use std::{
59 any::TypeId,
60 future::Future,
61 marker::PhantomData,
62 mem,
63 net::SocketAddr,
64 ops::{Deref, DerefMut},
65 rc::Rc,
66 sync::{
67 atomic::{AtomicBool, Ordering::SeqCst},
68 Arc, OnceLock,
69 },
70 time::{Duration, Instant},
71};
72use time::OffsetDateTime;
73use tokio::sync::{watch, Semaphore};
74use tower::ServiceBuilder;
75use tracing::{
76 field::{self},
77 info_span, instrument, Instrument,
78};
79
80use self::connection_pool::VersionedMessage;
81
82pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
83
84// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
85pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
86
87const MESSAGE_COUNT_PER_PAGE: usize = 100;
88const MAX_MESSAGE_LEN: usize = 1024;
89const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
90
91type MessageHandler =
92 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
93
94struct Response<R> {
95 peer: Arc<Peer>,
96 receipt: Receipt<R>,
97 responded: Arc<AtomicBool>,
98}
99
100impl<R: RequestMessage> Response<R> {
101 fn send(self, payload: R::Response) -> Result<()> {
102 self.responded.store(true, SeqCst);
103 self.peer.respond(self.receipt, payload)?;
104 Ok(())
105 }
106}
107
108struct StreamingResponse<R: RequestMessage> {
109 peer: Arc<Peer>,
110 receipt: Receipt<R>,
111}
112
113impl<R: RequestMessage> StreamingResponse<R> {
114 fn send(&self, payload: R::Response) -> Result<()> {
115 self.peer.respond(self.receipt, payload)?;
116 Ok(())
117 }
118}
119
120#[derive(Clone, Debug)]
121pub enum Principal {
122 User(User),
123 Impersonated { user: User, admin: User },
124 DevServer(dev_server::Model),
125}
126
127impl Principal {
128 fn update_span(&self, span: &tracing::Span) {
129 match &self {
130 Principal::User(user) => {
131 span.record("user_id", &user.id.0);
132 span.record("login", &user.github_login);
133 }
134 Principal::Impersonated { user, admin } => {
135 span.record("user_id", &user.id.0);
136 span.record("login", &user.github_login);
137 span.record("impersonator", &admin.github_login);
138 }
139 Principal::DevServer(dev_server) => {
140 span.record("dev_server_id", &dev_server.id.0);
141 }
142 }
143 }
144}
145
146#[derive(Clone)]
147struct Session {
148 principal: Principal,
149 connection_id: ConnectionId,
150 db: Arc<tokio::sync::Mutex<DbHandle>>,
151 peer: Arc<Peer>,
152 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
153 app_state: Arc<AppState>,
154 supermaven_client: Option<Arc<SupermavenAdminApi>>,
155 http_client: Arc<IsahcHttpClient>,
156 /// The GeoIP country code for the user.
157 #[allow(unused)]
158 geoip_country_code: Option<String>,
159 _executor: Executor,
160}
161
162impl Session {
163 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
164 #[cfg(test)]
165 tokio::task::yield_now().await;
166 let guard = self.db.lock().await;
167 #[cfg(test)]
168 tokio::task::yield_now().await;
169 guard
170 }
171
172 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
173 #[cfg(test)]
174 tokio::task::yield_now().await;
175 let guard = self.connection_pool.lock();
176 ConnectionPoolGuard {
177 guard,
178 _not_send: PhantomData,
179 }
180 }
181
182 fn for_user(self) -> Option<UserSession> {
183 UserSession::new(self)
184 }
185
186 fn for_dev_server(self) -> Option<DevServerSession> {
187 DevServerSession::new(self)
188 }
189
190 fn user_id(&self) -> Option<UserId> {
191 match &self.principal {
192 Principal::User(user) => Some(user.id),
193 Principal::Impersonated { user, .. } => Some(user.id),
194 Principal::DevServer(_) => None,
195 }
196 }
197
198 fn is_staff(&self) -> bool {
199 match &self.principal {
200 Principal::User(user) => user.admin,
201 Principal::Impersonated { .. } => true,
202 Principal::DevServer(_) => false,
203 }
204 }
205
206 pub async fn current_plan(&self) -> anyhow::Result<proto::Plan> {
207 if self.is_staff() {
208 return Ok(proto::Plan::ZedPro);
209 }
210
211 let Some(user_id) = self.user_id() else {
212 return Ok(proto::Plan::Free);
213 };
214
215 let db = self.db().await;
216 if db.has_active_billing_subscription(user_id).await? {
217 Ok(proto::Plan::ZedPro)
218 } else {
219 Ok(proto::Plan::Free)
220 }
221 }
222
223 fn dev_server_id(&self) -> Option<DevServerId> {
224 match &self.principal {
225 Principal::User(_) | Principal::Impersonated { .. } => None,
226 Principal::DevServer(dev_server) => Some(dev_server.id),
227 }
228 }
229
230 fn principal_id(&self) -> PrincipalId {
231 match &self.principal {
232 Principal::User(user) => PrincipalId::UserId(user.id),
233 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
234 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
235 }
236 }
237}
238
239impl Debug for Session {
240 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
241 let mut result = f.debug_struct("Session");
242 match &self.principal {
243 Principal::User(user) => {
244 result.field("user", &user.github_login);
245 }
246 Principal::Impersonated { user, admin } => {
247 result.field("user", &user.github_login);
248 result.field("impersonator", &admin.github_login);
249 }
250 Principal::DevServer(dev_server) => {
251 result.field("dev_server", &dev_server.id);
252 }
253 }
254 result.field("connection_id", &self.connection_id).finish()
255 }
256}
257
258struct UserSession(Session);
259
260impl UserSession {
261 pub fn new(s: Session) -> Option<Self> {
262 s.user_id().map(|_| UserSession(s))
263 }
264 pub fn user_id(&self) -> UserId {
265 self.0.user_id().unwrap()
266 }
267
268 pub fn email(&self) -> Option<String> {
269 match &self.0.principal {
270 Principal::User(user) => user.email_address.clone(),
271 Principal::Impersonated { user, .. } => user.email_address.clone(),
272 Principal::DevServer(..) => None,
273 }
274 }
275}
276
277impl Deref for UserSession {
278 type Target = Session;
279
280 fn deref(&self) -> &Self::Target {
281 &self.0
282 }
283}
284impl DerefMut for UserSession {
285 fn deref_mut(&mut self) -> &mut Self::Target {
286 &mut self.0
287 }
288}
289
290struct DevServerSession(Session);
291
292impl DevServerSession {
293 pub fn new(s: Session) -> Option<Self> {
294 s.dev_server_id().map(|_| DevServerSession(s))
295 }
296 pub fn dev_server_id(&self) -> DevServerId {
297 self.0.dev_server_id().unwrap()
298 }
299
300 fn dev_server(&self) -> &dev_server::Model {
301 match &self.0.principal {
302 Principal::DevServer(dev_server) => dev_server,
303 _ => unreachable!(),
304 }
305 }
306}
307
308impl Deref for DevServerSession {
309 type Target = Session;
310
311 fn deref(&self) -> &Self::Target {
312 &self.0
313 }
314}
315impl DerefMut for DevServerSession {
316 fn deref_mut(&mut self) -> &mut Self::Target {
317 &mut self.0
318 }
319}
320
321fn user_handler<M: RequestMessage, Fut>(
322 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
323) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
324where
325 Fut: Send + Future<Output = Result<()>>,
326{
327 let handler = Arc::new(handler);
328 move |message, response, session| {
329 let handler = handler.clone();
330 Box::pin(async move {
331 if let Some(user_session) = session.for_user() {
332 Ok(handler(message, response, user_session).await?)
333 } else {
334 Err(Error::Internal(anyhow!(
335 "must be a user to call {}",
336 M::NAME
337 )))
338 }
339 })
340 }
341}
342
343fn dev_server_handler<M: RequestMessage, Fut>(
344 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
345) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
346where
347 Fut: Send + Future<Output = Result<()>>,
348{
349 let handler = Arc::new(handler);
350 move |message, response, session| {
351 let handler = handler.clone();
352 Box::pin(async move {
353 if let Some(dev_server_session) = session.for_dev_server() {
354 Ok(handler(message, response, dev_server_session).await?)
355 } else {
356 Err(Error::Internal(anyhow!(
357 "must be a dev server to call {}",
358 M::NAME
359 )))
360 }
361 })
362 }
363}
364
365fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
366 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
367) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
368where
369 InnertRetFut: Send + Future<Output = Result<()>>,
370{
371 let handler = Arc::new(handler);
372 move |message, session| {
373 let handler = handler.clone();
374 Box::pin(async move {
375 if let Some(user_session) = session.for_user() {
376 Ok(handler(message, user_session).await?)
377 } else {
378 Err(Error::Internal(anyhow!(
379 "must be a user to call {}",
380 M::NAME
381 )))
382 }
383 })
384 }
385}
386
387struct DbHandle(Arc<Database>);
388
389impl Deref for DbHandle {
390 type Target = Database;
391
392 fn deref(&self) -> &Self::Target {
393 self.0.as_ref()
394 }
395}
396
397pub struct Server {
398 id: parking_lot::Mutex<ServerId>,
399 peer: Arc<Peer>,
400 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
401 app_state: Arc<AppState>,
402 handlers: HashMap<TypeId, MessageHandler>,
403 teardown: watch::Sender<bool>,
404}
405
406pub(crate) struct ConnectionPoolGuard<'a> {
407 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
408 _not_send: PhantomData<Rc<()>>,
409}
410
411#[derive(Serialize)]
412pub struct ServerSnapshot<'a> {
413 peer: &'a Peer,
414 #[serde(serialize_with = "serialize_deref")]
415 connection_pool: ConnectionPoolGuard<'a>,
416}
417
418pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
419where
420 S: Serializer,
421 T: Deref<Target = U>,
422 U: Serialize,
423{
424 Serialize::serialize(value.deref(), serializer)
425}
426
427impl Server {
428 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
429 let mut server = Self {
430 id: parking_lot::Mutex::new(id),
431 peer: Peer::new(id.0 as u32),
432 app_state: app_state.clone(),
433 connection_pool: Default::default(),
434 handlers: Default::default(),
435 teardown: watch::channel(false).0,
436 };
437
438 server
439 .add_request_handler(ping)
440 .add_request_handler(user_handler(create_room))
441 .add_request_handler(user_handler(join_room))
442 .add_request_handler(user_handler(rejoin_room))
443 .add_request_handler(user_handler(leave_room))
444 .add_request_handler(user_handler(set_room_participant_role))
445 .add_request_handler(user_handler(call))
446 .add_request_handler(user_handler(cancel_call))
447 .add_message_handler(user_message_handler(decline_call))
448 .add_request_handler(user_handler(update_participant_location))
449 .add_request_handler(user_handler(share_project))
450 .add_message_handler(unshare_project)
451 .add_request_handler(user_handler(join_project))
452 .add_request_handler(user_handler(join_hosted_project))
453 .add_request_handler(user_handler(rejoin_dev_server_projects))
454 .add_request_handler(user_handler(create_dev_server_project))
455 .add_request_handler(user_handler(update_dev_server_project))
456 .add_request_handler(user_handler(delete_dev_server_project))
457 .add_request_handler(user_handler(create_dev_server))
458 .add_request_handler(user_handler(regenerate_dev_server_token))
459 .add_request_handler(user_handler(rename_dev_server))
460 .add_request_handler(user_handler(delete_dev_server))
461 .add_request_handler(user_handler(list_remote_directory))
462 .add_request_handler(dev_server_handler(share_dev_server_project))
463 .add_request_handler(dev_server_handler(shutdown_dev_server))
464 .add_request_handler(dev_server_handler(reconnect_dev_server))
465 .add_message_handler(user_message_handler(leave_project))
466 .add_request_handler(update_project)
467 .add_request_handler(update_worktree)
468 .add_message_handler(start_language_server)
469 .add_message_handler(update_language_server)
470 .add_message_handler(update_diagnostic_summary)
471 .add_message_handler(update_worktree_settings)
472 .add_request_handler(user_handler(
473 forward_project_request_for_owner::<proto::TaskContextForLocation>,
474 ))
475 .add_request_handler(user_handler(
476 forward_project_request_for_owner::<proto::TaskTemplates>,
477 ))
478 .add_request_handler(user_handler(
479 forward_read_only_project_request::<proto::GetHover>,
480 ))
481 .add_request_handler(user_handler(
482 forward_read_only_project_request::<proto::GetDefinition>,
483 ))
484 .add_request_handler(user_handler(
485 forward_read_only_project_request::<proto::GetTypeDefinition>,
486 ))
487 .add_request_handler(user_handler(
488 forward_read_only_project_request::<proto::GetReferences>,
489 ))
490 .add_request_handler(user_handler(
491 forward_read_only_project_request::<proto::SearchProject>,
492 ))
493 .add_request_handler(user_handler(
494 forward_read_only_project_request::<proto::GetDocumentHighlights>,
495 ))
496 .add_request_handler(user_handler(
497 forward_read_only_project_request::<proto::GetProjectSymbols>,
498 ))
499 .add_request_handler(user_handler(
500 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
501 ))
502 .add_request_handler(user_handler(
503 forward_read_only_project_request::<proto::OpenBufferById>,
504 ))
505 .add_request_handler(user_handler(
506 forward_read_only_project_request::<proto::SynchronizeBuffers>,
507 ))
508 .add_request_handler(user_handler(
509 forward_read_only_project_request::<proto::InlayHints>,
510 ))
511 .add_request_handler(user_handler(
512 forward_read_only_project_request::<proto::OpenBufferByPath>,
513 ))
514 .add_request_handler(user_handler(
515 forward_mutating_project_request::<proto::GetCompletions>,
516 ))
517 .add_request_handler(user_handler(
518 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
519 ))
520 .add_request_handler(user_handler(
521 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
522 ))
523 .add_request_handler(user_handler(
524 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
525 ))
526 .add_request_handler(user_handler(
527 forward_mutating_project_request::<proto::GetCodeActions>,
528 ))
529 .add_request_handler(user_handler(
530 forward_mutating_project_request::<proto::ApplyCodeAction>,
531 ))
532 .add_request_handler(user_handler(
533 forward_mutating_project_request::<proto::PrepareRename>,
534 ))
535 .add_request_handler(user_handler(
536 forward_mutating_project_request::<proto::PerformRename>,
537 ))
538 .add_request_handler(user_handler(
539 forward_mutating_project_request::<proto::ReloadBuffers>,
540 ))
541 .add_request_handler(user_handler(
542 forward_mutating_project_request::<proto::FormatBuffers>,
543 ))
544 .add_request_handler(user_handler(
545 forward_mutating_project_request::<proto::CreateProjectEntry>,
546 ))
547 .add_request_handler(user_handler(
548 forward_mutating_project_request::<proto::RenameProjectEntry>,
549 ))
550 .add_request_handler(user_handler(
551 forward_mutating_project_request::<proto::CopyProjectEntry>,
552 ))
553 .add_request_handler(user_handler(
554 forward_mutating_project_request::<proto::DeleteProjectEntry>,
555 ))
556 .add_request_handler(user_handler(
557 forward_mutating_project_request::<proto::ExpandProjectEntry>,
558 ))
559 .add_request_handler(user_handler(
560 forward_mutating_project_request::<proto::OnTypeFormatting>,
561 ))
562 .add_request_handler(user_handler(
563 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
564 ))
565 .add_request_handler(user_handler(
566 forward_mutating_project_request::<proto::BlameBuffer>,
567 ))
568 .add_request_handler(user_handler(
569 forward_mutating_project_request::<proto::MultiLspQuery>,
570 ))
571 .add_request_handler(user_handler(
572 forward_mutating_project_request::<proto::RestartLanguageServers>,
573 ))
574 .add_request_handler(user_handler(
575 forward_mutating_project_request::<proto::LinkedEditingRange>,
576 ))
577 .add_message_handler(create_buffer_for_peer)
578 .add_request_handler(update_buffer)
579 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
580 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
581 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
582 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
583 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
584 .add_request_handler(get_users)
585 .add_request_handler(user_handler(fuzzy_search_users))
586 .add_request_handler(user_handler(request_contact))
587 .add_request_handler(user_handler(remove_contact))
588 .add_request_handler(user_handler(respond_to_contact_request))
589 .add_message_handler(subscribe_to_channels)
590 .add_request_handler(user_handler(create_channel))
591 .add_request_handler(user_handler(delete_channel))
592 .add_request_handler(user_handler(invite_channel_member))
593 .add_request_handler(user_handler(remove_channel_member))
594 .add_request_handler(user_handler(set_channel_member_role))
595 .add_request_handler(user_handler(set_channel_visibility))
596 .add_request_handler(user_handler(rename_channel))
597 .add_request_handler(user_handler(join_channel_buffer))
598 .add_request_handler(user_handler(leave_channel_buffer))
599 .add_message_handler(user_message_handler(update_channel_buffer))
600 .add_request_handler(user_handler(rejoin_channel_buffers))
601 .add_request_handler(user_handler(get_channel_members))
602 .add_request_handler(user_handler(respond_to_channel_invite))
603 .add_request_handler(user_handler(join_channel))
604 .add_request_handler(user_handler(join_channel_chat))
605 .add_message_handler(user_message_handler(leave_channel_chat))
606 .add_request_handler(user_handler(send_channel_message))
607 .add_request_handler(user_handler(remove_channel_message))
608 .add_request_handler(user_handler(update_channel_message))
609 .add_request_handler(user_handler(get_channel_messages))
610 .add_request_handler(user_handler(get_channel_messages_by_id))
611 .add_request_handler(user_handler(get_notifications))
612 .add_request_handler(user_handler(mark_notification_as_read))
613 .add_request_handler(user_handler(move_channel))
614 .add_request_handler(user_handler(follow))
615 .add_message_handler(user_message_handler(unfollow))
616 .add_message_handler(user_message_handler(update_followers))
617 .add_request_handler(user_handler(get_private_user_info))
618 .add_request_handler(user_handler(get_llm_api_token))
619 .add_message_handler(user_message_handler(acknowledge_channel_message))
620 .add_message_handler(user_message_handler(acknowledge_buffer_version))
621 .add_request_handler(user_handler(get_supermaven_api_key))
622 .add_request_handler(user_handler(
623 forward_mutating_project_request::<proto::OpenContext>,
624 ))
625 .add_request_handler(user_handler(
626 forward_mutating_project_request::<proto::CreateContext>,
627 ))
628 .add_request_handler(user_handler(
629 forward_mutating_project_request::<proto::SynchronizeContexts>,
630 ))
631 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
632 .add_message_handler(update_context)
633 .add_request_handler({
634 let app_state = app_state.clone();
635 move |request, response, session| {
636 let app_state = app_state.clone();
637 async move {
638 complete_with_language_model(request, response, session, &app_state.config)
639 .await
640 }
641 }
642 })
643 .add_streaming_request_handler({
644 let app_state = app_state.clone();
645 move |request, response, session| {
646 let app_state = app_state.clone();
647 async move {
648 stream_complete_with_language_model(
649 request,
650 response,
651 session,
652 &app_state.config,
653 )
654 .await
655 }
656 }
657 })
658 .add_request_handler({
659 let app_state = app_state.clone();
660 move |request, response, session| {
661 let app_state = app_state.clone();
662 async move {
663 count_language_model_tokens(request, response, session, &app_state.config)
664 .await
665 }
666 }
667 })
668 .add_request_handler({
669 user_handler(move |request, response, session| {
670 get_cached_embeddings(request, response, session)
671 })
672 })
673 .add_request_handler({
674 let app_state = app_state.clone();
675 user_handler(move |request, response, session| {
676 compute_embeddings(
677 request,
678 response,
679 session,
680 app_state.config.openai_api_key.clone(),
681 )
682 })
683 });
684
685 Arc::new(server)
686 }
687
688 pub async fn start(&self) -> Result<()> {
689 let server_id = *self.id.lock();
690 let app_state = self.app_state.clone();
691 let peer = self.peer.clone();
692 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
693 let pool = self.connection_pool.clone();
694 let live_kit_client = self.app_state.live_kit_client.clone();
695
696 let span = info_span!("start server");
697 self.app_state.executor.spawn_detached(
698 async move {
699 tracing::info!("waiting for cleanup timeout");
700 timeout.await;
701 tracing::info!("cleanup timeout expired, retrieving stale rooms");
702 if let Some((room_ids, channel_ids)) = app_state
703 .db
704 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
705 .await
706 .trace_err()
707 {
708 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
709 tracing::info!(
710 stale_channel_buffer_count = channel_ids.len(),
711 "retrieved stale channel buffers"
712 );
713
714 for channel_id in channel_ids {
715 if let Some(refreshed_channel_buffer) = app_state
716 .db
717 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
718 .await
719 .trace_err()
720 {
721 for connection_id in refreshed_channel_buffer.connection_ids {
722 peer.send(
723 connection_id,
724 proto::UpdateChannelBufferCollaborators {
725 channel_id: channel_id.to_proto(),
726 collaborators: refreshed_channel_buffer
727 .collaborators
728 .clone(),
729 },
730 )
731 .trace_err();
732 }
733 }
734 }
735
736 for room_id in room_ids {
737 let mut contacts_to_update = HashSet::default();
738 let mut canceled_calls_to_user_ids = Vec::new();
739 let mut live_kit_room = String::new();
740 let mut delete_live_kit_room = false;
741
742 if let Some(mut refreshed_room) = app_state
743 .db
744 .clear_stale_room_participants(room_id, server_id)
745 .await
746 .trace_err()
747 {
748 tracing::info!(
749 room_id = room_id.0,
750 new_participant_count = refreshed_room.room.participants.len(),
751 "refreshed room"
752 );
753 room_updated(&refreshed_room.room, &peer);
754 if let Some(channel) = refreshed_room.channel.as_ref() {
755 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
756 }
757 contacts_to_update
758 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
759 contacts_to_update
760 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
761 canceled_calls_to_user_ids =
762 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
763 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
764 delete_live_kit_room = refreshed_room.room.participants.is_empty();
765 }
766
767 {
768 let pool = pool.lock();
769 for canceled_user_id in canceled_calls_to_user_ids {
770 for connection_id in pool.user_connection_ids(canceled_user_id) {
771 peer.send(
772 connection_id,
773 proto::CallCanceled {
774 room_id: room_id.to_proto(),
775 },
776 )
777 .trace_err();
778 }
779 }
780 }
781
782 for user_id in contacts_to_update {
783 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
784 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
785 if let Some((busy, contacts)) = busy.zip(contacts) {
786 let pool = pool.lock();
787 let updated_contact = contact_for_user(user_id, busy, &pool);
788 for contact in contacts {
789 if let db::Contact::Accepted {
790 user_id: contact_user_id,
791 ..
792 } = contact
793 {
794 for contact_conn_id in
795 pool.user_connection_ids(contact_user_id)
796 {
797 peer.send(
798 contact_conn_id,
799 proto::UpdateContacts {
800 contacts: vec![updated_contact.clone()],
801 remove_contacts: Default::default(),
802 incoming_requests: Default::default(),
803 remove_incoming_requests: Default::default(),
804 outgoing_requests: Default::default(),
805 remove_outgoing_requests: Default::default(),
806 },
807 )
808 .trace_err();
809 }
810 }
811 }
812 }
813 }
814
815 if let Some(live_kit) = live_kit_client.as_ref() {
816 if delete_live_kit_room {
817 live_kit.delete_room(live_kit_room).await.trace_err();
818 }
819 }
820 }
821 }
822
823 app_state
824 .db
825 .delete_stale_servers(&app_state.config.zed_environment, server_id)
826 .await
827 .trace_err();
828 }
829 .instrument(span),
830 );
831 Ok(())
832 }
833
834 pub fn teardown(&self) {
835 self.peer.teardown();
836 self.connection_pool.lock().reset();
837 let _ = self.teardown.send(true);
838 }
839
840 #[cfg(test)]
841 pub fn reset(&self, id: ServerId) {
842 self.teardown();
843 *self.id.lock() = id;
844 self.peer.reset(id.0 as u32);
845 let _ = self.teardown.send(false);
846 }
847
848 #[cfg(test)]
849 pub fn id(&self) -> ServerId {
850 *self.id.lock()
851 }
852
853 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
854 where
855 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
856 Fut: 'static + Send + Future<Output = Result<()>>,
857 M: EnvelopedMessage,
858 {
859 let prev_handler = self.handlers.insert(
860 TypeId::of::<M>(),
861 Box::new(move |envelope, session| {
862 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
863 let received_at = envelope.received_at;
864 tracing::info!("message received");
865 let start_time = Instant::now();
866 let future = (handler)(*envelope, session);
867 async move {
868 let result = future.await;
869 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
870 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
871 let queue_duration_ms = total_duration_ms - processing_duration_ms;
872 let payload_type = M::NAME;
873
874 match result {
875 Err(error) => {
876 tracing::error!(
877 ?error,
878 total_duration_ms,
879 processing_duration_ms,
880 queue_duration_ms,
881 payload_type,
882 "error handling message"
883 )
884 }
885 Ok(()) => tracing::info!(
886 total_duration_ms,
887 processing_duration_ms,
888 queue_duration_ms,
889 "finished handling message"
890 ),
891 }
892 }
893 .boxed()
894 }),
895 );
896 if prev_handler.is_some() {
897 panic!("registered a handler for the same message twice");
898 }
899 self
900 }
901
902 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
903 where
904 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
905 Fut: 'static + Send + Future<Output = Result<()>>,
906 M: EnvelopedMessage,
907 {
908 self.add_handler(move |envelope, session| handler(envelope.payload, session));
909 self
910 }
911
912 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
913 where
914 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
915 Fut: Send + Future<Output = Result<()>>,
916 M: RequestMessage,
917 {
918 let handler = Arc::new(handler);
919 self.add_handler(move |envelope, session| {
920 let receipt = envelope.receipt();
921 let handler = handler.clone();
922 async move {
923 let peer = session.peer.clone();
924 let responded = Arc::new(AtomicBool::default());
925 let response = Response {
926 peer: peer.clone(),
927 responded: responded.clone(),
928 receipt,
929 };
930 match (handler)(envelope.payload, response, session).await {
931 Ok(()) => {
932 if responded.load(std::sync::atomic::Ordering::SeqCst) {
933 Ok(())
934 } else {
935 Err(anyhow!("handler did not send a response"))?
936 }
937 }
938 Err(error) => {
939 let proto_err = match &error {
940 Error::Internal(err) => err.to_proto(),
941 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
942 };
943 peer.respond_with_error(receipt, proto_err)?;
944 Err(error)
945 }
946 }
947 }
948 })
949 }
950
951 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
952 where
953 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
954 Fut: Send + Future<Output = Result<()>>,
955 M: RequestMessage,
956 {
957 let handler = Arc::new(handler);
958 self.add_handler(move |envelope, session| {
959 let receipt = envelope.receipt();
960 let handler = handler.clone();
961 async move {
962 let peer = session.peer.clone();
963 let response = StreamingResponse {
964 peer: peer.clone(),
965 receipt,
966 };
967 match (handler)(envelope.payload, response, session).await {
968 Ok(()) => {
969 peer.end_stream(receipt)?;
970 Ok(())
971 }
972 Err(error) => {
973 let proto_err = match &error {
974 Error::Internal(err) => err.to_proto(),
975 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
976 };
977 peer.respond_with_error(receipt, proto_err)?;
978 Err(error)
979 }
980 }
981 }
982 })
983 }
984
985 #[allow(clippy::too_many_arguments)]
986 pub fn handle_connection(
987 self: &Arc<Self>,
988 connection: Connection,
989 address: String,
990 principal: Principal,
991 zed_version: ZedVersion,
992 geoip_country_code: Option<String>,
993 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
994 executor: Executor,
995 ) -> impl Future<Output = ()> {
996 let this = self.clone();
997 let span = info_span!("handle connection", %address,
998 connection_id=field::Empty,
999 user_id=field::Empty,
1000 login=field::Empty,
1001 impersonator=field::Empty,
1002 dev_server_id=field::Empty,
1003 geoip_country_code=field::Empty
1004 );
1005 principal.update_span(&span);
1006 if let Some(country_code) = geoip_country_code.as_ref() {
1007 span.record("geoip_country_code", country_code);
1008 }
1009
1010 let mut teardown = self.teardown.subscribe();
1011 async move {
1012 if *teardown.borrow() {
1013 tracing::error!("server is tearing down");
1014 return
1015 }
1016 let (connection_id, handle_io, mut incoming_rx) = this
1017 .peer
1018 .add_connection(connection, {
1019 let executor = executor.clone();
1020 move |duration| executor.sleep(duration)
1021 });
1022 tracing::Span::current().record("connection_id", format!("{}", connection_id));
1023
1024 tracing::info!("connection opened");
1025
1026 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
1027 let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
1028 Ok(http_client) => Arc::new(http_client),
1029 Err(error) => {
1030 tracing::error!(?error, "failed to create HTTP client");
1031 return;
1032 }
1033 };
1034
1035 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
1036 Some(Arc::new(SupermavenAdminApi::new(
1037 supermaven_admin_api_key.to_string(),
1038 http_client.clone(),
1039 )))
1040 } else {
1041 None
1042 };
1043
1044 let session = Session {
1045 principal: principal.clone(),
1046 connection_id,
1047 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1048 peer: this.peer.clone(),
1049 connection_pool: this.connection_pool.clone(),
1050 app_state: this.app_state.clone(),
1051 http_client,
1052 geoip_country_code,
1053 _executor: executor.clone(),
1054 supermaven_client,
1055 };
1056
1057 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1058 tracing::error!(?error, "failed to send initial client update");
1059 return;
1060 }
1061
1062 let handle_io = handle_io.fuse();
1063 futures::pin_mut!(handle_io);
1064
1065 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1066 // This prevents deadlocks when e.g., client A performs a request to client B and
1067 // client B performs a request to client A. If both clients stop processing further
1068 // messages until their respective request completes, they won't have a chance to
1069 // respond to the other client's request and cause a deadlock.
1070 //
1071 // This arrangement ensures we will attempt to process earlier messages first, but fall
1072 // back to processing messages arrived later in the spirit of making progress.
1073 let mut foreground_message_handlers = FuturesUnordered::new();
1074 let concurrent_handlers = Arc::new(Semaphore::new(256));
1075 loop {
1076 let next_message = async {
1077 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1078 let message = incoming_rx.next().await;
1079 (permit, message)
1080 }.fuse();
1081 futures::pin_mut!(next_message);
1082 futures::select_biased! {
1083 _ = teardown.changed().fuse() => return,
1084 result = handle_io => {
1085 if let Err(error) = result {
1086 tracing::error!(?error, "error handling I/O");
1087 }
1088 break;
1089 }
1090 _ = foreground_message_handlers.next() => {}
1091 next_message = next_message => {
1092 let (permit, message) = next_message;
1093 if let Some(message) = message {
1094 let type_name = message.payload_type_name();
1095 // note: we copy all the fields from the parent span so we can query them in the logs.
1096 // (https://github.com/tokio-rs/tracing/issues/2670).
1097 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1098 user_id=field::Empty,
1099 login=field::Empty,
1100 impersonator=field::Empty,
1101 dev_server_id=field::Empty
1102 );
1103 principal.update_span(&span);
1104 let span_enter = span.enter();
1105 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1106 let is_background = message.is_background();
1107 let handle_message = (handler)(message, session.clone());
1108 drop(span_enter);
1109
1110 let handle_message = async move {
1111 handle_message.await;
1112 drop(permit);
1113 }.instrument(span);
1114 if is_background {
1115 executor.spawn_detached(handle_message);
1116 } else {
1117 foreground_message_handlers.push(handle_message);
1118 }
1119 } else {
1120 tracing::error!("no message handler");
1121 }
1122 } else {
1123 tracing::info!("connection closed");
1124 break;
1125 }
1126 }
1127 }
1128 }
1129
1130 drop(foreground_message_handlers);
1131 tracing::info!("signing out");
1132 if let Err(error) = connection_lost(session, teardown, executor).await {
1133 tracing::error!(?error, "error signing out");
1134 }
1135
1136 }.instrument(span)
1137 }
1138
1139 async fn send_initial_client_update(
1140 &self,
1141 connection_id: ConnectionId,
1142 principal: &Principal,
1143 zed_version: ZedVersion,
1144 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1145 session: &Session,
1146 ) -> Result<()> {
1147 self.peer.send(
1148 connection_id,
1149 proto::Hello {
1150 peer_id: Some(connection_id.into()),
1151 },
1152 )?;
1153 tracing::info!("sent hello message");
1154 if let Some(send_connection_id) = send_connection_id.take() {
1155 let _ = send_connection_id.send(connection_id);
1156 }
1157
1158 match principal {
1159 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1160 if !user.connected_once {
1161 self.peer.send(connection_id, proto::ShowContacts {})?;
1162 self.app_state
1163 .db
1164 .set_user_connected_once(user.id, true)
1165 .await?;
1166 }
1167
1168 update_user_plan(user.id, session).await?;
1169
1170 let (contacts, dev_server_projects) = future::try_join(
1171 self.app_state.db.get_contacts(user.id),
1172 self.app_state.db.dev_server_projects_update(user.id),
1173 )
1174 .await?;
1175
1176 {
1177 let mut pool = self.connection_pool.lock();
1178 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1179 self.peer.send(
1180 connection_id,
1181 build_initial_contacts_update(contacts, &pool),
1182 )?;
1183 }
1184
1185 if should_auto_subscribe_to_channels(zed_version) {
1186 subscribe_user_to_channels(user.id, session).await?;
1187 }
1188
1189 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1190
1191 if let Some(incoming_call) =
1192 self.app_state.db.incoming_call_for_user(user.id).await?
1193 {
1194 self.peer.send(connection_id, incoming_call)?;
1195 }
1196
1197 update_user_contacts(user.id, &session).await?;
1198 }
1199 Principal::DevServer(dev_server) => {
1200 {
1201 let mut pool = self.connection_pool.lock();
1202 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1203 {
1204 self.peer.send(
1205 stale_connection_id,
1206 proto::ShutdownDevServer {
1207 reason: Some(
1208 "another dev server connected with the same token".to_string(),
1209 ),
1210 },
1211 )?;
1212 pool.remove_connection(stale_connection_id)?;
1213 };
1214 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1215 }
1216
1217 let projects = self
1218 .app_state
1219 .db
1220 .get_projects_for_dev_server(dev_server.id)
1221 .await?;
1222 self.peer
1223 .send(connection_id, proto::DevServerInstructions { projects })?;
1224
1225 let status = self
1226 .app_state
1227 .db
1228 .dev_server_projects_update(dev_server.user_id)
1229 .await?;
1230 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1231 }
1232 }
1233
1234 Ok(())
1235 }
1236
1237 pub async fn invite_code_redeemed(
1238 self: &Arc<Self>,
1239 inviter_id: UserId,
1240 invitee_id: UserId,
1241 ) -> Result<()> {
1242 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1243 if let Some(code) = &user.invite_code {
1244 let pool = self.connection_pool.lock();
1245 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1246 for connection_id in pool.user_connection_ids(inviter_id) {
1247 self.peer.send(
1248 connection_id,
1249 proto::UpdateContacts {
1250 contacts: vec![invitee_contact.clone()],
1251 ..Default::default()
1252 },
1253 )?;
1254 self.peer.send(
1255 connection_id,
1256 proto::UpdateInviteInfo {
1257 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1258 count: user.invite_count as u32,
1259 },
1260 )?;
1261 }
1262 }
1263 }
1264 Ok(())
1265 }
1266
1267 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1268 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1269 if let Some(invite_code) = &user.invite_code {
1270 let pool = self.connection_pool.lock();
1271 for connection_id in pool.user_connection_ids(user_id) {
1272 self.peer.send(
1273 connection_id,
1274 proto::UpdateInviteInfo {
1275 url: format!(
1276 "{}{}",
1277 self.app_state.config.invite_link_prefix, invite_code
1278 ),
1279 count: user.invite_count as u32,
1280 },
1281 )?;
1282 }
1283 }
1284 }
1285 Ok(())
1286 }
1287
1288 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1289 ServerSnapshot {
1290 connection_pool: ConnectionPoolGuard {
1291 guard: self.connection_pool.lock(),
1292 _not_send: PhantomData,
1293 },
1294 peer: &self.peer,
1295 }
1296 }
1297}
1298
1299impl<'a> Deref for ConnectionPoolGuard<'a> {
1300 type Target = ConnectionPool;
1301
1302 fn deref(&self) -> &Self::Target {
1303 &self.guard
1304 }
1305}
1306
1307impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1308 fn deref_mut(&mut self) -> &mut Self::Target {
1309 &mut self.guard
1310 }
1311}
1312
1313impl<'a> Drop for ConnectionPoolGuard<'a> {
1314 fn drop(&mut self) {
1315 #[cfg(test)]
1316 self.check_invariants();
1317 }
1318}
1319
1320fn broadcast<F>(
1321 sender_id: Option<ConnectionId>,
1322 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1323 mut f: F,
1324) where
1325 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1326{
1327 for receiver_id in receiver_ids {
1328 if Some(receiver_id) != sender_id {
1329 if let Err(error) = f(receiver_id) {
1330 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1331 }
1332 }
1333 }
1334}
1335
1336pub struct ProtocolVersion(u32);
1337
1338impl Header for ProtocolVersion {
1339 fn name() -> &'static HeaderName {
1340 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1341 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1342 }
1343
1344 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1345 where
1346 Self: Sized,
1347 I: Iterator<Item = &'i axum::http::HeaderValue>,
1348 {
1349 let version = values
1350 .next()
1351 .ok_or_else(axum::headers::Error::invalid)?
1352 .to_str()
1353 .map_err(|_| axum::headers::Error::invalid())?
1354 .parse()
1355 .map_err(|_| axum::headers::Error::invalid())?;
1356 Ok(Self(version))
1357 }
1358
1359 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1360 values.extend([self.0.to_string().parse().unwrap()]);
1361 }
1362}
1363
1364pub struct AppVersionHeader(SemanticVersion);
1365impl Header for AppVersionHeader {
1366 fn name() -> &'static HeaderName {
1367 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1368 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1369 }
1370
1371 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1372 where
1373 Self: Sized,
1374 I: Iterator<Item = &'i axum::http::HeaderValue>,
1375 {
1376 let version = values
1377 .next()
1378 .ok_or_else(axum::headers::Error::invalid)?
1379 .to_str()
1380 .map_err(|_| axum::headers::Error::invalid())?
1381 .parse()
1382 .map_err(|_| axum::headers::Error::invalid())?;
1383 Ok(Self(version))
1384 }
1385
1386 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1387 values.extend([self.0.to_string().parse().unwrap()]);
1388 }
1389}
1390
1391pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1392 Router::new()
1393 .route("/rpc", get(handle_websocket_request))
1394 .layer(
1395 ServiceBuilder::new()
1396 .layer(Extension(server.app_state.clone()))
1397 .layer(middleware::from_fn(auth::validate_header)),
1398 )
1399 .route("/metrics", get(handle_metrics))
1400 .layer(Extension(server))
1401}
1402
1403pub async fn handle_websocket_request(
1404 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1405 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1406 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1407 Extension(server): Extension<Arc<Server>>,
1408 Extension(principal): Extension<Principal>,
1409 country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
1410 ws: WebSocketUpgrade,
1411) -> axum::response::Response {
1412 if protocol_version != rpc::PROTOCOL_VERSION {
1413 return (
1414 StatusCode::UPGRADE_REQUIRED,
1415 "client must be upgraded".to_string(),
1416 )
1417 .into_response();
1418 }
1419
1420 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1421 return (
1422 StatusCode::UPGRADE_REQUIRED,
1423 "no version header found".to_string(),
1424 )
1425 .into_response();
1426 };
1427
1428 if !version.can_collaborate() {
1429 return (
1430 StatusCode::UPGRADE_REQUIRED,
1431 "client must be upgraded".to_string(),
1432 )
1433 .into_response();
1434 }
1435
1436 let socket_address = socket_address.to_string();
1437 ws.on_upgrade(move |socket| {
1438 let socket = socket
1439 .map_ok(to_tungstenite_message)
1440 .err_into()
1441 .with(|message| async move { to_axum_message(message) });
1442 let connection = Connection::new(Box::pin(socket));
1443 async move {
1444 server
1445 .handle_connection(
1446 connection,
1447 socket_address,
1448 principal,
1449 version,
1450 country_code_header.map(|header| header.to_string()),
1451 None,
1452 Executor::Production,
1453 )
1454 .await;
1455 }
1456 })
1457}
1458
1459pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1460 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1461 let connections_metric = CONNECTIONS_METRIC
1462 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1463
1464 let connections = server
1465 .connection_pool
1466 .lock()
1467 .connections()
1468 .filter(|connection| !connection.admin)
1469 .count();
1470 connections_metric.set(connections as _);
1471
1472 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1473 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1474 register_int_gauge!(
1475 "shared_projects",
1476 "number of open projects with one or more guests"
1477 )
1478 .unwrap()
1479 });
1480
1481 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1482 shared_projects_metric.set(shared_projects as _);
1483
1484 let encoder = prometheus::TextEncoder::new();
1485 let metric_families = prometheus::gather();
1486 let encoded_metrics = encoder
1487 .encode_to_string(&metric_families)
1488 .map_err(|err| anyhow!("{}", err))?;
1489 Ok(encoded_metrics)
1490}
1491
1492#[instrument(err, skip(executor))]
1493async fn connection_lost(
1494 session: Session,
1495 mut teardown: watch::Receiver<bool>,
1496 executor: Executor,
1497) -> Result<()> {
1498 session.peer.disconnect(session.connection_id);
1499 session
1500 .connection_pool()
1501 .await
1502 .remove_connection(session.connection_id)?;
1503
1504 session
1505 .db()
1506 .await
1507 .connection_lost(session.connection_id)
1508 .await
1509 .trace_err();
1510
1511 futures::select_biased! {
1512 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1513 match &session.principal {
1514 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1515 let session = session.for_user().unwrap();
1516
1517 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1518 leave_room_for_session(&session, session.connection_id).await.trace_err();
1519 leave_channel_buffers_for_session(&session)
1520 .await
1521 .trace_err();
1522
1523 if !session
1524 .connection_pool()
1525 .await
1526 .is_user_online(session.user_id())
1527 {
1528 let db = session.db().await;
1529 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1530 room_updated(&room, &session.peer);
1531 }
1532 }
1533
1534 update_user_contacts(session.user_id(), &session).await?;
1535 },
1536 Principal::DevServer(_) => {
1537 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1538 },
1539 }
1540 },
1541 _ = teardown.changed().fuse() => {}
1542 }
1543
1544 Ok(())
1545}
1546
1547/// Acknowledges a ping from a client, used to keep the connection alive.
1548async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1549 response.send(proto::Ack {})?;
1550 Ok(())
1551}
1552
1553/// Creates a new room for calling (outside of channels)
1554async fn create_room(
1555 _request: proto::CreateRoom,
1556 response: Response<proto::CreateRoom>,
1557 session: UserSession,
1558) -> Result<()> {
1559 let live_kit_room = nanoid::nanoid!(30);
1560
1561 let live_kit_connection_info = util::maybe!(async {
1562 let live_kit = session.app_state.live_kit_client.as_ref();
1563 let live_kit = live_kit?;
1564 let user_id = session.user_id().to_string();
1565
1566 let token = live_kit
1567 .room_token(&live_kit_room, &user_id.to_string())
1568 .trace_err()?;
1569
1570 Some(proto::LiveKitConnectionInfo {
1571 server_url: live_kit.url().into(),
1572 token,
1573 can_publish: true,
1574 })
1575 })
1576 .await;
1577
1578 let room = session
1579 .db()
1580 .await
1581 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1582 .await?;
1583
1584 response.send(proto::CreateRoomResponse {
1585 room: Some(room.clone()),
1586 live_kit_connection_info,
1587 })?;
1588
1589 update_user_contacts(session.user_id(), &session).await?;
1590 Ok(())
1591}
1592
1593/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1594async fn join_room(
1595 request: proto::JoinRoom,
1596 response: Response<proto::JoinRoom>,
1597 session: UserSession,
1598) -> Result<()> {
1599 let room_id = RoomId::from_proto(request.id);
1600
1601 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1602
1603 if let Some(channel_id) = channel_id {
1604 return join_channel_internal(channel_id, Box::new(response), session).await;
1605 }
1606
1607 let joined_room = {
1608 let room = session
1609 .db()
1610 .await
1611 .join_room(room_id, session.user_id(), session.connection_id)
1612 .await?;
1613 room_updated(&room.room, &session.peer);
1614 room.into_inner()
1615 };
1616
1617 for connection_id in session
1618 .connection_pool()
1619 .await
1620 .user_connection_ids(session.user_id())
1621 {
1622 session
1623 .peer
1624 .send(
1625 connection_id,
1626 proto::CallCanceled {
1627 room_id: room_id.to_proto(),
1628 },
1629 )
1630 .trace_err();
1631 }
1632
1633 let live_kit_connection_info =
1634 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1635 if let Some(token) = live_kit
1636 .room_token(
1637 &joined_room.room.live_kit_room,
1638 &session.user_id().to_string(),
1639 )
1640 .trace_err()
1641 {
1642 Some(proto::LiveKitConnectionInfo {
1643 server_url: live_kit.url().into(),
1644 token,
1645 can_publish: true,
1646 })
1647 } else {
1648 None
1649 }
1650 } else {
1651 None
1652 };
1653
1654 response.send(proto::JoinRoomResponse {
1655 room: Some(joined_room.room),
1656 channel_id: None,
1657 live_kit_connection_info,
1658 })?;
1659
1660 update_user_contacts(session.user_id(), &session).await?;
1661 Ok(())
1662}
1663
1664/// Rejoin room is used to reconnect to a room after connection errors.
1665async fn rejoin_room(
1666 request: proto::RejoinRoom,
1667 response: Response<proto::RejoinRoom>,
1668 session: UserSession,
1669) -> Result<()> {
1670 let room;
1671 let channel;
1672 {
1673 let mut rejoined_room = session
1674 .db()
1675 .await
1676 .rejoin_room(request, session.user_id(), session.connection_id)
1677 .await?;
1678
1679 response.send(proto::RejoinRoomResponse {
1680 room: Some(rejoined_room.room.clone()),
1681 reshared_projects: rejoined_room
1682 .reshared_projects
1683 .iter()
1684 .map(|project| proto::ResharedProject {
1685 id: project.id.to_proto(),
1686 collaborators: project
1687 .collaborators
1688 .iter()
1689 .map(|collaborator| collaborator.to_proto())
1690 .collect(),
1691 })
1692 .collect(),
1693 rejoined_projects: rejoined_room
1694 .rejoined_projects
1695 .iter()
1696 .map(|rejoined_project| rejoined_project.to_proto())
1697 .collect(),
1698 })?;
1699 room_updated(&rejoined_room.room, &session.peer);
1700
1701 for project in &rejoined_room.reshared_projects {
1702 for collaborator in &project.collaborators {
1703 session
1704 .peer
1705 .send(
1706 collaborator.connection_id,
1707 proto::UpdateProjectCollaborator {
1708 project_id: project.id.to_proto(),
1709 old_peer_id: Some(project.old_connection_id.into()),
1710 new_peer_id: Some(session.connection_id.into()),
1711 },
1712 )
1713 .trace_err();
1714 }
1715
1716 broadcast(
1717 Some(session.connection_id),
1718 project
1719 .collaborators
1720 .iter()
1721 .map(|collaborator| collaborator.connection_id),
1722 |connection_id| {
1723 session.peer.forward_send(
1724 session.connection_id,
1725 connection_id,
1726 proto::UpdateProject {
1727 project_id: project.id.to_proto(),
1728 worktrees: project.worktrees.clone(),
1729 },
1730 )
1731 },
1732 );
1733 }
1734
1735 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1736
1737 let rejoined_room = rejoined_room.into_inner();
1738
1739 room = rejoined_room.room;
1740 channel = rejoined_room.channel;
1741 }
1742
1743 if let Some(channel) = channel {
1744 channel_updated(
1745 &channel,
1746 &room,
1747 &session.peer,
1748 &*session.connection_pool().await,
1749 );
1750 }
1751
1752 update_user_contacts(session.user_id(), &session).await?;
1753 Ok(())
1754}
1755
1756fn notify_rejoined_projects(
1757 rejoined_projects: &mut Vec<RejoinedProject>,
1758 session: &UserSession,
1759) -> Result<()> {
1760 for project in rejoined_projects.iter() {
1761 for collaborator in &project.collaborators {
1762 session
1763 .peer
1764 .send(
1765 collaborator.connection_id,
1766 proto::UpdateProjectCollaborator {
1767 project_id: project.id.to_proto(),
1768 old_peer_id: Some(project.old_connection_id.into()),
1769 new_peer_id: Some(session.connection_id.into()),
1770 },
1771 )
1772 .trace_err();
1773 }
1774 }
1775
1776 for project in rejoined_projects {
1777 for worktree in mem::take(&mut project.worktrees) {
1778 #[cfg(any(test, feature = "test-support"))]
1779 const MAX_CHUNK_SIZE: usize = 2;
1780 #[cfg(not(any(test, feature = "test-support")))]
1781 const MAX_CHUNK_SIZE: usize = 256;
1782
1783 // Stream this worktree's entries.
1784 let message = proto::UpdateWorktree {
1785 project_id: project.id.to_proto(),
1786 worktree_id: worktree.id,
1787 abs_path: worktree.abs_path.clone(),
1788 root_name: worktree.root_name,
1789 updated_entries: worktree.updated_entries,
1790 removed_entries: worktree.removed_entries,
1791 scan_id: worktree.scan_id,
1792 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1793 updated_repositories: worktree.updated_repositories,
1794 removed_repositories: worktree.removed_repositories,
1795 };
1796 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1797 session.peer.send(session.connection_id, update.clone())?;
1798 }
1799
1800 // Stream this worktree's diagnostics.
1801 for summary in worktree.diagnostic_summaries {
1802 session.peer.send(
1803 session.connection_id,
1804 proto::UpdateDiagnosticSummary {
1805 project_id: project.id.to_proto(),
1806 worktree_id: worktree.id,
1807 summary: Some(summary),
1808 },
1809 )?;
1810 }
1811
1812 for settings_file in worktree.settings_files {
1813 session.peer.send(
1814 session.connection_id,
1815 proto::UpdateWorktreeSettings {
1816 project_id: project.id.to_proto(),
1817 worktree_id: worktree.id,
1818 path: settings_file.path,
1819 content: Some(settings_file.content),
1820 },
1821 )?;
1822 }
1823 }
1824
1825 for language_server in &project.language_servers {
1826 session.peer.send(
1827 session.connection_id,
1828 proto::UpdateLanguageServer {
1829 project_id: project.id.to_proto(),
1830 language_server_id: language_server.id,
1831 variant: Some(
1832 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1833 proto::LspDiskBasedDiagnosticsUpdated {},
1834 ),
1835 ),
1836 },
1837 )?;
1838 }
1839 }
1840 Ok(())
1841}
1842
1843/// leave room disconnects from the room.
1844async fn leave_room(
1845 _: proto::LeaveRoom,
1846 response: Response<proto::LeaveRoom>,
1847 session: UserSession,
1848) -> Result<()> {
1849 leave_room_for_session(&session, session.connection_id).await?;
1850 response.send(proto::Ack {})?;
1851 Ok(())
1852}
1853
1854/// Updates the permissions of someone else in the room.
1855async fn set_room_participant_role(
1856 request: proto::SetRoomParticipantRole,
1857 response: Response<proto::SetRoomParticipantRole>,
1858 session: UserSession,
1859) -> Result<()> {
1860 let user_id = UserId::from_proto(request.user_id);
1861 let role = ChannelRole::from(request.role());
1862
1863 let (live_kit_room, can_publish) = {
1864 let room = session
1865 .db()
1866 .await
1867 .set_room_participant_role(
1868 session.user_id(),
1869 RoomId::from_proto(request.room_id),
1870 user_id,
1871 role,
1872 )
1873 .await?;
1874
1875 let live_kit_room = room.live_kit_room.clone();
1876 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1877 room_updated(&room, &session.peer);
1878 (live_kit_room, can_publish)
1879 };
1880
1881 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
1882 live_kit
1883 .update_participant(
1884 live_kit_room.clone(),
1885 request.user_id.to_string(),
1886 live_kit_server::proto::ParticipantPermission {
1887 can_subscribe: true,
1888 can_publish,
1889 can_publish_data: can_publish,
1890 hidden: false,
1891 recorder: false,
1892 },
1893 )
1894 .await
1895 .trace_err();
1896 }
1897
1898 response.send(proto::Ack {})?;
1899 Ok(())
1900}
1901
1902/// Call someone else into the current room
1903async fn call(
1904 request: proto::Call,
1905 response: Response<proto::Call>,
1906 session: UserSession,
1907) -> Result<()> {
1908 let room_id = RoomId::from_proto(request.room_id);
1909 let calling_user_id = session.user_id();
1910 let calling_connection_id = session.connection_id;
1911 let called_user_id = UserId::from_proto(request.called_user_id);
1912 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1913 if !session
1914 .db()
1915 .await
1916 .has_contact(calling_user_id, called_user_id)
1917 .await?
1918 {
1919 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1920 }
1921
1922 let incoming_call = {
1923 let (room, incoming_call) = &mut *session
1924 .db()
1925 .await
1926 .call(
1927 room_id,
1928 calling_user_id,
1929 calling_connection_id,
1930 called_user_id,
1931 initial_project_id,
1932 )
1933 .await?;
1934 room_updated(&room, &session.peer);
1935 mem::take(incoming_call)
1936 };
1937 update_user_contacts(called_user_id, &session).await?;
1938
1939 let mut calls = session
1940 .connection_pool()
1941 .await
1942 .user_connection_ids(called_user_id)
1943 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1944 .collect::<FuturesUnordered<_>>();
1945
1946 while let Some(call_response) = calls.next().await {
1947 match call_response.as_ref() {
1948 Ok(_) => {
1949 response.send(proto::Ack {})?;
1950 return Ok(());
1951 }
1952 Err(_) => {
1953 call_response.trace_err();
1954 }
1955 }
1956 }
1957
1958 {
1959 let room = session
1960 .db()
1961 .await
1962 .call_failed(room_id, called_user_id)
1963 .await?;
1964 room_updated(&room, &session.peer);
1965 }
1966 update_user_contacts(called_user_id, &session).await?;
1967
1968 Err(anyhow!("failed to ring user"))?
1969}
1970
1971/// Cancel an outgoing call.
1972async fn cancel_call(
1973 request: proto::CancelCall,
1974 response: Response<proto::CancelCall>,
1975 session: UserSession,
1976) -> Result<()> {
1977 let called_user_id = UserId::from_proto(request.called_user_id);
1978 let room_id = RoomId::from_proto(request.room_id);
1979 {
1980 let room = session
1981 .db()
1982 .await
1983 .cancel_call(room_id, session.connection_id, called_user_id)
1984 .await?;
1985 room_updated(&room, &session.peer);
1986 }
1987
1988 for connection_id in session
1989 .connection_pool()
1990 .await
1991 .user_connection_ids(called_user_id)
1992 {
1993 session
1994 .peer
1995 .send(
1996 connection_id,
1997 proto::CallCanceled {
1998 room_id: room_id.to_proto(),
1999 },
2000 )
2001 .trace_err();
2002 }
2003 response.send(proto::Ack {})?;
2004
2005 update_user_contacts(called_user_id, &session).await?;
2006 Ok(())
2007}
2008
2009/// Decline an incoming call.
2010async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
2011 let room_id = RoomId::from_proto(message.room_id);
2012 {
2013 let room = session
2014 .db()
2015 .await
2016 .decline_call(Some(room_id), session.user_id())
2017 .await?
2018 .ok_or_else(|| anyhow!("failed to decline call"))?;
2019 room_updated(&room, &session.peer);
2020 }
2021
2022 for connection_id in session
2023 .connection_pool()
2024 .await
2025 .user_connection_ids(session.user_id())
2026 {
2027 session
2028 .peer
2029 .send(
2030 connection_id,
2031 proto::CallCanceled {
2032 room_id: room_id.to_proto(),
2033 },
2034 )
2035 .trace_err();
2036 }
2037 update_user_contacts(session.user_id(), &session).await?;
2038 Ok(())
2039}
2040
2041/// Updates other participants in the room with your current location.
2042async fn update_participant_location(
2043 request: proto::UpdateParticipantLocation,
2044 response: Response<proto::UpdateParticipantLocation>,
2045 session: UserSession,
2046) -> Result<()> {
2047 let room_id = RoomId::from_proto(request.room_id);
2048 let location = request
2049 .location
2050 .ok_or_else(|| anyhow!("invalid location"))?;
2051
2052 let db = session.db().await;
2053 let room = db
2054 .update_room_participant_location(room_id, session.connection_id, location)
2055 .await?;
2056
2057 room_updated(&room, &session.peer);
2058 response.send(proto::Ack {})?;
2059 Ok(())
2060}
2061
2062/// Share a project into the room.
2063async fn share_project(
2064 request: proto::ShareProject,
2065 response: Response<proto::ShareProject>,
2066 session: UserSession,
2067) -> Result<()> {
2068 let (project_id, room) = &*session
2069 .db()
2070 .await
2071 .share_project(
2072 RoomId::from_proto(request.room_id),
2073 session.connection_id,
2074 &request.worktrees,
2075 request
2076 .dev_server_project_id
2077 .map(|id| DevServerProjectId::from_proto(id)),
2078 )
2079 .await?;
2080 response.send(proto::ShareProjectResponse {
2081 project_id: project_id.to_proto(),
2082 })?;
2083 room_updated(&room, &session.peer);
2084
2085 Ok(())
2086}
2087
2088/// Unshare a project from the room.
2089async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2090 let project_id = ProjectId::from_proto(message.project_id);
2091 unshare_project_internal(
2092 project_id,
2093 session.connection_id,
2094 session.user_id(),
2095 &session,
2096 )
2097 .await
2098}
2099
2100async fn unshare_project_internal(
2101 project_id: ProjectId,
2102 connection_id: ConnectionId,
2103 user_id: Option<UserId>,
2104 session: &Session,
2105) -> Result<()> {
2106 let delete = {
2107 let room_guard = session
2108 .db()
2109 .await
2110 .unshare_project(project_id, connection_id, user_id)
2111 .await?;
2112
2113 let (delete, room, guest_connection_ids) = &*room_guard;
2114
2115 let message = proto::UnshareProject {
2116 project_id: project_id.to_proto(),
2117 };
2118
2119 broadcast(
2120 Some(connection_id),
2121 guest_connection_ids.iter().copied(),
2122 |conn_id| session.peer.send(conn_id, message.clone()),
2123 );
2124 if let Some(room) = room {
2125 room_updated(room, &session.peer);
2126 }
2127
2128 *delete
2129 };
2130
2131 if delete {
2132 let db = session.db().await;
2133 db.delete_project(project_id).await?;
2134 }
2135
2136 Ok(())
2137}
2138
2139/// DevServer makes a project available online
2140async fn share_dev_server_project(
2141 request: proto::ShareDevServerProject,
2142 response: Response<proto::ShareDevServerProject>,
2143 session: DevServerSession,
2144) -> Result<()> {
2145 let (dev_server_project, user_id, status) = session
2146 .db()
2147 .await
2148 .share_dev_server_project(
2149 DevServerProjectId::from_proto(request.dev_server_project_id),
2150 session.dev_server_id(),
2151 session.connection_id,
2152 &request.worktrees,
2153 )
2154 .await?;
2155 let Some(project_id) = dev_server_project.project_id else {
2156 return Err(anyhow!("failed to share remote project"))?;
2157 };
2158
2159 send_dev_server_projects_update(user_id, status, &session).await;
2160
2161 response.send(proto::ShareProjectResponse { project_id })?;
2162
2163 Ok(())
2164}
2165
2166/// Join someone elses shared project.
2167async fn join_project(
2168 request: proto::JoinProject,
2169 response: Response<proto::JoinProject>,
2170 session: UserSession,
2171) -> Result<()> {
2172 let project_id = ProjectId::from_proto(request.project_id);
2173
2174 tracing::info!(%project_id, "join project");
2175
2176 let db = session.db().await;
2177 let (project, replica_id) = &mut *db
2178 .join_project(project_id, session.connection_id, session.user_id())
2179 .await?;
2180 drop(db);
2181 tracing::info!(%project_id, "join remote project");
2182 join_project_internal(response, session, project, replica_id)
2183}
2184
2185trait JoinProjectInternalResponse {
2186 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2187}
2188impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2189 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2190 Response::<proto::JoinProject>::send(self, result)
2191 }
2192}
2193impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2194 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2195 Response::<proto::JoinHostedProject>::send(self, result)
2196 }
2197}
2198
2199fn join_project_internal(
2200 response: impl JoinProjectInternalResponse,
2201 session: UserSession,
2202 project: &mut Project,
2203 replica_id: &ReplicaId,
2204) -> Result<()> {
2205 let collaborators = project
2206 .collaborators
2207 .iter()
2208 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2209 .map(|collaborator| collaborator.to_proto())
2210 .collect::<Vec<_>>();
2211 let project_id = project.id;
2212 let guest_user_id = session.user_id();
2213
2214 let worktrees = project
2215 .worktrees
2216 .iter()
2217 .map(|(id, worktree)| proto::WorktreeMetadata {
2218 id: *id,
2219 root_name: worktree.root_name.clone(),
2220 visible: worktree.visible,
2221 abs_path: worktree.abs_path.clone(),
2222 })
2223 .collect::<Vec<_>>();
2224
2225 let add_project_collaborator = proto::AddProjectCollaborator {
2226 project_id: project_id.to_proto(),
2227 collaborator: Some(proto::Collaborator {
2228 peer_id: Some(session.connection_id.into()),
2229 replica_id: replica_id.0 as u32,
2230 user_id: guest_user_id.to_proto(),
2231 }),
2232 };
2233
2234 for collaborator in &collaborators {
2235 session
2236 .peer
2237 .send(
2238 collaborator.peer_id.unwrap().into(),
2239 add_project_collaborator.clone(),
2240 )
2241 .trace_err();
2242 }
2243
2244 // First, we send the metadata associated with each worktree.
2245 response.send(proto::JoinProjectResponse {
2246 project_id: project.id.0 as u64,
2247 worktrees: worktrees.clone(),
2248 replica_id: replica_id.0 as u32,
2249 collaborators: collaborators.clone(),
2250 language_servers: project.language_servers.clone(),
2251 role: project.role.into(),
2252 dev_server_project_id: project
2253 .dev_server_project_id
2254 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2255 })?;
2256
2257 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2258 #[cfg(any(test, feature = "test-support"))]
2259 const MAX_CHUNK_SIZE: usize = 2;
2260 #[cfg(not(any(test, feature = "test-support")))]
2261 const MAX_CHUNK_SIZE: usize = 256;
2262
2263 // Stream this worktree's entries.
2264 let message = proto::UpdateWorktree {
2265 project_id: project_id.to_proto(),
2266 worktree_id,
2267 abs_path: worktree.abs_path.clone(),
2268 root_name: worktree.root_name,
2269 updated_entries: worktree.entries,
2270 removed_entries: Default::default(),
2271 scan_id: worktree.scan_id,
2272 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2273 updated_repositories: worktree.repository_entries.into_values().collect(),
2274 removed_repositories: Default::default(),
2275 };
2276 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2277 session.peer.send(session.connection_id, update.clone())?;
2278 }
2279
2280 // Stream this worktree's diagnostics.
2281 for summary in worktree.diagnostic_summaries {
2282 session.peer.send(
2283 session.connection_id,
2284 proto::UpdateDiagnosticSummary {
2285 project_id: project_id.to_proto(),
2286 worktree_id: worktree.id,
2287 summary: Some(summary),
2288 },
2289 )?;
2290 }
2291
2292 for settings_file in worktree.settings_files {
2293 session.peer.send(
2294 session.connection_id,
2295 proto::UpdateWorktreeSettings {
2296 project_id: project_id.to_proto(),
2297 worktree_id: worktree.id,
2298 path: settings_file.path,
2299 content: Some(settings_file.content),
2300 },
2301 )?;
2302 }
2303 }
2304
2305 for language_server in &project.language_servers {
2306 session.peer.send(
2307 session.connection_id,
2308 proto::UpdateLanguageServer {
2309 project_id: project_id.to_proto(),
2310 language_server_id: language_server.id,
2311 variant: Some(
2312 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2313 proto::LspDiskBasedDiagnosticsUpdated {},
2314 ),
2315 ),
2316 },
2317 )?;
2318 }
2319
2320 Ok(())
2321}
2322
2323/// Leave someone elses shared project.
2324async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2325 let sender_id = session.connection_id;
2326 let project_id = ProjectId::from_proto(request.project_id);
2327 let db = session.db().await;
2328 if db.is_hosted_project(project_id).await? {
2329 let project = db.leave_hosted_project(project_id, sender_id).await?;
2330 project_left(&project, &session);
2331 return Ok(());
2332 }
2333
2334 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2335 tracing::info!(
2336 %project_id,
2337 "leave project"
2338 );
2339
2340 project_left(&project, &session);
2341 if let Some(room) = room {
2342 room_updated(&room, &session.peer);
2343 }
2344
2345 Ok(())
2346}
2347
2348async fn join_hosted_project(
2349 request: proto::JoinHostedProject,
2350 response: Response<proto::JoinHostedProject>,
2351 session: UserSession,
2352) -> Result<()> {
2353 let (mut project, replica_id) = session
2354 .db()
2355 .await
2356 .join_hosted_project(
2357 ProjectId(request.project_id as i32),
2358 session.user_id(),
2359 session.connection_id,
2360 )
2361 .await?;
2362
2363 join_project_internal(response, session, &mut project, &replica_id)
2364}
2365
2366async fn list_remote_directory(
2367 request: proto::ListRemoteDirectory,
2368 response: Response<proto::ListRemoteDirectory>,
2369 session: UserSession,
2370) -> Result<()> {
2371 let dev_server_id = DevServerId(request.dev_server_id as i32);
2372 let dev_server_connection_id = session
2373 .connection_pool()
2374 .await
2375 .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2376
2377 session
2378 .db()
2379 .await
2380 .get_dev_server_for_user(dev_server_id, session.user_id())
2381 .await?;
2382
2383 response.send(
2384 session
2385 .peer
2386 .forward_request(session.connection_id, dev_server_connection_id, request)
2387 .await?,
2388 )?;
2389 Ok(())
2390}
2391
2392async fn update_dev_server_project(
2393 request: proto::UpdateDevServerProject,
2394 response: Response<proto::UpdateDevServerProject>,
2395 session: UserSession,
2396) -> Result<()> {
2397 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2398
2399 let (dev_server_project, update) = session
2400 .db()
2401 .await
2402 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2403 .await?;
2404
2405 let projects = session
2406 .db()
2407 .await
2408 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2409 .await?;
2410
2411 let dev_server_connection_id = session
2412 .connection_pool()
2413 .await
2414 .dev_server_connection_id_supporting(
2415 dev_server_project.dev_server_id,
2416 ZedVersion::with_list_directory(),
2417 )?;
2418
2419 session.peer.send(
2420 dev_server_connection_id,
2421 proto::DevServerInstructions { projects },
2422 )?;
2423
2424 send_dev_server_projects_update(session.user_id(), update, &session).await;
2425
2426 response.send(proto::Ack {})
2427}
2428
2429async fn create_dev_server_project(
2430 request: proto::CreateDevServerProject,
2431 response: Response<proto::CreateDevServerProject>,
2432 session: UserSession,
2433) -> Result<()> {
2434 let dev_server_id = DevServerId(request.dev_server_id as i32);
2435 let dev_server_connection_id = session
2436 .connection_pool()
2437 .await
2438 .dev_server_connection_id(dev_server_id);
2439 let Some(dev_server_connection_id) = dev_server_connection_id else {
2440 Err(ErrorCode::DevServerOffline
2441 .message("Cannot create a remote project when the dev server is offline".to_string())
2442 .anyhow())?
2443 };
2444
2445 let path = request.path.clone();
2446 //Check that the path exists on the dev server
2447 session
2448 .peer
2449 .forward_request(
2450 session.connection_id,
2451 dev_server_connection_id,
2452 proto::ValidateDevServerProjectRequest { path: path.clone() },
2453 )
2454 .await?;
2455
2456 let (dev_server_project, update) = session
2457 .db()
2458 .await
2459 .create_dev_server_project(
2460 DevServerId(request.dev_server_id as i32),
2461 &request.path,
2462 session.user_id(),
2463 )
2464 .await?;
2465
2466 let projects = session
2467 .db()
2468 .await
2469 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2470 .await?;
2471
2472 session.peer.send(
2473 dev_server_connection_id,
2474 proto::DevServerInstructions { projects },
2475 )?;
2476
2477 send_dev_server_projects_update(session.user_id(), update, &session).await;
2478
2479 response.send(proto::CreateDevServerProjectResponse {
2480 dev_server_project: Some(dev_server_project.to_proto(None)),
2481 })?;
2482 Ok(())
2483}
2484
2485async fn create_dev_server(
2486 request: proto::CreateDevServer,
2487 response: Response<proto::CreateDevServer>,
2488 session: UserSession,
2489) -> Result<()> {
2490 let access_token = auth::random_token();
2491 let hashed_access_token = auth::hash_access_token(&access_token);
2492
2493 if request.name.is_empty() {
2494 return Err(proto::ErrorCode::Forbidden
2495 .message("Dev server name cannot be empty".to_string())
2496 .anyhow())?;
2497 }
2498
2499 let (dev_server, status) = session
2500 .db()
2501 .await
2502 .create_dev_server(
2503 &request.name,
2504 request.ssh_connection_string.as_deref(),
2505 &hashed_access_token,
2506 session.user_id(),
2507 )
2508 .await?;
2509
2510 send_dev_server_projects_update(session.user_id(), status, &session).await;
2511
2512 response.send(proto::CreateDevServerResponse {
2513 dev_server_id: dev_server.id.0 as u64,
2514 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2515 name: request.name,
2516 })?;
2517 Ok(())
2518}
2519
2520async fn regenerate_dev_server_token(
2521 request: proto::RegenerateDevServerToken,
2522 response: Response<proto::RegenerateDevServerToken>,
2523 session: UserSession,
2524) -> Result<()> {
2525 let dev_server_id = DevServerId(request.dev_server_id as i32);
2526 let access_token = auth::random_token();
2527 let hashed_access_token = auth::hash_access_token(&access_token);
2528
2529 let connection_id = session
2530 .connection_pool()
2531 .await
2532 .dev_server_connection_id(dev_server_id);
2533 if let Some(connection_id) = connection_id {
2534 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2535 session.peer.send(
2536 connection_id,
2537 proto::ShutdownDevServer {
2538 reason: Some("dev server token was regenerated".to_string()),
2539 },
2540 )?;
2541 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2542 }
2543
2544 let status = session
2545 .db()
2546 .await
2547 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2548 .await?;
2549
2550 send_dev_server_projects_update(session.user_id(), status, &session).await;
2551
2552 response.send(proto::RegenerateDevServerTokenResponse {
2553 dev_server_id: dev_server_id.to_proto(),
2554 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2555 })?;
2556 Ok(())
2557}
2558
2559async fn rename_dev_server(
2560 request: proto::RenameDevServer,
2561 response: Response<proto::RenameDevServer>,
2562 session: UserSession,
2563) -> Result<()> {
2564 if request.name.trim().is_empty() {
2565 return Err(proto::ErrorCode::Forbidden
2566 .message("Dev server name cannot be empty".to_string())
2567 .anyhow())?;
2568 }
2569
2570 let dev_server_id = DevServerId(request.dev_server_id as i32);
2571 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2572 if dev_server.user_id != session.user_id() {
2573 return Err(anyhow!(ErrorCode::Forbidden))?;
2574 }
2575
2576 let status = session
2577 .db()
2578 .await
2579 .rename_dev_server(
2580 dev_server_id,
2581 &request.name,
2582 request.ssh_connection_string.as_deref(),
2583 session.user_id(),
2584 )
2585 .await?;
2586
2587 send_dev_server_projects_update(session.user_id(), status, &session).await;
2588
2589 response.send(proto::Ack {})?;
2590 Ok(())
2591}
2592
2593async fn delete_dev_server(
2594 request: proto::DeleteDevServer,
2595 response: Response<proto::DeleteDevServer>,
2596 session: UserSession,
2597) -> Result<()> {
2598 let dev_server_id = DevServerId(request.dev_server_id as i32);
2599 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2600 if dev_server.user_id != session.user_id() {
2601 return Err(anyhow!(ErrorCode::Forbidden))?;
2602 }
2603
2604 let connection_id = session
2605 .connection_pool()
2606 .await
2607 .dev_server_connection_id(dev_server_id);
2608 if let Some(connection_id) = connection_id {
2609 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2610 session.peer.send(
2611 connection_id,
2612 proto::ShutdownDevServer {
2613 reason: Some("dev server was deleted".to_string()),
2614 },
2615 )?;
2616 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2617 }
2618
2619 let status = session
2620 .db()
2621 .await
2622 .delete_dev_server(dev_server_id, session.user_id())
2623 .await?;
2624
2625 send_dev_server_projects_update(session.user_id(), status, &session).await;
2626
2627 response.send(proto::Ack {})?;
2628 Ok(())
2629}
2630
2631async fn delete_dev_server_project(
2632 request: proto::DeleteDevServerProject,
2633 response: Response<proto::DeleteDevServerProject>,
2634 session: UserSession,
2635) -> Result<()> {
2636 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2637 let dev_server_project = session
2638 .db()
2639 .await
2640 .get_dev_server_project(dev_server_project_id)
2641 .await?;
2642
2643 let dev_server = session
2644 .db()
2645 .await
2646 .get_dev_server(dev_server_project.dev_server_id)
2647 .await?;
2648 if dev_server.user_id != session.user_id() {
2649 return Err(anyhow!(ErrorCode::Forbidden))?;
2650 }
2651
2652 let dev_server_connection_id = session
2653 .connection_pool()
2654 .await
2655 .dev_server_connection_id(dev_server.id);
2656
2657 if let Some(dev_server_connection_id) = dev_server_connection_id {
2658 let project = session
2659 .db()
2660 .await
2661 .find_dev_server_project(dev_server_project_id)
2662 .await;
2663 if let Ok(project) = project {
2664 unshare_project_internal(
2665 project.id,
2666 dev_server_connection_id,
2667 Some(session.user_id()),
2668 &session,
2669 )
2670 .await?;
2671 }
2672 }
2673
2674 let (projects, status) = session
2675 .db()
2676 .await
2677 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2678 .await?;
2679
2680 if let Some(dev_server_connection_id) = dev_server_connection_id {
2681 session.peer.send(
2682 dev_server_connection_id,
2683 proto::DevServerInstructions { projects },
2684 )?;
2685 }
2686
2687 send_dev_server_projects_update(session.user_id(), status, &session).await;
2688
2689 response.send(proto::Ack {})?;
2690 Ok(())
2691}
2692
2693async fn rejoin_dev_server_projects(
2694 request: proto::RejoinRemoteProjects,
2695 response: Response<proto::RejoinRemoteProjects>,
2696 session: UserSession,
2697) -> Result<()> {
2698 let mut rejoined_projects = {
2699 let db = session.db().await;
2700 db.rejoin_dev_server_projects(
2701 &request.rejoined_projects,
2702 session.user_id(),
2703 session.0.connection_id,
2704 )
2705 .await?
2706 };
2707 response.send(proto::RejoinRemoteProjectsResponse {
2708 rejoined_projects: rejoined_projects
2709 .iter()
2710 .map(|project| project.to_proto())
2711 .collect(),
2712 })?;
2713 notify_rejoined_projects(&mut rejoined_projects, &session)
2714}
2715
2716async fn reconnect_dev_server(
2717 request: proto::ReconnectDevServer,
2718 response: Response<proto::ReconnectDevServer>,
2719 session: DevServerSession,
2720) -> Result<()> {
2721 let reshared_projects = {
2722 let db = session.db().await;
2723 db.reshare_dev_server_projects(
2724 &request.reshared_projects,
2725 session.dev_server_id(),
2726 session.0.connection_id,
2727 )
2728 .await?
2729 };
2730
2731 for project in &reshared_projects {
2732 for collaborator in &project.collaborators {
2733 session
2734 .peer
2735 .send(
2736 collaborator.connection_id,
2737 proto::UpdateProjectCollaborator {
2738 project_id: project.id.to_proto(),
2739 old_peer_id: Some(project.old_connection_id.into()),
2740 new_peer_id: Some(session.connection_id.into()),
2741 },
2742 )
2743 .trace_err();
2744 }
2745
2746 broadcast(
2747 Some(session.connection_id),
2748 project
2749 .collaborators
2750 .iter()
2751 .map(|collaborator| collaborator.connection_id),
2752 |connection_id| {
2753 session.peer.forward_send(
2754 session.connection_id,
2755 connection_id,
2756 proto::UpdateProject {
2757 project_id: project.id.to_proto(),
2758 worktrees: project.worktrees.clone(),
2759 },
2760 )
2761 },
2762 );
2763 }
2764
2765 response.send(proto::ReconnectDevServerResponse {
2766 reshared_projects: reshared_projects
2767 .iter()
2768 .map(|project| proto::ResharedProject {
2769 id: project.id.to_proto(),
2770 collaborators: project
2771 .collaborators
2772 .iter()
2773 .map(|collaborator| collaborator.to_proto())
2774 .collect(),
2775 })
2776 .collect(),
2777 })?;
2778
2779 Ok(())
2780}
2781
2782async fn shutdown_dev_server(
2783 _: proto::ShutdownDevServer,
2784 response: Response<proto::ShutdownDevServer>,
2785 session: DevServerSession,
2786) -> Result<()> {
2787 response.send(proto::Ack {})?;
2788 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2789 remove_dev_server_connection(session.dev_server_id(), &session).await
2790}
2791
2792async fn shutdown_dev_server_internal(
2793 dev_server_id: DevServerId,
2794 connection_id: ConnectionId,
2795 session: &Session,
2796) -> Result<()> {
2797 let (dev_server_projects, dev_server) = {
2798 let db = session.db().await;
2799 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2800 let dev_server = db.get_dev_server(dev_server_id).await?;
2801 (dev_server_projects, dev_server)
2802 };
2803
2804 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2805 unshare_project_internal(
2806 ProjectId::from_proto(project_id),
2807 connection_id,
2808 None,
2809 session,
2810 )
2811 .await?;
2812 }
2813
2814 session
2815 .connection_pool()
2816 .await
2817 .set_dev_server_offline(dev_server_id);
2818
2819 let status = session
2820 .db()
2821 .await
2822 .dev_server_projects_update(dev_server.user_id)
2823 .await?;
2824 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2825
2826 Ok(())
2827}
2828
2829async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2830 let dev_server_connection = session
2831 .connection_pool()
2832 .await
2833 .dev_server_connection_id(dev_server_id);
2834
2835 if let Some(dev_server_connection) = dev_server_connection {
2836 session
2837 .connection_pool()
2838 .await
2839 .remove_connection(dev_server_connection)?;
2840 }
2841 Ok(())
2842}
2843
2844/// Updates other participants with changes to the project
2845async fn update_project(
2846 request: proto::UpdateProject,
2847 response: Response<proto::UpdateProject>,
2848 session: Session,
2849) -> Result<()> {
2850 let project_id = ProjectId::from_proto(request.project_id);
2851 let (room, guest_connection_ids) = &*session
2852 .db()
2853 .await
2854 .update_project(project_id, session.connection_id, &request.worktrees)
2855 .await?;
2856 broadcast(
2857 Some(session.connection_id),
2858 guest_connection_ids.iter().copied(),
2859 |connection_id| {
2860 session
2861 .peer
2862 .forward_send(session.connection_id, connection_id, request.clone())
2863 },
2864 );
2865 if let Some(room) = room {
2866 room_updated(&room, &session.peer);
2867 }
2868 response.send(proto::Ack {})?;
2869
2870 Ok(())
2871}
2872
2873/// Updates other participants with changes to the worktree
2874async fn update_worktree(
2875 request: proto::UpdateWorktree,
2876 response: Response<proto::UpdateWorktree>,
2877 session: Session,
2878) -> Result<()> {
2879 let guest_connection_ids = session
2880 .db()
2881 .await
2882 .update_worktree(&request, session.connection_id)
2883 .await?;
2884
2885 broadcast(
2886 Some(session.connection_id),
2887 guest_connection_ids.iter().copied(),
2888 |connection_id| {
2889 session
2890 .peer
2891 .forward_send(session.connection_id, connection_id, request.clone())
2892 },
2893 );
2894 response.send(proto::Ack {})?;
2895 Ok(())
2896}
2897
2898/// Updates other participants with changes to the diagnostics
2899async fn update_diagnostic_summary(
2900 message: proto::UpdateDiagnosticSummary,
2901 session: Session,
2902) -> Result<()> {
2903 let guest_connection_ids = session
2904 .db()
2905 .await
2906 .update_diagnostic_summary(&message, session.connection_id)
2907 .await?;
2908
2909 broadcast(
2910 Some(session.connection_id),
2911 guest_connection_ids.iter().copied(),
2912 |connection_id| {
2913 session
2914 .peer
2915 .forward_send(session.connection_id, connection_id, message.clone())
2916 },
2917 );
2918
2919 Ok(())
2920}
2921
2922/// Updates other participants with changes to the worktree settings
2923async fn update_worktree_settings(
2924 message: proto::UpdateWorktreeSettings,
2925 session: Session,
2926) -> Result<()> {
2927 let guest_connection_ids = session
2928 .db()
2929 .await
2930 .update_worktree_settings(&message, session.connection_id)
2931 .await?;
2932
2933 broadcast(
2934 Some(session.connection_id),
2935 guest_connection_ids.iter().copied(),
2936 |connection_id| {
2937 session
2938 .peer
2939 .forward_send(session.connection_id, connection_id, message.clone())
2940 },
2941 );
2942
2943 Ok(())
2944}
2945
2946/// Notify other participants that a language server has started.
2947async fn start_language_server(
2948 request: proto::StartLanguageServer,
2949 session: Session,
2950) -> Result<()> {
2951 let guest_connection_ids = session
2952 .db()
2953 .await
2954 .start_language_server(&request, session.connection_id)
2955 .await?;
2956
2957 broadcast(
2958 Some(session.connection_id),
2959 guest_connection_ids.iter().copied(),
2960 |connection_id| {
2961 session
2962 .peer
2963 .forward_send(session.connection_id, connection_id, request.clone())
2964 },
2965 );
2966 Ok(())
2967}
2968
2969/// Notify other participants that a language server has changed.
2970async fn update_language_server(
2971 request: proto::UpdateLanguageServer,
2972 session: Session,
2973) -> Result<()> {
2974 let project_id = ProjectId::from_proto(request.project_id);
2975 let project_connection_ids = session
2976 .db()
2977 .await
2978 .project_connection_ids(project_id, session.connection_id, true)
2979 .await?;
2980 broadcast(
2981 Some(session.connection_id),
2982 project_connection_ids.iter().copied(),
2983 |connection_id| {
2984 session
2985 .peer
2986 .forward_send(session.connection_id, connection_id, request.clone())
2987 },
2988 );
2989 Ok(())
2990}
2991
2992/// forward a project request to the host. These requests should be read only
2993/// as guests are allowed to send them.
2994async fn forward_read_only_project_request<T>(
2995 request: T,
2996 response: Response<T>,
2997 session: UserSession,
2998) -> Result<()>
2999where
3000 T: EntityMessage + RequestMessage,
3001{
3002 let project_id = ProjectId::from_proto(request.remote_entity_id());
3003 let host_connection_id = session
3004 .db()
3005 .await
3006 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
3007 .await?;
3008 let payload = session
3009 .peer
3010 .forward_request(session.connection_id, host_connection_id, request)
3011 .await?;
3012 response.send(payload)?;
3013 Ok(())
3014}
3015
3016/// forward a project request to the dev server. Only allowed
3017/// if it's your dev server.
3018async fn forward_project_request_for_owner<T>(
3019 request: T,
3020 response: Response<T>,
3021 session: UserSession,
3022) -> Result<()>
3023where
3024 T: EntityMessage + RequestMessage,
3025{
3026 let project_id = ProjectId::from_proto(request.remote_entity_id());
3027
3028 let host_connection_id = session
3029 .db()
3030 .await
3031 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
3032 .await?;
3033 let payload = session
3034 .peer
3035 .forward_request(session.connection_id, host_connection_id, request)
3036 .await?;
3037 response.send(payload)?;
3038 Ok(())
3039}
3040
3041/// forward a project request to the host. These requests are disallowed
3042/// for guests.
3043async fn forward_mutating_project_request<T>(
3044 request: T,
3045 response: Response<T>,
3046 session: UserSession,
3047) -> Result<()>
3048where
3049 T: EntityMessage + RequestMessage,
3050{
3051 let project_id = ProjectId::from_proto(request.remote_entity_id());
3052
3053 let host_connection_id = session
3054 .db()
3055 .await
3056 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3057 .await?;
3058 let payload = session
3059 .peer
3060 .forward_request(session.connection_id, host_connection_id, request)
3061 .await?;
3062 response.send(payload)?;
3063 Ok(())
3064}
3065
3066/// forward a project request to the host. These requests are disallowed
3067/// for guests.
3068async fn forward_versioned_mutating_project_request<T>(
3069 request: T,
3070 response: Response<T>,
3071 session: UserSession,
3072) -> Result<()>
3073where
3074 T: EntityMessage + RequestMessage + VersionedMessage,
3075{
3076 let project_id = ProjectId::from_proto(request.remote_entity_id());
3077
3078 let host_connection_id = session
3079 .db()
3080 .await
3081 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3082 .await?;
3083 if let Some(host_version) = session
3084 .connection_pool()
3085 .await
3086 .connection(host_connection_id)
3087 .map(|c| c.zed_version)
3088 {
3089 if let Some(min_required_version) = request.required_host_version() {
3090 if min_required_version > host_version {
3091 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3092 .with_tag("required", &min_required_version.to_string())))?;
3093 }
3094 }
3095 }
3096
3097 let payload = session
3098 .peer
3099 .forward_request(session.connection_id, host_connection_id, request)
3100 .await?;
3101 response.send(payload)?;
3102 Ok(())
3103}
3104
3105/// Notify other participants that a new buffer has been created
3106async fn create_buffer_for_peer(
3107 request: proto::CreateBufferForPeer,
3108 session: Session,
3109) -> Result<()> {
3110 session
3111 .db()
3112 .await
3113 .check_user_is_project_host(
3114 ProjectId::from_proto(request.project_id),
3115 session.connection_id,
3116 )
3117 .await?;
3118 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3119 session
3120 .peer
3121 .forward_send(session.connection_id, peer_id.into(), request)?;
3122 Ok(())
3123}
3124
3125/// Notify other participants that a buffer has been updated. This is
3126/// allowed for guests as long as the update is limited to selections.
3127async fn update_buffer(
3128 request: proto::UpdateBuffer,
3129 response: Response<proto::UpdateBuffer>,
3130 session: Session,
3131) -> Result<()> {
3132 let project_id = ProjectId::from_proto(request.project_id);
3133 let mut capability = Capability::ReadOnly;
3134
3135 for op in request.operations.iter() {
3136 match op.variant {
3137 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3138 Some(_) => capability = Capability::ReadWrite,
3139 }
3140 }
3141
3142 let host = {
3143 let guard = session
3144 .db()
3145 .await
3146 .connections_for_buffer_update(
3147 project_id,
3148 session.principal_id(),
3149 session.connection_id,
3150 capability,
3151 )
3152 .await?;
3153
3154 let (host, guests) = &*guard;
3155
3156 broadcast(
3157 Some(session.connection_id),
3158 guests.clone(),
3159 |connection_id| {
3160 session
3161 .peer
3162 .forward_send(session.connection_id, connection_id, request.clone())
3163 },
3164 );
3165
3166 *host
3167 };
3168
3169 if host != session.connection_id {
3170 session
3171 .peer
3172 .forward_request(session.connection_id, host, request.clone())
3173 .await?;
3174 }
3175
3176 response.send(proto::Ack {})?;
3177 Ok(())
3178}
3179
3180async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3181 let project_id = ProjectId::from_proto(message.project_id);
3182
3183 let operation = message.operation.as_ref().context("invalid operation")?;
3184 let capability = match operation.variant.as_ref() {
3185 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3186 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3187 match buffer_op.variant {
3188 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3189 Capability::ReadOnly
3190 }
3191 _ => Capability::ReadWrite,
3192 }
3193 } else {
3194 Capability::ReadWrite
3195 }
3196 }
3197 Some(_) => Capability::ReadWrite,
3198 None => Capability::ReadOnly,
3199 };
3200
3201 let guard = session
3202 .db()
3203 .await
3204 .connections_for_buffer_update(
3205 project_id,
3206 session.principal_id(),
3207 session.connection_id,
3208 capability,
3209 )
3210 .await?;
3211
3212 let (host, guests) = &*guard;
3213
3214 broadcast(
3215 Some(session.connection_id),
3216 guests.iter().chain([host]).copied(),
3217 |connection_id| {
3218 session
3219 .peer
3220 .forward_send(session.connection_id, connection_id, message.clone())
3221 },
3222 );
3223
3224 Ok(())
3225}
3226
3227/// Notify other participants that a project has been updated.
3228async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3229 request: T,
3230 session: Session,
3231) -> Result<()> {
3232 let project_id = ProjectId::from_proto(request.remote_entity_id());
3233 let project_connection_ids = session
3234 .db()
3235 .await
3236 .project_connection_ids(project_id, session.connection_id, false)
3237 .await?;
3238
3239 broadcast(
3240 Some(session.connection_id),
3241 project_connection_ids.iter().copied(),
3242 |connection_id| {
3243 session
3244 .peer
3245 .forward_send(session.connection_id, connection_id, request.clone())
3246 },
3247 );
3248 Ok(())
3249}
3250
3251/// Start following another user in a call.
3252async fn follow(
3253 request: proto::Follow,
3254 response: Response<proto::Follow>,
3255 session: UserSession,
3256) -> Result<()> {
3257 let room_id = RoomId::from_proto(request.room_id);
3258 let project_id = request.project_id.map(ProjectId::from_proto);
3259 let leader_id = request
3260 .leader_id
3261 .ok_or_else(|| anyhow!("invalid leader id"))?
3262 .into();
3263 let follower_id = session.connection_id;
3264
3265 session
3266 .db()
3267 .await
3268 .check_room_participants(room_id, leader_id, session.connection_id)
3269 .await?;
3270
3271 let response_payload = session
3272 .peer
3273 .forward_request(session.connection_id, leader_id, request)
3274 .await?;
3275 response.send(response_payload)?;
3276
3277 if let Some(project_id) = project_id {
3278 let room = session
3279 .db()
3280 .await
3281 .follow(room_id, project_id, leader_id, follower_id)
3282 .await?;
3283 room_updated(&room, &session.peer);
3284 }
3285
3286 Ok(())
3287}
3288
3289/// Stop following another user in a call.
3290async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3291 let room_id = RoomId::from_proto(request.room_id);
3292 let project_id = request.project_id.map(ProjectId::from_proto);
3293 let leader_id = request
3294 .leader_id
3295 .ok_or_else(|| anyhow!("invalid leader id"))?
3296 .into();
3297 let follower_id = session.connection_id;
3298
3299 session
3300 .db()
3301 .await
3302 .check_room_participants(room_id, leader_id, session.connection_id)
3303 .await?;
3304
3305 session
3306 .peer
3307 .forward_send(session.connection_id, leader_id, request)?;
3308
3309 if let Some(project_id) = project_id {
3310 let room = session
3311 .db()
3312 .await
3313 .unfollow(room_id, project_id, leader_id, follower_id)
3314 .await?;
3315 room_updated(&room, &session.peer);
3316 }
3317
3318 Ok(())
3319}
3320
3321/// Notify everyone following you of your current location.
3322async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3323 let room_id = RoomId::from_proto(request.room_id);
3324 let database = session.db.lock().await;
3325
3326 let connection_ids = if let Some(project_id) = request.project_id {
3327 let project_id = ProjectId::from_proto(project_id);
3328 database
3329 .project_connection_ids(project_id, session.connection_id, true)
3330 .await?
3331 } else {
3332 database
3333 .room_connection_ids(room_id, session.connection_id)
3334 .await?
3335 };
3336
3337 // For now, don't send view update messages back to that view's current leader.
3338 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3339 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3340 _ => None,
3341 });
3342
3343 for connection_id in connection_ids.iter().cloned() {
3344 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3345 session
3346 .peer
3347 .forward_send(session.connection_id, connection_id, request.clone())?;
3348 }
3349 }
3350 Ok(())
3351}
3352
3353/// Get public data about users.
3354async fn get_users(
3355 request: proto::GetUsers,
3356 response: Response<proto::GetUsers>,
3357 session: Session,
3358) -> Result<()> {
3359 let user_ids = request
3360 .user_ids
3361 .into_iter()
3362 .map(UserId::from_proto)
3363 .collect();
3364 let users = session
3365 .db()
3366 .await
3367 .get_users_by_ids(user_ids)
3368 .await?
3369 .into_iter()
3370 .map(|user| proto::User {
3371 id: user.id.to_proto(),
3372 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3373 github_login: user.github_login,
3374 })
3375 .collect();
3376 response.send(proto::UsersResponse { users })?;
3377 Ok(())
3378}
3379
3380/// Search for users (to invite) buy Github login
3381async fn fuzzy_search_users(
3382 request: proto::FuzzySearchUsers,
3383 response: Response<proto::FuzzySearchUsers>,
3384 session: UserSession,
3385) -> Result<()> {
3386 let query = request.query;
3387 let users = match query.len() {
3388 0 => vec![],
3389 1 | 2 => session
3390 .db()
3391 .await
3392 .get_user_by_github_login(&query)
3393 .await?
3394 .into_iter()
3395 .collect(),
3396 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3397 };
3398 let users = users
3399 .into_iter()
3400 .filter(|user| user.id != session.user_id())
3401 .map(|user| proto::User {
3402 id: user.id.to_proto(),
3403 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3404 github_login: user.github_login,
3405 })
3406 .collect();
3407 response.send(proto::UsersResponse { users })?;
3408 Ok(())
3409}
3410
3411/// Send a contact request to another user.
3412async fn request_contact(
3413 request: proto::RequestContact,
3414 response: Response<proto::RequestContact>,
3415 session: UserSession,
3416) -> Result<()> {
3417 let requester_id = session.user_id();
3418 let responder_id = UserId::from_proto(request.responder_id);
3419 if requester_id == responder_id {
3420 return Err(anyhow!("cannot add yourself as a contact"))?;
3421 }
3422
3423 let notifications = session
3424 .db()
3425 .await
3426 .send_contact_request(requester_id, responder_id)
3427 .await?;
3428
3429 // Update outgoing contact requests of requester
3430 let mut update = proto::UpdateContacts::default();
3431 update.outgoing_requests.push(responder_id.to_proto());
3432 for connection_id in session
3433 .connection_pool()
3434 .await
3435 .user_connection_ids(requester_id)
3436 {
3437 session.peer.send(connection_id, update.clone())?;
3438 }
3439
3440 // Update incoming contact requests of responder
3441 let mut update = proto::UpdateContacts::default();
3442 update
3443 .incoming_requests
3444 .push(proto::IncomingContactRequest {
3445 requester_id: requester_id.to_proto(),
3446 });
3447 let connection_pool = session.connection_pool().await;
3448 for connection_id in connection_pool.user_connection_ids(responder_id) {
3449 session.peer.send(connection_id, update.clone())?;
3450 }
3451
3452 send_notifications(&connection_pool, &session.peer, notifications);
3453
3454 response.send(proto::Ack {})?;
3455 Ok(())
3456}
3457
3458/// Accept or decline a contact request
3459async fn respond_to_contact_request(
3460 request: proto::RespondToContactRequest,
3461 response: Response<proto::RespondToContactRequest>,
3462 session: UserSession,
3463) -> Result<()> {
3464 let responder_id = session.user_id();
3465 let requester_id = UserId::from_proto(request.requester_id);
3466 let db = session.db().await;
3467 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3468 db.dismiss_contact_notification(responder_id, requester_id)
3469 .await?;
3470 } else {
3471 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3472
3473 let notifications = db
3474 .respond_to_contact_request(responder_id, requester_id, accept)
3475 .await?;
3476 let requester_busy = db.is_user_busy(requester_id).await?;
3477 let responder_busy = db.is_user_busy(responder_id).await?;
3478
3479 let pool = session.connection_pool().await;
3480 // Update responder with new contact
3481 let mut update = proto::UpdateContacts::default();
3482 if accept {
3483 update
3484 .contacts
3485 .push(contact_for_user(requester_id, requester_busy, &pool));
3486 }
3487 update
3488 .remove_incoming_requests
3489 .push(requester_id.to_proto());
3490 for connection_id in pool.user_connection_ids(responder_id) {
3491 session.peer.send(connection_id, update.clone())?;
3492 }
3493
3494 // Update requester with new contact
3495 let mut update = proto::UpdateContacts::default();
3496 if accept {
3497 update
3498 .contacts
3499 .push(contact_for_user(responder_id, responder_busy, &pool));
3500 }
3501 update
3502 .remove_outgoing_requests
3503 .push(responder_id.to_proto());
3504
3505 for connection_id in pool.user_connection_ids(requester_id) {
3506 session.peer.send(connection_id, update.clone())?;
3507 }
3508
3509 send_notifications(&pool, &session.peer, notifications);
3510 }
3511
3512 response.send(proto::Ack {})?;
3513 Ok(())
3514}
3515
3516/// Remove a contact.
3517async fn remove_contact(
3518 request: proto::RemoveContact,
3519 response: Response<proto::RemoveContact>,
3520 session: UserSession,
3521) -> Result<()> {
3522 let requester_id = session.user_id();
3523 let responder_id = UserId::from_proto(request.user_id);
3524 let db = session.db().await;
3525 let (contact_accepted, deleted_notification_id) =
3526 db.remove_contact(requester_id, responder_id).await?;
3527
3528 let pool = session.connection_pool().await;
3529 // Update outgoing contact requests of requester
3530 let mut update = proto::UpdateContacts::default();
3531 if contact_accepted {
3532 update.remove_contacts.push(responder_id.to_proto());
3533 } else {
3534 update
3535 .remove_outgoing_requests
3536 .push(responder_id.to_proto());
3537 }
3538 for connection_id in pool.user_connection_ids(requester_id) {
3539 session.peer.send(connection_id, update.clone())?;
3540 }
3541
3542 // Update incoming contact requests of responder
3543 let mut update = proto::UpdateContacts::default();
3544 if contact_accepted {
3545 update.remove_contacts.push(requester_id.to_proto());
3546 } else {
3547 update
3548 .remove_incoming_requests
3549 .push(requester_id.to_proto());
3550 }
3551 for connection_id in pool.user_connection_ids(responder_id) {
3552 session.peer.send(connection_id, update.clone())?;
3553 if let Some(notification_id) = deleted_notification_id {
3554 session.peer.send(
3555 connection_id,
3556 proto::DeleteNotification {
3557 notification_id: notification_id.to_proto(),
3558 },
3559 )?;
3560 }
3561 }
3562
3563 response.send(proto::Ack {})?;
3564 Ok(())
3565}
3566
3567fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3568 version.0.minor() < 139
3569}
3570
3571async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
3572 let plan = session.current_plan().await?;
3573
3574 session
3575 .peer
3576 .send(
3577 session.connection_id,
3578 proto::UpdateUserPlan { plan: plan.into() },
3579 )
3580 .trace_err();
3581
3582 Ok(())
3583}
3584
3585async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3586 subscribe_user_to_channels(
3587 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3588 &session,
3589 )
3590 .await?;
3591 Ok(())
3592}
3593
3594async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3595 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3596 let mut pool = session.connection_pool().await;
3597 for membership in &channels_for_user.channel_memberships {
3598 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3599 }
3600 session.peer.send(
3601 session.connection_id,
3602 build_update_user_channels(&channels_for_user),
3603 )?;
3604 session.peer.send(
3605 session.connection_id,
3606 build_channels_update(channels_for_user),
3607 )?;
3608 Ok(())
3609}
3610
3611/// Creates a new channel.
3612async fn create_channel(
3613 request: proto::CreateChannel,
3614 response: Response<proto::CreateChannel>,
3615 session: UserSession,
3616) -> Result<()> {
3617 let db = session.db().await;
3618
3619 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3620 let (channel, membership) = db
3621 .create_channel(&request.name, parent_id, session.user_id())
3622 .await?;
3623
3624 let root_id = channel.root_id();
3625 let channel = Channel::from_model(channel);
3626
3627 response.send(proto::CreateChannelResponse {
3628 channel: Some(channel.to_proto()),
3629 parent_id: request.parent_id,
3630 })?;
3631
3632 let mut connection_pool = session.connection_pool().await;
3633 if let Some(membership) = membership {
3634 connection_pool.subscribe_to_channel(
3635 membership.user_id,
3636 membership.channel_id,
3637 membership.role,
3638 );
3639 let update = proto::UpdateUserChannels {
3640 channel_memberships: vec![proto::ChannelMembership {
3641 channel_id: membership.channel_id.to_proto(),
3642 role: membership.role.into(),
3643 }],
3644 ..Default::default()
3645 };
3646 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3647 session.peer.send(connection_id, update.clone())?;
3648 }
3649 }
3650
3651 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3652 if !role.can_see_channel(channel.visibility) {
3653 continue;
3654 }
3655
3656 let update = proto::UpdateChannels {
3657 channels: vec![channel.to_proto()],
3658 ..Default::default()
3659 };
3660 session.peer.send(connection_id, update.clone())?;
3661 }
3662
3663 Ok(())
3664}
3665
3666/// Delete a channel
3667async fn delete_channel(
3668 request: proto::DeleteChannel,
3669 response: Response<proto::DeleteChannel>,
3670 session: UserSession,
3671) -> Result<()> {
3672 let db = session.db().await;
3673
3674 let channel_id = request.channel_id;
3675 let (root_channel, removed_channels) = db
3676 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3677 .await?;
3678 response.send(proto::Ack {})?;
3679
3680 // Notify members of removed channels
3681 let mut update = proto::UpdateChannels::default();
3682 update
3683 .delete_channels
3684 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3685
3686 let connection_pool = session.connection_pool().await;
3687 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3688 session.peer.send(connection_id, update.clone())?;
3689 }
3690
3691 Ok(())
3692}
3693
3694/// Invite someone to join a channel.
3695async fn invite_channel_member(
3696 request: proto::InviteChannelMember,
3697 response: Response<proto::InviteChannelMember>,
3698 session: UserSession,
3699) -> Result<()> {
3700 let db = session.db().await;
3701 let channel_id = ChannelId::from_proto(request.channel_id);
3702 let invitee_id = UserId::from_proto(request.user_id);
3703 let InviteMemberResult {
3704 channel,
3705 notifications,
3706 } = db
3707 .invite_channel_member(
3708 channel_id,
3709 invitee_id,
3710 session.user_id(),
3711 request.role().into(),
3712 )
3713 .await?;
3714
3715 let update = proto::UpdateChannels {
3716 channel_invitations: vec![channel.to_proto()],
3717 ..Default::default()
3718 };
3719
3720 let connection_pool = session.connection_pool().await;
3721 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3722 session.peer.send(connection_id, update.clone())?;
3723 }
3724
3725 send_notifications(&connection_pool, &session.peer, notifications);
3726
3727 response.send(proto::Ack {})?;
3728 Ok(())
3729}
3730
3731/// remove someone from a channel
3732async fn remove_channel_member(
3733 request: proto::RemoveChannelMember,
3734 response: Response<proto::RemoveChannelMember>,
3735 session: UserSession,
3736) -> Result<()> {
3737 let db = session.db().await;
3738 let channel_id = ChannelId::from_proto(request.channel_id);
3739 let member_id = UserId::from_proto(request.user_id);
3740
3741 let RemoveChannelMemberResult {
3742 membership_update,
3743 notification_id,
3744 } = db
3745 .remove_channel_member(channel_id, member_id, session.user_id())
3746 .await?;
3747
3748 let mut connection_pool = session.connection_pool().await;
3749 notify_membership_updated(
3750 &mut connection_pool,
3751 membership_update,
3752 member_id,
3753 &session.peer,
3754 );
3755 for connection_id in connection_pool.user_connection_ids(member_id) {
3756 if let Some(notification_id) = notification_id {
3757 session
3758 .peer
3759 .send(
3760 connection_id,
3761 proto::DeleteNotification {
3762 notification_id: notification_id.to_proto(),
3763 },
3764 )
3765 .trace_err();
3766 }
3767 }
3768
3769 response.send(proto::Ack {})?;
3770 Ok(())
3771}
3772
3773/// Toggle the channel between public and private.
3774/// Care is taken to maintain the invariant that public channels only descend from public channels,
3775/// (though members-only channels can appear at any point in the hierarchy).
3776async fn set_channel_visibility(
3777 request: proto::SetChannelVisibility,
3778 response: Response<proto::SetChannelVisibility>,
3779 session: UserSession,
3780) -> Result<()> {
3781 let db = session.db().await;
3782 let channel_id = ChannelId::from_proto(request.channel_id);
3783 let visibility = request.visibility().into();
3784
3785 let channel_model = db
3786 .set_channel_visibility(channel_id, visibility, session.user_id())
3787 .await?;
3788 let root_id = channel_model.root_id();
3789 let channel = Channel::from_model(channel_model);
3790
3791 let mut connection_pool = session.connection_pool().await;
3792 for (user_id, role) in connection_pool
3793 .channel_user_ids(root_id)
3794 .collect::<Vec<_>>()
3795 .into_iter()
3796 {
3797 let update = if role.can_see_channel(channel.visibility) {
3798 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3799 proto::UpdateChannels {
3800 channels: vec![channel.to_proto()],
3801 ..Default::default()
3802 }
3803 } else {
3804 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3805 proto::UpdateChannels {
3806 delete_channels: vec![channel.id.to_proto()],
3807 ..Default::default()
3808 }
3809 };
3810
3811 for connection_id in connection_pool.user_connection_ids(user_id) {
3812 session.peer.send(connection_id, update.clone())?;
3813 }
3814 }
3815
3816 response.send(proto::Ack {})?;
3817 Ok(())
3818}
3819
3820/// Alter the role for a user in the channel.
3821async fn set_channel_member_role(
3822 request: proto::SetChannelMemberRole,
3823 response: Response<proto::SetChannelMemberRole>,
3824 session: UserSession,
3825) -> Result<()> {
3826 let db = session.db().await;
3827 let channel_id = ChannelId::from_proto(request.channel_id);
3828 let member_id = UserId::from_proto(request.user_id);
3829 let result = db
3830 .set_channel_member_role(
3831 channel_id,
3832 session.user_id(),
3833 member_id,
3834 request.role().into(),
3835 )
3836 .await?;
3837
3838 match result {
3839 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3840 let mut connection_pool = session.connection_pool().await;
3841 notify_membership_updated(
3842 &mut connection_pool,
3843 membership_update,
3844 member_id,
3845 &session.peer,
3846 )
3847 }
3848 db::SetMemberRoleResult::InviteUpdated(channel) => {
3849 let update = proto::UpdateChannels {
3850 channel_invitations: vec![channel.to_proto()],
3851 ..Default::default()
3852 };
3853
3854 for connection_id in session
3855 .connection_pool()
3856 .await
3857 .user_connection_ids(member_id)
3858 {
3859 session.peer.send(connection_id, update.clone())?;
3860 }
3861 }
3862 }
3863
3864 response.send(proto::Ack {})?;
3865 Ok(())
3866}
3867
3868/// Change the name of a channel
3869async fn rename_channel(
3870 request: proto::RenameChannel,
3871 response: Response<proto::RenameChannel>,
3872 session: UserSession,
3873) -> Result<()> {
3874 let db = session.db().await;
3875 let channel_id = ChannelId::from_proto(request.channel_id);
3876 let channel_model = db
3877 .rename_channel(channel_id, session.user_id(), &request.name)
3878 .await?;
3879 let root_id = channel_model.root_id();
3880 let channel = Channel::from_model(channel_model);
3881
3882 response.send(proto::RenameChannelResponse {
3883 channel: Some(channel.to_proto()),
3884 })?;
3885
3886 let connection_pool = session.connection_pool().await;
3887 let update = proto::UpdateChannels {
3888 channels: vec![channel.to_proto()],
3889 ..Default::default()
3890 };
3891 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3892 if role.can_see_channel(channel.visibility) {
3893 session.peer.send(connection_id, update.clone())?;
3894 }
3895 }
3896
3897 Ok(())
3898}
3899
3900/// Move a channel to a new parent.
3901async fn move_channel(
3902 request: proto::MoveChannel,
3903 response: Response<proto::MoveChannel>,
3904 session: UserSession,
3905) -> Result<()> {
3906 let channel_id = ChannelId::from_proto(request.channel_id);
3907 let to = ChannelId::from_proto(request.to);
3908
3909 let (root_id, channels) = session
3910 .db()
3911 .await
3912 .move_channel(channel_id, to, session.user_id())
3913 .await?;
3914
3915 let connection_pool = session.connection_pool().await;
3916 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3917 let channels = channels
3918 .iter()
3919 .filter_map(|channel| {
3920 if role.can_see_channel(channel.visibility) {
3921 Some(channel.to_proto())
3922 } else {
3923 None
3924 }
3925 })
3926 .collect::<Vec<_>>();
3927 if channels.is_empty() {
3928 continue;
3929 }
3930
3931 let update = proto::UpdateChannels {
3932 channels,
3933 ..Default::default()
3934 };
3935
3936 session.peer.send(connection_id, update.clone())?;
3937 }
3938
3939 response.send(Ack {})?;
3940 Ok(())
3941}
3942
3943/// Get the list of channel members
3944async fn get_channel_members(
3945 request: proto::GetChannelMembers,
3946 response: Response<proto::GetChannelMembers>,
3947 session: UserSession,
3948) -> Result<()> {
3949 let db = session.db().await;
3950 let channel_id = ChannelId::from_proto(request.channel_id);
3951 let limit = if request.limit == 0 {
3952 u16::MAX as u64
3953 } else {
3954 request.limit
3955 };
3956 let (members, users) = db
3957 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3958 .await?;
3959 response.send(proto::GetChannelMembersResponse { members, users })?;
3960 Ok(())
3961}
3962
3963/// Accept or decline a channel invitation.
3964async fn respond_to_channel_invite(
3965 request: proto::RespondToChannelInvite,
3966 response: Response<proto::RespondToChannelInvite>,
3967 session: UserSession,
3968) -> Result<()> {
3969 let db = session.db().await;
3970 let channel_id = ChannelId::from_proto(request.channel_id);
3971 let RespondToChannelInvite {
3972 membership_update,
3973 notifications,
3974 } = db
3975 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3976 .await?;
3977
3978 let mut connection_pool = session.connection_pool().await;
3979 if let Some(membership_update) = membership_update {
3980 notify_membership_updated(
3981 &mut connection_pool,
3982 membership_update,
3983 session.user_id(),
3984 &session.peer,
3985 );
3986 } else {
3987 let update = proto::UpdateChannels {
3988 remove_channel_invitations: vec![channel_id.to_proto()],
3989 ..Default::default()
3990 };
3991
3992 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3993 session.peer.send(connection_id, update.clone())?;
3994 }
3995 };
3996
3997 send_notifications(&connection_pool, &session.peer, notifications);
3998
3999 response.send(proto::Ack {})?;
4000
4001 Ok(())
4002}
4003
4004/// Join the channels' room
4005async fn join_channel(
4006 request: proto::JoinChannel,
4007 response: Response<proto::JoinChannel>,
4008 session: UserSession,
4009) -> Result<()> {
4010 let channel_id = ChannelId::from_proto(request.channel_id);
4011 join_channel_internal(channel_id, Box::new(response), session).await
4012}
4013
4014trait JoinChannelInternalResponse {
4015 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
4016}
4017impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
4018 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4019 Response::<proto::JoinChannel>::send(self, result)
4020 }
4021}
4022impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
4023 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
4024 Response::<proto::JoinRoom>::send(self, result)
4025 }
4026}
4027
4028async fn join_channel_internal(
4029 channel_id: ChannelId,
4030 response: Box<impl JoinChannelInternalResponse>,
4031 session: UserSession,
4032) -> Result<()> {
4033 let joined_room = {
4034 let mut db = session.db().await;
4035 // If zed quits without leaving the room, and the user re-opens zed before the
4036 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
4037 // room they were in.
4038 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
4039 tracing::info!(
4040 stale_connection_id = %connection,
4041 "cleaning up stale connection",
4042 );
4043 drop(db);
4044 leave_room_for_session(&session, connection).await?;
4045 db = session.db().await;
4046 }
4047
4048 let (joined_room, membership_updated, role) = db
4049 .join_channel(channel_id, session.user_id(), session.connection_id)
4050 .await?;
4051
4052 let live_kit_connection_info =
4053 session
4054 .app_state
4055 .live_kit_client
4056 .as_ref()
4057 .and_then(|live_kit| {
4058 let (can_publish, token) = if role == ChannelRole::Guest {
4059 (
4060 false,
4061 live_kit
4062 .guest_token(
4063 &joined_room.room.live_kit_room,
4064 &session.user_id().to_string(),
4065 )
4066 .trace_err()?,
4067 )
4068 } else {
4069 (
4070 true,
4071 live_kit
4072 .room_token(
4073 &joined_room.room.live_kit_room,
4074 &session.user_id().to_string(),
4075 )
4076 .trace_err()?,
4077 )
4078 };
4079
4080 Some(LiveKitConnectionInfo {
4081 server_url: live_kit.url().into(),
4082 token,
4083 can_publish,
4084 })
4085 });
4086
4087 response.send(proto::JoinRoomResponse {
4088 room: Some(joined_room.room.clone()),
4089 channel_id: joined_room
4090 .channel
4091 .as_ref()
4092 .map(|channel| channel.id.to_proto()),
4093 live_kit_connection_info,
4094 })?;
4095
4096 let mut connection_pool = session.connection_pool().await;
4097 if let Some(membership_updated) = membership_updated {
4098 notify_membership_updated(
4099 &mut connection_pool,
4100 membership_updated,
4101 session.user_id(),
4102 &session.peer,
4103 );
4104 }
4105
4106 room_updated(&joined_room.room, &session.peer);
4107
4108 joined_room
4109 };
4110
4111 channel_updated(
4112 &joined_room
4113 .channel
4114 .ok_or_else(|| anyhow!("channel not returned"))?,
4115 &joined_room.room,
4116 &session.peer,
4117 &*session.connection_pool().await,
4118 );
4119
4120 update_user_contacts(session.user_id(), &session).await?;
4121 Ok(())
4122}
4123
4124/// Start editing the channel notes
4125async fn join_channel_buffer(
4126 request: proto::JoinChannelBuffer,
4127 response: Response<proto::JoinChannelBuffer>,
4128 session: UserSession,
4129) -> Result<()> {
4130 let db = session.db().await;
4131 let channel_id = ChannelId::from_proto(request.channel_id);
4132
4133 let open_response = db
4134 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4135 .await?;
4136
4137 let collaborators = open_response.collaborators.clone();
4138 response.send(open_response)?;
4139
4140 let update = UpdateChannelBufferCollaborators {
4141 channel_id: channel_id.to_proto(),
4142 collaborators: collaborators.clone(),
4143 };
4144 channel_buffer_updated(
4145 session.connection_id,
4146 collaborators
4147 .iter()
4148 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4149 &update,
4150 &session.peer,
4151 );
4152
4153 Ok(())
4154}
4155
4156/// Edit the channel notes
4157async fn update_channel_buffer(
4158 request: proto::UpdateChannelBuffer,
4159 session: UserSession,
4160) -> Result<()> {
4161 let db = session.db().await;
4162 let channel_id = ChannelId::from_proto(request.channel_id);
4163
4164 let (collaborators, epoch, version) = db
4165 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4166 .await?;
4167
4168 channel_buffer_updated(
4169 session.connection_id,
4170 collaborators.clone(),
4171 &proto::UpdateChannelBuffer {
4172 channel_id: channel_id.to_proto(),
4173 operations: request.operations,
4174 },
4175 &session.peer,
4176 );
4177
4178 let pool = &*session.connection_pool().await;
4179
4180 let non_collaborators =
4181 pool.channel_connection_ids(channel_id)
4182 .filter_map(|(connection_id, _)| {
4183 if collaborators.contains(&connection_id) {
4184 None
4185 } else {
4186 Some(connection_id)
4187 }
4188 });
4189
4190 broadcast(None, non_collaborators, |peer_id| {
4191 session.peer.send(
4192 peer_id,
4193 proto::UpdateChannels {
4194 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4195 channel_id: channel_id.to_proto(),
4196 epoch: epoch as u64,
4197 version: version.clone(),
4198 }],
4199 ..Default::default()
4200 },
4201 )
4202 });
4203
4204 Ok(())
4205}
4206
4207/// Rejoin the channel notes after a connection blip
4208async fn rejoin_channel_buffers(
4209 request: proto::RejoinChannelBuffers,
4210 response: Response<proto::RejoinChannelBuffers>,
4211 session: UserSession,
4212) -> Result<()> {
4213 let db = session.db().await;
4214 let buffers = db
4215 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4216 .await?;
4217
4218 for rejoined_buffer in &buffers {
4219 let collaborators_to_notify = rejoined_buffer
4220 .buffer
4221 .collaborators
4222 .iter()
4223 .filter_map(|c| Some(c.peer_id?.into()));
4224 channel_buffer_updated(
4225 session.connection_id,
4226 collaborators_to_notify,
4227 &proto::UpdateChannelBufferCollaborators {
4228 channel_id: rejoined_buffer.buffer.channel_id,
4229 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4230 },
4231 &session.peer,
4232 );
4233 }
4234
4235 response.send(proto::RejoinChannelBuffersResponse {
4236 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4237 })?;
4238
4239 Ok(())
4240}
4241
4242/// Stop editing the channel notes
4243async fn leave_channel_buffer(
4244 request: proto::LeaveChannelBuffer,
4245 response: Response<proto::LeaveChannelBuffer>,
4246 session: UserSession,
4247) -> Result<()> {
4248 let db = session.db().await;
4249 let channel_id = ChannelId::from_proto(request.channel_id);
4250
4251 let left_buffer = db
4252 .leave_channel_buffer(channel_id, session.connection_id)
4253 .await?;
4254
4255 response.send(Ack {})?;
4256
4257 channel_buffer_updated(
4258 session.connection_id,
4259 left_buffer.connections,
4260 &proto::UpdateChannelBufferCollaborators {
4261 channel_id: channel_id.to_proto(),
4262 collaborators: left_buffer.collaborators,
4263 },
4264 &session.peer,
4265 );
4266
4267 Ok(())
4268}
4269
4270fn channel_buffer_updated<T: EnvelopedMessage>(
4271 sender_id: ConnectionId,
4272 collaborators: impl IntoIterator<Item = ConnectionId>,
4273 message: &T,
4274 peer: &Peer,
4275) {
4276 broadcast(Some(sender_id), collaborators, |peer_id| {
4277 peer.send(peer_id, message.clone())
4278 });
4279}
4280
4281fn send_notifications(
4282 connection_pool: &ConnectionPool,
4283 peer: &Peer,
4284 notifications: db::NotificationBatch,
4285) {
4286 for (user_id, notification) in notifications {
4287 for connection_id in connection_pool.user_connection_ids(user_id) {
4288 if let Err(error) = peer.send(
4289 connection_id,
4290 proto::AddNotification {
4291 notification: Some(notification.clone()),
4292 },
4293 ) {
4294 tracing::error!(
4295 "failed to send notification to {:?} {}",
4296 connection_id,
4297 error
4298 );
4299 }
4300 }
4301 }
4302}
4303
4304/// Send a message to the channel
4305async fn send_channel_message(
4306 request: proto::SendChannelMessage,
4307 response: Response<proto::SendChannelMessage>,
4308 session: UserSession,
4309) -> Result<()> {
4310 // Validate the message body.
4311 let body = request.body.trim().to_string();
4312 if body.len() > MAX_MESSAGE_LEN {
4313 return Err(anyhow!("message is too long"))?;
4314 }
4315 if body.is_empty() {
4316 return Err(anyhow!("message can't be blank"))?;
4317 }
4318
4319 // TODO: adjust mentions if body is trimmed
4320
4321 let timestamp = OffsetDateTime::now_utc();
4322 let nonce = request
4323 .nonce
4324 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4325
4326 let channel_id = ChannelId::from_proto(request.channel_id);
4327 let CreatedChannelMessage {
4328 message_id,
4329 participant_connection_ids,
4330 notifications,
4331 } = session
4332 .db()
4333 .await
4334 .create_channel_message(
4335 channel_id,
4336 session.user_id(),
4337 &body,
4338 &request.mentions,
4339 timestamp,
4340 nonce.clone().into(),
4341 match request.reply_to_message_id {
4342 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4343 None => None,
4344 },
4345 )
4346 .await?;
4347
4348 let message = proto::ChannelMessage {
4349 sender_id: session.user_id().to_proto(),
4350 id: message_id.to_proto(),
4351 body,
4352 mentions: request.mentions,
4353 timestamp: timestamp.unix_timestamp() as u64,
4354 nonce: Some(nonce),
4355 reply_to_message_id: request.reply_to_message_id,
4356 edited_at: None,
4357 };
4358 broadcast(
4359 Some(session.connection_id),
4360 participant_connection_ids.clone(),
4361 |connection| {
4362 session.peer.send(
4363 connection,
4364 proto::ChannelMessageSent {
4365 channel_id: channel_id.to_proto(),
4366 message: Some(message.clone()),
4367 },
4368 )
4369 },
4370 );
4371 response.send(proto::SendChannelMessageResponse {
4372 message: Some(message),
4373 })?;
4374
4375 let pool = &*session.connection_pool().await;
4376 let non_participants =
4377 pool.channel_connection_ids(channel_id)
4378 .filter_map(|(connection_id, _)| {
4379 if participant_connection_ids.contains(&connection_id) {
4380 None
4381 } else {
4382 Some(connection_id)
4383 }
4384 });
4385 broadcast(None, non_participants, |peer_id| {
4386 session.peer.send(
4387 peer_id,
4388 proto::UpdateChannels {
4389 latest_channel_message_ids: vec![proto::ChannelMessageId {
4390 channel_id: channel_id.to_proto(),
4391 message_id: message_id.to_proto(),
4392 }],
4393 ..Default::default()
4394 },
4395 )
4396 });
4397 send_notifications(pool, &session.peer, notifications);
4398
4399 Ok(())
4400}
4401
4402/// Delete a channel message
4403async fn remove_channel_message(
4404 request: proto::RemoveChannelMessage,
4405 response: Response<proto::RemoveChannelMessage>,
4406 session: UserSession,
4407) -> Result<()> {
4408 let channel_id = ChannelId::from_proto(request.channel_id);
4409 let message_id = MessageId::from_proto(request.message_id);
4410 let (connection_ids, existing_notification_ids) = session
4411 .db()
4412 .await
4413 .remove_channel_message(channel_id, message_id, session.user_id())
4414 .await?;
4415
4416 broadcast(
4417 Some(session.connection_id),
4418 connection_ids,
4419 move |connection| {
4420 session.peer.send(connection, request.clone())?;
4421
4422 for notification_id in &existing_notification_ids {
4423 session.peer.send(
4424 connection,
4425 proto::DeleteNotification {
4426 notification_id: (*notification_id).to_proto(),
4427 },
4428 )?;
4429 }
4430
4431 Ok(())
4432 },
4433 );
4434 response.send(proto::Ack {})?;
4435 Ok(())
4436}
4437
4438async fn update_channel_message(
4439 request: proto::UpdateChannelMessage,
4440 response: Response<proto::UpdateChannelMessage>,
4441 session: UserSession,
4442) -> Result<()> {
4443 let channel_id = ChannelId::from_proto(request.channel_id);
4444 let message_id = MessageId::from_proto(request.message_id);
4445 let updated_at = OffsetDateTime::now_utc();
4446 let UpdatedChannelMessage {
4447 message_id,
4448 participant_connection_ids,
4449 notifications,
4450 reply_to_message_id,
4451 timestamp,
4452 deleted_mention_notification_ids,
4453 updated_mention_notifications,
4454 } = session
4455 .db()
4456 .await
4457 .update_channel_message(
4458 channel_id,
4459 message_id,
4460 session.user_id(),
4461 request.body.as_str(),
4462 &request.mentions,
4463 updated_at,
4464 )
4465 .await?;
4466
4467 let nonce = request
4468 .nonce
4469 .clone()
4470 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4471
4472 let message = proto::ChannelMessage {
4473 sender_id: session.user_id().to_proto(),
4474 id: message_id.to_proto(),
4475 body: request.body.clone(),
4476 mentions: request.mentions.clone(),
4477 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4478 nonce: Some(nonce),
4479 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4480 edited_at: Some(updated_at.unix_timestamp() as u64),
4481 };
4482
4483 response.send(proto::Ack {})?;
4484
4485 let pool = &*session.connection_pool().await;
4486 broadcast(
4487 Some(session.connection_id),
4488 participant_connection_ids,
4489 |connection| {
4490 session.peer.send(
4491 connection,
4492 proto::ChannelMessageUpdate {
4493 channel_id: channel_id.to_proto(),
4494 message: Some(message.clone()),
4495 },
4496 )?;
4497
4498 for notification_id in &deleted_mention_notification_ids {
4499 session.peer.send(
4500 connection,
4501 proto::DeleteNotification {
4502 notification_id: (*notification_id).to_proto(),
4503 },
4504 )?;
4505 }
4506
4507 for notification in &updated_mention_notifications {
4508 session.peer.send(
4509 connection,
4510 proto::UpdateNotification {
4511 notification: Some(notification.clone()),
4512 },
4513 )?;
4514 }
4515
4516 Ok(())
4517 },
4518 );
4519
4520 send_notifications(pool, &session.peer, notifications);
4521
4522 Ok(())
4523}
4524
4525/// Mark a channel message as read
4526async fn acknowledge_channel_message(
4527 request: proto::AckChannelMessage,
4528 session: UserSession,
4529) -> Result<()> {
4530 let channel_id = ChannelId::from_proto(request.channel_id);
4531 let message_id = MessageId::from_proto(request.message_id);
4532 let notifications = session
4533 .db()
4534 .await
4535 .observe_channel_message(channel_id, session.user_id(), message_id)
4536 .await?;
4537 send_notifications(
4538 &*session.connection_pool().await,
4539 &session.peer,
4540 notifications,
4541 );
4542 Ok(())
4543}
4544
4545/// Mark a buffer version as synced
4546async fn acknowledge_buffer_version(
4547 request: proto::AckBufferOperation,
4548 session: UserSession,
4549) -> Result<()> {
4550 let buffer_id = BufferId::from_proto(request.buffer_id);
4551 session
4552 .db()
4553 .await
4554 .observe_buffer_version(
4555 buffer_id,
4556 session.user_id(),
4557 request.epoch as i32,
4558 &request.version,
4559 )
4560 .await?;
4561 Ok(())
4562}
4563
4564struct ZedProCompleteWithLanguageModelRateLimit;
4565
4566impl RateLimit for ZedProCompleteWithLanguageModelRateLimit {
4567 fn capacity(&self) -> usize {
4568 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4569 .ok()
4570 .and_then(|v| v.parse().ok())
4571 .unwrap_or(120) // Picked arbitrarily
4572 }
4573
4574 fn refill_duration(&self) -> chrono::Duration {
4575 chrono::Duration::hours(1)
4576 }
4577
4578 fn db_name(&self) -> &'static str {
4579 "zed-pro:complete-with-language-model"
4580 }
4581}
4582
4583struct FreeCompleteWithLanguageModelRateLimit;
4584
4585impl RateLimit for FreeCompleteWithLanguageModelRateLimit {
4586 fn capacity(&self) -> usize {
4587 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE")
4588 .ok()
4589 .and_then(|v| v.parse().ok())
4590 .unwrap_or(120 / 10) // Picked arbitrarily
4591 }
4592
4593 fn refill_duration(&self) -> chrono::Duration {
4594 chrono::Duration::hours(1)
4595 }
4596
4597 fn db_name(&self) -> &'static str {
4598 "free:complete-with-language-model"
4599 }
4600}
4601
4602async fn complete_with_language_model(
4603 request: proto::CompleteWithLanguageModel,
4604 response: Response<proto::CompleteWithLanguageModel>,
4605 session: Session,
4606 config: &Config,
4607) -> Result<()> {
4608 let Some(session) = session.for_user() else {
4609 return Err(anyhow!("user not found"))?;
4610 };
4611 authorize_access_to_language_models(&session).await?;
4612
4613 let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4614 proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
4615 proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
4616 };
4617
4618 session
4619 .app_state
4620 .rate_limiter
4621 .check(&*rate_limit, session.user_id())
4622 .await?;
4623
4624 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4625 Some(proto::LanguageModelProvider::Anthropic) => {
4626 let api_key = config
4627 .anthropic_api_key
4628 .as_ref()
4629 .context("no Anthropic AI API key configured on the server")?;
4630 anthropic::complete(
4631 session.http_client.as_ref(),
4632 anthropic::ANTHROPIC_API_URL,
4633 api_key,
4634 serde_json::from_str(&request.request)?,
4635 )
4636 .await?
4637 }
4638 _ => return Err(anyhow!("unsupported provider"))?,
4639 };
4640
4641 response.send(proto::CompleteWithLanguageModelResponse {
4642 completion: serde_json::to_string(&result)?,
4643 })?;
4644
4645 Ok(())
4646}
4647
4648async fn stream_complete_with_language_model(
4649 request: proto::StreamCompleteWithLanguageModel,
4650 response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
4651 session: Session,
4652 config: &Config,
4653) -> Result<()> {
4654 let Some(session) = session.for_user() else {
4655 return Err(anyhow!("user not found"))?;
4656 };
4657 authorize_access_to_language_models(&session).await?;
4658
4659 let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4660 proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
4661 proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
4662 };
4663
4664 session
4665 .app_state
4666 .rate_limiter
4667 .check(&*rate_limit, session.user_id())
4668 .await?;
4669
4670 match proto::LanguageModelProvider::from_i32(request.provider) {
4671 Some(proto::LanguageModelProvider::Anthropic) => {
4672 let api_key = config
4673 .anthropic_api_key
4674 .as_ref()
4675 .context("no Anthropic AI API key configured on the server")?;
4676 let mut chunks = anthropic::stream_completion(
4677 session.http_client.as_ref(),
4678 anthropic::ANTHROPIC_API_URL,
4679 api_key,
4680 serde_json::from_str(&request.request)?,
4681 None,
4682 )
4683 .await?;
4684 while let Some(event) = chunks.next().await {
4685 let chunk = event?;
4686 response.send(proto::StreamCompleteWithLanguageModelResponse {
4687 event: serde_json::to_string(&chunk)?,
4688 })?;
4689 }
4690 }
4691 Some(proto::LanguageModelProvider::OpenAi) => {
4692 let api_key = config
4693 .openai_api_key
4694 .as_ref()
4695 .context("no OpenAI API key configured on the server")?;
4696 let mut events = open_ai::stream_completion(
4697 session.http_client.as_ref(),
4698 open_ai::OPEN_AI_API_URL,
4699 api_key,
4700 serde_json::from_str(&request.request)?,
4701 None,
4702 )
4703 .await?;
4704 while let Some(event) = events.next().await {
4705 let event = event?;
4706 response.send(proto::StreamCompleteWithLanguageModelResponse {
4707 event: serde_json::to_string(&event)?,
4708 })?;
4709 }
4710 }
4711 Some(proto::LanguageModelProvider::Google) => {
4712 let api_key = config
4713 .google_ai_api_key
4714 .as_ref()
4715 .context("no Google AI API key configured on the server")?;
4716 let mut events = google_ai::stream_generate_content(
4717 session.http_client.as_ref(),
4718 google_ai::API_URL,
4719 api_key,
4720 serde_json::from_str(&request.request)?,
4721 )
4722 .await?;
4723 while let Some(event) = events.next().await {
4724 let event = event?;
4725 response.send(proto::StreamCompleteWithLanguageModelResponse {
4726 event: serde_json::to_string(&event)?,
4727 })?;
4728 }
4729 }
4730 Some(proto::LanguageModelProvider::Zed) => {
4731 let api_key = config
4732 .qwen2_7b_api_key
4733 .as_ref()
4734 .context("no Qwen2-7B API key configured on the server")?;
4735 let api_url = config
4736 .qwen2_7b_api_url
4737 .as_ref()
4738 .context("no Qwen2-7B URL configured on the server")?;
4739 let mut events = open_ai::stream_completion(
4740 session.http_client.as_ref(),
4741 &api_url,
4742 api_key,
4743 serde_json::from_str(&request.request)?,
4744 None,
4745 )
4746 .await?;
4747 while let Some(event) = events.next().await {
4748 let event = event?;
4749 response.send(proto::StreamCompleteWithLanguageModelResponse {
4750 event: serde_json::to_string(&event)?,
4751 })?;
4752 }
4753 }
4754 None => return Err(anyhow!("unknown provider"))?,
4755 }
4756
4757 Ok(())
4758}
4759
4760async fn count_language_model_tokens(
4761 request: proto::CountLanguageModelTokens,
4762 response: Response<proto::CountLanguageModelTokens>,
4763 session: Session,
4764 config: &Config,
4765) -> Result<()> {
4766 let Some(session) = session.for_user() else {
4767 return Err(anyhow!("user not found"))?;
4768 };
4769 authorize_access_to_language_models(&session).await?;
4770
4771 let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4772 proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
4773 proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
4774 };
4775
4776 session
4777 .app_state
4778 .rate_limiter
4779 .check(&*rate_limit, session.user_id())
4780 .await?;
4781
4782 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4783 Some(proto::LanguageModelProvider::Google) => {
4784 let api_key = config
4785 .google_ai_api_key
4786 .as_ref()
4787 .context("no Google AI API key configured on the server")?;
4788 google_ai::count_tokens(
4789 session.http_client.as_ref(),
4790 google_ai::API_URL,
4791 api_key,
4792 serde_json::from_str(&request.request)?,
4793 )
4794 .await?
4795 }
4796 _ => return Err(anyhow!("unsupported provider"))?,
4797 };
4798
4799 response.send(proto::CountLanguageModelTokensResponse {
4800 token_count: result.total_tokens as u32,
4801 })?;
4802
4803 Ok(())
4804}
4805
4806struct ZedProCountLanguageModelTokensRateLimit;
4807
4808impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
4809 fn capacity(&self) -> usize {
4810 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4811 .ok()
4812 .and_then(|v| v.parse().ok())
4813 .unwrap_or(600) // Picked arbitrarily
4814 }
4815
4816 fn refill_duration(&self) -> chrono::Duration {
4817 chrono::Duration::hours(1)
4818 }
4819
4820 fn db_name(&self) -> &'static str {
4821 "zed-pro:count-language-model-tokens"
4822 }
4823}
4824
4825struct FreeCountLanguageModelTokensRateLimit;
4826
4827impl RateLimit for FreeCountLanguageModelTokensRateLimit {
4828 fn capacity(&self) -> usize {
4829 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
4830 .ok()
4831 .and_then(|v| v.parse().ok())
4832 .unwrap_or(600 / 10) // Picked arbitrarily
4833 }
4834
4835 fn refill_duration(&self) -> chrono::Duration {
4836 chrono::Duration::hours(1)
4837 }
4838
4839 fn db_name(&self) -> &'static str {
4840 "free:count-language-model-tokens"
4841 }
4842}
4843
4844struct ZedProComputeEmbeddingsRateLimit;
4845
4846impl RateLimit for ZedProComputeEmbeddingsRateLimit {
4847 fn capacity(&self) -> usize {
4848 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4849 .ok()
4850 .and_then(|v| v.parse().ok())
4851 .unwrap_or(5000) // Picked arbitrarily
4852 }
4853
4854 fn refill_duration(&self) -> chrono::Duration {
4855 chrono::Duration::hours(1)
4856 }
4857
4858 fn db_name(&self) -> &'static str {
4859 "zed-pro:compute-embeddings"
4860 }
4861}
4862
4863struct FreeComputeEmbeddingsRateLimit;
4864
4865impl RateLimit for FreeComputeEmbeddingsRateLimit {
4866 fn capacity(&self) -> usize {
4867 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
4868 .ok()
4869 .and_then(|v| v.parse().ok())
4870 .unwrap_or(5000 / 10) // Picked arbitrarily
4871 }
4872
4873 fn refill_duration(&self) -> chrono::Duration {
4874 chrono::Duration::hours(1)
4875 }
4876
4877 fn db_name(&self) -> &'static str {
4878 "free:compute-embeddings"
4879 }
4880}
4881
4882async fn compute_embeddings(
4883 request: proto::ComputeEmbeddings,
4884 response: Response<proto::ComputeEmbeddings>,
4885 session: UserSession,
4886 api_key: Option<Arc<str>>,
4887) -> Result<()> {
4888 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4889 authorize_access_to_language_models(&session).await?;
4890
4891 let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
4892 proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
4893 proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
4894 };
4895
4896 session
4897 .app_state
4898 .rate_limiter
4899 .check(&*rate_limit, session.user_id())
4900 .await?;
4901
4902 let embeddings = match request.model.as_str() {
4903 "openai/text-embedding-3-small" => {
4904 open_ai::embed(
4905 session.http_client.as_ref(),
4906 OPEN_AI_API_URL,
4907 &api_key,
4908 OpenAiEmbeddingModel::TextEmbedding3Small,
4909 request.texts.iter().map(|text| text.as_str()),
4910 )
4911 .await?
4912 }
4913 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4914 };
4915
4916 let embeddings = request
4917 .texts
4918 .iter()
4919 .map(|text| {
4920 let mut hasher = sha2::Sha256::new();
4921 hasher.update(text.as_bytes());
4922 let result = hasher.finalize();
4923 result.to_vec()
4924 })
4925 .zip(
4926 embeddings
4927 .data
4928 .into_iter()
4929 .map(|embedding| embedding.embedding),
4930 )
4931 .collect::<HashMap<_, _>>();
4932
4933 let db = session.db().await;
4934 db.save_embeddings(&request.model, &embeddings)
4935 .await
4936 .context("failed to save embeddings")
4937 .trace_err();
4938
4939 response.send(proto::ComputeEmbeddingsResponse {
4940 embeddings: embeddings
4941 .into_iter()
4942 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4943 .collect(),
4944 })?;
4945 Ok(())
4946}
4947
4948async fn get_cached_embeddings(
4949 request: proto::GetCachedEmbeddings,
4950 response: Response<proto::GetCachedEmbeddings>,
4951 session: UserSession,
4952) -> Result<()> {
4953 authorize_access_to_language_models(&session).await?;
4954
4955 let db = session.db().await;
4956 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4957
4958 response.send(proto::GetCachedEmbeddingsResponse {
4959 embeddings: embeddings
4960 .into_iter()
4961 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4962 .collect(),
4963 })?;
4964 Ok(())
4965}
4966
4967async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4968 let db = session.db().await;
4969 let flags = db.get_user_flags(session.user_id()).await?;
4970 if flags.iter().any(|flag| flag == "language-models") {
4971 return Ok(());
4972 }
4973
4974 Err(anyhow!("permission denied"))?
4975}
4976
4977/// Get a Supermaven API key for the user
4978async fn get_supermaven_api_key(
4979 _request: proto::GetSupermavenApiKey,
4980 response: Response<proto::GetSupermavenApiKey>,
4981 session: UserSession,
4982) -> Result<()> {
4983 let user_id: String = session.user_id().to_string();
4984 if !session.is_staff() {
4985 return Err(anyhow!("supermaven not enabled for this account"))?;
4986 }
4987
4988 let email = session
4989 .email()
4990 .ok_or_else(|| anyhow!("user must have an email"))?;
4991
4992 let supermaven_admin_api = session
4993 .supermaven_client
4994 .as_ref()
4995 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4996
4997 let result = supermaven_admin_api
4998 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4999 .await?;
5000
5001 response.send(proto::GetSupermavenApiKeyResponse {
5002 api_key: result.api_key,
5003 })?;
5004
5005 Ok(())
5006}
5007
5008/// Start receiving chat updates for a channel
5009async fn join_channel_chat(
5010 request: proto::JoinChannelChat,
5011 response: Response<proto::JoinChannelChat>,
5012 session: UserSession,
5013) -> Result<()> {
5014 let channel_id = ChannelId::from_proto(request.channel_id);
5015
5016 let db = session.db().await;
5017 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
5018 .await?;
5019 let messages = db
5020 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
5021 .await?;
5022 response.send(proto::JoinChannelChatResponse {
5023 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5024 messages,
5025 })?;
5026 Ok(())
5027}
5028
5029/// Stop receiving chat updates for a channel
5030async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
5031 let channel_id = ChannelId::from_proto(request.channel_id);
5032 session
5033 .db()
5034 .await
5035 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
5036 .await?;
5037 Ok(())
5038}
5039
5040/// Retrieve the chat history for a channel
5041async fn get_channel_messages(
5042 request: proto::GetChannelMessages,
5043 response: Response<proto::GetChannelMessages>,
5044 session: UserSession,
5045) -> Result<()> {
5046 let channel_id = ChannelId::from_proto(request.channel_id);
5047 let messages = session
5048 .db()
5049 .await
5050 .get_channel_messages(
5051 channel_id,
5052 session.user_id(),
5053 MESSAGE_COUNT_PER_PAGE,
5054 Some(MessageId::from_proto(request.before_message_id)),
5055 )
5056 .await?;
5057 response.send(proto::GetChannelMessagesResponse {
5058 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5059 messages,
5060 })?;
5061 Ok(())
5062}
5063
5064/// Retrieve specific chat messages
5065async fn get_channel_messages_by_id(
5066 request: proto::GetChannelMessagesById,
5067 response: Response<proto::GetChannelMessagesById>,
5068 session: UserSession,
5069) -> Result<()> {
5070 let message_ids = request
5071 .message_ids
5072 .iter()
5073 .map(|id| MessageId::from_proto(*id))
5074 .collect::<Vec<_>>();
5075 let messages = session
5076 .db()
5077 .await
5078 .get_channel_messages_by_id(session.user_id(), &message_ids)
5079 .await?;
5080 response.send(proto::GetChannelMessagesResponse {
5081 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5082 messages,
5083 })?;
5084 Ok(())
5085}
5086
5087/// Retrieve the current users notifications
5088async fn get_notifications(
5089 request: proto::GetNotifications,
5090 response: Response<proto::GetNotifications>,
5091 session: UserSession,
5092) -> Result<()> {
5093 let notifications = session
5094 .db()
5095 .await
5096 .get_notifications(
5097 session.user_id(),
5098 NOTIFICATION_COUNT_PER_PAGE,
5099 request
5100 .before_id
5101 .map(|id| db::NotificationId::from_proto(id)),
5102 )
5103 .await?;
5104 response.send(proto::GetNotificationsResponse {
5105 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
5106 notifications,
5107 })?;
5108 Ok(())
5109}
5110
5111/// Mark notifications as read
5112async fn mark_notification_as_read(
5113 request: proto::MarkNotificationRead,
5114 response: Response<proto::MarkNotificationRead>,
5115 session: UserSession,
5116) -> Result<()> {
5117 let database = &session.db().await;
5118 let notifications = database
5119 .mark_notification_as_read_by_id(
5120 session.user_id(),
5121 NotificationId::from_proto(request.notification_id),
5122 )
5123 .await?;
5124 send_notifications(
5125 &*session.connection_pool().await,
5126 &session.peer,
5127 notifications,
5128 );
5129 response.send(proto::Ack {})?;
5130 Ok(())
5131}
5132
5133/// Get the current users information
5134async fn get_private_user_info(
5135 _request: proto::GetPrivateUserInfo,
5136 response: Response<proto::GetPrivateUserInfo>,
5137 session: UserSession,
5138) -> Result<()> {
5139 let db = session.db().await;
5140
5141 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
5142 let user = db
5143 .get_user_by_id(session.user_id())
5144 .await?
5145 .ok_or_else(|| anyhow!("user not found"))?;
5146 let flags = db.get_user_flags(session.user_id()).await?;
5147
5148 response.send(proto::GetPrivateUserInfoResponse {
5149 metrics_id,
5150 staff: user.admin,
5151 flags,
5152 })?;
5153 Ok(())
5154}
5155
5156async fn get_llm_api_token(
5157 _request: proto::GetLlmToken,
5158 response: Response<proto::GetLlmToken>,
5159 session: UserSession,
5160) -> Result<()> {
5161 if !session.is_staff() {
5162 Err(anyhow!("permission denied"))?
5163 }
5164
5165 let token = LlmTokenClaims::create(
5166 session.user_id(),
5167 session.is_staff(),
5168 session.current_plan().await?,
5169 &session.app_state.config,
5170 )?;
5171 response.send(proto::GetLlmTokenResponse { token })?;
5172 Ok(())
5173}
5174
5175fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
5176 let message = match message {
5177 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5178 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5179 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5180 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5181 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5182 code: frame.code.into(),
5183 reason: frame.reason,
5184 })),
5185 // We should never receive a frame while reading the message, according
5186 // to the `tungstenite` maintainers:
5187 //
5188 // > It cannot occur when you read messages from the WebSocket, but it
5189 // > can be used when you want to send the raw frames (e.g. you want to
5190 // > send the frames to the WebSocket without composing the full message first).
5191 // >
5192 // > — https://github.com/snapview/tungstenite-rs/issues/268
5193 TungsteniteMessage::Frame(_) => {
5194 bail!("received an unexpected frame while reading the message")
5195 }
5196 };
5197
5198 Ok(message)
5199}
5200
5201fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5202 match message {
5203 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5204 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5205 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5206 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5207 AxumMessage::Close(frame) => {
5208 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5209 code: frame.code.into(),
5210 reason: frame.reason,
5211 }))
5212 }
5213 }
5214}
5215
5216fn notify_membership_updated(
5217 connection_pool: &mut ConnectionPool,
5218 result: MembershipUpdated,
5219 user_id: UserId,
5220 peer: &Peer,
5221) {
5222 for membership in &result.new_channels.channel_memberships {
5223 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5224 }
5225 for channel_id in &result.removed_channels {
5226 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5227 }
5228
5229 let user_channels_update = proto::UpdateUserChannels {
5230 channel_memberships: result
5231 .new_channels
5232 .channel_memberships
5233 .iter()
5234 .map(|cm| proto::ChannelMembership {
5235 channel_id: cm.channel_id.to_proto(),
5236 role: cm.role.into(),
5237 })
5238 .collect(),
5239 ..Default::default()
5240 };
5241
5242 let mut update = build_channels_update(result.new_channels);
5243 update.delete_channels = result
5244 .removed_channels
5245 .into_iter()
5246 .map(|id| id.to_proto())
5247 .collect();
5248 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5249
5250 for connection_id in connection_pool.user_connection_ids(user_id) {
5251 peer.send(connection_id, user_channels_update.clone())
5252 .trace_err();
5253 peer.send(connection_id, update.clone()).trace_err();
5254 }
5255}
5256
5257fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5258 proto::UpdateUserChannels {
5259 channel_memberships: channels
5260 .channel_memberships
5261 .iter()
5262 .map(|m| proto::ChannelMembership {
5263 channel_id: m.channel_id.to_proto(),
5264 role: m.role.into(),
5265 })
5266 .collect(),
5267 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5268 observed_channel_message_id: channels.observed_channel_messages.clone(),
5269 }
5270}
5271
5272fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5273 let mut update = proto::UpdateChannels::default();
5274
5275 for channel in channels.channels {
5276 update.channels.push(channel.to_proto());
5277 }
5278
5279 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5280 update.latest_channel_message_ids = channels.latest_channel_messages;
5281
5282 for (channel_id, participants) in channels.channel_participants {
5283 update
5284 .channel_participants
5285 .push(proto::ChannelParticipants {
5286 channel_id: channel_id.to_proto(),
5287 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5288 });
5289 }
5290
5291 for channel in channels.invited_channels {
5292 update.channel_invitations.push(channel.to_proto());
5293 }
5294
5295 update.hosted_projects = channels.hosted_projects;
5296 update
5297}
5298
5299fn build_initial_contacts_update(
5300 contacts: Vec<db::Contact>,
5301 pool: &ConnectionPool,
5302) -> proto::UpdateContacts {
5303 let mut update = proto::UpdateContacts::default();
5304
5305 for contact in contacts {
5306 match contact {
5307 db::Contact::Accepted { user_id, busy } => {
5308 update.contacts.push(contact_for_user(user_id, busy, &pool));
5309 }
5310 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5311 db::Contact::Incoming { user_id } => {
5312 update
5313 .incoming_requests
5314 .push(proto::IncomingContactRequest {
5315 requester_id: user_id.to_proto(),
5316 })
5317 }
5318 }
5319 }
5320
5321 update
5322}
5323
5324fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5325 proto::Contact {
5326 user_id: user_id.to_proto(),
5327 online: pool.is_user_online(user_id),
5328 busy,
5329 }
5330}
5331
5332fn room_updated(room: &proto::Room, peer: &Peer) {
5333 broadcast(
5334 None,
5335 room.participants
5336 .iter()
5337 .filter_map(|participant| Some(participant.peer_id?.into())),
5338 |peer_id| {
5339 peer.send(
5340 peer_id,
5341 proto::RoomUpdated {
5342 room: Some(room.clone()),
5343 },
5344 )
5345 },
5346 );
5347}
5348
5349fn channel_updated(
5350 channel: &db::channel::Model,
5351 room: &proto::Room,
5352 peer: &Peer,
5353 pool: &ConnectionPool,
5354) {
5355 let participants = room
5356 .participants
5357 .iter()
5358 .map(|p| p.user_id)
5359 .collect::<Vec<_>>();
5360
5361 broadcast(
5362 None,
5363 pool.channel_connection_ids(channel.root_id())
5364 .filter_map(|(channel_id, role)| {
5365 role.can_see_channel(channel.visibility).then(|| channel_id)
5366 }),
5367 |peer_id| {
5368 peer.send(
5369 peer_id,
5370 proto::UpdateChannels {
5371 channel_participants: vec![proto::ChannelParticipants {
5372 channel_id: channel.id.to_proto(),
5373 participant_user_ids: participants.clone(),
5374 }],
5375 ..Default::default()
5376 },
5377 )
5378 },
5379 );
5380}
5381
5382async fn send_dev_server_projects_update(
5383 user_id: UserId,
5384 mut status: proto::DevServerProjectsUpdate,
5385 session: &Session,
5386) {
5387 let pool = session.connection_pool().await;
5388 for dev_server in &mut status.dev_servers {
5389 dev_server.status =
5390 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5391 }
5392 let connections = pool.user_connection_ids(user_id);
5393 for connection_id in connections {
5394 session.peer.send(connection_id, status.clone()).trace_err();
5395 }
5396}
5397
5398async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5399 let db = session.db().await;
5400
5401 let contacts = db.get_contacts(user_id).await?;
5402 let busy = db.is_user_busy(user_id).await?;
5403
5404 let pool = session.connection_pool().await;
5405 let updated_contact = contact_for_user(user_id, busy, &pool);
5406 for contact in contacts {
5407 if let db::Contact::Accepted {
5408 user_id: contact_user_id,
5409 ..
5410 } = contact
5411 {
5412 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5413 session
5414 .peer
5415 .send(
5416 contact_conn_id,
5417 proto::UpdateContacts {
5418 contacts: vec![updated_contact.clone()],
5419 remove_contacts: Default::default(),
5420 incoming_requests: Default::default(),
5421 remove_incoming_requests: Default::default(),
5422 outgoing_requests: Default::default(),
5423 remove_outgoing_requests: Default::default(),
5424 },
5425 )
5426 .trace_err();
5427 }
5428 }
5429 }
5430 Ok(())
5431}
5432
5433async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5434 log::info!("lost dev server connection, unsharing projects");
5435 let project_ids = session
5436 .db()
5437 .await
5438 .get_stale_dev_server_projects(session.connection_id)
5439 .await?;
5440
5441 for project_id in project_ids {
5442 // not unshare re-checks the connection ids match, so we get away with no transaction
5443 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5444 }
5445
5446 let user_id = session.dev_server().user_id;
5447 let update = session
5448 .db()
5449 .await
5450 .dev_server_projects_update(user_id)
5451 .await?;
5452
5453 send_dev_server_projects_update(user_id, update, session).await;
5454
5455 Ok(())
5456}
5457
5458async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5459 let mut contacts_to_update = HashSet::default();
5460
5461 let room_id;
5462 let canceled_calls_to_user_ids;
5463 let live_kit_room;
5464 let delete_live_kit_room;
5465 let room;
5466 let channel;
5467
5468 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5469 contacts_to_update.insert(session.user_id());
5470
5471 for project in left_room.left_projects.values() {
5472 project_left(project, session);
5473 }
5474
5475 room_id = RoomId::from_proto(left_room.room.id);
5476 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5477 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5478 delete_live_kit_room = left_room.deleted;
5479 room = mem::take(&mut left_room.room);
5480 channel = mem::take(&mut left_room.channel);
5481
5482 room_updated(&room, &session.peer);
5483 } else {
5484 return Ok(());
5485 }
5486
5487 if let Some(channel) = channel {
5488 channel_updated(
5489 &channel,
5490 &room,
5491 &session.peer,
5492 &*session.connection_pool().await,
5493 );
5494 }
5495
5496 {
5497 let pool = session.connection_pool().await;
5498 for canceled_user_id in canceled_calls_to_user_ids {
5499 for connection_id in pool.user_connection_ids(canceled_user_id) {
5500 session
5501 .peer
5502 .send(
5503 connection_id,
5504 proto::CallCanceled {
5505 room_id: room_id.to_proto(),
5506 },
5507 )
5508 .trace_err();
5509 }
5510 contacts_to_update.insert(canceled_user_id);
5511 }
5512 }
5513
5514 for contact_user_id in contacts_to_update {
5515 update_user_contacts(contact_user_id, &session).await?;
5516 }
5517
5518 if let Some(live_kit) = session.app_state.live_kit_client.as_ref() {
5519 live_kit
5520 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5521 .await
5522 .trace_err();
5523
5524 if delete_live_kit_room {
5525 live_kit.delete_room(live_kit_room).await.trace_err();
5526 }
5527 }
5528
5529 Ok(())
5530}
5531
5532async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5533 let left_channel_buffers = session
5534 .db()
5535 .await
5536 .leave_channel_buffers(session.connection_id)
5537 .await?;
5538
5539 for left_buffer in left_channel_buffers {
5540 channel_buffer_updated(
5541 session.connection_id,
5542 left_buffer.connections,
5543 &proto::UpdateChannelBufferCollaborators {
5544 channel_id: left_buffer.channel_id.to_proto(),
5545 collaborators: left_buffer.collaborators,
5546 },
5547 &session.peer,
5548 );
5549 }
5550
5551 Ok(())
5552}
5553
5554fn project_left(project: &db::LeftProject, session: &UserSession) {
5555 for connection_id in &project.connection_ids {
5556 if project.should_unshare {
5557 session
5558 .peer
5559 .send(
5560 *connection_id,
5561 proto::UnshareProject {
5562 project_id: project.id.to_proto(),
5563 },
5564 )
5565 .trace_err();
5566 } else {
5567 session
5568 .peer
5569 .send(
5570 *connection_id,
5571 proto::RemoveProjectCollaborator {
5572 project_id: project.id.to_proto(),
5573 peer_id: Some(session.connection_id.into()),
5574 },
5575 )
5576 .trace_err();
5577 }
5578 }
5579}
5580
5581pub trait ResultExt {
5582 type Ok;
5583
5584 fn trace_err(self) -> Option<Self::Ok>;
5585}
5586
5587impl<T, E> ResultExt for Result<T, E>
5588where
5589 E: std::fmt::Debug,
5590{
5591 type Ok = T;
5592
5593 #[track_caller]
5594 fn trace_err(self) -> Option<T> {
5595 match self {
5596 Ok(value) => Some(value),
5597 Err(error) => {
5598 tracing::error!("{:?}", error);
5599 None
5600 }
5601 }
5602 }
5603}