1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
8 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
9 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Config, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, bail, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38
39use futures::{
40 channel::oneshot,
41 future::{self, BoxFuture},
42 stream::FuturesUnordered,
43 FutureExt, SinkExt, StreamExt, TryStreamExt,
44};
45use http_client::IsahcHttpClient;
46use prometheus::{register_int_gauge, IntGauge};
47use rpc::{
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 atomic::{AtomicBool, Ordering::SeqCst},
66 Arc, OnceLock,
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{watch, Semaphore};
72use tower::ServiceBuilder;
73use tracing::{
74 field::{self},
75 info_span, instrument, Instrument,
76};
77
78use self::connection_pool::VersionedMessage;
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106struct StreamingResponse<R: RequestMessage> {
107 peer: Arc<Peer>,
108 receipt: Receipt<R>,
109}
110
111impl<R: RequestMessage> StreamingResponse<R> {
112 fn send(&self, payload: R::Response) -> Result<()> {
113 self.peer.respond(self.receipt, payload)?;
114 Ok(())
115 }
116}
117
118#[derive(Clone, Debug)]
119pub enum Principal {
120 User(User),
121 Impersonated { user: User, admin: User },
122 DevServer(dev_server::Model),
123}
124
125impl Principal {
126 fn update_span(&self, span: &tracing::Span) {
127 match &self {
128 Principal::User(user) => {
129 span.record("user_id", &user.id.0);
130 span.record("login", &user.github_login);
131 }
132 Principal::Impersonated { user, admin } => {
133 span.record("user_id", &user.id.0);
134 span.record("login", &user.github_login);
135 span.record("impersonator", &admin.github_login);
136 }
137 Principal::DevServer(dev_server) => {
138 span.record("dev_server_id", &dev_server.id.0);
139 }
140 }
141 }
142}
143
144#[derive(Clone)]
145struct Session {
146 principal: Principal,
147 connection_id: ConnectionId,
148 db: Arc<tokio::sync::Mutex<DbHandle>>,
149 peer: Arc<Peer>,
150 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
151 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
152 supermaven_client: Option<Arc<SupermavenAdminApi>>,
153 http_client: Arc<IsahcHttpClient>,
154 rate_limiter: Arc<RateLimiter>,
155 _executor: Executor,
156}
157
158impl Session {
159 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.db.lock().await;
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 guard
166 }
167
168 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
169 #[cfg(test)]
170 tokio::task::yield_now().await;
171 let guard = self.connection_pool.lock();
172 ConnectionPoolGuard {
173 guard,
174 _not_send: PhantomData,
175 }
176 }
177
178 fn for_user(self) -> Option<UserSession> {
179 UserSession::new(self)
180 }
181
182 fn for_dev_server(self) -> Option<DevServerSession> {
183 DevServerSession::new(self)
184 }
185
186 fn user_id(&self) -> Option<UserId> {
187 match &self.principal {
188 Principal::User(user) => Some(user.id),
189 Principal::Impersonated { user, .. } => Some(user.id),
190 Principal::DevServer(_) => None,
191 }
192 }
193
194 fn is_staff(&self) -> bool {
195 match &self.principal {
196 Principal::User(user) => user.admin,
197 Principal::Impersonated { .. } => true,
198 Principal::DevServer(_) => false,
199 }
200 }
201
202 fn dev_server_id(&self) -> Option<DevServerId> {
203 match &self.principal {
204 Principal::User(_) | Principal::Impersonated { .. } => None,
205 Principal::DevServer(dev_server) => Some(dev_server.id),
206 }
207 }
208
209 fn principal_id(&self) -> PrincipalId {
210 match &self.principal {
211 Principal::User(user) => PrincipalId::UserId(user.id),
212 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
213 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
214 }
215 }
216}
217
218impl Debug for Session {
219 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220 let mut result = f.debug_struct("Session");
221 match &self.principal {
222 Principal::User(user) => {
223 result.field("user", &user.github_login);
224 }
225 Principal::Impersonated { user, admin } => {
226 result.field("user", &user.github_login);
227 result.field("impersonator", &admin.github_login);
228 }
229 Principal::DevServer(dev_server) => {
230 result.field("dev_server", &dev_server.id);
231 }
232 }
233 result.field("connection_id", &self.connection_id).finish()
234 }
235}
236
237struct UserSession(Session);
238
239impl UserSession {
240 pub fn new(s: Session) -> Option<Self> {
241 s.user_id().map(|_| UserSession(s))
242 }
243 pub fn user_id(&self) -> UserId {
244 self.0.user_id().unwrap()
245 }
246
247 pub fn email(&self) -> Option<String> {
248 match &self.0.principal {
249 Principal::User(user) => user.email_address.clone(),
250 Principal::Impersonated { user, .. } => user.email_address.clone(),
251 Principal::DevServer(..) => None,
252 }
253 }
254}
255
256impl Deref for UserSession {
257 type Target = Session;
258
259 fn deref(&self) -> &Self::Target {
260 &self.0
261 }
262}
263impl DerefMut for UserSession {
264 fn deref_mut(&mut self) -> &mut Self::Target {
265 &mut self.0
266 }
267}
268
269struct DevServerSession(Session);
270
271impl DevServerSession {
272 pub fn new(s: Session) -> Option<Self> {
273 s.dev_server_id().map(|_| DevServerSession(s))
274 }
275 pub fn dev_server_id(&self) -> DevServerId {
276 self.0.dev_server_id().unwrap()
277 }
278
279 fn dev_server(&self) -> &dev_server::Model {
280 match &self.0.principal {
281 Principal::DevServer(dev_server) => dev_server,
282 _ => unreachable!(),
283 }
284 }
285}
286
287impl Deref for DevServerSession {
288 type Target = Session;
289
290 fn deref(&self) -> &Self::Target {
291 &self.0
292 }
293}
294impl DerefMut for DevServerSession {
295 fn deref_mut(&mut self) -> &mut Self::Target {
296 &mut self.0
297 }
298}
299
300fn user_handler<M: RequestMessage, Fut>(
301 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
303where
304 Fut: Send + Future<Output = Result<()>>,
305{
306 let handler = Arc::new(handler);
307 move |message, response, session| {
308 let handler = handler.clone();
309 Box::pin(async move {
310 if let Some(user_session) = session.for_user() {
311 Ok(handler(message, response, user_session).await?)
312 } else {
313 Err(Error::Internal(anyhow!(
314 "must be a user to call {}",
315 M::NAME
316 )))
317 }
318 })
319 }
320}
321
322fn dev_server_handler<M: RequestMessage, Fut>(
323 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
325where
326 Fut: Send + Future<Output = Result<()>>,
327{
328 let handler = Arc::new(handler);
329 move |message, response, session| {
330 let handler = handler.clone();
331 Box::pin(async move {
332 if let Some(dev_server_session) = session.for_dev_server() {
333 Ok(handler(message, response, dev_server_session).await?)
334 } else {
335 Err(Error::Internal(anyhow!(
336 "must be a dev server to call {}",
337 M::NAME
338 )))
339 }
340 })
341 }
342}
343
344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
345 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
347where
348 InnertRetFut: Send + Future<Output = Result<()>>,
349{
350 let handler = Arc::new(handler);
351 move |message, session| {
352 let handler = handler.clone();
353 Box::pin(async move {
354 if let Some(user_session) = session.for_user() {
355 Ok(handler(message, user_session).await?)
356 } else {
357 Err(Error::Internal(anyhow!(
358 "must be a user to call {}",
359 M::NAME
360 )))
361 }
362 })
363 }
364}
365
366struct DbHandle(Arc<Database>);
367
368impl Deref for DbHandle {
369 type Target = Database;
370
371 fn deref(&self) -> &Self::Target {
372 self.0.as_ref()
373 }
374}
375
376pub struct Server {
377 id: parking_lot::Mutex<ServerId>,
378 peer: Arc<Peer>,
379 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
380 app_state: Arc<AppState>,
381 handlers: HashMap<TypeId, MessageHandler>,
382 teardown: watch::Sender<bool>,
383}
384
385pub(crate) struct ConnectionPoolGuard<'a> {
386 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
387 _not_send: PhantomData<Rc<()>>,
388}
389
390#[derive(Serialize)]
391pub struct ServerSnapshot<'a> {
392 peer: &'a Peer,
393 #[serde(serialize_with = "serialize_deref")]
394 connection_pool: ConnectionPoolGuard<'a>,
395}
396
397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
398where
399 S: Serializer,
400 T: Deref<Target = U>,
401 U: Serialize,
402{
403 Serialize::serialize(value.deref(), serializer)
404}
405
406impl Server {
407 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
408 let mut server = Self {
409 id: parking_lot::Mutex::new(id),
410 peer: Peer::new(id.0 as u32),
411 app_state: app_state.clone(),
412 connection_pool: Default::default(),
413 handlers: Default::default(),
414 teardown: watch::channel(false).0,
415 };
416
417 server
418 .add_request_handler(ping)
419 .add_request_handler(user_handler(create_room))
420 .add_request_handler(user_handler(join_room))
421 .add_request_handler(user_handler(rejoin_room))
422 .add_request_handler(user_handler(leave_room))
423 .add_request_handler(user_handler(set_room_participant_role))
424 .add_request_handler(user_handler(call))
425 .add_request_handler(user_handler(cancel_call))
426 .add_message_handler(user_message_handler(decline_call))
427 .add_request_handler(user_handler(update_participant_location))
428 .add_request_handler(user_handler(share_project))
429 .add_message_handler(unshare_project)
430 .add_request_handler(user_handler(join_project))
431 .add_request_handler(user_handler(join_hosted_project))
432 .add_request_handler(user_handler(rejoin_dev_server_projects))
433 .add_request_handler(user_handler(create_dev_server_project))
434 .add_request_handler(user_handler(update_dev_server_project))
435 .add_request_handler(user_handler(delete_dev_server_project))
436 .add_request_handler(user_handler(create_dev_server))
437 .add_request_handler(user_handler(regenerate_dev_server_token))
438 .add_request_handler(user_handler(rename_dev_server))
439 .add_request_handler(user_handler(delete_dev_server))
440 .add_request_handler(user_handler(list_remote_directory))
441 .add_request_handler(dev_server_handler(share_dev_server_project))
442 .add_request_handler(dev_server_handler(shutdown_dev_server))
443 .add_request_handler(dev_server_handler(reconnect_dev_server))
444 .add_message_handler(user_message_handler(leave_project))
445 .add_request_handler(update_project)
446 .add_request_handler(update_worktree)
447 .add_message_handler(start_language_server)
448 .add_message_handler(update_language_server)
449 .add_message_handler(update_diagnostic_summary)
450 .add_message_handler(update_worktree_settings)
451 .add_request_handler(user_handler(
452 forward_project_request_for_owner::<proto::TaskContextForLocation>,
453 ))
454 .add_request_handler(user_handler(
455 forward_project_request_for_owner::<proto::TaskTemplates>,
456 ))
457 .add_request_handler(user_handler(
458 forward_read_only_project_request::<proto::GetHover>,
459 ))
460 .add_request_handler(user_handler(
461 forward_read_only_project_request::<proto::GetDefinition>,
462 ))
463 .add_request_handler(user_handler(
464 forward_read_only_project_request::<proto::GetTypeDefinition>,
465 ))
466 .add_request_handler(user_handler(
467 forward_read_only_project_request::<proto::GetReferences>,
468 ))
469 .add_request_handler(user_handler(
470 forward_read_only_project_request::<proto::SearchProject>,
471 ))
472 .add_request_handler(user_handler(
473 forward_read_only_project_request::<proto::GetDocumentHighlights>,
474 ))
475 .add_request_handler(user_handler(
476 forward_read_only_project_request::<proto::GetProjectSymbols>,
477 ))
478 .add_request_handler(user_handler(
479 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
480 ))
481 .add_request_handler(user_handler(
482 forward_read_only_project_request::<proto::OpenBufferById>,
483 ))
484 .add_request_handler(user_handler(
485 forward_read_only_project_request::<proto::SynchronizeBuffers>,
486 ))
487 .add_request_handler(user_handler(
488 forward_read_only_project_request::<proto::InlayHints>,
489 ))
490 .add_request_handler(user_handler(
491 forward_read_only_project_request::<proto::OpenBufferByPath>,
492 ))
493 .add_request_handler(user_handler(
494 forward_mutating_project_request::<proto::GetCompletions>,
495 ))
496 .add_request_handler(user_handler(
497 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
498 ))
499 .add_request_handler(user_handler(
500 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
501 ))
502 .add_request_handler(user_handler(
503 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
504 ))
505 .add_request_handler(user_handler(
506 forward_mutating_project_request::<proto::GetCodeActions>,
507 ))
508 .add_request_handler(user_handler(
509 forward_mutating_project_request::<proto::ApplyCodeAction>,
510 ))
511 .add_request_handler(user_handler(
512 forward_mutating_project_request::<proto::PrepareRename>,
513 ))
514 .add_request_handler(user_handler(
515 forward_mutating_project_request::<proto::PerformRename>,
516 ))
517 .add_request_handler(user_handler(
518 forward_mutating_project_request::<proto::ReloadBuffers>,
519 ))
520 .add_request_handler(user_handler(
521 forward_mutating_project_request::<proto::FormatBuffers>,
522 ))
523 .add_request_handler(user_handler(
524 forward_mutating_project_request::<proto::CreateProjectEntry>,
525 ))
526 .add_request_handler(user_handler(
527 forward_mutating_project_request::<proto::RenameProjectEntry>,
528 ))
529 .add_request_handler(user_handler(
530 forward_mutating_project_request::<proto::CopyProjectEntry>,
531 ))
532 .add_request_handler(user_handler(
533 forward_mutating_project_request::<proto::DeleteProjectEntry>,
534 ))
535 .add_request_handler(user_handler(
536 forward_mutating_project_request::<proto::ExpandProjectEntry>,
537 ))
538 .add_request_handler(user_handler(
539 forward_mutating_project_request::<proto::OnTypeFormatting>,
540 ))
541 .add_request_handler(user_handler(
542 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
543 ))
544 .add_request_handler(user_handler(
545 forward_mutating_project_request::<proto::BlameBuffer>,
546 ))
547 .add_request_handler(user_handler(
548 forward_mutating_project_request::<proto::MultiLspQuery>,
549 ))
550 .add_request_handler(user_handler(
551 forward_mutating_project_request::<proto::RestartLanguageServers>,
552 ))
553 .add_request_handler(user_handler(
554 forward_mutating_project_request::<proto::LinkedEditingRange>,
555 ))
556 .add_message_handler(create_buffer_for_peer)
557 .add_request_handler(update_buffer)
558 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
559 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
560 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
561 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
562 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
563 .add_request_handler(get_users)
564 .add_request_handler(user_handler(fuzzy_search_users))
565 .add_request_handler(user_handler(request_contact))
566 .add_request_handler(user_handler(remove_contact))
567 .add_request_handler(user_handler(respond_to_contact_request))
568 .add_message_handler(subscribe_to_channels)
569 .add_request_handler(user_handler(create_channel))
570 .add_request_handler(user_handler(delete_channel))
571 .add_request_handler(user_handler(invite_channel_member))
572 .add_request_handler(user_handler(remove_channel_member))
573 .add_request_handler(user_handler(set_channel_member_role))
574 .add_request_handler(user_handler(set_channel_visibility))
575 .add_request_handler(user_handler(rename_channel))
576 .add_request_handler(user_handler(join_channel_buffer))
577 .add_request_handler(user_handler(leave_channel_buffer))
578 .add_message_handler(user_message_handler(update_channel_buffer))
579 .add_request_handler(user_handler(rejoin_channel_buffers))
580 .add_request_handler(user_handler(get_channel_members))
581 .add_request_handler(user_handler(respond_to_channel_invite))
582 .add_request_handler(user_handler(join_channel))
583 .add_request_handler(user_handler(join_channel_chat))
584 .add_message_handler(user_message_handler(leave_channel_chat))
585 .add_request_handler(user_handler(send_channel_message))
586 .add_request_handler(user_handler(remove_channel_message))
587 .add_request_handler(user_handler(update_channel_message))
588 .add_request_handler(user_handler(get_channel_messages))
589 .add_request_handler(user_handler(get_channel_messages_by_id))
590 .add_request_handler(user_handler(get_notifications))
591 .add_request_handler(user_handler(mark_notification_as_read))
592 .add_request_handler(user_handler(move_channel))
593 .add_request_handler(user_handler(follow))
594 .add_message_handler(user_message_handler(unfollow))
595 .add_message_handler(user_message_handler(update_followers))
596 .add_request_handler(user_handler(get_private_user_info))
597 .add_message_handler(user_message_handler(acknowledge_channel_message))
598 .add_message_handler(user_message_handler(acknowledge_buffer_version))
599 .add_request_handler(user_handler(get_supermaven_api_key))
600 .add_request_handler(user_handler(
601 forward_mutating_project_request::<proto::OpenContext>,
602 ))
603 .add_request_handler(user_handler(
604 forward_mutating_project_request::<proto::CreateContext>,
605 ))
606 .add_request_handler(user_handler(
607 forward_mutating_project_request::<proto::SynchronizeContexts>,
608 ))
609 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
610 .add_message_handler(update_context)
611 .add_request_handler({
612 let app_state = app_state.clone();
613 move |request, response, session| {
614 let app_state = app_state.clone();
615 async move {
616 complete_with_language_model(request, response, session, &app_state.config)
617 .await
618 }
619 }
620 })
621 .add_streaming_request_handler({
622 let app_state = app_state.clone();
623 move |request, response, session| {
624 let app_state = app_state.clone();
625 async move {
626 stream_complete_with_language_model(
627 request,
628 response,
629 session,
630 &app_state.config,
631 )
632 .await
633 }
634 }
635 })
636 .add_request_handler({
637 let app_state = app_state.clone();
638 move |request, response, session| {
639 let app_state = app_state.clone();
640 async move {
641 count_language_model_tokens(request, response, session, &app_state.config)
642 .await
643 }
644 }
645 })
646 .add_request_handler({
647 user_handler(move |request, response, session| {
648 get_cached_embeddings(request, response, session)
649 })
650 })
651 .add_request_handler({
652 let app_state = app_state.clone();
653 user_handler(move |request, response, session| {
654 compute_embeddings(
655 request,
656 response,
657 session,
658 app_state.config.openai_api_key.clone(),
659 )
660 })
661 });
662
663 Arc::new(server)
664 }
665
666 pub async fn start(&self) -> Result<()> {
667 let server_id = *self.id.lock();
668 let app_state = self.app_state.clone();
669 let peer = self.peer.clone();
670 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
671 let pool = self.connection_pool.clone();
672 let live_kit_client = self.app_state.live_kit_client.clone();
673
674 let span = info_span!("start server");
675 self.app_state.executor.spawn_detached(
676 async move {
677 tracing::info!("waiting for cleanup timeout");
678 timeout.await;
679 tracing::info!("cleanup timeout expired, retrieving stale rooms");
680 if let Some((room_ids, channel_ids)) = app_state
681 .db
682 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
683 .await
684 .trace_err()
685 {
686 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
687 tracing::info!(
688 stale_channel_buffer_count = channel_ids.len(),
689 "retrieved stale channel buffers"
690 );
691
692 for channel_id in channel_ids {
693 if let Some(refreshed_channel_buffer) = app_state
694 .db
695 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
696 .await
697 .trace_err()
698 {
699 for connection_id in refreshed_channel_buffer.connection_ids {
700 peer.send(
701 connection_id,
702 proto::UpdateChannelBufferCollaborators {
703 channel_id: channel_id.to_proto(),
704 collaborators: refreshed_channel_buffer
705 .collaborators
706 .clone(),
707 },
708 )
709 .trace_err();
710 }
711 }
712 }
713
714 for room_id in room_ids {
715 let mut contacts_to_update = HashSet::default();
716 let mut canceled_calls_to_user_ids = Vec::new();
717 let mut live_kit_room = String::new();
718 let mut delete_live_kit_room = false;
719
720 if let Some(mut refreshed_room) = app_state
721 .db
722 .clear_stale_room_participants(room_id, server_id)
723 .await
724 .trace_err()
725 {
726 tracing::info!(
727 room_id = room_id.0,
728 new_participant_count = refreshed_room.room.participants.len(),
729 "refreshed room"
730 );
731 room_updated(&refreshed_room.room, &peer);
732 if let Some(channel) = refreshed_room.channel.as_ref() {
733 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
734 }
735 contacts_to_update
736 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
737 contacts_to_update
738 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
739 canceled_calls_to_user_ids =
740 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
741 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
742 delete_live_kit_room = refreshed_room.room.participants.is_empty();
743 }
744
745 {
746 let pool = pool.lock();
747 for canceled_user_id in canceled_calls_to_user_ids {
748 for connection_id in pool.user_connection_ids(canceled_user_id) {
749 peer.send(
750 connection_id,
751 proto::CallCanceled {
752 room_id: room_id.to_proto(),
753 },
754 )
755 .trace_err();
756 }
757 }
758 }
759
760 for user_id in contacts_to_update {
761 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
762 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
763 if let Some((busy, contacts)) = busy.zip(contacts) {
764 let pool = pool.lock();
765 let updated_contact = contact_for_user(user_id, busy, &pool);
766 for contact in contacts {
767 if let db::Contact::Accepted {
768 user_id: contact_user_id,
769 ..
770 } = contact
771 {
772 for contact_conn_id in
773 pool.user_connection_ids(contact_user_id)
774 {
775 peer.send(
776 contact_conn_id,
777 proto::UpdateContacts {
778 contacts: vec![updated_contact.clone()],
779 remove_contacts: Default::default(),
780 incoming_requests: Default::default(),
781 remove_incoming_requests: Default::default(),
782 outgoing_requests: Default::default(),
783 remove_outgoing_requests: Default::default(),
784 },
785 )
786 .trace_err();
787 }
788 }
789 }
790 }
791 }
792
793 if let Some(live_kit) = live_kit_client.as_ref() {
794 if delete_live_kit_room {
795 live_kit.delete_room(live_kit_room).await.trace_err();
796 }
797 }
798 }
799 }
800
801 app_state
802 .db
803 .delete_stale_servers(&app_state.config.zed_environment, server_id)
804 .await
805 .trace_err();
806 }
807 .instrument(span),
808 );
809 Ok(())
810 }
811
812 pub fn teardown(&self) {
813 self.peer.teardown();
814 self.connection_pool.lock().reset();
815 let _ = self.teardown.send(true);
816 }
817
818 #[cfg(test)]
819 pub fn reset(&self, id: ServerId) {
820 self.teardown();
821 *self.id.lock() = id;
822 self.peer.reset(id.0 as u32);
823 let _ = self.teardown.send(false);
824 }
825
826 #[cfg(test)]
827 pub fn id(&self) -> ServerId {
828 *self.id.lock()
829 }
830
831 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
832 where
833 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
834 Fut: 'static + Send + Future<Output = Result<()>>,
835 M: EnvelopedMessage,
836 {
837 let prev_handler = self.handlers.insert(
838 TypeId::of::<M>(),
839 Box::new(move |envelope, session| {
840 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
841 let received_at = envelope.received_at;
842 tracing::info!("message received");
843 let start_time = Instant::now();
844 let future = (handler)(*envelope, session);
845 async move {
846 let result = future.await;
847 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
848 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
849 let queue_duration_ms = total_duration_ms - processing_duration_ms;
850 let payload_type = M::NAME;
851
852 match result {
853 Err(error) => {
854 tracing::error!(
855 ?error,
856 total_duration_ms,
857 processing_duration_ms,
858 queue_duration_ms,
859 payload_type,
860 "error handling message"
861 )
862 }
863 Ok(()) => tracing::info!(
864 total_duration_ms,
865 processing_duration_ms,
866 queue_duration_ms,
867 "finished handling message"
868 ),
869 }
870 }
871 .boxed()
872 }),
873 );
874 if prev_handler.is_some() {
875 panic!("registered a handler for the same message twice");
876 }
877 self
878 }
879
880 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
881 where
882 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
883 Fut: 'static + Send + Future<Output = Result<()>>,
884 M: EnvelopedMessage,
885 {
886 self.add_handler(move |envelope, session| handler(envelope.payload, session));
887 self
888 }
889
890 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
891 where
892 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
893 Fut: Send + Future<Output = Result<()>>,
894 M: RequestMessage,
895 {
896 let handler = Arc::new(handler);
897 self.add_handler(move |envelope, session| {
898 let receipt = envelope.receipt();
899 let handler = handler.clone();
900 async move {
901 let peer = session.peer.clone();
902 let responded = Arc::new(AtomicBool::default());
903 let response = Response {
904 peer: peer.clone(),
905 responded: responded.clone(),
906 receipt,
907 };
908 match (handler)(envelope.payload, response, session).await {
909 Ok(()) => {
910 if responded.load(std::sync::atomic::Ordering::SeqCst) {
911 Ok(())
912 } else {
913 Err(anyhow!("handler did not send a response"))?
914 }
915 }
916 Err(error) => {
917 let proto_err = match &error {
918 Error::Internal(err) => err.to_proto(),
919 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
920 };
921 peer.respond_with_error(receipt, proto_err)?;
922 Err(error)
923 }
924 }
925 }
926 })
927 }
928
929 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
930 where
931 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
932 Fut: Send + Future<Output = Result<()>>,
933 M: RequestMessage,
934 {
935 let handler = Arc::new(handler);
936 self.add_handler(move |envelope, session| {
937 let receipt = envelope.receipt();
938 let handler = handler.clone();
939 async move {
940 let peer = session.peer.clone();
941 let response = StreamingResponse {
942 peer: peer.clone(),
943 receipt,
944 };
945 match (handler)(envelope.payload, response, session).await {
946 Ok(()) => {
947 peer.end_stream(receipt)?;
948 Ok(())
949 }
950 Err(error) => {
951 let proto_err = match &error {
952 Error::Internal(err) => err.to_proto(),
953 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
954 };
955 peer.respond_with_error(receipt, proto_err)?;
956 Err(error)
957 }
958 }
959 }
960 })
961 }
962
963 #[allow(clippy::too_many_arguments)]
964 pub fn handle_connection(
965 self: &Arc<Self>,
966 connection: Connection,
967 address: String,
968 principal: Principal,
969 zed_version: ZedVersion,
970 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
971 executor: Executor,
972 ) -> impl Future<Output = ()> {
973 let this = self.clone();
974 let span = info_span!("handle connection", %address,
975 connection_id=field::Empty,
976 user_id=field::Empty,
977 login=field::Empty,
978 impersonator=field::Empty,
979 dev_server_id=field::Empty
980 );
981 principal.update_span(&span);
982
983 let mut teardown = self.teardown.subscribe();
984 async move {
985 if *teardown.borrow() {
986 tracing::error!("server is tearing down");
987 return
988 }
989 let (connection_id, handle_io, mut incoming_rx) = this
990 .peer
991 .add_connection(connection, {
992 let executor = executor.clone();
993 move |duration| executor.sleep(duration)
994 });
995 tracing::Span::current().record("connection_id", format!("{}", connection_id));
996 tracing::info!("connection opened");
997
998 let http_client = match IsahcHttpClient::new() {
999 Ok(http_client) => Arc::new(http_client),
1000 Err(error) => {
1001 tracing::error!(?error, "failed to create HTTP client");
1002 return;
1003 }
1004 };
1005
1006 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
1007 Some(Arc::new(SupermavenAdminApi::new(
1008 supermaven_admin_api_key.to_string(),
1009 http_client.clone(),
1010 )))
1011 } else {
1012 None
1013 };
1014
1015 let session = Session {
1016 principal: principal.clone(),
1017 connection_id,
1018 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1019 peer: this.peer.clone(),
1020 connection_pool: this.connection_pool.clone(),
1021 live_kit_client: this.app_state.live_kit_client.clone(),
1022 http_client,
1023 rate_limiter: this.app_state.rate_limiter.clone(),
1024 _executor: executor.clone(),
1025 supermaven_client,
1026 };
1027
1028 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1029 tracing::error!(?error, "failed to send initial client update");
1030 return;
1031 }
1032
1033 let handle_io = handle_io.fuse();
1034 futures::pin_mut!(handle_io);
1035
1036 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1037 // This prevents deadlocks when e.g., client A performs a request to client B and
1038 // client B performs a request to client A. If both clients stop processing further
1039 // messages until their respective request completes, they won't have a chance to
1040 // respond to the other client's request and cause a deadlock.
1041 //
1042 // This arrangement ensures we will attempt to process earlier messages first, but fall
1043 // back to processing messages arrived later in the spirit of making progress.
1044 let mut foreground_message_handlers = FuturesUnordered::new();
1045 let concurrent_handlers = Arc::new(Semaphore::new(256));
1046 loop {
1047 let next_message = async {
1048 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1049 let message = incoming_rx.next().await;
1050 (permit, message)
1051 }.fuse();
1052 futures::pin_mut!(next_message);
1053 futures::select_biased! {
1054 _ = teardown.changed().fuse() => return,
1055 result = handle_io => {
1056 if let Err(error) = result {
1057 tracing::error!(?error, "error handling I/O");
1058 }
1059 break;
1060 }
1061 _ = foreground_message_handlers.next() => {}
1062 next_message = next_message => {
1063 let (permit, message) = next_message;
1064 if let Some(message) = message {
1065 let type_name = message.payload_type_name();
1066 // note: we copy all the fields from the parent span so we can query them in the logs.
1067 // (https://github.com/tokio-rs/tracing/issues/2670).
1068 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1069 user_id=field::Empty,
1070 login=field::Empty,
1071 impersonator=field::Empty,
1072 dev_server_id=field::Empty
1073 );
1074 principal.update_span(&span);
1075 let span_enter = span.enter();
1076 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1077 let is_background = message.is_background();
1078 let handle_message = (handler)(message, session.clone());
1079 drop(span_enter);
1080
1081 let handle_message = async move {
1082 handle_message.await;
1083 drop(permit);
1084 }.instrument(span);
1085 if is_background {
1086 executor.spawn_detached(handle_message);
1087 } else {
1088 foreground_message_handlers.push(handle_message);
1089 }
1090 } else {
1091 tracing::error!("no message handler");
1092 }
1093 } else {
1094 tracing::info!("connection closed");
1095 break;
1096 }
1097 }
1098 }
1099 }
1100
1101 drop(foreground_message_handlers);
1102 tracing::info!("signing out");
1103 if let Err(error) = connection_lost(session, teardown, executor).await {
1104 tracing::error!(?error, "error signing out");
1105 }
1106
1107 }.instrument(span)
1108 }
1109
1110 async fn send_initial_client_update(
1111 &self,
1112 connection_id: ConnectionId,
1113 principal: &Principal,
1114 zed_version: ZedVersion,
1115 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1116 session: &Session,
1117 ) -> Result<()> {
1118 self.peer.send(
1119 connection_id,
1120 proto::Hello {
1121 peer_id: Some(connection_id.into()),
1122 },
1123 )?;
1124 tracing::info!("sent hello message");
1125 if let Some(send_connection_id) = send_connection_id.take() {
1126 let _ = send_connection_id.send(connection_id);
1127 }
1128
1129 match principal {
1130 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1131 if !user.connected_once {
1132 self.peer.send(connection_id, proto::ShowContacts {})?;
1133 self.app_state
1134 .db
1135 .set_user_connected_once(user.id, true)
1136 .await?;
1137 }
1138
1139 let (contacts, dev_server_projects) = future::try_join(
1140 self.app_state.db.get_contacts(user.id),
1141 self.app_state.db.dev_server_projects_update(user.id),
1142 )
1143 .await?;
1144
1145 {
1146 let mut pool = self.connection_pool.lock();
1147 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1148 self.peer.send(
1149 connection_id,
1150 build_initial_contacts_update(contacts, &pool),
1151 )?;
1152 }
1153
1154 if should_auto_subscribe_to_channels(zed_version) {
1155 subscribe_user_to_channels(user.id, session).await?;
1156 }
1157
1158 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1159
1160 if let Some(incoming_call) =
1161 self.app_state.db.incoming_call_for_user(user.id).await?
1162 {
1163 self.peer.send(connection_id, incoming_call)?;
1164 }
1165
1166 update_user_contacts(user.id, &session).await?;
1167 }
1168 Principal::DevServer(dev_server) => {
1169 {
1170 let mut pool = self.connection_pool.lock();
1171 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1172 {
1173 self.peer.send(
1174 stale_connection_id,
1175 proto::ShutdownDevServer {
1176 reason: Some(
1177 "another dev server connected with the same token".to_string(),
1178 ),
1179 },
1180 )?;
1181 pool.remove_connection(stale_connection_id)?;
1182 };
1183 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1184 }
1185
1186 let projects = self
1187 .app_state
1188 .db
1189 .get_projects_for_dev_server(dev_server.id)
1190 .await?;
1191 self.peer
1192 .send(connection_id, proto::DevServerInstructions { projects })?;
1193
1194 let status = self
1195 .app_state
1196 .db
1197 .dev_server_projects_update(dev_server.user_id)
1198 .await?;
1199 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1200 }
1201 }
1202
1203 Ok(())
1204 }
1205
1206 pub async fn invite_code_redeemed(
1207 self: &Arc<Self>,
1208 inviter_id: UserId,
1209 invitee_id: UserId,
1210 ) -> Result<()> {
1211 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1212 if let Some(code) = &user.invite_code {
1213 let pool = self.connection_pool.lock();
1214 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1215 for connection_id in pool.user_connection_ids(inviter_id) {
1216 self.peer.send(
1217 connection_id,
1218 proto::UpdateContacts {
1219 contacts: vec![invitee_contact.clone()],
1220 ..Default::default()
1221 },
1222 )?;
1223 self.peer.send(
1224 connection_id,
1225 proto::UpdateInviteInfo {
1226 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1227 count: user.invite_count as u32,
1228 },
1229 )?;
1230 }
1231 }
1232 }
1233 Ok(())
1234 }
1235
1236 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1237 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1238 if let Some(invite_code) = &user.invite_code {
1239 let pool = self.connection_pool.lock();
1240 for connection_id in pool.user_connection_ids(user_id) {
1241 self.peer.send(
1242 connection_id,
1243 proto::UpdateInviteInfo {
1244 url: format!(
1245 "{}{}",
1246 self.app_state.config.invite_link_prefix, invite_code
1247 ),
1248 count: user.invite_count as u32,
1249 },
1250 )?;
1251 }
1252 }
1253 }
1254 Ok(())
1255 }
1256
1257 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1258 ServerSnapshot {
1259 connection_pool: ConnectionPoolGuard {
1260 guard: self.connection_pool.lock(),
1261 _not_send: PhantomData,
1262 },
1263 peer: &self.peer,
1264 }
1265 }
1266}
1267
1268impl<'a> Deref for ConnectionPoolGuard<'a> {
1269 type Target = ConnectionPool;
1270
1271 fn deref(&self) -> &Self::Target {
1272 &self.guard
1273 }
1274}
1275
1276impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1277 fn deref_mut(&mut self) -> &mut Self::Target {
1278 &mut self.guard
1279 }
1280}
1281
1282impl<'a> Drop for ConnectionPoolGuard<'a> {
1283 fn drop(&mut self) {
1284 #[cfg(test)]
1285 self.check_invariants();
1286 }
1287}
1288
1289fn broadcast<F>(
1290 sender_id: Option<ConnectionId>,
1291 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1292 mut f: F,
1293) where
1294 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1295{
1296 for receiver_id in receiver_ids {
1297 if Some(receiver_id) != sender_id {
1298 if let Err(error) = f(receiver_id) {
1299 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1300 }
1301 }
1302 }
1303}
1304
1305pub struct ProtocolVersion(u32);
1306
1307impl Header for ProtocolVersion {
1308 fn name() -> &'static HeaderName {
1309 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1310 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1311 }
1312
1313 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1314 where
1315 Self: Sized,
1316 I: Iterator<Item = &'i axum::http::HeaderValue>,
1317 {
1318 let version = values
1319 .next()
1320 .ok_or_else(axum::headers::Error::invalid)?
1321 .to_str()
1322 .map_err(|_| axum::headers::Error::invalid())?
1323 .parse()
1324 .map_err(|_| axum::headers::Error::invalid())?;
1325 Ok(Self(version))
1326 }
1327
1328 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1329 values.extend([self.0.to_string().parse().unwrap()]);
1330 }
1331}
1332
1333pub struct AppVersionHeader(SemanticVersion);
1334impl Header for AppVersionHeader {
1335 fn name() -> &'static HeaderName {
1336 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1337 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1338 }
1339
1340 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1341 where
1342 Self: Sized,
1343 I: Iterator<Item = &'i axum::http::HeaderValue>,
1344 {
1345 let version = values
1346 .next()
1347 .ok_or_else(axum::headers::Error::invalid)?
1348 .to_str()
1349 .map_err(|_| axum::headers::Error::invalid())?
1350 .parse()
1351 .map_err(|_| axum::headers::Error::invalid())?;
1352 Ok(Self(version))
1353 }
1354
1355 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1356 values.extend([self.0.to_string().parse().unwrap()]);
1357 }
1358}
1359
1360pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1361 Router::new()
1362 .route("/rpc", get(handle_websocket_request))
1363 .layer(
1364 ServiceBuilder::new()
1365 .layer(Extension(server.app_state.clone()))
1366 .layer(middleware::from_fn(auth::validate_header)),
1367 )
1368 .route("/metrics", get(handle_metrics))
1369 .layer(Extension(server))
1370}
1371
1372pub async fn handle_websocket_request(
1373 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1374 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1375 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1376 Extension(server): Extension<Arc<Server>>,
1377 Extension(principal): Extension<Principal>,
1378 ws: WebSocketUpgrade,
1379) -> axum::response::Response {
1380 if protocol_version != rpc::PROTOCOL_VERSION {
1381 return (
1382 StatusCode::UPGRADE_REQUIRED,
1383 "client must be upgraded".to_string(),
1384 )
1385 .into_response();
1386 }
1387
1388 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1389 return (
1390 StatusCode::UPGRADE_REQUIRED,
1391 "no version header found".to_string(),
1392 )
1393 .into_response();
1394 };
1395
1396 if !version.can_collaborate() {
1397 return (
1398 StatusCode::UPGRADE_REQUIRED,
1399 "client must be upgraded".to_string(),
1400 )
1401 .into_response();
1402 }
1403
1404 let socket_address = socket_address.to_string();
1405 ws.on_upgrade(move |socket| {
1406 let socket = socket
1407 .map_ok(to_tungstenite_message)
1408 .err_into()
1409 .with(|message| async move { to_axum_message(message) });
1410 let connection = Connection::new(Box::pin(socket));
1411 async move {
1412 server
1413 .handle_connection(
1414 connection,
1415 socket_address,
1416 principal,
1417 version,
1418 None,
1419 Executor::Production,
1420 )
1421 .await;
1422 }
1423 })
1424}
1425
1426pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1427 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1428 let connections_metric = CONNECTIONS_METRIC
1429 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1430
1431 let connections = server
1432 .connection_pool
1433 .lock()
1434 .connections()
1435 .filter(|connection| !connection.admin)
1436 .count();
1437 connections_metric.set(connections as _);
1438
1439 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1440 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1441 register_int_gauge!(
1442 "shared_projects",
1443 "number of open projects with one or more guests"
1444 )
1445 .unwrap()
1446 });
1447
1448 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1449 shared_projects_metric.set(shared_projects as _);
1450
1451 let encoder = prometheus::TextEncoder::new();
1452 let metric_families = prometheus::gather();
1453 let encoded_metrics = encoder
1454 .encode_to_string(&metric_families)
1455 .map_err(|err| anyhow!("{}", err))?;
1456 Ok(encoded_metrics)
1457}
1458
1459#[instrument(err, skip(executor))]
1460async fn connection_lost(
1461 session: Session,
1462 mut teardown: watch::Receiver<bool>,
1463 executor: Executor,
1464) -> Result<()> {
1465 session.peer.disconnect(session.connection_id);
1466 session
1467 .connection_pool()
1468 .await
1469 .remove_connection(session.connection_id)?;
1470
1471 session
1472 .db()
1473 .await
1474 .connection_lost(session.connection_id)
1475 .await
1476 .trace_err();
1477
1478 futures::select_biased! {
1479 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1480 match &session.principal {
1481 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1482 let session = session.for_user().unwrap();
1483
1484 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1485 leave_room_for_session(&session, session.connection_id).await.trace_err();
1486 leave_channel_buffers_for_session(&session)
1487 .await
1488 .trace_err();
1489
1490 if !session
1491 .connection_pool()
1492 .await
1493 .is_user_online(session.user_id())
1494 {
1495 let db = session.db().await;
1496 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1497 room_updated(&room, &session.peer);
1498 }
1499 }
1500
1501 update_user_contacts(session.user_id(), &session).await?;
1502 },
1503 Principal::DevServer(_) => {
1504 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1505 },
1506 }
1507 },
1508 _ = teardown.changed().fuse() => {}
1509 }
1510
1511 Ok(())
1512}
1513
1514/// Acknowledges a ping from a client, used to keep the connection alive.
1515async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1516 response.send(proto::Ack {})?;
1517 Ok(())
1518}
1519
1520/// Creates a new room for calling (outside of channels)
1521async fn create_room(
1522 _request: proto::CreateRoom,
1523 response: Response<proto::CreateRoom>,
1524 session: UserSession,
1525) -> Result<()> {
1526 let live_kit_room = nanoid::nanoid!(30);
1527
1528 let live_kit_connection_info = util::maybe!(async {
1529 let live_kit = session.live_kit_client.as_ref();
1530 let live_kit = live_kit?;
1531 let user_id = session.user_id().to_string();
1532
1533 let token = live_kit
1534 .room_token(&live_kit_room, &user_id.to_string())
1535 .trace_err()?;
1536
1537 Some(proto::LiveKitConnectionInfo {
1538 server_url: live_kit.url().into(),
1539 token,
1540 can_publish: true,
1541 })
1542 })
1543 .await;
1544
1545 let room = session
1546 .db()
1547 .await
1548 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1549 .await?;
1550
1551 response.send(proto::CreateRoomResponse {
1552 room: Some(room.clone()),
1553 live_kit_connection_info,
1554 })?;
1555
1556 update_user_contacts(session.user_id(), &session).await?;
1557 Ok(())
1558}
1559
1560/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1561async fn join_room(
1562 request: proto::JoinRoom,
1563 response: Response<proto::JoinRoom>,
1564 session: UserSession,
1565) -> Result<()> {
1566 let room_id = RoomId::from_proto(request.id);
1567
1568 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1569
1570 if let Some(channel_id) = channel_id {
1571 return join_channel_internal(channel_id, Box::new(response), session).await;
1572 }
1573
1574 let joined_room = {
1575 let room = session
1576 .db()
1577 .await
1578 .join_room(room_id, session.user_id(), session.connection_id)
1579 .await?;
1580 room_updated(&room.room, &session.peer);
1581 room.into_inner()
1582 };
1583
1584 for connection_id in session
1585 .connection_pool()
1586 .await
1587 .user_connection_ids(session.user_id())
1588 {
1589 session
1590 .peer
1591 .send(
1592 connection_id,
1593 proto::CallCanceled {
1594 room_id: room_id.to_proto(),
1595 },
1596 )
1597 .trace_err();
1598 }
1599
1600 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1601 if let Some(token) = live_kit
1602 .room_token(
1603 &joined_room.room.live_kit_room,
1604 &session.user_id().to_string(),
1605 )
1606 .trace_err()
1607 {
1608 Some(proto::LiveKitConnectionInfo {
1609 server_url: live_kit.url().into(),
1610 token,
1611 can_publish: true,
1612 })
1613 } else {
1614 None
1615 }
1616 } else {
1617 None
1618 };
1619
1620 response.send(proto::JoinRoomResponse {
1621 room: Some(joined_room.room),
1622 channel_id: None,
1623 live_kit_connection_info,
1624 })?;
1625
1626 update_user_contacts(session.user_id(), &session).await?;
1627 Ok(())
1628}
1629
1630/// Rejoin room is used to reconnect to a room after connection errors.
1631async fn rejoin_room(
1632 request: proto::RejoinRoom,
1633 response: Response<proto::RejoinRoom>,
1634 session: UserSession,
1635) -> Result<()> {
1636 let room;
1637 let channel;
1638 {
1639 let mut rejoined_room = session
1640 .db()
1641 .await
1642 .rejoin_room(request, session.user_id(), session.connection_id)
1643 .await?;
1644
1645 response.send(proto::RejoinRoomResponse {
1646 room: Some(rejoined_room.room.clone()),
1647 reshared_projects: rejoined_room
1648 .reshared_projects
1649 .iter()
1650 .map(|project| proto::ResharedProject {
1651 id: project.id.to_proto(),
1652 collaborators: project
1653 .collaborators
1654 .iter()
1655 .map(|collaborator| collaborator.to_proto())
1656 .collect(),
1657 })
1658 .collect(),
1659 rejoined_projects: rejoined_room
1660 .rejoined_projects
1661 .iter()
1662 .map(|rejoined_project| rejoined_project.to_proto())
1663 .collect(),
1664 })?;
1665 room_updated(&rejoined_room.room, &session.peer);
1666
1667 for project in &rejoined_room.reshared_projects {
1668 for collaborator in &project.collaborators {
1669 session
1670 .peer
1671 .send(
1672 collaborator.connection_id,
1673 proto::UpdateProjectCollaborator {
1674 project_id: project.id.to_proto(),
1675 old_peer_id: Some(project.old_connection_id.into()),
1676 new_peer_id: Some(session.connection_id.into()),
1677 },
1678 )
1679 .trace_err();
1680 }
1681
1682 broadcast(
1683 Some(session.connection_id),
1684 project
1685 .collaborators
1686 .iter()
1687 .map(|collaborator| collaborator.connection_id),
1688 |connection_id| {
1689 session.peer.forward_send(
1690 session.connection_id,
1691 connection_id,
1692 proto::UpdateProject {
1693 project_id: project.id.to_proto(),
1694 worktrees: project.worktrees.clone(),
1695 },
1696 )
1697 },
1698 );
1699 }
1700
1701 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1702
1703 let rejoined_room = rejoined_room.into_inner();
1704
1705 room = rejoined_room.room;
1706 channel = rejoined_room.channel;
1707 }
1708
1709 if let Some(channel) = channel {
1710 channel_updated(
1711 &channel,
1712 &room,
1713 &session.peer,
1714 &*session.connection_pool().await,
1715 );
1716 }
1717
1718 update_user_contacts(session.user_id(), &session).await?;
1719 Ok(())
1720}
1721
1722fn notify_rejoined_projects(
1723 rejoined_projects: &mut Vec<RejoinedProject>,
1724 session: &UserSession,
1725) -> Result<()> {
1726 for project in rejoined_projects.iter() {
1727 for collaborator in &project.collaborators {
1728 session
1729 .peer
1730 .send(
1731 collaborator.connection_id,
1732 proto::UpdateProjectCollaborator {
1733 project_id: project.id.to_proto(),
1734 old_peer_id: Some(project.old_connection_id.into()),
1735 new_peer_id: Some(session.connection_id.into()),
1736 },
1737 )
1738 .trace_err();
1739 }
1740 }
1741
1742 for project in rejoined_projects {
1743 for worktree in mem::take(&mut project.worktrees) {
1744 #[cfg(any(test, feature = "test-support"))]
1745 const MAX_CHUNK_SIZE: usize = 2;
1746 #[cfg(not(any(test, feature = "test-support")))]
1747 const MAX_CHUNK_SIZE: usize = 256;
1748
1749 // Stream this worktree's entries.
1750 let message = proto::UpdateWorktree {
1751 project_id: project.id.to_proto(),
1752 worktree_id: worktree.id,
1753 abs_path: worktree.abs_path.clone(),
1754 root_name: worktree.root_name,
1755 updated_entries: worktree.updated_entries,
1756 removed_entries: worktree.removed_entries,
1757 scan_id: worktree.scan_id,
1758 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1759 updated_repositories: worktree.updated_repositories,
1760 removed_repositories: worktree.removed_repositories,
1761 };
1762 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1763 session.peer.send(session.connection_id, update.clone())?;
1764 }
1765
1766 // Stream this worktree's diagnostics.
1767 for summary in worktree.diagnostic_summaries {
1768 session.peer.send(
1769 session.connection_id,
1770 proto::UpdateDiagnosticSummary {
1771 project_id: project.id.to_proto(),
1772 worktree_id: worktree.id,
1773 summary: Some(summary),
1774 },
1775 )?;
1776 }
1777
1778 for settings_file in worktree.settings_files {
1779 session.peer.send(
1780 session.connection_id,
1781 proto::UpdateWorktreeSettings {
1782 project_id: project.id.to_proto(),
1783 worktree_id: worktree.id,
1784 path: settings_file.path,
1785 content: Some(settings_file.content),
1786 },
1787 )?;
1788 }
1789 }
1790
1791 for language_server in &project.language_servers {
1792 session.peer.send(
1793 session.connection_id,
1794 proto::UpdateLanguageServer {
1795 project_id: project.id.to_proto(),
1796 language_server_id: language_server.id,
1797 variant: Some(
1798 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1799 proto::LspDiskBasedDiagnosticsUpdated {},
1800 ),
1801 ),
1802 },
1803 )?;
1804 }
1805 }
1806 Ok(())
1807}
1808
1809/// leave room disconnects from the room.
1810async fn leave_room(
1811 _: proto::LeaveRoom,
1812 response: Response<proto::LeaveRoom>,
1813 session: UserSession,
1814) -> Result<()> {
1815 leave_room_for_session(&session, session.connection_id).await?;
1816 response.send(proto::Ack {})?;
1817 Ok(())
1818}
1819
1820/// Updates the permissions of someone else in the room.
1821async fn set_room_participant_role(
1822 request: proto::SetRoomParticipantRole,
1823 response: Response<proto::SetRoomParticipantRole>,
1824 session: UserSession,
1825) -> Result<()> {
1826 let user_id = UserId::from_proto(request.user_id);
1827 let role = ChannelRole::from(request.role());
1828
1829 let (live_kit_room, can_publish) = {
1830 let room = session
1831 .db()
1832 .await
1833 .set_room_participant_role(
1834 session.user_id(),
1835 RoomId::from_proto(request.room_id),
1836 user_id,
1837 role,
1838 )
1839 .await?;
1840
1841 let live_kit_room = room.live_kit_room.clone();
1842 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1843 room_updated(&room, &session.peer);
1844 (live_kit_room, can_publish)
1845 };
1846
1847 if let Some(live_kit) = session.live_kit_client.as_ref() {
1848 live_kit
1849 .update_participant(
1850 live_kit_room.clone(),
1851 request.user_id.to_string(),
1852 live_kit_server::proto::ParticipantPermission {
1853 can_subscribe: true,
1854 can_publish,
1855 can_publish_data: can_publish,
1856 hidden: false,
1857 recorder: false,
1858 },
1859 )
1860 .await
1861 .trace_err();
1862 }
1863
1864 response.send(proto::Ack {})?;
1865 Ok(())
1866}
1867
1868/// Call someone else into the current room
1869async fn call(
1870 request: proto::Call,
1871 response: Response<proto::Call>,
1872 session: UserSession,
1873) -> Result<()> {
1874 let room_id = RoomId::from_proto(request.room_id);
1875 let calling_user_id = session.user_id();
1876 let calling_connection_id = session.connection_id;
1877 let called_user_id = UserId::from_proto(request.called_user_id);
1878 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1879 if !session
1880 .db()
1881 .await
1882 .has_contact(calling_user_id, called_user_id)
1883 .await?
1884 {
1885 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1886 }
1887
1888 let incoming_call = {
1889 let (room, incoming_call) = &mut *session
1890 .db()
1891 .await
1892 .call(
1893 room_id,
1894 calling_user_id,
1895 calling_connection_id,
1896 called_user_id,
1897 initial_project_id,
1898 )
1899 .await?;
1900 room_updated(&room, &session.peer);
1901 mem::take(incoming_call)
1902 };
1903 update_user_contacts(called_user_id, &session).await?;
1904
1905 let mut calls = session
1906 .connection_pool()
1907 .await
1908 .user_connection_ids(called_user_id)
1909 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1910 .collect::<FuturesUnordered<_>>();
1911
1912 while let Some(call_response) = calls.next().await {
1913 match call_response.as_ref() {
1914 Ok(_) => {
1915 response.send(proto::Ack {})?;
1916 return Ok(());
1917 }
1918 Err(_) => {
1919 call_response.trace_err();
1920 }
1921 }
1922 }
1923
1924 {
1925 let room = session
1926 .db()
1927 .await
1928 .call_failed(room_id, called_user_id)
1929 .await?;
1930 room_updated(&room, &session.peer);
1931 }
1932 update_user_contacts(called_user_id, &session).await?;
1933
1934 Err(anyhow!("failed to ring user"))?
1935}
1936
1937/// Cancel an outgoing call.
1938async fn cancel_call(
1939 request: proto::CancelCall,
1940 response: Response<proto::CancelCall>,
1941 session: UserSession,
1942) -> Result<()> {
1943 let called_user_id = UserId::from_proto(request.called_user_id);
1944 let room_id = RoomId::from_proto(request.room_id);
1945 {
1946 let room = session
1947 .db()
1948 .await
1949 .cancel_call(room_id, session.connection_id, called_user_id)
1950 .await?;
1951 room_updated(&room, &session.peer);
1952 }
1953
1954 for connection_id in session
1955 .connection_pool()
1956 .await
1957 .user_connection_ids(called_user_id)
1958 {
1959 session
1960 .peer
1961 .send(
1962 connection_id,
1963 proto::CallCanceled {
1964 room_id: room_id.to_proto(),
1965 },
1966 )
1967 .trace_err();
1968 }
1969 response.send(proto::Ack {})?;
1970
1971 update_user_contacts(called_user_id, &session).await?;
1972 Ok(())
1973}
1974
1975/// Decline an incoming call.
1976async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1977 let room_id = RoomId::from_proto(message.room_id);
1978 {
1979 let room = session
1980 .db()
1981 .await
1982 .decline_call(Some(room_id), session.user_id())
1983 .await?
1984 .ok_or_else(|| anyhow!("failed to decline call"))?;
1985 room_updated(&room, &session.peer);
1986 }
1987
1988 for connection_id in session
1989 .connection_pool()
1990 .await
1991 .user_connection_ids(session.user_id())
1992 {
1993 session
1994 .peer
1995 .send(
1996 connection_id,
1997 proto::CallCanceled {
1998 room_id: room_id.to_proto(),
1999 },
2000 )
2001 .trace_err();
2002 }
2003 update_user_contacts(session.user_id(), &session).await?;
2004 Ok(())
2005}
2006
2007/// Updates other participants in the room with your current location.
2008async fn update_participant_location(
2009 request: proto::UpdateParticipantLocation,
2010 response: Response<proto::UpdateParticipantLocation>,
2011 session: UserSession,
2012) -> Result<()> {
2013 let room_id = RoomId::from_proto(request.room_id);
2014 let location = request
2015 .location
2016 .ok_or_else(|| anyhow!("invalid location"))?;
2017
2018 let db = session.db().await;
2019 let room = db
2020 .update_room_participant_location(room_id, session.connection_id, location)
2021 .await?;
2022
2023 room_updated(&room, &session.peer);
2024 response.send(proto::Ack {})?;
2025 Ok(())
2026}
2027
2028/// Share a project into the room.
2029async fn share_project(
2030 request: proto::ShareProject,
2031 response: Response<proto::ShareProject>,
2032 session: UserSession,
2033) -> Result<()> {
2034 let (project_id, room) = &*session
2035 .db()
2036 .await
2037 .share_project(
2038 RoomId::from_proto(request.room_id),
2039 session.connection_id,
2040 &request.worktrees,
2041 request
2042 .dev_server_project_id
2043 .map(|id| DevServerProjectId::from_proto(id)),
2044 )
2045 .await?;
2046 response.send(proto::ShareProjectResponse {
2047 project_id: project_id.to_proto(),
2048 })?;
2049 room_updated(&room, &session.peer);
2050
2051 Ok(())
2052}
2053
2054/// Unshare a project from the room.
2055async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2056 let project_id = ProjectId::from_proto(message.project_id);
2057 unshare_project_internal(
2058 project_id,
2059 session.connection_id,
2060 session.user_id(),
2061 &session,
2062 )
2063 .await
2064}
2065
2066async fn unshare_project_internal(
2067 project_id: ProjectId,
2068 connection_id: ConnectionId,
2069 user_id: Option<UserId>,
2070 session: &Session,
2071) -> Result<()> {
2072 let delete = {
2073 let room_guard = session
2074 .db()
2075 .await
2076 .unshare_project(project_id, connection_id, user_id)
2077 .await?;
2078
2079 let (delete, room, guest_connection_ids) = &*room_guard;
2080
2081 let message = proto::UnshareProject {
2082 project_id: project_id.to_proto(),
2083 };
2084
2085 broadcast(
2086 Some(connection_id),
2087 guest_connection_ids.iter().copied(),
2088 |conn_id| session.peer.send(conn_id, message.clone()),
2089 );
2090 if let Some(room) = room {
2091 room_updated(room, &session.peer);
2092 }
2093
2094 *delete
2095 };
2096
2097 if delete {
2098 let db = session.db().await;
2099 db.delete_project(project_id).await?;
2100 }
2101
2102 Ok(())
2103}
2104
2105/// DevServer makes a project available online
2106async fn share_dev_server_project(
2107 request: proto::ShareDevServerProject,
2108 response: Response<proto::ShareDevServerProject>,
2109 session: DevServerSession,
2110) -> Result<()> {
2111 let (dev_server_project, user_id, status) = session
2112 .db()
2113 .await
2114 .share_dev_server_project(
2115 DevServerProjectId::from_proto(request.dev_server_project_id),
2116 session.dev_server_id(),
2117 session.connection_id,
2118 &request.worktrees,
2119 )
2120 .await?;
2121 let Some(project_id) = dev_server_project.project_id else {
2122 return Err(anyhow!("failed to share remote project"))?;
2123 };
2124
2125 send_dev_server_projects_update(user_id, status, &session).await;
2126
2127 response.send(proto::ShareProjectResponse { project_id })?;
2128
2129 Ok(())
2130}
2131
2132/// Join someone elses shared project.
2133async fn join_project(
2134 request: proto::JoinProject,
2135 response: Response<proto::JoinProject>,
2136 session: UserSession,
2137) -> Result<()> {
2138 let project_id = ProjectId::from_proto(request.project_id);
2139
2140 tracing::info!(%project_id, "join project");
2141
2142 let db = session.db().await;
2143 let (project, replica_id) = &mut *db
2144 .join_project(project_id, session.connection_id, session.user_id())
2145 .await?;
2146 drop(db);
2147 tracing::info!(%project_id, "join remote project");
2148 join_project_internal(response, session, project, replica_id)
2149}
2150
2151trait JoinProjectInternalResponse {
2152 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2153}
2154impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2155 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2156 Response::<proto::JoinProject>::send(self, result)
2157 }
2158}
2159impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2160 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2161 Response::<proto::JoinHostedProject>::send(self, result)
2162 }
2163}
2164
2165fn join_project_internal(
2166 response: impl JoinProjectInternalResponse,
2167 session: UserSession,
2168 project: &mut Project,
2169 replica_id: &ReplicaId,
2170) -> Result<()> {
2171 let collaborators = project
2172 .collaborators
2173 .iter()
2174 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2175 .map(|collaborator| collaborator.to_proto())
2176 .collect::<Vec<_>>();
2177 let project_id = project.id;
2178 let guest_user_id = session.user_id();
2179
2180 let worktrees = project
2181 .worktrees
2182 .iter()
2183 .map(|(id, worktree)| proto::WorktreeMetadata {
2184 id: *id,
2185 root_name: worktree.root_name.clone(),
2186 visible: worktree.visible,
2187 abs_path: worktree.abs_path.clone(),
2188 })
2189 .collect::<Vec<_>>();
2190
2191 let add_project_collaborator = proto::AddProjectCollaborator {
2192 project_id: project_id.to_proto(),
2193 collaborator: Some(proto::Collaborator {
2194 peer_id: Some(session.connection_id.into()),
2195 replica_id: replica_id.0 as u32,
2196 user_id: guest_user_id.to_proto(),
2197 }),
2198 };
2199
2200 for collaborator in &collaborators {
2201 session
2202 .peer
2203 .send(
2204 collaborator.peer_id.unwrap().into(),
2205 add_project_collaborator.clone(),
2206 )
2207 .trace_err();
2208 }
2209
2210 // First, we send the metadata associated with each worktree.
2211 response.send(proto::JoinProjectResponse {
2212 project_id: project.id.0 as u64,
2213 worktrees: worktrees.clone(),
2214 replica_id: replica_id.0 as u32,
2215 collaborators: collaborators.clone(),
2216 language_servers: project.language_servers.clone(),
2217 role: project.role.into(),
2218 dev_server_project_id: project
2219 .dev_server_project_id
2220 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2221 })?;
2222
2223 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2224 #[cfg(any(test, feature = "test-support"))]
2225 const MAX_CHUNK_SIZE: usize = 2;
2226 #[cfg(not(any(test, feature = "test-support")))]
2227 const MAX_CHUNK_SIZE: usize = 256;
2228
2229 // Stream this worktree's entries.
2230 let message = proto::UpdateWorktree {
2231 project_id: project_id.to_proto(),
2232 worktree_id,
2233 abs_path: worktree.abs_path.clone(),
2234 root_name: worktree.root_name,
2235 updated_entries: worktree.entries,
2236 removed_entries: Default::default(),
2237 scan_id: worktree.scan_id,
2238 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2239 updated_repositories: worktree.repository_entries.into_values().collect(),
2240 removed_repositories: Default::default(),
2241 };
2242 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2243 session.peer.send(session.connection_id, update.clone())?;
2244 }
2245
2246 // Stream this worktree's diagnostics.
2247 for summary in worktree.diagnostic_summaries {
2248 session.peer.send(
2249 session.connection_id,
2250 proto::UpdateDiagnosticSummary {
2251 project_id: project_id.to_proto(),
2252 worktree_id: worktree.id,
2253 summary: Some(summary),
2254 },
2255 )?;
2256 }
2257
2258 for settings_file in worktree.settings_files {
2259 session.peer.send(
2260 session.connection_id,
2261 proto::UpdateWorktreeSettings {
2262 project_id: project_id.to_proto(),
2263 worktree_id: worktree.id,
2264 path: settings_file.path,
2265 content: Some(settings_file.content),
2266 },
2267 )?;
2268 }
2269 }
2270
2271 for language_server in &project.language_servers {
2272 session.peer.send(
2273 session.connection_id,
2274 proto::UpdateLanguageServer {
2275 project_id: project_id.to_proto(),
2276 language_server_id: language_server.id,
2277 variant: Some(
2278 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2279 proto::LspDiskBasedDiagnosticsUpdated {},
2280 ),
2281 ),
2282 },
2283 )?;
2284 }
2285
2286 Ok(())
2287}
2288
2289/// Leave someone elses shared project.
2290async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2291 let sender_id = session.connection_id;
2292 let project_id = ProjectId::from_proto(request.project_id);
2293 let db = session.db().await;
2294 if db.is_hosted_project(project_id).await? {
2295 let project = db.leave_hosted_project(project_id, sender_id).await?;
2296 project_left(&project, &session);
2297 return Ok(());
2298 }
2299
2300 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2301 tracing::info!(
2302 %project_id,
2303 "leave project"
2304 );
2305
2306 project_left(&project, &session);
2307 if let Some(room) = room {
2308 room_updated(&room, &session.peer);
2309 }
2310
2311 Ok(())
2312}
2313
2314async fn join_hosted_project(
2315 request: proto::JoinHostedProject,
2316 response: Response<proto::JoinHostedProject>,
2317 session: UserSession,
2318) -> Result<()> {
2319 let (mut project, replica_id) = session
2320 .db()
2321 .await
2322 .join_hosted_project(
2323 ProjectId(request.project_id as i32),
2324 session.user_id(),
2325 session.connection_id,
2326 )
2327 .await?;
2328
2329 join_project_internal(response, session, &mut project, &replica_id)
2330}
2331
2332async fn list_remote_directory(
2333 request: proto::ListRemoteDirectory,
2334 response: Response<proto::ListRemoteDirectory>,
2335 session: UserSession,
2336) -> Result<()> {
2337 let dev_server_id = DevServerId(request.dev_server_id as i32);
2338 let dev_server_connection_id = session
2339 .connection_pool()
2340 .await
2341 .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2342
2343 session
2344 .db()
2345 .await
2346 .get_dev_server_for_user(dev_server_id, session.user_id())
2347 .await?;
2348
2349 response.send(
2350 session
2351 .peer
2352 .forward_request(session.connection_id, dev_server_connection_id, request)
2353 .await?,
2354 )?;
2355 Ok(())
2356}
2357
2358async fn update_dev_server_project(
2359 request: proto::UpdateDevServerProject,
2360 response: Response<proto::UpdateDevServerProject>,
2361 session: UserSession,
2362) -> Result<()> {
2363 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2364
2365 let (dev_server_project, update) = session
2366 .db()
2367 .await
2368 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2369 .await?;
2370
2371 let projects = session
2372 .db()
2373 .await
2374 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2375 .await?;
2376
2377 let dev_server_connection_id = session
2378 .connection_pool()
2379 .await
2380 .dev_server_connection_id_supporting(
2381 dev_server_project.dev_server_id,
2382 ZedVersion::with_list_directory(),
2383 )?;
2384
2385 session.peer.send(
2386 dev_server_connection_id,
2387 proto::DevServerInstructions { projects },
2388 )?;
2389
2390 send_dev_server_projects_update(session.user_id(), update, &session).await;
2391
2392 response.send(proto::Ack {})
2393}
2394
2395async fn create_dev_server_project(
2396 request: proto::CreateDevServerProject,
2397 response: Response<proto::CreateDevServerProject>,
2398 session: UserSession,
2399) -> Result<()> {
2400 let dev_server_id = DevServerId(request.dev_server_id as i32);
2401 let dev_server_connection_id = session
2402 .connection_pool()
2403 .await
2404 .dev_server_connection_id(dev_server_id);
2405 let Some(dev_server_connection_id) = dev_server_connection_id else {
2406 Err(ErrorCode::DevServerOffline
2407 .message("Cannot create a remote project when the dev server is offline".to_string())
2408 .anyhow())?
2409 };
2410
2411 let path = request.path.clone();
2412 //Check that the path exists on the dev server
2413 session
2414 .peer
2415 .forward_request(
2416 session.connection_id,
2417 dev_server_connection_id,
2418 proto::ValidateDevServerProjectRequest { path: path.clone() },
2419 )
2420 .await?;
2421
2422 let (dev_server_project, update) = session
2423 .db()
2424 .await
2425 .create_dev_server_project(
2426 DevServerId(request.dev_server_id as i32),
2427 &request.path,
2428 session.user_id(),
2429 )
2430 .await?;
2431
2432 let projects = session
2433 .db()
2434 .await
2435 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2436 .await?;
2437
2438 session.peer.send(
2439 dev_server_connection_id,
2440 proto::DevServerInstructions { projects },
2441 )?;
2442
2443 send_dev_server_projects_update(session.user_id(), update, &session).await;
2444
2445 response.send(proto::CreateDevServerProjectResponse {
2446 dev_server_project: Some(dev_server_project.to_proto(None)),
2447 })?;
2448 Ok(())
2449}
2450
2451async fn create_dev_server(
2452 request: proto::CreateDevServer,
2453 response: Response<proto::CreateDevServer>,
2454 session: UserSession,
2455) -> Result<()> {
2456 let access_token = auth::random_token();
2457 let hashed_access_token = auth::hash_access_token(&access_token);
2458
2459 if request.name.is_empty() {
2460 return Err(proto::ErrorCode::Forbidden
2461 .message("Dev server name cannot be empty".to_string())
2462 .anyhow())?;
2463 }
2464
2465 let (dev_server, status) = session
2466 .db()
2467 .await
2468 .create_dev_server(
2469 &request.name,
2470 request.ssh_connection_string.as_deref(),
2471 &hashed_access_token,
2472 session.user_id(),
2473 )
2474 .await?;
2475
2476 send_dev_server_projects_update(session.user_id(), status, &session).await;
2477
2478 response.send(proto::CreateDevServerResponse {
2479 dev_server_id: dev_server.id.0 as u64,
2480 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2481 name: request.name,
2482 })?;
2483 Ok(())
2484}
2485
2486async fn regenerate_dev_server_token(
2487 request: proto::RegenerateDevServerToken,
2488 response: Response<proto::RegenerateDevServerToken>,
2489 session: UserSession,
2490) -> Result<()> {
2491 let dev_server_id = DevServerId(request.dev_server_id as i32);
2492 let access_token = auth::random_token();
2493 let hashed_access_token = auth::hash_access_token(&access_token);
2494
2495 let connection_id = session
2496 .connection_pool()
2497 .await
2498 .dev_server_connection_id(dev_server_id);
2499 if let Some(connection_id) = connection_id {
2500 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2501 session.peer.send(
2502 connection_id,
2503 proto::ShutdownDevServer {
2504 reason: Some("dev server token was regenerated".to_string()),
2505 },
2506 )?;
2507 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2508 }
2509
2510 let status = session
2511 .db()
2512 .await
2513 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2514 .await?;
2515
2516 send_dev_server_projects_update(session.user_id(), status, &session).await;
2517
2518 response.send(proto::RegenerateDevServerTokenResponse {
2519 dev_server_id: dev_server_id.to_proto(),
2520 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2521 })?;
2522 Ok(())
2523}
2524
2525async fn rename_dev_server(
2526 request: proto::RenameDevServer,
2527 response: Response<proto::RenameDevServer>,
2528 session: UserSession,
2529) -> Result<()> {
2530 if request.name.trim().is_empty() {
2531 return Err(proto::ErrorCode::Forbidden
2532 .message("Dev server name cannot be empty".to_string())
2533 .anyhow())?;
2534 }
2535
2536 let dev_server_id = DevServerId(request.dev_server_id as i32);
2537 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2538 if dev_server.user_id != session.user_id() {
2539 return Err(anyhow!(ErrorCode::Forbidden))?;
2540 }
2541
2542 let status = session
2543 .db()
2544 .await
2545 .rename_dev_server(
2546 dev_server_id,
2547 &request.name,
2548 request.ssh_connection_string.as_deref(),
2549 session.user_id(),
2550 )
2551 .await?;
2552
2553 send_dev_server_projects_update(session.user_id(), status, &session).await;
2554
2555 response.send(proto::Ack {})?;
2556 Ok(())
2557}
2558
2559async fn delete_dev_server(
2560 request: proto::DeleteDevServer,
2561 response: Response<proto::DeleteDevServer>,
2562 session: UserSession,
2563) -> Result<()> {
2564 let dev_server_id = DevServerId(request.dev_server_id as i32);
2565 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2566 if dev_server.user_id != session.user_id() {
2567 return Err(anyhow!(ErrorCode::Forbidden))?;
2568 }
2569
2570 let connection_id = session
2571 .connection_pool()
2572 .await
2573 .dev_server_connection_id(dev_server_id);
2574 if let Some(connection_id) = connection_id {
2575 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2576 session.peer.send(
2577 connection_id,
2578 proto::ShutdownDevServer {
2579 reason: Some("dev server was deleted".to_string()),
2580 },
2581 )?;
2582 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2583 }
2584
2585 let status = session
2586 .db()
2587 .await
2588 .delete_dev_server(dev_server_id, session.user_id())
2589 .await?;
2590
2591 send_dev_server_projects_update(session.user_id(), status, &session).await;
2592
2593 response.send(proto::Ack {})?;
2594 Ok(())
2595}
2596
2597async fn delete_dev_server_project(
2598 request: proto::DeleteDevServerProject,
2599 response: Response<proto::DeleteDevServerProject>,
2600 session: UserSession,
2601) -> Result<()> {
2602 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2603 let dev_server_project = session
2604 .db()
2605 .await
2606 .get_dev_server_project(dev_server_project_id)
2607 .await?;
2608
2609 let dev_server = session
2610 .db()
2611 .await
2612 .get_dev_server(dev_server_project.dev_server_id)
2613 .await?;
2614 if dev_server.user_id != session.user_id() {
2615 return Err(anyhow!(ErrorCode::Forbidden))?;
2616 }
2617
2618 let dev_server_connection_id = session
2619 .connection_pool()
2620 .await
2621 .dev_server_connection_id(dev_server.id);
2622
2623 if let Some(dev_server_connection_id) = dev_server_connection_id {
2624 let project = session
2625 .db()
2626 .await
2627 .find_dev_server_project(dev_server_project_id)
2628 .await;
2629 if let Ok(project) = project {
2630 unshare_project_internal(
2631 project.id,
2632 dev_server_connection_id,
2633 Some(session.user_id()),
2634 &session,
2635 )
2636 .await?;
2637 }
2638 }
2639
2640 let (projects, status) = session
2641 .db()
2642 .await
2643 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2644 .await?;
2645
2646 if let Some(dev_server_connection_id) = dev_server_connection_id {
2647 session.peer.send(
2648 dev_server_connection_id,
2649 proto::DevServerInstructions { projects },
2650 )?;
2651 }
2652
2653 send_dev_server_projects_update(session.user_id(), status, &session).await;
2654
2655 response.send(proto::Ack {})?;
2656 Ok(())
2657}
2658
2659async fn rejoin_dev_server_projects(
2660 request: proto::RejoinRemoteProjects,
2661 response: Response<proto::RejoinRemoteProjects>,
2662 session: UserSession,
2663) -> Result<()> {
2664 let mut rejoined_projects = {
2665 let db = session.db().await;
2666 db.rejoin_dev_server_projects(
2667 &request.rejoined_projects,
2668 session.user_id(),
2669 session.0.connection_id,
2670 )
2671 .await?
2672 };
2673 response.send(proto::RejoinRemoteProjectsResponse {
2674 rejoined_projects: rejoined_projects
2675 .iter()
2676 .map(|project| project.to_proto())
2677 .collect(),
2678 })?;
2679 notify_rejoined_projects(&mut rejoined_projects, &session)
2680}
2681
2682async fn reconnect_dev_server(
2683 request: proto::ReconnectDevServer,
2684 response: Response<proto::ReconnectDevServer>,
2685 session: DevServerSession,
2686) -> Result<()> {
2687 let reshared_projects = {
2688 let db = session.db().await;
2689 db.reshare_dev_server_projects(
2690 &request.reshared_projects,
2691 session.dev_server_id(),
2692 session.0.connection_id,
2693 )
2694 .await?
2695 };
2696
2697 for project in &reshared_projects {
2698 for collaborator in &project.collaborators {
2699 session
2700 .peer
2701 .send(
2702 collaborator.connection_id,
2703 proto::UpdateProjectCollaborator {
2704 project_id: project.id.to_proto(),
2705 old_peer_id: Some(project.old_connection_id.into()),
2706 new_peer_id: Some(session.connection_id.into()),
2707 },
2708 )
2709 .trace_err();
2710 }
2711
2712 broadcast(
2713 Some(session.connection_id),
2714 project
2715 .collaborators
2716 .iter()
2717 .map(|collaborator| collaborator.connection_id),
2718 |connection_id| {
2719 session.peer.forward_send(
2720 session.connection_id,
2721 connection_id,
2722 proto::UpdateProject {
2723 project_id: project.id.to_proto(),
2724 worktrees: project.worktrees.clone(),
2725 },
2726 )
2727 },
2728 );
2729 }
2730
2731 response.send(proto::ReconnectDevServerResponse {
2732 reshared_projects: reshared_projects
2733 .iter()
2734 .map(|project| proto::ResharedProject {
2735 id: project.id.to_proto(),
2736 collaborators: project
2737 .collaborators
2738 .iter()
2739 .map(|collaborator| collaborator.to_proto())
2740 .collect(),
2741 })
2742 .collect(),
2743 })?;
2744
2745 Ok(())
2746}
2747
2748async fn shutdown_dev_server(
2749 _: proto::ShutdownDevServer,
2750 response: Response<proto::ShutdownDevServer>,
2751 session: DevServerSession,
2752) -> Result<()> {
2753 response.send(proto::Ack {})?;
2754 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2755 remove_dev_server_connection(session.dev_server_id(), &session).await
2756}
2757
2758async fn shutdown_dev_server_internal(
2759 dev_server_id: DevServerId,
2760 connection_id: ConnectionId,
2761 session: &Session,
2762) -> Result<()> {
2763 let (dev_server_projects, dev_server) = {
2764 let db = session.db().await;
2765 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2766 let dev_server = db.get_dev_server(dev_server_id).await?;
2767 (dev_server_projects, dev_server)
2768 };
2769
2770 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2771 unshare_project_internal(
2772 ProjectId::from_proto(project_id),
2773 connection_id,
2774 None,
2775 session,
2776 )
2777 .await?;
2778 }
2779
2780 session
2781 .connection_pool()
2782 .await
2783 .set_dev_server_offline(dev_server_id);
2784
2785 let status = session
2786 .db()
2787 .await
2788 .dev_server_projects_update(dev_server.user_id)
2789 .await?;
2790 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2791
2792 Ok(())
2793}
2794
2795async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2796 let dev_server_connection = session
2797 .connection_pool()
2798 .await
2799 .dev_server_connection_id(dev_server_id);
2800
2801 if let Some(dev_server_connection) = dev_server_connection {
2802 session
2803 .connection_pool()
2804 .await
2805 .remove_connection(dev_server_connection)?;
2806 }
2807 Ok(())
2808}
2809
2810/// Updates other participants with changes to the project
2811async fn update_project(
2812 request: proto::UpdateProject,
2813 response: Response<proto::UpdateProject>,
2814 session: Session,
2815) -> Result<()> {
2816 let project_id = ProjectId::from_proto(request.project_id);
2817 let (room, guest_connection_ids) = &*session
2818 .db()
2819 .await
2820 .update_project(project_id, session.connection_id, &request.worktrees)
2821 .await?;
2822 broadcast(
2823 Some(session.connection_id),
2824 guest_connection_ids.iter().copied(),
2825 |connection_id| {
2826 session
2827 .peer
2828 .forward_send(session.connection_id, connection_id, request.clone())
2829 },
2830 );
2831 if let Some(room) = room {
2832 room_updated(&room, &session.peer);
2833 }
2834 response.send(proto::Ack {})?;
2835
2836 Ok(())
2837}
2838
2839/// Updates other participants with changes to the worktree
2840async fn update_worktree(
2841 request: proto::UpdateWorktree,
2842 response: Response<proto::UpdateWorktree>,
2843 session: Session,
2844) -> Result<()> {
2845 let guest_connection_ids = session
2846 .db()
2847 .await
2848 .update_worktree(&request, session.connection_id)
2849 .await?;
2850
2851 broadcast(
2852 Some(session.connection_id),
2853 guest_connection_ids.iter().copied(),
2854 |connection_id| {
2855 session
2856 .peer
2857 .forward_send(session.connection_id, connection_id, request.clone())
2858 },
2859 );
2860 response.send(proto::Ack {})?;
2861 Ok(())
2862}
2863
2864/// Updates other participants with changes to the diagnostics
2865async fn update_diagnostic_summary(
2866 message: proto::UpdateDiagnosticSummary,
2867 session: Session,
2868) -> Result<()> {
2869 let guest_connection_ids = session
2870 .db()
2871 .await
2872 .update_diagnostic_summary(&message, session.connection_id)
2873 .await?;
2874
2875 broadcast(
2876 Some(session.connection_id),
2877 guest_connection_ids.iter().copied(),
2878 |connection_id| {
2879 session
2880 .peer
2881 .forward_send(session.connection_id, connection_id, message.clone())
2882 },
2883 );
2884
2885 Ok(())
2886}
2887
2888/// Updates other participants with changes to the worktree settings
2889async fn update_worktree_settings(
2890 message: proto::UpdateWorktreeSettings,
2891 session: Session,
2892) -> Result<()> {
2893 let guest_connection_ids = session
2894 .db()
2895 .await
2896 .update_worktree_settings(&message, session.connection_id)
2897 .await?;
2898
2899 broadcast(
2900 Some(session.connection_id),
2901 guest_connection_ids.iter().copied(),
2902 |connection_id| {
2903 session
2904 .peer
2905 .forward_send(session.connection_id, connection_id, message.clone())
2906 },
2907 );
2908
2909 Ok(())
2910}
2911
2912/// Notify other participants that a language server has started.
2913async fn start_language_server(
2914 request: proto::StartLanguageServer,
2915 session: Session,
2916) -> Result<()> {
2917 let guest_connection_ids = session
2918 .db()
2919 .await
2920 .start_language_server(&request, session.connection_id)
2921 .await?;
2922
2923 broadcast(
2924 Some(session.connection_id),
2925 guest_connection_ids.iter().copied(),
2926 |connection_id| {
2927 session
2928 .peer
2929 .forward_send(session.connection_id, connection_id, request.clone())
2930 },
2931 );
2932 Ok(())
2933}
2934
2935/// Notify other participants that a language server has changed.
2936async fn update_language_server(
2937 request: proto::UpdateLanguageServer,
2938 session: Session,
2939) -> Result<()> {
2940 let project_id = ProjectId::from_proto(request.project_id);
2941 let project_connection_ids = session
2942 .db()
2943 .await
2944 .project_connection_ids(project_id, session.connection_id, true)
2945 .await?;
2946 broadcast(
2947 Some(session.connection_id),
2948 project_connection_ids.iter().copied(),
2949 |connection_id| {
2950 session
2951 .peer
2952 .forward_send(session.connection_id, connection_id, request.clone())
2953 },
2954 );
2955 Ok(())
2956}
2957
2958/// forward a project request to the host. These requests should be read only
2959/// as guests are allowed to send them.
2960async fn forward_read_only_project_request<T>(
2961 request: T,
2962 response: Response<T>,
2963 session: UserSession,
2964) -> Result<()>
2965where
2966 T: EntityMessage + RequestMessage,
2967{
2968 let project_id = ProjectId::from_proto(request.remote_entity_id());
2969 let host_connection_id = session
2970 .db()
2971 .await
2972 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2973 .await?;
2974 let payload = session
2975 .peer
2976 .forward_request(session.connection_id, host_connection_id, request)
2977 .await?;
2978 response.send(payload)?;
2979 Ok(())
2980}
2981
2982/// forward a project request to the dev server. Only allowed
2983/// if it's your dev server.
2984async fn forward_project_request_for_owner<T>(
2985 request: T,
2986 response: Response<T>,
2987 session: UserSession,
2988) -> Result<()>
2989where
2990 T: EntityMessage + RequestMessage,
2991{
2992 let project_id = ProjectId::from_proto(request.remote_entity_id());
2993
2994 let host_connection_id = session
2995 .db()
2996 .await
2997 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2998 .await?;
2999 let payload = session
3000 .peer
3001 .forward_request(session.connection_id, host_connection_id, request)
3002 .await?;
3003 response.send(payload)?;
3004 Ok(())
3005}
3006
3007/// forward a project request to the host. These requests are disallowed
3008/// for guests.
3009async fn forward_mutating_project_request<T>(
3010 request: T,
3011 response: Response<T>,
3012 session: UserSession,
3013) -> Result<()>
3014where
3015 T: EntityMessage + RequestMessage,
3016{
3017 let project_id = ProjectId::from_proto(request.remote_entity_id());
3018
3019 let host_connection_id = session
3020 .db()
3021 .await
3022 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3023 .await?;
3024 let payload = session
3025 .peer
3026 .forward_request(session.connection_id, host_connection_id, request)
3027 .await?;
3028 response.send(payload)?;
3029 Ok(())
3030}
3031
3032/// forward a project request to the host. These requests are disallowed
3033/// for guests.
3034async fn forward_versioned_mutating_project_request<T>(
3035 request: T,
3036 response: Response<T>,
3037 session: UserSession,
3038) -> Result<()>
3039where
3040 T: EntityMessage + RequestMessage + VersionedMessage,
3041{
3042 let project_id = ProjectId::from_proto(request.remote_entity_id());
3043
3044 let host_connection_id = session
3045 .db()
3046 .await
3047 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3048 .await?;
3049 if let Some(host_version) = session
3050 .connection_pool()
3051 .await
3052 .connection(host_connection_id)
3053 .map(|c| c.zed_version)
3054 {
3055 if let Some(min_required_version) = request.required_host_version() {
3056 if min_required_version > host_version {
3057 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3058 .with_tag("required", &min_required_version.to_string())))?;
3059 }
3060 }
3061 }
3062
3063 let payload = session
3064 .peer
3065 .forward_request(session.connection_id, host_connection_id, request)
3066 .await?;
3067 response.send(payload)?;
3068 Ok(())
3069}
3070
3071/// Notify other participants that a new buffer has been created
3072async fn create_buffer_for_peer(
3073 request: proto::CreateBufferForPeer,
3074 session: Session,
3075) -> Result<()> {
3076 session
3077 .db()
3078 .await
3079 .check_user_is_project_host(
3080 ProjectId::from_proto(request.project_id),
3081 session.connection_id,
3082 )
3083 .await?;
3084 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3085 session
3086 .peer
3087 .forward_send(session.connection_id, peer_id.into(), request)?;
3088 Ok(())
3089}
3090
3091/// Notify other participants that a buffer has been updated. This is
3092/// allowed for guests as long as the update is limited to selections.
3093async fn update_buffer(
3094 request: proto::UpdateBuffer,
3095 response: Response<proto::UpdateBuffer>,
3096 session: Session,
3097) -> Result<()> {
3098 let project_id = ProjectId::from_proto(request.project_id);
3099 let mut capability = Capability::ReadOnly;
3100
3101 for op in request.operations.iter() {
3102 match op.variant {
3103 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3104 Some(_) => capability = Capability::ReadWrite,
3105 }
3106 }
3107
3108 let host = {
3109 let guard = session
3110 .db()
3111 .await
3112 .connections_for_buffer_update(
3113 project_id,
3114 session.principal_id(),
3115 session.connection_id,
3116 capability,
3117 )
3118 .await?;
3119
3120 let (host, guests) = &*guard;
3121
3122 broadcast(
3123 Some(session.connection_id),
3124 guests.clone(),
3125 |connection_id| {
3126 session
3127 .peer
3128 .forward_send(session.connection_id, connection_id, request.clone())
3129 },
3130 );
3131
3132 *host
3133 };
3134
3135 if host != session.connection_id {
3136 session
3137 .peer
3138 .forward_request(session.connection_id, host, request.clone())
3139 .await?;
3140 }
3141
3142 response.send(proto::Ack {})?;
3143 Ok(())
3144}
3145
3146async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3147 let project_id = ProjectId::from_proto(message.project_id);
3148
3149 let operation = message.operation.as_ref().context("invalid operation")?;
3150 let capability = match operation.variant.as_ref() {
3151 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3152 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3153 match buffer_op.variant {
3154 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3155 Capability::ReadOnly
3156 }
3157 _ => Capability::ReadWrite,
3158 }
3159 } else {
3160 Capability::ReadWrite
3161 }
3162 }
3163 Some(_) => Capability::ReadWrite,
3164 None => Capability::ReadOnly,
3165 };
3166
3167 let guard = session
3168 .db()
3169 .await
3170 .connections_for_buffer_update(
3171 project_id,
3172 session.principal_id(),
3173 session.connection_id,
3174 capability,
3175 )
3176 .await?;
3177
3178 let (host, guests) = &*guard;
3179
3180 broadcast(
3181 Some(session.connection_id),
3182 guests.iter().chain([host]).copied(),
3183 |connection_id| {
3184 session
3185 .peer
3186 .forward_send(session.connection_id, connection_id, message.clone())
3187 },
3188 );
3189
3190 Ok(())
3191}
3192
3193/// Notify other participants that a project has been updated.
3194async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3195 request: T,
3196 session: Session,
3197) -> Result<()> {
3198 let project_id = ProjectId::from_proto(request.remote_entity_id());
3199 let project_connection_ids = session
3200 .db()
3201 .await
3202 .project_connection_ids(project_id, session.connection_id, false)
3203 .await?;
3204
3205 broadcast(
3206 Some(session.connection_id),
3207 project_connection_ids.iter().copied(),
3208 |connection_id| {
3209 session
3210 .peer
3211 .forward_send(session.connection_id, connection_id, request.clone())
3212 },
3213 );
3214 Ok(())
3215}
3216
3217/// Start following another user in a call.
3218async fn follow(
3219 request: proto::Follow,
3220 response: Response<proto::Follow>,
3221 session: UserSession,
3222) -> Result<()> {
3223 let room_id = RoomId::from_proto(request.room_id);
3224 let project_id = request.project_id.map(ProjectId::from_proto);
3225 let leader_id = request
3226 .leader_id
3227 .ok_or_else(|| anyhow!("invalid leader id"))?
3228 .into();
3229 let follower_id = session.connection_id;
3230
3231 session
3232 .db()
3233 .await
3234 .check_room_participants(room_id, leader_id, session.connection_id)
3235 .await?;
3236
3237 let response_payload = session
3238 .peer
3239 .forward_request(session.connection_id, leader_id, request)
3240 .await?;
3241 response.send(response_payload)?;
3242
3243 if let Some(project_id) = project_id {
3244 let room = session
3245 .db()
3246 .await
3247 .follow(room_id, project_id, leader_id, follower_id)
3248 .await?;
3249 room_updated(&room, &session.peer);
3250 }
3251
3252 Ok(())
3253}
3254
3255/// Stop following another user in a call.
3256async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3257 let room_id = RoomId::from_proto(request.room_id);
3258 let project_id = request.project_id.map(ProjectId::from_proto);
3259 let leader_id = request
3260 .leader_id
3261 .ok_or_else(|| anyhow!("invalid leader id"))?
3262 .into();
3263 let follower_id = session.connection_id;
3264
3265 session
3266 .db()
3267 .await
3268 .check_room_participants(room_id, leader_id, session.connection_id)
3269 .await?;
3270
3271 session
3272 .peer
3273 .forward_send(session.connection_id, leader_id, request)?;
3274
3275 if let Some(project_id) = project_id {
3276 let room = session
3277 .db()
3278 .await
3279 .unfollow(room_id, project_id, leader_id, follower_id)
3280 .await?;
3281 room_updated(&room, &session.peer);
3282 }
3283
3284 Ok(())
3285}
3286
3287/// Notify everyone following you of your current location.
3288async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3289 let room_id = RoomId::from_proto(request.room_id);
3290 let database = session.db.lock().await;
3291
3292 let connection_ids = if let Some(project_id) = request.project_id {
3293 let project_id = ProjectId::from_proto(project_id);
3294 database
3295 .project_connection_ids(project_id, session.connection_id, true)
3296 .await?
3297 } else {
3298 database
3299 .room_connection_ids(room_id, session.connection_id)
3300 .await?
3301 };
3302
3303 // For now, don't send view update messages back to that view's current leader.
3304 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3305 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3306 _ => None,
3307 });
3308
3309 for connection_id in connection_ids.iter().cloned() {
3310 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3311 session
3312 .peer
3313 .forward_send(session.connection_id, connection_id, request.clone())?;
3314 }
3315 }
3316 Ok(())
3317}
3318
3319/// Get public data about users.
3320async fn get_users(
3321 request: proto::GetUsers,
3322 response: Response<proto::GetUsers>,
3323 session: Session,
3324) -> Result<()> {
3325 let user_ids = request
3326 .user_ids
3327 .into_iter()
3328 .map(UserId::from_proto)
3329 .collect();
3330 let users = session
3331 .db()
3332 .await
3333 .get_users_by_ids(user_ids)
3334 .await?
3335 .into_iter()
3336 .map(|user| proto::User {
3337 id: user.id.to_proto(),
3338 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3339 github_login: user.github_login,
3340 })
3341 .collect();
3342 response.send(proto::UsersResponse { users })?;
3343 Ok(())
3344}
3345
3346/// Search for users (to invite) buy Github login
3347async fn fuzzy_search_users(
3348 request: proto::FuzzySearchUsers,
3349 response: Response<proto::FuzzySearchUsers>,
3350 session: UserSession,
3351) -> Result<()> {
3352 let query = request.query;
3353 let users = match query.len() {
3354 0 => vec![],
3355 1 | 2 => session
3356 .db()
3357 .await
3358 .get_user_by_github_login(&query)
3359 .await?
3360 .into_iter()
3361 .collect(),
3362 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3363 };
3364 let users = users
3365 .into_iter()
3366 .filter(|user| user.id != session.user_id())
3367 .map(|user| proto::User {
3368 id: user.id.to_proto(),
3369 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3370 github_login: user.github_login,
3371 })
3372 .collect();
3373 response.send(proto::UsersResponse { users })?;
3374 Ok(())
3375}
3376
3377/// Send a contact request to another user.
3378async fn request_contact(
3379 request: proto::RequestContact,
3380 response: Response<proto::RequestContact>,
3381 session: UserSession,
3382) -> Result<()> {
3383 let requester_id = session.user_id();
3384 let responder_id = UserId::from_proto(request.responder_id);
3385 if requester_id == responder_id {
3386 return Err(anyhow!("cannot add yourself as a contact"))?;
3387 }
3388
3389 let notifications = session
3390 .db()
3391 .await
3392 .send_contact_request(requester_id, responder_id)
3393 .await?;
3394
3395 // Update outgoing contact requests of requester
3396 let mut update = proto::UpdateContacts::default();
3397 update.outgoing_requests.push(responder_id.to_proto());
3398 for connection_id in session
3399 .connection_pool()
3400 .await
3401 .user_connection_ids(requester_id)
3402 {
3403 session.peer.send(connection_id, update.clone())?;
3404 }
3405
3406 // Update incoming contact requests of responder
3407 let mut update = proto::UpdateContacts::default();
3408 update
3409 .incoming_requests
3410 .push(proto::IncomingContactRequest {
3411 requester_id: requester_id.to_proto(),
3412 });
3413 let connection_pool = session.connection_pool().await;
3414 for connection_id in connection_pool.user_connection_ids(responder_id) {
3415 session.peer.send(connection_id, update.clone())?;
3416 }
3417
3418 send_notifications(&connection_pool, &session.peer, notifications);
3419
3420 response.send(proto::Ack {})?;
3421 Ok(())
3422}
3423
3424/// Accept or decline a contact request
3425async fn respond_to_contact_request(
3426 request: proto::RespondToContactRequest,
3427 response: Response<proto::RespondToContactRequest>,
3428 session: UserSession,
3429) -> Result<()> {
3430 let responder_id = session.user_id();
3431 let requester_id = UserId::from_proto(request.requester_id);
3432 let db = session.db().await;
3433 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3434 db.dismiss_contact_notification(responder_id, requester_id)
3435 .await?;
3436 } else {
3437 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3438
3439 let notifications = db
3440 .respond_to_contact_request(responder_id, requester_id, accept)
3441 .await?;
3442 let requester_busy = db.is_user_busy(requester_id).await?;
3443 let responder_busy = db.is_user_busy(responder_id).await?;
3444
3445 let pool = session.connection_pool().await;
3446 // Update responder with new contact
3447 let mut update = proto::UpdateContacts::default();
3448 if accept {
3449 update
3450 .contacts
3451 .push(contact_for_user(requester_id, requester_busy, &pool));
3452 }
3453 update
3454 .remove_incoming_requests
3455 .push(requester_id.to_proto());
3456 for connection_id in pool.user_connection_ids(responder_id) {
3457 session.peer.send(connection_id, update.clone())?;
3458 }
3459
3460 // Update requester with new contact
3461 let mut update = proto::UpdateContacts::default();
3462 if accept {
3463 update
3464 .contacts
3465 .push(contact_for_user(responder_id, responder_busy, &pool));
3466 }
3467 update
3468 .remove_outgoing_requests
3469 .push(responder_id.to_proto());
3470
3471 for connection_id in pool.user_connection_ids(requester_id) {
3472 session.peer.send(connection_id, update.clone())?;
3473 }
3474
3475 send_notifications(&pool, &session.peer, notifications);
3476 }
3477
3478 response.send(proto::Ack {})?;
3479 Ok(())
3480}
3481
3482/// Remove a contact.
3483async fn remove_contact(
3484 request: proto::RemoveContact,
3485 response: Response<proto::RemoveContact>,
3486 session: UserSession,
3487) -> Result<()> {
3488 let requester_id = session.user_id();
3489 let responder_id = UserId::from_proto(request.user_id);
3490 let db = session.db().await;
3491 let (contact_accepted, deleted_notification_id) =
3492 db.remove_contact(requester_id, responder_id).await?;
3493
3494 let pool = session.connection_pool().await;
3495 // Update outgoing contact requests of requester
3496 let mut update = proto::UpdateContacts::default();
3497 if contact_accepted {
3498 update.remove_contacts.push(responder_id.to_proto());
3499 } else {
3500 update
3501 .remove_outgoing_requests
3502 .push(responder_id.to_proto());
3503 }
3504 for connection_id in pool.user_connection_ids(requester_id) {
3505 session.peer.send(connection_id, update.clone())?;
3506 }
3507
3508 // Update incoming contact requests of responder
3509 let mut update = proto::UpdateContacts::default();
3510 if contact_accepted {
3511 update.remove_contacts.push(requester_id.to_proto());
3512 } else {
3513 update
3514 .remove_incoming_requests
3515 .push(requester_id.to_proto());
3516 }
3517 for connection_id in pool.user_connection_ids(responder_id) {
3518 session.peer.send(connection_id, update.clone())?;
3519 if let Some(notification_id) = deleted_notification_id {
3520 session.peer.send(
3521 connection_id,
3522 proto::DeleteNotification {
3523 notification_id: notification_id.to_proto(),
3524 },
3525 )?;
3526 }
3527 }
3528
3529 response.send(proto::Ack {})?;
3530 Ok(())
3531}
3532
3533fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3534 version.0.minor() < 139
3535}
3536
3537async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3538 subscribe_user_to_channels(
3539 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3540 &session,
3541 )
3542 .await?;
3543 Ok(())
3544}
3545
3546async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3547 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3548 let mut pool = session.connection_pool().await;
3549 for membership in &channels_for_user.channel_memberships {
3550 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3551 }
3552 session.peer.send(
3553 session.connection_id,
3554 build_update_user_channels(&channels_for_user),
3555 )?;
3556 session.peer.send(
3557 session.connection_id,
3558 build_channels_update(channels_for_user),
3559 )?;
3560 Ok(())
3561}
3562
3563/// Creates a new channel.
3564async fn create_channel(
3565 request: proto::CreateChannel,
3566 response: Response<proto::CreateChannel>,
3567 session: UserSession,
3568) -> Result<()> {
3569 let db = session.db().await;
3570
3571 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3572 let (channel, membership) = db
3573 .create_channel(&request.name, parent_id, session.user_id())
3574 .await?;
3575
3576 let root_id = channel.root_id();
3577 let channel = Channel::from_model(channel);
3578
3579 response.send(proto::CreateChannelResponse {
3580 channel: Some(channel.to_proto()),
3581 parent_id: request.parent_id,
3582 })?;
3583
3584 let mut connection_pool = session.connection_pool().await;
3585 if let Some(membership) = membership {
3586 connection_pool.subscribe_to_channel(
3587 membership.user_id,
3588 membership.channel_id,
3589 membership.role,
3590 );
3591 let update = proto::UpdateUserChannels {
3592 channel_memberships: vec![proto::ChannelMembership {
3593 channel_id: membership.channel_id.to_proto(),
3594 role: membership.role.into(),
3595 }],
3596 ..Default::default()
3597 };
3598 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3599 session.peer.send(connection_id, update.clone())?;
3600 }
3601 }
3602
3603 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3604 if !role.can_see_channel(channel.visibility) {
3605 continue;
3606 }
3607
3608 let update = proto::UpdateChannels {
3609 channels: vec![channel.to_proto()],
3610 ..Default::default()
3611 };
3612 session.peer.send(connection_id, update.clone())?;
3613 }
3614
3615 Ok(())
3616}
3617
3618/// Delete a channel
3619async fn delete_channel(
3620 request: proto::DeleteChannel,
3621 response: Response<proto::DeleteChannel>,
3622 session: UserSession,
3623) -> Result<()> {
3624 let db = session.db().await;
3625
3626 let channel_id = request.channel_id;
3627 let (root_channel, removed_channels) = db
3628 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3629 .await?;
3630 response.send(proto::Ack {})?;
3631
3632 // Notify members of removed channels
3633 let mut update = proto::UpdateChannels::default();
3634 update
3635 .delete_channels
3636 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3637
3638 let connection_pool = session.connection_pool().await;
3639 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3640 session.peer.send(connection_id, update.clone())?;
3641 }
3642
3643 Ok(())
3644}
3645
3646/// Invite someone to join a channel.
3647async fn invite_channel_member(
3648 request: proto::InviteChannelMember,
3649 response: Response<proto::InviteChannelMember>,
3650 session: UserSession,
3651) -> Result<()> {
3652 let db = session.db().await;
3653 let channel_id = ChannelId::from_proto(request.channel_id);
3654 let invitee_id = UserId::from_proto(request.user_id);
3655 let InviteMemberResult {
3656 channel,
3657 notifications,
3658 } = db
3659 .invite_channel_member(
3660 channel_id,
3661 invitee_id,
3662 session.user_id(),
3663 request.role().into(),
3664 )
3665 .await?;
3666
3667 let update = proto::UpdateChannels {
3668 channel_invitations: vec![channel.to_proto()],
3669 ..Default::default()
3670 };
3671
3672 let connection_pool = session.connection_pool().await;
3673 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3674 session.peer.send(connection_id, update.clone())?;
3675 }
3676
3677 send_notifications(&connection_pool, &session.peer, notifications);
3678
3679 response.send(proto::Ack {})?;
3680 Ok(())
3681}
3682
3683/// remove someone from a channel
3684async fn remove_channel_member(
3685 request: proto::RemoveChannelMember,
3686 response: Response<proto::RemoveChannelMember>,
3687 session: UserSession,
3688) -> Result<()> {
3689 let db = session.db().await;
3690 let channel_id = ChannelId::from_proto(request.channel_id);
3691 let member_id = UserId::from_proto(request.user_id);
3692
3693 let RemoveChannelMemberResult {
3694 membership_update,
3695 notification_id,
3696 } = db
3697 .remove_channel_member(channel_id, member_id, session.user_id())
3698 .await?;
3699
3700 let mut connection_pool = session.connection_pool().await;
3701 notify_membership_updated(
3702 &mut connection_pool,
3703 membership_update,
3704 member_id,
3705 &session.peer,
3706 );
3707 for connection_id in connection_pool.user_connection_ids(member_id) {
3708 if let Some(notification_id) = notification_id {
3709 session
3710 .peer
3711 .send(
3712 connection_id,
3713 proto::DeleteNotification {
3714 notification_id: notification_id.to_proto(),
3715 },
3716 )
3717 .trace_err();
3718 }
3719 }
3720
3721 response.send(proto::Ack {})?;
3722 Ok(())
3723}
3724
3725/// Toggle the channel between public and private.
3726/// Care is taken to maintain the invariant that public channels only descend from public channels,
3727/// (though members-only channels can appear at any point in the hierarchy).
3728async fn set_channel_visibility(
3729 request: proto::SetChannelVisibility,
3730 response: Response<proto::SetChannelVisibility>,
3731 session: UserSession,
3732) -> Result<()> {
3733 let db = session.db().await;
3734 let channel_id = ChannelId::from_proto(request.channel_id);
3735 let visibility = request.visibility().into();
3736
3737 let channel_model = db
3738 .set_channel_visibility(channel_id, visibility, session.user_id())
3739 .await?;
3740 let root_id = channel_model.root_id();
3741 let channel = Channel::from_model(channel_model);
3742
3743 let mut connection_pool = session.connection_pool().await;
3744 for (user_id, role) in connection_pool
3745 .channel_user_ids(root_id)
3746 .collect::<Vec<_>>()
3747 .into_iter()
3748 {
3749 let update = if role.can_see_channel(channel.visibility) {
3750 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3751 proto::UpdateChannels {
3752 channels: vec![channel.to_proto()],
3753 ..Default::default()
3754 }
3755 } else {
3756 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3757 proto::UpdateChannels {
3758 delete_channels: vec![channel.id.to_proto()],
3759 ..Default::default()
3760 }
3761 };
3762
3763 for connection_id in connection_pool.user_connection_ids(user_id) {
3764 session.peer.send(connection_id, update.clone())?;
3765 }
3766 }
3767
3768 response.send(proto::Ack {})?;
3769 Ok(())
3770}
3771
3772/// Alter the role for a user in the channel.
3773async fn set_channel_member_role(
3774 request: proto::SetChannelMemberRole,
3775 response: Response<proto::SetChannelMemberRole>,
3776 session: UserSession,
3777) -> Result<()> {
3778 let db = session.db().await;
3779 let channel_id = ChannelId::from_proto(request.channel_id);
3780 let member_id = UserId::from_proto(request.user_id);
3781 let result = db
3782 .set_channel_member_role(
3783 channel_id,
3784 session.user_id(),
3785 member_id,
3786 request.role().into(),
3787 )
3788 .await?;
3789
3790 match result {
3791 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3792 let mut connection_pool = session.connection_pool().await;
3793 notify_membership_updated(
3794 &mut connection_pool,
3795 membership_update,
3796 member_id,
3797 &session.peer,
3798 )
3799 }
3800 db::SetMemberRoleResult::InviteUpdated(channel) => {
3801 let update = proto::UpdateChannels {
3802 channel_invitations: vec![channel.to_proto()],
3803 ..Default::default()
3804 };
3805
3806 for connection_id in session
3807 .connection_pool()
3808 .await
3809 .user_connection_ids(member_id)
3810 {
3811 session.peer.send(connection_id, update.clone())?;
3812 }
3813 }
3814 }
3815
3816 response.send(proto::Ack {})?;
3817 Ok(())
3818}
3819
3820/// Change the name of a channel
3821async fn rename_channel(
3822 request: proto::RenameChannel,
3823 response: Response<proto::RenameChannel>,
3824 session: UserSession,
3825) -> Result<()> {
3826 let db = session.db().await;
3827 let channel_id = ChannelId::from_proto(request.channel_id);
3828 let channel_model = db
3829 .rename_channel(channel_id, session.user_id(), &request.name)
3830 .await?;
3831 let root_id = channel_model.root_id();
3832 let channel = Channel::from_model(channel_model);
3833
3834 response.send(proto::RenameChannelResponse {
3835 channel: Some(channel.to_proto()),
3836 })?;
3837
3838 let connection_pool = session.connection_pool().await;
3839 let update = proto::UpdateChannels {
3840 channels: vec![channel.to_proto()],
3841 ..Default::default()
3842 };
3843 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3844 if role.can_see_channel(channel.visibility) {
3845 session.peer.send(connection_id, update.clone())?;
3846 }
3847 }
3848
3849 Ok(())
3850}
3851
3852/// Move a channel to a new parent.
3853async fn move_channel(
3854 request: proto::MoveChannel,
3855 response: Response<proto::MoveChannel>,
3856 session: UserSession,
3857) -> Result<()> {
3858 let channel_id = ChannelId::from_proto(request.channel_id);
3859 let to = ChannelId::from_proto(request.to);
3860
3861 let (root_id, channels) = session
3862 .db()
3863 .await
3864 .move_channel(channel_id, to, session.user_id())
3865 .await?;
3866
3867 let connection_pool = session.connection_pool().await;
3868 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3869 let channels = channels
3870 .iter()
3871 .filter_map(|channel| {
3872 if role.can_see_channel(channel.visibility) {
3873 Some(channel.to_proto())
3874 } else {
3875 None
3876 }
3877 })
3878 .collect::<Vec<_>>();
3879 if channels.is_empty() {
3880 continue;
3881 }
3882
3883 let update = proto::UpdateChannels {
3884 channels,
3885 ..Default::default()
3886 };
3887
3888 session.peer.send(connection_id, update.clone())?;
3889 }
3890
3891 response.send(Ack {})?;
3892 Ok(())
3893}
3894
3895/// Get the list of channel members
3896async fn get_channel_members(
3897 request: proto::GetChannelMembers,
3898 response: Response<proto::GetChannelMembers>,
3899 session: UserSession,
3900) -> Result<()> {
3901 let db = session.db().await;
3902 let channel_id = ChannelId::from_proto(request.channel_id);
3903 let limit = if request.limit == 0 {
3904 u16::MAX as u64
3905 } else {
3906 request.limit
3907 };
3908 let (members, users) = db
3909 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3910 .await?;
3911 response.send(proto::GetChannelMembersResponse { members, users })?;
3912 Ok(())
3913}
3914
3915/// Accept or decline a channel invitation.
3916async fn respond_to_channel_invite(
3917 request: proto::RespondToChannelInvite,
3918 response: Response<proto::RespondToChannelInvite>,
3919 session: UserSession,
3920) -> Result<()> {
3921 let db = session.db().await;
3922 let channel_id = ChannelId::from_proto(request.channel_id);
3923 let RespondToChannelInvite {
3924 membership_update,
3925 notifications,
3926 } = db
3927 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3928 .await?;
3929
3930 let mut connection_pool = session.connection_pool().await;
3931 if let Some(membership_update) = membership_update {
3932 notify_membership_updated(
3933 &mut connection_pool,
3934 membership_update,
3935 session.user_id(),
3936 &session.peer,
3937 );
3938 } else {
3939 let update = proto::UpdateChannels {
3940 remove_channel_invitations: vec![channel_id.to_proto()],
3941 ..Default::default()
3942 };
3943
3944 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3945 session.peer.send(connection_id, update.clone())?;
3946 }
3947 };
3948
3949 send_notifications(&connection_pool, &session.peer, notifications);
3950
3951 response.send(proto::Ack {})?;
3952
3953 Ok(())
3954}
3955
3956/// Join the channels' room
3957async fn join_channel(
3958 request: proto::JoinChannel,
3959 response: Response<proto::JoinChannel>,
3960 session: UserSession,
3961) -> Result<()> {
3962 let channel_id = ChannelId::from_proto(request.channel_id);
3963 join_channel_internal(channel_id, Box::new(response), session).await
3964}
3965
3966trait JoinChannelInternalResponse {
3967 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3968}
3969impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3970 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3971 Response::<proto::JoinChannel>::send(self, result)
3972 }
3973}
3974impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3975 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3976 Response::<proto::JoinRoom>::send(self, result)
3977 }
3978}
3979
3980async fn join_channel_internal(
3981 channel_id: ChannelId,
3982 response: Box<impl JoinChannelInternalResponse>,
3983 session: UserSession,
3984) -> Result<()> {
3985 let joined_room = {
3986 let mut db = session.db().await;
3987 // If zed quits without leaving the room, and the user re-opens zed before the
3988 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3989 // room they were in.
3990 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3991 tracing::info!(
3992 stale_connection_id = %connection,
3993 "cleaning up stale connection",
3994 );
3995 drop(db);
3996 leave_room_for_session(&session, connection).await?;
3997 db = session.db().await;
3998 }
3999
4000 let (joined_room, membership_updated, role) = db
4001 .join_channel(channel_id, session.user_id(), session.connection_id)
4002 .await?;
4003
4004 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
4005 let (can_publish, token) = if role == ChannelRole::Guest {
4006 (
4007 false,
4008 live_kit
4009 .guest_token(
4010 &joined_room.room.live_kit_room,
4011 &session.user_id().to_string(),
4012 )
4013 .trace_err()?,
4014 )
4015 } else {
4016 (
4017 true,
4018 live_kit
4019 .room_token(
4020 &joined_room.room.live_kit_room,
4021 &session.user_id().to_string(),
4022 )
4023 .trace_err()?,
4024 )
4025 };
4026
4027 Some(LiveKitConnectionInfo {
4028 server_url: live_kit.url().into(),
4029 token,
4030 can_publish,
4031 })
4032 });
4033
4034 response.send(proto::JoinRoomResponse {
4035 room: Some(joined_room.room.clone()),
4036 channel_id: joined_room
4037 .channel
4038 .as_ref()
4039 .map(|channel| channel.id.to_proto()),
4040 live_kit_connection_info,
4041 })?;
4042
4043 let mut connection_pool = session.connection_pool().await;
4044 if let Some(membership_updated) = membership_updated {
4045 notify_membership_updated(
4046 &mut connection_pool,
4047 membership_updated,
4048 session.user_id(),
4049 &session.peer,
4050 );
4051 }
4052
4053 room_updated(&joined_room.room, &session.peer);
4054
4055 joined_room
4056 };
4057
4058 channel_updated(
4059 &joined_room
4060 .channel
4061 .ok_or_else(|| anyhow!("channel not returned"))?,
4062 &joined_room.room,
4063 &session.peer,
4064 &*session.connection_pool().await,
4065 );
4066
4067 update_user_contacts(session.user_id(), &session).await?;
4068 Ok(())
4069}
4070
4071/// Start editing the channel notes
4072async fn join_channel_buffer(
4073 request: proto::JoinChannelBuffer,
4074 response: Response<proto::JoinChannelBuffer>,
4075 session: UserSession,
4076) -> Result<()> {
4077 let db = session.db().await;
4078 let channel_id = ChannelId::from_proto(request.channel_id);
4079
4080 let open_response = db
4081 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4082 .await?;
4083
4084 let collaborators = open_response.collaborators.clone();
4085 response.send(open_response)?;
4086
4087 let update = UpdateChannelBufferCollaborators {
4088 channel_id: channel_id.to_proto(),
4089 collaborators: collaborators.clone(),
4090 };
4091 channel_buffer_updated(
4092 session.connection_id,
4093 collaborators
4094 .iter()
4095 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4096 &update,
4097 &session.peer,
4098 );
4099
4100 Ok(())
4101}
4102
4103/// Edit the channel notes
4104async fn update_channel_buffer(
4105 request: proto::UpdateChannelBuffer,
4106 session: UserSession,
4107) -> Result<()> {
4108 let db = session.db().await;
4109 let channel_id = ChannelId::from_proto(request.channel_id);
4110
4111 let (collaborators, epoch, version) = db
4112 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4113 .await?;
4114
4115 channel_buffer_updated(
4116 session.connection_id,
4117 collaborators.clone(),
4118 &proto::UpdateChannelBuffer {
4119 channel_id: channel_id.to_proto(),
4120 operations: request.operations,
4121 },
4122 &session.peer,
4123 );
4124
4125 let pool = &*session.connection_pool().await;
4126
4127 let non_collaborators =
4128 pool.channel_connection_ids(channel_id)
4129 .filter_map(|(connection_id, _)| {
4130 if collaborators.contains(&connection_id) {
4131 None
4132 } else {
4133 Some(connection_id)
4134 }
4135 });
4136
4137 broadcast(None, non_collaborators, |peer_id| {
4138 session.peer.send(
4139 peer_id,
4140 proto::UpdateChannels {
4141 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4142 channel_id: channel_id.to_proto(),
4143 epoch: epoch as u64,
4144 version: version.clone(),
4145 }],
4146 ..Default::default()
4147 },
4148 )
4149 });
4150
4151 Ok(())
4152}
4153
4154/// Rejoin the channel notes after a connection blip
4155async fn rejoin_channel_buffers(
4156 request: proto::RejoinChannelBuffers,
4157 response: Response<proto::RejoinChannelBuffers>,
4158 session: UserSession,
4159) -> Result<()> {
4160 let db = session.db().await;
4161 let buffers = db
4162 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4163 .await?;
4164
4165 for rejoined_buffer in &buffers {
4166 let collaborators_to_notify = rejoined_buffer
4167 .buffer
4168 .collaborators
4169 .iter()
4170 .filter_map(|c| Some(c.peer_id?.into()));
4171 channel_buffer_updated(
4172 session.connection_id,
4173 collaborators_to_notify,
4174 &proto::UpdateChannelBufferCollaborators {
4175 channel_id: rejoined_buffer.buffer.channel_id,
4176 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4177 },
4178 &session.peer,
4179 );
4180 }
4181
4182 response.send(proto::RejoinChannelBuffersResponse {
4183 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4184 })?;
4185
4186 Ok(())
4187}
4188
4189/// Stop editing the channel notes
4190async fn leave_channel_buffer(
4191 request: proto::LeaveChannelBuffer,
4192 response: Response<proto::LeaveChannelBuffer>,
4193 session: UserSession,
4194) -> Result<()> {
4195 let db = session.db().await;
4196 let channel_id = ChannelId::from_proto(request.channel_id);
4197
4198 let left_buffer = db
4199 .leave_channel_buffer(channel_id, session.connection_id)
4200 .await?;
4201
4202 response.send(Ack {})?;
4203
4204 channel_buffer_updated(
4205 session.connection_id,
4206 left_buffer.connections,
4207 &proto::UpdateChannelBufferCollaborators {
4208 channel_id: channel_id.to_proto(),
4209 collaborators: left_buffer.collaborators,
4210 },
4211 &session.peer,
4212 );
4213
4214 Ok(())
4215}
4216
4217fn channel_buffer_updated<T: EnvelopedMessage>(
4218 sender_id: ConnectionId,
4219 collaborators: impl IntoIterator<Item = ConnectionId>,
4220 message: &T,
4221 peer: &Peer,
4222) {
4223 broadcast(Some(sender_id), collaborators, |peer_id| {
4224 peer.send(peer_id, message.clone())
4225 });
4226}
4227
4228fn send_notifications(
4229 connection_pool: &ConnectionPool,
4230 peer: &Peer,
4231 notifications: db::NotificationBatch,
4232) {
4233 for (user_id, notification) in notifications {
4234 for connection_id in connection_pool.user_connection_ids(user_id) {
4235 if let Err(error) = peer.send(
4236 connection_id,
4237 proto::AddNotification {
4238 notification: Some(notification.clone()),
4239 },
4240 ) {
4241 tracing::error!(
4242 "failed to send notification to {:?} {}",
4243 connection_id,
4244 error
4245 );
4246 }
4247 }
4248 }
4249}
4250
4251/// Send a message to the channel
4252async fn send_channel_message(
4253 request: proto::SendChannelMessage,
4254 response: Response<proto::SendChannelMessage>,
4255 session: UserSession,
4256) -> Result<()> {
4257 // Validate the message body.
4258 let body = request.body.trim().to_string();
4259 if body.len() > MAX_MESSAGE_LEN {
4260 return Err(anyhow!("message is too long"))?;
4261 }
4262 if body.is_empty() {
4263 return Err(anyhow!("message can't be blank"))?;
4264 }
4265
4266 // TODO: adjust mentions if body is trimmed
4267
4268 let timestamp = OffsetDateTime::now_utc();
4269 let nonce = request
4270 .nonce
4271 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4272
4273 let channel_id = ChannelId::from_proto(request.channel_id);
4274 let CreatedChannelMessage {
4275 message_id,
4276 participant_connection_ids,
4277 notifications,
4278 } = session
4279 .db()
4280 .await
4281 .create_channel_message(
4282 channel_id,
4283 session.user_id(),
4284 &body,
4285 &request.mentions,
4286 timestamp,
4287 nonce.clone().into(),
4288 match request.reply_to_message_id {
4289 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4290 None => None,
4291 },
4292 )
4293 .await?;
4294
4295 let message = proto::ChannelMessage {
4296 sender_id: session.user_id().to_proto(),
4297 id: message_id.to_proto(),
4298 body,
4299 mentions: request.mentions,
4300 timestamp: timestamp.unix_timestamp() as u64,
4301 nonce: Some(nonce),
4302 reply_to_message_id: request.reply_to_message_id,
4303 edited_at: None,
4304 };
4305 broadcast(
4306 Some(session.connection_id),
4307 participant_connection_ids.clone(),
4308 |connection| {
4309 session.peer.send(
4310 connection,
4311 proto::ChannelMessageSent {
4312 channel_id: channel_id.to_proto(),
4313 message: Some(message.clone()),
4314 },
4315 )
4316 },
4317 );
4318 response.send(proto::SendChannelMessageResponse {
4319 message: Some(message),
4320 })?;
4321
4322 let pool = &*session.connection_pool().await;
4323 let non_participants =
4324 pool.channel_connection_ids(channel_id)
4325 .filter_map(|(connection_id, _)| {
4326 if participant_connection_ids.contains(&connection_id) {
4327 None
4328 } else {
4329 Some(connection_id)
4330 }
4331 });
4332 broadcast(None, non_participants, |peer_id| {
4333 session.peer.send(
4334 peer_id,
4335 proto::UpdateChannels {
4336 latest_channel_message_ids: vec![proto::ChannelMessageId {
4337 channel_id: channel_id.to_proto(),
4338 message_id: message_id.to_proto(),
4339 }],
4340 ..Default::default()
4341 },
4342 )
4343 });
4344 send_notifications(pool, &session.peer, notifications);
4345
4346 Ok(())
4347}
4348
4349/// Delete a channel message
4350async fn remove_channel_message(
4351 request: proto::RemoveChannelMessage,
4352 response: Response<proto::RemoveChannelMessage>,
4353 session: UserSession,
4354) -> Result<()> {
4355 let channel_id = ChannelId::from_proto(request.channel_id);
4356 let message_id = MessageId::from_proto(request.message_id);
4357 let (connection_ids, existing_notification_ids) = session
4358 .db()
4359 .await
4360 .remove_channel_message(channel_id, message_id, session.user_id())
4361 .await?;
4362
4363 broadcast(
4364 Some(session.connection_id),
4365 connection_ids,
4366 move |connection| {
4367 session.peer.send(connection, request.clone())?;
4368
4369 for notification_id in &existing_notification_ids {
4370 session.peer.send(
4371 connection,
4372 proto::DeleteNotification {
4373 notification_id: (*notification_id).to_proto(),
4374 },
4375 )?;
4376 }
4377
4378 Ok(())
4379 },
4380 );
4381 response.send(proto::Ack {})?;
4382 Ok(())
4383}
4384
4385async fn update_channel_message(
4386 request: proto::UpdateChannelMessage,
4387 response: Response<proto::UpdateChannelMessage>,
4388 session: UserSession,
4389) -> Result<()> {
4390 let channel_id = ChannelId::from_proto(request.channel_id);
4391 let message_id = MessageId::from_proto(request.message_id);
4392 let updated_at = OffsetDateTime::now_utc();
4393 let UpdatedChannelMessage {
4394 message_id,
4395 participant_connection_ids,
4396 notifications,
4397 reply_to_message_id,
4398 timestamp,
4399 deleted_mention_notification_ids,
4400 updated_mention_notifications,
4401 } = session
4402 .db()
4403 .await
4404 .update_channel_message(
4405 channel_id,
4406 message_id,
4407 session.user_id(),
4408 request.body.as_str(),
4409 &request.mentions,
4410 updated_at,
4411 )
4412 .await?;
4413
4414 let nonce = request
4415 .nonce
4416 .clone()
4417 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4418
4419 let message = proto::ChannelMessage {
4420 sender_id: session.user_id().to_proto(),
4421 id: message_id.to_proto(),
4422 body: request.body.clone(),
4423 mentions: request.mentions.clone(),
4424 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4425 nonce: Some(nonce),
4426 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4427 edited_at: Some(updated_at.unix_timestamp() as u64),
4428 };
4429
4430 response.send(proto::Ack {})?;
4431
4432 let pool = &*session.connection_pool().await;
4433 broadcast(
4434 Some(session.connection_id),
4435 participant_connection_ids,
4436 |connection| {
4437 session.peer.send(
4438 connection,
4439 proto::ChannelMessageUpdate {
4440 channel_id: channel_id.to_proto(),
4441 message: Some(message.clone()),
4442 },
4443 )?;
4444
4445 for notification_id in &deleted_mention_notification_ids {
4446 session.peer.send(
4447 connection,
4448 proto::DeleteNotification {
4449 notification_id: (*notification_id).to_proto(),
4450 },
4451 )?;
4452 }
4453
4454 for notification in &updated_mention_notifications {
4455 session.peer.send(
4456 connection,
4457 proto::UpdateNotification {
4458 notification: Some(notification.clone()),
4459 },
4460 )?;
4461 }
4462
4463 Ok(())
4464 },
4465 );
4466
4467 send_notifications(pool, &session.peer, notifications);
4468
4469 Ok(())
4470}
4471
4472/// Mark a channel message as read
4473async fn acknowledge_channel_message(
4474 request: proto::AckChannelMessage,
4475 session: UserSession,
4476) -> Result<()> {
4477 let channel_id = ChannelId::from_proto(request.channel_id);
4478 let message_id = MessageId::from_proto(request.message_id);
4479 let notifications = session
4480 .db()
4481 .await
4482 .observe_channel_message(channel_id, session.user_id(), message_id)
4483 .await?;
4484 send_notifications(
4485 &*session.connection_pool().await,
4486 &session.peer,
4487 notifications,
4488 );
4489 Ok(())
4490}
4491
4492/// Mark a buffer version as synced
4493async fn acknowledge_buffer_version(
4494 request: proto::AckBufferOperation,
4495 session: UserSession,
4496) -> Result<()> {
4497 let buffer_id = BufferId::from_proto(request.buffer_id);
4498 session
4499 .db()
4500 .await
4501 .observe_buffer_version(
4502 buffer_id,
4503 session.user_id(),
4504 request.epoch as i32,
4505 &request.version,
4506 )
4507 .await?;
4508 Ok(())
4509}
4510
4511struct CompleteWithLanguageModelRateLimit;
4512
4513impl RateLimit for CompleteWithLanguageModelRateLimit {
4514 fn capacity() -> usize {
4515 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4516 .ok()
4517 .and_then(|v| v.parse().ok())
4518 .unwrap_or(120) // Picked arbitrarily
4519 }
4520
4521 fn refill_duration() -> chrono::Duration {
4522 chrono::Duration::hours(1)
4523 }
4524
4525 fn db_name() -> &'static str {
4526 "complete-with-language-model"
4527 }
4528}
4529
4530async fn complete_with_language_model(
4531 request: proto::CompleteWithLanguageModel,
4532 response: Response<proto::CompleteWithLanguageModel>,
4533 session: Session,
4534 config: &Config,
4535) -> Result<()> {
4536 let Some(session) = session.for_user() else {
4537 return Err(anyhow!("user not found"))?;
4538 };
4539 authorize_access_to_language_models(&session).await?;
4540
4541 session
4542 .rate_limiter
4543 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4544 .await?;
4545
4546 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4547 Some(proto::LanguageModelProvider::Anthropic) => {
4548 let api_key = config
4549 .anthropic_api_key
4550 .as_ref()
4551 .context("no Anthropic AI API key configured on the server")?;
4552 anthropic::complete(
4553 session.http_client.as_ref(),
4554 anthropic::ANTHROPIC_API_URL,
4555 api_key,
4556 serde_json::from_str(&request.request)?,
4557 )
4558 .await?
4559 }
4560 _ => return Err(anyhow!("unsupported provider"))?,
4561 };
4562
4563 response.send(proto::CompleteWithLanguageModelResponse {
4564 completion: serde_json::to_string(&result)?,
4565 })?;
4566
4567 Ok(())
4568}
4569
4570async fn stream_complete_with_language_model(
4571 request: proto::StreamCompleteWithLanguageModel,
4572 response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
4573 session: Session,
4574 config: &Config,
4575) -> Result<()> {
4576 let Some(session) = session.for_user() else {
4577 return Err(anyhow!("user not found"))?;
4578 };
4579 authorize_access_to_language_models(&session).await?;
4580
4581 session
4582 .rate_limiter
4583 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4584 .await?;
4585
4586 match proto::LanguageModelProvider::from_i32(request.provider) {
4587 Some(proto::LanguageModelProvider::Anthropic) => {
4588 let api_key = config
4589 .anthropic_api_key
4590 .as_ref()
4591 .context("no Anthropic AI API key configured on the server")?;
4592 let mut chunks = anthropic::stream_completion(
4593 session.http_client.as_ref(),
4594 anthropic::ANTHROPIC_API_URL,
4595 api_key,
4596 serde_json::from_str(&request.request)?,
4597 None,
4598 )
4599 .await?;
4600 while let Some(event) = chunks.next().await {
4601 let chunk = event?;
4602 response.send(proto::StreamCompleteWithLanguageModelResponse {
4603 event: serde_json::to_string(&chunk)?,
4604 })?;
4605 }
4606 }
4607 Some(proto::LanguageModelProvider::OpenAi) => {
4608 let api_key = config
4609 .openai_api_key
4610 .as_ref()
4611 .context("no OpenAI API key configured on the server")?;
4612 let mut events = open_ai::stream_completion(
4613 session.http_client.as_ref(),
4614 open_ai::OPEN_AI_API_URL,
4615 api_key,
4616 serde_json::from_str(&request.request)?,
4617 None,
4618 )
4619 .await?;
4620 while let Some(event) = events.next().await {
4621 let event = event?;
4622 response.send(proto::StreamCompleteWithLanguageModelResponse {
4623 event: serde_json::to_string(&event)?,
4624 })?;
4625 }
4626 }
4627 Some(proto::LanguageModelProvider::Google) => {
4628 let api_key = config
4629 .google_ai_api_key
4630 .as_ref()
4631 .context("no Google AI API key configured on the server")?;
4632 let mut events = google_ai::stream_generate_content(
4633 session.http_client.as_ref(),
4634 google_ai::API_URL,
4635 api_key,
4636 serde_json::from_str(&request.request)?,
4637 )
4638 .await?;
4639 while let Some(event) = events.next().await {
4640 let event = event?;
4641 response.send(proto::StreamCompleteWithLanguageModelResponse {
4642 event: serde_json::to_string(&event)?,
4643 })?;
4644 }
4645 }
4646 None => return Err(anyhow!("unknown provider"))?,
4647 }
4648
4649 Ok(())
4650}
4651
4652async fn count_language_model_tokens(
4653 request: proto::CountLanguageModelTokens,
4654 response: Response<proto::CountLanguageModelTokens>,
4655 session: Session,
4656 config: &Config,
4657) -> Result<()> {
4658 let Some(session) = session.for_user() else {
4659 return Err(anyhow!("user not found"))?;
4660 };
4661 authorize_access_to_language_models(&session).await?;
4662
4663 session
4664 .rate_limiter
4665 .check::<CountLanguageModelTokensRateLimit>(session.user_id())
4666 .await?;
4667
4668 let result = match proto::LanguageModelProvider::from_i32(request.provider) {
4669 Some(proto::LanguageModelProvider::Google) => {
4670 let api_key = config
4671 .google_ai_api_key
4672 .as_ref()
4673 .context("no Google AI API key configured on the server")?;
4674 google_ai::count_tokens(
4675 session.http_client.as_ref(),
4676 google_ai::API_URL,
4677 api_key,
4678 serde_json::from_str(&request.request)?,
4679 )
4680 .await?
4681 }
4682 _ => return Err(anyhow!("unsupported provider"))?,
4683 };
4684
4685 response.send(proto::CountLanguageModelTokensResponse {
4686 token_count: result.total_tokens as u32,
4687 })?;
4688
4689 Ok(())
4690}
4691
4692struct CountLanguageModelTokensRateLimit;
4693
4694impl RateLimit for CountLanguageModelTokensRateLimit {
4695 fn capacity() -> usize {
4696 std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
4697 .ok()
4698 .and_then(|v| v.parse().ok())
4699 .unwrap_or(600) // Picked arbitrarily
4700 }
4701
4702 fn refill_duration() -> chrono::Duration {
4703 chrono::Duration::hours(1)
4704 }
4705
4706 fn db_name() -> &'static str {
4707 "count-language-model-tokens"
4708 }
4709}
4710
4711struct ComputeEmbeddingsRateLimit;
4712
4713impl RateLimit for ComputeEmbeddingsRateLimit {
4714 fn capacity() -> usize {
4715 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4716 .ok()
4717 .and_then(|v| v.parse().ok())
4718 .unwrap_or(5000) // Picked arbitrarily
4719 }
4720
4721 fn refill_duration() -> chrono::Duration {
4722 chrono::Duration::hours(1)
4723 }
4724
4725 fn db_name() -> &'static str {
4726 "compute-embeddings"
4727 }
4728}
4729
4730async fn compute_embeddings(
4731 request: proto::ComputeEmbeddings,
4732 response: Response<proto::ComputeEmbeddings>,
4733 session: UserSession,
4734 api_key: Option<Arc<str>>,
4735) -> Result<()> {
4736 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4737 authorize_access_to_language_models(&session).await?;
4738
4739 session
4740 .rate_limiter
4741 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4742 .await?;
4743
4744 let embeddings = match request.model.as_str() {
4745 "openai/text-embedding-3-small" => {
4746 open_ai::embed(
4747 session.http_client.as_ref(),
4748 OPEN_AI_API_URL,
4749 &api_key,
4750 OpenAiEmbeddingModel::TextEmbedding3Small,
4751 request.texts.iter().map(|text| text.as_str()),
4752 )
4753 .await?
4754 }
4755 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4756 };
4757
4758 let embeddings = request
4759 .texts
4760 .iter()
4761 .map(|text| {
4762 let mut hasher = sha2::Sha256::new();
4763 hasher.update(text.as_bytes());
4764 let result = hasher.finalize();
4765 result.to_vec()
4766 })
4767 .zip(
4768 embeddings
4769 .data
4770 .into_iter()
4771 .map(|embedding| embedding.embedding),
4772 )
4773 .collect::<HashMap<_, _>>();
4774
4775 let db = session.db().await;
4776 db.save_embeddings(&request.model, &embeddings)
4777 .await
4778 .context("failed to save embeddings")
4779 .trace_err();
4780
4781 response.send(proto::ComputeEmbeddingsResponse {
4782 embeddings: embeddings
4783 .into_iter()
4784 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4785 .collect(),
4786 })?;
4787 Ok(())
4788}
4789
4790async fn get_cached_embeddings(
4791 request: proto::GetCachedEmbeddings,
4792 response: Response<proto::GetCachedEmbeddings>,
4793 session: UserSession,
4794) -> Result<()> {
4795 authorize_access_to_language_models(&session).await?;
4796
4797 let db = session.db().await;
4798 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4799
4800 response.send(proto::GetCachedEmbeddingsResponse {
4801 embeddings: embeddings
4802 .into_iter()
4803 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4804 .collect(),
4805 })?;
4806 Ok(())
4807}
4808
4809async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4810 let db = session.db().await;
4811 let flags = db.get_user_flags(session.user_id()).await?;
4812 if flags.iter().any(|flag| flag == "language-models") {
4813 Ok(())
4814 } else {
4815 Err(anyhow!("permission denied"))?
4816 }
4817}
4818
4819/// Get a Supermaven API key for the user
4820async fn get_supermaven_api_key(
4821 _request: proto::GetSupermavenApiKey,
4822 response: Response<proto::GetSupermavenApiKey>,
4823 session: UserSession,
4824) -> Result<()> {
4825 let user_id: String = session.user_id().to_string();
4826 if !session.is_staff() {
4827 return Err(anyhow!("supermaven not enabled for this account"))?;
4828 }
4829
4830 let email = session
4831 .email()
4832 .ok_or_else(|| anyhow!("user must have an email"))?;
4833
4834 let supermaven_admin_api = session
4835 .supermaven_client
4836 .as_ref()
4837 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4838
4839 let result = supermaven_admin_api
4840 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4841 .await?;
4842
4843 response.send(proto::GetSupermavenApiKeyResponse {
4844 api_key: result.api_key,
4845 })?;
4846
4847 Ok(())
4848}
4849
4850/// Start receiving chat updates for a channel
4851async fn join_channel_chat(
4852 request: proto::JoinChannelChat,
4853 response: Response<proto::JoinChannelChat>,
4854 session: UserSession,
4855) -> Result<()> {
4856 let channel_id = ChannelId::from_proto(request.channel_id);
4857
4858 let db = session.db().await;
4859 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4860 .await?;
4861 let messages = db
4862 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4863 .await?;
4864 response.send(proto::JoinChannelChatResponse {
4865 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4866 messages,
4867 })?;
4868 Ok(())
4869}
4870
4871/// Stop receiving chat updates for a channel
4872async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4873 let channel_id = ChannelId::from_proto(request.channel_id);
4874 session
4875 .db()
4876 .await
4877 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4878 .await?;
4879 Ok(())
4880}
4881
4882/// Retrieve the chat history for a channel
4883async fn get_channel_messages(
4884 request: proto::GetChannelMessages,
4885 response: Response<proto::GetChannelMessages>,
4886 session: UserSession,
4887) -> Result<()> {
4888 let channel_id = ChannelId::from_proto(request.channel_id);
4889 let messages = session
4890 .db()
4891 .await
4892 .get_channel_messages(
4893 channel_id,
4894 session.user_id(),
4895 MESSAGE_COUNT_PER_PAGE,
4896 Some(MessageId::from_proto(request.before_message_id)),
4897 )
4898 .await?;
4899 response.send(proto::GetChannelMessagesResponse {
4900 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4901 messages,
4902 })?;
4903 Ok(())
4904}
4905
4906/// Retrieve specific chat messages
4907async fn get_channel_messages_by_id(
4908 request: proto::GetChannelMessagesById,
4909 response: Response<proto::GetChannelMessagesById>,
4910 session: UserSession,
4911) -> Result<()> {
4912 let message_ids = request
4913 .message_ids
4914 .iter()
4915 .map(|id| MessageId::from_proto(*id))
4916 .collect::<Vec<_>>();
4917 let messages = session
4918 .db()
4919 .await
4920 .get_channel_messages_by_id(session.user_id(), &message_ids)
4921 .await?;
4922 response.send(proto::GetChannelMessagesResponse {
4923 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4924 messages,
4925 })?;
4926 Ok(())
4927}
4928
4929/// Retrieve the current users notifications
4930async fn get_notifications(
4931 request: proto::GetNotifications,
4932 response: Response<proto::GetNotifications>,
4933 session: UserSession,
4934) -> Result<()> {
4935 let notifications = session
4936 .db()
4937 .await
4938 .get_notifications(
4939 session.user_id(),
4940 NOTIFICATION_COUNT_PER_PAGE,
4941 request
4942 .before_id
4943 .map(|id| db::NotificationId::from_proto(id)),
4944 )
4945 .await?;
4946 response.send(proto::GetNotificationsResponse {
4947 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4948 notifications,
4949 })?;
4950 Ok(())
4951}
4952
4953/// Mark notifications as read
4954async fn mark_notification_as_read(
4955 request: proto::MarkNotificationRead,
4956 response: Response<proto::MarkNotificationRead>,
4957 session: UserSession,
4958) -> Result<()> {
4959 let database = &session.db().await;
4960 let notifications = database
4961 .mark_notification_as_read_by_id(
4962 session.user_id(),
4963 NotificationId::from_proto(request.notification_id),
4964 )
4965 .await?;
4966 send_notifications(
4967 &*session.connection_pool().await,
4968 &session.peer,
4969 notifications,
4970 );
4971 response.send(proto::Ack {})?;
4972 Ok(())
4973}
4974
4975/// Get the current users information
4976async fn get_private_user_info(
4977 _request: proto::GetPrivateUserInfo,
4978 response: Response<proto::GetPrivateUserInfo>,
4979 session: UserSession,
4980) -> Result<()> {
4981 let db = session.db().await;
4982
4983 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4984 let user = db
4985 .get_user_by_id(session.user_id())
4986 .await?
4987 .ok_or_else(|| anyhow!("user not found"))?;
4988 let flags = db.get_user_flags(session.user_id()).await?;
4989
4990 response.send(proto::GetPrivateUserInfoResponse {
4991 metrics_id,
4992 staff: user.admin,
4993 flags,
4994 })?;
4995 Ok(())
4996}
4997
4998fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4999 let message = match message {
5000 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
5001 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
5002 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
5003 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
5004 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5005 code: frame.code.into(),
5006 reason: frame.reason,
5007 })),
5008 // We should never receive a frame while reading the message, according
5009 // to the `tungstenite` maintainers:
5010 //
5011 // > It cannot occur when you read messages from the WebSocket, but it
5012 // > can be used when you want to send the raw frames (e.g. you want to
5013 // > send the frames to the WebSocket without composing the full message first).
5014 // >
5015 // > — https://github.com/snapview/tungstenite-rs/issues/268
5016 TungsteniteMessage::Frame(_) => {
5017 bail!("received an unexpected frame while reading the message")
5018 }
5019 };
5020
5021 Ok(message)
5022}
5023
5024fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5025 match message {
5026 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5027 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5028 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5029 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5030 AxumMessage::Close(frame) => {
5031 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5032 code: frame.code.into(),
5033 reason: frame.reason,
5034 }))
5035 }
5036 }
5037}
5038
5039fn notify_membership_updated(
5040 connection_pool: &mut ConnectionPool,
5041 result: MembershipUpdated,
5042 user_id: UserId,
5043 peer: &Peer,
5044) {
5045 for membership in &result.new_channels.channel_memberships {
5046 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5047 }
5048 for channel_id in &result.removed_channels {
5049 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5050 }
5051
5052 let user_channels_update = proto::UpdateUserChannels {
5053 channel_memberships: result
5054 .new_channels
5055 .channel_memberships
5056 .iter()
5057 .map(|cm| proto::ChannelMembership {
5058 channel_id: cm.channel_id.to_proto(),
5059 role: cm.role.into(),
5060 })
5061 .collect(),
5062 ..Default::default()
5063 };
5064
5065 let mut update = build_channels_update(result.new_channels);
5066 update.delete_channels = result
5067 .removed_channels
5068 .into_iter()
5069 .map(|id| id.to_proto())
5070 .collect();
5071 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5072
5073 for connection_id in connection_pool.user_connection_ids(user_id) {
5074 peer.send(connection_id, user_channels_update.clone())
5075 .trace_err();
5076 peer.send(connection_id, update.clone()).trace_err();
5077 }
5078}
5079
5080fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5081 proto::UpdateUserChannels {
5082 channel_memberships: channels
5083 .channel_memberships
5084 .iter()
5085 .map(|m| proto::ChannelMembership {
5086 channel_id: m.channel_id.to_proto(),
5087 role: m.role.into(),
5088 })
5089 .collect(),
5090 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5091 observed_channel_message_id: channels.observed_channel_messages.clone(),
5092 }
5093}
5094
5095fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5096 let mut update = proto::UpdateChannels::default();
5097
5098 for channel in channels.channels {
5099 update.channels.push(channel.to_proto());
5100 }
5101
5102 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5103 update.latest_channel_message_ids = channels.latest_channel_messages;
5104
5105 for (channel_id, participants) in channels.channel_participants {
5106 update
5107 .channel_participants
5108 .push(proto::ChannelParticipants {
5109 channel_id: channel_id.to_proto(),
5110 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5111 });
5112 }
5113
5114 for channel in channels.invited_channels {
5115 update.channel_invitations.push(channel.to_proto());
5116 }
5117
5118 update.hosted_projects = channels.hosted_projects;
5119 update
5120}
5121
5122fn build_initial_contacts_update(
5123 contacts: Vec<db::Contact>,
5124 pool: &ConnectionPool,
5125) -> proto::UpdateContacts {
5126 let mut update = proto::UpdateContacts::default();
5127
5128 for contact in contacts {
5129 match contact {
5130 db::Contact::Accepted { user_id, busy } => {
5131 update.contacts.push(contact_for_user(user_id, busy, &pool));
5132 }
5133 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5134 db::Contact::Incoming { user_id } => {
5135 update
5136 .incoming_requests
5137 .push(proto::IncomingContactRequest {
5138 requester_id: user_id.to_proto(),
5139 })
5140 }
5141 }
5142 }
5143
5144 update
5145}
5146
5147fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5148 proto::Contact {
5149 user_id: user_id.to_proto(),
5150 online: pool.is_user_online(user_id),
5151 busy,
5152 }
5153}
5154
5155fn room_updated(room: &proto::Room, peer: &Peer) {
5156 broadcast(
5157 None,
5158 room.participants
5159 .iter()
5160 .filter_map(|participant| Some(participant.peer_id?.into())),
5161 |peer_id| {
5162 peer.send(
5163 peer_id,
5164 proto::RoomUpdated {
5165 room: Some(room.clone()),
5166 },
5167 )
5168 },
5169 );
5170}
5171
5172fn channel_updated(
5173 channel: &db::channel::Model,
5174 room: &proto::Room,
5175 peer: &Peer,
5176 pool: &ConnectionPool,
5177) {
5178 let participants = room
5179 .participants
5180 .iter()
5181 .map(|p| p.user_id)
5182 .collect::<Vec<_>>();
5183
5184 broadcast(
5185 None,
5186 pool.channel_connection_ids(channel.root_id())
5187 .filter_map(|(channel_id, role)| {
5188 role.can_see_channel(channel.visibility).then(|| channel_id)
5189 }),
5190 |peer_id| {
5191 peer.send(
5192 peer_id,
5193 proto::UpdateChannels {
5194 channel_participants: vec![proto::ChannelParticipants {
5195 channel_id: channel.id.to_proto(),
5196 participant_user_ids: participants.clone(),
5197 }],
5198 ..Default::default()
5199 },
5200 )
5201 },
5202 );
5203}
5204
5205async fn send_dev_server_projects_update(
5206 user_id: UserId,
5207 mut status: proto::DevServerProjectsUpdate,
5208 session: &Session,
5209) {
5210 let pool = session.connection_pool().await;
5211 for dev_server in &mut status.dev_servers {
5212 dev_server.status =
5213 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5214 }
5215 let connections = pool.user_connection_ids(user_id);
5216 for connection_id in connections {
5217 session.peer.send(connection_id, status.clone()).trace_err();
5218 }
5219}
5220
5221async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5222 let db = session.db().await;
5223
5224 let contacts = db.get_contacts(user_id).await?;
5225 let busy = db.is_user_busy(user_id).await?;
5226
5227 let pool = session.connection_pool().await;
5228 let updated_contact = contact_for_user(user_id, busy, &pool);
5229 for contact in contacts {
5230 if let db::Contact::Accepted {
5231 user_id: contact_user_id,
5232 ..
5233 } = contact
5234 {
5235 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5236 session
5237 .peer
5238 .send(
5239 contact_conn_id,
5240 proto::UpdateContacts {
5241 contacts: vec![updated_contact.clone()],
5242 remove_contacts: Default::default(),
5243 incoming_requests: Default::default(),
5244 remove_incoming_requests: Default::default(),
5245 outgoing_requests: Default::default(),
5246 remove_outgoing_requests: Default::default(),
5247 },
5248 )
5249 .trace_err();
5250 }
5251 }
5252 }
5253 Ok(())
5254}
5255
5256async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5257 log::info!("lost dev server connection, unsharing projects");
5258 let project_ids = session
5259 .db()
5260 .await
5261 .get_stale_dev_server_projects(session.connection_id)
5262 .await?;
5263
5264 for project_id in project_ids {
5265 // not unshare re-checks the connection ids match, so we get away with no transaction
5266 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5267 }
5268
5269 let user_id = session.dev_server().user_id;
5270 let update = session
5271 .db()
5272 .await
5273 .dev_server_projects_update(user_id)
5274 .await?;
5275
5276 send_dev_server_projects_update(user_id, update, session).await;
5277
5278 Ok(())
5279}
5280
5281async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5282 let mut contacts_to_update = HashSet::default();
5283
5284 let room_id;
5285 let canceled_calls_to_user_ids;
5286 let live_kit_room;
5287 let delete_live_kit_room;
5288 let room;
5289 let channel;
5290
5291 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5292 contacts_to_update.insert(session.user_id());
5293
5294 for project in left_room.left_projects.values() {
5295 project_left(project, session);
5296 }
5297
5298 room_id = RoomId::from_proto(left_room.room.id);
5299 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5300 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5301 delete_live_kit_room = left_room.deleted;
5302 room = mem::take(&mut left_room.room);
5303 channel = mem::take(&mut left_room.channel);
5304
5305 room_updated(&room, &session.peer);
5306 } else {
5307 return Ok(());
5308 }
5309
5310 if let Some(channel) = channel {
5311 channel_updated(
5312 &channel,
5313 &room,
5314 &session.peer,
5315 &*session.connection_pool().await,
5316 );
5317 }
5318
5319 {
5320 let pool = session.connection_pool().await;
5321 for canceled_user_id in canceled_calls_to_user_ids {
5322 for connection_id in pool.user_connection_ids(canceled_user_id) {
5323 session
5324 .peer
5325 .send(
5326 connection_id,
5327 proto::CallCanceled {
5328 room_id: room_id.to_proto(),
5329 },
5330 )
5331 .trace_err();
5332 }
5333 contacts_to_update.insert(canceled_user_id);
5334 }
5335 }
5336
5337 for contact_user_id in contacts_to_update {
5338 update_user_contacts(contact_user_id, &session).await?;
5339 }
5340
5341 if let Some(live_kit) = session.live_kit_client.as_ref() {
5342 live_kit
5343 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5344 .await
5345 .trace_err();
5346
5347 if delete_live_kit_room {
5348 live_kit.delete_room(live_kit_room).await.trace_err();
5349 }
5350 }
5351
5352 Ok(())
5353}
5354
5355async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5356 let left_channel_buffers = session
5357 .db()
5358 .await
5359 .leave_channel_buffers(session.connection_id)
5360 .await?;
5361
5362 for left_buffer in left_channel_buffers {
5363 channel_buffer_updated(
5364 session.connection_id,
5365 left_buffer.connections,
5366 &proto::UpdateChannelBufferCollaborators {
5367 channel_id: left_buffer.channel_id.to_proto(),
5368 collaborators: left_buffer.collaborators,
5369 },
5370 &session.peer,
5371 );
5372 }
5373
5374 Ok(())
5375}
5376
5377fn project_left(project: &db::LeftProject, session: &UserSession) {
5378 for connection_id in &project.connection_ids {
5379 if project.should_unshare {
5380 session
5381 .peer
5382 .send(
5383 *connection_id,
5384 proto::UnshareProject {
5385 project_id: project.id.to_proto(),
5386 },
5387 )
5388 .trace_err();
5389 } else {
5390 session
5391 .peer
5392 .send(
5393 *connection_id,
5394 proto::RemoveProjectCollaborator {
5395 project_id: project.id.to_proto(),
5396 peer_id: Some(session.connection_id.into()),
5397 },
5398 )
5399 .trace_err();
5400 }
5401 }
5402}
5403
5404pub trait ResultExt {
5405 type Ok;
5406
5407 fn trace_err(self) -> Option<Self::Ok>;
5408}
5409
5410impl<T, E> ResultExt for Result<T, E>
5411where
5412 E: std::fmt::Debug,
5413{
5414 type Ok = T;
5415
5416 #[track_caller]
5417 fn trace_err(self) -> Option<T> {
5418 match self {
5419 Ok(value) => Some(value),
5420 Err(error) => {
5421 tracing::error!("{:?}", error);
5422 None
5423 }
5424 }
5425 }
5426}