1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
8 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
9 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38
39use futures::{
40 channel::oneshot,
41 future::{self, BoxFuture},
42 stream::FuturesUnordered,
43 FutureExt, SinkExt, StreamExt, TryStreamExt,
44};
45use http::IsahcHttpClient;
46use prometheus::{register_int_gauge, IntGauge};
47use rpc::{
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
50 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 atomic::{AtomicBool, Ordering::SeqCst},
66 Arc, OnceLock,
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{watch, Semaphore};
72use tower::ServiceBuilder;
73use tracing::{
74 field::{self},
75 info_span, instrument, Instrument,
76};
77
78use self::connection_pool::VersionedMessage;
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106struct StreamingResponse<R: RequestMessage> {
107 peer: Arc<Peer>,
108 receipt: Receipt<R>,
109}
110
111impl<R: RequestMessage> StreamingResponse<R> {
112 fn send(&self, payload: R::Response) -> Result<()> {
113 self.peer.respond(self.receipt, payload)?;
114 Ok(())
115 }
116}
117
118#[derive(Clone, Debug)]
119pub enum Principal {
120 User(User),
121 Impersonated { user: User, admin: User },
122 DevServer(dev_server::Model),
123}
124
125impl Principal {
126 fn update_span(&self, span: &tracing::Span) {
127 match &self {
128 Principal::User(user) => {
129 span.record("user_id", &user.id.0);
130 span.record("login", &user.github_login);
131 }
132 Principal::Impersonated { user, admin } => {
133 span.record("user_id", &user.id.0);
134 span.record("login", &user.github_login);
135 span.record("impersonator", &admin.github_login);
136 }
137 Principal::DevServer(dev_server) => {
138 span.record("dev_server_id", &dev_server.id.0);
139 }
140 }
141 }
142}
143
144#[derive(Clone)]
145struct Session {
146 principal: Principal,
147 connection_id: ConnectionId,
148 db: Arc<tokio::sync::Mutex<DbHandle>>,
149 peer: Arc<Peer>,
150 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
151 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
152 supermaven_client: Option<Arc<SupermavenAdminApi>>,
153 http_client: Arc<IsahcHttpClient>,
154 rate_limiter: Arc<RateLimiter>,
155 _executor: Executor,
156}
157
158impl Session {
159 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.db.lock().await;
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 guard
166 }
167
168 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
169 #[cfg(test)]
170 tokio::task::yield_now().await;
171 let guard = self.connection_pool.lock();
172 ConnectionPoolGuard {
173 guard,
174 _not_send: PhantomData,
175 }
176 }
177
178 fn for_user(self) -> Option<UserSession> {
179 UserSession::new(self)
180 }
181
182 fn for_dev_server(self) -> Option<DevServerSession> {
183 DevServerSession::new(self)
184 }
185
186 fn user_id(&self) -> Option<UserId> {
187 match &self.principal {
188 Principal::User(user) => Some(user.id),
189 Principal::Impersonated { user, .. } => Some(user.id),
190 Principal::DevServer(_) => None,
191 }
192 }
193
194 fn is_staff(&self) -> bool {
195 match &self.principal {
196 Principal::User(user) => user.admin,
197 Principal::Impersonated { .. } => true,
198 Principal::DevServer(_) => false,
199 }
200 }
201
202 fn dev_server_id(&self) -> Option<DevServerId> {
203 match &self.principal {
204 Principal::User(_) | Principal::Impersonated { .. } => None,
205 Principal::DevServer(dev_server) => Some(dev_server.id),
206 }
207 }
208
209 fn principal_id(&self) -> PrincipalId {
210 match &self.principal {
211 Principal::User(user) => PrincipalId::UserId(user.id),
212 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
213 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
214 }
215 }
216}
217
218impl Debug for Session {
219 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220 let mut result = f.debug_struct("Session");
221 match &self.principal {
222 Principal::User(user) => {
223 result.field("user", &user.github_login);
224 }
225 Principal::Impersonated { user, admin } => {
226 result.field("user", &user.github_login);
227 result.field("impersonator", &admin.github_login);
228 }
229 Principal::DevServer(dev_server) => {
230 result.field("dev_server", &dev_server.id);
231 }
232 }
233 result.field("connection_id", &self.connection_id).finish()
234 }
235}
236
237struct UserSession(Session);
238
239impl UserSession {
240 pub fn new(s: Session) -> Option<Self> {
241 s.user_id().map(|_| UserSession(s))
242 }
243 pub fn user_id(&self) -> UserId {
244 self.0.user_id().unwrap()
245 }
246
247 pub fn email(&self) -> Option<String> {
248 match &self.0.principal {
249 Principal::User(user) => user.email_address.clone(),
250 Principal::Impersonated { user, .. } => user.email_address.clone(),
251 Principal::DevServer(..) => None,
252 }
253 }
254}
255
256impl Deref for UserSession {
257 type Target = Session;
258
259 fn deref(&self) -> &Self::Target {
260 &self.0
261 }
262}
263impl DerefMut for UserSession {
264 fn deref_mut(&mut self) -> &mut Self::Target {
265 &mut self.0
266 }
267}
268
269struct DevServerSession(Session);
270
271impl DevServerSession {
272 pub fn new(s: Session) -> Option<Self> {
273 s.dev_server_id().map(|_| DevServerSession(s))
274 }
275 pub fn dev_server_id(&self) -> DevServerId {
276 self.0.dev_server_id().unwrap()
277 }
278
279 fn dev_server(&self) -> &dev_server::Model {
280 match &self.0.principal {
281 Principal::DevServer(dev_server) => dev_server,
282 _ => unreachable!(),
283 }
284 }
285}
286
287impl Deref for DevServerSession {
288 type Target = Session;
289
290 fn deref(&self) -> &Self::Target {
291 &self.0
292 }
293}
294impl DerefMut for DevServerSession {
295 fn deref_mut(&mut self) -> &mut Self::Target {
296 &mut self.0
297 }
298}
299
300fn user_handler<M: RequestMessage, Fut>(
301 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
303where
304 Fut: Send + Future<Output = Result<()>>,
305{
306 let handler = Arc::new(handler);
307 move |message, response, session| {
308 let handler = handler.clone();
309 Box::pin(async move {
310 if let Some(user_session) = session.for_user() {
311 Ok(handler(message, response, user_session).await?)
312 } else {
313 Err(Error::Internal(anyhow!(
314 "must be a user to call {}",
315 M::NAME
316 )))
317 }
318 })
319 }
320}
321
322fn dev_server_handler<M: RequestMessage, Fut>(
323 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
325where
326 Fut: Send + Future<Output = Result<()>>,
327{
328 let handler = Arc::new(handler);
329 move |message, response, session| {
330 let handler = handler.clone();
331 Box::pin(async move {
332 if let Some(dev_server_session) = session.for_dev_server() {
333 Ok(handler(message, response, dev_server_session).await?)
334 } else {
335 Err(Error::Internal(anyhow!(
336 "must be a dev server to call {}",
337 M::NAME
338 )))
339 }
340 })
341 }
342}
343
344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
345 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
347where
348 InnertRetFut: Send + Future<Output = Result<()>>,
349{
350 let handler = Arc::new(handler);
351 move |message, session| {
352 let handler = handler.clone();
353 Box::pin(async move {
354 if let Some(user_session) = session.for_user() {
355 Ok(handler(message, user_session).await?)
356 } else {
357 Err(Error::Internal(anyhow!(
358 "must be a user to call {}",
359 M::NAME
360 )))
361 }
362 })
363 }
364}
365
366struct DbHandle(Arc<Database>);
367
368impl Deref for DbHandle {
369 type Target = Database;
370
371 fn deref(&self) -> &Self::Target {
372 self.0.as_ref()
373 }
374}
375
376pub struct Server {
377 id: parking_lot::Mutex<ServerId>,
378 peer: Arc<Peer>,
379 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
380 app_state: Arc<AppState>,
381 handlers: HashMap<TypeId, MessageHandler>,
382 teardown: watch::Sender<bool>,
383}
384
385pub(crate) struct ConnectionPoolGuard<'a> {
386 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
387 _not_send: PhantomData<Rc<()>>,
388}
389
390#[derive(Serialize)]
391pub struct ServerSnapshot<'a> {
392 peer: &'a Peer,
393 #[serde(serialize_with = "serialize_deref")]
394 connection_pool: ConnectionPoolGuard<'a>,
395}
396
397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
398where
399 S: Serializer,
400 T: Deref<Target = U>,
401 U: Serialize,
402{
403 Serialize::serialize(value.deref(), serializer)
404}
405
406impl Server {
407 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
408 let mut server = Self {
409 id: parking_lot::Mutex::new(id),
410 peer: Peer::new(id.0 as u32),
411 app_state: app_state.clone(),
412 connection_pool: Default::default(),
413 handlers: Default::default(),
414 teardown: watch::channel(false).0,
415 };
416
417 server
418 .add_request_handler(ping)
419 .add_request_handler(user_handler(create_room))
420 .add_request_handler(user_handler(join_room))
421 .add_request_handler(user_handler(rejoin_room))
422 .add_request_handler(user_handler(leave_room))
423 .add_request_handler(user_handler(set_room_participant_role))
424 .add_request_handler(user_handler(call))
425 .add_request_handler(user_handler(cancel_call))
426 .add_message_handler(user_message_handler(decline_call))
427 .add_request_handler(user_handler(update_participant_location))
428 .add_request_handler(user_handler(share_project))
429 .add_message_handler(unshare_project)
430 .add_request_handler(user_handler(join_project))
431 .add_request_handler(user_handler(join_hosted_project))
432 .add_request_handler(user_handler(rejoin_dev_server_projects))
433 .add_request_handler(user_handler(create_dev_server_project))
434 .add_request_handler(user_handler(delete_dev_server_project))
435 .add_request_handler(user_handler(create_dev_server))
436 .add_request_handler(user_handler(regenerate_dev_server_token))
437 .add_request_handler(user_handler(rename_dev_server))
438 .add_request_handler(user_handler(delete_dev_server))
439 .add_request_handler(dev_server_handler(share_dev_server_project))
440 .add_request_handler(dev_server_handler(shutdown_dev_server))
441 .add_request_handler(dev_server_handler(reconnect_dev_server))
442 .add_message_handler(user_message_handler(leave_project))
443 .add_request_handler(update_project)
444 .add_request_handler(update_worktree)
445 .add_message_handler(start_language_server)
446 .add_message_handler(update_language_server)
447 .add_message_handler(update_diagnostic_summary)
448 .add_message_handler(update_worktree_settings)
449 .add_request_handler(user_handler(
450 forward_read_only_project_request::<proto::GetHover>,
451 ))
452 .add_request_handler(user_handler(
453 forward_read_only_project_request::<proto::GetDefinition>,
454 ))
455 .add_request_handler(user_handler(
456 forward_read_only_project_request::<proto::GetTypeDefinition>,
457 ))
458 .add_request_handler(user_handler(
459 forward_read_only_project_request::<proto::GetReferences>,
460 ))
461 .add_request_handler(user_handler(
462 forward_read_only_project_request::<proto::SearchProject>,
463 ))
464 .add_request_handler(user_handler(
465 forward_read_only_project_request::<proto::GetDocumentHighlights>,
466 ))
467 .add_request_handler(user_handler(
468 forward_read_only_project_request::<proto::GetProjectSymbols>,
469 ))
470 .add_request_handler(user_handler(
471 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
472 ))
473 .add_request_handler(user_handler(
474 forward_read_only_project_request::<proto::OpenBufferById>,
475 ))
476 .add_request_handler(user_handler(
477 forward_read_only_project_request::<proto::SynchronizeBuffers>,
478 ))
479 .add_request_handler(user_handler(
480 forward_read_only_project_request::<proto::InlayHints>,
481 ))
482 .add_request_handler(user_handler(
483 forward_read_only_project_request::<proto::OpenBufferByPath>,
484 ))
485 .add_request_handler(user_handler(
486 forward_mutating_project_request::<proto::GetCompletions>,
487 ))
488 .add_request_handler(user_handler(
489 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
490 ))
491 .add_request_handler(user_handler(
492 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
493 ))
494 .add_request_handler(user_handler(
495 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
496 ))
497 .add_request_handler(user_handler(
498 forward_mutating_project_request::<proto::GetCodeActions>,
499 ))
500 .add_request_handler(user_handler(
501 forward_mutating_project_request::<proto::ApplyCodeAction>,
502 ))
503 .add_request_handler(user_handler(
504 forward_mutating_project_request::<proto::PrepareRename>,
505 ))
506 .add_request_handler(user_handler(
507 forward_mutating_project_request::<proto::PerformRename>,
508 ))
509 .add_request_handler(user_handler(
510 forward_mutating_project_request::<proto::ReloadBuffers>,
511 ))
512 .add_request_handler(user_handler(
513 forward_mutating_project_request::<proto::FormatBuffers>,
514 ))
515 .add_request_handler(user_handler(
516 forward_mutating_project_request::<proto::CreateProjectEntry>,
517 ))
518 .add_request_handler(user_handler(
519 forward_mutating_project_request::<proto::RenameProjectEntry>,
520 ))
521 .add_request_handler(user_handler(
522 forward_mutating_project_request::<proto::CopyProjectEntry>,
523 ))
524 .add_request_handler(user_handler(
525 forward_mutating_project_request::<proto::DeleteProjectEntry>,
526 ))
527 .add_request_handler(user_handler(
528 forward_mutating_project_request::<proto::ExpandProjectEntry>,
529 ))
530 .add_request_handler(user_handler(
531 forward_mutating_project_request::<proto::OnTypeFormatting>,
532 ))
533 .add_request_handler(user_handler(
534 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
535 ))
536 .add_request_handler(user_handler(
537 forward_mutating_project_request::<proto::BlameBuffer>,
538 ))
539 .add_request_handler(user_handler(
540 forward_mutating_project_request::<proto::MultiLspQuery>,
541 ))
542 .add_message_handler(create_buffer_for_peer)
543 .add_request_handler(update_buffer)
544 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
545 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
546 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
547 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
548 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
549 .add_request_handler(get_users)
550 .add_request_handler(user_handler(fuzzy_search_users))
551 .add_request_handler(user_handler(request_contact))
552 .add_request_handler(user_handler(remove_contact))
553 .add_request_handler(user_handler(respond_to_contact_request))
554 .add_request_handler(user_handler(create_channel))
555 .add_request_handler(user_handler(delete_channel))
556 .add_request_handler(user_handler(invite_channel_member))
557 .add_request_handler(user_handler(remove_channel_member))
558 .add_request_handler(user_handler(set_channel_member_role))
559 .add_request_handler(user_handler(set_channel_visibility))
560 .add_request_handler(user_handler(rename_channel))
561 .add_request_handler(user_handler(join_channel_buffer))
562 .add_request_handler(user_handler(leave_channel_buffer))
563 .add_message_handler(user_message_handler(update_channel_buffer))
564 .add_request_handler(user_handler(rejoin_channel_buffers))
565 .add_request_handler(user_handler(get_channel_members))
566 .add_request_handler(user_handler(respond_to_channel_invite))
567 .add_request_handler(user_handler(join_channel))
568 .add_request_handler(user_handler(join_channel_chat))
569 .add_message_handler(user_message_handler(leave_channel_chat))
570 .add_request_handler(user_handler(send_channel_message))
571 .add_request_handler(user_handler(remove_channel_message))
572 .add_request_handler(user_handler(update_channel_message))
573 .add_request_handler(user_handler(get_channel_messages))
574 .add_request_handler(user_handler(get_channel_messages_by_id))
575 .add_request_handler(user_handler(get_notifications))
576 .add_request_handler(user_handler(mark_notification_as_read))
577 .add_request_handler(user_handler(move_channel))
578 .add_request_handler(user_handler(follow))
579 .add_message_handler(user_message_handler(unfollow))
580 .add_message_handler(user_message_handler(update_followers))
581 .add_request_handler(user_handler(get_private_user_info))
582 .add_message_handler(user_message_handler(acknowledge_channel_message))
583 .add_message_handler(user_message_handler(acknowledge_buffer_version))
584 .add_request_handler(user_handler(get_supermaven_api_key))
585 .add_streaming_request_handler({
586 let app_state = app_state.clone();
587 move |request, response, session| {
588 complete_with_language_model(
589 request,
590 response,
591 session,
592 app_state.config.openai_api_key.clone(),
593 app_state.config.google_ai_api_key.clone(),
594 app_state.config.anthropic_api_key.clone(),
595 )
596 }
597 })
598 .add_request_handler({
599 let app_state = app_state.clone();
600 user_handler(move |request, response, session| {
601 count_tokens_with_language_model(
602 request,
603 response,
604 session,
605 app_state.config.google_ai_api_key.clone(),
606 )
607 })
608 })
609 .add_request_handler({
610 user_handler(move |request, response, session| {
611 get_cached_embeddings(request, response, session)
612 })
613 })
614 .add_request_handler({
615 let app_state = app_state.clone();
616 user_handler(move |request, response, session| {
617 compute_embeddings(
618 request,
619 response,
620 session,
621 app_state.config.openai_api_key.clone(),
622 )
623 })
624 });
625
626 Arc::new(server)
627 }
628
629 pub async fn start(&self) -> Result<()> {
630 let server_id = *self.id.lock();
631 let app_state = self.app_state.clone();
632 let peer = self.peer.clone();
633 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
634 let pool = self.connection_pool.clone();
635 let live_kit_client = self.app_state.live_kit_client.clone();
636
637 let span = info_span!("start server");
638 self.app_state.executor.spawn_detached(
639 async move {
640 tracing::info!("waiting for cleanup timeout");
641 timeout.await;
642 tracing::info!("cleanup timeout expired, retrieving stale rooms");
643 if let Some((room_ids, channel_ids)) = app_state
644 .db
645 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
646 .await
647 .trace_err()
648 {
649 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
650 tracing::info!(
651 stale_channel_buffer_count = channel_ids.len(),
652 "retrieved stale channel buffers"
653 );
654
655 for channel_id in channel_ids {
656 if let Some(refreshed_channel_buffer) = app_state
657 .db
658 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
659 .await
660 .trace_err()
661 {
662 for connection_id in refreshed_channel_buffer.connection_ids {
663 peer.send(
664 connection_id,
665 proto::UpdateChannelBufferCollaborators {
666 channel_id: channel_id.to_proto(),
667 collaborators: refreshed_channel_buffer
668 .collaborators
669 .clone(),
670 },
671 )
672 .trace_err();
673 }
674 }
675 }
676
677 for room_id in room_ids {
678 let mut contacts_to_update = HashSet::default();
679 let mut canceled_calls_to_user_ids = Vec::new();
680 let mut live_kit_room = String::new();
681 let mut delete_live_kit_room = false;
682
683 if let Some(mut refreshed_room) = app_state
684 .db
685 .clear_stale_room_participants(room_id, server_id)
686 .await
687 .trace_err()
688 {
689 tracing::info!(
690 room_id = room_id.0,
691 new_participant_count = refreshed_room.room.participants.len(),
692 "refreshed room"
693 );
694 room_updated(&refreshed_room.room, &peer);
695 if let Some(channel) = refreshed_room.channel.as_ref() {
696 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
697 }
698 contacts_to_update
699 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
700 contacts_to_update
701 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
702 canceled_calls_to_user_ids =
703 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
704 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
705 delete_live_kit_room = refreshed_room.room.participants.is_empty();
706 }
707
708 {
709 let pool = pool.lock();
710 for canceled_user_id in canceled_calls_to_user_ids {
711 for connection_id in pool.user_connection_ids(canceled_user_id) {
712 peer.send(
713 connection_id,
714 proto::CallCanceled {
715 room_id: room_id.to_proto(),
716 },
717 )
718 .trace_err();
719 }
720 }
721 }
722
723 for user_id in contacts_to_update {
724 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
725 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
726 if let Some((busy, contacts)) = busy.zip(contacts) {
727 let pool = pool.lock();
728 let updated_contact = contact_for_user(user_id, busy, &pool);
729 for contact in contacts {
730 if let db::Contact::Accepted {
731 user_id: contact_user_id,
732 ..
733 } = contact
734 {
735 for contact_conn_id in
736 pool.user_connection_ids(contact_user_id)
737 {
738 peer.send(
739 contact_conn_id,
740 proto::UpdateContacts {
741 contacts: vec![updated_contact.clone()],
742 remove_contacts: Default::default(),
743 incoming_requests: Default::default(),
744 remove_incoming_requests: Default::default(),
745 outgoing_requests: Default::default(),
746 remove_outgoing_requests: Default::default(),
747 },
748 )
749 .trace_err();
750 }
751 }
752 }
753 }
754 }
755
756 if let Some(live_kit) = live_kit_client.as_ref() {
757 if delete_live_kit_room {
758 live_kit.delete_room(live_kit_room).await.trace_err();
759 }
760 }
761 }
762 }
763
764 app_state
765 .db
766 .delete_stale_servers(&app_state.config.zed_environment, server_id)
767 .await
768 .trace_err();
769 }
770 .instrument(span),
771 );
772 Ok(())
773 }
774
775 pub fn teardown(&self) {
776 self.peer.teardown();
777 self.connection_pool.lock().reset();
778 let _ = self.teardown.send(true);
779 }
780
781 #[cfg(test)]
782 pub fn reset(&self, id: ServerId) {
783 self.teardown();
784 *self.id.lock() = id;
785 self.peer.reset(id.0 as u32);
786 let _ = self.teardown.send(false);
787 }
788
789 #[cfg(test)]
790 pub fn id(&self) -> ServerId {
791 *self.id.lock()
792 }
793
794 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
795 where
796 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
797 Fut: 'static + Send + Future<Output = Result<()>>,
798 M: EnvelopedMessage,
799 {
800 let prev_handler = self.handlers.insert(
801 TypeId::of::<M>(),
802 Box::new(move |envelope, session| {
803 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
804 let received_at = envelope.received_at;
805 tracing::info!("message received");
806 let start_time = Instant::now();
807 let future = (handler)(*envelope, session);
808 async move {
809 let result = future.await;
810 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
811 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
812 let queue_duration_ms = total_duration_ms - processing_duration_ms;
813 let payload_type = M::NAME;
814
815 match result {
816 Err(error) => {
817 tracing::error!(
818 ?error,
819 total_duration_ms,
820 processing_duration_ms,
821 queue_duration_ms,
822 payload_type,
823 "error handling message"
824 )
825 }
826 Ok(()) => tracing::info!(
827 total_duration_ms,
828 processing_duration_ms,
829 queue_duration_ms,
830 "finished handling message"
831 ),
832 }
833 }
834 .boxed()
835 }),
836 );
837 if prev_handler.is_some() {
838 panic!("registered a handler for the same message twice");
839 }
840 self
841 }
842
843 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
844 where
845 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
846 Fut: 'static + Send + Future<Output = Result<()>>,
847 M: EnvelopedMessage,
848 {
849 self.add_handler(move |envelope, session| handler(envelope.payload, session));
850 self
851 }
852
853 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
854 where
855 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
856 Fut: Send + Future<Output = Result<()>>,
857 M: RequestMessage,
858 {
859 let handler = Arc::new(handler);
860 self.add_handler(move |envelope, session| {
861 let receipt = envelope.receipt();
862 let handler = handler.clone();
863 async move {
864 let peer = session.peer.clone();
865 let responded = Arc::new(AtomicBool::default());
866 let response = Response {
867 peer: peer.clone(),
868 responded: responded.clone(),
869 receipt,
870 };
871 match (handler)(envelope.payload, response, session).await {
872 Ok(()) => {
873 if responded.load(std::sync::atomic::Ordering::SeqCst) {
874 Ok(())
875 } else {
876 Err(anyhow!("handler did not send a response"))?
877 }
878 }
879 Err(error) => {
880 let proto_err = match &error {
881 Error::Internal(err) => err.to_proto(),
882 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
883 };
884 peer.respond_with_error(receipt, proto_err)?;
885 Err(error)
886 }
887 }
888 }
889 })
890 }
891
892 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
893 where
894 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
895 Fut: Send + Future<Output = Result<()>>,
896 M: RequestMessage,
897 {
898 let handler = Arc::new(handler);
899 self.add_handler(move |envelope, session| {
900 let receipt = envelope.receipt();
901 let handler = handler.clone();
902 async move {
903 let peer = session.peer.clone();
904 let response = StreamingResponse {
905 peer: peer.clone(),
906 receipt,
907 };
908 match (handler)(envelope.payload, response, session).await {
909 Ok(()) => {
910 peer.end_stream(receipt)?;
911 Ok(())
912 }
913 Err(error) => {
914 let proto_err = match &error {
915 Error::Internal(err) => err.to_proto(),
916 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
917 };
918 peer.respond_with_error(receipt, proto_err)?;
919 Err(error)
920 }
921 }
922 }
923 })
924 }
925
926 #[allow(clippy::too_many_arguments)]
927 pub fn handle_connection(
928 self: &Arc<Self>,
929 connection: Connection,
930 address: String,
931 principal: Principal,
932 zed_version: ZedVersion,
933 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
934 executor: Executor,
935 ) -> impl Future<Output = ()> {
936 let this = self.clone();
937 let span = info_span!("handle connection", %address,
938 connection_id=field::Empty,
939 user_id=field::Empty,
940 login=field::Empty,
941 impersonator=field::Empty,
942 dev_server_id=field::Empty
943 );
944 principal.update_span(&span);
945
946 let mut teardown = self.teardown.subscribe();
947 async move {
948 if *teardown.borrow() {
949 tracing::error!("server is tearing down");
950 return
951 }
952 let (connection_id, handle_io, mut incoming_rx) = this
953 .peer
954 .add_connection(connection, {
955 let executor = executor.clone();
956 move |duration| executor.sleep(duration)
957 });
958 tracing::Span::current().record("connection_id", format!("{}", connection_id));
959 tracing::info!("connection opened");
960
961 let http_client = match IsahcHttpClient::new() {
962 Ok(http_client) => Arc::new(http_client),
963 Err(error) => {
964 tracing::error!(?error, "failed to create HTTP client");
965 return;
966 }
967 };
968
969 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
970 Some(Arc::new(SupermavenAdminApi::new(
971 supermaven_admin_api_key.to_string(),
972 http_client.clone(),
973 )))
974 } else {
975 None
976 };
977
978 let session = Session {
979 principal: principal.clone(),
980 connection_id,
981 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
982 peer: this.peer.clone(),
983 connection_pool: this.connection_pool.clone(),
984 live_kit_client: this.app_state.live_kit_client.clone(),
985 http_client,
986 rate_limiter: this.app_state.rate_limiter.clone(),
987 _executor: executor.clone(),
988 supermaven_client,
989 };
990
991 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
992 tracing::error!(?error, "failed to send initial client update");
993 return;
994 }
995
996 let handle_io = handle_io.fuse();
997 futures::pin_mut!(handle_io);
998
999 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1000 // This prevents deadlocks when e.g., client A performs a request to client B and
1001 // client B performs a request to client A. If both clients stop processing further
1002 // messages until their respective request completes, they won't have a chance to
1003 // respond to the other client's request and cause a deadlock.
1004 //
1005 // This arrangement ensures we will attempt to process earlier messages first, but fall
1006 // back to processing messages arrived later in the spirit of making progress.
1007 let mut foreground_message_handlers = FuturesUnordered::new();
1008 let concurrent_handlers = Arc::new(Semaphore::new(256));
1009 loop {
1010 let next_message = async {
1011 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1012 let message = incoming_rx.next().await;
1013 (permit, message)
1014 }.fuse();
1015 futures::pin_mut!(next_message);
1016 futures::select_biased! {
1017 _ = teardown.changed().fuse() => return,
1018 result = handle_io => {
1019 if let Err(error) = result {
1020 tracing::error!(?error, "error handling I/O");
1021 }
1022 break;
1023 }
1024 _ = foreground_message_handlers.next() => {}
1025 next_message = next_message => {
1026 let (permit, message) = next_message;
1027 if let Some(message) = message {
1028 let type_name = message.payload_type_name();
1029 // note: we copy all the fields from the parent span so we can query them in the logs.
1030 // (https://github.com/tokio-rs/tracing/issues/2670).
1031 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1032 user_id=field::Empty,
1033 login=field::Empty,
1034 impersonator=field::Empty,
1035 dev_server_id=field::Empty
1036 );
1037 principal.update_span(&span);
1038 let span_enter = span.enter();
1039 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1040 let is_background = message.is_background();
1041 let handle_message = (handler)(message, session.clone());
1042 drop(span_enter);
1043
1044 let handle_message = async move {
1045 handle_message.await;
1046 drop(permit);
1047 }.instrument(span);
1048 if is_background {
1049 executor.spawn_detached(handle_message);
1050 } else {
1051 foreground_message_handlers.push(handle_message);
1052 }
1053 } else {
1054 tracing::error!("no message handler");
1055 }
1056 } else {
1057 tracing::info!("connection closed");
1058 break;
1059 }
1060 }
1061 }
1062 }
1063
1064 drop(foreground_message_handlers);
1065 tracing::info!("signing out");
1066 if let Err(error) = connection_lost(session, teardown, executor).await {
1067 tracing::error!(?error, "error signing out");
1068 }
1069
1070 }.instrument(span)
1071 }
1072
1073 async fn send_initial_client_update(
1074 &self,
1075 connection_id: ConnectionId,
1076 principal: &Principal,
1077 zed_version: ZedVersion,
1078 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1079 session: &Session,
1080 ) -> Result<()> {
1081 self.peer.send(
1082 connection_id,
1083 proto::Hello {
1084 peer_id: Some(connection_id.into()),
1085 },
1086 )?;
1087 tracing::info!("sent hello message");
1088 if let Some(send_connection_id) = send_connection_id.take() {
1089 let _ = send_connection_id.send(connection_id);
1090 }
1091
1092 match principal {
1093 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1094 if !user.connected_once {
1095 self.peer.send(connection_id, proto::ShowContacts {})?;
1096 self.app_state
1097 .db
1098 .set_user_connected_once(user.id, true)
1099 .await?;
1100 }
1101
1102 let (contacts, channels_for_user, channel_invites, dev_server_projects) =
1103 future::try_join4(
1104 self.app_state.db.get_contacts(user.id),
1105 self.app_state.db.get_channels_for_user(user.id),
1106 self.app_state.db.get_channel_invites_for_user(user.id),
1107 self.app_state.db.dev_server_projects_update(user.id),
1108 )
1109 .await?;
1110
1111 {
1112 let mut pool = self.connection_pool.lock();
1113 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1114 for membership in &channels_for_user.channel_memberships {
1115 pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
1116 }
1117 self.peer.send(
1118 connection_id,
1119 build_initial_contacts_update(contacts, &pool),
1120 )?;
1121 self.peer.send(
1122 connection_id,
1123 build_update_user_channels(&channels_for_user),
1124 )?;
1125 self.peer.send(
1126 connection_id,
1127 build_channels_update(channels_for_user, channel_invites),
1128 )?;
1129 }
1130 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1131
1132 if let Some(incoming_call) =
1133 self.app_state.db.incoming_call_for_user(user.id).await?
1134 {
1135 self.peer.send(connection_id, incoming_call)?;
1136 }
1137
1138 update_user_contacts(user.id, &session).await?;
1139 }
1140 Principal::DevServer(dev_server) => {
1141 {
1142 let mut pool = self.connection_pool.lock();
1143 if pool.dev_server_connection_id(dev_server.id).is_some() {
1144 return Err(anyhow!(ErrorCode::DevServerAlreadyOnline))?;
1145 };
1146 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1147 }
1148
1149 let projects = self
1150 .app_state
1151 .db
1152 .get_projects_for_dev_server(dev_server.id)
1153 .await?;
1154 self.peer
1155 .send(connection_id, proto::DevServerInstructions { projects })?;
1156
1157 let status = self
1158 .app_state
1159 .db
1160 .dev_server_projects_update(dev_server.user_id)
1161 .await?;
1162 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1163 }
1164 }
1165
1166 Ok(())
1167 }
1168
1169 pub async fn invite_code_redeemed(
1170 self: &Arc<Self>,
1171 inviter_id: UserId,
1172 invitee_id: UserId,
1173 ) -> Result<()> {
1174 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1175 if let Some(code) = &user.invite_code {
1176 let pool = self.connection_pool.lock();
1177 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1178 for connection_id in pool.user_connection_ids(inviter_id) {
1179 self.peer.send(
1180 connection_id,
1181 proto::UpdateContacts {
1182 contacts: vec![invitee_contact.clone()],
1183 ..Default::default()
1184 },
1185 )?;
1186 self.peer.send(
1187 connection_id,
1188 proto::UpdateInviteInfo {
1189 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1190 count: user.invite_count as u32,
1191 },
1192 )?;
1193 }
1194 }
1195 }
1196 Ok(())
1197 }
1198
1199 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1200 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1201 if let Some(invite_code) = &user.invite_code {
1202 let pool = self.connection_pool.lock();
1203 for connection_id in pool.user_connection_ids(user_id) {
1204 self.peer.send(
1205 connection_id,
1206 proto::UpdateInviteInfo {
1207 url: format!(
1208 "{}{}",
1209 self.app_state.config.invite_link_prefix, invite_code
1210 ),
1211 count: user.invite_count as u32,
1212 },
1213 )?;
1214 }
1215 }
1216 }
1217 Ok(())
1218 }
1219
1220 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1221 ServerSnapshot {
1222 connection_pool: ConnectionPoolGuard {
1223 guard: self.connection_pool.lock(),
1224 _not_send: PhantomData,
1225 },
1226 peer: &self.peer,
1227 }
1228 }
1229}
1230
1231impl<'a> Deref for ConnectionPoolGuard<'a> {
1232 type Target = ConnectionPool;
1233
1234 fn deref(&self) -> &Self::Target {
1235 &self.guard
1236 }
1237}
1238
1239impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1240 fn deref_mut(&mut self) -> &mut Self::Target {
1241 &mut self.guard
1242 }
1243}
1244
1245impl<'a> Drop for ConnectionPoolGuard<'a> {
1246 fn drop(&mut self) {
1247 #[cfg(test)]
1248 self.check_invariants();
1249 }
1250}
1251
1252fn broadcast<F>(
1253 sender_id: Option<ConnectionId>,
1254 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1255 mut f: F,
1256) where
1257 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1258{
1259 for receiver_id in receiver_ids {
1260 if Some(receiver_id) != sender_id {
1261 if let Err(error) = f(receiver_id) {
1262 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1263 }
1264 }
1265 }
1266}
1267
1268pub struct ProtocolVersion(u32);
1269
1270impl Header for ProtocolVersion {
1271 fn name() -> &'static HeaderName {
1272 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1273 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1274 }
1275
1276 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1277 where
1278 Self: Sized,
1279 I: Iterator<Item = &'i axum::http::HeaderValue>,
1280 {
1281 let version = values
1282 .next()
1283 .ok_or_else(axum::headers::Error::invalid)?
1284 .to_str()
1285 .map_err(|_| axum::headers::Error::invalid())?
1286 .parse()
1287 .map_err(|_| axum::headers::Error::invalid())?;
1288 Ok(Self(version))
1289 }
1290
1291 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1292 values.extend([self.0.to_string().parse().unwrap()]);
1293 }
1294}
1295
1296pub struct AppVersionHeader(SemanticVersion);
1297impl Header for AppVersionHeader {
1298 fn name() -> &'static HeaderName {
1299 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1300 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1301 }
1302
1303 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1304 where
1305 Self: Sized,
1306 I: Iterator<Item = &'i axum::http::HeaderValue>,
1307 {
1308 let version = values
1309 .next()
1310 .ok_or_else(axum::headers::Error::invalid)?
1311 .to_str()
1312 .map_err(|_| axum::headers::Error::invalid())?
1313 .parse()
1314 .map_err(|_| axum::headers::Error::invalid())?;
1315 Ok(Self(version))
1316 }
1317
1318 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1319 values.extend([self.0.to_string().parse().unwrap()]);
1320 }
1321}
1322
1323pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1324 Router::new()
1325 .route("/rpc", get(handle_websocket_request))
1326 .layer(
1327 ServiceBuilder::new()
1328 .layer(Extension(server.app_state.clone()))
1329 .layer(middleware::from_fn(auth::validate_header)),
1330 )
1331 .route("/metrics", get(handle_metrics))
1332 .layer(Extension(server))
1333}
1334
1335pub async fn handle_websocket_request(
1336 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1337 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1338 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1339 Extension(server): Extension<Arc<Server>>,
1340 Extension(principal): Extension<Principal>,
1341 ws: WebSocketUpgrade,
1342) -> axum::response::Response {
1343 if protocol_version != rpc::PROTOCOL_VERSION {
1344 return (
1345 StatusCode::UPGRADE_REQUIRED,
1346 "client must be upgraded".to_string(),
1347 )
1348 .into_response();
1349 }
1350
1351 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1352 return (
1353 StatusCode::UPGRADE_REQUIRED,
1354 "no version header found".to_string(),
1355 )
1356 .into_response();
1357 };
1358
1359 if !version.can_collaborate() {
1360 return (
1361 StatusCode::UPGRADE_REQUIRED,
1362 "client must be upgraded".to_string(),
1363 )
1364 .into_response();
1365 }
1366
1367 let socket_address = socket_address.to_string();
1368 ws.on_upgrade(move |socket| {
1369 let socket = socket
1370 .map_ok(to_tungstenite_message)
1371 .err_into()
1372 .with(|message| async move { Ok(to_axum_message(message)) });
1373 let connection = Connection::new(Box::pin(socket));
1374 async move {
1375 server
1376 .handle_connection(
1377 connection,
1378 socket_address,
1379 principal,
1380 version,
1381 None,
1382 Executor::Production,
1383 )
1384 .await;
1385 }
1386 })
1387}
1388
1389pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1390 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1391 let connections_metric = CONNECTIONS_METRIC
1392 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1393
1394 let connections = server
1395 .connection_pool
1396 .lock()
1397 .connections()
1398 .filter(|connection| !connection.admin)
1399 .count();
1400 connections_metric.set(connections as _);
1401
1402 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1403 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1404 register_int_gauge!(
1405 "shared_projects",
1406 "number of open projects with one or more guests"
1407 )
1408 .unwrap()
1409 });
1410
1411 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1412 shared_projects_metric.set(shared_projects as _);
1413
1414 let encoder = prometheus::TextEncoder::new();
1415 let metric_families = prometheus::gather();
1416 let encoded_metrics = encoder
1417 .encode_to_string(&metric_families)
1418 .map_err(|err| anyhow!("{}", err))?;
1419 Ok(encoded_metrics)
1420}
1421
1422#[instrument(err, skip(executor))]
1423async fn connection_lost(
1424 session: Session,
1425 mut teardown: watch::Receiver<bool>,
1426 executor: Executor,
1427) -> Result<()> {
1428 session.peer.disconnect(session.connection_id);
1429 session
1430 .connection_pool()
1431 .await
1432 .remove_connection(session.connection_id)?;
1433
1434 session
1435 .db()
1436 .await
1437 .connection_lost(session.connection_id)
1438 .await
1439 .trace_err();
1440
1441 futures::select_biased! {
1442 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1443 match &session.principal {
1444 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1445 let session = session.for_user().unwrap();
1446
1447 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1448 leave_room_for_session(&session, session.connection_id).await.trace_err();
1449 leave_channel_buffers_for_session(&session)
1450 .await
1451 .trace_err();
1452
1453 if !session
1454 .connection_pool()
1455 .await
1456 .is_user_online(session.user_id())
1457 {
1458 let db = session.db().await;
1459 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1460 room_updated(&room, &session.peer);
1461 }
1462 }
1463
1464 update_user_contacts(session.user_id(), &session).await?;
1465 },
1466 Principal::DevServer(_) => {
1467 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1468 },
1469 }
1470 },
1471 _ = teardown.changed().fuse() => {}
1472 }
1473
1474 Ok(())
1475}
1476
1477/// Acknowledges a ping from a client, used to keep the connection alive.
1478async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1479 response.send(proto::Ack {})?;
1480 Ok(())
1481}
1482
1483/// Creates a new room for calling (outside of channels)
1484async fn create_room(
1485 _request: proto::CreateRoom,
1486 response: Response<proto::CreateRoom>,
1487 session: UserSession,
1488) -> Result<()> {
1489 let live_kit_room = nanoid::nanoid!(30);
1490
1491 let live_kit_connection_info = util::maybe!(async {
1492 let live_kit = session.live_kit_client.as_ref();
1493 let live_kit = live_kit?;
1494 let user_id = session.user_id().to_string();
1495
1496 let token = live_kit
1497 .room_token(&live_kit_room, &user_id.to_string())
1498 .trace_err()?;
1499
1500 Some(proto::LiveKitConnectionInfo {
1501 server_url: live_kit.url().into(),
1502 token,
1503 can_publish: true,
1504 })
1505 })
1506 .await;
1507
1508 let room = session
1509 .db()
1510 .await
1511 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1512 .await?;
1513
1514 response.send(proto::CreateRoomResponse {
1515 room: Some(room.clone()),
1516 live_kit_connection_info,
1517 })?;
1518
1519 update_user_contacts(session.user_id(), &session).await?;
1520 Ok(())
1521}
1522
1523/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1524async fn join_room(
1525 request: proto::JoinRoom,
1526 response: Response<proto::JoinRoom>,
1527 session: UserSession,
1528) -> Result<()> {
1529 let room_id = RoomId::from_proto(request.id);
1530
1531 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1532
1533 if let Some(channel_id) = channel_id {
1534 return join_channel_internal(channel_id, Box::new(response), session).await;
1535 }
1536
1537 let joined_room = {
1538 let room = session
1539 .db()
1540 .await
1541 .join_room(room_id, session.user_id(), session.connection_id)
1542 .await?;
1543 room_updated(&room.room, &session.peer);
1544 room.into_inner()
1545 };
1546
1547 for connection_id in session
1548 .connection_pool()
1549 .await
1550 .user_connection_ids(session.user_id())
1551 {
1552 session
1553 .peer
1554 .send(
1555 connection_id,
1556 proto::CallCanceled {
1557 room_id: room_id.to_proto(),
1558 },
1559 )
1560 .trace_err();
1561 }
1562
1563 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1564 if let Some(token) = live_kit
1565 .room_token(
1566 &joined_room.room.live_kit_room,
1567 &session.user_id().to_string(),
1568 )
1569 .trace_err()
1570 {
1571 Some(proto::LiveKitConnectionInfo {
1572 server_url: live_kit.url().into(),
1573 token,
1574 can_publish: true,
1575 })
1576 } else {
1577 None
1578 }
1579 } else {
1580 None
1581 };
1582
1583 response.send(proto::JoinRoomResponse {
1584 room: Some(joined_room.room),
1585 channel_id: None,
1586 live_kit_connection_info,
1587 })?;
1588
1589 update_user_contacts(session.user_id(), &session).await?;
1590 Ok(())
1591}
1592
1593/// Rejoin room is used to reconnect to a room after connection errors.
1594async fn rejoin_room(
1595 request: proto::RejoinRoom,
1596 response: Response<proto::RejoinRoom>,
1597 session: UserSession,
1598) -> Result<()> {
1599 let room;
1600 let channel;
1601 {
1602 let mut rejoined_room = session
1603 .db()
1604 .await
1605 .rejoin_room(request, session.user_id(), session.connection_id)
1606 .await?;
1607
1608 response.send(proto::RejoinRoomResponse {
1609 room: Some(rejoined_room.room.clone()),
1610 reshared_projects: rejoined_room
1611 .reshared_projects
1612 .iter()
1613 .map(|project| proto::ResharedProject {
1614 id: project.id.to_proto(),
1615 collaborators: project
1616 .collaborators
1617 .iter()
1618 .map(|collaborator| collaborator.to_proto())
1619 .collect(),
1620 })
1621 .collect(),
1622 rejoined_projects: rejoined_room
1623 .rejoined_projects
1624 .iter()
1625 .map(|rejoined_project| rejoined_project.to_proto())
1626 .collect(),
1627 })?;
1628 room_updated(&rejoined_room.room, &session.peer);
1629
1630 for project in &rejoined_room.reshared_projects {
1631 for collaborator in &project.collaborators {
1632 session
1633 .peer
1634 .send(
1635 collaborator.connection_id,
1636 proto::UpdateProjectCollaborator {
1637 project_id: project.id.to_proto(),
1638 old_peer_id: Some(project.old_connection_id.into()),
1639 new_peer_id: Some(session.connection_id.into()),
1640 },
1641 )
1642 .trace_err();
1643 }
1644
1645 broadcast(
1646 Some(session.connection_id),
1647 project
1648 .collaborators
1649 .iter()
1650 .map(|collaborator| collaborator.connection_id),
1651 |connection_id| {
1652 session.peer.forward_send(
1653 session.connection_id,
1654 connection_id,
1655 proto::UpdateProject {
1656 project_id: project.id.to_proto(),
1657 worktrees: project.worktrees.clone(),
1658 },
1659 )
1660 },
1661 );
1662 }
1663
1664 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1665
1666 let rejoined_room = rejoined_room.into_inner();
1667
1668 room = rejoined_room.room;
1669 channel = rejoined_room.channel;
1670 }
1671
1672 if let Some(channel) = channel {
1673 channel_updated(
1674 &channel,
1675 &room,
1676 &session.peer,
1677 &*session.connection_pool().await,
1678 );
1679 }
1680
1681 update_user_contacts(session.user_id(), &session).await?;
1682 Ok(())
1683}
1684
1685fn notify_rejoined_projects(
1686 rejoined_projects: &mut Vec<RejoinedProject>,
1687 session: &UserSession,
1688) -> Result<()> {
1689 for project in rejoined_projects.iter() {
1690 for collaborator in &project.collaborators {
1691 session
1692 .peer
1693 .send(
1694 collaborator.connection_id,
1695 proto::UpdateProjectCollaborator {
1696 project_id: project.id.to_proto(),
1697 old_peer_id: Some(project.old_connection_id.into()),
1698 new_peer_id: Some(session.connection_id.into()),
1699 },
1700 )
1701 .trace_err();
1702 }
1703 }
1704
1705 for project in rejoined_projects {
1706 for worktree in mem::take(&mut project.worktrees) {
1707 #[cfg(any(test, feature = "test-support"))]
1708 const MAX_CHUNK_SIZE: usize = 2;
1709 #[cfg(not(any(test, feature = "test-support")))]
1710 const MAX_CHUNK_SIZE: usize = 256;
1711
1712 // Stream this worktree's entries.
1713 let message = proto::UpdateWorktree {
1714 project_id: project.id.to_proto(),
1715 worktree_id: worktree.id,
1716 abs_path: worktree.abs_path.clone(),
1717 root_name: worktree.root_name,
1718 updated_entries: worktree.updated_entries,
1719 removed_entries: worktree.removed_entries,
1720 scan_id: worktree.scan_id,
1721 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1722 updated_repositories: worktree.updated_repositories,
1723 removed_repositories: worktree.removed_repositories,
1724 };
1725 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1726 session.peer.send(session.connection_id, update.clone())?;
1727 }
1728
1729 // Stream this worktree's diagnostics.
1730 for summary in worktree.diagnostic_summaries {
1731 session.peer.send(
1732 session.connection_id,
1733 proto::UpdateDiagnosticSummary {
1734 project_id: project.id.to_proto(),
1735 worktree_id: worktree.id,
1736 summary: Some(summary),
1737 },
1738 )?;
1739 }
1740
1741 for settings_file in worktree.settings_files {
1742 session.peer.send(
1743 session.connection_id,
1744 proto::UpdateWorktreeSettings {
1745 project_id: project.id.to_proto(),
1746 worktree_id: worktree.id,
1747 path: settings_file.path,
1748 content: Some(settings_file.content),
1749 },
1750 )?;
1751 }
1752 }
1753
1754 for language_server in &project.language_servers {
1755 session.peer.send(
1756 session.connection_id,
1757 proto::UpdateLanguageServer {
1758 project_id: project.id.to_proto(),
1759 language_server_id: language_server.id,
1760 variant: Some(
1761 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1762 proto::LspDiskBasedDiagnosticsUpdated {},
1763 ),
1764 ),
1765 },
1766 )?;
1767 }
1768 }
1769 Ok(())
1770}
1771
1772/// leave room disconnects from the room.
1773async fn leave_room(
1774 _: proto::LeaveRoom,
1775 response: Response<proto::LeaveRoom>,
1776 session: UserSession,
1777) -> Result<()> {
1778 leave_room_for_session(&session, session.connection_id).await?;
1779 response.send(proto::Ack {})?;
1780 Ok(())
1781}
1782
1783/// Updates the permissions of someone else in the room.
1784async fn set_room_participant_role(
1785 request: proto::SetRoomParticipantRole,
1786 response: Response<proto::SetRoomParticipantRole>,
1787 session: UserSession,
1788) -> Result<()> {
1789 let user_id = UserId::from_proto(request.user_id);
1790 let role = ChannelRole::from(request.role());
1791
1792 let (live_kit_room, can_publish) = {
1793 let room = session
1794 .db()
1795 .await
1796 .set_room_participant_role(
1797 session.user_id(),
1798 RoomId::from_proto(request.room_id),
1799 user_id,
1800 role,
1801 )
1802 .await?;
1803
1804 let live_kit_room = room.live_kit_room.clone();
1805 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1806 room_updated(&room, &session.peer);
1807 (live_kit_room, can_publish)
1808 };
1809
1810 if let Some(live_kit) = session.live_kit_client.as_ref() {
1811 live_kit
1812 .update_participant(
1813 live_kit_room.clone(),
1814 request.user_id.to_string(),
1815 live_kit_server::proto::ParticipantPermission {
1816 can_subscribe: true,
1817 can_publish,
1818 can_publish_data: can_publish,
1819 hidden: false,
1820 recorder: false,
1821 },
1822 )
1823 .await
1824 .trace_err();
1825 }
1826
1827 response.send(proto::Ack {})?;
1828 Ok(())
1829}
1830
1831/// Call someone else into the current room
1832async fn call(
1833 request: proto::Call,
1834 response: Response<proto::Call>,
1835 session: UserSession,
1836) -> Result<()> {
1837 let room_id = RoomId::from_proto(request.room_id);
1838 let calling_user_id = session.user_id();
1839 let calling_connection_id = session.connection_id;
1840 let called_user_id = UserId::from_proto(request.called_user_id);
1841 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1842 if !session
1843 .db()
1844 .await
1845 .has_contact(calling_user_id, called_user_id)
1846 .await?
1847 {
1848 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1849 }
1850
1851 let incoming_call = {
1852 let (room, incoming_call) = &mut *session
1853 .db()
1854 .await
1855 .call(
1856 room_id,
1857 calling_user_id,
1858 calling_connection_id,
1859 called_user_id,
1860 initial_project_id,
1861 )
1862 .await?;
1863 room_updated(&room, &session.peer);
1864 mem::take(incoming_call)
1865 };
1866 update_user_contacts(called_user_id, &session).await?;
1867
1868 let mut calls = session
1869 .connection_pool()
1870 .await
1871 .user_connection_ids(called_user_id)
1872 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1873 .collect::<FuturesUnordered<_>>();
1874
1875 while let Some(call_response) = calls.next().await {
1876 match call_response.as_ref() {
1877 Ok(_) => {
1878 response.send(proto::Ack {})?;
1879 return Ok(());
1880 }
1881 Err(_) => {
1882 call_response.trace_err();
1883 }
1884 }
1885 }
1886
1887 {
1888 let room = session
1889 .db()
1890 .await
1891 .call_failed(room_id, called_user_id)
1892 .await?;
1893 room_updated(&room, &session.peer);
1894 }
1895 update_user_contacts(called_user_id, &session).await?;
1896
1897 Err(anyhow!("failed to ring user"))?
1898}
1899
1900/// Cancel an outgoing call.
1901async fn cancel_call(
1902 request: proto::CancelCall,
1903 response: Response<proto::CancelCall>,
1904 session: UserSession,
1905) -> Result<()> {
1906 let called_user_id = UserId::from_proto(request.called_user_id);
1907 let room_id = RoomId::from_proto(request.room_id);
1908 {
1909 let room = session
1910 .db()
1911 .await
1912 .cancel_call(room_id, session.connection_id, called_user_id)
1913 .await?;
1914 room_updated(&room, &session.peer);
1915 }
1916
1917 for connection_id in session
1918 .connection_pool()
1919 .await
1920 .user_connection_ids(called_user_id)
1921 {
1922 session
1923 .peer
1924 .send(
1925 connection_id,
1926 proto::CallCanceled {
1927 room_id: room_id.to_proto(),
1928 },
1929 )
1930 .trace_err();
1931 }
1932 response.send(proto::Ack {})?;
1933
1934 update_user_contacts(called_user_id, &session).await?;
1935 Ok(())
1936}
1937
1938/// Decline an incoming call.
1939async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1940 let room_id = RoomId::from_proto(message.room_id);
1941 {
1942 let room = session
1943 .db()
1944 .await
1945 .decline_call(Some(room_id), session.user_id())
1946 .await?
1947 .ok_or_else(|| anyhow!("failed to decline call"))?;
1948 room_updated(&room, &session.peer);
1949 }
1950
1951 for connection_id in session
1952 .connection_pool()
1953 .await
1954 .user_connection_ids(session.user_id())
1955 {
1956 session
1957 .peer
1958 .send(
1959 connection_id,
1960 proto::CallCanceled {
1961 room_id: room_id.to_proto(),
1962 },
1963 )
1964 .trace_err();
1965 }
1966 update_user_contacts(session.user_id(), &session).await?;
1967 Ok(())
1968}
1969
1970/// Updates other participants in the room with your current location.
1971async fn update_participant_location(
1972 request: proto::UpdateParticipantLocation,
1973 response: Response<proto::UpdateParticipantLocation>,
1974 session: UserSession,
1975) -> Result<()> {
1976 let room_id = RoomId::from_proto(request.room_id);
1977 let location = request
1978 .location
1979 .ok_or_else(|| anyhow!("invalid location"))?;
1980
1981 let db = session.db().await;
1982 let room = db
1983 .update_room_participant_location(room_id, session.connection_id, location)
1984 .await?;
1985
1986 room_updated(&room, &session.peer);
1987 response.send(proto::Ack {})?;
1988 Ok(())
1989}
1990
1991/// Share a project into the room.
1992async fn share_project(
1993 request: proto::ShareProject,
1994 response: Response<proto::ShareProject>,
1995 session: UserSession,
1996) -> Result<()> {
1997 let (project_id, room) = &*session
1998 .db()
1999 .await
2000 .share_project(
2001 RoomId::from_proto(request.room_id),
2002 session.connection_id,
2003 &request.worktrees,
2004 request
2005 .dev_server_project_id
2006 .map(|id| DevServerProjectId::from_proto(id)),
2007 )
2008 .await?;
2009 response.send(proto::ShareProjectResponse {
2010 project_id: project_id.to_proto(),
2011 })?;
2012 room_updated(&room, &session.peer);
2013
2014 Ok(())
2015}
2016
2017/// Unshare a project from the room.
2018async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2019 let project_id = ProjectId::from_proto(message.project_id);
2020 unshare_project_internal(
2021 project_id,
2022 session.connection_id,
2023 session.user_id(),
2024 &session,
2025 )
2026 .await
2027}
2028
2029async fn unshare_project_internal(
2030 project_id: ProjectId,
2031 connection_id: ConnectionId,
2032 user_id: Option<UserId>,
2033 session: &Session,
2034) -> Result<()> {
2035 let delete = {
2036 let room_guard = session
2037 .db()
2038 .await
2039 .unshare_project(project_id, connection_id, user_id)
2040 .await?;
2041
2042 let (delete, room, guest_connection_ids) = &*room_guard;
2043
2044 let message = proto::UnshareProject {
2045 project_id: project_id.to_proto(),
2046 };
2047
2048 broadcast(
2049 Some(connection_id),
2050 guest_connection_ids.iter().copied(),
2051 |conn_id| session.peer.send(conn_id, message.clone()),
2052 );
2053 if let Some(room) = room {
2054 room_updated(room, &session.peer);
2055 }
2056
2057 *delete
2058 };
2059
2060 if delete {
2061 let db = session.db().await;
2062 db.delete_project(project_id).await?;
2063 }
2064
2065 Ok(())
2066}
2067
2068/// DevServer makes a project available online
2069async fn share_dev_server_project(
2070 request: proto::ShareDevServerProject,
2071 response: Response<proto::ShareDevServerProject>,
2072 session: DevServerSession,
2073) -> Result<()> {
2074 let (dev_server_project, user_id, status) = session
2075 .db()
2076 .await
2077 .share_dev_server_project(
2078 DevServerProjectId::from_proto(request.dev_server_project_id),
2079 session.dev_server_id(),
2080 session.connection_id,
2081 &request.worktrees,
2082 )
2083 .await?;
2084 let Some(project_id) = dev_server_project.project_id else {
2085 return Err(anyhow!("failed to share remote project"))?;
2086 };
2087
2088 send_dev_server_projects_update(user_id, status, &session).await;
2089
2090 response.send(proto::ShareProjectResponse { project_id })?;
2091
2092 Ok(())
2093}
2094
2095/// Join someone elses shared project.
2096async fn join_project(
2097 request: proto::JoinProject,
2098 response: Response<proto::JoinProject>,
2099 session: UserSession,
2100) -> Result<()> {
2101 let project_id = ProjectId::from_proto(request.project_id);
2102
2103 tracing::info!(%project_id, "join project");
2104
2105 let db = session.db().await;
2106 let (project, replica_id) = &mut *db
2107 .join_project(project_id, session.connection_id, session.user_id())
2108 .await?;
2109 drop(db);
2110 tracing::info!(%project_id, "join remote project");
2111 join_project_internal(response, session, project, replica_id)
2112}
2113
2114trait JoinProjectInternalResponse {
2115 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2116}
2117impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2118 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2119 Response::<proto::JoinProject>::send(self, result)
2120 }
2121}
2122impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2123 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2124 Response::<proto::JoinHostedProject>::send(self, result)
2125 }
2126}
2127
2128fn join_project_internal(
2129 response: impl JoinProjectInternalResponse,
2130 session: UserSession,
2131 project: &mut Project,
2132 replica_id: &ReplicaId,
2133) -> Result<()> {
2134 let collaborators = project
2135 .collaborators
2136 .iter()
2137 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2138 .map(|collaborator| collaborator.to_proto())
2139 .collect::<Vec<_>>();
2140 let project_id = project.id;
2141 let guest_user_id = session.user_id();
2142
2143 let worktrees = project
2144 .worktrees
2145 .iter()
2146 .map(|(id, worktree)| proto::WorktreeMetadata {
2147 id: *id,
2148 root_name: worktree.root_name.clone(),
2149 visible: worktree.visible,
2150 abs_path: worktree.abs_path.clone(),
2151 })
2152 .collect::<Vec<_>>();
2153
2154 let add_project_collaborator = proto::AddProjectCollaborator {
2155 project_id: project_id.to_proto(),
2156 collaborator: Some(proto::Collaborator {
2157 peer_id: Some(session.connection_id.into()),
2158 replica_id: replica_id.0 as u32,
2159 user_id: guest_user_id.to_proto(),
2160 }),
2161 };
2162
2163 for collaborator in &collaborators {
2164 session
2165 .peer
2166 .send(
2167 collaborator.peer_id.unwrap().into(),
2168 add_project_collaborator.clone(),
2169 )
2170 .trace_err();
2171 }
2172
2173 // First, we send the metadata associated with each worktree.
2174 response.send(proto::JoinProjectResponse {
2175 project_id: project.id.0 as u64,
2176 worktrees: worktrees.clone(),
2177 replica_id: replica_id.0 as u32,
2178 collaborators: collaborators.clone(),
2179 language_servers: project.language_servers.clone(),
2180 role: project.role.into(),
2181 dev_server_project_id: project
2182 .dev_server_project_id
2183 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2184 })?;
2185
2186 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2187 #[cfg(any(test, feature = "test-support"))]
2188 const MAX_CHUNK_SIZE: usize = 2;
2189 #[cfg(not(any(test, feature = "test-support")))]
2190 const MAX_CHUNK_SIZE: usize = 256;
2191
2192 // Stream this worktree's entries.
2193 let message = proto::UpdateWorktree {
2194 project_id: project_id.to_proto(),
2195 worktree_id,
2196 abs_path: worktree.abs_path.clone(),
2197 root_name: worktree.root_name,
2198 updated_entries: worktree.entries,
2199 removed_entries: Default::default(),
2200 scan_id: worktree.scan_id,
2201 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2202 updated_repositories: worktree.repository_entries.into_values().collect(),
2203 removed_repositories: Default::default(),
2204 };
2205 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2206 session.peer.send(session.connection_id, update.clone())?;
2207 }
2208
2209 // Stream this worktree's diagnostics.
2210 for summary in worktree.diagnostic_summaries {
2211 session.peer.send(
2212 session.connection_id,
2213 proto::UpdateDiagnosticSummary {
2214 project_id: project_id.to_proto(),
2215 worktree_id: worktree.id,
2216 summary: Some(summary),
2217 },
2218 )?;
2219 }
2220
2221 for settings_file in worktree.settings_files {
2222 session.peer.send(
2223 session.connection_id,
2224 proto::UpdateWorktreeSettings {
2225 project_id: project_id.to_proto(),
2226 worktree_id: worktree.id,
2227 path: settings_file.path,
2228 content: Some(settings_file.content),
2229 },
2230 )?;
2231 }
2232 }
2233
2234 for language_server in &project.language_servers {
2235 session.peer.send(
2236 session.connection_id,
2237 proto::UpdateLanguageServer {
2238 project_id: project_id.to_proto(),
2239 language_server_id: language_server.id,
2240 variant: Some(
2241 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2242 proto::LspDiskBasedDiagnosticsUpdated {},
2243 ),
2244 ),
2245 },
2246 )?;
2247 }
2248
2249 Ok(())
2250}
2251
2252/// Leave someone elses shared project.
2253async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2254 let sender_id = session.connection_id;
2255 let project_id = ProjectId::from_proto(request.project_id);
2256 let db = session.db().await;
2257 if db.is_hosted_project(project_id).await? {
2258 let project = db.leave_hosted_project(project_id, sender_id).await?;
2259 project_left(&project, &session);
2260 return Ok(());
2261 }
2262
2263 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2264 tracing::info!(
2265 %project_id,
2266 "leave project"
2267 );
2268
2269 project_left(&project, &session);
2270 if let Some(room) = room {
2271 room_updated(&room, &session.peer);
2272 }
2273
2274 Ok(())
2275}
2276
2277async fn join_hosted_project(
2278 request: proto::JoinHostedProject,
2279 response: Response<proto::JoinHostedProject>,
2280 session: UserSession,
2281) -> Result<()> {
2282 let (mut project, replica_id) = session
2283 .db()
2284 .await
2285 .join_hosted_project(
2286 ProjectId(request.project_id as i32),
2287 session.user_id(),
2288 session.connection_id,
2289 )
2290 .await?;
2291
2292 join_project_internal(response, session, &mut project, &replica_id)
2293}
2294
2295async fn create_dev_server_project(
2296 request: proto::CreateDevServerProject,
2297 response: Response<proto::CreateDevServerProject>,
2298 session: UserSession,
2299) -> Result<()> {
2300 let dev_server_id = DevServerId(request.dev_server_id as i32);
2301 let dev_server_connection_id = session
2302 .connection_pool()
2303 .await
2304 .dev_server_connection_id(dev_server_id);
2305 let Some(dev_server_connection_id) = dev_server_connection_id else {
2306 Err(ErrorCode::DevServerOffline
2307 .message("Cannot create a remote project when the dev server is offline".to_string())
2308 .anyhow())?
2309 };
2310
2311 let path = request.path.clone();
2312 //Check that the path exists on the dev server
2313 session
2314 .peer
2315 .forward_request(
2316 session.connection_id,
2317 dev_server_connection_id,
2318 proto::ValidateDevServerProjectRequest { path: path.clone() },
2319 )
2320 .await?;
2321
2322 let (dev_server_project, update) = session
2323 .db()
2324 .await
2325 .create_dev_server_project(
2326 DevServerId(request.dev_server_id as i32),
2327 &request.path,
2328 session.user_id(),
2329 )
2330 .await?;
2331
2332 let projects = session
2333 .db()
2334 .await
2335 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2336 .await?;
2337
2338 session.peer.send(
2339 dev_server_connection_id,
2340 proto::DevServerInstructions { projects },
2341 )?;
2342
2343 send_dev_server_projects_update(session.user_id(), update, &session).await;
2344
2345 response.send(proto::CreateDevServerProjectResponse {
2346 dev_server_project: Some(dev_server_project.to_proto(None)),
2347 })?;
2348 Ok(())
2349}
2350
2351async fn create_dev_server(
2352 request: proto::CreateDevServer,
2353 response: Response<proto::CreateDevServer>,
2354 session: UserSession,
2355) -> Result<()> {
2356 let access_token = auth::random_token();
2357 let hashed_access_token = auth::hash_access_token(&access_token);
2358
2359 if request.name.is_empty() {
2360 return Err(proto::ErrorCode::Forbidden
2361 .message("Dev server name cannot be empty".to_string())
2362 .anyhow())?;
2363 }
2364
2365 let (dev_server, status) = session
2366 .db()
2367 .await
2368 .create_dev_server(
2369 &request.name,
2370 request.ssh_connection_string.as_deref(),
2371 &hashed_access_token,
2372 session.user_id(),
2373 )
2374 .await?;
2375
2376 send_dev_server_projects_update(session.user_id(), status, &session).await;
2377
2378 response.send(proto::CreateDevServerResponse {
2379 dev_server_id: dev_server.id.0 as u64,
2380 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2381 name: request.name,
2382 })?;
2383 Ok(())
2384}
2385
2386async fn regenerate_dev_server_token(
2387 request: proto::RegenerateDevServerToken,
2388 response: Response<proto::RegenerateDevServerToken>,
2389 session: UserSession,
2390) -> Result<()> {
2391 let dev_server_id = DevServerId(request.dev_server_id as i32);
2392 let access_token = auth::random_token();
2393 let hashed_access_token = auth::hash_access_token(&access_token);
2394
2395 let connection_id = session
2396 .connection_pool()
2397 .await
2398 .dev_server_connection_id(dev_server_id);
2399 if let Some(connection_id) = connection_id {
2400 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2401 session
2402 .peer
2403 .send(connection_id, proto::ShutdownDevServer {})?;
2404 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2405 }
2406
2407 let status = session
2408 .db()
2409 .await
2410 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2411 .await?;
2412
2413 send_dev_server_projects_update(session.user_id(), status, &session).await;
2414
2415 response.send(proto::RegenerateDevServerTokenResponse {
2416 dev_server_id: dev_server_id.to_proto(),
2417 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2418 })?;
2419 Ok(())
2420}
2421
2422async fn rename_dev_server(
2423 request: proto::RenameDevServer,
2424 response: Response<proto::RenameDevServer>,
2425 session: UserSession,
2426) -> Result<()> {
2427 if request.name.trim().is_empty() {
2428 return Err(proto::ErrorCode::Forbidden
2429 .message("Dev server name cannot be empty".to_string())
2430 .anyhow())?;
2431 }
2432
2433 let dev_server_id = DevServerId(request.dev_server_id as i32);
2434 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2435 if dev_server.user_id != session.user_id() {
2436 return Err(anyhow!(ErrorCode::Forbidden))?;
2437 }
2438
2439 let status = session
2440 .db()
2441 .await
2442 .rename_dev_server(
2443 dev_server_id,
2444 &request.name,
2445 &request.ssh_connection_string,
2446 session.user_id(),
2447 )
2448 .await?;
2449
2450 send_dev_server_projects_update(session.user_id(), status, &session).await;
2451
2452 response.send(proto::Ack {})?;
2453 Ok(())
2454}
2455
2456async fn delete_dev_server(
2457 request: proto::DeleteDevServer,
2458 response: Response<proto::DeleteDevServer>,
2459 session: UserSession,
2460) -> Result<()> {
2461 let dev_server_id = DevServerId(request.dev_server_id as i32);
2462 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2463 if dev_server.user_id != session.user_id() {
2464 return Err(anyhow!(ErrorCode::Forbidden))?;
2465 }
2466
2467 let connection_id = session
2468 .connection_pool()
2469 .await
2470 .dev_server_connection_id(dev_server_id);
2471 if let Some(connection_id) = connection_id {
2472 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2473 session
2474 .peer
2475 .send(connection_id, proto::ShutdownDevServer {})?;
2476 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2477 }
2478
2479 let status = session
2480 .db()
2481 .await
2482 .delete_dev_server(dev_server_id, session.user_id())
2483 .await?;
2484
2485 send_dev_server_projects_update(session.user_id(), status, &session).await;
2486
2487 response.send(proto::Ack {})?;
2488 Ok(())
2489}
2490
2491async fn delete_dev_server_project(
2492 request: proto::DeleteDevServerProject,
2493 response: Response<proto::DeleteDevServerProject>,
2494 session: UserSession,
2495) -> Result<()> {
2496 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2497 let dev_server_project = session
2498 .db()
2499 .await
2500 .get_dev_server_project(dev_server_project_id)
2501 .await?;
2502
2503 let dev_server = session
2504 .db()
2505 .await
2506 .get_dev_server(dev_server_project.dev_server_id)
2507 .await?;
2508 if dev_server.user_id != session.user_id() {
2509 return Err(anyhow!(ErrorCode::Forbidden))?;
2510 }
2511
2512 let dev_server_connection_id = session
2513 .connection_pool()
2514 .await
2515 .dev_server_connection_id(dev_server.id);
2516
2517 if let Some(dev_server_connection_id) = dev_server_connection_id {
2518 let project = session
2519 .db()
2520 .await
2521 .find_dev_server_project(dev_server_project_id)
2522 .await;
2523 if let Ok(project) = project {
2524 unshare_project_internal(
2525 project.id,
2526 dev_server_connection_id,
2527 Some(session.user_id()),
2528 &session,
2529 )
2530 .await?;
2531 }
2532 }
2533
2534 let (projects, status) = session
2535 .db()
2536 .await
2537 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2538 .await?;
2539
2540 if let Some(dev_server_connection_id) = dev_server_connection_id {
2541 session.peer.send(
2542 dev_server_connection_id,
2543 proto::DevServerInstructions { projects },
2544 )?;
2545 }
2546
2547 send_dev_server_projects_update(session.user_id(), status, &session).await;
2548
2549 response.send(proto::Ack {})?;
2550 Ok(())
2551}
2552
2553async fn rejoin_dev_server_projects(
2554 request: proto::RejoinRemoteProjects,
2555 response: Response<proto::RejoinRemoteProjects>,
2556 session: UserSession,
2557) -> Result<()> {
2558 let mut rejoined_projects = {
2559 let db = session.db().await;
2560 db.rejoin_dev_server_projects(
2561 &request.rejoined_projects,
2562 session.user_id(),
2563 session.0.connection_id,
2564 )
2565 .await?
2566 };
2567 notify_rejoined_projects(&mut rejoined_projects, &session)?;
2568
2569 response.send(proto::RejoinRemoteProjectsResponse {
2570 rejoined_projects: rejoined_projects
2571 .into_iter()
2572 .map(|project| project.to_proto())
2573 .collect(),
2574 })
2575}
2576
2577async fn reconnect_dev_server(
2578 request: proto::ReconnectDevServer,
2579 response: Response<proto::ReconnectDevServer>,
2580 session: DevServerSession,
2581) -> Result<()> {
2582 let reshared_projects = {
2583 let db = session.db().await;
2584 db.reshare_dev_server_projects(
2585 &request.reshared_projects,
2586 session.dev_server_id(),
2587 session.0.connection_id,
2588 )
2589 .await?
2590 };
2591
2592 for project in &reshared_projects {
2593 for collaborator in &project.collaborators {
2594 session
2595 .peer
2596 .send(
2597 collaborator.connection_id,
2598 proto::UpdateProjectCollaborator {
2599 project_id: project.id.to_proto(),
2600 old_peer_id: Some(project.old_connection_id.into()),
2601 new_peer_id: Some(session.connection_id.into()),
2602 },
2603 )
2604 .trace_err();
2605 }
2606
2607 broadcast(
2608 Some(session.connection_id),
2609 project
2610 .collaborators
2611 .iter()
2612 .map(|collaborator| collaborator.connection_id),
2613 |connection_id| {
2614 session.peer.forward_send(
2615 session.connection_id,
2616 connection_id,
2617 proto::UpdateProject {
2618 project_id: project.id.to_proto(),
2619 worktrees: project.worktrees.clone(),
2620 },
2621 )
2622 },
2623 );
2624 }
2625
2626 response.send(proto::ReconnectDevServerResponse {
2627 reshared_projects: reshared_projects
2628 .iter()
2629 .map(|project| proto::ResharedProject {
2630 id: project.id.to_proto(),
2631 collaborators: project
2632 .collaborators
2633 .iter()
2634 .map(|collaborator| collaborator.to_proto())
2635 .collect(),
2636 })
2637 .collect(),
2638 })?;
2639
2640 Ok(())
2641}
2642
2643async fn shutdown_dev_server(
2644 _: proto::ShutdownDevServer,
2645 response: Response<proto::ShutdownDevServer>,
2646 session: DevServerSession,
2647) -> Result<()> {
2648 response.send(proto::Ack {})?;
2649 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2650 remove_dev_server_connection(session.dev_server_id(), &session).await
2651}
2652
2653async fn shutdown_dev_server_internal(
2654 dev_server_id: DevServerId,
2655 connection_id: ConnectionId,
2656 session: &Session,
2657) -> Result<()> {
2658 let (dev_server_projects, dev_server) = {
2659 let db = session.db().await;
2660 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2661 let dev_server = db.get_dev_server(dev_server_id).await?;
2662 (dev_server_projects, dev_server)
2663 };
2664
2665 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2666 unshare_project_internal(
2667 ProjectId::from_proto(project_id),
2668 connection_id,
2669 None,
2670 session,
2671 )
2672 .await?;
2673 }
2674
2675 session
2676 .connection_pool()
2677 .await
2678 .set_dev_server_offline(dev_server_id);
2679
2680 let status = session
2681 .db()
2682 .await
2683 .dev_server_projects_update(dev_server.user_id)
2684 .await?;
2685 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2686
2687 Ok(())
2688}
2689
2690async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2691 let dev_server_connection = session
2692 .connection_pool()
2693 .await
2694 .dev_server_connection_id(dev_server_id);
2695
2696 if let Some(dev_server_connection) = dev_server_connection {
2697 session
2698 .connection_pool()
2699 .await
2700 .remove_connection(dev_server_connection)?;
2701 }
2702 Ok(())
2703}
2704
2705/// Updates other participants with changes to the project
2706async fn update_project(
2707 request: proto::UpdateProject,
2708 response: Response<proto::UpdateProject>,
2709 session: Session,
2710) -> Result<()> {
2711 let project_id = ProjectId::from_proto(request.project_id);
2712 let (room, guest_connection_ids) = &*session
2713 .db()
2714 .await
2715 .update_project(project_id, session.connection_id, &request.worktrees)
2716 .await?;
2717 broadcast(
2718 Some(session.connection_id),
2719 guest_connection_ids.iter().copied(),
2720 |connection_id| {
2721 session
2722 .peer
2723 .forward_send(session.connection_id, connection_id, request.clone())
2724 },
2725 );
2726 if let Some(room) = room {
2727 room_updated(&room, &session.peer);
2728 }
2729 response.send(proto::Ack {})?;
2730
2731 Ok(())
2732}
2733
2734/// Updates other participants with changes to the worktree
2735async fn update_worktree(
2736 request: proto::UpdateWorktree,
2737 response: Response<proto::UpdateWorktree>,
2738 session: Session,
2739) -> Result<()> {
2740 let guest_connection_ids = session
2741 .db()
2742 .await
2743 .update_worktree(&request, session.connection_id)
2744 .await?;
2745
2746 broadcast(
2747 Some(session.connection_id),
2748 guest_connection_ids.iter().copied(),
2749 |connection_id| {
2750 session
2751 .peer
2752 .forward_send(session.connection_id, connection_id, request.clone())
2753 },
2754 );
2755 response.send(proto::Ack {})?;
2756 Ok(())
2757}
2758
2759/// Updates other participants with changes to the diagnostics
2760async fn update_diagnostic_summary(
2761 message: proto::UpdateDiagnosticSummary,
2762 session: Session,
2763) -> Result<()> {
2764 let guest_connection_ids = session
2765 .db()
2766 .await
2767 .update_diagnostic_summary(&message, session.connection_id)
2768 .await?;
2769
2770 broadcast(
2771 Some(session.connection_id),
2772 guest_connection_ids.iter().copied(),
2773 |connection_id| {
2774 session
2775 .peer
2776 .forward_send(session.connection_id, connection_id, message.clone())
2777 },
2778 );
2779
2780 Ok(())
2781}
2782
2783/// Updates other participants with changes to the worktree settings
2784async fn update_worktree_settings(
2785 message: proto::UpdateWorktreeSettings,
2786 session: Session,
2787) -> Result<()> {
2788 let guest_connection_ids = session
2789 .db()
2790 .await
2791 .update_worktree_settings(&message, session.connection_id)
2792 .await?;
2793
2794 broadcast(
2795 Some(session.connection_id),
2796 guest_connection_ids.iter().copied(),
2797 |connection_id| {
2798 session
2799 .peer
2800 .forward_send(session.connection_id, connection_id, message.clone())
2801 },
2802 );
2803
2804 Ok(())
2805}
2806
2807/// Notify other participants that a language server has started.
2808async fn start_language_server(
2809 request: proto::StartLanguageServer,
2810 session: Session,
2811) -> Result<()> {
2812 let guest_connection_ids = session
2813 .db()
2814 .await
2815 .start_language_server(&request, session.connection_id)
2816 .await?;
2817
2818 broadcast(
2819 Some(session.connection_id),
2820 guest_connection_ids.iter().copied(),
2821 |connection_id| {
2822 session
2823 .peer
2824 .forward_send(session.connection_id, connection_id, request.clone())
2825 },
2826 );
2827 Ok(())
2828}
2829
2830/// Notify other participants that a language server has changed.
2831async fn update_language_server(
2832 request: proto::UpdateLanguageServer,
2833 session: Session,
2834) -> Result<()> {
2835 let project_id = ProjectId::from_proto(request.project_id);
2836 let project_connection_ids = session
2837 .db()
2838 .await
2839 .project_connection_ids(project_id, session.connection_id, true)
2840 .await?;
2841 broadcast(
2842 Some(session.connection_id),
2843 project_connection_ids.iter().copied(),
2844 |connection_id| {
2845 session
2846 .peer
2847 .forward_send(session.connection_id, connection_id, request.clone())
2848 },
2849 );
2850 Ok(())
2851}
2852
2853/// forward a project request to the host. These requests should be read only
2854/// as guests are allowed to send them.
2855async fn forward_read_only_project_request<T>(
2856 request: T,
2857 response: Response<T>,
2858 session: UserSession,
2859) -> Result<()>
2860where
2861 T: EntityMessage + RequestMessage,
2862{
2863 let project_id = ProjectId::from_proto(request.remote_entity_id());
2864 let host_connection_id = session
2865 .db()
2866 .await
2867 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2868 .await?;
2869 let payload = session
2870 .peer
2871 .forward_request(session.connection_id, host_connection_id, request)
2872 .await?;
2873 response.send(payload)?;
2874 Ok(())
2875}
2876
2877/// forward a project request to the host. These requests are disallowed
2878/// for guests.
2879async fn forward_mutating_project_request<T>(
2880 request: T,
2881 response: Response<T>,
2882 session: UserSession,
2883) -> Result<()>
2884where
2885 T: EntityMessage + RequestMessage,
2886{
2887 let project_id = ProjectId::from_proto(request.remote_entity_id());
2888
2889 let host_connection_id = session
2890 .db()
2891 .await
2892 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2893 .await?;
2894 let payload = session
2895 .peer
2896 .forward_request(session.connection_id, host_connection_id, request)
2897 .await?;
2898 response.send(payload)?;
2899 Ok(())
2900}
2901
2902/// forward a project request to the host. These requests are disallowed
2903/// for guests.
2904async fn forward_versioned_mutating_project_request<T>(
2905 request: T,
2906 response: Response<T>,
2907 session: UserSession,
2908) -> Result<()>
2909where
2910 T: EntityMessage + RequestMessage + VersionedMessage,
2911{
2912 let project_id = ProjectId::from_proto(request.remote_entity_id());
2913
2914 let host_connection_id = session
2915 .db()
2916 .await
2917 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2918 .await?;
2919 if let Some(host_version) = session
2920 .connection_pool()
2921 .await
2922 .connection(host_connection_id)
2923 .map(|c| c.zed_version)
2924 {
2925 if let Some(min_required_version) = request.required_host_version() {
2926 if min_required_version > host_version {
2927 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2928 .with_tag("required", &min_required_version.to_string())))?;
2929 }
2930 }
2931 }
2932
2933 let payload = session
2934 .peer
2935 .forward_request(session.connection_id, host_connection_id, request)
2936 .await?;
2937 response.send(payload)?;
2938 Ok(())
2939}
2940
2941/// Notify other participants that a new buffer has been created
2942async fn create_buffer_for_peer(
2943 request: proto::CreateBufferForPeer,
2944 session: Session,
2945) -> Result<()> {
2946 session
2947 .db()
2948 .await
2949 .check_user_is_project_host(
2950 ProjectId::from_proto(request.project_id),
2951 session.connection_id,
2952 )
2953 .await?;
2954 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2955 session
2956 .peer
2957 .forward_send(session.connection_id, peer_id.into(), request)?;
2958 Ok(())
2959}
2960
2961/// Notify other participants that a buffer has been updated. This is
2962/// allowed for guests as long as the update is limited to selections.
2963async fn update_buffer(
2964 request: proto::UpdateBuffer,
2965 response: Response<proto::UpdateBuffer>,
2966 session: Session,
2967) -> Result<()> {
2968 let project_id = ProjectId::from_proto(request.project_id);
2969 let mut capability = Capability::ReadOnly;
2970
2971 for op in request.operations.iter() {
2972 match op.variant {
2973 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2974 Some(_) => capability = Capability::ReadWrite,
2975 }
2976 }
2977
2978 let host = {
2979 let guard = session
2980 .db()
2981 .await
2982 .connections_for_buffer_update(
2983 project_id,
2984 session.principal_id(),
2985 session.connection_id,
2986 capability,
2987 )
2988 .await?;
2989
2990 let (host, guests) = &*guard;
2991
2992 broadcast(
2993 Some(session.connection_id),
2994 guests.clone(),
2995 |connection_id| {
2996 session
2997 .peer
2998 .forward_send(session.connection_id, connection_id, request.clone())
2999 },
3000 );
3001
3002 *host
3003 };
3004
3005 if host != session.connection_id {
3006 session
3007 .peer
3008 .forward_request(session.connection_id, host, request.clone())
3009 .await?;
3010 }
3011
3012 response.send(proto::Ack {})?;
3013 Ok(())
3014}
3015
3016/// Notify other participants that a project has been updated.
3017async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3018 request: T,
3019 session: Session,
3020) -> Result<()> {
3021 let project_id = ProjectId::from_proto(request.remote_entity_id());
3022 let project_connection_ids = session
3023 .db()
3024 .await
3025 .project_connection_ids(project_id, session.connection_id, false)
3026 .await?;
3027
3028 broadcast(
3029 Some(session.connection_id),
3030 project_connection_ids.iter().copied(),
3031 |connection_id| {
3032 session
3033 .peer
3034 .forward_send(session.connection_id, connection_id, request.clone())
3035 },
3036 );
3037 Ok(())
3038}
3039
3040/// Start following another user in a call.
3041async fn follow(
3042 request: proto::Follow,
3043 response: Response<proto::Follow>,
3044 session: UserSession,
3045) -> Result<()> {
3046 let room_id = RoomId::from_proto(request.room_id);
3047 let project_id = request.project_id.map(ProjectId::from_proto);
3048 let leader_id = request
3049 .leader_id
3050 .ok_or_else(|| anyhow!("invalid leader id"))?
3051 .into();
3052 let follower_id = session.connection_id;
3053
3054 session
3055 .db()
3056 .await
3057 .check_room_participants(room_id, leader_id, session.connection_id)
3058 .await?;
3059
3060 let response_payload = session
3061 .peer
3062 .forward_request(session.connection_id, leader_id, request)
3063 .await?;
3064 response.send(response_payload)?;
3065
3066 if let Some(project_id) = project_id {
3067 let room = session
3068 .db()
3069 .await
3070 .follow(room_id, project_id, leader_id, follower_id)
3071 .await?;
3072 room_updated(&room, &session.peer);
3073 }
3074
3075 Ok(())
3076}
3077
3078/// Stop following another user in a call.
3079async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3080 let room_id = RoomId::from_proto(request.room_id);
3081 let project_id = request.project_id.map(ProjectId::from_proto);
3082 let leader_id = request
3083 .leader_id
3084 .ok_or_else(|| anyhow!("invalid leader id"))?
3085 .into();
3086 let follower_id = session.connection_id;
3087
3088 session
3089 .db()
3090 .await
3091 .check_room_participants(room_id, leader_id, session.connection_id)
3092 .await?;
3093
3094 session
3095 .peer
3096 .forward_send(session.connection_id, leader_id, request)?;
3097
3098 if let Some(project_id) = project_id {
3099 let room = session
3100 .db()
3101 .await
3102 .unfollow(room_id, project_id, leader_id, follower_id)
3103 .await?;
3104 room_updated(&room, &session.peer);
3105 }
3106
3107 Ok(())
3108}
3109
3110/// Notify everyone following you of your current location.
3111async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3112 let room_id = RoomId::from_proto(request.room_id);
3113 let database = session.db.lock().await;
3114
3115 let connection_ids = if let Some(project_id) = request.project_id {
3116 let project_id = ProjectId::from_proto(project_id);
3117 database
3118 .project_connection_ids(project_id, session.connection_id, true)
3119 .await?
3120 } else {
3121 database
3122 .room_connection_ids(room_id, session.connection_id)
3123 .await?
3124 };
3125
3126 // For now, don't send view update messages back to that view's current leader.
3127 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3128 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3129 _ => None,
3130 });
3131
3132 for connection_id in connection_ids.iter().cloned() {
3133 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3134 session
3135 .peer
3136 .forward_send(session.connection_id, connection_id, request.clone())?;
3137 }
3138 }
3139 Ok(())
3140}
3141
3142/// Get public data about users.
3143async fn get_users(
3144 request: proto::GetUsers,
3145 response: Response<proto::GetUsers>,
3146 session: Session,
3147) -> Result<()> {
3148 let user_ids = request
3149 .user_ids
3150 .into_iter()
3151 .map(UserId::from_proto)
3152 .collect();
3153 let users = session
3154 .db()
3155 .await
3156 .get_users_by_ids(user_ids)
3157 .await?
3158 .into_iter()
3159 .map(|user| proto::User {
3160 id: user.id.to_proto(),
3161 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3162 github_login: user.github_login,
3163 })
3164 .collect();
3165 response.send(proto::UsersResponse { users })?;
3166 Ok(())
3167}
3168
3169/// Search for users (to invite) buy Github login
3170async fn fuzzy_search_users(
3171 request: proto::FuzzySearchUsers,
3172 response: Response<proto::FuzzySearchUsers>,
3173 session: UserSession,
3174) -> Result<()> {
3175 let query = request.query;
3176 let users = match query.len() {
3177 0 => vec![],
3178 1 | 2 => session
3179 .db()
3180 .await
3181 .get_user_by_github_login(&query)
3182 .await?
3183 .into_iter()
3184 .collect(),
3185 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3186 };
3187 let users = users
3188 .into_iter()
3189 .filter(|user| user.id != session.user_id())
3190 .map(|user| proto::User {
3191 id: user.id.to_proto(),
3192 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3193 github_login: user.github_login,
3194 })
3195 .collect();
3196 response.send(proto::UsersResponse { users })?;
3197 Ok(())
3198}
3199
3200/// Send a contact request to another user.
3201async fn request_contact(
3202 request: proto::RequestContact,
3203 response: Response<proto::RequestContact>,
3204 session: UserSession,
3205) -> Result<()> {
3206 let requester_id = session.user_id();
3207 let responder_id = UserId::from_proto(request.responder_id);
3208 if requester_id == responder_id {
3209 return Err(anyhow!("cannot add yourself as a contact"))?;
3210 }
3211
3212 let notifications = session
3213 .db()
3214 .await
3215 .send_contact_request(requester_id, responder_id)
3216 .await?;
3217
3218 // Update outgoing contact requests of requester
3219 let mut update = proto::UpdateContacts::default();
3220 update.outgoing_requests.push(responder_id.to_proto());
3221 for connection_id in session
3222 .connection_pool()
3223 .await
3224 .user_connection_ids(requester_id)
3225 {
3226 session.peer.send(connection_id, update.clone())?;
3227 }
3228
3229 // Update incoming contact requests of responder
3230 let mut update = proto::UpdateContacts::default();
3231 update
3232 .incoming_requests
3233 .push(proto::IncomingContactRequest {
3234 requester_id: requester_id.to_proto(),
3235 });
3236 let connection_pool = session.connection_pool().await;
3237 for connection_id in connection_pool.user_connection_ids(responder_id) {
3238 session.peer.send(connection_id, update.clone())?;
3239 }
3240
3241 send_notifications(&connection_pool, &session.peer, notifications);
3242
3243 response.send(proto::Ack {})?;
3244 Ok(())
3245}
3246
3247/// Accept or decline a contact request
3248async fn respond_to_contact_request(
3249 request: proto::RespondToContactRequest,
3250 response: Response<proto::RespondToContactRequest>,
3251 session: UserSession,
3252) -> Result<()> {
3253 let responder_id = session.user_id();
3254 let requester_id = UserId::from_proto(request.requester_id);
3255 let db = session.db().await;
3256 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3257 db.dismiss_contact_notification(responder_id, requester_id)
3258 .await?;
3259 } else {
3260 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3261
3262 let notifications = db
3263 .respond_to_contact_request(responder_id, requester_id, accept)
3264 .await?;
3265 let requester_busy = db.is_user_busy(requester_id).await?;
3266 let responder_busy = db.is_user_busy(responder_id).await?;
3267
3268 let pool = session.connection_pool().await;
3269 // Update responder with new contact
3270 let mut update = proto::UpdateContacts::default();
3271 if accept {
3272 update
3273 .contacts
3274 .push(contact_for_user(requester_id, requester_busy, &pool));
3275 }
3276 update
3277 .remove_incoming_requests
3278 .push(requester_id.to_proto());
3279 for connection_id in pool.user_connection_ids(responder_id) {
3280 session.peer.send(connection_id, update.clone())?;
3281 }
3282
3283 // Update requester with new contact
3284 let mut update = proto::UpdateContacts::default();
3285 if accept {
3286 update
3287 .contacts
3288 .push(contact_for_user(responder_id, responder_busy, &pool));
3289 }
3290 update
3291 .remove_outgoing_requests
3292 .push(responder_id.to_proto());
3293
3294 for connection_id in pool.user_connection_ids(requester_id) {
3295 session.peer.send(connection_id, update.clone())?;
3296 }
3297
3298 send_notifications(&pool, &session.peer, notifications);
3299 }
3300
3301 response.send(proto::Ack {})?;
3302 Ok(())
3303}
3304
3305/// Remove a contact.
3306async fn remove_contact(
3307 request: proto::RemoveContact,
3308 response: Response<proto::RemoveContact>,
3309 session: UserSession,
3310) -> Result<()> {
3311 let requester_id = session.user_id();
3312 let responder_id = UserId::from_proto(request.user_id);
3313 let db = session.db().await;
3314 let (contact_accepted, deleted_notification_id) =
3315 db.remove_contact(requester_id, responder_id).await?;
3316
3317 let pool = session.connection_pool().await;
3318 // Update outgoing contact requests of requester
3319 let mut update = proto::UpdateContacts::default();
3320 if contact_accepted {
3321 update.remove_contacts.push(responder_id.to_proto());
3322 } else {
3323 update
3324 .remove_outgoing_requests
3325 .push(responder_id.to_proto());
3326 }
3327 for connection_id in pool.user_connection_ids(requester_id) {
3328 session.peer.send(connection_id, update.clone())?;
3329 }
3330
3331 // Update incoming contact requests of responder
3332 let mut update = proto::UpdateContacts::default();
3333 if contact_accepted {
3334 update.remove_contacts.push(requester_id.to_proto());
3335 } else {
3336 update
3337 .remove_incoming_requests
3338 .push(requester_id.to_proto());
3339 }
3340 for connection_id in pool.user_connection_ids(responder_id) {
3341 session.peer.send(connection_id, update.clone())?;
3342 if let Some(notification_id) = deleted_notification_id {
3343 session.peer.send(
3344 connection_id,
3345 proto::DeleteNotification {
3346 notification_id: notification_id.to_proto(),
3347 },
3348 )?;
3349 }
3350 }
3351
3352 response.send(proto::Ack {})?;
3353 Ok(())
3354}
3355
3356/// Creates a new channel.
3357async fn create_channel(
3358 request: proto::CreateChannel,
3359 response: Response<proto::CreateChannel>,
3360 session: UserSession,
3361) -> Result<()> {
3362 let db = session.db().await;
3363
3364 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3365 let (channel, membership) = db
3366 .create_channel(&request.name, parent_id, session.user_id())
3367 .await?;
3368
3369 let root_id = channel.root_id();
3370 let channel = Channel::from_model(channel);
3371
3372 response.send(proto::CreateChannelResponse {
3373 channel: Some(channel.to_proto()),
3374 parent_id: request.parent_id,
3375 })?;
3376
3377 let mut connection_pool = session.connection_pool().await;
3378 if let Some(membership) = membership {
3379 connection_pool.subscribe_to_channel(
3380 membership.user_id,
3381 membership.channel_id,
3382 membership.role,
3383 );
3384 let update = proto::UpdateUserChannels {
3385 channel_memberships: vec![proto::ChannelMembership {
3386 channel_id: membership.channel_id.to_proto(),
3387 role: membership.role.into(),
3388 }],
3389 ..Default::default()
3390 };
3391 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3392 session.peer.send(connection_id, update.clone())?;
3393 }
3394 }
3395
3396 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3397 if !role.can_see_channel(channel.visibility) {
3398 continue;
3399 }
3400
3401 let update = proto::UpdateChannels {
3402 channels: vec![channel.to_proto()],
3403 ..Default::default()
3404 };
3405 session.peer.send(connection_id, update.clone())?;
3406 }
3407
3408 Ok(())
3409}
3410
3411/// Delete a channel
3412async fn delete_channel(
3413 request: proto::DeleteChannel,
3414 response: Response<proto::DeleteChannel>,
3415 session: UserSession,
3416) -> Result<()> {
3417 let db = session.db().await;
3418
3419 let channel_id = request.channel_id;
3420 let (root_channel, removed_channels) = db
3421 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3422 .await?;
3423 response.send(proto::Ack {})?;
3424
3425 // Notify members of removed channels
3426 let mut update = proto::UpdateChannels::default();
3427 update
3428 .delete_channels
3429 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3430
3431 let connection_pool = session.connection_pool().await;
3432 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3433 session.peer.send(connection_id, update.clone())?;
3434 }
3435
3436 Ok(())
3437}
3438
3439/// Invite someone to join a channel.
3440async fn invite_channel_member(
3441 request: proto::InviteChannelMember,
3442 response: Response<proto::InviteChannelMember>,
3443 session: UserSession,
3444) -> Result<()> {
3445 let db = session.db().await;
3446 let channel_id = ChannelId::from_proto(request.channel_id);
3447 let invitee_id = UserId::from_proto(request.user_id);
3448 let InviteMemberResult {
3449 channel,
3450 notifications,
3451 } = db
3452 .invite_channel_member(
3453 channel_id,
3454 invitee_id,
3455 session.user_id(),
3456 request.role().into(),
3457 )
3458 .await?;
3459
3460 let update = proto::UpdateChannels {
3461 channel_invitations: vec![channel.to_proto()],
3462 ..Default::default()
3463 };
3464
3465 let connection_pool = session.connection_pool().await;
3466 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3467 session.peer.send(connection_id, update.clone())?;
3468 }
3469
3470 send_notifications(&connection_pool, &session.peer, notifications);
3471
3472 response.send(proto::Ack {})?;
3473 Ok(())
3474}
3475
3476/// remove someone from a channel
3477async fn remove_channel_member(
3478 request: proto::RemoveChannelMember,
3479 response: Response<proto::RemoveChannelMember>,
3480 session: UserSession,
3481) -> Result<()> {
3482 let db = session.db().await;
3483 let channel_id = ChannelId::from_proto(request.channel_id);
3484 let member_id = UserId::from_proto(request.user_id);
3485
3486 let RemoveChannelMemberResult {
3487 membership_update,
3488 notification_id,
3489 } = db
3490 .remove_channel_member(channel_id, member_id, session.user_id())
3491 .await?;
3492
3493 let mut connection_pool = session.connection_pool().await;
3494 notify_membership_updated(
3495 &mut connection_pool,
3496 membership_update,
3497 member_id,
3498 &session.peer,
3499 );
3500 for connection_id in connection_pool.user_connection_ids(member_id) {
3501 if let Some(notification_id) = notification_id {
3502 session
3503 .peer
3504 .send(
3505 connection_id,
3506 proto::DeleteNotification {
3507 notification_id: notification_id.to_proto(),
3508 },
3509 )
3510 .trace_err();
3511 }
3512 }
3513
3514 response.send(proto::Ack {})?;
3515 Ok(())
3516}
3517
3518/// Toggle the channel between public and private.
3519/// Care is taken to maintain the invariant that public channels only descend from public channels,
3520/// (though members-only channels can appear at any point in the hierarchy).
3521async fn set_channel_visibility(
3522 request: proto::SetChannelVisibility,
3523 response: Response<proto::SetChannelVisibility>,
3524 session: UserSession,
3525) -> Result<()> {
3526 let db = session.db().await;
3527 let channel_id = ChannelId::from_proto(request.channel_id);
3528 let visibility = request.visibility().into();
3529
3530 let channel_model = db
3531 .set_channel_visibility(channel_id, visibility, session.user_id())
3532 .await?;
3533 let root_id = channel_model.root_id();
3534 let channel = Channel::from_model(channel_model);
3535
3536 let mut connection_pool = session.connection_pool().await;
3537 for (user_id, role) in connection_pool
3538 .channel_user_ids(root_id)
3539 .collect::<Vec<_>>()
3540 .into_iter()
3541 {
3542 let update = if role.can_see_channel(channel.visibility) {
3543 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3544 proto::UpdateChannels {
3545 channels: vec![channel.to_proto()],
3546 ..Default::default()
3547 }
3548 } else {
3549 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3550 proto::UpdateChannels {
3551 delete_channels: vec![channel.id.to_proto()],
3552 ..Default::default()
3553 }
3554 };
3555
3556 for connection_id in connection_pool.user_connection_ids(user_id) {
3557 session.peer.send(connection_id, update.clone())?;
3558 }
3559 }
3560
3561 response.send(proto::Ack {})?;
3562 Ok(())
3563}
3564
3565/// Alter the role for a user in the channel.
3566async fn set_channel_member_role(
3567 request: proto::SetChannelMemberRole,
3568 response: Response<proto::SetChannelMemberRole>,
3569 session: UserSession,
3570) -> Result<()> {
3571 let db = session.db().await;
3572 let channel_id = ChannelId::from_proto(request.channel_id);
3573 let member_id = UserId::from_proto(request.user_id);
3574 let result = db
3575 .set_channel_member_role(
3576 channel_id,
3577 session.user_id(),
3578 member_id,
3579 request.role().into(),
3580 )
3581 .await?;
3582
3583 match result {
3584 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3585 let mut connection_pool = session.connection_pool().await;
3586 notify_membership_updated(
3587 &mut connection_pool,
3588 membership_update,
3589 member_id,
3590 &session.peer,
3591 )
3592 }
3593 db::SetMemberRoleResult::InviteUpdated(channel) => {
3594 let update = proto::UpdateChannels {
3595 channel_invitations: vec![channel.to_proto()],
3596 ..Default::default()
3597 };
3598
3599 for connection_id in session
3600 .connection_pool()
3601 .await
3602 .user_connection_ids(member_id)
3603 {
3604 session.peer.send(connection_id, update.clone())?;
3605 }
3606 }
3607 }
3608
3609 response.send(proto::Ack {})?;
3610 Ok(())
3611}
3612
3613/// Change the name of a channel
3614async fn rename_channel(
3615 request: proto::RenameChannel,
3616 response: Response<proto::RenameChannel>,
3617 session: UserSession,
3618) -> Result<()> {
3619 let db = session.db().await;
3620 let channel_id = ChannelId::from_proto(request.channel_id);
3621 let channel_model = db
3622 .rename_channel(channel_id, session.user_id(), &request.name)
3623 .await?;
3624 let root_id = channel_model.root_id();
3625 let channel = Channel::from_model(channel_model);
3626
3627 response.send(proto::RenameChannelResponse {
3628 channel: Some(channel.to_proto()),
3629 })?;
3630
3631 let connection_pool = session.connection_pool().await;
3632 let update = proto::UpdateChannels {
3633 channels: vec![channel.to_proto()],
3634 ..Default::default()
3635 };
3636 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3637 if role.can_see_channel(channel.visibility) {
3638 session.peer.send(connection_id, update.clone())?;
3639 }
3640 }
3641
3642 Ok(())
3643}
3644
3645/// Move a channel to a new parent.
3646async fn move_channel(
3647 request: proto::MoveChannel,
3648 response: Response<proto::MoveChannel>,
3649 session: UserSession,
3650) -> Result<()> {
3651 let channel_id = ChannelId::from_proto(request.channel_id);
3652 let to = ChannelId::from_proto(request.to);
3653
3654 let (root_id, channels) = session
3655 .db()
3656 .await
3657 .move_channel(channel_id, to, session.user_id())
3658 .await?;
3659
3660 let connection_pool = session.connection_pool().await;
3661 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3662 let channels = channels
3663 .iter()
3664 .filter_map(|channel| {
3665 if role.can_see_channel(channel.visibility) {
3666 Some(channel.to_proto())
3667 } else {
3668 None
3669 }
3670 })
3671 .collect::<Vec<_>>();
3672 if channels.is_empty() {
3673 continue;
3674 }
3675
3676 let update = proto::UpdateChannels {
3677 channels,
3678 ..Default::default()
3679 };
3680
3681 session.peer.send(connection_id, update.clone())?;
3682 }
3683
3684 response.send(Ack {})?;
3685 Ok(())
3686}
3687
3688/// Get the list of channel members
3689async fn get_channel_members(
3690 request: proto::GetChannelMembers,
3691 response: Response<proto::GetChannelMembers>,
3692 session: UserSession,
3693) -> Result<()> {
3694 let db = session.db().await;
3695 let channel_id = ChannelId::from_proto(request.channel_id);
3696 let limit = if request.limit == 0 {
3697 u16::MAX as u64
3698 } else {
3699 request.limit
3700 };
3701 let (members, users) = db
3702 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3703 .await?;
3704 response.send(proto::GetChannelMembersResponse { members, users })?;
3705 Ok(())
3706}
3707
3708/// Accept or decline a channel invitation.
3709async fn respond_to_channel_invite(
3710 request: proto::RespondToChannelInvite,
3711 response: Response<proto::RespondToChannelInvite>,
3712 session: UserSession,
3713) -> Result<()> {
3714 let db = session.db().await;
3715 let channel_id = ChannelId::from_proto(request.channel_id);
3716 let RespondToChannelInvite {
3717 membership_update,
3718 notifications,
3719 } = db
3720 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3721 .await?;
3722
3723 let mut connection_pool = session.connection_pool().await;
3724 if let Some(membership_update) = membership_update {
3725 notify_membership_updated(
3726 &mut connection_pool,
3727 membership_update,
3728 session.user_id(),
3729 &session.peer,
3730 );
3731 } else {
3732 let update = proto::UpdateChannels {
3733 remove_channel_invitations: vec![channel_id.to_proto()],
3734 ..Default::default()
3735 };
3736
3737 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3738 session.peer.send(connection_id, update.clone())?;
3739 }
3740 };
3741
3742 send_notifications(&connection_pool, &session.peer, notifications);
3743
3744 response.send(proto::Ack {})?;
3745
3746 Ok(())
3747}
3748
3749/// Join the channels' room
3750async fn join_channel(
3751 request: proto::JoinChannel,
3752 response: Response<proto::JoinChannel>,
3753 session: UserSession,
3754) -> Result<()> {
3755 let channel_id = ChannelId::from_proto(request.channel_id);
3756 join_channel_internal(channel_id, Box::new(response), session).await
3757}
3758
3759trait JoinChannelInternalResponse {
3760 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3761}
3762impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3763 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3764 Response::<proto::JoinChannel>::send(self, result)
3765 }
3766}
3767impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3768 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3769 Response::<proto::JoinRoom>::send(self, result)
3770 }
3771}
3772
3773async fn join_channel_internal(
3774 channel_id: ChannelId,
3775 response: Box<impl JoinChannelInternalResponse>,
3776 session: UserSession,
3777) -> Result<()> {
3778 let joined_room = {
3779 let mut db = session.db().await;
3780 // If zed quits without leaving the room, and the user re-opens zed before the
3781 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3782 // room they were in.
3783 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3784 tracing::info!(
3785 stale_connection_id = %connection,
3786 "cleaning up stale connection",
3787 );
3788 drop(db);
3789 leave_room_for_session(&session, connection).await?;
3790 db = session.db().await;
3791 }
3792
3793 let (joined_room, membership_updated, role) = db
3794 .join_channel(channel_id, session.user_id(), session.connection_id)
3795 .await?;
3796
3797 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3798 let (can_publish, token) = if role == ChannelRole::Guest {
3799 (
3800 false,
3801 live_kit
3802 .guest_token(
3803 &joined_room.room.live_kit_room,
3804 &session.user_id().to_string(),
3805 )
3806 .trace_err()?,
3807 )
3808 } else {
3809 (
3810 true,
3811 live_kit
3812 .room_token(
3813 &joined_room.room.live_kit_room,
3814 &session.user_id().to_string(),
3815 )
3816 .trace_err()?,
3817 )
3818 };
3819
3820 Some(LiveKitConnectionInfo {
3821 server_url: live_kit.url().into(),
3822 token,
3823 can_publish,
3824 })
3825 });
3826
3827 response.send(proto::JoinRoomResponse {
3828 room: Some(joined_room.room.clone()),
3829 channel_id: joined_room
3830 .channel
3831 .as_ref()
3832 .map(|channel| channel.id.to_proto()),
3833 live_kit_connection_info,
3834 })?;
3835
3836 let mut connection_pool = session.connection_pool().await;
3837 if let Some(membership_updated) = membership_updated {
3838 notify_membership_updated(
3839 &mut connection_pool,
3840 membership_updated,
3841 session.user_id(),
3842 &session.peer,
3843 );
3844 }
3845
3846 room_updated(&joined_room.room, &session.peer);
3847
3848 joined_room
3849 };
3850
3851 channel_updated(
3852 &joined_room
3853 .channel
3854 .ok_or_else(|| anyhow!("channel not returned"))?,
3855 &joined_room.room,
3856 &session.peer,
3857 &*session.connection_pool().await,
3858 );
3859
3860 update_user_contacts(session.user_id(), &session).await?;
3861 Ok(())
3862}
3863
3864/// Start editing the channel notes
3865async fn join_channel_buffer(
3866 request: proto::JoinChannelBuffer,
3867 response: Response<proto::JoinChannelBuffer>,
3868 session: UserSession,
3869) -> Result<()> {
3870 let db = session.db().await;
3871 let channel_id = ChannelId::from_proto(request.channel_id);
3872
3873 let open_response = db
3874 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3875 .await?;
3876
3877 let collaborators = open_response.collaborators.clone();
3878 response.send(open_response)?;
3879
3880 let update = UpdateChannelBufferCollaborators {
3881 channel_id: channel_id.to_proto(),
3882 collaborators: collaborators.clone(),
3883 };
3884 channel_buffer_updated(
3885 session.connection_id,
3886 collaborators
3887 .iter()
3888 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3889 &update,
3890 &session.peer,
3891 );
3892
3893 Ok(())
3894}
3895
3896/// Edit the channel notes
3897async fn update_channel_buffer(
3898 request: proto::UpdateChannelBuffer,
3899 session: UserSession,
3900) -> Result<()> {
3901 let db = session.db().await;
3902 let channel_id = ChannelId::from_proto(request.channel_id);
3903
3904 let (collaborators, epoch, version) = db
3905 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3906 .await?;
3907
3908 channel_buffer_updated(
3909 session.connection_id,
3910 collaborators.clone(),
3911 &proto::UpdateChannelBuffer {
3912 channel_id: channel_id.to_proto(),
3913 operations: request.operations,
3914 },
3915 &session.peer,
3916 );
3917
3918 let pool = &*session.connection_pool().await;
3919
3920 let non_collaborators =
3921 pool.channel_connection_ids(channel_id)
3922 .filter_map(|(connection_id, _)| {
3923 if collaborators.contains(&connection_id) {
3924 None
3925 } else {
3926 Some(connection_id)
3927 }
3928 });
3929
3930 broadcast(None, non_collaborators, |peer_id| {
3931 session.peer.send(
3932 peer_id,
3933 proto::UpdateChannels {
3934 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3935 channel_id: channel_id.to_proto(),
3936 epoch: epoch as u64,
3937 version: version.clone(),
3938 }],
3939 ..Default::default()
3940 },
3941 )
3942 });
3943
3944 Ok(())
3945}
3946
3947/// Rejoin the channel notes after a connection blip
3948async fn rejoin_channel_buffers(
3949 request: proto::RejoinChannelBuffers,
3950 response: Response<proto::RejoinChannelBuffers>,
3951 session: UserSession,
3952) -> Result<()> {
3953 let db = session.db().await;
3954 let buffers = db
3955 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3956 .await?;
3957
3958 for rejoined_buffer in &buffers {
3959 let collaborators_to_notify = rejoined_buffer
3960 .buffer
3961 .collaborators
3962 .iter()
3963 .filter_map(|c| Some(c.peer_id?.into()));
3964 channel_buffer_updated(
3965 session.connection_id,
3966 collaborators_to_notify,
3967 &proto::UpdateChannelBufferCollaborators {
3968 channel_id: rejoined_buffer.buffer.channel_id,
3969 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3970 },
3971 &session.peer,
3972 );
3973 }
3974
3975 response.send(proto::RejoinChannelBuffersResponse {
3976 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3977 })?;
3978
3979 Ok(())
3980}
3981
3982/// Stop editing the channel notes
3983async fn leave_channel_buffer(
3984 request: proto::LeaveChannelBuffer,
3985 response: Response<proto::LeaveChannelBuffer>,
3986 session: UserSession,
3987) -> Result<()> {
3988 let db = session.db().await;
3989 let channel_id = ChannelId::from_proto(request.channel_id);
3990
3991 let left_buffer = db
3992 .leave_channel_buffer(channel_id, session.connection_id)
3993 .await?;
3994
3995 response.send(Ack {})?;
3996
3997 channel_buffer_updated(
3998 session.connection_id,
3999 left_buffer.connections,
4000 &proto::UpdateChannelBufferCollaborators {
4001 channel_id: channel_id.to_proto(),
4002 collaborators: left_buffer.collaborators,
4003 },
4004 &session.peer,
4005 );
4006
4007 Ok(())
4008}
4009
4010fn channel_buffer_updated<T: EnvelopedMessage>(
4011 sender_id: ConnectionId,
4012 collaborators: impl IntoIterator<Item = ConnectionId>,
4013 message: &T,
4014 peer: &Peer,
4015) {
4016 broadcast(Some(sender_id), collaborators, |peer_id| {
4017 peer.send(peer_id, message.clone())
4018 });
4019}
4020
4021fn send_notifications(
4022 connection_pool: &ConnectionPool,
4023 peer: &Peer,
4024 notifications: db::NotificationBatch,
4025) {
4026 for (user_id, notification) in notifications {
4027 for connection_id in connection_pool.user_connection_ids(user_id) {
4028 if let Err(error) = peer.send(
4029 connection_id,
4030 proto::AddNotification {
4031 notification: Some(notification.clone()),
4032 },
4033 ) {
4034 tracing::error!(
4035 "failed to send notification to {:?} {}",
4036 connection_id,
4037 error
4038 );
4039 }
4040 }
4041 }
4042}
4043
4044/// Send a message to the channel
4045async fn send_channel_message(
4046 request: proto::SendChannelMessage,
4047 response: Response<proto::SendChannelMessage>,
4048 session: UserSession,
4049) -> Result<()> {
4050 // Validate the message body.
4051 let body = request.body.trim().to_string();
4052 if body.len() > MAX_MESSAGE_LEN {
4053 return Err(anyhow!("message is too long"))?;
4054 }
4055 if body.is_empty() {
4056 return Err(anyhow!("message can't be blank"))?;
4057 }
4058
4059 // TODO: adjust mentions if body is trimmed
4060
4061 let timestamp = OffsetDateTime::now_utc();
4062 let nonce = request
4063 .nonce
4064 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4065
4066 let channel_id = ChannelId::from_proto(request.channel_id);
4067 let CreatedChannelMessage {
4068 message_id,
4069 participant_connection_ids,
4070 notifications,
4071 } = session
4072 .db()
4073 .await
4074 .create_channel_message(
4075 channel_id,
4076 session.user_id(),
4077 &body,
4078 &request.mentions,
4079 timestamp,
4080 nonce.clone().into(),
4081 match request.reply_to_message_id {
4082 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4083 None => None,
4084 },
4085 )
4086 .await?;
4087
4088 let message = proto::ChannelMessage {
4089 sender_id: session.user_id().to_proto(),
4090 id: message_id.to_proto(),
4091 body,
4092 mentions: request.mentions,
4093 timestamp: timestamp.unix_timestamp() as u64,
4094 nonce: Some(nonce),
4095 reply_to_message_id: request.reply_to_message_id,
4096 edited_at: None,
4097 };
4098 broadcast(
4099 Some(session.connection_id),
4100 participant_connection_ids.clone(),
4101 |connection| {
4102 session.peer.send(
4103 connection,
4104 proto::ChannelMessageSent {
4105 channel_id: channel_id.to_proto(),
4106 message: Some(message.clone()),
4107 },
4108 )
4109 },
4110 );
4111 response.send(proto::SendChannelMessageResponse {
4112 message: Some(message),
4113 })?;
4114
4115 let pool = &*session.connection_pool().await;
4116 let non_participants =
4117 pool.channel_connection_ids(channel_id)
4118 .filter_map(|(connection_id, _)| {
4119 if participant_connection_ids.contains(&connection_id) {
4120 None
4121 } else {
4122 Some(connection_id)
4123 }
4124 });
4125 broadcast(None, non_participants, |peer_id| {
4126 session.peer.send(
4127 peer_id,
4128 proto::UpdateChannels {
4129 latest_channel_message_ids: vec![proto::ChannelMessageId {
4130 channel_id: channel_id.to_proto(),
4131 message_id: message_id.to_proto(),
4132 }],
4133 ..Default::default()
4134 },
4135 )
4136 });
4137 send_notifications(pool, &session.peer, notifications);
4138
4139 Ok(())
4140}
4141
4142/// Delete a channel message
4143async fn remove_channel_message(
4144 request: proto::RemoveChannelMessage,
4145 response: Response<proto::RemoveChannelMessage>,
4146 session: UserSession,
4147) -> Result<()> {
4148 let channel_id = ChannelId::from_proto(request.channel_id);
4149 let message_id = MessageId::from_proto(request.message_id);
4150 let (connection_ids, existing_notification_ids) = session
4151 .db()
4152 .await
4153 .remove_channel_message(channel_id, message_id, session.user_id())
4154 .await?;
4155
4156 broadcast(
4157 Some(session.connection_id),
4158 connection_ids,
4159 move |connection| {
4160 session.peer.send(connection, request.clone())?;
4161
4162 for notification_id in &existing_notification_ids {
4163 session.peer.send(
4164 connection,
4165 proto::DeleteNotification {
4166 notification_id: (*notification_id).to_proto(),
4167 },
4168 )?;
4169 }
4170
4171 Ok(())
4172 },
4173 );
4174 response.send(proto::Ack {})?;
4175 Ok(())
4176}
4177
4178async fn update_channel_message(
4179 request: proto::UpdateChannelMessage,
4180 response: Response<proto::UpdateChannelMessage>,
4181 session: UserSession,
4182) -> Result<()> {
4183 let channel_id = ChannelId::from_proto(request.channel_id);
4184 let message_id = MessageId::from_proto(request.message_id);
4185 let updated_at = OffsetDateTime::now_utc();
4186 let UpdatedChannelMessage {
4187 message_id,
4188 participant_connection_ids,
4189 notifications,
4190 reply_to_message_id,
4191 timestamp,
4192 deleted_mention_notification_ids,
4193 updated_mention_notifications,
4194 } = session
4195 .db()
4196 .await
4197 .update_channel_message(
4198 channel_id,
4199 message_id,
4200 session.user_id(),
4201 request.body.as_str(),
4202 &request.mentions,
4203 updated_at,
4204 )
4205 .await?;
4206
4207 let nonce = request
4208 .nonce
4209 .clone()
4210 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4211
4212 let message = proto::ChannelMessage {
4213 sender_id: session.user_id().to_proto(),
4214 id: message_id.to_proto(),
4215 body: request.body.clone(),
4216 mentions: request.mentions.clone(),
4217 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4218 nonce: Some(nonce),
4219 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4220 edited_at: Some(updated_at.unix_timestamp() as u64),
4221 };
4222
4223 response.send(proto::Ack {})?;
4224
4225 let pool = &*session.connection_pool().await;
4226 broadcast(
4227 Some(session.connection_id),
4228 participant_connection_ids,
4229 |connection| {
4230 session.peer.send(
4231 connection,
4232 proto::ChannelMessageUpdate {
4233 channel_id: channel_id.to_proto(),
4234 message: Some(message.clone()),
4235 },
4236 )?;
4237
4238 for notification_id in &deleted_mention_notification_ids {
4239 session.peer.send(
4240 connection,
4241 proto::DeleteNotification {
4242 notification_id: (*notification_id).to_proto(),
4243 },
4244 )?;
4245 }
4246
4247 for notification in &updated_mention_notifications {
4248 session.peer.send(
4249 connection,
4250 proto::UpdateNotification {
4251 notification: Some(notification.clone()),
4252 },
4253 )?;
4254 }
4255
4256 Ok(())
4257 },
4258 );
4259
4260 send_notifications(pool, &session.peer, notifications);
4261
4262 Ok(())
4263}
4264
4265/// Mark a channel message as read
4266async fn acknowledge_channel_message(
4267 request: proto::AckChannelMessage,
4268 session: UserSession,
4269) -> Result<()> {
4270 let channel_id = ChannelId::from_proto(request.channel_id);
4271 let message_id = MessageId::from_proto(request.message_id);
4272 let notifications = session
4273 .db()
4274 .await
4275 .observe_channel_message(channel_id, session.user_id(), message_id)
4276 .await?;
4277 send_notifications(
4278 &*session.connection_pool().await,
4279 &session.peer,
4280 notifications,
4281 );
4282 Ok(())
4283}
4284
4285/// Mark a buffer version as synced
4286async fn acknowledge_buffer_version(
4287 request: proto::AckBufferOperation,
4288 session: UserSession,
4289) -> Result<()> {
4290 let buffer_id = BufferId::from_proto(request.buffer_id);
4291 session
4292 .db()
4293 .await
4294 .observe_buffer_version(
4295 buffer_id,
4296 session.user_id(),
4297 request.epoch as i32,
4298 &request.version,
4299 )
4300 .await?;
4301 Ok(())
4302}
4303
4304struct CompleteWithLanguageModelRateLimit;
4305
4306impl RateLimit for CompleteWithLanguageModelRateLimit {
4307 fn capacity() -> usize {
4308 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4309 .ok()
4310 .and_then(|v| v.parse().ok())
4311 .unwrap_or(120) // Picked arbitrarily
4312 }
4313
4314 fn refill_duration() -> chrono::Duration {
4315 chrono::Duration::hours(1)
4316 }
4317
4318 fn db_name() -> &'static str {
4319 "complete-with-language-model"
4320 }
4321}
4322
4323async fn complete_with_language_model(
4324 request: proto::CompleteWithLanguageModel,
4325 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4326 session: Session,
4327 open_ai_api_key: Option<Arc<str>>,
4328 google_ai_api_key: Option<Arc<str>>,
4329 anthropic_api_key: Option<Arc<str>>,
4330) -> Result<()> {
4331 let Some(session) = session.for_user() else {
4332 return Err(anyhow!("user not found"))?;
4333 };
4334 authorize_access_to_language_models(&session).await?;
4335 session
4336 .rate_limiter
4337 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4338 .await?;
4339
4340 if request.model.starts_with("gpt") {
4341 let api_key =
4342 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4343 complete_with_open_ai(request, response, session, api_key).await?;
4344 } else if request.model.starts_with("gemini") {
4345 let api_key = google_ai_api_key
4346 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4347 complete_with_google_ai(request, response, session, api_key).await?;
4348 } else if request.model.starts_with("claude") {
4349 let api_key = anthropic_api_key
4350 .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4351 complete_with_anthropic(request, response, session, api_key).await?;
4352 }
4353
4354 Ok(())
4355}
4356
4357async fn complete_with_open_ai(
4358 request: proto::CompleteWithLanguageModel,
4359 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4360 session: UserSession,
4361 api_key: Arc<str>,
4362) -> Result<()> {
4363 let mut completion_stream = open_ai::stream_completion(
4364 session.http_client.as_ref(),
4365 OPEN_AI_API_URL,
4366 &api_key,
4367 crate::ai::language_model_request_to_open_ai(request)?,
4368 None,
4369 )
4370 .await
4371 .context("open_ai::stream_completion request failed within collab")?;
4372
4373 while let Some(event) = completion_stream.next().await {
4374 let event = event?;
4375 response.send(proto::LanguageModelResponse {
4376 choices: event
4377 .choices
4378 .into_iter()
4379 .map(|choice| proto::LanguageModelChoiceDelta {
4380 index: choice.index,
4381 delta: Some(proto::LanguageModelResponseMessage {
4382 role: choice.delta.role.map(|role| match role {
4383 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4384 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4385 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4386 open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4387 } as i32),
4388 content: choice.delta.content,
4389 tool_calls: choice
4390 .delta
4391 .tool_calls
4392 .into_iter()
4393 .map(|delta| proto::ToolCallDelta {
4394 index: delta.index as u32,
4395 id: delta.id,
4396 variant: match delta.function {
4397 Some(function) => {
4398 let name = function.name;
4399 let arguments = function.arguments;
4400
4401 Some(proto::tool_call_delta::Variant::Function(
4402 proto::tool_call_delta::FunctionCallDelta {
4403 name,
4404 arguments,
4405 },
4406 ))
4407 }
4408 None => None,
4409 },
4410 })
4411 .collect(),
4412 }),
4413 finish_reason: choice.finish_reason,
4414 })
4415 .collect(),
4416 })?;
4417 }
4418
4419 Ok(())
4420}
4421
4422async fn complete_with_google_ai(
4423 request: proto::CompleteWithLanguageModel,
4424 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4425 session: UserSession,
4426 api_key: Arc<str>,
4427) -> Result<()> {
4428 let mut stream = google_ai::stream_generate_content(
4429 session.http_client.clone(),
4430 google_ai::API_URL,
4431 api_key.as_ref(),
4432 crate::ai::language_model_request_to_google_ai(request)?,
4433 )
4434 .await
4435 .context("google_ai::stream_generate_content request failed")?;
4436
4437 while let Some(event) = stream.next().await {
4438 let event = event?;
4439 response.send(proto::LanguageModelResponse {
4440 choices: event
4441 .candidates
4442 .unwrap_or_default()
4443 .into_iter()
4444 .map(|candidate| proto::LanguageModelChoiceDelta {
4445 index: candidate.index as u32,
4446 delta: Some(proto::LanguageModelResponseMessage {
4447 role: Some(match candidate.content.role {
4448 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4449 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4450 } as i32),
4451 content: Some(
4452 candidate
4453 .content
4454 .parts
4455 .into_iter()
4456 .filter_map(|part| match part {
4457 google_ai::Part::TextPart(part) => Some(part.text),
4458 google_ai::Part::InlineDataPart(_) => None,
4459 })
4460 .collect(),
4461 ),
4462 // Tool calls are not supported for Google
4463 tool_calls: Vec::new(),
4464 }),
4465 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4466 })
4467 .collect(),
4468 })?;
4469 }
4470
4471 Ok(())
4472}
4473
4474async fn complete_with_anthropic(
4475 request: proto::CompleteWithLanguageModel,
4476 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4477 session: UserSession,
4478 api_key: Arc<str>,
4479) -> Result<()> {
4480 let model = anthropic::Model::from_id(&request.model)?;
4481
4482 let mut system_message = String::new();
4483 let messages = request
4484 .messages
4485 .into_iter()
4486 .filter_map(|message| {
4487 match message.role() {
4488 LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4489 role: anthropic::Role::User,
4490 content: message.content,
4491 }),
4492 LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4493 role: anthropic::Role::Assistant,
4494 content: message.content,
4495 }),
4496 // Anthropic's API breaks system instructions out as a separate field rather
4497 // than having a system message role.
4498 LanguageModelRole::LanguageModelSystem => {
4499 if !system_message.is_empty() {
4500 system_message.push_str("\n\n");
4501 }
4502 system_message.push_str(&message.content);
4503
4504 None
4505 }
4506 // We don't yet support tool calls for Anthropic
4507 LanguageModelRole::LanguageModelTool => None,
4508 }
4509 })
4510 .collect();
4511
4512 let mut stream = anthropic::stream_completion(
4513 session.http_client.as_ref(),
4514 anthropic::ANTHROPIC_API_URL,
4515 &api_key,
4516 anthropic::Request {
4517 model,
4518 messages,
4519 stream: true,
4520 system: system_message,
4521 max_tokens: 4092,
4522 },
4523 None,
4524 )
4525 .await?;
4526
4527 let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4528
4529 while let Some(event) = stream.next().await {
4530 let event = event?;
4531
4532 match event {
4533 anthropic::ResponseEvent::MessageStart { message } => {
4534 if let Some(role) = message.role {
4535 if role == "assistant" {
4536 current_role = proto::LanguageModelRole::LanguageModelAssistant;
4537 } else if role == "user" {
4538 current_role = proto::LanguageModelRole::LanguageModelUser;
4539 }
4540 }
4541 }
4542 anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4543 match content_block {
4544 anthropic::ContentBlock::Text { text } => {
4545 if !text.is_empty() {
4546 response.send(proto::LanguageModelResponse {
4547 choices: vec![proto::LanguageModelChoiceDelta {
4548 index: 0,
4549 delta: Some(proto::LanguageModelResponseMessage {
4550 role: Some(current_role as i32),
4551 content: Some(text),
4552 tool_calls: Vec::new(),
4553 }),
4554 finish_reason: None,
4555 }],
4556 })?;
4557 }
4558 }
4559 }
4560 }
4561 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4562 anthropic::TextDelta::TextDelta { text } => {
4563 response.send(proto::LanguageModelResponse {
4564 choices: vec![proto::LanguageModelChoiceDelta {
4565 index: 0,
4566 delta: Some(proto::LanguageModelResponseMessage {
4567 role: Some(current_role as i32),
4568 content: Some(text),
4569 tool_calls: Vec::new(),
4570 }),
4571 finish_reason: None,
4572 }],
4573 })?;
4574 }
4575 },
4576 anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4577 if let Some(stop_reason) = delta.stop_reason {
4578 response.send(proto::LanguageModelResponse {
4579 choices: vec![proto::LanguageModelChoiceDelta {
4580 index: 0,
4581 delta: None,
4582 finish_reason: Some(stop_reason),
4583 }],
4584 })?;
4585 }
4586 }
4587 anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4588 anthropic::ResponseEvent::MessageStop {} => {}
4589 anthropic::ResponseEvent::Ping {} => {}
4590 }
4591 }
4592
4593 Ok(())
4594}
4595
4596struct CountTokensWithLanguageModelRateLimit;
4597
4598impl RateLimit for CountTokensWithLanguageModelRateLimit {
4599 fn capacity() -> usize {
4600 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4601 .ok()
4602 .and_then(|v| v.parse().ok())
4603 .unwrap_or(600) // Picked arbitrarily
4604 }
4605
4606 fn refill_duration() -> chrono::Duration {
4607 chrono::Duration::hours(1)
4608 }
4609
4610 fn db_name() -> &'static str {
4611 "count-tokens-with-language-model"
4612 }
4613}
4614
4615async fn count_tokens_with_language_model(
4616 request: proto::CountTokensWithLanguageModel,
4617 response: Response<proto::CountTokensWithLanguageModel>,
4618 session: UserSession,
4619 google_ai_api_key: Option<Arc<str>>,
4620) -> Result<()> {
4621 authorize_access_to_language_models(&session).await?;
4622
4623 if !request.model.starts_with("gemini") {
4624 return Err(anyhow!(
4625 "counting tokens for model: {:?} is not supported",
4626 request.model
4627 ))?;
4628 }
4629
4630 session
4631 .rate_limiter
4632 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4633 .await?;
4634
4635 let api_key = google_ai_api_key
4636 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4637 let tokens_response = google_ai::count_tokens(
4638 session.http_client.as_ref(),
4639 google_ai::API_URL,
4640 &api_key,
4641 crate::ai::count_tokens_request_to_google_ai(request)?,
4642 )
4643 .await?;
4644 response.send(proto::CountTokensResponse {
4645 token_count: tokens_response.total_tokens as u32,
4646 })?;
4647 Ok(())
4648}
4649
4650struct ComputeEmbeddingsRateLimit;
4651
4652impl RateLimit for ComputeEmbeddingsRateLimit {
4653 fn capacity() -> usize {
4654 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4655 .ok()
4656 .and_then(|v| v.parse().ok())
4657 .unwrap_or(5000) // Picked arbitrarily
4658 }
4659
4660 fn refill_duration() -> chrono::Duration {
4661 chrono::Duration::hours(1)
4662 }
4663
4664 fn db_name() -> &'static str {
4665 "compute-embeddings"
4666 }
4667}
4668
4669async fn compute_embeddings(
4670 request: proto::ComputeEmbeddings,
4671 response: Response<proto::ComputeEmbeddings>,
4672 session: UserSession,
4673 api_key: Option<Arc<str>>,
4674) -> Result<()> {
4675 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4676 authorize_access_to_language_models(&session).await?;
4677
4678 session
4679 .rate_limiter
4680 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4681 .await?;
4682
4683 let embeddings = match request.model.as_str() {
4684 "openai/text-embedding-3-small" => {
4685 open_ai::embed(
4686 session.http_client.as_ref(),
4687 OPEN_AI_API_URL,
4688 &api_key,
4689 OpenAiEmbeddingModel::TextEmbedding3Small,
4690 request.texts.iter().map(|text| text.as_str()),
4691 )
4692 .await?
4693 }
4694 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4695 };
4696
4697 let embeddings = request
4698 .texts
4699 .iter()
4700 .map(|text| {
4701 let mut hasher = sha2::Sha256::new();
4702 hasher.update(text.as_bytes());
4703 let result = hasher.finalize();
4704 result.to_vec()
4705 })
4706 .zip(
4707 embeddings
4708 .data
4709 .into_iter()
4710 .map(|embedding| embedding.embedding),
4711 )
4712 .collect::<HashMap<_, _>>();
4713
4714 let db = session.db().await;
4715 db.save_embeddings(&request.model, &embeddings)
4716 .await
4717 .context("failed to save embeddings")
4718 .trace_err();
4719
4720 response.send(proto::ComputeEmbeddingsResponse {
4721 embeddings: embeddings
4722 .into_iter()
4723 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4724 .collect(),
4725 })?;
4726 Ok(())
4727}
4728
4729async fn get_cached_embeddings(
4730 request: proto::GetCachedEmbeddings,
4731 response: Response<proto::GetCachedEmbeddings>,
4732 session: UserSession,
4733) -> Result<()> {
4734 authorize_access_to_language_models(&session).await?;
4735
4736 let db = session.db().await;
4737 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4738
4739 response.send(proto::GetCachedEmbeddingsResponse {
4740 embeddings: embeddings
4741 .into_iter()
4742 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4743 .collect(),
4744 })?;
4745 Ok(())
4746}
4747
4748async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4749 let db = session.db().await;
4750 let flags = db.get_user_flags(session.user_id()).await?;
4751 if flags.iter().any(|flag| flag == "language-models") {
4752 Ok(())
4753 } else {
4754 Err(anyhow!("permission denied"))?
4755 }
4756}
4757
4758/// Get a Supermaven API key for the user
4759async fn get_supermaven_api_key(
4760 _request: proto::GetSupermavenApiKey,
4761 response: Response<proto::GetSupermavenApiKey>,
4762 session: UserSession,
4763) -> Result<()> {
4764 let user_id: String = session.user_id().to_string();
4765 if !session.is_staff() {
4766 return Err(anyhow!("supermaven not enabled for this account"))?;
4767 }
4768
4769 let email = session
4770 .email()
4771 .ok_or_else(|| anyhow!("user must have an email"))?;
4772
4773 let supermaven_admin_api = session
4774 .supermaven_client
4775 .as_ref()
4776 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4777
4778 let result = supermaven_admin_api
4779 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4780 .await?;
4781
4782 response.send(proto::GetSupermavenApiKeyResponse {
4783 api_key: result.api_key,
4784 })?;
4785
4786 Ok(())
4787}
4788
4789/// Start receiving chat updates for a channel
4790async fn join_channel_chat(
4791 request: proto::JoinChannelChat,
4792 response: Response<proto::JoinChannelChat>,
4793 session: UserSession,
4794) -> Result<()> {
4795 let channel_id = ChannelId::from_proto(request.channel_id);
4796
4797 let db = session.db().await;
4798 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4799 .await?;
4800 let messages = db
4801 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4802 .await?;
4803 response.send(proto::JoinChannelChatResponse {
4804 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4805 messages,
4806 })?;
4807 Ok(())
4808}
4809
4810/// Stop receiving chat updates for a channel
4811async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4812 let channel_id = ChannelId::from_proto(request.channel_id);
4813 session
4814 .db()
4815 .await
4816 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4817 .await?;
4818 Ok(())
4819}
4820
4821/// Retrieve the chat history for a channel
4822async fn get_channel_messages(
4823 request: proto::GetChannelMessages,
4824 response: Response<proto::GetChannelMessages>,
4825 session: UserSession,
4826) -> Result<()> {
4827 let channel_id = ChannelId::from_proto(request.channel_id);
4828 let messages = session
4829 .db()
4830 .await
4831 .get_channel_messages(
4832 channel_id,
4833 session.user_id(),
4834 MESSAGE_COUNT_PER_PAGE,
4835 Some(MessageId::from_proto(request.before_message_id)),
4836 )
4837 .await?;
4838 response.send(proto::GetChannelMessagesResponse {
4839 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4840 messages,
4841 })?;
4842 Ok(())
4843}
4844
4845/// Retrieve specific chat messages
4846async fn get_channel_messages_by_id(
4847 request: proto::GetChannelMessagesById,
4848 response: Response<proto::GetChannelMessagesById>,
4849 session: UserSession,
4850) -> Result<()> {
4851 let message_ids = request
4852 .message_ids
4853 .iter()
4854 .map(|id| MessageId::from_proto(*id))
4855 .collect::<Vec<_>>();
4856 let messages = session
4857 .db()
4858 .await
4859 .get_channel_messages_by_id(session.user_id(), &message_ids)
4860 .await?;
4861 response.send(proto::GetChannelMessagesResponse {
4862 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4863 messages,
4864 })?;
4865 Ok(())
4866}
4867
4868/// Retrieve the current users notifications
4869async fn get_notifications(
4870 request: proto::GetNotifications,
4871 response: Response<proto::GetNotifications>,
4872 session: UserSession,
4873) -> Result<()> {
4874 let notifications = session
4875 .db()
4876 .await
4877 .get_notifications(
4878 session.user_id(),
4879 NOTIFICATION_COUNT_PER_PAGE,
4880 request
4881 .before_id
4882 .map(|id| db::NotificationId::from_proto(id)),
4883 )
4884 .await?;
4885 response.send(proto::GetNotificationsResponse {
4886 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4887 notifications,
4888 })?;
4889 Ok(())
4890}
4891
4892/// Mark notifications as read
4893async fn mark_notification_as_read(
4894 request: proto::MarkNotificationRead,
4895 response: Response<proto::MarkNotificationRead>,
4896 session: UserSession,
4897) -> Result<()> {
4898 let database = &session.db().await;
4899 let notifications = database
4900 .mark_notification_as_read_by_id(
4901 session.user_id(),
4902 NotificationId::from_proto(request.notification_id),
4903 )
4904 .await?;
4905 send_notifications(
4906 &*session.connection_pool().await,
4907 &session.peer,
4908 notifications,
4909 );
4910 response.send(proto::Ack {})?;
4911 Ok(())
4912}
4913
4914/// Get the current users information
4915async fn get_private_user_info(
4916 _request: proto::GetPrivateUserInfo,
4917 response: Response<proto::GetPrivateUserInfo>,
4918 session: UserSession,
4919) -> Result<()> {
4920 let db = session.db().await;
4921
4922 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4923 let user = db
4924 .get_user_by_id(session.user_id())
4925 .await?
4926 .ok_or_else(|| anyhow!("user not found"))?;
4927 let flags = db.get_user_flags(session.user_id()).await?;
4928
4929 response.send(proto::GetPrivateUserInfoResponse {
4930 metrics_id,
4931 staff: user.admin,
4932 flags,
4933 })?;
4934 Ok(())
4935}
4936
4937fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4938 match message {
4939 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4940 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4941 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4942 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4943 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4944 code: frame.code.into(),
4945 reason: frame.reason,
4946 })),
4947 }
4948}
4949
4950fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4951 match message {
4952 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4953 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4954 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4955 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4956 AxumMessage::Close(frame) => {
4957 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4958 code: frame.code.into(),
4959 reason: frame.reason,
4960 }))
4961 }
4962 }
4963}
4964
4965fn notify_membership_updated(
4966 connection_pool: &mut ConnectionPool,
4967 result: MembershipUpdated,
4968 user_id: UserId,
4969 peer: &Peer,
4970) {
4971 for membership in &result.new_channels.channel_memberships {
4972 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4973 }
4974 for channel_id in &result.removed_channels {
4975 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4976 }
4977
4978 let user_channels_update = proto::UpdateUserChannels {
4979 channel_memberships: result
4980 .new_channels
4981 .channel_memberships
4982 .iter()
4983 .map(|cm| proto::ChannelMembership {
4984 channel_id: cm.channel_id.to_proto(),
4985 role: cm.role.into(),
4986 })
4987 .collect(),
4988 ..Default::default()
4989 };
4990
4991 let mut update = build_channels_update(result.new_channels, vec![]);
4992 update.delete_channels = result
4993 .removed_channels
4994 .into_iter()
4995 .map(|id| id.to_proto())
4996 .collect();
4997 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4998
4999 for connection_id in connection_pool.user_connection_ids(user_id) {
5000 peer.send(connection_id, user_channels_update.clone())
5001 .trace_err();
5002 peer.send(connection_id, update.clone()).trace_err();
5003 }
5004}
5005
5006fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5007 proto::UpdateUserChannels {
5008 channel_memberships: channels
5009 .channel_memberships
5010 .iter()
5011 .map(|m| proto::ChannelMembership {
5012 channel_id: m.channel_id.to_proto(),
5013 role: m.role.into(),
5014 })
5015 .collect(),
5016 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5017 observed_channel_message_id: channels.observed_channel_messages.clone(),
5018 }
5019}
5020
5021fn build_channels_update(
5022 channels: ChannelsForUser,
5023 channel_invites: Vec<db::Channel>,
5024) -> proto::UpdateChannels {
5025 let mut update = proto::UpdateChannels::default();
5026
5027 for channel in channels.channels {
5028 update.channels.push(channel.to_proto());
5029 }
5030
5031 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5032 update.latest_channel_message_ids = channels.latest_channel_messages;
5033
5034 for (channel_id, participants) in channels.channel_participants {
5035 update
5036 .channel_participants
5037 .push(proto::ChannelParticipants {
5038 channel_id: channel_id.to_proto(),
5039 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5040 });
5041 }
5042
5043 for channel in channel_invites {
5044 update.channel_invitations.push(channel.to_proto());
5045 }
5046
5047 update.hosted_projects = channels.hosted_projects;
5048 update
5049}
5050
5051fn build_initial_contacts_update(
5052 contacts: Vec<db::Contact>,
5053 pool: &ConnectionPool,
5054) -> proto::UpdateContacts {
5055 let mut update = proto::UpdateContacts::default();
5056
5057 for contact in contacts {
5058 match contact {
5059 db::Contact::Accepted { user_id, busy } => {
5060 update.contacts.push(contact_for_user(user_id, busy, &pool));
5061 }
5062 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5063 db::Contact::Incoming { user_id } => {
5064 update
5065 .incoming_requests
5066 .push(proto::IncomingContactRequest {
5067 requester_id: user_id.to_proto(),
5068 })
5069 }
5070 }
5071 }
5072
5073 update
5074}
5075
5076fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5077 proto::Contact {
5078 user_id: user_id.to_proto(),
5079 online: pool.is_user_online(user_id),
5080 busy,
5081 }
5082}
5083
5084fn room_updated(room: &proto::Room, peer: &Peer) {
5085 broadcast(
5086 None,
5087 room.participants
5088 .iter()
5089 .filter_map(|participant| Some(participant.peer_id?.into())),
5090 |peer_id| {
5091 peer.send(
5092 peer_id,
5093 proto::RoomUpdated {
5094 room: Some(room.clone()),
5095 },
5096 )
5097 },
5098 );
5099}
5100
5101fn channel_updated(
5102 channel: &db::channel::Model,
5103 room: &proto::Room,
5104 peer: &Peer,
5105 pool: &ConnectionPool,
5106) {
5107 let participants = room
5108 .participants
5109 .iter()
5110 .map(|p| p.user_id)
5111 .collect::<Vec<_>>();
5112
5113 broadcast(
5114 None,
5115 pool.channel_connection_ids(channel.root_id())
5116 .filter_map(|(channel_id, role)| {
5117 role.can_see_channel(channel.visibility).then(|| channel_id)
5118 }),
5119 |peer_id| {
5120 peer.send(
5121 peer_id,
5122 proto::UpdateChannels {
5123 channel_participants: vec![proto::ChannelParticipants {
5124 channel_id: channel.id.to_proto(),
5125 participant_user_ids: participants.clone(),
5126 }],
5127 ..Default::default()
5128 },
5129 )
5130 },
5131 );
5132}
5133
5134async fn send_dev_server_projects_update(
5135 user_id: UserId,
5136 mut status: proto::DevServerProjectsUpdate,
5137 session: &Session,
5138) {
5139 let pool = session.connection_pool().await;
5140 for dev_server in &mut status.dev_servers {
5141 dev_server.status =
5142 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5143 }
5144 let connections = pool.user_connection_ids(user_id);
5145 for connection_id in connections {
5146 session.peer.send(connection_id, status.clone()).trace_err();
5147 }
5148}
5149
5150async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5151 let db = session.db().await;
5152
5153 let contacts = db.get_contacts(user_id).await?;
5154 let busy = db.is_user_busy(user_id).await?;
5155
5156 let pool = session.connection_pool().await;
5157 let updated_contact = contact_for_user(user_id, busy, &pool);
5158 for contact in contacts {
5159 if let db::Contact::Accepted {
5160 user_id: contact_user_id,
5161 ..
5162 } = contact
5163 {
5164 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5165 session
5166 .peer
5167 .send(
5168 contact_conn_id,
5169 proto::UpdateContacts {
5170 contacts: vec![updated_contact.clone()],
5171 remove_contacts: Default::default(),
5172 incoming_requests: Default::default(),
5173 remove_incoming_requests: Default::default(),
5174 outgoing_requests: Default::default(),
5175 remove_outgoing_requests: Default::default(),
5176 },
5177 )
5178 .trace_err();
5179 }
5180 }
5181 }
5182 Ok(())
5183}
5184
5185async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5186 log::info!("lost dev server connection, unsharing projects");
5187 let project_ids = session
5188 .db()
5189 .await
5190 .get_stale_dev_server_projects(session.connection_id)
5191 .await?;
5192
5193 for project_id in project_ids {
5194 // not unshare re-checks the connection ids match, so we get away with no transaction
5195 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5196 }
5197
5198 let user_id = session.dev_server().user_id;
5199 let update = session
5200 .db()
5201 .await
5202 .dev_server_projects_update(user_id)
5203 .await?;
5204
5205 send_dev_server_projects_update(user_id, update, session).await;
5206
5207 Ok(())
5208}
5209
5210async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5211 let mut contacts_to_update = HashSet::default();
5212
5213 let room_id;
5214 let canceled_calls_to_user_ids;
5215 let live_kit_room;
5216 let delete_live_kit_room;
5217 let room;
5218 let channel;
5219
5220 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5221 contacts_to_update.insert(session.user_id());
5222
5223 for project in left_room.left_projects.values() {
5224 project_left(project, session);
5225 }
5226
5227 room_id = RoomId::from_proto(left_room.room.id);
5228 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5229 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5230 delete_live_kit_room = left_room.deleted;
5231 room = mem::take(&mut left_room.room);
5232 channel = mem::take(&mut left_room.channel);
5233
5234 room_updated(&room, &session.peer);
5235 } else {
5236 return Ok(());
5237 }
5238
5239 if let Some(channel) = channel {
5240 channel_updated(
5241 &channel,
5242 &room,
5243 &session.peer,
5244 &*session.connection_pool().await,
5245 );
5246 }
5247
5248 {
5249 let pool = session.connection_pool().await;
5250 for canceled_user_id in canceled_calls_to_user_ids {
5251 for connection_id in pool.user_connection_ids(canceled_user_id) {
5252 session
5253 .peer
5254 .send(
5255 connection_id,
5256 proto::CallCanceled {
5257 room_id: room_id.to_proto(),
5258 },
5259 )
5260 .trace_err();
5261 }
5262 contacts_to_update.insert(canceled_user_id);
5263 }
5264 }
5265
5266 for contact_user_id in contacts_to_update {
5267 update_user_contacts(contact_user_id, &session).await?;
5268 }
5269
5270 if let Some(live_kit) = session.live_kit_client.as_ref() {
5271 live_kit
5272 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5273 .await
5274 .trace_err();
5275
5276 if delete_live_kit_room {
5277 live_kit.delete_room(live_kit_room).await.trace_err();
5278 }
5279 }
5280
5281 Ok(())
5282}
5283
5284async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5285 let left_channel_buffers = session
5286 .db()
5287 .await
5288 .leave_channel_buffers(session.connection_id)
5289 .await?;
5290
5291 for left_buffer in left_channel_buffers {
5292 channel_buffer_updated(
5293 session.connection_id,
5294 left_buffer.connections,
5295 &proto::UpdateChannelBufferCollaborators {
5296 channel_id: left_buffer.channel_id.to_proto(),
5297 collaborators: left_buffer.collaborators,
5298 },
5299 &session.peer,
5300 );
5301 }
5302
5303 Ok(())
5304}
5305
5306fn project_left(project: &db::LeftProject, session: &UserSession) {
5307 for connection_id in &project.connection_ids {
5308 if project.should_unshare {
5309 session
5310 .peer
5311 .send(
5312 *connection_id,
5313 proto::UnshareProject {
5314 project_id: project.id.to_proto(),
5315 },
5316 )
5317 .trace_err();
5318 } else {
5319 session
5320 .peer
5321 .send(
5322 *connection_id,
5323 proto::RemoveProjectCollaborator {
5324 project_id: project.id.to_proto(),
5325 peer_id: Some(session.connection_id.into()),
5326 },
5327 )
5328 .trace_err();
5329 }
5330 }
5331}
5332
5333pub trait ResultExt {
5334 type Ok;
5335
5336 fn trace_err(self) -> Option<Self::Ok>;
5337}
5338
5339impl<T, E> ResultExt for Result<T, E>
5340where
5341 E: std::fmt::Debug,
5342{
5343 type Ok = T;
5344
5345 #[track_caller]
5346 fn trace_err(self) -> Option<T> {
5347 match self {
5348 Ok(value) => Some(value),
5349 Err(error) => {
5350 tracing::error!("{:?}", error);
5351 None
5352 }
5353 }
5354 }
5355}