1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
8 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
9 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38
39use futures::{
40 channel::oneshot,
41 future::{self, BoxFuture},
42 stream::FuturesUnordered,
43 FutureExt, SinkExt, StreamExt, TryStreamExt,
44};
45use http::IsahcHttpClient;
46use prometheus::{register_int_gauge, IntGauge};
47use rpc::{
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
50 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 atomic::{AtomicBool, Ordering::SeqCst},
66 Arc, OnceLock,
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{watch, Semaphore};
72use tower::ServiceBuilder;
73use tracing::{
74 field::{self},
75 info_span, instrument, Instrument,
76};
77
78use self::connection_pool::VersionedMessage;
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106struct StreamingResponse<R: RequestMessage> {
107 peer: Arc<Peer>,
108 receipt: Receipt<R>,
109}
110
111impl<R: RequestMessage> StreamingResponse<R> {
112 fn send(&self, payload: R::Response) -> Result<()> {
113 self.peer.respond(self.receipt, payload)?;
114 Ok(())
115 }
116}
117
118#[derive(Clone, Debug)]
119pub enum Principal {
120 User(User),
121 Impersonated { user: User, admin: User },
122 DevServer(dev_server::Model),
123}
124
125impl Principal {
126 fn update_span(&self, span: &tracing::Span) {
127 match &self {
128 Principal::User(user) => {
129 span.record("user_id", &user.id.0);
130 span.record("login", &user.github_login);
131 }
132 Principal::Impersonated { user, admin } => {
133 span.record("user_id", &user.id.0);
134 span.record("login", &user.github_login);
135 span.record("impersonator", &admin.github_login);
136 }
137 Principal::DevServer(dev_server) => {
138 span.record("dev_server_id", &dev_server.id.0);
139 }
140 }
141 }
142}
143
144#[derive(Clone)]
145struct Session {
146 principal: Principal,
147 connection_id: ConnectionId,
148 db: Arc<tokio::sync::Mutex<DbHandle>>,
149 peer: Arc<Peer>,
150 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
151 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
152 supermaven_client: Option<Arc<SupermavenAdminApi>>,
153 http_client: Arc<IsahcHttpClient>,
154 rate_limiter: Arc<RateLimiter>,
155 _executor: Executor,
156}
157
158impl Session {
159 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.db.lock().await;
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 guard
166 }
167
168 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
169 #[cfg(test)]
170 tokio::task::yield_now().await;
171 let guard = self.connection_pool.lock();
172 ConnectionPoolGuard {
173 guard,
174 _not_send: PhantomData,
175 }
176 }
177
178 fn for_user(self) -> Option<UserSession> {
179 UserSession::new(self)
180 }
181
182 fn for_dev_server(self) -> Option<DevServerSession> {
183 DevServerSession::new(self)
184 }
185
186 fn user_id(&self) -> Option<UserId> {
187 match &self.principal {
188 Principal::User(user) => Some(user.id),
189 Principal::Impersonated { user, .. } => Some(user.id),
190 Principal::DevServer(_) => None,
191 }
192 }
193
194 fn is_staff(&self) -> bool {
195 match &self.principal {
196 Principal::User(user) => user.admin,
197 Principal::Impersonated { .. } => true,
198 Principal::DevServer(_) => false,
199 }
200 }
201
202 fn dev_server_id(&self) -> Option<DevServerId> {
203 match &self.principal {
204 Principal::User(_) | Principal::Impersonated { .. } => None,
205 Principal::DevServer(dev_server) => Some(dev_server.id),
206 }
207 }
208
209 fn principal_id(&self) -> PrincipalId {
210 match &self.principal {
211 Principal::User(user) => PrincipalId::UserId(user.id),
212 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
213 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
214 }
215 }
216}
217
218impl Debug for Session {
219 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220 let mut result = f.debug_struct("Session");
221 match &self.principal {
222 Principal::User(user) => {
223 result.field("user", &user.github_login);
224 }
225 Principal::Impersonated { user, admin } => {
226 result.field("user", &user.github_login);
227 result.field("impersonator", &admin.github_login);
228 }
229 Principal::DevServer(dev_server) => {
230 result.field("dev_server", &dev_server.id);
231 }
232 }
233 result.field("connection_id", &self.connection_id).finish()
234 }
235}
236
237struct UserSession(Session);
238
239impl UserSession {
240 pub fn new(s: Session) -> Option<Self> {
241 s.user_id().map(|_| UserSession(s))
242 }
243 pub fn user_id(&self) -> UserId {
244 self.0.user_id().unwrap()
245 }
246
247 pub fn email(&self) -> Option<String> {
248 match &self.0.principal {
249 Principal::User(user) => user.email_address.clone(),
250 Principal::Impersonated { user, .. } => user.email_address.clone(),
251 Principal::DevServer(..) => None,
252 }
253 }
254}
255
256impl Deref for UserSession {
257 type Target = Session;
258
259 fn deref(&self) -> &Self::Target {
260 &self.0
261 }
262}
263impl DerefMut for UserSession {
264 fn deref_mut(&mut self) -> &mut Self::Target {
265 &mut self.0
266 }
267}
268
269struct DevServerSession(Session);
270
271impl DevServerSession {
272 pub fn new(s: Session) -> Option<Self> {
273 s.dev_server_id().map(|_| DevServerSession(s))
274 }
275 pub fn dev_server_id(&self) -> DevServerId {
276 self.0.dev_server_id().unwrap()
277 }
278
279 fn dev_server(&self) -> &dev_server::Model {
280 match &self.0.principal {
281 Principal::DevServer(dev_server) => dev_server,
282 _ => unreachable!(),
283 }
284 }
285}
286
287impl Deref for DevServerSession {
288 type Target = Session;
289
290 fn deref(&self) -> &Self::Target {
291 &self.0
292 }
293}
294impl DerefMut for DevServerSession {
295 fn deref_mut(&mut self) -> &mut Self::Target {
296 &mut self.0
297 }
298}
299
300fn user_handler<M: RequestMessage, Fut>(
301 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
303where
304 Fut: Send + Future<Output = Result<()>>,
305{
306 let handler = Arc::new(handler);
307 move |message, response, session| {
308 let handler = handler.clone();
309 Box::pin(async move {
310 if let Some(user_session) = session.for_user() {
311 Ok(handler(message, response, user_session).await?)
312 } else {
313 Err(Error::Internal(anyhow!(
314 "must be a user to call {}",
315 M::NAME
316 )))
317 }
318 })
319 }
320}
321
322fn dev_server_handler<M: RequestMessage, Fut>(
323 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
325where
326 Fut: Send + Future<Output = Result<()>>,
327{
328 let handler = Arc::new(handler);
329 move |message, response, session| {
330 let handler = handler.clone();
331 Box::pin(async move {
332 if let Some(dev_server_session) = session.for_dev_server() {
333 Ok(handler(message, response, dev_server_session).await?)
334 } else {
335 Err(Error::Internal(anyhow!(
336 "must be a dev server to call {}",
337 M::NAME
338 )))
339 }
340 })
341 }
342}
343
344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
345 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
347where
348 InnertRetFut: Send + Future<Output = Result<()>>,
349{
350 let handler = Arc::new(handler);
351 move |message, session| {
352 let handler = handler.clone();
353 Box::pin(async move {
354 if let Some(user_session) = session.for_user() {
355 Ok(handler(message, user_session).await?)
356 } else {
357 Err(Error::Internal(anyhow!(
358 "must be a user to call {}",
359 M::NAME
360 )))
361 }
362 })
363 }
364}
365
366struct DbHandle(Arc<Database>);
367
368impl Deref for DbHandle {
369 type Target = Database;
370
371 fn deref(&self) -> &Self::Target {
372 self.0.as_ref()
373 }
374}
375
376pub struct Server {
377 id: parking_lot::Mutex<ServerId>,
378 peer: Arc<Peer>,
379 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
380 app_state: Arc<AppState>,
381 handlers: HashMap<TypeId, MessageHandler>,
382 teardown: watch::Sender<bool>,
383}
384
385pub(crate) struct ConnectionPoolGuard<'a> {
386 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
387 _not_send: PhantomData<Rc<()>>,
388}
389
390#[derive(Serialize)]
391pub struct ServerSnapshot<'a> {
392 peer: &'a Peer,
393 #[serde(serialize_with = "serialize_deref")]
394 connection_pool: ConnectionPoolGuard<'a>,
395}
396
397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
398where
399 S: Serializer,
400 T: Deref<Target = U>,
401 U: Serialize,
402{
403 Serialize::serialize(value.deref(), serializer)
404}
405
406impl Server {
407 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
408 let mut server = Self {
409 id: parking_lot::Mutex::new(id),
410 peer: Peer::new(id.0 as u32),
411 app_state: app_state.clone(),
412 connection_pool: Default::default(),
413 handlers: Default::default(),
414 teardown: watch::channel(false).0,
415 };
416
417 server
418 .add_request_handler(ping)
419 .add_request_handler(user_handler(create_room))
420 .add_request_handler(user_handler(join_room))
421 .add_request_handler(user_handler(rejoin_room))
422 .add_request_handler(user_handler(leave_room))
423 .add_request_handler(user_handler(set_room_participant_role))
424 .add_request_handler(user_handler(call))
425 .add_request_handler(user_handler(cancel_call))
426 .add_message_handler(user_message_handler(decline_call))
427 .add_request_handler(user_handler(update_participant_location))
428 .add_request_handler(user_handler(share_project))
429 .add_message_handler(unshare_project)
430 .add_request_handler(user_handler(join_project))
431 .add_request_handler(user_handler(join_hosted_project))
432 .add_request_handler(user_handler(rejoin_dev_server_projects))
433 .add_request_handler(user_handler(create_dev_server_project))
434 .add_request_handler(user_handler(delete_dev_server_project))
435 .add_request_handler(user_handler(create_dev_server))
436 .add_request_handler(user_handler(regenerate_dev_server_token))
437 .add_request_handler(user_handler(rename_dev_server))
438 .add_request_handler(user_handler(delete_dev_server))
439 .add_request_handler(dev_server_handler(share_dev_server_project))
440 .add_request_handler(dev_server_handler(shutdown_dev_server))
441 .add_request_handler(dev_server_handler(reconnect_dev_server))
442 .add_message_handler(user_message_handler(leave_project))
443 .add_request_handler(update_project)
444 .add_request_handler(update_worktree)
445 .add_message_handler(start_language_server)
446 .add_message_handler(update_language_server)
447 .add_message_handler(update_diagnostic_summary)
448 .add_message_handler(update_worktree_settings)
449 .add_request_handler(user_handler(
450 forward_read_only_project_request::<proto::GetHover>,
451 ))
452 .add_request_handler(user_handler(
453 forward_read_only_project_request::<proto::GetDefinition>,
454 ))
455 .add_request_handler(user_handler(
456 forward_read_only_project_request::<proto::GetTypeDefinition>,
457 ))
458 .add_request_handler(user_handler(
459 forward_read_only_project_request::<proto::GetReferences>,
460 ))
461 .add_request_handler(user_handler(
462 forward_read_only_project_request::<proto::SearchProject>,
463 ))
464 .add_request_handler(user_handler(
465 forward_read_only_project_request::<proto::GetDocumentHighlights>,
466 ))
467 .add_request_handler(user_handler(
468 forward_read_only_project_request::<proto::GetProjectSymbols>,
469 ))
470 .add_request_handler(user_handler(
471 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
472 ))
473 .add_request_handler(user_handler(
474 forward_read_only_project_request::<proto::OpenBufferById>,
475 ))
476 .add_request_handler(user_handler(
477 forward_read_only_project_request::<proto::SynchronizeBuffers>,
478 ))
479 .add_request_handler(user_handler(
480 forward_read_only_project_request::<proto::InlayHints>,
481 ))
482 .add_request_handler(user_handler(
483 forward_read_only_project_request::<proto::OpenBufferByPath>,
484 ))
485 .add_request_handler(user_handler(
486 forward_mutating_project_request::<proto::GetCompletions>,
487 ))
488 .add_request_handler(user_handler(
489 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
490 ))
491 .add_request_handler(user_handler(
492 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
493 ))
494 .add_request_handler(user_handler(
495 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
496 ))
497 .add_request_handler(user_handler(
498 forward_mutating_project_request::<proto::GetCodeActions>,
499 ))
500 .add_request_handler(user_handler(
501 forward_mutating_project_request::<proto::ApplyCodeAction>,
502 ))
503 .add_request_handler(user_handler(
504 forward_mutating_project_request::<proto::PrepareRename>,
505 ))
506 .add_request_handler(user_handler(
507 forward_mutating_project_request::<proto::PerformRename>,
508 ))
509 .add_request_handler(user_handler(
510 forward_mutating_project_request::<proto::ReloadBuffers>,
511 ))
512 .add_request_handler(user_handler(
513 forward_mutating_project_request::<proto::FormatBuffers>,
514 ))
515 .add_request_handler(user_handler(
516 forward_mutating_project_request::<proto::CreateProjectEntry>,
517 ))
518 .add_request_handler(user_handler(
519 forward_mutating_project_request::<proto::RenameProjectEntry>,
520 ))
521 .add_request_handler(user_handler(
522 forward_mutating_project_request::<proto::CopyProjectEntry>,
523 ))
524 .add_request_handler(user_handler(
525 forward_mutating_project_request::<proto::DeleteProjectEntry>,
526 ))
527 .add_request_handler(user_handler(
528 forward_mutating_project_request::<proto::ExpandProjectEntry>,
529 ))
530 .add_request_handler(user_handler(
531 forward_mutating_project_request::<proto::OnTypeFormatting>,
532 ))
533 .add_request_handler(user_handler(
534 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
535 ))
536 .add_request_handler(user_handler(
537 forward_mutating_project_request::<proto::BlameBuffer>,
538 ))
539 .add_request_handler(user_handler(
540 forward_mutating_project_request::<proto::MultiLspQuery>,
541 ))
542 .add_message_handler(create_buffer_for_peer)
543 .add_request_handler(update_buffer)
544 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
545 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
546 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
547 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
548 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
549 .add_request_handler(get_users)
550 .add_request_handler(user_handler(fuzzy_search_users))
551 .add_request_handler(user_handler(request_contact))
552 .add_request_handler(user_handler(remove_contact))
553 .add_request_handler(user_handler(respond_to_contact_request))
554 .add_request_handler(user_handler(create_channel))
555 .add_request_handler(user_handler(delete_channel))
556 .add_request_handler(user_handler(invite_channel_member))
557 .add_request_handler(user_handler(remove_channel_member))
558 .add_request_handler(user_handler(set_channel_member_role))
559 .add_request_handler(user_handler(set_channel_visibility))
560 .add_request_handler(user_handler(rename_channel))
561 .add_request_handler(user_handler(join_channel_buffer))
562 .add_request_handler(user_handler(leave_channel_buffer))
563 .add_message_handler(user_message_handler(update_channel_buffer))
564 .add_request_handler(user_handler(rejoin_channel_buffers))
565 .add_request_handler(user_handler(get_channel_members))
566 .add_request_handler(user_handler(respond_to_channel_invite))
567 .add_request_handler(user_handler(join_channel))
568 .add_request_handler(user_handler(join_channel_chat))
569 .add_message_handler(user_message_handler(leave_channel_chat))
570 .add_request_handler(user_handler(send_channel_message))
571 .add_request_handler(user_handler(remove_channel_message))
572 .add_request_handler(user_handler(update_channel_message))
573 .add_request_handler(user_handler(get_channel_messages))
574 .add_request_handler(user_handler(get_channel_messages_by_id))
575 .add_request_handler(user_handler(get_notifications))
576 .add_request_handler(user_handler(mark_notification_as_read))
577 .add_request_handler(user_handler(move_channel))
578 .add_request_handler(user_handler(follow))
579 .add_message_handler(user_message_handler(unfollow))
580 .add_message_handler(user_message_handler(update_followers))
581 .add_request_handler(user_handler(get_private_user_info))
582 .add_message_handler(user_message_handler(acknowledge_channel_message))
583 .add_message_handler(user_message_handler(acknowledge_buffer_version))
584 .add_request_handler(user_handler(get_supermaven_api_key))
585 .add_streaming_request_handler({
586 let app_state = app_state.clone();
587 move |request, response, session| {
588 complete_with_language_model(
589 request,
590 response,
591 session,
592 app_state.config.openai_api_key.clone(),
593 app_state.config.google_ai_api_key.clone(),
594 app_state.config.anthropic_api_key.clone(),
595 )
596 }
597 })
598 .add_request_handler({
599 let app_state = app_state.clone();
600 user_handler(move |request, response, session| {
601 count_tokens_with_language_model(
602 request,
603 response,
604 session,
605 app_state.config.google_ai_api_key.clone(),
606 )
607 })
608 })
609 .add_request_handler({
610 user_handler(move |request, response, session| {
611 get_cached_embeddings(request, response, session)
612 })
613 })
614 .add_request_handler({
615 let app_state = app_state.clone();
616 user_handler(move |request, response, session| {
617 compute_embeddings(
618 request,
619 response,
620 session,
621 app_state.config.openai_api_key.clone(),
622 )
623 })
624 });
625
626 Arc::new(server)
627 }
628
629 pub async fn start(&self) -> Result<()> {
630 let server_id = *self.id.lock();
631 let app_state = self.app_state.clone();
632 let peer = self.peer.clone();
633 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
634 let pool = self.connection_pool.clone();
635 let live_kit_client = self.app_state.live_kit_client.clone();
636
637 let span = info_span!("start server");
638 self.app_state.executor.spawn_detached(
639 async move {
640 tracing::info!("waiting for cleanup timeout");
641 timeout.await;
642 tracing::info!("cleanup timeout expired, retrieving stale rooms");
643 if let Some((room_ids, channel_ids)) = app_state
644 .db
645 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
646 .await
647 .trace_err()
648 {
649 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
650 tracing::info!(
651 stale_channel_buffer_count = channel_ids.len(),
652 "retrieved stale channel buffers"
653 );
654
655 for channel_id in channel_ids {
656 if let Some(refreshed_channel_buffer) = app_state
657 .db
658 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
659 .await
660 .trace_err()
661 {
662 for connection_id in refreshed_channel_buffer.connection_ids {
663 peer.send(
664 connection_id,
665 proto::UpdateChannelBufferCollaborators {
666 channel_id: channel_id.to_proto(),
667 collaborators: refreshed_channel_buffer
668 .collaborators
669 .clone(),
670 },
671 )
672 .trace_err();
673 }
674 }
675 }
676
677 for room_id in room_ids {
678 let mut contacts_to_update = HashSet::default();
679 let mut canceled_calls_to_user_ids = Vec::new();
680 let mut live_kit_room = String::new();
681 let mut delete_live_kit_room = false;
682
683 if let Some(mut refreshed_room) = app_state
684 .db
685 .clear_stale_room_participants(room_id, server_id)
686 .await
687 .trace_err()
688 {
689 tracing::info!(
690 room_id = room_id.0,
691 new_participant_count = refreshed_room.room.participants.len(),
692 "refreshed room"
693 );
694 room_updated(&refreshed_room.room, &peer);
695 if let Some(channel) = refreshed_room.channel.as_ref() {
696 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
697 }
698 contacts_to_update
699 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
700 contacts_to_update
701 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
702 canceled_calls_to_user_ids =
703 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
704 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
705 delete_live_kit_room = refreshed_room.room.participants.is_empty();
706 }
707
708 {
709 let pool = pool.lock();
710 for canceled_user_id in canceled_calls_to_user_ids {
711 for connection_id in pool.user_connection_ids(canceled_user_id) {
712 peer.send(
713 connection_id,
714 proto::CallCanceled {
715 room_id: room_id.to_proto(),
716 },
717 )
718 .trace_err();
719 }
720 }
721 }
722
723 for user_id in contacts_to_update {
724 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
725 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
726 if let Some((busy, contacts)) = busy.zip(contacts) {
727 let pool = pool.lock();
728 let updated_contact = contact_for_user(user_id, busy, &pool);
729 for contact in contacts {
730 if let db::Contact::Accepted {
731 user_id: contact_user_id,
732 ..
733 } = contact
734 {
735 for contact_conn_id in
736 pool.user_connection_ids(contact_user_id)
737 {
738 peer.send(
739 contact_conn_id,
740 proto::UpdateContacts {
741 contacts: vec![updated_contact.clone()],
742 remove_contacts: Default::default(),
743 incoming_requests: Default::default(),
744 remove_incoming_requests: Default::default(),
745 outgoing_requests: Default::default(),
746 remove_outgoing_requests: Default::default(),
747 },
748 )
749 .trace_err();
750 }
751 }
752 }
753 }
754 }
755
756 if let Some(live_kit) = live_kit_client.as_ref() {
757 if delete_live_kit_room {
758 live_kit.delete_room(live_kit_room).await.trace_err();
759 }
760 }
761 }
762 }
763
764 app_state
765 .db
766 .delete_stale_servers(&app_state.config.zed_environment, server_id)
767 .await
768 .trace_err();
769 }
770 .instrument(span),
771 );
772 Ok(())
773 }
774
775 pub fn teardown(&self) {
776 self.peer.teardown();
777 self.connection_pool.lock().reset();
778 let _ = self.teardown.send(true);
779 }
780
781 #[cfg(test)]
782 pub fn reset(&self, id: ServerId) {
783 self.teardown();
784 *self.id.lock() = id;
785 self.peer.reset(id.0 as u32);
786 let _ = self.teardown.send(false);
787 }
788
789 #[cfg(test)]
790 pub fn id(&self) -> ServerId {
791 *self.id.lock()
792 }
793
794 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
795 where
796 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
797 Fut: 'static + Send + Future<Output = Result<()>>,
798 M: EnvelopedMessage,
799 {
800 let prev_handler = self.handlers.insert(
801 TypeId::of::<M>(),
802 Box::new(move |envelope, session| {
803 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
804 let received_at = envelope.received_at;
805 tracing::info!("message received");
806 let start_time = Instant::now();
807 let future = (handler)(*envelope, session);
808 async move {
809 let result = future.await;
810 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
811 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
812 let queue_duration_ms = total_duration_ms - processing_duration_ms;
813 let payload_type = M::NAME;
814
815 match result {
816 Err(error) => {
817 tracing::error!(
818 ?error,
819 total_duration_ms,
820 processing_duration_ms,
821 queue_duration_ms,
822 payload_type,
823 "error handling message"
824 )
825 }
826 Ok(()) => tracing::info!(
827 total_duration_ms,
828 processing_duration_ms,
829 queue_duration_ms,
830 "finished handling message"
831 ),
832 }
833 }
834 .boxed()
835 }),
836 );
837 if prev_handler.is_some() {
838 panic!("registered a handler for the same message twice");
839 }
840 self
841 }
842
843 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
844 where
845 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
846 Fut: 'static + Send + Future<Output = Result<()>>,
847 M: EnvelopedMessage,
848 {
849 self.add_handler(move |envelope, session| handler(envelope.payload, session));
850 self
851 }
852
853 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
854 where
855 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
856 Fut: Send + Future<Output = Result<()>>,
857 M: RequestMessage,
858 {
859 let handler = Arc::new(handler);
860 self.add_handler(move |envelope, session| {
861 let receipt = envelope.receipt();
862 let handler = handler.clone();
863 async move {
864 let peer = session.peer.clone();
865 let responded = Arc::new(AtomicBool::default());
866 let response = Response {
867 peer: peer.clone(),
868 responded: responded.clone(),
869 receipt,
870 };
871 match (handler)(envelope.payload, response, session).await {
872 Ok(()) => {
873 if responded.load(std::sync::atomic::Ordering::SeqCst) {
874 Ok(())
875 } else {
876 Err(anyhow!("handler did not send a response"))?
877 }
878 }
879 Err(error) => {
880 let proto_err = match &error {
881 Error::Internal(err) => err.to_proto(),
882 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
883 };
884 peer.respond_with_error(receipt, proto_err)?;
885 Err(error)
886 }
887 }
888 }
889 })
890 }
891
892 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
893 where
894 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
895 Fut: Send + Future<Output = Result<()>>,
896 M: RequestMessage,
897 {
898 let handler = Arc::new(handler);
899 self.add_handler(move |envelope, session| {
900 let receipt = envelope.receipt();
901 let handler = handler.clone();
902 async move {
903 let peer = session.peer.clone();
904 let response = StreamingResponse {
905 peer: peer.clone(),
906 receipt,
907 };
908 match (handler)(envelope.payload, response, session).await {
909 Ok(()) => {
910 peer.end_stream(receipt)?;
911 Ok(())
912 }
913 Err(error) => {
914 let proto_err = match &error {
915 Error::Internal(err) => err.to_proto(),
916 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
917 };
918 peer.respond_with_error(receipt, proto_err)?;
919 Err(error)
920 }
921 }
922 }
923 })
924 }
925
926 #[allow(clippy::too_many_arguments)]
927 pub fn handle_connection(
928 self: &Arc<Self>,
929 connection: Connection,
930 address: String,
931 principal: Principal,
932 zed_version: ZedVersion,
933 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
934 executor: Executor,
935 ) -> impl Future<Output = ()> {
936 let this = self.clone();
937 let span = info_span!("handle connection", %address,
938 connection_id=field::Empty,
939 user_id=field::Empty,
940 login=field::Empty,
941 impersonator=field::Empty,
942 dev_server_id=field::Empty
943 );
944 principal.update_span(&span);
945
946 let mut teardown = self.teardown.subscribe();
947 async move {
948 if *teardown.borrow() {
949 tracing::error!("server is tearing down");
950 return
951 }
952 let (connection_id, handle_io, mut incoming_rx) = this
953 .peer
954 .add_connection(connection, {
955 let executor = executor.clone();
956 move |duration| executor.sleep(duration)
957 });
958 tracing::Span::current().record("connection_id", format!("{}", connection_id));
959 tracing::info!("connection opened");
960
961 let http_client = match IsahcHttpClient::new() {
962 Ok(http_client) => Arc::new(http_client),
963 Err(error) => {
964 tracing::error!(?error, "failed to create HTTP client");
965 return;
966 }
967 };
968
969 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
970 Some(Arc::new(SupermavenAdminApi::new(
971 supermaven_admin_api_key.to_string(),
972 http_client.clone(),
973 )))
974 } else {
975 None
976 };
977
978 let session = Session {
979 principal: principal.clone(),
980 connection_id,
981 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
982 peer: this.peer.clone(),
983 connection_pool: this.connection_pool.clone(),
984 live_kit_client: this.app_state.live_kit_client.clone(),
985 http_client,
986 rate_limiter: this.app_state.rate_limiter.clone(),
987 _executor: executor.clone(),
988 supermaven_client,
989 };
990
991 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
992 tracing::error!(?error, "failed to send initial client update");
993 return;
994 }
995
996 let handle_io = handle_io.fuse();
997 futures::pin_mut!(handle_io);
998
999 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1000 // This prevents deadlocks when e.g., client A performs a request to client B and
1001 // client B performs a request to client A. If both clients stop processing further
1002 // messages until their respective request completes, they won't have a chance to
1003 // respond to the other client's request and cause a deadlock.
1004 //
1005 // This arrangement ensures we will attempt to process earlier messages first, but fall
1006 // back to processing messages arrived later in the spirit of making progress.
1007 let mut foreground_message_handlers = FuturesUnordered::new();
1008 let concurrent_handlers = Arc::new(Semaphore::new(256));
1009 loop {
1010 let next_message = async {
1011 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1012 let message = incoming_rx.next().await;
1013 (permit, message)
1014 }.fuse();
1015 futures::pin_mut!(next_message);
1016 futures::select_biased! {
1017 _ = teardown.changed().fuse() => return,
1018 result = handle_io => {
1019 if let Err(error) = result {
1020 tracing::error!(?error, "error handling I/O");
1021 }
1022 break;
1023 }
1024 _ = foreground_message_handlers.next() => {}
1025 next_message = next_message => {
1026 let (permit, message) = next_message;
1027 if let Some(message) = message {
1028 let type_name = message.payload_type_name();
1029 // note: we copy all the fields from the parent span so we can query them in the logs.
1030 // (https://github.com/tokio-rs/tracing/issues/2670).
1031 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1032 user_id=field::Empty,
1033 login=field::Empty,
1034 impersonator=field::Empty,
1035 dev_server_id=field::Empty
1036 );
1037 principal.update_span(&span);
1038 let span_enter = span.enter();
1039 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1040 let is_background = message.is_background();
1041 let handle_message = (handler)(message, session.clone());
1042 drop(span_enter);
1043
1044 let handle_message = async move {
1045 handle_message.await;
1046 drop(permit);
1047 }.instrument(span);
1048 if is_background {
1049 executor.spawn_detached(handle_message);
1050 } else {
1051 foreground_message_handlers.push(handle_message);
1052 }
1053 } else {
1054 tracing::error!("no message handler");
1055 }
1056 } else {
1057 tracing::info!("connection closed");
1058 break;
1059 }
1060 }
1061 }
1062 }
1063
1064 drop(foreground_message_handlers);
1065 tracing::info!("signing out");
1066 if let Err(error) = connection_lost(session, teardown, executor).await {
1067 tracing::error!(?error, "error signing out");
1068 }
1069
1070 }.instrument(span)
1071 }
1072
1073 async fn send_initial_client_update(
1074 &self,
1075 connection_id: ConnectionId,
1076 principal: &Principal,
1077 zed_version: ZedVersion,
1078 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1079 session: &Session,
1080 ) -> Result<()> {
1081 self.peer.send(
1082 connection_id,
1083 proto::Hello {
1084 peer_id: Some(connection_id.into()),
1085 },
1086 )?;
1087 tracing::info!("sent hello message");
1088 if let Some(send_connection_id) = send_connection_id.take() {
1089 let _ = send_connection_id.send(connection_id);
1090 }
1091
1092 match principal {
1093 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1094 if !user.connected_once {
1095 self.peer.send(connection_id, proto::ShowContacts {})?;
1096 self.app_state
1097 .db
1098 .set_user_connected_once(user.id, true)
1099 .await?;
1100 }
1101
1102 let (contacts, channels_for_user, channel_invites, dev_server_projects) =
1103 future::try_join4(
1104 self.app_state.db.get_contacts(user.id),
1105 self.app_state.db.get_channels_for_user(user.id),
1106 self.app_state.db.get_channel_invites_for_user(user.id),
1107 self.app_state.db.dev_server_projects_update(user.id),
1108 )
1109 .await?;
1110
1111 {
1112 let mut pool = self.connection_pool.lock();
1113 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1114 for membership in &channels_for_user.channel_memberships {
1115 pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
1116 }
1117 self.peer.send(
1118 connection_id,
1119 build_initial_contacts_update(contacts, &pool),
1120 )?;
1121 self.peer.send(
1122 connection_id,
1123 build_update_user_channels(&channels_for_user),
1124 )?;
1125 self.peer.send(
1126 connection_id,
1127 build_channels_update(channels_for_user, channel_invites),
1128 )?;
1129 }
1130 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1131
1132 if let Some(incoming_call) =
1133 self.app_state.db.incoming_call_for_user(user.id).await?
1134 {
1135 self.peer.send(connection_id, incoming_call)?;
1136 }
1137
1138 update_user_contacts(user.id, &session).await?;
1139 }
1140 Principal::DevServer(dev_server) => {
1141 {
1142 let mut pool = self.connection_pool.lock();
1143 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1144 {
1145 self.peer.send(
1146 stale_connection_id,
1147 proto::ShutdownDevServer {
1148 reason: Some(
1149 "another dev server connected with the same token".to_string(),
1150 ),
1151 },
1152 )?;
1153 pool.remove_connection(stale_connection_id)?;
1154 };
1155 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1156 }
1157
1158 let projects = self
1159 .app_state
1160 .db
1161 .get_projects_for_dev_server(dev_server.id)
1162 .await?;
1163 self.peer
1164 .send(connection_id, proto::DevServerInstructions { projects })?;
1165
1166 let status = self
1167 .app_state
1168 .db
1169 .dev_server_projects_update(dev_server.user_id)
1170 .await?;
1171 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1172 }
1173 }
1174
1175 Ok(())
1176 }
1177
1178 pub async fn invite_code_redeemed(
1179 self: &Arc<Self>,
1180 inviter_id: UserId,
1181 invitee_id: UserId,
1182 ) -> Result<()> {
1183 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1184 if let Some(code) = &user.invite_code {
1185 let pool = self.connection_pool.lock();
1186 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1187 for connection_id in pool.user_connection_ids(inviter_id) {
1188 self.peer.send(
1189 connection_id,
1190 proto::UpdateContacts {
1191 contacts: vec![invitee_contact.clone()],
1192 ..Default::default()
1193 },
1194 )?;
1195 self.peer.send(
1196 connection_id,
1197 proto::UpdateInviteInfo {
1198 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1199 count: user.invite_count as u32,
1200 },
1201 )?;
1202 }
1203 }
1204 }
1205 Ok(())
1206 }
1207
1208 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1209 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1210 if let Some(invite_code) = &user.invite_code {
1211 let pool = self.connection_pool.lock();
1212 for connection_id in pool.user_connection_ids(user_id) {
1213 self.peer.send(
1214 connection_id,
1215 proto::UpdateInviteInfo {
1216 url: format!(
1217 "{}{}",
1218 self.app_state.config.invite_link_prefix, invite_code
1219 ),
1220 count: user.invite_count as u32,
1221 },
1222 )?;
1223 }
1224 }
1225 }
1226 Ok(())
1227 }
1228
1229 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1230 ServerSnapshot {
1231 connection_pool: ConnectionPoolGuard {
1232 guard: self.connection_pool.lock(),
1233 _not_send: PhantomData,
1234 },
1235 peer: &self.peer,
1236 }
1237 }
1238}
1239
1240impl<'a> Deref for ConnectionPoolGuard<'a> {
1241 type Target = ConnectionPool;
1242
1243 fn deref(&self) -> &Self::Target {
1244 &self.guard
1245 }
1246}
1247
1248impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1249 fn deref_mut(&mut self) -> &mut Self::Target {
1250 &mut self.guard
1251 }
1252}
1253
1254impl<'a> Drop for ConnectionPoolGuard<'a> {
1255 fn drop(&mut self) {
1256 #[cfg(test)]
1257 self.check_invariants();
1258 }
1259}
1260
1261fn broadcast<F>(
1262 sender_id: Option<ConnectionId>,
1263 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1264 mut f: F,
1265) where
1266 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1267{
1268 for receiver_id in receiver_ids {
1269 if Some(receiver_id) != sender_id {
1270 if let Err(error) = f(receiver_id) {
1271 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1272 }
1273 }
1274 }
1275}
1276
1277pub struct ProtocolVersion(u32);
1278
1279impl Header for ProtocolVersion {
1280 fn name() -> &'static HeaderName {
1281 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1282 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1283 }
1284
1285 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1286 where
1287 Self: Sized,
1288 I: Iterator<Item = &'i axum::http::HeaderValue>,
1289 {
1290 let version = values
1291 .next()
1292 .ok_or_else(axum::headers::Error::invalid)?
1293 .to_str()
1294 .map_err(|_| axum::headers::Error::invalid())?
1295 .parse()
1296 .map_err(|_| axum::headers::Error::invalid())?;
1297 Ok(Self(version))
1298 }
1299
1300 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1301 values.extend([self.0.to_string().parse().unwrap()]);
1302 }
1303}
1304
1305pub struct AppVersionHeader(SemanticVersion);
1306impl Header for AppVersionHeader {
1307 fn name() -> &'static HeaderName {
1308 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1309 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1310 }
1311
1312 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1313 where
1314 Self: Sized,
1315 I: Iterator<Item = &'i axum::http::HeaderValue>,
1316 {
1317 let version = values
1318 .next()
1319 .ok_or_else(axum::headers::Error::invalid)?
1320 .to_str()
1321 .map_err(|_| axum::headers::Error::invalid())?
1322 .parse()
1323 .map_err(|_| axum::headers::Error::invalid())?;
1324 Ok(Self(version))
1325 }
1326
1327 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1328 values.extend([self.0.to_string().parse().unwrap()]);
1329 }
1330}
1331
1332pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1333 Router::new()
1334 .route("/rpc", get(handle_websocket_request))
1335 .layer(
1336 ServiceBuilder::new()
1337 .layer(Extension(server.app_state.clone()))
1338 .layer(middleware::from_fn(auth::validate_header)),
1339 )
1340 .route("/metrics", get(handle_metrics))
1341 .layer(Extension(server))
1342}
1343
1344pub async fn handle_websocket_request(
1345 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1346 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1347 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1348 Extension(server): Extension<Arc<Server>>,
1349 Extension(principal): Extension<Principal>,
1350 ws: WebSocketUpgrade,
1351) -> axum::response::Response {
1352 if protocol_version != rpc::PROTOCOL_VERSION {
1353 return (
1354 StatusCode::UPGRADE_REQUIRED,
1355 "client must be upgraded".to_string(),
1356 )
1357 .into_response();
1358 }
1359
1360 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1361 return (
1362 StatusCode::UPGRADE_REQUIRED,
1363 "no version header found".to_string(),
1364 )
1365 .into_response();
1366 };
1367
1368 if !version.can_collaborate() {
1369 return (
1370 StatusCode::UPGRADE_REQUIRED,
1371 "client must be upgraded".to_string(),
1372 )
1373 .into_response();
1374 }
1375
1376 let socket_address = socket_address.to_string();
1377 ws.on_upgrade(move |socket| {
1378 let socket = socket
1379 .map_ok(to_tungstenite_message)
1380 .err_into()
1381 .with(|message| async move { Ok(to_axum_message(message)) });
1382 let connection = Connection::new(Box::pin(socket));
1383 async move {
1384 server
1385 .handle_connection(
1386 connection,
1387 socket_address,
1388 principal,
1389 version,
1390 None,
1391 Executor::Production,
1392 )
1393 .await;
1394 }
1395 })
1396}
1397
1398pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1399 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1400 let connections_metric = CONNECTIONS_METRIC
1401 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1402
1403 let connections = server
1404 .connection_pool
1405 .lock()
1406 .connections()
1407 .filter(|connection| !connection.admin)
1408 .count();
1409 connections_metric.set(connections as _);
1410
1411 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1412 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1413 register_int_gauge!(
1414 "shared_projects",
1415 "number of open projects with one or more guests"
1416 )
1417 .unwrap()
1418 });
1419
1420 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1421 shared_projects_metric.set(shared_projects as _);
1422
1423 let encoder = prometheus::TextEncoder::new();
1424 let metric_families = prometheus::gather();
1425 let encoded_metrics = encoder
1426 .encode_to_string(&metric_families)
1427 .map_err(|err| anyhow!("{}", err))?;
1428 Ok(encoded_metrics)
1429}
1430
1431#[instrument(err, skip(executor))]
1432async fn connection_lost(
1433 session: Session,
1434 mut teardown: watch::Receiver<bool>,
1435 executor: Executor,
1436) -> Result<()> {
1437 session.peer.disconnect(session.connection_id);
1438 session
1439 .connection_pool()
1440 .await
1441 .remove_connection(session.connection_id)?;
1442
1443 session
1444 .db()
1445 .await
1446 .connection_lost(session.connection_id)
1447 .await
1448 .trace_err();
1449
1450 futures::select_biased! {
1451 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1452 match &session.principal {
1453 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1454 let session = session.for_user().unwrap();
1455
1456 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1457 leave_room_for_session(&session, session.connection_id).await.trace_err();
1458 leave_channel_buffers_for_session(&session)
1459 .await
1460 .trace_err();
1461
1462 if !session
1463 .connection_pool()
1464 .await
1465 .is_user_online(session.user_id())
1466 {
1467 let db = session.db().await;
1468 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1469 room_updated(&room, &session.peer);
1470 }
1471 }
1472
1473 update_user_contacts(session.user_id(), &session).await?;
1474 },
1475 Principal::DevServer(_) => {
1476 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1477 },
1478 }
1479 },
1480 _ = teardown.changed().fuse() => {}
1481 }
1482
1483 Ok(())
1484}
1485
1486/// Acknowledges a ping from a client, used to keep the connection alive.
1487async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1488 response.send(proto::Ack {})?;
1489 Ok(())
1490}
1491
1492/// Creates a new room for calling (outside of channels)
1493async fn create_room(
1494 _request: proto::CreateRoom,
1495 response: Response<proto::CreateRoom>,
1496 session: UserSession,
1497) -> Result<()> {
1498 let live_kit_room = nanoid::nanoid!(30);
1499
1500 let live_kit_connection_info = util::maybe!(async {
1501 let live_kit = session.live_kit_client.as_ref();
1502 let live_kit = live_kit?;
1503 let user_id = session.user_id().to_string();
1504
1505 let token = live_kit
1506 .room_token(&live_kit_room, &user_id.to_string())
1507 .trace_err()?;
1508
1509 Some(proto::LiveKitConnectionInfo {
1510 server_url: live_kit.url().into(),
1511 token,
1512 can_publish: true,
1513 })
1514 })
1515 .await;
1516
1517 let room = session
1518 .db()
1519 .await
1520 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1521 .await?;
1522
1523 response.send(proto::CreateRoomResponse {
1524 room: Some(room.clone()),
1525 live_kit_connection_info,
1526 })?;
1527
1528 update_user_contacts(session.user_id(), &session).await?;
1529 Ok(())
1530}
1531
1532/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1533async fn join_room(
1534 request: proto::JoinRoom,
1535 response: Response<proto::JoinRoom>,
1536 session: UserSession,
1537) -> Result<()> {
1538 let room_id = RoomId::from_proto(request.id);
1539
1540 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1541
1542 if let Some(channel_id) = channel_id {
1543 return join_channel_internal(channel_id, Box::new(response), session).await;
1544 }
1545
1546 let joined_room = {
1547 let room = session
1548 .db()
1549 .await
1550 .join_room(room_id, session.user_id(), session.connection_id)
1551 .await?;
1552 room_updated(&room.room, &session.peer);
1553 room.into_inner()
1554 };
1555
1556 for connection_id in session
1557 .connection_pool()
1558 .await
1559 .user_connection_ids(session.user_id())
1560 {
1561 session
1562 .peer
1563 .send(
1564 connection_id,
1565 proto::CallCanceled {
1566 room_id: room_id.to_proto(),
1567 },
1568 )
1569 .trace_err();
1570 }
1571
1572 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1573 if let Some(token) = live_kit
1574 .room_token(
1575 &joined_room.room.live_kit_room,
1576 &session.user_id().to_string(),
1577 )
1578 .trace_err()
1579 {
1580 Some(proto::LiveKitConnectionInfo {
1581 server_url: live_kit.url().into(),
1582 token,
1583 can_publish: true,
1584 })
1585 } else {
1586 None
1587 }
1588 } else {
1589 None
1590 };
1591
1592 response.send(proto::JoinRoomResponse {
1593 room: Some(joined_room.room),
1594 channel_id: None,
1595 live_kit_connection_info,
1596 })?;
1597
1598 update_user_contacts(session.user_id(), &session).await?;
1599 Ok(())
1600}
1601
1602/// Rejoin room is used to reconnect to a room after connection errors.
1603async fn rejoin_room(
1604 request: proto::RejoinRoom,
1605 response: Response<proto::RejoinRoom>,
1606 session: UserSession,
1607) -> Result<()> {
1608 let room;
1609 let channel;
1610 {
1611 let mut rejoined_room = session
1612 .db()
1613 .await
1614 .rejoin_room(request, session.user_id(), session.connection_id)
1615 .await?;
1616
1617 response.send(proto::RejoinRoomResponse {
1618 room: Some(rejoined_room.room.clone()),
1619 reshared_projects: rejoined_room
1620 .reshared_projects
1621 .iter()
1622 .map(|project| proto::ResharedProject {
1623 id: project.id.to_proto(),
1624 collaborators: project
1625 .collaborators
1626 .iter()
1627 .map(|collaborator| collaborator.to_proto())
1628 .collect(),
1629 })
1630 .collect(),
1631 rejoined_projects: rejoined_room
1632 .rejoined_projects
1633 .iter()
1634 .map(|rejoined_project| rejoined_project.to_proto())
1635 .collect(),
1636 })?;
1637 room_updated(&rejoined_room.room, &session.peer);
1638
1639 for project in &rejoined_room.reshared_projects {
1640 for collaborator in &project.collaborators {
1641 session
1642 .peer
1643 .send(
1644 collaborator.connection_id,
1645 proto::UpdateProjectCollaborator {
1646 project_id: project.id.to_proto(),
1647 old_peer_id: Some(project.old_connection_id.into()),
1648 new_peer_id: Some(session.connection_id.into()),
1649 },
1650 )
1651 .trace_err();
1652 }
1653
1654 broadcast(
1655 Some(session.connection_id),
1656 project
1657 .collaborators
1658 .iter()
1659 .map(|collaborator| collaborator.connection_id),
1660 |connection_id| {
1661 session.peer.forward_send(
1662 session.connection_id,
1663 connection_id,
1664 proto::UpdateProject {
1665 project_id: project.id.to_proto(),
1666 worktrees: project.worktrees.clone(),
1667 },
1668 )
1669 },
1670 );
1671 }
1672
1673 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1674
1675 let rejoined_room = rejoined_room.into_inner();
1676
1677 room = rejoined_room.room;
1678 channel = rejoined_room.channel;
1679 }
1680
1681 if let Some(channel) = channel {
1682 channel_updated(
1683 &channel,
1684 &room,
1685 &session.peer,
1686 &*session.connection_pool().await,
1687 );
1688 }
1689
1690 update_user_contacts(session.user_id(), &session).await?;
1691 Ok(())
1692}
1693
1694fn notify_rejoined_projects(
1695 rejoined_projects: &mut Vec<RejoinedProject>,
1696 session: &UserSession,
1697) -> Result<()> {
1698 for project in rejoined_projects.iter() {
1699 for collaborator in &project.collaborators {
1700 session
1701 .peer
1702 .send(
1703 collaborator.connection_id,
1704 proto::UpdateProjectCollaborator {
1705 project_id: project.id.to_proto(),
1706 old_peer_id: Some(project.old_connection_id.into()),
1707 new_peer_id: Some(session.connection_id.into()),
1708 },
1709 )
1710 .trace_err();
1711 }
1712 }
1713
1714 for project in rejoined_projects {
1715 for worktree in mem::take(&mut project.worktrees) {
1716 #[cfg(any(test, feature = "test-support"))]
1717 const MAX_CHUNK_SIZE: usize = 2;
1718 #[cfg(not(any(test, feature = "test-support")))]
1719 const MAX_CHUNK_SIZE: usize = 256;
1720
1721 // Stream this worktree's entries.
1722 let message = proto::UpdateWorktree {
1723 project_id: project.id.to_proto(),
1724 worktree_id: worktree.id,
1725 abs_path: worktree.abs_path.clone(),
1726 root_name: worktree.root_name,
1727 updated_entries: worktree.updated_entries,
1728 removed_entries: worktree.removed_entries,
1729 scan_id: worktree.scan_id,
1730 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1731 updated_repositories: worktree.updated_repositories,
1732 removed_repositories: worktree.removed_repositories,
1733 };
1734 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1735 session.peer.send(session.connection_id, update.clone())?;
1736 }
1737
1738 // Stream this worktree's diagnostics.
1739 for summary in worktree.diagnostic_summaries {
1740 session.peer.send(
1741 session.connection_id,
1742 proto::UpdateDiagnosticSummary {
1743 project_id: project.id.to_proto(),
1744 worktree_id: worktree.id,
1745 summary: Some(summary),
1746 },
1747 )?;
1748 }
1749
1750 for settings_file in worktree.settings_files {
1751 session.peer.send(
1752 session.connection_id,
1753 proto::UpdateWorktreeSettings {
1754 project_id: project.id.to_proto(),
1755 worktree_id: worktree.id,
1756 path: settings_file.path,
1757 content: Some(settings_file.content),
1758 },
1759 )?;
1760 }
1761 }
1762
1763 for language_server in &project.language_servers {
1764 session.peer.send(
1765 session.connection_id,
1766 proto::UpdateLanguageServer {
1767 project_id: project.id.to_proto(),
1768 language_server_id: language_server.id,
1769 variant: Some(
1770 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1771 proto::LspDiskBasedDiagnosticsUpdated {},
1772 ),
1773 ),
1774 },
1775 )?;
1776 }
1777 }
1778 Ok(())
1779}
1780
1781/// leave room disconnects from the room.
1782async fn leave_room(
1783 _: proto::LeaveRoom,
1784 response: Response<proto::LeaveRoom>,
1785 session: UserSession,
1786) -> Result<()> {
1787 leave_room_for_session(&session, session.connection_id).await?;
1788 response.send(proto::Ack {})?;
1789 Ok(())
1790}
1791
1792/// Updates the permissions of someone else in the room.
1793async fn set_room_participant_role(
1794 request: proto::SetRoomParticipantRole,
1795 response: Response<proto::SetRoomParticipantRole>,
1796 session: UserSession,
1797) -> Result<()> {
1798 let user_id = UserId::from_proto(request.user_id);
1799 let role = ChannelRole::from(request.role());
1800
1801 let (live_kit_room, can_publish) = {
1802 let room = session
1803 .db()
1804 .await
1805 .set_room_participant_role(
1806 session.user_id(),
1807 RoomId::from_proto(request.room_id),
1808 user_id,
1809 role,
1810 )
1811 .await?;
1812
1813 let live_kit_room = room.live_kit_room.clone();
1814 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1815 room_updated(&room, &session.peer);
1816 (live_kit_room, can_publish)
1817 };
1818
1819 if let Some(live_kit) = session.live_kit_client.as_ref() {
1820 live_kit
1821 .update_participant(
1822 live_kit_room.clone(),
1823 request.user_id.to_string(),
1824 live_kit_server::proto::ParticipantPermission {
1825 can_subscribe: true,
1826 can_publish,
1827 can_publish_data: can_publish,
1828 hidden: false,
1829 recorder: false,
1830 },
1831 )
1832 .await
1833 .trace_err();
1834 }
1835
1836 response.send(proto::Ack {})?;
1837 Ok(())
1838}
1839
1840/// Call someone else into the current room
1841async fn call(
1842 request: proto::Call,
1843 response: Response<proto::Call>,
1844 session: UserSession,
1845) -> Result<()> {
1846 let room_id = RoomId::from_proto(request.room_id);
1847 let calling_user_id = session.user_id();
1848 let calling_connection_id = session.connection_id;
1849 let called_user_id = UserId::from_proto(request.called_user_id);
1850 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1851 if !session
1852 .db()
1853 .await
1854 .has_contact(calling_user_id, called_user_id)
1855 .await?
1856 {
1857 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1858 }
1859
1860 let incoming_call = {
1861 let (room, incoming_call) = &mut *session
1862 .db()
1863 .await
1864 .call(
1865 room_id,
1866 calling_user_id,
1867 calling_connection_id,
1868 called_user_id,
1869 initial_project_id,
1870 )
1871 .await?;
1872 room_updated(&room, &session.peer);
1873 mem::take(incoming_call)
1874 };
1875 update_user_contacts(called_user_id, &session).await?;
1876
1877 let mut calls = session
1878 .connection_pool()
1879 .await
1880 .user_connection_ids(called_user_id)
1881 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1882 .collect::<FuturesUnordered<_>>();
1883
1884 while let Some(call_response) = calls.next().await {
1885 match call_response.as_ref() {
1886 Ok(_) => {
1887 response.send(proto::Ack {})?;
1888 return Ok(());
1889 }
1890 Err(_) => {
1891 call_response.trace_err();
1892 }
1893 }
1894 }
1895
1896 {
1897 let room = session
1898 .db()
1899 .await
1900 .call_failed(room_id, called_user_id)
1901 .await?;
1902 room_updated(&room, &session.peer);
1903 }
1904 update_user_contacts(called_user_id, &session).await?;
1905
1906 Err(anyhow!("failed to ring user"))?
1907}
1908
1909/// Cancel an outgoing call.
1910async fn cancel_call(
1911 request: proto::CancelCall,
1912 response: Response<proto::CancelCall>,
1913 session: UserSession,
1914) -> Result<()> {
1915 let called_user_id = UserId::from_proto(request.called_user_id);
1916 let room_id = RoomId::from_proto(request.room_id);
1917 {
1918 let room = session
1919 .db()
1920 .await
1921 .cancel_call(room_id, session.connection_id, called_user_id)
1922 .await?;
1923 room_updated(&room, &session.peer);
1924 }
1925
1926 for connection_id in session
1927 .connection_pool()
1928 .await
1929 .user_connection_ids(called_user_id)
1930 {
1931 session
1932 .peer
1933 .send(
1934 connection_id,
1935 proto::CallCanceled {
1936 room_id: room_id.to_proto(),
1937 },
1938 )
1939 .trace_err();
1940 }
1941 response.send(proto::Ack {})?;
1942
1943 update_user_contacts(called_user_id, &session).await?;
1944 Ok(())
1945}
1946
1947/// Decline an incoming call.
1948async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1949 let room_id = RoomId::from_proto(message.room_id);
1950 {
1951 let room = session
1952 .db()
1953 .await
1954 .decline_call(Some(room_id), session.user_id())
1955 .await?
1956 .ok_or_else(|| anyhow!("failed to decline call"))?;
1957 room_updated(&room, &session.peer);
1958 }
1959
1960 for connection_id in session
1961 .connection_pool()
1962 .await
1963 .user_connection_ids(session.user_id())
1964 {
1965 session
1966 .peer
1967 .send(
1968 connection_id,
1969 proto::CallCanceled {
1970 room_id: room_id.to_proto(),
1971 },
1972 )
1973 .trace_err();
1974 }
1975 update_user_contacts(session.user_id(), &session).await?;
1976 Ok(())
1977}
1978
1979/// Updates other participants in the room with your current location.
1980async fn update_participant_location(
1981 request: proto::UpdateParticipantLocation,
1982 response: Response<proto::UpdateParticipantLocation>,
1983 session: UserSession,
1984) -> Result<()> {
1985 let room_id = RoomId::from_proto(request.room_id);
1986 let location = request
1987 .location
1988 .ok_or_else(|| anyhow!("invalid location"))?;
1989
1990 let db = session.db().await;
1991 let room = db
1992 .update_room_participant_location(room_id, session.connection_id, location)
1993 .await?;
1994
1995 room_updated(&room, &session.peer);
1996 response.send(proto::Ack {})?;
1997 Ok(())
1998}
1999
2000/// Share a project into the room.
2001async fn share_project(
2002 request: proto::ShareProject,
2003 response: Response<proto::ShareProject>,
2004 session: UserSession,
2005) -> Result<()> {
2006 let (project_id, room) = &*session
2007 .db()
2008 .await
2009 .share_project(
2010 RoomId::from_proto(request.room_id),
2011 session.connection_id,
2012 &request.worktrees,
2013 request
2014 .dev_server_project_id
2015 .map(|id| DevServerProjectId::from_proto(id)),
2016 )
2017 .await?;
2018 response.send(proto::ShareProjectResponse {
2019 project_id: project_id.to_proto(),
2020 })?;
2021 room_updated(&room, &session.peer);
2022
2023 Ok(())
2024}
2025
2026/// Unshare a project from the room.
2027async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2028 let project_id = ProjectId::from_proto(message.project_id);
2029 unshare_project_internal(
2030 project_id,
2031 session.connection_id,
2032 session.user_id(),
2033 &session,
2034 )
2035 .await
2036}
2037
2038async fn unshare_project_internal(
2039 project_id: ProjectId,
2040 connection_id: ConnectionId,
2041 user_id: Option<UserId>,
2042 session: &Session,
2043) -> Result<()> {
2044 let delete = {
2045 let room_guard = session
2046 .db()
2047 .await
2048 .unshare_project(project_id, connection_id, user_id)
2049 .await?;
2050
2051 let (delete, room, guest_connection_ids) = &*room_guard;
2052
2053 let message = proto::UnshareProject {
2054 project_id: project_id.to_proto(),
2055 };
2056
2057 broadcast(
2058 Some(connection_id),
2059 guest_connection_ids.iter().copied(),
2060 |conn_id| session.peer.send(conn_id, message.clone()),
2061 );
2062 if let Some(room) = room {
2063 room_updated(room, &session.peer);
2064 }
2065
2066 *delete
2067 };
2068
2069 if delete {
2070 let db = session.db().await;
2071 db.delete_project(project_id).await?;
2072 }
2073
2074 Ok(())
2075}
2076
2077/// DevServer makes a project available online
2078async fn share_dev_server_project(
2079 request: proto::ShareDevServerProject,
2080 response: Response<proto::ShareDevServerProject>,
2081 session: DevServerSession,
2082) -> Result<()> {
2083 let (dev_server_project, user_id, status) = session
2084 .db()
2085 .await
2086 .share_dev_server_project(
2087 DevServerProjectId::from_proto(request.dev_server_project_id),
2088 session.dev_server_id(),
2089 session.connection_id,
2090 &request.worktrees,
2091 )
2092 .await?;
2093 let Some(project_id) = dev_server_project.project_id else {
2094 return Err(anyhow!("failed to share remote project"))?;
2095 };
2096
2097 send_dev_server_projects_update(user_id, status, &session).await;
2098
2099 response.send(proto::ShareProjectResponse { project_id })?;
2100
2101 Ok(())
2102}
2103
2104/// Join someone elses shared project.
2105async fn join_project(
2106 request: proto::JoinProject,
2107 response: Response<proto::JoinProject>,
2108 session: UserSession,
2109) -> Result<()> {
2110 let project_id = ProjectId::from_proto(request.project_id);
2111
2112 tracing::info!(%project_id, "join project");
2113
2114 let db = session.db().await;
2115 let (project, replica_id) = &mut *db
2116 .join_project(project_id, session.connection_id, session.user_id())
2117 .await?;
2118 drop(db);
2119 tracing::info!(%project_id, "join remote project");
2120 join_project_internal(response, session, project, replica_id)
2121}
2122
2123trait JoinProjectInternalResponse {
2124 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2125}
2126impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2127 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2128 Response::<proto::JoinProject>::send(self, result)
2129 }
2130}
2131impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2132 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2133 Response::<proto::JoinHostedProject>::send(self, result)
2134 }
2135}
2136
2137fn join_project_internal(
2138 response: impl JoinProjectInternalResponse,
2139 session: UserSession,
2140 project: &mut Project,
2141 replica_id: &ReplicaId,
2142) -> Result<()> {
2143 let collaborators = project
2144 .collaborators
2145 .iter()
2146 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2147 .map(|collaborator| collaborator.to_proto())
2148 .collect::<Vec<_>>();
2149 let project_id = project.id;
2150 let guest_user_id = session.user_id();
2151
2152 let worktrees = project
2153 .worktrees
2154 .iter()
2155 .map(|(id, worktree)| proto::WorktreeMetadata {
2156 id: *id,
2157 root_name: worktree.root_name.clone(),
2158 visible: worktree.visible,
2159 abs_path: worktree.abs_path.clone(),
2160 })
2161 .collect::<Vec<_>>();
2162
2163 let add_project_collaborator = proto::AddProjectCollaborator {
2164 project_id: project_id.to_proto(),
2165 collaborator: Some(proto::Collaborator {
2166 peer_id: Some(session.connection_id.into()),
2167 replica_id: replica_id.0 as u32,
2168 user_id: guest_user_id.to_proto(),
2169 }),
2170 };
2171
2172 for collaborator in &collaborators {
2173 session
2174 .peer
2175 .send(
2176 collaborator.peer_id.unwrap().into(),
2177 add_project_collaborator.clone(),
2178 )
2179 .trace_err();
2180 }
2181
2182 // First, we send the metadata associated with each worktree.
2183 response.send(proto::JoinProjectResponse {
2184 project_id: project.id.0 as u64,
2185 worktrees: worktrees.clone(),
2186 replica_id: replica_id.0 as u32,
2187 collaborators: collaborators.clone(),
2188 language_servers: project.language_servers.clone(),
2189 role: project.role.into(),
2190 dev_server_project_id: project
2191 .dev_server_project_id
2192 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2193 })?;
2194
2195 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2196 #[cfg(any(test, feature = "test-support"))]
2197 const MAX_CHUNK_SIZE: usize = 2;
2198 #[cfg(not(any(test, feature = "test-support")))]
2199 const MAX_CHUNK_SIZE: usize = 256;
2200
2201 // Stream this worktree's entries.
2202 let message = proto::UpdateWorktree {
2203 project_id: project_id.to_proto(),
2204 worktree_id,
2205 abs_path: worktree.abs_path.clone(),
2206 root_name: worktree.root_name,
2207 updated_entries: worktree.entries,
2208 removed_entries: Default::default(),
2209 scan_id: worktree.scan_id,
2210 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2211 updated_repositories: worktree.repository_entries.into_values().collect(),
2212 removed_repositories: Default::default(),
2213 };
2214 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2215 session.peer.send(session.connection_id, update.clone())?;
2216 }
2217
2218 // Stream this worktree's diagnostics.
2219 for summary in worktree.diagnostic_summaries {
2220 session.peer.send(
2221 session.connection_id,
2222 proto::UpdateDiagnosticSummary {
2223 project_id: project_id.to_proto(),
2224 worktree_id: worktree.id,
2225 summary: Some(summary),
2226 },
2227 )?;
2228 }
2229
2230 for settings_file in worktree.settings_files {
2231 session.peer.send(
2232 session.connection_id,
2233 proto::UpdateWorktreeSettings {
2234 project_id: project_id.to_proto(),
2235 worktree_id: worktree.id,
2236 path: settings_file.path,
2237 content: Some(settings_file.content),
2238 },
2239 )?;
2240 }
2241 }
2242
2243 for language_server in &project.language_servers {
2244 session.peer.send(
2245 session.connection_id,
2246 proto::UpdateLanguageServer {
2247 project_id: project_id.to_proto(),
2248 language_server_id: language_server.id,
2249 variant: Some(
2250 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2251 proto::LspDiskBasedDiagnosticsUpdated {},
2252 ),
2253 ),
2254 },
2255 )?;
2256 }
2257
2258 Ok(())
2259}
2260
2261/// Leave someone elses shared project.
2262async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2263 let sender_id = session.connection_id;
2264 let project_id = ProjectId::from_proto(request.project_id);
2265 let db = session.db().await;
2266 if db.is_hosted_project(project_id).await? {
2267 let project = db.leave_hosted_project(project_id, sender_id).await?;
2268 project_left(&project, &session);
2269 return Ok(());
2270 }
2271
2272 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2273 tracing::info!(
2274 %project_id,
2275 "leave project"
2276 );
2277
2278 project_left(&project, &session);
2279 if let Some(room) = room {
2280 room_updated(&room, &session.peer);
2281 }
2282
2283 Ok(())
2284}
2285
2286async fn join_hosted_project(
2287 request: proto::JoinHostedProject,
2288 response: Response<proto::JoinHostedProject>,
2289 session: UserSession,
2290) -> Result<()> {
2291 let (mut project, replica_id) = session
2292 .db()
2293 .await
2294 .join_hosted_project(
2295 ProjectId(request.project_id as i32),
2296 session.user_id(),
2297 session.connection_id,
2298 )
2299 .await?;
2300
2301 join_project_internal(response, session, &mut project, &replica_id)
2302}
2303
2304async fn create_dev_server_project(
2305 request: proto::CreateDevServerProject,
2306 response: Response<proto::CreateDevServerProject>,
2307 session: UserSession,
2308) -> Result<()> {
2309 let dev_server_id = DevServerId(request.dev_server_id as i32);
2310 let dev_server_connection_id = session
2311 .connection_pool()
2312 .await
2313 .dev_server_connection_id(dev_server_id);
2314 let Some(dev_server_connection_id) = dev_server_connection_id else {
2315 Err(ErrorCode::DevServerOffline
2316 .message("Cannot create a remote project when the dev server is offline".to_string())
2317 .anyhow())?
2318 };
2319
2320 let path = request.path.clone();
2321 //Check that the path exists on the dev server
2322 session
2323 .peer
2324 .forward_request(
2325 session.connection_id,
2326 dev_server_connection_id,
2327 proto::ValidateDevServerProjectRequest { path: path.clone() },
2328 )
2329 .await?;
2330
2331 let (dev_server_project, update) = session
2332 .db()
2333 .await
2334 .create_dev_server_project(
2335 DevServerId(request.dev_server_id as i32),
2336 &request.path,
2337 session.user_id(),
2338 )
2339 .await?;
2340
2341 let projects = session
2342 .db()
2343 .await
2344 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2345 .await?;
2346
2347 session.peer.send(
2348 dev_server_connection_id,
2349 proto::DevServerInstructions { projects },
2350 )?;
2351
2352 send_dev_server_projects_update(session.user_id(), update, &session).await;
2353
2354 response.send(proto::CreateDevServerProjectResponse {
2355 dev_server_project: Some(dev_server_project.to_proto(None)),
2356 })?;
2357 Ok(())
2358}
2359
2360async fn create_dev_server(
2361 request: proto::CreateDevServer,
2362 response: Response<proto::CreateDevServer>,
2363 session: UserSession,
2364) -> Result<()> {
2365 let access_token = auth::random_token();
2366 let hashed_access_token = auth::hash_access_token(&access_token);
2367
2368 if request.name.is_empty() {
2369 return Err(proto::ErrorCode::Forbidden
2370 .message("Dev server name cannot be empty".to_string())
2371 .anyhow())?;
2372 }
2373
2374 let (dev_server, status) = session
2375 .db()
2376 .await
2377 .create_dev_server(
2378 &request.name,
2379 request.ssh_connection_string.as_deref(),
2380 &hashed_access_token,
2381 session.user_id(),
2382 )
2383 .await?;
2384
2385 send_dev_server_projects_update(session.user_id(), status, &session).await;
2386
2387 response.send(proto::CreateDevServerResponse {
2388 dev_server_id: dev_server.id.0 as u64,
2389 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2390 name: request.name,
2391 })?;
2392 Ok(())
2393}
2394
2395async fn regenerate_dev_server_token(
2396 request: proto::RegenerateDevServerToken,
2397 response: Response<proto::RegenerateDevServerToken>,
2398 session: UserSession,
2399) -> Result<()> {
2400 let dev_server_id = DevServerId(request.dev_server_id as i32);
2401 let access_token = auth::random_token();
2402 let hashed_access_token = auth::hash_access_token(&access_token);
2403
2404 let connection_id = session
2405 .connection_pool()
2406 .await
2407 .dev_server_connection_id(dev_server_id);
2408 if let Some(connection_id) = connection_id {
2409 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2410 session.peer.send(
2411 connection_id,
2412 proto::ShutdownDevServer {
2413 reason: Some("dev server token was regenerated".to_string()),
2414 },
2415 )?;
2416 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2417 }
2418
2419 let status = session
2420 .db()
2421 .await
2422 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2423 .await?;
2424
2425 send_dev_server_projects_update(session.user_id(), status, &session).await;
2426
2427 response.send(proto::RegenerateDevServerTokenResponse {
2428 dev_server_id: dev_server_id.to_proto(),
2429 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2430 })?;
2431 Ok(())
2432}
2433
2434async fn rename_dev_server(
2435 request: proto::RenameDevServer,
2436 response: Response<proto::RenameDevServer>,
2437 session: UserSession,
2438) -> Result<()> {
2439 if request.name.trim().is_empty() {
2440 return Err(proto::ErrorCode::Forbidden
2441 .message("Dev server name cannot be empty".to_string())
2442 .anyhow())?;
2443 }
2444
2445 let dev_server_id = DevServerId(request.dev_server_id as i32);
2446 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2447 if dev_server.user_id != session.user_id() {
2448 return Err(anyhow!(ErrorCode::Forbidden))?;
2449 }
2450
2451 let status = session
2452 .db()
2453 .await
2454 .rename_dev_server(
2455 dev_server_id,
2456 &request.name,
2457 request.ssh_connection_string.as_deref(),
2458 session.user_id(),
2459 )
2460 .await?;
2461
2462 send_dev_server_projects_update(session.user_id(), status, &session).await;
2463
2464 response.send(proto::Ack {})?;
2465 Ok(())
2466}
2467
2468async fn delete_dev_server(
2469 request: proto::DeleteDevServer,
2470 response: Response<proto::DeleteDevServer>,
2471 session: UserSession,
2472) -> Result<()> {
2473 let dev_server_id = DevServerId(request.dev_server_id as i32);
2474 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2475 if dev_server.user_id != session.user_id() {
2476 return Err(anyhow!(ErrorCode::Forbidden))?;
2477 }
2478
2479 let connection_id = session
2480 .connection_pool()
2481 .await
2482 .dev_server_connection_id(dev_server_id);
2483 if let Some(connection_id) = connection_id {
2484 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2485 session.peer.send(
2486 connection_id,
2487 proto::ShutdownDevServer {
2488 reason: Some("dev server was deleted".to_string()),
2489 },
2490 )?;
2491 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2492 }
2493
2494 let status = session
2495 .db()
2496 .await
2497 .delete_dev_server(dev_server_id, session.user_id())
2498 .await?;
2499
2500 send_dev_server_projects_update(session.user_id(), status, &session).await;
2501
2502 response.send(proto::Ack {})?;
2503 Ok(())
2504}
2505
2506async fn delete_dev_server_project(
2507 request: proto::DeleteDevServerProject,
2508 response: Response<proto::DeleteDevServerProject>,
2509 session: UserSession,
2510) -> Result<()> {
2511 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2512 let dev_server_project = session
2513 .db()
2514 .await
2515 .get_dev_server_project(dev_server_project_id)
2516 .await?;
2517
2518 let dev_server = session
2519 .db()
2520 .await
2521 .get_dev_server(dev_server_project.dev_server_id)
2522 .await?;
2523 if dev_server.user_id != session.user_id() {
2524 return Err(anyhow!(ErrorCode::Forbidden))?;
2525 }
2526
2527 let dev_server_connection_id = session
2528 .connection_pool()
2529 .await
2530 .dev_server_connection_id(dev_server.id);
2531
2532 if let Some(dev_server_connection_id) = dev_server_connection_id {
2533 let project = session
2534 .db()
2535 .await
2536 .find_dev_server_project(dev_server_project_id)
2537 .await;
2538 if let Ok(project) = project {
2539 unshare_project_internal(
2540 project.id,
2541 dev_server_connection_id,
2542 Some(session.user_id()),
2543 &session,
2544 )
2545 .await?;
2546 }
2547 }
2548
2549 let (projects, status) = session
2550 .db()
2551 .await
2552 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2553 .await?;
2554
2555 if let Some(dev_server_connection_id) = dev_server_connection_id {
2556 session.peer.send(
2557 dev_server_connection_id,
2558 proto::DevServerInstructions { projects },
2559 )?;
2560 }
2561
2562 send_dev_server_projects_update(session.user_id(), status, &session).await;
2563
2564 response.send(proto::Ack {})?;
2565 Ok(())
2566}
2567
2568async fn rejoin_dev_server_projects(
2569 request: proto::RejoinRemoteProjects,
2570 response: Response<proto::RejoinRemoteProjects>,
2571 session: UserSession,
2572) -> Result<()> {
2573 let mut rejoined_projects = {
2574 let db = session.db().await;
2575 db.rejoin_dev_server_projects(
2576 &request.rejoined_projects,
2577 session.user_id(),
2578 session.0.connection_id,
2579 )
2580 .await?
2581 };
2582 notify_rejoined_projects(&mut rejoined_projects, &session)?;
2583
2584 response.send(proto::RejoinRemoteProjectsResponse {
2585 rejoined_projects: rejoined_projects
2586 .into_iter()
2587 .map(|project| project.to_proto())
2588 .collect(),
2589 })
2590}
2591
2592async fn reconnect_dev_server(
2593 request: proto::ReconnectDevServer,
2594 response: Response<proto::ReconnectDevServer>,
2595 session: DevServerSession,
2596) -> Result<()> {
2597 let reshared_projects = {
2598 let db = session.db().await;
2599 db.reshare_dev_server_projects(
2600 &request.reshared_projects,
2601 session.dev_server_id(),
2602 session.0.connection_id,
2603 )
2604 .await?
2605 };
2606
2607 for project in &reshared_projects {
2608 for collaborator in &project.collaborators {
2609 session
2610 .peer
2611 .send(
2612 collaborator.connection_id,
2613 proto::UpdateProjectCollaborator {
2614 project_id: project.id.to_proto(),
2615 old_peer_id: Some(project.old_connection_id.into()),
2616 new_peer_id: Some(session.connection_id.into()),
2617 },
2618 )
2619 .trace_err();
2620 }
2621
2622 broadcast(
2623 Some(session.connection_id),
2624 project
2625 .collaborators
2626 .iter()
2627 .map(|collaborator| collaborator.connection_id),
2628 |connection_id| {
2629 session.peer.forward_send(
2630 session.connection_id,
2631 connection_id,
2632 proto::UpdateProject {
2633 project_id: project.id.to_proto(),
2634 worktrees: project.worktrees.clone(),
2635 },
2636 )
2637 },
2638 );
2639 }
2640
2641 response.send(proto::ReconnectDevServerResponse {
2642 reshared_projects: reshared_projects
2643 .iter()
2644 .map(|project| proto::ResharedProject {
2645 id: project.id.to_proto(),
2646 collaborators: project
2647 .collaborators
2648 .iter()
2649 .map(|collaborator| collaborator.to_proto())
2650 .collect(),
2651 })
2652 .collect(),
2653 })?;
2654
2655 Ok(())
2656}
2657
2658async fn shutdown_dev_server(
2659 _: proto::ShutdownDevServer,
2660 response: Response<proto::ShutdownDevServer>,
2661 session: DevServerSession,
2662) -> Result<()> {
2663 response.send(proto::Ack {})?;
2664 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2665 remove_dev_server_connection(session.dev_server_id(), &session).await
2666}
2667
2668async fn shutdown_dev_server_internal(
2669 dev_server_id: DevServerId,
2670 connection_id: ConnectionId,
2671 session: &Session,
2672) -> Result<()> {
2673 let (dev_server_projects, dev_server) = {
2674 let db = session.db().await;
2675 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2676 let dev_server = db.get_dev_server(dev_server_id).await?;
2677 (dev_server_projects, dev_server)
2678 };
2679
2680 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2681 unshare_project_internal(
2682 ProjectId::from_proto(project_id),
2683 connection_id,
2684 None,
2685 session,
2686 )
2687 .await?;
2688 }
2689
2690 session
2691 .connection_pool()
2692 .await
2693 .set_dev_server_offline(dev_server_id);
2694
2695 let status = session
2696 .db()
2697 .await
2698 .dev_server_projects_update(dev_server.user_id)
2699 .await?;
2700 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2701
2702 Ok(())
2703}
2704
2705async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2706 let dev_server_connection = session
2707 .connection_pool()
2708 .await
2709 .dev_server_connection_id(dev_server_id);
2710
2711 if let Some(dev_server_connection) = dev_server_connection {
2712 session
2713 .connection_pool()
2714 .await
2715 .remove_connection(dev_server_connection)?;
2716 }
2717 Ok(())
2718}
2719
2720/// Updates other participants with changes to the project
2721async fn update_project(
2722 request: proto::UpdateProject,
2723 response: Response<proto::UpdateProject>,
2724 session: Session,
2725) -> Result<()> {
2726 let project_id = ProjectId::from_proto(request.project_id);
2727 let (room, guest_connection_ids) = &*session
2728 .db()
2729 .await
2730 .update_project(project_id, session.connection_id, &request.worktrees)
2731 .await?;
2732 broadcast(
2733 Some(session.connection_id),
2734 guest_connection_ids.iter().copied(),
2735 |connection_id| {
2736 session
2737 .peer
2738 .forward_send(session.connection_id, connection_id, request.clone())
2739 },
2740 );
2741 if let Some(room) = room {
2742 room_updated(&room, &session.peer);
2743 }
2744 response.send(proto::Ack {})?;
2745
2746 Ok(())
2747}
2748
2749/// Updates other participants with changes to the worktree
2750async fn update_worktree(
2751 request: proto::UpdateWorktree,
2752 response: Response<proto::UpdateWorktree>,
2753 session: Session,
2754) -> Result<()> {
2755 let guest_connection_ids = session
2756 .db()
2757 .await
2758 .update_worktree(&request, session.connection_id)
2759 .await?;
2760
2761 broadcast(
2762 Some(session.connection_id),
2763 guest_connection_ids.iter().copied(),
2764 |connection_id| {
2765 session
2766 .peer
2767 .forward_send(session.connection_id, connection_id, request.clone())
2768 },
2769 );
2770 response.send(proto::Ack {})?;
2771 Ok(())
2772}
2773
2774/// Updates other participants with changes to the diagnostics
2775async fn update_diagnostic_summary(
2776 message: proto::UpdateDiagnosticSummary,
2777 session: Session,
2778) -> Result<()> {
2779 let guest_connection_ids = session
2780 .db()
2781 .await
2782 .update_diagnostic_summary(&message, session.connection_id)
2783 .await?;
2784
2785 broadcast(
2786 Some(session.connection_id),
2787 guest_connection_ids.iter().copied(),
2788 |connection_id| {
2789 session
2790 .peer
2791 .forward_send(session.connection_id, connection_id, message.clone())
2792 },
2793 );
2794
2795 Ok(())
2796}
2797
2798/// Updates other participants with changes to the worktree settings
2799async fn update_worktree_settings(
2800 message: proto::UpdateWorktreeSettings,
2801 session: Session,
2802) -> Result<()> {
2803 let guest_connection_ids = session
2804 .db()
2805 .await
2806 .update_worktree_settings(&message, session.connection_id)
2807 .await?;
2808
2809 broadcast(
2810 Some(session.connection_id),
2811 guest_connection_ids.iter().copied(),
2812 |connection_id| {
2813 session
2814 .peer
2815 .forward_send(session.connection_id, connection_id, message.clone())
2816 },
2817 );
2818
2819 Ok(())
2820}
2821
2822/// Notify other participants that a language server has started.
2823async fn start_language_server(
2824 request: proto::StartLanguageServer,
2825 session: Session,
2826) -> Result<()> {
2827 let guest_connection_ids = session
2828 .db()
2829 .await
2830 .start_language_server(&request, session.connection_id)
2831 .await?;
2832
2833 broadcast(
2834 Some(session.connection_id),
2835 guest_connection_ids.iter().copied(),
2836 |connection_id| {
2837 session
2838 .peer
2839 .forward_send(session.connection_id, connection_id, request.clone())
2840 },
2841 );
2842 Ok(())
2843}
2844
2845/// Notify other participants that a language server has changed.
2846async fn update_language_server(
2847 request: proto::UpdateLanguageServer,
2848 session: Session,
2849) -> Result<()> {
2850 let project_id = ProjectId::from_proto(request.project_id);
2851 let project_connection_ids = session
2852 .db()
2853 .await
2854 .project_connection_ids(project_id, session.connection_id, true)
2855 .await?;
2856 broadcast(
2857 Some(session.connection_id),
2858 project_connection_ids.iter().copied(),
2859 |connection_id| {
2860 session
2861 .peer
2862 .forward_send(session.connection_id, connection_id, request.clone())
2863 },
2864 );
2865 Ok(())
2866}
2867
2868/// forward a project request to the host. These requests should be read only
2869/// as guests are allowed to send them.
2870async fn forward_read_only_project_request<T>(
2871 request: T,
2872 response: Response<T>,
2873 session: UserSession,
2874) -> Result<()>
2875where
2876 T: EntityMessage + RequestMessage,
2877{
2878 let project_id = ProjectId::from_proto(request.remote_entity_id());
2879 let host_connection_id = session
2880 .db()
2881 .await
2882 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2883 .await?;
2884 let payload = session
2885 .peer
2886 .forward_request(session.connection_id, host_connection_id, request)
2887 .await?;
2888 response.send(payload)?;
2889 Ok(())
2890}
2891
2892/// forward a project request to the host. These requests are disallowed
2893/// for guests.
2894async fn forward_mutating_project_request<T>(
2895 request: T,
2896 response: Response<T>,
2897 session: UserSession,
2898) -> Result<()>
2899where
2900 T: EntityMessage + RequestMessage,
2901{
2902 let project_id = ProjectId::from_proto(request.remote_entity_id());
2903
2904 let host_connection_id = session
2905 .db()
2906 .await
2907 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2908 .await?;
2909 let payload = session
2910 .peer
2911 .forward_request(session.connection_id, host_connection_id, request)
2912 .await?;
2913 response.send(payload)?;
2914 Ok(())
2915}
2916
2917/// forward a project request to the host. These requests are disallowed
2918/// for guests.
2919async fn forward_versioned_mutating_project_request<T>(
2920 request: T,
2921 response: Response<T>,
2922 session: UserSession,
2923) -> Result<()>
2924where
2925 T: EntityMessage + RequestMessage + VersionedMessage,
2926{
2927 let project_id = ProjectId::from_proto(request.remote_entity_id());
2928
2929 let host_connection_id = session
2930 .db()
2931 .await
2932 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2933 .await?;
2934 if let Some(host_version) = session
2935 .connection_pool()
2936 .await
2937 .connection(host_connection_id)
2938 .map(|c| c.zed_version)
2939 {
2940 if let Some(min_required_version) = request.required_host_version() {
2941 if min_required_version > host_version {
2942 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2943 .with_tag("required", &min_required_version.to_string())))?;
2944 }
2945 }
2946 }
2947
2948 let payload = session
2949 .peer
2950 .forward_request(session.connection_id, host_connection_id, request)
2951 .await?;
2952 response.send(payload)?;
2953 Ok(())
2954}
2955
2956/// Notify other participants that a new buffer has been created
2957async fn create_buffer_for_peer(
2958 request: proto::CreateBufferForPeer,
2959 session: Session,
2960) -> Result<()> {
2961 session
2962 .db()
2963 .await
2964 .check_user_is_project_host(
2965 ProjectId::from_proto(request.project_id),
2966 session.connection_id,
2967 )
2968 .await?;
2969 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2970 session
2971 .peer
2972 .forward_send(session.connection_id, peer_id.into(), request)?;
2973 Ok(())
2974}
2975
2976/// Notify other participants that a buffer has been updated. This is
2977/// allowed for guests as long as the update is limited to selections.
2978async fn update_buffer(
2979 request: proto::UpdateBuffer,
2980 response: Response<proto::UpdateBuffer>,
2981 session: Session,
2982) -> Result<()> {
2983 let project_id = ProjectId::from_proto(request.project_id);
2984 let mut capability = Capability::ReadOnly;
2985
2986 for op in request.operations.iter() {
2987 match op.variant {
2988 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2989 Some(_) => capability = Capability::ReadWrite,
2990 }
2991 }
2992
2993 let host = {
2994 let guard = session
2995 .db()
2996 .await
2997 .connections_for_buffer_update(
2998 project_id,
2999 session.principal_id(),
3000 session.connection_id,
3001 capability,
3002 )
3003 .await?;
3004
3005 let (host, guests) = &*guard;
3006
3007 broadcast(
3008 Some(session.connection_id),
3009 guests.clone(),
3010 |connection_id| {
3011 session
3012 .peer
3013 .forward_send(session.connection_id, connection_id, request.clone())
3014 },
3015 );
3016
3017 *host
3018 };
3019
3020 if host != session.connection_id {
3021 session
3022 .peer
3023 .forward_request(session.connection_id, host, request.clone())
3024 .await?;
3025 }
3026
3027 response.send(proto::Ack {})?;
3028 Ok(())
3029}
3030
3031/// Notify other participants that a project has been updated.
3032async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3033 request: T,
3034 session: Session,
3035) -> Result<()> {
3036 let project_id = ProjectId::from_proto(request.remote_entity_id());
3037 let project_connection_ids = session
3038 .db()
3039 .await
3040 .project_connection_ids(project_id, session.connection_id, false)
3041 .await?;
3042
3043 broadcast(
3044 Some(session.connection_id),
3045 project_connection_ids.iter().copied(),
3046 |connection_id| {
3047 session
3048 .peer
3049 .forward_send(session.connection_id, connection_id, request.clone())
3050 },
3051 );
3052 Ok(())
3053}
3054
3055/// Start following another user in a call.
3056async fn follow(
3057 request: proto::Follow,
3058 response: Response<proto::Follow>,
3059 session: UserSession,
3060) -> Result<()> {
3061 let room_id = RoomId::from_proto(request.room_id);
3062 let project_id = request.project_id.map(ProjectId::from_proto);
3063 let leader_id = request
3064 .leader_id
3065 .ok_or_else(|| anyhow!("invalid leader id"))?
3066 .into();
3067 let follower_id = session.connection_id;
3068
3069 session
3070 .db()
3071 .await
3072 .check_room_participants(room_id, leader_id, session.connection_id)
3073 .await?;
3074
3075 let response_payload = session
3076 .peer
3077 .forward_request(session.connection_id, leader_id, request)
3078 .await?;
3079 response.send(response_payload)?;
3080
3081 if let Some(project_id) = project_id {
3082 let room = session
3083 .db()
3084 .await
3085 .follow(room_id, project_id, leader_id, follower_id)
3086 .await?;
3087 room_updated(&room, &session.peer);
3088 }
3089
3090 Ok(())
3091}
3092
3093/// Stop following another user in a call.
3094async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3095 let room_id = RoomId::from_proto(request.room_id);
3096 let project_id = request.project_id.map(ProjectId::from_proto);
3097 let leader_id = request
3098 .leader_id
3099 .ok_or_else(|| anyhow!("invalid leader id"))?
3100 .into();
3101 let follower_id = session.connection_id;
3102
3103 session
3104 .db()
3105 .await
3106 .check_room_participants(room_id, leader_id, session.connection_id)
3107 .await?;
3108
3109 session
3110 .peer
3111 .forward_send(session.connection_id, leader_id, request)?;
3112
3113 if let Some(project_id) = project_id {
3114 let room = session
3115 .db()
3116 .await
3117 .unfollow(room_id, project_id, leader_id, follower_id)
3118 .await?;
3119 room_updated(&room, &session.peer);
3120 }
3121
3122 Ok(())
3123}
3124
3125/// Notify everyone following you of your current location.
3126async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3127 let room_id = RoomId::from_proto(request.room_id);
3128 let database = session.db.lock().await;
3129
3130 let connection_ids = if let Some(project_id) = request.project_id {
3131 let project_id = ProjectId::from_proto(project_id);
3132 database
3133 .project_connection_ids(project_id, session.connection_id, true)
3134 .await?
3135 } else {
3136 database
3137 .room_connection_ids(room_id, session.connection_id)
3138 .await?
3139 };
3140
3141 // For now, don't send view update messages back to that view's current leader.
3142 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3143 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3144 _ => None,
3145 });
3146
3147 for connection_id in connection_ids.iter().cloned() {
3148 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3149 session
3150 .peer
3151 .forward_send(session.connection_id, connection_id, request.clone())?;
3152 }
3153 }
3154 Ok(())
3155}
3156
3157/// Get public data about users.
3158async fn get_users(
3159 request: proto::GetUsers,
3160 response: Response<proto::GetUsers>,
3161 session: Session,
3162) -> Result<()> {
3163 let user_ids = request
3164 .user_ids
3165 .into_iter()
3166 .map(UserId::from_proto)
3167 .collect();
3168 let users = session
3169 .db()
3170 .await
3171 .get_users_by_ids(user_ids)
3172 .await?
3173 .into_iter()
3174 .map(|user| proto::User {
3175 id: user.id.to_proto(),
3176 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3177 github_login: user.github_login,
3178 })
3179 .collect();
3180 response.send(proto::UsersResponse { users })?;
3181 Ok(())
3182}
3183
3184/// Search for users (to invite) buy Github login
3185async fn fuzzy_search_users(
3186 request: proto::FuzzySearchUsers,
3187 response: Response<proto::FuzzySearchUsers>,
3188 session: UserSession,
3189) -> Result<()> {
3190 let query = request.query;
3191 let users = match query.len() {
3192 0 => vec![],
3193 1 | 2 => session
3194 .db()
3195 .await
3196 .get_user_by_github_login(&query)
3197 .await?
3198 .into_iter()
3199 .collect(),
3200 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3201 };
3202 let users = users
3203 .into_iter()
3204 .filter(|user| user.id != session.user_id())
3205 .map(|user| proto::User {
3206 id: user.id.to_proto(),
3207 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3208 github_login: user.github_login,
3209 })
3210 .collect();
3211 response.send(proto::UsersResponse { users })?;
3212 Ok(())
3213}
3214
3215/// Send a contact request to another user.
3216async fn request_contact(
3217 request: proto::RequestContact,
3218 response: Response<proto::RequestContact>,
3219 session: UserSession,
3220) -> Result<()> {
3221 let requester_id = session.user_id();
3222 let responder_id = UserId::from_proto(request.responder_id);
3223 if requester_id == responder_id {
3224 return Err(anyhow!("cannot add yourself as a contact"))?;
3225 }
3226
3227 let notifications = session
3228 .db()
3229 .await
3230 .send_contact_request(requester_id, responder_id)
3231 .await?;
3232
3233 // Update outgoing contact requests of requester
3234 let mut update = proto::UpdateContacts::default();
3235 update.outgoing_requests.push(responder_id.to_proto());
3236 for connection_id in session
3237 .connection_pool()
3238 .await
3239 .user_connection_ids(requester_id)
3240 {
3241 session.peer.send(connection_id, update.clone())?;
3242 }
3243
3244 // Update incoming contact requests of responder
3245 let mut update = proto::UpdateContacts::default();
3246 update
3247 .incoming_requests
3248 .push(proto::IncomingContactRequest {
3249 requester_id: requester_id.to_proto(),
3250 });
3251 let connection_pool = session.connection_pool().await;
3252 for connection_id in connection_pool.user_connection_ids(responder_id) {
3253 session.peer.send(connection_id, update.clone())?;
3254 }
3255
3256 send_notifications(&connection_pool, &session.peer, notifications);
3257
3258 response.send(proto::Ack {})?;
3259 Ok(())
3260}
3261
3262/// Accept or decline a contact request
3263async fn respond_to_contact_request(
3264 request: proto::RespondToContactRequest,
3265 response: Response<proto::RespondToContactRequest>,
3266 session: UserSession,
3267) -> Result<()> {
3268 let responder_id = session.user_id();
3269 let requester_id = UserId::from_proto(request.requester_id);
3270 let db = session.db().await;
3271 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3272 db.dismiss_contact_notification(responder_id, requester_id)
3273 .await?;
3274 } else {
3275 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3276
3277 let notifications = db
3278 .respond_to_contact_request(responder_id, requester_id, accept)
3279 .await?;
3280 let requester_busy = db.is_user_busy(requester_id).await?;
3281 let responder_busy = db.is_user_busy(responder_id).await?;
3282
3283 let pool = session.connection_pool().await;
3284 // Update responder with new contact
3285 let mut update = proto::UpdateContacts::default();
3286 if accept {
3287 update
3288 .contacts
3289 .push(contact_for_user(requester_id, requester_busy, &pool));
3290 }
3291 update
3292 .remove_incoming_requests
3293 .push(requester_id.to_proto());
3294 for connection_id in pool.user_connection_ids(responder_id) {
3295 session.peer.send(connection_id, update.clone())?;
3296 }
3297
3298 // Update requester with new contact
3299 let mut update = proto::UpdateContacts::default();
3300 if accept {
3301 update
3302 .contacts
3303 .push(contact_for_user(responder_id, responder_busy, &pool));
3304 }
3305 update
3306 .remove_outgoing_requests
3307 .push(responder_id.to_proto());
3308
3309 for connection_id in pool.user_connection_ids(requester_id) {
3310 session.peer.send(connection_id, update.clone())?;
3311 }
3312
3313 send_notifications(&pool, &session.peer, notifications);
3314 }
3315
3316 response.send(proto::Ack {})?;
3317 Ok(())
3318}
3319
3320/// Remove a contact.
3321async fn remove_contact(
3322 request: proto::RemoveContact,
3323 response: Response<proto::RemoveContact>,
3324 session: UserSession,
3325) -> Result<()> {
3326 let requester_id = session.user_id();
3327 let responder_id = UserId::from_proto(request.user_id);
3328 let db = session.db().await;
3329 let (contact_accepted, deleted_notification_id) =
3330 db.remove_contact(requester_id, responder_id).await?;
3331
3332 let pool = session.connection_pool().await;
3333 // Update outgoing contact requests of requester
3334 let mut update = proto::UpdateContacts::default();
3335 if contact_accepted {
3336 update.remove_contacts.push(responder_id.to_proto());
3337 } else {
3338 update
3339 .remove_outgoing_requests
3340 .push(responder_id.to_proto());
3341 }
3342 for connection_id in pool.user_connection_ids(requester_id) {
3343 session.peer.send(connection_id, update.clone())?;
3344 }
3345
3346 // Update incoming contact requests of responder
3347 let mut update = proto::UpdateContacts::default();
3348 if contact_accepted {
3349 update.remove_contacts.push(requester_id.to_proto());
3350 } else {
3351 update
3352 .remove_incoming_requests
3353 .push(requester_id.to_proto());
3354 }
3355 for connection_id in pool.user_connection_ids(responder_id) {
3356 session.peer.send(connection_id, update.clone())?;
3357 if let Some(notification_id) = deleted_notification_id {
3358 session.peer.send(
3359 connection_id,
3360 proto::DeleteNotification {
3361 notification_id: notification_id.to_proto(),
3362 },
3363 )?;
3364 }
3365 }
3366
3367 response.send(proto::Ack {})?;
3368 Ok(())
3369}
3370
3371/// Creates a new channel.
3372async fn create_channel(
3373 request: proto::CreateChannel,
3374 response: Response<proto::CreateChannel>,
3375 session: UserSession,
3376) -> Result<()> {
3377 let db = session.db().await;
3378
3379 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3380 let (channel, membership) = db
3381 .create_channel(&request.name, parent_id, session.user_id())
3382 .await?;
3383
3384 let root_id = channel.root_id();
3385 let channel = Channel::from_model(channel);
3386
3387 response.send(proto::CreateChannelResponse {
3388 channel: Some(channel.to_proto()),
3389 parent_id: request.parent_id,
3390 })?;
3391
3392 let mut connection_pool = session.connection_pool().await;
3393 if let Some(membership) = membership {
3394 connection_pool.subscribe_to_channel(
3395 membership.user_id,
3396 membership.channel_id,
3397 membership.role,
3398 );
3399 let update = proto::UpdateUserChannels {
3400 channel_memberships: vec![proto::ChannelMembership {
3401 channel_id: membership.channel_id.to_proto(),
3402 role: membership.role.into(),
3403 }],
3404 ..Default::default()
3405 };
3406 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3407 session.peer.send(connection_id, update.clone())?;
3408 }
3409 }
3410
3411 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3412 if !role.can_see_channel(channel.visibility) {
3413 continue;
3414 }
3415
3416 let update = proto::UpdateChannels {
3417 channels: vec![channel.to_proto()],
3418 ..Default::default()
3419 };
3420 session.peer.send(connection_id, update.clone())?;
3421 }
3422
3423 Ok(())
3424}
3425
3426/// Delete a channel
3427async fn delete_channel(
3428 request: proto::DeleteChannel,
3429 response: Response<proto::DeleteChannel>,
3430 session: UserSession,
3431) -> Result<()> {
3432 let db = session.db().await;
3433
3434 let channel_id = request.channel_id;
3435 let (root_channel, removed_channels) = db
3436 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3437 .await?;
3438 response.send(proto::Ack {})?;
3439
3440 // Notify members of removed channels
3441 let mut update = proto::UpdateChannels::default();
3442 update
3443 .delete_channels
3444 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3445
3446 let connection_pool = session.connection_pool().await;
3447 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3448 session.peer.send(connection_id, update.clone())?;
3449 }
3450
3451 Ok(())
3452}
3453
3454/// Invite someone to join a channel.
3455async fn invite_channel_member(
3456 request: proto::InviteChannelMember,
3457 response: Response<proto::InviteChannelMember>,
3458 session: UserSession,
3459) -> Result<()> {
3460 let db = session.db().await;
3461 let channel_id = ChannelId::from_proto(request.channel_id);
3462 let invitee_id = UserId::from_proto(request.user_id);
3463 let InviteMemberResult {
3464 channel,
3465 notifications,
3466 } = db
3467 .invite_channel_member(
3468 channel_id,
3469 invitee_id,
3470 session.user_id(),
3471 request.role().into(),
3472 )
3473 .await?;
3474
3475 let update = proto::UpdateChannels {
3476 channel_invitations: vec![channel.to_proto()],
3477 ..Default::default()
3478 };
3479
3480 let connection_pool = session.connection_pool().await;
3481 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3482 session.peer.send(connection_id, update.clone())?;
3483 }
3484
3485 send_notifications(&connection_pool, &session.peer, notifications);
3486
3487 response.send(proto::Ack {})?;
3488 Ok(())
3489}
3490
3491/// remove someone from a channel
3492async fn remove_channel_member(
3493 request: proto::RemoveChannelMember,
3494 response: Response<proto::RemoveChannelMember>,
3495 session: UserSession,
3496) -> Result<()> {
3497 let db = session.db().await;
3498 let channel_id = ChannelId::from_proto(request.channel_id);
3499 let member_id = UserId::from_proto(request.user_id);
3500
3501 let RemoveChannelMemberResult {
3502 membership_update,
3503 notification_id,
3504 } = db
3505 .remove_channel_member(channel_id, member_id, session.user_id())
3506 .await?;
3507
3508 let mut connection_pool = session.connection_pool().await;
3509 notify_membership_updated(
3510 &mut connection_pool,
3511 membership_update,
3512 member_id,
3513 &session.peer,
3514 );
3515 for connection_id in connection_pool.user_connection_ids(member_id) {
3516 if let Some(notification_id) = notification_id {
3517 session
3518 .peer
3519 .send(
3520 connection_id,
3521 proto::DeleteNotification {
3522 notification_id: notification_id.to_proto(),
3523 },
3524 )
3525 .trace_err();
3526 }
3527 }
3528
3529 response.send(proto::Ack {})?;
3530 Ok(())
3531}
3532
3533/// Toggle the channel between public and private.
3534/// Care is taken to maintain the invariant that public channels only descend from public channels,
3535/// (though members-only channels can appear at any point in the hierarchy).
3536async fn set_channel_visibility(
3537 request: proto::SetChannelVisibility,
3538 response: Response<proto::SetChannelVisibility>,
3539 session: UserSession,
3540) -> Result<()> {
3541 let db = session.db().await;
3542 let channel_id = ChannelId::from_proto(request.channel_id);
3543 let visibility = request.visibility().into();
3544
3545 let channel_model = db
3546 .set_channel_visibility(channel_id, visibility, session.user_id())
3547 .await?;
3548 let root_id = channel_model.root_id();
3549 let channel = Channel::from_model(channel_model);
3550
3551 let mut connection_pool = session.connection_pool().await;
3552 for (user_id, role) in connection_pool
3553 .channel_user_ids(root_id)
3554 .collect::<Vec<_>>()
3555 .into_iter()
3556 {
3557 let update = if role.can_see_channel(channel.visibility) {
3558 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3559 proto::UpdateChannels {
3560 channels: vec![channel.to_proto()],
3561 ..Default::default()
3562 }
3563 } else {
3564 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3565 proto::UpdateChannels {
3566 delete_channels: vec![channel.id.to_proto()],
3567 ..Default::default()
3568 }
3569 };
3570
3571 for connection_id in connection_pool.user_connection_ids(user_id) {
3572 session.peer.send(connection_id, update.clone())?;
3573 }
3574 }
3575
3576 response.send(proto::Ack {})?;
3577 Ok(())
3578}
3579
3580/// Alter the role for a user in the channel.
3581async fn set_channel_member_role(
3582 request: proto::SetChannelMemberRole,
3583 response: Response<proto::SetChannelMemberRole>,
3584 session: UserSession,
3585) -> Result<()> {
3586 let db = session.db().await;
3587 let channel_id = ChannelId::from_proto(request.channel_id);
3588 let member_id = UserId::from_proto(request.user_id);
3589 let result = db
3590 .set_channel_member_role(
3591 channel_id,
3592 session.user_id(),
3593 member_id,
3594 request.role().into(),
3595 )
3596 .await?;
3597
3598 match result {
3599 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3600 let mut connection_pool = session.connection_pool().await;
3601 notify_membership_updated(
3602 &mut connection_pool,
3603 membership_update,
3604 member_id,
3605 &session.peer,
3606 )
3607 }
3608 db::SetMemberRoleResult::InviteUpdated(channel) => {
3609 let update = proto::UpdateChannels {
3610 channel_invitations: vec![channel.to_proto()],
3611 ..Default::default()
3612 };
3613
3614 for connection_id in session
3615 .connection_pool()
3616 .await
3617 .user_connection_ids(member_id)
3618 {
3619 session.peer.send(connection_id, update.clone())?;
3620 }
3621 }
3622 }
3623
3624 response.send(proto::Ack {})?;
3625 Ok(())
3626}
3627
3628/// Change the name of a channel
3629async fn rename_channel(
3630 request: proto::RenameChannel,
3631 response: Response<proto::RenameChannel>,
3632 session: UserSession,
3633) -> Result<()> {
3634 let db = session.db().await;
3635 let channel_id = ChannelId::from_proto(request.channel_id);
3636 let channel_model = db
3637 .rename_channel(channel_id, session.user_id(), &request.name)
3638 .await?;
3639 let root_id = channel_model.root_id();
3640 let channel = Channel::from_model(channel_model);
3641
3642 response.send(proto::RenameChannelResponse {
3643 channel: Some(channel.to_proto()),
3644 })?;
3645
3646 let connection_pool = session.connection_pool().await;
3647 let update = proto::UpdateChannels {
3648 channels: vec![channel.to_proto()],
3649 ..Default::default()
3650 };
3651 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3652 if role.can_see_channel(channel.visibility) {
3653 session.peer.send(connection_id, update.clone())?;
3654 }
3655 }
3656
3657 Ok(())
3658}
3659
3660/// Move a channel to a new parent.
3661async fn move_channel(
3662 request: proto::MoveChannel,
3663 response: Response<proto::MoveChannel>,
3664 session: UserSession,
3665) -> Result<()> {
3666 let channel_id = ChannelId::from_proto(request.channel_id);
3667 let to = ChannelId::from_proto(request.to);
3668
3669 let (root_id, channels) = session
3670 .db()
3671 .await
3672 .move_channel(channel_id, to, session.user_id())
3673 .await?;
3674
3675 let connection_pool = session.connection_pool().await;
3676 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3677 let channels = channels
3678 .iter()
3679 .filter_map(|channel| {
3680 if role.can_see_channel(channel.visibility) {
3681 Some(channel.to_proto())
3682 } else {
3683 None
3684 }
3685 })
3686 .collect::<Vec<_>>();
3687 if channels.is_empty() {
3688 continue;
3689 }
3690
3691 let update = proto::UpdateChannels {
3692 channels,
3693 ..Default::default()
3694 };
3695
3696 session.peer.send(connection_id, update.clone())?;
3697 }
3698
3699 response.send(Ack {})?;
3700 Ok(())
3701}
3702
3703/// Get the list of channel members
3704async fn get_channel_members(
3705 request: proto::GetChannelMembers,
3706 response: Response<proto::GetChannelMembers>,
3707 session: UserSession,
3708) -> Result<()> {
3709 let db = session.db().await;
3710 let channel_id = ChannelId::from_proto(request.channel_id);
3711 let limit = if request.limit == 0 {
3712 u16::MAX as u64
3713 } else {
3714 request.limit
3715 };
3716 let (members, users) = db
3717 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3718 .await?;
3719 response.send(proto::GetChannelMembersResponse { members, users })?;
3720 Ok(())
3721}
3722
3723/// Accept or decline a channel invitation.
3724async fn respond_to_channel_invite(
3725 request: proto::RespondToChannelInvite,
3726 response: Response<proto::RespondToChannelInvite>,
3727 session: UserSession,
3728) -> Result<()> {
3729 let db = session.db().await;
3730 let channel_id = ChannelId::from_proto(request.channel_id);
3731 let RespondToChannelInvite {
3732 membership_update,
3733 notifications,
3734 } = db
3735 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3736 .await?;
3737
3738 let mut connection_pool = session.connection_pool().await;
3739 if let Some(membership_update) = membership_update {
3740 notify_membership_updated(
3741 &mut connection_pool,
3742 membership_update,
3743 session.user_id(),
3744 &session.peer,
3745 );
3746 } else {
3747 let update = proto::UpdateChannels {
3748 remove_channel_invitations: vec![channel_id.to_proto()],
3749 ..Default::default()
3750 };
3751
3752 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3753 session.peer.send(connection_id, update.clone())?;
3754 }
3755 };
3756
3757 send_notifications(&connection_pool, &session.peer, notifications);
3758
3759 response.send(proto::Ack {})?;
3760
3761 Ok(())
3762}
3763
3764/// Join the channels' room
3765async fn join_channel(
3766 request: proto::JoinChannel,
3767 response: Response<proto::JoinChannel>,
3768 session: UserSession,
3769) -> Result<()> {
3770 let channel_id = ChannelId::from_proto(request.channel_id);
3771 join_channel_internal(channel_id, Box::new(response), session).await
3772}
3773
3774trait JoinChannelInternalResponse {
3775 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3776}
3777impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3778 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3779 Response::<proto::JoinChannel>::send(self, result)
3780 }
3781}
3782impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3783 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3784 Response::<proto::JoinRoom>::send(self, result)
3785 }
3786}
3787
3788async fn join_channel_internal(
3789 channel_id: ChannelId,
3790 response: Box<impl JoinChannelInternalResponse>,
3791 session: UserSession,
3792) -> Result<()> {
3793 let joined_room = {
3794 let mut db = session.db().await;
3795 // If zed quits without leaving the room, and the user re-opens zed before the
3796 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3797 // room they were in.
3798 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3799 tracing::info!(
3800 stale_connection_id = %connection,
3801 "cleaning up stale connection",
3802 );
3803 drop(db);
3804 leave_room_for_session(&session, connection).await?;
3805 db = session.db().await;
3806 }
3807
3808 let (joined_room, membership_updated, role) = db
3809 .join_channel(channel_id, session.user_id(), session.connection_id)
3810 .await?;
3811
3812 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3813 let (can_publish, token) = if role == ChannelRole::Guest {
3814 (
3815 false,
3816 live_kit
3817 .guest_token(
3818 &joined_room.room.live_kit_room,
3819 &session.user_id().to_string(),
3820 )
3821 .trace_err()?,
3822 )
3823 } else {
3824 (
3825 true,
3826 live_kit
3827 .room_token(
3828 &joined_room.room.live_kit_room,
3829 &session.user_id().to_string(),
3830 )
3831 .trace_err()?,
3832 )
3833 };
3834
3835 Some(LiveKitConnectionInfo {
3836 server_url: live_kit.url().into(),
3837 token,
3838 can_publish,
3839 })
3840 });
3841
3842 response.send(proto::JoinRoomResponse {
3843 room: Some(joined_room.room.clone()),
3844 channel_id: joined_room
3845 .channel
3846 .as_ref()
3847 .map(|channel| channel.id.to_proto()),
3848 live_kit_connection_info,
3849 })?;
3850
3851 let mut connection_pool = session.connection_pool().await;
3852 if let Some(membership_updated) = membership_updated {
3853 notify_membership_updated(
3854 &mut connection_pool,
3855 membership_updated,
3856 session.user_id(),
3857 &session.peer,
3858 );
3859 }
3860
3861 room_updated(&joined_room.room, &session.peer);
3862
3863 joined_room
3864 };
3865
3866 channel_updated(
3867 &joined_room
3868 .channel
3869 .ok_or_else(|| anyhow!("channel not returned"))?,
3870 &joined_room.room,
3871 &session.peer,
3872 &*session.connection_pool().await,
3873 );
3874
3875 update_user_contacts(session.user_id(), &session).await?;
3876 Ok(())
3877}
3878
3879/// Start editing the channel notes
3880async fn join_channel_buffer(
3881 request: proto::JoinChannelBuffer,
3882 response: Response<proto::JoinChannelBuffer>,
3883 session: UserSession,
3884) -> Result<()> {
3885 let db = session.db().await;
3886 let channel_id = ChannelId::from_proto(request.channel_id);
3887
3888 let open_response = db
3889 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3890 .await?;
3891
3892 let collaborators = open_response.collaborators.clone();
3893 response.send(open_response)?;
3894
3895 let update = UpdateChannelBufferCollaborators {
3896 channel_id: channel_id.to_proto(),
3897 collaborators: collaborators.clone(),
3898 };
3899 channel_buffer_updated(
3900 session.connection_id,
3901 collaborators
3902 .iter()
3903 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3904 &update,
3905 &session.peer,
3906 );
3907
3908 Ok(())
3909}
3910
3911/// Edit the channel notes
3912async fn update_channel_buffer(
3913 request: proto::UpdateChannelBuffer,
3914 session: UserSession,
3915) -> Result<()> {
3916 let db = session.db().await;
3917 let channel_id = ChannelId::from_proto(request.channel_id);
3918
3919 let (collaborators, epoch, version) = db
3920 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3921 .await?;
3922
3923 channel_buffer_updated(
3924 session.connection_id,
3925 collaborators.clone(),
3926 &proto::UpdateChannelBuffer {
3927 channel_id: channel_id.to_proto(),
3928 operations: request.operations,
3929 },
3930 &session.peer,
3931 );
3932
3933 let pool = &*session.connection_pool().await;
3934
3935 let non_collaborators =
3936 pool.channel_connection_ids(channel_id)
3937 .filter_map(|(connection_id, _)| {
3938 if collaborators.contains(&connection_id) {
3939 None
3940 } else {
3941 Some(connection_id)
3942 }
3943 });
3944
3945 broadcast(None, non_collaborators, |peer_id| {
3946 session.peer.send(
3947 peer_id,
3948 proto::UpdateChannels {
3949 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3950 channel_id: channel_id.to_proto(),
3951 epoch: epoch as u64,
3952 version: version.clone(),
3953 }],
3954 ..Default::default()
3955 },
3956 )
3957 });
3958
3959 Ok(())
3960}
3961
3962/// Rejoin the channel notes after a connection blip
3963async fn rejoin_channel_buffers(
3964 request: proto::RejoinChannelBuffers,
3965 response: Response<proto::RejoinChannelBuffers>,
3966 session: UserSession,
3967) -> Result<()> {
3968 let db = session.db().await;
3969 let buffers = db
3970 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3971 .await?;
3972
3973 for rejoined_buffer in &buffers {
3974 let collaborators_to_notify = rejoined_buffer
3975 .buffer
3976 .collaborators
3977 .iter()
3978 .filter_map(|c| Some(c.peer_id?.into()));
3979 channel_buffer_updated(
3980 session.connection_id,
3981 collaborators_to_notify,
3982 &proto::UpdateChannelBufferCollaborators {
3983 channel_id: rejoined_buffer.buffer.channel_id,
3984 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3985 },
3986 &session.peer,
3987 );
3988 }
3989
3990 response.send(proto::RejoinChannelBuffersResponse {
3991 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3992 })?;
3993
3994 Ok(())
3995}
3996
3997/// Stop editing the channel notes
3998async fn leave_channel_buffer(
3999 request: proto::LeaveChannelBuffer,
4000 response: Response<proto::LeaveChannelBuffer>,
4001 session: UserSession,
4002) -> Result<()> {
4003 let db = session.db().await;
4004 let channel_id = ChannelId::from_proto(request.channel_id);
4005
4006 let left_buffer = db
4007 .leave_channel_buffer(channel_id, session.connection_id)
4008 .await?;
4009
4010 response.send(Ack {})?;
4011
4012 channel_buffer_updated(
4013 session.connection_id,
4014 left_buffer.connections,
4015 &proto::UpdateChannelBufferCollaborators {
4016 channel_id: channel_id.to_proto(),
4017 collaborators: left_buffer.collaborators,
4018 },
4019 &session.peer,
4020 );
4021
4022 Ok(())
4023}
4024
4025fn channel_buffer_updated<T: EnvelopedMessage>(
4026 sender_id: ConnectionId,
4027 collaborators: impl IntoIterator<Item = ConnectionId>,
4028 message: &T,
4029 peer: &Peer,
4030) {
4031 broadcast(Some(sender_id), collaborators, |peer_id| {
4032 peer.send(peer_id, message.clone())
4033 });
4034}
4035
4036fn send_notifications(
4037 connection_pool: &ConnectionPool,
4038 peer: &Peer,
4039 notifications: db::NotificationBatch,
4040) {
4041 for (user_id, notification) in notifications {
4042 for connection_id in connection_pool.user_connection_ids(user_id) {
4043 if let Err(error) = peer.send(
4044 connection_id,
4045 proto::AddNotification {
4046 notification: Some(notification.clone()),
4047 },
4048 ) {
4049 tracing::error!(
4050 "failed to send notification to {:?} {}",
4051 connection_id,
4052 error
4053 );
4054 }
4055 }
4056 }
4057}
4058
4059/// Send a message to the channel
4060async fn send_channel_message(
4061 request: proto::SendChannelMessage,
4062 response: Response<proto::SendChannelMessage>,
4063 session: UserSession,
4064) -> Result<()> {
4065 // Validate the message body.
4066 let body = request.body.trim().to_string();
4067 if body.len() > MAX_MESSAGE_LEN {
4068 return Err(anyhow!("message is too long"))?;
4069 }
4070 if body.is_empty() {
4071 return Err(anyhow!("message can't be blank"))?;
4072 }
4073
4074 // TODO: adjust mentions if body is trimmed
4075
4076 let timestamp = OffsetDateTime::now_utc();
4077 let nonce = request
4078 .nonce
4079 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4080
4081 let channel_id = ChannelId::from_proto(request.channel_id);
4082 let CreatedChannelMessage {
4083 message_id,
4084 participant_connection_ids,
4085 notifications,
4086 } = session
4087 .db()
4088 .await
4089 .create_channel_message(
4090 channel_id,
4091 session.user_id(),
4092 &body,
4093 &request.mentions,
4094 timestamp,
4095 nonce.clone().into(),
4096 match request.reply_to_message_id {
4097 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4098 None => None,
4099 },
4100 )
4101 .await?;
4102
4103 let message = proto::ChannelMessage {
4104 sender_id: session.user_id().to_proto(),
4105 id: message_id.to_proto(),
4106 body,
4107 mentions: request.mentions,
4108 timestamp: timestamp.unix_timestamp() as u64,
4109 nonce: Some(nonce),
4110 reply_to_message_id: request.reply_to_message_id,
4111 edited_at: None,
4112 };
4113 broadcast(
4114 Some(session.connection_id),
4115 participant_connection_ids.clone(),
4116 |connection| {
4117 session.peer.send(
4118 connection,
4119 proto::ChannelMessageSent {
4120 channel_id: channel_id.to_proto(),
4121 message: Some(message.clone()),
4122 },
4123 )
4124 },
4125 );
4126 response.send(proto::SendChannelMessageResponse {
4127 message: Some(message),
4128 })?;
4129
4130 let pool = &*session.connection_pool().await;
4131 let non_participants =
4132 pool.channel_connection_ids(channel_id)
4133 .filter_map(|(connection_id, _)| {
4134 if participant_connection_ids.contains(&connection_id) {
4135 None
4136 } else {
4137 Some(connection_id)
4138 }
4139 });
4140 broadcast(None, non_participants, |peer_id| {
4141 session.peer.send(
4142 peer_id,
4143 proto::UpdateChannels {
4144 latest_channel_message_ids: vec![proto::ChannelMessageId {
4145 channel_id: channel_id.to_proto(),
4146 message_id: message_id.to_proto(),
4147 }],
4148 ..Default::default()
4149 },
4150 )
4151 });
4152 send_notifications(pool, &session.peer, notifications);
4153
4154 Ok(())
4155}
4156
4157/// Delete a channel message
4158async fn remove_channel_message(
4159 request: proto::RemoveChannelMessage,
4160 response: Response<proto::RemoveChannelMessage>,
4161 session: UserSession,
4162) -> Result<()> {
4163 let channel_id = ChannelId::from_proto(request.channel_id);
4164 let message_id = MessageId::from_proto(request.message_id);
4165 let (connection_ids, existing_notification_ids) = session
4166 .db()
4167 .await
4168 .remove_channel_message(channel_id, message_id, session.user_id())
4169 .await?;
4170
4171 broadcast(
4172 Some(session.connection_id),
4173 connection_ids,
4174 move |connection| {
4175 session.peer.send(connection, request.clone())?;
4176
4177 for notification_id in &existing_notification_ids {
4178 session.peer.send(
4179 connection,
4180 proto::DeleteNotification {
4181 notification_id: (*notification_id).to_proto(),
4182 },
4183 )?;
4184 }
4185
4186 Ok(())
4187 },
4188 );
4189 response.send(proto::Ack {})?;
4190 Ok(())
4191}
4192
4193async fn update_channel_message(
4194 request: proto::UpdateChannelMessage,
4195 response: Response<proto::UpdateChannelMessage>,
4196 session: UserSession,
4197) -> Result<()> {
4198 let channel_id = ChannelId::from_proto(request.channel_id);
4199 let message_id = MessageId::from_proto(request.message_id);
4200 let updated_at = OffsetDateTime::now_utc();
4201 let UpdatedChannelMessage {
4202 message_id,
4203 participant_connection_ids,
4204 notifications,
4205 reply_to_message_id,
4206 timestamp,
4207 deleted_mention_notification_ids,
4208 updated_mention_notifications,
4209 } = session
4210 .db()
4211 .await
4212 .update_channel_message(
4213 channel_id,
4214 message_id,
4215 session.user_id(),
4216 request.body.as_str(),
4217 &request.mentions,
4218 updated_at,
4219 )
4220 .await?;
4221
4222 let nonce = request
4223 .nonce
4224 .clone()
4225 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4226
4227 let message = proto::ChannelMessage {
4228 sender_id: session.user_id().to_proto(),
4229 id: message_id.to_proto(),
4230 body: request.body.clone(),
4231 mentions: request.mentions.clone(),
4232 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4233 nonce: Some(nonce),
4234 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4235 edited_at: Some(updated_at.unix_timestamp() as u64),
4236 };
4237
4238 response.send(proto::Ack {})?;
4239
4240 let pool = &*session.connection_pool().await;
4241 broadcast(
4242 Some(session.connection_id),
4243 participant_connection_ids,
4244 |connection| {
4245 session.peer.send(
4246 connection,
4247 proto::ChannelMessageUpdate {
4248 channel_id: channel_id.to_proto(),
4249 message: Some(message.clone()),
4250 },
4251 )?;
4252
4253 for notification_id in &deleted_mention_notification_ids {
4254 session.peer.send(
4255 connection,
4256 proto::DeleteNotification {
4257 notification_id: (*notification_id).to_proto(),
4258 },
4259 )?;
4260 }
4261
4262 for notification in &updated_mention_notifications {
4263 session.peer.send(
4264 connection,
4265 proto::UpdateNotification {
4266 notification: Some(notification.clone()),
4267 },
4268 )?;
4269 }
4270
4271 Ok(())
4272 },
4273 );
4274
4275 send_notifications(pool, &session.peer, notifications);
4276
4277 Ok(())
4278}
4279
4280/// Mark a channel message as read
4281async fn acknowledge_channel_message(
4282 request: proto::AckChannelMessage,
4283 session: UserSession,
4284) -> Result<()> {
4285 let channel_id = ChannelId::from_proto(request.channel_id);
4286 let message_id = MessageId::from_proto(request.message_id);
4287 let notifications = session
4288 .db()
4289 .await
4290 .observe_channel_message(channel_id, session.user_id(), message_id)
4291 .await?;
4292 send_notifications(
4293 &*session.connection_pool().await,
4294 &session.peer,
4295 notifications,
4296 );
4297 Ok(())
4298}
4299
4300/// Mark a buffer version as synced
4301async fn acknowledge_buffer_version(
4302 request: proto::AckBufferOperation,
4303 session: UserSession,
4304) -> Result<()> {
4305 let buffer_id = BufferId::from_proto(request.buffer_id);
4306 session
4307 .db()
4308 .await
4309 .observe_buffer_version(
4310 buffer_id,
4311 session.user_id(),
4312 request.epoch as i32,
4313 &request.version,
4314 )
4315 .await?;
4316 Ok(())
4317}
4318
4319struct CompleteWithLanguageModelRateLimit;
4320
4321impl RateLimit for CompleteWithLanguageModelRateLimit {
4322 fn capacity() -> usize {
4323 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4324 .ok()
4325 .and_then(|v| v.parse().ok())
4326 .unwrap_or(120) // Picked arbitrarily
4327 }
4328
4329 fn refill_duration() -> chrono::Duration {
4330 chrono::Duration::hours(1)
4331 }
4332
4333 fn db_name() -> &'static str {
4334 "complete-with-language-model"
4335 }
4336}
4337
4338async fn complete_with_language_model(
4339 request: proto::CompleteWithLanguageModel,
4340 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4341 session: Session,
4342 open_ai_api_key: Option<Arc<str>>,
4343 google_ai_api_key: Option<Arc<str>>,
4344 anthropic_api_key: Option<Arc<str>>,
4345) -> Result<()> {
4346 let Some(session) = session.for_user() else {
4347 return Err(anyhow!("user not found"))?;
4348 };
4349 authorize_access_to_language_models(&session).await?;
4350 session
4351 .rate_limiter
4352 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4353 .await?;
4354
4355 if request.model.starts_with("gpt") {
4356 let api_key =
4357 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4358 complete_with_open_ai(request, response, session, api_key).await?;
4359 } else if request.model.starts_with("gemini") {
4360 let api_key = google_ai_api_key
4361 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4362 complete_with_google_ai(request, response, session, api_key).await?;
4363 } else if request.model.starts_with("claude") {
4364 let api_key = anthropic_api_key
4365 .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4366 complete_with_anthropic(request, response, session, api_key).await?;
4367 }
4368
4369 Ok(())
4370}
4371
4372async fn complete_with_open_ai(
4373 request: proto::CompleteWithLanguageModel,
4374 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4375 session: UserSession,
4376 api_key: Arc<str>,
4377) -> Result<()> {
4378 let mut completion_stream = open_ai::stream_completion(
4379 session.http_client.as_ref(),
4380 OPEN_AI_API_URL,
4381 &api_key,
4382 crate::ai::language_model_request_to_open_ai(request)?,
4383 None,
4384 )
4385 .await
4386 .context("open_ai::stream_completion request failed within collab")?;
4387
4388 while let Some(event) = completion_stream.next().await {
4389 let event = event?;
4390 response.send(proto::LanguageModelResponse {
4391 choices: event
4392 .choices
4393 .into_iter()
4394 .map(|choice| proto::LanguageModelChoiceDelta {
4395 index: choice.index,
4396 delta: Some(proto::LanguageModelResponseMessage {
4397 role: choice.delta.role.map(|role| match role {
4398 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4399 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4400 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4401 open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4402 } as i32),
4403 content: choice.delta.content,
4404 tool_calls: choice
4405 .delta
4406 .tool_calls
4407 .into_iter()
4408 .map(|delta| proto::ToolCallDelta {
4409 index: delta.index as u32,
4410 id: delta.id,
4411 variant: match delta.function {
4412 Some(function) => {
4413 let name = function.name;
4414 let arguments = function.arguments;
4415
4416 Some(proto::tool_call_delta::Variant::Function(
4417 proto::tool_call_delta::FunctionCallDelta {
4418 name,
4419 arguments,
4420 },
4421 ))
4422 }
4423 None => None,
4424 },
4425 })
4426 .collect(),
4427 }),
4428 finish_reason: choice.finish_reason,
4429 })
4430 .collect(),
4431 })?;
4432 }
4433
4434 Ok(())
4435}
4436
4437async fn complete_with_google_ai(
4438 request: proto::CompleteWithLanguageModel,
4439 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4440 session: UserSession,
4441 api_key: Arc<str>,
4442) -> Result<()> {
4443 let mut stream = google_ai::stream_generate_content(
4444 session.http_client.clone(),
4445 google_ai::API_URL,
4446 api_key.as_ref(),
4447 crate::ai::language_model_request_to_google_ai(request)?,
4448 )
4449 .await
4450 .context("google_ai::stream_generate_content request failed")?;
4451
4452 while let Some(event) = stream.next().await {
4453 let event = event?;
4454 response.send(proto::LanguageModelResponse {
4455 choices: event
4456 .candidates
4457 .unwrap_or_default()
4458 .into_iter()
4459 .map(|candidate| proto::LanguageModelChoiceDelta {
4460 index: candidate.index as u32,
4461 delta: Some(proto::LanguageModelResponseMessage {
4462 role: Some(match candidate.content.role {
4463 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4464 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4465 } as i32),
4466 content: Some(
4467 candidate
4468 .content
4469 .parts
4470 .into_iter()
4471 .filter_map(|part| match part {
4472 google_ai::Part::TextPart(part) => Some(part.text),
4473 google_ai::Part::InlineDataPart(_) => None,
4474 })
4475 .collect(),
4476 ),
4477 // Tool calls are not supported for Google
4478 tool_calls: Vec::new(),
4479 }),
4480 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4481 })
4482 .collect(),
4483 })?;
4484 }
4485
4486 Ok(())
4487}
4488
4489async fn complete_with_anthropic(
4490 request: proto::CompleteWithLanguageModel,
4491 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4492 session: UserSession,
4493 api_key: Arc<str>,
4494) -> Result<()> {
4495 let model = anthropic::Model::from_id(&request.model)?;
4496
4497 let mut system_message = String::new();
4498 let messages = request
4499 .messages
4500 .into_iter()
4501 .filter_map(|message| {
4502 match message.role() {
4503 LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4504 role: anthropic::Role::User,
4505 content: message.content,
4506 }),
4507 LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4508 role: anthropic::Role::Assistant,
4509 content: message.content,
4510 }),
4511 // Anthropic's API breaks system instructions out as a separate field rather
4512 // than having a system message role.
4513 LanguageModelRole::LanguageModelSystem => {
4514 if !system_message.is_empty() {
4515 system_message.push_str("\n\n");
4516 }
4517 system_message.push_str(&message.content);
4518
4519 None
4520 }
4521 // We don't yet support tool calls for Anthropic
4522 LanguageModelRole::LanguageModelTool => None,
4523 }
4524 })
4525 .collect();
4526
4527 let mut stream = anthropic::stream_completion(
4528 session.http_client.as_ref(),
4529 anthropic::ANTHROPIC_API_URL,
4530 &api_key,
4531 anthropic::Request {
4532 model,
4533 messages,
4534 stream: true,
4535 system: system_message,
4536 max_tokens: 4092,
4537 },
4538 None,
4539 )
4540 .await?;
4541
4542 let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4543
4544 while let Some(event) = stream.next().await {
4545 let event = event?;
4546
4547 match event {
4548 anthropic::ResponseEvent::MessageStart { message } => {
4549 if let Some(role) = message.role {
4550 if role == "assistant" {
4551 current_role = proto::LanguageModelRole::LanguageModelAssistant;
4552 } else if role == "user" {
4553 current_role = proto::LanguageModelRole::LanguageModelUser;
4554 }
4555 }
4556 }
4557 anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4558 match content_block {
4559 anthropic::ContentBlock::Text { text } => {
4560 if !text.is_empty() {
4561 response.send(proto::LanguageModelResponse {
4562 choices: vec![proto::LanguageModelChoiceDelta {
4563 index: 0,
4564 delta: Some(proto::LanguageModelResponseMessage {
4565 role: Some(current_role as i32),
4566 content: Some(text),
4567 tool_calls: Vec::new(),
4568 }),
4569 finish_reason: None,
4570 }],
4571 })?;
4572 }
4573 }
4574 }
4575 }
4576 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4577 anthropic::TextDelta::TextDelta { text } => {
4578 response.send(proto::LanguageModelResponse {
4579 choices: vec![proto::LanguageModelChoiceDelta {
4580 index: 0,
4581 delta: Some(proto::LanguageModelResponseMessage {
4582 role: Some(current_role as i32),
4583 content: Some(text),
4584 tool_calls: Vec::new(),
4585 }),
4586 finish_reason: None,
4587 }],
4588 })?;
4589 }
4590 },
4591 anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4592 if let Some(stop_reason) = delta.stop_reason {
4593 response.send(proto::LanguageModelResponse {
4594 choices: vec![proto::LanguageModelChoiceDelta {
4595 index: 0,
4596 delta: None,
4597 finish_reason: Some(stop_reason),
4598 }],
4599 })?;
4600 }
4601 }
4602 anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4603 anthropic::ResponseEvent::MessageStop {} => {}
4604 anthropic::ResponseEvent::Ping {} => {}
4605 }
4606 }
4607
4608 Ok(())
4609}
4610
4611struct CountTokensWithLanguageModelRateLimit;
4612
4613impl RateLimit for CountTokensWithLanguageModelRateLimit {
4614 fn capacity() -> usize {
4615 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4616 .ok()
4617 .and_then(|v| v.parse().ok())
4618 .unwrap_or(600) // Picked arbitrarily
4619 }
4620
4621 fn refill_duration() -> chrono::Duration {
4622 chrono::Duration::hours(1)
4623 }
4624
4625 fn db_name() -> &'static str {
4626 "count-tokens-with-language-model"
4627 }
4628}
4629
4630async fn count_tokens_with_language_model(
4631 request: proto::CountTokensWithLanguageModel,
4632 response: Response<proto::CountTokensWithLanguageModel>,
4633 session: UserSession,
4634 google_ai_api_key: Option<Arc<str>>,
4635) -> Result<()> {
4636 authorize_access_to_language_models(&session).await?;
4637
4638 if !request.model.starts_with("gemini") {
4639 return Err(anyhow!(
4640 "counting tokens for model: {:?} is not supported",
4641 request.model
4642 ))?;
4643 }
4644
4645 session
4646 .rate_limiter
4647 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4648 .await?;
4649
4650 let api_key = google_ai_api_key
4651 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4652 let tokens_response = google_ai::count_tokens(
4653 session.http_client.as_ref(),
4654 google_ai::API_URL,
4655 &api_key,
4656 crate::ai::count_tokens_request_to_google_ai(request)?,
4657 )
4658 .await?;
4659 response.send(proto::CountTokensResponse {
4660 token_count: tokens_response.total_tokens as u32,
4661 })?;
4662 Ok(())
4663}
4664
4665struct ComputeEmbeddingsRateLimit;
4666
4667impl RateLimit for ComputeEmbeddingsRateLimit {
4668 fn capacity() -> usize {
4669 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4670 .ok()
4671 .and_then(|v| v.parse().ok())
4672 .unwrap_or(5000) // Picked arbitrarily
4673 }
4674
4675 fn refill_duration() -> chrono::Duration {
4676 chrono::Duration::hours(1)
4677 }
4678
4679 fn db_name() -> &'static str {
4680 "compute-embeddings"
4681 }
4682}
4683
4684async fn compute_embeddings(
4685 request: proto::ComputeEmbeddings,
4686 response: Response<proto::ComputeEmbeddings>,
4687 session: UserSession,
4688 api_key: Option<Arc<str>>,
4689) -> Result<()> {
4690 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4691 authorize_access_to_language_models(&session).await?;
4692
4693 session
4694 .rate_limiter
4695 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4696 .await?;
4697
4698 let embeddings = match request.model.as_str() {
4699 "openai/text-embedding-3-small" => {
4700 open_ai::embed(
4701 session.http_client.as_ref(),
4702 OPEN_AI_API_URL,
4703 &api_key,
4704 OpenAiEmbeddingModel::TextEmbedding3Small,
4705 request.texts.iter().map(|text| text.as_str()),
4706 )
4707 .await?
4708 }
4709 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4710 };
4711
4712 let embeddings = request
4713 .texts
4714 .iter()
4715 .map(|text| {
4716 let mut hasher = sha2::Sha256::new();
4717 hasher.update(text.as_bytes());
4718 let result = hasher.finalize();
4719 result.to_vec()
4720 })
4721 .zip(
4722 embeddings
4723 .data
4724 .into_iter()
4725 .map(|embedding| embedding.embedding),
4726 )
4727 .collect::<HashMap<_, _>>();
4728
4729 let db = session.db().await;
4730 db.save_embeddings(&request.model, &embeddings)
4731 .await
4732 .context("failed to save embeddings")
4733 .trace_err();
4734
4735 response.send(proto::ComputeEmbeddingsResponse {
4736 embeddings: embeddings
4737 .into_iter()
4738 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4739 .collect(),
4740 })?;
4741 Ok(())
4742}
4743
4744async fn get_cached_embeddings(
4745 request: proto::GetCachedEmbeddings,
4746 response: Response<proto::GetCachedEmbeddings>,
4747 session: UserSession,
4748) -> Result<()> {
4749 authorize_access_to_language_models(&session).await?;
4750
4751 let db = session.db().await;
4752 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4753
4754 response.send(proto::GetCachedEmbeddingsResponse {
4755 embeddings: embeddings
4756 .into_iter()
4757 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4758 .collect(),
4759 })?;
4760 Ok(())
4761}
4762
4763async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4764 let db = session.db().await;
4765 let flags = db.get_user_flags(session.user_id()).await?;
4766 if flags.iter().any(|flag| flag == "language-models") {
4767 Ok(())
4768 } else {
4769 Err(anyhow!("permission denied"))?
4770 }
4771}
4772
4773/// Get a Supermaven API key for the user
4774async fn get_supermaven_api_key(
4775 _request: proto::GetSupermavenApiKey,
4776 response: Response<proto::GetSupermavenApiKey>,
4777 session: UserSession,
4778) -> Result<()> {
4779 let user_id: String = session.user_id().to_string();
4780 if !session.is_staff() {
4781 return Err(anyhow!("supermaven not enabled for this account"))?;
4782 }
4783
4784 let email = session
4785 .email()
4786 .ok_or_else(|| anyhow!("user must have an email"))?;
4787
4788 let supermaven_admin_api = session
4789 .supermaven_client
4790 .as_ref()
4791 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4792
4793 let result = supermaven_admin_api
4794 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4795 .await?;
4796
4797 response.send(proto::GetSupermavenApiKeyResponse {
4798 api_key: result.api_key,
4799 })?;
4800
4801 Ok(())
4802}
4803
4804/// Start receiving chat updates for a channel
4805async fn join_channel_chat(
4806 request: proto::JoinChannelChat,
4807 response: Response<proto::JoinChannelChat>,
4808 session: UserSession,
4809) -> Result<()> {
4810 let channel_id = ChannelId::from_proto(request.channel_id);
4811
4812 let db = session.db().await;
4813 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4814 .await?;
4815 let messages = db
4816 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4817 .await?;
4818 response.send(proto::JoinChannelChatResponse {
4819 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4820 messages,
4821 })?;
4822 Ok(())
4823}
4824
4825/// Stop receiving chat updates for a channel
4826async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4827 let channel_id = ChannelId::from_proto(request.channel_id);
4828 session
4829 .db()
4830 .await
4831 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4832 .await?;
4833 Ok(())
4834}
4835
4836/// Retrieve the chat history for a channel
4837async fn get_channel_messages(
4838 request: proto::GetChannelMessages,
4839 response: Response<proto::GetChannelMessages>,
4840 session: UserSession,
4841) -> Result<()> {
4842 let channel_id = ChannelId::from_proto(request.channel_id);
4843 let messages = session
4844 .db()
4845 .await
4846 .get_channel_messages(
4847 channel_id,
4848 session.user_id(),
4849 MESSAGE_COUNT_PER_PAGE,
4850 Some(MessageId::from_proto(request.before_message_id)),
4851 )
4852 .await?;
4853 response.send(proto::GetChannelMessagesResponse {
4854 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4855 messages,
4856 })?;
4857 Ok(())
4858}
4859
4860/// Retrieve specific chat messages
4861async fn get_channel_messages_by_id(
4862 request: proto::GetChannelMessagesById,
4863 response: Response<proto::GetChannelMessagesById>,
4864 session: UserSession,
4865) -> Result<()> {
4866 let message_ids = request
4867 .message_ids
4868 .iter()
4869 .map(|id| MessageId::from_proto(*id))
4870 .collect::<Vec<_>>();
4871 let messages = session
4872 .db()
4873 .await
4874 .get_channel_messages_by_id(session.user_id(), &message_ids)
4875 .await?;
4876 response.send(proto::GetChannelMessagesResponse {
4877 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4878 messages,
4879 })?;
4880 Ok(())
4881}
4882
4883/// Retrieve the current users notifications
4884async fn get_notifications(
4885 request: proto::GetNotifications,
4886 response: Response<proto::GetNotifications>,
4887 session: UserSession,
4888) -> Result<()> {
4889 let notifications = session
4890 .db()
4891 .await
4892 .get_notifications(
4893 session.user_id(),
4894 NOTIFICATION_COUNT_PER_PAGE,
4895 request
4896 .before_id
4897 .map(|id| db::NotificationId::from_proto(id)),
4898 )
4899 .await?;
4900 response.send(proto::GetNotificationsResponse {
4901 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4902 notifications,
4903 })?;
4904 Ok(())
4905}
4906
4907/// Mark notifications as read
4908async fn mark_notification_as_read(
4909 request: proto::MarkNotificationRead,
4910 response: Response<proto::MarkNotificationRead>,
4911 session: UserSession,
4912) -> Result<()> {
4913 let database = &session.db().await;
4914 let notifications = database
4915 .mark_notification_as_read_by_id(
4916 session.user_id(),
4917 NotificationId::from_proto(request.notification_id),
4918 )
4919 .await?;
4920 send_notifications(
4921 &*session.connection_pool().await,
4922 &session.peer,
4923 notifications,
4924 );
4925 response.send(proto::Ack {})?;
4926 Ok(())
4927}
4928
4929/// Get the current users information
4930async fn get_private_user_info(
4931 _request: proto::GetPrivateUserInfo,
4932 response: Response<proto::GetPrivateUserInfo>,
4933 session: UserSession,
4934) -> Result<()> {
4935 let db = session.db().await;
4936
4937 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4938 let user = db
4939 .get_user_by_id(session.user_id())
4940 .await?
4941 .ok_or_else(|| anyhow!("user not found"))?;
4942 let flags = db.get_user_flags(session.user_id()).await?;
4943
4944 response.send(proto::GetPrivateUserInfoResponse {
4945 metrics_id,
4946 staff: user.admin,
4947 flags,
4948 })?;
4949 Ok(())
4950}
4951
4952fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4953 match message {
4954 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4955 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4956 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4957 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4958 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4959 code: frame.code.into(),
4960 reason: frame.reason,
4961 })),
4962 }
4963}
4964
4965fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4966 match message {
4967 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4968 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4969 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4970 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4971 AxumMessage::Close(frame) => {
4972 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4973 code: frame.code.into(),
4974 reason: frame.reason,
4975 }))
4976 }
4977 }
4978}
4979
4980fn notify_membership_updated(
4981 connection_pool: &mut ConnectionPool,
4982 result: MembershipUpdated,
4983 user_id: UserId,
4984 peer: &Peer,
4985) {
4986 for membership in &result.new_channels.channel_memberships {
4987 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4988 }
4989 for channel_id in &result.removed_channels {
4990 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4991 }
4992
4993 let user_channels_update = proto::UpdateUserChannels {
4994 channel_memberships: result
4995 .new_channels
4996 .channel_memberships
4997 .iter()
4998 .map(|cm| proto::ChannelMembership {
4999 channel_id: cm.channel_id.to_proto(),
5000 role: cm.role.into(),
5001 })
5002 .collect(),
5003 ..Default::default()
5004 };
5005
5006 let mut update = build_channels_update(result.new_channels, vec![]);
5007 update.delete_channels = result
5008 .removed_channels
5009 .into_iter()
5010 .map(|id| id.to_proto())
5011 .collect();
5012 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5013
5014 for connection_id in connection_pool.user_connection_ids(user_id) {
5015 peer.send(connection_id, user_channels_update.clone())
5016 .trace_err();
5017 peer.send(connection_id, update.clone()).trace_err();
5018 }
5019}
5020
5021fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5022 proto::UpdateUserChannels {
5023 channel_memberships: channels
5024 .channel_memberships
5025 .iter()
5026 .map(|m| proto::ChannelMembership {
5027 channel_id: m.channel_id.to_proto(),
5028 role: m.role.into(),
5029 })
5030 .collect(),
5031 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5032 observed_channel_message_id: channels.observed_channel_messages.clone(),
5033 }
5034}
5035
5036fn build_channels_update(
5037 channels: ChannelsForUser,
5038 channel_invites: Vec<db::Channel>,
5039) -> proto::UpdateChannels {
5040 let mut update = proto::UpdateChannels::default();
5041
5042 for channel in channels.channels {
5043 update.channels.push(channel.to_proto());
5044 }
5045
5046 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5047 update.latest_channel_message_ids = channels.latest_channel_messages;
5048
5049 for (channel_id, participants) in channels.channel_participants {
5050 update
5051 .channel_participants
5052 .push(proto::ChannelParticipants {
5053 channel_id: channel_id.to_proto(),
5054 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5055 });
5056 }
5057
5058 for channel in channel_invites {
5059 update.channel_invitations.push(channel.to_proto());
5060 }
5061
5062 update.hosted_projects = channels.hosted_projects;
5063 update
5064}
5065
5066fn build_initial_contacts_update(
5067 contacts: Vec<db::Contact>,
5068 pool: &ConnectionPool,
5069) -> proto::UpdateContacts {
5070 let mut update = proto::UpdateContacts::default();
5071
5072 for contact in contacts {
5073 match contact {
5074 db::Contact::Accepted { user_id, busy } => {
5075 update.contacts.push(contact_for_user(user_id, busy, &pool));
5076 }
5077 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5078 db::Contact::Incoming { user_id } => {
5079 update
5080 .incoming_requests
5081 .push(proto::IncomingContactRequest {
5082 requester_id: user_id.to_proto(),
5083 })
5084 }
5085 }
5086 }
5087
5088 update
5089}
5090
5091fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5092 proto::Contact {
5093 user_id: user_id.to_proto(),
5094 online: pool.is_user_online(user_id),
5095 busy,
5096 }
5097}
5098
5099fn room_updated(room: &proto::Room, peer: &Peer) {
5100 broadcast(
5101 None,
5102 room.participants
5103 .iter()
5104 .filter_map(|participant| Some(participant.peer_id?.into())),
5105 |peer_id| {
5106 peer.send(
5107 peer_id,
5108 proto::RoomUpdated {
5109 room: Some(room.clone()),
5110 },
5111 )
5112 },
5113 );
5114}
5115
5116fn channel_updated(
5117 channel: &db::channel::Model,
5118 room: &proto::Room,
5119 peer: &Peer,
5120 pool: &ConnectionPool,
5121) {
5122 let participants = room
5123 .participants
5124 .iter()
5125 .map(|p| p.user_id)
5126 .collect::<Vec<_>>();
5127
5128 broadcast(
5129 None,
5130 pool.channel_connection_ids(channel.root_id())
5131 .filter_map(|(channel_id, role)| {
5132 role.can_see_channel(channel.visibility).then(|| channel_id)
5133 }),
5134 |peer_id| {
5135 peer.send(
5136 peer_id,
5137 proto::UpdateChannels {
5138 channel_participants: vec![proto::ChannelParticipants {
5139 channel_id: channel.id.to_proto(),
5140 participant_user_ids: participants.clone(),
5141 }],
5142 ..Default::default()
5143 },
5144 )
5145 },
5146 );
5147}
5148
5149async fn send_dev_server_projects_update(
5150 user_id: UserId,
5151 mut status: proto::DevServerProjectsUpdate,
5152 session: &Session,
5153) {
5154 let pool = session.connection_pool().await;
5155 for dev_server in &mut status.dev_servers {
5156 dev_server.status =
5157 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5158 }
5159 let connections = pool.user_connection_ids(user_id);
5160 for connection_id in connections {
5161 session.peer.send(connection_id, status.clone()).trace_err();
5162 }
5163}
5164
5165async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5166 let db = session.db().await;
5167
5168 let contacts = db.get_contacts(user_id).await?;
5169 let busy = db.is_user_busy(user_id).await?;
5170
5171 let pool = session.connection_pool().await;
5172 let updated_contact = contact_for_user(user_id, busy, &pool);
5173 for contact in contacts {
5174 if let db::Contact::Accepted {
5175 user_id: contact_user_id,
5176 ..
5177 } = contact
5178 {
5179 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5180 session
5181 .peer
5182 .send(
5183 contact_conn_id,
5184 proto::UpdateContacts {
5185 contacts: vec![updated_contact.clone()],
5186 remove_contacts: Default::default(),
5187 incoming_requests: Default::default(),
5188 remove_incoming_requests: Default::default(),
5189 outgoing_requests: Default::default(),
5190 remove_outgoing_requests: Default::default(),
5191 },
5192 )
5193 .trace_err();
5194 }
5195 }
5196 }
5197 Ok(())
5198}
5199
5200async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5201 log::info!("lost dev server connection, unsharing projects");
5202 let project_ids = session
5203 .db()
5204 .await
5205 .get_stale_dev_server_projects(session.connection_id)
5206 .await?;
5207
5208 for project_id in project_ids {
5209 // not unshare re-checks the connection ids match, so we get away with no transaction
5210 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5211 }
5212
5213 let user_id = session.dev_server().user_id;
5214 let update = session
5215 .db()
5216 .await
5217 .dev_server_projects_update(user_id)
5218 .await?;
5219
5220 send_dev_server_projects_update(user_id, update, session).await;
5221
5222 Ok(())
5223}
5224
5225async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5226 let mut contacts_to_update = HashSet::default();
5227
5228 let room_id;
5229 let canceled_calls_to_user_ids;
5230 let live_kit_room;
5231 let delete_live_kit_room;
5232 let room;
5233 let channel;
5234
5235 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5236 contacts_to_update.insert(session.user_id());
5237
5238 for project in left_room.left_projects.values() {
5239 project_left(project, session);
5240 }
5241
5242 room_id = RoomId::from_proto(left_room.room.id);
5243 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5244 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5245 delete_live_kit_room = left_room.deleted;
5246 room = mem::take(&mut left_room.room);
5247 channel = mem::take(&mut left_room.channel);
5248
5249 room_updated(&room, &session.peer);
5250 } else {
5251 return Ok(());
5252 }
5253
5254 if let Some(channel) = channel {
5255 channel_updated(
5256 &channel,
5257 &room,
5258 &session.peer,
5259 &*session.connection_pool().await,
5260 );
5261 }
5262
5263 {
5264 let pool = session.connection_pool().await;
5265 for canceled_user_id in canceled_calls_to_user_ids {
5266 for connection_id in pool.user_connection_ids(canceled_user_id) {
5267 session
5268 .peer
5269 .send(
5270 connection_id,
5271 proto::CallCanceled {
5272 room_id: room_id.to_proto(),
5273 },
5274 )
5275 .trace_err();
5276 }
5277 contacts_to_update.insert(canceled_user_id);
5278 }
5279 }
5280
5281 for contact_user_id in contacts_to_update {
5282 update_user_contacts(contact_user_id, &session).await?;
5283 }
5284
5285 if let Some(live_kit) = session.live_kit_client.as_ref() {
5286 live_kit
5287 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5288 .await
5289 .trace_err();
5290
5291 if delete_live_kit_room {
5292 live_kit.delete_room(live_kit_room).await.trace_err();
5293 }
5294 }
5295
5296 Ok(())
5297}
5298
5299async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5300 let left_channel_buffers = session
5301 .db()
5302 .await
5303 .leave_channel_buffers(session.connection_id)
5304 .await?;
5305
5306 for left_buffer in left_channel_buffers {
5307 channel_buffer_updated(
5308 session.connection_id,
5309 left_buffer.connections,
5310 &proto::UpdateChannelBufferCollaborators {
5311 channel_id: left_buffer.channel_id.to_proto(),
5312 collaborators: left_buffer.collaborators,
5313 },
5314 &session.peer,
5315 );
5316 }
5317
5318 Ok(())
5319}
5320
5321fn project_left(project: &db::LeftProject, session: &UserSession) {
5322 for connection_id in &project.connection_ids {
5323 if project.should_unshare {
5324 session
5325 .peer
5326 .send(
5327 *connection_id,
5328 proto::UnshareProject {
5329 project_id: project.id.to_proto(),
5330 },
5331 )
5332 .trace_err();
5333 } else {
5334 session
5335 .peer
5336 .send(
5337 *connection_id,
5338 proto::RemoveProjectCollaborator {
5339 project_id: project.id.to_proto(),
5340 peer_id: Some(session.connection_id.into()),
5341 },
5342 )
5343 .trace_err();
5344 }
5345 }
5346}
5347
5348pub trait ResultExt {
5349 type Ok;
5350
5351 fn trace_err(self) -> Option<Self::Ok>;
5352}
5353
5354impl<T, E> ResultExt for Result<T, E>
5355where
5356 E: std::fmt::Debug,
5357{
5358 type Ok = T;
5359
5360 #[track_caller]
5361 fn trace_err(self) -> Option<T> {
5362 match self {
5363 Ok(value) => Some(value),
5364 Err(error) => {
5365 tracing::error!("{:?}", error);
5366 None
5367 }
5368 }
5369 }
5370}