1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
8 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
9 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Config, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, bail, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38
39use futures::{
40 channel::oneshot,
41 future::{self, BoxFuture},
42 stream::FuturesUnordered,
43 FutureExt, SinkExt, StreamExt, TryStreamExt,
44};
45use http_client::IsahcHttpClient;
46use prometheus::{register_int_gauge, IntGauge};
47use rpc::{
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 atomic::{AtomicBool, Ordering::SeqCst},
66 Arc, OnceLock,
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{watch, Semaphore};
72use tower::ServiceBuilder;
73use tracing::{
74 field::{self},
75 info_span, instrument, Instrument,
76};
77
78use self::connection_pool::VersionedMessage;
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106struct StreamingResponse<R: RequestMessage> {
107 peer: Arc<Peer>,
108 receipt: Receipt<R>,
109}
110
111impl<R: RequestMessage> StreamingResponse<R> {
112 fn send(&self, payload: R::Response) -> Result<()> {
113 self.peer.respond(self.receipt, payload)?;
114 Ok(())
115 }
116}
117
118#[derive(Clone, Debug)]
119pub enum Principal {
120 User(User),
121 Impersonated { user: User, admin: User },
122 DevServer(dev_server::Model),
123}
124
125impl Principal {
126 fn update_span(&self, span: &tracing::Span) {
127 match &self {
128 Principal::User(user) => {
129 span.record("user_id", &user.id.0);
130 span.record("login", &user.github_login);
131 }
132 Principal::Impersonated { user, admin } => {
133 span.record("user_id", &user.id.0);
134 span.record("login", &user.github_login);
135 span.record("impersonator", &admin.github_login);
136 }
137 Principal::DevServer(dev_server) => {
138 span.record("dev_server_id", &dev_server.id.0);
139 }
140 }
141 }
142}
143
144#[derive(Clone)]
145struct Session {
146 principal: Principal,
147 connection_id: ConnectionId,
148 db: Arc<tokio::sync::Mutex<DbHandle>>,
149 peer: Arc<Peer>,
150 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
151 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
152 supermaven_client: Option<Arc<SupermavenAdminApi>>,
153 http_client: Arc<IsahcHttpClient>,
154 rate_limiter: Arc<RateLimiter>,
155 _executor: Executor,
156}
157
158impl Session {
159 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.db.lock().await;
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 guard
166 }
167
168 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
169 #[cfg(test)]
170 tokio::task::yield_now().await;
171 let guard = self.connection_pool.lock();
172 ConnectionPoolGuard {
173 guard,
174 _not_send: PhantomData,
175 }
176 }
177
178 fn for_user(self) -> Option<UserSession> {
179 UserSession::new(self)
180 }
181
182 fn for_dev_server(self) -> Option<DevServerSession> {
183 DevServerSession::new(self)
184 }
185
186 fn user_id(&self) -> Option<UserId> {
187 match &self.principal {
188 Principal::User(user) => Some(user.id),
189 Principal::Impersonated { user, .. } => Some(user.id),
190 Principal::DevServer(_) => None,
191 }
192 }
193
194 fn is_staff(&self) -> bool {
195 match &self.principal {
196 Principal::User(user) => user.admin,
197 Principal::Impersonated { .. } => true,
198 Principal::DevServer(_) => false,
199 }
200 }
201
202 fn dev_server_id(&self) -> Option<DevServerId> {
203 match &self.principal {
204 Principal::User(_) | Principal::Impersonated { .. } => None,
205 Principal::DevServer(dev_server) => Some(dev_server.id),
206 }
207 }
208
209 fn principal_id(&self) -> PrincipalId {
210 match &self.principal {
211 Principal::User(user) => PrincipalId::UserId(user.id),
212 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
213 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
214 }
215 }
216}
217
218impl Debug for Session {
219 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220 let mut result = f.debug_struct("Session");
221 match &self.principal {
222 Principal::User(user) => {
223 result.field("user", &user.github_login);
224 }
225 Principal::Impersonated { user, admin } => {
226 result.field("user", &user.github_login);
227 result.field("impersonator", &admin.github_login);
228 }
229 Principal::DevServer(dev_server) => {
230 result.field("dev_server", &dev_server.id);
231 }
232 }
233 result.field("connection_id", &self.connection_id).finish()
234 }
235}
236
237struct UserSession(Session);
238
239impl UserSession {
240 pub fn new(s: Session) -> Option<Self> {
241 s.user_id().map(|_| UserSession(s))
242 }
243 pub fn user_id(&self) -> UserId {
244 self.0.user_id().unwrap()
245 }
246
247 pub fn email(&self) -> Option<String> {
248 match &self.0.principal {
249 Principal::User(user) => user.email_address.clone(),
250 Principal::Impersonated { user, .. } => user.email_address.clone(),
251 Principal::DevServer(..) => None,
252 }
253 }
254}
255
256impl Deref for UserSession {
257 type Target = Session;
258
259 fn deref(&self) -> &Self::Target {
260 &self.0
261 }
262}
263impl DerefMut for UserSession {
264 fn deref_mut(&mut self) -> &mut Self::Target {
265 &mut self.0
266 }
267}
268
269struct DevServerSession(Session);
270
271impl DevServerSession {
272 pub fn new(s: Session) -> Option<Self> {
273 s.dev_server_id().map(|_| DevServerSession(s))
274 }
275 pub fn dev_server_id(&self) -> DevServerId {
276 self.0.dev_server_id().unwrap()
277 }
278
279 fn dev_server(&self) -> &dev_server::Model {
280 match &self.0.principal {
281 Principal::DevServer(dev_server) => dev_server,
282 _ => unreachable!(),
283 }
284 }
285}
286
287impl Deref for DevServerSession {
288 type Target = Session;
289
290 fn deref(&self) -> &Self::Target {
291 &self.0
292 }
293}
294impl DerefMut for DevServerSession {
295 fn deref_mut(&mut self) -> &mut Self::Target {
296 &mut self.0
297 }
298}
299
300fn user_handler<M: RequestMessage, Fut>(
301 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
303where
304 Fut: Send + Future<Output = Result<()>>,
305{
306 let handler = Arc::new(handler);
307 move |message, response, session| {
308 let handler = handler.clone();
309 Box::pin(async move {
310 if let Some(user_session) = session.for_user() {
311 Ok(handler(message, response, user_session).await?)
312 } else {
313 Err(Error::Internal(anyhow!(
314 "must be a user to call {}",
315 M::NAME
316 )))
317 }
318 })
319 }
320}
321
322fn dev_server_handler<M: RequestMessage, Fut>(
323 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
325where
326 Fut: Send + Future<Output = Result<()>>,
327{
328 let handler = Arc::new(handler);
329 move |message, response, session| {
330 let handler = handler.clone();
331 Box::pin(async move {
332 if let Some(dev_server_session) = session.for_dev_server() {
333 Ok(handler(message, response, dev_server_session).await?)
334 } else {
335 Err(Error::Internal(anyhow!(
336 "must be a dev server to call {}",
337 M::NAME
338 )))
339 }
340 })
341 }
342}
343
344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
345 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
347where
348 InnertRetFut: Send + Future<Output = Result<()>>,
349{
350 let handler = Arc::new(handler);
351 move |message, session| {
352 let handler = handler.clone();
353 Box::pin(async move {
354 if let Some(user_session) = session.for_user() {
355 Ok(handler(message, user_session).await?)
356 } else {
357 Err(Error::Internal(anyhow!(
358 "must be a user to call {}",
359 M::NAME
360 )))
361 }
362 })
363 }
364}
365
366struct DbHandle(Arc<Database>);
367
368impl Deref for DbHandle {
369 type Target = Database;
370
371 fn deref(&self) -> &Self::Target {
372 self.0.as_ref()
373 }
374}
375
376pub struct Server {
377 id: parking_lot::Mutex<ServerId>,
378 peer: Arc<Peer>,
379 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
380 app_state: Arc<AppState>,
381 handlers: HashMap<TypeId, MessageHandler>,
382 teardown: watch::Sender<bool>,
383}
384
385pub(crate) struct ConnectionPoolGuard<'a> {
386 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
387 _not_send: PhantomData<Rc<()>>,
388}
389
390#[derive(Serialize)]
391pub struct ServerSnapshot<'a> {
392 peer: &'a Peer,
393 #[serde(serialize_with = "serialize_deref")]
394 connection_pool: ConnectionPoolGuard<'a>,
395}
396
397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
398where
399 S: Serializer,
400 T: Deref<Target = U>,
401 U: Serialize,
402{
403 Serialize::serialize(value.deref(), serializer)
404}
405
406impl Server {
407 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
408 let mut server = Self {
409 id: parking_lot::Mutex::new(id),
410 peer: Peer::new(id.0 as u32),
411 app_state: app_state.clone(),
412 connection_pool: Default::default(),
413 handlers: Default::default(),
414 teardown: watch::channel(false).0,
415 };
416
417 server
418 .add_request_handler(ping)
419 .add_request_handler(user_handler(create_room))
420 .add_request_handler(user_handler(join_room))
421 .add_request_handler(user_handler(rejoin_room))
422 .add_request_handler(user_handler(leave_room))
423 .add_request_handler(user_handler(set_room_participant_role))
424 .add_request_handler(user_handler(call))
425 .add_request_handler(user_handler(cancel_call))
426 .add_message_handler(user_message_handler(decline_call))
427 .add_request_handler(user_handler(update_participant_location))
428 .add_request_handler(user_handler(share_project))
429 .add_message_handler(unshare_project)
430 .add_request_handler(user_handler(join_project))
431 .add_request_handler(user_handler(join_hosted_project))
432 .add_request_handler(user_handler(rejoin_dev_server_projects))
433 .add_request_handler(user_handler(create_dev_server_project))
434 .add_request_handler(user_handler(update_dev_server_project))
435 .add_request_handler(user_handler(delete_dev_server_project))
436 .add_request_handler(user_handler(create_dev_server))
437 .add_request_handler(user_handler(regenerate_dev_server_token))
438 .add_request_handler(user_handler(rename_dev_server))
439 .add_request_handler(user_handler(delete_dev_server))
440 .add_request_handler(user_handler(list_remote_directory))
441 .add_request_handler(dev_server_handler(share_dev_server_project))
442 .add_request_handler(dev_server_handler(shutdown_dev_server))
443 .add_request_handler(dev_server_handler(reconnect_dev_server))
444 .add_message_handler(user_message_handler(leave_project))
445 .add_request_handler(update_project)
446 .add_request_handler(update_worktree)
447 .add_message_handler(start_language_server)
448 .add_message_handler(update_language_server)
449 .add_message_handler(update_diagnostic_summary)
450 .add_message_handler(update_worktree_settings)
451 .add_request_handler(user_handler(
452 forward_project_request_for_owner::<proto::TaskContextForLocation>,
453 ))
454 .add_request_handler(user_handler(
455 forward_project_request_for_owner::<proto::TaskTemplates>,
456 ))
457 .add_request_handler(user_handler(
458 forward_read_only_project_request::<proto::GetHover>,
459 ))
460 .add_request_handler(user_handler(
461 forward_read_only_project_request::<proto::GetDefinition>,
462 ))
463 .add_request_handler(user_handler(
464 forward_read_only_project_request::<proto::GetTypeDefinition>,
465 ))
466 .add_request_handler(user_handler(
467 forward_read_only_project_request::<proto::GetReferences>,
468 ))
469 .add_request_handler(user_handler(
470 forward_read_only_project_request::<proto::SearchProject>,
471 ))
472 .add_request_handler(user_handler(
473 forward_read_only_project_request::<proto::GetDocumentHighlights>,
474 ))
475 .add_request_handler(user_handler(
476 forward_read_only_project_request::<proto::GetProjectSymbols>,
477 ))
478 .add_request_handler(user_handler(
479 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
480 ))
481 .add_request_handler(user_handler(
482 forward_read_only_project_request::<proto::OpenBufferById>,
483 ))
484 .add_request_handler(user_handler(
485 forward_read_only_project_request::<proto::SynchronizeBuffers>,
486 ))
487 .add_request_handler(user_handler(
488 forward_read_only_project_request::<proto::InlayHints>,
489 ))
490 .add_request_handler(user_handler(
491 forward_read_only_project_request::<proto::OpenBufferByPath>,
492 ))
493 .add_request_handler(user_handler(
494 forward_mutating_project_request::<proto::GetCompletions>,
495 ))
496 .add_request_handler(user_handler(
497 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
498 ))
499 .add_request_handler(user_handler(
500 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
501 ))
502 .add_request_handler(user_handler(
503 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
504 ))
505 .add_request_handler(user_handler(
506 forward_mutating_project_request::<proto::GetCodeActions>,
507 ))
508 .add_request_handler(user_handler(
509 forward_mutating_project_request::<proto::ApplyCodeAction>,
510 ))
511 .add_request_handler(user_handler(
512 forward_mutating_project_request::<proto::PrepareRename>,
513 ))
514 .add_request_handler(user_handler(
515 forward_mutating_project_request::<proto::PerformRename>,
516 ))
517 .add_request_handler(user_handler(
518 forward_mutating_project_request::<proto::ReloadBuffers>,
519 ))
520 .add_request_handler(user_handler(
521 forward_mutating_project_request::<proto::FormatBuffers>,
522 ))
523 .add_request_handler(user_handler(
524 forward_mutating_project_request::<proto::CreateProjectEntry>,
525 ))
526 .add_request_handler(user_handler(
527 forward_mutating_project_request::<proto::RenameProjectEntry>,
528 ))
529 .add_request_handler(user_handler(
530 forward_mutating_project_request::<proto::CopyProjectEntry>,
531 ))
532 .add_request_handler(user_handler(
533 forward_mutating_project_request::<proto::DeleteProjectEntry>,
534 ))
535 .add_request_handler(user_handler(
536 forward_mutating_project_request::<proto::ExpandProjectEntry>,
537 ))
538 .add_request_handler(user_handler(
539 forward_mutating_project_request::<proto::OnTypeFormatting>,
540 ))
541 .add_request_handler(user_handler(
542 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
543 ))
544 .add_request_handler(user_handler(
545 forward_mutating_project_request::<proto::BlameBuffer>,
546 ))
547 .add_request_handler(user_handler(
548 forward_mutating_project_request::<proto::MultiLspQuery>,
549 ))
550 .add_request_handler(user_handler(
551 forward_mutating_project_request::<proto::RestartLanguageServers>,
552 ))
553 .add_request_handler(user_handler(
554 forward_mutating_project_request::<proto::LinkedEditingRange>,
555 ))
556 .add_message_handler(create_buffer_for_peer)
557 .add_request_handler(update_buffer)
558 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
559 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
560 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
561 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
562 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
563 .add_request_handler(get_users)
564 .add_request_handler(user_handler(fuzzy_search_users))
565 .add_request_handler(user_handler(request_contact))
566 .add_request_handler(user_handler(remove_contact))
567 .add_request_handler(user_handler(respond_to_contact_request))
568 .add_message_handler(subscribe_to_channels)
569 .add_request_handler(user_handler(create_channel))
570 .add_request_handler(user_handler(delete_channel))
571 .add_request_handler(user_handler(invite_channel_member))
572 .add_request_handler(user_handler(remove_channel_member))
573 .add_request_handler(user_handler(set_channel_member_role))
574 .add_request_handler(user_handler(set_channel_visibility))
575 .add_request_handler(user_handler(rename_channel))
576 .add_request_handler(user_handler(join_channel_buffer))
577 .add_request_handler(user_handler(leave_channel_buffer))
578 .add_message_handler(user_message_handler(update_channel_buffer))
579 .add_request_handler(user_handler(rejoin_channel_buffers))
580 .add_request_handler(user_handler(get_channel_members))
581 .add_request_handler(user_handler(respond_to_channel_invite))
582 .add_request_handler(user_handler(join_channel))
583 .add_request_handler(user_handler(join_channel_chat))
584 .add_message_handler(user_message_handler(leave_channel_chat))
585 .add_request_handler(user_handler(send_channel_message))
586 .add_request_handler(user_handler(remove_channel_message))
587 .add_request_handler(user_handler(update_channel_message))
588 .add_request_handler(user_handler(get_channel_messages))
589 .add_request_handler(user_handler(get_channel_messages_by_id))
590 .add_request_handler(user_handler(get_notifications))
591 .add_request_handler(user_handler(mark_notification_as_read))
592 .add_request_handler(user_handler(move_channel))
593 .add_request_handler(user_handler(follow))
594 .add_message_handler(user_message_handler(unfollow))
595 .add_message_handler(user_message_handler(update_followers))
596 .add_request_handler(user_handler(get_private_user_info))
597 .add_message_handler(user_message_handler(acknowledge_channel_message))
598 .add_message_handler(user_message_handler(acknowledge_buffer_version))
599 .add_request_handler(user_handler(get_supermaven_api_key))
600 .add_request_handler(user_handler(
601 forward_mutating_project_request::<proto::OpenContext>,
602 ))
603 .add_request_handler(user_handler(
604 forward_mutating_project_request::<proto::CreateContext>,
605 ))
606 .add_request_handler(user_handler(
607 forward_mutating_project_request::<proto::SynchronizeContexts>,
608 ))
609 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
610 .add_message_handler(update_context)
611 .add_request_handler({
612 let app_state = app_state.clone();
613 move |request, response, session| {
614 let app_state = app_state.clone();
615 async move {
616 complete_with_language_model(request, response, session, &app_state.config)
617 .await
618 }
619 }
620 })
621 .add_streaming_request_handler({
622 let app_state = app_state.clone();
623 move |request, response, session| {
624 let app_state = app_state.clone();
625 async move {
626 stream_complete_with_language_model(
627 request,
628 response,
629 session,
630 &app_state.config,
631 )
632 .await
633 }
634 }
635 })
636 .add_request_handler({
637 let app_state = app_state.clone();
638 move |request, response, session| {
639 let app_state = app_state.clone();
640 async move {
641 count_language_model_tokens(request, response, session, &app_state.config)
642 .await
643 }
644 }
645 })
646 .add_request_handler({
647 user_handler(move |request, response, session| {
648 get_cached_embeddings(request, response, session)
649 })
650 })
651 .add_request_handler({
652 let app_state = app_state.clone();
653 user_handler(move |request, response, session| {
654 compute_embeddings(
655 request,
656 response,
657 session,
658 app_state.config.openai_api_key.clone(),
659 )
660 })
661 });
662
663 Arc::new(server)
664 }
665
666 pub async fn start(&self) -> Result<()> {
667 let server_id = *self.id.lock();
668 let app_state = self.app_state.clone();
669 let peer = self.peer.clone();
670 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
671 let pool = self.connection_pool.clone();
672 let live_kit_client = self.app_state.live_kit_client.clone();
673
674 let span = info_span!("start server");
675 self.app_state.executor.spawn_detached(
676 async move {
677 tracing::info!("waiting for cleanup timeout");
678 timeout.await;
679 tracing::info!("cleanup timeout expired, retrieving stale rooms");
680 if let Some((room_ids, channel_ids)) = app_state
681 .db
682 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
683 .await
684 .trace_err()
685 {
686 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
687 tracing::info!(
688 stale_channel_buffer_count = channel_ids.len(),
689 "retrieved stale channel buffers"
690 );
691
692 for channel_id in channel_ids {
693 if let Some(refreshed_channel_buffer) = app_state
694 .db
695 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
696 .await
697 .trace_err()
698 {
699 for connection_id in refreshed_channel_buffer.connection_ids {
700 peer.send(
701 connection_id,
702 proto::UpdateChannelBufferCollaborators {
703 channel_id: channel_id.to_proto(),
704 collaborators: refreshed_channel_buffer
705 .collaborators
706 .clone(),
707 },
708 )
709 .trace_err();
710 }
711 }
712 }
713
714 for room_id in room_ids {
715 let mut contacts_to_update = HashSet::default();
716 let mut canceled_calls_to_user_ids = Vec::new();
717 let mut live_kit_room = String::new();
718 let mut delete_live_kit_room = false;
719
720 if let Some(mut refreshed_room) = app_state
721 .db
722 .clear_stale_room_participants(room_id, server_id)
723 .await
724 .trace_err()
725 {
726 tracing::info!(
727 room_id = room_id.0,
728 new_participant_count = refreshed_room.room.participants.len(),
729 "refreshed room"
730 );
731 room_updated(&refreshed_room.room, &peer);
732 if let Some(channel) = refreshed_room.channel.as_ref() {
733 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
734 }
735 contacts_to_update
736 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
737 contacts_to_update
738 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
739 canceled_calls_to_user_ids =
740 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
741 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
742 delete_live_kit_room = refreshed_room.room.participants.is_empty();
743 }
744
745 {
746 let pool = pool.lock();
747 for canceled_user_id in canceled_calls_to_user_ids {
748 for connection_id in pool.user_connection_ids(canceled_user_id) {
749 peer.send(
750 connection_id,
751 proto::CallCanceled {
752 room_id: room_id.to_proto(),
753 },
754 )
755 .trace_err();
756 }
757 }
758 }
759
760 for user_id in contacts_to_update {
761 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
762 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
763 if let Some((busy, contacts)) = busy.zip(contacts) {
764 let pool = pool.lock();
765 let updated_contact = contact_for_user(user_id, busy, &pool);
766 for contact in contacts {
767 if let db::Contact::Accepted {
768 user_id: contact_user_id,
769 ..
770 } = contact
771 {
772 for contact_conn_id in
773 pool.user_connection_ids(contact_user_id)
774 {
775 peer.send(
776 contact_conn_id,
777 proto::UpdateContacts {
778 contacts: vec![updated_contact.clone()],
779 remove_contacts: Default::default(),
780 incoming_requests: Default::default(),
781 remove_incoming_requests: Default::default(),
782 outgoing_requests: Default::default(),
783 remove_outgoing_requests: Default::default(),
784 },
785 )
786 .trace_err();
787 }
788 }
789 }
790 }
791 }
792
793 if let Some(live_kit) = live_kit_client.as_ref() {
794 if delete_live_kit_room {
795 live_kit.delete_room(live_kit_room).await.trace_err();
796 }
797 }
798 }
799 }
800
801 app_state
802 .db
803 .delete_stale_servers(&app_state.config.zed_environment, server_id)
804 .await
805 .trace_err();
806 }
807 .instrument(span),
808 );
809 Ok(())
810 }
811
812 pub fn teardown(&self) {
813 self.peer.teardown();
814 self.connection_pool.lock().reset();
815 let _ = self.teardown.send(true);
816 }
817
818 #[cfg(test)]
819 pub fn reset(&self, id: ServerId) {
820 self.teardown();
821 *self.id.lock() = id;
822 self.peer.reset(id.0 as u32);
823 let _ = self.teardown.send(false);
824 }
825
826 #[cfg(test)]
827 pub fn id(&self) -> ServerId {
828 *self.id.lock()
829 }
830
831 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
832 where
833 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
834 Fut: 'static + Send + Future<Output = Result<()>>,
835 M: EnvelopedMessage,
836 {
837 let prev_handler = self.handlers.insert(
838 TypeId::of::<M>(),
839 Box::new(move |envelope, session| {
840 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
841 let received_at = envelope.received_at;
842 tracing::info!("message received");
843 let start_time = Instant::now();
844 let future = (handler)(*envelope, session);
845 async move {
846 let result = future.await;
847 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
848 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
849 let queue_duration_ms = total_duration_ms - processing_duration_ms;
850 let payload_type = M::NAME;
851
852 match result {
853 Err(error) => {
854 tracing::error!(
855 ?error,
856 total_duration_ms,
857 processing_duration_ms,
858 queue_duration_ms,
859 payload_type,
860 "error handling message"
861 )
862 }
863 Ok(()) => tracing::info!(
864 total_duration_ms,
865 processing_duration_ms,
866 queue_duration_ms,
867 "finished handling message"
868 ),
869 }
870 }
871 .boxed()
872 }),
873 );
874 if prev_handler.is_some() {
875 panic!("registered a handler for the same message twice");
876 }
877 self
878 }
879
880 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
881 where
882 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
883 Fut: 'static + Send + Future<Output = Result<()>>,
884 M: EnvelopedMessage,
885 {
886 self.add_handler(move |envelope, session| handler(envelope.payload, session));
887 self
888 }
889
890 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
891 where
892 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
893 Fut: Send + Future<Output = Result<()>>,
894 M: RequestMessage,
895 {
896 let handler = Arc::new(handler);
897 self.add_handler(move |envelope, session| {
898 let receipt = envelope.receipt();
899 let handler = handler.clone();
900 async move {
901 let peer = session.peer.clone();
902 let responded = Arc::new(AtomicBool::default());
903 let response = Response {
904 peer: peer.clone(),
905 responded: responded.clone(),
906 receipt,
907 };
908 match (handler)(envelope.payload, response, session).await {
909 Ok(()) => {
910 if responded.load(std::sync::atomic::Ordering::SeqCst) {
911 Ok(())
912 } else {
913 Err(anyhow!("handler did not send a response"))?
914 }
915 }
916 Err(error) => {
917 let proto_err = match &error {
918 Error::Internal(err) => err.to_proto(),
919 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
920 };
921 peer.respond_with_error(receipt, proto_err)?;
922 Err(error)
923 }
924 }
925 }
926 })
927 }
928
929 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
930 where
931 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
932 Fut: Send + Future<Output = Result<()>>,
933 M: RequestMessage,
934 {
935 let handler = Arc::new(handler);
936 self.add_handler(move |envelope, session| {
937 let receipt = envelope.receipt();
938 let handler = handler.clone();
939 async move {
940 let peer = session.peer.clone();
941 let response = StreamingResponse {
942 peer: peer.clone(),
943 receipt,
944 };
945 match (handler)(envelope.payload, response, session).await {
946 Ok(()) => {
947 peer.end_stream(receipt)?;
948 Ok(())
949 }
950 Err(error) => {
951 let proto_err = match &error {
952 Error::Internal(err) => err.to_proto(),
953 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
954 };
955 peer.respond_with_error(receipt, proto_err)?;
956 Err(error)
957 }
958 }
959 }
960 })
961 }
962
963 #[allow(clippy::too_many_arguments)]
964 pub fn handle_connection(
965 self: &Arc<Self>,
966 connection: Connection,
967 address: String,
968 principal: Principal,
969 zed_version: ZedVersion,
970 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
971 executor: Executor,
972 ) -> impl Future<Output = ()> {
973 let this = self.clone();
974 let span = info_span!("handle connection", %address,
975 connection_id=field::Empty,
976 user_id=field::Empty,
977 login=field::Empty,
978 impersonator=field::Empty,
979 dev_server_id=field::Empty
980 );
981 principal.update_span(&span);
982
983 let mut teardown = self.teardown.subscribe();
984 async move {
985 if *teardown.borrow() {
986 tracing::error!("server is tearing down");
987 return
988 }
989 let (connection_id, handle_io, mut incoming_rx) = this
990 .peer
991 .add_connection(connection, {
992 let executor = executor.clone();
993 move |duration| executor.sleep(duration)
994 });
995 tracing::Span::current().record("connection_id", format!("{}", connection_id));
996 tracing::info!("connection opened");
997
998 let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
999 let http_client = match IsahcHttpClient::builder().default_header("User-Agent", user_agent).build() {
1000 Ok(http_client) => Arc::new(http_client),
1001 Err(error) => {
1002 tracing::error!(?error, "failed to create HTTP client");
1003 return;
1004 }
1005 };
1006
1007 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
1008 Some(Arc::new(SupermavenAdminApi::new(
1009 supermaven_admin_api_key.to_string(),
1010 http_client.clone(),
1011 )))
1012 } else {
1013 None
1014 };
1015
1016 let session = Session {
1017 principal: principal.clone(),
1018 connection_id,
1019 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1020 peer: this.peer.clone(),
1021 connection_pool: this.connection_pool.clone(),
1022 live_kit_client: this.app_state.live_kit_client.clone(),
1023 http_client,
1024 rate_limiter: this.app_state.rate_limiter.clone(),
1025 _executor: executor.clone(),
1026 supermaven_client,
1027 };
1028
1029 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1030 tracing::error!(?error, "failed to send initial client update");
1031 return;
1032 }
1033
1034 let handle_io = handle_io.fuse();
1035 futures::pin_mut!(handle_io);
1036
1037 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1038 // This prevents deadlocks when e.g., client A performs a request to client B and
1039 // client B performs a request to client A. If both clients stop processing further
1040 // messages until their respective request completes, they won't have a chance to
1041 // respond to the other client's request and cause a deadlock.
1042 //
1043 // This arrangement ensures we will attempt to process earlier messages first, but fall
1044 // back to processing messages arrived later in the spirit of making progress.
1045 let mut foreground_message_handlers = FuturesUnordered::new();
1046 let concurrent_handlers = Arc::new(Semaphore::new(256));
1047 loop {
1048 let next_message = async {
1049 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1050 let message = incoming_rx.next().await;
1051 (permit, message)
1052 }.fuse();
1053 futures::pin_mut!(next_message);
1054 futures::select_biased! {
1055 _ = teardown.changed().fuse() => return,
1056 result = handle_io => {
1057 if let Err(error) = result {
1058 tracing::error!(?error, "error handling I/O");
1059 }
1060 break;
1061 }
1062 _ = foreground_message_handlers.next() => {}
1063 next_message = next_message => {
1064 let (permit, message) = next_message;
1065 if let Some(message) = message {
1066 let type_name = message.payload_type_name();
1067 // note: we copy all the fields from the parent span so we can query them in the logs.
1068 // (https://github.com/tokio-rs/tracing/issues/2670).
1069 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1070 user_id=field::Empty,
1071 login=field::Empty,
1072 impersonator=field::Empty,
1073 dev_server_id=field::Empty
1074 );
1075 principal.update_span(&span);
1076 let span_enter = span.enter();
1077 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1078 let is_background = message.is_background();
1079 let handle_message = (handler)(message, session.clone());
1080 drop(span_enter);
1081
1082 let handle_message = async move {
1083 handle_message.await;
1084 drop(permit);
1085 }.instrument(span);
1086 if is_background {
1087 executor.spawn_detached(handle_message);
1088 } else {
1089 foreground_message_handlers.push(handle_message);
1090 }
1091 } else {
1092 tracing::error!("no message handler");
1093 }
1094 } else {
1095 tracing::info!("connection closed");
1096 break;
1097 }
1098 }
1099 }
1100 }
1101
1102 drop(foreground_message_handlers);
1103 tracing::info!("signing out");
1104 if let Err(error) = connection_lost(session, teardown, executor).await {
1105 tracing::error!(?error, "error signing out");
1106 }
1107
1108 }.instrument(span)
1109 }
1110
1111 async fn send_initial_client_update(
1112 &self,
1113 connection_id: ConnectionId,
1114 principal: &Principal,
1115 zed_version: ZedVersion,
1116 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1117 session: &Session,
1118 ) -> Result<()> {
1119 self.peer.send(
1120 connection_id,
1121 proto::Hello {
1122 peer_id: Some(connection_id.into()),
1123 },
1124 )?;
1125 tracing::info!("sent hello message");
1126 if let Some(send_connection_id) = send_connection_id.take() {
1127 let _ = send_connection_id.send(connection_id);
1128 }
1129
1130 match principal {
1131 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1132 if !user.connected_once {
1133 self.peer.send(connection_id, proto::ShowContacts {})?;
1134 self.app_state
1135 .db
1136 .set_user_connected_once(user.id, true)
1137 .await?;
1138 }
1139
1140 let (contacts, dev_server_projects) = future::try_join(
1141 self.app_state.db.get_contacts(user.id),
1142 self.app_state.db.dev_server_projects_update(user.id),
1143 )
1144 .await?;
1145
1146 {
1147 let mut pool = self.connection_pool.lock();
1148 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1149 self.peer.send(
1150 connection_id,
1151 build_initial_contacts_update(contacts, &pool),
1152 )?;
1153 }
1154
1155 if should_auto_subscribe_to_channels(zed_version) {
1156 subscribe_user_to_channels(user.id, session).await?;
1157 }
1158
1159 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1160
1161 if let Some(incoming_call) =
1162 self.app_state.db.incoming_call_for_user(user.id).await?
1163 {
1164 self.peer.send(connection_id, incoming_call)?;
1165 }
1166
1167 update_user_contacts(user.id, &session).await?;
1168 }
1169 Principal::DevServer(dev_server) => {
1170 {
1171 let mut pool = self.connection_pool.lock();
1172 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1173 {
1174 self.peer.send(
1175 stale_connection_id,
1176 proto::ShutdownDevServer {
1177 reason: Some(
1178 "another dev server connected with the same token".to_string(),
1179 ),
1180 },
1181 )?;
1182 pool.remove_connection(stale_connection_id)?;
1183 };
1184 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1185 }
1186
1187 let projects = self
1188 .app_state
1189 .db
1190 .get_projects_for_dev_server(dev_server.id)
1191 .await?;
1192 self.peer
1193 .send(connection_id, proto::DevServerInstructions { projects })?;
1194
1195 let status = self
1196 .app_state
1197 .db
1198 .dev_server_projects_update(dev_server.user_id)
1199 .await?;
1200 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1201 }
1202 }
1203
1204 Ok(())
1205 }
1206
1207 pub async fn invite_code_redeemed(
1208 self: &Arc<Self>,
1209 inviter_id: UserId,
1210 invitee_id: UserId,
1211 ) -> Result<()> {
1212 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1213 if let Some(code) = &user.invite_code {
1214 let pool = self.connection_pool.lock();
1215 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1216 for connection_id in pool.user_connection_ids(inviter_id) {
1217 self.peer.send(
1218 connection_id,
1219 proto::UpdateContacts {
1220 contacts: vec![invitee_contact.clone()],
1221 ..Default::default()
1222 },
1223 )?;
1224 self.peer.send(
1225 connection_id,
1226 proto::UpdateInviteInfo {
1227 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1228 count: user.invite_count as u32,
1229 },
1230 )?;
1231 }
1232 }
1233 }
1234 Ok(())
1235 }
1236
1237 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1238 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1239 if let Some(invite_code) = &user.invite_code {
1240 let pool = self.connection_pool.lock();
1241 for connection_id in pool.user_connection_ids(user_id) {
1242 self.peer.send(
1243 connection_id,
1244 proto::UpdateInviteInfo {
1245 url: format!(
1246 "{}{}",
1247 self.app_state.config.invite_link_prefix, invite_code
1248 ),
1249 count: user.invite_count as u32,
1250 },
1251 )?;
1252 }
1253 }
1254 }
1255 Ok(())
1256 }
1257
1258 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1259 ServerSnapshot {
1260 connection_pool: ConnectionPoolGuard {
1261 guard: self.connection_pool.lock(),
1262 _not_send: PhantomData,
1263 },
1264 peer: &self.peer,
1265 }
1266 }
1267}
1268
1269impl<'a> Deref for ConnectionPoolGuard<'a> {
1270 type Target = ConnectionPool;
1271
1272 fn deref(&self) -> &Self::Target {
1273 &self.guard
1274 }
1275}
1276
1277impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1278 fn deref_mut(&mut self) -> &mut Self::Target {
1279 &mut self.guard
1280 }
1281}
1282
1283impl<'a> Drop for ConnectionPoolGuard<'a> {
1284 fn drop(&mut self) {
1285 #[cfg(test)]
1286 self.check_invariants();
1287 }
1288}
1289
1290fn broadcast<F>(
1291 sender_id: Option<ConnectionId>,
1292 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1293 mut f: F,
1294) where
1295 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1296{
1297 for receiver_id in receiver_ids {
1298 if Some(receiver_id) != sender_id {
1299 if let Err(error) = f(receiver_id) {
1300 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1301 }
1302 }
1303 }
1304}
1305
1306pub struct ProtocolVersion(u32);
1307
1308impl Header for ProtocolVersion {
1309 fn name() -> &'static HeaderName {
1310 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1311 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1312 }
1313
1314 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1315 where
1316 Self: Sized,
1317 I: Iterator<Item = &'i axum::http::HeaderValue>,
1318 {
1319 let version = values
1320 .next()
1321 .ok_or_else(axum::headers::Error::invalid)?
1322 .to_str()
1323 .map_err(|_| axum::headers::Error::invalid())?
1324 .parse()
1325 .map_err(|_| axum::headers::Error::invalid())?;
1326 Ok(Self(version))
1327 }
1328
1329 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1330 values.extend([self.0.to_string().parse().unwrap()]);
1331 }
1332}
1333
1334pub struct AppVersionHeader(SemanticVersion);
1335impl Header for AppVersionHeader {
1336 fn name() -> &'static HeaderName {
1337 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1338 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1339 }
1340
1341 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1342 where
1343 Self: Sized,
1344 I: Iterator<Item = &'i axum::http::HeaderValue>,
1345 {
1346 let version = values
1347 .next()
1348 .ok_or_else(axum::headers::Error::invalid)?
1349 .to_str()
1350 .map_err(|_| axum::headers::Error::invalid())?
1351 .parse()
1352 .map_err(|_| axum::headers::Error::invalid())?;
1353 Ok(Self(version))
1354 }
1355
1356 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1357 values.extend([self.0.to_string().parse().unwrap()]);
1358 }
1359}
1360
1361pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1362 Router::new()
1363 .route("/rpc", get(handle_websocket_request))
1364 .layer(
1365 ServiceBuilder::new()
1366 .layer(Extension(server.app_state.clone()))
1367 .layer(middleware::from_fn(auth::validate_header)),
1368 )
1369 .route("/metrics", get(handle_metrics))
1370 .layer(Extension(server))
1371}
1372
1373pub async fn handle_websocket_request(
1374 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1375 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1376 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1377 Extension(server): Extension<Arc<Server>>,
1378 Extension(principal): Extension<Principal>,
1379 ws: WebSocketUpgrade,
1380) -> axum::response::Response {
1381 if protocol_version != rpc::PROTOCOL_VERSION {
1382 return (
1383 StatusCode::UPGRADE_REQUIRED,
1384 "client must be upgraded".to_string(),
1385 )
1386 .into_response();
1387 }
1388
1389 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1390 return (
1391 StatusCode::UPGRADE_REQUIRED,
1392 "no version header found".to_string(),
1393 )
1394 .into_response();
1395 };
1396
1397 if !version.can_collaborate() {
1398 return (
1399 StatusCode::UPGRADE_REQUIRED,
1400 "client must be upgraded".to_string(),
1401 )
1402 .into_response();
1403 }
1404
1405 let socket_address = socket_address.to_string();
1406 ws.on_upgrade(move |socket| {
1407 let socket = socket
1408 .map_ok(to_tungstenite_message)
1409 .err_into()
1410 .with(|message| async move { to_axum_message(message) });
1411 let connection = Connection::new(Box::pin(socket));
1412 async move {
1413 server
1414 .handle_connection(
1415 connection,
1416 socket_address,
1417 principal,
1418 version,
1419 None,
1420 Executor::Production,
1421 )
1422 .await;
1423 }
1424 })
1425}
1426
1427pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1428 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1429 let connections_metric = CONNECTIONS_METRIC
1430 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1431
1432 let connections = server
1433 .connection_pool
1434 .lock()
1435 .connections()
1436 .filter(|connection| !connection.admin)
1437 .count();
1438 connections_metric.set(connections as _);
1439
1440 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1441 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1442 register_int_gauge!(
1443 "shared_projects",
1444 "number of open projects with one or more guests"
1445 )
1446 .unwrap()
1447 });
1448
1449 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1450 shared_projects_metric.set(shared_projects as _);
1451
1452 let encoder = prometheus::TextEncoder::new();
1453 let metric_families = prometheus::gather();
1454 let encoded_metrics = encoder
1455 .encode_to_string(&metric_families)
1456 .map_err(|err| anyhow!("{}", err))?;
1457 Ok(encoded_metrics)
1458}
1459
1460#[instrument(err, skip(executor))]
1461async fn connection_lost(
1462 session: Session,
1463 mut teardown: watch::Receiver<bool>,
1464 executor: Executor,
1465) -> Result<()> {
1466 session.peer.disconnect(session.connection_id);
1467 session
1468 .connection_pool()
1469 .await
1470 .remove_connection(session.connection_id)?;
1471
1472 session
1473 .db()
1474 .await
1475 .connection_lost(session.connection_id)
1476 .await
1477 .trace_err();
1478
1479 futures::select_biased! {
1480 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1481 match &session.principal {
1482 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1483 let session = session.for_user().unwrap();
1484
1485 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1486 leave_room_for_session(&session, session.connection_id).await.trace_err();
1487 leave_channel_buffers_for_session(&session)
1488 .await
1489 .trace_err();
1490
1491 if !session
1492 .connection_pool()
1493 .await
1494 .is_user_online(session.user_id())
1495 {
1496 let db = session.db().await;
1497 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1498 room_updated(&room, &session.peer);
1499 }
1500 }
1501
1502 update_user_contacts(session.user_id(), &session).await?;
1503 },
1504 Principal::DevServer(_) => {
1505 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1506 },
1507 }
1508 },
1509 _ = teardown.changed().fuse() => {}
1510 }
1511
1512 Ok(())
1513}
1514
1515/// Acknowledges a ping from a client, used to keep the connection alive.
1516async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1517 response.send(proto::Ack {})?;
1518 Ok(())
1519}
1520
1521/// Creates a new room for calling (outside of channels)
1522async fn create_room(
1523 _request: proto::CreateRoom,
1524 response: Response<proto::CreateRoom>,
1525 session: UserSession,
1526) -> Result<()> {
1527 let live_kit_room = nanoid::nanoid!(30);
1528
1529 let live_kit_connection_info = util::maybe!(async {
1530 let live_kit = session.live_kit_client.as_ref();
1531 let live_kit = live_kit?;
1532 let user_id = session.user_id().to_string();
1533
1534 let token = live_kit
1535 .room_token(&live_kit_room, &user_id.to_string())
1536 .trace_err()?;
1537
1538 Some(proto::LiveKitConnectionInfo {
1539 server_url: live_kit.url().into(),
1540 token,
1541 can_publish: true,
1542 })
1543 })
1544 .await;
1545
1546 let room = session
1547 .db()
1548 .await
1549 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1550 .await?;
1551
1552 response.send(proto::CreateRoomResponse {
1553 room: Some(room.clone()),
1554 live_kit_connection_info,
1555 })?;
1556
1557 update_user_contacts(session.user_id(), &session).await?;
1558 Ok(())
1559}
1560
1561/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1562async fn join_room(
1563 request: proto::JoinRoom,
1564 response: Response<proto::JoinRoom>,
1565 session: UserSession,
1566) -> Result<()> {
1567 let room_id = RoomId::from_proto(request.id);
1568
1569 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1570
1571 if let Some(channel_id) = channel_id {
1572 return join_channel_internal(channel_id, Box::new(response), session).await;
1573 }
1574
1575 let joined_room = {
1576 let room = session
1577 .db()
1578 .await
1579 .join_room(room_id, session.user_id(), session.connection_id)
1580 .await?;
1581 room_updated(&room.room, &session.peer);
1582 room.into_inner()
1583 };
1584
1585 for connection_id in session
1586 .connection_pool()
1587 .await
1588 .user_connection_ids(session.user_id())
1589 {
1590 session
1591 .peer
1592 .send(
1593 connection_id,
1594 proto::CallCanceled {
1595 room_id: room_id.to_proto(),
1596 },
1597 )
1598 .trace_err();
1599 }
1600
1601 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1602 if let Some(token) = live_kit
1603 .room_token(
1604 &joined_room.room.live_kit_room,
1605 &session.user_id().to_string(),
1606 )
1607 .trace_err()
1608 {
1609 Some(proto::LiveKitConnectionInfo {
1610 server_url: live_kit.url().into(),
1611 token,
1612 can_publish: true,
1613 })
1614 } else {
1615 None
1616 }
1617 } else {
1618 None
1619 };
1620
1621 response.send(proto::JoinRoomResponse {
1622 room: Some(joined_room.room),
1623 channel_id: None,
1624 live_kit_connection_info,
1625 })?;
1626
1627 update_user_contacts(session.user_id(), &session).await?;
1628 Ok(())
1629}
1630
1631/// Rejoin room is used to reconnect to a room after connection errors.
1632async fn rejoin_room(
1633 request: proto::RejoinRoom,
1634 response: Response<proto::RejoinRoom>,
1635 session: UserSession,
1636) -> Result<()> {
1637 let room;
1638 let channel;
1639 {
1640 let mut rejoined_room = session
1641 .db()
1642 .await
1643 .rejoin_room(request, session.user_id(), session.connection_id)
1644 .await?;
1645
1646 response.send(proto::RejoinRoomResponse {
1647 room: Some(rejoined_room.room.clone()),
1648 reshared_projects: rejoined_room
1649 .reshared_projects
1650 .iter()
1651 .map(|project| proto::ResharedProject {
1652 id: project.id.to_proto(),
1653 collaborators: project
1654 .collaborators
1655 .iter()
1656 .map(|collaborator| collaborator.to_proto())
1657 .collect(),
1658 })
1659 .collect(),
1660 rejoined_projects: rejoined_room
1661 .rejoined_projects
1662 .iter()
1663 .map(|rejoined_project| rejoined_project.to_proto())
1664 .collect(),
1665 })?;
1666 room_updated(&rejoined_room.room, &session.peer);
1667
1668 for project in &rejoined_room.reshared_projects {
1669 for collaborator in &project.collaborators {
1670 session
1671 .peer
1672 .send(
1673 collaborator.connection_id,
1674 proto::UpdateProjectCollaborator {
1675 project_id: project.id.to_proto(),
1676 old_peer_id: Some(project.old_connection_id.into()),
1677 new_peer_id: Some(session.connection_id.into()),
1678 },
1679 )
1680 .trace_err();
1681 }
1682
1683 broadcast(
1684 Some(session.connection_id),
1685 project
1686 .collaborators
1687 .iter()
1688 .map(|collaborator| collaborator.connection_id),
1689 |connection_id| {
1690 session.peer.forward_send(
1691 session.connection_id,
1692 connection_id,
1693 proto::UpdateProject {
1694 project_id: project.id.to_proto(),
1695 worktrees: project.worktrees.clone(),
1696 },
1697 )
1698 },
1699 );
1700 }
1701
1702 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1703
1704 let rejoined_room = rejoined_room.into_inner();
1705
1706 room = rejoined_room.room;
1707 channel = rejoined_room.channel;
1708 }
1709
1710 if let Some(channel) = channel {
1711 channel_updated(
1712 &channel,
1713 &room,
1714 &session.peer,
1715 &*session.connection_pool().await,
1716 );
1717 }
1718
1719 update_user_contacts(session.user_id(), &session).await?;
1720 Ok(())
1721}
1722
1723fn notify_rejoined_projects(
1724 rejoined_projects: &mut Vec<RejoinedProject>,
1725 session: &UserSession,
1726) -> Result<()> {
1727 for project in rejoined_projects.iter() {
1728 for collaborator in &project.collaborators {
1729 session
1730 .peer
1731 .send(
1732 collaborator.connection_id,
1733 proto::UpdateProjectCollaborator {
1734 project_id: project.id.to_proto(),
1735 old_peer_id: Some(project.old_connection_id.into()),
1736 new_peer_id: Some(session.connection_id.into()),
1737 },
1738 )
1739 .trace_err();
1740 }
1741 }
1742
1743 for project in rejoined_projects {
1744 for worktree in mem::take(&mut project.worktrees) {
1745 #[cfg(any(test, feature = "test-support"))]
1746 const MAX_CHUNK_SIZE: usize = 2;
1747 #[cfg(not(any(test, feature = "test-support")))]
1748 const MAX_CHUNK_SIZE: usize = 256;
1749
1750 // Stream this worktree's entries.
1751 let message = proto::UpdateWorktree {
1752 project_id: project.id.to_proto(),
1753 worktree_id: worktree.id,
1754 abs_path: worktree.abs_path.clone(),
1755 root_name: worktree.root_name,
1756 updated_entries: worktree.updated_entries,
1757 removed_entries: worktree.removed_entries,
1758 scan_id: worktree.scan_id,
1759 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1760 updated_repositories: worktree.updated_repositories,
1761 removed_repositories: worktree.removed_repositories,
1762 };
1763 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1764 session.peer.send(session.connection_id, update.clone())?;
1765 }
1766
1767 // Stream this worktree's diagnostics.
1768 for summary in worktree.diagnostic_summaries {
1769 session.peer.send(
1770 session.connection_id,
1771 proto::UpdateDiagnosticSummary {
1772 project_id: project.id.to_proto(),
1773 worktree_id: worktree.id,
1774 summary: Some(summary),
1775 },
1776 )?;
1777 }
1778
1779 for settings_file in worktree.settings_files {
1780 session.peer.send(
1781 session.connection_id,
1782 proto::UpdateWorktreeSettings {
1783 project_id: project.id.to_proto(),
1784 worktree_id: worktree.id,
1785 path: settings_file.path,
1786 content: Some(settings_file.content),
1787 },
1788 )?;
1789 }
1790 }
1791
1792 for language_server in &project.language_servers {
1793 session.peer.send(
1794 session.connection_id,
1795 proto::UpdateLanguageServer {
1796 project_id: project.id.to_proto(),
1797 language_server_id: language_server.id,
1798 variant: Some(
1799 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1800 proto::LspDiskBasedDiagnosticsUpdated {},
1801 ),
1802 ),
1803 },
1804 )?;
1805 }
1806 }
1807 Ok(())
1808}
1809
1810/// leave room disconnects from the room.
1811async fn leave_room(
1812 _: proto::LeaveRoom,
1813 response: Response<proto::LeaveRoom>,
1814 session: UserSession,
1815) -> Result<()> {
1816 leave_room_for_session(&session, session.connection_id).await?;
1817 response.send(proto::Ack {})?;
1818 Ok(())
1819}
1820
1821/// Updates the permissions of someone else in the room.
1822async fn set_room_participant_role(
1823 request: proto::SetRoomParticipantRole,
1824 response: Response<proto::SetRoomParticipantRole>,
1825 session: UserSession,
1826) -> Result<()> {
1827 let user_id = UserId::from_proto(request.user_id);
1828 let role = ChannelRole::from(request.role());
1829
1830 let (live_kit_room, can_publish) = {
1831 let room = session
1832 .db()
1833 .await
1834 .set_room_participant_role(
1835 session.user_id(),
1836 RoomId::from_proto(request.room_id),
1837 user_id,
1838 role,
1839 )
1840 .await?;
1841
1842 let live_kit_room = room.live_kit_room.clone();
1843 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1844 room_updated(&room, &session.peer);
1845 (live_kit_room, can_publish)
1846 };
1847
1848 if let Some(live_kit) = session.live_kit_client.as_ref() {
1849 live_kit
1850 .update_participant(
1851 live_kit_room.clone(),
1852 request.user_id.to_string(),
1853 live_kit_server::proto::ParticipantPermission {
1854 can_subscribe: true,
1855 can_publish,
1856 can_publish_data: can_publish,
1857 hidden: false,
1858 recorder: false,
1859 },
1860 )
1861 .await
1862 .trace_err();
1863 }
1864
1865 response.send(proto::Ack {})?;
1866 Ok(())
1867}
1868
1869/// Call someone else into the current room
1870async fn call(
1871 request: proto::Call,
1872 response: Response<proto::Call>,
1873 session: UserSession,
1874) -> Result<()> {
1875 let room_id = RoomId::from_proto(request.room_id);
1876 let calling_user_id = session.user_id();
1877 let calling_connection_id = session.connection_id;
1878 let called_user_id = UserId::from_proto(request.called_user_id);
1879 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1880 if !session
1881 .db()
1882 .await
1883 .has_contact(calling_user_id, called_user_id)
1884 .await?
1885 {
1886 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1887 }
1888
1889 let incoming_call = {
1890 let (room, incoming_call) = &mut *session
1891 .db()
1892 .await
1893 .call(
1894 room_id,
1895 calling_user_id,
1896 calling_connection_id,
1897 called_user_id,
1898 initial_project_id,
1899 )
1900 .await?;
1901 room_updated(&room, &session.peer);
1902 mem::take(incoming_call)
1903 };
1904 update_user_contacts(called_user_id, &session).await?;
1905
1906 let mut calls = session
1907 .connection_pool()
1908 .await
1909 .user_connection_ids(called_user_id)
1910 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1911 .collect::<FuturesUnordered<_>>();
1912
1913 while let Some(call_response) = calls.next().await {
1914 match call_response.as_ref() {
1915 Ok(_) => {
1916 response.send(proto::Ack {})?;
1917 return Ok(());
1918 }
1919 Err(_) => {
1920 call_response.trace_err();
1921 }
1922 }
1923 }
1924
1925 {
1926 let room = session
1927 .db()
1928 .await
1929 .call_failed(room_id, called_user_id)
1930 .await?;
1931 room_updated(&room, &session.peer);
1932 }
1933 update_user_contacts(called_user_id, &session).await?;
1934
1935 Err(anyhow!("failed to ring user"))?
1936}
1937
1938/// Cancel an outgoing call.
1939async fn cancel_call(
1940 request: proto::CancelCall,
1941 response: Response<proto::CancelCall>,
1942 session: UserSession,
1943) -> Result<()> {
1944 let called_user_id = UserId::from_proto(request.called_user_id);
1945 let room_id = RoomId::from_proto(request.room_id);
1946 {
1947 let room = session
1948 .db()
1949 .await
1950 .cancel_call(room_id, session.connection_id, called_user_id)
1951 .await?;
1952 room_updated(&room, &session.peer);
1953 }
1954
1955 for connection_id in session
1956 .connection_pool()
1957 .await
1958 .user_connection_ids(called_user_id)
1959 {
1960 session
1961 .peer
1962 .send(
1963 connection_id,
1964 proto::CallCanceled {
1965 room_id: room_id.to_proto(),
1966 },
1967 )
1968 .trace_err();
1969 }
1970 response.send(proto::Ack {})?;
1971
1972 update_user_contacts(called_user_id, &session).await?;
1973 Ok(())
1974}
1975
1976/// Decline an incoming call.
1977async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1978 let room_id = RoomId::from_proto(message.room_id);
1979 {
1980 let room = session
1981 .db()
1982 .await
1983 .decline_call(Some(room_id), session.user_id())
1984 .await?
1985 .ok_or_else(|| anyhow!("failed to decline call"))?;
1986 room_updated(&room, &session.peer);
1987 }
1988
1989 for connection_id in session
1990 .connection_pool()
1991 .await
1992 .user_connection_ids(session.user_id())
1993 {
1994 session
1995 .peer
1996 .send(
1997 connection_id,
1998 proto::CallCanceled {
1999 room_id: room_id.to_proto(),
2000 },
2001 )
2002 .trace_err();
2003 }
2004 update_user_contacts(session.user_id(), &session).await?;
2005 Ok(())
2006}
2007
2008/// Updates other participants in the room with your current location.
2009async fn update_participant_location(
2010 request: proto::UpdateParticipantLocation,
2011 response: Response<proto::UpdateParticipantLocation>,
2012 session: UserSession,
2013) -> Result<()> {
2014 let room_id = RoomId::from_proto(request.room_id);
2015 let location = request
2016 .location
2017 .ok_or_else(|| anyhow!("invalid location"))?;
2018
2019 let db = session.db().await;
2020 let room = db
2021 .update_room_participant_location(room_id, session.connection_id, location)
2022 .await?;
2023
2024 room_updated(&room, &session.peer);
2025 response.send(proto::Ack {})?;
2026 Ok(())
2027}
2028
2029/// Share a project into the room.
2030async fn share_project(
2031 request: proto::ShareProject,
2032 response: Response<proto::ShareProject>,
2033 session: UserSession,
2034) -> Result<()> {
2035 let (project_id, room) = &*session
2036 .db()
2037 .await
2038 .share_project(
2039 RoomId::from_proto(request.room_id),
2040 session.connection_id,
2041 &request.worktrees,
2042 request
2043 .dev_server_project_id
2044 .map(|id| DevServerProjectId::from_proto(id)),
2045 )
2046 .await?;
2047 response.send(proto::ShareProjectResponse {
2048 project_id: project_id.to_proto(),
2049 })?;
2050 room_updated(&room, &session.peer);
2051
2052 Ok(())
2053}
2054
2055/// Unshare a project from the room.
2056async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2057 let project_id = ProjectId::from_proto(message.project_id);
2058 unshare_project_internal(
2059 project_id,
2060 session.connection_id,
2061 session.user_id(),
2062 &session,
2063 )
2064 .await
2065}
2066
2067async fn unshare_project_internal(
2068 project_id: ProjectId,
2069 connection_id: ConnectionId,
2070 user_id: Option<UserId>,
2071 session: &Session,
2072) -> Result<()> {
2073 let delete = {
2074 let room_guard = session
2075 .db()
2076 .await
2077 .unshare_project(project_id, connection_id, user_id)
2078 .await?;
2079
2080 let (delete, room, guest_connection_ids) = &*room_guard;
2081
2082 let message = proto::UnshareProject {
2083 project_id: project_id.to_proto(),
2084 };
2085
2086 broadcast(
2087 Some(connection_id),
2088 guest_connection_ids.iter().copied(),
2089 |conn_id| session.peer.send(conn_id, message.clone()),
2090 );
2091 if let Some(room) = room {
2092 room_updated(room, &session.peer);
2093 }
2094
2095 *delete
2096 };
2097
2098 if delete {
2099 let db = session.db().await;
2100 db.delete_project(project_id).await?;
2101 }
2102
2103 Ok(())
2104}
2105
2106/// DevServer makes a project available online
2107async fn share_dev_server_project(
2108 request: proto::ShareDevServerProject,
2109 response: Response<proto::ShareDevServerProject>,
2110 session: DevServerSession,
2111) -> Result<()> {
2112 let (dev_server_project, user_id, status) = session
2113 .db()
2114 .await
2115 .share_dev_server_project(
2116 DevServerProjectId::from_proto(request.dev_server_project_id),
2117 session.dev_server_id(),
2118 session.connection_id,
2119 &request.worktrees,
2120 )
2121 .await?;
2122 let Some(project_id) = dev_server_project.project_id else {
2123 return Err(anyhow!("failed to share remote project"))?;
2124 };
2125
2126 send_dev_server_projects_update(user_id, status, &session).await;
2127
2128 response.send(proto::ShareProjectResponse { project_id })?;
2129
2130 Ok(())
2131}
2132
2133/// Join someone elses shared project.
2134async fn join_project(
2135 request: proto::JoinProject,
2136 response: Response<proto::JoinProject>,
2137 session: UserSession,
2138) -> Result<()> {
2139 let project_id = ProjectId::from_proto(request.project_id);
2140
2141 tracing::info!(%project_id, "join project");
2142
2143 let db = session.db().await;
2144 let (project, replica_id) = &mut *db
2145 .join_project(project_id, session.connection_id, session.user_id())
2146 .await?;
2147 drop(db);
2148 tracing::info!(%project_id, "join remote project");
2149 join_project_internal(response, session, project, replica_id)
2150}
2151
2152trait JoinProjectInternalResponse {
2153 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2154}
2155impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2156 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2157 Response::<proto::JoinProject>::send(self, result)
2158 }
2159}
2160impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2161 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2162 Response::<proto::JoinHostedProject>::send(self, result)
2163 }
2164}
2165
2166fn join_project_internal(
2167 response: impl JoinProjectInternalResponse,
2168 session: UserSession,
2169 project: &mut Project,
2170 replica_id: &ReplicaId,
2171) -> Result<()> {
2172 let collaborators = project
2173 .collaborators
2174 .iter()
2175 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2176 .map(|collaborator| collaborator.to_proto())
2177 .collect::<Vec<_>>();
2178 let project_id = project.id;
2179 let guest_user_id = session.user_id();
2180
2181 let worktrees = project
2182 .worktrees
2183 .iter()
2184 .map(|(id, worktree)| proto::WorktreeMetadata {
2185 id: *id,
2186 root_name: worktree.root_name.clone(),
2187 visible: worktree.visible,
2188 abs_path: worktree.abs_path.clone(),
2189 })
2190 .collect::<Vec<_>>();
2191
2192 let add_project_collaborator = proto::AddProjectCollaborator {
2193 project_id: project_id.to_proto(),
2194 collaborator: Some(proto::Collaborator {
2195 peer_id: Some(session.connection_id.into()),
2196 replica_id: replica_id.0 as u32,
2197 user_id: guest_user_id.to_proto(),
2198 }),
2199 };
2200
2201 for collaborator in &collaborators {
2202 session
2203 .peer
2204 .send(
2205 collaborator.peer_id.unwrap().into(),
2206 add_project_collaborator.clone(),
2207 )
2208 .trace_err();
2209 }
2210
2211 // First, we send the metadata associated with each worktree.
2212 response.send(proto::JoinProjectResponse {
2213 project_id: project.id.0 as u64,
2214 worktrees: worktrees.clone(),
2215 replica_id: replica_id.0 as u32,
2216 collaborators: collaborators.clone(),
2217 language_servers: project.language_servers.clone(),
2218 role: project.role.into(),
2219 dev_server_project_id: project
2220 .dev_server_project_id
2221 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2222 })?;
2223
2224 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2225 #[cfg(any(test, feature = "test-support"))]
2226 const MAX_CHUNK_SIZE: usize = 2;
2227 #[cfg(not(any(test, feature = "test-support")))]
2228 const MAX_CHUNK_SIZE: usize = 256;
2229
2230 // Stream this worktree's entries.
2231 let message = proto::UpdateWorktree {
2232 project_id: project_id.to_proto(),
2233 worktree_id,
2234 abs_path: worktree.abs_path.clone(),
2235 root_name: worktree.root_name,
2236 updated_entries: worktree.entries,
2237 removed_entries: Default::default(),
2238 scan_id: worktree.scan_id,
2239 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2240 updated_repositories: worktree.repository_entries.into_values().collect(),
2241 removed_repositories: Default::default(),
2242 };
2243 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2244 session.peer.send(session.connection_id, update.clone())?;
2245 }
2246
2247 // Stream this worktree's diagnostics.
2248 for summary in worktree.diagnostic_summaries {
2249 session.peer.send(
2250 session.connection_id,
2251 proto::UpdateDiagnosticSummary {
2252 project_id: project_id.to_proto(),
2253 worktree_id: worktree.id,
2254 summary: Some(summary),
2255 },
2256 )?;
2257 }
2258
2259 for settings_file in worktree.settings_files {
2260 session.peer.send(
2261 session.connection_id,
2262 proto::UpdateWorktreeSettings {
2263 project_id: project_id.to_proto(),
2264 worktree_id: worktree.id,
2265 path: settings_file.path,
2266 content: Some(settings_file.content),
2267 },
2268 )?;
2269 }
2270 }
2271
2272 for language_server in &project.language_servers {
2273 session.peer.send(
2274 session.connection_id,
2275 proto::UpdateLanguageServer {
2276 project_id: project_id.to_proto(),
2277 language_server_id: language_server.id,
2278 variant: Some(
2279 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2280 proto::LspDiskBasedDiagnosticsUpdated {},
2281 ),
2282 ),
2283 },
2284 )?;
2285 }
2286
2287 Ok(())
2288}
2289
2290/// Leave someone elses shared project.
2291async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2292 let sender_id = session.connection_id;
2293 let project_id = ProjectId::from_proto(request.project_id);
2294 let db = session.db().await;
2295 if db.is_hosted_project(project_id).await? {
2296 let project = db.leave_hosted_project(project_id, sender_id).await?;
2297 project_left(&project, &session);
2298 return Ok(());
2299 }
2300
2301 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2302 tracing::info!(
2303 %project_id,
2304 "leave project"
2305 );
2306
2307 project_left(&project, &session);
2308 if let Some(room) = room {
2309 room_updated(&room, &session.peer);
2310 }
2311
2312 Ok(())
2313}
2314
2315async fn join_hosted_project(
2316 request: proto::JoinHostedProject,
2317 response: Response<proto::JoinHostedProject>,
2318 session: UserSession,
2319) -> Result<()> {
2320 let (mut project, replica_id) = session
2321 .db()
2322 .await
2323 .join_hosted_project(
2324 ProjectId(request.project_id as i32),
2325 session.user_id(),
2326 session.connection_id,
2327 )
2328 .await?;
2329
2330 join_project_internal(response, session, &mut project, &replica_id)
2331}
2332
2333async fn list_remote_directory(
2334 request: proto::ListRemoteDirectory,
2335 response: Response<proto::ListRemoteDirectory>,
2336 session: UserSession,
2337) -> Result<()> {
2338 let dev_server_id = DevServerId(request.dev_server_id as i32);
2339 let dev_server_connection_id = session
2340 .connection_pool()
2341 .await
2342 .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2343
2344 session
2345 .db()
2346 .await
2347 .get_dev_server_for_user(dev_server_id, session.user_id())
2348 .await?;
2349
2350 response.send(
2351 session
2352 .peer
2353 .forward_request(session.connection_id, dev_server_connection_id, request)
2354 .await?,
2355 )?;
2356 Ok(())
2357}
2358
2359async fn update_dev_server_project(
2360 request: proto::UpdateDevServerProject,
2361 response: Response<proto::UpdateDevServerProject>,
2362 session: UserSession,
2363) -> Result<()> {
2364 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2365
2366 let (dev_server_project, update) = session
2367 .db()
2368 .await
2369 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2370 .await?;
2371
2372 let projects = session
2373 .db()
2374 .await
2375 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2376 .await?;
2377
2378 let dev_server_connection_id = session
2379 .connection_pool()
2380 .await
2381 .dev_server_connection_id_supporting(
2382 dev_server_project.dev_server_id,
2383 ZedVersion::with_list_directory(),
2384 )?;
2385
2386 session.peer.send(
2387 dev_server_connection_id,
2388 proto::DevServerInstructions { projects },
2389 )?;
2390
2391 send_dev_server_projects_update(session.user_id(), update, &session).await;
2392
2393 response.send(proto::Ack {})
2394}
2395
2396async fn create_dev_server_project(
2397 request: proto::CreateDevServerProject,
2398 response: Response<proto::CreateDevServerProject>,
2399 session: UserSession,
2400) -> Result<()> {
2401 let dev_server_id = DevServerId(request.dev_server_id as i32);
2402 let dev_server_connection_id = session
2403 .connection_pool()
2404 .await
2405 .dev_server_connection_id(dev_server_id);
2406 let Some(dev_server_connection_id) = dev_server_connection_id else {
2407 Err(ErrorCode::DevServerOffline
2408 .message("Cannot create a remote project when the dev server is offline".to_string())
2409 .anyhow())?
2410 };
2411
2412 let path = request.path.clone();
2413 //Check that the path exists on the dev server
2414 session
2415 .peer
2416 .forward_request(
2417 session.connection_id,
2418 dev_server_connection_id,
2419 proto::ValidateDevServerProjectRequest { path: path.clone() },
2420 )
2421 .await?;
2422
2423 let (dev_server_project, update) = session
2424 .db()
2425 .await
2426 .create_dev_server_project(
2427 DevServerId(request.dev_server_id as i32),
2428 &request.path,
2429 session.user_id(),
2430 )
2431 .await?;
2432
2433 let projects = session
2434 .db()
2435 .await
2436 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2437 .await?;
2438
2439 session.peer.send(
2440 dev_server_connection_id,
2441 proto::DevServerInstructions { projects },
2442 )?;
2443
2444 send_dev_server_projects_update(session.user_id(), update, &session).await;
2445
2446 response.send(proto::CreateDevServerProjectResponse {
2447 dev_server_project: Some(dev_server_project.to_proto(None)),
2448 })?;
2449 Ok(())
2450}
2451
2452async fn create_dev_server(
2453 request: proto::CreateDevServer,
2454 response: Response<proto::CreateDevServer>,
2455 session: UserSession,
2456) -> Result<()> {
2457 let access_token = auth::random_token();
2458 let hashed_access_token = auth::hash_access_token(&access_token);
2459
2460 if request.name.is_empty() {
2461 return Err(proto::ErrorCode::Forbidden
2462 .message("Dev server name cannot be empty".to_string())
2463 .anyhow())?;
2464 }
2465
2466 let (dev_server, status) = session
2467 .db()
2468 .await
2469 .create_dev_server(
2470 &request.name,
2471 request.ssh_connection_string.as_deref(),
2472 &hashed_access_token,
2473 session.user_id(),
2474 )
2475 .await?;
2476
2477 send_dev_server_projects_update(session.user_id(), status, &session).await;
2478
2479 response.send(proto::CreateDevServerResponse {
2480 dev_server_id: dev_server.id.0 as u64,
2481 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2482 name: request.name,
2483 })?;
2484 Ok(())
2485}
2486
2487async fn regenerate_dev_server_token(
2488 request: proto::RegenerateDevServerToken,
2489 response: Response<proto::RegenerateDevServerToken>,
2490 session: UserSession,
2491) -> Result<()> {
2492 let dev_server_id = DevServerId(request.dev_server_id as i32);
2493 let access_token = auth::random_token();
2494 let hashed_access_token = auth::hash_access_token(&access_token);
2495
2496 let connection_id = session
2497 .connection_pool()
2498 .await
2499 .dev_server_connection_id(dev_server_id);
2500 if let Some(connection_id) = connection_id {
2501 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2502 session.peer.send(
2503 connection_id,
2504 proto::ShutdownDevServer {
2505 reason: Some("dev server token was regenerated".to_string()),
2506 },
2507 )?;
2508 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2509 }
2510
2511 let status = session
2512 .db()
2513 .await
2514 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2515 .await?;
2516
2517 send_dev_server_projects_update(session.user_id(), status, &session).await;
2518
2519 response.send(proto::RegenerateDevServerTokenResponse {
2520 dev_server_id: dev_server_id.to_proto(),
2521 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2522 })?;
2523 Ok(())
2524}
2525
2526async fn rename_dev_server(
2527 request: proto::RenameDevServer,
2528 response: Response<proto::RenameDevServer>,
2529 session: UserSession,
2530) -> Result<()> {
2531 if request.name.trim().is_empty() {
2532 return Err(proto::ErrorCode::Forbidden
2533 .message("Dev server name cannot be empty".to_string())
2534 .anyhow())?;
2535 }
2536
2537 let dev_server_id = DevServerId(request.dev_server_id as i32);
2538 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2539 if dev_server.user_id != session.user_id() {
2540 return Err(anyhow!(ErrorCode::Forbidden))?;
2541 }
2542
2543 let status = session
2544 .db()
2545 .await
2546 .rename_dev_server(
2547 dev_server_id,
2548 &request.name,
2549 request.ssh_connection_string.as_deref(),
2550 session.user_id(),
2551 )
2552 .await?;
2553
2554 send_dev_server_projects_update(session.user_id(), status, &session).await;
2555
2556 response.send(proto::Ack {})?;
2557 Ok(())
2558}
2559
2560async fn delete_dev_server(
2561 request: proto::DeleteDevServer,
2562 response: Response<proto::DeleteDevServer>,
2563 session: UserSession,
2564) -> Result<()> {
2565 let dev_server_id = DevServerId(request.dev_server_id as i32);
2566 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2567 if dev_server.user_id != session.user_id() {
2568 return Err(anyhow!(ErrorCode::Forbidden))?;
2569 }
2570
2571 let connection_id = session
2572 .connection_pool()
2573 .await
2574 .dev_server_connection_id(dev_server_id);
2575 if let Some(connection_id) = connection_id {
2576 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2577 session.peer.send(
2578 connection_id,
2579 proto::ShutdownDevServer {
2580 reason: Some("dev server was deleted".to_string()),
2581 },
2582 )?;
2583 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2584 }
2585
2586 let status = session
2587 .db()
2588 .await
2589 .delete_dev_server(dev_server_id, session.user_id())
2590 .await?;
2591
2592 send_dev_server_projects_update(session.user_id(), status, &session).await;
2593
2594 response.send(proto::Ack {})?;
2595 Ok(())
2596}
2597
2598async fn delete_dev_server_project(
2599 request: proto::DeleteDevServerProject,
2600 response: Response<proto::DeleteDevServerProject>,
2601 session: UserSession,
2602) -> Result<()> {
2603 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2604 let dev_server_project = session
2605 .db()
2606 .await
2607 .get_dev_server_project(dev_server_project_id)
2608 .await?;
2609
2610 let dev_server = session
2611 .db()
2612 .await
2613 .get_dev_server(dev_server_project.dev_server_id)
2614 .await?;
2615 if dev_server.user_id != session.user_id() {
2616 return Err(anyhow!(ErrorCode::Forbidden))?;
2617 }
2618
2619 let dev_server_connection_id = session
2620 .connection_pool()
2621 .await
2622 .dev_server_connection_id(dev_server.id);
2623
2624 if let Some(dev_server_connection_id) = dev_server_connection_id {
2625 let project = session
2626 .db()
2627 .await
2628 .find_dev_server_project(dev_server_project_id)
2629 .await;
2630 if let Ok(project) = project {
2631 unshare_project_internal(
2632 project.id,
2633 dev_server_connection_id,
2634 Some(session.user_id()),
2635 &session,
2636 )
2637 .await?;
2638 }
2639 }
2640
2641 let (projects, status) = session
2642 .db()
2643 .await
2644 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2645 .await?;
2646
2647 if let Some(dev_server_connection_id) = dev_server_connection_id {
2648 session.peer.send(
2649 dev_server_connection_id,
2650 proto::DevServerInstructions { projects },
2651 )?;
2652 }
2653
2654 send_dev_server_projects_update(session.user_id(), status, &session).await;
2655
2656 response.send(proto::Ack {})?;
2657 Ok(())
2658}
2659
2660async fn rejoin_dev_server_projects(
2661 request: proto::RejoinRemoteProjects,
2662 response: Response<proto::RejoinRemoteProjects>,
2663 session: UserSession,
2664) -> Result<()> {
2665 let mut rejoined_projects = {
2666 let db = session.db().await;
2667 db.rejoin_dev_server_projects(
2668 &request.rejoined_projects,
2669 session.user_id(),
2670 session.0.connection_id,
2671 )
2672 .await?
2673 };
2674 response.send(proto::RejoinRemoteProjectsResponse {
2675 rejoined_projects: rejoined_projects
2676 .iter()
2677 .map(|project| project.to_proto())
2678 .collect(),
2679 })?;
2680 notify_rejoined_projects(&mut rejoined_projects, &session)
2681}
2682
2683async fn reconnect_dev_server(
2684 request: proto::ReconnectDevServer,
2685 response: Response<proto::ReconnectDevServer>,
2686 session: DevServerSession,
2687) -> Result<()> {
2688 let reshared_projects = {
2689 let db = session.db().await;
2690 db.reshare_dev_server_projects(
2691 &request.reshared_projects,
2692 session.dev_server_id(),
2693 session.0.connection_id,
2694 )
2695 .await?
2696 };
2697
2698 for project in &reshared_projects {
2699 for collaborator in &project.collaborators {
2700 session
2701 .peer
2702 .send(
2703 collaborator.connection_id,
2704 proto::UpdateProjectCollaborator {
2705 project_id: project.id.to_proto(),
2706 old_peer_id: Some(project.old_connection_id.into()),
2707 new_peer_id: Some(session.connection_id.into()),
2708 },
2709 )
2710 .trace_err();
2711 }
2712
2713 broadcast(
2714 Some(session.connection_id),
2715 project
2716 .collaborators
2717 .iter()
2718 .map(|collaborator| collaborator.connection_id),
2719 |connection_id| {
2720 session.peer.forward_send(
2721 session.connection_id,
2722 connection_id,
2723 proto::UpdateProject {
2724 project_id: project.id.to_proto(),
2725 worktrees: project.worktrees.clone(),
2726 },
2727 )
2728 },
2729 );
2730 }
2731
2732 response.send(proto::ReconnectDevServerResponse {
2733 reshared_projects: reshared_projects
2734 .iter()
2735 .map(|project| proto::ResharedProject {
2736 id: project.id.to_proto(),
2737 collaborators: project
2738 .collaborators
2739 .iter()
2740 .map(|collaborator| collaborator.to_proto())
2741 .collect(),
2742 })
2743 .collect(),
2744 })?;
2745
2746 Ok(())
2747}
2748
2749async fn shutdown_dev_server(
2750 _: proto::ShutdownDevServer,
2751 response: Response<proto::ShutdownDevServer>,
2752 session: DevServerSession,
2753) -> Result<()> {
2754 response.send(proto::Ack {})?;
2755 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2756 remove_dev_server_connection(session.dev_server_id(), &session).await
2757}
2758
2759async fn shutdown_dev_server_internal(
2760 dev_server_id: DevServerId,
2761 connection_id: ConnectionId,
2762 session: &Session,
2763) -> Result<()> {
2764 let (dev_server_projects, dev_server) = {
2765 let db = session.db().await;
2766 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2767 let dev_server = db.get_dev_server(dev_server_id).await?;
2768 (dev_server_projects, dev_server)
2769 };
2770
2771 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2772 unshare_project_internal(
2773 ProjectId::from_proto(project_id),
2774 connection_id,
2775 None,
2776 session,
2777 )
2778 .await?;
2779 }
2780
2781 session
2782 .connection_pool()
2783 .await
2784 .set_dev_server_offline(dev_server_id);
2785
2786 let status = session
2787 .db()
2788 .await
2789 .dev_server_projects_update(dev_server.user_id)
2790 .await?;
2791 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2792
2793 Ok(())
2794}
2795
2796async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2797 let dev_server_connection = session
2798 .connection_pool()
2799 .await
2800 .dev_server_connection_id(dev_server_id);
2801
2802 if let Some(dev_server_connection) = dev_server_connection {
2803 session
2804 .connection_pool()
2805 .await
2806 .remove_connection(dev_server_connection)?;
2807 }
2808 Ok(())
2809}
2810
2811/// Updates other participants with changes to the project
2812async fn update_project(
2813 request: proto::UpdateProject,
2814 response: Response<proto::UpdateProject>,
2815 session: Session,
2816) -> Result<()> {
2817 let project_id = ProjectId::from_proto(request.project_id);
2818 let (room, guest_connection_ids) = &*session
2819 .db()
2820 .await
2821 .update_project(project_id, session.connection_id, &request.worktrees)
2822 .await?;
2823 broadcast(
2824 Some(session.connection_id),
2825 guest_connection_ids.iter().copied(),
2826 |connection_id| {
2827 session
2828 .peer
2829 .forward_send(session.connection_id, connection_id, request.clone())
2830 },
2831 );
2832 if let Some(room) = room {
2833 room_updated(&room, &session.peer);
2834 }
2835 response.send(proto::Ack {})?;
2836
2837 Ok(())
2838}
2839
2840/// Updates other participants with changes to the worktree
2841async fn update_worktree(
2842 request: proto::UpdateWorktree,
2843 response: Response<proto::UpdateWorktree>,
2844 session: Session,
2845) -> Result<()> {
2846 let guest_connection_ids = session
2847 .db()
2848 .await
2849 .update_worktree(&request, session.connection_id)
2850 .await?;
2851
2852 broadcast(
2853 Some(session.connection_id),
2854 guest_connection_ids.iter().copied(),
2855 |connection_id| {
2856 session
2857 .peer
2858 .forward_send(session.connection_id, connection_id, request.clone())
2859 },
2860 );
2861 response.send(proto::Ack {})?;
2862 Ok(())
2863}
2864
2865/// Updates other participants with changes to the diagnostics
2866async fn update_diagnostic_summary(
2867 message: proto::UpdateDiagnosticSummary,
2868 session: Session,
2869) -> Result<()> {
2870 let guest_connection_ids = session
2871 .db()
2872 .await
2873 .update_diagnostic_summary(&message, session.connection_id)
2874 .await?;
2875
2876 broadcast(
2877 Some(session.connection_id),
2878 guest_connection_ids.iter().copied(),
2879 |connection_id| {
2880 session
2881 .peer
2882 .forward_send(session.connection_id, connection_id, message.clone())
2883 },
2884 );
2885
2886 Ok(())
2887}
2888
2889/// Updates other participants with changes to the worktree settings
2890async fn update_worktree_settings(
2891 message: proto::UpdateWorktreeSettings,
2892 session: Session,
2893) -> Result<()> {
2894 let guest_connection_ids = session
2895 .db()
2896 .await
2897 .update_worktree_settings(&message, session.connection_id)
2898 .await?;
2899
2900 broadcast(
2901 Some(session.connection_id),
2902 guest_connection_ids.iter().copied(),
2903 |connection_id| {
2904 session
2905 .peer
2906 .forward_send(session.connection_id, connection_id, message.clone())
2907 },
2908 );
2909
2910 Ok(())
2911}
2912
2913/// Notify other participants that a language server has started.
2914async fn start_language_server(
2915 request: proto::StartLanguageServer,
2916 session: Session,
2917) -> Result<()> {
2918 let guest_connection_ids = session
2919 .db()
2920 .await
2921 .start_language_server(&request, session.connection_id)
2922 .await?;
2923
2924 broadcast(
2925 Some(session.connection_id),
2926 guest_connection_ids.iter().copied(),
2927 |connection_id| {
2928 session
2929 .peer
2930 .forward_send(session.connection_id, connection_id, request.clone())
2931 },
2932 );
2933 Ok(())
2934}
2935
2936/// Notify other participants that a language server has changed.
2937async fn update_language_server(
2938 request: proto::UpdateLanguageServer,
2939 session: Session,
2940) -> Result<()> {
2941 let project_id = ProjectId::from_proto(request.project_id);
2942 let project_connection_ids = session
2943 .db()
2944 .await
2945 .project_connection_ids(project_id, session.connection_id, true)
2946 .await?;
2947 broadcast(
2948 Some(session.connection_id),
2949 project_connection_ids.iter().copied(),
2950 |connection_id| {
2951 session
2952 .peer
2953 .forward_send(session.connection_id, connection_id, request.clone())
2954 },
2955 );
2956 Ok(())
2957}
2958
2959/// forward a project request to the host. These requests should be read only
2960/// as guests are allowed to send them.
2961async fn forward_read_only_project_request<T>(
2962 request: T,
2963 response: Response<T>,
2964 session: UserSession,
2965) -> Result<()>
2966where
2967 T: EntityMessage + RequestMessage,
2968{
2969 let project_id = ProjectId::from_proto(request.remote_entity_id());
2970 let host_connection_id = session
2971 .db()
2972 .await
2973 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2974 .await?;
2975 let payload = session
2976 .peer
2977 .forward_request(session.connection_id, host_connection_id, request)
2978 .await?;
2979 response.send(payload)?;
2980 Ok(())
2981}
2982
2983/// forward a project request to the dev server. Only allowed
2984/// if it's your dev server.
2985async fn forward_project_request_for_owner<T>(
2986 request: T,
2987 response: Response<T>,
2988 session: UserSession,
2989) -> Result<()>
2990where
2991 T: EntityMessage + RequestMessage,
2992{
2993 let project_id = ProjectId::from_proto(request.remote_entity_id());
2994
2995 let host_connection_id = session
2996 .db()
2997 .await
2998 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2999 .await?;
3000 let payload = session
3001 .peer
3002 .forward_request(session.connection_id, host_connection_id, request)
3003 .await?;
3004 response.send(payload)?;
3005 Ok(())
3006}
3007
3008/// forward a project request to the host. These requests are disallowed
3009/// for guests.
3010async fn forward_mutating_project_request<T>(
3011 request: T,
3012 response: Response<T>,
3013 session: UserSession,
3014) -> Result<()>
3015where
3016 T: EntityMessage + RequestMessage,
3017{
3018 let project_id = ProjectId::from_proto(request.remote_entity_id());
3019
3020 let host_connection_id = session
3021 .db()
3022 .await
3023 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3024 .await?;
3025 let payload = session
3026 .peer
3027 .forward_request(session.connection_id, host_connection_id, request)
3028 .await?;
3029 response.send(payload)?;
3030 Ok(())
3031}
3032
3033/// forward a project request to the host. These requests are disallowed
3034/// for guests.
3035async fn forward_versioned_mutating_project_request<T>(
3036 request: T,
3037 response: Response<T>,
3038 session: UserSession,
3039) -> Result<()>
3040where
3041 T: EntityMessage + RequestMessage + VersionedMessage,
3042{
3043 let project_id = ProjectId::from_proto(request.remote_entity_id());
3044
3045 let host_connection_id = session
3046 .db()
3047 .await
3048 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3049 .await?;
3050 if let Some(host_version) = session
3051 .connection_pool()
3052 .await
3053 .connection(host_connection_id)
3054 .map(|c| c.zed_version)
3055 {
3056 if let Some(min_required_version) = request.required_host_version() {
3057 if min_required_version > host_version {
3058 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3059 .with_tag("required", &min_required_version.to_string())))?;
3060 }
3061 }
3062 }
3063
3064 let payload = session
3065 .peer
3066 .forward_request(session.connection_id, host_connection_id, request)
3067 .await?;
3068 response.send(payload)?;
3069 Ok(())
3070}
3071
3072/// Notify other participants that a new buffer has been created
3073async fn create_buffer_for_peer(
3074 request: proto::CreateBufferForPeer,
3075 session: Session,
3076) -> Result<()> {
3077 session
3078 .db()
3079 .await
3080 .check_user_is_project_host(
3081 ProjectId::from_proto(request.project_id),
3082 session.connection_id,
3083 )
3084 .await?;
3085 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3086 session
3087 .peer
3088 .forward_send(session.connection_id, peer_id.into(), request)?;
3089 Ok(())
3090}
3091
3092/// Notify other participants that a buffer has been updated. This is
3093/// allowed for guests as long as the update is limited to selections.
3094async fn update_buffer(
3095 request: proto::UpdateBuffer,
3096 response: Response<proto::UpdateBuffer>,
3097 session: Session,
3098) -> Result<()> {
3099 let project_id = ProjectId::from_proto(request.project_id);
3100 let mut capability = Capability::ReadOnly;
3101
3102 for op in request.operations.iter() {
3103 match op.variant {
3104 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3105 Some(_) => capability = Capability::ReadWrite,
3106 }
3107 }
3108
3109 let host = {
3110 let guard = session
3111 .db()
3112 .await
3113 .connections_for_buffer_update(
3114 project_id,
3115 session.principal_id(),
3116 session.connection_id,
3117 capability,
3118 )
3119 .await?;
3120
3121 let (host, guests) = &*guard;
3122
3123 broadcast(
3124 Some(session.connection_id),
3125 guests.clone(),
3126 |connection_id| {
3127 session
3128 .peer
3129 .forward_send(session.connection_id, connection_id, request.clone())
3130 },
3131 );
3132
3133 *host
3134 };
3135
3136 if host != session.connection_id {
3137 session
3138 .peer
3139 .forward_request(session.connection_id, host, request.clone())
3140 .await?;
3141 }
3142
3143 response.send(proto::Ack {})?;
3144 Ok(())
3145}
3146
3147async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3148 let project_id = ProjectId::from_proto(message.project_id);
3149
3150 let operation = message.operation.as_ref().context("invalid operation")?;
3151 let capability = match operation.variant.as_ref() {
3152 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3153 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3154 match buffer_op.variant {
3155 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3156 Capability::ReadOnly
3157 }
3158 _ => Capability::ReadWrite,
3159 }
3160 } else {
3161 Capability::ReadWrite
3162 }
3163 }
3164 Some(_) => Capability::ReadWrite,
3165 None => Capability::ReadOnly,
3166 };
3167
3168 let guard = session
3169 .db()
3170 .await
3171 .connections_for_buffer_update(
3172 project_id,
3173 session.principal_id(),
3174 session.connection_id,
3175 capability,
3176 )
3177 .await?;
3178
3179 let (host, guests) = &*guard;
3180
3181 broadcast(
3182 Some(session.connection_id),
3183 guests.iter().chain([host]).copied(),
3184 |connection_id| {
3185 session
3186 .peer
3187 .forward_send(session.connection_id, connection_id, message.clone())
3188 },
3189 );
3190
3191 Ok(())
3192}
3193
3194/// Notify other participants that a project has been updated.
3195async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3196 request: T,
3197 session: Session,
3198) -> Result<()> {
3199 let project_id = ProjectId::from_proto(request.remote_entity_id());
3200 let project_connection_ids = session
3201 .db()
3202 .await
3203 .project_connection_ids(project_id, session.connection_id, false)
3204 .await?;
3205
3206 broadcast(
3207 Some(session.connection_id),
3208 project_connection_ids.iter().copied(),
3209 |connection_id| {
3210 session
3211 .peer
3212 .forward_send(session.connection_id, connection_id, request.clone())
3213 },
3214 );
3215 Ok(())
3216}
3217
3218/// Start following another user in a call.
3219async fn follow(
3220 request: proto::Follow,
3221 response: Response<proto::Follow>,
3222 session: UserSession,
3223) -> Result<()> {
3224 let room_id = RoomId::from_proto(request.room_id);
3225 let project_id = request.project_id.map(ProjectId::from_proto);
3226 let leader_id = request
3227 .leader_id
3228 .ok_or_else(|| anyhow!("invalid leader id"))?
3229 .into();
3230 let follower_id = session.connection_id;
3231
3232 session
3233 .db()
3234 .await
3235 .check_room_participants(room_id, leader_id, session.connection_id)
3236 .await?;
3237
3238 let response_payload = session
3239 .peer
3240 .forward_request(session.connection_id, leader_id, request)
3241 .await?;
3242 response.send(response_payload)?;
3243
3244 if let Some(project_id) = project_id {
3245 let room = session
3246 .db()
3247 .await
3248 .follow(room_id, project_id, leader_id, follower_id)
3249 .await?;
3250 room_updated(&room, &session.peer);
3251 }
3252
3253 Ok(())
3254}
3255
3256/// Stop following another user in a call.
3257async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3258 let room_id = RoomId::from_proto(request.room_id);
3259 let project_id = request.project_id.map(ProjectId::from_proto);
3260 let leader_id = request
3261 .leader_id
3262 .ok_or_else(|| anyhow!("invalid leader id"))?
3263 .into();
3264 let follower_id = session.connection_id;
3265
3266 session
3267 .db()
3268 .await
3269 .check_room_participants(room_id, leader_id, session.connection_id)
3270 .await?;
3271
3272 session
3273 .peer
3274 .forward_send(session.connection_id, leader_id, request)?;
3275
3276 if let Some(project_id) = project_id {
3277 let room = session
3278 .db()
3279 .await
3280 .unfollow(room_id, project_id, leader_id, follower_id)
3281 .await?;
3282 room_updated(&room, &session.peer);
3283 }
3284
3285 Ok(())
3286}
3287
3288/// Notify everyone following you of your current location.
3289async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3290 let room_id = RoomId::from_proto(request.room_id);
3291 let database = session.db.lock().await;
3292
3293 let connection_ids = if let Some(project_id) = request.project_id {
3294 let project_id = ProjectId::from_proto(project_id);
3295 database
3296 .project_connection_ids(project_id, session.connection_id, true)
3297 .await?
3298 } else {
3299 database
3300 .room_connection_ids(room_id, session.connection_id)
3301 .await?
3302 };
3303
3304 // For now, don't send view update messages back to that view's current leader.
3305 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3306 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3307 _ => None,
3308 });
3309
3310 for connection_id in connection_ids.iter().cloned() {
3311 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3312 session
3313 .peer
3314 .forward_send(session.connection_id, connection_id, request.clone())?;
3315 }
3316 }
3317 Ok(())
3318}
3319
3320/// Get public data about users.
3321async fn get_users(
3322 request: proto::GetUsers,
3323 response: Response<proto::GetUsers>,
3324 session: Session,
3325) -> Result<()> {
3326 let user_ids = request
3327 .user_ids
3328 .into_iter()
3329 .map(UserId::from_proto)
3330 .collect();
3331 let users = session
3332 .db()
3333 .await
3334 .get_users_by_ids(user_ids)
3335 .await?
3336 .into_iter()
3337 .map(|user| proto::User {
3338 id: user.id.to_proto(),
3339 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3340 github_login: user.github_login,
3341 })
3342 .collect();
3343 response.send(proto::UsersResponse { users })?;
3344 Ok(())
3345}
3346
3347/// Search for users (to invite) buy Github login
3348async fn fuzzy_search_users(
3349 request: proto::FuzzySearchUsers,
3350 response: Response<proto::FuzzySearchUsers>,
3351 session: UserSession,
3352) -> Result<()> {
3353 let query = request.query;
3354 let users = match query.len() {
3355 0 => vec![],
3356 1 | 2 => session
3357 .db()
3358 .await
3359 .get_user_by_github_login(&query)
3360 .await?
3361 .into_iter()
3362 .collect(),
3363 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3364 };
3365 let users = users
3366 .into_iter()
3367 .filter(|user| user.id != session.user_id())
3368 .map(|user| proto::User {
3369 id: user.id.to_proto(),
3370 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3371 github_login: user.github_login,
3372 })
3373 .collect();
3374 response.send(proto::UsersResponse { users })?;
3375 Ok(())
3376}
3377
3378/// Send a contact request to another user.
3379async fn request_contact(
3380 request: proto::RequestContact,
3381 response: Response<proto::RequestContact>,
3382 session: UserSession,
3383) -> Result<()> {
3384 let requester_id = session.user_id();
3385 let responder_id = UserId::from_proto(request.responder_id);
3386 if requester_id == responder_id {
3387 return Err(anyhow!("cannot add yourself as a contact"))?;
3388 }
3389
3390 let notifications = session
3391 .db()
3392 .await
3393 .send_contact_request(requester_id, responder_id)
3394 .await?;
3395
3396 // Update outgoing contact requests of requester
3397 let mut update = proto::UpdateContacts::default();
3398 update.outgoing_requests.push(responder_id.to_proto());
3399 for connection_id in session
3400 .connection_pool()
3401 .await
3402 .user_connection_ids(requester_id)
3403 {
3404 session.peer.send(connection_id, update.clone())?;
3405 }
3406
3407 // Update incoming contact requests of responder
3408 let mut update = proto::UpdateContacts::default();
3409 update
3410 .incoming_requests
3411 .push(proto::IncomingContactRequest {
3412 requester_id: requester_id.to_proto(),
3413 });
3414 let connection_pool = session.connection_pool().await;
3415 for connection_id in connection_pool.user_connection_ids(responder_id) {
3416 session.peer.send(connection_id, update.clone())?;
3417 }
3418
3419 send_notifications(&connection_pool, &session.peer, notifications);
3420
3421 response.send(proto::Ack {})?;
3422 Ok(())
3423}
3424
3425/// Accept or decline a contact request
3426async fn respond_to_contact_request(
3427 request: proto::RespondToContactRequest,
3428 response: Response<proto::RespondToContactRequest>,
3429 session: UserSession,
3430) -> Result<()> {
3431 let responder_id = session.user_id();
3432 let requester_id = UserId::from_proto(request.requester_id);
3433 let db = session.db().await;
3434 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3435 db.dismiss_contact_notification(responder_id, requester_id)
3436 .await?;
3437 } else {
3438 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3439
3440 let notifications = db
3441 .respond_to_contact_request(responder_id, requester_id, accept)
3442 .await?;
3443 let requester_busy = db.is_user_busy(requester_id).await?;
3444 let responder_busy = db.is_user_busy(responder_id).await?;
3445
3446 let pool = session.connection_pool().await;
3447 // Update responder with new contact
3448 let mut update = proto::UpdateContacts::default();
3449 if accept {
3450 update
3451 .contacts
3452 .push(contact_for_user(requester_id, requester_busy, &pool));
3453 }
3454 update
3455 .remove_incoming_requests
3456 .push(requester_id.to_proto());
3457 for connection_id in pool.user_connection_ids(responder_id) {
3458 session.peer.send(connection_id, update.clone())?;
3459 }
3460
3461 // Update requester with new contact
3462 let mut update = proto::UpdateContacts::default();
3463 if accept {
3464 update
3465 .contacts
3466 .push(contact_for_user(responder_id, responder_busy, &pool));
3467 }
3468 update
3469 .remove_outgoing_requests
3470 .push(responder_id.to_proto());
3471
3472 for connection_id in pool.user_connection_ids(requester_id) {
3473 session.peer.send(connection_id, update.clone())?;
3474 }
3475
3476 send_notifications(&pool, &session.peer, notifications);
3477 }
3478
3479 response.send(proto::Ack {})?;
3480 Ok(())
3481}
3482
3483/// Remove a contact.
3484async fn remove_contact(
3485 request: proto::RemoveContact,
3486 response: Response<proto::RemoveContact>,
3487 session: UserSession,
3488) -> Result<()> {
3489 let requester_id = session.user_id();
3490 let responder_id = UserId::from_proto(request.user_id);
3491 let db = session.db().await;
3492 let (contact_accepted, deleted_notification_id) =
3493 db.remove_contact(requester_id, responder_id).await?;
3494
3495 let pool = session.connection_pool().await;
3496 // Update outgoing contact requests of requester
3497 let mut update = proto::UpdateContacts::default();
3498 if contact_accepted {
3499 update.remove_contacts.push(responder_id.to_proto());
3500 } else {
3501 update
3502 .remove_outgoing_requests
3503 .push(responder_id.to_proto());
3504 }
3505 for connection_id in pool.user_connection_ids(requester_id) {
3506 session.peer.send(connection_id, update.clone())?;
3507 }
3508
3509 // Update incoming contact requests of responder
3510 let mut update = proto::UpdateContacts::default();
3511 if contact_accepted {
3512 update.remove_contacts.push(requester_id.to_proto());
3513 } else {
3514 update
3515 .remove_incoming_requests
3516 .push(requester_id.to_proto());
3517 }
3518 for connection_id in pool.user_connection_ids(responder_id) {
3519 session.peer.send(connection_id, update.clone())?;
3520 if let Some(notification_id) = deleted_notification_id {
3521 session.peer.send(
3522 connection_id,
3523 proto::DeleteNotification {
3524 notification_id: notification_id.to_proto(),
3525 },
3526 )?;
3527 }
3528 }
3529
3530 response.send(proto::Ack {})?;
3531 Ok(())
3532}
3533
3534fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3535 version.0.minor() < 139
3536}
3537
3538async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3539 subscribe_user_to_channels(
3540 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3541 &session,
3542 )
3543 .await?;
3544 Ok(())
3545}
3546
3547async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3548 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3549 let mut pool = session.connection_pool().await;
3550 for membership in &channels_for_user.channel_memberships {
3551 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3552 }
3553 session.peer.send(
3554 session.connection_id,
3555 build_update_user_channels(&channels_for_user),
3556 )?;
3557 session.peer.send(
3558 session.connection_id,
3559 build_channels_update(channels_for_user),
3560 )?;
3561 Ok(())
3562}
3563
3564/// Creates a new channel.
3565async fn create_channel(
3566 request: proto::CreateChannel,
3567 response: Response<proto::CreateChannel>,
3568 session: UserSession,
3569) -> Result<()> {
3570 let db = session.db().await;
3571
3572 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3573 let (channel, membership) = db
3574 .create_channel(&request.name, parent_id, session.user_id())
3575 .await?;
3576
3577 let root_id = channel.root_id();
3578 let channel = Channel::from_model(channel);
3579
3580 response.send(proto::CreateChannelResponse {
3581 channel: Some(channel.to_proto()),
3582 parent_id: request.parent_id,
3583 })?;
3584
3585 let mut connection_pool = session.connection_pool().await;
3586 if let Some(membership) = membership {
3587 connection_pool.subscribe_to_channel(
3588 membership.user_id,
3589 membership.channel_id,
3590 membership.role,
3591 );
3592 let update = proto::UpdateUserChannels {
3593 channel_memberships: vec![proto::ChannelMembership {
3594 channel_id: membership.channel_id.to_proto(),
3595 role: membership.role.into(),
3596 }],
3597 ..Default::default()
3598 };
3599 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3600 session.peer.send(connection_id, update.clone())?;
3601 }
3602 }
3603
3604 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3605 if !role.can_see_channel(channel.visibility) {
3606 continue;
3607 }
3608
3609 let update = proto::UpdateChannels {
3610 channels: vec![channel.to_proto()],
3611 ..Default::default()
3612 };
3613 session.peer.send(connection_id, update.clone())?;
3614 }
3615
3616 Ok(())
3617}
3618
3619/// Delete a channel
3620async fn delete_channel(
3621 request: proto::DeleteChannel,
3622 response: Response<proto::DeleteChannel>,
3623 session: UserSession,
3624) -> Result<()> {
3625 let db = session.db().await;
3626
3627 let channel_id = request.channel_id;
3628 let (root_channel, removed_channels) = db
3629 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3630 .await?;
3631 response.send(proto::Ack {})?;
3632
3633 // Notify members of removed channels
3634 let mut update = proto::UpdateChannels::default();
3635 update
3636 .delete_channels
3637 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3638
3639 let connection_pool = session.connection_pool().await;
3640 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3641 session.peer.send(connection_id, update.clone())?;
3642 }
3643
3644 Ok(())
3645}
3646
3647/// Invite someone to join a channel.
3648async fn invite_channel_member(
3649 request: proto::InviteChannelMember,
3650 response: Response<proto::InviteChannelMember>,
3651 session: UserSession,
3652) -> Result<()> {
3653 let db = session.db().await;
3654 let channel_id = ChannelId::from_proto(request.channel_id);
3655 let invitee_id = UserId::from_proto(request.user_id);
3656 let InviteMemberResult {
3657 channel,
3658 notifications,
3659 } = db
3660 .invite_channel_member(
3661 channel_id,
3662 invitee_id,
3663 session.user_id(),
3664 request.role().into(),
3665 )
3666 .await?;
3667
3668 let update = proto::UpdateChannels {
3669 channel_invitations: vec![channel.to_proto()],
3670 ..Default::default()
3671 };
3672
3673 let connection_pool = session.connection_pool().await;
3674 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3675 session.peer.send(connection_id, update.clone())?;
3676 }
3677
3678 send_notifications(&connection_pool, &session.peer, notifications);
3679
3680 response.send(proto::Ack {})?;
3681 Ok(())
3682}
3683
3684/// remove someone from a channel
3685async fn remove_channel_member(
3686 request: proto::RemoveChannelMember,
3687 response: Response<proto::RemoveChannelMember>,
3688 session: UserSession,
3689) -> Result<()> {
3690 let db = session.db().await;
3691 let channel_id = ChannelId::from_proto(request.channel_id);
3692 let member_id = UserId::from_proto(request.user_id);
3693
3694 let RemoveChannelMemberResult {
3695 membership_update,
3696 notification_id,
3697 } = db
3698 .remove_channel_member(channel_id, member_id, session.user_id())
3699 .await?;
3700
3701 let mut connection_pool = session.connection_pool().await;
3702 notify_membership_updated(
3703 &mut connection_pool,
3704 membership_update,
3705 member_id,
3706 &session.peer,
3707 );
3708 for connection_id in connection_pool.user_connection_ids(member_id) {
3709 if let Some(notification_id) = notification_id {
3710 session
3711 .peer
3712 .send(
3713 connection_id,
3714 proto::DeleteNotification {
3715 notification_id: notification_id.to_proto(),
3716 },
3717 )
3718 .trace_err();
3719 }
3720 }
3721
3722 response.send(proto::Ack {})?;
3723 Ok(())
3724}
3725
3726/// Toggle the channel between public and private.
3727/// Care is taken to maintain the invariant that public channels only descend from public channels,
3728/// (though members-only channels can appear at any point in the hierarchy).
3729async fn set_channel_visibility(
3730 request: proto::SetChannelVisibility,
3731 response: Response<proto::SetChannelVisibility>,
3732 session: UserSession,
3733) -> Result<()> {
3734 let db = session.db().await;
3735 let channel_id = ChannelId::from_proto(request.channel_id);
3736 let visibility = request.visibility().into();
3737
3738 let channel_model = db
3739 .set_channel_visibility(channel_id, visibility, session.user_id())
3740 .await?;
3741 let root_id = channel_model.root_id();
3742 let channel = Channel::from_model(channel_model);
3743
3744 let mut connection_pool = session.connection_pool().await;
3745 for (user_id, role) in connection_pool
3746 .channel_user_ids(root_id)
3747 .collect::<Vec<_>>()
3748 .into_iter()
3749 {
3750 let update = if role.can_see_channel(channel.visibility) {
3751 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3752 proto::UpdateChannels {
3753 channels: vec![channel.to_proto()],
3754 ..Default::default()
3755 }
3756 } else {
3757 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3758 proto::UpdateChannels {
3759 delete_channels: vec![channel.id.to_proto()],
3760 ..Default::default()
3761 }
3762 };
3763
3764 for connection_id in connection_pool.user_connection_ids(user_id) {
3765 session.peer.send(connection_id, update.clone())?;
3766 }
3767 }
3768
3769 response.send(proto::Ack {})?;
3770 Ok(())
3771}
3772
3773/// Alter the role for a user in the channel.
3774async fn set_channel_member_role(
3775 request: proto::SetChannelMemberRole,
3776 response: Response<proto::SetChannelMemberRole>,
3777 session: UserSession,
3778) -> Result<()> {
3779 let db = session.db().await;
3780 let channel_id = ChannelId::from_proto(request.channel_id);
3781 let member_id = UserId::from_proto(request.user_id);
3782 let result = db
3783 .set_channel_member_role(
3784 channel_id,
3785 session.user_id(),
3786 member_id,
3787 request.role().into(),
3788 )
3789 .await?;
3790
3791 match result {
3792 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3793 let mut connection_pool = session.connection_pool().await;
3794 notify_membership_updated(
3795 &mut connection_pool,
3796 membership_update,
3797 member_id,
3798 &session.peer,
3799 )
3800 }
3801 db::SetMemberRoleResult::InviteUpdated(channel) => {
3802 let update = proto::UpdateChannels {
3803 channel_invitations: vec![channel.to_proto()],
3804 ..Default::default()
3805 };
3806
3807 for connection_id in session
3808 .connection_pool()
3809 .await
3810 .user_connection_ids(member_id)
3811 {
3812 session.peer.send(connection_id, update.clone())?;
3813 }
3814 }
3815 }
3816
3817 response.send(proto::Ack {})?;
3818 Ok(())
3819}
3820
3821/// Change the name of a channel
3822async fn rename_channel(
3823 request: proto::RenameChannel,
3824 response: Response<proto::RenameChannel>,
3825 session: UserSession,
3826) -> Result<()> {
3827 let db = session.db().await;
3828 let channel_id = ChannelId::from_proto(request.channel_id);
3829 let channel_model = db
3830 .rename_channel(channel_id, session.user_id(), &request.name)
3831 .await?;
3832 let root_id = channel_model.root_id();
3833 let channel = Channel::from_model(channel_model);
3834
3835 response.send(proto::RenameChannelResponse {
3836 channel: Some(channel.to_proto()),
3837 })?;
3838
3839 let connection_pool = session.connection_pool().await;
3840 let update = proto::UpdateChannels {
3841 channels: vec![channel.to_proto()],
3842 ..Default::default()
3843 };
3844 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3845 if role.can_see_channel(channel.visibility) {
3846 session.peer.send(connection_id, update.clone())?;
3847 }
3848 }
3849
3850 Ok(())
3851}
3852
3853/// Move a channel to a new parent.
3854async fn move_channel(
3855 request: proto::MoveChannel,
3856 response: Response<proto::MoveChannel>,
3857 session: UserSession,
3858) -> Result<()> {
3859 let channel_id = ChannelId::from_proto(request.channel_id);
3860 let to = ChannelId::from_proto(request.to);
3861
3862 let (root_id, channels) = session
3863 .db()
3864 .await
3865 .move_channel(channel_id, to, session.user_id())
3866 .await?;
3867
3868 let connection_pool = session.connection_pool().await;
3869 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3870 let channels = channels
3871 .iter()
3872 .filter_map(|channel| {
3873 if role.can_see_channel(channel.visibility) {
3874 Some(channel.to_proto())
3875 } else {
3876 None
3877 }
3878 })
3879 .collect::<Vec<_>>();
3880 if channels.is_empty() {
3881 continue;
3882 }
3883
3884 let update = proto::UpdateChannels {
3885 channels,
3886 ..Default::default()
3887 };
3888
3889 session.peer.send(connection_id, update.clone())?;
3890 }
3891
3892 response.send(Ack {})?;
3893 Ok(())
3894}
3895
3896/// Get the list of channel members
3897async fn get_channel_members(
3898 request: proto::GetChannelMembers,
3899 response: Response<proto::GetChannelMembers>,
3900 session: UserSession,
3901) -> Result<()> {
3902 let db = session.db().await;
3903 let channel_id = ChannelId::from_proto(request.channel_id);
3904 let limit = if request.limit == 0 {
3905 u16::MAX as u64
3906 } else {
3907 request.limit
3908 };
3909 let (members, users) = db
3910 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3911 .await?;
3912 response.send(proto::GetChannelMembersResponse { members, users })?;
3913 Ok(())
3914}
3915
3916/// Accept or decline a channel invitation.
3917async fn respond_to_channel_invite(
3918 request: proto::RespondToChannelInvite,
3919 response: Response<proto::RespondToChannelInvite>,
3920 session: UserSession,
3921) -> Result<()> {
3922 let db = session.db().await;
3923 let channel_id = ChannelId::from_proto(request.channel_id);
3924 let RespondToChannelInvite {
3925 membership_update,
3926 notifications,
3927 } = db
3928 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3929 .await?;
3930
3931 let mut connection_pool = session.connection_pool().await;
3932 if let Some(membership_update) = membership_update {
3933 notify_membership_updated(
3934 &mut connection_pool,
3935 membership_update,
3936 session.user_id(),
3937 &session.peer,
3938 );
3939 } else {
3940 let update = proto::UpdateChannels {
3941 remove_channel_invitations: vec![channel_id.to_proto()],
3942 ..Default::default()
3943 };
3944
3945 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3946 session.peer.send(connection_id, update.clone())?;
3947 }
3948 };
3949
3950 send_notifications(&connection_pool, &session.peer, notifications);
3951
3952 response.send(proto::Ack {})?;
3953
3954 Ok(())
3955}
3956
3957/// Join the channels' room
3958async fn join_channel(
3959 request: proto::JoinChannel,
3960 response: Response<proto::JoinChannel>,
3961 session: UserSession,
3962) -> Result<()> {
3963 let channel_id = ChannelId::from_proto(request.channel_id);
3964 join_channel_internal(channel_id, Box::new(response), session).await
3965}
3966
3967trait JoinChannelInternalResponse {
3968 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3969}
3970impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3971 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3972 Response::<proto::JoinChannel>::send(self, result)
3973 }
3974}
3975impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3976 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3977 Response::<proto::JoinRoom>::send(self, result)
3978 }
3979}
3980
3981async fn join_channel_internal(
3982 channel_id: ChannelId,
3983 response: Box<impl JoinChannelInternalResponse>,
3984 session: UserSession,
3985) -> Result<()> {
3986 let joined_room = {
3987 let mut db = session.db().await;
3988 // If zed quits without leaving the room, and the user re-opens zed before the
3989 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3990 // room they were in.
3991 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3992 tracing::info!(
3993 stale_connection_id = %connection,
3994 "cleaning up stale connection",
3995 );
3996 drop(db);
3997 leave_room_for_session(&session, connection).await?;
3998 db = session.db().await;
3999 }
4000
4001 let (joined_room, membership_updated, role) = db
4002 .join_channel(channel_id, session.user_id(), session.connection_id)
4003 .await?;
4004
4005 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
4006 let (can_publish, token) = if role == ChannelRole::Guest {
4007 (
4008 false,
4009 live_kit
4010 .guest_token(
4011 &joined_room.room.live_kit_room,
4012 &session.user_id().to_string(),
4013 )
4014 .trace_err()?,
4015 )
4016 } else {
4017 (
4018 true,
4019 live_kit
4020 .room_token(
4021 &joined_room.room.live_kit_room,
4022 &session.user_id().to_string(),
4023 )
4024 .trace_err()?,
4025 )
4026 };
4027
4028 Some(LiveKitConnectionInfo {
4029 server_url: live_kit.url().into(),
4030 token,
4031 can_publish,
4032 })
4033 });
4034
4035 response.send(proto::JoinRoomResponse {
4036 room: Some(joined_room.room.clone()),
4037 channel_id: joined_room
4038 .channel
4039 .as_ref()
4040 .map(|channel| channel.id.to_proto()),
4041 live_kit_connection_info,
4042 })?;
4043
4044 let mut connection_pool = session.connection_pool().await;
4045 if let Some(membership_updated) = membership_updated {
4046 notify_membership_updated(
4047 &mut connection_pool,
4048 membership_updated,
4049 session.user_id(),
4050 &session.peer,
4051 );
4052 }
4053
4054 room_updated(&joined_room.room, &session.peer);
4055
4056 joined_room
4057 };
4058
4059 channel_updated(
4060 &joined_room
4061 .channel
4062 .ok_or_else(|| anyhow!("channel not returned"))?,
4063 &joined_room.room,
4064 &session.peer,
4065 &*session.connection_pool().await,
4066 );
4067
4068 update_user_contacts(session.user_id(), &session).await?;
4069 Ok(())
4070}
4071
4072/// Start editing the channel notes
4073async fn join_channel_buffer(
4074 request: proto::JoinChannelBuffer,
4075 response: Response<proto::JoinChannelBuffer>,
4076 session: UserSession,
4077) -> Result<()> {
4078 let db = session.db().await;
4079 let channel_id = ChannelId::from_proto(request.channel_id);
4080
4081 let open_response = db
4082 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4083 .await?;
4084
4085 let collaborators = open_response.collaborators.clone();
4086 response.send(open_response)?;
4087
4088 let update = UpdateChannelBufferCollaborators {
4089 channel_id: channel_id.to_proto(),
4090 collaborators: collaborators.clone(),
4091 };
4092 channel_buffer_updated(
4093 session.connection_id,
4094 collaborators
4095 .iter()
4096 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4097 &update,
4098 &session.peer,
4099 );
4100
4101 Ok(())
4102}
4103
4104/// Edit the channel notes
4105async fn update_channel_buffer(
4106 request: proto::UpdateChannelBuffer,
4107 session: UserSession,
4108) -> Result<()> {
4109 let db = session.db().await;
4110 let channel_id = ChannelId::from_proto(request.channel_id);
4111
4112 let (collaborators, epoch, version) = db
4113 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4114 .await?;
4115
4116 channel_buffer_updated(
4117 session.connection_id,
4118 collaborators.clone(),
4119 &proto::UpdateChannelBuffer {
4120 channel_id: channel_id.to_proto(),
4121 operations: request.operations,
4122 },
4123 &session.peer,
4124 );
4125
4126 let pool = &*session.connection_pool().await;
4127
4128 let non_collaborators =
4129 pool.channel_connection_ids(channel_id)
4130 .filter_map(|(connection_id, _)| {
4131 if collaborators.contains(&connection_id) {
4132 None
4133 } else {
4134 Some(connection_id)
4135 }
4136 });
4137
4138 broadcast(None, non_collaborators, |peer_id| {
4139 session.peer.send(
4140 peer_id,
4141 proto::UpdateChannels {
4142 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4143 channel_id: channel_id.to_proto(),
4144 epoch: epoch as u64,
4145 version: version.clone(),
4146 }],
4147 ..Default::default()
4148 },
4149 )
4150 });
4151
4152 Ok(())
4153}
4154
4155/// Rejoin the channel notes after a connection blip
4156async fn rejoin_channel_buffers(
4157 request: proto::RejoinChannelBuffers,
4158 response: Response<proto::RejoinChannelBuffers>,
4159 session: UserSession,
4160) -> Result<()> {
4161 let db = session.db().await;
4162 let buffers = db
4163 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4164 .await?;
4165
4166 for rejoined_buffer in &buffers {
4167 let collaborators_to_notify = rejoined_buffer
4168 .buffer
4169 .collaborators
4170 .iter()
4171 .filter_map(|c| Some(c.peer_id?.into()));
4172 channel_buffer_updated(
4173 session.connection_id,
4174 collaborators_to_notify,
4175 &proto::UpdateChannelBufferCollaborators {
4176 channel_id: rejoined_buffer.buffer.channel_id,
4177 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4178 },
4179 &session.peer,
4180 );
4181 }
4182
4183 response.send(proto::RejoinChannelBuffersResponse {
4184 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4185 })?;
4186
4187 Ok(())
4188}
4189
4190/// Stop editing the channel notes
4191async fn leave_channel_buffer(
4192 request: proto::LeaveChannelBuffer,
4193 response: Response<proto::LeaveChannelBuffer>,
4194 session: UserSession,
4195) -> Result<()> {
4196 let db = session.db().await;
4197 let channel_id = ChannelId::from_proto(request.channel_id);
4198
4199 let left_buffer = db
4200 .leave_channel_buffer(channel_id, session.connection_id)
4201 .await?;
4202
4203 response.send(Ack {})?;
4204
4205 channel_buffer_updated(
4206 session.connection_id,
4207 left_buffer.connections,
4208 &proto::UpdateChannelBufferCollaborators {
4209 channel_id: channel_id.to_proto(),
4210 collaborators: left_buffer.collaborators,
4211 },
4212 &session.peer,
4213 );
4214
4215 Ok(())
4216}
4217
4218fn channel_buffer_updated<T: EnvelopedMessage>(
4219 sender_id: ConnectionId,
4220 collaborators: impl IntoIterator<Item = ConnectionId>,
4221 message: &T,
4222 peer: &Peer,
4223) {
4224 broadcast(Some(sender_id), collaborators, |peer_id| {
4225 peer.send(peer_id, message.clone())
4226 });
4227}
4228
4229fn send_notifications(
4230 connection_pool: &ConnectionPool,
4231 peer: &Peer,
4232 notifications: db::NotificationBatch,
4233) {
4234 for (user_id, notification) in notifications {
4235 for connection_id in connection_pool.user_connection_ids(user_id) {
4236 if let Err(error) = peer.send(
4237 connection_id,
4238 proto::AddNotification {
4239 notification: Some(notification.clone()),
4240 },
4241 ) {
4242 tracing::error!(
4243 "failed to send notification to {:?} {}",
4244 connection_id,
4245 error
4246 );
4247 }
4248 }
4249 }
4250}
4251
4252/// Send a message to the channel
4253async fn send_channel_message(
4254 request: proto::SendChannelMessage,
4255 response: Response<proto::SendChannelMessage>,
4256 session: UserSession,
4257) -> Result<()> {
4258 // Validate the message body.
4259 let body = request.body.trim().to_string();
4260 if body.len() > MAX_MESSAGE_LEN {
4261 return Err(anyhow!("message is too long"))?;
4262 }
4263 if body.is_empty() {
4264 return Err(anyhow!("message can't be blank"))?;
4265 }
4266
4267 // TODO: adjust mentions if body is trimmed
4268
4269 let timestamp = OffsetDateTime::now_utc();
4270 let nonce = request
4271 .nonce
4272 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4273
4274 let channel_id = ChannelId::from_proto(request.channel_id);
4275 let CreatedChannelMessage {
4276 message_id,
4277 participant_connection_ids,
4278 notifications,
4279 } = session
4280 .db()
4281 .await
4282 .create_channel_message(
4283 channel_id,
4284 session.user_id(),
4285 &body,
4286 &request.mentions,
4287 timestamp,
4288 nonce.clone().into(),
4289 match request.reply_to_message_id {
4290 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4291 None => None,
4292 },
4293 )
4294 .await?;
4295
4296 let message = proto::ChannelMessage {
4297 sender_id: session.user_id().to_proto(),
4298 id: message_id.to_proto(),
4299 body,
4300 mentions: request.mentions,
4301 timestamp: timestamp.unix_timestamp() as u64,
4302 nonce: Some(nonce),
4303 reply_to_message_id: request.reply_to_message_id,
4304 edited_at: None,
4305 };
4306 broadcast(
4307 Some(session.connection_id),
4308 participant_connection_ids.clone(),
4309 |connection| {
4310 session.peer.send(
4311 connection,
4312 proto::ChannelMessageSent {
4313 channel_id: channel_id.to_proto(),
4314 message: Some(message.clone()),
4315 },
4316 )
4317 },
4318 );
4319 response.send(proto::SendChannelMessageResponse {
4320 message: Some(message),
4321 })?;
4322
4323 let pool = &*session.connection_pool().await;
4324 let non_participants =
4325 pool.channel_connection_ids(channel_id)
4326 .filter_map(|(connection_id, _)| {
4327 if participant_connection_ids.contains(&connection_id) {
4328 None
4329 } else {
4330 Some(connection_id)
4331 }
4332 });
4333 broadcast(None, non_participants, |peer_id| {
4334 session.peer.send(
4335 peer_id,
4336 proto::UpdateChannels {
4337 latest_channel_message_ids: vec![proto::ChannelMessageId {
4338 channel_id: channel_id.to_proto(),
4339 message_id: message_id.to_proto(),
4340 }],
4341 ..Default::default()
4342 },
4343 )
4344 });
4345 send_notifications(pool, &session.peer, notifications);
4346
4347 Ok(())
4348}
4349
4350/// Delete a channel message
4351async fn remove_channel_message(
4352 request: proto::RemoveChannelMessage,
4353 response: Response<proto::RemoveChannelMessage>,
4354 session: UserSession,
4355) -> Result<()> {
4356 let channel_id = ChannelId::from_proto(request.channel_id);
4357 let message_id = MessageId::from_proto(request.message_id);
4358 let (connection_ids, existing_notification_ids) = session
4359 .db()
4360 .await
4361 .remove_channel_message(channel_id, message_id, session.user_id())
4362 .await?;
4363
4364 broadcast(
4365 Some(session.connection_id),
4366 connection_ids,
4367 move |connection| {
4368 session.peer.send(connection, request.clone())?;
4369
4370 for notification_id in &existing_notification_ids {
4371 session.peer.send(
4372 connection,
4373 proto::DeleteNotification {
4374 notification_id: (*notification_id).to_proto(),
4375 },
4376 )?;
4377 }
4378
4379 Ok(())
4380 },
4381 );
4382 response.send(proto::Ack {})?;
4383 Ok(())
4384}
4385
4386async fn update_channel_message(
4387 request: proto::UpdateChannelMessage,
4388 response: Response<proto::UpdateChannelMessage>,
4389 session: UserSession,
4390) -> Result<()> {
4391 let channel_id = ChannelId::from_proto(request.channel_id);
4392 let message_id = MessageId::from_proto(request.message_id);
4393 let updated_at = OffsetDateTime::now_utc();
4394 let UpdatedChannelMessage {
4395 message_id,
4396 participant_connection_ids,
4397 notifications,
4398 reply_to_message_id,
4399 timestamp,
4400 deleted_mention_notification_ids,
4401 updated_mention_notifications,
4402 } = session
4403 .db()
4404 .await
4405 .update_channel_message(
4406 channel_id,
4407 message_id,
4408 session.user_id(),
4409 request.body.as_str(),
4410 &request.mentions,
4411 updated_at,
4412 )
4413 .await?;
4414
4415 let nonce = request
4416 .nonce
4417 .clone()
4418 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4419
4420 let message = proto::ChannelMessage {
4421 sender_id: session.user_id().to_proto(),
4422 id: message_id.to_proto(),
4423 body: request.body.clone(),
4424 mentions: request.mentions.clone(),
4425 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4426 nonce: Some(nonce),
4427 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4428 edited_at: Some(updated_at.unix_timestamp() as u64),
4429 };
4430
4431 response.send(proto::Ack {})?;
4432
4433 let pool = &*session.connection_pool().await;
4434 broadcast(
4435 Some(session.connection_id),
4436 participant_connection_ids,
4437 |connection| {
4438 session.peer.send(
4439 connection,
4440 proto::ChannelMessageUpdate {
4441 channel_id: channel_id.to_proto(),
4442 message: Some(message.clone()),
4443 },
4444 )?;
4445
4446 for notification_id in &deleted_mention_notification_ids {
4447 session.peer.send(
4448 connection,
4449 proto::DeleteNotification {
4450 notification_id: (*notification_id).to_proto(),
4451 },
4452 )?;
4453 }
4454
4455 for notification in &updated_mention_notifications {
4456 session.peer.send(
4457 connection,
4458 proto::UpdateNotification {
4459 notification: Some(notification.clone()),
4460 },
4461 )?;
4462 }
4463
4464 Ok(())
4465 },
4466 );
4467
4468 send_notifications(pool, &session.peer, notifications);
4469
4470 Ok(())
4471}
4472
4473/// Mark a channel message as read
4474async fn acknowledge_channel_message(
4475 request: proto::AckChannelMessage,
4476 session: UserSession,
4477) -> Result<()> {
4478 let channel_id = ChannelId::from_proto(request.channel_id);
4479 let message_id = MessageId::from_proto(request.message_id);
4480 let notifications = session
4481 .db()
4482 .await
4483 .observe_channel_message(channel_id, session.user_id(), message_id)
4484 .await?;
4485 send_notifications(
4486 &*session.connection_pool().await,
4487 &session.peer,
4488 notifications,
4489 );
4490 Ok(())
4491}
4492
4493/// Mark a buffer version as synced
4494async fn acknowledge_buffer_version(
4495 request: proto::AckBufferOperation,
4496 session: UserSession,
4497) -> Result<()> {
4498 let buffer_id = BufferId::from_proto(request.buffer_id);
4499 session
4500 .db()
4501 .await
4502 .observe_buffer_version(
4503 buffer_id,
4504 session.user_id(),
4505 request.epoch as i32,
4506 &request.version,
4507 )
4508 .await?;
4509 Ok(())
4510}
4511
4512struct CompleteWithLanguageModelRateLimit;
4513
4514impl RateLimit for CompleteWithLanguageModelRateLimit {
4515 fn capacity() -> usize {
4516 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4517 .ok()
4518 .and_then(|v| v.parse().ok())
4519 .unwrap_or(120) // Picked arbitrarily
4520 }
4521
4522 fn refill_duration() -> chrono::Duration {
4523 chrono::Duration::hours(1)
4524 }
4525
4526 fn db_name() -> &'static str {
4527 "complete-with-language-model"
4528 }
4529}
4530
4531async fn complete_with_language_model(
4532 request: proto::CompleteWithLanguageModel,
4533 response: Response<proto::CompleteWithLanguageModel>,
4534 session: Session,
4535 config: &Config,
4536) -> Result<()> {
4537 let Some(session) = session.for_user() else {
4538 return Err(anyhow!("user not found"))?;
4539 };
4540 authorize_access_to_language_models(&session).await?;
4541
4542 session
4543 .rate_limiter
4544 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4545 .await?;
4546
4547 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4548 Some(proto::LanguageModelProvider::Anthropic) => {
4549 let api_key = config
4550 .anthropic_api_key
4551 .as_ref()
4552 .context("no Anthropic AI API key configured on the server")?;
4553 anthropic::complete(
4554 session.http_client.as_ref(),
4555 anthropic::ANTHROPIC_API_URL,
4556 api_key,
4557 serde_json::from_str(&request.request)?,
4558 )
4559 .await?
4560 }
4561 _ => return Err(anyhow!("unsupported provider"))?,
4562 };
4563
4564 response.send(proto::CompleteWithLanguageModelResponse {
4565 completion: serde_json::to_string(&result)?,
4566 })?;
4567
4568 Ok(())
4569}
4570
4571async fn stream_complete_with_language_model(
4572 request: proto::StreamCompleteWithLanguageModel,
4573 response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
4574 session: Session,
4575 config: &Config,
4576) -> Result<()> {
4577 let Some(session) = session.for_user() else {
4578 return Err(anyhow!("user not found"))?;
4579 };
4580 authorize_access_to_language_models(&session).await?;
4581
4582 session
4583 .rate_limiter
4584 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4585 .await?;
4586
4587 match proto::LanguageModelProvider::from_i32(request.provider) {
4588 Some(proto::LanguageModelProvider::Anthropic) => {
4589 let api_key = config
4590 .anthropic_api_key
4591 .as_ref()
4592 .context("no Anthropic AI API key configured on the server")?;
4593 let mut chunks = anthropic::stream_completion(
4594 session.http_client.as_ref(),
4595 anthropic::ANTHROPIC_API_URL,
4596 api_key,
4597 serde_json::from_str(&request.request)?,
4598 None,
4599 )
4600 .await?;
4601 while let Some(event) = chunks.next().await {
4602 let chunk = event?;
4603 response.send(proto::StreamCompleteWithLanguageModelResponse {
4604 event: serde_json::to_string(&chunk)?,
4605 })?;
4606 }
4607 }
4608 Some(proto::LanguageModelProvider::OpenAi) => {
4609 let api_key = config
4610 .openai_api_key
4611 .as_ref()
4612 .context("no OpenAI API key configured on the server")?;
4613 let mut events = open_ai::stream_completion(
4614 session.http_client.as_ref(),
4615 open_ai::OPEN_AI_API_URL,
4616 api_key,
4617 serde_json::from_str(&request.request)?,
4618 None,
4619 )
4620 .await?;
4621 while let Some(event) = events.next().await {
4622 let event = event?;
4623 response.send(proto::StreamCompleteWithLanguageModelResponse {
4624 event: serde_json::to_string(&event)?,
4625 })?;
4626 }
4627 }
4628 Some(proto::LanguageModelProvider::Google) => {
4629 let api_key = config
4630 .google_ai_api_key
4631 .as_ref()
4632 .context("no Google AI API key configured on the server")?;
4633 let mut events = google_ai::stream_generate_content(
4634 session.http_client.as_ref(),
4635 google_ai::API_URL,
4636 api_key,
4637 serde_json::from_str(&request.request)?,
4638 )
4639 .await?;
4640 while let Some(event) = events.next().await {
4641 let event = event?;
4642 response.send(proto::StreamCompleteWithLanguageModelResponse {
4643 event: serde_json::to_string(&event)?,
4644 })?;
4645 }
4646 }
4647 None => return Err(anyhow!("unknown provider"))?,
4648 }
4649
4650 Ok(())
4651}
4652
4653async fn count_language_model_tokens(
4654 request: proto::CountLanguageModelTokens,
4655 response: Response<proto::CountLanguageModelTokens>,
4656 session: Session,
4657 config: &Config,
4658) -> Result<()> {
4659 let Some(session) = session.for_user() else {
4660 return Err(anyhow!("user not found"))?;
4661 };
4662 authorize_access_to_language_models(&session).await?;
4663
4664 session
4665 .rate_limiter
4666 .check::<CountLanguageModelTokensRateLimit>(session.user_id())
4667 .await?;
4668
4669 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4670 Some(proto::LanguageModelProvider::Google) => {
4671 let api_key = config
4672 .google_ai_api_key
4673 .as_ref()
4674 .context("no Google AI API key configured on the server")?;
4675 google_ai::count_tokens(
4676 session.http_client.as_ref(),
4677 google_ai::API_URL,
4678 api_key,
4679 serde_json::from_str(&request.request)?,
4680 )
4681 .await?
4682 }
4683 _ => return Err(anyhow!("unsupported provider"))?,
4684 };
4685
4686 response.send(proto::CountLanguageModelTokensResponse {
4687 token_count: result.total_tokens as u32,
4688 })?;
4689
4690 Ok(())
4691}
4692
4693struct CountLanguageModelTokensRateLimit;
4694
4695impl RateLimit for CountLanguageModelTokensRateLimit {
4696 fn capacity() -> usize {
4697 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4698 .ok()
4699 .and_then(|v| v.parse().ok())
4700 .unwrap_or(600) // Picked arbitrarily
4701 }
4702
4703 fn refill_duration() -> chrono::Duration {
4704 chrono::Duration::hours(1)
4705 }
4706
4707 fn db_name() -> &'static str {
4708 "count-language-model-tokens"
4709 }
4710}
4711
4712struct ComputeEmbeddingsRateLimit;
4713
4714impl RateLimit for ComputeEmbeddingsRateLimit {
4715 fn capacity() -> usize {
4716 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4717 .ok()
4718 .and_then(|v| v.parse().ok())
4719 .unwrap_or(5000) // Picked arbitrarily
4720 }
4721
4722 fn refill_duration() -> chrono::Duration {
4723 chrono::Duration::hours(1)
4724 }
4725
4726 fn db_name() -> &'static str {
4727 "compute-embeddings"
4728 }
4729}
4730
4731async fn compute_embeddings(
4732 request: proto::ComputeEmbeddings,
4733 response: Response<proto::ComputeEmbeddings>,
4734 session: UserSession,
4735 api_key: Option<Arc<str>>,
4736) -> Result<()> {
4737 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4738 authorize_access_to_language_models(&session).await?;
4739
4740 session
4741 .rate_limiter
4742 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4743 .await?;
4744
4745 let embeddings = match request.model.as_str() {
4746 "openai/text-embedding-3-small" => {
4747 open_ai::embed(
4748 session.http_client.as_ref(),
4749 OPEN_AI_API_URL,
4750 &api_key,
4751 OpenAiEmbeddingModel::TextEmbedding3Small,
4752 request.texts.iter().map(|text| text.as_str()),
4753 )
4754 .await?
4755 }
4756 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4757 };
4758
4759 let embeddings = request
4760 .texts
4761 .iter()
4762 .map(|text| {
4763 let mut hasher = sha2::Sha256::new();
4764 hasher.update(text.as_bytes());
4765 let result = hasher.finalize();
4766 result.to_vec()
4767 })
4768 .zip(
4769 embeddings
4770 .data
4771 .into_iter()
4772 .map(|embedding| embedding.embedding),
4773 )
4774 .collect::<HashMap<_, _>>();
4775
4776 let db = session.db().await;
4777 db.save_embeddings(&request.model, &embeddings)
4778 .await
4779 .context("failed to save embeddings")
4780 .trace_err();
4781
4782 response.send(proto::ComputeEmbeddingsResponse {
4783 embeddings: embeddings
4784 .into_iter()
4785 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4786 .collect(),
4787 })?;
4788 Ok(())
4789}
4790
4791async fn get_cached_embeddings(
4792 request: proto::GetCachedEmbeddings,
4793 response: Response<proto::GetCachedEmbeddings>,
4794 session: UserSession,
4795) -> Result<()> {
4796 authorize_access_to_language_models(&session).await?;
4797
4798 let db = session.db().await;
4799 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4800
4801 response.send(proto::GetCachedEmbeddingsResponse {
4802 embeddings: embeddings
4803 .into_iter()
4804 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4805 .collect(),
4806 })?;
4807 Ok(())
4808}
4809
4810async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4811 let db = session.db().await;
4812 let flags = db.get_user_flags(session.user_id()).await?;
4813 if flags.iter().any(|flag| flag == "language-models") {
4814 Ok(())
4815 } else {
4816 Err(anyhow!("permission denied"))?
4817 }
4818}
4819
4820/// Get a Supermaven API key for the user
4821async fn get_supermaven_api_key(
4822 _request: proto::GetSupermavenApiKey,
4823 response: Response<proto::GetSupermavenApiKey>,
4824 session: UserSession,
4825) -> Result<()> {
4826 let user_id: String = session.user_id().to_string();
4827 if !session.is_staff() {
4828 return Err(anyhow!("supermaven not enabled for this account"))?;
4829 }
4830
4831 let email = session
4832 .email()
4833 .ok_or_else(|| anyhow!("user must have an email"))?;
4834
4835 let supermaven_admin_api = session
4836 .supermaven_client
4837 .as_ref()
4838 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4839
4840 let result = supermaven_admin_api
4841 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4842 .await?;
4843
4844 response.send(proto::GetSupermavenApiKeyResponse {
4845 api_key: result.api_key,
4846 })?;
4847
4848 Ok(())
4849}
4850
4851/// Start receiving chat updates for a channel
4852async fn join_channel_chat(
4853 request: proto::JoinChannelChat,
4854 response: Response<proto::JoinChannelChat>,
4855 session: UserSession,
4856) -> Result<()> {
4857 let channel_id = ChannelId::from_proto(request.channel_id);
4858
4859 let db = session.db().await;
4860 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4861 .await?;
4862 let messages = db
4863 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4864 .await?;
4865 response.send(proto::JoinChannelChatResponse {
4866 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4867 messages,
4868 })?;
4869 Ok(())
4870}
4871
4872/// Stop receiving chat updates for a channel
4873async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4874 let channel_id = ChannelId::from_proto(request.channel_id);
4875 session
4876 .db()
4877 .await
4878 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4879 .await?;
4880 Ok(())
4881}
4882
4883/// Retrieve the chat history for a channel
4884async fn get_channel_messages(
4885 request: proto::GetChannelMessages,
4886 response: Response<proto::GetChannelMessages>,
4887 session: UserSession,
4888) -> Result<()> {
4889 let channel_id = ChannelId::from_proto(request.channel_id);
4890 let messages = session
4891 .db()
4892 .await
4893 .get_channel_messages(
4894 channel_id,
4895 session.user_id(),
4896 MESSAGE_COUNT_PER_PAGE,
4897 Some(MessageId::from_proto(request.before_message_id)),
4898 )
4899 .await?;
4900 response.send(proto::GetChannelMessagesResponse {
4901 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4902 messages,
4903 })?;
4904 Ok(())
4905}
4906
4907/// Retrieve specific chat messages
4908async fn get_channel_messages_by_id(
4909 request: proto::GetChannelMessagesById,
4910 response: Response<proto::GetChannelMessagesById>,
4911 session: UserSession,
4912) -> Result<()> {
4913 let message_ids = request
4914 .message_ids
4915 .iter()
4916 .map(|id| MessageId::from_proto(*id))
4917 .collect::<Vec<_>>();
4918 let messages = session
4919 .db()
4920 .await
4921 .get_channel_messages_by_id(session.user_id(), &message_ids)
4922 .await?;
4923 response.send(proto::GetChannelMessagesResponse {
4924 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4925 messages,
4926 })?;
4927 Ok(())
4928}
4929
4930/// Retrieve the current users notifications
4931async fn get_notifications(
4932 request: proto::GetNotifications,
4933 response: Response<proto::GetNotifications>,
4934 session: UserSession,
4935) -> Result<()> {
4936 let notifications = session
4937 .db()
4938 .await
4939 .get_notifications(
4940 session.user_id(),
4941 NOTIFICATION_COUNT_PER_PAGE,
4942 request
4943 .before_id
4944 .map(|id| db::NotificationId::from_proto(id)),
4945 )
4946 .await?;
4947 response.send(proto::GetNotificationsResponse {
4948 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4949 notifications,
4950 })?;
4951 Ok(())
4952}
4953
4954/// Mark notifications as read
4955async fn mark_notification_as_read(
4956 request: proto::MarkNotificationRead,
4957 response: Response<proto::MarkNotificationRead>,
4958 session: UserSession,
4959) -> Result<()> {
4960 let database = &session.db().await;
4961 let notifications = database
4962 .mark_notification_as_read_by_id(
4963 session.user_id(),
4964 NotificationId::from_proto(request.notification_id),
4965 )
4966 .await?;
4967 send_notifications(
4968 &*session.connection_pool().await,
4969 &session.peer,
4970 notifications,
4971 );
4972 response.send(proto::Ack {})?;
4973 Ok(())
4974}
4975
4976/// Get the current users information
4977async fn get_private_user_info(
4978 _request: proto::GetPrivateUserInfo,
4979 response: Response<proto::GetPrivateUserInfo>,
4980 session: UserSession,
4981) -> Result<()> {
4982 let db = session.db().await;
4983
4984 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4985 let user = db
4986 .get_user_by_id(session.user_id())
4987 .await?
4988 .ok_or_else(|| anyhow!("user not found"))?;
4989 let flags = db.get_user_flags(session.user_id()).await?;
4990
4991 response.send(proto::GetPrivateUserInfoResponse {
4992 metrics_id,
4993 staff: user.admin,
4994 flags,
4995 })?;
4996 Ok(())
4997}
4998
4999fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
5000 let message = match message {
5001 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5002 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5003 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5004 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5005 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5006 code: frame.code.into(),
5007 reason: frame.reason,
5008 })),
5009 // We should never receive a frame while reading the message, according
5010 // to the `tungstenite` maintainers:
5011 //
5012 // > It cannot occur when you read messages from the WebSocket, but it
5013 // > can be used when you want to send the raw frames (e.g. you want to
5014 // > send the frames to the WebSocket without composing the full message first).
5015 // >
5016 // > — https://github.com/snapview/tungstenite-rs/issues/268
5017 TungsteniteMessage::Frame(_) => {
5018 bail!("received an unexpected frame while reading the message")
5019 }
5020 };
5021
5022 Ok(message)
5023}
5024
5025fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5026 match message {
5027 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5028 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5029 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5030 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5031 AxumMessage::Close(frame) => {
5032 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5033 code: frame.code.into(),
5034 reason: frame.reason,
5035 }))
5036 }
5037 }
5038}
5039
5040fn notify_membership_updated(
5041 connection_pool: &mut ConnectionPool,
5042 result: MembershipUpdated,
5043 user_id: UserId,
5044 peer: &Peer,
5045) {
5046 for membership in &result.new_channels.channel_memberships {
5047 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5048 }
5049 for channel_id in &result.removed_channels {
5050 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5051 }
5052
5053 let user_channels_update = proto::UpdateUserChannels {
5054 channel_memberships: result
5055 .new_channels
5056 .channel_memberships
5057 .iter()
5058 .map(|cm| proto::ChannelMembership {
5059 channel_id: cm.channel_id.to_proto(),
5060 role: cm.role.into(),
5061 })
5062 .collect(),
5063 ..Default::default()
5064 };
5065
5066 let mut update = build_channels_update(result.new_channels);
5067 update.delete_channels = result
5068 .removed_channels
5069 .into_iter()
5070 .map(|id| id.to_proto())
5071 .collect();
5072 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5073
5074 for connection_id in connection_pool.user_connection_ids(user_id) {
5075 peer.send(connection_id, user_channels_update.clone())
5076 .trace_err();
5077 peer.send(connection_id, update.clone()).trace_err();
5078 }
5079}
5080
5081fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5082 proto::UpdateUserChannels {
5083 channel_memberships: channels
5084 .channel_memberships
5085 .iter()
5086 .map(|m| proto::ChannelMembership {
5087 channel_id: m.channel_id.to_proto(),
5088 role: m.role.into(),
5089 })
5090 .collect(),
5091 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5092 observed_channel_message_id: channels.observed_channel_messages.clone(),
5093 }
5094}
5095
5096fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5097 let mut update = proto::UpdateChannels::default();
5098
5099 for channel in channels.channels {
5100 update.channels.push(channel.to_proto());
5101 }
5102
5103 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5104 update.latest_channel_message_ids = channels.latest_channel_messages;
5105
5106 for (channel_id, participants) in channels.channel_participants {
5107 update
5108 .channel_participants
5109 .push(proto::ChannelParticipants {
5110 channel_id: channel_id.to_proto(),
5111 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5112 });
5113 }
5114
5115 for channel in channels.invited_channels {
5116 update.channel_invitations.push(channel.to_proto());
5117 }
5118
5119 update.hosted_projects = channels.hosted_projects;
5120 update
5121}
5122
5123fn build_initial_contacts_update(
5124 contacts: Vec<db::Contact>,
5125 pool: &ConnectionPool,
5126) -> proto::UpdateContacts {
5127 let mut update = proto::UpdateContacts::default();
5128
5129 for contact in contacts {
5130 match contact {
5131 db::Contact::Accepted { user_id, busy } => {
5132 update.contacts.push(contact_for_user(user_id, busy, &pool));
5133 }
5134 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5135 db::Contact::Incoming { user_id } => {
5136 update
5137 .incoming_requests
5138 .push(proto::IncomingContactRequest {
5139 requester_id: user_id.to_proto(),
5140 })
5141 }
5142 }
5143 }
5144
5145 update
5146}
5147
5148fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5149 proto::Contact {
5150 user_id: user_id.to_proto(),
5151 online: pool.is_user_online(user_id),
5152 busy,
5153 }
5154}
5155
5156fn room_updated(room: &proto::Room, peer: &Peer) {
5157 broadcast(
5158 None,
5159 room.participants
5160 .iter()
5161 .filter_map(|participant| Some(participant.peer_id?.into())),
5162 |peer_id| {
5163 peer.send(
5164 peer_id,
5165 proto::RoomUpdated {
5166 room: Some(room.clone()),
5167 },
5168 )
5169 },
5170 );
5171}
5172
5173fn channel_updated(
5174 channel: &db::channel::Model,
5175 room: &proto::Room,
5176 peer: &Peer,
5177 pool: &ConnectionPool,
5178) {
5179 let participants = room
5180 .participants
5181 .iter()
5182 .map(|p| p.user_id)
5183 .collect::<Vec<_>>();
5184
5185 broadcast(
5186 None,
5187 pool.channel_connection_ids(channel.root_id())
5188 .filter_map(|(channel_id, role)| {
5189 role.can_see_channel(channel.visibility).then(|| channel_id)
5190 }),
5191 |peer_id| {
5192 peer.send(
5193 peer_id,
5194 proto::UpdateChannels {
5195 channel_participants: vec![proto::ChannelParticipants {
5196 channel_id: channel.id.to_proto(),
5197 participant_user_ids: participants.clone(),
5198 }],
5199 ..Default::default()
5200 },
5201 )
5202 },
5203 );
5204}
5205
5206async fn send_dev_server_projects_update(
5207 user_id: UserId,
5208 mut status: proto::DevServerProjectsUpdate,
5209 session: &Session,
5210) {
5211 let pool = session.connection_pool().await;
5212 for dev_server in &mut status.dev_servers {
5213 dev_server.status =
5214 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5215 }
5216 let connections = pool.user_connection_ids(user_id);
5217 for connection_id in connections {
5218 session.peer.send(connection_id, status.clone()).trace_err();
5219 }
5220}
5221
5222async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5223 let db = session.db().await;
5224
5225 let contacts = db.get_contacts(user_id).await?;
5226 let busy = db.is_user_busy(user_id).await?;
5227
5228 let pool = session.connection_pool().await;
5229 let updated_contact = contact_for_user(user_id, busy, &pool);
5230 for contact in contacts {
5231 if let db::Contact::Accepted {
5232 user_id: contact_user_id,
5233 ..
5234 } = contact
5235 {
5236 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5237 session
5238 .peer
5239 .send(
5240 contact_conn_id,
5241 proto::UpdateContacts {
5242 contacts: vec![updated_contact.clone()],
5243 remove_contacts: Default::default(),
5244 incoming_requests: Default::default(),
5245 remove_incoming_requests: Default::default(),
5246 outgoing_requests: Default::default(),
5247 remove_outgoing_requests: Default::default(),
5248 },
5249 )
5250 .trace_err();
5251 }
5252 }
5253 }
5254 Ok(())
5255}
5256
5257async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5258 log::info!("lost dev server connection, unsharing projects");
5259 let project_ids = session
5260 .db()
5261 .await
5262 .get_stale_dev_server_projects(session.connection_id)
5263 .await?;
5264
5265 for project_id in project_ids {
5266 // not unshare re-checks the connection ids match, so we get away with no transaction
5267 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5268 }
5269
5270 let user_id = session.dev_server().user_id;
5271 let update = session
5272 .db()
5273 .await
5274 .dev_server_projects_update(user_id)
5275 .await?;
5276
5277 send_dev_server_projects_update(user_id, update, session).await;
5278
5279 Ok(())
5280}
5281
5282async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5283 let mut contacts_to_update = HashSet::default();
5284
5285 let room_id;
5286 let canceled_calls_to_user_ids;
5287 let live_kit_room;
5288 let delete_live_kit_room;
5289 let room;
5290 let channel;
5291
5292 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5293 contacts_to_update.insert(session.user_id());
5294
5295 for project in left_room.left_projects.values() {
5296 project_left(project, session);
5297 }
5298
5299 room_id = RoomId::from_proto(left_room.room.id);
5300 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5301 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5302 delete_live_kit_room = left_room.deleted;
5303 room = mem::take(&mut left_room.room);
5304 channel = mem::take(&mut left_room.channel);
5305
5306 room_updated(&room, &session.peer);
5307 } else {
5308 return Ok(());
5309 }
5310
5311 if let Some(channel) = channel {
5312 channel_updated(
5313 &channel,
5314 &room,
5315 &session.peer,
5316 &*session.connection_pool().await,
5317 );
5318 }
5319
5320 {
5321 let pool = session.connection_pool().await;
5322 for canceled_user_id in canceled_calls_to_user_ids {
5323 for connection_id in pool.user_connection_ids(canceled_user_id) {
5324 session
5325 .peer
5326 .send(
5327 connection_id,
5328 proto::CallCanceled {
5329 room_id: room_id.to_proto(),
5330 },
5331 )
5332 .trace_err();
5333 }
5334 contacts_to_update.insert(canceled_user_id);
5335 }
5336 }
5337
5338 for contact_user_id in contacts_to_update {
5339 update_user_contacts(contact_user_id, &session).await?;
5340 }
5341
5342 if let Some(live_kit) = session.live_kit_client.as_ref() {
5343 live_kit
5344 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5345 .await
5346 .trace_err();
5347
5348 if delete_live_kit_room {
5349 live_kit.delete_room(live_kit_room).await.trace_err();
5350 }
5351 }
5352
5353 Ok(())
5354}
5355
5356async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5357 let left_channel_buffers = session
5358 .db()
5359 .await
5360 .leave_channel_buffers(session.connection_id)
5361 .await?;
5362
5363 for left_buffer in left_channel_buffers {
5364 channel_buffer_updated(
5365 session.connection_id,
5366 left_buffer.connections,
5367 &proto::UpdateChannelBufferCollaborators {
5368 channel_id: left_buffer.channel_id.to_proto(),
5369 collaborators: left_buffer.collaborators,
5370 },
5371 &session.peer,
5372 );
5373 }
5374
5375 Ok(())
5376}
5377
5378fn project_left(project: &db::LeftProject, session: &UserSession) {
5379 for connection_id in &project.connection_ids {
5380 if project.should_unshare {
5381 session
5382 .peer
5383 .send(
5384 *connection_id,
5385 proto::UnshareProject {
5386 project_id: project.id.to_proto(),
5387 },
5388 )
5389 .trace_err();
5390 } else {
5391 session
5392 .peer
5393 .send(
5394 *connection_id,
5395 proto::RemoveProjectCollaborator {
5396 project_id: project.id.to_proto(),
5397 peer_id: Some(session.connection_id.into()),
5398 },
5399 )
5400 .trace_err();
5401 }
5402 }
5403}
5404
5405pub trait ResultExt {
5406 type Ok;
5407
5408 fn trace_err(self) -> Option<Self::Ok>;
5409}
5410
5411impl<T, E> ResultExt for Result<T, E>
5412where
5413 E: std::fmt::Debug,
5414{
5415 type Ok = T;
5416
5417 #[track_caller]
5418 fn trace_err(self) -> Option<T> {
5419 match self {
5420 Ok(value) => Some(value),
5421 Err(error) => {
5422 tracing::error!("{:?}", error);
5423 None
5424 }
5425 }
5426 }
5427}