1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
8 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
9 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, bail, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38
39use futures::{
40 channel::oneshot,
41 future::{self, BoxFuture},
42 stream::FuturesUnordered,
43 FutureExt, SinkExt, StreamExt, TryStreamExt,
44};
45use http_client::IsahcHttpClient;
46use prometheus::{register_int_gauge, IntGauge};
47use rpc::{
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
50 RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 atomic::{AtomicBool, Ordering::SeqCst},
66 Arc, OnceLock,
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{watch, Semaphore};
72use tower::ServiceBuilder;
73use tracing::{
74 field::{self},
75 info_span, instrument, Instrument,
76};
77
78use self::connection_pool::VersionedMessage;
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106struct StreamingResponse<R: RequestMessage> {
107 peer: Arc<Peer>,
108 receipt: Receipt<R>,
109}
110
111impl<R: RequestMessage> StreamingResponse<R> {
112 fn send(&self, payload: R::Response) -> Result<()> {
113 self.peer.respond(self.receipt, payload)?;
114 Ok(())
115 }
116}
117
118#[derive(Clone, Debug)]
119pub enum Principal {
120 User(User),
121 Impersonated { user: User, admin: User },
122 DevServer(dev_server::Model),
123}
124
125impl Principal {
126 fn update_span(&self, span: &tracing::Span) {
127 match &self {
128 Principal::User(user) => {
129 span.record("user_id", &user.id.0);
130 span.record("login", &user.github_login);
131 }
132 Principal::Impersonated { user, admin } => {
133 span.record("user_id", &user.id.0);
134 span.record("login", &user.github_login);
135 span.record("impersonator", &admin.github_login);
136 }
137 Principal::DevServer(dev_server) => {
138 span.record("dev_server_id", &dev_server.id.0);
139 }
140 }
141 }
142}
143
144#[derive(Clone)]
145struct Session {
146 principal: Principal,
147 connection_id: ConnectionId,
148 db: Arc<tokio::sync::Mutex<DbHandle>>,
149 peer: Arc<Peer>,
150 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
151 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
152 supermaven_client: Option<Arc<SupermavenAdminApi>>,
153 http_client: Arc<IsahcHttpClient>,
154 rate_limiter: Arc<RateLimiter>,
155 _executor: Executor,
156}
157
158impl Session {
159 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.db.lock().await;
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 guard
166 }
167
168 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
169 #[cfg(test)]
170 tokio::task::yield_now().await;
171 let guard = self.connection_pool.lock();
172 ConnectionPoolGuard {
173 guard,
174 _not_send: PhantomData,
175 }
176 }
177
178 fn for_user(self) -> Option<UserSession> {
179 UserSession::new(self)
180 }
181
182 fn for_dev_server(self) -> Option<DevServerSession> {
183 DevServerSession::new(self)
184 }
185
186 fn user_id(&self) -> Option<UserId> {
187 match &self.principal {
188 Principal::User(user) => Some(user.id),
189 Principal::Impersonated { user, .. } => Some(user.id),
190 Principal::DevServer(_) => None,
191 }
192 }
193
194 fn is_staff(&self) -> bool {
195 match &self.principal {
196 Principal::User(user) => user.admin,
197 Principal::Impersonated { .. } => true,
198 Principal::DevServer(_) => false,
199 }
200 }
201
202 fn dev_server_id(&self) -> Option<DevServerId> {
203 match &self.principal {
204 Principal::User(_) | Principal::Impersonated { .. } => None,
205 Principal::DevServer(dev_server) => Some(dev_server.id),
206 }
207 }
208
209 fn principal_id(&self) -> PrincipalId {
210 match &self.principal {
211 Principal::User(user) => PrincipalId::UserId(user.id),
212 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
213 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
214 }
215 }
216}
217
218impl Debug for Session {
219 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220 let mut result = f.debug_struct("Session");
221 match &self.principal {
222 Principal::User(user) => {
223 result.field("user", &user.github_login);
224 }
225 Principal::Impersonated { user, admin } => {
226 result.field("user", &user.github_login);
227 result.field("impersonator", &admin.github_login);
228 }
229 Principal::DevServer(dev_server) => {
230 result.field("dev_server", &dev_server.id);
231 }
232 }
233 result.field("connection_id", &self.connection_id).finish()
234 }
235}
236
237struct UserSession(Session);
238
239impl UserSession {
240 pub fn new(s: Session) -> Option<Self> {
241 s.user_id().map(|_| UserSession(s))
242 }
243 pub fn user_id(&self) -> UserId {
244 self.0.user_id().unwrap()
245 }
246
247 pub fn email(&self) -> Option<String> {
248 match &self.0.principal {
249 Principal::User(user) => user.email_address.clone(),
250 Principal::Impersonated { user, .. } => user.email_address.clone(),
251 Principal::DevServer(..) => None,
252 }
253 }
254}
255
256impl Deref for UserSession {
257 type Target = Session;
258
259 fn deref(&self) -> &Self::Target {
260 &self.0
261 }
262}
263impl DerefMut for UserSession {
264 fn deref_mut(&mut self) -> &mut Self::Target {
265 &mut self.0
266 }
267}
268
269struct DevServerSession(Session);
270
271impl DevServerSession {
272 pub fn new(s: Session) -> Option<Self> {
273 s.dev_server_id().map(|_| DevServerSession(s))
274 }
275 pub fn dev_server_id(&self) -> DevServerId {
276 self.0.dev_server_id().unwrap()
277 }
278
279 fn dev_server(&self) -> &dev_server::Model {
280 match &self.0.principal {
281 Principal::DevServer(dev_server) => dev_server,
282 _ => unreachable!(),
283 }
284 }
285}
286
287impl Deref for DevServerSession {
288 type Target = Session;
289
290 fn deref(&self) -> &Self::Target {
291 &self.0
292 }
293}
294impl DerefMut for DevServerSession {
295 fn deref_mut(&mut self) -> &mut Self::Target {
296 &mut self.0
297 }
298}
299
300fn user_handler<M: RequestMessage, Fut>(
301 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
303where
304 Fut: Send + Future<Output = Result<()>>,
305{
306 let handler = Arc::new(handler);
307 move |message, response, session| {
308 let handler = handler.clone();
309 Box::pin(async move {
310 if let Some(user_session) = session.for_user() {
311 Ok(handler(message, response, user_session).await?)
312 } else {
313 Err(Error::Internal(anyhow!(
314 "must be a user to call {}",
315 M::NAME
316 )))
317 }
318 })
319 }
320}
321
322fn dev_server_handler<M: RequestMessage, Fut>(
323 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
325where
326 Fut: Send + Future<Output = Result<()>>,
327{
328 let handler = Arc::new(handler);
329 move |message, response, session| {
330 let handler = handler.clone();
331 Box::pin(async move {
332 if let Some(dev_server_session) = session.for_dev_server() {
333 Ok(handler(message, response, dev_server_session).await?)
334 } else {
335 Err(Error::Internal(anyhow!(
336 "must be a dev server to call {}",
337 M::NAME
338 )))
339 }
340 })
341 }
342}
343
344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
345 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
347where
348 InnertRetFut: Send + Future<Output = Result<()>>,
349{
350 let handler = Arc::new(handler);
351 move |message, session| {
352 let handler = handler.clone();
353 Box::pin(async move {
354 if let Some(user_session) = session.for_user() {
355 Ok(handler(message, user_session).await?)
356 } else {
357 Err(Error::Internal(anyhow!(
358 "must be a user to call {}",
359 M::NAME
360 )))
361 }
362 })
363 }
364}
365
366struct DbHandle(Arc<Database>);
367
368impl Deref for DbHandle {
369 type Target = Database;
370
371 fn deref(&self) -> &Self::Target {
372 self.0.as_ref()
373 }
374}
375
376pub struct Server {
377 id: parking_lot::Mutex<ServerId>,
378 peer: Arc<Peer>,
379 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
380 app_state: Arc<AppState>,
381 handlers: HashMap<TypeId, MessageHandler>,
382 teardown: watch::Sender<bool>,
383}
384
385pub(crate) struct ConnectionPoolGuard<'a> {
386 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
387 _not_send: PhantomData<Rc<()>>,
388}
389
390#[derive(Serialize)]
391pub struct ServerSnapshot<'a> {
392 peer: &'a Peer,
393 #[serde(serialize_with = "serialize_deref")]
394 connection_pool: ConnectionPoolGuard<'a>,
395}
396
397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
398where
399 S: Serializer,
400 T: Deref<Target = U>,
401 U: Serialize,
402{
403 Serialize::serialize(value.deref(), serializer)
404}
405
406impl Server {
407 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
408 let mut server = Self {
409 id: parking_lot::Mutex::new(id),
410 peer: Peer::new(id.0 as u32),
411 app_state: app_state.clone(),
412 connection_pool: Default::default(),
413 handlers: Default::default(),
414 teardown: watch::channel(false).0,
415 };
416
417 server
418 .add_request_handler(ping)
419 .add_request_handler(user_handler(create_room))
420 .add_request_handler(user_handler(join_room))
421 .add_request_handler(user_handler(rejoin_room))
422 .add_request_handler(user_handler(leave_room))
423 .add_request_handler(user_handler(set_room_participant_role))
424 .add_request_handler(user_handler(call))
425 .add_request_handler(user_handler(cancel_call))
426 .add_message_handler(user_message_handler(decline_call))
427 .add_request_handler(user_handler(update_participant_location))
428 .add_request_handler(user_handler(share_project))
429 .add_message_handler(unshare_project)
430 .add_request_handler(user_handler(join_project))
431 .add_request_handler(user_handler(join_hosted_project))
432 .add_request_handler(user_handler(rejoin_dev_server_projects))
433 .add_request_handler(user_handler(create_dev_server_project))
434 .add_request_handler(user_handler(update_dev_server_project))
435 .add_request_handler(user_handler(delete_dev_server_project))
436 .add_request_handler(user_handler(create_dev_server))
437 .add_request_handler(user_handler(regenerate_dev_server_token))
438 .add_request_handler(user_handler(rename_dev_server))
439 .add_request_handler(user_handler(delete_dev_server))
440 .add_request_handler(user_handler(list_remote_directory))
441 .add_request_handler(dev_server_handler(share_dev_server_project))
442 .add_request_handler(dev_server_handler(shutdown_dev_server))
443 .add_request_handler(dev_server_handler(reconnect_dev_server))
444 .add_message_handler(user_message_handler(leave_project))
445 .add_request_handler(update_project)
446 .add_request_handler(update_worktree)
447 .add_message_handler(start_language_server)
448 .add_message_handler(update_language_server)
449 .add_message_handler(update_diagnostic_summary)
450 .add_message_handler(update_worktree_settings)
451 .add_request_handler(user_handler(
452 forward_project_request_for_owner::<proto::TaskContextForLocation>,
453 ))
454 .add_request_handler(user_handler(
455 forward_project_request_for_owner::<proto::TaskTemplates>,
456 ))
457 .add_request_handler(user_handler(
458 forward_read_only_project_request::<proto::GetHover>,
459 ))
460 .add_request_handler(user_handler(
461 forward_read_only_project_request::<proto::GetDefinition>,
462 ))
463 .add_request_handler(user_handler(
464 forward_read_only_project_request::<proto::GetTypeDefinition>,
465 ))
466 .add_request_handler(user_handler(
467 forward_read_only_project_request::<proto::GetReferences>,
468 ))
469 .add_request_handler(user_handler(
470 forward_read_only_project_request::<proto::SearchProject>,
471 ))
472 .add_request_handler(user_handler(
473 forward_read_only_project_request::<proto::GetDocumentHighlights>,
474 ))
475 .add_request_handler(user_handler(
476 forward_read_only_project_request::<proto::GetProjectSymbols>,
477 ))
478 .add_request_handler(user_handler(
479 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
480 ))
481 .add_request_handler(user_handler(
482 forward_read_only_project_request::<proto::OpenBufferById>,
483 ))
484 .add_request_handler(user_handler(
485 forward_read_only_project_request::<proto::SynchronizeBuffers>,
486 ))
487 .add_request_handler(user_handler(
488 forward_read_only_project_request::<proto::InlayHints>,
489 ))
490 .add_request_handler(user_handler(
491 forward_read_only_project_request::<proto::OpenBufferByPath>,
492 ))
493 .add_request_handler(user_handler(
494 forward_mutating_project_request::<proto::GetCompletions>,
495 ))
496 .add_request_handler(user_handler(
497 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
498 ))
499 .add_request_handler(user_handler(
500 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
501 ))
502 .add_request_handler(user_handler(
503 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
504 ))
505 .add_request_handler(user_handler(
506 forward_mutating_project_request::<proto::GetCodeActions>,
507 ))
508 .add_request_handler(user_handler(
509 forward_mutating_project_request::<proto::ApplyCodeAction>,
510 ))
511 .add_request_handler(user_handler(
512 forward_mutating_project_request::<proto::PrepareRename>,
513 ))
514 .add_request_handler(user_handler(
515 forward_mutating_project_request::<proto::PerformRename>,
516 ))
517 .add_request_handler(user_handler(
518 forward_mutating_project_request::<proto::ReloadBuffers>,
519 ))
520 .add_request_handler(user_handler(
521 forward_mutating_project_request::<proto::FormatBuffers>,
522 ))
523 .add_request_handler(user_handler(
524 forward_mutating_project_request::<proto::CreateProjectEntry>,
525 ))
526 .add_request_handler(user_handler(
527 forward_mutating_project_request::<proto::RenameProjectEntry>,
528 ))
529 .add_request_handler(user_handler(
530 forward_mutating_project_request::<proto::CopyProjectEntry>,
531 ))
532 .add_request_handler(user_handler(
533 forward_mutating_project_request::<proto::DeleteProjectEntry>,
534 ))
535 .add_request_handler(user_handler(
536 forward_mutating_project_request::<proto::ExpandProjectEntry>,
537 ))
538 .add_request_handler(user_handler(
539 forward_mutating_project_request::<proto::OnTypeFormatting>,
540 ))
541 .add_request_handler(user_handler(
542 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
543 ))
544 .add_request_handler(user_handler(
545 forward_mutating_project_request::<proto::BlameBuffer>,
546 ))
547 .add_request_handler(user_handler(
548 forward_mutating_project_request::<proto::MultiLspQuery>,
549 ))
550 .add_request_handler(user_handler(
551 forward_mutating_project_request::<proto::RestartLanguageServers>,
552 ))
553 .add_request_handler(user_handler(
554 forward_mutating_project_request::<proto::LinkedEditingRange>,
555 ))
556 .add_message_handler(create_buffer_for_peer)
557 .add_request_handler(update_buffer)
558 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
559 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
560 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
561 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
562 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
563 .add_request_handler(get_users)
564 .add_request_handler(user_handler(fuzzy_search_users))
565 .add_request_handler(user_handler(request_contact))
566 .add_request_handler(user_handler(remove_contact))
567 .add_request_handler(user_handler(respond_to_contact_request))
568 .add_message_handler(subscribe_to_channels)
569 .add_request_handler(user_handler(create_channel))
570 .add_request_handler(user_handler(delete_channel))
571 .add_request_handler(user_handler(invite_channel_member))
572 .add_request_handler(user_handler(remove_channel_member))
573 .add_request_handler(user_handler(set_channel_member_role))
574 .add_request_handler(user_handler(set_channel_visibility))
575 .add_request_handler(user_handler(rename_channel))
576 .add_request_handler(user_handler(join_channel_buffer))
577 .add_request_handler(user_handler(leave_channel_buffer))
578 .add_message_handler(user_message_handler(update_channel_buffer))
579 .add_request_handler(user_handler(rejoin_channel_buffers))
580 .add_request_handler(user_handler(get_channel_members))
581 .add_request_handler(user_handler(respond_to_channel_invite))
582 .add_request_handler(user_handler(join_channel))
583 .add_request_handler(user_handler(join_channel_chat))
584 .add_message_handler(user_message_handler(leave_channel_chat))
585 .add_request_handler(user_handler(send_channel_message))
586 .add_request_handler(user_handler(remove_channel_message))
587 .add_request_handler(user_handler(update_channel_message))
588 .add_request_handler(user_handler(get_channel_messages))
589 .add_request_handler(user_handler(get_channel_messages_by_id))
590 .add_request_handler(user_handler(get_notifications))
591 .add_request_handler(user_handler(mark_notification_as_read))
592 .add_request_handler(user_handler(move_channel))
593 .add_request_handler(user_handler(follow))
594 .add_message_handler(user_message_handler(unfollow))
595 .add_message_handler(user_message_handler(update_followers))
596 .add_request_handler(user_handler(get_private_user_info))
597 .add_message_handler(user_message_handler(acknowledge_channel_message))
598 .add_message_handler(user_message_handler(acknowledge_buffer_version))
599 .add_request_handler(user_handler(get_supermaven_api_key))
600 .add_request_handler(user_handler(
601 forward_mutating_project_request::<proto::OpenContext>,
602 ))
603 .add_request_handler(user_handler(
604 forward_mutating_project_request::<proto::SynchronizeContexts>,
605 ))
606 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
607 .add_message_handler(update_context)
608 .add_streaming_request_handler({
609 let app_state = app_state.clone();
610 move |request, response, session| {
611 complete_with_language_model(
612 request,
613 response,
614 session,
615 app_state.config.openai_api_key.clone(),
616 app_state.config.google_ai_api_key.clone(),
617 app_state.config.anthropic_api_key.clone(),
618 )
619 }
620 })
621 .add_request_handler({
622 user_handler(move |request, response, session| {
623 get_cached_embeddings(request, response, session)
624 })
625 })
626 .add_request_handler({
627 let app_state = app_state.clone();
628 user_handler(move |request, response, session| {
629 compute_embeddings(
630 request,
631 response,
632 session,
633 app_state.config.openai_api_key.clone(),
634 )
635 })
636 });
637
638 Arc::new(server)
639 }
640
641 pub async fn start(&self) -> Result<()> {
642 let server_id = *self.id.lock();
643 let app_state = self.app_state.clone();
644 let peer = self.peer.clone();
645 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
646 let pool = self.connection_pool.clone();
647 let live_kit_client = self.app_state.live_kit_client.clone();
648
649 let span = info_span!("start server");
650 self.app_state.executor.spawn_detached(
651 async move {
652 tracing::info!("waiting for cleanup timeout");
653 timeout.await;
654 tracing::info!("cleanup timeout expired, retrieving stale rooms");
655 if let Some((room_ids, channel_ids)) = app_state
656 .db
657 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
658 .await
659 .trace_err()
660 {
661 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
662 tracing::info!(
663 stale_channel_buffer_count = channel_ids.len(),
664 "retrieved stale channel buffers"
665 );
666
667 for channel_id in channel_ids {
668 if let Some(refreshed_channel_buffer) = app_state
669 .db
670 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
671 .await
672 .trace_err()
673 {
674 for connection_id in refreshed_channel_buffer.connection_ids {
675 peer.send(
676 connection_id,
677 proto::UpdateChannelBufferCollaborators {
678 channel_id: channel_id.to_proto(),
679 collaborators: refreshed_channel_buffer
680 .collaborators
681 .clone(),
682 },
683 )
684 .trace_err();
685 }
686 }
687 }
688
689 for room_id in room_ids {
690 let mut contacts_to_update = HashSet::default();
691 let mut canceled_calls_to_user_ids = Vec::new();
692 let mut live_kit_room = String::new();
693 let mut delete_live_kit_room = false;
694
695 if let Some(mut refreshed_room) = app_state
696 .db
697 .clear_stale_room_participants(room_id, server_id)
698 .await
699 .trace_err()
700 {
701 tracing::info!(
702 room_id = room_id.0,
703 new_participant_count = refreshed_room.room.participants.len(),
704 "refreshed room"
705 );
706 room_updated(&refreshed_room.room, &peer);
707 if let Some(channel) = refreshed_room.channel.as_ref() {
708 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
709 }
710 contacts_to_update
711 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
712 contacts_to_update
713 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
714 canceled_calls_to_user_ids =
715 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
716 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
717 delete_live_kit_room = refreshed_room.room.participants.is_empty();
718 }
719
720 {
721 let pool = pool.lock();
722 for canceled_user_id in canceled_calls_to_user_ids {
723 for connection_id in pool.user_connection_ids(canceled_user_id) {
724 peer.send(
725 connection_id,
726 proto::CallCanceled {
727 room_id: room_id.to_proto(),
728 },
729 )
730 .trace_err();
731 }
732 }
733 }
734
735 for user_id in contacts_to_update {
736 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
737 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
738 if let Some((busy, contacts)) = busy.zip(contacts) {
739 let pool = pool.lock();
740 let updated_contact = contact_for_user(user_id, busy, &pool);
741 for contact in contacts {
742 if let db::Contact::Accepted {
743 user_id: contact_user_id,
744 ..
745 } = contact
746 {
747 for contact_conn_id in
748 pool.user_connection_ids(contact_user_id)
749 {
750 peer.send(
751 contact_conn_id,
752 proto::UpdateContacts {
753 contacts: vec![updated_contact.clone()],
754 remove_contacts: Default::default(),
755 incoming_requests: Default::default(),
756 remove_incoming_requests: Default::default(),
757 outgoing_requests: Default::default(),
758 remove_outgoing_requests: Default::default(),
759 },
760 )
761 .trace_err();
762 }
763 }
764 }
765 }
766 }
767
768 if let Some(live_kit) = live_kit_client.as_ref() {
769 if delete_live_kit_room {
770 live_kit.delete_room(live_kit_room).await.trace_err();
771 }
772 }
773 }
774 }
775
776 app_state
777 .db
778 .delete_stale_servers(&app_state.config.zed_environment, server_id)
779 .await
780 .trace_err();
781 }
782 .instrument(span),
783 );
784 Ok(())
785 }
786
787 pub fn teardown(&self) {
788 self.peer.teardown();
789 self.connection_pool.lock().reset();
790 let _ = self.teardown.send(true);
791 }
792
793 #[cfg(test)]
794 pub fn reset(&self, id: ServerId) {
795 self.teardown();
796 *self.id.lock() = id;
797 self.peer.reset(id.0 as u32);
798 let _ = self.teardown.send(false);
799 }
800
801 #[cfg(test)]
802 pub fn id(&self) -> ServerId {
803 *self.id.lock()
804 }
805
806 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
807 where
808 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
809 Fut: 'static + Send + Future<Output = Result<()>>,
810 M: EnvelopedMessage,
811 {
812 let prev_handler = self.handlers.insert(
813 TypeId::of::<M>(),
814 Box::new(move |envelope, session| {
815 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
816 let received_at = envelope.received_at;
817 tracing::info!("message received");
818 let start_time = Instant::now();
819 let future = (handler)(*envelope, session);
820 async move {
821 let result = future.await;
822 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
823 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
824 let queue_duration_ms = total_duration_ms - processing_duration_ms;
825 let payload_type = M::NAME;
826
827 match result {
828 Err(error) => {
829 tracing::error!(
830 ?error,
831 total_duration_ms,
832 processing_duration_ms,
833 queue_duration_ms,
834 payload_type,
835 "error handling message"
836 )
837 }
838 Ok(()) => tracing::info!(
839 total_duration_ms,
840 processing_duration_ms,
841 queue_duration_ms,
842 "finished handling message"
843 ),
844 }
845 }
846 .boxed()
847 }),
848 );
849 if prev_handler.is_some() {
850 panic!("registered a handler for the same message twice");
851 }
852 self
853 }
854
855 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
856 where
857 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
858 Fut: 'static + Send + Future<Output = Result<()>>,
859 M: EnvelopedMessage,
860 {
861 self.add_handler(move |envelope, session| handler(envelope.payload, session));
862 self
863 }
864
865 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
866 where
867 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
868 Fut: Send + Future<Output = Result<()>>,
869 M: RequestMessage,
870 {
871 let handler = Arc::new(handler);
872 self.add_handler(move |envelope, session| {
873 let receipt = envelope.receipt();
874 let handler = handler.clone();
875 async move {
876 let peer = session.peer.clone();
877 let responded = Arc::new(AtomicBool::default());
878 let response = Response {
879 peer: peer.clone(),
880 responded: responded.clone(),
881 receipt,
882 };
883 match (handler)(envelope.payload, response, session).await {
884 Ok(()) => {
885 if responded.load(std::sync::atomic::Ordering::SeqCst) {
886 Ok(())
887 } else {
888 Err(anyhow!("handler did not send a response"))?
889 }
890 }
891 Err(error) => {
892 let proto_err = match &error {
893 Error::Internal(err) => err.to_proto(),
894 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
895 };
896 peer.respond_with_error(receipt, proto_err)?;
897 Err(error)
898 }
899 }
900 }
901 })
902 }
903
904 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
905 where
906 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
907 Fut: Send + Future<Output = Result<()>>,
908 M: RequestMessage,
909 {
910 let handler = Arc::new(handler);
911 self.add_handler(move |envelope, session| {
912 let receipt = envelope.receipt();
913 let handler = handler.clone();
914 async move {
915 let peer = session.peer.clone();
916 let response = StreamingResponse {
917 peer: peer.clone(),
918 receipt,
919 };
920 match (handler)(envelope.payload, response, session).await {
921 Ok(()) => {
922 peer.end_stream(receipt)?;
923 Ok(())
924 }
925 Err(error) => {
926 let proto_err = match &error {
927 Error::Internal(err) => err.to_proto(),
928 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
929 };
930 peer.respond_with_error(receipt, proto_err)?;
931 Err(error)
932 }
933 }
934 }
935 })
936 }
937
938 #[allow(clippy::too_many_arguments)]
939 pub fn handle_connection(
940 self: &Arc<Self>,
941 connection: Connection,
942 address: String,
943 principal: Principal,
944 zed_version: ZedVersion,
945 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
946 executor: Executor,
947 ) -> impl Future<Output = ()> {
948 let this = self.clone();
949 let span = info_span!("handle connection", %address,
950 connection_id=field::Empty,
951 user_id=field::Empty,
952 login=field::Empty,
953 impersonator=field::Empty,
954 dev_server_id=field::Empty
955 );
956 principal.update_span(&span);
957
958 let mut teardown = self.teardown.subscribe();
959 async move {
960 if *teardown.borrow() {
961 tracing::error!("server is tearing down");
962 return
963 }
964 let (connection_id, handle_io, mut incoming_rx) = this
965 .peer
966 .add_connection(connection, {
967 let executor = executor.clone();
968 move |duration| executor.sleep(duration)
969 });
970 tracing::Span::current().record("connection_id", format!("{}", connection_id));
971 tracing::info!("connection opened");
972
973 let http_client = match IsahcHttpClient::new() {
974 Ok(http_client) => Arc::new(http_client),
975 Err(error) => {
976 tracing::error!(?error, "failed to create HTTP client");
977 return;
978 }
979 };
980
981 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
982 Some(Arc::new(SupermavenAdminApi::new(
983 supermaven_admin_api_key.to_string(),
984 http_client.clone(),
985 )))
986 } else {
987 None
988 };
989
990 let session = Session {
991 principal: principal.clone(),
992 connection_id,
993 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
994 peer: this.peer.clone(),
995 connection_pool: this.connection_pool.clone(),
996 live_kit_client: this.app_state.live_kit_client.clone(),
997 http_client,
998 rate_limiter: this.app_state.rate_limiter.clone(),
999 _executor: executor.clone(),
1000 supermaven_client,
1001 };
1002
1003 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1004 tracing::error!(?error, "failed to send initial client update");
1005 return;
1006 }
1007
1008 let handle_io = handle_io.fuse();
1009 futures::pin_mut!(handle_io);
1010
1011 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1012 // This prevents deadlocks when e.g., client A performs a request to client B and
1013 // client B performs a request to client A. If both clients stop processing further
1014 // messages until their respective request completes, they won't have a chance to
1015 // respond to the other client's request and cause a deadlock.
1016 //
1017 // This arrangement ensures we will attempt to process earlier messages first, but fall
1018 // back to processing messages arrived later in the spirit of making progress.
1019 let mut foreground_message_handlers = FuturesUnordered::new();
1020 let concurrent_handlers = Arc::new(Semaphore::new(256));
1021 loop {
1022 let next_message = async {
1023 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1024 let message = incoming_rx.next().await;
1025 (permit, message)
1026 }.fuse();
1027 futures::pin_mut!(next_message);
1028 futures::select_biased! {
1029 _ = teardown.changed().fuse() => return,
1030 result = handle_io => {
1031 if let Err(error) = result {
1032 tracing::error!(?error, "error handling I/O");
1033 }
1034 break;
1035 }
1036 _ = foreground_message_handlers.next() => {}
1037 next_message = next_message => {
1038 let (permit, message) = next_message;
1039 if let Some(message) = message {
1040 let type_name = message.payload_type_name();
1041 // note: we copy all the fields from the parent span so we can query them in the logs.
1042 // (https://github.com/tokio-rs/tracing/issues/2670).
1043 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1044 user_id=field::Empty,
1045 login=field::Empty,
1046 impersonator=field::Empty,
1047 dev_server_id=field::Empty
1048 );
1049 principal.update_span(&span);
1050 let span_enter = span.enter();
1051 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1052 let is_background = message.is_background();
1053 let handle_message = (handler)(message, session.clone());
1054 drop(span_enter);
1055
1056 let handle_message = async move {
1057 handle_message.await;
1058 drop(permit);
1059 }.instrument(span);
1060 if is_background {
1061 executor.spawn_detached(handle_message);
1062 } else {
1063 foreground_message_handlers.push(handle_message);
1064 }
1065 } else {
1066 tracing::error!("no message handler");
1067 }
1068 } else {
1069 tracing::info!("connection closed");
1070 break;
1071 }
1072 }
1073 }
1074 }
1075
1076 drop(foreground_message_handlers);
1077 tracing::info!("signing out");
1078 if let Err(error) = connection_lost(session, teardown, executor).await {
1079 tracing::error!(?error, "error signing out");
1080 }
1081
1082 }.instrument(span)
1083 }
1084
1085 async fn send_initial_client_update(
1086 &self,
1087 connection_id: ConnectionId,
1088 principal: &Principal,
1089 zed_version: ZedVersion,
1090 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1091 session: &Session,
1092 ) -> Result<()> {
1093 self.peer.send(
1094 connection_id,
1095 proto::Hello {
1096 peer_id: Some(connection_id.into()),
1097 },
1098 )?;
1099 tracing::info!("sent hello message");
1100 if let Some(send_connection_id) = send_connection_id.take() {
1101 let _ = send_connection_id.send(connection_id);
1102 }
1103
1104 match principal {
1105 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1106 if !user.connected_once {
1107 self.peer.send(connection_id, proto::ShowContacts {})?;
1108 self.app_state
1109 .db
1110 .set_user_connected_once(user.id, true)
1111 .await?;
1112 }
1113
1114 let (contacts, dev_server_projects) = future::try_join(
1115 self.app_state.db.get_contacts(user.id),
1116 self.app_state.db.dev_server_projects_update(user.id),
1117 )
1118 .await?;
1119
1120 {
1121 let mut pool = self.connection_pool.lock();
1122 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1123 self.peer.send(
1124 connection_id,
1125 build_initial_contacts_update(contacts, &pool),
1126 )?;
1127 }
1128
1129 if should_auto_subscribe_to_channels(zed_version) {
1130 subscribe_user_to_channels(user.id, session).await?;
1131 }
1132
1133 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1134
1135 if let Some(incoming_call) =
1136 self.app_state.db.incoming_call_for_user(user.id).await?
1137 {
1138 self.peer.send(connection_id, incoming_call)?;
1139 }
1140
1141 update_user_contacts(user.id, &session).await?;
1142 }
1143 Principal::DevServer(dev_server) => {
1144 {
1145 let mut pool = self.connection_pool.lock();
1146 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1147 {
1148 self.peer.send(
1149 stale_connection_id,
1150 proto::ShutdownDevServer {
1151 reason: Some(
1152 "another dev server connected with the same token".to_string(),
1153 ),
1154 },
1155 )?;
1156 pool.remove_connection(stale_connection_id)?;
1157 };
1158 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1159 }
1160
1161 let projects = self
1162 .app_state
1163 .db
1164 .get_projects_for_dev_server(dev_server.id)
1165 .await?;
1166 self.peer
1167 .send(connection_id, proto::DevServerInstructions { projects })?;
1168
1169 let status = self
1170 .app_state
1171 .db
1172 .dev_server_projects_update(dev_server.user_id)
1173 .await?;
1174 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1175 }
1176 }
1177
1178 Ok(())
1179 }
1180
1181 pub async fn invite_code_redeemed(
1182 self: &Arc<Self>,
1183 inviter_id: UserId,
1184 invitee_id: UserId,
1185 ) -> Result<()> {
1186 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1187 if let Some(code) = &user.invite_code {
1188 let pool = self.connection_pool.lock();
1189 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1190 for connection_id in pool.user_connection_ids(inviter_id) {
1191 self.peer.send(
1192 connection_id,
1193 proto::UpdateContacts {
1194 contacts: vec![invitee_contact.clone()],
1195 ..Default::default()
1196 },
1197 )?;
1198 self.peer.send(
1199 connection_id,
1200 proto::UpdateInviteInfo {
1201 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1202 count: user.invite_count as u32,
1203 },
1204 )?;
1205 }
1206 }
1207 }
1208 Ok(())
1209 }
1210
1211 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1212 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1213 if let Some(invite_code) = &user.invite_code {
1214 let pool = self.connection_pool.lock();
1215 for connection_id in pool.user_connection_ids(user_id) {
1216 self.peer.send(
1217 connection_id,
1218 proto::UpdateInviteInfo {
1219 url: format!(
1220 "{}{}",
1221 self.app_state.config.invite_link_prefix, invite_code
1222 ),
1223 count: user.invite_count as u32,
1224 },
1225 )?;
1226 }
1227 }
1228 }
1229 Ok(())
1230 }
1231
1232 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1233 ServerSnapshot {
1234 connection_pool: ConnectionPoolGuard {
1235 guard: self.connection_pool.lock(),
1236 _not_send: PhantomData,
1237 },
1238 peer: &self.peer,
1239 }
1240 }
1241}
1242
1243impl<'a> Deref for ConnectionPoolGuard<'a> {
1244 type Target = ConnectionPool;
1245
1246 fn deref(&self) -> &Self::Target {
1247 &self.guard
1248 }
1249}
1250
1251impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1252 fn deref_mut(&mut self) -> &mut Self::Target {
1253 &mut self.guard
1254 }
1255}
1256
1257impl<'a> Drop for ConnectionPoolGuard<'a> {
1258 fn drop(&mut self) {
1259 #[cfg(test)]
1260 self.check_invariants();
1261 }
1262}
1263
1264fn broadcast<F>(
1265 sender_id: Option<ConnectionId>,
1266 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1267 mut f: F,
1268) where
1269 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1270{
1271 for receiver_id in receiver_ids {
1272 if Some(receiver_id) != sender_id {
1273 if let Err(error) = f(receiver_id) {
1274 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1275 }
1276 }
1277 }
1278}
1279
1280pub struct ProtocolVersion(u32);
1281
1282impl Header for ProtocolVersion {
1283 fn name() -> &'static HeaderName {
1284 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1285 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1286 }
1287
1288 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1289 where
1290 Self: Sized,
1291 I: Iterator<Item = &'i axum::http::HeaderValue>,
1292 {
1293 let version = values
1294 .next()
1295 .ok_or_else(axum::headers::Error::invalid)?
1296 .to_str()
1297 .map_err(|_| axum::headers::Error::invalid())?
1298 .parse()
1299 .map_err(|_| axum::headers::Error::invalid())?;
1300 Ok(Self(version))
1301 }
1302
1303 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1304 values.extend([self.0.to_string().parse().unwrap()]);
1305 }
1306}
1307
1308pub struct AppVersionHeader(SemanticVersion);
1309impl Header for AppVersionHeader {
1310 fn name() -> &'static HeaderName {
1311 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1312 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1313 }
1314
1315 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1316 where
1317 Self: Sized,
1318 I: Iterator<Item = &'i axum::http::HeaderValue>,
1319 {
1320 let version = values
1321 .next()
1322 .ok_or_else(axum::headers::Error::invalid)?
1323 .to_str()
1324 .map_err(|_| axum::headers::Error::invalid())?
1325 .parse()
1326 .map_err(|_| axum::headers::Error::invalid())?;
1327 Ok(Self(version))
1328 }
1329
1330 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1331 values.extend([self.0.to_string().parse().unwrap()]);
1332 }
1333}
1334
1335pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1336 Router::new()
1337 .route("/rpc", get(handle_websocket_request))
1338 .layer(
1339 ServiceBuilder::new()
1340 .layer(Extension(server.app_state.clone()))
1341 .layer(middleware::from_fn(auth::validate_header)),
1342 )
1343 .route("/metrics", get(handle_metrics))
1344 .layer(Extension(server))
1345}
1346
1347pub async fn handle_websocket_request(
1348 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1349 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1350 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1351 Extension(server): Extension<Arc<Server>>,
1352 Extension(principal): Extension<Principal>,
1353 ws: WebSocketUpgrade,
1354) -> axum::response::Response {
1355 if protocol_version != rpc::PROTOCOL_VERSION {
1356 return (
1357 StatusCode::UPGRADE_REQUIRED,
1358 "client must be upgraded".to_string(),
1359 )
1360 .into_response();
1361 }
1362
1363 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1364 return (
1365 StatusCode::UPGRADE_REQUIRED,
1366 "no version header found".to_string(),
1367 )
1368 .into_response();
1369 };
1370
1371 if !version.can_collaborate() {
1372 return (
1373 StatusCode::UPGRADE_REQUIRED,
1374 "client must be upgraded".to_string(),
1375 )
1376 .into_response();
1377 }
1378
1379 let socket_address = socket_address.to_string();
1380 ws.on_upgrade(move |socket| {
1381 let socket = socket
1382 .map_ok(to_tungstenite_message)
1383 .err_into()
1384 .with(|message| async move { to_axum_message(message) });
1385 let connection = Connection::new(Box::pin(socket));
1386 async move {
1387 server
1388 .handle_connection(
1389 connection,
1390 socket_address,
1391 principal,
1392 version,
1393 None,
1394 Executor::Production,
1395 )
1396 .await;
1397 }
1398 })
1399}
1400
1401pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1402 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1403 let connections_metric = CONNECTIONS_METRIC
1404 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1405
1406 let connections = server
1407 .connection_pool
1408 .lock()
1409 .connections()
1410 .filter(|connection| !connection.admin)
1411 .count();
1412 connections_metric.set(connections as _);
1413
1414 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1415 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1416 register_int_gauge!(
1417 "shared_projects",
1418 "number of open projects with one or more guests"
1419 )
1420 .unwrap()
1421 });
1422
1423 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1424 shared_projects_metric.set(shared_projects as _);
1425
1426 let encoder = prometheus::TextEncoder::new();
1427 let metric_families = prometheus::gather();
1428 let encoded_metrics = encoder
1429 .encode_to_string(&metric_families)
1430 .map_err(|err| anyhow!("{}", err))?;
1431 Ok(encoded_metrics)
1432}
1433
1434#[instrument(err, skip(executor))]
1435async fn connection_lost(
1436 session: Session,
1437 mut teardown: watch::Receiver<bool>,
1438 executor: Executor,
1439) -> Result<()> {
1440 session.peer.disconnect(session.connection_id);
1441 session
1442 .connection_pool()
1443 .await
1444 .remove_connection(session.connection_id)?;
1445
1446 session
1447 .db()
1448 .await
1449 .connection_lost(session.connection_id)
1450 .await
1451 .trace_err();
1452
1453 futures::select_biased! {
1454 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1455 match &session.principal {
1456 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1457 let session = session.for_user().unwrap();
1458
1459 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1460 leave_room_for_session(&session, session.connection_id).await.trace_err();
1461 leave_channel_buffers_for_session(&session)
1462 .await
1463 .trace_err();
1464
1465 if !session
1466 .connection_pool()
1467 .await
1468 .is_user_online(session.user_id())
1469 {
1470 let db = session.db().await;
1471 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1472 room_updated(&room, &session.peer);
1473 }
1474 }
1475
1476 update_user_contacts(session.user_id(), &session).await?;
1477 },
1478 Principal::DevServer(_) => {
1479 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1480 },
1481 }
1482 },
1483 _ = teardown.changed().fuse() => {}
1484 }
1485
1486 Ok(())
1487}
1488
1489/// Acknowledges a ping from a client, used to keep the connection alive.
1490async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1491 response.send(proto::Ack {})?;
1492 Ok(())
1493}
1494
1495/// Creates a new room for calling (outside of channels)
1496async fn create_room(
1497 _request: proto::CreateRoom,
1498 response: Response<proto::CreateRoom>,
1499 session: UserSession,
1500) -> Result<()> {
1501 let live_kit_room = nanoid::nanoid!(30);
1502
1503 let live_kit_connection_info = util::maybe!(async {
1504 let live_kit = session.live_kit_client.as_ref();
1505 let live_kit = live_kit?;
1506 let user_id = session.user_id().to_string();
1507
1508 let token = live_kit
1509 .room_token(&live_kit_room, &user_id.to_string())
1510 .trace_err()?;
1511
1512 Some(proto::LiveKitConnectionInfo {
1513 server_url: live_kit.url().into(),
1514 token,
1515 can_publish: true,
1516 })
1517 })
1518 .await;
1519
1520 let room = session
1521 .db()
1522 .await
1523 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1524 .await?;
1525
1526 response.send(proto::CreateRoomResponse {
1527 room: Some(room.clone()),
1528 live_kit_connection_info,
1529 })?;
1530
1531 update_user_contacts(session.user_id(), &session).await?;
1532 Ok(())
1533}
1534
1535/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1536async fn join_room(
1537 request: proto::JoinRoom,
1538 response: Response<proto::JoinRoom>,
1539 session: UserSession,
1540) -> Result<()> {
1541 let room_id = RoomId::from_proto(request.id);
1542
1543 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1544
1545 if let Some(channel_id) = channel_id {
1546 return join_channel_internal(channel_id, Box::new(response), session).await;
1547 }
1548
1549 let joined_room = {
1550 let room = session
1551 .db()
1552 .await
1553 .join_room(room_id, session.user_id(), session.connection_id)
1554 .await?;
1555 room_updated(&room.room, &session.peer);
1556 room.into_inner()
1557 };
1558
1559 for connection_id in session
1560 .connection_pool()
1561 .await
1562 .user_connection_ids(session.user_id())
1563 {
1564 session
1565 .peer
1566 .send(
1567 connection_id,
1568 proto::CallCanceled {
1569 room_id: room_id.to_proto(),
1570 },
1571 )
1572 .trace_err();
1573 }
1574
1575 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1576 if let Some(token) = live_kit
1577 .room_token(
1578 &joined_room.room.live_kit_room,
1579 &session.user_id().to_string(),
1580 )
1581 .trace_err()
1582 {
1583 Some(proto::LiveKitConnectionInfo {
1584 server_url: live_kit.url().into(),
1585 token,
1586 can_publish: true,
1587 })
1588 } else {
1589 None
1590 }
1591 } else {
1592 None
1593 };
1594
1595 response.send(proto::JoinRoomResponse {
1596 room: Some(joined_room.room),
1597 channel_id: None,
1598 live_kit_connection_info,
1599 })?;
1600
1601 update_user_contacts(session.user_id(), &session).await?;
1602 Ok(())
1603}
1604
1605/// Rejoin room is used to reconnect to a room after connection errors.
1606async fn rejoin_room(
1607 request: proto::RejoinRoom,
1608 response: Response<proto::RejoinRoom>,
1609 session: UserSession,
1610) -> Result<()> {
1611 let room;
1612 let channel;
1613 {
1614 let mut rejoined_room = session
1615 .db()
1616 .await
1617 .rejoin_room(request, session.user_id(), session.connection_id)
1618 .await?;
1619
1620 response.send(proto::RejoinRoomResponse {
1621 room: Some(rejoined_room.room.clone()),
1622 reshared_projects: rejoined_room
1623 .reshared_projects
1624 .iter()
1625 .map(|project| proto::ResharedProject {
1626 id: project.id.to_proto(),
1627 collaborators: project
1628 .collaborators
1629 .iter()
1630 .map(|collaborator| collaborator.to_proto())
1631 .collect(),
1632 })
1633 .collect(),
1634 rejoined_projects: rejoined_room
1635 .rejoined_projects
1636 .iter()
1637 .map(|rejoined_project| rejoined_project.to_proto())
1638 .collect(),
1639 })?;
1640 room_updated(&rejoined_room.room, &session.peer);
1641
1642 for project in &rejoined_room.reshared_projects {
1643 for collaborator in &project.collaborators {
1644 session
1645 .peer
1646 .send(
1647 collaborator.connection_id,
1648 proto::UpdateProjectCollaborator {
1649 project_id: project.id.to_proto(),
1650 old_peer_id: Some(project.old_connection_id.into()),
1651 new_peer_id: Some(session.connection_id.into()),
1652 },
1653 )
1654 .trace_err();
1655 }
1656
1657 broadcast(
1658 Some(session.connection_id),
1659 project
1660 .collaborators
1661 .iter()
1662 .map(|collaborator| collaborator.connection_id),
1663 |connection_id| {
1664 session.peer.forward_send(
1665 session.connection_id,
1666 connection_id,
1667 proto::UpdateProject {
1668 project_id: project.id.to_proto(),
1669 worktrees: project.worktrees.clone(),
1670 },
1671 )
1672 },
1673 );
1674 }
1675
1676 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1677
1678 let rejoined_room = rejoined_room.into_inner();
1679
1680 room = rejoined_room.room;
1681 channel = rejoined_room.channel;
1682 }
1683
1684 if let Some(channel) = channel {
1685 channel_updated(
1686 &channel,
1687 &room,
1688 &session.peer,
1689 &*session.connection_pool().await,
1690 );
1691 }
1692
1693 update_user_contacts(session.user_id(), &session).await?;
1694 Ok(())
1695}
1696
1697fn notify_rejoined_projects(
1698 rejoined_projects: &mut Vec<RejoinedProject>,
1699 session: &UserSession,
1700) -> Result<()> {
1701 for project in rejoined_projects.iter() {
1702 for collaborator in &project.collaborators {
1703 session
1704 .peer
1705 .send(
1706 collaborator.connection_id,
1707 proto::UpdateProjectCollaborator {
1708 project_id: project.id.to_proto(),
1709 old_peer_id: Some(project.old_connection_id.into()),
1710 new_peer_id: Some(session.connection_id.into()),
1711 },
1712 )
1713 .trace_err();
1714 }
1715 }
1716
1717 for project in rejoined_projects {
1718 for worktree in mem::take(&mut project.worktrees) {
1719 #[cfg(any(test, feature = "test-support"))]
1720 const MAX_CHUNK_SIZE: usize = 2;
1721 #[cfg(not(any(test, feature = "test-support")))]
1722 const MAX_CHUNK_SIZE: usize = 256;
1723
1724 // Stream this worktree's entries.
1725 let message = proto::UpdateWorktree {
1726 project_id: project.id.to_proto(),
1727 worktree_id: worktree.id,
1728 abs_path: worktree.abs_path.clone(),
1729 root_name: worktree.root_name,
1730 updated_entries: worktree.updated_entries,
1731 removed_entries: worktree.removed_entries,
1732 scan_id: worktree.scan_id,
1733 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1734 updated_repositories: worktree.updated_repositories,
1735 removed_repositories: worktree.removed_repositories,
1736 };
1737 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1738 session.peer.send(session.connection_id, update.clone())?;
1739 }
1740
1741 // Stream this worktree's diagnostics.
1742 for summary in worktree.diagnostic_summaries {
1743 session.peer.send(
1744 session.connection_id,
1745 proto::UpdateDiagnosticSummary {
1746 project_id: project.id.to_proto(),
1747 worktree_id: worktree.id,
1748 summary: Some(summary),
1749 },
1750 )?;
1751 }
1752
1753 for settings_file in worktree.settings_files {
1754 session.peer.send(
1755 session.connection_id,
1756 proto::UpdateWorktreeSettings {
1757 project_id: project.id.to_proto(),
1758 worktree_id: worktree.id,
1759 path: settings_file.path,
1760 content: Some(settings_file.content),
1761 },
1762 )?;
1763 }
1764 }
1765
1766 for language_server in &project.language_servers {
1767 session.peer.send(
1768 session.connection_id,
1769 proto::UpdateLanguageServer {
1770 project_id: project.id.to_proto(),
1771 language_server_id: language_server.id,
1772 variant: Some(
1773 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1774 proto::LspDiskBasedDiagnosticsUpdated {},
1775 ),
1776 ),
1777 },
1778 )?;
1779 }
1780 }
1781 Ok(())
1782}
1783
1784/// leave room disconnects from the room.
1785async fn leave_room(
1786 _: proto::LeaveRoom,
1787 response: Response<proto::LeaveRoom>,
1788 session: UserSession,
1789) -> Result<()> {
1790 leave_room_for_session(&session, session.connection_id).await?;
1791 response.send(proto::Ack {})?;
1792 Ok(())
1793}
1794
1795/// Updates the permissions of someone else in the room.
1796async fn set_room_participant_role(
1797 request: proto::SetRoomParticipantRole,
1798 response: Response<proto::SetRoomParticipantRole>,
1799 session: UserSession,
1800) -> Result<()> {
1801 let user_id = UserId::from_proto(request.user_id);
1802 let role = ChannelRole::from(request.role());
1803
1804 let (live_kit_room, can_publish) = {
1805 let room = session
1806 .db()
1807 .await
1808 .set_room_participant_role(
1809 session.user_id(),
1810 RoomId::from_proto(request.room_id),
1811 user_id,
1812 role,
1813 )
1814 .await?;
1815
1816 let live_kit_room = room.live_kit_room.clone();
1817 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1818 room_updated(&room, &session.peer);
1819 (live_kit_room, can_publish)
1820 };
1821
1822 if let Some(live_kit) = session.live_kit_client.as_ref() {
1823 live_kit
1824 .update_participant(
1825 live_kit_room.clone(),
1826 request.user_id.to_string(),
1827 live_kit_server::proto::ParticipantPermission {
1828 can_subscribe: true,
1829 can_publish,
1830 can_publish_data: can_publish,
1831 hidden: false,
1832 recorder: false,
1833 },
1834 )
1835 .await
1836 .trace_err();
1837 }
1838
1839 response.send(proto::Ack {})?;
1840 Ok(())
1841}
1842
1843/// Call someone else into the current room
1844async fn call(
1845 request: proto::Call,
1846 response: Response<proto::Call>,
1847 session: UserSession,
1848) -> Result<()> {
1849 let room_id = RoomId::from_proto(request.room_id);
1850 let calling_user_id = session.user_id();
1851 let calling_connection_id = session.connection_id;
1852 let called_user_id = UserId::from_proto(request.called_user_id);
1853 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1854 if !session
1855 .db()
1856 .await
1857 .has_contact(calling_user_id, called_user_id)
1858 .await?
1859 {
1860 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1861 }
1862
1863 let incoming_call = {
1864 let (room, incoming_call) = &mut *session
1865 .db()
1866 .await
1867 .call(
1868 room_id,
1869 calling_user_id,
1870 calling_connection_id,
1871 called_user_id,
1872 initial_project_id,
1873 )
1874 .await?;
1875 room_updated(&room, &session.peer);
1876 mem::take(incoming_call)
1877 };
1878 update_user_contacts(called_user_id, &session).await?;
1879
1880 let mut calls = session
1881 .connection_pool()
1882 .await
1883 .user_connection_ids(called_user_id)
1884 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1885 .collect::<FuturesUnordered<_>>();
1886
1887 while let Some(call_response) = calls.next().await {
1888 match call_response.as_ref() {
1889 Ok(_) => {
1890 response.send(proto::Ack {})?;
1891 return Ok(());
1892 }
1893 Err(_) => {
1894 call_response.trace_err();
1895 }
1896 }
1897 }
1898
1899 {
1900 let room = session
1901 .db()
1902 .await
1903 .call_failed(room_id, called_user_id)
1904 .await?;
1905 room_updated(&room, &session.peer);
1906 }
1907 update_user_contacts(called_user_id, &session).await?;
1908
1909 Err(anyhow!("failed to ring user"))?
1910}
1911
1912/// Cancel an outgoing call.
1913async fn cancel_call(
1914 request: proto::CancelCall,
1915 response: Response<proto::CancelCall>,
1916 session: UserSession,
1917) -> Result<()> {
1918 let called_user_id = UserId::from_proto(request.called_user_id);
1919 let room_id = RoomId::from_proto(request.room_id);
1920 {
1921 let room = session
1922 .db()
1923 .await
1924 .cancel_call(room_id, session.connection_id, called_user_id)
1925 .await?;
1926 room_updated(&room, &session.peer);
1927 }
1928
1929 for connection_id in session
1930 .connection_pool()
1931 .await
1932 .user_connection_ids(called_user_id)
1933 {
1934 session
1935 .peer
1936 .send(
1937 connection_id,
1938 proto::CallCanceled {
1939 room_id: room_id.to_proto(),
1940 },
1941 )
1942 .trace_err();
1943 }
1944 response.send(proto::Ack {})?;
1945
1946 update_user_contacts(called_user_id, &session).await?;
1947 Ok(())
1948}
1949
1950/// Decline an incoming call.
1951async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1952 let room_id = RoomId::from_proto(message.room_id);
1953 {
1954 let room = session
1955 .db()
1956 .await
1957 .decline_call(Some(room_id), session.user_id())
1958 .await?
1959 .ok_or_else(|| anyhow!("failed to decline call"))?;
1960 room_updated(&room, &session.peer);
1961 }
1962
1963 for connection_id in session
1964 .connection_pool()
1965 .await
1966 .user_connection_ids(session.user_id())
1967 {
1968 session
1969 .peer
1970 .send(
1971 connection_id,
1972 proto::CallCanceled {
1973 room_id: room_id.to_proto(),
1974 },
1975 )
1976 .trace_err();
1977 }
1978 update_user_contacts(session.user_id(), &session).await?;
1979 Ok(())
1980}
1981
1982/// Updates other participants in the room with your current location.
1983async fn update_participant_location(
1984 request: proto::UpdateParticipantLocation,
1985 response: Response<proto::UpdateParticipantLocation>,
1986 session: UserSession,
1987) -> Result<()> {
1988 let room_id = RoomId::from_proto(request.room_id);
1989 let location = request
1990 .location
1991 .ok_or_else(|| anyhow!("invalid location"))?;
1992
1993 let db = session.db().await;
1994 let room = db
1995 .update_room_participant_location(room_id, session.connection_id, location)
1996 .await?;
1997
1998 room_updated(&room, &session.peer);
1999 response.send(proto::Ack {})?;
2000 Ok(())
2001}
2002
2003/// Share a project into the room.
2004async fn share_project(
2005 request: proto::ShareProject,
2006 response: Response<proto::ShareProject>,
2007 session: UserSession,
2008) -> Result<()> {
2009 let (project_id, room) = &*session
2010 .db()
2011 .await
2012 .share_project(
2013 RoomId::from_proto(request.room_id),
2014 session.connection_id,
2015 &request.worktrees,
2016 request
2017 .dev_server_project_id
2018 .map(|id| DevServerProjectId::from_proto(id)),
2019 )
2020 .await?;
2021 response.send(proto::ShareProjectResponse {
2022 project_id: project_id.to_proto(),
2023 })?;
2024 room_updated(&room, &session.peer);
2025
2026 Ok(())
2027}
2028
2029/// Unshare a project from the room.
2030async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2031 let project_id = ProjectId::from_proto(message.project_id);
2032 unshare_project_internal(
2033 project_id,
2034 session.connection_id,
2035 session.user_id(),
2036 &session,
2037 )
2038 .await
2039}
2040
2041async fn unshare_project_internal(
2042 project_id: ProjectId,
2043 connection_id: ConnectionId,
2044 user_id: Option<UserId>,
2045 session: &Session,
2046) -> Result<()> {
2047 let delete = {
2048 let room_guard = session
2049 .db()
2050 .await
2051 .unshare_project(project_id, connection_id, user_id)
2052 .await?;
2053
2054 let (delete, room, guest_connection_ids) = &*room_guard;
2055
2056 let message = proto::UnshareProject {
2057 project_id: project_id.to_proto(),
2058 };
2059
2060 broadcast(
2061 Some(connection_id),
2062 guest_connection_ids.iter().copied(),
2063 |conn_id| session.peer.send(conn_id, message.clone()),
2064 );
2065 if let Some(room) = room {
2066 room_updated(room, &session.peer);
2067 }
2068
2069 *delete
2070 };
2071
2072 if delete {
2073 let db = session.db().await;
2074 db.delete_project(project_id).await?;
2075 }
2076
2077 Ok(())
2078}
2079
2080/// DevServer makes a project available online
2081async fn share_dev_server_project(
2082 request: proto::ShareDevServerProject,
2083 response: Response<proto::ShareDevServerProject>,
2084 session: DevServerSession,
2085) -> Result<()> {
2086 let (dev_server_project, user_id, status) = session
2087 .db()
2088 .await
2089 .share_dev_server_project(
2090 DevServerProjectId::from_proto(request.dev_server_project_id),
2091 session.dev_server_id(),
2092 session.connection_id,
2093 &request.worktrees,
2094 )
2095 .await?;
2096 let Some(project_id) = dev_server_project.project_id else {
2097 return Err(anyhow!("failed to share remote project"))?;
2098 };
2099
2100 send_dev_server_projects_update(user_id, status, &session).await;
2101
2102 response.send(proto::ShareProjectResponse { project_id })?;
2103
2104 Ok(())
2105}
2106
2107/// Join someone elses shared project.
2108async fn join_project(
2109 request: proto::JoinProject,
2110 response: Response<proto::JoinProject>,
2111 session: UserSession,
2112) -> Result<()> {
2113 let project_id = ProjectId::from_proto(request.project_id);
2114
2115 tracing::info!(%project_id, "join project");
2116
2117 let db = session.db().await;
2118 let (project, replica_id) = &mut *db
2119 .join_project(project_id, session.connection_id, session.user_id())
2120 .await?;
2121 drop(db);
2122 tracing::info!(%project_id, "join remote project");
2123 join_project_internal(response, session, project, replica_id)
2124}
2125
2126trait JoinProjectInternalResponse {
2127 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2128}
2129impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2130 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2131 Response::<proto::JoinProject>::send(self, result)
2132 }
2133}
2134impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2135 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2136 Response::<proto::JoinHostedProject>::send(self, result)
2137 }
2138}
2139
2140fn join_project_internal(
2141 response: impl JoinProjectInternalResponse,
2142 session: UserSession,
2143 project: &mut Project,
2144 replica_id: &ReplicaId,
2145) -> Result<()> {
2146 let collaborators = project
2147 .collaborators
2148 .iter()
2149 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2150 .map(|collaborator| collaborator.to_proto())
2151 .collect::<Vec<_>>();
2152 let project_id = project.id;
2153 let guest_user_id = session.user_id();
2154
2155 let worktrees = project
2156 .worktrees
2157 .iter()
2158 .map(|(id, worktree)| proto::WorktreeMetadata {
2159 id: *id,
2160 root_name: worktree.root_name.clone(),
2161 visible: worktree.visible,
2162 abs_path: worktree.abs_path.clone(),
2163 })
2164 .collect::<Vec<_>>();
2165
2166 let add_project_collaborator = proto::AddProjectCollaborator {
2167 project_id: project_id.to_proto(),
2168 collaborator: Some(proto::Collaborator {
2169 peer_id: Some(session.connection_id.into()),
2170 replica_id: replica_id.0 as u32,
2171 user_id: guest_user_id.to_proto(),
2172 }),
2173 };
2174
2175 for collaborator in &collaborators {
2176 session
2177 .peer
2178 .send(
2179 collaborator.peer_id.unwrap().into(),
2180 add_project_collaborator.clone(),
2181 )
2182 .trace_err();
2183 }
2184
2185 // First, we send the metadata associated with each worktree.
2186 response.send(proto::JoinProjectResponse {
2187 project_id: project.id.0 as u64,
2188 worktrees: worktrees.clone(),
2189 replica_id: replica_id.0 as u32,
2190 collaborators: collaborators.clone(),
2191 language_servers: project.language_servers.clone(),
2192 role: project.role.into(),
2193 dev_server_project_id: project
2194 .dev_server_project_id
2195 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2196 })?;
2197
2198 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2199 #[cfg(any(test, feature = "test-support"))]
2200 const MAX_CHUNK_SIZE: usize = 2;
2201 #[cfg(not(any(test, feature = "test-support")))]
2202 const MAX_CHUNK_SIZE: usize = 256;
2203
2204 // Stream this worktree's entries.
2205 let message = proto::UpdateWorktree {
2206 project_id: project_id.to_proto(),
2207 worktree_id,
2208 abs_path: worktree.abs_path.clone(),
2209 root_name: worktree.root_name,
2210 updated_entries: worktree.entries,
2211 removed_entries: Default::default(),
2212 scan_id: worktree.scan_id,
2213 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2214 updated_repositories: worktree.repository_entries.into_values().collect(),
2215 removed_repositories: Default::default(),
2216 };
2217 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2218 session.peer.send(session.connection_id, update.clone())?;
2219 }
2220
2221 // Stream this worktree's diagnostics.
2222 for summary in worktree.diagnostic_summaries {
2223 session.peer.send(
2224 session.connection_id,
2225 proto::UpdateDiagnosticSummary {
2226 project_id: project_id.to_proto(),
2227 worktree_id: worktree.id,
2228 summary: Some(summary),
2229 },
2230 )?;
2231 }
2232
2233 for settings_file in worktree.settings_files {
2234 session.peer.send(
2235 session.connection_id,
2236 proto::UpdateWorktreeSettings {
2237 project_id: project_id.to_proto(),
2238 worktree_id: worktree.id,
2239 path: settings_file.path,
2240 content: Some(settings_file.content),
2241 },
2242 )?;
2243 }
2244 }
2245
2246 for language_server in &project.language_servers {
2247 session.peer.send(
2248 session.connection_id,
2249 proto::UpdateLanguageServer {
2250 project_id: project_id.to_proto(),
2251 language_server_id: language_server.id,
2252 variant: Some(
2253 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2254 proto::LspDiskBasedDiagnosticsUpdated {},
2255 ),
2256 ),
2257 },
2258 )?;
2259 }
2260
2261 Ok(())
2262}
2263
2264/// Leave someone elses shared project.
2265async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2266 let sender_id = session.connection_id;
2267 let project_id = ProjectId::from_proto(request.project_id);
2268 let db = session.db().await;
2269 if db.is_hosted_project(project_id).await? {
2270 let project = db.leave_hosted_project(project_id, sender_id).await?;
2271 project_left(&project, &session);
2272 return Ok(());
2273 }
2274
2275 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2276 tracing::info!(
2277 %project_id,
2278 "leave project"
2279 );
2280
2281 project_left(&project, &session);
2282 if let Some(room) = room {
2283 room_updated(&room, &session.peer);
2284 }
2285
2286 Ok(())
2287}
2288
2289async fn join_hosted_project(
2290 request: proto::JoinHostedProject,
2291 response: Response<proto::JoinHostedProject>,
2292 session: UserSession,
2293) -> Result<()> {
2294 let (mut project, replica_id) = session
2295 .db()
2296 .await
2297 .join_hosted_project(
2298 ProjectId(request.project_id as i32),
2299 session.user_id(),
2300 session.connection_id,
2301 )
2302 .await?;
2303
2304 join_project_internal(response, session, &mut project, &replica_id)
2305}
2306
2307async fn list_remote_directory(
2308 request: proto::ListRemoteDirectory,
2309 response: Response<proto::ListRemoteDirectory>,
2310 session: UserSession,
2311) -> Result<()> {
2312 let dev_server_id = DevServerId(request.dev_server_id as i32);
2313 let dev_server_connection_id = session
2314 .connection_pool()
2315 .await
2316 .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2317
2318 session
2319 .db()
2320 .await
2321 .get_dev_server_for_user(dev_server_id, session.user_id())
2322 .await?;
2323
2324 response.send(
2325 session
2326 .peer
2327 .forward_request(session.connection_id, dev_server_connection_id, request)
2328 .await?,
2329 )?;
2330 Ok(())
2331}
2332
2333async fn update_dev_server_project(
2334 request: proto::UpdateDevServerProject,
2335 response: Response<proto::UpdateDevServerProject>,
2336 session: UserSession,
2337) -> Result<()> {
2338 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2339
2340 let (dev_server_project, update) = session
2341 .db()
2342 .await
2343 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2344 .await?;
2345
2346 let projects = session
2347 .db()
2348 .await
2349 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2350 .await?;
2351
2352 let dev_server_connection_id = session
2353 .connection_pool()
2354 .await
2355 .dev_server_connection_id_supporting(
2356 dev_server_project.dev_server_id,
2357 ZedVersion::with_list_directory(),
2358 )?;
2359
2360 session.peer.send(
2361 dev_server_connection_id,
2362 proto::DevServerInstructions { projects },
2363 )?;
2364
2365 send_dev_server_projects_update(session.user_id(), update, &session).await;
2366
2367 response.send(proto::Ack {})
2368}
2369
2370async fn create_dev_server_project(
2371 request: proto::CreateDevServerProject,
2372 response: Response<proto::CreateDevServerProject>,
2373 session: UserSession,
2374) -> Result<()> {
2375 let dev_server_id = DevServerId(request.dev_server_id as i32);
2376 let dev_server_connection_id = session
2377 .connection_pool()
2378 .await
2379 .dev_server_connection_id(dev_server_id);
2380 let Some(dev_server_connection_id) = dev_server_connection_id else {
2381 Err(ErrorCode::DevServerOffline
2382 .message("Cannot create a remote project when the dev server is offline".to_string())
2383 .anyhow())?
2384 };
2385
2386 let path = request.path.clone();
2387 //Check that the path exists on the dev server
2388 session
2389 .peer
2390 .forward_request(
2391 session.connection_id,
2392 dev_server_connection_id,
2393 proto::ValidateDevServerProjectRequest { path: path.clone() },
2394 )
2395 .await?;
2396
2397 let (dev_server_project, update) = session
2398 .db()
2399 .await
2400 .create_dev_server_project(
2401 DevServerId(request.dev_server_id as i32),
2402 &request.path,
2403 session.user_id(),
2404 )
2405 .await?;
2406
2407 let projects = session
2408 .db()
2409 .await
2410 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2411 .await?;
2412
2413 session.peer.send(
2414 dev_server_connection_id,
2415 proto::DevServerInstructions { projects },
2416 )?;
2417
2418 send_dev_server_projects_update(session.user_id(), update, &session).await;
2419
2420 response.send(proto::CreateDevServerProjectResponse {
2421 dev_server_project: Some(dev_server_project.to_proto(None)),
2422 })?;
2423 Ok(())
2424}
2425
2426async fn create_dev_server(
2427 request: proto::CreateDevServer,
2428 response: Response<proto::CreateDevServer>,
2429 session: UserSession,
2430) -> Result<()> {
2431 let access_token = auth::random_token();
2432 let hashed_access_token = auth::hash_access_token(&access_token);
2433
2434 if request.name.is_empty() {
2435 return Err(proto::ErrorCode::Forbidden
2436 .message("Dev server name cannot be empty".to_string())
2437 .anyhow())?;
2438 }
2439
2440 let (dev_server, status) = session
2441 .db()
2442 .await
2443 .create_dev_server(
2444 &request.name,
2445 request.ssh_connection_string.as_deref(),
2446 &hashed_access_token,
2447 session.user_id(),
2448 )
2449 .await?;
2450
2451 send_dev_server_projects_update(session.user_id(), status, &session).await;
2452
2453 response.send(proto::CreateDevServerResponse {
2454 dev_server_id: dev_server.id.0 as u64,
2455 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2456 name: request.name,
2457 })?;
2458 Ok(())
2459}
2460
2461async fn regenerate_dev_server_token(
2462 request: proto::RegenerateDevServerToken,
2463 response: Response<proto::RegenerateDevServerToken>,
2464 session: UserSession,
2465) -> Result<()> {
2466 let dev_server_id = DevServerId(request.dev_server_id as i32);
2467 let access_token = auth::random_token();
2468 let hashed_access_token = auth::hash_access_token(&access_token);
2469
2470 let connection_id = session
2471 .connection_pool()
2472 .await
2473 .dev_server_connection_id(dev_server_id);
2474 if let Some(connection_id) = connection_id {
2475 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2476 session.peer.send(
2477 connection_id,
2478 proto::ShutdownDevServer {
2479 reason: Some("dev server token was regenerated".to_string()),
2480 },
2481 )?;
2482 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2483 }
2484
2485 let status = session
2486 .db()
2487 .await
2488 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2489 .await?;
2490
2491 send_dev_server_projects_update(session.user_id(), status, &session).await;
2492
2493 response.send(proto::RegenerateDevServerTokenResponse {
2494 dev_server_id: dev_server_id.to_proto(),
2495 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2496 })?;
2497 Ok(())
2498}
2499
2500async fn rename_dev_server(
2501 request: proto::RenameDevServer,
2502 response: Response<proto::RenameDevServer>,
2503 session: UserSession,
2504) -> Result<()> {
2505 if request.name.trim().is_empty() {
2506 return Err(proto::ErrorCode::Forbidden
2507 .message("Dev server name cannot be empty".to_string())
2508 .anyhow())?;
2509 }
2510
2511 let dev_server_id = DevServerId(request.dev_server_id as i32);
2512 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2513 if dev_server.user_id != session.user_id() {
2514 return Err(anyhow!(ErrorCode::Forbidden))?;
2515 }
2516
2517 let status = session
2518 .db()
2519 .await
2520 .rename_dev_server(
2521 dev_server_id,
2522 &request.name,
2523 request.ssh_connection_string.as_deref(),
2524 session.user_id(),
2525 )
2526 .await?;
2527
2528 send_dev_server_projects_update(session.user_id(), status, &session).await;
2529
2530 response.send(proto::Ack {})?;
2531 Ok(())
2532}
2533
2534async fn delete_dev_server(
2535 request: proto::DeleteDevServer,
2536 response: Response<proto::DeleteDevServer>,
2537 session: UserSession,
2538) -> Result<()> {
2539 let dev_server_id = DevServerId(request.dev_server_id as i32);
2540 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2541 if dev_server.user_id != session.user_id() {
2542 return Err(anyhow!(ErrorCode::Forbidden))?;
2543 }
2544
2545 let connection_id = session
2546 .connection_pool()
2547 .await
2548 .dev_server_connection_id(dev_server_id);
2549 if let Some(connection_id) = connection_id {
2550 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2551 session.peer.send(
2552 connection_id,
2553 proto::ShutdownDevServer {
2554 reason: Some("dev server was deleted".to_string()),
2555 },
2556 )?;
2557 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2558 }
2559
2560 let status = session
2561 .db()
2562 .await
2563 .delete_dev_server(dev_server_id, session.user_id())
2564 .await?;
2565
2566 send_dev_server_projects_update(session.user_id(), status, &session).await;
2567
2568 response.send(proto::Ack {})?;
2569 Ok(())
2570}
2571
2572async fn delete_dev_server_project(
2573 request: proto::DeleteDevServerProject,
2574 response: Response<proto::DeleteDevServerProject>,
2575 session: UserSession,
2576) -> Result<()> {
2577 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2578 let dev_server_project = session
2579 .db()
2580 .await
2581 .get_dev_server_project(dev_server_project_id)
2582 .await?;
2583
2584 let dev_server = session
2585 .db()
2586 .await
2587 .get_dev_server(dev_server_project.dev_server_id)
2588 .await?;
2589 if dev_server.user_id != session.user_id() {
2590 return Err(anyhow!(ErrorCode::Forbidden))?;
2591 }
2592
2593 let dev_server_connection_id = session
2594 .connection_pool()
2595 .await
2596 .dev_server_connection_id(dev_server.id);
2597
2598 if let Some(dev_server_connection_id) = dev_server_connection_id {
2599 let project = session
2600 .db()
2601 .await
2602 .find_dev_server_project(dev_server_project_id)
2603 .await;
2604 if let Ok(project) = project {
2605 unshare_project_internal(
2606 project.id,
2607 dev_server_connection_id,
2608 Some(session.user_id()),
2609 &session,
2610 )
2611 .await?;
2612 }
2613 }
2614
2615 let (projects, status) = session
2616 .db()
2617 .await
2618 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2619 .await?;
2620
2621 if let Some(dev_server_connection_id) = dev_server_connection_id {
2622 session.peer.send(
2623 dev_server_connection_id,
2624 proto::DevServerInstructions { projects },
2625 )?;
2626 }
2627
2628 send_dev_server_projects_update(session.user_id(), status, &session).await;
2629
2630 response.send(proto::Ack {})?;
2631 Ok(())
2632}
2633
2634async fn rejoin_dev_server_projects(
2635 request: proto::RejoinRemoteProjects,
2636 response: Response<proto::RejoinRemoteProjects>,
2637 session: UserSession,
2638) -> Result<()> {
2639 let mut rejoined_projects = {
2640 let db = session.db().await;
2641 db.rejoin_dev_server_projects(
2642 &request.rejoined_projects,
2643 session.user_id(),
2644 session.0.connection_id,
2645 )
2646 .await?
2647 };
2648 response.send(proto::RejoinRemoteProjectsResponse {
2649 rejoined_projects: rejoined_projects
2650 .iter()
2651 .map(|project| project.to_proto())
2652 .collect(),
2653 })?;
2654 notify_rejoined_projects(&mut rejoined_projects, &session)
2655}
2656
2657async fn reconnect_dev_server(
2658 request: proto::ReconnectDevServer,
2659 response: Response<proto::ReconnectDevServer>,
2660 session: DevServerSession,
2661) -> Result<()> {
2662 let reshared_projects = {
2663 let db = session.db().await;
2664 db.reshare_dev_server_projects(
2665 &request.reshared_projects,
2666 session.dev_server_id(),
2667 session.0.connection_id,
2668 )
2669 .await?
2670 };
2671
2672 for project in &reshared_projects {
2673 for collaborator in &project.collaborators {
2674 session
2675 .peer
2676 .send(
2677 collaborator.connection_id,
2678 proto::UpdateProjectCollaborator {
2679 project_id: project.id.to_proto(),
2680 old_peer_id: Some(project.old_connection_id.into()),
2681 new_peer_id: Some(session.connection_id.into()),
2682 },
2683 )
2684 .trace_err();
2685 }
2686
2687 broadcast(
2688 Some(session.connection_id),
2689 project
2690 .collaborators
2691 .iter()
2692 .map(|collaborator| collaborator.connection_id),
2693 |connection_id| {
2694 session.peer.forward_send(
2695 session.connection_id,
2696 connection_id,
2697 proto::UpdateProject {
2698 project_id: project.id.to_proto(),
2699 worktrees: project.worktrees.clone(),
2700 },
2701 )
2702 },
2703 );
2704 }
2705
2706 response.send(proto::ReconnectDevServerResponse {
2707 reshared_projects: reshared_projects
2708 .iter()
2709 .map(|project| proto::ResharedProject {
2710 id: project.id.to_proto(),
2711 collaborators: project
2712 .collaborators
2713 .iter()
2714 .map(|collaborator| collaborator.to_proto())
2715 .collect(),
2716 })
2717 .collect(),
2718 })?;
2719
2720 Ok(())
2721}
2722
2723async fn shutdown_dev_server(
2724 _: proto::ShutdownDevServer,
2725 response: Response<proto::ShutdownDevServer>,
2726 session: DevServerSession,
2727) -> Result<()> {
2728 response.send(proto::Ack {})?;
2729 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2730 remove_dev_server_connection(session.dev_server_id(), &session).await
2731}
2732
2733async fn shutdown_dev_server_internal(
2734 dev_server_id: DevServerId,
2735 connection_id: ConnectionId,
2736 session: &Session,
2737) -> Result<()> {
2738 let (dev_server_projects, dev_server) = {
2739 let db = session.db().await;
2740 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2741 let dev_server = db.get_dev_server(dev_server_id).await?;
2742 (dev_server_projects, dev_server)
2743 };
2744
2745 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2746 unshare_project_internal(
2747 ProjectId::from_proto(project_id),
2748 connection_id,
2749 None,
2750 session,
2751 )
2752 .await?;
2753 }
2754
2755 session
2756 .connection_pool()
2757 .await
2758 .set_dev_server_offline(dev_server_id);
2759
2760 let status = session
2761 .db()
2762 .await
2763 .dev_server_projects_update(dev_server.user_id)
2764 .await?;
2765 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2766
2767 Ok(())
2768}
2769
2770async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2771 let dev_server_connection = session
2772 .connection_pool()
2773 .await
2774 .dev_server_connection_id(dev_server_id);
2775
2776 if let Some(dev_server_connection) = dev_server_connection {
2777 session
2778 .connection_pool()
2779 .await
2780 .remove_connection(dev_server_connection)?;
2781 }
2782 Ok(())
2783}
2784
2785/// Updates other participants with changes to the project
2786async fn update_project(
2787 request: proto::UpdateProject,
2788 response: Response<proto::UpdateProject>,
2789 session: Session,
2790) -> Result<()> {
2791 let project_id = ProjectId::from_proto(request.project_id);
2792 let (room, guest_connection_ids) = &*session
2793 .db()
2794 .await
2795 .update_project(project_id, session.connection_id, &request.worktrees)
2796 .await?;
2797 broadcast(
2798 Some(session.connection_id),
2799 guest_connection_ids.iter().copied(),
2800 |connection_id| {
2801 session
2802 .peer
2803 .forward_send(session.connection_id, connection_id, request.clone())
2804 },
2805 );
2806 if let Some(room) = room {
2807 room_updated(&room, &session.peer);
2808 }
2809 response.send(proto::Ack {})?;
2810
2811 Ok(())
2812}
2813
2814/// Updates other participants with changes to the worktree
2815async fn update_worktree(
2816 request: proto::UpdateWorktree,
2817 response: Response<proto::UpdateWorktree>,
2818 session: Session,
2819) -> Result<()> {
2820 let guest_connection_ids = session
2821 .db()
2822 .await
2823 .update_worktree(&request, session.connection_id)
2824 .await?;
2825
2826 broadcast(
2827 Some(session.connection_id),
2828 guest_connection_ids.iter().copied(),
2829 |connection_id| {
2830 session
2831 .peer
2832 .forward_send(session.connection_id, connection_id, request.clone())
2833 },
2834 );
2835 response.send(proto::Ack {})?;
2836 Ok(())
2837}
2838
2839/// Updates other participants with changes to the diagnostics
2840async fn update_diagnostic_summary(
2841 message: proto::UpdateDiagnosticSummary,
2842 session: Session,
2843) -> Result<()> {
2844 let guest_connection_ids = session
2845 .db()
2846 .await
2847 .update_diagnostic_summary(&message, session.connection_id)
2848 .await?;
2849
2850 broadcast(
2851 Some(session.connection_id),
2852 guest_connection_ids.iter().copied(),
2853 |connection_id| {
2854 session
2855 .peer
2856 .forward_send(session.connection_id, connection_id, message.clone())
2857 },
2858 );
2859
2860 Ok(())
2861}
2862
2863/// Updates other participants with changes to the worktree settings
2864async fn update_worktree_settings(
2865 message: proto::UpdateWorktreeSettings,
2866 session: Session,
2867) -> Result<()> {
2868 let guest_connection_ids = session
2869 .db()
2870 .await
2871 .update_worktree_settings(&message, session.connection_id)
2872 .await?;
2873
2874 broadcast(
2875 Some(session.connection_id),
2876 guest_connection_ids.iter().copied(),
2877 |connection_id| {
2878 session
2879 .peer
2880 .forward_send(session.connection_id, connection_id, message.clone())
2881 },
2882 );
2883
2884 Ok(())
2885}
2886
2887/// Notify other participants that a language server has started.
2888async fn start_language_server(
2889 request: proto::StartLanguageServer,
2890 session: Session,
2891) -> Result<()> {
2892 let guest_connection_ids = session
2893 .db()
2894 .await
2895 .start_language_server(&request, session.connection_id)
2896 .await?;
2897
2898 broadcast(
2899 Some(session.connection_id),
2900 guest_connection_ids.iter().copied(),
2901 |connection_id| {
2902 session
2903 .peer
2904 .forward_send(session.connection_id, connection_id, request.clone())
2905 },
2906 );
2907 Ok(())
2908}
2909
2910/// Notify other participants that a language server has changed.
2911async fn update_language_server(
2912 request: proto::UpdateLanguageServer,
2913 session: Session,
2914) -> Result<()> {
2915 let project_id = ProjectId::from_proto(request.project_id);
2916 let project_connection_ids = session
2917 .db()
2918 .await
2919 .project_connection_ids(project_id, session.connection_id, true)
2920 .await?;
2921 broadcast(
2922 Some(session.connection_id),
2923 project_connection_ids.iter().copied(),
2924 |connection_id| {
2925 session
2926 .peer
2927 .forward_send(session.connection_id, connection_id, request.clone())
2928 },
2929 );
2930 Ok(())
2931}
2932
2933/// forward a project request to the host. These requests should be read only
2934/// as guests are allowed to send them.
2935async fn forward_read_only_project_request<T>(
2936 request: T,
2937 response: Response<T>,
2938 session: UserSession,
2939) -> Result<()>
2940where
2941 T: EntityMessage + RequestMessage,
2942{
2943 let project_id = ProjectId::from_proto(request.remote_entity_id());
2944 let host_connection_id = session
2945 .db()
2946 .await
2947 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2948 .await?;
2949 let payload = session
2950 .peer
2951 .forward_request(session.connection_id, host_connection_id, request)
2952 .await?;
2953 response.send(payload)?;
2954 Ok(())
2955}
2956
2957/// forward a project request to the dev server. Only allowed
2958/// if it's your dev server.
2959async fn forward_project_request_for_owner<T>(
2960 request: T,
2961 response: Response<T>,
2962 session: UserSession,
2963) -> Result<()>
2964where
2965 T: EntityMessage + RequestMessage,
2966{
2967 let project_id = ProjectId::from_proto(request.remote_entity_id());
2968
2969 let host_connection_id = session
2970 .db()
2971 .await
2972 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2973 .await?;
2974 let payload = session
2975 .peer
2976 .forward_request(session.connection_id, host_connection_id, request)
2977 .await?;
2978 response.send(payload)?;
2979 Ok(())
2980}
2981
2982/// forward a project request to the host. These requests are disallowed
2983/// for guests.
2984async fn forward_mutating_project_request<T>(
2985 request: T,
2986 response: Response<T>,
2987 session: UserSession,
2988) -> Result<()>
2989where
2990 T: EntityMessage + RequestMessage,
2991{
2992 let project_id = ProjectId::from_proto(request.remote_entity_id());
2993
2994 let host_connection_id = session
2995 .db()
2996 .await
2997 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
2998 .await?;
2999 let payload = session
3000 .peer
3001 .forward_request(session.connection_id, host_connection_id, request)
3002 .await?;
3003 response.send(payload)?;
3004 Ok(())
3005}
3006
3007/// forward a project request to the host. These requests are disallowed
3008/// for guests.
3009async fn forward_versioned_mutating_project_request<T>(
3010 request: T,
3011 response: Response<T>,
3012 session: UserSession,
3013) -> Result<()>
3014where
3015 T: EntityMessage + RequestMessage + VersionedMessage,
3016{
3017 let project_id = ProjectId::from_proto(request.remote_entity_id());
3018
3019 let host_connection_id = session
3020 .db()
3021 .await
3022 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3023 .await?;
3024 if let Some(host_version) = session
3025 .connection_pool()
3026 .await
3027 .connection(host_connection_id)
3028 .map(|c| c.zed_version)
3029 {
3030 if let Some(min_required_version) = request.required_host_version() {
3031 if min_required_version > host_version {
3032 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3033 .with_tag("required", &min_required_version.to_string())))?;
3034 }
3035 }
3036 }
3037
3038 let payload = session
3039 .peer
3040 .forward_request(session.connection_id, host_connection_id, request)
3041 .await?;
3042 response.send(payload)?;
3043 Ok(())
3044}
3045
3046/// Notify other participants that a new buffer has been created
3047async fn create_buffer_for_peer(
3048 request: proto::CreateBufferForPeer,
3049 session: Session,
3050) -> Result<()> {
3051 session
3052 .db()
3053 .await
3054 .check_user_is_project_host(
3055 ProjectId::from_proto(request.project_id),
3056 session.connection_id,
3057 )
3058 .await?;
3059 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3060 session
3061 .peer
3062 .forward_send(session.connection_id, peer_id.into(), request)?;
3063 Ok(())
3064}
3065
3066/// Notify other participants that a buffer has been updated. This is
3067/// allowed for guests as long as the update is limited to selections.
3068async fn update_buffer(
3069 request: proto::UpdateBuffer,
3070 response: Response<proto::UpdateBuffer>,
3071 session: Session,
3072) -> Result<()> {
3073 let project_id = ProjectId::from_proto(request.project_id);
3074 let mut capability = Capability::ReadOnly;
3075
3076 for op in request.operations.iter() {
3077 match op.variant {
3078 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3079 Some(_) => capability = Capability::ReadWrite,
3080 }
3081 }
3082
3083 let host = {
3084 let guard = session
3085 .db()
3086 .await
3087 .connections_for_buffer_update(
3088 project_id,
3089 session.principal_id(),
3090 session.connection_id,
3091 capability,
3092 )
3093 .await?;
3094
3095 let (host, guests) = &*guard;
3096
3097 broadcast(
3098 Some(session.connection_id),
3099 guests.clone(),
3100 |connection_id| {
3101 session
3102 .peer
3103 .forward_send(session.connection_id, connection_id, request.clone())
3104 },
3105 );
3106
3107 *host
3108 };
3109
3110 if host != session.connection_id {
3111 session
3112 .peer
3113 .forward_request(session.connection_id, host, request.clone())
3114 .await?;
3115 }
3116
3117 response.send(proto::Ack {})?;
3118 Ok(())
3119}
3120
3121async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3122 let project_id = ProjectId::from_proto(message.project_id);
3123
3124 let operation = message.operation.as_ref().context("invalid operation")?;
3125 let capability = match operation.variant.as_ref() {
3126 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3127 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3128 match buffer_op.variant {
3129 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3130 Capability::ReadOnly
3131 }
3132 _ => Capability::ReadWrite,
3133 }
3134 } else {
3135 Capability::ReadWrite
3136 }
3137 }
3138 Some(_) => Capability::ReadWrite,
3139 None => Capability::ReadOnly,
3140 };
3141
3142 let guard = session
3143 .db()
3144 .await
3145 .connections_for_buffer_update(
3146 project_id,
3147 session.principal_id(),
3148 session.connection_id,
3149 capability,
3150 )
3151 .await?;
3152
3153 let (host, guests) = &*guard;
3154
3155 broadcast(
3156 Some(session.connection_id),
3157 guests.iter().chain([host]).copied(),
3158 |connection_id| {
3159 session
3160 .peer
3161 .forward_send(session.connection_id, connection_id, message.clone())
3162 },
3163 );
3164
3165 Ok(())
3166}
3167
3168/// Notify other participants that a project has been updated.
3169async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3170 request: T,
3171 session: Session,
3172) -> Result<()> {
3173 let project_id = ProjectId::from_proto(request.remote_entity_id());
3174 let project_connection_ids = session
3175 .db()
3176 .await
3177 .project_connection_ids(project_id, session.connection_id, false)
3178 .await?;
3179
3180 broadcast(
3181 Some(session.connection_id),
3182 project_connection_ids.iter().copied(),
3183 |connection_id| {
3184 session
3185 .peer
3186 .forward_send(session.connection_id, connection_id, request.clone())
3187 },
3188 );
3189 Ok(())
3190}
3191
3192/// Start following another user in a call.
3193async fn follow(
3194 request: proto::Follow,
3195 response: Response<proto::Follow>,
3196 session: UserSession,
3197) -> Result<()> {
3198 let room_id = RoomId::from_proto(request.room_id);
3199 let project_id = request.project_id.map(ProjectId::from_proto);
3200 let leader_id = request
3201 .leader_id
3202 .ok_or_else(|| anyhow!("invalid leader id"))?
3203 .into();
3204 let follower_id = session.connection_id;
3205
3206 session
3207 .db()
3208 .await
3209 .check_room_participants(room_id, leader_id, session.connection_id)
3210 .await?;
3211
3212 let response_payload = session
3213 .peer
3214 .forward_request(session.connection_id, leader_id, request)
3215 .await?;
3216 response.send(response_payload)?;
3217
3218 if let Some(project_id) = project_id {
3219 let room = session
3220 .db()
3221 .await
3222 .follow(room_id, project_id, leader_id, follower_id)
3223 .await?;
3224 room_updated(&room, &session.peer);
3225 }
3226
3227 Ok(())
3228}
3229
3230/// Stop following another user in a call.
3231async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3232 let room_id = RoomId::from_proto(request.room_id);
3233 let project_id = request.project_id.map(ProjectId::from_proto);
3234 let leader_id = request
3235 .leader_id
3236 .ok_or_else(|| anyhow!("invalid leader id"))?
3237 .into();
3238 let follower_id = session.connection_id;
3239
3240 session
3241 .db()
3242 .await
3243 .check_room_participants(room_id, leader_id, session.connection_id)
3244 .await?;
3245
3246 session
3247 .peer
3248 .forward_send(session.connection_id, leader_id, request)?;
3249
3250 if let Some(project_id) = project_id {
3251 let room = session
3252 .db()
3253 .await
3254 .unfollow(room_id, project_id, leader_id, follower_id)
3255 .await?;
3256 room_updated(&room, &session.peer);
3257 }
3258
3259 Ok(())
3260}
3261
3262/// Notify everyone following you of your current location.
3263async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3264 let room_id = RoomId::from_proto(request.room_id);
3265 let database = session.db.lock().await;
3266
3267 let connection_ids = if let Some(project_id) = request.project_id {
3268 let project_id = ProjectId::from_proto(project_id);
3269 database
3270 .project_connection_ids(project_id, session.connection_id, true)
3271 .await?
3272 } else {
3273 database
3274 .room_connection_ids(room_id, session.connection_id)
3275 .await?
3276 };
3277
3278 // For now, don't send view update messages back to that view's current leader.
3279 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3280 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3281 _ => None,
3282 });
3283
3284 for connection_id in connection_ids.iter().cloned() {
3285 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3286 session
3287 .peer
3288 .forward_send(session.connection_id, connection_id, request.clone())?;
3289 }
3290 }
3291 Ok(())
3292}
3293
3294/// Get public data about users.
3295async fn get_users(
3296 request: proto::GetUsers,
3297 response: Response<proto::GetUsers>,
3298 session: Session,
3299) -> Result<()> {
3300 let user_ids = request
3301 .user_ids
3302 .into_iter()
3303 .map(UserId::from_proto)
3304 .collect();
3305 let users = session
3306 .db()
3307 .await
3308 .get_users_by_ids(user_ids)
3309 .await?
3310 .into_iter()
3311 .map(|user| proto::User {
3312 id: user.id.to_proto(),
3313 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3314 github_login: user.github_login,
3315 })
3316 .collect();
3317 response.send(proto::UsersResponse { users })?;
3318 Ok(())
3319}
3320
3321/// Search for users (to invite) buy Github login
3322async fn fuzzy_search_users(
3323 request: proto::FuzzySearchUsers,
3324 response: Response<proto::FuzzySearchUsers>,
3325 session: UserSession,
3326) -> Result<()> {
3327 let query = request.query;
3328 let users = match query.len() {
3329 0 => vec![],
3330 1 | 2 => session
3331 .db()
3332 .await
3333 .get_user_by_github_login(&query)
3334 .await?
3335 .into_iter()
3336 .collect(),
3337 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3338 };
3339 let users = users
3340 .into_iter()
3341 .filter(|user| user.id != session.user_id())
3342 .map(|user| proto::User {
3343 id: user.id.to_proto(),
3344 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3345 github_login: user.github_login,
3346 })
3347 .collect();
3348 response.send(proto::UsersResponse { users })?;
3349 Ok(())
3350}
3351
3352/// Send a contact request to another user.
3353async fn request_contact(
3354 request: proto::RequestContact,
3355 response: Response<proto::RequestContact>,
3356 session: UserSession,
3357) -> Result<()> {
3358 let requester_id = session.user_id();
3359 let responder_id = UserId::from_proto(request.responder_id);
3360 if requester_id == responder_id {
3361 return Err(anyhow!("cannot add yourself as a contact"))?;
3362 }
3363
3364 let notifications = session
3365 .db()
3366 .await
3367 .send_contact_request(requester_id, responder_id)
3368 .await?;
3369
3370 // Update outgoing contact requests of requester
3371 let mut update = proto::UpdateContacts::default();
3372 update.outgoing_requests.push(responder_id.to_proto());
3373 for connection_id in session
3374 .connection_pool()
3375 .await
3376 .user_connection_ids(requester_id)
3377 {
3378 session.peer.send(connection_id, update.clone())?;
3379 }
3380
3381 // Update incoming contact requests of responder
3382 let mut update = proto::UpdateContacts::default();
3383 update
3384 .incoming_requests
3385 .push(proto::IncomingContactRequest {
3386 requester_id: requester_id.to_proto(),
3387 });
3388 let connection_pool = session.connection_pool().await;
3389 for connection_id in connection_pool.user_connection_ids(responder_id) {
3390 session.peer.send(connection_id, update.clone())?;
3391 }
3392
3393 send_notifications(&connection_pool, &session.peer, notifications);
3394
3395 response.send(proto::Ack {})?;
3396 Ok(())
3397}
3398
3399/// Accept or decline a contact request
3400async fn respond_to_contact_request(
3401 request: proto::RespondToContactRequest,
3402 response: Response<proto::RespondToContactRequest>,
3403 session: UserSession,
3404) -> Result<()> {
3405 let responder_id = session.user_id();
3406 let requester_id = UserId::from_proto(request.requester_id);
3407 let db = session.db().await;
3408 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3409 db.dismiss_contact_notification(responder_id, requester_id)
3410 .await?;
3411 } else {
3412 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3413
3414 let notifications = db
3415 .respond_to_contact_request(responder_id, requester_id, accept)
3416 .await?;
3417 let requester_busy = db.is_user_busy(requester_id).await?;
3418 let responder_busy = db.is_user_busy(responder_id).await?;
3419
3420 let pool = session.connection_pool().await;
3421 // Update responder with new contact
3422 let mut update = proto::UpdateContacts::default();
3423 if accept {
3424 update
3425 .contacts
3426 .push(contact_for_user(requester_id, requester_busy, &pool));
3427 }
3428 update
3429 .remove_incoming_requests
3430 .push(requester_id.to_proto());
3431 for connection_id in pool.user_connection_ids(responder_id) {
3432 session.peer.send(connection_id, update.clone())?;
3433 }
3434
3435 // Update requester with new contact
3436 let mut update = proto::UpdateContacts::default();
3437 if accept {
3438 update
3439 .contacts
3440 .push(contact_for_user(responder_id, responder_busy, &pool));
3441 }
3442 update
3443 .remove_outgoing_requests
3444 .push(responder_id.to_proto());
3445
3446 for connection_id in pool.user_connection_ids(requester_id) {
3447 session.peer.send(connection_id, update.clone())?;
3448 }
3449
3450 send_notifications(&pool, &session.peer, notifications);
3451 }
3452
3453 response.send(proto::Ack {})?;
3454 Ok(())
3455}
3456
3457/// Remove a contact.
3458async fn remove_contact(
3459 request: proto::RemoveContact,
3460 response: Response<proto::RemoveContact>,
3461 session: UserSession,
3462) -> Result<()> {
3463 let requester_id = session.user_id();
3464 let responder_id = UserId::from_proto(request.user_id);
3465 let db = session.db().await;
3466 let (contact_accepted, deleted_notification_id) =
3467 db.remove_contact(requester_id, responder_id).await?;
3468
3469 let pool = session.connection_pool().await;
3470 // Update outgoing contact requests of requester
3471 let mut update = proto::UpdateContacts::default();
3472 if contact_accepted {
3473 update.remove_contacts.push(responder_id.to_proto());
3474 } else {
3475 update
3476 .remove_outgoing_requests
3477 .push(responder_id.to_proto());
3478 }
3479 for connection_id in pool.user_connection_ids(requester_id) {
3480 session.peer.send(connection_id, update.clone())?;
3481 }
3482
3483 // Update incoming contact requests of responder
3484 let mut update = proto::UpdateContacts::default();
3485 if contact_accepted {
3486 update.remove_contacts.push(requester_id.to_proto());
3487 } else {
3488 update
3489 .remove_incoming_requests
3490 .push(requester_id.to_proto());
3491 }
3492 for connection_id in pool.user_connection_ids(responder_id) {
3493 session.peer.send(connection_id, update.clone())?;
3494 if let Some(notification_id) = deleted_notification_id {
3495 session.peer.send(
3496 connection_id,
3497 proto::DeleteNotification {
3498 notification_id: notification_id.to_proto(),
3499 },
3500 )?;
3501 }
3502 }
3503
3504 response.send(proto::Ack {})?;
3505 Ok(())
3506}
3507
3508fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3509 version.0.minor() < 139
3510}
3511
3512async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3513 subscribe_user_to_channels(
3514 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3515 &session,
3516 )
3517 .await?;
3518 Ok(())
3519}
3520
3521async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3522 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3523 let mut pool = session.connection_pool().await;
3524 for membership in &channels_for_user.channel_memberships {
3525 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3526 }
3527 session.peer.send(
3528 session.connection_id,
3529 build_update_user_channels(&channels_for_user),
3530 )?;
3531 session.peer.send(
3532 session.connection_id,
3533 build_channels_update(channels_for_user),
3534 )?;
3535 Ok(())
3536}
3537
3538/// Creates a new channel.
3539async fn create_channel(
3540 request: proto::CreateChannel,
3541 response: Response<proto::CreateChannel>,
3542 session: UserSession,
3543) -> Result<()> {
3544 let db = session.db().await;
3545
3546 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3547 let (channel, membership) = db
3548 .create_channel(&request.name, parent_id, session.user_id())
3549 .await?;
3550
3551 let root_id = channel.root_id();
3552 let channel = Channel::from_model(channel);
3553
3554 response.send(proto::CreateChannelResponse {
3555 channel: Some(channel.to_proto()),
3556 parent_id: request.parent_id,
3557 })?;
3558
3559 let mut connection_pool = session.connection_pool().await;
3560 if let Some(membership) = membership {
3561 connection_pool.subscribe_to_channel(
3562 membership.user_id,
3563 membership.channel_id,
3564 membership.role,
3565 );
3566 let update = proto::UpdateUserChannels {
3567 channel_memberships: vec![proto::ChannelMembership {
3568 channel_id: membership.channel_id.to_proto(),
3569 role: membership.role.into(),
3570 }],
3571 ..Default::default()
3572 };
3573 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3574 session.peer.send(connection_id, update.clone())?;
3575 }
3576 }
3577
3578 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3579 if !role.can_see_channel(channel.visibility) {
3580 continue;
3581 }
3582
3583 let update = proto::UpdateChannels {
3584 channels: vec![channel.to_proto()],
3585 ..Default::default()
3586 };
3587 session.peer.send(connection_id, update.clone())?;
3588 }
3589
3590 Ok(())
3591}
3592
3593/// Delete a channel
3594async fn delete_channel(
3595 request: proto::DeleteChannel,
3596 response: Response<proto::DeleteChannel>,
3597 session: UserSession,
3598) -> Result<()> {
3599 let db = session.db().await;
3600
3601 let channel_id = request.channel_id;
3602 let (root_channel, removed_channels) = db
3603 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3604 .await?;
3605 response.send(proto::Ack {})?;
3606
3607 // Notify members of removed channels
3608 let mut update = proto::UpdateChannels::default();
3609 update
3610 .delete_channels
3611 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3612
3613 let connection_pool = session.connection_pool().await;
3614 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3615 session.peer.send(connection_id, update.clone())?;
3616 }
3617
3618 Ok(())
3619}
3620
3621/// Invite someone to join a channel.
3622async fn invite_channel_member(
3623 request: proto::InviteChannelMember,
3624 response: Response<proto::InviteChannelMember>,
3625 session: UserSession,
3626) -> Result<()> {
3627 let db = session.db().await;
3628 let channel_id = ChannelId::from_proto(request.channel_id);
3629 let invitee_id = UserId::from_proto(request.user_id);
3630 let InviteMemberResult {
3631 channel,
3632 notifications,
3633 } = db
3634 .invite_channel_member(
3635 channel_id,
3636 invitee_id,
3637 session.user_id(),
3638 request.role().into(),
3639 )
3640 .await?;
3641
3642 let update = proto::UpdateChannels {
3643 channel_invitations: vec![channel.to_proto()],
3644 ..Default::default()
3645 };
3646
3647 let connection_pool = session.connection_pool().await;
3648 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3649 session.peer.send(connection_id, update.clone())?;
3650 }
3651
3652 send_notifications(&connection_pool, &session.peer, notifications);
3653
3654 response.send(proto::Ack {})?;
3655 Ok(())
3656}
3657
3658/// remove someone from a channel
3659async fn remove_channel_member(
3660 request: proto::RemoveChannelMember,
3661 response: Response<proto::RemoveChannelMember>,
3662 session: UserSession,
3663) -> Result<()> {
3664 let db = session.db().await;
3665 let channel_id = ChannelId::from_proto(request.channel_id);
3666 let member_id = UserId::from_proto(request.user_id);
3667
3668 let RemoveChannelMemberResult {
3669 membership_update,
3670 notification_id,
3671 } = db
3672 .remove_channel_member(channel_id, member_id, session.user_id())
3673 .await?;
3674
3675 let mut connection_pool = session.connection_pool().await;
3676 notify_membership_updated(
3677 &mut connection_pool,
3678 membership_update,
3679 member_id,
3680 &session.peer,
3681 );
3682 for connection_id in connection_pool.user_connection_ids(member_id) {
3683 if let Some(notification_id) = notification_id {
3684 session
3685 .peer
3686 .send(
3687 connection_id,
3688 proto::DeleteNotification {
3689 notification_id: notification_id.to_proto(),
3690 },
3691 )
3692 .trace_err();
3693 }
3694 }
3695
3696 response.send(proto::Ack {})?;
3697 Ok(())
3698}
3699
3700/// Toggle the channel between public and private.
3701/// Care is taken to maintain the invariant that public channels only descend from public channels,
3702/// (though members-only channels can appear at any point in the hierarchy).
3703async fn set_channel_visibility(
3704 request: proto::SetChannelVisibility,
3705 response: Response<proto::SetChannelVisibility>,
3706 session: UserSession,
3707) -> Result<()> {
3708 let db = session.db().await;
3709 let channel_id = ChannelId::from_proto(request.channel_id);
3710 let visibility = request.visibility().into();
3711
3712 let channel_model = db
3713 .set_channel_visibility(channel_id, visibility, session.user_id())
3714 .await?;
3715 let root_id = channel_model.root_id();
3716 let channel = Channel::from_model(channel_model);
3717
3718 let mut connection_pool = session.connection_pool().await;
3719 for (user_id, role) in connection_pool
3720 .channel_user_ids(root_id)
3721 .collect::<Vec<_>>()
3722 .into_iter()
3723 {
3724 let update = if role.can_see_channel(channel.visibility) {
3725 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3726 proto::UpdateChannels {
3727 channels: vec![channel.to_proto()],
3728 ..Default::default()
3729 }
3730 } else {
3731 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3732 proto::UpdateChannels {
3733 delete_channels: vec![channel.id.to_proto()],
3734 ..Default::default()
3735 }
3736 };
3737
3738 for connection_id in connection_pool.user_connection_ids(user_id) {
3739 session.peer.send(connection_id, update.clone())?;
3740 }
3741 }
3742
3743 response.send(proto::Ack {})?;
3744 Ok(())
3745}
3746
3747/// Alter the role for a user in the channel.
3748async fn set_channel_member_role(
3749 request: proto::SetChannelMemberRole,
3750 response: Response<proto::SetChannelMemberRole>,
3751 session: UserSession,
3752) -> Result<()> {
3753 let db = session.db().await;
3754 let channel_id = ChannelId::from_proto(request.channel_id);
3755 let member_id = UserId::from_proto(request.user_id);
3756 let result = db
3757 .set_channel_member_role(
3758 channel_id,
3759 session.user_id(),
3760 member_id,
3761 request.role().into(),
3762 )
3763 .await?;
3764
3765 match result {
3766 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3767 let mut connection_pool = session.connection_pool().await;
3768 notify_membership_updated(
3769 &mut connection_pool,
3770 membership_update,
3771 member_id,
3772 &session.peer,
3773 )
3774 }
3775 db::SetMemberRoleResult::InviteUpdated(channel) => {
3776 let update = proto::UpdateChannels {
3777 channel_invitations: vec![channel.to_proto()],
3778 ..Default::default()
3779 };
3780
3781 for connection_id in session
3782 .connection_pool()
3783 .await
3784 .user_connection_ids(member_id)
3785 {
3786 session.peer.send(connection_id, update.clone())?;
3787 }
3788 }
3789 }
3790
3791 response.send(proto::Ack {})?;
3792 Ok(())
3793}
3794
3795/// Change the name of a channel
3796async fn rename_channel(
3797 request: proto::RenameChannel,
3798 response: Response<proto::RenameChannel>,
3799 session: UserSession,
3800) -> Result<()> {
3801 let db = session.db().await;
3802 let channel_id = ChannelId::from_proto(request.channel_id);
3803 let channel_model = db
3804 .rename_channel(channel_id, session.user_id(), &request.name)
3805 .await?;
3806 let root_id = channel_model.root_id();
3807 let channel = Channel::from_model(channel_model);
3808
3809 response.send(proto::RenameChannelResponse {
3810 channel: Some(channel.to_proto()),
3811 })?;
3812
3813 let connection_pool = session.connection_pool().await;
3814 let update = proto::UpdateChannels {
3815 channels: vec![channel.to_proto()],
3816 ..Default::default()
3817 };
3818 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3819 if role.can_see_channel(channel.visibility) {
3820 session.peer.send(connection_id, update.clone())?;
3821 }
3822 }
3823
3824 Ok(())
3825}
3826
3827/// Move a channel to a new parent.
3828async fn move_channel(
3829 request: proto::MoveChannel,
3830 response: Response<proto::MoveChannel>,
3831 session: UserSession,
3832) -> Result<()> {
3833 let channel_id = ChannelId::from_proto(request.channel_id);
3834 let to = ChannelId::from_proto(request.to);
3835
3836 let (root_id, channels) = session
3837 .db()
3838 .await
3839 .move_channel(channel_id, to, session.user_id())
3840 .await?;
3841
3842 let connection_pool = session.connection_pool().await;
3843 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3844 let channels = channels
3845 .iter()
3846 .filter_map(|channel| {
3847 if role.can_see_channel(channel.visibility) {
3848 Some(channel.to_proto())
3849 } else {
3850 None
3851 }
3852 })
3853 .collect::<Vec<_>>();
3854 if channels.is_empty() {
3855 continue;
3856 }
3857
3858 let update = proto::UpdateChannels {
3859 channels,
3860 ..Default::default()
3861 };
3862
3863 session.peer.send(connection_id, update.clone())?;
3864 }
3865
3866 response.send(Ack {})?;
3867 Ok(())
3868}
3869
3870/// Get the list of channel members
3871async fn get_channel_members(
3872 request: proto::GetChannelMembers,
3873 response: Response<proto::GetChannelMembers>,
3874 session: UserSession,
3875) -> Result<()> {
3876 let db = session.db().await;
3877 let channel_id = ChannelId::from_proto(request.channel_id);
3878 let limit = if request.limit == 0 {
3879 u16::MAX as u64
3880 } else {
3881 request.limit
3882 };
3883 let (members, users) = db
3884 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3885 .await?;
3886 response.send(proto::GetChannelMembersResponse { members, users })?;
3887 Ok(())
3888}
3889
3890/// Accept or decline a channel invitation.
3891async fn respond_to_channel_invite(
3892 request: proto::RespondToChannelInvite,
3893 response: Response<proto::RespondToChannelInvite>,
3894 session: UserSession,
3895) -> Result<()> {
3896 let db = session.db().await;
3897 let channel_id = ChannelId::from_proto(request.channel_id);
3898 let RespondToChannelInvite {
3899 membership_update,
3900 notifications,
3901 } = db
3902 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3903 .await?;
3904
3905 let mut connection_pool = session.connection_pool().await;
3906 if let Some(membership_update) = membership_update {
3907 notify_membership_updated(
3908 &mut connection_pool,
3909 membership_update,
3910 session.user_id(),
3911 &session.peer,
3912 );
3913 } else {
3914 let update = proto::UpdateChannels {
3915 remove_channel_invitations: vec![channel_id.to_proto()],
3916 ..Default::default()
3917 };
3918
3919 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3920 session.peer.send(connection_id, update.clone())?;
3921 }
3922 };
3923
3924 send_notifications(&connection_pool, &session.peer, notifications);
3925
3926 response.send(proto::Ack {})?;
3927
3928 Ok(())
3929}
3930
3931/// Join the channels' room
3932async fn join_channel(
3933 request: proto::JoinChannel,
3934 response: Response<proto::JoinChannel>,
3935 session: UserSession,
3936) -> Result<()> {
3937 let channel_id = ChannelId::from_proto(request.channel_id);
3938 join_channel_internal(channel_id, Box::new(response), session).await
3939}
3940
3941trait JoinChannelInternalResponse {
3942 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3943}
3944impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3945 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3946 Response::<proto::JoinChannel>::send(self, result)
3947 }
3948}
3949impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3950 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3951 Response::<proto::JoinRoom>::send(self, result)
3952 }
3953}
3954
3955async fn join_channel_internal(
3956 channel_id: ChannelId,
3957 response: Box<impl JoinChannelInternalResponse>,
3958 session: UserSession,
3959) -> Result<()> {
3960 let joined_room = {
3961 let mut db = session.db().await;
3962 // If zed quits without leaving the room, and the user re-opens zed before the
3963 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3964 // room they were in.
3965 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3966 tracing::info!(
3967 stale_connection_id = %connection,
3968 "cleaning up stale connection",
3969 );
3970 drop(db);
3971 leave_room_for_session(&session, connection).await?;
3972 db = session.db().await;
3973 }
3974
3975 let (joined_room, membership_updated, role) = db
3976 .join_channel(channel_id, session.user_id(), session.connection_id)
3977 .await?;
3978
3979 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3980 let (can_publish, token) = if role == ChannelRole::Guest {
3981 (
3982 false,
3983 live_kit
3984 .guest_token(
3985 &joined_room.room.live_kit_room,
3986 &session.user_id().to_string(),
3987 )
3988 .trace_err()?,
3989 )
3990 } else {
3991 (
3992 true,
3993 live_kit
3994 .room_token(
3995 &joined_room.room.live_kit_room,
3996 &session.user_id().to_string(),
3997 )
3998 .trace_err()?,
3999 )
4000 };
4001
4002 Some(LiveKitConnectionInfo {
4003 server_url: live_kit.url().into(),
4004 token,
4005 can_publish,
4006 })
4007 });
4008
4009 response.send(proto::JoinRoomResponse {
4010 room: Some(joined_room.room.clone()),
4011 channel_id: joined_room
4012 .channel
4013 .as_ref()
4014 .map(|channel| channel.id.to_proto()),
4015 live_kit_connection_info,
4016 })?;
4017
4018 let mut connection_pool = session.connection_pool().await;
4019 if let Some(membership_updated) = membership_updated {
4020 notify_membership_updated(
4021 &mut connection_pool,
4022 membership_updated,
4023 session.user_id(),
4024 &session.peer,
4025 );
4026 }
4027
4028 room_updated(&joined_room.room, &session.peer);
4029
4030 joined_room
4031 };
4032
4033 channel_updated(
4034 &joined_room
4035 .channel
4036 .ok_or_else(|| anyhow!("channel not returned"))?,
4037 &joined_room.room,
4038 &session.peer,
4039 &*session.connection_pool().await,
4040 );
4041
4042 update_user_contacts(session.user_id(), &session).await?;
4043 Ok(())
4044}
4045
4046/// Start editing the channel notes
4047async fn join_channel_buffer(
4048 request: proto::JoinChannelBuffer,
4049 response: Response<proto::JoinChannelBuffer>,
4050 session: UserSession,
4051) -> Result<()> {
4052 let db = session.db().await;
4053 let channel_id = ChannelId::from_proto(request.channel_id);
4054
4055 let open_response = db
4056 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4057 .await?;
4058
4059 let collaborators = open_response.collaborators.clone();
4060 response.send(open_response)?;
4061
4062 let update = UpdateChannelBufferCollaborators {
4063 channel_id: channel_id.to_proto(),
4064 collaborators: collaborators.clone(),
4065 };
4066 channel_buffer_updated(
4067 session.connection_id,
4068 collaborators
4069 .iter()
4070 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4071 &update,
4072 &session.peer,
4073 );
4074
4075 Ok(())
4076}
4077
4078/// Edit the channel notes
4079async fn update_channel_buffer(
4080 request: proto::UpdateChannelBuffer,
4081 session: UserSession,
4082) -> Result<()> {
4083 let db = session.db().await;
4084 let channel_id = ChannelId::from_proto(request.channel_id);
4085
4086 let (collaborators, epoch, version) = db
4087 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4088 .await?;
4089
4090 channel_buffer_updated(
4091 session.connection_id,
4092 collaborators.clone(),
4093 &proto::UpdateChannelBuffer {
4094 channel_id: channel_id.to_proto(),
4095 operations: request.operations,
4096 },
4097 &session.peer,
4098 );
4099
4100 let pool = &*session.connection_pool().await;
4101
4102 let non_collaborators =
4103 pool.channel_connection_ids(channel_id)
4104 .filter_map(|(connection_id, _)| {
4105 if collaborators.contains(&connection_id) {
4106 None
4107 } else {
4108 Some(connection_id)
4109 }
4110 });
4111
4112 broadcast(None, non_collaborators, |peer_id| {
4113 session.peer.send(
4114 peer_id,
4115 proto::UpdateChannels {
4116 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4117 channel_id: channel_id.to_proto(),
4118 epoch: epoch as u64,
4119 version: version.clone(),
4120 }],
4121 ..Default::default()
4122 },
4123 )
4124 });
4125
4126 Ok(())
4127}
4128
4129/// Rejoin the channel notes after a connection blip
4130async fn rejoin_channel_buffers(
4131 request: proto::RejoinChannelBuffers,
4132 response: Response<proto::RejoinChannelBuffers>,
4133 session: UserSession,
4134) -> Result<()> {
4135 let db = session.db().await;
4136 let buffers = db
4137 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4138 .await?;
4139
4140 for rejoined_buffer in &buffers {
4141 let collaborators_to_notify = rejoined_buffer
4142 .buffer
4143 .collaborators
4144 .iter()
4145 .filter_map(|c| Some(c.peer_id?.into()));
4146 channel_buffer_updated(
4147 session.connection_id,
4148 collaborators_to_notify,
4149 &proto::UpdateChannelBufferCollaborators {
4150 channel_id: rejoined_buffer.buffer.channel_id,
4151 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4152 },
4153 &session.peer,
4154 );
4155 }
4156
4157 response.send(proto::RejoinChannelBuffersResponse {
4158 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4159 })?;
4160
4161 Ok(())
4162}
4163
4164/// Stop editing the channel notes
4165async fn leave_channel_buffer(
4166 request: proto::LeaveChannelBuffer,
4167 response: Response<proto::LeaveChannelBuffer>,
4168 session: UserSession,
4169) -> Result<()> {
4170 let db = session.db().await;
4171 let channel_id = ChannelId::from_proto(request.channel_id);
4172
4173 let left_buffer = db
4174 .leave_channel_buffer(channel_id, session.connection_id)
4175 .await?;
4176
4177 response.send(Ack {})?;
4178
4179 channel_buffer_updated(
4180 session.connection_id,
4181 left_buffer.connections,
4182 &proto::UpdateChannelBufferCollaborators {
4183 channel_id: channel_id.to_proto(),
4184 collaborators: left_buffer.collaborators,
4185 },
4186 &session.peer,
4187 );
4188
4189 Ok(())
4190}
4191
4192fn channel_buffer_updated<T: EnvelopedMessage>(
4193 sender_id: ConnectionId,
4194 collaborators: impl IntoIterator<Item = ConnectionId>,
4195 message: &T,
4196 peer: &Peer,
4197) {
4198 broadcast(Some(sender_id), collaborators, |peer_id| {
4199 peer.send(peer_id, message.clone())
4200 });
4201}
4202
4203fn send_notifications(
4204 connection_pool: &ConnectionPool,
4205 peer: &Peer,
4206 notifications: db::NotificationBatch,
4207) {
4208 for (user_id, notification) in notifications {
4209 for connection_id in connection_pool.user_connection_ids(user_id) {
4210 if let Err(error) = peer.send(
4211 connection_id,
4212 proto::AddNotification {
4213 notification: Some(notification.clone()),
4214 },
4215 ) {
4216 tracing::error!(
4217 "failed to send notification to {:?} {}",
4218 connection_id,
4219 error
4220 );
4221 }
4222 }
4223 }
4224}
4225
4226/// Send a message to the channel
4227async fn send_channel_message(
4228 request: proto::SendChannelMessage,
4229 response: Response<proto::SendChannelMessage>,
4230 session: UserSession,
4231) -> Result<()> {
4232 // Validate the message body.
4233 let body = request.body.trim().to_string();
4234 if body.len() > MAX_MESSAGE_LEN {
4235 return Err(anyhow!("message is too long"))?;
4236 }
4237 if body.is_empty() {
4238 return Err(anyhow!("message can't be blank"))?;
4239 }
4240
4241 // TODO: adjust mentions if body is trimmed
4242
4243 let timestamp = OffsetDateTime::now_utc();
4244 let nonce = request
4245 .nonce
4246 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4247
4248 let channel_id = ChannelId::from_proto(request.channel_id);
4249 let CreatedChannelMessage {
4250 message_id,
4251 participant_connection_ids,
4252 notifications,
4253 } = session
4254 .db()
4255 .await
4256 .create_channel_message(
4257 channel_id,
4258 session.user_id(),
4259 &body,
4260 &request.mentions,
4261 timestamp,
4262 nonce.clone().into(),
4263 match request.reply_to_message_id {
4264 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4265 None => None,
4266 },
4267 )
4268 .await?;
4269
4270 let message = proto::ChannelMessage {
4271 sender_id: session.user_id().to_proto(),
4272 id: message_id.to_proto(),
4273 body,
4274 mentions: request.mentions,
4275 timestamp: timestamp.unix_timestamp() as u64,
4276 nonce: Some(nonce),
4277 reply_to_message_id: request.reply_to_message_id,
4278 edited_at: None,
4279 };
4280 broadcast(
4281 Some(session.connection_id),
4282 participant_connection_ids.clone(),
4283 |connection| {
4284 session.peer.send(
4285 connection,
4286 proto::ChannelMessageSent {
4287 channel_id: channel_id.to_proto(),
4288 message: Some(message.clone()),
4289 },
4290 )
4291 },
4292 );
4293 response.send(proto::SendChannelMessageResponse {
4294 message: Some(message),
4295 })?;
4296
4297 let pool = &*session.connection_pool().await;
4298 let non_participants =
4299 pool.channel_connection_ids(channel_id)
4300 .filter_map(|(connection_id, _)| {
4301 if participant_connection_ids.contains(&connection_id) {
4302 None
4303 } else {
4304 Some(connection_id)
4305 }
4306 });
4307 broadcast(None, non_participants, |peer_id| {
4308 session.peer.send(
4309 peer_id,
4310 proto::UpdateChannels {
4311 latest_channel_message_ids: vec![proto::ChannelMessageId {
4312 channel_id: channel_id.to_proto(),
4313 message_id: message_id.to_proto(),
4314 }],
4315 ..Default::default()
4316 },
4317 )
4318 });
4319 send_notifications(pool, &session.peer, notifications);
4320
4321 Ok(())
4322}
4323
4324/// Delete a channel message
4325async fn remove_channel_message(
4326 request: proto::RemoveChannelMessage,
4327 response: Response<proto::RemoveChannelMessage>,
4328 session: UserSession,
4329) -> Result<()> {
4330 let channel_id = ChannelId::from_proto(request.channel_id);
4331 let message_id = MessageId::from_proto(request.message_id);
4332 let (connection_ids, existing_notification_ids) = session
4333 .db()
4334 .await
4335 .remove_channel_message(channel_id, message_id, session.user_id())
4336 .await?;
4337
4338 broadcast(
4339 Some(session.connection_id),
4340 connection_ids,
4341 move |connection| {
4342 session.peer.send(connection, request.clone())?;
4343
4344 for notification_id in &existing_notification_ids {
4345 session.peer.send(
4346 connection,
4347 proto::DeleteNotification {
4348 notification_id: (*notification_id).to_proto(),
4349 },
4350 )?;
4351 }
4352
4353 Ok(())
4354 },
4355 );
4356 response.send(proto::Ack {})?;
4357 Ok(())
4358}
4359
4360async fn update_channel_message(
4361 request: proto::UpdateChannelMessage,
4362 response: Response<proto::UpdateChannelMessage>,
4363 session: UserSession,
4364) -> Result<()> {
4365 let channel_id = ChannelId::from_proto(request.channel_id);
4366 let message_id = MessageId::from_proto(request.message_id);
4367 let updated_at = OffsetDateTime::now_utc();
4368 let UpdatedChannelMessage {
4369 message_id,
4370 participant_connection_ids,
4371 notifications,
4372 reply_to_message_id,
4373 timestamp,
4374 deleted_mention_notification_ids,
4375 updated_mention_notifications,
4376 } = session
4377 .db()
4378 .await
4379 .update_channel_message(
4380 channel_id,
4381 message_id,
4382 session.user_id(),
4383 request.body.as_str(),
4384 &request.mentions,
4385 updated_at,
4386 )
4387 .await?;
4388
4389 let nonce = request
4390 .nonce
4391 .clone()
4392 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4393
4394 let message = proto::ChannelMessage {
4395 sender_id: session.user_id().to_proto(),
4396 id: message_id.to_proto(),
4397 body: request.body.clone(),
4398 mentions: request.mentions.clone(),
4399 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4400 nonce: Some(nonce),
4401 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4402 edited_at: Some(updated_at.unix_timestamp() as u64),
4403 };
4404
4405 response.send(proto::Ack {})?;
4406
4407 let pool = &*session.connection_pool().await;
4408 broadcast(
4409 Some(session.connection_id),
4410 participant_connection_ids,
4411 |connection| {
4412 session.peer.send(
4413 connection,
4414 proto::ChannelMessageUpdate {
4415 channel_id: channel_id.to_proto(),
4416 message: Some(message.clone()),
4417 },
4418 )?;
4419
4420 for notification_id in &deleted_mention_notification_ids {
4421 session.peer.send(
4422 connection,
4423 proto::DeleteNotification {
4424 notification_id: (*notification_id).to_proto(),
4425 },
4426 )?;
4427 }
4428
4429 for notification in &updated_mention_notifications {
4430 session.peer.send(
4431 connection,
4432 proto::UpdateNotification {
4433 notification: Some(notification.clone()),
4434 },
4435 )?;
4436 }
4437
4438 Ok(())
4439 },
4440 );
4441
4442 send_notifications(pool, &session.peer, notifications);
4443
4444 Ok(())
4445}
4446
4447/// Mark a channel message as read
4448async fn acknowledge_channel_message(
4449 request: proto::AckChannelMessage,
4450 session: UserSession,
4451) -> Result<()> {
4452 let channel_id = ChannelId::from_proto(request.channel_id);
4453 let message_id = MessageId::from_proto(request.message_id);
4454 let notifications = session
4455 .db()
4456 .await
4457 .observe_channel_message(channel_id, session.user_id(), message_id)
4458 .await?;
4459 send_notifications(
4460 &*session.connection_pool().await,
4461 &session.peer,
4462 notifications,
4463 );
4464 Ok(())
4465}
4466
4467/// Mark a buffer version as synced
4468async fn acknowledge_buffer_version(
4469 request: proto::AckBufferOperation,
4470 session: UserSession,
4471) -> Result<()> {
4472 let buffer_id = BufferId::from_proto(request.buffer_id);
4473 session
4474 .db()
4475 .await
4476 .observe_buffer_version(
4477 buffer_id,
4478 session.user_id(),
4479 request.epoch as i32,
4480 &request.version,
4481 )
4482 .await?;
4483 Ok(())
4484}
4485
4486struct CompleteWithLanguageModelRateLimit;
4487
4488impl RateLimit for CompleteWithLanguageModelRateLimit {
4489 fn capacity() -> usize {
4490 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4491 .ok()
4492 .and_then(|v| v.parse().ok())
4493 .unwrap_or(120) // Picked arbitrarily
4494 }
4495
4496 fn refill_duration() -> chrono::Duration {
4497 chrono::Duration::hours(1)
4498 }
4499
4500 fn db_name() -> &'static str {
4501 "complete-with-language-model"
4502 }
4503}
4504
4505async fn complete_with_language_model(
4506 query: proto::QueryLanguageModel,
4507 response: StreamingResponse<proto::QueryLanguageModel>,
4508 session: Session,
4509 open_ai_api_key: Option<Arc<str>>,
4510 google_ai_api_key: Option<Arc<str>>,
4511 anthropic_api_key: Option<Arc<str>>,
4512) -> Result<()> {
4513 let Some(session) = session.for_user() else {
4514 return Err(anyhow!("user not found"))?;
4515 };
4516 authorize_access_to_language_models(&session).await?;
4517
4518 match proto::LanguageModelRequestKind::from_i32(query.kind) {
4519 Some(proto::LanguageModelRequestKind::Complete) => {
4520 session
4521 .rate_limiter
4522 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4523 .await?;
4524 }
4525 Some(proto::LanguageModelRequestKind::CountTokens) => {
4526 session
4527 .rate_limiter
4528 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4529 .await?;
4530 }
4531 None => Err(anyhow!("unknown request kind"))?,
4532 }
4533
4534 match proto::LanguageModelProvider::from_i32(query.provider) {
4535 Some(proto::LanguageModelProvider::Anthropic) => {
4536 let api_key =
4537 anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
4538 let mut chunks = anthropic::stream_completion(
4539 session.http_client.as_ref(),
4540 anthropic::ANTHROPIC_API_URL,
4541 &api_key,
4542 serde_json::from_str(&query.request)?,
4543 None,
4544 )
4545 .await?;
4546 while let Some(chunk) = chunks.next().await {
4547 let chunk = chunk?;
4548 response.send(proto::QueryLanguageModelResponse {
4549 response: serde_json::to_string(&chunk)?,
4550 })?;
4551 }
4552 }
4553 Some(proto::LanguageModelProvider::OpenAi) => {
4554 let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
4555 let mut chunks = open_ai::stream_completion(
4556 session.http_client.as_ref(),
4557 open_ai::OPEN_AI_API_URL,
4558 &api_key,
4559 serde_json::from_str(&query.request)?,
4560 None,
4561 )
4562 .await?;
4563 while let Some(chunk) = chunks.next().await {
4564 let chunk = chunk?;
4565 response.send(proto::QueryLanguageModelResponse {
4566 response: serde_json::to_string(&chunk)?,
4567 })?;
4568 }
4569 }
4570 Some(proto::LanguageModelProvider::Google) => {
4571 let api_key =
4572 google_ai_api_key.context("no Google AI API key configured on the server")?;
4573
4574 match proto::LanguageModelRequestKind::from_i32(query.kind) {
4575 Some(proto::LanguageModelRequestKind::Complete) => {
4576 let mut chunks = google_ai::stream_generate_content(
4577 session.http_client.as_ref(),
4578 google_ai::API_URL,
4579 &api_key,
4580 serde_json::from_str(&query.request)?,
4581 )
4582 .await?;
4583 while let Some(chunk) = chunks.next().await {
4584 let chunk = chunk?;
4585 response.send(proto::QueryLanguageModelResponse {
4586 response: serde_json::to_string(&chunk)?,
4587 })?;
4588 }
4589 }
4590 Some(proto::LanguageModelRequestKind::CountTokens) => {
4591 let tokens_response = google_ai::count_tokens(
4592 session.http_client.as_ref(),
4593 google_ai::API_URL,
4594 &api_key,
4595 serde_json::from_str(&query.request)?,
4596 )
4597 .await?;
4598 response.send(proto::QueryLanguageModelResponse {
4599 response: serde_json::to_string(&tokens_response)?,
4600 })?;
4601 }
4602 None => Err(anyhow!("unknown request kind"))?,
4603 }
4604 }
4605 None => return Err(anyhow!("unknown provider"))?,
4606 }
4607
4608 Ok(())
4609}
4610
4611struct CountTokensWithLanguageModelRateLimit;
4612
4613impl RateLimit for CountTokensWithLanguageModelRateLimit {
4614 fn capacity() -> usize {
4615 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4616 .ok()
4617 .and_then(|v| v.parse().ok())
4618 .unwrap_or(600) // Picked arbitrarily
4619 }
4620
4621 fn refill_duration() -> chrono::Duration {
4622 chrono::Duration::hours(1)
4623 }
4624
4625 fn db_name() -> &'static str {
4626 "count-tokens-with-language-model"
4627 }
4628}
4629
4630struct ComputeEmbeddingsRateLimit;
4631
4632impl RateLimit for ComputeEmbeddingsRateLimit {
4633 fn capacity() -> usize {
4634 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4635 .ok()
4636 .and_then(|v| v.parse().ok())
4637 .unwrap_or(5000) // Picked arbitrarily
4638 }
4639
4640 fn refill_duration() -> chrono::Duration {
4641 chrono::Duration::hours(1)
4642 }
4643
4644 fn db_name() -> &'static str {
4645 "compute-embeddings"
4646 }
4647}
4648
4649async fn compute_embeddings(
4650 request: proto::ComputeEmbeddings,
4651 response: Response<proto::ComputeEmbeddings>,
4652 session: UserSession,
4653 api_key: Option<Arc<str>>,
4654) -> Result<()> {
4655 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4656 authorize_access_to_language_models(&session).await?;
4657
4658 session
4659 .rate_limiter
4660 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4661 .await?;
4662
4663 let embeddings = match request.model.as_str() {
4664 "openai/text-embedding-3-small" => {
4665 open_ai::embed(
4666 session.http_client.as_ref(),
4667 OPEN_AI_API_URL,
4668 &api_key,
4669 OpenAiEmbeddingModel::TextEmbedding3Small,
4670 request.texts.iter().map(|text| text.as_str()),
4671 )
4672 .await?
4673 }
4674 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4675 };
4676
4677 let embeddings = request
4678 .texts
4679 .iter()
4680 .map(|text| {
4681 let mut hasher = sha2::Sha256::new();
4682 hasher.update(text.as_bytes());
4683 let result = hasher.finalize();
4684 result.to_vec()
4685 })
4686 .zip(
4687 embeddings
4688 .data
4689 .into_iter()
4690 .map(|embedding| embedding.embedding),
4691 )
4692 .collect::<HashMap<_, _>>();
4693
4694 let db = session.db().await;
4695 db.save_embeddings(&request.model, &embeddings)
4696 .await
4697 .context("failed to save embeddings")
4698 .trace_err();
4699
4700 response.send(proto::ComputeEmbeddingsResponse {
4701 embeddings: embeddings
4702 .into_iter()
4703 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4704 .collect(),
4705 })?;
4706 Ok(())
4707}
4708
4709async fn get_cached_embeddings(
4710 request: proto::GetCachedEmbeddings,
4711 response: Response<proto::GetCachedEmbeddings>,
4712 session: UserSession,
4713) -> Result<()> {
4714 authorize_access_to_language_models(&session).await?;
4715
4716 let db = session.db().await;
4717 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4718
4719 response.send(proto::GetCachedEmbeddingsResponse {
4720 embeddings: embeddings
4721 .into_iter()
4722 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4723 .collect(),
4724 })?;
4725 Ok(())
4726}
4727
4728async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4729 let db = session.db().await;
4730 let flags = db.get_user_flags(session.user_id()).await?;
4731 if flags.iter().any(|flag| flag == "language-models") {
4732 Ok(())
4733 } else {
4734 Err(anyhow!("permission denied"))?
4735 }
4736}
4737
4738/// Get a Supermaven API key for the user
4739async fn get_supermaven_api_key(
4740 _request: proto::GetSupermavenApiKey,
4741 response: Response<proto::GetSupermavenApiKey>,
4742 session: UserSession,
4743) -> Result<()> {
4744 let user_id: String = session.user_id().to_string();
4745 if !session.is_staff() {
4746 return Err(anyhow!("supermaven not enabled for this account"))?;
4747 }
4748
4749 let email = session
4750 .email()
4751 .ok_or_else(|| anyhow!("user must have an email"))?;
4752
4753 let supermaven_admin_api = session
4754 .supermaven_client
4755 .as_ref()
4756 .ok_or_else(|| anyhow!("supermaven not configured"))?;
4757
4758 let result = supermaven_admin_api
4759 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
4760 .await?;
4761
4762 response.send(proto::GetSupermavenApiKeyResponse {
4763 api_key: result.api_key,
4764 })?;
4765
4766 Ok(())
4767}
4768
4769/// Start receiving chat updates for a channel
4770async fn join_channel_chat(
4771 request: proto::JoinChannelChat,
4772 response: Response<proto::JoinChannelChat>,
4773 session: UserSession,
4774) -> Result<()> {
4775 let channel_id = ChannelId::from_proto(request.channel_id);
4776
4777 let db = session.db().await;
4778 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
4779 .await?;
4780 let messages = db
4781 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
4782 .await?;
4783 response.send(proto::JoinChannelChatResponse {
4784 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4785 messages,
4786 })?;
4787 Ok(())
4788}
4789
4790/// Stop receiving chat updates for a channel
4791async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
4792 let channel_id = ChannelId::from_proto(request.channel_id);
4793 session
4794 .db()
4795 .await
4796 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
4797 .await?;
4798 Ok(())
4799}
4800
4801/// Retrieve the chat history for a channel
4802async fn get_channel_messages(
4803 request: proto::GetChannelMessages,
4804 response: Response<proto::GetChannelMessages>,
4805 session: UserSession,
4806) -> Result<()> {
4807 let channel_id = ChannelId::from_proto(request.channel_id);
4808 let messages = session
4809 .db()
4810 .await
4811 .get_channel_messages(
4812 channel_id,
4813 session.user_id(),
4814 MESSAGE_COUNT_PER_PAGE,
4815 Some(MessageId::from_proto(request.before_message_id)),
4816 )
4817 .await?;
4818 response.send(proto::GetChannelMessagesResponse {
4819 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4820 messages,
4821 })?;
4822 Ok(())
4823}
4824
4825/// Retrieve specific chat messages
4826async fn get_channel_messages_by_id(
4827 request: proto::GetChannelMessagesById,
4828 response: Response<proto::GetChannelMessagesById>,
4829 session: UserSession,
4830) -> Result<()> {
4831 let message_ids = request
4832 .message_ids
4833 .iter()
4834 .map(|id| MessageId::from_proto(*id))
4835 .collect::<Vec<_>>();
4836 let messages = session
4837 .db()
4838 .await
4839 .get_channel_messages_by_id(session.user_id(), &message_ids)
4840 .await?;
4841 response.send(proto::GetChannelMessagesResponse {
4842 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
4843 messages,
4844 })?;
4845 Ok(())
4846}
4847
4848/// Retrieve the current users notifications
4849async fn get_notifications(
4850 request: proto::GetNotifications,
4851 response: Response<proto::GetNotifications>,
4852 session: UserSession,
4853) -> Result<()> {
4854 let notifications = session
4855 .db()
4856 .await
4857 .get_notifications(
4858 session.user_id(),
4859 NOTIFICATION_COUNT_PER_PAGE,
4860 request
4861 .before_id
4862 .map(|id| db::NotificationId::from_proto(id)),
4863 )
4864 .await?;
4865 response.send(proto::GetNotificationsResponse {
4866 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
4867 notifications,
4868 })?;
4869 Ok(())
4870}
4871
4872/// Mark notifications as read
4873async fn mark_notification_as_read(
4874 request: proto::MarkNotificationRead,
4875 response: Response<proto::MarkNotificationRead>,
4876 session: UserSession,
4877) -> Result<()> {
4878 let database = &session.db().await;
4879 let notifications = database
4880 .mark_notification_as_read_by_id(
4881 session.user_id(),
4882 NotificationId::from_proto(request.notification_id),
4883 )
4884 .await?;
4885 send_notifications(
4886 &*session.connection_pool().await,
4887 &session.peer,
4888 notifications,
4889 );
4890 response.send(proto::Ack {})?;
4891 Ok(())
4892}
4893
4894/// Get the current users information
4895async fn get_private_user_info(
4896 _request: proto::GetPrivateUserInfo,
4897 response: Response<proto::GetPrivateUserInfo>,
4898 session: UserSession,
4899) -> Result<()> {
4900 let db = session.db().await;
4901
4902 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
4903 let user = db
4904 .get_user_by_id(session.user_id())
4905 .await?
4906 .ok_or_else(|| anyhow!("user not found"))?;
4907 let flags = db.get_user_flags(session.user_id()).await?;
4908
4909 response.send(proto::GetPrivateUserInfoResponse {
4910 metrics_id,
4911 staff: user.admin,
4912 flags,
4913 })?;
4914 Ok(())
4915}
4916
4917fn to_axum_message(message: TungsteniteMessage) -> anyhow::Result<AxumMessage> {
4918 let message = match message {
4919 TungsteniteMessage::Text(payload) => AxumMessage::Text(payload),
4920 TungsteniteMessage::Binary(payload) => AxumMessage::Binary(payload),
4921 TungsteniteMessage::Ping(payload) => AxumMessage::Ping(payload),
4922 TungsteniteMessage::Pong(payload) => AxumMessage::Pong(payload),
4923 TungsteniteMessage::Close(frame) => AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
4924 code: frame.code.into(),
4925 reason: frame.reason,
4926 })),
4927 // We should never receive a frame while reading the message, according
4928 // to the `tungstenite` maintainers:
4929 //
4930 // > It cannot occur when you read messages from the WebSocket, but it
4931 // > can be used when you want to send the raw frames (e.g. you want to
4932 // > send the frames to the WebSocket without composing the full message first).
4933 // >
4934 // > — https://github.com/snapview/tungstenite-rs/issues/268
4935 TungsteniteMessage::Frame(_) => {
4936 bail!("received an unexpected frame while reading the message")
4937 }
4938 };
4939
4940 Ok(message)
4941}
4942
4943fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
4944 match message {
4945 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
4946 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
4947 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
4948 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
4949 AxumMessage::Close(frame) => {
4950 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
4951 code: frame.code.into(),
4952 reason: frame.reason,
4953 }))
4954 }
4955 }
4956}
4957
4958fn notify_membership_updated(
4959 connection_pool: &mut ConnectionPool,
4960 result: MembershipUpdated,
4961 user_id: UserId,
4962 peer: &Peer,
4963) {
4964 for membership in &result.new_channels.channel_memberships {
4965 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
4966 }
4967 for channel_id in &result.removed_channels {
4968 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
4969 }
4970
4971 let user_channels_update = proto::UpdateUserChannels {
4972 channel_memberships: result
4973 .new_channels
4974 .channel_memberships
4975 .iter()
4976 .map(|cm| proto::ChannelMembership {
4977 channel_id: cm.channel_id.to_proto(),
4978 role: cm.role.into(),
4979 })
4980 .collect(),
4981 ..Default::default()
4982 };
4983
4984 let mut update = build_channels_update(result.new_channels);
4985 update.delete_channels = result
4986 .removed_channels
4987 .into_iter()
4988 .map(|id| id.to_proto())
4989 .collect();
4990 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
4991
4992 for connection_id in connection_pool.user_connection_ids(user_id) {
4993 peer.send(connection_id, user_channels_update.clone())
4994 .trace_err();
4995 peer.send(connection_id, update.clone()).trace_err();
4996 }
4997}
4998
4999fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5000 proto::UpdateUserChannels {
5001 channel_memberships: channels
5002 .channel_memberships
5003 .iter()
5004 .map(|m| proto::ChannelMembership {
5005 channel_id: m.channel_id.to_proto(),
5006 role: m.role.into(),
5007 })
5008 .collect(),
5009 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5010 observed_channel_message_id: channels.observed_channel_messages.clone(),
5011 }
5012}
5013
5014fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5015 let mut update = proto::UpdateChannels::default();
5016
5017 for channel in channels.channels {
5018 update.channels.push(channel.to_proto());
5019 }
5020
5021 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5022 update.latest_channel_message_ids = channels.latest_channel_messages;
5023
5024 for (channel_id, participants) in channels.channel_participants {
5025 update
5026 .channel_participants
5027 .push(proto::ChannelParticipants {
5028 channel_id: channel_id.to_proto(),
5029 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5030 });
5031 }
5032
5033 for channel in channels.invited_channels {
5034 update.channel_invitations.push(channel.to_proto());
5035 }
5036
5037 update.hosted_projects = channels.hosted_projects;
5038 update
5039}
5040
5041fn build_initial_contacts_update(
5042 contacts: Vec<db::Contact>,
5043 pool: &ConnectionPool,
5044) -> proto::UpdateContacts {
5045 let mut update = proto::UpdateContacts::default();
5046
5047 for contact in contacts {
5048 match contact {
5049 db::Contact::Accepted { user_id, busy } => {
5050 update.contacts.push(contact_for_user(user_id, busy, &pool));
5051 }
5052 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5053 db::Contact::Incoming { user_id } => {
5054 update
5055 .incoming_requests
5056 .push(proto::IncomingContactRequest {
5057 requester_id: user_id.to_proto(),
5058 })
5059 }
5060 }
5061 }
5062
5063 update
5064}
5065
5066fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5067 proto::Contact {
5068 user_id: user_id.to_proto(),
5069 online: pool.is_user_online(user_id),
5070 busy,
5071 }
5072}
5073
5074fn room_updated(room: &proto::Room, peer: &Peer) {
5075 broadcast(
5076 None,
5077 room.participants
5078 .iter()
5079 .filter_map(|participant| Some(participant.peer_id?.into())),
5080 |peer_id| {
5081 peer.send(
5082 peer_id,
5083 proto::RoomUpdated {
5084 room: Some(room.clone()),
5085 },
5086 )
5087 },
5088 );
5089}
5090
5091fn channel_updated(
5092 channel: &db::channel::Model,
5093 room: &proto::Room,
5094 peer: &Peer,
5095 pool: &ConnectionPool,
5096) {
5097 let participants = room
5098 .participants
5099 .iter()
5100 .map(|p| p.user_id)
5101 .collect::<Vec<_>>();
5102
5103 broadcast(
5104 None,
5105 pool.channel_connection_ids(channel.root_id())
5106 .filter_map(|(channel_id, role)| {
5107 role.can_see_channel(channel.visibility).then(|| channel_id)
5108 }),
5109 |peer_id| {
5110 peer.send(
5111 peer_id,
5112 proto::UpdateChannels {
5113 channel_participants: vec![proto::ChannelParticipants {
5114 channel_id: channel.id.to_proto(),
5115 participant_user_ids: participants.clone(),
5116 }],
5117 ..Default::default()
5118 },
5119 )
5120 },
5121 );
5122}
5123
5124async fn send_dev_server_projects_update(
5125 user_id: UserId,
5126 mut status: proto::DevServerProjectsUpdate,
5127 session: &Session,
5128) {
5129 let pool = session.connection_pool().await;
5130 for dev_server in &mut status.dev_servers {
5131 dev_server.status =
5132 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5133 }
5134 let connections = pool.user_connection_ids(user_id);
5135 for connection_id in connections {
5136 session.peer.send(connection_id, status.clone()).trace_err();
5137 }
5138}
5139
5140async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5141 let db = session.db().await;
5142
5143 let contacts = db.get_contacts(user_id).await?;
5144 let busy = db.is_user_busy(user_id).await?;
5145
5146 let pool = session.connection_pool().await;
5147 let updated_contact = contact_for_user(user_id, busy, &pool);
5148 for contact in contacts {
5149 if let db::Contact::Accepted {
5150 user_id: contact_user_id,
5151 ..
5152 } = contact
5153 {
5154 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5155 session
5156 .peer
5157 .send(
5158 contact_conn_id,
5159 proto::UpdateContacts {
5160 contacts: vec![updated_contact.clone()],
5161 remove_contacts: Default::default(),
5162 incoming_requests: Default::default(),
5163 remove_incoming_requests: Default::default(),
5164 outgoing_requests: Default::default(),
5165 remove_outgoing_requests: Default::default(),
5166 },
5167 )
5168 .trace_err();
5169 }
5170 }
5171 }
5172 Ok(())
5173}
5174
5175async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5176 log::info!("lost dev server connection, unsharing projects");
5177 let project_ids = session
5178 .db()
5179 .await
5180 .get_stale_dev_server_projects(session.connection_id)
5181 .await?;
5182
5183 for project_id in project_ids {
5184 // not unshare re-checks the connection ids match, so we get away with no transaction
5185 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5186 }
5187
5188 let user_id = session.dev_server().user_id;
5189 let update = session
5190 .db()
5191 .await
5192 .dev_server_projects_update(user_id)
5193 .await?;
5194
5195 send_dev_server_projects_update(user_id, update, session).await;
5196
5197 Ok(())
5198}
5199
5200async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5201 let mut contacts_to_update = HashSet::default();
5202
5203 let room_id;
5204 let canceled_calls_to_user_ids;
5205 let live_kit_room;
5206 let delete_live_kit_room;
5207 let room;
5208 let channel;
5209
5210 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5211 contacts_to_update.insert(session.user_id());
5212
5213 for project in left_room.left_projects.values() {
5214 project_left(project, session);
5215 }
5216
5217 room_id = RoomId::from_proto(left_room.room.id);
5218 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5219 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5220 delete_live_kit_room = left_room.deleted;
5221 room = mem::take(&mut left_room.room);
5222 channel = mem::take(&mut left_room.channel);
5223
5224 room_updated(&room, &session.peer);
5225 } else {
5226 return Ok(());
5227 }
5228
5229 if let Some(channel) = channel {
5230 channel_updated(
5231 &channel,
5232 &room,
5233 &session.peer,
5234 &*session.connection_pool().await,
5235 );
5236 }
5237
5238 {
5239 let pool = session.connection_pool().await;
5240 for canceled_user_id in canceled_calls_to_user_ids {
5241 for connection_id in pool.user_connection_ids(canceled_user_id) {
5242 session
5243 .peer
5244 .send(
5245 connection_id,
5246 proto::CallCanceled {
5247 room_id: room_id.to_proto(),
5248 },
5249 )
5250 .trace_err();
5251 }
5252 contacts_to_update.insert(canceled_user_id);
5253 }
5254 }
5255
5256 for contact_user_id in contacts_to_update {
5257 update_user_contacts(contact_user_id, &session).await?;
5258 }
5259
5260 if let Some(live_kit) = session.live_kit_client.as_ref() {
5261 live_kit
5262 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5263 .await
5264 .trace_err();
5265
5266 if delete_live_kit_room {
5267 live_kit.delete_room(live_kit_room).await.trace_err();
5268 }
5269 }
5270
5271 Ok(())
5272}
5273
5274async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5275 let left_channel_buffers = session
5276 .db()
5277 .await
5278 .leave_channel_buffers(session.connection_id)
5279 .await?;
5280
5281 for left_buffer in left_channel_buffers {
5282 channel_buffer_updated(
5283 session.connection_id,
5284 left_buffer.connections,
5285 &proto::UpdateChannelBufferCollaborators {
5286 channel_id: left_buffer.channel_id.to_proto(),
5287 collaborators: left_buffer.collaborators,
5288 },
5289 &session.peer,
5290 );
5291 }
5292
5293 Ok(())
5294}
5295
5296fn project_left(project: &db::LeftProject, session: &UserSession) {
5297 for connection_id in &project.connection_ids {
5298 if project.should_unshare {
5299 session
5300 .peer
5301 .send(
5302 *connection_id,
5303 proto::UnshareProject {
5304 project_id: project.id.to_proto(),
5305 },
5306 )
5307 .trace_err();
5308 } else {
5309 session
5310 .peer
5311 .send(
5312 *connection_id,
5313 proto::RemoveProjectCollaborator {
5314 project_id: project.id.to_proto(),
5315 peer_id: Some(session.connection_id.into()),
5316 },
5317 )
5318 .trace_err();
5319 }
5320 }
5321}
5322
5323pub trait ResultExt {
5324 type Ok;
5325
5326 fn trace_err(self) -> Option<Self::Ok>;
5327}
5328
5329impl<T, E> ResultExt for Result<T, E>
5330where
5331 E: std::fmt::Debug,
5332{
5333 type Ok = T;
5334
5335 #[track_caller]
5336 fn trace_err(self) -> Option<T> {
5337 match self {
5338 Ok(value) => Some(value),
5339 Err(error) => {
5340 tracing::error!("{:?}", error);
5341 None
5342 }
5343 }
5344 }
5345}