1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
8 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
9 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38
39use futures::{
40 channel::oneshot,
41 future::{self, BoxFuture},
42 stream::FuturesUnordered,
43 FutureExt, SinkExt, StreamExt, TryStreamExt,
44};
45use prometheus::{register_int_gauge, IntGauge};
46use rpc::{
47 proto::{
48 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
49 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
50 },
51 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
52};
53use semantic_version::SemanticVersion;
54use serde::{Serialize, Serializer};
55use std::{
56 any::TypeId,
57 future::Future,
58 marker::PhantomData,
59 mem,
60 net::SocketAddr,
61 ops::{Deref, DerefMut},
62 rc::Rc,
63 sync::{
64 atomic::{AtomicBool, Ordering::SeqCst},
65 Arc, OnceLock,
66 },
67 time::{Duration, Instant},
68};
69use time::OffsetDateTime;
70use tokio::sync::{watch, Semaphore};
71use tower::ServiceBuilder;
72use tracing::{
73 field::{self},
74 info_span, instrument, Instrument,
75};
76use util::http::IsahcHttpClient;
77
78use self::connection_pool::VersionedMessage;
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106struct StreamingResponse<R: RequestMessage> {
107 peer: Arc<Peer>,
108 receipt: Receipt<R>,
109}
110
111impl<R: RequestMessage> StreamingResponse<R> {
112 fn send(&self, payload: R::Response) -> Result<()> {
113 self.peer.respond(self.receipt, payload)?;
114 Ok(())
115 }
116}
117
118#[derive(Clone, Debug)]
119pub enum Principal {
120 User(User),
121 Impersonated { user: User, admin: User },
122 DevServer(dev_server::Model),
123}
124
125impl Principal {
126 fn update_span(&self, span: &tracing::Span) {
127 match &self {
128 Principal::User(user) => {
129 span.record("user_id", &user.id.0);
130 span.record("login", &user.github_login);
131 }
132 Principal::Impersonated { user, admin } => {
133 span.record("user_id", &user.id.0);
134 span.record("login", &user.github_login);
135 span.record("impersonator", &admin.github_login);
136 }
137 Principal::DevServer(dev_server) => {
138 span.record("dev_server_id", &dev_server.id.0);
139 }
140 }
141 }
142}
143
144#[derive(Clone)]
145struct Session {
146 principal: Principal,
147 connection_id: ConnectionId,
148 db: Arc<tokio::sync::Mutex<DbHandle>>,
149 peer: Arc<Peer>,
150 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
151 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
152 supermaven_client: Option<Arc<SupermavenAdminApi>>,
153 http_client: Arc<IsahcHttpClient>,
154 rate_limiter: Arc<RateLimiter>,
155 _executor: Executor,
156}
157
158impl Session {
159 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.db.lock().await;
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 guard
166 }
167
168 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
169 #[cfg(test)]
170 tokio::task::yield_now().await;
171 let guard = self.connection_pool.lock();
172 ConnectionPoolGuard {
173 guard,
174 _not_send: PhantomData,
175 }
176 }
177
178 fn for_user(self) -> Option<UserSession> {
179 UserSession::new(self)
180 }
181
182 fn for_dev_server(self) -> Option<DevServerSession> {
183 DevServerSession::new(self)
184 }
185
186 fn user_id(&self) -> Option<UserId> {
187 match &self.principal {
188 Principal::User(user) => Some(user.id),
189 Principal::Impersonated { user, .. } => Some(user.id),
190 Principal::DevServer(_) => None,
191 }
192 }
193
194 fn is_staff(&self) -> bool {
195 match &self.principal {
196 Principal::User(user) => user.admin,
197 Principal::Impersonated { .. } => true,
198 Principal::DevServer(_) => false,
199 }
200 }
201
202 fn dev_server_id(&self) -> Option<DevServerId> {
203 match &self.principal {
204 Principal::User(_) | Principal::Impersonated { .. } => None,
205 Principal::DevServer(dev_server) => Some(dev_server.id),
206 }
207 }
208
209 fn principal_id(&self) -> PrincipalId {
210 match &self.principal {
211 Principal::User(user) => PrincipalId::UserId(user.id),
212 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
213 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
214 }
215 }
216}
217
218impl Debug for Session {
219 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220 let mut result = f.debug_struct("Session");
221 match &self.principal {
222 Principal::User(user) => {
223 result.field("user", &user.github_login);
224 }
225 Principal::Impersonated { user, admin } => {
226 result.field("user", &user.github_login);
227 result.field("impersonator", &admin.github_login);
228 }
229 Principal::DevServer(dev_server) => {
230 result.field("dev_server", &dev_server.id);
231 }
232 }
233 result.field("connection_id", &self.connection_id).finish()
234 }
235}
236
237struct UserSession(Session);
238
239impl UserSession {
240 pub fn new(s: Session) -> Option<Self> {
241 s.user_id().map(|_| UserSession(s))
242 }
243 pub fn user_id(&self) -> UserId {
244 self.0.user_id().unwrap()
245 }
246
247 pub fn email(&self) -> Option<String> {
248 match &self.0.principal {
249 Principal::User(user) => user.email_address.clone(),
250 Principal::Impersonated { user, .. } => user.email_address.clone(),
251 Principal::DevServer(..) => None,
252 }
253 }
254}
255
256impl Deref for UserSession {
257 type Target = Session;
258
259 fn deref(&self) -> &Self::Target {
260 &self.0
261 }
262}
263impl DerefMut for UserSession {
264 fn deref_mut(&mut self) -> &mut Self::Target {
265 &mut self.0
266 }
267}
268
269struct DevServerSession(Session);
270
271impl DevServerSession {
272 pub fn new(s: Session) -> Option<Self> {
273 s.dev_server_id().map(|_| DevServerSession(s))
274 }
275 pub fn dev_server_id(&self) -> DevServerId {
276 self.0.dev_server_id().unwrap()
277 }
278
279 fn dev_server(&self) -> &dev_server::Model {
280 match &self.0.principal {
281 Principal::DevServer(dev_server) => dev_server,
282 _ => unreachable!(),
283 }
284 }
285}
286
287impl Deref for DevServerSession {
288 type Target = Session;
289
290 fn deref(&self) -> &Self::Target {
291 &self.0
292 }
293}
294impl DerefMut for DevServerSession {
295 fn deref_mut(&mut self) -> &mut Self::Target {
296 &mut self.0
297 }
298}
299
300fn user_handler<M: RequestMessage, Fut>(
301 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
303where
304 Fut: Send + Future<Output = Result<()>>,
305{
306 let handler = Arc::new(handler);
307 move |message, response, session| {
308 let handler = handler.clone();
309 Box::pin(async move {
310 if let Some(user_session) = session.for_user() {
311 Ok(handler(message, response, user_session).await?)
312 } else {
313 Err(Error::Internal(anyhow!(
314 "must be a user to call {}",
315 M::NAME
316 )))
317 }
318 })
319 }
320}
321
322fn dev_server_handler<M: RequestMessage, Fut>(
323 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
325where
326 Fut: Send + Future<Output = Result<()>>,
327{
328 let handler = Arc::new(handler);
329 move |message, response, session| {
330 let handler = handler.clone();
331 Box::pin(async move {
332 if let Some(dev_server_session) = session.for_dev_server() {
333 Ok(handler(message, response, dev_server_session).await?)
334 } else {
335 Err(Error::Internal(anyhow!(
336 "must be a dev server to call {}",
337 M::NAME
338 )))
339 }
340 })
341 }
342}
343
344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
345 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
347where
348 InnertRetFut: Send + Future<Output = Result<()>>,
349{
350 let handler = Arc::new(handler);
351 move |message, session| {
352 let handler = handler.clone();
353 Box::pin(async move {
354 if let Some(user_session) = session.for_user() {
355 Ok(handler(message, user_session).await?)
356 } else {
357 Err(Error::Internal(anyhow!(
358 "must be a user to call {}",
359 M::NAME
360 )))
361 }
362 })
363 }
364}
365
366struct DbHandle(Arc<Database>);
367
368impl Deref for DbHandle {
369 type Target = Database;
370
371 fn deref(&self) -> &Self::Target {
372 self.0.as_ref()
373 }
374}
375
376pub struct Server {
377 id: parking_lot::Mutex<ServerId>,
378 peer: Arc<Peer>,
379 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
380 app_state: Arc<AppState>,
381 handlers: HashMap<TypeId, MessageHandler>,
382 teardown: watch::Sender<bool>,
383}
384
385pub(crate) struct ConnectionPoolGuard<'a> {
386 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
387 _not_send: PhantomData<Rc<()>>,
388}
389
390#[derive(Serialize)]
391pub struct ServerSnapshot<'a> {
392 peer: &'a Peer,
393 #[serde(serialize_with = "serialize_deref")]
394 connection_pool: ConnectionPoolGuard<'a>,
395}
396
397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
398where
399 S: Serializer,
400 T: Deref<Target = U>,
401 U: Serialize,
402{
403 Serialize::serialize(value.deref(), serializer)
404}
405
406impl Server {
407 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
408 let mut server = Self {
409 id: parking_lot::Mutex::new(id),
410 peer: Peer::new(id.0 as u32),
411 app_state: app_state.clone(),
412 connection_pool: Default::default(),
413 handlers: Default::default(),
414 teardown: watch::channel(false).0,
415 };
416
417 server
418 .add_request_handler(ping)
419 .add_request_handler(user_handler(create_room))
420 .add_request_handler(user_handler(join_room))
421 .add_request_handler(user_handler(rejoin_room))
422 .add_request_handler(user_handler(leave_room))
423 .add_request_handler(user_handler(set_room_participant_role))
424 .add_request_handler(user_handler(call))
425 .add_request_handler(user_handler(cancel_call))
426 .add_message_handler(user_message_handler(decline_call))
427 .add_request_handler(user_handler(update_participant_location))
428 .add_request_handler(user_handler(share_project))
429 .add_message_handler(unshare_project)
430 .add_request_handler(user_handler(join_project))
431 .add_request_handler(user_handler(join_hosted_project))
432 .add_request_handler(user_handler(rejoin_dev_server_projects))
433 .add_request_handler(user_handler(create_dev_server_project))
434 .add_request_handler(user_handler(delete_dev_server_project))
435 .add_request_handler(user_handler(create_dev_server))
436 .add_request_handler(user_handler(delete_dev_server))
437 .add_request_handler(dev_server_handler(share_dev_server_project))
438 .add_request_handler(dev_server_handler(shutdown_dev_server))
439 .add_request_handler(dev_server_handler(reconnect_dev_server))
440 .add_message_handler(user_message_handler(leave_project))
441 .add_request_handler(update_project)
442 .add_request_handler(update_worktree)
443 .add_message_handler(start_language_server)
444 .add_message_handler(update_language_server)
445 .add_message_handler(update_diagnostic_summary)
446 .add_message_handler(update_worktree_settings)
447 .add_request_handler(user_handler(
448 forward_read_only_project_request::<proto::GetHover>,
449 ))
450 .add_request_handler(user_handler(
451 forward_read_only_project_request::<proto::GetDefinition>,
452 ))
453 .add_request_handler(user_handler(
454 forward_read_only_project_request::<proto::GetTypeDefinition>,
455 ))
456 .add_request_handler(user_handler(
457 forward_read_only_project_request::<proto::GetReferences>,
458 ))
459 .add_request_handler(user_handler(
460 forward_read_only_project_request::<proto::SearchProject>,
461 ))
462 .add_request_handler(user_handler(
463 forward_read_only_project_request::<proto::GetDocumentHighlights>,
464 ))
465 .add_request_handler(user_handler(
466 forward_read_only_project_request::<proto::GetProjectSymbols>,
467 ))
468 .add_request_handler(user_handler(
469 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
470 ))
471 .add_request_handler(user_handler(
472 forward_read_only_project_request::<proto::OpenBufferById>,
473 ))
474 .add_request_handler(user_handler(
475 forward_read_only_project_request::<proto::SynchronizeBuffers>,
476 ))
477 .add_request_handler(user_handler(
478 forward_read_only_project_request::<proto::InlayHints>,
479 ))
480 .add_request_handler(user_handler(
481 forward_read_only_project_request::<proto::OpenBufferByPath>,
482 ))
483 .add_request_handler(user_handler(
484 forward_mutating_project_request::<proto::GetCompletions>,
485 ))
486 .add_request_handler(user_handler(
487 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
488 ))
489 .add_request_handler(user_handler(
490 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
491 ))
492 .add_request_handler(user_handler(
493 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
494 ))
495 .add_request_handler(user_handler(
496 forward_mutating_project_request::<proto::GetCodeActions>,
497 ))
498 .add_request_handler(user_handler(
499 forward_mutating_project_request::<proto::ApplyCodeAction>,
500 ))
501 .add_request_handler(user_handler(
502 forward_mutating_project_request::<proto::PrepareRename>,
503 ))
504 .add_request_handler(user_handler(
505 forward_mutating_project_request::<proto::PerformRename>,
506 ))
507 .add_request_handler(user_handler(
508 forward_mutating_project_request::<proto::ReloadBuffers>,
509 ))
510 .add_request_handler(user_handler(
511 forward_mutating_project_request::<proto::FormatBuffers>,
512 ))
513 .add_request_handler(user_handler(
514 forward_mutating_project_request::<proto::CreateProjectEntry>,
515 ))
516 .add_request_handler(user_handler(
517 forward_mutating_project_request::<proto::RenameProjectEntry>,
518 ))
519 .add_request_handler(user_handler(
520 forward_mutating_project_request::<proto::CopyProjectEntry>,
521 ))
522 .add_request_handler(user_handler(
523 forward_mutating_project_request::<proto::DeleteProjectEntry>,
524 ))
525 .add_request_handler(user_handler(
526 forward_mutating_project_request::<proto::ExpandProjectEntry>,
527 ))
528 .add_request_handler(user_handler(
529 forward_mutating_project_request::<proto::OnTypeFormatting>,
530 ))
531 .add_request_handler(user_handler(
532 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
533 ))
534 .add_request_handler(user_handler(
535 forward_mutating_project_request::<proto::BlameBuffer>,
536 ))
537 .add_request_handler(user_handler(
538 forward_mutating_project_request::<proto::MultiLspQuery>,
539 ))
540 .add_message_handler(create_buffer_for_peer)
541 .add_request_handler(update_buffer)
542 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
543 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
544 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
545 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
546 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
547 .add_request_handler(get_users)
548 .add_request_handler(user_handler(fuzzy_search_users))
549 .add_request_handler(user_handler(request_contact))
550 .add_request_handler(user_handler(remove_contact))
551 .add_request_handler(user_handler(respond_to_contact_request))
552 .add_request_handler(user_handler(create_channel))
553 .add_request_handler(user_handler(delete_channel))
554 .add_request_handler(user_handler(invite_channel_member))
555 .add_request_handler(user_handler(remove_channel_member))
556 .add_request_handler(user_handler(set_channel_member_role))
557 .add_request_handler(user_handler(set_channel_visibility))
558 .add_request_handler(user_handler(rename_channel))
559 .add_request_handler(user_handler(join_channel_buffer))
560 .add_request_handler(user_handler(leave_channel_buffer))
561 .add_message_handler(user_message_handler(update_channel_buffer))
562 .add_request_handler(user_handler(rejoin_channel_buffers))
563 .add_request_handler(user_handler(get_channel_members))
564 .add_request_handler(user_handler(respond_to_channel_invite))
565 .add_request_handler(user_handler(join_channel))
566 .add_request_handler(user_handler(join_channel_chat))
567 .add_message_handler(user_message_handler(leave_channel_chat))
568 .add_request_handler(user_handler(send_channel_message))
569 .add_request_handler(user_handler(remove_channel_message))
570 .add_request_handler(user_handler(update_channel_message))
571 .add_request_handler(user_handler(get_channel_messages))
572 .add_request_handler(user_handler(get_channel_messages_by_id))
573 .add_request_handler(user_handler(get_notifications))
574 .add_request_handler(user_handler(mark_notification_as_read))
575 .add_request_handler(user_handler(move_channel))
576 .add_request_handler(user_handler(follow))
577 .add_message_handler(user_message_handler(unfollow))
578 .add_message_handler(user_message_handler(update_followers))
579 .add_request_handler(user_handler(get_private_user_info))
580 .add_message_handler(user_message_handler(acknowledge_channel_message))
581 .add_message_handler(user_message_handler(acknowledge_buffer_version))
582 .add_request_handler(user_handler(get_supermaven_api_key))
583 .add_streaming_request_handler({
584 let app_state = app_state.clone();
585 move |request, response, session| {
586 complete_with_language_model(
587 request,
588 response,
589 session,
590 app_state.config.openai_api_key.clone(),
591 app_state.config.google_ai_api_key.clone(),
592 app_state.config.anthropic_api_key.clone(),
593 )
594 }
595 })
596 .add_request_handler({
597 let app_state = app_state.clone();
598 user_handler(move |request, response, session| {
599 count_tokens_with_language_model(
600 request,
601 response,
602 session,
603 app_state.config.google_ai_api_key.clone(),
604 )
605 })
606 })
607 .add_request_handler({
608 user_handler(move |request, response, session| {
609 get_cached_embeddings(request, response, session)
610 })
611 })
612 .add_request_handler({
613 let app_state = app_state.clone();
614 user_handler(move |request, response, session| {
615 compute_embeddings(
616 request,
617 response,
618 session,
619 app_state.config.openai_api_key.clone(),
620 )
621 })
622 });
623
624 Arc::new(server)
625 }
626
627 pub async fn start(&self) -> Result<()> {
628 let server_id = *self.id.lock();
629 let app_state = self.app_state.clone();
630 let peer = self.peer.clone();
631 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
632 let pool = self.connection_pool.clone();
633 let live_kit_client = self.app_state.live_kit_client.clone();
634
635 let span = info_span!("start server");
636 self.app_state.executor.spawn_detached(
637 async move {
638 tracing::info!("waiting for cleanup timeout");
639 timeout.await;
640 tracing::info!("cleanup timeout expired, retrieving stale rooms");
641 if let Some((room_ids, channel_ids)) = app_state
642 .db
643 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
644 .await
645 .trace_err()
646 {
647 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
648 tracing::info!(
649 stale_channel_buffer_count = channel_ids.len(),
650 "retrieved stale channel buffers"
651 );
652
653 for channel_id in channel_ids {
654 if let Some(refreshed_channel_buffer) = app_state
655 .db
656 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
657 .await
658 .trace_err()
659 {
660 for connection_id in refreshed_channel_buffer.connection_ids {
661 peer.send(
662 connection_id,
663 proto::UpdateChannelBufferCollaborators {
664 channel_id: channel_id.to_proto(),
665 collaborators: refreshed_channel_buffer
666 .collaborators
667 .clone(),
668 },
669 )
670 .trace_err();
671 }
672 }
673 }
674
675 for room_id in room_ids {
676 let mut contacts_to_update = HashSet::default();
677 let mut canceled_calls_to_user_ids = Vec::new();
678 let mut live_kit_room = String::new();
679 let mut delete_live_kit_room = false;
680
681 if let Some(mut refreshed_room) = app_state
682 .db
683 .clear_stale_room_participants(room_id, server_id)
684 .await
685 .trace_err()
686 {
687 tracing::info!(
688 room_id = room_id.0,
689 new_participant_count = refreshed_room.room.participants.len(),
690 "refreshed room"
691 );
692 room_updated(&refreshed_room.room, &peer);
693 if let Some(channel) = refreshed_room.channel.as_ref() {
694 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
695 }
696 contacts_to_update
697 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
698 contacts_to_update
699 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
700 canceled_calls_to_user_ids =
701 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
702 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
703 delete_live_kit_room = refreshed_room.room.participants.is_empty();
704 }
705
706 {
707 let pool = pool.lock();
708 for canceled_user_id in canceled_calls_to_user_ids {
709 for connection_id in pool.user_connection_ids(canceled_user_id) {
710 peer.send(
711 connection_id,
712 proto::CallCanceled {
713 room_id: room_id.to_proto(),
714 },
715 )
716 .trace_err();
717 }
718 }
719 }
720
721 for user_id in contacts_to_update {
722 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
723 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
724 if let Some((busy, contacts)) = busy.zip(contacts) {
725 let pool = pool.lock();
726 let updated_contact = contact_for_user(user_id, busy, &pool);
727 for contact in contacts {
728 if let db::Contact::Accepted {
729 user_id: contact_user_id,
730 ..
731 } = contact
732 {
733 for contact_conn_id in
734 pool.user_connection_ids(contact_user_id)
735 {
736 peer.send(
737 contact_conn_id,
738 proto::UpdateContacts {
739 contacts: vec![updated_contact.clone()],
740 remove_contacts: Default::default(),
741 incoming_requests: Default::default(),
742 remove_incoming_requests: Default::default(),
743 outgoing_requests: Default::default(),
744 remove_outgoing_requests: Default::default(),
745 },
746 )
747 .trace_err();
748 }
749 }
750 }
751 }
752 }
753
754 if let Some(live_kit) = live_kit_client.as_ref() {
755 if delete_live_kit_room {
756 live_kit.delete_room(live_kit_room).await.trace_err();
757 }
758 }
759 }
760 }
761
762 app_state
763 .db
764 .delete_stale_servers(&app_state.config.zed_environment, server_id)
765 .await
766 .trace_err();
767 }
768 .instrument(span),
769 );
770 Ok(())
771 }
772
773 pub fn teardown(&self) {
774 self.peer.teardown();
775 self.connection_pool.lock().reset();
776 let _ = self.teardown.send(true);
777 }
778
779 #[cfg(test)]
780 pub fn reset(&self, id: ServerId) {
781 self.teardown();
782 *self.id.lock() = id;
783 self.peer.reset(id.0 as u32);
784 let _ = self.teardown.send(false);
785 }
786
787 #[cfg(test)]
788 pub fn id(&self) -> ServerId {
789 *self.id.lock()
790 }
791
792 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
793 where
794 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
795 Fut: 'static + Send + Future<Output = Result<()>>,
796 M: EnvelopedMessage,
797 {
798 let prev_handler = self.handlers.insert(
799 TypeId::of::<M>(),
800 Box::new(move |envelope, session| {
801 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
802 let received_at = envelope.received_at;
803 tracing::info!("message received");
804 let start_time = Instant::now();
805 let future = (handler)(*envelope, session);
806 async move {
807 let result = future.await;
808 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
809 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
810 let queue_duration_ms = total_duration_ms - processing_duration_ms;
811 let payload_type = M::NAME;
812
813 match result {
814 Err(error) => {
815 tracing::error!(
816 ?error,
817 total_duration_ms,
818 processing_duration_ms,
819 queue_duration_ms,
820 payload_type,
821 "error handling message"
822 )
823 }
824 Ok(()) => tracing::info!(
825 total_duration_ms,
826 processing_duration_ms,
827 queue_duration_ms,
828 "finished handling message"
829 ),
830 }
831 }
832 .boxed()
833 }),
834 );
835 if prev_handler.is_some() {
836 panic!("registered a handler for the same message twice");
837 }
838 self
839 }
840
841 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
842 where
843 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
844 Fut: 'static + Send + Future<Output = Result<()>>,
845 M: EnvelopedMessage,
846 {
847 self.add_handler(move |envelope, session| handler(envelope.payload, session));
848 self
849 }
850
851 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
852 where
853 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
854 Fut: Send + Future<Output = Result<()>>,
855 M: RequestMessage,
856 {
857 let handler = Arc::new(handler);
858 self.add_handler(move |envelope, session| {
859 let receipt = envelope.receipt();
860 let handler = handler.clone();
861 async move {
862 let peer = session.peer.clone();
863 let responded = Arc::new(AtomicBool::default());
864 let response = Response {
865 peer: peer.clone(),
866 responded: responded.clone(),
867 receipt,
868 };
869 match (handler)(envelope.payload, response, session).await {
870 Ok(()) => {
871 if responded.load(std::sync::atomic::Ordering::SeqCst) {
872 Ok(())
873 } else {
874 Err(anyhow!("handler did not send a response"))?
875 }
876 }
877 Err(error) => {
878 let proto_err = match &error {
879 Error::Internal(err) => err.to_proto(),
880 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
881 };
882 peer.respond_with_error(receipt, proto_err)?;
883 Err(error)
884 }
885 }
886 }
887 })
888 }
889
890 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
891 where
892 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
893 Fut: Send + Future<Output = Result<()>>,
894 M: RequestMessage,
895 {
896 let handler = Arc::new(handler);
897 self.add_handler(move |envelope, session| {
898 let receipt = envelope.receipt();
899 let handler = handler.clone();
900 async move {
901 let peer = session.peer.clone();
902 let response = StreamingResponse {
903 peer: peer.clone(),
904 receipt,
905 };
906 match (handler)(envelope.payload, response, session).await {
907 Ok(()) => {
908 peer.end_stream(receipt)?;
909 Ok(())
910 }
911 Err(error) => {
912 let proto_err = match &error {
913 Error::Internal(err) => err.to_proto(),
914 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
915 };
916 peer.respond_with_error(receipt, proto_err)?;
917 Err(error)
918 }
919 }
920 }
921 })
922 }
923
924 #[allow(clippy::too_many_arguments)]
925 pub fn handle_connection(
926 self: &Arc<Self>,
927 connection: Connection,
928 address: String,
929 principal: Principal,
930 zed_version: ZedVersion,
931 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
932 executor: Executor,
933 ) -> impl Future<Output = ()> {
934 let this = self.clone();
935 let span = info_span!("handle connection", %address,
936 connection_id=field::Empty,
937 user_id=field::Empty,
938 login=field::Empty,
939 impersonator=field::Empty,
940 dev_server_id=field::Empty
941 );
942 principal.update_span(&span);
943
944 let mut teardown = self.teardown.subscribe();
945 async move {
946 if *teardown.borrow() {
947 tracing::error!("server is tearing down");
948 return
949 }
950 let (connection_id, handle_io, mut incoming_rx) = this
951 .peer
952 .add_connection(connection, {
953 let executor = executor.clone();
954 move |duration| executor.sleep(duration)
955 });
956 tracing::Span::current().record("connection_id", format!("{}", connection_id));
957 tracing::info!("connection opened");
958
959 let http_client = match IsahcHttpClient::new() {
960 Ok(http_client) => Arc::new(http_client),
961 Err(error) => {
962 tracing::error!(?error, "failed to create HTTP client");
963 return;
964 }
965 };
966
967 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
968 Some(Arc::new(SupermavenAdminApi::new(
969 supermaven_admin_api_key.to_string(),
970 http_client.clone(),
971 )))
972 } else {
973 None
974 };
975
976 let session = Session {
977 principal: principal.clone(),
978 connection_id,
979 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
980 peer: this.peer.clone(),
981 connection_pool: this.connection_pool.clone(),
982 live_kit_client: this.app_state.live_kit_client.clone(),
983 http_client,
984 rate_limiter: this.app_state.rate_limiter.clone(),
985 _executor: executor.clone(),
986 supermaven_client,
987 };
988
989 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
990 tracing::error!(?error, "failed to send initial client update");
991 return;
992 }
993
994 let handle_io = handle_io.fuse();
995 futures::pin_mut!(handle_io);
996
997 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
998 // This prevents deadlocks when e.g., client A performs a request to client B and
999 // client B performs a request to client A. If both clients stop processing further
1000 // messages until their respective request completes, they won't have a chance to
1001 // respond to the other client's request and cause a deadlock.
1002 //
1003 // This arrangement ensures we will attempt to process earlier messages first, but fall
1004 // back to processing messages arrived later in the spirit of making progress.
1005 let mut foreground_message_handlers = FuturesUnordered::new();
1006 let concurrent_handlers = Arc::new(Semaphore::new(256));
1007 loop {
1008 let next_message = async {
1009 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1010 let message = incoming_rx.next().await;
1011 (permit, message)
1012 }.fuse();
1013 futures::pin_mut!(next_message);
1014 futures::select_biased! {
1015 _ = teardown.changed().fuse() => return,
1016 result = handle_io => {
1017 if let Err(error) = result {
1018 tracing::error!(?error, "error handling I/O");
1019 }
1020 break;
1021 }
1022 _ = foreground_message_handlers.next() => {}
1023 next_message = next_message => {
1024 let (permit, message) = next_message;
1025 if let Some(message) = message {
1026 let type_name = message.payload_type_name();
1027 // note: we copy all the fields from the parent span so we can query them in the logs.
1028 // (https://github.com/tokio-rs/tracing/issues/2670).
1029 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1030 user_id=field::Empty,
1031 login=field::Empty,
1032 impersonator=field::Empty,
1033 dev_server_id=field::Empty
1034 );
1035 principal.update_span(&span);
1036 let span_enter = span.enter();
1037 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1038 let is_background = message.is_background();
1039 let handle_message = (handler)(message, session.clone());
1040 drop(span_enter);
1041
1042 let handle_message = async move {
1043 handle_message.await;
1044 drop(permit);
1045 }.instrument(span);
1046 if is_background {
1047 executor.spawn_detached(handle_message);
1048 } else {
1049 foreground_message_handlers.push(handle_message);
1050 }
1051 } else {
1052 tracing::error!("no message handler");
1053 }
1054 } else {
1055 tracing::info!("connection closed");
1056 break;
1057 }
1058 }
1059 }
1060 }
1061
1062 drop(foreground_message_handlers);
1063 tracing::info!("signing out");
1064 if let Err(error) = connection_lost(session, teardown, executor).await {
1065 tracing::error!(?error, "error signing out");
1066 }
1067
1068 }.instrument(span)
1069 }
1070
1071 async fn send_initial_client_update(
1072 &self,
1073 connection_id: ConnectionId,
1074 principal: &Principal,
1075 zed_version: ZedVersion,
1076 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1077 session: &Session,
1078 ) -> Result<()> {
1079 self.peer.send(
1080 connection_id,
1081 proto::Hello {
1082 peer_id: Some(connection_id.into()),
1083 },
1084 )?;
1085 tracing::info!("sent hello message");
1086 if let Some(send_connection_id) = send_connection_id.take() {
1087 let _ = send_connection_id.send(connection_id);
1088 }
1089
1090 match principal {
1091 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1092 if !user.connected_once {
1093 self.peer.send(connection_id, proto::ShowContacts {})?;
1094 self.app_state
1095 .db
1096 .set_user_connected_once(user.id, true)
1097 .await?;
1098 }
1099
1100 let (contacts, channels_for_user, channel_invites, dev_server_projects) =
1101 future::try_join4(
1102 self.app_state.db.get_contacts(user.id),
1103 self.app_state.db.get_channels_for_user(user.id),
1104 self.app_state.db.get_channel_invites_for_user(user.id),
1105 self.app_state.db.dev_server_projects_update(user.id),
1106 )
1107 .await?;
1108
1109 {
1110 let mut pool = self.connection_pool.lock();
1111 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1112 for membership in &channels_for_user.channel_memberships {
1113 pool.subscribe_to_channel(user.id, membership.channel_id, membership.role)
1114 }
1115 self.peer.send(
1116 connection_id,
1117 build_initial_contacts_update(contacts, &pool),
1118 )?;
1119 self.peer.send(
1120 connection_id,
1121 build_update_user_channels(&channels_for_user),
1122 )?;
1123 self.peer.send(
1124 connection_id,
1125 build_channels_update(channels_for_user, channel_invites),
1126 )?;
1127 }
1128 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1129
1130 if let Some(incoming_call) =
1131 self.app_state.db.incoming_call_for_user(user.id).await?
1132 {
1133 self.peer.send(connection_id, incoming_call)?;
1134 }
1135
1136 update_user_contacts(user.id, &session).await?;
1137 }
1138 Principal::DevServer(dev_server) => {
1139 {
1140 let mut pool = self.connection_pool.lock();
1141 if pool.dev_server_connection_id(dev_server.id).is_some() {
1142 return Err(anyhow!(ErrorCode::DevServerAlreadyOnline))?;
1143 };
1144 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1145 }
1146
1147 let projects = self
1148 .app_state
1149 .db
1150 .get_projects_for_dev_server(dev_server.id)
1151 .await?;
1152 self.peer
1153 .send(connection_id, proto::DevServerInstructions { projects })?;
1154
1155 let status = self
1156 .app_state
1157 .db
1158 .dev_server_projects_update(dev_server.user_id)
1159 .await?;
1160 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1161 }
1162 }
1163
1164 Ok(())
1165 }
1166
1167 pub async fn invite_code_redeemed(
1168 self: &Arc<Self>,
1169 inviter_id: UserId,
1170 invitee_id: UserId,
1171 ) -> Result<()> {
1172 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1173 if let Some(code) = &user.invite_code {
1174 let pool = self.connection_pool.lock();
1175 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1176 for connection_id in pool.user_connection_ids(inviter_id) {
1177 self.peer.send(
1178 connection_id,
1179 proto::UpdateContacts {
1180 contacts: vec![invitee_contact.clone()],
1181 ..Default::default()
1182 },
1183 )?;
1184 self.peer.send(
1185 connection_id,
1186 proto::UpdateInviteInfo {
1187 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1188 count: user.invite_count as u32,
1189 },
1190 )?;
1191 }
1192 }
1193 }
1194 Ok(())
1195 }
1196
1197 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1198 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1199 if let Some(invite_code) = &user.invite_code {
1200 let pool = self.connection_pool.lock();
1201 for connection_id in pool.user_connection_ids(user_id) {
1202 self.peer.send(
1203 connection_id,
1204 proto::UpdateInviteInfo {
1205 url: format!(
1206 "{}{}",
1207 self.app_state.config.invite_link_prefix, invite_code
1208 ),
1209 count: user.invite_count as u32,
1210 },
1211 )?;
1212 }
1213 }
1214 }
1215 Ok(())
1216 }
1217
1218 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1219 ServerSnapshot {
1220 connection_pool: ConnectionPoolGuard {
1221 guard: self.connection_pool.lock(),
1222 _not_send: PhantomData,
1223 },
1224 peer: &self.peer,
1225 }
1226 }
1227}
1228
1229impl<'a> Deref for ConnectionPoolGuard<'a> {
1230 type Target = ConnectionPool;
1231
1232 fn deref(&self) -> &Self::Target {
1233 &self.guard
1234 }
1235}
1236
1237impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1238 fn deref_mut(&mut self) -> &mut Self::Target {
1239 &mut self.guard
1240 }
1241}
1242
1243impl<'a> Drop for ConnectionPoolGuard<'a> {
1244 fn drop(&mut self) {
1245 #[cfg(test)]
1246 self.check_invariants();
1247 }
1248}
1249
1250fn broadcast<F>(
1251 sender_id: Option<ConnectionId>,
1252 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1253 mut f: F,
1254) where
1255 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1256{
1257 for receiver_id in receiver_ids {
1258 if Some(receiver_id) != sender_id {
1259 if let Err(error) = f(receiver_id) {
1260 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1261 }
1262 }
1263 }
1264}
1265
1266pub struct ProtocolVersion(u32);
1267
1268impl Header for ProtocolVersion {
1269 fn name() -> &'static HeaderName {
1270 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1271 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1272 }
1273
1274 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1275 where
1276 Self: Sized,
1277 I: Iterator<Item = &'i axum::http::HeaderValue>,
1278 {
1279 let version = values
1280 .next()
1281 .ok_or_else(axum::headers::Error::invalid)?
1282 .to_str()
1283 .map_err(|_| axum::headers::Error::invalid())?
1284 .parse()
1285 .map_err(|_| axum::headers::Error::invalid())?;
1286 Ok(Self(version))
1287 }
1288
1289 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1290 values.extend([self.0.to_string().parse().unwrap()]);
1291 }
1292}
1293
1294pub struct AppVersionHeader(SemanticVersion);
1295impl Header for AppVersionHeader {
1296 fn name() -> &'static HeaderName {
1297 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1298 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1299 }
1300
1301 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1302 where
1303 Self: Sized,
1304 I: Iterator<Item = &'i axum::http::HeaderValue>,
1305 {
1306 let version = values
1307 .next()
1308 .ok_or_else(axum::headers::Error::invalid)?
1309 .to_str()
1310 .map_err(|_| axum::headers::Error::invalid())?
1311 .parse()
1312 .map_err(|_| axum::headers::Error::invalid())?;
1313 Ok(Self(version))
1314 }
1315
1316 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1317 values.extend([self.0.to_string().parse().unwrap()]);
1318 }
1319}
1320
1321pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1322 Router::new()
1323 .route("/rpc", get(handle_websocket_request))
1324 .layer(
1325 ServiceBuilder::new()
1326 .layer(Extension(server.app_state.clone()))
1327 .layer(middleware::from_fn(auth::validate_header)),
1328 )
1329 .route("/metrics", get(handle_metrics))
1330 .layer(Extension(server))
1331}
1332
1333pub async fn handle_websocket_request(
1334 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1335 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1336 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1337 Extension(server): Extension<Arc<Server>>,
1338 Extension(principal): Extension<Principal>,
1339 ws: WebSocketUpgrade,
1340) -> axum::response::Response {
1341 if protocol_version != rpc::PROTOCOL_VERSION {
1342 return (
1343 StatusCode::UPGRADE_REQUIRED,
1344 "client must be upgraded".to_string(),
1345 )
1346 .into_response();
1347 }
1348
1349 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1350 return (
1351 StatusCode::UPGRADE_REQUIRED,
1352 "no version header found".to_string(),
1353 )
1354 .into_response();
1355 };
1356
1357 if !version.can_collaborate() {
1358 return (
1359 StatusCode::UPGRADE_REQUIRED,
1360 "client must be upgraded".to_string(),
1361 )
1362 .into_response();
1363 }
1364
1365 let socket_address = socket_address.to_string();
1366 ws.on_upgrade(move |socket| {
1367 let socket = socket
1368 .map_ok(to_tungstenite_message)
1369 .err_into()
1370 .with(|message| async move { Ok(to_axum_message(message)) });
1371 let connection = Connection::new(Box::pin(socket));
1372 async move {
1373 server
1374 .handle_connection(
1375 connection,
1376 socket_address,
1377 principal,
1378 version,
1379 None,
1380 Executor::Production,
1381 )
1382 .await;
1383 }
1384 })
1385}
1386
1387pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1388 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1389 let connections_metric = CONNECTIONS_METRIC
1390 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1391
1392 let connections = server
1393 .connection_pool
1394 .lock()
1395 .connections()
1396 .filter(|connection| !connection.admin)
1397 .count();
1398 connections_metric.set(connections as _);
1399
1400 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1401 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1402 register_int_gauge!(
1403 "shared_projects",
1404 "number of open projects with one or more guests"
1405 )
1406 .unwrap()
1407 });
1408
1409 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1410 shared_projects_metric.set(shared_projects as _);
1411
1412 let encoder = prometheus::TextEncoder::new();
1413 let metric_families = prometheus::gather();
1414 let encoded_metrics = encoder
1415 .encode_to_string(&metric_families)
1416 .map_err(|err| anyhow!("{}", err))?;
1417 Ok(encoded_metrics)
1418}
1419
1420#[instrument(err, skip(executor))]
1421async fn connection_lost(
1422 session: Session,
1423 mut teardown: watch::Receiver<bool>,
1424 executor: Executor,
1425) -> Result<()> {
1426 session.peer.disconnect(session.connection_id);
1427 session
1428 .connection_pool()
1429 .await
1430 .remove_connection(session.connection_id)?;
1431
1432 session
1433 .db()
1434 .await
1435 .connection_lost(session.connection_id)
1436 .await
1437 .trace_err();
1438
1439 futures::select_biased! {
1440 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1441 match &session.principal {
1442 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1443 let session = session.for_user().unwrap();
1444
1445 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1446 leave_room_for_session(&session, session.connection_id).await.trace_err();
1447 leave_channel_buffers_for_session(&session)
1448 .await
1449 .trace_err();
1450
1451 if !session
1452 .connection_pool()
1453 .await
1454 .is_user_online(session.user_id())
1455 {
1456 let db = session.db().await;
1457 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1458 room_updated(&room, &session.peer);
1459 }
1460 }
1461
1462 update_user_contacts(session.user_id(), &session).await?;
1463 },
1464 Principal::DevServer(_) => {
1465 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1466 },
1467 }
1468 },
1469 _ = teardown.changed().fuse() => {}
1470 }
1471
1472 Ok(())
1473}
1474
1475/// Acknowledges a ping from a client, used to keep the connection alive.
1476async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1477 response.send(proto::Ack {})?;
1478 Ok(())
1479}
1480
1481/// Creates a new room for calling (outside of channels)
1482async fn create_room(
1483 _request: proto::CreateRoom,
1484 response: Response<proto::CreateRoom>,
1485 session: UserSession,
1486) -> Result<()> {
1487 let live_kit_room = nanoid::nanoid!(30);
1488
1489 let live_kit_connection_info = util::maybe!(async {
1490 let live_kit = session.live_kit_client.as_ref();
1491 let live_kit = live_kit?;
1492 let user_id = session.user_id().to_string();
1493
1494 let token = live_kit
1495 .room_token(&live_kit_room, &user_id.to_string())
1496 .trace_err()?;
1497
1498 Some(proto::LiveKitConnectionInfo {
1499 server_url: live_kit.url().into(),
1500 token,
1501 can_publish: true,
1502 })
1503 })
1504 .await;
1505
1506 let room = session
1507 .db()
1508 .await
1509 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1510 .await?;
1511
1512 response.send(proto::CreateRoomResponse {
1513 room: Some(room.clone()),
1514 live_kit_connection_info,
1515 })?;
1516
1517 update_user_contacts(session.user_id(), &session).await?;
1518 Ok(())
1519}
1520
1521/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1522async fn join_room(
1523 request: proto::JoinRoom,
1524 response: Response<proto::JoinRoom>,
1525 session: UserSession,
1526) -> Result<()> {
1527 let room_id = RoomId::from_proto(request.id);
1528
1529 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1530
1531 if let Some(channel_id) = channel_id {
1532 return join_channel_internal(channel_id, Box::new(response), session).await;
1533 }
1534
1535 let joined_room = {
1536 let room = session
1537 .db()
1538 .await
1539 .join_room(room_id, session.user_id(), session.connection_id)
1540 .await?;
1541 room_updated(&room.room, &session.peer);
1542 room.into_inner()
1543 };
1544
1545 for connection_id in session
1546 .connection_pool()
1547 .await
1548 .user_connection_ids(session.user_id())
1549 {
1550 session
1551 .peer
1552 .send(
1553 connection_id,
1554 proto::CallCanceled {
1555 room_id: room_id.to_proto(),
1556 },
1557 )
1558 .trace_err();
1559 }
1560
1561 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1562 if let Some(token) = live_kit
1563 .room_token(
1564 &joined_room.room.live_kit_room,
1565 &session.user_id().to_string(),
1566 )
1567 .trace_err()
1568 {
1569 Some(proto::LiveKitConnectionInfo {
1570 server_url: live_kit.url().into(),
1571 token,
1572 can_publish: true,
1573 })
1574 } else {
1575 None
1576 }
1577 } else {
1578 None
1579 };
1580
1581 response.send(proto::JoinRoomResponse {
1582 room: Some(joined_room.room),
1583 channel_id: None,
1584 live_kit_connection_info,
1585 })?;
1586
1587 update_user_contacts(session.user_id(), &session).await?;
1588 Ok(())
1589}
1590
1591/// Rejoin room is used to reconnect to a room after connection errors.
1592async fn rejoin_room(
1593 request: proto::RejoinRoom,
1594 response: Response<proto::RejoinRoom>,
1595 session: UserSession,
1596) -> Result<()> {
1597 let room;
1598 let channel;
1599 {
1600 let mut rejoined_room = session
1601 .db()
1602 .await
1603 .rejoin_room(request, session.user_id(), session.connection_id)
1604 .await?;
1605
1606 response.send(proto::RejoinRoomResponse {
1607 room: Some(rejoined_room.room.clone()),
1608 reshared_projects: rejoined_room
1609 .reshared_projects
1610 .iter()
1611 .map(|project| proto::ResharedProject {
1612 id: project.id.to_proto(),
1613 collaborators: project
1614 .collaborators
1615 .iter()
1616 .map(|collaborator| collaborator.to_proto())
1617 .collect(),
1618 })
1619 .collect(),
1620 rejoined_projects: rejoined_room
1621 .rejoined_projects
1622 .iter()
1623 .map(|rejoined_project| rejoined_project.to_proto())
1624 .collect(),
1625 })?;
1626 room_updated(&rejoined_room.room, &session.peer);
1627
1628 for project in &rejoined_room.reshared_projects {
1629 for collaborator in &project.collaborators {
1630 session
1631 .peer
1632 .send(
1633 collaborator.connection_id,
1634 proto::UpdateProjectCollaborator {
1635 project_id: project.id.to_proto(),
1636 old_peer_id: Some(project.old_connection_id.into()),
1637 new_peer_id: Some(session.connection_id.into()),
1638 },
1639 )
1640 .trace_err();
1641 }
1642
1643 broadcast(
1644 Some(session.connection_id),
1645 project
1646 .collaborators
1647 .iter()
1648 .map(|collaborator| collaborator.connection_id),
1649 |connection_id| {
1650 session.peer.forward_send(
1651 session.connection_id,
1652 connection_id,
1653 proto::UpdateProject {
1654 project_id: project.id.to_proto(),
1655 worktrees: project.worktrees.clone(),
1656 },
1657 )
1658 },
1659 );
1660 }
1661
1662 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1663
1664 let rejoined_room = rejoined_room.into_inner();
1665
1666 room = rejoined_room.room;
1667 channel = rejoined_room.channel;
1668 }
1669
1670 if let Some(channel) = channel {
1671 channel_updated(
1672 &channel,
1673 &room,
1674 &session.peer,
1675 &*session.connection_pool().await,
1676 );
1677 }
1678
1679 update_user_contacts(session.user_id(), &session).await?;
1680 Ok(())
1681}
1682
1683fn notify_rejoined_projects(
1684 rejoined_projects: &mut Vec<RejoinedProject>,
1685 session: &UserSession,
1686) -> Result<()> {
1687 for project in rejoined_projects.iter() {
1688 for collaborator in &project.collaborators {
1689 session
1690 .peer
1691 .send(
1692 collaborator.connection_id,
1693 proto::UpdateProjectCollaborator {
1694 project_id: project.id.to_proto(),
1695 old_peer_id: Some(project.old_connection_id.into()),
1696 new_peer_id: Some(session.connection_id.into()),
1697 },
1698 )
1699 .trace_err();
1700 }
1701 }
1702
1703 for project in rejoined_projects {
1704 for worktree in mem::take(&mut project.worktrees) {
1705 #[cfg(any(test, feature = "test-support"))]
1706 const MAX_CHUNK_SIZE: usize = 2;
1707 #[cfg(not(any(test, feature = "test-support")))]
1708 const MAX_CHUNK_SIZE: usize = 256;
1709
1710 // Stream this worktree's entries.
1711 let message = proto::UpdateWorktree {
1712 project_id: project.id.to_proto(),
1713 worktree_id: worktree.id,
1714 abs_path: worktree.abs_path.clone(),
1715 root_name: worktree.root_name,
1716 updated_entries: worktree.updated_entries,
1717 removed_entries: worktree.removed_entries,
1718 scan_id: worktree.scan_id,
1719 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1720 updated_repositories: worktree.updated_repositories,
1721 removed_repositories: worktree.removed_repositories,
1722 };
1723 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1724 session.peer.send(session.connection_id, update.clone())?;
1725 }
1726
1727 // Stream this worktree's diagnostics.
1728 for summary in worktree.diagnostic_summaries {
1729 session.peer.send(
1730 session.connection_id,
1731 proto::UpdateDiagnosticSummary {
1732 project_id: project.id.to_proto(),
1733 worktree_id: worktree.id,
1734 summary: Some(summary),
1735 },
1736 )?;
1737 }
1738
1739 for settings_file in worktree.settings_files {
1740 session.peer.send(
1741 session.connection_id,
1742 proto::UpdateWorktreeSettings {
1743 project_id: project.id.to_proto(),
1744 worktree_id: worktree.id,
1745 path: settings_file.path,
1746 content: Some(settings_file.content),
1747 },
1748 )?;
1749 }
1750 }
1751
1752 for language_server in &project.language_servers {
1753 session.peer.send(
1754 session.connection_id,
1755 proto::UpdateLanguageServer {
1756 project_id: project.id.to_proto(),
1757 language_server_id: language_server.id,
1758 variant: Some(
1759 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1760 proto::LspDiskBasedDiagnosticsUpdated {},
1761 ),
1762 ),
1763 },
1764 )?;
1765 }
1766 }
1767 Ok(())
1768}
1769
1770/// leave room disconnects from the room.
1771async fn leave_room(
1772 _: proto::LeaveRoom,
1773 response: Response<proto::LeaveRoom>,
1774 session: UserSession,
1775) -> Result<()> {
1776 leave_room_for_session(&session, session.connection_id).await?;
1777 response.send(proto::Ack {})?;
1778 Ok(())
1779}
1780
1781/// Updates the permissions of someone else in the room.
1782async fn set_room_participant_role(
1783 request: proto::SetRoomParticipantRole,
1784 response: Response<proto::SetRoomParticipantRole>,
1785 session: UserSession,
1786) -> Result<()> {
1787 let user_id = UserId::from_proto(request.user_id);
1788 let role = ChannelRole::from(request.role());
1789
1790 let (live_kit_room, can_publish) = {
1791 let room = session
1792 .db()
1793 .await
1794 .set_room_participant_role(
1795 session.user_id(),
1796 RoomId::from_proto(request.room_id),
1797 user_id,
1798 role,
1799 )
1800 .await?;
1801
1802 let live_kit_room = room.live_kit_room.clone();
1803 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1804 room_updated(&room, &session.peer);
1805 (live_kit_room, can_publish)
1806 };
1807
1808 if let Some(live_kit) = session.live_kit_client.as_ref() {
1809 live_kit
1810 .update_participant(
1811 live_kit_room.clone(),
1812 request.user_id.to_string(),
1813 live_kit_server::proto::ParticipantPermission {
1814 can_subscribe: true,
1815 can_publish,
1816 can_publish_data: can_publish,
1817 hidden: false,
1818 recorder: false,
1819 },
1820 )
1821 .await
1822 .trace_err();
1823 }
1824
1825 response.send(proto::Ack {})?;
1826 Ok(())
1827}
1828
1829/// Call someone else into the current room
1830async fn call(
1831 request: proto::Call,
1832 response: Response<proto::Call>,
1833 session: UserSession,
1834) -> Result<()> {
1835 let room_id = RoomId::from_proto(request.room_id);
1836 let calling_user_id = session.user_id();
1837 let calling_connection_id = session.connection_id;
1838 let called_user_id = UserId::from_proto(request.called_user_id);
1839 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1840 if !session
1841 .db()
1842 .await
1843 .has_contact(calling_user_id, called_user_id)
1844 .await?
1845 {
1846 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1847 }
1848
1849 let incoming_call = {
1850 let (room, incoming_call) = &mut *session
1851 .db()
1852 .await
1853 .call(
1854 room_id,
1855 calling_user_id,
1856 calling_connection_id,
1857 called_user_id,
1858 initial_project_id,
1859 )
1860 .await?;
1861 room_updated(&room, &session.peer);
1862 mem::take(incoming_call)
1863 };
1864 update_user_contacts(called_user_id, &session).await?;
1865
1866 let mut calls = session
1867 .connection_pool()
1868 .await
1869 .user_connection_ids(called_user_id)
1870 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1871 .collect::<FuturesUnordered<_>>();
1872
1873 while let Some(call_response) = calls.next().await {
1874 match call_response.as_ref() {
1875 Ok(_) => {
1876 response.send(proto::Ack {})?;
1877 return Ok(());
1878 }
1879 Err(_) => {
1880 call_response.trace_err();
1881 }
1882 }
1883 }
1884
1885 {
1886 let room = session
1887 .db()
1888 .await
1889 .call_failed(room_id, called_user_id)
1890 .await?;
1891 room_updated(&room, &session.peer);
1892 }
1893 update_user_contacts(called_user_id, &session).await?;
1894
1895 Err(anyhow!("failed to ring user"))?
1896}
1897
1898/// Cancel an outgoing call.
1899async fn cancel_call(
1900 request: proto::CancelCall,
1901 response: Response<proto::CancelCall>,
1902 session: UserSession,
1903) -> Result<()> {
1904 let called_user_id = UserId::from_proto(request.called_user_id);
1905 let room_id = RoomId::from_proto(request.room_id);
1906 {
1907 let room = session
1908 .db()
1909 .await
1910 .cancel_call(room_id, session.connection_id, called_user_id)
1911 .await?;
1912 room_updated(&room, &session.peer);
1913 }
1914
1915 for connection_id in session
1916 .connection_pool()
1917 .await
1918 .user_connection_ids(called_user_id)
1919 {
1920 session
1921 .peer
1922 .send(
1923 connection_id,
1924 proto::CallCanceled {
1925 room_id: room_id.to_proto(),
1926 },
1927 )
1928 .trace_err();
1929 }
1930 response.send(proto::Ack {})?;
1931
1932 update_user_contacts(called_user_id, &session).await?;
1933 Ok(())
1934}
1935
1936/// Decline an incoming call.
1937async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1938 let room_id = RoomId::from_proto(message.room_id);
1939 {
1940 let room = session
1941 .db()
1942 .await
1943 .decline_call(Some(room_id), session.user_id())
1944 .await?
1945 .ok_or_else(|| anyhow!("failed to decline call"))?;
1946 room_updated(&room, &session.peer);
1947 }
1948
1949 for connection_id in session
1950 .connection_pool()
1951 .await
1952 .user_connection_ids(session.user_id())
1953 {
1954 session
1955 .peer
1956 .send(
1957 connection_id,
1958 proto::CallCanceled {
1959 room_id: room_id.to_proto(),
1960 },
1961 )
1962 .trace_err();
1963 }
1964 update_user_contacts(session.user_id(), &session).await?;
1965 Ok(())
1966}
1967
1968/// Updates other participants in the room with your current location.
1969async fn update_participant_location(
1970 request: proto::UpdateParticipantLocation,
1971 response: Response<proto::UpdateParticipantLocation>,
1972 session: UserSession,
1973) -> Result<()> {
1974 let room_id = RoomId::from_proto(request.room_id);
1975 let location = request
1976 .location
1977 .ok_or_else(|| anyhow!("invalid location"))?;
1978
1979 let db = session.db().await;
1980 let room = db
1981 .update_room_participant_location(room_id, session.connection_id, location)
1982 .await?;
1983
1984 room_updated(&room, &session.peer);
1985 response.send(proto::Ack {})?;
1986 Ok(())
1987}
1988
1989/// Share a project into the room.
1990async fn share_project(
1991 request: proto::ShareProject,
1992 response: Response<proto::ShareProject>,
1993 session: UserSession,
1994) -> Result<()> {
1995 let (project_id, room) = &*session
1996 .db()
1997 .await
1998 .share_project(
1999 RoomId::from_proto(request.room_id),
2000 session.connection_id,
2001 &request.worktrees,
2002 request
2003 .dev_server_project_id
2004 .map(|id| DevServerProjectId::from_proto(id)),
2005 )
2006 .await?;
2007 response.send(proto::ShareProjectResponse {
2008 project_id: project_id.to_proto(),
2009 })?;
2010 room_updated(&room, &session.peer);
2011
2012 Ok(())
2013}
2014
2015/// Unshare a project from the room.
2016async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2017 let project_id = ProjectId::from_proto(message.project_id);
2018 unshare_project_internal(
2019 project_id,
2020 session.connection_id,
2021 session.user_id(),
2022 &session,
2023 )
2024 .await
2025}
2026
2027async fn unshare_project_internal(
2028 project_id: ProjectId,
2029 connection_id: ConnectionId,
2030 user_id: Option<UserId>,
2031 session: &Session,
2032) -> Result<()> {
2033 let (room, guest_connection_ids) = &*session
2034 .db()
2035 .await
2036 .unshare_project(project_id, connection_id, user_id)
2037 .await?;
2038
2039 let message = proto::UnshareProject {
2040 project_id: project_id.to_proto(),
2041 };
2042
2043 broadcast(
2044 Some(connection_id),
2045 guest_connection_ids.iter().copied(),
2046 |conn_id| session.peer.send(conn_id, message.clone()),
2047 );
2048 if let Some(room) = room {
2049 room_updated(room, &session.peer);
2050 }
2051
2052 Ok(())
2053}
2054
2055/// DevServer makes a project available online
2056async fn share_dev_server_project(
2057 request: proto::ShareDevServerProject,
2058 response: Response<proto::ShareDevServerProject>,
2059 session: DevServerSession,
2060) -> Result<()> {
2061 let (dev_server_project, user_id, status) = session
2062 .db()
2063 .await
2064 .share_dev_server_project(
2065 DevServerProjectId::from_proto(request.dev_server_project_id),
2066 session.dev_server_id(),
2067 session.connection_id,
2068 &request.worktrees,
2069 )
2070 .await?;
2071 let Some(project_id) = dev_server_project.project_id else {
2072 return Err(anyhow!("failed to share remote project"))?;
2073 };
2074
2075 send_dev_server_projects_update(user_id, status, &session).await;
2076
2077 response.send(proto::ShareProjectResponse { project_id })?;
2078
2079 Ok(())
2080}
2081
2082/// Join someone elses shared project.
2083async fn join_project(
2084 request: proto::JoinProject,
2085 response: Response<proto::JoinProject>,
2086 session: UserSession,
2087) -> Result<()> {
2088 let project_id = ProjectId::from_proto(request.project_id);
2089
2090 tracing::info!(%project_id, "join project");
2091
2092 let db = session.db().await;
2093 let (project, replica_id) = &mut *db
2094 .join_project(project_id, session.connection_id, session.user_id())
2095 .await?;
2096 drop(db);
2097 tracing::info!(%project_id, "join remote project");
2098 join_project_internal(response, session, project, replica_id)
2099}
2100
2101trait JoinProjectInternalResponse {
2102 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2103}
2104impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2105 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2106 Response::<proto::JoinProject>::send(self, result)
2107 }
2108}
2109impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2110 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2111 Response::<proto::JoinHostedProject>::send(self, result)
2112 }
2113}
2114
2115fn join_project_internal(
2116 response: impl JoinProjectInternalResponse,
2117 session: UserSession,
2118 project: &mut Project,
2119 replica_id: &ReplicaId,
2120) -> Result<()> {
2121 let collaborators = project
2122 .collaborators
2123 .iter()
2124 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2125 .map(|collaborator| collaborator.to_proto())
2126 .collect::<Vec<_>>();
2127 let project_id = project.id;
2128 let guest_user_id = session.user_id();
2129
2130 let worktrees = project
2131 .worktrees
2132 .iter()
2133 .map(|(id, worktree)| proto::WorktreeMetadata {
2134 id: *id,
2135 root_name: worktree.root_name.clone(),
2136 visible: worktree.visible,
2137 abs_path: worktree.abs_path.clone(),
2138 })
2139 .collect::<Vec<_>>();
2140
2141 let add_project_collaborator = proto::AddProjectCollaborator {
2142 project_id: project_id.to_proto(),
2143 collaborator: Some(proto::Collaborator {
2144 peer_id: Some(session.connection_id.into()),
2145 replica_id: replica_id.0 as u32,
2146 user_id: guest_user_id.to_proto(),
2147 }),
2148 };
2149
2150 for collaborator in &collaborators {
2151 session
2152 .peer
2153 .send(
2154 collaborator.peer_id.unwrap().into(),
2155 add_project_collaborator.clone(),
2156 )
2157 .trace_err();
2158 }
2159
2160 // First, we send the metadata associated with each worktree.
2161 response.send(proto::JoinProjectResponse {
2162 project_id: project.id.0 as u64,
2163 worktrees: worktrees.clone(),
2164 replica_id: replica_id.0 as u32,
2165 collaborators: collaborators.clone(),
2166 language_servers: project.language_servers.clone(),
2167 role: project.role.into(),
2168 dev_server_project_id: project
2169 .dev_server_project_id
2170 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2171 })?;
2172
2173 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2174 #[cfg(any(test, feature = "test-support"))]
2175 const MAX_CHUNK_SIZE: usize = 2;
2176 #[cfg(not(any(test, feature = "test-support")))]
2177 const MAX_CHUNK_SIZE: usize = 256;
2178
2179 // Stream this worktree's entries.
2180 let message = proto::UpdateWorktree {
2181 project_id: project_id.to_proto(),
2182 worktree_id,
2183 abs_path: worktree.abs_path.clone(),
2184 root_name: worktree.root_name,
2185 updated_entries: worktree.entries,
2186 removed_entries: Default::default(),
2187 scan_id: worktree.scan_id,
2188 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2189 updated_repositories: worktree.repository_entries.into_values().collect(),
2190 removed_repositories: Default::default(),
2191 };
2192 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2193 session.peer.send(session.connection_id, update.clone())?;
2194 }
2195
2196 // Stream this worktree's diagnostics.
2197 for summary in worktree.diagnostic_summaries {
2198 session.peer.send(
2199 session.connection_id,
2200 proto::UpdateDiagnosticSummary {
2201 project_id: project_id.to_proto(),
2202 worktree_id: worktree.id,
2203 summary: Some(summary),
2204 },
2205 )?;
2206 }
2207
2208 for settings_file in worktree.settings_files {
2209 session.peer.send(
2210 session.connection_id,
2211 proto::UpdateWorktreeSettings {
2212 project_id: project_id.to_proto(),
2213 worktree_id: worktree.id,
2214 path: settings_file.path,
2215 content: Some(settings_file.content),
2216 },
2217 )?;
2218 }
2219 }
2220
2221 for language_server in &project.language_servers {
2222 session.peer.send(
2223 session.connection_id,
2224 proto::UpdateLanguageServer {
2225 project_id: project_id.to_proto(),
2226 language_server_id: language_server.id,
2227 variant: Some(
2228 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2229 proto::LspDiskBasedDiagnosticsUpdated {},
2230 ),
2231 ),
2232 },
2233 )?;
2234 }
2235
2236 Ok(())
2237}
2238
2239/// Leave someone elses shared project.
2240async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2241 let sender_id = session.connection_id;
2242 let project_id = ProjectId::from_proto(request.project_id);
2243 let db = session.db().await;
2244 if db.is_hosted_project(project_id).await? {
2245 let project = db.leave_hosted_project(project_id, sender_id).await?;
2246 project_left(&project, &session);
2247 return Ok(());
2248 }
2249
2250 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2251 tracing::info!(
2252 %project_id,
2253 "leave project"
2254 );
2255
2256 project_left(&project, &session);
2257 if let Some(room) = room {
2258 room_updated(&room, &session.peer);
2259 }
2260
2261 Ok(())
2262}
2263
2264async fn join_hosted_project(
2265 request: proto::JoinHostedProject,
2266 response: Response<proto::JoinHostedProject>,
2267 session: UserSession,
2268) -> Result<()> {
2269 let (mut project, replica_id) = session
2270 .db()
2271 .await
2272 .join_hosted_project(
2273 ProjectId(request.project_id as i32),
2274 session.user_id(),
2275 session.connection_id,
2276 )
2277 .await?;
2278
2279 join_project_internal(response, session, &mut project, &replica_id)
2280}
2281
2282async fn create_dev_server_project(
2283 request: proto::CreateDevServerProject,
2284 response: Response<proto::CreateDevServerProject>,
2285 session: UserSession,
2286) -> Result<()> {
2287 let dev_server_id = DevServerId(request.dev_server_id as i32);
2288 let dev_server_connection_id = session
2289 .connection_pool()
2290 .await
2291 .dev_server_connection_id(dev_server_id);
2292 let Some(dev_server_connection_id) = dev_server_connection_id else {
2293 Err(ErrorCode::DevServerOffline
2294 .message("Cannot create a remote project when the dev server is offline".to_string())
2295 .anyhow())?
2296 };
2297
2298 let path = request.path.clone();
2299 //Check that the path exists on the dev server
2300 session
2301 .peer
2302 .forward_request(
2303 session.connection_id,
2304 dev_server_connection_id,
2305 proto::ValidateDevServerProjectRequest { path: path.clone() },
2306 )
2307 .await?;
2308
2309 let (dev_server_project, update) = session
2310 .db()
2311 .await
2312 .create_dev_server_project(
2313 DevServerId(request.dev_server_id as i32),
2314 &request.path,
2315 session.user_id(),
2316 )
2317 .await?;
2318
2319 let projects = session
2320 .db()
2321 .await
2322 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2323 .await?;
2324
2325 session.peer.send(
2326 dev_server_connection_id,
2327 proto::DevServerInstructions { projects },
2328 )?;
2329
2330 send_dev_server_projects_update(session.user_id(), update, &session).await;
2331
2332 response.send(proto::CreateDevServerProjectResponse {
2333 dev_server_project: Some(dev_server_project.to_proto(None)),
2334 })?;
2335 Ok(())
2336}
2337
2338async fn create_dev_server(
2339 request: proto::CreateDevServer,
2340 response: Response<proto::CreateDevServer>,
2341 session: UserSession,
2342) -> Result<()> {
2343 let access_token = auth::random_token();
2344 let hashed_access_token = auth::hash_access_token(&access_token);
2345
2346 let (dev_server, status) = session
2347 .db()
2348 .await
2349 .create_dev_server(&request.name, &hashed_access_token, session.user_id())
2350 .await?;
2351
2352 send_dev_server_projects_update(session.user_id(), status, &session).await;
2353
2354 response.send(proto::CreateDevServerResponse {
2355 dev_server_id: dev_server.id.0 as u64,
2356 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2357 name: request.name.clone(),
2358 })?;
2359 Ok(())
2360}
2361
2362async fn delete_dev_server(
2363 request: proto::DeleteDevServer,
2364 response: Response<proto::DeleteDevServer>,
2365 session: UserSession,
2366) -> Result<()> {
2367 let dev_server_id = DevServerId(request.dev_server_id as i32);
2368 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2369 if dev_server.user_id != session.user_id() {
2370 return Err(anyhow!(ErrorCode::Forbidden))?;
2371 }
2372
2373 let connection_id = session
2374 .connection_pool()
2375 .await
2376 .dev_server_connection_id(dev_server_id);
2377 if let Some(connection_id) = connection_id {
2378 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2379 session
2380 .peer
2381 .send(connection_id, proto::ShutdownDevServer {})?;
2382 }
2383
2384 let status = session
2385 .db()
2386 .await
2387 .delete_dev_server(dev_server_id, session.user_id())
2388 .await?;
2389
2390 send_dev_server_projects_update(session.user_id(), status, &session).await;
2391
2392 response.send(proto::Ack {})?;
2393 Ok(())
2394}
2395
2396async fn delete_dev_server_project(
2397 request: proto::DeleteDevServerProject,
2398 response: Response<proto::DeleteDevServerProject>,
2399 session: UserSession,
2400) -> Result<()> {
2401 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2402 let dev_server_project = session
2403 .db()
2404 .await
2405 .get_dev_server_project(dev_server_project_id)
2406 .await?;
2407
2408 let dev_server = session
2409 .db()
2410 .await
2411 .get_dev_server(dev_server_project.dev_server_id)
2412 .await?;
2413 if dev_server.user_id != session.user_id() {
2414 return Err(anyhow!(ErrorCode::Forbidden))?;
2415 }
2416
2417 let dev_server_connection_id = session
2418 .connection_pool()
2419 .await
2420 .dev_server_connection_id(dev_server.id);
2421
2422 if let Some(dev_server_connection_id) = dev_server_connection_id {
2423 let project = session
2424 .db()
2425 .await
2426 .find_dev_server_project(dev_server_project_id)
2427 .await;
2428 if let Ok(project) = project {
2429 unshare_project_internal(
2430 project.id,
2431 dev_server_connection_id,
2432 Some(session.user_id()),
2433 &session,
2434 )
2435 .await?;
2436 }
2437 }
2438
2439 let (projects, status) = session
2440 .db()
2441 .await
2442 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2443 .await?;
2444
2445 if let Some(dev_server_connection_id) = dev_server_connection_id {
2446 session.peer.send(
2447 dev_server_connection_id,
2448 proto::DevServerInstructions { projects },
2449 )?;
2450 }
2451
2452 send_dev_server_projects_update(session.user_id(), status, &session).await;
2453
2454 response.send(proto::Ack {})?;
2455 Ok(())
2456}
2457
2458async fn rejoin_dev_server_projects(
2459 request: proto::RejoinRemoteProjects,
2460 response: Response<proto::RejoinRemoteProjects>,
2461 session: UserSession,
2462) -> Result<()> {
2463 let mut rejoined_projects = {
2464 let db = session.db().await;
2465 db.rejoin_dev_server_projects(
2466 &request.rejoined_projects,
2467 session.user_id(),
2468 session.0.connection_id,
2469 )
2470 .await?
2471 };
2472 notify_rejoined_projects(&mut rejoined_projects, &session)?;
2473
2474 response.send(proto::RejoinRemoteProjectsResponse {
2475 rejoined_projects: rejoined_projects
2476 .into_iter()
2477 .map(|project| project.to_proto())
2478 .collect(),
2479 })
2480}
2481
2482async fn reconnect_dev_server(
2483 request: proto::ReconnectDevServer,
2484 response: Response<proto::ReconnectDevServer>,
2485 session: DevServerSession,
2486) -> Result<()> {
2487 let reshared_projects = {
2488 let db = session.db().await;
2489 db.reshare_dev_server_projects(
2490 &request.reshared_projects,
2491 session.dev_server_id(),
2492 session.0.connection_id,
2493 )
2494 .await?
2495 };
2496
2497 for project in &reshared_projects {
2498 for collaborator in &project.collaborators {
2499 session
2500 .peer
2501 .send(
2502 collaborator.connection_id,
2503 proto::UpdateProjectCollaborator {
2504 project_id: project.id.to_proto(),
2505 old_peer_id: Some(project.old_connection_id.into()),
2506 new_peer_id: Some(session.connection_id.into()),
2507 },
2508 )
2509 .trace_err();
2510 }
2511
2512 broadcast(
2513 Some(session.connection_id),
2514 project
2515 .collaborators
2516 .iter()
2517 .map(|collaborator| collaborator.connection_id),
2518 |connection_id| {
2519 session.peer.forward_send(
2520 session.connection_id,
2521 connection_id,
2522 proto::UpdateProject {
2523 project_id: project.id.to_proto(),
2524 worktrees: project.worktrees.clone(),
2525 },
2526 )
2527 },
2528 );
2529 }
2530
2531 response.send(proto::ReconnectDevServerResponse {
2532 reshared_projects: reshared_projects
2533 .iter()
2534 .map(|project| proto::ResharedProject {
2535 id: project.id.to_proto(),
2536 collaborators: project
2537 .collaborators
2538 .iter()
2539 .map(|collaborator| collaborator.to_proto())
2540 .collect(),
2541 })
2542 .collect(),
2543 })?;
2544
2545 Ok(())
2546}
2547
2548async fn shutdown_dev_server(
2549 _: proto::ShutdownDevServer,
2550 response: Response<proto::ShutdownDevServer>,
2551 session: DevServerSession,
2552) -> Result<()> {
2553 response.send(proto::Ack {})?;
2554 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await
2555}
2556
2557async fn shutdown_dev_server_internal(
2558 dev_server_id: DevServerId,
2559 connection_id: ConnectionId,
2560 session: &Session,
2561) -> Result<()> {
2562 let (dev_server_projects, dev_server) = {
2563 let db = session.db().await;
2564 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2565 let dev_server = db.get_dev_server(dev_server_id).await?;
2566 (dev_server_projects, dev_server)
2567 };
2568
2569 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2570 unshare_project_internal(
2571 ProjectId::from_proto(project_id),
2572 connection_id,
2573 None,
2574 session,
2575 )
2576 .await?;
2577 }
2578
2579 session
2580 .connection_pool()
2581 .await
2582 .set_dev_server_offline(dev_server_id);
2583
2584 let status = session
2585 .db()
2586 .await
2587 .dev_server_projects_update(dev_server.user_id)
2588 .await?;
2589 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2590
2591 Ok(())
2592}
2593
2594/// Updates other participants with changes to the project
2595async fn update_project(
2596 request: proto::UpdateProject,
2597 response: Response<proto::UpdateProject>,
2598 session: Session,
2599) -> Result<()> {
2600 let project_id = ProjectId::from_proto(request.project_id);
2601 let (room, guest_connection_ids) = &*session
2602 .db()
2603 .await
2604 .update_project(project_id, session.connection_id, &request.worktrees)
2605 .await?;
2606 broadcast(
2607 Some(session.connection_id),
2608 guest_connection_ids.iter().copied(),
2609 |connection_id| {
2610 session
2611 .peer
2612 .forward_send(session.connection_id, connection_id, request.clone())
2613 },
2614 );
2615 if let Some(room) = room {
2616 room_updated(&room, &session.peer);
2617 }
2618 response.send(proto::Ack {})?;
2619
2620 Ok(())
2621}
2622
2623/// Updates other participants with changes to the worktree
2624async fn update_worktree(
2625 request: proto::UpdateWorktree,
2626 response: Response<proto::UpdateWorktree>,
2627 session: Session,
2628) -> Result<()> {
2629 let guest_connection_ids = session
2630 .db()
2631 .await
2632 .update_worktree(&request, session.connection_id)
2633 .await?;
2634
2635 broadcast(
2636 Some(session.connection_id),
2637 guest_connection_ids.iter().copied(),
2638 |connection_id| {
2639 session
2640 .peer
2641 .forward_send(session.connection_id, connection_id, request.clone())
2642 },
2643 );
2644 response.send(proto::Ack {})?;
2645 Ok(())
2646}
2647
2648/// Updates other participants with changes to the diagnostics
2649async fn update_diagnostic_summary(
2650 message: proto::UpdateDiagnosticSummary,
2651 session: Session,
2652) -> Result<()> {
2653 let guest_connection_ids = session
2654 .db()
2655 .await
2656 .update_diagnostic_summary(&message, session.connection_id)
2657 .await?;
2658
2659 broadcast(
2660 Some(session.connection_id),
2661 guest_connection_ids.iter().copied(),
2662 |connection_id| {
2663 session
2664 .peer
2665 .forward_send(session.connection_id, connection_id, message.clone())
2666 },
2667 );
2668
2669 Ok(())
2670}
2671
2672/// Updates other participants with changes to the worktree settings
2673async fn update_worktree_settings(
2674 message: proto::UpdateWorktreeSettings,
2675 session: Session,
2676) -> Result<()> {
2677 let guest_connection_ids = session
2678 .db()
2679 .await
2680 .update_worktree_settings(&message, session.connection_id)
2681 .await?;
2682
2683 broadcast(
2684 Some(session.connection_id),
2685 guest_connection_ids.iter().copied(),
2686 |connection_id| {
2687 session
2688 .peer
2689 .forward_send(session.connection_id, connection_id, message.clone())
2690 },
2691 );
2692
2693 Ok(())
2694}
2695
2696/// Notify other participants that a language server has started.
2697async fn start_language_server(
2698 request: proto::StartLanguageServer,
2699 session: Session,
2700) -> Result<()> {
2701 let guest_connection_ids = session
2702 .db()
2703 .await
2704 .start_language_server(&request, session.connection_id)
2705 .await?;
2706
2707 broadcast(
2708 Some(session.connection_id),
2709 guest_connection_ids.iter().copied(),
2710 |connection_id| {
2711 session
2712 .peer
2713 .forward_send(session.connection_id, connection_id, request.clone())
2714 },
2715 );
2716 Ok(())
2717}
2718
2719/// Notify other participants that a language server has changed.
2720async fn update_language_server(
2721 request: proto::UpdateLanguageServer,
2722 session: Session,
2723) -> Result<()> {
2724 let project_id = ProjectId::from_proto(request.project_id);
2725 let project_connection_ids = session
2726 .db()
2727 .await
2728 .project_connection_ids(project_id, session.connection_id, true)
2729 .await?;
2730 broadcast(
2731 Some(session.connection_id),
2732 project_connection_ids.iter().copied(),
2733 |connection_id| {
2734 session
2735 .peer
2736 .forward_send(session.connection_id, connection_id, request.clone())
2737 },
2738 );
2739 Ok(())
2740}
2741
2742/// forward a project request to the host. These requests should be read only
2743/// as guests are allowed to send them.
2744async fn forward_read_only_project_request<T>(
2745 request: T,
2746 response: Response<T>,
2747 session: UserSession,
2748) -> Result<()>
2749where
2750 T: EntityMessage + RequestMessage,
2751{
2752 let project_id = ProjectId::from_proto(request.remote_entity_id());
2753 let host_connection_id = session
2754 .db()
2755 .await
2756 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2757 .await?;
2758 let payload = session
2759 .peer
2760 .forward_request(session.connection_id, host_connection_id, request)
2761 .await?;
2762 response.send(payload)?;
2763 Ok(())
2764}
2765
2766/// forward a project request to the host. These requests are disallowed
2767/// for guests.
2768async fn forward_mutating_project_request<T>(
2769 request: T,
2770 response: Response<T>,
2771 session: UserSession,
2772) -> Result<()>
2773where
2774 T: EntityMessage + RequestMessage,
2775{
2776 let project_id = ProjectId::from_proto(request.remote_entity_id());
2777
2778 let host_connection_id = session
2779 .db()
2780 .await
2781 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2782 .await?;
2783 let payload = session
2784 .peer
2785 .forward_request(session.connection_id, host_connection_id, request)
2786 .await?;
2787 response.send(payload)?;
2788 Ok(())
2789}
2790
2791/// forward a project request to the host. These requests are disallowed
2792/// for guests.
2793async fn forward_versioned_mutating_project_request<T>(
2794 request: T,
2795 response: Response<T>,
2796 session: UserSession,
2797) -> Result<()>
2798where
2799 T: EntityMessage + RequestMessage + VersionedMessage,
2800{
2801 let project_id = ProjectId::from_proto(request.remote_entity_id());
2802
2803 let host_connection_id = session
2804 .db()
2805 .await
2806 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2807 .await?;
2808 if let Some(host_version) = session
2809 .connection_pool()
2810 .await
2811 .connection(host_connection_id)
2812 .map(|c| c.zed_version)
2813 {
2814 if let Some(min_required_version) = request.required_host_version() {
2815 if min_required_version > host_version {
2816 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
2817 .with_tag("required", &min_required_version.to_string())))?;
2818 }
2819 }
2820 }
2821
2822 let payload = session
2823 .peer
2824 .forward_request(session.connection_id, host_connection_id, request)
2825 .await?;
2826 response.send(payload)?;
2827 Ok(())
2828}
2829
2830/// Notify other participants that a new buffer has been created
2831async fn create_buffer_for_peer(
2832 request: proto::CreateBufferForPeer,
2833 session: Session,
2834) -> Result<()> {
2835 session
2836 .db()
2837 .await
2838 .check_user_is_project_host(
2839 ProjectId::from_proto(request.project_id),
2840 session.connection_id,
2841 )
2842 .await?;
2843 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
2844 session
2845 .peer
2846 .forward_send(session.connection_id, peer_id.into(), request)?;
2847 Ok(())
2848}
2849
2850/// Notify other participants that a buffer has been updated. This is
2851/// allowed for guests as long as the update is limited to selections.
2852async fn update_buffer(
2853 request: proto::UpdateBuffer,
2854 response: Response<proto::UpdateBuffer>,
2855 session: Session,
2856) -> Result<()> {
2857 let project_id = ProjectId::from_proto(request.project_id);
2858 let mut capability = Capability::ReadOnly;
2859
2860 for op in request.operations.iter() {
2861 match op.variant {
2862 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
2863 Some(_) => capability = Capability::ReadWrite,
2864 }
2865 }
2866
2867 let host = {
2868 let guard = session
2869 .db()
2870 .await
2871 .connections_for_buffer_update(
2872 project_id,
2873 session.principal_id(),
2874 session.connection_id,
2875 capability,
2876 )
2877 .await?;
2878
2879 let (host, guests) = &*guard;
2880
2881 broadcast(
2882 Some(session.connection_id),
2883 guests.clone(),
2884 |connection_id| {
2885 session
2886 .peer
2887 .forward_send(session.connection_id, connection_id, request.clone())
2888 },
2889 );
2890
2891 *host
2892 };
2893
2894 if host != session.connection_id {
2895 session
2896 .peer
2897 .forward_request(session.connection_id, host, request.clone())
2898 .await?;
2899 }
2900
2901 response.send(proto::Ack {})?;
2902 Ok(())
2903}
2904
2905/// Notify other participants that a project has been updated.
2906async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
2907 request: T,
2908 session: Session,
2909) -> Result<()> {
2910 let project_id = ProjectId::from_proto(request.remote_entity_id());
2911 let project_connection_ids = session
2912 .db()
2913 .await
2914 .project_connection_ids(project_id, session.connection_id, false)
2915 .await?;
2916
2917 broadcast(
2918 Some(session.connection_id),
2919 project_connection_ids.iter().copied(),
2920 |connection_id| {
2921 session
2922 .peer
2923 .forward_send(session.connection_id, connection_id, request.clone())
2924 },
2925 );
2926 Ok(())
2927}
2928
2929/// Start following another user in a call.
2930async fn follow(
2931 request: proto::Follow,
2932 response: Response<proto::Follow>,
2933 session: UserSession,
2934) -> Result<()> {
2935 let room_id = RoomId::from_proto(request.room_id);
2936 let project_id = request.project_id.map(ProjectId::from_proto);
2937 let leader_id = request
2938 .leader_id
2939 .ok_or_else(|| anyhow!("invalid leader id"))?
2940 .into();
2941 let follower_id = session.connection_id;
2942
2943 session
2944 .db()
2945 .await
2946 .check_room_participants(room_id, leader_id, session.connection_id)
2947 .await?;
2948
2949 let response_payload = session
2950 .peer
2951 .forward_request(session.connection_id, leader_id, request)
2952 .await?;
2953 response.send(response_payload)?;
2954
2955 if let Some(project_id) = project_id {
2956 let room = session
2957 .db()
2958 .await
2959 .follow(room_id, project_id, leader_id, follower_id)
2960 .await?;
2961 room_updated(&room, &session.peer);
2962 }
2963
2964 Ok(())
2965}
2966
2967/// Stop following another user in a call.
2968async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
2969 let room_id = RoomId::from_proto(request.room_id);
2970 let project_id = request.project_id.map(ProjectId::from_proto);
2971 let leader_id = request
2972 .leader_id
2973 .ok_or_else(|| anyhow!("invalid leader id"))?
2974 .into();
2975 let follower_id = session.connection_id;
2976
2977 session
2978 .db()
2979 .await
2980 .check_room_participants(room_id, leader_id, session.connection_id)
2981 .await?;
2982
2983 session
2984 .peer
2985 .forward_send(session.connection_id, leader_id, request)?;
2986
2987 if let Some(project_id) = project_id {
2988 let room = session
2989 .db()
2990 .await
2991 .unfollow(room_id, project_id, leader_id, follower_id)
2992 .await?;
2993 room_updated(&room, &session.peer);
2994 }
2995
2996 Ok(())
2997}
2998
2999/// Notify everyone following you of your current location.
3000async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3001 let room_id = RoomId::from_proto(request.room_id);
3002 let database = session.db.lock().await;
3003
3004 let connection_ids = if let Some(project_id) = request.project_id {
3005 let project_id = ProjectId::from_proto(project_id);
3006 database
3007 .project_connection_ids(project_id, session.connection_id, true)
3008 .await?
3009 } else {
3010 database
3011 .room_connection_ids(room_id, session.connection_id)
3012 .await?
3013 };
3014
3015 // For now, don't send view update messages back to that view's current leader.
3016 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3017 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3018 _ => None,
3019 });
3020
3021 for connection_id in connection_ids.iter().cloned() {
3022 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3023 session
3024 .peer
3025 .forward_send(session.connection_id, connection_id, request.clone())?;
3026 }
3027 }
3028 Ok(())
3029}
3030
3031/// Get public data about users.
3032async fn get_users(
3033 request: proto::GetUsers,
3034 response: Response<proto::GetUsers>,
3035 session: Session,
3036) -> Result<()> {
3037 let user_ids = request
3038 .user_ids
3039 .into_iter()
3040 .map(UserId::from_proto)
3041 .collect();
3042 let users = session
3043 .db()
3044 .await
3045 .get_users_by_ids(user_ids)
3046 .await?
3047 .into_iter()
3048 .map(|user| proto::User {
3049 id: user.id.to_proto(),
3050 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3051 github_login: user.github_login,
3052 })
3053 .collect();
3054 response.send(proto::UsersResponse { users })?;
3055 Ok(())
3056}
3057
3058/// Search for users (to invite) buy Github login
3059async fn fuzzy_search_users(
3060 request: proto::FuzzySearchUsers,
3061 response: Response<proto::FuzzySearchUsers>,
3062 session: UserSession,
3063) -> Result<()> {
3064 let query = request.query;
3065 let users = match query.len() {
3066 0 => vec![],
3067 1 | 2 => session
3068 .db()
3069 .await
3070 .get_user_by_github_login(&query)
3071 .await?
3072 .into_iter()
3073 .collect(),
3074 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3075 };
3076 let users = users
3077 .into_iter()
3078 .filter(|user| user.id != session.user_id())
3079 .map(|user| proto::User {
3080 id: user.id.to_proto(),
3081 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3082 github_login: user.github_login,
3083 })
3084 .collect();
3085 response.send(proto::UsersResponse { users })?;
3086 Ok(())
3087}
3088
3089/// Send a contact request to another user.
3090async fn request_contact(
3091 request: proto::RequestContact,
3092 response: Response<proto::RequestContact>,
3093 session: UserSession,
3094) -> Result<()> {
3095 let requester_id = session.user_id();
3096 let responder_id = UserId::from_proto(request.responder_id);
3097 if requester_id == responder_id {
3098 return Err(anyhow!("cannot add yourself as a contact"))?;
3099 }
3100
3101 let notifications = session
3102 .db()
3103 .await
3104 .send_contact_request(requester_id, responder_id)
3105 .await?;
3106
3107 // Update outgoing contact requests of requester
3108 let mut update = proto::UpdateContacts::default();
3109 update.outgoing_requests.push(responder_id.to_proto());
3110 for connection_id in session
3111 .connection_pool()
3112 .await
3113 .user_connection_ids(requester_id)
3114 {
3115 session.peer.send(connection_id, update.clone())?;
3116 }
3117
3118 // Update incoming contact requests of responder
3119 let mut update = proto::UpdateContacts::default();
3120 update
3121 .incoming_requests
3122 .push(proto::IncomingContactRequest {
3123 requester_id: requester_id.to_proto(),
3124 });
3125 let connection_pool = session.connection_pool().await;
3126 for connection_id in connection_pool.user_connection_ids(responder_id) {
3127 session.peer.send(connection_id, update.clone())?;
3128 }
3129
3130 send_notifications(&connection_pool, &session.peer, notifications);
3131
3132 response.send(proto::Ack {})?;
3133 Ok(())
3134}
3135
3136/// Accept or decline a contact request
3137async fn respond_to_contact_request(
3138 request: proto::RespondToContactRequest,
3139 response: Response<proto::RespondToContactRequest>,
3140 session: UserSession,
3141) -> Result<()> {
3142 let responder_id = session.user_id();
3143 let requester_id = UserId::from_proto(request.requester_id);
3144 let db = session.db().await;
3145 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3146 db.dismiss_contact_notification(responder_id, requester_id)
3147 .await?;
3148 } else {
3149 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3150
3151 let notifications = db
3152 .respond_to_contact_request(responder_id, requester_id, accept)
3153 .await?;
3154 let requester_busy = db.is_user_busy(requester_id).await?;
3155 let responder_busy = db.is_user_busy(responder_id).await?;
3156
3157 let pool = session.connection_pool().await;
3158 // Update responder with new contact
3159 let mut update = proto::UpdateContacts::default();
3160 if accept {
3161 update
3162 .contacts
3163 .push(contact_for_user(requester_id, requester_busy, &pool));
3164 }
3165 update
3166 .remove_incoming_requests
3167 .push(requester_id.to_proto());
3168 for connection_id in pool.user_connection_ids(responder_id) {
3169 session.peer.send(connection_id, update.clone())?;
3170 }
3171
3172 // Update requester with new contact
3173 let mut update = proto::UpdateContacts::default();
3174 if accept {
3175 update
3176 .contacts
3177 .push(contact_for_user(responder_id, responder_busy, &pool));
3178 }
3179 update
3180 .remove_outgoing_requests
3181 .push(responder_id.to_proto());
3182
3183 for connection_id in pool.user_connection_ids(requester_id) {
3184 session.peer.send(connection_id, update.clone())?;
3185 }
3186
3187 send_notifications(&pool, &session.peer, notifications);
3188 }
3189
3190 response.send(proto::Ack {})?;
3191 Ok(())
3192}
3193
3194/// Remove a contact.
3195async fn remove_contact(
3196 request: proto::RemoveContact,
3197 response: Response<proto::RemoveContact>,
3198 session: UserSession,
3199) -> Result<()> {
3200 let requester_id = session.user_id();
3201 let responder_id = UserId::from_proto(request.user_id);
3202 let db = session.db().await;
3203 let (contact_accepted, deleted_notification_id) =
3204 db.remove_contact(requester_id, responder_id).await?;
3205
3206 let pool = session.connection_pool().await;
3207 // Update outgoing contact requests of requester
3208 let mut update = proto::UpdateContacts::default();
3209 if contact_accepted {
3210 update.remove_contacts.push(responder_id.to_proto());
3211 } else {
3212 update
3213 .remove_outgoing_requests
3214 .push(responder_id.to_proto());
3215 }
3216 for connection_id in pool.user_connection_ids(requester_id) {
3217 session.peer.send(connection_id, update.clone())?;
3218 }
3219
3220 // Update incoming contact requests of responder
3221 let mut update = proto::UpdateContacts::default();
3222 if contact_accepted {
3223 update.remove_contacts.push(requester_id.to_proto());
3224 } else {
3225 update
3226 .remove_incoming_requests
3227 .push(requester_id.to_proto());
3228 }
3229 for connection_id in pool.user_connection_ids(responder_id) {
3230 session.peer.send(connection_id, update.clone())?;
3231 if let Some(notification_id) = deleted_notification_id {
3232 session.peer.send(
3233 connection_id,
3234 proto::DeleteNotification {
3235 notification_id: notification_id.to_proto(),
3236 },
3237 )?;
3238 }
3239 }
3240
3241 response.send(proto::Ack {})?;
3242 Ok(())
3243}
3244
3245/// Creates a new channel.
3246async fn create_channel(
3247 request: proto::CreateChannel,
3248 response: Response<proto::CreateChannel>,
3249 session: UserSession,
3250) -> Result<()> {
3251 let db = session.db().await;
3252
3253 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3254 let (channel, membership) = db
3255 .create_channel(&request.name, parent_id, session.user_id())
3256 .await?;
3257
3258 let root_id = channel.root_id();
3259 let channel = Channel::from_model(channel);
3260
3261 response.send(proto::CreateChannelResponse {
3262 channel: Some(channel.to_proto()),
3263 parent_id: request.parent_id,
3264 })?;
3265
3266 let mut connection_pool = session.connection_pool().await;
3267 if let Some(membership) = membership {
3268 connection_pool.subscribe_to_channel(
3269 membership.user_id,
3270 membership.channel_id,
3271 membership.role,
3272 );
3273 let update = proto::UpdateUserChannels {
3274 channel_memberships: vec![proto::ChannelMembership {
3275 channel_id: membership.channel_id.to_proto(),
3276 role: membership.role.into(),
3277 }],
3278 ..Default::default()
3279 };
3280 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3281 session.peer.send(connection_id, update.clone())?;
3282 }
3283 }
3284
3285 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3286 if !role.can_see_channel(channel.visibility) {
3287 continue;
3288 }
3289
3290 let update = proto::UpdateChannels {
3291 channels: vec![channel.to_proto()],
3292 ..Default::default()
3293 };
3294 session.peer.send(connection_id, update.clone())?;
3295 }
3296
3297 Ok(())
3298}
3299
3300/// Delete a channel
3301async fn delete_channel(
3302 request: proto::DeleteChannel,
3303 response: Response<proto::DeleteChannel>,
3304 session: UserSession,
3305) -> Result<()> {
3306 let db = session.db().await;
3307
3308 let channel_id = request.channel_id;
3309 let (root_channel, removed_channels) = db
3310 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3311 .await?;
3312 response.send(proto::Ack {})?;
3313
3314 // Notify members of removed channels
3315 let mut update = proto::UpdateChannels::default();
3316 update
3317 .delete_channels
3318 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3319
3320 let connection_pool = session.connection_pool().await;
3321 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3322 session.peer.send(connection_id, update.clone())?;
3323 }
3324
3325 Ok(())
3326}
3327
3328/// Invite someone to join a channel.
3329async fn invite_channel_member(
3330 request: proto::InviteChannelMember,
3331 response: Response<proto::InviteChannelMember>,
3332 session: UserSession,
3333) -> Result<()> {
3334 let db = session.db().await;
3335 let channel_id = ChannelId::from_proto(request.channel_id);
3336 let invitee_id = UserId::from_proto(request.user_id);
3337 let InviteMemberResult {
3338 channel,
3339 notifications,
3340 } = db
3341 .invite_channel_member(
3342 channel_id,
3343 invitee_id,
3344 session.user_id(),
3345 request.role().into(),
3346 )
3347 .await?;
3348
3349 let update = proto::UpdateChannels {
3350 channel_invitations: vec![channel.to_proto()],
3351 ..Default::default()
3352 };
3353
3354 let connection_pool = session.connection_pool().await;
3355 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3356 session.peer.send(connection_id, update.clone())?;
3357 }
3358
3359 send_notifications(&connection_pool, &session.peer, notifications);
3360
3361 response.send(proto::Ack {})?;
3362 Ok(())
3363}
3364
3365/// remove someone from a channel
3366async fn remove_channel_member(
3367 request: proto::RemoveChannelMember,
3368 response: Response<proto::RemoveChannelMember>,
3369 session: UserSession,
3370) -> Result<()> {
3371 let db = session.db().await;
3372 let channel_id = ChannelId::from_proto(request.channel_id);
3373 let member_id = UserId::from_proto(request.user_id);
3374
3375 let RemoveChannelMemberResult {
3376 membership_update,
3377 notification_id,
3378 } = db
3379 .remove_channel_member(channel_id, member_id, session.user_id())
3380 .await?;
3381
3382 let mut connection_pool = session.connection_pool().await;
3383 notify_membership_updated(
3384 &mut connection_pool,
3385 membership_update,
3386 member_id,
3387 &session.peer,
3388 );
3389 for connection_id in connection_pool.user_connection_ids(member_id) {
3390 if let Some(notification_id) = notification_id {
3391 session
3392 .peer
3393 .send(
3394 connection_id,
3395 proto::DeleteNotification {
3396 notification_id: notification_id.to_proto(),
3397 },
3398 )
3399 .trace_err();
3400 }
3401 }
3402
3403 response.send(proto::Ack {})?;
3404 Ok(())
3405}
3406
3407/// Toggle the channel between public and private.
3408/// Care is taken to maintain the invariant that public channels only descend from public channels,
3409/// (though members-only channels can appear at any point in the hierarchy).
3410async fn set_channel_visibility(
3411 request: proto::SetChannelVisibility,
3412 response: Response<proto::SetChannelVisibility>,
3413 session: UserSession,
3414) -> Result<()> {
3415 let db = session.db().await;
3416 let channel_id = ChannelId::from_proto(request.channel_id);
3417 let visibility = request.visibility().into();
3418
3419 let channel_model = db
3420 .set_channel_visibility(channel_id, visibility, session.user_id())
3421 .await?;
3422 let root_id = channel_model.root_id();
3423 let channel = Channel::from_model(channel_model);
3424
3425 let mut connection_pool = session.connection_pool().await;
3426 for (user_id, role) in connection_pool
3427 .channel_user_ids(root_id)
3428 .collect::<Vec<_>>()
3429 .into_iter()
3430 {
3431 let update = if role.can_see_channel(channel.visibility) {
3432 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3433 proto::UpdateChannels {
3434 channels: vec![channel.to_proto()],
3435 ..Default::default()
3436 }
3437 } else {
3438 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3439 proto::UpdateChannels {
3440 delete_channels: vec![channel.id.to_proto()],
3441 ..Default::default()
3442 }
3443 };
3444
3445 for connection_id in connection_pool.user_connection_ids(user_id) {
3446 session.peer.send(connection_id, update.clone())?;
3447 }
3448 }
3449
3450 response.send(proto::Ack {})?;
3451 Ok(())
3452}
3453
3454/// Alter the role for a user in the channel.
3455async fn set_channel_member_role(
3456 request: proto::SetChannelMemberRole,
3457 response: Response<proto::SetChannelMemberRole>,
3458 session: UserSession,
3459) -> Result<()> {
3460 let db = session.db().await;
3461 let channel_id = ChannelId::from_proto(request.channel_id);
3462 let member_id = UserId::from_proto(request.user_id);
3463 let result = db
3464 .set_channel_member_role(
3465 channel_id,
3466 session.user_id(),
3467 member_id,
3468 request.role().into(),
3469 )
3470 .await?;
3471
3472 match result {
3473 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3474 let mut connection_pool = session.connection_pool().await;
3475 notify_membership_updated(
3476 &mut connection_pool,
3477 membership_update,
3478 member_id,
3479 &session.peer,
3480 )
3481 }
3482 db::SetMemberRoleResult::InviteUpdated(channel) => {
3483 let update = proto::UpdateChannels {
3484 channel_invitations: vec![channel.to_proto()],
3485 ..Default::default()
3486 };
3487
3488 for connection_id in session
3489 .connection_pool()
3490 .await
3491 .user_connection_ids(member_id)
3492 {
3493 session.peer.send(connection_id, update.clone())?;
3494 }
3495 }
3496 }
3497
3498 response.send(proto::Ack {})?;
3499 Ok(())
3500}
3501
3502/// Change the name of a channel
3503async fn rename_channel(
3504 request: proto::RenameChannel,
3505 response: Response<proto::RenameChannel>,
3506 session: UserSession,
3507) -> Result<()> {
3508 let db = session.db().await;
3509 let channel_id = ChannelId::from_proto(request.channel_id);
3510 let channel_model = db
3511 .rename_channel(channel_id, session.user_id(), &request.name)
3512 .await?;
3513 let root_id = channel_model.root_id();
3514 let channel = Channel::from_model(channel_model);
3515
3516 response.send(proto::RenameChannelResponse {
3517 channel: Some(channel.to_proto()),
3518 })?;
3519
3520 let connection_pool = session.connection_pool().await;
3521 let update = proto::UpdateChannels {
3522 channels: vec![channel.to_proto()],
3523 ..Default::default()
3524 };
3525 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3526 if role.can_see_channel(channel.visibility) {
3527 session.peer.send(connection_id, update.clone())?;
3528 }
3529 }
3530
3531 Ok(())
3532}
3533
3534/// Move a channel to a new parent.
3535async fn move_channel(
3536 request: proto::MoveChannel,
3537 response: Response<proto::MoveChannel>,
3538 session: UserSession,
3539) -> Result<()> {
3540 let channel_id = ChannelId::from_proto(request.channel_id);
3541 let to = ChannelId::from_proto(request.to);
3542
3543 let (root_id, channels) = session
3544 .db()
3545 .await
3546 .move_channel(channel_id, to, session.user_id())
3547 .await?;
3548
3549 let connection_pool = session.connection_pool().await;
3550 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3551 let channels = channels
3552 .iter()
3553 .filter_map(|channel| {
3554 if role.can_see_channel(channel.visibility) {
3555 Some(channel.to_proto())
3556 } else {
3557 None
3558 }
3559 })
3560 .collect::<Vec<_>>();
3561 if channels.is_empty() {
3562 continue;
3563 }
3564
3565 let update = proto::UpdateChannels {
3566 channels,
3567 ..Default::default()
3568 };
3569
3570 session.peer.send(connection_id, update.clone())?;
3571 }
3572
3573 response.send(Ack {})?;
3574 Ok(())
3575}
3576
3577/// Get the list of channel members
3578async fn get_channel_members(
3579 request: proto::GetChannelMembers,
3580 response: Response<proto::GetChannelMembers>,
3581 session: UserSession,
3582) -> Result<()> {
3583 let db = session.db().await;
3584 let channel_id = ChannelId::from_proto(request.channel_id);
3585 let members = db
3586 .get_channel_participant_details(channel_id, session.user_id())
3587 .await?;
3588 response.send(proto::GetChannelMembersResponse { members })?;
3589 Ok(())
3590}
3591
3592/// Accept or decline a channel invitation.
3593async fn respond_to_channel_invite(
3594 request: proto::RespondToChannelInvite,
3595 response: Response<proto::RespondToChannelInvite>,
3596 session: UserSession,
3597) -> Result<()> {
3598 let db = session.db().await;
3599 let channel_id = ChannelId::from_proto(request.channel_id);
3600 let RespondToChannelInvite {
3601 membership_update,
3602 notifications,
3603 } = db
3604 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3605 .await?;
3606
3607 let mut connection_pool = session.connection_pool().await;
3608 if let Some(membership_update) = membership_update {
3609 notify_membership_updated(
3610 &mut connection_pool,
3611 membership_update,
3612 session.user_id(),
3613 &session.peer,
3614 );
3615 } else {
3616 let update = proto::UpdateChannels {
3617 remove_channel_invitations: vec![channel_id.to_proto()],
3618 ..Default::default()
3619 };
3620
3621 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3622 session.peer.send(connection_id, update.clone())?;
3623 }
3624 };
3625
3626 send_notifications(&connection_pool, &session.peer, notifications);
3627
3628 response.send(proto::Ack {})?;
3629
3630 Ok(())
3631}
3632
3633/// Join the channels' room
3634async fn join_channel(
3635 request: proto::JoinChannel,
3636 response: Response<proto::JoinChannel>,
3637 session: UserSession,
3638) -> Result<()> {
3639 let channel_id = ChannelId::from_proto(request.channel_id);
3640 join_channel_internal(channel_id, Box::new(response), session).await
3641}
3642
3643trait JoinChannelInternalResponse {
3644 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3645}
3646impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3647 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3648 Response::<proto::JoinChannel>::send(self, result)
3649 }
3650}
3651impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3652 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3653 Response::<proto::JoinRoom>::send(self, result)
3654 }
3655}
3656
3657async fn join_channel_internal(
3658 channel_id: ChannelId,
3659 response: Box<impl JoinChannelInternalResponse>,
3660 session: UserSession,
3661) -> Result<()> {
3662 let joined_room = {
3663 let mut db = session.db().await;
3664 // If zed quits without leaving the room, and the user re-opens zed before the
3665 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3666 // room they were in.
3667 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3668 tracing::info!(
3669 stale_connection_id = %connection,
3670 "cleaning up stale connection",
3671 );
3672 drop(db);
3673 leave_room_for_session(&session, connection).await?;
3674 db = session.db().await;
3675 }
3676
3677 let (joined_room, membership_updated, role) = db
3678 .join_channel(channel_id, session.user_id(), session.connection_id)
3679 .await?;
3680
3681 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3682 let (can_publish, token) = if role == ChannelRole::Guest {
3683 (
3684 false,
3685 live_kit
3686 .guest_token(
3687 &joined_room.room.live_kit_room,
3688 &session.user_id().to_string(),
3689 )
3690 .trace_err()?,
3691 )
3692 } else {
3693 (
3694 true,
3695 live_kit
3696 .room_token(
3697 &joined_room.room.live_kit_room,
3698 &session.user_id().to_string(),
3699 )
3700 .trace_err()?,
3701 )
3702 };
3703
3704 Some(LiveKitConnectionInfo {
3705 server_url: live_kit.url().into(),
3706 token,
3707 can_publish,
3708 })
3709 });
3710
3711 response.send(proto::JoinRoomResponse {
3712 room: Some(joined_room.room.clone()),
3713 channel_id: joined_room
3714 .channel
3715 .as_ref()
3716 .map(|channel| channel.id.to_proto()),
3717 live_kit_connection_info,
3718 })?;
3719
3720 let mut connection_pool = session.connection_pool().await;
3721 if let Some(membership_updated) = membership_updated {
3722 notify_membership_updated(
3723 &mut connection_pool,
3724 membership_updated,
3725 session.user_id(),
3726 &session.peer,
3727 );
3728 }
3729
3730 room_updated(&joined_room.room, &session.peer);
3731
3732 joined_room
3733 };
3734
3735 channel_updated(
3736 &joined_room
3737 .channel
3738 .ok_or_else(|| anyhow!("channel not returned"))?,
3739 &joined_room.room,
3740 &session.peer,
3741 &*session.connection_pool().await,
3742 );
3743
3744 update_user_contacts(session.user_id(), &session).await?;
3745 Ok(())
3746}
3747
3748/// Start editing the channel notes
3749async fn join_channel_buffer(
3750 request: proto::JoinChannelBuffer,
3751 response: Response<proto::JoinChannelBuffer>,
3752 session: UserSession,
3753) -> Result<()> {
3754 let db = session.db().await;
3755 let channel_id = ChannelId::from_proto(request.channel_id);
3756
3757 let open_response = db
3758 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
3759 .await?;
3760
3761 let collaborators = open_response.collaborators.clone();
3762 response.send(open_response)?;
3763
3764 let update = UpdateChannelBufferCollaborators {
3765 channel_id: channel_id.to_proto(),
3766 collaborators: collaborators.clone(),
3767 };
3768 channel_buffer_updated(
3769 session.connection_id,
3770 collaborators
3771 .iter()
3772 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
3773 &update,
3774 &session.peer,
3775 );
3776
3777 Ok(())
3778}
3779
3780/// Edit the channel notes
3781async fn update_channel_buffer(
3782 request: proto::UpdateChannelBuffer,
3783 session: UserSession,
3784) -> Result<()> {
3785 let db = session.db().await;
3786 let channel_id = ChannelId::from_proto(request.channel_id);
3787
3788 let (collaborators, non_collaborators, epoch, version) = db
3789 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
3790 .await?;
3791
3792 channel_buffer_updated(
3793 session.connection_id,
3794 collaborators,
3795 &proto::UpdateChannelBuffer {
3796 channel_id: channel_id.to_proto(),
3797 operations: request.operations,
3798 },
3799 &session.peer,
3800 );
3801
3802 let pool = &*session.connection_pool().await;
3803
3804 broadcast(
3805 None,
3806 non_collaborators
3807 .iter()
3808 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
3809 |peer_id| {
3810 session.peer.send(
3811 peer_id,
3812 proto::UpdateChannels {
3813 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
3814 channel_id: channel_id.to_proto(),
3815 epoch: epoch as u64,
3816 version: version.clone(),
3817 }],
3818 ..Default::default()
3819 },
3820 )
3821 },
3822 );
3823
3824 Ok(())
3825}
3826
3827/// Rejoin the channel notes after a connection blip
3828async fn rejoin_channel_buffers(
3829 request: proto::RejoinChannelBuffers,
3830 response: Response<proto::RejoinChannelBuffers>,
3831 session: UserSession,
3832) -> Result<()> {
3833 let db = session.db().await;
3834 let buffers = db
3835 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
3836 .await?;
3837
3838 for rejoined_buffer in &buffers {
3839 let collaborators_to_notify = rejoined_buffer
3840 .buffer
3841 .collaborators
3842 .iter()
3843 .filter_map(|c| Some(c.peer_id?.into()));
3844 channel_buffer_updated(
3845 session.connection_id,
3846 collaborators_to_notify,
3847 &proto::UpdateChannelBufferCollaborators {
3848 channel_id: rejoined_buffer.buffer.channel_id,
3849 collaborators: rejoined_buffer.buffer.collaborators.clone(),
3850 },
3851 &session.peer,
3852 );
3853 }
3854
3855 response.send(proto::RejoinChannelBuffersResponse {
3856 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
3857 })?;
3858
3859 Ok(())
3860}
3861
3862/// Stop editing the channel notes
3863async fn leave_channel_buffer(
3864 request: proto::LeaveChannelBuffer,
3865 response: Response<proto::LeaveChannelBuffer>,
3866 session: UserSession,
3867) -> Result<()> {
3868 let db = session.db().await;
3869 let channel_id = ChannelId::from_proto(request.channel_id);
3870
3871 let left_buffer = db
3872 .leave_channel_buffer(channel_id, session.connection_id)
3873 .await?;
3874
3875 response.send(Ack {})?;
3876
3877 channel_buffer_updated(
3878 session.connection_id,
3879 left_buffer.connections,
3880 &proto::UpdateChannelBufferCollaborators {
3881 channel_id: channel_id.to_proto(),
3882 collaborators: left_buffer.collaborators,
3883 },
3884 &session.peer,
3885 );
3886
3887 Ok(())
3888}
3889
3890fn channel_buffer_updated<T: EnvelopedMessage>(
3891 sender_id: ConnectionId,
3892 collaborators: impl IntoIterator<Item = ConnectionId>,
3893 message: &T,
3894 peer: &Peer,
3895) {
3896 broadcast(Some(sender_id), collaborators, |peer_id| {
3897 peer.send(peer_id, message.clone())
3898 });
3899}
3900
3901fn send_notifications(
3902 connection_pool: &ConnectionPool,
3903 peer: &Peer,
3904 notifications: db::NotificationBatch,
3905) {
3906 for (user_id, notification) in notifications {
3907 for connection_id in connection_pool.user_connection_ids(user_id) {
3908 if let Err(error) = peer.send(
3909 connection_id,
3910 proto::AddNotification {
3911 notification: Some(notification.clone()),
3912 },
3913 ) {
3914 tracing::error!(
3915 "failed to send notification to {:?} {}",
3916 connection_id,
3917 error
3918 );
3919 }
3920 }
3921 }
3922}
3923
3924/// Send a message to the channel
3925async fn send_channel_message(
3926 request: proto::SendChannelMessage,
3927 response: Response<proto::SendChannelMessage>,
3928 session: UserSession,
3929) -> Result<()> {
3930 // Validate the message body.
3931 let body = request.body.trim().to_string();
3932 if body.len() > MAX_MESSAGE_LEN {
3933 return Err(anyhow!("message is too long"))?;
3934 }
3935 if body.is_empty() {
3936 return Err(anyhow!("message can't be blank"))?;
3937 }
3938
3939 // TODO: adjust mentions if body is trimmed
3940
3941 let timestamp = OffsetDateTime::now_utc();
3942 let nonce = request
3943 .nonce
3944 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
3945
3946 let channel_id = ChannelId::from_proto(request.channel_id);
3947 let CreatedChannelMessage {
3948 message_id,
3949 participant_connection_ids,
3950 channel_members,
3951 notifications,
3952 } = session
3953 .db()
3954 .await
3955 .create_channel_message(
3956 channel_id,
3957 session.user_id(),
3958 &body,
3959 &request.mentions,
3960 timestamp,
3961 nonce.clone().into(),
3962 match request.reply_to_message_id {
3963 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
3964 None => None,
3965 },
3966 )
3967 .await?;
3968
3969 let message = proto::ChannelMessage {
3970 sender_id: session.user_id().to_proto(),
3971 id: message_id.to_proto(),
3972 body,
3973 mentions: request.mentions,
3974 timestamp: timestamp.unix_timestamp() as u64,
3975 nonce: Some(nonce),
3976 reply_to_message_id: request.reply_to_message_id,
3977 edited_at: None,
3978 };
3979 broadcast(
3980 Some(session.connection_id),
3981 participant_connection_ids,
3982 |connection| {
3983 session.peer.send(
3984 connection,
3985 proto::ChannelMessageSent {
3986 channel_id: channel_id.to_proto(),
3987 message: Some(message.clone()),
3988 },
3989 )
3990 },
3991 );
3992 response.send(proto::SendChannelMessageResponse {
3993 message: Some(message),
3994 })?;
3995
3996 let pool = &*session.connection_pool().await;
3997 broadcast(
3998 None,
3999 channel_members
4000 .iter()
4001 .flat_map(|user_id| pool.user_connection_ids(*user_id)),
4002 |peer_id| {
4003 session.peer.send(
4004 peer_id,
4005 proto::UpdateChannels {
4006 latest_channel_message_ids: vec![proto::ChannelMessageId {
4007 channel_id: channel_id.to_proto(),
4008 message_id: message_id.to_proto(),
4009 }],
4010 ..Default::default()
4011 },
4012 )
4013 },
4014 );
4015 send_notifications(pool, &session.peer, notifications);
4016
4017 Ok(())
4018}
4019
4020/// Delete a channel message
4021async fn remove_channel_message(
4022 request: proto::RemoveChannelMessage,
4023 response: Response<proto::RemoveChannelMessage>,
4024 session: UserSession,
4025) -> Result<()> {
4026 let channel_id = ChannelId::from_proto(request.channel_id);
4027 let message_id = MessageId::from_proto(request.message_id);
4028 let (connection_ids, existing_notification_ids) = session
4029 .db()
4030 .await
4031 .remove_channel_message(channel_id, message_id, session.user_id())
4032 .await?;
4033
4034 broadcast(
4035 Some(session.connection_id),
4036 connection_ids,
4037 move |connection| {
4038 session.peer.send(connection, request.clone())?;
4039
4040 for notification_id in &existing_notification_ids {
4041 session.peer.send(
4042 connection,
4043 proto::DeleteNotification {
4044 notification_id: (*notification_id).to_proto(),
4045 },
4046 )?;
4047 }
4048
4049 Ok(())
4050 },
4051 );
4052 response.send(proto::Ack {})?;
4053 Ok(())
4054}
4055
4056async fn update_channel_message(
4057 request: proto::UpdateChannelMessage,
4058 response: Response<proto::UpdateChannelMessage>,
4059 session: UserSession,
4060) -> Result<()> {
4061 let channel_id = ChannelId::from_proto(request.channel_id);
4062 let message_id = MessageId::from_proto(request.message_id);
4063 let updated_at = OffsetDateTime::now_utc();
4064 let UpdatedChannelMessage {
4065 message_id,
4066 participant_connection_ids,
4067 notifications,
4068 reply_to_message_id,
4069 timestamp,
4070 deleted_mention_notification_ids,
4071 updated_mention_notifications,
4072 } = session
4073 .db()
4074 .await
4075 .update_channel_message(
4076 channel_id,
4077 message_id,
4078 session.user_id(),
4079 request.body.as_str(),
4080 &request.mentions,
4081 updated_at,
4082 )
4083 .await?;
4084
4085 let nonce = request
4086 .nonce
4087 .clone()
4088 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4089
4090 let message = proto::ChannelMessage {
4091 sender_id: session.user_id().to_proto(),
4092 id: message_id.to_proto(),
4093 body: request.body.clone(),
4094 mentions: request.mentions.clone(),
4095 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4096 nonce: Some(nonce),
4097 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4098 edited_at: Some(updated_at.unix_timestamp() as u64),
4099 };
4100
4101 response.send(proto::Ack {})?;
4102
4103 let pool = &*session.connection_pool().await;
4104 broadcast(
4105 Some(session.connection_id),
4106 participant_connection_ids,
4107 |connection| {
4108 session.peer.send(
4109 connection,
4110 proto::ChannelMessageUpdate {
4111 channel_id: channel_id.to_proto(),
4112 message: Some(message.clone()),
4113 },
4114 )?;
4115
4116 for notification_id in &deleted_mention_notification_ids {
4117 session.peer.send(
4118 connection,
4119 proto::DeleteNotification {
4120 notification_id: (*notification_id).to_proto(),
4121 },
4122 )?;
4123 }
4124
4125 for notification in &updated_mention_notifications {
4126 session.peer.send(
4127 connection,
4128 proto::UpdateNotification {
4129 notification: Some(notification.clone()),
4130 },
4131 )?;
4132 }
4133
4134 Ok(())
4135 },
4136 );
4137
4138 send_notifications(pool, &session.peer, notifications);
4139
4140 Ok(())
4141}
4142
4143/// Mark a channel message as read
4144async fn acknowledge_channel_message(
4145 request: proto::AckChannelMessage,
4146 session: UserSession,
4147) -> Result<()> {
4148 let channel_id = ChannelId::from_proto(request.channel_id);
4149 let message_id = MessageId::from_proto(request.message_id);
4150 let notifications = session
4151 .db()
4152 .await
4153 .observe_channel_message(channel_id, session.user_id(), message_id)
4154 .await?;
4155 send_notifications(
4156 &*session.connection_pool().await,
4157 &session.peer,
4158 notifications,
4159 );
4160 Ok(())
4161}
4162
4163/// Mark a buffer version as synced
4164async fn acknowledge_buffer_version(
4165 request: proto::AckBufferOperation,
4166 session: UserSession,
4167) -> Result<()> {
4168 let buffer_id = BufferId::from_proto(request.buffer_id);
4169 session
4170 .db()
4171 .await
4172 .observe_buffer_version(
4173 buffer_id,
4174 session.user_id(),
4175 request.epoch as i32,
4176 &request.version,
4177 )
4178 .await?;
4179 Ok(())
4180}
4181
4182struct CompleteWithLanguageModelRateLimit;
4183
4184impl RateLimit for CompleteWithLanguageModelRateLimit {
4185 fn capacity() -> usize {
4186 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4187 .ok()
4188 .and_then(|v| v.parse().ok())
4189 .unwrap_or(120) // Picked arbitrarily
4190 }
4191
4192 fn refill_duration() -> chrono::Duration {
4193 chrono::Duration::hours(1)
4194 }
4195
4196 fn db_name() -> &'static str {
4197 "complete-with-language-model"
4198 }
4199}
4200
4201async fn complete_with_language_model(
4202 request: proto::CompleteWithLanguageModel,
4203 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4204 session: Session,
4205 open_ai_api_key: Option<Arc<str>>,
4206 google_ai_api_key: Option<Arc<str>>,
4207 anthropic_api_key: Option<Arc<str>>,
4208) -> Result<()> {
4209 let Some(session) = session.for_user() else {
4210 return Err(anyhow!("user not found"))?;
4211 };
4212 authorize_access_to_language_models(&session).await?;
4213 session
4214 .rate_limiter
4215 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4216 .await?;
4217
4218 if request.model.starts_with("gpt") {
4219 let api_key =
4220 open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
4221 complete_with_open_ai(request, response, session, api_key).await?;
4222 } else if request.model.starts_with("gemini") {
4223 let api_key = google_ai_api_key
4224 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4225 complete_with_google_ai(request, response, session, api_key).await?;
4226 } else if request.model.starts_with("claude") {
4227 let api_key = anthropic_api_key
4228 .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
4229 complete_with_anthropic(request, response, session, api_key).await?;
4230 }
4231
4232 Ok(())
4233}
4234
4235async fn complete_with_open_ai(
4236 request: proto::CompleteWithLanguageModel,
4237 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4238 session: UserSession,
4239 api_key: Arc<str>,
4240) -> Result<()> {
4241 let mut completion_stream = open_ai::stream_completion(
4242 session.http_client.as_ref(),
4243 OPEN_AI_API_URL,
4244 &api_key,
4245 crate::ai::language_model_request_to_open_ai(request)?,
4246 )
4247 .await
4248 .context("open_ai::stream_completion request failed within collab")?;
4249
4250 while let Some(event) = completion_stream.next().await {
4251 let event = event?;
4252 response.send(proto::LanguageModelResponse {
4253 choices: event
4254 .choices
4255 .into_iter()
4256 .map(|choice| proto::LanguageModelChoiceDelta {
4257 index: choice.index,
4258 delta: Some(proto::LanguageModelResponseMessage {
4259 role: choice.delta.role.map(|role| match role {
4260 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4261 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4262 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4263 open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4264 } as i32),
4265 content: choice.delta.content,
4266 tool_calls: choice
4267 .delta
4268 .tool_calls
4269 .into_iter()
4270 .map(|delta| proto::ToolCallDelta {
4271 index: delta.index as u32,
4272 id: delta.id,
4273 variant: match delta.function {
4274 Some(function) => {
4275 let name = function.name;
4276 let arguments = function.arguments;
4277
4278 Some(proto::tool_call_delta::Variant::Function(
4279 proto::tool_call_delta::FunctionCallDelta {
4280 name,
4281 arguments,
4282 },
4283 ))
4284 }
4285 None => None,
4286 },
4287 })
4288 .collect(),
4289 }),
4290 finish_reason: choice.finish_reason,
4291 })
4292 .collect(),
4293 })?;
4294 }
4295
4296 Ok(())
4297}
4298
4299async fn complete_with_google_ai(
4300 request: proto::CompleteWithLanguageModel,
4301 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4302 session: UserSession,
4303 api_key: Arc<str>,
4304) -> Result<()> {
4305 let mut stream = google_ai::stream_generate_content(
4306 session.http_client.clone(),
4307 google_ai::API_URL,
4308 api_key.as_ref(),
4309 crate::ai::language_model_request_to_google_ai(request)?,
4310 )
4311 .await
4312 .context("google_ai::stream_generate_content request failed")?;
4313
4314 while let Some(event) = stream.next().await {
4315 let event = event?;
4316 response.send(proto::LanguageModelResponse {
4317 choices: event
4318 .candidates
4319 .unwrap_or_default()
4320 .into_iter()
4321 .map(|candidate| proto::LanguageModelChoiceDelta {
4322 index: candidate.index as u32,
4323 delta: Some(proto::LanguageModelResponseMessage {
4324 role: Some(match candidate.content.role {
4325 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4326 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4327 } as i32),
4328 content: Some(
4329 candidate
4330 .content
4331 .parts
4332 .into_iter()
4333 .filter_map(|part| match part {
4334 google_ai::Part::TextPart(part) => Some(part.text),
4335 google_ai::Part::InlineDataPart(_) => None,
4336 })
4337 .collect(),
4338 ),
4339 // Tool calls are not supported for Google
4340 tool_calls: Vec::new(),
4341 }),
4342 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4343 })
4344 .collect(),
4345 })?;
4346 }
4347
4348 Ok(())
4349}
4350
4351async fn complete_with_anthropic(
4352 request: proto::CompleteWithLanguageModel,
4353 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4354 session: UserSession,
4355 api_key: Arc<str>,
4356) -> Result<()> {
4357 let model = anthropic::Model::from_id(&request.model)?;
4358
4359 let mut system_message = String::new();
4360 let messages = request
4361 .messages
4362 .into_iter()
4363 .filter_map(|message| {
4364 match message.role() {
4365 LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4366 role: anthropic::Role::User,
4367 content: message.content,
4368 }),
4369 LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4370 role: anthropic::Role::Assistant,
4371 content: message.content,
4372 }),
4373 // Anthropic's API breaks system instructions out as a separate field rather
4374 // than having a system message role.
4375 LanguageModelRole::LanguageModelSystem => {
4376 if !system_message.is_empty() {
4377 system_message.push_str("\n\n");
4378 }
4379 system_message.push_str(&message.content);
4380
4381 None
4382 }
4383 // We don't yet support tool calls for Anthropic
4384 LanguageModelRole::LanguageModelTool => None,
4385 }
4386 })
4387 .collect();
4388
4389 let mut stream = anthropic::stream_completion(
4390 session.http_client.clone(),
4391 "https://api.anthropic.com",
4392 &api_key,
4393 anthropic::Request {
4394 model,
4395 messages,
4396 stream: true,
4397 system: system_message,
4398 max_tokens: 4092,
4399 },
4400 )
4401 .await?;
4402
4403 let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4404
4405 while let Some(event) = stream.next().await {
4406 let event = event?;
4407
4408 match event {
4409 anthropic::ResponseEvent::MessageStart { message } => {
4410 if let Some(role) = message.role {
4411 if role == "assistant" {
4412 current_role = proto::LanguageModelRole::LanguageModelAssistant;
4413 } else if role == "user" {
4414 current_role = proto::LanguageModelRole::LanguageModelUser;
4415 }
4416 }
4417 }
4418 anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4419 match content_block {
4420 anthropic::ContentBlock::Text { text } => {
4421 if !text.is_empty() {
4422 response.send(proto::LanguageModelResponse {
4423 choices: vec![proto::LanguageModelChoiceDelta {
4424 index: 0,
4425 delta: Some(proto::LanguageModelResponseMessage {
4426 role: Some(current_role as i32),
4427 content: Some(text),
4428 tool_calls: Vec::new(),
4429 }),
4430 finish_reason: None,
4431 }],
4432 })?;
4433 }
4434 }
4435 }
4436 }
4437 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4438 anthropic::TextDelta::TextDelta { text } => {
4439 response.send(proto::LanguageModelResponse {
4440 choices: vec![proto::LanguageModelChoiceDelta {
4441 index: 0,
4442 delta: Some(proto::LanguageModelResponseMessage {
4443 role: Some(current_role as i32),
4444 content: Some(text),
4445 tool_calls: Vec::new(),
4446 }),
4447 finish_reason: None,
4448 }],
4449 })?;
4450 }
4451 },
4452 anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4453 if let Some(stop_reason) = delta.stop_reason {
4454 response.send(proto::LanguageModelResponse {
4455 choices: vec![proto::LanguageModelChoiceDelta {
4456 index: 0,
4457 delta: None,
4458 finish_reason: Some(stop_reason),
4459 }],
4460 })?;
4461 }
4462 }
4463 anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4464 anthropic::ResponseEvent::MessageStop {} => {}
4465 anthropic::ResponseEvent::Ping {} => {}
4466 }
4467 }
4468
4469 Ok(())
4470}
4471
4472struct CountTokensWithLanguageModelRateLimit;
4473
4474impl RateLimit for CountTokensWithLanguageModelRateLimit {
4475 fn capacity() -> usize {
4476 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4477 .ok()
4478 .and_then(|v| v.parse().ok())
4479 .unwrap_or(600) // Picked arbitrarily
4480 }
4481
4482 fn refill_duration() -> chrono::Duration {
4483 chrono::Duration::hours(1)
4484 }
4485
4486 fn db_name() -> &'static str {
4487 "count-tokens-with-language-model"
4488 }
4489}
4490
4491async fn count_tokens_with_language_model(
4492 request: proto::CountTokensWithLanguageModel,
4493 response: Response<proto::CountTokensWithLanguageModel>,
4494 session: UserSession,
4495 google_ai_api_key: Option<Arc<str>>,
4496) -> Result<()> {
4497 authorize_access_to_language_models(&session).await?;
4498
4499 if !request.model.starts_with("gemini") {
4500 return Err(anyhow!(
4501 "counting tokens for model: {:?} is not supported",
4502 request.model
4503 ))?;
4504 }
4505
4506 session
4507 .rate_limiter
4508 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4509 .await?;
4510
4511 let api_key = google_ai_api_key
4512 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4513 let tokens_response = google_ai::count_tokens(
4514 session.http_client.as_ref(),
4515 google_ai::API_URL,
4516 &api_key,
4517 crate::ai::count_tokens_request_to_google_ai(request)?,
4518 )
4519 .await?;
4520 response.send(proto::CountTokensResponse {
4521 token_count: tokens_response.total_tokens as u32,
4522 })?;
4523 Ok(())
4524}
4525
4526struct ComputeEmbeddingsRateLimit;
4527
4528impl RateLimit for ComputeEmbeddingsRateLimit {
4529 fn capacity() -> usize {
4530 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4531 .ok()
4532 .and_then(|v| v.parse().ok())
4533 .unwrap_or(5000) // Picked arbitrarily
4534 }
4535
4536 fn refill_duration() -> chrono::Duration {
4537 chrono::Duration::hours(1)
4538 }
4539
4540 fn db_name() -> &'static str {
4541 "compute-embeddings"
4542 }
4543}
4544
4545async fn compute_embeddings(
4546 request: proto::ComputeEmbeddings,
4547 response: Response<proto::ComputeEmbeddings>,
4548 session: UserSession,
4549 api_key: Option<Arc<str>>,
4550) -> Result<()> {
4551 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4552 authorize_access_to_language_models(&session).await?;
4553
4554 session
4555 .rate_limiter
4556 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4557 .await?;
4558
4559 let embeddings = match request.model.as_str() {
4560 "openai/text-embedding-3-small" => {
4561 open_ai::embed(
4562 session.http_client.as_ref(),
4563 OPEN_AI_API_URL,
4564 &api_key,
4565 OpenAiEmbeddingModel::TextEmbedding3Small,
4566 request.texts.iter().map(|text| text.as_str()),
4567 )
4568 .await?
4569 }
4570 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4571 };
4572
4573 let embeddings = request
4574 .texts
4575 .iter()
4576 .map(|text| {
4577 let mut hasher = sha2::Sha256::new();
4578 hasher.update(text.as_bytes());
4579 let result = hasher.finalize();
4580 result.to_vec()
4581 })
4582 .zip(
4583 embeddings
4584 .data
4585 .into_iter()
4586 .map(|embedding| embedding.embedding),
4587 )
4588 .collect::<HashMap<_, _>>();
4589
4590 let db = session.db().await;
4591 db.save_embeddings(&request.model, &embeddings)
4592 .await
4593 .context("failed to save embeddings")
4594 .trace_err();
4595
4596 response.send(proto::ComputeEmbeddingsResponse {
4597 embeddings: embeddings
4598 .into_iter()
4599 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4600 .collect(),
4601 })?;
4602 Ok(())
4603}
4604
4605async fn get_cached_embeddings(
4606 request: proto::GetCachedEmbeddings,
4607 response: Response<proto::GetCachedEmbeddings>,
4608 session: UserSession,
4609) -> Result<()> {
4610 authorize_access_to_language_models(&session).await?;
4611
4612 let db = session.db().await;
4613 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4614
4615 response.send(proto::GetCachedEmbeddingsResponse {
4616 embeddings: embeddings
4617 .into_iter()
4618 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4619 .collect(),
4620 })?;
4621 Ok(())
4622}
4623
4624async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4625 let db = session.db().await;
4626 let flags = db.get_user_flags(session.user_id()).await?;
4627 if flags.iter().any(|flag| flag == "language-models") {
4628 Ok(())
4629 } else {
4630 Err(anyhow!("permission denied"))?
4631 }
4632}
4633
4634/// Get a Supermaven API key for the user
4635async fn get_supermaven_api_key(
4636 _request: proto::GetSupermavenApiKey,
4637 response: Response<proto::GetSupermavenApiKey>,
4638 session: UserSession,
4639) -> Result<()> {
4640 let user_id: String = session.user_id().to_string();
4641 if !session.is_staff() {
4642 return Err(anyhow!("supermaven not enabled for this account"))?;
4643 }
4644
4645 let email = session
4646 .email()
4647 .ok_or_else(|| anyhow!("user must have an email"))?;
4648
4649 let supermaven_admin_api = session
4650 .supermaven_client
4651 .as_ref()
4652 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4653
4654 let result = supermaven_admin_api
4655 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4656 .await?;
4657
4658 response.send(proto::GetSupermavenApiKeyResponse {
4659 api_key: result.api_key,
4660 })?;
4661
4662 Ok(())
4663}
4664
4665/// Start receiving chat updates for a channel
4666async fn join_channel_chat(
4667 request: proto::JoinChannelChat,
4668 response: Response<proto::JoinChannelChat>,
4669 session: UserSession,
4670) -> Result<()> {
4671 let channel_id = ChannelId::from_proto(request.channel_id);
4672
4673 let db = session.db().await;
4674 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4675 .await?;
4676 let messages = db
4677 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4678 .await?;
4679 response.send(proto::JoinChannelChatResponse {
4680 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4681 messages,
4682 })?;
4683 Ok(())
4684}
4685
4686/// Stop receiving chat updates for a channel
4687async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4688 let channel_id = ChannelId::from_proto(request.channel_id);
4689 session
4690 .db()
4691 .await
4692 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4693 .await?;
4694 Ok(())
4695}
4696
4697/// Retrieve the chat history for a channel
4698async fn get_channel_messages(
4699 request: proto::GetChannelMessages,
4700 response: Response<proto::GetChannelMessages>,
4701 session: UserSession,
4702) -> Result<()> {
4703 let channel_id = ChannelId::from_proto(request.channel_id);
4704 let messages = session
4705 .db()
4706 .await
4707 .get_channel_messages(
4708 channel_id,
4709 session.user_id(),
4710 MESSAGE_COUNT_PER_PAGE,
4711 Some(MessageId::from_proto(request.before_message_id)),
4712 )
4713 .await?;
4714 response.send(proto::GetChannelMessagesResponse {
4715 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4716 messages,
4717 })?;
4718 Ok(())
4719}
4720
4721/// Retrieve specific chat messages
4722async fn get_channel_messages_by_id(
4723 request: proto::GetChannelMessagesById,
4724 response: Response<proto::GetChannelMessagesById>,
4725 session: UserSession,
4726) -> Result<()> {
4727 let message_ids = request
4728 .message_ids
4729 .iter()
4730 .map(|id| MessageId::from_proto(*id))
4731 .collect::<Vec<_>>();
4732 let messages = session
4733 .db()
4734 .await
4735 .get_channel_messages_by_id(session.user_id(), &message_ids)
4736 .await?;
4737 response.send(proto::GetChannelMessagesResponse {
4738 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4739 messages,
4740 })?;
4741 Ok(())
4742}
4743
4744/// Retrieve the current users notifications
4745async fn get_notifications(
4746 request: proto::GetNotifications,
4747 response: Response<proto::GetNotifications>,
4748 session: UserSession,
4749) -> Result<()> {
4750 let notifications = session
4751 .db()
4752 .await
4753 .get_notifications(
4754 session.user_id(),
4755 NOTIFICATION_COUNT_PER_PAGE,
4756 request
4757 .before_id
4758 .map(|id| db::NotificationId::from_proto(id)),
4759 )
4760 .await?;
4761 response.send(proto::GetNotificationsResponse {
4762 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4763 notifications,
4764 })?;
4765 Ok(())
4766}
4767
4768/// Mark notifications as read
4769async fn mark_notification_as_read(
4770 request: proto::MarkNotificationRead,
4771 response: Response<proto::MarkNotificationRead>,
4772 session: UserSession,
4773) -> Result<()> {
4774 let database = &session.db().await;
4775 let notifications = database
4776 .mark_notification_as_read_by_id(
4777 session.user_id(),
4778 NotificationId::from_proto(request.notification_id),
4779 )
4780 .await?;
4781 send_notifications(
4782 &*session.connection_pool().await,
4783 &session.peer,
4784 notifications,
4785 );
4786 response.send(proto::Ack {})?;
4787 Ok(())
4788}
4789
4790/// Get the current users information
4791async fn get_private_user_info(
4792 _request: proto::GetPrivateUserInfo,
4793 response: Response<proto::GetPrivateUserInfo>,
4794 session: UserSession,
4795) -> Result<()> {
4796 let db = session.db().await;
4797
4798 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4799 let user = db
4800 .get_user_by_id(session.user_id())
4801 .await?
4802 .ok_or_else(|| anyhow!("user not found"))?;
4803 let flags = db.get_user_flags(session.user_id()).await?;
4804
4805 response.send(proto::GetPrivateUserInfoResponse {
4806 metrics_id,
4807 staff: user.admin,
4808 flags,
4809 })?;
4810 Ok(())
4811}
4812
4813fn to_axum_message(message: TungsteniteMessage) -> AxumMessage {
4814 match message {
4815 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4816 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4817 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4818 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4819 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4820 code: frame.code.into(),
4821 reason: frame.reason,
4822 })),
4823 }
4824}
4825
4826fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4827 match message {
4828 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4829 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4830 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4831 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4832 AxumMessage::Close(frame) => {
4833 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4834 code: frame.code.into(),
4835 reason: frame.reason,
4836 }))
4837 }
4838 }
4839}
4840
4841fn notify_membership_updated(
4842 connection_pool: &mut ConnectionPool,
4843 result: MembershipUpdated,
4844 user_id: UserId,
4845 peer: &Peer,
4846) {
4847 for membership in &result.new_channels.channel_memberships {
4848 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4849 }
4850 for channel_id in &result.removed_channels {
4851 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4852 }
4853
4854 let user_channels_update = proto::UpdateUserChannels {
4855 channel_memberships: result
4856 .new_channels
4857 .channel_memberships
4858 .iter()
4859 .map(|cm| proto::ChannelMembership {
4860 channel_id: cm.channel_id.to_proto(),
4861 role: cm.role.into(),
4862 })
4863 .collect(),
4864 ..Default::default()
4865 };
4866
4867 let mut update = build_channels_update(result.new_channels, vec![]);
4868 update.delete_channels = result
4869 .removed_channels
4870 .into_iter()
4871 .map(|id| id.to_proto())
4872 .collect();
4873 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4874
4875 for connection_id in connection_pool.user_connection_ids(user_id) {
4876 peer.send(connection_id, user_channels_update.clone())
4877 .trace_err();
4878 peer.send(connection_id, update.clone()).trace_err();
4879 }
4880}
4881
4882fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
4883 proto::UpdateUserChannels {
4884 channel_memberships: channels
4885 .channel_memberships
4886 .iter()
4887 .map(|m| proto::ChannelMembership {
4888 channel_id: m.channel_id.to_proto(),
4889 role: m.role.into(),
4890 })
4891 .collect(),
4892 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
4893 observed_channel_message_id: channels.observed_channel_messages.clone(),
4894 }
4895}
4896
4897fn build_channels_update(
4898 channels: ChannelsForUser,
4899 channel_invites: Vec<db::Channel>,
4900) -> proto::UpdateChannels {
4901 let mut update = proto::UpdateChannels::default();
4902
4903 for channel in channels.channels {
4904 update.channels.push(channel.to_proto());
4905 }
4906
4907 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
4908 update.latest_channel_message_ids = channels.latest_channel_messages;
4909
4910 for (channel_id, participants) in channels.channel_participants {
4911 update
4912 .channel_participants
4913 .push(proto::ChannelParticipants {
4914 channel_id: channel_id.to_proto(),
4915 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
4916 });
4917 }
4918
4919 for channel in channel_invites {
4920 update.channel_invitations.push(channel.to_proto());
4921 }
4922
4923 update.hosted_projects = channels.hosted_projects;
4924 update
4925}
4926
4927fn build_initial_contacts_update(
4928 contacts: Vec<db::Contact>,
4929 pool: &ConnectionPool,
4930) -> proto::UpdateContacts {
4931 let mut update = proto::UpdateContacts::default();
4932
4933 for contact in contacts {
4934 match contact {
4935 db::Contact::Accepted { user_id, busy } => {
4936 update.contacts.push(contact_for_user(user_id, busy, &pool));
4937 }
4938 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
4939 db::Contact::Incoming { user_id } => {
4940 update
4941 .incoming_requests
4942 .push(proto::IncomingContactRequest {
4943 requester_id: user_id.to_proto(),
4944 })
4945 }
4946 }
4947 }
4948
4949 update
4950}
4951
4952fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
4953 proto::Contact {
4954 user_id: user_id.to_proto(),
4955 online: pool.is_user_online(user_id),
4956 busy,
4957 }
4958}
4959
4960fn room_updated(room: &proto::Room, peer: &Peer) {
4961 broadcast(
4962 None,
4963 room.participants
4964 .iter()
4965 .filter_map(|participant| Some(participant.peer_id?.into())),
4966 |peer_id| {
4967 peer.send(
4968 peer_id,
4969 proto::RoomUpdated {
4970 room: Some(room.clone()),
4971 },
4972 )
4973 },
4974 );
4975}
4976
4977fn channel_updated(
4978 channel: &db::channel::Model,
4979 room: &proto::Room,
4980 peer: &Peer,
4981 pool: &ConnectionPool,
4982) {
4983 let participants = room
4984 .participants
4985 .iter()
4986 .map(|p| p.user_id)
4987 .collect::<Vec<_>>();
4988
4989 broadcast(
4990 None,
4991 pool.channel_connection_ids(channel.root_id())
4992 .filter_map(|(channel_id, role)| {
4993 role.can_see_channel(channel.visibility).then(|| channel_id)
4994 }),
4995 |peer_id| {
4996 peer.send(
4997 peer_id,
4998 proto::UpdateChannels {
4999 channel_participants: vec![proto::ChannelParticipants {
5000 channel_id: channel.id.to_proto(),
5001 participant_user_ids: participants.clone(),
5002 }],
5003 ..Default::default()
5004 },
5005 )
5006 },
5007 );
5008}
5009
5010async fn send_dev_server_projects_update(
5011 user_id: UserId,
5012 mut status: proto::DevServerProjectsUpdate,
5013 session: &Session,
5014) {
5015 let pool = session.connection_pool().await;
5016 for dev_server in &mut status.dev_servers {
5017 dev_server.status =
5018 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5019 }
5020 let connections = pool.user_connection_ids(user_id);
5021 for connection_id in connections {
5022 session.peer.send(connection_id, status.clone()).trace_err();
5023 }
5024}
5025
5026async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5027 let db = session.db().await;
5028
5029 let contacts = db.get_contacts(user_id).await?;
5030 let busy = db.is_user_busy(user_id).await?;
5031
5032 let pool = session.connection_pool().await;
5033 let updated_contact = contact_for_user(user_id, busy, &pool);
5034 for contact in contacts {
5035 if let db::Contact::Accepted {
5036 user_id: contact_user_id,
5037 ..
5038 } = contact
5039 {
5040 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5041 session
5042 .peer
5043 .send(
5044 contact_conn_id,
5045 proto::UpdateContacts {
5046 contacts: vec![updated_contact.clone()],
5047 remove_contacts: Default::default(),
5048 incoming_requests: Default::default(),
5049 remove_incoming_requests: Default::default(),
5050 outgoing_requests: Default::default(),
5051 remove_outgoing_requests: Default::default(),
5052 },
5053 )
5054 .trace_err();
5055 }
5056 }
5057 }
5058 Ok(())
5059}
5060
5061async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5062 log::info!("lost dev server connection, unsharing projects");
5063 let project_ids = session
5064 .db()
5065 .await
5066 .get_stale_dev_server_projects(session.connection_id)
5067 .await?;
5068
5069 for project_id in project_ids {
5070 // not unshare re-checks the connection ids match, so we get away with no transaction
5071 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5072 }
5073
5074 let user_id = session.dev_server().user_id;
5075 let update = session
5076 .db()
5077 .await
5078 .dev_server_projects_update(user_id)
5079 .await?;
5080
5081 send_dev_server_projects_update(user_id, update, session).await;
5082
5083 Ok(())
5084}
5085
5086async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5087 let mut contacts_to_update = HashSet::default();
5088
5089 let room_id;
5090 let canceled_calls_to_user_ids;
5091 let live_kit_room;
5092 let delete_live_kit_room;
5093 let room;
5094 let channel;
5095
5096 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5097 contacts_to_update.insert(session.user_id());
5098
5099 for project in left_room.left_projects.values() {
5100 project_left(project, session);
5101 }
5102
5103 room_id = RoomId::from_proto(left_room.room.id);
5104 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5105 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5106 delete_live_kit_room = left_room.deleted;
5107 room = mem::take(&mut left_room.room);
5108 channel = mem::take(&mut left_room.channel);
5109
5110 room_updated(&room, &session.peer);
5111 } else {
5112 return Ok(());
5113 }
5114
5115 if let Some(channel) = channel {
5116 channel_updated(
5117 &channel,
5118 &room,
5119 &session.peer,
5120 &*session.connection_pool().await,
5121 );
5122 }
5123
5124 {
5125 let pool = session.connection_pool().await;
5126 for canceled_user_id in canceled_calls_to_user_ids {
5127 for connection_id in pool.user_connection_ids(canceled_user_id) {
5128 session
5129 .peer
5130 .send(
5131 connection_id,
5132 proto::CallCanceled {
5133 room_id: room_id.to_proto(),
5134 },
5135 )
5136 .trace_err();
5137 }
5138 contacts_to_update.insert(canceled_user_id);
5139 }
5140 }
5141
5142 for contact_user_id in contacts_to_update {
5143 update_user_contacts(contact_user_id, &session).await?;
5144 }
5145
5146 if let Some(live_kit) = session.live_kit_client.as_ref() {
5147 live_kit
5148 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5149 .await
5150 .trace_err();
5151
5152 if delete_live_kit_room {
5153 live_kit.delete_room(live_kit_room).await.trace_err();
5154 }
5155 }
5156
5157 Ok(())
5158}
5159
5160async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5161 let left_channel_buffers = session
5162 .db()
5163 .await
5164 .leave_channel_buffers(session.connection_id)
5165 .await?;
5166
5167 for left_buffer in left_channel_buffers {
5168 channel_buffer_updated(
5169 session.connection_id,
5170 left_buffer.connections,
5171 &proto::UpdateChannelBufferCollaborators {
5172 channel_id: left_buffer.channel_id.to_proto(),
5173 collaborators: left_buffer.collaborators,
5174 },
5175 &session.peer,
5176 );
5177 }
5178
5179 Ok(())
5180}
5181
5182fn project_left(project: &db::LeftProject, session: &UserSession) {
5183 for connection_id in &project.connection_ids {
5184 if project.should_unshare {
5185 session
5186 .peer
5187 .send(
5188 *connection_id,
5189 proto::UnshareProject {
5190 project_id: project.id.to_proto(),
5191 },
5192 )
5193 .trace_err();
5194 } else {
5195 session
5196 .peer
5197 .send(
5198 *connection_id,
5199 proto::RemoveProjectCollaborator {
5200 project_id: project.id.to_proto(),
5201 peer_id: Some(session.connection_id.into()),
5202 },
5203 )
5204 .trace_err();
5205 }
5206 }
5207}
5208
5209pub trait ResultExt {
5210 type Ok;
5211
5212 fn trace_err(self) -> Option<Self::Ok>;
5213}
5214
5215impl<T, E> ResultExt for Result<T, E>
5216where
5217 E: std::fmt::Debug,
5218{
5219 type Ok = T;
5220
5221 #[track_caller]
5222 fn trace_err(self) -> Option<T> {
5223 match self {
5224 Ok(value) => Some(value),
5225 Err(error) => {
5226 tracing::error!("{:?}", error);
5227 None
5228 }
5229 }
5230 }
5231}