1mod connection_pool;
2
3use crate::{
4 auth,
5 db::{
6 self, dev_server, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
7 CreatedChannelMessage, Database, DevServerId, DevServerProjectId, InviteMemberResult,
8 MembershipUpdated, MessageId, NotificationId, PrincipalId, Project, ProjectId,
9 RejoinedProject, RemoveChannelMemberResult, ReplicaId, RespondToChannelInvite, RoomId,
10 ServerId, UpdatedChannelMessage, User, UserId,
11 },
12 executor::Executor,
13 AppState, Error, RateLimit, RateLimiter, Result,
14};
15use anyhow::{anyhow, bail, Context as _};
16use async_tungstenite::tungstenite::{
17 protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
18};
19use axum::{
20 body::Body,
21 extract::{
22 ws::{CloseFrame as AxumCloseFrame, Message as AxumMessage},
23 ConnectInfo, WebSocketUpgrade,
24 },
25 headers::{Header, HeaderName},
26 http::StatusCode,
27 middleware,
28 response::IntoResponse,
29 routing::get,
30 Extension, Router, TypedHeader,
31};
32use collections::{HashMap, HashSet};
33pub use connection_pool::{ConnectionPool, ZedVersion};
34use core::fmt::{self, Debug, Formatter};
35use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
36use sha2::Digest;
37use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
38
39use futures::{
40 channel::oneshot,
41 future::{self, BoxFuture},
42 stream::FuturesUnordered,
43 FutureExt, SinkExt, StreamExt, TryStreamExt,
44};
45use http_client::IsahcHttpClient;
46use prometheus::{register_int_gauge, IntGauge};
47use rpc::{
48 proto::{
49 self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
50 LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
51 },
52 Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
53};
54use semantic_version::SemanticVersion;
55use serde::{Serialize, Serializer};
56use std::{
57 any::TypeId,
58 future::Future,
59 marker::PhantomData,
60 mem,
61 net::SocketAddr,
62 ops::{Deref, DerefMut},
63 rc::Rc,
64 sync::{
65 atomic::{AtomicBool, Ordering::SeqCst},
66 Arc, OnceLock,
67 },
68 time::{Duration, Instant},
69};
70use time::OffsetDateTime;
71use tokio::sync::{watch, Semaphore};
72use tower::ServiceBuilder;
73use tracing::{
74 field::{self},
75 info_span, instrument, Instrument,
76};
77
78use self::connection_pool::VersionedMessage;
79
80pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
81
82// kubernetes gives terminated pods 10s to shutdown gracefully. After they're gone, we can clean up old resources.
83pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(15);
84
85const MESSAGE_COUNT_PER_PAGE: usize = 100;
86const MAX_MESSAGE_LEN: usize = 1024;
87const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
88
89type MessageHandler =
90 Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
91
92struct Response<R> {
93 peer: Arc<Peer>,
94 receipt: Receipt<R>,
95 responded: Arc<AtomicBool>,
96}
97
98impl<R: RequestMessage> Response<R> {
99 fn send(self, payload: R::Response) -> Result<()> {
100 self.responded.store(true, SeqCst);
101 self.peer.respond(self.receipt, payload)?;
102 Ok(())
103 }
104}
105
106struct StreamingResponse<R: RequestMessage> {
107 peer: Arc<Peer>,
108 receipt: Receipt<R>,
109}
110
111impl<R: RequestMessage> StreamingResponse<R> {
112 fn send(&self, payload: R::Response) -> Result<()> {
113 self.peer.respond(self.receipt, payload)?;
114 Ok(())
115 }
116}
117
118#[derive(Clone, Debug)]
119pub enum Principal {
120 User(User),
121 Impersonated { user: User, admin: User },
122 DevServer(dev_server::Model),
123}
124
125impl Principal {
126 fn update_span(&self, span: &tracing::Span) {
127 match &self {
128 Principal::User(user) => {
129 span.record("user_id", &user.id.0);
130 span.record("login", &user.github_login);
131 }
132 Principal::Impersonated { user, admin } => {
133 span.record("user_id", &user.id.0);
134 span.record("login", &user.github_login);
135 span.record("impersonator", &admin.github_login);
136 }
137 Principal::DevServer(dev_server) => {
138 span.record("dev_server_id", &dev_server.id.0);
139 }
140 }
141 }
142}
143
144#[derive(Clone)]
145struct Session {
146 principal: Principal,
147 connection_id: ConnectionId,
148 db: Arc<tokio::sync::Mutex<DbHandle>>,
149 peer: Arc<Peer>,
150 connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
151 live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
152 supermaven_client: Option<Arc<SupermavenAdminApi>>,
153 http_client: Arc<IsahcHttpClient>,
154 rate_limiter: Arc<RateLimiter>,
155 _executor: Executor,
156}
157
158impl Session {
159 async fn db(&self) -> tokio::sync::MutexGuard<DbHandle> {
160 #[cfg(test)]
161 tokio::task::yield_now().await;
162 let guard = self.db.lock().await;
163 #[cfg(test)]
164 tokio::task::yield_now().await;
165 guard
166 }
167
168 async fn connection_pool(&self) -> ConnectionPoolGuard<'_> {
169 #[cfg(test)]
170 tokio::task::yield_now().await;
171 let guard = self.connection_pool.lock();
172 ConnectionPoolGuard {
173 guard,
174 _not_send: PhantomData,
175 }
176 }
177
178 fn for_user(self) -> Option<UserSession> {
179 UserSession::new(self)
180 }
181
182 fn for_dev_server(self) -> Option<DevServerSession> {
183 DevServerSession::new(self)
184 }
185
186 fn user_id(&self) -> Option<UserId> {
187 match &self.principal {
188 Principal::User(user) => Some(user.id),
189 Principal::Impersonated { user, .. } => Some(user.id),
190 Principal::DevServer(_) => None,
191 }
192 }
193
194 fn is_staff(&self) -> bool {
195 match &self.principal {
196 Principal::User(user) => user.admin,
197 Principal::Impersonated { .. } => true,
198 Principal::DevServer(_) => false,
199 }
200 }
201
202 fn dev_server_id(&self) -> Option<DevServerId> {
203 match &self.principal {
204 Principal::User(_) | Principal::Impersonated { .. } => None,
205 Principal::DevServer(dev_server) => Some(dev_server.id),
206 }
207 }
208
209 fn principal_id(&self) -> PrincipalId {
210 match &self.principal {
211 Principal::User(user) => PrincipalId::UserId(user.id),
212 Principal::Impersonated { user, .. } => PrincipalId::UserId(user.id),
213 Principal::DevServer(dev_server) => PrincipalId::DevServerId(dev_server.id),
214 }
215 }
216}
217
218impl Debug for Session {
219 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
220 let mut result = f.debug_struct("Session");
221 match &self.principal {
222 Principal::User(user) => {
223 result.field("user", &user.github_login);
224 }
225 Principal::Impersonated { user, admin } => {
226 result.field("user", &user.github_login);
227 result.field("impersonator", &admin.github_login);
228 }
229 Principal::DevServer(dev_server) => {
230 result.field("dev_server", &dev_server.id);
231 }
232 }
233 result.field("connection_id", &self.connection_id).finish()
234 }
235}
236
237struct UserSession(Session);
238
239impl UserSession {
240 pub fn new(s: Session) -> Option<Self> {
241 s.user_id().map(|_| UserSession(s))
242 }
243 pub fn user_id(&self) -> UserId {
244 self.0.user_id().unwrap()
245 }
246
247 pub fn email(&self) -> Option<String> {
248 match &self.0.principal {
249 Principal::User(user) => user.email_address.clone(),
250 Principal::Impersonated { user, .. } => user.email_address.clone(),
251 Principal::DevServer(..) => None,
252 }
253 }
254}
255
256impl Deref for UserSession {
257 type Target = Session;
258
259 fn deref(&self) -> &Self::Target {
260 &self.0
261 }
262}
263impl DerefMut for UserSession {
264 fn deref_mut(&mut self) -> &mut Self::Target {
265 &mut self.0
266 }
267}
268
269struct DevServerSession(Session);
270
271impl DevServerSession {
272 pub fn new(s: Session) -> Option<Self> {
273 s.dev_server_id().map(|_| DevServerSession(s))
274 }
275 pub fn dev_server_id(&self) -> DevServerId {
276 self.0.dev_server_id().unwrap()
277 }
278
279 fn dev_server(&self) -> &dev_server::Model {
280 match &self.0.principal {
281 Principal::DevServer(dev_server) => dev_server,
282 _ => unreachable!(),
283 }
284 }
285}
286
287impl Deref for DevServerSession {
288 type Target = Session;
289
290 fn deref(&self) -> &Self::Target {
291 &self.0
292 }
293}
294impl DerefMut for DevServerSession {
295 fn deref_mut(&mut self) -> &mut Self::Target {
296 &mut self.0
297 }
298}
299
300fn user_handler<M: RequestMessage, Fut>(
301 handler: impl 'static + Send + Sync + Fn(M, Response<M>, UserSession) -> Fut,
302) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
303where
304 Fut: Send + Future<Output = Result<()>>,
305{
306 let handler = Arc::new(handler);
307 move |message, response, session| {
308 let handler = handler.clone();
309 Box::pin(async move {
310 if let Some(user_session) = session.for_user() {
311 Ok(handler(message, response, user_session).await?)
312 } else {
313 Err(Error::Internal(anyhow!(
314 "must be a user to call {}",
315 M::NAME
316 )))
317 }
318 })
319 }
320}
321
322fn dev_server_handler<M: RequestMessage, Fut>(
323 handler: impl 'static + Send + Sync + Fn(M, Response<M>, DevServerSession) -> Fut,
324) -> impl 'static + Send + Sync + Fn(M, Response<M>, Session) -> BoxFuture<'static, Result<()>>
325where
326 Fut: Send + Future<Output = Result<()>>,
327{
328 let handler = Arc::new(handler);
329 move |message, response, session| {
330 let handler = handler.clone();
331 Box::pin(async move {
332 if let Some(dev_server_session) = session.for_dev_server() {
333 Ok(handler(message, response, dev_server_session).await?)
334 } else {
335 Err(Error::Internal(anyhow!(
336 "must be a dev server to call {}",
337 M::NAME
338 )))
339 }
340 })
341 }
342}
343
344fn user_message_handler<M: EnvelopedMessage, InnertRetFut>(
345 handler: impl 'static + Send + Sync + Fn(M, UserSession) -> InnertRetFut,
346) -> impl 'static + Send + Sync + Fn(M, Session) -> BoxFuture<'static, Result<()>>
347where
348 InnertRetFut: Send + Future<Output = Result<()>>,
349{
350 let handler = Arc::new(handler);
351 move |message, session| {
352 let handler = handler.clone();
353 Box::pin(async move {
354 if let Some(user_session) = session.for_user() {
355 Ok(handler(message, user_session).await?)
356 } else {
357 Err(Error::Internal(anyhow!(
358 "must be a user to call {}",
359 M::NAME
360 )))
361 }
362 })
363 }
364}
365
366struct DbHandle(Arc<Database>);
367
368impl Deref for DbHandle {
369 type Target = Database;
370
371 fn deref(&self) -> &Self::Target {
372 self.0.as_ref()
373 }
374}
375
376pub struct Server {
377 id: parking_lot::Mutex<ServerId>,
378 peer: Arc<Peer>,
379 pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
380 app_state: Arc<AppState>,
381 handlers: HashMap<TypeId, MessageHandler>,
382 teardown: watch::Sender<bool>,
383}
384
385pub(crate) struct ConnectionPoolGuard<'a> {
386 guard: parking_lot::MutexGuard<'a, ConnectionPool>,
387 _not_send: PhantomData<Rc<()>>,
388}
389
390#[derive(Serialize)]
391pub struct ServerSnapshot<'a> {
392 peer: &'a Peer,
393 #[serde(serialize_with = "serialize_deref")]
394 connection_pool: ConnectionPoolGuard<'a>,
395}
396
397pub fn serialize_deref<S, T, U>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
398where
399 S: Serializer,
400 T: Deref<Target = U>,
401 U: Serialize,
402{
403 Serialize::serialize(value.deref(), serializer)
404}
405
406impl Server {
407 pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
408 let mut server = Self {
409 id: parking_lot::Mutex::new(id),
410 peer: Peer::new(id.0 as u32),
411 app_state: app_state.clone(),
412 connection_pool: Default::default(),
413 handlers: Default::default(),
414 teardown: watch::channel(false).0,
415 };
416
417 server
418 .add_request_handler(ping)
419 .add_request_handler(user_handler(create_room))
420 .add_request_handler(user_handler(join_room))
421 .add_request_handler(user_handler(rejoin_room))
422 .add_request_handler(user_handler(leave_room))
423 .add_request_handler(user_handler(set_room_participant_role))
424 .add_request_handler(user_handler(call))
425 .add_request_handler(user_handler(cancel_call))
426 .add_message_handler(user_message_handler(decline_call))
427 .add_request_handler(user_handler(update_participant_location))
428 .add_request_handler(user_handler(share_project))
429 .add_message_handler(unshare_project)
430 .add_request_handler(user_handler(join_project))
431 .add_request_handler(user_handler(join_hosted_project))
432 .add_request_handler(user_handler(rejoin_dev_server_projects))
433 .add_request_handler(user_handler(create_dev_server_project))
434 .add_request_handler(user_handler(update_dev_server_project))
435 .add_request_handler(user_handler(delete_dev_server_project))
436 .add_request_handler(user_handler(create_dev_server))
437 .add_request_handler(user_handler(regenerate_dev_server_token))
438 .add_request_handler(user_handler(rename_dev_server))
439 .add_request_handler(user_handler(delete_dev_server))
440 .add_request_handler(user_handler(list_remote_directory))
441 .add_request_handler(dev_server_handler(share_dev_server_project))
442 .add_request_handler(dev_server_handler(shutdown_dev_server))
443 .add_request_handler(dev_server_handler(reconnect_dev_server))
444 .add_message_handler(user_message_handler(leave_project))
445 .add_request_handler(update_project)
446 .add_request_handler(update_worktree)
447 .add_message_handler(start_language_server)
448 .add_message_handler(update_language_server)
449 .add_message_handler(update_diagnostic_summary)
450 .add_message_handler(update_worktree_settings)
451 .add_request_handler(user_handler(
452 forward_project_request_for_owner::<proto::TaskContextForLocation>,
453 ))
454 .add_request_handler(user_handler(
455 forward_project_request_for_owner::<proto::TaskTemplates>,
456 ))
457 .add_request_handler(user_handler(
458 forward_read_only_project_request::<proto::GetHover>,
459 ))
460 .add_request_handler(user_handler(
461 forward_read_only_project_request::<proto::GetDefinition>,
462 ))
463 .add_request_handler(user_handler(
464 forward_read_only_project_request::<proto::GetTypeDefinition>,
465 ))
466 .add_request_handler(user_handler(
467 forward_read_only_project_request::<proto::GetReferences>,
468 ))
469 .add_request_handler(user_handler(
470 forward_read_only_project_request::<proto::SearchProject>,
471 ))
472 .add_request_handler(user_handler(
473 forward_read_only_project_request::<proto::GetDocumentHighlights>,
474 ))
475 .add_request_handler(user_handler(
476 forward_read_only_project_request::<proto::GetProjectSymbols>,
477 ))
478 .add_request_handler(user_handler(
479 forward_read_only_project_request::<proto::OpenBufferForSymbol>,
480 ))
481 .add_request_handler(user_handler(
482 forward_read_only_project_request::<proto::OpenBufferById>,
483 ))
484 .add_request_handler(user_handler(
485 forward_read_only_project_request::<proto::SynchronizeBuffers>,
486 ))
487 .add_request_handler(user_handler(
488 forward_read_only_project_request::<proto::InlayHints>,
489 ))
490 .add_request_handler(user_handler(
491 forward_read_only_project_request::<proto::OpenBufferByPath>,
492 ))
493 .add_request_handler(user_handler(
494 forward_mutating_project_request::<proto::GetCompletions>,
495 ))
496 .add_request_handler(user_handler(
497 forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
498 ))
499 .add_request_handler(user_handler(
500 forward_versioned_mutating_project_request::<proto::OpenNewBuffer>,
501 ))
502 .add_request_handler(user_handler(
503 forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
504 ))
505 .add_request_handler(user_handler(
506 forward_mutating_project_request::<proto::GetCodeActions>,
507 ))
508 .add_request_handler(user_handler(
509 forward_mutating_project_request::<proto::ApplyCodeAction>,
510 ))
511 .add_request_handler(user_handler(
512 forward_mutating_project_request::<proto::PrepareRename>,
513 ))
514 .add_request_handler(user_handler(
515 forward_mutating_project_request::<proto::PerformRename>,
516 ))
517 .add_request_handler(user_handler(
518 forward_mutating_project_request::<proto::ReloadBuffers>,
519 ))
520 .add_request_handler(user_handler(
521 forward_mutating_project_request::<proto::FormatBuffers>,
522 ))
523 .add_request_handler(user_handler(
524 forward_mutating_project_request::<proto::CreateProjectEntry>,
525 ))
526 .add_request_handler(user_handler(
527 forward_mutating_project_request::<proto::RenameProjectEntry>,
528 ))
529 .add_request_handler(user_handler(
530 forward_mutating_project_request::<proto::CopyProjectEntry>,
531 ))
532 .add_request_handler(user_handler(
533 forward_mutating_project_request::<proto::DeleteProjectEntry>,
534 ))
535 .add_request_handler(user_handler(
536 forward_mutating_project_request::<proto::ExpandProjectEntry>,
537 ))
538 .add_request_handler(user_handler(
539 forward_mutating_project_request::<proto::OnTypeFormatting>,
540 ))
541 .add_request_handler(user_handler(
542 forward_versioned_mutating_project_request::<proto::SaveBuffer>,
543 ))
544 .add_request_handler(user_handler(
545 forward_mutating_project_request::<proto::BlameBuffer>,
546 ))
547 .add_request_handler(user_handler(
548 forward_mutating_project_request::<proto::MultiLspQuery>,
549 ))
550 .add_request_handler(user_handler(
551 forward_mutating_project_request::<proto::RestartLanguageServers>,
552 ))
553 .add_request_handler(user_handler(
554 forward_mutating_project_request::<proto::LinkedEditingRange>,
555 ))
556 .add_message_handler(create_buffer_for_peer)
557 .add_request_handler(update_buffer)
558 .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
559 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
560 .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
561 .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
562 .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
563 .add_request_handler(get_users)
564 .add_request_handler(user_handler(fuzzy_search_users))
565 .add_request_handler(user_handler(request_contact))
566 .add_request_handler(user_handler(remove_contact))
567 .add_request_handler(user_handler(respond_to_contact_request))
568 .add_message_handler(subscribe_to_channels)
569 .add_request_handler(user_handler(create_channel))
570 .add_request_handler(user_handler(delete_channel))
571 .add_request_handler(user_handler(invite_channel_member))
572 .add_request_handler(user_handler(remove_channel_member))
573 .add_request_handler(user_handler(set_channel_member_role))
574 .add_request_handler(user_handler(set_channel_visibility))
575 .add_request_handler(user_handler(rename_channel))
576 .add_request_handler(user_handler(join_channel_buffer))
577 .add_request_handler(user_handler(leave_channel_buffer))
578 .add_message_handler(user_message_handler(update_channel_buffer))
579 .add_request_handler(user_handler(rejoin_channel_buffers))
580 .add_request_handler(user_handler(get_channel_members))
581 .add_request_handler(user_handler(respond_to_channel_invite))
582 .add_request_handler(user_handler(join_channel))
583 .add_request_handler(user_handler(join_channel_chat))
584 .add_message_handler(user_message_handler(leave_channel_chat))
585 .add_request_handler(user_handler(send_channel_message))
586 .add_request_handler(user_handler(remove_channel_message))
587 .add_request_handler(user_handler(update_channel_message))
588 .add_request_handler(user_handler(get_channel_messages))
589 .add_request_handler(user_handler(get_channel_messages_by_id))
590 .add_request_handler(user_handler(get_notifications))
591 .add_request_handler(user_handler(mark_notification_as_read))
592 .add_request_handler(user_handler(move_channel))
593 .add_request_handler(user_handler(follow))
594 .add_message_handler(user_message_handler(unfollow))
595 .add_message_handler(user_message_handler(update_followers))
596 .add_request_handler(user_handler(get_private_user_info))
597 .add_message_handler(user_message_handler(acknowledge_channel_message))
598 .add_message_handler(user_message_handler(acknowledge_buffer_version))
599 .add_request_handler(user_handler(get_supermaven_api_key))
600 .add_request_handler(user_handler(
601 forward_mutating_project_request::<proto::OpenContext>,
602 ))
603 .add_request_handler(user_handler(
604 forward_mutating_project_request::<proto::SynchronizeContexts>,
605 ))
606 .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
607 .add_message_handler(update_context)
608 .add_streaming_request_handler({
609 let app_state = app_state.clone();
610 move |request, response, session| {
611 complete_with_language_model(
612 request,
613 response,
614 session,
615 app_state.config.openai_api_key.clone(),
616 app_state.config.google_ai_api_key.clone(),
617 app_state.config.anthropic_api_key.clone(),
618 )
619 }
620 })
621 .add_request_handler({
622 let app_state = app_state.clone();
623 user_handler(move |request, response, session| {
624 count_tokens_with_language_model(
625 request,
626 response,
627 session,
628 app_state.config.google_ai_api_key.clone(),
629 )
630 })
631 })
632 .add_request_handler({
633 user_handler(move |request, response, session| {
634 get_cached_embeddings(request, response, session)
635 })
636 })
637 .add_request_handler({
638 let app_state = app_state.clone();
639 user_handler(move |request, response, session| {
640 compute_embeddings(
641 request,
642 response,
643 session,
644 app_state.config.openai_api_key.clone(),
645 )
646 })
647 });
648
649 Arc::new(server)
650 }
651
652 pub async fn start(&self) -> Result<()> {
653 let server_id = *self.id.lock();
654 let app_state = self.app_state.clone();
655 let peer = self.peer.clone();
656 let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
657 let pool = self.connection_pool.clone();
658 let live_kit_client = self.app_state.live_kit_client.clone();
659
660 let span = info_span!("start server");
661 self.app_state.executor.spawn_detached(
662 async move {
663 tracing::info!("waiting for cleanup timeout");
664 timeout.await;
665 tracing::info!("cleanup timeout expired, retrieving stale rooms");
666 if let Some((room_ids, channel_ids)) = app_state
667 .db
668 .stale_server_resource_ids(&app_state.config.zed_environment, server_id)
669 .await
670 .trace_err()
671 {
672 tracing::info!(stale_room_count = room_ids.len(), "retrieved stale rooms");
673 tracing::info!(
674 stale_channel_buffer_count = channel_ids.len(),
675 "retrieved stale channel buffers"
676 );
677
678 for channel_id in channel_ids {
679 if let Some(refreshed_channel_buffer) = app_state
680 .db
681 .clear_stale_channel_buffer_collaborators(channel_id, server_id)
682 .await
683 .trace_err()
684 {
685 for connection_id in refreshed_channel_buffer.connection_ids {
686 peer.send(
687 connection_id,
688 proto::UpdateChannelBufferCollaborators {
689 channel_id: channel_id.to_proto(),
690 collaborators: refreshed_channel_buffer
691 .collaborators
692 .clone(),
693 },
694 )
695 .trace_err();
696 }
697 }
698 }
699
700 for room_id in room_ids {
701 let mut contacts_to_update = HashSet::default();
702 let mut canceled_calls_to_user_ids = Vec::new();
703 let mut live_kit_room = String::new();
704 let mut delete_live_kit_room = false;
705
706 if let Some(mut refreshed_room) = app_state
707 .db
708 .clear_stale_room_participants(room_id, server_id)
709 .await
710 .trace_err()
711 {
712 tracing::info!(
713 room_id = room_id.0,
714 new_participant_count = refreshed_room.room.participants.len(),
715 "refreshed room"
716 );
717 room_updated(&refreshed_room.room, &peer);
718 if let Some(channel) = refreshed_room.channel.as_ref() {
719 channel_updated(channel, &refreshed_room.room, &peer, &pool.lock());
720 }
721 contacts_to_update
722 .extend(refreshed_room.stale_participant_user_ids.iter().copied());
723 contacts_to_update
724 .extend(refreshed_room.canceled_calls_to_user_ids.iter().copied());
725 canceled_calls_to_user_ids =
726 mem::take(&mut refreshed_room.canceled_calls_to_user_ids);
727 live_kit_room = mem::take(&mut refreshed_room.room.live_kit_room);
728 delete_live_kit_room = refreshed_room.room.participants.is_empty();
729 }
730
731 {
732 let pool = pool.lock();
733 for canceled_user_id in canceled_calls_to_user_ids {
734 for connection_id in pool.user_connection_ids(canceled_user_id) {
735 peer.send(
736 connection_id,
737 proto::CallCanceled {
738 room_id: room_id.to_proto(),
739 },
740 )
741 .trace_err();
742 }
743 }
744 }
745
746 for user_id in contacts_to_update {
747 let busy = app_state.db.is_user_busy(user_id).await.trace_err();
748 let contacts = app_state.db.get_contacts(user_id).await.trace_err();
749 if let Some((busy, contacts)) = busy.zip(contacts) {
750 let pool = pool.lock();
751 let updated_contact = contact_for_user(user_id, busy, &pool);
752 for contact in contacts {
753 if let db::Contact::Accepted {
754 user_id: contact_user_id,
755 ..
756 } = contact
757 {
758 for contact_conn_id in
759 pool.user_connection_ids(contact_user_id)
760 {
761 peer.send(
762 contact_conn_id,
763 proto::UpdateContacts {
764 contacts: vec![updated_contact.clone()],
765 remove_contacts: Default::default(),
766 incoming_requests: Default::default(),
767 remove_incoming_requests: Default::default(),
768 outgoing_requests: Default::default(),
769 remove_outgoing_requests: Default::default(),
770 },
771 )
772 .trace_err();
773 }
774 }
775 }
776 }
777 }
778
779 if let Some(live_kit) = live_kit_client.as_ref() {
780 if delete_live_kit_room {
781 live_kit.delete_room(live_kit_room).await.trace_err();
782 }
783 }
784 }
785 }
786
787 app_state
788 .db
789 .delete_stale_servers(&app_state.config.zed_environment, server_id)
790 .await
791 .trace_err();
792 }
793 .instrument(span),
794 );
795 Ok(())
796 }
797
798 pub fn teardown(&self) {
799 self.peer.teardown();
800 self.connection_pool.lock().reset();
801 let _ = self.teardown.send(true);
802 }
803
804 #[cfg(test)]
805 pub fn reset(&self, id: ServerId) {
806 self.teardown();
807 *self.id.lock() = id;
808 self.peer.reset(id.0 as u32);
809 let _ = self.teardown.send(false);
810 }
811
812 #[cfg(test)]
813 pub fn id(&self) -> ServerId {
814 *self.id.lock()
815 }
816
817 fn add_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
818 where
819 F: 'static + Send + Sync + Fn(TypedEnvelope<M>, Session) -> Fut,
820 Fut: 'static + Send + Future<Output = Result<()>>,
821 M: EnvelopedMessage,
822 {
823 let prev_handler = self.handlers.insert(
824 TypeId::of::<M>(),
825 Box::new(move |envelope, session| {
826 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
827 let received_at = envelope.received_at;
828 tracing::info!("message received");
829 let start_time = Instant::now();
830 let future = (handler)(*envelope, session);
831 async move {
832 let result = future.await;
833 let total_duration_ms = received_at.elapsed().as_micros() as f64 / 1000.0;
834 let processing_duration_ms = start_time.elapsed().as_micros() as f64 / 1000.0;
835 let queue_duration_ms = total_duration_ms - processing_duration_ms;
836 let payload_type = M::NAME;
837
838 match result {
839 Err(error) => {
840 tracing::error!(
841 ?error,
842 total_duration_ms,
843 processing_duration_ms,
844 queue_duration_ms,
845 payload_type,
846 "error handling message"
847 )
848 }
849 Ok(()) => tracing::info!(
850 total_duration_ms,
851 processing_duration_ms,
852 queue_duration_ms,
853 "finished handling message"
854 ),
855 }
856 }
857 .boxed()
858 }),
859 );
860 if prev_handler.is_some() {
861 panic!("registered a handler for the same message twice");
862 }
863 self
864 }
865
866 fn add_message_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
867 where
868 F: 'static + Send + Sync + Fn(M, Session) -> Fut,
869 Fut: 'static + Send + Future<Output = Result<()>>,
870 M: EnvelopedMessage,
871 {
872 self.add_handler(move |envelope, session| handler(envelope.payload, session));
873 self
874 }
875
876 fn add_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
877 where
878 F: 'static + Send + Sync + Fn(M, Response<M>, Session) -> Fut,
879 Fut: Send + Future<Output = Result<()>>,
880 M: RequestMessage,
881 {
882 let handler = Arc::new(handler);
883 self.add_handler(move |envelope, session| {
884 let receipt = envelope.receipt();
885 let handler = handler.clone();
886 async move {
887 let peer = session.peer.clone();
888 let responded = Arc::new(AtomicBool::default());
889 let response = Response {
890 peer: peer.clone(),
891 responded: responded.clone(),
892 receipt,
893 };
894 match (handler)(envelope.payload, response, session).await {
895 Ok(()) => {
896 if responded.load(std::sync::atomic::Ordering::SeqCst) {
897 Ok(())
898 } else {
899 Err(anyhow!("handler did not send a response"))?
900 }
901 }
902 Err(error) => {
903 let proto_err = match &error {
904 Error::Internal(err) => err.to_proto(),
905 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
906 };
907 peer.respond_with_error(receipt, proto_err)?;
908 Err(error)
909 }
910 }
911 }
912 })
913 }
914
915 fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
916 where
917 F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
918 Fut: Send + Future<Output = Result<()>>,
919 M: RequestMessage,
920 {
921 let handler = Arc::new(handler);
922 self.add_handler(move |envelope, session| {
923 let receipt = envelope.receipt();
924 let handler = handler.clone();
925 async move {
926 let peer = session.peer.clone();
927 let response = StreamingResponse {
928 peer: peer.clone(),
929 receipt,
930 };
931 match (handler)(envelope.payload, response, session).await {
932 Ok(()) => {
933 peer.end_stream(receipt)?;
934 Ok(())
935 }
936 Err(error) => {
937 let proto_err = match &error {
938 Error::Internal(err) => err.to_proto(),
939 _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
940 };
941 peer.respond_with_error(receipt, proto_err)?;
942 Err(error)
943 }
944 }
945 }
946 })
947 }
948
949 #[allow(clippy::too_many_arguments)]
950 pub fn handle_connection(
951 self: &Arc<Self>,
952 connection: Connection,
953 address: String,
954 principal: Principal,
955 zed_version: ZedVersion,
956 send_connection_id: Option<oneshot::Sender<ConnectionId>>,
957 executor: Executor,
958 ) -> impl Future<Output = ()> {
959 let this = self.clone();
960 let span = info_span!("handle connection", %address,
961 connection_id=field::Empty,
962 user_id=field::Empty,
963 login=field::Empty,
964 impersonator=field::Empty,
965 dev_server_id=field::Empty
966 );
967 principal.update_span(&span);
968
969 let mut teardown = self.teardown.subscribe();
970 async move {
971 if *teardown.borrow() {
972 tracing::error!("server is tearing down");
973 return
974 }
975 let (connection_id, handle_io, mut incoming_rx) = this
976 .peer
977 .add_connection(connection, {
978 let executor = executor.clone();
979 move |duration| executor.sleep(duration)
980 });
981 tracing::Span::current().record("connection_id", format!("{}", connection_id));
982 tracing::info!("connection opened");
983
984 let http_client = match IsahcHttpClient::new() {
985 Ok(http_client) => Arc::new(http_client),
986 Err(error) => {
987 tracing::error!(?error, "failed to create HTTP client");
988 return;
989 }
990 };
991
992 let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() {
993 Some(Arc::new(SupermavenAdminApi::new(
994 supermaven_admin_api_key.to_string(),
995 http_client.clone(),
996 )))
997 } else {
998 None
999 };
1000
1001 let session = Session {
1002 principal: principal.clone(),
1003 connection_id,
1004 db: Arc::new(tokio::sync::Mutex::new(DbHandle(this.app_state.db.clone()))),
1005 peer: this.peer.clone(),
1006 connection_pool: this.connection_pool.clone(),
1007 live_kit_client: this.app_state.live_kit_client.clone(),
1008 http_client,
1009 rate_limiter: this.app_state.rate_limiter.clone(),
1010 _executor: executor.clone(),
1011 supermaven_client,
1012 };
1013
1014 if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await {
1015 tracing::error!(?error, "failed to send initial client update");
1016 return;
1017 }
1018
1019 let handle_io = handle_io.fuse();
1020 futures::pin_mut!(handle_io);
1021
1022 // Handlers for foreground messages are pushed into the following `FuturesUnordered`.
1023 // This prevents deadlocks when e.g., client A performs a request to client B and
1024 // client B performs a request to client A. If both clients stop processing further
1025 // messages until their respective request completes, they won't have a chance to
1026 // respond to the other client's request and cause a deadlock.
1027 //
1028 // This arrangement ensures we will attempt to process earlier messages first, but fall
1029 // back to processing messages arrived later in the spirit of making progress.
1030 let mut foreground_message_handlers = FuturesUnordered::new();
1031 let concurrent_handlers = Arc::new(Semaphore::new(256));
1032 loop {
1033 let next_message = async {
1034 let permit = concurrent_handlers.clone().acquire_owned().await.unwrap();
1035 let message = incoming_rx.next().await;
1036 (permit, message)
1037 }.fuse();
1038 futures::pin_mut!(next_message);
1039 futures::select_biased! {
1040 _ = teardown.changed().fuse() => return,
1041 result = handle_io => {
1042 if let Err(error) = result {
1043 tracing::error!(?error, "error handling I/O");
1044 }
1045 break;
1046 }
1047 _ = foreground_message_handlers.next() => {}
1048 next_message = next_message => {
1049 let (permit, message) = next_message;
1050 if let Some(message) = message {
1051 let type_name = message.payload_type_name();
1052 // note: we copy all the fields from the parent span so we can query them in the logs.
1053 // (https://github.com/tokio-rs/tracing/issues/2670).
1054 let span = tracing::info_span!("receive message", %connection_id, %address, type_name,
1055 user_id=field::Empty,
1056 login=field::Empty,
1057 impersonator=field::Empty,
1058 dev_server_id=field::Empty
1059 );
1060 principal.update_span(&span);
1061 let span_enter = span.enter();
1062 if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
1063 let is_background = message.is_background();
1064 let handle_message = (handler)(message, session.clone());
1065 drop(span_enter);
1066
1067 let handle_message = async move {
1068 handle_message.await;
1069 drop(permit);
1070 }.instrument(span);
1071 if is_background {
1072 executor.spawn_detached(handle_message);
1073 } else {
1074 foreground_message_handlers.push(handle_message);
1075 }
1076 } else {
1077 tracing::error!("no message handler");
1078 }
1079 } else {
1080 tracing::info!("connection closed");
1081 break;
1082 }
1083 }
1084 }
1085 }
1086
1087 drop(foreground_message_handlers);
1088 tracing::info!("signing out");
1089 if let Err(error) = connection_lost(session, teardown, executor).await {
1090 tracing::error!(?error, "error signing out");
1091 }
1092
1093 }.instrument(span)
1094 }
1095
1096 async fn send_initial_client_update(
1097 &self,
1098 connection_id: ConnectionId,
1099 principal: &Principal,
1100 zed_version: ZedVersion,
1101 mut send_connection_id: Option<oneshot::Sender<ConnectionId>>,
1102 session: &Session,
1103 ) -> Result<()> {
1104 self.peer.send(
1105 connection_id,
1106 proto::Hello {
1107 peer_id: Some(connection_id.into()),
1108 },
1109 )?;
1110 tracing::info!("sent hello message");
1111 if let Some(send_connection_id) = send_connection_id.take() {
1112 let _ = send_connection_id.send(connection_id);
1113 }
1114
1115 match principal {
1116 Principal::User(user) | Principal::Impersonated { user, admin: _ } => {
1117 if !user.connected_once {
1118 self.peer.send(connection_id, proto::ShowContacts {})?;
1119 self.app_state
1120 .db
1121 .set_user_connected_once(user.id, true)
1122 .await?;
1123 }
1124
1125 let (contacts, dev_server_projects) = future::try_join(
1126 self.app_state.db.get_contacts(user.id),
1127 self.app_state.db.dev_server_projects_update(user.id),
1128 )
1129 .await?;
1130
1131 {
1132 let mut pool = self.connection_pool.lock();
1133 pool.add_connection(connection_id, user.id, user.admin, zed_version);
1134 self.peer.send(
1135 connection_id,
1136 build_initial_contacts_update(contacts, &pool),
1137 )?;
1138 }
1139
1140 if should_auto_subscribe_to_channels(zed_version) {
1141 subscribe_user_to_channels(user.id, session).await?;
1142 }
1143
1144 send_dev_server_projects_update(user.id, dev_server_projects, session).await;
1145
1146 if let Some(incoming_call) =
1147 self.app_state.db.incoming_call_for_user(user.id).await?
1148 {
1149 self.peer.send(connection_id, incoming_call)?;
1150 }
1151
1152 update_user_contacts(user.id, &session).await?;
1153 }
1154 Principal::DevServer(dev_server) => {
1155 {
1156 let mut pool = self.connection_pool.lock();
1157 if let Some(stale_connection_id) = pool.dev_server_connection_id(dev_server.id)
1158 {
1159 self.peer.send(
1160 stale_connection_id,
1161 proto::ShutdownDevServer {
1162 reason: Some(
1163 "another dev server connected with the same token".to_string(),
1164 ),
1165 },
1166 )?;
1167 pool.remove_connection(stale_connection_id)?;
1168 };
1169 pool.add_dev_server(connection_id, dev_server.id, zed_version);
1170 }
1171
1172 let projects = self
1173 .app_state
1174 .db
1175 .get_projects_for_dev_server(dev_server.id)
1176 .await?;
1177 self.peer
1178 .send(connection_id, proto::DevServerInstructions { projects })?;
1179
1180 let status = self
1181 .app_state
1182 .db
1183 .dev_server_projects_update(dev_server.user_id)
1184 .await?;
1185 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
1186 }
1187 }
1188
1189 Ok(())
1190 }
1191
1192 pub async fn invite_code_redeemed(
1193 self: &Arc<Self>,
1194 inviter_id: UserId,
1195 invitee_id: UserId,
1196 ) -> Result<()> {
1197 if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
1198 if let Some(code) = &user.invite_code {
1199 let pool = self.connection_pool.lock();
1200 let invitee_contact = contact_for_user(invitee_id, false, &pool);
1201 for connection_id in pool.user_connection_ids(inviter_id) {
1202 self.peer.send(
1203 connection_id,
1204 proto::UpdateContacts {
1205 contacts: vec![invitee_contact.clone()],
1206 ..Default::default()
1207 },
1208 )?;
1209 self.peer.send(
1210 connection_id,
1211 proto::UpdateInviteInfo {
1212 url: format!("{}{}", self.app_state.config.invite_link_prefix, &code),
1213 count: user.invite_count as u32,
1214 },
1215 )?;
1216 }
1217 }
1218 }
1219 Ok(())
1220 }
1221
1222 pub async fn invite_count_updated(self: &Arc<Self>, user_id: UserId) -> Result<()> {
1223 if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? {
1224 if let Some(invite_code) = &user.invite_code {
1225 let pool = self.connection_pool.lock();
1226 for connection_id in pool.user_connection_ids(user_id) {
1227 self.peer.send(
1228 connection_id,
1229 proto::UpdateInviteInfo {
1230 url: format!(
1231 "{}{}",
1232 self.app_state.config.invite_link_prefix, invite_code
1233 ),
1234 count: user.invite_count as u32,
1235 },
1236 )?;
1237 }
1238 }
1239 }
1240 Ok(())
1241 }
1242
1243 pub async fn snapshot<'a>(self: &'a Arc<Self>) -> ServerSnapshot<'a> {
1244 ServerSnapshot {
1245 connection_pool: ConnectionPoolGuard {
1246 guard: self.connection_pool.lock(),
1247 _not_send: PhantomData,
1248 },
1249 peer: &self.peer,
1250 }
1251 }
1252}
1253
1254impl<'a> Deref for ConnectionPoolGuard<'a> {
1255 type Target = ConnectionPool;
1256
1257 fn deref(&self) -> &Self::Target {
1258 &self.guard
1259 }
1260}
1261
1262impl<'a> DerefMut for ConnectionPoolGuard<'a> {
1263 fn deref_mut(&mut self) -> &mut Self::Target {
1264 &mut self.guard
1265 }
1266}
1267
1268impl<'a> Drop for ConnectionPoolGuard<'a> {
1269 fn drop(&mut self) {
1270 #[cfg(test)]
1271 self.check_invariants();
1272 }
1273}
1274
1275fn broadcast<F>(
1276 sender_id: Option<ConnectionId>,
1277 receiver_ids: impl IntoIterator<Item = ConnectionId>,
1278 mut f: F,
1279) where
1280 F: FnMut(ConnectionId) -> anyhow::Result<()>,
1281{
1282 for receiver_id in receiver_ids {
1283 if Some(receiver_id) != sender_id {
1284 if let Err(error) = f(receiver_id) {
1285 tracing::error!("failed to send to {:?} {}", receiver_id, error);
1286 }
1287 }
1288 }
1289}
1290
1291pub struct ProtocolVersion(u32);
1292
1293impl Header for ProtocolVersion {
1294 fn name() -> &'static HeaderName {
1295 static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
1296 ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
1297 }
1298
1299 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1300 where
1301 Self: Sized,
1302 I: Iterator<Item = &'i axum::http::HeaderValue>,
1303 {
1304 let version = values
1305 .next()
1306 .ok_or_else(axum::headers::Error::invalid)?
1307 .to_str()
1308 .map_err(|_| axum::headers::Error::invalid())?
1309 .parse()
1310 .map_err(|_| axum::headers::Error::invalid())?;
1311 Ok(Self(version))
1312 }
1313
1314 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1315 values.extend([self.0.to_string().parse().unwrap()]);
1316 }
1317}
1318
1319pub struct AppVersionHeader(SemanticVersion);
1320impl Header for AppVersionHeader {
1321 fn name() -> &'static HeaderName {
1322 static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
1323 ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
1324 }
1325
1326 fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
1327 where
1328 Self: Sized,
1329 I: Iterator<Item = &'i axum::http::HeaderValue>,
1330 {
1331 let version = values
1332 .next()
1333 .ok_or_else(axum::headers::Error::invalid)?
1334 .to_str()
1335 .map_err(|_| axum::headers::Error::invalid())?
1336 .parse()
1337 .map_err(|_| axum::headers::Error::invalid())?;
1338 Ok(Self(version))
1339 }
1340
1341 fn encode<E: Extend<axum::http::HeaderValue>>(&self, values: &mut E) {
1342 values.extend([self.0.to_string().parse().unwrap()]);
1343 }
1344}
1345
1346pub fn routes(server: Arc<Server>) -> Router<(), Body> {
1347 Router::new()
1348 .route("/rpc", get(handle_websocket_request))
1349 .layer(
1350 ServiceBuilder::new()
1351 .layer(Extension(server.app_state.clone()))
1352 .layer(middleware::from_fn(auth::validate_header)),
1353 )
1354 .route("/metrics", get(handle_metrics))
1355 .layer(Extension(server))
1356}
1357
1358pub async fn handle_websocket_request(
1359 TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
1360 app_version_header: Option<TypedHeader<AppVersionHeader>>,
1361 ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
1362 Extension(server): Extension<Arc<Server>>,
1363 Extension(principal): Extension<Principal>,
1364 ws: WebSocketUpgrade,
1365) -> axum::response::Response {
1366 if protocol_version != rpc::PROTOCOL_VERSION {
1367 return (
1368 StatusCode::UPGRADE_REQUIRED,
1369 "client must be upgraded".to_string(),
1370 )
1371 .into_response();
1372 }
1373
1374 let Some(version) = app_version_header.map(|header| ZedVersion(header.0 .0)) else {
1375 return (
1376 StatusCode::UPGRADE_REQUIRED,
1377 "no version header found".to_string(),
1378 )
1379 .into_response();
1380 };
1381
1382 if !version.can_collaborate() {
1383 return (
1384 StatusCode::UPGRADE_REQUIRED,
1385 "client must be upgraded".to_string(),
1386 )
1387 .into_response();
1388 }
1389
1390 let socket_address = socket_address.to_string();
1391 ws.on_upgrade(move |socket| {
1392 let socket = socket
1393 .map_ok(to_tungstenite_message)
1394 .err_into()
1395 .with(|message| async move {
1396 if let Some(message) = to_axum_message(message) {
1397 Ok(message)
1398 } else {
1399 bail!("Could not convert a tungstenite message to axum message");
1400 }
1401 });
1402 let connection = Connection::new(Box::pin(socket));
1403 async move {
1404 server
1405 .handle_connection(
1406 connection,
1407 socket_address,
1408 principal,
1409 version,
1410 None,
1411 Executor::Production,
1412 )
1413 .await;
1414 }
1415 })
1416}
1417
1418pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
1419 static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1420 let connections_metric = CONNECTIONS_METRIC
1421 .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
1422
1423 let connections = server
1424 .connection_pool
1425 .lock()
1426 .connections()
1427 .filter(|connection| !connection.admin)
1428 .count();
1429 connections_metric.set(connections as _);
1430
1431 static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
1432 let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
1433 register_int_gauge!(
1434 "shared_projects",
1435 "number of open projects with one or more guests"
1436 )
1437 .unwrap()
1438 });
1439
1440 let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
1441 shared_projects_metric.set(shared_projects as _);
1442
1443 let encoder = prometheus::TextEncoder::new();
1444 let metric_families = prometheus::gather();
1445 let encoded_metrics = encoder
1446 .encode_to_string(&metric_families)
1447 .map_err(|err| anyhow!("{}", err))?;
1448 Ok(encoded_metrics)
1449}
1450
1451#[instrument(err, skip(executor))]
1452async fn connection_lost(
1453 session: Session,
1454 mut teardown: watch::Receiver<bool>,
1455 executor: Executor,
1456) -> Result<()> {
1457 session.peer.disconnect(session.connection_id);
1458 session
1459 .connection_pool()
1460 .await
1461 .remove_connection(session.connection_id)?;
1462
1463 session
1464 .db()
1465 .await
1466 .connection_lost(session.connection_id)
1467 .await
1468 .trace_err();
1469
1470 futures::select_biased! {
1471 _ = executor.sleep(RECONNECT_TIMEOUT).fuse() => {
1472 match &session.principal {
1473 Principal::User(_) | Principal::Impersonated{ user: _, admin:_ } => {
1474 let session = session.for_user().unwrap();
1475
1476 log::info!("connection lost, removing all resources for user:{}, connection:{:?}", session.user_id(), session.connection_id);
1477 leave_room_for_session(&session, session.connection_id).await.trace_err();
1478 leave_channel_buffers_for_session(&session)
1479 .await
1480 .trace_err();
1481
1482 if !session
1483 .connection_pool()
1484 .await
1485 .is_user_online(session.user_id())
1486 {
1487 let db = session.db().await;
1488 if let Some(room) = db.decline_call(None, session.user_id()).await.trace_err().flatten() {
1489 room_updated(&room, &session.peer);
1490 }
1491 }
1492
1493 update_user_contacts(session.user_id(), &session).await?;
1494 },
1495 Principal::DevServer(_) => {
1496 lost_dev_server_connection(&session.for_dev_server().unwrap()).await?;
1497 },
1498 }
1499 },
1500 _ = teardown.changed().fuse() => {}
1501 }
1502
1503 Ok(())
1504}
1505
1506/// Acknowledges a ping from a client, used to keep the connection alive.
1507async fn ping(_: proto::Ping, response: Response<proto::Ping>, _session: Session) -> Result<()> {
1508 response.send(proto::Ack {})?;
1509 Ok(())
1510}
1511
1512/// Creates a new room for calling (outside of channels)
1513async fn create_room(
1514 _request: proto::CreateRoom,
1515 response: Response<proto::CreateRoom>,
1516 session: UserSession,
1517) -> Result<()> {
1518 let live_kit_room = nanoid::nanoid!(30);
1519
1520 let live_kit_connection_info = util::maybe!(async {
1521 let live_kit = session.live_kit_client.as_ref();
1522 let live_kit = live_kit?;
1523 let user_id = session.user_id().to_string();
1524
1525 let token = live_kit
1526 .room_token(&live_kit_room, &user_id.to_string())
1527 .trace_err()?;
1528
1529 Some(proto::LiveKitConnectionInfo {
1530 server_url: live_kit.url().into(),
1531 token,
1532 can_publish: true,
1533 })
1534 })
1535 .await;
1536
1537 let room = session
1538 .db()
1539 .await
1540 .create_room(session.user_id(), session.connection_id, &live_kit_room)
1541 .await?;
1542
1543 response.send(proto::CreateRoomResponse {
1544 room: Some(room.clone()),
1545 live_kit_connection_info,
1546 })?;
1547
1548 update_user_contacts(session.user_id(), &session).await?;
1549 Ok(())
1550}
1551
1552/// Join a room from an invitation. Equivalent to joining a channel if there is one.
1553async fn join_room(
1554 request: proto::JoinRoom,
1555 response: Response<proto::JoinRoom>,
1556 session: UserSession,
1557) -> Result<()> {
1558 let room_id = RoomId::from_proto(request.id);
1559
1560 let channel_id = session.db().await.channel_id_for_room(room_id).await?;
1561
1562 if let Some(channel_id) = channel_id {
1563 return join_channel_internal(channel_id, Box::new(response), session).await;
1564 }
1565
1566 let joined_room = {
1567 let room = session
1568 .db()
1569 .await
1570 .join_room(room_id, session.user_id(), session.connection_id)
1571 .await?;
1572 room_updated(&room.room, &session.peer);
1573 room.into_inner()
1574 };
1575
1576 for connection_id in session
1577 .connection_pool()
1578 .await
1579 .user_connection_ids(session.user_id())
1580 {
1581 session
1582 .peer
1583 .send(
1584 connection_id,
1585 proto::CallCanceled {
1586 room_id: room_id.to_proto(),
1587 },
1588 )
1589 .trace_err();
1590 }
1591
1592 let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() {
1593 if let Some(token) = live_kit
1594 .room_token(
1595 &joined_room.room.live_kit_room,
1596 &session.user_id().to_string(),
1597 )
1598 .trace_err()
1599 {
1600 Some(proto::LiveKitConnectionInfo {
1601 server_url: live_kit.url().into(),
1602 token,
1603 can_publish: true,
1604 })
1605 } else {
1606 None
1607 }
1608 } else {
1609 None
1610 };
1611
1612 response.send(proto::JoinRoomResponse {
1613 room: Some(joined_room.room),
1614 channel_id: None,
1615 live_kit_connection_info,
1616 })?;
1617
1618 update_user_contacts(session.user_id(), &session).await?;
1619 Ok(())
1620}
1621
1622/// Rejoin room is used to reconnect to a room after connection errors.
1623async fn rejoin_room(
1624 request: proto::RejoinRoom,
1625 response: Response<proto::RejoinRoom>,
1626 session: UserSession,
1627) -> Result<()> {
1628 let room;
1629 let channel;
1630 {
1631 let mut rejoined_room = session
1632 .db()
1633 .await
1634 .rejoin_room(request, session.user_id(), session.connection_id)
1635 .await?;
1636
1637 response.send(proto::RejoinRoomResponse {
1638 room: Some(rejoined_room.room.clone()),
1639 reshared_projects: rejoined_room
1640 .reshared_projects
1641 .iter()
1642 .map(|project| proto::ResharedProject {
1643 id: project.id.to_proto(),
1644 collaborators: project
1645 .collaborators
1646 .iter()
1647 .map(|collaborator| collaborator.to_proto())
1648 .collect(),
1649 })
1650 .collect(),
1651 rejoined_projects: rejoined_room
1652 .rejoined_projects
1653 .iter()
1654 .map(|rejoined_project| rejoined_project.to_proto())
1655 .collect(),
1656 })?;
1657 room_updated(&rejoined_room.room, &session.peer);
1658
1659 for project in &rejoined_room.reshared_projects {
1660 for collaborator in &project.collaborators {
1661 session
1662 .peer
1663 .send(
1664 collaborator.connection_id,
1665 proto::UpdateProjectCollaborator {
1666 project_id: project.id.to_proto(),
1667 old_peer_id: Some(project.old_connection_id.into()),
1668 new_peer_id: Some(session.connection_id.into()),
1669 },
1670 )
1671 .trace_err();
1672 }
1673
1674 broadcast(
1675 Some(session.connection_id),
1676 project
1677 .collaborators
1678 .iter()
1679 .map(|collaborator| collaborator.connection_id),
1680 |connection_id| {
1681 session.peer.forward_send(
1682 session.connection_id,
1683 connection_id,
1684 proto::UpdateProject {
1685 project_id: project.id.to_proto(),
1686 worktrees: project.worktrees.clone(),
1687 },
1688 )
1689 },
1690 );
1691 }
1692
1693 notify_rejoined_projects(&mut rejoined_room.rejoined_projects, &session)?;
1694
1695 let rejoined_room = rejoined_room.into_inner();
1696
1697 room = rejoined_room.room;
1698 channel = rejoined_room.channel;
1699 }
1700
1701 if let Some(channel) = channel {
1702 channel_updated(
1703 &channel,
1704 &room,
1705 &session.peer,
1706 &*session.connection_pool().await,
1707 );
1708 }
1709
1710 update_user_contacts(session.user_id(), &session).await?;
1711 Ok(())
1712}
1713
1714fn notify_rejoined_projects(
1715 rejoined_projects: &mut Vec<RejoinedProject>,
1716 session: &UserSession,
1717) -> Result<()> {
1718 for project in rejoined_projects.iter() {
1719 for collaborator in &project.collaborators {
1720 session
1721 .peer
1722 .send(
1723 collaborator.connection_id,
1724 proto::UpdateProjectCollaborator {
1725 project_id: project.id.to_proto(),
1726 old_peer_id: Some(project.old_connection_id.into()),
1727 new_peer_id: Some(session.connection_id.into()),
1728 },
1729 )
1730 .trace_err();
1731 }
1732 }
1733
1734 for project in rejoined_projects {
1735 for worktree in mem::take(&mut project.worktrees) {
1736 #[cfg(any(test, feature = "test-support"))]
1737 const MAX_CHUNK_SIZE: usize = 2;
1738 #[cfg(not(any(test, feature = "test-support")))]
1739 const MAX_CHUNK_SIZE: usize = 256;
1740
1741 // Stream this worktree's entries.
1742 let message = proto::UpdateWorktree {
1743 project_id: project.id.to_proto(),
1744 worktree_id: worktree.id,
1745 abs_path: worktree.abs_path.clone(),
1746 root_name: worktree.root_name,
1747 updated_entries: worktree.updated_entries,
1748 removed_entries: worktree.removed_entries,
1749 scan_id: worktree.scan_id,
1750 is_last_update: worktree.completed_scan_id == worktree.scan_id,
1751 updated_repositories: worktree.updated_repositories,
1752 removed_repositories: worktree.removed_repositories,
1753 };
1754 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
1755 session.peer.send(session.connection_id, update.clone())?;
1756 }
1757
1758 // Stream this worktree's diagnostics.
1759 for summary in worktree.diagnostic_summaries {
1760 session.peer.send(
1761 session.connection_id,
1762 proto::UpdateDiagnosticSummary {
1763 project_id: project.id.to_proto(),
1764 worktree_id: worktree.id,
1765 summary: Some(summary),
1766 },
1767 )?;
1768 }
1769
1770 for settings_file in worktree.settings_files {
1771 session.peer.send(
1772 session.connection_id,
1773 proto::UpdateWorktreeSettings {
1774 project_id: project.id.to_proto(),
1775 worktree_id: worktree.id,
1776 path: settings_file.path,
1777 content: Some(settings_file.content),
1778 },
1779 )?;
1780 }
1781 }
1782
1783 for language_server in &project.language_servers {
1784 session.peer.send(
1785 session.connection_id,
1786 proto::UpdateLanguageServer {
1787 project_id: project.id.to_proto(),
1788 language_server_id: language_server.id,
1789 variant: Some(
1790 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
1791 proto::LspDiskBasedDiagnosticsUpdated {},
1792 ),
1793 ),
1794 },
1795 )?;
1796 }
1797 }
1798 Ok(())
1799}
1800
1801/// leave room disconnects from the room.
1802async fn leave_room(
1803 _: proto::LeaveRoom,
1804 response: Response<proto::LeaveRoom>,
1805 session: UserSession,
1806) -> Result<()> {
1807 leave_room_for_session(&session, session.connection_id).await?;
1808 response.send(proto::Ack {})?;
1809 Ok(())
1810}
1811
1812/// Updates the permissions of someone else in the room.
1813async fn set_room_participant_role(
1814 request: proto::SetRoomParticipantRole,
1815 response: Response<proto::SetRoomParticipantRole>,
1816 session: UserSession,
1817) -> Result<()> {
1818 let user_id = UserId::from_proto(request.user_id);
1819 let role = ChannelRole::from(request.role());
1820
1821 let (live_kit_room, can_publish) = {
1822 let room = session
1823 .db()
1824 .await
1825 .set_room_participant_role(
1826 session.user_id(),
1827 RoomId::from_proto(request.room_id),
1828 user_id,
1829 role,
1830 )
1831 .await?;
1832
1833 let live_kit_room = room.live_kit_room.clone();
1834 let can_publish = ChannelRole::from(request.role()).can_use_microphone();
1835 room_updated(&room, &session.peer);
1836 (live_kit_room, can_publish)
1837 };
1838
1839 if let Some(live_kit) = session.live_kit_client.as_ref() {
1840 live_kit
1841 .update_participant(
1842 live_kit_room.clone(),
1843 request.user_id.to_string(),
1844 live_kit_server::proto::ParticipantPermission {
1845 can_subscribe: true,
1846 can_publish,
1847 can_publish_data: can_publish,
1848 hidden: false,
1849 recorder: false,
1850 },
1851 )
1852 .await
1853 .trace_err();
1854 }
1855
1856 response.send(proto::Ack {})?;
1857 Ok(())
1858}
1859
1860/// Call someone else into the current room
1861async fn call(
1862 request: proto::Call,
1863 response: Response<proto::Call>,
1864 session: UserSession,
1865) -> Result<()> {
1866 let room_id = RoomId::from_proto(request.room_id);
1867 let calling_user_id = session.user_id();
1868 let calling_connection_id = session.connection_id;
1869 let called_user_id = UserId::from_proto(request.called_user_id);
1870 let initial_project_id = request.initial_project_id.map(ProjectId::from_proto);
1871 if !session
1872 .db()
1873 .await
1874 .has_contact(calling_user_id, called_user_id)
1875 .await?
1876 {
1877 return Err(anyhow!("cannot call a user who isn't a contact"))?;
1878 }
1879
1880 let incoming_call = {
1881 let (room, incoming_call) = &mut *session
1882 .db()
1883 .await
1884 .call(
1885 room_id,
1886 calling_user_id,
1887 calling_connection_id,
1888 called_user_id,
1889 initial_project_id,
1890 )
1891 .await?;
1892 room_updated(&room, &session.peer);
1893 mem::take(incoming_call)
1894 };
1895 update_user_contacts(called_user_id, &session).await?;
1896
1897 let mut calls = session
1898 .connection_pool()
1899 .await
1900 .user_connection_ids(called_user_id)
1901 .map(|connection_id| session.peer.request(connection_id, incoming_call.clone()))
1902 .collect::<FuturesUnordered<_>>();
1903
1904 while let Some(call_response) = calls.next().await {
1905 match call_response.as_ref() {
1906 Ok(_) => {
1907 response.send(proto::Ack {})?;
1908 return Ok(());
1909 }
1910 Err(_) => {
1911 call_response.trace_err();
1912 }
1913 }
1914 }
1915
1916 {
1917 let room = session
1918 .db()
1919 .await
1920 .call_failed(room_id, called_user_id)
1921 .await?;
1922 room_updated(&room, &session.peer);
1923 }
1924 update_user_contacts(called_user_id, &session).await?;
1925
1926 Err(anyhow!("failed to ring user"))?
1927}
1928
1929/// Cancel an outgoing call.
1930async fn cancel_call(
1931 request: proto::CancelCall,
1932 response: Response<proto::CancelCall>,
1933 session: UserSession,
1934) -> Result<()> {
1935 let called_user_id = UserId::from_proto(request.called_user_id);
1936 let room_id = RoomId::from_proto(request.room_id);
1937 {
1938 let room = session
1939 .db()
1940 .await
1941 .cancel_call(room_id, session.connection_id, called_user_id)
1942 .await?;
1943 room_updated(&room, &session.peer);
1944 }
1945
1946 for connection_id in session
1947 .connection_pool()
1948 .await
1949 .user_connection_ids(called_user_id)
1950 {
1951 session
1952 .peer
1953 .send(
1954 connection_id,
1955 proto::CallCanceled {
1956 room_id: room_id.to_proto(),
1957 },
1958 )
1959 .trace_err();
1960 }
1961 response.send(proto::Ack {})?;
1962
1963 update_user_contacts(called_user_id, &session).await?;
1964 Ok(())
1965}
1966
1967/// Decline an incoming call.
1968async fn decline_call(message: proto::DeclineCall, session: UserSession) -> Result<()> {
1969 let room_id = RoomId::from_proto(message.room_id);
1970 {
1971 let room = session
1972 .db()
1973 .await
1974 .decline_call(Some(room_id), session.user_id())
1975 .await?
1976 .ok_or_else(|| anyhow!("failed to decline call"))?;
1977 room_updated(&room, &session.peer);
1978 }
1979
1980 for connection_id in session
1981 .connection_pool()
1982 .await
1983 .user_connection_ids(session.user_id())
1984 {
1985 session
1986 .peer
1987 .send(
1988 connection_id,
1989 proto::CallCanceled {
1990 room_id: room_id.to_proto(),
1991 },
1992 )
1993 .trace_err();
1994 }
1995 update_user_contacts(session.user_id(), &session).await?;
1996 Ok(())
1997}
1998
1999/// Updates other participants in the room with your current location.
2000async fn update_participant_location(
2001 request: proto::UpdateParticipantLocation,
2002 response: Response<proto::UpdateParticipantLocation>,
2003 session: UserSession,
2004) -> Result<()> {
2005 let room_id = RoomId::from_proto(request.room_id);
2006 let location = request
2007 .location
2008 .ok_or_else(|| anyhow!("invalid location"))?;
2009
2010 let db = session.db().await;
2011 let room = db
2012 .update_room_participant_location(room_id, session.connection_id, location)
2013 .await?;
2014
2015 room_updated(&room, &session.peer);
2016 response.send(proto::Ack {})?;
2017 Ok(())
2018}
2019
2020/// Share a project into the room.
2021async fn share_project(
2022 request: proto::ShareProject,
2023 response: Response<proto::ShareProject>,
2024 session: UserSession,
2025) -> Result<()> {
2026 let (project_id, room) = &*session
2027 .db()
2028 .await
2029 .share_project(
2030 RoomId::from_proto(request.room_id),
2031 session.connection_id,
2032 &request.worktrees,
2033 request
2034 .dev_server_project_id
2035 .map(|id| DevServerProjectId::from_proto(id)),
2036 )
2037 .await?;
2038 response.send(proto::ShareProjectResponse {
2039 project_id: project_id.to_proto(),
2040 })?;
2041 room_updated(&room, &session.peer);
2042
2043 Ok(())
2044}
2045
2046/// Unshare a project from the room.
2047async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> {
2048 let project_id = ProjectId::from_proto(message.project_id);
2049 unshare_project_internal(
2050 project_id,
2051 session.connection_id,
2052 session.user_id(),
2053 &session,
2054 )
2055 .await
2056}
2057
2058async fn unshare_project_internal(
2059 project_id: ProjectId,
2060 connection_id: ConnectionId,
2061 user_id: Option<UserId>,
2062 session: &Session,
2063) -> Result<()> {
2064 let delete = {
2065 let room_guard = session
2066 .db()
2067 .await
2068 .unshare_project(project_id, connection_id, user_id)
2069 .await?;
2070
2071 let (delete, room, guest_connection_ids) = &*room_guard;
2072
2073 let message = proto::UnshareProject {
2074 project_id: project_id.to_proto(),
2075 };
2076
2077 broadcast(
2078 Some(connection_id),
2079 guest_connection_ids.iter().copied(),
2080 |conn_id| session.peer.send(conn_id, message.clone()),
2081 );
2082 if let Some(room) = room {
2083 room_updated(room, &session.peer);
2084 }
2085
2086 *delete
2087 };
2088
2089 if delete {
2090 let db = session.db().await;
2091 db.delete_project(project_id).await?;
2092 }
2093
2094 Ok(())
2095}
2096
2097/// DevServer makes a project available online
2098async fn share_dev_server_project(
2099 request: proto::ShareDevServerProject,
2100 response: Response<proto::ShareDevServerProject>,
2101 session: DevServerSession,
2102) -> Result<()> {
2103 let (dev_server_project, user_id, status) = session
2104 .db()
2105 .await
2106 .share_dev_server_project(
2107 DevServerProjectId::from_proto(request.dev_server_project_id),
2108 session.dev_server_id(),
2109 session.connection_id,
2110 &request.worktrees,
2111 )
2112 .await?;
2113 let Some(project_id) = dev_server_project.project_id else {
2114 return Err(anyhow!("failed to share remote project"))?;
2115 };
2116
2117 send_dev_server_projects_update(user_id, status, &session).await;
2118
2119 response.send(proto::ShareProjectResponse { project_id })?;
2120
2121 Ok(())
2122}
2123
2124/// Join someone elses shared project.
2125async fn join_project(
2126 request: proto::JoinProject,
2127 response: Response<proto::JoinProject>,
2128 session: UserSession,
2129) -> Result<()> {
2130 let project_id = ProjectId::from_proto(request.project_id);
2131
2132 tracing::info!(%project_id, "join project");
2133
2134 let db = session.db().await;
2135 let (project, replica_id) = &mut *db
2136 .join_project(project_id, session.connection_id, session.user_id())
2137 .await?;
2138 drop(db);
2139 tracing::info!(%project_id, "join remote project");
2140 join_project_internal(response, session, project, replica_id)
2141}
2142
2143trait JoinProjectInternalResponse {
2144 fn send(self, result: proto::JoinProjectResponse) -> Result<()>;
2145}
2146impl JoinProjectInternalResponse for Response<proto::JoinProject> {
2147 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2148 Response::<proto::JoinProject>::send(self, result)
2149 }
2150}
2151impl JoinProjectInternalResponse for Response<proto::JoinHostedProject> {
2152 fn send(self, result: proto::JoinProjectResponse) -> Result<()> {
2153 Response::<proto::JoinHostedProject>::send(self, result)
2154 }
2155}
2156
2157fn join_project_internal(
2158 response: impl JoinProjectInternalResponse,
2159 session: UserSession,
2160 project: &mut Project,
2161 replica_id: &ReplicaId,
2162) -> Result<()> {
2163 let collaborators = project
2164 .collaborators
2165 .iter()
2166 .filter(|collaborator| collaborator.connection_id != session.connection_id)
2167 .map(|collaborator| collaborator.to_proto())
2168 .collect::<Vec<_>>();
2169 let project_id = project.id;
2170 let guest_user_id = session.user_id();
2171
2172 let worktrees = project
2173 .worktrees
2174 .iter()
2175 .map(|(id, worktree)| proto::WorktreeMetadata {
2176 id: *id,
2177 root_name: worktree.root_name.clone(),
2178 visible: worktree.visible,
2179 abs_path: worktree.abs_path.clone(),
2180 })
2181 .collect::<Vec<_>>();
2182
2183 let add_project_collaborator = proto::AddProjectCollaborator {
2184 project_id: project_id.to_proto(),
2185 collaborator: Some(proto::Collaborator {
2186 peer_id: Some(session.connection_id.into()),
2187 replica_id: replica_id.0 as u32,
2188 user_id: guest_user_id.to_proto(),
2189 }),
2190 };
2191
2192 for collaborator in &collaborators {
2193 session
2194 .peer
2195 .send(
2196 collaborator.peer_id.unwrap().into(),
2197 add_project_collaborator.clone(),
2198 )
2199 .trace_err();
2200 }
2201
2202 // First, we send the metadata associated with each worktree.
2203 response.send(proto::JoinProjectResponse {
2204 project_id: project.id.0 as u64,
2205 worktrees: worktrees.clone(),
2206 replica_id: replica_id.0 as u32,
2207 collaborators: collaborators.clone(),
2208 language_servers: project.language_servers.clone(),
2209 role: project.role.into(),
2210 dev_server_project_id: project
2211 .dev_server_project_id
2212 .map(|dev_server_project_id| dev_server_project_id.0 as u64),
2213 })?;
2214
2215 for (worktree_id, worktree) in mem::take(&mut project.worktrees) {
2216 #[cfg(any(test, feature = "test-support"))]
2217 const MAX_CHUNK_SIZE: usize = 2;
2218 #[cfg(not(any(test, feature = "test-support")))]
2219 const MAX_CHUNK_SIZE: usize = 256;
2220
2221 // Stream this worktree's entries.
2222 let message = proto::UpdateWorktree {
2223 project_id: project_id.to_proto(),
2224 worktree_id,
2225 abs_path: worktree.abs_path.clone(),
2226 root_name: worktree.root_name,
2227 updated_entries: worktree.entries,
2228 removed_entries: Default::default(),
2229 scan_id: worktree.scan_id,
2230 is_last_update: worktree.scan_id == worktree.completed_scan_id,
2231 updated_repositories: worktree.repository_entries.into_values().collect(),
2232 removed_repositories: Default::default(),
2233 };
2234 for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) {
2235 session.peer.send(session.connection_id, update.clone())?;
2236 }
2237
2238 // Stream this worktree's diagnostics.
2239 for summary in worktree.diagnostic_summaries {
2240 session.peer.send(
2241 session.connection_id,
2242 proto::UpdateDiagnosticSummary {
2243 project_id: project_id.to_proto(),
2244 worktree_id: worktree.id,
2245 summary: Some(summary),
2246 },
2247 )?;
2248 }
2249
2250 for settings_file in worktree.settings_files {
2251 session.peer.send(
2252 session.connection_id,
2253 proto::UpdateWorktreeSettings {
2254 project_id: project_id.to_proto(),
2255 worktree_id: worktree.id,
2256 path: settings_file.path,
2257 content: Some(settings_file.content),
2258 },
2259 )?;
2260 }
2261 }
2262
2263 for language_server in &project.language_servers {
2264 session.peer.send(
2265 session.connection_id,
2266 proto::UpdateLanguageServer {
2267 project_id: project_id.to_proto(),
2268 language_server_id: language_server.id,
2269 variant: Some(
2270 proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
2271 proto::LspDiskBasedDiagnosticsUpdated {},
2272 ),
2273 ),
2274 },
2275 )?;
2276 }
2277
2278 Ok(())
2279}
2280
2281/// Leave someone elses shared project.
2282async fn leave_project(request: proto::LeaveProject, session: UserSession) -> Result<()> {
2283 let sender_id = session.connection_id;
2284 let project_id = ProjectId::from_proto(request.project_id);
2285 let db = session.db().await;
2286 if db.is_hosted_project(project_id).await? {
2287 let project = db.leave_hosted_project(project_id, sender_id).await?;
2288 project_left(&project, &session);
2289 return Ok(());
2290 }
2291
2292 let (room, project) = &*db.leave_project(project_id, sender_id).await?;
2293 tracing::info!(
2294 %project_id,
2295 "leave project"
2296 );
2297
2298 project_left(&project, &session);
2299 if let Some(room) = room {
2300 room_updated(&room, &session.peer);
2301 }
2302
2303 Ok(())
2304}
2305
2306async fn join_hosted_project(
2307 request: proto::JoinHostedProject,
2308 response: Response<proto::JoinHostedProject>,
2309 session: UserSession,
2310) -> Result<()> {
2311 let (mut project, replica_id) = session
2312 .db()
2313 .await
2314 .join_hosted_project(
2315 ProjectId(request.project_id as i32),
2316 session.user_id(),
2317 session.connection_id,
2318 )
2319 .await?;
2320
2321 join_project_internal(response, session, &mut project, &replica_id)
2322}
2323
2324async fn list_remote_directory(
2325 request: proto::ListRemoteDirectory,
2326 response: Response<proto::ListRemoteDirectory>,
2327 session: UserSession,
2328) -> Result<()> {
2329 let dev_server_id = DevServerId(request.dev_server_id as i32);
2330 let dev_server_connection_id = session
2331 .connection_pool()
2332 .await
2333 .dev_server_connection_id_supporting(dev_server_id, ZedVersion::with_list_directory())?;
2334
2335 session
2336 .db()
2337 .await
2338 .get_dev_server_for_user(dev_server_id, session.user_id())
2339 .await?;
2340
2341 response.send(
2342 session
2343 .peer
2344 .forward_request(session.connection_id, dev_server_connection_id, request)
2345 .await?,
2346 )?;
2347 Ok(())
2348}
2349
2350async fn update_dev_server_project(
2351 request: proto::UpdateDevServerProject,
2352 response: Response<proto::UpdateDevServerProject>,
2353 session: UserSession,
2354) -> Result<()> {
2355 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2356
2357 let (dev_server_project, update) = session
2358 .db()
2359 .await
2360 .update_dev_server_project(dev_server_project_id, &request.paths, session.user_id())
2361 .await?;
2362
2363 let projects = session
2364 .db()
2365 .await
2366 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2367 .await?;
2368
2369 let dev_server_connection_id = session
2370 .connection_pool()
2371 .await
2372 .dev_server_connection_id_supporting(
2373 dev_server_project.dev_server_id,
2374 ZedVersion::with_list_directory(),
2375 )?;
2376
2377 session.peer.send(
2378 dev_server_connection_id,
2379 proto::DevServerInstructions { projects },
2380 )?;
2381
2382 send_dev_server_projects_update(session.user_id(), update, &session).await;
2383
2384 response.send(proto::Ack {})
2385}
2386
2387async fn create_dev_server_project(
2388 request: proto::CreateDevServerProject,
2389 response: Response<proto::CreateDevServerProject>,
2390 session: UserSession,
2391) -> Result<()> {
2392 let dev_server_id = DevServerId(request.dev_server_id as i32);
2393 let dev_server_connection_id = session
2394 .connection_pool()
2395 .await
2396 .dev_server_connection_id(dev_server_id);
2397 let Some(dev_server_connection_id) = dev_server_connection_id else {
2398 Err(ErrorCode::DevServerOffline
2399 .message("Cannot create a remote project when the dev server is offline".to_string())
2400 .anyhow())?
2401 };
2402
2403 let path = request.path.clone();
2404 //Check that the path exists on the dev server
2405 session
2406 .peer
2407 .forward_request(
2408 session.connection_id,
2409 dev_server_connection_id,
2410 proto::ValidateDevServerProjectRequest { path: path.clone() },
2411 )
2412 .await?;
2413
2414 let (dev_server_project, update) = session
2415 .db()
2416 .await
2417 .create_dev_server_project(
2418 DevServerId(request.dev_server_id as i32),
2419 &request.path,
2420 session.user_id(),
2421 )
2422 .await?;
2423
2424 let projects = session
2425 .db()
2426 .await
2427 .get_projects_for_dev_server(dev_server_project.dev_server_id)
2428 .await?;
2429
2430 session.peer.send(
2431 dev_server_connection_id,
2432 proto::DevServerInstructions { projects },
2433 )?;
2434
2435 send_dev_server_projects_update(session.user_id(), update, &session).await;
2436
2437 response.send(proto::CreateDevServerProjectResponse {
2438 dev_server_project: Some(dev_server_project.to_proto(None)),
2439 })?;
2440 Ok(())
2441}
2442
2443async fn create_dev_server(
2444 request: proto::CreateDevServer,
2445 response: Response<proto::CreateDevServer>,
2446 session: UserSession,
2447) -> Result<()> {
2448 let access_token = auth::random_token();
2449 let hashed_access_token = auth::hash_access_token(&access_token);
2450
2451 if request.name.is_empty() {
2452 return Err(proto::ErrorCode::Forbidden
2453 .message("Dev server name cannot be empty".to_string())
2454 .anyhow())?;
2455 }
2456
2457 let (dev_server, status) = session
2458 .db()
2459 .await
2460 .create_dev_server(
2461 &request.name,
2462 request.ssh_connection_string.as_deref(),
2463 &hashed_access_token,
2464 session.user_id(),
2465 )
2466 .await?;
2467
2468 send_dev_server_projects_update(session.user_id(), status, &session).await;
2469
2470 response.send(proto::CreateDevServerResponse {
2471 dev_server_id: dev_server.id.0 as u64,
2472 access_token: auth::generate_dev_server_token(dev_server.id.0 as usize, access_token),
2473 name: request.name,
2474 })?;
2475 Ok(())
2476}
2477
2478async fn regenerate_dev_server_token(
2479 request: proto::RegenerateDevServerToken,
2480 response: Response<proto::RegenerateDevServerToken>,
2481 session: UserSession,
2482) -> Result<()> {
2483 let dev_server_id = DevServerId(request.dev_server_id as i32);
2484 let access_token = auth::random_token();
2485 let hashed_access_token = auth::hash_access_token(&access_token);
2486
2487 let connection_id = session
2488 .connection_pool()
2489 .await
2490 .dev_server_connection_id(dev_server_id);
2491 if let Some(connection_id) = connection_id {
2492 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2493 session.peer.send(
2494 connection_id,
2495 proto::ShutdownDevServer {
2496 reason: Some("dev server token was regenerated".to_string()),
2497 },
2498 )?;
2499 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2500 }
2501
2502 let status = session
2503 .db()
2504 .await
2505 .update_dev_server_token(dev_server_id, &hashed_access_token, session.user_id())
2506 .await?;
2507
2508 send_dev_server_projects_update(session.user_id(), status, &session).await;
2509
2510 response.send(proto::RegenerateDevServerTokenResponse {
2511 dev_server_id: dev_server_id.to_proto(),
2512 access_token: auth::generate_dev_server_token(dev_server_id.0 as usize, access_token),
2513 })?;
2514 Ok(())
2515}
2516
2517async fn rename_dev_server(
2518 request: proto::RenameDevServer,
2519 response: Response<proto::RenameDevServer>,
2520 session: UserSession,
2521) -> Result<()> {
2522 if request.name.trim().is_empty() {
2523 return Err(proto::ErrorCode::Forbidden
2524 .message("Dev server name cannot be empty".to_string())
2525 .anyhow())?;
2526 }
2527
2528 let dev_server_id = DevServerId(request.dev_server_id as i32);
2529 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2530 if dev_server.user_id != session.user_id() {
2531 return Err(anyhow!(ErrorCode::Forbidden))?;
2532 }
2533
2534 let status = session
2535 .db()
2536 .await
2537 .rename_dev_server(
2538 dev_server_id,
2539 &request.name,
2540 request.ssh_connection_string.as_deref(),
2541 session.user_id(),
2542 )
2543 .await?;
2544
2545 send_dev_server_projects_update(session.user_id(), status, &session).await;
2546
2547 response.send(proto::Ack {})?;
2548 Ok(())
2549}
2550
2551async fn delete_dev_server(
2552 request: proto::DeleteDevServer,
2553 response: Response<proto::DeleteDevServer>,
2554 session: UserSession,
2555) -> Result<()> {
2556 let dev_server_id = DevServerId(request.dev_server_id as i32);
2557 let dev_server = session.db().await.get_dev_server(dev_server_id).await?;
2558 if dev_server.user_id != session.user_id() {
2559 return Err(anyhow!(ErrorCode::Forbidden))?;
2560 }
2561
2562 let connection_id = session
2563 .connection_pool()
2564 .await
2565 .dev_server_connection_id(dev_server_id);
2566 if let Some(connection_id) = connection_id {
2567 shutdown_dev_server_internal(dev_server_id, connection_id, &session).await?;
2568 session.peer.send(
2569 connection_id,
2570 proto::ShutdownDevServer {
2571 reason: Some("dev server was deleted".to_string()),
2572 },
2573 )?;
2574 let _ = remove_dev_server_connection(dev_server_id, &session).await;
2575 }
2576
2577 let status = session
2578 .db()
2579 .await
2580 .delete_dev_server(dev_server_id, session.user_id())
2581 .await?;
2582
2583 send_dev_server_projects_update(session.user_id(), status, &session).await;
2584
2585 response.send(proto::Ack {})?;
2586 Ok(())
2587}
2588
2589async fn delete_dev_server_project(
2590 request: proto::DeleteDevServerProject,
2591 response: Response<proto::DeleteDevServerProject>,
2592 session: UserSession,
2593) -> Result<()> {
2594 let dev_server_project_id = DevServerProjectId(request.dev_server_project_id as i32);
2595 let dev_server_project = session
2596 .db()
2597 .await
2598 .get_dev_server_project(dev_server_project_id)
2599 .await?;
2600
2601 let dev_server = session
2602 .db()
2603 .await
2604 .get_dev_server(dev_server_project.dev_server_id)
2605 .await?;
2606 if dev_server.user_id != session.user_id() {
2607 return Err(anyhow!(ErrorCode::Forbidden))?;
2608 }
2609
2610 let dev_server_connection_id = session
2611 .connection_pool()
2612 .await
2613 .dev_server_connection_id(dev_server.id);
2614
2615 if let Some(dev_server_connection_id) = dev_server_connection_id {
2616 let project = session
2617 .db()
2618 .await
2619 .find_dev_server_project(dev_server_project_id)
2620 .await;
2621 if let Ok(project) = project {
2622 unshare_project_internal(
2623 project.id,
2624 dev_server_connection_id,
2625 Some(session.user_id()),
2626 &session,
2627 )
2628 .await?;
2629 }
2630 }
2631
2632 let (projects, status) = session
2633 .db()
2634 .await
2635 .delete_dev_server_project(dev_server_project_id, dev_server.id, session.user_id())
2636 .await?;
2637
2638 if let Some(dev_server_connection_id) = dev_server_connection_id {
2639 session.peer.send(
2640 dev_server_connection_id,
2641 proto::DevServerInstructions { projects },
2642 )?;
2643 }
2644
2645 send_dev_server_projects_update(session.user_id(), status, &session).await;
2646
2647 response.send(proto::Ack {})?;
2648 Ok(())
2649}
2650
2651async fn rejoin_dev_server_projects(
2652 request: proto::RejoinRemoteProjects,
2653 response: Response<proto::RejoinRemoteProjects>,
2654 session: UserSession,
2655) -> Result<()> {
2656 let mut rejoined_projects = {
2657 let db = session.db().await;
2658 db.rejoin_dev_server_projects(
2659 &request.rejoined_projects,
2660 session.user_id(),
2661 session.0.connection_id,
2662 )
2663 .await?
2664 };
2665 response.send(proto::RejoinRemoteProjectsResponse {
2666 rejoined_projects: rejoined_projects
2667 .iter()
2668 .map(|project| project.to_proto())
2669 .collect(),
2670 })?;
2671 notify_rejoined_projects(&mut rejoined_projects, &session)
2672}
2673
2674async fn reconnect_dev_server(
2675 request: proto::ReconnectDevServer,
2676 response: Response<proto::ReconnectDevServer>,
2677 session: DevServerSession,
2678) -> Result<()> {
2679 let reshared_projects = {
2680 let db = session.db().await;
2681 db.reshare_dev_server_projects(
2682 &request.reshared_projects,
2683 session.dev_server_id(),
2684 session.0.connection_id,
2685 )
2686 .await?
2687 };
2688
2689 for project in &reshared_projects {
2690 for collaborator in &project.collaborators {
2691 session
2692 .peer
2693 .send(
2694 collaborator.connection_id,
2695 proto::UpdateProjectCollaborator {
2696 project_id: project.id.to_proto(),
2697 old_peer_id: Some(project.old_connection_id.into()),
2698 new_peer_id: Some(session.connection_id.into()),
2699 },
2700 )
2701 .trace_err();
2702 }
2703
2704 broadcast(
2705 Some(session.connection_id),
2706 project
2707 .collaborators
2708 .iter()
2709 .map(|collaborator| collaborator.connection_id),
2710 |connection_id| {
2711 session.peer.forward_send(
2712 session.connection_id,
2713 connection_id,
2714 proto::UpdateProject {
2715 project_id: project.id.to_proto(),
2716 worktrees: project.worktrees.clone(),
2717 },
2718 )
2719 },
2720 );
2721 }
2722
2723 response.send(proto::ReconnectDevServerResponse {
2724 reshared_projects: reshared_projects
2725 .iter()
2726 .map(|project| proto::ResharedProject {
2727 id: project.id.to_proto(),
2728 collaborators: project
2729 .collaborators
2730 .iter()
2731 .map(|collaborator| collaborator.to_proto())
2732 .collect(),
2733 })
2734 .collect(),
2735 })?;
2736
2737 Ok(())
2738}
2739
2740async fn shutdown_dev_server(
2741 _: proto::ShutdownDevServer,
2742 response: Response<proto::ShutdownDevServer>,
2743 session: DevServerSession,
2744) -> Result<()> {
2745 response.send(proto::Ack {})?;
2746 shutdown_dev_server_internal(session.dev_server_id(), session.connection_id, &session).await?;
2747 remove_dev_server_connection(session.dev_server_id(), &session).await
2748}
2749
2750async fn shutdown_dev_server_internal(
2751 dev_server_id: DevServerId,
2752 connection_id: ConnectionId,
2753 session: &Session,
2754) -> Result<()> {
2755 let (dev_server_projects, dev_server) = {
2756 let db = session.db().await;
2757 let dev_server_projects = db.get_projects_for_dev_server(dev_server_id).await?;
2758 let dev_server = db.get_dev_server(dev_server_id).await?;
2759 (dev_server_projects, dev_server)
2760 };
2761
2762 for project_id in dev_server_projects.iter().filter_map(|p| p.project_id) {
2763 unshare_project_internal(
2764 ProjectId::from_proto(project_id),
2765 connection_id,
2766 None,
2767 session,
2768 )
2769 .await?;
2770 }
2771
2772 session
2773 .connection_pool()
2774 .await
2775 .set_dev_server_offline(dev_server_id);
2776
2777 let status = session
2778 .db()
2779 .await
2780 .dev_server_projects_update(dev_server.user_id)
2781 .await?;
2782 send_dev_server_projects_update(dev_server.user_id, status, &session).await;
2783
2784 Ok(())
2785}
2786
2787async fn remove_dev_server_connection(dev_server_id: DevServerId, session: &Session) -> Result<()> {
2788 let dev_server_connection = session
2789 .connection_pool()
2790 .await
2791 .dev_server_connection_id(dev_server_id);
2792
2793 if let Some(dev_server_connection) = dev_server_connection {
2794 session
2795 .connection_pool()
2796 .await
2797 .remove_connection(dev_server_connection)?;
2798 }
2799 Ok(())
2800}
2801
2802/// Updates other participants with changes to the project
2803async fn update_project(
2804 request: proto::UpdateProject,
2805 response: Response<proto::UpdateProject>,
2806 session: Session,
2807) -> Result<()> {
2808 let project_id = ProjectId::from_proto(request.project_id);
2809 let (room, guest_connection_ids) = &*session
2810 .db()
2811 .await
2812 .update_project(project_id, session.connection_id, &request.worktrees)
2813 .await?;
2814 broadcast(
2815 Some(session.connection_id),
2816 guest_connection_ids.iter().copied(),
2817 |connection_id| {
2818 session
2819 .peer
2820 .forward_send(session.connection_id, connection_id, request.clone())
2821 },
2822 );
2823 if let Some(room) = room {
2824 room_updated(&room, &session.peer);
2825 }
2826 response.send(proto::Ack {})?;
2827
2828 Ok(())
2829}
2830
2831/// Updates other participants with changes to the worktree
2832async fn update_worktree(
2833 request: proto::UpdateWorktree,
2834 response: Response<proto::UpdateWorktree>,
2835 session: Session,
2836) -> Result<()> {
2837 let guest_connection_ids = session
2838 .db()
2839 .await
2840 .update_worktree(&request, session.connection_id)
2841 .await?;
2842
2843 broadcast(
2844 Some(session.connection_id),
2845 guest_connection_ids.iter().copied(),
2846 |connection_id| {
2847 session
2848 .peer
2849 .forward_send(session.connection_id, connection_id, request.clone())
2850 },
2851 );
2852 response.send(proto::Ack {})?;
2853 Ok(())
2854}
2855
2856/// Updates other participants with changes to the diagnostics
2857async fn update_diagnostic_summary(
2858 message: proto::UpdateDiagnosticSummary,
2859 session: Session,
2860) -> Result<()> {
2861 let guest_connection_ids = session
2862 .db()
2863 .await
2864 .update_diagnostic_summary(&message, session.connection_id)
2865 .await?;
2866
2867 broadcast(
2868 Some(session.connection_id),
2869 guest_connection_ids.iter().copied(),
2870 |connection_id| {
2871 session
2872 .peer
2873 .forward_send(session.connection_id, connection_id, message.clone())
2874 },
2875 );
2876
2877 Ok(())
2878}
2879
2880/// Updates other participants with changes to the worktree settings
2881async fn update_worktree_settings(
2882 message: proto::UpdateWorktreeSettings,
2883 session: Session,
2884) -> Result<()> {
2885 let guest_connection_ids = session
2886 .db()
2887 .await
2888 .update_worktree_settings(&message, session.connection_id)
2889 .await?;
2890
2891 broadcast(
2892 Some(session.connection_id),
2893 guest_connection_ids.iter().copied(),
2894 |connection_id| {
2895 session
2896 .peer
2897 .forward_send(session.connection_id, connection_id, message.clone())
2898 },
2899 );
2900
2901 Ok(())
2902}
2903
2904/// Notify other participants that a language server has started.
2905async fn start_language_server(
2906 request: proto::StartLanguageServer,
2907 session: Session,
2908) -> Result<()> {
2909 let guest_connection_ids = session
2910 .db()
2911 .await
2912 .start_language_server(&request, session.connection_id)
2913 .await?;
2914
2915 broadcast(
2916 Some(session.connection_id),
2917 guest_connection_ids.iter().copied(),
2918 |connection_id| {
2919 session
2920 .peer
2921 .forward_send(session.connection_id, connection_id, request.clone())
2922 },
2923 );
2924 Ok(())
2925}
2926
2927/// Notify other participants that a language server has changed.
2928async fn update_language_server(
2929 request: proto::UpdateLanguageServer,
2930 session: Session,
2931) -> Result<()> {
2932 let project_id = ProjectId::from_proto(request.project_id);
2933 let project_connection_ids = session
2934 .db()
2935 .await
2936 .project_connection_ids(project_id, session.connection_id, true)
2937 .await?;
2938 broadcast(
2939 Some(session.connection_id),
2940 project_connection_ids.iter().copied(),
2941 |connection_id| {
2942 session
2943 .peer
2944 .forward_send(session.connection_id, connection_id, request.clone())
2945 },
2946 );
2947 Ok(())
2948}
2949
2950/// forward a project request to the host. These requests should be read only
2951/// as guests are allowed to send them.
2952async fn forward_read_only_project_request<T>(
2953 request: T,
2954 response: Response<T>,
2955 session: UserSession,
2956) -> Result<()>
2957where
2958 T: EntityMessage + RequestMessage,
2959{
2960 let project_id = ProjectId::from_proto(request.remote_entity_id());
2961 let host_connection_id = session
2962 .db()
2963 .await
2964 .host_for_read_only_project_request(project_id, session.connection_id, session.user_id())
2965 .await?;
2966 let payload = session
2967 .peer
2968 .forward_request(session.connection_id, host_connection_id, request)
2969 .await?;
2970 response.send(payload)?;
2971 Ok(())
2972}
2973
2974/// forward a project request to the dev server. Only allowed
2975/// if it's your dev server.
2976async fn forward_project_request_for_owner<T>(
2977 request: T,
2978 response: Response<T>,
2979 session: UserSession,
2980) -> Result<()>
2981where
2982 T: EntityMessage + RequestMessage,
2983{
2984 let project_id = ProjectId::from_proto(request.remote_entity_id());
2985
2986 let host_connection_id = session
2987 .db()
2988 .await
2989 .host_for_owner_project_request(project_id, session.connection_id, session.user_id())
2990 .await?;
2991 let payload = session
2992 .peer
2993 .forward_request(session.connection_id, host_connection_id, request)
2994 .await?;
2995 response.send(payload)?;
2996 Ok(())
2997}
2998
2999/// forward a project request to the host. These requests are disallowed
3000/// for guests.
3001async fn forward_mutating_project_request<T>(
3002 request: T,
3003 response: Response<T>,
3004 session: UserSession,
3005) -> Result<()>
3006where
3007 T: EntityMessage + RequestMessage,
3008{
3009 let project_id = ProjectId::from_proto(request.remote_entity_id());
3010
3011 let host_connection_id = session
3012 .db()
3013 .await
3014 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3015 .await?;
3016 let payload = session
3017 .peer
3018 .forward_request(session.connection_id, host_connection_id, request)
3019 .await?;
3020 response.send(payload)?;
3021 Ok(())
3022}
3023
3024/// forward a project request to the host. These requests are disallowed
3025/// for guests.
3026async fn forward_versioned_mutating_project_request<T>(
3027 request: T,
3028 response: Response<T>,
3029 session: UserSession,
3030) -> Result<()>
3031where
3032 T: EntityMessage + RequestMessage + VersionedMessage,
3033{
3034 let project_id = ProjectId::from_proto(request.remote_entity_id());
3035
3036 let host_connection_id = session
3037 .db()
3038 .await
3039 .host_for_mutating_project_request(project_id, session.connection_id, session.user_id())
3040 .await?;
3041 if let Some(host_version) = session
3042 .connection_pool()
3043 .await
3044 .connection(host_connection_id)
3045 .map(|c| c.zed_version)
3046 {
3047 if let Some(min_required_version) = request.required_host_version() {
3048 if min_required_version > host_version {
3049 return Err(anyhow!(ErrorCode::RemoteUpgradeRequired
3050 .with_tag("required", &min_required_version.to_string())))?;
3051 }
3052 }
3053 }
3054
3055 let payload = session
3056 .peer
3057 .forward_request(session.connection_id, host_connection_id, request)
3058 .await?;
3059 response.send(payload)?;
3060 Ok(())
3061}
3062
3063/// Notify other participants that a new buffer has been created
3064async fn create_buffer_for_peer(
3065 request: proto::CreateBufferForPeer,
3066 session: Session,
3067) -> Result<()> {
3068 session
3069 .db()
3070 .await
3071 .check_user_is_project_host(
3072 ProjectId::from_proto(request.project_id),
3073 session.connection_id,
3074 )
3075 .await?;
3076 let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
3077 session
3078 .peer
3079 .forward_send(session.connection_id, peer_id.into(), request)?;
3080 Ok(())
3081}
3082
3083/// Notify other participants that a buffer has been updated. This is
3084/// allowed for guests as long as the update is limited to selections.
3085async fn update_buffer(
3086 request: proto::UpdateBuffer,
3087 response: Response<proto::UpdateBuffer>,
3088 session: Session,
3089) -> Result<()> {
3090 let project_id = ProjectId::from_proto(request.project_id);
3091 let mut capability = Capability::ReadOnly;
3092
3093 for op in request.operations.iter() {
3094 match op.variant {
3095 None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
3096 Some(_) => capability = Capability::ReadWrite,
3097 }
3098 }
3099
3100 let host = {
3101 let guard = session
3102 .db()
3103 .await
3104 .connections_for_buffer_update(
3105 project_id,
3106 session.principal_id(),
3107 session.connection_id,
3108 capability,
3109 )
3110 .await?;
3111
3112 let (host, guests) = &*guard;
3113
3114 broadcast(
3115 Some(session.connection_id),
3116 guests.clone(),
3117 |connection_id| {
3118 session
3119 .peer
3120 .forward_send(session.connection_id, connection_id, request.clone())
3121 },
3122 );
3123
3124 *host
3125 };
3126
3127 if host != session.connection_id {
3128 session
3129 .peer
3130 .forward_request(session.connection_id, host, request.clone())
3131 .await?;
3132 }
3133
3134 response.send(proto::Ack {})?;
3135 Ok(())
3136}
3137
3138async fn update_context(message: proto::UpdateContext, session: Session) -> Result<()> {
3139 let project_id = ProjectId::from_proto(message.project_id);
3140
3141 let operation = message.operation.as_ref().context("invalid operation")?;
3142 let capability = match operation.variant.as_ref() {
3143 Some(proto::context_operation::Variant::BufferOperation(buffer_op)) => {
3144 if let Some(buffer_op) = buffer_op.operation.as_ref() {
3145 match buffer_op.variant {
3146 None | Some(proto::operation::Variant::UpdateSelections(_)) => {
3147 Capability::ReadOnly
3148 }
3149 _ => Capability::ReadWrite,
3150 }
3151 } else {
3152 Capability::ReadWrite
3153 }
3154 }
3155 Some(_) => Capability::ReadWrite,
3156 None => Capability::ReadOnly,
3157 };
3158
3159 let guard = session
3160 .db()
3161 .await
3162 .connections_for_buffer_update(
3163 project_id,
3164 session.principal_id(),
3165 session.connection_id,
3166 capability,
3167 )
3168 .await?;
3169
3170 let (host, guests) = &*guard;
3171
3172 broadcast(
3173 Some(session.connection_id),
3174 guests.iter().chain([host]).copied(),
3175 |connection_id| {
3176 session
3177 .peer
3178 .forward_send(session.connection_id, connection_id, message.clone())
3179 },
3180 );
3181
3182 Ok(())
3183}
3184
3185/// Notify other participants that a project has been updated.
3186async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
3187 request: T,
3188 session: Session,
3189) -> Result<()> {
3190 let project_id = ProjectId::from_proto(request.remote_entity_id());
3191 let project_connection_ids = session
3192 .db()
3193 .await
3194 .project_connection_ids(project_id, session.connection_id, false)
3195 .await?;
3196
3197 broadcast(
3198 Some(session.connection_id),
3199 project_connection_ids.iter().copied(),
3200 |connection_id| {
3201 session
3202 .peer
3203 .forward_send(session.connection_id, connection_id, request.clone())
3204 },
3205 );
3206 Ok(())
3207}
3208
3209/// Start following another user in a call.
3210async fn follow(
3211 request: proto::Follow,
3212 response: Response<proto::Follow>,
3213 session: UserSession,
3214) -> Result<()> {
3215 let room_id = RoomId::from_proto(request.room_id);
3216 let project_id = request.project_id.map(ProjectId::from_proto);
3217 let leader_id = request
3218 .leader_id
3219 .ok_or_else(|| anyhow!("invalid leader id"))?
3220 .into();
3221 let follower_id = session.connection_id;
3222
3223 session
3224 .db()
3225 .await
3226 .check_room_participants(room_id, leader_id, session.connection_id)
3227 .await?;
3228
3229 let response_payload = session
3230 .peer
3231 .forward_request(session.connection_id, leader_id, request)
3232 .await?;
3233 response.send(response_payload)?;
3234
3235 if let Some(project_id) = project_id {
3236 let room = session
3237 .db()
3238 .await
3239 .follow(room_id, project_id, leader_id, follower_id)
3240 .await?;
3241 room_updated(&room, &session.peer);
3242 }
3243
3244 Ok(())
3245}
3246
3247/// Stop following another user in a call.
3248async fn unfollow(request: proto::Unfollow, session: UserSession) -> Result<()> {
3249 let room_id = RoomId::from_proto(request.room_id);
3250 let project_id = request.project_id.map(ProjectId::from_proto);
3251 let leader_id = request
3252 .leader_id
3253 .ok_or_else(|| anyhow!("invalid leader id"))?
3254 .into();
3255 let follower_id = session.connection_id;
3256
3257 session
3258 .db()
3259 .await
3260 .check_room_participants(room_id, leader_id, session.connection_id)
3261 .await?;
3262
3263 session
3264 .peer
3265 .forward_send(session.connection_id, leader_id, request)?;
3266
3267 if let Some(project_id) = project_id {
3268 let room = session
3269 .db()
3270 .await
3271 .unfollow(room_id, project_id, leader_id, follower_id)
3272 .await?;
3273 room_updated(&room, &session.peer);
3274 }
3275
3276 Ok(())
3277}
3278
3279/// Notify everyone following you of your current location.
3280async fn update_followers(request: proto::UpdateFollowers, session: UserSession) -> Result<()> {
3281 let room_id = RoomId::from_proto(request.room_id);
3282 let database = session.db.lock().await;
3283
3284 let connection_ids = if let Some(project_id) = request.project_id {
3285 let project_id = ProjectId::from_proto(project_id);
3286 database
3287 .project_connection_ids(project_id, session.connection_id, true)
3288 .await?
3289 } else {
3290 database
3291 .room_connection_ids(room_id, session.connection_id)
3292 .await?
3293 };
3294
3295 // For now, don't send view update messages back to that view's current leader.
3296 let peer_id_to_omit = request.variant.as_ref().and_then(|variant| match variant {
3297 proto::update_followers::Variant::UpdateView(payload) => payload.leader_id,
3298 _ => None,
3299 });
3300
3301 for connection_id in connection_ids.iter().cloned() {
3302 if Some(connection_id.into()) != peer_id_to_omit && connection_id != session.connection_id {
3303 session
3304 .peer
3305 .forward_send(session.connection_id, connection_id, request.clone())?;
3306 }
3307 }
3308 Ok(())
3309}
3310
3311/// Get public data about users.
3312async fn get_users(
3313 request: proto::GetUsers,
3314 response: Response<proto::GetUsers>,
3315 session: Session,
3316) -> Result<()> {
3317 let user_ids = request
3318 .user_ids
3319 .into_iter()
3320 .map(UserId::from_proto)
3321 .collect();
3322 let users = session
3323 .db()
3324 .await
3325 .get_users_by_ids(user_ids)
3326 .await?
3327 .into_iter()
3328 .map(|user| proto::User {
3329 id: user.id.to_proto(),
3330 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3331 github_login: user.github_login,
3332 })
3333 .collect();
3334 response.send(proto::UsersResponse { users })?;
3335 Ok(())
3336}
3337
3338/// Search for users (to invite) buy Github login
3339async fn fuzzy_search_users(
3340 request: proto::FuzzySearchUsers,
3341 response: Response<proto::FuzzySearchUsers>,
3342 session: UserSession,
3343) -> Result<()> {
3344 let query = request.query;
3345 let users = match query.len() {
3346 0 => vec![],
3347 1 | 2 => session
3348 .db()
3349 .await
3350 .get_user_by_github_login(&query)
3351 .await?
3352 .into_iter()
3353 .collect(),
3354 _ => session.db().await.fuzzy_search_users(&query, 10).await?,
3355 };
3356 let users = users
3357 .into_iter()
3358 .filter(|user| user.id != session.user_id())
3359 .map(|user| proto::User {
3360 id: user.id.to_proto(),
3361 avatar_url: format!("https://github.com/{}.png?size=128", user.github_login),
3362 github_login: user.github_login,
3363 })
3364 .collect();
3365 response.send(proto::UsersResponse { users })?;
3366 Ok(())
3367}
3368
3369/// Send a contact request to another user.
3370async fn request_contact(
3371 request: proto::RequestContact,
3372 response: Response<proto::RequestContact>,
3373 session: UserSession,
3374) -> Result<()> {
3375 let requester_id = session.user_id();
3376 let responder_id = UserId::from_proto(request.responder_id);
3377 if requester_id == responder_id {
3378 return Err(anyhow!("cannot add yourself as a contact"))?;
3379 }
3380
3381 let notifications = session
3382 .db()
3383 .await
3384 .send_contact_request(requester_id, responder_id)
3385 .await?;
3386
3387 // Update outgoing contact requests of requester
3388 let mut update = proto::UpdateContacts::default();
3389 update.outgoing_requests.push(responder_id.to_proto());
3390 for connection_id in session
3391 .connection_pool()
3392 .await
3393 .user_connection_ids(requester_id)
3394 {
3395 session.peer.send(connection_id, update.clone())?;
3396 }
3397
3398 // Update incoming contact requests of responder
3399 let mut update = proto::UpdateContacts::default();
3400 update
3401 .incoming_requests
3402 .push(proto::IncomingContactRequest {
3403 requester_id: requester_id.to_proto(),
3404 });
3405 let connection_pool = session.connection_pool().await;
3406 for connection_id in connection_pool.user_connection_ids(responder_id) {
3407 session.peer.send(connection_id, update.clone())?;
3408 }
3409
3410 send_notifications(&connection_pool, &session.peer, notifications);
3411
3412 response.send(proto::Ack {})?;
3413 Ok(())
3414}
3415
3416/// Accept or decline a contact request
3417async fn respond_to_contact_request(
3418 request: proto::RespondToContactRequest,
3419 response: Response<proto::RespondToContactRequest>,
3420 session: UserSession,
3421) -> Result<()> {
3422 let responder_id = session.user_id();
3423 let requester_id = UserId::from_proto(request.requester_id);
3424 let db = session.db().await;
3425 if request.response == proto::ContactRequestResponse::Dismiss as i32 {
3426 db.dismiss_contact_notification(responder_id, requester_id)
3427 .await?;
3428 } else {
3429 let accept = request.response == proto::ContactRequestResponse::Accept as i32;
3430
3431 let notifications = db
3432 .respond_to_contact_request(responder_id, requester_id, accept)
3433 .await?;
3434 let requester_busy = db.is_user_busy(requester_id).await?;
3435 let responder_busy = db.is_user_busy(responder_id).await?;
3436
3437 let pool = session.connection_pool().await;
3438 // Update responder with new contact
3439 let mut update = proto::UpdateContacts::default();
3440 if accept {
3441 update
3442 .contacts
3443 .push(contact_for_user(requester_id, requester_busy, &pool));
3444 }
3445 update
3446 .remove_incoming_requests
3447 .push(requester_id.to_proto());
3448 for connection_id in pool.user_connection_ids(responder_id) {
3449 session.peer.send(connection_id, update.clone())?;
3450 }
3451
3452 // Update requester with new contact
3453 let mut update = proto::UpdateContacts::default();
3454 if accept {
3455 update
3456 .contacts
3457 .push(contact_for_user(responder_id, responder_busy, &pool));
3458 }
3459 update
3460 .remove_outgoing_requests
3461 .push(responder_id.to_proto());
3462
3463 for connection_id in pool.user_connection_ids(requester_id) {
3464 session.peer.send(connection_id, update.clone())?;
3465 }
3466
3467 send_notifications(&pool, &session.peer, notifications);
3468 }
3469
3470 response.send(proto::Ack {})?;
3471 Ok(())
3472}
3473
3474/// Remove a contact.
3475async fn remove_contact(
3476 request: proto::RemoveContact,
3477 response: Response<proto::RemoveContact>,
3478 session: UserSession,
3479) -> Result<()> {
3480 let requester_id = session.user_id();
3481 let responder_id = UserId::from_proto(request.user_id);
3482 let db = session.db().await;
3483 let (contact_accepted, deleted_notification_id) =
3484 db.remove_contact(requester_id, responder_id).await?;
3485
3486 let pool = session.connection_pool().await;
3487 // Update outgoing contact requests of requester
3488 let mut update = proto::UpdateContacts::default();
3489 if contact_accepted {
3490 update.remove_contacts.push(responder_id.to_proto());
3491 } else {
3492 update
3493 .remove_outgoing_requests
3494 .push(responder_id.to_proto());
3495 }
3496 for connection_id in pool.user_connection_ids(requester_id) {
3497 session.peer.send(connection_id, update.clone())?;
3498 }
3499
3500 // Update incoming contact requests of responder
3501 let mut update = proto::UpdateContacts::default();
3502 if contact_accepted {
3503 update.remove_contacts.push(requester_id.to_proto());
3504 } else {
3505 update
3506 .remove_incoming_requests
3507 .push(requester_id.to_proto());
3508 }
3509 for connection_id in pool.user_connection_ids(responder_id) {
3510 session.peer.send(connection_id, update.clone())?;
3511 if let Some(notification_id) = deleted_notification_id {
3512 session.peer.send(
3513 connection_id,
3514 proto::DeleteNotification {
3515 notification_id: notification_id.to_proto(),
3516 },
3517 )?;
3518 }
3519 }
3520
3521 response.send(proto::Ack {})?;
3522 Ok(())
3523}
3524
3525fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
3526 version.0.minor() < 139
3527}
3528
3529async fn subscribe_to_channels(_: proto::SubscribeToChannels, session: Session) -> Result<()> {
3530 subscribe_user_to_channels(
3531 session.user_id().ok_or_else(|| anyhow!("must be a user"))?,
3532 &session,
3533 )
3534 .await?;
3535 Ok(())
3536}
3537
3538async fn subscribe_user_to_channels(user_id: UserId, session: &Session) -> Result<(), Error> {
3539 let channels_for_user = session.db().await.get_channels_for_user(user_id).await?;
3540 let mut pool = session.connection_pool().await;
3541 for membership in &channels_for_user.channel_memberships {
3542 pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
3543 }
3544 session.peer.send(
3545 session.connection_id,
3546 build_update_user_channels(&channels_for_user),
3547 )?;
3548 session.peer.send(
3549 session.connection_id,
3550 build_channels_update(channels_for_user),
3551 )?;
3552 Ok(())
3553}
3554
3555/// Creates a new channel.
3556async fn create_channel(
3557 request: proto::CreateChannel,
3558 response: Response<proto::CreateChannel>,
3559 session: UserSession,
3560) -> Result<()> {
3561 let db = session.db().await;
3562
3563 let parent_id = request.parent_id.map(|id| ChannelId::from_proto(id));
3564 let (channel, membership) = db
3565 .create_channel(&request.name, parent_id, session.user_id())
3566 .await?;
3567
3568 let root_id = channel.root_id();
3569 let channel = Channel::from_model(channel);
3570
3571 response.send(proto::CreateChannelResponse {
3572 channel: Some(channel.to_proto()),
3573 parent_id: request.parent_id,
3574 })?;
3575
3576 let mut connection_pool = session.connection_pool().await;
3577 if let Some(membership) = membership {
3578 connection_pool.subscribe_to_channel(
3579 membership.user_id,
3580 membership.channel_id,
3581 membership.role,
3582 );
3583 let update = proto::UpdateUserChannels {
3584 channel_memberships: vec![proto::ChannelMembership {
3585 channel_id: membership.channel_id.to_proto(),
3586 role: membership.role.into(),
3587 }],
3588 ..Default::default()
3589 };
3590 for connection_id in connection_pool.user_connection_ids(membership.user_id) {
3591 session.peer.send(connection_id, update.clone())?;
3592 }
3593 }
3594
3595 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3596 if !role.can_see_channel(channel.visibility) {
3597 continue;
3598 }
3599
3600 let update = proto::UpdateChannels {
3601 channels: vec![channel.to_proto()],
3602 ..Default::default()
3603 };
3604 session.peer.send(connection_id, update.clone())?;
3605 }
3606
3607 Ok(())
3608}
3609
3610/// Delete a channel
3611async fn delete_channel(
3612 request: proto::DeleteChannel,
3613 response: Response<proto::DeleteChannel>,
3614 session: UserSession,
3615) -> Result<()> {
3616 let db = session.db().await;
3617
3618 let channel_id = request.channel_id;
3619 let (root_channel, removed_channels) = db
3620 .delete_channel(ChannelId::from_proto(channel_id), session.user_id())
3621 .await?;
3622 response.send(proto::Ack {})?;
3623
3624 // Notify members of removed channels
3625 let mut update = proto::UpdateChannels::default();
3626 update
3627 .delete_channels
3628 .extend(removed_channels.into_iter().map(|id| id.to_proto()));
3629
3630 let connection_pool = session.connection_pool().await;
3631 for (connection_id, _) in connection_pool.channel_connection_ids(root_channel) {
3632 session.peer.send(connection_id, update.clone())?;
3633 }
3634
3635 Ok(())
3636}
3637
3638/// Invite someone to join a channel.
3639async fn invite_channel_member(
3640 request: proto::InviteChannelMember,
3641 response: Response<proto::InviteChannelMember>,
3642 session: UserSession,
3643) -> Result<()> {
3644 let db = session.db().await;
3645 let channel_id = ChannelId::from_proto(request.channel_id);
3646 let invitee_id = UserId::from_proto(request.user_id);
3647 let InviteMemberResult {
3648 channel,
3649 notifications,
3650 } = db
3651 .invite_channel_member(
3652 channel_id,
3653 invitee_id,
3654 session.user_id(),
3655 request.role().into(),
3656 )
3657 .await?;
3658
3659 let update = proto::UpdateChannels {
3660 channel_invitations: vec![channel.to_proto()],
3661 ..Default::default()
3662 };
3663
3664 let connection_pool = session.connection_pool().await;
3665 for connection_id in connection_pool.user_connection_ids(invitee_id) {
3666 session.peer.send(connection_id, update.clone())?;
3667 }
3668
3669 send_notifications(&connection_pool, &session.peer, notifications);
3670
3671 response.send(proto::Ack {})?;
3672 Ok(())
3673}
3674
3675/// remove someone from a channel
3676async fn remove_channel_member(
3677 request: proto::RemoveChannelMember,
3678 response: Response<proto::RemoveChannelMember>,
3679 session: UserSession,
3680) -> Result<()> {
3681 let db = session.db().await;
3682 let channel_id = ChannelId::from_proto(request.channel_id);
3683 let member_id = UserId::from_proto(request.user_id);
3684
3685 let RemoveChannelMemberResult {
3686 membership_update,
3687 notification_id,
3688 } = db
3689 .remove_channel_member(channel_id, member_id, session.user_id())
3690 .await?;
3691
3692 let mut connection_pool = session.connection_pool().await;
3693 notify_membership_updated(
3694 &mut connection_pool,
3695 membership_update,
3696 member_id,
3697 &session.peer,
3698 );
3699 for connection_id in connection_pool.user_connection_ids(member_id) {
3700 if let Some(notification_id) = notification_id {
3701 session
3702 .peer
3703 .send(
3704 connection_id,
3705 proto::DeleteNotification {
3706 notification_id: notification_id.to_proto(),
3707 },
3708 )
3709 .trace_err();
3710 }
3711 }
3712
3713 response.send(proto::Ack {})?;
3714 Ok(())
3715}
3716
3717/// Toggle the channel between public and private.
3718/// Care is taken to maintain the invariant that public channels only descend from public channels,
3719/// (though members-only channels can appear at any point in the hierarchy).
3720async fn set_channel_visibility(
3721 request: proto::SetChannelVisibility,
3722 response: Response<proto::SetChannelVisibility>,
3723 session: UserSession,
3724) -> Result<()> {
3725 let db = session.db().await;
3726 let channel_id = ChannelId::from_proto(request.channel_id);
3727 let visibility = request.visibility().into();
3728
3729 let channel_model = db
3730 .set_channel_visibility(channel_id, visibility, session.user_id())
3731 .await?;
3732 let root_id = channel_model.root_id();
3733 let channel = Channel::from_model(channel_model);
3734
3735 let mut connection_pool = session.connection_pool().await;
3736 for (user_id, role) in connection_pool
3737 .channel_user_ids(root_id)
3738 .collect::<Vec<_>>()
3739 .into_iter()
3740 {
3741 let update = if role.can_see_channel(channel.visibility) {
3742 connection_pool.subscribe_to_channel(user_id, channel_id, role);
3743 proto::UpdateChannels {
3744 channels: vec![channel.to_proto()],
3745 ..Default::default()
3746 }
3747 } else {
3748 connection_pool.unsubscribe_from_channel(&user_id, &channel_id);
3749 proto::UpdateChannels {
3750 delete_channels: vec![channel.id.to_proto()],
3751 ..Default::default()
3752 }
3753 };
3754
3755 for connection_id in connection_pool.user_connection_ids(user_id) {
3756 session.peer.send(connection_id, update.clone())?;
3757 }
3758 }
3759
3760 response.send(proto::Ack {})?;
3761 Ok(())
3762}
3763
3764/// Alter the role for a user in the channel.
3765async fn set_channel_member_role(
3766 request: proto::SetChannelMemberRole,
3767 response: Response<proto::SetChannelMemberRole>,
3768 session: UserSession,
3769) -> Result<()> {
3770 let db = session.db().await;
3771 let channel_id = ChannelId::from_proto(request.channel_id);
3772 let member_id = UserId::from_proto(request.user_id);
3773 let result = db
3774 .set_channel_member_role(
3775 channel_id,
3776 session.user_id(),
3777 member_id,
3778 request.role().into(),
3779 )
3780 .await?;
3781
3782 match result {
3783 db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
3784 let mut connection_pool = session.connection_pool().await;
3785 notify_membership_updated(
3786 &mut connection_pool,
3787 membership_update,
3788 member_id,
3789 &session.peer,
3790 )
3791 }
3792 db::SetMemberRoleResult::InviteUpdated(channel) => {
3793 let update = proto::UpdateChannels {
3794 channel_invitations: vec![channel.to_proto()],
3795 ..Default::default()
3796 };
3797
3798 for connection_id in session
3799 .connection_pool()
3800 .await
3801 .user_connection_ids(member_id)
3802 {
3803 session.peer.send(connection_id, update.clone())?;
3804 }
3805 }
3806 }
3807
3808 response.send(proto::Ack {})?;
3809 Ok(())
3810}
3811
3812/// Change the name of a channel
3813async fn rename_channel(
3814 request: proto::RenameChannel,
3815 response: Response<proto::RenameChannel>,
3816 session: UserSession,
3817) -> Result<()> {
3818 let db = session.db().await;
3819 let channel_id = ChannelId::from_proto(request.channel_id);
3820 let channel_model = db
3821 .rename_channel(channel_id, session.user_id(), &request.name)
3822 .await?;
3823 let root_id = channel_model.root_id();
3824 let channel = Channel::from_model(channel_model);
3825
3826 response.send(proto::RenameChannelResponse {
3827 channel: Some(channel.to_proto()),
3828 })?;
3829
3830 let connection_pool = session.connection_pool().await;
3831 let update = proto::UpdateChannels {
3832 channels: vec![channel.to_proto()],
3833 ..Default::default()
3834 };
3835 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3836 if role.can_see_channel(channel.visibility) {
3837 session.peer.send(connection_id, update.clone())?;
3838 }
3839 }
3840
3841 Ok(())
3842}
3843
3844/// Move a channel to a new parent.
3845async fn move_channel(
3846 request: proto::MoveChannel,
3847 response: Response<proto::MoveChannel>,
3848 session: UserSession,
3849) -> Result<()> {
3850 let channel_id = ChannelId::from_proto(request.channel_id);
3851 let to = ChannelId::from_proto(request.to);
3852
3853 let (root_id, channels) = session
3854 .db()
3855 .await
3856 .move_channel(channel_id, to, session.user_id())
3857 .await?;
3858
3859 let connection_pool = session.connection_pool().await;
3860 for (connection_id, role) in connection_pool.channel_connection_ids(root_id) {
3861 let channels = channels
3862 .iter()
3863 .filter_map(|channel| {
3864 if role.can_see_channel(channel.visibility) {
3865 Some(channel.to_proto())
3866 } else {
3867 None
3868 }
3869 })
3870 .collect::<Vec<_>>();
3871 if channels.is_empty() {
3872 continue;
3873 }
3874
3875 let update = proto::UpdateChannels {
3876 channels,
3877 ..Default::default()
3878 };
3879
3880 session.peer.send(connection_id, update.clone())?;
3881 }
3882
3883 response.send(Ack {})?;
3884 Ok(())
3885}
3886
3887/// Get the list of channel members
3888async fn get_channel_members(
3889 request: proto::GetChannelMembers,
3890 response: Response<proto::GetChannelMembers>,
3891 session: UserSession,
3892) -> Result<()> {
3893 let db = session.db().await;
3894 let channel_id = ChannelId::from_proto(request.channel_id);
3895 let limit = if request.limit == 0 {
3896 u16::MAX as u64
3897 } else {
3898 request.limit
3899 };
3900 let (members, users) = db
3901 .get_channel_participant_details(channel_id, &request.query, limit, session.user_id())
3902 .await?;
3903 response.send(proto::GetChannelMembersResponse { members, users })?;
3904 Ok(())
3905}
3906
3907/// Accept or decline a channel invitation.
3908async fn respond_to_channel_invite(
3909 request: proto::RespondToChannelInvite,
3910 response: Response<proto::RespondToChannelInvite>,
3911 session: UserSession,
3912) -> Result<()> {
3913 let db = session.db().await;
3914 let channel_id = ChannelId::from_proto(request.channel_id);
3915 let RespondToChannelInvite {
3916 membership_update,
3917 notifications,
3918 } = db
3919 .respond_to_channel_invite(channel_id, session.user_id(), request.accept)
3920 .await?;
3921
3922 let mut connection_pool = session.connection_pool().await;
3923 if let Some(membership_update) = membership_update {
3924 notify_membership_updated(
3925 &mut connection_pool,
3926 membership_update,
3927 session.user_id(),
3928 &session.peer,
3929 );
3930 } else {
3931 let update = proto::UpdateChannels {
3932 remove_channel_invitations: vec![channel_id.to_proto()],
3933 ..Default::default()
3934 };
3935
3936 for connection_id in connection_pool.user_connection_ids(session.user_id()) {
3937 session.peer.send(connection_id, update.clone())?;
3938 }
3939 };
3940
3941 send_notifications(&connection_pool, &session.peer, notifications);
3942
3943 response.send(proto::Ack {})?;
3944
3945 Ok(())
3946}
3947
3948/// Join the channels' room
3949async fn join_channel(
3950 request: proto::JoinChannel,
3951 response: Response<proto::JoinChannel>,
3952 session: UserSession,
3953) -> Result<()> {
3954 let channel_id = ChannelId::from_proto(request.channel_id);
3955 join_channel_internal(channel_id, Box::new(response), session).await
3956}
3957
3958trait JoinChannelInternalResponse {
3959 fn send(self, result: proto::JoinRoomResponse) -> Result<()>;
3960}
3961impl JoinChannelInternalResponse for Response<proto::JoinChannel> {
3962 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3963 Response::<proto::JoinChannel>::send(self, result)
3964 }
3965}
3966impl JoinChannelInternalResponse for Response<proto::JoinRoom> {
3967 fn send(self, result: proto::JoinRoomResponse) -> Result<()> {
3968 Response::<proto::JoinRoom>::send(self, result)
3969 }
3970}
3971
3972async fn join_channel_internal(
3973 channel_id: ChannelId,
3974 response: Box<impl JoinChannelInternalResponse>,
3975 session: UserSession,
3976) -> Result<()> {
3977 let joined_room = {
3978 let mut db = session.db().await;
3979 // If zed quits without leaving the room, and the user re-opens zed before the
3980 // RECONNECT_TIMEOUT, we need to make sure that we kick the user out of the previous
3981 // room they were in.
3982 if let Some(connection) = db.stale_room_connection(session.user_id()).await? {
3983 tracing::info!(
3984 stale_connection_id = %connection,
3985 "cleaning up stale connection",
3986 );
3987 drop(db);
3988 leave_room_for_session(&session, connection).await?;
3989 db = session.db().await;
3990 }
3991
3992 let (joined_room, membership_updated, role) = db
3993 .join_channel(channel_id, session.user_id(), session.connection_id)
3994 .await?;
3995
3996 let live_kit_connection_info = session.live_kit_client.as_ref().and_then(|live_kit| {
3997 let (can_publish, token) = if role == ChannelRole::Guest {
3998 (
3999 false,
4000 live_kit
4001 .guest_token(
4002 &joined_room.room.live_kit_room,
4003 &session.user_id().to_string(),
4004 )
4005 .trace_err()?,
4006 )
4007 } else {
4008 (
4009 true,
4010 live_kit
4011 .room_token(
4012 &joined_room.room.live_kit_room,
4013 &session.user_id().to_string(),
4014 )
4015 .trace_err()?,
4016 )
4017 };
4018
4019 Some(LiveKitConnectionInfo {
4020 server_url: live_kit.url().into(),
4021 token,
4022 can_publish,
4023 })
4024 });
4025
4026 response.send(proto::JoinRoomResponse {
4027 room: Some(joined_room.room.clone()),
4028 channel_id: joined_room
4029 .channel
4030 .as_ref()
4031 .map(|channel| channel.id.to_proto()),
4032 live_kit_connection_info,
4033 })?;
4034
4035 let mut connection_pool = session.connection_pool().await;
4036 if let Some(membership_updated) = membership_updated {
4037 notify_membership_updated(
4038 &mut connection_pool,
4039 membership_updated,
4040 session.user_id(),
4041 &session.peer,
4042 );
4043 }
4044
4045 room_updated(&joined_room.room, &session.peer);
4046
4047 joined_room
4048 };
4049
4050 channel_updated(
4051 &joined_room
4052 .channel
4053 .ok_or_else(|| anyhow!("channel not returned"))?,
4054 &joined_room.room,
4055 &session.peer,
4056 &*session.connection_pool().await,
4057 );
4058
4059 update_user_contacts(session.user_id(), &session).await?;
4060 Ok(())
4061}
4062
4063/// Start editing the channel notes
4064async fn join_channel_buffer(
4065 request: proto::JoinChannelBuffer,
4066 response: Response<proto::JoinChannelBuffer>,
4067 session: UserSession,
4068) -> Result<()> {
4069 let db = session.db().await;
4070 let channel_id = ChannelId::from_proto(request.channel_id);
4071
4072 let open_response = db
4073 .join_channel_buffer(channel_id, session.user_id(), session.connection_id)
4074 .await?;
4075
4076 let collaborators = open_response.collaborators.clone();
4077 response.send(open_response)?;
4078
4079 let update = UpdateChannelBufferCollaborators {
4080 channel_id: channel_id.to_proto(),
4081 collaborators: collaborators.clone(),
4082 };
4083 channel_buffer_updated(
4084 session.connection_id,
4085 collaborators
4086 .iter()
4087 .filter_map(|collaborator| Some(collaborator.peer_id?.into())),
4088 &update,
4089 &session.peer,
4090 );
4091
4092 Ok(())
4093}
4094
4095/// Edit the channel notes
4096async fn update_channel_buffer(
4097 request: proto::UpdateChannelBuffer,
4098 session: UserSession,
4099) -> Result<()> {
4100 let db = session.db().await;
4101 let channel_id = ChannelId::from_proto(request.channel_id);
4102
4103 let (collaborators, epoch, version) = db
4104 .update_channel_buffer(channel_id, session.user_id(), &request.operations)
4105 .await?;
4106
4107 channel_buffer_updated(
4108 session.connection_id,
4109 collaborators.clone(),
4110 &proto::UpdateChannelBuffer {
4111 channel_id: channel_id.to_proto(),
4112 operations: request.operations,
4113 },
4114 &session.peer,
4115 );
4116
4117 let pool = &*session.connection_pool().await;
4118
4119 let non_collaborators =
4120 pool.channel_connection_ids(channel_id)
4121 .filter_map(|(connection_id, _)| {
4122 if collaborators.contains(&connection_id) {
4123 None
4124 } else {
4125 Some(connection_id)
4126 }
4127 });
4128
4129 broadcast(None, non_collaborators, |peer_id| {
4130 session.peer.send(
4131 peer_id,
4132 proto::UpdateChannels {
4133 latest_channel_buffer_versions: vec![proto::ChannelBufferVersion {
4134 channel_id: channel_id.to_proto(),
4135 epoch: epoch as u64,
4136 version: version.clone(),
4137 }],
4138 ..Default::default()
4139 },
4140 )
4141 });
4142
4143 Ok(())
4144}
4145
4146/// Rejoin the channel notes after a connection blip
4147async fn rejoin_channel_buffers(
4148 request: proto::RejoinChannelBuffers,
4149 response: Response<proto::RejoinChannelBuffers>,
4150 session: UserSession,
4151) -> Result<()> {
4152 let db = session.db().await;
4153 let buffers = db
4154 .rejoin_channel_buffers(&request.buffers, session.user_id(), session.connection_id)
4155 .await?;
4156
4157 for rejoined_buffer in &buffers {
4158 let collaborators_to_notify = rejoined_buffer
4159 .buffer
4160 .collaborators
4161 .iter()
4162 .filter_map(|c| Some(c.peer_id?.into()));
4163 channel_buffer_updated(
4164 session.connection_id,
4165 collaborators_to_notify,
4166 &proto::UpdateChannelBufferCollaborators {
4167 channel_id: rejoined_buffer.buffer.channel_id,
4168 collaborators: rejoined_buffer.buffer.collaborators.clone(),
4169 },
4170 &session.peer,
4171 );
4172 }
4173
4174 response.send(proto::RejoinChannelBuffersResponse {
4175 buffers: buffers.into_iter().map(|b| b.buffer).collect(),
4176 })?;
4177
4178 Ok(())
4179}
4180
4181/// Stop editing the channel notes
4182async fn leave_channel_buffer(
4183 request: proto::LeaveChannelBuffer,
4184 response: Response<proto::LeaveChannelBuffer>,
4185 session: UserSession,
4186) -> Result<()> {
4187 let db = session.db().await;
4188 let channel_id = ChannelId::from_proto(request.channel_id);
4189
4190 let left_buffer = db
4191 .leave_channel_buffer(channel_id, session.connection_id)
4192 .await?;
4193
4194 response.send(Ack {})?;
4195
4196 channel_buffer_updated(
4197 session.connection_id,
4198 left_buffer.connections,
4199 &proto::UpdateChannelBufferCollaborators {
4200 channel_id: channel_id.to_proto(),
4201 collaborators: left_buffer.collaborators,
4202 },
4203 &session.peer,
4204 );
4205
4206 Ok(())
4207}
4208
4209fn channel_buffer_updated<T: EnvelopedMessage>(
4210 sender_id: ConnectionId,
4211 collaborators: impl IntoIterator<Item = ConnectionId>,
4212 message: &T,
4213 peer: &Peer,
4214) {
4215 broadcast(Some(sender_id), collaborators, |peer_id| {
4216 peer.send(peer_id, message.clone())
4217 });
4218}
4219
4220fn send_notifications(
4221 connection_pool: &ConnectionPool,
4222 peer: &Peer,
4223 notifications: db::NotificationBatch,
4224) {
4225 for (user_id, notification) in notifications {
4226 for connection_id in connection_pool.user_connection_ids(user_id) {
4227 if let Err(error) = peer.send(
4228 connection_id,
4229 proto::AddNotification {
4230 notification: Some(notification.clone()),
4231 },
4232 ) {
4233 tracing::error!(
4234 "failed to send notification to {:?} {}",
4235 connection_id,
4236 error
4237 );
4238 }
4239 }
4240 }
4241}
4242
4243/// Send a message to the channel
4244async fn send_channel_message(
4245 request: proto::SendChannelMessage,
4246 response: Response<proto::SendChannelMessage>,
4247 session: UserSession,
4248) -> Result<()> {
4249 // Validate the message body.
4250 let body = request.body.trim().to_string();
4251 if body.len() > MAX_MESSAGE_LEN {
4252 return Err(anyhow!("message is too long"))?;
4253 }
4254 if body.is_empty() {
4255 return Err(anyhow!("message can't be blank"))?;
4256 }
4257
4258 // TODO: adjust mentions if body is trimmed
4259
4260 let timestamp = OffsetDateTime::now_utc();
4261 let nonce = request
4262 .nonce
4263 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4264
4265 let channel_id = ChannelId::from_proto(request.channel_id);
4266 let CreatedChannelMessage {
4267 message_id,
4268 participant_connection_ids,
4269 notifications,
4270 } = session
4271 .db()
4272 .await
4273 .create_channel_message(
4274 channel_id,
4275 session.user_id(),
4276 &body,
4277 &request.mentions,
4278 timestamp,
4279 nonce.clone().into(),
4280 match request.reply_to_message_id {
4281 Some(reply_to_message_id) => Some(MessageId::from_proto(reply_to_message_id)),
4282 None => None,
4283 },
4284 )
4285 .await?;
4286
4287 let message = proto::ChannelMessage {
4288 sender_id: session.user_id().to_proto(),
4289 id: message_id.to_proto(),
4290 body,
4291 mentions: request.mentions,
4292 timestamp: timestamp.unix_timestamp() as u64,
4293 nonce: Some(nonce),
4294 reply_to_message_id: request.reply_to_message_id,
4295 edited_at: None,
4296 };
4297 broadcast(
4298 Some(session.connection_id),
4299 participant_connection_ids.clone(),
4300 |connection| {
4301 session.peer.send(
4302 connection,
4303 proto::ChannelMessageSent {
4304 channel_id: channel_id.to_proto(),
4305 message: Some(message.clone()),
4306 },
4307 )
4308 },
4309 );
4310 response.send(proto::SendChannelMessageResponse {
4311 message: Some(message),
4312 })?;
4313
4314 let pool = &*session.connection_pool().await;
4315 let non_participants =
4316 pool.channel_connection_ids(channel_id)
4317 .filter_map(|(connection_id, _)| {
4318 if participant_connection_ids.contains(&connection_id) {
4319 None
4320 } else {
4321 Some(connection_id)
4322 }
4323 });
4324 broadcast(None, non_participants, |peer_id| {
4325 session.peer.send(
4326 peer_id,
4327 proto::UpdateChannels {
4328 latest_channel_message_ids: vec![proto::ChannelMessageId {
4329 channel_id: channel_id.to_proto(),
4330 message_id: message_id.to_proto(),
4331 }],
4332 ..Default::default()
4333 },
4334 )
4335 });
4336 send_notifications(pool, &session.peer, notifications);
4337
4338 Ok(())
4339}
4340
4341/// Delete a channel message
4342async fn remove_channel_message(
4343 request: proto::RemoveChannelMessage,
4344 response: Response<proto::RemoveChannelMessage>,
4345 session: UserSession,
4346) -> Result<()> {
4347 let channel_id = ChannelId::from_proto(request.channel_id);
4348 let message_id = MessageId::from_proto(request.message_id);
4349 let (connection_ids, existing_notification_ids) = session
4350 .db()
4351 .await
4352 .remove_channel_message(channel_id, message_id, session.user_id())
4353 .await?;
4354
4355 broadcast(
4356 Some(session.connection_id),
4357 connection_ids,
4358 move |connection| {
4359 session.peer.send(connection, request.clone())?;
4360
4361 for notification_id in &existing_notification_ids {
4362 session.peer.send(
4363 connection,
4364 proto::DeleteNotification {
4365 notification_id: (*notification_id).to_proto(),
4366 },
4367 )?;
4368 }
4369
4370 Ok(())
4371 },
4372 );
4373 response.send(proto::Ack {})?;
4374 Ok(())
4375}
4376
4377async fn update_channel_message(
4378 request: proto::UpdateChannelMessage,
4379 response: Response<proto::UpdateChannelMessage>,
4380 session: UserSession,
4381) -> Result<()> {
4382 let channel_id = ChannelId::from_proto(request.channel_id);
4383 let message_id = MessageId::from_proto(request.message_id);
4384 let updated_at = OffsetDateTime::now_utc();
4385 let UpdatedChannelMessage {
4386 message_id,
4387 participant_connection_ids,
4388 notifications,
4389 reply_to_message_id,
4390 timestamp,
4391 deleted_mention_notification_ids,
4392 updated_mention_notifications,
4393 } = session
4394 .db()
4395 .await
4396 .update_channel_message(
4397 channel_id,
4398 message_id,
4399 session.user_id(),
4400 request.body.as_str(),
4401 &request.mentions,
4402 updated_at,
4403 )
4404 .await?;
4405
4406 let nonce = request
4407 .nonce
4408 .clone()
4409 .ok_or_else(|| anyhow!("nonce can't be blank"))?;
4410
4411 let message = proto::ChannelMessage {
4412 sender_id: session.user_id().to_proto(),
4413 id: message_id.to_proto(),
4414 body: request.body.clone(),
4415 mentions: request.mentions.clone(),
4416 timestamp: timestamp.assume_utc().unix_timestamp() as u64,
4417 nonce: Some(nonce),
4418 reply_to_message_id: reply_to_message_id.map(|id| id.to_proto()),
4419 edited_at: Some(updated_at.unix_timestamp() as u64),
4420 };
4421
4422 response.send(proto::Ack {})?;
4423
4424 let pool = &*session.connection_pool().await;
4425 broadcast(
4426 Some(session.connection_id),
4427 participant_connection_ids,
4428 |connection| {
4429 session.peer.send(
4430 connection,
4431 proto::ChannelMessageUpdate {
4432 channel_id: channel_id.to_proto(),
4433 message: Some(message.clone()),
4434 },
4435 )?;
4436
4437 for notification_id in &deleted_mention_notification_ids {
4438 session.peer.send(
4439 connection,
4440 proto::DeleteNotification {
4441 notification_id: (*notification_id).to_proto(),
4442 },
4443 )?;
4444 }
4445
4446 for notification in &updated_mention_notifications {
4447 session.peer.send(
4448 connection,
4449 proto::UpdateNotification {
4450 notification: Some(notification.clone()),
4451 },
4452 )?;
4453 }
4454
4455 Ok(())
4456 },
4457 );
4458
4459 send_notifications(pool, &session.peer, notifications);
4460
4461 Ok(())
4462}
4463
4464/// Mark a channel message as read
4465async fn acknowledge_channel_message(
4466 request: proto::AckChannelMessage,
4467 session: UserSession,
4468) -> Result<()> {
4469 let channel_id = ChannelId::from_proto(request.channel_id);
4470 let message_id = MessageId::from_proto(request.message_id);
4471 let notifications = session
4472 .db()
4473 .await
4474 .observe_channel_message(channel_id, session.user_id(), message_id)
4475 .await?;
4476 send_notifications(
4477 &*session.connection_pool().await,
4478 &session.peer,
4479 notifications,
4480 );
4481 Ok(())
4482}
4483
4484/// Mark a buffer version as synced
4485async fn acknowledge_buffer_version(
4486 request: proto::AckBufferOperation,
4487 session: UserSession,
4488) -> Result<()> {
4489 let buffer_id = BufferId::from_proto(request.buffer_id);
4490 session
4491 .db()
4492 .await
4493 .observe_buffer_version(
4494 buffer_id,
4495 session.user_id(),
4496 request.epoch as i32,
4497 &request.version,
4498 )
4499 .await?;
4500 Ok(())
4501}
4502
4503struct CompleteWithLanguageModelRateLimit;
4504
4505impl RateLimit for CompleteWithLanguageModelRateLimit {
4506 fn capacity() -> usize {
4507 std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4508 .ok()
4509 .and_then(|v| v.parse().ok())
4510 .unwrap_or(120) // Picked arbitrarily
4511 }
4512
4513 fn refill_duration() -> chrono::Duration {
4514 chrono::Duration::hours(1)
4515 }
4516
4517 fn db_name() -> &'static str {
4518 "complete-with-language-model"
4519 }
4520}
4521
4522async fn complete_with_language_model(
4523 mut request: proto::CompleteWithLanguageModel,
4524 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4525 session: Session,
4526 open_ai_api_key: Option<Arc<str>>,
4527 google_ai_api_key: Option<Arc<str>>,
4528 anthropic_api_key: Option<Arc<str>>,
4529) -> Result<()> {
4530 let Some(session) = session.for_user() else {
4531 return Err(anyhow!("user not found"))?;
4532 };
4533 authorize_access_to_language_models(&session).await?;
4534 session
4535 .rate_limiter
4536 .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
4537 .await?;
4538
4539 let mut provider_and_model = request.model.split('/');
4540 let (provider, model) = match (
4541 provider_and_model.next().unwrap(),
4542 provider_and_model.next(),
4543 ) {
4544 (provider, Some(model)) => (provider, model),
4545 (model, None) => {
4546 if model.starts_with("gpt") {
4547 ("openai", model)
4548 } else if model.starts_with("gemini") {
4549 ("google", model)
4550 } else if model.starts_with("claude") {
4551 ("anthropic", model)
4552 } else {
4553 ("unknown", model)
4554 }
4555 }
4556 };
4557 let provider = provider.to_string();
4558 request.model = model.to_string();
4559
4560 match provider.as_str() {
4561 "openai" => {
4562 let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
4563 complete_with_open_ai(request, response, session, api_key).await?;
4564 }
4565 "anthropic" => {
4566 let api_key =
4567 anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
4568 complete_with_anthropic(request, response, session, api_key).await?;
4569 }
4570 "google" => {
4571 let api_key =
4572 google_ai_api_key.context("no Google AI API key configured on the server")?;
4573 complete_with_google_ai(request, response, session, api_key).await?;
4574 }
4575 provider => return Err(anyhow!("unknown provider {:?}", provider))?,
4576 }
4577
4578 Ok(())
4579}
4580
4581async fn complete_with_open_ai(
4582 request: proto::CompleteWithLanguageModel,
4583 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4584 session: UserSession,
4585 api_key: Arc<str>,
4586) -> Result<()> {
4587 let mut completion_stream = open_ai::stream_completion(
4588 session.http_client.as_ref(),
4589 OPEN_AI_API_URL,
4590 &api_key,
4591 crate::ai::language_model_request_to_open_ai(request)?,
4592 None,
4593 )
4594 .await
4595 .context("open_ai::stream_completion request failed within collab")?;
4596
4597 while let Some(event) = completion_stream.next().await {
4598 let event = event?;
4599 response.send(proto::LanguageModelResponse {
4600 choices: event
4601 .choices
4602 .into_iter()
4603 .map(|choice| proto::LanguageModelChoiceDelta {
4604 index: choice.index,
4605 delta: Some(proto::LanguageModelResponseMessage {
4606 role: choice.delta.role.map(|role| match role {
4607 open_ai::Role::User => LanguageModelRole::LanguageModelUser,
4608 open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
4609 open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
4610 open_ai::Role::Tool => LanguageModelRole::LanguageModelTool,
4611 } as i32),
4612 content: choice.delta.content,
4613 tool_calls: choice
4614 .delta
4615 .tool_calls
4616 .unwrap_or_default()
4617 .into_iter()
4618 .map(|delta| proto::ToolCallDelta {
4619 index: delta.index as u32,
4620 id: delta.id,
4621 variant: match delta.function {
4622 Some(function) => {
4623 let name = function.name;
4624 let arguments = function.arguments;
4625
4626 Some(proto::tool_call_delta::Variant::Function(
4627 proto::tool_call_delta::FunctionCallDelta {
4628 name,
4629 arguments,
4630 },
4631 ))
4632 }
4633 None => None,
4634 },
4635 })
4636 .collect(),
4637 }),
4638 finish_reason: choice.finish_reason,
4639 })
4640 .collect(),
4641 })?;
4642 }
4643
4644 Ok(())
4645}
4646
4647async fn complete_with_google_ai(
4648 request: proto::CompleteWithLanguageModel,
4649 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4650 session: UserSession,
4651 api_key: Arc<str>,
4652) -> Result<()> {
4653 let mut stream = google_ai::stream_generate_content(
4654 session.http_client.clone(),
4655 google_ai::API_URL,
4656 api_key.as_ref(),
4657 &request.model.clone(),
4658 crate::ai::language_model_request_to_google_ai(request)?,
4659 )
4660 .await
4661 .context("google_ai::stream_generate_content request failed")?;
4662
4663 while let Some(event) = stream.next().await {
4664 let event = event?;
4665 response.send(proto::LanguageModelResponse {
4666 choices: event
4667 .candidates
4668 .unwrap_or_default()
4669 .into_iter()
4670 .map(|candidate| proto::LanguageModelChoiceDelta {
4671 index: candidate.index as u32,
4672 delta: Some(proto::LanguageModelResponseMessage {
4673 role: Some(match candidate.content.role {
4674 google_ai::Role::User => LanguageModelRole::LanguageModelUser,
4675 google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
4676 } as i32),
4677 content: Some(
4678 candidate
4679 .content
4680 .parts
4681 .into_iter()
4682 .filter_map(|part| match part {
4683 google_ai::Part::TextPart(part) => Some(part.text),
4684 google_ai::Part::InlineDataPart(_) => None,
4685 })
4686 .collect(),
4687 ),
4688 // Tool calls are not supported for Google
4689 tool_calls: Vec::new(),
4690 }),
4691 finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
4692 })
4693 .collect(),
4694 })?;
4695 }
4696
4697 Ok(())
4698}
4699
4700async fn complete_with_anthropic(
4701 request: proto::CompleteWithLanguageModel,
4702 response: StreamingResponse<proto::CompleteWithLanguageModel>,
4703 session: UserSession,
4704 api_key: Arc<str>,
4705) -> Result<()> {
4706 let model = anthropic::Model::from_id(&request.model)?;
4707
4708 let mut system_message = String::new();
4709 let messages = request
4710 .messages
4711 .into_iter()
4712 .filter_map(|message| {
4713 match message.role() {
4714 LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
4715 role: anthropic::Role::User,
4716 content: message.content,
4717 }),
4718 LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
4719 role: anthropic::Role::Assistant,
4720 content: message.content,
4721 }),
4722 // Anthropic's API breaks system instructions out as a separate field rather
4723 // than having a system message role.
4724 LanguageModelRole::LanguageModelSystem => {
4725 if !system_message.is_empty() {
4726 system_message.push_str("\n\n");
4727 }
4728 system_message.push_str(&message.content);
4729
4730 None
4731 }
4732 // We don't yet support tool calls for Anthropic
4733 LanguageModelRole::LanguageModelTool => None,
4734 }
4735 })
4736 .collect();
4737
4738 let mut stream = anthropic::stream_completion(
4739 session.http_client.as_ref(),
4740 anthropic::ANTHROPIC_API_URL,
4741 &api_key,
4742 anthropic::Request {
4743 model,
4744 messages,
4745 stream: true,
4746 system: system_message,
4747 max_tokens: 4092,
4748 },
4749 None,
4750 )
4751 .await?;
4752
4753 let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
4754
4755 while let Some(event) = stream.next().await {
4756 let event = event?;
4757
4758 match event {
4759 anthropic::ResponseEvent::MessageStart { message } => {
4760 if let Some(role) = message.role {
4761 if role == "assistant" {
4762 current_role = proto::LanguageModelRole::LanguageModelAssistant;
4763 } else if role == "user" {
4764 current_role = proto::LanguageModelRole::LanguageModelUser;
4765 }
4766 }
4767 }
4768 anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
4769 match content_block {
4770 anthropic::ContentBlock::Text { text } => {
4771 if !text.is_empty() {
4772 response.send(proto::LanguageModelResponse {
4773 choices: vec![proto::LanguageModelChoiceDelta {
4774 index: 0,
4775 delta: Some(proto::LanguageModelResponseMessage {
4776 role: Some(current_role as i32),
4777 content: Some(text),
4778 tool_calls: Vec::new(),
4779 }),
4780 finish_reason: None,
4781 }],
4782 })?;
4783 }
4784 }
4785 }
4786 }
4787 anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
4788 anthropic::TextDelta::TextDelta { text } => {
4789 response.send(proto::LanguageModelResponse {
4790 choices: vec![proto::LanguageModelChoiceDelta {
4791 index: 0,
4792 delta: Some(proto::LanguageModelResponseMessage {
4793 role: Some(current_role as i32),
4794 content: Some(text),
4795 tool_calls: Vec::new(),
4796 }),
4797 finish_reason: None,
4798 }],
4799 })?;
4800 }
4801 },
4802 anthropic::ResponseEvent::MessageDelta { delta, .. } => {
4803 if let Some(stop_reason) = delta.stop_reason {
4804 response.send(proto::LanguageModelResponse {
4805 choices: vec![proto::LanguageModelChoiceDelta {
4806 index: 0,
4807 delta: None,
4808 finish_reason: Some(stop_reason),
4809 }],
4810 })?;
4811 }
4812 }
4813 anthropic::ResponseEvent::ContentBlockStop { .. } => {}
4814 anthropic::ResponseEvent::MessageStop {} => {}
4815 anthropic::ResponseEvent::Ping {} => {}
4816 }
4817 }
4818
4819 Ok(())
4820}
4821
4822struct CountTokensWithLanguageModelRateLimit;
4823
4824impl RateLimit for CountTokensWithLanguageModelRateLimit {
4825 fn capacity() -> usize {
4826 std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
4827 .ok()
4828 .and_then(|v| v.parse().ok())
4829 .unwrap_or(600) // Picked arbitrarily
4830 }
4831
4832 fn refill_duration() -> chrono::Duration {
4833 chrono::Duration::hours(1)
4834 }
4835
4836 fn db_name() -> &'static str {
4837 "count-tokens-with-language-model"
4838 }
4839}
4840
4841async fn count_tokens_with_language_model(
4842 request: proto::CountTokensWithLanguageModel,
4843 response: Response<proto::CountTokensWithLanguageModel>,
4844 session: UserSession,
4845 google_ai_api_key: Option<Arc<str>>,
4846) -> Result<()> {
4847 authorize_access_to_language_models(&session).await?;
4848
4849 if !request.model.starts_with("gemini") {
4850 return Err(anyhow!(
4851 "counting tokens for model: {:?} is not supported",
4852 request.model
4853 ))?;
4854 }
4855
4856 session
4857 .rate_limiter
4858 .check::<CountTokensWithLanguageModelRateLimit>(session.user_id())
4859 .await?;
4860
4861 let api_key = google_ai_api_key
4862 .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
4863 let tokens_response = google_ai::count_tokens(
4864 session.http_client.as_ref(),
4865 google_ai::API_URL,
4866 &api_key,
4867 crate::ai::count_tokens_request_to_google_ai(request)?,
4868 )
4869 .await?;
4870 response.send(proto::CountTokensResponse {
4871 token_count: tokens_response.total_tokens as u32,
4872 })?;
4873 Ok(())
4874}
4875
4876struct ComputeEmbeddingsRateLimit;
4877
4878impl RateLimit for ComputeEmbeddingsRateLimit {
4879 fn capacity() -> usize {
4880 std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
4881 .ok()
4882 .and_then(|v| v.parse().ok())
4883 .unwrap_or(5000) // Picked arbitrarily
4884 }
4885
4886 fn refill_duration() -> chrono::Duration {
4887 chrono::Duration::hours(1)
4888 }
4889
4890 fn db_name() -> &'static str {
4891 "compute-embeddings"
4892 }
4893}
4894
4895async fn compute_embeddings(
4896 request: proto::ComputeEmbeddings,
4897 response: Response<proto::ComputeEmbeddings>,
4898 session: UserSession,
4899 api_key: Option<Arc<str>>,
4900) -> Result<()> {
4901 let api_key = api_key.context("no OpenAI API key configured on the server")?;
4902 authorize_access_to_language_models(&session).await?;
4903
4904 session
4905 .rate_limiter
4906 .check::<ComputeEmbeddingsRateLimit>(session.user_id())
4907 .await?;
4908
4909 let embeddings = match request.model.as_str() {
4910 "openai/text-embedding-3-small" => {
4911 open_ai::embed(
4912 session.http_client.as_ref(),
4913 OPEN_AI_API_URL,
4914 &api_key,
4915 OpenAiEmbeddingModel::TextEmbedding3Small,
4916 request.texts.iter().map(|text| text.as_str()),
4917 )
4918 .await?
4919 }
4920 provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
4921 };
4922
4923 let embeddings = request
4924 .texts
4925 .iter()
4926 .map(|text| {
4927 let mut hasher = sha2::Sha256::new();
4928 hasher.update(text.as_bytes());
4929 let result = hasher.finalize();
4930 result.to_vec()
4931 })
4932 .zip(
4933 embeddings
4934 .data
4935 .into_iter()
4936 .map(|embedding| embedding.embedding),
4937 )
4938 .collect::<HashMap<_, _>>();
4939
4940 let db = session.db().await;
4941 db.save_embeddings(&request.model, &embeddings)
4942 .await
4943 .context("failed to save embeddings")
4944 .trace_err();
4945
4946 response.send(proto::ComputeEmbeddingsResponse {
4947 embeddings: embeddings
4948 .into_iter()
4949 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4950 .collect(),
4951 })?;
4952 Ok(())
4953}
4954
4955async fn get_cached_embeddings(
4956 request: proto::GetCachedEmbeddings,
4957 response: Response<proto::GetCachedEmbeddings>,
4958 session: UserSession,
4959) -> Result<()> {
4960 authorize_access_to_language_models(&session).await?;
4961
4962 let db = session.db().await;
4963 let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
4964
4965 response.send(proto::GetCachedEmbeddingsResponse {
4966 embeddings: embeddings
4967 .into_iter()
4968 .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
4969 .collect(),
4970 })?;
4971 Ok(())
4972}
4973
4974async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
4975 let db = session.db().await;
4976 let flags = db.get_user_flags(session.user_id()).await?;
4977 if flags.iter().any(|flag| flag == "language-models") {
4978 Ok(())
4979 } else {
4980 Err(anyhow!("permission denied"))?
4981 }
4982}
4983
4984/// Get a Supermaven API key for the user
4985async fn get_supermaven_api_key(
4986 _request: proto::GetSupermavenApiKey,
4987 response: Response<proto::GetSupermavenApiKey>,
4988 session: UserSession,
4989) -> Result<()> {
4990 let user_id: String = session.user_id().to_string();
4991 if !session.is_staff() {
4992 return Err(anyhow!("supermaven not enabled for this account"))?;
4993 }
4994
4995 let email = session
4996 .email()
4997 .ok_or_else(|| anyhow!("user must have an email"))?;
4998
4999 let supermaven_admin_api = session
5000 .supermaven_client
5001 .as_ref()
5002 .ok_or_else(|| anyhow!("supermaven not configured"))?;
5003
5004 let result = supermaven_admin_api
5005 .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email })
5006 .await?;
5007
5008 response.send(proto::GetSupermavenApiKeyResponse {
5009 api_key: result.api_key,
5010 })?;
5011
5012 Ok(())
5013}
5014
5015/// Start receiving chat updates for a channel
5016async fn join_channel_chat(
5017 request: proto::JoinChannelChat,
5018 response: Response<proto::JoinChannelChat>,
5019 session: UserSession,
5020) -> Result<()> {
5021 let channel_id = ChannelId::from_proto(request.channel_id);
5022
5023 let db = session.db().await;
5024 db.join_channel_chat(channel_id, session.connection_id, session.user_id())
5025 .await?;
5026 let messages = db
5027 .get_channel_messages(channel_id, session.user_id(), MESSAGE_COUNT_PER_PAGE, None)
5028 .await?;
5029 response.send(proto::JoinChannelChatResponse {
5030 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5031 messages,
5032 })?;
5033 Ok(())
5034}
5035
5036/// Stop receiving chat updates for a channel
5037async fn leave_channel_chat(request: proto::LeaveChannelChat, session: UserSession) -> Result<()> {
5038 let channel_id = ChannelId::from_proto(request.channel_id);
5039 session
5040 .db()
5041 .await
5042 .leave_channel_chat(channel_id, session.connection_id, session.user_id())
5043 .await?;
5044 Ok(())
5045}
5046
5047/// Retrieve the chat history for a channel
5048async fn get_channel_messages(
5049 request: proto::GetChannelMessages,
5050 response: Response<proto::GetChannelMessages>,
5051 session: UserSession,
5052) -> Result<()> {
5053 let channel_id = ChannelId::from_proto(request.channel_id);
5054 let messages = session
5055 .db()
5056 .await
5057 .get_channel_messages(
5058 channel_id,
5059 session.user_id(),
5060 MESSAGE_COUNT_PER_PAGE,
5061 Some(MessageId::from_proto(request.before_message_id)),
5062 )
5063 .await?;
5064 response.send(proto::GetChannelMessagesResponse {
5065 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5066 messages,
5067 })?;
5068 Ok(())
5069}
5070
5071/// Retrieve specific chat messages
5072async fn get_channel_messages_by_id(
5073 request: proto::GetChannelMessagesById,
5074 response: Response<proto::GetChannelMessagesById>,
5075 session: UserSession,
5076) -> Result<()> {
5077 let message_ids = request
5078 .message_ids
5079 .iter()
5080 .map(|id| MessageId::from_proto(*id))
5081 .collect::<Vec<_>>();
5082 let messages = session
5083 .db()
5084 .await
5085 .get_channel_messages_by_id(session.user_id(), &message_ids)
5086 .await?;
5087 response.send(proto::GetChannelMessagesResponse {
5088 done: messages.len() < MESSAGE_COUNT_PER_PAGE,
5089 messages,
5090 })?;
5091 Ok(())
5092}
5093
5094/// Retrieve the current users notifications
5095async fn get_notifications(
5096 request: proto::GetNotifications,
5097 response: Response<proto::GetNotifications>,
5098 session: UserSession,
5099) -> Result<()> {
5100 let notifications = session
5101 .db()
5102 .await
5103 .get_notifications(
5104 session.user_id(),
5105 NOTIFICATION_COUNT_PER_PAGE,
5106 request
5107 .before_id
5108 .map(|id| db::NotificationId::from_proto(id)),
5109 )
5110 .await?;
5111 response.send(proto::GetNotificationsResponse {
5112 done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
5113 notifications,
5114 })?;
5115 Ok(())
5116}
5117
5118/// Mark notifications as read
5119async fn mark_notification_as_read(
5120 request: proto::MarkNotificationRead,
5121 response: Response<proto::MarkNotificationRead>,
5122 session: UserSession,
5123) -> Result<()> {
5124 let database = &session.db().await;
5125 let notifications = database
5126 .mark_notification_as_read_by_id(
5127 session.user_id(),
5128 NotificationId::from_proto(request.notification_id),
5129 )
5130 .await?;
5131 send_notifications(
5132 &*session.connection_pool().await,
5133 &session.peer,
5134 notifications,
5135 );
5136 response.send(proto::Ack {})?;
5137 Ok(())
5138}
5139
5140/// Get the current users information
5141async fn get_private_user_info(
5142 _request: proto::GetPrivateUserInfo,
5143 response: Response<proto::GetPrivateUserInfo>,
5144 session: UserSession,
5145) -> Result<()> {
5146 let db = session.db().await;
5147
5148 let metrics_id = db.get_user_metrics_id(session.user_id()).await?;
5149 let user = db
5150 .get_user_by_id(session.user_id())
5151 .await?
5152 .ok_or_else(|| anyhow!("user not found"))?;
5153 let flags = db.get_user_flags(session.user_id()).await?;
5154
5155 response.send(proto::GetPrivateUserInfoResponse {
5156 metrics_id,
5157 staff: user.admin,
5158 flags,
5159 })?;
5160 Ok(())
5161}
5162
5163fn to_axum_message(message: TungsteniteMessage) -> Option<AxumMessage> {
5164 match message {
5165 TungsteniteMessage::Text(payload) => Some(AxumMessage::Text(payload)),
5166 TungsteniteMessage::Binary(payload) => Some(AxumMessage::Binary(payload)),
5167 TungsteniteMessage::Ping(payload) => Some(AxumMessage::Ping(payload)),
5168 TungsteniteMessage::Pong(payload) => Some(AxumMessage::Pong(payload)),
5169 TungsteniteMessage::Close(frame) => {
5170 Some(AxumMessage::Close(frame.map(|frame| AxumCloseFrame {
5171 code: frame.code.into(),
5172 reason: frame.reason,
5173 })))
5174 }
5175 // we can ignore `Frame` frames as recommended by the tungstenite maintainers
5176 // https://github.com/snapview/tungstenite-rs/issues/268
5177 TungsteniteMessage::Frame(_) => None,
5178 }
5179}
5180
5181fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
5182 match message {
5183 AxumMessage::Text(payload) => TungsteniteMessage::Text(payload),
5184 AxumMessage::Binary(payload) => TungsteniteMessage::Binary(payload),
5185 AxumMessage::Ping(payload) => TungsteniteMessage::Ping(payload),
5186 AxumMessage::Pong(payload) => TungsteniteMessage::Pong(payload),
5187 AxumMessage::Close(frame) => {
5188 TungsteniteMessage::Close(frame.map(|frame| TungsteniteCloseFrame {
5189 code: frame.code.into(),
5190 reason: frame.reason,
5191 }))
5192 }
5193 }
5194}
5195
5196fn notify_membership_updated(
5197 connection_pool: &mut ConnectionPool,
5198 result: MembershipUpdated,
5199 user_id: UserId,
5200 peer: &Peer,
5201) {
5202 for membership in &result.new_channels.channel_memberships {
5203 connection_pool.subscribe_to_channel(user_id, membership.channel_id, membership.role)
5204 }
5205 for channel_id in &result.removed_channels {
5206 connection_pool.unsubscribe_from_channel(&user_id, channel_id)
5207 }
5208
5209 let user_channels_update = proto::UpdateUserChannels {
5210 channel_memberships: result
5211 .new_channels
5212 .channel_memberships
5213 .iter()
5214 .map(|cm| proto::ChannelMembership {
5215 channel_id: cm.channel_id.to_proto(),
5216 role: cm.role.into(),
5217 })
5218 .collect(),
5219 ..Default::default()
5220 };
5221
5222 let mut update = build_channels_update(result.new_channels);
5223 update.delete_channels = result
5224 .removed_channels
5225 .into_iter()
5226 .map(|id| id.to_proto())
5227 .collect();
5228 update.remove_channel_invitations = vec![result.channel_id.to_proto()];
5229
5230 for connection_id in connection_pool.user_connection_ids(user_id) {
5231 peer.send(connection_id, user_channels_update.clone())
5232 .trace_err();
5233 peer.send(connection_id, update.clone()).trace_err();
5234 }
5235}
5236
5237fn build_update_user_channels(channels: &ChannelsForUser) -> proto::UpdateUserChannels {
5238 proto::UpdateUserChannels {
5239 channel_memberships: channels
5240 .channel_memberships
5241 .iter()
5242 .map(|m| proto::ChannelMembership {
5243 channel_id: m.channel_id.to_proto(),
5244 role: m.role.into(),
5245 })
5246 .collect(),
5247 observed_channel_buffer_version: channels.observed_buffer_versions.clone(),
5248 observed_channel_message_id: channels.observed_channel_messages.clone(),
5249 }
5250}
5251
5252fn build_channels_update(channels: ChannelsForUser) -> proto::UpdateChannels {
5253 let mut update = proto::UpdateChannels::default();
5254
5255 for channel in channels.channels {
5256 update.channels.push(channel.to_proto());
5257 }
5258
5259 update.latest_channel_buffer_versions = channels.latest_buffer_versions;
5260 update.latest_channel_message_ids = channels.latest_channel_messages;
5261
5262 for (channel_id, participants) in channels.channel_participants {
5263 update
5264 .channel_participants
5265 .push(proto::ChannelParticipants {
5266 channel_id: channel_id.to_proto(),
5267 participant_user_ids: participants.into_iter().map(|id| id.to_proto()).collect(),
5268 });
5269 }
5270
5271 for channel in channels.invited_channels {
5272 update.channel_invitations.push(channel.to_proto());
5273 }
5274
5275 update.hosted_projects = channels.hosted_projects;
5276 update
5277}
5278
5279fn build_initial_contacts_update(
5280 contacts: Vec<db::Contact>,
5281 pool: &ConnectionPool,
5282) -> proto::UpdateContacts {
5283 let mut update = proto::UpdateContacts::default();
5284
5285 for contact in contacts {
5286 match contact {
5287 db::Contact::Accepted { user_id, busy } => {
5288 update.contacts.push(contact_for_user(user_id, busy, &pool));
5289 }
5290 db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
5291 db::Contact::Incoming { user_id } => {
5292 update
5293 .incoming_requests
5294 .push(proto::IncomingContactRequest {
5295 requester_id: user_id.to_proto(),
5296 })
5297 }
5298 }
5299 }
5300
5301 update
5302}
5303
5304fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
5305 proto::Contact {
5306 user_id: user_id.to_proto(),
5307 online: pool.is_user_online(user_id),
5308 busy,
5309 }
5310}
5311
5312fn room_updated(room: &proto::Room, peer: &Peer) {
5313 broadcast(
5314 None,
5315 room.participants
5316 .iter()
5317 .filter_map(|participant| Some(participant.peer_id?.into())),
5318 |peer_id| {
5319 peer.send(
5320 peer_id,
5321 proto::RoomUpdated {
5322 room: Some(room.clone()),
5323 },
5324 )
5325 },
5326 );
5327}
5328
5329fn channel_updated(
5330 channel: &db::channel::Model,
5331 room: &proto::Room,
5332 peer: &Peer,
5333 pool: &ConnectionPool,
5334) {
5335 let participants = room
5336 .participants
5337 .iter()
5338 .map(|p| p.user_id)
5339 .collect::<Vec<_>>();
5340
5341 broadcast(
5342 None,
5343 pool.channel_connection_ids(channel.root_id())
5344 .filter_map(|(channel_id, role)| {
5345 role.can_see_channel(channel.visibility).then(|| channel_id)
5346 }),
5347 |peer_id| {
5348 peer.send(
5349 peer_id,
5350 proto::UpdateChannels {
5351 channel_participants: vec![proto::ChannelParticipants {
5352 channel_id: channel.id.to_proto(),
5353 participant_user_ids: participants.clone(),
5354 }],
5355 ..Default::default()
5356 },
5357 )
5358 },
5359 );
5360}
5361
5362async fn send_dev_server_projects_update(
5363 user_id: UserId,
5364 mut status: proto::DevServerProjectsUpdate,
5365 session: &Session,
5366) {
5367 let pool = session.connection_pool().await;
5368 for dev_server in &mut status.dev_servers {
5369 dev_server.status =
5370 pool.dev_server_status(DevServerId(dev_server.dev_server_id as i32)) as i32;
5371 }
5372 let connections = pool.user_connection_ids(user_id);
5373 for connection_id in connections {
5374 session.peer.send(connection_id, status.clone()).trace_err();
5375 }
5376}
5377
5378async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> {
5379 let db = session.db().await;
5380
5381 let contacts = db.get_contacts(user_id).await?;
5382 let busy = db.is_user_busy(user_id).await?;
5383
5384 let pool = session.connection_pool().await;
5385 let updated_contact = contact_for_user(user_id, busy, &pool);
5386 for contact in contacts {
5387 if let db::Contact::Accepted {
5388 user_id: contact_user_id,
5389 ..
5390 } = contact
5391 {
5392 for contact_conn_id in pool.user_connection_ids(contact_user_id) {
5393 session
5394 .peer
5395 .send(
5396 contact_conn_id,
5397 proto::UpdateContacts {
5398 contacts: vec![updated_contact.clone()],
5399 remove_contacts: Default::default(),
5400 incoming_requests: Default::default(),
5401 remove_incoming_requests: Default::default(),
5402 outgoing_requests: Default::default(),
5403 remove_outgoing_requests: Default::default(),
5404 },
5405 )
5406 .trace_err();
5407 }
5408 }
5409 }
5410 Ok(())
5411}
5412
5413async fn lost_dev_server_connection(session: &DevServerSession) -> Result<()> {
5414 log::info!("lost dev server connection, unsharing projects");
5415 let project_ids = session
5416 .db()
5417 .await
5418 .get_stale_dev_server_projects(session.connection_id)
5419 .await?;
5420
5421 for project_id in project_ids {
5422 // not unshare re-checks the connection ids match, so we get away with no transaction
5423 unshare_project_internal(project_id, session.connection_id, None, &session).await?;
5424 }
5425
5426 let user_id = session.dev_server().user_id;
5427 let update = session
5428 .db()
5429 .await
5430 .dev_server_projects_update(user_id)
5431 .await?;
5432
5433 send_dev_server_projects_update(user_id, update, session).await;
5434
5435 Ok(())
5436}
5437
5438async fn leave_room_for_session(session: &UserSession, connection_id: ConnectionId) -> Result<()> {
5439 let mut contacts_to_update = HashSet::default();
5440
5441 let room_id;
5442 let canceled_calls_to_user_ids;
5443 let live_kit_room;
5444 let delete_live_kit_room;
5445 let room;
5446 let channel;
5447
5448 if let Some(mut left_room) = session.db().await.leave_room(connection_id).await? {
5449 contacts_to_update.insert(session.user_id());
5450
5451 for project in left_room.left_projects.values() {
5452 project_left(project, session);
5453 }
5454
5455 room_id = RoomId::from_proto(left_room.room.id);
5456 canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids);
5457 live_kit_room = mem::take(&mut left_room.room.live_kit_room);
5458 delete_live_kit_room = left_room.deleted;
5459 room = mem::take(&mut left_room.room);
5460 channel = mem::take(&mut left_room.channel);
5461
5462 room_updated(&room, &session.peer);
5463 } else {
5464 return Ok(());
5465 }
5466
5467 if let Some(channel) = channel {
5468 channel_updated(
5469 &channel,
5470 &room,
5471 &session.peer,
5472 &*session.connection_pool().await,
5473 );
5474 }
5475
5476 {
5477 let pool = session.connection_pool().await;
5478 for canceled_user_id in canceled_calls_to_user_ids {
5479 for connection_id in pool.user_connection_ids(canceled_user_id) {
5480 session
5481 .peer
5482 .send(
5483 connection_id,
5484 proto::CallCanceled {
5485 room_id: room_id.to_proto(),
5486 },
5487 )
5488 .trace_err();
5489 }
5490 contacts_to_update.insert(canceled_user_id);
5491 }
5492 }
5493
5494 for contact_user_id in contacts_to_update {
5495 update_user_contacts(contact_user_id, &session).await?;
5496 }
5497
5498 if let Some(live_kit) = session.live_kit_client.as_ref() {
5499 live_kit
5500 .remove_participant(live_kit_room.clone(), session.user_id().to_string())
5501 .await
5502 .trace_err();
5503
5504 if delete_live_kit_room {
5505 live_kit.delete_room(live_kit_room).await.trace_err();
5506 }
5507 }
5508
5509 Ok(())
5510}
5511
5512async fn leave_channel_buffers_for_session(session: &Session) -> Result<()> {
5513 let left_channel_buffers = session
5514 .db()
5515 .await
5516 .leave_channel_buffers(session.connection_id)
5517 .await?;
5518
5519 for left_buffer in left_channel_buffers {
5520 channel_buffer_updated(
5521 session.connection_id,
5522 left_buffer.connections,
5523 &proto::UpdateChannelBufferCollaborators {
5524 channel_id: left_buffer.channel_id.to_proto(),
5525 collaborators: left_buffer.collaborators,
5526 },
5527 &session.peer,
5528 );
5529 }
5530
5531 Ok(())
5532}
5533
5534fn project_left(project: &db::LeftProject, session: &UserSession) {
5535 for connection_id in &project.connection_ids {
5536 if project.should_unshare {
5537 session
5538 .peer
5539 .send(
5540 *connection_id,
5541 proto::UnshareProject {
5542 project_id: project.id.to_proto(),
5543 },
5544 )
5545 .trace_err();
5546 } else {
5547 session
5548 .peer
5549 .send(
5550 *connection_id,
5551 proto::RemoveProjectCollaborator {
5552 project_id: project.id.to_proto(),
5553 peer_id: Some(session.connection_id.into()),
5554 },
5555 )
5556 .trace_err();
5557 }
5558 }
5559}
5560
5561pub trait ResultExt {
5562 type Ok;
5563
5564 fn trace_err(self) -> Option<Self::Ok>;
5565}
5566
5567impl<T, E> ResultExt for Result<T, E>
5568where
5569 E: std::fmt::Debug,
5570{
5571 type Ok = T;
5572
5573 #[track_caller]
5574 fn trace_err(self) -> Option<T> {
5575 match self {
5576 Ok(value) => Some(value),
5577 Err(error) => {
5578 tracing::error!("{:?}", error);
5579 None
5580 }
5581 }
5582 }
5583}